gemma4-e2b-grouped-k96 / inference_k96.py
ncylich's picture
Upload inference_k96.py with huggingface_hub
b460126 verified
#!/usr/bin/env python3
"""inference_k96.py — generate text with the K=96 grouped Gemma-4 E2B model.
Loads:
- Base Gemma-4 E2B (bf16) via gemma4_hf
- GroupedMaskedMLP at K_groups=96, K_active=48 (d=0.50), s50 cluster assignments
- Int4 QAT (group_size=32)
- LoRA r128 alpha=128 on up_proj/down_proj
- State dict from checkpoints/Sw_grouped_50_K96_lora_long.pt
Verification: prints config + per-layer K_groups/K_active to confirm 96 groups active.
Usage:
python scripts/inference_k96.py \
--checkpoint checkpoints/Sw_grouped_50_K96_lora_long.pt \
--prompt "The capital of France is"
"""
import argparse, os, sys
import torch
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gemma4_hf import load_gemma4, DEVICE, N_LAYERS
from rung6_moe_g4 import wrap_int4, Int4QuantLinear, wrap_lora
from rung8_grouped_g4 import install_grouped, GroupedMaskedMLP
def build_model(checkpoint_path: str,
group_assignments_dir: str = "logs/groups",
group_tag: str = "s50"):
"""Build the K=96 grouped model and load weights. Returns (model, tokenizer, cfg)."""
print(f"Loading base Gemma-4 E2B...")
model, tokenizer = load_gemma4()
for p in model.parameters():
p.requires_grad_(False)
print(f"Loading checkpoint metadata from {checkpoint_path}...")
ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
cfg = ckpt["config"]
# Sanity check: this MUST be a K=96 grouped checkpoint
K_groups = cfg.get("K_groups")
if K_groups != 96:
raise ValueError(f"Expected K_groups=96, got {K_groups} from checkpoint")
K_active = cfg.get("K_active") or max(1, round(K_groups * cfg["density"]))
density = cfg["density"]
print(f" K_groups={K_groups} K_active={K_active} density={density:.3f}")
# Install GroupedMaskedMLP at K=96 with s50 cluster assignments
print(f"Installing GroupedMaskedMLP (K={K_groups}, K_active={K_active}) on {N_LAYERS} layers...")
mlp_modules = install_grouped(model,
K_groups=K_groups, K_active=K_active,
group_assignments_dir=group_assignments_dir,
group_tag=group_tag,
freeze_base=False)
# Load partial state (proj weights)
missing, unexpected = model.load_state_dict(ckpt["student_state"], strict=False)
print(f" load: missing={len(missing)} unexpected={len(unexpected)}")
# Set tau (sigmoid relaxation temperature) to converged value
tau = cfg.get("tau", 0.01)
for m in mlp_modules:
m.tau = tau
# Apply Int4 QAT wrappers
if cfg.get("int4_qat"):
Int4QuantLinear._group_size = cfg.get("int4_group_size", 32)
n = wrap_int4(model)
print(f" int4 QAT: wrapped {n} Linear modules (group_size={Int4QuantLinear._group_size})")
# Apply LoRA
if cfg.get("use_lora") or cfg.get("gate_lora_train"):
ts = cfg.get("lora_targets", "")
targets = tuple(t.strip() for t in ts.split(",") if t.strip()) if ts else None
if targets:
n_lora, n_lora_p = wrap_lora(model,
rank=cfg.get("lora_rank", 16),
alpha=cfg.get("lora_alpha", 16.0),
target_substrings=targets)
else:
n_lora, n_lora_p = wrap_lora(model,
rank=cfg.get("lora_rank", 16),
alpha=cfg.get("lora_alpha", 16.0))
print(f" LoRA: rank={cfg.get('lora_rank')} alpha={cfg.get('lora_alpha')} "
f"({n_lora} modules, {n_lora_p/1e6:.2f}M params)")
# Re-load state to populate LoRA + int4 buffers
missing2, unexp2 = model.load_state_dict(ckpt["student_state"], strict=False)
print(f" re-load after wrappers: missing={len(missing2)} unexpected={len(unexp2)}")
# Hard guard: any LoRA/int4 buffer missing from the load means we'd silently
# serve a model with random LoRA weights or wrong int4 scales.
suspicious = [k for k in missing2
if any(s in k for s in ("lora_a", "lora_b", "lora_A", "lora_B",
"scale", "zero", "qweight"))]
if suspicious:
raise RuntimeError(
f"After wrap_int4/wrap_lora, {len(suspicious)} expected weights are still "
f"unloaded (would default to random init): {suspicious[:5]}...")
model.eval()
return model, tokenizer, cfg, mlp_modules
def verify_grouped_routing(model, expected_K=96, expected_density=0.50):
"""Re-walk model.layers and confirm every MLP is GroupedMaskedMLP with K_groups==expected_K.
Reading from model.layers (not a returned list) catches any later wrapper that may have
silently replaced an MLP."""
print(f"\n=== Verifying grouped routing on {N_LAYERS} layers (walking model.layers) ===")
issues = []
expected_K_active = max(1, round(expected_K * expected_density))
for i in range(N_LAYERS):
m = model.layers[i].mlp
if not isinstance(m, GroupedMaskedMLP):
issues.append(f"Layer {i}: not GroupedMaskedMLP, got {type(m).__name__}")
continue
if m.K_groups != expected_K:
issues.append(f"Layer {i}: K_groups={m.K_groups}, expected {expected_K}")
if m.K_active != expected_K_active:
issues.append(f"Layer {i}: K_active={m.K_active}, expected {expected_K_active}")
if not hasattr(m, "group_assignments"):
issues.append(f"Layer {i}: missing group_assignments buffer")
continue
n_unique = m.group_assignments.unique().numel()
max_id = m.group_assignments.max().item()
if max_id >= expected_K:
issues.append(f"Layer {i}: max group id {max_id} >= K_groups {expected_K}")
if n_unique > expected_K:
issues.append(f"Layer {i}: {n_unique} unique groups > expected {expected_K}")
if issues:
print(" FAIL:")
for s in issues:
print(f" {s}")
raise RuntimeError("Verification failed")
m0 = model.layers[0].mlp
counts = torch.bincount(m0.group_assignments, minlength=m0.K_groups)
print(f" L0: K_groups={m0.K_groups} K_active={m0.K_active} "
f"D_FFN={m0.group_assignments.numel()} "
f"group_size_min={counts.min().item()} max={counts.max().item()} mean={counts.float().mean().item():.1f}")
print(f" ALL {N_LAYERS} layers verified — K={expected_K}, K_active={expected_K_active}")
@torch.no_grad()
def generate(model, tokenizer, prompt: str, max_new_tokens: int = 60,
temperature: float = 0.0, use_chat_template: bool = True):
"""Use HF's generate() on the inner model with proper KV-cache + sampling.
For Gemma-4-IT, applies the chat template (turns the prompt into a user message).
Set use_chat_template=False to feed raw text (e.g. for completions)."""
if not hasattr(model, "inner"):
raise RuntimeError("Model lacks .inner; cannot use HF generate")
if use_chat_template:
formatted = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False, add_generation_prompt=True)
else:
formatted = prompt
inputs = tokenizer(formatted, return_tensors="pt").to(DEVICE)
in_len = inputs["input_ids"].shape[1]
do_sample = temperature > 0.0
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id,
)
if do_sample:
gen_kwargs["temperature"] = temperature
gen_kwargs["top_p"] = 0.9
out_ids = model.inner.generate(**inputs, **gen_kwargs)
full = tokenizer.decode(out_ids[0], skip_special_tokens=False)
response = tokenizer.decode(out_ids[0][in_len:], skip_special_tokens=True)
return full, response
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="checkpoints/Sw_grouped_50_K96_lora_long.pt")
parser.add_argument("--group_assignments_dir", default="logs/groups")
parser.add_argument("--group_tag", default="s50")
parser.add_argument("--prompt", default="What is the capital of France? Answer in one short sentence.")
parser.add_argument("--max_new_tokens", type=int, default=60)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--no_chat_template", action="store_true",
help="Feed raw prompt without chat template (for completions)")
args = parser.parse_args()
model, tokenizer, cfg, mlp_modules = build_model(
checkpoint_path=args.checkpoint,
group_assignments_dir=args.group_assignments_dir,
group_tag=args.group_tag)
verify_grouped_routing(model, expected_K=96, expected_density=cfg["density"])
print(f"\n=== Generation ===")
print(f"Prompt: {args.prompt!r}")
print(f"Chat template: {not args.no_chat_template}")
print(f"Generating up to {args.max_new_tokens} tokens (temp={args.temperature})...")
full, response = generate(model, tokenizer, args.prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
use_chat_template=not args.no_chat_template)
print(f"\n--- Response ---")
print(response)
print(f"--- Full (with special tokens) ---")
print(full)
print(f"--- End ---")
if __name__ == "__main__":
main()