Spaces:
Sleeping
Sleeping
| # encode.py | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class EncoderConfig: | |
| # Vocabulary size for source language (set from tokenizer) | |
| src_vocab_size: int | |
| # Model dimensions | |
| embed_dim: int = 512 | |
| ff_hidden_dim: int = 2048 | |
| num_heads: int = 8 | |
| num_layers: int = 6 | |
| # Regularization | |
| dropout: float = 0.1 | |
| # Max sequence length for positional embeddings | |
| max_position_embeddings: int = 1024 | |
| # Special tokens | |
| pad_token_id: int = 0 | |
| # Initialization scale (optional, small init helps stability) | |
| init_range: float = 0.02 | |
| class TokenPositionalEmbedding(nn.Module): | |
| """ | |
| Token embedding + learned positional embedding. | |
| Shapes: | |
| - input_ids: [B, S] | |
| - return: [B, S, D] | |
| """ | |
| def __init__(self, vocab_size: int, embed_dim: int, | |
| max_position_embeddings: int, pad_token_id: int, dropout: float): | |
| super().__init__() | |
| self.token_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id) | |
| self.pos_embedding = nn.Embedding(max_position_embeddings, embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: | |
| batch_size, seq_len = input_ids.shape | |
| device = input_ids.device | |
| # [S] absolute positions 0..S-1 | |
| positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, seq_len) | |
| x = self.token_embedding(input_ids) + self.pos_embedding(positions) | |
| return self.dropout(x) # [B, S, D] | |
| class MultiHeadSelfAttention(nn.Module): | |
| """ | |
| Standard MHA (Q=K=V) with padding mask support. | |
| Shapes: | |
| - x: [B, S, D] | |
| - key_padding_mask: [B, S] with True for tokens to keep OR 1/0; we convert to bool keep mask | |
| - return: [B, S, D] | |
| """ | |
| def __init__(self, embed_dim: int, num_heads: int, dropout: float): | |
| super().__init__() | |
| assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.scale = 1.0 / math.sqrt(self.head_dim) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.attn_dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.FloatTensor, key_padding_mask: torch.Tensor) -> torch.FloatTensor: | |
| B, S, D = x.shape | |
| # Project to multihead Q, K, V: [B, S, H*Hd] -> [B, H, S, Hd] | |
| q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | |
| # Attention scores: [B, H, S, Hd] @ [B, H, Hd, S] -> [B, H, S, S] | |
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale | |
| # Build broadcastable mask over keys dimension: [B, 1, 1, S] | |
| # key_padding_mask is 1/True for valid tokens; 0/False for PADs. | |
| if key_padding_mask.dtype != torch.bool: | |
| keep_mask = key_padding_mask != 0 | |
| else: | |
| keep_mask = key_padding_mask | |
| keep_mask = keep_mask.unsqueeze(1).unsqueeze(1) # [B,1,1,S] | |
| # Mask PAD keys by setting scores to a large negative value (excluded after softmax) | |
| attn_scores = attn_scores.masked_fill(~keep_mask, float("-inf")) | |
| attn_weights = F.softmax(attn_scores, dim=-1) | |
| attn_weights = self.attn_dropout(attn_weights) | |
| # Weighted sum of values: [B, H, S, S] @ [B, H, S, Hd] -> [B, H, S, Hd] | |
| attn_output = torch.matmul(attn_weights, v) | |
| # Merge heads: [B, H, S, Hd] -> [B, S, H*Hd=D] | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, D) | |
| return self.out_proj(attn_output) | |
| class FeedForward(nn.Module): | |
| """ | |
| Position-wise MLP applied to each position independently. | |
| Shapes: | |
| - x: [B, S, D] -> [B, S, D] | |
| """ | |
| def __init__(self, embed_dim: int, hidden_dim: int, dropout: float): | |
| super().__init__() | |
| self.fc1 = nn.Linear(embed_dim, hidden_dim) | |
| self.fc2 = nn.Linear(hidden_dim, embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = nn.GELU() | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| x = self.fc1(x) | |
| x = self.activation(x) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| return self.dropout(x) | |
| class EncoderBlock(nn.Module): | |
| """ | |
| One Pre-LN encoder block: LN -> MHA -> resid, then LN -> FFN -> resid. | |
| """ | |
| def __init__(self, embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(embed_dim) | |
| self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.ln2 = nn.LayerNorm(embed_dim) | |
| self.ff = FeedForward(embed_dim, ff_hidden_dim, dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| def forward(self, x: torch.FloatTensor, key_padding_mask: torch.Tensor) -> torch.FloatTensor: | |
| # Self-attention sub-layer (Pre-LN) | |
| attn_out = self.self_attn(self.ln1(x), key_padding_mask=key_padding_mask) | |
| x = x + self.dropout1(attn_out) | |
| # Feedforward sub-layer (Pre-LN) | |
| ff_out = self.ff(self.ln2(x)) | |
| x = x + self.dropout2(ff_out) | |
| return x | |
| class Encoder(nn.Module): | |
| """ | |
| Full encoder: embeddings -> N blocks -> final LayerNorm. | |
| Forward signature: | |
| encoder_hidden_states = Encoder(config)(src_input_ids, src_attention_mask) | |
| """ | |
| def __init__(self, config: EncoderConfig): | |
| super().__init__() | |
| self.config = config | |
| assert config.embed_dim % config.num_heads == 0, "embed_dim must be divisible by num_heads" | |
| self.embeddings = TokenPositionalEmbedding( | |
| vocab_size=config.src_vocab_size, | |
| embed_dim=config.embed_dim, | |
| max_position_embeddings=config.max_position_embeddings, | |
| pad_token_id=config.pad_token_id, | |
| dropout=config.dropout, | |
| ) | |
| self.layers = nn.ModuleList([ | |
| EncoderBlock( | |
| embed_dim=config.embed_dim, | |
| num_heads=config.num_heads, | |
| ff_hidden_dim=config.ff_hidden_dim, | |
| dropout=config.dropout, | |
| ) | |
| for _ in range(config.num_layers) | |
| ]) | |
| self.final_ln = nn.LayerNorm(config.embed_dim) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module: nn.Module) -> None: | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.init_range) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.init_range) | |
| # Respect padding index: keep pad vectors near zero | |
| if module.padding_idx is not None: | |
| with torch.no_grad(): | |
| module.weight[module.padding_idx].fill_(0.0) | |
| def _ensure_mask_dtype(self, mask: torch.Tensor) -> torch.Tensor: | |
| # Accept bool or 0/1. Return bool where True means "keep". | |
| return mask.bool() if mask.dtype != torch.bool else mask | |
| def forward( | |
| self, | |
| src_input_ids: torch.LongTensor, # [B, S] | |
| src_attention_mask: torch.Tensor, # [B, S] (1/True=token, 0/False=PAD) | |
| ) -> torch.FloatTensor: | |
| x = self.embeddings(src_input_ids) # [B, S, D] | |
| keep_mask = self._ensure_mask_dtype(src_attention_mask) | |
| for layer in self.layers: | |
| x = layer(x, key_padding_mask=keep_mask) | |
| x = self.final_ln(x) | |
| x = x * keep_mask.unsqueeze(-1) | |
| return x # [B, S, D] |