DeepFakeDetector-demo / test_new.py
guard2PFE's picture
Update test_new.py
9545ee0 verified
import argparse
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import os
import warnings
import json
from pathlib import Path
from timm.models import create_model
import my_models # registers TALL_SWIN
import utils
from video_dataset import VideoDataSet
from video_dataset_aug import get_augmentor, build_dataflow
from video_dataset_config import get_dataset_config, DATASET_CONFIG
from sklearn.metrics import (
accuracy_score, balanced_accuracy_score,
precision_recall_fscore_support,
confusion_matrix, classification_report,
roc_auc_score, roc_curve,
average_precision_score, precision_recall_curve
)
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore", category=UserWarning)
def get_args_parser():
parser = argparse.ArgumentParser('DeiT evaluation script', add_help=False)
parser.add_argument('--model', default='TALL_SWIN', type=str)
parser.add_argument('--model_name', default="TALL_SWIN")
parser.add_argument('--batch-size', default=2, type=int)
# Dataset parameters
parser.add_argument('--data_txt_dir', type=str, default='##path_for_dataset_txt##')
parser.add_argument('--data_dir', type=str, default="##path_for_dataset##")
parser.add_argument('--dataset', default='ffpp', choices=list(DATASET_CONFIG.keys()))
parser.add_argument('--duration', default=1, type=int)
parser.add_argument('--frames_per_group', default=1, type=int)
parser.add_argument('--threed_data', default=False)
parser.add_argument('--input_size', default=224, type=int)
parser.add_argument('--disable_scaleup', action='store_true')
parser.add_argument('--random_sampling', action='store_true')
parser.add_argument('--dense_sampling', default=True)
parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2'])
parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+")
parser.add_argument('--modality', default='rgb', type=str)
parser.add_argument('--use_lmdb', default=False)
parser.add_argument('--use_pyav', default=False)
# temporal module / model params
parser.add_argument('--pretrained', action='store_true', default=False)
parser.add_argument('--temporal_module_name', default=None, type=str,
choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA'])
parser.add_argument('--temporal_attention_only', action='store_true', default=False)
parser.add_argument('--no_token_mask', action='store_true', default=False)
parser.add_argument('--temporal_heads_scale', default=1.0, type=float)
parser.add_argument('--temporal_mlp_scale', default=1.0, type=float)
parser.add_argument('--rel_pos', action='store_true', default=False)
parser.add_argument('--temporal_pooling', type=str, default=None,
choices=['avg', 'max', 'conv', 'depthconv'])
parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw'])
parser.add_argument('--window_size', default=7, type=int)
parser.add_argument('--thumbnail_rows', default=3, type=int)
parser.add_argument('--hpe_to_token', default=False, action='store_true')
parser.add_argument('--drop', type=float, default=0.0)
parser.add_argument('--drop-path', type=float, default=0.1)
parser.add_argument('--drop-block', type=float, default=None)
# runtime
parser.add_argument('--output_dir', default="./output")
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10])
parser.add_argument('--num_clips', default=3, type=int)
parser.add_argument('--world_size', default=1, type=int)
parser.add_argument("--local_rank", type=int)
parser.add_argument('--dist_url', default='env://')
# checkpoint
parser.add_argument('--initial_checkpoint', type=str, default='',
help='path to .pth/.pth.tar checkpoint (expects key "model")')
parser.add_argument('--threshold', type=float, default=0.5,
help='threshold to decide class 1 (fake) from prob[:,1]')
parser.add_argument('--metrics_out', default='', type=str,
help='folder to save metrics.json and plots (default: output_dir)')
parser.add_argument('--save_plots', action='store_true',
help='save cm.png / roc.png / pr.png')
return parser
@torch.no_grad()
def eval_with_outputs(data_loader, model, device, threshold: float = 0.5):
model.eval()
y_true, y_score, y_pred = [], [], []
thr = float(threshold)
for samples, targets in data_loader:
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
logits = model(samples) # [B,2] or [B*K,2]
# if logits came per-clip, aggregate per video
B = targets.shape[0]
if logits.shape[0] != B:
if logits.shape[0] % B != 0:
raise RuntimeError(
f"logits batch ({logits.shape[0]}) is not a multiple of target batch ({B})."
)
K = logits.shape[0] // B
logits = logits.view(B, K, -1).mean(dim=1) # [B,2]
probs = torch.softmax(logits, dim=1) # [B,2]
p1 = probs[:, 1] # class 1 (fake) score
# >>> THIS is the THRESHOLD <<<
hat = (p1 >= thr).long()
y_true.append(targets.detach().cpu().numpy())
y_score.append(p1.detach().cpu().numpy())
y_pred.append(hat.detach().cpu().numpy())
y_true = np.concatenate(y_true).astype(int)
y_score = np.concatenate(y_score).astype(float)
y_pred = np.concatenate(y_pred).astype(int)
return y_true, y_score, y_pred
def plot_confusion(cm, out_path):
plt.figure(figsize=(6, 5))
plt.imshow(cm)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
for (i, j), v in np.ndenumerate(cm):
plt.text(j, i, str(v), ha="center", va="center")
plt.tight_layout()
plt.savefig(out_path, dpi=200)
plt.close()
def plot_roc(y, scores, out_path):
fpr, tpr, _ = roc_curve(y, scores)
auc = roc_auc_score(y, scores)
plt.figure(figsize=(7, 6))
plt.plot(fpr, tpr, label=f"AUC={auc:.4f}")
plt.plot([0, 1], [0, 1], "--", label="Chance")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.legend(loc="best")
plt.tight_layout()
plt.savefig(out_path, dpi=200)
plt.close()
def plot_pr(y, scores, out_path):
p, r, _ = precision_recall_curve(y, scores)
ap = average_precision_score(y, scores)
plt.figure(figsize=(7, 6))
plt.plot(r, p, label=f"AP={ap:.4f}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend(loc="best")
plt.tight_layout()
plt.savefig(out_path, dpi=200)
plt.close()
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = \
get_dataset_config(args.dataset, args.use_lmdb)
args.num_classes = num_classes
args.input_channels = 3 if args.modality == 'rgb' else 2 * 5
print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=args.pretrained,
duration=args.duration,
hpe_to_token=args.hpe_to_token,
rel_pos=args.rel_pos,
window_size=args.window_size,
thumbnail_rows=args.thumbnail_rows,
token_mask=not args.no_token_mask,
online_learning=False,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
use_checkpoint=False
)
model.to(device)
# mean/std
if args.distributed:
mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean']
std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std']
else:
mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean']
std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std']
# dataset (validation list)
video_data_cls = VideoDataSet
val_list = os.path.join(args.data_txt_dir, val_list_name)
val_augmentor = get_augmentor(
False, args.input_size, mean, std, args.disable_scaleup,
threed_data=args.threed_data, version=args.augmentor_ver,
scale_range=args.scale_range, num_clips=args.num_clips,
num_crops=args.num_crops, dataset=args.dataset
)
dataset_val = video_data_cls(
args.data_dir, val_list,
args.duration, args.frames_per_group,
num_clips=args.num_clips,
modality=args.modality,
dense_sampling=args.dense_sampling,
image_tmpl=image_tmpl,
transform=val_augmentor,
is_train=False, test_mode=False,
seperator=filename_seperator, filter_video=filter_video
)
data_loader_val = build_dataflow(
dataset_val, is_train=False, batch_size=args.batch_size,
workers=args.num_workers, is_distributed=args.distributed
)
if not args.initial_checkpoint:
raise RuntimeError("Please pass --initial_checkpoint pointing to the model checkpoint.")
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
# many checkpoints come as {"model": state_dict, ...}
if isinstance(checkpoint, dict) and "model" in checkpoint:
utils.load_checkpoint(model, checkpoint["model"])
else:
# if it is a direct state_dict
model.load_state_dict(checkpoint, strict=False)
# eval
y_true, y_score, y_pred = eval_with_outputs(
data_loader_val, model, device, threshold=args.threshold
)
acc = accuracy_score(y_true, y_pred)
bacc = balanced_accuracy_score(y_true, y_pred)
prec, rec, f1, _ = precision_recall_fscore_support(
y_true, y_pred, average="binary", zero_division=0
)
cm = confusion_matrix(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_score)
pr_auc = average_precision_score(y_true, y_score)
print(f"\nN={len(y_true)} | thr={args.threshold:.3f}")
print(f"acc={acc:.4f} | bacc={bacc:.4f} | prec={prec:.4f} | rec={rec:.4f} | f1={f1:.4f} | roc_auc={roc_auc:.4f} | pr_auc={pr_auc:.4f}")
print(classification_report(y_true, y_pred, digits=4, zero_division=0))
outdir = args.metrics_out.strip() if args.metrics_out else args.output_dir
os.makedirs(outdir, exist_ok=True)
out_json = {
"threshold": float(args.threshold),
"acc": float(acc),
"balanced_acc": float(bacc),
"precision": float(prec),
"recall": float(rec),
"f1": float(f1),
"roc_auc": float(roc_auc),
"pr_auc": float(pr_auc),
"confusion_matrix": cm.tolist(),
"n": int(len(y_true)),
}
with open(os.path.join(outdir, "metrics.json"), "w", encoding="utf-8") as f:
json.dump(out_json, f, indent=2)
np.savez(os.path.join(outdir, "eval_outputs.npz"),
y_true=y_true, y_score=y_score, y_pred=y_pred)
if args.save_plots:
plot_confusion(cm, os.path.join(outdir, "cm.png"))
plot_roc(y_true, y_score, os.path.join(outdir, "roc.png"))
plot_pr(y_true, y_score, os.path.join(outdir, "pr.png"))
print(f"\n✔ Plots + metrics saved in: {os.path.abspath(outdir)}")
else:
print(f"\n✔ Metrics saved in: {os.path.abspath(os.path.join(outdir, 'metrics.json'))}")
if __name__ == '__main__':
parser = argparse.ArgumentParser('DeiT evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)