import torch from .Transformer import TransformerDecoderLayer from .Transformer import TransformerDecoder from .ScoreFunction import ScoreFunction from .Adapter import Adapter class TaskCLIP(torch.nn.Module): def __init__(self, model_config, normalize_before=False, device = "cuda:1") -> None: super().__init__() self.num_layers = model_config['num_layers'] self.return_intermediate = model_config['return_intermediate'] self.d_model = model_config['d_model'] self.nhead = model_config['nhead'] self.dim_feedforward = model_config['dim_feedforward'] self.dropout = model_config['dropout'] self.N_words = model_config['N_words'] self.activation = model_config['activation'] self.ratio_text = model_config['ratio_text'] self.ratio_image = model_config['ratio_image'] self.ratio_glob = model_config['ratio_glob'] self.norm_before = model_config['norm_before'] self.norm_after = model_config['norm_after'] self.MAX_Val = model_config['MAX_VAL'] self.MIN_Val = model_config['MIN_VAL'] self.normalize_before = normalize_before self.device = device self.decoder_norm = torch.nn.LayerNorm(self.d_model) self.decoder_layer = TransformerDecoderLayer(self.d_model, self.nhead, self.dim_feedforward, self.dropout, self.activation, self.normalize_before) self.decoder = TransformerDecoder(self.decoder_layer, self.num_layers, self.decoder_norm, return_intermediate=self.return_intermediate) #self.adapter_norm = torch.nn.LayerNorm(self.d_model) self.vision_adapter = Adapter(self.d_model) self.text_adapter = Adapter(self.d_model) self.glob_adapter = torch.nn.MultiheadAttention(self.d_model, self.nhead, dropout=self.dropout) self.ScoreFunction = ScoreFunction(N_words=self.N_words) self.threshold = 0.1 def forward(self, tgt, memory, image_embedding,norm=False): if self.norm_before: tgt /= tgt.norm(dim=-1, keepdim=True) memory /= memory.norm(dim=-1, keepdim=True) x = self.vision_adapter(tgt) tgt = self.ratio_image * x + (1 - self.ratio_image) * tgt x0 = self.vision_adapter(image_embedding) image_embedding_temp = self.ratio_image * x0 + (1 - self.ratio_image) * image_embedding y = self.text_adapter(memory)[0] memory = self.ratio_text*y + (1 - self.ratio_text) * memory tgt = self.ratio_glob*self.glob_adapter(tgt, image_embedding_temp, image_embedding_temp)[0] + (1 - self.ratio_glob)*tgt tgt_new, memory_new = self.decoder(tgt,memory,None) score_raw = torch.mm(tgt_new,memory_new.T) if self.norm_after: score_raw = self.Norm(score_raw) score_res = self.ScoreFunction(score_raw) return tgt_new, memory_new, score_res, score_raw def Norm(self, score): min_val = score.min() max_val = score.max() res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val) return res