import torch from .TextTransformer import TransformerDecoderLayer from .TextTransformer import TransformerDecoder from .ScoreFunction import ScoreFunction class CoCoTask_Model(torch.nn.Module): def __init__(self, num_layers=3, norm=None, return_intermediate=False, d_model = 1024, nhead = 8, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, device = "cuda:1") -> None: super().__init__() self.decoder_norm = torch.nn.LayerNorm(d_model) self.decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) self.decoder = TransformerDecoder(self.decoder_layer, num_layers, self.decoder_norm, return_intermediate=return_intermediate) self.MAX_Val = 30 self.MIN_Val = 10 # self.MLP = torch.nn.Sequential(torch.nn.Linear(10,64), # torch.nn.ReLU(), # torch.nn.Linear(64,1), # torch.nn.Sigmoid()) self.ScoreFunction = ScoreFunction() self.threshold = 0.4 def forward(self, tgt, memory): tgt_new, memory_new = self.decoder(tgt,memory,None) score_raw = torch.mm(tgt_new,memory_new.T) score_raw = self.Norm(score_raw) #score_res = self.MLP(score_raw) score_res = self.ScoreFunction(score_raw) return tgt_new, memory_new, score_res, score_raw def Norm(self, score): min_val = score.min() max_val = score.max() res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val) return res