Upload src/attention.py with huggingface_hub
Browse files- src/attention.py +83 -0
src/attention.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Milestone 2: Single-head causal self-attention.
|
| 3 |
+
|
| 4 |
+
Implements scaled dot-product attention with:
|
| 5 |
+
- Separate Q, K, V linear projections
|
| 6 |
+
- Causal mask (lower-triangular) so each position can only attend to past tokens
|
| 7 |
+
- Dropout on the attention weights
|
| 8 |
+
|
| 9 |
+
Key formula:
|
| 10 |
+
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Head(nn.Module):
|
| 19 |
+
"""Single head of causal self-attention."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, head_size: int, n_embd: int, block_size: int, dropout: float = 0.1):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.key = nn.Linear(n_embd, head_size, bias=False)
|
| 24 |
+
self.query = nn.Linear(n_embd, head_size, bias=False)
|
| 25 |
+
self.value = nn.Linear(n_embd, head_size, bias=False)
|
| 26 |
+
self.dropout = nn.Dropout(dropout)
|
| 27 |
+
|
| 28 |
+
# Causal mask: lower triangle of 1s, upper triangle of 0s.
|
| 29 |
+
# Registered as a buffer so it moves with the model (to/from device)
|
| 30 |
+
# but is NOT a learnable parameter.
|
| 31 |
+
self.register_buffer(
|
| 32 |
+
"tril",
|
| 33 |
+
torch.tril(torch.ones(block_size, block_size))
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
B, T, C = x.shape # batch, time (seq len), channels (n_embd)
|
| 38 |
+
|
| 39 |
+
k = self.key(x) # (B, T, head_size)
|
| 40 |
+
q = self.query(x) # (B, T, head_size)
|
| 41 |
+
v = self.value(x) # (B, T, head_size)
|
| 42 |
+
|
| 43 |
+
head_size = k.shape[-1]
|
| 44 |
+
|
| 45 |
+
# Scaled dot-product attention scores
|
| 46 |
+
# (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
|
| 47 |
+
scores = q @ k.transpose(-2, -1) * (head_size ** -0.5)
|
| 48 |
+
|
| 49 |
+
# Apply causal mask: positions that shouldn't be attended to get -inf,
|
| 50 |
+
# which softmax turns into 0 probability.
|
| 51 |
+
scores = scores.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
| 52 |
+
|
| 53 |
+
weights = F.softmax(scores, dim=-1) # (B, T, T)
|
| 54 |
+
weights = self.dropout(weights)
|
| 55 |
+
|
| 56 |
+
# Weighted sum of values
|
| 57 |
+
out = weights @ v # (B, T, head_size)
|
| 58 |
+
return out
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ββ Quick sanity check ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
from tokenizer import DEVICE, BLOCK_SIZE, get_batch
|
| 64 |
+
|
| 65 |
+
n_embd = 32
|
| 66 |
+
head_size = 16
|
| 67 |
+
batch_size = 4
|
| 68 |
+
|
| 69 |
+
head = Head(head_size=head_size, n_embd=n_embd, block_size=BLOCK_SIZE).to(DEVICE)
|
| 70 |
+
|
| 71 |
+
# Use random embeddings (we don't have the full model yet)
|
| 72 |
+
x = torch.randn(batch_size, BLOCK_SIZE, n_embd, device=DEVICE)
|
| 73 |
+
out = head(x)
|
| 74 |
+
|
| 75 |
+
print(f"Input shape: {x.shape}")
|
| 76 |
+
print(f"Output shape: {out.shape} (expected [4, {BLOCK_SIZE}, {head_size}])")
|
| 77 |
+
|
| 78 |
+
# Verify causality: output at position t should NOT depend on positions > t.
|
| 79 |
+
# We do this by checking that the attention mask is lower-triangular.
|
| 80 |
+
tril = head.tril[:8, :8]
|
| 81 |
+
print(f"\nCausal mask (8x8 top-left corner):")
|
| 82 |
+
print(tril.int())
|
| 83 |
+
print("\nMilestone 2 OK: single-head causal self-attention works.")
|