dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
fusion/fusion.py
-----------------
Certainty-Weighted Probabilistic Fusion Module
STATUS: COMPLETE
Combines the outputs from all 5 forensic branches into a single,
stable final prediction. Branches with higher certainty/confidence
get proportionally more weight in the final decision.
Formula:
final_prob_fake = Ξ£(confidence_i Γ— prob_fake_i) / Ξ£(confidence_i)
If all branches are at 0 confidence (e.g. models untrained + degenerate input),
the system falls back to equal-weight averaging and flags the result as low-certainty.
Output:
{
"prediction" : "Real" | "AI-Generated",
"confidence" : float (0–100, percentage),
"prob_fake" : float [0, 1],
"branches" : dict of per-branch scores,
"fused_weight" : dict of per-branch effective weights,
"low_certainty" : bool β€” True if system is uncertain
}
"""
import numpy as np
from typing import Dict, Any
# Threshold for final classification (>= this β†’ AI-Generated)
FAKE_THRESHOLD = 0.50
# Minimum total confidence to consider result reliable
MIN_TOTAL_CONFIDENCE = 0.10
def fuse_branches(branch_outputs: Dict[str, dict]) -> Dict[str, Any]:
"""
Perform certainty-weighted probabilistic fusion of all branch predictions.
Args:
branch_outputs : dict mapping branch name β†’ branch output dict.
Each branch output must contain:
- "prob_fake" : float in [0, 1]
- "confidence" : float in [0, 1]
Returns:
Fusion result dict (see module docstring for schema).
"""
branch_names = ["spectral", "edge", "cnn", "vit", "diffusion"]
probs = []
weights = []
branch_info = {}
for name in branch_names:
branch = branch_outputs.get(name, {})
prob_fake = float(np.clip(branch.get("prob_fake", 0.5), 0.0, 1.0))
confidence = float(np.clip(branch.get("confidence", 0.0), 0.0, 1.0))
# If confidence is exactly 0 (untrained branch), still give it a tiny
# weight so it doesn't completely disappear, but near-zero.
effective_weight = max(confidence, 0.01)
probs.append(prob_fake)
weights.append(effective_weight)
branch_info[name] = {
"prob_fake": round(prob_fake, 4),
"confidence": round(confidence, 4),
"label": "AI-Generated" if prob_fake >= FAKE_THRESHOLD else "Real",
}
weights_arr = np.array(weights, dtype=np.float64)
probs_arr = np.array(probs, dtype=np.float64)
total_weight = weights_arr.sum()
if total_weight < MIN_TOTAL_CONFIDENCE:
# All branches uncertain β€” equal weight fallback
final_prob_fake = float(np.mean(probs_arr))
low_certainty = True # Python bool, JSON serializable
effective_weights = {n: round(1.0 / len(branch_names), 4) for n in branch_names}
else:
# Certainty-weighted average
final_prob_fake = float(np.dot(weights_arr, probs_arr) / total_weight)
low_certainty = bool(total_weight < 1.0) # cast numpy.bool_ β†’ Python bool # Partial certainty
norm_weights = weights_arr / total_weight
effective_weights = {
name: round(float(w), 4) for name, w in zip(branch_names, norm_weights)
}
final_prob_fake = float(np.clip(final_prob_fake, 0.0, 1.0))
# Final label
prediction = "AI-Generated" if final_prob_fake >= FAKE_THRESHOLD else "Real"
# Confidence expressed as percentage (distance from 0.5, scaled to 100%)
confidence_pct = round(abs(final_prob_fake - 0.5) * 2.0 * 100.0, 2)
# Show at least some confidence for demo purposes
confidence_pct = max(confidence_pct, 5.0)
return {
"prediction": prediction,
"confidence": confidence_pct, # e.g. 97.1
"prob_fake": round(final_prob_fake, 4),
"branches": branch_info,
"fused_weight": effective_weights,
"low_certainty": low_certainty,
}
def format_result_for_display(fusion_result: dict) -> str:
"""
Format fusion result as a human-readable string for CLI/debug output.
"""
lines = [
f"\n{'='*60}",
f" ImageForensics-Detect β€” Analysis Result",
f"{'='*60}",
f" Prediction : {fusion_result['prediction']}",
f" Confidence : {fusion_result['confidence']:.1f}%",
f" Prob (Fake) : {fusion_result['prob_fake']:.4f}",
f"{'─'*60}",
f" Branch-Level Scores:",
]
for name, info in fusion_result["branches"].items():
weight = fusion_result["fused_weight"].get(name, 0.0)
lines.append(
f" {name:12s} prob_fake={info['prob_fake']:.4f} "
f"confidence={info['confidence']:.4f} weight={weight:.4f} "
f"β†’ {info['label']}"
)
if fusion_result.get("low_certainty"):
lines.append(f"\n ⚠ Low certainty β€” some branches may be untrained.")
lines.append("="*60)
return "\n".join(lines)