Prisma / config.py
y3i12's picture
Initial commit
56e82ec
"""
Configuration for Circuit Transformer experiments.
"""
from dataclasses import dataclass, field
import argparse
@dataclass
class CircuitConfig:
"""Configuration for CircuitTransformer model and training."""
# Model architecture
vocab_size: int = 50257 # GPT-2 tokenizer
hidden_size: int = 256
num_heads: int = 8
num_kv_heads: int | None = None # GQA: None = same as num_heads (MHA)
num_layers: int = 6
max_seq_len: int = 512
dropout: float = 0.0
# Training
batch_size: int = 32
learning_rate: float = 3e-4
min_lr: float = 0.0
weight_decay: float = 0.1
warmup_steps: int = 100
epochs: int = 10
grad_clip: float = 1.0
reset: bool = False
# Hardware
gpu: int = 0
fp16: bool = True
bf16: bool = False
compile: bool = False
# Logging
log_every: int = 50
save_every: int = 5000
checkpoint_dir: str = "./circuits/checkpoints"
def __post_init__(self):
assert self.hidden_size % self.num_heads == 0, \
f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})"
if self.num_kv_heads is not None:
assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
# Presets
@classmethod
def tiny(cls) -> "CircuitConfig":
"""~2M params"""
return cls(hidden_size=128, num_heads=4, num_layers=4)
@classmethod
def small(cls) -> "CircuitConfig":
"""~10M params"""
return cls(hidden_size=256, num_heads=8, num_layers=6)
@classmethod
def medium(cls) -> "CircuitConfig":
"""~50M params"""
return cls(hidden_size=512, num_heads=8, num_layers=12)
@classmethod
def medium_plus(cls) -> "CircuitConfig":
"""~50M params"""
return cls(hidden_size=512, num_heads=8, num_layers=15)
@classmethod
def medium_wide_9(cls) -> "CircuitConfig":
"""~50M params"""
return cls(hidden_size=640, num_heads=10, num_layers=9)
@classmethod
def medium_wide_11(cls) -> "CircuitConfig":
"""~50M params"""
return cls(hidden_size=640, num_heads=10, num_layers=11)
@classmethod
def medium_large(cls) -> "CircuitConfig":
"""~90M params"""
return cls(hidden_size=768, num_heads=12, num_layers=12)
@classmethod
def large(cls) -> "CircuitConfig":
return cls(hidden_size=1280, num_heads=20, num_layers=11)
# Auxiliary objectives
aux_skip_k: int = 0 # skip-ahead prediction distance (0 = disabled)
aux_skip_weight: float = 0.1 # weight for auxiliary skip loss
# Word-position RoPE (SemRoPE)
word_rope_dims: int = 0 # head dims for word-position RoPE (0 = disabled)
word_rope_base: float = 10.0 # frequency base for word-position RoPE
# Factorized embedding / MLP head
embed_dim: int = 0 # factorized embedding dim (0 = use hidden_size)
head_dim: int = 0 # MLP head intermediate dim (0 = linear head)
def to_dict(self) -> dict:
"""Convert to dictionary for serialization."""
d = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"num_heads": self.num_heads,
"num_layers": self.num_layers,
"max_seq_len": self.max_seq_len,
"dropout": self.dropout,
}
if self.num_kv_heads is not None:
d["num_kv_heads"] = self.num_kv_heads
if self.aux_skip_k > 0:
d["aux_skip_k"] = self.aux_skip_k
d["aux_skip_weight"] = self.aux_skip_weight
if self.word_rope_dims > 0:
d["word_rope_dims"] = self.word_rope_dims
d["word_rope_base"] = self.word_rope_base
if self.embed_dim > 0:
d["embed_dim"] = self.embed_dim
if self.head_dim > 0:
d["head_dim"] = self.head_dim
return d
@classmethod
def from_dict(cls, d: dict) -> "CircuitConfig":
"""Create from dictionary."""
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
def parse_args() -> tuple[CircuitConfig, argparse.Namespace]:
"""Parse CLI arguments and return config + extra args."""
parser = argparse.ArgumentParser(description="Circuit Transformer Training")
# Data
parser.add_argument("--data", type=str, required=True,
help="Data source: path/to/file.txt, path/to/dir/, or hf:dataset_name")
parser.add_argument("--text-column", type=str, default="text",
help="Column name for HF datasets (default: text)")
parser.add_argument("--data-format", type=str, choices=["text", "chat"], default="text",
help="Data format: text (single column) or chat (system + conversations)")
parser.add_argument("--num-samples", type=int, default=None,
help="Limit samples from HF dataset")
parser.add_argument("--cache-dir", type=str, default="./circuits/.cache",
help="Cache directory for tokenized data")
parser.add_argument("--no-cache", action="store_true",
help="Disable data caching")
parser.add_argument("--val-split", type=float, default=0.05,
help="Fraction of data for validation (default: 0.05, 0 to disable)")
# Model architecture
# TODO: Remove `slot_mirrored`
parser.add_argument("--arch", type=str, choices=["standard", "mirrored", "graft_g2lu"], default="standard",
help="Model architecture (default: standard)")
parser.add_argument("--preset", type=str, choices=["tiny", "small", "medium", "medium_plus", "medium_large", "medium_wide_9", "medium_wide_11", "large"],
help="Use preset configuration")
parser.add_argument("--dims", type=int, default=None, help="Hidden size")
parser.add_argument("--layers", type=int, default=None, help="Number of layers")
parser.add_argument("--heads", type=int, default=None, help="Number of attention heads")
parser.add_argument("--kv-heads", type=int, default=None,
help="Number of KV heads for GQA (default: same as --heads for MHA)")
parser.add_argument("--context-length", type=int, default=None, help="Max sequence length")
parser.add_argument("--dropout", type=float, default=None, help="Dropout rate")
parser.add_argument("--tokenizer", type=str, default="gpt2",
help="Tokenizer to use (default: gpt2, e.g. facebook/MobileLLM-125M)")
# Mirrored architecture specific
parser.add_argument("--n-middle", type=int, default=2,
help="Unique middle layers for mirrored arch (default: 2)")
parser.add_argument("--share-attention", action="store_true", default=True,
help="Share attention weights between mirror pairs (default)")
parser.add_argument("--no-share-attention", dest="share_attention", action="store_false",
help="Separate attention weights per direction")
# G²LU gating
parser.add_argument("--no-g2lu", action="store_true",
help="Disable G²LU (use vanilla SwiGLU in mirrored arch)")
# Auxiliary objectives
parser.add_argument("--aux-skip", type=int, default=0,
help="Skip-ahead prediction distance (0 = disabled, e.g. 5 predicts t+5)")
parser.add_argument("--aux-weight", type=float, default=0.1,
help="Weight for auxiliary skip loss (default: 0.1)")
# Word-position RoPE (SemRoPE)
parser.add_argument("--word-rope-dims", type=int, default=0,
help="Head dims dedicated to word-position RoPE (0=disabled, try 8 or 16)")
parser.add_argument("--word-rope-base", type=float, default=10.0,
help="Frequency base for word-position RoPE (default: 10.0)")
# Factorized embedding / MLP head
parser.add_argument("--embed-dim", type=int, default=0,
help="Factorized embedding dim (0=use hidden_size, e.g. 256)")
parser.add_argument("--head-dim", type=int, default=0,
help="MLP head intermediate dim (0=linear head, e.g. 512)")
# G²LU gate grafting
parser.add_argument("--pretrained", type=str, default=None,
help="HuggingFace model for graft_g2lu (e.g. meta-llama/Llama-3.2-1B)")
parser.add_argument("--align-weight", type=float, default=1.0,
help="Alignment loss weight for G²LU grafting (default: 1.0)")
parser.add_argument("--graft-warmup", type=int, default=500,
help="Blend warmup steps: SwiGLU→G²LU transition (default: 500)")
# Training
parser.add_argument("--epochs", type=int, default=None)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--lr", type=float, default=None, help="Learning rate")
parser.add_argument("--min-lr", type=float, default=None,
help="Minimum learning rate for cosine decay (default: 0)")
parser.add_argument("--weight-decay", type=float, default=None)
parser.add_argument("--warmup-steps", type=int, default=None)
parser.add_argument("--grad-clip", type=float, default=None)
parser.add_argument("--grad-accum", type=int, default=1,
help="Gradient accumulation steps (effective batch = batch_size * grad_accum)")
# Hardware
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--fp16", action="store_true", help="Use FP16 mixed precision (with GradScaler)")
parser.add_argument("--bf16", action="store_true", help="Use BF16 mixed precision (no scaler needed)")
parser.add_argument("--no-fp16", action="store_true", help="Disable mixed precision (FP32)")
parser.add_argument("--compile", action="store_true", help="Use torch.compile")
# Logging/Checkpointing
parser.add_argument("--log-every", type=int, default=None)
parser.add_argument("--save-every", type=int, default=None)
parser.add_argument("--val-every", type=int, default=0,
help="Run validation every N steps (0 = only at epoch end)")
parser.add_argument("--checkpoint-dir", type=str, default=None)
parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
parser.add_argument("--reset", action="store_true", default=False, help="When resuming the training resets steps and optimizers")
args = parser.parse_args()
# Build config from preset or defaults
if args.preset:
config = getattr(CircuitConfig, args.preset)()
else:
config = CircuitConfig()
# Override with explicit args
if args.dims is not None:
config.hidden_size = args.dims
if args.layers is not None:
config.num_layers = args.layers
if args.heads is not None:
config.num_heads = args.heads
if args.kv_heads is not None:
config.num_kv_heads = args.kv_heads
if args.context_length is not None:
config.max_seq_len = args.context_length
if args.dropout is not None:
config.dropout = args.dropout
if args.epochs is not None:
config.epochs = args.epochs
if args.batch_size is not None:
config.batch_size = args.batch_size
if args.lr is not None:
config.learning_rate = args.lr
if args.min_lr is not None:
config.min_lr = args.min_lr
if args.weight_decay is not None:
config.weight_decay = args.weight_decay
if args.warmup_steps is not None:
config.warmup_steps = args.warmup_steps
if args.grad_clip is not None:
config.grad_clip = args.grad_clip
if args.log_every is not None:
config.log_every = args.log_every
if args.save_every is not None:
config.save_every = args.save_every
if args.checkpoint_dir is not None:
config.checkpoint_dir = args.checkpoint_dir
# Auxiliary objectives
if args.aux_skip > 0:
config.aux_skip_k = args.aux_skip
config.aux_skip_weight = args.aux_weight
# Word-position RoPE
if args.word_rope_dims > 0:
config.word_rope_dims = args.word_rope_dims
config.word_rope_base = args.word_rope_base
# Factorized embedding / MLP head
if args.embed_dim > 0:
config.embed_dim = args.embed_dim
if args.head_dim > 0:
config.head_dim = args.head_dim
config.gpu = args.gpu
if args.bf16:
config.bf16 = True
config.fp16 = False
elif args.no_fp16:
config.fp16 = False
config.bf16 = False
elif args.fp16:
config.fp16 = True
config.bf16 = False
config.compile = args.compile
config.reset = args.reset
return config, args