import torch import numpy as np import copy import torch.nn.functional as F from torch.nn import Parameter class HDReason(torch.nn.Module): def __init__(self, d=10, D=256): super().__init__() self.d = d self.D = D self.vertex_d = 64 self.q_proj = torch.nn.Linear(self.d, self.vertex_d) self.k_proj = torch.nn.Linear(self.d, self.vertex_d) self.v_proj = torch.nn.Linear(self.d, self.vertex_d) self.HDC_encoder = torch.nn.Linear(self.vertex_d, self.D) self.HDC_encoder.requires_grad = False self.Linear = torch.nn.Linear(d, self.vertex_d) self.scale = self.D ** -0.5 #TODO: May need to chaneg self.activation0 = torch.nn.ReLU() self.activation1 = torch.nn.ReLU() def forward(self, x): #NOTE: build adjacency graph q = self.activation1(self.HDC_encoder(self.activation0(self.q_proj(x)))) k = self.activation1(self.HDC_encoder(self.activation0(self.k_proj(x)))) q = q * self.scale adj = q @ k.transpose(-2, -1) adj = adj.softmax(dim=-1) #NOTE: vertex hypervector v = self.activation1(self.HDC_encoder(self.activation0(self.v_proj(x)))) #NOTE: GrapHD memorization out = adj @ v out = out*0.3 + 0.7*self.HDC_encoder(self.activation0(self.Linear(x))) return out class ScoreFunctionHDC(torch.nn.Module): def __init__(self, N_words=20, HDV_D=512) -> None: super().__init__() self.D = HDV_D self.N_words = N_words self.norm = torch.nn.LayerNorm(self.D) self.HDReason = HDReason(d=self.N_words, D=self.D) self.Linear2 = torch.nn.Linear(self.D, self.D // 2) self.Linear3 = torch.nn.Linear(self.D // 2, self.D // 8) self.Linear4 = torch.nn.Linear(self.D // 8, 1) self.Activation1 = torch.nn.ReLU() self.Activation2 = torch.nn.Sigmoid() self.register_parameter('bias',Parameter(torch.zeros(1))) def forward(self, x): #NOTE: input has shape NxN_word #NOTE: N_bbox x N_word output = self.HDReason(x) output = self.norm(output) output = self.Activation1(output) output = self.Linear2(output) output = self.Activation1(output) output = self.Linear3(output) output = self.Activation1(output) output = self.Linear4(output) + self.bias output = self.Activation2(output) return output