| """ |
| training/evaluate.py |
| --------------------- |
| Full System and Branch-Level Evaluation Script |
| STATUS: COMPLETE |
| |
| Usage: |
| cd ImageForensics-Detect/ |
| # Evaluate the full fusion system (all branches): |
| python training/evaluate.py |
| |
| # Evaluate a specific branch only: |
| python training/evaluate.py --branch spectral |
| python training/evaluate.py --branch edge |
| python training/evaluate.py --branch cnn |
| python training/evaluate.py --branch vit |
| python training/evaluate.py --branch diffusion |
| |
| Reports: |
| - Accuracy, Precision, Recall, F1-Score (per-class and macro) |
| - Confusion Matrix (console + saved PNG) |
| - ROC-AUC curve (saved PNG) |
| - Per-sample CSV export |
| |
| All output saved to: outputs/evaluation_<branch_or_full>.csv |
| """ |
|
|
| import sys |
| import json |
| import argparse |
| import numpy as np |
| from pathlib import Path |
| from tqdm import tqdm |
|
|
| ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
| OUTPUTS_DIR = ROOT / "outputs" |
| OUTPUTS_DIR.mkdir(exist_ok=True) |
|
|
| from training.dataset_loader import discover_dataset, split_dataset |
| from utils.image_utils import load_image_from_path |
|
|
|
|
| |
| |
| |
|
|
| def predict_single(img: np.ndarray, branch: str) -> float: |
| """Run a single branch and return prob_fake.""" |
| if branch == "spectral": |
| from branches.spectral_branch import run_spectral_branch |
| return run_spectral_branch(img)["prob_fake"] |
| elif branch == "edge": |
| from branches.edge_branch import run_edge_branch |
| return run_edge_branch(img)["prob_fake"] |
| elif branch == "diffusion": |
| from branches.diffusion_branch import run_diffusion_branch |
| return run_diffusion_branch(img)["prob_fake"] |
| elif branch == "cnn": |
| from branches.cnn_branch import run_cnn_branch |
| return run_cnn_branch(img)["prob_fake"] |
| elif branch == "vit": |
| from branches.vit_branch import run_vit_branch |
| return run_vit_branch(img)["prob_fake"] |
| elif branch == "full": |
| from branches.spectral_branch import run_spectral_branch |
| from branches.edge_branch import run_edge_branch |
| from branches.cnn_branch import run_cnn_branch |
| from branches.vit_branch import run_vit_branch |
| from branches.diffusion_branch import run_diffusion_branch |
| from fusion.fusion import fuse_branches |
| outs = { |
| "spectral": run_spectral_branch(img), |
| "edge": run_edge_branch(img), |
| "cnn": run_cnn_branch(img), |
| "vit": run_vit_branch(img), |
| "diffusion": run_diffusion_branch(img), |
| } |
| return fuse_branches(outs)["prob_fake"] |
| else: |
| raise ValueError(f"Unknown branch: {branch}") |
|
|
|
|
| |
| |
| |
|
|
| def evaluate(branch: str = "full"): |
| from sklearn.metrics import ( |
| accuracy_score, precision_score, recall_score, f1_score, |
| confusion_matrix, roc_auc_score, classification_report |
| ) |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import csv |
|
|
| print(f"\n{'='*60}") |
| print(f" Evaluating: {branch.upper()} branch") |
| print(f"{'='*60}") |
|
|
| paths, labels = discover_dataset() |
| splits = split_dataset(paths, labels) |
| test_paths, test_labels = splits["test"] |
|
|
| if len(test_paths) == 0: |
| print("β No test images found.") |
| sys.exit(1) |
|
|
| print(f"Test set: {len(test_paths)} images " |
| f"({test_labels.count(0)} real, {test_labels.count(1)} fake)\n") |
|
|
| probs, preds, gt = [], [], [] |
|
|
| for path, label in tqdm(zip(test_paths, test_labels), total=len(test_paths), |
| desc=f"Evaluating [{branch}]"): |
| try: |
| img = load_image_from_path(path) |
| prob_fake = predict_single(img, branch) |
| probs.append(prob_fake) |
| preds.append(1 if prob_fake >= 0.5 else 0) |
| gt.append(label) |
| except Exception as e: |
| print(f" β Skipped {path}: {e}") |
|
|
| |
| acc = accuracy_score(gt, preds) |
| prec_mac = precision_score(gt, preds, average="macro", zero_division=0) |
| rec_mac = recall_score(gt, preds, average="macro", zero_division=0) |
| f1_mac = f1_score(gt, preds, average="macro", zero_division=0) |
| prec_cls = precision_score(gt, preds, average=None, zero_division=0) |
| rec_cls = recall_score(gt, preds, average=None, zero_division=0) |
| f1_cls = f1_score(gt, preds, average=None, zero_division=0) |
|
|
| try: |
| auc = roc_auc_score(gt, probs) |
| except Exception: |
| auc = float("nan") |
|
|
| cm = confusion_matrix(gt, preds) |
|
|
| print(f"\n Accuracy : {acc:.4f}") |
| print(f" Precision : {prec_mac:.4f} (macro)") |
| print(f" Recall : {rec_mac:.4f} (macro)") |
| print(f" F1-Score : {f1_mac:.4f} (macro)") |
| print(f" ROC-AUC : {auc:.4f}") |
| print(f"\n Classification Report:") |
| print(classification_report(gt, preds, target_names=["Real", "AI-Generated"])) |
| print(f"\n Confusion Matrix:\n {cm}") |
|
|
| |
| fig, ax = plt.subplots(figsize=(5, 4)) |
| sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", |
| xticklabels=["Real", "AI-Gen"], |
| yticklabels=["Real", "AI-Gen"], ax=ax) |
| ax.set_title(f"Confusion Matrix β {branch.upper()} Branch") |
| ax.set_xlabel("Predicted") |
| ax.set_ylabel("Actual") |
| plt.tight_layout() |
| cm_path = OUTPUTS_DIR / f"confusion_matrix_{branch}.png" |
| fig.savefig(cm_path, dpi=150) |
| plt.close() |
| print(f"\nβ Confusion matrix saved β {cm_path}") |
|
|
| |
| if not np.isnan(auc): |
| from sklearn.metrics import roc_curve |
| fpr, tpr, _ = roc_curve(gt, probs) |
| fig2, ax2 = plt.subplots(figsize=(5, 4)) |
| ax2.plot(fpr, tpr, label=f"AUC = {auc:.4f}") |
| ax2.plot([0, 1], [0, 1], "k--") |
| ax2.set_xlabel("False Positive Rate") |
| ax2.set_ylabel("True Positive Rate") |
| ax2.set_title(f"ROC Curve β {branch.upper()} Branch") |
| ax2.legend() |
| plt.tight_layout() |
| roc_path = OUTPUTS_DIR / f"roc_curve_{branch}.png" |
| fig2.savefig(roc_path, dpi=150) |
| plt.close() |
| print(f"β ROC curve saved β {roc_path}") |
|
|
| |
| csv_path = OUTPUTS_DIR / f"evaluation_{branch}.csv" |
| with open(csv_path, "w", newline="") as f: |
| writer = csv.writer(f) |
| writer.writerow(["path", "true_label", "prob_fake", "predicted"]) |
| for p, l, pr, pd in zip(test_paths, gt, probs, preds): |
| writer.writerow([p, l, round(pr, 4), pd]) |
| print(f"β Per-sample results saved β {csv_path}") |
|
|
| return { |
| "accuracy": acc, "precision": prec_mac, |
| "recall": rec_mac, "f1": f1_mac, "auc": auc, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Evaluate ImageForensics-Detect") |
| parser.add_argument( |
| "--branch", type=str, default="full", |
| choices=["full", "spectral", "edge", "cnn", "vit", "diffusion"], |
| help="Which branch to evaluate (default: full fusion)" |
| ) |
| args = parser.parse_args() |
| evaluate(branch=args.branch) |
|
|