| """ |
| EE Sanity Check + Layer Diagnostics |
| Usage: |
| python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242 |
| """ |
| import torch |
| import numpy as np |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import argparse |
|
|
| def get_sigma(hidden_size, seed): |
| rng = np.random.default_rng(seed) |
| sigma = rng.permutation(hidden_size) |
| sigma_inv = np.argsort(sigma) |
| return sigma, sigma_inv |
|
|
| def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"): |
| print(f"\n{'='*60}") |
| tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True) |
| inputs = tokenizer(prompt, return_tensors="pt") |
|
|
| print("[1] Loading models...") |
| orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True) |
| ee = AutoModelForCausalLM.from_pretrained(ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True) |
| orig.eval(); ee.eval() |
|
|
| hidden_size = orig.config.hidden_size |
| sigma, sigma_inv = get_sigma(hidden_size, seed) |
| print(f"hidden_size={hidden_size}, seed={seed}") |
|
|
| |
| embed_match = torch.allclose(orig.model.embed_tokens.weight.data, ee.model.embed_tokens.weight.data, atol=1e-3) |
| print(f"\n[CHECK 1] Embed layers identical: {embed_match}") |
|
|
| |
| print("\n[LAYER DIFF] Comparing every named parameter...") |
| ROPE_OUTPUT_LAYERS = {"q_proj", "k_proj"} |
| issues = [] |
| for (name_o, param_o), (name_e, param_e) in zip(orig.named_parameters(), ee.named_parameters()): |
| assert name_o == name_e |
| if torch.allclose(param_o.data, param_e.data, atol=1e-3): |
| continue |
|
|
| basename = name_o.split(".")[-1] |
| layer = name_o.split(".")[-2] |
| shape = tuple(param_o.shape) |
|
|
| |
| changed_cols = changed_rows = False |
| if param_o.dim() == 2: |
| if not torch.allclose(param_o.data, param_e.data[:, np.argsort(sigma_inv)], atol=1e-3): |
| pass |
| |
| reconstructed_cols = param_e.data[:, np.argsort(sigma_inv)] |
| changed_cols = torch.allclose(param_o.data, reconstructed_cols, atol=1e-3) |
| |
| reconstructed_rows = param_e.data[np.argsort(sigma_inv), :] |
| changed_rows = torch.allclose(param_o.data, reconstructed_rows, atol=1e-3) |
| |
| reconstructed_both = param_e.data[np.argsort(sigma_inv), :][:, np.argsort(sigma_inv)] |
| changed_both = torch.allclose(param_o.data, reconstructed_both, atol=1e-3) |
|
|
| what = [] |
| if changed_both: what = ["BOTH rows+cols"] |
| elif changed_cols: what = ["cols only"] |
| elif changed_rows: what = ["rows only"] |
| else: what = ["UNKNOWN permutation"] |
|
|
| flag = "" |
| if layer in ROPE_OUTPUT_LAYERS and ("BOTH" in what[0] or "rows" in what[0]): |
| flag = " ⚠️ BAD: RoPE layer has rows permuted!" |
| issues.append(f"{name_o}: rows permuted on RoPE layer") |
| elif layer not in ROPE_OUTPUT_LAYERS and shape[0] == hidden_size and shape[1] == hidden_size and "BOTH" not in what[0]: |
| flag = f" ⚠️ BAD: square hidden layer should have BOTH permuted" |
| issues.append(f"{name_o}: square layer missing full permutation") |
|
|
| print(f" {layer:20s} {str(shape):20s} → {what[0]}{flag}") |
|
|
| elif param_o.dim() == 1: |
| print(f" {layer:20s} {str(shape):20s} → 1D (norm/bias)") |
|
|
| |
| print("\n[CHECK 4] Equivariance test...") |
| with torch.no_grad(): |
| plain_embeds = orig.model.embed_tokens(inputs.input_ids) |
| encrypted_embeds = plain_embeds[..., sigma] |
| orig_logits = orig(inputs_embeds=plain_embeds).logits |
| ee_logits = ee(inputs_embeds=encrypted_embeds).logits |
|
|
| max_diff = (orig_logits - ee_logits).abs().max().item() |
| match = max_diff < 0.5 |
| print(f" Max logit diff: {max_diff:.4f} → {'✅ PASS' if match else '❌ FAIL'}") |
|
|
| |
| print("\n[CHECK 5] Greedy decode (10 tokens)...") |
| with torch.no_grad(): |
| orig_ids = orig.generate(inputs.input_ids, max_new_tokens=10, do_sample=False) |
| ee_ids = ee.generate(inputs_embeds=encrypted_embeds, |
| attention_mask=inputs.attention_mask, |
| max_new_tokens=10, do_sample=False, |
| pad_token_id=tokenizer.eos_token_id) |
| print(f" Original : {repr(tokenizer.decode(orig_ids[0], skip_special_tokens=True))}") |
| print(f" EE model : {repr(tokenizer.decode(ee_ids[0], skip_special_tokens=True))}") |
|
|
| if issues: |
| print(f"\n⚠️ {len(issues)} issue(s) found:") |
| for i in issues: print(f" - {i}") |
| else: |
| print("\n✅ No layer issues detected") |
|
|
| if __name__ == "__main__": |
|
|
| original_name='Qwen/Qwen3-0.6B' |
| ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee' |
| seed = 424242 |
| run_check(original_name, ee_name, seed, prompt="Hello, how are you?") |
| '''parser = argparse.ArgumentParser() |
| parser.add_argument("--original", required=True) |
| parser.add_argument("--ee", required=True) |
| parser.add_argument("--seed", type=int, required=True) |
| parser.add_argument("--prompt", default="Hello, how are you?") |
| args = parser.parse_args() |
| run_check(args.original, args.ee, args.seed, args.prompt)''' |