""" 10+2 Tied Transformer for English → Malay Translation ======================================================= An asymmetric encoder-decoder Transformer built on ``torch.nn.Transformer``. Architecture (redesigned for efficient T4 GPU training & inference): d_model = 512 (embedding dimension, head_dim = 64) n_head = 8 (attention heads) encoder layers = 10 (deep encoder for source understanding) decoder layers = 2 (shallow decoder for fast generation) d_ff = 2048 (feed-forward inner dimension) dropout = 0.1 norm_first = True (pre-norm for training stability) shared embeddings = True (single vocab, en+ms share Latin script) tied output proj. = True (output reuses embedding weights) Key design choices (see architecture_report.md for full rationale): • **Asymmetric depth (Kasai et al., 2021):** Encoder depth drives translation quality; decoder depth can be aggressively reduced with minimal quality loss and ~3× faster inference. • **Shared vocabulary:** English and Malay both use Latin script with significant lexical overlap (loanwords, numbers, proper nouns). A joint BPE naturally captures cross-lingual subword patterns. • **Tied output projection (Press & Wolf, 2017):** The decoder's output linear layer reuses the shared embedding matrix, saving ~26M params and acting as a regulariser. • **Pre-layer normalisation (Xiong et al., 2020):** Essential for stable training of a 10-layer encoder. Places LayerNorm before each sublayer. • Uses PyTorch's native ``nn.Transformer`` to keep FlashAttention / SDPA fused kernels active (PyTorch 2.0+). """ from __future__ import annotations import math from typing import Optional import torch import torch.nn as nn # --------------------------------------------------------------------------- # Positional Encoding (sinusoidal, from "Attention Is All You Need") # --------------------------------------------------------------------------- class PositionalEncoding(nn.Module): """ Inject positional information via fixed sinusoidal signals. PE(pos, 2i) = sin(pos / 10000^{2i / d_model}) PE(pos, 2i+1) = cos(pos / 10000^{2i / d_model}) """ def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) # (max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) # (d_model/2,) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer("pe", pe) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, seq_len, d_model) Returns: (batch, seq_len, d_model) with positional encoding added. """ x = x + self.pe[:, : x.size(1)] return self.dropout(x) # --------------------------------------------------------------------------- # Full Transformer Model (10+2 Tied) # --------------------------------------------------------------------------- class TransformerTranslator(nn.Module): """ Asymmetric encoder-decoder Transformer with shared/tied embeddings. Parameters ---------- vocab_size : int Size of the shared source+target vocabulary. d_model : int Embedding / hidden dimension. n_head : int Number of attention heads. num_encoder_layers : int Number of encoder blocks (default 10). num_decoder_layers : int Number of decoder blocks (default 2). d_ff : int Feed-forward inner dimension. dropout : float Dropout rate. max_len : int Maximum sequence length for positional encoding. pad_idx : int Padding token ID (used to create padding masks). """ def __init__( self, vocab_size: int, d_model: int = 512, n_head: int = 8, num_encoder_layers: int = 10, num_decoder_layers: int = 2, d_ff: int = 2048, dropout: float = 0.1, max_len: int = 512, pad_idx: int = 0, ): super().__init__() self.pad_idx = pad_idx self.d_model = d_model # --- Shared embedding (one matrix for both enc & dec) ------------- self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx) self.pos_encoding = PositionalEncoding(d_model, max_len, dropout) self.embed_scale = math.sqrt(d_model) # --- Core Transformer (asymmetric, pre-norm) ---------------------- self.transformer = nn.Transformer( d_model=d_model, nhead=n_head, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=d_ff, dropout=dropout, batch_first=True, norm_first=True, # pre-layer norm for stability ) # --- Tied output projection (reuses embedding weights) ------------ # No separate nn.Linear — forward() uses F.linear with shared weights self.output_bias = nn.Parameter(torch.zeros(vocab_size)) # --- Initialize weights ------------------------------------------- self._init_weights() def _embed(self, tokens: torch.Tensor) -> torch.Tensor: """Shared embedding + scale + positional encoding.""" return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale) def _init_weights(self): """Xavier-uniform initialization for embeddings.""" nn.init.normal_(self.shared_embedding.weight, mean=0, std=self.d_model ** -0.5) # Zero out padding embedding with torch.no_grad(): self.shared_embedding.weight[self.pad_idx].zero_() # ------------------------------------------------------------------ # Mask utilities # ------------------------------------------------------------------ @staticmethod def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor: """ Causal mask for the decoder: prevents attending to future positions. Returns a (sz, sz) boolean mask where True = blocked. """ return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1) def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor: """ Create a padding mask: True where token == pad_idx. Shape: (batch, seq_len) """ return x == self.pad_idx # ------------------------------------------------------------------ # Forward # ------------------------------------------------------------------ def forward( self, src: torch.Tensor, tgt: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: src: (batch, src_len) source token IDs. tgt: (batch, tgt_len) target token IDs (teacher-forced). Returns: logits: (batch, tgt_len, vocab_size) """ # Build masks if not provided if src_key_padding_mask is None: src_key_padding_mask = self._make_pad_mask(src) if tgt_key_padding_mask is None: tgt_key_padding_mask = self._make_pad_mask(tgt) # Causal mask for decoder tgt_len = tgt.size(1) tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device) # Shared embeddings for both encoder and decoder src_emb = self._embed(src) tgt_emb = self._embed(tgt) # Transformer forward out = self.transformer( src=src_emb, tgt=tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask, ) # (batch, tgt_len, d_model) # Tied output projection: logits = out @ embedding_weights.T + bias logits = torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias) return logits # ------------------------------------------------------------------ # Inference helpers # ------------------------------------------------------------------ def encode(self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Run only the encoder. Returns memory: (batch, src_len, d_model).""" if src_key_padding_mask is None: src_key_padding_mask = self._make_pad_mask(src) src_emb = self._embed(src) return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask) def decode( self, tgt: torch.Tensor, memory: torch.Tensor, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Run only the decoder given encoder memory. Returns logits.""" if tgt_key_padding_mask is None: tgt_key_padding_mask = self._make_pad_mask(tgt) tgt_len = tgt.size(1) tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device) tgt_emb = self._embed(tgt) out = self.transformer.decoder( tgt_emb, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, ) return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias) # --------------------------------------------------------------------------- # Helper: count parameters # --------------------------------------------------------------------------- def count_parameters(model: nn.Module) -> int: """Return the number of trainable parameters.""" return sum(p.numel() for p in model.parameters() if p.requires_grad) # --------------------------------------------------------------------------- # Helper: build model # --------------------------------------------------------------------------- def build_model( vocab_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **kwargs, ) -> TransformerTranslator: """ Build and return a TransformerTranslator with default hyperparameters. Any kwarg (d_model, n_head, etc.) overrides the default. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = TransformerTranslator( vocab_size=vocab_size, pad_idx=pad_idx, **kwargs, ).to(device) return model