Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""
Advanced Transformer Architecture Components
State-of-the-art implementations for Claude Opus 4 scale training
"""
import math
import warnings
from typing import Optional, Tuple, Union, Dict, Any
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
@dataclass
class ModelConfig:
"""Advanced model configuration for Claude Opus 4 scale models"""
vocab_size: int = 100352
n_positions: int = 8192
n_embd: int = 4096
n_layer: int = 64
n_head: int = 32
n_kv_head: int = 8
rotary_dim: int = 128
intermediate_size: int = 14336
activation: str = "swiglu"
norm_type: str = "rmsnorm"
norm_eps: float = 1e-6
dropout: float = 0.0
attention_dropout: float = 0.0
residual_dropout: float = 0.05
embed_dropout: float = 0.05
tie_word_embeddings: bool = False
use_cache: bool = True
attention_bias: bool = False
mlp_bias: bool = False
flash_attention: bool = True
sliding_window: Optional[int] = 4096
rope_theta: float = 500000.0
rope_scaling: Optional[Dict] = None
gradient_checkpointing: bool = True
max_position_embeddings: int = 8192
def __post_init__(self):
"""Validate configuration after initialization"""
# Check n_head is divisible by n_kv_head
if self.n_head % self.n_kv_head != 0:
raise ValueError(
f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head})"
)
# Check rotary_dim
head_dim = self.n_embd // self.n_head
if self.rotary_dim > head_dim:
raise ValueError(
f"rotary_dim ({self.rotary_dim}) cannot exceed head_dim ({head_dim})"
)
# Warn about suboptimal settings
if self.flash_attention and not FLASH_ATTENTION_AVAILABLE:
warnings.warn(
"flash_attention=True but Flash Attention not installed. "
"Install with: pip install flash-attn --no-build-isolation"
)
if self.gradient_checkpointing and self.use_cache:
warnings.warn(
"gradient_checkpointing=True with use_cache=True may cause issues. "
"Cache will be disabled during training with gradient checkpointing."
)
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = float(eps)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Positional Embedding with enhanced stability for long contexts"""
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, scaling_factor: float = 1.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scaling_factor = scaling_factor
# Use float64 for better numerical precision
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float64) / self.dim))
self.register_buffer("inv_freq", inv_freq.float(), persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, x, seq_len):
# Apply scaling for extrapolation
scaled_seq_len = int(seq_len * self.scaling_factor)
if scaled_seq_len != self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != x.device:
self._seq_len_cached = scaled_seq_len
t = torch.arange(scaled_seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# Use float32 for cos/sin computation (more stable)
self._cos_cached = emb.cos().to(x.dtype)[None, None, :, :]
self._sin_cached = emb.sin().to(x.dtype)[None, None, :, :]
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[-2]
self._update_cos_sin_cache(x, seq_len)
return self._cos_cached[:, :, :seq_len, ...], self._sin_cached[:, :, :seq_len, ...]
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply rotary positional embedding to query and key tensors."""
# Slice cos/sin to match the head dimension
head_dim = q.shape[-1]
cos = cos[..., :head_dim]
sin = sin[..., :head_dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SwiGLU(nn.Module):
"""SwiGLU activation function"""
def __init__(self, config: ModelConfig):
super().__init__()
self.hidden_size = config.n_embd
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
# Mark for scaled initialization
self.down_proj.scale_init = True
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention with Flash Attention support"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.hidden_size = config.n_embd
self.num_heads = config.n_head
self.num_kv_heads = config.n_kv_head
self.head_dim = self.hidden_size // self.num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.rotary_emb = RotaryPositionalEmbedding(
config.rotary_dim,
max_position_embeddings=config.n_positions,
base=int(config.rope_theta),
scaling_factor=1.0
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def _repeat_kv(self, hidden_states, n_rep):
"""Repeat key/value heads to match query heads"""
batch, num_kv_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
def forward(self, hidden_states, attention_mask=None, use_cache=False, past_key_value=None):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
cos, sin = self.rotary_emb(value_states, seq_len=q_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Handle past key values for generation
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
# Repeat k/v heads if n_kv_heads < n_heads
key_states = self._repeat_kv(key_states, self.num_kv_groups)
value_states = self._repeat_kv(value_states, self.num_kv_groups)
# Flash Attention if available
if FLASH_ATTENTION_AVAILABLE and self.config.flash_attention:
attn_output = flash_attn_func(
query_states.transpose(1, 2),
key_states.transpose(1, 2),
value_states.transpose(1, 2),
dropout_p=self.attention_dropout.p if self.training else 0.0,
causal=True,
window_size=(self.config.sliding_window, self.config.sliding_window),
).transpose(1, 2)
else:
# Prefer PyTorch SDPA for stability and memory efficiency
try:
# SDPA expects (batch, heads, seq, dim)
# CRITICAL FIX: Improved mask handling
sdpa_mask = None
if attention_mask is not None:
# Convert additive mask to boolean for stability
# Additive masks use large negative values for masked positions
sdpa_mask = attention_mask > -1e8
# OR keep as additive but ensure correct dtype
# sdpa_mask = attention_mask.to(query_states.dtype)
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=sdpa_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0,
is_causal=(sdpa_mask is None) # causal if we didn't pass a combined mask
)
except Exception:
# Fallback to manual attention with NaN protection
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask.to(attn_weights.dtype)
# CRITICAL FIX: Clamp before softmax to prevent all -inf rows (NaN)
mask_value = -1e4 if attn_weights.dtype in (torch.float16, torch.bfloat16) else -1e9
attn_weights = torch.clamp(attn_weights, min=mask_value, max=1e4)
# Use float32 for softmax stability
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Add small epsilon to prevent exact zeros
attn_weights = attn_weights + 1e-10
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
past_key_value = (key_states, value_states) if use_cache else None
return attn_output, past_key_value
class TransformerBlock(nn.Module):
"""Advanced Transformer Block"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Normalization
if config.norm_type == "rmsnorm":
self.input_layernorm = RMSNorm(config.n_embd, eps=float(config.norm_eps))
self.post_attention_layernorm = RMSNorm(config.n_embd, eps=float(config.norm_eps))
else:
self.input_layernorm = nn.LayerNorm(config.n_embd, eps=float(config.norm_eps))
self.post_attention_layernorm = nn.LayerNorm(config.n_embd, eps=float(config.norm_eps))
# Attention
self.self_attn = GroupedQueryAttention(config)
# MLP
if config.activation == "swiglu":
self.mlp = SwiGLU(config)
else:
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, config.intermediate_size, bias=config.mlp_bias),
nn.GELU() if config.activation == "gelu" else nn.ReLU(),
nn.Linear(config.intermediate_size, config.n_embd, bias=config.mlp_bias)
)
self.residual_dropout = nn.Dropout(config.residual_dropout)
def forward(self, hidden_states, attention_mask=None, use_cache=False, past_key_value=None):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_output, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
use_cache=use_cache,
past_key_value=past_key_value,
)
# Residual connection
hidden_states = residual + self.residual_dropout(attn_output)
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
hidden_states = residual + self.residual_dropout(mlp_output)
return hidden_states, present_key_value
class AdvancedGPTModel(nn.Module):
"""Advanced GPT Model with modern improvements"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
self.embed_dropout = nn.Dropout(config.embed_dropout)
# Transformer blocks
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.n_layer)
])
# Final norm
if config.norm_type == "rmsnorm":
self.norm = RMSNorm(config.n_embd, eps=float(config.norm_eps))
else:
self.norm = nn.LayerNorm(config.n_embd, eps=float(config.norm_eps))
# Output projection
if config.tie_word_embeddings:
self.lm_head = None
else:
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights using improved scaled initialization"""
if isinstance(module, nn.Linear):
# Use truncated normal for better convergence
std = 0.02
if hasattr(module, 'scale_init') and module.scale_init:
# Scale down residual layers (GPT-3/LLaMA style)
std /= math.sqrt(2 * self.config.n_layer)
torch.nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-2*std, b=2*std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'padding_idx') and module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def forward(self, input_ids, attention_mask=None, past_key_values=None, labels=None, use_cache=None):
batch_size, seq_length = input_ids.shape
use_cache = use_cache if use_cache is not None else self.config.use_cache
# Embeddings
hidden_states = self.embed_tokens(input_ids)
hidden_states = self.embed_dropout(hidden_states)
# Build final 4D additive attention mask combining causal and key-side padding masks
# Use finite large negatives for stability with fp16/bf16
if hidden_states.dtype in (torch.float16, torch.bfloat16):
mask_value = -1e4
else:
mask_value = -1e9
# Causal mask: shape (1, 1, seq, seq)
causal_mask = torch.triu(
torch.full((seq_length, seq_length), mask_value, device=input_ids.device),
diagonal=1
)[None, None, :, :]
if attention_mask is not None:
# attention_mask expected as (batch, seq) with 1 for tokens, 0 for padding
attn1d = attention_mask.to(hidden_states.dtype)
# Key-side padding mask: (batch, 1, 1, seq)
key_mask = (1 - attn1d) * mask_value
key_mask = key_mask[:, None, None, :]
# Combine (do NOT apply query-side mask to avoid fully masked rows and NaNs)
attention_mask = causal_mask + key_mask
else:
attention_mask = causal_mask
# Transform through layers
past_key_values = past_key_values if past_key_values is not None else [None] * len(self.layers)
present_key_values = []
for i, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
if self.config.gradient_checkpointing and self.training:
# CRITICAL FIX: Disable cache during gradient checkpointing
# Checkpointing discards activations, incompatible with caching
hidden_states, _ = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
attention_mask,
False, # Force use_cache=False
None, # Force past_key_value=None
use_reentrant=False,
)
present_key_value = None
else:
hidden_states, present_key_value = layer(
hidden_states, attention_mask, use_cache, past_key_value
)
# Only append cache if not using gradient checkpointing
if use_cache and not (self.config.gradient_checkpointing and self.training):
present_key_values.append(present_key_value)
hidden_states = self.norm(hidden_states)
# Compute logits
if self.lm_head is not None:
logits = self.lm_head(hidden_states)
else:
logits = F.linear(hidden_states, self.embed_tokens.weight)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return {
'loss': loss,
'logits': logits,
'past_key_values': present_key_values if use_cache else None,
'hidden_states': hidden_states,
}