File size: 5,711 Bytes
02f6c65
41550e9
02f6c65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41550e9
 
 
 
02f6c65
 
 
41550e9
02f6c65
41550e9
 
02f6c65
 
41550e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02f6c65
41550e9
02f6c65
41550e9
 
02f6c65
 
41550e9
 
 
 
 
02f6c65
41550e9
02f6c65
 
 
 
41550e9
 
 
 
 
 
 
 
65fc5c6
319b3f9
65fc5c6
 
 
1318c99
65fc5c6
 
319b3f9
 
 
 
 
65fc5c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
EE Sanity Check + Layer Diagnostics
Usage:
  python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
"""
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse

def get_sigma(hidden_size, seed):
    rng = np.random.default_rng(seed)
    sigma = rng.permutation(hidden_size)
    sigma_inv = np.argsort(sigma)
    return sigma, sigma_inv

def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
    print(f"\n{'='*60}")
    tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
    inputs = tokenizer(prompt, return_tensors="pt")

    print("[1] Loading models...")
    orig = AutoModelForCausalLM.from_pretrained(original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
    ee   = AutoModelForCausalLM.from_pretrained(ee_name,       torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
    orig.eval(); ee.eval()

    hidden_size = orig.config.hidden_size
    sigma, sigma_inv = get_sigma(hidden_size, seed)
    print(f"hidden_size={hidden_size}, seed={seed}")

    # --- CHECK 1: Embed layers ---
    embed_match = torch.allclose(orig.model.embed_tokens.weight.data, ee.model.embed_tokens.weight.data, atol=1e-3)
    print(f"\n[CHECK 1] Embed layers identical: {embed_match}")

    # --- LAYER DIFF: print every layer that differs and HOW ---
    print("\n[LAYER DIFF] Comparing every named parameter...")
    ROPE_OUTPUT_LAYERS = {"q_proj", "k_proj"}
    issues = []
    for (name_o, param_o), (name_e, param_e) in zip(orig.named_parameters(), ee.named_parameters()):
        assert name_o == name_e
        if torch.allclose(param_o.data, param_e.data, atol=1e-3):
            continue  # unchanged — skip

        basename = name_o.split(".")[-1]  # "weight", "bias"
        layer    = name_o.split(".")[-2]  # "q_proj", "embed_tokens", etc.
        shape    = tuple(param_o.shape)

        # Check what the transform DID to this param
        changed_cols = changed_rows = False
        if param_o.dim() == 2:
            if not torch.allclose(param_o.data, param_e.data[:, np.argsort(sigma_inv)], atol=1e-3):
                pass
            # Did it permute cols?
            reconstructed_cols = param_e.data[:, np.argsort(sigma_inv)]
            changed_cols = torch.allclose(param_o.data, reconstructed_cols, atol=1e-3)
            # Did it permute rows?
            reconstructed_rows = param_e.data[np.argsort(sigma_inv), :]
            changed_rows = torch.allclose(param_o.data, reconstructed_rows, atol=1e-3)
            # Did it permute both?
            reconstructed_both = param_e.data[np.argsort(sigma_inv), :][:, np.argsort(sigma_inv)]
            changed_both = torch.allclose(param_o.data, reconstructed_both, atol=1e-3)

            what = []
            if changed_both: what = ["BOTH rows+cols"]
            elif changed_cols: what = ["cols only"]
            elif changed_rows: what = ["rows only"]
            else: what = ["UNKNOWN permutation"]

            flag = ""
            if layer in ROPE_OUTPUT_LAYERS and ("BOTH" in what[0] or "rows" in what[0]):
                flag = "  ⚠️  BAD: RoPE layer has rows permuted!"
                issues.append(f"{name_o}: rows permuted on RoPE layer")
            elif layer not in ROPE_OUTPUT_LAYERS and shape[0] == hidden_size and shape[1] == hidden_size and "BOTH" not in what[0]:
                flag = f"  ⚠️  BAD: square hidden layer should have BOTH permuted"
                issues.append(f"{name_o}: square layer missing full permutation")

            print(f"  {layer:20s} {str(shape):20s}{what[0]}{flag}")

        elif param_o.dim() == 1:
            print(f"  {layer:20s} {str(shape):20s} → 1D (norm/bias)")

    # --- CHECK 4: Logits ---
    print("\n[CHECK 4] Equivariance test...")
    with torch.no_grad():
        plain_embeds = orig.model.embed_tokens(inputs.input_ids)
        encrypted_embeds = plain_embeds[..., sigma]
        orig_logits = orig(inputs_embeds=plain_embeds).logits
        ee_logits   = ee(inputs_embeds=encrypted_embeds).logits

    max_diff = (orig_logits - ee_logits).abs().max().item()
    match = max_diff < 0.5
    print(f"  Max logit diff: {max_diff:.4f}{'✅ PASS' if match else '❌ FAIL'}")

    # --- CHECK 5: Decode ---
    print("\n[CHECK 5] Greedy decode (10 tokens)...")
    with torch.no_grad():
        orig_ids = orig.generate(inputs.input_ids, max_new_tokens=10, do_sample=False)
        ee_ids   = ee.generate(inputs_embeds=encrypted_embeds,
                               attention_mask=inputs.attention_mask,
                               max_new_tokens=10, do_sample=False,
                               pad_token_id=tokenizer.eos_token_id)
    print(f"  Original : {repr(tokenizer.decode(orig_ids[0], skip_special_tokens=True))}")
    print(f"  EE model : {repr(tokenizer.decode(ee_ids[0],   skip_special_tokens=True))}")

    if issues:
        print(f"\n⚠️  {len(issues)} issue(s) found:")
        for i in issues: print(f"  - {i}")
    else:
        print("\n✅ No layer issues detected")

if __name__ == "__main__":

    original_name='Qwen/Qwen3-0.6B'
    ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
    seed = 424242
    run_check(original_name, ee_name, seed, prompt="Hello, how are you?")
    '''parser = argparse.ArgumentParser()
    parser.add_argument("--original", required=True)
    parser.add_argument("--ee",       required=True)
    parser.add_argument("--seed",     type=int, required=True)
    parser.add_argument("--prompt",   default="Hello, how are you?")
    args = parser.parse_args()
    run_check(args.original, args.ee, args.seed, args.prompt)'''