Spaces:
Running
Running
| import torch | |
| from utils.encode.quantizer import LinearQuantizer | |
| import math | |
| from scipy.special import gamma as Gamma | |
| import numpy as np | |
| import dask.array as da | |
| class DeepShape: | |
| def __init__(self): | |
| self.gamma_table = torch.load('utils/gamma_table.pt') | |
| self.rho_table = torch.load('utils/rho_table.pt') | |
| """estimate GGD parameters""" | |
| def Calc_GG_params(self, model, adj_minnum = 0): | |
| #get parameters | |
| params = [] | |
| for param in model.parameters(): | |
| params.append(param.flatten()) | |
| params = torch.cat(params).detach() | |
| params_org = params.clone() | |
| # Quantization | |
| lq = LinearQuantizer(params, 13) | |
| params = lq.quant(params) | |
| #sorting | |
| elements, counts = torch.unique(params, return_counts=True) | |
| # dask_params = da.from_array(params.numpy(), chunks=int(1e8)) #if param's size is big | |
| # elements, counts = da.unique(dask_params, return_counts=True) | |
| # elements = torch.from_numpy(elements.compute()) | |
| # counts = torch.from_numpy(counts.compute()) | |
| indices = torch.argsort(counts, descending=True) | |
| elements = elements[indices] | |
| counts = counts[indices] | |
| if adj_minnum > 0: | |
| param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long() | |
| # print("param_max", (param_max/(2**13))) | |
| # print('max_param, num_max_param', (elements[0]/(2**13)), counts[0]) | |
| elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))] | |
| else: | |
| elements_cut = params_org | |
| #estimate | |
| n = len(elements_cut) | |
| var = torch.sum(torch.pow(elements_cut, 2)) | |
| mean = torch.sum(torch.abs(elements_cut)) | |
| self.gamma_table = self.gamma_table.to(elements_cut.device) | |
| self.rho_table = self.rho_table.to(elements_cut.device) | |
| rho = n * var / mean ** 2 | |
| pos = torch.argmin(torch.abs(rho - self.rho_table)).item() | |
| shape = self.gamma_table[pos].item() | |
| std = torch.sqrt(var / n) | |
| beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std | |
| mu = torch.mean(elements_cut) | |
| print("mu:", mu) | |
| print('shape:', shape) | |
| print('beta',(beta)) | |
| return mu, shape, beta | |
| """GGD deepshape remap""" | |
| def GGD_deepshape(self, model, shape_scale=0.8, std_scale=0.6, adj_minnum = 1000): | |
| #get parameters | |
| params = [] | |
| for param in model.parameters(): | |
| params.append(param.flatten()) | |
| params = torch.cat(params).detach() | |
| params_org = params.clone() | |
| # Quantization | |
| lq = LinearQuantizer(params, 13) | |
| params = lq.quant(params) | |
| #sorting | |
| elements, counts = torch.unique(params, return_counts=True) | |
| indices = torch.argsort(counts, descending=True) | |
| elements = elements[indices] | |
| counts = counts[indices] | |
| if adj_minnum > 0: | |
| param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long() | |
| elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))] | |
| else: | |
| elements_cut = params_org | |
| param_max=0 | |
| #estimate org GGD | |
| n = len(elements_cut) | |
| var = torch.sum(torch.pow(elements_cut, 2)) | |
| mean = torch.sum(torch.abs(elements_cut)) | |
| self.gamma_table = self.gamma_table.to(elements_cut.device) | |
| self.rho_table = self.rho_table.to(elements_cut.device) | |
| rho = n * var / mean ** 2 | |
| pos = torch.argmin(torch.abs(rho - self.rho_table)).item() | |
| shape = self.gamma_table[pos].item() | |
| std = torch.sqrt(var / n) | |
| beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std | |
| mu_est = torch.mean(elements_cut) | |
| print("org mu:", mu_est) | |
| print('org shape:', shape) | |
| print('org beta',beta) | |
| beta = (beta * (2**13)) | |
| mu_est = int(mu_est*(2**13)) | |
| #sorting params in [-param_pax, param_max] | |
| if adj_minnum>0: | |
| adj_indices = torch.nonzero((params>=mu_est-param_max)&(params<=mu_est+param_max), as_tuple=False).squeeze() | |
| adj_indices = adj_indices[torch.argsort(params[(params>=mu_est-param_max)&(params<=mu_est+param_max)], descending=False)] | |
| adj_num = len(adj_indices) | |
| else: | |
| adj_indices = torch.argsort(params, descending=False) | |
| adj_num = len(adj_indices) | |
| #remape new GGD | |
| new_params = params.clone() | |
| new_shape = shape * shape_scale | |
| new_beta = beta * std_scale | |
| if(beta<=0): | |
| beta=1 | |
| x = torch.arange(mu_est-param_max, mu_est+param_max+1, device=params.device) | |
| new_ratio = -torch.pow(torch.abs(x.float()-mu_est)/new_beta, new_shape) | |
| new_ratio = torch.exp(new_ratio) | |
| new_ratio = new_ratio / torch.sum(new_ratio) | |
| new_num = (adj_num * new_ratio).long() | |
| num_temp = 0 | |
| for i in range(0, 2*param_max+1): | |
| new_params[adj_indices[num_temp : num_temp+new_num[i]]]=i+mu_est-param_max | |
| num_temp += new_num[i] | |
| new_params=new_params.float()/(2**13) | |
| #modify model parameters | |
| j=0 | |
| for name, param in model.named_parameters(): | |
| shape=param.data.shape | |
| param_flatten = torch.flatten(param.data) | |
| param_flatten = new_params[j: j+len(param_flatten)] | |
| j+=len(param_flatten) | |
| param_flatten = param_flatten.reshape(shape) | |
| param.data= param_flatten | |
| print("new mu:", float(mu_est)/(2**13)) | |
| print('new_shape:', new_shape) | |
| print('new beta', float(new_beta)/(2**13)) | |
| return float(mu_est)/(2**13), new_shape, float(new_beta)/(2**13) | |