""" Configuration for Circuit Transformer experiments. """ from dataclasses import dataclass, field import argparse @dataclass class CircuitConfig: """Configuration for CircuitTransformer model and training.""" # Model architecture vocab_size: int = 50257 # GPT-2 tokenizer hidden_size: int = 256 num_heads: int = 8 num_kv_heads: int | None = None # GQA: None = same as num_heads (MHA) num_layers: int = 6 max_seq_len: int = 512 dropout: float = 0.0 # Training batch_size: int = 32 learning_rate: float = 3e-4 min_lr: float = 0.0 weight_decay: float = 0.1 warmup_steps: int = 100 epochs: int = 10 grad_clip: float = 1.0 reset: bool = False # Hardware gpu: int = 0 fp16: bool = True bf16: bool = False compile: bool = False # Logging log_every: int = 50 save_every: int = 5000 checkpoint_dir: str = "./circuits/checkpoints" def __post_init__(self): assert self.hidden_size % self.num_heads == 0, \ f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})" if self.num_kv_heads is not None: assert self.num_heads % self.num_kv_heads == 0, \ f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})" # Presets @classmethod def tiny(cls) -> "CircuitConfig": """~2M params""" return cls(hidden_size=128, num_heads=4, num_layers=4) @classmethod def small(cls) -> "CircuitConfig": """~10M params""" return cls(hidden_size=256, num_heads=8, num_layers=6) @classmethod def medium(cls) -> "CircuitConfig": """~50M params""" return cls(hidden_size=512, num_heads=8, num_layers=12) @classmethod def medium_plus(cls) -> "CircuitConfig": """~50M params""" return cls(hidden_size=512, num_heads=8, num_layers=15) @classmethod def medium_wide_9(cls) -> "CircuitConfig": """~50M params""" return cls(hidden_size=640, num_heads=10, num_layers=9) @classmethod def medium_wide_11(cls) -> "CircuitConfig": """~50M params""" return cls(hidden_size=640, num_heads=10, num_layers=11) @classmethod def medium_large(cls) -> "CircuitConfig": """~90M params""" return cls(hidden_size=768, num_heads=12, num_layers=12) @classmethod def large(cls) -> "CircuitConfig": return cls(hidden_size=1280, num_heads=20, num_layers=11) # Auxiliary objectives aux_skip_k: int = 0 # skip-ahead prediction distance (0 = disabled) aux_skip_weight: float = 0.1 # weight for auxiliary skip loss # Word-position RoPE (SemRoPE) word_rope_dims: int = 0 # head dims for word-position RoPE (0 = disabled) word_rope_base: float = 10.0 # frequency base for word-position RoPE # Factorized embedding / MLP head embed_dim: int = 0 # factorized embedding dim (0 = use hidden_size) head_dim: int = 0 # MLP head intermediate dim (0 = linear head) def to_dict(self) -> dict: """Convert to dictionary for serialization.""" d = { "vocab_size": self.vocab_size, "hidden_size": self.hidden_size, "num_heads": self.num_heads, "num_layers": self.num_layers, "max_seq_len": self.max_seq_len, "dropout": self.dropout, } if self.num_kv_heads is not None: d["num_kv_heads"] = self.num_kv_heads if self.aux_skip_k > 0: d["aux_skip_k"] = self.aux_skip_k d["aux_skip_weight"] = self.aux_skip_weight if self.word_rope_dims > 0: d["word_rope_dims"] = self.word_rope_dims d["word_rope_base"] = self.word_rope_base if self.embed_dim > 0: d["embed_dim"] = self.embed_dim if self.head_dim > 0: d["head_dim"] = self.head_dim return d @classmethod def from_dict(cls, d: dict) -> "CircuitConfig": """Create from dictionary.""" return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) def parse_args() -> tuple[CircuitConfig, argparse.Namespace]: """Parse CLI arguments and return config + extra args.""" parser = argparse.ArgumentParser(description="Circuit Transformer Training") # Data parser.add_argument("--data", type=str, required=True, help="Data source: path/to/file.txt, path/to/dir/, or hf:dataset_name") parser.add_argument("--text-column", type=str, default="text", help="Column name for HF datasets (default: text)") parser.add_argument("--data-format", type=str, choices=["text", "chat"], default="text", help="Data format: text (single column) or chat (system + conversations)") parser.add_argument("--num-samples", type=int, default=None, help="Limit samples from HF dataset") parser.add_argument("--cache-dir", type=str, default="./circuits/.cache", help="Cache directory for tokenized data") parser.add_argument("--no-cache", action="store_true", help="Disable data caching") parser.add_argument("--val-split", type=float, default=0.05, help="Fraction of data for validation (default: 0.05, 0 to disable)") # Model architecture # TODO: Remove `slot_mirrored` parser.add_argument("--arch", type=str, choices=["standard", "mirrored", "graft_g2lu"], default="standard", help="Model architecture (default: standard)") parser.add_argument("--preset", type=str, choices=["tiny", "small", "medium", "medium_plus", "medium_large", "medium_wide_9", "medium_wide_11", "large"], help="Use preset configuration") parser.add_argument("--dims", type=int, default=None, help="Hidden size") parser.add_argument("--layers", type=int, default=None, help="Number of layers") parser.add_argument("--heads", type=int, default=None, help="Number of attention heads") parser.add_argument("--kv-heads", type=int, default=None, help="Number of KV heads for GQA (default: same as --heads for MHA)") parser.add_argument("--context-length", type=int, default=None, help="Max sequence length") parser.add_argument("--dropout", type=float, default=None, help="Dropout rate") parser.add_argument("--tokenizer", type=str, default="gpt2", help="Tokenizer to use (default: gpt2, e.g. facebook/MobileLLM-125M)") # Mirrored architecture specific parser.add_argument("--n-middle", type=int, default=2, help="Unique middle layers for mirrored arch (default: 2)") parser.add_argument("--share-attention", action="store_true", default=True, help="Share attention weights between mirror pairs (default)") parser.add_argument("--no-share-attention", dest="share_attention", action="store_false", help="Separate attention weights per direction") # G²LU gating parser.add_argument("--no-g2lu", action="store_true", help="Disable G²LU (use vanilla SwiGLU in mirrored arch)") # Auxiliary objectives parser.add_argument("--aux-skip", type=int, default=0, help="Skip-ahead prediction distance (0 = disabled, e.g. 5 predicts t+5)") parser.add_argument("--aux-weight", type=float, default=0.1, help="Weight for auxiliary skip loss (default: 0.1)") # Word-position RoPE (SemRoPE) parser.add_argument("--word-rope-dims", type=int, default=0, help="Head dims dedicated to word-position RoPE (0=disabled, try 8 or 16)") parser.add_argument("--word-rope-base", type=float, default=10.0, help="Frequency base for word-position RoPE (default: 10.0)") # Factorized embedding / MLP head parser.add_argument("--embed-dim", type=int, default=0, help="Factorized embedding dim (0=use hidden_size, e.g. 256)") parser.add_argument("--head-dim", type=int, default=0, help="MLP head intermediate dim (0=linear head, e.g. 512)") # G²LU gate grafting parser.add_argument("--pretrained", type=str, default=None, help="HuggingFace model for graft_g2lu (e.g. meta-llama/Llama-3.2-1B)") parser.add_argument("--align-weight", type=float, default=1.0, help="Alignment loss weight for G²LU grafting (default: 1.0)") parser.add_argument("--graft-warmup", type=int, default=500, help="Blend warmup steps: SwiGLU→G²LU transition (default: 500)") # Training parser.add_argument("--epochs", type=int, default=None) parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--lr", type=float, default=None, help="Learning rate") parser.add_argument("--min-lr", type=float, default=None, help="Minimum learning rate for cosine decay (default: 0)") parser.add_argument("--weight-decay", type=float, default=None) parser.add_argument("--warmup-steps", type=int, default=None) parser.add_argument("--grad-clip", type=float, default=None) parser.add_argument("--grad-accum", type=int, default=1, help="Gradient accumulation steps (effective batch = batch_size * grad_accum)") # Hardware parser.add_argument("--gpu", type=int, default=0) parser.add_argument("--fp16", action="store_true", help="Use FP16 mixed precision (with GradScaler)") parser.add_argument("--bf16", action="store_true", help="Use BF16 mixed precision (no scaler needed)") parser.add_argument("--no-fp16", action="store_true", help="Disable mixed precision (FP32)") parser.add_argument("--compile", action="store_true", help="Use torch.compile") # Logging/Checkpointing parser.add_argument("--log-every", type=int, default=None) parser.add_argument("--save-every", type=int, default=None) parser.add_argument("--val-every", type=int, default=0, help="Run validation every N steps (0 = only at epoch end)") parser.add_argument("--checkpoint-dir", type=str, default=None) parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") parser.add_argument("--reset", action="store_true", default=False, help="When resuming the training resets steps and optimizers") args = parser.parse_args() # Build config from preset or defaults if args.preset: config = getattr(CircuitConfig, args.preset)() else: config = CircuitConfig() # Override with explicit args if args.dims is not None: config.hidden_size = args.dims if args.layers is not None: config.num_layers = args.layers if args.heads is not None: config.num_heads = args.heads if args.kv_heads is not None: config.num_kv_heads = args.kv_heads if args.context_length is not None: config.max_seq_len = args.context_length if args.dropout is not None: config.dropout = args.dropout if args.epochs is not None: config.epochs = args.epochs if args.batch_size is not None: config.batch_size = args.batch_size if args.lr is not None: config.learning_rate = args.lr if args.min_lr is not None: config.min_lr = args.min_lr if args.weight_decay is not None: config.weight_decay = args.weight_decay if args.warmup_steps is not None: config.warmup_steps = args.warmup_steps if args.grad_clip is not None: config.grad_clip = args.grad_clip if args.log_every is not None: config.log_every = args.log_every if args.save_every is not None: config.save_every = args.save_every if args.checkpoint_dir is not None: config.checkpoint_dir = args.checkpoint_dir # Auxiliary objectives if args.aux_skip > 0: config.aux_skip_k = args.aux_skip config.aux_skip_weight = args.aux_weight # Word-position RoPE if args.word_rope_dims > 0: config.word_rope_dims = args.word_rope_dims config.word_rope_base = args.word_rope_base # Factorized embedding / MLP head if args.embed_dim > 0: config.embed_dim = args.embed_dim if args.head_dim > 0: config.head_dim = args.head_dim config.gpu = args.gpu if args.bf16: config.bf16 = True config.fp16 = False elif args.no_fp16: config.fp16 = False config.bf16 = False elif args.fp16: config.fp16 = True config.bf16 = False config.compile = args.compile config.reset = args.reset return config, args