Vjeong Claude Sonnet 4.6 commited on
Commit
858e8b2
·
1 Parent(s): 2367a60

docs: translate all Korean comments and docstrings to English

Browse files

Convert all Korean-language comments, docstrings, and inline annotations
to English across 38 source files and CLAUDE.md. Update code conventions
to require English for all code, comments, docstrings, and git commit messages.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

CLAUDE.md CHANGED
@@ -1,78 +1,78 @@
1
  # LLM-1B-Lab
2
 
3
- 1.1B parameter LLaMA-style Decoder-Only Transformer 교육용 구현.
4
- 딥러닝 초보자가 처음부터 끝까지 LLM을 학습하고 평가하는 과정을 경험할 있도록 설계됨.
5
 
6
- ## 프로젝트 구조
7
 
8
  ```
9
  LLM_Foundation_Model/
10
  ├── CLAUDE.md
11
  ├── requirements.txt
12
- ├── llm_lab/ # Python 패키지 (핵심 코드)
13
  │ ├── __init__.py
14
- │ ├── config/ # 설정 데이터클래스
15
- │ │ ├── model_config.py # ModelConfig (debug_10m / small_100m / base_1b 프리셋)
16
- │ │ ├── data_config.py # DataConfig (데이터셋, 토크나이저, 배치 설정)
17
- │ │ ├── train_config.py # TrainConfig (LR, 스케줄러, 체크포인트, wandb)
18
- │ │ └── eval_config.py # EvalConfig (평가 파라미터)
19
- │ ├── model/ # 모델 아키텍처
20
  │ │ ├── norm.py # RMSNorm
21
  │ │ ├── rope.py # RotaryPositionalEmbedding (RoPE)
22
  │ │ ├── attention.py # GroupedQueryAttention (GQA)
23
  │ │ ├── feedforward.py # SwiGLUFeedForward
24
  │ │ ├── transformer_block.py # TransformerBlock (Pre-LN)
25
- │ │ ├── llm_model.py # LLMModel (전체 모델 + generate)
26
  │ │ └── utils.py # count_parameters_detailed, estimate_memory_gb
27
- │ ├── data/ # 데이터 파이프라인
28
  │ │ ├── tokenizer.py # Tokenizer (SentencePiece / BPE / HuggingFace)
29
  │ │ ├── dataset.py # PackedStreamingDataset, ValidationDataset, _collate_fn
30
  │ │ ├── pipeline.py # create_train_dataloader, setup_data_pipeline
31
  │ │ └── diagnostics.py # DataPipelineDiagnostics
32
- │ ├── training/ # 학습 루프
33
  │ │ ├── scheduler.py # CosineWarmupScheduler
34
- │ │ ├── checkpoint.py # CheckpointManager (Google Drive 지원)
35
- │ │ ├── metrics.py # MetricsTracker (wandb 연동)
36
- │ │ ├── optimizer.py # create_optimizer (weight decay 분리)
37
  │ │ ├── trainer.py # Trainer (gradient accumulation, mixed precision)
38
- │ │ └── runner.py # start_training ( 줄 실행 헬퍼)
39
- │ ├── evaluation/ # 평가 & 분석
40
- │ │ ├── perplexity.py # PerplexityEvaluator (위치별 Loss 포함)
41
- │ │ ├── generation.py # GenerationEvaluator (다양한 프롬프트)
42
  │ │ ├── scaling.py # ScalingAnalyzer (Chinchilla Scaling Law)
43
- │ │ ├── dynamics.py # TrainingDynamicsAnalyzer (Loss/LR/Grad 시각화)
44
- │ │ ├── attention_viz.py # AttentionVisualizer (헤드별 heatmap)
45
- │ │ ├── full_evaluator.py # FullEvaluator (종합 평가 + 리포트)
46
- │ │ ├── checklist.py # InsightChecklist (학습 인사이트 체크리스트)
47
- │ │ └── runner.py # run_evaluation ( 줄 실행 헬퍼)
48
- │ └── utils/ # 공통 유틸리티
49
  │ ├── device.py # auto_configure, get_device, detect_gpu_info
50
  │ └── seed.py # set_seed
51
- ├── notebooks/ # Jupyter 노트북 (설정 + 실행)
52
  │ ├── 01_data_pipeline.ipynb
53
  │ ├── 02_model.ipynb
54
  │ ├── 03_training.ipynb
55
  │ └── 04_evaluation.ipynb
56
- └── _archive/ # 원본 단일파일 백업
57
  ├── llm-1b-model.py
58
  ├── llm-1b-data-pipeline.py
59
  ├── llm-1b-trainer.py
60
  └── llm-1b-evaluation.py
61
  ```
62
 
63
- ## 기술 스택
64
 
65
- - **모델**: LLaMA-style Decoder-Only Transformer (RMSNorm, RoPE, GQA, SwiGLU, Weight Tying)
66
- - **학습**: Gradient Accumulation, Mixed Precision (bf16/fp16), Cosine LR + Warmup, Activation Checkpointing
67
- - **데이터**: HuggingFace Streaming (FineWeb-Edu), BPE 토크나이저, 시퀀스 패킹
68
- - **체크포인트**: Google Drive 자동 저장/복원 (Colab Pro+ 환경)
69
- - **평가**: Perplexity, 텍스트 생성, Scaling Law, Attention 시각화
70
- - **타겟 환경**: Google Colab Pro+ (A100 40GB)
71
 
72
- ## 의존성 그래프 (순환 없음)
73
 
74
  ```
75
- config (의존성 없음)
76
 
77
  utils → config
78
 
@@ -85,13 +85,13 @@ training → config, utils
85
  evaluation → config
86
  ```
87
 
88
- ## 모델 프리셋
89
 
90
- | 프리셋 | 파라미터 | dim | layers | heads | kv_heads | 용도 |
91
- |--------|---------|-----|--------|-------|----------|------|
92
- | `debug_10m` | ~10M | 256 | 6 | 8 | 4 | 빠른 검증/디버그 |
93
- | `small_100m` | ~100M | 768 | 12 | 12 | 4 | 중간 실험 |
94
- | `base_1b` | ~1.1B | 2048 | 22 | 32 | 8 | 본격 학습 |
95
 
96
  ## Quick Start
97
 
@@ -102,30 +102,30 @@ from llm_lab.data import setup_data_pipeline
102
  from llm_lab.training import start_training
103
  from llm_lab.evaluation import run_evaluation
104
 
105
- # 1. 모델
106
  model = LLMModel(ModelConfig.base_1b())
107
 
108
- # 2. 데이터
109
  tok, train_dl, val_dl = setup_data_pipeline("pretrained")
110
 
111
- # 3. 학습
112
  trainer = start_training(model, train_dl, val_dl)
113
 
114
- # 4. 평가
115
  report = run_evaluation(model, tok, val_dl,
116
  metrics_history=trainer.metrics.history)
117
  ```
118
 
119
- ## 코드 컨벤션
120
 
121
- - **언어**: 코드는 영어, 주석/독스트링은 한국어 (교육적 설명 포함)
122
- - **타입 힌트**: 모든 함수에 typing 어노테이션 사용
123
- - **import 순서**: stdlib → torch → llm_lab (절대 경로) → 로컬 (상대 경로)
124
- - **데이터클래스**: 모든 설정은 `@dataclass` 정의, 기본값 포함
125
- - **에러 처리**: 외부 의존성(matplotlib, wandb ) `try/except ImportError`로 선택적 사용
126
 
127
- ## 주의사항
128
 
129
- - `torch` 로컬 환경에 설치되어 있지 않을 수 있음 (Colab Pro+에서 실행 전제)
130
  - `pip install torch datasets tokenizers sentencepiece transformers wandb matplotlib numpy`
131
- - 원본 4 파일(`_archive/`) 모듈화된 `llm_lab/` 패키지의 로직은 동일 (import 경로만 변경)
 
1
  # LLM-1B-Lab
2
 
3
+ Educational implementation of a 1.1B parameter LLaMA-style Decoder-Only Transformer.
4
+ Designed so beginners in deep learning can experience training and evaluating an LLM from scratch.
5
 
6
+ ## Project Structure
7
 
8
  ```
9
  LLM_Foundation_Model/
10
  ├── CLAUDE.md
11
  ├── requirements.txt
12
+ ├── llm_lab/ # Python package (core code)
13
  │ ├── __init__.py
14
+ │ ├── config/ # Configuration dataclasses
15
+ │ │ ├── model_config.py # ModelConfig (debug_10m / small_100m / base_1b presets)
16
+ │ │ ├── data_config.py # DataConfig (dataset, tokenizer, batch settings)
17
+ │ │ ├── train_config.py # TrainConfig (LR, scheduler, checkpoint, wandb)
18
+ │ │ └── eval_config.py # EvalConfig (evaluation parameters)
19
+ │ ├── model/ # Model architecture
20
  │ │ ├── norm.py # RMSNorm
21
  │ │ ├── rope.py # RotaryPositionalEmbedding (RoPE)
22
  │ │ ├── attention.py # GroupedQueryAttention (GQA)
23
  │ │ ├── feedforward.py # SwiGLUFeedForward
24
  │ │ ├── transformer_block.py # TransformerBlock (Pre-LN)
25
+ │ │ ├── llm_model.py # LLMModel (full model + generate)
26
  │ │ └── utils.py # count_parameters_detailed, estimate_memory_gb
27
+ │ ├── data/ # Data pipeline
28
  │ │ ├── tokenizer.py # Tokenizer (SentencePiece / BPE / HuggingFace)
29
  │ │ ├── dataset.py # PackedStreamingDataset, ValidationDataset, _collate_fn
30
  │ │ ├── pipeline.py # create_train_dataloader, setup_data_pipeline
31
  │ │ └── diagnostics.py # DataPipelineDiagnostics
32
+ │ ├── training/ # Training loop
33
  │ │ ├── scheduler.py # CosineWarmupScheduler
34
+ │ │ ├── checkpoint.py # CheckpointManager (Google Drive support)
35
+ │ │ ├── metrics.py # MetricsTracker (wandb integration)
36
+ │ │ ├── optimizer.py # create_optimizer (weight decay separation)
37
  │ │ ├── trainer.py # Trainer (gradient accumulation, mixed precision)
38
+ │ │ └── runner.py # start_training (one-line helper)
39
+ │ ├── evaluation/ # Evaluation & analysis
40
+ │ │ ├── perplexity.py # PerplexityEvaluator (including per-position loss)
41
+ │ │ ├── generation.py # GenerationEvaluator (various prompts)
42
  │ │ ├── scaling.py # ScalingAnalyzer (Chinchilla Scaling Law)
43
+ │ │ ├── dynamics.py # TrainingDynamicsAnalyzer (Loss/LR/Grad visualization)
44
+ │ │ ├── attention_viz.py # AttentionVisualizer (per-head heatmap)
45
+ │ │ ├── full_evaluator.py # FullEvaluator (comprehensive evaluation + report)
46
+ │ │ ├── checklist.py # InsightChecklist (training insight checklist)
47
+ │ │ └── runner.py # run_evaluation (one-line helper)
48
+ │ └── utils/ # Common utilities
49
  │ ├── device.py # auto_configure, get_device, detect_gpu_info
50
  │ └── seed.py # set_seed
51
+ ├── notebooks/ # Jupyter notebooks (configuration + execution)
52
  │ ├── 01_data_pipeline.ipynb
53
  │ ├── 02_model.ipynb
54
  │ ├── 03_training.ipynb
55
  │ └── 04_evaluation.ipynb
56
+ └── _archive/ # Original single-file backups
57
  ├── llm-1b-model.py
58
  ├── llm-1b-data-pipeline.py
59
  ├── llm-1b-trainer.py
60
  └── llm-1b-evaluation.py
61
  ```
62
 
63
+ ## Tech Stack
64
 
65
+ - **Model**: LLaMA-style Decoder-Only Transformer (RMSNorm, RoPE, GQA, SwiGLU, Weight Tying)
66
+ - **Training**: Gradient Accumulation, Mixed Precision (bf16/fp16), Cosine LR + Warmup, Activation Checkpointing
67
+ - **Data**: HuggingFace Streaming (FineWeb-Edu), BPE tokenizer, sequence packing
68
+ - **Checkpoint**: Auto save/restore to Google Drive (Colab Pro+ environment)
69
+ - **Evaluation**: Perplexity, text generation, Scaling Law, Attention visualization
70
+ - **Target Environment**: Google Colab Pro+ (A100 40GB)
71
 
72
+ ## Dependency Graph (no cycles)
73
 
74
  ```
75
+ config (no dependencies)
76
 
77
  utils → config
78
 
 
85
  evaluation → config
86
  ```
87
 
88
+ ## Model Presets
89
 
90
+ | Preset | Parameters | dim | layers | heads | kv_heads | Purpose |
91
+ |--------|-----------|-----|--------|-------|----------|---------|
92
+ | `debug_10m` | ~10M | 256 | 6 | 8 | 4 | Fast validation/debug |
93
+ | `small_100m` | ~100M | 768 | 12 | 12 | 4 | Intermediate experiments |
94
+ | `base_1b` | ~1.1B | 2048 | 22 | 32 | 8 | Full-scale training |
95
 
96
  ## Quick Start
97
 
 
102
  from llm_lab.training import start_training
103
  from llm_lab.evaluation import run_evaluation
104
 
105
+ # 1. Model
106
  model = LLMModel(ModelConfig.base_1b())
107
 
108
+ # 2. Data
109
  tok, train_dl, val_dl = setup_data_pipeline("pretrained")
110
 
111
+ # 3. Training
112
  trainer = start_training(model, train_dl, val_dl)
113
 
114
+ # 4. Evaluation
115
  report = run_evaluation(model, tok, val_dl,
116
  metrics_history=trainer.metrics.history)
117
  ```
118
 
119
+ ## Code Conventions
120
 
121
+ - **Language**: All code, comments, docstrings, and git commit messages must be written in English
122
+ - **Type hints**: Use typing annotations on all functions
123
+ - **Import order**: stdlib → torch → llm_lab (absolute) → local (relative)
124
+ - **Dataclasses**: All configurations defined as `@dataclass` with defaults
125
+ - **Error handling**: Optional dependencies (matplotlib, wandb, etc.) wrapped in `try/except ImportError`
126
 
127
+ ## Notes
128
 
129
+ - `torch` may not be installed locally (assumes Colab Pro+ runtime)
130
  - `pip install torch datasets tokenizers sentencepiece transformers wandb matplotlib numpy`
131
+ - The logic in the original 4 files (`_archive/`) and the modularized `llm_lab/` package is identical (only import paths changed)
llm_lab/__init__.py CHANGED
@@ -1,16 +1,16 @@
1
  """
2
  LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
3
  ================================================================
4
- 딥러닝 초보자를 위한 학습용 구현.
5
- 컴포넌트에 상세 주석을 달아 " 이렇게 하는지"를 설명합니다.
6
 
7
- 모듈 구조:
8
- llm_lab.config — 모든 설정 (ModelConfig, DataConfig, TrainConfig, EvalConfig)
9
- llm_lab.model — 모델 아키텍처 (RMSNorm, RoPE, GQA, SwiGLU, Transformer)
10
- llm_lab.data — 데이터 파이프라인 (토크나이저, 스트리밍, 패킹)
11
- llm_lab.training — 학습 루프 (Trainer, 스케줄러, 체크포인트)
12
- llm_lab.evaluation — 평가 (Perplexity, 생성, Scaling Law, Attention)
13
- llm_lab.utils — 공통 유틸리티 (디바이스 감지, 시드)
14
 
15
  Quick Start:
16
  from llm_lab.config import ModelConfig, DataConfig, TrainConfig
 
1
  """
2
  LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
3
  ================================================================
4
+ An educational implementation for deep learning beginners.
5
+ Each component includes detailed comments explaining "why" things are done this way.
6
 
7
+ Module structure:
8
+ llm_lab.config — All configurations (ModelConfig, DataConfig, TrainConfig, EvalConfig)
9
+ llm_lab.model — Model architecture (RMSNorm, RoPE, GQA, SwiGLU, Transformer)
10
+ llm_lab.data — Data pipeline (tokenizer, streaming, packing)
11
+ llm_lab.training — Training loop (Trainer, scheduler, checkpoint)
12
+ llm_lab.evaluation — Evaluation (Perplexity, generation, Scaling Law, Attention)
13
+ llm_lab.utils — Common utilities (device detection, seed)
14
 
15
  Quick Start:
16
  from llm_lab.config import ModelConfig, DataConfig, TrainConfig
llm_lab/config/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """설정(Config) 모듈모든 하이퍼파라미터를 곳에서 관리합니다."""
2
  from .model_config import ModelConfig
3
  from .data_config import DataConfig
4
  from .train_config import TrainConfig
 
1
+ """Config modulemanages all hyperparameters in one place."""
2
  from .model_config import ModelConfig
3
  from .data_config import DataConfig
4
  from .train_config import TrainConfig
llm_lab/config/data_config.py CHANGED
@@ -4,38 +4,38 @@ from typing import Optional
4
 
5
  @dataclass
6
  class DataConfig:
7
- """데이터 파이프라인 설정.
8
 
9
- Colab Pro+ 환경 제약을 고려한 기본값:
10
- - Streaming 모드로 디스크 사용 최소화
11
- - 시퀀스 패킹으로 패딩 없이 GPU 활용률 극대화
12
- - 전처리를 on-the-fly 수행하여 메모리 절약
13
  """
14
- # ── 데이터셋 ──
15
  dataset_name: str = "HuggingFaceFW/fineweb-edu"
16
- dataset_subset: str = "sample-10BT" # 10B 토큰 샘플
17
  dataset_split: str = "train"
18
- text_column: str = "text" # 텍스트가 담긴 컬럼명
19
 
20
- # ── 토크나이저 ──
21
- tokenizer_type: str = "sentencepiece" # "sentencepiece" 또는 "hf"
22
- # 사전 학습된 토크나이저 경로 (없으면 새로 학습)
23
  tokenizer_path: Optional[str] = None
24
  vocab_size: int = 32_000
25
 
26
- # ── 시퀀스 ──
27
  max_seq_len: int = 2048
28
- # 문서 구분 토큰 사용 여부 (패킹 문서 경계 표시)
29
  use_eos_separator: bool = True
30
 
31
- # ── 배치 ──
32
- batch_size: int = 4 # micro batch (GPU)
33
- num_workers: int = 2 # DataLoader 워커
34
- prefetch_factor: int = 4 # 미리 준비할 배치
35
 
36
- # ── 토크나이저 학습 설정 (새로 학습 ) ──
37
- tokenizer_train_samples: int = 50_000 # 학습에 사용할 문서
38
  tokenizer_save_dir: str = "./tokenizer"
39
 
40
- # ── 검증 데이터 ──
41
- val_ratio: float = 0.001 # 전체의 0.1% 검증용으로
 
4
 
5
  @dataclass
6
  class DataConfig:
7
+ """Data pipeline configuration.
8
 
9
+ Default values optimized for Colab Pro+ environment constraints:
10
+ - Streaming mode to minimize disk usage
11
+ - Sequence packing to maximize GPU utilization without padding
12
+ - On-the-fly preprocessing to save memory
13
  """
14
+ # ── Dataset ──
15
  dataset_name: str = "HuggingFaceFW/fineweb-edu"
16
+ dataset_subset: str = "sample-10BT" # 10B token sample
17
  dataset_split: str = "train"
18
+ text_column: str = "text" # column name containing text
19
 
20
+ # ── Tokenizer ──
21
+ tokenizer_type: str = "sentencepiece" # "sentencepiece" or "hf"
22
+ # path to a pretrained tokenizer (trains a new one if not provided)
23
  tokenizer_path: Optional[str] = None
24
  vocab_size: int = 32_000
25
 
26
+ # ── Sequence ──
27
  max_seq_len: int = 2048
28
+ # whether to use a document separator token (marks document boundaries during packing)
29
  use_eos_separator: bool = True
30
 
31
+ # ── Batch ──
32
+ batch_size: int = 4 # micro batch (per GPU)
33
+ num_workers: int = 2 # number of DataLoader workers
34
+ prefetch_factor: int = 4 # number of batches to prefetch
35
 
36
+ # ── Tokenizer training settings (when training from scratch) ──
37
+ tokenizer_train_samples: int = 50_000 # number of documents to use for training
38
  tokenizer_save_dir: str = "./tokenizer"
39
 
40
+ # ── Validation data ──
41
+ val_ratio: float = 0.001 # use 0.1% of total data for validation
llm_lab/config/eval_config.py CHANGED
@@ -3,18 +3,18 @@ from dataclasses import dataclass
3
 
4
  @dataclass
5
  class EvalConfig:
6
- """평가 파라미터."""
7
  # ── Perplexity ──
8
  eval_batch_size: int = 4
9
- max_eval_batches: int = 100 # 최대 평가 배치
10
 
11
- # ── 생성 ──
12
  max_new_tokens: int = 200
13
  temperature: float = 0.8
14
  top_k: int = 50
15
  top_p: float = 0.9
16
- num_samples: int = 3 # 프롬프트당 생성 횟수
17
 
18
- # ── 출력 ──
19
  save_dir: str = "./eval_results"
20
  plot_dpi: int = 150
 
3
 
4
  @dataclass
5
  class EvalConfig:
6
+ """Evaluation parameters."""
7
  # ── Perplexity ──
8
  eval_batch_size: int = 4
9
+ max_eval_batches: int = 100 # maximum number of evaluation batches
10
 
11
+ # ── Generation ──
12
  max_new_tokens: int = 200
13
  temperature: float = 0.8
14
  top_k: int = 50
15
  top_p: float = 0.9
16
+ num_samples: int = 3 # number of generations per prompt
17
 
18
+ # ── Output ──
19
  save_dir: str = "./eval_results"
20
  plot_dpi: int = 150
llm_lab/config/model_config.py CHANGED
@@ -3,37 +3,37 @@ from dataclasses import dataclass
3
 
4
  @dataclass
5
  class ModelConfig:
6
- """모델 하이퍼파라미터를 하나의 데이터클래스로 관리합니다.
7
 
8
- 규모별 프리셋:
9
- - debug: ~10M (파이프라인 검증용)
10
- - small: ~100M (중간 검증용)
11
- - base: ~1.1B (최종 목표)
12
  """
13
  vocab_size: int = 32_000
14
- hidden_dim: int = 2048 # d_model: 모델의 기본 차원
15
- num_layers: int = 22 # Transformer 블록
16
- num_heads: int = 16 # Query 헤드
17
- num_kv_heads: int = 4 # Key/Value 헤드 (GQA)
18
- intermediate_dim: int = 5632 # FFN 중간 차원 (≈ 2.75 × hidden_dim)
19
- max_seq_len: int = 2048 # 최대 시퀀스 길이
20
- dropout: float = 0.0 # Pretraining에서는 보통 0 사용
21
- rope_theta: float = 10000.0 # RoPE 주파수 베이스
22
  norm_eps: float = 1e-6 # RMSNorm epsilon
23
 
24
  @property
25
  def head_dim(self) -> int:
26
- """ 어텐션 헤드의 차원."""
27
  return self.hidden_dim // self.num_heads
28
 
29
  @property
30
  def num_kv_groups(self) -> int:
31
- """GQA에서 하나의 KV 헤드가 담당하는 Q 헤드 ."""
32
  return self.num_heads // self.num_kv_heads
33
 
34
  @classmethod
35
  def debug_10m(cls) -> "ModelConfig":
36
- """~10M 파라미터 - 빠른 디버깅용."""
37
  return cls(
38
  hidden_dim=256, num_layers=6, num_heads=8,
39
  num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
@@ -41,7 +41,7 @@ class ModelConfig:
41
 
42
  @classmethod
43
  def small_100m(cls) -> "ModelConfig":
44
- """~100M 파라미터 - 중간 검증용."""
45
  return cls(
46
  hidden_dim=768, num_layers=12, num_heads=12,
47
  num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
@@ -49,5 +49,5 @@ class ModelConfig:
49
 
50
  @classmethod
51
  def base_1b(cls) -> "ModelConfig":
52
- """~1.1B 파라미터 - 최종 학습 목표."""
53
- return cls() # 기본값이 1B 설정
 
3
 
4
  @dataclass
5
  class ModelConfig:
6
+ """Manages model hyperparameters as a single dataclass.
7
 
8
+ Scale-specific presets:
9
+ - debug: ~10M (for pipeline validation)
10
+ - small: ~100M (for intermediate validation)
11
+ - base: ~1.1B (final target)
12
  """
13
  vocab_size: int = 32_000
14
+ hidden_dim: int = 2048 # d_model: base dimension of the model
15
+ num_layers: int = 22 # number of Transformer blocks
16
+ num_heads: int = 16 # number of Query heads
17
+ num_kv_heads: int = 4 # number of Key/Value heads (GQA)
18
+ intermediate_dim: int = 5632 # FFN intermediate dimension (≈ 2.75 × hidden_dim)
19
+ max_seq_len: int = 2048 # maximum sequence length
20
+ dropout: float = 0.0 # typically 0 during pretraining
21
+ rope_theta: float = 10000.0 # RoPE frequency base
22
  norm_eps: float = 1e-6 # RMSNorm epsilon
23
 
24
  @property
25
  def head_dim(self) -> int:
26
+ """Dimension of each attention head."""
27
  return self.hidden_dim // self.num_heads
28
 
29
  @property
30
  def num_kv_groups(self) -> int:
31
+ """Number of Q heads per KV head in GQA."""
32
  return self.num_heads // self.num_kv_heads
33
 
34
  @classmethod
35
  def debug_10m(cls) -> "ModelConfig":
36
+ """~10M parameters - for fast debugging."""
37
  return cls(
38
  hidden_dim=256, num_layers=6, num_heads=8,
39
  num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
 
41
 
42
  @classmethod
43
  def small_100m(cls) -> "ModelConfig":
44
+ """~100M parameters - for intermediate validation."""
45
  return cls(
46
  hidden_dim=768, num_layers=12, num_heads=12,
47
  num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
 
49
 
50
  @classmethod
51
  def base_1b(cls) -> "ModelConfig":
52
+ """~1.1B parameters - final training target."""
53
+ return cls() # defaults are the 1B configuration
llm_lab/config/train_config.py CHANGED
@@ -6,97 +6,97 @@ import torch
6
 
7
  @dataclass
8
  class TrainConfig:
9
- """학습 하이퍼파라미터 + 인프라 설정.
10
 
11
- Colab Pro+ (A100 40GB) 기준 최적화된 기본값.
12
- 모든 값에 '왜 값인지' 설명을 포함합니다.
13
  """
14
 
15
- # ── 최적화 ──
16
  learning_rate: float = 3e-4
17
- """Peak LR. 1B 모델 기준 3e-4가 표준.
18
- GPT-3 논문에서 모델 크기별 최적 LR을 제시:
19
  125M → 6e-4, 350M → 3e-4, 1.3B → 2e-4
20
- 우리 모델(1.1B) 3e-4에서 시작, 불안정하면 2e-4 하향."""
21
 
22
  min_learning_rate: float = 3e-5
23
- """Cosine decay 최저점. 보통 peak의 10%.
24
- 너무 낮으면 학습 후반 정체, 너무 높으면 수렴 불안정."""
25
 
26
  weight_decay: float = 0.1
27
- """AdamW의 L2 정규화. 0.1 LLM 표준.
28
- Embedding과 Bias에는 적용하지 않음 (관례)."""
29
 
30
  beta1: float = 0.9
31
  beta2: float = 0.95
32
- """Adam 모멘텀 계수. β2=0.95 LLM 학습에서 β2=0.999보다 안정적.
33
- 배치 + 학습에서 β2 너무 크면 적응 속도가 느림."""
34
 
35
  adam_eps: float = 1e-8
36
  grad_clip: float = 1.0
37
- """Gradient Clipping: gradient norm 1.0을 초과하면 스케일링.
38
- 학습 초반이나 노이즈 데이터에서 발생하는 gradient spike 방지."""
39
 
40
- # ── 스케줄링 ──
41
  warmup_steps: int = 2000
42
- """Warmup: 처음 2000 스텝 동안 LR 0 peak 선형 증가.
43
- 필요한가?
44
- - 초기 가중치가 랜덤 LR 불안정한 업데이트 유발
45
- - 작은 LR 시작해 모델이 '방향'을 잡게 본격 학습
46
- - 2000 전체 학습의 ~10% 적당 (경험적 규칙)."""
47
 
48
  total_steps: int = 20_000
49
- """ 학습 스텝 .
50
- 10B tokens / (128 batch × 2048 seq_len) ≈ 38,000 이지만,
51
- gradient accumulation 포함 effective step 기준 ~20,000."""
52
 
53
- # ── 배치 ──
54
  micro_batch_size: int = 4
55
- """GPU에 번에 올리는 배치 크기.
56
- A100 40GB에서 1B 모델 bf16 기준 4가 안전한 상한."""
57
 
58
  gradient_accumulation_steps: int = 32
59
- """Gradient 누적 횟수. Effective batch = 4 × 32 = 128.
60
- 배치가 좋은가?
61
- - gradient 추정이 안정적 (노이즈 감소)
62
- - LLM 학습은 보통 effective batch 128~512
63
- - 메모리 부족 값을 늘리고 micro_batch를 줄임."""
64
 
65
  # ── Mixed Precision ──
66
  dtype: str = "bfloat16"
67
- """bfloat16: A100에서 지원, fp16보다 수치 안정성 우수.
68
- exponent 비트가 fp32 동일 → overflow/underflow 위험 적음.
69
- T4/V100 폴백 'float16'으로 변경."""
70
 
71
- # ── 체크포인트 ──
72
  checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints"
73
- """Google Drive 경로. Colab 세션 만료 시에도 보존됨."""
74
 
75
  checkpoint_interval: int = 500
76
- """500 스텝마다 체크포인트 저장.
77
- A100 기준 ~30 간격. 너무 잦으면 I/O 오버헤드,
78
- 너무 드물면 세션 만료 손실 ."""
79
 
80
  max_checkpoints: int = 3
81
- """롤링 보관 수. 오래된 것부터 삭제.
82
- 체크포인트 1개 ≈ 8-10GB → 3개면 ~30GB."""
83
 
84
- # ── 로깅 ──
85
  log_interval: int = 10
86
- """10 스텝마다 콘솔 + wandb 로깅."""
87
 
88
  eval_interval: int = 500
89
- """500 스텝마다 검증 Loss 측정."""
90
 
91
  eval_steps: int = 20
92
- """검증 사용할 배치 . 20 × 4 × 2048 ≈ 160K 토큰."""
93
 
94
  # ── wandb ──
95
  wandb_project: str = "llm-1b-lab"
96
  wandb_run_name: Optional[str] = None
97
  use_wandb: bool = True
98
 
99
- # ── 재현성 ──
100
  seed: int = 42
101
 
102
  @property
@@ -105,8 +105,8 @@ class TrainConfig:
105
 
106
  @property
107
  def tokens_per_step(self) -> int:
108
- """ optimizer step당 처리 토큰 ."""
109
- # max_seq_len 외부에서 주입 (ModelConfig 참조)
110
  return self.effective_batch_size * 2048
111
 
112
  @property
 
6
 
7
  @dataclass
8
  class TrainConfig:
9
+ """Training hyperparameters and infrastructure configuration.
10
 
11
+ Default values optimized for Colab Pro+ (A100 40GB).
12
+ Each value includes an explanation of why it was chosen.
13
  """
14
 
15
+ # ── Optimization ──
16
  learning_rate: float = 3e-4
17
+ """Peak LR. 3e-4 is the standard for 1B-scale models.
18
+ The GPT-3 paper reports optimal LRs by model size:
19
  125M → 6e-4, 350M → 3e-4, 1.3B → 2e-4
20
+ Our model (1.1B) starts at 3e-4; lower to 2e-4 if unstable."""
21
 
22
  min_learning_rate: float = 3e-5
23
+ """Minimum point of cosine decay. Typically 10% of peak.
24
+ Too low causes stagnation in later training; too high causes unstable convergence."""
25
 
26
  weight_decay: float = 0.1
27
+ """L2 regularization for AdamW. 0.1 is the LLM standard.
28
+ Not applied to embeddings and biases (by convention)."""
29
 
30
  beta1: float = 0.9
31
  beta2: float = 0.95
32
+ """Adam momentum coefficients. β2=0.95 is more stable than β2=0.999 for LLM training.
33
+ With large batches and long training, a β2 that is too large slows adaptation."""
34
 
35
  adam_eps: float = 1e-8
36
  grad_clip: float = 1.0
37
+ """Gradient Clipping: rescales gradients when their norm exceeds 1.0.
38
+ Prevents gradient spikes that occur during early training or with noisy data."""
39
 
40
+ # ── Scheduling ──
41
  warmup_steps: int = 2000
42
+ """Warmup: linearly increases LR from 0 to peak over the first 2000 steps.
43
+ Why is this necessary?
44
+ - Initial weights are random large LR causes unstable updates
45
+ - Starting with a small LR lets the model find its direction before full training
46
+ - 2000 is roughly ~10% of total training steps (empirical rule)."""
47
 
48
  total_steps: int = 20_000
49
+ """Total number of training steps.
50
+ 10B tokens / (128 batch × 2048 seq_len) ≈ 38,000, but
51
+ ~20,000 effective steps when accounting for gradient accumulation."""
52
 
53
+ # ── Batch ──
54
  micro_batch_size: int = 4
55
+ """Batch size loaded onto the GPU at once.
56
+ 4 is a safe upper bound for a 1B model in bf16 on an A100 40GB."""
57
 
58
  gradient_accumulation_steps: int = 32
59
+ """Number of gradient accumulation steps. Effective batch = 4 × 32 = 128.
60
+ Why is a large batch beneficial?
61
+ - More stable gradient estimates (reduced noise)
62
+ - LLM training typically uses an effective batch of 128512
63
+ - When memory is limited, increase this and reduce micro_batch."""
64
 
65
  # ── Mixed Precision ──
66
  dtype: str = "bfloat16"
67
+ """bfloat16: supported on A100, numerically more stable than fp16.
68
+ Uses the same number of exponent bits as fp32 → lower risk of overflow/underflow.
69
+ Change to 'float16' when falling back to T4/V100."""
70
 
71
+ # ── Checkpointing ──
72
  checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints"
73
+ """Google Drive path. Preserved even when the Colab session expires."""
74
 
75
  checkpoint_interval: int = 500
76
+ """Save a checkpoint every 500 steps.
77
+ Roughly every ~30 minutes on an A100. Too frequent causes I/O overhead;
78
+ too infrequent risks large losses when the session expires."""
79
 
80
  max_checkpoints: int = 3
81
+ """Number of rolling checkpoints to retain; oldest are deleted first.
82
+ One checkpoint ≈ 8–10 GB → 3 checkpoints ≈ ~30 GB."""
83
 
84
+ # ── Logging ──
85
  log_interval: int = 10
86
+ """Log to console and wandb every 10 steps."""
87
 
88
  eval_interval: int = 500
89
+ """Measure validation loss every 500 steps."""
90
 
91
  eval_steps: int = 20
92
+ """Number of batches to use during validation. 20 × 4 × 2048 ≈ 160K tokens."""
93
 
94
  # ── wandb ──
95
  wandb_project: str = "llm-1b-lab"
96
  wandb_run_name: Optional[str] = None
97
  use_wandb: bool = True
98
 
99
+ # ── Reproducibility ──
100
  seed: int = 42
101
 
102
  @property
 
105
 
106
  @property
107
  def tokens_per_step(self) -> int:
108
+ """Number of tokens processed per optimizer step."""
109
+ # max_seq_len is injected externally (see ModelConfig)
110
  return self.effective_batch_size * 2048
111
 
112
  @property
llm_lab/data/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """데이터 파이프라인 모듈토크나이저, 스트리밍, 시퀀스 패킹."""
2
  from .tokenizer import Tokenizer
3
  from .dataset import PackedStreamingDataset, ValidationDataset
4
  from .pipeline import create_train_dataloader, train_tokenizer_from_dataset, setup_data_pipeline
 
1
+ """Data pipeline moduletokenizer, streaming, and sequence packing."""
2
  from .tokenizer import Tokenizer
3
  from .dataset import PackedStreamingDataset, ValidationDataset
4
  from .pipeline import create_train_dataloader, train_tokenizer_from_dataset, setup_data_pipeline
llm_lab/data/dataset.py CHANGED
@@ -1,4 +1,4 @@
1
- """스트리밍 데이터셋시퀀스 패킹, 검증 데이터셋."""
2
 
3
  from typing import Iterator, List, Dict, Optional
4
 
@@ -10,24 +10,24 @@ from .tokenizer import Tokenizer
10
 
11
 
12
  class PackedStreamingDataset(IterableDataset):
13
- """Streaming + 시퀀스 패킹 데이터셋.
14
 
15
- 시퀀스 패킹인가?
16
- - 일반적 방법: 문서를 max_seq_len으로 잘라 패딩GPU 낭비
17
- - 시퀀스 패킹: 여러 문서를 이어붙여 max_seq_len을 채움 → 100% 활용
18
 
19
- 동작 방식:
20
- 문서1 (300 토큰) + 문서2 (1500 토큰) + 문서3 (248 토큰) = 2048 토큰
21
- → [문서1][EOS][문서2][EOS][문서3][EOS][...패딩 없이 맞춤]
22
 
23
- Streaming인가?
24
- - FineWeb-Edu 10B 샘플: 압축 상태에서도 수십 GB
25
- - Colab 디스크 한계 (~200GB)에서 전체 다운로드 불가
26
- - Streaming: 필요한 만큼만 네트워크에서 읽어옴
27
 
28
- 학습 주의사항:
29
- - 시퀀스 문서 경계에 EOS 토큰 삽입으로 모델이 문서 끝을 인식
30
- - Cross-Attention 마스크 없이도 EOS가 자연스러운 경계 역할
31
  """
32
 
33
  def __init__(
@@ -45,15 +45,16 @@ class PackedStreamingDataset(IterableDataset):
45
  self.max_seq_len = config.max_seq_len
46
 
47
  def _load_dataset(self, num_shards: int = 1, shard_index: int = 0):
48
- """HuggingFace 데이터셋을 스트리밍 모드로 로드합니다.
49
 
50
  Args:
51
- num_shards: 전체 샤드 (= DataLoader num_workers)
52
- shard_index: 워커가 담당할 샤드 번호 (0 ~ num_shards-1)
53
 
54
- 샤딩 원리:
55
- num_shards=4 스트림을 4등분하여 워커가 서로 다른 1/4만 처리.
56
- 셔플은 샤딩 이후에 적용하므로 워커 문서 중복이 없음.
 
57
  """
58
  from datasets import load_dataset
59
 
@@ -61,87 +62,87 @@ class PackedStreamingDataset(IterableDataset):
61
  self.config.dataset_name,
62
  name=self.config.dataset_subset,
63
  split=self.config.dataset_split,
64
- streaming=True, # 핵심: 스트리밍 모드
65
  trust_remote_code=True,
66
  )
67
 
68
- # 완전 분할(샤딩): 워커 i 전체 스트림의 1/num_shards 구간만 처리
69
- # 반드시 셔플 전에 적용해야 워커가 겹치지 않는 문서 집합을 가짐
70
  if num_shards > 1:
71
  ds = ds.shard(num_shards=num_shards, index=shard_index)
72
 
73
- # 셔플 (스트리밍에서는 버퍼 기반 근사 셔플)
74
  ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
75
 
76
  return ds
77
 
78
  def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
79
- """문서를 토크나이즈하고 시퀀스 패킹합니다.
80
 
81
  Yields:
82
  {"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
83
 
84
- targets = input_ids shift:
85
  input_ids: [A, B, C, D, E]
86
  targets: [B, C, D, E, F]
87
- 모델은 A 보고 B를 예측, B 보고 C를 예측, ...
88
  """
89
- buffer: List[int] = [] # 토큰 버퍼
90
 
91
  for example in dataset:
92
  text = example[self.config.text_column]
93
  if not text or not text.strip():
94
  continue
95
 
96
- # 토크나이즈 (특수 토큰 없이)
97
  token_ids = self.tokenizer.encode(text, add_special_tokens=False)
98
 
99
  if not token_ids:
100
  continue
101
 
102
- # EOS 토큰 추가 (문서 경계 표시)
103
  if self.config.use_eos_separator:
104
  token_ids.append(self.tokenizer.eos_id)
105
 
106
- # 버퍼에 추가
107
  buffer.extend(token_ids)
108
 
109
- # 버퍼가 충분히 차면 시퀀스 생성
110
- # +1 targets 생성을 위해 (input + 다음 토큰)
111
  while len(buffer) >= self.max_seq_len + 1:
112
- # max_seq_len + 1 만큼 꺼냄
113
  chunk = buffer[: self.max_seq_len + 1]
114
  buffer = buffer[self.max_seq_len + 1 :]
115
 
116
- # input_ids: 처음 ~ 끝에서 번째
117
  input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
118
- # targets: 번째 ~ ( shift)
119
  targets = torch.tensor(chunk[1:], dtype=torch.long)
120
 
121
  yield {"input_ids": input_ids, "targets": targets}
122
 
123
  def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
124
- """DataLoader가 호출하는 이터레이터.
125
 
126
- 멀티 워커 지원 (완전 분할 방식):
127
- - 이전: 모든 워커가 동일한 스트림을 읽고 시드만 달리함문서 중복 가능
128
- - 개선: ds.shard() 스트림을 num_workers등분워커 문서 중복 없음
129
 
130
- 예시 (num_workers=4, 전체 문서 N):
131
- Worker 0: 문서 0, 4, 8, 12, ... (N/4)
132
- Worker 1: 문서 1, 5, 9, 13, ... (N/4)
133
- Worker 2: 문서 2, 6, 10, 14, ... (N/4)
134
- Worker 3: 문서 3, 7, 11, 15, ... (N/4)
135
  """
136
  worker_info = torch.utils.data.get_worker_info()
137
 
138
  if worker_info is not None:
139
- # 완전 분할: 워커별 샤드 할당 + 독립적인 셔플 시드
140
  num_shards = worker_info.num_workers
141
  shard_index = worker_info.id
142
  worker_seed = self.seed + worker_info.id
143
  else:
144
- # 단일 프로세스: 샤딩 없이 전체 스트림 처리
145
  num_shards = 1
146
  shard_index = 0
147
  worker_seed = self.seed
@@ -153,10 +154,10 @@ class PackedStreamingDataset(IterableDataset):
153
 
154
 
155
  class ValidationDataset:
156
- """검증용 데이터셋.
157
 
158
- Streaming 데이터셋에서 일정량을 미리 가져와 메모리에 저장합니다.
159
- 에폭 동일한 데이터로 평가해야 비교가 의미 있기 때문입니다.
160
  """
161
 
162
  def __init__(
@@ -174,10 +175,10 @@ class ValidationDataset:
174
  self._prepare(seed)
175
 
176
  def _prepare(self, seed: int):
177
- """데이터셋에서 검증 샘플을 미리 추출합니다."""
178
  from datasets import load_dataset
179
 
180
- print(f"[Validation] {self.num_samples} 검증 샘플 준비 중...")
181
 
182
  ds = load_dataset(
183
  self.config.dataset_name,
@@ -186,7 +187,7 @@ class ValidationDataset:
186
  streaming=True,
187
  trust_remote_code=True,
188
  )
189
- # 학습 데이터와 겹치지 않도록 다른 시드, 앞부분 건너뛰기
190
  ds = ds.shuffle(seed=seed, buffer_size=5_000)
191
 
192
  buffer: List[int] = []
@@ -217,10 +218,10 @@ class ValidationDataset:
217
  })
218
  count += 1
219
 
220
- print(f"[Validation] {len(self.samples)} 샘플 준비 완료")
221
 
222
  def get_dataloader(self, batch_size: int) -> DataLoader:
223
- """검증 DataLoader를 반환합니다."""
224
  return DataLoader(
225
  self.samples,
226
  batch_size=batch_size,
@@ -231,10 +232,10 @@ class ValidationDataset:
231
 
232
 
233
  def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
234
- """배치 샘플들을 하나의 텐서로 합칩니다.
235
 
236
- 시퀀스 패킹 덕분에 모든 샘플이 동일한 길이(max_seq_len)이므로
237
- 추가 패딩이 필요 없습니다.
238
  """
239
  return {
240
  "input_ids": torch.stack([s["input_ids"] for s in batch]),
 
1
+ """Streaming datasetsequence packing and validation dataset."""
2
 
3
  from typing import Iterator, List, Dict, Optional
4
 
 
10
 
11
 
12
  class PackedStreamingDataset(IterableDataset):
13
+ """Streaming + sequence packing dataset.
14
 
15
+ Why sequence packing?
16
+ - Naive approach: truncate each document to max_seq_len with paddingwastes GPU
17
+ - Sequence packing: concatenate multiple documents to fill max_seq_len → 100% utilization
18
 
19
+ How it works:
20
+ Doc1 (300 tokens) + Doc2 (1500 tokens) + Doc3 (248 tokens) = 2048 tokens
21
+ → [Doc1][EOS][Doc2][EOS][Doc3][EOS][... no padding, fits exactly]
22
 
23
+ Why streaming?
24
+ - FineWeb-Edu 10B samples: tens of GB even when compressed
25
+ - Full download not feasible on Colab disk limit (~200GB)
26
+ - Streaming: reads from the network only as much as needed
27
 
28
+ Notes for training:
29
+ - EOS token inserted at document boundaries so the model recognizes end-of-document
30
+ - EOS naturally serves as a boundary marker without cross-attention masking
31
  """
32
 
33
  def __init__(
 
45
  self.max_seq_len = config.max_seq_len
46
 
47
  def _load_dataset(self, num_shards: int = 1, shard_index: int = 0):
48
+ """Loads the HuggingFace dataset in streaming mode.
49
 
50
  Args:
51
+ num_shards: Total number of shards (= DataLoader num_workers)
52
+ shard_index: The shard index this worker is responsible for (0 ~ num_shards-1)
53
 
54
+ Sharding principle:
55
+ With num_shards=4, the stream is split into 4 equal parts so each worker
56
+ processes a distinct 1/4. Shuffling is applied after sharding so there is
57
+ no document overlap between workers.
58
  """
59
  from datasets import load_dataset
60
 
 
62
  self.config.dataset_name,
63
  name=self.config.dataset_subset,
64
  split=self.config.dataset_split,
65
+ streaming=True, # Key: streaming mode
66
  trust_remote_code=True,
67
  )
68
 
69
+ # Full partitioning (sharding): worker i processes only 1/num_shards of the stream
70
+ # Must be applied before shuffling so each worker has a non-overlapping set of documents
71
  if num_shards > 1:
72
  ds = ds.shard(num_shards=num_shards, index=shard_index)
73
 
74
+ # Shuffle (approximate buffer-based shuffle in streaming mode)
75
  ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
76
 
77
  return ds
78
 
79
  def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
80
+ """Tokenizes documents and packs them into sequences.
81
 
82
  Yields:
83
  {"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
84
 
85
+ targets = input_ids shifted by one position:
86
  input_ids: [A, B, C, D, E]
87
  targets: [B, C, D, E, F]
88
+ The model sees A and predicts B, sees B and predicts C, ...
89
  """
90
+ buffer: List[int] = [] # Token buffer
91
 
92
  for example in dataset:
93
  text = example[self.config.text_column]
94
  if not text or not text.strip():
95
  continue
96
 
97
+ # Tokenize (without special tokens)
98
  token_ids = self.tokenizer.encode(text, add_special_tokens=False)
99
 
100
  if not token_ids:
101
  continue
102
 
103
+ # Append EOS token (marks document boundary)
104
  if self.config.use_eos_separator:
105
  token_ids.append(self.tokenizer.eos_id)
106
 
107
+ # Add to buffer
108
  buffer.extend(token_ids)
109
 
110
+ # Generate sequences once the buffer is full enough
111
+ # +1 is needed to generate targets (input + next token)
112
  while len(buffer) >= self.max_seq_len + 1:
113
+ # Extract max_seq_len + 1 tokens
114
  chunk = buffer[: self.max_seq_len + 1]
115
  buffer = buffer[self.max_seq_len + 1 :]
116
 
117
+ # input_ids: from the first to the second-to-last token
118
  input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
119
+ # targets: from the second to the last token (shifted by one)
120
  targets = torch.tensor(chunk[1:], dtype=torch.long)
121
 
122
  yield {"input_ids": input_ids, "targets": targets}
123
 
124
  def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
125
+ """Iterator called by DataLoader.
126
 
127
+ Multi-worker support (full partitioning approach):
128
+ - Previous: all workers read the same stream with different seeds possible document duplication
129
+ - Improved: ds.shard() splits the stream into num_workers parts no document overlap between workers
130
 
131
+ Example (num_workers=4, total N documents):
132
+ Worker 0: docs 0, 4, 8, 12, ... (N/4 docs)
133
+ Worker 1: docs 1, 5, 9, 13, ... (N/4 docs)
134
+ Worker 2: docs 2, 6, 10, 14, ... (N/4 docs)
135
+ Worker 3: docs 3, 7, 11, 15, ... (N/4 docs)
136
  """
137
  worker_info = torch.utils.data.get_worker_info()
138
 
139
  if worker_info is not None:
140
+ # Full partitioning: assign a shard per worker + independent shuffle seed
141
  num_shards = worker_info.num_workers
142
  shard_index = worker_info.id
143
  worker_seed = self.seed + worker_info.id
144
  else:
145
+ # Single process: process the full stream without sharding
146
  num_shards = 1
147
  shard_index = 0
148
  worker_seed = self.seed
 
154
 
155
 
156
  class ValidationDataset:
157
+ """Validation dataset.
158
 
159
+ Pre-fetches a fixed amount of data from the streaming dataset and stores it in memory.
160
+ Consistent data across evaluations is necessary for meaningful comparisons between epochs.
161
  """
162
 
163
  def __init__(
 
175
  self._prepare(seed)
176
 
177
  def _prepare(self, seed: int):
178
+ """Pre-extracts validation samples from the dataset."""
179
  from datasets import load_dataset
180
 
181
+ print(f"[Validation] Preparing {self.num_samples} validation samples...")
182
 
183
  ds = load_dataset(
184
  self.config.dataset_name,
 
187
  streaming=True,
188
  trust_remote_code=True,
189
  )
190
+ # Use a different seed and skip the beginning to avoid overlap with training data
191
  ds = ds.shuffle(seed=seed, buffer_size=5_000)
192
 
193
  buffer: List[int] = []
 
218
  })
219
  count += 1
220
 
221
+ print(f"[Validation] {len(self.samples)} samples ready")
222
 
223
  def get_dataloader(self, batch_size: int) -> DataLoader:
224
+ """Returns a validation DataLoader."""
225
  return DataLoader(
226
  self.samples,
227
  batch_size=batch_size,
 
232
 
233
 
234
  def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
235
+ """Combines samples in a batch into a single tensor.
236
 
237
+ Because of sequence packing, all samples have the same length (max_seq_len),
238
+ so no additional padding is needed.
239
  """
240
  return {
241
  "input_ids": torch.stack([s["input_ids"] for s in batch]),
llm_lab/data/diagnostics.py CHANGED
@@ -1,4 +1,4 @@
1
- """데이터 파이프라인 진단 도구."""
2
 
3
  import time
4
  from typing import Dict
@@ -11,13 +11,13 @@ from .tokenizer import Tokenizer
11
 
12
 
13
  class DataPipelineDiagnostics:
14
- """데이터 파이프라인의 성능과 품질을 진단합니다.
15
 
16
- 학습 반드시 확인해야 할 항목:
17
- 1) 토크나이저 품질: 평균 토큰/문서, 없는 토큰 비율
18
- 2) 패킹 효율: 실제 토큰 비율 vs 패딩 비율
19
- 3) 처리 속도: tokens/sec (데이터 로딩 병목 확인)
20
- 4) 배치 형태: shape, dtype 정확성
21
  """
22
 
23
  @staticmethod
@@ -26,11 +26,11 @@ class DataPipelineDiagnostics:
26
  config: DataConfig,
27
  num_samples: int = 1000,
28
  ):
29
- """토크나이저 품질을 진단합니다."""
30
  from datasets import load_dataset
31
 
32
  print("\n" + "=" * 60)
33
- print("📊 토크나이저 품질 진단")
34
  print("=" * 60)
35
 
36
  ds = load_dataset(
@@ -59,24 +59,24 @@ class DataPipelineDiagnostics:
59
 
60
  avg_tokens = sum(token_counts) / len(token_counts)
61
  avg_chars = sum(char_counts) / len(char_counts)
62
- compression_ratio = avg_chars / avg_tokens # 문자/토큰 비율
63
 
64
- print(f" 분석 문서 수: {len(token_counts):,}")
65
- print(f" 평균 토큰/문서: {avg_tokens:.1f}")
66
- print(f" 평균 문자/문서: {avg_chars:.1f}")
67
- print(f" 압축 비율 (문자/토큰): {compression_ratio:.2f}")
68
- print(f" 영어 기준 3.5~4.5 정상")
69
- print(f" 최소 토큰: {min(token_counts)}, 최대: {max(token_counts)}")
70
 
71
- # 디코드 왕복 테스트
72
  test_text = "The quick brown fox jumps over the lazy dog."
73
  encoded = tokenizer.encode(test_text)
74
  decoded = tokenizer.decode(encoded)
75
  roundtrip_ok = test_text.strip() in decoded.strip()
76
- print(f"\n 왕복 테스트: {'✅ 통과' if roundtrip_ok else '❌ 실패'}")
77
- print(f" 원본: {test_text}")
78
- print(f" 인코딩: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
79
- print(f" 디코딩: {decoded}")
80
 
81
  @staticmethod
82
  def benchmark_throughput(
@@ -84,13 +84,13 @@ class DataPipelineDiagnostics:
84
  num_batches: int = 50,
85
  seq_len: int = 2048,
86
  ):
87
- """데이터 로딩 처리량을 측정합니다.
88
 
89
- GPU 학습 속도의 병목이 데이터 로딩인지 확인하는 핵심 진단.
90
- 목표: 데이터 로딩이 GPU 연산보다 빨라야 (data loading bottleneck).
91
  """
92
  print("\n" + "=" * 60)
93
- print(" 데이터 로딩 처리량 벤치마크")
94
  print("=" * 60)
95
 
96
  total_tokens = 0
@@ -110,23 +110,23 @@ class DataPipelineDiagnostics:
110
  elapsed = time.time() - start_time
111
  tps = total_tokens / elapsed
112
 
113
- print(f"\n 배치 수: {num_batches}")
114
- print(f" 토큰 수: {total_tokens:,}")
115
- print(f" 소요 시간: {elapsed:.2f}")
116
- print(f" 평균 처리량: {tps:,.0f} tokens/sec")
117
- print(f"\n 💡 A100 학습 처리량 ~50-80K tokens/sec 기준:")
118
  if tps > 80_000:
119
- print(f" 데이터 로딩이 병목이 아닙��다")
120
  elif tps > 30_000:
121
- print(f" ⚠️ 경계선 - num_workers 증가를 고려하세요")
122
  else:
123
- print(f" 데이터 로딩이 병목! num_workers/prefetch 조정 필요")
124
 
125
  @staticmethod
126
  def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
127
- """배치 하나를 상세 검사합니다."""
128
  print("\n" + "=" * 60)
129
- print("🔍 배치 상세 검사")
130
  print("=" * 60)
131
 
132
  input_ids = batch["input_ids"]
@@ -135,19 +135,19 @@ class DataPipelineDiagnostics:
135
  print(f" input_ids shape: {input_ids.shape}")
136
  print(f" targets shape: {targets.shape}")
137
  print(f" dtype: {input_ids.dtype}")
138
- print(f" 범위: [{input_ids.min().item()}, {input_ids.max().item()}]")
139
 
140
- # Shift 관계 확인: targets[i] == input_ids[i+1]
141
  shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
142
- print(f" Shift 정합성: {shift_correct*100:.1f}% (100%여야 정상)")
143
 
144
- # EOS 토큰 분포 (문서 경계)
145
  eos_count = (input_ids == tokenizer.eos_id).sum().item()
146
  total_tokens = input_ids.numel()
147
- print(f" EOS 토큰 : {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
148
 
149
- # 번째 샘플 디코딩 미리보기
150
  first_sample = input_ids[0][:100].tolist()
151
  decoded_preview = tokenizer.decode(first_sample)
152
- print(f"\n 샘플 디코딩 (처음 100 토큰):")
153
  print(f" {decoded_preview[:300]}...")
 
1
+ """Data pipeline diagnostic tools."""
2
 
3
  import time
4
  from typing import Dict
 
11
 
12
 
13
  class DataPipelineDiagnostics:
14
+ """Diagnoses the performance and quality of the data pipeline.
15
 
16
+ Items to verify before training:
17
+ 1) Tokenizer quality: average tokens/document, unknown token ratio
18
+ 2) Packing efficiency: actual token ratio vs. padding ratio
19
+ 3) Throughput: tokens/sec (check for data loading bottlenecks)
20
+ 4) Batch shape: correctness of shape and dtype
21
  """
22
 
23
  @staticmethod
 
26
  config: DataConfig,
27
  num_samples: int = 1000,
28
  ):
29
+ """Diagnoses tokenizer quality."""
30
  from datasets import load_dataset
31
 
32
  print("\n" + "=" * 60)
33
+ print("Tokenizer Quality Diagnostics")
34
  print("=" * 60)
35
 
36
  ds = load_dataset(
 
59
 
60
  avg_tokens = sum(token_counts) / len(token_counts)
61
  avg_chars = sum(char_counts) / len(char_counts)
62
+ compression_ratio = avg_chars / avg_tokens # Characters per token ratio
63
 
64
+ print(f" Documents analyzed: {len(token_counts):,}")
65
+ print(f" Average tokens/document: {avg_tokens:.1f}")
66
+ print(f" Average chars/document: {avg_chars:.1f}")
67
+ print(f" Compression ratio (chars/token): {compression_ratio:.2f}")
68
+ print(f" -> 3.5~4.5 is normal for English")
69
+ print(f" Min tokens: {min(token_counts)}, Max: {max(token_counts)}")
70
 
71
+ # Round-trip decode test
72
  test_text = "The quick brown fox jumps over the lazy dog."
73
  encoded = tokenizer.encode(test_text)
74
  decoded = tokenizer.decode(encoded)
75
  roundtrip_ok = test_text.strip() in decoded.strip()
76
+ print(f"\n Round-trip test: {'PASSED' if roundtrip_ok else 'FAILED'}")
77
+ print(f" Original: {test_text}")
78
+ print(f" Encoded: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
79
+ print(f" Decoded: {decoded}")
80
 
81
  @staticmethod
82
  def benchmark_throughput(
 
84
  num_batches: int = 50,
85
  seq_len: int = 2048,
86
  ):
87
+ """Measures data loading throughput.
88
 
89
+ A key diagnostic to determine whether data loading is the bottleneck in GPU training.
90
+ Goal: data loading should be faster than GPU computation (data loading != bottleneck).
91
  """
92
  print("\n" + "=" * 60)
93
+ print("Data Loading Throughput Benchmark")
94
  print("=" * 60)
95
 
96
  total_tokens = 0
 
110
  elapsed = time.time() - start_time
111
  tps = total_tokens / elapsed
112
 
113
+ print(f"\n Total batches: {num_batches}")
114
+ print(f" Total tokens: {total_tokens:,}")
115
+ print(f" Elapsed time: {elapsed:.2f}s")
116
+ print(f" Average throughput: {tps:,.0f} tokens/sec")
117
+ print(f"\n A100 training throughput reference ~50-80K tokens/sec:")
118
  if tps > 80_000:
119
+ print(f" Data loading is not the bottleneck")
120
  elif tps > 30_000:
121
+ print(f" Borderline - consider increasing num_workers")
122
  else:
123
+ print(f" Data loading is the bottleneck! Adjust num_workers/prefetch")
124
 
125
  @staticmethod
126
  def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
127
+ """Inspects a single batch in detail."""
128
  print("\n" + "=" * 60)
129
+ print("Batch Detailed Inspection")
130
  print("=" * 60)
131
 
132
  input_ids = batch["input_ids"]
 
135
  print(f" input_ids shape: {input_ids.shape}")
136
  print(f" targets shape: {targets.shape}")
137
  print(f" dtype: {input_ids.dtype}")
138
+ print(f" value range: [{input_ids.min().item()}, {input_ids.max().item()}]")
139
 
140
+ # Verify shift relationship: targets[i] == input_ids[i+1]
141
  shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
142
+ print(f" Shift consistency: {shift_correct*100:.1f}% (should be 100%)")
143
 
144
+ # EOS token distribution (document boundaries)
145
  eos_count = (input_ids == tokenizer.eos_id).sum().item()
146
  total_tokens = input_ids.numel()
147
+ print(f" EOS token count: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
148
 
149
+ # Decode preview of the first sample
150
  first_sample = input_ids[0][:100].tolist()
151
  decoded_preview = tokenizer.decode(first_sample)
152
+ print(f"\n First sample decoded (first 100 tokens):")
153
  print(f" {decoded_preview[:300]}...")
llm_lab/data/pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- """데이터 파이프라인 통합 — DataLoader 생성, 토크나이저 학습, Quick Start."""
2
 
3
  from typing import Optional
4
 
@@ -15,12 +15,12 @@ def create_train_dataloader(
15
  config: DataConfig,
16
  seed: int = 42,
17
  ) -> DataLoader:
18
- """학습용 DataLoader를 생성합니다.
19
 
20
  Returns:
21
- 무한히 반복되는 스트리밍 DataLoader
22
 
23
- 사용법:
24
  dataloader = create_train_dataloader(tokenizer, config)
25
  for step, batch in enumerate(dataloader):
26
  input_ids = batch["input_ids"].to(device) # (B, seq_len)
@@ -40,7 +40,7 @@ def create_train_dataloader(
40
  batch_size=config.batch_size,
41
  num_workers=config.num_workers,
42
  prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
43
- pin_memory=True, # GPU 전송 속도 향상
44
  collate_fn=_collate_fn,
45
  )
46
 
@@ -48,17 +48,17 @@ def create_train_dataloader(
48
 
49
 
50
  def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
51
- """데이터셋에서 BPE 토크나이저를 학습합니다.
52
 
53
- 전체 데이터를 사용할 필요 없이, 50K 문서면 충분합니다.
54
- 토크나이저 vocab 전체 데이터의 통계를 반영하면 되므로.
55
  """
56
  from datasets import load_dataset
57
 
58
- print(f"[Train Tokenizer] {config.dataset_name}에서 토크나이저 학습")
59
- print(f"[Train Tokenizer] 학습 문서 : {config.tokenizer_train_samples:,}")
60
 
61
- # 텍스트 이터레이터 생성
62
  ds = load_dataset(
63
  config.dataset_name,
64
  name=config.dataset_subset,
@@ -77,9 +77,9 @@ def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
77
  yield text
78
  count += 1
79
  if count % 10_000 == 0:
80
- print(f" ... {count:,} 문서 처리")
81
 
82
- # 토크나이저 학습
83
  tokenizer = Tokenizer(config)
84
  tokenizer.train_bpe(text_iterator(), save_dir=config.tokenizer_save_dir)
85
 
@@ -91,38 +91,38 @@ def setup_data_pipeline(
91
  tokenizer_path: Optional[str] = None,
92
  config: Optional[DataConfig] = None,
93
  ) -> tuple:
94
- """데이터 파이프라인을 번에 설정합니다.
95
 
96
  Args:
97
  tokenizer_mode:
98
- "train_new" - BPE 토크나이저 새로 학습
99
- "load_trained" - 이전에 학습한 토크나이저 로드
100
- "pretrained" - HuggingFace 사전학습 토크나이저 사용
101
  tokenizer_path:
102
- "train_new" 저장 경로 (기본: ./tokenizer)
103
- "load_trained" 저장된 토크나이저 경로
104
- "pretrained" HF 모델명 (기본: mistralai/Mistral-7B-v0.1)
105
 
106
  Returns:
107
  (tokenizer, train_dataloader, val_dataloader)
108
 
109
- 사용 예시 (Colab):
110
- # 방법 1: 토크나이저 새로 학습
111
  tok, train_dl, val_dl = setup_data_pipeline("train_new")
112
 
113
- # 방법 2: 기존 토크나이저 로드
114
  tok, train_dl, val_dl = setup_data_pipeline("load_trained", "./tokenizer")
115
 
116
- # 방법 3: 사전학습 토크나이저 (가장 간편)
117
  tok, train_dl, val_dl = setup_data_pipeline("pretrained")
118
  """
119
  config = config or DataConfig()
120
 
121
  print("=" * 60)
122
- print("🚀 데이터 파이프라인 설정")
123
  print("=" * 60)
124
 
125
- # ── Step 1: 토크나이저 ──
126
  tokenizer = Tokenizer(config)
127
 
128
  if tokenizer_mode == "train_new":
@@ -136,21 +136,21 @@ def setup_data_pipeline(
136
  else:
137
  raise ValueError(f"Unknown tokenizer_mode: {tokenizer_mode}")
138
 
139
- # ── Step 2: 학습 DataLoader ──
140
- print("\n[DataLoader] 학습 DataLoader 생성...")
141
  train_dataloader = create_train_dataloader(tokenizer, config)
142
 
143
- # ── Step 3: 검증 DataLoader ──
144
- print("\n[DataLoader] 검증 DataLoader 생성...")
145
  val_dataset = ValidationDataset(tokenizer, config, num_samples=100)
146
  val_dataloader = val_dataset.get_dataloader(batch_size=config.batch_size)
147
 
148
  print("\n" + "=" * 60)
149
- print(" 데이터 파이프라인 설정 완료!")
150
- print(f" 토크나이저 vocab: {tokenizer.vocab_size:,}")
151
- print(f" 시퀀스 길이: {config.max_seq_len}")
152
- print(f" 배치 크기: {config.batch_size}")
153
- print(f" 토큰/배치: {config.batch_size * config.max_seq_len:,}")
154
  print("=" * 60)
155
 
156
  return tokenizer, train_dataloader, val_dataloader
 
1
+ """Data pipeline integration — DataLoader creation, tokenizer training, and Quick Start."""
2
 
3
  from typing import Optional
4
 
 
15
  config: DataConfig,
16
  seed: int = 42,
17
  ) -> DataLoader:
18
+ """Creates a training DataLoader.
19
 
20
  Returns:
21
+ An infinitely repeating streaming DataLoader
22
 
23
+ Usage:
24
  dataloader = create_train_dataloader(tokenizer, config)
25
  for step, batch in enumerate(dataloader):
26
  input_ids = batch["input_ids"].to(device) # (B, seq_len)
 
40
  batch_size=config.batch_size,
41
  num_workers=config.num_workers,
42
  prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
43
+ pin_memory=True, # Improves GPU transfer speed
44
  collate_fn=_collate_fn,
45
  )
46
 
 
48
 
49
 
50
  def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
51
+ """Trains a BPE tokenizer from the dataset.
52
 
53
+ There is no need to use the entire dataset; 50K documents is sufficient,
54
+ since the tokenizer vocab only needs to reflect the statistics of the full data.
55
  """
56
  from datasets import load_dataset
57
 
58
+ print(f"[Train Tokenizer] Training tokenizer from {config.dataset_name}")
59
+ print(f"[Train Tokenizer] Number of training documents: {config.tokenizer_train_samples:,}")
60
 
61
+ # Create text iterator
62
  ds = load_dataset(
63
  config.dataset_name,
64
  name=config.dataset_subset,
 
77
  yield text
78
  count += 1
79
  if count % 10_000 == 0:
80
+ print(f" ... {count:,} documents processed")
81
 
82
+ # Train tokenizer
83
  tokenizer = Tokenizer(config)
84
  tokenizer.train_bpe(text_iterator(), save_dir=config.tokenizer_save_dir)
85
 
 
91
  tokenizer_path: Optional[str] = None,
92
  config: Optional[DataConfig] = None,
93
  ) -> tuple:
94
+ """Sets up the data pipeline in one call.
95
 
96
  Args:
97
  tokenizer_mode:
98
+ "train_new" - Train a new BPE tokenizer
99
+ "load_trained" - Load a previously trained tokenizer
100
+ "pretrained" - Use a pretrained HuggingFace tokenizer
101
  tokenizer_path:
102
+ "train_new" -> Save directory (default: ./tokenizer)
103
+ "load_trained" -> Path to the saved tokenizer
104
+ "pretrained" -> HF model name (default: mistralai/Mistral-7B-v0.1)
105
 
106
  Returns:
107
  (tokenizer, train_dataloader, val_dataloader)
108
 
109
+ Example usage (Colab):
110
+ # Method 1: Train a new tokenizer
111
  tok, train_dl, val_dl = setup_data_pipeline("train_new")
112
 
113
+ # Method 2: Load an existing tokenizer
114
  tok, train_dl, val_dl = setup_data_pipeline("load_trained", "./tokenizer")
115
 
116
+ # Method 3: Use a pretrained tokenizer (simplest)
117
  tok, train_dl, val_dl = setup_data_pipeline("pretrained")
118
  """
119
  config = config or DataConfig()
120
 
121
  print("=" * 60)
122
+ print("Data Pipeline Setup")
123
  print("=" * 60)
124
 
125
+ # ── Step 1: Tokenizer ──
126
  tokenizer = Tokenizer(config)
127
 
128
  if tokenizer_mode == "train_new":
 
136
  else:
137
  raise ValueError(f"Unknown tokenizer_mode: {tokenizer_mode}")
138
 
139
+ # ── Step 2: Training DataLoader ──
140
+ print("\n[DataLoader] Creating training DataLoader...")
141
  train_dataloader = create_train_dataloader(tokenizer, config)
142
 
143
+ # ── Step 3: Validation DataLoader ──
144
+ print("\n[DataLoader] Creating validation DataLoader...")
145
  val_dataset = ValidationDataset(tokenizer, config, num_samples=100)
146
  val_dataloader = val_dataset.get_dataloader(batch_size=config.batch_size)
147
 
148
  print("\n" + "=" * 60)
149
+ print("Data pipeline setup complete!")
150
+ print(f" Tokenizer vocab: {tokenizer.vocab_size:,}")
151
+ print(f" Sequence length: {config.max_seq_len}")
152
+ print(f" Batch size: {config.batch_size}")
153
+ print(f" Tokens/batch: {config.batch_size * config.max_seq_len:,}")
154
  print("=" * 60)
155
 
156
  return tokenizer, train_dataloader, val_dataloader
llm_lab/data/tokenizer.py CHANGED
@@ -1,4 +1,4 @@
1
- """토크나이저 래퍼 — SentencePiece / HuggingFace BPE 통합."""
2
 
3
  import os
4
  import json
@@ -8,23 +8,23 @@ from llm_lab.config import DataConfig
8
 
9
 
10
  class Tokenizer:
11
- """토크나이저 통합 래퍼.
12
-
13
- 가지 방법 지원:
14
- 1) 기존 SentencePiece 모델 로드
15
- 2) HuggingFace tokenizers 라이브러리로 새로 학습
16
- 3) 사전 학습된 HF 토크나이저 로드 (예: LLaMA tokenizer)
17
-
18
- 직접 구현하지 않는가?
19
- - BPE 토크나이저 학습은 대규모 텍스트 통계 처리이며,
20
- 모델 아키텍처 이해와 직접적 관련이 적습니다.
21
- - 다만 토크나이저의 동작 원리(BPE 병합 규칙) 이해해야 합니다.
22
-
23
- BPE(Byte Pair Encoding) 핵심 원리:
24
- 1) 텍스트를 바이트/문자 단위로 분리
25
- 2) 가장 빈번한 인접 쌍을 반복적으로 병합
26
- 3) vocab_size에 도달할 때까지 반복
27
- 자주 등장하는 단어는 하나의 토큰, 희귀 단어는 여러 토큰으로 분리
28
  """
29
 
30
  def __init__(self, config: DataConfig):
@@ -32,17 +32,17 @@ class Tokenizer:
32
  self._tokenizer = None
33
  self.vocab_size = config.vocab_size
34
 
35
- # 특수 토큰 ID (초기화 설정됨)
36
  self.bos_id: int = 1 # Beginning of Sequence
37
  self.eos_id: int = 2 # End of Sequence
38
  self.pad_id: int = 0 # Padding
39
 
40
  # ────────────────────────────────────────────────
41
- # 방법 1: SentencePiece 모델 로드
42
  # ────────────────────────────────────────────────
43
 
44
  def load_sentencepiece(self, model_path: str):
45
- """기존 SentencePiece 모델을 로드합니다."""
46
  import sentencepiece as spm
47
 
48
  self._tokenizer = spm.SentencePieceProcessor()
@@ -55,23 +55,23 @@ class Tokenizer:
55
  self._encode_fn = self._tokenizer.Encode
56
  self._decode_fn = self._tokenizer.Decode
57
 
58
- print(f"[Tokenizer] SentencePiece 로드 완료: vocab_size={self.vocab_size}")
59
 
60
  # ────────────────────────────────────────────────
61
- # 방법 2: HuggingFace tokenizers로 BPE 학습
62
  # ────────────────────────────────────────────────
63
 
64
  def train_bpe(self, text_iterator: Iterator[str], save_dir: Optional[str] = None):
65
- """BPE 토크나이저를 처음부터 학습합니다.
66
 
67
  Args:
68
- text_iterator: 학습 텍스트를 yield하는 이터레이터
69
- save_dir: 저장 경로
70
 
71
- 학습 포인트:
72
- - vocab_size가 클수록: 자주 쓰는 표현이 1토큰시퀀스 짧아짐
73
- - vocab_size가 작을수록: Embedding 파라미터 절약, 하지만 시퀀스 길어짐
74
- - 32K 영어 기준 좋은 균형점
75
  """
76
  from tokenizers import Tokenizer as HFTokenizer
77
  from tokenizers.models import BPE
@@ -79,27 +79,27 @@ class Tokenizer:
79
  from tokenizers.pre_tokenizers import ByteLevel
80
  from tokenizers.processors import TemplateProcessing
81
 
82
- print("[Tokenizer] BPE 토크나이저 ���습 시작...")
83
 
84
- # BPE 모델 생성
85
  tokenizer = HFTokenizer(BPE(unk_token="<unk>"))
86
  tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
87
 
88
- # 특수 토큰 정의
89
  special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
90
 
91
- # 트레이너 설정
92
  trainer = BpeTrainer(
93
  vocab_size=self.config.vocab_size,
94
  special_tokens=special_tokens,
95
- min_frequency=2, # 최소 2번 등장한 쌍만 병합
96
  show_progress=True,
97
  )
98
 
99
- # 학습 실행
100
  tokenizer.train_from_iterator(text_iterator, trainer=trainer)
101
 
102
- # 후처리: BOS/EOS 자동 추가
103
  tokenizer.post_processor = TemplateProcessing(
104
  single="<s> $A </s>",
105
  special_tokens=[("<s>", 1), ("</s>", 2)],
@@ -114,11 +114,11 @@ class Tokenizer:
114
  self._encode_fn = lambda text: tokenizer.encode(text).ids
115
  self._decode_fn = lambda ids: tokenizer.decode(ids)
116
 
117
- # 저장
118
  save_dir = save_dir or self.config.tokenizer_save_dir
119
  os.makedirs(save_dir, exist_ok=True)
120
  tokenizer.save(os.path.join(save_dir, "tokenizer.json"))
121
- # 메타 정보 저장
122
  meta = {
123
  "vocab_size": self.vocab_size,
124
  "bos_id": self.bos_id,
@@ -128,23 +128,23 @@ class Tokenizer:
128
  with open(os.path.join(save_dir, "tokenizer_meta.json"), "w") as f:
129
  json.dump(meta, f, indent=2)
130
 
131
- print(f"[Tokenizer] 학습 완료: vocab_size={self.vocab_size}")
132
- print(f"[Tokenizer] 저장 위치: {save_dir}")
133
 
134
  # ────────────────────────────────────────────────
135
- # 방법 3: 사전 학습된 HF 토크나이저 로드
136
  # ────────────────────────────────────────────────
137
 
138
  def load_pretrained_hf(self, name_or_path: str = "meta-llama/Llama-2-7b-hf"):
139
- """HuggingFace에서 사전 학습된 토크나이저를 로드합니다.
140
 
141
- 가장 간편한 방법. LLaMA 토크나이저는 32K vocab, BPE 기반.
142
- 주의: meta-llama 모델은 HF 승인이 필요할 있음.
143
- 대안: mistralai/Mistral-7B-v0.1 (승인 불필요)
144
  """
145
  from transformers import AutoTokenizer
146
 
147
- print(f"[Tokenizer] HF 토크나이저 로드: {name_or_path}")
148
  tokenizer = AutoTokenizer.from_pretrained(name_or_path)
149
 
150
  self._tokenizer = tokenizer
@@ -156,10 +156,10 @@ class Tokenizer:
156
  self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
157
  self._decode_fn = lambda ids: tokenizer.decode(ids)
158
 
159
- print(f"[Tokenizer] 로드 완료: vocab_size={self.vocab_size}")
160
 
161
  def load_trained_hf(self, path: str):
162
- """train_bpe()로 학습한 토크나이저를 다시 로드합니다."""
163
  from tokenizers import Tokenizer as HFTokenizer
164
 
165
  tokenizer = HFTokenizer.from_file(os.path.join(path, "tokenizer.json"))
@@ -175,21 +175,21 @@ class Tokenizer:
175
  self._encode_fn = lambda text: tokenizer.encode(text).ids
176
  self._decode_fn = lambda ids: tokenizer.decode(ids)
177
 
178
- print(f"[Tokenizer] 로드 완료: vocab_size={self.vocab_size}")
179
 
180
  # ────────────────────────────────────────────────
181
- # 공통 인터페이스
182
  # ────────────────────────────────────────────────
183
 
184
  def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
185
- """텍스트토큰 ID 리스트."""
186
  ids = self._encode_fn(text)
187
  if add_special_tokens:
188
  ids = [self.bos_id] + ids + [self.eos_id]
189
  return ids
190
 
191
  def decode(self, ids: List[int]) -> str:
192
- """토큰 ID 리스트텍스트."""
193
  return self._decode_fn(ids)
194
 
195
  def __len__(self) -> int:
 
1
+ """Tokenizer wrapper — SentencePiece / HuggingFace BPE integration."""
2
 
3
  import os
4
  import json
 
8
 
9
 
10
  class Tokenizer:
11
+ """Unified tokenizer wrapper.
12
+
13
+ Supports three methods:
14
+ 1) Load an existing SentencePiece model
15
+ 2) Train a new tokenizer using the HuggingFace tokenizers library
16
+ 3) Load a pretrained HF tokenizer (e.g., LLaMA tokenizer)
17
+
18
+ Why not implement from scratch?
19
+ - Training a BPE tokenizer involves large-scale text statistics processing,
20
+ which has little direct relevance to understanding model architecture.
21
+ - However, understanding how a tokenizer works (BPE merge rules) is still important.
22
+
23
+ BPE (Byte Pair Encoding) core principle:
24
+ 1) Split text into byte/character units
25
+ 2) Repeatedly merge the most frequent adjacent pair
26
+ 3) Repeat until vocab_size is reached
27
+ Frequent words become a single token; rare words are split into multiple tokens
28
  """
29
 
30
  def __init__(self, config: DataConfig):
 
32
  self._tokenizer = None
33
  self.vocab_size = config.vocab_size
34
 
35
+ # Special token IDs (set after initialization)
36
  self.bos_id: int = 1 # Beginning of Sequence
37
  self.eos_id: int = 2 # End of Sequence
38
  self.pad_id: int = 0 # Padding
39
 
40
  # ────────────────────────────────────────────────
41
+ # Method 1: Load a SentencePiece model
42
  # ────────────────────────────────────────────────
43
 
44
  def load_sentencepiece(self, model_path: str):
45
+ """Loads an existing SentencePiece model."""
46
  import sentencepiece as spm
47
 
48
  self._tokenizer = spm.SentencePieceProcessor()
 
55
  self._encode_fn = self._tokenizer.Encode
56
  self._decode_fn = self._tokenizer.Decode
57
 
58
+ print(f"[Tokenizer] SentencePiece loaded: vocab_size={self.vocab_size}")
59
 
60
  # ────────────────────────────────────────────────
61
+ # Method 2: Train a BPE tokenizer with HuggingFace tokenizers
62
  # ────────────────────────────────────────────────
63
 
64
  def train_bpe(self, text_iterator: Iterator[str], save_dir: Optional[str] = None):
65
+ """Trains a BPE tokenizer from scratch.
66
 
67
  Args:
68
+ text_iterator: Iterator that yields training text strings
69
+ save_dir: Directory path to save the trained tokenizer
70
 
71
+ Key insights:
72
+ - Larger vocab_size: common expressions become 1 token shorter sequences
73
+ - Smaller vocab_size: saves embedding parameters, but sequences get longer
74
+ - 32K is a good balance point for English
75
  """
76
  from tokenizers import Tokenizer as HFTokenizer
77
  from tokenizers.models import BPE
 
79
  from tokenizers.pre_tokenizers import ByteLevel
80
  from tokenizers.processors import TemplateProcessing
81
 
82
+ print("[Tokenizer] Starting BPE tokenizer training...")
83
 
84
+ # Create BPE model
85
  tokenizer = HFTokenizer(BPE(unk_token="<unk>"))
86
  tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
87
 
88
+ # Define special tokens
89
  special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
90
 
91
+ # Configure trainer
92
  trainer = BpeTrainer(
93
  vocab_size=self.config.vocab_size,
94
  special_tokens=special_tokens,
95
+ min_frequency=2, # Only merge pairs that appear at least twice
96
  show_progress=True,
97
  )
98
 
99
+ # Run training
100
  tokenizer.train_from_iterator(text_iterator, trainer=trainer)
101
 
102
+ # Post-processing: automatically add BOS/EOS
103
  tokenizer.post_processor = TemplateProcessing(
104
  single="<s> $A </s>",
105
  special_tokens=[("<s>", 1), ("</s>", 2)],
 
114
  self._encode_fn = lambda text: tokenizer.encode(text).ids
115
  self._decode_fn = lambda ids: tokenizer.decode(ids)
116
 
117
+ # Save
118
  save_dir = save_dir or self.config.tokenizer_save_dir
119
  os.makedirs(save_dir, exist_ok=True)
120
  tokenizer.save(os.path.join(save_dir, "tokenizer.json"))
121
+ # Save metadata
122
  meta = {
123
  "vocab_size": self.vocab_size,
124
  "bos_id": self.bos_id,
 
128
  with open(os.path.join(save_dir, "tokenizer_meta.json"), "w") as f:
129
  json.dump(meta, f, indent=2)
130
 
131
+ print(f"[Tokenizer] Training complete: vocab_size={self.vocab_size}")
132
+ print(f"[Tokenizer] Saved to: {save_dir}")
133
 
134
  # ────────────────────────────────────────────────
135
+ # Method 3: Load a pretrained HF tokenizer
136
  # ────────────────────────────────────────────────
137
 
138
  def load_pretrained_hf(self, name_or_path: str = "meta-llama/Llama-2-7b-hf"):
139
+ """Loads a pretrained tokenizer from HuggingFace.
140
 
141
+ The simplest method. The LLaMA tokenizer has a 32K vocab and is BPE-based.
142
+ Note: meta-llama models may require HF approval to access.
143
+ Alternative: mistralai/Mistral-7B-v0.1 (no approval required)
144
  """
145
  from transformers import AutoTokenizer
146
 
147
+ print(f"[Tokenizer] Loading HF tokenizer: {name_or_path}")
148
  tokenizer = AutoTokenizer.from_pretrained(name_or_path)
149
 
150
  self._tokenizer = tokenizer
 
156
  self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
157
  self._decode_fn = lambda ids: tokenizer.decode(ids)
158
 
159
+ print(f"[Tokenizer] Loaded: vocab_size={self.vocab_size}")
160
 
161
  def load_trained_hf(self, path: str):
162
+ """Reloads a tokenizer previously trained with train_bpe()."""
163
  from tokenizers import Tokenizer as HFTokenizer
164
 
165
  tokenizer = HFTokenizer.from_file(os.path.join(path, "tokenizer.json"))
 
175
  self._encode_fn = lambda text: tokenizer.encode(text).ids
176
  self._decode_fn = lambda ids: tokenizer.decode(ids)
177
 
178
+ print(f"[Tokenizer] Loaded: vocab_size={self.vocab_size}")
179
 
180
  # ────────────────────────────────────────────────
181
+ # Common interface
182
  # ────────────────────────────────────────────────
183
 
184
  def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
185
+ """Textlist of token IDs."""
186
  ids = self._encode_fn(text)
187
  if add_special_tokens:
188
  ids = [self.bos_id] + ids + [self.eos_id]
189
  return ids
190
 
191
  def decode(self, ids: List[int]) -> str:
192
+ """List of token IDs text."""
193
  return self._decode_fn(ids)
194
 
195
  def __len__(self) -> int:
llm_lab/evaluation/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """평가 모듈 — Perplexity, 텍스트 생성, Scaling Law, Attention 시각화."""
2
 
3
  from .perplexity import PerplexityEvaluator
4
  from .generation import GenerationEvaluator
 
1
+ """Evaluation module — Perplexity, text generation, Scaling Law, Attention visualization."""
2
 
3
  from .perplexity import PerplexityEvaluator
4
  from .generation import GenerationEvaluator
llm_lab/evaluation/attention_viz.py CHANGED
@@ -1,4 +1,4 @@
1
- """Attention 패턴 시각화."""
2
 
3
  import math
4
  from pathlib import Path
@@ -18,15 +18,15 @@ except ImportError:
18
 
19
 
20
  class AttentionVisualizer:
21
- """Attention 패턴을 시각화합니다.
22
 
23
- 학습 포인트:
24
- - Causal Mask: 하삼각 패턴 (미래 토큰은 없음)
25
- - 헤드별 역할 분화: 일부는 로컬(인접), 일부는 글로벌( 토큰) 주목
26
- - 구문론적 패턴: 동사주어, 대명사선행사 등에 높은 attention
27
 
28
- 주의: 1B 모델의 전체 attention 저장하면 메모리 부족!
29
- 특정 레이어/헤드만 선택적으로 시각화합니다.
30
  """
31
 
32
  def __init__(self, save_dir: str = "./eval_results"):
@@ -41,10 +41,10 @@ class AttentionVisualizer:
41
  layer_idx: int = 0,
42
  device: torch.device = torch.device("cpu"),
43
  ) -> torch.Tensor:
44
- """특정 레이어의 attention weight를 추출합니다.
45
 
46
- 모델의 attention 모듈을 일시적으로 수정하여
47
- attention weight를 캡처합니다.
48
 
49
  Returns:
50
  attention_weights: (num_heads, seq_len, seq_len)
@@ -52,10 +52,10 @@ class AttentionVisualizer:
52
  model.eval()
53
  captured_attn = {}
54
 
55
- # Hook으로 attention weight 캡처
56
  target_layer = model.layers[layer_idx].attention
57
 
58
- # scaled_dot_product_attention 수동 구현으로 대체
59
  original_forward = target_layer.forward
60
 
61
  def hooked_forward(x, mask=None, position_offset=0):
@@ -72,7 +72,7 @@ class AttentionVisualizer:
72
  k = target_layer._repeat_kv(k)
73
  v = target_layer._repeat_kv(v)
74
 
75
- # 수동 attention 계산 (weight 추출용)
76
  scale = 1.0 / math.sqrt(hd)
77
  scores = torch.matmul(q, k.transpose(-2, -1)) * scale
78
 
@@ -81,13 +81,13 @@ class AttentionVisualizer:
81
  scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
82
 
83
  attn_weights = F.softmax(scores, dim=-1)
84
- captured_attn["weights"] = attn_weights[0].cpu() # 배치만
85
 
86
  out = torch.matmul(attn_weights, v)
87
  out = out.transpose(1, 2).contiguous().view(B, S, -1)
88
  return target_layer.o_proj(out)
89
 
90
- # Hook 적용
91
  target_layer.forward = hooked_forward
92
 
93
  try:
@@ -105,13 +105,13 @@ class AttentionVisualizer:
105
  save_path: Optional[str] = None,
106
  title: str = "Attention Weights",
107
  ):
108
- """Attention heatmap을 그립니다."""
109
  if not HAS_MATPLOTLIB:
110
- print("⚠️ matplotlib 필요합니다")
111
  return
112
 
113
  weights = attn_weights[head_idx].numpy()
114
- max_len = min(len(tokens), 50) # 최대 50 토큰만 표시
115
  weights = weights[:max_len, :max_len]
116
  display_tokens = tokens[:max_len]
117
 
@@ -132,7 +132,7 @@ class AttentionVisualizer:
132
 
133
  save_path = save_path or str(self.save_dir / f"attention_head{head_idx}.png")
134
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
135
- print(f" 📊 Attention 시각화 저장: {save_path}")
136
  plt.close(fig)
137
 
138
  def plot_multi_head_summary(
@@ -141,7 +141,7 @@ class AttentionVisualizer:
141
  num_heads_to_show: int = 8,
142
  save_path: Optional[str] = None,
143
  ):
144
- """여러 헤드의 attention 패턴을 요약 비교합니다."""
145
  if not HAS_MATPLOTLIB:
146
  return
147
 
@@ -162,7 +162,7 @@ class AttentionVisualizer:
162
  ax.set_xticks([])
163
  ax.set_yticks([])
164
 
165
- # subplot 숨기기
166
  for idx in range(n_heads, rows * cols):
167
  r, c = idx // cols, idx % cols
168
  axes[r, c].axis("off")
@@ -172,5 +172,5 @@ class AttentionVisualizer:
172
 
173
  save_path = save_path or str(self.save_dir / "attention_multi_head.png")
174
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
175
- print(f" 📊 멀티 헤드 요약 저장: {save_path}")
176
  plt.close(fig)
 
1
+ """Attention pattern visualization."""
2
 
3
  import math
4
  from pathlib import Path
 
18
 
19
 
20
  class AttentionVisualizer:
21
+ """Visualizes attention patterns.
22
 
23
+ Learning insights:
24
+ - Causal Mask: lower-triangular pattern (future tokens cannot be attended to)
25
+ - Head specialization: some heads focus locally (adjacent), others globally (distant tokens)
26
+ - Syntactic patterns: high attention on verbsubject, pronounantecedent, etc.
27
 
28
+ Note: Storing the full attention of a 1B model causes out-of-memory!
29
+ Visualize only selected layers/heads.
30
  """
31
 
32
  def __init__(self, save_dir: str = "./eval_results"):
 
41
  layer_idx: int = 0,
42
  device: torch.device = torch.device("cpu"),
43
  ) -> torch.Tensor:
44
+ """Extracts attention weights from a specific layer.
45
 
46
+ Temporarily modifies the model's attention module to
47
+ capture attention weights.
48
 
49
  Returns:
50
  attention_weights: (num_heads, seq_len, seq_len)
 
52
  model.eval()
53
  captured_attn = {}
54
 
55
+ # Capture attention weights via hook
56
  target_layer = model.layers[layer_idx].attention
57
 
58
+ # Replace scaled_dot_product_attention with a manual implementation
59
  original_forward = target_layer.forward
60
 
61
  def hooked_forward(x, mask=None, position_offset=0):
 
72
  k = target_layer._repeat_kv(k)
73
  v = target_layer._repeat_kv(v)
74
 
75
+ # Manual attention computation (for weight extraction)
76
  scale = 1.0 / math.sqrt(hd)
77
  scores = torch.matmul(q, k.transpose(-2, -1)) * scale
78
 
 
81
  scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
82
 
83
  attn_weights = F.softmax(scores, dim=-1)
84
+ captured_attn["weights"] = attn_weights[0].cpu() # first batch only
85
 
86
  out = torch.matmul(attn_weights, v)
87
  out = out.transpose(1, 2).contiguous().view(B, S, -1)
88
  return target_layer.o_proj(out)
89
 
90
+ # Apply hook
91
  target_layer.forward = hooked_forward
92
 
93
  try:
 
105
  save_path: Optional[str] = None,
106
  title: str = "Attention Weights",
107
  ):
108
+ """Draws an attention heatmap."""
109
  if not HAS_MATPLOTLIB:
110
+ print("⚠️ matplotlib required")
111
  return
112
 
113
  weights = attn_weights[head_idx].numpy()
114
+ max_len = min(len(tokens), 50) # display at most 50 tokens
115
  weights = weights[:max_len, :max_len]
116
  display_tokens = tokens[:max_len]
117
 
 
132
 
133
  save_path = save_path or str(self.save_dir / f"attention_head{head_idx}.png")
134
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
135
+ print(f" 📊 Attention visualization saved: {save_path}")
136
  plt.close(fig)
137
 
138
  def plot_multi_head_summary(
 
141
  num_heads_to_show: int = 8,
142
  save_path: Optional[str] = None,
143
  ):
144
+ """Summarizes and compares attention patterns across multiple heads."""
145
  if not HAS_MATPLOTLIB:
146
  return
147
 
 
162
  ax.set_xticks([])
163
  ax.set_yticks([])
164
 
165
+ # Hide empty subplots
166
  for idx in range(n_heads, rows * cols):
167
  r, c = idx // cols, idx % cols
168
  axes[r, c].axis("off")
 
172
 
173
  save_path = save_path or str(self.save_dir / "attention_multi_head.png")
174
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
175
+ print(f" 📊 Multi-head summary saved: {save_path}")
176
  plt.close(fig)
llm_lab/evaluation/checklist.py CHANGED
@@ -1,13 +1,13 @@
1
- """학습 인사이트 체크리스트 검증기."""
2
 
3
  from typing import Any, Dict, Optional
4
 
5
 
6
  class InsightChecklist:
7
- """PRD에 정의된 학습 인사이트 체크리스트를 자동/수동으로 검증합니다.
8
 
9
- 자동 검증 가능 항목은 메트릭 기반으로 판정하고,
10
- 수동 항목은 질문으로 제시합니다.
11
  """
12
 
13
  @staticmethod
@@ -15,9 +15,9 @@ class InsightChecklist:
15
  report: Dict[str, Any],
16
  metrics_history: Optional[Dict[str, list]] = None,
17
  ):
18
- """체크리스트를 실행합니다."""
19
  print("\n" + "=" * 70)
20
- print("✅ 학습 인사이트 체크리스트")
21
  print("=" * 70)
22
 
23
  checks = {
@@ -26,74 +26,74 @@ class InsightChecklist:
26
  "manual": [],
27
  }
28
 
29
- # ── 자동 검증 ──
30
 
31
- # 1. Loss 수렴
32
  if report.get("perplexity", {}).get("loss", 99) < 4.0:
33
- checks["passed"].append("모델 Loss 4.0 이하로 수렴")
34
  else:
35
- checks["failed"].append("모델 Loss 4.0 이하로 미수렴")
36
 
37
- # 2. Loss 스파이크
38
  spikes = report.get("training_dynamics", {}).get("loss", {}).get("spikes", [])
39
  if len(spikes) < 5:
40
- checks["passed"].append(f"Loss 스파이크 {len(spikes)} (< 5)")
41
  else:
42
- checks["failed"].append(f"Loss 스파이크 {len(spikes)} ( 5, 안정성 개선 필요)")
43
 
44
- # 3. 위치별 Loss 패턴
45
  if report.get("position_losses"):
46
  early = report["position_losses"]["early_avg"]
47
  late = report["position_losses"]["late_avg"]
48
  if early > late:
49
- checks["passed"].append("위치별 Loss 감소 패턴 확인 (컨텍스트 활용)")
50
  else:
51
- checks["failed"].append("위치별 Loss 패턴 이상 (컨텍스트 미활용?)")
52
 
53
- # 4. 생성 반복률
54
  rep = report.get("generation", {}).get("avg_metrics", {}).get("repetition_rate", 1.0)
55
  if rep < 0.3:
56
- checks["passed"].append(f"생성 반복률 {rep:.1%} (< 30%)")
57
  else:
58
- checks["failed"].append(f"생성 반복률 {rep:.1%} ( 30%, temperature/top_p 조정)")
59
 
60
- # 5. Gradient 클리핑 비율
61
  if metrics_history and metrics_history.get("grad_norm"):
62
  gnorms = metrics_history["grad_norm"]
63
  clip_rate = sum(1 for g in gnorms if g >= 0.99) / max(len(gnorms), 1)
64
  if clip_rate < 0.3:
65
- checks["passed"].append(f"Gradient 클리핑 비율 {clip_rate:.1%} (건강)")
66
  else:
67
- checks["failed"].append(f"Gradient 클리핑 비율 {clip_rate:.1%} (너무 잦음)")
68
 
69
- # ── 수동 확인 항목 ──
70
  manual_items = [
71
- "Self-Attention에서 Q, K, V 각각의 역할을 설명할 수 있는가?",
72
- "RoPE가 위치 정보를 인코딩하는 수학적 원리를 이해하는가?",
73
- "GQA가 MHA 대비 메모리를 절약하는 메커니즘을 설명할 있는가?",
74
- "SwiGLU 게이팅 메커니즘이 ReLU FFN과 어떻게 다른지 이해하는가?",
75
- "Learning Rate Warmup 필요한지 체감했는가?",
76
- "Gradient Accumulation 배치를 시뮬레이션하는 원리를 이해하는가?",
77
- "Mixed Precision(bf16)의 메모리-속도 효과를 측정했는가?",
78
- "Activation Checkpointing의 메모리-연산 트레이���오프를 이해하는가?",
79
  ]
80
  checks["manual"] = manual_items
81
 
82
- # ── 출력 ──
83
  total_auto = len(checks["passed"]) + len(checks["failed"])
84
  passed_auto = len(checks["passed"])
85
 
86
- print(f"\n 자동 검증: {passed_auto}/{total_auto} 통과")
87
  for item in checks["passed"]:
88
  print(f" ✅ {item}")
89
  for item in checks["failed"]:
90
  print(f" ❌ {item}")
91
 
92
- print(f"\n 수동 확인 ({len(manual_items)} 항목):")
93
  for i, item in enumerate(manual_items, 1):
94
  print(f" {i}. [ ] {item}")
95
 
96
- print(f"\n 진행률: {passed_auto}/{total_auto + len(manual_items)} "
97
- f"(수동 항목 포함 시)")
98
 
99
  return checks
 
1
+ """Training insight checklist validator."""
2
 
3
  from typing import Any, Dict, Optional
4
 
5
 
6
  class InsightChecklist:
7
+ """Automatically and manually validates the training insight checklist defined in the PRD.
8
 
9
+ Items that can be automatically validated are judged based on metrics,
10
+ while manual items are presented as questions.
11
  """
12
 
13
  @staticmethod
 
15
  report: Dict[str, Any],
16
  metrics_history: Optional[Dict[str, list]] = None,
17
  ):
18
+ """Runs the checklist."""
19
  print("\n" + "=" * 70)
20
+ print("✅ Training Insight Checklist")
21
  print("=" * 70)
22
 
23
  checks = {
 
26
  "manual": [],
27
  }
28
 
29
+ # ── Automatic validation ──
30
 
31
+ # 1. Loss convergence
32
  if report.get("perplexity", {}).get("loss", 99) < 4.0:
33
+ checks["passed"].append("Model Loss converged below 4.0")
34
  else:
35
+ checks["failed"].append("Model Loss has not converged below 4.0")
36
 
37
+ # 2. Loss spikes
38
  spikes = report.get("training_dynamics", {}).get("loss", {}).get("spikes", [])
39
  if len(spikes) < 5:
40
+ checks["passed"].append(f"Loss spikes: {len(spikes)} (< 5)")
41
  else:
42
+ checks["failed"].append(f"Loss spikes: {len(spikes)} (>= 5, stability improvement needed)")
43
 
44
+ # 3. Per-position loss pattern
45
  if report.get("position_losses"):
46
  early = report["position_losses"]["early_avg"]
47
  late = report["position_losses"]["late_avg"]
48
  if early > late:
49
+ checks["passed"].append("Per-position loss decrease pattern confirmed (context utilization)")
50
  else:
51
+ checks["failed"].append("Per-position loss pattern abnormal (context not utilized?)")
52
 
53
+ # 4. Generation repetition rate
54
  rep = report.get("generation", {}).get("avg_metrics", {}).get("repetition_rate", 1.0)
55
  if rep < 0.3:
56
+ checks["passed"].append(f"Generation repetition rate {rep:.1%} (< 30%)")
57
  else:
58
+ checks["failed"].append(f"Generation repetition rate {rep:.1%} (>= 30%, adjust temperature/top_p)")
59
 
60
+ # 5. Gradient clipping rate
61
  if metrics_history and metrics_history.get("grad_norm"):
62
  gnorms = metrics_history["grad_norm"]
63
  clip_rate = sum(1 for g in gnorms if g >= 0.99) / max(len(gnorms), 1)
64
  if clip_rate < 0.3:
65
+ checks["passed"].append(f"Gradient clipping rate {clip_rate:.1%} (healthy)")
66
  else:
67
+ checks["failed"].append(f"Gradient clipping rate {clip_rate:.1%} (too frequent)")
68
 
69
+ # ── Manual verification items ──
70
  manual_items = [
71
+ "Can you explain the individual roles of Q, K, and V in Self-Attention?",
72
+ "Do you understand the mathematical principle by which RoPE encodes positional information?",
73
+ "Can you explain the mechanism by which GQA saves memory compared to MHA?",
74
+ "Do you understand how SwiGLU's gating mechanism differs from a ReLU FFN?",
75
+ "Did you experience why Learning Rate Warmup is necessary?",
76
+ "Do you understand the principle by which Gradient Accumulation simulates a large batch?",
77
+ "Have you measured the memory-speed effect of Mixed Precision (bf16)?",
78
+ "Do you understand the memory-compute trade-off of Activation Checkpointing?",
79
  ]
80
  checks["manual"] = manual_items
81
 
82
+ # ── Output ──
83
  total_auto = len(checks["passed"]) + len(checks["failed"])
84
  passed_auto = len(checks["passed"])
85
 
86
+ print(f"\n Automatic validation: {passed_auto}/{total_auto} passed")
87
  for item in checks["passed"]:
88
  print(f" ✅ {item}")
89
  for item in checks["failed"]:
90
  print(f" ❌ {item}")
91
 
92
+ print(f"\n Manual verification ({len(manual_items)} items):")
93
  for i, item in enumerate(manual_items, 1):
94
  print(f" {i}. [ ] {item}")
95
 
96
+ print(f"\n Total progress: {passed_auto}/{total_auto + len(manual_items)} "
97
+ f"(including manual items)")
98
 
99
  return checks
llm_lab/evaluation/dynamics.py CHANGED
@@ -1,4 +1,4 @@
1
- """학습 역학 분석기."""
2
 
3
  import math
4
  from pathlib import Path
@@ -14,13 +14,13 @@ except ImportError:
14
 
15
 
16
  class TrainingDynamicsAnalyzer:
17
- """학습 과정의 메트릭을 분석하고 시각화합니다.
18
 
19
- 분석 항목:
20
- - Loss 곡선: 수렴 패턴, 스파이크 감지
21
- - LR 스케줄: Warmup + Cosine decay 확인
22
- - Gradient Norm: 학습 안정성, 폭발/소멸 감지
23
- - 처리량: tokens/sec 안정성, 병목 감지
24
  """
25
 
26
  def __init__(self, save_dir: str = "./eval_results"):
@@ -28,21 +28,21 @@ class TrainingDynamicsAnalyzer:
28
  self.save_dir.mkdir(parents=True, exist_ok=True)
29
 
30
  def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
31
- """학습 메트릭을 분석합니다.
32
 
33
  Args:
34
- metrics_history: Trainer.metrics.history 딕셔너리
35
 
36
  Returns:
37
- 분석 결과
38
  """
39
  print("\n" + "=" * 70)
40
- print("🔬 학습 역학 분석")
41
  print("=" * 70)
42
 
43
  analysis = {}
44
 
45
- # ── Loss 분석 ──
46
  if metrics_history.get("train_loss"):
47
  losses = metrics_history["train_loss"]
48
  analysis["loss"] = {
@@ -52,7 +52,7 @@ class TrainingDynamicsAnalyzer:
52
  "total_reduction": round(losses[0] - losses[-1], 4),
53
  }
54
 
55
- # 스파이크 감지 (이전 대비 50% 이상 급증)
56
  spikes = []
57
  for i in range(1, len(losses)):
58
  if losses[i] > losses[i-1] * 1.5:
@@ -61,17 +61,17 @@ class TrainingDynamicsAnalyzer:
61
 
62
  analysis["loss"]["spikes"] = spikes
63
 
64
- print(f"\n 📉 Loss 분석:")
65
- print(f" 초기: {analysis['loss']['initial']:.4f}")
66
- print(f" 최종: {analysis['loss']['final']:.4f}")
67
- print(f" 최소: {analysis['loss']['minimum']:.4f}")
68
- print(f" 감소: {analysis['loss']['total_reduction']:.4f}")
69
- print(f" 스파이크: {len(spikes)}")
70
  if spikes:
71
  for s in spikes[:5]:
72
  print(f" Step {s['step']}: Loss = {s['loss']}")
73
 
74
- # ── Gradient Norm 분석 ──
75
  if metrics_history.get("grad_norm"):
76
  gnorms = metrics_history["grad_norm"]
77
  analysis["grad_norm"] = {
@@ -81,14 +81,14 @@ class TrainingDynamicsAnalyzer:
81
  "clipped_pct": round(sum(1 for g in gnorms if g >= 0.99) / len(gnorms) * 100, 1),
82
  }
83
 
84
- print(f"\n 📐 Gradient Norm 분석:")
85
- print(f" 평균: {analysis['grad_norm']['mean']:.4f}")
86
- print(f" 최대: {analysis['grad_norm']['max']:.4f}")
87
- print(f" 클리핑 비율: {analysis['grad_norm']['clipped_pct']:.1f}%")
88
  if analysis["grad_norm"]["clipped_pct"] > 30:
89
- print(f" ⚠️ 클리핑이 잦음LR 하향 또는 warmup 연장 고려")
90
 
91
- # ── 처리량 분석 ──
92
  if metrics_history.get("tokens_per_sec"):
93
  tps = metrics_history["tokens_per_sec"]
94
  tps_valid = [t for t in tps if t > 0]
@@ -100,10 +100,10 @@ class TrainingDynamicsAnalyzer:
100
  "max": round(max(tps_valid)),
101
  }
102
 
103
- print(f"\n ⚡ 처리량 분석:")
104
- print(f" 평균: {analysis['throughput']['mean']:,} tokens/sec")
105
- print(f" 표준편차: {analysis['throughput']['std']:,}")
106
- print(f" 범위: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]")
107
 
108
  return analysis
109
 
@@ -112,9 +112,9 @@ class TrainingDynamicsAnalyzer:
112
  metrics_history: Dict[str, list],
113
  save_path: Optional[str] = None,
114
  ):
115
- """학습 곡선을 4-panel 차트로 시각화합니다."""
116
  if not HAS_MATPLOTLIB:
117
- print("⚠️ matplotlib 필요합니다: pip install matplotlib")
118
  return
119
 
120
  fig, axes = plt.subplots(2, 2, figsize=(16, 10))
@@ -129,7 +129,7 @@ class TrainingDynamicsAnalyzer:
129
  metrics_history["train_loss"],
130
  color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
131
 
132
- # 이동 평균 (스무딩)
133
  if len(metrics_history["train_loss"]) > 20:
134
  window = min(50, len(metrics_history["train_loss"]) // 5)
135
  smoothed = self._moving_average(metrics_history["train_loss"], window)
@@ -192,7 +192,7 @@ class TrainingDynamicsAnalyzer:
192
 
193
  save_path = save_path or str(self.save_dir / "training_curves.png")
194
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
195
- print(f"\n 📊 학습 곡선 저장: {save_path}")
196
  plt.close(fig)
197
 
198
  def plot_position_loss(
@@ -200,7 +200,7 @@ class TrainingDynamicsAnalyzer:
200
  position_losses: List[float],
201
  save_path: Optional[str] = None,
202
  ):
203
- """위치별 Loss 분포를 시각화합니다."""
204
  if not HAS_MATPLOTLIB:
205
  return
206
 
@@ -215,7 +215,7 @@ class TrainingDynamicsAnalyzer:
215
  ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
216
  ax.grid(True, alpha=0.3)
217
 
218
- # 주요 구간 표시
219
  if len(position_losses) > 100:
220
  early_avg = sum(position_losses[:50]) / 50
221
  late_avg = sum(position_losses[-200:]) / 200
@@ -229,12 +229,12 @@ class TrainingDynamicsAnalyzer:
229
 
230
  save_path = save_path or str(self.save_dir / "position_loss.png")
231
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
232
- print(f" 📊 위치별 Loss 저장: {save_path}")
233
  plt.close(fig)
234
 
235
  @staticmethod
236
  def _moving_average(data: list, window: int) -> list:
237
- """이동 평균 계산."""
238
  result = []
239
  for i in range(window - 1, len(data)):
240
  avg = sum(data[i - window + 1 : i + 1]) / window
 
1
+ """Training dynamics analyzer."""
2
 
3
  import math
4
  from pathlib import Path
 
14
 
15
 
16
  class TrainingDynamicsAnalyzer:
17
+ """Analyzes and visualizes training metrics.
18
 
19
+ Analysis items:
20
+ - Loss curve: Convergence patterns, spike detection
21
+ - LR schedule: Warmup + Cosine decay verification
22
+ - Gradient Norm: Training stability, explosion/vanishing detection
23
+ - Throughput: tokens/sec stability, bottleneck detection
24
  """
25
 
26
  def __init__(self, save_dir: str = "./eval_results"):
 
28
  self.save_dir.mkdir(parents=True, exist_ok=True)
29
 
30
  def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
31
+ """Analyzes training metrics.
32
 
33
  Args:
34
+ metrics_history: Trainer.metrics.history dictionary
35
 
36
  Returns:
37
+ Analysis results
38
  """
39
  print("\n" + "=" * 70)
40
+ print("🔬 Training Dynamics Analysis")
41
  print("=" * 70)
42
 
43
  analysis = {}
44
 
45
+ # ── Loss analysis ──
46
  if metrics_history.get("train_loss"):
47
  losses = metrics_history["train_loss"]
48
  analysis["loss"] = {
 
52
  "total_reduction": round(losses[0] - losses[-1], 4),
53
  }
54
 
55
+ # Spike detection (sudden increase of 50% or more compared to previous value)
56
  spikes = []
57
  for i in range(1, len(losses)):
58
  if losses[i] > losses[i-1] * 1.5:
 
61
 
62
  analysis["loss"]["spikes"] = spikes
63
 
64
+ print(f"\n 📉 Loss Analysis:")
65
+ print(f" Initial: {analysis['loss']['initial']:.4f}")
66
+ print(f" Final: {analysis['loss']['final']:.4f}")
67
+ print(f" Minimum: {analysis['loss']['minimum']:.4f}")
68
+ print(f" Reduction: {analysis['loss']['total_reduction']:.4f}")
69
+ print(f" Spikes: {len(spikes)}")
70
  if spikes:
71
  for s in spikes[:5]:
72
  print(f" Step {s['step']}: Loss = {s['loss']}")
73
 
74
+ # ── Gradient Norm analysis ──
75
  if metrics_history.get("grad_norm"):
76
  gnorms = metrics_history["grad_norm"]
77
  analysis["grad_norm"] = {
 
81
  "clipped_pct": round(sum(1 for g in gnorms if g >= 0.99) / len(gnorms) * 100, 1),
82
  }
83
 
84
+ print(f"\n 📐 Gradient Norm Analysis:")
85
+ print(f" Mean: {analysis['grad_norm']['mean']:.4f}")
86
+ print(f" Max: {analysis['grad_norm']['max']:.4f}")
87
+ print(f" Clipping rate: {analysis['grad_norm']['clipped_pct']:.1f}%")
88
  if analysis["grad_norm"]["clipped_pct"] > 30:
89
+ print(f" ⚠️ Clipping is frequent consider lowering LR or extending warmup")
90
 
91
+ # ── Throughput analysis ──
92
  if metrics_history.get("tokens_per_sec"):
93
  tps = metrics_history["tokens_per_sec"]
94
  tps_valid = [t for t in tps if t > 0]
 
100
  "max": round(max(tps_valid)),
101
  }
102
 
103
+ print(f"\n ⚡ Throughput Analysis:")
104
+ print(f" Mean: {analysis['throughput']['mean']:,} tokens/sec")
105
+ print(f" StdDev: {analysis['throughput']['std']:,}")
106
+ print(f" Range: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]")
107
 
108
  return analysis
109
 
 
112
  metrics_history: Dict[str, list],
113
  save_path: Optional[str] = None,
114
  ):
115
+ """Visualizes training curves as a 4-panel chart."""
116
  if not HAS_MATPLOTLIB:
117
+ print("⚠️ matplotlib required: pip install matplotlib")
118
  return
119
 
120
  fig, axes = plt.subplots(2, 2, figsize=(16, 10))
 
129
  metrics_history["train_loss"],
130
  color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
131
 
132
+ # Moving average (smoothing)
133
  if len(metrics_history["train_loss"]) > 20:
134
  window = min(50, len(metrics_history["train_loss"]) // 5)
135
  smoothed = self._moving_average(metrics_history["train_loss"], window)
 
192
 
193
  save_path = save_path or str(self.save_dir / "training_curves.png")
194
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
195
+ print(f"\n 📊 Training curves saved: {save_path}")
196
  plt.close(fig)
197
 
198
  def plot_position_loss(
 
200
  position_losses: List[float],
201
  save_path: Optional[str] = None,
202
  ):
203
+ """Visualizes loss distribution by position."""
204
  if not HAS_MATPLOTLIB:
205
  return
206
 
 
215
  ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
216
  ax.grid(True, alpha=0.3)
217
 
218
+ # Mark key regions
219
  if len(position_losses) > 100:
220
  early_avg = sum(position_losses[:50]) / 50
221
  late_avg = sum(position_losses[-200:]) / 200
 
229
 
230
  save_path = save_path or str(self.save_dir / "position_loss.png")
231
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
232
+ print(f" 📊 Position loss saved: {save_path}")
233
  plt.close(fig)
234
 
235
  @staticmethod
236
  def _moving_average(data: list, window: int) -> list:
237
+ """Compute moving average."""
238
  result = []
239
  for i in range(window - 1, len(data)):
240
  avg = sum(data[i - window + 1 : i + 1]) / window
llm_lab/evaluation/full_evaluator.py CHANGED
@@ -1,4 +1,4 @@
1
- """종합 평가 실행기."""
2
 
3
  import json
4
  import time
@@ -17,9 +17,9 @@ from .attention_viz import AttentionVisualizer
17
 
18
 
19
  class FullEvaluator:
20
- """모든 평가를 번에 실행하고 리포트를 생성합니다.
21
 
22
- 사용법:
23
  ```python
24
  evaluator = FullEvaluator(model, tokenizer, val_dataloader, device)
25
  report = evaluator.run_full_evaluation()
@@ -48,24 +48,24 @@ class FullEvaluator:
48
  self.save_dir.mkdir(parents=True, exist_ok=True)
49
 
50
  def run_full_evaluation(self) -> Dict[str, Any]:
51
- """전체 평가를 실행합니다."""
52
  report = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
53
 
54
  print("\n" + "=" * 70)
55
- print("🔍 종합 평가 시작")
56
  print("=" * 70)
57
 
58
  # ── 1. Perplexity ──
59
  print("\n" + "━" * 40)
60
- print("Phase 1/4: Perplexity 측정")
61
  print("━" * 40)
62
  ppl_evaluator = PerplexityEvaluator(self.config)
63
  report["perplexity"] = ppl_evaluator.evaluate(
64
  self.model, self.val_dataloader, self.device, self.dtype
65
  )
66
 
67
- # 위치별 Loss
68
- print("\n 위치별 Loss 측정 중...")
69
  position_losses = ppl_evaluator.evaluate_per_position(
70
  self.model, self.val_dataloader, self.device, self.dtype
71
  )
@@ -74,13 +74,13 @@ class FullEvaluator:
74
  "late_avg": round(sum(position_losses[-200:]) / max(len(position_losses[-200:]), 1), 4),
75
  }
76
 
77
- # 위치별 Loss 시각화
78
  dynamics = TrainingDynamicsAnalyzer(str(self.save_dir))
79
  dynamics.plot_position_loss(position_losses, str(self.save_dir / "position_loss.png"))
80
 
81
- # ── 2. 텍스트 생성 ──
82
  print("\n" + "━" * 40)
83
- print("Phase 2/4: 텍스트 생성")
84
  print("━" * 40)
85
  gen_evaluator = GenerationEvaluator(self.config)
86
  gen_results = gen_evaluator.generate_samples(
@@ -91,52 +91,52 @@ class FullEvaluator:
91
  "avg_metrics": self._average_gen_metrics(gen_results),
92
  }
93
 
94
- # ── 3. 학습 역학 분석 ──
95
  if self.metrics_history:
96
  print("\n" + "━" * 40)
97
- print("Phase 3/4: 학습 역학 분석")
98
  print("━" * 40)
99
  report["training_dynamics"] = dynamics.analyze_metrics(self.metrics_history)
100
  dynamics.plot_training_curves(self.metrics_history,
101
  str(self.save_dir / "training_curves.png"))
102
  else:
103
- print("\n Phase 3/4: 건너뜀 (metrics_history 없음)")
104
 
105
- # ── 4. Attention 시각화 (샘플) ──
106
  print("\n" + "━" * 40)
107
- print("Phase 4/4: Attention 시각화")
108
  print("━" * 40)
109
  try:
110
  self._visualize_attention_sample()
111
  except Exception as e:
112
- print(f" ⚠️ Attention 시각화 실패: {e}")
113
 
114
- # ── 리포트 저장 ──
115
  report_path = self.save_dir / "eval_report.json"
116
  with open(report_path, "w") as f:
117
  json.dump(report, f, indent=2, default=str)
118
- print(f"\n📋 리포트 저장: {report_path}")
119
 
120
- # ── 요약 출력 ──
121
  self._print_summary(report)
122
 
123
  return report
124
 
125
  def _visualize_attention_sample(self):
126
- """샘플 텍스트로 attention 시각화합니다."""
127
  viz = AttentionVisualizer(str(self.save_dir))
128
 
129
  sample_text = "The cat sat on the mat and looked at the bird."
130
  token_ids = self.tokenizer.encode(sample_text, add_special_tokens=False)
131
  input_tensor = torch.tensor([token_ids], dtype=torch.long)
132
 
133
- # 토큰 문자열 (시각화 라벨용)
134
  tokens_str = []
135
  for tid in token_ids:
136
  decoded = self.tokenizer.decode([tid])
137
  tokens_str.append(decoded.replace("\n", "\\n"))
138
 
139
- # Layer 0 attention 추출
140
  attn_weights = viz.extract_attention(
141
  self.model, input_tensor, layer_idx=0, device=self.device
142
  )
@@ -150,7 +150,7 @@ class FullEvaluator:
150
 
151
  @staticmethod
152
  def _average_gen_metrics(gen_results: List[Dict]) -> Dict[str, float]:
153
- """모든 프롬프트의 생성 메트릭 평균."""
154
  if not gen_results:
155
  return {}
156
 
@@ -165,9 +165,9 @@ class FullEvaluator:
165
  }
166
 
167
  def _print_summary(self, report: Dict[str, Any]):
168
- """최종 요약을 출력합니다."""
169
  print("\n" + "=" * 70)
170
- print("📋 평가 요약 리포트")
171
  print("=" * 70)
172
 
173
  # Perplexity
@@ -177,44 +177,44 @@ class FullEvaluator:
177
  print(f" Loss: {ppl['loss']:.4f}")
178
  print(f" PPL: {ppl['perplexity']:.2f}")
179
 
180
- # 등급 판정
181
  ppl_val = ppl["perplexity"]
182
  if ppl_val < 20:
183
- grade = "🌟 우수 (Strong)"
184
  elif ppl_val < 35:
185
- grade = "✅ 양호 (Good)"
186
  elif ppl_val < 60:
187
- grade = "⚠️ 보통 (Fair)"
188
  else:
189
- grade = "❌ 미흡 (학습 추가 필요)"
190
- print(f" 등급: {grade}")
191
 
192
- # 위치별 Loss
193
  if "position_losses" in report:
194
  pl = report["position_losses"]
195
- print(f"\n 📍 위치별 Loss:")
196
- print(f" 초반 (0-50): {pl['early_avg']:.4f}")
197
- print(f" 후반 (-200): {pl['late_avg']:.4f}")
198
- print(f" 컨텍스트 효과: {pl['early_avg'] - pl['late_avg']:.4f} 감소")
199
 
200
- # 생성 품질
201
  if "generation" in report and report["generation"].get("avg_metrics"):
202
  gm = report["generation"]["avg_metrics"]
203
- print(f"\n ✍️ 생성 품질:")
204
- print(f" 평균 길이: {gm.get('avg_length', 0):.0f} ")
205
- print(f" 반복률: {gm.get('repetition_rate', 0):.1%}")
206
- print(f" 어휘 다양성: {gm.get('lexical_diversity', 0):.3f}")
207
 
208
- # 학습 역학
209
  if "training_dynamics" in report:
210
  td = report["training_dynamics"]
211
  if "loss" in td:
212
- print(f"\n 📉 학습 역학:")
213
- print(f" Loss 감소: {td['loss']['initial']:.4f} → {td['loss']['final']:.4f}")
214
- print(f" 스파이크: {len(td['loss']['spikes'])}")
215
 
216
- # 생성된 파일
217
- print(f"\n 📂 결과 파일:")
218
  for f in sorted(self.save_dir.glob("*")):
219
  size = f.stat().st_size / 1024
220
  print(f" {f.name} ({size:.1f} KB)")
 
1
+ """Comprehensive evaluation runner."""
2
 
3
  import json
4
  import time
 
17
 
18
 
19
  class FullEvaluator:
20
+ """Runs all evaluations at once and generates a report.
21
 
22
+ Usage:
23
  ```python
24
  evaluator = FullEvaluator(model, tokenizer, val_dataloader, device)
25
  report = evaluator.run_full_evaluation()
 
48
  self.save_dir.mkdir(parents=True, exist_ok=True)
49
 
50
  def run_full_evaluation(self) -> Dict[str, Any]:
51
+ """Runs the full evaluation."""
52
  report = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
53
 
54
  print("\n" + "=" * 70)
55
+ print("🔍 Starting comprehensive evaluation")
56
  print("=" * 70)
57
 
58
  # ── 1. Perplexity ──
59
  print("\n" + "━" * 40)
60
+ print("Phase 1/4: Perplexity measurement")
61
  print("━" * 40)
62
  ppl_evaluator = PerplexityEvaluator(self.config)
63
  report["perplexity"] = ppl_evaluator.evaluate(
64
  self.model, self.val_dataloader, self.device, self.dtype
65
  )
66
 
67
+ # Per-position loss
68
+ print("\n Measuring per-position loss...")
69
  position_losses = ppl_evaluator.evaluate_per_position(
70
  self.model, self.val_dataloader, self.device, self.dtype
71
  )
 
74
  "late_avg": round(sum(position_losses[-200:]) / max(len(position_losses[-200:]), 1), 4),
75
  }
76
 
77
+ # Per-position loss visualization
78
  dynamics = TrainingDynamicsAnalyzer(str(self.save_dir))
79
  dynamics.plot_position_loss(position_losses, str(self.save_dir / "position_loss.png"))
80
 
81
+ # ── 2. Text generation ──
82
  print("\n" + "━" * 40)
83
+ print("Phase 2/4: Text generation")
84
  print("━" * 40)
85
  gen_evaluator = GenerationEvaluator(self.config)
86
  gen_results = gen_evaluator.generate_samples(
 
91
  "avg_metrics": self._average_gen_metrics(gen_results),
92
  }
93
 
94
+ # ── 3. Training dynamics analysis ──
95
  if self.metrics_history:
96
  print("\n" + "━" * 40)
97
+ print("Phase 3/4: Training dynamics analysis")
98
  print("━" * 40)
99
  report["training_dynamics"] = dynamics.analyze_metrics(self.metrics_history)
100
  dynamics.plot_training_curves(self.metrics_history,
101
  str(self.save_dir / "training_curves.png"))
102
  else:
103
+ print("\n Phase 3/4: Skipped (no metrics_history)")
104
 
105
+ # ── 4. Attention visualization (sample) ──
106
  print("\n" + "━" * 40)
107
+ print("Phase 4/4: Attention visualization")
108
  print("━" * 40)
109
  try:
110
  self._visualize_attention_sample()
111
  except Exception as e:
112
+ print(f" ⚠️ Attention visualization failed: {e}")
113
 
114
+ # ── Save report ──
115
  report_path = self.save_dir / "eval_report.json"
116
  with open(report_path, "w") as f:
117
  json.dump(report, f, indent=2, default=str)
118
+ print(f"\n📋 Report saved: {report_path}")
119
 
120
+ # ── Print summary ──
121
  self._print_summary(report)
122
 
123
  return report
124
 
125
  def _visualize_attention_sample(self):
126
+ """Visualizes attention using a sample text."""
127
  viz = AttentionVisualizer(str(self.save_dir))
128
 
129
  sample_text = "The cat sat on the mat and looked at the bird."
130
  token_ids = self.tokenizer.encode(sample_text, add_special_tokens=False)
131
  input_tensor = torch.tensor([token_ids], dtype=torch.long)
132
 
133
+ # Token strings (for visualization labels)
134
  tokens_str = []
135
  for tid in token_ids:
136
  decoded = self.tokenizer.decode([tid])
137
  tokens_str.append(decoded.replace("\n", "\\n"))
138
 
139
+ # Extract Layer 0 attention
140
  attn_weights = viz.extract_attention(
141
  self.model, input_tensor, layer_idx=0, device=self.device
142
  )
 
150
 
151
  @staticmethod
152
  def _average_gen_metrics(gen_results: List[Dict]) -> Dict[str, float]:
153
+ """Average generation metrics across all prompts."""
154
  if not gen_results:
155
  return {}
156
 
 
165
  }
166
 
167
  def _print_summary(self, report: Dict[str, Any]):
168
+ """Prints the final summary."""
169
  print("\n" + "=" * 70)
170
+ print("📋 Evaluation Summary Report")
171
  print("=" * 70)
172
 
173
  # Perplexity
 
177
  print(f" Loss: {ppl['loss']:.4f}")
178
  print(f" PPL: {ppl['perplexity']:.2f}")
179
 
180
+ # Grade assessment
181
  ppl_val = ppl["perplexity"]
182
  if ppl_val < 20:
183
+ grade = "🌟 Excellent (Strong)"
184
  elif ppl_val < 35:
185
+ grade = "✅ Good"
186
  elif ppl_val < 60:
187
+ grade = "⚠️ Fair"
188
  else:
189
+ grade = "❌ Poor (more training needed)"
190
+ print(f" Grade: {grade}")
191
 
192
+ # Per-position loss
193
  if "position_losses" in report:
194
  pl = report["position_losses"]
195
+ print(f"\n 📍 Per-position Loss:")
196
+ print(f" Early (0-50): {pl['early_avg']:.4f}")
197
+ print(f" Late (-200): {pl['late_avg']:.4f}")
198
+ print(f" Context effect: {pl['early_avg'] - pl['late_avg']:.4f} reduction")
199
 
200
+ # Generation quality
201
  if "generation" in report and report["generation"].get("avg_metrics"):
202
  gm = report["generation"]["avg_metrics"]
203
+ print(f"\n ✍️ Generation Quality:")
204
+ print(f" Avg length: {gm.get('avg_length', 0):.0f} chars")
205
+ print(f" Repetition rate: {gm.get('repetition_rate', 0):.1%}")
206
+ print(f" Lexical diversity: {gm.get('lexical_diversity', 0):.3f}")
207
 
208
+ # Training dynamics
209
  if "training_dynamics" in report:
210
  td = report["training_dynamics"]
211
  if "loss" in td:
212
+ print(f"\n 📉 Training Dynamics:")
213
+ print(f" Loss reduction: {td['loss']['initial']:.4f} → {td['loss']['final']:.4f}")
214
+ print(f" Spikes: {len(td['loss']['spikes'])}")
215
 
216
+ # Generated files
217
+ print(f"\n 📂 Output files:")
218
  for f in sorted(self.save_dir.glob("*")):
219
  size = f.stat().st_size / 1024
220
  print(f" {f.name} ({size:.1f} KB)")
llm_lab/evaluation/generation.py CHANGED
@@ -1,4 +1,4 @@
1
- """텍스트 생성 평가기."""
2
 
3
  from typing import Any, Dict, List, Optional
4
 
@@ -9,47 +9,47 @@ from llm_lab.config import EvalConfig
9
 
10
 
11
  class GenerationEvaluator:
12
- """다양한 프롬프트로 텍스트를 생성하여 품질을 평가합니다.
13
-
14
- 평가 관점:
15
- 1) 문법적 정확성: 영어 문법에 맞는 문장을 생성하는가?
16
- 2) 일관성: 문맥을 유지하며 이어가는가?
17
- 3) 다양성: 같은 프롬프트에 다른 결과를 생성하는가?
18
- 4) 반복 회피: 같은 구절을 반복하지 않는가?
19
- 5) 지식 표현: 학습 데이터의 지식이 반영되는가?
20
-
21
- 1B 모델의 현실적 기대치:
22
- - 문법적으로 올바른 영어 문장 생성
23
- - 짧은 문단 일관성 유지
24
- - 복잡한 추론이나 논리 전개 ❌ ( 모델 필요)
25
- - 사실적 정확성은 보장 ⚠️
26
  """
27
 
28
- # 다양한 도메인의 테스트 프롬프트
29
  DEFAULT_PROMPTS = [
30
- # ── 일반 지식 ──
31
  "The theory of relativity states that",
32
  "In the history of computer science,",
33
  "The human brain is remarkable because",
34
 
35
- # ── 설명/교육 ──
36
  "To understand machine learning, one must first",
37
  "The water cycle begins when",
38
  "Photosynthesis is the process by which",
39
 
40
- # ── 서사/스토리 ──
41
  "Once upon a time, in a small village near the mountains,",
42
  "The detective looked at the evidence and realized that",
43
 
44
- # ── 코드/기술 ──
45
  "def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
46
  "The most important data structures in programming are",
47
 
48
- # ── 짧은 완성 ──
49
  "The capital of France is",
50
  "Water boils at a temperature of",
51
 
52
- # ── 문맥 ──
53
  ("Artificial intelligence has transformed many industries. "
54
  "In healthcare, AI is used for diagnosis and drug discovery. "
55
  "In finance, it powers algorithmic trading and fraud detection. "
@@ -68,7 +68,7 @@ class GenerationEvaluator:
68
  prompts: Optional[List[str]] = None,
69
  verbose: bool = True,
70
  ) -> List[Dict[str, Any]]:
71
- """프롬프트별로 텍스트를 생성합니다.
72
 
73
  Returns:
74
  [{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
@@ -79,7 +79,7 @@ class GenerationEvaluator:
79
 
80
  if verbose:
81
  print("\n" + "=" * 70)
82
- print("📝 텍스트 생성 평가")
83
  print("=" * 70)
84
 
85
  for idx, prompt in enumerate(prompts):
@@ -91,17 +91,17 @@ class GenerationEvaluator:
91
 
92
  if verbose:
93
  print(f"\n{'─'*60}")
94
- print(f"프롬프트 [{idx+1}/{len(prompts)}]:")
95
  print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
96
  print(f"{'─'*60}")
97
 
98
- # 프롬프트 인코딩
99
  prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
100
  input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
101
 
102
  all_texts = []
103
  for sample_idx in range(self.config.num_samples):
104
- # 생성
105
  generated_ids = model.generate(
106
  input_tensor,
107
  max_new_tokens=self.config.max_new_tokens,
@@ -110,7 +110,7 @@ class GenerationEvaluator:
110
  top_p=self.config.top_p,
111
  )
112
 
113
- # 디코딩 (프롬프트 이후 부분만)
114
  new_ids = generated_ids[0][len(prompt_ids):].tolist()
115
  generated_text = tokenizer.decode(new_ids)
116
  all_texts.append(generated_text)
@@ -118,23 +118,23 @@ class GenerationEvaluator:
118
  prompt_results["generations"].append(generated_text)
119
 
120
  if verbose:
121
- print(f"\n ✍️ 생성 #{sample_idx+1}:")
122
- # 깔끔한 출력 (줄바꿈 포함)
123
  display_text = generated_text[:500]
124
  for line in display_text.split("\n"):
125
  print(f" {line}")
126
  if len(generated_text) > 500:
127
- print(f" ... ( {len(generated_text)} 문자)")
128
 
129
- # 생성 품질 메트릭
130
  prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
131
 
132
  if verbose and prompt_results["metrics"]:
133
  m = prompt_results["metrics"]
134
- print(f"\n 📊 메트릭: "
135
- f"평균 길이={m['avg_length']:.0f}, "
136
- f"반복률={m['repetition_rate']:.1%}, "
137
- f"어휘 다양성={m['lexical_diversity']:.2f}")
138
 
139
  results.append(prompt_results)
140
 
@@ -142,23 +142,23 @@ class GenerationEvaluator:
142
 
143
  @staticmethod
144
  def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
145
- """생성 텍스트의 품질 메트릭을 계산합니다.
146
-
147
- 메트릭:
148
- - avg_length: 평균 생성 길이 (문자)
149
- - avg_word_count: 평균 단어
150
- - repetition_rate: n-gram 반복률 (낮을수록 좋음)
151
- - lexical_diversity: 고유 단어 비율 (높을수록 다양)
152
- - sample_diversity: 샘플 다양성 (다른 생성끼리 얼마나 다른가)
153
  """
154
  if not texts:
155
  return {}
156
 
157
- # 길이
158
  lengths = [len(t) for t in texts]
159
  word_counts = [len(t.split()) for t in texts]
160
 
161
- # 반복률 (4-gram 기준)
162
  rep_rates = []
163
  for text in texts:
164
  words = text.lower().split()
@@ -167,9 +167,9 @@ class GenerationEvaluator:
167
  continue
168
  ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
169
  unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
170
- rep_rates.append(1.0 - unique_ratio) # 반복률 = 1 - 고유비율
171
 
172
- # 어휘 다양성 (Type-Token Ratio)
173
  diversities = []
174
  for text in texts:
175
  words = text.lower().split()
@@ -178,7 +178,7 @@ class GenerationEvaluator:
178
  else:
179
  diversities.append(0.0)
180
 
181
- # 샘플 다양성 (자카드 유사도의 )
182
  sample_div = 0.0
183
  if len(texts) > 1:
184
  word_sets = [set(t.lower().split()) for t in texts]
 
1
+ """Text generation evaluator."""
2
 
3
  from typing import Any, Dict, List, Optional
4
 
 
9
 
10
 
11
  class GenerationEvaluator:
12
+ """Evaluates text quality by generating from various prompts.
13
+
14
+ Evaluation perspectives:
15
+ 1) Grammatical accuracy: Does it generate grammatically correct English sentences?
16
+ 2) Coherence: Does it maintain context continuity?
17
+ 3) Diversity: Does it produce different outputs for the same prompt?
18
+ 4) Repetition avoidance: Does it avoid repeating the same phrases?
19
+ 5) Knowledge expression: Is knowledge from the training data reflected?
20
+
21
+ Realistic expectations for a 1B model:
22
+ - Generates grammatically correct English sentences
23
+ - Maintains coherence within short paragraphs
24
+ - Complex reasoning or extended logical chains ❌ (requires a larger model)
25
+ - Factual accuracy is not guaranteed ⚠️
26
  """
27
 
28
+ # Test prompts from various domains
29
  DEFAULT_PROMPTS = [
30
+ # ── General knowledge ──
31
  "The theory of relativity states that",
32
  "In the history of computer science,",
33
  "The human brain is remarkable because",
34
 
35
+ # ── Explanation / Education ──
36
  "To understand machine learning, one must first",
37
  "The water cycle begins when",
38
  "Photosynthesis is the process by which",
39
 
40
+ # ── Narrative / Story ──
41
  "Once upon a time, in a small village near the mountains,",
42
  "The detective looked at the evidence and realized that",
43
 
44
+ # ── Code / Technical ──
45
  "def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
46
  "The most important data structures in programming are",
47
 
48
+ # ── Short completion ──
49
  "The capital of France is",
50
  "Water boils at a temperature of",
51
 
52
+ # ── Long context ──
53
  ("Artificial intelligence has transformed many industries. "
54
  "In healthcare, AI is used for diagnosis and drug discovery. "
55
  "In finance, it powers algorithmic trading and fraud detection. "
 
68
  prompts: Optional[List[str]] = None,
69
  verbose: bool = True,
70
  ) -> List[Dict[str, Any]]:
71
+ """Generates text for each prompt.
72
 
73
  Returns:
74
  [{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
 
79
 
80
  if verbose:
81
  print("\n" + "=" * 70)
82
+ print("📝 Text Generation Evaluation")
83
  print("=" * 70)
84
 
85
  for idx, prompt in enumerate(prompts):
 
91
 
92
  if verbose:
93
  print(f"\n{'─'*60}")
94
+ print(f"Prompt [{idx+1}/{len(prompts)}]:")
95
  print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
96
  print(f"{'─'*60}")
97
 
98
+ # Encode prompt
99
  prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
100
  input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
101
 
102
  all_texts = []
103
  for sample_idx in range(self.config.num_samples):
104
+ # Generate
105
  generated_ids = model.generate(
106
  input_tensor,
107
  max_new_tokens=self.config.max_new_tokens,
 
110
  top_p=self.config.top_p,
111
  )
112
 
113
+ # Decode (only the part after the prompt)
114
  new_ids = generated_ids[0][len(prompt_ids):].tolist()
115
  generated_text = tokenizer.decode(new_ids)
116
  all_texts.append(generated_text)
 
118
  prompt_results["generations"].append(generated_text)
119
 
120
  if verbose:
121
+ print(f"\n ✍️ Generation #{sample_idx+1}:")
122
+ # Clean output (including newlines)
123
  display_text = generated_text[:500]
124
  for line in display_text.split("\n"):
125
  print(f" {line}")
126
  if len(generated_text) > 500:
127
+ print(f" ... (total {len(generated_text)} characters)")
128
 
129
+ # Generation quality metrics
130
  prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
131
 
132
  if verbose and prompt_results["metrics"]:
133
  m = prompt_results["metrics"]
134
+ print(f"\n 📊 Metrics: "
135
+ f"avg_length={m['avg_length']:.0f} chars, "
136
+ f"repetition_rate={m['repetition_rate']:.1%}, "
137
+ f"lexical_diversity={m['lexical_diversity']:.2f}")
138
 
139
  results.append(prompt_results)
140
 
 
142
 
143
  @staticmethod
144
  def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
145
+ """Computes quality metrics for generated text.
146
+
147
+ Metrics:
148
+ - avg_length: Average generation length (characters)
149
+ - avg_word_count: Average word count
150
+ - repetition_rate: n-gram repetition rate (lower is better)
151
+ - lexical_diversity: Ratio of unique words (higher means more diverse)
152
+ - sample_diversity: Diversity across samples (how different are different generations)
153
  """
154
  if not texts:
155
  return {}
156
 
157
+ # Length
158
  lengths = [len(t) for t in texts]
159
  word_counts = [len(t.split()) for t in texts]
160
 
161
+ # Repetition rate (based on 4-grams)
162
  rep_rates = []
163
  for text in texts:
164
  words = text.lower().split()
 
167
  continue
168
  ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
169
  unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
170
+ rep_rates.append(1.0 - unique_ratio) # repetition rate = 1 - unique ratio
171
 
172
+ # Lexical diversity (Type-Token Ratio)
173
  diversities = []
174
  for text in texts:
175
  words = text.lower().split()
 
178
  else:
179
  diversities.append(0.0)
180
 
181
+ # Inter-sample diversity (inverse of Jaccard similarity)
182
  sample_div = 0.0
183
  if len(texts) > 1:
184
  word_sets = [set(t.lower().split()) for t in texts]
llm_lab/evaluation/perplexity.py CHANGED
@@ -1,4 +1,4 @@
1
- """Perplexity(PPL) 평가기."""
2
 
3
  import math
4
  import time
@@ -13,26 +13,26 @@ from llm_lab.config import EvalConfig
13
 
14
 
15
  class PerplexityEvaluator:
16
- """Perplexity(PPL)를 측정합니다.
17
 
18
- Perplexity?
19
  PPL = exp(average cross-entropy loss)
20
 
21
- 직관적 의미:
22
- - PPL = 1: 완벽한 예측 (불가능)
23
- - PPL = 10: 매번 10개 후보 고르는 수준
24
- - PPL = 100: 100개 후보 고르는 수준 (무작위에 가까움)
25
- - PPL = 32000: vocab 전체에서 랜덤 선택 (초기 랜덤 모델)
26
-
27
- 좋은 1B 모델 기준 (영어 텍스트):
28
- - 5B 토큰 학습: PPL ~30-40
29
- - 10B 토큰 학습: PPL ~20-30
30
- - 20B 토큰 학습: PPL ~15-25
31
-
32
- 측정 방법:
33
- - 검증 데이터셋의 모든 토큰에 대해 cross-entropy 계산
34
- - 토큰 단위 평균 exp() 적용
35
- - 패딩 토큰은 제외 (ignore_index=-100)
36
  """
37
 
38
  def __init__(self, config: EvalConfig):
@@ -47,14 +47,14 @@ class PerplexityEvaluator:
47
  dtype: torch.dtype = torch.bfloat16,
48
  desc: str = "Evaluation",
49
  ) -> Dict[str, float]:
50
- """Perplexity를 측정합니다.
51
 
52
  Returns:
53
  {
54
- "loss": 평균 cross-entropy loss,
55
  "perplexity": exp(loss),
56
- "num_tokens": 평가에 사용된 토큰 ,
57
- "num_batches": 평가에 사용된 배치 ,
58
  }
59
  """
60
  model.eval()
@@ -76,7 +76,7 @@ class PerplexityEvaluator:
76
  with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
77
  logits, _ = model(input_ids)
78
 
79
- # 토큰별 cross-entropy (reduction='none')
80
  # logits: (B, S, V) → (B*S, V)
81
  # targets: (B, S) → (B*S,)
82
  loss_per_token = F.cross_entropy(
@@ -86,7 +86,7 @@ class PerplexityEvaluator:
86
  reduction="none",
87
  )
88
 
89
- # -100이 아닌 유효 토큰만 카운트
90
  valid_mask = (targets.view(-1) != -100)
91
  valid_tokens = valid_mask.sum().item()
92
 
@@ -100,7 +100,7 @@ class PerplexityEvaluator:
100
 
101
  elapsed = time.time() - start_time
102
  avg_loss = total_loss / max(total_tokens, 1)
103
- perplexity = math.exp(min(avg_loss, 100)) # overflow 방지
104
 
105
  results = {
106
  "loss": round(avg_loss, 4),
@@ -113,8 +113,8 @@ class PerplexityEvaluator:
113
  print(f" ────────────────────────────────")
114
  print(f" Loss: {results['loss']:.4f}")
115
  print(f" Perplexity: {results['perplexity']:.2f}")
116
- print(f" 평가 토큰: {total_tokens:,}")
117
- print(f" 소요 시간: {elapsed:.1f}")
118
 
119
  return results
120
 
@@ -127,12 +127,12 @@ class PerplexityEvaluator:
127
  dtype: torch.dtype = torch.bfloat16,
128
  max_batches: int = 50,
129
  ) -> List[float]:
130
- """시퀀스 위치별 Loss를 측정합니다.
131
 
132
- 학습 포인트:
133
- - 위치 0~10: Loss가 높음 (문맥이 부족)
134
- - 위치 100+: Loss 안정적으로 낮아짐 (문맥 활용)
135
- - 패턴이 Transformer in-context learning 능력을 보여줌
136
  """
137
  model.eval()
138
  seq_len = None
@@ -155,7 +155,7 @@ class PerplexityEvaluator:
155
  with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
156
  logits, _ = model(input_ids)
157
 
158
- # (B, S) 형태의 토큰별 loss
159
  loss_per_token = F.cross_entropy(
160
  logits.view(-1, logits.size(-1)),
161
  targets.view(-1),
@@ -167,6 +167,6 @@ class PerplexityEvaluator:
167
  position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
168
  position_count += valid_mask.sum(dim=0)
169
 
170
- # 위치별 평균 loss
171
  position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
172
  return position_avg_loss
 
1
+ """Perplexity (PPL) evaluator."""
2
 
3
  import math
4
  import time
 
13
 
14
 
15
  class PerplexityEvaluator:
16
+ """Measures Perplexity (PPL).
17
 
18
+ What is Perplexity?
19
  PPL = exp(average cross-entropy loss)
20
 
21
+ Intuitive meaning:
22
+ - PPL = 1: Perfect prediction (impossible)
23
+ - PPL = 10: Equivalent to picking from 10 candidates each time
24
+ - PPL = 100: Equivalent to picking from 100 candidates (close to random)
25
+ - PPL = 32000: Random selection from the entire vocab (initial random model)
26
+
27
+ Good benchmark for a 1B model (English web text):
28
+ - Trained on 5B tokens: PPL ~30-40
29
+ - Trained on 10B tokens: PPL ~20-30
30
+ - Trained on 20B tokens: PPL ~15-25
31
+
32
+ Measurement method:
33
+ - Compute cross-entropy over all tokens in the validation dataset
34
+ - Average per token, then apply exp()
35
+ - Padding tokens are excluded (ignore_index=-100)
36
  """
37
 
38
  def __init__(self, config: EvalConfig):
 
47
  dtype: torch.dtype = torch.bfloat16,
48
  desc: str = "Evaluation",
49
  ) -> Dict[str, float]:
50
+ """Measures Perplexity.
51
 
52
  Returns:
53
  {
54
+ "loss": average cross-entropy loss,
55
  "perplexity": exp(loss),
56
+ "num_tokens": total number of tokens used for evaluation,
57
+ "num_batches": number of batches used for evaluation,
58
  }
59
  """
60
  model.eval()
 
76
  with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
77
  logits, _ = model(input_ids)
78
 
79
+ # Per-token cross-entropy (reduction='none')
80
  # logits: (B, S, V) → (B*S, V)
81
  # targets: (B, S) → (B*S,)
82
  loss_per_token = F.cross_entropy(
 
86
  reduction="none",
87
  )
88
 
89
+ # Count only valid tokens that are not -100
90
  valid_mask = (targets.view(-1) != -100)
91
  valid_tokens = valid_mask.sum().item()
92
 
 
100
 
101
  elapsed = time.time() - start_time
102
  avg_loss = total_loss / max(total_tokens, 1)
103
+ perplexity = math.exp(min(avg_loss, 100)) # prevent overflow
104
 
105
  results = {
106
  "loss": round(avg_loss, 4),
 
113
  print(f" ────────────────────────────────")
114
  print(f" Loss: {results['loss']:.4f}")
115
  print(f" Perplexity: {results['perplexity']:.2f}")
116
+ print(f" Eval tokens: {total_tokens:,}")
117
+ print(f" Elapsed: {elapsed:.1f}s")
118
 
119
  return results
120
 
 
127
  dtype: torch.dtype = torch.bfloat16,
128
  max_batches: int = 50,
129
  ) -> List[float]:
130
+ """Measures loss per position within a sequence.
131
 
132
+ Learning insight:
133
+ - Positions 0~10: Higher loss (insufficient context)
134
+ - Positions 100+: Loss stabilizes lower (context is leveraged)
135
+ - This pattern demonstrates the Transformer's in-context learning capability
136
  """
137
  model.eval()
138
  seq_len = None
 
155
  with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
156
  logits, _ = model(input_ids)
157
 
158
+ # Per-token loss in shape (B, S)
159
  loss_per_token = F.cross_entropy(
160
  logits.view(-1, logits.size(-1)),
161
  targets.view(-1),
 
167
  position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
168
  position_count += valid_mask.sum(dim=0)
169
 
170
+ # Average loss per position
171
  position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
172
  return position_avg_loss
llm_lab/evaluation/runner.py CHANGED
@@ -1,4 +1,4 @@
1
- """평가 실행 헬퍼 (Quick Start)."""
2
 
3
  from typing import Any, Dict, Optional
4
 
@@ -20,13 +20,13 @@ def run_evaluation(
20
  metrics_history: Optional[Dict[str, list]] = None,
21
  config: Optional[EvalConfig] = None,
22
  ) -> Dict[str, Any]:
23
- """평가를 번에 실행합니다.
24
 
25
- 사용법 (Colab):
26
  ```python
27
  from llm_lab.evaluation import run_evaluation
28
 
29
- # 학습 완료
30
  report = run_evaluation(
31
  model=trainer.model,
32
  tokenizer=tokenizer,
@@ -50,7 +50,7 @@ def run_evaluation(
50
 
51
  report = evaluator.run_full_evaluation()
52
 
53
- # 인사이트 체크리스트
54
  InsightChecklist.run_checklist(report, metrics_history)
55
 
56
  return report
 
1
+ """Evaluation runner helper (Quick Start)."""
2
 
3
  from typing import Any, Dict, Optional
4
 
 
20
  metrics_history: Optional[Dict[str, list]] = None,
21
  config: Optional[EvalConfig] = None,
22
  ) -> Dict[str, Any]:
23
+ """Runs all evaluations in one call.
24
 
25
+ Usage (Colab):
26
  ```python
27
  from llm_lab.evaluation import run_evaluation
28
 
29
+ # After training is complete
30
  report = run_evaluation(
31
  model=trainer.model,
32
  tokenizer=tokenizer,
 
50
 
51
  report = evaluator.run_full_evaluation()
52
 
53
+ # Insight checklist
54
  InsightChecklist.run_checklist(report, metrics_history)
55
 
56
  return report
llm_lab/evaluation/scaling.py CHANGED
@@ -1,4 +1,4 @@
1
- """Scaling Law 분석기."""
2
 
3
  from pathlib import Path
4
  from typing import Any, Dict, List, Optional
@@ -19,17 +19,17 @@ except ImportError:
19
 
20
 
21
  class ScalingAnalyzer:
22
- """10M → 100M → 1B 모델의 Scaling Law를 분석합니다.
23
 
24
  Chinchilla Scaling Law (2022):
25
- - 최적 학습: 토큰 ≈ 20 × 파라미터
26
- - Loss ∝ N^(-α) × D^(-β) (N=파라미터, D=데이터)
27
- - α ≈ 0.076, β ≈ 0.095 (논문 기준)
28
-
29
- 분석의 목적:
30
- - 우리 모델이 Scaling Law를 따르는지 확인
31
- - 모델/더 많은 데이터의 효과를 예측
32
- - 컴퓨팅 자원 배분의 최적점 이해
33
  """
34
 
35
  def __init__(self, save_dir: str = "./eval_results"):
@@ -40,7 +40,7 @@ class ScalingAnalyzer:
40
  self,
41
  model_results: List[Dict[str, Any]],
42
  ) -> Dict[str, Any]:
43
- """여러 모델 크기의 결과를 비교 분석합니다.
44
 
45
  Args:
46
  model_results: [
@@ -50,25 +50,25 @@ class ScalingAnalyzer:
50
  ]
51
 
52
  Returns:
53
- 분석 결과 딕셔너리
54
  """
55
  if len(model_results) < 2:
56
- print("⚠️ Scaling 분석에는 최소 2개 모델 결과가 필요합니다.")
57
  return {}
58
 
59
  print("\n" + "=" * 70)
60
- print("📈 Scaling Law 분석")
61
  print("=" * 70)
62
 
63
- # ── 결과 테이블 ──
64
- print(f"\n {'모델':<8} {'파라미터':>12} {'토큰':>10} {'Loss':>8} {'PPL':>8}")
65
  print(f" {'─'*52}")
66
  for r in model_results:
67
  params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
68
  tokens_str = f"{r['tokens']/1e9:.1f}B"
69
  print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
70
 
71
- # ── Scaling 효율 계산 ──
72
  analysis = {"models": model_results, "scaling_efficiency": []}
73
 
74
  for i in range(1, len(model_results)):
@@ -89,17 +89,17 @@ class ScalingAnalyzer:
89
  analysis["scaling_efficiency"].append(efficiency)
90
 
91
  print(f"\n {prev['name']} → {curr['name']}:")
92
- print(f" 파라미터 ×{param_ratio:.1f}")
93
- print(f" Loss 감소: {loss_reduction:.4f}")
94
- print(f" PPL 감소: {ppl_reduction*100:.1f}%")
95
 
96
- # ── Chinchilla 최적성 체크 ──
97
- print(f"\n Chinchilla 최적성 체크 (토큰 ≈ 20 × 파라미터):")
98
  for r in model_results:
99
  actual_ratio = r["tokens"] / r["params"]
100
- status = "✅ 최적 범위" if 15 <= actual_ratio <= 25 else "⚠️ 범위 "
101
- print(f" {r['name']}: 토큰/파라미터 = {actual_ratio:.1f}x "
102
- f"(최적: 20x) {status}")
103
 
104
  analysis["chinchilla_ratios"] = [
105
  {"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
@@ -113,9 +113,9 @@ class ScalingAnalyzer:
113
  model_results: List[Dict[str, Any]],
114
  save_path: Optional[str] = None,
115
  ):
116
- """Scaling 곡선을 시각화합니다."""
117
  if not HAS_MATPLOTLIB or not HAS_NUMPY:
118
- print("⚠️ matplotlib/numpy 필요합니다: pip install matplotlib numpy")
119
  return
120
 
121
  fig, axes = plt.subplots(1, 2, figsize=(14, 5))
@@ -149,5 +149,5 @@ class ScalingAnalyzer:
149
 
150
  save_path = save_path or str(self.save_dir / "scaling_curves.png")
151
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
152
- print(f"\n 📊 Scaling 곡선 저장: {save_path}")
153
  plt.close(fig)
 
1
+ """Scaling Law analyzer."""
2
 
3
  from pathlib import Path
4
  from typing import Any, Dict, List, Optional
 
19
 
20
 
21
  class ScalingAnalyzer:
22
+ """Analyzes Scaling Law across 10M → 100M → 1B models.
23
 
24
  Chinchilla Scaling Law (2022):
25
+ - Optimal training: tokens ≈ 20 × number of parameters
26
+ - Loss ∝ N^(-α) × D^(-β) (N=parameters, D=data)
27
+ - α ≈ 0.076, β ≈ 0.095 (per the paper)
28
+
29
+ Purpose of this analysis:
30
+ - Verify whether our model follows the Scaling Law
31
+ - Predict the effect of larger models / more data
32
+ - Understand the optimal allocation of compute resources
33
  """
34
 
35
  def __init__(self, save_dir: str = "./eval_results"):
 
40
  self,
41
  model_results: List[Dict[str, Any]],
42
  ) -> Dict[str, Any]:
43
+ """Comparatively analyzes results across multiple model sizes.
44
 
45
  Args:
46
  model_results: [
 
50
  ]
51
 
52
  Returns:
53
+ Analysis result dictionary
54
  """
55
  if len(model_results) < 2:
56
+ print("⚠️ Scaling analysis requires results from at least 2 models.")
57
  return {}
58
 
59
  print("\n" + "=" * 70)
60
+ print("📈 Scaling Law Analysis")
61
  print("=" * 70)
62
 
63
+ # ── Results table ──
64
+ print(f"\n {'Model':<8} {'Parameters':>12} {'Tokens':>10} {'Loss':>8} {'PPL':>8}")
65
  print(f" {'─'*52}")
66
  for r in model_results:
67
  params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
68
  tokens_str = f"{r['tokens']/1e9:.1f}B"
69
  print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
70
 
71
+ # ── Scaling efficiency calculation ──
72
  analysis = {"models": model_results, "scaling_efficiency": []}
73
 
74
  for i in range(1, len(model_results)):
 
89
  analysis["scaling_efficiency"].append(efficiency)
90
 
91
  print(f"\n {prev['name']} → {curr['name']}:")
92
+ print(f" Parameters ×{param_ratio:.1f}")
93
+ print(f" Loss reduction: {loss_reduction:.4f}")
94
+ print(f" PPL reduction: {ppl_reduction*100:.1f}%")
95
 
96
+ # ── Chinchilla optimality check ──
97
+ print(f"\n Chinchilla optimality check (tokens ≈ 20 × parameters):")
98
  for r in model_results:
99
  actual_ratio = r["tokens"] / r["params"]
100
+ status = "✅ Optimal range" if 15 <= actual_ratio <= 25 else "⚠️ Out of range"
101
+ print(f" {r['name']}: tokens/parameters = {actual_ratio:.1f}x "
102
+ f"(optimal: 20x) {status}")
103
 
104
  analysis["chinchilla_ratios"] = [
105
  {"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
 
113
  model_results: List[Dict[str, Any]],
114
  save_path: Optional[str] = None,
115
  ):
116
+ """Visualizes scaling curves."""
117
  if not HAS_MATPLOTLIB or not HAS_NUMPY:
118
+ print("⚠️ matplotlib/numpy required: pip install matplotlib numpy")
119
  return
120
 
121
  fig, axes = plt.subplots(1, 2, figsize=(14, 5))
 
149
 
150
  save_path = save_path or str(self.save_dir / "scaling_curves.png")
151
  fig.savefig(save_path, dpi=150, bbox_inches="tight")
152
+ print(f"\n 📊 Scaling curves saved: {save_path}")
153
  plt.close(fig)
llm_lab/model/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """모델 아키텍처 모듈 — LLaMA-style Decoder-Only Transformer."""
2
  from .norm import RMSNorm
3
  from .rope import RotaryPositionalEmbedding
4
  from .attention import GroupedQueryAttention
 
1
+ """Model architecture module — LLaMA-style Decoder-Only Transformer."""
2
  from .norm import RMSNorm
3
  from .rope import RotaryPositionalEmbedding
4
  from .attention import GroupedQueryAttention
llm_lab/model/attention.py CHANGED
@@ -11,20 +11,20 @@ from .rope import RotaryPositionalEmbedding
11
 
12
 
13
  class GroupedQueryAttention(nn.Module):
14
- """GQA: Multi-Head Attention의 메모리 효율적 변형.
15
 
16
  MHA vs GQA vs MQA:
17
- - MHA (Multi-Head Attention): Q, K, V 모두 num_heads메모리
18
- - MQA (Multi-Query Attention): K, V 1개 헤드 공유품질 저하 우려
19
- - GQA (Grouped Query Attention): K, V num_kv_heads개로 그룹화
20
- → MHA MQA의 중간, 좋은 품질-효율 균형
21
 
22
- 예시 (num_heads=16, num_kv_heads=4):
23
- Q 헤드: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
24
- K/V 그룹: [ 0 , 1 , 2 , 3 ]
25
- → Q 헤드 4개가 K/V 헤드 1개를 공유
26
 
27
- Attention 수식:
28
  Attention(Q, K, V) = softmax(Q·K^T / √d_k) · V
29
  """
30
 
@@ -36,14 +36,14 @@ class GroupedQueryAttention(nn.Module):
36
  self.num_kv_heads = config.num_kv_heads
37
  self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
38
 
39
- # Q/K/V 프로젝션
40
  # Q: hidden_dim → num_heads × head_dim
41
  self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
42
- # K, V: hidden_dim → num_kv_heads × head_dim (Q보다 작음!)
43
  self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
44
  self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
45
 
46
- # 출력 프로젝션: 모든 헤드의 출력을 다시 hidden_dim으로
47
  self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
48
 
49
  # RoPE
@@ -51,7 +51,7 @@ class GroupedQueryAttention(nn.Module):
51
  dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
52
  )
53
 
54
- # Attention dropout (pretraining에서는 보통 0)
55
  self.attn_dropout = nn.Dropout(config.dropout)
56
 
57
  def forward(
@@ -64,7 +64,7 @@ class GroupedQueryAttention(nn.Module):
64
  Args:
65
  x: (batch_size, seq_len, hidden_dim)
66
  mask: (seq_len, seq_len) causal mask
67
- position_offset: 위치 오프셋 (추론 사용)
68
 
69
  Returns:
70
  (batch_size, seq_len, hidden_dim)
@@ -72,13 +72,13 @@ class GroupedQueryAttention(nn.Module):
72
  B, S, _ = x.shape
73
 
74
  # ──────────────────────────────────────────────
75
- # Step 1: Q, K, V 프로젝션
76
  # ──────────────────────────────────────────────
77
  q = self.q_proj(x) # (B, S, num_heads × head_dim)
78
  k = self.k_proj(x) # (B, S, num_kv_heads × head_dim)
79
  v = self.v_proj(x) # (B, S, num_kv_heads × head_dim)
80
 
81
- # 멀티헤드 형태로 reshape
82
  q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
83
  # → (B, num_heads, S, head_dim)
84
  k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
@@ -86,16 +86,16 @@ class GroupedQueryAttention(nn.Module):
86
  v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
87
 
88
  # ──────────────────────────────────────────────
89
- # Step 2: RoPE 적용 (Q, K에만! V에는 적용하지 않음)
90
  # ──────────────────────────────────────────────
91
- # 위치 정보는 "어디를 볼지"(Q·K)에만 영향을 줘야 하고,
92
- # "무엇을 가져올지"(V)에는 영향을 주면 안 됩니다.
93
  q, k = self.rope(q, k, position_offset)
94
 
95
  # ──────────────────────────────────────────────
96
- # Step 3: GQA - KV 헤드 확장 (repeat)
97
  # ──────────────────────────────────────────────
98
- # num_kv_heads=4 → num_heads=16: KV 4 반복
99
  if self.num_kv_groups > 1:
100
  k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
101
  v = self._repeat_kv(v)
@@ -103,17 +103,17 @@ class GroupedQueryAttention(nn.Module):
103
  # ──────────────────────────────────────────────
104
  # Step 4: Scaled Dot-Product Attention
105
  # ──────────────────────────────────────────────
106
- # PyTorch >= 2.0 최적화된 구현 사용 (Flash Attention 자동 적용)
107
  attn_out = F.scaled_dot_product_attention(
108
  q, k, v,
109
  attn_mask=mask,
110
  dropout_p=self.config.dropout if self.training else 0.0,
111
- is_causal=(mask is None), # mask가 없으면 자동 causal masking
112
  )
113
  # → (B, num_heads, S, head_dim)
114
 
115
  # ──────────────────────────────────────────────
116
- # Step 5: 헤드 합치기 + 출력 프로젝션
117
  # ──────────────────────────────────────────────
118
  attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
119
  # → (B, S, num_heads × head_dim)
@@ -121,11 +121,11 @@ class GroupedQueryAttention(nn.Module):
121
  return self.o_proj(attn_out) # → (B, S, hidden_dim)
122
 
123
  def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
124
- """KV 헤드를 Q 헤드 수에 맞게 반복합니다.
125
 
126
  (B, num_kv_heads, S, head_dim) → (B, num_heads, S, head_dim)
127
 
128
- : num_kv_heads=4, num_kv_groups=4
129
  [kv0, kv1, kv2, kv3] → [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
130
  """
131
  B, H_kv, S, D = x.shape
 
11
 
12
 
13
  class GroupedQueryAttention(nn.Module):
14
+ """GQA: A memory-efficient variant of Multi-Head Attention.
15
 
16
  MHA vs GQA vs MQA:
17
+ - MHA (Multi-Head Attention): Q, K, V all have num_heads → high memory usage
18
+ - MQA (Multi-Query Attention): K, V share a single head risk of quality degradation
19
+ - GQA (Grouped Query Attention): K, V are grouped into num_kv_heads
20
+ a middle ground between MHA and MQA, good quality-efficiency balance
21
 
22
+ Example (num_heads=16, num_kv_heads=4):
23
+ Q heads: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
24
+ K/V groups: [ 0 , 1 , 2 , 3 ]
25
+ 4 Q heads share 1 K/V head
26
 
27
+ Attention formula:
28
  Attention(Q, K, V) = softmax(Q·K^T / √d_k) · V
29
  """
30
 
 
36
  self.num_kv_heads = config.num_kv_heads
37
  self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
38
 
39
+ # Q/K/V projections
40
  # Q: hidden_dim → num_heads × head_dim
41
  self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
42
+ # K, V: hidden_dim → num_kv_heads × head_dim (smaller than Q!)
43
  self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
44
  self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
45
 
46
+ # Output projection: merge all head outputs back to hidden_dim
47
  self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
48
 
49
  # RoPE
 
51
  dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
52
  )
53
 
54
+ # Attention dropout (typically 0 during pretraining)
55
  self.attn_dropout = nn.Dropout(config.dropout)
56
 
57
  def forward(
 
64
  Args:
65
  x: (batch_size, seq_len, hidden_dim)
66
  mask: (seq_len, seq_len) causal mask
67
+ position_offset: position offset (used during inference)
68
 
69
  Returns:
70
  (batch_size, seq_len, hidden_dim)
 
72
  B, S, _ = x.shape
73
 
74
  # ──────────────────────────────────────────────
75
+ # Step 1: Q, K, V projections
76
  # ──────────────────────────────────────────────
77
  q = self.q_proj(x) # (B, S, num_heads × head_dim)
78
  k = self.k_proj(x) # (B, S, num_kv_heads × head_dim)
79
  v = self.v_proj(x) # (B, S, num_kv_heads × head_dim)
80
 
81
+ # Reshape into multi-head form
82
  q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
83
  # → (B, num_heads, S, head_dim)
84
  k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
 
86
  v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
87
 
88
  # ──────────────────────────────────────────────
89
+ # Step 2: Apply RoPE (to Q and K only! Not to V)
90
  # ──────────────────────────────────────────────
91
+ # Positional information should only affect "where to attend" (Q·K),
92
+ # not "what to retrieve" (V).
93
  q, k = self.rope(q, k, position_offset)
94
 
95
  # ──────────────────────────────────────────────
96
+ # Step 3: GQA - expand KV heads (repeat)
97
  # ──────────────────────────────────────────────
98
+ # num_kv_heads=4 → num_heads=16: repeat each KV 4 times
99
  if self.num_kv_groups > 1:
100
  k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
101
  v = self._repeat_kv(v)
 
103
  # ──────────────────────────────────────────────
104
  # Step 4: Scaled Dot-Product Attention
105
  # ──────────────────────────────────────────────
106
+ # Uses PyTorch >= 2.0's optimized implementation (Flash Attention applied automatically)
107
  attn_out = F.scaled_dot_product_attention(
108
  q, k, v,
109
  attn_mask=mask,
110
  dropout_p=self.config.dropout if self.training else 0.0,
111
+ is_causal=(mask is None), # apply automatic causal masking when no mask is provided
112
  )
113
  # → (B, num_heads, S, head_dim)
114
 
115
  # ──────────────────────────────────────────────
116
+ # Step 5: Merge heads + output projection
117
  # ──────────────────────────────────────────────
118
  attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
119
  # → (B, S, num_heads × head_dim)
 
121
  return self.o_proj(attn_out) # → (B, S, hidden_dim)
122
 
123
  def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
124
+ """Repeat KV heads to match the number of Q heads.
125
 
126
  (B, num_kv_heads, S, head_dim) → (B, num_heads, S, head_dim)
127
 
128
+ Example: num_kv_heads=4, num_kv_groups=4
129
  [kv0, kv1, kv2, kv3] → [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
130
  """
131
  B, H_kv, S, D = x.shape
llm_lab/model/feedforward.py CHANGED
@@ -8,41 +8,41 @@ from llm_lab.config import ModelConfig
8
 
9
 
10
  class SwiGLUFeedForward(nn.Module):
11
- """SwiGLU: Gated Linear Unit with Swish 활성화 함수.
12
 
13
- 기존 FFN:
14
  FFN(x) = ReLU(x·W1 + b1)·W2 + b2
15
- 단순한 비선형 변환
16
 
17
  SwiGLU FFN:
18
  SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down
19
- 게이팅 메커니즘으로 정보 흐름을 제어
20
 
21
- SwiGLU 더 좋은가?
22
- - Swish(x) = x · sigmoid(x): 부드러운 활성화, 음수 영역 일부 허용
23
- - Gate 벡터가 "어떤 정보를 통과시킬지" 학습
24
- - PaLM, LLaMA 등에서 ReLU FFN 대비 일관된 성능 향상 보고
25
 
26
- 참고: W_gate와 W_up 두 개의 up-projection이 있어서
27
- 파라미터 수가 기존 FFN 대비 1.5배이지만, intermediate_dim
28
- 조정하여 파라미터 수를 맞춥니다.
29
  """
30
 
31
  def __init__(self, config: ModelConfig):
32
  super().__init__()
33
- # 게이트 프로젝션: hidden_dim → intermediate_dim
34
  self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
35
- # 프로젝션: hidden_dim → intermediate_dim
36
  self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
37
- # 다운 프로젝션: intermediate_dim → hidden_dim
38
  self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
39
 
40
  def forward(self, x: torch.Tensor) -> torch.Tensor:
41
  # SwiGLU(x) = (Swish(gate(x)) ⊙ up(x)) · down
42
  #
43
- # 1) gate: 어떤 정보를 통과시킬지 결정 (Swish 활성화)
44
  gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
45
- # 2) up: 정보를 고차원으로 사영
46
  up = self.up_proj(x)
47
- # 3) element-wise (게이팅) → 다시 원래 차원으로
48
  return self.down_proj(gate * up)
 
8
 
9
 
10
  class SwiGLUFeedForward(nn.Module):
11
+ """SwiGLU: Gated Linear Unit with Swish activation function.
12
 
13
+ Standard FFN:
14
  FFN(x) = ReLU(x·W1 + b1)·W2 + b2
15
+ simple nonlinear transformation
16
 
17
  SwiGLU FFN:
18
  SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down
19
+ controls information flow via a gating mechanism
20
 
21
+ Why is SwiGLU better?
22
+ - Swish(x) = x · sigmoid(x): smooth activation, allows some negative values
23
+ - The gate vector learns "which information to let through"
24
+ - Consistently reported to outperform ReLU FFN in PaLM, LLaMA, etc.
25
 
26
+ Note: Having two up-projections (W_gate and W_up) means
27
+ 1.5x the parameters of a standard FFN, but intermediate_dim is
28
+ adjusted to match the total parameter count.
29
  """
30
 
31
  def __init__(self, config: ModelConfig):
32
  super().__init__()
33
+ # Gate projection: hidden_dim → intermediate_dim
34
  self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
35
+ # Up projection: hidden_dim → intermediate_dim
36
  self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
37
+ # Down projection: intermediate_dim → hidden_dim
38
  self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
39
 
40
  def forward(self, x: torch.Tensor) -> torch.Tensor:
41
  # SwiGLU(x) = (Swish(gate(x)) ⊙ up(x)) · down
42
  #
43
+ # 1) gate: decides which information to pass through (Swish activation)
44
  gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
45
+ # 2) up: projects information to a higher dimension
46
  up = self.up_proj(x)
47
+ # 3) element-wise multiplication (gating) → project back to original dimension
48
  return self.down_proj(gate * up)
llm_lab/model/llm_model.py CHANGED
@@ -13,19 +13,19 @@ from .transformer_block import TransformerBlock
13
 
14
 
15
  class LLMModel(nn.Module):
16
- """1B 파라미터 LLaMA-style Decoder-Only Transformer.
17
 
18
- 전체 구조:
19
  Input Token IDs
20
  → Token Embedding
21
  → [TransformerBlock] × num_layers (+ Activation Checkpointing)
22
- → RMSNorm (최종)
23
  → Linear Head (→ vocab logits)
24
 
25
  Weight Tying:
26
- - 입력 Embedding 출력 Linear Head의 가중치를 공유
27
- - 파라미터 절약 (~65M) + 성능 유지/향상
28
- - 직관: "단어의 의미 표현" "단어 예측" 같은 공간을 사용
29
  """
30
 
31
  def __init__(self, config: ModelConfig):
@@ -41,29 +41,29 @@ class LLMModel(nn.Module):
41
  for i in range(config.num_layers)
42
  ])
43
 
44
- # ── 최종 정규화 ──
45
  self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
46
 
47
- # ── 출력 헤드 (Weight Tying) ──
48
  self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
49
- # Weight Tying: lm_head 가중치 = token_embedding 가중치
50
  self.lm_head.weight = self.token_embedding.weight
51
 
52
- # 가중치 초기화
53
  self._init_weights()
54
 
55
  def _init_weights(self):
56
- """가중치 초기화 전략.
57
 
58
- 초기화가 중요한가?
59
- - 너무 크면: 활성화 폭발 → NaN
60
- - 너무 작으면: gradient 소멸학습 정체
61
- - 적절한 초기화: 레이어의 출력 분산을 일정하게 유지
62
 
63
- GPT-2 스타일 초기화:
64
- - 일반 Linear: N(0, 0.02)
65
  - Residual projection: N(0, 0.02 / √(2 × num_layers))
66
- 레이어가 깊어질수록 residual 기여를 줄여 안정화
67
  """
68
  std = 0.02
69
  residual_std = std / math.sqrt(2 * self.config.num_layers)
@@ -76,7 +76,7 @@ class LLMModel(nn.Module):
76
  elif isinstance(module, nn.Embedding):
77
  nn.init.normal_(module.weight, mean=0.0, std=std)
78
 
79
- # Residual projection 레이어에 축소된 초기화 적용
80
  for layer in self.layers:
81
  nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
82
  nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
@@ -89,55 +89,55 @@ class LLMModel(nn.Module):
89
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
90
  """
91
  Args:
92
- input_ids: (batch_size, seq_len) - 토큰 ID
93
- targets: (batch_size, seq_len) - 정답 토큰 ID (학습 )
94
- position_offset: 위치 오프셋 (추론 )
95
 
96
  Returns:
97
  logits: (batch_size, seq_len, vocab_size)
98
- loss: 스칼라 (targets 제공 ) 또는 None
99
  """
100
  B, S = input_ids.shape
101
 
102
  # ── Step 1: Token Embedding ──
103
- # 토큰 ID hidden_dim 차원의 벡터로 변환
104
  h = self.token_embedding(input_ids) # (B, S, hidden_dim)
105
 
106
  # ── Step 2: Transformer Blocks ──
107
- # Activation Checkpointing: 학습 메모리 절약
108
- # (중간 활성화를 저장하지 않고, backward 재계산)
109
  for layer in self.layers:
110
  if self.training and torch.is_grad_enabled():
111
- # Activation Checkpointing 적용
112
  h = torch.utils.checkpoint.checkpoint(
113
  layer, h, None, position_offset,
114
- use_reentrant=False, # PyTorch >= 2.0 권장
115
  )
116
  else:
117
  h = layer(h, mask=None, position_offset=position_offset)
118
 
119
- # ── Step 3: 최종 정규화 ──
120
  h = self.final_norm(h)
121
 
122
- # ── Step 4: 출력 로짓 계산 ──
123
  logits = self.lm_head(h) # (B, S, vocab_size)
124
 
125
- # ── Step 5: Loss 계산 (학습 ) ──
126
  loss = None
127
  if targets is not None:
128
- # Cross-Entropy Loss: 다음 토큰 예측
129
  # logits: (B, S, V) → (B*S, V)
130
  # targets: (B, S) → (B*S,)
131
  loss = F.cross_entropy(
132
  logits.view(-1, self.config.vocab_size),
133
  targets.view(-1),
134
- ignore_index=-100, # 패딩 토큰 무시
135
  )
136
 
137
  return logits, loss
138
 
139
  def count_parameters(self, trainable_only: bool = True) -> int:
140
- """모델 파라미터 계산."""
141
  if trainable_only:
142
  return sum(p.numel() for p in self.parameters() if p.requires_grad)
143
  return sum(p.numel() for p in self.parameters())
@@ -151,50 +151,50 @@ class LLMModel(nn.Module):
151
  top_k: int = 50,
152
  top_p: float = 0.9,
153
  ) -> torch.Tensor:
154
- """텍스트 생성 (추론).
155
 
156
- Autoregressive 생성: 토큰씩 예측하여 이어붙이기.
157
 
158
  Args:
159
- input_ids: (1, prompt_len) - 초기 프롬프트
160
- max_new_tokens: 생성할 최대 토큰
161
- temperature: 확률 분포 날카로움 조절 (낮을수록 보수적)
162
- top_k: 확률 상위 k개만 고려
163
- top_p: 누적 확률 p까지만 고려 (nucleus sampling)
164
  """
165
  self.eval()
166
  generated = input_ids
167
 
168
  for _ in range(max_new_tokens):
169
- # 현재 시퀀스가 max_seq_len을 초과하면 잘라내기
170
  ctx = generated[:, -self.config.max_seq_len:]
171
 
172
  # Forward pass
173
  logits, _ = self(ctx)
174
- # 마지막 토큰의 logits만 사용 (다음 토큰 예측)
175
  next_logits = logits[:, -1, :] / temperature
176
 
177
- # ── Top-K 필터링 ──
178
  if top_k > 0:
179
  top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
180
  min_top_k = top_k_values[:, -1].unsqueeze(-1)
181
  next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
182
 
183
- # ── Top-P (Nucleus) 필터링 ──
184
  if top_p < 1.0:
185
  sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
186
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
187
- # 누적 확률이 top_p를 초과하는 토큰 제거
188
  remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
189
  sorted_logits[remove_mask] = float("-inf")
190
- # 원래 순서로 복원
191
  next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
192
 
193
- # 확률 분포에서 샘플링
194
  probs = F.softmax(next_logits, dim=-1)
195
  next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
196
 
197
- # 생성된 토큰 이어붙이기
198
  generated = torch.cat([generated, next_token], dim=1)
199
 
200
  return generated
 
13
 
14
 
15
  class LLMModel(nn.Module):
16
+ """1B parameter LLaMA-style Decoder-Only Transformer.
17
 
18
+ Overall structure:
19
  Input Token IDs
20
  → Token Embedding
21
  → [TransformerBlock] × num_layers (+ Activation Checkpointing)
22
+ → RMSNorm (final)
23
  → Linear Head (→ vocab logits)
24
 
25
  Weight Tying:
26
+ - Shares weights between the input Embedding and the output Linear Head
27
+ - Saves parameters (~65M) while maintaining or improving performance
28
+ - Intuition: "representing word meaning" and "predicting words" use the same space
29
  """
30
 
31
  def __init__(self, config: ModelConfig):
 
41
  for i in range(config.num_layers)
42
  ])
43
 
44
+ # ── Final normalization ──
45
  self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
46
 
47
+ # ── Output head (Weight Tying) ──
48
  self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
49
+ # Weight Tying: lm_head weights = token_embedding weights
50
  self.lm_head.weight = self.token_embedding.weight
51
 
52
+ # Weight initialization
53
  self._init_weights()
54
 
55
  def _init_weights(self):
56
+ """Weight initialization strategy.
57
 
58
+ Why does initialization matter?
59
+ - Too large: activation explosion → NaN
60
+ - Too small: gradient vanishingtraining stagnation
61
+ - Proper initialization: keeps output variance consistent across layers
62
 
63
+ GPT-2 style initialization:
64
+ - General Linear: N(0, 0.02)
65
  - Residual projection: N(0, 0.02 / √(2 × num_layers))
66
+ reduces residual contribution as depth increases for stability
67
  """
68
  std = 0.02
69
  residual_std = std / math.sqrt(2 * self.config.num_layers)
 
76
  elif isinstance(module, nn.Embedding):
77
  nn.init.normal_(module.weight, mean=0.0, std=std)
78
 
79
+ # Apply scaled-down initialization to residual projection layers
80
  for layer in self.layers:
81
  nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
82
  nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
 
89
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
90
  """
91
  Args:
92
+ input_ids: (batch_size, seq_len) - token IDs
93
+ targets: (batch_size, seq_len) - ground-truth token IDs (during training)
94
+ position_offset: position offset (during inference)
95
 
96
  Returns:
97
  logits: (batch_size, seq_len, vocab_size)
98
+ loss: scalar (when targets are provided) or None
99
  """
100
  B, S = input_ids.shape
101
 
102
  # ── Step 1: Token Embedding ──
103
+ # Convert each token ID into a vector of dimension hidden_dim
104
  h = self.token_embedding(input_ids) # (B, S, hidden_dim)
105
 
106
  # ── Step 2: Transformer Blocks ──
107
+ # Activation Checkpointing: saves memory during training
108
+ # (does not store intermediate activations; recomputes them during backward)
109
  for layer in self.layers:
110
  if self.training and torch.is_grad_enabled():
111
+ # Apply Activation Checkpointing
112
  h = torch.utils.checkpoint.checkpoint(
113
  layer, h, None, position_offset,
114
+ use_reentrant=False, # recommended for PyTorch >= 2.0
115
  )
116
  else:
117
  h = layer(h, mask=None, position_offset=position_offset)
118
 
119
+ # ── Step 3: Final normalization ──
120
  h = self.final_norm(h)
121
 
122
+ # ── Step 4: Compute output logits ──
123
  logits = self.lm_head(h) # (B, S, vocab_size)
124
 
125
+ # ── Step 5: Compute loss (during training) ──
126
  loss = None
127
  if targets is not None:
128
+ # Cross-Entropy Loss: next-token prediction
129
  # logits: (B, S, V) → (B*S, V)
130
  # targets: (B, S) → (B*S,)
131
  loss = F.cross_entropy(
132
  logits.view(-1, self.config.vocab_size),
133
  targets.view(-1),
134
+ ignore_index=-100, # ignore padding tokens
135
  )
136
 
137
  return logits, loss
138
 
139
  def count_parameters(self, trainable_only: bool = True) -> int:
140
+ """Count the number of model parameters."""
141
  if trainable_only:
142
  return sum(p.numel() for p in self.parameters() if p.requires_grad)
143
  return sum(p.numel() for p in self.parameters())
 
151
  top_k: int = 50,
152
  top_p: float = 0.9,
153
  ) -> torch.Tensor:
154
+ """Text generation (inference).
155
 
156
+ Autoregressive generation: predicts and appends one token at a time.
157
 
158
  Args:
159
+ input_ids: (1, prompt_len) - initial prompt
160
+ max_new_tokens: maximum number of tokens to generate
161
+ temperature: controls sharpness of probability distribution (lower = more conservative)
162
+ top_k: consider only the top k tokens by probability
163
+ top_p: consider only tokens up to cumulative probability p (nucleus sampling)
164
  """
165
  self.eval()
166
  generated = input_ids
167
 
168
  for _ in range(max_new_tokens):
169
+ # Truncate if current sequence exceeds max_seq_len
170
  ctx = generated[:, -self.config.max_seq_len:]
171
 
172
  # Forward pass
173
  logits, _ = self(ctx)
174
+ # Use only the last token's logits (next-token prediction)
175
  next_logits = logits[:, -1, :] / temperature
176
 
177
+ # ── Top-K filtering ──
178
  if top_k > 0:
179
  top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
180
  min_top_k = top_k_values[:, -1].unsqueeze(-1)
181
  next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
182
 
183
+ # ── Top-P (Nucleus) filtering ──
184
  if top_p < 1.0:
185
  sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
186
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
187
+ # Remove tokens where cumulative probability exceeds top_p
188
  remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
189
  sorted_logits[remove_mask] = float("-inf")
190
+ # Restore original order
191
  next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
192
 
193
+ # Sample from probability distribution
194
  probs = F.softmax(next_logits, dim=-1)
195
  next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
196
 
197
+ # Append generated token
198
  generated = torch.cat([generated, next_token], dim=1)
199
 
200
  return generated
llm_lab/model/norm.py CHANGED
@@ -5,36 +5,36 @@ import torch.nn as nn
5
 
6
 
7
  class RMSNorm(nn.Module):
8
- """RMSNorm: LayerNorm의 경량화 버전.
9
 
10
- 일반 LayerNorm과의 차이:
11
- - 평균(mean)을 빼지 않음연산 절약
12
- - 분산 대신 RMS(Root Mean Square) 정규화
13
- - bias 파라미터 없음
14
 
15
- 수식:
16
  RMSNorm(x) = (x / RMS(x)) * γ
17
  RMS(x) = sqrt(mean(x²) + ε)
18
 
19
- 정규화가 필요한가?
20
- 레이어를 깊게 쌓으면 활성화 값의 스케일이 폭발하거나 소멸합니다.
21
- 정규화로 레이어의 입력을 안정적인 범위로 유지합니다.
22
  """
23
 
24
  def __init__(self, dim: int, eps: float = 1e-6):
25
  super().__init__()
26
  self.eps = eps
27
- # γ (gamma): 학습 가능한 스케일 파라미터, 1로 초기화
28
  self.weight = nn.Parameter(torch.ones(dim))
29
 
30
  def forward(self, x: torch.Tensor) -> torch.Tensor:
31
- # 1) 입력을 float32 변환 (수치 안정성)
32
- # bf16/fp16 상태에서 제곱합을 구하면 오버플로우 위험
33
  x_float = x.float()
34
 
35
- # 2) RMS 계산: sqrt(mean(x²) + ε)
36
  rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
37
- # rsqrt = 1/sqrt(x) → 나눗셈 대신 곱셈으로 대체 (더 빠름)
38
 
39
- # 3) 정규화 원래 dtype으로 복원, 스케일 적용
40
  return (x_float * rms).to(x.dtype) * self.weight
 
5
 
6
 
7
  class RMSNorm(nn.Module):
8
+ """RMSNorm: A lightweight alternative to LayerNorm.
9
 
10
+ Differences from standard LayerNorm:
11
+ - Does not subtract the mean saves computation
12
+ - Normalizes using RMS (Root Mean Square) instead of variance
13
+ - No bias parameter
14
 
15
+ Formula:
16
  RMSNorm(x) = (x / RMS(x)) * γ
17
  RMS(x) = sqrt(mean(x²) + ε)
18
 
19
+ Why is normalization necessary?
20
+ Stacking layers deeply causes activation values to explode or vanish.
21
+ Normalization keeps the input to each layer within a stable range.
22
  """
23
 
24
  def __init__(self, dim: int, eps: float = 1e-6):
25
  super().__init__()
26
  self.eps = eps
27
+ # γ (gamma): learnable scale parameter, initialized to 1
28
  self.weight = nn.Parameter(torch.ones(dim))
29
 
30
  def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ # 1) Cast input to float32 for numerical stability
32
+ # Computing the sum of squares in bf16/fp16 risks overflow
33
  x_float = x.float()
34
 
35
+ # 2) Compute RMS: sqrt(mean(x²) + ε)
36
  rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
37
+ # rsqrt = 1/sqrt(x) → replaces division with multiplication (faster)
38
 
39
+ # 3) Normalize, restore original dtype, and apply scale
40
  return (x_float * rms).to(x.dtype) * self.weight
llm_lab/model/rope.py CHANGED
@@ -7,21 +7,23 @@ import torch.nn as nn
7
 
8
 
9
  class RotaryPositionalEmbedding(nn.Module):
10
- """RoPE: 회전 행렬을 이용한 상대 위치 인코딩.
11
 
12
- 핵심 아이디어:
13
- - 차원 (2i, 2i+1) 2D 평면의 좌표로 보고,
14
- 위치(position)에 비례한 각도만큼 회전시킵니다.
15
- - 토큰의 어텐션 스코어(Q·K) 상대 거리에만 의존하게 됩니다.
 
16
 
17
- RoPE인가?
18
- - 절대 위치 임베딩: 위치에 고정 벡터를 더함 길이 일반화 어려움
19
- - 상대 위치 임베딩: 구현 복잡, 추가 파라미터 필요
20
- - RoPE: 파라미터 없이, 자연스럽게 상대 위치 정보 인코딩
 
21
 
22
- 수식:
23
  θ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
24
- RoPE(x, pos) = x 차원 쌍에서 pos × θ_i 만큼 회전
25
  """
26
 
27
  def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
@@ -30,16 +32,16 @@ class RotaryPositionalEmbedding(nn.Module):
30
  self.max_seq_len = max_seq_len
31
  self.theta = theta
32
 
33
- # 주파수 벡터 미리 계산 (학습 불필요buffer로 등록)
34
  # freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
35
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
36
  self.register_buffer("freqs", freqs, persistent=False)
37
 
38
- # (max_seq_len, dim/2) 크기의 cos/sin 테이블 미리 계산
39
  self._build_cache(max_seq_len)
40
 
41
  def _build_cache(self, seq_len: int):
42
- """cos/sin 값을 미리 계산하여 캐싱합니다."""
43
  t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
44
  # outer product: (seq_len,) × (dim/2,) → (seq_len, dim/2)
45
  angles = torch.outer(t, self.freqs)
@@ -49,23 +51,23 @@ class RotaryPositionalEmbedding(nn.Module):
49
  def forward(
50
  self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
51
  ) -> Tuple[torch.Tensor, torch.Tensor]:
52
- """Q, K에 회전 변환을 적용합니다.
53
 
54
  Args:
55
  q: (batch, num_heads, seq_len, head_dim)
56
  k: (batch, num_kv_heads, seq_len, head_dim)
57
- position_offset: 시퀀스 시작 위치 오프셋 (추론 KV 캐시 사용 )
58
 
59
  Returns:
60
- 회전 변환이 적용된 (q_rotated, k_rotated)
61
  """
62
  seq_len = q.shape[2]
63
 
64
- # 필요 캐시 확장
65
  if position_offset + seq_len > self.cos_cached.shape[0]:
66
  self._build_cache(position_offset + seq_len)
67
 
68
- # 현재 위치에 해당하는 cos/sin 슬라이스
69
  cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
70
  sin = self.sin_cached[position_offset : position_offset + seq_len]
71
 
@@ -77,27 +79,27 @@ class RotaryPositionalEmbedding(nn.Module):
77
  def _apply_rotation(
78
  x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
79
  ) -> torch.Tensor:
80
- """회전 변환 적용.
81
 
82
- 2D 회전 행렬:
83
  [cos θ, -sin θ] [x1] [x1·cos θ - x2·sin θ]
84
  [sin θ, cos θ] [x2] = [x1·sin θ + x2·cos θ]
85
 
86
- 이를 벡터 연산으로 효율적으로 구현합니다.
87
  """
88
  # x: (batch, heads, seq_len, head_dim)
89
- # 짝수/홀수 인덱스를 분리: (x0, x1, x2, x3, ...) → (x0, x2, ...), (x1, x3, ...)
90
- x_even = x[..., 0::2] # 짝수 인덱스
91
- x_odd = x[..., 1::2] # 홀수 인덱스
92
 
93
- # 브로드캐스팅을 위해 차원 맞춤: (seq_len, dim/2) → (1, 1, seq_len, dim/2)
94
  cos = cos.unsqueeze(0).unsqueeze(0)
95
  sin = sin.unsqueeze(0).unsqueeze(0)
96
 
97
- # 회전 적용
98
  rotated_even = x_even * cos - x_odd * sin
99
  rotated_odd = x_even * sin + x_odd * cos
100
 
101
- # 다시 인터리빙: (even0, odd0, even1, odd1, ...)
102
  out = torch.stack([rotated_even, rotated_odd], dim=-1)
103
- return out.flatten(-2) # 마지막 차원을 합쳐 원래 shape 복원
 
7
 
8
 
9
  class RotaryPositionalEmbedding(nn.Module):
10
+ """RoPE: Relative positional encoding using rotation matrices.
11
 
12
+ Core idea:
13
+ - Each dimension pair (2i, 2i+1) is treated as coordinates in a 2D plane,
14
+ and is rotated by an angle proportional to the position.
15
+ - The attention score (Q·K) between two tokens depends only on their
16
+ relative distance.
17
 
18
+ Why RoPE?
19
+ - Absolute positional embeddings: add a fixed vector at each position
20
+ difficult to generalize to longer sequences
21
+ - Relative positional embeddings: complex implementation, extra parameters needed
22
+ - RoPE: encodes relative position information naturally with no extra parameters
23
 
24
+ Formula:
25
  θ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
26
+ RoPE(x, pos) = rotate x in each dimension pair by pos × θ_i
27
  """
28
 
29
  def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
 
32
  self.max_seq_len = max_seq_len
33
  self.theta = theta
34
 
35
+ # Pre-compute frequency vector (no training needed register as buffer)
36
  # freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
37
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
38
  self.register_buffer("freqs", freqs, persistent=False)
39
 
40
+ # Pre-compute cos/sin table of shape (max_seq_len, dim/2)
41
  self._build_cache(max_seq_len)
42
 
43
  def _build_cache(self, seq_len: int):
44
+ """Pre-compute and cache cos/sin values."""
45
  t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
46
  # outer product: (seq_len,) × (dim/2,) → (seq_len, dim/2)
47
  angles = torch.outer(t, self.freqs)
 
51
  def forward(
52
  self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
53
  ) -> Tuple[torch.Tensor, torch.Tensor]:
54
+ """Apply rotary transformation to Q and K.
55
 
56
  Args:
57
  q: (batch, num_heads, seq_len, head_dim)
58
  k: (batch, num_kv_heads, seq_len, head_dim)
59
+ position_offset: sequence start position offset (used with KV cache during inference)
60
 
61
  Returns:
62
+ (q_rotated, k_rotated) with rotary transformation applied
63
  """
64
  seq_len = q.shape[2]
65
 
66
+ # Extend cache if needed
67
  if position_offset + seq_len > self.cos_cached.shape[0]:
68
  self._build_cache(position_offset + seq_len)
69
 
70
+ # Slice cos/sin values for the current positions
71
  cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
72
  sin = self.sin_cached[position_offset : position_offset + seq_len]
73
 
 
79
  def _apply_rotation(
80
  x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
81
  ) -> torch.Tensor:
82
+ """Apply rotation transformation.
83
 
84
+ 2D rotation matrix:
85
  [cos θ, -sin θ] [x1] [x1·cos θ - x2·sin θ]
86
  [sin θ, cos θ] [x2] = [x1·sin θ + x2·cos θ]
87
 
88
+ Implemented efficiently using vectorized operations.
89
  """
90
  # x: (batch, heads, seq_len, head_dim)
91
+ # Separate even/odd indices: (x0, x1, x2, x3, ...) → (x0, x2, ...), (x1, x3, ...)
92
+ x_even = x[..., 0::2] # even indices
93
+ x_odd = x[..., 1::2] # odd indices
94
 
95
+ # Adjust dimensions for broadcasting: (seq_len, dim/2) → (1, 1, seq_len, dim/2)
96
  cos = cos.unsqueeze(0).unsqueeze(0)
97
  sin = sin.unsqueeze(0).unsqueeze(0)
98
 
99
+ # Apply rotation
100
  rotated_even = x_even * cos - x_odd * sin
101
  rotated_odd = x_even * sin + x_odd * cos
102
 
103
+ # Re-interleave: (even0, odd0, even1, odd1, ...)
104
  out = torch.stack([rotated_even, rotated_odd], dim=-1)
105
+ return out.flatten(-2) # Merge last two dimensions to restore original shape
llm_lab/model/transformer_block.py CHANGED
@@ -1,4 +1,4 @@
1
- """Transformer Block (하나의 레이어)."""
2
 
3
  from typing import Optional
4
 
@@ -12,32 +12,32 @@ from .feedforward import SwiGLUFeedForward
12
 
13
 
14
  class TransformerBlock(nn.Module):
15
- """하나의 Transformer 디코더 블록.
16
 
17
- 구조 (Pre-Norm 방식):
18
  x → RMSNorm → Attention → + (residual) → RMSNorm → FFN → + (residual) → out
19
 
20
  Pre-Norm vs Post-Norm:
21
- - Post-Norm (원래 Transformer): LayerNorm residual 이후
22
- 깊은 모델에서 학습 불안정
23
- - Pre-Norm (GPT-2 이후 표준): LayerNorm sublayer 이전
24
- → gradient 흐름이 원활, 학습이 안정적
25
 
26
- Residual Connection의 역할:
27
- - 입력을 출력에 더함 gradient가 레이어를 건너뛸 있는 "고속도로"
28
- - 22개 레이어를 쌓아도 학습이 가능한 핵심 이유
29
  """
30
 
31
  def __init__(self, config: ModelConfig, layer_idx: int):
32
  super().__init__()
33
  self.layer_idx = layer_idx
34
 
35
- # Pre-Norm: Attention 정규화
36
  self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
37
  # Self-Attention
38
  self.attention = GroupedQueryAttention(config)
39
 
40
- # Pre-Norm: FFN 정규화
41
  self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
42
  # Feed-Forward Network
43
  self.feed_forward = SwiGLUFeedForward(config)
 
1
+ """Transformer Block (a single layer)."""
2
 
3
  from typing import Optional
4
 
 
12
 
13
 
14
  class TransformerBlock(nn.Module):
15
+ """A single Transformer decoder block.
16
 
17
+ Structure (Pre-Norm style):
18
  x → RMSNorm → Attention → + (residual) → RMSNorm → FFN → + (residual) → out
19
 
20
  Pre-Norm vs Post-Norm:
21
+ - Post-Norm (original Transformer): LayerNorm applied after the residual
22
+ training instability in deep models
23
+ - Pre-Norm (standard since GPT-2): LayerNorm applied before the sublayer
24
+ smooth gradient flow, stable training
25
 
26
+ Role of Residual Connection:
27
+ - Adds the input to the output a "highway" that lets gradients skip layers
28
+ - The key reason training is feasible even with 22 stacked layers
29
  """
30
 
31
  def __init__(self, config: ModelConfig, layer_idx: int):
32
  super().__init__()
33
  self.layer_idx = layer_idx
34
 
35
+ # Pre-Norm: normalization before Attention
36
  self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
37
  # Self-Attention
38
  self.attention = GroupedQueryAttention(config)
39
 
40
+ # Pre-Norm: normalization before FFN
41
  self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
42
  # Feed-Forward Network
43
  self.feed_forward = SwiGLUFeedForward(config)
llm_lab/model/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- """모델 유틸리티 함수."""
2
 
3
  from __future__ import annotations
4
 
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
 
13
 
14
  def count_parameters_detailed(model: "LLMModel") -> dict:
15
- """모델의 파라미터 수를 컴포넌트별로 상세 출력합니다."""
16
  total = 0
17
  breakdown = {}
18
 
@@ -21,7 +21,7 @@ def count_parameters_detailed(model: "LLMModel") -> dict:
21
  breakdown["token_embedding"] = emb_params
22
  total += emb_params
23
 
24
- # 레이어
25
  layer_total = 0
26
  layer_detail = {}
27
  layer = model.layers[0]
@@ -40,7 +40,7 @@ def count_parameters_detailed(model: "LLMModel") -> dict:
40
  breakdown["final_norm"] = norm_params
41
  total += norm_params
42
 
43
- # LM head (weight tying이므로 실제 추가 파라미터 0)
44
  breakdown["lm_head"] = "weight tying (0 additional)"
45
  breakdown["total"] = total
46
 
@@ -48,12 +48,12 @@ def count_parameters_detailed(model: "LLMModel") -> dict:
48
 
49
 
50
  def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
51
- """모델의 GPU 메모리 사용량을 추정합니다.
52
 
53
  Args:
54
- dtype_bytes: 2 (bf16/fp16) 또는 4 (fp32)
55
  """
56
- # 대략적인 파라미터 수 계산
57
  emb = config.vocab_size * config.hidden_dim
58
  per_layer = (
59
  config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
@@ -67,11 +67,11 @@ def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: in
67
  optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states × fp32
68
  gradient_gb = total_params * dtype_bytes / 1e9
69
 
70
- # 활성화 메모리 (activation checkpointing 적용 가정)
71
- # 대략적 추정: batch_size × seq_len × hidden_dim × num_layers × factor
72
  activation_gb = (
73
- batch_size * config.max_seq_len * config.hidden_dim * 4 # 바이트
74
- * math.sqrt(config.num_layers) # checkpointing 효과
75
  / 1e9
76
  )
77
 
 
1
+ """Model utility functions."""
2
 
3
  from __future__ import annotations
4
 
 
12
 
13
 
14
  def count_parameters_detailed(model: "LLMModel") -> dict:
15
+ """Print a detailed breakdown of the model's parameter count by component."""
16
  total = 0
17
  breakdown = {}
18
 
 
21
  breakdown["token_embedding"] = emb_params
22
  total += emb_params
23
 
24
+ # Per layer
25
  layer_total = 0
26
  layer_detail = {}
27
  layer = model.layers[0]
 
40
  breakdown["final_norm"] = norm_params
41
  total += norm_params
42
 
43
+ # LM head (weight tying, so 0 additional parameters)
44
  breakdown["lm_head"] = "weight tying (0 additional)"
45
  breakdown["total"] = total
46
 
 
48
 
49
 
50
  def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
51
+ """Estimate GPU memory usage of the model.
52
 
53
  Args:
54
+ dtype_bytes: 2 (bf16/fp16) or 4 (fp32)
55
  """
56
+ # Approximate parameter count
57
  emb = config.vocab_size * config.hidden_dim
58
  per_layer = (
59
  config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
 
67
  optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states × fp32
68
  gradient_gb = total_params * dtype_bytes / 1e9
69
 
70
+ # Activation memory (assuming activation checkpointing is applied)
71
+ # Rough estimate: batch_size × seq_len × hidden_dim × num_layers × factor
72
  activation_gb = (
73
+ batch_size * config.max_seq_len * config.hidden_dim * 4 # bytes
74
+ * math.sqrt(config.num_layers) # effect of checkpointing
75
  / 1e9
76
  )
77
 
llm_lab/training/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """학습 모듈 — Gradient Accumulation, Mixed Precision, 체크포인트, wandb 로깅."""
2
  from .scheduler import CosineWarmupScheduler
3
  from .checkpoint import CheckpointManager
4
  from .metrics import MetricsTracker
 
1
+ """Training module — Gradient Accumulation, Mixed Precision, checkpointing, wandb logging."""
2
  from .scheduler import CosineWarmupScheduler
3
  from .checkpoint import CheckpointManager
4
  from .metrics import MetricsTracker
llm_lab/training/checkpoint.py CHANGED
@@ -1,4 +1,4 @@
1
- """학습 상태 저장/복원 관리자."""
2
 
3
  import json
4
  import shutil
@@ -13,22 +13,22 @@ from llm_lab.config import TrainConfig
13
 
14
 
15
  class CheckpointManager:
16
- """학습 상태 저장/복원 관리자.
17
-
18
- Colab에서 체크포인트가 중요한 이유:
19
- - 세션 만료 (최대 ~24시간) 모든 메모리 상태 소멸
20
- - Google Drive 저장하면 세션 연속 학습 가능
21
- - 옵티마이저 상태까지 저장해야 AdamW 모멘텀이 유지됨
22
-
23
- 저장 내용:
24
- - model_state_dict: 모델 가중치
25
- - optimizer_state_dict: 옵티마이저 상태 (m, v 모멘텀)
26
- - step: 현재 학습 스텝
27
- - best_val_loss: 최저 검증 Loss
28
- - config: 학습 설정 (재현성)
29
- - rng_states: 랜덤 시드 상태 (완전 재현)
30
- - metrics_history: 학습 메트릭 기록
31
- - wandb_run_id: wandb 실행 ID (로깅 연속성)
32
  """
33
 
34
  def __init__(self, config: TrainConfig):
@@ -46,20 +46,20 @@ class CheckpointManager:
46
  metrics_history: Dict[str, list],
47
  wandb_run_id: Optional[str] = None,
48
  ):
49
- """체크포인트를 저장합니다."""
50
  ckpt_path = self.checkpoint_dir / f"step_{step:06d}"
51
  ckpt_path.mkdir(parents=True, exist_ok=True)
52
 
53
- print(f"\n💾 체크포인트 저장: {ckpt_path}")
54
  start = time.time()
55
 
56
- # 1) 모델 가중치 (bf16 상태 그대로)
57
  torch.save(model.state_dict(), ckpt_path / "model.pt")
58
 
59
- # 2) 옵티마이저 상태 (fp32 모멘텀 포함, 크기 )
60
  torch.save(optimizer.state_dict(), ckpt_path / "optimizer.pt")
61
 
62
- # 3) 학습 메타 정보
63
  meta = {
64
  "step": step,
65
  "best_val_loss": best_val_loss,
@@ -69,10 +69,10 @@ class CheckpointManager:
69
  with open(ckpt_path / "meta.json", "w") as f:
70
  json.dump(meta, f, indent=2)
71
 
72
- # 4) 메트릭 기록
73
  torch.save(metrics_history, ckpt_path / "metrics.pt")
74
 
75
- # 5) 랜덤 상태 (완전 재현을 위해)
76
  rng_states = {
77
  "python": torch.random.get_rng_state(),
78
  "cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
@@ -81,9 +81,9 @@ class CheckpointManager:
81
 
82
  elapsed = time.time() - start
83
  ckpt_size = sum(f.stat().st_size for f in ckpt_path.rglob("*")) / 1e9
84
- print(f" 저장 완료: {ckpt_size:.2f} GB, {elapsed:.1f}")
85
 
86
- # 오래된 체크포인트 삭제 (롤링)
87
  self._cleanup_old_checkpoints()
88
 
89
  def load_latest(
@@ -92,42 +92,42 @@ class CheckpointManager:
92
  optimizer: Optional[torch.optim.Optimizer] = None,
93
  device: torch.device = torch.device("cpu"),
94
  ) -> Dict[str, Any]:
95
- """가장 최근 체크포인트를 로드합니다.
96
 
97
  Returns:
98
  {"step", "best_val_loss", "wandb_run_id", "metrics_history"}
99
- 또는 체크포인트가 없으면 None
100
  """
101
  ckpt_path = self._find_latest()
102
  if ckpt_path is None:
103
- print("[Checkpoint] 저장된 체크포인트 없음. 처음부터 시작합니다.")
104
  return None
105
 
106
- print(f"\n📂 체크포인트 로드: {ckpt_path}")
107
  start = time.time()
108
 
109
- # 1) 모델 가중치
110
  model_state = torch.load(ckpt_path / "model.pt", map_location=device, weights_only=True)
111
  model.load_state_dict(model_state)
112
- del model_state # 메모리 해제
113
 
114
- # 2) 옵티마이저 상태
115
  if optimizer is not None:
116
  optim_state = torch.load(ckpt_path / "optimizer.pt", map_location=device, weights_only=True)
117
  optimizer.load_state_dict(optim_state)
118
  del optim_state
119
 
120
- # 3) 메타 정보
121
  with open(ckpt_path / "meta.json", "r") as f:
122
  meta = json.load(f)
123
 
124
- # 4) 메트릭 기록
125
  metrics_history = {}
126
  metrics_path = ckpt_path / "metrics.pt"
127
  if metrics_path.exists():
128
  metrics_history = torch.load(metrics_path, weights_only=False)
129
 
130
- # 5) 랜덤 상태 복원
131
  rng_path = ckpt_path / "rng_states.pt"
132
  if rng_path.exists():
133
  rng_states = torch.load(rng_path, weights_only=False)
@@ -136,7 +136,7 @@ class CheckpointManager:
136
  torch.cuda.set_rng_state(rng_states["cuda"])
137
 
138
  elapsed = time.time() - start
139
- print(f" 로드 완료: step={meta['step']}, {elapsed:.1f}")
140
 
141
  return {
142
  "step": meta["step"],
@@ -146,14 +146,14 @@ class CheckpointManager:
146
  }
147
 
148
  def _find_latest(self) -> Optional[Path]:
149
- """가장 최근 체크포인트 경로를 찾습니다."""
150
  ckpts = sorted(self.checkpoint_dir.glob("step_*"))
151
  return ckpts[-1] if ckpts else None
152
 
153
  def _cleanup_old_checkpoints(self):
154
- """오래된 체크포인트를 삭제합니다 (롤링)."""
155
  ckpts = sorted(self.checkpoint_dir.glob("step_*"))
156
  while len(ckpts) > self.max_checkpoints:
157
  old = ckpts.pop(0)
158
- print(f" 🗑️ 오래된 체크포인트 삭제: {old.name}")
159
  shutil.rmtree(old)
 
1
+ """Training state save/restore manager."""
2
 
3
  import json
4
  import shutil
 
13
 
14
 
15
  class CheckpointManager:
16
+ """Training state save/restore manager.
17
+
18
+ Why checkpoints matter in Colab:
19
+ - Session expiry (up to ~24 hours) causes all in-memory state to be lost
20
+ - Saving to Google Drive enables continuous training across sessions
21
+ - Optimizer state must be saved to preserve AdamW momentum
22
+
23
+ Saved contents:
24
+ - model_state_dict: model weights
25
+ - optimizer_state_dict: optimizer state (m, v momentum)
26
+ - step: current training step
27
+ - best_val_loss: lowest validation loss
28
+ - config: training configuration (for reproducibility)
29
+ - rng_states: random seed state (full reproducibility)
30
+ - metrics_history: training metrics history
31
+ - wandb_run_id: wandb run ID (for logging continuity)
32
  """
33
 
34
  def __init__(self, config: TrainConfig):
 
46
  metrics_history: Dict[str, list],
47
  wandb_run_id: Optional[str] = None,
48
  ):
49
+ """Saves a checkpoint."""
50
  ckpt_path = self.checkpoint_dir / f"step_{step:06d}"
51
  ckpt_path.mkdir(parents=True, exist_ok=True)
52
 
53
+ print(f"\n💾 Saving checkpoint: {ckpt_path}")
54
  start = time.time()
55
 
56
+ # 1) Model weights (saved as-is in bf16)
57
  torch.save(model.state_dict(), ckpt_path / "model.pt")
58
 
59
+ # 2) Optimizer state (includes fp32 momentum, can be large)
60
  torch.save(optimizer.state_dict(), ckpt_path / "optimizer.pt")
61
 
62
+ # 3) Training metadata
63
  meta = {
64
  "step": step,
65
  "best_val_loss": best_val_loss,
 
69
  with open(ckpt_path / "meta.json", "w") as f:
70
  json.dump(meta, f, indent=2)
71
 
72
+ # 4) Metrics history
73
  torch.save(metrics_history, ckpt_path / "metrics.pt")
74
 
75
+ # 5) Random states (for full reproducibility)
76
  rng_states = {
77
  "python": torch.random.get_rng_state(),
78
  "cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
 
81
 
82
  elapsed = time.time() - start
83
  ckpt_size = sum(f.stat().st_size for f in ckpt_path.rglob("*")) / 1e9
84
+ print(f" Save complete: {ckpt_size:.2f} GB, {elapsed:.1f}s")
85
 
86
+ # Remove old checkpoints (rolling)
87
  self._cleanup_old_checkpoints()
88
 
89
  def load_latest(
 
92
  optimizer: Optional[torch.optim.Optimizer] = None,
93
  device: torch.device = torch.device("cpu"),
94
  ) -> Dict[str, Any]:
95
+ """Loads the most recent checkpoint.
96
 
97
  Returns:
98
  {"step", "best_val_loss", "wandb_run_id", "metrics_history"}
99
+ or None if no checkpoint exists
100
  """
101
  ckpt_path = self._find_latest()
102
  if ckpt_path is None:
103
+ print("[Checkpoint] No saved checkpoint found. Starting from scratch.")
104
  return None
105
 
106
+ print(f"\n📂 Loading checkpoint: {ckpt_path}")
107
  start = time.time()
108
 
109
+ # 1) Model weights
110
  model_state = torch.load(ckpt_path / "model.pt", map_location=device, weights_only=True)
111
  model.load_state_dict(model_state)
112
+ del model_state # free memory
113
 
114
+ # 2) Optimizer state
115
  if optimizer is not None:
116
  optim_state = torch.load(ckpt_path / "optimizer.pt", map_location=device, weights_only=True)
117
  optimizer.load_state_dict(optim_state)
118
  del optim_state
119
 
120
+ # 3) Metadata
121
  with open(ckpt_path / "meta.json", "r") as f:
122
  meta = json.load(f)
123
 
124
+ # 4) Metrics history
125
  metrics_history = {}
126
  metrics_path = ckpt_path / "metrics.pt"
127
  if metrics_path.exists():
128
  metrics_history = torch.load(metrics_path, weights_only=False)
129
 
130
+ # 5) Restore random states
131
  rng_path = ckpt_path / "rng_states.pt"
132
  if rng_path.exists():
133
  rng_states = torch.load(rng_path, weights_only=False)
 
136
  torch.cuda.set_rng_state(rng_states["cuda"])
137
 
138
  elapsed = time.time() - start
139
+ print(f" Load complete: step={meta['step']}, {elapsed:.1f}s")
140
 
141
  return {
142
  "step": meta["step"],
 
146
  }
147
 
148
  def _find_latest(self) -> Optional[Path]:
149
+ """Finds the path of the most recent checkpoint."""
150
  ckpts = sorted(self.checkpoint_dir.glob("step_*"))
151
  return ckpts[-1] if ckpts else None
152
 
153
  def _cleanup_old_checkpoints(self):
154
+ """Removes old checkpoints (rolling)."""
155
  ckpts = sorted(self.checkpoint_dir.glob("step_*"))
156
  while len(ckpts) > self.max_checkpoints:
157
  old = ckpts.pop(0)
158
+ print(f" 🗑️ Removing old checkpoint: {old.name}")
159
  shutil.rmtree(old)
llm_lab/training/metrics.py CHANGED
@@ -1,4 +1,4 @@
1
- """학습 메트릭 추적 로깅."""
2
 
3
  from typing import Dict, Optional
4
 
@@ -8,16 +8,16 @@ from llm_lab.config import TrainConfig
8
 
9
 
10
  class MetricsTracker:
11
- """학습 메트릭을 추적하고 로깅합니다.
12
-
13
- 추적 항목:
14
- - train/loss: 학습 Loss (Cross-Entropy)
15
- - train/lr: 현재 학습률
16
- - train/grad_norm: Gradient L2 Norm
17
- - train/tokens_per_sec: 처리량
18
- - train/gpu_mem_gb: GPU 메모리 사용량
19
- - val/loss: 검증 Loss
20
- - val/perplexity: 검증 Perplexity (= exp(loss))
21
  """
22
 
23
  def __init__(self, config: TrainConfig):
@@ -33,13 +33,13 @@ class MetricsTracker:
33
  "val_ppl": [],
34
  }
35
 
36
- # wandb 초기화
37
  self.wandb_run = None
38
  if config.use_wandb:
39
  self._init_wandb()
40
 
41
  def _init_wandb(self, resume_id: Optional[str] = None):
42
- """wandb 초기화 (세션 연속 로깅 지원)."""
43
  try:
44
  import wandb
45
 
@@ -51,16 +51,16 @@ class MetricsTracker:
51
  resume="allow",
52
  config=self.config.__dict__,
53
  )
54
- print(f"[wandb] 초기화 완료: {self.wandb_run.url}")
55
  except ImportError:
56
- print("[wandb] 설치되지 않음. 콘솔 로깅만 사용합니다.")
57
  self.config.use_wandb = False
58
  except Exception as e:
59
- print(f"[wandb] 초기화 실패: {e}. 콘솔 로깅만 사용합니다.")
60
  self.config.use_wandb = False
61
 
62
  def resume_wandb(self, run_id: str):
63
- """이전 wandb 실행을 이어서 로깅합니다."""
64
  if self.config.use_wandb:
65
  self._init_wandb(resume_id=run_id)
66
 
@@ -73,7 +73,7 @@ class MetricsTracker:
73
  tokens_per_sec: float,
74
  gpu_mem_gb: float,
75
  ):
76
- """학습 스텝 메트릭을 기록합니다."""
77
  self.history["step"].append(step)
78
  self.history["train_loss"].append(loss)
79
  self.history["learning_rate"].append(lr)
@@ -93,7 +93,7 @@ class MetricsTracker:
93
  }, step=step)
94
 
95
  def log_eval(self, step: int, val_loss: float, val_ppl: float):
96
- """검증 메트릭을 기록합니다."""
97
  self.history["val_loss"].append(val_loss)
98
  self.history["val_ppl"].append(val_ppl)
99
 
 
1
+ """Training metrics tracking and logging."""
2
 
3
  from typing import Dict, Optional
4
 
 
8
 
9
 
10
  class MetricsTracker:
11
+ """Tracks and logs training metrics.
12
+
13
+ Tracked items:
14
+ - train/loss: training loss (Cross-Entropy)
15
+ - train/lr: current learning rate
16
+ - train/grad_norm: gradient L2 norm
17
+ - train/tokens_per_sec: throughput
18
+ - train/gpu_mem_gb: GPU memory usage
19
+ - val/loss: validation loss
20
+ - val/perplexity: validation perplexity (= exp(loss))
21
  """
22
 
23
  def __init__(self, config: TrainConfig):
 
33
  "val_ppl": [],
34
  }
35
 
36
+ # wandb initialization
37
  self.wandb_run = None
38
  if config.use_wandb:
39
  self._init_wandb()
40
 
41
  def _init_wandb(self, resume_id: Optional[str] = None):
42
+ """Initializes wandb (supports continuous logging across sessions)."""
43
  try:
44
  import wandb
45
 
 
51
  resume="allow",
52
  config=self.config.__dict__,
53
  )
54
+ print(f"[wandb] Initialized: {self.wandb_run.url}")
55
  except ImportError:
56
+ print("[wandb] Not installed. Using console logging only.")
57
  self.config.use_wandb = False
58
  except Exception as e:
59
+ print(f"[wandb] Initialization failed: {e}. Using console logging only.")
60
  self.config.use_wandb = False
61
 
62
  def resume_wandb(self, run_id: str):
63
+ """Resumes logging from a previous wandb run."""
64
  if self.config.use_wandb:
65
  self._init_wandb(resume_id=run_id)
66
 
 
73
  tokens_per_sec: float,
74
  gpu_mem_gb: float,
75
  ):
76
+ """Records training step metrics."""
77
  self.history["step"].append(step)
78
  self.history["train_loss"].append(loss)
79
  self.history["learning_rate"].append(lr)
 
93
  }, step=step)
94
 
95
  def log_eval(self, step: int, val_loss: float, val_ppl: float):
96
+ """Records validation metrics."""
97
  self.history["val_loss"].append(val_loss)
98
  self.history["val_ppl"].append(val_ppl)
99
 
llm_lab/training/optimizer.py CHANGED
@@ -1,4 +1,4 @@
1
- """AdamW 옵티마이저 생성 (Weight Decay 분리)."""
2
 
3
  import torch
4
  import torch.nn as nn
@@ -7,19 +7,19 @@ from llm_lab.config import TrainConfig
7
 
8
 
9
  def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
10
- """AdamW 옵티마이저를 생성합니다.
11
 
12
- Weight Decay 분리 규칙:
13
- - Decay 적용: Linear 가중치 (attention proj, FFN )
14
- - Decay 미적용: Embedding, LayerNorm/RMSNorm, Bias
15
 
16
- 분리하는가?
17
- - Weight Decay 가중치에 패널티를 주어 과적합 방지
18
- - 하지만 Norm scale 파라미터에 적용하면 정규화 효과를 방해
19
- - Embedding에 적용하면 희귀 토큰의 표현이 0으로 수축
20
- - 1D 파라미터(bias, norm weight) decay에서 제외하는 것이 관례
21
  """
22
- # 파라미터를 decay/no-decay 그룹으로 분리
23
  decay_params = []
24
  no_decay_params = []
25
 
@@ -27,7 +27,7 @@ def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW
27
  if not param.requires_grad:
28
  continue
29
 
30
- # 1D 텐서(bias, norm weight) 또는 embedding → no decay
31
  if param.dim() <= 1 or "embedding" in name:
32
  no_decay_params.append(param)
33
  else:
@@ -40,15 +40,15 @@ def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW
40
 
41
  n_decay = sum(p.numel() for p in decay_params)
42
  n_no_decay = sum(p.numel() for p in no_decay_params)
43
- print(f"[Optimizer] Decay 파라미터: {n_decay:,} ({n_decay/1e6:.1f}M)")
44
- print(f"[Optimizer] No-decay 파라미터: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
45
 
46
  optimizer = torch.optim.AdamW(
47
  param_groups,
48
  lr=config.learning_rate,
49
  betas=(config.beta1, config.beta2),
50
  eps=config.adam_eps,
51
- fused=torch.cuda.is_available(), # CUDA fused AdamW (더 빠름)
52
  )
53
 
54
  return optimizer
 
1
+ """AdamW optimizer creation with Weight Decay separation."""
2
 
3
  import torch
4
  import torch.nn as nn
 
7
 
8
 
9
  def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
10
+ """Creates an AdamW optimizer.
11
 
12
+ Weight Decay separation rules:
13
+ - Apply decay: Linear weights (attention proj, FFN, etc.)
14
+ - No decay: Embeddings, LayerNorm/RMSNorm, Bias
15
 
16
+ Why separate?
17
+ - Weight Decay penalizes large weights to prevent overfitting
18
+ - However, applying it to Norm scale parameters interferes with normalization
19
+ - Applying it to Embeddings causes rare token representations to shrink toward 0
20
+ - It is convention to exclude 1D parameters (bias, norm weight) from decay
21
  """
22
+ # Separate parameters into decay / no-decay groups
23
  decay_params = []
24
  no_decay_params = []
25
 
 
27
  if not param.requires_grad:
28
  continue
29
 
30
+ # 1D tensors (bias, norm weight) or embedding → no decay
31
  if param.dim() <= 1 or "embedding" in name:
32
  no_decay_params.append(param)
33
  else:
 
40
 
41
  n_decay = sum(p.numel() for p in decay_params)
42
  n_no_decay = sum(p.numel() for p in no_decay_params)
43
+ print(f"[Optimizer] Decay parameters: {n_decay:,} ({n_decay/1e6:.1f}M)")
44
+ print(f"[Optimizer] No-decay parameters: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
45
 
46
  optimizer = torch.optim.AdamW(
47
  param_groups,
48
  lr=config.learning_rate,
49
  betas=(config.beta1, config.beta2),
50
  eps=config.adam_eps,
51
+ fused=torch.cuda.is_available(), # CUDA fused AdamW (faster)
52
  )
53
 
54
  return optimizer
llm_lab/training/runner.py CHANGED
@@ -1,4 +1,4 @@
1
- """학습 실행 헬퍼 (Quick Start)."""
2
 
3
  from pathlib import Path
4
  from typing import Optional
@@ -20,49 +20,49 @@ def start_training(
20
  seq_len: int = 2048,
21
  auto_config: bool = True,
22
  ) -> Trainer:
23
- """학습을 시작합니다 ( 줄 실행).
24
 
25
- 사용법 (Colab):
26
  ```python
27
  from model import LLMModel, ModelConfig
28
  from data_pipeline import setup_data_pipeline, DataConfig
29
  from trainer import start_training, TrainConfig
30
 
31
- # 1. 모델 생성
32
  model_config = ModelConfig.base_1b()
33
  model = LLMModel(model_config)
34
 
35
- # 2. 데이터 파이프라인
36
  tok, train_dl, val_dl = setup_data_pipeline("pretrained")
37
 
38
- # 3. 학습 시작 (체크포인트 자동 복원)
39
  trainer = start_training(model, train_dl, val_dl)
40
  ```
41
  """
42
  config = config or TrainConfig()
43
 
44
- # GPU 자동 감지 및 설정 조정
45
  if auto_config:
46
  config = auto_configure(config)
47
 
48
- # Google Drive 마운트 확인 (Colab)
49
  if "/content/drive" in config.checkpoint_dir:
50
  drive_path = Path("/content/drive/MyDrive")
51
  if not drive_path.exists():
52
- print("\n⚠️ Google Drive 마운트되지 않았습니다!")
53
- print(" Colab에서 실행: from google.colab import drive; drive.mount('/content/drive')")
54
- print(" 로컬 경로로 변경합니다.")
55
  config.checkpoint_dir = "./checkpoints"
56
 
57
- # 재현성 시드 설정
58
  torch.manual_seed(config.seed)
59
  if torch.cuda.is_available():
60
  torch.cuda.manual_seed(config.seed)
61
 
62
- # Trainer 생성 (체크포인트 자동 복원 포함)
63
  trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
64
 
65
- # 학습 실행
66
  trainer.train()
67
 
68
  return trainer
 
1
+ """Training execution helper (Quick Start)."""
2
 
3
  from pathlib import Path
4
  from typing import Optional
 
20
  seq_len: int = 2048,
21
  auto_config: bool = True,
22
  ) -> Trainer:
23
+ """Starts training (one-line execution).
24
 
25
+ Usage (Colab):
26
  ```python
27
  from model import LLMModel, ModelConfig
28
  from data_pipeline import setup_data_pipeline, DataConfig
29
  from trainer import start_training, TrainConfig
30
 
31
+ # 1. Create model
32
  model_config = ModelConfig.base_1b()
33
  model = LLMModel(model_config)
34
 
35
+ # 2. Data pipeline
36
  tok, train_dl, val_dl = setup_data_pipeline("pretrained")
37
 
38
+ # 3. Start training (automatic checkpoint restoration)
39
  trainer = start_training(model, train_dl, val_dl)
40
  ```
41
  """
42
  config = config or TrainConfig()
43
 
44
+ # Auto-detect GPU and adjust configuration
45
  if auto_config:
46
  config = auto_configure(config)
47
 
48
+ # Check Google Drive mount (Colab)
49
  if "/content/drive" in config.checkpoint_dir:
50
  drive_path = Path("/content/drive/MyDrive")
51
  if not drive_path.exists():
52
+ print("\n⚠️ Google Drive is not mounted!")
53
+ print(" Run in Colab: from google.colab import drive; drive.mount('/content/drive')")
54
+ print(" Switching to local path.")
55
  config.checkpoint_dir = "./checkpoints"
56
 
57
+ # Set reproducibility seed
58
  torch.manual_seed(config.seed)
59
  if torch.cuda.is_available():
60
  torch.cuda.manual_seed(config.seed)
61
 
62
+ # Create Trainer (includes automatic checkpoint restoration)
63
  trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
64
 
65
+ # Run training
66
  trainer.train()
67
 
68
  return trainer
llm_lab/training/scheduler.py CHANGED
@@ -1,4 +1,4 @@
1
- """Cosine Annealing with Linear Warmup 스케줄러."""
2
 
3
  import math
4
 
@@ -10,22 +10,22 @@ from llm_lab.config import TrainConfig
10
  class CosineWarmupScheduler:
11
  """Cosine Annealing with Linear Warmup.
12
 
13
- LR 곡선:
14
  ┌─── peak_lr ───────╲
15
  │ ╲ cosine decay
16
  │ warmup (linear) ╲
17
  │/ ╲_______ min_lr
18
  └──────────────────────────────────→ steps
19
 
20
- Cosine Decay인가?
21
- - Step decay: 갑작스러운 LR 하락Loss 불안정
22
- - Linear decay: 후반부 LR 너무 빨리 감소
23
- - Cosine: 부드러운 감소, 학습 후반에도 적절한 LR 유지
24
- - GPT-3, LLaMA, Chinchilla 등 대부분의 LLM이 사용
25
 
26
- 구현 참고:
27
- PyTorch 내장 스케줄러(CosineAnnealingLR )도 있지만,
28
- warmup + min_lr + 체크포인트 복원을 위해 직접 구현이 유연합니다.
29
  """
30
 
31
  def __init__(self, config: TrainConfig):
@@ -35,33 +35,33 @@ class CosineWarmupScheduler:
35
  self.total_steps = config.total_steps
36
 
37
  def get_lr(self, step: int) -> float:
38
- """현재 step에 해당하는 학습률을 반환합니다.
39
 
40
  Args:
41
- step: 현재 optimizer step (0-indexed)
42
 
43
  Returns:
44
- 학습률 (float)
45
  """
46
  # Phase 1: Linear Warmup
47
  if step < self.warmup_steps:
48
- # 0 peak_lr 선형 증가
49
  return self.peak_lr * (step / self.warmup_steps)
50
 
51
  # Phase 2: Cosine Decay
52
- # warmup 이후 남은 진행률 (0.0 → 1.0)
53
  decay_steps = self.total_steps - self.warmup_steps
54
  progress = (step - self.warmup_steps) / max(decay_steps, 1)
55
- progress = min(progress, 1.0) # 안전장치
56
 
57
- # Cosine 공식: min_lr + 0.5 × (peak - min) × (1 + cos(π × progress))
58
  cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
59
  lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
60
 
61
  return lr
62
 
63
  def set_lr(self, optimizer: torch.optim.Optimizer, step: int):
64
- """Optimizer의 학습률을 업데이트합니다."""
65
  lr = self.get_lr(step)
66
  for param_group in optimizer.param_groups:
67
  param_group["lr"] = lr
 
1
+ """Cosine Annealing with Linear Warmup scheduler."""
2
 
3
  import math
4
 
 
10
  class CosineWarmupScheduler:
11
  """Cosine Annealing with Linear Warmup.
12
 
13
+ LR curve:
14
  ┌─── peak_lr ───────╲
15
  │ ╲ cosine decay
16
  │ warmup (linear) ╲
17
  │/ ╲_______ min_lr
18
  └──────────────────────────────────→ steps
19
 
20
+ Why Cosine Decay?
21
+ - Step decay: sudden LR dropunstable loss
22
+ - Linear decay: LR decreases too quickly in the later stages
23
+ - Cosine: smooth decay, maintains appropriate LR even in the late training phase
24
+ - Used by most LLMs including GPT-3, LLaMA, and Chinchilla
25
 
26
+ Implementation note:
27
+ PyTorch has built-in schedulers (e.g., CosineAnnealingLR), but
28
+ a custom implementation is more flexible for warmup + min_lr + checkpoint restoration.
29
  """
30
 
31
  def __init__(self, config: TrainConfig):
 
35
  self.total_steps = config.total_steps
36
 
37
  def get_lr(self, step: int) -> float:
38
+ """Returns the learning rate for the current step.
39
 
40
  Args:
41
+ step: Current optimizer step (0-indexed)
42
 
43
  Returns:
44
+ Learning rate (float)
45
  """
46
  # Phase 1: Linear Warmup
47
  if step < self.warmup_steps:
48
+ # Linear increase from 0 to peak_lr
49
  return self.peak_lr * (step / self.warmup_steps)
50
 
51
  # Phase 2: Cosine Decay
52
+ # Progress ratio after warmup (0.0 → 1.0)
53
  decay_steps = self.total_steps - self.warmup_steps
54
  progress = (step - self.warmup_steps) / max(decay_steps, 1)
55
+ progress = min(progress, 1.0) # safety clamp
56
 
57
+ # Cosine formula: min_lr + 0.5 × (peak - min) × (1 + cos(π × progress))
58
  cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
59
  lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
60
 
61
  return lr
62
 
63
  def set_lr(self, optimizer: torch.optim.Optimizer, step: int):
64
+ """Updates the learning rate of the optimizer."""
65
  lr = self.get_lr(step)
66
  for param_group in optimizer.param_groups:
67
  param_group["lr"] = lr
llm_lab/training/trainer.py CHANGED
@@ -1,4 +1,4 @@
1
- """LLM 사전학습 트레이너."""
2
 
3
  import math
4
  import time
@@ -16,9 +16,9 @@ from .optimizer import create_optimizer
16
 
17
 
18
  class Trainer:
19
- """LLM 사전학습 트레이너.
20
 
21
- 학습 루프의 핵심 구조:
22
  ```
23
  for step in range(total_steps):
24
  # ── Gradient Accumulation Loop ──
@@ -27,22 +27,22 @@ class Trainer:
27
  with autocast(bf16):
28
  logits, loss = model(input_ids, targets)
29
  scaled_loss = loss / accumulation_steps
30
- scaled_loss.backward() # gradient 누적
31
 
32
- # ── Optimizer Step (accumulation 완료 후) ──
33
  clip_grad_norm(model, max_norm=1.0)
34
  optimizer.step()
35
  optimizer.zero_grad()
36
  scheduler.set_lr(optimizer, step)
37
  ```
38
 
39
- Gradient Accumulation이란?
40
- - GPU 메모리에 배치를 번에 올릴 없을
41
- - 작은 micro_batch로 여러 번 forward/backward → gradient를 누적
42
- - 누적 번에 optimizer step
43
- - 결과적으로 effective_batch와 동일한 효과
44
- - Loss를 accumulation_steps로 나누는 이유:
45
- gradient의 평균을 구하기 위해 (합이 아닌 평균)
46
  """
47
 
48
  def __init__(
@@ -56,52 +56,52 @@ class Trainer:
56
  self.config = config
57
  self.seq_len = seq_len
58
 
59
- # ── 디바이스 설정 ──
60
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
- print(f"[Trainer] 디바이스: {self.device}")
62
  if torch.cuda.is_available():
63
  print(f"[Trainer] GPU: {torch.cuda.get_device_name()}")
64
- print(f"[Trainer] GPU 메모리: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
65
 
66
- # ── 모델 ──
67
  self.model = model.to(self.device)
68
- # torch.compile: PyTorch 2.0+ 그래프 최적화 (속도 10-30% 향상)
69
  if torch.cuda.is_available() and hasattr(torch, "compile"):
70
- print("[Trainer] torch.compile 적용 중...")
71
  self.model = torch.compile(self.model)
72
 
73
- # ── 데이터 ──
74
  self.train_dataloader = train_dataloader
75
  self.val_dataloader = val_dataloader
76
  self.train_iter = iter(train_dataloader)
77
 
78
- # ── 옵티마이저 ──
79
  self.optimizer = create_optimizer(self.model, config)
80
 
81
- # ── 스케줄러 ──
82
  self.scheduler = CosineWarmupScheduler(config)
83
 
84
- # ── 체크포인트 ──
85
  self.ckpt_manager = CheckpointManager(config)
86
 
87
- # ── 메트릭 ──
88
  self.metrics = MetricsTracker(config)
89
 
90
- # ── 학습 상태 ──
91
  self.global_step = 0
92
  self.best_val_loss = float("inf")
93
  self.tokens_seen = 0
94
 
95
  # ── Mixed Precision ──
96
- # bf16 GradScaler가 불필요 (fp16일 때만 필요)
97
  self.use_amp = config.dtype != "float32"
98
  self.amp_dtype = config.torch_dtype
99
 
100
- # ── 자동 복원 시도 ──
101
  self._try_resume()
102
 
103
  def _try_resume(self):
104
- """이전 체크포인트가 있으면 자동으로 복원합니다."""
105
  result = self.ckpt_manager.load_latest(
106
  self.model, self.optimizer, self.device
107
  )
@@ -111,20 +111,20 @@ class Trainer:
111
  self.best_val_loss = result["best_val_loss"]
112
  self.metrics.history = result.get("metrics_history", self.metrics.history)
113
 
114
- # wandb 연속 로깅
115
  if result.get("wandb_run_id"):
116
  self.metrics.resume_wandb(result["wandb_run_id"])
117
 
118
  self.tokens_seen = self.global_step * self.config.effective_batch_size * self.seq_len
119
- print(f"[Trainer] 학습 재개: step={self.global_step}, "
120
  f"tokens={self.tokens_seen/1e9:.2f}B, "
121
  f"best_val_loss={self.best_val_loss:.4f}")
122
 
123
  def _get_next_batch(self) -> Dict[str, torch.Tensor]:
124
- """다음 학습 배치를 가져옵니다.
125
 
126
- Streaming DataLoader 에폭 개념이 없으므로,
127
- StopIteration 이터레이터를 생성합니다.
128
  """
129
  try:
130
  batch = next(self.train_iter)
@@ -138,14 +138,14 @@ class Trainer:
138
  }
139
 
140
  def _train_step(self) -> Tuple[float, float]:
141
- """하나의 optimizer step을 수행합니다.
142
 
143
  Returns:
144
  (loss, grad_norm)
145
  """
146
  self.model.train()
147
  self.optimizer.zero_grad(set_to_none=True)
148
- # set_to_none=True: gradient를 None으로 설정메모리 절약
149
 
150
  total_loss = 0.0
151
 
@@ -157,16 +157,16 @@ class Trainer:
157
  with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
158
  logits, loss = self.model(batch["input_ids"], batch["targets"])
159
 
160
- # Loss 스케일링: effective batch의 평균을 위해
161
  scaled_loss = loss / self.config.gradient_accumulation_steps
162
  total_loss += loss.item()
163
 
164
- # Backward (gradient 누적)
165
  scaled_loss.backward()
166
 
167
  # ── Gradient Clipping ──
168
- # 모든 파라미터의 gradient를 하나의 벡터로 보고 L2 norm 계산
169
- # norm max_norm 초과하면 비례적으로 스케일 다운
170
  grad_norm = torch.nn.utils.clip_grad_norm_(
171
  self.model.parameters(),
172
  max_norm=self.config.grad_clip,
@@ -175,7 +175,7 @@ class Trainer:
175
  # ── Optimizer Step ──
176
  self.optimizer.step()
177
 
178
- # ── LR 업데이트 ──
179
  self.scheduler.set_lr(self.optimizer, self.global_step)
180
 
181
  avg_loss = total_loss / self.config.gradient_accumulation_steps
@@ -183,13 +183,13 @@ class Trainer:
183
 
184
  @torch.no_grad()
185
  def _evaluate(self) -> Tuple[float, float]:
186
- """검증 데이터에서 Loss Perplexity 측정합니다.
187
 
188
  Perplexity = exp(loss)
189
- - 직관: "모델이 다음 토큰을 평균 개의 후보 중에서 고르는가"
190
- - PPL 100 → 100개 1개를 균일하게 고르는 수준
191
- - PPL 20 → 20개 1개 수준 ( 좋음)
192
- - PPL 10 → 매우 자신있게 예측
193
  """
194
  if self.val_dataloader is None:
195
  return float("inf"), float("inf")
@@ -212,36 +212,37 @@ class Trainer:
212
  num_batches += 1
213
 
214
  avg_loss = total_loss / max(num_batches, 1)
215
- perplexity = math.exp(min(avg_loss, 20)) # overflow 방지 (exp(20) ≈ 5억)
216
 
217
  return avg_loss, perplexity
218
 
219
  def train(self):
220
- """메인 학습 루프.
221
 
222
- 메서드가 전체 학습을 실행합니다.
223
- Colab 세션 만료 중단되어도 체크포인트에서 자동 재개됩니다.
 
224
  """
225
  config = self.config
226
 
227
  print("\n" + "=" * 70)
228
- print("🚀 학습 시작")
229
  print("=" * 70)
230
- print(f" 스텝: {config.total_steps:,}")
231
- print(f" 시작 스텝: {self.global_step}")
232
  print(f" Effective batch size: {config.effective_batch_size}")
233
- print(f" 토큰/스텝: {config.effective_batch_size * self.seq_len:,}")
234
- print(f" 학습 토큰 (예상): {config.total_steps * config.effective_batch_size * self.seq_len / 1e9:.1f}B")
235
  print(f" Mixed Precision: {config.dtype}")
236
  print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
237
- print(f" 체크포인트: {config.checkpoint_dir}")
238
  print("=" * 70 + "\n")
239
 
240
  step_start_time = time.time()
241
  tokens_at_log_start = self.tokens_seen
242
 
243
  # ════════════════════════════════════════════
244
- # 메인 루프
245
  # ════════════════════════════════════════════
246
 
247
  while self.global_step < config.total_steps:
@@ -257,21 +258,21 @@ class Trainer:
257
  tokens_delta = self.tokens_seen - tokens_at_log_start
258
  tokens_per_sec = tokens_delta / max(elapsed, 1e-6)
259
 
260
- # GPU 메모리
261
  gpu_mem_gb = 0.0
262
  if torch.cuda.is_available():
263
  gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9
264
 
265
- # 현재 LR
266
  current_lr = self.scheduler.get_lr(self.global_step)
267
 
268
- # 남은 시간 추정
269
  remaining_steps = config.total_steps - self.global_step
270
  steps_per_sec = config.log_interval / max(elapsed, 1e-6)
271
  eta_seconds = remaining_steps / max(steps_per_sec, 1e-6)
272
  eta_hours = eta_seconds / 3600
273
 
274
- # 콘솔 출력
275
  print(
276
  f" Step {self.global_step:>6d}/{config.total_steps} │ "
277
  f"Loss {loss:.4f} │ "
@@ -283,7 +284,7 @@ class Trainer:
283
  f"Tokens {self.tokens_seen/1e9:.2f}B"
284
  )
285
 
286
- # wandb 로깅
287
  self.metrics.log_train_step(
288
  step=self.global_step,
289
  loss=loss,
@@ -324,19 +325,19 @@ class Trainer:
324
  )
325
 
326
  # ════════════════════════════════════════════
327
- # 학습 완료
328
  # ════════════════════════════════════════════
329
 
330
  print("\n" + "=" * 70)
331
- print("🎉 학습 완료!")
332
  print("=" * 70)
333
- print(f" 스텝: {self.global_step:,}")
334
- print(f" 토큰: {self.tokens_seen/1e9:.2f}B")
335
- print(f" 최저 Val Loss: {self.best_val_loss:.4f}")
336
- print(f" 최저 Val PPL: {math.exp(min(self.best_val_loss, 20)):.2f}")
337
  print("=" * 70)
338
 
339
- # 최종 체크포인트 저장
340
  self.ckpt_manager.save(
341
  model=self.model,
342
  optimizer=self.optimizer,
 
1
+ """LLM pretraining trainer."""
2
 
3
  import math
4
  import time
 
16
 
17
 
18
  class Trainer:
19
+ """LLM pretraining trainer.
20
 
21
+ Core structure of the training loop:
22
  ```
23
  for step in range(total_steps):
24
  # ── Gradient Accumulation Loop ──
 
27
  with autocast(bf16):
28
  logits, loss = model(input_ids, targets)
29
  scaled_loss = loss / accumulation_steps
30
+ scaled_loss.backward() # accumulate gradients
31
 
32
+ # ── Optimizer Step (after accumulation completes) ──
33
  clip_grad_norm(model, max_norm=1.0)
34
  optimizer.step()
35
  optimizer.zero_grad()
36
  scheduler.set_lr(optimizer, step)
37
  ```
38
 
39
+ What is Gradient Accumulation?
40
+ - Used when a large batch cannot fit into GPU memory all at once
41
+ - Run forward/backward multiple times with small micro_batches accumulate gradients
42
+ - Perform optimizer step once after accumulation is complete
43
+ - Effectively equivalent to training with a large effective batch size
44
+ - Reason for dividing loss by accumulation_steps:
45
+ to compute the mean of gradients (average, not sum)
46
  """
47
 
48
  def __init__(
 
56
  self.config = config
57
  self.seq_len = seq_len
58
 
59
+ # ── Device setup ──
60
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ print(f"[Trainer] Device: {self.device}")
62
  if torch.cuda.is_available():
63
  print(f"[Trainer] GPU: {torch.cuda.get_device_name()}")
64
+ print(f"[Trainer] GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
65
 
66
+ # ── Model ──
67
  self.model = model.to(self.device)
68
+ # torch.compile: PyTorch 2.0+ graph optimization (10-30% speed improvement)
69
  if torch.cuda.is_available() and hasattr(torch, "compile"):
70
+ print("[Trainer] Applying torch.compile...")
71
  self.model = torch.compile(self.model)
72
 
73
+ # ── Data ──
74
  self.train_dataloader = train_dataloader
75
  self.val_dataloader = val_dataloader
76
  self.train_iter = iter(train_dataloader)
77
 
78
+ # ── Optimizer ──
79
  self.optimizer = create_optimizer(self.model, config)
80
 
81
+ # ── Scheduler ──
82
  self.scheduler = CosineWarmupScheduler(config)
83
 
84
+ # ── Checkpoint ──
85
  self.ckpt_manager = CheckpointManager(config)
86
 
87
+ # ── Metrics ──
88
  self.metrics = MetricsTracker(config)
89
 
90
+ # ── Training state ──
91
  self.global_step = 0
92
  self.best_val_loss = float("inf")
93
  self.tokens_seen = 0
94
 
95
  # ── Mixed Precision ──
96
+ # bf16 does not require GradScaler (only needed for fp16)
97
  self.use_amp = config.dtype != "float32"
98
  self.amp_dtype = config.torch_dtype
99
 
100
+ # ── Attempt automatic resume ──
101
  self._try_resume()
102
 
103
  def _try_resume(self):
104
+ """Automatically restores from a previous checkpoint if one exists."""
105
  result = self.ckpt_manager.load_latest(
106
  self.model, self.optimizer, self.device
107
  )
 
111
  self.best_val_loss = result["best_val_loss"]
112
  self.metrics.history = result.get("metrics_history", self.metrics.history)
113
 
114
+ # Resume wandb logging continuously
115
  if result.get("wandb_run_id"):
116
  self.metrics.resume_wandb(result["wandb_run_id"])
117
 
118
  self.tokens_seen = self.global_step * self.config.effective_batch_size * self.seq_len
119
+ print(f"[Trainer] Resuming training: step={self.global_step}, "
120
  f"tokens={self.tokens_seen/1e9:.2f}B, "
121
  f"best_val_loss={self.best_val_loss:.4f}")
122
 
123
  def _get_next_batch(self) -> Dict[str, torch.Tensor]:
124
+ """Fetches the next training batch.
125
 
126
+ Since a Streaming DataLoader has no epoch concept,
127
+ a new iterator is created when StopIteration is raised.
128
  """
129
  try:
130
  batch = next(self.train_iter)
 
138
  }
139
 
140
  def _train_step(self) -> Tuple[float, float]:
141
+ """Performs one optimizer step.
142
 
143
  Returns:
144
  (loss, grad_norm)
145
  """
146
  self.model.train()
147
  self.optimizer.zero_grad(set_to_none=True)
148
+ # set_to_none=True: sets gradients to None saves memory
149
 
150
  total_loss = 0.0
151
 
 
157
  with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
158
  logits, loss = self.model(batch["input_ids"], batch["targets"])
159
 
160
+ # Loss scaling: to compute the mean over the effective batch
161
  scaled_loss = loss / self.config.gradient_accumulation_steps
162
  total_loss += loss.item()
163
 
164
+ # Backward (accumulate gradients)
165
  scaled_loss.backward()
166
 
167
  # ── Gradient Clipping ──
168
+ # Treat all parameter gradients as a single vector and compute L2 norm
169
+ # If norm exceeds max_norm, scale down proportionally
170
  grad_norm = torch.nn.utils.clip_grad_norm_(
171
  self.model.parameters(),
172
  max_norm=self.config.grad_clip,
 
175
  # ── Optimizer Step ──
176
  self.optimizer.step()
177
 
178
+ # ── LR Update ──
179
  self.scheduler.set_lr(self.optimizer, self.global_step)
180
 
181
  avg_loss = total_loss / self.config.gradient_accumulation_steps
 
183
 
184
  @torch.no_grad()
185
  def _evaluate(self) -> Tuple[float, float]:
186
+ """Measures Loss and Perplexity on the validation data.
187
 
188
  Perplexity = exp(loss)
189
+ - Intuition: "how many candidates does the model choose the next token from on average"
190
+ - PPL 100 → equivalent to uniformly choosing 1 out of 100
191
+ - PPL 20 → 1 out of 20 (fairly good)
192
+ - PPL 10 → predicting with high confidence
193
  """
194
  if self.val_dataloader is None:
195
  return float("inf"), float("inf")
 
212
  num_batches += 1
213
 
214
  avg_loss = total_loss / max(num_batches, 1)
215
+ perplexity = math.exp(min(avg_loss, 20)) # prevent overflow (exp(20) ≈ 500M)
216
 
217
  return avg_loss, perplexity
218
 
219
  def train(self):
220
+ """Main training loop.
221
 
222
+ This method runs the entire training process.
223
+ Even if interrupted by a Colab session expiry,
224
+ training will automatically resume from the last checkpoint.
225
  """
226
  config = self.config
227
 
228
  print("\n" + "=" * 70)
229
+ print("🚀 Training started")
230
  print("=" * 70)
231
+ print(f" Total steps: {config.total_steps:,}")
232
+ print(f" Start step: {self.global_step}")
233
  print(f" Effective batch size: {config.effective_batch_size}")
234
+ print(f" Tokens/step: {config.effective_batch_size * self.seq_len:,}")
235
+ print(f" Total training tokens (estimated): {config.total_steps * config.effective_batch_size * self.seq_len / 1e9:.1f}B")
236
  print(f" Mixed Precision: {config.dtype}")
237
  print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
238
+ print(f" Checkpoint: {config.checkpoint_dir}")
239
  print("=" * 70 + "\n")
240
 
241
  step_start_time = time.time()
242
  tokens_at_log_start = self.tokens_seen
243
 
244
  # ════════════════════════════════════════════
245
+ # Main loop
246
  # ════════════════════════════════════════════
247
 
248
  while self.global_step < config.total_steps:
 
258
  tokens_delta = self.tokens_seen - tokens_at_log_start
259
  tokens_per_sec = tokens_delta / max(elapsed, 1e-6)
260
 
261
+ # GPU memory
262
  gpu_mem_gb = 0.0
263
  if torch.cuda.is_available():
264
  gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9
265
 
266
+ # Current LR
267
  current_lr = self.scheduler.get_lr(self.global_step)
268
 
269
+ # Estimate remaining time
270
  remaining_steps = config.total_steps - self.global_step
271
  steps_per_sec = config.log_interval / max(elapsed, 1e-6)
272
  eta_seconds = remaining_steps / max(steps_per_sec, 1e-6)
273
  eta_hours = eta_seconds / 3600
274
 
275
+ # Console output
276
  print(
277
  f" Step {self.global_step:>6d}/{config.total_steps} │ "
278
  f"Loss {loss:.4f} │ "
 
284
  f"Tokens {self.tokens_seen/1e9:.2f}B"
285
  )
286
 
287
+ # wandb logging
288
  self.metrics.log_train_step(
289
  step=self.global_step,
290
  loss=loss,
 
325
  )
326
 
327
  # ════════════════════════════════════════════
328
+ # Training complete
329
  # ════════════════════════════════════════════
330
 
331
  print("\n" + "=" * 70)
332
+ print("🎉 Training complete!")
333
  print("=" * 70)
334
+ print(f" Total steps: {self.global_step:,}")
335
+ print(f" Total tokens: {self.tokens_seen/1e9:.2f}B")
336
+ print(f" Best Val Loss: {self.best_val_loss:.4f}")
337
+ print(f" Best Val PPL: {math.exp(min(self.best_val_loss, 20)):.2f}")
338
  print("=" * 70)
339
 
340
+ # Save final checkpoint
341
  self.ckpt_manager.save(
342
  model=self.model,
343
  optimizer=self.optimizer,
llm_lab/utils/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """공통 유틸리티디바이스 감지, 시드 설정."""
2
  from .device import get_device, detect_gpu_info, auto_configure
3
  from .seed import set_seed
4
 
 
1
+ """Common utilitiesdevice detection, seed configuration."""
2
  from .device import get_device, detect_gpu_info, auto_configure
3
  from .seed import set_seed
4
 
llm_lab/utils/device.py CHANGED
@@ -1,4 +1,4 @@
1
- """디바이스 감지 자동 설정 유틸리티."""
2
  from __future__ import annotations
3
 
4
  from typing import TYPE_CHECKING
@@ -10,15 +10,15 @@ if TYPE_CHECKING:
10
 
11
 
12
  def get_device() -> torch.device:
13
- """사용 가능한 디바이스(cuda 또는 cpu)를 반환합니다."""
14
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
 
17
  def detect_gpu_info() -> dict:
18
- """GPU 이름과 메모리 정보를 반환합니다.
19
 
20
  Returns:
21
- {"name": str, "memory_gb": float} 또는 GPU가 없으면 dict
22
  """
23
  if not torch.cuda.is_available():
24
  return {}
@@ -29,16 +29,16 @@ def detect_gpu_info() -> dict:
29
 
30
 
31
  def auto_configure(config: "TrainConfig") -> "TrainConfig":
32
- """GPU 종류에 따라 설정을 자동 조정합니다.
33
 
34
- Colab Pro+에서 A100 항상 배정되지는 않습니다.
35
- T4 V100 배정될 경우 자동으로 설정을 조정합니다.
36
 
37
  Returns:
38
- 조정된 TrainConfig
39
  """
40
  if not torch.cuda.is_available():
41
- print("⚠️ GPU 없음! CPU 모드 (매우 느림)")
42
  config.dtype = "float32"
43
  config.micro_batch_size = 1
44
  config.gradient_accumulation_steps = 4
@@ -47,37 +47,37 @@ def auto_configure(config: "TrainConfig") -> "TrainConfig":
47
  gpu_name = torch.cuda.get_device_name().lower()
48
  gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
49
 
50
- print(f"\n🔍 GPU 감지: {torch.cuda.get_device_name()} ({gpu_mem:.1f} GB)")
51
 
52
  if "a100" in gpu_name:
53
- # A100 40GB: 기본 설정 그대로 (최적)
54
- print(" → A100 감지: 기본 설정 사용 (bf16, batch=4)")
55
  config.dtype = "bfloat16"
56
  config.micro_batch_size = 4
57
 
58
  elif "v100" in gpu_name:
59
- # V100 16GB: bf16 미지원, 배치 축소
60
- print(" → V100 감지: fp16 모드, 배치 축소")
61
  config.dtype = "float16"
62
  config.micro_batch_size = 2
63
- config.gradient_accumulation_steps = 64 # effective batch 유지
64
 
65
  elif "t4" in gpu_name:
66
- # T4 16GB: bf16 미지원, 작은 배치
67
- print(" → T4 감지: fp16 모드, 최소 배치")
68
  config.dtype = "float16"
69
  config.micro_batch_size = 1
70
  config.gradient_accumulation_steps = 128
71
 
72
  elif "l4" in gpu_name:
73
- # L4 24GB: bf16 지원
74
- print(" → L4 감지: bf16 모드, 배치 조정")
75
  config.dtype = "bfloat16"
76
  config.micro_batch_size = 2
77
  config.gradient_accumulation_steps = 64
78
 
79
  else:
80
- print(f" → 수 없는 GPU. 메모리 기준으로 설정 조정")
81
  if gpu_mem >= 30:
82
  config.micro_batch_size = 4
83
  elif gpu_mem >= 16:
 
1
+ """Device detection and auto-configuration utilities."""
2
  from __future__ import annotations
3
 
4
  from typing import TYPE_CHECKING
 
10
 
11
 
12
  def get_device() -> torch.device:
13
+ """Returns the available device (cuda or cpu)."""
14
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
 
17
  def detect_gpu_info() -> dict:
18
+ """Returns GPU name and memory information.
19
 
20
  Returns:
21
+ {"name": str, "memory_gb": float} or an empty dict if no GPU is available
22
  """
23
  if not torch.cuda.is_available():
24
  return {}
 
29
 
30
 
31
  def auto_configure(config: "TrainConfig") -> "TrainConfig":
32
+ """Automatically adjusts configuration based on GPU type.
33
 
34
+ In Colab Pro+, an A100 is not always assigned.
35
+ If a T4 or V100 is assigned, configuration is automatically adjusted.
36
 
37
  Returns:
38
+ Adjusted TrainConfig
39
  """
40
  if not torch.cuda.is_available():
41
+ print("⚠️ No GPU found! Running in CPU mode (very slow)")
42
  config.dtype = "float32"
43
  config.micro_batch_size = 1
44
  config.gradient_accumulation_steps = 4
 
47
  gpu_name = torch.cuda.get_device_name().lower()
48
  gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
49
 
50
+ print(f"\n🔍 GPU detected: {torch.cuda.get_device_name()} ({gpu_mem:.1f} GB)")
51
 
52
  if "a100" in gpu_name:
53
+ # A100 40GB: use default settings (optimal)
54
+ print(" → A100 detected: using default settings (bf16, batch=4)")
55
  config.dtype = "bfloat16"
56
  config.micro_batch_size = 4
57
 
58
  elif "v100" in gpu_name:
59
+ # V100 16GB: bf16 not supported, reduce batch size
60
+ print(" → V100 detected: fp16 mode, reduced batch size")
61
  config.dtype = "float16"
62
  config.micro_batch_size = 2
63
+ config.gradient_accumulation_steps = 64 # maintain effective batch size
64
 
65
  elif "t4" in gpu_name:
66
+ # T4 16GB: bf16 not supported, smaller batch
67
+ print(" → T4 detected: fp16 mode, minimum batch size")
68
  config.dtype = "float16"
69
  config.micro_batch_size = 1
70
  config.gradient_accumulation_steps = 128
71
 
72
  elif "l4" in gpu_name:
73
+ # L4 24GB: bf16 supported
74
+ print(" → L4 detected: bf16 mode, adjusted batch size")
75
  config.dtype = "bfloat16"
76
  config.micro_batch_size = 2
77
  config.gradient_accumulation_steps = 64
78
 
79
  else:
80
+ print(f" → Unknown GPU. Adjusting settings based on memory")
81
  if gpu_mem >= 30:
82
  config.micro_batch_size = 4
83
  elif gpu_mem >= 16:
llm_lab/utils/seed.py CHANGED
@@ -1,9 +1,9 @@
1
- """재현성을 위한 시드 유틸리티."""
2
  import torch
3
 
4
 
5
  def set_seed(seed: int = 42):
6
- """재현성을 위한 시드 설정."""
7
  torch.manual_seed(seed)
8
  if torch.cuda.is_available():
9
  torch.cuda.manual_seed(seed)
 
1
+ """Seed utility for reproducibility."""
2
  import torch
3
 
4
 
5
  def set_seed(seed: int = 42):
6
+ """Set seed for reproducibility."""
7
  torch.manual_seed(seed)
8
  if torch.cuda.is_available():
9
  torch.cuda.manual_seed(seed)