LLM-1B-Lab / _archive /llm-1b-model.py
Vjeong's picture
Initial commit: LLM-1B-Lab project setup
8a58ffe
"""
LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
================================================================
딥러닝 초보자를 위한 학습용 구현.
각 컴포넌트에 상세 주석을 달아 "왜 이렇게 하는지"를 설명합니다.
아키텍처 요약:
- Decoder-Only Transformer (Causal LM)
- RMSNorm (Pre-Normalization)
- Rotary Positional Embedding (RoPE)
- Grouped Query Attention (GQA)
- SwiGLU Feed-Forward Network
- Weight Tying (Embedding ↔ Output Head)
"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================================
# 1. 모델 설정 (Config)
# ============================================================================
@dataclass
class ModelConfig:
"""모델 하이퍼파라미터를 하나의 데이터클래스로 관리합니다.
규모별 프리셋:
- debug: ~10M (파이프라인 검증용)
- small: ~100M (중간 검증용)
- base: ~1.1B (최종 목표)
"""
vocab_size: int = 32_000
hidden_dim: int = 2048 # d_model: 모델의 기본 차원
num_layers: int = 22 # Transformer 블록 수
num_heads: int = 16 # Query 헤드 수
num_kv_heads: int = 4 # Key/Value 헤드 수 (GQA)
intermediate_dim: int = 5632 # FFN 중간 차원 (≈ 2.75 × hidden_dim)
max_seq_len: int = 2048 # 최대 시퀀스 길이
dropout: float = 0.0 # Pretraining에서는 보통 0 사용
rope_theta: float = 10000.0 # RoPE 주파수 베이스
norm_eps: float = 1e-6 # RMSNorm epsilon
@property
def head_dim(self) -> int:
"""각 어텐션 헤드의 차원."""
return self.hidden_dim // self.num_heads
@property
def num_kv_groups(self) -> int:
"""GQA에서 하나의 KV 헤드가 담당하는 Q 헤드 수."""
return self.num_heads // self.num_kv_heads
@classmethod
def debug_10m(cls) -> "ModelConfig":
"""~10M 파라미터 - 빠른 디버깅용."""
return cls(
hidden_dim=256, num_layers=6, num_heads=8,
num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
)
@classmethod
def small_100m(cls) -> "ModelConfig":
"""~100M 파라미터 - 중간 검증용."""
return cls(
hidden_dim=768, num_layers=12, num_heads=12,
num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
)
@classmethod
def base_1b(cls) -> "ModelConfig":
"""~1.1B 파라미터 - 최종 학습 목표."""
return cls() # 기본값이 1B 설정
# ============================================================================
# 2. RMSNorm (Root Mean Square Layer Normalization)
# ============================================================================
class RMSNorm(nn.Module):
"""RMSNorm: LayerNorm의 경량화 버전.
일반 LayerNorm과의 차이:
- 평균(mean)을 빼지 않음 → 연산 절약
- 분산 대신 RMS(Root Mean Square)로 정규화
- bias 파라미터 없음
수식:
RMSNorm(x) = (x / RMS(x)) * γ
RMS(x) = sqrt(mean(x²) + ε)
왜 정규화가 필요한가?
→ 레이어를 깊게 쌓으면 활성화 값의 스케일이 폭발하거나 소멸합니다.
→ 정규화로 각 레이어의 입력을 안정적인 범위로 유지합니다.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# γ (gamma): 학습 가능한 스케일 파라미터, 1로 초기화
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 1) 입력을 float32로 변환 (수치 안정성)
# bf16/fp16 상태에서 제곱합을 구하면 오버플로우 위험
x_float = x.float()
# 2) RMS 계산: sqrt(mean(x²) + ε)
rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
# rsqrt = 1/sqrt(x) → 나눗셈 대신 곱셈으로 대체 (더 빠름)
# 3) 정규화 후 원래 dtype으로 복원, 스케일 적용
return (x_float * rms).to(x.dtype) * self.weight
# ============================================================================
# 3. Rotary Positional Embedding (RoPE)
# ============================================================================
class RotaryPositionalEmbedding(nn.Module):
"""RoPE: 회전 행렬을 이용한 상대 위치 인코딩.
핵심 아이디어:
- 각 차원 쌍(2i, 2i+1)을 2D 평면의 좌표로 보고,
위치(position)에 비례한 각도만큼 회전시킵니다.
- 두 토큰의 어텐션 스코어(Q·K)는 상대 거리에만 의존하게 됩니다.
왜 RoPE인가?
- 절대 위치 임베딩: 각 위치에 고정 벡터를 더함 → 길이 일반화 어려움
- 상대 위치 임베딩: 구현 복잡, 추가 파라미터 필요
- RoPE: 파라미터 없이, 자연스럽게 상대 위치 정보 인코딩
수식:
θ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
RoPE(x, pos) = x를 각 차원 쌍에서 pos × θ_i 만큼 회전
"""
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.theta = theta
# 주파수 벡터 미리 계산 (학습 불필요 → buffer로 등록)
# freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("freqs", freqs, persistent=False)
# (max_seq_len, dim/2) 크기의 cos/sin 테이블 미리 계산
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
"""cos/sin 값을 미리 계산하여 캐싱합니다."""
t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
# outer product: (seq_len,) × (dim/2,) → (seq_len, dim/2)
angles = torch.outer(t, self.freqs)
self.register_buffer("cos_cached", angles.cos(), persistent=False)
self.register_buffer("sin_cached", angles.sin(), persistent=False)
def forward(
self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Q, K에 회전 변환을 적용합니다.
Args:
q: (batch, num_heads, seq_len, head_dim)
k: (batch, num_kv_heads, seq_len, head_dim)
position_offset: 시퀀스 시작 위치 오프셋 (추론 시 KV 캐시 사용 시)
Returns:
회전 변환이 적용된 (q_rotated, k_rotated)
"""
seq_len = q.shape[2]
# 필요 시 캐시 확장
if position_offset + seq_len > self.cos_cached.shape[0]:
self._build_cache(position_offset + seq_len)
# 현재 위치에 해당하는 cos/sin 슬라이스
cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
sin = self.sin_cached[position_offset : position_offset + seq_len]
q_rotated = self._apply_rotation(q, cos, sin)
k_rotated = self._apply_rotation(k, cos, sin)
return q_rotated, k_rotated
@staticmethod
def _apply_rotation(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
"""회전 변환 적용.
2D 회전 행렬:
[cos θ, -sin θ] [x1] [x1·cos θ - x2·sin θ]
[sin θ, cos θ] [x2] = [x1·sin θ + x2·cos θ]
이를 벡터 연산으로 효율적으로 구현합니다.
"""
# x: (batch, heads, seq_len, head_dim)
# 짝수/홀수 인덱스를 분리: (x0, x1, x2, x3, ...) → (x0, x2, ...), (x1, x3, ...)
x_even = x[..., 0::2] # 짝수 인덱스
x_odd = x[..., 1::2] # 홀수 인덱스
# 브로드캐스팅을 위해 차원 맞춤: (seq_len, dim/2) → (1, 1, seq_len, dim/2)
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
# 회전 적용
rotated_even = x_even * cos - x_odd * sin
rotated_odd = x_even * sin + x_odd * cos
# 다시 인터리빙: (even0, odd0, even1, odd1, ...)
out = torch.stack([rotated_even, rotated_odd], dim=-1)
return out.flatten(-2) # 마지막 두 차원을 합쳐 원래 shape 복원
# ============================================================================
# 4. Grouped Query Attention (GQA)
# ============================================================================
class GroupedQueryAttention(nn.Module):
"""GQA: Multi-Head Attention의 메모리 효율적 변형.
MHA vs GQA vs MQA:
- MHA (Multi-Head Attention): Q, K, V 모두 num_heads개 → 메모리 큼
- MQA (Multi-Query Attention): K, V는 1개 헤드 공유 → 품질 저하 우려
- GQA (Grouped Query Attention): K, V를 num_kv_heads개로 그룹화
→ MHA와 MQA의 중간, 좋은 품질-효율 균형
예시 (num_heads=16, num_kv_heads=4):
Q 헤드: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
K/V 그룹: [ 0 , 1 , 2 , 3 ]
→ Q 헤드 4개가 K/V 헤드 1개를 공유
Attention 수식:
Attention(Q, K, V) = softmax(Q·K^T / √d_k) · V
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.head_dim = config.head_dim
self.num_heads = config.num_heads
self.num_kv_heads = config.num_kv_heads
self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
# Q/K/V 프로젝션
# Q: hidden_dim → num_heads × head_dim
self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
# K, V: hidden_dim → num_kv_heads × head_dim (Q보다 작음!)
self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
# 출력 프로젝션: 모든 헤드의 출력을 다시 hidden_dim으로
self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
# RoPE
self.rope = RotaryPositionalEmbedding(
dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
)
# Attention dropout (pretraining에서는 보통 0)
self.attn_dropout = nn.Dropout(config.dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_offset: int = 0,
) -> torch.Tensor:
"""
Args:
x: (batch_size, seq_len, hidden_dim)
mask: (seq_len, seq_len) causal mask
position_offset: 위치 오프셋 (추론 시 사용)
Returns:
(batch_size, seq_len, hidden_dim)
"""
B, S, _ = x.shape
# ──────────────────────────────────────────────
# Step 1: Q, K, V 프로젝션
# ──────────────────────────────────────────────
q = self.q_proj(x) # (B, S, num_heads × head_dim)
k = self.k_proj(x) # (B, S, num_kv_heads × head_dim)
v = self.v_proj(x) # (B, S, num_kv_heads × head_dim)
# 멀티헤드 형태로 reshape
q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
# → (B, num_heads, S, head_dim)
k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
# → (B, num_kv_heads, S, head_dim)
v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
# ──────────────────────────────────────────────
# Step 2: RoPE 적용 (Q, K에만! V에는 적용하지 않음)
# ──────────────────────────────────────────────
# 위치 정보는 "어디를 볼지"(Q·K)에만 영향을 줘야 하고,
# "무엇을 가져올지"(V)에는 영향을 주면 안 됩니다.
q, k = self.rope(q, k, position_offset)
# ──────────────────────────────────────────────
# Step 3: GQA - KV 헤드 확장 (repeat)
# ──────────────────────────────────────────────
# num_kv_heads=4 → num_heads=16: 각 KV를 4번 반복
if self.num_kv_groups > 1:
k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
v = self._repeat_kv(v)
# ──────────────────────────────────────────────
# Step 4: Scaled Dot-Product Attention
# ──────────────────────────────────────────────
# PyTorch >= 2.0의 최적화된 구현 사용 (Flash Attention 자동 적용)
attn_out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.config.dropout if self.training else 0.0,
is_causal=(mask is None), # mask가 없으면 자동 causal masking
)
# → (B, num_heads, S, head_dim)
# ──────────────────────────────────────────────
# Step 5: 헤드 합치기 + 출력 프로젝션
# ──────────────────────────────────────────────
attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
# → (B, S, num_heads × head_dim)
return self.o_proj(attn_out) # → (B, S, hidden_dim)
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""KV 헤드를 Q 헤드 수에 맞게 반복합니다.
(B, num_kv_heads, S, head_dim) → (B, num_heads, S, head_dim)
예: num_kv_heads=4, num_kv_groups=4
[kv0, kv1, kv2, kv3] → [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
"""
B, H_kv, S, D = x.shape
x = x[:, :, None, :, :] # (B, H_kv, 1, S, D)
x = x.expand(B, H_kv, self.num_kv_groups, S, D) # (B, H_kv, groups, S, D)
return x.reshape(B, self.num_heads, S, D)
# ============================================================================
# 5. SwiGLU Feed-Forward Network
# ============================================================================
class SwiGLUFeedForward(nn.Module):
"""SwiGLU: Gated Linear Unit with Swish 활성화 함수.
기존 FFN:
FFN(x) = ReLU(x·W1 + b1)·W2 + b2
→ 단순한 비선형 변환
SwiGLU FFN:
SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down
→ 게이팅 메커니즘으로 정보 흐름을 제어
왜 SwiGLU가 더 좋은가?
- Swish(x) = x · sigmoid(x): 부드러운 활성화, 음수 영역 일부 허용
- Gate 벡터가 "어떤 정보를 통과시킬지" 학습
- PaLM, LLaMA 등에서 ReLU FFN 대비 일관된 성능 향상 보고
참고: W_gate와 W_up 두 개의 up-projection이 있어서
파라미터 수가 기존 FFN 대비 1.5배이지만, intermediate_dim을
조정하여 총 파라미터 수를 맞춥니다.
"""
def __init__(self, config: ModelConfig):
super().__init__()
# 게이트 프로젝션: hidden_dim → intermediate_dim
self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
# 업 프로젝션: hidden_dim → intermediate_dim
self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
# 다운 프로젝션: intermediate_dim → hidden_dim
self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU(x) = (Swish(gate(x)) ⊙ up(x)) · down
#
# 1) gate: 어떤 정보를 통과시킬지 결정 (Swish 활성화)
gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
# 2) up: 정보를 고차원으로 사영
up = self.up_proj(x)
# 3) element-wise 곱 (게이팅) → 다시 원래 차원으로
return self.down_proj(gate * up)
# ============================================================================
# 6. Transformer Block (하나의 레이어)
# ============================================================================
class TransformerBlock(nn.Module):
"""하나의 Transformer 디코더 블록.
구조 (Pre-Norm 방식):
x → RMSNorm → Attention → + (residual) → RMSNorm → FFN → + (residual) → out
Pre-Norm vs Post-Norm:
- Post-Norm (원래 Transformer): LayerNorm이 residual 이후
→ 깊은 모델에서 학습 불안정
- Pre-Norm (GPT-2 이후 표준): LayerNorm이 sublayer 이전
→ gradient 흐름이 원활, 학습이 안정적
Residual Connection의 역할:
- 입력을 출력에 더함 → gradient가 레이어를 건너뛸 수 있는 "고속도로"
- 22개 레이어를 쌓아도 학습이 가능한 핵심 이유
"""
def __init__(self, config: ModelConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
# Pre-Norm: Attention 전 정규화
self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
# Self-Attention
self.attention = GroupedQueryAttention(config)
# Pre-Norm: FFN 전 정규화
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
# Feed-Forward Network
self.feed_forward = SwiGLUFeedForward(config)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_offset: int = 0,
) -> torch.Tensor:
"""
Args:
x: (batch_size, seq_len, hidden_dim)
Returns:
(batch_size, seq_len, hidden_dim)
"""
# ── Attention sublayer with residual ──
# h = x + Attention(RMSNorm(x))
h = x + self.attention(self.attn_norm(x), mask, position_offset)
# ── FFN sublayer with residual ──
# out = h + FFN(RMSNorm(h))
out = h + self.feed_forward(self.ffn_norm(h))
return out
# ============================================================================
# 7. Full Transformer Model (LLaMA-style)
# ============================================================================
class LLMModel(nn.Module):
"""1B 파라미터 LLaMA-style Decoder-Only Transformer.
전체 구조:
Input Token IDs
→ Token Embedding
→ [TransformerBlock] × num_layers (+ Activation Checkpointing)
→ RMSNorm (최종)
→ Linear Head (→ vocab logits)
Weight Tying:
- 입력 Embedding과 출력 Linear Head의 가중치를 공유
- 파라미터 수 절약 (~65M) + 성능 유지/향상
- 직관: "단어의 의미 표현"과 "단어 예측"이 같은 공간을 사용
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# ── Token Embedding ──
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
# ── Transformer Blocks ──
self.layers = nn.ModuleList([
TransformerBlock(config, layer_idx=i)
for i in range(config.num_layers)
])
# ── 최종 정규화 ──
self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
# ── 출력 헤드 (Weight Tying) ──
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
# Weight Tying: lm_head의 가중치 = token_embedding의 가중치
self.lm_head.weight = self.token_embedding.weight
# 가중치 초기화
self._init_weights()
def _init_weights(self):
"""가중치 초기화 전략.
왜 초기화가 중요한가?
- 너무 크면: 활성화 폭발 → NaN
- 너무 작으면: gradient 소멸 → 학습 정체
- 적절한 초기화: 각 레이어의 출력 분산을 일정하게 유지
GPT-2 스타일 초기화:
- 일반 Linear: N(0, 0.02)
- Residual projection: N(0, 0.02 / √(2 × num_layers))
→ 레이어가 깊어질수록 residual 기여를 줄여 안정화
"""
std = 0.02
residual_std = std / math.sqrt(2 * self.config.num_layers)
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
# Residual projection 레이어에 축소된 초기화 적용
for layer in self.layers:
nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
def forward(
self,
input_ids: torch.Tensor,
targets: Optional[torch.Tensor] = None,
position_offset: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
input_ids: (batch_size, seq_len) - 토큰 ID
targets: (batch_size, seq_len) - 정답 토큰 ID (학습 시)
position_offset: 위치 오프셋 (추론 시)
Returns:
logits: (batch_size, seq_len, vocab_size)
loss: 스칼라 (targets 제공 시) 또는 None
"""
B, S = input_ids.shape
# ── Step 1: Token Embedding ──
# 각 토큰 ID를 hidden_dim 차원의 벡터로 변환
h = self.token_embedding(input_ids) # (B, S, hidden_dim)
# ── Step 2: Transformer Blocks ──
# Activation Checkpointing: 학습 시 메모리 절약
# (중간 활성화를 저장하지 않고, backward 시 재계산)
for layer in self.layers:
if self.training and torch.is_grad_enabled():
# Activation Checkpointing 적용
h = torch.utils.checkpoint.checkpoint(
layer, h, None, position_offset,
use_reentrant=False, # PyTorch >= 2.0 권장
)
else:
h = layer(h, mask=None, position_offset=position_offset)
# ── Step 3: 최종 정규화 ──
h = self.final_norm(h)
# ── Step 4: 출력 로짓 계산 ──
logits = self.lm_head(h) # (B, S, vocab_size)
# ── Step 5: Loss 계산 (학습 시) ──
loss = None
if targets is not None:
# Cross-Entropy Loss: 다음 토큰 예측
# logits: (B, S, V) → (B*S, V)
# targets: (B, S) → (B*S,)
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
targets.view(-1),
ignore_index=-100, # 패딩 토큰 무시
)
return logits, loss
def count_parameters(self, trainable_only: bool = True) -> int:
"""모델 파라미터 수 계산."""
if trainable_only:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
return sum(p.numel() for p in self.parameters())
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.9,
) -> torch.Tensor:
"""텍스트 생성 (추론).
Autoregressive 생성: 한 토큰씩 예측하여 이어붙이기.
Args:
input_ids: (1, prompt_len) - 초기 프롬프트
max_new_tokens: 생성할 최대 토큰 수
temperature: 확률 분포 날카로움 조절 (낮을수록 보수적)
top_k: 확률 상위 k개만 고려
top_p: 누적 확률 p까지만 고려 (nucleus sampling)
"""
self.eval()
generated = input_ids
for _ in range(max_new_tokens):
# 현재 시퀀스가 max_seq_len을 초과하면 잘라내기
ctx = generated[:, -self.config.max_seq_len:]
# Forward pass
logits, _ = self(ctx)
# 마지막 토큰의 logits만 사용 (다음 토큰 예측)
next_logits = logits[:, -1, :] / temperature
# ── Top-K 필터링 ──
if top_k > 0:
top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
min_top_k = top_k_values[:, -1].unsqueeze(-1)
next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
# ── Top-P (Nucleus) 필터링 ──
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 누적 확률이 top_p를 초과하는 토큰 제거
remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[remove_mask] = float("-inf")
# 원래 순서로 복원
next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
# 확률 분포에서 샘플링
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
# 생성된 토큰 이어붙이기
generated = torch.cat([generated, next_token], dim=1)
return generated
# ============================================================================
# 8. 유틸리티 함수
# ============================================================================
def count_parameters_detailed(model: LLMModel) -> dict:
"""모델의 파라미터 수를 컴포넌트별로 상세 출력합니다."""
total = 0
breakdown = {}
# Embedding
emb_params = model.token_embedding.weight.numel()
breakdown["token_embedding"] = emb_params
total += emb_params
# 각 레이어
layer_total = 0
layer_detail = {}
layer = model.layers[0]
for name, param in layer.named_parameters():
layer_detail[name] = param.numel()
layer_total += param.numel()
breakdown["per_layer"] = layer_detail
breakdown["per_layer_total"] = layer_total
breakdown["all_layers_total"] = layer_total * len(model.layers)
total += layer_total * len(model.layers)
# Final norm
norm_params = model.final_norm.weight.numel()
breakdown["final_norm"] = norm_params
total += norm_params
# LM head (weight tying이므로 실제 추가 파라미터 0)
breakdown["lm_head"] = "weight tying (0 additional)"
breakdown["total"] = total
return breakdown
def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
"""모델의 GPU 메모리 사용량을 추정합니다.
Args:
dtype_bytes: 2 (bf16/fp16) 또는 4 (fp32)
"""
# 대략적인 파라미터 수 계산
emb = config.vocab_size * config.hidden_dim
per_layer = (
config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
+ config.num_heads * config.head_dim * config.hidden_dim # O proj
+ 3 * config.hidden_dim * config.intermediate_dim # SwiGLU (gate + up + down)
+ 2 * config.hidden_dim # 2 × RMSNorm
)
total_params = emb + per_layer * config.num_layers + config.hidden_dim
model_gb = total_params * dtype_bytes / 1e9
optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states × fp32
gradient_gb = total_params * dtype_bytes / 1e9
# 활성화 메모리 (activation checkpointing 적용 가정)
# 대략적 추정: batch_size × seq_len × hidden_dim × num_layers × factor
activation_gb = (
batch_size * config.max_seq_len * config.hidden_dim * 4 # 바이트
* math.sqrt(config.num_layers) # checkpointing 효과
/ 1e9
)
return {
"total_parameters": total_params,
"model_weights_gb": round(model_gb, 2),
"optimizer_states_gb": round(optimizer_gb, 2),
"gradients_gb": round(gradient_gb, 2),
"activations_estimated_gb": round(activation_gb, 2),
"total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2),
}
# ============================================================================
# 9. 검증 스크립트 (실행 시)
# ============================================================================
if __name__ == "__main__":
print("=" * 70)
print("LLM-1B-Lab: 모델 아키텍처 검증")
print("=" * 70)
# ── 디버그 모델 (10M) 테스트 ──
print("\n[1] Debug Model (~10M params)")
cfg_debug = ModelConfig.debug_10m()
model_debug = LLMModel(cfg_debug)
n_params = model_debug.count_parameters()
print(f" 파라미터 수: {n_params:,} ({n_params / 1e6:.1f}M)")
# Forward pass 테스트
dummy_input = torch.randint(0, cfg_debug.vocab_size, (2, 64))
dummy_target = torch.randint(0, cfg_debug.vocab_size, (2, 64))
logits, loss = model_debug(dummy_input, dummy_target)
print(f" Input shape: {dummy_input.shape}")
print(f" Logits shape: {logits.shape}")
print(f" Loss: {loss.item():.4f}")
# 초기 loss ≈ ln(vocab_size) ≈ ln(32000) ≈ 10.37 이면 정상
expected_loss = math.log(cfg_debug.vocab_size)
print(f" Expected initial loss ≈ ln({cfg_debug.vocab_size}) = {expected_loss:.2f}")
# ── 1B 모델 파라미터 수 확인 ──
print("\n[2] Base Model (~1B params) — 파라미터 수만 확인")
cfg_1b = ModelConfig.base_1b()
# 메모리가 부족할 수 있으므로 meta device에서 생성
with torch.device("meta"):
model_1b = LLMModel(cfg_1b)
n_params_1b = model_1b.count_parameters()
print(f" 파라미터 수: {n_params_1b:,} ({n_params_1b / 1e6:.1f}M ≈ {n_params_1b / 1e9:.2f}B)")
# 상세 파라미터 분해
print("\n[3] 파라미터 상세 분해 (1B)")
detail = count_parameters_detailed(model_1b)
print(f" Token Embedding: {detail['token_embedding']:,}")
print(f" Per Layer Total: {detail['per_layer_total']:,}")
print(f" All Layers ({cfg_1b.num_layers}): {detail['all_layers_total']:,}")
print(f" Final Norm: {detail['final_norm']:,}")
print(f" LM Head: {detail['lm_head']}")
print(f" ────────────────────────")
print(f" TOTAL: {detail['total']:,}")
# 메모리 추정
print("\n[4] GPU 메모리 추정 (A100 40GB, bf16, batch_size=4)")
mem = estimate_memory_gb(cfg_1b, batch_size=4, dtype_bytes=2)
print(f" 모델 가중치: {mem['model_weights_gb']} GB")
print(f" 옵티마이저: {mem['optimizer_states_gb']} GB")
print(f" 기울기: {mem['gradients_gb']} GB")
print(f" 활성화 (추정): {mem['activations_estimated_gb']} GB")
print(f" ────────────────────────")
print(f" 총 추정: {mem['total_estimated_gb']} GB")
# 텍스트 생성 테스트 (디버그 모델)
print("\n[5] 텍스트 생성 테스트 (10M debug model, 랜덤 가중치)")
prompt = torch.randint(0, cfg_debug.vocab_size, (1, 10))
generated = model_debug.generate(prompt, max_new_tokens=20, temperature=1.0, top_k=50)
print(f" Prompt length: {prompt.shape[1]}")
print(f" Generated length: {generated.shape[1]}")
print(f" Generated token IDs: {generated[0].tolist()}")
print("\n" + "=" * 70)
print("✅ 모든 검증 통과!")
print("=" * 70)