| import torch |
| import torch.utils.data |
| import torch.nn as nn |
|
|
| def get_model(params): |
|
|
| if params['model'] == 'ResidualFCNet': |
| return ResidualFCNet(params['input_dim'], params['num_classes'], params['num_filts'], params['depth']) |
| elif params['model'] == 'LinNet': |
| return LinNet(params['input_dim'], params['num_classes']) |
| else: |
| raise NotImplementedError('Invalid model specified.') |
|
|
| class ResLayer(nn.Module): |
| def __init__(self, linear_size): |
| super(ResLayer, self).__init__() |
| self.l_size = linear_size |
| self.nonlin1 = nn.ReLU(inplace=True) |
| self.nonlin2 = nn.ReLU(inplace=True) |
| self.dropout1 = nn.Dropout() |
| self.w1 = nn.Linear(self.l_size, self.l_size) |
| self.w2 = nn.Linear(self.l_size, self.l_size) |
|
|
| def forward(self, x): |
| y = self.w1(x) |
| y = self.nonlin1(y) |
| y = self.dropout1(y) |
| y = self.w2(y) |
| y = self.nonlin2(y) |
| out = x + y |
| return out |
|
|
| class ResidualFCNet(nn.Module): |
|
|
| def __init__(self, num_inputs, num_classes, num_filts, depth=4): |
| super(ResidualFCNet, self).__init__() |
| self.inc_bias = False |
| self.class_emb = nn.Linear(num_filts, num_classes, bias=self.inc_bias) |
| layers = [] |
| layers.append(nn.Linear(num_inputs, num_filts)) |
| layers.append(nn.ReLU(inplace=True)) |
| for i in range(depth): |
| layers.append(ResLayer(num_filts)) |
| self.feats = torch.nn.Sequential(*layers) |
|
|
| def forward(self, x, class_of_interest=None, return_feats=False): |
| loc_emb = self.feats(x) |
| if return_feats: |
| return loc_emb |
| if class_of_interest is None: |
| class_pred = self.class_emb(loc_emb) |
| else: |
| class_pred = self.eval_single_class(loc_emb, class_of_interest) |
| return torch.sigmoid(class_pred) |
|
|
| def eval_single_class(self, x, class_of_interest): |
| if self.inc_bias: |
| return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) + self.class_emb.bias[class_of_interest] |
| else: |
| return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) |
|
|
| class LinNet(nn.Module): |
| def __init__(self, num_inputs, num_classes): |
| super(LinNet, self).__init__() |
| self.num_layers = 0 |
| self.inc_bias = False |
| self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias) |
| self.feats = nn.Identity() |
|
|
| def forward(self, x, class_of_interest=None, return_feats=False): |
| loc_emb = self.feats(x) |
| if return_feats: |
| return loc_emb |
| if class_of_interest is None: |
| class_pred = self.class_emb(loc_emb) |
| else: |
| class_pred = self.eval_single_class(loc_emb, class_of_interest) |
|
|
| return torch.sigmoid(class_pred) |
|
|
| def eval_single_class(self, x, class_of_interest): |
| if self.inc_bias: |
| return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) + self.class_emb.bias[class_of_interest] |
| else: |
| return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) |
|
|