mindi-backup / src /model_architecture /code_transformer.py
Mindigenous
Initial full project backup with Git LFS
53f0cc2
"""
Component 4: Transformer model architecture for code generation.
This module defines a decoder-only transformer built from scratch in PyTorch.
It is modular through configuration so model size can be scaled up/down.
"""
from __future__ import annotations
import math
from dataclasses import asdict, dataclass
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class ModelConfig:
# Vocabulary size from tokenizer.
vocab_size: int = 50_000
# Maximum context length in tokens.
max_seq_len: int = 2048
# Core hidden size of transformer.
d_model: int = 1152
# Number of transformer blocks.
n_layers: int = 23
# Number of attention heads.
n_heads: int = 16
# Feed-forward hidden size.
d_ff: int = 4608
# Dropout for regularization.
dropout: float = 0.1
# Whether to tie token embedding and LM head weights.
tie_embeddings: bool = True
# Enable gradient checkpointing to reduce VRAM usage during training.
gradient_checkpointing: bool = False
# Initialization standard deviation.
init_std: float = 0.02
# Epsilon for layer normalization stability.
rms_norm_eps: float = 1e-5
@property
def head_dim(self) -> int:
if self.d_model % self.n_heads != 0:
raise ValueError("d_model must be divisible by n_heads.")
return self.d_model // self.n_heads
def get_model_presets() -> Dict[str, ModelConfig]:
"""
Returns standard size presets.
"""
return {
"small_180m": ModelConfig(d_model=896, n_layers=18, n_heads=14, d_ff=3584),
"medium_420m": ModelConfig(d_model=1152, n_layers=23, n_heads=16, d_ff=4608),
"large_800m": ModelConfig(d_model=1536, n_layers=24, n_heads=16, d_ff=6144),
}
class RMSNorm(nn.Module):
"""
RMSNorm is a lightweight normalization layer used in many modern LLMs.
"""
def __init__(self, dim: int, eps: float = 1e-5) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(norm + self.eps)
return self.weight * x
class RotaryEmbedding(nn.Module):
"""
Rotary positional embedding.
This injects token order information directly into query/key vectors.
"""
def __init__(self, head_dim: int, max_seq_len: int) -> None:
super().__init__()
if head_dim % 2 != 0:
raise ValueError("head_dim must be even for rotary embeddings.")
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self.register_buffer("cos_cached", torch.cos(freqs), persistent=False)
self.register_buffer("sin_cached", torch.sin(freqs), persistent=False)
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0) # [1,1,S,H/2]
sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0) # [1,1,S,H/2]
q = self._apply_rotary(q, cos, sin)
k = self._apply_rotary(k, cos, sin)
return q, k
@staticmethod
def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x_rot_even = x1 * cos - x2 * sin
x_rot_odd = x1 * sin + x2 * cos
out = torch.stack((x_rot_even, x_rot_odd), dim=-1).flatten(-2)
return out
class CausalSelfAttention(nn.Module):
"""
Multi-head causal self-attention for autoregressive code generation.
"""
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.n_heads = config.n_heads
self.head_dim = config.head_dim
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout)
self.rotary = RotaryEmbedding(head_dim=self.head_dim, max_seq_len=config.max_seq_len)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
bsz, seq_len, _ = x.shape
q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
q, k = self.rotary(q, k, seq_len=seq_len)
# Use PyTorch scaled dot-product attention with causal masking.
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=True,
scale=self.scale,
)
out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.o_proj(out)
class FeedForward(nn.Module):
"""
Two-layer feed-forward network with GELU activation.
"""
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.fc1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.fc2 = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = F.gelu(x, approximate="tanh")
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
"""
One transformer block:
norm -> attention -> residual
norm -> feed-forward -> residual
"""
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.norm1 = RMSNorm(config.d_model, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(config)
self.norm2 = RMSNorm(config.d_model, eps=config.rms_norm_eps)
self.ffn = FeedForward(config)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
x = x + self.ffn(self.norm2(x))
return x
class CodeTransformerLM(nn.Module):
"""
Full decoder-only language model for code generation.
"""
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.dropout = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.norm_final = RMSNorm(config.d_model, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
if config.tie_embeddings:
self.lm_head.weight = self.embed_tokens.weight
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module) -> None:
# Keep initialization stable for deep networks.
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
def enable_gradient_checkpointing(self, enabled: bool = True) -> None:
# Toggle gradient checkpointing mode.
self.config.gradient_checkpointing = enabled
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
if input_ids.dim() != 2:
raise ValueError("input_ids must be shape [batch, seq_len].")
x = self.embed_tokens(input_ids)
x = self.dropout(x)
for block in self.blocks:
if self.config.gradient_checkpointing and self.training:
x = torch.utils.checkpoint.checkpoint(block, x, attn_mask, use_reentrant=False)
else:
x = block(x, attn_mask=attn_mask)
x = self.norm_final(x)
logits = self.lm_head(x)
out: Dict[str, torch.Tensor] = {"logits": logits}
if labels is not None:
# Standard next-token cross entropy loss.
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
out["loss"] = loss
return out
def estimate_num_parameters(self) -> int:
# Returns total trainable parameter count.
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def summary(self) -> Dict[str, object]:
# Returns a simple structured summary for logs/CLI.
return {
"config": asdict(self.config),
"num_parameters": self.estimate_num_parameters(),
}