AstralPotato's picture
Upload en-ms Transformer (6+2 Tied, 16K BPE, chrF 45.62)
e7f17a4 verified
"""
10+2 Tied Transformer for English → Malay Translation
=======================================================
An asymmetric encoder-decoder Transformer built on ``torch.nn.Transformer``.
Architecture (redesigned for efficient T4 GPU training & inference):
d_model = 512 (embedding dimension, head_dim = 64)
n_head = 8 (attention heads)
encoder layers = 10 (deep encoder for source understanding)
decoder layers = 2 (shallow decoder for fast generation)
d_ff = 2048 (feed-forward inner dimension)
dropout = 0.1
norm_first = True (pre-norm for training stability)
shared embeddings = True (single vocab, en+ms share Latin script)
tied output proj. = True (output reuses embedding weights)
Key design choices (see architecture_report.md for full rationale):
• **Asymmetric depth (Kasai et al., 2021):** Encoder depth drives
translation quality; decoder depth can be aggressively reduced
with minimal quality loss and ~3× faster inference.
• **Shared vocabulary:** English and Malay both use Latin script with
significant lexical overlap (loanwords, numbers, proper nouns).
A joint BPE naturally captures cross-lingual subword patterns.
• **Tied output projection (Press & Wolf, 2017):** The decoder's output
linear layer reuses the shared embedding matrix, saving ~26M params
and acting as a regulariser.
• **Pre-layer normalisation (Xiong et al., 2020):** Essential for stable
training of a 10-layer encoder. Places LayerNorm before each sublayer.
• Uses PyTorch's native ``nn.Transformer`` to keep FlashAttention /
SDPA fused kernels active (PyTorch 2.0+).
"""
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Positional Encoding (sinusoidal, from "Attention Is All You Need")
# ---------------------------------------------------------------------------
class PositionalEncoding(nn.Module):
"""
Inject positional information via fixed sinusoidal signals.
PE(pos, 2i) = sin(pos / 10000^{2i / d_model})
PE(pos, 2i+1) = cos(pos / 10000^{2i / d_model})
"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
) # (d_model/2,)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model)
Returns:
(batch, seq_len, d_model) with positional encoding added.
"""
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
# ---------------------------------------------------------------------------
# Full Transformer Model (10+2 Tied)
# ---------------------------------------------------------------------------
class TransformerTranslator(nn.Module):
"""
Asymmetric encoder-decoder Transformer with shared/tied embeddings.
Parameters
----------
vocab_size : int
Size of the shared source+target vocabulary.
d_model : int
Embedding / hidden dimension.
n_head : int
Number of attention heads.
num_encoder_layers : int
Number of encoder blocks (default 10).
num_decoder_layers : int
Number of decoder blocks (default 2).
d_ff : int
Feed-forward inner dimension.
dropout : float
Dropout rate.
max_len : int
Maximum sequence length for positional encoding.
pad_idx : int
Padding token ID (used to create padding masks).
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
n_head: int = 8,
num_encoder_layers: int = 10,
num_decoder_layers: int = 2,
d_ff: int = 2048,
dropout: float = 0.1,
max_len: int = 512,
pad_idx: int = 0,
):
super().__init__()
self.pad_idx = pad_idx
self.d_model = d_model
# --- Shared embedding (one matrix for both enc & dec) -------------
self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.embed_scale = math.sqrt(d_model)
# --- Core Transformer (asymmetric, pre-norm) ----------------------
self.transformer = nn.Transformer(
d_model=d_model,
nhead=n_head,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=d_ff,
dropout=dropout,
batch_first=True,
norm_first=True, # pre-layer norm for stability
)
# --- Tied output projection (reuses embedding weights) ------------
# No separate nn.Linear — forward() uses F.linear with shared weights
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
# --- Initialize weights -------------------------------------------
self._init_weights()
def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
"""Shared embedding + scale + positional encoding."""
return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale)
def _init_weights(self):
"""Xavier-uniform initialization for embeddings."""
nn.init.normal_(self.shared_embedding.weight, mean=0, std=self.d_model ** -0.5)
# Zero out padding embedding
with torch.no_grad():
self.shared_embedding.weight[self.pad_idx].zero_()
# ------------------------------------------------------------------
# Mask utilities
# ------------------------------------------------------------------
@staticmethod
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
"""
Causal mask for the decoder: prevents attending to future positions.
Returns a (sz, sz) boolean mask where True = blocked.
"""
return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)
def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor:
"""
Create a padding mask: True where token == pad_idx.
Shape: (batch, seq_len)
"""
return x == self.pad_idx
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
src: (batch, src_len) source token IDs.
tgt: (batch, tgt_len) target token IDs (teacher-forced).
Returns:
logits: (batch, tgt_len, vocab_size)
"""
# Build masks if not provided
if src_key_padding_mask is None:
src_key_padding_mask = self._make_pad_mask(src)
if tgt_key_padding_mask is None:
tgt_key_padding_mask = self._make_pad_mask(tgt)
# Causal mask for decoder
tgt_len = tgt.size(1)
tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
# Shared embeddings for both encoder and decoder
src_emb = self._embed(src)
tgt_emb = self._embed(tgt)
# Transformer forward
out = self.transformer(
src=src_emb,
tgt=tgt_emb,
tgt_mask=tgt_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
) # (batch, tgt_len, d_model)
# Tied output projection: logits = out @ embedding_weights.T + bias
logits = torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
return logits
# ------------------------------------------------------------------
# Inference helpers
# ------------------------------------------------------------------
def encode(self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Run only the encoder. Returns memory: (batch, src_len, d_model)."""
if src_key_padding_mask is None:
src_key_padding_mask = self._make_pad_mask(src)
src_emb = self._embed(src)
return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Run only the decoder given encoder memory. Returns logits."""
if tgt_key_padding_mask is None:
tgt_key_padding_mask = self._make_pad_mask(tgt)
tgt_len = tgt.size(1)
tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
tgt_emb = self._embed(tgt)
out = self.transformer.decoder(
tgt_emb,
memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
# ---------------------------------------------------------------------------
# Helper: count parameters
# ---------------------------------------------------------------------------
def count_parameters(model: nn.Module) -> int:
"""Return the number of trainable parameters."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# ---------------------------------------------------------------------------
# Helper: build model
# ---------------------------------------------------------------------------
def build_model(
vocab_size: int,
pad_idx: int = 0,
device: Optional[torch.device] = None,
**kwargs,
) -> TransformerTranslator:
"""
Build and return a TransformerTranslator with default hyperparameters.
Any kwarg (d_model, n_head, etc.) overrides the default.
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerTranslator(
vocab_size=vocab_size,
pad_idx=pad_idx,
**kwargs,
).to(device)
return model