Spaces:
Running
Running
| import argparse | |
| import numpy as np | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| import os | |
| import warnings | |
| import json | |
| from pathlib import Path | |
| from timm.models import create_model | |
| import my_models # registers TALL_SWIN | |
| import utils | |
| from video_dataset import VideoDataSet | |
| from video_dataset_aug import get_augmentor, build_dataflow | |
| from video_dataset_config import get_dataset_config, DATASET_CONFIG | |
| from sklearn.metrics import ( | |
| accuracy_score, balanced_accuracy_score, | |
| precision_recall_fscore_support, | |
| confusion_matrix, classification_report, | |
| roc_auc_score, roc_curve, | |
| average_precision_score, precision_recall_curve | |
| ) | |
| import matplotlib.pyplot as plt | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser('DeiT evaluation script', add_help=False) | |
| parser.add_argument('--model', default='TALL_SWIN', type=str) | |
| parser.add_argument('--model_name', default="TALL_SWIN") | |
| parser.add_argument('--batch-size', default=2, type=int) | |
| # Dataset parameters | |
| parser.add_argument('--data_txt_dir', type=str, default='##path_for_dataset_txt##') | |
| parser.add_argument('--data_dir', type=str, default="##path_for_dataset##") | |
| parser.add_argument('--dataset', default='ffpp', choices=list(DATASET_CONFIG.keys())) | |
| parser.add_argument('--duration', default=1, type=int) | |
| parser.add_argument('--frames_per_group', default=1, type=int) | |
| parser.add_argument('--threed_data', default=False) | |
| parser.add_argument('--input_size', default=224, type=int) | |
| parser.add_argument('--disable_scaleup', action='store_true') | |
| parser.add_argument('--random_sampling', action='store_true') | |
| parser.add_argument('--dense_sampling', default=True) | |
| parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2']) | |
| parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+") | |
| parser.add_argument('--modality', default='rgb', type=str) | |
| parser.add_argument('--use_lmdb', default=False) | |
| parser.add_argument('--use_pyav', default=False) | |
| # temporal module / model params | |
| parser.add_argument('--pretrained', action='store_true', default=False) | |
| parser.add_argument('--temporal_module_name', default=None, type=str, | |
| choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA']) | |
| parser.add_argument('--temporal_attention_only', action='store_true', default=False) | |
| parser.add_argument('--no_token_mask', action='store_true', default=False) | |
| parser.add_argument('--temporal_heads_scale', default=1.0, type=float) | |
| parser.add_argument('--temporal_mlp_scale', default=1.0, type=float) | |
| parser.add_argument('--rel_pos', action='store_true', default=False) | |
| parser.add_argument('--temporal_pooling', type=str, default=None, | |
| choices=['avg', 'max', 'conv', 'depthconv']) | |
| parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw']) | |
| parser.add_argument('--window_size', default=7, type=int) | |
| parser.add_argument('--thumbnail_rows', default=3, type=int) | |
| parser.add_argument('--hpe_to_token', default=False, action='store_true') | |
| parser.add_argument('--drop', type=float, default=0.0) | |
| parser.add_argument('--drop-path', type=float, default=0.1) | |
| parser.add_argument('--drop-block', type=float, default=None) | |
| # runtime | |
| parser.add_argument('--output_dir', default="./output") | |
| parser.add_argument('--device', default='cuda') | |
| parser.add_argument('--seed', default=42, type=int) | |
| parser.add_argument('--num_workers', default=8, type=int) | |
| parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10]) | |
| parser.add_argument('--num_clips', default=3, type=int) | |
| parser.add_argument('--world_size', default=1, type=int) | |
| parser.add_argument("--local_rank", type=int) | |
| parser.add_argument('--dist_url', default='env://') | |
| # checkpoint | |
| parser.add_argument('--initial_checkpoint', type=str, default='', | |
| help='path to .pth/.pth.tar checkpoint (expects key "model")') | |
| parser.add_argument('--threshold', type=float, default=0.5, | |
| help='threshold to decide class 1 (fake) from prob[:,1]') | |
| parser.add_argument('--metrics_out', default='', type=str, | |
| help='folder to save metrics.json and plots (default: output_dir)') | |
| parser.add_argument('--save_plots', action='store_true', | |
| help='save cm.png / roc.png / pr.png') | |
| return parser | |
| def eval_with_outputs(data_loader, model, device, threshold: float = 0.5): | |
| model.eval() | |
| y_true, y_score, y_pred = [], [], [] | |
| thr = float(threshold) | |
| for samples, targets in data_loader: | |
| samples = samples.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| logits = model(samples) # [B,2] or [B*K,2] | |
| # if logits came per-clip, aggregate per video | |
| B = targets.shape[0] | |
| if logits.shape[0] != B: | |
| if logits.shape[0] % B != 0: | |
| raise RuntimeError( | |
| f"logits batch ({logits.shape[0]}) is not a multiple of target batch ({B})." | |
| ) | |
| K = logits.shape[0] // B | |
| logits = logits.view(B, K, -1).mean(dim=1) # [B,2] | |
| probs = torch.softmax(logits, dim=1) # [B,2] | |
| p1 = probs[:, 1] # class 1 (fake) score | |
| # >>> THIS is the THRESHOLD <<< | |
| hat = (p1 >= thr).long() | |
| y_true.append(targets.detach().cpu().numpy()) | |
| y_score.append(p1.detach().cpu().numpy()) | |
| y_pred.append(hat.detach().cpu().numpy()) | |
| y_true = np.concatenate(y_true).astype(int) | |
| y_score = np.concatenate(y_score).astype(float) | |
| y_pred = np.concatenate(y_pred).astype(int) | |
| return y_true, y_score, y_pred | |
| def plot_confusion(cm, out_path): | |
| plt.figure(figsize=(6, 5)) | |
| plt.imshow(cm) | |
| plt.title("Confusion Matrix") | |
| plt.xlabel("Predicted") | |
| plt.ylabel("True") | |
| for (i, j), v in np.ndenumerate(cm): | |
| plt.text(j, i, str(v), ha="center", va="center") | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=200) | |
| plt.close() | |
| def plot_roc(y, scores, out_path): | |
| fpr, tpr, _ = roc_curve(y, scores) | |
| auc = roc_auc_score(y, scores) | |
| plt.figure(figsize=(7, 6)) | |
| plt.plot(fpr, tpr, label=f"AUC={auc:.4f}") | |
| plt.plot([0, 1], [0, 1], "--", label="Chance") | |
| plt.xlabel("FPR") | |
| plt.ylabel("TPR") | |
| plt.legend(loc="best") | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=200) | |
| plt.close() | |
| def plot_pr(y, scores, out_path): | |
| p, r, _ = precision_recall_curve(y, scores) | |
| ap = average_precision_score(y, scores) | |
| plt.figure(figsize=(7, 6)) | |
| plt.plot(r, p, label=f"AP={ap:.4f}") | |
| plt.xlabel("Recall") | |
| plt.ylabel("Precision") | |
| plt.legend(loc="best") | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=200) | |
| plt.close() | |
| def main(args): | |
| utils.init_distributed_mode(args) | |
| print(args) | |
| device = torch.device(args.device) | |
| seed = args.seed + utils.get_rank() | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| cudnn.benchmark = True | |
| num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = \ | |
| get_dataset_config(args.dataset, args.use_lmdb) | |
| args.num_classes = num_classes | |
| args.input_channels = 3 if args.modality == 'rgb' else 2 * 5 | |
| print(f"Creating model: {args.model}") | |
| model = create_model( | |
| args.model, | |
| pretrained=args.pretrained, | |
| duration=args.duration, | |
| hpe_to_token=args.hpe_to_token, | |
| rel_pos=args.rel_pos, | |
| window_size=args.window_size, | |
| thumbnail_rows=args.thumbnail_rows, | |
| token_mask=not args.no_token_mask, | |
| online_learning=False, | |
| num_classes=args.num_classes, | |
| drop_rate=args.drop, | |
| drop_path_rate=args.drop_path, | |
| drop_block_rate=args.drop_block, | |
| use_checkpoint=False | |
| ) | |
| model.to(device) | |
| # mean/std | |
| if args.distributed: | |
| mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean'] | |
| std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std'] | |
| else: | |
| mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean'] | |
| std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std'] | |
| # dataset (validation list) | |
| video_data_cls = VideoDataSet | |
| val_list = os.path.join(args.data_txt_dir, val_list_name) | |
| val_augmentor = get_augmentor( | |
| False, args.input_size, mean, std, args.disable_scaleup, | |
| threed_data=args.threed_data, version=args.augmentor_ver, | |
| scale_range=args.scale_range, num_clips=args.num_clips, | |
| num_crops=args.num_crops, dataset=args.dataset | |
| ) | |
| dataset_val = video_data_cls( | |
| args.data_dir, val_list, | |
| args.duration, args.frames_per_group, | |
| num_clips=args.num_clips, | |
| modality=args.modality, | |
| dense_sampling=args.dense_sampling, | |
| image_tmpl=image_tmpl, | |
| transform=val_augmentor, | |
| is_train=False, test_mode=False, | |
| seperator=filename_seperator, filter_video=filter_video | |
| ) | |
| data_loader_val = build_dataflow( | |
| dataset_val, is_train=False, batch_size=args.batch_size, | |
| workers=args.num_workers, is_distributed=args.distributed | |
| ) | |
| if not args.initial_checkpoint: | |
| raise RuntimeError("Please pass --initial_checkpoint pointing to the model checkpoint.") | |
| checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') | |
| # many checkpoints come as {"model": state_dict, ...} | |
| if isinstance(checkpoint, dict) and "model" in checkpoint: | |
| utils.load_checkpoint(model, checkpoint["model"]) | |
| else: | |
| # if it is a direct state_dict | |
| model.load_state_dict(checkpoint, strict=False) | |
| # eval | |
| y_true, y_score, y_pred = eval_with_outputs( | |
| data_loader_val, model, device, threshold=args.threshold | |
| ) | |
| acc = accuracy_score(y_true, y_pred) | |
| bacc = balanced_accuracy_score(y_true, y_pred) | |
| prec, rec, f1, _ = precision_recall_fscore_support( | |
| y_true, y_pred, average="binary", zero_division=0 | |
| ) | |
| cm = confusion_matrix(y_true, y_pred) | |
| roc_auc = roc_auc_score(y_true, y_score) | |
| pr_auc = average_precision_score(y_true, y_score) | |
| print(f"\nN={len(y_true)} | thr={args.threshold:.3f}") | |
| print(f"acc={acc:.4f} | bacc={bacc:.4f} | prec={prec:.4f} | rec={rec:.4f} | f1={f1:.4f} | roc_auc={roc_auc:.4f} | pr_auc={pr_auc:.4f}") | |
| print(classification_report(y_true, y_pred, digits=4, zero_division=0)) | |
| outdir = args.metrics_out.strip() if args.metrics_out else args.output_dir | |
| os.makedirs(outdir, exist_ok=True) | |
| out_json = { | |
| "threshold": float(args.threshold), | |
| "acc": float(acc), | |
| "balanced_acc": float(bacc), | |
| "precision": float(prec), | |
| "recall": float(rec), | |
| "f1": float(f1), | |
| "roc_auc": float(roc_auc), | |
| "pr_auc": float(pr_auc), | |
| "confusion_matrix": cm.tolist(), | |
| "n": int(len(y_true)), | |
| } | |
| with open(os.path.join(outdir, "metrics.json"), "w", encoding="utf-8") as f: | |
| json.dump(out_json, f, indent=2) | |
| np.savez(os.path.join(outdir, "eval_outputs.npz"), | |
| y_true=y_true, y_score=y_score, y_pred=y_pred) | |
| if args.save_plots: | |
| plot_confusion(cm, os.path.join(outdir, "cm.png")) | |
| plot_roc(y_true, y_score, os.path.join(outdir, "roc.png")) | |
| plot_pr(y_true, y_score, os.path.join(outdir, "pr.png")) | |
| print(f"\n✔ Plots + metrics saved in: {os.path.abspath(outdir)}") | |
| else: | |
| print(f"\n✔ Metrics saved in: {os.path.abspath(os.path.join(outdir, 'metrics.json'))}") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser('DeiT evaluation script', parents=[get_args_parser()]) | |
| args = parser.parse_args() | |
| if args.output_dir: | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| main(args) |