import torch import torch.nn as nn from easydict import EasyDict as edict class Block(nn.Module): def __init__(self, config): super().__init__() self.config = config class Model(torch.nn.Module): def __init__(self, clip_model, config): super().__init__() self.clip_model = clip_model # if config.i2t_encoder_layers > 0: # self.i2t_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)]) # if config.t2i_encoder_layers > 0: # self.t2i_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)]) self.config = config def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224] x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width] x = torch.cat( [self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.clip_model.visual.positional_embedding.to(x.dtype) x = self.clip_model.visual.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.clip_model.visual.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.clip_model.visual.ln_post(x) # [NLD] cls_token = self.clip_model.visual.ln_post(x[:, 0, :]) if self.clip_model.visual.proj is not None: cls_token = cls_token @ self.clip_model.visual.proj return x, cls_token def txt_forward(self, text): dtype = self.clip_model.dtype x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model] x = x + self.clip_model.positional_embedding.type(dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.clip_model.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.clip_model.ln_final(x).type(dtype) # take features from the eot embedding (eot_token is the highest number in each sequence) eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.clip_model.text_projection return x, eot # [NLD] def var_img_forward(self, image): if len(image.shape) == 5: img_features1, img_token1 = self.img_forward(image[:, 0, ...]) img_features2, img_token2 = self.img_forward(image[:, 1, ...]) img_token = (img_token1 + img_token2) / 2 img_features = (img_features1 + img_features2) / 2 else: img_features, img_token = self.img_forward(image) img_token = img_token / img_token.norm(dim=-1, keepdim=True) return img_features, img_token def var_txt_forward(self, text): txt_features, txt_token = self.txt_forward(text) txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True) return txt_features, txt_token def forward(self, image, text, past_img_tokens=None, past_txt_tokens=None): # TODO: aggregate past img and txt tokens img_features, img_token = self.var_img_forward(image) txt_features, txt_token = self.var_txt_forward(text) logit_scale = self.clip_model.logit_scale.exp() if past_img_tokens is not None: past_img_tokens = torch.cat([past_img_tokens, img_token], dim=0) past_txt_tokens = torch.cat([past_txt_tokens, txt_token], dim=0) batch_size = past_img_tokens.shape[0] ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device) logits_for_imgs = logit_scale * past_img_tokens @ past_txt_tokens.t() logits_for_txts = logits_for_imgs.t() # print(f"past_img_tokens: {past_img_tokens.shape}, past_txt_tokens: {past_txt_tokens.shape}") # CLIP Contrastive Learning Loss Function loss_img = torch.nn.CrossEntropyLoss() loss_txt = torch.nn.CrossEntropyLoss() loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2 else: batch_size = img_token.shape[0] ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device) logits_for_imgs = logit_scale * img_token @ txt_token.t() logits_for_txts = logits_for_imgs.t() # CLIP Contrastive Learning Loss Function loss_img = torch.nn.CrossEntropyLoss() loss_txt = torch.nn.CrossEntropyLoss() loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2 return dict( img_token=img_token, txt_token=txt_token, img_features=img_features, txt_features=txt_features, loss=loss, past_img_tokens=past_img_tokens, past_txt_tokens=past_txt_tokens, )