#!/usr/bin/env python3 """ Generation script for Circuit Transformer. Usage: python circuits/generate.py --checkpoint circuits/checkpoints/latest.pt --prompt "Once upon a time" """ import argparse import torch import torch.nn as nn from transformers import AutoTokenizer from .config import CircuitConfig from .model import CircuitTransformer from .mirrored import MirroredConfig, MirroredTransformer from .graft_g2lu import load_g2lu_model from .layers import build_word_start_table from .data import get_tokenizer def parse_args(): parser = argparse.ArgumentParser(description="Generate text with Circuit Transformer") parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint") parser.add_argument("--prompt", type=str, default="", help="Prompt text") parser.add_argument("--max-tokens", type=int, default=100, help="Max tokens to generate") parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") parser.add_argument("--top-k", type=int, default=50, help="Top-k filtering") parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold") parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty (1.0=off, 1.3=default for slot models)") parser.add_argument("--gpu", type=int, default=0, help="GPU index") parser.add_argument("--no-cache", action="store_true", help="Disable KV cache") return parser.parse_args() def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict: """Migrate checkpoint state_dict to match current model architecture. Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle). """ if any(k.startswith("_orig_mod.") for k in state_dict): state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()} model_keys = set(model.state_dict().keys()) ckpt_keys = set(state_dict.keys()) missing = model_keys - ckpt_keys unexpected = ckpt_keys - model_keys print(unexpected) if not missing and not unexpected: return state_dict # perfect match, no migration needed migrated = dict(state_dict) migrations = [] # SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade) for key in list(unexpected): if ".ffn.gate_expand.weight" in key: new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight") if new_key in missing: migrated[new_key] = migrated.pop(key) missing.discard(new_key) unexpected.discard(key) migrations.append(f" {key} → {new_key}") if ".ffn.gate_compress.weight" in key: new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight") if new_key in missing: migrated[new_key] = migrated.pop(key) missing.discard(new_key) unexpected.discard(key) migrations.append(f" {key} → {new_key}") if migrations: print(f"State dict migration ({len(migrations)} keys renamed):") for m in migrations: print(m) # Report remaining missing keys (freshly initialized) still_missing = model_keys - set(migrated.keys()) if still_missing: print(f" New parameters (freshly initialized): {len(still_missing)}") for k in sorted(still_missing): print(f" {k}") return migrated def generate(): args = parse_args() # Setup device device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # Load checkpoint print(f"Loading checkpoint: {args.checkpoint}") checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) # Reconstruct config and model based on architecture type model_type = checkpoint.get("model_type", "standard") is_folded = model_type == "folded" if model_type == "graft_g2lu": model = load_g2lu_model(args.checkpoint, device=device) model.eval() pretrained_name = checkpoint.get("pretrained_name", "unknown") print(f"Architecture: G²LU Graft ({pretrained_name}, {len(model.g2lu_mlps)}L)") tokenizer_name = checkpoint.get("tokenizer_name", pretrained_name) tokenizer = get_tokenizer(tokenizer_name) elif is_folded: from grafting.fold_llama import FoldedLlama model = FoldedLlama.load_from_checkpoint(args.checkpoint, device=device) model.eval() fold_cfg = model.config print(f"Architecture: FoldedLlama ({fold_cfg.model_name}, " f"{fold_cfg.n_expand}E+{fold_cfg.n_middle}M+{fold_cfg.n_compress}C)") tokenizer = AutoTokenizer.from_pretrained(fold_cfg.model_name, trust_remote_code=True) else: if model_type == "mirrored": if checkpoint["config"].get("dual_gate_middle"): checkpoint["config"].pop("dual_gate_middle") config = MirroredConfig.from_dict(checkpoint["config"]) model = MirroredTransformer(config).to(device) print(f"Architecture: MirroredTransformer ({model.total_virtual_layers} virtual layers)") else: config = CircuitConfig.from_dict(checkpoint["config"]) model = CircuitTransformer(config).to(device) print(f"Architecture: CircuitTransformer ({config.num_layers} layers)") # Strip _orig_mod. prefix from torch.compile'd checkpoints state_dict = _migrate_state_dict(checkpoint["model"], model) model.load_state_dict(state_dict) model.eval() tokenizer_name = checkpoint.get("tokenizer_name", "gpt2") tokenizer = get_tokenizer(tokenizer_name) # Build word-position table if model uses SemRoPE word_start_table_device = None if model_type not in ("graft_g2lu", "folded"): ckpt_config = checkpoint.get("config", {}) word_rope_dims = ckpt_config.get("word_rope_dims", 0) if word_rope_dims > 0: word_start_table_device = build_word_start_table(tokenizer, len(tokenizer)).to(device) print(f"Word-position RoPE: {word_rope_dims} dims") # Tokenize prompt if args.prompt: prompt_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(device) else: # Start with BOS/EOS token prompt_ids = torch.tensor([[tokenizer.eos_token_id]], device=device) print(f"\nPrompt: {args.prompt or ''}") print(f"Prompt tokens: {prompt_ids.shape[1]}") print(f"Generating {args.max_tokens} tokens...") print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Top-p: {args.top_p}") print("-" * 50) # Generate with torch.no_grad(): gen_kwargs = dict( max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, use_cache=not args.no_cache, ) if args.repetition_penalty != 1.0: gen_kwargs["repetition_penalty"] = args.repetition_penalty # HF models need do_sample=True for temperature/top_k/top_p if model_type == "graft_g2lu": if args.temperature > 0 and args.temperature != 1.0: gen_kwargs["do_sample"] = True elif args.top_p < 1.0 or args.top_k > 0: gen_kwargs["do_sample"] = True if word_start_table_device is not None: gen_kwargs["word_start_table"] = word_start_table_device output_ids = model.generate(prompt_ids, **gen_kwargs) # Decode and print generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) print(generated_text) print("-" * 50) print(f"Total tokens: {output_ids.shape[1]}") if __name__ == "__main__": generate()