AttrLLM / attribution /utils.py
Qingpeng Kong
clean initial state
3e72399
from __future__ import annotations
import math
from dataclasses import dataclass
from itertools import combinations
from typing import Dict, List, Sequence, Tuple, Union
import numpy as np
@dataclass
class AttributionResult:
"""Container for a single feature or interaction attribution value."""
feature: Union[str, Tuple[str, ...]]
value: float
method: str
interaction_order: int
def powerset(loc_tuple: Tuple[int, ...], max_order: int | None = None) -> List[Tuple[int, ...]]:
"""
Generate the powerset of a sparse location tuple up to a specified maximum order.
Parameters:
- loc_tuple: A sparse tuple of feature indices (e.g., (0, 2) means features 0 and 2).
- max_order: The maximum order of the powerset (default is None).
Returns:
- A list of tuples representing the powerset in sparse format.
"""
# loc_tuple is already sparse (contains feature indices)
# e.g., (0, 2) means features 0 and 2 are active
if max_order is None:
max_order = len(loc_tuple)
# Generate all subsets of the sparse indices up to max_order
tuples = []
for r in range(max_order + 1):
for subset in combinations(loc_tuple, r):
tuples.append(subset)
return tuples
def fourier_to_mobius(fourier_dict: Dict[Tuple[int, ...], float]) -> Dict[Tuple[int, ...], float]:
"""
Convert Fourier coefficients to Mobius coefficients.
Parameters:
- fourier_dict: A dictionary of Fourier coefficients.
Returns:
- A dictionary of Mobius coefficients.
"""
if len(fourier_dict) == 0:
return {}
else:
unscaled_mobius_dict = {}
for loc, coef in fourier_dict.items():
real_coef = np.real(coef)
for subset in powerset(loc):
if subset in unscaled_mobius_dict:
unscaled_mobius_dict[subset] += real_coef
else:
unscaled_mobius_dict[subset] = real_coef
# multiply each entry by (-2)^(cardinality)
return {loc: val * np.power(-2.0, len(loc)) for loc, val in unscaled_mobius_dict.items() if
np.abs(val) > 1e-12}
def mobius_to_fourier(mobius_dict: Dict[Tuple[int, ...], float]) -> Dict[Tuple[int, ...], float]:
"""
Convert Mobius coefficients to Fourier coefficients.
Parameters:
- mobius_dict: A dictionary of Mobius coefficients.
Returns:
- A dictionary of Fourier coefficients.
"""
if len(mobius_dict) == 0:
return {}
else:
unscaled_fourier_dict = {}
for loc, coef in mobius_dict.items():
# Sparse tuples store feature indices, so cardinality is len(loc), not sum(loc).
real_coef = np.real(coef) / (2 ** len(loc))
for subset in powerset(loc):
if subset in unscaled_fourier_dict:
unscaled_fourier_dict[subset] += real_coef
else:
unscaled_fourier_dict[subset] = real_coef
# multiply each entry by (-1)^(cardinality)
return {loc: val * np.power(-1.0, len(loc)) for loc, val in unscaled_fourier_dict.items() if
np.abs(val) > 1e-12}
def mobius_to_shapley_ii(
mobius_dict: Dict[Tuple[int, ...], complex],
max_order: int | None = None,
**kwargs,
) -> Dict[Tuple[int, ...], float]:
"""
Convert Mobius coefficients to Shapley interaction indices.
WARNING: This function can be expensive (exponential in feature count).
For large feature sets, consider using sparse pairwise extraction instead.
Args:
mobius_dict: Mobius coefficients keyed by sparse tuples of indices.
max_order: Maximum interaction order to compute (limits powerset size)
"""
sii_dict: Dict[Tuple[int, ...], float] = {}
for loc, coef in mobius_dict.items():
loc = tuple(sorted(loc))
real_coef = float(np.real(coef))
t_size = len(loc)
for subset in powerset(loc, max_order=max_order):
contribution = real_coef / (1 + t_size - len(subset))
sii_dict[subset] = sii_dict.get(subset, 0.0) + contribution
return sii_dict
def mobius_to_banzhaf_ii(
mobius_dict: Dict[Tuple[int, ...], complex],
max_order: int | None = None,
**kwargs,
) -> Dict[Tuple[int, ...], float]:
"""
Convert Mobius coefficients to Banzhaf interaction indices.
WARNING: This function can be expensive (exponential in feature count).
For large feature sets, consider using sparse pairwise extraction instead.
Args:
mobius_dict: Mobius coefficients keyed by sparse tuples of indices.
max_order: Maximum interaction order to compute (limits powerset size)
"""
bii_dict: Dict[Tuple[int, ...], float] = {}
for loc, coef in mobius_dict.items():
loc = tuple(sorted(loc))
real_coef = float(np.real(coef))
t_size = len(loc)
for subset in powerset(loc, max_order=max_order):
contribution = real_coef / math.pow(2.0, t_size - len(subset))
bii_dict[subset] = bii_dict.get(subset, 0.0) + contribution
return bii_dict
def mobius_to_influence_ii(
mobius_dict: Dict[Tuple[int, ...], complex],
max_order: int | None = None,
**kwargs,
) -> Dict[Tuple[int, ...], float]:
"""
Convert Mobius coefficients to Influence interaction indices.
WARNING: This function can be expensive (exponential in feature count).
For large feature sets, consider using sparse pairwise extraction instead.
Args:
mobius_dict: Mobius coefficients keyed by sparse tuples of indices.
max_order: Maximum interaction order to compute (limits powerset size)
"""
subset_influences: Dict[Tuple[int, ...], float] = {}
fourier_dict = mobius_to_fourier(mobius_dict)
# Step 1: Precompute the squared magnitude for every location
# (Doing this once saves us from recalculating abs()**2 inside the loop)
squared_coefs = {loc: abs(coef) ** 2 for loc, coef in fourier_dict.items()}
# Step 2: For every nonzero location, calculate its subset influence
for target_loc in fourier_dict.keys():
target_set = set(target_loc)
influence = 0.0
# Sum the squared coefficients of any location that overlaps with the target
for loc, sq_coef in squared_coefs.items():
if not target_set.isdisjoint(loc):
influence += sq_coef
subset_influences[target_loc] = influence
return subset_influences
def _filter_order(values: Dict[Tuple[int, ...], float], order: int) -> Dict[Tuple[int, ...], float]:
return {
loc: val
for loc, val in values.items()
if len(loc) == order and not math.isclose(val, 0.0)
}
def _top_k(items: Dict[Tuple[int, ...], float], top_k: int) -> List[Tuple[Tuple[int, ...], float]]:
sorted_items = sorted(items.items(), key=lambda kv: abs(kv[1]), reverse=True)
if top_k is None or top_k <= 0:
return sorted_items
return sorted_items[:top_k]
def mobius_to_shapley(mobius_dict: Dict[Tuple[int, ...], complex]) -> Dict[Tuple[int, ...], float]:
"""
Convert Mobius coefficients to singleton Shapley values.
Optimized: Only computes up to order=1 (singletons) to avoid exponential blow-up.
"""
# Only compute singletons - no need for higher-order powersets
sii = mobius_to_shapley_ii(mobius_dict, max_order=1)
return _filter_order(sii, order=1)
def shapley_interactions(
mobius_dict: Dict[Tuple[int, ...], complex],
order: int = 2,
top_k: int = 10,
) -> List[Tuple[Tuple[int, ...], float]]:
"""
Extract the top-k Shapley interaction indices of a given order.
For order=2 (pairwise), uses sparse superset aggregation to avoid exponential blow-up.
For order>2, falls back to full conversion (may be expensive).
Formula for pairwise (|S|=2): φ_shapley(S) = Σ_{T ⊇ S} m(T) / (|T| - |S| + 1)
"""
if order == 2:
# SPARSE PATH: Sparse superset aggregation for pairwise interactions
# Formula: φ_shapley(S) = Σ_{T ⊇ S} m(T) / (|T| - |S| + 1) = Σ_{T ⊇ S} m(T) / (|T| - 1)
assert order == 2, "Only pairwise supported for sparse path"
if not mobius_dict:
return []
scores: Dict[Tuple[int, ...], float] = {}
for T, mval in mobius_dict.items():
T = tuple(sorted(T))
k = len(T)
if k < order:
continue
# Weight: 1 / (|T| - |S| + 1) = 1 / (k - 2 + 1) = 1 / (k - 1)
weight = 1.0 / (k - order + 1)
for subset in combinations(T, order):
scores[subset] = scores.get(subset, 0.0) + weight * float(np.real(mval))
return _top_k(scores, top_k)
else:
# FULL PATH: Use traditional conversion for higher orders
sii = mobius_to_shapley_ii(mobius_dict)
filtered = _filter_order(sii, order=order)
return _top_k(filtered, top_k)
def mobius_to_banzhaf(mobius_dict: Dict[Tuple[int, ...], complex]) -> Dict[Tuple[int, ...], float]:
"""
Convert Mobius coefficients to singleton Banzhaf values.
Optimized: Only computes up to order=1 (singletons) to avoid exponential blow-up.
"""
# Only compute singletons - no need for higher-order powersets
bii = mobius_to_banzhaf_ii(mobius_dict, max_order=1)
return _filter_order(bii, order=1)
def banzhaf_interactions(
mobius_dict: Dict[Tuple[int, ...], complex],
order: int = 2,
top_k: int = 10,
) -> List[Tuple[Tuple[int, ...], float]]:
"""
Extract the top-k Banzhaf interaction indices of a given order.
For order=2 (pairwise), uses sparse superset aggregation to avoid exponential blow-up.
For order>2, falls back to full conversion (may be expensive).
Formula for pairwise (|S|=2): β_banzhaf(S) = Σ_{T ⊇ S} m(T) / 2^(|T| - |S|)
"""
if order == 2:
# SPARSE PATH: Sparse superset aggregation for pairwise interactions
# Formula: β_banzhaf(S) = Σ_{T ⊇ S} m(T) / 2^(|T| - |S|) = Σ_{T ⊇ S} m(T) / 2^(|T| - 2)
assert order == 2, "Only pairwise supported for sparse path"
if not mobius_dict:
return []
scores: Dict[Tuple[int, ...], float] = {}
for T, mval in mobius_dict.items():
T = tuple(sorted(T))
k = len(T)
if k < order:
continue
# Weight: 1 / 2^(|T| - |S|) = 1 / 2^(k - 2)
weight = 1.0 / (2 ** (k - order))
for subset in combinations(T, order):
scores[subset] = scores.get(subset, 0.0) + weight * float(np.real(mval))
return _top_k(scores, top_k)
else:
# FULL PATH: Use traditional conversion for higher orders
bii = mobius_to_banzhaf_ii(mobius_dict)
filtered = _filter_order(bii, order=order)
return _top_k(filtered, top_k)
def mobius_to_banzhaf(mobius_dict: Dict[Tuple[int, ...], complex]) -> Dict[Tuple[int, ...], float]:
"""
Convert Mobius coefficients to singleton Banzhaf values.
Optimized: Only computes up to order=1 (singletons) to avoid exponential blow-up.
"""
# Only compute singletons - no need for higher-order powersets
bii = mobius_to_banzhaf_ii(mobius_dict, max_order=1)
return _filter_order(bii, order=1)
def mobius_to_influence(mobius_dict: Dict[Tuple[int, ...], complex]) -> Dict[Tuple[int, ...], float]:
"""
Convert Mobius coefficients to singleton Influence values.
Optimized: Only computes up to order=1 (singletons) to avoid exponential blow-up.
"""
# Only compute singletons - no need for higher-order powersets
iii = mobius_to_influence_ii(mobius_dict, max_order=1)
return _filter_order(iii, order=1)
def influence_interactions(
mobius_dict: Dict[Tuple[int, ...], complex],
order: int = 2,
top_k: int = 10,
) -> List[Tuple[Tuple[int, ...], float]]:
"""
Extract the top-k Influence interaction indices of a given order.
For order=2 (pairwise), uses sparse superset aggregation to avoid exponential blow-up.
For order>2, falls back to full conversion (may be expensive).
Formula for pairwise (|S|=2): β_banzhaf(S) = Σ_{T ⊇ S} m(T) / 2^(|T| - |S|)
"""
# FULL PATH: Use traditional conversion for higher orders
iii = mobius_to_influence_ii(mobius_dict)
filtered = _filter_order(iii, order=order)
return _top_k(filtered, top_k)
def convert_to_heatmap_data(
attrs: Sequence[AttributionResult],
feature_order: Sequence[str] | None = None,
) -> Dict[str, Sequence]:
"""
Convert pairwise attribution results into a token × token interaction matrix.
Parameters
----------
attrs:
Sequence of attribution entries where ``feature`` is either a single token
(diagonal contribution) or a tuple of tokens describing an interaction.
feature_order:
Optional list specifying the axis ordering. When omitted, tokens are added
according to first appearance in ``attrs``.
"""
if not attrs:
# Safe default when there are no interactions
return {
"matrix": [],
"features": list(feature_order) if feature_order else [],
"methods": [],
"orders": [],
"min": 0.0,
"max": 0.0,
}
if feature_order:
tokens = list(feature_order)
else:
tokens = []
for attr in attrs:
feats = attr.feature if isinstance(attr.feature, (tuple, list)) else (attr.feature,)
for token in feats:
if token not in tokens:
tokens.append(token)
size = len(tokens)
if size == 0:
return {"matrix": [], "features": [], "methods": []}
index = {token: idx for idx, token in enumerate(tokens)}
matrix = [[0.0 for _ in range(size)] for _ in range(size)]
methods: List[str] = []
orders: List[int] = []
for attr in attrs:
feats = attr.feature if isinstance(attr.feature, (tuple, list)) else (attr.feature,)
if not feats:
continue
methods.append(attr.method)
orders.append(attr.interaction_order)
if len(feats) == 1:
token = feats[0]
if token in index:
idx = index[token]
matrix[idx][idx] = float(attr.value)
continue
if len(feats) > 2:
# For higher-order interactions, distribute the value across all ordered pairs.
iterator = [
(index[a], index[b])
for i, a in enumerate(feats)
for j, b in enumerate(feats)
if i != j and a in index and b in index
]
else:
a, b = feats
if a not in index or b not in index:
continue
iterator = [(index[a], index[b]), (index[b], index[a])]
for i, j in iterator:
matrix[i][j] = float(attr.value)
flat_values = [value for row in matrix for value in row]
if not flat_values:
flat_values = [0.0]
return {
"matrix": matrix,
"features": tokens,
"methods": methods,
"orders": orders,
"min": float(min(flat_values)),
"max": float(max(flat_values)),
}