pointing_models_from_scratch / model_components.py
mbiswas's picture
Upload 10 files
b781107 verified
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbeddings(nn.Module):
def __init__(self, patch_size=PATCH_SIZE, hidden_dim=HIDDEN_DIM):
super().__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, X):
X = self.conv(X) # (B, C, H/P, W/P)
X = X.flatten(2) # (B, C, N) where N = (H/P)*(W/P)
X = X.transpose(1, 2) # (B, N, C)
return X
class Head(nn.Module):
def __init__(self, n_embd, head_size, dropout=DROPOUT, is_decoder=False):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
self.is_decoder = is_decoder
# causal mask is registered persistent=False so it's not saved in state_dict
if self.is_decoder:
self.register_buffer("bias", torch.tril(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH, dtype=torch.bool))
.view(1, CONTEXT_LENGTH, CONTEXT_LENGTH), persistent=False)
def forward(self, x, attention_mask=None):
B, T, C = x.shape
# print(f"B = {B} T={T}, C={C}")
k = self.key(x) # (B, T, hs)
q = self.query(x) # (B, T, hs)
v = self.value(x) # (B, T, hs)
# Compute attention scores ("affinities")
wei = q @ k.transpose(-2, -1) * (k.size(-1)**-0.5) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
if self.is_decoder:
# Apply causal mask
# Ensure the mask is sliced correctly if T < CONTEXT_LENGTH
causal_mask = self.bias[:, :T, :T]
wei = wei.masked_fill(causal_mask == 0, float('-inf'))
if attention_mask is not None:
# Apply padding mask (for text tokens)
# attention_mask shape: (B, T_combined) -> needs expansion
# Expand mask: (B, T) -> (B, 1, 1, T) or (B, 1, T, T) depending on what needs masking
# Mask where attention_mask is 0
# attention_mask shape: (B, T) == (B, T_key)
# Expand mask to align with wei's key dimension for broadcasting across queries
# Target shape for mask: [B, 1, T_key]
# print(f"attn mask = {attention_mask.shape}")
# print(f"wei shape = {wei.shape}")
mask = attention_mask.unsqueeze(1) # Shape [B, 1, T]
# Apply mask using broadcasting rules. masked_fill condition needs to be broadcastable to wei [B, T_query, T_key]
# (mask == 0) gives a boolean tensor of shape [B, 1, T]
# This broadcasts correctly: dim 2 (T vs T) matches, dim 1 (1 vs T) broadcasts 1->T, dim 0 (B vs B) matches.
wei = wei.masked_fill(mask == 0, float('-inf'))
# Apply softmax
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
# Perform weighted aggregation of values
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
# print(f"out shape = {out.shape}")
return out
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
super().__init__()
assert n_embd % num_heads == 0
head_size = n_embd // num_heads
self.heads = nn.ModuleList([
Head(n_embd, head_size, dropout, is_decoder)
for _ in range(num_heads)
])
self.proj = nn.Linear(n_embd, n_embd) # n_embd = num_heads * head_size
self.dropout = nn.Dropout(dropout)
self.is_decoder = is_decoder # Store is_decoder status
def forward(self, x, attention_mask=None):
# Pass attention_mask only if it's a decoder block dealing with combined sequence
out = torch.cat([h(x, attention_mask=attention_mask if self.is_decoder else None) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
""" a simple linear layer followed by a non-linearity """
def __init__(self, n_embd, dropout=DROPOUT):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(), # Changed from ReLU to GELU, common in transformers
nn.Linear(4 * n_embd, n_embd), # Projection back to residual stream
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block: communication followed by computation """
def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
self.ln2 = nn.LayerNorm(n_embd)
self.ffn = FeedForward(n_embd, dropout)
self.is_decoder = is_decoder # Store is_decoder status
def forward(self, x, attention_mask=None):
# Pass attention_mask only if it's a decoder block
# print(f"is decoder = {self.is_decoder} input shape = {x.shape}")
x = x + self.attn(self.ln1(x), attention_mask=attention_mask if self.is_decoder else None)
x = x + self.ffn(self.ln2(x))
# print(f"output shape = {x.shape}")
return x
class ViT(nn.Module):
def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_hiddens=HIDDEN_DIM,
num_heads=NUM_HEADS, num_blks=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT):
super().__init__()
self.patch_embedding = PatchEmbeddings(patch_size, num_hiddens)
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
num_patches = (img_size // patch_size) ** 2
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens) * 0.02) # Smaller init
self.dropout = nn.Dropout(emb_dropout)
# ViT blocks are NOT decoders (no causal mask)
self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
self.layer_norm = nn.LayerNorm(num_hiddens) # Final LN
def forward(self, X):
x = self.patch_embedding(X) # (B, N, C)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, C)
x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, C)
# Add positional embedding
x = x + self.pos_embedding # Uses broadcasting
x = self.dropout(x)
for block in self.blocks:
# ViT blocks don't need attention_mask
x = block(x)
x = self.layer_norm(x) # Apply final layer norm
return x
class MultiModalProjector(nn.Module):
# Projects image embedding dim to text embedding dim
def __init__(self, image_embed_dim=HIDDEN_DIM, text_embed_dim=HIDDEN_DIM, dropout=DROPOUT):
super().__init__()
self.net = nn.Sequential(
nn.Linear(image_embed_dim, text_embed_dim * 4), # Intermediate expansion
nn.GELU(),
nn.Linear(text_embed_dim * 4, text_embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)