Spaces:
Runtime error
Runtime error
| from typing import Optional, List | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from open_clip.transformer import _expand_token, to_2tuple | |
| def resample_abs_pos_embed( | |
| posemb, | |
| new_size: List[int], | |
| old_size: Optional[List[int]] = None, | |
| num_prefix_tokens: int = 1, | |
| interpolation: str = 'bicubic', | |
| antialias: bool = True | |
| ): | |
| # sort out sizes, assume square if old size not provided | |
| new_size = to_2tuple(new_size) | |
| new_ntok = new_size[0] * new_size[1] | |
| if not old_size: | |
| old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens)) | |
| old_size = to_2tuple(old_size) | |
| if new_size == old_size: # might not both be same container type | |
| return posemb | |
| if num_prefix_tokens: | |
| posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] | |
| else: | |
| posemb_prefix, posemb = None, posemb | |
| # do the interpolation | |
| posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) | |
| posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) | |
| posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1) | |
| # add back extra (class, etc) prefix tokens | |
| if posemb_prefix is not None: | |
| posemb = torch.cat([posemb_prefix, posemb], dim=1) | |
| return posemb | |
| class SelfSelfAttention(nn.Module): | |
| def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ss_attn_iter=1, | |
| ss_attn_temp=None): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim ** -0.5 | |
| self.ss_attn_iter = ss_attn_iter | |
| self.ss_attn_temp = ss_attn_temp | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x, attn_bias=None, prev_attn=None): | |
| x = x.transpose(0, 1) | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| self.v_values = v | |
| # original self-attention for the original path | |
| attn_ori_return = (q @ k.transpose(-2, -1)) * self.scale | |
| attn_ori = attn_ori_return.softmax(dim=-1) | |
| attn_ori = self.attn_drop(attn_ori) | |
| x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C) | |
| x_ori = self.proj_drop(self.proj(x_ori)) | |
| # GEM | |
| xs1 = v | |
| xs2 = k | |
| xs3 = q | |
| if self.ss_attn_temp is None: | |
| pre_norm = torch.norm(x, dim=-1).mean(dim=-1, keepdim=True).unsqueeze(1).unsqueeze(-1) | |
| inv_temp = pre_norm * self.scale | |
| else: | |
| inv_temp = self.ss_attn_temp | |
| for it in range(self.ss_attn_iter): | |
| xs1 = F.normalize(xs1, dim=-1) | |
| xs2 = F.normalize(xs2, dim=-1) | |
| xs3 = F.normalize(xs3, dim=-1) | |
| attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp | |
| attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp | |
| attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp | |
| attn1 = (attn_return1).softmax(dim=-1) | |
| attn2 = (attn_return2).softmax(dim=-1) | |
| attn3 = (attn_return3).softmax(dim=-1) | |
| xs1 = attn1 @ xs1 | |
| xs2 = attn2 @ xs2 | |
| xs3 = attn3 @ xs3 | |
| # Assigment to V | |
| xs1 = F.normalize(xs1, dim=-1) | |
| xs2 = F.normalize(xs2, dim=-1) | |
| xs3 = F.normalize(xs3, dim=-1) | |
| attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp | |
| attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp | |
| attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp | |
| attn1 = (attn_return1).softmax(dim=-1) | |
| attn2 = (attn_return2).softmax(dim=-1) | |
| attn3 = (attn_return3).softmax(dim=-1) | |
| xs1 = attn1 @ v | |
| xs2 = attn2 @ v | |
| xs3 = attn3 @ v | |
| xs = (xs1 + xs2 + xs3) / 3 | |
| x = xs.transpose(1, 2).reshape(B, N, C) | |
| x = self.proj_drop(self.proj(x)) | |
| return [x.transpose(0, 1), x_ori.transpose(0, 1)] | |
| class GEMResidualBlock(nn.Module): | |
| def __init__(self, res_block): | |
| super(GEMResidualBlock, self).__init__() | |
| self.res_block = res_block | |
| def forward(self, | |
| q_x: torch.Tensor, | |
| k_x: Optional[torch.Tensor] = None, | |
| v_x: Optional[torch.Tensor] = None, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| ): | |
| if isinstance(q_x, list): | |
| x_gem, q_x = q_x | |
| else: | |
| x_gem = q_x | |
| x_gem_res, x_ori_res = self.res_block.attn(x=self.res_block.ln_1(q_x)) | |
| x_gem_res, x_ori_res = self.res_block.ls_1(x_gem_res), self.res_block.ls_1(x_ori_res) | |
| # Original | |
| x_ori = q_x + x_ori_res | |
| x_ori = x_ori + self.res_block.ls_2(self.res_block.mlp(self.res_block.ln_2(x_ori))) | |
| # GEM | |
| x_gem = x_gem + x_gem_res | |
| return [x_gem, x_ori] | |
| class GEMViT(nn.Module): | |
| def __init__(self, vit): | |
| self.vit = vit | |
| def modified_vit_forward(self, x: torch.Tensor): | |
| x = self.conv1(x) # shape = [*, width, grid, grid] | |
| grid_h, grid_w = x.shape[2:] | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
| # class embeddings and positional embeddings | |
| x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) | |
| # shape = [*, grid ** 2 + 1, width] | |
| if x.shape[1] != self.positional_embedding.shape[1]: | |
| pos_emb = resample_abs_pos_embed(self.positional_embedding.unsqueeze(0), | |
| new_size=[grid_h, grid_w], | |
| # old_size=list(self.grid_size), | |
| num_prefix_tokens=1, | |
| interpolation='bicubic', | |
| antialias=True) | |
| else: | |
| pos_emb = self.positional_embedding | |
| x = x + pos_emb.to(x.dtype) | |
| # x = x + self.positional_embedding.to(x.dtype) | |
| x = self.patch_dropout(x) | |
| x = self.ln_pre(x) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x_gem, x = self.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x_gem = x_gem.permute(1, 0, 2) # LND -> NLD | |
| # Apply proj | |
| x = self.ln_post(x) | |
| x_gem = self.ln_post(x_gem) | |
| if self.proj is not None: | |
| x = x @ self.proj | |
| x_gem = x_gem @ self.proj | |
| return [x_gem, x] | |