Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from open_clip.transformer import VisionTransformer | |
| from .gem_utils import SelfSelfAttention, GEMResidualBlock, modified_vit_forward | |
| class GEMWrapper(nn.Module): | |
| def __init__(self, model, tokenizer, depth=7, ss_attn_iter=1, ss_attn_temp=None): | |
| super(GEMWrapper, self).__init__() | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.depth = depth | |
| self.ss_attn_iter = ss_attn_iter | |
| self.ss_attn_temp = ss_attn_temp | |
| self.patch_size = self.model.visual.patch_size[0] | |
| self.apply_gem() | |
| def apply_gem(self): | |
| for i in range(1, self.depth): | |
| # Extract info from the original ViT | |
| num_heads = self.model.visual.transformer.resblocks[-i].attn.num_heads | |
| dim = int(self.model.visual.transformer.resblocks[-i].attn.head_dim * num_heads) | |
| qkv_bias = True | |
| # Init the self-self attention layer | |
| ss_attn = SelfSelfAttention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, | |
| ss_attn_iter=self.ss_attn_iter, ss_attn_temp=self.ss_attn_temp) | |
| # Copy necessary weights | |
| ss_attn.qkv.weight.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_weight.clone() | |
| ss_attn.qkv.bias.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_bias.clone() | |
| ss_attn.proj.weight.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.weight.clone() | |
| ss_attn.proj.bias.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.bias.clone() | |
| # Swap the original Attention with our SelfSelfAttention | |
| self.model.visual.transformer.resblocks[-i].attn = ss_attn | |
| # Wrap Residual block to handle SelfSelfAttention outputs | |
| self.model.visual.transformer.resblocks[-i] = GEMResidualBlock(self.model.visual.transformer.resblocks[-i]) | |
| # Modify ViT's forward function | |
| self.model.visual.forward = modified_vit_forward.__get__(self.model.visual, VisionTransformer) | |
| return | |
| def encode_text(self, text: list): | |
| prompts = [f'a photo of a {cls}.' for cls in text] | |
| tokenized_prompts = self.tokenizer(prompts).to(self.model.visual.proj.device) | |
| text_embedding = self.model.encode_text(tokenized_prompts) | |
| text_embedding = F.normalize(text_embedding, dim=-1) | |
| return text_embedding.unsqueeze(0) | |
| def min_max(self, logits): | |
| B, num_prompt = logits.shape[:2] | |
| logits_min = logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1) | |
| logits_max = logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1) | |
| logits = (logits - logits_min) / (logits_max - logits_min) | |
| return logits | |
| def forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False): | |
| """ | |
| :param image: torch.Tensor [1, 3, H, W] | |
| :param text: list[] | |
| :param normalize: bool - if True performs min-max normalization | |
| :param return_ori: bool - if True uses the features from the original visual encoder | |
| """ | |
| # Image | |
| W, H = image.shape[-2:] | |
| feat_gem, feat_ori = self.model.visual(image) | |
| image_feat = feat_ori if return_ori else feat_gem | |
| image_feat = F.normalize(image_feat, dim=-1) # [1, N, dim] | |
| # Text | |
| text_embeddings = self.encode_text(text) # [1, num_prompt, dim] | |
| # Image-Text matching | |
| img_txt_matching = image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [1, N, num_prompt] | |
| img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h', | |
| w=W//self.patch_size, h=H//self.patch_size) # [1, num_prompt, w, h] | |
| # Interpolate | |
| img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [1, num_prompt, W, H] | |
| # Heat Maps | |
| if normalize: | |
| img_txt_matching = self.min_max(img_txt_matching) | |
| return img_txt_matching | |
| def batched_forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False): | |
| """ | |
| :param image: torch.Tensor [B, 3, H, W] | |
| :param text: list[list[]] | |
| :param normalize: bool - if True performs min-max normalization | |
| :param return_ori: bool - if True uses the features from the original visual encoder | |
| """ | |
| L = len(text) | |
| cumm_idx = np.cumsum([len(t) for t in text]).tolist() | |
| B, _, W, H = image.shape | |
| assert B == L, f'Number of prompts L: {L} should be the same as number of images B: {B}.' | |
| # Image | |
| feat_gem, feat_ori = self.model.visual(image) | |
| image_feat = feat_ori if return_ori else feat_gem | |
| image_feat = F.normalize(image_feat, dim=-1) # [B, N, dim] | |
| # Text | |
| flatten_text = [t for sub_text in text for t in sub_text] | |
| text_embeddings = self.encode_text(flatten_text) # [B, num_prompt, dim] | |
| # Image-Text matching | |
| img_txt_matching = 100 * image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [B, N, num_prompt] | |
| img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h', | |
| w=W // self.patch_size, h=H // self.patch_size) # [B, num_prompt, w, h] | |
| # Interpolate | |
| img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [B,num_prompt, W, H] | |
| # Heat Maps | |
| if normalize: | |
| img_txt_matching = self.min_max(img_txt_matching) # [B,num_prompt, W, H] | |
| # unflatten | |
| img_txt_matching = torch.tensor_split(img_txt_matching, cumm_idx[:-1], dim=1) | |
| img_txt_matching = [itm[i] for i, itm in enumerate(img_txt_matching)] | |
| return img_txt_matching | |