maverickrzw's picture
des
2402804
import os
from collections import OrderedDict
import torch
from mmcv.cnn.bricks import DropPath
from torch import nn
from transformers import CLIPTokenizer
from .utils import get_prompt_templates
# modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/language/vlpencoder.py # noqa
class LanguageEncoder(nn.Module):
def __init__(
self,
tokenizer='openai/clip-vit-base-patch32',
dim_lang=512,
dim_projection=512,
):
super().__init__()
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
self.tokenizer.add_special_tokens(
{'cls_token': self.tokenizer.eos_token})
max_token_num = self.tokenizer.model_max_length
self.lang_encoder = Transformer(max_token_num,
self.tokenizer.vocab_size, dim_lang)
self.lang_proj = nn.Parameter(torch.empty(dim_lang, dim_projection))
self.max_token_num = max_token_num
self.logit_scale = nn.Parameter(torch.ones([]))
@torch.no_grad()
def get_mean_embeds(self, class_names, name='default'):
def extract_mean_emb(txts):
tokens = self.tokenizer(
txts,
padding='max_length',
truncation=True,
max_length=self.max_token_num,
return_tensors='pt')
clss_embedding, _ = self.forward_language(
(tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()),
norm=True,
with_token_embed=False)
clss_embedding = clss_embedding.mean(dim=0)
clss_embedding /= clss_embedding.norm()
return clss_embedding
templates = get_prompt_templates()
clss_embeddings = []
for clss in class_names:
txts = [
template.format(
clss.replace('-other',
'').replace('-merged',
'').replace('-stuff', ''))
for template in templates
]
clss_embeddings.append(extract_mean_emb(txts))
text_emb = torch.stack(clss_embeddings, dim=0)
setattr(self, '{}_text_embeddings'.format(name), text_emb)
def get_text_embeds(self, txts, name='grounding', norm=False):
tokens = self.tokenizer(
txts,
padding='max_length',
truncation=True,
max_length=self.max_token_num,
return_tensors='pt')
tokens = {key: value.cuda() for key, value in tokens.items()}
class_emb, token_emb = self.forward_language(
(tokens['input_ids'], tokens['attention_mask']), norm=norm)
ret = {
'tokens': tokens,
'token_emb': token_emb,
'class_emb': class_emb,
}
setattr(self, '{}_token_embeddings'.format(name), ret)
return ret
def get_sot_token(self, device):
# 49406: CLIP SOT token <|startoftext|>
# 77: CLIP context_length
return torch.tensor([[49406] * 77], device=device)
def compute_similarity(self, v_emb, name='default'):
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
t_emb = getattr(self, '{}_text_embeddings'.format(name))
output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(
1, 2)
return output
def forward_language(self,
texts,
norm=False,
with_token_embed=True,
with_cls_embed=True):
x = self.lang_encoder(*texts)
hidden_x = x['last_hidden_state']
class_embed = None
if with_cls_embed:
class_embed = hidden_x[torch.arange(hidden_x.size(0)),
texts[0].argmax(dim=-1)]
class_embed = class_embed @ self.lang_proj
if norm:
class_embed = class_embed / (
class_embed.norm(dim=-1, keepdim=True) + 1e-7)
hidden_embed = None
if with_token_embed:
hidden_embed = hidden_x @ self.lang_proj
if norm:
hidden_embed = hidden_embed / (
hidden_embed.norm(dim=-1, keepdim=True) + 1e-7)
return class_embed, hidden_embed
class Transformer(nn.Module):
def __init__(self,
context_length,
vocab_size,
width,
layers: int = 12,
heads: int = 8,
drop_path: float = 0.0,
autogressive: bool = True):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, width)
self.context_length = context_length
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, width))
self.width = width
self.layers = layers
self.autogressive = autogressive
attn_mask = self.build_attention_mask() if autogressive else None
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)
] # stochastic depth decay rule
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
for i in range(layers)
])
self.ln_final = LayerNorm(width)
@property
def dim_out(self):
return self.width
def build_attention_mask(self):
# lazily create causal attention mask,
# with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, input_ids, attention_mask=None):
key_padding_mask = (attention_mask == 0) if (
not self.autogressive and attention_mask is not None) else None
x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
for block in self.resblocks:
x = block(x, key_padding_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
return {'last_hidden_state': x}
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the
square root)."""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
pdtype = x.dtype
x = x.float()
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x.to(pdtype) + self.bias
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None,
drop_path: float = 0.0):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def attention(self,
x: torch.Tensor,
key_padding_mask: torch.Tensor = None):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
if self.attn_mask is not None else None
return self.attn(
x,
x,
x,
key_padding_mask=key_padding_mask,
need_weights=False,
attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
x = x + self.drop_path(
self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
x = x + self.drop_path(self.mlp(self.ln_2(x)))
return x