import torch import numpy as np import copy import torch.nn.functional as F from torch.nn import Parameter def _fake_quant_sym(x: torch.Tensor, bits: int, eps: float = 1e-8) -> torch.Tensor: bits = int(bits) if bits >= 32: return x if bits == 16: # pick fp16; if you prefer bf16: x.to(torch.bfloat16).to(torch.float32) return x.to(torch.float16).to(torch.float32) if bits == 1: return torch.sign(x) # signed symmetric levels: [-Qmax, Qmax] Qmax = (1 << (bits - 1)) - 1 # per-row scale (last dim); works for both (N,d) and (...,d) max_abs = x.abs().amax(dim=-1, keepdim=True).clamp(min=eps) scale = max_abs / Qmax q = torch.round(x / scale).clamp(-Qmax, Qmax) return (q * scale).to(x.dtype) def qlinear(x: torch.Tensor, layer: torch.nn.Linear, bits: int) -> torch.Tensor: """Quantize BOTH activation and weight, then do linear in float.""" if int(bits) >= 32: return layer(x) if int(bits) == 16: # do true fp16 compute-ish (still uses PyTorch kernels) x16 = x.to(torch.float16) w16 = layer.weight.to(torch.float16) b16 = None if layer.bias is None else layer.bias.to(torch.float16) y16 = F.linear(x16, w16, b16) return y16.to(torch.float32) xq = _fake_quant_sym(x, bits) wq = _fake_quant_sym(layer.weight, bits) b = layer.bias # keep bias float (common & stable) y = F.linear(xq, wq, b) return _fake_quant_sym(y, bits) class HDReason(torch.nn.Module): def __init__(self, d=10, D=256): super().__init__() self.d = d self.D = D self.vertex_d = 64 self.q_proj = torch.nn.Linear(self.d, self.vertex_d) self.k_proj = torch.nn.Linear(self.d, self.vertex_d) self.v_proj = torch.nn.Linear(self.d, self.vertex_d) self.HDC_encoder = torch.nn.Linear(self.vertex_d, self.D) self.HDC_encoder.requires_grad = False self.Linear = torch.nn.Linear(d, self.vertex_d) self.scale = self.D ** -0.5 #TODO: May need to chaneg self.activation0 = torch.nn.ReLU() self.activation1 = torch.nn.ReLU() """ def forward(self, x): #NOTE: build adjacency graph q = self.activation1(self.HDC_encoder(self.activation0(self.q_proj(x)))) k = self.activation1(self.HDC_encoder(self.activation0(self.k_proj(x)))) q = q * self.scale adj = q @ k.transpose(-2, -1) adj = adj.softmax(dim=-1) #NOTE: vertex hypervector v = self.activation1(self.HDC_encoder(self.activation0(self.v_proj(x)))) #NOTE: GrapHD memorization out = adj @ v out = out*0.3 + 0.7*self.HDC_encoder(self.activation0(self.Linear(x))) return out """ def forward(self, x, quant_bits: int = 32): b = int(quant_bits) # q path q = qlinear(x, self.q_proj, b) q = self.activation0(q) q = qlinear(q, self.HDC_encoder, b) q = self.activation1(q) # k path k = qlinear(x, self.k_proj, b) k = self.activation0(k) k = qlinear(k, self.HDC_encoder, b) k = self.activation1(k) q = _fake_quant_sym(q * self.scale, b) k = _fake_quant_sym(k, b) # adj matmul + softmax adj = _fake_quant_sym(q @ k.transpose(-2, -1), b) # softmax is sensitive at low-bit; keep it in fp32 but quantize output adj = adj.softmax(dim=-1) adj = _fake_quant_sym(adj, b) # v path v = qlinear(x, self.v_proj, b) v = self.activation0(v) v = qlinear(v, self.HDC_encoder, b) v = self.activation1(v) v = _fake_quant_sym(v, b) out = _fake_quant_sym(adj @ v, b) # skip/mix branch base = qlinear(x, self.Linear, b) base = self.activation0(base) base = qlinear(base, self.HDC_encoder, b) out = _fake_quant_sym(out * 0.3 + 0.7 * base, b) return out class ScoreFunctionHDC(torch.nn.Module): def __init__(self, N_words=20, HDV_D=512) -> None: super().__init__() self.D = HDV_D self.N_words = N_words self.norm = torch.nn.LayerNorm(self.D) self.HDReason = HDReason(d=self.N_words, D=self.D) self.Linear2 = torch.nn.Linear(self.D, self.D // 2) self.Linear3 = torch.nn.Linear(self.D // 2, self.D // 8) self.Linear4 = torch.nn.Linear(self.D // 8, 1) self.Activation1 = torch.nn.ReLU() self.Activation2 = torch.nn.Sigmoid() self.register_parameter('bias',Parameter(torch.zeros(1))) """ def forward(self, x): #NOTE: input has shape NxN_word #NOTE: N_bbox x N_word output = self.HDReason(x) output = self.norm(output) output = self.Activation1(output) output = self.Linear2(output) output = self.Activation1(output) output = self.Linear3(output) output = self.Activation1(output) output = self.Linear4(output) + self.bias output = self.Activation2(output) return output """ def forward(self, x, quant_bits: int = 32): b = int(quant_bits) # input activation quant (optional but consistent) if b < 32: x = _fake_quant_sym(x, b) output = self.HDReason(x, quant_bits=b) output = self.norm(output) # LayerNorm usually best left fp32 output = self.Activation1(output) if b < 16: output = _fake_quant_sym(output, b) output = qlinear(output, self.Linear2, b) output = self.Activation1(output) if b < 16: output = _fake_quant_sym(output, b) output = qlinear(output, self.Linear3, b) output = self.Activation1(output) if b < 16: output = _fake_quant_sym(output, b) output = qlinear(output, self.Linear4, b) + self.bias output = self.Activation2(output) return output