AttrLLM / visualization /plotting /coalition_viewer.py
Qingpeng Kong
clean initial state
3e72399
"""
coalition_viewer.py — Step-through attribution process viewer.
For small n (≤ 10): Exhaustive 2^n coalitions sorted by popcount.
For large n (> 10): Sequential build-up — add one feature at a time
(left-to-right), showing n+1 steps.
Renders an interactive HTML+CSS+JS widget that lets users step through
each coalition to see:
- Which features are active (visible) vs masked
- The f(S) score for that coalition
- Running Shapley values converging toward their final values
"""
from __future__ import annotations
import json
import uuid
from itertools import product
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# Data computation
# ---------------------------------------------------------------------------
_MAX_FEATURES = 100 # hard cap for safety
def _parse_mobius_key(key: str) -> Tuple[int, ...]:
"""Parse a Mobius dict key like '0,2' → (0, 2), '' → ()."""
if not key.strip():
return ()
return tuple(int(x) for x in key.split(","))
def _is_subset(sub: Tuple[int, ...], sup_set: set) -> bool:
"""Check if every element of *sub* appears in *sup_set*."""
return all(s in sup_set for s in sub)
def _f_score_from_mobius(
active: Tuple[int, ...],
mobius: Dict[Tuple[int, ...], float],
) -> float:
"""Compute f(S) = Σ_{T ⊆ S} m(T) from the Mobius dict."""
active_set = set(active)
total = 0.0
for t_key, m_val in mobius.items():
if _is_subset(t_key, active_set):
total += m_val
return total
def _compute_final_shapley(
n: int,
mobius: Dict[Tuple[int, ...], float],
) -> List[float]:
"""φ_i = Σ_{T∋i} m(T) / |T|"""
shapley = [0.0] * n
for t_key, m_val in mobius.items():
if len(t_key) == 0:
continue
share = m_val / len(t_key)
for i in t_key:
if i < n:
shapley[i] += share
return shapley
def _build_exhaustive_coalitions(
n: int,
features: List[str],
mobius: Dict[Tuple[int, ...], float],
) -> List[Dict[str, Any]]:
"""All 2^n coalitions sorted by popcount then lex."""
all_masks = list(product([0, 1], repeat=n))
all_masks.sort(key=lambda m: (sum(m), m))
# Precompute all scores
scores = {}
for mask in all_masks:
active = tuple(i for i, v in enumerate(mask) if v)
scores[mask] = _f_score_from_mobius(active, mobius)
marginal_history: List[List[float]] = [[] for _ in range(n)]
coalitions = []
for step_idx, mask in enumerate(all_masks):
active = [i for i, v in enumerate(mask) if v]
masked_parts = [features[i] for i in range(n) if mask[i]]
masked_text = " ".join(masked_parts) if masked_parts else "[ALL MASKED]"
score = scores[mask]
marginals_this_step: Dict[int, float] = {}
for i in active:
mask_without = list(mask)
mask_without[i] = 0
marginal = score - scores[tuple(mask_without)]
marginals_this_step[i] = marginal
marginal_history[i].append(marginal)
running_shapley = [
(sum(marginal_history[i]) / len(marginal_history[i]))
if marginal_history[i] else 0.0
for i in range(n)
]
coalitions.append({
"step": step_idx + 1,
"mask": list(mask),
"active": active,
"masked_text": masked_text,
"f_score": score,
"running_shapley": running_shapley,
"marginals_this_step": {str(k): v for k, v in marginals_this_step.items()},
})
return coalitions
def _build_sequential_coalitions(
n: int,
features: List[str],
mobius: Dict[Tuple[int, ...], float],
final_shapley: List[float],
) -> List[Dict[str, Any]]:
"""Sequential build-up: add one feature at a time (n+1 steps).
Uses final Shapley values revealed progressively — each feature's
Shapley bar appears when that feature is added. This avoids the
problem of one-permutation marginals having opposite signs to the
true Shapley values.
"""
coalitions = []
current_active: List[int] = []
for step_idx in range(n + 1):
mask = [0] * n
for idx in current_active:
mask[idx] = 1
active_tuple = tuple(current_active)
score = _f_score_from_mobius(active_tuple, mobius)
masked_parts = [features[i] for i in current_active]
masked_text = " ".join(masked_parts) if masked_parts else "[ALL MASKED]"
# Marginal: the feature just added
marginals_this_step: Dict[int, float] = {}
if step_idx > 0:
added_idx = current_active[-1]
prev_active = tuple(current_active[:-1])
prev_score = _f_score_from_mobius(prev_active, mobius)
marginals_this_step[added_idx] = score - prev_score
# Running Shapley: reveal final Shapley values for features seen so far.
# Features not yet added show 0.
running_shapley = [0.0] * n
for i in current_active:
running_shapley[i] = final_shapley[i]
coalitions.append({
"step": step_idx + 1,
"mask": mask,
"active": list(current_active),
"masked_text": masked_text,
"f_score": score,
"running_shapley": running_shapley,
"marginals_this_step": {str(k): v for k, v in marginals_this_step.items()},
})
# Add next feature for next step
if step_idx < n:
current_active = list(current_active) + [step_idx]
return coalitions
def compute_coalition_viewer_data(
features: List[str],
mobius_dict_raw: Dict[str, float],
) -> Optional[Dict[str, Any]]:
"""Build the full data bundle for the coalition step-through viewer.
- n ≤ 10: exhaustive 2^n coalitions
- 10 < n ≤ 100: sequential left-to-right build-up (n+1 steps)
- n > 100: returns None (safety cap)
"""
n = len(features)
if n == 0 or n > _MAX_FEATURES:
return None
mobius: Dict[Tuple[int, ...], float] = {
_parse_mobius_key(k): v for k, v in mobius_dict_raw.items()
}
final_shapley = _compute_final_shapley(n, mobius)
if n <= 10:
coalitions = _build_exhaustive_coalitions(n, features, mobius)
mode = "exhaustive"
else:
coalitions = _build_sequential_coalitions(n, features, mobius, final_shapley)
mode = "sequential"
all_scores = [c["f_score"] for c in coalitions]
return {
"features": features,
"coalitions": coalitions,
"final_shapley": final_shapley,
"f_min": min(all_scores),
"f_max": max(all_scores),
"shapley_min": min(final_shapley),
"shapley_max": max(final_shapley),
"mode": mode,
}
# ---------------------------------------------------------------------------
# HTML rendering
# ---------------------------------------------------------------------------
def render_coalition_viewer_html(
data_bundle: Dict[str, Any],
highlighted_step: int = 0,
) -> str:
"""Render a self-contained HTML+CSS+JS coalition step viewer."""
uid = uuid.uuid4().hex[:8]
script_id = f"apv-script-{uid}"
loader_id = f"apv-loader-{uid}"
data_id = f"apv-data-{uid}"
root_id = f"apv-root-{uid}"
features = data_bundle["features"]
n = len(features)
n_steps = len(data_bundle["coalitions"])
mode = data_bundle.get("mode", "exhaustive")
is_word_level = n > 10
data_json = json.dumps(data_bundle)
# For word-level, use inline token spans instead of block divs
# and a horizontal bar chart layout for Shapley
css = f"""
#{root_id} {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #f8f9fb;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 20px;
color: #1a1a2e;
max-width: 100%;
overflow: hidden;
}}
#{root_id} .apv-header {{
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 16px;
flex-wrap: wrap;
gap: 8px;
}}
#{root_id} .apv-title {{
font-size: 16px;
font-weight: 700;
color: #1a1a2e;
}}
#{root_id} .apv-mode-badge {{
font-size: 11px;
padding: 2px 8px;
border-radius: 10px;
background: #e8e0f0;
color: #5b2d98;
font-weight: 600;
}}
#{root_id} .apv-step-label {{
font-size: 13px;
color: #666;
min-width: 80px;
text-align: center;
}}
#{root_id} .apv-controls {{
display: flex;
align-items: center;
gap: 6px;
}}
#{root_id} .apv-btn {{
background: #fff;
border: 1px solid #ccc;
color: #333;
border-radius: 6px;
padding: 5px 12px;
cursor: pointer;
font-size: 13px;
transition: background 0.15s;
}}
#{root_id} .apv-btn:hover {{
background: #f0ecf5;
border-color: #7c3aed;
}}
#{root_id} .apv-btn-play {{
background: #7c3aed;
border-color: #7c3aed;
color: #fff;
min-width: 70px;
}}
#{root_id} .apv-btn-play:hover {{
background: #6b2ed0;
}}
#{root_id} .apv-body {{
display: flex;
gap: 20px;
flex-wrap: wrap;
}}
#{root_id} .apv-text-panel {{
flex: 1.2;
min-width: 280px;
}}
#{root_id} .apv-shapley-panel {{
flex: 1;
min-width: 280px;
}}
#{root_id} .apv-section-title {{
font-size: 13px;
font-weight: 600;
color: #888;
margin-bottom: 10px;
text-transform: uppercase;
letter-spacing: 0.5px;
}}
/* Block-level features (sentence/paragraph) */
#{root_id} .apv-sentence {{
padding: 8px 12px;
margin-bottom: 6px;
border-radius: 6px;
font-size: 13px;
line-height: 1.5;
transition: all 0.35s ease;
border-left: 3px solid transparent;
}}
#{root_id} .apv-sentence-active {{
background: rgba(124, 58, 237, 0.1);
border-left-color: #7c3aed;
color: #1a1a2e;
opacity: 1;
}}
#{root_id} .apv-sentence-masked {{
background: rgba(200, 200, 200, 0.2);
color: #bbb;
opacity: 0.45;
text-decoration: line-through;
text-decoration-color: rgba(180, 180, 180, 0.5);
}}
#{root_id} .apv-sentence-label {{
font-size: 10px;
font-weight: 700;
color: #999;
margin-right: 6px;
}}
/* Inline word tokens */
#{root_id} .apv-tokens-wrap {{
line-height: 2.0;
padding: 8px;
background: #fff;
border: 1px solid #e8e8e8;
border-radius: 8px;
max-height: 260px;
overflow-y: auto;
}}
#{root_id} .apv-token {{
display: inline-block;
padding: 2px 5px;
margin: 2px 1px;
border-radius: 4px;
font-size: 13px;
transition: all 0.25s ease;
cursor: default;
}}
#{root_id} .apv-token-active {{
background: rgba(124, 58, 237, 0.15);
color: #1a1a2e;
font-weight: 500;
box-shadow: 0 0 0 1px rgba(124, 58, 237, 0.3);
}}
#{root_id} .apv-token-masked {{
background: transparent;
color: #ccc;
text-decoration: line-through;
text-decoration-color: rgba(180, 180, 180, 0.5);
}}
#{root_id} .apv-token-just-added {{
background: rgba(124, 58, 237, 0.3);
color: #1a1a2e;
font-weight: 700;
box-shadow: 0 0 0 2px #7c3aed;
}}
#{root_id} .apv-score-row {{
display: flex;
align-items: center;
gap: 12px;
margin-top: 14px;
padding: 10px 12px;
background: #f0ecf5;
border-radius: 8px;
}}
#{root_id} .apv-score-label {{
font-size: 12px;
color: #666;
white-space: nowrap;
}}
#{root_id} .apv-score-value {{
font-size: 18px;
font-weight: 700;
font-family: 'SF Mono', 'Fira Code', monospace;
min-width: 100px;
}}
#{root_id} .apv-score-bar-track {{
flex: 1;
height: 8px;
background: #e0dce8;
border-radius: 4px;
overflow: hidden;
}}
#{root_id} .apv-score-bar-fill {{
height: 100%;
border-radius: 4px;
transition: width 0.35s ease, background 0.35s ease;
}}
#{root_id} .apv-bar-chart {{
width: 100%;
}}
#{root_id} .apv-marginal-note {{
font-size: 14px;
color: #555;
margin-top: 10px;
min-height: 20px;
font-style: italic;
line-height: 1.4;
}}
#{root_id} .apv-coalition-label {{
font-size: 14px;
font-weight: 600;
color: #444;
margin-top: 10px;
font-family: 'SF Mono', 'Fira Code', monospace;
}}
#{root_id} .apv-btn-reset {{
background: #fff;
border: 1px solid #ccc;
color: #888;
font-size: 12px;
}}
#{root_id} .apv-btn-reset:hover {{
background: #f0ecf5;
border-color: #7c3aed;
color: #333;
}}
#{root_id} .apv-loader {{
width: 1px;
height: 1px;
position: absolute;
opacity: 0;
}}
"""
# For word-level, show top-N Shapley bars rather than all n bars
# and use inline tokens instead of block divs
bar_display_limit = min(n, 15) if is_word_level else n
js_code = f"""
(function() {{
var root = document.getElementById('{root_id}');
if (!root) return;
var DATA = JSON.parse(document.getElementById('{data_id}').textContent);
var N_STEPS = DATA.coalitions.length;
var currentStep = {highlighted_step};
var playTimer = null;
var n = DATA.features.length;
var isWordLevel = {str(is_word_level).lower()};
var barDisplayLimit = {bar_display_limit};
var stepLabel = root.querySelector('.apv-step-label');
var playBtn = root.querySelector('.apv-btn-play');
var scoreValue = root.querySelector('.apv-score-value');
var scoreFill = root.querySelector('.apv-score-bar-fill');
var marginalNote = root.querySelector('.apv-marginal-note');
var coalitionLabel = root.querySelector('.apv-coalition-label');
var svg = root.querySelector('.apv-bar-chart');
function render(step) {{
var c = DATA.coalitions[step];
stepLabel.textContent = 'Step ' + c.step + ' / ' + N_STEPS;
if (isWordLevel) {{
coalitionLabel.textContent = 'Words active: ' + c.active.length + ' / ' + n;
}} else {{
coalitionLabel.textContent = 'Coalition: (' + c.mask.join(', ') + ')';
}}
// Update features (inline tokens or block sentences)
for (var i = 0; i < n; i++) {{
var el = root.querySelector('[data-idx="' + i + '"]');
if (!el) continue;
if (isWordLevel) {{
// Check if this token was JUST added this step
if (i.toString() in c.marginals_this_step) {{
el.className = 'apv-token apv-token-just-added';
}} else if (c.mask[i]) {{
el.className = 'apv-token apv-token-active';
}} else {{
el.className = 'apv-token apv-token-masked';
}}
}} else {{
if (c.mask[i]) {{
el.className = 'apv-sentence apv-sentence-active';
}} else {{
el.className = 'apv-sentence apv-sentence-masked';
}}
}}
}}
// f(S) score
var score = c.f_score;
var color = score >= 0 ? '#7c3aed' : '#dd1313';
scoreValue.textContent = score.toFixed(6);
scoreValue.style.color = color;
var range = DATA.f_max - DATA.f_min;
var pct = range > 0 ? ((score - DATA.f_min) / range) * 100 : 50;
scoreFill.style.width = Math.max(2, pct) + '%';
scoreFill.style.background = color;
// Shapley bar chart
renderBars(c.running_shapley, DATA.final_shapley, c.marginals_this_step);
// Marginal note
var keys = Object.keys(c.marginals_this_step);
if (keys.length > 0) {{
var parts = keys.slice(0, 3).map(function(k) {{
var v = c.marginals_this_step[k];
var sign = v >= 0 ? '+' : '';
var label = DATA.features[parseInt(k)];
if (label.length > 20) label = label.substring(0, 20) + '...';
return '"' + label + '": ' + sign + v.toFixed(6);
}});
marginalNote.textContent = 'Marginal: ' + parts.join('; ');
}} else {{
marginalNote.textContent = 'No active features \\u2014 baseline score.';
}}
}}
// Precompute stable axis range from ALL final Shapley values (never changes between steps)
var stableMin = 0, stableMax = 0;
for (var i = 0; i < n; i++) {{
if (DATA.final_shapley[i] < stableMin) stableMin = DATA.final_shapley[i];
if (DATA.final_shapley[i] > stableMax) stableMax = DATA.final_shapley[i];
}}
var stablePad = Math.max(Math.abs(stableMax), Math.abs(stableMin)) * 0.25;
if (stablePad < 1e-10) stablePad = 0.001;
stableMin = stableMin - stablePad;
stableMax = stableMax + stablePad;
function renderBars(running, final_vals, marginals) {{
// For word-level, show the top-N by |final Shapley| value
var indices = [];
if (isWordLevel) {{
var sorted = final_vals.map(function(v, i) {{ return [Math.abs(v), i]; }});
sorted.sort(function(a, b) {{ return b[0] - a[0]; }});
for (var k = 0; k < Math.min(barDisplayLimit, sorted.length); k++) {{
indices.push(sorted[k][1]);
}}
}} else {{
for (var k = 0; k < n; k++) indices.push(k);
}}
var dispN = indices.length;
var rowH = 42;
var W = 340, H = 30 + dispN * rowH;
svg.setAttribute('viewBox', '0 0 ' + W + ' ' + H);
// Use stable axis range so bars don't jump between steps
var vMin = stableMin;
var vMax = stableMax;
var range = vMax - vMin;
if (range < 1e-12) range = 0.002;
var mL = 10, barArea = W - mL - 10;
var zeroX = mL + ((-vMin) / range) * barArea;
var html = '';
html += '<line x1="' + zeroX + '" y1="5" x2="' + zeroX + '" y2="' + (H - 5) + '" stroke="#ccc" stroke-width="1" stroke-dasharray="3,3"/>';
for (var k = 0; k < dispN; k++) {{
var i = indices[k];
var y = 15 + k * rowH;
var barH = 20;
var rv = running[i];
var fv = final_vals[i];
// Ghost bar (final target)
var fvX = mL + ((fv - vMin) / range) * barArea;
var gX = Math.min(zeroX, fvX);
var gW = Math.abs(fvX - zeroX);
var gC = fv >= 0 ? '#7c3aed' : '#dd1313';
html += '<rect x="' + gX + '" y="' + y + '" width="' + Math.max(gW, 0.5) + '" height="' + barH + '" fill="none" stroke="' + gC + '" stroke-width="1.5" stroke-dasharray="4,3" opacity="0.35" rx="3"/>';
// Running bar
var rvX = mL + ((rv - vMin) / range) * barArea;
var bX = Math.min(zeroX, rvX);
var bW = Math.abs(rvX - zeroX);
var bC = rv >= 0 ? '#7c3aed' : '#dd1313';
var op = (rv === 0 && !(i.toString() in marginals)) ? '0.1' : '0.75';
html += '<rect x="' + bX + '" y="' + y + '" width="' + Math.max(bW, 0.5) + '" height="' + barH + '" fill="' + bC + '" opacity="' + op + '" rx="3"/>';
// Highlight dot
if (i.toString() in marginals) {{
html += '<circle cx="' + rvX + '" cy="' + (y + barH / 2) + '" r="3.5" fill="#7c3aed" stroke="#fff" stroke-width="1"/>';
}}
// Label + Value on the SAME line below the bar
var label = DATA.features[i];
if (label.length > 25) label = label.substring(0, 25) + '..';
label = label.replace(/</g, '&lt;').replace(/>/g, '&gt;');
var prefix = isWordLevel ? ('W' + i) : ('S' + i);
var labelY = y + barH + 13;
html += '<text x="' + mL + '" y="' + labelY + '" fill="#666" font-size="10" font-weight="500">' + prefix + ': ' + label + '</text>';
var sign = rv >= 0 ? '+' : '';
var vT = sign + rv.toFixed(6);
html += '<text x="' + (W - 10) + '" y="' + labelY + '" fill="#555" font-size="10" font-family="monospace" text-anchor="end">' + vT + '</text>';
}}
html += '<text x="' + (W - 10) + '" y="' + (H - 2) + '" fill="#aaa" font-size="8" text-anchor="end">Dashed = final Shapley target'
+ (isWordLevel ? ' (top ' + barDisplayLimit + ' by |value|)' : '')
+ '</text>';
svg.innerHTML = html;
}}
// Controls
root.querySelector('.apv-btn-reset').addEventListener('click', function() {{
if (playTimer) {{
clearInterval(playTimer);
playTimer = null;
playBtn.textContent = '\\u25B6 Play';
}}
currentStep = 0;
render(currentStep);
}});
root.querySelector('.apv-btn-prev').addEventListener('click', function() {{
currentStep = Math.max(0, currentStep - 1);
render(currentStep);
}});
root.querySelector('.apv-btn-next').addEventListener('click', function() {{
currentStep = Math.min(N_STEPS - 1, currentStep + 1);
render(currentStep);
}});
playBtn.addEventListener('click', function() {{
if (playTimer) {{
clearInterval(playTimer);
playTimer = null;
playBtn.textContent = '\\u25B6 Play';
}} else {{
playBtn.textContent = '\\u23F8 Pause';
var interval = isWordLevel ? 400 : 1500;
playTimer = setInterval(function() {{
// Stop if our root element was removed from DOM (dataset switched)
if (!document.getElementById('{root_id}')) {{
clearInterval(playTimer);
playTimer = null;
return;
}}
currentStep = (currentStep + 1) % N_STEPS;
render(currentStep);
if (currentStep === N_STEPS - 1) {{
clearInterval(playTimer);
playTimer = null;
playBtn.textContent = '\\u25B6 Play';
}}
}}, interval);
}}
}});
render(currentStep);
}})();
"""
# Build feature elements — inline tokens for word-level, block divs otherwise
if is_word_level:
feature_html = '<div class="apv-tokens-wrap">'
for i, feat in enumerate(features):
escaped = feat.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
feature_html += (
f'<span class="apv-token apv-token-masked" data-idx="{i}"'
f' title="W{i}: {escaped}">{escaped}</span>'
)
feature_html += '</div>'
else:
feature_html = ""
for i, feat in enumerate(features):
escaped = feat.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
feature_html += (
f'<div class="apv-sentence apv-sentence-masked" data-idx="{i}">'
f'<span class="apv-sentence-label">S{i}</span>{escaped}</div>\n'
)
svg_height = 30 + bar_display_limit * 42
mode_label = "Sequential Build-up" if mode == "sequential" else f"All 2\u207f = {n_steps} Coalitions"
html = (
f'<div id="{root_id}">'
f"<style>{css}</style>"
f'<div class="apv-header">'
f'<div>'
f'<span class="apv-title">Attribution Process Viewer</span> '
f'<span class="apv-mode-badge">{mode_label}</span>'
f'</div>'
f'<div class="apv-controls">'
f'<button class="apv-btn apv-btn-reset" title="Reset to step 1">&#8634; Reset</button>'
f'<button class="apv-btn apv-btn-prev">&#9664; Prev</button>'
f'<button class="apv-btn apv-btn-play">&#9654; Play</button>'
f'<button class="apv-btn apv-btn-next">Next &#9654;</button>'
f'<span class="apv-step-label">Step 1 / {n_steps}</span>'
f'</div></div>'
f'<div class="apv-body">'
f'<div class="apv-text-panel">'
f'<div class="apv-section-title">Masked Context</div>'
f'{feature_html}'
f'<div class="apv-score-row">'
f'<span class="apv-score-label">f(S) =</span>'
f'<span class="apv-score-value">0.000000</span>'
f'<div class="apv-score-bar-track"><div class="apv-score-bar-fill" style="width:0%"></div></div>'
f'</div>'
f'<div class="apv-coalition-label"></div>'
f'<div class="apv-marginal-note"></div>'
f'</div>'
f'<div class="apv-shapley-panel">'
f'<div class="apv-section-title">Running Shapley Values</div>'
f'<svg class="apv-bar-chart" viewBox="0 0 340 {svg_height}" '
f'xmlns="http://www.w3.org/2000/svg"></svg>'
f'</div>'
f'</div>'
f'<script type="text/plain" id="{data_id}">{data_json}</script>'
f'<img class="apv-loader" id="{loader_id}" alt="" '
'src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///ywAAAAAAQABAAACAUwAOw==" '
f'onload="(function(){{var s=document.getElementById(\'{script_id}\');'
'if(!s||!s.textContent){return;}try{(new Function(s.textContent))();}catch(e){'
'console.warn(\'APV init failed\',e);}})()" />'
f'<script type="text/plain" id="{script_id}">{js_code}</script>'
f'</div>'
)
return html