import argparse import numpy as np import torch import torch.backends.cudnn as cudnn import os import warnings import json from pathlib import Path from timm.models import create_model import my_models # registers TALL_SWIN import utils from video_dataset import VideoDataSet from video_dataset_aug import get_augmentor, build_dataflow from video_dataset_config import get_dataset_config, DATASET_CONFIG from sklearn.metrics import ( accuracy_score, balanced_accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report, roc_auc_score, roc_curve, average_precision_score, precision_recall_curve ) import matplotlib.pyplot as plt warnings.filterwarnings("ignore", category=UserWarning) def get_args_parser(): parser = argparse.ArgumentParser('DeiT evaluation script', add_help=False) parser.add_argument('--model', default='TALL_SWIN', type=str) parser.add_argument('--model_name', default="TALL_SWIN") parser.add_argument('--batch-size', default=2, type=int) # Dataset parameters parser.add_argument('--data_txt_dir', type=str, default='##path_for_dataset_txt##') parser.add_argument('--data_dir', type=str, default="##path_for_dataset##") parser.add_argument('--dataset', default='ffpp', choices=list(DATASET_CONFIG.keys())) parser.add_argument('--duration', default=1, type=int) parser.add_argument('--frames_per_group', default=1, type=int) parser.add_argument('--threed_data', default=False) parser.add_argument('--input_size', default=224, type=int) parser.add_argument('--disable_scaleup', action='store_true') parser.add_argument('--random_sampling', action='store_true') parser.add_argument('--dense_sampling', default=True) parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2']) parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+") parser.add_argument('--modality', default='rgb', type=str) parser.add_argument('--use_lmdb', default=False) parser.add_argument('--use_pyav', default=False) # temporal module / model params parser.add_argument('--pretrained', action='store_true', default=False) parser.add_argument('--temporal_module_name', default=None, type=str, choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA']) parser.add_argument('--temporal_attention_only', action='store_true', default=False) parser.add_argument('--no_token_mask', action='store_true', default=False) parser.add_argument('--temporal_heads_scale', default=1.0, type=float) parser.add_argument('--temporal_mlp_scale', default=1.0, type=float) parser.add_argument('--rel_pos', action='store_true', default=False) parser.add_argument('--temporal_pooling', type=str, default=None, choices=['avg', 'max', 'conv', 'depthconv']) parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw']) parser.add_argument('--window_size', default=7, type=int) parser.add_argument('--thumbnail_rows', default=3, type=int) parser.add_argument('--hpe_to_token', default=False, action='store_true') parser.add_argument('--drop', type=float, default=0.0) parser.add_argument('--drop-path', type=float, default=0.1) parser.add_argument('--drop-block', type=float, default=None) # runtime parser.add_argument('--output_dir', default="./output") parser.add_argument('--device', default='cuda') parser.add_argument('--seed', default=42, type=int) parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10]) parser.add_argument('--num_clips', default=3, type=int) parser.add_argument('--world_size', default=1, type=int) parser.add_argument("--local_rank", type=int) parser.add_argument('--dist_url', default='env://') # checkpoint parser.add_argument('--initial_checkpoint', type=str, default='', help='path to .pth/.pth.tar checkpoint (expects key "model")') parser.add_argument('--threshold', type=float, default=0.5, help='threshold to decide class 1 (fake) from prob[:,1]') parser.add_argument('--metrics_out', default='', type=str, help='folder to save metrics.json and plots (default: output_dir)') parser.add_argument('--save_plots', action='store_true', help='save cm.png / roc.png / pr.png') return parser @torch.no_grad() def eval_with_outputs(data_loader, model, device, threshold: float = 0.5): model.eval() y_true, y_score, y_pred = [], [], [] thr = float(threshold) for samples, targets in data_loader: samples = samples.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) logits = model(samples) # [B,2] or [B*K,2] # if logits came per-clip, aggregate per video B = targets.shape[0] if logits.shape[0] != B: if logits.shape[0] % B != 0: raise RuntimeError( f"logits batch ({logits.shape[0]}) is not a multiple of target batch ({B})." ) K = logits.shape[0] // B logits = logits.view(B, K, -1).mean(dim=1) # [B,2] probs = torch.softmax(logits, dim=1) # [B,2] p1 = probs[:, 1] # class 1 (fake) score # >>> THIS is the THRESHOLD <<< hat = (p1 >= thr).long() y_true.append(targets.detach().cpu().numpy()) y_score.append(p1.detach().cpu().numpy()) y_pred.append(hat.detach().cpu().numpy()) y_true = np.concatenate(y_true).astype(int) y_score = np.concatenate(y_score).astype(float) y_pred = np.concatenate(y_pred).astype(int) return y_true, y_score, y_pred def plot_confusion(cm, out_path): plt.figure(figsize=(6, 5)) plt.imshow(cm) plt.title("Confusion Matrix") plt.xlabel("Predicted") plt.ylabel("True") for (i, j), v in np.ndenumerate(cm): plt.text(j, i, str(v), ha="center", va="center") plt.tight_layout() plt.savefig(out_path, dpi=200) plt.close() def plot_roc(y, scores, out_path): fpr, tpr, _ = roc_curve(y, scores) auc = roc_auc_score(y, scores) plt.figure(figsize=(7, 6)) plt.plot(fpr, tpr, label=f"AUC={auc:.4f}") plt.plot([0, 1], [0, 1], "--", label="Chance") plt.xlabel("FPR") plt.ylabel("TPR") plt.legend(loc="best") plt.tight_layout() plt.savefig(out_path, dpi=200) plt.close() def plot_pr(y, scores, out_path): p, r, _ = precision_recall_curve(y, scores) ap = average_precision_score(y, scores) plt.figure(figsize=(7, 6)) plt.plot(r, p, label=f"AP={ap:.4f}") plt.xlabel("Recall") plt.ylabel("Precision") plt.legend(loc="best") plt.tight_layout() plt.savefig(out_path, dpi=200) plt.close() def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = \ get_dataset_config(args.dataset, args.use_lmdb) args.num_classes = num_classes args.input_channels = 3 if args.modality == 'rgb' else 2 * 5 print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=args.pretrained, duration=args.duration, hpe_to_token=args.hpe_to_token, rel_pos=args.rel_pos, window_size=args.window_size, thumbnail_rows=args.thumbnail_rows, token_mask=not args.no_token_mask, online_learning=False, num_classes=args.num_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, use_checkpoint=False ) model.to(device) # mean/std if args.distributed: mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean'] std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std'] else: mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean'] std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std'] # dataset (validation list) video_data_cls = VideoDataSet val_list = os.path.join(args.data_txt_dir, val_list_name) val_augmentor = get_augmentor( False, args.input_size, mean, std, args.disable_scaleup, threed_data=args.threed_data, version=args.augmentor_ver, scale_range=args.scale_range, num_clips=args.num_clips, num_crops=args.num_crops, dataset=args.dataset ) dataset_val = video_data_cls( args.data_dir, val_list, args.duration, args.frames_per_group, num_clips=args.num_clips, modality=args.modality, dense_sampling=args.dense_sampling, image_tmpl=image_tmpl, transform=val_augmentor, is_train=False, test_mode=False, seperator=filename_seperator, filter_video=filter_video ) data_loader_val = build_dataflow( dataset_val, is_train=False, batch_size=args.batch_size, workers=args.num_workers, is_distributed=args.distributed ) if not args.initial_checkpoint: raise RuntimeError("Please pass --initial_checkpoint pointing to the model checkpoint.") checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') # many checkpoints come as {"model": state_dict, ...} if isinstance(checkpoint, dict) and "model" in checkpoint: utils.load_checkpoint(model, checkpoint["model"]) else: # if it is a direct state_dict model.load_state_dict(checkpoint, strict=False) # eval y_true, y_score, y_pred = eval_with_outputs( data_loader_val, model, device, threshold=args.threshold ) acc = accuracy_score(y_true, y_pred) bacc = balanced_accuracy_score(y_true, y_pred) prec, rec, f1, _ = precision_recall_fscore_support( y_true, y_pred, average="binary", zero_division=0 ) cm = confusion_matrix(y_true, y_pred) roc_auc = roc_auc_score(y_true, y_score) pr_auc = average_precision_score(y_true, y_score) print(f"\nN={len(y_true)} | thr={args.threshold:.3f}") print(f"acc={acc:.4f} | bacc={bacc:.4f} | prec={prec:.4f} | rec={rec:.4f} | f1={f1:.4f} | roc_auc={roc_auc:.4f} | pr_auc={pr_auc:.4f}") print(classification_report(y_true, y_pred, digits=4, zero_division=0)) outdir = args.metrics_out.strip() if args.metrics_out else args.output_dir os.makedirs(outdir, exist_ok=True) out_json = { "threshold": float(args.threshold), "acc": float(acc), "balanced_acc": float(bacc), "precision": float(prec), "recall": float(rec), "f1": float(f1), "roc_auc": float(roc_auc), "pr_auc": float(pr_auc), "confusion_matrix": cm.tolist(), "n": int(len(y_true)), } with open(os.path.join(outdir, "metrics.json"), "w", encoding="utf-8") as f: json.dump(out_json, f, indent=2) np.savez(os.path.join(outdir, "eval_outputs.npz"), y_true=y_true, y_score=y_score, y_pred=y_pred) if args.save_plots: plot_confusion(cm, os.path.join(outdir, "cm.png")) plot_roc(y_true, y_score, os.path.join(outdir, "roc.png")) plot_pr(y_true, y_score, os.path.join(outdir, "pr.png")) print(f"\n✔ Plots + metrics saved in: {os.path.abspath(outdir)}") else: print(f"\n✔ Metrics saved in: {os.path.abspath(os.path.join(outdir, 'metrics.json'))}") if __name__ == '__main__': parser = argparse.ArgumentParser('DeiT evaluation script', parents=[get_args_parser()]) args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)