| import math |
| from typing import Optional |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from functools import partial |
|
|
| from .config import CodsworthConfig |
|
|
|
|
| class RotaryPositionalEmbedding(nn.Module): |
| """Rotary Position Embedding (RoPE) - https://arxiv.org/abs/2104.09864""" |
| |
| def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0): |
| super().__init__() |
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| |
| self._set_cos_sin_cache( |
| max_position_embeddings, |
| torch.get_default_dtype() |
| ) |
| |
| def _set_cos_sin_cache(self, seq_len: int, dtype: torch.dtype): |
| max_seq_len = self.max_position_embeddings |
| t = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) |
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) |
| |
| def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype = None): |
| if dtype is None: |
| dtype = self.cos_cached.dtype |
| return ( |
| self.cos_cached[:seq_len].to(device=device, dtype=dtype, non_blocking=True), |
| self.sin_cached[:seq_len].to(device=device, dtype=dtype, non_blocking=True), |
| ) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| return (q * cos) + (rotate_half(q) * sin) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| """Causal self-attention with optional flash attention support.""" |
| |
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| head_dim: int, |
| dropout: float = 0.0, |
| bias: bool = True, |
| max_position_embeddings: int = 2048, |
| use_rope: bool = True, |
| rope_theta: float = 10000.0, |
| use_flash_attention: bool = True, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| self.dropout = dropout |
| self.use_flash_attention = use_flash_attention |
| |
| self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| |
| self.use_rope = use_rope |
| if use_rope: |
| self.rotary_emb = RotaryPositionalEmbedding( |
| head_dim, max_position_embeddings, rope_theta |
| ) |
| |
| self.scale = head_dim ** -0.5 |
| self.causal_mask = None |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| batch_size, seq_len, _ = x.shape |
| |
| qkv = self.qkv_proj(x) |
| q, k, v = qkv.split(self.embed_dim, dim=-1) |
| |
| q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| if self.use_rope: |
| cos, sin = self.rotary_emb(seq_len, x.device, x.dtype) |
| q = apply_rotary_pos_emb(q, cos, sin) |
| k = apply_rotary_pos_emb(k, cos, sin) |
| |
| if self.use_flash_attention and self.training is False: |
| return self._flash_attention_forward(q, k, v, batch_size, seq_len) |
| else: |
| return self._standard_attention_forward(q, k, v, batch_size, seq_len, attention_mask) |
| |
| def _flash_attention_forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| batch_size: int, |
| seq_len: int, |
| ) -> torch.Tensor: |
| try: |
| import flash_attn |
| from flash_attn import flash_attn_func |
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
| |
| out = flash_attn_func( |
| q, k, v, |
| causal=True, |
| dropout_p=self.dropout if self.training else 0.0, |
| softmax_scale=self.scale, |
| ) |
| |
| out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) |
| out = self.out_proj(out) |
| |
| return out |
| except ImportError: |
| return self._standard_attention_forward(q, k, v, batch_size, seq_len, None) |
| |
| def _standard_attention_forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| batch_size: int, |
| seq_len: int, |
| attention_mask: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| |
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask |
| |
| mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), |
| diagonal=1, |
| ) |
| attn_weights = attn_weights.masked_fill(mask, float("-inf")) |
| |
| attn_probs = F.softmax(attn_weights, dim=-1) |
| if self.dropout > 0.0 and self.training: |
| attn_probs = F.dropout(attn_probs, p=self.dropout) |
| |
| out = torch.matmul(attn_probs, v) |
| out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) |
| out = self.out_proj(out) |
| |
| return out |
|
|
|
|
| class FeedForward(nn.Module): |
| """Feed-forward network with SwiGLU activation.""" |
| |
| def __init__(self, embed_dim: int, ffn_hidden_dim: int, dropout: float = 0.0, bias: bool = True): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.hidden_dim = ffn_hidden_dim |
| |
| self.gate_proj = nn.Linear(embed_dim, ffn_hidden_dim * 2, bias=bias) |
| self.up_proj = nn.Linear(embed_dim, ffn_hidden_dim, bias=bias) |
| self.down_proj = nn.Linear(ffn_hidden_dim, embed_dim, bias=bias) |
| self.dropout = nn.Dropout(dropout) |
| |
| self.act_fn = nn.SiLU() |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| gate, up = self.gate_proj(x).chunk(2, dim=-1) |
| gate = self.act_fn(gate) |
| return self.down_proj(self.dropout(gate * up)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Single transformer block with attention and feed-forward.""" |
| |
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| head_dim: int, |
| ffn_hidden_dim: int, |
| dropout: float = 0.0, |
| attention_dropout: float = 0.0, |
| ffn_dropout: float = 0.0, |
| bias: bool = True, |
| max_position_embeddings: int = 2048, |
| use_rope: bool = True, |
| rope_theta: float = 10000.0, |
| use_flash_attention: bool = True, |
| use_gradient_checkpointing: bool = False, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| |
| self.input_layernorm = nn.LayerNorm(embed_dim, bias=bias) |
| self.self_attention = CausalSelfAttention( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| head_dim=head_dim, |
| dropout=attention_dropout, |
| bias=bias, |
| max_position_embeddings=max_position_embeddings, |
| use_rope=use_rope, |
| rope_theta=rope_theta, |
| use_flash_attention=use_flash_attention, |
| ) |
| |
| self.post_attention_layernorm = nn.LayerNorm(embed_dim, bias=bias) |
| self.feed_forward = FeedForward( |
| embed_dim=embed_dim, |
| ffn_hidden_dim=ffn_hidden_dim, |
| dropout=ffn_dropout, |
| bias=bias, |
| ) |
| |
| self.use_gradient_checkpointing = use_gradient_checkpointing |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| residual = x |
| x = self.input_layernorm(x) |
| |
| if self.use_gradient_checkpointing and self.training: |
| x = self._gradient_checkpointed_forward( |
| x, attention_mask, position_ids |
| ) |
| else: |
| x = self.self_attention(x, attention_mask, position_ids) |
| |
| x = residual + self.dropout(x) |
| |
| residual = x |
| x = self.post_attention_layernorm(x) |
| x = self.feed_forward(x) |
| x = residual + self.dropout(x) |
| |
| return x |
| |
| def _gradient_checkpointed_forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| position_ids: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| return torch.utils.checkpoint.checkpoint( |
| self.self_attention, |
| x, |
| attention_mask, |
| position_ids, |
| use_reentrant=False, |
| ) |
|
|
|
|
| class CodsworthTransformer(nn.Module): |
| """Codsworth transformer language model.""" |
| |
| def __init__(self, config: CodsworthConfig): |
| super().__init__() |
| self.config = config |
| |
| self.vocab_embedding = nn.Embedding( |
| config.vocab_size, |
| config.embedding_dim, |
| padding_idx=config.pad_token_id |
| ) |
| self.embedding_dropout = nn.Dropout(config.embedding_dropout) |
| |
| self.layers = nn.ModuleList([ |
| TransformerBlock( |
| embed_dim=config.embedding_dim, |
| num_heads=config.num_heads, |
| head_dim=config.head_dim, |
| ffn_hidden_dim=config.ffn_hidden_dim, |
| dropout=config.dropout, |
| attention_dropout=config.attention_dropout, |
| ffn_dropout=config.ffn_dropout, |
| bias=config.use_bias, |
| max_position_embeddings=config.max_position_embeddings, |
| use_rope=config.use_rope, |
| rope_theta=config.rope_theta, |
| use_flash_attention=config.use_flash_attention, |
| use_gradient_checkpointing=config.use_gradient_checkpointing, |
| ) |
| for _ in range(config.num_layers) |
| ]) |
| |
| self.final_layernorm = nn.LayerNorm( |
| config.embedding_dim, |
| bias=config.use_bias |
| ) |
| |
| self.lm_head = nn.Linear( |
| config.embedding_dim, |
| config.vocab_size, |
| bias=False |
| ) |
| |
| self._tie_weights() |
| |
| self.apply(self._init_weights) |
| |
| def _tie_weights(self): |
| self.lm_head.weight = self.vocab_embedding.weight |
| |
| def _init_weights(self, module: nn.Module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_( |
| module.weight, |
| mean=0.0, |
| std=0.02 |
| ) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_( |
| module.weight, |
| mean=0.0, |
| std=0.02 |
| ) |
| if module.padding_idx is not None: |
| with torch.no_grad(): |
| module.weight[module.padding_idx].zero_() |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ) -> dict[str, torch.Tensor]: |
| batch_size, seq_len = input_ids.shape |
| |
| token_embeds = self.vocab_embedding(input_ids) |
| hidden_states = self.embedding_dropout(token_embeds) |
| |
| for layer in self.layers: |
| hidden_states = layer(hidden_states, attention_mask, position_ids) |
| |
| hidden_states = self.final_layernorm(hidden_states) |
| |
| logits = self.lm_head(hidden_states) |
| |
| loss = None |
| if labels is not None: |
| loss = self._compute_loss(logits, labels) |
| |
| return { |
| "logits": logits, |
| "loss": loss, |
| "hidden_states": hidden_states, |
| } |
| |
| def _compute_loss( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| ) -> torch.Tensor: |
| batch_size, seq_len, vocab_size = logits.shape |
| shift_logits = logits[:, :-1, :].reshape(-1, vocab_size) |
| shift_labels = labels[:, 1:].reshape(-1) |
| |
| loss_fct = nn.CrossEntropyLoss( |
| ignore_index=self.config.pad_token_id, |
| reduction="mean", |
| ) |
| |
| return loss_fct(shift_logits, shift_labels) |
| |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| eos_token_id: Optional[int] = None, |
| ) -> torch.Tensor: |
| self.eval() |
| |
| batch_size = input_ids.shape[0] |
| generated = input_ids.clone() |
| |
| for _ in range(max_new_tokens): |
| if generated.shape[1] > self.config.max_position_embeddings: |
| generated = generated[:, -self.config.max_position_embeddings:] |
| |
| outputs = self.forward(generated) |
| logits = outputs["logits"] |
| |
| next_token_logits = logits[:, -1, :] / temperature |
| |
| if top_k is not None: |
| v = torch.topk(next_token_logits, min(top_k, logits.shape[-1]))[0] |
| next_token_logits = torch.where( |
| next_token_logits < v[..., -1:], |
| torch.tensor(float("-inf"), device=logits.device), |
| next_token_logits, |
| ) |
| |
| if top_p is not None: |
| sorted_logits, sorted_indices = torch.sort( |
| next_token_logits, descending=True |
| ) |
| cumulative_probs = torch.cumsum( |
| F.softmax(sorted_logits, dim=-1), |
| dim=-1 |
| ) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = ( |
| sorted_indices_to_remove[..., :-1].clone() |
| ) |
| sorted_indices_to_remove[..., 0] = False |
| |
| for i in range(batch_size): |
| indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
| next_token_logits[i, indices_to_remove] = float("-inf") |
| |
| probs = F.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| |
| generated = torch.cat([generated, next_token], dim=1) |
| |
| if eos_token_id is not None and (next_token == eos_token_id).all(): |
| break |
| |
| return generated |
| |
| @torch.no_grad() |
| def encode( |
| self, |
| input_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| outputs = self.forward(input_ids) |
| return outputs["hidden_states"] |
| |
| @torch.no_grad() |
| def decode( |
| self, |
| hidden_states: torch.Tensor, |
| ) -> torch.Tensor: |
| logits = self.lm_head(self.final_layernorm(hidden_states)) |
| return logits |
| |
| def get_num_params(self, trainable_only: bool = False) -> int: |
| if trainable_only: |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return sum(p.numel() for p in self.parameters()) |