from constants import * import torch import torch.nn as nn import torch.nn.functional as F class PatchEmbeddings(nn.Module): def __init__(self, patch_size=PATCH_SIZE, hidden_dim=HIDDEN_DIM): super().__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size) def forward(self, X): X = self.conv(X) # (B, C, H/P, W/P) X = X.flatten(2) # (B, C, N) where N = (H/P)*(W/P) X = X.transpose(1, 2) # (B, N, C) return X class Head(nn.Module): def __init__(self, n_embd, head_size, dropout=DROPOUT, is_decoder=False): super().__init__() self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.dropout = nn.Dropout(dropout) self.is_decoder = is_decoder # causal mask is registered persistent=False so it's not saved in state_dict if self.is_decoder: self.register_buffer("bias", torch.tril(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH, dtype=torch.bool)) .view(1, CONTEXT_LENGTH, CONTEXT_LENGTH), persistent=False) def forward(self, x, attention_mask=None): B, T, C = x.shape # print(f"B = {B} T={T}, C={C}") k = self.key(x) # (B, T, hs) q = self.query(x) # (B, T, hs) v = self.value(x) # (B, T, hs) # Compute attention scores ("affinities") wei = q @ k.transpose(-2, -1) * (k.size(-1)**-0.5) # (B, T, hs) @ (B, hs, T) -> (B, T, T) if self.is_decoder: # Apply causal mask # Ensure the mask is sliced correctly if T < CONTEXT_LENGTH causal_mask = self.bias[:, :T, :T] wei = wei.masked_fill(causal_mask == 0, float('-inf')) if attention_mask is not None: # Apply padding mask (for text tokens) # attention_mask shape: (B, T_combined) -> needs expansion # Expand mask: (B, T) -> (B, 1, 1, T) or (B, 1, T, T) depending on what needs masking # Mask where attention_mask is 0 # attention_mask shape: (B, T) == (B, T_key) # Expand mask to align with wei's key dimension for broadcasting across queries # Target shape for mask: [B, 1, T_key] # print(f"attn mask = {attention_mask.shape}") # print(f"wei shape = {wei.shape}") mask = attention_mask.unsqueeze(1) # Shape [B, 1, T] # Apply mask using broadcasting rules. masked_fill condition needs to be broadcastable to wei [B, T_query, T_key] # (mask == 0) gives a boolean tensor of shape [B, 1, T] # This broadcasts correctly: dim 2 (T vs T) matches, dim 1 (1 vs T) broadcasts 1->T, dim 0 (B vs B) matches. wei = wei.masked_fill(mask == 0, float('-inf')) # Apply softmax wei = F.softmax(wei, dim=-1) wei = self.dropout(wei) # Perform weighted aggregation of values out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs) # print(f"out shape = {out.shape}") return out class MultiHeadAttention(nn.Module): def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False): super().__init__() assert n_embd % num_heads == 0 head_size = n_embd // num_heads self.heads = nn.ModuleList([ Head(n_embd, head_size, dropout, is_decoder) for _ in range(num_heads) ]) self.proj = nn.Linear(n_embd, n_embd) # n_embd = num_heads * head_size self.dropout = nn.Dropout(dropout) self.is_decoder = is_decoder # Store is_decoder status def forward(self, x, attention_mask=None): # Pass attention_mask only if it's a decoder block dealing with combined sequence out = torch.cat([h(x, attention_mask=attention_mask if self.is_decoder else None) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out class FeedForward(nn.Module): """ a simple linear layer followed by a non-linearity """ def __init__(self, n_embd, dropout=DROPOUT): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.GELU(), # Changed from ReLU to GELU, common in transformers nn.Linear(4 * n_embd, n_embd), # Projection back to residual stream nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Block(nn.Module): """ Transformer block: communication followed by computation """ def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder) self.ln2 = nn.LayerNorm(n_embd) self.ffn = FeedForward(n_embd, dropout) self.is_decoder = is_decoder # Store is_decoder status def forward(self, x, attention_mask=None): # Pass attention_mask only if it's a decoder block # print(f"is decoder = {self.is_decoder} input shape = {x.shape}") x = x + self.attn(self.ln1(x), attention_mask=attention_mask if self.is_decoder else None) x = x + self.ffn(self.ln2(x)) # print(f"output shape = {x.shape}") return x class ViT(nn.Module): def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_hiddens=HIDDEN_DIM, num_heads=NUM_HEADS, num_blks=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT): super().__init__() self.patch_embedding = PatchEmbeddings(patch_size, num_hiddens) self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens)) num_patches = (img_size // patch_size) ** 2 self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens) * 0.02) # Smaller init self.dropout = nn.Dropout(emb_dropout) # ViT blocks are NOT decoders (no causal mask) self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)]) self.layer_norm = nn.LayerNorm(num_hiddens) # Final LN def forward(self, X): x = self.patch_embedding(X) # (B, N, C) cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, C) x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, C) # Add positional embedding x = x + self.pos_embedding # Uses broadcasting x = self.dropout(x) for block in self.blocks: # ViT blocks don't need attention_mask x = block(x) x = self.layer_norm(x) # Apply final layer norm return x class MultiModalProjector(nn.Module): # Projects image embedding dim to text embedding dim def __init__(self, image_embed_dim=HIDDEN_DIM, text_embed_dim=HIDDEN_DIM, dropout=DROPOUT): super().__init__() self.net = nn.Sequential( nn.Linear(image_embed_dim, text_embed_dim * 4), # Intermediate expansion nn.GELU(), nn.Linear(text_embed_dim * 4, text_embed_dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x)