Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from easydict import EasyDict as edict | |
| class Block(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| class Model(torch.nn.Module): | |
| def __init__(self, clip_model, config): | |
| super().__init__() | |
| self.clip_model = clip_model | |
| # if config.i2t_encoder_layers > 0: | |
| # self.i2t_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)]) | |
| # if config.t2i_encoder_layers > 0: | |
| # self.t2i_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)]) | |
| self.config = config | |
| def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224] | |
| x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid] | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width] | |
| x = torch.cat( | |
| [self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], | |
| dim=1) # shape = [*, grid ** 2 + 1, width] | |
| x = x + self.clip_model.visual.positional_embedding.to(x.dtype) | |
| x = self.clip_model.visual.ln_pre(x) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.clip_model.visual.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.clip_model.visual.ln_post(x) # [NLD] | |
| cls_token = self.clip_model.visual.ln_post(x[:, 0, :]) | |
| if self.clip_model.visual.proj is not None: | |
| cls_token = cls_token @ self.clip_model.visual.proj | |
| return x, cls_token | |
| def txt_forward(self, text): | |
| dtype = self.clip_model.dtype | |
| x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model] | |
| x = x + self.clip_model.positional_embedding.type(dtype) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.clip_model.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.clip_model.ln_final(x).type(dtype) | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.clip_model.text_projection | |
| return x, eot # [NLD] | |
| def var_img_forward(self, image): | |
| if len(image.shape) == 5: | |
| img_features1, img_token1 = self.img_forward(image[:, 0, ...]) | |
| img_features2, img_token2 = self.img_forward(image[:, 1, ...]) | |
| img_token = (img_token1 + img_token2) / 2 | |
| img_features = (img_features1 + img_features2) / 2 | |
| else: | |
| img_features, img_token = self.img_forward(image) | |
| img_token = img_token / img_token.norm(dim=-1, keepdim=True) | |
| return img_features, img_token | |
| def var_txt_forward(self, text): | |
| txt_features, txt_token = self.txt_forward(text) | |
| txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True) | |
| return txt_features, txt_token | |
| def forward(self, image, text, past_img_tokens=None, past_txt_tokens=None): | |
| # TODO: aggregate past img and txt tokens | |
| img_features, img_token = self.var_img_forward(image) | |
| txt_features, txt_token = self.var_txt_forward(text) | |
| logit_scale = self.clip_model.logit_scale.exp() | |
| if past_img_tokens is not None: | |
| past_img_tokens = torch.cat([past_img_tokens, img_token], dim=0) | |
| past_txt_tokens = torch.cat([past_txt_tokens, txt_token], dim=0) | |
| batch_size = past_img_tokens.shape[0] | |
| ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device) | |
| logits_for_imgs = logit_scale * past_img_tokens @ past_txt_tokens.t() | |
| logits_for_txts = logits_for_imgs.t() | |
| # print(f"past_img_tokens: {past_img_tokens.shape}, past_txt_tokens: {past_txt_tokens.shape}") | |
| # CLIP Contrastive Learning Loss Function | |
| loss_img = torch.nn.CrossEntropyLoss() | |
| loss_txt = torch.nn.CrossEntropyLoss() | |
| loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2 | |
| else: | |
| batch_size = img_token.shape[0] | |
| ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device) | |
| logits_for_imgs = logit_scale * img_token @ txt_token.t() | |
| logits_for_txts = logits_for_imgs.t() | |
| # CLIP Contrastive Learning Loss Function | |
| loss_img = torch.nn.CrossEntropyLoss() | |
| loss_txt = torch.nn.CrossEntropyLoss() | |
| loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2 | |
| return dict( | |
| img_token=img_token, | |
| txt_token=txt_token, | |
| img_features=img_features, | |
| txt_features=txt_features, | |
| loss=loss, | |
| past_img_tokens=past_img_tokens, | |
| past_txt_tokens=past_txt_tokens, | |
| ) | |