from __future__ import absolute_import, print_function import torch from torch import nn import math from torch.nn.parameter import Parameter from torch.nn import functional as F import numpy as np # class MemoryUnit(nn.Module): def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025): super(MemoryUnit, self).__init__() self.mem_dim = mem_dim self.fea_dim = fea_dim self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # M x C self.bias = None self.shrink_thres= shrink_thres # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input, period_score): # print(input.shape) score,indices = torch.max(period_score,dim=1) indices = (torch.floor((indices/126)*self.mem_dim).cpu().numpy()).astype(int) # # print(indices) att_weight = F.linear(input, self.weight) # Fea x Mem^T, (TxC) x (CxM) = TxM # BUGFIX: Original code had undefined variable 'i' in lines below # Period-aware attention enhancement (fixed for batched processing) # For now, we'll use the first batch element's period for all tokens # TODO: Properly implement batch-specific period enhancement if len(indices) > 0: i = 0 # Use first batch element's period # Clamp indices to valid range start_idx = max(0, indices[i] - 7) end_idx = min(self.mem_dim, indices[i] + 8) if start_idx < end_idx: att_weight[:, start_idx:end_idx] = att_weight[:, start_idx:end_idx] + att_weight[:, start_idx:end_idx].clone() * score[i].item() att_weight = F.softmax(att_weight, dim=1) # TxM # print(att_weight.shape) # print(period_score.shape) # ReLU based shrinkage, hard shrinkage for positive value if(self.shrink_thres>0): att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres) # att_weight = F.softshrink(att_weight, lambd=self.shrink_thres) # normalize??? att_weight = F.normalize(att_weight, p=1, dim=1) # att_weight = F.softmax(att_weight, dim=1) # att_weight = self.hard_sparse_shrink_opt(att_weight) mem_trans = self.weight.permute(1, 0) # Mem^T, MxC output = F.linear(att_weight, mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC return {'output': output, 'att': att_weight} # output, att_weight def extra_repr(self): return 'mem_dim={}, fea_dim={}'.format( self.mem_dim, self.fea_dim is not None ) # NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW class MemModule(nn.Module): def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'): super(MemModule, self).__init__() self.mem_dim = mem_dim self.fea_dim = fea_dim self.shrink_thres = shrink_thres self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres) def forward(self, input, period_score): s = input.data.shape l = len(s)# 5 if l == 3: x = input.permute(0, 2, 1) elif l == 4: x = input.permute(0, 2, 3, 1) elif l == 5: x = input.permute(0, 2, 3, 4, 1) else: x = [] print('wrong feature map size') x = x.contiguous() x = x.view(-1, s[1]) # y_and = self.memory(x,period_score) # y = y_and['output'] att = y_and['att'] if l == 3: y = y.view(s[0], s[2], s[1]) y = y.permute(0, 2, 1) att = att.view(s[0], s[2], self.mem_dim) att = att.permute(0, 2, 1) elif l == 4: y = y.view(s[0], s[2], s[3], s[1]) y = y.permute(0, 3, 1, 2) att = att.view(s[0], s[2], s[3], self.mem_dim) att = att.permute(0, 3, 1, 2) elif l == 5: y = y.view(s[0], s[2], s[3], s[4], s[1]) y = y.permute(0, 4, 1, 2, 3) att = att.view(s[0], s[2], s[3], s[4], self.mem_dim) att = att.permute(0, 4, 1, 2, 3) else: y = x att = att print('wrong feature map size') return {'output': y, 'att': att} # relu based hard shrinkage function, only works for positive values def hard_shrink_relu(input, lambd=0, epsilon=1e-12): output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon) return output