Spaces:
Runtime error
Runtime error
| import math | |
| from argparse import Namespace | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models import register | |
| class gen_basis(nn.Module): | |
| def __init__(self, args): | |
| super(gen_basis, self).__init__() | |
| self.basis_num = args.basis_num | |
| self.hidden = args.hidden | |
| self.state = args.state | |
| self.path=args.path | |
| def init_basis_bias(self): | |
| self.w0 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*580), requires_grad=True) | |
| nn.init.kaiming_uniform_(self.w0, a=math.sqrt(5)) | |
| self.w1 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True) | |
| nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) | |
| self.w2 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True) | |
| nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) | |
| self.w3 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True) | |
| nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) | |
| self.w4 = nn.Parameter(torch.Tensor(self.basis_num,3*self.hidden), requires_grad=True) | |
| nn.init.kaiming_uniform_(self.w4, a=math.sqrt(5)) | |
| basis = [self.w0, self.w1, self.w2, self.w3, self.w4] | |
| self.bias1 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True) | |
| self.bias2 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True) | |
| self.bias3 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True) | |
| self.bias4 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True) | |
| self.bias5 = nn.Parameter(torch.Tensor(self.basis_num,3), requires_grad=True) | |
| bias = [self.bias1,self.bias2,self.bias3,self.bias4,self.bias5] | |
| for i in range(len(bias)): | |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(basis[i]) | |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
| nn.init.uniform_(bias[i], -bound, bound) | |
| return basis,bias | |
| def load_basis_for_test_kaiming(self,path): | |
| model_spec = torch.load(path)['model'] | |
| w0 = model_spec['sd']['basis.w0'] | |
| w1 = model_spec['sd']['basis.w1'] | |
| w2 = model_spec['sd']['basis.w2'] | |
| w3 = model_spec['sd']['basis.w3'] | |
| w4 = model_spec['sd']['basis.w4'] | |
| b0 = model_spec['sd']['basis.bias1'] | |
| b1 = model_spec['sd']['basis.bias2'] | |
| b2 = model_spec['sd']['basis.bias3'] | |
| b3 = model_spec['sd']['basis.bias4'] | |
| b4 = model_spec['sd']['basis.bias5'] | |
| torch.cuda.empty_cache() | |
| return [w0,w1,w2,w3,w4],[b0,b1,b2,b3,b4] | |
| def forward(self): | |
| if self.state=='train': | |
| print('init_basis_use_kaiming') | |
| res=self.init_basis_bias() | |
| else: | |
| print('load_basis_from_model') | |
| res=self.load_basis_for_test_kaiming(self.path) | |
| return res | |
| def make_basis(basis_num=10,hidden=16,state=None,path=None): | |
| args = Namespace() | |
| args.basis_num = basis_num | |
| args.hidden = hidden | |
| args.state = state | |
| args.path = path | |
| return gen_basis(args) | |