ncylich commited on
Commit
19797bb
·
verified ·
1 Parent(s): eca1ebb

Upload inference_k96.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference_k96.py +209 -0
inference_k96.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """inference_k96.py — generate text with the K=96 grouped Gemma-4 E2B model.
3
+
4
+ Loads:
5
+ - Base Gemma-4 E2B (bf16) via gemma4_hf
6
+ - GroupedMaskedMLP at K_groups=96, K_active=48 (d=0.50), s50 cluster assignments
7
+ - Int4 QAT (group_size=32)
8
+ - LoRA r128 alpha=128 on up_proj/down_proj
9
+ - State dict from checkpoints/Sw_grouped_50_K96_lora_long.pt
10
+
11
+ Verification: prints config + per-layer K_groups/K_active to confirm 96 groups active.
12
+
13
+ Usage:
14
+ python scripts/inference_k96.py \
15
+ --checkpoint checkpoints/Sw_grouped_50_K96_lora_long.pt \
16
+ --prompt "The capital of France is"
17
+ """
18
+ import argparse, os, sys
19
+ import torch
20
+
21
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
+
23
+ from gemma4_hf import load_gemma4, DEVICE, N_LAYERS
24
+ from rung6_moe_g4 import wrap_int4, Int4QuantLinear, wrap_lora
25
+ from rung8_grouped_g4 import install_grouped, GroupedMaskedMLP
26
+
27
+
28
+ def build_model(checkpoint_path: str,
29
+ group_assignments_dir: str = "logs/groups",
30
+ group_tag: str = "s50"):
31
+ """Build the K=96 grouped model and load weights. Returns (model, tokenizer, cfg)."""
32
+ print(f"Loading base Gemma-4 E2B...")
33
+ model, tokenizer = load_gemma4()
34
+ for p in model.parameters():
35
+ p.requires_grad_(False)
36
+
37
+ print(f"Loading checkpoint metadata from {checkpoint_path}...")
38
+ ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
39
+ cfg = ckpt["config"]
40
+
41
+ # Sanity check: this MUST be a K=96 grouped checkpoint
42
+ K_groups = cfg.get("K_groups")
43
+ if K_groups != 96:
44
+ raise ValueError(f"Expected K_groups=96, got {K_groups} from checkpoint")
45
+ K_active = cfg.get("K_active") or max(1, round(K_groups * cfg["density"]))
46
+ density = cfg["density"]
47
+ print(f" K_groups={K_groups} K_active={K_active} density={density:.3f}")
48
+
49
+ # Install GroupedMaskedMLP at K=96 with s50 cluster assignments
50
+ print(f"Installing GroupedMaskedMLP (K={K_groups}, K_active={K_active}) on {N_LAYERS} layers...")
51
+ mlp_modules = install_grouped(model,
52
+ K_groups=K_groups, K_active=K_active,
53
+ group_assignments_dir=group_assignments_dir,
54
+ group_tag=group_tag,
55
+ freeze_base=False)
56
+ # Load partial state (proj weights)
57
+ missing, unexpected = model.load_state_dict(ckpt["student_state"], strict=False)
58
+ print(f" load: missing={len(missing)} unexpected={len(unexpected)}")
59
+
60
+ # Set tau (sigmoid relaxation temperature) to converged value
61
+ tau = cfg.get("tau", 0.01)
62
+ for m in mlp_modules:
63
+ m.tau = tau
64
+
65
+ # Apply Int4 QAT wrappers
66
+ if cfg.get("int4_qat"):
67
+ Int4QuantLinear._group_size = cfg.get("int4_group_size", 32)
68
+ n = wrap_int4(model)
69
+ print(f" int4 QAT: wrapped {n} Linear modules (group_size={Int4QuantLinear._group_size})")
70
+
71
+ # Apply LoRA
72
+ if cfg.get("use_lora") or cfg.get("gate_lora_train"):
73
+ ts = cfg.get("lora_targets", "")
74
+ targets = tuple(t.strip() for t in ts.split(",") if t.strip()) if ts else None
75
+ if targets:
76
+ n_lora, n_lora_p = wrap_lora(model,
77
+ rank=cfg.get("lora_rank", 16),
78
+ alpha=cfg.get("lora_alpha", 16.0),
79
+ target_substrings=targets)
80
+ else:
81
+ n_lora, n_lora_p = wrap_lora(model,
82
+ rank=cfg.get("lora_rank", 16),
83
+ alpha=cfg.get("lora_alpha", 16.0))
84
+ print(f" LoRA: rank={cfg.get('lora_rank')} alpha={cfg.get('lora_alpha')} "
85
+ f"({n_lora} modules, {n_lora_p/1e6:.2f}M params)")
86
+
87
+ # Re-load state to populate LoRA + int4 buffers
88
+ missing2, unexp2 = model.load_state_dict(ckpt["student_state"], strict=False)
89
+ print(f" re-load after wrappers: missing={len(missing2)} unexpected={len(unexp2)}")
90
+
91
+ # Hard guard: any LoRA/int4 buffer missing from the load means we'd silently
92
+ # serve a model with random LoRA weights or wrong int4 scales.
93
+ suspicious = [k for k in missing2
94
+ if any(s in k for s in ("lora_a", "lora_b", "lora_A", "lora_B",
95
+ "scale", "zero", "qweight"))]
96
+ if suspicious:
97
+ raise RuntimeError(
98
+ f"After wrap_int4/wrap_lora, {len(suspicious)} expected weights are still "
99
+ f"unloaded (would default to random init): {suspicious[:5]}...")
100
+
101
+ model.eval()
102
+ return model, tokenizer, cfg, mlp_modules
103
+
104
+
105
+ def verify_grouped_routing(model, expected_K=96, expected_density=0.50):
106
+ """Re-walk model.layers and confirm every MLP is GroupedMaskedMLP with K_groups==expected_K.
107
+ Reading from model.layers (not a returned list) catches any later wrapper that may have
108
+ silently replaced an MLP."""
109
+ print(f"\n=== Verifying grouped routing on {N_LAYERS} layers (walking model.layers) ===")
110
+ issues = []
111
+ expected_K_active = max(1, round(expected_K * expected_density))
112
+ for i in range(N_LAYERS):
113
+ m = model.layers[i].mlp
114
+ if not isinstance(m, GroupedMaskedMLP):
115
+ issues.append(f"Layer {i}: not GroupedMaskedMLP, got {type(m).__name__}")
116
+ continue
117
+ if m.K_groups != expected_K:
118
+ issues.append(f"Layer {i}: K_groups={m.K_groups}, expected {expected_K}")
119
+ if m.K_active != expected_K_active:
120
+ issues.append(f"Layer {i}: K_active={m.K_active}, expected {expected_K_active}")
121
+ if not hasattr(m, "group_assignments"):
122
+ issues.append(f"Layer {i}: missing group_assignments buffer")
123
+ continue
124
+ n_unique = m.group_assignments.unique().numel()
125
+ max_id = m.group_assignments.max().item()
126
+ if max_id >= expected_K:
127
+ issues.append(f"Layer {i}: max group id {max_id} >= K_groups {expected_K}")
128
+ if n_unique > expected_K:
129
+ issues.append(f"Layer {i}: {n_unique} unique groups > expected {expected_K}")
130
+ if issues:
131
+ print(" FAIL:")
132
+ for s in issues:
133
+ print(f" {s}")
134
+ raise RuntimeError("Verification failed")
135
+ m0 = model.layers[0].mlp
136
+ counts = torch.bincount(m0.group_assignments, minlength=m0.K_groups)
137
+ print(f" L0: K_groups={m0.K_groups} K_active={m0.K_active} "
138
+ f"D_FFN={m0.group_assignments.numel()} "
139
+ f"group_size_min={counts.min().item()} max={counts.max().item()} mean={counts.float().mean().item():.1f}")
140
+ print(f" ALL {N_LAYERS} layers verified — K={expected_K}, K_active={expected_K_active}")
141
+
142
+
143
+ @torch.no_grad()
144
+ def generate(model, tokenizer, prompt: str, max_new_tokens: int = 60,
145
+ temperature: float = 0.0, use_chat_template: bool = True):
146
+ """Use HF's generate() on the inner model with proper KV-cache + sampling.
147
+ For Gemma-4-IT, applies the chat template (turns the prompt into a user message).
148
+ Set use_chat_template=False to feed raw text (e.g. for completions)."""
149
+ if not hasattr(model, "inner"):
150
+ raise RuntimeError("Model lacks .inner; cannot use HF generate")
151
+ if use_chat_template:
152
+ formatted = tokenizer.apply_chat_template(
153
+ [{"role": "user", "content": prompt}],
154
+ tokenize=False, add_generation_prompt=True)
155
+ else:
156
+ formatted = prompt
157
+ inputs = tokenizer(formatted, return_tensors="pt").to(DEVICE)
158
+ in_len = inputs["input_ids"].shape[1]
159
+ do_sample = temperature > 0.0
160
+ gen_kwargs = dict(
161
+ max_new_tokens=max_new_tokens,
162
+ do_sample=do_sample,
163
+ pad_token_id=tokenizer.eos_token_id,
164
+ )
165
+ if do_sample:
166
+ gen_kwargs["temperature"] = temperature
167
+ gen_kwargs["top_p"] = 0.9
168
+ out_ids = model.inner.generate(**inputs, **gen_kwargs)
169
+ full = tokenizer.decode(out_ids[0], skip_special_tokens=False)
170
+ response = tokenizer.decode(out_ids[0][in_len:], skip_special_tokens=True)
171
+ return full, response
172
+
173
+
174
+ def main():
175
+ parser = argparse.ArgumentParser()
176
+ parser.add_argument("--checkpoint", default="checkpoints/Sw_grouped_50_K96_lora_long.pt")
177
+ parser.add_argument("--group_assignments_dir", default="logs/groups")
178
+ parser.add_argument("--group_tag", default="s50")
179
+ parser.add_argument("--prompt", default="What is the capital of France? Answer in one short sentence.")
180
+ parser.add_argument("--max_new_tokens", type=int, default=60)
181
+ parser.add_argument("--temperature", type=float, default=0.0)
182
+ parser.add_argument("--no_chat_template", action="store_true",
183
+ help="Feed raw prompt without chat template (for completions)")
184
+ args = parser.parse_args()
185
+
186
+ model, tokenizer, cfg, mlp_modules = build_model(
187
+ checkpoint_path=args.checkpoint,
188
+ group_assignments_dir=args.group_assignments_dir,
189
+ group_tag=args.group_tag)
190
+
191
+ verify_grouped_routing(model, expected_K=96, expected_density=cfg["density"])
192
+
193
+ print(f"\n=== Generation ===")
194
+ print(f"Prompt: {args.prompt!r}")
195
+ print(f"Chat template: {not args.no_chat_template}")
196
+ print(f"Generating up to {args.max_new_tokens} tokens (temp={args.temperature})...")
197
+ full, response = generate(model, tokenizer, args.prompt,
198
+ max_new_tokens=args.max_new_tokens,
199
+ temperature=args.temperature,
200
+ use_chat_template=not args.no_chat_template)
201
+ print(f"\n--- Response ---")
202
+ print(response)
203
+ print(f"--- Full (with special tokens) ---")
204
+ print(full)
205
+ print(f"--- End ---")
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()