Upload source/PLAN_hybrid_3b_fixes.md with huggingface_hub

#34
by somebody-to-love - opened
Files changed (1) hide show
  1. source/PLAN_hybrid_3b_fixes.md +498 -0
source/PLAN_hybrid_3b_fixes.md ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FRANKENSTALLM-H 3B Hybrid Model โ€” ์ ๊ฒ€ ๊ฒฐ๊ณผ ๋ฐ ์ˆ˜์ • ์‹คํ–‰ ๊ฐ€์ด๋“œ
2
+
3
+ > **์ž‘์„ฑ์ผ**: 2026-03-05
4
+ > **๋ชฉ์ **: Phase 2 ๊ฒ€์ฆ ์ „, ๋ฐœ๊ฒฌ๋œ ์ด์Šˆ 6๊ฑด์„ ์ˆ˜์ •ํ•˜๊ณ  ๋ฐ”๋กœ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ์ƒํƒœ๋กœ ๋งŒ๋“ ๋‹ค.
5
+ > **๋‹ค์Œ ์„ธ์…˜์—์„œ ์ด ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์—ฌ ๋ฐ”๋กœ ์‹คํ–‰ํ•  ๊ฒƒ.**
6
+
7
+ ---
8
+
9
+ ## ์ด์Šˆ ์š”์•ฝ (6๊ฑด)
10
+
11
+ | # | ์‹ฌ๊ฐ๋„ | ์ด์Šˆ | ํŒŒ์ผ | ์˜ํ–ฅ |
12
+ |---|--------|------|------|------|
13
+ | 1 | **CRITICAL** | Mamba ๋ธ”๋ก์— FFN(channel mixer) ์—†์Œ | `model/mamba_block.py` | 37/40 ๋ ˆ์ด์–ด capacity ๋ถ€์กฑ |
14
+ | 2 | **HIGH** | `n_groups=1` (Nemotron ํ‘œ์ค€์€ 8) | `configs/hybrid_3b.yaml` | B/C projection ํ‘œํ˜„๋ ฅ ์ €ํ•˜ |
15
+ | 3 | **HIGH** | Hybrid ์•„ํ‚คํ…์ฒ˜ startup ๋กœ๊ทธ ์—†์Œ | `train/pretrain.py` | ๋””๋ฒ„๊น…ยท๋ชจ๋‹ˆํ„ฐ๋ง ๊ณค๋ž€ |
16
+ | 4 | **MEDIUM** | ์ฒดํฌํฌ์ธํŠธ resume ์‹œ ์•„ํ‚คํ…์ฒ˜ ๊ฒ€์ฆ ์—†์Œ | `train/utils.py` | ์ž˜๋ชป๋œ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ๊ฐ€๋Šฅ |
17
+ | 5 | **MEDIUM** | selective_scan์— NaN/Inf ๊ฐ์ง€ ์—†์Œ | `model/mamba_block.py` | ์ˆ˜์น˜ ๋ถˆ์•ˆ์ • ์ง„๋‹จ ๋ถˆ๊ฐ€ |
18
+ | 6 | **LOW** | selective_scan ์ž…๋ ฅ shape ๊ฒ€์ฆ ์—†์Œ | `model/mamba_block.py` | ๋ชจํ˜ธํ•œ ์—๋Ÿฌ ๋ฉ”์‹œ์ง€ |
19
+
20
+ ---
21
+
22
+ ## ๊ตฌํ˜„ ์ˆœ์„œ ๋ฐ ์˜์กด์„ฑ
23
+
24
+ ```
25
+ Step 1 (FFN ์ถ”๊ฐ€) โ† ๊ฐ€์žฅ ๋จผ์ €, ์•„ํ‚คํ…์ฒ˜ ๋ณ€๊ฒฝ
26
+ โ”œโ”€โ”€ 1a. model/config.py: mamba_d_ffn ํ•„๋“œ ์ถ”๊ฐ€
27
+ โ”œโ”€โ”€ 1b. model/mamba_block.py: FFN sublayer ์ถ”๊ฐ€
28
+ โ”œโ”€โ”€ 1c. model/transformer.py: ์ƒ์„ฑ์ž ์ธ์ž ์ „๋‹ฌ + _init_weights ์ˆ˜์ •
29
+ โ””โ”€โ”€ 1d. configs/hybrid_3b.yaml: mamba_d_ffn=4608 ์ถ”๊ฐ€
30
+
31
+ Step 2 (n_groups) โ† Step 1๊ณผ ๋…๋ฆฝ, ๋ณ‘๋ ฌ ๊ฐ€๋Šฅ
32
+ โ””โ”€โ”€ configs/hybrid_3b.yaml: n_groups=8
33
+
34
+ Step 3 (๋กœ๊ทธ) โ† Step 1 ์™„๋ฃŒ ํ›„ (ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ์ •ํ™•ํ•ด์•ผ)
35
+ โ””โ”€โ”€ train/pretrain.py: startup ๋ฐฐ๋„ˆ์— hybrid ์ •๋ณด ์ถ”๊ฐ€
36
+
37
+ Step 4 (์ฒดํฌํฌ์ธํŠธ ๊ฒ€์ฆ) โ† ๋…๋ฆฝ
38
+ โ””โ”€โ”€ train/utils.py: load_checkpoint์— config ๋น„๊ต ๋กœ์ง
39
+
40
+ Step 5-6 (NaN ๊ฐ์ง€ + shape ๊ฒ€์ฆ) โ† ๋…๋ฆฝ
41
+ โ””โ”€โ”€ model/mamba_block.py: selective_scan ํ•จ์ˆ˜
42
+ ```
43
+
44
+ **๋ณ‘๋ ฌ ๊ฐ€๋Šฅ**: Step 1 + Step 2๋Š” YAML๋งŒ ๊ฒน์นจ (๋งˆ์ง€๋ง‰์— ํ•ฉ์น˜๋ฉด ๋จ).
45
+ Step 4, Step 5-6๋„ ๋…๋ฆฝ์ ์œผ๋กœ ๋ณ‘๋ ฌ ์‹คํ–‰ ๊ฐ€๋Šฅ.
46
+
47
+ ---
48
+
49
+ ## Step 1: Mamba2Block์— FFN ์ถ”๊ฐ€ (CRITICAL)
50
+
51
+ ### ๋ฐฐ๊ฒฝ
52
+
53
+ - Mamba2Block์€ SSM(sequence mixer)๋งŒ ์žˆ๊ณ  FFN(channel mixer)์ด ์—†์Œ
54
+ - Nemotron-H์—์„œ๋Š” ๋ชจ๋“  Mamba ๋ ˆ์ด์–ด ๋’ค์— MLP๊ฐ€ ๋”ฐ๋ผ์˜ด
55
+ - ํ˜„์žฌ 37/40 ๋ ˆ์ด์–ด์— FFN์ด ์—†์–ด feature mixing์ด ๋ถˆ๊ฐ€๋Šฅ
56
+ - **ํ™•์ •**: `mamba_d_ffn = 4608` (d_model ร— 1.5), ์ด ํŒŒ๋ผ๋ฏธํ„ฐ ~4.5B, VRAM ~80GB/GPU
57
+
58
+ ### 1a. `model/config.py` ์ˆ˜์ •
59
+
60
+ **์œ„์น˜**: LMConfig dataclass ๋‚ด๋ถ€ (line 61 ์ดํ›„)
61
+
62
+ **์ถ”๊ฐ€ํ•  ํ•„๋“œ** (๊ธฐ์กด `mamba_chunk_size` ๋’ค์—):
63
+ ```python
64
+ mamba_d_ffn: Optional[int] = None # FFN dim for Mamba blocks (None โ†’ d_ffn)
65
+ ```
66
+
67
+ **`__post_init__` ์ถ”๊ฐ€** (line 86, hybrid validation ๋ธ”๋ก ๋’ค์—):
68
+ ```python
69
+ # Mamba FFN dimension: default to d_ffn if not specified
70
+ if self.mamba_d_ffn is None:
71
+ self.mamba_d_ffn = self.d_ffn
72
+ ```
73
+
74
+ **`to_dict()` ์ถ”๊ฐ€** (๊ธฐ์กด mamba_chunk_size ๋’ค์—):
75
+ ```python
76
+ "mamba_d_ffn": self.mamba_d_ffn,
77
+ ```
78
+
79
+ ### 1b. `model/mamba_block.py` ์ˆ˜์ •
80
+
81
+ **Import ๋ณ€๊ฒฝ** (line 19):
82
+ ```python
83
+ # ๋ณ€๊ฒฝ ์ „:
84
+ from .layers import RMSNorm
85
+
86
+ # ๋ณ€๊ฒฝ ํ›„:
87
+ from .layers import RMSNorm, SwiGLU
88
+ ```
89
+
90
+ **`Mamba2Block.__init__` ์‹œ๊ทธ๋‹ˆ์ฒ˜ ๋ณ€๊ฒฝ** (line 128-137):
91
+ ```python
92
+ # ๋ณ€๊ฒฝ ์ „:
93
+ def __init__(
94
+ self,
95
+ d_model: int,
96
+ d_state: int = 128,
97
+ head_dim: int = 64,
98
+ expand: int = 2,
99
+ conv_kernel: int = 4,
100
+ n_groups: int = 1,
101
+ chunk_size: int = 256,
102
+ ) -> None:
103
+
104
+ # ๋ณ€๊ฒฝ ํ›„:
105
+ def __init__(
106
+ self,
107
+ d_model: int,
108
+ d_state: int = 128,
109
+ head_dim: int = 64,
110
+ expand: int = 2,
111
+ conv_kernel: int = 4,
112
+ n_groups: int = 1,
113
+ chunk_size: int = 256,
114
+ d_ffn: int = 0,
115
+ bias: bool = False,
116
+ ) -> None:
117
+ ```
118
+
119
+ **FFN ์„œ๋ธŒ๋ ˆ์ด์–ด ์ถ”๊ฐ€** (line 192, `self.out_proj` ๋’ค์—):
120
+ ```python
121
+ # --- FFN sublayer (channel mixer) ---
122
+ if d_ffn > 0:
123
+ self.ffn_norm = RMSNorm(d_model)
124
+ self.ffn = SwiGLU(d_model, d_ffn, bias=bias)
125
+ else:
126
+ self.ffn_norm = None
127
+ self.ffn = None
128
+ ```
129
+
130
+ **`forward()` ์ˆ˜์ •** (line 280):
131
+ ```python
132
+ # ๋ณ€๊ฒฝ ์ „:
133
+ return residual + self.out_proj(y)
134
+
135
+ # ๋ณ€๊ฒฝ ํ›„:
136
+ x = residual + self.out_proj(y)
137
+ # FFN sublayer (channel mixer)
138
+ if self.ffn is not None:
139
+ x = x + self.ffn(self.ffn_norm(x))
140
+ return x
141
+ ```
142
+
143
+ ### 1c. `model/transformer.py` ์ˆ˜์ •
144
+
145
+ **Mamba2Block ์ƒ์„ฑ์ž ํ˜ธ์ถœ ๋ณ€๊ฒฝ** (line 124-132):
146
+ ```python
147
+ # ๋ณ€๊ฒฝ ์ „:
148
+ layers.append(Mamba2Block(
149
+ d_model=config.d_model,
150
+ d_state=config.mamba_d_state,
151
+ head_dim=config.mamba_head_dim,
152
+ expand=config.mamba_expand,
153
+ conv_kernel=config.mamba_conv_kernel,
154
+ n_groups=config.mamba_n_groups,
155
+ chunk_size=config.mamba_chunk_size,
156
+ ))
157
+
158
+ # ๋ณ€๊ฒฝ ํ›„:
159
+ layers.append(Mamba2Block(
160
+ d_model=config.d_model,
161
+ d_state=config.mamba_d_state,
162
+ head_dim=config.mamba_head_dim,
163
+ expand=config.mamba_expand,
164
+ conv_kernel=config.mamba_conv_kernel,
165
+ n_groups=config.mamba_n_groups,
166
+ chunk_size=config.mamba_chunk_size,
167
+ d_ffn=config.mamba_d_ffn,
168
+ bias=config.bias,
169
+ ))
170
+ ```
171
+
172
+ **`_init_weights` ์ˆ˜์ •** (line 180-182):
173
+ ```python
174
+ # ๋ณ€๊ฒฝ ์ „:
175
+ # Mamba2Block handles its own parameter init (A_log, D, dt_bias, etc.)
176
+ if isinstance(module, Mamba2Block):
177
+ return
178
+
179
+ # ๋ณ€๊ฒฝ ํ›„ (์ด 3์ค„์„ ์‚ญ์ œ):
180
+ # ์‚ญ์ œ ์ด์œ : FFN ์ถ”๊ฐ€ ํ›„ ๋‚ด๋ถ€ SwiGLU์˜ nn.Linear๊ฐ€ init ํ•„์š”.
181
+ # A_log, D, dt_bias๋Š” nn.Parameter์ด๋ฏ€๋กœ isinstance(nn.Linear) ์ฒดํฌ์— ๊ฑธ๋ฆฌ์ง€ ์•Š์•„
182
+ # ์ž๋™์œผ๋กœ ์Šคํ‚ต๋จ (Mamba2Block.__init__์—์„œ ์ง์ ‘ ์ดˆ๊ธฐํ™”๋จ).
183
+ ```
184
+
185
+ ### 1d. `configs/hybrid_3b.yaml` ์ˆ˜์ •
186
+
187
+ ```yaml
188
+ # mamba_chunk_size: 256 ๋’ค์— ์ถ”๊ฐ€:
189
+ mamba_d_ffn: 4608
190
+ ```
191
+
192
+ ### Step 1 ๊ฒ€์ฆ
193
+
194
+ ```bash
195
+ cd /PROJECT/0325120031_A/ghong/taketimes/llm-bang
196
+ CUDA_VISIBLE_DEVICES=0 python -c "
197
+ import torch, sys
198
+ sys.path.insert(0, '.')
199
+ from model import LLM, LMConfig
200
+
201
+ config = LMConfig.from_yaml('configs/hybrid_3b.yaml')
202
+ print(f'mamba_d_ffn = {config.mamba_d_ffn}')
203
+
204
+ model = LLM(config)
205
+ total = sum(p.numel() for p in model.parameters())
206
+ print(f'Total params: {total:,} ({total/1e9:.2f}B)')
207
+
208
+ # Forward test
209
+ x = torch.randint(0, 64000, (1, 128))
210
+ logits, loss = model(x, targets=x)
211
+ print(f'Forward OK: logits shape={logits.shape}, loss={loss.item():.4f}')
212
+
213
+ # Backward test
214
+ loss.backward()
215
+ grads_ok = all(p.grad is not None for p in model.parameters() if p.requires_grad)
216
+ print(f'Backward OK: all grads exist = {grads_ok}')
217
+ "
218
+ # ์˜ˆ์ƒ ์ถœ๋ ฅ: Total params ~4.5B, Forward/Backward OK
219
+ ```
220
+
221
+ ---
222
+
223
+ ## Step 2: n_groups ์ˆ˜์ •
224
+
225
+ ### `configs/hybrid_3b.yaml`
226
+
227
+ ```yaml
228
+ # ๋ณ€๊ฒฝ ์ „:
229
+ mamba_n_groups: 1
230
+
231
+ # ๋ณ€๊ฒฝ ํ›„:
232
+ mamba_n_groups: 8
233
+ ```
234
+
235
+ ### ๊ฒ€์ฆ
236
+
237
+ n_heads(= d_inner / head_dim = 6144 / 64 = 96) % 8 == 0 โœ“
238
+ Step 1 ๊ฒ€์ฆ ์Šคํฌ๋ฆฝํŠธ์—์„œ ํ•จ๊ป˜ ํ™•์ธ๋จ (assertion์ด `__init__`์— ์žˆ์Œ).
239
+
240
+ ---
241
+
242
+ ## Step 3: ํ•˜์ด๋ธŒ๋ฆฌ๋“œ ์•„ํ‚คํ…์ฒ˜ startup ๋กœ๊ทธ ์ถ”๊ฐ€
243
+
244
+ ### `train/pretrain.py` ์ˆ˜์ •
245
+
246
+ **์œ„์น˜**: line 296-297 (๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ถœ๋ ฅ ๋ถ€๋ถ„) ๋’ค์— ์ถ”๊ฐ€
247
+
248
+ ```python
249
+ if is_main_process():
250
+ total_params = sum(p.numel() for p in model.parameters())
251
+ print(f"Model parameters: {total_params:,}")
252
+ print(f"LMConfig: {lm_config}")
253
+
254
+ # --- ์—ฌ๊ธฐ๋ถ€ํ„ฐ ์ถ”๊ฐ€ ---
255
+ if lm_config.use_hybrid:
256
+ pattern = lm_config.hybrid_pattern.split()
257
+ m_count = sum(1 for p in pattern if p == 'M')
258
+ a_count = sum(1 for p in pattern if p == 'A')
259
+ mamba_params = sum(
260
+ p.numel() for n, p in model.named_parameters()
261
+ if 'layers.' in n and pattern[int(n.split('.')[1])] == 'M'
262
+ )
263
+ attn_params = sum(
264
+ p.numel() for n, p in model.named_parameters()
265
+ if 'layers.' in n and pattern[int(n.split('.')[1])] == 'A'
266
+ )
267
+ other_params = total_params - mamba_params - attn_params
268
+ print(
269
+ f" arch : Hybrid Mamba-Transformer\n"
270
+ f" layers : {m_count} Mamba + {a_count} Attention = {len(pattern)} total\n"
271
+ f" params : Mamba {mamba_params/1e6:.0f}M + "
272
+ f"Attn {attn_params/1e6:.0f}M + Other {other_params/1e6:.0f}M\n"
273
+ f" mamba cfg: d_state={lm_config.mamba_d_state}, "
274
+ f"head_dim={lm_config.mamba_head_dim}, "
275
+ f"expand={lm_config.mamba_expand}, "
276
+ f"n_groups={lm_config.mamba_n_groups}, "
277
+ f"d_ffn={lm_config.mamba_d_ffn}"
278
+ )
279
+ # --- ์ถ”๊ฐ€ ๋ ---
280
+ ```
281
+
282
+ ### ๊ฒ€์ฆ
283
+
284
+ Step 1 ๊ฒ€์ฆ ์‹คํ–‰ ์‹œ ๋กœ๊ทธ์— hybrid ์ •๋ณด๊ฐ€ ์ถœ๋ ฅ๋˜๋Š”์ง€ ํ™•์ธ.
285
+
286
+ ---
287
+
288
+ ## Step 4: ์ฒดํฌํฌ์ธํŠธ resume ์•„ํ‚คํ…์ฒ˜ ๊ฒ€์ฆ
289
+
290
+ ### `train/utils.py` โ€” `load_checkpoint()` ์ˆ˜์ •
291
+
292
+ **์œ„์น˜**: line 179 (`raw_model.load_state_dict(...)`) ์ง์ „์— ์ถ”๊ฐ€
293
+
294
+ ```python
295
+ # --- Architecture validation ---
296
+ config_path = ckpt_dir / "config.yaml"
297
+ if config_path.exists() and hasattr(raw_model, "config"):
298
+ with open(config_path, "r", encoding="utf-8") as f:
299
+ saved_cfg = yaml.safe_load(f)
300
+ current_cfg = raw_model.config.to_dict()
301
+ critical_keys = [
302
+ "d_model", "n_layers", "n_heads", "n_kv_heads", "vocab_size",
303
+ "use_hybrid", "hybrid_pattern",
304
+ ]
305
+ mismatches = []
306
+ for key in critical_keys:
307
+ saved_val = saved_cfg.get(key)
308
+ current_val = current_cfg.get(key)
309
+ if saved_val is not None and saved_val != current_val:
310
+ mismatches.append(
311
+ f" {key}: checkpoint={saved_val} vs current={current_val}"
312
+ )
313
+ if mismatches:
314
+ raise ValueError(
315
+ f"Checkpoint architecture mismatch!\n"
316
+ f"Checkpoint dir: {ckpt_dir}\n"
317
+ + "\n".join(mismatches)
318
+ + "\nUse --config matching the checkpoint, or start fresh."
319
+ )
320
+ # --- End architecture validation ---
321
+ ```
322
+
323
+ **์ฐธ๊ณ **: `yaml`์€ ์ด๋ฏธ `train/utils.py` line 23์—์„œ import ๋˜์–ด ์žˆ์Œ.
324
+
325
+ ### ๊ฒ€์ฆ
326
+
327
+ ```bash
328
+ # ์˜๋„์ ์œผ๋กœ ๋‹ค๋ฅธ config๋กœ resume ์‹œ๋„
329
+ CUDA_VISIBLE_DEVICES=0 python train/pretrain.py \
330
+ --config configs/small.yaml \
331
+ --train_data data/3b_train.bin \
332
+ --resume checkpoints/hybrid_3b_run1/checkpoint-0001000
333
+ # ์˜ˆ์ƒ: ValueError "Checkpoint architecture mismatch!" ์ถœ๋ ฅ
334
+ ```
335
+
336
+ ---
337
+
338
+ ## Step 5: selective_scan NaN/Inf ๊ฐ์ง€
339
+
340
+ ### `model/mamba_block.py` โ€” `selective_scan()` ์ˆ˜์ •
341
+
342
+ **์œ„์น˜**: line 94 (`y[:, t, :, :] = y_t.to(x.dtype)`) ๋’ค์— ์ถ”๊ฐ€
343
+
344
+ ```python
345
+ # Periodic NaN/Inf check (every 512 steps, < 1% overhead)
346
+ if t % 512 == 511:
347
+ if not torch.isfinite(h).all():
348
+ raise RuntimeError(
349
+ f"NaN/Inf in Mamba SSM state at timestep {t}/{seq_len}. "
350
+ f"h stats: min={h.min().item():.4e}, max={h.max().item():.4e}, "
351
+ f"A_log range=[{A_log.min().item():.4f}, {A_log.max().item():.4f}]"
352
+ )
353
+ ```
354
+
355
+ ### ๊ฒ€์ฆ
356
+
357
+ ```bash
358
+ CUDA_VISIBLE_DEVICES=0 python -c "
359
+ import torch, sys
360
+ sys.path.insert(0, '.')
361
+ from model.mamba_block import Mamba2Block
362
+
363
+ block = Mamba2Block(d_model=256, d_state=64, head_dim=32, d_ffn=384)
364
+ x = torch.randn(1, 1024, 256)
365
+
366
+ # ์ •์ƒ ์ผ€์ด์Šค
367
+ y = block(x)
368
+ print(f'Normal: output shape={y.shape}, finite={torch.isfinite(y).all()}')
369
+
370
+ # NaN ์ฃผ์ž… ํ…Œ์ŠคํŠธ
371
+ block.A_log.data.fill_(100.0) # ๋งค์šฐ ํฐ ๊ฐ’ โ†’ exp(100) overflow
372
+ try:
373
+ y = block(x)
374
+ print('WARNING: NaN not detected!')
375
+ except RuntimeError as e:
376
+ print(f'NaN correctly detected: {e}')
377
+ "
378
+ ```
379
+
380
+ ---
381
+
382
+ ## Step 6: selective_scan ์ž…๋ ฅ shape ๊ฒ€์ฆ
383
+
384
+ ### `model/mamba_block.py` โ€” `selective_scan()` ์ˆ˜์ •
385
+
386
+ **์œ„์น˜**: line 49 (`batch, seq_len, n_heads, head_dim = x.shape`) ์ง์ „์— ์ถ”๊ฐ€
387
+
388
+ ```python
389
+ # Input shape validation
390
+ assert x.ndim == 4, f"x expected 4D (B,L,n_heads,head_dim), got {x.shape}"
391
+ assert dt.ndim == 3, f"dt expected 3D (B,L,n_heads), got {dt.shape}"
392
+ assert B.ndim == 4, f"B expected 4D (B,L,n_groups,d_state), got {B.shape}"
393
+ assert C.ndim == 4, f"C expected 4D (B,L,n_groups,d_state), got {C.shape}"
394
+ ```
395
+
396
+ ---
397
+
398
+ ## ์ตœ์ข… ๊ฒ€์ฆ ์ ˆ์ฐจ (๋ชจ๋“  Step ์™„๋ฃŒ ํ›„)
399
+
400
+ ### 1. ๋ชจ๋ธ ์ƒ์„ฑ + Forward/Backward (๋‹จ์ผ GPU)
401
+
402
+ ```bash
403
+ cd /PROJECT/0325120031_A/ghong/taketimes/llm-bang
404
+ CUDA_VISIBLE_DEVICES=0 python -c "
405
+ import torch, sys
406
+ sys.path.insert(0, '.')
407
+ from model import LLM, LMConfig
408
+
409
+ config = LMConfig.from_yaml('configs/hybrid_3b.yaml')
410
+ model = LLM(config).cuda()
411
+
412
+ total = sum(p.numel() for p in model.parameters())
413
+ print(f'Total params: {total:,} ({total/1e9:.2f}B)')
414
+ assert 4.0e9 < total < 5.0e9, f'Expected ~4.5B params, got {total/1e9:.2f}B'
415
+
416
+ # Forward
417
+ x = torch.randint(0, 64000, (2, 512)).cuda()
418
+ logits, loss = model(x, targets=x)
419
+ print(f'Forward: logits={logits.shape}, loss={loss.item():.4f}')
420
+
421
+ # Backward
422
+ loss.backward()
423
+ no_grad = [n for n, p in model.named_parameters() if p.requires_grad and p.grad is None]
424
+ assert len(no_grad) == 0, f'Missing gradients: {no_grad}'
425
+ print(f'Backward: all {sum(1 for p in model.parameters() if p.requires_grad)} params have grad')
426
+
427
+ # VRAM
428
+ print(f'VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB allocated')
429
+ "
430
+ ```
431
+
432
+ ### 2. DDP 8-GPU ํ…Œ์ŠคํŠธ (10 steps)
433
+
434
+ ```bash
435
+ cd /PROJECT/0325120031_A/ghong/taketimes/llm-bang
436
+ torchrun --nproc_per_node=8 --master_port=29501 train/pretrain.py \
437
+ --config configs/hybrid_3b.yaml \
438
+ --train_data data/3b_train.bin \
439
+ --batch_size 2 \
440
+ --lr 1e-4 \
441
+ --warmup_steps 5 \
442
+ --grad_accum 1 \
443
+ --max_steps 10 \
444
+ --checkpoint_dir /tmp/hybrid_test_ckpt \
445
+ --use_fp8
446
+ # ์˜ˆ์ƒ: 10 steps ์™„๋ฃŒ, ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ, startup ๋ฐฐ๋„ˆ์— hybrid ์ •๋ณด ์ถœ๋ ฅ
447
+ ```
448
+
449
+ ### 3. ์ฒดํฌํฌ์ธํŠธ Resume ํ…Œ์ŠคํŠธ
450
+
451
+ ```bash
452
+ # Step 2 ์ฒดํฌํฌ์ธํŠธ์—์„œ resume
453
+ torchrun --nproc_per_node=8 --master_port=29501 train/pretrain.py \
454
+ --config configs/hybrid_3b.yaml \
455
+ --train_data data/3b_train.bin \
456
+ --batch_size 2 \
457
+ --lr 1e-4 \
458
+ --warmup_steps 5 \
459
+ --grad_accum 1 \
460
+ --max_steps 20 \
461
+ --checkpoint_dir /tmp/hybrid_test_ckpt \
462
+ --resume /tmp/hybrid_test_ckpt/checkpoint-0000010 \
463
+ --use_fp8
464
+ # ์˜ˆ์ƒ: step 10์—์„œ ์ด์–ด์„œ step 20๊นŒ์ง€ ํ•™์Šต
465
+ ```
466
+
467
+ ---
468
+
469
+ ## ์ˆ˜์ •ํ•˜์ง€ ์•Š๋Š” ๊ฒƒ๋“ค (์˜๋„์  ์ œ์™ธ)
470
+
471
+ - **sequential scan ์„ฑ๋Šฅ**: Python for-loop๋Š” ๋А๋ฆฌ์ง€๋งŒ ๊ตฌ์กฐ ๋ณ€๊ฒฝ์ด ํผ. ๋ณ„๋„ ํƒœ์Šคํฌ๋กœ chunked SSD ๊ตฌํ˜„
472
+ - **FP8 + Mamba ํ˜ผํ•ฉ**: ํ˜„์žฌ ์„ค๊ณ„(Mamba=bf16, Attention=FP8)๊ฐ€ ์˜ฌ๋ฐ”๋ฆ„. te.fp8_autocast๋Š” te ๋ชจ๋“ˆ๋งŒ ์˜ํ–ฅ
473
+ - **DDP ์„ค์ •**: find_unused_parameters=False, gradient_as_bucket_view=True ๋ชจ๋‘ ์ •์ƒ
474
+ - **pure Transformer ๋ชจ๋“œ**: use_hybrid=False๋ฉด ๊ธฐ์กด ๋™์ž‘ ์œ ์ง€ (ํ•˜์œ„ ํ˜ธํ™˜)
475
+
476
+ ---
477
+
478
+ ## ์ˆ˜์ • ๋Œ€์ƒ ํŒŒ์ผ ์š”์•ฝ
479
+
480
+ | ํŒŒ์ผ | Step | ๋ณ€๊ฒฝ ๋‚ด์šฉ |
481
+ |------|------|----------|
482
+ | `model/config.py` | 1a | `mamba_d_ffn` ํ•„๋“œ + `__post_init__` + `to_dict()` |
483
+ | `model/mamba_block.py` | 1b, 5, 6 | SwiGLU import, FFN sublayer, NaN ๊ฐ์ง€, shape ๊ฒ€์ฆ |
484
+ | `model/transformer.py` | 1c | Mamba2Block ์ƒ์„ฑ์ž์— d_ffn/bias ์ „๋‹ฌ, `_init_weights` ์ˆ˜์ • |
485
+ | `configs/hybrid_3b.yaml` | 1d, 2 | `mamba_d_ffn: 4608`, `mamba_n_groups: 8` |
486
+ | `train/pretrain.py` | 3 | Hybrid startup ๋กœ๊ทธ |
487
+ | `train/utils.py` | 4 | `load_checkpoint()` ์•„ํ‚คํ…์ฒ˜ ๊ฒ€์ฆ |
488
+
489
+ ---
490
+
491
+ ## ์‹คํ–‰ ์ง€์‹œ (๋‹ค์Œ ์„ธ์…˜์šฉ)
492
+
493
+ ์ด ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์—ฌ ๋‹ค์Œ ๋ช…๋ น์„ ๋‚ด๋ฆฌ๋ฉด ๋ฉ๋‹ˆ๋‹ค:
494
+
495
+ > "์ด ๋ฌธ์„œ(hashed-drifting-harp.md)์˜ Step 1~6์„ ์ˆœ์„œ๋Œ€๋กœ ์‹คํ–‰ํ•ด ์ค˜.
496
+ > Step 1+2๋Š” ๋ณ‘๋ ฌ๋กœ, Step 3~6์€ ๋…๋ฆฝ์ ์œผ๋กœ ์ง„ํ–‰.
497
+ > ๊ฐ Step ์™„๋ฃŒ ํ›„ ํ•ด๋‹น ๊ฒ€์ฆ์„ ์‹คํ–‰ํ•˜๊ณ ,
498
+ > ์ „์ฒด ์™„๋ฃŒ ํ›„ ์ตœ์ข… ๊ฒ€์ฆ ์ ˆ์ฐจ 3๋‹จ๊ณ„๋ฅผ ๋ชจ๋‘ ์‹คํ–‰ํ•ด ์ค˜."