Eli181927's picture
Rename encode.py to encoder_transformer/encode.py
d2f5442 verified
# encode.py
from dataclasses import dataclass
from typing import Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class EncoderConfig:
# Vocabulary size for source language (set from tokenizer)
src_vocab_size: int
# Model dimensions
embed_dim: int = 512
ff_hidden_dim: int = 2048
num_heads: int = 8
num_layers: int = 6
# Regularization
dropout: float = 0.1
# Max sequence length for positional embeddings
max_position_embeddings: int = 1024
# Special tokens
pad_token_id: int = 0
# Initialization scale (optional, small init helps stability)
init_range: float = 0.02
class TokenPositionalEmbedding(nn.Module):
"""
Token embedding + learned positional embedding.
Shapes:
- input_ids: [B, S]
- return: [B, S, D]
"""
def __init__(self, vocab_size: int, embed_dim: int,
max_position_embeddings: int, pad_token_id: int, dropout: float):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
self.pos_embedding = nn.Embedding(max_position_embeddings, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
batch_size, seq_len = input_ids.shape
device = input_ids.device
# [S] absolute positions 0..S-1
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, seq_len)
x = self.token_embedding(input_ids) + self.pos_embedding(positions)
return self.dropout(x) # [B, S, D]
class MultiHeadSelfAttention(nn.Module):
"""
Standard MHA (Q=K=V) with padding mask support.
Shapes:
- x: [B, S, D]
- key_padding_mask: [B, S] with True for tokens to keep OR 1/0; we convert to bool keep mask
- return: [B, S, D]
"""
def __init__(self, embed_dim: int, num_heads: int, dropout: float):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.attn_dropout = nn.Dropout(dropout)
def forward(self, x: torch.FloatTensor, key_padding_mask: torch.Tensor) -> torch.FloatTensor:
B, S, D = x.shape
# Project to multihead Q, K, V: [B, S, H*Hd] -> [B, H, S, Hd]
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
# Attention scores: [B, H, S, Hd] @ [B, H, Hd, S] -> [B, H, S, S]
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Build broadcastable mask over keys dimension: [B, 1, 1, S]
# key_padding_mask is 1/True for valid tokens; 0/False for PADs.
if key_padding_mask.dtype != torch.bool:
keep_mask = key_padding_mask != 0
else:
keep_mask = key_padding_mask
keep_mask = keep_mask.unsqueeze(1).unsqueeze(1) # [B,1,1,S]
# Mask PAD keys by setting scores to a large negative value (excluded after softmax)
attn_scores = attn_scores.masked_fill(~keep_mask, float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
# Weighted sum of values: [B, H, S, S] @ [B, H, S, Hd] -> [B, H, S, Hd]
attn_output = torch.matmul(attn_weights, v)
# Merge heads: [B, H, S, Hd] -> [B, S, H*Hd=D]
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, D)
return self.out_proj(attn_output)
class FeedForward(nn.Module):
"""
Position-wise MLP applied to each position independently.
Shapes:
- x: [B, S, D] -> [B, S, D]
"""
def __init__(self, embed_dim: int, hidden_dim: int, dropout: float):
super().__init__()
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = self.fc1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.fc2(x)
return self.dropout(x)
class EncoderBlock(nn.Module):
"""
One Pre-LN encoder block: LN -> MHA -> resid, then LN -> FFN -> resid.
"""
def __init__(self, embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
self.dropout1 = nn.Dropout(dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.ff = FeedForward(embed_dim, ff_hidden_dim, dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x: torch.FloatTensor, key_padding_mask: torch.Tensor) -> torch.FloatTensor:
# Self-attention sub-layer (Pre-LN)
attn_out = self.self_attn(self.ln1(x), key_padding_mask=key_padding_mask)
x = x + self.dropout1(attn_out)
# Feedforward sub-layer (Pre-LN)
ff_out = self.ff(self.ln2(x))
x = x + self.dropout2(ff_out)
return x
class Encoder(nn.Module):
"""
Full encoder: embeddings -> N blocks -> final LayerNorm.
Forward signature:
encoder_hidden_states = Encoder(config)(src_input_ids, src_attention_mask)
"""
def __init__(self, config: EncoderConfig):
super().__init__()
self.config = config
assert config.embed_dim % config.num_heads == 0, "embed_dim must be divisible by num_heads"
self.embeddings = TokenPositionalEmbedding(
vocab_size=config.src_vocab_size,
embed_dim=config.embed_dim,
max_position_embeddings=config.max_position_embeddings,
pad_token_id=config.pad_token_id,
dropout=config.dropout,
)
self.layers = nn.ModuleList([
EncoderBlock(
embed_dim=config.embed_dim,
num_heads=config.num_heads,
ff_hidden_dim=config.ff_hidden_dim,
dropout=config.dropout,
)
for _ in range(config.num_layers)
])
self.final_ln = nn.LayerNorm(config.embed_dim)
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_range)
# Respect padding index: keep pad vectors near zero
if module.padding_idx is not None:
with torch.no_grad():
module.weight[module.padding_idx].fill_(0.0)
@torch.no_grad()
def _ensure_mask_dtype(self, mask: torch.Tensor) -> torch.Tensor:
# Accept bool or 0/1. Return bool where True means "keep".
return mask.bool() if mask.dtype != torch.bool else mask
def forward(
self,
src_input_ids: torch.LongTensor, # [B, S]
src_attention_mask: torch.Tensor, # [B, S] (1/True=token, 0/False=PAD)
) -> torch.FloatTensor:
x = self.embeddings(src_input_ids) # [B, S, D]
keep_mask = self._ensure_mask_dtype(src_attention_mask)
for layer in self.layers:
x = layer(x, key_padding_mask=keep_mask)
x = self.final_ln(x)
x = x * keep_mask.unsqueeze(-1)
return x # [B, S, D]