Spaces:
Runtime error
Runtime error
Commit ·
f672f40
1
Parent(s): aefdcf0
Finsihed Decoder Implementation, as well as prediction heads and multitask
Browse files- src/models/decoder.py +242 -150
- src/models/heads.py +151 -0
- src/models/multitask.py +198 -0
- tests/test_models/test_decoder.py +152 -0
- tests/test_models/test_decoder_step.py +98 -0
- tests/test_models/test_heads.py +104 -0
- tests/test_models/test_multitask.py +102 -0
src/models/decoder.py
CHANGED
|
@@ -1,28 +1,16 @@
|
|
| 1 |
"""
|
| 2 |
-
Transformer Decoder
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
- create_causal_mask
|
| 6 |
-
- TransformerDecoderLayer
|
| 7 |
-
- TransformerDecoder
|
| 8 |
-
|
| 9 |
-
Notes / conventions:
|
| 10 |
-
- Pre-LN (LayerNorm before each sublayer) is assumed for stability (consistent with your encoder).
|
| 11 |
-
- MultiHeadAttention, FeedForward, PositionalEncoding are expected to live in src/models
|
| 12 |
-
(you already implemented them).
|
| 13 |
-
- Masks use boolean semantics: True = allowed, False = masked.
|
| 14 |
-
- The decoder API supports:
|
| 15 |
-
- inputs: token ids (LongTensor, (B, T)) or embeddings ((B, T, d_model))
|
| 16 |
-
- memory: encoder outputs (B, S, d_model)
|
| 17 |
-
- mask arguments: tgt_mask (causal/padding), memory_mask (encoder padding)
|
| 18 |
-
- collect_attn: return attention maps per layer if requested
|
| 19 |
-
- Generation helpers (greedy) are skeletons that you can extend to beam search or caching.
|
| 20 |
-
|
| 21 |
-
TODO status keys:
|
| 22 |
-
- [IMPLEMENT] : core implementation required
|
| 23 |
-
- [OPTIONAL] : useful enhancement (caching, beam search, advanced scheduling)
|
| 24 |
-
"""
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from typing import Optional, Tuple, List, Union, Dict
|
| 27 |
import math
|
| 28 |
import torch
|
|
@@ -35,47 +23,34 @@ from .positional_encoding import PositionalEncoding
|
|
| 35 |
|
| 36 |
def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
|
| 37 |
"""
|
| 38 |
-
Create a
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
Returns:
|
| 42 |
-
mask: torch.BoolTensor of shape (seq_len, seq_len)
|
| 43 |
"""
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
mask
|
| 47 |
-
# mask has True above diagonal (to be masked). Want True for allowed, so invert:
|
| 48 |
-
return ~mask # shape (seq_len, seq_len) or (T, T)
|
| 49 |
|
| 50 |
|
| 51 |
class TransformerDecoderLayer(nn.Module):
|
| 52 |
"""
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
d_model: model hidden size
|
| 61 |
-
num_heads: number of attention heads
|
| 62 |
-
d_ff: ff intermediate size
|
| 63 |
-
dropout: dropout for residuals / FFN
|
| 64 |
"""
|
| 65 |
|
| 66 |
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
|
| 67 |
super().__init__()
|
| 68 |
-
#
|
| 69 |
self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
|
| 70 |
self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
|
| 71 |
self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
|
| 72 |
|
| 73 |
-
# LayerNorms (Pre-LN)
|
| 74 |
self.norm1 = nn.LayerNorm(d_model)
|
| 75 |
self.norm2 = nn.LayerNorm(d_model)
|
| 76 |
self.norm3 = nn.LayerNorm(d_model)
|
| 77 |
|
| 78 |
-
# Dropouts applied after sublayers (on sublayer outputs before residual add)
|
| 79 |
self.dropout1 = nn.Dropout(dropout)
|
| 80 |
self.dropout2 = nn.Dropout(dropout)
|
| 81 |
self.dropout3 = nn.Dropout(dropout)
|
|
@@ -88,46 +63,51 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 88 |
memory_mask: Optional[torch.Tensor] = None,
|
| 89 |
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 90 |
"""
|
| 91 |
-
Forward pass for one decoder layer.
|
| 92 |
-
|
| 93 |
Args:
|
| 94 |
-
tgt: (
|
| 95 |
-
memory: (
|
| 96 |
-
tgt_mask: optional (
|
| 97 |
-
memory_mask: optional (
|
| 98 |
|
| 99 |
Returns:
|
| 100 |
-
|
| 101 |
-
attn_maps: dict with keys 'self' and 'cross' containing attention tensors
|
| 102 |
-
shapes: (batch, num_heads, tgt_len, tgt_len) and (batch, num_heads, tgt_len, src_len)
|
| 103 |
"""
|
| 104 |
-
#
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
#
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
|
| 119 |
-
|
| 120 |
-
raise NotImplementedError("TransformerDecoderLayer.forward: implement Pre-LN pipeline")
|
| 121 |
|
| 122 |
|
| 123 |
class TransformerDecoder(nn.Module):
|
| 124 |
"""
|
| 125 |
-
|
| 126 |
-
Also supports simple greedy decoding.
|
| 127 |
|
| 128 |
-
|
| 129 |
-
vocab_size: for token embeddings (if using token ids)
|
| 130 |
-
d_model, num_layers, num_heads, d_ff, dropout, max_len, pad_token_id: same semantics as encoder
|
| 131 |
"""
|
| 132 |
|
| 133 |
def __init__(
|
|
@@ -146,37 +126,25 @@ class TransformerDecoder(nn.Module):
|
|
| 146 |
self.d_model = d_model
|
| 147 |
self.pad_token_id = pad_token_id
|
| 148 |
|
| 149 |
-
# Token embedding (used if inputs are token ids)
|
| 150 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 151 |
-
|
| 152 |
-
# Positional encoding
|
| 153 |
self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
|
| 154 |
|
| 155 |
-
# Decoder layers
|
| 156 |
self.layers = nn.ModuleList(
|
| 157 |
-
[
|
| 158 |
-
|
| 159 |
-
for _ in range(num_layers)
|
| 160 |
-
]
|
| 161 |
)
|
| 162 |
|
| 163 |
-
# Final layer norm for Pre-LN stacks
|
| 164 |
self.final_norm = nn.LayerNorm(d_model)
|
| 165 |
-
|
| 166 |
-
# Output projection to vocabulary (logits)
|
| 167 |
self.output_projection = nn.Linear(d_model, vocab_size)
|
| 168 |
-
|
| 169 |
-
# Input dropout (after pos encoding)
|
| 170 |
self.input_dropout = nn.Dropout(dropout)
|
| 171 |
|
| 172 |
def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 173 |
"""
|
| 174 |
-
|
| 175 |
-
True = allowed, False = masked.
|
| 176 |
"""
|
| 177 |
assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
|
| 178 |
-
pad_mask = (input_ids != self.pad_token_id) # (B,
|
| 179 |
-
attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B,
|
| 180 |
return attn_mask
|
| 181 |
|
| 182 |
def forward(
|
|
@@ -188,21 +156,13 @@ class TransformerDecoder(nn.Module):
|
|
| 188 |
collect_attn: bool = False,
|
| 189 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
|
| 190 |
"""
|
| 191 |
-
Forward pass for the decoder stack.
|
| 192 |
-
|
| 193 |
Args:
|
| 194 |
-
inputs:
|
| 195 |
-
memory:
|
| 196 |
-
tgt_mask: optional
|
| 197 |
-
|
| 198 |
-
memory_mask: optional mask over memory (B, S, S) or (B, 1, T, S)
|
| 199 |
-
collect_attn: if True returns (logits/outputs, [per-layer attn dicts])
|
| 200 |
-
|
| 201 |
-
Returns:
|
| 202 |
-
logits: (B, T, vocab_size) or (B, T, d_model) if you prefer returning hidden states
|
| 203 |
-
or (logits, attn_list) if collect_attn True
|
| 204 |
"""
|
| 205 |
-
#
|
| 206 |
if inputs.dim() == 2: # token ids
|
| 207 |
x = self.embedding(inputs) * math.sqrt(self.d_model)
|
| 208 |
elif inputs.dim() == 3:
|
|
@@ -210,47 +170,48 @@ class TransformerDecoder(nn.Module):
|
|
| 210 |
else:
|
| 211 |
raise ValueError("inputs must be (B, T) token ids or (B, T, d_model) embeddings")
|
| 212 |
|
| 213 |
-
# Positional encoding + dropout
|
| 214 |
x = self.pos_encoder(x)
|
| 215 |
x = self.input_dropout(x)
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
|
|
|
| 219 |
if tgt_mask is None:
|
| 220 |
-
|
| 221 |
-
causal = create_causal_mask(seq_len, device=x.device) # [TODO implement]
|
| 222 |
-
# expand to batch dim later if padding present
|
| 223 |
if inputs.dim() == 2 and self.pad_token_id is not None:
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
raise NotImplementedError("tgt_mask construction: combine causal + padding_mask")
|
| 228 |
else:
|
| 229 |
-
#
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
-
#
|
| 233 |
if memory_mask is not None:
|
| 234 |
memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
attn_list: List[Dict[str, torch.Tensor]] = []
|
| 237 |
|
| 238 |
-
# Pass through layers
|
| 239 |
for layer in self.layers:
|
| 240 |
x, attn = layer(x, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
|
| 241 |
if collect_attn:
|
| 242 |
attn_list.append(attn)
|
| 243 |
|
| 244 |
-
x = self.final_norm(x)
|
| 245 |
-
|
| 246 |
logits = self.output_projection(x) # (B, T, vocab)
|
|
|
|
| 247 |
if collect_attn:
|
| 248 |
return logits, attn_list
|
| 249 |
return logits
|
| 250 |
|
| 251 |
-
# ---------------------------------------------------------------------
|
| 252 |
-
# Generation / inference helpers (skeletons)
|
| 253 |
-
# ---------------------------------------------------------------------
|
| 254 |
def greedy_decode(
|
| 255 |
self,
|
| 256 |
memory: torch.Tensor,
|
|
@@ -258,26 +219,32 @@ class TransformerDecoder(nn.Module):
|
|
| 258 |
start_token_id: int,
|
| 259 |
end_token_id: Optional[int] = None,
|
| 260 |
device: Optional[torch.device] = None,
|
| 261 |
-
) -> torch.
|
| 262 |
"""
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
Args:
|
| 266 |
-
memory: encoder outputs (B, S, d_model)
|
| 267 |
-
max_len: maximum target length to generate
|
| 268 |
-
start_token_id: BOS token id
|
| 269 |
-
end_token_id: optional EOS token id to stop early
|
| 270 |
-
Returns:
|
| 271 |
-
generated: (B, T_out) long tensor of token ids
|
| 272 |
"""
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
def step(
|
| 282 |
self,
|
| 283 |
last_token_ids: torch.LongTensor,
|
|
@@ -285,16 +252,141 @@ class TransformerDecoder(nn.Module):
|
|
| 285 |
cache: Optional[Dict] = None,
|
| 286 |
) -> Tuple[torch.Tensor, Dict]:
|
| 287 |
"""
|
| 288 |
-
|
| 289 |
|
| 290 |
Args:
|
| 291 |
-
last_token_ids: (B, 1)
|
| 292 |
-
memory: encoder outputs
|
| 293 |
-
cache: optional dict
|
| 294 |
|
| 295 |
Returns:
|
| 296 |
-
logits: (B, vocab_size)
|
| 297 |
-
new_cache: updated cache
|
| 298 |
"""
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Transformer Decoder (Pre-LN) - implementation.
|
| 3 |
+
|
| 4 |
+
Implements:
|
| 5 |
+
- create_causal_mask
|
| 6 |
+
- TransformerDecoderLayer
|
| 7 |
+
- TransformerDecoder (stack + naive greedy decoding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
Conventions:
|
| 10 |
+
- Masks are boolean: True = allowed, False = masked.
|
| 11 |
+
- MultiHeadAttention expects masks broadcastable to (B, num_heads, T_q, T_k).
|
| 12 |
+
- This decoder uses Pre-LN (LayerNorm before each sublayer).
|
| 13 |
+
"""
|
| 14 |
from typing import Optional, Tuple, List, Union, Dict
|
| 15 |
import math
|
| 16 |
import torch
|
|
|
|
| 23 |
|
| 24 |
def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
|
| 25 |
"""
|
| 26 |
+
Create a (seq_len, seq_len) causal mask where entry (i, j) is True iff
|
| 27 |
+
j <= i (query at i may attend to keys up to i).
|
|
|
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
+
# torch.triu(..., diagonal=1) is True above the diagonal. Invert to get allowed positions.
|
| 30 |
+
mask = ~torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1)
|
| 31 |
+
return mask # shape: (T, T)
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
class TransformerDecoderLayer(nn.Module):
|
| 35 |
"""
|
| 36 |
+
Single decoder layer (Pre-LN):
|
| 37 |
+
1) Masked self-attention
|
| 38 |
+
2) Cross-attention (encoder -> decoder)
|
| 39 |
+
3) Feed-forward
|
| 40 |
+
Returns the updated tgt and a dict of attention maps.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
"""
|
| 42 |
|
| 43 |
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
|
| 44 |
super().__init__()
|
| 45 |
+
# use internal MHA dropout = 0.0; the layer handles dropout after sublayers
|
| 46 |
self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
|
| 47 |
self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
|
| 48 |
self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
|
| 49 |
|
|
|
|
| 50 |
self.norm1 = nn.LayerNorm(d_model)
|
| 51 |
self.norm2 = nn.LayerNorm(d_model)
|
| 52 |
self.norm3 = nn.LayerNorm(d_model)
|
| 53 |
|
|
|
|
| 54 |
self.dropout1 = nn.Dropout(dropout)
|
| 55 |
self.dropout2 = nn.Dropout(dropout)
|
| 56 |
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
| 63 |
memory_mask: Optional[torch.Tensor] = None,
|
| 64 |
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 65 |
"""
|
|
|
|
|
|
|
| 66 |
Args:
|
| 67 |
+
tgt: (B, T, d_model)
|
| 68 |
+
memory: (B, S, d_model)
|
| 69 |
+
tgt_mask: optional mask for self-attn - shape (B, T, T) or (B, 1, T, T)
|
| 70 |
+
memory_mask: optional mask for cross-attn - shape (B, S) or (B, 1, S) or (B, 1, T, S)
|
| 71 |
|
| 72 |
Returns:
|
| 73 |
+
(tgt_out, {"self": self_attn_weights, "cross": cross_attn_weights})
|
|
|
|
|
|
|
| 74 |
"""
|
| 75 |
+
# Ensure masks are on same device and boolean
|
| 76 |
+
if tgt_mask is not None:
|
| 77 |
+
tgt_mask = tgt_mask.to(dtype=torch.bool, device=tgt.device)
|
| 78 |
+
if memory_mask is not None:
|
| 79 |
+
memory_mask = memory_mask.to(dtype=torch.bool, device=tgt.device)
|
| 80 |
+
# If memory_mask is provided as (B, S) (per-key padding), expand to (B, 1, 1, S)
|
| 81 |
+
if memory_mask.dim() == 2:
|
| 82 |
+
memory_mask = memory_mask.unsqueeze(1).unsqueeze(1) # (B,1,1,S)
|
| 83 |
+
# If it's (B, S, S) or (B, 1, S, S) leave as-is; if (B, T, S) convert to (B,1,T,S)
|
| 84 |
+
elif memory_mask.dim() == 3 and memory_mask.shape[1] != 1:
|
| 85 |
+
# assume (B, T, S) -> make (B, 1, T, S)
|
| 86 |
+
memory_mask = memory_mask.unsqueeze(1)
|
| 87 |
+
|
| 88 |
+
# --- Masked self-attention (Pre-LN) ---
|
| 89 |
+
x_norm = self.norm1(tgt)
|
| 90 |
+
self_out, self_attn = self.self_attn(x_norm, x_norm, x_norm, tgt_mask)
|
| 91 |
+
tgt = tgt + self.dropout1(self_out)
|
| 92 |
|
| 93 |
+
# --- Cross-attention (Pre-LN) ---
|
| 94 |
+
x_norm = self.norm2(tgt)
|
| 95 |
+
cross_out, cross_attn = self.cross_attn(x_norm, memory, memory, memory_mask)
|
| 96 |
+
tgt = tgt + self.dropout2(cross_out)
|
| 97 |
|
| 98 |
+
# --- Feed-forward (Pre-LN) ---
|
| 99 |
+
x_norm = self.norm3(tgt)
|
| 100 |
+
ffn_out = self.ffn(x_norm)
|
| 101 |
+
tgt = tgt + self.dropout3(ffn_out)
|
| 102 |
|
| 103 |
+
return tgt, {"self": self_attn, "cross": cross_attn}
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
class TransformerDecoder(nn.Module):
|
| 107 |
"""
|
| 108 |
+
Decoder stack with token embeddings and positional encoding.
|
|
|
|
| 109 |
|
| 110 |
+
Forward returns logits (B, T, vocab_size) by default; if collect_attn=True returns (logits, attn_list).
|
|
|
|
|
|
|
| 111 |
"""
|
| 112 |
|
| 113 |
def __init__(
|
|
|
|
| 126 |
self.d_model = d_model
|
| 127 |
self.pad_token_id = pad_token_id
|
| 128 |
|
|
|
|
| 129 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
|
|
|
|
|
|
| 130 |
self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
|
| 131 |
|
|
|
|
| 132 |
self.layers = nn.ModuleList(
|
| 133 |
+
[TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=dropout)
|
| 134 |
+
for _ in range(num_layers)]
|
|
|
|
|
|
|
| 135 |
)
|
| 136 |
|
|
|
|
| 137 |
self.final_norm = nn.LayerNorm(d_model)
|
|
|
|
|
|
|
| 138 |
self.output_projection = nn.Linear(d_model, vocab_size)
|
|
|
|
|
|
|
| 139 |
self.input_dropout = nn.Dropout(dropout)
|
| 140 |
|
| 141 |
def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 142 |
"""
|
| 143 |
+
Convert input ids to (B, T, T) boolean mask where True = allowed.
|
|
|
|
| 144 |
"""
|
| 145 |
assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
|
| 146 |
+
pad_mask = (input_ids != self.pad_token_id) # (B, T)
|
| 147 |
+
attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B, T, T)
|
| 148 |
return attn_mask
|
| 149 |
|
| 150 |
def forward(
|
|
|
|
| 156 |
collect_attn: bool = False,
|
| 157 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
|
| 158 |
"""
|
|
|
|
|
|
|
| 159 |
Args:
|
| 160 |
+
inputs: (B, T) token ids or (B, T, d_model) embeddings
|
| 161 |
+
memory: (B, S, d_model)
|
| 162 |
+
tgt_mask: optional; if None, will create (causal [+ padding if ids available])
|
| 163 |
+
memory_mask: optional; if provided as (B, S) will be expanded to (B, 1, 1, S)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
"""
|
| 165 |
+
# Prepare embeddings
|
| 166 |
if inputs.dim() == 2: # token ids
|
| 167 |
x = self.embedding(inputs) * math.sqrt(self.d_model)
|
| 168 |
elif inputs.dim() == 3:
|
|
|
|
| 170 |
else:
|
| 171 |
raise ValueError("inputs must be (B, T) token ids or (B, T, d_model) embeddings")
|
| 172 |
|
|
|
|
| 173 |
x = self.pos_encoder(x)
|
| 174 |
x = self.input_dropout(x)
|
| 175 |
|
| 176 |
+
B, T, _ = x.shape
|
| 177 |
+
|
| 178 |
+
# Build target mask if not provided: combine causal + padding (if available)
|
| 179 |
if tgt_mask is None:
|
| 180 |
+
causal = create_causal_mask(T, device=x.device) # (T, T)
|
|
|
|
|
|
|
| 181 |
if inputs.dim() == 2 and self.pad_token_id is not None:
|
| 182 |
+
pad_pairwise = self._build_padding_mask_from_ids(inputs) # (B, T, T)
|
| 183 |
+
combined = pad_pairwise & causal.unsqueeze(0) # (B, T, T)
|
| 184 |
+
tgt_mask = combined.unsqueeze(1) # (B, 1, T, T) -> broadcast to heads
|
|
|
|
| 185 |
else:
|
| 186 |
+
# No per-batch padding info: broadcast causal to (1, 1, T, T)
|
| 187 |
+
tgt_mask = causal.unsqueeze(0).unsqueeze(1) # (1, 1, T, T)
|
| 188 |
+
else:
|
| 189 |
+
# Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
|
| 190 |
+
tgt_mask = tgt_mask.to(dtype=torch.bool, device=x.device)
|
| 191 |
|
| 192 |
+
# Normalize memory_mask dtype/device and expand simple shapes
|
| 193 |
if memory_mask is not None:
|
| 194 |
memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
|
| 195 |
+
if memory_mask.dim() == 2: # (B, S) -> (B, 1, 1, S)
|
| 196 |
+
memory_mask = memory_mask.unsqueeze(1).unsqueeze(1)
|
| 197 |
+
elif memory_mask.dim() == 3: # (B, T, S) -> (B, 1, T, S)
|
| 198 |
+
memory_mask = memory_mask.unsqueeze(1)
|
| 199 |
|
| 200 |
attn_list: List[Dict[str, torch.Tensor]] = []
|
| 201 |
|
| 202 |
+
# Pass through decoder layers
|
| 203 |
for layer in self.layers:
|
| 204 |
x, attn = layer(x, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
|
| 205 |
if collect_attn:
|
| 206 |
attn_list.append(attn)
|
| 207 |
|
| 208 |
+
x = self.final_norm(x)
|
|
|
|
| 209 |
logits = self.output_projection(x) # (B, T, vocab)
|
| 210 |
+
|
| 211 |
if collect_attn:
|
| 212 |
return logits, attn_list
|
| 213 |
return logits
|
| 214 |
|
|
|
|
|
|
|
|
|
|
| 215 |
def greedy_decode(
|
| 216 |
self,
|
| 217 |
memory: torch.Tensor,
|
|
|
|
| 219 |
start_token_id: int,
|
| 220 |
end_token_id: Optional[int] = None,
|
| 221 |
device: Optional[torch.device] = None,
|
| 222 |
+
) -> torch.Tensor:
|
| 223 |
"""
|
| 224 |
+
Naive greedy decoding: repeatedly run the decoder on the growing prefix.
|
| 225 |
+
Not optimized (recomputes full decoder each step) but simple and correct.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
"""
|
| 227 |
+
if device is None:
|
| 228 |
+
device = memory.device
|
| 229 |
+
B = memory.size(0)
|
| 230 |
+
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 231 |
+
|
| 232 |
+
for _ in range(max_len - 1):
|
| 233 |
+
logits = self.forward(generated, memory, collect_attn=False) # (B, L, V)
|
| 234 |
+
assert isinstance(logits, torch.Tensor) # type narrowing
|
| 235 |
+
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) # (B, 1)
|
| 236 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 237 |
+
|
| 238 |
+
if end_token_id is not None:
|
| 239 |
+
# stop if all sequences ended
|
| 240 |
+
if (generated[:, -1] == end_token_id).all():
|
| 241 |
+
break
|
| 242 |
+
|
| 243 |
+
return generated
|
| 244 |
+
|
| 245 |
+
# -----------------------------
|
| 246 |
+
# Incremental single-step API
|
| 247 |
+
# -----------------------------
|
| 248 |
def step(
|
| 249 |
self,
|
| 250 |
last_token_ids: torch.LongTensor,
|
|
|
|
| 252 |
cache: Optional[Dict] = None,
|
| 253 |
) -> Tuple[torch.Tensor, Dict]:
|
| 254 |
"""
|
| 255 |
+
Run one autoregressive step.
|
| 256 |
|
| 257 |
Args:
|
| 258 |
+
last_token_ids: (B, 1) last generated token ids
|
| 259 |
+
memory: encoder outputs (B, S, d_model)
|
| 260 |
+
cache: optional dict with previous cached keys/values and 'past_length'.
|
| 261 |
|
| 262 |
Returns:
|
| 263 |
+
logits: (B, vocab_size) logits for the next-token prediction
|
| 264 |
+
new_cache: updated cache dictionary
|
| 265 |
"""
|
| 266 |
+
device = memory.device
|
| 267 |
+
B = last_token_ids.size(0)
|
| 268 |
+
|
| 269 |
+
if cache is None:
|
| 270 |
+
cache = {}
|
| 271 |
+
past_len = int(cache.get("past_length", 0))
|
| 272 |
+
|
| 273 |
+
# 1) Embed last token and add positional encoding for position `past_len`
|
| 274 |
+
x = self.embedding(last_token_ids) * math.sqrt(self.d_model) # (B,1,d)
|
| 275 |
+
# Use positional encoding buffer directly (avoid dropout in pos_encoder)
|
| 276 |
+
# pos_encoder.pe expected shape (1, max_len, d_model)
|
| 277 |
+
if hasattr(self.pos_encoder, "pe"):
|
| 278 |
+
pe = self.pos_encoder.pe # (1, max_len, d_model)
|
| 279 |
+
pos_idx = past_len
|
| 280 |
+
if pos_idx >= pe.size(1):
|
| 281 |
+
raise RuntimeError(f"pos_idx {pos_idx} exceeds max_len {pe.size(1)}")
|
| 282 |
+
x = x + pe[:, pos_idx:pos_idx + 1, :].to(device)
|
| 283 |
+
else:
|
| 284 |
+
# fallback: call pos_encoder and rely on its dropout (less ideal)
|
| 285 |
+
x = self.pos_encoder(x)
|
| 286 |
+
|
| 287 |
+
# We will update new_cache incrementally
|
| 288 |
+
new_cache = dict(cache) # shallow copy
|
| 289 |
+
new_cache["past_length"] = past_len + 1
|
| 290 |
+
|
| 291 |
+
# Optional: memory_mask could be supplied in cache under 'memory_mask'
|
| 292 |
+
memory_mask = new_cache.get("memory_mask", None)
|
| 293 |
+
if memory_mask is not None:
|
| 294 |
+
memory_mask = memory_mask.to(dtype=torch.bool, device=device)
|
| 295 |
+
# expand (B, S) -> (B,1,1,S) if necessary
|
| 296 |
+
if memory_mask.dim() == 2:
|
| 297 |
+
memory_mask = memory_mask.unsqueeze(1).unsqueeze(1)
|
| 298 |
+
elif memory_mask.dim() == 3:
|
| 299 |
+
memory_mask = memory_mask.unsqueeze(1)
|
| 300 |
+
|
| 301 |
+
# Iterate layers, updating caches and computing output for current token only
|
| 302 |
+
layer_input = x # (B,1,d_model)
|
| 303 |
+
for i, layer in enumerate(self.layers):
|
| 304 |
+
# -------------------
|
| 305 |
+
# 1) Self-attention (incremental)
|
| 306 |
+
# -------------------
|
| 307 |
+
# Normalize input for pre-LN
|
| 308 |
+
x_norm = layer.norm1(layer_input) # (B,1,d)
|
| 309 |
+
|
| 310 |
+
# Project Q,K,V for the new token
|
| 311 |
+
Q_new = layer.self_attn.W_Q(x_norm) # (B,1,d_model)
|
| 312 |
+
K_new = layer.self_attn.W_K(x_norm)
|
| 313 |
+
V_new = layer.self_attn.W_V(x_norm)
|
| 314 |
+
|
| 315 |
+
# Reshape into heads: (B, num_heads, 1, d_k)
|
| 316 |
+
B_, Lq, _ = Q_new.shape
|
| 317 |
+
num_heads = layer.self_attn.num_heads
|
| 318 |
+
d_k = layer.self_attn.d_k
|
| 319 |
+
Qh = Q_new.view(B_, Lq, num_heads, d_k).transpose(1, 2) # (B, num_heads, 1, d_k)
|
| 320 |
+
Kh = K_new.view(B_, Lq, num_heads, d_k).transpose(1, 2)
|
| 321 |
+
Vh = V_new.view(B_, Lq, num_heads, d_k).transpose(1, 2)
|
| 322 |
+
|
| 323 |
+
# Retrieve cached keys/values for self-attn (if exist)
|
| 324 |
+
cache_k = cache.get(f"self_k_{i}", None)
|
| 325 |
+
cache_v = cache.get(f"self_v_{i}", None)
|
| 326 |
+
if cache_k is None or cache_v is None:
|
| 327 |
+
K_all = Kh # (B, H, 1, d_k)
|
| 328 |
+
V_all = Vh
|
| 329 |
+
else:
|
| 330 |
+
# concat along sequence dim (dim=2)
|
| 331 |
+
K_all = torch.cat([cache_k.to(device), Kh], dim=2)
|
| 332 |
+
V_all = torch.cat([cache_v.to(device), Vh], dim=2)
|
| 333 |
+
|
| 334 |
+
# Store updated caches
|
| 335 |
+
new_cache[f"self_k_{i}"] = K_all
|
| 336 |
+
new_cache[f"self_v_{i}"] = V_all
|
| 337 |
+
|
| 338 |
+
# Compute attention for the new token: Query length = 1, Key length = K_all.size(2)
|
| 339 |
+
attn_out_heads, self_attn_w = layer.self_attn.attention(Qh, K_all, V_all, mask=None)
|
| 340 |
+
# attn_out_heads: (B, H, 1, d_k)
|
| 341 |
+
# concat heads, project out
|
| 342 |
+
attn_out = attn_out_heads.transpose(1, 2).contiguous().view(B_, 1, num_heads * d_k)
|
| 343 |
+
attn_out = layer.self_attn.W_O(attn_out) # (B,1,d_model)
|
| 344 |
+
layer_output = layer_input + layer.dropout1(attn_out)
|
| 345 |
+
|
| 346 |
+
# -------------------
|
| 347 |
+
# 2) Cross-attention (use cached memory projections if available)
|
| 348 |
+
# -------------------
|
| 349 |
+
x_norm2 = layer.norm2(layer_output) # (B,1,d)
|
| 350 |
+
# Ensure memory K/V are cached per layer
|
| 351 |
+
mem_k = cache.get(f"mem_k_{i}", None)
|
| 352 |
+
mem_v = cache.get(f"mem_v_{i}", None)
|
| 353 |
+
if mem_k is None or mem_v is None:
|
| 354 |
+
# project memory once for this layer and cache it
|
| 355 |
+
# memory: (B, S, d_model)
|
| 356 |
+
MK = layer.cross_attn.W_K(memory) # (B, S, d_model)
|
| 357 |
+
MV = layer.cross_attn.W_V(memory)
|
| 358 |
+
Bm, S, _ = MK.shape
|
| 359 |
+
MKh = MK.view(Bm, S, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(1, 2) # (B,H,S,d_k)
|
| 360 |
+
MVh = MV.view(Bm, S, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(1, 2)
|
| 361 |
+
mem_k = MKh
|
| 362 |
+
mem_v = MVh
|
| 363 |
+
new_cache[f"mem_k_{i}"] = mem_k
|
| 364 |
+
new_cache[f"mem_v_{i}"] = mem_v
|
| 365 |
+
else:
|
| 366 |
+
mem_k = mem_k.to(device)
|
| 367 |
+
mem_v = mem_v.to(device)
|
| 368 |
+
|
| 369 |
+
Qc = layer.cross_attn.W_Q(x_norm2) # (B,1,d_model)
|
| 370 |
+
Qch = Qc.view(B, 1, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(1, 2) # (B,H,1,d_k)
|
| 371 |
+
|
| 372 |
+
cross_out_heads, cross_attn_w = layer.cross_attn.attention(Qch, mem_k, mem_v, mask=memory_mask)
|
| 373 |
+
cross_out = cross_out_heads.transpose(1, 2).contiguous().view(B, 1, layer.cross_attn.num_heads * layer.cross_attn.d_k)
|
| 374 |
+
cross_out = layer.cross_attn.W_O(cross_out) # (B,1,d_model)
|
| 375 |
+
layer_output = layer_output + layer.dropout2(cross_out)
|
| 376 |
+
|
| 377 |
+
# -------------------
|
| 378 |
+
# 3) Feed-forward (incremental)
|
| 379 |
+
# -------------------
|
| 380 |
+
x_norm3 = layer.norm3(layer_output)
|
| 381 |
+
ffn_out = layer.ffn(x_norm3) # (B,1,d_model)
|
| 382 |
+
layer_output = layer_output + layer.dropout3(ffn_out)
|
| 383 |
+
|
| 384 |
+
# Prepare for next layer
|
| 385 |
+
layer_input = layer_output
|
| 386 |
+
|
| 387 |
+
# Final norm + output projection (for this single time step)
|
| 388 |
+
out_norm = self.final_norm(layer_input) # (B,1,d_model)
|
| 389 |
+
logits = self.output_projection(out_norm) # (B,1,vocab)
|
| 390 |
+
logits = logits.squeeze(1) # (B, vocab)
|
| 391 |
+
|
| 392 |
+
return logits, new_cache
|
src/models/heads.py
CHANGED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prediction heads for Transformer models.
|
| 3 |
+
|
| 4 |
+
Includes:
|
| 5 |
+
- ClassificationHead: sequence-level classification with simple pooling (mean/cls/max).
|
| 6 |
+
- TokenClassificationHead: per-token classification (e.g., NER).
|
| 7 |
+
- LMHead: language-modeling head mapping hidden states to vocabulary logits. Optional weight tying to an Embedding.
|
| 8 |
+
- ProjectionHead: small projection MLP for representation learning / contrastive heads.
|
| 9 |
+
|
| 10 |
+
Keep these heads minimal, well-tested, and easy to compose on top of encoder/decoder outputs.
|
| 11 |
+
"""
|
| 12 |
+
from typing import Optional, Literal
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ClassificationHead(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Sequence-level classification head.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
d_model: hidden size from encoder/decoder
|
| 24 |
+
num_labels: number of output classes
|
| 25 |
+
pooler: one of 'mean', 'cls', 'max' - how to pool the sequence
|
| 26 |
+
dropout: dropout probability before final linear layer
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
d_model: int,
|
| 32 |
+
num_labels: int,
|
| 33 |
+
pooler: Literal["mean", "cls", "max"] = "mean",
|
| 34 |
+
dropout: float = 0.1,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
assert pooler in ("mean", "cls", "max"), "pooler must be 'mean'|'cls'|'max'"
|
| 38 |
+
self.pooler = pooler
|
| 39 |
+
self.dropout = nn.Dropout(dropout)
|
| 40 |
+
self.out_proj = nn.Linear(d_model, num_labels)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
"""
|
| 44 |
+
x: (batch, seq_len, d_model)
|
| 45 |
+
returns: (batch, num_labels)
|
| 46 |
+
"""
|
| 47 |
+
if self.pooler == "mean":
|
| 48 |
+
pooled = x.mean(dim=1)
|
| 49 |
+
elif self.pooler == "cls":
|
| 50 |
+
pooled = x[:, 0, :]
|
| 51 |
+
else: # max
|
| 52 |
+
pooled, _ = x.max(dim=1)
|
| 53 |
+
pooled = self.dropout(pooled)
|
| 54 |
+
return self.out_proj(pooled)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TokenClassificationHead(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
Per-token classification head. Useful for NER, POS, etc.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
d_model: hidden size
|
| 63 |
+
num_labels: number of per-token classes
|
| 64 |
+
dropout: dropout probability applied before the linear layer
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, d_model: int, num_labels: int, dropout: float = 0.1):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.dropout = nn.Dropout(dropout)
|
| 70 |
+
self.out_proj = nn.Linear(d_model, num_labels)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
x: (batch, seq_len, d_model)
|
| 75 |
+
returns: (batch, seq_len, num_labels)
|
| 76 |
+
"""
|
| 77 |
+
x = self.dropout(x)
|
| 78 |
+
return self.out_proj(x)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class LMHead(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Language modeling head: maps hidden states to logits over vocabulary.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
d_model: hidden size
|
| 87 |
+
vocab_size: vocabulary size
|
| 88 |
+
tie_embedding: optional nn.Embedding instance to tie weights with
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, d_model: int, vocab_size: int, tie_embedding: Optional[nn.Embedding] = None):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.vocab_size = vocab_size
|
| 94 |
+
self.d_model = d_model
|
| 95 |
+
self.proj = nn.Linear(d_model, vocab_size, bias=True)
|
| 96 |
+
|
| 97 |
+
if tie_embedding is not None:
|
| 98 |
+
# Validate sizes
|
| 99 |
+
assert tie_embedding.num_embeddings == vocab_size, "vocab size mismatch for weight tying"
|
| 100 |
+
assert tie_embedding.embedding_dim == d_model, "embedding dim must match d_model for weight tying"
|
| 101 |
+
# Tie weights: point the projection weight to the embedding weight Tensor
|
| 102 |
+
# Remove the existing projection parameter in favor of the embedding weight
|
| 103 |
+
# This keeps the same Parameter object, so updates affect both modules.
|
| 104 |
+
self.proj.weight = tie_embedding.weight
|
| 105 |
+
|
| 106 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
hidden_states: (batch, seq_len, d_model)
|
| 109 |
+
returns logits: (batch, seq_len, vocab_size)
|
| 110 |
+
"""
|
| 111 |
+
return self.proj(hidden_states)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ProjectionHead(nn.Module):
|
| 115 |
+
"""
|
| 116 |
+
Simple projection head for representation learning.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
d_model: input dimension
|
| 120 |
+
proj_dim: output projection dimension
|
| 121 |
+
hidden_dim: intermediate dimension (optional)
|
| 122 |
+
dropout: dropout probability
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, d_model: int, proj_dim: int = 128, hidden_dim: Optional[int] = None, dropout: float = 0.1):
|
| 126 |
+
super().__init__()
|
| 127 |
+
if hidden_dim is None:
|
| 128 |
+
hidden_dim = max(d_model, proj_dim)
|
| 129 |
+
self.net = nn.Sequential(
|
| 130 |
+
nn.Linear(d_model, hidden_dim),
|
| 131 |
+
nn.GELU(),
|
| 132 |
+
nn.Dropout(dropout),
|
| 133 |
+
nn.Linear(hidden_dim, proj_dim),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
"""
|
| 138 |
+
x: (batch, d_model) or (batch, seq_len, d_model) - both supported.
|
| 139 |
+
Returns:
|
| 140 |
+
If input is 3D: (batch, seq_len, proj_dim)
|
| 141 |
+
If input is 2D: (batch, proj_dim)
|
| 142 |
+
"""
|
| 143 |
+
orig_dim = x.dim()
|
| 144 |
+
if orig_dim == 3:
|
| 145 |
+
B, T, D = x.shape
|
| 146 |
+
out = self.net(x.view(B * T, D))
|
| 147 |
+
return out.view(B, T, -1)
|
| 148 |
+
elif orig_dim == 2:
|
| 149 |
+
return self.net(x)
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError("Input must be 2D or 3D tensor")
|
src/models/multitask.py
CHANGED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multitask model composition utilities.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- MultiTaskModel: lightweight wrapper to compose an encoder and/or decoder with
|
| 6 |
+
multiple task heads (classification, token classification, LM head, etc.)
|
| 7 |
+
- add_head / remove_head helpers
|
| 8 |
+
- forward(task_name, ...) that routes inputs to the correct sub-modules
|
| 9 |
+
- compute_loss helper that uses common losses and ignore_index support
|
| 10 |
+
|
| 11 |
+
Design goals:
|
| 12 |
+
- Keep composition simple and explicit (use named heads per task)
|
| 13 |
+
- Support encoder-only tasks (classification, token classification) and
|
| 14 |
+
seq2seq tasks (encoder -> decoder -> LMHead)
|
| 15 |
+
- Minimal dependencies on training loop; return logits and (optionally) loss
|
| 16 |
+
"""
|
| 17 |
+
from typing import Optional, Dict, Any, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
# Import your components
|
| 24 |
+
from .encoder import TransformerEncoder
|
| 25 |
+
from .decoder import TransformerDecoder
|
| 26 |
+
from .heads import ClassificationHead, TokenClassificationHead, LMHead
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MultiTaskModel(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
Compose encoder/decoder and task heads.
|
| 32 |
+
|
| 33 |
+
Usage patterns:
|
| 34 |
+
- Encoder-only classification:
|
| 35 |
+
mt = MultiTaskModel(encoder=enc)
|
| 36 |
+
mt.add_head("sentiment", ClassificationHead(...))
|
| 37 |
+
logits = mt.forward("sentiment", {"input_ids": src_ids})
|
| 38 |
+
- Seq2seq LM:
|
| 39 |
+
mt = MultiTaskModel(encoder=enc, decoder=dec)
|
| 40 |
+
mt.add_head("summarize", LMHead(...))
|
| 41 |
+
logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
encoder: Optional[TransformerEncoder] = None,
|
| 47 |
+
decoder: Optional[TransformerDecoder] = None,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.encoder = encoder
|
| 51 |
+
self.decoder = decoder
|
| 52 |
+
self.heads: Dict[str, nn.Module] = {}
|
| 53 |
+
|
| 54 |
+
def add_head(self, name: str, module: nn.Module) -> None:
|
| 55 |
+
"""Register a head under a task name."""
|
| 56 |
+
if name in self.heads:
|
| 57 |
+
raise ValueError(f"Head '{name}' already exists")
|
| 58 |
+
self.heads[name] = module
|
| 59 |
+
self.add_module(f"head_{name}", module)
|
| 60 |
+
|
| 61 |
+
def remove_head(self, name: str) -> None:
|
| 62 |
+
"""Remove a registered head."""
|
| 63 |
+
if name not in self.heads:
|
| 64 |
+
raise KeyError(name)
|
| 65 |
+
del self._modules[f"head_{name}"]
|
| 66 |
+
del self.heads[name]
|
| 67 |
+
|
| 68 |
+
def forward(
|
| 69 |
+
self,
|
| 70 |
+
task: str,
|
| 71 |
+
inputs: Dict[str, torch.Tensor],
|
| 72 |
+
return_loss: bool = False,
|
| 73 |
+
loss_kwargs: Optional[Dict[str, Any]] = None,
|
| 74 |
+
) -> Any:
|
| 75 |
+
"""
|
| 76 |
+
Route inputs to appropriate model components and head.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
task: registered head name
|
| 80 |
+
inputs: dictionary; common keys:
|
| 81 |
+
- For encoder tasks: "input_ids" or "embeddings" (B, S) or (B, S, d)
|
| 82 |
+
- For seq2seq: "src_ids" (B,S) or "src_embeddings", and "tgt_ids" (B,T) or "tgt_embeddings"
|
| 83 |
+
when computing training loss, pass "labels" (B,T) for LM
|
| 84 |
+
return_loss: if True and labels provided, returns (loss, logits)
|
| 85 |
+
loss_kwargs: forwarded to compute_loss (e.g., ignore_index)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
logits (or (loss, logits) if return_loss True)
|
| 89 |
+
"""
|
| 90 |
+
if task not in self.heads:
|
| 91 |
+
raise KeyError(f"Unknown task/head '{task}'")
|
| 92 |
+
|
| 93 |
+
head = self.heads[task]
|
| 94 |
+
loss_kwargs = loss_kwargs or {}
|
| 95 |
+
|
| 96 |
+
# Encoder-only heads expect encoder outputs
|
| 97 |
+
if isinstance(head, (ClassificationHead, TokenClassificationHead)):
|
| 98 |
+
if self.encoder is None:
|
| 99 |
+
raise RuntimeError("Encoder is required for encoder-side heads")
|
| 100 |
+
# accept either input_ids or embeddings
|
| 101 |
+
if "input_ids" in inputs:
|
| 102 |
+
enc_out = self.encoder(inputs["input_ids"])
|
| 103 |
+
elif "embeddings" in inputs:
|
| 104 |
+
enc_out = self.encoder(inputs["embeddings"])
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError("inputs must contain 'input_ids' or 'embeddings' for encoder tasks")
|
| 107 |
+
logits = head(enc_out)
|
| 108 |
+
|
| 109 |
+
if return_loss:
|
| 110 |
+
labels = inputs.get("labels", None)
|
| 111 |
+
if labels is None:
|
| 112 |
+
raise ValueError("return_loss=True requires 'labels' in inputs")
|
| 113 |
+
loss = self.compute_loss_for_head(head, logits, labels, **loss_kwargs)
|
| 114 |
+
return loss, logits
|
| 115 |
+
return logits
|
| 116 |
+
|
| 117 |
+
# LM/seq2seq head: run encoder -> decoder -> lm head
|
| 118 |
+
if isinstance(head, LMHead):
|
| 119 |
+
if self.encoder is None or self.decoder is None:
|
| 120 |
+
raise RuntimeError("Both encoder and decoder are required for LM-style heads")
|
| 121 |
+
|
| 122 |
+
# Build encoder memory
|
| 123 |
+
if "src_ids" in inputs:
|
| 124 |
+
memory = self.encoder(inputs["src_ids"])
|
| 125 |
+
elif "src_embeddings" in inputs:
|
| 126 |
+
memory = self.encoder(inputs["src_embeddings"])
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError("inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks")
|
| 129 |
+
|
| 130 |
+
# If training / teacher forcing: expect tgt_ids (shifted by caller) or embeddings
|
| 131 |
+
if "tgt_ids" in inputs:
|
| 132 |
+
decoder_inputs = inputs["tgt_ids"]
|
| 133 |
+
elif "tgt_embeddings" in inputs:
|
| 134 |
+
decoder_inputs = inputs["tgt_embeddings"]
|
| 135 |
+
else:
|
| 136 |
+
# For generation time you may call decoder.greedy_decode separately.
|
| 137 |
+
# Here we don't attempt to generate when labels not provided.
|
| 138 |
+
raise ValueError("Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward")
|
| 139 |
+
|
| 140 |
+
# Run decoder. Decoder returns logits shaped (B, T, vocab) in this codebase.
|
| 141 |
+
decoder_out = self.decoder(decoder_inputs, memory)
|
| 142 |
+
|
| 143 |
+
# If decoder already returned logits matching the head vocab size, use them directly.
|
| 144 |
+
# Otherwise, assume decoder returned hidden states and let the head project them.
|
| 145 |
+
if isinstance(decoder_out, torch.Tensor) and decoder_out.shape[-1] == head.vocab_size:
|
| 146 |
+
logits = decoder_out
|
| 147 |
+
else:
|
| 148 |
+
logits = head(decoder_out)
|
| 149 |
+
|
| 150 |
+
if return_loss:
|
| 151 |
+
labels = inputs.get("labels", None)
|
| 152 |
+
if labels is None:
|
| 153 |
+
raise ValueError("return_loss=True requires 'labels' in inputs for seq2seq")
|
| 154 |
+
loss = self.compute_loss_for_head(head, logits, labels, **loss_kwargs)
|
| 155 |
+
return loss, logits
|
| 156 |
+
return logits
|
| 157 |
+
|
| 158 |
+
# Otherwise unsupported head type
|
| 159 |
+
raise RuntimeError(f"Unsupported head type: {type(head)}")
|
| 160 |
+
|
| 161 |
+
def compute_loss_for_head(
|
| 162 |
+
self,
|
| 163 |
+
head: nn.Module,
|
| 164 |
+
logits: torch.Tensor,
|
| 165 |
+
labels: torch.Tensor,
|
| 166 |
+
ignore_index: int = -100,
|
| 167 |
+
) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
Default loss dispatch:
|
| 170 |
+
- ClassificationHead: CrossEntropy on (B, num_labels)
|
| 171 |
+
- TokenClassificationHead: CrossEntropy per token (flattened)
|
| 172 |
+
- LMHead: CrossEntropy per token (flattened), ignore_index supported
|
| 173 |
+
|
| 174 |
+
Returns scalar loss.
|
| 175 |
+
"""
|
| 176 |
+
if isinstance(head, ClassificationHead):
|
| 177 |
+
# logits: (B, num_labels) or (B, num_labels) direct
|
| 178 |
+
loss = F.cross_entropy(logits, labels.long())
|
| 179 |
+
return loss
|
| 180 |
+
|
| 181 |
+
if isinstance(head, TokenClassificationHead):
|
| 182 |
+
# logits: (B, T, C), labels: (B, T)
|
| 183 |
+
B, T, C = logits.shape
|
| 184 |
+
loss = F.cross_entropy(logits.view(B * T, C), labels.view(B * T).long(), ignore_index=ignore_index)
|
| 185 |
+
return loss
|
| 186 |
+
|
| 187 |
+
if isinstance(head, LMHead):
|
| 188 |
+
# logits: (B, T, V), labels: (B, T)
|
| 189 |
+
B, T, V = logits.shape
|
| 190 |
+
loss = F.cross_entropy(logits.view(B * T, V), labels.view(B * T).long(), ignore_index=ignore_index)
|
| 191 |
+
return loss
|
| 192 |
+
|
| 193 |
+
# Generic fall-back: try CrossEntropy on final dim
|
| 194 |
+
if logits.dim() == 2:
|
| 195 |
+
return F.cross_entropy(logits, labels.long())
|
| 196 |
+
|
| 197 |
+
# If we can't determine, raise
|
| 198 |
+
raise RuntimeError("Cannot compute loss for unknown head type")
|
tests/test_models/test_decoder.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytest
|
| 3 |
+
from src.models.decoder import (
|
| 4 |
+
create_causal_mask,
|
| 5 |
+
TransformerDecoderLayer,
|
| 6 |
+
TransformerDecoder,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_create_causal_mask_properties():
|
| 11 |
+
mask = create_causal_mask(5)
|
| 12 |
+
assert mask.shape == (5, 5)
|
| 13 |
+
# diagonal and below should be True
|
| 14 |
+
for i in range(5):
|
| 15 |
+
for j in range(5):
|
| 16 |
+
if j <= i:
|
| 17 |
+
assert mask[i, j].item() is True
|
| 18 |
+
else:
|
| 19 |
+
assert mask[i, j].item() is False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_decoder_layer_shapes_and_grad():
|
| 23 |
+
torch.manual_seed(0)
|
| 24 |
+
d_model, num_heads, d_ff = 32, 4, 64
|
| 25 |
+
batch_size, tgt_len, src_len = 2, 6, 7
|
| 26 |
+
|
| 27 |
+
layer = TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
|
| 28 |
+
tgt = torch.randn(batch_size, tgt_len, d_model, requires_grad=True)
|
| 29 |
+
memory = torch.randn(batch_size, src_len, d_model)
|
| 30 |
+
|
| 31 |
+
# No masks
|
| 32 |
+
out, attn = layer(tgt, memory, tgt_mask=None, memory_mask=None)
|
| 33 |
+
assert out.shape == (batch_size, tgt_len, d_model)
|
| 34 |
+
assert isinstance(attn, dict)
|
| 35 |
+
assert "self" in attn and "cross" in attn
|
| 36 |
+
assert attn["self"].shape == (batch_size, num_heads, tgt_len, tgt_len)
|
| 37 |
+
assert attn["cross"].shape == (batch_size, num_heads, tgt_len, src_len)
|
| 38 |
+
|
| 39 |
+
# Backprop works
|
| 40 |
+
loss = out.sum()
|
| 41 |
+
loss.backward()
|
| 42 |
+
grads = [p.grad for p in layer.parameters() if p.requires_grad]
|
| 43 |
+
assert any(g is not None for g in grads)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_decoder_layer_causal_mask_blocks_future():
|
| 47 |
+
torch.manual_seed(1)
|
| 48 |
+
d_model, num_heads, d_ff = 48, 6, 128
|
| 49 |
+
batch_size, tgt_len, src_len = 1, 5, 5
|
| 50 |
+
|
| 51 |
+
layer = TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
|
| 52 |
+
# create trivial increasing tgt embeddings so attention patterns are deterministic-ish
|
| 53 |
+
tgt = torch.randn(batch_size, tgt_len, d_model)
|
| 54 |
+
memory = torch.randn(batch_size, src_len, d_model)
|
| 55 |
+
|
| 56 |
+
causal = create_causal_mask(tgt_len, device=tgt.device) # (T, T)
|
| 57 |
+
tgt_mask = causal.unsqueeze(0) # (1, T, T) -> layer will handle unsqueeze to heads
|
| 58 |
+
|
| 59 |
+
out, attn = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=None)
|
| 60 |
+
self_attn = attn["self"].detach()
|
| 61 |
+
# Ensure upper triangle of attention weights is zero (no future attention)
|
| 62 |
+
# For each head and query i, keys j>i should be zero
|
| 63 |
+
B, H, Tq, Tk = self_attn.shape
|
| 64 |
+
for i in range(Tq):
|
| 65 |
+
for j in range(i + 1, Tk):
|
| 66 |
+
assert torch.allclose(self_attn[:, :, i, j], torch.zeros(B, H)), \
|
| 67 |
+
f"Found nonzero attention to future position {j} from query {i}"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_decoder_stack_and_greedy_decode_shapes():
|
| 71 |
+
torch.manual_seed(2)
|
| 72 |
+
vocab_size = 30
|
| 73 |
+
d_model = 32
|
| 74 |
+
num_layers = 2
|
| 75 |
+
num_heads = 4
|
| 76 |
+
d_ff = 128
|
| 77 |
+
batch_size = 2
|
| 78 |
+
src_len = 7
|
| 79 |
+
max_tgt = 6
|
| 80 |
+
|
| 81 |
+
decoder = TransformerDecoder(
|
| 82 |
+
vocab_size=vocab_size,
|
| 83 |
+
d_model=d_model,
|
| 84 |
+
num_layers=num_layers,
|
| 85 |
+
num_heads=num_heads,
|
| 86 |
+
d_ff=d_ff,
|
| 87 |
+
dropout=0.0,
|
| 88 |
+
max_len=max_tgt,
|
| 89 |
+
pad_token_id=0,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Random memory from encoder
|
| 93 |
+
memory = torch.randn(batch_size, src_len, d_model)
|
| 94 |
+
|
| 95 |
+
# Greedy decode: should produce (B, <= max_tgt)
|
| 96 |
+
generated = decoder.greedy_decode(memory, max_len=max_tgt, start_token_id=1, end_token_id=None)
|
| 97 |
+
assert generated.shape[0] == batch_size
|
| 98 |
+
assert generated.shape[1] <= max_tgt
|
| 99 |
+
assert (generated[:, 0] == 1).all() # starts with start token
|
| 100 |
+
|
| 101 |
+
# Also test forward with embeddings and collect_attn
|
| 102 |
+
embeddings = torch.randn(batch_size, max_tgt, d_model)
|
| 103 |
+
logits, attn_list = decoder(embeddings, memory, collect_attn=True)
|
| 104 |
+
assert logits.shape == (batch_size, max_tgt, vocab_size)
|
| 105 |
+
assert isinstance(attn_list, list)
|
| 106 |
+
assert len(attn_list) == num_layers
|
| 107 |
+
for attn in attn_list:
|
| 108 |
+
assert "self" in attn and "cross" in attn
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def test_decoder_train_eval_dropout_behavior():
|
| 112 |
+
torch.manual_seed(3)
|
| 113 |
+
vocab_size = 40
|
| 114 |
+
d_model = 32
|
| 115 |
+
num_layers = 2
|
| 116 |
+
num_heads = 4
|
| 117 |
+
d_ff = 128
|
| 118 |
+
batch_size = 2
|
| 119 |
+
src_len = 6
|
| 120 |
+
tgt_len = 5
|
| 121 |
+
|
| 122 |
+
decoder = TransformerDecoder(
|
| 123 |
+
vocab_size=vocab_size,
|
| 124 |
+
d_model=d_model,
|
| 125 |
+
num_layers=num_layers,
|
| 126 |
+
num_heads=num_heads,
|
| 127 |
+
d_ff=d_ff,
|
| 128 |
+
dropout=0.4,
|
| 129 |
+
max_len=tgt_len,
|
| 130 |
+
pad_token_id=0,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# token ids with padding possible
|
| 134 |
+
input_ids = torch.randint(1, vocab_size, (batch_size, tgt_len), dtype=torch.long)
|
| 135 |
+
input_ids[0, -1] = 0
|
| 136 |
+
|
| 137 |
+
memory = torch.randn(batch_size, src_len, d_model)
|
| 138 |
+
|
| 139 |
+
decoder.train()
|
| 140 |
+
out1 = decoder(input_ids, memory)
|
| 141 |
+
out2 = decoder(input_ids, memory)
|
| 142 |
+
# With dropout in train mode, outputs should usually differ
|
| 143 |
+
assert not torch.allclose(out1, out2)
|
| 144 |
+
|
| 145 |
+
decoder.eval()
|
| 146 |
+
out3 = decoder(input_ids, memory)
|
| 147 |
+
out4 = decoder(input_ids, memory)
|
| 148 |
+
assert torch.allclose(out3, out4)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
pytest.main([__file__, "-q"])
|
tests/test_models/test_decoder_step.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytest
|
| 3 |
+
from typing import Any, Dict, cast
|
| 4 |
+
from src.models.decoder import TransformerDecoder
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_step_equivalence_with_greedy_decode():
|
| 8 |
+
torch.manual_seed(7)
|
| 9 |
+
vocab_size = 25
|
| 10 |
+
d_model = 32
|
| 11 |
+
num_layers = 2
|
| 12 |
+
num_heads = 4
|
| 13 |
+
d_ff = 64
|
| 14 |
+
batch_size = 2
|
| 15 |
+
src_len = 6
|
| 16 |
+
max_tgt = 6
|
| 17 |
+
|
| 18 |
+
decoder = TransformerDecoder(
|
| 19 |
+
vocab_size=vocab_size,
|
| 20 |
+
d_model=d_model,
|
| 21 |
+
num_layers=num_layers,
|
| 22 |
+
num_heads=num_heads,
|
| 23 |
+
d_ff=d_ff,
|
| 24 |
+
dropout=0.0,
|
| 25 |
+
max_len=max_tgt,
|
| 26 |
+
pad_token_id=0,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
memory = torch.randn(batch_size, src_len, d_model)
|
| 30 |
+
|
| 31 |
+
# 1) Get greedy sequence from naive greedy_decode
|
| 32 |
+
greedy = decoder.greedy_decode(memory, max_len=max_tgt, start_token_id=1, end_token_id=None)
|
| 33 |
+
|
| 34 |
+
# 2) Reproduce the same sequence with step() using cache
|
| 35 |
+
cache: Dict[str, Any] = {"past_length": 0}
|
| 36 |
+
generated = torch.full((batch_size, 1), 1, dtype=torch.long)
|
| 37 |
+
for _ in range(max_tgt - 1):
|
| 38 |
+
last_token = generated[:, -1:].to(memory.device)
|
| 39 |
+
logits, cache = decoder.step(cast(torch.LongTensor, last_token), memory, cache=cache)
|
| 40 |
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
| 41 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 42 |
+
|
| 43 |
+
# Compare shapes & that sequences are identical
|
| 44 |
+
assert generated.shape == greedy.shape
|
| 45 |
+
assert torch.equal(generated, greedy)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_step_cache_growth_and_shapes():
|
| 49 |
+
torch.manual_seed(9)
|
| 50 |
+
vocab_size = 20
|
| 51 |
+
d_model = 24
|
| 52 |
+
num_layers = 3
|
| 53 |
+
num_heads = 4
|
| 54 |
+
d_ff = 64
|
| 55 |
+
batch_size = 1
|
| 56 |
+
src_len = 5
|
| 57 |
+
steps = 4
|
| 58 |
+
max_tgt = 8
|
| 59 |
+
|
| 60 |
+
decoder = TransformerDecoder(
|
| 61 |
+
vocab_size=vocab_size,
|
| 62 |
+
d_model=d_model,
|
| 63 |
+
num_layers=num_layers,
|
| 64 |
+
num_heads=num_heads,
|
| 65 |
+
d_ff=d_ff,
|
| 66 |
+
dropout=0.0,
|
| 67 |
+
max_len=max_tgt,
|
| 68 |
+
pad_token_id=0,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
memory = torch.randn(batch_size, src_len, d_model)
|
| 72 |
+
|
| 73 |
+
cache: Dict[str, Any] = {"past_length": 0}
|
| 74 |
+
last = torch.full((batch_size, 1), 1, dtype=torch.long)
|
| 75 |
+
for step_idx in range(steps):
|
| 76 |
+
logits, cache = decoder.step(cast(torch.LongTensor, last), memory, cache=cache)
|
| 77 |
+
# check updated past_length
|
| 78 |
+
assert cache["past_length"] == step_idx + 1
|
| 79 |
+
# check cached per-layer keys exist and have expected shape (B, H, seq_len, d_k)
|
| 80 |
+
for i in range(num_layers):
|
| 81 |
+
k = cache.get(f"self_k_{i}")
|
| 82 |
+
v = cache.get(f"self_v_{i}")
|
| 83 |
+
assert k is not None and v is not None
|
| 84 |
+
# seq_len should equal past_length
|
| 85 |
+
assert k.shape[2] == cache["past_length"]
|
| 86 |
+
# shapes match
|
| 87 |
+
assert k.shape[0] == batch_size
|
| 88 |
+
assert v.shape[0] == batch_size
|
| 89 |
+
# advance last token for next loop
|
| 90 |
+
last = logits.argmax(dim=-1, keepdim=True)
|
| 91 |
+
|
| 92 |
+
# Also ensure memory projections cached
|
| 93 |
+
for i in range(num_layers):
|
| 94 |
+
assert f"mem_k_{i}" in cache and f"mem_v_{i}" in cache
|
| 95 |
+
mem_k = cache[f"mem_k_{i}"]
|
| 96 |
+
mem_v = cache[f"mem_v_{i}"]
|
| 97 |
+
assert mem_k.shape[0] == batch_size
|
| 98 |
+
assert mem_k.shape[2] == src_len # seq length of memory
|
tests/test_models/test_heads.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytest
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from src.models.heads import (
|
| 5 |
+
ClassificationHead,
|
| 6 |
+
TokenClassificationHead,
|
| 7 |
+
LMHead,
|
| 8 |
+
ProjectionHead,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_classification_head_shapes_and_dropout():
|
| 13 |
+
torch.manual_seed(0)
|
| 14 |
+
d_model = 64
|
| 15 |
+
num_labels = 5
|
| 16 |
+
batch_size = 3
|
| 17 |
+
seq_len = 10
|
| 18 |
+
|
| 19 |
+
head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.5)
|
| 20 |
+
head.train()
|
| 21 |
+
x = torch.randn(batch_size, seq_len, d_model)
|
| 22 |
+
|
| 23 |
+
out1 = head(x)
|
| 24 |
+
out2 = head(x)
|
| 25 |
+
# With dropout in train mode, outputs should usually differ
|
| 26 |
+
assert out1.shape == (batch_size, num_labels)
|
| 27 |
+
assert out2.shape == (batch_size, num_labels)
|
| 28 |
+
assert not torch.allclose(out1, out2)
|
| 29 |
+
|
| 30 |
+
head.eval()
|
| 31 |
+
out3 = head(x)
|
| 32 |
+
out4 = head(x)
|
| 33 |
+
assert torch.allclose(out3, out4), "Eval mode should be deterministic"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_token_classification_head_shapes_and_grads():
|
| 37 |
+
torch.manual_seed(1)
|
| 38 |
+
d_model = 48
|
| 39 |
+
num_labels = 7
|
| 40 |
+
batch_size = 2
|
| 41 |
+
seq_len = 6
|
| 42 |
+
|
| 43 |
+
head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0)
|
| 44 |
+
x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
|
| 45 |
+
out = head(x)
|
| 46 |
+
assert out.shape == (batch_size, seq_len, num_labels)
|
| 47 |
+
|
| 48 |
+
loss = out.sum()
|
| 49 |
+
loss.backward()
|
| 50 |
+
grads = [p.grad for name, p in head.named_parameters() if p.requires_grad]
|
| 51 |
+
assert any(g is not None for g in grads)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_lm_head_tie_weights_and_shapes():
|
| 55 |
+
torch.manual_seed(2)
|
| 56 |
+
vocab_size = 50
|
| 57 |
+
d_model = 32
|
| 58 |
+
batch_size = 2
|
| 59 |
+
seq_len = 4
|
| 60 |
+
|
| 61 |
+
embedding = nn.Embedding(vocab_size, d_model)
|
| 62 |
+
lm_tied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=embedding)
|
| 63 |
+
lm_untied = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None)
|
| 64 |
+
|
| 65 |
+
hidden = torch.randn(batch_size, seq_len, d_model)
|
| 66 |
+
|
| 67 |
+
# Shapes
|
| 68 |
+
logits_tied = lm_tied(hidden)
|
| 69 |
+
logits_untied = lm_untied(hidden)
|
| 70 |
+
assert logits_tied.shape == (batch_size, seq_len, vocab_size)
|
| 71 |
+
assert logits_untied.shape == (batch_size, seq_len, vocab_size)
|
| 72 |
+
|
| 73 |
+
# Weight tying: projection weight should be the same object as embedding.weight
|
| 74 |
+
assert lm_tied.proj.weight is embedding.weight
|
| 75 |
+
|
| 76 |
+
# Grad flows through tied weights
|
| 77 |
+
loss = logits_tied.sum()
|
| 78 |
+
loss.backward()
|
| 79 |
+
assert embedding.weight.grad is not None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_projection_head_2d_and_3d_behavior_and_grad():
|
| 83 |
+
torch.manual_seed(3)
|
| 84 |
+
d_model = 40
|
| 85 |
+
proj_dim = 16
|
| 86 |
+
batch_size = 2
|
| 87 |
+
seq_len = 5
|
| 88 |
+
|
| 89 |
+
head = ProjectionHead(d_model=d_model, proj_dim=proj_dim, hidden_dim=64, dropout=0.0)
|
| 90 |
+
# 2D input
|
| 91 |
+
vec = torch.randn(batch_size, d_model, requires_grad=True)
|
| 92 |
+
out2 = head(vec)
|
| 93 |
+
assert out2.shape == (batch_size, proj_dim)
|
| 94 |
+
|
| 95 |
+
# 3D input
|
| 96 |
+
seq = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
|
| 97 |
+
out3 = head(seq)
|
| 98 |
+
assert out3.shape == (batch_size, seq_len, proj_dim)
|
| 99 |
+
|
| 100 |
+
# Grad flow
|
| 101 |
+
loss = out3.sum()
|
| 102 |
+
loss.backward()
|
| 103 |
+
grads = [p.grad for p in head.parameters() if p.requires_grad]
|
| 104 |
+
assert any(g is not None for g in grads)
|
tests/test_models/test_multitask.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytest
|
| 3 |
+
from src.models.encoder import TransformerEncoder
|
| 4 |
+
from src.models.decoder import TransformerDecoder
|
| 5 |
+
from src.models.heads import ClassificationHead, LMHead, TokenClassificationHead
|
| 6 |
+
from src.models.multitask import MultiTaskModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_multitask_encoder_classification_forward_and_loss():
|
| 10 |
+
torch.manual_seed(0)
|
| 11 |
+
vocab_size = 30
|
| 12 |
+
d_model = 32
|
| 13 |
+
num_layers = 2
|
| 14 |
+
num_heads = 4
|
| 15 |
+
d_ff = 64
|
| 16 |
+
batch_size = 3
|
| 17 |
+
seq_len = 8
|
| 18 |
+
num_labels = 5
|
| 19 |
+
|
| 20 |
+
enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
|
| 21 |
+
num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=seq_len, pad_token_id=0)
|
| 22 |
+
|
| 23 |
+
mt = MultiTaskModel(encoder=enc)
|
| 24 |
+
head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.0)
|
| 25 |
+
mt.add_head("sentiment", head)
|
| 26 |
+
|
| 27 |
+
input_ids = torch.randint(1, vocab_size, (batch_size, seq_len), dtype=torch.long)
|
| 28 |
+
labels = torch.randint(0, num_labels, (batch_size,), dtype=torch.long)
|
| 29 |
+
|
| 30 |
+
logits = mt.forward("sentiment", {"input_ids": input_ids})
|
| 31 |
+
assert logits.shape == (batch_size, num_labels)
|
| 32 |
+
|
| 33 |
+
loss, logits2 = mt.forward("sentiment", {"input_ids": input_ids, "labels": labels}, return_loss=True)
|
| 34 |
+
assert loss.item() >= 0
|
| 35 |
+
# grads
|
| 36 |
+
loss.backward()
|
| 37 |
+
grads = [p.grad for p in mt.parameters() if p.requires_grad]
|
| 38 |
+
assert any(g is not None for g in grads)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_multitask_seq2seq_lm_forward_and_loss():
|
| 42 |
+
torch.manual_seed(1)
|
| 43 |
+
vocab_size = 40
|
| 44 |
+
d_model = 32
|
| 45 |
+
num_layers = 2
|
| 46 |
+
num_heads = 4
|
| 47 |
+
d_ff = 64
|
| 48 |
+
batch_size = 2
|
| 49 |
+
src_len = 7
|
| 50 |
+
tgt_len = 6
|
| 51 |
+
|
| 52 |
+
enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
|
| 53 |
+
num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=src_len, pad_token_id=0)
|
| 54 |
+
dec = TransformerDecoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
|
| 55 |
+
num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=tgt_len, pad_token_id=0)
|
| 56 |
+
mt = MultiTaskModel(encoder=enc, decoder=dec)
|
| 57 |
+
lm_head = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None)
|
| 58 |
+
mt.add_head("summarize", lm_head)
|
| 59 |
+
|
| 60 |
+
src_ids = torch.randint(1, vocab_size, (batch_size, src_len), dtype=torch.long)
|
| 61 |
+
# for training: provide decoder inputs (typically shifted right) and labels
|
| 62 |
+
tgt_ids = torch.randint(1, vocab_size, (batch_size, tgt_len), dtype=torch.long)
|
| 63 |
+
labels = tgt_ids.clone()
|
| 64 |
+
|
| 65 |
+
logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
|
| 66 |
+
assert logits.shape == (batch_size, tgt_len, vocab_size)
|
| 67 |
+
|
| 68 |
+
loss, logits2 = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids, "labels": labels}, return_loss=True)
|
| 69 |
+
assert loss.item() >= 0
|
| 70 |
+
loss.backward()
|
| 71 |
+
grads = [p.grad for p in mt.parameters() if p.requires_grad]
|
| 72 |
+
assert any(g is not None for g in grads)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_token_classification_forward_and_loss():
|
| 76 |
+
torch.manual_seed(2)
|
| 77 |
+
vocab_size = 20
|
| 78 |
+
d_model = 24
|
| 79 |
+
num_layers = 2
|
| 80 |
+
num_heads = 4
|
| 81 |
+
d_ff = 64
|
| 82 |
+
batch_size = 2
|
| 83 |
+
seq_len = 5
|
| 84 |
+
num_labels = 7
|
| 85 |
+
|
| 86 |
+
enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
|
| 87 |
+
num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=seq_len, pad_token_id=0)
|
| 88 |
+
mt = MultiTaskModel(encoder=enc)
|
| 89 |
+
head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0)
|
| 90 |
+
mt.add_head("ner", head)
|
| 91 |
+
|
| 92 |
+
input_ids = torch.randint(1, vocab_size, (batch_size, seq_len), dtype=torch.long)
|
| 93 |
+
labels = torch.randint(0, num_labels, (batch_size, seq_len), dtype=torch.long)
|
| 94 |
+
|
| 95 |
+
logits = mt.forward("ner", {"input_ids": input_ids})
|
| 96 |
+
assert logits.shape == (batch_size, seq_len, num_labels)
|
| 97 |
+
|
| 98 |
+
loss, logits2 = mt.forward("ner", {"input_ids": input_ids, "labels": labels}, return_loss=True)
|
| 99 |
+
assert loss.item() >= 0
|
| 100 |
+
loss.backward()
|
| 101 |
+
grads = [p.grad for p in mt.parameters() if p.requires_grad]
|
| 102 |
+
assert any(g is not None for g in grads)
|