|
|
import os |
|
|
from collections import OrderedDict |
|
|
|
|
|
import torch |
|
|
from mmcv.cnn.bricks import DropPath |
|
|
from torch import nn |
|
|
from transformers import CLIPTokenizer |
|
|
|
|
|
from .utils import get_prompt_templates |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LanguageEncoder(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer='openai/clip-vit-base-patch32', |
|
|
dim_lang=512, |
|
|
dim_projection=512, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer) |
|
|
self.tokenizer.add_special_tokens( |
|
|
{'cls_token': self.tokenizer.eos_token}) |
|
|
|
|
|
max_token_num = self.tokenizer.model_max_length |
|
|
self.lang_encoder = Transformer(max_token_num, |
|
|
self.tokenizer.vocab_size, dim_lang) |
|
|
|
|
|
self.lang_proj = nn.Parameter(torch.empty(dim_lang, dim_projection)) |
|
|
self.max_token_num = max_token_num |
|
|
self.logit_scale = nn.Parameter(torch.ones([])) |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_mean_embeds(self, class_names, name='default'): |
|
|
|
|
|
def extract_mean_emb(txts): |
|
|
tokens = self.tokenizer( |
|
|
txts, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=self.max_token_num, |
|
|
return_tensors='pt') |
|
|
clss_embedding, _ = self.forward_language( |
|
|
(tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), |
|
|
norm=True, |
|
|
with_token_embed=False) |
|
|
clss_embedding = clss_embedding.mean(dim=0) |
|
|
clss_embedding /= clss_embedding.norm() |
|
|
return clss_embedding |
|
|
|
|
|
templates = get_prompt_templates() |
|
|
|
|
|
clss_embeddings = [] |
|
|
for clss in class_names: |
|
|
txts = [ |
|
|
template.format( |
|
|
clss.replace('-other', |
|
|
'').replace('-merged', |
|
|
'').replace('-stuff', '')) |
|
|
for template in templates |
|
|
] |
|
|
clss_embeddings.append(extract_mean_emb(txts)) |
|
|
|
|
|
text_emb = torch.stack(clss_embeddings, dim=0) |
|
|
setattr(self, '{}_text_embeddings'.format(name), text_emb) |
|
|
|
|
|
def get_text_embeds(self, txts, name='grounding', norm=False): |
|
|
tokens = self.tokenizer( |
|
|
txts, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=self.max_token_num, |
|
|
return_tensors='pt') |
|
|
tokens = {key: value.cuda() for key, value in tokens.items()} |
|
|
class_emb, token_emb = self.forward_language( |
|
|
(tokens['input_ids'], tokens['attention_mask']), norm=norm) |
|
|
ret = { |
|
|
'tokens': tokens, |
|
|
'token_emb': token_emb, |
|
|
'class_emb': class_emb, |
|
|
} |
|
|
setattr(self, '{}_token_embeddings'.format(name), ret) |
|
|
return ret |
|
|
|
|
|
def get_sot_token(self, device): |
|
|
|
|
|
|
|
|
return torch.tensor([[49406] * 77], device=device) |
|
|
|
|
|
def compute_similarity(self, v_emb, name='default'): |
|
|
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) |
|
|
t_emb = getattr(self, '{}_text_embeddings'.format(name)) |
|
|
output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose( |
|
|
1, 2) |
|
|
return output |
|
|
|
|
|
def forward_language(self, |
|
|
texts, |
|
|
norm=False, |
|
|
with_token_embed=True, |
|
|
with_cls_embed=True): |
|
|
x = self.lang_encoder(*texts) |
|
|
hidden_x = x['last_hidden_state'] |
|
|
|
|
|
class_embed = None |
|
|
if with_cls_embed: |
|
|
class_embed = hidden_x[torch.arange(hidden_x.size(0)), |
|
|
texts[0].argmax(dim=-1)] |
|
|
|
|
|
class_embed = class_embed @ self.lang_proj |
|
|
if norm: |
|
|
class_embed = class_embed / ( |
|
|
class_embed.norm(dim=-1, keepdim=True) + 1e-7) |
|
|
|
|
|
hidden_embed = None |
|
|
if with_token_embed: |
|
|
hidden_embed = hidden_x @ self.lang_proj |
|
|
if norm: |
|
|
hidden_embed = hidden_embed / ( |
|
|
hidden_embed.norm(dim=-1, keepdim=True) + 1e-7) |
|
|
|
|
|
return class_embed, hidden_embed |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
context_length, |
|
|
vocab_size, |
|
|
width, |
|
|
layers: int = 12, |
|
|
heads: int = 8, |
|
|
drop_path: float = 0.0, |
|
|
autogressive: bool = True): |
|
|
super().__init__() |
|
|
|
|
|
self.token_embedding = nn.Embedding(vocab_size, width) |
|
|
|
|
|
self.context_length = context_length |
|
|
self.positional_embedding = nn.Parameter( |
|
|
torch.empty(self.context_length, width)) |
|
|
|
|
|
self.width = width |
|
|
self.layers = layers |
|
|
self.autogressive = autogressive |
|
|
attn_mask = self.build_attention_mask() if autogressive else None |
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path, layers) |
|
|
] |
|
|
self.resblocks = nn.ModuleList([ |
|
|
ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) |
|
|
for i in range(layers) |
|
|
]) |
|
|
|
|
|
self.ln_final = LayerNorm(width) |
|
|
|
|
|
@property |
|
|
def dim_out(self): |
|
|
return self.width |
|
|
|
|
|
def build_attention_mask(self): |
|
|
|
|
|
|
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
|
mask.fill_(float('-inf')) |
|
|
mask.triu_(1) |
|
|
return mask |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
key_padding_mask = (attention_mask == 0) if ( |
|
|
not self.autogressive and attention_mask is not None) else None |
|
|
x = self.token_embedding(input_ids) |
|
|
x = x + self.positional_embedding |
|
|
x = x.permute(1, 0, 2) |
|
|
for block in self.resblocks: |
|
|
x = block(x, key_padding_mask) |
|
|
x = x.permute(1, 0, 2) |
|
|
|
|
|
x = self.ln_final(x) |
|
|
|
|
|
return {'last_hidden_state': x} |
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
|
|
|
def __init__(self, hidden_size, eps=1e-12): |
|
|
"""Construct a layernorm module in the TF style (epsilon inside the |
|
|
square root).""" |
|
|
super(LayerNorm, self).__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.bias = nn.Parameter(torch.zeros(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, x): |
|
|
pdtype = x.dtype |
|
|
x = x.float() |
|
|
u = x.mean(-1, keepdim=True) |
|
|
s = (x - u).pow(2).mean(-1, keepdim=True) |
|
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon) |
|
|
return self.weight * x.to(pdtype) + self.bias |
|
|
|
|
|
|
|
|
class QuickGELU(nn.Module): |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
d_model: int, |
|
|
n_head: int, |
|
|
attn_mask: torch.Tensor = None, |
|
|
drop_path: float = 0.0): |
|
|
super().__init__() |
|
|
|
|
|
self.attn = nn.MultiheadAttention(d_model, n_head) |
|
|
self.ln_1 = LayerNorm(d_model) |
|
|
self.mlp = nn.Sequential( |
|
|
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), |
|
|
('gelu', QuickGELU()), |
|
|
('c_proj', nn.Linear(d_model * 4, d_model))])) |
|
|
self.ln_2 = LayerNorm(d_model) |
|
|
self.attn_mask = attn_mask |
|
|
self.drop_path = DropPath( |
|
|
drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
|
|
def attention(self, |
|
|
x: torch.Tensor, |
|
|
key_padding_mask: torch.Tensor = None): |
|
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ |
|
|
if self.attn_mask is not None else None |
|
|
|
|
|
return self.attn( |
|
|
x, |
|
|
x, |
|
|
x, |
|
|
key_padding_mask=key_padding_mask, |
|
|
need_weights=False, |
|
|
attn_mask=self.attn_mask)[0] |
|
|
|
|
|
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): |
|
|
x = x + self.drop_path( |
|
|
self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) |
|
|
x = x + self.drop_path(self.mlp(self.ln_2(x))) |
|
|
return x |
|
|
|