TaskCLIP / models /TaskCLIP.py
HanningChen
Add noise
398e700
raw
history blame
6.05 kB
import torch
from .Transformer import TransformerDecoderLayer
from .Transformer import TransformerCrossDecoderLayer
from .Transformer import TransformerDecoder
from .ScoreFunction import ScoreFunction
from .Adapter import Adapter
from .ScoreFunction_HDC import ScoreFunctionHDC
class TaskCLIP(torch.nn.Module):
def __init__(self,
model_config,
normalize_before=False,
device = "cuda:1") -> None:
super().__init__()
self.num_layers = model_config['num_layers']
self.return_intermediate = model_config['return_intermediate']
self.d_model = model_config['d_model']
self.nhead = model_config['nhead']
self.dim_feedforward = model_config['dim_feedforward']
self.dropout = model_config['dropout']
self.N_words = model_config['N_words']
self.activation = model_config['activation']
self.ratio_text = model_config['ratio_text']
self.ratio_image = model_config['ratio_image']
self.ratio_glob = model_config['ratio_glob']
self.norm_before = model_config['norm_before']
self.norm_after = model_config['norm_after']
self.MAX_Val = model_config['MAX_VAL']
self.MIN_Val = model_config['MIN_VAL']
self.normalize_before = normalize_before
self.device = device
self.decoder_norm = torch.nn.LayerNorm(self.d_model)
if model_config['cross_attention']:
self.decoder_layer = TransformerCrossDecoderLayer(self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.normalize_before)
else:
self.decoder_layer = TransformerDecoderLayer(self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.normalize_before)
self.decoder = TransformerDecoder(self.decoder_layer,
self.num_layers,
self.decoder_norm,
return_intermediate=self.return_intermediate)
self.vision_adapter = Adapter(self.d_model)
self.text_adapter = Adapter(self.d_model)
self.glob_adapter = torch.nn.MultiheadAttention(self.d_model,
self.nhead,
dropout=self.dropout)
self.score_function_name = model_config["score_function"]
if model_config['score_function'] != 'HDC':
self.ScoreFunction = ScoreFunction(N_words=self.N_words)
else:
self.ScoreFunction = ScoreFunctionHDC(N_words=self.N_words, HDV_D=int(model_config['HDV_D']))
self.threshold = 0.1
def _apply_hw_noise(self, score_raw: torch.Tensor, dist: str, width_0_100: int, strength_0_100: int) -> torch.Tensor:
dist = (dist or "none").lower()
w = max(0, min(100, int(width_0_100)))
s = max(0, min(100, int(strength_0_100)))
if dist == "none" or w == 0 or s == 0:
return score_raw
# Tune this constant to match your desired “device noise” magnitude.
# score_raw here is a dot-product similarity matrix; typical scale depends on your embeddings.
MAX_WIDTH = 5.0
base = (w / 100.0) * MAX_WIDTH
scale = (s / 100.0)
eps = base * scale
if dist == "gaussian":
noise = torch.randn_like(score_raw) * eps
elif dist == "uniform":
noise = (torch.rand_like(score_raw) * 2.0 - 1.0) * eps
elif dist == "laplace":
# Laplace(0, b): sample via inverse-CDF
u = torch.rand_like(score_raw) - 0.5
noise = -eps * torch.sign(u) * torch.log1p(-2.0 * torch.abs(u))
else:
return score_raw
return score_raw + noise
def forward(
self,
tgt,
memory,
image_embedding,
norm=False,
hw_noise_dist: str = "none",
hw_noise_width: int = 0,
hw_noise_strength: int = 0,
hdc_bits: int = 32):
if self.norm_before:
tgt /= tgt.norm(dim=-1, keepdim=True)
memory /= memory.norm(dim=-1, keepdim=True)
x = self.vision_adapter(tgt)
tgt = self.ratio_image * x + (1 - self.ratio_image) * tgt
x0 = self.vision_adapter(image_embedding)
image_embedding_temp = self.ratio_image * x0 + (1 - self.ratio_image) * image_embedding
y = self.text_adapter(memory)[0]
memory = self.ratio_text*y + (1 - self.ratio_text) * memory
tgt = self.ratio_glob*self.glob_adapter(tgt, image_embedding_temp, image_embedding_temp)[0] + (1 - self.ratio_glob)*tgt
tgt_new, memory_new = self.decoder(tgt,memory,None)
score_raw = torch.mm(tgt_new,memory_new.T)
# add noise
score_raw = self._apply_hw_noise(score_raw, hw_noise_dist, hw_noise_width, hw_noise_strength)
if self.norm_after:
score_raw = self.Norm(score_raw)
if self.score_function_name == 'HDC':
score_res = self.ScoreFunction(score_raw, quant_bits=hdc_bits)
else:
score_res = self.ScoreFunction(score_raw)
return tgt_new, memory_new, score_res, score_raw
def Norm(self, score):
min_val = score.min()
max_val = score.max()
res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val)
return res