File size: 7,293 Bytes
b781107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbeddings(nn.Module):
    def __init__(self, patch_size=PATCH_SIZE, hidden_dim=HIDDEN_DIM):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, X):
        X = self.conv(X) # (B, C, H/P, W/P)
        X = X.flatten(2)  # (B, C, N) where N = (H/P)*(W/P)
        X = X.transpose(1, 2)  # (B, N, C)
        return X

class Head(nn.Module):
    def __init__(self, n_embd, head_size, dropout=DROPOUT, is_decoder=False):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.is_decoder = is_decoder
        # causal mask is registered persistent=False so it's not saved in state_dict
        if self.is_decoder:
            self.register_buffer("bias", torch.tril(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH, dtype=torch.bool))
                                 .view(1, CONTEXT_LENGTH, CONTEXT_LENGTH), persistent=False)


    def forward(self, x, attention_mask=None):
        B, T, C = x.shape
        # print(f"B = {B} T={T}, C={C}")
        k = self.key(x)   # (B, T, hs)
        q = self.query(x) # (B, T, hs)
        v = self.value(x) # (B, T, hs)

        # Compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * (k.size(-1)**-0.5) # (B, T, hs) @ (B, hs, T) -> (B, T, T)

        if self.is_decoder:
            # Apply causal mask
            # Ensure the mask is sliced correctly if T < CONTEXT_LENGTH
            causal_mask = self.bias[:, :T, :T]
            wei = wei.masked_fill(causal_mask == 0, float('-inf'))

        if attention_mask is not None:
            # Apply padding mask (for text tokens)
            # attention_mask shape: (B, T_combined) -> needs expansion
            # Expand mask: (B, T) -> (B, 1, 1, T) or (B, 1, T, T) depending on what needs masking
            # Mask where attention_mask is 0
            # attention_mask shape: (B, T) == (B, T_key)
            # Expand mask to align with wei's key dimension for broadcasting across queries
            # Target shape for mask: [B, 1, T_key]
            # print(f"attn mask = {attention_mask.shape}")
            # print(f"wei shape = {wei.shape}")
            mask = attention_mask.unsqueeze(1) # Shape [B, 1, T]
            # Apply mask using broadcasting rules. masked_fill condition needs to be broadcastable to wei [B, T_query, T_key]
            # (mask == 0) gives a boolean tensor of shape [B, 1, T]
            # This broadcasts correctly: dim 2 (T vs T) matches, dim 1 (1 vs T) broadcasts 1->T, dim 0 (B vs B) matches.
            wei = wei.masked_fill(mask == 0, float('-inf'))


        # Apply softmax
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # Perform weighted aggregation of values
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        # print(f"out shape = {out.shape}")
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
        super().__init__()
        assert n_embd % num_heads == 0
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList([
            Head(n_embd, head_size, dropout, is_decoder)
            for _ in range(num_heads)
        ])
        self.proj = nn.Linear(n_embd, n_embd) # n_embd = num_heads * head_size
        self.dropout = nn.Dropout(dropout)
        self.is_decoder = is_decoder # Store is_decoder status

    def forward(self, x, attention_mask=None):
         # Pass attention_mask only if it's a decoder block dealing with combined sequence
        out = torch.cat([h(x, attention_mask=attention_mask if self.is_decoder else None) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """
    def __init__(self, n_embd, dropout=DROPOUT):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(), # Changed from ReLU to GELU, common in transformers
            nn.Linear(4 * n_embd, n_embd), # Projection back to residual stream
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """
    def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ffn = FeedForward(n_embd, dropout)
        self.is_decoder = is_decoder # Store is_decoder status

    def forward(self, x, attention_mask=None):
        # Pass attention_mask only if it's a decoder block
        # print(f"is decoder = {self.is_decoder} input shape = {x.shape}")
        x = x + self.attn(self.ln1(x), attention_mask=attention_mask if self.is_decoder else None)
        x = x + self.ffn(self.ln2(x))
        # print(f"output shape = {x.shape}")
        return x

class ViT(nn.Module):
    def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_hiddens=HIDDEN_DIM,
                 num_heads=NUM_HEADS, num_blks=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT):
        super().__init__()
        self.patch_embedding = PatchEmbeddings(patch_size, num_hiddens)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
        num_patches = (img_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens) * 0.02) # Smaller init
        self.dropout = nn.Dropout(emb_dropout)
        # ViT blocks are NOT decoders (no causal mask)
        self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
        self.layer_norm = nn.LayerNorm(num_hiddens) # Final LN

    def forward(self, X):
        x = self.patch_embedding(X) # (B, N, C)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, C)
        x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, C)
        # Add positional embedding
        x = x + self.pos_embedding # Uses broadcasting
        x = self.dropout(x)
        for block in self.blocks:
            # ViT blocks don't need attention_mask
            x = block(x)
        x = self.layer_norm(x) # Apply final layer norm
        return x

class MultiModalProjector(nn.Module):
    # Projects image embedding dim to text embedding dim
    def __init__(self, image_embed_dim=HIDDEN_DIM, text_embed_dim=HIDDEN_DIM, dropout=DROPOUT):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(image_embed_dim, text_embed_dim * 4), # Intermediate expansion
            nn.GELU(),
            nn.Linear(text_embed_dim * 4, text_embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)