| """Derive Shapley/Banzhaf/Influence attributions from a stored Mobius dict. |
| |
| Multimodal pipelines persist raw Mobius coefficients to disk (see |
| precompute/run_medical.py and attribution/set_mm.py). This module is the single |
| place the Gradio UI dispatches on method, mirroring the text-mode pattern at |
| visualization/app.py:165-167. |
| |
| Key property: influence singletons are `Σ |f̂(B)|²` (sums of squared Fourier |
| magnitudes) and are therefore always non-negative by construction. No |
| positive-only filtering is needed; switching to influence gives clinicians a |
| clean importance score. |
| """ |
| from __future__ import annotations |
|
|
| from typing import Dict, List, Tuple |
|
|
| from attribution.utils import ( |
| banzhaf_interactions, |
| influence_interactions, |
| mobius_to_banzhaf, |
| mobius_to_influence, |
| mobius_to_shapley, |
| shapley_interactions, |
| ) |
|
|
|
|
| SUPPORTED_METHODS = ("shapley", "banzhaf", "influence") |
|
|
| MobiusDict = Dict[Tuple[int, ...], float] |
|
|
|
|
| def parse_mobius_dict(raw: Dict[str, float]) -> MobiusDict: |
| """Parse the JSON-serialized Mobius dict back into tuple-keyed form. |
| |
| Keys are comma-joined index strings; empty key → empty tuple (baseline). |
| """ |
| out: MobiusDict = {} |
| for key, val in raw.items(): |
| if key == "": |
| out[()] = float(val) |
| continue |
| try: |
| idxs = tuple(int(p) for p in key.split(",")) |
| except ValueError: |
| continue |
| out[idxs] = float(val) |
| return out |
|
|
|
|
| def derive_method_values( |
| mobius_dict: Dict[str, float] | MobiusDict, |
| method: str, |
| *, |
| interaction_order: int = 2, |
| top_k_interactions: int = 15, |
| ) -> Tuple[Dict[int, float], List[Dict]]: |
| """Return (singleton_values, pairwise_interactions) for the requested method. |
| |
| Parameters |
| ---------- |
| mobius_dict : dict |
| Either string-keyed (JSON-serialized) or tuple-keyed Mobius coefficients. |
| method : str |
| One of "shapley", "banzhaf", "influence". |
| interaction_order : int |
| Order for pairwise interactions (default 2). |
| top_k_interactions : int |
| Cap for interactions; 0 for unlimited. |
| |
| Returns |
| ------- |
| singleton_values : dict[int, float] |
| Feature index → attribution value. |
| interactions : list of dict |
| Each dict has keys: {indices: list[int], value: float, order: int}. |
| """ |
| method = (method or "shapley").lower() |
| if method not in SUPPORTED_METHODS: |
| raise ValueError(f"Unsupported method: {method!r}. Expected one of {SUPPORTED_METHODS}.") |
|
|
| if mobius_dict and not isinstance(next(iter(mobius_dict)), tuple): |
| mobius = parse_mobius_dict(mobius_dict) |
| else: |
| mobius = mobius_dict |
|
|
| if method == "shapley": |
| raw_values = mobius_to_shapley(mobius) |
| raw_inter = shapley_interactions(mobius, order=interaction_order, top_k=top_k_interactions) |
| elif method == "banzhaf": |
| raw_values = mobius_to_banzhaf(mobius) |
| raw_inter = banzhaf_interactions(mobius, order=interaction_order, top_k=top_k_interactions) |
| else: |
| raw_values = mobius_to_influence(mobius) |
| raw_inter = influence_interactions(mobius, order=interaction_order, top_k=top_k_interactions) |
|
|
| singleton_values: Dict[int, float] = {} |
| for key, val in raw_values.items(): |
| if len(key) == 1: |
| singleton_values[int(key[0])] = float(val) |
|
|
| interactions: List[Dict] = [] |
| for loc, val in raw_inter: |
| interactions.append( |
| { |
| "indices": [int(i) for i in sorted(loc)], |
| "value": float(val), |
| "order": len(loc), |
| } |
| ) |
|
|
| return singleton_values, interactions |
|
|
|
|
| def derive_cross_modal_pairs( |
| mobius_dict: Dict[str, float] | MobiusDict, |
| n_img: int, |
| method: str = "influence", |
| ) -> List[Dict]: |
| """Return every segment × token pair with its method-specific strength. |
| |
| `player_filter` equivalents are handled inline: we keep pairs where one |
| index is in [0, n_img) and the other is in [n_img, ...). Influence values |
| are always non-negative; Shapley/Banzhaf can be signed. |
| """ |
| method = (method or "influence").lower() |
| if method not in SUPPORTED_METHODS: |
| raise ValueError(f"Unsupported method: {method!r}") |
|
|
| if mobius_dict and not isinstance(next(iter(mobius_dict)), tuple): |
| mobius = parse_mobius_dict(mobius_dict) |
| else: |
| mobius = mobius_dict |
|
|
| if method == "shapley": |
| raw = shapley_interactions(mobius, order=2, top_k=0) |
| elif method == "banzhaf": |
| raw = banzhaf_interactions(mobius, order=2, top_k=0) |
| else: |
| raw = influence_interactions(mobius, order=2, top_k=0) |
|
|
| pairs: List[Dict] = [] |
| for loc, val in raw: |
| if len(loc) != 2: |
| continue |
| a, b = sorted(int(i) for i in loc) |
| if (a < n_img) == (b < n_img): |
| continue |
| seg_idx, tok_idx = (a, b - n_img) if a < n_img else (b, a - n_img) |
| pairs.append( |
| { |
| "seg_index": int(seg_idx), |
| "token_index": int(tok_idx), |
| "value": float(val), |
| "method": method, |
| } |
| ) |
| return pairs |
|
|