Spaces:
Sleeping
Sleeping
| import torch | |
| from .Transformer import TransformerDecoderLayer | |
| from .Transformer import TransformerCrossDecoderLayer | |
| from .Transformer import TransformerDecoder | |
| from .ScoreFunction import ScoreFunction | |
| from .Adapter import Adapter | |
| from .ScoreFunction_HDC import ScoreFunctionHDC | |
| class TaskCLIP(torch.nn.Module): | |
| def __init__(self, | |
| model_config, | |
| normalize_before=False, | |
| device = "cuda:1") -> None: | |
| super().__init__() | |
| self.num_layers = model_config['num_layers'] | |
| self.return_intermediate = model_config['return_intermediate'] | |
| self.d_model = model_config['d_model'] | |
| self.nhead = model_config['nhead'] | |
| self.dim_feedforward = model_config['dim_feedforward'] | |
| self.dropout = model_config['dropout'] | |
| self.N_words = model_config['N_words'] | |
| self.activation = model_config['activation'] | |
| self.ratio_text = model_config['ratio_text'] | |
| self.ratio_image = model_config['ratio_image'] | |
| self.ratio_glob = model_config['ratio_glob'] | |
| self.norm_before = model_config['norm_before'] | |
| self.norm_after = model_config['norm_after'] | |
| self.MAX_Val = model_config['MAX_VAL'] | |
| self.MIN_Val = model_config['MIN_VAL'] | |
| self.normalize_before = normalize_before | |
| self.device = device | |
| self.decoder_norm = torch.nn.LayerNorm(self.d_model) | |
| if model_config['cross_attention']: | |
| self.decoder_layer = TransformerCrossDecoderLayer(self.d_model, | |
| self.nhead, | |
| self.dim_feedforward, | |
| self.dropout, | |
| self.activation, | |
| self.normalize_before) | |
| else: | |
| self.decoder_layer = TransformerDecoderLayer(self.d_model, | |
| self.nhead, | |
| self.dim_feedforward, | |
| self.dropout, | |
| self.activation, | |
| self.normalize_before) | |
| self.decoder = TransformerDecoder(self.decoder_layer, | |
| self.num_layers, | |
| self.decoder_norm, | |
| return_intermediate=self.return_intermediate) | |
| self.vision_adapter = Adapter(self.d_model) | |
| self.text_adapter = Adapter(self.d_model) | |
| self.glob_adapter = torch.nn.MultiheadAttention(self.d_model, | |
| self.nhead, | |
| dropout=self.dropout) | |
| self.score_function_name = model_config["score_function"] | |
| if model_config['score_function'] != 'HDC': | |
| self.ScoreFunction = ScoreFunction(N_words=self.N_words) | |
| else: | |
| self.ScoreFunction = ScoreFunctionHDC(N_words=self.N_words, HDV_D=int(model_config['HDV_D'])) | |
| self.threshold = 0.1 | |
| def _apply_hw_noise(self, score_raw: torch.Tensor, dist: str, width_0_100: int, strength_0_100: int) -> torch.Tensor: | |
| dist = (dist or "none").lower() | |
| w = max(0, min(100, int(width_0_100))) | |
| s = max(0, min(100, int(strength_0_100))) | |
| if dist == "none" or w == 0 or s == 0: | |
| return score_raw | |
| # Tune this constant to match your desired “device noise” magnitude. | |
| # score_raw here is a dot-product similarity matrix; typical scale depends on your embeddings. | |
| MAX_WIDTH = 5.0 | |
| base = (w / 100.0) * MAX_WIDTH | |
| scale = (s / 100.0) | |
| eps = base * scale | |
| if dist == "gaussian": | |
| noise = torch.randn_like(score_raw) * eps | |
| elif dist == "uniform": | |
| noise = (torch.rand_like(score_raw) * 2.0 - 1.0) * eps | |
| elif dist == "laplace": | |
| # Laplace(0, b): sample via inverse-CDF | |
| u = torch.rand_like(score_raw) - 0.5 | |
| noise = -eps * torch.sign(u) * torch.log1p(-2.0 * torch.abs(u)) | |
| else: | |
| return score_raw | |
| return score_raw + noise | |
| def forward( | |
| self, | |
| tgt, | |
| memory, | |
| image_embedding, | |
| norm=False, | |
| hw_noise_dist: str = "none", | |
| hw_noise_width: int = 0, | |
| hw_noise_strength: int = 0, | |
| hdc_bits: int = 32): | |
| if self.norm_before: | |
| tgt /= tgt.norm(dim=-1, keepdim=True) | |
| memory /= memory.norm(dim=-1, keepdim=True) | |
| x = self.vision_adapter(tgt) | |
| tgt = self.ratio_image * x + (1 - self.ratio_image) * tgt | |
| x0 = self.vision_adapter(image_embedding) | |
| image_embedding_temp = self.ratio_image * x0 + (1 - self.ratio_image) * image_embedding | |
| y = self.text_adapter(memory)[0] | |
| memory = self.ratio_text*y + (1 - self.ratio_text) * memory | |
| tgt = self.ratio_glob*self.glob_adapter(tgt, image_embedding_temp, image_embedding_temp)[0] + (1 - self.ratio_glob)*tgt | |
| tgt_new, memory_new = self.decoder(tgt,memory,None) | |
| score_raw = torch.mm(tgt_new,memory_new.T) | |
| # add noise | |
| score_raw = self._apply_hw_noise(score_raw, hw_noise_dist, hw_noise_width, hw_noise_strength) | |
| if self.norm_after: | |
| score_raw = self.Norm(score_raw) | |
| if self.score_function_name == 'HDC': | |
| score_res = self.ScoreFunction(score_raw, quant_bits=hdc_bits) | |
| else: | |
| score_res = self.ScoreFunction(score_raw) | |
| return tgt_new, memory_new, score_res, score_raw | |
| def Norm(self, score): | |
| min_val = score.min() | |
| max_val = score.max() | |
| res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val) | |
| return res |