Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import copy | |
| import torch.nn.functional as F | |
| from torch.nn import Parameter | |
| def _fake_quant_sym(x: torch.Tensor, bits: int, eps: float = 1e-8) -> torch.Tensor: | |
| bits = int(bits) | |
| if bits >= 32: | |
| return x | |
| if bits == 16: | |
| # pick fp16; if you prefer bf16: x.to(torch.bfloat16).to(torch.float32) | |
| return x.to(torch.float16).to(torch.float32) | |
| if bits == 1: | |
| return torch.sign(x) | |
| # signed symmetric levels: [-Qmax, Qmax] | |
| Qmax = (1 << (bits - 1)) - 1 | |
| # per-row scale (last dim); works for both (N,d) and (...,d) | |
| max_abs = x.abs().amax(dim=-1, keepdim=True).clamp(min=eps) | |
| scale = max_abs / Qmax | |
| q = torch.round(x / scale).clamp(-Qmax, Qmax) | |
| return (q * scale).to(x.dtype) | |
| def qlinear(x: torch.Tensor, layer: torch.nn.Linear, bits: int) -> torch.Tensor: | |
| """Quantize BOTH activation and weight, then do linear in float.""" | |
| if int(bits) >= 32: | |
| return layer(x) | |
| if int(bits) == 16: | |
| # do true fp16 compute-ish (still uses PyTorch kernels) | |
| x16 = x.to(torch.float16) | |
| w16 = layer.weight.to(torch.float16) | |
| b16 = None if layer.bias is None else layer.bias.to(torch.float16) | |
| y16 = F.linear(x16, w16, b16) | |
| return y16.to(torch.float32) | |
| xq = _fake_quant_sym(x, bits) | |
| wq = _fake_quant_sym(layer.weight, bits) | |
| b = layer.bias # keep bias float (common & stable) | |
| y = F.linear(xq, wq, b) | |
| return _fake_quant_sym(y, bits) | |
| class HDReason(torch.nn.Module): | |
| def __init__(self, d=10, D=256): | |
| super().__init__() | |
| self.d = d | |
| self.D = D | |
| self.vertex_d = 64 | |
| self.q_proj = torch.nn.Linear(self.d, self.vertex_d) | |
| self.k_proj = torch.nn.Linear(self.d, self.vertex_d) | |
| self.v_proj = torch.nn.Linear(self.d, self.vertex_d) | |
| self.HDC_encoder = torch.nn.Linear(self.vertex_d, self.D) | |
| self.HDC_encoder.requires_grad = False | |
| self.Linear = torch.nn.Linear(d, self.vertex_d) | |
| self.scale = self.D ** -0.5 | |
| #TODO: May need to chaneg | |
| self.activation0 = torch.nn.ReLU() | |
| self.activation1 = torch.nn.ReLU() | |
| """ | |
| def forward(self, x): | |
| #NOTE: build adjacency graph | |
| q = self.activation1(self.HDC_encoder(self.activation0(self.q_proj(x)))) | |
| k = self.activation1(self.HDC_encoder(self.activation0(self.k_proj(x)))) | |
| q = q * self.scale | |
| adj = q @ k.transpose(-2, -1) | |
| adj = adj.softmax(dim=-1) | |
| #NOTE: vertex hypervector | |
| v = self.activation1(self.HDC_encoder(self.activation0(self.v_proj(x)))) | |
| #NOTE: GrapHD memorization | |
| out = adj @ v | |
| out = out*0.3 + 0.7*self.HDC_encoder(self.activation0(self.Linear(x))) | |
| return out | |
| """ | |
| def forward(self, x, quant_bits: int = 32): | |
| b = int(quant_bits) | |
| # q path | |
| q = qlinear(x, self.q_proj, b) | |
| q = self.activation0(q) | |
| q = qlinear(q, self.HDC_encoder, b) | |
| q = self.activation1(q) | |
| # k path | |
| k = qlinear(x, self.k_proj, b) | |
| k = self.activation0(k) | |
| k = qlinear(k, self.HDC_encoder, b) | |
| k = self.activation1(k) | |
| q = _fake_quant_sym(q * self.scale, b) | |
| k = _fake_quant_sym(k, b) | |
| # adj matmul + softmax | |
| adj = _fake_quant_sym(q @ k.transpose(-2, -1), b) | |
| # softmax is sensitive at low-bit; keep it in fp32 but quantize output | |
| adj = adj.softmax(dim=-1) | |
| adj = _fake_quant_sym(adj, b) | |
| # v path | |
| v = qlinear(x, self.v_proj, b) | |
| v = self.activation0(v) | |
| v = qlinear(v, self.HDC_encoder, b) | |
| v = self.activation1(v) | |
| v = _fake_quant_sym(v, b) | |
| out = _fake_quant_sym(adj @ v, b) | |
| # skip/mix branch | |
| base = qlinear(x, self.Linear, b) | |
| base = self.activation0(base) | |
| base = qlinear(base, self.HDC_encoder, b) | |
| out = _fake_quant_sym(out * 0.3 + 0.7 * base, b) | |
| return out | |
| class ScoreFunctionHDC(torch.nn.Module): | |
| def __init__(self, N_words=20, HDV_D=512) -> None: | |
| super().__init__() | |
| self.D = HDV_D | |
| self.N_words = N_words | |
| self.norm = torch.nn.LayerNorm(self.D) | |
| self.HDReason = HDReason(d=self.N_words, D=self.D) | |
| self.Linear2 = torch.nn.Linear(self.D, self.D // 2) | |
| self.Linear3 = torch.nn.Linear(self.D // 2, self.D // 8) | |
| self.Linear4 = torch.nn.Linear(self.D // 8, 1) | |
| self.Activation1 = torch.nn.ReLU() | |
| self.Activation2 = torch.nn.Sigmoid() | |
| self.register_parameter('bias',Parameter(torch.zeros(1))) | |
| """ | |
| def forward(self, x): | |
| #NOTE: input has shape NxN_word | |
| #NOTE: N_bbox x N_word | |
| output = self.HDReason(x) | |
| output = self.norm(output) | |
| output = self.Activation1(output) | |
| output = self.Linear2(output) | |
| output = self.Activation1(output) | |
| output = self.Linear3(output) | |
| output = self.Activation1(output) | |
| output = self.Linear4(output) + self.bias | |
| output = self.Activation2(output) | |
| return output | |
| """ | |
| def forward(self, x, quant_bits: int = 32): | |
| b = int(quant_bits) | |
| # input activation quant (optional but consistent) | |
| if b < 32: | |
| x = _fake_quant_sym(x, b) | |
| output = self.HDReason(x, quant_bits=b) | |
| output = self.norm(output) # LayerNorm usually best left fp32 | |
| output = self.Activation1(output) | |
| if b < 16: | |
| output = _fake_quant_sym(output, b) | |
| output = qlinear(output, self.Linear2, b) | |
| output = self.Activation1(output) | |
| if b < 16: | |
| output = _fake_quant_sym(output, b) | |
| output = qlinear(output, self.Linear3, b) | |
| output = self.Activation1(output) | |
| if b < 16: | |
| output = _fake_quant_sym(output, b) | |
| output = qlinear(output, self.Linear4, b) + self.bias | |
| output = self.Activation2(output) | |
| return output |