dkumar15's picture
Upload training_code/model/config.py with huggingface_hub
1dd2382 verified
"""
Configuration for 1B parameter LLaMA-style Transformer model.
Architecture: Decoder-only Transformer with RoPE, GQA, SwiGLU, RMSNorm.
"""
from dataclasses import dataclass
@dataclass
class ModelConfig:
vocab_size: int = 32000
hidden_dim: int = 2048
intermediate_dim: int = 5504 # ~2.7x hidden for SwiGLU (adjusted for param count)
num_layers: int = 22
num_attention_heads: int = 32
num_kv_heads: int = 8 # GQA: 4 query heads per KV head
max_seq_len: int = 2048
rope_theta: float = 10000.0
rms_norm_eps: float = 1e-5
dropout: float = 0.0 # No dropout (modern practice for pretraining)
tie_word_embeddings: bool = False
@property
def head_dim(self) -> int:
return self.hidden_dim // self.num_attention_heads
@property
def num_params_approx(self) -> int:
"""Rough parameter count estimate."""
embed = self.vocab_size * self.hidden_dim
attn_per_layer = (
self.hidden_dim * self.head_dim * self.num_attention_heads + # Q
self.hidden_dim * self.head_dim * self.num_kv_heads + # K
self.hidden_dim * self.head_dim * self.num_kv_heads + # V
self.head_dim * self.num_attention_heads * self.hidden_dim # O
)
ffn_per_layer = 3 * self.hidden_dim * self.intermediate_dim # gate + up + down
norm_per_layer = 2 * self.hidden_dim
total = (
embed +
self.num_layers * (attn_per_layer + ffn_per_layer + norm_per_layer) +
self.hidden_dim + # final norm
(0 if self.tie_word_embeddings else self.vocab_size * self.hidden_dim)
)
return total
@dataclass
class TrainConfig:
# Paths
checkpoint_dir: str = "/jfs/deepak-kumar/checkpoints"
data_cache_dir: str = "/jfs/deepak-kumar/data"
log_dir: str = "/home/jovyan/training/logs"
# Training
total_tokens: int = 20_000_000_000 # 20B tokens
batch_size_per_gpu: int = 8
gradient_accumulation_steps: int = 8 # effective batch = 8 * 8 * 8 = 512 seqs
max_seq_len: int = 2048
# WSD Schedule
learning_rate: float = 3e-4
min_lr: float = 3e-5
warmup_steps: int = 1000
weight_decay: float = 0.1
beta1: float = 0.9
beta2: float = 0.95
grad_clip: float = 1.0
# Logging
log_interval: int = 10
save_interval: int = 1000
eval_interval: int = 500
# System
num_workers: int = 4
seed: int = 42
bf16: bool = True