Spaces:
Sleeping
Sleeping
File size: 6,052 Bytes
f2f112a 398e700 f2f112a 398e700 f2f112a 398e700 f2f112a 398e700 f2f112a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
from .Transformer import TransformerDecoderLayer
from .Transformer import TransformerCrossDecoderLayer
from .Transformer import TransformerDecoder
from .ScoreFunction import ScoreFunction
from .Adapter import Adapter
from .ScoreFunction_HDC import ScoreFunctionHDC
class TaskCLIP(torch.nn.Module):
def __init__(self,
model_config,
normalize_before=False,
device = "cuda:1") -> None:
super().__init__()
self.num_layers = model_config['num_layers']
self.return_intermediate = model_config['return_intermediate']
self.d_model = model_config['d_model']
self.nhead = model_config['nhead']
self.dim_feedforward = model_config['dim_feedforward']
self.dropout = model_config['dropout']
self.N_words = model_config['N_words']
self.activation = model_config['activation']
self.ratio_text = model_config['ratio_text']
self.ratio_image = model_config['ratio_image']
self.ratio_glob = model_config['ratio_glob']
self.norm_before = model_config['norm_before']
self.norm_after = model_config['norm_after']
self.MAX_Val = model_config['MAX_VAL']
self.MIN_Val = model_config['MIN_VAL']
self.normalize_before = normalize_before
self.device = device
self.decoder_norm = torch.nn.LayerNorm(self.d_model)
if model_config['cross_attention']:
self.decoder_layer = TransformerCrossDecoderLayer(self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.normalize_before)
else:
self.decoder_layer = TransformerDecoderLayer(self.d_model,
self.nhead,
self.dim_feedforward,
self.dropout,
self.activation,
self.normalize_before)
self.decoder = TransformerDecoder(self.decoder_layer,
self.num_layers,
self.decoder_norm,
return_intermediate=self.return_intermediate)
self.vision_adapter = Adapter(self.d_model)
self.text_adapter = Adapter(self.d_model)
self.glob_adapter = torch.nn.MultiheadAttention(self.d_model,
self.nhead,
dropout=self.dropout)
self.score_function_name = model_config["score_function"]
if model_config['score_function'] != 'HDC':
self.ScoreFunction = ScoreFunction(N_words=self.N_words)
else:
self.ScoreFunction = ScoreFunctionHDC(N_words=self.N_words, HDV_D=int(model_config['HDV_D']))
self.threshold = 0.1
def _apply_hw_noise(self, score_raw: torch.Tensor, dist: str, width_0_100: int, strength_0_100: int) -> torch.Tensor:
dist = (dist or "none").lower()
w = max(0, min(100, int(width_0_100)))
s = max(0, min(100, int(strength_0_100)))
if dist == "none" or w == 0 or s == 0:
return score_raw
# Tune this constant to match your desired “device noise” magnitude.
# score_raw here is a dot-product similarity matrix; typical scale depends on your embeddings.
MAX_WIDTH = 5.0
base = (w / 100.0) * MAX_WIDTH
scale = (s / 100.0)
eps = base * scale
if dist == "gaussian":
noise = torch.randn_like(score_raw) * eps
elif dist == "uniform":
noise = (torch.rand_like(score_raw) * 2.0 - 1.0) * eps
elif dist == "laplace":
# Laplace(0, b): sample via inverse-CDF
u = torch.rand_like(score_raw) - 0.5
noise = -eps * torch.sign(u) * torch.log1p(-2.0 * torch.abs(u))
else:
return score_raw
return score_raw + noise
def forward(
self,
tgt,
memory,
image_embedding,
norm=False,
hw_noise_dist: str = "none",
hw_noise_width: int = 0,
hw_noise_strength: int = 0,
hdc_bits: int = 32):
if self.norm_before:
tgt /= tgt.norm(dim=-1, keepdim=True)
memory /= memory.norm(dim=-1, keepdim=True)
x = self.vision_adapter(tgt)
tgt = self.ratio_image * x + (1 - self.ratio_image) * tgt
x0 = self.vision_adapter(image_embedding)
image_embedding_temp = self.ratio_image * x0 + (1 - self.ratio_image) * image_embedding
y = self.text_adapter(memory)[0]
memory = self.ratio_text*y + (1 - self.ratio_text) * memory
tgt = self.ratio_glob*self.glob_adapter(tgt, image_embedding_temp, image_embedding_temp)[0] + (1 - self.ratio_glob)*tgt
tgt_new, memory_new = self.decoder(tgt,memory,None)
score_raw = torch.mm(tgt_new,memory_new.T)
# add noise
score_raw = self._apply_hw_noise(score_raw, hw_noise_dist, hw_noise_width, hw_noise_strength)
if self.norm_after:
score_raw = self.Norm(score_raw)
if self.score_function_name == 'HDC':
score_res = self.ScoreFunction(score_raw, quant_bits=hdc_bits)
else:
score_res = self.ScoreFunction(score_raw)
return tgt_new, memory_new, score_res, score_raw
def Norm(self, score):
min_val = score.min()
max_val = score.max()
res = self.MIN_Val + ((score - min_val) * (self.MAX_Val - self.MIN_Val)) / (max_val - min_val)
return res |