File size: 5,317 Bytes
02f6c65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65fc5c6
319b3f9
65fc5c6
 
 
1318c99
65fc5c6
 
319b3f9
 
 
 
 
65fc5c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
EE Sanity Check Script
Run this locally (not on HF Spaces) to verify the transform is correct.

Usage:
  python debug_ee.py --original Qwen/Qwen3-0.6B --ee your/model-dp-ee --seed 424242
"""
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse

def get_sigma(hidden_size, seed):
    rng = np.random.default_rng(seed)
    sigma = rng.permutation(hidden_size)
    sigma_inv = np.argsort(sigma)
    return sigma, sigma_inv

def run_check(original_name, ee_name, seed, prompt="Hello, how are you?"):
    print(f"\n{'='*60}")
    print(f"Original : {original_name}")
    print(f"EE model : {ee_name}")
    print(f"Seed     : {seed}")
    print(f"Prompt   : {prompt}")
    print('='*60)

    tokenizer = AutoTokenizer.from_pretrained(original_name, trust_remote_code=True)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids

    print("\n[1] Loading original model...")
    orig = AutoModelForCausalLM.from_pretrained(
        original_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
    )
    orig.eval()

    print("[2] Loading EE model...")
    ee = AutoModelForCausalLM.from_pretrained(
        ee_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
    )
    ee.eval()

    hidden_size = orig.config.hidden_size
    sigma, sigma_inv = get_sigma(hidden_size, seed)

    # --- Check 1: Does the EE embed layer match original? ---
    orig_embed = orig.model.embed_tokens.weight.data
    ee_embed   = ee.model.embed_tokens.weight.data
    embed_match = torch.allclose(orig_embed, ee_embed, atol=1e-3)
    print(f"\n[CHECK 1] Embed layers identical: {embed_match}")
    if not embed_match:
        diff = (orig_embed - ee_embed).abs().max().item()
        print(f"  ⚠️  Max diff: {diff:.6f} — EE embed was permuted, this BREAKS client-side encryption")
        print(f"  → Re-run transform with the embed layer skipped (see transform_fix.py)")

    # --- Check 2: Run plain forward on original ---
    print("\n[CHECK 2] Running plain forward on original...")
    with torch.no_grad():
        plain_embeds = orig.model.embed_tokens(input_ids)
        orig_out = orig(inputs_embeds=plain_embeds, output_hidden_states=False)
    orig_logits = orig_out.logits  # (1, seq, vocab)

    # --- Check 3: Run encrypted forward on EE model ---
    print("[CHECK 3] Running encrypted forward on EE model...")
    with torch.no_grad():
        encrypted_embeds = plain_embeds[..., sigma]
        ee_out = ee(inputs_embeds=encrypted_embeds, output_hidden_states=False)
    ee_logits = ee_out.logits

    # --- Check 4: Do logits match? ---
    logit_match = torch.allclose(orig_logits, ee_logits, atol=1e-1)
    max_diff = (orig_logits - ee_logits).abs().max().item()
    print(f"\n[CHECK 4] Logits match (atol=0.1): {logit_match}")
    print(f"  Max logit diff: {max_diff:.4f}")
    if not logit_match:
        print("  ⚠️  Logits differ — equivariance is BROKEN")
        # Find where it breaks — check RoPE suspicion
        print("\n  Diagnosing: checking if RoPE is the culprit...")
        print("  RoPE applies rotation in head_dim space (64), not hidden space (1024)")
        print("  If q_proj/k_proj output is permuted (because output==hidden_size),")
        print("  the head_dim slices fed to RoPE will be scrambled → broken attention")

    # --- Check 5: Greedy decode comparison ---
    print("\n[CHECK 5] Greedy decode comparison (10 tokens)...")
    with torch.no_grad():
        orig_ids = orig.generate(input_ids, max_new_tokens=10, do_sample=False)
        ee_ids   = ee.generate(inputs_embeds=encrypted_embeds,
                               attention_mask=inputs.attention_mask,
                               max_new_tokens=10, do_sample=False,
                               pad_token_id=tokenizer.eos_token_id)

    orig_text = tokenizer.decode(orig_ids[0], skip_special_tokens=True)
    ee_text   = tokenizer.decode(ee_ids[0],   skip_special_tokens=True)
    print(f"  Original output : {repr(orig_text)}")
    print(f"  EE model output : {repr(ee_text)}")
    print(f"  Match: {orig_text == ee_text}")

    if orig_text != ee_text:
        print("\n  ⚠️  OUTPUTS DIFFER. Most likely causes in order:")
        print("  1. Embed layer was permuted in EE model (Check 1 above)")
        print("  2. RoPE disruption — q_proj/k_proj output rows were permuted")
        print("     FIX: do NOT permute output rows of q_proj and k_proj")
        print("     because their outputs are split into heads for RoPE rotation")
        print("  3. Model on Hub is stale — re-run transform and re-push")

    print(f"\n{'='*60}\n")
    return embed_match and logit_match

if __name__ == "__main__":

    original_name='Qwen/Qwen3-0.6B'
    ee_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
    seed = 424242
    run_check(original_name, ee_name, seed, prompt="Hello, how are you?")
    '''parser = argparse.ArgumentParser()
    parser.add_argument("--original", required=True)
    parser.add_argument("--ee",       required=True)
    parser.add_argument("--seed",     type=int, required=True)
    parser.add_argument("--prompt",   default="Hello, how are you?")
    args = parser.parse_args()
    run_check(args.original, args.ee, args.seed, args.prompt)'''