MuhammedKsee's picture
Upload model.py with huggingface_hub
62e8bb2 verified
import torch
import torch.nn as nn
import math
from config import HParams
class VisionTransformer(nn.Module):
def __init__(self):
super().__init__()
# Patch Embedding
self.conv1 = nn.Conv2d(3, HParams.VISION_WIDTH, kernel_size=16, stride=16, bias=False)
scale = HParams.VISION_WIDTH ** -0.5
self.cls = nn.Parameter(scale * torch.randn(HParams.VISION_WIDTH))
self.pos = nn.Parameter(scale * torch.randn((HParams.IMAGE_SIZE // 16) ** 2 + 1, HParams.VISION_WIDTH))
self.ln_pre = nn.LayerNorm(HParams.VISION_WIDTH)
# Transformer Encoder
layer = nn.TransformerEncoderLayer(
d_model=HParams.VISION_WIDTH,
nhead=HParams.VISION_HEADS,
dim_feedforward=HParams.VISION_WIDTH * HParams.FFN_MULT,
activation="gelu",
batch_first=True,
norm_first=True
)
self.transformer = nn.TransformerEncoder(layer, HParams.VISION_LAYERS)
self.ln_post = nn.LayerNorm(HParams.VISION_WIDTH)
self.proj = nn.Parameter(scale * torch.randn(HParams.VISION_WIDTH, HParams.EMBED_DIM))
def forward(self, x):
x = self.conv1(x) # [B, C, H, W]
x = x.flatten(2).transpose(1, 2) # [B, L, C]
cls = self.cls.expand(x.size(0), 1, -1)
x = torch.cat([cls, x], dim=1)
x = x + self.pos
x = self.ln_pre(x)
x = self.transformer(x)
x = self.ln_post(x[:, 0]) # Take CLS token
return x @ self.proj
class TextTransformer(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token = nn.Embedding(vocab_size, HParams.TEXT_WIDTH)
self.pos = nn.Parameter(torch.randn(HParams.MAX_TOKENS, HParams.TEXT_WIDTH))
layer = nn.TransformerEncoderLayer(
d_model=HParams.TEXT_WIDTH,
nhead=HParams.TEXT_HEADS,
dim_feedforward=HParams.TEXT_WIDTH * HParams.FFN_MULT,
activation="gelu",
batch_first=True,
norm_first=True
)
self.transformer = nn.TransformerEncoder(layer, HParams.TEXT_LAYERS)
self.ln = nn.LayerNorm(HParams.TEXT_WIDTH)
self.proj = nn.Parameter(torch.randn(HParams.TEXT_WIDTH, HParams.EMBED_DIM))
def forward(self, input_ids, attention_mask):
x = self.token(input_ids)
x = x + self.pos[:x.size(1)]
# Masking (ignore padding)
mask = attention_mask == 0
x = self.transformer(x, src_key_padding_mask=mask)
x = self.ln(x)
# Get embedding from EOT (End of Text) token
eot_indices = attention_mask.sum(dim=1) - 1
x = x[torch.arange(x.size(0)), eot_indices]
return x @ self.proj
class CLIP(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.visual = VisionTransformer()
self.text = TextTransformer(vocab_size)
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
def forward(self, img, text_ids, text_mask):
# NOTE: During inference, encode_image and encode_text are usually called separately.
# However, keeping the forward method ensures structural integrity.
img_features = self.visual(img)
text_features = self.text(text_ids, text_mask)
img_features = img_features / img_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp().clamp(max=100)
logits_per_image = logit_scale * img_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text