import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Parameter class LinearSVDO(nn.Module): def __init__(self, in_features, out_features, alpha_threshold, theta_threshold, device): super(LinearSVDO, self).__init__() self.in_features = in_features self.out_features = out_features self.alpha_threshold = alpha_threshold self.theta_threshold = theta_threshold self.device = device self.W = Parameter(torch.Tensor(out_features, in_features)) self.log_sigma = Parameter(torch.Tensor(out_features, in_features)) self.bias = Parameter(torch.Tensor(1, out_features)) self.reset_parameters() def reset_parameters(self): self.bias.data.zero_() self.W.data.normal_(0, 0.02) self.log_sigma.data.fill_(-5) def forward(self, x): self.log_alpha = self.log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(self.W)) self.log_alpha = torch.clamp(self.log_alpha, -10, 10) if self.training: lrt_mean = F.linear(x, self.W) + self.bias lrt_std = F.linear(torch.sqrt(x * x), torch.exp(2*self.log_sigma)+ 1e-8) eps = torch.randn_like(lrt_std) return lrt_mean + lrt_std * eps out = self.W * (self.log_alpha < self.alpha_threshold).float() out = F.linear(x, out) + self.bias return out def get_pruned_weights(self): W = self.W * (self.log_alpha < self.alpha_threshold).float() return W def get_num_remained_weights(self): num = ((self.log_alpha < self.alpha_threshold) * (torch.abs(self.W) > self.theta_threshold)).sum().item() return num def kl_reg(self): k1, k2, k3 = torch.FloatTensor([0.63576]).to(self.device), torch.FloatTensor([1.8732]).to(self.device), torch.FloatTensor([1.48695]).to(self.device) KL = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha)) KL = - torch.sum(KL) return KL