pathcosmos commited on
Commit
29fc577
·
verified ·
1 Parent(s): 42ca925

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/combined_preference.jsonl filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,162 +1,146 @@
1
  ---
2
  language:
3
- - ko
4
- - en
5
  license: mit
6
  tags:
7
- - mamba2
8
- - hybrid
9
- - korean
10
- - causal-lm
 
 
 
 
 
11
  pipeline_tag: text-generation
12
  ---
13
 
14
- # EVAFRILL-Mo-3B
15
 
16
- EVAFRILL-Mo-3B is a 2.94B-parameter **hybrid Mamba-2 + Transformer** language
17
- model optimised for Korean, trained from scratch on 55 billion tokens of
18
- Korean-dominant multilingual text.
19
 
20
- > **EVAFRILL-Mo** stands for *Efficient Variably-Architected Fusion of
21
- > Recurrent and Integrated Linear Layers for Language Model-based Output* — a
22
- > custom architecture inspired by [Nemotron-H](https://arxiv.org/abs/2501.14587)
23
- > that replaces most self-attention layers with Mamba-2 SSM blocks, achieving
24
- > linear-time inference without sacrificing generation quality.
25
 
26
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ## Architecture
29
 
30
- | Property | Value |
31
- |---|---|
32
- | Total parameters | ~2.94 B |
33
- | Layers | 26 (24 × Mamba-2 + 2 × Attention) |
34
- | Hidden size | 3072 |
35
- | Attention heads | 24 (GQA, 8 KV heads) |
36
- | FFN size | 9 216 |
37
- | Mamba-2 state dim | 128 |
38
- | Mamba-2 head dim | 64 |
39
- | Vocab size | 64 000 |
40
- | Max sequence length | 4 096 |
41
- | RoPE theta | 500 000 |
42
-
43
- The layer pattern places attention blocks at positions 9 and 18 (zero-indexed),
44
- mirroring the Nemotron-H 8B dense design scaled to 3B parameters. All other
45
- layers use Mamba-2 with SwiGLU FFN (mamba_d_ffn = 4 608). Attention layers use
46
- full SwiGLU FFN (d_ffn = 9 216).
47
 
48
- ---
49
 
50
- ## Training
51
-
52
- ### Pretraining
53
- - **Tokens**: 55 B (319 772 steps, effective batch ≈ 172 K tokens)
54
- - **Hardware**: 8× NVIDIA B200 (183 GB each), ~62 hours
55
- - **Optimizer**: AdamW, lr=2e-4, cosine decay, warmup 2 000 steps
56
- - **Precision**: FP8 (TransformerEngine MXFP8) + BF16 embedding
57
- - **Data**: Korean web corpus, Wikipedia, books, code (Korean-dominant)
58
-
59
- ### Supervised Fine-Tuning (SFT)
60
- - **Steps**: 65 000 (≈ 1 epoch on 2.44M instruction samples)
61
- - **Effective batch**: 56 (2 per GPU × 7 GPU × 4 grad_accum)
62
- - **LR**: 1e-5 (pretrain/30, catastrophic-forgetting guard)
63
- - **NEFTune alpha**: 5.0 (repetition degeneracy mitigation)
64
- - **Data**: Combined Korean instruction set (filtered, 2.44M samples)
65
-
66
- ### Direct Preference Optimisation (DPO)
67
- - **Rounds**: 2-round DPO (Nemotron-H style)
68
- - Round 1: 3 000 steps, beta=0.1, lr=5e-7, LoRA rank=32
69
- - Round 2: 2 000 steps, beta=0.05, lr=1e-7, LoRA rank=32
70
- - **Hardware**: 1× NVIDIA H100 MIG 3g.40gb (~42 GB VRAM)
71
- - **Method**: Native LoRA DPO (no TRL dependency)
72
-
73
- ### SLERP Merge
74
- The final checkpoint is produced by **spherical linear interpolation (SLERP)**
75
- between the SFT-v2 and DPO-round-2 checkpoints (ratio 0.5), combining the
76
- instruction-following strengths of both stages.
77
 
78
- ---
79
 
80
- ## Evaluation (SLERP checkpoint, lm-eval-harness)
81
 
82
- | Benchmark | Metric | Score |
83
- |---|---|---|
84
- | HellaSwag | acc_norm | 0.42 |
85
- | ARC-Challenge | acc_norm | 0.22 |
86
- | ARC-Easy | acc_norm | 0.28 |
87
- | Belebele (kor_Hang) | acc | 0.30 |
88
- | Global-MMLU-ko (full) | acc | 0.233 |
89
- | — Humanities | acc | 0.242 |
90
- | — STEM | acc | 0.237 |
91
- | — Social Sciences | acc | 0.221 |
92
- | — Other | acc | 0.229 |
93
 
94
- *Evaluated on 100-sample subsets per task. Numbers reflect the final
95
- SLERP-merged checkpoint.*
96
 
97
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  ## Usage
100
 
101
  ```python
102
- # Requires: pip install transformers tokenizers safetensors
103
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
104
  import torch
 
 
105
 
106
- model_id = "pathcosmos/EVAFRILL-Mo-3B"
107
-
108
- tokenizer = AutoTokenizer.from_pretrained(model_id)
109
- model = AutoModelForCausalLM.from_pretrained(
110
- model_id,
111
- torch_dtype=torch.bfloat16,
112
- device_map="auto",
113
- )
114
-
115
- # Chat-style prompt
116
- prompt = "<|user|>\n안녕하세요! 자기소개를 해 주세요.\n<|assistant|>\n"
117
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
118
-
119
- with torch.no_grad():
120
- output = model.generate(
121
- **inputs,
122
- max_new_tokens=256,
123
- temperature=0.8,
124
- top_p=0.9,
125
- do_sample=True,
126
- repetition_penalty=1.1,
127
- )
128
-
129
- print(tokenizer.decode(output[0], skip_special_tokens=False))
130
- ```
131
 
132
- ---
 
133
 
134
  ## Limitations
135
 
136
- - This model is an **experimental research checkpoint**, not a production system.
137
- - Korean is the dominant language; English and other languages are secondary.
138
- - The custom architecture (`evafrill-mo`) requires either
139
- (a) the original project code for full inference, or
140
- (b) a compatible HuggingFace integration that understands Mamba-2 hybrid layers.
141
- The exported `model.safetensors` preserves the native weight layout.
142
- - Benchmark numbers were evaluated on small (100-sample) subsets and should be
143
- treated as rough estimates.
144
-
145
- ---
146
-
147
- ## Citation
148
 
149
- ```bibtex
150
- @misc{evafrill-mo-3b-2026,
151
- title = {EVAFRILL-Mo-3B: A Hybrid Mamba-2 + Transformer LLM for Korean},
152
- author = {pathcosmos},
153
- year = {2026},
154
- url = {https://huggingface.co/pathcosmos/EVAFRILL-Mo-3B},
155
- }
156
- ```
157
 
158
- ---
 
159
 
160
  ## License
161
 
162
- [MIT](LICENSE)
 
1
  ---
2
  language:
3
+ - ko
4
+ - en
5
  license: mit
6
  tags:
7
+ - mamba2
8
+ - hybrid
9
+ - transformer
10
+ - korean
11
+ - from-scratch
12
+ - dpo
13
+ - slerp
14
+ - orpo
15
+ library_name: pytorch
16
  pipeline_tag: text-generation
17
  ---
18
 
19
+ # EVAFRILL-Mo 3B — Hybrid Mamba-2 + Transformer
20
 
21
+ **A 3-billion-parameter hybrid Mamba-2 + Transformer language model built from scratch.**
 
 
22
 
23
+ Inspired by the NVIDIA [Nemotron-H](https://arxiv.org/abs/2504.03624) architecture. Pretrained on 55B tokens across Korean, English, code, and math using 7× NVIDIA B200 GPUs.
 
 
 
 
24
 
25
+ ## Model Variants
26
+
27
+ This repository contains **7 model versions** representing each stage of the training pipeline, plus training data and scripts for full reproducibility.
28
+
29
+ | Variant | Directory | Size | Description | Recommended |
30
+ |---------|-----------|------|-------------|:-----------:|
31
+ | **SLERP** | `slerp/` | 6.3GB | SFT + DPO merged (α=0.5) | ⭐ **Yes** |
32
+ | Pretrain | `pretrain/` | 12.6GB | Base model (319K steps, 55B tokens) | |
33
+ | SFT v2 | `sft-v2/` | 6.3GB | Instruction-tuned (65K steps) | |
34
+ | DPO R1 | `dpo-r1/` | 6.3GB | Preference-aligned Round 1 | |
35
+ | DPO R2 | `dpo-r2/` | 6.3GB | Conservative fine-tuning Round 2 | |
36
+ | ORPO | `orpo/` | 6.3GB | SFT+alignment simultaneous (experimental) | |
37
+ | DPO R3 | `dpo-r3/` | 6.3GB | Repetition-targeted (experimental) | |
38
+
39
+ ## Training Pipeline
40
+
41
+ ```
42
+ Pretrain (55B tokens, 7×B200, 60h)
43
+
44
+ SFT v2 (65K steps, H100 MIG, 5 days)
45
+
46
+ DPO Round 1 (3K steps, LoRA, loss 0.693→0.565)
47
+
48
+ DPO Round 2 (2K steps, conservative, loss 0.692→0.689)
49
+
50
+ SLERP Merge (α=0.5, SFT 50% + DPO 50%) ← RECOMMENDED
51
+
52
+ ORPO Experiment (10K steps, alternative approach)
53
+
54
+ DPO Round 3 (1K steps, repetition-targeted experiment)
55
+ ```
56
 
57
  ## Architecture
58
 
59
+ ```
60
+ Type: Hybrid Mamba-2 + Transformer
61
+ Parameters: 2.94B (2,975,397,632)
62
+ Layers: 26 (24× Mamba-2 SSM + 2× Attention GQA)
63
+ d_model: 3,072
64
+ Vocabulary: 64,000 (custom SentencePiece)
65
+ Max seq length: 4,096
66
+ ```
 
 
 
 
 
 
 
 
 
67
 
68
+ ## Benchmark Results (SLERP, recommended model)
69
 
70
+ | Metric | Value |
71
+ |--------|-------|
72
+ | Greedy 3-gram repetition | 74.5% (→ 5.5% with rep_penalty=1.2) |
73
+ | hellaswag (0-shot) | 34.6% |
74
+ | arc_easy (0-shot) | 32.0% |
75
+ | belebele_kor (0-shot) | 23.6% |
76
+ | global_mmlu_ko (0-shot) | 23.7% |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ **Recommended inference**: `temperature=0.7, repetition_penalty=1.2`
79
 
80
+ ## SFT→DPO→SLERP vs ORPO Comparison
81
 
82
+ | Metric | SLERP | ORPO | Winner |
83
+ |--------|:-----:|:----:|:------:|
84
+ | Greedy repetition | **74.5%** | 87.1% | SLERP |
85
+ | Chat quality | Fluent | Broken | SLERP |
86
+ | hellaswag | **39.0%** | 35.0% | SLERP |
87
+ | Training time | 5d+8h | **12.8h** | ORPO |
 
 
 
 
 
88
 
89
+ ORPO's weakness: insufficient SFT learning at 10K steps (vs SFT's 65K).
 
90
 
91
+ ## Repository Structure
92
+
93
+ ```
94
+ ├── slerp/ # ⭐ Recommended final model
95
+ ├── pretrain/ # Base pretrained model
96
+ ├── sft-v2/ # SFT instruction-tuned
97
+ ├── dpo-r1/ # DPO Round 1 + LoRA weights
98
+ ├── dpo-r2/ # DPO Round 2 + LoRA weights
99
+ ├── orpo/ # ORPO experiment + LoRA weights
100
+ ├── dpo-r3/ # DPO Round 3
101
+ ├── data/ # Preference datasets for reproducibility
102
+ │ ├── combined_preference.jsonl (684K pairs, 2.6GB)
103
+ │ └── repetition_preference.jsonl (105 pairs, self-generated)
104
+ ├── configs/ # Training YAML configs
105
+ │ ├── korean_3b_sft_1gpu.yaml
106
+ │ ├── dpo_3b_1gpu.yaml
107
+ │ └── orpo_3b_1gpu.yaml
108
+ └── scripts/ # Training & evaluation code
109
+ ├── dpo.py, orpo_native.py, sft.py
110
+ ├── lora.py, merge_checkpoints.py
111
+ ├── evafrill_eval.py
112
+ └── generate_repetition_preference.py
113
+ ```
114
 
115
  ## Usage
116
 
117
  ```python
118
+ # This is a custom architecture — use the project's native loading code
119
+ # Clone: https://github.com/pathcosmos/EVAFRILL-Mo
120
+
121
  import torch
122
+ from model.transformer import LLM
123
+ from tokenizers import Tokenizer
124
 
125
+ model = LLM.from_pretrained("checkpoints/3b_dpo/checkpoint-slerp")
126
+ model = model.to(device="cuda:0", dtype=torch.bfloat16)
127
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ tok = Tokenizer.from_file("tokenizer/korean_sp/tokenizer.json")
130
+ ```
131
 
132
  ## Limitations
133
 
134
+ - **3B scale**: Factual accuracy and complex reasoning are limited
135
+ - **GGUF/Ollama**: Not possible due to custom hybrid Mamba-2 architecture
136
+ - **vLLM**: Theoretically possible but requires custom weight key mapping
137
+ - **Greedy repetition**: ~74.5% without rep_penalty (use rep_penalty=1.2)
 
 
 
 
 
 
 
 
138
 
139
+ ## Links
 
 
 
 
 
 
 
140
 
141
+ - **GitHub**: [pathcosmos/EVAFRILL-Mo](https://github.com/pathcosmos/EVAFRILL-Mo)
142
+ - **Paper reference**: [Nemotron-H](https://arxiv.org/abs/2504.03624)
143
 
144
  ## License
145
 
146
+ MIT
configs/dpo_3b_1gpu.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EVAFRILL-Mo 3B DPO — Single GPU (H100 MIG 3g.40gb, 42.3GB VRAM)
2
+ #
3
+ # Base model: checkpoints/3b_sft_v2/checkpoint-best (SFT 완료 모델)
4
+ # Method: LoRA DPO (native, TRL 미사용)
5
+ #
6
+ # [설계 근거]
7
+ # - GPU: H100 PCIe MIG 3g.40gb (42.3GB VRAM)
8
+ # - LoRA DPO VRAM 예산: base(6GB) + LoRA(0.3GB) + optim(0.2GB) + act(10GB) + ref_fwd(6GB) ≈ 22GB
9
+ # - BF16 + Gradient Checkpointing (FP8 미지원)
10
+ # - eff_batch: 1 × 16 grad_accum = 16
11
+ # - Nemotron-H 스타일 2-round DPO
12
+
13
+ train:
14
+ # Round 1 설정 (Round 2는 max_steps=2000, beta=0.05, lr=1e-7로 변경)
15
+ max_steps: 3000
16
+ batch_size: 1
17
+ grad_accum_steps: 16 # eff_batch = 16
18
+ lr: 5.0e-7 # DPO는 SFT보다 훨씬 낮은 lr
19
+ weight_decay: 0.01
20
+ warmup_steps: 100
21
+ max_length: 1024 # VRAM 제약으로 seq_len 제한
22
+ beta: 0.1 # DPO temperature
23
+
24
+ # LoRA 설정
25
+ use_lora: true
26
+ lora_rank: 32
27
+ lora_alpha: 64
28
+
29
+ # 저장/로깅
30
+ save_interval: 500
31
+ log_interval: 10
configs/korean_3b_sft_1gpu.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EVAFRILL-Mo 3B SFT — Single GPU (H100 MIG 3g.40gb, 42.3GB VRAM)
2
+ #
3
+ # Base model: checkpoints/3b_final/checkpoint-0319772
4
+ # Fresh start from pretrained checkpoint
5
+ #
6
+ # [설계 근거 — 2026-03-17, 최적화 2026-03-17]
7
+ # - GPU: H100 PCIe MIG 3g.40gb (42.3GB VRAM, 46 SMs)
8
+ # - CPU: 45 cores (cgroup), RAM: 200GB (cgroup)
9
+ # - BF16 + Gradient Checkpointing (no FP8, MIG NVML 제약)
10
+ # - 벤치마크 결과: bs=4 ga=7 @ 27.7GB VRAM (68.7%), 5,475 tok/s (+10% vs bs=1)
11
+ # - eff_batch: 4 × 1GPU × 7 grad_accum = 28
12
+ # - 1 epoch: 3,774,413 / 28 ≈ 134,800 steps → max_steps=135000
13
+ # - 예상 시간: 135,000 steps × ~10.5s/step ≈ ~391 hours ≈ ~16 days
14
+
15
+ train:
16
+ max_steps: 135000 # ≈ 1 epoch on 3.77M samples, eff_batch=28
17
+ batch_size: 4 # 벤치마크 최적: bs=4 @ 27.7GB (68.7% VRAM)
18
+ grad_accum_steps: 7 # eff_batch=28 (4×7), bs=1→4 전환으로 tok/s +10%
19
+ lr: 7.0e-6 # sqrt(28/56) * 1e-5 ≈ 7e-6 (linear scaling rule)
20
+ weight_decay: 0.01
21
+ warmup_steps: 500
22
+ max_grad_norm: 1.0
23
+ log_interval: 10
24
+ save_interval: 2000 # 크래시 복구 위해 자주 저장
25
+ eval_interval: 5000 # validation 1회 ~5분, 부담 최소화
26
+ neftune_alpha: 5.0 # NEFTune noise injection (반복 퇴화 완화)
27
+ max_val_batches: 500 # validation 배치 수 제한 (속도 최적화)
28
+
29
+ tokenizer:
30
+ vocab_size: 64000
31
+ type: sentencepiece_unigram
configs/orpo_3b_1gpu.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EVAFRILL-Mo 3B ORPO — Single GPU (H100 MIG 3g.40gb, 42.3GB VRAM)
2
+ #
3
+ # Base model: checkpoints/3b_final/checkpoint-0319772 (Pretrained, NOT SFT)
4
+ # Method: ORPO (SFT + Odds Ratio Preference) with LoRA
5
+ #
6
+ # [설계 근거]
7
+ # - ORPO는 SFT+정렬을 동시에 학습 → pretrained 모델에서 시작
8
+ # - Reference model 불필요 → DPO보다 VRAM 절약
9
+ # - LoRA rank=32: base(6GB) + LoRA(0.3GB) + optim(0.2GB) + act(~8GB) ≈ 15GB
10
+ # - eff_batch: 1 × 16 grad_accum = 16
11
+
12
+ train:
13
+ max_steps: 10000
14
+ batch_size: 1
15
+ grad_accum_steps: 16
16
+ lr: 5.0e-6
17
+ weight_decay: 0.01
18
+ warmup_steps: 500
19
+ max_length: 1024
20
+ lambda_or: 1.0
21
+
22
+ use_lora: true
23
+ lora_rank: 32
24
+ lora_alpha: 64
25
+
26
+ save_interval: 1000
27
+ log_interval: 10
data/combined_preference.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32b0c2a5ca3a523a22882c2d828917a3eb543605f04b51db2bab16e6bd262f95
3
+ size 2721356504
data/repetition_preference.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
dpo-r1/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 64000,
3
+ "d_model": 3072,
4
+ "n_layers": 26,
5
+ "n_heads": 24,
6
+ "n_kv_heads": 8,
7
+ "d_ffn": 9216,
8
+ "max_seq_len": 4096,
9
+ "rope_theta": 500000.0,
10
+ "dropout": 0.0,
11
+ "bias": false,
12
+ "use_flash_attn": true,
13
+ "use_fp8": false,
14
+ "use_hybrid": true,
15
+ "hybrid_pattern": "M M M M M M M M M M M M A M M M M M M M M M M M A M",
16
+ "mamba_d_state": 128,
17
+ "mamba_head_dim": 64,
18
+ "mamba_expand": 2,
19
+ "mamba_conv_kernel": 4,
20
+ "mamba_n_groups": 8,
21
+ "mamba_d_ffn": 4608,
22
+ "mamba_chunk_size": 256,
23
+ "model_type": "evafrill-mo",
24
+ "architectures": [
25
+ "EvafrillMoForCausalLM"
26
+ ],
27
+ "_variant": "dpo-r1",
28
+ "_description": "DPO Round 1 (3K steps, loss 0.693->0.565)"
29
+ }
dpo-r1/lora_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f161ffc4138d61116dedb9e29176497ff85adbed7374a1b7fd35f9672d21245
3
+ size 42909589
dpo-r1/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a07011b219031b7cab8985a4c0dc811aa4758f44c79694da12b744059f77cd99
3
+ size 6301164272
dpo-r1/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
dpo-r2/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 64000,
3
+ "d_model": 3072,
4
+ "n_layers": 26,
5
+ "n_heads": 24,
6
+ "n_kv_heads": 8,
7
+ "d_ffn": 9216,
8
+ "max_seq_len": 4096,
9
+ "rope_theta": 500000.0,
10
+ "dropout": 0.0,
11
+ "bias": false,
12
+ "use_flash_attn": true,
13
+ "use_fp8": false,
14
+ "use_hybrid": true,
15
+ "hybrid_pattern": "M M M M M M M M M M M M A M M M M M M M M M M M A M",
16
+ "mamba_d_state": 128,
17
+ "mamba_head_dim": 64,
18
+ "mamba_expand": 2,
19
+ "mamba_conv_kernel": 4,
20
+ "mamba_n_groups": 8,
21
+ "mamba_d_ffn": 4608,
22
+ "mamba_chunk_size": 256,
23
+ "model_type": "evafrill-mo",
24
+ "architectures": [
25
+ "EvafrillMoForCausalLM"
26
+ ],
27
+ "_variant": "dpo-r2",
28
+ "_description": "DPO Round 2 (2K steps, conservative fine-tuning)"
29
+ }
dpo-r2/lora_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeda7e71c6d2f69ea5ed8a02759ee2ac1f6feadcbf6e440fd3f9da919628f947
3
+ size 42909589
dpo-r2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c64ad3c2c26b3e706e58011a8a3eb8dba3f5cee2b0ea62eaa0083ad3c4b7685e
3
+ size 6301164272
dpo-r2/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
dpo-r3/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 64000,
3
+ "d_model": 3072,
4
+ "n_layers": 26,
5
+ "n_heads": 24,
6
+ "n_kv_heads": 8,
7
+ "d_ffn": 9216,
8
+ "max_seq_len": 4096,
9
+ "rope_theta": 500000.0,
10
+ "dropout": 0.0,
11
+ "bias": false,
12
+ "use_flash_attn": true,
13
+ "use_fp8": false,
14
+ "use_hybrid": true,
15
+ "hybrid_pattern": "M M M M M M M M M M M M A M M M M M M M M M M M A M",
16
+ "mamba_d_state": 128,
17
+ "mamba_head_dim": 64,
18
+ "mamba_expand": 2,
19
+ "mamba_conv_kernel": 4,
20
+ "mamba_n_groups": 8,
21
+ "mamba_d_ffn": 4608,
22
+ "mamba_chunk_size": 256,
23
+ "model_type": "evafrill-mo",
24
+ "architectures": [
25
+ "EvafrillMoForCausalLM"
26
+ ],
27
+ "_variant": "dpo-r3",
28
+ "_description": "DPO Round 3 (1K steps, repetition-targeted experiment)"
29
+ }
dpo-r3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d49c4459ae2db0f5fb8ef761c94afd06b6f518e07e7208407effe1298fa9bc20
3
+ size 6301164272
dpo-r3/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
orpo/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 64000,
3
+ "d_model": 3072,
4
+ "n_layers": 26,
5
+ "n_heads": 24,
6
+ "n_kv_heads": 8,
7
+ "d_ffn": 9216,
8
+ "max_seq_len": 4096,
9
+ "rope_theta": 500000.0,
10
+ "dropout": 0.0,
11
+ "bias": false,
12
+ "use_flash_attn": true,
13
+ "use_fp8": false,
14
+ "use_hybrid": true,
15
+ "hybrid_pattern": "M M M M M M M M M M M M A M M M M M M M M M M M A M",
16
+ "mamba_d_state": 128,
17
+ "mamba_head_dim": 64,
18
+ "mamba_expand": 2,
19
+ "mamba_conv_kernel": 4,
20
+ "mamba_n_groups": 8,
21
+ "mamba_d_ffn": 4608,
22
+ "mamba_chunk_size": 256,
23
+ "model_type": "evafrill-mo",
24
+ "architectures": [
25
+ "EvafrillMoForCausalLM"
26
+ ],
27
+ "_variant": "orpo",
28
+ "_description": "ORPO experiment (10K steps, SFT+alignment simultaneous)"
29
+ }
orpo/lora_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a06dd455a91aa303be57a004db0ef2889027cf9b2a7854137e7e7a0aefbfeeda
3
+ size 42909589
orpo/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46b8ee870223a391c08ff08a363ac68948df0b13971910fb9ba3c85678f3f6e5
3
+ size 6301164272
orpo/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pretrain/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 64000,
3
+ "d_model": 3072,
4
+ "n_layers": 26,
5
+ "n_heads": 24,
6
+ "n_kv_heads": 8,
7
+ "d_ffn": 9216,
8
+ "max_seq_len": 4096,
9
+ "rope_theta": 500000.0,
10
+ "dropout": 0.0,
11
+ "bias": false,
12
+ "use_flash_attn": true,
13
+ "use_fp8": true,
14
+ "use_hybrid": true,
15
+ "hybrid_pattern": "M M M M M M M M M M M M A M M M M M M M M M M M A M",
16
+ "mamba_d_state": 128,
17
+ "mamba_head_dim": 64,
18
+ "mamba_expand": 2,
19
+ "mamba_conv_kernel": 4,
20
+ "mamba_n_groups": 8,
21
+ "mamba_d_ffn": 4608,
22
+ "mamba_chunk_size": 256,
23
+ "model_type": "evafrill-mo",
24
+ "architectures": [
25
+ "EvafrillMoForCausalLM"
26
+ ],
27
+ "_variant": "pretrain",
28
+ "_description": "Pretrained base model (319K steps, 55B tokens, Chinchilla 93%)"
29
+ }
pretrain/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4e8796eb71489901cded57f3981889b2cd57f06552c03cf814c15fc52ad69df
3
+ size 12602306810
pretrain/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
scripts/dpo.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train/dpo.py — Direct Preference Optimization (DPO) training.
3
+
4
+ Native DPO implementation (no TRL dependency) for EVAFRILL-Mo hybrid models.
5
+ Supports LoRA adapters for memory-efficient training on single GPU.
6
+
7
+ Launch:
8
+ python train/dpo.py \
9
+ --sft_checkpoint checkpoints/3b_sft_v2/checkpoint-best \
10
+ --dpo_data data/preference/combined_preference.jsonl \
11
+ --config configs/h100_mig/dpo_3b_1gpu.yaml \
12
+ --device cuda:0
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import os
19
+ import random
20
+ import signal
21
+ import shutil
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch.utils.data import DataLoader, RandomSampler
30
+
31
+ torch.backends.cuda.matmul.allow_tf32 = True
32
+ torch.backends.cudnn.allow_tf32 = True
33
+ torch.set_float32_matmul_precision("high")
34
+
35
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent
36
+ if str(_PROJECT_ROOT) not in sys.path:
37
+ sys.path.insert(0, str(_PROJECT_ROOT))
38
+
39
+ from model import LLM
40
+ from model.lora import apply_lora, get_lora_params, merge_lora, save_lora
41
+ from data.dpo_dataset import DPODataset, dpo_collate_fn
42
+ from train.utils import (
43
+ get_cosine_schedule_with_warmup,
44
+ is_main_process,
45
+ save_checkpoint,
46
+ load_checkpoint,
47
+ )
48
+
49
+
50
+ def parse_args() -> argparse.Namespace:
51
+ parser = argparse.ArgumentParser(description="DPO Training for EVAFRILL-Mo")
52
+
53
+ # Paths
54
+ parser.add_argument("--sft_checkpoint", type=Path, required=True,
55
+ help="Path to SFT checkpoint directory")
56
+ parser.add_argument("--dpo_data", type=Path, required=True,
57
+ help="Path to preference JSONL data")
58
+ parser.add_argument("--checkpoint_dir", type=Path, default=Path("checkpoints/3b_dpo"),
59
+ help="Output checkpoint directory")
60
+ parser.add_argument("--resume", type=Path, default=None)
61
+ parser.add_argument("--tokenizer", type=Path, default=None)
62
+ parser.add_argument("--log_file", type=Path, default=None)
63
+ parser.add_argument("--config", type=Path, default=None)
64
+
65
+ # DPO hyperparameters
66
+ parser.add_argument("--beta", type=float, default=0.1, help="DPO temperature")
67
+ parser.add_argument("--max_steps", type=int, default=3000)
68
+ parser.add_argument("--batch_size", type=int, default=1)
69
+ parser.add_argument("--grad_accum", type=int, default=16)
70
+ parser.add_argument("--lr", type=float, default=5e-7)
71
+ parser.add_argument("--weight_decay", type=float, default=0.01)
72
+ parser.add_argument("--warmup_steps", type=int, default=100)
73
+ parser.add_argument("--max_length", type=int, default=1024)
74
+ parser.add_argument("--seed", type=int, default=42)
75
+
76
+ # LoRA
77
+ parser.add_argument("--use_lora", action="store_true", default=True)
78
+ parser.add_argument("--lora_rank", type=int, default=32)
79
+ parser.add_argument("--lora_alpha", type=float, default=64.0)
80
+
81
+ # Infra
82
+ parser.add_argument("--device", type=str, default=None)
83
+ parser.add_argument("--save_interval", type=int, default=500)
84
+ parser.add_argument("--log_interval", type=int, default=10)
85
+ parser.add_argument("--num_workers", type=int, default=4)
86
+
87
+ args, _ = parser.parse_known_args()
88
+
89
+ # Load YAML config
90
+ if args.config is not None:
91
+ if not args.config.exists():
92
+ raise FileNotFoundError(f"Config not found: {args.config}")
93
+ import yaml
94
+ with open(args.config) as f:
95
+ cfg = yaml.safe_load(f)
96
+ train_cfg = cfg.get("train", {})
97
+ yaml_map = {
98
+ "max_steps": "max_steps", "batch_size": "batch_size",
99
+ "grad_accum_steps": "grad_accum", "lr": "lr",
100
+ "weight_decay": "weight_decay", "warmup_steps": "warmup_steps",
101
+ "beta": "beta", "max_length": "max_length",
102
+ "save_interval": "save_interval", "log_interval": "log_interval",
103
+ "use_lora": "use_lora", "lora_rank": "lora_rank", "lora_alpha": "lora_alpha",
104
+ }
105
+ defaults = {}
106
+ for yk, ak in yaml_map.items():
107
+ if yk in train_cfg:
108
+ defaults[ak] = train_cfg[yk]
109
+ if defaults:
110
+ parser.set_defaults(**defaults)
111
+
112
+ return parser.parse_args()
113
+
114
+
115
+ def set_seed(seed: int) -> None:
116
+ random.seed(seed)
117
+ np.random.seed(seed)
118
+ torch.manual_seed(seed)
119
+ torch.cuda.manual_seed_all(seed)
120
+
121
+
122
+ def compute_log_probs(
123
+ model: nn.Module,
124
+ input_ids: torch.Tensor,
125
+ labels: torch.Tensor,
126
+ ) -> torch.Tensor:
127
+ """Compute sum of log probabilities over non-masked tokens.
128
+
129
+ Args:
130
+ model: The LLM model
131
+ input_ids: (B, T) token ids
132
+ labels: (B, T) target ids, -1 for masked positions
133
+
134
+ Returns:
135
+ (B,) sum of log probs per sample
136
+ """
137
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
138
+ logits, _ = model(input_ids) # (B, T, V)
139
+
140
+ # Shift: predict next token
141
+ # logits[:, :-1] predicts labels[:, 1:]
142
+ # But our labels already have the shifted targets (same as SFT convention)
143
+ # labels[i] = token_id means input_ids[i] should predict labels[i]
144
+ log_probs = F.log_softmax(logits.float(), dim=-1) # (B, T, V)
145
+
146
+ # Gather log probs for target tokens
147
+ # For each position, get log_prob of the label token
148
+ mask = labels != -1 # (B, T)
149
+ # Clamp labels for gather (replace -1 with 0, will be masked out)
150
+ safe_labels = labels.clamp(min=0) # (B, T)
151
+ per_token_logps = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) # (B, T)
152
+ per_token_logps = per_token_logps * mask.float() # zero out masked positions
153
+
154
+ return per_token_logps.sum(dim=-1) # (B,)
155
+
156
+
157
+ def dpo_loss(
158
+ policy_chosen_logps: torch.Tensor,
159
+ policy_rejected_logps: torch.Tensor,
160
+ ref_chosen_logps: torch.Tensor,
161
+ ref_rejected_logps: torch.Tensor,
162
+ beta: float = 0.1,
163
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
164
+ """Compute DPO loss.
165
+
166
+ Returns:
167
+ (loss, chosen_rewards, rejected_rewards)
168
+ """
169
+ chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
170
+ rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
171
+
172
+ logits = chosen_rewards - rejected_rewards # (B,)
173
+ loss = -F.logsigmoid(logits).mean()
174
+
175
+ return loss, chosen_rewards.detach().mean(), rejected_rewards.detach().mean()
176
+
177
+
178
+ def _resolve_tokenizer_path(args: argparse.Namespace) -> Path:
179
+ if args.tokenizer is not None:
180
+ return Path(args.tokenizer)
181
+ ckpt_tok = args.sft_checkpoint / "tokenizer.json"
182
+ if ckpt_tok.exists():
183
+ return ckpt_tok
184
+ default_tok = _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json"
185
+ if default_tok.exists():
186
+ return default_tok
187
+ raise FileNotFoundError("Cannot find tokenizer.json")
188
+
189
+
190
+ def main() -> None:
191
+ args = parse_args()
192
+ set_seed(args.seed)
193
+
194
+ # Device setup
195
+ if args.device:
196
+ device = torch.device(args.device)
197
+ elif torch.cuda.is_available():
198
+ device = torch.device("cuda:0")
199
+ else:
200
+ device = torch.device("cpu")
201
+
202
+ # Validate checkpoint
203
+ if not args.sft_checkpoint.exists():
204
+ raise FileNotFoundError(f"SFT checkpoint not found: {args.sft_checkpoint}")
205
+
206
+ # Load SFT model as policy
207
+ print(f"Loading SFT model from {args.sft_checkpoint}...")
208
+ model = LLM.from_pretrained(args.sft_checkpoint)
209
+ model.config.use_fp8 = False # H100 MIG: BF16 only
210
+ model = model.to(device=device, dtype=torch.bfloat16)
211
+
212
+ # Enable gradient checkpointing
213
+ if hasattr(model, 'gradient_checkpointing_enable'):
214
+ model.gradient_checkpointing_enable()
215
+ print("[INFO] Gradient checkpointing enabled")
216
+
217
+ # Compute reference log probs BEFORE applying LoRA
218
+ # (reference model = SFT model without LoRA)
219
+ # We'll compute ref logps on-the-fly with LoRA disabled via a context manager
220
+ # Actually for simplicity: precompute nothing, just use model without LoRA adapters
221
+ # For LoRA DPO: ref_model is the base (original weights), policy is base + LoRA
222
+ # Since LoRA is initialized to zero, at start policy = ref
223
+
224
+ # Apply LoRA
225
+ if args.use_lora:
226
+ n_lora_params = apply_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
227
+ lora_params = get_lora_params(model)
228
+ print(f"[INFO] LoRA: {n_lora_params:,} trainable params")
229
+ else:
230
+ # Full fine-tuning (risky for VRAM)
231
+ lora_params = None
232
+
233
+ total_params = sum(p.numel() for p in model.parameters())
234
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
235
+ print(f"Total params: {total_params:,}, Trainable: {trainable_params:,}")
236
+
237
+ # Tokenizer
238
+ tokenizer_path = _resolve_tokenizer_path(args)
239
+ print(f"Loading tokenizer from {tokenizer_path}")
240
+ from tokenizers import Tokenizer
241
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
242
+
243
+ # Dataset
244
+ train_dataset = DPODataset(
245
+ data_path=args.dpo_data,
246
+ tokenizer=tokenizer,
247
+ max_seq_len=args.max_length,
248
+ )
249
+
250
+ train_loader = DataLoader(
251
+ train_dataset,
252
+ batch_size=args.batch_size,
253
+ sampler=RandomSampler(train_dataset),
254
+ num_workers=args.num_workers,
255
+ pin_memory=True,
256
+ drop_last=True,
257
+ collate_fn=dpo_collate_fn,
258
+ prefetch_factor=2,
259
+ persistent_workers=True,
260
+ )
261
+
262
+ # Optimizer — only LoRA params if using LoRA
263
+ if lora_params is not None:
264
+ optimizer = torch.optim.AdamW(
265
+ lora_params,
266
+ lr=args.lr,
267
+ betas=(0.9, 0.95),
268
+ weight_decay=args.weight_decay,
269
+ fused=torch.cuda.is_available(),
270
+ )
271
+ else:
272
+ optimizer = torch.optim.AdamW(
273
+ [p for p in model.parameters() if p.requires_grad],
274
+ lr=args.lr,
275
+ betas=(0.9, 0.95),
276
+ weight_decay=args.weight_decay,
277
+ fused=torch.cuda.is_available(),
278
+ )
279
+
280
+ scheduler = get_cosine_schedule_with_warmup(
281
+ optimizer=optimizer,
282
+ warmup_steps=args.warmup_steps,
283
+ total_steps=args.max_steps,
284
+ )
285
+
286
+ # Resume
287
+ start_step = 0
288
+ if args.resume is not None:
289
+ start_step, _ = load_checkpoint(args.resume, model, optimizer, scheduler)
290
+ print(f"Resumed from step {start_step}")
291
+
292
+ # Checkpoint dir
293
+ args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
294
+
295
+ # Copy tokenizer
296
+ dest_tok = args.checkpoint_dir / "tokenizer.json"
297
+ if not dest_tok.exists():
298
+ shutil.copy2(str(tokenizer_path), str(dest_tok))
299
+
300
+ # Log file
301
+ log_fh = None
302
+ if args.log_file:
303
+ Path(args.log_file).parent.mkdir(parents=True, exist_ok=True)
304
+ log_fh = open(args.log_file, "a", encoding="utf-8", buffering=1)
305
+
306
+ def log(msg: str, level: str = "INFO"):
307
+ import datetime
308
+ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
309
+ line = f"[{ts}] [{level}] {msg}"
310
+ print(line)
311
+ if log_fh:
312
+ log_fh.write(line + "\n")
313
+
314
+ # Banner
315
+ eff_batch = args.batch_size * args.grad_accum
316
+ log(f"{'='*60}")
317
+ log(f"DPO Training — EVAFRILL-Mo 3B")
318
+ log(f" SFT ckpt: {args.sft_checkpoint}")
319
+ log(f" DPO data: {args.dpo_data} ({len(train_dataset):,} samples)")
320
+ log(f" LoRA: rank={args.lora_rank} alpha={args.lora_alpha}")
321
+ log(f" beta={args.beta}, lr={args.lr:.2e}, eff_batch={eff_batch}")
322
+ log(f" max_steps={args.max_steps}, max_length={args.max_length}")
323
+ log(f" device={device}")
324
+ log(f"{'='*60}")
325
+
326
+ # Training loop
327
+ import time
328
+ model.train()
329
+ loader_iter = iter(train_loader)
330
+ epoch = 0
331
+
332
+ def next_batch():
333
+ nonlocal loader_iter, epoch
334
+ try:
335
+ return next(loader_iter)
336
+ except StopIteration:
337
+ epoch += 1
338
+ loader_iter = iter(train_loader)
339
+ return next(loader_iter)
340
+
341
+ shutdown_requested = False
342
+ def shutdown_handler(signum, frame):
343
+ nonlocal shutdown_requested
344
+ shutdown_requested = True
345
+ log(f"Shutdown signal received ({signum})", "WARN")
346
+
347
+ signal.signal(signal.SIGHUP, shutdown_handler)
348
+ signal.signal(signal.SIGTERM, shutdown_handler)
349
+
350
+ t0 = time.perf_counter()
351
+ running_loss = 0.0
352
+ running_chosen_reward = 0.0
353
+ running_rejected_reward = 0.0
354
+ log_step_count = 0
355
+
356
+ for step in range(start_step, args.max_steps):
357
+ optimizer.zero_grad(set_to_none=True)
358
+ accum_loss = 0.0
359
+
360
+ for micro in range(args.grad_accum):
361
+ batch = next_batch()
362
+ chosen_ids = batch[0].to(device, dtype=torch.long, non_blocking=True)
363
+ chosen_labels = batch[1].to(device, dtype=torch.long, non_blocking=True)
364
+ rejected_ids = batch[2].to(device, dtype=torch.long, non_blocking=True)
365
+ rejected_labels = batch[3].to(device, dtype=torch.long, non_blocking=True)
366
+
367
+ # Policy log probs (with LoRA active)
368
+ policy_chosen_logps = compute_log_probs(model, chosen_ids, chosen_labels)
369
+ policy_rejected_logps = compute_log_probs(model, rejected_ids, rejected_labels)
370
+
371
+ # Reference log probs (LoRA disabled)
372
+ # For LoRA: temporarily set lora scaling to 0
373
+ with torch.no_grad():
374
+ # Save and zero LoRA params
375
+ if args.use_lora:
376
+ saved_B = []
377
+ for m in model.modules():
378
+ from model.lora import LoRALinear
379
+ if isinstance(m, LoRALinear):
380
+ saved_B.append(m.lora_B.data.clone())
381
+ m.lora_B.data.zero_()
382
+
383
+ ref_chosen_logps = compute_log_probs(model, chosen_ids, chosen_labels)
384
+ ref_rejected_logps = compute_log_probs(model, rejected_ids, rejected_labels)
385
+
386
+ # Restore LoRA params
387
+ if args.use_lora:
388
+ idx = 0
389
+ for m in model.modules():
390
+ from model.lora import LoRALinear
391
+ if isinstance(m, LoRALinear):
392
+ m.lora_B.data.copy_(saved_B[idx])
393
+ idx += 1
394
+
395
+ # DPO loss
396
+ loss, chosen_reward, rejected_reward = dpo_loss(
397
+ policy_chosen_logps, policy_rejected_logps,
398
+ ref_chosen_logps, ref_rejected_logps,
399
+ beta=args.beta,
400
+ )
401
+
402
+ scaled_loss = loss / args.grad_accum
403
+ scaled_loss.backward()
404
+ accum_loss += loss.item()
405
+
406
+ # Gradient clipping
407
+ grad_norm = torch.nn.utils.clip_grad_norm_(
408
+ [p for p in model.parameters() if p.requires_grad], 1.0
409
+ ).item()
410
+
411
+ optimizer.step()
412
+ scheduler.step()
413
+
414
+ avg_loss = accum_loss / args.grad_accum
415
+ running_loss += avg_loss
416
+ running_chosen_reward += chosen_reward.item()
417
+ running_rejected_reward += rejected_reward.item()
418
+ log_step_count += 1
419
+
420
+ # Shutdown check
421
+ if shutdown_requested:
422
+ log(f"Graceful shutdown at step {step + 1}", "WARN")
423
+ save_checkpoint(model, optimizer, scheduler, step + 1, avg_loss, str(args.checkpoint_dir))
424
+ if args.use_lora:
425
+ save_lora(model, args.checkpoint_dir / f"lora-{step+1:07d}")
426
+ break
427
+
428
+ # Logging
429
+ if (step + 1) % args.log_interval == 0:
430
+ t1 = time.perf_counter()
431
+ elapsed = t1 - t0
432
+ avg_l = running_loss / log_step_count
433
+ avg_cr = running_chosen_reward / log_step_count
434
+ avg_rr = running_rejected_reward / log_step_count
435
+ margin = avg_cr - avg_rr
436
+ lr = scheduler.get_last_lr()[0]
437
+ mem_gb = torch.cuda.memory_allocated() / 1e9
438
+
439
+ log(f"step {step+1:>6d} | loss {avg_l:.4f} | "
440
+ f"margin {margin:.4f} (c={avg_cr:.3f} r={avg_rr:.3f}) | "
441
+ f"lr {lr:.2e} | gnorm {grad_norm:.3f} | mem {mem_gb:.1f}GB")
442
+
443
+ running_loss = 0.0
444
+ running_chosen_reward = 0.0
445
+ running_rejected_reward = 0.0
446
+ log_step_count = 0
447
+ t0 = t1
448
+
449
+ # Save checkpoint
450
+ if (step + 1) % args.save_interval == 0:
451
+ ckpt_path = save_checkpoint(
452
+ model, optimizer, scheduler, step + 1, avg_loss, str(args.checkpoint_dir)
453
+ )
454
+ if args.use_lora:
455
+ save_lora(model, args.checkpoint_dir / f"lora-{step+1:07d}")
456
+ log(f"Checkpoint saved -> {ckpt_path}")
457
+
458
+ # Final save
459
+ final_path = save_checkpoint(
460
+ model, optimizer, scheduler, args.max_steps, avg_loss, str(args.checkpoint_dir)
461
+ )
462
+ if args.use_lora:
463
+ save_lora(model, args.checkpoint_dir / "lora-final")
464
+ # Also merge and save merged model
465
+ log("Merging LoRA weights into base model...")
466
+ merge_lora(model)
467
+ model.save_pretrained(args.checkpoint_dir / "checkpoint-merged")
468
+ log(f"Merged model saved -> {args.checkpoint_dir / 'checkpoint-merged'}")
469
+
470
+ log(f"DPO training complete. Final checkpoint -> {final_path}")
471
+
472
+ if log_fh:
473
+ log_fh.close()
474
+
475
+
476
+ if __name__ == "__main__":
477
+ main()
scripts/evafrill_eval.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EVAFRILL-Mo 3B — 종합 평가 파이프라인
3
+ ======================================
4
+
5
+ Phase 1: PPL (1-GPU 순차, 16개 val 셋)
6
+ Phase 2: 생성 품질 + 반복률 분석 (cuda:0)
7
+ Phase 3: Calibration (cuda:0)
8
+ Phase 4: lm-eval 벤치마크 — 커스텀 래퍼 사용
9
+ (belebele_kor_Hang, global_mmlu_full_ko, hellaswag, arc_easy, arc_challenge, kmmlu)
10
+
11
+ Usage:
12
+ cd /home/ghong/project-ghong/taketimes/llm-star
13
+ python eval/evafrill_eval.py
14
+ python eval/evafrill_eval.py --skip-phase4
15
+ python eval/evafrill_eval.py --checkpoint checkpoints/3b_final/checkpoint-0319772
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import json
22
+ import math
23
+ import os
24
+ import sys
25
+ import time
26
+ from collections import Counter
27
+ from datetime import datetime
28
+ from pathlib import Path
29
+ from typing import Dict, List, Optional, Tuple
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from torch.utils.data import DataLoader, Dataset
35
+ from tqdm import tqdm
36
+
37
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent
38
+ if str(_PROJECT_ROOT) not in sys.path:
39
+ sys.path.insert(0, str(_PROJECT_ROOT))
40
+
41
+ from model.transformer import LLM # noqa: E402
42
+ from tokenizers import Tokenizer # noqa: E402
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Constants
46
+ # ---------------------------------------------------------------------------
47
+ DEFAULT_CHECKPOINT = str(_PROJECT_ROOT / "checkpoints" / "3b_final" / "checkpoint-0319772")
48
+ TOKENIZER_PATH = str(_PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json")
49
+ DATA_DIR = _PROJECT_ROOT / "data"
50
+ OUTPUT_DIR = _PROJECT_ROOT / "eval" / "outputs"
51
+
52
+ # GPUs available
53
+ N_GPUS = 1
54
+ GPU_IDS = [0]
55
+
56
+ # 한국어 생성 프롬프트 (15개)
57
+ PROMPTS = [
58
+ "대한민국의 수도는",
59
+ "인공지능이란",
60
+ "한국의 전통 음식 중에서",
61
+ "지구 온난화의 주요 원인은",
62
+ "프로그래밍을 배우려면",
63
+ "조선시대에는",
64
+ "물리학에서 에너지란",
65
+ "한국어는 세계에서",
66
+ "경제 성장을 위해서는",
67
+ "우주 탐사의 역사를 보면",
68
+ "머신러닝과 딥러닝의 차이는",
69
+ "한국 문학의 대표적인 작품으로는",
70
+ "양자 컴퓨터란",
71
+ "건강한 식습관을 위해서는",
72
+ "세계 2차 대전 이후",
73
+ ]
74
+
75
+ # PPL 태스크: GPU → val 파일 리스트
76
+ PPL_TASKS: Dict[int, List[str]] = {
77
+ 0: [
78
+ "3b_val.bin",
79
+ "korean_c4_val.bin", "korean_val.bin",
80
+ "hplt_ko_val.bin", "cc100_ko_val.bin",
81
+ "korean_wiki_val.bin", "korean_namuwiki_val.bin",
82
+ "cosmo_auto_math_text_val.bin", "cosmo_stories_val.bin", "cosmo_web_v2_val.bin",
83
+ "cosmo_stanford_val.bin", "cosmo_khanacademy_val.bin", "cosmo_openstax_val.bin", "cosmo_wikihow_val.bin",
84
+ "mathpile_val.bin", "open_web_math_val.bin",
85
+ ],
86
+ }
87
+
88
+
89
+ # ===========================================================================
90
+ # Argument parsing
91
+ # ===========================================================================
92
+
93
+ def parse_args() -> argparse.Namespace:
94
+ parser = argparse.ArgumentParser(description="EVAFRILL-Mo 종합 평가")
95
+ parser.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT)
96
+ parser.add_argument("--output-dir", default=None)
97
+ parser.add_argument("--seq-len", type=int, default=2048)
98
+ parser.add_argument("--stride", type=int, default=512)
99
+ parser.add_argument("--batch-size", type=int, default=2)
100
+ parser.add_argument("--max-new-tokens", type=int, default=256)
101
+ parser.add_argument("--skip-phase1", action="store_true")
102
+ parser.add_argument("--skip-phase2", action="store_true")
103
+ parser.add_argument("--skip-phase3", action="store_true")
104
+ parser.add_argument("--skip-phase4", action="store_true")
105
+ parser.add_argument("--limit", type=int, default=None,
106
+ help="Limit examples per lm-eval task (for fast testing)")
107
+ parser.add_argument("--exclude-tasks", type=str, default=None,
108
+ help="Comma-separated lm-eval tasks to exclude (e.g. kmmlu)")
109
+ return parser.parse_args()
110
+
111
+
112
+ # ===========================================================================
113
+ # Sliding-window PPL dataset
114
+ # ===========================================================================
115
+
116
+ class BinDataset(Dataset):
117
+ def __init__(self, path: str, seq_len: int, stride: int):
118
+ data = np.fromfile(path, dtype=np.uint16)
119
+ self.data = torch.from_numpy(data.astype(np.int64))
120
+ self.seq_len = seq_len
121
+ self.stride = stride
122
+ self.indices = list(range(0, max(1, len(self.data) - seq_len), stride))
123
+
124
+ def __len__(self):
125
+ return len(self.indices)
126
+
127
+ def __getitem__(self, idx):
128
+ start = self.indices[idx]
129
+ chunk = self.data[start: start + self.seq_len + 1]
130
+ if len(chunk) < self.seq_len + 1:
131
+ chunk = F.pad(chunk, (0, self.seq_len + 1 - len(chunk)))
132
+ return chunk[:-1], chunk[1:]
133
+
134
+
135
+ # ===========================================================================
136
+ # PPL worker (runs in separate process)
137
+ # ===========================================================================
138
+
139
+ def _ppl_worker(
140
+ checkpoint: str,
141
+ gpu_id: int,
142
+ val_files: List[str],
143
+ data_dir: str,
144
+ seq_len: int,
145
+ stride: int,
146
+ batch_size: int,
147
+ ) -> Dict[str, float]:
148
+ """각 GPU에서 여러 val 파일의 PPL을 계산."""
149
+ import torch
150
+ import sys
151
+ from pathlib import Path
152
+ sys.path.insert(0, str(Path(checkpoint).parent.parent.parent)) # project root
153
+
154
+ from model.transformer import LLM # noqa
155
+
156
+ device = f"cuda:{gpu_id}"
157
+ model = LLM.from_pretrained(checkpoint)
158
+ model = model.to(device=device, dtype=torch.bfloat16)
159
+ model.eval()
160
+
161
+ results = {}
162
+ for fname in val_files:
163
+ fpath = Path(data_dir) / fname
164
+ if not fpath.exists():
165
+ results[fname.replace("_val.bin", "")] = None
166
+ continue
167
+
168
+ ds = BinDataset(str(fpath), seq_len, stride)
169
+ loader = DataLoader(ds, batch_size=batch_size, num_workers=0, pin_memory=True)
170
+
171
+ total_nll = 0.0
172
+ total_tokens = 0
173
+ with torch.no_grad():
174
+ for x, y in loader:
175
+ x, y = x.to(device), y.to(device)
176
+ logits, _ = model(x)
177
+ loss = F.cross_entropy(
178
+ logits.reshape(-1, logits.size(-1)),
179
+ y.reshape(-1),
180
+ reduction="sum",
181
+ ignore_index=0,
182
+ )
183
+ valid = (y != 0).sum().item()
184
+ total_nll += loss.item()
185
+ total_tokens += valid
186
+
187
+ ppl = math.exp(total_nll / max(total_tokens, 1))
188
+ key = fname.replace("_val.bin", "")
189
+ results[key] = round(ppl, 4)
190
+ print(f"[GPU {gpu_id}] {key}: PPL={ppl:.4f}", flush=True)
191
+
192
+ return results
193
+
194
+
195
+ # ===========================================================================
196
+ # Phase 1: PPL (병렬)
197
+ # ===========================================================================
198
+
199
+ def run_phase1(checkpoint: str, seq_len: int, stride: int, batch_size: int) -> Dict[str, float]:
200
+ print("\n" + "=" * 60)
201
+ print("Phase 1: PPL 평가 (1-GPU 순차)")
202
+ print("=" * 60)
203
+ t0 = time.time()
204
+
205
+ existing = [f for f in PPL_TASKS[0] if (DATA_DIR / f).exists()]
206
+ if not existing:
207
+ print(" 평가할 val 파일 없음")
208
+ return {}
209
+
210
+ all_results = _ppl_worker(
211
+ checkpoint=checkpoint,
212
+ gpu_id=0,
213
+ val_files=existing,
214
+ data_dir=str(DATA_DIR),
215
+ seq_len=seq_len,
216
+ stride=stride,
217
+ batch_size=batch_size,
218
+ )
219
+
220
+ elapsed = time.time() - t0
221
+ print(f"\n Phase 1 완료 ({elapsed:.1f}s)")
222
+ return all_results
223
+
224
+
225
+ # ===========================================================================
226
+ # Phase 2: 생성 품질 + 반복률
227
+ # ===========================================================================
228
+
229
+ def _ngram_repetition(tokens: List[int], n: int) -> float:
230
+ if len(tokens) < n:
231
+ return 0.0
232
+ ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
233
+ total = len(ngrams)
234
+ unique = len(set(ngrams))
235
+ return round(1.0 - unique / total, 4) if total > 0 else 0.0
236
+
237
+
238
+ def run_phase2(checkpoint: str, max_new_tokens: int) -> List[Dict]:
239
+ print("\n" + "=" * 60)
240
+ print("Phase 2: 생성 품질 + 반복률")
241
+ print("=" * 60)
242
+
243
+ device = "cuda:0"
244
+ model = LLM.from_pretrained(checkpoint)
245
+ model = model.to(device=device, dtype=torch.bfloat16)
246
+ model.eval()
247
+
248
+ tok = Tokenizer.from_file(TOKENIZER_PATH)
249
+
250
+ results = []
251
+ configs = [
252
+ ("greedy", 0.0, 1.0),
253
+ ("t0.7", 0.7, 1.0),
254
+ ("t0.7_r1.2", 0.7, 1.2),
255
+ ("t0.9_r1.1", 0.9, 1.1),
256
+ ]
257
+
258
+ for prompt in PROMPTS:
259
+ ids = tok.encode(prompt).ids
260
+ x = torch.tensor([ids], dtype=torch.long, device=device)
261
+
262
+ row = {"prompt": prompt, "configs": {}}
263
+ for cfg_name, temp, rep_pen in configs:
264
+ with torch.no_grad():
265
+ generated = list(ids)
266
+ for _ in range(max_new_tokens):
267
+ inp = torch.tensor([generated[-2048:]], dtype=torch.long, device=device)
268
+ logits, _ = model(inp)
269
+ logits = logits[:, -1, :]
270
+
271
+ # Repetition penalty
272
+ if rep_pen != 1.0:
273
+ for tok_id in set(generated[-64:]):
274
+ logits[0, tok_id] /= rep_pen
275
+
276
+ if temp == 0.0:
277
+ next_tok = logits.argmax(dim=-1).item()
278
+ else:
279
+ probs = torch.softmax(logits / temp, dim=-1)
280
+ next_tok = torch.multinomial(probs[0], 1).item()
281
+
282
+ generated.append(next_tok)
283
+ if next_tok in (tok.token_to_id("</s>"), tok.token_to_id("<eos>"), 2):
284
+ break
285
+
286
+ new_ids = generated[len(ids):]
287
+ text = tok.decode(new_ids)
288
+ rep3 = _ngram_repetition(new_ids, 3)
289
+ rep4 = _ngram_repetition(new_ids, 4)
290
+ eos_hit = new_ids[-1] in (2,) if new_ids else False
291
+
292
+ row["configs"][cfg_name] = {
293
+ "text": text,
294
+ "tokens": len(new_ids),
295
+ "3gram_rep": rep3,
296
+ "4gram_rep": rep4,
297
+ "eos": eos_hit,
298
+ }
299
+
300
+ results.append(row)
301
+ greedy = row["configs"]["greedy"]
302
+ print(f"\n[{prompt}]")
303
+ print(f" greedy({greedy['tokens']}tok, rep3={greedy['3gram_rep']:.2%}): {greedy['text'][:120]}")
304
+
305
+ del model
306
+ torch.cuda.empty_cache()
307
+ return results
308
+
309
+
310
+ # ===========================================================================
311
+ # Phase 3: Calibration
312
+ # ===========================================================================
313
+
314
+ def run_phase3(checkpoint: str) -> Dict:
315
+ print("\n" + "=" * 60)
316
+ print("Phase 3: Calibration 체크")
317
+ print("=" * 60)
318
+
319
+ device = "cuda:0"
320
+ model = LLM.from_pretrained(checkpoint)
321
+ model = model.to(device=device, dtype=torch.bfloat16)
322
+ model.eval()
323
+
324
+ val_path = DATA_DIR / "3b_val.bin"
325
+ if not val_path.exists():
326
+ print(" 3b_val.bin 없음 — 스킵")
327
+ return {}
328
+
329
+ ds = BinDataset(str(val_path), seq_len=512, stride=256)
330
+ loader = DataLoader(ds, batch_size=8, num_workers=0)
331
+
332
+ top1 = top5 = top10 = total = 0
333
+ mean_probs, mean_entropies = [], []
334
+
335
+ CALIB_TOKENS = 50_000
336
+ token_count = 0
337
+
338
+ with torch.no_grad():
339
+ for x, y in loader:
340
+ x, y = x.to(device), y.to(device)
341
+ logits, _ = model(x)
342
+ probs = torch.softmax(logits, dim=-1)
343
+
344
+ mask = (y != 0)
345
+ labels = y[mask]
346
+ p = probs[mask]
347
+
348
+ ranks = (p > p.gather(1, labels.unsqueeze(1))).sum(dim=1)
349
+ top1 += (ranks < 1).sum().item()
350
+ top5 += (ranks < 5).sum().item()
351
+ top10 += (ranks < 10).sum().item()
352
+
353
+ chosen_p = p.gather(1, labels.unsqueeze(1)).squeeze(1)
354
+ mean_probs.append(chosen_p.mean().item())
355
+
356
+ ent = -(p * (p + 1e-10).log()).sum(dim=-1) # p already masked → 1D
357
+ mean_entropies.append(ent.mean().item())
358
+
359
+ total += labels.size(0)
360
+ token_count += labels.size(0)
361
+ if token_count >= CALIB_TOKENS:
362
+ break
363
+
364
+ result = {
365
+ "top1_acc": round(top1 / total, 4),
366
+ "top5_acc": round(top5 / total, 4),
367
+ "top10_acc": round(top10 / total, 4),
368
+ "mean_prob": round(float(np.mean(mean_probs)), 4),
369
+ "mean_entropy": round(float(np.mean(mean_entropies)), 4),
370
+ "total_tokens": total,
371
+ }
372
+ print(f" Top-1: {result['top1_acc']:.2%} Top-5: {result['top5_acc']:.2%} Top-10: {result['top10_acc']:.2%}")
373
+ print(f" Mean prob: {result['mean_prob']:.4f} Entropy: {result['mean_entropy']:.4f}")
374
+
375
+ del model
376
+ torch.cuda.empty_cache()
377
+ return result
378
+
379
+
380
+ # ===========================================================================
381
+ # Phase 4: lm-eval 벤치마크 (커스텀 래퍼)
382
+ # ===========================================================================
383
+
384
+ def run_phase4(checkpoint: str, limit: int = None, exclude_tasks: str = None) -> Dict:
385
+ print("\n" + "=" * 60)
386
+ print("Phase 4: lm-eval 벤치마크")
387
+ print("=" * 60)
388
+
389
+ try:
390
+ import lm_eval
391
+ from lm_eval.api.model import LM as BaseLM
392
+ from lm_eval.api.instance import Instance
393
+ from lm_eval import evaluator
394
+ except ImportError:
395
+ print(" lm-eval 미설치 — 스킵 (pip install lm-eval)")
396
+ return {}
397
+
398
+ device = "cuda:0"
399
+
400
+ class EvafrillLM(BaseLM):
401
+ """EVAFRILL-Mo를 lm-eval-harness에 연결하는 래퍼."""
402
+
403
+ def __init__(self, checkpoint: str, device: str, batch_size: int = 8):
404
+ super().__init__()
405
+ self._model = LLM.from_pretrained(checkpoint)
406
+ self._model = self._model.to(device=device, dtype=torch.bfloat16)
407
+ self._model.eval()
408
+ self._tok = Tokenizer.from_file(TOKENIZER_PATH)
409
+ self._device = device
410
+ self._batch_size = batch_size
411
+ self._max_len = 4096
412
+
413
+ @property
414
+ def eot_token_id(self) -> int:
415
+ return 2 # </s>
416
+
417
+ @property
418
+ def max_length(self) -> int:
419
+ return self._max_len
420
+
421
+ @property
422
+ def max_gen_toks(self) -> int:
423
+ return 256
424
+
425
+ @property
426
+ def batch_size(self) -> int:
427
+ return self._batch_size
428
+
429
+ @property
430
+ def device(self):
431
+ return self._device
432
+
433
+ def tok_encode(self, string: str) -> List[int]:
434
+ return self._tok.encode(string).ids
435
+
436
+ def tok_decode(self, tokens) -> str:
437
+ return self._tok.decode(list(tokens))
438
+
439
+ def _model_call(self, inps: torch.Tensor) -> torch.Tensor:
440
+ with torch.no_grad():
441
+ logits, _ = self._model(inps.to(self._device))
442
+ return logits
443
+
444
+ def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
445
+ results = []
446
+ for req in requests:
447
+ ctx, cont = req.args[0], req.args[1]
448
+ ctx_ids = self.tok_encode(ctx)
449
+ cont_ids = self.tok_encode(cont)
450
+
451
+ all_ids = ctx_ids + cont_ids
452
+ if len(all_ids) > self._max_len:
453
+ all_ids = all_ids[-self._max_len:]
454
+ # adjust cont boundary
455
+ cont_start = len(all_ids) - len(cont_ids)
456
+ else:
457
+ cont_start = len(ctx_ids)
458
+
459
+ inp = torch.tensor([all_ids[:-1]], dtype=torch.long)
460
+ tgt = torch.tensor([all_ids[1:]], dtype=torch.long)
461
+
462
+ logits = self._model_call(inp)
463
+ log_probs = F.log_softmax(logits, dim=-1)
464
+
465
+ # sum log-probs over continuation tokens
466
+ cont_log_prob = 0.0
467
+ is_greedy = True
468
+ for i, t in enumerate(cont_ids):
469
+ pos = cont_start - 1 + i
470
+ if pos >= log_probs.size(1):
471
+ break
472
+ cont_log_prob += log_probs[0, pos, t].item()
473
+ pred = log_probs[0, pos].argmax().item()
474
+ if pred != t:
475
+ is_greedy = False
476
+
477
+ results.append((cont_log_prob, is_greedy))
478
+ return results
479
+
480
+ def loglikelihood_rolling(self, requests) -> List[float]:
481
+ results = []
482
+ for req in requests:
483
+ text = req.args[0]
484
+ ids = self.tok_encode(text)
485
+ total_nll = 0.0
486
+ for start in range(0, len(ids) - 1, self._max_len - 1):
487
+ chunk = ids[start: start + self._max_len]
488
+ if len(chunk) < 2:
489
+ break
490
+ inp = torch.tensor([chunk[:-1]], dtype=torch.long)
491
+ tgt = torch.tensor([chunk[1:]], dtype=torch.long)
492
+ logits = self._model_call(inp)
493
+ nll = F.cross_entropy(
494
+ logits[0], tgt[0].to(self._device), reduction="sum"
495
+ ).item()
496
+ total_nll += nll
497
+ results.append(-total_nll)
498
+ return results
499
+
500
+ def generate_until(self, requests) -> List[str]:
501
+ results = []
502
+ for req in requests:
503
+ ctx = req.args[0]
504
+ gen_kwargs = req.args[1] if len(req.args) > 1 else {}
505
+ until = gen_kwargs.get("until", [])
506
+ max_gen = gen_kwargs.get("max_gen_toks", self.max_gen_toks)
507
+ temp = gen_kwargs.get("temperature", 0.0)
508
+
509
+ ids = self.tok_encode(ctx)
510
+ generated = list(ids)
511
+
512
+ with torch.no_grad():
513
+ for _ in range(max_gen):
514
+ inp = torch.tensor(
515
+ [generated[-self._max_len:]], dtype=torch.long
516
+ )
517
+ logits = self._model_call(inp)[:, -1:, :].squeeze(1)
518
+ if temp == 0.0:
519
+ next_tok = logits.argmax(dim=-1).item()
520
+ else:
521
+ probs = torch.softmax(logits / temp, dim=-1)
522
+ next_tok = torch.multinomial(probs[0], 1).item()
523
+ generated.append(next_tok)
524
+ if next_tok == self.eot_token_id:
525
+ break
526
+ decoded_new = self.tok_decode(generated[len(ids):])
527
+ if any(stop in decoded_new for stop in until):
528
+ break
529
+
530
+ new_text = self.tok_decode(generated[len(ids):])
531
+ for stop in until:
532
+ if stop in new_text:
533
+ new_text = new_text[:new_text.index(stop)]
534
+ results.append(new_text)
535
+ return results
536
+
537
+ lm = EvafrillLM(checkpoint, device=device, batch_size=2)
538
+
539
+ tasks = [
540
+ "belebele_kor_Hang",
541
+ "global_mmlu_full_ko",
542
+ "hellaswag",
543
+ "arc_easy",
544
+ "arc_challenge",
545
+ "kmmlu",
546
+ ]
547
+
548
+ if exclude_tasks:
549
+ excluded = {t.strip() for t in exclude_tasks.split(",")}
550
+ tasks = [t for t in tasks if t not in excluded]
551
+ print(f" 제외: {', '.join(excluded)}")
552
+
553
+ print(f" 태스크: {', '.join(tasks)}")
554
+ print(" (belebele/mmlu: 한국어, hellaswag/arc: 영어)")
555
+ if limit:
556
+ print(f" limit: {limit} examples/task")
557
+
558
+ try:
559
+ results = evaluator.simple_evaluate(
560
+ model=lm,
561
+ tasks=tasks,
562
+ num_fewshot=0,
563
+ batch_size=2,
564
+ log_samples=False,
565
+ limit=limit,
566
+ )
567
+ return results.get("results", {})
568
+ except Exception as e:
569
+ print(f" lm-eval 오류: {e}")
570
+ import traceback; traceback.print_exc()
571
+ return {}
572
+
573
+
574
+ # ===========================================================================
575
+ # Report generation
576
+ # ===========================================================================
577
+
578
+ def generate_report(
579
+ checkpoint: str,
580
+ output_dir: Path,
581
+ ppl: Dict,
582
+ gen: List[Dict],
583
+ calib: Dict,
584
+ bench: Dict,
585
+ elapsed: float,
586
+ ) -> Path:
587
+ now = datetime.now().strftime("%Y-%m-%d %H:%M")
588
+ run_tag = datetime.now().strftime("%Y%m%d_%H%M")
589
+ report_path = _PROJECT_ROOT / "reports" / f"{run_tag}_EVAFRILL_EVAL_REPORT.md"
590
+ report_path.parent.mkdir(parents=True, exist_ok=True)
591
+
592
+ lines = [
593
+ "# EVAFRILL-Mo 3B — 종합 평가 보고서",
594
+ "",
595
+ f"- **평가 일시**: {now}",
596
+ f"- **체크포인트**: `{Path(checkpoint).name}`",
597
+ f"- **총 소요 시간**: {elapsed/60:.1f}분",
598
+ "",
599
+ "---",
600
+ "",
601
+ "## 1. Executive Summary",
602
+ "",
603
+ ]
604
+
605
+ # PPL summary
606
+ if ppl:
607
+ avg_ko = np.mean([v for k, v in ppl.items() if v and "korean" in k or "hplt" in k or "cc100" in k])
608
+ lines += [
609
+ "### PPL (주요 셋)",
610
+ "",
611
+ "| 데이터셋 | PPL |",
612
+ "|---------|-----|",
613
+ ]
614
+ for k, v in sorted(ppl.items()):
615
+ if v is not None:
616
+ lines.append(f"| {k} | {v:.4f} |")
617
+ lines.append("")
618
+
619
+ # Generation summary
620
+ if gen:
621
+ greedy_reps = [r["configs"]["greedy"]["3gram_rep"] for r in gen if "greedy" in r["configs"]]
622
+ greedy_eos = [r["configs"]["greedy"]["eos"] for r in gen if "greedy" in r["configs"]]
623
+ t07r12_reps = [r["configs"].get("t0.7_r1.2", {}).get("3gram_rep", None) for r in gen]
624
+ t07r12_reps = [x for x in t07r12_reps if x is not None]
625
+
626
+ lines += [
627
+ "### 생성 품질 요약",
628
+ "",
629
+ f"| 설정 | 평균 3-gram 반복률 | EOS 종료율 |",
630
+ f"|------|-------------------|-----------|",
631
+ f"| greedy | {np.mean(greedy_reps):.2%} | {np.mean(greedy_eos):.0%} |",
632
+ ]
633
+ if t07r12_reps:
634
+ t07r12_eos = [r["configs"].get("t0.7_r1.2", {}).get("eos", False) for r in gen]
635
+ lines.append(f"| temp=0.7 rep=1.2 | {np.mean(t07r12_reps):.2%} | {np.mean(t07r12_eos):.0%} |")
636
+ lines.append("")
637
+
638
+ # Calibration
639
+ if calib:
640
+ lines += [
641
+ "### Calibration",
642
+ "",
643
+ f"| Top-1 | Top-5 | Top-10 |",
644
+ f"|-------|-------|--------|",
645
+ f"| {calib['top1_acc']:.2%} | {calib['top5_acc']:.2%} | {calib['top10_acc']:.2%} |",
646
+ "",
647
+ ]
648
+
649
+ # Benchmarks
650
+ if bench:
651
+ lines += [
652
+ "### lm-eval 벤치마크",
653
+ "",
654
+ "| 태스크 | Accuracy | 랜덤 기준 |",
655
+ "|--------|----------|----------|",
656
+ ]
657
+ random_baseline = {
658
+ "belebele_kor_Hang": 0.25,
659
+ "global_mmlu_full_ko": 0.25,
660
+ "hellaswag": 0.25,
661
+ "arc_easy": 0.25,
662
+ "arc_challenge": 0.25,
663
+ "kmmlu": 0.25,
664
+ }
665
+ for task, res in bench.items():
666
+ acc = res.get("acc,none", res.get("acc", "N/A"))
667
+ rb = random_baseline.get(task, "?")
668
+ lines.append(f"| {task} | {acc:.4f} | {rb} |")
669
+ lines.append("")
670
+
671
+ # Generation samples
672
+ if gen:
673
+ lines += ["## 2. 생성 샘플 (Greedy)", ""]
674
+ for r in gen:
675
+ gcfg = r["configs"].get("greedy", {})
676
+ lines += [
677
+ f"**[{r['prompt']}]**",
678
+ f"> {gcfg.get('text', '')[:200]}",
679
+ f"> *EOS={gcfg.get('eos')}, 3gram_rep={gcfg.get('3gram_rep', 0):.2%}, tokens={gcfg.get('tokens')}*",
680
+ "",
681
+ ]
682
+
683
+ report_path.write_text("\n".join(lines), encoding="utf-8")
684
+ print(f"\n 보고서 저장: {report_path}")
685
+
686
+ # JSON 결과도 저장
687
+ json_path = output_dir / "evafrill_eval_results.json"
688
+ json_path.parent.mkdir(parents=True, exist_ok=True)
689
+ with open(json_path, "w", encoding="utf-8") as f:
690
+ json.dump({"ppl": ppl, "calib": calib, "bench": bench}, f, ensure_ascii=False, indent=2)
691
+
692
+ return report_path
693
+
694
+
695
+ # ===========================================================================
696
+ # Main
697
+ # ===========================================================================
698
+
699
+ def main():
700
+ args = parse_args()
701
+ t_start = time.time()
702
+
703
+ run_tag = datetime.now().strftime("%Y%m%d_%H%M")
704
+ output_dir = Path(args.output_dir) if args.output_dir else (
705
+ _PROJECT_ROOT / "eval" / "outputs" / f"evafrill_eval_{run_tag}"
706
+ )
707
+ output_dir.mkdir(parents=True, exist_ok=True)
708
+
709
+ print("=" * 60)
710
+ print("EVAFRILL-Mo 3B 종합 평가 시작")
711
+ print(f"체크포인트: {args.checkpoint}")
712
+ print(f"출력 디렉토리: {output_dir}")
713
+ print("=" * 60)
714
+
715
+ ppl_results = {}
716
+ gen_results = []
717
+ calib_results = {}
718
+ bench_results = {}
719
+
720
+ if not args.skip_phase1:
721
+ ppl_results = run_phase1(
722
+ args.checkpoint, args.seq_len, args.stride, args.batch_size
723
+ )
724
+
725
+ if not args.skip_phase2:
726
+ gen_results = run_phase2(args.checkpoint, args.max_new_tokens)
727
+
728
+ if not args.skip_phase3:
729
+ calib_results = run_phase3(args.checkpoint)
730
+
731
+ if not args.skip_phase4:
732
+ bench_results = run_phase4(args.checkpoint, limit=args.limit,
733
+ exclude_tasks=args.exclude_tasks)
734
+
735
+ elapsed = time.time() - t_start
736
+ report_path = generate_report(
737
+ args.checkpoint, output_dir,
738
+ ppl_results, gen_results, calib_results, bench_results,
739
+ elapsed,
740
+ )
741
+
742
+ print("\n" + "=" * 60)
743
+ print(f"평가 완료! 총 {elapsed/60:.1f}분")
744
+ print(f"보고서: {report_path}")
745
+ print("=" * 60)
746
+
747
+
748
+ if __name__ == "__main__":
749
+ main()
scripts/generate_repetition_preference.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ data/generate_repetition_preference.py — Self-play preference data targeting repetition.
4
+
5
+ Generates (prompt, chosen, rejected) pairs by:
6
+ - rejected: greedy decoding (temp=0, rep_penalty=1.0) → tends to repeat
7
+ - chosen: sampling with repetition penalty (temp=0.7, rep_penalty=1.2) → cleaner
8
+
9
+ Only keeps pairs where rejected has strictly higher 3-gram repetition rate than chosen.
10
+
11
+ Usage:
12
+ python3 data/generate_repetition_preference.py \
13
+ --checkpoint checkpoints/3b_dpo/checkpoint-slerp
14
+
15
+ python3 data/generate_repetition_preference.py \
16
+ --checkpoint checkpoints/3b_dpo/checkpoint-slerp \
17
+ --output data/preference/repetition_preference.jsonl \
18
+ --num_prompts 100 \
19
+ --max_tokens 256
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import math
27
+ import os
28
+ import sys
29
+ import time
30
+ from pathlib import Path
31
+ from typing import List, Optional
32
+
33
+ import torch
34
+ import torch.nn.functional as F
35
+
36
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent
37
+ if str(_PROJECT_ROOT) not in sys.path:
38
+ sys.path.insert(0, str(_PROJECT_ROOT))
39
+
40
+ from model import LLM # noqa: E402
41
+ from tokenizers import Tokenizer # noqa: E402
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Korean prompt bank — 100+ diverse prompts
45
+ # ---------------------------------------------------------------------------
46
+
47
+ # 15 existing eval prompts (completion style → wrapped in chat template)
48
+ _EVAL_PROMPTS = [
49
+ "대한민국의 수도는 어디인지 설명해주세요.",
50
+ "인공지능이란 무엇인지 자세히 설명해주세요.",
51
+ "한국의 전통 음식 중에서 대표적인 것들을 소개해주세요.",
52
+ "지구 온난화의 주요 원인은 무엇인가요?",
53
+ "프로그래밍을 배우려면 어떻게 시작해야 하나요?",
54
+ "조선시대에는 어떤 일들이 있었나요?",
55
+ "물리학에서 에너지란 무엇인지 설명해주세요.",
56
+ "한국어는 세계에서 어떤 특징을 가지고 있나요?",
57
+ "경제 성장을 위해서는 무엇이 필요한가요?",
58
+ "우주 탐사의 역사를 간단히 설명해주세요.",
59
+ "머신러닝과 딥러닝의 차이는 무엇인가요?",
60
+ "한국 문학의 대표적인 작품으로는 어떤 것들이 있나요?",
61
+ "양자 컴퓨터란 무엇인지 설명해주세요.",
62
+ "건강한 식습관을 위해서는 어떻게 해야 하나요?",
63
+ "세계 2차 대전 이후 세계는 어떻게 변했나요?",
64
+ ]
65
+
66
+ # Additional diverse prompts (~85 more)
67
+ _EXTRA_PROMPTS = [
68
+ # 일상 대화
69
+ "오늘 날씨가 좋은데 뭐 하면 좋을까요?",
70
+ "주말에 뭐 하면 좋을지 추천해주세요.",
71
+ "좋은 하루를 시작하는 방법을 알려주세요.",
72
+ "집에서 할 수 있는 취미 활동을 추천해주세요.",
73
+ "친구와 싸웠을 때 어떻게 화해하면 좋을까요?",
74
+ "외로움을 느낄 때 어떻게 극복할 수 있나요?",
75
+ "시간 관리를 잘 하는 방법을 알려주세요.",
76
+ "아침 일찍 일어나는 습관을 만들려면 어떻게 해야 하나요?",
77
+ "새로운 도시로 이사했을 때 적응하는 방법은?",
78
+ "카페에서 혼자 시간 보내는 것의 장점은 무엇인가요?",
79
+
80
+ # 지식 — 과학
81
+ "DNA가 무엇인지 설명해주세요.",
82
+ "블랙홀이란 무엇인가요?",
83
+ "진화론이란 무엇인지 간단히 설명해주세요.",
84
+ "기후 변화가 생태계에 미치는 영향은 무엇인가요?",
85
+ "인체의 면역 시스템은 어떻게 작동하나요?",
86
+ "빛의 속도는 왜 중요한가요?",
87
+ "원자와 분자의 차이점은 무엇인가요?",
88
+ "광합성이란 무엇인지 설명해주세요.",
89
+ "중력파란 무엇인가요?",
90
+ "줄기세포 치료란 무엇이며 어떻게 활용되나요?",
91
+
92
+ # 지식 — 역사·사회
93
+ "한국의 역사에서 가장 중요한 사건은 무엇인가요?",
94
+ "민주주의란 무엇인지 설명해주세요.",
95
+ "산업혁명이 세계에 미친 영향은 무엇인가요?",
96
+ "냉전이란 무엇이었나요?",
97
+ "한국 전쟁의 원인과 결과를 설명해주세요.",
98
+ "세계화란 무엇이며 어떤 영향을 미치나요?",
99
+ "인권이란 무엇이며 왜 중요한가요?",
100
+ "실크로드가 역사적으로 중요한 이유는 무엇인가요?",
101
+ "르네상스 시대는 어떤 시기였나요?",
102
+ "한국의 독립운동에 대해 설명해주세요.",
103
+
104
+ # 조언 — 직업·학습
105
+ "취업 면접 잘 보는 방법은 무엇인가요?",
106
+ "이력서를 잘 쓰는 방법을 알려주세요.",
107
+ "대학 생활을 알차게 보내는 방법은?",
108
+ "공부 집중력을 높이는 방법을 알려주세요.",
109
+ "외국어를 빠르게 배우는 방법은 무엇인가요?",
110
+ "직장에서 상사와 잘 지내는 방법은?",
111
+ "프리랜서로 일하면 어떤 장단점이 있나요?",
112
+ "자기소개서를 잘 ��는 팁을 알려주세요.",
113
+ "독서 습관을 기르는 방법은 무엇인가요?",
114
+ "수학을 잘하기 위한 공부 방법은?",
115
+
116
+ # 조언 — 건강·심리
117
+ "스트레스 해소 방법을 알려주세요.",
118
+ "우울감을 극복하는 방법은 무엇인가요?",
119
+ "규칙적인 운동 습관을 만드는 방법은?",
120
+ "수면의 질을 높이는 방법을 알려주세요.",
121
+ "번아웃을 예방하는 방법은 무엇인가요?",
122
+ "마음의 평화를 찾는 방법은?",
123
+ "자존감을 높이는 방법을 알려주세요.",
124
+ "명상을 시작하려면 어떻게 해야 하나요?",
125
+ "건강한 체중을 유지하는 방법은?",
126
+ "디지털 중독을 극복하는 방법을 알려주세요.",
127
+
128
+ # 창작
129
+ "짧은 동화를 하나 만들어주세요.",
130
+ "봄에 대한 시를 써주세요.",
131
+ "미래 도시를 배경으로 한 짧은 이야기를 써주세요.",
132
+ "바다에 관한 짧은 수필을 써주세요.",
133
+ "고양이를 주인공으로 한 짧은 이야기를 만들어주세요.",
134
+ "가을 풍경을 묘사하는 글을 써주세요.",
135
+ "우정에 관한 짧은 시를 써주세요.",
136
+ "엄마에게 보내는 편지를 써주세요.",
137
+ "미래의 나에게 쓰는 편지를 작성해주세요.",
138
+ "어린 시절 추억에 관한 짧은 글을 써주세요.",
139
+
140
+ # 기술·IT
141
+ "클라우드 컴퓨팅이란 무엇인가요?",
142
+ "블록체인이 무엇인지 설명해주세요.",
143
+ "사이버 보안이 왜 중요한가요?",
144
+ "빅데이터란 무엇이며 어떻게 활용되나요?",
145
+ "5G 기술이 가져올 변화는 무엇인가요?",
146
+ "인터넷 검색 엔진은 어떻게 작동하나요?",
147
+ "스마트폰이 생활에 미친 영향은 무엇인가요?",
148
+ "가상현실과 증강현실의 차이는 무엇인가요?",
149
+ "자율주행 자동차 기술은 어디까지 왔나요?",
150
+ "오픈소스 소프트웨어란 무엇인가요?",
151
+
152
+ # 문화·예술
153
+ "K-팝이 세계적으로 인기를 얻은 이유는 무엇인가요?",
154
+ "한국 영화가 세계 시장에서 주목받는 이유는?",
155
+ "전통 예술과 현대 예술의 차이는 무엇인가요?",
156
+ "음악이 감정에 미치는 영향은 무엇인가요?",
157
+ "독서가 삶에 미치는 긍정적인 영향은?",
158
+ "미술 감상을 잘 하는 방법을 알려주세요.",
159
+ "한국 전통 음악인 국악의 특징은 무엇인가요?",
160
+ "영화 비평을 잘 쓰는 방법은?",
161
+ "여행이 사람을 성장시키는 이유는 무엇인가요?",
162
+ "사진 찍기를 잘 하는 팁을 알려주세요.",
163
+
164
+ # 환경·사회
165
+ "환경 보호를 위해 개인이 할 수 있는 일은?",
166
+ "재활용의 중요성과 방법을 설명해주세요.",
167
+ "채식주의의 장단점은 무엇인가요?",
168
+ "동물 복지란 무엇이며 왜 중요한가요?",
169
+ "지속 가능한 발전이란 무엇인가요?",
170
+ "노령화 사회가 가져오는 문제점은 무엇인가요?",
171
+ "교육 불평등을 해소하는 방법은?",
172
+ "빈곤 문제를 해결하기 위한 방법은?",
173
+ "다문화 사회에서 공존하는 방법은?",
174
+ "봉사 활동이 사회에 미치는 영향은 무엇인가요?",
175
+ ]
176
+
177
+ ALL_PROMPTS = _EVAL_PROMPTS + _EXTRA_PROMPTS # 15 + 85 = 100
178
+
179
+ CHAT_TEMPLATE = "<|user|>\n{prompt}\n<|assistant|>\n"
180
+
181
+ EOS_TOKEN_ID = 2
182
+
183
+
184
+ # ---------------------------------------------------------------------------
185
+ # Repetition metric
186
+ # ---------------------------------------------------------------------------
187
+
188
+ def compute_ngram_repetition_rate(tokens: List[int], n: int = 3) -> float:
189
+ """Fraction of n-gram positions that are repeats of an earlier occurrence."""
190
+ if len(tokens) < n:
191
+ return 0.0
192
+ ngrams = [tuple(tokens[i: i + n]) for i in range(len(tokens) - n + 1)]
193
+ if not ngrams:
194
+ return 0.0
195
+ seen: set = set()
196
+ repeated = 0
197
+ for ng in ngrams:
198
+ if ng in seen:
199
+ repeated += 1
200
+ seen.add(ng)
201
+ return repeated / len(ngrams)
202
+
203
+
204
+ # ---------------------------------------------------------------------------
205
+ # Generation
206
+ # ---------------------------------------------------------------------------
207
+
208
+ @torch.inference_mode()
209
+ def generate(
210
+ model: torch.nn.Module,
211
+ input_ids: torch.Tensor,
212
+ max_new_tokens: int,
213
+ temperature: float,
214
+ repetition_penalty: float,
215
+ eos_token_id: int,
216
+ ) -> List[int]:
217
+ """Auto-regressive generation with optional repetition penalty.
218
+
219
+ Args:
220
+ model: LLM instance already on device
221
+ input_ids: (1, T) prompt token ids
222
+ max_new_tokens: max tokens to generate
223
+ temperature: sampling temperature (0 = greedy)
224
+ repetition_penalty: penalty > 1 reduces prob of previously seen tokens
225
+ eos_token_id: stop generation when this token is produced
226
+
227
+ Returns:
228
+ List of generated token ids (not including the prompt).
229
+ """
230
+ device = input_ids.device
231
+ generated: List[int] = []
232
+ current_ids = input_ids.clone() # (1, T)
233
+
234
+ for _ in range(max_new_tokens):
235
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
236
+ logits, _ = model(current_ids) # (1, T, V)
237
+
238
+ next_logits = logits[0, -1, :].float() # (V,)
239
+
240
+ # Repetition penalty: discount logits for already-generated tokens
241
+ if repetition_penalty != 1.0:
242
+ all_seen_ids = current_ids[0].tolist() + generated
243
+ for token_id in set(all_seen_ids):
244
+ if token_id < next_logits.shape[0]:
245
+ if next_logits[token_id] < 0:
246
+ next_logits[token_id] *= repetition_penalty
247
+ else:
248
+ next_logits[token_id] /= repetition_penalty
249
+
250
+ # Sample / greedy
251
+ if temperature == 0.0:
252
+ next_token = int(next_logits.argmax())
253
+ else:
254
+ next_logits = next_logits / temperature
255
+ probs = F.softmax(next_logits, dim=-1)
256
+ next_token = int(torch.multinomial(probs, num_samples=1).item())
257
+
258
+ generated.append(next_token)
259
+
260
+ if next_token == eos_token_id:
261
+ break
262
+
263
+ # Append to context
264
+ next_tensor = torch.tensor([[next_token]], dtype=torch.long, device=device)
265
+ current_ids = torch.cat([current_ids, next_tensor], dim=1)
266
+
267
+ return generated
268
+
269
+
270
+ # ---------------------------------------------------------------------------
271
+ # Main
272
+ # ---------------------------------------------------------------------------
273
+
274
+ def parse_args() -> argparse.Namespace:
275
+ parser = argparse.ArgumentParser(
276
+ description="Generate self-play repetition preference data"
277
+ )
278
+ parser.add_argument(
279
+ "--checkpoint",
280
+ type=Path,
281
+ default=Path("checkpoints/3b_dpo/checkpoint-slerp"),
282
+ help="Path to model checkpoint directory",
283
+ )
284
+ parser.add_argument(
285
+ "--output",
286
+ type=Path,
287
+ default=Path("data/preference/repetition_preference.jsonl"),
288
+ help="Output JSONL path",
289
+ )
290
+ parser.add_argument(
291
+ "--num_prompts",
292
+ type=int,
293
+ default=None,
294
+ help="How many prompts to use (default: all ~100)",
295
+ )
296
+ parser.add_argument(
297
+ "--max_tokens",
298
+ type=int,
299
+ default=256,
300
+ help="Max new tokens per generation",
301
+ )
302
+ parser.add_argument(
303
+ "--tokenizer",
304
+ type=Path,
305
+ default=None,
306
+ help="Path to tokenizer.json (default: auto-resolve)",
307
+ )
308
+ parser.add_argument(
309
+ "--device",
310
+ type=str,
311
+ default="cuda:0",
312
+ help="Torch device string",
313
+ )
314
+ parser.add_argument(
315
+ "--seed",
316
+ type=int,
317
+ default=42,
318
+ help="Random seed for reproducibility",
319
+ )
320
+ parser.add_argument(
321
+ "--min_rep_diff",
322
+ type=float,
323
+ default=0.0,
324
+ help="Minimum difference (rejected_rep - chosen_rep) to keep a pair (default: >0)",
325
+ )
326
+ return parser.parse_args()
327
+
328
+
329
+ def _resolve_tokenizer(args: argparse.Namespace) -> Path:
330
+ if args.tokenizer is not None:
331
+ return Path(args.tokenizer)
332
+ # Try checkpoint dir first
333
+ ckpt_tok = args.checkpoint / "tokenizer.json"
334
+ if ckpt_tok.exists():
335
+ return ckpt_tok
336
+ # Fall back to project default
337
+ default_tok = _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json"
338
+ if default_tok.exists():
339
+ return default_tok
340
+ raise FileNotFoundError(
341
+ "Cannot find tokenizer.json — specify with --tokenizer"
342
+ )
343
+
344
+
345
+ def main() -> None:
346
+ args = parse_args()
347
+
348
+ # Reproducibility
349
+ torch.manual_seed(args.seed)
350
+ if torch.cuda.is_available():
351
+ torch.cuda.manual_seed_all(args.seed)
352
+
353
+ # Prompts
354
+ prompts = ALL_PROMPTS
355
+ if args.num_prompts is not None:
356
+ prompts = prompts[: args.num_prompts]
357
+ print(f"[INFO] Using {len(prompts)} prompts")
358
+
359
+ # Device
360
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
361
+ print(f"[INFO] Device: {device}")
362
+
363
+ # Tokenizer
364
+ tokenizer_path = _resolve_tokenizer(args)
365
+ print(f"[INFO] Loading tokenizer from {tokenizer_path}")
366
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
367
+
368
+ # Model
369
+ checkpoint_path = _PROJECT_ROOT / args.checkpoint if not args.checkpoint.is_absolute() else args.checkpoint
370
+ if not checkpoint_path.exists():
371
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
372
+ print(f"[INFO] Loading model from {checkpoint_path} ...")
373
+ t0 = time.perf_counter()
374
+ model = LLM.from_pretrained(checkpoint_path)
375
+ model = model.to(device=device, dtype=torch.bfloat16)
376
+ model.eval()
377
+ print(f"[INFO] Model loaded in {time.perf_counter() - t0:.1f}s")
378
+
379
+ # Output dir
380
+ output_path = _PROJECT_ROOT / args.output if not args.output.is_absolute() else args.output
381
+ output_path.parent.mkdir(parents=True, exist_ok=True)
382
+
383
+ # Stats
384
+ valid_pairs = 0
385
+ skipped = 0
386
+ total_rejected_rep = 0.0
387
+ total_chosen_rep = 0.0
388
+
389
+ t_start = time.perf_counter()
390
+
391
+ with open(output_path, "w", encoding="utf-8") as fout:
392
+ for idx, prompt_text in enumerate(prompts):
393
+ prompt_str = CHAT_TEMPLATE.format(prompt=prompt_text)
394
+
395
+ # Tokenize prompt
396
+ encoding = tokenizer.encode(prompt_str)
397
+ prompt_ids = encoding.ids
398
+ if not prompt_ids:
399
+ print(f" [{idx+1}/{len(prompts)}] SKIP: empty tokenization for prompt")
400
+ skipped += 1
401
+ continue
402
+
403
+ input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
404
+
405
+ # --- Generate REJECTED: greedy, no rep penalty ---
406
+ rej_tokens = generate(
407
+ model=model,
408
+ input_ids=input_ids,
409
+ max_new_tokens=args.max_tokens,
410
+ temperature=0.0,
411
+ repetition_penalty=1.0,
412
+ eos_token_id=EOS_TOKEN_ID,
413
+ )
414
+
415
+ # --- Generate CHOSEN: sampling + rep penalty ---
416
+ cho_tokens = generate(
417
+ model=model,
418
+ input_ids=input_ids,
419
+ max_new_tokens=args.max_tokens,
420
+ temperature=0.7,
421
+ repetition_penalty=1.2,
422
+ eos_token_id=EOS_TOKEN_ID,
423
+ )
424
+
425
+ # Decode (strip EOS)
426
+ rej_clean = [t for t in rej_tokens if t != EOS_TOKEN_ID]
427
+ cho_clean = [t for t in cho_tokens if t != EOS_TOKEN_ID]
428
+
429
+ rej_text = tokenizer.decode(rej_clean)
430
+ cho_text = tokenizer.decode(cho_clean)
431
+
432
+ # Compute 3-gram repetition rates on generated tokens
433
+ rej_rep = compute_ngram_repetition_rate(rej_clean, n=3)
434
+ cho_rep = compute_ngram_repetition_rate(cho_clean, n=3)
435
+
436
+ # Filter: only keep if rejected is more repetitive than chosen
437
+ diff = rej_rep - cho_rep
438
+ if diff <= args.min_rep_diff:
439
+ status = "SKIP"
440
+ skipped += 1
441
+ else:
442
+ status = "KEEP"
443
+ valid_pairs += 1
444
+ total_rejected_rep += rej_rep
445
+ total_chosen_rep += cho_rep
446
+ record = {
447
+ "prompt": prompt_str,
448
+ "chosen": cho_text,
449
+ "rejected": rej_text,
450
+ }
451
+ fout.write(json.dumps(record, ensure_ascii=False) + "\n")
452
+
453
+ elapsed = time.perf_counter() - t_start
454
+ print(
455
+ f" [{idx+1:3d}/{len(prompts)}] {status:4s} "
456
+ f"rej_rep={rej_rep:.3f} cho_rep={cho_rep:.3f} diff={diff:+.3f} "
457
+ f"| rej_len={len(rej_clean)} cho_len={len(cho_clean)} "
458
+ f"| elapsed={elapsed:.1f}s"
459
+ )
460
+
461
+ # Summary
462
+ elapsed_total = time.perf_counter() - t_start
463
+ print()
464
+ print("=" * 60)
465
+ print(f"Generation complete in {elapsed_total:.1f}s")
466
+ print(f" Total prompts processed : {len(prompts)}")
467
+ print(f" Valid pairs kept : {valid_pairs}")
468
+ print(f" Skipped (rep filter) : {skipped}")
469
+ if valid_pairs > 0:
470
+ avg_rej = total_rejected_rep / valid_pairs
471
+ avg_cho = total_chosen_rep / valid_pairs
472
+ print(f" Avg rejected 3-gram rep : {avg_rej:.4f}")
473
+ print(f" Avg chosen 3-gram rep : {avg_cho:.4f}")
474
+ print(f" Avg improvement : {avg_rej - avg_cho:+.4f}")
475
+ print(f" Output saved to : {output_path}")
476
+ print("=" * 60)
477
+
478
+
479
+ if __name__ == "__main__":
480
+ main()
scripts/lora.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model/lora.py — LoRA (Low-Rank Adaptation) for EVAFRILL-Mo hybrid models.
3
+
4
+ Injects trainable low-rank adapters into:
5
+ - Attention layers: qkv_proj, out_proj
6
+ - Mamba-2 layers: in_proj, out_proj
7
+
8
+ Usage:
9
+ model = LLM.from_pretrained(checkpoint)
10
+ apply_lora(model, rank=32, alpha=64)
11
+ # Only LoRA params are trainable; base model is frozen
12
+
13
+ # After training, merge LoRA weights back:
14
+ merge_lora(model)
15
+
16
+ # Or save/load LoRA weights separately:
17
+ save_lora(model, path)
18
+ load_lora(model, path)
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import math
24
+ from pathlib import Path
25
+ from typing import Optional
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ from .attention import MultiHeadAttention
32
+ from .mamba_block import Mamba2Block
33
+
34
+
35
+ class LoRALinear(nn.Module):
36
+ """LoRA adapter wrapping an existing nn.Linear layer.
37
+
38
+ Computes: output = original_linear(x) + (alpha/rank) * x @ A^T @ B^T
39
+ where A: (rank, in_features), B: (out_features, rank)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ original: nn.Linear,
45
+ rank: int = 32,
46
+ alpha: float = 64.0,
47
+ dropout: float = 0.0,
48
+ ) -> None:
49
+ super().__init__()
50
+ self.original = original
51
+ self.rank = rank
52
+ self.alpha = alpha
53
+ self.scaling = alpha / rank
54
+
55
+ in_features = original.in_features
56
+ out_features = original.out_features
57
+
58
+ # A: down-projection (in_features → rank)
59
+ # Create on same device/dtype as original weights
60
+ _dev = original.weight.device
61
+ _dt = original.weight.dtype
62
+ self.lora_A = nn.Parameter(torch.empty(rank, in_features, device=_dev, dtype=_dt))
63
+ # B: up-projection (rank → out_features)
64
+ self.lora_B = nn.Parameter(torch.zeros(out_features, rank, device=_dev, dtype=_dt))
65
+
66
+ # Initialize A with kaiming uniform, B with zeros
67
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
68
+ # B is already zeros → initial LoRA output is zero
69
+
70
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
71
+
72
+ # Freeze original weights
73
+ original.weight.requires_grad = False
74
+ if original.bias is not None:
75
+ original.bias.requires_grad = False
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ # Original forward
79
+ result = self.original(x)
80
+ # LoRA path: x → dropout → A → B → scale
81
+ lora_out = self.dropout(x)
82
+ lora_out = F.linear(lora_out, self.lora_A) # (..., rank)
83
+ lora_out = F.linear(lora_out, self.lora_B) # (..., out_features)
84
+ return result + lora_out * self.scaling
85
+
86
+ def merge_weights(self) -> None:
87
+ """Merge LoRA weights into the original linear layer permanently."""
88
+ with torch.no_grad():
89
+ # W' = W + scaling * B @ A
90
+ self.original.weight.add_(
91
+ (self.lora_B @ self.lora_A) * self.scaling
92
+ )
93
+
94
+ @property
95
+ def weight(self) -> torch.Tensor:
96
+ """Access original weight for compatibility."""
97
+ return self.original.weight
98
+
99
+ @property
100
+ def bias(self) -> Optional[torch.Tensor]:
101
+ return self.original.bias
102
+
103
+
104
+ def apply_lora(
105
+ model: nn.Module,
106
+ rank: int = 32,
107
+ alpha: float = 64.0,
108
+ dropout: float = 0.0,
109
+ target_modules: Optional[list[str]] = None,
110
+ ) -> int:
111
+ """Apply LoRA adapters to a model, freeze base weights.
112
+
113
+ Args:
114
+ model: The LLM model (raw, not DDP-wrapped).
115
+ rank: LoRA rank (default 32).
116
+ alpha: LoRA scaling factor (default 64).
117
+ dropout: Dropout on LoRA path (default 0).
118
+ target_modules: List of module attribute names to adapt.
119
+ Default: ["qkv_proj", "out_proj", "in_proj"]
120
+ (covers both Attention and Mamba layers).
121
+
122
+ Returns:
123
+ Number of LoRA parameters added.
124
+ """
125
+ if target_modules is None:
126
+ target_modules = ["qkv_proj", "out_proj", "in_proj"]
127
+
128
+ # First, freeze ALL parameters
129
+ for param in model.parameters():
130
+ param.requires_grad = False
131
+
132
+ lora_count = 0
133
+ total_lora_params = 0
134
+
135
+ for name, module in model.named_modules():
136
+ # Check Attention layers
137
+ if isinstance(module, MultiHeadAttention):
138
+ for attr in target_modules:
139
+ if hasattr(module, attr):
140
+ original = getattr(module, attr)
141
+ if isinstance(original, nn.Linear):
142
+ lora_layer = LoRALinear(original, rank=rank, alpha=alpha, dropout=dropout)
143
+ setattr(module, attr, lora_layer)
144
+ params = rank * original.in_features + original.out_features * rank
145
+ total_lora_params += params
146
+ lora_count += 1
147
+
148
+ # Check Mamba layers
149
+ elif isinstance(module, Mamba2Block):
150
+ for attr in target_modules:
151
+ if hasattr(module, attr):
152
+ original = getattr(module, attr)
153
+ if isinstance(original, nn.Linear):
154
+ lora_layer = LoRALinear(original, rank=rank, alpha=alpha, dropout=dropout)
155
+ setattr(module, attr, lora_layer)
156
+ params = rank * original.in_features + original.out_features * rank
157
+ total_lora_params += params
158
+ lora_count += 1
159
+
160
+ print(f"[LoRA] Applied {lora_count} adapters, {total_lora_params:,} trainable params "
161
+ f"(rank={rank}, alpha={alpha})")
162
+ return total_lora_params
163
+
164
+
165
+ def merge_lora(model: nn.Module) -> int:
166
+ """Merge all LoRA weights back into base model and remove LoRA layers.
167
+
168
+ Returns:
169
+ Number of LoRA layers merged.
170
+ """
171
+ merged = 0
172
+ for name, module in model.named_modules():
173
+ for attr_name in list(vars(module).keys()):
174
+ # Check nn.Module children
175
+ pass
176
+
177
+ if isinstance(module, (MultiHeadAttention, Mamba2Block)):
178
+ for attr in ["qkv_proj", "out_proj", "in_proj"]:
179
+ if hasattr(module, attr):
180
+ layer = getattr(module, attr)
181
+ if isinstance(layer, LoRALinear):
182
+ layer.merge_weights()
183
+ setattr(module, attr, layer.original)
184
+ merged += 1
185
+
186
+ # Unfreeze all parameters after merging
187
+ for param in model.parameters():
188
+ param.requires_grad = True
189
+
190
+ print(f"[LoRA] Merged {merged} adapters back into base model")
191
+ return merged
192
+
193
+
194
+ def get_lora_params(model: nn.Module) -> list[nn.Parameter]:
195
+ """Get all LoRA trainable parameters."""
196
+ params = []
197
+ for module in model.modules():
198
+ if isinstance(module, LoRALinear):
199
+ params.append(module.lora_A)
200
+ params.append(module.lora_B)
201
+ return params
202
+
203
+
204
+ def save_lora(model: nn.Module, path: str | Path) -> Path:
205
+ """Save only the LoRA adapter weights."""
206
+ path = Path(path)
207
+ path.mkdir(parents=True, exist_ok=True)
208
+
209
+ lora_state = {}
210
+ for name, module in model.named_modules():
211
+ if isinstance(module, LoRALinear):
212
+ lora_state[f"{name}.lora_A"] = module.lora_A.data.cpu()
213
+ lora_state[f"{name}.lora_B"] = module.lora_B.data.cpu()
214
+
215
+ save_path = path / "lora_weights.pt"
216
+ torch.save(lora_state, save_path)
217
+ n_params = sum(v.numel() for v in lora_state.values())
218
+ size_mb = save_path.stat().st_size / 1e6
219
+ print(f"[LoRA] Saved {len(lora_state)} tensors ({n_params:,} params, {size_mb:.1f} MB) → {save_path}")
220
+ return save_path
221
+
222
+
223
+ def load_lora(model: nn.Module, path: str | Path) -> int:
224
+ """Load LoRA adapter weights. LoRA layers must already be applied."""
225
+ path = Path(path)
226
+ lora_file = path / "lora_weights.pt" if path.is_dir() else path
227
+ lora_state = torch.load(lora_file, map_location="cpu", weights_only=True)
228
+
229
+ loaded = 0
230
+ for name, module in model.named_modules():
231
+ if isinstance(module, LoRALinear):
232
+ a_key = f"{name}.lora_A"
233
+ b_key = f"{name}.lora_B"
234
+ if a_key in lora_state and b_key in lora_state:
235
+ module.lora_A.data.copy_(lora_state[a_key])
236
+ module.lora_B.data.copy_(lora_state[b_key])
237
+ loaded += 1
238
+
239
+ print(f"[LoRA] Loaded {loaded} adapter weight pairs from {lora_file}")
240
+ return loaded
scripts/merge_checkpoints.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ scripts/merge_checkpoints.py — Slerp (Spherical Linear Interpolation) checkpoint merge.
4
+
5
+ Merges two model checkpoints (e.g., SFT + DPO) using SLERP interpolation
6
+ to balance knowledge retention and alignment improvement.
7
+
8
+ Reference: Nemotron-H paper — SLERP merging reduces alignment tax.
9
+
10
+ Usage:
11
+ python scripts/merge_checkpoints.py \
12
+ --ckpt_a checkpoints/3b_sft_v2/checkpoint-best \
13
+ --ckpt_b checkpoints/3b_dpo/checkpoint-merged \
14
+ --output checkpoints/3b_dpo/checkpoint-slerp \
15
+ --alpha 0.5
16
+
17
+ alpha=0.0 → pure ckpt_a (SFT)
18
+ alpha=1.0 → pure ckpt_b (DPO)
19
+ alpha=0.5 → equal blend (recommended starting point)
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import math
26
+ import shutil
27
+ from pathlib import Path
28
+
29
+ import torch
30
+ import yaml
31
+
32
+
33
+ def slerp(t: float, v0: torch.Tensor, v1: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
34
+ """Spherical linear interpolation between two tensors.
35
+
36
+ Args:
37
+ t: Interpolation factor in [0, 1]. 0 → v0, 1 → v1.
38
+ v0: First tensor (flattened internally).
39
+ v1: Second tensor (same shape as v0).
40
+ eps: Small value to avoid division by zero.
41
+
42
+ Returns:
43
+ Interpolated tensor with the same shape as v0.
44
+ """
45
+ original_shape = v0.shape
46
+ v0_flat = v0.flatten().float()
47
+ v1_flat = v1.flatten().float()
48
+
49
+ # Normalize
50
+ v0_norm = v0_flat / (v0_flat.norm() + eps)
51
+ v1_norm = v1_flat / (v1_flat.norm() + eps)
52
+
53
+ # Cosine of angle between vectors
54
+ cos_omega = torch.dot(v0_norm, v1_norm).clamp(-1.0, 1.0)
55
+
56
+ # If vectors are very similar, fall back to linear interpolation
57
+ if abs(cos_omega.item()) > 0.9995:
58
+ result = (1.0 - t) * v0_flat + t * v1_flat
59
+ return result.reshape(original_shape).to(v0.dtype)
60
+
61
+ omega = torch.acos(cos_omega)
62
+ sin_omega = torch.sin(omega)
63
+
64
+ s0 = torch.sin((1.0 - t) * omega) / sin_omega
65
+ s1 = torch.sin(t * omega) / sin_omega
66
+
67
+ # Interpolate using original (non-normalized) vectors to preserve scale
68
+ result = s0 * v0_flat + s1 * v1_flat
69
+ return result.reshape(original_shape).to(v0.dtype)
70
+
71
+
72
+ def lerp(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
73
+ """Simple linear interpolation."""
74
+ return ((1.0 - t) * v0.float() + t * v1.float()).to(v0.dtype)
75
+
76
+
77
+ def merge_state_dicts(
78
+ sd_a: dict[str, torch.Tensor],
79
+ sd_b: dict[str, torch.Tensor],
80
+ alpha: float = 0.5,
81
+ method: str = "slerp",
82
+ ) -> dict[str, torch.Tensor]:
83
+ """Merge two state dicts using SLERP or LERP.
84
+
85
+ Args:
86
+ sd_a: State dict A (e.g., SFT model).
87
+ sd_b: State dict B (e.g., DPO model).
88
+ alpha: Interpolation factor. 0 → A, 1 → B.
89
+ method: "slerp" or "lerp".
90
+
91
+ Returns:
92
+ Merged state dict.
93
+ """
94
+ interp_fn = slerp if method == "slerp" else lerp
95
+
96
+ merged = {}
97
+ keys_a = set(sd_a.keys())
98
+ keys_b = set(sd_b.keys())
99
+
100
+ common = keys_a & keys_b
101
+ only_a = keys_a - keys_b
102
+ only_b = keys_b - keys_a
103
+
104
+ if only_a:
105
+ print(f"[WARN] {len(only_a)} keys only in ckpt_a (kept as-is)")
106
+ if only_b:
107
+ print(f"[WARN] {len(only_b)} keys only in ckpt_b (kept as-is)")
108
+
109
+ for key in sorted(common):
110
+ va = sd_a[key]
111
+ vb = sd_b[key]
112
+
113
+ if va.shape != vb.shape:
114
+ print(f"[WARN] Shape mismatch for {key}: {va.shape} vs {vb.shape}, keeping ckpt_a")
115
+ merged[key] = va
116
+ continue
117
+
118
+ # Only interpolate float parameters (skip int buffers, etc.)
119
+ if va.is_floating_point() and va.numel() > 1:
120
+ merged[key] = interp_fn(alpha, va, vb)
121
+ else:
122
+ merged[key] = va # Keep from ckpt_a for non-float/scalar
123
+
124
+ # Include keys unique to each
125
+ for key in only_a:
126
+ merged[key] = sd_a[key]
127
+ for key in only_b:
128
+ merged[key] = sd_b[key]
129
+
130
+ return merged
131
+
132
+
133
+ def main():
134
+ parser = argparse.ArgumentParser(description="SLERP checkpoint merge")
135
+ parser.add_argument("--ckpt_a", type=Path, required=True,
136
+ help="Path to checkpoint A (e.g., SFT)")
137
+ parser.add_argument("--ckpt_b", type=Path, required=True,
138
+ help="Path to checkpoint B (e.g., DPO)")
139
+ parser.add_argument("--output", type=Path, required=True,
140
+ help="Output checkpoint directory")
141
+ parser.add_argument("--alpha", type=float, default=0.5,
142
+ help="Interpolation factor (0=A, 1=B, default 0.5)")
143
+ parser.add_argument("--method", choices=["slerp", "lerp"], default="slerp",
144
+ help="Interpolation method (default: slerp)")
145
+ args = parser.parse_args()
146
+
147
+ print(f"Merge: {args.ckpt_a.name} ←({1-args.alpha:.1%})— ({args.alpha:.1%})→ {args.ckpt_b.name}")
148
+ print(f"Method: {args.method}, alpha={args.alpha}")
149
+
150
+ # Load state dicts
151
+ print("Loading checkpoint A...")
152
+ sd_a = torch.load(args.ckpt_a / "model.pt", map_location="cpu", weights_only=True)
153
+ print(f" {len(sd_a)} keys loaded")
154
+
155
+ print("Loading checkpoint B...")
156
+ sd_b = torch.load(args.ckpt_b / "model.pt", map_location="cpu", weights_only=True)
157
+ print(f" {len(sd_b)} keys loaded")
158
+
159
+ # Merge
160
+ print("Merging...")
161
+ merged_sd = merge_state_dicts(sd_a, sd_b, alpha=args.alpha, method=args.method)
162
+ print(f" {len(merged_sd)} keys in merged state dict")
163
+
164
+ # Save
165
+ args.output.mkdir(parents=True, exist_ok=True)
166
+ torch.save(merged_sd, args.output / "model.pt")
167
+
168
+ # Copy config from ckpt_a
169
+ config_src = args.ckpt_a / "config.yaml"
170
+ if config_src.exists():
171
+ shutil.copy2(str(config_src), str(args.output / "config.yaml"))
172
+
173
+ # Copy tokenizer if available
174
+ for tok_name in ["tokenizer.json", "tokenizer.model"]:
175
+ tok_src = args.ckpt_a / tok_name
176
+ if tok_src.exists():
177
+ shutil.copy2(str(tok_src), str(args.output / tok_name))
178
+
179
+ # Write merge metadata
180
+ meta = {
181
+ "ckpt_a": str(args.ckpt_a),
182
+ "ckpt_b": str(args.ckpt_b),
183
+ "alpha": args.alpha,
184
+ "method": args.method,
185
+ }
186
+ with open(args.output / "merge_info.yaml", "w") as f:
187
+ yaml.safe_dump(meta, f)
188
+
189
+ size_mb = (args.output / "model.pt").stat().st_size / 1e6
190
+ print(f"\nMerged checkpoint saved → {args.output} ({size_mb:.0f} MB)")
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
scripts/orpo_native.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train/orpo_native.py — ORPO (Odds Ratio Preference Optimization) training.
3
+
4
+ Native ORPO implementation (no TRL, no HuggingFace Trainer) for EVAFRILL-Mo
5
+ hybrid Mamba-2+Transformer models. Unlike DPO, ORPO requires NO reference model
6
+ and performs SFT + alignment in a single training stage, making it ideal for
7
+ starting from a raw pretrained checkpoint.
8
+
9
+ Reference: Hong et al., "ORPO: Monolithic Preference Optimization without
10
+ Reference Model" (2024), https://arxiv.org/abs/2403.07691
11
+
12
+ Loss:
13
+ L_ORPO = L_SFT + λ * L_OR
14
+ L_SFT = CrossEntropy(chosen_logits, chosen_labels)
15
+ L_OR = -E[log σ(log(odds_chosen / odds_rejected))]
16
+ odds(x) = P(x) / (1 - P(x)), P(x) = exp(avg_log_prob(x))
17
+
18
+ Launch:
19
+ python train/orpo_native.py \
20
+ --pretrained_checkpoint checkpoints/3b_final/checkpoint-0319772 \
21
+ --preference_data data/preference/combined_preference.jsonl \
22
+ --config configs/h100_mig/dpo_3b_1gpu.yaml \
23
+ --device cuda:0
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import datetime
30
+ import os
31
+ import random
32
+ import signal
33
+ import shutil
34
+ import sys
35
+ from pathlib import Path
36
+
37
+ import numpy as np
38
+ import torch
39
+ import torch.nn as nn
40
+ import torch.nn.functional as F
41
+ from torch.utils.data import DataLoader, RandomSampler
42
+
43
+ torch.backends.cuda.matmul.allow_tf32 = True
44
+ torch.backends.cudnn.allow_tf32 = True
45
+ torch.set_float32_matmul_precision("high")
46
+
47
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent
48
+ if str(_PROJECT_ROOT) not in sys.path:
49
+ sys.path.insert(0, str(_PROJECT_ROOT))
50
+
51
+ from model import LLM
52
+ from model.lora import apply_lora, get_lora_params, merge_lora, save_lora
53
+ from data.dpo_dataset import DPODataset, dpo_collate_fn
54
+ from train.utils import (
55
+ get_cosine_schedule_with_warmup,
56
+ is_main_process,
57
+ save_checkpoint,
58
+ load_checkpoint,
59
+ )
60
+
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # Argument parsing
64
+ # ---------------------------------------------------------------------------
65
+
66
+ def parse_args() -> argparse.Namespace:
67
+ parser = argparse.ArgumentParser(description="ORPO Training for EVAFRILL-Mo")
68
+
69
+ # Paths
70
+ parser.add_argument("--pretrained_checkpoint", type=Path, required=True,
71
+ help="Path to pretrained model checkpoint directory "
72
+ "(e.g. checkpoints/3b_final/checkpoint-0319772)")
73
+ parser.add_argument("--preference_data", type=Path, required=True,
74
+ help="Path to preference JSONL data (prompt/chosen/rejected)")
75
+ parser.add_argument("--checkpoint_dir", type=Path, default=Path("checkpoints/3b_orpo"),
76
+ help="Output checkpoint directory (default: checkpoints/3b_orpo)")
77
+ parser.add_argument("--resume", type=Path, default=None,
78
+ help="Resume training from an existing ORPO checkpoint directory")
79
+ parser.add_argument("--tokenizer", type=Path, default=None,
80
+ help="Path to tokenizer.json (auto-detected if omitted)")
81
+ parser.add_argument("--log_file", type=Path, default=None,
82
+ help="Append logs to this file in addition to stdout")
83
+ parser.add_argument("--config", type=Path, default=None,
84
+ help="YAML config to load defaults from (train: section)")
85
+
86
+ # ORPO hyperparameters
87
+ parser.add_argument("--lambda_or", type=float, default=1.0,
88
+ help="ORPO odds-ratio loss weight λ (default: 1.0)")
89
+ parser.add_argument("--max_steps", type=int, default=3000,
90
+ help="Total optimisation steps (default: 3000)")
91
+ parser.add_argument("--batch_size", type=int, default=1,
92
+ help="Per-step micro-batch size (default: 1)")
93
+ parser.add_argument("--grad_accum", type=int, default=16,
94
+ help="Gradient accumulation steps (default: 16)")
95
+ parser.add_argument("--lr", type=float, default=5e-6,
96
+ help="Peak learning rate (default: 5e-6; higher than DPO because "
97
+ "ORPO starts from pretrained, not SFT)")
98
+ parser.add_argument("--weight_decay", type=float, default=0.01)
99
+ parser.add_argument("--warmup_steps", type=int, default=100)
100
+ parser.add_argument("--max_length", type=int, default=1024)
101
+ parser.add_argument("--seed", type=int, default=42)
102
+
103
+ # LoRA
104
+ parser.add_argument("--use_lora", action="store_true", default=True,
105
+ help="Use LoRA adapters for memory-efficient training (default: on)")
106
+ parser.add_argument("--lora_rank", type=int, default=32)
107
+ parser.add_argument("--lora_alpha", type=float, default=64.0)
108
+
109
+ # Infrastructure
110
+ parser.add_argument("--device", type=str, default=None,
111
+ help="Device string, e.g. cuda:0 (auto-detected if omitted)")
112
+ parser.add_argument("--save_interval", type=int, default=500)
113
+ parser.add_argument("--log_interval", type=int, default=10)
114
+ parser.add_argument("--num_workers", type=int, default=4)
115
+
116
+ args, _ = parser.parse_known_args()
117
+
118
+ # Load YAML config and apply as defaults (CLI flags override YAML)
119
+ if args.config is not None:
120
+ if not args.config.exists():
121
+ raise FileNotFoundError(f"Config not found: {args.config}")
122
+ import yaml
123
+ with open(args.config) as f:
124
+ cfg = yaml.safe_load(f)
125
+ train_cfg = cfg.get("train", {})
126
+ yaml_map = {
127
+ "max_steps": "max_steps",
128
+ "batch_size": "batch_size",
129
+ "grad_accum_steps": "grad_accum",
130
+ "lr": "lr",
131
+ "weight_decay": "weight_decay",
132
+ "warmup_steps": "warmup_steps",
133
+ "lambda_or": "lambda_or",
134
+ "max_length": "max_length",
135
+ "save_interval": "save_interval",
136
+ "log_interval": "log_interval",
137
+ "use_lora": "use_lora",
138
+ "lora_rank": "lora_rank",
139
+ "lora_alpha": "lora_alpha",
140
+ }
141
+ defaults: dict = {}
142
+ for yk, ak in yaml_map.items():
143
+ if yk in train_cfg:
144
+ defaults[ak] = train_cfg[yk]
145
+ if defaults:
146
+ parser.set_defaults(**defaults)
147
+
148
+ return parser.parse_args()
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Utilities
153
+ # ---------------------------------------------------------------------------
154
+
155
+ def set_seed(seed: int) -> None:
156
+ random.seed(seed)
157
+ np.random.seed(seed)
158
+ torch.manual_seed(seed)
159
+ torch.cuda.manual_seed_all(seed)
160
+
161
+
162
+ def _resolve_tokenizer_path(args: argparse.Namespace) -> Path:
163
+ """Find tokenizer.json: explicit flag > checkpoint dir > project default."""
164
+ if args.tokenizer is not None:
165
+ return Path(args.tokenizer)
166
+ ckpt_tok = args.pretrained_checkpoint / "tokenizer.json"
167
+ if ckpt_tok.exists():
168
+ return ckpt_tok
169
+ default_tok = _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json"
170
+ if default_tok.exists():
171
+ return default_tok
172
+ raise FileNotFoundError(
173
+ "Cannot find tokenizer.json. Provide --tokenizer or place it in the checkpoint dir."
174
+ )
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # ORPO loss
179
+ # ---------------------------------------------------------------------------
180
+
181
+ def get_avg_log_prob(
182
+ logits: torch.Tensor,
183
+ labels: torch.Tensor,
184
+ ) -> torch.Tensor:
185
+ """Compute average log probability over non-masked (response) tokens.
186
+
187
+ Args:
188
+ logits: (B, T, V) raw model logits — already in float32.
189
+ labels: (B, T) token ids; -1 marks prompt/padding positions to ignore.
190
+
191
+ Returns:
192
+ (B,) mean log probability over response tokens per sample.
193
+ Returns 0 for samples where no response token is present (shouldn't
194
+ happen with well-formed data, but guarded for safety).
195
+ """
196
+ log_probs = F.log_softmax(logits.float(), dim=-1) # (B, T, V)
197
+
198
+ mask = labels != -1 # (B, T) True = response token
199
+ safe_labels = labels.clamp(min=0) # replace -1 with 0 for gather
200
+ per_token_logps = log_probs.gather(
201
+ -1, safe_labels.unsqueeze(-1)
202
+ ).squeeze(-1) # (B, T)
203
+
204
+ # Zero out masked positions
205
+ per_token_logps = per_token_logps * mask.float() # (B, T)
206
+
207
+ # Average over response tokens; clamp denominator to avoid div-by-zero
208
+ n_tokens = mask.float().sum(dim=-1).clamp(min=1.0) # (B,)
209
+ return per_token_logps.sum(dim=-1) / n_tokens # (B,)
210
+
211
+
212
+ def compute_orpo_loss(
213
+ model: nn.Module,
214
+ chosen_ids: torch.Tensor,
215
+ chosen_labels: torch.Tensor,
216
+ rejected_ids: torch.Tensor,
217
+ rejected_labels: torch.Tensor,
218
+ lambda_or: float = 1.0,
219
+ vocab_size: int | None = None,
220
+ ) -> tuple[torch.Tensor, float, float]:
221
+ """Compute ORPO loss = SFT loss + λ * OR loss.
222
+
223
+ No reference model is needed. The SFT loss trains the model to generate
224
+ chosen responses; the OR loss simultaneously teaches the model to prefer
225
+ chosen over rejected by maximising the log odds ratio.
226
+
227
+ Args:
228
+ model: The policy model (frozen base + trainable LoRA).
229
+ chosen_ids: (B, T) token ids for chosen sequences.
230
+ chosen_labels: (B, T) labels for chosen; -1 on prompt tokens.
231
+ rejected_ids: (B, T) token ids for rejected sequences.
232
+ rejected_labels: (B, T) labels for rejected; -1 on prompt tokens.
233
+ lambda_or: Weight of the OR loss term (paper default = 1.0).
234
+ vocab_size: Vocabulary size for reshape; inferred from logits if None.
235
+
236
+ Returns:
237
+ (total_loss, sft_loss_scalar, or_loss_scalar)
238
+ """
239
+ # -----------------------------------------------------------------------
240
+ # 1. Forward pass — chosen
241
+ # -----------------------------------------------------------------------
242
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
243
+ chosen_logits, _ = model(chosen_ids) # (B, T, V)
244
+
245
+ # Infer vocab size from logits if not given
246
+ V = chosen_logits.size(-1) if vocab_size is None else vocab_size
247
+
248
+ # SFT loss: next-token prediction on response positions only.
249
+ # logits[:, :-1] predicts labels[:, 1:] (standard causal shift).
250
+ sft_logits = chosen_logits[:, :-1].contiguous().reshape(-1, V).float()
251
+ sft_targets = chosen_labels[:, 1:].contiguous().reshape(-1)
252
+
253
+ # F.cross_entropy ignores index -1 via ignore_index; -1 covers prompt tokens
254
+ # AND the last padding position shifted out of the window.
255
+ sft_loss: torch.Tensor = F.cross_entropy(sft_logits, sft_targets, ignore_index=-1)
256
+
257
+ # Average log-prob over response tokens (used for OR computation)
258
+ # Labels are NOT shifted here — get_avg_log_prob handles the alignment
259
+ # by using labels directly as targets at each position.
260
+ chosen_avg_logp: torch.Tensor = get_avg_log_prob(chosen_logits.float(), chosen_labels)
261
+
262
+ # -----------------------------------------------------------------------
263
+ # 2. Forward pass — rejected
264
+ # -----------------------------------------------------------------------
265
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
266
+ rejected_logits, _ = model(rejected_ids) # (B, T, V)
267
+
268
+ rejected_avg_logp: torch.Tensor = get_avg_log_prob(rejected_logits.float(), rejected_labels)
269
+
270
+ # -----------------------------------------------------------------------
271
+ # 3. Odds ratio loss
272
+ #
273
+ # odds(x) = P(x) / (1 - P(x))
274
+ # log odds = log P(x) - log(1 - P(x)) = log P(x) - log1p(-exp(log P(x)))
275
+ #
276
+ # We use log1p(-exp(·)) with clamping to keep values numerically stable:
277
+ # - avg_log_prob is always ≤ 0
278
+ # - exp(avg_log_prob) ∈ (0, 1] → 1 - exp ∈ [0, 1)
279
+ # - clamp to avoid log(0) when avg_log_prob ≈ 0 (very high confidence)
280
+ # -----------------------------------------------------------------------
281
+ # Clamp to (-33, -1e-6): upper bound avoids 1-exp≈0 → log(0); lower keeps
282
+ # values finite (exp(-33) ≈ 5e-15, no underflow in float32).
283
+ eps_low, eps_high = -33.0, -1e-6
284
+
285
+ chosen_avg_logp_clamped = chosen_avg_logp.clamp(eps_low, eps_high)
286
+ rejected_avg_logp_clamped = rejected_avg_logp.clamp(eps_low, eps_high)
287
+
288
+ log_odds_chosen = chosen_avg_logp_clamped - torch.log1p(-chosen_avg_logp_clamped.exp())
289
+ log_odds_rejected = rejected_avg_logp_clamped - torch.log1p(-rejected_avg_logp_clamped.exp())
290
+
291
+ log_odds_ratio = log_odds_chosen - log_odds_rejected # (B,)
292
+ or_loss: torch.Tensor = -F.logsigmoid(log_odds_ratio).mean()
293
+
294
+ # -----------------------------------------------------------------------
295
+ # 4. Combined loss
296
+ # -----------------------------------------------------------------------
297
+ total_loss = sft_loss + lambda_or * or_loss
298
+
299
+ return total_loss, sft_loss.item(), or_loss.item()
300
+
301
+
302
+ # ---------------------------------------------------------------------------
303
+ # Main training loop
304
+ # ---------------------------------------------------------------------------
305
+
306
+ def main() -> None:
307
+ args = parse_args()
308
+ set_seed(args.seed)
309
+
310
+ # ------------------------------------------------------------------
311
+ # Device
312
+ # ------------------------------------------------------------------
313
+ if args.device:
314
+ device = torch.device(args.device)
315
+ elif torch.cuda.is_available():
316
+ device = torch.device("cuda:0")
317
+ else:
318
+ device = torch.device("cpu")
319
+
320
+ # ------------------------------------------------------------------
321
+ # Load pretrained model
322
+ # ------------------------------------------------------------------
323
+ if not args.pretrained_checkpoint.exists():
324
+ raise FileNotFoundError(
325
+ f"Pretrained checkpoint not found: {args.pretrained_checkpoint}"
326
+ )
327
+
328
+ print(f"Loading pretrained model from {args.pretrained_checkpoint} ...")
329
+ model: nn.Module = LLM.from_pretrained(args.pretrained_checkpoint)
330
+ model.config.use_fp8 = False # H100 MIG: BF16 only; B200 may set fp8 via config
331
+ model = model.to(device=device, dtype=torch.bfloat16)
332
+
333
+ # Gradient checkpointing — reduces VRAM at cost of ~20% speed
334
+ if hasattr(model, "gradient_checkpointing_enable"):
335
+ model.gradient_checkpointing_enable()
336
+ print("[INFO] Gradient checkpointing enabled")
337
+
338
+ # ------------------------------------------------------------------
339
+ # LoRA
340
+ # ------------------------------------------------------------------
341
+ if args.use_lora:
342
+ n_lora = apply_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
343
+ lora_params = get_lora_params(model)
344
+ print(f"[INFO] LoRA: {n_lora:,} trainable params "
345
+ f"(rank={args.lora_rank}, alpha={args.lora_alpha})")
346
+ else:
347
+ lora_params = None
348
+ print("[INFO] Full fine-tuning (all parameters trainable)")
349
+
350
+ total_params = sum(p.numel() for p in model.parameters())
351
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
352
+ print(f"Total params: {total_params:,} | Trainable: {trainable_params:,}")
353
+
354
+ # ------------------------------------------------------------------
355
+ # Tokenizer
356
+ # ------------------------------------------------------------------
357
+ tokenizer_path = _resolve_tokenizer_path(args)
358
+ print(f"Loading tokenizer from {tokenizer_path}")
359
+ from tokenizers import Tokenizer
360
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
361
+
362
+ # ------------------------------------------------------------------
363
+ # Dataset & DataLoader
364
+ # ------------------------------------------------------------------
365
+ train_dataset = DPODataset(
366
+ data_path=args.preference_data,
367
+ tokenizer=tokenizer,
368
+ max_seq_len=args.max_length,
369
+ )
370
+ if len(train_dataset) == 0:
371
+ raise ValueError(f"Preference dataset is empty: {args.preference_data}")
372
+
373
+ train_loader = DataLoader(
374
+ train_dataset,
375
+ batch_size=args.batch_size,
376
+ sampler=RandomSampler(train_dataset),
377
+ num_workers=args.num_workers,
378
+ pin_memory=True,
379
+ drop_last=True,
380
+ collate_fn=dpo_collate_fn,
381
+ prefetch_factor=2,
382
+ persistent_workers=(args.num_workers > 0),
383
+ )
384
+
385
+ # ------------------------------------------------------------------
386
+ # Optimizer
387
+ # ------------------------------------------------------------------
388
+ if lora_params is not None:
389
+ opt_params = lora_params
390
+ else:
391
+ opt_params = [p for p in model.parameters() if p.requires_grad]
392
+
393
+ optimizer = torch.optim.AdamW(
394
+ opt_params,
395
+ lr=args.lr,
396
+ betas=(0.9, 0.95),
397
+ weight_decay=args.weight_decay,
398
+ fused=torch.cuda.is_available(),
399
+ )
400
+
401
+ scheduler = get_cosine_schedule_with_warmup(
402
+ optimizer=optimizer,
403
+ warmup_steps=args.warmup_steps,
404
+ total_steps=args.max_steps,
405
+ )
406
+
407
+ # ------------------------------------------------------------------
408
+ # Resume
409
+ # ------------------------------------------------------------------
410
+ start_step = 0
411
+ if args.resume is not None:
412
+ if not args.resume.exists():
413
+ raise FileNotFoundError(f"Resume checkpoint not found: {args.resume}")
414
+ start_step, _ = load_checkpoint(args.resume, model, optimizer, scheduler)
415
+ print(f"Resumed from step {start_step}")
416
+
417
+ # ------------------------------------------------------------------
418
+ # Output directory & tokenizer copy
419
+ # ------------------------------------------------------------------
420
+ args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
421
+ dest_tok = args.checkpoint_dir / "tokenizer.json"
422
+ if not dest_tok.exists():
423
+ shutil.copy2(str(tokenizer_path), str(dest_tok))
424
+
425
+ # ------------------------------------------------------------------
426
+ # Logger
427
+ # ------------------------------------------------------------------
428
+ log_fh = None
429
+ if args.log_file:
430
+ Path(args.log_file).parent.mkdir(parents=True, exist_ok=True)
431
+ log_fh = open(args.log_file, "a", encoding="utf-8", buffering=1)
432
+
433
+ def log(msg: str, level: str = "INFO") -> None:
434
+ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
435
+ line = f"[{ts}] [{level}] {msg}"
436
+ print(line, flush=True)
437
+ if log_fh:
438
+ log_fh.write(line + "\n")
439
+
440
+ # ------------------------------------------------------------------
441
+ # Training banner
442
+ # ------------------------------------------------------------------
443
+ eff_batch = args.batch_size * args.grad_accum
444
+ log("=" * 65)
445
+ log("ORPO Training — EVAFRILL-Mo")
446
+ log(f" Pretrained ckpt : {args.pretrained_checkpoint}")
447
+ log(f" Preference data : {args.preference_data} ({len(train_dataset):,} samples)")
448
+ log(f" LoRA : rank={args.lora_rank} alpha={args.lora_alpha} "
449
+ f"enabled={args.use_lora}")
450
+ log(f" lambda_or={args.lambda_or}, lr={args.lr:.2e}, eff_batch={eff_batch}")
451
+ log(f" max_steps={args.max_steps}, warmup={args.warmup_steps}, "
452
+ f"max_len={args.max_length}")
453
+ log(f" device={device}")
454
+ log("=" * 65)
455
+
456
+ # ------------------------------------------------------------------
457
+ # Graceful shutdown handler
458
+ # ------------------------------------------------------------------
459
+ shutdown_requested = False
460
+
461
+ def shutdown_handler(signum, frame):
462
+ nonlocal shutdown_requested
463
+ shutdown_requested = True
464
+ log(f"Shutdown signal received (sig={signum}). Saving checkpoint ...", "WARN")
465
+
466
+ signal.signal(signal.SIGTERM, shutdown_handler)
467
+ signal.signal(signal.SIGINT, shutdown_handler)
468
+ try:
469
+ signal.signal(signal.SIGHUP, shutdown_handler)
470
+ except AttributeError:
471
+ pass # Windows does not have SIGHUP
472
+
473
+ # ------------------------------------------------------------------
474
+ # Data iterator (infinite, cycling through epochs)
475
+ # ------------------------------------------------------------------
476
+ import time
477
+
478
+ epoch = 0
479
+ loader_iter = iter(train_loader)
480
+
481
+ def next_batch() -> tuple[torch.Tensor, ...]:
482
+ nonlocal loader_iter, epoch
483
+ try:
484
+ return next(loader_iter)
485
+ except StopIteration:
486
+ epoch += 1
487
+ log(f"--- Epoch {epoch} begin ---")
488
+ loader_iter = iter(train_loader)
489
+ return next(loader_iter)
490
+
491
+ # ------------------------------------------------------------------
492
+ # Training loop
493
+ # ------------------------------------------------------------------
494
+ model.train()
495
+
496
+ # Running statistics (reset every log_interval steps)
497
+ running_total_loss = 0.0
498
+ running_sft_loss = 0.0
499
+ running_or_loss = 0.0
500
+ log_step_count = 0
501
+ t0 = time.perf_counter()
502
+
503
+ # Keep track of the last loss value for the final checkpoint call
504
+ avg_loss = float("nan")
505
+
506
+ for step in range(start_step, args.max_steps):
507
+ optimizer.zero_grad(set_to_none=True)
508
+
509
+ accum_total = 0.0
510
+ accum_sft = 0.0
511
+ accum_or = 0.0
512
+
513
+ # ---- Gradient accumulation ----------------------------------------
514
+ for _micro in range(args.grad_accum):
515
+ batch = next_batch()
516
+ chosen_ids = batch[0].to(device, dtype=torch.long, non_blocking=True)
517
+ chosen_labels = batch[1].to(device, dtype=torch.long, non_blocking=True)
518
+ rejected_ids = batch[2].to(device, dtype=torch.long, non_blocking=True)
519
+ rejected_labels = batch[3].to(device, dtype=torch.long, non_blocking=True)
520
+
521
+ loss, sft_l, or_l = compute_orpo_loss(
522
+ model,
523
+ chosen_ids, chosen_labels,
524
+ rejected_ids, rejected_labels,
525
+ lambda_or=args.lambda_or,
526
+ )
527
+
528
+ scaled_loss = loss / args.grad_accum
529
+ scaled_loss.backward()
530
+
531
+ accum_total += loss.item()
532
+ accum_sft += sft_l
533
+ accum_or += or_l
534
+
535
+ # ---- Gradient clipping --------------------------------------------
536
+ grad_norm = torch.nn.utils.clip_grad_norm_(
537
+ [p for p in model.parameters() if p.requires_grad],
538
+ max_norm=1.0,
539
+ ).item()
540
+
541
+ optimizer.step()
542
+ scheduler.step()
543
+
544
+ # ---- Accumulate stats ---------------------------------------------
545
+ avg_total = accum_total / args.grad_accum
546
+ avg_sft = accum_sft / args.grad_accum
547
+ avg_or = accum_or / args.grad_accum
548
+
549
+ running_total_loss += avg_total
550
+ running_sft_loss += avg_sft
551
+ running_or_loss += avg_or
552
+ log_step_count += 1
553
+ avg_loss = avg_total # for use in checkpoint call
554
+
555
+ # ---- Graceful shutdown check --------------------------------------
556
+ if shutdown_requested:
557
+ log(f"Graceful shutdown at step {step + 1}", "WARN")
558
+ ckpt_path = save_checkpoint(
559
+ model, optimizer, scheduler,
560
+ step + 1, avg_loss, str(args.checkpoint_dir)
561
+ )
562
+ if args.use_lora:
563
+ save_lora(model, args.checkpoint_dir / f"lora-{step+1:07d}")
564
+ log(f"Checkpoint saved -> {ckpt_path}")
565
+ break
566
+
567
+ # ---- Logging ------------------------------------------------------
568
+ if (step + 1) % args.log_interval == 0:
569
+ t1 = time.perf_counter()
570
+ elapsed = t1 - t0
571
+
572
+ mean_total = running_total_loss / log_step_count
573
+ mean_sft = running_sft_loss / log_step_count
574
+ mean_or = running_or_loss / log_step_count
575
+ lr_now = scheduler.get_last_lr()[0]
576
+ mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
577
+ sps = log_step_count / max(elapsed, 1e-6) # steps per second
578
+
579
+ log(
580
+ f"step {step+1:>6d}/{args.max_steps} | "
581
+ f"loss {mean_total:.4f} "
582
+ f"(sft {mean_sft:.4f} or {mean_or:.4f}) | "
583
+ f"lr {lr_now:.2e} | "
584
+ f"gnorm {grad_norm:.3f} | "
585
+ f"mem {mem_gb:.1f}GB | "
586
+ f"{sps:.2f}step/s"
587
+ )
588
+
589
+ running_total_loss = 0.0
590
+ running_sft_loss = 0.0
591
+ running_or_loss = 0.0
592
+ log_step_count = 0
593
+ t0 = t1
594
+
595
+ # ---- Periodic checkpoint ------------------------------------------
596
+ if (step + 1) % args.save_interval == 0:
597
+ ckpt_path = save_checkpoint(
598
+ model, optimizer, scheduler,
599
+ step + 1, avg_loss, str(args.checkpoint_dir)
600
+ )
601
+ if args.use_lora:
602
+ save_lora(model, args.checkpoint_dir / f"lora-{step+1:07d}")
603
+ log(f"Checkpoint saved -> {ckpt_path}")
604
+
605
+ # -----------------------------------------------------------------------
606
+ # Final checkpoint
607
+ # -----------------------------------------------------------------------
608
+ if not shutdown_requested:
609
+ final_path = save_checkpoint(
610
+ model, optimizer, scheduler,
611
+ args.max_steps, avg_loss, str(args.checkpoint_dir)
612
+ )
613
+ if args.use_lora:
614
+ save_lora(model, args.checkpoint_dir / "lora-final")
615
+ log(f"Final checkpoint -> {final_path}")
616
+
617
+ # -----------------------------------------------------------------------
618
+ # LoRA merge + save merged model
619
+ # -----------------------------------------------------------------------
620
+ if args.use_lora:
621
+ log("Merging LoRA weights into base model ...")
622
+ merge_lora(model)
623
+ merged_dir = args.checkpoint_dir / "checkpoint-merged"
624
+ model.save_pretrained(merged_dir)
625
+ # Also copy tokenizer into merged dir for easy inference
626
+ shutil.copy2(str(dest_tok), str(merged_dir / "tokenizer.json"))
627
+ log(f"Merged model saved -> {merged_dir}")
628
+
629
+ log("ORPO training complete.")
630
+
631
+ if log_fh:
632
+ log_fh.close()
633
+
634
+
635
+ if __name__ == "__main__":
636
+ main()
scripts/sft.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train/sft.py — Supervised Fine-Tuning (SFT) entry point.
3
+
4
+ Loads a pretrained checkpoint and fine-tunes it on instruction/conversation
5
+ data using SFTDataset, which masks prompt tokens with ignore_index=-1 so only
6
+ the assistant response tokens contribute to the loss.
7
+
8
+ Launch single-GPU:
9
+ python train/sft.py \\
10
+ --base_checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \\
11
+ --sft_data data/sft/train.jsonl \\
12
+ --device cuda:0
13
+
14
+ Launch multi-GPU (DDP via torchrun, 7 GPU):
15
+ torchrun --nproc_per_node=7 train/sft.py \\
16
+ --base_checkpoint checkpoints/3b_final/checkpoint-0319772 \\
17
+ --sft_data data/sft_combined/train_filtered.jsonl
18
+
19
+ KEY DIFFERENCES from pretrain.py:
20
+ - Loads weights from a pretrained checkpoint via LLM.from_pretrained()
21
+ - Uses SFTDataset (JSONL instruction data) instead of PackedDataset
22
+ - Lower default learning rate (2e-5 vs 2e-4)
23
+ - Fewer default steps (3000 vs 100000)
24
+ - Copies tokenizer.json to checkpoint_dir for easy deployment
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import argparse
30
+ import os
31
+ import random
32
+ import signal
33
+ import shutil
34
+ import sys
35
+ from pathlib import Path
36
+
37
+ import numpy as np
38
+ import torch
39
+ import torch.nn.functional as F
40
+ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
41
+
42
+ # B200 Tensor Core 최대 활용: TF32 matmul + cuDNN
43
+ torch.backends.cuda.matmul.allow_tf32 = True
44
+ torch.backends.cudnn.allow_tf32 = True
45
+ torch.set_float32_matmul_precision("high") # TF32 precision for fp32 matmul
46
+
47
+ # Allow imports from the project root regardless of working directory.
48
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent
49
+ if str(_PROJECT_ROOT) not in sys.path:
50
+ sys.path.insert(0, str(_PROJECT_ROOT))
51
+
52
+ from model import LLM
53
+ from train.trainer import TrainConfig, Trainer
54
+ from train.utils import (
55
+ cleanup_ddp,
56
+ get_cosine_schedule_with_warmup,
57
+ is_main_process,
58
+ load_checkpoint,
59
+ setup_ddp,
60
+ )
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # Optional TransformerEngine import (FP8 support)
64
+ # ---------------------------------------------------------------------------
65
+ try:
66
+ import transformer_engine.pytorch as te # type: ignore[import]
67
+ HAS_TE = True
68
+ except ImportError:
69
+ te = None # type: ignore[assignment]
70
+ HAS_TE = False
71
+
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # Argument parsing
75
+ # ---------------------------------------------------------------------------
76
+
77
+
78
+ def parse_args() -> argparse.Namespace:
79
+ parser = argparse.ArgumentParser(
80
+ description="Supervised Fine-Tuning (SFT) of a pretrained decoder-only LLM.",
81
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
82
+ )
83
+
84
+ # --- Required paths -----------------------------------------------------
85
+ parser.add_argument(
86
+ "--base_checkpoint",
87
+ type=Path,
88
+ required=True,
89
+ help=(
90
+ "Path to the pretrained checkpoint directory. "
91
+ "Must contain model.pt and config.yaml (produced by save_checkpoint)."
92
+ ),
93
+ )
94
+ parser.add_argument(
95
+ "--sft_data",
96
+ type=Path,
97
+ required=True,
98
+ help="Path to the JSONL SFT training data file.",
99
+ )
100
+
101
+ # --- Optional paths -----------------------------------------------------
102
+ parser.add_argument(
103
+ "--val_data",
104
+ type=Path,
105
+ default=None,
106
+ help="Optional path to JSONL SFT validation data file.",
107
+ )
108
+ parser.add_argument(
109
+ "--checkpoint_dir",
110
+ type=Path,
111
+ default=Path("checkpoints/korean_1b_sft"),
112
+ help="Root directory for saving SFT checkpoints.",
113
+ )
114
+ parser.add_argument(
115
+ "--resume",
116
+ type=Path,
117
+ default=None,
118
+ help="Path to an SFT checkpoint directory to resume fine-tuning from.",
119
+ )
120
+ parser.add_argument(
121
+ "--tokenizer",
122
+ type=Path,
123
+ default=None,
124
+ help=(
125
+ "Override path to tokenizer.json. "
126
+ "Defaults to <base_checkpoint>/tokenizer.json, "
127
+ "then falls back to tokenizer/korean_sp/tokenizer.json."
128
+ ),
129
+ )
130
+ parser.add_argument(
131
+ "--log_file",
132
+ type=Path,
133
+ default=None,
134
+ help=(
135
+ "Path to a text file for structured training logs (rank-0 only). "
136
+ "If omitted, logs go only to stdout."
137
+ ),
138
+ )
139
+
140
+ # --- Training hyper-parameters ------------------------------------------
141
+ parser.add_argument(
142
+ "--max_steps",
143
+ type=int,
144
+ default=3000,
145
+ help="Total number of optimiser steps.",
146
+ )
147
+ parser.add_argument(
148
+ "--batch_size",
149
+ type=int,
150
+ default=4,
151
+ help="Per-GPU micro-batch size.",
152
+ )
153
+ parser.add_argument(
154
+ "--lr",
155
+ type=float,
156
+ default=2e-5,
157
+ help=(
158
+ "Peak learning rate. "
159
+ "SFT uses a much lower lr than pretraining (2e-5 vs 2e-4) "
160
+ "to preserve pretrained representations."
161
+ ),
162
+ )
163
+ parser.add_argument(
164
+ "--weight_decay",
165
+ type=float,
166
+ default=0.01,
167
+ help="AdamW weight decay. Lower than pretrain (0.01 vs 0.1).",
168
+ )
169
+ parser.add_argument(
170
+ "--warmup_steps",
171
+ type=int,
172
+ default=100,
173
+ help="Number of linear LR warmup steps.",
174
+ )
175
+ parser.add_argument(
176
+ "--grad_accum",
177
+ type=int,
178
+ default=2,
179
+ help="Gradient accumulation steps.",
180
+ )
181
+ parser.add_argument(
182
+ "--seed",
183
+ type=int,
184
+ default=42,
185
+ help="Base random seed (rank offset is added automatically in DDP).",
186
+ )
187
+ parser.add_argument(
188
+ "--use_fp8",
189
+ action="store_true",
190
+ default=False,
191
+ help=(
192
+ "Enable TransformerEngine FP8 training "
193
+ "(requires B200/H100, uses MXFP8BlockScaling)."
194
+ ),
195
+ )
196
+
197
+ # --- Single-GPU device override (ignored when using torchrun) -----------
198
+ parser.add_argument(
199
+ "--device",
200
+ type=str,
201
+ default=None,
202
+ help=(
203
+ "Explicit device string (e.g. 'cuda:0'). "
204
+ "Ignored when running under torchrun (DDP auto-assigns devices)."
205
+ ),
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--config", type=Path, default=None,
210
+ help="YAML config file. Values under 'train:' section are used as CLI defaults.",
211
+ )
212
+ parser.add_argument("--save_interval", type=int, default=500, help="Checkpoint save interval (steps).")
213
+ parser.add_argument("--eval_interval", type=int, default=250, help="Validation eval interval (steps).")
214
+ parser.add_argument("--neftune_alpha", type=float, default=5.0, help="NEFTune noise magnitude (0 to disable).")
215
+ parser.add_argument("--no_fp8", action="store_true", default=False, help="Force disable FP8 even if pretrained config has use_fp8=True.")
216
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of DataLoader worker processes.")
217
+ parser.add_argument("--max_val_batches", type=int, default=0, help="Max validation batches (0=unlimited).")
218
+
219
+ # First pass: just get --config
220
+ args, remaining = parser.parse_known_args()
221
+
222
+ # Load YAML config and apply values as defaults
223
+ if args.config is not None:
224
+ if not args.config.exists():
225
+ raise FileNotFoundError(f"Config file not found: {args.config}")
226
+ import yaml
227
+ with open(args.config, "r") as f:
228
+ yaml_cfg = yaml.safe_load(f)
229
+ train_section = yaml_cfg.get("train", {})
230
+ yaml_to_arg = {
231
+ "max_steps": "max_steps",
232
+ "batch_size": "batch_size",
233
+ "lr": "lr",
234
+ "weight_decay": "weight_decay",
235
+ "warmup_steps": "warmup_steps",
236
+ "grad_accum_steps": "grad_accum",
237
+ "save_interval": "save_interval",
238
+ "eval_interval": "eval_interval",
239
+ "neftune_alpha": "neftune_alpha",
240
+ "max_val_batches": "max_val_batches",
241
+ }
242
+ new_defaults = {}
243
+ for yaml_key, arg_name in yaml_to_arg.items():
244
+ if yaml_key in train_section:
245
+ new_defaults[arg_name] = train_section[yaml_key]
246
+ if new_defaults:
247
+ parser.set_defaults(**new_defaults)
248
+
249
+ return parser.parse_args()
250
+
251
+
252
+ # ---------------------------------------------------------------------------
253
+ # Seed helper
254
+ # ---------------------------------------------------------------------------
255
+
256
+
257
+ def set_seed(seed: int) -> None:
258
+ """Set deterministic seeds for Python, NumPy, and PyTorch."""
259
+ random.seed(seed)
260
+ np.random.seed(seed)
261
+ torch.manual_seed(seed)
262
+ torch.cuda.manual_seed_all(seed)
263
+
264
+
265
+ # ---------------------------------------------------------------------------
266
+ # Optimizer parameter groups
267
+ # (Copied from pretrain.py to avoid circular import; identical logic)
268
+ # ---------------------------------------------------------------------------
269
+
270
+
271
+ def build_optimizer_param_groups(
272
+ model: torch.nn.Module,
273
+ weight_decay: float,
274
+ ) -> list[dict]:
275
+ """
276
+ Split parameters into two groups:
277
+ - decay group : weight tensors with ndim >= 2 (Linear, etc.)
278
+ - no-decay group: bias, LayerNorm/RMSNorm weights, and embedding weights
279
+
280
+ This follows standard practice (e.g. GPT-style training).
281
+ """
282
+ decay_params: list[torch.nn.Parameter] = []
283
+ no_decay_params: list[torch.nn.Parameter] = []
284
+
285
+ # Module types whose parameters should never be decayed.
286
+ no_decay_module_types = (
287
+ torch.nn.Embedding,
288
+ torch.nn.LayerNorm,
289
+ )
290
+ # Also skip any parameter whose name ends with '.bias'.
291
+ no_decay_name_suffixes = ("bias",)
292
+
293
+ # Collect module-level exclusions.
294
+ no_decay_module_params: set[int] = set()
295
+ for module in model.modules():
296
+ if isinstance(module, no_decay_module_types):
297
+ for param in module.parameters(recurse=False):
298
+ no_decay_module_params.add(id(param))
299
+
300
+ seen: set[int] = set()
301
+ for name, param in model.named_parameters():
302
+ if not param.requires_grad:
303
+ continue
304
+ if id(param) in seen:
305
+ continue
306
+ seen.add(id(param))
307
+
308
+ if (
309
+ id(param) in no_decay_module_params
310
+ or any(name.endswith(sfx) for sfx in no_decay_name_suffixes)
311
+ or param.ndim < 2
312
+ ):
313
+ no_decay_params.append(param)
314
+ else:
315
+ decay_params.append(param)
316
+
317
+ return [
318
+ {"params": decay_params, "weight_decay": weight_decay},
319
+ {"params": no_decay_params, "weight_decay": 0.0},
320
+ ]
321
+
322
+
323
+ # ---------------------------------------------------------------------------
324
+ # Tokenizer resolution helper
325
+ # ---------------------------------------------------------------------------
326
+
327
+
328
+ def _resolve_tokenizer_path(args: argparse.Namespace) -> Path:
329
+ """
330
+ Determine the tokenizer path in priority order:
331
+ 1. Explicit --tokenizer argument
332
+ 2. tokenizer.json inside the base_checkpoint directory
333
+ 3. Project default: tokenizer/korean_sp/tokenizer.json
334
+ """
335
+ if args.tokenizer is not None:
336
+ p = Path(args.tokenizer)
337
+ if not p.exists():
338
+ raise FileNotFoundError(f"Tokenizer not found at --tokenizer path: {p}")
339
+ return p
340
+
341
+ ckpt_tok = args.base_checkpoint / "tokenizer.json"
342
+ if ckpt_tok.exists():
343
+ return ckpt_tok
344
+
345
+ default_tok = _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json"
346
+ if default_tok.exists():
347
+ return default_tok
348
+
349
+ raise FileNotFoundError(
350
+ "Could not locate tokenizer.json. Tried:\n"
351
+ f" 1. {ckpt_tok}\n"
352
+ f" 2. {default_tok}\n"
353
+ "Use --tokenizer to specify an explicit path."
354
+ )
355
+
356
+
357
+ # ---------------------------------------------------------------------------
358
+ # Dynamic padding collate function
359
+ # ---------------------------------------------------------------------------
360
+
361
+
362
+ def dynamic_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
363
+ """
364
+ Collate function that pads each batch to its own maximum sequence length
365
+ instead of a fixed global max_seq_len. This reduces wasted FLOPs on
366
+ short sequences and speeds up SFT which tends to have highly variable
367
+ response lengths.
368
+
369
+ Pads to the batch-local max, aligned to 64 tokens (for Flash Attention
370
+ efficiency), with a floor of 512 tokens so micro-batches are not too short.
371
+
372
+ Args:
373
+ batch: List of ``(input_ids, labels)`` tuples from SFTDataset.
374
+
375
+ Returns:
376
+ Tuple of ``(input_ids, labels, attention_mask)`` tensors shaped
377
+ ``[B, max_len]``.
378
+ ``input_ids`` is right-padded with 0 (pad token).
379
+ ``labels`` is right-padded with -1 (cross-entropy ignore_index).
380
+ ``attention_mask`` is 1 for real tokens, 0 for padding.
381
+ """
382
+ # 64-token alignment + minimum 512 floor
383
+ raw_max = max(item[0].size(0) for item in batch)
384
+ max_len = max(512, ((raw_max + 63) // 64) * 64)
385
+
386
+ input_ids_list, labels_list, mask_list = [], [], []
387
+ for ids, labs in batch:
388
+ pad_len = max_len - ids.size(0)
389
+ input_ids_list.append(F.pad(ids, (0, pad_len), value=0))
390
+ labels_list.append(F.pad(labs, (0, pad_len), value=-1))
391
+ mask_list.append(
392
+ F.pad(torch.ones(ids.size(0), dtype=torch.long), (0, pad_len), value=0)
393
+ )
394
+
395
+ return (
396
+ torch.stack(input_ids_list),
397
+ torch.stack(labels_list),
398
+ torch.stack(mask_list),
399
+ )
400
+
401
+
402
+ # ---------------------------------------------------------------------------
403
+ # NEFTune helper
404
+ # ---------------------------------------------------------------------------
405
+
406
+
407
+ def add_neftune_hook(model: torch.nn.Module, noise_alpha: float = 10.0):
408
+ """
409
+ Register a forward hook on the model's input embedding layer that adds
410
+ uniform noise scaled by noise_alpha during training (NEFTune).
411
+
412
+ Reference: "NEFTune: Noisy Embeddings Improve Instruction Finetuning"
413
+ (Jain et al., 2023). https://arxiv.org/abs/2310.05914
414
+
415
+ Args:
416
+ model: Raw (non-DDP) model instance.
417
+ noise_alpha: Noise magnitude parameter (paper default: 10).
418
+
419
+ Returns:
420
+ The hook handle (call ``handle.remove()`` to deactivate), or None if
421
+ the embedding layer could not be located.
422
+ """
423
+ # Unwrap DDP if needed
424
+ raw = model.module if hasattr(model, "module") else model
425
+
426
+ # 1) Try the standard HuggingFace accessor first.
427
+ embedding: torch.nn.Embedding | None = None
428
+ if hasattr(raw, "get_input_embeddings"):
429
+ try:
430
+ emb = raw.get_input_embeddings()
431
+ if isinstance(emb, torch.nn.Embedding):
432
+ embedding = emb
433
+ except Exception:
434
+ pass
435
+
436
+ # 2) Fallback: walk common attribute paths found in open-source LLMs.
437
+ if embedding is None:
438
+ for attr_path in [
439
+ "embedding",
440
+ "embed_tokens",
441
+ "token_embedding",
442
+ "wte",
443
+ "word_embeddings",
444
+ "tok_embeddings",
445
+ "transformer.wte",
446
+ "model.embed_tokens",
447
+ "model.embedding",
448
+ ]:
449
+ obj = raw
450
+ for part in attr_path.split("."):
451
+ obj = getattr(obj, part, None)
452
+ if obj is None:
453
+ break
454
+ if obj is not None and isinstance(obj, torch.nn.Embedding):
455
+ embedding = obj
456
+ break
457
+
458
+ if embedding is None:
459
+ print("[WARN] NEFTune: embedding layer을 찾지 못함, NEFTune 비활성화")
460
+ return None
461
+
462
+ print(
463
+ f"[INFO] NEFTune: {type(embedding).__name__} hook 등록 "
464
+ f"(shape={tuple(embedding.weight.shape)}, alpha={noise_alpha})"
465
+ )
466
+
467
+ def _hook(
468
+ module: torch.nn.Module,
469
+ inp: tuple,
470
+ out: torch.Tensor,
471
+ ) -> torch.Tensor:
472
+ if module.training:
473
+ # out shape: [B, seq_len, d_model]
474
+ mag = noise_alpha / ((out.size(1) * out.size(2)) ** 0.5)
475
+ out = out + torch.empty_like(out).uniform_(-mag, mag)
476
+ return out
477
+
478
+ return embedding.register_forward_hook(_hook)
479
+
480
+
481
+ # ---------------------------------------------------------------------------
482
+ # Main
483
+ # ---------------------------------------------------------------------------
484
+
485
+
486
+ def main() -> None:
487
+ args = parse_args()
488
+
489
+ # ---- Distributed setup -------------------------------------------------
490
+ is_ddp = "RANK" in os.environ
491
+ rank = 0
492
+ local_rank = 0
493
+ world_size = 1
494
+
495
+ if is_ddp:
496
+ rank, local_rank, world_size, device = setup_ddp()
497
+ else:
498
+ # Single-GPU: honour --device flag, else pick cuda:0 or cpu.
499
+ if args.device is not None:
500
+ device = torch.device(args.device)
501
+ elif torch.cuda.is_available():
502
+ device = torch.device("cuda:0")
503
+ else:
504
+ device = torch.device("cpu")
505
+
506
+ # Per-rank seed so data shuffling differs across replicas.
507
+ set_seed(args.seed + rank)
508
+
509
+ # ---- NUMA affinity for optimal GPU↔CPU memory locality ---------------
510
+ # B200 topology: GPU 0-3 → NUMA node 0 (cores 0-35)
511
+ # GPU 4-6 → NUMA node 1 (cores 36-71) [7 GPU 환경]
512
+ try:
513
+ if local_rank < 4:
514
+ os.sched_setaffinity(0, set(range(0, 36))) # NUMA node 0
515
+ else:
516
+ os.sched_setaffinity(0, set(range(36, 72))) # NUMA node 1
517
+ if is_main_process():
518
+ print(f"NUMA affinity: rank {rank} (GPU {local_rank}) → "
519
+ f"{'NUMA0 cores 0-35' if local_rank < 4 else 'NUMA1 cores 36-71'}")
520
+ except (AttributeError, OSError) as e:
521
+ if is_main_process():
522
+ print(f"[WARN] NUMA affinity failed: {e}")
523
+
524
+ # ---- Validate base checkpoint ------------------------------------------
525
+ if not args.base_checkpoint.exists():
526
+ raise FileNotFoundError(
527
+ f"Base checkpoint directory not found: {args.base_checkpoint}"
528
+ )
529
+ for required_file in ("model.pt", "config.yaml"):
530
+ if not (args.base_checkpoint / required_file).exists():
531
+ raise FileNotFoundError(
532
+ f"Expected {required_file} inside base checkpoint: {args.base_checkpoint}"
533
+ )
534
+
535
+ # ---- Load pretrained model ---------------------------------------------
536
+ # LLM.from_pretrained() reads config.yaml + model.pt and returns the model on CPU.
537
+ # We move it to the target device immediately after loading.
538
+ #
539
+ # NOTE: fp8_model_init() is intentionally NOT used here (same as pretrain.py).
540
+ # MXFP8Tensor weights are incompatible with DDP's _broadcast_coalesced.
541
+ # Weights stay in float32; TransformerEngine quantizes on-the-fly inside fp8_autocast.
542
+ model = LLM.from_pretrained(args.base_checkpoint)
543
+
544
+ # FP8 override: --no_fp8 forces BF16 even if pretrained config had use_fp8=True.
545
+ # --use_fp8 enables FP8 if pretrained config had it disabled.
546
+ if args.no_fp8:
547
+ model.config.use_fp8 = False
548
+ elif args.use_fp8:
549
+ model.config.use_fp8 = True
550
+
551
+ # Move model to target device in bfloat16 (more memory-efficient than fp32
552
+ # for fine-tuning, and required when BF16 autocast + TE are active).
553
+ model = model.to(device=device, dtype=torch.bfloat16)
554
+
555
+ # ---- Gradient checkpointing ----------------------------------------
556
+ # Trades activation memory for recomputation during backward pass.
557
+ # Especially useful for large models / long sequences in SFT.
558
+ if hasattr(model, 'gradient_checkpointing_enable'):
559
+ model.gradient_checkpointing_enable()
560
+ if rank == 0:
561
+ print("[INFO] Gradient checkpointing enabled")
562
+
563
+ # FP8 alignment check: (batch_size × seq_len) must be divisible by 8.
564
+ if model.config.use_fp8:
565
+ seq_len = model.config.max_seq_len
566
+ if (args.batch_size * seq_len) % 8 != 0:
567
+ raise ValueError(
568
+ f"FP8: batch_size × max_seq_len = {args.batch_size} × {seq_len} "
569
+ f"= {args.batch_size * seq_len} must be divisible by 8."
570
+ )
571
+
572
+ if is_main_process():
573
+ total_params = sum(p.numel() for p in model.parameters())
574
+ print(f"Pretrained model loaded: {total_params:,} parameters")
575
+ print(f"LMConfig: {model.config}")
576
+
577
+ # ---- Wrap in DDP -------------------------------------------------------
578
+ if is_ddp:
579
+ from torch.nn.parallel import DistributedDataParallel as DDP
580
+
581
+ model = DDP(
582
+ model,
583
+ device_ids=[local_rank],
584
+ output_device=local_rank,
585
+ gradient_as_bucket_view=True,
586
+ bucket_cap_mb=800,
587
+ find_unused_parameters=False,
588
+ )
589
+
590
+ # ---- Tokenizer ---------------------------------------------------------
591
+ tokenizer_path = _resolve_tokenizer_path(args)
592
+ if is_main_process():
593
+ print(f"Loading tokenizer from: {tokenizer_path}")
594
+
595
+ # Use the fast tokenizers library (same as the rest of the project).
596
+ from tokenizers import Tokenizer # type: ignore[import]
597
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
598
+
599
+ # ---- Dataset & DataLoader ----------------------------------------------
600
+ # Import SFTDataset (created separately alongside this file).
601
+ # SFTDataset returns (input_ids, targets) where prompt token positions in
602
+ # targets are filled with -1. The Trainer._compute_loss already uses
603
+ # ignore_index=-1, so only response tokens contribute to the gradient.
604
+ from data.sft_dataset import SFTDataset # type: ignore[import]
605
+
606
+ train_dataset = SFTDataset(
607
+ data_path=args.sft_data,
608
+ tokenizer=tokenizer,
609
+ max_seq_len=model.config.max_seq_len
610
+ if not isinstance(model, torch.nn.parallel.DistributedDataParallel)
611
+ else model.module.config.max_seq_len,
612
+ )
613
+
614
+ if is_ddp:
615
+ train_sampler: DistributedSampler | RandomSampler = DistributedSampler(
616
+ train_dataset,
617
+ num_replicas=world_size,
618
+ rank=rank,
619
+ shuffle=True,
620
+ seed=args.seed,
621
+ )
622
+ shuffle = False
623
+ else:
624
+ train_sampler = RandomSampler(train_dataset)
625
+ shuffle = False # Sampler is provided; DataLoader must not also shuffle.
626
+
627
+ train_loader = DataLoader(
628
+ train_dataset,
629
+ batch_size=args.batch_size,
630
+ sampler=train_sampler,
631
+ # SFT datasets are typically small enough that 2–4 workers suffice.
632
+ # We use 4 to balance I/O with CPU parsing overhead from JSONL.
633
+ num_workers=args.num_workers,
634
+ pin_memory=True,
635
+ drop_last=True,
636
+ prefetch_factor=2,
637
+ persistent_workers=True,
638
+ collate_fn=dynamic_collate_fn,
639
+ )
640
+
641
+ # Optional validation loader.
642
+ # NOTE: The current Trainer implementation does not yet accept a val_loader
643
+ # argument; the eval_interval config field is reserved for future use.
644
+ # We construct the loader here so that once Trainer gains eval support,
645
+ # wiring it in requires only passing val_loader=val_loader below.
646
+ val_loader: DataLoader | None = None
647
+ if args.val_data is not None:
648
+ if not args.val_data.exists():
649
+ raise FileNotFoundError(f"Validation data not found: {args.val_data}")
650
+ val_dataset = SFTDataset(
651
+ data_path=args.val_data,
652
+ tokenizer=tokenizer,
653
+ max_seq_len=train_dataset.max_seq_len,
654
+ )
655
+ val_loader = DataLoader(
656
+ val_dataset,
657
+ batch_size=args.batch_size,
658
+ shuffle=False,
659
+ num_workers=2,
660
+ pin_memory=True,
661
+ drop_last=False,
662
+ collate_fn=dynamic_collate_fn,
663
+ )
664
+ if is_main_process():
665
+ print(f"Validation dataset: {len(val_dataset):,} samples")
666
+
667
+ # ---- Optimizer ---------------------------------------------------------
668
+ # Use the same two-group split (weight_decay / no weight_decay) as pretrain.
669
+ # Unwrap DDP to get the raw model's parameters.
670
+ raw_model = getattr(model, "module", model)
671
+ param_groups = build_optimizer_param_groups(raw_model, args.weight_decay)
672
+ optimizer = torch.optim.AdamW(
673
+ param_groups,
674
+ lr=args.lr,
675
+ betas=(0.9, 0.95),
676
+ eps=1e-8,
677
+ fused=torch.cuda.is_available(), # Use fused kernel when on CUDA.
678
+ )
679
+
680
+ # ---- TrainConfig -------------------------------------------------------
681
+ # Set use_fp8 from the (possibly overridden) model config so Trainer builds
682
+ # the correct FP8 recipe and wraps forward passes in fp8_autocast.
683
+ use_fp8 = raw_model.config.use_fp8
684
+
685
+ train_config = TrainConfig(
686
+ max_steps=args.max_steps,
687
+ checkpoint_dir=str(args.checkpoint_dir),
688
+ grad_accum_steps=args.grad_accum,
689
+ use_fp8=use_fp8,
690
+ log_file=str(args.log_file) if args.log_file is not None else None,
691
+ save_interval=args.save_interval,
692
+ log_interval=10,
693
+ eval_interval=args.eval_interval,
694
+ max_val_batches=args.max_val_batches,
695
+ )
696
+
697
+ # ---- LR Scheduler ------------------------------------------------------
698
+ scheduler = get_cosine_schedule_with_warmup(
699
+ optimizer=optimizer,
700
+ warmup_steps=args.warmup_steps,
701
+ total_steps=train_config.max_steps,
702
+ )
703
+
704
+ # ---- Resume from SFT checkpoint ----------------------------------------
705
+ # When --resume is given we restore the SFT optimizer/scheduler state as
706
+ # well so learning rate, momentum buffers, etc. are correctly restored.
707
+ # NOTE: This resumes SFT training, NOT the pretrain checkpoint.
708
+ # The pretrain weights were already loaded above via from_pretrained().
709
+ start_step = 0
710
+ if args.resume is not None:
711
+ if not args.resume.exists():
712
+ raise FileNotFoundError(f"Resume checkpoint not found: {args.resume}")
713
+ start_step, resume_loss = load_checkpoint(
714
+ path=args.resume,
715
+ model=model,
716
+ optimizer=optimizer,
717
+ scheduler=scheduler,
718
+ )
719
+ if is_main_process():
720
+ print(f"Resumed SFT from {args.resume} at step {start_step} (loss={resume_loss:.4f})")
721
+
722
+ if args.resume is not None and isinstance(train_sampler, DistributedSampler):
723
+ steps_per_epoch = len(train_loader)
724
+ approx_epoch = start_step // steps_per_epoch if steps_per_epoch > 0 else 0
725
+ train_sampler.set_epoch(approx_epoch)
726
+ if is_main_process():
727
+ print(f"[INFO] Resume: sampler epoch set to {approx_epoch}")
728
+
729
+ # ---- Checkpoint directory ----------------------------------------------
730
+ args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
731
+
732
+ # ---- Copy tokenizer to checkpoint dir for easy deployment later --------
733
+ # This mirrors the tokenizer into the SFT checkpoint root so that the
734
+ # final checkpoint directory is self-contained for convert_to_hf.py, etc.
735
+ if is_main_process():
736
+ dest_tok = args.checkpoint_dir / "tokenizer.json"
737
+ if not dest_tok.exists():
738
+ shutil.copy2(str(tokenizer_path), str(dest_tok))
739
+ print(f"Tokenizer copied to {dest_tok}")
740
+
741
+ # ---- Trainer -----------------------------------------------------------
742
+ trainer = Trainer(
743
+ model=model,
744
+ train_loader=train_loader,
745
+ optimizer=optimizer,
746
+ scheduler=scheduler,
747
+ config=train_config,
748
+ device=device,
749
+ rank=rank,
750
+ sampler=train_sampler if is_ddp else None,
751
+ val_loader=val_loader,
752
+ )
753
+
754
+ # ---- Signal handlers for graceful shutdown ----------------------------
755
+ import signal as _signal_mod
756
+
757
+ _trainer_ref = trainer
758
+
759
+ def _graceful_shutdown_handler(signum, frame):
760
+ sig_name = _signal_mod.Signals(signum).name
761
+ if is_main_process():
762
+ import datetime as _dt
763
+ ts = _dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
764
+ msg = (
765
+ f"[{ts}] [SIGNAL] Received {sig_name} (signum={signum}). "
766
+ f"Initiating graceful shutdown..."
767
+ )
768
+ print(f"\n{msg}")
769
+ if args.log_file is not None:
770
+ try:
771
+ with open(args.log_file, "a", encoding="utf-8") as f:
772
+ f.write(msg + "\n")
773
+ except Exception:
774
+ pass
775
+ _trainer_ref.request_shutdown(sig_name)
776
+
777
+ for _sig in (_signal_mod.SIGHUP, _signal_mod.SIGTERM):
778
+ _signal_mod.signal(_sig, _graceful_shutdown_handler)
779
+
780
+ # ---- SFT banner --------------------------------------------------------
781
+ if is_main_process():
782
+ import datetime
783
+
784
+ inner_config = raw_model.config
785
+ eff_batch_seqs = args.batch_size * args.grad_accum * world_size
786
+ eff_tokens_per_step = eff_batch_seqs * inner_config.max_seq_len
787
+ train_samples = len(train_dataset)
788
+ precision_label = "FP8 (MXFP8BlockScaling)" if use_fp8 else "BF16"
789
+ nccl_debug = os.environ.get("NCCL_DEBUG", "not set")
790
+ omp_threads = os.environ.get("OMP_NUM_THREADS", "not set")
791
+
792
+ print(
793
+ f"\n{'='*70}\n"
794
+ f" LLM Supervised Fine-Tuning — "
795
+ f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
796
+ f"{'='*70}\n"
797
+ f" base ckpt : {args.base_checkpoint}\n"
798
+ f" sft data : {args.sft_data} ({train_samples:,} samples)\n"
799
+ f" model : {inner_config.num_params:,} params | "
800
+ f"d_model={inner_config.d_model} n_layers={inner_config.n_layers}\n"
801
+ f" precision : {precision_label}\n"
802
+ f" GPUs : {world_size} | batch/GPU={args.batch_size} "
803
+ f"grad_accum={args.grad_accum}\n"
804
+ f" eff_batch : {eff_batch_seqs} seqs "
805
+ f"= {eff_tokens_per_step:,} tok/step\n"
806
+ f" max_steps : {train_config.max_steps:,}\n"
807
+ f" lr : {args.lr:.2e} "
808
+ f"warmup={args.warmup_steps} weight_decay={args.weight_decay}\n"
809
+ f" ckpt_dir : {args.checkpoint_dir}\n"
810
+ f" env : OMP_NUM_THREADS={omp_threads} NCCL_DEBUG={nccl_debug}\n"
811
+ f"{'='*70}\n"
812
+ )
813
+
814
+ # ---- NEFTune -----------------------------------------------------------
815
+ # Add uniform noise to embeddings during training to improve instruction
816
+ # following (Jain et al., 2023). Hook is registered on the raw (non-DDP)
817
+ # model so it survives DDP's internal module wrapping.
818
+ neftune_alpha = getattr(args, 'neftune_alpha', 5.0)
819
+ neftune_handle = add_neftune_hook(raw_model, noise_alpha=neftune_alpha)
820
+ if rank == 0:
821
+ if neftune_handle is not None:
822
+ print(f"[INFO] NEFTune enabled (noise_alpha={neftune_alpha})")
823
+ else:
824
+ print("[WARN] NEFTune disabled - embedding layer not found")
825
+
826
+ # ---- Train -------------------------------------------------------------
827
+ try:
828
+ trainer.train(start_step=start_step)
829
+ except KeyboardInterrupt:
830
+ if is_main_process():
831
+ print("\n[INFO] SFT interrupted by user (KeyboardInterrupt).")
832
+ except Exception as e:
833
+ import traceback
834
+ if is_main_process():
835
+ tb = traceback.format_exc()
836
+ print(f"\n[ERROR] SFT failed at rank {rank}:\n{tb}")
837
+ if args.log_file is not None:
838
+ with open(args.log_file, "a", encoding="utf-8") as f:
839
+ import datetime
840
+ f.write(
841
+ f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
842
+ f"[FATAL] {tb}\n"
843
+ )
844
+ raise
845
+ finally:
846
+ # Remove NEFTune hook so the model is clean for inference/saving.
847
+ if neftune_handle is not None:
848
+ neftune_handle.remove()
849
+ if is_ddp:
850
+ cleanup_ddp()
851
+
852
+
853
+ if __name__ == "__main__":
854
+ main()
sft-v2/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 64000,
3
+ "d_model": 3072,
4
+ "n_layers": 26,
5
+ "n_heads": 24,
6
+ "n_kv_heads": 8,
7
+ "d_ffn": 9216,
8
+ "max_seq_len": 4096,
9
+ "rope_theta": 500000.0,
10
+ "dropout": 0.0,
11
+ "bias": false,
12
+ "use_flash_attn": true,
13
+ "use_fp8": false,
14
+ "use_hybrid": true,
15
+ "hybrid_pattern": "M M M M M M M M M M M M A M M M M M M M M M M M A M",
16
+ "mamba_d_state": 128,
17
+ "mamba_head_dim": 64,
18
+ "mamba_expand": 2,
19
+ "mamba_conv_kernel": 4,
20
+ "mamba_n_groups": 8,
21
+ "mamba_d_ffn": 4608,
22
+ "mamba_chunk_size": 256,
23
+ "model_type": "evafrill-mo",
24
+ "architectures": [
25
+ "EvafrillMoForCausalLM"
26
+ ],
27
+ "_variant": "sft-v2",
28
+ "_description": "SFT v2 (65K steps, val_loss 1.79, early stop)"
29
+ }
sft-v2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:144b4a8154e3c5fc13fa8d029d8016fe330fe8dc9083762a7ddbab62103d7073
3
+ size 6301164272
sft-v2/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
slerp/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 64000,
3
+ "d_model": 3072,
4
+ "n_layers": 26,
5
+ "n_heads": 24,
6
+ "n_kv_heads": 8,
7
+ "d_ffn": 9216,
8
+ "max_seq_len": 4096,
9
+ "rope_theta": 500000.0,
10
+ "dropout": 0.0,
11
+ "bias": false,
12
+ "use_flash_attn": true,
13
+ "use_fp8": false,
14
+ "use_hybrid": true,
15
+ "hybrid_pattern": "M M M M M M M M M M M M A M M M M M M M M M M M A M",
16
+ "mamba_d_state": 128,
17
+ "mamba_head_dim": 64,
18
+ "mamba_expand": 2,
19
+ "mamba_conv_kernel": 4,
20
+ "mamba_n_groups": 8,
21
+ "mamba_d_ffn": 4608,
22
+ "mamba_chunk_size": 256,
23
+ "model_type": "evafrill-mo",
24
+ "architectures": [
25
+ "EvafrillMoForCausalLM"
26
+ ],
27
+ "_variant": "slerp",
28
+ "_description": "SLERP merge alpha=0.5 (RECOMMENDED FINAL MODEL)"
29
+ }
slerp/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7fedbd0d0f8e33a1fb5e6c4e8e9393f729cc77b364d431e522857ce6a1c8d56
3
+ size 6301164272
slerp/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff