| import os |
| import os.path as osp |
| import pdb |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| |
| import json |
| from packaging import version |
| import torch.distributed as dist |
|
|
|
|
| class OD_model(nn.Module): |
| def __init__(self, args): |
| super().__init__() |
| self.args = args |
| self.order_num = args.order_num |
| if args.od_type == 'linear_cat': |
| |
| |
| self.order_dense_1 = nn.Linear(args.hidden_size * self.order_num, args.hidden_size) |
| if self.args.num_od_layer > 0: |
| self.layer = nn.ModuleList([OD_Layer_linear(args) for _ in range(args.num_od_layer)]) |
|
|
| self.order_dense_2 = nn.Linear(args.hidden_size, 1) |
|
|
| self.actication = nn.LeakyReLU() |
| self.bn = torch.nn.BatchNorm1d(args.hidden_size) |
| self.dp = nn.Dropout(p=args.hidden_dropout_prob) |
| self.loss_func = nn.BCEWithLogitsLoss() |
| |
|
|
| def forward(self, input, labels): |
| |
| |
| loss_dic = {} |
| pre = self.predict(input) |
| |
| loss = self.loss_func(pre, labels.unsqueeze(1)) |
| loss_dic['order_loss'] = loss.item() |
| return loss, loss_dic |
|
|
| def encode(self, input): |
| if self.args.num_od_layer > 0: |
| for layer_module in self.layer: |
| input = layer_module(input) |
| inputs = torch.chunk(input, 2, dim=0) |
| emb = torch.concat(inputs, dim=1) |
| return self.actication(self.order_dense_1(self.dp(emb))) |
|
|
| def predict(self, input): |
| return self.order_dense_2(self.bn(self.encode(input))) |
|
|
| def right_caculate(self, input, labels, threshold=0.5): |
| input = input.squeeze(1).tolist() |
| labels = labels.tolist() |
| right = 0 |
| for i in range(len(input)): |
| if (input[i] >= threshold and labels[i] >= 0.5) or (input[i] < threshold and labels[i] < 0.5): |
| right += 1 |
| return right |
|
|
|
|
| class OD_Layer_linear(nn.Module): |
| def __init__(self, args): |
| super().__init__() |
| self.args = args |
| self.dense = nn.Linear(args.hidden_size, args.hidden_size) |
| self.actication = nn.LeakyReLU() |
| self.bn = torch.nn.BatchNorm1d(args.hidden_size) |
| self.dropout = nn.Dropout(p=args.hidden_dropout_prob) |
|
|
| def forward(self, input): |
| return self.actication(self.bn(self.dense(self.dropout(input)))) |
|
|