""" training/evaluate.py --------------------- Full System and Branch-Level Evaluation Script STATUS: COMPLETE Usage: cd ImageForensics-Detect/ # Evaluate the full fusion system (all branches): python training/evaluate.py # Evaluate a specific branch only: python training/evaluate.py --branch spectral python training/evaluate.py --branch edge python training/evaluate.py --branch cnn python training/evaluate.py --branch vit python training/evaluate.py --branch diffusion Reports: - Accuracy, Precision, Recall, F1-Score (per-class and macro) - Confusion Matrix (console + saved PNG) - ROC-AUC curve (saved PNG) - Per-sample CSV export All output saved to: outputs/evaluation_.csv """ import sys import json import argparse import numpy as np from pathlib import Path from tqdm import tqdm ROOT = Path(__file__).parent.parent sys.path.insert(0, str(ROOT)) OUTPUTS_DIR = ROOT / "outputs" OUTPUTS_DIR.mkdir(exist_ok=True) from training.dataset_loader import discover_dataset, split_dataset from utils.image_utils import load_image_from_path # ───────────────────────────────────────────────────────────────── # Branch Evaluators # ───────────────────────────────────────────────────────────────── def predict_single(img: np.ndarray, branch: str) -> float: """Run a single branch and return prob_fake.""" if branch == "spectral": from branches.spectral_branch import run_spectral_branch return run_spectral_branch(img)["prob_fake"] elif branch == "edge": from branches.edge_branch import run_edge_branch return run_edge_branch(img)["prob_fake"] elif branch == "diffusion": from branches.diffusion_branch import run_diffusion_branch return run_diffusion_branch(img)["prob_fake"] elif branch == "cnn": from branches.cnn_branch import run_cnn_branch return run_cnn_branch(img)["prob_fake"] elif branch == "vit": from branches.vit_branch import run_vit_branch return run_vit_branch(img)["prob_fake"] elif branch == "full": from branches.spectral_branch import run_spectral_branch from branches.edge_branch import run_edge_branch from branches.cnn_branch import run_cnn_branch from branches.vit_branch import run_vit_branch from branches.diffusion_branch import run_diffusion_branch from fusion.fusion import fuse_branches outs = { "spectral": run_spectral_branch(img), "edge": run_edge_branch(img), "cnn": run_cnn_branch(img), "vit": run_vit_branch(img), "diffusion": run_diffusion_branch(img), } return fuse_branches(outs)["prob_fake"] else: raise ValueError(f"Unknown branch: {branch}") # ───────────────────────────────────────────────────────────────── # Main Evaluation # ───────────────────────────────────────────────────────────────── def evaluate(branch: str = "full"): from sklearn.metrics import ( accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, classification_report ) import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns import csv print(f"\n{'='*60}") print(f" Evaluating: {branch.upper()} branch") print(f"{'='*60}") paths, labels = discover_dataset() splits = split_dataset(paths, labels) test_paths, test_labels = splits["test"] if len(test_paths) == 0: print("❌ No test images found.") sys.exit(1) print(f"Test set: {len(test_paths)} images " f"({test_labels.count(0)} real, {test_labels.count(1)} fake)\n") probs, preds, gt = [], [], [] for path, label in tqdm(zip(test_paths, test_labels), total=len(test_paths), desc=f"Evaluating [{branch}]"): try: img = load_image_from_path(path) prob_fake = predict_single(img, branch) probs.append(prob_fake) preds.append(1 if prob_fake >= 0.5 else 0) gt.append(label) except Exception as e: print(f" ⚠ Skipped {path}: {e}") # ── Metrics ────────────────────────────────────────────────── acc = accuracy_score(gt, preds) prec_mac = precision_score(gt, preds, average="macro", zero_division=0) rec_mac = recall_score(gt, preds, average="macro", zero_division=0) f1_mac = f1_score(gt, preds, average="macro", zero_division=0) prec_cls = precision_score(gt, preds, average=None, zero_division=0) rec_cls = recall_score(gt, preds, average=None, zero_division=0) f1_cls = f1_score(gt, preds, average=None, zero_division=0) try: auc = roc_auc_score(gt, probs) except Exception: auc = float("nan") cm = confusion_matrix(gt, preds) print(f"\n Accuracy : {acc:.4f}") print(f" Precision : {prec_mac:.4f} (macro)") print(f" Recall : {rec_mac:.4f} (macro)") print(f" F1-Score : {f1_mac:.4f} (macro)") print(f" ROC-AUC : {auc:.4f}") print(f"\n Classification Report:") print(classification_report(gt, preds, target_names=["Real", "AI-Generated"])) print(f"\n Confusion Matrix:\n {cm}") # ── Save Confusion Matrix Plot ──────────────────────────────── fig, ax = plt.subplots(figsize=(5, 4)) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Real", "AI-Gen"], yticklabels=["Real", "AI-Gen"], ax=ax) ax.set_title(f"Confusion Matrix — {branch.upper()} Branch") ax.set_xlabel("Predicted") ax.set_ylabel("Actual") plt.tight_layout() cm_path = OUTPUTS_DIR / f"confusion_matrix_{branch}.png" fig.savefig(cm_path, dpi=150) plt.close() print(f"\n✓ Confusion matrix saved → {cm_path}") # ── Save ROC-AUC Curve ──────────────────────────────────────── if not np.isnan(auc): from sklearn.metrics import roc_curve fpr, tpr, _ = roc_curve(gt, probs) fig2, ax2 = plt.subplots(figsize=(5, 4)) ax2.plot(fpr, tpr, label=f"AUC = {auc:.4f}") ax2.plot([0, 1], [0, 1], "k--") ax2.set_xlabel("False Positive Rate") ax2.set_ylabel("True Positive Rate") ax2.set_title(f"ROC Curve — {branch.upper()} Branch") ax2.legend() plt.tight_layout() roc_path = OUTPUTS_DIR / f"roc_curve_{branch}.png" fig2.savefig(roc_path, dpi=150) plt.close() print(f"✓ ROC curve saved → {roc_path}") # ── Export CSV ──────────────────────────────────────────────── csv_path = OUTPUTS_DIR / f"evaluation_{branch}.csv" with open(csv_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["path", "true_label", "prob_fake", "predicted"]) for p, l, pr, pd in zip(test_paths, gt, probs, preds): writer.writerow([p, l, round(pr, 4), pd]) print(f"✓ Per-sample results saved → {csv_path}") return { "accuracy": acc, "precision": prec_mac, "recall": rec_mac, "f1": f1_mac, "auc": auc, } if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate ImageForensics-Detect") parser.add_argument( "--branch", type=str, default="full", choices=["full", "spectral", "edge", "cnn", "vit", "diffusion"], help="Which branch to evaluate (default: full fusion)" ) args = parser.parse_args() evaluate(branch=args.branch)