|
|
"""
|
|
|
10+2 Tied Transformer for English → Malay Translation
|
|
|
=======================================================
|
|
|
An asymmetric encoder-decoder Transformer built on ``torch.nn.Transformer``.
|
|
|
|
|
|
Architecture (redesigned for efficient T4 GPU training & inference):
|
|
|
d_model = 512 (embedding dimension, head_dim = 64)
|
|
|
n_head = 8 (attention heads)
|
|
|
encoder layers = 10 (deep encoder for source understanding)
|
|
|
decoder layers = 2 (shallow decoder for fast generation)
|
|
|
d_ff = 2048 (feed-forward inner dimension)
|
|
|
dropout = 0.1
|
|
|
norm_first = True (pre-norm for training stability)
|
|
|
shared embeddings = True (single vocab, en+ms share Latin script)
|
|
|
tied output proj. = True (output reuses embedding weights)
|
|
|
|
|
|
Key design choices (see architecture_report.md for full rationale):
|
|
|
• **Asymmetric depth (Kasai et al., 2021):** Encoder depth drives
|
|
|
translation quality; decoder depth can be aggressively reduced
|
|
|
with minimal quality loss and ~3× faster inference.
|
|
|
• **Shared vocabulary:** English and Malay both use Latin script with
|
|
|
significant lexical overlap (loanwords, numbers, proper nouns).
|
|
|
A joint BPE naturally captures cross-lingual subword patterns.
|
|
|
• **Tied output projection (Press & Wolf, 2017):** The decoder's output
|
|
|
linear layer reuses the shared embedding matrix, saving ~26M params
|
|
|
and acting as a regulariser.
|
|
|
• **Pre-layer normalisation (Xiong et al., 2020):** Essential for stable
|
|
|
training of a 10-layer encoder. Places LayerNorm before each sublayer.
|
|
|
• Uses PyTorch's native ``nn.Transformer`` to keep FlashAttention /
|
|
|
SDPA fused kernels active (PyTorch 2.0+).
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import math
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
"""
|
|
|
Inject positional information via fixed sinusoidal signals.
|
|
|
|
|
|
PE(pos, 2i) = sin(pos / 10000^{2i / d_model})
|
|
|
PE(pos, 2i+1) = cos(pos / 10000^{2i / d_model})
|
|
|
"""
|
|
|
|
|
|
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
|
|
|
super().__init__()
|
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
|
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
|
div_term = torch.exp(
|
|
|
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
|
|
)
|
|
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
pe = pe.unsqueeze(0)
|
|
|
self.register_buffer("pe", pe)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Args:
|
|
|
x: (batch, seq_len, d_model)
|
|
|
Returns:
|
|
|
(batch, seq_len, d_model) with positional encoding added.
|
|
|
"""
|
|
|
x = x + self.pe[:, : x.size(1)]
|
|
|
return self.dropout(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerTranslator(nn.Module):
|
|
|
"""
|
|
|
Asymmetric encoder-decoder Transformer with shared/tied embeddings.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
vocab_size : int
|
|
|
Size of the shared source+target vocabulary.
|
|
|
d_model : int
|
|
|
Embedding / hidden dimension.
|
|
|
n_head : int
|
|
|
Number of attention heads.
|
|
|
num_encoder_layers : int
|
|
|
Number of encoder blocks (default 10).
|
|
|
num_decoder_layers : int
|
|
|
Number of decoder blocks (default 2).
|
|
|
d_ff : int
|
|
|
Feed-forward inner dimension.
|
|
|
dropout : float
|
|
|
Dropout rate.
|
|
|
max_len : int
|
|
|
Maximum sequence length for positional encoding.
|
|
|
pad_idx : int
|
|
|
Padding token ID (used to create padding masks).
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
vocab_size: int,
|
|
|
d_model: int = 512,
|
|
|
n_head: int = 8,
|
|
|
num_encoder_layers: int = 10,
|
|
|
num_decoder_layers: int = 2,
|
|
|
d_ff: int = 2048,
|
|
|
dropout: float = 0.1,
|
|
|
max_len: int = 512,
|
|
|
pad_idx: int = 0,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.pad_idx = pad_idx
|
|
|
self.d_model = d_model
|
|
|
|
|
|
|
|
|
self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
|
|
|
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
|
|
|
self.embed_scale = math.sqrt(d_model)
|
|
|
|
|
|
|
|
|
self.transformer = nn.Transformer(
|
|
|
d_model=d_model,
|
|
|
nhead=n_head,
|
|
|
num_encoder_layers=num_encoder_layers,
|
|
|
num_decoder_layers=num_decoder_layers,
|
|
|
dim_feedforward=d_ff,
|
|
|
dropout=dropout,
|
|
|
batch_first=True,
|
|
|
norm_first=True,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
|
|
|
|
|
|
|
|
|
self._init_weights()
|
|
|
|
|
|
def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
|
"""Shared embedding + scale + positional encoding."""
|
|
|
return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale)
|
|
|
|
|
|
def _init_weights(self):
|
|
|
"""Xavier-uniform initialization for embeddings."""
|
|
|
nn.init.normal_(self.shared_embedding.weight, mean=0, std=self.d_model ** -0.5)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
self.shared_embedding.weight[self.pad_idx].zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
|
|
|
"""
|
|
|
Causal mask for the decoder: prevents attending to future positions.
|
|
|
Returns a (sz, sz) boolean mask where True = blocked.
|
|
|
"""
|
|
|
return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)
|
|
|
|
|
|
def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Create a padding mask: True where token == pad_idx.
|
|
|
Shape: (batch, seq_len)
|
|
|
"""
|
|
|
return x == self.pad_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
src: torch.Tensor,
|
|
|
tgt: torch.Tensor,
|
|
|
src_key_padding_mask: Optional[torch.Tensor] = None,
|
|
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Args:
|
|
|
src: (batch, src_len) source token IDs.
|
|
|
tgt: (batch, tgt_len) target token IDs (teacher-forced).
|
|
|
|
|
|
Returns:
|
|
|
logits: (batch, tgt_len, vocab_size)
|
|
|
"""
|
|
|
|
|
|
if src_key_padding_mask is None:
|
|
|
src_key_padding_mask = self._make_pad_mask(src)
|
|
|
if tgt_key_padding_mask is None:
|
|
|
tgt_key_padding_mask = self._make_pad_mask(tgt)
|
|
|
|
|
|
|
|
|
tgt_len = tgt.size(1)
|
|
|
tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
|
|
|
|
|
|
|
|
|
src_emb = self._embed(src)
|
|
|
tgt_emb = self._embed(tgt)
|
|
|
|
|
|
|
|
|
out = self.transformer(
|
|
|
src=src_emb,
|
|
|
tgt=tgt_emb,
|
|
|
tgt_mask=tgt_mask,
|
|
|
src_key_padding_mask=src_key_padding_mask,
|
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
|
memory_key_padding_mask=src_key_padding_mask,
|
|
|
)
|
|
|
|
|
|
|
|
|
logits = torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode(self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
"""Run only the encoder. Returns memory: (batch, src_len, d_model)."""
|
|
|
if src_key_padding_mask is None:
|
|
|
src_key_padding_mask = self._make_pad_mask(src)
|
|
|
src_emb = self._embed(src)
|
|
|
return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
|
|
|
|
|
|
def decode(
|
|
|
self,
|
|
|
tgt: torch.Tensor,
|
|
|
memory: torch.Tensor,
|
|
|
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
|
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
|
"""Run only the decoder given encoder memory. Returns logits."""
|
|
|
if tgt_key_padding_mask is None:
|
|
|
tgt_key_padding_mask = self._make_pad_mask(tgt)
|
|
|
tgt_len = tgt.size(1)
|
|
|
tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
|
|
|
tgt_emb = self._embed(tgt)
|
|
|
out = self.transformer.decoder(
|
|
|
tgt_emb,
|
|
|
memory,
|
|
|
tgt_mask=tgt_mask,
|
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
|
)
|
|
|
return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def count_parameters(model: nn.Module) -> int:
|
|
|
"""Return the number of trainable parameters."""
|
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_model(
|
|
|
vocab_size: int,
|
|
|
pad_idx: int = 0,
|
|
|
device: Optional[torch.device] = None,
|
|
|
**kwargs,
|
|
|
) -> TransformerTranslator:
|
|
|
"""
|
|
|
Build and return a TransformerTranslator with default hyperparameters.
|
|
|
|
|
|
Any kwarg (d_model, n_head, etc.) overrides the default.
|
|
|
"""
|
|
|
if device is None:
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
model = TransformerTranslator(
|
|
|
vocab_size=vocab_size,
|
|
|
pad_idx=pad_idx,
|
|
|
**kwargs,
|
|
|
).to(device)
|
|
|
|
|
|
return model
|
|
|
|