"""Derive Shapley/Banzhaf/Influence attributions from a stored Mobius dict. Multimodal pipelines persist raw Mobius coefficients to disk (see precompute/run_medical.py and attribution/set_mm.py). This module is the single place the Gradio UI dispatches on method, mirroring the text-mode pattern at visualization/app.py:165-167. Key property: influence singletons are `Σ |f̂(B)|²` (sums of squared Fourier magnitudes) and are therefore always non-negative by construction. No positive-only filtering is needed; switching to influence gives clinicians a clean importance score. """ from __future__ import annotations from typing import Dict, List, Tuple from attribution.utils import ( banzhaf_interactions, influence_interactions, mobius_to_banzhaf, mobius_to_influence, mobius_to_shapley, shapley_interactions, ) SUPPORTED_METHODS = ("shapley", "banzhaf", "influence") MobiusDict = Dict[Tuple[int, ...], float] def parse_mobius_dict(raw: Dict[str, float]) -> MobiusDict: """Parse the JSON-serialized Mobius dict back into tuple-keyed form. Keys are comma-joined index strings; empty key → empty tuple (baseline). """ out: MobiusDict = {} for key, val in raw.items(): if key == "": out[()] = float(val) continue try: idxs = tuple(int(p) for p in key.split(",")) except ValueError: continue out[idxs] = float(val) return out def derive_method_values( mobius_dict: Dict[str, float] | MobiusDict, method: str, *, interaction_order: int = 2, top_k_interactions: int = 15, ) -> Tuple[Dict[int, float], List[Dict]]: """Return (singleton_values, pairwise_interactions) for the requested method. Parameters ---------- mobius_dict : dict Either string-keyed (JSON-serialized) or tuple-keyed Mobius coefficients. method : str One of "shapley", "banzhaf", "influence". interaction_order : int Order for pairwise interactions (default 2). top_k_interactions : int Cap for interactions; 0 for unlimited. Returns ------- singleton_values : dict[int, float] Feature index → attribution value. interactions : list of dict Each dict has keys: {indices: list[int], value: float, order: int}. """ method = (method or "shapley").lower() if method not in SUPPORTED_METHODS: raise ValueError(f"Unsupported method: {method!r}. Expected one of {SUPPORTED_METHODS}.") if mobius_dict and not isinstance(next(iter(mobius_dict)), tuple): mobius = parse_mobius_dict(mobius_dict) # type: ignore[arg-type] else: mobius = mobius_dict # type: ignore[assignment] if method == "shapley": raw_values = mobius_to_shapley(mobius) raw_inter = shapley_interactions(mobius, order=interaction_order, top_k=top_k_interactions) elif method == "banzhaf": raw_values = mobius_to_banzhaf(mobius) raw_inter = banzhaf_interactions(mobius, order=interaction_order, top_k=top_k_interactions) else: # influence raw_values = mobius_to_influence(mobius) raw_inter = influence_interactions(mobius, order=interaction_order, top_k=top_k_interactions) singleton_values: Dict[int, float] = {} for key, val in raw_values.items(): if len(key) == 1: singleton_values[int(key[0])] = float(val) interactions: List[Dict] = [] for loc, val in raw_inter: interactions.append( { "indices": [int(i) for i in sorted(loc)], "value": float(val), "order": len(loc), } ) return singleton_values, interactions def derive_cross_modal_pairs( mobius_dict: Dict[str, float] | MobiusDict, n_img: int, method: str = "influence", ) -> List[Dict]: """Return every segment × token pair with its method-specific strength. `player_filter` equivalents are handled inline: we keep pairs where one index is in [0, n_img) and the other is in [n_img, ...). Influence values are always non-negative; Shapley/Banzhaf can be signed. """ method = (method or "influence").lower() if method not in SUPPORTED_METHODS: raise ValueError(f"Unsupported method: {method!r}") if mobius_dict and not isinstance(next(iter(mobius_dict)), tuple): mobius = parse_mobius_dict(mobius_dict) # type: ignore[arg-type] else: mobius = mobius_dict # type: ignore[assignment] if method == "shapley": raw = shapley_interactions(mobius, order=2, top_k=0) elif method == "banzhaf": raw = banzhaf_interactions(mobius, order=2, top_k=0) else: raw = influence_interactions(mobius, order=2, top_k=0) pairs: List[Dict] = [] for loc, val in raw: if len(loc) != 2: continue a, b = sorted(int(i) for i in loc) if (a < n_img) == (b < n_img): continue seg_idx, tok_idx = (a, b - n_img) if a < n_img else (b, a - n_img) pairs.append( { "seg_index": int(seg_idx), "token_index": int(tok_idx), "value": float(val), "method": method, } ) return pairs