|
|
from torch import Tensor |
|
|
import torch.nn as nn |
|
|
from jaxtyping import Bool, Float |
|
|
import math |
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
""" |
|
|
Terminology (jaxtyping): |
|
|
B: batch_size |
|
|
T_q: target sequence length (query) |
|
|
T_k: source sequence length (key/value) |
|
|
D: d_model (model dimension) |
|
|
H: n_heads (number of heads) |
|
|
d_k: dimension of each head (d_model / n_heads) |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, n_heads: int) -> None: |
|
|
super().__init__() |
|
|
assert d_model % n_heads == 0, "d_model must be divisible by n_heads" |
|
|
|
|
|
self.d_model: int = d_model |
|
|
self.n_heads: int = n_heads |
|
|
self.d_k: int = d_model // n_heads |
|
|
|
|
|
self.w_q: nn.Linear = nn.Linear(d_model, d_model, bias=False) |
|
|
self.w_k: nn.Linear = nn.Linear(d_model, d_model, bias=False) |
|
|
self.w_v: nn.Linear = nn.Linear(d_model, d_model, bias=False) |
|
|
self.w_o: nn.Linear = nn.Linear(d_model, d_model, bias=False) |
|
|
|
|
|
self.attention_weights: Tensor | None = None |
|
|
|
|
|
@staticmethod |
|
|
def attention( |
|
|
query: Float[Tensor, "B H T_q d_k"], |
|
|
key: Float[Tensor, "B H T_k d_k"], |
|
|
value: Float[Tensor, "B H T_k d_k"], |
|
|
mask: Bool[Tensor, "... 1 T_q T_k"] | None, |
|
|
) -> tuple[Float[Tensor, "B H T_q d_k"], Float[Tensor, "B H T_q T_k"]]: |
|
|
""" |
|
|
Static method for Scaled Dot-Product Attention calculation. |
|
|
This is pure, stateless logic, making it easy to test. |
|
|
(Ref: "Attention Is All You Need", Equation 1) |
|
|
|
|
|
Args: |
|
|
query (Tensor): Query tensor |
|
|
key (Tensor): Key tensor |
|
|
value (Tensor): Value tensor |
|
|
mask (Tensor | None): Optional mask (for padding or look-ahead). |
|
|
|
|
|
Returns: |
|
|
tuple[Tensor, Tensor]: |
|
|
- context_vector: The output of the attention mechanism. |
|
|
- attention_weights: The softmax-normalized attention weights. |
|
|
""" |
|
|
|
|
|
d_k: int = query.shape[-1] |
|
|
|
|
|
|
|
|
attention_scores: Tensor = (query @ key.transpose(-2, -1)) / math.sqrt(d_k) |
|
|
|
|
|
if mask is not None: |
|
|
attention_scores = attention_scores.masked_fill( |
|
|
mask == 0, value=float("-inf") |
|
|
) |
|
|
|
|
|
attention_weights: Tensor = attention_scores.softmax(dim=-1) |
|
|
|
|
|
|
|
|
context_vector: Tensor = attention_weights @ value |
|
|
|
|
|
return context_vector, attention_weights |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: Float[Tensor, "B T_q D"], |
|
|
k: Float[Tensor, "B T_k D"], |
|
|
v: Float[Tensor, "B T_k D"], |
|
|
mask: Bool[Tensor, "... 1 T_q T_k"] | None = None, |
|
|
) -> Float[Tensor, "B T_q D"]: |
|
|
""" |
|
|
Forward pass for Multi-Head Attention. |
|
|
|
|
|
In Self-Attention (Encoder), q, k, and v are all the same tensor. |
|
|
In Cross-Attention (Decoder), q comes from the Decoder, while k and v |
|
|
come from the Encoder's output. |
|
|
|
|
|
Args: |
|
|
q: Query tensor |
|
|
k: Key tensor |
|
|
v: Value tensor |
|
|
mask: Optional mask to apply (padding or look-ahead) |
|
|
|
|
|
Returns: |
|
|
The context vector after multi-head attention and output projection. |
|
|
""" |
|
|
|
|
|
B, T_q, _ = q.shape |
|
|
_, T_k, _ = k.shape |
|
|
|
|
|
|
|
|
Q: Tensor = self.w_q(q) |
|
|
K: Tensor = self.w_k(k) |
|
|
V: Tensor = self.w_v(v) |
|
|
|
|
|
|
|
|
Q = Q.view(B, T_q, self.n_heads, self.d_k).transpose(1, 2) |
|
|
K = K.view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) |
|
|
V = V.view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) |
|
|
|
|
|
context_vector, self.attention_weights = self.attention(Q, K, V, mask) |
|
|
|
|
|
|
|
|
context_vector = context_vector.transpose(1, 2).contiguous() |
|
|
|
|
|
|
|
|
context_vector = context_vector.view(B, T_q, self.d_model) |
|
|
|
|
|
|
|
|
output: Tensor = self.w_o(context_vector) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
|
""" |
|
|
Implements the Position-wise Feed-Forward Network (FFN) sublayer. |
|
|
(Ref: "Attention Is All You Need", Section 3.3) |
|
|
|
|
|
This is a two-layer MLP (Multi-Layer Perceptron) applied independently |
|
|
to each position in the sequence. |
|
|
|
|
|
FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2 |
|
|
(Or using ReLU activation) |
|
|
|
|
|
Terminology (jaxtyping): |
|
|
B: batch_size |
|
|
T: seq_len (context_length) |
|
|
D: d_model (model dimension) |
|
|
D_FF: d_ff (inner feed-forward dimension) |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, d_ff: int) -> None: |
|
|
""" |
|
|
Initializes the FFN. |
|
|
|
|
|
Args: |
|
|
d_model (int): Dimension of the model (e.g., 512). |
|
|
d_ff (int): Inner dimension of the FFN (e.g., 2048). |
|
|
Paper suggests d_ff = 4 * d_model. |
|
|
dropout (float): Dropout probability (applied *before* the |
|
|
second linear layer in some implementations, |
|
|
or as part of ResidualConnection). |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.linear_1: nn.Linear = nn.Linear(d_model, d_ff) |
|
|
|
|
|
self.activation: nn.ReLU = nn.ReLU() |
|
|
|
|
|
|
|
|
self.linear_2: nn.Linear = nn.Linear(d_ff, d_model) |
|
|
|
|
|
def forward(self, x: Float[Tensor, "B T D"]) -> Float[Tensor, "B T D"]: |
|
|
""" |
|
|
Forward pass for the FFN. |
|
|
Applies two linear transformations with a ReLU activation in between. |
|
|
|
|
|
Args: |
|
|
x: Input tensor from the previous sublayer |
|
|
(e.g., MultiHeadAttention output). |
|
|
|
|
|
Returns: |
|
|
Output tensor of the same shape. |
|
|
""" |
|
|
|
|
|
x = self.linear_1(x) |
|
|
|
|
|
|
|
|
x = self.activation(x) |
|
|
|
|
|
|
|
|
x = self.linear_2(x) |
|
|
|
|
|
return x |
|
|
|