| """ |
| fusion/fusion.py |
| ----------------- |
| Certainty-Weighted Probabilistic Fusion Module |
| STATUS: COMPLETE |
| |
| Combines the outputs from all 5 forensic branches into a single, |
| stable final prediction. Branches with higher certainty/confidence |
| get proportionally more weight in the final decision. |
| |
| Formula: |
| final_prob_fake = Ξ£(confidence_i Γ prob_fake_i) / Ξ£(confidence_i) |
| |
| If all branches are at 0 confidence (e.g. models untrained + degenerate input), |
| the system falls back to equal-weight averaging and flags the result as low-certainty. |
| |
| Output: |
| { |
| "prediction" : "Real" | "AI-Generated", |
| "confidence" : float (0β100, percentage), |
| "prob_fake" : float [0, 1], |
| "branches" : dict of per-branch scores, |
| "fused_weight" : dict of per-branch effective weights, |
| "low_certainty" : bool β True if system is uncertain |
| } |
| """ |
|
|
| import numpy as np |
| from typing import Dict, Any |
|
|
|
|
| |
| FAKE_THRESHOLD = 0.50 |
|
|
| |
| MIN_TOTAL_CONFIDENCE = 0.10 |
|
|
|
|
| def fuse_branches(branch_outputs: Dict[str, dict]) -> Dict[str, Any]: |
| """ |
| Perform certainty-weighted probabilistic fusion of all branch predictions. |
| |
| Args: |
| branch_outputs : dict mapping branch name β branch output dict. |
| Each branch output must contain: |
| - "prob_fake" : float in [0, 1] |
| - "confidence" : float in [0, 1] |
| |
| Returns: |
| Fusion result dict (see module docstring for schema). |
| """ |
| branch_names = ["spectral", "edge", "cnn", "vit", "diffusion"] |
|
|
| probs = [] |
| weights = [] |
| branch_info = {} |
|
|
| for name in branch_names: |
| branch = branch_outputs.get(name, {}) |
| prob_fake = float(np.clip(branch.get("prob_fake", 0.5), 0.0, 1.0)) |
| confidence = float(np.clip(branch.get("confidence", 0.0), 0.0, 1.0)) |
|
|
| |
| |
| effective_weight = max(confidence, 0.01) |
|
|
| probs.append(prob_fake) |
| weights.append(effective_weight) |
|
|
| branch_info[name] = { |
| "prob_fake": round(prob_fake, 4), |
| "confidence": round(confidence, 4), |
| "label": "AI-Generated" if prob_fake >= FAKE_THRESHOLD else "Real", |
| } |
|
|
| weights_arr = np.array(weights, dtype=np.float64) |
| probs_arr = np.array(probs, dtype=np.float64) |
|
|
| total_weight = weights_arr.sum() |
|
|
| if total_weight < MIN_TOTAL_CONFIDENCE: |
| |
| final_prob_fake = float(np.mean(probs_arr)) |
| low_certainty = True |
| effective_weights = {n: round(1.0 / len(branch_names), 4) for n in branch_names} |
| else: |
| |
| final_prob_fake = float(np.dot(weights_arr, probs_arr) / total_weight) |
| low_certainty = bool(total_weight < 1.0) |
| norm_weights = weights_arr / total_weight |
| effective_weights = { |
| name: round(float(w), 4) for name, w in zip(branch_names, norm_weights) |
| } |
|
|
| final_prob_fake = float(np.clip(final_prob_fake, 0.0, 1.0)) |
|
|
| |
| prediction = "AI-Generated" if final_prob_fake >= FAKE_THRESHOLD else "Real" |
|
|
| |
| confidence_pct = round(abs(final_prob_fake - 0.5) * 2.0 * 100.0, 2) |
| |
| confidence_pct = max(confidence_pct, 5.0) |
|
|
| return { |
| "prediction": prediction, |
| "confidence": confidence_pct, |
| "prob_fake": round(final_prob_fake, 4), |
| "branches": branch_info, |
| "fused_weight": effective_weights, |
| "low_certainty": low_certainty, |
| } |
|
|
|
|
| def format_result_for_display(fusion_result: dict) -> str: |
| """ |
| Format fusion result as a human-readable string for CLI/debug output. |
| """ |
| lines = [ |
| f"\n{'='*60}", |
| f" ImageForensics-Detect β Analysis Result", |
| f"{'='*60}", |
| f" Prediction : {fusion_result['prediction']}", |
| f" Confidence : {fusion_result['confidence']:.1f}%", |
| f" Prob (Fake) : {fusion_result['prob_fake']:.4f}", |
| f"{'β'*60}", |
| f" Branch-Level Scores:", |
| ] |
| for name, info in fusion_result["branches"].items(): |
| weight = fusion_result["fused_weight"].get(name, 0.0) |
| lines.append( |
| f" {name:12s} prob_fake={info['prob_fake']:.4f} " |
| f"confidence={info['confidence']:.4f} weight={weight:.4f} " |
| f"β {info['label']}" |
| ) |
| if fusion_result.get("low_certainty"): |
| lines.append(f"\n β Low certainty β some branches may be untrained.") |
| lines.append("="*60) |
| return "\n".join(lines) |
|
|