Evaluation_Server / models /fused_model.py
DHPR's picture
Upload 25 files
f638d9c
import copy
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from einops import rearrange
from transformers.models.t5.configuration_t5 import T5Config
from transformers.modeling_utils import ModuleUtilsMixin
from einops import rearrange, reduce
class FeedForward(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN["gelu"]
self.layer_norm = nn.LayerNorm(config.d_model)
def forward(self, x):
x_hidden = self.wo(self.dropout(self.act(self.wi(self.layer_norm(x)))))
return x + self.dropout(x_hidden)
class Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
memory_position - query_position -> bucket_idx.
If bidirectional=False, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position
and larger buckets for larger absolute relative_positions.
* All relative positions >=max_distance map to the same bucket.
* All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) *
(num_buckets - max_exact)).to(torch.long)
relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def forward(self, x, mask=None, x_kv=None, pos_bias=None):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
batch_size, seq_length = x.shape[:2]
real_seq_length = seq_length
key_length = real_seq_length if x_kv is None else x_kv.shape[1]
reshape = lambda states: rearrange(states, 'b s (h d) -> b h s d', h=self.n_heads)
unshape = lambda states: rearrange(states, 'b h s d -> b s (h d)')
q = reshape(self.q(x)) # (batch_size, n_heads, seq_length, dim_per_head)
k = reshape(self.k(x if x_kv is None else x_kv))
v = reshape(self.v(x if x_kv is None else x_kv))
# compute scores
scores = torch.matmul(q, k.transpose(3, 2))
if pos_bias is None:
if not self.has_relative_attention_bias:
pos_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
else:
pos_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
if mask is not None:
pos_bias = pos_bias + mask # (batch_size, n_heads, seq_length, key_length)
position_bias_masked = pos_bias
scores += position_bias_masked
attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (B, H, seq_length, key_length)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) # (B, H, seq_length, key_length)
attn_output = unshape(torch.matmul(attn_weights, v)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
return (attn_output, pos_bias)
class LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = nn.LayerNorm(config.d_model)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, mask=None, pos_bias=None): # x + drop(attn(ln(x)))
h = self.layer_norm(x)
outputs = self.SelfAttention(h, mask=mask, pos_bias=pos_bias)
x = x + self.dropout(outputs[0])
return (x, outputs[1]) # outputs[1] is pos_bias
class LayerCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.EncDecAttention = Attention(config, has_relative_attention_bias=False)
self.layer_norm = nn.LayerNorm(config.d_model)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, x_kv, mask=None, pos_bias=None): # x + drop(attn(ln(x), x_kv))
h = self.layer_norm(x)
outputs = self.EncDecAttention(h, mask=mask, x_kv=x_kv, pos_bias=pos_bias)
x = x + self.dropout(outputs[0])
return (x, outputs[1]) # outputs[1] is pos_bias
class Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
self.layer.append(LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder:
self.layer.append(LayerCrossAttention(config))
self.layer.append(FeedForward(config))
def forward(self, x, mask=None, pos_bias=None, context=None, context_mask=None, context_pos_bias=None):
self_attention_outputs = self.layer[0](x, mask=mask, pos_bias=pos_bias)
hidden_states = self_attention_outputs[0]
do_cross_attention = self.is_decoder and context is not None
if do_cross_attention:
cross_attention_outputs = self.layer[1](
hidden_states,
x_kv=context,
mask=context_mask,
pos_bias=context_pos_bias,
)
hidden_states = cross_attention_outputs[0]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
pos_bias = self_attention_outputs[1]
context_pos_bias = cross_attention_outputs[1] if do_cross_attention else None
return (hidden_states, pos_bias, context_pos_bias)
class Stack(nn.Module):
def __init__(self, config, is_decoder=True, has_embedding=False, generate_causal_mask=False):
super().__init__()
self.config = config
if has_embedding:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.is_decoder = is_decoder
self.dtype = torch.float32
self.generate_causal_mask = generate_causal_mask
self.block = nn.ModuleList([Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)])
self.final_layer_norm = nn.LayerNorm(config.d_model)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
input_ids=None,
dec_hidden_states=None,
enc_hidden_states=None,
dec_attention_mask=None,
enc_attention_mask=None,
):
input_shape = input_ids.size() if input_ids is not None else dec_hidden_states.shape[:-1]
batch_size, seq_length = input_shape
if input_ids is not None:
input_ids = input_ids.view(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids)
else:
inputs_embeds = dec_hidden_states
# required mask seq length can be calculated via length of past
mask_seq_length = seq_length
if dec_attention_mask is None:
dec_attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and enc_attention_mask is None and enc_hidden_states is not None:
encoder_seq_length = enc_hidden_states.shape[1]
enc_attention_mask = torch.ones(batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(dec_attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and enc_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = enc_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if enc_attention_mask is None:
enc_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_extended_attention_mask = self.invert_attention_mask(enc_attention_mask)
else:
encoder_extended_attention_mask = None
pos_bias = None
context_pos_bias = None
hidden_states = self.dropout(inputs_embeds)
for i, layer_module in enumerate(self.block):
layer_outputs = layer_module(
hidden_states,
mask=extended_attention_mask, # [1, 1, 1, 1 ] [B, L]
pos_bias=pos_bias,
context=enc_hidden_states,
context_mask=encoder_extended_attention_mask,
context_pos_bias=context_pos_bias,
)
# layer_outputs is a tuple with:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] # [B, L, D], None
# We share the position biases between the layers - the first layer store them
pos_bias = layer_outputs[2] # [B, H, L, L]
if self.is_decoder and enc_hidden_states is not None:
context_pos_bias = layer_outputs[3]
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
return (hidden_states,)
def invert_attention_mask(self, attention_mask):
"""
Input: 1 for attend, 0 for masked/ignored
Output: 0 for attend, -1e30 for masked/ignored.
Then we can add it to the attention logits.
[B, L] -> [B, 1, 1, L]
[B, L, L] -> [B, 1, L, L]
"""
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
if attention_mask.dim() == 2:
extended_attention_mask = attention_mask[:, None, None, :]
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
return extended_attention_mask
def get_extended_attention_mask(self, attention_mask, input_shape, device=None, dtype=None):
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
attention_mask: 1 for attend, 0 for masked/ignored
Return: The extended attention mask: 0 for attend, -1e30 for masked/ignored
[B, L] -> [B, 1, 1, L]
[B, L, L] -> [B, 1, L, L]
"""
dtype = dtype if dtype else attention_mask.dtype
# If input [B, query_length, key_length] -> [B, 1, query_length, key_length]
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and self.generate_causal_mask:
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(input_shape, attention_mask, device)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})")
# Input: valid = 1, padding = 0
# Output: valid = 0, padding = -1e30
# => then we can add it to the attention logits
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
class Model(torch.nn.Module):
def __init__(self, clip_model, config):
super().__init__()
self.clip_model = clip_model
self.config = config
if self.config.has_extra_txt_decoder:
self.txt_decoder = Stack(config.extra_decoder)
self.itm_txt_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
if self.config.has_extra_img_decoder:
self.img_decoder = Stack(config.extra_decoder)
self.itm_img_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
if self.config.has_extra_mix_decoder:
self.mix_decoder = Stack(config.extra_decoder)
self.mix_itm_head = torch.nn.Linear(config.extra_decoder.d_model, 2)
if self.config.has_extra_gen_decoder:
self.gen_decoder = Stack(config.extra_decoder, has_embedding=True, generate_causal_mask=True)
self.gen_head = torch.nn.Linear(config.extra_decoder.d_model, config.vocab_size)
self.config = config
def img_forward(self, x: torch.Tensor): # [N, 3, 224, 224]
x = self.clip_model.visual.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, gri d ** 2, width]
x = torch.cat(
[self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
x = self.clip_model.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip_model.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.clip_model.visual.ln_post(x) # [NLD]
if self.clip_model.visual.proj is not None:
proj = self.clip_model.visual.proj[None, :, :]
x = (x @ proj)
cls_token = x[:, 0, :]
return x, cls_token
def txt_forward(self, text):
dtype = self.clip_model.dtype
x = self.clip_model.token_embedding(text).type(dtype) # [batch_size, n_ctx, d_model]
x = x + self.clip_model.positional_embedding.type(dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip_model.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.clip_model.ln_final(x).type(dtype)
proj = self.clip_model.text_projection[None, :, :]
x = (x @ proj)
# take features from the eot embedding (eot_token is the highest number in each sequence)
eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
return x, eot # [NLD]
def var_img_forward(self, image):
if len(image.shape) == 5:
img_features1, img_token1 = self.img_forward(image[:, 0, ...])
img_features2, img_token2 = self.img_forward(image[:, 1, ...])
img_token = (img_token1 + img_token2) / 2
img_features = (img_features1 + img_features2) / 2
else:
img_features, img_token = self.img_forward(image)
img_token = img_token / img_token.norm(dim=-1, keepdim=True)
return img_features, img_token
def var_txt_forward(self, text):
txt_features, txt_token = self.txt_forward(text)
txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True)
return txt_features, txt_token
def get_device(self):
return next(self.parameters()).device
def get_features(self, image=None, text_ids=None):
outputs = {}
if image is not None:
img_features, img_token = self.var_img_forward(image)
outputs['img_features'] = img_features
outputs['img_token'] = img_token
outputs['img_mask'] = torch.ones_like(img_features[:, :, 0])
if text_ids is not None:
txt_features, txt_token = self.var_txt_forward(text_ids)
outputs['txt_features'] = txt_features
outputs['txt_token'] = txt_token
outputs['txt_mask'] = (text_ids != 0).to(txt_features.dtype)
return outputs
def get_prediction(self, img_features, txt_features, img_mask=None, txt_mask=None, decoder="txt_decoder", **kwargs):
outputs = {}
if decoder == 'txt_decoder':
hidden_states = self.txt_decoder(
dec_hidden_states=txt_features,
enc_hidden_states=img_features,
enc_attention_mask=img_mask,
dec_attention_mask=txt_mask,
)
outputs['itm_txt_logits'] = self.itm_txt_head(hidden_states[0][:, 0, :])
outputs['itm_txt_probs'] = torch.softmax(outputs['itm_txt_logits'], dim=-1)
if decoder == 'img_decoder':
hidden_states = self.img_decoder(
dec_hidden_states=img_features,
enc_hidden_states=txt_features,
enc_attention_mask=txt_mask,
dec_attention_mask=img_mask,
)
outputs['itm_img_logits'] = self.itm_img_head(hidden_states[0][:, 0, :])
outputs['itm_img_probs'] = torch.softmax(outputs['itm_img_logits'], dim=-1)
return outputs
def forward(self, image, text, itm_text=None, itm_labels=None, gen_inputs=None, gen_labels=None): # , gen_inputs, gen_labels, **kwargs):
img_features, img_token = self.var_img_forward(image)
txt_features, txt_token = self.var_txt_forward(text)
itm_txt_features, _ = self.var_txt_forward(itm_text)
itm_txt_mask = (itm_text != 0).to(itm_txt_features.dtype)
outputs = dict(
img_token=img_token,
txt_token=txt_token,
img_features=img_features,
txt_features=txt_features,
)
if self.config.has_extra_txt_decoder and itm_text is not None:
itm_img_features = img_features
itm_txt_states = self.txt_decoder(
dec_hidden_states=itm_txt_features,
enc_hidden_states=itm_img_features,
enc_attention_mask=None,
dec_attention_mask=itm_txt_mask,
)
outputs['itm_txt_logits'] = self.itm_txt_head(itm_txt_states[0][:, 0])
if self.config.has_extra_img_decoder and itm_text is not None:
itm_img_features = img_features
itm_img_states = self.img_decoder(
dec_hidden_states=itm_img_features,
enc_hidden_states=itm_txt_features,
enc_attention_mask=itm_txt_mask,
dec_attention_mask=None,
)
outputs['itm_img_logits'] = self.itm_img_head(itm_img_states[0][:, 0])
if self.config.has_extra_mix_decoder:
pass
if self.config.has_extra_gen_decoder:
gen_features = self.gen_decoder(
input_ids=gen_inputs,
enc_hidden_states=img_features,
enc_attention_mask=None,
dec_attention_mask=None,
labels=gen_labels,
)
outputs['gen_logits'] = self.gen_head(gen_features[0])
return outputs
if __name__ == "__main__":
import sys
from omegaconf import OmegaConf
sys.path.append("/home/quang/workspace/traffic_var")
from config.examples import with_decoder_config as config
config.has_extra_txt_decoder = True
print(OmegaConf.to_yaml(config))
import clip
def get_resolution(model):
return model.visual.input_resolution if hasattr(model, 'visual') else model.input_resolution
model, _ = clip.load(config.clip_model, jit=False, device="cpu")
config.img_size = get_resolution(model)
model = Model(model, config)