Spaces:
Build error
Build error
| import copy | |
| import json | |
| import os | |
| import time | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import tqdm | |
| from sklearn.metrics import * | |
| from tqdm import tqdm | |
| from transformers import AutoConfig, BertModel | |
| from transformers.models.bert.modeling_bert import BertLayer | |
| from zmq import device | |
| from .coattention import * | |
| from .layers import * | |
| from FakeVD.code_test.utils.metrics import * | |
| class SVFENDModel(torch.nn.Module): | |
| def __init__(self,bert_model,fea_dim,dropout): | |
| super(SVFENDModel, self).__init__() | |
| self.bert = BertModel.from_pretrained("./FakeVD/Models/bert-base-chinese/").requires_grad_(False) | |
| self.text_dim = 768 | |
| self.comment_dim = 768 | |
| self.img_dim = 4096 | |
| self.video_dim = 4096 | |
| self.num_frames = 83 | |
| self.num_audioframes = 50 | |
| self.num_comments = 23 | |
| self.dim = fea_dim | |
| self.num_heads = 4 | |
| self.dropout = dropout | |
| self.vggish_layer = torch.hub.load('./FakeVD/Models/torchvggish/', 'vggish', source = 'local') | |
| net_structure = list(self.vggish_layer.children()) | |
| self.vggish_modified = nn.Sequential(*net_structure[-2:-1]) | |
| self.co_attention_ta = co_attention(d_k=fea_dim, d_v=fea_dim, n_heads=self.num_heads, dropout=self.dropout, d_model=fea_dim, | |
| visual_len=self.num_audioframes, sen_len=512, fea_v=self.dim, fea_s=self.dim, pos=False) | |
| self.co_attention_tv = co_attention(d_k=fea_dim, d_v=fea_dim, n_heads=self.num_heads, dropout=self.dropout, d_model=fea_dim, | |
| visual_len=self.num_frames, sen_len=512, fea_v=self.dim, fea_s=self.dim, pos=False) | |
| self.trm = nn.TransformerEncoderLayer(d_model = self.dim, nhead = 2, batch_first = True) | |
| self.linear_text = nn.Sequential(torch.nn.Linear(self.text_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
| self.linear_comment = nn.Sequential(torch.nn.Linear(self.comment_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
| self.linear_img = nn.Sequential(torch.nn.Linear(self.img_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
| self.linear_video = nn.Sequential(torch.nn.Linear(self.video_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
| self.linear_intro = nn.Sequential(torch.nn.Linear(self.text_dim, fea_dim),torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
| self.linear_audio = nn.Sequential(torch.nn.Linear(fea_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
| self.classifier = nn.Linear(fea_dim,2) | |
| def forward(self, **kwargs): | |
| ### User Intro ### | |
| ### Title ### | |
| title_inputid = kwargs['title_inputid']#(batch,512) | |
| title_mask=kwargs['title_mask']#(batch,512) | |
| fea_text=self.bert(title_inputid,attention_mask=title_mask)['last_hidden_state']#(batch,sequence,768) | |
| fea_text=self.linear_text(fea_text) | |
| ### Audio Frames ### | |
| audioframes=kwargs['audioframes']#(batch,36,12288) | |
| audioframes_masks = kwargs['audioframes_masks'] | |
| fea_audio = self.vggish_modified(audioframes) #(batch, frames, 128) | |
| fea_audio = self.linear_audio(fea_audio) | |
| fea_audio, fea_text = self.co_attention_ta(v=fea_audio, s=fea_text, v_len=fea_audio.shape[1], s_len=fea_text.shape[1]) | |
| fea_audio = torch.mean(fea_audio, -2) | |
| ### Image Frames ### | |
| frames=kwargs['frames']#(batch,30,4096) | |
| frames_masks = kwargs['frames_masks'] | |
| fea_img = self.linear_img(frames) | |
| fea_img, fea_text = self.co_attention_tv(v=fea_img, s=fea_text, v_len=fea_img.shape[1], s_len=fea_text.shape[1]) | |
| fea_img = torch.mean(fea_img, -2) | |
| fea_text = torch.mean(fea_text, -2) | |
| ### C3D ### | |
| c3d = kwargs['c3d'] # (batch, 36, 4096) | |
| c3d_masks = kwargs['c3d_masks'] | |
| fea_video = self.linear_video(c3d) #(batch, frames, 128) | |
| fea_video = torch.mean(fea_video, -2) | |
| ### Comment ### | |
| fea_text = fea_text.unsqueeze(1) | |
| # fea_comments = fea_comments.unsqueeze(1) | |
| fea_img = fea_img.unsqueeze(1) | |
| fea_audio = fea_audio.unsqueeze(1) | |
| fea_video = fea_video.unsqueeze(1) | |
| # fea_intro = fea_intro.unsqueeze(1) | |
| fea=torch.cat((fea_text, fea_audio, fea_video,fea_img),1) # (bs, 6, 128) | |
| fea = self.trm(fea) | |
| fea = torch.mean(fea, -2) | |
| output = self.classifier(fea) | |
| return output, fea | |