File size: 7,293 Bytes
b781107 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbeddings(nn.Module):
def __init__(self, patch_size=PATCH_SIZE, hidden_dim=HIDDEN_DIM):
super().__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, X):
X = self.conv(X) # (B, C, H/P, W/P)
X = X.flatten(2) # (B, C, N) where N = (H/P)*(W/P)
X = X.transpose(1, 2) # (B, N, C)
return X
class Head(nn.Module):
def __init__(self, n_embd, head_size, dropout=DROPOUT, is_decoder=False):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
self.is_decoder = is_decoder
# causal mask is registered persistent=False so it's not saved in state_dict
if self.is_decoder:
self.register_buffer("bias", torch.tril(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH, dtype=torch.bool))
.view(1, CONTEXT_LENGTH, CONTEXT_LENGTH), persistent=False)
def forward(self, x, attention_mask=None):
B, T, C = x.shape
# print(f"B = {B} T={T}, C={C}")
k = self.key(x) # (B, T, hs)
q = self.query(x) # (B, T, hs)
v = self.value(x) # (B, T, hs)
# Compute attention scores ("affinities")
wei = q @ k.transpose(-2, -1) * (k.size(-1)**-0.5) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
if self.is_decoder:
# Apply causal mask
# Ensure the mask is sliced correctly if T < CONTEXT_LENGTH
causal_mask = self.bias[:, :T, :T]
wei = wei.masked_fill(causal_mask == 0, float('-inf'))
if attention_mask is not None:
# Apply padding mask (for text tokens)
# attention_mask shape: (B, T_combined) -> needs expansion
# Expand mask: (B, T) -> (B, 1, 1, T) or (B, 1, T, T) depending on what needs masking
# Mask where attention_mask is 0
# attention_mask shape: (B, T) == (B, T_key)
# Expand mask to align with wei's key dimension for broadcasting across queries
# Target shape for mask: [B, 1, T_key]
# print(f"attn mask = {attention_mask.shape}")
# print(f"wei shape = {wei.shape}")
mask = attention_mask.unsqueeze(1) # Shape [B, 1, T]
# Apply mask using broadcasting rules. masked_fill condition needs to be broadcastable to wei [B, T_query, T_key]
# (mask == 0) gives a boolean tensor of shape [B, 1, T]
# This broadcasts correctly: dim 2 (T vs T) matches, dim 1 (1 vs T) broadcasts 1->T, dim 0 (B vs B) matches.
wei = wei.masked_fill(mask == 0, float('-inf'))
# Apply softmax
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
# Perform weighted aggregation of values
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
# print(f"out shape = {out.shape}")
return out
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
super().__init__()
assert n_embd % num_heads == 0
head_size = n_embd // num_heads
self.heads = nn.ModuleList([
Head(n_embd, head_size, dropout, is_decoder)
for _ in range(num_heads)
])
self.proj = nn.Linear(n_embd, n_embd) # n_embd = num_heads * head_size
self.dropout = nn.Dropout(dropout)
self.is_decoder = is_decoder # Store is_decoder status
def forward(self, x, attention_mask=None):
# Pass attention_mask only if it's a decoder block dealing with combined sequence
out = torch.cat([h(x, attention_mask=attention_mask if self.is_decoder else None) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
""" a simple linear layer followed by a non-linearity """
def __init__(self, n_embd, dropout=DROPOUT):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(), # Changed from ReLU to GELU, common in transformers
nn.Linear(4 * n_embd, n_embd), # Projection back to residual stream
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block: communication followed by computation """
def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
self.ln2 = nn.LayerNorm(n_embd)
self.ffn = FeedForward(n_embd, dropout)
self.is_decoder = is_decoder # Store is_decoder status
def forward(self, x, attention_mask=None):
# Pass attention_mask only if it's a decoder block
# print(f"is decoder = {self.is_decoder} input shape = {x.shape}")
x = x + self.attn(self.ln1(x), attention_mask=attention_mask if self.is_decoder else None)
x = x + self.ffn(self.ln2(x))
# print(f"output shape = {x.shape}")
return x
class ViT(nn.Module):
def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_hiddens=HIDDEN_DIM,
num_heads=NUM_HEADS, num_blks=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT):
super().__init__()
self.patch_embedding = PatchEmbeddings(patch_size, num_hiddens)
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
num_patches = (img_size // patch_size) ** 2
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens) * 0.02) # Smaller init
self.dropout = nn.Dropout(emb_dropout)
# ViT blocks are NOT decoders (no causal mask)
self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
self.layer_norm = nn.LayerNorm(num_hiddens) # Final LN
def forward(self, X):
x = self.patch_embedding(X) # (B, N, C)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, C)
x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, C)
# Add positional embedding
x = x + self.pos_embedding # Uses broadcasting
x = self.dropout(x)
for block in self.blocks:
# ViT blocks don't need attention_mask
x = block(x)
x = self.layer_norm(x) # Apply final layer norm
return x
class MultiModalProjector(nn.Module):
# Projects image embedding dim to text embedding dim
def __init__(self, image_embed_dim=HIDDEN_DIM, text_embed_dim=HIDDEN_DIM, dropout=DROPOUT):
super().__init__()
self.net = nn.Sequential(
nn.Linear(image_embed_dim, text_embed_dim * 4), # Intermediate expansion
nn.GELU(),
nn.Linear(text_embed_dim * 4, text_embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x) |