Spaces:
Runtime error
Runtime error
| import copy | |
| import math | |
| from typing import Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.nn import CrossEntropyLoss | |
| from transformers.activations import ACT2FN | |
| from einops import rearrange | |
| from transformers.models.t5.configuration_t5 import T5Config | |
| from transformers.modeling_utils import ModuleUtilsMixin | |
| from einops import rearrange, reduce | |
| class FeedForward(nn.Module): | |
| def __init__(self, config: T5Config): | |
| super().__init__() | |
| self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) | |
| self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| self.act = ACT2FN["gelu"] | |
| self.layer_norm = nn.LayerNorm(config.d_model) | |
| def forward(self, x): | |
| x_hidden = self.wo(self.dropout(self.act(self.wi(self.layer_norm(x))))) | |
| return x + self.dropout(x_hidden) | |
| class Attention(nn.Module): | |
| def __init__(self, config: T5Config, has_relative_attention_bias=False): | |
| super().__init__() | |
| self.is_decoder = config.is_decoder | |
| self.has_relative_attention_bias = has_relative_attention_bias | |
| self.relative_attention_num_buckets = config.relative_attention_num_buckets | |
| self.relative_attention_max_distance = config.relative_attention_max_distance | |
| self.d_model = config.d_model | |
| self.key_value_proj_dim = config.d_kv | |
| self.n_heads = config.num_heads | |
| self.dropout = config.dropout_rate | |
| self.inner_dim = self.n_heads * self.key_value_proj_dim | |
| # Mesh TensorFlow initialization to avoid scaling before softmax | |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) | |
| self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) | |
| self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) | |
| self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) | |
| if self.has_relative_attention_bias: | |
| self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) | |
| def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): | |
| """ | |
| memory_position - query_position -> bucket_idx. | |
| If bidirectional=False, then positive relative positions are invalid. | |
| We use smaller buckets for small absolute relative_position | |
| and larger buckets for larger absolute relative_positions. | |
| * All relative positions >=max_distance map to the same bucket. | |
| * All relative positions <=-max_distance map to the same bucket. | |
| This should allow for more graceful generalization to longer sequences than the model has been trained on | |
| Args: | |
| relative_position: an int32 Tensor | |
| bidirectional: a boolean - whether the attention is bidirectional | |
| num_buckets: an integer | |
| max_distance: an integer | |
| Returns: | |
| a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) | |
| """ | |
| relative_buckets = 0 | |
| if bidirectional: | |
| num_buckets //= 2 | |
| relative_buckets += (relative_position > 0).to(torch.long) * num_buckets | |
| relative_position = torch.abs(relative_position) | |
| else: | |
| relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) | |
| # now relative_position is in the range [0, inf) | |
| # half of the buckets are for exact increments in positions | |
| max_exact = num_buckets // 2 | |
| is_small = relative_position < max_exact | |
| # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance | |
| relative_position_if_large = max_exact + (torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * | |
| (num_buckets - max_exact)).to(torch.long) | |
| relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)) | |
| relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) | |
| return relative_buckets | |
| def compute_bias(self, query_length, key_length, device=None): | |
| """Compute binned relative position bias""" | |
| if device is None: | |
| device = self.relative_attention_bias.weight.device | |
| context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] | |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] | |
| relative_position = memory_position - context_position # shape (query_length, key_length) | |
| relative_position_bucket = self._relative_position_bucket( | |
| relative_position, # shape (query_length, key_length) | |
| bidirectional=(not self.is_decoder), | |
| num_buckets=self.relative_attention_num_buckets, | |
| max_distance=self.relative_attention_max_distance, | |
| ) | |
| values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) | |
| values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) | |
| return values | |
| def forward(self, x, mask=None, x_kv=None, pos_bias=None): | |
| """ | |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). | |
| """ | |
| # Input is (batch_size, seq_length, dim) | |
| # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) | |
| batch_size, seq_length = x.shape[:2] | |
| real_seq_length = seq_length | |
| key_length = real_seq_length if x_kv is None else x_kv.shape[1] | |
| reshape = lambda states: rearrange(states, 'b s (h d) -> b h s d', h=self.n_heads) | |
| unshape = lambda states: rearrange(states, 'b h s d -> b s (h d)') | |
| q = reshape(self.q(x)) # (batch_size, n_heads, seq_length, dim_per_head) | |
| k = reshape(self.k(x if x_kv is None else x_kv)) | |
| v = reshape(self.v(x if x_kv is None else x_kv)) | |
| # compute scores | |
| scores = torch.matmul(q, k.transpose(3, 2)) | |
| if pos_bias is None: | |
| if not self.has_relative_attention_bias: | |
| pos_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype) | |
| else: | |
| pos_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) | |
| if mask is not None: | |
| pos_bias = pos_bias + mask # (batch_size, n_heads, seq_length, key_length) | |
| position_bias_masked = pos_bias | |
| scores += position_bias_masked | |
| attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (B, H, seq_length, key_length) | |
| attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) # (B, H, seq_length, key_length) | |
| attn_output = unshape(torch.matmul(attn_weights, v)) # (batch_size, seq_length, dim) | |
| attn_output = self.o(attn_output) | |
| return (attn_output, pos_bias) | |
| class LayerSelfAttention(nn.Module): | |
| def __init__(self, config, has_relative_attention_bias=False): | |
| super().__init__() | |
| self.SelfAttention = Attention(config, has_relative_attention_bias=has_relative_attention_bias) | |
| self.layer_norm = nn.LayerNorm(config.d_model) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward(self, x, mask=None, pos_bias=None): # x + drop(attn(ln(x))) | |
| h = self.layer_norm(x) | |
| outputs = self.SelfAttention(h, mask=mask, pos_bias=pos_bias) | |
| x = x + self.dropout(outputs[0]) | |
| return (x, outputs[1]) # outputs[1] is pos_bias | |
| class LayerCrossAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.EncDecAttention = Attention(config, has_relative_attention_bias=False) | |
| self.layer_norm = nn.LayerNorm(config.d_model) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward(self, x, x_kv, mask=None, pos_bias=None): # x + drop(attn(ln(x), x_kv)) | |
| h = self.layer_norm(x) | |
| outputs = self.EncDecAttention(h, mask=mask, x_kv=x_kv, pos_bias=pos_bias) | |
| x = x + self.dropout(outputs[0]) | |
| return (x, outputs[1]) # outputs[1] is pos_bias | |
| class Block(nn.Module): | |
| def __init__(self, config, has_relative_attention_bias=False): | |
| super().__init__() | |
| self.is_decoder = config.is_decoder | |
| self.layer = nn.ModuleList() | |
| self.layer.append(LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) | |
| if self.is_decoder: | |
| self.layer.append(LayerCrossAttention(config)) | |
| self.layer.append(FeedForward(config)) | |
| def forward(self, x, mask=None, pos_bias=None, context=None, context_mask=None, context_pos_bias=None): | |
| self_attention_outputs = self.layer[0](x, mask=mask, pos_bias=pos_bias) | |
| hidden_states = self_attention_outputs[0] | |
| do_cross_attention = self.is_decoder and context is not None | |
| if do_cross_attention: | |
| cross_attention_outputs = self.layer[1]( | |
| hidden_states, | |
| x_kv=context, | |
| mask=context_mask, | |
| pos_bias=context_pos_bias, | |
| ) | |
| hidden_states = cross_attention_outputs[0] | |
| # Apply Feed Forward layer | |
| hidden_states = self.layer[-1](hidden_states) | |
| pos_bias = self_attention_outputs[1] | |
| context_pos_bias = cross_attention_outputs[1] if do_cross_attention else None | |
| return (hidden_states, pos_bias, context_pos_bias) | |
| class Stack(nn.Module): | |
| def __init__(self, config, is_decoder=True, has_embedding=False, generate_causal_mask=False): | |
| super().__init__() | |
| self.config = config | |
| if has_embedding: | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) | |
| self.is_decoder = is_decoder | |
| self.dtype = torch.float32 | |
| self.generate_causal_mask = generate_causal_mask | |
| self.block = nn.ModuleList([Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]) | |
| self.final_layer_norm = nn.LayerNorm(config.d_model) | |
| self.dropout = nn.Dropout(config.dropout_rate) | |
| def forward( | |
| self, | |
| input_ids=None, | |
| dec_hidden_states=None, | |
| enc_hidden_states=None, | |
| dec_attention_mask=None, | |
| enc_attention_mask=None, | |
| ): | |
| input_shape = input_ids.size() if input_ids is not None else dec_hidden_states.shape[:-1] | |
| batch_size, seq_length = input_shape | |
| if input_ids is not None: | |
| input_ids = input_ids.view(-1, input_shape[-1]) | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| else: | |
| inputs_embeds = dec_hidden_states | |
| # required mask seq length can be calculated via length of past | |
| mask_seq_length = seq_length | |
| if dec_attention_mask is None: | |
| dec_attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) | |
| if self.is_decoder and enc_attention_mask is None and enc_hidden_states is not None: | |
| encoder_seq_length = enc_hidden_states.shape[1] | |
| enc_attention_mask = torch.ones(batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long) | |
| # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |
| # ourselves in which case we just need to make it broadcastable to all heads. | |
| extended_attention_mask = self.get_extended_attention_mask(dec_attention_mask, input_shape) | |
| # If a 2D or 3D attention mask is provided for the cross-attention | |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | |
| if self.is_decoder and enc_hidden_states is not None: | |
| encoder_batch_size, encoder_sequence_length, _ = enc_hidden_states.size() | |
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | |
| if enc_attention_mask is None: | |
| enc_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) | |
| encoder_extended_attention_mask = self.invert_attention_mask(enc_attention_mask) | |
| else: | |
| encoder_extended_attention_mask = None | |
| pos_bias = None | |
| context_pos_bias = None | |
| hidden_states = self.dropout(inputs_embeds) | |
| for i, layer_module in enumerate(self.block): | |
| layer_outputs = layer_module( | |
| hidden_states, | |
| mask=extended_attention_mask, # [1, 1, 1, 1 ] [B, L] | |
| pos_bias=pos_bias, | |
| context=enc_hidden_states, | |
| context_mask=encoder_extended_attention_mask, | |
| context_pos_bias=context_pos_bias, | |
| ) | |
| # layer_outputs is a tuple with: | |
| layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] | |
| hidden_states, present_key_value_state = layer_outputs[:2] # [B, L, D], None | |
| # We share the position biases between the layers - the first layer store them | |
| pos_bias = layer_outputs[2] # [B, H, L, L] | |
| if self.is_decoder and enc_hidden_states is not None: | |
| context_pos_bias = layer_outputs[3] | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| return (hidden_states,) | |
| def invert_attention_mask(self, attention_mask): | |
| """ | |
| Input: 1 for attend, 0 for masked/ignored | |
| Output: 0 for attend, -1e30 for masked/ignored. | |
| Then we can add it to the attention logits. | |
| [B, L] -> [B, 1, 1, L] | |
| [B, L, L] -> [B, 1, L, L] | |
| """ | |
| if attention_mask.dim() == 3: | |
| extended_attention_mask = attention_mask[:, None, :, :] | |
| if attention_mask.dim() == 2: | |
| extended_attention_mask = attention_mask[:, None, None, :] | |
| extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility | |
| extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min | |
| return extended_attention_mask | |
| def get_extended_attention_mask(self, attention_mask, input_shape, device=None, dtype=None): | |
| """ | |
| Makes broadcastable attention and causal masks so that future and masked tokens are ignored. | |
| attention_mask: 1 for attend, 0 for masked/ignored | |
| Return: The extended attention mask: 0 for attend, -1e30 for masked/ignored | |
| [B, L] -> [B, 1, 1, L] | |
| [B, L, L] -> [B, 1, L, L] | |
| """ | |
| dtype = dtype if dtype else attention_mask.dtype | |
| # If input [B, query_length, key_length] -> [B, 1, query_length, key_length] | |
| if attention_mask.dim() == 3: | |
| extended_attention_mask = attention_mask[:, None, :, :] | |
| elif attention_mask.dim() == 2: | |
| # Provided a padding mask of dimensions [batch_size, seq_length] | |
| # - if the model is a decoder, apply a causal mask in addition to the padding mask | |
| # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] | |
| if self.config.is_decoder and self.generate_causal_mask: | |
| extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(input_shape, attention_mask, device) | |
| else: | |
| extended_attention_mask = attention_mask[:, None, None, :] | |
| else: | |
| raise ValueError(f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})") | |
| # Input: valid = 1, padding = 0 | |
| # Output: valid = 0, padding = -1e30 | |
| # => then we can add it to the attention logits | |
| extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility | |
| extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min | |
| return extended_attention_mask | |
| class Model(torch.nn.Module): | |
| def __init__(self, clip_model, config): | |
| super().__init__() | |
| self.clip_model = clip_model | |
| self.config = config | |
| if self.config.has_extra_txt_decoder: | |
| self.txt_decoder = Stack(config.extra_decoder) | |
| self.itm_txt_head = torch.nn.Linear(config.extra_decoder.d_model, 2) | |
| if self.config.has_extra_img_decoder: | |
| self.img_decoder = Stack(config.extra_decoder) | |
| self.itm_img_head = torch.nn.Linear(config.extra_decoder.d_model, 2) | |
| if self.config.has_extra_mix_decoder: | |
| self.mix_decoder = Stack(config.extra_decoder) | |
| self.mix_itm_head = torch.nn.Linear(config.extra_decoder.d_model, 2) | |
| if self.config.has_extra_gen_decoder: | |
| self.gen_decoder = Stack(config.extra_decoder, has_embedding=True, generate_causal_mask=True) | |
| self.gen_head = torch.nn.Linear(config.extra_decoder.d_model, config.vocab_size) | |
| self.config = config | |
| def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224] | |
| x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid] | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width] | |
| x = torch.cat( | |
| [self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], | |
| dim=1) # shape = [*, grid ** 2 + 1, width] | |
| x = x + self.clip_model.visual.positional_embedding.to(x.dtype) | |
| x = self.clip_model.visual.ln_pre(x) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.clip_model.visual.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.clip_model.visual.ln_post(x) # [NLD] | |
| if self.clip_model.visual.proj is not None: | |
| proj = self.clip_model.visual.proj[None, :, :] | |
| x = (x @ proj) | |
| cls_token = x[:, 0, :] | |
| return x, cls_token | |
| def txt_forward(self, text): | |
| dtype = self.clip_model.dtype | |
| x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model] | |
| x = x + self.clip_model.positional_embedding.type(dtype) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.clip_model.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.clip_model.ln_final(x).type(dtype) | |
| proj = self.clip_model.text_projection[None, :, :] | |
| x = (x @ proj) | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] | |
| return x, eot # [NLD] | |
| def var_img_forward(self, image): | |
| if len(image.shape) == 5: | |
| img_features1, img_token1 = self.img_forward(image[:, 0, ...]) | |
| img_features2, img_token2 = self.img_forward(image[:, 1, ...]) | |
| img_token = (img_token1 + img_token2) / 2 | |
| img_features = (img_features1 + img_features2) / 2 | |
| else: | |
| img_features, img_token = self.img_forward(image) | |
| img_token = img_token / img_token.norm(dim=-1, keepdim=True) | |
| return img_features, img_token | |
| def var_txt_forward(self, text): | |
| txt_features, txt_token = self.txt_forward(text) | |
| txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True) | |
| return txt_features, txt_token | |
| def get_device(self): | |
| return next(self.parameters()).device | |
| def get_features(self, image=None, text_ids=None): | |
| outputs = {} | |
| if image is not None: | |
| img_features, img_token = self.var_img_forward(image) | |
| outputs['img_features'] = img_features | |
| outputs['img_token'] = img_token | |
| outputs['img_mask'] = torch.ones_like(img_features[:, :, 0]) | |
| if text_ids is not None: | |
| txt_features, txt_token = self.var_txt_forward(text_ids) | |
| outputs['txt_features'] = txt_features | |
| outputs['txt_token'] = txt_token | |
| outputs['txt_mask'] = (text_ids != 0).to(txt_features.dtype) | |
| return outputs | |
| def get_prediction(self, img_features, txt_features, img_mask=None, txt_mask=None, decoder="txt_decoder", **kwargs): | |
| outputs = {} | |
| if decoder == 'txt_decoder': | |
| hidden_states = self.txt_decoder( | |
| dec_hidden_states=txt_features, | |
| enc_hidden_states=img_features, | |
| enc_attention_mask=img_mask, | |
| dec_attention_mask=txt_mask, | |
| ) | |
| outputs['itm_txt_logits'] = self.itm_txt_head(hidden_states[0][:, 0, :]) | |
| outputs['itm_txt_probs'] = torch.softmax(outputs['itm_txt_logits'], dim=-1) | |
| if decoder == 'img_decoder': | |
| hidden_states = self.img_decoder( | |
| dec_hidden_states=img_features, | |
| enc_hidden_states=txt_features, | |
| enc_attention_mask=txt_mask, | |
| dec_attention_mask=img_mask, | |
| ) | |
| outputs['itm_img_logits'] = self.itm_img_head(hidden_states[0][:, 0, :]) | |
| outputs['itm_img_probs'] = torch.softmax(outputs['itm_img_logits'], dim=-1) | |
| return outputs | |
| def forward(self, image, text, itm_text=None, itm_labels=None, gen_inputs=None, gen_labels=None): # , gen_inputs, gen_labels, **kwargs): | |
| img_features, img_token = self.var_img_forward(image) | |
| txt_features, txt_token = self.var_txt_forward(text) | |
| itm_txt_features, _ = self.var_txt_forward(itm_text) | |
| itm_txt_mask = (itm_text != 0).to(itm_txt_features.dtype) | |
| outputs = dict( | |
| img_token=img_token, | |
| txt_token=txt_token, | |
| img_features=img_features, | |
| txt_features=txt_features, | |
| ) | |
| if self.config.has_extra_txt_decoder and itm_text is not None: | |
| itm_img_features = img_features | |
| itm_txt_states = self.txt_decoder( | |
| dec_hidden_states=itm_txt_features, | |
| enc_hidden_states=itm_img_features, | |
| enc_attention_mask=None, | |
| dec_attention_mask=itm_txt_mask, | |
| ) | |
| outputs['itm_txt_logits'] = self.itm_txt_head(itm_txt_states[0][:, 0]) | |
| if self.config.has_extra_img_decoder and itm_text is not None: | |
| itm_img_features = img_features | |
| itm_img_states = self.img_decoder( | |
| dec_hidden_states=itm_img_features, | |
| enc_hidden_states=itm_txt_features, | |
| enc_attention_mask=itm_txt_mask, | |
| dec_attention_mask=None, | |
| ) | |
| outputs['itm_img_logits'] = self.itm_img_head(itm_img_states[0][:, 0]) | |
| if self.config.has_extra_mix_decoder: | |
| pass | |
| if self.config.has_extra_gen_decoder: | |
| gen_features = self.gen_decoder( | |
| input_ids=gen_inputs, | |
| enc_hidden_states=img_features, | |
| enc_attention_mask=None, | |
| dec_attention_mask=None, | |
| labels=gen_labels, | |
| ) | |
| outputs['gen_logits'] = self.gen_head(gen_features[0]) | |
| return outputs | |
| if __name__ == "__main__": | |
| import sys | |
| from omegaconf import OmegaConf | |
| sys.path.append("/home/quang/workspace/traffic_var") | |
| from config.examples import with_decoder_config as config | |
| config.has_extra_txt_decoder = True | |
| print(OmegaConf.to_yaml(config)) | |
| import clip | |
| def get_resolution(model): | |
| return model.visual.input_resolution if hasattr(model, 'visual') else model.input_resolution | |
| model, _ = clip.load(config.clip_model, jit=False, device="cpu") | |
| config.img_size = get_resolution(model) | |
| model = Model(model, config) |