Spaces:
Runtime error
Runtime error
| import math | |
| from argparse import Namespace | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import models | |
| from models import register | |
| import numpy as np | |
| class ExpansionNet(nn.Module): | |
| def __init__(self, args): | |
| super(ExpansionNet, self).__init__() | |
| self.args = args | |
| self.in_dim = args.in_dim | |
| self.out_dim = args.out_dim | |
| self.hidden_list = args.hidden_list | |
| layers = [] | |
| lastv = self.in_dim | |
| hidden_list = self.hidden_list | |
| out_dim = self.out_dim | |
| for hidden in hidden_list: | |
| layers.append(nn.Linear(lastv, hidden)) | |
| layers.append(nn.ReLU()) | |
| lastv = hidden | |
| layers.append(nn.Linear(lastv, out_dim)) | |
| self.layers = nn.Sequential(*layers) | |
| def forward(self, x): | |
| b, _, c = x.shape | |
| x = x.view(-1, c) | |
| logits = self.layers(x) | |
| out = nn.functional.normalize(logits, dim=1) | |
| return out.view(b,_,self.out_dim) | |
| def make_ExpansionNet(in_dim=580,out_dim=10,hidden_list=None): | |
| args = Namespace() | |
| args.in_dim = in_dim | |
| args.out_dim = out_dim | |
| args.hidden_list = hidden_list | |
| return ExpansionNet(args) | |