| | import torch
|
| | import torch.nn as nn
|
| | import math
|
| | from config import HParams
|
| |
|
| | class VisionTransformer(nn.Module):
|
| | def __init__(self):
|
| | super().__init__()
|
| |
|
| | self.conv1 = nn.Conv2d(3, HParams.VISION_WIDTH, kernel_size=16, stride=16, bias=False)
|
| |
|
| | scale = HParams.VISION_WIDTH ** -0.5
|
| | self.cls = nn.Parameter(scale * torch.randn(HParams.VISION_WIDTH))
|
| | self.pos = nn.Parameter(scale * torch.randn((HParams.IMAGE_SIZE // 16) ** 2 + 1, HParams.VISION_WIDTH))
|
| | self.ln_pre = nn.LayerNorm(HParams.VISION_WIDTH)
|
| |
|
| |
|
| | layer = nn.TransformerEncoderLayer(
|
| | d_model=HParams.VISION_WIDTH,
|
| | nhead=HParams.VISION_HEADS,
|
| | dim_feedforward=HParams.VISION_WIDTH * HParams.FFN_MULT,
|
| | activation="gelu",
|
| | batch_first=True,
|
| | norm_first=True
|
| | )
|
| | self.transformer = nn.TransformerEncoder(layer, HParams.VISION_LAYERS)
|
| |
|
| | self.ln_post = nn.LayerNorm(HParams.VISION_WIDTH)
|
| | self.proj = nn.Parameter(scale * torch.randn(HParams.VISION_WIDTH, HParams.EMBED_DIM))
|
| |
|
| | def forward(self, x):
|
| | x = self.conv1(x)
|
| | x = x.flatten(2).transpose(1, 2)
|
| |
|
| | cls = self.cls.expand(x.size(0), 1, -1)
|
| | x = torch.cat([cls, x], dim=1)
|
| | x = x + self.pos
|
| |
|
| | x = self.ln_pre(x)
|
| | x = self.transformer(x)
|
| | x = self.ln_post(x[:, 0])
|
| |
|
| | return x @ self.proj
|
| |
|
| | class TextTransformer(nn.Module):
|
| | def __init__(self, vocab_size):
|
| | super().__init__()
|
| | self.token = nn.Embedding(vocab_size, HParams.TEXT_WIDTH)
|
| | self.pos = nn.Parameter(torch.randn(HParams.MAX_TOKENS, HParams.TEXT_WIDTH))
|
| |
|
| | layer = nn.TransformerEncoderLayer(
|
| | d_model=HParams.TEXT_WIDTH,
|
| | nhead=HParams.TEXT_HEADS,
|
| | dim_feedforward=HParams.TEXT_WIDTH * HParams.FFN_MULT,
|
| | activation="gelu",
|
| | batch_first=True,
|
| | norm_first=True
|
| | )
|
| | self.transformer = nn.TransformerEncoder(layer, HParams.TEXT_LAYERS)
|
| | self.ln = nn.LayerNorm(HParams.TEXT_WIDTH)
|
| | self.proj = nn.Parameter(torch.randn(HParams.TEXT_WIDTH, HParams.EMBED_DIM))
|
| |
|
| | def forward(self, input_ids, attention_mask):
|
| | x = self.token(input_ids)
|
| | x = x + self.pos[:x.size(1)]
|
| |
|
| |
|
| | mask = attention_mask == 0
|
| |
|
| | x = self.transformer(x, src_key_padding_mask=mask)
|
| | x = self.ln(x)
|
| |
|
| |
|
| | eot_indices = attention_mask.sum(dim=1) - 1
|
| | x = x[torch.arange(x.size(0)), eot_indices]
|
| | return x @ self.proj
|
| |
|
| | class CLIP(nn.Module):
|
| | def __init__(self, vocab_size):
|
| | super().__init__()
|
| | self.visual = VisionTransformer()
|
| | self.text = TextTransformer(vocab_size)
|
| | self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
|
| |
|
| | def forward(self, img, text_ids, text_mask):
|
| |
|
| |
|
| | img_features = self.visual(img)
|
| | text_features = self.text(text_ids, text_mask)
|
| |
|
| | img_features = img_features / img_features.norm(dim=-1, keepdim=True)
|
| | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| |
|
| | logit_scale = self.logit_scale.exp().clamp(max=100)
|
| | logits_per_image = logit_scale * img_features @ text_features.t()
|
| | logits_per_text = logits_per_image.t()
|
| |
|
| | return logits_per_image, logits_per_text |