|
|
from constants import * |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class PatchEmbeddings(nn.Module): |
|
|
def __init__(self, patch_size=PATCH_SIZE, hidden_dim=HIDDEN_DIM): |
|
|
super().__init__() |
|
|
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
|
|
def forward(self, X): |
|
|
X = self.conv(X) |
|
|
X = X.flatten(2) |
|
|
X = X.transpose(1, 2) |
|
|
return X |
|
|
|
|
|
class Head(nn.Module): |
|
|
def __init__(self, n_embd, head_size, dropout=DROPOUT, is_decoder=False): |
|
|
super().__init__() |
|
|
self.key = nn.Linear(n_embd, head_size, bias=False) |
|
|
self.query = nn.Linear(n_embd, head_size, bias=False) |
|
|
self.value = nn.Linear(n_embd, head_size, bias=False) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.is_decoder = is_decoder |
|
|
|
|
|
if self.is_decoder: |
|
|
self.register_buffer("bias", torch.tril(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH, dtype=torch.bool)) |
|
|
.view(1, CONTEXT_LENGTH, CONTEXT_LENGTH), persistent=False) |
|
|
|
|
|
|
|
|
def forward(self, x, attention_mask=None): |
|
|
B, T, C = x.shape |
|
|
|
|
|
k = self.key(x) |
|
|
q = self.query(x) |
|
|
v = self.value(x) |
|
|
|
|
|
|
|
|
wei = q @ k.transpose(-2, -1) * (k.size(-1)**-0.5) |
|
|
|
|
|
if self.is_decoder: |
|
|
|
|
|
|
|
|
causal_mask = self.bias[:, :T, :T] |
|
|
wei = wei.masked_fill(causal_mask == 0, float('-inf')) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
wei = wei.masked_fill(mask == 0, float('-inf')) |
|
|
|
|
|
|
|
|
|
|
|
wei = F.softmax(wei, dim=-1) |
|
|
wei = self.dropout(wei) |
|
|
|
|
|
|
|
|
out = wei @ v |
|
|
|
|
|
return out |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False): |
|
|
super().__init__() |
|
|
assert n_embd % num_heads == 0 |
|
|
head_size = n_embd // num_heads |
|
|
self.heads = nn.ModuleList([ |
|
|
Head(n_embd, head_size, dropout, is_decoder) |
|
|
for _ in range(num_heads) |
|
|
]) |
|
|
self.proj = nn.Linear(n_embd, n_embd) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.is_decoder = is_decoder |
|
|
|
|
|
def forward(self, x, attention_mask=None): |
|
|
|
|
|
out = torch.cat([h(x, attention_mask=attention_mask if self.is_decoder else None) for h in self.heads], dim=-1) |
|
|
out = self.dropout(self.proj(out)) |
|
|
return out |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
""" a simple linear layer followed by a non-linearity """ |
|
|
def __init__(self, n_embd, dropout=DROPOUT): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(n_embd, 4 * n_embd), |
|
|
nn.GELU(), |
|
|
nn.Linear(4 * n_embd, n_embd), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
class Block(nn.Module): |
|
|
""" Transformer block: communication followed by computation """ |
|
|
def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(n_embd) |
|
|
self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder) |
|
|
self.ln2 = nn.LayerNorm(n_embd) |
|
|
self.ffn = FeedForward(n_embd, dropout) |
|
|
self.is_decoder = is_decoder |
|
|
|
|
|
def forward(self, x, attention_mask=None): |
|
|
|
|
|
|
|
|
x = x + self.attn(self.ln1(x), attention_mask=attention_mask if self.is_decoder else None) |
|
|
x = x + self.ffn(self.ln2(x)) |
|
|
|
|
|
return x |
|
|
|
|
|
class ViT(nn.Module): |
|
|
def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_hiddens=HIDDEN_DIM, |
|
|
num_heads=NUM_HEADS, num_blks=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT): |
|
|
super().__init__() |
|
|
self.patch_embedding = PatchEmbeddings(patch_size, num_hiddens) |
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens)) |
|
|
num_patches = (img_size // patch_size) ** 2 |
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens) * 0.02) |
|
|
self.dropout = nn.Dropout(emb_dropout) |
|
|
|
|
|
self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)]) |
|
|
self.layer_norm = nn.LayerNorm(num_hiddens) |
|
|
|
|
|
def forward(self, X): |
|
|
x = self.patch_embedding(X) |
|
|
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) |
|
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
x = x + self.pos_embedding |
|
|
x = self.dropout(x) |
|
|
for block in self.blocks: |
|
|
|
|
|
x = block(x) |
|
|
x = self.layer_norm(x) |
|
|
return x |
|
|
|
|
|
class MultiModalProjector(nn.Module): |
|
|
|
|
|
def __init__(self, image_embed_dim=HIDDEN_DIM, text_embed_dim=HIDDEN_DIM, dropout=DROPOUT): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(image_embed_dim, text_embed_dim * 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(text_embed_dim * 4, text_embed_dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |