import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from .config import CodsworthConfig class RotaryPositionalEmbedding(nn.Module): """Rotary Position Embedding (RoPE) - https://arxiv.org/abs/2104.09864""" def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache( max_position_embeddings, torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len: int, dtype: torch.dtype): max_seq_len = self.max_position_embeddings t = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype = None): if dtype is None: dtype = self.cos_cached.dtype return ( self.cos_cached[:seq_len].to(device=device, dtype=dtype, non_blocking=True), self.sin_cached[:seq_len].to(device=device, dtype=dtype, non_blocking=True), ) def rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: return (q * cos) + (rotate_half(q) * sin) class CausalSelfAttention(nn.Module): """Causal self-attention with optional flash attention support.""" def __init__( self, embed_dim: int, num_heads: int, head_dim: int, dropout: float = 0.0, bias: bool = True, max_position_embeddings: int = 2048, use_rope: bool = True, rope_theta: float = 10000.0, use_flash_attention: bool = True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = head_dim self.dropout = dropout self.use_flash_attention = use_flash_attention self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.use_rope = use_rope if use_rope: self.rotary_emb = RotaryPositionalEmbedding( head_dim, max_position_embeddings, rope_theta ) self.scale = head_dim ** -0.5 self.causal_mask = None def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, seq_len, _ = x.shape qkv = self.qkv_proj(x) q, k, v = qkv.split(self.embed_dim, dim=-1) q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) if self.use_rope: cos, sin = self.rotary_emb(seq_len, x.device, x.dtype) q = apply_rotary_pos_emb(q, cos, sin) k = apply_rotary_pos_emb(k, cos, sin) if self.use_flash_attention and self.training is False: return self._flash_attention_forward(q, k, v, batch_size, seq_len) else: return self._standard_attention_forward(q, k, v, batch_size, seq_len, attention_mask) def _flash_attention_forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, batch_size: int, seq_len: int, ) -> torch.Tensor: try: import flash_attn from flash_attn import flash_attn_func q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) out = flash_attn_func( q, k, v, causal=True, dropout_p=self.dropout if self.training else 0.0, softmax_scale=self.scale, ) out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) out = self.out_proj(out) return out except ImportError: return self._standard_attention_forward(q, k, v, batch_size, seq_len, None) def _standard_attention_forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, batch_size: int, seq_len: int, attention_mask: Optional[torch.Tensor], ) -> torch.Tensor: attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale if attention_mask is not None: attn_weights = attn_weights + attention_mask mask = torch.triu( torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), diagonal=1, ) attn_weights = attn_weights.masked_fill(mask, float("-inf")) attn_probs = F.softmax(attn_weights, dim=-1) if self.dropout > 0.0 and self.training: attn_probs = F.dropout(attn_probs, p=self.dropout) out = torch.matmul(attn_probs, v) out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) out = self.out_proj(out) return out class FeedForward(nn.Module): """Feed-forward network with SwiGLU activation.""" def __init__(self, embed_dim: int, ffn_hidden_dim: int, dropout: float = 0.0, bias: bool = True): super().__init__() self.embed_dim = embed_dim self.hidden_dim = ffn_hidden_dim self.gate_proj = nn.Linear(embed_dim, ffn_hidden_dim * 2, bias=bias) self.up_proj = nn.Linear(embed_dim, ffn_hidden_dim, bias=bias) self.down_proj = nn.Linear(ffn_hidden_dim, embed_dim, bias=bias) self.dropout = nn.Dropout(dropout) self.act_fn = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: gate, up = self.gate_proj(x).chunk(2, dim=-1) gate = self.act_fn(gate) return self.down_proj(self.dropout(gate * up)) class TransformerBlock(nn.Module): """Single transformer block with attention and feed-forward.""" def __init__( self, embed_dim: int, num_heads: int, head_dim: int, ffn_hidden_dim: int, dropout: float = 0.0, attention_dropout: float = 0.0, ffn_dropout: float = 0.0, bias: bool = True, max_position_embeddings: int = 2048, use_rope: bool = True, rope_theta: float = 10000.0, use_flash_attention: bool = True, use_gradient_checkpointing: bool = False, ): super().__init__() self.embed_dim = embed_dim self.input_layernorm = nn.LayerNorm(embed_dim, bias=bias) self.self_attention = CausalSelfAttention( embed_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, dropout=attention_dropout, bias=bias, max_position_embeddings=max_position_embeddings, use_rope=use_rope, rope_theta=rope_theta, use_flash_attention=use_flash_attention, ) self.post_attention_layernorm = nn.LayerNorm(embed_dim, bias=bias) self.feed_forward = FeedForward( embed_dim=embed_dim, ffn_hidden_dim=ffn_hidden_dim, dropout=ffn_dropout, bias=bias, ) self.use_gradient_checkpointing = use_gradient_checkpointing self.dropout = nn.Dropout(dropout) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = x x = self.input_layernorm(x) if self.use_gradient_checkpointing and self.training: x = self._gradient_checkpointed_forward( x, attention_mask, position_ids ) else: x = self.self_attention(x, attention_mask, position_ids) x = residual + self.dropout(x) residual = x x = self.post_attention_layernorm(x) x = self.feed_forward(x) x = residual + self.dropout(x) return x def _gradient_checkpointed_forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.Tensor], ) -> torch.Tensor: return torch.utils.checkpoint.checkpoint( self.self_attention, x, attention_mask, position_ids, use_reentrant=False, ) class CodsworthTransformer(nn.Module): """Codsworth transformer language model.""" def __init__(self, config: CodsworthConfig): super().__init__() self.config = config self.vocab_embedding = nn.Embedding( config.vocab_size, config.embedding_dim, padding_idx=config.pad_token_id ) self.embedding_dropout = nn.Dropout(config.embedding_dropout) self.layers = nn.ModuleList([ TransformerBlock( embed_dim=config.embedding_dim, num_heads=config.num_heads, head_dim=config.head_dim, ffn_hidden_dim=config.ffn_hidden_dim, dropout=config.dropout, attention_dropout=config.attention_dropout, ffn_dropout=config.ffn_dropout, bias=config.use_bias, max_position_embeddings=config.max_position_embeddings, use_rope=config.use_rope, rope_theta=config.rope_theta, use_flash_attention=config.use_flash_attention, use_gradient_checkpointing=config.use_gradient_checkpointing, ) for _ in range(config.num_layers) ]) self.final_layernorm = nn.LayerNorm( config.embedding_dim, bias=config.use_bias ) self.lm_head = nn.Linear( config.embedding_dim, config.vocab_size, bias=False ) self._tie_weights() self.apply(self._init_weights) def _tie_weights(self): self.lm_head.weight = self.vocab_embedding.weight def _init_weights(self, module: nn.Module): if isinstance(module, nn.Linear): torch.nn.init.normal_( module.weight, mean=0.0, std=0.02 ) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_( module.weight, mean=0.0, std=0.02 ) if module.padding_idx is not None: with torch.no_grad(): module.weight[module.padding_idx].zero_() def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: batch_size, seq_len = input_ids.shape token_embeds = self.vocab_embedding(input_ids) hidden_states = self.embedding_dropout(token_embeds) for layer in self.layers: hidden_states = layer(hidden_states, attention_mask, position_ids) hidden_states = self.final_layernorm(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: loss = self._compute_loss(logits, labels) return { "logits": logits, "loss": loss, "hidden_states": hidden_states, } def _compute_loss( self, logits: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: batch_size, seq_len, vocab_size = logits.shape shift_logits = logits[:, :-1, :].reshape(-1, vocab_size) shift_labels = labels[:, 1:].reshape(-1) loss_fct = nn.CrossEntropyLoss( ignore_index=self.config.pad_token_id, reduction="mean", ) return loss_fct(shift_logits, shift_labels) def generate( self, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, eos_token_id: Optional[int] = None, ) -> torch.Tensor: self.eval() batch_size = input_ids.shape[0] generated = input_ids.clone() for _ in range(max_new_tokens): if generated.shape[1] > self.config.max_position_embeddings: generated = generated[:, -self.config.max_position_embeddings:] outputs = self.forward(generated) logits = outputs["logits"] next_token_logits = logits[:, -1, :] / temperature if top_k is not None: v = torch.topk(next_token_logits, min(top_k, logits.shape[-1]))[0] next_token_logits = torch.where( next_token_logits < v[..., -1:], torch.tensor(float("-inf"), device=logits.device), next_token_logits, ) if top_p is not None: sorted_logits, sorted_indices = torch.sort( next_token_logits, descending=True ) cumulative_probs = torch.cumsum( F.softmax(sorted_logits, dim=-1), dim=-1 ) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = ( sorted_indices_to_remove[..., :-1].clone() ) sorted_indices_to_remove[..., 0] = False for i in range(batch_size): indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] next_token_logits[i, indices_to_remove] = float("-inf") probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return generated @torch.no_grad() def encode( self, input_ids: torch.Tensor, ) -> torch.Tensor: outputs = self.forward(input_ids) return outputs["hidden_states"] @torch.no_grad() def decode( self, hidden_states: torch.Tensor, ) -> torch.Tensor: logits = self.lm_head(self.final_layernorm(hidden_states)) return logits def get_num_params(self, trainable_only: bool = False) -> int: if trainable_only: return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters())