File size: 11,508 Bytes
e7f17a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""

10+2 Tied Transformer for English → Malay Translation

=======================================================

An asymmetric encoder-decoder Transformer built on ``torch.nn.Transformer``.



Architecture (redesigned for efficient T4 GPU training & inference):

    d_model            = 512   (embedding dimension, head_dim = 64)

    n_head             = 8     (attention heads)

    encoder layers     = 10    (deep encoder for source understanding)

    decoder layers     = 2     (shallow decoder for fast generation)

    d_ff               = 2048  (feed-forward inner dimension)

    dropout            = 0.1

    norm_first         = True  (pre-norm for training stability)

    shared embeddings  = True  (single vocab, en+ms share Latin script)

    tied output proj.  = True  (output reuses embedding weights)



Key design choices (see architecture_report.md for full rationale):

  • **Asymmetric depth (Kasai et al., 2021):** Encoder depth drives

    translation quality; decoder depth can be aggressively reduced

    with minimal quality loss and ~3× faster inference.

  • **Shared vocabulary:** English and Malay both use Latin script with

    significant lexical overlap (loanwords, numbers, proper nouns).

    A joint BPE naturally captures cross-lingual subword patterns.

  • **Tied output projection (Press & Wolf, 2017):** The decoder's output

    linear layer reuses the shared embedding matrix, saving ~26M params

    and acting as a regulariser.

  • **Pre-layer normalisation (Xiong et al., 2020):** Essential for stable

    training of a 10-layer encoder.  Places LayerNorm before each sublayer.

  • Uses PyTorch's native ``nn.Transformer`` to keep FlashAttention /

    SDPA fused kernels active (PyTorch 2.0+).

"""

from __future__ import annotations

import math
from typing import Optional

import torch
import torch.nn as nn


# ---------------------------------------------------------------------------
# Positional Encoding (sinusoidal, from "Attention Is All You Need")
# ---------------------------------------------------------------------------
class PositionalEncoding(nn.Module):
    """

    Inject positional information via fixed sinusoidal signals.



    PE(pos, 2i)   = sin(pos / 10000^{2i / d_model})

    PE(pos, 2i+1) = cos(pos / 10000^{2i / d_model})

    """

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)                       # (max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )                                                         # (d_model/2,)

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)                                      # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Args:

            x: (batch, seq_len, d_model)

        Returns:

            (batch, seq_len, d_model) with positional encoding added.

        """
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


# ---------------------------------------------------------------------------
# Full Transformer Model (10+2 Tied)
# ---------------------------------------------------------------------------
class TransformerTranslator(nn.Module):
    """

    Asymmetric encoder-decoder Transformer with shared/tied embeddings.



    Parameters

    ----------

    vocab_size : int

        Size of the shared source+target vocabulary.

    d_model : int

        Embedding / hidden dimension.

    n_head : int

        Number of attention heads.

    num_encoder_layers : int

        Number of encoder blocks (default 10).

    num_decoder_layers : int

        Number of decoder blocks (default 2).

    d_ff : int

        Feed-forward inner dimension.

    dropout : float

        Dropout rate.

    max_len : int

        Maximum sequence length for positional encoding.

    pad_idx : int

        Padding token ID (used to create padding masks).

    """

    def __init__(

        self,

        vocab_size: int,

        d_model: int = 512,

        n_head: int = 8,

        num_encoder_layers: int = 10,

        num_decoder_layers: int = 2,

        d_ff: int = 2048,

        dropout: float = 0.1,

        max_len: int = 512,

        pad_idx: int = 0,

    ):
        super().__init__()
        self.pad_idx = pad_idx
        self.d_model = d_model

        # --- Shared embedding (one matrix for both enc & dec) -------------
        self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.embed_scale = math.sqrt(d_model)

        # --- Core Transformer (asymmetric, pre-norm) ----------------------
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=n_head,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True,
            norm_first=True,            # pre-layer norm for stability
        )

        # --- Tied output projection (reuses embedding weights) ------------
        # No separate nn.Linear — forward() uses F.linear with shared weights
        self.output_bias = nn.Parameter(torch.zeros(vocab_size))

        # --- Initialize weights -------------------------------------------
        self._init_weights()

    def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
        """Shared embedding + scale + positional encoding."""
        return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale)

    def _init_weights(self):
        """Xavier-uniform initialization for embeddings."""
        nn.init.normal_(self.shared_embedding.weight, mean=0, std=self.d_model ** -0.5)
        # Zero out padding embedding
        with torch.no_grad():
            self.shared_embedding.weight[self.pad_idx].zero_()

    # ------------------------------------------------------------------
    # Mask utilities
    # ------------------------------------------------------------------
    @staticmethod
    def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
        """

        Causal mask for the decoder: prevents attending to future positions.

        Returns a (sz, sz) boolean mask where True = blocked.

        """
        return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)

    def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor:
        """

        Create a padding mask: True where token == pad_idx.

        Shape: (batch, seq_len)

        """
        return x == self.pad_idx

    # ------------------------------------------------------------------
    # Forward
    # ------------------------------------------------------------------
    def forward(

        self,

        src: torch.Tensor,

        tgt: torch.Tensor,

        src_key_padding_mask: Optional[torch.Tensor] = None,

        tgt_key_padding_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        """

        Args:

            src: (batch, src_len) source token IDs.

            tgt: (batch, tgt_len) target token IDs (teacher-forced).



        Returns:

            logits: (batch, tgt_len, vocab_size)

        """
        # Build masks if not provided
        if src_key_padding_mask is None:
            src_key_padding_mask = self._make_pad_mask(src)
        if tgt_key_padding_mask is None:
            tgt_key_padding_mask = self._make_pad_mask(tgt)

        # Causal mask for decoder
        tgt_len = tgt.size(1)
        tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)

        # Shared embeddings for both encoder and decoder
        src_emb = self._embed(src)
        tgt_emb = self._embed(tgt)

        # Transformer forward
        out = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
        )  # (batch, tgt_len, d_model)

        # Tied output projection: logits = out @ embedding_weights.T + bias
        logits = torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
        return logits

    # ------------------------------------------------------------------
    # Inference helpers
    # ------------------------------------------------------------------
    def encode(self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Run only the encoder. Returns memory: (batch, src_len, d_model)."""
        if src_key_padding_mask is None:
            src_key_padding_mask = self._make_pad_mask(src)
        src_emb = self._embed(src)
        return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)

    def decode(

        self,

        tgt: torch.Tensor,

        memory: torch.Tensor,

        tgt_key_padding_mask: Optional[torch.Tensor] = None,

        memory_key_padding_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        """Run only the decoder given encoder memory. Returns logits."""
        if tgt_key_padding_mask is None:
            tgt_key_padding_mask = self._make_pad_mask(tgt)
        tgt_len = tgt.size(1)
        tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
        tgt_emb = self._embed(tgt)
        out = self.transformer.decoder(
            tgt_emb,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )
        return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)


# ---------------------------------------------------------------------------
# Helper: count parameters
# ---------------------------------------------------------------------------
def count_parameters(model: nn.Module) -> int:
    """Return the number of trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# ---------------------------------------------------------------------------
# Helper: build model
# ---------------------------------------------------------------------------
def build_model(

    vocab_size: int,

    pad_idx: int = 0,

    device: Optional[torch.device] = None,

    **kwargs,

) -> TransformerTranslator:
    """

    Build and return a TransformerTranslator with default hyperparameters.



    Any kwarg (d_model, n_head, etc.) overrides the default.

    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = TransformerTranslator(
        vocab_size=vocab_size,
        pad_idx=pad_idx,
        **kwargs,
    ).to(device)

    return model