| |
| """inference_k96.py — generate text with the K=96 grouped Gemma-4 E2B model. |
| |
| Loads: |
| - Base Gemma-4 E2B (bf16) via gemma4_hf |
| - GroupedMaskedMLP at K_groups=96, K_active=48 (d=0.50), s50 cluster assignments |
| - Int4 QAT (group_size=32) |
| - LoRA r128 alpha=128 on up_proj/down_proj |
| - State dict from checkpoints/Sw_grouped_50_K96_lora_long.pt |
| |
| Verification: prints config + per-layer K_groups/K_active to confirm 96 groups active. |
| |
| Usage: |
| python scripts/inference_k96.py \ |
| --checkpoint checkpoints/Sw_grouped_50_K96_lora_long.pt \ |
| --prompt "The capital of France is" |
| """ |
| import argparse, os, sys |
| import torch |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from gemma4_hf import load_gemma4, DEVICE, N_LAYERS |
| from rung6_moe_g4 import wrap_int4, Int4QuantLinear, wrap_lora |
| from rung8_grouped_g4 import install_grouped, GroupedMaskedMLP |
|
|
|
|
| def build_model(checkpoint_path: str, |
| group_assignments_dir: str = "logs/groups", |
| group_tag: str = "s50"): |
| """Build the K=96 grouped model and load weights. Returns (model, tokenizer, cfg).""" |
| print(f"Loading base Gemma-4 E2B...") |
| model, tokenizer = load_gemma4() |
| for p in model.parameters(): |
| p.requires_grad_(False) |
|
|
| print(f"Loading checkpoint metadata from {checkpoint_path}...") |
| ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) |
| cfg = ckpt["config"] |
|
|
| |
| K_groups = cfg.get("K_groups") |
| if K_groups != 96: |
| raise ValueError(f"Expected K_groups=96, got {K_groups} from checkpoint") |
| K_active = cfg.get("K_active") or max(1, round(K_groups * cfg["density"])) |
| density = cfg["density"] |
| print(f" K_groups={K_groups} K_active={K_active} density={density:.3f}") |
|
|
| |
| print(f"Installing GroupedMaskedMLP (K={K_groups}, K_active={K_active}) on {N_LAYERS} layers...") |
| mlp_modules = install_grouped(model, |
| K_groups=K_groups, K_active=K_active, |
| group_assignments_dir=group_assignments_dir, |
| group_tag=group_tag, |
| freeze_base=False) |
| |
| missing, unexpected = model.load_state_dict(ckpt["student_state"], strict=False) |
| print(f" load: missing={len(missing)} unexpected={len(unexpected)}") |
|
|
| |
| tau = cfg.get("tau", 0.01) |
| for m in mlp_modules: |
| m.tau = tau |
|
|
| |
| if cfg.get("int4_qat"): |
| Int4QuantLinear._group_size = cfg.get("int4_group_size", 32) |
| n = wrap_int4(model) |
| print(f" int4 QAT: wrapped {n} Linear modules (group_size={Int4QuantLinear._group_size})") |
|
|
| |
| if cfg.get("use_lora") or cfg.get("gate_lora_train"): |
| ts = cfg.get("lora_targets", "") |
| targets = tuple(t.strip() for t in ts.split(",") if t.strip()) if ts else None |
| if targets: |
| n_lora, n_lora_p = wrap_lora(model, |
| rank=cfg.get("lora_rank", 16), |
| alpha=cfg.get("lora_alpha", 16.0), |
| target_substrings=targets) |
| else: |
| n_lora, n_lora_p = wrap_lora(model, |
| rank=cfg.get("lora_rank", 16), |
| alpha=cfg.get("lora_alpha", 16.0)) |
| print(f" LoRA: rank={cfg.get('lora_rank')} alpha={cfg.get('lora_alpha')} " |
| f"({n_lora} modules, {n_lora_p/1e6:.2f}M params)") |
|
|
| |
| missing2, unexp2 = model.load_state_dict(ckpt["student_state"], strict=False) |
| print(f" re-load after wrappers: missing={len(missing2)} unexpected={len(unexp2)}") |
|
|
| |
| |
| suspicious = [k for k in missing2 |
| if any(s in k for s in ("lora_a", "lora_b", "lora_A", "lora_B", |
| "scale", "zero", "qweight"))] |
| if suspicious: |
| raise RuntimeError( |
| f"After wrap_int4/wrap_lora, {len(suspicious)} expected weights are still " |
| f"unloaded (would default to random init): {suspicious[:5]}...") |
|
|
| model.eval() |
| return model, tokenizer, cfg, mlp_modules |
|
|
|
|
| def verify_grouped_routing(model, expected_K=96, expected_density=0.50): |
| """Re-walk model.layers and confirm every MLP is GroupedMaskedMLP with K_groups==expected_K. |
| Reading from model.layers (not a returned list) catches any later wrapper that may have |
| silently replaced an MLP.""" |
| print(f"\n=== Verifying grouped routing on {N_LAYERS} layers (walking model.layers) ===") |
| issues = [] |
| expected_K_active = max(1, round(expected_K * expected_density)) |
| for i in range(N_LAYERS): |
| m = model.layers[i].mlp |
| if not isinstance(m, GroupedMaskedMLP): |
| issues.append(f"Layer {i}: not GroupedMaskedMLP, got {type(m).__name__}") |
| continue |
| if m.K_groups != expected_K: |
| issues.append(f"Layer {i}: K_groups={m.K_groups}, expected {expected_K}") |
| if m.K_active != expected_K_active: |
| issues.append(f"Layer {i}: K_active={m.K_active}, expected {expected_K_active}") |
| if not hasattr(m, "group_assignments"): |
| issues.append(f"Layer {i}: missing group_assignments buffer") |
| continue |
| n_unique = m.group_assignments.unique().numel() |
| max_id = m.group_assignments.max().item() |
| if max_id >= expected_K: |
| issues.append(f"Layer {i}: max group id {max_id} >= K_groups {expected_K}") |
| if n_unique > expected_K: |
| issues.append(f"Layer {i}: {n_unique} unique groups > expected {expected_K}") |
| if issues: |
| print(" FAIL:") |
| for s in issues: |
| print(f" {s}") |
| raise RuntimeError("Verification failed") |
| m0 = model.layers[0].mlp |
| counts = torch.bincount(m0.group_assignments, minlength=m0.K_groups) |
| print(f" L0: K_groups={m0.K_groups} K_active={m0.K_active} " |
| f"D_FFN={m0.group_assignments.numel()} " |
| f"group_size_min={counts.min().item()} max={counts.max().item()} mean={counts.float().mean().item():.1f}") |
| print(f" ALL {N_LAYERS} layers verified — K={expected_K}, K_active={expected_K_active}") |
|
|
|
|
| @torch.no_grad() |
| def generate(model, tokenizer, prompt: str, max_new_tokens: int = 60, |
| temperature: float = 0.0, use_chat_template: bool = True): |
| """Use HF's generate() on the inner model with proper KV-cache + sampling. |
| For Gemma-4-IT, applies the chat template (turns the prompt into a user message). |
| Set use_chat_template=False to feed raw text (e.g. for completions).""" |
| if not hasattr(model, "inner"): |
| raise RuntimeError("Model lacks .inner; cannot use HF generate") |
| if use_chat_template: |
| formatted = tokenizer.apply_chat_template( |
| [{"role": "user", "content": prompt}], |
| tokenize=False, add_generation_prompt=True) |
| else: |
| formatted = prompt |
| inputs = tokenizer(formatted, return_tensors="pt").to(DEVICE) |
| in_len = inputs["input_ids"].shape[1] |
| do_sample = temperature > 0.0 |
| gen_kwargs = dict( |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| if do_sample: |
| gen_kwargs["temperature"] = temperature |
| gen_kwargs["top_p"] = 0.9 |
| out_ids = model.inner.generate(**inputs, **gen_kwargs) |
| full = tokenizer.decode(out_ids[0], skip_special_tokens=False) |
| response = tokenizer.decode(out_ids[0][in_len:], skip_special_tokens=True) |
| return full, response |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="checkpoints/Sw_grouped_50_K96_lora_long.pt") |
| parser.add_argument("--group_assignments_dir", default="logs/groups") |
| parser.add_argument("--group_tag", default="s50") |
| parser.add_argument("--prompt", default="What is the capital of France? Answer in one short sentence.") |
| parser.add_argument("--max_new_tokens", type=int, default=60) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--no_chat_template", action="store_true", |
| help="Feed raw prompt without chat template (for completions)") |
| args = parser.parse_args() |
|
|
| model, tokenizer, cfg, mlp_modules = build_model( |
| checkpoint_path=args.checkpoint, |
| group_assignments_dir=args.group_assignments_dir, |
| group_tag=args.group_tag) |
|
|
| verify_grouped_routing(model, expected_K=96, expected_density=cfg["density"]) |
|
|
| print(f"\n=== Generation ===") |
| print(f"Prompt: {args.prompt!r}") |
| print(f"Chat template: {not args.no_chat_template}") |
| print(f"Generating up to {args.max_new_tokens} tokens (temp={args.temperature})...") |
| full, response = generate(model, tokenizer, args.prompt, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| use_chat_template=not args.no_chat_template) |
| print(f"\n--- Response ---") |
| print(response) |
| print(f"--- Full (with special tokens) ---") |
| print(full) |
| print(f"--- End ---") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|