frankenstallm / source /PLAN_hybrid_3b_fixes.md
pathcosmos's picture
Upload source/PLAN_hybrid_3b_fixes.md with huggingface_hub (#34)
ed8ae31

FRANKENSTALLM-H 3B Hybrid Model โ€” ์ ๊ฒ€ ๊ฒฐ๊ณผ ๋ฐ ์ˆ˜์ • ์‹คํ–‰ ๊ฐ€์ด๋“œ

์ž‘์„ฑ์ผ: 2026-03-05 ๋ชฉ์ : Phase 2 ๊ฒ€์ฆ ์ „, ๋ฐœ๊ฒฌ๋œ ์ด์Šˆ 6๊ฑด์„ ์ˆ˜์ •ํ•˜๊ณ  ๋ฐ”๋กœ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ์ƒํƒœ๋กœ ๋งŒ๋“ ๋‹ค. ๋‹ค์Œ ์„ธ์…˜์—์„œ ์ด ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์—ฌ ๋ฐ”๋กœ ์‹คํ–‰ํ•  ๊ฒƒ.


์ด์Šˆ ์š”์•ฝ (6๊ฑด)

# ์‹ฌ๊ฐ๋„ ์ด์Šˆ ํŒŒ์ผ ์˜ํ–ฅ
1 CRITICAL Mamba ๋ธ”๋ก์— FFN(channel mixer) ์—†์Œ model/mamba_block.py 37/40 ๋ ˆ์ด์–ด capacity ๋ถ€์กฑ
2 HIGH n_groups=1 (Nemotron ํ‘œ์ค€์€ 8) configs/hybrid_3b.yaml B/C projection ํ‘œํ˜„๋ ฅ ์ €ํ•˜
3 HIGH Hybrid ์•„ํ‚คํ…์ฒ˜ startup ๋กœ๊ทธ ์—†์Œ train/pretrain.py ๋””๋ฒ„๊น…ยท๋ชจ๋‹ˆํ„ฐ๋ง ๊ณค๋ž€
4 MEDIUM ์ฒดํฌํฌ์ธํŠธ resume ์‹œ ์•„ํ‚คํ…์ฒ˜ ๊ฒ€์ฆ ์—†์Œ train/utils.py ์ž˜๋ชป๋œ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ๊ฐ€๋Šฅ
5 MEDIUM selective_scan์— NaN/Inf ๊ฐ์ง€ ์—†์Œ model/mamba_block.py ์ˆ˜์น˜ ๋ถˆ์•ˆ์ • ์ง„๋‹จ ๋ถˆ๊ฐ€
6 LOW selective_scan ์ž…๋ ฅ shape ๊ฒ€์ฆ ์—†์Œ model/mamba_block.py ๋ชจํ˜ธํ•œ ์—๋Ÿฌ ๋ฉ”์‹œ์ง€

๊ตฌํ˜„ ์ˆœ์„œ ๋ฐ ์˜์กด์„ฑ

Step 1 (FFN ์ถ”๊ฐ€) โ† ๊ฐ€์žฅ ๋จผ์ €, ์•„ํ‚คํ…์ฒ˜ ๋ณ€๊ฒฝ
  โ”œโ”€โ”€ 1a. model/config.py: mamba_d_ffn ํ•„๋“œ ์ถ”๊ฐ€
  โ”œโ”€โ”€ 1b. model/mamba_block.py: FFN sublayer ์ถ”๊ฐ€
  โ”œโ”€โ”€ 1c. model/transformer.py: ์ƒ์„ฑ์ž ์ธ์ž ์ „๋‹ฌ + _init_weights ์ˆ˜์ •
  โ””โ”€โ”€ 1d. configs/hybrid_3b.yaml: mamba_d_ffn=4608 ์ถ”๊ฐ€

Step 2 (n_groups) โ† Step 1๊ณผ ๋…๋ฆฝ, ๋ณ‘๋ ฌ ๊ฐ€๋Šฅ
  โ””โ”€โ”€ configs/hybrid_3b.yaml: n_groups=8

Step 3 (๋กœ๊ทธ) โ† Step 1 ์™„๋ฃŒ ํ›„ (ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ์ •ํ™•ํ•ด์•ผ)
  โ””โ”€โ”€ train/pretrain.py: startup ๋ฐฐ๋„ˆ์— hybrid ์ •๋ณด ์ถ”๊ฐ€

Step 4 (์ฒดํฌํฌ์ธํŠธ ๊ฒ€์ฆ) โ† ๋…๋ฆฝ
  โ””โ”€โ”€ train/utils.py: load_checkpoint์— config ๋น„๊ต ๋กœ์ง

Step 5-6 (NaN ๊ฐ์ง€ + shape ๊ฒ€์ฆ) โ† ๋…๋ฆฝ
  โ””โ”€โ”€ model/mamba_block.py: selective_scan ํ•จ์ˆ˜

๋ณ‘๋ ฌ ๊ฐ€๋Šฅ: Step 1 + Step 2๋Š” YAML๋งŒ ๊ฒน์นจ (๋งˆ์ง€๋ง‰์— ํ•ฉ์น˜๋ฉด ๋จ). Step 4, Step 5-6๋„ ๋…๋ฆฝ์ ์œผ๋กœ ๋ณ‘๋ ฌ ์‹คํ–‰ ๊ฐ€๋Šฅ.


Step 1: Mamba2Block์— FFN ์ถ”๊ฐ€ (CRITICAL)

๋ฐฐ๊ฒฝ

  • Mamba2Block์€ SSM(sequence mixer)๋งŒ ์žˆ๊ณ  FFN(channel mixer)์ด ์—†์Œ
  • Nemotron-H์—์„œ๋Š” ๋ชจ๋“  Mamba ๋ ˆ์ด์–ด ๋’ค์— MLP๊ฐ€ ๋”ฐ๋ผ์˜ด
  • ํ˜„์žฌ 37/40 ๋ ˆ์ด์–ด์— FFN์ด ์—†์–ด feature mixing์ด ๋ถˆ๊ฐ€๋Šฅ
  • ํ™•์ •: mamba_d_ffn = 4608 (d_model ร— 1.5), ์ด ํŒŒ๋ผ๋ฏธํ„ฐ ~4.5B, VRAM ~80GB/GPU

1a. model/config.py ์ˆ˜์ •

์œ„์น˜: LMConfig dataclass ๋‚ด๋ถ€ (line 61 ์ดํ›„)

์ถ”๊ฐ€ํ•  ํ•„๋“œ (๊ธฐ์กด mamba_chunk_size ๋’ค์—):

    mamba_d_ffn: Optional[int] = None  # FFN dim for Mamba blocks (None โ†’ d_ffn)

__post_init__ ์ถ”๊ฐ€ (line 86, hybrid validation ๋ธ”๋ก ๋’ค์—):

        # Mamba FFN dimension: default to d_ffn if not specified
        if self.mamba_d_ffn is None:
            self.mamba_d_ffn = self.d_ffn

to_dict() ์ถ”๊ฐ€ (๊ธฐ์กด mamba_chunk_size ๋’ค์—):

            "mamba_d_ffn": self.mamba_d_ffn,

1b. model/mamba_block.py ์ˆ˜์ •

Import ๋ณ€๊ฒฝ (line 19):

# ๋ณ€๊ฒฝ ์ „:
from .layers import RMSNorm

# ๋ณ€๊ฒฝ ํ›„:
from .layers import RMSNorm, SwiGLU

Mamba2Block.__init__ ์‹œ๊ทธ๋‹ˆ์ฒ˜ ๋ณ€๊ฒฝ (line 128-137):

# ๋ณ€๊ฒฝ ์ „:
    def __init__(
        self,
        d_model: int,
        d_state: int = 128,
        head_dim: int = 64,
        expand: int = 2,
        conv_kernel: int = 4,
        n_groups: int = 1,
        chunk_size: int = 256,
    ) -> None:

# ๋ณ€๊ฒฝ ํ›„:
    def __init__(
        self,
        d_model: int,
        d_state: int = 128,
        head_dim: int = 64,
        expand: int = 2,
        conv_kernel: int = 4,
        n_groups: int = 1,
        chunk_size: int = 256,
        d_ffn: int = 0,
        bias: bool = False,
    ) -> None:

FFN ์„œ๋ธŒ๋ ˆ์ด์–ด ์ถ”๊ฐ€ (line 192, self.out_proj ๋’ค์—):

        # --- FFN sublayer (channel mixer) ---
        if d_ffn > 0:
            self.ffn_norm = RMSNorm(d_model)
            self.ffn = SwiGLU(d_model, d_ffn, bias=bias)
        else:
            self.ffn_norm = None
            self.ffn = None

forward() ์ˆ˜์ • (line 280):

# ๋ณ€๊ฒฝ ์ „:
        return residual + self.out_proj(y)

# ๋ณ€๊ฒฝ ํ›„:
        x = residual + self.out_proj(y)
        # FFN sublayer (channel mixer)
        if self.ffn is not None:
            x = x + self.ffn(self.ffn_norm(x))
        return x

1c. model/transformer.py ์ˆ˜์ •

Mamba2Block ์ƒ์„ฑ์ž ํ˜ธ์ถœ ๋ณ€๊ฒฝ (line 124-132):

# ๋ณ€๊ฒฝ ์ „:
                    layers.append(Mamba2Block(
                        d_model=config.d_model,
                        d_state=config.mamba_d_state,
                        head_dim=config.mamba_head_dim,
                        expand=config.mamba_expand,
                        conv_kernel=config.mamba_conv_kernel,
                        n_groups=config.mamba_n_groups,
                        chunk_size=config.mamba_chunk_size,
                    ))

# ๋ณ€๊ฒฝ ํ›„:
                    layers.append(Mamba2Block(
                        d_model=config.d_model,
                        d_state=config.mamba_d_state,
                        head_dim=config.mamba_head_dim,
                        expand=config.mamba_expand,
                        conv_kernel=config.mamba_conv_kernel,
                        n_groups=config.mamba_n_groups,
                        chunk_size=config.mamba_chunk_size,
                        d_ffn=config.mamba_d_ffn,
                        bias=config.bias,
                    ))

_init_weights ์ˆ˜์ • (line 180-182):

# ๋ณ€๊ฒฝ ์ „:
        # Mamba2Block handles its own parameter init (A_log, D, dt_bias, etc.)
        if isinstance(module, Mamba2Block):
            return

# ๋ณ€๊ฒฝ ํ›„ (์ด 3์ค„์„ ์‚ญ์ œ):
# ์‚ญ์ œ ์ด์œ : FFN ์ถ”๊ฐ€ ํ›„ ๋‚ด๋ถ€ SwiGLU์˜ nn.Linear๊ฐ€ init ํ•„์š”.
# A_log, D, dt_bias๋Š” nn.Parameter์ด๋ฏ€๋กœ isinstance(nn.Linear) ์ฒดํฌ์— ๊ฑธ๋ฆฌ์ง€ ์•Š์•„
# ์ž๋™์œผ๋กœ ์Šคํ‚ต๋จ (Mamba2Block.__init__์—์„œ ์ง์ ‘ ์ดˆ๊ธฐํ™”๋จ).

1d. configs/hybrid_3b.yaml ์ˆ˜์ •

# mamba_chunk_size: 256 ๋’ค์— ์ถ”๊ฐ€:
  mamba_d_ffn: 4608

Step 1 ๊ฒ€์ฆ

cd /PROJECT/0325120031_A/ghong/taketimes/llm-bang
CUDA_VISIBLE_DEVICES=0 python -c "
import torch, sys
sys.path.insert(0, '.')
from model import LLM, LMConfig

config = LMConfig.from_yaml('configs/hybrid_3b.yaml')
print(f'mamba_d_ffn = {config.mamba_d_ffn}')

model = LLM(config)
total = sum(p.numel() for p in model.parameters())
print(f'Total params: {total:,} ({total/1e9:.2f}B)')

# Forward test
x = torch.randint(0, 64000, (1, 128))
logits, loss = model(x, targets=x)
print(f'Forward OK: logits shape={logits.shape}, loss={loss.item():.4f}')

# Backward test
loss.backward()
grads_ok = all(p.grad is not None for p in model.parameters() if p.requires_grad)
print(f'Backward OK: all grads exist = {grads_ok}')
"
# ์˜ˆ์ƒ ์ถœ๋ ฅ: Total params ~4.5B, Forward/Backward OK

Step 2: n_groups ์ˆ˜์ •

configs/hybrid_3b.yaml

# ๋ณ€๊ฒฝ ์ „:
  mamba_n_groups: 1

# ๋ณ€๊ฒฝ ํ›„:
  mamba_n_groups: 8

๊ฒ€์ฆ

n_heads(= d_inner / head_dim = 6144 / 64 = 96) % 8 == 0 โœ“ Step 1 ๊ฒ€์ฆ ์Šคํฌ๋ฆฝํŠธ์—์„œ ํ•จ๊ป˜ ํ™•์ธ๋จ (assertion์ด __init__์— ์žˆ์Œ).


Step 3: ํ•˜์ด๋ธŒ๋ฆฌ๋“œ ์•„ํ‚คํ…์ฒ˜ startup ๋กœ๊ทธ ์ถ”๊ฐ€

train/pretrain.py ์ˆ˜์ •

์œ„์น˜: line 296-297 (๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ถœ๋ ฅ ๋ถ€๋ถ„) ๋’ค์— ์ถ”๊ฐ€

    if is_main_process():
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Model parameters: {total_params:,}")
        print(f"LMConfig: {lm_config}")

        # --- ์—ฌ๊ธฐ๋ถ€ํ„ฐ ์ถ”๊ฐ€ ---
        if lm_config.use_hybrid:
            pattern = lm_config.hybrid_pattern.split()
            m_count = sum(1 for p in pattern if p == 'M')
            a_count = sum(1 for p in pattern if p == 'A')
            mamba_params = sum(
                p.numel() for n, p in model.named_parameters()
                if 'layers.' in n and pattern[int(n.split('.')[1])] == 'M'
            )
            attn_params = sum(
                p.numel() for n, p in model.named_parameters()
                if 'layers.' in n and pattern[int(n.split('.')[1])] == 'A'
            )
            other_params = total_params - mamba_params - attn_params
            print(
                f"  arch     : Hybrid Mamba-Transformer\n"
                f"  layers   : {m_count} Mamba + {a_count} Attention = {len(pattern)} total\n"
                f"  params   : Mamba {mamba_params/1e6:.0f}M + "
                f"Attn {attn_params/1e6:.0f}M + Other {other_params/1e6:.0f}M\n"
                f"  mamba cfg: d_state={lm_config.mamba_d_state}, "
                f"head_dim={lm_config.mamba_head_dim}, "
                f"expand={lm_config.mamba_expand}, "
                f"n_groups={lm_config.mamba_n_groups}, "
                f"d_ffn={lm_config.mamba_d_ffn}"
            )
        # --- ์ถ”๊ฐ€ ๋ ---

๊ฒ€์ฆ

Step 1 ๊ฒ€์ฆ ์‹คํ–‰ ์‹œ ๋กœ๊ทธ์— hybrid ์ •๋ณด๊ฐ€ ์ถœ๋ ฅ๋˜๋Š”์ง€ ํ™•์ธ.


Step 4: ์ฒดํฌํฌ์ธํŠธ resume ์•„ํ‚คํ…์ฒ˜ ๊ฒ€์ฆ

train/utils.py โ€” load_checkpoint() ์ˆ˜์ •

์œ„์น˜: line 179 (raw_model.load_state_dict(...)) ์ง์ „์— ์ถ”๊ฐ€

    # --- Architecture validation ---
    config_path = ckpt_dir / "config.yaml"
    if config_path.exists() and hasattr(raw_model, "config"):
        with open(config_path, "r", encoding="utf-8") as f:
            saved_cfg = yaml.safe_load(f)
        current_cfg = raw_model.config.to_dict()
        critical_keys = [
            "d_model", "n_layers", "n_heads", "n_kv_heads", "vocab_size",
            "use_hybrid", "hybrid_pattern",
        ]
        mismatches = []
        for key in critical_keys:
            saved_val = saved_cfg.get(key)
            current_val = current_cfg.get(key)
            if saved_val is not None and saved_val != current_val:
                mismatches.append(
                    f"  {key}: checkpoint={saved_val} vs current={current_val}"
                )
        if mismatches:
            raise ValueError(
                f"Checkpoint architecture mismatch!\n"
                f"Checkpoint dir: {ckpt_dir}\n"
                + "\n".join(mismatches)
                + "\nUse --config matching the checkpoint, or start fresh."
            )
    # --- End architecture validation ---

์ฐธ๊ณ : yaml์€ ์ด๋ฏธ train/utils.py line 23์—์„œ import ๋˜์–ด ์žˆ์Œ.

๊ฒ€์ฆ

# ์˜๋„์ ์œผ๋กœ ๋‹ค๋ฅธ config๋กœ resume ์‹œ๋„
CUDA_VISIBLE_DEVICES=0 python train/pretrain.py \
    --config configs/small.yaml \
    --train_data data/3b_train.bin \
    --resume checkpoints/hybrid_3b_run1/checkpoint-0001000
# ์˜ˆ์ƒ: ValueError "Checkpoint architecture mismatch!" ์ถœ๋ ฅ

Step 5: selective_scan NaN/Inf ๊ฐ์ง€

model/mamba_block.py โ€” selective_scan() ์ˆ˜์ •

์œ„์น˜: line 94 (y[:, t, :, :] = y_t.to(x.dtype)) ๋’ค์— ์ถ”๊ฐ€

        # Periodic NaN/Inf check (every 512 steps, < 1% overhead)
        if t % 512 == 511:
            if not torch.isfinite(h).all():
                raise RuntimeError(
                    f"NaN/Inf in Mamba SSM state at timestep {t}/{seq_len}. "
                    f"h stats: min={h.min().item():.4e}, max={h.max().item():.4e}, "
                    f"A_log range=[{A_log.min().item():.4f}, {A_log.max().item():.4f}]"
                )

๊ฒ€์ฆ

CUDA_VISIBLE_DEVICES=0 python -c "
import torch, sys
sys.path.insert(0, '.')
from model.mamba_block import Mamba2Block

block = Mamba2Block(d_model=256, d_state=64, head_dim=32, d_ffn=384)
x = torch.randn(1, 1024, 256)

# ์ •์ƒ ์ผ€์ด์Šค
y = block(x)
print(f'Normal: output shape={y.shape}, finite={torch.isfinite(y).all()}')

# NaN ์ฃผ์ž… ํ…Œ์ŠคํŠธ
block.A_log.data.fill_(100.0)  # ๋งค์šฐ ํฐ ๊ฐ’ โ†’ exp(100) overflow
try:
    y = block(x)
    print('WARNING: NaN not detected!')
except RuntimeError as e:
    print(f'NaN correctly detected: {e}')
"

Step 6: selective_scan ์ž…๋ ฅ shape ๊ฒ€์ฆ

model/mamba_block.py โ€” selective_scan() ์ˆ˜์ •

์œ„์น˜: line 49 (batch, seq_len, n_heads, head_dim = x.shape) ์ง์ „์— ์ถ”๊ฐ€

    # Input shape validation
    assert x.ndim == 4, f"x expected 4D (B,L,n_heads,head_dim), got {x.shape}"
    assert dt.ndim == 3, f"dt expected 3D (B,L,n_heads), got {dt.shape}"
    assert B.ndim == 4, f"B expected 4D (B,L,n_groups,d_state), got {B.shape}"
    assert C.ndim == 4, f"C expected 4D (B,L,n_groups,d_state), got {C.shape}"

์ตœ์ข… ๊ฒ€์ฆ ์ ˆ์ฐจ (๋ชจ๋“  Step ์™„๋ฃŒ ํ›„)

1. ๋ชจ๋ธ ์ƒ์„ฑ + Forward/Backward (๋‹จ์ผ GPU)

cd /PROJECT/0325120031_A/ghong/taketimes/llm-bang
CUDA_VISIBLE_DEVICES=0 python -c "
import torch, sys
sys.path.insert(0, '.')
from model import LLM, LMConfig

config = LMConfig.from_yaml('configs/hybrid_3b.yaml')
model = LLM(config).cuda()

total = sum(p.numel() for p in model.parameters())
print(f'Total params: {total:,} ({total/1e9:.2f}B)')
assert 4.0e9 < total < 5.0e9, f'Expected ~4.5B params, got {total/1e9:.2f}B'

# Forward
x = torch.randint(0, 64000, (2, 512)).cuda()
logits, loss = model(x, targets=x)
print(f'Forward: logits={logits.shape}, loss={loss.item():.4f}')

# Backward
loss.backward()
no_grad = [n for n, p in model.named_parameters() if p.requires_grad and p.grad is None]
assert len(no_grad) == 0, f'Missing gradients: {no_grad}'
print(f'Backward: all {sum(1 for p in model.parameters() if p.requires_grad)} params have grad')

# VRAM
print(f'VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB allocated')
"

2. DDP 8-GPU ํ…Œ์ŠคํŠธ (10 steps)

cd /PROJECT/0325120031_A/ghong/taketimes/llm-bang
torchrun --nproc_per_node=8 --master_port=29501 train/pretrain.py \
    --config configs/hybrid_3b.yaml \
    --train_data data/3b_train.bin \
    --batch_size 2 \
    --lr 1e-4 \
    --warmup_steps 5 \
    --grad_accum 1 \
    --max_steps 10 \
    --checkpoint_dir /tmp/hybrid_test_ckpt \
    --use_fp8
# ์˜ˆ์ƒ: 10 steps ์™„๋ฃŒ, ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ, startup ๋ฐฐ๋„ˆ์— hybrid ์ •๋ณด ์ถœ๋ ฅ

3. ์ฒดํฌํฌ์ธํŠธ Resume ํ…Œ์ŠคํŠธ

# Step 2 ์ฒดํฌํฌ์ธํŠธ์—์„œ resume
torchrun --nproc_per_node=8 --master_port=29501 train/pretrain.py \
    --config configs/hybrid_3b.yaml \
    --train_data data/3b_train.bin \
    --batch_size 2 \
    --lr 1e-4 \
    --warmup_steps 5 \
    --grad_accum 1 \
    --max_steps 20 \
    --checkpoint_dir /tmp/hybrid_test_ckpt \
    --resume /tmp/hybrid_test_ckpt/checkpoint-0000010 \
    --use_fp8
# ์˜ˆ์ƒ: step 10์—์„œ ์ด์–ด์„œ step 20๊นŒ์ง€ ํ•™์Šต

์ˆ˜์ •ํ•˜์ง€ ์•Š๋Š” ๊ฒƒ๋“ค (์˜๋„์  ์ œ์™ธ)

  • sequential scan ์„ฑ๋Šฅ: Python for-loop๋Š” ๋А๋ฆฌ์ง€๋งŒ ๊ตฌ์กฐ ๋ณ€๊ฒฝ์ด ํผ. ๋ณ„๋„ ํƒœ์Šคํฌ๋กœ chunked SSD ๊ตฌํ˜„
  • FP8 + Mamba ํ˜ผํ•ฉ: ํ˜„์žฌ ์„ค๊ณ„(Mamba=bf16, Attention=FP8)๊ฐ€ ์˜ฌ๋ฐ”๋ฆ„. te.fp8_autocast๋Š” te ๋ชจ๋“ˆ๋งŒ ์˜ํ–ฅ
  • DDP ์„ค์ •: find_unused_parameters=False, gradient_as_bucket_view=True ๋ชจ๋‘ ์ •์ƒ
  • pure Transformer ๋ชจ๋“œ: use_hybrid=False๋ฉด ๊ธฐ์กด ๋™์ž‘ ์œ ์ง€ (ํ•˜์œ„ ํ˜ธํ™˜)

์ˆ˜์ • ๋Œ€์ƒ ํŒŒ์ผ ์š”์•ฝ

ํŒŒ์ผ Step ๋ณ€๊ฒฝ ๋‚ด์šฉ
model/config.py 1a mamba_d_ffn ํ•„๋“œ + __post_init__ + to_dict()
model/mamba_block.py 1b, 5, 6 SwiGLU import, FFN sublayer, NaN ๊ฐ์ง€, shape ๊ฒ€์ฆ
model/transformer.py 1c Mamba2Block ์ƒ์„ฑ์ž์— d_ffn/bias ์ „๋‹ฌ, _init_weights ์ˆ˜์ •
configs/hybrid_3b.yaml 1d, 2 mamba_d_ffn: 4608, mamba_n_groups: 8
train/pretrain.py 3 Hybrid startup ๋กœ๊ทธ
train/utils.py 4 load_checkpoint() ์•„ํ‚คํ…์ฒ˜ ๊ฒ€์ฆ

์‹คํ–‰ ์ง€์‹œ (๋‹ค์Œ ์„ธ์…˜์šฉ)

์ด ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์—ฌ ๋‹ค์Œ ๋ช…๋ น์„ ๋‚ด๋ฆฌ๋ฉด ๋ฉ๋‹ˆ๋‹ค:

"์ด ๋ฌธ์„œ(hashed-drifting-harp.md)์˜ Step 16์„ ์ˆœ์„œ๋Œ€๋กœ ์‹คํ–‰ํ•ด ์ค˜. Step 1+2๋Š” ๋ณ‘๋ ฌ๋กœ, Step 36์€ ๋…๋ฆฝ์ ์œผ๋กœ ์ง„ํ–‰. ๊ฐ Step ์™„๋ฃŒ ํ›„ ํ•ด๋‹น ๊ฒ€์ฆ์„ ์‹คํ–‰ํ•˜๊ณ , ์ „์ฒด ์™„๋ฃŒ ํ›„ ์ตœ์ข… ๊ฒ€์ฆ ์ ˆ์ฐจ 3๋‹จ๊ณ„๋ฅผ ๋ชจ๋‘ ์‹คํ–‰ํ•ด ์ค˜."