File size: 8,086 Bytes
56e82ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | #!/usr/bin/env python3
"""
Generation script for Circuit Transformer.
Usage:
python circuits/generate.py --checkpoint circuits/checkpoints/latest.pt --prompt "Once upon a time"
"""
import argparse
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from .config import CircuitConfig
from .model import CircuitTransformer
from .mirrored import MirroredConfig, MirroredTransformer
from .graft_g2lu import load_g2lu_model
from .layers import build_word_start_table
from .data import get_tokenizer
def parse_args():
parser = argparse.ArgumentParser(description="Generate text with Circuit Transformer")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint")
parser.add_argument("--prompt", type=str, default="", help="Prompt text")
parser.add_argument("--max-tokens", type=int, default=100, help="Max tokens to generate")
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
parser.add_argument("--top-k", type=int, default=50, help="Top-k filtering")
parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold")
parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty (1.0=off, 1.3=default for slot models)")
parser.add_argument("--gpu", type=int, default=0, help="GPU index")
parser.add_argument("--no-cache", action="store_true", help="Disable KV cache")
return parser.parse_args()
def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
"""Migrate checkpoint state_dict to match current model architecture.
Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
"""
if any(k.startswith("_orig_mod.") for k in state_dict):
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
model_keys = set(model.state_dict().keys())
ckpt_keys = set(state_dict.keys())
missing = model_keys - ckpt_keys
unexpected = ckpt_keys - model_keys
print(unexpected)
if not missing and not unexpected:
return state_dict # perfect match, no migration needed
migrated = dict(state_dict)
migrations = []
# SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
for key in list(unexpected):
if ".ffn.gate_expand.weight" in key:
new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
if new_key in missing:
migrated[new_key] = migrated.pop(key)
missing.discard(new_key)
unexpected.discard(key)
migrations.append(f" {key} → {new_key}")
if ".ffn.gate_compress.weight" in key:
new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
if new_key in missing:
migrated[new_key] = migrated.pop(key)
missing.discard(new_key)
unexpected.discard(key)
migrations.append(f" {key} → {new_key}")
if migrations:
print(f"State dict migration ({len(migrations)} keys renamed):")
for m in migrations:
print(m)
# Report remaining missing keys (freshly initialized)
still_missing = model_keys - set(migrated.keys())
if still_missing:
print(f" New parameters (freshly initialized): {len(still_missing)}")
for k in sorted(still_missing):
print(f" {k}")
return migrated
def generate():
args = parse_args()
# Setup device
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Load checkpoint
print(f"Loading checkpoint: {args.checkpoint}")
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
# Reconstruct config and model based on architecture type
model_type = checkpoint.get("model_type", "standard")
is_folded = model_type == "folded"
if model_type == "graft_g2lu":
model = load_g2lu_model(args.checkpoint, device=device)
model.eval()
pretrained_name = checkpoint.get("pretrained_name", "unknown")
print(f"Architecture: G²LU Graft ({pretrained_name}, {len(model.g2lu_mlps)}L)")
tokenizer_name = checkpoint.get("tokenizer_name", pretrained_name)
tokenizer = get_tokenizer(tokenizer_name)
elif is_folded:
from grafting.fold_llama import FoldedLlama
model = FoldedLlama.load_from_checkpoint(args.checkpoint, device=device)
model.eval()
fold_cfg = model.config
print(f"Architecture: FoldedLlama ({fold_cfg.model_name}, "
f"{fold_cfg.n_expand}E+{fold_cfg.n_middle}M+{fold_cfg.n_compress}C)")
tokenizer = AutoTokenizer.from_pretrained(fold_cfg.model_name, trust_remote_code=True)
else:
if model_type == "mirrored":
if checkpoint["config"].get("dual_gate_middle"):
checkpoint["config"].pop("dual_gate_middle")
config = MirroredConfig.from_dict(checkpoint["config"])
model = MirroredTransformer(config).to(device)
print(f"Architecture: MirroredTransformer ({model.total_virtual_layers} virtual layers)")
else:
config = CircuitConfig.from_dict(checkpoint["config"])
model = CircuitTransformer(config).to(device)
print(f"Architecture: CircuitTransformer ({config.num_layers} layers)")
# Strip _orig_mod. prefix from torch.compile'd checkpoints
state_dict = _migrate_state_dict(checkpoint["model"], model)
model.load_state_dict(state_dict)
model.eval()
tokenizer_name = checkpoint.get("tokenizer_name", "gpt2")
tokenizer = get_tokenizer(tokenizer_name)
# Build word-position table if model uses SemRoPE
word_start_table_device = None
if model_type not in ("graft_g2lu", "folded"):
ckpt_config = checkpoint.get("config", {})
word_rope_dims = ckpt_config.get("word_rope_dims", 0)
if word_rope_dims > 0:
word_start_table_device = build_word_start_table(tokenizer, len(tokenizer)).to(device)
print(f"Word-position RoPE: {word_rope_dims} dims")
# Tokenize prompt
if args.prompt:
prompt_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(device)
else:
# Start with BOS/EOS token
prompt_ids = torch.tensor([[tokenizer.eos_token_id]], device=device)
print(f"\nPrompt: {args.prompt or '<empty>'}")
print(f"Prompt tokens: {prompt_ids.shape[1]}")
print(f"Generating {args.max_tokens} tokens...")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Top-p: {args.top_p}")
print("-" * 50)
# Generate
with torch.no_grad():
gen_kwargs = dict(
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
use_cache=not args.no_cache,
)
if args.repetition_penalty != 1.0:
gen_kwargs["repetition_penalty"] = args.repetition_penalty
# HF models need do_sample=True for temperature/top_k/top_p
if model_type == "graft_g2lu":
if args.temperature > 0 and args.temperature != 1.0:
gen_kwargs["do_sample"] = True
elif args.top_p < 1.0 or args.top_k > 0:
gen_kwargs["do_sample"] = True
if word_start_table_device is not None:
gen_kwargs["word_start_table"] = word_start_table_device
output_ids = model.generate(prompt_ids, **gen_kwargs)
# Decode and print
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(generated_text)
print("-" * 50)
print(f"Total tokens: {output_ids.shape[1]}")
if __name__ == "__main__":
generate()
|