Spaces:
Build error
Build error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import BatchNorm1d, Linear, ReLU | |
| from .bert_model import BertForSequenceEncoder | |
| from torch.nn import BatchNorm1d, Linear, ReLU | |
| from .bert_model import BertForSequenceEncoder | |
| from torch.autograd import Variable | |
| import numpy as np | |
| def kernal_mus(n_kernels): | |
| """ | |
| get the mu for each guassian kernel. Mu is the middle of each bin | |
| :param n_kernels: number of kernels (including exact match). first one is exact match | |
| :return: l_mu, a list of mu. | |
| """ | |
| l_mu = [1] | |
| if n_kernels == 1: | |
| return l_mu | |
| bin_size = 2.0 / (n_kernels - 1) # score range from [-1, 1] | |
| l_mu.append(1 - bin_size / 2) # mu: middle of the bin | |
| for i in range(1, n_kernels - 1): | |
| l_mu.append(l_mu[i] - bin_size) | |
| return l_mu | |
| def kernel_sigmas(n_kernels): | |
| """ | |
| get sigmas for each guassian kernel. | |
| :param n_kernels: number of kernels (including exactmath.) | |
| :param lamb: | |
| :param use_exact: | |
| :return: l_sigma, a list of simga | |
| """ | |
| bin_size = 2.0 / (n_kernels - 1) | |
| l_sigma = [0.001] # for exact match. small variance -> exact match | |
| if n_kernels == 1: | |
| return l_sigma | |
| l_sigma += [0.1] * (n_kernels - 1) | |
| return l_sigma | |
| class inference_model(nn.Module): | |
| def __init__(self, bert_model, args): | |
| super(inference_model, self).__init__() | |
| self.bert_hidden_dim = args.bert_hidden_dim | |
| self.dropout = nn.Dropout(args.dropout) | |
| self.max_len = args.max_len | |
| self.num_labels = args.num_labels | |
| self.pred_model = bert_model | |
| #self.proj_hidden = nn.Linear(self.bert_hidden_dim, 128) | |
| self.proj_match = nn.Linear(self.bert_hidden_dim, 1) | |
| def forward(self, inp_tensor, msk_tensor, seg_tensor): | |
| _, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor) | |
| inputs = self.dropout(inputs) | |
| score = self.proj_match(inputs).squeeze(-1) | |
| score = torch.tanh(score) | |
| return score |