OliverPerrin commited on
Commit
f672f40
·
1 Parent(s): aefdcf0

Finsihed Decoder Implementation, as well as prediction heads and multitask

Browse files
src/models/decoder.py CHANGED
@@ -1,28 +1,16 @@
1
  """
2
- Transformer Decoder layout (Pre-LN)
3
-
4
- Contents:
5
- - create_causal_mask: utility to build a causal (subsequent) mask
6
- - TransformerDecoderLayer: one decoder block (masked self-attn, cross-attn, FFN)
7
- - TransformerDecoder: embedding/pos-encoding + stack of decoder layers + generation helpers
8
-
9
- Notes / conventions:
10
- - Pre-LN (LayerNorm before each sublayer) is assumed for stability (consistent with your encoder).
11
- - MultiHeadAttention, FeedForward, PositionalEncoding are expected to live in src/models
12
- (you already implemented them).
13
- - Masks use boolean semantics: True = allowed, False = masked.
14
- - The decoder API supports:
15
- - inputs: token ids (LongTensor, (B, T)) or embeddings ((B, T, d_model))
16
- - memory: encoder outputs (B, S, d_model)
17
- - mask arguments: tgt_mask (causal/padding), memory_mask (encoder padding)
18
- - collect_attn: return attention maps per layer if requested
19
- - Generation helpers (greedy) are skeletons that you can extend to beam search or caching.
20
-
21
- TODO status keys:
22
- - [IMPLEMENT] : core implementation required
23
- - [OPTIONAL] : useful enhancement (caching, beam search, advanced scheduling)
24
- """
25
 
 
 
 
 
 
26
  from typing import Optional, Tuple, List, Union, Dict
27
  import math
28
  import torch
@@ -35,47 +23,34 @@ from .positional_encoding import PositionalEncoding
35
 
36
  def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
37
  """
38
- Create a square causal mask of shape (seq_len, seq_len).
39
- True indicates allowed positions; False indicates masked (future) positions.
40
-
41
- Returns:
42
- mask: torch.BoolTensor of shape (seq_len, seq_len)
43
  """
44
- # return a mask with True on and below diagonal, False above diagonal
45
- # The torch.trui function does masking, which is the idea of zeroing all the values in a matrix below a certain diagonal
46
- mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1)
47
- # mask has True above diagonal (to be masked). Want True for allowed, so invert:
48
- return ~mask # shape (seq_len, seq_len) or (T, T)
49
 
50
 
51
  class TransformerDecoderLayer(nn.Module):
52
  """
53
- One decoder layer with:
54
- - Masked self-attention (query/key/value = tgt)
55
- - Encoder-Decoder cross-attention (query = tgt, key/value = memory)
56
- - Position-wise FeedForward
57
- - Pre-LN + residuals + dropout
58
-
59
- Args:
60
- d_model: model hidden size
61
- num_heads: number of attention heads
62
- d_ff: ff intermediate size
63
- dropout: dropout for residuals / FFN
64
  """
65
 
66
  def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
67
  super().__init__()
68
- # NOTE: instantiate internal MHA with dropout=0.0 and manage dropout at layer-level
69
  self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
70
  self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
71
  self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
72
 
73
- # LayerNorms (Pre-LN)
74
  self.norm1 = nn.LayerNorm(d_model)
75
  self.norm2 = nn.LayerNorm(d_model)
76
  self.norm3 = nn.LayerNorm(d_model)
77
 
78
- # Dropouts applied after sublayers (on sublayer outputs before residual add)
79
  self.dropout1 = nn.Dropout(dropout)
80
  self.dropout2 = nn.Dropout(dropout)
81
  self.dropout3 = nn.Dropout(dropout)
@@ -88,46 +63,51 @@ class TransformerDecoderLayer(nn.Module):
88
  memory_mask: Optional[torch.Tensor] = None,
89
  ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
90
  """
91
- Forward pass for one decoder layer.
92
-
93
  Args:
94
- tgt: (batch, tgt_len, d_model)
95
- memory: (batch, src_len, d_model) -- encoder outputs
96
- tgt_mask: optional (batch, tgt_len, tgt_len) or (batch, 1, tgt_len, tgt_len)
97
- memory_mask: optional (batch, src_len, src_len) or (batch, 1, tgt_len, src_len)
98
 
99
  Returns:
100
- output: (batch, tgt_len, d_model)
101
- attn_maps: dict with keys 'self' and 'cross' containing attention tensors
102
- shapes: (batch, num_heads, tgt_len, tgt_len) and (batch, num_heads, tgt_len, src_len)
103
  """
104
- # TODO [IMPLEMENT] Self-attention (Pre-LN)
105
- # x_norm = self.norm1(tgt)
106
- # self_out, self_attn = self.self_attn(x_norm, x_norm, x_norm, tgt_mask)
107
- # tgt = tgt + self.dropout1(self_out)
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # TODO [IMPLEMENT] Cross-attention (Pre-LN)
110
- # x_norm = self.norm2(tgt)
111
- # cross_out, cross_attn = self.cross_attn(x_norm, memory, memory, memory_mask)
112
- # tgt = tgt + self.dropout2(cross_out)
113
 
114
- # TODO [IMPLEMENT] Feed-forward (Pre-LN)
115
- # x_norm = self.norm3(tgt)
116
- # ffn_out = self.ffn(x_norm)
117
- # tgt = tgt + self.dropout3(ffn_out)
118
 
119
- # TODO [RETURN] Return (tgt, {"self": self_attn, "cross": cross_attn})
120
- raise NotImplementedError("TransformerDecoderLayer.forward: implement Pre-LN pipeline")
121
 
122
 
123
  class TransformerDecoder(nn.Module):
124
  """
125
- Full decoder: token embedding + positional encoding + stack of decoder layers.
126
- Also supports simple greedy decoding.
127
 
128
- Args:
129
- vocab_size: for token embeddings (if using token ids)
130
- d_model, num_layers, num_heads, d_ff, dropout, max_len, pad_token_id: same semantics as encoder
131
  """
132
 
133
  def __init__(
@@ -146,37 +126,25 @@ class TransformerDecoder(nn.Module):
146
  self.d_model = d_model
147
  self.pad_token_id = pad_token_id
148
 
149
- # Token embedding (used if inputs are token ids)
150
  self.embedding = nn.Embedding(vocab_size, d_model)
151
-
152
- # Positional encoding
153
  self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
154
 
155
- # Decoder layers
156
  self.layers = nn.ModuleList(
157
- [
158
- TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=dropout)
159
- for _ in range(num_layers)
160
- ]
161
  )
162
 
163
- # Final layer norm for Pre-LN stacks
164
  self.final_norm = nn.LayerNorm(d_model)
165
-
166
- # Output projection to vocabulary (logits)
167
  self.output_projection = nn.Linear(d_model, vocab_size)
168
-
169
- # Input dropout (after pos encoding)
170
  self.input_dropout = nn.Dropout(dropout)
171
 
172
  def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
173
  """
174
- Build (batch, seq, seq) boolean mask from input ids and pad_token_id.
175
- True = allowed, False = masked.
176
  """
177
  assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
178
- pad_mask = (input_ids != self.pad_token_id) # (B, S)
179
- attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B, S, S)
180
  return attn_mask
181
 
182
  def forward(
@@ -188,21 +156,13 @@ class TransformerDecoder(nn.Module):
188
  collect_attn: bool = False,
189
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
190
  """
191
- Forward pass for the decoder stack.
192
-
193
  Args:
194
- inputs: token ids (B, T) or embeddings (B, T, d_model)
195
- memory: encoder outputs (B, S, d_model)
196
- tgt_mask: optional mask for decoder self-attention. If None, a causal mask will be created.
197
- Mask shapes: (B, T, T) or (B, 1, T, T)
198
- memory_mask: optional mask over memory (B, S, S) or (B, 1, T, S)
199
- collect_attn: if True returns (logits/outputs, [per-layer attn dicts])
200
-
201
- Returns:
202
- logits: (B, T, vocab_size) or (B, T, d_model) if you prefer returning hidden states
203
- or (logits, attn_list) if collect_attn True
204
  """
205
- # Inputs: if token ids, embed and scale; else assume embeddings
206
  if inputs.dim() == 2: # token ids
207
  x = self.embedding(inputs) * math.sqrt(self.d_model)
208
  elif inputs.dim() == 3:
@@ -210,47 +170,48 @@ class TransformerDecoder(nn.Module):
210
  else:
211
  raise ValueError("inputs must be (B, T) token ids or (B, T, d_model) embeddings")
212
 
213
- # Positional encoding + dropout
214
  x = self.pos_encoder(x)
215
  x = self.input_dropout(x)
216
 
217
- # Build tgt_mask if not provided: combine causal mask and padding mask if available
218
- seq_len = x.size(1)
 
219
  if tgt_mask is None:
220
- # base causal mask (T, T)
221
- causal = create_causal_mask(seq_len, device=x.device) # [TODO implement]
222
- # expand to batch dim later if padding present
223
  if inputs.dim() == 2 and self.pad_token_id is not None:
224
- padding_mask = self._build_padding_mask_from_ids(inputs) # (B, T, T)
225
- # combine: True only where both causal and padding allow attention
226
- # TODO: ensure shapes align; broadcast causal to (1, T, T) then & with padding_mask
227
- raise NotImplementedError("tgt_mask construction: combine causal + padding_mask")
228
  else:
229
- # TODO: Broadcast causal to (1, T, T) or (B, 1, T, T) depending on downstream expectations
230
- raise NotImplementedError("tgt_mask construction: broadcast causal mask for batch")
 
 
 
231
 
232
- # Ensure memory_mask is boolean on correct device if provided
233
  if memory_mask is not None:
234
  memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
 
 
 
 
235
 
236
  attn_list: List[Dict[str, torch.Tensor]] = []
237
 
238
- # Pass through layers
239
  for layer in self.layers:
240
  x, attn = layer(x, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
241
  if collect_attn:
242
  attn_list.append(attn)
243
 
244
- x = self.final_norm(x) # Pre-LN final normalization
245
-
246
  logits = self.output_projection(x) # (B, T, vocab)
 
247
  if collect_attn:
248
  return logits, attn_list
249
  return logits
250
 
251
- # ---------------------------------------------------------------------
252
- # Generation / inference helpers (skeletons)
253
- # ---------------------------------------------------------------------
254
  def greedy_decode(
255
  self,
256
  memory: torch.Tensor,
@@ -258,26 +219,32 @@ class TransformerDecoder(nn.Module):
258
  start_token_id: int,
259
  end_token_id: Optional[int] = None,
260
  device: Optional[torch.device] = None,
261
- ) -> torch.LongTensor:
262
  """
263
- Greedy autoregressive decoding using the decoder stack.
264
-
265
- Args:
266
- memory: encoder outputs (B, S, d_model)
267
- max_len: maximum target length to generate
268
- start_token_id: BOS token id
269
- end_token_id: optional EOS token id to stop early
270
- Returns:
271
- generated: (B, T_out) long tensor of token ids
272
  """
273
- # TODO [IMPLEMENT]:
274
- # - Start with tensor of shape (B, 1) filled with start_token_id
275
- # - Repeatedly call decoder.forward in incremental mode (or full forward with causal mask)
276
- # - At each step pick argmax over logits and append to sequence
277
- # - Stop if all sequences produced end_token_id or max_len reached
278
- raise NotImplementedError("greedy_decode: implement autoregressive generation loop")
279
-
280
- # Optional: incremental step method with caching of past keys/values for speed
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def step(
282
  self,
283
  last_token_ids: torch.LongTensor,
@@ -285,16 +252,141 @@ class TransformerDecoder(nn.Module):
285
  cache: Optional[Dict] = None,
286
  ) -> Tuple[torch.Tensor, Dict]:
287
  """
288
- Single-step decoder that returns logits for the next token given last_token_ids.
289
 
290
  Args:
291
- last_token_ids: (B, 1) tokens at current time step
292
- memory: encoder outputs
293
- cache: optional dict storing per-layer cached keys/values
294
 
295
  Returns:
296
- logits: (B, vocab_size)
297
- new_cache: updated cache
298
  """
299
- # TODO [OPTIONAL]: implement fast incremental decoding caching keys/values per layer
300
- raise NotImplementedError("step: incremental decoding (optional optimization)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Transformer Decoder (Pre-LN) - implementation.
3
+
4
+ Implements:
5
+ - create_causal_mask
6
+ - TransformerDecoderLayer
7
+ - TransformerDecoder (stack + naive greedy decoding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ Conventions:
10
+ - Masks are boolean: True = allowed, False = masked.
11
+ - MultiHeadAttention expects masks broadcastable to (B, num_heads, T_q, T_k).
12
+ - This decoder uses Pre-LN (LayerNorm before each sublayer).
13
+ """
14
  from typing import Optional, Tuple, List, Union, Dict
15
  import math
16
  import torch
 
23
 
24
  def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
25
  """
26
+ Create a (seq_len, seq_len) causal mask where entry (i, j) is True iff
27
+ j <= i (query at i may attend to keys up to i).
 
 
 
28
  """
29
+ # torch.triu(..., diagonal=1) is True above the diagonal. Invert to get allowed positions.
30
+ mask = ~torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1)
31
+ return mask # shape: (T, T)
 
 
32
 
33
 
34
  class TransformerDecoderLayer(nn.Module):
35
  """
36
+ Single decoder layer (Pre-LN):
37
+ 1) Masked self-attention
38
+ 2) Cross-attention (encoder -> decoder)
39
+ 3) Feed-forward
40
+ Returns the updated tgt and a dict of attention maps.
 
 
 
 
 
 
41
  """
42
 
43
  def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
44
  super().__init__()
45
+ # use internal MHA dropout = 0.0; the layer handles dropout after sublayers
46
  self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
47
  self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
48
  self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
49
 
 
50
  self.norm1 = nn.LayerNorm(d_model)
51
  self.norm2 = nn.LayerNorm(d_model)
52
  self.norm3 = nn.LayerNorm(d_model)
53
 
 
54
  self.dropout1 = nn.Dropout(dropout)
55
  self.dropout2 = nn.Dropout(dropout)
56
  self.dropout3 = nn.Dropout(dropout)
 
63
  memory_mask: Optional[torch.Tensor] = None,
64
  ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
65
  """
 
 
66
  Args:
67
+ tgt: (B, T, d_model)
68
+ memory: (B, S, d_model)
69
+ tgt_mask: optional mask for self-attn - shape (B, T, T) or (B, 1, T, T)
70
+ memory_mask: optional mask for cross-attn - shape (B, S) or (B, 1, S) or (B, 1, T, S)
71
 
72
  Returns:
73
+ (tgt_out, {"self": self_attn_weights, "cross": cross_attn_weights})
 
 
74
  """
75
+ # Ensure masks are on same device and boolean
76
+ if tgt_mask is not None:
77
+ tgt_mask = tgt_mask.to(dtype=torch.bool, device=tgt.device)
78
+ if memory_mask is not None:
79
+ memory_mask = memory_mask.to(dtype=torch.bool, device=tgt.device)
80
+ # If memory_mask is provided as (B, S) (per-key padding), expand to (B, 1, 1, S)
81
+ if memory_mask.dim() == 2:
82
+ memory_mask = memory_mask.unsqueeze(1).unsqueeze(1) # (B,1,1,S)
83
+ # If it's (B, S, S) or (B, 1, S, S) leave as-is; if (B, T, S) convert to (B,1,T,S)
84
+ elif memory_mask.dim() == 3 and memory_mask.shape[1] != 1:
85
+ # assume (B, T, S) -> make (B, 1, T, S)
86
+ memory_mask = memory_mask.unsqueeze(1)
87
+
88
+ # --- Masked self-attention (Pre-LN) ---
89
+ x_norm = self.norm1(tgt)
90
+ self_out, self_attn = self.self_attn(x_norm, x_norm, x_norm, tgt_mask)
91
+ tgt = tgt + self.dropout1(self_out)
92
 
93
+ # --- Cross-attention (Pre-LN) ---
94
+ x_norm = self.norm2(tgt)
95
+ cross_out, cross_attn = self.cross_attn(x_norm, memory, memory, memory_mask)
96
+ tgt = tgt + self.dropout2(cross_out)
97
 
98
+ # --- Feed-forward (Pre-LN) ---
99
+ x_norm = self.norm3(tgt)
100
+ ffn_out = self.ffn(x_norm)
101
+ tgt = tgt + self.dropout3(ffn_out)
102
 
103
+ return tgt, {"self": self_attn, "cross": cross_attn}
 
104
 
105
 
106
  class TransformerDecoder(nn.Module):
107
  """
108
+ Decoder stack with token embeddings and positional encoding.
 
109
 
110
+ Forward returns logits (B, T, vocab_size) by default; if collect_attn=True returns (logits, attn_list).
 
 
111
  """
112
 
113
  def __init__(
 
126
  self.d_model = d_model
127
  self.pad_token_id = pad_token_id
128
 
 
129
  self.embedding = nn.Embedding(vocab_size, d_model)
 
 
130
  self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
131
 
 
132
  self.layers = nn.ModuleList(
133
+ [TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=dropout)
134
+ for _ in range(num_layers)]
 
 
135
  )
136
 
 
137
  self.final_norm = nn.LayerNorm(d_model)
 
 
138
  self.output_projection = nn.Linear(d_model, vocab_size)
 
 
139
  self.input_dropout = nn.Dropout(dropout)
140
 
141
  def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
142
  """
143
+ Convert input ids to (B, T, T) boolean mask where True = allowed.
 
144
  """
145
  assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
146
+ pad_mask = (input_ids != self.pad_token_id) # (B, T)
147
+ attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B, T, T)
148
  return attn_mask
149
 
150
  def forward(
 
156
  collect_attn: bool = False,
157
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
158
  """
 
 
159
  Args:
160
+ inputs: (B, T) token ids or (B, T, d_model) embeddings
161
+ memory: (B, S, d_model)
162
+ tgt_mask: optional; if None, will create (causal [+ padding if ids available])
163
+ memory_mask: optional; if provided as (B, S) will be expanded to (B, 1, 1, S)
 
 
 
 
 
 
164
  """
165
+ # Prepare embeddings
166
  if inputs.dim() == 2: # token ids
167
  x = self.embedding(inputs) * math.sqrt(self.d_model)
168
  elif inputs.dim() == 3:
 
170
  else:
171
  raise ValueError("inputs must be (B, T) token ids or (B, T, d_model) embeddings")
172
 
 
173
  x = self.pos_encoder(x)
174
  x = self.input_dropout(x)
175
 
176
+ B, T, _ = x.shape
177
+
178
+ # Build target mask if not provided: combine causal + padding (if available)
179
  if tgt_mask is None:
180
+ causal = create_causal_mask(T, device=x.device) # (T, T)
 
 
181
  if inputs.dim() == 2 and self.pad_token_id is not None:
182
+ pad_pairwise = self._build_padding_mask_from_ids(inputs) # (B, T, T)
183
+ combined = pad_pairwise & causal.unsqueeze(0) # (B, T, T)
184
+ tgt_mask = combined.unsqueeze(1) # (B, 1, T, T) -> broadcast to heads
 
185
  else:
186
+ # No per-batch padding info: broadcast causal to (1, 1, T, T)
187
+ tgt_mask = causal.unsqueeze(0).unsqueeze(1) # (1, 1, T, T)
188
+ else:
189
+ # Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
190
+ tgt_mask = tgt_mask.to(dtype=torch.bool, device=x.device)
191
 
192
+ # Normalize memory_mask dtype/device and expand simple shapes
193
  if memory_mask is not None:
194
  memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
195
+ if memory_mask.dim() == 2: # (B, S) -> (B, 1, 1, S)
196
+ memory_mask = memory_mask.unsqueeze(1).unsqueeze(1)
197
+ elif memory_mask.dim() == 3: # (B, T, S) -> (B, 1, T, S)
198
+ memory_mask = memory_mask.unsqueeze(1)
199
 
200
  attn_list: List[Dict[str, torch.Tensor]] = []
201
 
202
+ # Pass through decoder layers
203
  for layer in self.layers:
204
  x, attn = layer(x, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
205
  if collect_attn:
206
  attn_list.append(attn)
207
 
208
+ x = self.final_norm(x)
 
209
  logits = self.output_projection(x) # (B, T, vocab)
210
+
211
  if collect_attn:
212
  return logits, attn_list
213
  return logits
214
 
 
 
 
215
  def greedy_decode(
216
  self,
217
  memory: torch.Tensor,
 
219
  start_token_id: int,
220
  end_token_id: Optional[int] = None,
221
  device: Optional[torch.device] = None,
222
+ ) -> torch.Tensor:
223
  """
224
+ Naive greedy decoding: repeatedly run the decoder on the growing prefix.
225
+ Not optimized (recomputes full decoder each step) but simple and correct.
 
 
 
 
 
 
 
226
  """
227
+ if device is None:
228
+ device = memory.device
229
+ B = memory.size(0)
230
+ generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
231
+
232
+ for _ in range(max_len - 1):
233
+ logits = self.forward(generated, memory, collect_attn=False) # (B, L, V)
234
+ assert isinstance(logits, torch.Tensor) # type narrowing
235
+ next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) # (B, 1)
236
+ generated = torch.cat([generated, next_token], dim=1)
237
+
238
+ if end_token_id is not None:
239
+ # stop if all sequences ended
240
+ if (generated[:, -1] == end_token_id).all():
241
+ break
242
+
243
+ return generated
244
+
245
+ # -----------------------------
246
+ # Incremental single-step API
247
+ # -----------------------------
248
  def step(
249
  self,
250
  last_token_ids: torch.LongTensor,
 
252
  cache: Optional[Dict] = None,
253
  ) -> Tuple[torch.Tensor, Dict]:
254
  """
255
+ Run one autoregressive step.
256
 
257
  Args:
258
+ last_token_ids: (B, 1) last generated token ids
259
+ memory: encoder outputs (B, S, d_model)
260
+ cache: optional dict with previous cached keys/values and 'past_length'.
261
 
262
  Returns:
263
+ logits: (B, vocab_size) logits for the next-token prediction
264
+ new_cache: updated cache dictionary
265
  """
266
+ device = memory.device
267
+ B = last_token_ids.size(0)
268
+
269
+ if cache is None:
270
+ cache = {}
271
+ past_len = int(cache.get("past_length", 0))
272
+
273
+ # 1) Embed last token and add positional encoding for position `past_len`
274
+ x = self.embedding(last_token_ids) * math.sqrt(self.d_model) # (B,1,d)
275
+ # Use positional encoding buffer directly (avoid dropout in pos_encoder)
276
+ # pos_encoder.pe expected shape (1, max_len, d_model)
277
+ if hasattr(self.pos_encoder, "pe"):
278
+ pe = self.pos_encoder.pe # (1, max_len, d_model)
279
+ pos_idx = past_len
280
+ if pos_idx >= pe.size(1):
281
+ raise RuntimeError(f"pos_idx {pos_idx} exceeds max_len {pe.size(1)}")
282
+ x = x + pe[:, pos_idx:pos_idx + 1, :].to(device)
283
+ else:
284
+ # fallback: call pos_encoder and rely on its dropout (less ideal)
285
+ x = self.pos_encoder(x)
286
+
287
+ # We will update new_cache incrementally
288
+ new_cache = dict(cache) # shallow copy
289
+ new_cache["past_length"] = past_len + 1
290
+
291
+ # Optional: memory_mask could be supplied in cache under 'memory_mask'
292
+ memory_mask = new_cache.get("memory_mask", None)
293
+ if memory_mask is not None:
294
+ memory_mask = memory_mask.to(dtype=torch.bool, device=device)
295
+ # expand (B, S) -> (B,1,1,S) if necessary
296
+ if memory_mask.dim() == 2:
297
+ memory_mask = memory_mask.unsqueeze(1).unsqueeze(1)
298
+ elif memory_mask.dim() == 3:
299
+ memory_mask = memory_mask.unsqueeze(1)
300
+
301
+ # Iterate layers, updating caches and computing output for current token only
302
+ layer_input = x # (B,1,d_model)
303
+ for i, layer in enumerate(self.layers):
304
+ # -------------------
305
+ # 1) Self-attention (incremental)
306
+ # -------------------
307
+ # Normalize input for pre-LN
308
+ x_norm = layer.norm1(layer_input) # (B,1,d)
309
+
310
+ # Project Q,K,V for the new token
311
+ Q_new = layer.self_attn.W_Q(x_norm) # (B,1,d_model)
312
+ K_new = layer.self_attn.W_K(x_norm)
313
+ V_new = layer.self_attn.W_V(x_norm)
314
+
315
+ # Reshape into heads: (B, num_heads, 1, d_k)
316
+ B_, Lq, _ = Q_new.shape
317
+ num_heads = layer.self_attn.num_heads
318
+ d_k = layer.self_attn.d_k
319
+ Qh = Q_new.view(B_, Lq, num_heads, d_k).transpose(1, 2) # (B, num_heads, 1, d_k)
320
+ Kh = K_new.view(B_, Lq, num_heads, d_k).transpose(1, 2)
321
+ Vh = V_new.view(B_, Lq, num_heads, d_k).transpose(1, 2)
322
+
323
+ # Retrieve cached keys/values for self-attn (if exist)
324
+ cache_k = cache.get(f"self_k_{i}", None)
325
+ cache_v = cache.get(f"self_v_{i}", None)
326
+ if cache_k is None or cache_v is None:
327
+ K_all = Kh # (B, H, 1, d_k)
328
+ V_all = Vh
329
+ else:
330
+ # concat along sequence dim (dim=2)
331
+ K_all = torch.cat([cache_k.to(device), Kh], dim=2)
332
+ V_all = torch.cat([cache_v.to(device), Vh], dim=2)
333
+
334
+ # Store updated caches
335
+ new_cache[f"self_k_{i}"] = K_all
336
+ new_cache[f"self_v_{i}"] = V_all
337
+
338
+ # Compute attention for the new token: Query length = 1, Key length = K_all.size(2)
339
+ attn_out_heads, self_attn_w = layer.self_attn.attention(Qh, K_all, V_all, mask=None)
340
+ # attn_out_heads: (B, H, 1, d_k)
341
+ # concat heads, project out
342
+ attn_out = attn_out_heads.transpose(1, 2).contiguous().view(B_, 1, num_heads * d_k)
343
+ attn_out = layer.self_attn.W_O(attn_out) # (B,1,d_model)
344
+ layer_output = layer_input + layer.dropout1(attn_out)
345
+
346
+ # -------------------
347
+ # 2) Cross-attention (use cached memory projections if available)
348
+ # -------------------
349
+ x_norm2 = layer.norm2(layer_output) # (B,1,d)
350
+ # Ensure memory K/V are cached per layer
351
+ mem_k = cache.get(f"mem_k_{i}", None)
352
+ mem_v = cache.get(f"mem_v_{i}", None)
353
+ if mem_k is None or mem_v is None:
354
+ # project memory once for this layer and cache it
355
+ # memory: (B, S, d_model)
356
+ MK = layer.cross_attn.W_K(memory) # (B, S, d_model)
357
+ MV = layer.cross_attn.W_V(memory)
358
+ Bm, S, _ = MK.shape
359
+ MKh = MK.view(Bm, S, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(1, 2) # (B,H,S,d_k)
360
+ MVh = MV.view(Bm, S, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(1, 2)
361
+ mem_k = MKh
362
+ mem_v = MVh
363
+ new_cache[f"mem_k_{i}"] = mem_k
364
+ new_cache[f"mem_v_{i}"] = mem_v
365
+ else:
366
+ mem_k = mem_k.to(device)
367
+ mem_v = mem_v.to(device)
368
+
369
+ Qc = layer.cross_attn.W_Q(x_norm2) # (B,1,d_model)
370
+ Qch = Qc.view(B, 1, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(1, 2) # (B,H,1,d_k)
371
+
372
+ cross_out_heads, cross_attn_w = layer.cross_attn.attention(Qch, mem_k, mem_v, mask=memory_mask)
373
+ cross_out = cross_out_heads.transpose(1, 2).contiguous().view(B, 1, layer.cross_attn.num_heads * layer.cross_attn.d_k)
374
+ cross_out = layer.cross_attn.W_O(cross_out) # (B,1,d_model)
375
+ layer_output = layer_output + layer.dropout2(cross_out)
376
+
377
+ # -------------------
378
+ # 3) Feed-forward (incremental)
379
+ # -------------------
380
+ x_norm3 = layer.norm3(layer_output)
381
+ ffn_out = layer.ffn(x_norm3) # (B,1,d_model)
382
+ layer_output = layer_output + layer.dropout3(ffn_out)
383
+
384
+ # Prepare for next layer
385
+ layer_input = layer_output
386
+
387
+ # Final norm + output projection (for this single time step)
388
+ out_norm = self.final_norm(layer_input) # (B,1,d_model)
389
+ logits = self.output_projection(out_norm) # (B,1,vocab)
390
+ logits = logits.squeeze(1) # (B, vocab)
391
+
392
+ return logits, new_cache
src/models/heads.py CHANGED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prediction heads for Transformer models.
3
+
4
+ Includes:
5
+ - ClassificationHead: sequence-level classification with simple pooling (mean/cls/max).
6
+ - TokenClassificationHead: per-token classification (e.g., NER).
7
+ - LMHead: language-modeling head mapping hidden states to vocabulary logits. Optional weight tying to an Embedding.
8
+ - ProjectionHead: small projection MLP for representation learning / contrastive heads.
9
+
10
+ Keep these heads minimal, well-tested, and easy to compose on top of encoder/decoder outputs.
11
+ """
12
+ from typing import Optional, Literal
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ class ClassificationHead(nn.Module):
19
+ """
20
+ Sequence-level classification head.
21
+
22
+ Args:
23
+ d_model: hidden size from encoder/decoder
24
+ num_labels: number of output classes
25
+ pooler: one of 'mean', 'cls', 'max' - how to pool the sequence
26
+ dropout: dropout probability before final linear layer
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ d_model: int,
32
+ num_labels: int,
33
+ pooler: Literal["mean", "cls", "max"] = "mean",
34
+ dropout: float = 0.1,
35
+ ):
36
+ super().__init__()
37
+ assert pooler in ("mean", "cls", "max"), "pooler must be 'mean'|'cls'|'max'"
38
+ self.pooler = pooler
39
+ self.dropout = nn.Dropout(dropout)
40
+ self.out_proj = nn.Linear(d_model, num_labels)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ x: (batch, seq_len, d_model)
45
+ returns: (batch, num_labels)
46
+ """
47
+ if self.pooler == "mean":
48
+ pooled = x.mean(dim=1)
49
+ elif self.pooler == "cls":
50
+ pooled = x[:, 0, :]
51
+ else: # max
52
+ pooled, _ = x.max(dim=1)
53
+ pooled = self.dropout(pooled)
54
+ return self.out_proj(pooled)
55
+
56
+
57
+ class TokenClassificationHead(nn.Module):
58
+ """
59
+ Per-token classification head. Useful for NER, POS, etc.
60
+
61
+ Args:
62
+ d_model: hidden size
63
+ num_labels: number of per-token classes
64
+ dropout: dropout probability applied before the linear layer
65
+ """
66
+
67
+ def __init__(self, d_model: int, num_labels: int, dropout: float = 0.1):
68
+ super().__init__()
69
+ self.dropout = nn.Dropout(dropout)
70
+ self.out_proj = nn.Linear(d_model, num_labels)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ """
74
+ x: (batch, seq_len, d_model)
75
+ returns: (batch, seq_len, num_labels)
76
+ """
77
+ x = self.dropout(x)
78
+ return self.out_proj(x)
79
+
80
+
81
+ class LMHead(nn.Module):
82
+ """
83
+ Language modeling head: maps hidden states to logits over vocabulary.
84
+
85
+ Args:
86
+ d_model: hidden size
87
+ vocab_size: vocabulary size
88
+ tie_embedding: optional nn.Embedding instance to tie weights with
89
+ """
90
+
91
+ def __init__(self, d_model: int, vocab_size: int, tie_embedding: Optional[nn.Embedding] = None):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.d_model = d_model
95
+ self.proj = nn.Linear(d_model, vocab_size, bias=True)
96
+
97
+ if tie_embedding is not None:
98
+ # Validate sizes
99
+ assert tie_embedding.num_embeddings == vocab_size, "vocab size mismatch for weight tying"
100
+ assert tie_embedding.embedding_dim == d_model, "embedding dim must match d_model for weight tying"
101
+ # Tie weights: point the projection weight to the embedding weight Tensor
102
+ # Remove the existing projection parameter in favor of the embedding weight
103
+ # This keeps the same Parameter object, so updates affect both modules.
104
+ self.proj.weight = tie_embedding.weight
105
+
106
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ hidden_states: (batch, seq_len, d_model)
109
+ returns logits: (batch, seq_len, vocab_size)
110
+ """
111
+ return self.proj(hidden_states)
112
+
113
+
114
+ class ProjectionHead(nn.Module):
115
+ """
116
+ Simple projection head for representation learning.
117
+
118
+ Args:
119
+ d_model: input dimension
120
+ proj_dim: output projection dimension
121
+ hidden_dim: intermediate dimension (optional)
122
+ dropout: dropout probability
123
+ """
124
+
125
+ def __init__(self, d_model: int, proj_dim: int = 128, hidden_dim: Optional[int] = None, dropout: float = 0.1):
126
+ super().__init__()
127
+ if hidden_dim is None:
128
+ hidden_dim = max(d_model, proj_dim)
129
+ self.net = nn.Sequential(
130
+ nn.Linear(d_model, hidden_dim),
131
+ nn.GELU(),
132
+ nn.Dropout(dropout),
133
+ nn.Linear(hidden_dim, proj_dim),
134
+ )
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ """
138
+ x: (batch, d_model) or (batch, seq_len, d_model) - both supported.
139
+ Returns:
140
+ If input is 3D: (batch, seq_len, proj_dim)
141
+ If input is 2D: (batch, proj_dim)
142
+ """
143
+ orig_dim = x.dim()
144
+ if orig_dim == 3:
145
+ B, T, D = x.shape
146
+ out = self.net(x.view(B * T, D))
147
+ return out.view(B, T, -1)
148
+ elif orig_dim == 2:
149
+ return self.net(x)
150
+ else:
151
+ raise ValueError("Input must be 2D or 3D tensor")
src/models/multitask.py CHANGED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multitask model composition utilities.
3
+
4
+ Provides:
5
+ - MultiTaskModel: lightweight wrapper to compose an encoder and/or decoder with
6
+ multiple task heads (classification, token classification, LM head, etc.)
7
+ - add_head / remove_head helpers
8
+ - forward(task_name, ...) that routes inputs to the correct sub-modules
9
+ - compute_loss helper that uses common losses and ignore_index support
10
+
11
+ Design goals:
12
+ - Keep composition simple and explicit (use named heads per task)
13
+ - Support encoder-only tasks (classification, token classification) and
14
+ seq2seq tasks (encoder -> decoder -> LMHead)
15
+ - Minimal dependencies on training loop; return logits and (optionally) loss
16
+ """
17
+ from typing import Optional, Dict, Any, Tuple
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ # Import your components
24
+ from .encoder import TransformerEncoder
25
+ from .decoder import TransformerDecoder
26
+ from .heads import ClassificationHead, TokenClassificationHead, LMHead
27
+
28
+
29
+ class MultiTaskModel(nn.Module):
30
+ """
31
+ Compose encoder/decoder and task heads.
32
+
33
+ Usage patterns:
34
+ - Encoder-only classification:
35
+ mt = MultiTaskModel(encoder=enc)
36
+ mt.add_head("sentiment", ClassificationHead(...))
37
+ logits = mt.forward("sentiment", {"input_ids": src_ids})
38
+ - Seq2seq LM:
39
+ mt = MultiTaskModel(encoder=enc, decoder=dec)
40
+ mt.add_head("summarize", LMHead(...))
41
+ logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ encoder: Optional[TransformerEncoder] = None,
47
+ decoder: Optional[TransformerDecoder] = None,
48
+ ):
49
+ super().__init__()
50
+ self.encoder = encoder
51
+ self.decoder = decoder
52
+ self.heads: Dict[str, nn.Module] = {}
53
+
54
+ def add_head(self, name: str, module: nn.Module) -> None:
55
+ """Register a head under a task name."""
56
+ if name in self.heads:
57
+ raise ValueError(f"Head '{name}' already exists")
58
+ self.heads[name] = module
59
+ self.add_module(f"head_{name}", module)
60
+
61
+ def remove_head(self, name: str) -> None:
62
+ """Remove a registered head."""
63
+ if name not in self.heads:
64
+ raise KeyError(name)
65
+ del self._modules[f"head_{name}"]
66
+ del self.heads[name]
67
+
68
+ def forward(
69
+ self,
70
+ task: str,
71
+ inputs: Dict[str, torch.Tensor],
72
+ return_loss: bool = False,
73
+ loss_kwargs: Optional[Dict[str, Any]] = None,
74
+ ) -> Any:
75
+ """
76
+ Route inputs to appropriate model components and head.
77
+
78
+ Args:
79
+ task: registered head name
80
+ inputs: dictionary; common keys:
81
+ - For encoder tasks: "input_ids" or "embeddings" (B, S) or (B, S, d)
82
+ - For seq2seq: "src_ids" (B,S) or "src_embeddings", and "tgt_ids" (B,T) or "tgt_embeddings"
83
+ when computing training loss, pass "labels" (B,T) for LM
84
+ return_loss: if True and labels provided, returns (loss, logits)
85
+ loss_kwargs: forwarded to compute_loss (e.g., ignore_index)
86
+
87
+ Returns:
88
+ logits (or (loss, logits) if return_loss True)
89
+ """
90
+ if task not in self.heads:
91
+ raise KeyError(f"Unknown task/head '{task}'")
92
+
93
+ head = self.heads[task]
94
+ loss_kwargs = loss_kwargs or {}
95
+
96
+ # Encoder-only heads expect encoder outputs
97
+ if isinstance(head, (ClassificationHead, TokenClassificationHead)):
98
+ if self.encoder is None:
99
+ raise RuntimeError("Encoder is required for encoder-side heads")
100
+ # accept either input_ids or embeddings
101
+ if "input_ids" in inputs:
102
+ enc_out = self.encoder(inputs["input_ids"])
103
+ elif "embeddings" in inputs:
104
+ enc_out = self.encoder(inputs["embeddings"])
105
+ else:
106
+ raise ValueError("inputs must contain 'input_ids' or 'embeddings' for encoder tasks")
107
+ logits = head(enc_out)
108
+
109
+ if return_loss:
110
+ labels = inputs.get("labels", None)
111
+ if labels is None:
112
+ raise ValueError("return_loss=True requires 'labels' in inputs")
113
+ loss = self.compute_loss_for_head(head, logits, labels, **loss_kwargs)
114
+ return loss, logits
115
+ return logits
116
+
117
+ # LM/seq2seq head: run encoder -> decoder -> lm head
118
+ if isinstance(head, LMHead):
119
+ if self.encoder is None or self.decoder is None:
120
+ raise RuntimeError("Both encoder and decoder are required for LM-style heads")
121
+
122
+ # Build encoder memory
123
+ if "src_ids" in inputs:
124
+ memory = self.encoder(inputs["src_ids"])
125
+ elif "src_embeddings" in inputs:
126
+ memory = self.encoder(inputs["src_embeddings"])
127
+ else:
128
+ raise ValueError("inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks")
129
+
130
+ # If training / teacher forcing: expect tgt_ids (shifted by caller) or embeddings
131
+ if "tgt_ids" in inputs:
132
+ decoder_inputs = inputs["tgt_ids"]
133
+ elif "tgt_embeddings" in inputs:
134
+ decoder_inputs = inputs["tgt_embeddings"]
135
+ else:
136
+ # For generation time you may call decoder.greedy_decode separately.
137
+ # Here we don't attempt to generate when labels not provided.
138
+ raise ValueError("Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward")
139
+
140
+ # Run decoder. Decoder returns logits shaped (B, T, vocab) in this codebase.
141
+ decoder_out = self.decoder(decoder_inputs, memory)
142
+
143
+ # If decoder already returned logits matching the head vocab size, use them directly.
144
+ # Otherwise, assume decoder returned hidden states and let the head project them.
145
+ if isinstance(decoder_out, torch.Tensor) and decoder_out.shape[-1] == head.vocab_size:
146
+ logits = decoder_out
147
+ else:
148
+ logits = head(decoder_out)
149
+
150
+ if return_loss:
151
+ labels = inputs.get("labels", None)
152
+ if labels is None:
153
+ raise ValueError("return_loss=True requires 'labels' in inputs for seq2seq")
154
+ loss = self.compute_loss_for_head(head, logits, labels, **loss_kwargs)
155
+ return loss, logits
156
+ return logits
157
+
158
+ # Otherwise unsupported head type
159
+ raise RuntimeError(f"Unsupported head type: {type(head)}")
160
+
161
+ def compute_loss_for_head(
162
+ self,
163
+ head: nn.Module,
164
+ logits: torch.Tensor,
165
+ labels: torch.Tensor,
166
+ ignore_index: int = -100,
167
+ ) -> torch.Tensor:
168
+ """
169
+ Default loss dispatch:
170
+ - ClassificationHead: CrossEntropy on (B, num_labels)
171
+ - TokenClassificationHead: CrossEntropy per token (flattened)
172
+ - LMHead: CrossEntropy per token (flattened), ignore_index supported
173
+
174
+ Returns scalar loss.
175
+ """
176
+ if isinstance(head, ClassificationHead):
177
+ # logits: (B, num_labels) or (B, num_labels) direct
178
+ loss = F.cross_entropy(logits, labels.long())
179
+ return loss
180
+
181
+ if isinstance(head, TokenClassificationHead):
182
+ # logits: (B, T, C), labels: (B, T)
183
+ B, T, C = logits.shape
184
+ loss = F.cross_entropy(logits.view(B * T, C), labels.view(B * T).long(), ignore_index=ignore_index)
185
+ return loss
186
+
187
+ if isinstance(head, LMHead):
188
+ # logits: (B, T, V), labels: (B, T)
189
+ B, T, V = logits.shape
190
+ loss = F.cross_entropy(logits.view(B * T, V), labels.view(B * T).long(), ignore_index=ignore_index)
191
+ return loss
192
+
193
+ # Generic fall-back: try CrossEntropy on final dim
194
+ if logits.dim() == 2:
195
+ return F.cross_entropy(logits, labels.long())
196
+
197
+ # If we can't determine, raise
198
+ raise RuntimeError("Cannot compute loss for unknown head type")
tests/test_models/test_decoder.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytest
3
+ from src.models.decoder import (
4
+ create_causal_mask,
5
+ TransformerDecoderLayer,
6
+ TransformerDecoder,
7
+ )
8
+
9
+
10
+ def test_create_causal_mask_properties():
11
+ mask = create_causal_mask(5)
12
+ assert mask.shape == (5, 5)
13
+ # diagonal and below should be True
14
+ for i in range(5):
15
+ for j in range(5):
16
+ if j <= i:
17
+ assert mask[i, j].item() is True
18
+ else:
19
+ assert mask[i, j].item() is False
20
+
21
+
22
+ def test_decoder_layer_shapes_and_grad():
23
+ torch.manual_seed(0)
24
+ d_model, num_heads, d_ff = 32, 4, 64
25
+ batch_size, tgt_len, src_len = 2, 6, 7
26
+
27
+ layer = TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
28
+ tgt = torch.randn(batch_size, tgt_len, d_model, requires_grad=True)
29
+ memory = torch.randn(batch_size, src_len, d_model)
30
+
31
+ # No masks
32
+ out, attn = layer(tgt, memory, tgt_mask=None, memory_mask=None)
33
+ assert out.shape == (batch_size, tgt_len, d_model)
34
+ assert isinstance(attn, dict)
35
+ assert "self" in attn and "cross" in attn
36
+ assert attn["self"].shape == (batch_size, num_heads, tgt_len, tgt_len)
37
+ assert attn["cross"].shape == (batch_size, num_heads, tgt_len, src_len)
38
+
39
+ # Backprop works
40
+ loss = out.sum()
41
+ loss.backward()
42
+ grads = [p.grad for p in layer.parameters() if p.requires_grad]
43
+ assert any(g is not None for g in grads)
44
+
45
+
46
+ def test_decoder_layer_causal_mask_blocks_future():
47
+ torch.manual_seed(1)
48
+ d_model, num_heads, d_ff = 48, 6, 128
49
+ batch_size, tgt_len, src_len = 1, 5, 5
50
+
51
+ layer = TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
52
+ # create trivial increasing tgt embeddings so attention patterns are deterministic-ish
53
+ tgt = torch.randn(batch_size, tgt_len, d_model)
54
+ memory = torch.randn(batch_size, src_len, d_model)
55
+
56
+ causal = create_causal_mask(tgt_len, device=tgt.device) # (T, T)
57
+ tgt_mask = causal.unsqueeze(0) # (1, T, T) -> layer will handle unsqueeze to heads
58
+
59
+ out, attn = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=None)
60
+ self_attn = attn["self"].detach()
61
+ # Ensure upper triangle of attention weights is zero (no future attention)
62
+ # For each head and query i, keys j>i should be zero
63
+ B, H, Tq, Tk = self_attn.shape
64
+ for i in range(Tq):
65
+ for j in range(i + 1, Tk):
66
+ assert torch.allclose(self_attn[:, :, i, j], torch.zeros(B, H)), \
67
+ f"Found nonzero attention to future position {j} from query {i}"
68
+
69
+
70
+ def test_decoder_stack_and_greedy_decode_shapes():
71
+ torch.manual_seed(2)
72
+ vocab_size = 30
73
+ d_model = 32
74
+ num_layers = 2
75
+ num_heads = 4
76
+ d_ff = 128
77
+ batch_size = 2
78
+ src_len = 7
79
+ max_tgt = 6
80
+
81
+ decoder = TransformerDecoder(
82
+ vocab_size=vocab_size,
83
+ d_model=d_model,
84
+ num_layers=num_layers,
85
+ num_heads=num_heads,
86
+ d_ff=d_ff,
87
+ dropout=0.0,
88
+ max_len=max_tgt,
89
+ pad_token_id=0,
90
+ )
91
+
92
+ # Random memory from encoder
93
+ memory = torch.randn(batch_size, src_len, d_model)
94
+
95
+ # Greedy decode: should produce (B, <= max_tgt)
96
+ generated = decoder.greedy_decode(memory, max_len=max_tgt, start_token_id=1, end_token_id=None)
97
+ assert generated.shape[0] == batch_size
98
+ assert generated.shape[1] <= max_tgt
99
+ assert (generated[:, 0] == 1).all() # starts with start token
100
+
101
+ # Also test forward with embeddings and collect_attn
102
+ embeddings = torch.randn(batch_size, max_tgt, d_model)
103
+ logits, attn_list = decoder(embeddings, memory, collect_attn=True)
104
+ assert logits.shape == (batch_size, max_tgt, vocab_size)
105
+ assert isinstance(attn_list, list)
106
+ assert len(attn_list) == num_layers
107
+ for attn in attn_list:
108
+ assert "self" in attn and "cross" in attn
109
+
110
+
111
+ def test_decoder_train_eval_dropout_behavior():
112
+ torch.manual_seed(3)
113
+ vocab_size = 40
114
+ d_model = 32
115
+ num_layers = 2
116
+ num_heads = 4
117
+ d_ff = 128
118
+ batch_size = 2
119
+ src_len = 6
120
+ tgt_len = 5
121
+
122
+ decoder = TransformerDecoder(
123
+ vocab_size=vocab_size,
124
+ d_model=d_model,
125
+ num_layers=num_layers,
126
+ num_heads=num_heads,
127
+ d_ff=d_ff,
128
+ dropout=0.4,
129
+ max_len=tgt_len,
130
+ pad_token_id=0,
131
+ )
132
+
133
+ # token ids with padding possible
134
+ input_ids = torch.randint(1, vocab_size, (batch_size, tgt_len), dtype=torch.long)
135
+ input_ids[0, -1] = 0
136
+
137
+ memory = torch.randn(batch_size, src_len, d_model)
138
+
139
+ decoder.train()
140
+ out1 = decoder(input_ids, memory)
141
+ out2 = decoder(input_ids, memory)
142
+ # With dropout in train mode, outputs should usually differ
143
+ assert not torch.allclose(out1, out2)
144
+
145
+ decoder.eval()
146
+ out3 = decoder(input_ids, memory)
147
+ out4 = decoder(input_ids, memory)
148
+ assert torch.allclose(out3, out4)
149
+
150
+
151
+ if __name__ == "__main__":
152
+ pytest.main([__file__, "-q"])
tests/test_models/test_decoder_step.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytest
3
+ from typing import Any, Dict, cast
4
+ from src.models.decoder import TransformerDecoder
5
+
6
+
7
+ def test_step_equivalence_with_greedy_decode():
8
+ torch.manual_seed(7)
9
+ vocab_size = 25
10
+ d_model = 32
11
+ num_layers = 2
12
+ num_heads = 4
13
+ d_ff = 64
14
+ batch_size = 2
15
+ src_len = 6
16
+ max_tgt = 6
17
+
18
+ decoder = TransformerDecoder(
19
+ vocab_size=vocab_size,
20
+ d_model=d_model,
21
+ num_layers=num_layers,
22
+ num_heads=num_heads,
23
+ d_ff=d_ff,
24
+ dropout=0.0,
25
+ max_len=max_tgt,
26
+ pad_token_id=0,
27
+ )
28
+
29
+ memory = torch.randn(batch_size, src_len, d_model)
30
+
31
+ # 1) Get greedy sequence from naive greedy_decode
32
+ greedy = decoder.greedy_decode(memory, max_len=max_tgt, start_token_id=1, end_token_id=None)
33
+
34
+ # 2) Reproduce the same sequence with step() using cache
35
+ cache: Dict[str, Any] = {"past_length": 0}
36
+ generated = torch.full((batch_size, 1), 1, dtype=torch.long)
37
+ for _ in range(max_tgt - 1):
38
+ last_token = generated[:, -1:].to(memory.device)
39
+ logits, cache = decoder.step(cast(torch.LongTensor, last_token), memory, cache=cache)
40
+ next_token = logits.argmax(dim=-1, keepdim=True)
41
+ generated = torch.cat([generated, next_token], dim=1)
42
+
43
+ # Compare shapes & that sequences are identical
44
+ assert generated.shape == greedy.shape
45
+ assert torch.equal(generated, greedy)
46
+
47
+
48
+ def test_step_cache_growth_and_shapes():
49
+ torch.manual_seed(9)
50
+ vocab_size = 20
51
+ d_model = 24
52
+ num_layers = 3
53
+ num_heads = 4
54
+ d_ff = 64
55
+ batch_size = 1
56
+ src_len = 5
57
+ steps = 4
58
+ max_tgt = 8
59
+
60
+ decoder = TransformerDecoder(
61
+ vocab_size=vocab_size,
62
+ d_model=d_model,
63
+ num_layers=num_layers,
64
+ num_heads=num_heads,
65
+ d_ff=d_ff,
66
+ dropout=0.0,
67
+ max_len=max_tgt,
68
+ pad_token_id=0,
69
+ )
70
+
71
+ memory = torch.randn(batch_size, src_len, d_model)
72
+
73
+ cache: Dict[str, Any] = {"past_length": 0}
74
+ last = torch.full((batch_size, 1), 1, dtype=torch.long)
75
+ for step_idx in range(steps):
76
+ logits, cache = decoder.step(cast(torch.LongTensor, last), memory, cache=cache)
77
+ # check updated past_length
78
+ assert cache["past_length"] == step_idx + 1
79
+ # check cached per-layer keys exist and have expected shape (B, H, seq_len, d_k)
80
+ for i in range(num_layers):
81
+ k = cache.get(f"self_k_{i}")
82
+ v = cache.get(f"self_v_{i}")
83
+ assert k is not None and v is not None
84
+ # seq_len should equal past_length
85
+ assert k.shape[2] == cache["past_length"]
86
+ # shapes match
87
+ assert k.shape[0] == batch_size
88
+ assert v.shape[0] == batch_size
89
+ # advance last token for next loop
90
+ last = logits.argmax(dim=-1, keepdim=True)
91
+
92
+ # Also ensure memory projections cached
93
+ for i in range(num_layers):
94
+ assert f"mem_k_{i}" in cache and f"mem_v_{i}" in cache
95
+ mem_k = cache[f"mem_k_{i}"]
96
+ mem_v = cache[f"mem_v_{i}"]
97
+ assert mem_k.shape[0] == batch_size
98
+ assert mem_k.shape[2] == src_len # seq length of memory
tests/test_models/test_heads.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytest
3
+ import torch.nn as nn
4
+ from src.models.heads import (
5
+ ClassificationHead,
6
+ TokenClassificationHead,
7
+ LMHead,
8
+ ProjectionHead,
9
+ )
10
+
11
+
12
+ def test_classification_head_shapes_and_dropout():
13
+ torch.manual_seed(0)
14
+ d_model = 64
15
+ num_labels = 5
16
+ batch_size = 3
17
+ seq_len = 10
18
+
19
+ head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.5)
20
+ head.train()
21
+ x = torch.randn(batch_size, seq_len, d_model)
22
+
23
+ out1 = head(x)
24
+ out2 = head(x)
25
+ # With dropout in train mode, outputs should usually differ
26
+ assert out1.shape == (batch_size, num_labels)
27
+ assert out2.shape == (batch_size, num_labels)
28
+ assert not torch.allclose(out1, out2)
29
+
30
+ head.eval()
31
+ out3 = head(x)
32
+ out4 = head(x)
33
+ assert torch.allclose(out3, out4), "Eval mode should be deterministic"
34
+
35
+
36
+ def test_token_classification_head_shapes_and_grads():
37
+ torch.manual_seed(1)
38
+ d_model = 48
39
+ num_labels = 7
40
+ batch_size = 2
41
+ seq_len = 6
42
+
43
+ head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0)
44
+ x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
45
+ out = head(x)
46
+ assert out.shape == (batch_size, seq_len, num_labels)
47
+
48
+ loss = out.sum()
49
+ loss.backward()
50
+ grads = [p.grad for name, p in head.named_parameters() if p.requires_grad]
51
+ assert any(g is not None for g in grads)
52
+
53
+
54
+ def test_lm_head_tie_weights_and_shapes():
55
+ torch.manual_seed(2)
56
+ vocab_size = 50
57
+ d_model = 32
58
+ batch_size = 2
59
+ seq_len = 4
60
+
61
+ embedding = nn.Embedding(vocab_size, d_model)
62
+ lm_tied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=embedding)
63
+ lm_untied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None)
64
+
65
+ hidden = torch.randn(batch_size, seq_len, d_model)
66
+
67
+ # Shapes
68
+ logits_tied = lm_tied(hidden)
69
+ logits_untied = lm_untied(hidden)
70
+ assert logits_tied.shape == (batch_size, seq_len, vocab_size)
71
+ assert logits_untied.shape == (batch_size, seq_len, vocab_size)
72
+
73
+ # Weight tying: projection weight should be the same object as embedding.weight
74
+ assert lm_tied.proj.weight is embedding.weight
75
+
76
+ # Grad flows through tied weights
77
+ loss = logits_tied.sum()
78
+ loss.backward()
79
+ assert embedding.weight.grad is not None
80
+
81
+
82
+ def test_projection_head_2d_and_3d_behavior_and_grad():
83
+ torch.manual_seed(3)
84
+ d_model = 40
85
+ proj_dim = 16
86
+ batch_size = 2
87
+ seq_len = 5
88
+
89
+ head = ProjectionHead(d_model=d_model, proj_dim=proj_dim, hidden_dim=64, dropout=0.0)
90
+ # 2D input
91
+ vec = torch.randn(batch_size, d_model, requires_grad=True)
92
+ out2 = head(vec)
93
+ assert out2.shape == (batch_size, proj_dim)
94
+
95
+ # 3D input
96
+ seq = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
97
+ out3 = head(seq)
98
+ assert out3.shape == (batch_size, seq_len, proj_dim)
99
+
100
+ # Grad flow
101
+ loss = out3.sum()
102
+ loss.backward()
103
+ grads = [p.grad for p in head.parameters() if p.requires_grad]
104
+ assert any(g is not None for g in grads)
tests/test_models/test_multitask.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytest
3
+ from src.models.encoder import TransformerEncoder
4
+ from src.models.decoder import TransformerDecoder
5
+ from src.models.heads import ClassificationHead, LMHead, TokenClassificationHead
6
+ from src.models.multitask import MultiTaskModel
7
+
8
+
9
+ def test_multitask_encoder_classification_forward_and_loss():
10
+ torch.manual_seed(0)
11
+ vocab_size = 30
12
+ d_model = 32
13
+ num_layers = 2
14
+ num_heads = 4
15
+ d_ff = 64
16
+ batch_size = 3
17
+ seq_len = 8
18
+ num_labels = 5
19
+
20
+ enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
21
+ num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=seq_len, pad_token_id=0)
22
+
23
+ mt = MultiTaskModel(encoder=enc)
24
+ head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.0)
25
+ mt.add_head("sentiment", head)
26
+
27
+ input_ids = torch.randint(1, vocab_size, (batch_size, seq_len), dtype=torch.long)
28
+ labels = torch.randint(0, num_labels, (batch_size,), dtype=torch.long)
29
+
30
+ logits = mt.forward("sentiment", {"input_ids": input_ids})
31
+ assert logits.shape == (batch_size, num_labels)
32
+
33
+ loss, logits2 = mt.forward("sentiment", {"input_ids": input_ids, "labels": labels}, return_loss=True)
34
+ assert loss.item() >= 0
35
+ # grads
36
+ loss.backward()
37
+ grads = [p.grad for p in mt.parameters() if p.requires_grad]
38
+ assert any(g is not None for g in grads)
39
+
40
+
41
+ def test_multitask_seq2seq_lm_forward_and_loss():
42
+ torch.manual_seed(1)
43
+ vocab_size = 40
44
+ d_model = 32
45
+ num_layers = 2
46
+ num_heads = 4
47
+ d_ff = 64
48
+ batch_size = 2
49
+ src_len = 7
50
+ tgt_len = 6
51
+
52
+ enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
53
+ num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=src_len, pad_token_id=0)
54
+ dec = TransformerDecoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
55
+ num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=tgt_len, pad_token_id=0)
56
+ mt = MultiTaskModel(encoder=enc, decoder=dec)
57
+ lm_head = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None)
58
+ mt.add_head("summarize", lm_head)
59
+
60
+ src_ids = torch.randint(1, vocab_size, (batch_size, src_len), dtype=torch.long)
61
+ # for training: provide decoder inputs (typically shifted right) and labels
62
+ tgt_ids = torch.randint(1, vocab_size, (batch_size, tgt_len), dtype=torch.long)
63
+ labels = tgt_ids.clone()
64
+
65
+ logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
66
+ assert logits.shape == (batch_size, tgt_len, vocab_size)
67
+
68
+ loss, logits2 = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids, "labels": labels}, return_loss=True)
69
+ assert loss.item() >= 0
70
+ loss.backward()
71
+ grads = [p.grad for p in mt.parameters() if p.requires_grad]
72
+ assert any(g is not None for g in grads)
73
+
74
+
75
+ def test_token_classification_forward_and_loss():
76
+ torch.manual_seed(2)
77
+ vocab_size = 20
78
+ d_model = 24
79
+ num_layers = 2
80
+ num_heads = 4
81
+ d_ff = 64
82
+ batch_size = 2
83
+ seq_len = 5
84
+ num_labels = 7
85
+
86
+ enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
87
+ num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=seq_len, pad_token_id=0)
88
+ mt = MultiTaskModel(encoder=enc)
89
+ head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0)
90
+ mt.add_head("ner", head)
91
+
92
+ input_ids = torch.randint(1, vocab_size, (batch_size, seq_len), dtype=torch.long)
93
+ labels = torch.randint(0, num_labels, (batch_size, seq_len), dtype=torch.long)
94
+
95
+ logits = mt.forward("ner", {"input_ids": input_ids})
96
+ assert logits.shape == (batch_size, seq_len, num_labels)
97
+
98
+ loss, logits2 = mt.forward("ner", {"input_ids": input_ids, "labels": labels}, return_loss=True)
99
+ assert loss.item() >= 0
100
+ loss.backward()
101
+ grads = [p.grad for p in mt.parameters() if p.requires_grad]
102
+ assert any(g is not None for g in grads)