File size: 1,847 Bytes
f2f112a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from .Transformer import TransformerDecoderLayer
from .Transformer import TransformerDecoder
from .ScoreFunction import ScoreFunction

class CoCoTask_Model(torch.nn.Module):
    def __init__(self, 
                 num_layers=3, 
                 norm=None, 
                 return_intermediate=False,
                 d_model = 1024, 
                 nhead = 8, 
                 dim_feedforward=2048, 
                 dropout=0.1, 
                 activation="relu", 
                 normalize_before=False,
                 device = "cuda:1") -> None:
        super().__init__()
        self.decoder_norm = torch.nn.LayerNorm(d_model)
        self.decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
        self.decoder = TransformerDecoder(self.decoder_layer, num_layers, self.decoder_norm, return_intermediate=return_intermediate)
        self.MAX_Val = 30
        self.MIN_Val = 10

        # self.MLP = torch.nn.Sequential(torch.nn.Linear(10,64), 
        #                                torch.nn.ReLU(),  
        #                                torch.nn.Linear(64,1),
        #                                torch.nn.Sigmoid())

        self.ScoreFunction = ScoreFunction()

        self.threshold = 0.2
    
    def forward(self, tgt, memory):
        #tgt_new, memory_new = self.decoder(tgt,memory,None)
        tgt_new, memory_new = tgt, memory
        score_raw = torch.mm(tgt_new,memory_new.T)
        score_raw = self.Norm(score_raw)
        score_res = self.ScoreFunction(score_raw)
        return tgt_new, memory_new, score_res, score_raw

    def Norm(self, score):
        min_val = score.min()
        max_val = score.max()
        res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val)
        return res