TaskCLIP / models /TaskCLIP.py
HanningChen
Add noise
398e700
import torch
from .Transformer import TransformerDecoderLayer
from .Transformer import TransformerCrossDecoderLayer
from .Transformer import TransformerDecoder
from .ScoreFunction import ScoreFunction
from .Adapter import Adapter
from .ScoreFunction_HDC import ScoreFunctionHDC
class TaskCLIP(torch.nn.Module):
def __init__(self,
model_config,
normalize_before=False,
device = "cuda:1") -> None:
super().__init__()
self.num_layers = model_config['num_layers']
self.return_intermediate = model_config['return_intermediate']
self.d_model = model_config['d_model']
self.nhead = model_config['nhead']
self.dim_feedforward = model_config['dim_feedforward']
self.dropout = model_config['dropout']
self.N_words = model_config['N_words']
self.activation = model_config['activation']
self.ratio_text = model_config['ratio_text']
self.ratio_image = model_config['ratio_image']
self.ratio_glob = model_config['ratio_glob']
self.norm_before = model_config['norm_before']
self.norm_after = model_config['norm_after']
self.MAX_Val = model_config['MAX_VAL']
self.MIN_Val = model_config['MIN_VAL']
self.normalize_before = normalize_before
self.device = device
self.decoder_norm = torch.nn.LayerNorm(self.d_model)
if model_config['cross_attention']:
self.decoder_layer = TransformerCrossDecoderLayer(self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.normalize_before)
else:
self.decoder_layer = TransformerDecoderLayer(self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.normalize_before)
self.decoder = TransformerDecoder(self.decoder_layer,
self.num_layers,
self.decoder_norm,
return_intermediate=self.return_intermediate)
self.vision_adapter = Adapter(self.d_model)
self.text_adapter = Adapter(self.d_model)
self.glob_adapter = torch.nn.MultiheadAttention(self.d_model,
self.nhead,
dropout=self.dropout)
self.score_function_name = model_config["score_function"]
if model_config['score_function'] != 'HDC':
self.ScoreFunction = ScoreFunction(N_words=self.N_words)
else:
self.ScoreFunction = ScoreFunctionHDC(N_words=self.N_words, HDV_D=int(model_config['HDV_D']))
self.threshold = 0.1
def _apply_hw_noise(self, score_raw: torch.Tensor, dist: str, width_0_100: int, strength_0_100: int) -> torch.Tensor:
dist = (dist or "none").lower()
w = max(0, min(100, int(width_0_100)))
s = max(0, min(100, int(strength_0_100)))
if dist == "none" or w == 0 or s == 0:
return score_raw
# Tune this constant to match your desired “device noise” magnitude.
# score_raw here is a dot-product similarity matrix; typical scale depends on your embeddings.
MAX_WIDTH = 5.0
base = (w / 100.0) * MAX_WIDTH
scale = (s / 100.0)
eps = base * scale
if dist == "gaussian":
noise = torch.randn_like(score_raw) * eps
elif dist == "uniform":
noise = (torch.rand_like(score_raw) * 2.0 - 1.0) * eps
elif dist == "laplace":
# Laplace(0, b): sample via inverse-CDF
u = torch.rand_like(score_raw) - 0.5
noise = -eps * torch.sign(u) * torch.log1p(-2.0 * torch.abs(u))
else:
return score_raw
return score_raw + noise
def forward(
self,
tgt,
memory,
image_embedding,
norm=False,
hw_noise_dist: str = "none",
hw_noise_width: int = 0,
hw_noise_strength: int = 0,
hdc_bits: int = 32):
if self.norm_before:
tgt /= tgt.norm(dim=-1, keepdim=True)
memory /= memory.norm(dim=-1, keepdim=True)
x = self.vision_adapter(tgt)
tgt = self.ratio_image * x + (1 - self.ratio_image) * tgt
x0 = self.vision_adapter(image_embedding)
image_embedding_temp = self.ratio_image * x0 + (1 - self.ratio_image) * image_embedding
y = self.text_adapter(memory)[0]
memory = self.ratio_text*y + (1 - self.ratio_text) * memory
tgt = self.ratio_glob*self.glob_adapter(tgt, image_embedding_temp, image_embedding_temp)[0] + (1 - self.ratio_glob)*tgt
tgt_new, memory_new = self.decoder(tgt,memory,None)
score_raw = torch.mm(tgt_new,memory_new.T)
# add noise
score_raw = self._apply_hw_noise(score_raw, hw_noise_dist, hw_noise_width, hw_noise_strength)
if self.norm_after:
score_raw = self.Norm(score_raw)
if self.score_function_name == 'HDC':
score_res = self.ScoreFunction(score_raw, quant_bits=hdc_bits)
else:
score_res = self.ScoreFunction(score_raw)
return tgt_new, memory_new, score_res, score_raw
def Norm(self, score):
min_val = score.min()
max_val = score.max()
res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val)
return res