Spaces:
Build error
Build error
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| from tqdm import tqdm | |
| from FakeVD.code_test.utils.metrics import * | |
| from FakeVD.code_test.models.SVFEND import SVFENDModel | |
| from FakeVD.code_test.utils.dataloader import SVFENDDataset | |
| from FakeVD.code_test.run import _init_fn, SVFEND_collate_fn | |
| # from VGGish_Feature_Extractor.my_vggish_folder_fun import vggish_audio | |
| from FakeVD.code_test.VGGish_Feature_Extractor.my_vggish_fun import vggish_audio, load_model_vggish | |
| from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import process_video as vgg19_frame | |
| from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import load_model_vgg19 | |
| from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import feature_extractor as c3d_video | |
| from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import load_model_c3d | |
| from FakeVD.code_test.Text_Feature_Extractor.main import video_work as asr_text | |
| from FakeVD.code_test.Text_Feature_Extractor.wav2text import wav2text | |
| def load_model(checkpoint_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = SVFENDModel(bert_model='bert-base-chinese', fea_dim=128,dropout=0.1) | |
| # model.load_state_dict(torch.load(checkpoint_path)) | |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device), False) | |
| model.eval() | |
| return model | |
| def get_model(checkpoint_path='./FakeVD/code_test/checkpoints/SVFEND/SVFEND/_test_epoch4_0.7943'): | |
| # 加载检测模型 模型存放路径 checkpoint_path | |
| model_main = load_model(checkpoint_path) | |
| model_vggish = load_model_vggish() | |
| model_vgg19 = load_model_vgg19() | |
| model_c3d = load_model_c3d() | |
| model_text = wav2text() | |
| models = { | |
| 'model_main': model_main, | |
| 'model_vggish': model_vggish, | |
| 'model_vgg19': model_vgg19, | |
| 'model_c3d' : model_c3d, | |
| 'model_text' : model_text | |
| } | |
| return models | |
| # label = 0 if item['annotation']=='真' else 1 | |
| def test(model, dataloader): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # model.cuda() | |
| model.eval() | |
| pred = [] | |
| label = [] | |
| prob = [] | |
| for batch in tqdm(dataloader): | |
| with torch.no_grad(): | |
| batch_data = batch | |
| for k, v in batch_data.items(): | |
| batch_data[k] = v.to(device) | |
| batch_label = batch_data['label'] | |
| batch_outputs, fea = model(**batch_data) | |
| _, batch_preds = torch.max(batch_outputs, 1) | |
| softmax_probs = F.softmax(batch_outputs, dim=1) # 计算softmax概率 | |
| label.extend(batch_label.detach().cpu().numpy().tolist()) | |
| pred.extend(batch_preds.detach().cpu().numpy().tolist()) | |
| prob.extend(softmax_probs.detach().cpu().numpy().tolist()) # 收集softmax概率 | |
| return (label, pred, prob) | |
| def main(models, | |
| video_file_path, | |
| preprocessed_flag=False, | |
| feature_path='./FakeVD/code_test/preprocessed_feature'): | |
| # 视频是否已经过预处理 preprocessed_flag | |
| # 特征存放目录 feature_path | |
| # 获取模型 | |
| model_main = models['model_main'] | |
| model_vggish = models['model_vggish'] | |
| model_vgg19 = models['model_vgg19'] | |
| model_c3d = models['model_c3d'] | |
| model_text = models['model_text'] | |
| # 获取视频文件夹路径 | |
| video_folder_path = os.path.dirname(video_file_path) | |
| # 获取视频文件名(包含扩展名) | |
| video_file_name = os.path.basename(video_file_path) | |
| # 提取视频文件名(不包括扩展名)作为视频ID | |
| vids = [] | |
| vid = os.path.splitext(video_file_name)[0] | |
| vids.append(vid) | |
| # video_file_name = os.path.basename(video_file_path) | |
| # vids.append(os.path.splitext(video_file_name)[0]) | |
| # # vids.append(video_file_name.split('_')[1].split('.')[0] | |
| # VGGish_audio特征目录 | |
| VGGish_audio_feature_path = os.path.join(feature_path, vid+'.pkl') | |
| # C3D_video特征目录 | |
| C3D_video_feature_path = os.path.join(feature_path, 'C3D/') | |
| # VGG19_frame特征目录 | |
| VGG19_frame_feature_path = os.path.join(feature_path, 'VGG19/') | |
| # ASR_text特征目录 | |
| asr_text_feature_path = os.path.join(feature_path, 'ASR/'+vid+'.json') | |
| # 特征提取 | |
| if not preprocessed_flag: | |
| vggish_audio(model_vggish, video_file_path, VGGish_audio_feature_path) | |
| vgg19_frame(model_vgg19, video_file_name, video_folder_path, VGG19_frame_feature_path) | |
| c3d_video(model_c3d, C3D_video_feature_path, video_folder_path, video_file_name) | |
| asr_text(model_text, model_vggish, video_file_path, asr_text_feature_path) | |
| # 数据路径 | |
| data = vids | |
| data_paths = { | |
| 'VGGish_audio' : VGGish_audio_feature_path, | |
| 'C3D_video' : C3D_video_feature_path, | |
| 'VGG19_frame' : VGG19_frame_feature_path, | |
| 'ASR_text' : asr_text_feature_path | |
| } | |
| # 创建Dataset和DataLoader | |
| dataset = SVFENDDataset(data, data_paths) | |
| dataloader=DataLoader(dataset, batch_size=1, | |
| num_workers=0, | |
| pin_memory=True, | |
| shuffle=False, | |
| worker_init_fn=_init_fn, | |
| collate_fn=SVFEND_collate_fn) | |
| # 进行预测 | |
| predictions = test(model_main, dataloader) | |
| annotation = '真' if predictions[1][0]==0 else '假' | |
| prob_softmax = predictions[2] | |
| # annotation_prob = max(prob_softmax[0]) | |
| annotation_prob = prob_softmax[0][0]#真的概率 | |
| annotation_prob1 = prob_softmax[0][1]#假的概率 | |
| # 打印预测结果 | |
| print(annotation, annotation_prob, annotation_prob1) | |
| return annotation_prob1 | |
| if __name__ == "__main__": | |
| # 视频是否已经过预处理 | |
| preprocessed_flag = False | |
| video_file_path = "./FakeVD/dataset/videos_1/douyin_6700861687563570439.mp4" | |
| models = get_model() | |
| main(models, video_file_path, preprocessed_flag) | |