Spaces:
Build error
Build error
| import os | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| from torchvision.io import read_video | |
| import torch.utils.data | |
| import numpy as np | |
| from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score | |
| import pickle | |
| from tqdm import tqdm | |
| from datetime import datetime | |
| from copy import deepcopy | |
| from dataset_paths import DATASET_PATHS | |
| import random | |
| from datasetss import create_test_dataloader | |
| from utilss.logger import create_logger | |
| import options | |
| from networks.validator import Validator | |
| def get_model(): | |
| val_opt = options.TestOptions().parse(print_options=False) | |
| output_dir=os.path.join(val_opt.output, val_opt.name) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector") | |
| print(f"working...") | |
| model = Validator(val_opt) | |
| model.load_state_dict(val_opt.ckpt) | |
| print("ckpt loaded!") | |
| return model | |
| def detect_video(video_path, model): | |
| frames, _, _ = read_video(str(video_path), pts_unit='sec') | |
| frames = frames[:16] | |
| frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W) | |
| video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) | |
| with torch.no_grad(): | |
| model.set_input([torch.as_tensor(video_frames), torch.tensor([0])]) | |
| pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid() | |
| return pred[0].item() | |
| if __name__ == '__main__': | |
| video_path = '../../dataset/MSRVTT/videos/all/video1.mp4' | |
| # val_opt = options.TestOptions().parse() | |
| # output_dir=os.path.join(val_opt.output, val_opt.name) | |
| # os.makedirs(output_dir, exist_ok=True) | |
| # # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector") | |
| # print(f"working...") | |
| # model = Validator(val_opt) | |
| # model.load_state_dict(val_opt.ckpt) | |
| # print("ckpt loaded!") | |
| # # val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess) | |
| # frames, _, _ = read_video(str(video_path), pts_unit='sec') | |
| # frames = frames[:16] | |
| # frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W) | |
| # video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) | |
| # with torch.no_grad(): | |
| # model.set_input([torch.as_tensor(video_frames), torch.tensor([0])]) | |
| # pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid() | |
| model = get_model() | |
| pred = detect_video(video_path, model) | |
| if pred > 0.5: | |
| print(f"Fake: {pred*100:.2f}%") | |
| else: | |
| print(f"Real: {(1-pred)*100:.2f}%") | |