TaskCLIP / models /TaskCLIP.py
HanningChen
Initial HF Space: FastAPI + HTML (no weights yet)
f2f112a
raw
history blame
4.43 kB
import torch
from .Transformer import TransformerDecoderLayer
from .Transformer import TransformerCrossDecoderLayer
from .Transformer import TransformerDecoder
from .ScoreFunction import ScoreFunction
from .Adapter import Adapter
from .ScoreFunction_HDC import ScoreFunctionHDC
class TaskCLIP(torch.nn.Module):
def __init__(self,
model_config,
normalize_before=False,
device = "cuda:1") -> None:
super().__init__()
self.num_layers = model_config['num_layers']
self.return_intermediate = model_config['return_intermediate']
self.d_model = model_config['d_model']
self.nhead = model_config['nhead']
self.dim_feedforward = model_config['dim_feedforward']
self.dropout = model_config['dropout']
self.N_words = model_config['N_words']
self.activation = model_config['activation']
self.ratio_text = model_config['ratio_text']
self.ratio_image = model_config['ratio_image']
self.ratio_glob = model_config['ratio_glob']
self.norm_before = model_config['norm_before']
self.norm_after = model_config['norm_after']
self.MAX_Val = model_config['MAX_VAL']
self.MIN_Val = model_config['MIN_VAL']
self.normalize_before = normalize_before
self.device = device
self.decoder_norm = torch.nn.LayerNorm(self.d_model)
if model_config['cross_attention']:
self.decoder_layer = TransformerCrossDecoderLayer(self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.normalize_before)
else:
self.decoder_layer = TransformerDecoderLayer(self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.normalize_before)
self.decoder = TransformerDecoder(self.decoder_layer,
self.num_layers,
self.decoder_norm,
return_intermediate=self.return_intermediate)
self.vision_adapter = Adapter(self.d_model)
self.text_adapter = Adapter(self.d_model)
self.glob_adapter = torch.nn.MultiheadAttention(self.d_model,
self.nhead,
dropout=self.dropout)
if model_config['score_function'] != 'HDC':
self.ScoreFunction = ScoreFunction(N_words=self.N_words)
else:
self.ScoreFunction = ScoreFunctionHDC(N_words=self.N_words, HDV_D=int(model_config['HDV_D']))
self.threshold = 0.1
def forward(self, tgt, memory, image_embedding,norm=False):
if self.norm_before:
tgt /= tgt.norm(dim=-1, keepdim=True)
memory /= memory.norm(dim=-1, keepdim=True)
x = self.vision_adapter(tgt)
tgt = self.ratio_image * x + (1 - self.ratio_image) * tgt
x0 = self.vision_adapter(image_embedding)
image_embedding_temp = self.ratio_image * x0 + (1 - self.ratio_image) * image_embedding
y = self.text_adapter(memory)[0]
memory = self.ratio_text*y + (1 - self.ratio_text) * memory
tgt = self.ratio_glob*self.glob_adapter(tgt, image_embedding_temp, image_embedding_temp)[0] + (1 - self.ratio_glob)*tgt
tgt_new, memory_new = self.decoder(tgt,memory,None)
score_raw = torch.mm(tgt_new,memory_new.T)
if self.norm_after:
score_raw = self.Norm(score_raw)
score_res = self.ScoreFunction(score_raw)
return tgt_new, memory_new, score_res, score_raw
def Norm(self, score):
min_val = score.min()
max_val = score.max()
res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val)
return res