bmeyer2025 commited on
Commit
1e86f73
Β·
verified Β·
1 Parent(s): 2d3ab90

Upload src/attention.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/attention.py +83 -0
src/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Milestone 2: Single-head causal self-attention.
3
+
4
+ Implements scaled dot-product attention with:
5
+ - Separate Q, K, V linear projections
6
+ - Causal mask (lower-triangular) so each position can only attend to past tokens
7
+ - Dropout on the attention weights
8
+
9
+ Key formula:
10
+ Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ class Head(nn.Module):
19
+ """Single head of causal self-attention."""
20
+
21
+ def __init__(self, head_size: int, n_embd: int, block_size: int, dropout: float = 0.1):
22
+ super().__init__()
23
+ self.key = nn.Linear(n_embd, head_size, bias=False)
24
+ self.query = nn.Linear(n_embd, head_size, bias=False)
25
+ self.value = nn.Linear(n_embd, head_size, bias=False)
26
+ self.dropout = nn.Dropout(dropout)
27
+
28
+ # Causal mask: lower triangle of 1s, upper triangle of 0s.
29
+ # Registered as a buffer so it moves with the model (to/from device)
30
+ # but is NOT a learnable parameter.
31
+ self.register_buffer(
32
+ "tril",
33
+ torch.tril(torch.ones(block_size, block_size))
34
+ )
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ B, T, C = x.shape # batch, time (seq len), channels (n_embd)
38
+
39
+ k = self.key(x) # (B, T, head_size)
40
+ q = self.query(x) # (B, T, head_size)
41
+ v = self.value(x) # (B, T, head_size)
42
+
43
+ head_size = k.shape[-1]
44
+
45
+ # Scaled dot-product attention scores
46
+ # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
47
+ scores = q @ k.transpose(-2, -1) * (head_size ** -0.5)
48
+
49
+ # Apply causal mask: positions that shouldn't be attended to get -inf,
50
+ # which softmax turns into 0 probability.
51
+ scores = scores.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
52
+
53
+ weights = F.softmax(scores, dim=-1) # (B, T, T)
54
+ weights = self.dropout(weights)
55
+
56
+ # Weighted sum of values
57
+ out = weights @ v # (B, T, head_size)
58
+ return out
59
+
60
+
61
+ # ── Quick sanity check ────────────────────────────────────────────────────────
62
+ if __name__ == "__main__":
63
+ from tokenizer import DEVICE, BLOCK_SIZE, get_batch
64
+
65
+ n_embd = 32
66
+ head_size = 16
67
+ batch_size = 4
68
+
69
+ head = Head(head_size=head_size, n_embd=n_embd, block_size=BLOCK_SIZE).to(DEVICE)
70
+
71
+ # Use random embeddings (we don't have the full model yet)
72
+ x = torch.randn(batch_size, BLOCK_SIZE, n_embd, device=DEVICE)
73
+ out = head(x)
74
+
75
+ print(f"Input shape: {x.shape}")
76
+ print(f"Output shape: {out.shape} (expected [4, {BLOCK_SIZE}, {head_size}])")
77
+
78
+ # Verify causality: output at position t should NOT depend on positions > t.
79
+ # We do this by checking that the attention mask is lower-triangular.
80
+ tril = head.tril[:8, :8]
81
+ print(f"\nCausal mask (8x8 top-left corner):")
82
+ print(tril.int())
83
+ print("\nMilestone 2 OK: single-head causal self-attention works.")