""" EE Sanity Check Script Run this locally (not on HF Spaces) to verify the transform is correct. Usage: python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242 """ import torch import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer import argparse def get_sigma(hidden_size, seed): rng = np.random.default_rng(seed) sigma = rng.permutation(hidden_size) sigma_inv = np.argsort(sigma) return sigma, sigma_inv def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"): print(f"\n{'='*60}") print(f"Original : {original_name}") print(f"EE model : {ee_name}") print(f"Seed : {seed}") print(f"Prompt : {prompt}") print('='*60) tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs.input_ids print("\n[1] Loading original model...") orig = AutoModelForCausalLM.from_pretrained( original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True ) orig.eval() print("[2] Loading EE model...") ee = AutoModelForCausalLM.from_pretrained( ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True ) ee.eval() hidden_size = orig.config.hidden_size sigma, sigma_inv = get_sigma(hidden_size, seed) # --- Check 1: Does the EE embed layer match original? --- orig_embed = orig.model.embed_tokens.weight.data ee_embed = ee.model.embed_tokens.weight.data embed_match = torch.allclose(orig_embed, ee_embed, atol=1e-3) print(f"\n[CHECK 1] Embed layers identical: {embed_match}") if not embed_match: diff = (orig_embed - ee_embed).abs().max().item() print(f" ⚠️ Max diff: {diff:.6f} — EE embed was permuted, this BREAKS client-side encryption") print(f" → Re-run transform with the embed layer skipped (see transform_fix.py)") # --- Check 2: Run plain forward on original --- print("\n[CHECK 2] Running plain forward on original...") with torch.no_grad(): plain_embeds = orig.model.embed_tokens(input_ids) orig_out = orig(inputs_embeds=plain_embeds, output_hidden_states=False) orig_logits = orig_out.logits # (1, seq, vocab) # --- Check 3: Run encrypted forward on EE model --- print("[CHECK 3] Running encrypted forward on EE model...") with torch.no_grad(): encrypted_embeds = plain_embeds[..., sigma] ee_out = ee(inputs_embeds=encrypted_embeds, output_hidden_states=False) ee_logits = ee_out.logits # --- Check 4: Do logits match? --- logit_match = torch.allclose(orig_logits, ee_logits, atol=1e-1) max_diff = (orig_logits - ee_logits).abs().max().item() print(f"\n[CHECK 4] Logits match (atol=0.1): {logit_match}") print(f" Max logit diff: {max_diff:.4f}") if not logit_match: print(" ⚠️ Logits differ — equivariance is BROKEN") # Find where it breaks — check RoPE suspicion print("\n Diagnosing: checking if RoPE is the culprit...") print(" RoPE applies rotation in head_dim space (64), not hidden space (1024)") print(" If q_proj/k_proj output is permuted (because output==hidden_size),") print(" the head_dim slices fed to RoPE will be scrambled → broken attention") # --- Check 5: Greedy decode comparison --- print("\n[CHECK 5] Greedy decode comparison (10 tokens)...") with torch.no_grad(): orig_ids = orig.generate(input_ids, max_new_tokens=10, do_sample=False) ee_ids = ee.generate(inputs_embeds=encrypted_embeds, attention_mask=inputs.attention_mask, max_new_tokens=10, do_sample=False, pad_token_id=tokenizer.eos_token_id) orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True) ee_text = tokenizer.decode(ee_ids[0], skip_special_tokens=True) print(f" Original output : {repr(orig_text)}") print(f" EE model output : {repr(ee_text)}") print(f" Match: {orig_text == ee_text}") if orig_text != ee_text: print("\n ⚠️ OUTPUTS DIFFER. Most likely causes in order:") print(" 1. Embed layer was permuted in EE model (Check 1 above)") print(" 2. RoPE disruption — q_proj/k_proj output rows were permuted") print(" FIX: do NOT permute output rows of q_proj and k_proj") print(" because their outputs are split into heads for RoPE rotation") print(" 3. Model on Hub is stale — re-run transform and re-push") print(f"\n{'='*60}\n") return embed_match and logit_match if __name__ == "__main__": original_name='Qwen/Qwen3-0.6B' ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee' seed = 424242 run_check(original_name, ee_name, seed, prompt="Hello, how are you?") '''parser = argparse.ArgumentParser() parser.add_argument("--original", required=True) parser.add_argument("--ee", required=True) parser.add_argument("--seed", type=int, required=True) parser.add_argument("--prompt", default="Hello, how are you?") args = parser.parse_args() run_check(args.original, args.ee, args.seed, args.prompt)'''