#!/usr/bin/env python3 """Unified evaluation script for the 7 VFM baselines.""" from __future__ import annotations import argparse import json from pathlib import Path import numpy as np import torch from PIL import Image from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score from torch.utils.data import DataLoader, Dataset from models import LOADERS, MODEL_SPECS, canonical_model_name, default_checkpoint_path, load_model IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".JPG", ".JPEG", ".PNG") class BinaryFolderDataset(Dataset): def __init__(self, real_dir: str, fake_dir: str, transform, max_samples: int | None = None): self.transform = transform real_paths = self._get_image_files(real_dir) fake_paths = self._get_image_files(fake_dir) if max_samples is not None: real_paths = real_paths[:max_samples] fake_paths = fake_paths[:max_samples] self.image_paths = real_paths + fake_paths self.labels = [0] * len(real_paths) + [1] * len(fake_paths) @staticmethod def _get_image_files(folder: str): folder = Path(folder) images = [] for extension in IMAGE_EXTENSIONS: images.extend(folder.rglob(f"*{extension}")) return sorted(images) def __len__(self): return len(self.image_paths) def __getitem__(self, index): image_path = self.image_paths[index] image = Image.open(image_path).convert("RGB") return self.transform(image), self.labels[index], str(image_path) def evaluate(model, transform, real_dir: str, fake_dir: str, batch_size: int, num_workers: int, max_samples: int | None): dataset = BinaryFolderDataset(real_dir, fake_dir, transform, max_samples=max_samples) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available(), ) device = next(model.parameters()).device y_true = [] y_prob = [] y_pred = [] paths = [] with torch.no_grad(): for images, labels, batch_paths in dataloader: images = images.to(device) logits = model(images) probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy() preds = (probs > 0.5).astype(int) y_true.extend(labels.numpy().tolist()) y_prob.extend(probs.tolist()) y_pred.extend(preds.tolist()) paths.extend(batch_paths) y_true = np.asarray(y_true) y_prob = np.asarray(y_prob) y_pred = np.asarray(y_pred) metrics = { "accuracy": float(accuracy_score(y_true, y_pred)), "real_accuracy": float(accuracy_score(y_true[y_true == 0], y_pred[y_true == 0])), "fake_accuracy": float(accuracy_score(y_true[y_true == 1], y_pred[y_true == 1])), } if len(np.unique(y_true)) > 1: metrics["auc"] = float(roc_auc_score(y_true, y_prob)) metrics["ap"] = float(average_precision_score(y_true, y_prob)) samples = [ { "path": path, "label": int(label), "prob_fake": float(prob), "pred": int(pred), } for path, label, prob, pred in zip(paths, y_true, y_prob, y_pred) ] return {"metrics": metrics, "samples": samples} def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="all", help="One of: all, metacliplin, metaclip2lin, sigliplin, siglip2lin, pelin, dinov2lin, dinov3lin") parser.add_argument("--real-dir", required=True) parser.add_argument("--fake-dir", required=True) parser.add_argument("--checkpoint", default=None, help="Optional explicit checkpoint path for single-model evaluation") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--device", default=None) parser.add_argument("--save-json", default=None) args = parser.parse_args() model_names = list(LOADERS.keys()) if args.model == "all" else [canonical_model_name(args.model)] results = {} for model_name in model_names: checkpoint = args.checkpoint if args.model != "all" and args.checkpoint else default_checkpoint_path(model_name) checkpoint = Path(checkpoint) try: checkpoint_for_output = str(checkpoint.relative_to(Path(__file__).resolve().parent)) except ValueError: checkpoint_for_output = str(checkpoint) model, transform = load_model(model_name, checkpoint_path=checkpoint, device=args.device) result = evaluate( model=model, transform=transform, real_dir=args.real_dir, fake_dir=args.fake_dir, batch_size=args.batch_size, num_workers=args.num_workers, max_samples=args.max_samples, ) results[model_name] = { "paper_name": MODEL_SPECS[model_name]["paper_name"], "checkpoint": checkpoint_for_output, **result, } del model if torch.cuda.is_available(): torch.cuda.empty_cache() output = json.dumps(results, indent=2, ensure_ascii=False) print(output) if args.save_json: Path(args.save_json).write_text(output + "\n", encoding="utf-8") if __name__ == "__main__": main()