| """ |
| Configuration for Circuit Transformer experiments. |
| """ |
|
|
| from dataclasses import dataclass, field |
| import argparse |
|
|
|
|
| @dataclass |
| class CircuitConfig: |
| """Configuration for CircuitTransformer model and training.""" |
|
|
| |
| vocab_size: int = 50257 |
| hidden_size: int = 256 |
| num_heads: int = 8 |
| num_kv_heads: int | None = None |
| num_layers: int = 6 |
| max_seq_len: int = 512 |
| dropout: float = 0.0 |
|
|
| |
| batch_size: int = 32 |
| learning_rate: float = 3e-4 |
| min_lr: float = 0.0 |
| weight_decay: float = 0.1 |
| warmup_steps: int = 100 |
| epochs: int = 10 |
| grad_clip: float = 1.0 |
| reset: bool = False |
|
|
| |
| gpu: int = 0 |
| fp16: bool = True |
| bf16: bool = False |
| compile: bool = False |
|
|
| |
| log_every: int = 50 |
| save_every: int = 5000 |
| checkpoint_dir: str = "./circuits/checkpoints" |
|
|
| def __post_init__(self): |
| assert self.hidden_size % self.num_heads == 0, \ |
| f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})" |
| if self.num_kv_heads is not None: |
| assert self.num_heads % self.num_kv_heads == 0, \ |
| f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})" |
|
|
| |
| @classmethod |
| def tiny(cls) -> "CircuitConfig": |
| """~2M params""" |
| return cls(hidden_size=128, num_heads=4, num_layers=4) |
|
|
| @classmethod |
| def small(cls) -> "CircuitConfig": |
| """~10M params""" |
| return cls(hidden_size=256, num_heads=8, num_layers=6) |
|
|
| @classmethod |
| def medium(cls) -> "CircuitConfig": |
| """~50M params""" |
| return cls(hidden_size=512, num_heads=8, num_layers=12) |
|
|
| @classmethod |
| def medium_plus(cls) -> "CircuitConfig": |
| """~50M params""" |
| return cls(hidden_size=512, num_heads=8, num_layers=15) |
|
|
|
|
| @classmethod |
| def medium_wide_9(cls) -> "CircuitConfig": |
| """~50M params""" |
| return cls(hidden_size=640, num_heads=10, num_layers=9) |
|
|
| @classmethod |
| def medium_wide_11(cls) -> "CircuitConfig": |
| """~50M params""" |
| return cls(hidden_size=640, num_heads=10, num_layers=11) |
|
|
| @classmethod |
| def medium_large(cls) -> "CircuitConfig": |
| """~90M params""" |
| return cls(hidden_size=768, num_heads=12, num_layers=12) |
|
|
| @classmethod |
| def large(cls) -> "CircuitConfig": |
| return cls(hidden_size=1280, num_heads=20, num_layers=11) |
|
|
| |
| aux_skip_k: int = 0 |
| aux_skip_weight: float = 0.1 |
|
|
| |
| word_rope_dims: int = 0 |
| word_rope_base: float = 10.0 |
|
|
| |
| embed_dim: int = 0 |
| head_dim: int = 0 |
|
|
| def to_dict(self) -> dict: |
| """Convert to dictionary for serialization.""" |
| d = { |
| "vocab_size": self.vocab_size, |
| "hidden_size": self.hidden_size, |
| "num_heads": self.num_heads, |
| "num_layers": self.num_layers, |
| "max_seq_len": self.max_seq_len, |
| "dropout": self.dropout, |
| } |
| if self.num_kv_heads is not None: |
| d["num_kv_heads"] = self.num_kv_heads |
| if self.aux_skip_k > 0: |
| d["aux_skip_k"] = self.aux_skip_k |
| d["aux_skip_weight"] = self.aux_skip_weight |
| if self.word_rope_dims > 0: |
| d["word_rope_dims"] = self.word_rope_dims |
| d["word_rope_base"] = self.word_rope_base |
| if self.embed_dim > 0: |
| d["embed_dim"] = self.embed_dim |
| if self.head_dim > 0: |
| d["head_dim"] = self.head_dim |
| return d |
|
|
| @classmethod |
| def from_dict(cls, d: dict) -> "CircuitConfig": |
| """Create from dictionary.""" |
| return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) |
|
|
|
|
| def parse_args() -> tuple[CircuitConfig, argparse.Namespace]: |
| """Parse CLI arguments and return config + extra args.""" |
| parser = argparse.ArgumentParser(description="Circuit Transformer Training") |
|
|
| |
| parser.add_argument("--data", type=str, required=True, |
| help="Data source: path/to/file.txt, path/to/dir/, or hf:dataset_name") |
| parser.add_argument("--text-column", type=str, default="text", |
| help="Column name for HF datasets (default: text)") |
| parser.add_argument("--data-format", type=str, choices=["text", "chat"], default="text", |
| help="Data format: text (single column) or chat (system + conversations)") |
| parser.add_argument("--num-samples", type=int, default=None, |
| help="Limit samples from HF dataset") |
| parser.add_argument("--cache-dir", type=str, default="./circuits/.cache", |
| help="Cache directory for tokenized data") |
| parser.add_argument("--no-cache", action="store_true", |
| help="Disable data caching") |
| parser.add_argument("--val-split", type=float, default=0.05, |
| help="Fraction of data for validation (default: 0.05, 0 to disable)") |
|
|
| |
| |
| parser.add_argument("--arch", type=str, choices=["standard", "mirrored", "graft_g2lu"], default="standard", |
| help="Model architecture (default: standard)") |
| parser.add_argument("--preset", type=str, choices=["tiny", "small", "medium", "medium_plus", "medium_large", "medium_wide_9", "medium_wide_11", "large"], |
| help="Use preset configuration") |
| parser.add_argument("--dims", type=int, default=None, help="Hidden size") |
| parser.add_argument("--layers", type=int, default=None, help="Number of layers") |
| parser.add_argument("--heads", type=int, default=None, help="Number of attention heads") |
| parser.add_argument("--kv-heads", type=int, default=None, |
| help="Number of KV heads for GQA (default: same as --heads for MHA)") |
| parser.add_argument("--context-length", type=int, default=None, help="Max sequence length") |
| parser.add_argument("--dropout", type=float, default=None, help="Dropout rate") |
| parser.add_argument("--tokenizer", type=str, default="gpt2", |
| help="Tokenizer to use (default: gpt2, e.g. facebook/MobileLLM-125M)") |
|
|
| |
| parser.add_argument("--n-middle", type=int, default=2, |
| help="Unique middle layers for mirrored arch (default: 2)") |
| parser.add_argument("--share-attention", action="store_true", default=True, |
| help="Share attention weights between mirror pairs (default)") |
| parser.add_argument("--no-share-attention", dest="share_attention", action="store_false", |
| help="Separate attention weights per direction") |
| |
| |
| parser.add_argument("--no-g2lu", action="store_true", |
| help="Disable G²LU (use vanilla SwiGLU in mirrored arch)") |
|
|
| |
| parser.add_argument("--aux-skip", type=int, default=0, |
| help="Skip-ahead prediction distance (0 = disabled, e.g. 5 predicts t+5)") |
| parser.add_argument("--aux-weight", type=float, default=0.1, |
| help="Weight for auxiliary skip loss (default: 0.1)") |
|
|
| |
| parser.add_argument("--word-rope-dims", type=int, default=0, |
| help="Head dims dedicated to word-position RoPE (0=disabled, try 8 or 16)") |
| parser.add_argument("--word-rope-base", type=float, default=10.0, |
| help="Frequency base for word-position RoPE (default: 10.0)") |
|
|
| |
| parser.add_argument("--embed-dim", type=int, default=0, |
| help="Factorized embedding dim (0=use hidden_size, e.g. 256)") |
| parser.add_argument("--head-dim", type=int, default=0, |
| help="MLP head intermediate dim (0=linear head, e.g. 512)") |
|
|
| |
| parser.add_argument("--pretrained", type=str, default=None, |
| help="HuggingFace model for graft_g2lu (e.g. meta-llama/Llama-3.2-1B)") |
| parser.add_argument("--align-weight", type=float, default=1.0, |
| help="Alignment loss weight for G²LU grafting (default: 1.0)") |
| parser.add_argument("--graft-warmup", type=int, default=500, |
| help="Blend warmup steps: SwiGLU→G²LU transition (default: 500)") |
|
|
| |
| parser.add_argument("--epochs", type=int, default=None) |
| parser.add_argument("--batch-size", type=int, default=None) |
| parser.add_argument("--lr", type=float, default=None, help="Learning rate") |
| parser.add_argument("--min-lr", type=float, default=None, |
| help="Minimum learning rate for cosine decay (default: 0)") |
| parser.add_argument("--weight-decay", type=float, default=None) |
| parser.add_argument("--warmup-steps", type=int, default=None) |
| parser.add_argument("--grad-clip", type=float, default=None) |
| parser.add_argument("--grad-accum", type=int, default=1, |
| help="Gradient accumulation steps (effective batch = batch_size * grad_accum)") |
|
|
| |
| parser.add_argument("--gpu", type=int, default=0) |
| parser.add_argument("--fp16", action="store_true", help="Use FP16 mixed precision (with GradScaler)") |
| parser.add_argument("--bf16", action="store_true", help="Use BF16 mixed precision (no scaler needed)") |
| parser.add_argument("--no-fp16", action="store_true", help="Disable mixed precision (FP32)") |
| parser.add_argument("--compile", action="store_true", help="Use torch.compile") |
|
|
| |
| parser.add_argument("--log-every", type=int, default=None) |
| parser.add_argument("--save-every", type=int, default=None) |
| parser.add_argument("--val-every", type=int, default=0, |
| help="Run validation every N steps (0 = only at epoch end)") |
| parser.add_argument("--checkpoint-dir", type=str, default=None) |
| parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") |
| parser.add_argument("--reset", action="store_true", default=False, help="When resuming the training resets steps and optimizers") |
|
|
| args = parser.parse_args() |
|
|
| |
| if args.preset: |
| config = getattr(CircuitConfig, args.preset)() |
| else: |
| config = CircuitConfig() |
|
|
| |
| if args.dims is not None: |
| config.hidden_size = args.dims |
| if args.layers is not None: |
| config.num_layers = args.layers |
| if args.heads is not None: |
| config.num_heads = args.heads |
| if args.kv_heads is not None: |
| config.num_kv_heads = args.kv_heads |
| if args.context_length is not None: |
| config.max_seq_len = args.context_length |
| if args.dropout is not None: |
| config.dropout = args.dropout |
| if args.epochs is not None: |
| config.epochs = args.epochs |
| if args.batch_size is not None: |
| config.batch_size = args.batch_size |
| if args.lr is not None: |
| config.learning_rate = args.lr |
| if args.min_lr is not None: |
| config.min_lr = args.min_lr |
| if args.weight_decay is not None: |
| config.weight_decay = args.weight_decay |
| if args.warmup_steps is not None: |
| config.warmup_steps = args.warmup_steps |
| if args.grad_clip is not None: |
| config.grad_clip = args.grad_clip |
| if args.log_every is not None: |
| config.log_every = args.log_every |
| if args.save_every is not None: |
| config.save_every = args.save_every |
| if args.checkpoint_dir is not None: |
| config.checkpoint_dir = args.checkpoint_dir |
|
|
| |
| if args.aux_skip > 0: |
| config.aux_skip_k = args.aux_skip |
| config.aux_skip_weight = args.aux_weight |
|
|
| |
| if args.word_rope_dims > 0: |
| config.word_rope_dims = args.word_rope_dims |
| config.word_rope_base = args.word_rope_base |
|
|
| |
| if args.embed_dim > 0: |
| config.embed_dim = args.embed_dim |
| if args.head_dim > 0: |
| config.head_dim = args.head_dim |
|
|
| config.gpu = args.gpu |
| if args.bf16: |
| config.bf16 = True |
| config.fp16 = False |
| elif args.no_fp16: |
| config.fp16 = False |
| config.bf16 = False |
| elif args.fp16: |
| config.fp16 = True |
| config.bf16 = False |
| config.compile = args.compile |
| config.reset = args.reset |
|
|
| return config, args |
|
|