| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| class SemanticEmbedding(nn.Module): | |
| def __init__(self, args, mesh_dim=71, report_dim=761, embed_size=512): | |
| super(SemanticEmbedding, self).__init__() | |
| self.mesh_tf = nn.Sequential( | |
| nn.Linear(embed_size, embed_size // 2), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(embed_size // 2, embed_size // 4), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(embed_size // 4, mesh_dim) | |
| ) | |
| self.report_tf = nn.Sequential( | |
| nn.Linear(embed_size, embed_size // 2), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(embed_size // 2, embed_size // 4), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(embed_size // 4, report_dim) | |
| ) | |
| self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.1) | |
| self.w1 = nn.Linear(in_features=mesh_dim + report_dim, out_features=embed_size) | |
| self.w2 = nn.Linear(in_features=embed_size, out_features=embed_size) | |
| self.relu = nn.ReLU() | |
| self.logit = nn.Linear(60, 31) | |
| self.dropout = nn.Dropout(0.2) | |
| self.__init_weight() | |
| self.target_dim = 60 | |
| self.sigm = nn.Sigmoid() | |
| def __init_weight(self): | |
| self.w1.weight.data.uniform_(-0.1, 0.1) | |
| self.w1.bias.data.fill_(0) | |
| self.w2.weight.data.uniform_(-0.1, 0.1) | |
| self.w2.bias.data.fill_(0) | |
| def forward(self, avg, pred_output): | |
| avg_visual = avg.unsqueeze(1) | |
| pred_output2 = F.pad(pred_output, (0, 0, 0, self.target_dim - pred_output.shape[1]), 'constant', 0) | |
| pred = pred_output2.permute(0, 2, 1) | |
| visual_text = torch.matmul(avg_visual, pred).squeeze(1) | |
| outputs = self.sigm(self.logit(visual_text)) | |
| return outputs | |
| class classfication(nn.Module): | |
| def __init__(self, distiller_num, avg_dim=1024): | |
| super(classfication, self).__init__() | |
| self.logit = nn.Linear(avg_dim, distiller_num) | |
| self.relu = nn.ReLU() | |
| self.sigm = nn.Sigmoid() | |
| self.dropout = nn.Dropout(0.5) | |
| def forward(self, avg): | |
| avg_visual = self.dropout(avg) | |
| x = self.logit(avg_visual) | |
| outputs = self.sigm(x) | |
| return outputs | |