AttrLLM / visualization /attribution_derive.py
Qingpeng Kong
clean initial state
3e72399
"""Derive Shapley/Banzhaf/Influence attributions from a stored Mobius dict.
Multimodal pipelines persist raw Mobius coefficients to disk (see
precompute/run_medical.py and attribution/set_mm.py). This module is the single
place the Gradio UI dispatches on method, mirroring the text-mode pattern at
visualization/app.py:165-167.
Key property: influence singletons are `Σ |f̂(B)|²` (sums of squared Fourier
magnitudes) and are therefore always non-negative by construction. No
positive-only filtering is needed; switching to influence gives clinicians a
clean importance score.
"""
from __future__ import annotations
from typing import Dict, List, Tuple
from attribution.utils import (
banzhaf_interactions,
influence_interactions,
mobius_to_banzhaf,
mobius_to_influence,
mobius_to_shapley,
shapley_interactions,
)
SUPPORTED_METHODS = ("shapley", "banzhaf", "influence")
MobiusDict = Dict[Tuple[int, ...], float]
def parse_mobius_dict(raw: Dict[str, float]) -> MobiusDict:
"""Parse the JSON-serialized Mobius dict back into tuple-keyed form.
Keys are comma-joined index strings; empty key → empty tuple (baseline).
"""
out: MobiusDict = {}
for key, val in raw.items():
if key == "":
out[()] = float(val)
continue
try:
idxs = tuple(int(p) for p in key.split(","))
except ValueError:
continue
out[idxs] = float(val)
return out
def derive_method_values(
mobius_dict: Dict[str, float] | MobiusDict,
method: str,
*,
interaction_order: int = 2,
top_k_interactions: int = 15,
) -> Tuple[Dict[int, float], List[Dict]]:
"""Return (singleton_values, pairwise_interactions) for the requested method.
Parameters
----------
mobius_dict : dict
Either string-keyed (JSON-serialized) or tuple-keyed Mobius coefficients.
method : str
One of "shapley", "banzhaf", "influence".
interaction_order : int
Order for pairwise interactions (default 2).
top_k_interactions : int
Cap for interactions; 0 for unlimited.
Returns
-------
singleton_values : dict[int, float]
Feature index → attribution value.
interactions : list of dict
Each dict has keys: {indices: list[int], value: float, order: int}.
"""
method = (method or "shapley").lower()
if method not in SUPPORTED_METHODS:
raise ValueError(f"Unsupported method: {method!r}. Expected one of {SUPPORTED_METHODS}.")
if mobius_dict and not isinstance(next(iter(mobius_dict)), tuple):
mobius = parse_mobius_dict(mobius_dict) # type: ignore[arg-type]
else:
mobius = mobius_dict # type: ignore[assignment]
if method == "shapley":
raw_values = mobius_to_shapley(mobius)
raw_inter = shapley_interactions(mobius, order=interaction_order, top_k=top_k_interactions)
elif method == "banzhaf":
raw_values = mobius_to_banzhaf(mobius)
raw_inter = banzhaf_interactions(mobius, order=interaction_order, top_k=top_k_interactions)
else: # influence
raw_values = mobius_to_influence(mobius)
raw_inter = influence_interactions(mobius, order=interaction_order, top_k=top_k_interactions)
singleton_values: Dict[int, float] = {}
for key, val in raw_values.items():
if len(key) == 1:
singleton_values[int(key[0])] = float(val)
interactions: List[Dict] = []
for loc, val in raw_inter:
interactions.append(
{
"indices": [int(i) for i in sorted(loc)],
"value": float(val),
"order": len(loc),
}
)
return singleton_values, interactions
def derive_cross_modal_pairs(
mobius_dict: Dict[str, float] | MobiusDict,
n_img: int,
method: str = "influence",
) -> List[Dict]:
"""Return every segment × token pair with its method-specific strength.
`player_filter` equivalents are handled inline: we keep pairs where one
index is in [0, n_img) and the other is in [n_img, ...). Influence values
are always non-negative; Shapley/Banzhaf can be signed.
"""
method = (method or "influence").lower()
if method not in SUPPORTED_METHODS:
raise ValueError(f"Unsupported method: {method!r}")
if mobius_dict and not isinstance(next(iter(mobius_dict)), tuple):
mobius = parse_mobius_dict(mobius_dict) # type: ignore[arg-type]
else:
mobius = mobius_dict # type: ignore[assignment]
if method == "shapley":
raw = shapley_interactions(mobius, order=2, top_k=0)
elif method == "banzhaf":
raw = banzhaf_interactions(mobius, order=2, top_k=0)
else:
raw = influence_interactions(mobius, order=2, top_k=0)
pairs: List[Dict] = []
for loc, val in raw:
if len(loc) != 2:
continue
a, b = sorted(int(i) for i in loc)
if (a < n_img) == (b < n_img):
continue
seg_idx, tok_idx = (a, b - n_img) if a < n_img else (b, a - n_img)
pairs.append(
{
"seg_index": int(seg_idx),
"token_index": int(tok_idx),
"value": float(val),
"method": method,
}
)
return pairs