File size: 6,052 Bytes
f2f112a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398e700
f2f112a
 
 
 
 
 
398e700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2f112a
 
 
 
 
 
 
 
 
 
 
 
398e700
 
f2f112a
 
398e700
 
 
 
f2f112a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from .Transformer import TransformerDecoderLayer
from .Transformer import TransformerCrossDecoderLayer
from .Transformer import TransformerDecoder
from .ScoreFunction import ScoreFunction
from .Adapter import Adapter
from .ScoreFunction_HDC import ScoreFunctionHDC

class TaskCLIP(torch.nn.Module):
    def __init__(self, 
                 model_config, 
                 normalize_before=False,
                 device = "cuda:1") -> None:
        super().__init__()
        self.num_layers = model_config['num_layers']
        self.return_intermediate = model_config['return_intermediate']
        self.d_model = model_config['d_model']
        self.nhead = model_config['nhead']
        self.dim_feedforward = model_config['dim_feedforward']
        self.dropout = model_config['dropout']
        self.N_words = model_config['N_words']
        self.activation = model_config['activation']
        self.ratio_text = model_config['ratio_text']
        self.ratio_image = model_config['ratio_image']
        self.ratio_glob = model_config['ratio_glob']
        self.norm_before = model_config['norm_before']
        self.norm_after = model_config['norm_after']
        self.MAX_Val = model_config['MAX_VAL']
        self.MIN_Val = model_config['MIN_VAL']
        self.normalize_before = normalize_before
        self.device = device

        self.decoder_norm = torch.nn.LayerNorm(self.d_model)
        if model_config['cross_attention']:
            self.decoder_layer = TransformerCrossDecoderLayer(self.d_model, 
                                                          self.nhead,
                                                          self.dim_feedforward, 
                                                          self.dropout, 
                                                          self.activation, 
                                                          self.normalize_before)
        else:
            self.decoder_layer = TransformerDecoderLayer(self.d_model,
                                                         self.nhead, 
                                                         self.dim_feedforward, 
                                                         self.dropout, 
                                                         self.activation, 
                                                         self.normalize_before)
        
        self.decoder = TransformerDecoder(self.decoder_layer, 
                                          self.num_layers, 
                                          self.decoder_norm, 
                                          return_intermediate=self.return_intermediate)
        self.vision_adapter = Adapter(self.d_model)
        self.text_adapter = Adapter(self.d_model)
        self.glob_adapter = torch.nn.MultiheadAttention(self.d_model, 
                                                        self.nhead, 
                                                        dropout=self.dropout)
        self.score_function_name = model_config["score_function"]
        if model_config['score_function'] != 'HDC':
            self.ScoreFunction = ScoreFunction(N_words=self.N_words)
        else:
            self.ScoreFunction = ScoreFunctionHDC(N_words=self.N_words, HDV_D=int(model_config['HDV_D']))
        self.threshold = 0.1
    
    def _apply_hw_noise(self, score_raw: torch.Tensor, dist: str, width_0_100: int, strength_0_100: int) -> torch.Tensor:
        dist = (dist or "none").lower()
        w = max(0, min(100, int(width_0_100)))
        s = max(0, min(100, int(strength_0_100)))

        if dist == "none" or w == 0 or s == 0:
            return score_raw

        # Tune this constant to match your desired “device noise” magnitude.
        # score_raw here is a dot-product similarity matrix; typical scale depends on your embeddings.
        MAX_WIDTH = 5.0

        base = (w / 100.0) * MAX_WIDTH
        scale = (s / 100.0)
        eps = base * scale

        if dist == "gaussian":
            noise = torch.randn_like(score_raw) * eps
        elif dist == "uniform":
            noise = (torch.rand_like(score_raw) * 2.0 - 1.0) * eps
        elif dist == "laplace":
            # Laplace(0, b): sample via inverse-CDF
            u = torch.rand_like(score_raw) - 0.5
            noise = -eps * torch.sign(u) * torch.log1p(-2.0 * torch.abs(u))
        else:
            return score_raw

        return score_raw + noise

    def forward(
        self, 
        tgt, 
        memory, 
        image_embedding,
        norm=False,
        hw_noise_dist: str = "none",
        hw_noise_width: int = 0,
        hw_noise_strength: int = 0,
        hdc_bits: int = 32):
        if self.norm_before:
            tgt /= tgt.norm(dim=-1, keepdim=True)
            memory /= memory.norm(dim=-1, keepdim=True)
        x = self.vision_adapter(tgt)
        tgt = self.ratio_image * x + (1 - self.ratio_image) * tgt
        x0 = self.vision_adapter(image_embedding)
        image_embedding_temp = self.ratio_image * x0 + (1 - self.ratio_image) * image_embedding
        y = self.text_adapter(memory)[0]
        memory = self.ratio_text*y + (1 - self.ratio_text) * memory
        tgt = self.ratio_glob*self.glob_adapter(tgt, image_embedding_temp, image_embedding_temp)[0] + (1 - self.ratio_glob)*tgt
        tgt_new, memory_new = self.decoder(tgt,memory,None)
        score_raw = torch.mm(tgt_new,memory_new.T)
        # add noise
        score_raw = self._apply_hw_noise(score_raw, hw_noise_dist, hw_noise_width, hw_noise_strength)
        if self.norm_after:
            score_raw = self.Norm(score_raw)
        if self.score_function_name == 'HDC':
            score_res = self.ScoreFunction(score_raw, quant_bits=hdc_bits)
        else:
            score_res = self.ScoreFunction(score_raw)
        return tgt_new, memory_new, score_res, score_raw

    def Norm(self, score):
        min_val = score.min()
        max_val = score.max()
        res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val)
        return res