File size: 3,670 Bytes
f2f112a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from .Transformer import TransformerDecoderLayer
from .Transformer import TransformerDecoder
from .ScoreFunction import ScoreFunction
from .Adapter import Adapter

class TaskCLIP(torch.nn.Module):
    def __init__(self, 
                 model_config, 
                 normalize_before=False,
                 device = "cuda:1") -> None:
        super().__init__()
        self.num_layers = model_config['num_layers']
        self.return_intermediate = model_config['return_intermediate']
        self.d_model = model_config['d_model']
        self.nhead = model_config['nhead']
        self.dim_feedforward = model_config['dim_feedforward']
        self.dropout = model_config['dropout']
        self.N_words = model_config['N_words']
        self.activation = model_config['activation']
        self.ratio_text = model_config['ratio_text']
        self.ratio_image = model_config['ratio_image']
        self.ratio_glob = model_config['ratio_glob']
        self.norm_before = model_config['norm_before']
        self.norm_after = model_config['norm_after']
        self.MAX_Val = model_config['MAX_VAL']
        self.MIN_Val = model_config['MIN_VAL']
        self.normalize_before = normalize_before
        self.device = device

        self.decoder_norm = torch.nn.LayerNorm(self.d_model)
        self.decoder_layer = TransformerDecoderLayer(self.d_model, 
                                                     self.nhead, 
                                                     self.dim_feedforward, 
                                                     self.dropout, 
                                                     self.activation, 
                                                     self.normalize_before)
        self.decoder = TransformerDecoder(self.decoder_layer, 
                                          self.num_layers, 
                                          self.decoder_norm, 
                                          return_intermediate=self.return_intermediate)
        #self.adapter_norm = torch.nn.LayerNorm(self.d_model)
        self.vision_adapter = Adapter(self.d_model)
        self.text_adapter = Adapter(self.d_model)
        self.glob_adapter = torch.nn.MultiheadAttention(self.d_model, 
                                                        self.nhead, 
                                                        dropout=self.dropout)
        self.ScoreFunction = ScoreFunction(N_words=self.N_words)
        self.threshold = 0.1
    
    def forward(self, tgt, memory, image_embedding,norm=False):
        if self.norm_before:
            tgt /= tgt.norm(dim=-1, keepdim=True)
            memory /= memory.norm(dim=-1, keepdim=True)
        x = self.vision_adapter(tgt)
        tgt = self.ratio_image * x + (1 - self.ratio_image) * tgt
        x0 = self.vision_adapter(image_embedding)
        image_embedding_temp = self.ratio_image * x0 + (1 - self.ratio_image) * image_embedding
        y = self.text_adapter(memory)[0]
        memory = self.ratio_text*y + (1 - self.ratio_text) * memory
        tgt = self.ratio_glob*self.glob_adapter(tgt, image_embedding_temp, image_embedding_temp)[0] + (1 - self.ratio_glob)*tgt
        tgt_new, memory_new = self.decoder(tgt,memory,None)
        score_raw = torch.mm(tgt_new,memory_new.T)
        if self.norm_after:
            score_raw = self.Norm(score_raw)
        score_res = self.ScoreFunction(score_raw)
        return tgt_new, memory_new, score_res, score_raw

    def Norm(self, score):
        min_val = score.min()
        max_val = score.max()
        res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val)
        return res