DHPR's picture
Upload 25 files
f638d9c
import torch
import torch.nn as nn
from easydict import EasyDict as edict
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
class Model(torch.nn.Module):
def __init__(self, clip_model, config):
super().__init__()
self.clip_model = clip_model
# if config.i2t_encoder_layers > 0:
# self.i2t_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)])
# if config.t2i_encoder_layers > 0:
# self.t2i_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)])
self.config = config
def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224]
x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width]
x = torch.cat(
[self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
x = self.clip_model.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip_model.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.clip_model.visual.ln_post(x) # [NLD]
cls_token = self.clip_model.visual.ln_post(x[:, 0, :])
if self.clip_model.visual.proj is not None:
cls_token = cls_token @ self.clip_model.visual.proj
return x, cls_token
def txt_forward(self, text):
dtype = self.clip_model.dtype
x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model]
x = x + self.clip_model.positional_embedding.type(dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip_model.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.clip_model.ln_final(x).type(dtype)
# take features from the eot embedding (eot_token is the highest number in each sequence)
eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.clip_model.text_projection
return x, eot # [NLD]
def var_img_forward(self, image):
if len(image.shape) == 5:
img_features1, img_token1 = self.img_forward(image[:, 0, ...])
img_features2, img_token2 = self.img_forward(image[:, 1, ...])
img_token = (img_token1 + img_token2) / 2
img_features = (img_features1 + img_features2) / 2
else:
img_features, img_token = self.img_forward(image)
img_token = img_token / img_token.norm(dim=-1, keepdim=True)
return img_features, img_token
def var_txt_forward(self, text):
txt_features, txt_token = self.txt_forward(text)
txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True)
return txt_features, txt_token
def forward(self, image, text, past_img_tokens=None, past_txt_tokens=None):
# TODO: aggregate past img and txt tokens
img_features, img_token = self.var_img_forward(image)
txt_features, txt_token = self.var_txt_forward(text)
logit_scale = self.clip_model.logit_scale.exp()
if past_img_tokens is not None:
past_img_tokens = torch.cat([past_img_tokens, img_token], dim=0)
past_txt_tokens = torch.cat([past_txt_tokens, txt_token], dim=0)
batch_size = past_img_tokens.shape[0]
ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device)
logits_for_imgs = logit_scale * past_img_tokens @ past_txt_tokens.t()
logits_for_txts = logits_for_imgs.t()
# print(f"past_img_tokens: {past_img_tokens.shape}, past_txt_tokens: {past_txt_tokens.shape}")
# CLIP Contrastive Learning Loss Function
loss_img = torch.nn.CrossEntropyLoss()
loss_txt = torch.nn.CrossEntropyLoss()
loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2
else:
batch_size = img_token.shape[0]
ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device)
logits_for_imgs = logit_scale * img_token @ txt_token.t()
logits_for_txts = logits_for_imgs.t()
# CLIP Contrastive Learning Loss Function
loss_img = torch.nn.CrossEntropyLoss()
loss_txt = torch.nn.CrossEntropyLoss()
loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2
return dict(
img_token=img_token,
txt_token=txt_token,
img_features=img_features,
txt_features=txt_features,
loss=loss,
past_img_tokens=past_img_tokens,
past_txt_tokens=past_txt_tokens,
)