#!/usr/bin/env python3 """inference_k96.py — generate text with the K=96 grouped Gemma-4 E2B model. Loads: - Base Gemma-4 E2B (bf16) via gemma4_hf - GroupedMaskedMLP at K_groups=96, K_active=48 (d=0.50), s50 cluster assignments - Int4 QAT (group_size=32) - LoRA r128 alpha=128 on up_proj/down_proj - State dict from checkpoints/Sw_grouped_50_K96_lora_long.pt Verification: prints config + per-layer K_groups/K_active to confirm 96 groups active. Usage: python scripts/inference_k96.py \ --checkpoint checkpoints/Sw_grouped_50_K96_lora_long.pt \ --prompt "The capital of France is" """ import argparse, os, sys import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from gemma4_hf import load_gemma4, DEVICE, N_LAYERS from rung6_moe_g4 import wrap_int4, Int4QuantLinear, wrap_lora from rung8_grouped_g4 import install_grouped, GroupedMaskedMLP def build_model(checkpoint_path: str, group_assignments_dir: str = "logs/groups", group_tag: str = "s50"): """Build the K=96 grouped model and load weights. Returns (model, tokenizer, cfg).""" print(f"Loading base Gemma-4 E2B...") model, tokenizer = load_gemma4() for p in model.parameters(): p.requires_grad_(False) print(f"Loading checkpoint metadata from {checkpoint_path}...") ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) cfg = ckpt["config"] # Sanity check: this MUST be a K=96 grouped checkpoint K_groups = cfg.get("K_groups") if K_groups != 96: raise ValueError(f"Expected K_groups=96, got {K_groups} from checkpoint") K_active = cfg.get("K_active") or max(1, round(K_groups * cfg["density"])) density = cfg["density"] print(f" K_groups={K_groups} K_active={K_active} density={density:.3f}") # Install GroupedMaskedMLP at K=96 with s50 cluster assignments print(f"Installing GroupedMaskedMLP (K={K_groups}, K_active={K_active}) on {N_LAYERS} layers...") mlp_modules = install_grouped(model, K_groups=K_groups, K_active=K_active, group_assignments_dir=group_assignments_dir, group_tag=group_tag, freeze_base=False) # Load partial state (proj weights) missing, unexpected = model.load_state_dict(ckpt["student_state"], strict=False) print(f" load: missing={len(missing)} unexpected={len(unexpected)}") # Set tau (sigmoid relaxation temperature) to converged value tau = cfg.get("tau", 0.01) for m in mlp_modules: m.tau = tau # Apply Int4 QAT wrappers if cfg.get("int4_qat"): Int4QuantLinear._group_size = cfg.get("int4_group_size", 32) n = wrap_int4(model) print(f" int4 QAT: wrapped {n} Linear modules (group_size={Int4QuantLinear._group_size})") # Apply LoRA if cfg.get("use_lora") or cfg.get("gate_lora_train"): ts = cfg.get("lora_targets", "") targets = tuple(t.strip() for t in ts.split(",") if t.strip()) if ts else None if targets: n_lora, n_lora_p = wrap_lora(model, rank=cfg.get("lora_rank", 16), alpha=cfg.get("lora_alpha", 16.0), target_substrings=targets) else: n_lora, n_lora_p = wrap_lora(model, rank=cfg.get("lora_rank", 16), alpha=cfg.get("lora_alpha", 16.0)) print(f" LoRA: rank={cfg.get('lora_rank')} alpha={cfg.get('lora_alpha')} " f"({n_lora} modules, {n_lora_p/1e6:.2f}M params)") # Re-load state to populate LoRA + int4 buffers missing2, unexp2 = model.load_state_dict(ckpt["student_state"], strict=False) print(f" re-load after wrappers: missing={len(missing2)} unexpected={len(unexp2)}") # Hard guard: any LoRA/int4 buffer missing from the load means we'd silently # serve a model with random LoRA weights or wrong int4 scales. suspicious = [k for k in missing2 if any(s in k for s in ("lora_a", "lora_b", "lora_A", "lora_B", "scale", "zero", "qweight"))] if suspicious: raise RuntimeError( f"After wrap_int4/wrap_lora, {len(suspicious)} expected weights are still " f"unloaded (would default to random init): {suspicious[:5]}...") model.eval() return model, tokenizer, cfg, mlp_modules def verify_grouped_routing(model, expected_K=96, expected_density=0.50): """Re-walk model.layers and confirm every MLP is GroupedMaskedMLP with K_groups==expected_K. Reading from model.layers (not a returned list) catches any later wrapper that may have silently replaced an MLP.""" print(f"\n=== Verifying grouped routing on {N_LAYERS} layers (walking model.layers) ===") issues = [] expected_K_active = max(1, round(expected_K * expected_density)) for i in range(N_LAYERS): m = model.layers[i].mlp if not isinstance(m, GroupedMaskedMLP): issues.append(f"Layer {i}: not GroupedMaskedMLP, got {type(m).__name__}") continue if m.K_groups != expected_K: issues.append(f"Layer {i}: K_groups={m.K_groups}, expected {expected_K}") if m.K_active != expected_K_active: issues.append(f"Layer {i}: K_active={m.K_active}, expected {expected_K_active}") if not hasattr(m, "group_assignments"): issues.append(f"Layer {i}: missing group_assignments buffer") continue n_unique = m.group_assignments.unique().numel() max_id = m.group_assignments.max().item() if max_id >= expected_K: issues.append(f"Layer {i}: max group id {max_id} >= K_groups {expected_K}") if n_unique > expected_K: issues.append(f"Layer {i}: {n_unique} unique groups > expected {expected_K}") if issues: print(" FAIL:") for s in issues: print(f" {s}") raise RuntimeError("Verification failed") m0 = model.layers[0].mlp counts = torch.bincount(m0.group_assignments, minlength=m0.K_groups) print(f" L0: K_groups={m0.K_groups} K_active={m0.K_active} " f"D_FFN={m0.group_assignments.numel()} " f"group_size_min={counts.min().item()} max={counts.max().item()} mean={counts.float().mean().item():.1f}") print(f" ALL {N_LAYERS} layers verified — K={expected_K}, K_active={expected_K_active}") @torch.no_grad() def generate(model, tokenizer, prompt: str, max_new_tokens: int = 60, temperature: float = 0.0, use_chat_template: bool = True): """Use HF's generate() on the inner model with proper KV-cache + sampling. For Gemma-4-IT, applies the chat template (turns the prompt into a user message). Set use_chat_template=False to feed raw text (e.g. for completions).""" if not hasattr(model, "inner"): raise RuntimeError("Model lacks .inner; cannot use HF generate") if use_chat_template: formatted = tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) else: formatted = prompt inputs = tokenizer(formatted, return_tensors="pt").to(DEVICE) in_len = inputs["input_ids"].shape[1] do_sample = temperature > 0.0 gen_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id, ) if do_sample: gen_kwargs["temperature"] = temperature gen_kwargs["top_p"] = 0.9 out_ids = model.inner.generate(**inputs, **gen_kwargs) full = tokenizer.decode(out_ids[0], skip_special_tokens=False) response = tokenizer.decode(out_ids[0][in_len:], skip_special_tokens=True) return full, response def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default="checkpoints/Sw_grouped_50_K96_lora_long.pt") parser.add_argument("--group_assignments_dir", default="logs/groups") parser.add_argument("--group_tag", default="s50") parser.add_argument("--prompt", default="What is the capital of France? Answer in one short sentence.") parser.add_argument("--max_new_tokens", type=int, default=60) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--no_chat_template", action="store_true", help="Feed raw prompt without chat template (for completions)") args = parser.parse_args() model, tokenizer, cfg, mlp_modules = build_model( checkpoint_path=args.checkpoint, group_assignments_dir=args.group_assignments_dir, group_tag=args.group_tag) verify_grouped_routing(model, expected_K=96, expected_density=cfg["density"]) print(f"\n=== Generation ===") print(f"Prompt: {args.prompt!r}") print(f"Chat template: {not args.no_chat_template}") print(f"Generating up to {args.max_new_tokens} tokens (temp={args.temperature})...") full, response = generate(model, tokenizer, args.prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, use_chat_template=not args.no_chat_template) print(f"\n--- Response ---") print(response) print(f"--- Full (with special tokens) ---") print(full) print(f"--- End ---") if __name__ == "__main__": main()