import copy import math from typing import Optional, Tuple, Union import numpy as np import torch from torch import nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from einops import rearrange from transformers.models.t5.configuration_t5 import T5Config from transformers.modeling_utils import ModuleUtilsMixin from einops import rearrange, reduce class FeedForward(nn.Module): def __init__(self, config: T5Config): super().__init__() self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) self.dropout = nn.Dropout(config.dropout_rate) self.act = ACT2FN["gelu"] self.layer_norm = nn.LayerNorm(config.d_model) def forward(self, x): x_hidden = self.wo(self.dropout(self.act(self.wi(self.layer_norm(x))))) return x + self.dropout(x_hidden) class Attention(nn.Module): def __init__(self, config: T5Config, has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_max_distance = config.relative_attention_max_distance self.d_model = config.d_model self.key_value_proj_dim = config.d_kv self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """ memory_position - query_position -> bucket_idx. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. * All relative positions >=max_distance map to the same bucket. * All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ relative_buckets = 0 if bidirectional: num_buckets //= 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance relative_position_if_large = max_exact + (torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).to(torch.long) relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets def compute_bias(self, query_length, key_length, device=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values def forward(self, x, mask=None, x_kv=None, pos_bias=None): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) batch_size, seq_length = x.shape[:2] real_seq_length = seq_length key_length = real_seq_length if x_kv is None else x_kv.shape[1] reshape = lambda states: rearrange(states, 'b s (h d) -> b h s d', h=self.n_heads) unshape = lambda states: rearrange(states, 'b h s d -> b s (h d)') q = reshape(self.q(x)) # (batch_size, n_heads, seq_length, dim_per_head) k = reshape(self.k(x if x_kv is None else x_kv)) v = reshape(self.v(x if x_kv is None else x_kv)) # compute scores scores = torch.matmul(q, k.transpose(3, 2)) if pos_bias is None: if not self.has_relative_attention_bias: pos_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype) else: pos_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) if mask is not None: pos_bias = pos_bias + mask # (batch_size, n_heads, seq_length, key_length) position_bias_masked = pos_bias scores += position_bias_masked attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (B, H, seq_length, key_length) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) # (B, H, seq_length, key_length) attn_output = unshape(torch.matmul(attn_weights, v)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) return (attn_output, pos_bias) class LayerSelfAttention(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.SelfAttention = Attention(config, has_relative_attention_bias=has_relative_attention_bias) self.layer_norm = nn.LayerNorm(config.d_model) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, x, mask=None, pos_bias=None): # x + drop(attn(ln(x))) h = self.layer_norm(x) outputs = self.SelfAttention(h, mask=mask, pos_bias=pos_bias) x = x + self.dropout(outputs[0]) return (x, outputs[1]) # outputs[1] is pos_bias class LayerCrossAttention(nn.Module): def __init__(self, config): super().__init__() self.EncDecAttention = Attention(config, has_relative_attention_bias=False) self.layer_norm = nn.LayerNorm(config.d_model) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, x, x_kv, mask=None, pos_bias=None): # x + drop(attn(ln(x), x_kv)) h = self.layer_norm(x) outputs = self.EncDecAttention(h, mask=mask, x_kv=x_kv, pos_bias=pos_bias) x = x + self.dropout(outputs[0]) return (x, outputs[1]) # outputs[1] is pos_bias class Block(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() self.layer.append(LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) if self.is_decoder: self.layer.append(LayerCrossAttention(config)) self.layer.append(FeedForward(config)) def forward(self, x, mask=None, pos_bias=None, context=None, context_mask=None, context_pos_bias=None): self_attention_outputs = self.layer[0](x, mask=mask, pos_bias=pos_bias) hidden_states = self_attention_outputs[0] do_cross_attention = self.is_decoder and context is not None if do_cross_attention: cross_attention_outputs = self.layer[1]( hidden_states, x_kv=context, mask=context_mask, pos_bias=context_pos_bias, ) hidden_states = cross_attention_outputs[0] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) pos_bias = self_attention_outputs[1] context_pos_bias = cross_attention_outputs[1] if do_cross_attention else None return (hidden_states, pos_bias, context_pos_bias) class Stack(nn.Module): def __init__(self, config, is_decoder=True, has_embedding=False, generate_causal_mask=False): super().__init__() self.config = config if has_embedding: self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.is_decoder = is_decoder self.dtype = torch.float32 self.generate_causal_mask = generate_causal_mask self.block = nn.ModuleList([Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]) self.final_layer_norm = nn.LayerNorm(config.d_model) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, input_ids=None, dec_hidden_states=None, enc_hidden_states=None, dec_attention_mask=None, enc_attention_mask=None, ): input_shape = input_ids.size() if input_ids is not None else dec_hidden_states.shape[:-1] batch_size, seq_length = input_shape if input_ids is not None: input_ids = input_ids.view(-1, input_shape[-1]) inputs_embeds = self.embed_tokens(input_ids) else: inputs_embeds = dec_hidden_states # required mask seq length can be calculated via length of past mask_seq_length = seq_length if dec_attention_mask is None: dec_attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.is_decoder and enc_attention_mask is None and enc_hidden_states is not None: encoder_seq_length = enc_hidden_states.shape[1] enc_attention_mask = torch.ones(batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(dec_attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.is_decoder and enc_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = enc_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if enc_attention_mask is None: enc_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) encoder_extended_attention_mask = self.invert_attention_mask(enc_attention_mask) else: encoder_extended_attention_mask = None pos_bias = None context_pos_bias = None hidden_states = self.dropout(inputs_embeds) for i, layer_module in enumerate(self.block): layer_outputs = layer_module( hidden_states, mask=extended_attention_mask, # [1, 1, 1, 1 ] [B, L] pos_bias=pos_bias, context=enc_hidden_states, context_mask=encoder_extended_attention_mask, context_pos_bias=context_pos_bias, ) # layer_outputs is a tuple with: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] hidden_states, present_key_value_state = layer_outputs[:2] # [B, L, D], None # We share the position biases between the layers - the first layer store them pos_bias = layer_outputs[2] # [B, H, L, L] if self.is_decoder and enc_hidden_states is not None: context_pos_bias = layer_outputs[3] hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) return (hidden_states,) def invert_attention_mask(self, attention_mask): """ Input: 1 for attend, 0 for masked/ignored Output: 0 for attend, -1e30 for masked/ignored. Then we can add it to the attention logits. [B, L] -> [B, 1, 1, L] [B, L, L] -> [B, 1, L, L] """ if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] if attention_mask.dim() == 2: extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min return extended_attention_mask def get_extended_attention_mask(self, attention_mask, input_shape, device=None, dtype=None): """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. attention_mask: 1 for attend, 0 for masked/ignored Return: The extended attention mask: 0 for attend, -1e30 for masked/ignored [B, L] -> [B, 1, 1, L] [B, L, L] -> [B, 1, L, L] """ dtype = dtype if dtype else attention_mask.dtype # If input [B, query_length, key_length] -> [B, 1, query_length, key_length] if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and self.generate_causal_mask: extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(input_shape, attention_mask, device) else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError(f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})") # Input: valid = 1, padding = 0 # Output: valid = 0, padding = -1e30 # => then we can add it to the attention logits extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min return extended_attention_mask class Model(torch.nn.Module): def __init__(self, clip_model, config): super().__init__() self.clip_model = clip_model self.config = config if self.config.has_extra_txt_decoder: self.txt_decoder = Stack(config.extra_decoder) self.itm_txt_head = torch.nn.Linear(config.extra_decoder.d_model, 2) if self.config.has_extra_img_decoder: self.img_decoder = Stack(config.extra_decoder) self.itm_img_head = torch.nn.Linear(config.extra_decoder.d_model, 2) if self.config.has_extra_mix_decoder: self.mix_decoder = Stack(config.extra_decoder) self.mix_itm_head = torch.nn.Linear(config.extra_decoder.d_model, 2) if self.config.has_extra_gen_decoder: self.gen_decoder = Stack(config.extra_decoder, has_embedding=True, generate_causal_mask=True) self.gen_head = torch.nn.Linear(config.extra_decoder.d_model, config.vocab_size) self.config = config def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224] x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width] x = torch.cat( [self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.clip_model.visual.positional_embedding.to(x.dtype) x = self.clip_model.visual.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.clip_model.visual.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.clip_model.visual.ln_post(x) # [NLD] if self.clip_model.visual.proj is not None: proj = self.clip_model.visual.proj[None, :, :] x = (x @ proj) cls_token = x[:, 0, :] return x, cls_token def txt_forward(self, text): dtype = self.clip_model.dtype x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model] x = x + self.clip_model.positional_embedding.type(dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.clip_model.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.clip_model.ln_final(x).type(dtype) proj = self.clip_model.text_projection[None, :, :] x = (x @ proj) # take features from the eot embedding (eot_token is the highest number in each sequence) eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] return x, eot # [NLD] def var_img_forward(self, image): if len(image.shape) == 5: img_features1, img_token1 = self.img_forward(image[:, 0, ...]) img_features2, img_token2 = self.img_forward(image[:, 1, ...]) img_token = (img_token1 + img_token2) / 2 img_features = (img_features1 + img_features2) / 2 else: img_features, img_token = self.img_forward(image) img_token = img_token / img_token.norm(dim=-1, keepdim=True) return img_features, img_token def var_txt_forward(self, text): txt_features, txt_token = self.txt_forward(text) txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True) return txt_features, txt_token def get_device(self): return next(self.parameters()).device def get_features(self, image=None, text_ids=None): outputs = {} if image is not None: img_features, img_token = self.var_img_forward(image) outputs['img_features'] = img_features outputs['img_token'] = img_token outputs['img_mask'] = torch.ones_like(img_features[:, :, 0]) if text_ids is not None: txt_features, txt_token = self.var_txt_forward(text_ids) outputs['txt_features'] = txt_features outputs['txt_token'] = txt_token outputs['txt_mask'] = (text_ids != 0).to(txt_features.dtype) return outputs def get_prediction(self, img_features, txt_features, img_mask=None, txt_mask=None, decoder="txt_decoder", **kwargs): outputs = {} if decoder == 'txt_decoder': hidden_states = self.txt_decoder( dec_hidden_states=txt_features, enc_hidden_states=img_features, enc_attention_mask=img_mask, dec_attention_mask=txt_mask, ) outputs['itm_txt_logits'] = self.itm_txt_head(hidden_states[0][:, 0, :]) outputs['itm_txt_probs'] = torch.softmax(outputs['itm_txt_logits'], dim=-1) if decoder == 'img_decoder': hidden_states = self.img_decoder( dec_hidden_states=img_features, enc_hidden_states=txt_features, enc_attention_mask=txt_mask, dec_attention_mask=img_mask, ) outputs['itm_img_logits'] = self.itm_img_head(hidden_states[0][:, 0, :]) outputs['itm_img_probs'] = torch.softmax(outputs['itm_img_logits'], dim=-1) return outputs def forward(self, image, text, itm_text=None, itm_labels=None, gen_inputs=None, gen_labels=None): # , gen_inputs, gen_labels, **kwargs): img_features, img_token = self.var_img_forward(image) txt_features, txt_token = self.var_txt_forward(text) itm_txt_features, _ = self.var_txt_forward(itm_text) itm_txt_mask = (itm_text != 0).to(itm_txt_features.dtype) outputs = dict( img_token=img_token, txt_token=txt_token, img_features=img_features, txt_features=txt_features, ) if self.config.has_extra_txt_decoder and itm_text is not None: itm_img_features = img_features itm_txt_states = self.txt_decoder( dec_hidden_states=itm_txt_features, enc_hidden_states=itm_img_features, enc_attention_mask=None, dec_attention_mask=itm_txt_mask, ) outputs['itm_txt_logits'] = self.itm_txt_head(itm_txt_states[0][:, 0]) if self.config.has_extra_img_decoder and itm_text is not None: itm_img_features = img_features itm_img_states = self.img_decoder( dec_hidden_states=itm_img_features, enc_hidden_states=itm_txt_features, enc_attention_mask=itm_txt_mask, dec_attention_mask=None, ) outputs['itm_img_logits'] = self.itm_img_head(itm_img_states[0][:, 0]) if self.config.has_extra_mix_decoder: pass if self.config.has_extra_gen_decoder: gen_features = self.gen_decoder( input_ids=gen_inputs, enc_hidden_states=img_features, enc_attention_mask=None, dec_attention_mask=None, labels=gen_labels, ) outputs['gen_logits'] = self.gen_head(gen_features[0]) return outputs if __name__ == "__main__": import sys from omegaconf import OmegaConf sys.path.append("/home/quang/workspace/traffic_var") from config.examples import with_decoder_config as config config.has_extra_txt_decoder = True print(OmegaConf.to_yaml(config)) import clip def get_resolution(model): return model.visual.input_resolution if hasattr(model, 'visual') else model.input_resolution model, _ = clip.load(config.clip_model, jit=False, device="cpu") config.img_size = get_resolution(model) model = Model(model, config)