Spaces:
Build error
Build error
| import os | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torch.utils.data | |
| import numpy as np | |
| from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score | |
| import pickle | |
| from tqdm import tqdm | |
| from datetime import datetime | |
| from copy import deepcopy | |
| from dataset_paths import DATASET_PATHS | |
| import random | |
| from datasetss import create_test_dataloader | |
| from utilss.logger import create_logger | |
| import options | |
| from networks.validator import Validator | |
| SEED = 0 | |
| def set_seed(): | |
| torch.manual_seed(SEED) | |
| torch.cuda.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| random.seed(SEED) | |
| MEAN = { | |
| "imagenet":[0.485, 0.456, 0.406], | |
| "clip":[0.48145466, 0.4578275, 0.40821073] | |
| } | |
| STD = { | |
| "imagenet":[0.229, 0.224, 0.225], | |
| "clip":[0.26862954, 0.26130258, 0.27577711] | |
| } | |
| def find_best_threshold(y_true, y_pred): | |
| "We assume first half is real 0, and the second half is fake 1" | |
| N = y_true.shape[0] | |
| if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case | |
| return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 | |
| best_acc = 0 | |
| best_thres = 0 | |
| for thres in y_pred: | |
| temp = deepcopy(y_pred) | |
| temp[temp>=thres] = 1 | |
| temp[temp<thres] = 0 | |
| acc = (temp == y_true).sum() / N | |
| if acc >= best_acc: | |
| best_thres = thres | |
| best_acc = acc | |
| return best_thres | |
| def calculate_acc(y_true, y_pred, thres): | |
| r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres) | |
| f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres) | |
| acc = accuracy_score(y_true, y_pred > thres) | |
| return r_acc, f_acc, acc | |
| def validate(model, loader, logger, find_thres=False): | |
| with torch.no_grad(): | |
| y_true, y_pred = [], [] | |
| logger.info ("Length of dataset: %d" %(len(loader))) | |
| pbar = tqdm(loader) | |
| for data in pbar: | |
| pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
| model.set_input(data) | |
| y_pred.extend(model.model(model.input).view(-1).unsqueeze(1).sigmoid().flatten().tolist()) | |
| y_true.extend(data[1].flatten().tolist()) | |
| y_true, y_pred = np.array(y_true), np.array(y_pred) | |
| # ================== save this if you want to plot the curves =========== # | |
| # torch.save( torch.stack( [torch.tensor(y_true), torch.tensor(y_pred)] ), 'baseline_predication_for_pr_roc_curve.pth' ) | |
| # exit() | |
| # =================================================================== # | |
| # print(y_pred, '\n', y_true) | |
| # Get AP | |
| ap = average_precision_score(y_true, y_pred) | |
| # Acc based on 0.5 | |
| r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5) | |
| if not find_thres: | |
| return ap, r_acc0, f_acc0, acc0 | |
| # Acc based on the best thres | |
| best_thres = find_best_threshold(y_true, y_pred) | |
| r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres) | |
| return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres | |
| # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # | |
| def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]): | |
| out = [] | |
| for r, d, f in os.walk(rootdir): | |
| for file in f: | |
| if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): | |
| out.append(os.path.join(r, file)) | |
| return out | |
| def get_list(path, must_contain=''): | |
| if ".pickle" in path: | |
| with open(path, 'rb') as f: | |
| image_list = pickle.load(f) | |
| image_list = [ item for item in image_list if must_contain in item ] | |
| else: | |
| image_list = recursively_read(path, must_contain) | |
| return image_list | |
| if __name__ == '__main__': | |
| val_opt = options.TestOptions().parse() | |
| output_dir=os.path.join(val_opt.output, val_opt.name) | |
| os.makedirs(output_dir, exist_ok=True) | |
| logger = create_logger(output_dir=output_dir, name="FakeVideoDetector") | |
| logger.info(f"working dir: {output_dir}") | |
| model = Validator(val_opt) | |
| model.load_state_dict(val_opt.ckpt) | |
| logger.info("ckpt loaded!") | |
| val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess) | |
| ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, val_loader, logger, find_thres=True, ) | |
| print(f"ap: {ap}, r_acc0: {r_acc0}, f_acc0: {f_acc0}, acc0:{acc0}, r_acc1: {r_acc1}, f_acc1: {f_acc1}, acc1: {acc1}, best_thres: {best_thres} ") | |
| with open( os.path.join(val_opt.name,'ap.txt'), 'a') as f: | |
| f.write(str(round(ap*100, 2))+'\n' ) | |
| with open( os.path.join(val_opt.name,'acc0.txt'), 'a') as f: | |
| f.write(str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' ) | |