""" coalition_viewer.py — Step-through attribution process viewer. For small n (≤ 10): Exhaustive 2^n coalitions sorted by popcount. For large n (> 10): Sequential build-up — add one feature at a time (left-to-right), showing n+1 steps. Renders an interactive HTML+CSS+JS widget that lets users step through each coalition to see: - Which features are active (visible) vs masked - The f(S) score for that coalition - Running Shapley values converging toward their final values """ from __future__ import annotations import json import uuid from itertools import product from typing import Any, Dict, List, Optional, Tuple # --------------------------------------------------------------------------- # Data computation # --------------------------------------------------------------------------- _MAX_FEATURES = 100 # hard cap for safety def _parse_mobius_key(key: str) -> Tuple[int, ...]: """Parse a Mobius dict key like '0,2' → (0, 2), '' → ().""" if not key.strip(): return () return tuple(int(x) for x in key.split(",")) def _is_subset(sub: Tuple[int, ...], sup_set: set) -> bool: """Check if every element of *sub* appears in *sup_set*.""" return all(s in sup_set for s in sub) def _f_score_from_mobius( active: Tuple[int, ...], mobius: Dict[Tuple[int, ...], float], ) -> float: """Compute f(S) = Σ_{T ⊆ S} m(T) from the Mobius dict.""" active_set = set(active) total = 0.0 for t_key, m_val in mobius.items(): if _is_subset(t_key, active_set): total += m_val return total def _compute_final_shapley( n: int, mobius: Dict[Tuple[int, ...], float], ) -> List[float]: """φ_i = Σ_{T∋i} m(T) / |T|""" shapley = [0.0] * n for t_key, m_val in mobius.items(): if len(t_key) == 0: continue share = m_val / len(t_key) for i in t_key: if i < n: shapley[i] += share return shapley def _build_exhaustive_coalitions( n: int, features: List[str], mobius: Dict[Tuple[int, ...], float], ) -> List[Dict[str, Any]]: """All 2^n coalitions sorted by popcount then lex.""" all_masks = list(product([0, 1], repeat=n)) all_masks.sort(key=lambda m: (sum(m), m)) # Precompute all scores scores = {} for mask in all_masks: active = tuple(i for i, v in enumerate(mask) if v) scores[mask] = _f_score_from_mobius(active, mobius) marginal_history: List[List[float]] = [[] for _ in range(n)] coalitions = [] for step_idx, mask in enumerate(all_masks): active = [i for i, v in enumerate(mask) if v] masked_parts = [features[i] for i in range(n) if mask[i]] masked_text = " ".join(masked_parts) if masked_parts else "[ALL MASKED]" score = scores[mask] marginals_this_step: Dict[int, float] = {} for i in active: mask_without = list(mask) mask_without[i] = 0 marginal = score - scores[tuple(mask_without)] marginals_this_step[i] = marginal marginal_history[i].append(marginal) running_shapley = [ (sum(marginal_history[i]) / len(marginal_history[i])) if marginal_history[i] else 0.0 for i in range(n) ] coalitions.append({ "step": step_idx + 1, "mask": list(mask), "active": active, "masked_text": masked_text, "f_score": score, "running_shapley": running_shapley, "marginals_this_step": {str(k): v for k, v in marginals_this_step.items()}, }) return coalitions def _build_sequential_coalitions( n: int, features: List[str], mobius: Dict[Tuple[int, ...], float], final_shapley: List[float], ) -> List[Dict[str, Any]]: """Sequential build-up: add one feature at a time (n+1 steps). Uses final Shapley values revealed progressively — each feature's Shapley bar appears when that feature is added. This avoids the problem of one-permutation marginals having opposite signs to the true Shapley values. """ coalitions = [] current_active: List[int] = [] for step_idx in range(n + 1): mask = [0] * n for idx in current_active: mask[idx] = 1 active_tuple = tuple(current_active) score = _f_score_from_mobius(active_tuple, mobius) masked_parts = [features[i] for i in current_active] masked_text = " ".join(masked_parts) if masked_parts else "[ALL MASKED]" # Marginal: the feature just added marginals_this_step: Dict[int, float] = {} if step_idx > 0: added_idx = current_active[-1] prev_active = tuple(current_active[:-1]) prev_score = _f_score_from_mobius(prev_active, mobius) marginals_this_step[added_idx] = score - prev_score # Running Shapley: reveal final Shapley values for features seen so far. # Features not yet added show 0. running_shapley = [0.0] * n for i in current_active: running_shapley[i] = final_shapley[i] coalitions.append({ "step": step_idx + 1, "mask": mask, "active": list(current_active), "masked_text": masked_text, "f_score": score, "running_shapley": running_shapley, "marginals_this_step": {str(k): v for k, v in marginals_this_step.items()}, }) # Add next feature for next step if step_idx < n: current_active = list(current_active) + [step_idx] return coalitions def compute_coalition_viewer_data( features: List[str], mobius_dict_raw: Dict[str, float], ) -> Optional[Dict[str, Any]]: """Build the full data bundle for the coalition step-through viewer. - n ≤ 10: exhaustive 2^n coalitions - 10 < n ≤ 100: sequential left-to-right build-up (n+1 steps) - n > 100: returns None (safety cap) """ n = len(features) if n == 0 or n > _MAX_FEATURES: return None mobius: Dict[Tuple[int, ...], float] = { _parse_mobius_key(k): v for k, v in mobius_dict_raw.items() } final_shapley = _compute_final_shapley(n, mobius) if n <= 10: coalitions = _build_exhaustive_coalitions(n, features, mobius) mode = "exhaustive" else: coalitions = _build_sequential_coalitions(n, features, mobius, final_shapley) mode = "sequential" all_scores = [c["f_score"] for c in coalitions] return { "features": features, "coalitions": coalitions, "final_shapley": final_shapley, "f_min": min(all_scores), "f_max": max(all_scores), "shapley_min": min(final_shapley), "shapley_max": max(final_shapley), "mode": mode, } # --------------------------------------------------------------------------- # HTML rendering # --------------------------------------------------------------------------- def render_coalition_viewer_html( data_bundle: Dict[str, Any], highlighted_step: int = 0, ) -> str: """Render a self-contained HTML+CSS+JS coalition step viewer.""" uid = uuid.uuid4().hex[:8] script_id = f"apv-script-{uid}" loader_id = f"apv-loader-{uid}" data_id = f"apv-data-{uid}" root_id = f"apv-root-{uid}" features = data_bundle["features"] n = len(features) n_steps = len(data_bundle["coalitions"]) mode = data_bundle.get("mode", "exhaustive") is_word_level = n > 10 data_json = json.dumps(data_bundle) # For word-level, use inline token spans instead of block divs # and a horizontal bar chart layout for Shapley css = f""" #{root_id} {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: #f8f9fb; border: 1px solid #e0e0e0; border-radius: 12px; padding: 20px; color: #1a1a2e; max-width: 100%; overflow: hidden; }} #{root_id} .apv-header {{ display: flex; align-items: center; justify-content: space-between; margin-bottom: 16px; flex-wrap: wrap; gap: 8px; }} #{root_id} .apv-title {{ font-size: 16px; font-weight: 700; color: #1a1a2e; }} #{root_id} .apv-mode-badge {{ font-size: 11px; padding: 2px 8px; border-radius: 10px; background: #e8e0f0; color: #5b2d98; font-weight: 600; }} #{root_id} .apv-step-label {{ font-size: 13px; color: #666; min-width: 80px; text-align: center; }} #{root_id} .apv-controls {{ display: flex; align-items: center; gap: 6px; }} #{root_id} .apv-btn {{ background: #fff; border: 1px solid #ccc; color: #333; border-radius: 6px; padding: 5px 12px; cursor: pointer; font-size: 13px; transition: background 0.15s; }} #{root_id} .apv-btn:hover {{ background: #f0ecf5; border-color: #7c3aed; }} #{root_id} .apv-btn-play {{ background: #7c3aed; border-color: #7c3aed; color: #fff; min-width: 70px; }} #{root_id} .apv-btn-play:hover {{ background: #6b2ed0; }} #{root_id} .apv-body {{ display: flex; gap: 20px; flex-wrap: wrap; }} #{root_id} .apv-text-panel {{ flex: 1.2; min-width: 280px; }} #{root_id} .apv-shapley-panel {{ flex: 1; min-width: 280px; }} #{root_id} .apv-section-title {{ font-size: 13px; font-weight: 600; color: #888; margin-bottom: 10px; text-transform: uppercase; letter-spacing: 0.5px; }} /* Block-level features (sentence/paragraph) */ #{root_id} .apv-sentence {{ padding: 8px 12px; margin-bottom: 6px; border-radius: 6px; font-size: 13px; line-height: 1.5; transition: all 0.35s ease; border-left: 3px solid transparent; }} #{root_id} .apv-sentence-active {{ background: rgba(124, 58, 237, 0.1); border-left-color: #7c3aed; color: #1a1a2e; opacity: 1; }} #{root_id} .apv-sentence-masked {{ background: rgba(200, 200, 200, 0.2); color: #bbb; opacity: 0.45; text-decoration: line-through; text-decoration-color: rgba(180, 180, 180, 0.5); }} #{root_id} .apv-sentence-label {{ font-size: 10px; font-weight: 700; color: #999; margin-right: 6px; }} /* Inline word tokens */ #{root_id} .apv-tokens-wrap {{ line-height: 2.0; padding: 8px; background: #fff; border: 1px solid #e8e8e8; border-radius: 8px; max-height: 260px; overflow-y: auto; }} #{root_id} .apv-token {{ display: inline-block; padding: 2px 5px; margin: 2px 1px; border-radius: 4px; font-size: 13px; transition: all 0.25s ease; cursor: default; }} #{root_id} .apv-token-active {{ background: rgba(124, 58, 237, 0.15); color: #1a1a2e; font-weight: 500; box-shadow: 0 0 0 1px rgba(124, 58, 237, 0.3); }} #{root_id} .apv-token-masked {{ background: transparent; color: #ccc; text-decoration: line-through; text-decoration-color: rgba(180, 180, 180, 0.5); }} #{root_id} .apv-token-just-added {{ background: rgba(124, 58, 237, 0.3); color: #1a1a2e; font-weight: 700; box-shadow: 0 0 0 2px #7c3aed; }} #{root_id} .apv-score-row {{ display: flex; align-items: center; gap: 12px; margin-top: 14px; padding: 10px 12px; background: #f0ecf5; border-radius: 8px; }} #{root_id} .apv-score-label {{ font-size: 12px; color: #666; white-space: nowrap; }} #{root_id} .apv-score-value {{ font-size: 18px; font-weight: 700; font-family: 'SF Mono', 'Fira Code', monospace; min-width: 100px; }} #{root_id} .apv-score-bar-track {{ flex: 1; height: 8px; background: #e0dce8; border-radius: 4px; overflow: hidden; }} #{root_id} .apv-score-bar-fill {{ height: 100%; border-radius: 4px; transition: width 0.35s ease, background 0.35s ease; }} #{root_id} .apv-bar-chart {{ width: 100%; }} #{root_id} .apv-marginal-note {{ font-size: 14px; color: #555; margin-top: 10px; min-height: 20px; font-style: italic; line-height: 1.4; }} #{root_id} .apv-coalition-label {{ font-size: 14px; font-weight: 600; color: #444; margin-top: 10px; font-family: 'SF Mono', 'Fira Code', monospace; }} #{root_id} .apv-btn-reset {{ background: #fff; border: 1px solid #ccc; color: #888; font-size: 12px; }} #{root_id} .apv-btn-reset:hover {{ background: #f0ecf5; border-color: #7c3aed; color: #333; }} #{root_id} .apv-loader {{ width: 1px; height: 1px; position: absolute; opacity: 0; }} """ # For word-level, show top-N Shapley bars rather than all n bars # and use inline tokens instead of block divs bar_display_limit = min(n, 15) if is_word_level else n js_code = f""" (function() {{ var root = document.getElementById('{root_id}'); if (!root) return; var DATA = JSON.parse(document.getElementById('{data_id}').textContent); var N_STEPS = DATA.coalitions.length; var currentStep = {highlighted_step}; var playTimer = null; var n = DATA.features.length; var isWordLevel = {str(is_word_level).lower()}; var barDisplayLimit = {bar_display_limit}; var stepLabel = root.querySelector('.apv-step-label'); var playBtn = root.querySelector('.apv-btn-play'); var scoreValue = root.querySelector('.apv-score-value'); var scoreFill = root.querySelector('.apv-score-bar-fill'); var marginalNote = root.querySelector('.apv-marginal-note'); var coalitionLabel = root.querySelector('.apv-coalition-label'); var svg = root.querySelector('.apv-bar-chart'); function render(step) {{ var c = DATA.coalitions[step]; stepLabel.textContent = 'Step ' + c.step + ' / ' + N_STEPS; if (isWordLevel) {{ coalitionLabel.textContent = 'Words active: ' + c.active.length + ' / ' + n; }} else {{ coalitionLabel.textContent = 'Coalition: (' + c.mask.join(', ') + ')'; }} // Update features (inline tokens or block sentences) for (var i = 0; i < n; i++) {{ var el = root.querySelector('[data-idx="' + i + '"]'); if (!el) continue; if (isWordLevel) {{ // Check if this token was JUST added this step if (i.toString() in c.marginals_this_step) {{ el.className = 'apv-token apv-token-just-added'; }} else if (c.mask[i]) {{ el.className = 'apv-token apv-token-active'; }} else {{ el.className = 'apv-token apv-token-masked'; }} }} else {{ if (c.mask[i]) {{ el.className = 'apv-sentence apv-sentence-active'; }} else {{ el.className = 'apv-sentence apv-sentence-masked'; }} }} }} // f(S) score var score = c.f_score; var color = score >= 0 ? '#7c3aed' : '#dd1313'; scoreValue.textContent = score.toFixed(6); scoreValue.style.color = color; var range = DATA.f_max - DATA.f_min; var pct = range > 0 ? ((score - DATA.f_min) / range) * 100 : 50; scoreFill.style.width = Math.max(2, pct) + '%'; scoreFill.style.background = color; // Shapley bar chart renderBars(c.running_shapley, DATA.final_shapley, c.marginals_this_step); // Marginal note var keys = Object.keys(c.marginals_this_step); if (keys.length > 0) {{ var parts = keys.slice(0, 3).map(function(k) {{ var v = c.marginals_this_step[k]; var sign = v >= 0 ? '+' : ''; var label = DATA.features[parseInt(k)]; if (label.length > 20) label = label.substring(0, 20) + '...'; return '"' + label + '": ' + sign + v.toFixed(6); }}); marginalNote.textContent = 'Marginal: ' + parts.join('; '); }} else {{ marginalNote.textContent = 'No active features \\u2014 baseline score.'; }} }} // Precompute stable axis range from ALL final Shapley values (never changes between steps) var stableMin = 0, stableMax = 0; for (var i = 0; i < n; i++) {{ if (DATA.final_shapley[i] < stableMin) stableMin = DATA.final_shapley[i]; if (DATA.final_shapley[i] > stableMax) stableMax = DATA.final_shapley[i]; }} var stablePad = Math.max(Math.abs(stableMax), Math.abs(stableMin)) * 0.25; if (stablePad < 1e-10) stablePad = 0.001; stableMin = stableMin - stablePad; stableMax = stableMax + stablePad; function renderBars(running, final_vals, marginals) {{ // For word-level, show the top-N by |final Shapley| value var indices = []; if (isWordLevel) {{ var sorted = final_vals.map(function(v, i) {{ return [Math.abs(v), i]; }}); sorted.sort(function(a, b) {{ return b[0] - a[0]; }}); for (var k = 0; k < Math.min(barDisplayLimit, sorted.length); k++) {{ indices.push(sorted[k][1]); }} }} else {{ for (var k = 0; k < n; k++) indices.push(k); }} var dispN = indices.length; var rowH = 42; var W = 340, H = 30 + dispN * rowH; svg.setAttribute('viewBox', '0 0 ' + W + ' ' + H); // Use stable axis range so bars don't jump between steps var vMin = stableMin; var vMax = stableMax; var range = vMax - vMin; if (range < 1e-12) range = 0.002; var mL = 10, barArea = W - mL - 10; var zeroX = mL + ((-vMin) / range) * barArea; var html = ''; html += ''; for (var k = 0; k < dispN; k++) {{ var i = indices[k]; var y = 15 + k * rowH; var barH = 20; var rv = running[i]; var fv = final_vals[i]; // Ghost bar (final target) var fvX = mL + ((fv - vMin) / range) * barArea; var gX = Math.min(zeroX, fvX); var gW = Math.abs(fvX - zeroX); var gC = fv >= 0 ? '#7c3aed' : '#dd1313'; html += ''; // Running bar var rvX = mL + ((rv - vMin) / range) * barArea; var bX = Math.min(zeroX, rvX); var bW = Math.abs(rvX - zeroX); var bC = rv >= 0 ? '#7c3aed' : '#dd1313'; var op = (rv === 0 && !(i.toString() in marginals)) ? '0.1' : '0.75'; html += ''; // Highlight dot if (i.toString() in marginals) {{ html += ''; }} // Label + Value on the SAME line below the bar var label = DATA.features[i]; if (label.length > 25) label = label.substring(0, 25) + '..'; label = label.replace(//g, '>'); var prefix = isWordLevel ? ('W' + i) : ('S' + i); var labelY = y + barH + 13; html += '' + prefix + ': ' + label + ''; var sign = rv >= 0 ? '+' : ''; var vT = sign + rv.toFixed(6); html += '' + vT + ''; }} html += 'Dashed = final Shapley target' + (isWordLevel ? ' (top ' + barDisplayLimit + ' by |value|)' : '') + ''; svg.innerHTML = html; }} // Controls root.querySelector('.apv-btn-reset').addEventListener('click', function() {{ if (playTimer) {{ clearInterval(playTimer); playTimer = null; playBtn.textContent = '\\u25B6 Play'; }} currentStep = 0; render(currentStep); }}); root.querySelector('.apv-btn-prev').addEventListener('click', function() {{ currentStep = Math.max(0, currentStep - 1); render(currentStep); }}); root.querySelector('.apv-btn-next').addEventListener('click', function() {{ currentStep = Math.min(N_STEPS - 1, currentStep + 1); render(currentStep); }}); playBtn.addEventListener('click', function() {{ if (playTimer) {{ clearInterval(playTimer); playTimer = null; playBtn.textContent = '\\u25B6 Play'; }} else {{ playBtn.textContent = '\\u23F8 Pause'; var interval = isWordLevel ? 400 : 1500; playTimer = setInterval(function() {{ // Stop if our root element was removed from DOM (dataset switched) if (!document.getElementById('{root_id}')) {{ clearInterval(playTimer); playTimer = null; return; }} currentStep = (currentStep + 1) % N_STEPS; render(currentStep); if (currentStep === N_STEPS - 1) {{ clearInterval(playTimer); playTimer = null; playBtn.textContent = '\\u25B6 Play'; }} }}, interval); }} }}); render(currentStep); }})(); """ # Build feature elements — inline tokens for word-level, block divs otherwise if is_word_level: feature_html = '
' for i, feat in enumerate(features): escaped = feat.replace("&", "&").replace("<", "<").replace(">", ">") feature_html += ( f'{escaped}' ) feature_html += '
' else: feature_html = "" for i, feat in enumerate(features): escaped = feat.replace("&", "&").replace("<", "<").replace(">", ">") feature_html += ( f'
' f'S{i}{escaped}
\n' ) svg_height = 30 + bar_display_limit * 42 mode_label = "Sequential Build-up" if mode == "sequential" else f"All 2\u207f = {n_steps} Coalitions" html = ( f'
' f"" f'
' f'
' f'Attribution Process Viewer ' f'{mode_label}' f'
' f'
' f'' f'' f'' f'' f'Step 1 / {n_steps}' f'
' f'
' f'
' f'
Masked Context
' f'{feature_html}' f'
' f'f(S) =' f'0.000000' f'
' f'
' f'
' f'
' f'
' f'
' f'
Running Shapley Values
' f'' f'
' f'
' f'' f'' f'' f'
' ) return html