from __future__ import annotations import math from dataclasses import dataclass from itertools import combinations from typing import Dict, List, Sequence, Tuple, Union import numpy as np @dataclass class AttributionResult: """Container for a single feature or interaction attribution value.""" feature: Union[str, Tuple[str, ...]] value: float method: str interaction_order: int def powerset(loc_tuple: Tuple[int, ...], max_order: int | None = None) -> List[Tuple[int, ...]]: """ Generate the powerset of a sparse location tuple up to a specified maximum order. Parameters: - loc_tuple: A sparse tuple of feature indices (e.g., (0, 2) means features 0 and 2). - max_order: The maximum order of the powerset (default is None). Returns: - A list of tuples representing the powerset in sparse format. """ # loc_tuple is already sparse (contains feature indices) # e.g., (0, 2) means features 0 and 2 are active if max_order is None: max_order = len(loc_tuple) # Generate all subsets of the sparse indices up to max_order tuples = [] for r in range(max_order + 1): for subset in combinations(loc_tuple, r): tuples.append(subset) return tuples def fourier_to_mobius(fourier_dict: Dict[Tuple[int, ...], float]) -> Dict[Tuple[int, ...], float]: """ Convert Fourier coefficients to Mobius coefficients. Parameters: - fourier_dict: A dictionary of Fourier coefficients. Returns: - A dictionary of Mobius coefficients. """ if len(fourier_dict) == 0: return {} else: unscaled_mobius_dict = {} for loc, coef in fourier_dict.items(): real_coef = np.real(coef) for subset in powerset(loc): if subset in unscaled_mobius_dict: unscaled_mobius_dict[subset] += real_coef else: unscaled_mobius_dict[subset] = real_coef # multiply each entry by (-2)^(cardinality) return {loc: val * np.power(-2.0, len(loc)) for loc, val in unscaled_mobius_dict.items() if np.abs(val) > 1e-12} def mobius_to_fourier(mobius_dict: Dict[Tuple[int, ...], float]) -> Dict[Tuple[int, ...], float]: """ Convert Mobius coefficients to Fourier coefficients. Parameters: - mobius_dict: A dictionary of Mobius coefficients. Returns: - A dictionary of Fourier coefficients. """ if len(mobius_dict) == 0: return {} else: unscaled_fourier_dict = {} for loc, coef in mobius_dict.items(): # Sparse tuples store feature indices, so cardinality is len(loc), not sum(loc). real_coef = np.real(coef) / (2 ** len(loc)) for subset in powerset(loc): if subset in unscaled_fourier_dict: unscaled_fourier_dict[subset] += real_coef else: unscaled_fourier_dict[subset] = real_coef # multiply each entry by (-1)^(cardinality) return {loc: val * np.power(-1.0, len(loc)) for loc, val in unscaled_fourier_dict.items() if np.abs(val) > 1e-12} def mobius_to_shapley_ii( mobius_dict: Dict[Tuple[int, ...], complex], max_order: int | None = None, **kwargs, ) -> Dict[Tuple[int, ...], float]: """ Convert Mobius coefficients to Shapley interaction indices. WARNING: This function can be expensive (exponential in feature count). For large feature sets, consider using sparse pairwise extraction instead. Args: mobius_dict: Mobius coefficients keyed by sparse tuples of indices. max_order: Maximum interaction order to compute (limits powerset size) """ sii_dict: Dict[Tuple[int, ...], float] = {} for loc, coef in mobius_dict.items(): loc = tuple(sorted(loc)) real_coef = float(np.real(coef)) t_size = len(loc) for subset in powerset(loc, max_order=max_order): contribution = real_coef / (1 + t_size - len(subset)) sii_dict[subset] = sii_dict.get(subset, 0.0) + contribution return sii_dict def mobius_to_banzhaf_ii( mobius_dict: Dict[Tuple[int, ...], complex], max_order: int | None = None, **kwargs, ) -> Dict[Tuple[int, ...], float]: """ Convert Mobius coefficients to Banzhaf interaction indices. WARNING: This function can be expensive (exponential in feature count). For large feature sets, consider using sparse pairwise extraction instead. Args: mobius_dict: Mobius coefficients keyed by sparse tuples of indices. max_order: Maximum interaction order to compute (limits powerset size) """ bii_dict: Dict[Tuple[int, ...], float] = {} for loc, coef in mobius_dict.items(): loc = tuple(sorted(loc)) real_coef = float(np.real(coef)) t_size = len(loc) for subset in powerset(loc, max_order=max_order): contribution = real_coef / math.pow(2.0, t_size - len(subset)) bii_dict[subset] = bii_dict.get(subset, 0.0) + contribution return bii_dict def mobius_to_influence_ii( mobius_dict: Dict[Tuple[int, ...], complex], max_order: int | None = None, **kwargs, ) -> Dict[Tuple[int, ...], float]: """ Convert Mobius coefficients to Influence interaction indices. WARNING: This function can be expensive (exponential in feature count). For large feature sets, consider using sparse pairwise extraction instead. Args: mobius_dict: Mobius coefficients keyed by sparse tuples of indices. max_order: Maximum interaction order to compute (limits powerset size) """ subset_influences: Dict[Tuple[int, ...], float] = {} fourier_dict = mobius_to_fourier(mobius_dict) # Step 1: Precompute the squared magnitude for every location # (Doing this once saves us from recalculating abs()**2 inside the loop) squared_coefs = {loc: abs(coef) ** 2 for loc, coef in fourier_dict.items()} # Step 2: For every nonzero location, calculate its subset influence for target_loc in fourier_dict.keys(): target_set = set(target_loc) influence = 0.0 # Sum the squared coefficients of any location that overlaps with the target for loc, sq_coef in squared_coefs.items(): if not target_set.isdisjoint(loc): influence += sq_coef subset_influences[target_loc] = influence return subset_influences def _filter_order(values: Dict[Tuple[int, ...], float], order: int) -> Dict[Tuple[int, ...], float]: return { loc: val for loc, val in values.items() if len(loc) == order and not math.isclose(val, 0.0) } def _top_k(items: Dict[Tuple[int, ...], float], top_k: int) -> List[Tuple[Tuple[int, ...], float]]: sorted_items = sorted(items.items(), key=lambda kv: abs(kv[1]), reverse=True) if top_k is None or top_k <= 0: return sorted_items return sorted_items[:top_k] def mobius_to_shapley(mobius_dict: Dict[Tuple[int, ...], complex]) -> Dict[Tuple[int, ...], float]: """ Convert Mobius coefficients to singleton Shapley values. Optimized: Only computes up to order=1 (singletons) to avoid exponential blow-up. """ # Only compute singletons - no need for higher-order powersets sii = mobius_to_shapley_ii(mobius_dict, max_order=1) return _filter_order(sii, order=1) def shapley_interactions( mobius_dict: Dict[Tuple[int, ...], complex], order: int = 2, top_k: int = 10, ) -> List[Tuple[Tuple[int, ...], float]]: """ Extract the top-k Shapley interaction indices of a given order. For order=2 (pairwise), uses sparse superset aggregation to avoid exponential blow-up. For order>2, falls back to full conversion (may be expensive). Formula for pairwise (|S|=2): φ_shapley(S) = Σ_{T ⊇ S} m(T) / (|T| - |S| + 1) """ if order == 2: # SPARSE PATH: Sparse superset aggregation for pairwise interactions # Formula: φ_shapley(S) = Σ_{T ⊇ S} m(T) / (|T| - |S| + 1) = Σ_{T ⊇ S} m(T) / (|T| - 1) assert order == 2, "Only pairwise supported for sparse path" if not mobius_dict: return [] scores: Dict[Tuple[int, ...], float] = {} for T, mval in mobius_dict.items(): T = tuple(sorted(T)) k = len(T) if k < order: continue # Weight: 1 / (|T| - |S| + 1) = 1 / (k - 2 + 1) = 1 / (k - 1) weight = 1.0 / (k - order + 1) for subset in combinations(T, order): scores[subset] = scores.get(subset, 0.0) + weight * float(np.real(mval)) return _top_k(scores, top_k) else: # FULL PATH: Use traditional conversion for higher orders sii = mobius_to_shapley_ii(mobius_dict) filtered = _filter_order(sii, order=order) return _top_k(filtered, top_k) def mobius_to_banzhaf(mobius_dict: Dict[Tuple[int, ...], complex]) -> Dict[Tuple[int, ...], float]: """ Convert Mobius coefficients to singleton Banzhaf values. Optimized: Only computes up to order=1 (singletons) to avoid exponential blow-up. """ # Only compute singletons - no need for higher-order powersets bii = mobius_to_banzhaf_ii(mobius_dict, max_order=1) return _filter_order(bii, order=1) def banzhaf_interactions( mobius_dict: Dict[Tuple[int, ...], complex], order: int = 2, top_k: int = 10, ) -> List[Tuple[Tuple[int, ...], float]]: """ Extract the top-k Banzhaf interaction indices of a given order. For order=2 (pairwise), uses sparse superset aggregation to avoid exponential blow-up. For order>2, falls back to full conversion (may be expensive). Formula for pairwise (|S|=2): β_banzhaf(S) = Σ_{T ⊇ S} m(T) / 2^(|T| - |S|) """ if order == 2: # SPARSE PATH: Sparse superset aggregation for pairwise interactions # Formula: β_banzhaf(S) = Σ_{T ⊇ S} m(T) / 2^(|T| - |S|) = Σ_{T ⊇ S} m(T) / 2^(|T| - 2) assert order == 2, "Only pairwise supported for sparse path" if not mobius_dict: return [] scores: Dict[Tuple[int, ...], float] = {} for T, mval in mobius_dict.items(): T = tuple(sorted(T)) k = len(T) if k < order: continue # Weight: 1 / 2^(|T| - |S|) = 1 / 2^(k - 2) weight = 1.0 / (2 ** (k - order)) for subset in combinations(T, order): scores[subset] = scores.get(subset, 0.0) + weight * float(np.real(mval)) return _top_k(scores, top_k) else: # FULL PATH: Use traditional conversion for higher orders bii = mobius_to_banzhaf_ii(mobius_dict) filtered = _filter_order(bii, order=order) return _top_k(filtered, top_k) def mobius_to_banzhaf(mobius_dict: Dict[Tuple[int, ...], complex]) -> Dict[Tuple[int, ...], float]: """ Convert Mobius coefficients to singleton Banzhaf values. Optimized: Only computes up to order=1 (singletons) to avoid exponential blow-up. """ # Only compute singletons - no need for higher-order powersets bii = mobius_to_banzhaf_ii(mobius_dict, max_order=1) return _filter_order(bii, order=1) def mobius_to_influence(mobius_dict: Dict[Tuple[int, ...], complex]) -> Dict[Tuple[int, ...], float]: """ Convert Mobius coefficients to singleton Influence values. Optimized: Only computes up to order=1 (singletons) to avoid exponential blow-up. """ # Only compute singletons - no need for higher-order powersets iii = mobius_to_influence_ii(mobius_dict, max_order=1) return _filter_order(iii, order=1) def influence_interactions( mobius_dict: Dict[Tuple[int, ...], complex], order: int = 2, top_k: int = 10, ) -> List[Tuple[Tuple[int, ...], float]]: """ Extract the top-k Influence interaction indices of a given order. For order=2 (pairwise), uses sparse superset aggregation to avoid exponential blow-up. For order>2, falls back to full conversion (may be expensive). Formula for pairwise (|S|=2): β_banzhaf(S) = Σ_{T ⊇ S} m(T) / 2^(|T| - |S|) """ # FULL PATH: Use traditional conversion for higher orders iii = mobius_to_influence_ii(mobius_dict) filtered = _filter_order(iii, order=order) return _top_k(filtered, top_k) def convert_to_heatmap_data( attrs: Sequence[AttributionResult], feature_order: Sequence[str] | None = None, ) -> Dict[str, Sequence]: """ Convert pairwise attribution results into a token × token interaction matrix. Parameters ---------- attrs: Sequence of attribution entries where ``feature`` is either a single token (diagonal contribution) or a tuple of tokens describing an interaction. feature_order: Optional list specifying the axis ordering. When omitted, tokens are added according to first appearance in ``attrs``. """ if not attrs: # Safe default when there are no interactions return { "matrix": [], "features": list(feature_order) if feature_order else [], "methods": [], "orders": [], "min": 0.0, "max": 0.0, } if feature_order: tokens = list(feature_order) else: tokens = [] for attr in attrs: feats = attr.feature if isinstance(attr.feature, (tuple, list)) else (attr.feature,) for token in feats: if token not in tokens: tokens.append(token) size = len(tokens) if size == 0: return {"matrix": [], "features": [], "methods": []} index = {token: idx for idx, token in enumerate(tokens)} matrix = [[0.0 for _ in range(size)] for _ in range(size)] methods: List[str] = [] orders: List[int] = [] for attr in attrs: feats = attr.feature if isinstance(attr.feature, (tuple, list)) else (attr.feature,) if not feats: continue methods.append(attr.method) orders.append(attr.interaction_order) if len(feats) == 1: token = feats[0] if token in index: idx = index[token] matrix[idx][idx] = float(attr.value) continue if len(feats) > 2: # For higher-order interactions, distribute the value across all ordered pairs. iterator = [ (index[a], index[b]) for i, a in enumerate(feats) for j, b in enumerate(feats) if i != j and a in index and b in index ] else: a, b = feats if a not in index or b not in index: continue iterator = [(index[a], index[b]), (index[b], index[a])] for i, j in iterator: matrix[i][j] = float(attr.value) flat_values = [value for row in matrix for value in row] if not flat_values: flat_values = [0.0] return { "matrix": matrix, "features": tokens, "methods": methods, "orders": orders, "min": float(min(flat_values)), "max": float(max(flat_values)), }