Prisma / generate.py
y3i12's picture
Initial commit
56e82ec
#!/usr/bin/env python3
"""
Generation script for Circuit Transformer.
Usage:
python circuits/generate.py --checkpoint circuits/checkpoints/latest.pt --prompt "Once upon a time"
"""
import argparse
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from .config import CircuitConfig
from .model import CircuitTransformer
from .mirrored import MirroredConfig, MirroredTransformer
from .graft_g2lu import load_g2lu_model
from .layers import build_word_start_table
from .data import get_tokenizer
def parse_args():
parser = argparse.ArgumentParser(description="Generate text with Circuit Transformer")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint")
parser.add_argument("--prompt", type=str, default="", help="Prompt text")
parser.add_argument("--max-tokens", type=int, default=100, help="Max tokens to generate")
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
parser.add_argument("--top-k", type=int, default=50, help="Top-k filtering")
parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold")
parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty (1.0=off, 1.3=default for slot models)")
parser.add_argument("--gpu", type=int, default=0, help="GPU index")
parser.add_argument("--no-cache", action="store_true", help="Disable KV cache")
return parser.parse_args()
def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
"""Migrate checkpoint state_dict to match current model architecture.
Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
"""
if any(k.startswith("_orig_mod.") for k in state_dict):
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
model_keys = set(model.state_dict().keys())
ckpt_keys = set(state_dict.keys())
missing = model_keys - ckpt_keys
unexpected = ckpt_keys - model_keys
print(unexpected)
if not missing and not unexpected:
return state_dict # perfect match, no migration needed
migrated = dict(state_dict)
migrations = []
# SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
for key in list(unexpected):
if ".ffn.gate_expand.weight" in key:
new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
if new_key in missing:
migrated[new_key] = migrated.pop(key)
missing.discard(new_key)
unexpected.discard(key)
migrations.append(f" {key}{new_key}")
if ".ffn.gate_compress.weight" in key:
new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
if new_key in missing:
migrated[new_key] = migrated.pop(key)
missing.discard(new_key)
unexpected.discard(key)
migrations.append(f" {key}{new_key}")
if migrations:
print(f"State dict migration ({len(migrations)} keys renamed):")
for m in migrations:
print(m)
# Report remaining missing keys (freshly initialized)
still_missing = model_keys - set(migrated.keys())
if still_missing:
print(f" New parameters (freshly initialized): {len(still_missing)}")
for k in sorted(still_missing):
print(f" {k}")
return migrated
def generate():
args = parse_args()
# Setup device
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Load checkpoint
print(f"Loading checkpoint: {args.checkpoint}")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
# Reconstruct config and model based on architecture type
model_type = checkpoint.get("model_type", "standard")
is_folded = model_type == "folded"
if model_type == "graft_g2lu":
model = load_g2lu_model(args.checkpoint, device=device)
model.eval()
pretrained_name = checkpoint.get("pretrained_name", "unknown")
print(f"Architecture: G²LU Graft ({pretrained_name}, {len(model.g2lu_mlps)}L)")
tokenizer_name = checkpoint.get("tokenizer_name", pretrained_name)
tokenizer = get_tokenizer(tokenizer_name)
elif is_folded:
from grafting.fold_llama import FoldedLlama
model = FoldedLlama.load_from_checkpoint(args.checkpoint, device=device)
model.eval()
fold_cfg = model.config
print(f"Architecture: FoldedLlama ({fold_cfg.model_name}, "
f"{fold_cfg.n_expand}E+{fold_cfg.n_middle}M+{fold_cfg.n_compress}C)")
tokenizer = AutoTokenizer.from_pretrained(fold_cfg.model_name, trust_remote_code=True)
else:
if model_type == "mirrored":
if checkpoint["config"].get("dual_gate_middle"):
checkpoint["config"].pop("dual_gate_middle")
config = MirroredConfig.from_dict(checkpoint["config"])
model = MirroredTransformer(config).to(device)
print(f"Architecture: MirroredTransformer ({model.total_virtual_layers} virtual layers)")
else:
config = CircuitConfig.from_dict(checkpoint["config"])
model = CircuitTransformer(config).to(device)
print(f"Architecture: CircuitTransformer ({config.num_layers} layers)")
# Strip _orig_mod. prefix from torch.compile'd checkpoints
state_dict = _migrate_state_dict(checkpoint["model"], model)
model.load_state_dict(state_dict)
model.eval()
tokenizer_name = checkpoint.get("tokenizer_name", "gpt2")
tokenizer = get_tokenizer(tokenizer_name)
# Build word-position table if model uses SemRoPE
word_start_table_device = None
if model_type not in ("graft_g2lu", "folded"):
ckpt_config = checkpoint.get("config", {})
word_rope_dims = ckpt_config.get("word_rope_dims", 0)
if word_rope_dims > 0:
word_start_table_device = build_word_start_table(tokenizer, len(tokenizer)).to(device)
print(f"Word-position RoPE: {word_rope_dims} dims")
# Tokenize prompt
if args.prompt:
prompt_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(device)
else:
# Start with BOS/EOS token
prompt_ids = torch.tensor([[tokenizer.eos_token_id]], device=device)
print(f"\nPrompt: {args.prompt or '<empty>'}")
print(f"Prompt tokens: {prompt_ids.shape[1]}")
print(f"Generating {args.max_tokens} tokens...")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Top-p: {args.top_p}")
print("-" * 50)
# Generate
with torch.no_grad():
gen_kwargs = dict(
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
use_cache=not args.no_cache,
)
if args.repetition_penalty != 1.0:
gen_kwargs["repetition_penalty"] = args.repetition_penalty
# HF models need do_sample=True for temperature/top_k/top_p
if model_type == "graft_g2lu":
if args.temperature > 0 and args.temperature != 1.0:
gen_kwargs["do_sample"] = True
elif args.top_p < 1.0 or args.top_k > 0:
gen_kwargs["do_sample"] = True
if word_start_table_device is not None:
gen_kwargs["word_start_table"] = word_start_table_device
output_ids = model.generate(prompt_ids, **gen_kwargs)
# Decode and print
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(generated_text)
print("-" * 50)
print(f"Total tokens: {output_ids.shape[1]}")
if __name__ == "__main__":
generate()