StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Standard Transformer language model implementation."""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from taoTrain.core import BaseModel
from taoTrain.config import ModelConfig
from .registry import register_architecture
# ============================================================================
# Components
# ============================================================================
class PositionalEmbedding(nn.Module):
"""Sinusoidal positional embeddings."""
def __init__(self, dim: int, max_seq_length: int = 2048):
"""Initialize positional embeddings."""
super().__init__()
self.dim = dim
self.max_seq_length = max_seq_length
# Precompute positional embeddings
pe = torch.zeros(max_seq_length, dim)
pos = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(pos * div_term)
if dim % 2 == 1:
pe[:, 1::2] = torch.cos(pos * div_term[:-1])
else:
pe[:, 1::2] = torch.cos(pos * div_term)
self.register_buffer("pe", pe, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional embeddings to input.
Args:
x: Input tensor (batch, seq_len, hidden_dim)
Returns:
Input + positional embeddings
"""
seq_len = x.shape[1]
return x + self.pe[:seq_len]
class Attention(nn.Module):
"""Multi-head self-attention using scaled dot-product attention."""
def __init__(self, config: ModelConfig):
"""Initialize attention."""
super().__init__()
self.hidden_dim = config.hidden_dim
self.num_heads = config.num_heads
self.head_dim = config.head_dim
assert self.hidden_dim % self.num_heads == 0
# Linear projections
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
self.dropout_p = config.dropout
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass using scaled_dot_product_attention.
Args:
x: Shape (batch, seq_len, hidden_dim)
attention_mask: Shape (batch, seq_len)
Returns:
Output: Shape (batch, seq_len, hidden_dim)
"""
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
# Transpose for attention: (batch, num_heads, seq_len, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# NOTE: PyTorch's scaled_dot_product_attention does NOT support both
# explicit attn_mask AND is_causal=True together.
# When is_causal=True, PyTorch handles causal masking automatically.
# Padding positions are handled separately via loss computation (labels=-100).
# See: https://github.com/pytorch/pytorch/issues/96099
# Compute attention using scaled_dot_product_attention
# is_causal=True automatically applies causal masking
# We do NOT pass attn_mask when is_causal=True
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None, # Must be None when is_causal=True
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=True,
scale=None # Uses default scale of 1/sqrt(head_dim)
) # (batch, num_heads, seq_len, head_dim)
# Transpose back and reshape
out = out.transpose(1, 2).contiguous() # (batch, seq_len, num_heads, head_dim)
out = out.reshape(batch_size, seq_len, self.hidden_dim)
# Output projection
out = self.out_proj(out)
return out
class SwiGLU(nn.Module):
"""Swish Gated Linear Unit activation."""
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.0):
"""
Initialize SwiGLU.
Args:
in_dim: Input dimension
out_dim: Intermediate/hidden dimension
dropout: Dropout rate
"""
super().__init__()
# Project to 2x the intermediate dimension (for value and gate)
self.fc1 = nn.Linear(in_dim, 2 * out_dim)
self.fc2 = nn.Linear(out_dim, in_dim) # Project back to input dimension
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with SwiGLU activation.
Args:
x: Input tensor
Returns:
Gated activation output (same dimension as input)
"""
# Project to 2x intermediate dimension
x = self.fc1(x)
# Split into value and gate
x, gate = x.chunk(2, dim=-1)
# SwiGLU: value * swish(gate) = value * gate * sigmoid(gate)
x = x * F.silu(gate) # SiLU is Swish: x * sigmoid(x)
x = self.dropout(x)
x = self.fc2(x) # Project back to input dimension
return x
class FeedForward(nn.Module):
"""Feed-forward network with SwiGLU activation."""
def __init__(self, config: ModelConfig):
"""Initialize FFN with SwiGLU."""
super().__init__()
self.swiglu = SwiGLU(
in_dim=config.hidden_dim,
out_dim=config.intermediate_dim,
dropout=config.dropout
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with SwiGLU activation."""
return self.swiglu(x)
class TransformerBlock(nn.Module):
"""Single transformer block with attention and FFN."""
def __init__(self, config: ModelConfig):
"""Initialize transformer block."""
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_dim)
self.attn = Attention(config)
self.norm2 = nn.LayerNorm(config.hidden_dim)
self.ffn = FeedForward(config)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with pre-norm residual connections."""
# Attention with residual
x = x + self.attn(self.norm1(x), attention_mask=attention_mask)
# FFN with residual
x = x + self.ffn(self.norm2(x))
return x
# ============================================================================
# Transformer LM
# ============================================================================
@register_architecture("transformer")
class TransformerLM(BaseModel):
"""Standard Transformer language model."""
def __init__(self, config: ModelConfig):
"""Initialize Transformer LM."""
super().__init__(config)
# Embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
self.pos_embed = PositionalEmbedding(config.hidden_dim, max_seq_length=config.max_seq_length)
self.dropout = nn.Dropout(config.dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(config) for _ in range(config.num_layers)
])
# Final layer norm
self.final_norm = nn.LayerNorm(config.hidden_dim)
# Output projection (shared with input embeddings for efficiency)
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
# Weight tying (optional)
self.lm_head.weight = self.embed_tokens.weight
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize model weights."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=self.config.init_std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=self.config.init_std)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
"""
Forward pass.
Args:
input_ids: (batch_size, seq_len)
attention_mask: (batch_size, seq_len)
labels: (batch_size, seq_len) for loss computation
Returns:
Dict with 'logits' and optionally 'loss'
"""
batch_size, seq_len = input_ids.shape
# Embedding
x = self.embed_tokens(input_ids)
# Add positional embeddings
x = self.pos_embed(x)
x = self.dropout(x)
# Transformer blocks
for block in self.blocks:
x = block(x, attention_mask=attention_mask)
# Final normalization
x = self.final_norm(x)
# LM head
logits = self.lm_head(x) # (batch, seq_len, vocab_size)
# Loss computation
loss = None
if labels is not None:
# Flatten for loss computation
logits_flat = logits.view(-1, logits.size(-1)) # (batch * seq_len, vocab_size)
labels_flat = labels.view(-1)
# Only compute loss on valid targets (ignore -100 tokens)
loss = F.cross_entropy(
logits_flat,
labels_flat,
reduction='mean',
ignore_index=-100
)
return {
'logits': logits,
'loss': loss,
}