Spaces:
Paused
Paused
| import torch | |
| import numpy as np | |
| import copy | |
| import torch.nn.functional as F | |
| from torch.nn import Parameter | |
| class HDReason(torch.nn.Module): | |
| def __init__(self, d=10, D=256): | |
| super().__init__() | |
| self.d = d | |
| self.D = D | |
| self.vertex_d = 64 | |
| self.q_proj = torch.nn.Linear(self.d, self.vertex_d) | |
| self.k_proj = torch.nn.Linear(self.d, self.vertex_d) | |
| self.v_proj = torch.nn.Linear(self.d, self.vertex_d) | |
| self.HDC_encoder = torch.nn.Linear(self.vertex_d, self.D) | |
| self.HDC_encoder.requires_grad = False | |
| self.Linear = torch.nn.Linear(d, self.vertex_d) | |
| self.scale = self.D ** -0.5 | |
| #TODO: May need to chaneg | |
| self.activation0 = torch.nn.ReLU() | |
| self.activation1 = torch.nn.ReLU() | |
| def forward(self, x): | |
| #NOTE: build adjacency graph | |
| q = self.activation1(self.HDC_encoder(self.activation0(self.q_proj(x)))) | |
| k = self.activation1(self.HDC_encoder(self.activation0(self.k_proj(x)))) | |
| q = q * self.scale | |
| adj = q @ k.transpose(-2, -1) | |
| adj = adj.softmax(dim=-1) | |
| #NOTE: vertex hypervector | |
| v = self.activation1(self.HDC_encoder(self.activation0(self.v_proj(x)))) | |
| #NOTE: GrapHD memorization | |
| out = adj @ v | |
| out = out*0.3 + 0.7*self.HDC_encoder(self.activation0(self.Linear(x))) | |
| return out | |
| class ScoreFunctionHDC(torch.nn.Module): | |
| def __init__(self, N_words=20, HDV_D=512) -> None: | |
| super().__init__() | |
| self.D = HDV_D | |
| self.N_words = N_words | |
| self.norm = torch.nn.LayerNorm(self.D) | |
| self.HDReason = HDReason(d=self.N_words, D=self.D) | |
| self.Linear2 = torch.nn.Linear(self.D, self.D // 2) | |
| self.Linear3 = torch.nn.Linear(self.D // 2, self.D // 8) | |
| self.Linear4 = torch.nn.Linear(self.D // 8, 1) | |
| self.Activation1 = torch.nn.ReLU() | |
| self.Activation2 = torch.nn.Sigmoid() | |
| self.register_parameter('bias',Parameter(torch.zeros(1))) | |
| def forward(self, x): | |
| #NOTE: input has shape NxN_word | |
| #NOTE: N_bbox x N_word | |
| output = self.HDReason(x) | |
| output = self.norm(output) | |
| output = self.Activation1(output) | |
| output = self.Linear2(output) | |
| output = self.Activation1(output) | |
| output = self.Linear3(output) | |
| output = self.Activation1(output) | |
| output = self.Linear4(output) + self.bias | |
| output = self.Activation2(output) | |
| return output |