jangwon-kim-cocel's picture
Upload 11 files
96170c3 verified
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class LinearSVDO(nn.Module):
def __init__(self, in_features, out_features, alpha_threshold, theta_threshold, device):
super(LinearSVDO, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.alpha_threshold = alpha_threshold
self.theta_threshold = theta_threshold
self.device = device
self.W = Parameter(torch.Tensor(out_features, in_features))
self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
self.bias = Parameter(torch.Tensor(1, out_features))
self.reset_parameters()
def reset_parameters(self):
self.bias.data.zero_()
self.W.data.normal_(0, 0.02)
self.log_sigma.data.fill_(-5)
def forward(self, x):
self.log_alpha = self.log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(self.W))
self.log_alpha = torch.clamp(self.log_alpha, -10, 10)
if self.training:
lrt_mean = F.linear(x, self.W) + self.bias
lrt_std = F.linear(torch.sqrt(x * x), torch.exp(2*self.log_sigma)+ 1e-8)
eps = torch.randn_like(lrt_std)
return lrt_mean + lrt_std * eps
out = self.W * (self.log_alpha < self.alpha_threshold).float()
out = F.linear(x, out) + self.bias
return out
def get_pruned_weights(self):
W = self.W * (self.log_alpha < self.alpha_threshold).float()
return W
def get_num_remained_weights(self):
num = ((self.log_alpha < self.alpha_threshold) * (torch.abs(self.W) > self.theta_threshold)).sum().item()
return num
def kl_reg(self):
k1, k2, k3 = torch.FloatTensor([0.63576]).to(self.device), torch.FloatTensor([1.8732]).to(self.device), torch.FloatTensor([1.48695]).to(self.device)
KL = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
KL = - torch.sum(KL)
return KL