""" fusion/fusion.py ----------------- Certainty-Weighted Probabilistic Fusion Module STATUS: COMPLETE Combines the outputs from all 5 forensic branches into a single, stable final prediction. Branches with higher certainty/confidence get proportionally more weight in the final decision. Formula: final_prob_fake = Σ(confidence_i × prob_fake_i) / Σ(confidence_i) If all branches are at 0 confidence (e.g. models untrained + degenerate input), the system falls back to equal-weight averaging and flags the result as low-certainty. Output: { "prediction" : "Real" | "AI-Generated", "confidence" : float (0–100, percentage), "prob_fake" : float [0, 1], "branches" : dict of per-branch scores, "fused_weight" : dict of per-branch effective weights, "low_certainty" : bool — True if system is uncertain } """ import numpy as np from typing import Dict, Any # Threshold for final classification (>= this → AI-Generated) FAKE_THRESHOLD = 0.50 # Minimum total confidence to consider result reliable MIN_TOTAL_CONFIDENCE = 0.10 def fuse_branches(branch_outputs: Dict[str, dict]) -> Dict[str, Any]: """ Perform certainty-weighted probabilistic fusion of all branch predictions. Args: branch_outputs : dict mapping branch name → branch output dict. Each branch output must contain: - "prob_fake" : float in [0, 1] - "confidence" : float in [0, 1] Returns: Fusion result dict (see module docstring for schema). """ branch_names = ["spectral", "edge", "cnn", "vit", "diffusion"] probs = [] weights = [] branch_info = {} for name in branch_names: branch = branch_outputs.get(name, {}) prob_fake = float(np.clip(branch.get("prob_fake", 0.5), 0.0, 1.0)) confidence = float(np.clip(branch.get("confidence", 0.0), 0.0, 1.0)) # If confidence is exactly 0 (untrained branch), still give it a tiny # weight so it doesn't completely disappear, but near-zero. effective_weight = max(confidence, 0.01) probs.append(prob_fake) weights.append(effective_weight) branch_info[name] = { "prob_fake": round(prob_fake, 4), "confidence": round(confidence, 4), "label": "AI-Generated" if prob_fake >= FAKE_THRESHOLD else "Real", } weights_arr = np.array(weights, dtype=np.float64) probs_arr = np.array(probs, dtype=np.float64) total_weight = weights_arr.sum() if total_weight < MIN_TOTAL_CONFIDENCE: # All branches uncertain — equal weight fallback final_prob_fake = float(np.mean(probs_arr)) low_certainty = True # Python bool, JSON serializable effective_weights = {n: round(1.0 / len(branch_names), 4) for n in branch_names} else: # Certainty-weighted average final_prob_fake = float(np.dot(weights_arr, probs_arr) / total_weight) low_certainty = bool(total_weight < 1.0) # cast numpy.bool_ → Python bool # Partial certainty norm_weights = weights_arr / total_weight effective_weights = { name: round(float(w), 4) for name, w in zip(branch_names, norm_weights) } final_prob_fake = float(np.clip(final_prob_fake, 0.0, 1.0)) # Final label prediction = "AI-Generated" if final_prob_fake >= FAKE_THRESHOLD else "Real" # Confidence expressed as percentage (distance from 0.5, scaled to 100%) confidence_pct = round(abs(final_prob_fake - 0.5) * 2.0 * 100.0, 2) # Show at least some confidence for demo purposes confidence_pct = max(confidence_pct, 5.0) return { "prediction": prediction, "confidence": confidence_pct, # e.g. 97.1 "prob_fake": round(final_prob_fake, 4), "branches": branch_info, "fused_weight": effective_weights, "low_certainty": low_certainty, } def format_result_for_display(fusion_result: dict) -> str: """ Format fusion result as a human-readable string for CLI/debug output. """ lines = [ f"\n{'='*60}", f" ImageForensics-Detect — Analysis Result", f"{'='*60}", f" Prediction : {fusion_result['prediction']}", f" Confidence : {fusion_result['confidence']:.1f}%", f" Prob (Fake) : {fusion_result['prob_fake']:.4f}", f"{'─'*60}", f" Branch-Level Scores:", ] for name, info in fusion_result["branches"].items(): weight = fusion_result["fused_weight"].get(name, 0.0) lines.append( f" {name:12s} prob_fake={info['prob_fake']:.4f} " f"confidence={info['confidence']:.4f} weight={weight:.4f} " f"→ {info['label']}" ) if fusion_result.get("low_certainty"): lines.append(f"\n ⚠ Low certainty — some branches may be untrained.") lines.append("="*60) return "\n".join(lines)