dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
training/evaluate.py
---------------------
Full System and Branch-Level Evaluation Script
STATUS: COMPLETE
Usage:
cd ImageForensics-Detect/
# Evaluate the full fusion system (all branches):
python training/evaluate.py
# Evaluate a specific branch only:
python training/evaluate.py --branch spectral
python training/evaluate.py --branch edge
python training/evaluate.py --branch cnn
python training/evaluate.py --branch vit
python training/evaluate.py --branch diffusion
Reports:
- Accuracy, Precision, Recall, F1-Score (per-class and macro)
- Confusion Matrix (console + saved PNG)
- ROC-AUC curve (saved PNG)
- Per-sample CSV export
All output saved to: outputs/evaluation_<branch_or_full>.csv
"""
import sys
import json
import argparse
import numpy as np
from pathlib import Path
from tqdm import tqdm
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))
OUTPUTS_DIR = ROOT / "outputs"
OUTPUTS_DIR.mkdir(exist_ok=True)
from training.dataset_loader import discover_dataset, split_dataset
from utils.image_utils import load_image_from_path
# ─────────────────────────────────────────────────────────────────
# Branch Evaluators
# ─────────────────────────────────────────────────────────────────
def predict_single(img: np.ndarray, branch: str) -> float:
"""Run a single branch and return prob_fake."""
if branch == "spectral":
from branches.spectral_branch import run_spectral_branch
return run_spectral_branch(img)["prob_fake"]
elif branch == "edge":
from branches.edge_branch import run_edge_branch
return run_edge_branch(img)["prob_fake"]
elif branch == "diffusion":
from branches.diffusion_branch import run_diffusion_branch
return run_diffusion_branch(img)["prob_fake"]
elif branch == "cnn":
from branches.cnn_branch import run_cnn_branch
return run_cnn_branch(img)["prob_fake"]
elif branch == "vit":
from branches.vit_branch import run_vit_branch
return run_vit_branch(img)["prob_fake"]
elif branch == "full":
from branches.spectral_branch import run_spectral_branch
from branches.edge_branch import run_edge_branch
from branches.cnn_branch import run_cnn_branch
from branches.vit_branch import run_vit_branch
from branches.diffusion_branch import run_diffusion_branch
from fusion.fusion import fuse_branches
outs = {
"spectral": run_spectral_branch(img),
"edge": run_edge_branch(img),
"cnn": run_cnn_branch(img),
"vit": run_vit_branch(img),
"diffusion": run_diffusion_branch(img),
}
return fuse_branches(outs)["prob_fake"]
else:
raise ValueError(f"Unknown branch: {branch}")
# ─────────────────────────────────────────────────────────────────
# Main Evaluation
# ─────────────────────────────────────────────────────────────────
def evaluate(branch: str = "full"):
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, roc_auc_score, classification_report
)
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
import csv
print(f"\n{'='*60}")
print(f" Evaluating: {branch.upper()} branch")
print(f"{'='*60}")
paths, labels = discover_dataset()
splits = split_dataset(paths, labels)
test_paths, test_labels = splits["test"]
if len(test_paths) == 0:
print("❌ No test images found.")
sys.exit(1)
print(f"Test set: {len(test_paths)} images "
f"({test_labels.count(0)} real, {test_labels.count(1)} fake)\n")
probs, preds, gt = [], [], []
for path, label in tqdm(zip(test_paths, test_labels), total=len(test_paths),
desc=f"Evaluating [{branch}]"):
try:
img = load_image_from_path(path)
prob_fake = predict_single(img, branch)
probs.append(prob_fake)
preds.append(1 if prob_fake >= 0.5 else 0)
gt.append(label)
except Exception as e:
print(f" ⚠ Skipped {path}: {e}")
# ── Metrics ──────────────────────────────────────────────────
acc = accuracy_score(gt, preds)
prec_mac = precision_score(gt, preds, average="macro", zero_division=0)
rec_mac = recall_score(gt, preds, average="macro", zero_division=0)
f1_mac = f1_score(gt, preds, average="macro", zero_division=0)
prec_cls = precision_score(gt, preds, average=None, zero_division=0)
rec_cls = recall_score(gt, preds, average=None, zero_division=0)
f1_cls = f1_score(gt, preds, average=None, zero_division=0)
try:
auc = roc_auc_score(gt, probs)
except Exception:
auc = float("nan")
cm = confusion_matrix(gt, preds)
print(f"\n Accuracy : {acc:.4f}")
print(f" Precision : {prec_mac:.4f} (macro)")
print(f" Recall : {rec_mac:.4f} (macro)")
print(f" F1-Score : {f1_mac:.4f} (macro)")
print(f" ROC-AUC : {auc:.4f}")
print(f"\n Classification Report:")
print(classification_report(gt, preds, target_names=["Real", "AI-Generated"]))
print(f"\n Confusion Matrix:\n {cm}")
# ── Save Confusion Matrix Plot ────────────────────────────────
fig, ax = plt.subplots(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=["Real", "AI-Gen"],
yticklabels=["Real", "AI-Gen"], ax=ax)
ax.set_title(f"Confusion Matrix β€” {branch.upper()} Branch")
ax.set_xlabel("Predicted")
ax.set_ylabel("Actual")
plt.tight_layout()
cm_path = OUTPUTS_DIR / f"confusion_matrix_{branch}.png"
fig.savefig(cm_path, dpi=150)
plt.close()
print(f"\nβœ“ Confusion matrix saved β†’ {cm_path}")
# ── Save ROC-AUC Curve ────────────────────────────────────────
if not np.isnan(auc):
from sklearn.metrics import roc_curve
fpr, tpr, _ = roc_curve(gt, probs)
fig2, ax2 = plt.subplots(figsize=(5, 4))
ax2.plot(fpr, tpr, label=f"AUC = {auc:.4f}")
ax2.plot([0, 1], [0, 1], "k--")
ax2.set_xlabel("False Positive Rate")
ax2.set_ylabel("True Positive Rate")
ax2.set_title(f"ROC Curve β€” {branch.upper()} Branch")
ax2.legend()
plt.tight_layout()
roc_path = OUTPUTS_DIR / f"roc_curve_{branch}.png"
fig2.savefig(roc_path, dpi=150)
plt.close()
print(f"βœ“ ROC curve saved β†’ {roc_path}")
# ── Export CSV ────────────────────────────────────────────────
csv_path = OUTPUTS_DIR / f"evaluation_{branch}.csv"
with open(csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["path", "true_label", "prob_fake", "predicted"])
for p, l, pr, pd in zip(test_paths, gt, probs, preds):
writer.writerow([p, l, round(pr, 4), pd])
print(f"βœ“ Per-sample results saved β†’ {csv_path}")
return {
"accuracy": acc, "precision": prec_mac,
"recall": rec_mac, "f1": f1_mac, "auc": auc,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate ImageForensics-Detect")
parser.add_argument(
"--branch", type=str, default="full",
choices=["full", "spectral", "edge", "cnn", "vit", "diffusion"],
help="Which branch to evaluate (default: full fusion)"
)
args = parser.parse_args()
evaluate(branch=args.branch)