broadfield-dev's picture
Update debug_ee.py
41550e9 verified
raw
history blame
5.71 kB
"""
EE Sanity Check + Layer Diagnostics
Usage:
python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
"""
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
def get_sigma(hidden_size, seed):
rng = np.random.default_rng(seed)
sigma = rng.permutation(hidden_size)
sigma_inv = np.argsort(sigma)
return sigma, sigma_inv
def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
print(f"\n{'='*60}")
tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt")
print("[1] Loading models...")
orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
ee = AutoModelForCausalLM.from_pretrained(ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
orig.eval(); ee.eval()
hidden_size = orig.config.hidden_size
sigma, sigma_inv = get_sigma(hidden_size, seed)
print(f"hidden_size={hidden_size}, seed={seed}")
# --- CHECK 1: Embed layers ---
embed_match = torch.allclose(orig.model.embed_tokens.weight.data, ee.model.embed_tokens.weight.data, atol=1e-3)
print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
# --- LAYER DIFF: print every layer that differs and HOW ---
print("\n[LAYER DIFF] Comparing every named parameter...")
ROPE_OUTPUT_LAYERS = {"q_proj", "k_proj"}
issues = []
for (name_o, param_o), (name_e, param_e) in zip(orig.named_parameters(), ee.named_parameters()):
assert name_o == name_e
if torch.allclose(param_o.data, param_e.data, atol=1e-3):
continue # unchanged — skip
basename = name_o.split(".")[-1] # "weight", "bias"
layer = name_o.split(".")[-2] # "q_proj", "embed_tokens", etc.
shape = tuple(param_o.shape)
# Check what the transform DID to this param
changed_cols = changed_rows = False
if param_o.dim() == 2:
if not torch.allclose(param_o.data, param_e.data[:, np.argsort(sigma_inv)], atol=1e-3):
pass
# Did it permute cols?
reconstructed_cols = param_e.data[:, np.argsort(sigma_inv)]
changed_cols = torch.allclose(param_o.data, reconstructed_cols, atol=1e-3)
# Did it permute rows?
reconstructed_rows = param_e.data[np.argsort(sigma_inv), :]
changed_rows = torch.allclose(param_o.data, reconstructed_rows, atol=1e-3)
# Did it permute both?
reconstructed_both = param_e.data[np.argsort(sigma_inv), :][:, np.argsort(sigma_inv)]
changed_both = torch.allclose(param_o.data, reconstructed_both, atol=1e-3)
what = []
if changed_both: what = ["BOTH rows+cols"]
elif changed_cols: what = ["cols only"]
elif changed_rows: what = ["rows only"]
else: what = ["UNKNOWN permutation"]
flag = ""
if layer in ROPE_OUTPUT_LAYERS and ("BOTH" in what[0] or "rows" in what[0]):
flag = " ⚠️ BAD: RoPE layer has rows permuted!"
issues.append(f"{name_o}: rows permuted on RoPE layer")
elif layer not in ROPE_OUTPUT_LAYERS and shape[0] == hidden_size and shape[1] == hidden_size and "BOTH" not in what[0]:
flag = f" ⚠️ BAD: square hidden layer should have BOTH permuted"
issues.append(f"{name_o}: square layer missing full permutation")
print(f" {layer:20s} {str(shape):20s}{what[0]}{flag}")
elif param_o.dim() == 1:
print(f" {layer:20s} {str(shape):20s} → 1D (norm/bias)")
# --- CHECK 4: Logits ---
print("\n[CHECK 4] Equivariance test...")
with torch.no_grad():
plain_embeds = orig.model.embed_tokens(inputs.input_ids)
encrypted_embeds = plain_embeds[..., sigma]
orig_logits = orig(inputs_embeds=plain_embeds).logits
ee_logits = ee(inputs_embeds=encrypted_embeds).logits
max_diff = (orig_logits - ee_logits).abs().max().item()
match = max_diff < 0.5
print(f" Max logit diff: {max_diff:.4f}{'✅ PASS' if match else '❌ FAIL'}")
# --- CHECK 5: Decode ---
print("\n[CHECK 5] Greedy decode (10 tokens)...")
with torch.no_grad():
orig_ids = orig.generate(inputs.input_ids, max_new_tokens=10, do_sample=False)
ee_ids = ee.generate(inputs_embeds=encrypted_embeds,
attention_mask=inputs.attention_mask,
max_new_tokens=10, do_sample=False,
pad_token_id=tokenizer.eos_token_id)
print(f" Original : {repr(tokenizer.decode(orig_ids[0], skip_special_tokens=True))}")
print(f" EE model : {repr(tokenizer.decode(ee_ids[0], skip_special_tokens=True))}")
if issues:
print(f"\n⚠️ {len(issues)} issue(s) found:")
for i in issues: print(f" - {i}")
else:
print("\n✅ No layer issues detected")
if __name__ == "__main__":
original_name='Qwen/Qwen3-0.6B'
ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
seed = 424242
run_check(original_name, ee_name, seed, prompt="Hello, how are you?")
'''parser = argparse.ArgumentParser()
parser.add_argument("--original", required=True)
parser.add_argument("--ee", required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--prompt", default="Hello, how are you?")
args = parser.parse_args()
run_check(args.original, args.ee, args.seed, args.prompt)'''