docs: translate all Korean comments and docstrings to English
Browse filesConvert all Korean-language comments, docstrings, and inline annotations
to English across 38 source files and CLAUDE.md. Update code conventions
to require English for all code, comments, docstrings, and git commit messages.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- CLAUDE.md +56 -56
- llm_lab/__init__.py +9 -9
- llm_lab/config/__init__.py +1 -1
- llm_lab/config/data_config.py +21 -21
- llm_lab/config/eval_config.py +5 -5
- llm_lab/config/model_config.py +19 -19
- llm_lab/config/train_config.py +49 -49
- llm_lab/data/__init__.py +1 -1
- llm_lab/data/dataset.py +60 -59
- llm_lab/data/diagnostics.py +42 -42
- llm_lab/data/pipeline.py +35 -35
- llm_lab/data/tokenizer.py +53 -53
- llm_lab/evaluation/__init__.py +1 -1
- llm_lab/evaluation/attention_viz.py +23 -23
- llm_lab/evaluation/checklist.py +36 -36
- llm_lab/evaluation/dynamics.py +38 -38
- llm_lab/evaluation/full_evaluator.py +48 -48
- llm_lab/evaluation/generation.py +49 -49
- llm_lab/evaluation/perplexity.py +34 -34
- llm_lab/evaluation/runner.py +5 -5
- llm_lab/evaluation/scaling.py +28 -28
- llm_lab/model/__init__.py +1 -1
- llm_lab/model/attention.py +27 -27
- llm_lab/model/feedforward.py +17 -17
- llm_lab/model/llm_model.py +49 -49
- llm_lab/model/norm.py +15 -15
- llm_lab/model/rope.py +31 -29
- llm_lab/model/transformer_block.py +12 -12
- llm_lab/model/utils.py +11 -11
- llm_lab/training/__init__.py +1 -1
- llm_lab/training/checkpoint.py +40 -40
- llm_lab/training/metrics.py +19 -19
- llm_lab/training/optimizer.py +15 -15
- llm_lab/training/runner.py +14 -14
- llm_lab/training/scheduler.py +18 -18
- llm_lab/training/trainer.py +68 -67
- llm_lab/utils/__init__.py +1 -1
- llm_lab/utils/device.py +20 -20
- llm_lab/utils/seed.py +2 -2
CLAUDE.md
CHANGED
|
@@ -1,78 +1,78 @@
|
|
| 1 |
# LLM-1B-Lab
|
| 2 |
|
| 3 |
-
1.1B parameter LLaMA-style Decoder-Only Transformer
|
| 4 |
-
|
| 5 |
|
| 6 |
-
##
|
| 7 |
|
| 8 |
```
|
| 9 |
LLM_Foundation_Model/
|
| 10 |
├── CLAUDE.md
|
| 11 |
├── requirements.txt
|
| 12 |
-
├── llm_lab/ # Python
|
| 13 |
│ ├── __init__.py
|
| 14 |
-
│ ├── config/ #
|
| 15 |
-
│ │ ├── model_config.py # ModelConfig (debug_10m / small_100m / base_1b
|
| 16 |
-
│ │ ├── data_config.py # DataConfig (
|
| 17 |
-
│ │ ├── train_config.py # TrainConfig (LR,
|
| 18 |
-
│ │ └── eval_config.py # EvalConfig (
|
| 19 |
-
│ ├── model/ #
|
| 20 |
│ │ ├── norm.py # RMSNorm
|
| 21 |
│ │ ├── rope.py # RotaryPositionalEmbedding (RoPE)
|
| 22 |
│ │ ├── attention.py # GroupedQueryAttention (GQA)
|
| 23 |
│ │ ├── feedforward.py # SwiGLUFeedForward
|
| 24 |
│ │ ├── transformer_block.py # TransformerBlock (Pre-LN)
|
| 25 |
-
│ │ ├── llm_model.py # LLMModel (
|
| 26 |
│ │ └── utils.py # count_parameters_detailed, estimate_memory_gb
|
| 27 |
-
│ ├── data/ #
|
| 28 |
│ │ ├── tokenizer.py # Tokenizer (SentencePiece / BPE / HuggingFace)
|
| 29 |
│ │ ├── dataset.py # PackedStreamingDataset, ValidationDataset, _collate_fn
|
| 30 |
│ │ ├── pipeline.py # create_train_dataloader, setup_data_pipeline
|
| 31 |
│ │ └── diagnostics.py # DataPipelineDiagnostics
|
| 32 |
-
│ ├── training/ #
|
| 33 |
│ │ ├── scheduler.py # CosineWarmupScheduler
|
| 34 |
-
│ │ ├── checkpoint.py # CheckpointManager (Google Drive
|
| 35 |
-
│ │ ├── metrics.py # MetricsTracker (wandb
|
| 36 |
-
│ │ ├── optimizer.py # create_optimizer (weight decay
|
| 37 |
│ │ ├── trainer.py # Trainer (gradient accumulation, mixed precision)
|
| 38 |
-
│ │ └── runner.py # start_training (
|
| 39 |
-
│ ├── evaluation/ #
|
| 40 |
-
│ │ ├── perplexity.py # PerplexityEvaluator (
|
| 41 |
-
│ │ ├── generation.py # GenerationEvaluator (
|
| 42 |
│ │ ├── scaling.py # ScalingAnalyzer (Chinchilla Scaling Law)
|
| 43 |
-
│ │ ├── dynamics.py # TrainingDynamicsAnalyzer (Loss/LR/Grad
|
| 44 |
-
│ │ ├── attention_viz.py # AttentionVisualizer (
|
| 45 |
-
│ │ ├── full_evaluator.py # FullEvaluator (
|
| 46 |
-
│ │ ├── checklist.py # InsightChecklist (
|
| 47 |
-
│ │ └── runner.py # run_evaluation (
|
| 48 |
-
│ └── utils/ #
|
| 49 |
│ ├── device.py # auto_configure, get_device, detect_gpu_info
|
| 50 |
│ └── seed.py # set_seed
|
| 51 |
-
├── notebooks/ # Jupyter
|
| 52 |
│ ├── 01_data_pipeline.ipynb
|
| 53 |
│ ├── 02_model.ipynb
|
| 54 |
│ ├── 03_training.ipynb
|
| 55 |
│ └── 04_evaluation.ipynb
|
| 56 |
-
└── _archive/ #
|
| 57 |
├── llm-1b-model.py
|
| 58 |
├── llm-1b-data-pipeline.py
|
| 59 |
├── llm-1b-trainer.py
|
| 60 |
└── llm-1b-evaluation.py
|
| 61 |
```
|
| 62 |
|
| 63 |
-
##
|
| 64 |
|
| 65 |
-
- **
|
| 66 |
-
- **
|
| 67 |
-
- **
|
| 68 |
-
- **
|
| 69 |
-
- **
|
| 70 |
-
- **
|
| 71 |
|
| 72 |
-
##
|
| 73 |
|
| 74 |
```
|
| 75 |
-
config (
|
| 76 |
↓
|
| 77 |
utils → config
|
| 78 |
↓
|
|
@@ -85,13 +85,13 @@ training → config, utils
|
|
| 85 |
evaluation → config
|
| 86 |
```
|
| 87 |
|
| 88 |
-
##
|
| 89 |
|
| 90 |
-
|
|
| 91 |
-
|--------|---------|-----|--------|-------|----------|------|
|
| 92 |
-
| `debug_10m` | ~10M | 256 | 6 | 8 | 4 |
|
| 93 |
-
| `small_100m` | ~100M | 768 | 12 | 12 | 4 |
|
| 94 |
-
| `base_1b` | ~1.1B | 2048 | 22 | 32 | 8 |
|
| 95 |
|
| 96 |
## Quick Start
|
| 97 |
|
|
@@ -102,30 +102,30 @@ from llm_lab.data import setup_data_pipeline
|
|
| 102 |
from llm_lab.training import start_training
|
| 103 |
from llm_lab.evaluation import run_evaluation
|
| 104 |
|
| 105 |
-
# 1.
|
| 106 |
model = LLMModel(ModelConfig.base_1b())
|
| 107 |
|
| 108 |
-
# 2.
|
| 109 |
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 110 |
|
| 111 |
-
# 3.
|
| 112 |
trainer = start_training(model, train_dl, val_dl)
|
| 113 |
|
| 114 |
-
# 4.
|
| 115 |
report = run_evaluation(model, tok, val_dl,
|
| 116 |
metrics_history=trainer.metrics.history)
|
| 117 |
```
|
| 118 |
|
| 119 |
-
##
|
| 120 |
|
| 121 |
-
- **
|
| 122 |
-
- **
|
| 123 |
-
- **
|
| 124 |
-
- **
|
| 125 |
-
- **
|
| 126 |
|
| 127 |
-
##
|
| 128 |
|
| 129 |
-
- `torch`
|
| 130 |
- `pip install torch datasets tokenizers sentencepiece transformers wandb matplotlib numpy`
|
| 131 |
-
-
|
|
|
|
| 1 |
# LLM-1B-Lab
|
| 2 |
|
| 3 |
+
Educational implementation of a 1.1B parameter LLaMA-style Decoder-Only Transformer.
|
| 4 |
+
Designed so beginners in deep learning can experience training and evaluating an LLM from scratch.
|
| 5 |
|
| 6 |
+
## Project Structure
|
| 7 |
|
| 8 |
```
|
| 9 |
LLM_Foundation_Model/
|
| 10 |
├── CLAUDE.md
|
| 11 |
├── requirements.txt
|
| 12 |
+
├── llm_lab/ # Python package (core code)
|
| 13 |
│ ├── __init__.py
|
| 14 |
+
│ ├── config/ # Configuration dataclasses
|
| 15 |
+
│ │ ├── model_config.py # ModelConfig (debug_10m / small_100m / base_1b presets)
|
| 16 |
+
│ │ ├── data_config.py # DataConfig (dataset, tokenizer, batch settings)
|
| 17 |
+
│ │ ├── train_config.py # TrainConfig (LR, scheduler, checkpoint, wandb)
|
| 18 |
+
│ │ └── eval_config.py # EvalConfig (evaluation parameters)
|
| 19 |
+
│ ├── model/ # Model architecture
|
| 20 |
│ │ ├── norm.py # RMSNorm
|
| 21 |
│ │ ├── rope.py # RotaryPositionalEmbedding (RoPE)
|
| 22 |
│ │ ├── attention.py # GroupedQueryAttention (GQA)
|
| 23 |
│ │ ├── feedforward.py # SwiGLUFeedForward
|
| 24 |
│ │ ├── transformer_block.py # TransformerBlock (Pre-LN)
|
| 25 |
+
│ │ ├── llm_model.py # LLMModel (full model + generate)
|
| 26 |
│ │ └── utils.py # count_parameters_detailed, estimate_memory_gb
|
| 27 |
+
│ ├── data/ # Data pipeline
|
| 28 |
│ │ ├── tokenizer.py # Tokenizer (SentencePiece / BPE / HuggingFace)
|
| 29 |
│ │ ├── dataset.py # PackedStreamingDataset, ValidationDataset, _collate_fn
|
| 30 |
│ │ ├── pipeline.py # create_train_dataloader, setup_data_pipeline
|
| 31 |
│ │ └── diagnostics.py # DataPipelineDiagnostics
|
| 32 |
+
│ ├── training/ # Training loop
|
| 33 |
│ │ ├── scheduler.py # CosineWarmupScheduler
|
| 34 |
+
│ │ ├── checkpoint.py # CheckpointManager (Google Drive support)
|
| 35 |
+
│ │ ├── metrics.py # MetricsTracker (wandb integration)
|
| 36 |
+
│ │ ├── optimizer.py # create_optimizer (weight decay separation)
|
| 37 |
│ │ ├── trainer.py # Trainer (gradient accumulation, mixed precision)
|
| 38 |
+
│ │ └── runner.py # start_training (one-line helper)
|
| 39 |
+
│ ├── evaluation/ # Evaluation & analysis
|
| 40 |
+
│ │ ├── perplexity.py # PerplexityEvaluator (including per-position loss)
|
| 41 |
+
│ │ ├── generation.py # GenerationEvaluator (various prompts)
|
| 42 |
│ │ ├── scaling.py # ScalingAnalyzer (Chinchilla Scaling Law)
|
| 43 |
+
│ │ ├── dynamics.py # TrainingDynamicsAnalyzer (Loss/LR/Grad visualization)
|
| 44 |
+
│ │ ├── attention_viz.py # AttentionVisualizer (per-head heatmap)
|
| 45 |
+
│ │ ├── full_evaluator.py # FullEvaluator (comprehensive evaluation + report)
|
| 46 |
+
│ │ ├── checklist.py # InsightChecklist (training insight checklist)
|
| 47 |
+
│ │ └── runner.py # run_evaluation (one-line helper)
|
| 48 |
+
│ └── utils/ # Common utilities
|
| 49 |
│ ├── device.py # auto_configure, get_device, detect_gpu_info
|
| 50 |
│ └── seed.py # set_seed
|
| 51 |
+
├── notebooks/ # Jupyter notebooks (configuration + execution)
|
| 52 |
│ ├── 01_data_pipeline.ipynb
|
| 53 |
│ ├── 02_model.ipynb
|
| 54 |
│ ├── 03_training.ipynb
|
| 55 |
│ └── 04_evaluation.ipynb
|
| 56 |
+
└── _archive/ # Original single-file backups
|
| 57 |
├── llm-1b-model.py
|
| 58 |
├── llm-1b-data-pipeline.py
|
| 59 |
├── llm-1b-trainer.py
|
| 60 |
└── llm-1b-evaluation.py
|
| 61 |
```
|
| 62 |
|
| 63 |
+
## Tech Stack
|
| 64 |
|
| 65 |
+
- **Model**: LLaMA-style Decoder-Only Transformer (RMSNorm, RoPE, GQA, SwiGLU, Weight Tying)
|
| 66 |
+
- **Training**: Gradient Accumulation, Mixed Precision (bf16/fp16), Cosine LR + Warmup, Activation Checkpointing
|
| 67 |
+
- **Data**: HuggingFace Streaming (FineWeb-Edu), BPE tokenizer, sequence packing
|
| 68 |
+
- **Checkpoint**: Auto save/restore to Google Drive (Colab Pro+ environment)
|
| 69 |
+
- **Evaluation**: Perplexity, text generation, Scaling Law, Attention visualization
|
| 70 |
+
- **Target Environment**: Google Colab Pro+ (A100 40GB)
|
| 71 |
|
| 72 |
+
## Dependency Graph (no cycles)
|
| 73 |
|
| 74 |
```
|
| 75 |
+
config (no dependencies)
|
| 76 |
↓
|
| 77 |
utils → config
|
| 78 |
↓
|
|
|
|
| 85 |
evaluation → config
|
| 86 |
```
|
| 87 |
|
| 88 |
+
## Model Presets
|
| 89 |
|
| 90 |
+
| Preset | Parameters | dim | layers | heads | kv_heads | Purpose |
|
| 91 |
+
|--------|-----------|-----|--------|-------|----------|---------|
|
| 92 |
+
| `debug_10m` | ~10M | 256 | 6 | 8 | 4 | Fast validation/debug |
|
| 93 |
+
| `small_100m` | ~100M | 768 | 12 | 12 | 4 | Intermediate experiments |
|
| 94 |
+
| `base_1b` | ~1.1B | 2048 | 22 | 32 | 8 | Full-scale training |
|
| 95 |
|
| 96 |
## Quick Start
|
| 97 |
|
|
|
|
| 102 |
from llm_lab.training import start_training
|
| 103 |
from llm_lab.evaluation import run_evaluation
|
| 104 |
|
| 105 |
+
# 1. Model
|
| 106 |
model = LLMModel(ModelConfig.base_1b())
|
| 107 |
|
| 108 |
+
# 2. Data
|
| 109 |
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 110 |
|
| 111 |
+
# 3. Training
|
| 112 |
trainer = start_training(model, train_dl, val_dl)
|
| 113 |
|
| 114 |
+
# 4. Evaluation
|
| 115 |
report = run_evaluation(model, tok, val_dl,
|
| 116 |
metrics_history=trainer.metrics.history)
|
| 117 |
```
|
| 118 |
|
| 119 |
+
## Code Conventions
|
| 120 |
|
| 121 |
+
- **Language**: All code, comments, docstrings, and git commit messages must be written in English
|
| 122 |
+
- **Type hints**: Use typing annotations on all functions
|
| 123 |
+
- **Import order**: stdlib → torch → llm_lab (absolute) → local (relative)
|
| 124 |
+
- **Dataclasses**: All configurations defined as `@dataclass` with defaults
|
| 125 |
+
- **Error handling**: Optional dependencies (matplotlib, wandb, etc.) wrapped in `try/except ImportError`
|
| 126 |
|
| 127 |
+
## Notes
|
| 128 |
|
| 129 |
+
- `torch` may not be installed locally (assumes Colab Pro+ runtime)
|
| 130 |
- `pip install torch datasets tokenizers sentencepiece transformers wandb matplotlib numpy`
|
| 131 |
+
- The logic in the original 4 files (`_archive/`) and the modularized `llm_lab/` package is identical (only import paths changed)
|
llm_lab/__init__.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
"""
|
| 2 |
LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
|
| 3 |
================================================================
|
| 4 |
-
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
-
llm_lab.config —
|
| 9 |
-
llm_lab.model —
|
| 10 |
-
llm_lab.data —
|
| 11 |
-
llm_lab.training —
|
| 12 |
-
llm_lab.evaluation —
|
| 13 |
-
llm_lab.utils —
|
| 14 |
|
| 15 |
Quick Start:
|
| 16 |
from llm_lab.config import ModelConfig, DataConfig, TrainConfig
|
|
|
|
| 1 |
"""
|
| 2 |
LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
|
| 3 |
================================================================
|
| 4 |
+
An educational implementation for deep learning beginners.
|
| 5 |
+
Each component includes detailed comments explaining "why" things are done this way.
|
| 6 |
|
| 7 |
+
Module structure:
|
| 8 |
+
llm_lab.config — All configurations (ModelConfig, DataConfig, TrainConfig, EvalConfig)
|
| 9 |
+
llm_lab.model — Model architecture (RMSNorm, RoPE, GQA, SwiGLU, Transformer)
|
| 10 |
+
llm_lab.data — Data pipeline (tokenizer, streaming, packing)
|
| 11 |
+
llm_lab.training — Training loop (Trainer, scheduler, checkpoint)
|
| 12 |
+
llm_lab.evaluation — Evaluation (Perplexity, generation, Scaling Law, Attention)
|
| 13 |
+
llm_lab.utils — Common utilities (device detection, seed)
|
| 14 |
|
| 15 |
Quick Start:
|
| 16 |
from llm_lab.config import ModelConfig, DataConfig, TrainConfig
|
llm_lab/config/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
from .model_config import ModelConfig
|
| 3 |
from .data_config import DataConfig
|
| 4 |
from .train_config import TrainConfig
|
|
|
|
| 1 |
+
"""Config module — manages all hyperparameters in one place."""
|
| 2 |
from .model_config import ModelConfig
|
| 3 |
from .data_config import DataConfig
|
| 4 |
from .train_config import TrainConfig
|
llm_lab/config/data_config.py
CHANGED
|
@@ -4,38 +4,38 @@ from typing import Optional
|
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class DataConfig:
|
| 7 |
-
"""
|
| 8 |
|
| 9 |
-
Colab Pro+
|
| 10 |
-
- Streaming
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
"""
|
| 14 |
-
# ──
|
| 15 |
dataset_name: str = "HuggingFaceFW/fineweb-edu"
|
| 16 |
-
dataset_subset: str = "sample-10BT" # 10B
|
| 17 |
dataset_split: str = "train"
|
| 18 |
-
text_column: str = "text" #
|
| 19 |
|
| 20 |
-
# ──
|
| 21 |
-
tokenizer_type: str = "sentencepiece" # "sentencepiece"
|
| 22 |
-
#
|
| 23 |
tokenizer_path: Optional[str] = None
|
| 24 |
vocab_size: int = 32_000
|
| 25 |
|
| 26 |
-
# ──
|
| 27 |
max_seq_len: int = 2048
|
| 28 |
-
#
|
| 29 |
use_eos_separator: bool = True
|
| 30 |
|
| 31 |
-
# ──
|
| 32 |
-
batch_size: int = 4 # micro batch (GPU
|
| 33 |
-
num_workers: int = 2 #
|
| 34 |
-
prefetch_factor: int = 4 #
|
| 35 |
|
| 36 |
-
# ──
|
| 37 |
-
tokenizer_train_samples: int = 50_000 #
|
| 38 |
tokenizer_save_dir: str = "./tokenizer"
|
| 39 |
|
| 40 |
-
# ──
|
| 41 |
-
val_ratio: float = 0.001 #
|
|
|
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class DataConfig:
|
| 7 |
+
"""Data pipeline configuration.
|
| 8 |
|
| 9 |
+
Default values optimized for Colab Pro+ environment constraints:
|
| 10 |
+
- Streaming mode to minimize disk usage
|
| 11 |
+
- Sequence packing to maximize GPU utilization without padding
|
| 12 |
+
- On-the-fly preprocessing to save memory
|
| 13 |
"""
|
| 14 |
+
# ── Dataset ──
|
| 15 |
dataset_name: str = "HuggingFaceFW/fineweb-edu"
|
| 16 |
+
dataset_subset: str = "sample-10BT" # 10B token sample
|
| 17 |
dataset_split: str = "train"
|
| 18 |
+
text_column: str = "text" # column name containing text
|
| 19 |
|
| 20 |
+
# ── Tokenizer ──
|
| 21 |
+
tokenizer_type: str = "sentencepiece" # "sentencepiece" or "hf"
|
| 22 |
+
# path to a pretrained tokenizer (trains a new one if not provided)
|
| 23 |
tokenizer_path: Optional[str] = None
|
| 24 |
vocab_size: int = 32_000
|
| 25 |
|
| 26 |
+
# ── Sequence ──
|
| 27 |
max_seq_len: int = 2048
|
| 28 |
+
# whether to use a document separator token (marks document boundaries during packing)
|
| 29 |
use_eos_separator: bool = True
|
| 30 |
|
| 31 |
+
# ── Batch ──
|
| 32 |
+
batch_size: int = 4 # micro batch (per GPU)
|
| 33 |
+
num_workers: int = 2 # number of DataLoader workers
|
| 34 |
+
prefetch_factor: int = 4 # number of batches to prefetch
|
| 35 |
|
| 36 |
+
# ── Tokenizer training settings (when training from scratch) ──
|
| 37 |
+
tokenizer_train_samples: int = 50_000 # number of documents to use for training
|
| 38 |
tokenizer_save_dir: str = "./tokenizer"
|
| 39 |
|
| 40 |
+
# ── Validation data ──
|
| 41 |
+
val_ratio: float = 0.001 # use 0.1% of total data for validation
|
llm_lab/config/eval_config.py
CHANGED
|
@@ -3,18 +3,18 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
@dataclass
|
| 5 |
class EvalConfig:
|
| 6 |
-
"""
|
| 7 |
# ── Perplexity ──
|
| 8 |
eval_batch_size: int = 4
|
| 9 |
-
max_eval_batches: int = 100 #
|
| 10 |
|
| 11 |
-
# ──
|
| 12 |
max_new_tokens: int = 200
|
| 13 |
temperature: float = 0.8
|
| 14 |
top_k: int = 50
|
| 15 |
top_p: float = 0.9
|
| 16 |
-
num_samples: int = 3 #
|
| 17 |
|
| 18 |
-
# ──
|
| 19 |
save_dir: str = "./eval_results"
|
| 20 |
plot_dpi: int = 150
|
|
|
|
| 3 |
|
| 4 |
@dataclass
|
| 5 |
class EvalConfig:
|
| 6 |
+
"""Evaluation parameters."""
|
| 7 |
# ── Perplexity ──
|
| 8 |
eval_batch_size: int = 4
|
| 9 |
+
max_eval_batches: int = 100 # maximum number of evaluation batches
|
| 10 |
|
| 11 |
+
# ── Generation ──
|
| 12 |
max_new_tokens: int = 200
|
| 13 |
temperature: float = 0.8
|
| 14 |
top_k: int = 50
|
| 15 |
top_p: float = 0.9
|
| 16 |
+
num_samples: int = 3 # number of generations per prompt
|
| 17 |
|
| 18 |
+
# ── Output ──
|
| 19 |
save_dir: str = "./eval_results"
|
| 20 |
plot_dpi: int = 150
|
llm_lab/config/model_config.py
CHANGED
|
@@ -3,37 +3,37 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
@dataclass
|
| 5 |
class ModelConfig:
|
| 6 |
-
"""
|
| 7 |
|
| 8 |
-
|
| 9 |
-
- debug: ~10M (
|
| 10 |
-
- small: ~100M (
|
| 11 |
-
- base: ~1.1B (
|
| 12 |
"""
|
| 13 |
vocab_size: int = 32_000
|
| 14 |
-
hidden_dim: int = 2048 # d_model:
|
| 15 |
-
num_layers: int = 22 #
|
| 16 |
-
num_heads: int = 16 #
|
| 17 |
-
num_kv_heads: int = 4 # Key/Value
|
| 18 |
-
intermediate_dim: int = 5632 # FFN
|
| 19 |
-
max_seq_len: int = 2048 #
|
| 20 |
-
dropout: float = 0.0 #
|
| 21 |
-
rope_theta: float = 10000.0 # RoPE
|
| 22 |
norm_eps: float = 1e-6 # RMSNorm epsilon
|
| 23 |
|
| 24 |
@property
|
| 25 |
def head_dim(self) -> int:
|
| 26 |
-
"""
|
| 27 |
return self.hidden_dim // self.num_heads
|
| 28 |
|
| 29 |
@property
|
| 30 |
def num_kv_groups(self) -> int:
|
| 31 |
-
"""
|
| 32 |
return self.num_heads // self.num_kv_heads
|
| 33 |
|
| 34 |
@classmethod
|
| 35 |
def debug_10m(cls) -> "ModelConfig":
|
| 36 |
-
"""~10M
|
| 37 |
return cls(
|
| 38 |
hidden_dim=256, num_layers=6, num_heads=8,
|
| 39 |
num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
|
|
@@ -41,7 +41,7 @@ class ModelConfig:
|
|
| 41 |
|
| 42 |
@classmethod
|
| 43 |
def small_100m(cls) -> "ModelConfig":
|
| 44 |
-
"""~100M
|
| 45 |
return cls(
|
| 46 |
hidden_dim=768, num_layers=12, num_heads=12,
|
| 47 |
num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
|
|
@@ -49,5 +49,5 @@ class ModelConfig:
|
|
| 49 |
|
| 50 |
@classmethod
|
| 51 |
def base_1b(cls) -> "ModelConfig":
|
| 52 |
-
"""~1.1B
|
| 53 |
-
return cls() #
|
|
|
|
| 3 |
|
| 4 |
@dataclass
|
| 5 |
class ModelConfig:
|
| 6 |
+
"""Manages model hyperparameters as a single dataclass.
|
| 7 |
|
| 8 |
+
Scale-specific presets:
|
| 9 |
+
- debug: ~10M (for pipeline validation)
|
| 10 |
+
- small: ~100M (for intermediate validation)
|
| 11 |
+
- base: ~1.1B (final target)
|
| 12 |
"""
|
| 13 |
vocab_size: int = 32_000
|
| 14 |
+
hidden_dim: int = 2048 # d_model: base dimension of the model
|
| 15 |
+
num_layers: int = 22 # number of Transformer blocks
|
| 16 |
+
num_heads: int = 16 # number of Query heads
|
| 17 |
+
num_kv_heads: int = 4 # number of Key/Value heads (GQA)
|
| 18 |
+
intermediate_dim: int = 5632 # FFN intermediate dimension (≈ 2.75 × hidden_dim)
|
| 19 |
+
max_seq_len: int = 2048 # maximum sequence length
|
| 20 |
+
dropout: float = 0.0 # typically 0 during pretraining
|
| 21 |
+
rope_theta: float = 10000.0 # RoPE frequency base
|
| 22 |
norm_eps: float = 1e-6 # RMSNorm epsilon
|
| 23 |
|
| 24 |
@property
|
| 25 |
def head_dim(self) -> int:
|
| 26 |
+
"""Dimension of each attention head."""
|
| 27 |
return self.hidden_dim // self.num_heads
|
| 28 |
|
| 29 |
@property
|
| 30 |
def num_kv_groups(self) -> int:
|
| 31 |
+
"""Number of Q heads per KV head in GQA."""
|
| 32 |
return self.num_heads // self.num_kv_heads
|
| 33 |
|
| 34 |
@classmethod
|
| 35 |
def debug_10m(cls) -> "ModelConfig":
|
| 36 |
+
"""~10M parameters - for fast debugging."""
|
| 37 |
return cls(
|
| 38 |
hidden_dim=256, num_layers=6, num_heads=8,
|
| 39 |
num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
|
|
|
|
| 41 |
|
| 42 |
@classmethod
|
| 43 |
def small_100m(cls) -> "ModelConfig":
|
| 44 |
+
"""~100M parameters - for intermediate validation."""
|
| 45 |
return cls(
|
| 46 |
hidden_dim=768, num_layers=12, num_heads=12,
|
| 47 |
num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
|
|
|
|
| 49 |
|
| 50 |
@classmethod
|
| 51 |
def base_1b(cls) -> "ModelConfig":
|
| 52 |
+
"""~1.1B parameters - final training target."""
|
| 53 |
+
return cls() # defaults are the 1B configuration
|
llm_lab/config/train_config.py
CHANGED
|
@@ -6,97 +6,97 @@ import torch
|
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class TrainConfig:
|
| 9 |
-
"""
|
| 10 |
|
| 11 |
-
Colab Pro+ (A100 40GB)
|
| 12 |
-
|
| 13 |
"""
|
| 14 |
|
| 15 |
-
# ──
|
| 16 |
learning_rate: float = 3e-4
|
| 17 |
-
"""Peak LR.
|
| 18 |
-
GPT-3
|
| 19 |
125M → 6e-4, 350M → 3e-4, 1.3B → 2e-4
|
| 20 |
-
|
| 21 |
|
| 22 |
min_learning_rate: float = 3e-5
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
|
| 26 |
weight_decay: float = 0.1
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
|
| 30 |
beta1: float = 0.9
|
| 31 |
beta2: float = 0.95
|
| 32 |
-
"""Adam
|
| 33 |
-
|
| 34 |
|
| 35 |
adam_eps: float = 1e-8
|
| 36 |
grad_clip: float = 1.0
|
| 37 |
-
"""Gradient Clipping:
|
| 38 |
-
|
| 39 |
|
| 40 |
-
# ──
|
| 41 |
warmup_steps: int = 2000
|
| 42 |
-
"""Warmup:
|
| 43 |
-
|
| 44 |
-
-
|
| 45 |
-
-
|
| 46 |
-
- 2000
|
| 47 |
|
| 48 |
total_steps: int = 20_000
|
| 49 |
-
"""
|
| 50 |
-
10B tokens / (128 batch × 2048 seq_len) ≈ 38,000
|
| 51 |
-
|
| 52 |
|
| 53 |
-
# ──
|
| 54 |
micro_batch_size: int = 4
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
|
| 58 |
gradient_accumulation_steps: int = 32
|
| 59 |
-
"""
|
| 60 |
-
|
| 61 |
-
-
|
| 62 |
-
- LLM
|
| 63 |
-
-
|
| 64 |
|
| 65 |
# ── Mixed Precision ──
|
| 66 |
dtype: str = "bfloat16"
|
| 67 |
-
"""bfloat16:
|
| 68 |
-
exponent
|
| 69 |
-
|
| 70 |
|
| 71 |
-
# ──
|
| 72 |
checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints"
|
| 73 |
-
"""Google Drive
|
| 74 |
|
| 75 |
checkpoint_interval: int = 500
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
max_checkpoints: int = 3
|
| 81 |
-
"""
|
| 82 |
-
|
| 83 |
|
| 84 |
-
# ──
|
| 85 |
log_interval: int = 10
|
| 86 |
-
"""
|
| 87 |
|
| 88 |
eval_interval: int = 500
|
| 89 |
-
"""
|
| 90 |
|
| 91 |
eval_steps: int = 20
|
| 92 |
-
"""
|
| 93 |
|
| 94 |
# ── wandb ──
|
| 95 |
wandb_project: str = "llm-1b-lab"
|
| 96 |
wandb_run_name: Optional[str] = None
|
| 97 |
use_wandb: bool = True
|
| 98 |
|
| 99 |
-
# ──
|
| 100 |
seed: int = 42
|
| 101 |
|
| 102 |
@property
|
|
@@ -105,8 +105,8 @@ class TrainConfig:
|
|
| 105 |
|
| 106 |
@property
|
| 107 |
def tokens_per_step(self) -> int:
|
| 108 |
-
"""
|
| 109 |
-
# max_seq_len
|
| 110 |
return self.effective_batch_size * 2048
|
| 111 |
|
| 112 |
@property
|
|
|
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class TrainConfig:
|
| 9 |
+
"""Training hyperparameters and infrastructure configuration.
|
| 10 |
|
| 11 |
+
Default values optimized for Colab Pro+ (A100 40GB).
|
| 12 |
+
Each value includes an explanation of why it was chosen.
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
# ── Optimization ──
|
| 16 |
learning_rate: float = 3e-4
|
| 17 |
+
"""Peak LR. 3e-4 is the standard for 1B-scale models.
|
| 18 |
+
The GPT-3 paper reports optimal LRs by model size:
|
| 19 |
125M → 6e-4, 350M → 3e-4, 1.3B → 2e-4
|
| 20 |
+
Our model (1.1B) starts at 3e-4; lower to 2e-4 if unstable."""
|
| 21 |
|
| 22 |
min_learning_rate: float = 3e-5
|
| 23 |
+
"""Minimum point of cosine decay. Typically 10% of peak.
|
| 24 |
+
Too low causes stagnation in later training; too high causes unstable convergence."""
|
| 25 |
|
| 26 |
weight_decay: float = 0.1
|
| 27 |
+
"""L2 regularization for AdamW. 0.1 is the LLM standard.
|
| 28 |
+
Not applied to embeddings and biases (by convention)."""
|
| 29 |
|
| 30 |
beta1: float = 0.9
|
| 31 |
beta2: float = 0.95
|
| 32 |
+
"""Adam momentum coefficients. β2=0.95 is more stable than β2=0.999 for LLM training.
|
| 33 |
+
With large batches and long training, a β2 that is too large slows adaptation."""
|
| 34 |
|
| 35 |
adam_eps: float = 1e-8
|
| 36 |
grad_clip: float = 1.0
|
| 37 |
+
"""Gradient Clipping: rescales gradients when their norm exceeds 1.0.
|
| 38 |
+
Prevents gradient spikes that occur during early training or with noisy data."""
|
| 39 |
|
| 40 |
+
# ── Scheduling ──
|
| 41 |
warmup_steps: int = 2000
|
| 42 |
+
"""Warmup: linearly increases LR from 0 to peak over the first 2000 steps.
|
| 43 |
+
Why is this necessary?
|
| 44 |
+
- Initial weights are random → large LR causes unstable updates
|
| 45 |
+
- Starting with a small LR lets the model find its direction before full training
|
| 46 |
+
- 2000 is roughly ~10% of total training steps (empirical rule)."""
|
| 47 |
|
| 48 |
total_steps: int = 20_000
|
| 49 |
+
"""Total number of training steps.
|
| 50 |
+
10B tokens / (128 batch × 2048 seq_len) ≈ 38,000, but
|
| 51 |
+
~20,000 effective steps when accounting for gradient accumulation."""
|
| 52 |
|
| 53 |
+
# ── Batch ──
|
| 54 |
micro_batch_size: int = 4
|
| 55 |
+
"""Batch size loaded onto the GPU at once.
|
| 56 |
+
4 is a safe upper bound for a 1B model in bf16 on an A100 40GB."""
|
| 57 |
|
| 58 |
gradient_accumulation_steps: int = 32
|
| 59 |
+
"""Number of gradient accumulation steps. Effective batch = 4 × 32 = 128.
|
| 60 |
+
Why is a large batch beneficial?
|
| 61 |
+
- More stable gradient estimates (reduced noise)
|
| 62 |
+
- LLM training typically uses an effective batch of 128–512
|
| 63 |
+
- When memory is limited, increase this and reduce micro_batch."""
|
| 64 |
|
| 65 |
# ── Mixed Precision ──
|
| 66 |
dtype: str = "bfloat16"
|
| 67 |
+
"""bfloat16: supported on A100, numerically more stable than fp16.
|
| 68 |
+
Uses the same number of exponent bits as fp32 → lower risk of overflow/underflow.
|
| 69 |
+
Change to 'float16' when falling back to T4/V100."""
|
| 70 |
|
| 71 |
+
# ── Checkpointing ──
|
| 72 |
checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints"
|
| 73 |
+
"""Google Drive path. Preserved even when the Colab session expires."""
|
| 74 |
|
| 75 |
checkpoint_interval: int = 500
|
| 76 |
+
"""Save a checkpoint every 500 steps.
|
| 77 |
+
Roughly every ~30 minutes on an A100. Too frequent causes I/O overhead;
|
| 78 |
+
too infrequent risks large losses when the session expires."""
|
| 79 |
|
| 80 |
max_checkpoints: int = 3
|
| 81 |
+
"""Number of rolling checkpoints to retain; oldest are deleted first.
|
| 82 |
+
One checkpoint ≈ 8–10 GB → 3 checkpoints ≈ ~30 GB."""
|
| 83 |
|
| 84 |
+
# ── Logging ──
|
| 85 |
log_interval: int = 10
|
| 86 |
+
"""Log to console and wandb every 10 steps."""
|
| 87 |
|
| 88 |
eval_interval: int = 500
|
| 89 |
+
"""Measure validation loss every 500 steps."""
|
| 90 |
|
| 91 |
eval_steps: int = 20
|
| 92 |
+
"""Number of batches to use during validation. 20 × 4 × 2048 ≈ 160K tokens."""
|
| 93 |
|
| 94 |
# ── wandb ──
|
| 95 |
wandb_project: str = "llm-1b-lab"
|
| 96 |
wandb_run_name: Optional[str] = None
|
| 97 |
use_wandb: bool = True
|
| 98 |
|
| 99 |
+
# ── Reproducibility ──
|
| 100 |
seed: int = 42
|
| 101 |
|
| 102 |
@property
|
|
|
|
| 105 |
|
| 106 |
@property
|
| 107 |
def tokens_per_step(self) -> int:
|
| 108 |
+
"""Number of tokens processed per optimizer step."""
|
| 109 |
+
# max_seq_len is injected externally (see ModelConfig)
|
| 110 |
return self.effective_batch_size * 2048
|
| 111 |
|
| 112 |
@property
|
llm_lab/data/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
from .tokenizer import Tokenizer
|
| 3 |
from .dataset import PackedStreamingDataset, ValidationDataset
|
| 4 |
from .pipeline import create_train_dataloader, train_tokenizer_from_dataset, setup_data_pipeline
|
|
|
|
| 1 |
+
"""Data pipeline module — tokenizer, streaming, and sequence packing."""
|
| 2 |
from .tokenizer import Tokenizer
|
| 3 |
from .dataset import PackedStreamingDataset, ValidationDataset
|
| 4 |
from .pipeline import create_train_dataloader, train_tokenizer_from_dataset, setup_data_pipeline
|
llm_lab/data/dataset.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from typing import Iterator, List, Dict, Optional
|
| 4 |
|
|
@@ -10,24 +10,24 @@ from .tokenizer import Tokenizer
|
|
| 10 |
|
| 11 |
|
| 12 |
class PackedStreamingDataset(IterableDataset):
|
| 13 |
-
"""Streaming +
|
| 14 |
|
| 15 |
-
|
| 16 |
-
-
|
| 17 |
-
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
→ [
|
| 22 |
|
| 23 |
-
|
| 24 |
-
- FineWeb-Edu 10B
|
| 25 |
-
- Colab
|
| 26 |
-
- Streaming:
|
| 27 |
|
| 28 |
-
|
| 29 |
-
-
|
| 30 |
-
-
|
| 31 |
"""
|
| 32 |
|
| 33 |
def __init__(
|
|
@@ -45,15 +45,16 @@ class PackedStreamingDataset(IterableDataset):
|
|
| 45 |
self.max_seq_len = config.max_seq_len
|
| 46 |
|
| 47 |
def _load_dataset(self, num_shards: int = 1, shard_index: int = 0):
|
| 48 |
-
"""HuggingFace
|
| 49 |
|
| 50 |
Args:
|
| 51 |
-
num_shards:
|
| 52 |
-
shard_index:
|
| 53 |
|
| 54 |
-
|
| 55 |
-
num_shards=4
|
| 56 |
-
|
|
|
|
| 57 |
"""
|
| 58 |
from datasets import load_dataset
|
| 59 |
|
|
@@ -61,87 +62,87 @@ class PackedStreamingDataset(IterableDataset):
|
|
| 61 |
self.config.dataset_name,
|
| 62 |
name=self.config.dataset_subset,
|
| 63 |
split=self.config.dataset_split,
|
| 64 |
-
streaming=True, #
|
| 65 |
trust_remote_code=True,
|
| 66 |
)
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
#
|
| 70 |
if num_shards > 1:
|
| 71 |
ds = ds.shard(num_shards=num_shards, index=shard_index)
|
| 72 |
|
| 73 |
-
#
|
| 74 |
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
|
| 75 |
|
| 76 |
return ds
|
| 77 |
|
| 78 |
def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
|
| 79 |
-
"""
|
| 80 |
|
| 81 |
Yields:
|
| 82 |
{"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
|
| 83 |
|
| 84 |
-
targets = input_ids
|
| 85 |
input_ids: [A, B, C, D, E]
|
| 86 |
targets: [B, C, D, E, F]
|
| 87 |
-
→
|
| 88 |
"""
|
| 89 |
-
buffer: List[int] = [] #
|
| 90 |
|
| 91 |
for example in dataset:
|
| 92 |
text = example[self.config.text_column]
|
| 93 |
if not text or not text.strip():
|
| 94 |
continue
|
| 95 |
|
| 96 |
-
#
|
| 97 |
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 98 |
|
| 99 |
if not token_ids:
|
| 100 |
continue
|
| 101 |
|
| 102 |
-
# EOS
|
| 103 |
if self.config.use_eos_separator:
|
| 104 |
token_ids.append(self.tokenizer.eos_id)
|
| 105 |
|
| 106 |
-
#
|
| 107 |
buffer.extend(token_ids)
|
| 108 |
|
| 109 |
-
#
|
| 110 |
-
# +1
|
| 111 |
while len(buffer) >= self.max_seq_len + 1:
|
| 112 |
-
# max_seq_len + 1
|
| 113 |
chunk = buffer[: self.max_seq_len + 1]
|
| 114 |
buffer = buffer[self.max_seq_len + 1 :]
|
| 115 |
|
| 116 |
-
# input_ids:
|
| 117 |
input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
|
| 118 |
-
# targets:
|
| 119 |
targets = torch.tensor(chunk[1:], dtype=torch.long)
|
| 120 |
|
| 121 |
yield {"input_ids": input_ids, "targets": targets}
|
| 122 |
|
| 123 |
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 124 |
-
"""
|
| 125 |
|
| 126 |
-
|
| 127 |
-
-
|
| 128 |
-
-
|
| 129 |
|
| 130 |
-
|
| 131 |
-
Worker 0:
|
| 132 |
-
Worker 1:
|
| 133 |
-
Worker 2:
|
| 134 |
-
Worker 3:
|
| 135 |
"""
|
| 136 |
worker_info = torch.utils.data.get_worker_info()
|
| 137 |
|
| 138 |
if worker_info is not None:
|
| 139 |
-
#
|
| 140 |
num_shards = worker_info.num_workers
|
| 141 |
shard_index = worker_info.id
|
| 142 |
worker_seed = self.seed + worker_info.id
|
| 143 |
else:
|
| 144 |
-
#
|
| 145 |
num_shards = 1
|
| 146 |
shard_index = 0
|
| 147 |
worker_seed = self.seed
|
|
@@ -153,10 +154,10 @@ class PackedStreamingDataset(IterableDataset):
|
|
| 153 |
|
| 154 |
|
| 155 |
class ValidationDataset:
|
| 156 |
-
"""
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
"""
|
| 161 |
|
| 162 |
def __init__(
|
|
@@ -174,10 +175,10 @@ class ValidationDataset:
|
|
| 174 |
self._prepare(seed)
|
| 175 |
|
| 176 |
def _prepare(self, seed: int):
|
| 177 |
-
"""
|
| 178 |
from datasets import load_dataset
|
| 179 |
|
| 180 |
-
print(f"[Validation] {self.num_samples}
|
| 181 |
|
| 182 |
ds = load_dataset(
|
| 183 |
self.config.dataset_name,
|
|
@@ -186,7 +187,7 @@ class ValidationDataset:
|
|
| 186 |
streaming=True,
|
| 187 |
trust_remote_code=True,
|
| 188 |
)
|
| 189 |
-
#
|
| 190 |
ds = ds.shuffle(seed=seed, buffer_size=5_000)
|
| 191 |
|
| 192 |
buffer: List[int] = []
|
|
@@ -217,10 +218,10 @@ class ValidationDataset:
|
|
| 217 |
})
|
| 218 |
count += 1
|
| 219 |
|
| 220 |
-
print(f"[Validation] {len(self.samples)}
|
| 221 |
|
| 222 |
def get_dataloader(self, batch_size: int) -> DataLoader:
|
| 223 |
-
"""
|
| 224 |
return DataLoader(
|
| 225 |
self.samples,
|
| 226 |
batch_size=batch_size,
|
|
@@ -231,10 +232,10 @@ class ValidationDataset:
|
|
| 231 |
|
| 232 |
|
| 233 |
def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 234 |
-
"""
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
"""
|
| 239 |
return {
|
| 240 |
"input_ids": torch.stack([s["input_ids"] for s in batch]),
|
|
|
|
| 1 |
+
"""Streaming dataset — sequence packing and validation dataset."""
|
| 2 |
|
| 3 |
from typing import Iterator, List, Dict, Optional
|
| 4 |
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class PackedStreamingDataset(IterableDataset):
|
| 13 |
+
"""Streaming + sequence packing dataset.
|
| 14 |
|
| 15 |
+
Why sequence packing?
|
| 16 |
+
- Naive approach: truncate each document to max_seq_len with padding → wastes GPU
|
| 17 |
+
- Sequence packing: concatenate multiple documents to fill max_seq_len → 100% utilization
|
| 18 |
|
| 19 |
+
How it works:
|
| 20 |
+
Doc1 (300 tokens) + Doc2 (1500 tokens) + Doc3 (248 tokens) = 2048 tokens
|
| 21 |
+
→ [Doc1][EOS][Doc2][EOS][Doc3][EOS][... no padding, fits exactly]
|
| 22 |
|
| 23 |
+
Why streaming?
|
| 24 |
+
- FineWeb-Edu 10B samples: tens of GB even when compressed
|
| 25 |
+
- Full download not feasible on Colab disk limit (~200GB)
|
| 26 |
+
- Streaming: reads from the network only as much as needed
|
| 27 |
|
| 28 |
+
Notes for training:
|
| 29 |
+
- EOS token inserted at document boundaries so the model recognizes end-of-document
|
| 30 |
+
- EOS naturally serves as a boundary marker without cross-attention masking
|
| 31 |
"""
|
| 32 |
|
| 33 |
def __init__(
|
|
|
|
| 45 |
self.max_seq_len = config.max_seq_len
|
| 46 |
|
| 47 |
def _load_dataset(self, num_shards: int = 1, shard_index: int = 0):
|
| 48 |
+
"""Loads the HuggingFace dataset in streaming mode.
|
| 49 |
|
| 50 |
Args:
|
| 51 |
+
num_shards: Total number of shards (= DataLoader num_workers)
|
| 52 |
+
shard_index: The shard index this worker is responsible for (0 ~ num_shards-1)
|
| 53 |
|
| 54 |
+
Sharding principle:
|
| 55 |
+
With num_shards=4, the stream is split into 4 equal parts so each worker
|
| 56 |
+
processes a distinct 1/4. Shuffling is applied after sharding so there is
|
| 57 |
+
no document overlap between workers.
|
| 58 |
"""
|
| 59 |
from datasets import load_dataset
|
| 60 |
|
|
|
|
| 62 |
self.config.dataset_name,
|
| 63 |
name=self.config.dataset_subset,
|
| 64 |
split=self.config.dataset_split,
|
| 65 |
+
streaming=True, # Key: streaming mode
|
| 66 |
trust_remote_code=True,
|
| 67 |
)
|
| 68 |
|
| 69 |
+
# Full partitioning (sharding): worker i processes only 1/num_shards of the stream
|
| 70 |
+
# Must be applied before shuffling so each worker has a non-overlapping set of documents
|
| 71 |
if num_shards > 1:
|
| 72 |
ds = ds.shard(num_shards=num_shards, index=shard_index)
|
| 73 |
|
| 74 |
+
# Shuffle (approximate buffer-based shuffle in streaming mode)
|
| 75 |
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
|
| 76 |
|
| 77 |
return ds
|
| 78 |
|
| 79 |
def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
|
| 80 |
+
"""Tokenizes documents and packs them into sequences.
|
| 81 |
|
| 82 |
Yields:
|
| 83 |
{"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
|
| 84 |
|
| 85 |
+
targets = input_ids shifted by one position:
|
| 86 |
input_ids: [A, B, C, D, E]
|
| 87 |
targets: [B, C, D, E, F]
|
| 88 |
+
→ The model sees A and predicts B, sees B and predicts C, ...
|
| 89 |
"""
|
| 90 |
+
buffer: List[int] = [] # Token buffer
|
| 91 |
|
| 92 |
for example in dataset:
|
| 93 |
text = example[self.config.text_column]
|
| 94 |
if not text or not text.strip():
|
| 95 |
continue
|
| 96 |
|
| 97 |
+
# Tokenize (without special tokens)
|
| 98 |
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 99 |
|
| 100 |
if not token_ids:
|
| 101 |
continue
|
| 102 |
|
| 103 |
+
# Append EOS token (marks document boundary)
|
| 104 |
if self.config.use_eos_separator:
|
| 105 |
token_ids.append(self.tokenizer.eos_id)
|
| 106 |
|
| 107 |
+
# Add to buffer
|
| 108 |
buffer.extend(token_ids)
|
| 109 |
|
| 110 |
+
# Generate sequences once the buffer is full enough
|
| 111 |
+
# +1 is needed to generate targets (input + next token)
|
| 112 |
while len(buffer) >= self.max_seq_len + 1:
|
| 113 |
+
# Extract max_seq_len + 1 tokens
|
| 114 |
chunk = buffer[: self.max_seq_len + 1]
|
| 115 |
buffer = buffer[self.max_seq_len + 1 :]
|
| 116 |
|
| 117 |
+
# input_ids: from the first to the second-to-last token
|
| 118 |
input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
|
| 119 |
+
# targets: from the second to the last token (shifted by one)
|
| 120 |
targets = torch.tensor(chunk[1:], dtype=torch.long)
|
| 121 |
|
| 122 |
yield {"input_ids": input_ids, "targets": targets}
|
| 123 |
|
| 124 |
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 125 |
+
"""Iterator called by DataLoader.
|
| 126 |
|
| 127 |
+
Multi-worker support (full partitioning approach):
|
| 128 |
+
- Previous: all workers read the same stream with different seeds → possible document duplication
|
| 129 |
+
- Improved: ds.shard() splits the stream into num_workers parts → no document overlap between workers
|
| 130 |
|
| 131 |
+
Example (num_workers=4, total N documents):
|
| 132 |
+
Worker 0: docs 0, 4, 8, 12, ... (N/4 docs)
|
| 133 |
+
Worker 1: docs 1, 5, 9, 13, ... (N/4 docs)
|
| 134 |
+
Worker 2: docs 2, 6, 10, 14, ... (N/4 docs)
|
| 135 |
+
Worker 3: docs 3, 7, 11, 15, ... (N/4 docs)
|
| 136 |
"""
|
| 137 |
worker_info = torch.utils.data.get_worker_info()
|
| 138 |
|
| 139 |
if worker_info is not None:
|
| 140 |
+
# Full partitioning: assign a shard per worker + independent shuffle seed
|
| 141 |
num_shards = worker_info.num_workers
|
| 142 |
shard_index = worker_info.id
|
| 143 |
worker_seed = self.seed + worker_info.id
|
| 144 |
else:
|
| 145 |
+
# Single process: process the full stream without sharding
|
| 146 |
num_shards = 1
|
| 147 |
shard_index = 0
|
| 148 |
worker_seed = self.seed
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
class ValidationDataset:
|
| 157 |
+
"""Validation dataset.
|
| 158 |
|
| 159 |
+
Pre-fetches a fixed amount of data from the streaming dataset and stores it in memory.
|
| 160 |
+
Consistent data across evaluations is necessary for meaningful comparisons between epochs.
|
| 161 |
"""
|
| 162 |
|
| 163 |
def __init__(
|
|
|
|
| 175 |
self._prepare(seed)
|
| 176 |
|
| 177 |
def _prepare(self, seed: int):
|
| 178 |
+
"""Pre-extracts validation samples from the dataset."""
|
| 179 |
from datasets import load_dataset
|
| 180 |
|
| 181 |
+
print(f"[Validation] Preparing {self.num_samples} validation samples...")
|
| 182 |
|
| 183 |
ds = load_dataset(
|
| 184 |
self.config.dataset_name,
|
|
|
|
| 187 |
streaming=True,
|
| 188 |
trust_remote_code=True,
|
| 189 |
)
|
| 190 |
+
# Use a different seed and skip the beginning to avoid overlap with training data
|
| 191 |
ds = ds.shuffle(seed=seed, buffer_size=5_000)
|
| 192 |
|
| 193 |
buffer: List[int] = []
|
|
|
|
| 218 |
})
|
| 219 |
count += 1
|
| 220 |
|
| 221 |
+
print(f"[Validation] {len(self.samples)} samples ready")
|
| 222 |
|
| 223 |
def get_dataloader(self, batch_size: int) -> DataLoader:
|
| 224 |
+
"""Returns a validation DataLoader."""
|
| 225 |
return DataLoader(
|
| 226 |
self.samples,
|
| 227 |
batch_size=batch_size,
|
|
|
|
| 232 |
|
| 233 |
|
| 234 |
def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 235 |
+
"""Combines samples in a batch into a single tensor.
|
| 236 |
|
| 237 |
+
Because of sequence packing, all samples have the same length (max_seq_len),
|
| 238 |
+
so no additional padding is needed.
|
| 239 |
"""
|
| 240 |
return {
|
| 241 |
"input_ids": torch.stack([s["input_ids"] for s in batch]),
|
llm_lab/data/diagnostics.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
import time
|
| 4 |
from typing import Dict
|
|
@@ -11,13 +11,13 @@ from .tokenizer import Tokenizer
|
|
| 11 |
|
| 12 |
|
| 13 |
class DataPipelineDiagnostics:
|
| 14 |
-
"""
|
| 15 |
|
| 16 |
-
|
| 17 |
-
1)
|
| 18 |
-
2)
|
| 19 |
-
3)
|
| 20 |
-
4)
|
| 21 |
"""
|
| 22 |
|
| 23 |
@staticmethod
|
|
@@ -26,11 +26,11 @@ class DataPipelineDiagnostics:
|
|
| 26 |
config: DataConfig,
|
| 27 |
num_samples: int = 1000,
|
| 28 |
):
|
| 29 |
-
"""
|
| 30 |
from datasets import load_dataset
|
| 31 |
|
| 32 |
print("\n" + "=" * 60)
|
| 33 |
-
print("
|
| 34 |
print("=" * 60)
|
| 35 |
|
| 36 |
ds = load_dataset(
|
|
@@ -59,24 +59,24 @@ class DataPipelineDiagnostics:
|
|
| 59 |
|
| 60 |
avg_tokens = sum(token_counts) / len(token_counts)
|
| 61 |
avg_chars = sum(char_counts) / len(char_counts)
|
| 62 |
-
compression_ratio = avg_chars / avg_tokens #
|
| 63 |
|
| 64 |
-
print(f"
|
| 65 |
-
print(f"
|
| 66 |
-
print(f"
|
| 67 |
-
print(f"
|
| 68 |
-
print(f"
|
| 69 |
-
print(f"
|
| 70 |
|
| 71 |
-
#
|
| 72 |
test_text = "The quick brown fox jumps over the lazy dog."
|
| 73 |
encoded = tokenizer.encode(test_text)
|
| 74 |
decoded = tokenizer.decode(encoded)
|
| 75 |
roundtrip_ok = test_text.strip() in decoded.strip()
|
| 76 |
-
print(f"\n
|
| 77 |
-
print(f"
|
| 78 |
-
print(f"
|
| 79 |
-
print(f"
|
| 80 |
|
| 81 |
@staticmethod
|
| 82 |
def benchmark_throughput(
|
|
@@ -84,13 +84,13 @@ class DataPipelineDiagnostics:
|
|
| 84 |
num_batches: int = 50,
|
| 85 |
seq_len: int = 2048,
|
| 86 |
):
|
| 87 |
-
"""
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
"""
|
| 92 |
print("\n" + "=" * 60)
|
| 93 |
-
print("
|
| 94 |
print("=" * 60)
|
| 95 |
|
| 96 |
total_tokens = 0
|
|
@@ -110,23 +110,23 @@ class DataPipelineDiagnostics:
|
|
| 110 |
elapsed = time.time() - start_time
|
| 111 |
tps = total_tokens / elapsed
|
| 112 |
|
| 113 |
-
print(f"\n
|
| 114 |
-
print(f"
|
| 115 |
-
print(f"
|
| 116 |
-
print(f"
|
| 117 |
-
print(f"\n
|
| 118 |
if tps > 80_000:
|
| 119 |
-
print(f"
|
| 120 |
elif tps > 30_000:
|
| 121 |
-
print(f"
|
| 122 |
else:
|
| 123 |
-
print(f"
|
| 124 |
|
| 125 |
@staticmethod
|
| 126 |
def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
|
| 127 |
-
"""
|
| 128 |
print("\n" + "=" * 60)
|
| 129 |
-
print("
|
| 130 |
print("=" * 60)
|
| 131 |
|
| 132 |
input_ids = batch["input_ids"]
|
|
@@ -135,19 +135,19 @@ class DataPipelineDiagnostics:
|
|
| 135 |
print(f" input_ids shape: {input_ids.shape}")
|
| 136 |
print(f" targets shape: {targets.shape}")
|
| 137 |
print(f" dtype: {input_ids.dtype}")
|
| 138 |
-
print(f"
|
| 139 |
|
| 140 |
-
#
|
| 141 |
shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
|
| 142 |
-
print(f" Shift
|
| 143 |
|
| 144 |
-
# EOS
|
| 145 |
eos_count = (input_ids == tokenizer.eos_id).sum().item()
|
| 146 |
total_tokens = input_ids.numel()
|
| 147 |
-
print(f" EOS
|
| 148 |
|
| 149 |
-
#
|
| 150 |
first_sample = input_ids[0][:100].tolist()
|
| 151 |
decoded_preview = tokenizer.decode(first_sample)
|
| 152 |
-
print(f"\n
|
| 153 |
print(f" {decoded_preview[:300]}...")
|
|
|
|
| 1 |
+
"""Data pipeline diagnostic tools."""
|
| 2 |
|
| 3 |
import time
|
| 4 |
from typing import Dict
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class DataPipelineDiagnostics:
|
| 14 |
+
"""Diagnoses the performance and quality of the data pipeline.
|
| 15 |
|
| 16 |
+
Items to verify before training:
|
| 17 |
+
1) Tokenizer quality: average tokens/document, unknown token ratio
|
| 18 |
+
2) Packing efficiency: actual token ratio vs. padding ratio
|
| 19 |
+
3) Throughput: tokens/sec (check for data loading bottlenecks)
|
| 20 |
+
4) Batch shape: correctness of shape and dtype
|
| 21 |
"""
|
| 22 |
|
| 23 |
@staticmethod
|
|
|
|
| 26 |
config: DataConfig,
|
| 27 |
num_samples: int = 1000,
|
| 28 |
):
|
| 29 |
+
"""Diagnoses tokenizer quality."""
|
| 30 |
from datasets import load_dataset
|
| 31 |
|
| 32 |
print("\n" + "=" * 60)
|
| 33 |
+
print("Tokenizer Quality Diagnostics")
|
| 34 |
print("=" * 60)
|
| 35 |
|
| 36 |
ds = load_dataset(
|
|
|
|
| 59 |
|
| 60 |
avg_tokens = sum(token_counts) / len(token_counts)
|
| 61 |
avg_chars = sum(char_counts) / len(char_counts)
|
| 62 |
+
compression_ratio = avg_chars / avg_tokens # Characters per token ratio
|
| 63 |
|
| 64 |
+
print(f" Documents analyzed: {len(token_counts):,}")
|
| 65 |
+
print(f" Average tokens/document: {avg_tokens:.1f}")
|
| 66 |
+
print(f" Average chars/document: {avg_chars:.1f}")
|
| 67 |
+
print(f" Compression ratio (chars/token): {compression_ratio:.2f}")
|
| 68 |
+
print(f" -> 3.5~4.5 is normal for English")
|
| 69 |
+
print(f" Min tokens: {min(token_counts)}, Max: {max(token_counts)}")
|
| 70 |
|
| 71 |
+
# Round-trip decode test
|
| 72 |
test_text = "The quick brown fox jumps over the lazy dog."
|
| 73 |
encoded = tokenizer.encode(test_text)
|
| 74 |
decoded = tokenizer.decode(encoded)
|
| 75 |
roundtrip_ok = test_text.strip() in decoded.strip()
|
| 76 |
+
print(f"\n Round-trip test: {'PASSED' if roundtrip_ok else 'FAILED'}")
|
| 77 |
+
print(f" Original: {test_text}")
|
| 78 |
+
print(f" Encoded: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
|
| 79 |
+
print(f" Decoded: {decoded}")
|
| 80 |
|
| 81 |
@staticmethod
|
| 82 |
def benchmark_throughput(
|
|
|
|
| 84 |
num_batches: int = 50,
|
| 85 |
seq_len: int = 2048,
|
| 86 |
):
|
| 87 |
+
"""Measures data loading throughput.
|
| 88 |
|
| 89 |
+
A key diagnostic to determine whether data loading is the bottleneck in GPU training.
|
| 90 |
+
Goal: data loading should be faster than GPU computation (data loading != bottleneck).
|
| 91 |
"""
|
| 92 |
print("\n" + "=" * 60)
|
| 93 |
+
print("Data Loading Throughput Benchmark")
|
| 94 |
print("=" * 60)
|
| 95 |
|
| 96 |
total_tokens = 0
|
|
|
|
| 110 |
elapsed = time.time() - start_time
|
| 111 |
tps = total_tokens / elapsed
|
| 112 |
|
| 113 |
+
print(f"\n Total batches: {num_batches}")
|
| 114 |
+
print(f" Total tokens: {total_tokens:,}")
|
| 115 |
+
print(f" Elapsed time: {elapsed:.2f}s")
|
| 116 |
+
print(f" Average throughput: {tps:,.0f} tokens/sec")
|
| 117 |
+
print(f"\n A100 training throughput reference ~50-80K tokens/sec:")
|
| 118 |
if tps > 80_000:
|
| 119 |
+
print(f" Data loading is not the bottleneck")
|
| 120 |
elif tps > 30_000:
|
| 121 |
+
print(f" Borderline - consider increasing num_workers")
|
| 122 |
else:
|
| 123 |
+
print(f" Data loading is the bottleneck! Adjust num_workers/prefetch")
|
| 124 |
|
| 125 |
@staticmethod
|
| 126 |
def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
|
| 127 |
+
"""Inspects a single batch in detail."""
|
| 128 |
print("\n" + "=" * 60)
|
| 129 |
+
print("Batch Detailed Inspection")
|
| 130 |
print("=" * 60)
|
| 131 |
|
| 132 |
input_ids = batch["input_ids"]
|
|
|
|
| 135 |
print(f" input_ids shape: {input_ids.shape}")
|
| 136 |
print(f" targets shape: {targets.shape}")
|
| 137 |
print(f" dtype: {input_ids.dtype}")
|
| 138 |
+
print(f" value range: [{input_ids.min().item()}, {input_ids.max().item()}]")
|
| 139 |
|
| 140 |
+
# Verify shift relationship: targets[i] == input_ids[i+1]
|
| 141 |
shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
|
| 142 |
+
print(f" Shift consistency: {shift_correct*100:.1f}% (should be 100%)")
|
| 143 |
|
| 144 |
+
# EOS token distribution (document boundaries)
|
| 145 |
eos_count = (input_ids == tokenizer.eos_id).sum().item()
|
| 146 |
total_tokens = input_ids.numel()
|
| 147 |
+
print(f" EOS token count: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
|
| 148 |
|
| 149 |
+
# Decode preview of the first sample
|
| 150 |
first_sample = input_ids[0][:100].tolist()
|
| 151 |
decoded_preview = tokenizer.decode(first_sample)
|
| 152 |
+
print(f"\n First sample decoded (first 100 tokens):")
|
| 153 |
print(f" {decoded_preview[:300]}...")
|
llm_lab/data/pipeline.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
|
|
@@ -15,12 +15,12 @@ def create_train_dataloader(
|
|
| 15 |
config: DataConfig,
|
| 16 |
seed: int = 42,
|
| 17 |
) -> DataLoader:
|
| 18 |
-
"""
|
| 19 |
|
| 20 |
Returns:
|
| 21 |
-
|
| 22 |
|
| 23 |
-
|
| 24 |
dataloader = create_train_dataloader(tokenizer, config)
|
| 25 |
for step, batch in enumerate(dataloader):
|
| 26 |
input_ids = batch["input_ids"].to(device) # (B, seq_len)
|
|
@@ -40,7 +40,7 @@ def create_train_dataloader(
|
|
| 40 |
batch_size=config.batch_size,
|
| 41 |
num_workers=config.num_workers,
|
| 42 |
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
|
| 43 |
-
pin_memory=True, # GPU
|
| 44 |
collate_fn=_collate_fn,
|
| 45 |
)
|
| 46 |
|
|
@@ -48,17 +48,17 @@ def create_train_dataloader(
|
|
| 48 |
|
| 49 |
|
| 50 |
def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
|
| 51 |
-
"""
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
"""
|
| 56 |
from datasets import load_dataset
|
| 57 |
|
| 58 |
-
print(f"[Train Tokenizer] {config.dataset_name}
|
| 59 |
-
print(f"[Train Tokenizer]
|
| 60 |
|
| 61 |
-
#
|
| 62 |
ds = load_dataset(
|
| 63 |
config.dataset_name,
|
| 64 |
name=config.dataset_subset,
|
|
@@ -77,9 +77,9 @@ def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
|
|
| 77 |
yield text
|
| 78 |
count += 1
|
| 79 |
if count % 10_000 == 0:
|
| 80 |
-
print(f" ... {count:,}
|
| 81 |
|
| 82 |
-
#
|
| 83 |
tokenizer = Tokenizer(config)
|
| 84 |
tokenizer.train_bpe(text_iterator(), save_dir=config.tokenizer_save_dir)
|
| 85 |
|
|
@@ -91,38 +91,38 @@ def setup_data_pipeline(
|
|
| 91 |
tokenizer_path: Optional[str] = None,
|
| 92 |
config: Optional[DataConfig] = None,
|
| 93 |
) -> tuple:
|
| 94 |
-
"""
|
| 95 |
|
| 96 |
Args:
|
| 97 |
tokenizer_mode:
|
| 98 |
-
"train_new" -
|
| 99 |
-
"load_trained" -
|
| 100 |
-
"pretrained" -
|
| 101 |
tokenizer_path:
|
| 102 |
-
"train_new"
|
| 103 |
-
"load_trained"
|
| 104 |
-
"pretrained"
|
| 105 |
|
| 106 |
Returns:
|
| 107 |
(tokenizer, train_dataloader, val_dataloader)
|
| 108 |
|
| 109 |
-
|
| 110 |
-
#
|
| 111 |
tok, train_dl, val_dl = setup_data_pipeline("train_new")
|
| 112 |
|
| 113 |
-
#
|
| 114 |
tok, train_dl, val_dl = setup_data_pipeline("load_trained", "./tokenizer")
|
| 115 |
|
| 116 |
-
#
|
| 117 |
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 118 |
"""
|
| 119 |
config = config or DataConfig()
|
| 120 |
|
| 121 |
print("=" * 60)
|
| 122 |
-
print("
|
| 123 |
print("=" * 60)
|
| 124 |
|
| 125 |
-
# ── Step 1:
|
| 126 |
tokenizer = Tokenizer(config)
|
| 127 |
|
| 128 |
if tokenizer_mode == "train_new":
|
|
@@ -136,21 +136,21 @@ def setup_data_pipeline(
|
|
| 136 |
else:
|
| 137 |
raise ValueError(f"Unknown tokenizer_mode: {tokenizer_mode}")
|
| 138 |
|
| 139 |
-
# ── Step 2:
|
| 140 |
-
print("\n[DataLoader]
|
| 141 |
train_dataloader = create_train_dataloader(tokenizer, config)
|
| 142 |
|
| 143 |
-
# ── Step 3:
|
| 144 |
-
print("\n[DataLoader]
|
| 145 |
val_dataset = ValidationDataset(tokenizer, config, num_samples=100)
|
| 146 |
val_dataloader = val_dataset.get_dataloader(batch_size=config.batch_size)
|
| 147 |
|
| 148 |
print("\n" + "=" * 60)
|
| 149 |
-
print("
|
| 150 |
-
print(f"
|
| 151 |
-
print(f"
|
| 152 |
-
print(f"
|
| 153 |
-
print(f"
|
| 154 |
print("=" * 60)
|
| 155 |
|
| 156 |
return tokenizer, train_dataloader, val_dataloader
|
|
|
|
| 1 |
+
"""Data pipeline integration — DataLoader creation, tokenizer training, and Quick Start."""
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
|
|
|
|
| 15 |
config: DataConfig,
|
| 16 |
seed: int = 42,
|
| 17 |
) -> DataLoader:
|
| 18 |
+
"""Creates a training DataLoader.
|
| 19 |
|
| 20 |
Returns:
|
| 21 |
+
An infinitely repeating streaming DataLoader
|
| 22 |
|
| 23 |
+
Usage:
|
| 24 |
dataloader = create_train_dataloader(tokenizer, config)
|
| 25 |
for step, batch in enumerate(dataloader):
|
| 26 |
input_ids = batch["input_ids"].to(device) # (B, seq_len)
|
|
|
|
| 40 |
batch_size=config.batch_size,
|
| 41 |
num_workers=config.num_workers,
|
| 42 |
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
|
| 43 |
+
pin_memory=True, # Improves GPU transfer speed
|
| 44 |
collate_fn=_collate_fn,
|
| 45 |
)
|
| 46 |
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
|
| 51 |
+
"""Trains a BPE tokenizer from the dataset.
|
| 52 |
|
| 53 |
+
There is no need to use the entire dataset; 50K documents is sufficient,
|
| 54 |
+
since the tokenizer vocab only needs to reflect the statistics of the full data.
|
| 55 |
"""
|
| 56 |
from datasets import load_dataset
|
| 57 |
|
| 58 |
+
print(f"[Train Tokenizer] Training tokenizer from {config.dataset_name}")
|
| 59 |
+
print(f"[Train Tokenizer] Number of training documents: {config.tokenizer_train_samples:,}")
|
| 60 |
|
| 61 |
+
# Create text iterator
|
| 62 |
ds = load_dataset(
|
| 63 |
config.dataset_name,
|
| 64 |
name=config.dataset_subset,
|
|
|
|
| 77 |
yield text
|
| 78 |
count += 1
|
| 79 |
if count % 10_000 == 0:
|
| 80 |
+
print(f" ... {count:,} documents processed")
|
| 81 |
|
| 82 |
+
# Train tokenizer
|
| 83 |
tokenizer = Tokenizer(config)
|
| 84 |
tokenizer.train_bpe(text_iterator(), save_dir=config.tokenizer_save_dir)
|
| 85 |
|
|
|
|
| 91 |
tokenizer_path: Optional[str] = None,
|
| 92 |
config: Optional[DataConfig] = None,
|
| 93 |
) -> tuple:
|
| 94 |
+
"""Sets up the data pipeline in one call.
|
| 95 |
|
| 96 |
Args:
|
| 97 |
tokenizer_mode:
|
| 98 |
+
"train_new" - Train a new BPE tokenizer
|
| 99 |
+
"load_trained" - Load a previously trained tokenizer
|
| 100 |
+
"pretrained" - Use a pretrained HuggingFace tokenizer
|
| 101 |
tokenizer_path:
|
| 102 |
+
"train_new" -> Save directory (default: ./tokenizer)
|
| 103 |
+
"load_trained" -> Path to the saved tokenizer
|
| 104 |
+
"pretrained" -> HF model name (default: mistralai/Mistral-7B-v0.1)
|
| 105 |
|
| 106 |
Returns:
|
| 107 |
(tokenizer, train_dataloader, val_dataloader)
|
| 108 |
|
| 109 |
+
Example usage (Colab):
|
| 110 |
+
# Method 1: Train a new tokenizer
|
| 111 |
tok, train_dl, val_dl = setup_data_pipeline("train_new")
|
| 112 |
|
| 113 |
+
# Method 2: Load an existing tokenizer
|
| 114 |
tok, train_dl, val_dl = setup_data_pipeline("load_trained", "./tokenizer")
|
| 115 |
|
| 116 |
+
# Method 3: Use a pretrained tokenizer (simplest)
|
| 117 |
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 118 |
"""
|
| 119 |
config = config or DataConfig()
|
| 120 |
|
| 121 |
print("=" * 60)
|
| 122 |
+
print("Data Pipeline Setup")
|
| 123 |
print("=" * 60)
|
| 124 |
|
| 125 |
+
# ── Step 1: Tokenizer ──
|
| 126 |
tokenizer = Tokenizer(config)
|
| 127 |
|
| 128 |
if tokenizer_mode == "train_new":
|
|
|
|
| 136 |
else:
|
| 137 |
raise ValueError(f"Unknown tokenizer_mode: {tokenizer_mode}")
|
| 138 |
|
| 139 |
+
# ── Step 2: Training DataLoader ──
|
| 140 |
+
print("\n[DataLoader] Creating training DataLoader...")
|
| 141 |
train_dataloader = create_train_dataloader(tokenizer, config)
|
| 142 |
|
| 143 |
+
# ── Step 3: Validation DataLoader ──
|
| 144 |
+
print("\n[DataLoader] Creating validation DataLoader...")
|
| 145 |
val_dataset = ValidationDataset(tokenizer, config, num_samples=100)
|
| 146 |
val_dataloader = val_dataset.get_dataloader(batch_size=config.batch_size)
|
| 147 |
|
| 148 |
print("\n" + "=" * 60)
|
| 149 |
+
print("Data pipeline setup complete!")
|
| 150 |
+
print(f" Tokenizer vocab: {tokenizer.vocab_size:,}")
|
| 151 |
+
print(f" Sequence length: {config.max_seq_len}")
|
| 152 |
+
print(f" Batch size: {config.batch_size}")
|
| 153 |
+
print(f" Tokens/batch: {config.batch_size * config.max_seq_len:,}")
|
| 154 |
print("=" * 60)
|
| 155 |
|
| 156 |
return tokenizer, train_dataloader, val_dataloader
|
llm_lab/data/tokenizer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
import os
|
| 4 |
import json
|
|
@@ -8,23 +8,23 @@ from llm_lab.config import DataConfig
|
|
| 8 |
|
| 9 |
|
| 10 |
class Tokenizer:
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
1)
|
| 15 |
-
2)
|
| 16 |
-
3)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
- BPE
|
| 20 |
-
|
| 21 |
-
-
|
| 22 |
-
|
| 23 |
-
BPE(Byte Pair Encoding)
|
| 24 |
-
1)
|
| 25 |
-
2)
|
| 26 |
-
3)
|
| 27 |
-
→
|
| 28 |
"""
|
| 29 |
|
| 30 |
def __init__(self, config: DataConfig):
|
|
@@ -32,17 +32,17 @@ class Tokenizer:
|
|
| 32 |
self._tokenizer = None
|
| 33 |
self.vocab_size = config.vocab_size
|
| 34 |
|
| 35 |
-
#
|
| 36 |
self.bos_id: int = 1 # Beginning of Sequence
|
| 37 |
self.eos_id: int = 2 # End of Sequence
|
| 38 |
self.pad_id: int = 0 # Padding
|
| 39 |
|
| 40 |
# ────────────────────────────────────────────────
|
| 41 |
-
#
|
| 42 |
# ────────────────────────────────────────────────
|
| 43 |
|
| 44 |
def load_sentencepiece(self, model_path: str):
|
| 45 |
-
"""
|
| 46 |
import sentencepiece as spm
|
| 47 |
|
| 48 |
self._tokenizer = spm.SentencePieceProcessor()
|
|
@@ -55,23 +55,23 @@ class Tokenizer:
|
|
| 55 |
self._encode_fn = self._tokenizer.Encode
|
| 56 |
self._decode_fn = self._tokenizer.Decode
|
| 57 |
|
| 58 |
-
print(f"[Tokenizer] SentencePiece
|
| 59 |
|
| 60 |
# ────────────────────────────────────────────────
|
| 61 |
-
#
|
| 62 |
# ────────────────────────────────────────────────
|
| 63 |
|
| 64 |
def train_bpe(self, text_iterator: Iterator[str], save_dir: Optional[str] = None):
|
| 65 |
-
"""BPE
|
| 66 |
|
| 67 |
Args:
|
| 68 |
-
text_iterator:
|
| 69 |
-
save_dir:
|
| 70 |
|
| 71 |
-
|
| 72 |
-
-
|
| 73 |
-
-
|
| 74 |
-
- 32K
|
| 75 |
"""
|
| 76 |
from tokenizers import Tokenizer as HFTokenizer
|
| 77 |
from tokenizers.models import BPE
|
|
@@ -79,27 +79,27 @@ class Tokenizer:
|
|
| 79 |
from tokenizers.pre_tokenizers import ByteLevel
|
| 80 |
from tokenizers.processors import TemplateProcessing
|
| 81 |
|
| 82 |
-
print("[Tokenizer] BPE
|
| 83 |
|
| 84 |
-
# BPE
|
| 85 |
tokenizer = HFTokenizer(BPE(unk_token="<unk>"))
|
| 86 |
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
|
| 87 |
|
| 88 |
-
#
|
| 89 |
special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
|
| 90 |
|
| 91 |
-
#
|
| 92 |
trainer = BpeTrainer(
|
| 93 |
vocab_size=self.config.vocab_size,
|
| 94 |
special_tokens=special_tokens,
|
| 95 |
-
min_frequency=2, #
|
| 96 |
show_progress=True,
|
| 97 |
)
|
| 98 |
|
| 99 |
-
#
|
| 100 |
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
|
| 101 |
|
| 102 |
-
#
|
| 103 |
tokenizer.post_processor = TemplateProcessing(
|
| 104 |
single="<s> $A </s>",
|
| 105 |
special_tokens=[("<s>", 1), ("</s>", 2)],
|
|
@@ -114,11 +114,11 @@ class Tokenizer:
|
|
| 114 |
self._encode_fn = lambda text: tokenizer.encode(text).ids
|
| 115 |
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 116 |
|
| 117 |
-
#
|
| 118 |
save_dir = save_dir or self.config.tokenizer_save_dir
|
| 119 |
os.makedirs(save_dir, exist_ok=True)
|
| 120 |
tokenizer.save(os.path.join(save_dir, "tokenizer.json"))
|
| 121 |
-
#
|
| 122 |
meta = {
|
| 123 |
"vocab_size": self.vocab_size,
|
| 124 |
"bos_id": self.bos_id,
|
|
@@ -128,23 +128,23 @@ class Tokenizer:
|
|
| 128 |
with open(os.path.join(save_dir, "tokenizer_meta.json"), "w") as f:
|
| 129 |
json.dump(meta, f, indent=2)
|
| 130 |
|
| 131 |
-
print(f"[Tokenizer]
|
| 132 |
-
print(f"[Tokenizer]
|
| 133 |
|
| 134 |
# ────────────────────────────────────────────────
|
| 135 |
-
#
|
| 136 |
# ────────────────────────────────────────────────
|
| 137 |
|
| 138 |
def load_pretrained_hf(self, name_or_path: str = "meta-llama/Llama-2-7b-hf"):
|
| 139 |
-
"""
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
"""
|
| 145 |
from transformers import AutoTokenizer
|
| 146 |
|
| 147 |
-
print(f"[Tokenizer] HF
|
| 148 |
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
|
| 149 |
|
| 150 |
self._tokenizer = tokenizer
|
|
@@ -156,10 +156,10 @@ class Tokenizer:
|
|
| 156 |
self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
|
| 157 |
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 158 |
|
| 159 |
-
print(f"[Tokenizer]
|
| 160 |
|
| 161 |
def load_trained_hf(self, path: str):
|
| 162 |
-
"""
|
| 163 |
from tokenizers import Tokenizer as HFTokenizer
|
| 164 |
|
| 165 |
tokenizer = HFTokenizer.from_file(os.path.join(path, "tokenizer.json"))
|
|
@@ -175,21 +175,21 @@ class Tokenizer:
|
|
| 175 |
self._encode_fn = lambda text: tokenizer.encode(text).ids
|
| 176 |
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 177 |
|
| 178 |
-
print(f"[Tokenizer]
|
| 179 |
|
| 180 |
# ────────────────────────────────────────────────
|
| 181 |
-
#
|
| 182 |
# ────────────────────────────────────────────────
|
| 183 |
|
| 184 |
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
|
| 185 |
-
"""
|
| 186 |
ids = self._encode_fn(text)
|
| 187 |
if add_special_tokens:
|
| 188 |
ids = [self.bos_id] + ids + [self.eos_id]
|
| 189 |
return ids
|
| 190 |
|
| 191 |
def decode(self, ids: List[int]) -> str:
|
| 192 |
-
"""
|
| 193 |
return self._decode_fn(ids)
|
| 194 |
|
| 195 |
def __len__(self) -> int:
|
|
|
|
| 1 |
+
"""Tokenizer wrapper — SentencePiece / HuggingFace BPE integration."""
|
| 2 |
|
| 3 |
import os
|
| 4 |
import json
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class Tokenizer:
|
| 11 |
+
"""Unified tokenizer wrapper.
|
| 12 |
+
|
| 13 |
+
Supports three methods:
|
| 14 |
+
1) Load an existing SentencePiece model
|
| 15 |
+
2) Train a new tokenizer using the HuggingFace tokenizers library
|
| 16 |
+
3) Load a pretrained HF tokenizer (e.g., LLaMA tokenizer)
|
| 17 |
+
|
| 18 |
+
Why not implement from scratch?
|
| 19 |
+
- Training a BPE tokenizer involves large-scale text statistics processing,
|
| 20 |
+
which has little direct relevance to understanding model architecture.
|
| 21 |
+
- However, understanding how a tokenizer works (BPE merge rules) is still important.
|
| 22 |
+
|
| 23 |
+
BPE (Byte Pair Encoding) core principle:
|
| 24 |
+
1) Split text into byte/character units
|
| 25 |
+
2) Repeatedly merge the most frequent adjacent pair
|
| 26 |
+
3) Repeat until vocab_size is reached
|
| 27 |
+
→ Frequent words become a single token; rare words are split into multiple tokens
|
| 28 |
"""
|
| 29 |
|
| 30 |
def __init__(self, config: DataConfig):
|
|
|
|
| 32 |
self._tokenizer = None
|
| 33 |
self.vocab_size = config.vocab_size
|
| 34 |
|
| 35 |
+
# Special token IDs (set after initialization)
|
| 36 |
self.bos_id: int = 1 # Beginning of Sequence
|
| 37 |
self.eos_id: int = 2 # End of Sequence
|
| 38 |
self.pad_id: int = 0 # Padding
|
| 39 |
|
| 40 |
# ────────────────────────────────────────────────
|
| 41 |
+
# Method 1: Load a SentencePiece model
|
| 42 |
# ────────────────────────────────────────────────
|
| 43 |
|
| 44 |
def load_sentencepiece(self, model_path: str):
|
| 45 |
+
"""Loads an existing SentencePiece model."""
|
| 46 |
import sentencepiece as spm
|
| 47 |
|
| 48 |
self._tokenizer = spm.SentencePieceProcessor()
|
|
|
|
| 55 |
self._encode_fn = self._tokenizer.Encode
|
| 56 |
self._decode_fn = self._tokenizer.Decode
|
| 57 |
|
| 58 |
+
print(f"[Tokenizer] SentencePiece loaded: vocab_size={self.vocab_size}")
|
| 59 |
|
| 60 |
# ────────────────────────────────────────────────
|
| 61 |
+
# Method 2: Train a BPE tokenizer with HuggingFace tokenizers
|
| 62 |
# ────────────────────────────────────────────────
|
| 63 |
|
| 64 |
def train_bpe(self, text_iterator: Iterator[str], save_dir: Optional[str] = None):
|
| 65 |
+
"""Trains a BPE tokenizer from scratch.
|
| 66 |
|
| 67 |
Args:
|
| 68 |
+
text_iterator: Iterator that yields training text strings
|
| 69 |
+
save_dir: Directory path to save the trained tokenizer
|
| 70 |
|
| 71 |
+
Key insights:
|
| 72 |
+
- Larger vocab_size: common expressions become 1 token → shorter sequences
|
| 73 |
+
- Smaller vocab_size: saves embedding parameters, but sequences get longer
|
| 74 |
+
- 32K is a good balance point for English
|
| 75 |
"""
|
| 76 |
from tokenizers import Tokenizer as HFTokenizer
|
| 77 |
from tokenizers.models import BPE
|
|
|
|
| 79 |
from tokenizers.pre_tokenizers import ByteLevel
|
| 80 |
from tokenizers.processors import TemplateProcessing
|
| 81 |
|
| 82 |
+
print("[Tokenizer] Starting BPE tokenizer training...")
|
| 83 |
|
| 84 |
+
# Create BPE model
|
| 85 |
tokenizer = HFTokenizer(BPE(unk_token="<unk>"))
|
| 86 |
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
|
| 87 |
|
| 88 |
+
# Define special tokens
|
| 89 |
special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
|
| 90 |
|
| 91 |
+
# Configure trainer
|
| 92 |
trainer = BpeTrainer(
|
| 93 |
vocab_size=self.config.vocab_size,
|
| 94 |
special_tokens=special_tokens,
|
| 95 |
+
min_frequency=2, # Only merge pairs that appear at least twice
|
| 96 |
show_progress=True,
|
| 97 |
)
|
| 98 |
|
| 99 |
+
# Run training
|
| 100 |
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
|
| 101 |
|
| 102 |
+
# Post-processing: automatically add BOS/EOS
|
| 103 |
tokenizer.post_processor = TemplateProcessing(
|
| 104 |
single="<s> $A </s>",
|
| 105 |
special_tokens=[("<s>", 1), ("</s>", 2)],
|
|
|
|
| 114 |
self._encode_fn = lambda text: tokenizer.encode(text).ids
|
| 115 |
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 116 |
|
| 117 |
+
# Save
|
| 118 |
save_dir = save_dir or self.config.tokenizer_save_dir
|
| 119 |
os.makedirs(save_dir, exist_ok=True)
|
| 120 |
tokenizer.save(os.path.join(save_dir, "tokenizer.json"))
|
| 121 |
+
# Save metadata
|
| 122 |
meta = {
|
| 123 |
"vocab_size": self.vocab_size,
|
| 124 |
"bos_id": self.bos_id,
|
|
|
|
| 128 |
with open(os.path.join(save_dir, "tokenizer_meta.json"), "w") as f:
|
| 129 |
json.dump(meta, f, indent=2)
|
| 130 |
|
| 131 |
+
print(f"[Tokenizer] Training complete: vocab_size={self.vocab_size}")
|
| 132 |
+
print(f"[Tokenizer] Saved to: {save_dir}")
|
| 133 |
|
| 134 |
# ────────────────────────────────────────────────
|
| 135 |
+
# Method 3: Load a pretrained HF tokenizer
|
| 136 |
# ────────────────────────────────────────────────
|
| 137 |
|
| 138 |
def load_pretrained_hf(self, name_or_path: str = "meta-llama/Llama-2-7b-hf"):
|
| 139 |
+
"""Loads a pretrained tokenizer from HuggingFace.
|
| 140 |
|
| 141 |
+
The simplest method. The LLaMA tokenizer has a 32K vocab and is BPE-based.
|
| 142 |
+
Note: meta-llama models may require HF approval to access.
|
| 143 |
+
Alternative: mistralai/Mistral-7B-v0.1 (no approval required)
|
| 144 |
"""
|
| 145 |
from transformers import AutoTokenizer
|
| 146 |
|
| 147 |
+
print(f"[Tokenizer] Loading HF tokenizer: {name_or_path}")
|
| 148 |
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
|
| 149 |
|
| 150 |
self._tokenizer = tokenizer
|
|
|
|
| 156 |
self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
|
| 157 |
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 158 |
|
| 159 |
+
print(f"[Tokenizer] Loaded: vocab_size={self.vocab_size}")
|
| 160 |
|
| 161 |
def load_trained_hf(self, path: str):
|
| 162 |
+
"""Reloads a tokenizer previously trained with train_bpe()."""
|
| 163 |
from tokenizers import Tokenizer as HFTokenizer
|
| 164 |
|
| 165 |
tokenizer = HFTokenizer.from_file(os.path.join(path, "tokenizer.json"))
|
|
|
|
| 175 |
self._encode_fn = lambda text: tokenizer.encode(text).ids
|
| 176 |
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 177 |
|
| 178 |
+
print(f"[Tokenizer] Loaded: vocab_size={self.vocab_size}")
|
| 179 |
|
| 180 |
# ────────────────────────────────────────────────
|
| 181 |
+
# Common interface
|
| 182 |
# ────────────────────────────────────────────────
|
| 183 |
|
| 184 |
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
|
| 185 |
+
"""Text → list of token IDs."""
|
| 186 |
ids = self._encode_fn(text)
|
| 187 |
if add_special_tokens:
|
| 188 |
ids = [self.bos_id] + ids + [self.eos_id]
|
| 189 |
return ids
|
| 190 |
|
| 191 |
def decode(self, ids: List[int]) -> str:
|
| 192 |
+
"""List of token IDs → text."""
|
| 193 |
return self._decode_fn(ids)
|
| 194 |
|
| 195 |
def __len__(self) -> int:
|
llm_lab/evaluation/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from .perplexity import PerplexityEvaluator
|
| 4 |
from .generation import GenerationEvaluator
|
|
|
|
| 1 |
+
"""Evaluation module — Perplexity, text generation, Scaling Law, Attention visualization."""
|
| 2 |
|
| 3 |
from .perplexity import PerplexityEvaluator
|
| 4 |
from .generation import GenerationEvaluator
|
llm_lab/evaluation/attention_viz.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Attention
|
| 2 |
|
| 3 |
import math
|
| 4 |
from pathlib import Path
|
|
@@ -18,15 +18,15 @@ except ImportError:
|
|
| 18 |
|
| 19 |
|
| 20 |
class AttentionVisualizer:
|
| 21 |
-
"""
|
| 22 |
|
| 23 |
-
|
| 24 |
-
- Causal Mask:
|
| 25 |
-
-
|
| 26 |
-
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
→
|
| 30 |
"""
|
| 31 |
|
| 32 |
def __init__(self, save_dir: str = "./eval_results"):
|
|
@@ -41,10 +41,10 @@ class AttentionVisualizer:
|
|
| 41 |
layer_idx: int = 0,
|
| 42 |
device: torch.device = torch.device("cpu"),
|
| 43 |
) -> torch.Tensor:
|
| 44 |
-
"""
|
| 45 |
|
| 46 |
-
|
| 47 |
-
attention
|
| 48 |
|
| 49 |
Returns:
|
| 50 |
attention_weights: (num_heads, seq_len, seq_len)
|
|
@@ -52,10 +52,10 @@ class AttentionVisualizer:
|
|
| 52 |
model.eval()
|
| 53 |
captured_attn = {}
|
| 54 |
|
| 55 |
-
#
|
| 56 |
target_layer = model.layers[layer_idx].attention
|
| 57 |
|
| 58 |
-
# scaled_dot_product_attention
|
| 59 |
original_forward = target_layer.forward
|
| 60 |
|
| 61 |
def hooked_forward(x, mask=None, position_offset=0):
|
|
@@ -72,7 +72,7 @@ class AttentionVisualizer:
|
|
| 72 |
k = target_layer._repeat_kv(k)
|
| 73 |
v = target_layer._repeat_kv(v)
|
| 74 |
|
| 75 |
-
#
|
| 76 |
scale = 1.0 / math.sqrt(hd)
|
| 77 |
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
| 78 |
|
|
@@ -81,13 +81,13 @@ class AttentionVisualizer:
|
|
| 81 |
scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
|
| 82 |
|
| 83 |
attn_weights = F.softmax(scores, dim=-1)
|
| 84 |
-
captured_attn["weights"] = attn_weights[0].cpu() #
|
| 85 |
|
| 86 |
out = torch.matmul(attn_weights, v)
|
| 87 |
out = out.transpose(1, 2).contiguous().view(B, S, -1)
|
| 88 |
return target_layer.o_proj(out)
|
| 89 |
|
| 90 |
-
#
|
| 91 |
target_layer.forward = hooked_forward
|
| 92 |
|
| 93 |
try:
|
|
@@ -105,13 +105,13 @@ class AttentionVisualizer:
|
|
| 105 |
save_path: Optional[str] = None,
|
| 106 |
title: str = "Attention Weights",
|
| 107 |
):
|
| 108 |
-
"""
|
| 109 |
if not HAS_MATPLOTLIB:
|
| 110 |
-
print("⚠️ matplotlib
|
| 111 |
return
|
| 112 |
|
| 113 |
weights = attn_weights[head_idx].numpy()
|
| 114 |
-
max_len = min(len(tokens), 50) #
|
| 115 |
weights = weights[:max_len, :max_len]
|
| 116 |
display_tokens = tokens[:max_len]
|
| 117 |
|
|
@@ -132,7 +132,7 @@ class AttentionVisualizer:
|
|
| 132 |
|
| 133 |
save_path = save_path or str(self.save_dir / f"attention_head{head_idx}.png")
|
| 134 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 135 |
-
print(f" 📊 Attention
|
| 136 |
plt.close(fig)
|
| 137 |
|
| 138 |
def plot_multi_head_summary(
|
|
@@ -141,7 +141,7 @@ class AttentionVisualizer:
|
|
| 141 |
num_heads_to_show: int = 8,
|
| 142 |
save_path: Optional[str] = None,
|
| 143 |
):
|
| 144 |
-
"""
|
| 145 |
if not HAS_MATPLOTLIB:
|
| 146 |
return
|
| 147 |
|
|
@@ -162,7 +162,7 @@ class AttentionVisualizer:
|
|
| 162 |
ax.set_xticks([])
|
| 163 |
ax.set_yticks([])
|
| 164 |
|
| 165 |
-
#
|
| 166 |
for idx in range(n_heads, rows * cols):
|
| 167 |
r, c = idx // cols, idx % cols
|
| 168 |
axes[r, c].axis("off")
|
|
@@ -172,5 +172,5 @@ class AttentionVisualizer:
|
|
| 172 |
|
| 173 |
save_path = save_path or str(self.save_dir / "attention_multi_head.png")
|
| 174 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 175 |
-
print(f" 📊
|
| 176 |
plt.close(fig)
|
|
|
|
| 1 |
+
"""Attention pattern visualization."""
|
| 2 |
|
| 3 |
import math
|
| 4 |
from pathlib import Path
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class AttentionVisualizer:
|
| 21 |
+
"""Visualizes attention patterns.
|
| 22 |
|
| 23 |
+
Learning insights:
|
| 24 |
+
- Causal Mask: lower-triangular pattern (future tokens cannot be attended to)
|
| 25 |
+
- Head specialization: some heads focus locally (adjacent), others globally (distant tokens)
|
| 26 |
+
- Syntactic patterns: high attention on verb→subject, pronoun→antecedent, etc.
|
| 27 |
|
| 28 |
+
Note: Storing the full attention of a 1B model causes out-of-memory!
|
| 29 |
+
→ Visualize only selected layers/heads.
|
| 30 |
"""
|
| 31 |
|
| 32 |
def __init__(self, save_dir: str = "./eval_results"):
|
|
|
|
| 41 |
layer_idx: int = 0,
|
| 42 |
device: torch.device = torch.device("cpu"),
|
| 43 |
) -> torch.Tensor:
|
| 44 |
+
"""Extracts attention weights from a specific layer.
|
| 45 |
|
| 46 |
+
Temporarily modifies the model's attention module to
|
| 47 |
+
capture attention weights.
|
| 48 |
|
| 49 |
Returns:
|
| 50 |
attention_weights: (num_heads, seq_len, seq_len)
|
|
|
|
| 52 |
model.eval()
|
| 53 |
captured_attn = {}
|
| 54 |
|
| 55 |
+
# Capture attention weights via hook
|
| 56 |
target_layer = model.layers[layer_idx].attention
|
| 57 |
|
| 58 |
+
# Replace scaled_dot_product_attention with a manual implementation
|
| 59 |
original_forward = target_layer.forward
|
| 60 |
|
| 61 |
def hooked_forward(x, mask=None, position_offset=0):
|
|
|
|
| 72 |
k = target_layer._repeat_kv(k)
|
| 73 |
v = target_layer._repeat_kv(v)
|
| 74 |
|
| 75 |
+
# Manual attention computation (for weight extraction)
|
| 76 |
scale = 1.0 / math.sqrt(hd)
|
| 77 |
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
| 78 |
|
|
|
|
| 81 |
scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
|
| 82 |
|
| 83 |
attn_weights = F.softmax(scores, dim=-1)
|
| 84 |
+
captured_attn["weights"] = attn_weights[0].cpu() # first batch only
|
| 85 |
|
| 86 |
out = torch.matmul(attn_weights, v)
|
| 87 |
out = out.transpose(1, 2).contiguous().view(B, S, -1)
|
| 88 |
return target_layer.o_proj(out)
|
| 89 |
|
| 90 |
+
# Apply hook
|
| 91 |
target_layer.forward = hooked_forward
|
| 92 |
|
| 93 |
try:
|
|
|
|
| 105 |
save_path: Optional[str] = None,
|
| 106 |
title: str = "Attention Weights",
|
| 107 |
):
|
| 108 |
+
"""Draws an attention heatmap."""
|
| 109 |
if not HAS_MATPLOTLIB:
|
| 110 |
+
print("⚠️ matplotlib required")
|
| 111 |
return
|
| 112 |
|
| 113 |
weights = attn_weights[head_idx].numpy()
|
| 114 |
+
max_len = min(len(tokens), 50) # display at most 50 tokens
|
| 115 |
weights = weights[:max_len, :max_len]
|
| 116 |
display_tokens = tokens[:max_len]
|
| 117 |
|
|
|
|
| 132 |
|
| 133 |
save_path = save_path or str(self.save_dir / f"attention_head{head_idx}.png")
|
| 134 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 135 |
+
print(f" 📊 Attention visualization saved: {save_path}")
|
| 136 |
plt.close(fig)
|
| 137 |
|
| 138 |
def plot_multi_head_summary(
|
|
|
|
| 141 |
num_heads_to_show: int = 8,
|
| 142 |
save_path: Optional[str] = None,
|
| 143 |
):
|
| 144 |
+
"""Summarizes and compares attention patterns across multiple heads."""
|
| 145 |
if not HAS_MATPLOTLIB:
|
| 146 |
return
|
| 147 |
|
|
|
|
| 162 |
ax.set_xticks([])
|
| 163 |
ax.set_yticks([])
|
| 164 |
|
| 165 |
+
# Hide empty subplots
|
| 166 |
for idx in range(n_heads, rows * cols):
|
| 167 |
r, c = idx // cols, idx % cols
|
| 168 |
axes[r, c].axis("off")
|
|
|
|
| 172 |
|
| 173 |
save_path = save_path or str(self.save_dir / "attention_multi_head.png")
|
| 174 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 175 |
+
print(f" 📊 Multi-head summary saved: {save_path}")
|
| 176 |
plt.close(fig)
|
llm_lab/evaluation/checklist.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from typing import Any, Dict, Optional
|
| 4 |
|
| 5 |
|
| 6 |
class InsightChecklist:
|
| 7 |
-
"""
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
"""
|
| 12 |
|
| 13 |
@staticmethod
|
|
@@ -15,9 +15,9 @@ class InsightChecklist:
|
|
| 15 |
report: Dict[str, Any],
|
| 16 |
metrics_history: Optional[Dict[str, list]] = None,
|
| 17 |
):
|
| 18 |
-
"""
|
| 19 |
print("\n" + "=" * 70)
|
| 20 |
-
print("✅
|
| 21 |
print("=" * 70)
|
| 22 |
|
| 23 |
checks = {
|
|
@@ -26,74 +26,74 @@ class InsightChecklist:
|
|
| 26 |
"manual": [],
|
| 27 |
}
|
| 28 |
|
| 29 |
-
# ──
|
| 30 |
|
| 31 |
-
# 1. Loss
|
| 32 |
if report.get("perplexity", {}).get("loss", 99) < 4.0:
|
| 33 |
-
checks["passed"].append("
|
| 34 |
else:
|
| 35 |
-
checks["failed"].append("
|
| 36 |
|
| 37 |
-
# 2. Loss
|
| 38 |
spikes = report.get("training_dynamics", {}).get("loss", {}).get("spikes", [])
|
| 39 |
if len(spikes) < 5:
|
| 40 |
-
checks["passed"].append(f"Loss
|
| 41 |
else:
|
| 42 |
-
checks["failed"].append(f"Loss
|
| 43 |
|
| 44 |
-
# 3.
|
| 45 |
if report.get("position_losses"):
|
| 46 |
early = report["position_losses"]["early_avg"]
|
| 47 |
late = report["position_losses"]["late_avg"]
|
| 48 |
if early > late:
|
| 49 |
-
checks["passed"].append("
|
| 50 |
else:
|
| 51 |
-
checks["failed"].append("
|
| 52 |
|
| 53 |
-
# 4.
|
| 54 |
rep = report.get("generation", {}).get("avg_metrics", {}).get("repetition_rate", 1.0)
|
| 55 |
if rep < 0.3:
|
| 56 |
-
checks["passed"].append(f"
|
| 57 |
else:
|
| 58 |
-
checks["failed"].append(f"
|
| 59 |
|
| 60 |
-
# 5. Gradient
|
| 61 |
if metrics_history and metrics_history.get("grad_norm"):
|
| 62 |
gnorms = metrics_history["grad_norm"]
|
| 63 |
clip_rate = sum(1 for g in gnorms if g >= 0.99) / max(len(gnorms), 1)
|
| 64 |
if clip_rate < 0.3:
|
| 65 |
-
checks["passed"].append(f"Gradient
|
| 66 |
else:
|
| 67 |
-
checks["failed"].append(f"Gradient
|
| 68 |
|
| 69 |
-
# ──
|
| 70 |
manual_items = [
|
| 71 |
-
"
|
| 72 |
-
"
|
| 73 |
-
"
|
| 74 |
-
"SwiGLU
|
| 75 |
-
"Learning Rate Warmup
|
| 76 |
-
"Gradient Accumulation
|
| 77 |
-
"Mixed Precision(bf16)
|
| 78 |
-
"
|
| 79 |
]
|
| 80 |
checks["manual"] = manual_items
|
| 81 |
|
| 82 |
-
# ──
|
| 83 |
total_auto = len(checks["passed"]) + len(checks["failed"])
|
| 84 |
passed_auto = len(checks["passed"])
|
| 85 |
|
| 86 |
-
print(f"\n
|
| 87 |
for item in checks["passed"]:
|
| 88 |
print(f" ✅ {item}")
|
| 89 |
for item in checks["failed"]:
|
| 90 |
print(f" ❌ {item}")
|
| 91 |
|
| 92 |
-
print(f"\n
|
| 93 |
for i, item in enumerate(manual_items, 1):
|
| 94 |
print(f" {i}. [ ] {item}")
|
| 95 |
|
| 96 |
-
print(f"\n
|
| 97 |
-
f"(
|
| 98 |
|
| 99 |
return checks
|
|
|
|
| 1 |
+
"""Training insight checklist validator."""
|
| 2 |
|
| 3 |
from typing import Any, Dict, Optional
|
| 4 |
|
| 5 |
|
| 6 |
class InsightChecklist:
|
| 7 |
+
"""Automatically and manually validates the training insight checklist defined in the PRD.
|
| 8 |
|
| 9 |
+
Items that can be automatically validated are judged based on metrics,
|
| 10 |
+
while manual items are presented as questions.
|
| 11 |
"""
|
| 12 |
|
| 13 |
@staticmethod
|
|
|
|
| 15 |
report: Dict[str, Any],
|
| 16 |
metrics_history: Optional[Dict[str, list]] = None,
|
| 17 |
):
|
| 18 |
+
"""Runs the checklist."""
|
| 19 |
print("\n" + "=" * 70)
|
| 20 |
+
print("✅ Training Insight Checklist")
|
| 21 |
print("=" * 70)
|
| 22 |
|
| 23 |
checks = {
|
|
|
|
| 26 |
"manual": [],
|
| 27 |
}
|
| 28 |
|
| 29 |
+
# ── Automatic validation ──
|
| 30 |
|
| 31 |
+
# 1. Loss convergence
|
| 32 |
if report.get("perplexity", {}).get("loss", 99) < 4.0:
|
| 33 |
+
checks["passed"].append("Model Loss converged below 4.0")
|
| 34 |
else:
|
| 35 |
+
checks["failed"].append("Model Loss has not converged below 4.0")
|
| 36 |
|
| 37 |
+
# 2. Loss spikes
|
| 38 |
spikes = report.get("training_dynamics", {}).get("loss", {}).get("spikes", [])
|
| 39 |
if len(spikes) < 5:
|
| 40 |
+
checks["passed"].append(f"Loss spikes: {len(spikes)} (< 5)")
|
| 41 |
else:
|
| 42 |
+
checks["failed"].append(f"Loss spikes: {len(spikes)} (>= 5, stability improvement needed)")
|
| 43 |
|
| 44 |
+
# 3. Per-position loss pattern
|
| 45 |
if report.get("position_losses"):
|
| 46 |
early = report["position_losses"]["early_avg"]
|
| 47 |
late = report["position_losses"]["late_avg"]
|
| 48 |
if early > late:
|
| 49 |
+
checks["passed"].append("Per-position loss decrease pattern confirmed (context utilization)")
|
| 50 |
else:
|
| 51 |
+
checks["failed"].append("Per-position loss pattern abnormal (context not utilized?)")
|
| 52 |
|
| 53 |
+
# 4. Generation repetition rate
|
| 54 |
rep = report.get("generation", {}).get("avg_metrics", {}).get("repetition_rate", 1.0)
|
| 55 |
if rep < 0.3:
|
| 56 |
+
checks["passed"].append(f"Generation repetition rate {rep:.1%} (< 30%)")
|
| 57 |
else:
|
| 58 |
+
checks["failed"].append(f"Generation repetition rate {rep:.1%} (>= 30%, adjust temperature/top_p)")
|
| 59 |
|
| 60 |
+
# 5. Gradient clipping rate
|
| 61 |
if metrics_history and metrics_history.get("grad_norm"):
|
| 62 |
gnorms = metrics_history["grad_norm"]
|
| 63 |
clip_rate = sum(1 for g in gnorms if g >= 0.99) / max(len(gnorms), 1)
|
| 64 |
if clip_rate < 0.3:
|
| 65 |
+
checks["passed"].append(f"Gradient clipping rate {clip_rate:.1%} (healthy)")
|
| 66 |
else:
|
| 67 |
+
checks["failed"].append(f"Gradient clipping rate {clip_rate:.1%} (too frequent)")
|
| 68 |
|
| 69 |
+
# ── Manual verification items ──
|
| 70 |
manual_items = [
|
| 71 |
+
"Can you explain the individual roles of Q, K, and V in Self-Attention?",
|
| 72 |
+
"Do you understand the mathematical principle by which RoPE encodes positional information?",
|
| 73 |
+
"Can you explain the mechanism by which GQA saves memory compared to MHA?",
|
| 74 |
+
"Do you understand how SwiGLU's gating mechanism differs from a ReLU FFN?",
|
| 75 |
+
"Did you experience why Learning Rate Warmup is necessary?",
|
| 76 |
+
"Do you understand the principle by which Gradient Accumulation simulates a large batch?",
|
| 77 |
+
"Have you measured the memory-speed effect of Mixed Precision (bf16)?",
|
| 78 |
+
"Do you understand the memory-compute trade-off of Activation Checkpointing?",
|
| 79 |
]
|
| 80 |
checks["manual"] = manual_items
|
| 81 |
|
| 82 |
+
# ── Output ──
|
| 83 |
total_auto = len(checks["passed"]) + len(checks["failed"])
|
| 84 |
passed_auto = len(checks["passed"])
|
| 85 |
|
| 86 |
+
print(f"\n Automatic validation: {passed_auto}/{total_auto} passed")
|
| 87 |
for item in checks["passed"]:
|
| 88 |
print(f" ✅ {item}")
|
| 89 |
for item in checks["failed"]:
|
| 90 |
print(f" ❌ {item}")
|
| 91 |
|
| 92 |
+
print(f"\n Manual verification ({len(manual_items)} items):")
|
| 93 |
for i, item in enumerate(manual_items, 1):
|
| 94 |
print(f" {i}. [ ] {item}")
|
| 95 |
|
| 96 |
+
print(f"\n Total progress: {passed_auto}/{total_auto + len(manual_items)} "
|
| 97 |
+
f"(including manual items)")
|
| 98 |
|
| 99 |
return checks
|
llm_lab/evaluation/dynamics.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
import math
|
| 4 |
from pathlib import Path
|
|
@@ -14,13 +14,13 @@ except ImportError:
|
|
| 14 |
|
| 15 |
|
| 16 |
class TrainingDynamicsAnalyzer:
|
| 17 |
-
"""
|
| 18 |
|
| 19 |
-
|
| 20 |
-
- Loss
|
| 21 |
-
- LR
|
| 22 |
-
- Gradient Norm:
|
| 23 |
-
-
|
| 24 |
"""
|
| 25 |
|
| 26 |
def __init__(self, save_dir: str = "./eval_results"):
|
|
@@ -28,21 +28,21 @@ class TrainingDynamicsAnalyzer:
|
|
| 28 |
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 29 |
|
| 30 |
def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
|
| 31 |
-
"""
|
| 32 |
|
| 33 |
Args:
|
| 34 |
-
metrics_history: Trainer.metrics.history
|
| 35 |
|
| 36 |
Returns:
|
| 37 |
-
|
| 38 |
"""
|
| 39 |
print("\n" + "=" * 70)
|
| 40 |
-
print("🔬
|
| 41 |
print("=" * 70)
|
| 42 |
|
| 43 |
analysis = {}
|
| 44 |
|
| 45 |
-
# ── Loss
|
| 46 |
if metrics_history.get("train_loss"):
|
| 47 |
losses = metrics_history["train_loss"]
|
| 48 |
analysis["loss"] = {
|
|
@@ -52,7 +52,7 @@ class TrainingDynamicsAnalyzer:
|
|
| 52 |
"total_reduction": round(losses[0] - losses[-1], 4),
|
| 53 |
}
|
| 54 |
|
| 55 |
-
#
|
| 56 |
spikes = []
|
| 57 |
for i in range(1, len(losses)):
|
| 58 |
if losses[i] > losses[i-1] * 1.5:
|
|
@@ -61,17 +61,17 @@ class TrainingDynamicsAnalyzer:
|
|
| 61 |
|
| 62 |
analysis["loss"]["spikes"] = spikes
|
| 63 |
|
| 64 |
-
print(f"\n 📉 Loss
|
| 65 |
-
print(f"
|
| 66 |
-
print(f"
|
| 67 |
-
print(f"
|
| 68 |
-
print(f"
|
| 69 |
-
print(f"
|
| 70 |
if spikes:
|
| 71 |
for s in spikes[:5]:
|
| 72 |
print(f" Step {s['step']}: Loss = {s['loss']}")
|
| 73 |
|
| 74 |
-
# ── Gradient Norm
|
| 75 |
if metrics_history.get("grad_norm"):
|
| 76 |
gnorms = metrics_history["grad_norm"]
|
| 77 |
analysis["grad_norm"] = {
|
|
@@ -81,14 +81,14 @@ class TrainingDynamicsAnalyzer:
|
|
| 81 |
"clipped_pct": round(sum(1 for g in gnorms if g >= 0.99) / len(gnorms) * 100, 1),
|
| 82 |
}
|
| 83 |
|
| 84 |
-
print(f"\n 📐 Gradient Norm
|
| 85 |
-
print(f"
|
| 86 |
-
print(f"
|
| 87 |
-
print(f"
|
| 88 |
if analysis["grad_norm"]["clipped_pct"] > 30:
|
| 89 |
-
print(f" ⚠️
|
| 90 |
|
| 91 |
-
# ──
|
| 92 |
if metrics_history.get("tokens_per_sec"):
|
| 93 |
tps = metrics_history["tokens_per_sec"]
|
| 94 |
tps_valid = [t for t in tps if t > 0]
|
|
@@ -100,10 +100,10 @@ class TrainingDynamicsAnalyzer:
|
|
| 100 |
"max": round(max(tps_valid)),
|
| 101 |
}
|
| 102 |
|
| 103 |
-
print(f"\n ⚡
|
| 104 |
-
print(f"
|
| 105 |
-
print(f"
|
| 106 |
-
print(f"
|
| 107 |
|
| 108 |
return analysis
|
| 109 |
|
|
@@ -112,9 +112,9 @@ class TrainingDynamicsAnalyzer:
|
|
| 112 |
metrics_history: Dict[str, list],
|
| 113 |
save_path: Optional[str] = None,
|
| 114 |
):
|
| 115 |
-
"""
|
| 116 |
if not HAS_MATPLOTLIB:
|
| 117 |
-
print("⚠️ matplotlib
|
| 118 |
return
|
| 119 |
|
| 120 |
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
|
|
@@ -129,7 +129,7 @@ class TrainingDynamicsAnalyzer:
|
|
| 129 |
metrics_history["train_loss"],
|
| 130 |
color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
|
| 131 |
|
| 132 |
-
#
|
| 133 |
if len(metrics_history["train_loss"]) > 20:
|
| 134 |
window = min(50, len(metrics_history["train_loss"]) // 5)
|
| 135 |
smoothed = self._moving_average(metrics_history["train_loss"], window)
|
|
@@ -192,7 +192,7 @@ class TrainingDynamicsAnalyzer:
|
|
| 192 |
|
| 193 |
save_path = save_path or str(self.save_dir / "training_curves.png")
|
| 194 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 195 |
-
print(f"\n 📊
|
| 196 |
plt.close(fig)
|
| 197 |
|
| 198 |
def plot_position_loss(
|
|
@@ -200,7 +200,7 @@ class TrainingDynamicsAnalyzer:
|
|
| 200 |
position_losses: List[float],
|
| 201 |
save_path: Optional[str] = None,
|
| 202 |
):
|
| 203 |
-
"""
|
| 204 |
if not HAS_MATPLOTLIB:
|
| 205 |
return
|
| 206 |
|
|
@@ -215,7 +215,7 @@ class TrainingDynamicsAnalyzer:
|
|
| 215 |
ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
|
| 216 |
ax.grid(True, alpha=0.3)
|
| 217 |
|
| 218 |
-
#
|
| 219 |
if len(position_losses) > 100:
|
| 220 |
early_avg = sum(position_losses[:50]) / 50
|
| 221 |
late_avg = sum(position_losses[-200:]) / 200
|
|
@@ -229,12 +229,12 @@ class TrainingDynamicsAnalyzer:
|
|
| 229 |
|
| 230 |
save_path = save_path or str(self.save_dir / "position_loss.png")
|
| 231 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 232 |
-
print(f" 📊
|
| 233 |
plt.close(fig)
|
| 234 |
|
| 235 |
@staticmethod
|
| 236 |
def _moving_average(data: list, window: int) -> list:
|
| 237 |
-
"""
|
| 238 |
result = []
|
| 239 |
for i in range(window - 1, len(data)):
|
| 240 |
avg = sum(data[i - window + 1 : i + 1]) / window
|
|
|
|
| 1 |
+
"""Training dynamics analyzer."""
|
| 2 |
|
| 3 |
import math
|
| 4 |
from pathlib import Path
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class TrainingDynamicsAnalyzer:
|
| 17 |
+
"""Analyzes and visualizes training metrics.
|
| 18 |
|
| 19 |
+
Analysis items:
|
| 20 |
+
- Loss curve: Convergence patterns, spike detection
|
| 21 |
+
- LR schedule: Warmup + Cosine decay verification
|
| 22 |
+
- Gradient Norm: Training stability, explosion/vanishing detection
|
| 23 |
+
- Throughput: tokens/sec stability, bottleneck detection
|
| 24 |
"""
|
| 25 |
|
| 26 |
def __init__(self, save_dir: str = "./eval_results"):
|
|
|
|
| 28 |
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 29 |
|
| 30 |
def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
|
| 31 |
+
"""Analyzes training metrics.
|
| 32 |
|
| 33 |
Args:
|
| 34 |
+
metrics_history: Trainer.metrics.history dictionary
|
| 35 |
|
| 36 |
Returns:
|
| 37 |
+
Analysis results
|
| 38 |
"""
|
| 39 |
print("\n" + "=" * 70)
|
| 40 |
+
print("🔬 Training Dynamics Analysis")
|
| 41 |
print("=" * 70)
|
| 42 |
|
| 43 |
analysis = {}
|
| 44 |
|
| 45 |
+
# ── Loss analysis ──
|
| 46 |
if metrics_history.get("train_loss"):
|
| 47 |
losses = metrics_history["train_loss"]
|
| 48 |
analysis["loss"] = {
|
|
|
|
| 52 |
"total_reduction": round(losses[0] - losses[-1], 4),
|
| 53 |
}
|
| 54 |
|
| 55 |
+
# Spike detection (sudden increase of 50% or more compared to previous value)
|
| 56 |
spikes = []
|
| 57 |
for i in range(1, len(losses)):
|
| 58 |
if losses[i] > losses[i-1] * 1.5:
|
|
|
|
| 61 |
|
| 62 |
analysis["loss"]["spikes"] = spikes
|
| 63 |
|
| 64 |
+
print(f"\n 📉 Loss Analysis:")
|
| 65 |
+
print(f" Initial: {analysis['loss']['initial']:.4f}")
|
| 66 |
+
print(f" Final: {analysis['loss']['final']:.4f}")
|
| 67 |
+
print(f" Minimum: {analysis['loss']['minimum']:.4f}")
|
| 68 |
+
print(f" Reduction: {analysis['loss']['total_reduction']:.4f}")
|
| 69 |
+
print(f" Spikes: {len(spikes)}")
|
| 70 |
if spikes:
|
| 71 |
for s in spikes[:5]:
|
| 72 |
print(f" Step {s['step']}: Loss = {s['loss']}")
|
| 73 |
|
| 74 |
+
# ── Gradient Norm analysis ──
|
| 75 |
if metrics_history.get("grad_norm"):
|
| 76 |
gnorms = metrics_history["grad_norm"]
|
| 77 |
analysis["grad_norm"] = {
|
|
|
|
| 81 |
"clipped_pct": round(sum(1 for g in gnorms if g >= 0.99) / len(gnorms) * 100, 1),
|
| 82 |
}
|
| 83 |
|
| 84 |
+
print(f"\n 📐 Gradient Norm Analysis:")
|
| 85 |
+
print(f" Mean: {analysis['grad_norm']['mean']:.4f}")
|
| 86 |
+
print(f" Max: {analysis['grad_norm']['max']:.4f}")
|
| 87 |
+
print(f" Clipping rate: {analysis['grad_norm']['clipped_pct']:.1f}%")
|
| 88 |
if analysis["grad_norm"]["clipped_pct"] > 30:
|
| 89 |
+
print(f" ⚠️ Clipping is frequent → consider lowering LR or extending warmup")
|
| 90 |
|
| 91 |
+
# ── Throughput analysis ──
|
| 92 |
if metrics_history.get("tokens_per_sec"):
|
| 93 |
tps = metrics_history["tokens_per_sec"]
|
| 94 |
tps_valid = [t for t in tps if t > 0]
|
|
|
|
| 100 |
"max": round(max(tps_valid)),
|
| 101 |
}
|
| 102 |
|
| 103 |
+
print(f"\n ⚡ Throughput Analysis:")
|
| 104 |
+
print(f" Mean: {analysis['throughput']['mean']:,} tokens/sec")
|
| 105 |
+
print(f" StdDev: {analysis['throughput']['std']:,}")
|
| 106 |
+
print(f" Range: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]")
|
| 107 |
|
| 108 |
return analysis
|
| 109 |
|
|
|
|
| 112 |
metrics_history: Dict[str, list],
|
| 113 |
save_path: Optional[str] = None,
|
| 114 |
):
|
| 115 |
+
"""Visualizes training curves as a 4-panel chart."""
|
| 116 |
if not HAS_MATPLOTLIB:
|
| 117 |
+
print("⚠️ matplotlib required: pip install matplotlib")
|
| 118 |
return
|
| 119 |
|
| 120 |
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
|
|
|
|
| 129 |
metrics_history["train_loss"],
|
| 130 |
color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
|
| 131 |
|
| 132 |
+
# Moving average (smoothing)
|
| 133 |
if len(metrics_history["train_loss"]) > 20:
|
| 134 |
window = min(50, len(metrics_history["train_loss"]) // 5)
|
| 135 |
smoothed = self._moving_average(metrics_history["train_loss"], window)
|
|
|
|
| 192 |
|
| 193 |
save_path = save_path or str(self.save_dir / "training_curves.png")
|
| 194 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 195 |
+
print(f"\n 📊 Training curves saved: {save_path}")
|
| 196 |
plt.close(fig)
|
| 197 |
|
| 198 |
def plot_position_loss(
|
|
|
|
| 200 |
position_losses: List[float],
|
| 201 |
save_path: Optional[str] = None,
|
| 202 |
):
|
| 203 |
+
"""Visualizes loss distribution by position."""
|
| 204 |
if not HAS_MATPLOTLIB:
|
| 205 |
return
|
| 206 |
|
|
|
|
| 215 |
ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
|
| 216 |
ax.grid(True, alpha=0.3)
|
| 217 |
|
| 218 |
+
# Mark key regions
|
| 219 |
if len(position_losses) > 100:
|
| 220 |
early_avg = sum(position_losses[:50]) / 50
|
| 221 |
late_avg = sum(position_losses[-200:]) / 200
|
|
|
|
| 229 |
|
| 230 |
save_path = save_path or str(self.save_dir / "position_loss.png")
|
| 231 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 232 |
+
print(f" 📊 Position loss saved: {save_path}")
|
| 233 |
plt.close(fig)
|
| 234 |
|
| 235 |
@staticmethod
|
| 236 |
def _moving_average(data: list, window: int) -> list:
|
| 237 |
+
"""Compute moving average."""
|
| 238 |
result = []
|
| 239 |
for i in range(window - 1, len(data)):
|
| 240 |
avg = sum(data[i - window + 1 : i + 1]) / window
|
llm_lab/evaluation/full_evaluator.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
import json
|
| 4 |
import time
|
|
@@ -17,9 +17,9 @@ from .attention_viz import AttentionVisualizer
|
|
| 17 |
|
| 18 |
|
| 19 |
class FullEvaluator:
|
| 20 |
-
"""
|
| 21 |
|
| 22 |
-
|
| 23 |
```python
|
| 24 |
evaluator = FullEvaluator(model, tokenizer, val_dataloader, device)
|
| 25 |
report = evaluator.run_full_evaluation()
|
|
@@ -48,24 +48,24 @@ class FullEvaluator:
|
|
| 48 |
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 49 |
|
| 50 |
def run_full_evaluation(self) -> Dict[str, Any]:
|
| 51 |
-
"""
|
| 52 |
report = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
|
| 53 |
|
| 54 |
print("\n" + "=" * 70)
|
| 55 |
-
print("🔍
|
| 56 |
print("=" * 70)
|
| 57 |
|
| 58 |
# ── 1. Perplexity ──
|
| 59 |
print("\n" + "━" * 40)
|
| 60 |
-
print("Phase 1/4: Perplexity
|
| 61 |
print("━" * 40)
|
| 62 |
ppl_evaluator = PerplexityEvaluator(self.config)
|
| 63 |
report["perplexity"] = ppl_evaluator.evaluate(
|
| 64 |
self.model, self.val_dataloader, self.device, self.dtype
|
| 65 |
)
|
| 66 |
|
| 67 |
-
#
|
| 68 |
-
print("\n
|
| 69 |
position_losses = ppl_evaluator.evaluate_per_position(
|
| 70 |
self.model, self.val_dataloader, self.device, self.dtype
|
| 71 |
)
|
|
@@ -74,13 +74,13 @@ class FullEvaluator:
|
|
| 74 |
"late_avg": round(sum(position_losses[-200:]) / max(len(position_losses[-200:]), 1), 4),
|
| 75 |
}
|
| 76 |
|
| 77 |
-
#
|
| 78 |
dynamics = TrainingDynamicsAnalyzer(str(self.save_dir))
|
| 79 |
dynamics.plot_position_loss(position_losses, str(self.save_dir / "position_loss.png"))
|
| 80 |
|
| 81 |
-
# ── 2.
|
| 82 |
print("\n" + "━" * 40)
|
| 83 |
-
print("Phase 2/4:
|
| 84 |
print("━" * 40)
|
| 85 |
gen_evaluator = GenerationEvaluator(self.config)
|
| 86 |
gen_results = gen_evaluator.generate_samples(
|
|
@@ -91,52 +91,52 @@ class FullEvaluator:
|
|
| 91 |
"avg_metrics": self._average_gen_metrics(gen_results),
|
| 92 |
}
|
| 93 |
|
| 94 |
-
# ── 3.
|
| 95 |
if self.metrics_history:
|
| 96 |
print("\n" + "━" * 40)
|
| 97 |
-
print("Phase 3/4:
|
| 98 |
print("━" * 40)
|
| 99 |
report["training_dynamics"] = dynamics.analyze_metrics(self.metrics_history)
|
| 100 |
dynamics.plot_training_curves(self.metrics_history,
|
| 101 |
str(self.save_dir / "training_curves.png"))
|
| 102 |
else:
|
| 103 |
-
print("\n Phase 3/4:
|
| 104 |
|
| 105 |
-
# ── 4. Attention
|
| 106 |
print("\n" + "━" * 40)
|
| 107 |
-
print("Phase 4/4: Attention
|
| 108 |
print("━" * 40)
|
| 109 |
try:
|
| 110 |
self._visualize_attention_sample()
|
| 111 |
except Exception as e:
|
| 112 |
-
print(f" ⚠️ Attention
|
| 113 |
|
| 114 |
-
# ──
|
| 115 |
report_path = self.save_dir / "eval_report.json"
|
| 116 |
with open(report_path, "w") as f:
|
| 117 |
json.dump(report, f, indent=2, default=str)
|
| 118 |
-
print(f"\n📋
|
| 119 |
|
| 120 |
-
# ──
|
| 121 |
self._print_summary(report)
|
| 122 |
|
| 123 |
return report
|
| 124 |
|
| 125 |
def _visualize_attention_sample(self):
|
| 126 |
-
"""
|
| 127 |
viz = AttentionVisualizer(str(self.save_dir))
|
| 128 |
|
| 129 |
sample_text = "The cat sat on the mat and looked at the bird."
|
| 130 |
token_ids = self.tokenizer.encode(sample_text, add_special_tokens=False)
|
| 131 |
input_tensor = torch.tensor([token_ids], dtype=torch.long)
|
| 132 |
|
| 133 |
-
#
|
| 134 |
tokens_str = []
|
| 135 |
for tid in token_ids:
|
| 136 |
decoded = self.tokenizer.decode([tid])
|
| 137 |
tokens_str.append(decoded.replace("\n", "\\n"))
|
| 138 |
|
| 139 |
-
# Layer 0 attention
|
| 140 |
attn_weights = viz.extract_attention(
|
| 141 |
self.model, input_tensor, layer_idx=0, device=self.device
|
| 142 |
)
|
|
@@ -150,7 +150,7 @@ class FullEvaluator:
|
|
| 150 |
|
| 151 |
@staticmethod
|
| 152 |
def _average_gen_metrics(gen_results: List[Dict]) -> Dict[str, float]:
|
| 153 |
-
"""
|
| 154 |
if not gen_results:
|
| 155 |
return {}
|
| 156 |
|
|
@@ -165,9 +165,9 @@ class FullEvaluator:
|
|
| 165 |
}
|
| 166 |
|
| 167 |
def _print_summary(self, report: Dict[str, Any]):
|
| 168 |
-
"""
|
| 169 |
print("\n" + "=" * 70)
|
| 170 |
-
print("📋
|
| 171 |
print("=" * 70)
|
| 172 |
|
| 173 |
# Perplexity
|
|
@@ -177,44 +177,44 @@ class FullEvaluator:
|
|
| 177 |
print(f" Loss: {ppl['loss']:.4f}")
|
| 178 |
print(f" PPL: {ppl['perplexity']:.2f}")
|
| 179 |
|
| 180 |
-
#
|
| 181 |
ppl_val = ppl["perplexity"]
|
| 182 |
if ppl_val < 20:
|
| 183 |
-
grade = "🌟
|
| 184 |
elif ppl_val < 35:
|
| 185 |
-
grade = "✅
|
| 186 |
elif ppl_val < 60:
|
| 187 |
-
grade = "⚠️
|
| 188 |
else:
|
| 189 |
-
grade = "❌
|
| 190 |
-
print(f"
|
| 191 |
|
| 192 |
-
#
|
| 193 |
if "position_losses" in report:
|
| 194 |
pl = report["position_losses"]
|
| 195 |
-
print(f"\n 📍
|
| 196 |
-
print(f"
|
| 197 |
-
print(f"
|
| 198 |
-
print(f"
|
| 199 |
|
| 200 |
-
#
|
| 201 |
if "generation" in report and report["generation"].get("avg_metrics"):
|
| 202 |
gm = report["generation"]["avg_metrics"]
|
| 203 |
-
print(f"\n ✍️
|
| 204 |
-
print(f"
|
| 205 |
-
print(f"
|
| 206 |
-
print(f"
|
| 207 |
|
| 208 |
-
#
|
| 209 |
if "training_dynamics" in report:
|
| 210 |
td = report["training_dynamics"]
|
| 211 |
if "loss" in td:
|
| 212 |
-
print(f"\n 📉
|
| 213 |
-
print(f" Loss
|
| 214 |
-
print(f"
|
| 215 |
|
| 216 |
-
#
|
| 217 |
-
print(f"\n 📂
|
| 218 |
for f in sorted(self.save_dir.glob("*")):
|
| 219 |
size = f.stat().st_size / 1024
|
| 220 |
print(f" {f.name} ({size:.1f} KB)")
|
|
|
|
| 1 |
+
"""Comprehensive evaluation runner."""
|
| 2 |
|
| 3 |
import json
|
| 4 |
import time
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class FullEvaluator:
|
| 20 |
+
"""Runs all evaluations at once and generates a report.
|
| 21 |
|
| 22 |
+
Usage:
|
| 23 |
```python
|
| 24 |
evaluator = FullEvaluator(model, tokenizer, val_dataloader, device)
|
| 25 |
report = evaluator.run_full_evaluation()
|
|
|
|
| 48 |
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 49 |
|
| 50 |
def run_full_evaluation(self) -> Dict[str, Any]:
|
| 51 |
+
"""Runs the full evaluation."""
|
| 52 |
report = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
|
| 53 |
|
| 54 |
print("\n" + "=" * 70)
|
| 55 |
+
print("🔍 Starting comprehensive evaluation")
|
| 56 |
print("=" * 70)
|
| 57 |
|
| 58 |
# ── 1. Perplexity ──
|
| 59 |
print("\n" + "━" * 40)
|
| 60 |
+
print("Phase 1/4: Perplexity measurement")
|
| 61 |
print("━" * 40)
|
| 62 |
ppl_evaluator = PerplexityEvaluator(self.config)
|
| 63 |
report["perplexity"] = ppl_evaluator.evaluate(
|
| 64 |
self.model, self.val_dataloader, self.device, self.dtype
|
| 65 |
)
|
| 66 |
|
| 67 |
+
# Per-position loss
|
| 68 |
+
print("\n Measuring per-position loss...")
|
| 69 |
position_losses = ppl_evaluator.evaluate_per_position(
|
| 70 |
self.model, self.val_dataloader, self.device, self.dtype
|
| 71 |
)
|
|
|
|
| 74 |
"late_avg": round(sum(position_losses[-200:]) / max(len(position_losses[-200:]), 1), 4),
|
| 75 |
}
|
| 76 |
|
| 77 |
+
# Per-position loss visualization
|
| 78 |
dynamics = TrainingDynamicsAnalyzer(str(self.save_dir))
|
| 79 |
dynamics.plot_position_loss(position_losses, str(self.save_dir / "position_loss.png"))
|
| 80 |
|
| 81 |
+
# ── 2. Text generation ──
|
| 82 |
print("\n" + "━" * 40)
|
| 83 |
+
print("Phase 2/4: Text generation")
|
| 84 |
print("━" * 40)
|
| 85 |
gen_evaluator = GenerationEvaluator(self.config)
|
| 86 |
gen_results = gen_evaluator.generate_samples(
|
|
|
|
| 91 |
"avg_metrics": self._average_gen_metrics(gen_results),
|
| 92 |
}
|
| 93 |
|
| 94 |
+
# ── 3. Training dynamics analysis ──
|
| 95 |
if self.metrics_history:
|
| 96 |
print("\n" + "━" * 40)
|
| 97 |
+
print("Phase 3/4: Training dynamics analysis")
|
| 98 |
print("━" * 40)
|
| 99 |
report["training_dynamics"] = dynamics.analyze_metrics(self.metrics_history)
|
| 100 |
dynamics.plot_training_curves(self.metrics_history,
|
| 101 |
str(self.save_dir / "training_curves.png"))
|
| 102 |
else:
|
| 103 |
+
print("\n Phase 3/4: Skipped (no metrics_history)")
|
| 104 |
|
| 105 |
+
# ── 4. Attention visualization (sample) ──
|
| 106 |
print("\n" + "━" * 40)
|
| 107 |
+
print("Phase 4/4: Attention visualization")
|
| 108 |
print("━" * 40)
|
| 109 |
try:
|
| 110 |
self._visualize_attention_sample()
|
| 111 |
except Exception as e:
|
| 112 |
+
print(f" ⚠️ Attention visualization failed: {e}")
|
| 113 |
|
| 114 |
+
# ── Save report ──
|
| 115 |
report_path = self.save_dir / "eval_report.json"
|
| 116 |
with open(report_path, "w") as f:
|
| 117 |
json.dump(report, f, indent=2, default=str)
|
| 118 |
+
print(f"\n📋 Report saved: {report_path}")
|
| 119 |
|
| 120 |
+
# ── Print summary ──
|
| 121 |
self._print_summary(report)
|
| 122 |
|
| 123 |
return report
|
| 124 |
|
| 125 |
def _visualize_attention_sample(self):
|
| 126 |
+
"""Visualizes attention using a sample text."""
|
| 127 |
viz = AttentionVisualizer(str(self.save_dir))
|
| 128 |
|
| 129 |
sample_text = "The cat sat on the mat and looked at the bird."
|
| 130 |
token_ids = self.tokenizer.encode(sample_text, add_special_tokens=False)
|
| 131 |
input_tensor = torch.tensor([token_ids], dtype=torch.long)
|
| 132 |
|
| 133 |
+
# Token strings (for visualization labels)
|
| 134 |
tokens_str = []
|
| 135 |
for tid in token_ids:
|
| 136 |
decoded = self.tokenizer.decode([tid])
|
| 137 |
tokens_str.append(decoded.replace("\n", "\\n"))
|
| 138 |
|
| 139 |
+
# Extract Layer 0 attention
|
| 140 |
attn_weights = viz.extract_attention(
|
| 141 |
self.model, input_tensor, layer_idx=0, device=self.device
|
| 142 |
)
|
|
|
|
| 150 |
|
| 151 |
@staticmethod
|
| 152 |
def _average_gen_metrics(gen_results: List[Dict]) -> Dict[str, float]:
|
| 153 |
+
"""Average generation metrics across all prompts."""
|
| 154 |
if not gen_results:
|
| 155 |
return {}
|
| 156 |
|
|
|
|
| 165 |
}
|
| 166 |
|
| 167 |
def _print_summary(self, report: Dict[str, Any]):
|
| 168 |
+
"""Prints the final summary."""
|
| 169 |
print("\n" + "=" * 70)
|
| 170 |
+
print("📋 Evaluation Summary Report")
|
| 171 |
print("=" * 70)
|
| 172 |
|
| 173 |
# Perplexity
|
|
|
|
| 177 |
print(f" Loss: {ppl['loss']:.4f}")
|
| 178 |
print(f" PPL: {ppl['perplexity']:.2f}")
|
| 179 |
|
| 180 |
+
# Grade assessment
|
| 181 |
ppl_val = ppl["perplexity"]
|
| 182 |
if ppl_val < 20:
|
| 183 |
+
grade = "🌟 Excellent (Strong)"
|
| 184 |
elif ppl_val < 35:
|
| 185 |
+
grade = "✅ Good"
|
| 186 |
elif ppl_val < 60:
|
| 187 |
+
grade = "⚠️ Fair"
|
| 188 |
else:
|
| 189 |
+
grade = "❌ Poor (more training needed)"
|
| 190 |
+
print(f" Grade: {grade}")
|
| 191 |
|
| 192 |
+
# Per-position loss
|
| 193 |
if "position_losses" in report:
|
| 194 |
pl = report["position_losses"]
|
| 195 |
+
print(f"\n 📍 Per-position Loss:")
|
| 196 |
+
print(f" Early (0-50): {pl['early_avg']:.4f}")
|
| 197 |
+
print(f" Late (-200): {pl['late_avg']:.4f}")
|
| 198 |
+
print(f" Context effect: {pl['early_avg'] - pl['late_avg']:.4f} reduction")
|
| 199 |
|
| 200 |
+
# Generation quality
|
| 201 |
if "generation" in report and report["generation"].get("avg_metrics"):
|
| 202 |
gm = report["generation"]["avg_metrics"]
|
| 203 |
+
print(f"\n ✍️ Generation Quality:")
|
| 204 |
+
print(f" Avg length: {gm.get('avg_length', 0):.0f} chars")
|
| 205 |
+
print(f" Repetition rate: {gm.get('repetition_rate', 0):.1%}")
|
| 206 |
+
print(f" Lexical diversity: {gm.get('lexical_diversity', 0):.3f}")
|
| 207 |
|
| 208 |
+
# Training dynamics
|
| 209 |
if "training_dynamics" in report:
|
| 210 |
td = report["training_dynamics"]
|
| 211 |
if "loss" in td:
|
| 212 |
+
print(f"\n 📉 Training Dynamics:")
|
| 213 |
+
print(f" Loss reduction: {td['loss']['initial']:.4f} → {td['loss']['final']:.4f}")
|
| 214 |
+
print(f" Spikes: {len(td['loss']['spikes'])}")
|
| 215 |
|
| 216 |
+
# Generated files
|
| 217 |
+
print(f"\n 📂 Output files:")
|
| 218 |
for f in sorted(self.save_dir.glob("*")):
|
| 219 |
size = f.stat().st_size / 1024
|
| 220 |
print(f" {f.name} ({size:.1f} KB)")
|
llm_lab/evaluation/generation.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from typing import Any, Dict, List, Optional
|
| 4 |
|
|
@@ -9,47 +9,47 @@ from llm_lab.config import EvalConfig
|
|
| 9 |
|
| 10 |
|
| 11 |
class GenerationEvaluator:
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
1)
|
| 16 |
-
2)
|
| 17 |
-
3)
|
| 18 |
-
4)
|
| 19 |
-
5)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
-
|
| 23 |
-
-
|
| 24 |
-
-
|
| 25 |
-
-
|
| 26 |
"""
|
| 27 |
|
| 28 |
-
#
|
| 29 |
DEFAULT_PROMPTS = [
|
| 30 |
-
# ──
|
| 31 |
"The theory of relativity states that",
|
| 32 |
"In the history of computer science,",
|
| 33 |
"The human brain is remarkable because",
|
| 34 |
|
| 35 |
-
# ──
|
| 36 |
"To understand machine learning, one must first",
|
| 37 |
"The water cycle begins when",
|
| 38 |
"Photosynthesis is the process by which",
|
| 39 |
|
| 40 |
-
# ──
|
| 41 |
"Once upon a time, in a small village near the mountains,",
|
| 42 |
"The detective looked at the evidence and realized that",
|
| 43 |
|
| 44 |
-
# ──
|
| 45 |
"def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
|
| 46 |
"The most important data structures in programming are",
|
| 47 |
|
| 48 |
-
# ──
|
| 49 |
"The capital of France is",
|
| 50 |
"Water boils at a temperature of",
|
| 51 |
|
| 52 |
-
# ──
|
| 53 |
("Artificial intelligence has transformed many industries. "
|
| 54 |
"In healthcare, AI is used for diagnosis and drug discovery. "
|
| 55 |
"In finance, it powers algorithmic trading and fraud detection. "
|
|
@@ -68,7 +68,7 @@ class GenerationEvaluator:
|
|
| 68 |
prompts: Optional[List[str]] = None,
|
| 69 |
verbose: bool = True,
|
| 70 |
) -> List[Dict[str, Any]]:
|
| 71 |
-
"""
|
| 72 |
|
| 73 |
Returns:
|
| 74 |
[{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
|
|
@@ -79,7 +79,7 @@ class GenerationEvaluator:
|
|
| 79 |
|
| 80 |
if verbose:
|
| 81 |
print("\n" + "=" * 70)
|
| 82 |
-
print("📝
|
| 83 |
print("=" * 70)
|
| 84 |
|
| 85 |
for idx, prompt in enumerate(prompts):
|
|
@@ -91,17 +91,17 @@ class GenerationEvaluator:
|
|
| 91 |
|
| 92 |
if verbose:
|
| 93 |
print(f"\n{'─'*60}")
|
| 94 |
-
print(f"
|
| 95 |
print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
|
| 96 |
print(f"{'─'*60}")
|
| 97 |
|
| 98 |
-
#
|
| 99 |
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
| 100 |
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 101 |
|
| 102 |
all_texts = []
|
| 103 |
for sample_idx in range(self.config.num_samples):
|
| 104 |
-
#
|
| 105 |
generated_ids = model.generate(
|
| 106 |
input_tensor,
|
| 107 |
max_new_tokens=self.config.max_new_tokens,
|
|
@@ -110,7 +110,7 @@ class GenerationEvaluator:
|
|
| 110 |
top_p=self.config.top_p,
|
| 111 |
)
|
| 112 |
|
| 113 |
-
#
|
| 114 |
new_ids = generated_ids[0][len(prompt_ids):].tolist()
|
| 115 |
generated_text = tokenizer.decode(new_ids)
|
| 116 |
all_texts.append(generated_text)
|
|
@@ -118,23 +118,23 @@ class GenerationEvaluator:
|
|
| 118 |
prompt_results["generations"].append(generated_text)
|
| 119 |
|
| 120 |
if verbose:
|
| 121 |
-
print(f"\n ✍️
|
| 122 |
-
#
|
| 123 |
display_text = generated_text[:500]
|
| 124 |
for line in display_text.split("\n"):
|
| 125 |
print(f" {line}")
|
| 126 |
if len(generated_text) > 500:
|
| 127 |
-
print(f" ... (
|
| 128 |
|
| 129 |
-
#
|
| 130 |
prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
|
| 131 |
|
| 132 |
if verbose and prompt_results["metrics"]:
|
| 133 |
m = prompt_results["metrics"]
|
| 134 |
-
print(f"\n 📊
|
| 135 |
-
f"
|
| 136 |
-
f"
|
| 137 |
-
f"
|
| 138 |
|
| 139 |
results.append(prompt_results)
|
| 140 |
|
|
@@ -142,23 +142,23 @@ class GenerationEvaluator:
|
|
| 142 |
|
| 143 |
@staticmethod
|
| 144 |
def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
|
| 145 |
-
"""
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
- avg_length:
|
| 149 |
-
- avg_word_count:
|
| 150 |
-
- repetition_rate: n-gram
|
| 151 |
-
- lexical_diversity:
|
| 152 |
-
- sample_diversity:
|
| 153 |
"""
|
| 154 |
if not texts:
|
| 155 |
return {}
|
| 156 |
|
| 157 |
-
#
|
| 158 |
lengths = [len(t) for t in texts]
|
| 159 |
word_counts = [len(t.split()) for t in texts]
|
| 160 |
|
| 161 |
-
#
|
| 162 |
rep_rates = []
|
| 163 |
for text in texts:
|
| 164 |
words = text.lower().split()
|
|
@@ -167,9 +167,9 @@ class GenerationEvaluator:
|
|
| 167 |
continue
|
| 168 |
ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
|
| 169 |
unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
|
| 170 |
-
rep_rates.append(1.0 - unique_ratio) #
|
| 171 |
|
| 172 |
-
#
|
| 173 |
diversities = []
|
| 174 |
for text in texts:
|
| 175 |
words = text.lower().split()
|
|
@@ -178,7 +178,7 @@ class GenerationEvaluator:
|
|
| 178 |
else:
|
| 179 |
diversities.append(0.0)
|
| 180 |
|
| 181 |
-
#
|
| 182 |
sample_div = 0.0
|
| 183 |
if len(texts) > 1:
|
| 184 |
word_sets = [set(t.lower().split()) for t in texts]
|
|
|
|
| 1 |
+
"""Text generation evaluator."""
|
| 2 |
|
| 3 |
from typing import Any, Dict, List, Optional
|
| 4 |
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class GenerationEvaluator:
|
| 12 |
+
"""Evaluates text quality by generating from various prompts.
|
| 13 |
+
|
| 14 |
+
Evaluation perspectives:
|
| 15 |
+
1) Grammatical accuracy: Does it generate grammatically correct English sentences?
|
| 16 |
+
2) Coherence: Does it maintain context continuity?
|
| 17 |
+
3) Diversity: Does it produce different outputs for the same prompt?
|
| 18 |
+
4) Repetition avoidance: Does it avoid repeating the same phrases?
|
| 19 |
+
5) Knowledge expression: Is knowledge from the training data reflected?
|
| 20 |
+
|
| 21 |
+
Realistic expectations for a 1B model:
|
| 22 |
+
- Generates grammatically correct English sentences ✅
|
| 23 |
+
- Maintains coherence within short paragraphs ✅
|
| 24 |
+
- Complex reasoning or extended logical chains ❌ (requires a larger model)
|
| 25 |
+
- Factual accuracy is not guaranteed ⚠️
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
# Test prompts from various domains
|
| 29 |
DEFAULT_PROMPTS = [
|
| 30 |
+
# ── General knowledge ──
|
| 31 |
"The theory of relativity states that",
|
| 32 |
"In the history of computer science,",
|
| 33 |
"The human brain is remarkable because",
|
| 34 |
|
| 35 |
+
# ── Explanation / Education ──
|
| 36 |
"To understand machine learning, one must first",
|
| 37 |
"The water cycle begins when",
|
| 38 |
"Photosynthesis is the process by which",
|
| 39 |
|
| 40 |
+
# ── Narrative / Story ──
|
| 41 |
"Once upon a time, in a small village near the mountains,",
|
| 42 |
"The detective looked at the evidence and realized that",
|
| 43 |
|
| 44 |
+
# ── Code / Technical ──
|
| 45 |
"def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
|
| 46 |
"The most important data structures in programming are",
|
| 47 |
|
| 48 |
+
# ── Short completion ──
|
| 49 |
"The capital of France is",
|
| 50 |
"Water boils at a temperature of",
|
| 51 |
|
| 52 |
+
# ── Long context ──
|
| 53 |
("Artificial intelligence has transformed many industries. "
|
| 54 |
"In healthcare, AI is used for diagnosis and drug discovery. "
|
| 55 |
"In finance, it powers algorithmic trading and fraud detection. "
|
|
|
|
| 68 |
prompts: Optional[List[str]] = None,
|
| 69 |
verbose: bool = True,
|
| 70 |
) -> List[Dict[str, Any]]:
|
| 71 |
+
"""Generates text for each prompt.
|
| 72 |
|
| 73 |
Returns:
|
| 74 |
[{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
|
|
|
|
| 79 |
|
| 80 |
if verbose:
|
| 81 |
print("\n" + "=" * 70)
|
| 82 |
+
print("📝 Text Generation Evaluation")
|
| 83 |
print("=" * 70)
|
| 84 |
|
| 85 |
for idx, prompt in enumerate(prompts):
|
|
|
|
| 91 |
|
| 92 |
if verbose:
|
| 93 |
print(f"\n{'─'*60}")
|
| 94 |
+
print(f"Prompt [{idx+1}/{len(prompts)}]:")
|
| 95 |
print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
|
| 96 |
print(f"{'─'*60}")
|
| 97 |
|
| 98 |
+
# Encode prompt
|
| 99 |
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
| 100 |
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 101 |
|
| 102 |
all_texts = []
|
| 103 |
for sample_idx in range(self.config.num_samples):
|
| 104 |
+
# Generate
|
| 105 |
generated_ids = model.generate(
|
| 106 |
input_tensor,
|
| 107 |
max_new_tokens=self.config.max_new_tokens,
|
|
|
|
| 110 |
top_p=self.config.top_p,
|
| 111 |
)
|
| 112 |
|
| 113 |
+
# Decode (only the part after the prompt)
|
| 114 |
new_ids = generated_ids[0][len(prompt_ids):].tolist()
|
| 115 |
generated_text = tokenizer.decode(new_ids)
|
| 116 |
all_texts.append(generated_text)
|
|
|
|
| 118 |
prompt_results["generations"].append(generated_text)
|
| 119 |
|
| 120 |
if verbose:
|
| 121 |
+
print(f"\n ✍️ Generation #{sample_idx+1}:")
|
| 122 |
+
# Clean output (including newlines)
|
| 123 |
display_text = generated_text[:500]
|
| 124 |
for line in display_text.split("\n"):
|
| 125 |
print(f" {line}")
|
| 126 |
if len(generated_text) > 500:
|
| 127 |
+
print(f" ... (total {len(generated_text)} characters)")
|
| 128 |
|
| 129 |
+
# Generation quality metrics
|
| 130 |
prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
|
| 131 |
|
| 132 |
if verbose and prompt_results["metrics"]:
|
| 133 |
m = prompt_results["metrics"]
|
| 134 |
+
print(f"\n 📊 Metrics: "
|
| 135 |
+
f"avg_length={m['avg_length']:.0f} chars, "
|
| 136 |
+
f"repetition_rate={m['repetition_rate']:.1%}, "
|
| 137 |
+
f"lexical_diversity={m['lexical_diversity']:.2f}")
|
| 138 |
|
| 139 |
results.append(prompt_results)
|
| 140 |
|
|
|
|
| 142 |
|
| 143 |
@staticmethod
|
| 144 |
def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
|
| 145 |
+
"""Computes quality metrics for generated text.
|
| 146 |
+
|
| 147 |
+
Metrics:
|
| 148 |
+
- avg_length: Average generation length (characters)
|
| 149 |
+
- avg_word_count: Average word count
|
| 150 |
+
- repetition_rate: n-gram repetition rate (lower is better)
|
| 151 |
+
- lexical_diversity: Ratio of unique words (higher means more diverse)
|
| 152 |
+
- sample_diversity: Diversity across samples (how different are different generations)
|
| 153 |
"""
|
| 154 |
if not texts:
|
| 155 |
return {}
|
| 156 |
|
| 157 |
+
# Length
|
| 158 |
lengths = [len(t) for t in texts]
|
| 159 |
word_counts = [len(t.split()) for t in texts]
|
| 160 |
|
| 161 |
+
# Repetition rate (based on 4-grams)
|
| 162 |
rep_rates = []
|
| 163 |
for text in texts:
|
| 164 |
words = text.lower().split()
|
|
|
|
| 167 |
continue
|
| 168 |
ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
|
| 169 |
unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
|
| 170 |
+
rep_rates.append(1.0 - unique_ratio) # repetition rate = 1 - unique ratio
|
| 171 |
|
| 172 |
+
# Lexical diversity (Type-Token Ratio)
|
| 173 |
diversities = []
|
| 174 |
for text in texts:
|
| 175 |
words = text.lower().split()
|
|
|
|
| 178 |
else:
|
| 179 |
diversities.append(0.0)
|
| 180 |
|
| 181 |
+
# Inter-sample diversity (inverse of Jaccard similarity)
|
| 182 |
sample_div = 0.0
|
| 183 |
if len(texts) > 1:
|
| 184 |
word_sets = [set(t.lower().split()) for t in texts]
|
llm_lab/evaluation/perplexity.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Perplexity(PPL)
|
| 2 |
|
| 3 |
import math
|
| 4 |
import time
|
|
@@ -13,26 +13,26 @@ from llm_lab.config import EvalConfig
|
|
| 13 |
|
| 14 |
|
| 15 |
class PerplexityEvaluator:
|
| 16 |
-
"""Perplexity(PPL)
|
| 17 |
|
| 18 |
-
Perplexity
|
| 19 |
PPL = exp(average cross-entropy loss)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
- PPL = 1:
|
| 23 |
-
- PPL = 10:
|
| 24 |
-
- PPL = 100:
|
| 25 |
-
- PPL = 32000:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
-
|
| 29 |
-
-
|
| 30 |
-
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
-
|
| 34 |
-
-
|
| 35 |
-
-
|
| 36 |
"""
|
| 37 |
|
| 38 |
def __init__(self, config: EvalConfig):
|
|
@@ -47,14 +47,14 @@ class PerplexityEvaluator:
|
|
| 47 |
dtype: torch.dtype = torch.bfloat16,
|
| 48 |
desc: str = "Evaluation",
|
| 49 |
) -> Dict[str, float]:
|
| 50 |
-
"""
|
| 51 |
|
| 52 |
Returns:
|
| 53 |
{
|
| 54 |
-
"loss":
|
| 55 |
"perplexity": exp(loss),
|
| 56 |
-
"num_tokens":
|
| 57 |
-
"num_batches":
|
| 58 |
}
|
| 59 |
"""
|
| 60 |
model.eval()
|
|
@@ -76,7 +76,7 @@ class PerplexityEvaluator:
|
|
| 76 |
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 77 |
logits, _ = model(input_ids)
|
| 78 |
|
| 79 |
-
#
|
| 80 |
# logits: (B, S, V) → (B*S, V)
|
| 81 |
# targets: (B, S) → (B*S,)
|
| 82 |
loss_per_token = F.cross_entropy(
|
|
@@ -86,7 +86,7 @@ class PerplexityEvaluator:
|
|
| 86 |
reduction="none",
|
| 87 |
)
|
| 88 |
|
| 89 |
-
#
|
| 90 |
valid_mask = (targets.view(-1) != -100)
|
| 91 |
valid_tokens = valid_mask.sum().item()
|
| 92 |
|
|
@@ -100,7 +100,7 @@ class PerplexityEvaluator:
|
|
| 100 |
|
| 101 |
elapsed = time.time() - start_time
|
| 102 |
avg_loss = total_loss / max(total_tokens, 1)
|
| 103 |
-
perplexity = math.exp(min(avg_loss, 100)) #
|
| 104 |
|
| 105 |
results = {
|
| 106 |
"loss": round(avg_loss, 4),
|
|
@@ -113,8 +113,8 @@ class PerplexityEvaluator:
|
|
| 113 |
print(f" ────────────────────────────────")
|
| 114 |
print(f" Loss: {results['loss']:.4f}")
|
| 115 |
print(f" Perplexity: {results['perplexity']:.2f}")
|
| 116 |
-
print(f"
|
| 117 |
-
print(f"
|
| 118 |
|
| 119 |
return results
|
| 120 |
|
|
@@ -127,12 +127,12 @@ class PerplexityEvaluator:
|
|
| 127 |
dtype: torch.dtype = torch.bfloat16,
|
| 128 |
max_batches: int = 50,
|
| 129 |
) -> List[float]:
|
| 130 |
-
"""
|
| 131 |
|
| 132 |
-
|
| 133 |
-
-
|
| 134 |
-
-
|
| 135 |
-
-
|
| 136 |
"""
|
| 137 |
model.eval()
|
| 138 |
seq_len = None
|
|
@@ -155,7 +155,7 @@ class PerplexityEvaluator:
|
|
| 155 |
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 156 |
logits, _ = model(input_ids)
|
| 157 |
|
| 158 |
-
# (B, S)
|
| 159 |
loss_per_token = F.cross_entropy(
|
| 160 |
logits.view(-1, logits.size(-1)),
|
| 161 |
targets.view(-1),
|
|
@@ -167,6 +167,6 @@ class PerplexityEvaluator:
|
|
| 167 |
position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
|
| 168 |
position_count += valid_mask.sum(dim=0)
|
| 169 |
|
| 170 |
-
#
|
| 171 |
position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
|
| 172 |
return position_avg_loss
|
|
|
|
| 1 |
+
"""Perplexity (PPL) evaluator."""
|
| 2 |
|
| 3 |
import math
|
| 4 |
import time
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class PerplexityEvaluator:
|
| 16 |
+
"""Measures Perplexity (PPL).
|
| 17 |
|
| 18 |
+
What is Perplexity?
|
| 19 |
PPL = exp(average cross-entropy loss)
|
| 20 |
|
| 21 |
+
Intuitive meaning:
|
| 22 |
+
- PPL = 1: Perfect prediction (impossible)
|
| 23 |
+
- PPL = 10: Equivalent to picking from 10 candidates each time
|
| 24 |
+
- PPL = 100: Equivalent to picking from 100 candidates (close to random)
|
| 25 |
+
- PPL = 32000: Random selection from the entire vocab (initial random model)
|
| 26 |
+
|
| 27 |
+
Good benchmark for a 1B model (English web text):
|
| 28 |
+
- Trained on 5B tokens: PPL ~30-40
|
| 29 |
+
- Trained on 10B tokens: PPL ~20-30
|
| 30 |
+
- Trained on 20B tokens: PPL ~15-25
|
| 31 |
+
|
| 32 |
+
Measurement method:
|
| 33 |
+
- Compute cross-entropy over all tokens in the validation dataset
|
| 34 |
+
- Average per token, then apply exp()
|
| 35 |
+
- Padding tokens are excluded (ignore_index=-100)
|
| 36 |
"""
|
| 37 |
|
| 38 |
def __init__(self, config: EvalConfig):
|
|
|
|
| 47 |
dtype: torch.dtype = torch.bfloat16,
|
| 48 |
desc: str = "Evaluation",
|
| 49 |
) -> Dict[str, float]:
|
| 50 |
+
"""Measures Perplexity.
|
| 51 |
|
| 52 |
Returns:
|
| 53 |
{
|
| 54 |
+
"loss": average cross-entropy loss,
|
| 55 |
"perplexity": exp(loss),
|
| 56 |
+
"num_tokens": total number of tokens used for evaluation,
|
| 57 |
+
"num_batches": number of batches used for evaluation,
|
| 58 |
}
|
| 59 |
"""
|
| 60 |
model.eval()
|
|
|
|
| 76 |
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 77 |
logits, _ = model(input_ids)
|
| 78 |
|
| 79 |
+
# Per-token cross-entropy (reduction='none')
|
| 80 |
# logits: (B, S, V) → (B*S, V)
|
| 81 |
# targets: (B, S) → (B*S,)
|
| 82 |
loss_per_token = F.cross_entropy(
|
|
|
|
| 86 |
reduction="none",
|
| 87 |
)
|
| 88 |
|
| 89 |
+
# Count only valid tokens that are not -100
|
| 90 |
valid_mask = (targets.view(-1) != -100)
|
| 91 |
valid_tokens = valid_mask.sum().item()
|
| 92 |
|
|
|
|
| 100 |
|
| 101 |
elapsed = time.time() - start_time
|
| 102 |
avg_loss = total_loss / max(total_tokens, 1)
|
| 103 |
+
perplexity = math.exp(min(avg_loss, 100)) # prevent overflow
|
| 104 |
|
| 105 |
results = {
|
| 106 |
"loss": round(avg_loss, 4),
|
|
|
|
| 113 |
print(f" ────────────────────────────────")
|
| 114 |
print(f" Loss: {results['loss']:.4f}")
|
| 115 |
print(f" Perplexity: {results['perplexity']:.2f}")
|
| 116 |
+
print(f" Eval tokens: {total_tokens:,}")
|
| 117 |
+
print(f" Elapsed: {elapsed:.1f}s")
|
| 118 |
|
| 119 |
return results
|
| 120 |
|
|
|
|
| 127 |
dtype: torch.dtype = torch.bfloat16,
|
| 128 |
max_batches: int = 50,
|
| 129 |
) -> List[float]:
|
| 130 |
+
"""Measures loss per position within a sequence.
|
| 131 |
|
| 132 |
+
Learning insight:
|
| 133 |
+
- Positions 0~10: Higher loss (insufficient context)
|
| 134 |
+
- Positions 100+: Loss stabilizes lower (context is leveraged)
|
| 135 |
+
- This pattern demonstrates the Transformer's in-context learning capability
|
| 136 |
"""
|
| 137 |
model.eval()
|
| 138 |
seq_len = None
|
|
|
|
| 155 |
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 156 |
logits, _ = model(input_ids)
|
| 157 |
|
| 158 |
+
# Per-token loss in shape (B, S)
|
| 159 |
loss_per_token = F.cross_entropy(
|
| 160 |
logits.view(-1, logits.size(-1)),
|
| 161 |
targets.view(-1),
|
|
|
|
| 167 |
position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
|
| 168 |
position_count += valid_mask.sum(dim=0)
|
| 169 |
|
| 170 |
+
# Average loss per position
|
| 171 |
position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
|
| 172 |
return position_avg_loss
|
llm_lab/evaluation/runner.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from typing import Any, Dict, Optional
|
| 4 |
|
|
@@ -20,13 +20,13 @@ def run_evaluation(
|
|
| 20 |
metrics_history: Optional[Dict[str, list]] = None,
|
| 21 |
config: Optional[EvalConfig] = None,
|
| 22 |
) -> Dict[str, Any]:
|
| 23 |
-
"""
|
| 24 |
|
| 25 |
-
|
| 26 |
```python
|
| 27 |
from llm_lab.evaluation import run_evaluation
|
| 28 |
|
| 29 |
-
#
|
| 30 |
report = run_evaluation(
|
| 31 |
model=trainer.model,
|
| 32 |
tokenizer=tokenizer,
|
|
@@ -50,7 +50,7 @@ def run_evaluation(
|
|
| 50 |
|
| 51 |
report = evaluator.run_full_evaluation()
|
| 52 |
|
| 53 |
-
#
|
| 54 |
InsightChecklist.run_checklist(report, metrics_history)
|
| 55 |
|
| 56 |
return report
|
|
|
|
| 1 |
+
"""Evaluation runner helper (Quick Start)."""
|
| 2 |
|
| 3 |
from typing import Any, Dict, Optional
|
| 4 |
|
|
|
|
| 20 |
metrics_history: Optional[Dict[str, list]] = None,
|
| 21 |
config: Optional[EvalConfig] = None,
|
| 22 |
) -> Dict[str, Any]:
|
| 23 |
+
"""Runs all evaluations in one call.
|
| 24 |
|
| 25 |
+
Usage (Colab):
|
| 26 |
```python
|
| 27 |
from llm_lab.evaluation import run_evaluation
|
| 28 |
|
| 29 |
+
# After training is complete
|
| 30 |
report = run_evaluation(
|
| 31 |
model=trainer.model,
|
| 32 |
tokenizer=tokenizer,
|
|
|
|
| 50 |
|
| 51 |
report = evaluator.run_full_evaluation()
|
| 52 |
|
| 53 |
+
# Insight checklist
|
| 54 |
InsightChecklist.run_checklist(report, metrics_history)
|
| 55 |
|
| 56 |
return report
|
llm_lab/evaluation/scaling.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Scaling Law
|
| 2 |
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Any, Dict, List, Optional
|
|
@@ -19,17 +19,17 @@ except ImportError:
|
|
| 19 |
|
| 20 |
|
| 21 |
class ScalingAnalyzer:
|
| 22 |
-
"""10M → 100M → 1B
|
| 23 |
|
| 24 |
Chinchilla Scaling Law (2022):
|
| 25 |
-
-
|
| 26 |
-
- Loss ∝ N^(-α) × D^(-β) (N=
|
| 27 |
-
- α ≈ 0.076, β ≈ 0.095 (
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
-
|
| 31 |
-
-
|
| 32 |
-
-
|
| 33 |
"""
|
| 34 |
|
| 35 |
def __init__(self, save_dir: str = "./eval_results"):
|
|
@@ -40,7 +40,7 @@ class ScalingAnalyzer:
|
|
| 40 |
self,
|
| 41 |
model_results: List[Dict[str, Any]],
|
| 42 |
) -> Dict[str, Any]:
|
| 43 |
-
"""
|
| 44 |
|
| 45 |
Args:
|
| 46 |
model_results: [
|
|
@@ -50,25 +50,25 @@ class ScalingAnalyzer:
|
|
| 50 |
]
|
| 51 |
|
| 52 |
Returns:
|
| 53 |
-
|
| 54 |
"""
|
| 55 |
if len(model_results) < 2:
|
| 56 |
-
print("⚠️ Scaling
|
| 57 |
return {}
|
| 58 |
|
| 59 |
print("\n" + "=" * 70)
|
| 60 |
-
print("📈 Scaling Law
|
| 61 |
print("=" * 70)
|
| 62 |
|
| 63 |
-
# ──
|
| 64 |
-
print(f"\n {'
|
| 65 |
print(f" {'─'*52}")
|
| 66 |
for r in model_results:
|
| 67 |
params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
|
| 68 |
tokens_str = f"{r['tokens']/1e9:.1f}B"
|
| 69 |
print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
|
| 70 |
|
| 71 |
-
# ── Scaling
|
| 72 |
analysis = {"models": model_results, "scaling_efficiency": []}
|
| 73 |
|
| 74 |
for i in range(1, len(model_results)):
|
|
@@ -89,17 +89,17 @@ class ScalingAnalyzer:
|
|
| 89 |
analysis["scaling_efficiency"].append(efficiency)
|
| 90 |
|
| 91 |
print(f"\n {prev['name']} → {curr['name']}:")
|
| 92 |
-
print(f"
|
| 93 |
-
print(f" Loss
|
| 94 |
-
print(f" PPL
|
| 95 |
|
| 96 |
-
# ── Chinchilla
|
| 97 |
-
print(f"\n Chinchilla
|
| 98 |
for r in model_results:
|
| 99 |
actual_ratio = r["tokens"] / r["params"]
|
| 100 |
-
status = "✅
|
| 101 |
-
print(f" {r['name']}:
|
| 102 |
-
f"(
|
| 103 |
|
| 104 |
analysis["chinchilla_ratios"] = [
|
| 105 |
{"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
|
|
@@ -113,9 +113,9 @@ class ScalingAnalyzer:
|
|
| 113 |
model_results: List[Dict[str, Any]],
|
| 114 |
save_path: Optional[str] = None,
|
| 115 |
):
|
| 116 |
-
"""
|
| 117 |
if not HAS_MATPLOTLIB or not HAS_NUMPY:
|
| 118 |
-
print("⚠️ matplotlib/numpy
|
| 119 |
return
|
| 120 |
|
| 121 |
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
|
@@ -149,5 +149,5 @@ class ScalingAnalyzer:
|
|
| 149 |
|
| 150 |
save_path = save_path or str(self.save_dir / "scaling_curves.png")
|
| 151 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 152 |
-
print(f"\n 📊 Scaling
|
| 153 |
plt.close(fig)
|
|
|
|
| 1 |
+
"""Scaling Law analyzer."""
|
| 2 |
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Any, Dict, List, Optional
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class ScalingAnalyzer:
|
| 22 |
+
"""Analyzes Scaling Law across 10M → 100M → 1B models.
|
| 23 |
|
| 24 |
Chinchilla Scaling Law (2022):
|
| 25 |
+
- Optimal training: tokens ≈ 20 × number of parameters
|
| 26 |
+
- Loss ∝ N^(-α) × D^(-β) (N=parameters, D=data)
|
| 27 |
+
- α ≈ 0.076, β ≈ 0.095 (per the paper)
|
| 28 |
+
|
| 29 |
+
Purpose of this analysis:
|
| 30 |
+
- Verify whether our model follows the Scaling Law
|
| 31 |
+
- Predict the effect of larger models / more data
|
| 32 |
+
- Understand the optimal allocation of compute resources
|
| 33 |
"""
|
| 34 |
|
| 35 |
def __init__(self, save_dir: str = "./eval_results"):
|
|
|
|
| 40 |
self,
|
| 41 |
model_results: List[Dict[str, Any]],
|
| 42 |
) -> Dict[str, Any]:
|
| 43 |
+
"""Comparatively analyzes results across multiple model sizes.
|
| 44 |
|
| 45 |
Args:
|
| 46 |
model_results: [
|
|
|
|
| 50 |
]
|
| 51 |
|
| 52 |
Returns:
|
| 53 |
+
Analysis result dictionary
|
| 54 |
"""
|
| 55 |
if len(model_results) < 2:
|
| 56 |
+
print("⚠️ Scaling analysis requires results from at least 2 models.")
|
| 57 |
return {}
|
| 58 |
|
| 59 |
print("\n" + "=" * 70)
|
| 60 |
+
print("📈 Scaling Law Analysis")
|
| 61 |
print("=" * 70)
|
| 62 |
|
| 63 |
+
# ── Results table ──
|
| 64 |
+
print(f"\n {'Model':<8} {'Parameters':>12} {'Tokens':>10} {'Loss':>8} {'PPL':>8}")
|
| 65 |
print(f" {'─'*52}")
|
| 66 |
for r in model_results:
|
| 67 |
params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
|
| 68 |
tokens_str = f"{r['tokens']/1e9:.1f}B"
|
| 69 |
print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
|
| 70 |
|
| 71 |
+
# ── Scaling efficiency calculation ──
|
| 72 |
analysis = {"models": model_results, "scaling_efficiency": []}
|
| 73 |
|
| 74 |
for i in range(1, len(model_results)):
|
|
|
|
| 89 |
analysis["scaling_efficiency"].append(efficiency)
|
| 90 |
|
| 91 |
print(f"\n {prev['name']} → {curr['name']}:")
|
| 92 |
+
print(f" Parameters ×{param_ratio:.1f}")
|
| 93 |
+
print(f" Loss reduction: {loss_reduction:.4f}")
|
| 94 |
+
print(f" PPL reduction: {ppl_reduction*100:.1f}%")
|
| 95 |
|
| 96 |
+
# ── Chinchilla optimality check ──
|
| 97 |
+
print(f"\n Chinchilla optimality check (tokens ≈ 20 × parameters):")
|
| 98 |
for r in model_results:
|
| 99 |
actual_ratio = r["tokens"] / r["params"]
|
| 100 |
+
status = "✅ Optimal range" if 15 <= actual_ratio <= 25 else "⚠️ Out of range"
|
| 101 |
+
print(f" {r['name']}: tokens/parameters = {actual_ratio:.1f}x "
|
| 102 |
+
f"(optimal: 20x) {status}")
|
| 103 |
|
| 104 |
analysis["chinchilla_ratios"] = [
|
| 105 |
{"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
|
|
|
|
| 113 |
model_results: List[Dict[str, Any]],
|
| 114 |
save_path: Optional[str] = None,
|
| 115 |
):
|
| 116 |
+
"""Visualizes scaling curves."""
|
| 117 |
if not HAS_MATPLOTLIB or not HAS_NUMPY:
|
| 118 |
+
print("⚠️ matplotlib/numpy required: pip install matplotlib numpy")
|
| 119 |
return
|
| 120 |
|
| 121 |
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
|
|
|
| 149 |
|
| 150 |
save_path = save_path or str(self.save_dir / "scaling_curves.png")
|
| 151 |
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 152 |
+
print(f"\n 📊 Scaling curves saved: {save_path}")
|
| 153 |
plt.close(fig)
|
llm_lab/model/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
from .norm import RMSNorm
|
| 3 |
from .rope import RotaryPositionalEmbedding
|
| 4 |
from .attention import GroupedQueryAttention
|
|
|
|
| 1 |
+
"""Model architecture module — LLaMA-style Decoder-Only Transformer."""
|
| 2 |
from .norm import RMSNorm
|
| 3 |
from .rope import RotaryPositionalEmbedding
|
| 4 |
from .attention import GroupedQueryAttention
|
llm_lab/model/attention.py
CHANGED
|
@@ -11,20 +11,20 @@ from .rope import RotaryPositionalEmbedding
|
|
| 11 |
|
| 12 |
|
| 13 |
class GroupedQueryAttention(nn.Module):
|
| 14 |
-
"""GQA:
|
| 15 |
|
| 16 |
MHA vs GQA vs MQA:
|
| 17 |
-
- MHA (Multi-Head Attention): Q, K, V
|
| 18 |
-
- MQA (Multi-Query Attention): K, V
|
| 19 |
-
- GQA (Grouped Query Attention): K, V
|
| 20 |
-
→ MHA
|
| 21 |
|
| 22 |
-
|
| 23 |
-
Q
|
| 24 |
-
K/V
|
| 25 |
-
→ Q
|
| 26 |
|
| 27 |
-
Attention
|
| 28 |
Attention(Q, K, V) = softmax(Q·K^T / √d_k) · V
|
| 29 |
"""
|
| 30 |
|
|
@@ -36,14 +36,14 @@ class GroupedQueryAttention(nn.Module):
|
|
| 36 |
self.num_kv_heads = config.num_kv_heads
|
| 37 |
self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
|
| 38 |
|
| 39 |
-
# Q/K/V
|
| 40 |
# Q: hidden_dim → num_heads × head_dim
|
| 41 |
self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
|
| 42 |
-
# K, V: hidden_dim → num_kv_heads × head_dim (
|
| 43 |
self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
|
| 44 |
self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
|
| 45 |
|
| 46 |
-
#
|
| 47 |
self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
|
| 48 |
|
| 49 |
# RoPE
|
|
@@ -51,7 +51,7 @@ class GroupedQueryAttention(nn.Module):
|
|
| 51 |
dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
|
| 52 |
)
|
| 53 |
|
| 54 |
-
# Attention dropout (
|
| 55 |
self.attn_dropout = nn.Dropout(config.dropout)
|
| 56 |
|
| 57 |
def forward(
|
|
@@ -64,7 +64,7 @@ class GroupedQueryAttention(nn.Module):
|
|
| 64 |
Args:
|
| 65 |
x: (batch_size, seq_len, hidden_dim)
|
| 66 |
mask: (seq_len, seq_len) causal mask
|
| 67 |
-
position_offset:
|
| 68 |
|
| 69 |
Returns:
|
| 70 |
(batch_size, seq_len, hidden_dim)
|
|
@@ -72,13 +72,13 @@ class GroupedQueryAttention(nn.Module):
|
|
| 72 |
B, S, _ = x.shape
|
| 73 |
|
| 74 |
# ──────────────────────────────────────────────
|
| 75 |
-
# Step 1: Q, K, V
|
| 76 |
# ──────────────────────────────────────────────
|
| 77 |
q = self.q_proj(x) # (B, S, num_heads × head_dim)
|
| 78 |
k = self.k_proj(x) # (B, S, num_kv_heads × head_dim)
|
| 79 |
v = self.v_proj(x) # (B, S, num_kv_heads × head_dim)
|
| 80 |
|
| 81 |
-
#
|
| 82 |
q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
| 83 |
# → (B, num_heads, S, head_dim)
|
| 84 |
k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
@@ -86,16 +86,16 @@ class GroupedQueryAttention(nn.Module):
|
|
| 86 |
v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 87 |
|
| 88 |
# ──────────────────────────────────────────────
|
| 89 |
-
# Step 2: RoPE
|
| 90 |
# ──────────────────────────────────────────────
|
| 91 |
-
#
|
| 92 |
-
# "
|
| 93 |
q, k = self.rope(q, k, position_offset)
|
| 94 |
|
| 95 |
# ──────────────────────────────────────────────
|
| 96 |
-
# Step 3: GQA - KV
|
| 97 |
# ──────────────────────────────────────────────
|
| 98 |
-
# num_kv_heads=4 → num_heads=16:
|
| 99 |
if self.num_kv_groups > 1:
|
| 100 |
k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
|
| 101 |
v = self._repeat_kv(v)
|
|
@@ -103,17 +103,17 @@ class GroupedQueryAttention(nn.Module):
|
|
| 103 |
# ──────────────────────────────────────────────
|
| 104 |
# Step 4: Scaled Dot-Product Attention
|
| 105 |
# ──────────────────────────────────────────────
|
| 106 |
-
# PyTorch >= 2.0
|
| 107 |
attn_out = F.scaled_dot_product_attention(
|
| 108 |
q, k, v,
|
| 109 |
attn_mask=mask,
|
| 110 |
dropout_p=self.config.dropout if self.training else 0.0,
|
| 111 |
-
is_causal=(mask is None), #
|
| 112 |
)
|
| 113 |
# → (B, num_heads, S, head_dim)
|
| 114 |
|
| 115 |
# ──────────────────────────────────────────────
|
| 116 |
-
# Step 5:
|
| 117 |
# ──────────────────────────────────────────────
|
| 118 |
attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
|
| 119 |
# → (B, S, num_heads × head_dim)
|
|
@@ -121,11 +121,11 @@ class GroupedQueryAttention(nn.Module):
|
|
| 121 |
return self.o_proj(attn_out) # → (B, S, hidden_dim)
|
| 122 |
|
| 123 |
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
-
"""KV
|
| 125 |
|
| 126 |
(B, num_kv_heads, S, head_dim) → (B, num_heads, S, head_dim)
|
| 127 |
|
| 128 |
-
|
| 129 |
[kv0, kv1, kv2, kv3] → [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
|
| 130 |
"""
|
| 131 |
B, H_kv, S, D = x.shape
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class GroupedQueryAttention(nn.Module):
|
| 14 |
+
"""GQA: A memory-efficient variant of Multi-Head Attention.
|
| 15 |
|
| 16 |
MHA vs GQA vs MQA:
|
| 17 |
+
- MHA (Multi-Head Attention): Q, K, V all have num_heads → high memory usage
|
| 18 |
+
- MQA (Multi-Query Attention): K, V share a single head → risk of quality degradation
|
| 19 |
+
- GQA (Grouped Query Attention): K, V are grouped into num_kv_heads
|
| 20 |
+
→ a middle ground between MHA and MQA, good quality-efficiency balance
|
| 21 |
|
| 22 |
+
Example (num_heads=16, num_kv_heads=4):
|
| 23 |
+
Q heads: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
|
| 24 |
+
K/V groups: [ 0 , 1 , 2 , 3 ]
|
| 25 |
+
→ 4 Q heads share 1 K/V head
|
| 26 |
|
| 27 |
+
Attention formula:
|
| 28 |
Attention(Q, K, V) = softmax(Q·K^T / √d_k) · V
|
| 29 |
"""
|
| 30 |
|
|
|
|
| 36 |
self.num_kv_heads = config.num_kv_heads
|
| 37 |
self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
|
| 38 |
|
| 39 |
+
# Q/K/V projections
|
| 40 |
# Q: hidden_dim → num_heads × head_dim
|
| 41 |
self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
|
| 42 |
+
# K, V: hidden_dim → num_kv_heads × head_dim (smaller than Q!)
|
| 43 |
self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
|
| 44 |
self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
|
| 45 |
|
| 46 |
+
# Output projection: merge all head outputs back to hidden_dim
|
| 47 |
self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
|
| 48 |
|
| 49 |
# RoPE
|
|
|
|
| 51 |
dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
|
| 52 |
)
|
| 53 |
|
| 54 |
+
# Attention dropout (typically 0 during pretraining)
|
| 55 |
self.attn_dropout = nn.Dropout(config.dropout)
|
| 56 |
|
| 57 |
def forward(
|
|
|
|
| 64 |
Args:
|
| 65 |
x: (batch_size, seq_len, hidden_dim)
|
| 66 |
mask: (seq_len, seq_len) causal mask
|
| 67 |
+
position_offset: position offset (used during inference)
|
| 68 |
|
| 69 |
Returns:
|
| 70 |
(batch_size, seq_len, hidden_dim)
|
|
|
|
| 72 |
B, S, _ = x.shape
|
| 73 |
|
| 74 |
# ──────────────────────────────────────────────
|
| 75 |
+
# Step 1: Q, K, V projections
|
| 76 |
# ──────────────────────────────────────────────
|
| 77 |
q = self.q_proj(x) # (B, S, num_heads × head_dim)
|
| 78 |
k = self.k_proj(x) # (B, S, num_kv_heads × head_dim)
|
| 79 |
v = self.v_proj(x) # (B, S, num_kv_heads × head_dim)
|
| 80 |
|
| 81 |
+
# Reshape into multi-head form
|
| 82 |
q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
| 83 |
# → (B, num_heads, S, head_dim)
|
| 84 |
k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
|
|
| 86 |
v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 87 |
|
| 88 |
# ──────────────────────────────────────────────
|
| 89 |
+
# Step 2: Apply RoPE (to Q and K only! Not to V)
|
| 90 |
# ──────────────────────────────────────────────
|
| 91 |
+
# Positional information should only affect "where to attend" (Q·K),
|
| 92 |
+
# not "what to retrieve" (V).
|
| 93 |
q, k = self.rope(q, k, position_offset)
|
| 94 |
|
| 95 |
# ──────────────────────────────────────────────
|
| 96 |
+
# Step 3: GQA - expand KV heads (repeat)
|
| 97 |
# ──────────────────────────────────────────────
|
| 98 |
+
# num_kv_heads=4 → num_heads=16: repeat each KV 4 times
|
| 99 |
if self.num_kv_groups > 1:
|
| 100 |
k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
|
| 101 |
v = self._repeat_kv(v)
|
|
|
|
| 103 |
# ──────────────────────────────────────────────
|
| 104 |
# Step 4: Scaled Dot-Product Attention
|
| 105 |
# ──────────────────────────────────────────────
|
| 106 |
+
# Uses PyTorch >= 2.0's optimized implementation (Flash Attention applied automatically)
|
| 107 |
attn_out = F.scaled_dot_product_attention(
|
| 108 |
q, k, v,
|
| 109 |
attn_mask=mask,
|
| 110 |
dropout_p=self.config.dropout if self.training else 0.0,
|
| 111 |
+
is_causal=(mask is None), # apply automatic causal masking when no mask is provided
|
| 112 |
)
|
| 113 |
# → (B, num_heads, S, head_dim)
|
| 114 |
|
| 115 |
# ──────────────────────────────────────────────
|
| 116 |
+
# Step 5: Merge heads + output projection
|
| 117 |
# ──────────────────────────────────────────────
|
| 118 |
attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
|
| 119 |
# → (B, S, num_heads × head_dim)
|
|
|
|
| 121 |
return self.o_proj(attn_out) # → (B, S, hidden_dim)
|
| 122 |
|
| 123 |
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
"""Repeat KV heads to match the number of Q heads.
|
| 125 |
|
| 126 |
(B, num_kv_heads, S, head_dim) → (B, num_heads, S, head_dim)
|
| 127 |
|
| 128 |
+
Example: num_kv_heads=4, num_kv_groups=4
|
| 129 |
[kv0, kv1, kv2, kv3] → [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
|
| 130 |
"""
|
| 131 |
B, H_kv, S, D = x.shape
|
llm_lab/model/feedforward.py
CHANGED
|
@@ -8,41 +8,41 @@ from llm_lab.config import ModelConfig
|
|
| 8 |
|
| 9 |
|
| 10 |
class SwiGLUFeedForward(nn.Module):
|
| 11 |
-
"""SwiGLU: Gated Linear Unit with Swish
|
| 12 |
|
| 13 |
-
|
| 14 |
FFN(x) = ReLU(x·W1 + b1)·W2 + b2
|
| 15 |
-
→
|
| 16 |
|
| 17 |
SwiGLU FFN:
|
| 18 |
SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down
|
| 19 |
-
→
|
| 20 |
|
| 21 |
-
|
| 22 |
-
- Swish(x) = x · sigmoid(x):
|
| 23 |
-
-
|
| 24 |
-
-
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self, config: ModelConfig):
|
| 32 |
super().__init__()
|
| 33 |
-
#
|
| 34 |
self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
| 35 |
-
#
|
| 36 |
self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
| 37 |
-
#
|
| 38 |
self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
|
| 39 |
|
| 40 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 41 |
# SwiGLU(x) = (Swish(gate(x)) ⊙ up(x)) · down
|
| 42 |
#
|
| 43 |
-
# 1) gate:
|
| 44 |
gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
|
| 45 |
-
# 2) up:
|
| 46 |
up = self.up_proj(x)
|
| 47 |
-
# 3) element-wise
|
| 48 |
return self.down_proj(gate * up)
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class SwiGLUFeedForward(nn.Module):
|
| 11 |
+
"""SwiGLU: Gated Linear Unit with Swish activation function.
|
| 12 |
|
| 13 |
+
Standard FFN:
|
| 14 |
FFN(x) = ReLU(x·W1 + b1)·W2 + b2
|
| 15 |
+
→ simple nonlinear transformation
|
| 16 |
|
| 17 |
SwiGLU FFN:
|
| 18 |
SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down
|
| 19 |
+
→ controls information flow via a gating mechanism
|
| 20 |
|
| 21 |
+
Why is SwiGLU better?
|
| 22 |
+
- Swish(x) = x · sigmoid(x): smooth activation, allows some negative values
|
| 23 |
+
- The gate vector learns "which information to let through"
|
| 24 |
+
- Consistently reported to outperform ReLU FFN in PaLM, LLaMA, etc.
|
| 25 |
|
| 26 |
+
Note: Having two up-projections (W_gate and W_up) means
|
| 27 |
+
1.5x the parameters of a standard FFN, but intermediate_dim is
|
| 28 |
+
adjusted to match the total parameter count.
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self, config: ModelConfig):
|
| 32 |
super().__init__()
|
| 33 |
+
# Gate projection: hidden_dim → intermediate_dim
|
| 34 |
self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
| 35 |
+
# Up projection: hidden_dim → intermediate_dim
|
| 36 |
self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
| 37 |
+
# Down projection: intermediate_dim → hidden_dim
|
| 38 |
self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
|
| 39 |
|
| 40 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 41 |
# SwiGLU(x) = (Swish(gate(x)) ⊙ up(x)) · down
|
| 42 |
#
|
| 43 |
+
# 1) gate: decides which information to pass through (Swish activation)
|
| 44 |
gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
|
| 45 |
+
# 2) up: projects information to a higher dimension
|
| 46 |
up = self.up_proj(x)
|
| 47 |
+
# 3) element-wise multiplication (gating) → project back to original dimension
|
| 48 |
return self.down_proj(gate * up)
|
llm_lab/model/llm_model.py
CHANGED
|
@@ -13,19 +13,19 @@ from .transformer_block import TransformerBlock
|
|
| 13 |
|
| 14 |
|
| 15 |
class LLMModel(nn.Module):
|
| 16 |
-
"""1B
|
| 17 |
|
| 18 |
-
|
| 19 |
Input Token IDs
|
| 20 |
→ Token Embedding
|
| 21 |
→ [TransformerBlock] × num_layers (+ Activation Checkpointing)
|
| 22 |
-
→ RMSNorm (
|
| 23 |
→ Linear Head (→ vocab logits)
|
| 24 |
|
| 25 |
Weight Tying:
|
| 26 |
-
-
|
| 27 |
-
-
|
| 28 |
-
-
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self, config: ModelConfig):
|
|
@@ -41,29 +41,29 @@ class LLMModel(nn.Module):
|
|
| 41 |
for i in range(config.num_layers)
|
| 42 |
])
|
| 43 |
|
| 44 |
-
# ──
|
| 45 |
self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 46 |
|
| 47 |
-
# ──
|
| 48 |
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
| 49 |
-
# Weight Tying: lm_head
|
| 50 |
self.lm_head.weight = self.token_embedding.weight
|
| 51 |
|
| 52 |
-
#
|
| 53 |
self._init_weights()
|
| 54 |
|
| 55 |
def _init_weights(self):
|
| 56 |
-
"""
|
| 57 |
|
| 58 |
-
|
| 59 |
-
-
|
| 60 |
-
-
|
| 61 |
-
-
|
| 62 |
|
| 63 |
-
GPT-2
|
| 64 |
-
-
|
| 65 |
- Residual projection: N(0, 0.02 / √(2 × num_layers))
|
| 66 |
-
→
|
| 67 |
"""
|
| 68 |
std = 0.02
|
| 69 |
residual_std = std / math.sqrt(2 * self.config.num_layers)
|
|
@@ -76,7 +76,7 @@ class LLMModel(nn.Module):
|
|
| 76 |
elif isinstance(module, nn.Embedding):
|
| 77 |
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 78 |
|
| 79 |
-
#
|
| 80 |
for layer in self.layers:
|
| 81 |
nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
|
| 82 |
nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
|
|
@@ -89,55 +89,55 @@ class LLMModel(nn.Module):
|
|
| 89 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 90 |
"""
|
| 91 |
Args:
|
| 92 |
-
input_ids: (batch_size, seq_len) -
|
| 93 |
-
targets: (batch_size, seq_len) -
|
| 94 |
-
position_offset:
|
| 95 |
|
| 96 |
Returns:
|
| 97 |
logits: (batch_size, seq_len, vocab_size)
|
| 98 |
-
loss:
|
| 99 |
"""
|
| 100 |
B, S = input_ids.shape
|
| 101 |
|
| 102 |
# ── Step 1: Token Embedding ──
|
| 103 |
-
#
|
| 104 |
h = self.token_embedding(input_ids) # (B, S, hidden_dim)
|
| 105 |
|
| 106 |
# ── Step 2: Transformer Blocks ──
|
| 107 |
-
# Activation Checkpointing:
|
| 108 |
-
# (
|
| 109 |
for layer in self.layers:
|
| 110 |
if self.training and torch.is_grad_enabled():
|
| 111 |
-
# Activation Checkpointing
|
| 112 |
h = torch.utils.checkpoint.checkpoint(
|
| 113 |
layer, h, None, position_offset,
|
| 114 |
-
use_reentrant=False, # PyTorch >= 2.0
|
| 115 |
)
|
| 116 |
else:
|
| 117 |
h = layer(h, mask=None, position_offset=position_offset)
|
| 118 |
|
| 119 |
-
# ── Step 3:
|
| 120 |
h = self.final_norm(h)
|
| 121 |
|
| 122 |
-
# ── Step 4:
|
| 123 |
logits = self.lm_head(h) # (B, S, vocab_size)
|
| 124 |
|
| 125 |
-
# ── Step 5:
|
| 126 |
loss = None
|
| 127 |
if targets is not None:
|
| 128 |
-
# Cross-Entropy Loss:
|
| 129 |
# logits: (B, S, V) → (B*S, V)
|
| 130 |
# targets: (B, S) → (B*S,)
|
| 131 |
loss = F.cross_entropy(
|
| 132 |
logits.view(-1, self.config.vocab_size),
|
| 133 |
targets.view(-1),
|
| 134 |
-
ignore_index=-100, #
|
| 135 |
)
|
| 136 |
|
| 137 |
return logits, loss
|
| 138 |
|
| 139 |
def count_parameters(self, trainable_only: bool = True) -> int:
|
| 140 |
-
"""
|
| 141 |
if trainable_only:
|
| 142 |
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 143 |
return sum(p.numel() for p in self.parameters())
|
|
@@ -151,50 +151,50 @@ class LLMModel(nn.Module):
|
|
| 151 |
top_k: int = 50,
|
| 152 |
top_p: float = 0.9,
|
| 153 |
) -> torch.Tensor:
|
| 154 |
-
"""
|
| 155 |
|
| 156 |
-
Autoregressive
|
| 157 |
|
| 158 |
Args:
|
| 159 |
-
input_ids: (1, prompt_len) -
|
| 160 |
-
max_new_tokens:
|
| 161 |
-
temperature:
|
| 162 |
-
top_k:
|
| 163 |
-
top_p:
|
| 164 |
"""
|
| 165 |
self.eval()
|
| 166 |
generated = input_ids
|
| 167 |
|
| 168 |
for _ in range(max_new_tokens):
|
| 169 |
-
#
|
| 170 |
ctx = generated[:, -self.config.max_seq_len:]
|
| 171 |
|
| 172 |
# Forward pass
|
| 173 |
logits, _ = self(ctx)
|
| 174 |
-
#
|
| 175 |
next_logits = logits[:, -1, :] / temperature
|
| 176 |
|
| 177 |
-
# ── Top-K
|
| 178 |
if top_k > 0:
|
| 179 |
top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
| 180 |
min_top_k = top_k_values[:, -1].unsqueeze(-1)
|
| 181 |
next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
|
| 182 |
|
| 183 |
-
# ── Top-P (Nucleus)
|
| 184 |
if top_p < 1.0:
|
| 185 |
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 186 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 187 |
-
#
|
| 188 |
remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
| 189 |
sorted_logits[remove_mask] = float("-inf")
|
| 190 |
-
#
|
| 191 |
next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
| 192 |
|
| 193 |
-
#
|
| 194 |
probs = F.softmax(next_logits, dim=-1)
|
| 195 |
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 196 |
|
| 197 |
-
#
|
| 198 |
generated = torch.cat([generated, next_token], dim=1)
|
| 199 |
|
| 200 |
return generated
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class LLMModel(nn.Module):
|
| 16 |
+
"""1B parameter LLaMA-style Decoder-Only Transformer.
|
| 17 |
|
| 18 |
+
Overall structure:
|
| 19 |
Input Token IDs
|
| 20 |
→ Token Embedding
|
| 21 |
→ [TransformerBlock] × num_layers (+ Activation Checkpointing)
|
| 22 |
+
→ RMSNorm (final)
|
| 23 |
→ Linear Head (→ vocab logits)
|
| 24 |
|
| 25 |
Weight Tying:
|
| 26 |
+
- Shares weights between the input Embedding and the output Linear Head
|
| 27 |
+
- Saves parameters (~65M) while maintaining or improving performance
|
| 28 |
+
- Intuition: "representing word meaning" and "predicting words" use the same space
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self, config: ModelConfig):
|
|
|
|
| 41 |
for i in range(config.num_layers)
|
| 42 |
])
|
| 43 |
|
| 44 |
+
# ── Final normalization ──
|
| 45 |
self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 46 |
|
| 47 |
+
# ── Output head (Weight Tying) ──
|
| 48 |
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
| 49 |
+
# Weight Tying: lm_head weights = token_embedding weights
|
| 50 |
self.lm_head.weight = self.token_embedding.weight
|
| 51 |
|
| 52 |
+
# Weight initialization
|
| 53 |
self._init_weights()
|
| 54 |
|
| 55 |
def _init_weights(self):
|
| 56 |
+
"""Weight initialization strategy.
|
| 57 |
|
| 58 |
+
Why does initialization matter?
|
| 59 |
+
- Too large: activation explosion → NaN
|
| 60 |
+
- Too small: gradient vanishing → training stagnation
|
| 61 |
+
- Proper initialization: keeps output variance consistent across layers
|
| 62 |
|
| 63 |
+
GPT-2 style initialization:
|
| 64 |
+
- General Linear: N(0, 0.02)
|
| 65 |
- Residual projection: N(0, 0.02 / √(2 × num_layers))
|
| 66 |
+
→ reduces residual contribution as depth increases for stability
|
| 67 |
"""
|
| 68 |
std = 0.02
|
| 69 |
residual_std = std / math.sqrt(2 * self.config.num_layers)
|
|
|
|
| 76 |
elif isinstance(module, nn.Embedding):
|
| 77 |
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 78 |
|
| 79 |
+
# Apply scaled-down initialization to residual projection layers
|
| 80 |
for layer in self.layers:
|
| 81 |
nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
|
| 82 |
nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
|
|
|
|
| 89 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 90 |
"""
|
| 91 |
Args:
|
| 92 |
+
input_ids: (batch_size, seq_len) - token IDs
|
| 93 |
+
targets: (batch_size, seq_len) - ground-truth token IDs (during training)
|
| 94 |
+
position_offset: position offset (during inference)
|
| 95 |
|
| 96 |
Returns:
|
| 97 |
logits: (batch_size, seq_len, vocab_size)
|
| 98 |
+
loss: scalar (when targets are provided) or None
|
| 99 |
"""
|
| 100 |
B, S = input_ids.shape
|
| 101 |
|
| 102 |
# ── Step 1: Token Embedding ──
|
| 103 |
+
# Convert each token ID into a vector of dimension hidden_dim
|
| 104 |
h = self.token_embedding(input_ids) # (B, S, hidden_dim)
|
| 105 |
|
| 106 |
# ── Step 2: Transformer Blocks ──
|
| 107 |
+
# Activation Checkpointing: saves memory during training
|
| 108 |
+
# (does not store intermediate activations; recomputes them during backward)
|
| 109 |
for layer in self.layers:
|
| 110 |
if self.training and torch.is_grad_enabled():
|
| 111 |
+
# Apply Activation Checkpointing
|
| 112 |
h = torch.utils.checkpoint.checkpoint(
|
| 113 |
layer, h, None, position_offset,
|
| 114 |
+
use_reentrant=False, # recommended for PyTorch >= 2.0
|
| 115 |
)
|
| 116 |
else:
|
| 117 |
h = layer(h, mask=None, position_offset=position_offset)
|
| 118 |
|
| 119 |
+
# ── Step 3: Final normalization ──
|
| 120 |
h = self.final_norm(h)
|
| 121 |
|
| 122 |
+
# ── Step 4: Compute output logits ──
|
| 123 |
logits = self.lm_head(h) # (B, S, vocab_size)
|
| 124 |
|
| 125 |
+
# ── Step 5: Compute loss (during training) ──
|
| 126 |
loss = None
|
| 127 |
if targets is not None:
|
| 128 |
+
# Cross-Entropy Loss: next-token prediction
|
| 129 |
# logits: (B, S, V) → (B*S, V)
|
| 130 |
# targets: (B, S) → (B*S,)
|
| 131 |
loss = F.cross_entropy(
|
| 132 |
logits.view(-1, self.config.vocab_size),
|
| 133 |
targets.view(-1),
|
| 134 |
+
ignore_index=-100, # ignore padding tokens
|
| 135 |
)
|
| 136 |
|
| 137 |
return logits, loss
|
| 138 |
|
| 139 |
def count_parameters(self, trainable_only: bool = True) -> int:
|
| 140 |
+
"""Count the number of model parameters."""
|
| 141 |
if trainable_only:
|
| 142 |
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 143 |
return sum(p.numel() for p in self.parameters())
|
|
|
|
| 151 |
top_k: int = 50,
|
| 152 |
top_p: float = 0.9,
|
| 153 |
) -> torch.Tensor:
|
| 154 |
+
"""Text generation (inference).
|
| 155 |
|
| 156 |
+
Autoregressive generation: predicts and appends one token at a time.
|
| 157 |
|
| 158 |
Args:
|
| 159 |
+
input_ids: (1, prompt_len) - initial prompt
|
| 160 |
+
max_new_tokens: maximum number of tokens to generate
|
| 161 |
+
temperature: controls sharpness of probability distribution (lower = more conservative)
|
| 162 |
+
top_k: consider only the top k tokens by probability
|
| 163 |
+
top_p: consider only tokens up to cumulative probability p (nucleus sampling)
|
| 164 |
"""
|
| 165 |
self.eval()
|
| 166 |
generated = input_ids
|
| 167 |
|
| 168 |
for _ in range(max_new_tokens):
|
| 169 |
+
# Truncate if current sequence exceeds max_seq_len
|
| 170 |
ctx = generated[:, -self.config.max_seq_len:]
|
| 171 |
|
| 172 |
# Forward pass
|
| 173 |
logits, _ = self(ctx)
|
| 174 |
+
# Use only the last token's logits (next-token prediction)
|
| 175 |
next_logits = logits[:, -1, :] / temperature
|
| 176 |
|
| 177 |
+
# ── Top-K filtering ──
|
| 178 |
if top_k > 0:
|
| 179 |
top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
| 180 |
min_top_k = top_k_values[:, -1].unsqueeze(-1)
|
| 181 |
next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
|
| 182 |
|
| 183 |
+
# ── Top-P (Nucleus) filtering ──
|
| 184 |
if top_p < 1.0:
|
| 185 |
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 186 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 187 |
+
# Remove tokens where cumulative probability exceeds top_p
|
| 188 |
remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
| 189 |
sorted_logits[remove_mask] = float("-inf")
|
| 190 |
+
# Restore original order
|
| 191 |
next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
| 192 |
|
| 193 |
+
# Sample from probability distribution
|
| 194 |
probs = F.softmax(next_logits, dim=-1)
|
| 195 |
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 196 |
|
| 197 |
+
# Append generated token
|
| 198 |
generated = torch.cat([generated, next_token], dim=1)
|
| 199 |
|
| 200 |
return generated
|
llm_lab/model/norm.py
CHANGED
|
@@ -5,36 +5,36 @@ import torch.nn as nn
|
|
| 5 |
|
| 6 |
|
| 7 |
class RMSNorm(nn.Module):
|
| 8 |
-
"""RMSNorm:
|
| 9 |
|
| 10 |
-
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
-
- bias
|
| 14 |
|
| 15 |
-
|
| 16 |
RMSNorm(x) = (x / RMS(x)) * γ
|
| 17 |
RMS(x) = sqrt(mean(x²) + ε)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
→
|
| 21 |
-
→
|
| 22 |
"""
|
| 23 |
|
| 24 |
def __init__(self, dim: int, eps: float = 1e-6):
|
| 25 |
super().__init__()
|
| 26 |
self.eps = eps
|
| 27 |
-
# γ (gamma):
|
| 28 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 29 |
|
| 30 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
-
# 1)
|
| 32 |
-
#
|
| 33 |
x_float = x.float()
|
| 34 |
|
| 35 |
-
# 2)
|
| 36 |
rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 37 |
-
# rsqrt = 1/sqrt(x) →
|
| 38 |
|
| 39 |
-
# 3)
|
| 40 |
return (x_float * rms).to(x.dtype) * self.weight
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class RMSNorm(nn.Module):
|
| 8 |
+
"""RMSNorm: A lightweight alternative to LayerNorm.
|
| 9 |
|
| 10 |
+
Differences from standard LayerNorm:
|
| 11 |
+
- Does not subtract the mean → saves computation
|
| 12 |
+
- Normalizes using RMS (Root Mean Square) instead of variance
|
| 13 |
+
- No bias parameter
|
| 14 |
|
| 15 |
+
Formula:
|
| 16 |
RMSNorm(x) = (x / RMS(x)) * γ
|
| 17 |
RMS(x) = sqrt(mean(x²) + ε)
|
| 18 |
|
| 19 |
+
Why is normalization necessary?
|
| 20 |
+
→ Stacking layers deeply causes activation values to explode or vanish.
|
| 21 |
+
→ Normalization keeps the input to each layer within a stable range.
|
| 22 |
"""
|
| 23 |
|
| 24 |
def __init__(self, dim: int, eps: float = 1e-6):
|
| 25 |
super().__init__()
|
| 26 |
self.eps = eps
|
| 27 |
+
# γ (gamma): learnable scale parameter, initialized to 1
|
| 28 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 29 |
|
| 30 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
# 1) Cast input to float32 for numerical stability
|
| 32 |
+
# Computing the sum of squares in bf16/fp16 risks overflow
|
| 33 |
x_float = x.float()
|
| 34 |
|
| 35 |
+
# 2) Compute RMS: sqrt(mean(x²) + ε)
|
| 36 |
rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 37 |
+
# rsqrt = 1/sqrt(x) → replaces division with multiplication (faster)
|
| 38 |
|
| 39 |
+
# 3) Normalize, restore original dtype, and apply scale
|
| 40 |
return (x_float * rms).to(x.dtype) * self.weight
|
llm_lab/model/rope.py
CHANGED
|
@@ -7,21 +7,23 @@ import torch.nn as nn
|
|
| 7 |
|
| 8 |
|
| 9 |
class RotaryPositionalEmbedding(nn.Module):
|
| 10 |
-
"""RoPE:
|
| 11 |
|
| 12 |
-
|
| 13 |
-
-
|
| 14 |
-
|
| 15 |
-
-
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
-
|
| 19 |
-
|
| 20 |
-
-
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
θ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
|
| 24 |
-
RoPE(x, pos) = x
|
| 25 |
"""
|
| 26 |
|
| 27 |
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
|
|
@@ -30,16 +32,16 @@ class RotaryPositionalEmbedding(nn.Module):
|
|
| 30 |
self.max_seq_len = max_seq_len
|
| 31 |
self.theta = theta
|
| 32 |
|
| 33 |
-
#
|
| 34 |
# freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
|
| 35 |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 36 |
self.register_buffer("freqs", freqs, persistent=False)
|
| 37 |
|
| 38 |
-
# (max_seq_len, dim/2)
|
| 39 |
self._build_cache(max_seq_len)
|
| 40 |
|
| 41 |
def _build_cache(self, seq_len: int):
|
| 42 |
-
"""cos/sin
|
| 43 |
t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
|
| 44 |
# outer product: (seq_len,) × (dim/2,) → (seq_len, dim/2)
|
| 45 |
angles = torch.outer(t, self.freqs)
|
|
@@ -49,23 +51,23 @@ class RotaryPositionalEmbedding(nn.Module):
|
|
| 49 |
def forward(
|
| 50 |
self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
|
| 51 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 52 |
-
"""
|
| 53 |
|
| 54 |
Args:
|
| 55 |
q: (batch, num_heads, seq_len, head_dim)
|
| 56 |
k: (batch, num_kv_heads, seq_len, head_dim)
|
| 57 |
-
position_offset:
|
| 58 |
|
| 59 |
Returns:
|
| 60 |
-
|
| 61 |
"""
|
| 62 |
seq_len = q.shape[2]
|
| 63 |
|
| 64 |
-
#
|
| 65 |
if position_offset + seq_len > self.cos_cached.shape[0]:
|
| 66 |
self._build_cache(position_offset + seq_len)
|
| 67 |
|
| 68 |
-
#
|
| 69 |
cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
|
| 70 |
sin = self.sin_cached[position_offset : position_offset + seq_len]
|
| 71 |
|
|
@@ -77,27 +79,27 @@ class RotaryPositionalEmbedding(nn.Module):
|
|
| 77 |
def _apply_rotation(
|
| 78 |
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 79 |
) -> torch.Tensor:
|
| 80 |
-
"""
|
| 81 |
|
| 82 |
-
2D
|
| 83 |
[cos θ, -sin θ] [x1] [x1·cos θ - x2·sin θ]
|
| 84 |
[sin θ, cos θ] [x2] = [x1·sin θ + x2·cos θ]
|
| 85 |
|
| 86 |
-
|
| 87 |
"""
|
| 88 |
# x: (batch, heads, seq_len, head_dim)
|
| 89 |
-
#
|
| 90 |
-
x_even = x[..., 0::2] #
|
| 91 |
-
x_odd = x[..., 1::2] #
|
| 92 |
|
| 93 |
-
#
|
| 94 |
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 95 |
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 96 |
|
| 97 |
-
#
|
| 98 |
rotated_even = x_even * cos - x_odd * sin
|
| 99 |
rotated_odd = x_even * sin + x_odd * cos
|
| 100 |
|
| 101 |
-
#
|
| 102 |
out = torch.stack([rotated_even, rotated_odd], dim=-1)
|
| 103 |
-
return out.flatten(-2) #
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class RotaryPositionalEmbedding(nn.Module):
|
| 10 |
+
"""RoPE: Relative positional encoding using rotation matrices.
|
| 11 |
|
| 12 |
+
Core idea:
|
| 13 |
+
- Each dimension pair (2i, 2i+1) is treated as coordinates in a 2D plane,
|
| 14 |
+
and is rotated by an angle proportional to the position.
|
| 15 |
+
- The attention score (Q·K) between two tokens depends only on their
|
| 16 |
+
relative distance.
|
| 17 |
|
| 18 |
+
Why RoPE?
|
| 19 |
+
- Absolute positional embeddings: add a fixed vector at each position
|
| 20 |
+
→ difficult to generalize to longer sequences
|
| 21 |
+
- Relative positional embeddings: complex implementation, extra parameters needed
|
| 22 |
+
- RoPE: encodes relative position information naturally with no extra parameters
|
| 23 |
|
| 24 |
+
Formula:
|
| 25 |
θ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
|
| 26 |
+
RoPE(x, pos) = rotate x in each dimension pair by pos × θ_i
|
| 27 |
"""
|
| 28 |
|
| 29 |
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
|
|
|
|
| 32 |
self.max_seq_len = max_seq_len
|
| 33 |
self.theta = theta
|
| 34 |
|
| 35 |
+
# Pre-compute frequency vector (no training needed → register as buffer)
|
| 36 |
# freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
|
| 37 |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 38 |
self.register_buffer("freqs", freqs, persistent=False)
|
| 39 |
|
| 40 |
+
# Pre-compute cos/sin table of shape (max_seq_len, dim/2)
|
| 41 |
self._build_cache(max_seq_len)
|
| 42 |
|
| 43 |
def _build_cache(self, seq_len: int):
|
| 44 |
+
"""Pre-compute and cache cos/sin values."""
|
| 45 |
t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
|
| 46 |
# outer product: (seq_len,) × (dim/2,) → (seq_len, dim/2)
|
| 47 |
angles = torch.outer(t, self.freqs)
|
|
|
|
| 51 |
def forward(
|
| 52 |
self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
|
| 53 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 54 |
+
"""Apply rotary transformation to Q and K.
|
| 55 |
|
| 56 |
Args:
|
| 57 |
q: (batch, num_heads, seq_len, head_dim)
|
| 58 |
k: (batch, num_kv_heads, seq_len, head_dim)
|
| 59 |
+
position_offset: sequence start position offset (used with KV cache during inference)
|
| 60 |
|
| 61 |
Returns:
|
| 62 |
+
(q_rotated, k_rotated) with rotary transformation applied
|
| 63 |
"""
|
| 64 |
seq_len = q.shape[2]
|
| 65 |
|
| 66 |
+
# Extend cache if needed
|
| 67 |
if position_offset + seq_len > self.cos_cached.shape[0]:
|
| 68 |
self._build_cache(position_offset + seq_len)
|
| 69 |
|
| 70 |
+
# Slice cos/sin values for the current positions
|
| 71 |
cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
|
| 72 |
sin = self.sin_cached[position_offset : position_offset + seq_len]
|
| 73 |
|
|
|
|
| 79 |
def _apply_rotation(
|
| 80 |
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 81 |
) -> torch.Tensor:
|
| 82 |
+
"""Apply rotation transformation.
|
| 83 |
|
| 84 |
+
2D rotation matrix:
|
| 85 |
[cos θ, -sin θ] [x1] [x1·cos θ - x2·sin θ]
|
| 86 |
[sin θ, cos θ] [x2] = [x1·sin θ + x2·cos θ]
|
| 87 |
|
| 88 |
+
Implemented efficiently using vectorized operations.
|
| 89 |
"""
|
| 90 |
# x: (batch, heads, seq_len, head_dim)
|
| 91 |
+
# Separate even/odd indices: (x0, x1, x2, x3, ...) → (x0, x2, ...), (x1, x3, ...)
|
| 92 |
+
x_even = x[..., 0::2] # even indices
|
| 93 |
+
x_odd = x[..., 1::2] # odd indices
|
| 94 |
|
| 95 |
+
# Adjust dimensions for broadcasting: (seq_len, dim/2) → (1, 1, seq_len, dim/2)
|
| 96 |
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 97 |
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 98 |
|
| 99 |
+
# Apply rotation
|
| 100 |
rotated_even = x_even * cos - x_odd * sin
|
| 101 |
rotated_odd = x_even * sin + x_odd * cos
|
| 102 |
|
| 103 |
+
# Re-interleave: (even0, odd0, even1, odd1, ...)
|
| 104 |
out = torch.stack([rotated_even, rotated_odd], dim=-1)
|
| 105 |
+
return out.flatten(-2) # Merge last two dimensions to restore original shape
|
llm_lab/model/transformer_block.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Transformer Block (
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
|
|
@@ -12,32 +12,32 @@ from .feedforward import SwiGLUFeedForward
|
|
| 12 |
|
| 13 |
|
| 14 |
class TransformerBlock(nn.Module):
|
| 15 |
-
"""
|
| 16 |
|
| 17 |
-
|
| 18 |
x → RMSNorm → Attention → + (residual) → RMSNorm → FFN → + (residual) → out
|
| 19 |
|
| 20 |
Pre-Norm vs Post-Norm:
|
| 21 |
-
- Post-Norm (
|
| 22 |
-
→
|
| 23 |
-
- Pre-Norm (GPT-2
|
| 24 |
-
→ gradient
|
| 25 |
|
| 26 |
-
Residual Connection
|
| 27 |
-
-
|
| 28 |
-
-
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self, config: ModelConfig, layer_idx: int):
|
| 32 |
super().__init__()
|
| 33 |
self.layer_idx = layer_idx
|
| 34 |
|
| 35 |
-
# Pre-Norm:
|
| 36 |
self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 37 |
# Self-Attention
|
| 38 |
self.attention = GroupedQueryAttention(config)
|
| 39 |
|
| 40 |
-
# Pre-Norm:
|
| 41 |
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 42 |
# Feed-Forward Network
|
| 43 |
self.feed_forward = SwiGLUFeedForward(config)
|
|
|
|
| 1 |
+
"""Transformer Block (a single layer)."""
|
| 2 |
|
| 3 |
from typing import Optional
|
| 4 |
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class TransformerBlock(nn.Module):
|
| 15 |
+
"""A single Transformer decoder block.
|
| 16 |
|
| 17 |
+
Structure (Pre-Norm style):
|
| 18 |
x → RMSNorm → Attention → + (residual) → RMSNorm → FFN → + (residual) → out
|
| 19 |
|
| 20 |
Pre-Norm vs Post-Norm:
|
| 21 |
+
- Post-Norm (original Transformer): LayerNorm applied after the residual
|
| 22 |
+
→ training instability in deep models
|
| 23 |
+
- Pre-Norm (standard since GPT-2): LayerNorm applied before the sublayer
|
| 24 |
+
→ smooth gradient flow, stable training
|
| 25 |
|
| 26 |
+
Role of Residual Connection:
|
| 27 |
+
- Adds the input to the output → a "highway" that lets gradients skip layers
|
| 28 |
+
- The key reason training is feasible even with 22 stacked layers
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self, config: ModelConfig, layer_idx: int):
|
| 32 |
super().__init__()
|
| 33 |
self.layer_idx = layer_idx
|
| 34 |
|
| 35 |
+
# Pre-Norm: normalization before Attention
|
| 36 |
self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 37 |
# Self-Attention
|
| 38 |
self.attention = GroupedQueryAttention(config)
|
| 39 |
|
| 40 |
+
# Pre-Norm: normalization before FFN
|
| 41 |
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 42 |
# Feed-Forward Network
|
| 43 |
self.feed_forward = SwiGLUFeedForward(config)
|
llm_lab/model/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
| 12 |
|
| 13 |
|
| 14 |
def count_parameters_detailed(model: "LLMModel") -> dict:
|
| 15 |
-
"""
|
| 16 |
total = 0
|
| 17 |
breakdown = {}
|
| 18 |
|
|
@@ -21,7 +21,7 @@ def count_parameters_detailed(model: "LLMModel") -> dict:
|
|
| 21 |
breakdown["token_embedding"] = emb_params
|
| 22 |
total += emb_params
|
| 23 |
|
| 24 |
-
#
|
| 25 |
layer_total = 0
|
| 26 |
layer_detail = {}
|
| 27 |
layer = model.layers[0]
|
|
@@ -40,7 +40,7 @@ def count_parameters_detailed(model: "LLMModel") -> dict:
|
|
| 40 |
breakdown["final_norm"] = norm_params
|
| 41 |
total += norm_params
|
| 42 |
|
| 43 |
-
# LM head (weight tying
|
| 44 |
breakdown["lm_head"] = "weight tying (0 additional)"
|
| 45 |
breakdown["total"] = total
|
| 46 |
|
|
@@ -48,12 +48,12 @@ def count_parameters_detailed(model: "LLMModel") -> dict:
|
|
| 48 |
|
| 49 |
|
| 50 |
def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
|
| 51 |
-
"""
|
| 52 |
|
| 53 |
Args:
|
| 54 |
-
dtype_bytes: 2 (bf16/fp16)
|
| 55 |
"""
|
| 56 |
-
#
|
| 57 |
emb = config.vocab_size * config.hidden_dim
|
| 58 |
per_layer = (
|
| 59 |
config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
|
|
@@ -67,11 +67,11 @@ def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: in
|
|
| 67 |
optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states × fp32
|
| 68 |
gradient_gb = total_params * dtype_bytes / 1e9
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
#
|
| 72 |
activation_gb = (
|
| 73 |
-
batch_size * config.max_seq_len * config.hidden_dim * 4 #
|
| 74 |
-
* math.sqrt(config.num_layers) #
|
| 75 |
/ 1e9
|
| 76 |
)
|
| 77 |
|
|
|
|
| 1 |
+
"""Model utility functions."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def count_parameters_detailed(model: "LLMModel") -> dict:
|
| 15 |
+
"""Print a detailed breakdown of the model's parameter count by component."""
|
| 16 |
total = 0
|
| 17 |
breakdown = {}
|
| 18 |
|
|
|
|
| 21 |
breakdown["token_embedding"] = emb_params
|
| 22 |
total += emb_params
|
| 23 |
|
| 24 |
+
# Per layer
|
| 25 |
layer_total = 0
|
| 26 |
layer_detail = {}
|
| 27 |
layer = model.layers[0]
|
|
|
|
| 40 |
breakdown["final_norm"] = norm_params
|
| 41 |
total += norm_params
|
| 42 |
|
| 43 |
+
# LM head (weight tying, so 0 additional parameters)
|
| 44 |
breakdown["lm_head"] = "weight tying (0 additional)"
|
| 45 |
breakdown["total"] = total
|
| 46 |
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
|
| 51 |
+
"""Estimate GPU memory usage of the model.
|
| 52 |
|
| 53 |
Args:
|
| 54 |
+
dtype_bytes: 2 (bf16/fp16) or 4 (fp32)
|
| 55 |
"""
|
| 56 |
+
# Approximate parameter count
|
| 57 |
emb = config.vocab_size * config.hidden_dim
|
| 58 |
per_layer = (
|
| 59 |
config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
|
|
|
|
| 67 |
optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states × fp32
|
| 68 |
gradient_gb = total_params * dtype_bytes / 1e9
|
| 69 |
|
| 70 |
+
# Activation memory (assuming activation checkpointing is applied)
|
| 71 |
+
# Rough estimate: batch_size × seq_len × hidden_dim × num_layers × factor
|
| 72 |
activation_gb = (
|
| 73 |
+
batch_size * config.max_seq_len * config.hidden_dim * 4 # bytes
|
| 74 |
+
* math.sqrt(config.num_layers) # effect of checkpointing
|
| 75 |
/ 1e9
|
| 76 |
)
|
| 77 |
|
llm_lab/training/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
from .scheduler import CosineWarmupScheduler
|
| 3 |
from .checkpoint import CheckpointManager
|
| 4 |
from .metrics import MetricsTracker
|
|
|
|
| 1 |
+
"""Training module — Gradient Accumulation, Mixed Precision, checkpointing, wandb logging."""
|
| 2 |
from .scheduler import CosineWarmupScheduler
|
| 3 |
from .checkpoint import CheckpointManager
|
| 4 |
from .metrics import MetricsTracker
|
llm_lab/training/checkpoint.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
import json
|
| 4 |
import shutil
|
|
@@ -13,22 +13,22 @@ from llm_lab.config import TrainConfig
|
|
| 13 |
|
| 14 |
|
| 15 |
class CheckpointManager:
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
-
|
| 20 |
-
- Google Drive
|
| 21 |
-
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
- model_state_dict:
|
| 25 |
-
- optimizer_state_dict:
|
| 26 |
-
- step:
|
| 27 |
-
- best_val_loss:
|
| 28 |
-
- config:
|
| 29 |
-
- rng_states:
|
| 30 |
-
- metrics_history:
|
| 31 |
-
- wandb_run_id: wandb
|
| 32 |
"""
|
| 33 |
|
| 34 |
def __init__(self, config: TrainConfig):
|
|
@@ -46,20 +46,20 @@ class CheckpointManager:
|
|
| 46 |
metrics_history: Dict[str, list],
|
| 47 |
wandb_run_id: Optional[str] = None,
|
| 48 |
):
|
| 49 |
-
"""
|
| 50 |
ckpt_path = self.checkpoint_dir / f"step_{step:06d}"
|
| 51 |
ckpt_path.mkdir(parents=True, exist_ok=True)
|
| 52 |
|
| 53 |
-
print(f"\n💾
|
| 54 |
start = time.time()
|
| 55 |
|
| 56 |
-
# 1)
|
| 57 |
torch.save(model.state_dict(), ckpt_path / "model.pt")
|
| 58 |
|
| 59 |
-
# 2)
|
| 60 |
torch.save(optimizer.state_dict(), ckpt_path / "optimizer.pt")
|
| 61 |
|
| 62 |
-
# 3)
|
| 63 |
meta = {
|
| 64 |
"step": step,
|
| 65 |
"best_val_loss": best_val_loss,
|
|
@@ -69,10 +69,10 @@ class CheckpointManager:
|
|
| 69 |
with open(ckpt_path / "meta.json", "w") as f:
|
| 70 |
json.dump(meta, f, indent=2)
|
| 71 |
|
| 72 |
-
# 4)
|
| 73 |
torch.save(metrics_history, ckpt_path / "metrics.pt")
|
| 74 |
|
| 75 |
-
# 5)
|
| 76 |
rng_states = {
|
| 77 |
"python": torch.random.get_rng_state(),
|
| 78 |
"cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
|
@@ -81,9 +81,9 @@ class CheckpointManager:
|
|
| 81 |
|
| 82 |
elapsed = time.time() - start
|
| 83 |
ckpt_size = sum(f.stat().st_size for f in ckpt_path.rglob("*")) / 1e9
|
| 84 |
-
print(f"
|
| 85 |
|
| 86 |
-
#
|
| 87 |
self._cleanup_old_checkpoints()
|
| 88 |
|
| 89 |
def load_latest(
|
|
@@ -92,42 +92,42 @@ class CheckpointManager:
|
|
| 92 |
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 93 |
device: torch.device = torch.device("cpu"),
|
| 94 |
) -> Dict[str, Any]:
|
| 95 |
-
"""
|
| 96 |
|
| 97 |
Returns:
|
| 98 |
{"step", "best_val_loss", "wandb_run_id", "metrics_history"}
|
| 99 |
-
|
| 100 |
"""
|
| 101 |
ckpt_path = self._find_latest()
|
| 102 |
if ckpt_path is None:
|
| 103 |
-
print("[Checkpoint]
|
| 104 |
return None
|
| 105 |
|
| 106 |
-
print(f"\n📂
|
| 107 |
start = time.time()
|
| 108 |
|
| 109 |
-
# 1)
|
| 110 |
model_state = torch.load(ckpt_path / "model.pt", map_location=device, weights_only=True)
|
| 111 |
model.load_state_dict(model_state)
|
| 112 |
-
del model_state #
|
| 113 |
|
| 114 |
-
# 2)
|
| 115 |
if optimizer is not None:
|
| 116 |
optim_state = torch.load(ckpt_path / "optimizer.pt", map_location=device, weights_only=True)
|
| 117 |
optimizer.load_state_dict(optim_state)
|
| 118 |
del optim_state
|
| 119 |
|
| 120 |
-
# 3)
|
| 121 |
with open(ckpt_path / "meta.json", "r") as f:
|
| 122 |
meta = json.load(f)
|
| 123 |
|
| 124 |
-
# 4)
|
| 125 |
metrics_history = {}
|
| 126 |
metrics_path = ckpt_path / "metrics.pt"
|
| 127 |
if metrics_path.exists():
|
| 128 |
metrics_history = torch.load(metrics_path, weights_only=False)
|
| 129 |
|
| 130 |
-
# 5)
|
| 131 |
rng_path = ckpt_path / "rng_states.pt"
|
| 132 |
if rng_path.exists():
|
| 133 |
rng_states = torch.load(rng_path, weights_only=False)
|
|
@@ -136,7 +136,7 @@ class CheckpointManager:
|
|
| 136 |
torch.cuda.set_rng_state(rng_states["cuda"])
|
| 137 |
|
| 138 |
elapsed = time.time() - start
|
| 139 |
-
print(f"
|
| 140 |
|
| 141 |
return {
|
| 142 |
"step": meta["step"],
|
|
@@ -146,14 +146,14 @@ class CheckpointManager:
|
|
| 146 |
}
|
| 147 |
|
| 148 |
def _find_latest(self) -> Optional[Path]:
|
| 149 |
-
"""
|
| 150 |
ckpts = sorted(self.checkpoint_dir.glob("step_*"))
|
| 151 |
return ckpts[-1] if ckpts else None
|
| 152 |
|
| 153 |
def _cleanup_old_checkpoints(self):
|
| 154 |
-
"""
|
| 155 |
ckpts = sorted(self.checkpoint_dir.glob("step_*"))
|
| 156 |
while len(ckpts) > self.max_checkpoints:
|
| 157 |
old = ckpts.pop(0)
|
| 158 |
-
print(f" 🗑️
|
| 159 |
shutil.rmtree(old)
|
|
|
|
| 1 |
+
"""Training state save/restore manager."""
|
| 2 |
|
| 3 |
import json
|
| 4 |
import shutil
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class CheckpointManager:
|
| 16 |
+
"""Training state save/restore manager.
|
| 17 |
+
|
| 18 |
+
Why checkpoints matter in Colab:
|
| 19 |
+
- Session expiry (up to ~24 hours) causes all in-memory state to be lost
|
| 20 |
+
- Saving to Google Drive enables continuous training across sessions
|
| 21 |
+
- Optimizer state must be saved to preserve AdamW momentum
|
| 22 |
+
|
| 23 |
+
Saved contents:
|
| 24 |
+
- model_state_dict: model weights
|
| 25 |
+
- optimizer_state_dict: optimizer state (m, v momentum)
|
| 26 |
+
- step: current training step
|
| 27 |
+
- best_val_loss: lowest validation loss
|
| 28 |
+
- config: training configuration (for reproducibility)
|
| 29 |
+
- rng_states: random seed state (full reproducibility)
|
| 30 |
+
- metrics_history: training metrics history
|
| 31 |
+
- wandb_run_id: wandb run ID (for logging continuity)
|
| 32 |
"""
|
| 33 |
|
| 34 |
def __init__(self, config: TrainConfig):
|
|
|
|
| 46 |
metrics_history: Dict[str, list],
|
| 47 |
wandb_run_id: Optional[str] = None,
|
| 48 |
):
|
| 49 |
+
"""Saves a checkpoint."""
|
| 50 |
ckpt_path = self.checkpoint_dir / f"step_{step:06d}"
|
| 51 |
ckpt_path.mkdir(parents=True, exist_ok=True)
|
| 52 |
|
| 53 |
+
print(f"\n💾 Saving checkpoint: {ckpt_path}")
|
| 54 |
start = time.time()
|
| 55 |
|
| 56 |
+
# 1) Model weights (saved as-is in bf16)
|
| 57 |
torch.save(model.state_dict(), ckpt_path / "model.pt")
|
| 58 |
|
| 59 |
+
# 2) Optimizer state (includes fp32 momentum, can be large)
|
| 60 |
torch.save(optimizer.state_dict(), ckpt_path / "optimizer.pt")
|
| 61 |
|
| 62 |
+
# 3) Training metadata
|
| 63 |
meta = {
|
| 64 |
"step": step,
|
| 65 |
"best_val_loss": best_val_loss,
|
|
|
|
| 69 |
with open(ckpt_path / "meta.json", "w") as f:
|
| 70 |
json.dump(meta, f, indent=2)
|
| 71 |
|
| 72 |
+
# 4) Metrics history
|
| 73 |
torch.save(metrics_history, ckpt_path / "metrics.pt")
|
| 74 |
|
| 75 |
+
# 5) Random states (for full reproducibility)
|
| 76 |
rng_states = {
|
| 77 |
"python": torch.random.get_rng_state(),
|
| 78 |
"cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
|
|
|
| 81 |
|
| 82 |
elapsed = time.time() - start
|
| 83 |
ckpt_size = sum(f.stat().st_size for f in ckpt_path.rglob("*")) / 1e9
|
| 84 |
+
print(f" Save complete: {ckpt_size:.2f} GB, {elapsed:.1f}s")
|
| 85 |
|
| 86 |
+
# Remove old checkpoints (rolling)
|
| 87 |
self._cleanup_old_checkpoints()
|
| 88 |
|
| 89 |
def load_latest(
|
|
|
|
| 92 |
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 93 |
device: torch.device = torch.device("cpu"),
|
| 94 |
) -> Dict[str, Any]:
|
| 95 |
+
"""Loads the most recent checkpoint.
|
| 96 |
|
| 97 |
Returns:
|
| 98 |
{"step", "best_val_loss", "wandb_run_id", "metrics_history"}
|
| 99 |
+
or None if no checkpoint exists
|
| 100 |
"""
|
| 101 |
ckpt_path = self._find_latest()
|
| 102 |
if ckpt_path is None:
|
| 103 |
+
print("[Checkpoint] No saved checkpoint found. Starting from scratch.")
|
| 104 |
return None
|
| 105 |
|
| 106 |
+
print(f"\n📂 Loading checkpoint: {ckpt_path}")
|
| 107 |
start = time.time()
|
| 108 |
|
| 109 |
+
# 1) Model weights
|
| 110 |
model_state = torch.load(ckpt_path / "model.pt", map_location=device, weights_only=True)
|
| 111 |
model.load_state_dict(model_state)
|
| 112 |
+
del model_state # free memory
|
| 113 |
|
| 114 |
+
# 2) Optimizer state
|
| 115 |
if optimizer is not None:
|
| 116 |
optim_state = torch.load(ckpt_path / "optimizer.pt", map_location=device, weights_only=True)
|
| 117 |
optimizer.load_state_dict(optim_state)
|
| 118 |
del optim_state
|
| 119 |
|
| 120 |
+
# 3) Metadata
|
| 121 |
with open(ckpt_path / "meta.json", "r") as f:
|
| 122 |
meta = json.load(f)
|
| 123 |
|
| 124 |
+
# 4) Metrics history
|
| 125 |
metrics_history = {}
|
| 126 |
metrics_path = ckpt_path / "metrics.pt"
|
| 127 |
if metrics_path.exists():
|
| 128 |
metrics_history = torch.load(metrics_path, weights_only=False)
|
| 129 |
|
| 130 |
+
# 5) Restore random states
|
| 131 |
rng_path = ckpt_path / "rng_states.pt"
|
| 132 |
if rng_path.exists():
|
| 133 |
rng_states = torch.load(rng_path, weights_only=False)
|
|
|
|
| 136 |
torch.cuda.set_rng_state(rng_states["cuda"])
|
| 137 |
|
| 138 |
elapsed = time.time() - start
|
| 139 |
+
print(f" Load complete: step={meta['step']}, {elapsed:.1f}s")
|
| 140 |
|
| 141 |
return {
|
| 142 |
"step": meta["step"],
|
|
|
|
| 146 |
}
|
| 147 |
|
| 148 |
def _find_latest(self) -> Optional[Path]:
|
| 149 |
+
"""Finds the path of the most recent checkpoint."""
|
| 150 |
ckpts = sorted(self.checkpoint_dir.glob("step_*"))
|
| 151 |
return ckpts[-1] if ckpts else None
|
| 152 |
|
| 153 |
def _cleanup_old_checkpoints(self):
|
| 154 |
+
"""Removes old checkpoints (rolling)."""
|
| 155 |
ckpts = sorted(self.checkpoint_dir.glob("step_*"))
|
| 156 |
while len(ckpts) > self.max_checkpoints:
|
| 157 |
old = ckpts.pop(0)
|
| 158 |
+
print(f" 🗑️ Removing old checkpoint: {old.name}")
|
| 159 |
shutil.rmtree(old)
|
llm_lab/training/metrics.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from typing import Dict, Optional
|
| 4 |
|
|
@@ -8,16 +8,16 @@ from llm_lab.config import TrainConfig
|
|
| 8 |
|
| 9 |
|
| 10 |
class MetricsTracker:
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
- train/loss:
|
| 15 |
-
- train/lr:
|
| 16 |
-
- train/grad_norm:
|
| 17 |
-
- train/tokens_per_sec:
|
| 18 |
-
- train/gpu_mem_gb: GPU
|
| 19 |
-
- val/loss:
|
| 20 |
-
- val/perplexity:
|
| 21 |
"""
|
| 22 |
|
| 23 |
def __init__(self, config: TrainConfig):
|
|
@@ -33,13 +33,13 @@ class MetricsTracker:
|
|
| 33 |
"val_ppl": [],
|
| 34 |
}
|
| 35 |
|
| 36 |
-
# wandb
|
| 37 |
self.wandb_run = None
|
| 38 |
if config.use_wandb:
|
| 39 |
self._init_wandb()
|
| 40 |
|
| 41 |
def _init_wandb(self, resume_id: Optional[str] = None):
|
| 42 |
-
"""wandb
|
| 43 |
try:
|
| 44 |
import wandb
|
| 45 |
|
|
@@ -51,16 +51,16 @@ class MetricsTracker:
|
|
| 51 |
resume="allow",
|
| 52 |
config=self.config.__dict__,
|
| 53 |
)
|
| 54 |
-
print(f"[wandb]
|
| 55 |
except ImportError:
|
| 56 |
-
print("[wandb]
|
| 57 |
self.config.use_wandb = False
|
| 58 |
except Exception as e:
|
| 59 |
-
print(f"[wandb]
|
| 60 |
self.config.use_wandb = False
|
| 61 |
|
| 62 |
def resume_wandb(self, run_id: str):
|
| 63 |
-
"""
|
| 64 |
if self.config.use_wandb:
|
| 65 |
self._init_wandb(resume_id=run_id)
|
| 66 |
|
|
@@ -73,7 +73,7 @@ class MetricsTracker:
|
|
| 73 |
tokens_per_sec: float,
|
| 74 |
gpu_mem_gb: float,
|
| 75 |
):
|
| 76 |
-
"""
|
| 77 |
self.history["step"].append(step)
|
| 78 |
self.history["train_loss"].append(loss)
|
| 79 |
self.history["learning_rate"].append(lr)
|
|
@@ -93,7 +93,7 @@ class MetricsTracker:
|
|
| 93 |
}, step=step)
|
| 94 |
|
| 95 |
def log_eval(self, step: int, val_loss: float, val_ppl: float):
|
| 96 |
-
"""
|
| 97 |
self.history["val_loss"].append(val_loss)
|
| 98 |
self.history["val_ppl"].append(val_ppl)
|
| 99 |
|
|
|
|
| 1 |
+
"""Training metrics tracking and logging."""
|
| 2 |
|
| 3 |
from typing import Dict, Optional
|
| 4 |
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class MetricsTracker:
|
| 11 |
+
"""Tracks and logs training metrics.
|
| 12 |
+
|
| 13 |
+
Tracked items:
|
| 14 |
+
- train/loss: training loss (Cross-Entropy)
|
| 15 |
+
- train/lr: current learning rate
|
| 16 |
+
- train/grad_norm: gradient L2 norm
|
| 17 |
+
- train/tokens_per_sec: throughput
|
| 18 |
+
- train/gpu_mem_gb: GPU memory usage
|
| 19 |
+
- val/loss: validation loss
|
| 20 |
+
- val/perplexity: validation perplexity (= exp(loss))
|
| 21 |
"""
|
| 22 |
|
| 23 |
def __init__(self, config: TrainConfig):
|
|
|
|
| 33 |
"val_ppl": [],
|
| 34 |
}
|
| 35 |
|
| 36 |
+
# wandb initialization
|
| 37 |
self.wandb_run = None
|
| 38 |
if config.use_wandb:
|
| 39 |
self._init_wandb()
|
| 40 |
|
| 41 |
def _init_wandb(self, resume_id: Optional[str] = None):
|
| 42 |
+
"""Initializes wandb (supports continuous logging across sessions)."""
|
| 43 |
try:
|
| 44 |
import wandb
|
| 45 |
|
|
|
|
| 51 |
resume="allow",
|
| 52 |
config=self.config.__dict__,
|
| 53 |
)
|
| 54 |
+
print(f"[wandb] Initialized: {self.wandb_run.url}")
|
| 55 |
except ImportError:
|
| 56 |
+
print("[wandb] Not installed. Using console logging only.")
|
| 57 |
self.config.use_wandb = False
|
| 58 |
except Exception as e:
|
| 59 |
+
print(f"[wandb] Initialization failed: {e}. Using console logging only.")
|
| 60 |
self.config.use_wandb = False
|
| 61 |
|
| 62 |
def resume_wandb(self, run_id: str):
|
| 63 |
+
"""Resumes logging from a previous wandb run."""
|
| 64 |
if self.config.use_wandb:
|
| 65 |
self._init_wandb(resume_id=run_id)
|
| 66 |
|
|
|
|
| 73 |
tokens_per_sec: float,
|
| 74 |
gpu_mem_gb: float,
|
| 75 |
):
|
| 76 |
+
"""Records training step metrics."""
|
| 77 |
self.history["step"].append(step)
|
| 78 |
self.history["train_loss"].append(loss)
|
| 79 |
self.history["learning_rate"].append(lr)
|
|
|
|
| 93 |
}, step=step)
|
| 94 |
|
| 95 |
def log_eval(self, step: int, val_loss: float, val_ppl: float):
|
| 96 |
+
"""Records validation metrics."""
|
| 97 |
self.history["val_loss"].append(val_loss)
|
| 98 |
self.history["val_ppl"].append(val_ppl)
|
| 99 |
|
llm_lab/training/optimizer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""AdamW
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
@@ -7,19 +7,19 @@ from llm_lab.config import TrainConfig
|
|
| 7 |
|
| 8 |
|
| 9 |
def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
|
| 10 |
-
"""
|
| 11 |
|
| 12 |
-
Weight Decay
|
| 13 |
-
-
|
| 14 |
-
-
|
| 15 |
|
| 16 |
-
|
| 17 |
-
- Weight Decay
|
| 18 |
-
-
|
| 19 |
-
-
|
| 20 |
-
- 1D
|
| 21 |
"""
|
| 22 |
-
#
|
| 23 |
decay_params = []
|
| 24 |
no_decay_params = []
|
| 25 |
|
|
@@ -27,7 +27,7 @@ def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW
|
|
| 27 |
if not param.requires_grad:
|
| 28 |
continue
|
| 29 |
|
| 30 |
-
# 1D
|
| 31 |
if param.dim() <= 1 or "embedding" in name:
|
| 32 |
no_decay_params.append(param)
|
| 33 |
else:
|
|
@@ -40,15 +40,15 @@ def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW
|
|
| 40 |
|
| 41 |
n_decay = sum(p.numel() for p in decay_params)
|
| 42 |
n_no_decay = sum(p.numel() for p in no_decay_params)
|
| 43 |
-
print(f"[Optimizer] Decay
|
| 44 |
-
print(f"[Optimizer] No-decay
|
| 45 |
|
| 46 |
optimizer = torch.optim.AdamW(
|
| 47 |
param_groups,
|
| 48 |
lr=config.learning_rate,
|
| 49 |
betas=(config.beta1, config.beta2),
|
| 50 |
eps=config.adam_eps,
|
| 51 |
-
fused=torch.cuda.is_available(), # CUDA fused AdamW (
|
| 52 |
)
|
| 53 |
|
| 54 |
return optimizer
|
|
|
|
| 1 |
+
"""AdamW optimizer creation with Weight Decay separation."""
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
|
| 10 |
+
"""Creates an AdamW optimizer.
|
| 11 |
|
| 12 |
+
Weight Decay separation rules:
|
| 13 |
+
- Apply decay: Linear weights (attention proj, FFN, etc.)
|
| 14 |
+
- No decay: Embeddings, LayerNorm/RMSNorm, Bias
|
| 15 |
|
| 16 |
+
Why separate?
|
| 17 |
+
- Weight Decay penalizes large weights to prevent overfitting
|
| 18 |
+
- However, applying it to Norm scale parameters interferes with normalization
|
| 19 |
+
- Applying it to Embeddings causes rare token representations to shrink toward 0
|
| 20 |
+
- It is convention to exclude 1D parameters (bias, norm weight) from decay
|
| 21 |
"""
|
| 22 |
+
# Separate parameters into decay / no-decay groups
|
| 23 |
decay_params = []
|
| 24 |
no_decay_params = []
|
| 25 |
|
|
|
|
| 27 |
if not param.requires_grad:
|
| 28 |
continue
|
| 29 |
|
| 30 |
+
# 1D tensors (bias, norm weight) or embedding → no decay
|
| 31 |
if param.dim() <= 1 or "embedding" in name:
|
| 32 |
no_decay_params.append(param)
|
| 33 |
else:
|
|
|
|
| 40 |
|
| 41 |
n_decay = sum(p.numel() for p in decay_params)
|
| 42 |
n_no_decay = sum(p.numel() for p in no_decay_params)
|
| 43 |
+
print(f"[Optimizer] Decay parameters: {n_decay:,} ({n_decay/1e6:.1f}M)")
|
| 44 |
+
print(f"[Optimizer] No-decay parameters: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
|
| 45 |
|
| 46 |
optimizer = torch.optim.AdamW(
|
| 47 |
param_groups,
|
| 48 |
lr=config.learning_rate,
|
| 49 |
betas=(config.beta1, config.beta2),
|
| 50 |
eps=config.adam_eps,
|
| 51 |
+
fused=torch.cuda.is_available(), # CUDA fused AdamW (faster)
|
| 52 |
)
|
| 53 |
|
| 54 |
return optimizer
|
llm_lab/training/runner.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Optional
|
|
@@ -20,49 +20,49 @@ def start_training(
|
|
| 20 |
seq_len: int = 2048,
|
| 21 |
auto_config: bool = True,
|
| 22 |
) -> Trainer:
|
| 23 |
-
"""
|
| 24 |
|
| 25 |
-
|
| 26 |
```python
|
| 27 |
from model import LLMModel, ModelConfig
|
| 28 |
from data_pipeline import setup_data_pipeline, DataConfig
|
| 29 |
from trainer import start_training, TrainConfig
|
| 30 |
|
| 31 |
-
# 1.
|
| 32 |
model_config = ModelConfig.base_1b()
|
| 33 |
model = LLMModel(model_config)
|
| 34 |
|
| 35 |
-
# 2.
|
| 36 |
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 37 |
|
| 38 |
-
# 3.
|
| 39 |
trainer = start_training(model, train_dl, val_dl)
|
| 40 |
```
|
| 41 |
"""
|
| 42 |
config = config or TrainConfig()
|
| 43 |
|
| 44 |
-
# GPU
|
| 45 |
if auto_config:
|
| 46 |
config = auto_configure(config)
|
| 47 |
|
| 48 |
-
# Google Drive
|
| 49 |
if "/content/drive" in config.checkpoint_dir:
|
| 50 |
drive_path = Path("/content/drive/MyDrive")
|
| 51 |
if not drive_path.exists():
|
| 52 |
-
print("\n⚠️ Google Drive
|
| 53 |
-
print("
|
| 54 |
-
print("
|
| 55 |
config.checkpoint_dir = "./checkpoints"
|
| 56 |
|
| 57 |
-
#
|
| 58 |
torch.manual_seed(config.seed)
|
| 59 |
if torch.cuda.is_available():
|
| 60 |
torch.cuda.manual_seed(config.seed)
|
| 61 |
|
| 62 |
-
# Trainer
|
| 63 |
trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
|
| 64 |
|
| 65 |
-
#
|
| 66 |
trainer.train()
|
| 67 |
|
| 68 |
return trainer
|
|
|
|
| 1 |
+
"""Training execution helper (Quick Start)."""
|
| 2 |
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Optional
|
|
|
|
| 20 |
seq_len: int = 2048,
|
| 21 |
auto_config: bool = True,
|
| 22 |
) -> Trainer:
|
| 23 |
+
"""Starts training (one-line execution).
|
| 24 |
|
| 25 |
+
Usage (Colab):
|
| 26 |
```python
|
| 27 |
from model import LLMModel, ModelConfig
|
| 28 |
from data_pipeline import setup_data_pipeline, DataConfig
|
| 29 |
from trainer import start_training, TrainConfig
|
| 30 |
|
| 31 |
+
# 1. Create model
|
| 32 |
model_config = ModelConfig.base_1b()
|
| 33 |
model = LLMModel(model_config)
|
| 34 |
|
| 35 |
+
# 2. Data pipeline
|
| 36 |
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 37 |
|
| 38 |
+
# 3. Start training (automatic checkpoint restoration)
|
| 39 |
trainer = start_training(model, train_dl, val_dl)
|
| 40 |
```
|
| 41 |
"""
|
| 42 |
config = config or TrainConfig()
|
| 43 |
|
| 44 |
+
# Auto-detect GPU and adjust configuration
|
| 45 |
if auto_config:
|
| 46 |
config = auto_configure(config)
|
| 47 |
|
| 48 |
+
# Check Google Drive mount (Colab)
|
| 49 |
if "/content/drive" in config.checkpoint_dir:
|
| 50 |
drive_path = Path("/content/drive/MyDrive")
|
| 51 |
if not drive_path.exists():
|
| 52 |
+
print("\n⚠️ Google Drive is not mounted!")
|
| 53 |
+
print(" Run in Colab: from google.colab import drive; drive.mount('/content/drive')")
|
| 54 |
+
print(" Switching to local path.")
|
| 55 |
config.checkpoint_dir = "./checkpoints"
|
| 56 |
|
| 57 |
+
# Set reproducibility seed
|
| 58 |
torch.manual_seed(config.seed)
|
| 59 |
if torch.cuda.is_available():
|
| 60 |
torch.cuda.manual_seed(config.seed)
|
| 61 |
|
| 62 |
+
# Create Trainer (includes automatic checkpoint restoration)
|
| 63 |
trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
|
| 64 |
|
| 65 |
+
# Run training
|
| 66 |
trainer.train()
|
| 67 |
|
| 68 |
return trainer
|
llm_lab/training/scheduler.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Cosine Annealing with Linear Warmup
|
| 2 |
|
| 3 |
import math
|
| 4 |
|
|
@@ -10,22 +10,22 @@ from llm_lab.config import TrainConfig
|
|
| 10 |
class CosineWarmupScheduler:
|
| 11 |
"""Cosine Annealing with Linear Warmup.
|
| 12 |
|
| 13 |
-
LR
|
| 14 |
┌─── peak_lr ───────╲
|
| 15 |
│ ╲ cosine decay
|
| 16 |
│ warmup (linear) ╲
|
| 17 |
│/ ╲_______ min_lr
|
| 18 |
└──────────────────────────────────→ steps
|
| 19 |
|
| 20 |
-
|
| 21 |
-
- Step decay:
|
| 22 |
-
- Linear decay:
|
| 23 |
-
- Cosine:
|
| 24 |
-
- GPT-3, LLaMA,
|
| 25 |
|
| 26 |
-
|
| 27 |
-
PyTorch
|
| 28 |
-
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self, config: TrainConfig):
|
|
@@ -35,33 +35,33 @@ class CosineWarmupScheduler:
|
|
| 35 |
self.total_steps = config.total_steps
|
| 36 |
|
| 37 |
def get_lr(self, step: int) -> float:
|
| 38 |
-
"""
|
| 39 |
|
| 40 |
Args:
|
| 41 |
-
step:
|
| 42 |
|
| 43 |
Returns:
|
| 44 |
-
|
| 45 |
"""
|
| 46 |
# Phase 1: Linear Warmup
|
| 47 |
if step < self.warmup_steps:
|
| 48 |
-
#
|
| 49 |
return self.peak_lr * (step / self.warmup_steps)
|
| 50 |
|
| 51 |
# Phase 2: Cosine Decay
|
| 52 |
-
#
|
| 53 |
decay_steps = self.total_steps - self.warmup_steps
|
| 54 |
progress = (step - self.warmup_steps) / max(decay_steps, 1)
|
| 55 |
-
progress = min(progress, 1.0) #
|
| 56 |
|
| 57 |
-
# Cosine
|
| 58 |
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 59 |
lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
| 60 |
|
| 61 |
return lr
|
| 62 |
|
| 63 |
def set_lr(self, optimizer: torch.optim.Optimizer, step: int):
|
| 64 |
-
"""
|
| 65 |
lr = self.get_lr(step)
|
| 66 |
for param_group in optimizer.param_groups:
|
| 67 |
param_group["lr"] = lr
|
|
|
|
| 1 |
+
"""Cosine Annealing with Linear Warmup scheduler."""
|
| 2 |
|
| 3 |
import math
|
| 4 |
|
|
|
|
| 10 |
class CosineWarmupScheduler:
|
| 11 |
"""Cosine Annealing with Linear Warmup.
|
| 12 |
|
| 13 |
+
LR curve:
|
| 14 |
┌─── peak_lr ───────╲
|
| 15 |
│ ╲ cosine decay
|
| 16 |
│ warmup (linear) ╲
|
| 17 |
│/ ╲_______ min_lr
|
| 18 |
└──────────────────────────────────→ steps
|
| 19 |
|
| 20 |
+
Why Cosine Decay?
|
| 21 |
+
- Step decay: sudden LR drop → unstable loss
|
| 22 |
+
- Linear decay: LR decreases too quickly in the later stages
|
| 23 |
+
- Cosine: smooth decay, maintains appropriate LR even in the late training phase
|
| 24 |
+
- Used by most LLMs including GPT-3, LLaMA, and Chinchilla
|
| 25 |
|
| 26 |
+
Implementation note:
|
| 27 |
+
PyTorch has built-in schedulers (e.g., CosineAnnealingLR), but
|
| 28 |
+
a custom implementation is more flexible for warmup + min_lr + checkpoint restoration.
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(self, config: TrainConfig):
|
|
|
|
| 35 |
self.total_steps = config.total_steps
|
| 36 |
|
| 37 |
def get_lr(self, step: int) -> float:
|
| 38 |
+
"""Returns the learning rate for the current step.
|
| 39 |
|
| 40 |
Args:
|
| 41 |
+
step: Current optimizer step (0-indexed)
|
| 42 |
|
| 43 |
Returns:
|
| 44 |
+
Learning rate (float)
|
| 45 |
"""
|
| 46 |
# Phase 1: Linear Warmup
|
| 47 |
if step < self.warmup_steps:
|
| 48 |
+
# Linear increase from 0 to peak_lr
|
| 49 |
return self.peak_lr * (step / self.warmup_steps)
|
| 50 |
|
| 51 |
# Phase 2: Cosine Decay
|
| 52 |
+
# Progress ratio after warmup (0.0 → 1.0)
|
| 53 |
decay_steps = self.total_steps - self.warmup_steps
|
| 54 |
progress = (step - self.warmup_steps) / max(decay_steps, 1)
|
| 55 |
+
progress = min(progress, 1.0) # safety clamp
|
| 56 |
|
| 57 |
+
# Cosine formula: min_lr + 0.5 × (peak - min) × (1 + cos(π × progress))
|
| 58 |
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 59 |
lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
| 60 |
|
| 61 |
return lr
|
| 62 |
|
| 63 |
def set_lr(self, optimizer: torch.optim.Optimizer, step: int):
|
| 64 |
+
"""Updates the learning rate of the optimizer."""
|
| 65 |
lr = self.get_lr(step)
|
| 66 |
for param_group in optimizer.param_groups:
|
| 67 |
param_group["lr"] = lr
|
llm_lab/training/trainer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""LLM
|
| 2 |
|
| 3 |
import math
|
| 4 |
import time
|
|
@@ -16,9 +16,9 @@ from .optimizer import create_optimizer
|
|
| 16 |
|
| 17 |
|
| 18 |
class Trainer:
|
| 19 |
-
"""LLM
|
| 20 |
|
| 21 |
-
|
| 22 |
```
|
| 23 |
for step in range(total_steps):
|
| 24 |
# ── Gradient Accumulation Loop ──
|
|
@@ -27,22 +27,22 @@ class Trainer:
|
|
| 27 |
with autocast(bf16):
|
| 28 |
logits, loss = model(input_ids, targets)
|
| 29 |
scaled_loss = loss / accumulation_steps
|
| 30 |
-
scaled_loss.backward() #
|
| 31 |
|
| 32 |
-
# ── Optimizer Step (accumulation
|
| 33 |
clip_grad_norm(model, max_norm=1.0)
|
| 34 |
optimizer.step()
|
| 35 |
optimizer.zero_grad()
|
| 36 |
scheduler.set_lr(optimizer, step)
|
| 37 |
```
|
| 38 |
|
| 39 |
-
Gradient Accumulation
|
| 40 |
-
-
|
| 41 |
-
-
|
| 42 |
-
-
|
| 43 |
-
-
|
| 44 |
-
-
|
| 45 |
-
|
| 46 |
"""
|
| 47 |
|
| 48 |
def __init__(
|
|
@@ -56,52 +56,52 @@ class Trainer:
|
|
| 56 |
self.config = config
|
| 57 |
self.seq_len = seq_len
|
| 58 |
|
| 59 |
-
# ──
|
| 60 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 61 |
-
print(f"[Trainer]
|
| 62 |
if torch.cuda.is_available():
|
| 63 |
print(f"[Trainer] GPU: {torch.cuda.get_device_name()}")
|
| 64 |
-
print(f"[Trainer] GPU
|
| 65 |
|
| 66 |
-
# ──
|
| 67 |
self.model = model.to(self.device)
|
| 68 |
-
# torch.compile: PyTorch 2.0+
|
| 69 |
if torch.cuda.is_available() and hasattr(torch, "compile"):
|
| 70 |
-
print("[Trainer] torch.compile
|
| 71 |
self.model = torch.compile(self.model)
|
| 72 |
|
| 73 |
-
# ──
|
| 74 |
self.train_dataloader = train_dataloader
|
| 75 |
self.val_dataloader = val_dataloader
|
| 76 |
self.train_iter = iter(train_dataloader)
|
| 77 |
|
| 78 |
-
# ──
|
| 79 |
self.optimizer = create_optimizer(self.model, config)
|
| 80 |
|
| 81 |
-
# ──
|
| 82 |
self.scheduler = CosineWarmupScheduler(config)
|
| 83 |
|
| 84 |
-
# ──
|
| 85 |
self.ckpt_manager = CheckpointManager(config)
|
| 86 |
|
| 87 |
-
# ──
|
| 88 |
self.metrics = MetricsTracker(config)
|
| 89 |
|
| 90 |
-
# ──
|
| 91 |
self.global_step = 0
|
| 92 |
self.best_val_loss = float("inf")
|
| 93 |
self.tokens_seen = 0
|
| 94 |
|
| 95 |
# ── Mixed Precision ──
|
| 96 |
-
# bf16
|
| 97 |
self.use_amp = config.dtype != "float32"
|
| 98 |
self.amp_dtype = config.torch_dtype
|
| 99 |
|
| 100 |
-
# ──
|
| 101 |
self._try_resume()
|
| 102 |
|
| 103 |
def _try_resume(self):
|
| 104 |
-
"""
|
| 105 |
result = self.ckpt_manager.load_latest(
|
| 106 |
self.model, self.optimizer, self.device
|
| 107 |
)
|
|
@@ -111,20 +111,20 @@ class Trainer:
|
|
| 111 |
self.best_val_loss = result["best_val_loss"]
|
| 112 |
self.metrics.history = result.get("metrics_history", self.metrics.history)
|
| 113 |
|
| 114 |
-
# wandb
|
| 115 |
if result.get("wandb_run_id"):
|
| 116 |
self.metrics.resume_wandb(result["wandb_run_id"])
|
| 117 |
|
| 118 |
self.tokens_seen = self.global_step * self.config.effective_batch_size * self.seq_len
|
| 119 |
-
print(f"[Trainer]
|
| 120 |
f"tokens={self.tokens_seen/1e9:.2f}B, "
|
| 121 |
f"best_val_loss={self.best_val_loss:.4f}")
|
| 122 |
|
| 123 |
def _get_next_batch(self) -> Dict[str, torch.Tensor]:
|
| 124 |
-
"""
|
| 125 |
|
| 126 |
-
Streaming DataLoader
|
| 127 |
-
|
| 128 |
"""
|
| 129 |
try:
|
| 130 |
batch = next(self.train_iter)
|
|
@@ -138,14 +138,14 @@ class Trainer:
|
|
| 138 |
}
|
| 139 |
|
| 140 |
def _train_step(self) -> Tuple[float, float]:
|
| 141 |
-
"""
|
| 142 |
|
| 143 |
Returns:
|
| 144 |
(loss, grad_norm)
|
| 145 |
"""
|
| 146 |
self.model.train()
|
| 147 |
self.optimizer.zero_grad(set_to_none=True)
|
| 148 |
-
# set_to_none=True:
|
| 149 |
|
| 150 |
total_loss = 0.0
|
| 151 |
|
|
@@ -157,16 +157,16 @@ class Trainer:
|
|
| 157 |
with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
|
| 158 |
logits, loss = self.model(batch["input_ids"], batch["targets"])
|
| 159 |
|
| 160 |
-
# Loss
|
| 161 |
scaled_loss = loss / self.config.gradient_accumulation_steps
|
| 162 |
total_loss += loss.item()
|
| 163 |
|
| 164 |
-
# Backward (
|
| 165 |
scaled_loss.backward()
|
| 166 |
|
| 167 |
# ── Gradient Clipping ──
|
| 168 |
-
#
|
| 169 |
-
# norm
|
| 170 |
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 171 |
self.model.parameters(),
|
| 172 |
max_norm=self.config.grad_clip,
|
|
@@ -175,7 +175,7 @@ class Trainer:
|
|
| 175 |
# ── Optimizer Step ──
|
| 176 |
self.optimizer.step()
|
| 177 |
|
| 178 |
-
# ── LR
|
| 179 |
self.scheduler.set_lr(self.optimizer, self.global_step)
|
| 180 |
|
| 181 |
avg_loss = total_loss / self.config.gradient_accumulation_steps
|
|
@@ -183,13 +183,13 @@ class Trainer:
|
|
| 183 |
|
| 184 |
@torch.no_grad()
|
| 185 |
def _evaluate(self) -> Tuple[float, float]:
|
| 186 |
-
"""
|
| 187 |
|
| 188 |
Perplexity = exp(loss)
|
| 189 |
-
-
|
| 190 |
-
- PPL 100 →
|
| 191 |
-
- PPL 20 →
|
| 192 |
-
- PPL 10 →
|
| 193 |
"""
|
| 194 |
if self.val_dataloader is None:
|
| 195 |
return float("inf"), float("inf")
|
|
@@ -212,36 +212,37 @@ class Trainer:
|
|
| 212 |
num_batches += 1
|
| 213 |
|
| 214 |
avg_loss = total_loss / max(num_batches, 1)
|
| 215 |
-
perplexity = math.exp(min(avg_loss, 20)) # overflow
|
| 216 |
|
| 217 |
return avg_loss, perplexity
|
| 218 |
|
| 219 |
def train(self):
|
| 220 |
-
"""
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
|
|
|
| 224 |
"""
|
| 225 |
config = self.config
|
| 226 |
|
| 227 |
print("\n" + "=" * 70)
|
| 228 |
-
print("🚀
|
| 229 |
print("=" * 70)
|
| 230 |
-
print(f"
|
| 231 |
-
print(f"
|
| 232 |
print(f" Effective batch size: {config.effective_batch_size}")
|
| 233 |
-
print(f"
|
| 234 |
-
print(f"
|
| 235 |
print(f" Mixed Precision: {config.dtype}")
|
| 236 |
print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
|
| 237 |
-
print(f"
|
| 238 |
print("=" * 70 + "\n")
|
| 239 |
|
| 240 |
step_start_time = time.time()
|
| 241 |
tokens_at_log_start = self.tokens_seen
|
| 242 |
|
| 243 |
# ════════════════════════════════════════════
|
| 244 |
-
#
|
| 245 |
# ════════════════════════════════════════════
|
| 246 |
|
| 247 |
while self.global_step < config.total_steps:
|
|
@@ -257,21 +258,21 @@ class Trainer:
|
|
| 257 |
tokens_delta = self.tokens_seen - tokens_at_log_start
|
| 258 |
tokens_per_sec = tokens_delta / max(elapsed, 1e-6)
|
| 259 |
|
| 260 |
-
# GPU
|
| 261 |
gpu_mem_gb = 0.0
|
| 262 |
if torch.cuda.is_available():
|
| 263 |
gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9
|
| 264 |
|
| 265 |
-
#
|
| 266 |
current_lr = self.scheduler.get_lr(self.global_step)
|
| 267 |
|
| 268 |
-
#
|
| 269 |
remaining_steps = config.total_steps - self.global_step
|
| 270 |
steps_per_sec = config.log_interval / max(elapsed, 1e-6)
|
| 271 |
eta_seconds = remaining_steps / max(steps_per_sec, 1e-6)
|
| 272 |
eta_hours = eta_seconds / 3600
|
| 273 |
|
| 274 |
-
#
|
| 275 |
print(
|
| 276 |
f" Step {self.global_step:>6d}/{config.total_steps} │ "
|
| 277 |
f"Loss {loss:.4f} │ "
|
|
@@ -283,7 +284,7 @@ class Trainer:
|
|
| 283 |
f"Tokens {self.tokens_seen/1e9:.2f}B"
|
| 284 |
)
|
| 285 |
|
| 286 |
-
# wandb
|
| 287 |
self.metrics.log_train_step(
|
| 288 |
step=self.global_step,
|
| 289 |
loss=loss,
|
|
@@ -324,19 +325,19 @@ class Trainer:
|
|
| 324 |
)
|
| 325 |
|
| 326 |
# ════════════════════════════════════════════
|
| 327 |
-
#
|
| 328 |
# ════════════════════════════════════════════
|
| 329 |
|
| 330 |
print("\n" + "=" * 70)
|
| 331 |
-
print("🎉
|
| 332 |
print("=" * 70)
|
| 333 |
-
print(f"
|
| 334 |
-
print(f"
|
| 335 |
-
print(f"
|
| 336 |
-
print(f"
|
| 337 |
print("=" * 70)
|
| 338 |
|
| 339 |
-
#
|
| 340 |
self.ckpt_manager.save(
|
| 341 |
model=self.model,
|
| 342 |
optimizer=self.optimizer,
|
|
|
|
| 1 |
+
"""LLM pretraining trainer."""
|
| 2 |
|
| 3 |
import math
|
| 4 |
import time
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class Trainer:
|
| 19 |
+
"""LLM pretraining trainer.
|
| 20 |
|
| 21 |
+
Core structure of the training loop:
|
| 22 |
```
|
| 23 |
for step in range(total_steps):
|
| 24 |
# ── Gradient Accumulation Loop ──
|
|
|
|
| 27 |
with autocast(bf16):
|
| 28 |
logits, loss = model(input_ids, targets)
|
| 29 |
scaled_loss = loss / accumulation_steps
|
| 30 |
+
scaled_loss.backward() # accumulate gradients
|
| 31 |
|
| 32 |
+
# ── Optimizer Step (after accumulation completes) ──
|
| 33 |
clip_grad_norm(model, max_norm=1.0)
|
| 34 |
optimizer.step()
|
| 35 |
optimizer.zero_grad()
|
| 36 |
scheduler.set_lr(optimizer, step)
|
| 37 |
```
|
| 38 |
|
| 39 |
+
What is Gradient Accumulation?
|
| 40 |
+
- Used when a large batch cannot fit into GPU memory all at once
|
| 41 |
+
- Run forward/backward multiple times with small micro_batches → accumulate gradients
|
| 42 |
+
- Perform optimizer step once after accumulation is complete
|
| 43 |
+
- Effectively equivalent to training with a large effective batch size
|
| 44 |
+
- Reason for dividing loss by accumulation_steps:
|
| 45 |
+
to compute the mean of gradients (average, not sum)
|
| 46 |
"""
|
| 47 |
|
| 48 |
def __init__(
|
|
|
|
| 56 |
self.config = config
|
| 57 |
self.seq_len = seq_len
|
| 58 |
|
| 59 |
+
# ── Device setup ──
|
| 60 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 61 |
+
print(f"[Trainer] Device: {self.device}")
|
| 62 |
if torch.cuda.is_available():
|
| 63 |
print(f"[Trainer] GPU: {torch.cuda.get_device_name()}")
|
| 64 |
+
print(f"[Trainer] GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 65 |
|
| 66 |
+
# ── Model ──
|
| 67 |
self.model = model.to(self.device)
|
| 68 |
+
# torch.compile: PyTorch 2.0+ graph optimization (10-30% speed improvement)
|
| 69 |
if torch.cuda.is_available() and hasattr(torch, "compile"):
|
| 70 |
+
print("[Trainer] Applying torch.compile...")
|
| 71 |
self.model = torch.compile(self.model)
|
| 72 |
|
| 73 |
+
# ── Data ──
|
| 74 |
self.train_dataloader = train_dataloader
|
| 75 |
self.val_dataloader = val_dataloader
|
| 76 |
self.train_iter = iter(train_dataloader)
|
| 77 |
|
| 78 |
+
# ── Optimizer ──
|
| 79 |
self.optimizer = create_optimizer(self.model, config)
|
| 80 |
|
| 81 |
+
# ── Scheduler ──
|
| 82 |
self.scheduler = CosineWarmupScheduler(config)
|
| 83 |
|
| 84 |
+
# ── Checkpoint ──
|
| 85 |
self.ckpt_manager = CheckpointManager(config)
|
| 86 |
|
| 87 |
+
# ── Metrics ──
|
| 88 |
self.metrics = MetricsTracker(config)
|
| 89 |
|
| 90 |
+
# ── Training state ──
|
| 91 |
self.global_step = 0
|
| 92 |
self.best_val_loss = float("inf")
|
| 93 |
self.tokens_seen = 0
|
| 94 |
|
| 95 |
# ── Mixed Precision ──
|
| 96 |
+
# bf16 does not require GradScaler (only needed for fp16)
|
| 97 |
self.use_amp = config.dtype != "float32"
|
| 98 |
self.amp_dtype = config.torch_dtype
|
| 99 |
|
| 100 |
+
# ── Attempt automatic resume ──
|
| 101 |
self._try_resume()
|
| 102 |
|
| 103 |
def _try_resume(self):
|
| 104 |
+
"""Automatically restores from a previous checkpoint if one exists."""
|
| 105 |
result = self.ckpt_manager.load_latest(
|
| 106 |
self.model, self.optimizer, self.device
|
| 107 |
)
|
|
|
|
| 111 |
self.best_val_loss = result["best_val_loss"]
|
| 112 |
self.metrics.history = result.get("metrics_history", self.metrics.history)
|
| 113 |
|
| 114 |
+
# Resume wandb logging continuously
|
| 115 |
if result.get("wandb_run_id"):
|
| 116 |
self.metrics.resume_wandb(result["wandb_run_id"])
|
| 117 |
|
| 118 |
self.tokens_seen = self.global_step * self.config.effective_batch_size * self.seq_len
|
| 119 |
+
print(f"[Trainer] Resuming training: step={self.global_step}, "
|
| 120 |
f"tokens={self.tokens_seen/1e9:.2f}B, "
|
| 121 |
f"best_val_loss={self.best_val_loss:.4f}")
|
| 122 |
|
| 123 |
def _get_next_batch(self) -> Dict[str, torch.Tensor]:
|
| 124 |
+
"""Fetches the next training batch.
|
| 125 |
|
| 126 |
+
Since a Streaming DataLoader has no epoch concept,
|
| 127 |
+
a new iterator is created when StopIteration is raised.
|
| 128 |
"""
|
| 129 |
try:
|
| 130 |
batch = next(self.train_iter)
|
|
|
|
| 138 |
}
|
| 139 |
|
| 140 |
def _train_step(self) -> Tuple[float, float]:
|
| 141 |
+
"""Performs one optimizer step.
|
| 142 |
|
| 143 |
Returns:
|
| 144 |
(loss, grad_norm)
|
| 145 |
"""
|
| 146 |
self.model.train()
|
| 147 |
self.optimizer.zero_grad(set_to_none=True)
|
| 148 |
+
# set_to_none=True: sets gradients to None → saves memory
|
| 149 |
|
| 150 |
total_loss = 0.0
|
| 151 |
|
|
|
|
| 157 |
with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
|
| 158 |
logits, loss = self.model(batch["input_ids"], batch["targets"])
|
| 159 |
|
| 160 |
+
# Loss scaling: to compute the mean over the effective batch
|
| 161 |
scaled_loss = loss / self.config.gradient_accumulation_steps
|
| 162 |
total_loss += loss.item()
|
| 163 |
|
| 164 |
+
# Backward (accumulate gradients)
|
| 165 |
scaled_loss.backward()
|
| 166 |
|
| 167 |
# ── Gradient Clipping ──
|
| 168 |
+
# Treat all parameter gradients as a single vector and compute L2 norm
|
| 169 |
+
# If norm exceeds max_norm, scale down proportionally
|
| 170 |
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 171 |
self.model.parameters(),
|
| 172 |
max_norm=self.config.grad_clip,
|
|
|
|
| 175 |
# ── Optimizer Step ──
|
| 176 |
self.optimizer.step()
|
| 177 |
|
| 178 |
+
# ── LR Update ──
|
| 179 |
self.scheduler.set_lr(self.optimizer, self.global_step)
|
| 180 |
|
| 181 |
avg_loss = total_loss / self.config.gradient_accumulation_steps
|
|
|
|
| 183 |
|
| 184 |
@torch.no_grad()
|
| 185 |
def _evaluate(self) -> Tuple[float, float]:
|
| 186 |
+
"""Measures Loss and Perplexity on the validation data.
|
| 187 |
|
| 188 |
Perplexity = exp(loss)
|
| 189 |
+
- Intuition: "how many candidates does the model choose the next token from on average"
|
| 190 |
+
- PPL 100 → equivalent to uniformly choosing 1 out of 100
|
| 191 |
+
- PPL 20 → 1 out of 20 (fairly good)
|
| 192 |
+
- PPL 10 → predicting with high confidence
|
| 193 |
"""
|
| 194 |
if self.val_dataloader is None:
|
| 195 |
return float("inf"), float("inf")
|
|
|
|
| 212 |
num_batches += 1
|
| 213 |
|
| 214 |
avg_loss = total_loss / max(num_batches, 1)
|
| 215 |
+
perplexity = math.exp(min(avg_loss, 20)) # prevent overflow (exp(20) ≈ 500M)
|
| 216 |
|
| 217 |
return avg_loss, perplexity
|
| 218 |
|
| 219 |
def train(self):
|
| 220 |
+
"""Main training loop.
|
| 221 |
|
| 222 |
+
This method runs the entire training process.
|
| 223 |
+
Even if interrupted by a Colab session expiry,
|
| 224 |
+
training will automatically resume from the last checkpoint.
|
| 225 |
"""
|
| 226 |
config = self.config
|
| 227 |
|
| 228 |
print("\n" + "=" * 70)
|
| 229 |
+
print("🚀 Training started")
|
| 230 |
print("=" * 70)
|
| 231 |
+
print(f" Total steps: {config.total_steps:,}")
|
| 232 |
+
print(f" Start step: {self.global_step}")
|
| 233 |
print(f" Effective batch size: {config.effective_batch_size}")
|
| 234 |
+
print(f" Tokens/step: {config.effective_batch_size * self.seq_len:,}")
|
| 235 |
+
print(f" Total training tokens (estimated): {config.total_steps * config.effective_batch_size * self.seq_len / 1e9:.1f}B")
|
| 236 |
print(f" Mixed Precision: {config.dtype}")
|
| 237 |
print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
|
| 238 |
+
print(f" Checkpoint: {config.checkpoint_dir}")
|
| 239 |
print("=" * 70 + "\n")
|
| 240 |
|
| 241 |
step_start_time = time.time()
|
| 242 |
tokens_at_log_start = self.tokens_seen
|
| 243 |
|
| 244 |
# ════════════════════════════════════════════
|
| 245 |
+
# Main loop
|
| 246 |
# ════════════════════════════════════════════
|
| 247 |
|
| 248 |
while self.global_step < config.total_steps:
|
|
|
|
| 258 |
tokens_delta = self.tokens_seen - tokens_at_log_start
|
| 259 |
tokens_per_sec = tokens_delta / max(elapsed, 1e-6)
|
| 260 |
|
| 261 |
+
# GPU memory
|
| 262 |
gpu_mem_gb = 0.0
|
| 263 |
if torch.cuda.is_available():
|
| 264 |
gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9
|
| 265 |
|
| 266 |
+
# Current LR
|
| 267 |
current_lr = self.scheduler.get_lr(self.global_step)
|
| 268 |
|
| 269 |
+
# Estimate remaining time
|
| 270 |
remaining_steps = config.total_steps - self.global_step
|
| 271 |
steps_per_sec = config.log_interval / max(elapsed, 1e-6)
|
| 272 |
eta_seconds = remaining_steps / max(steps_per_sec, 1e-6)
|
| 273 |
eta_hours = eta_seconds / 3600
|
| 274 |
|
| 275 |
+
# Console output
|
| 276 |
print(
|
| 277 |
f" Step {self.global_step:>6d}/{config.total_steps} │ "
|
| 278 |
f"Loss {loss:.4f} │ "
|
|
|
|
| 284 |
f"Tokens {self.tokens_seen/1e9:.2f}B"
|
| 285 |
)
|
| 286 |
|
| 287 |
+
# wandb logging
|
| 288 |
self.metrics.log_train_step(
|
| 289 |
step=self.global_step,
|
| 290 |
loss=loss,
|
|
|
|
| 325 |
)
|
| 326 |
|
| 327 |
# ════════════════════════════════════════════
|
| 328 |
+
# Training complete
|
| 329 |
# ════════════════════════════════════════════
|
| 330 |
|
| 331 |
print("\n" + "=" * 70)
|
| 332 |
+
print("🎉 Training complete!")
|
| 333 |
print("=" * 70)
|
| 334 |
+
print(f" Total steps: {self.global_step:,}")
|
| 335 |
+
print(f" Total tokens: {self.tokens_seen/1e9:.2f}B")
|
| 336 |
+
print(f" Best Val Loss: {self.best_val_loss:.4f}")
|
| 337 |
+
print(f" Best Val PPL: {math.exp(min(self.best_val_loss, 20)):.2f}")
|
| 338 |
print("=" * 70)
|
| 339 |
|
| 340 |
+
# Save final checkpoint
|
| 341 |
self.ckpt_manager.save(
|
| 342 |
model=self.model,
|
| 343 |
optimizer=self.optimizer,
|
llm_lab/utils/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
from .device import get_device, detect_gpu_info, auto_configure
|
| 3 |
from .seed import set_seed
|
| 4 |
|
|
|
|
| 1 |
+
"""Common utilities — device detection, seed configuration."""
|
| 2 |
from .device import get_device, detect_gpu_info, auto_configure
|
| 3 |
from .seed import set_seed
|
| 4 |
|
llm_lab/utils/device.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
from typing import TYPE_CHECKING
|
|
@@ -10,15 +10,15 @@ if TYPE_CHECKING:
|
|
| 10 |
|
| 11 |
|
| 12 |
def get_device() -> torch.device:
|
| 13 |
-
"""
|
| 14 |
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
|
| 16 |
|
| 17 |
def detect_gpu_info() -> dict:
|
| 18 |
-
"""GPU
|
| 19 |
|
| 20 |
Returns:
|
| 21 |
-
{"name": str, "memory_gb": float}
|
| 22 |
"""
|
| 23 |
if not torch.cuda.is_available():
|
| 24 |
return {}
|
|
@@ -29,16 +29,16 @@ def detect_gpu_info() -> dict:
|
|
| 29 |
|
| 30 |
|
| 31 |
def auto_configure(config: "TrainConfig") -> "TrainConfig":
|
| 32 |
-
"""
|
| 33 |
|
| 34 |
-
Colab Pro+
|
| 35 |
-
T4
|
| 36 |
|
| 37 |
Returns:
|
| 38 |
-
|
| 39 |
"""
|
| 40 |
if not torch.cuda.is_available():
|
| 41 |
-
print("⚠️ GPU
|
| 42 |
config.dtype = "float32"
|
| 43 |
config.micro_batch_size = 1
|
| 44 |
config.gradient_accumulation_steps = 4
|
|
@@ -47,37 +47,37 @@ def auto_configure(config: "TrainConfig") -> "TrainConfig":
|
|
| 47 |
gpu_name = torch.cuda.get_device_name().lower()
|
| 48 |
gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
|
| 49 |
|
| 50 |
-
print(f"\n🔍 GPU
|
| 51 |
|
| 52 |
if "a100" in gpu_name:
|
| 53 |
-
# A100 40GB:
|
| 54 |
-
print(" → A100
|
| 55 |
config.dtype = "bfloat16"
|
| 56 |
config.micro_batch_size = 4
|
| 57 |
|
| 58 |
elif "v100" in gpu_name:
|
| 59 |
-
# V100 16GB: bf16
|
| 60 |
-
print(" → V100
|
| 61 |
config.dtype = "float16"
|
| 62 |
config.micro_batch_size = 2
|
| 63 |
-
config.gradient_accumulation_steps = 64 # effective batch
|
| 64 |
|
| 65 |
elif "t4" in gpu_name:
|
| 66 |
-
# T4 16GB: bf16
|
| 67 |
-
print(" → T4
|
| 68 |
config.dtype = "float16"
|
| 69 |
config.micro_batch_size = 1
|
| 70 |
config.gradient_accumulation_steps = 128
|
| 71 |
|
| 72 |
elif "l4" in gpu_name:
|
| 73 |
-
# L4 24GB: bf16
|
| 74 |
-
print(" → L4
|
| 75 |
config.dtype = "bfloat16"
|
| 76 |
config.micro_batch_size = 2
|
| 77 |
config.gradient_accumulation_steps = 64
|
| 78 |
|
| 79 |
else:
|
| 80 |
-
print(f" →
|
| 81 |
if gpu_mem >= 30:
|
| 82 |
config.micro_batch_size = 4
|
| 83 |
elif gpu_mem >= 16:
|
|
|
|
| 1 |
+
"""Device detection and auto-configuration utilities."""
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
from typing import TYPE_CHECKING
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def get_device() -> torch.device:
|
| 13 |
+
"""Returns the available device (cuda or cpu)."""
|
| 14 |
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
|
| 16 |
|
| 17 |
def detect_gpu_info() -> dict:
|
| 18 |
+
"""Returns GPU name and memory information.
|
| 19 |
|
| 20 |
Returns:
|
| 21 |
+
{"name": str, "memory_gb": float} or an empty dict if no GPU is available
|
| 22 |
"""
|
| 23 |
if not torch.cuda.is_available():
|
| 24 |
return {}
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def auto_configure(config: "TrainConfig") -> "TrainConfig":
|
| 32 |
+
"""Automatically adjusts configuration based on GPU type.
|
| 33 |
|
| 34 |
+
In Colab Pro+, an A100 is not always assigned.
|
| 35 |
+
If a T4 or V100 is assigned, configuration is automatically adjusted.
|
| 36 |
|
| 37 |
Returns:
|
| 38 |
+
Adjusted TrainConfig
|
| 39 |
"""
|
| 40 |
if not torch.cuda.is_available():
|
| 41 |
+
print("⚠️ No GPU found! Running in CPU mode (very slow)")
|
| 42 |
config.dtype = "float32"
|
| 43 |
config.micro_batch_size = 1
|
| 44 |
config.gradient_accumulation_steps = 4
|
|
|
|
| 47 |
gpu_name = torch.cuda.get_device_name().lower()
|
| 48 |
gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
|
| 49 |
|
| 50 |
+
print(f"\n🔍 GPU detected: {torch.cuda.get_device_name()} ({gpu_mem:.1f} GB)")
|
| 51 |
|
| 52 |
if "a100" in gpu_name:
|
| 53 |
+
# A100 40GB: use default settings (optimal)
|
| 54 |
+
print(" → A100 detected: using default settings (bf16, batch=4)")
|
| 55 |
config.dtype = "bfloat16"
|
| 56 |
config.micro_batch_size = 4
|
| 57 |
|
| 58 |
elif "v100" in gpu_name:
|
| 59 |
+
# V100 16GB: bf16 not supported, reduce batch size
|
| 60 |
+
print(" → V100 detected: fp16 mode, reduced batch size")
|
| 61 |
config.dtype = "float16"
|
| 62 |
config.micro_batch_size = 2
|
| 63 |
+
config.gradient_accumulation_steps = 64 # maintain effective batch size
|
| 64 |
|
| 65 |
elif "t4" in gpu_name:
|
| 66 |
+
# T4 16GB: bf16 not supported, smaller batch
|
| 67 |
+
print(" → T4 detected: fp16 mode, minimum batch size")
|
| 68 |
config.dtype = "float16"
|
| 69 |
config.micro_batch_size = 1
|
| 70 |
config.gradient_accumulation_steps = 128
|
| 71 |
|
| 72 |
elif "l4" in gpu_name:
|
| 73 |
+
# L4 24GB: bf16 supported
|
| 74 |
+
print(" → L4 detected: bf16 mode, adjusted batch size")
|
| 75 |
config.dtype = "bfloat16"
|
| 76 |
config.micro_batch_size = 2
|
| 77 |
config.gradient_accumulation_steps = 64
|
| 78 |
|
| 79 |
else:
|
| 80 |
+
print(f" → Unknown GPU. Adjusting settings based on memory")
|
| 81 |
if gpu_mem >= 30:
|
| 82 |
config.micro_batch_size = 4
|
| 83 |
elif gpu_mem >= 16:
|
llm_lab/utils/seed.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
-
"""
|
| 2 |
import torch
|
| 3 |
|
| 4 |
|
| 5 |
def set_seed(seed: int = 42):
|
| 6 |
-
"""
|
| 7 |
torch.manual_seed(seed)
|
| 8 |
if torch.cuda.is_available():
|
| 9 |
torch.cuda.manual_seed(seed)
|
|
|
|
| 1 |
+
"""Seed utility for reproducibility."""
|
| 2 |
import torch
|
| 3 |
|
| 4 |
|
| 5 |
def set_seed(seed: int = 42):
|
| 6 |
+
"""Set seed for reproducibility."""
|
| 7 |
torch.manual_seed(seed)
|
| 8 |
if torch.cuda.is_available():
|
| 9 |
torch.cuda.manual_seed(seed)
|