File size: 11,508 Bytes
e7f17a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
"""
10+2 Tied Transformer for English → Malay Translation
=======================================================
An asymmetric encoder-decoder Transformer built on ``torch.nn.Transformer``.
Architecture (redesigned for efficient T4 GPU training & inference):
d_model = 512 (embedding dimension, head_dim = 64)
n_head = 8 (attention heads)
encoder layers = 10 (deep encoder for source understanding)
decoder layers = 2 (shallow decoder for fast generation)
d_ff = 2048 (feed-forward inner dimension)
dropout = 0.1
norm_first = True (pre-norm for training stability)
shared embeddings = True (single vocab, en+ms share Latin script)
tied output proj. = True (output reuses embedding weights)
Key design choices (see architecture_report.md for full rationale):
• **Asymmetric depth (Kasai et al., 2021):** Encoder depth drives
translation quality; decoder depth can be aggressively reduced
with minimal quality loss and ~3× faster inference.
• **Shared vocabulary:** English and Malay both use Latin script with
significant lexical overlap (loanwords, numbers, proper nouns).
A joint BPE naturally captures cross-lingual subword patterns.
• **Tied output projection (Press & Wolf, 2017):** The decoder's output
linear layer reuses the shared embedding matrix, saving ~26M params
and acting as a regulariser.
• **Pre-layer normalisation (Xiong et al., 2020):** Essential for stable
training of a 10-layer encoder. Places LayerNorm before each sublayer.
• Uses PyTorch's native ``nn.Transformer`` to keep FlashAttention /
SDPA fused kernels active (PyTorch 2.0+).
"""
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Positional Encoding (sinusoidal, from "Attention Is All You Need")
# ---------------------------------------------------------------------------
class PositionalEncoding(nn.Module):
"""
Inject positional information via fixed sinusoidal signals.
PE(pos, 2i) = sin(pos / 10000^{2i / d_model})
PE(pos, 2i+1) = cos(pos / 10000^{2i / d_model})
"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
) # (d_model/2,)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model)
Returns:
(batch, seq_len, d_model) with positional encoding added.
"""
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
# ---------------------------------------------------------------------------
# Full Transformer Model (10+2 Tied)
# ---------------------------------------------------------------------------
class TransformerTranslator(nn.Module):
"""
Asymmetric encoder-decoder Transformer with shared/tied embeddings.
Parameters
----------
vocab_size : int
Size of the shared source+target vocabulary.
d_model : int
Embedding / hidden dimension.
n_head : int
Number of attention heads.
num_encoder_layers : int
Number of encoder blocks (default 10).
num_decoder_layers : int
Number of decoder blocks (default 2).
d_ff : int
Feed-forward inner dimension.
dropout : float
Dropout rate.
max_len : int
Maximum sequence length for positional encoding.
pad_idx : int
Padding token ID (used to create padding masks).
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
n_head: int = 8,
num_encoder_layers: int = 10,
num_decoder_layers: int = 2,
d_ff: int = 2048,
dropout: float = 0.1,
max_len: int = 512,
pad_idx: int = 0,
):
super().__init__()
self.pad_idx = pad_idx
self.d_model = d_model
# --- Shared embedding (one matrix for both enc & dec) -------------
self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.embed_scale = math.sqrt(d_model)
# --- Core Transformer (asymmetric, pre-norm) ----------------------
self.transformer = nn.Transformer(
d_model=d_model,
nhead=n_head,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=d_ff,
dropout=dropout,
batch_first=True,
norm_first=True, # pre-layer norm for stability
)
# --- Tied output projection (reuses embedding weights) ------------
# No separate nn.Linear — forward() uses F.linear with shared weights
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
# --- Initialize weights -------------------------------------------
self._init_weights()
def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
"""Shared embedding + scale + positional encoding."""
return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale)
def _init_weights(self):
"""Xavier-uniform initialization for embeddings."""
nn.init.normal_(self.shared_embedding.weight, mean=0, std=self.d_model ** -0.5)
# Zero out padding embedding
with torch.no_grad():
self.shared_embedding.weight[self.pad_idx].zero_()
# ------------------------------------------------------------------
# Mask utilities
# ------------------------------------------------------------------
@staticmethod
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
"""
Causal mask for the decoder: prevents attending to future positions.
Returns a (sz, sz) boolean mask where True = blocked.
"""
return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)
def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor:
"""
Create a padding mask: True where token == pad_idx.
Shape: (batch, seq_len)
"""
return x == self.pad_idx
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
src: (batch, src_len) source token IDs.
tgt: (batch, tgt_len) target token IDs (teacher-forced).
Returns:
logits: (batch, tgt_len, vocab_size)
"""
# Build masks if not provided
if src_key_padding_mask is None:
src_key_padding_mask = self._make_pad_mask(src)
if tgt_key_padding_mask is None:
tgt_key_padding_mask = self._make_pad_mask(tgt)
# Causal mask for decoder
tgt_len = tgt.size(1)
tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
# Shared embeddings for both encoder and decoder
src_emb = self._embed(src)
tgt_emb = self._embed(tgt)
# Transformer forward
out = self.transformer(
src=src_emb,
tgt=tgt_emb,
tgt_mask=tgt_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
) # (batch, tgt_len, d_model)
# Tied output projection: logits = out @ embedding_weights.T + bias
logits = torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
return logits
# ------------------------------------------------------------------
# Inference helpers
# ------------------------------------------------------------------
def encode(self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Run only the encoder. Returns memory: (batch, src_len, d_model)."""
if src_key_padding_mask is None:
src_key_padding_mask = self._make_pad_mask(src)
src_emb = self._embed(src)
return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Run only the decoder given encoder memory. Returns logits."""
if tgt_key_padding_mask is None:
tgt_key_padding_mask = self._make_pad_mask(tgt)
tgt_len = tgt.size(1)
tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
tgt_emb = self._embed(tgt)
out = self.transformer.decoder(
tgt_emb,
memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
# ---------------------------------------------------------------------------
# Helper: count parameters
# ---------------------------------------------------------------------------
def count_parameters(model: nn.Module) -> int:
"""Return the number of trainable parameters."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# ---------------------------------------------------------------------------
# Helper: build model
# ---------------------------------------------------------------------------
def build_model(
vocab_size: int,
pad_idx: int = 0,
device: Optional[torch.device] = None,
**kwargs,
) -> TransformerTranslator:
"""
Build and return a TransformerTranslator with default hyperparameters.
Any kwarg (d_model, n_head, etc.) overrides the default.
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerTranslator(
vocab_size=vocab_size,
pad_idx=pad_idx,
**kwargs,
).to(device)
return model
|