| |
|
| | """
|
| | Generation script for Circuit Transformer.
|
| |
|
| | Usage:
|
| | python circuits/generate.py --checkpoint circuits/checkpoints/latest.pt --prompt "Once upon a time"
|
| | """
|
| |
|
| | import argparse
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| |
|
| | from transformers import AutoTokenizer
|
| |
|
| | from .config import CircuitConfig
|
| | from .model import CircuitTransformer
|
| | from .mirrored import MirroredConfig, MirroredTransformer
|
| | from .graft_g2lu import load_g2lu_model
|
| | from .layers import build_word_start_table
|
| | from .data import get_tokenizer
|
| |
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(description="Generate text with Circuit Transformer")
|
| |
|
| | parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint")
|
| | parser.add_argument("--prompt", type=str, default="", help="Prompt text")
|
| | parser.add_argument("--max-tokens", type=int, default=100, help="Max tokens to generate")
|
| | parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
| | parser.add_argument("--top-k", type=int, default=50, help="Top-k filtering")
|
| | parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold")
|
| | parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty (1.0=off, 1.3=default for slot models)")
|
| | parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| | parser.add_argument("--no-cache", action="store_true", help="Disable KV cache")
|
| |
|
| | return parser.parse_args()
|
| |
|
| | def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
|
| | """Migrate checkpoint state_dict to match current model architecture.
|
| |
|
| | Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
|
| | """
|
| | if any(k.startswith("_orig_mod.") for k in state_dict):
|
| | state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| |
|
| | model_keys = set(model.state_dict().keys())
|
| | ckpt_keys = set(state_dict.keys())
|
| |
|
| | missing = model_keys - ckpt_keys
|
| | unexpected = ckpt_keys - model_keys
|
| |
|
| | print(unexpected)
|
| |
|
| | if not missing and not unexpected:
|
| | return state_dict
|
| |
|
| | migrated = dict(state_dict)
|
| | migrations = []
|
| |
|
| |
|
| | for key in list(unexpected):
|
| | if ".ffn.gate_expand.weight" in key:
|
| | new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
|
| | if new_key in missing:
|
| | migrated[new_key] = migrated.pop(key)
|
| | missing.discard(new_key)
|
| | unexpected.discard(key)
|
| | migrations.append(f" {key} → {new_key}")
|
| | if ".ffn.gate_compress.weight" in key:
|
| | new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
|
| | if new_key in missing:
|
| | migrated[new_key] = migrated.pop(key)
|
| | missing.discard(new_key)
|
| | unexpected.discard(key)
|
| | migrations.append(f" {key} → {new_key}")
|
| |
|
| | if migrations:
|
| | print(f"State dict migration ({len(migrations)} keys renamed):")
|
| | for m in migrations:
|
| | print(m)
|
| |
|
| | still_missing = model_keys - set(migrated.keys())
|
| | if still_missing:
|
| | print(f" New parameters (freshly initialized): {len(still_missing)}")
|
| | for k in sorted(still_missing):
|
| | print(f" {k}")
|
| |
|
| | return migrated
|
| |
|
| | def generate():
|
| | args = parse_args()
|
| |
|
| |
|
| | device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
| | print(f"Device: {device}")
|
| |
|
| |
|
| | print(f"Loading checkpoint: {args.checkpoint}")
|
| | checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| |
|
| |
|
| | model_type = checkpoint.get("model_type", "standard")
|
| | is_folded = model_type == "folded"
|
| |
|
| | if model_type == "graft_g2lu":
|
| | model = load_g2lu_model(args.checkpoint, device=device)
|
| | model.eval()
|
| | pretrained_name = checkpoint.get("pretrained_name", "unknown")
|
| | print(f"Architecture: G²LU Graft ({pretrained_name}, {len(model.g2lu_mlps)}L)")
|
| | tokenizer_name = checkpoint.get("tokenizer_name", pretrained_name)
|
| | tokenizer = get_tokenizer(tokenizer_name)
|
| | elif is_folded:
|
| | from grafting.fold_llama import FoldedLlama
|
| | model = FoldedLlama.load_from_checkpoint(args.checkpoint, device=device)
|
| | model.eval()
|
| | fold_cfg = model.config
|
| | print(f"Architecture: FoldedLlama ({fold_cfg.model_name}, "
|
| | f"{fold_cfg.n_expand}E+{fold_cfg.n_middle}M+{fold_cfg.n_compress}C)")
|
| | tokenizer = AutoTokenizer.from_pretrained(fold_cfg.model_name, trust_remote_code=True)
|
| | else:
|
| | if model_type == "mirrored":
|
| | if checkpoint["config"].get("dual_gate_middle"):
|
| | checkpoint["config"].pop("dual_gate_middle")
|
| | config = MirroredConfig.from_dict(checkpoint["config"])
|
| | model = MirroredTransformer(config).to(device)
|
| | print(f"Architecture: MirroredTransformer ({model.total_virtual_layers} virtual layers)")
|
| | else:
|
| | config = CircuitConfig.from_dict(checkpoint["config"])
|
| | model = CircuitTransformer(config).to(device)
|
| | print(f"Architecture: CircuitTransformer ({config.num_layers} layers)")
|
| |
|
| |
|
| | state_dict = _migrate_state_dict(checkpoint["model"], model)
|
| |
|
| | model.load_state_dict(state_dict)
|
| | model.eval()
|
| | tokenizer_name = checkpoint.get("tokenizer_name", "gpt2")
|
| | tokenizer = get_tokenizer(tokenizer_name)
|
| |
|
| |
|
| | word_start_table_device = None
|
| | if model_type not in ("graft_g2lu", "folded"):
|
| | ckpt_config = checkpoint.get("config", {})
|
| | word_rope_dims = ckpt_config.get("word_rope_dims", 0)
|
| | if word_rope_dims > 0:
|
| | word_start_table_device = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| | print(f"Word-position RoPE: {word_rope_dims} dims")
|
| |
|
| |
|
| | if args.prompt:
|
| | prompt_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(device)
|
| | else:
|
| |
|
| | prompt_ids = torch.tensor([[tokenizer.eos_token_id]], device=device)
|
| |
|
| | print(f"\nPrompt: {args.prompt or '<empty>'}")
|
| | print(f"Prompt tokens: {prompt_ids.shape[1]}")
|
| | print(f"Generating {args.max_tokens} tokens...")
|
| | print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Top-p: {args.top_p}")
|
| | print("-" * 50)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | gen_kwargs = dict(
|
| | max_new_tokens=args.max_tokens,
|
| | temperature=args.temperature,
|
| | top_k=args.top_k,
|
| | top_p=args.top_p,
|
| | use_cache=not args.no_cache,
|
| | )
|
| | if args.repetition_penalty != 1.0:
|
| | gen_kwargs["repetition_penalty"] = args.repetition_penalty
|
| |
|
| |
|
| | if model_type == "graft_g2lu":
|
| | if args.temperature > 0 and args.temperature != 1.0:
|
| | gen_kwargs["do_sample"] = True
|
| | elif args.top_p < 1.0 or args.top_k > 0:
|
| | gen_kwargs["do_sample"] = True
|
| |
|
| | if word_start_table_device is not None:
|
| | gen_kwargs["word_start_table"] = word_start_table_device
|
| |
|
| | output_ids = model.generate(prompt_ids, **gen_kwargs)
|
| |
|
| |
|
| | generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| | print(generated_text)
|
| | print("-" * 50)
|
| | print(f"Total tokens: {output_ids.shape[1]}")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | generate()
|
| |
|