TaskCLIP / models /ScoreFunction_HDC.py
HanningChen
Add noise
398e700
import torch
import numpy as np
import copy
import torch.nn.functional as F
from torch.nn import Parameter
def _fake_quant_sym(x: torch.Tensor, bits: int, eps: float = 1e-8) -> torch.Tensor:
bits = int(bits)
if bits >= 32:
return x
if bits == 16:
# pick fp16; if you prefer bf16: x.to(torch.bfloat16).to(torch.float32)
return x.to(torch.float16).to(torch.float32)
if bits == 1:
return torch.sign(x)
# signed symmetric levels: [-Qmax, Qmax]
Qmax = (1 << (bits - 1)) - 1
# per-row scale (last dim); works for both (N,d) and (...,d)
max_abs = x.abs().amax(dim=-1, keepdim=True).clamp(min=eps)
scale = max_abs / Qmax
q = torch.round(x / scale).clamp(-Qmax, Qmax)
return (q * scale).to(x.dtype)
def qlinear(x: torch.Tensor, layer: torch.nn.Linear, bits: int) -> torch.Tensor:
"""Quantize BOTH activation and weight, then do linear in float."""
if int(bits) >= 32:
return layer(x)
if int(bits) == 16:
# do true fp16 compute-ish (still uses PyTorch kernels)
x16 = x.to(torch.float16)
w16 = layer.weight.to(torch.float16)
b16 = None if layer.bias is None else layer.bias.to(torch.float16)
y16 = F.linear(x16, w16, b16)
return y16.to(torch.float32)
xq = _fake_quant_sym(x, bits)
wq = _fake_quant_sym(layer.weight, bits)
b = layer.bias # keep bias float (common & stable)
y = F.linear(xq, wq, b)
return _fake_quant_sym(y, bits)
class HDReason(torch.nn.Module):
def __init__(self, d=10, D=256):
super().__init__()
self.d = d
self.D = D
self.vertex_d = 64
self.q_proj = torch.nn.Linear(self.d, self.vertex_d)
self.k_proj = torch.nn.Linear(self.d, self.vertex_d)
self.v_proj = torch.nn.Linear(self.d, self.vertex_d)
self.HDC_encoder = torch.nn.Linear(self.vertex_d, self.D)
self.HDC_encoder.requires_grad = False
self.Linear = torch.nn.Linear(d, self.vertex_d)
self.scale = self.D ** -0.5
#TODO: May need to chaneg
self.activation0 = torch.nn.ReLU()
self.activation1 = torch.nn.ReLU()
"""
def forward(self, x):
#NOTE: build adjacency graph
q = self.activation1(self.HDC_encoder(self.activation0(self.q_proj(x))))
k = self.activation1(self.HDC_encoder(self.activation0(self.k_proj(x))))
q = q * self.scale
adj = q @ k.transpose(-2, -1)
adj = adj.softmax(dim=-1)
#NOTE: vertex hypervector
v = self.activation1(self.HDC_encoder(self.activation0(self.v_proj(x))))
#NOTE: GrapHD memorization
out = adj @ v
out = out*0.3 + 0.7*self.HDC_encoder(self.activation0(self.Linear(x)))
return out
"""
def forward(self, x, quant_bits: int = 32):
b = int(quant_bits)
# q path
q = qlinear(x, self.q_proj, b)
q = self.activation0(q)
q = qlinear(q, self.HDC_encoder, b)
q = self.activation1(q)
# k path
k = qlinear(x, self.k_proj, b)
k = self.activation0(k)
k = qlinear(k, self.HDC_encoder, b)
k = self.activation1(k)
q = _fake_quant_sym(q * self.scale, b)
k = _fake_quant_sym(k, b)
# adj matmul + softmax
adj = _fake_quant_sym(q @ k.transpose(-2, -1), b)
# softmax is sensitive at low-bit; keep it in fp32 but quantize output
adj = adj.softmax(dim=-1)
adj = _fake_quant_sym(adj, b)
# v path
v = qlinear(x, self.v_proj, b)
v = self.activation0(v)
v = qlinear(v, self.HDC_encoder, b)
v = self.activation1(v)
v = _fake_quant_sym(v, b)
out = _fake_quant_sym(adj @ v, b)
# skip/mix branch
base = qlinear(x, self.Linear, b)
base = self.activation0(base)
base = qlinear(base, self.HDC_encoder, b)
out = _fake_quant_sym(out * 0.3 + 0.7 * base, b)
return out
class ScoreFunctionHDC(torch.nn.Module):
def __init__(self, N_words=20, HDV_D=512) -> None:
super().__init__()
self.D = HDV_D
self.N_words = N_words
self.norm = torch.nn.LayerNorm(self.D)
self.HDReason = HDReason(d=self.N_words, D=self.D)
self.Linear2 = torch.nn.Linear(self.D, self.D // 2)
self.Linear3 = torch.nn.Linear(self.D // 2, self.D // 8)
self.Linear4 = torch.nn.Linear(self.D // 8, 1)
self.Activation1 = torch.nn.ReLU()
self.Activation2 = torch.nn.Sigmoid()
self.register_parameter('bias',Parameter(torch.zeros(1)))
"""
def forward(self, x):
#NOTE: input has shape NxN_word
#NOTE: N_bbox x N_word
output = self.HDReason(x)
output = self.norm(output)
output = self.Activation1(output)
output = self.Linear2(output)
output = self.Activation1(output)
output = self.Linear3(output)
output = self.Activation1(output)
output = self.Linear4(output) + self.bias
output = self.Activation2(output)
return output
"""
def forward(self, x, quant_bits: int = 32):
b = int(quant_bits)
# input activation quant (optional but consistent)
if b < 32:
x = _fake_quant_sym(x, b)
output = self.HDReason(x, quant_bits=b)
output = self.norm(output) # LayerNorm usually best left fp32
output = self.Activation1(output)
if b < 16:
output = _fake_quant_sym(output, b)
output = qlinear(output, self.Linear2, b)
output = self.Activation1(output)
if b < 16:
output = _fake_quant_sym(output, b)
output = qlinear(output, self.Linear3, b)
output = self.Activation1(output)
if b < 16:
output = _fake_quant_sym(output, b)
output = qlinear(output, self.Linear4, b) + self.bias
output = self.Activation2(output)
return output