Spaces:
Sleeping
Sleeping
| from __future__ import absolute_import, print_function | |
| import torch | |
| from torch import nn | |
| import math | |
| from torch.nn.parameter import Parameter | |
| from torch.nn import functional as F | |
| import numpy as np | |
| # | |
| class MemoryUnit(nn.Module): | |
| def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025): | |
| super(MemoryUnit, self).__init__() | |
| self.mem_dim = mem_dim | |
| self.fea_dim = fea_dim | |
| self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # M x C | |
| self.bias = None | |
| self.shrink_thres= shrink_thres | |
| # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| stdv = 1. / math.sqrt(self.weight.size(1)) | |
| self.weight.data.uniform_(-stdv, stdv) | |
| if self.bias is not None: | |
| self.bias.data.uniform_(-stdv, stdv) | |
| def forward(self, input, period_score): | |
| # print(input.shape) | |
| score,indices = torch.max(period_score,dim=1) | |
| indices = (torch.floor((indices/126)*self.mem_dim).cpu().numpy()).astype(int) | |
| # # print(indices) | |
| att_weight = F.linear(input, self.weight) # Fea x Mem^T, (TxC) x (CxM) = TxM | |
| # BUGFIX: Original code had undefined variable 'i' in lines below | |
| # Period-aware attention enhancement (fixed for batched processing) | |
| # For now, we'll use the first batch element's period for all tokens | |
| # TODO: Properly implement batch-specific period enhancement | |
| if len(indices) > 0: | |
| i = 0 # Use first batch element's period | |
| # Clamp indices to valid range | |
| start_idx = max(0, indices[i] - 7) | |
| end_idx = min(self.mem_dim, indices[i] + 8) | |
| if start_idx < end_idx: | |
| att_weight[:, start_idx:end_idx] = att_weight[:, start_idx:end_idx] + att_weight[:, start_idx:end_idx].clone() * score[i].item() | |
| att_weight = F.softmax(att_weight, dim=1) # TxM | |
| # print(att_weight.shape) | |
| # print(period_score.shape) | |
| # ReLU based shrinkage, hard shrinkage for positive value | |
| if(self.shrink_thres>0): | |
| att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres) | |
| # att_weight = F.softshrink(att_weight, lambd=self.shrink_thres) | |
| # normalize??? | |
| att_weight = F.normalize(att_weight, p=1, dim=1) | |
| # att_weight = F.softmax(att_weight, dim=1) | |
| # att_weight = self.hard_sparse_shrink_opt(att_weight) | |
| mem_trans = self.weight.permute(1, 0) # Mem^T, MxC | |
| output = F.linear(att_weight, mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC | |
| return {'output': output, 'att': att_weight} # output, att_weight | |
| def extra_repr(self): | |
| return 'mem_dim={}, fea_dim={}'.format( | |
| self.mem_dim, self.fea_dim is not None | |
| ) | |
| # NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW | |
| class MemModule(nn.Module): | |
| def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'): | |
| super(MemModule, self).__init__() | |
| self.mem_dim = mem_dim | |
| self.fea_dim = fea_dim | |
| self.shrink_thres = shrink_thres | |
| self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres) | |
| def forward(self, input, period_score): | |
| s = input.data.shape | |
| l = len(s)# 5 | |
| if l == 3: | |
| x = input.permute(0, 2, 1) | |
| elif l == 4: | |
| x = input.permute(0, 2, 3, 1) | |
| elif l == 5: | |
| x = input.permute(0, 2, 3, 4, 1) | |
| else: | |
| x = [] | |
| print('wrong feature map size') | |
| x = x.contiguous() | |
| x = x.view(-1, s[1]) | |
| # | |
| y_and = self.memory(x,period_score) | |
| # | |
| y = y_and['output'] | |
| att = y_and['att'] | |
| if l == 3: | |
| y = y.view(s[0], s[2], s[1]) | |
| y = y.permute(0, 2, 1) | |
| att = att.view(s[0], s[2], self.mem_dim) | |
| att = att.permute(0, 2, 1) | |
| elif l == 4: | |
| y = y.view(s[0], s[2], s[3], s[1]) | |
| y = y.permute(0, 3, 1, 2) | |
| att = att.view(s[0], s[2], s[3], self.mem_dim) | |
| att = att.permute(0, 3, 1, 2) | |
| elif l == 5: | |
| y = y.view(s[0], s[2], s[3], s[4], s[1]) | |
| y = y.permute(0, 4, 1, 2, 3) | |
| att = att.view(s[0], s[2], s[3], s[4], self.mem_dim) | |
| att = att.permute(0, 4, 1, 2, 3) | |
| else: | |
| y = x | |
| att = att | |
| print('wrong feature map size') | |
| return {'output': y, 'att': att} | |
| # relu based hard shrinkage function, only works for positive values | |
| def hard_shrink_relu(input, lambd=0, epsilon=1e-12): | |
| output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon) | |
| return output | |