from __future__ import annotations
import json
import uuid
from html import escape
from typing import Any, Dict, List, Sequence
_NEGATIVE_RGB = (221, 19, 19)
_POSITIVE_RGB = (74, 28, 135)
_NEUTRAL_RGB = (235, 228, 244)
def _interpolated_rgb(value: float, max_abs: float) -> tuple[int, int, int]:
if max_abs <= 0:
return _NEUTRAL_RGB
norm = max(-1.0, min(1.0, value / max_abs))
t = (norm + 1.0) / 2.0
if t < 0.5:
local = t * 2.0
return tuple(
int(round(_NEGATIVE_RGB[i] + (_NEUTRAL_RGB[i] - _NEGATIVE_RGB[i]) * local))
for i in range(3)
)
local = (t - 0.5) * 2.0
return tuple(
int(round(_NEUTRAL_RGB[i] + (_POSITIVE_RGB[i] - _NEUTRAL_RGB[i]) * local))
for i in range(3)
)
def _strip_occurrence_suffix(text: str) -> str:
text = text or ""
if text.endswith(")") and " (" in text:
base, _, tail = text.rpartition(" (")
if tail[:-1].isdigit():
return base
return text
def _value_to_color(value: float, max_abs: float) -> str:
rgb = _interpolated_rgb(value, max_abs)
return f"rgb({rgb[0]}, {rgb[1]}, {rgb[2]})"
def _text_color_for_value(value: float, max_abs: float) -> str:
rgb = _interpolated_rgb(value, max_abs)
luminance = (0.2126 * rgb[0] + 0.7152 * rgb[1] + 0.0722 * rgb[2]) / 255.0
if luminance < 0.42:
return "#ffffff"
return "#020617"
def _coerce_marginals(features: Sequence[str], marginals: Any) -> List[float]:
if marginals is None:
return [0.0 for _ in features]
if isinstance(marginals, dict):
return [float(marginals.get(f, 0.0)) for f in features]
if isinstance(marginals, (list, tuple)):
values: List[float] = []
for idx, feat in enumerate(features):
try:
values.append(float(marginals[idx]))
except Exception:
values.append(0.0)
return values
return [0.0 for _ in features]
def _sanitize_edges(
interactions: Sequence[Dict[str, Any]],
feature_count: int,
top_k: int,
threshold: float,
) -> List[Dict[str, Any]]:
edges: List[Dict[str, Any]] = []
seen: set = set()
for item in interactions:
if not isinstance(item, dict):
continue
indices = item.get("indices")
if not indices or len(indices) != 2:
continue
try:
i = int(indices[0])
j = int(indices[1])
value = float(item.get("value", 0.0))
except Exception:
continue
if i < 0 or j < 0 or i >= feature_count or j >= feature_count or i == j:
continue
if abs(value) < threshold:
continue
# Deduplicate (i,j) and (j,i) — keep canonical form with smaller index first.
canonical = (min(i, j), max(i, j))
if canonical in seen:
continue
seen.add(canonical)
edges.append({"i": canonical[0], "j": canonical[1], "value": value})
edges.sort(key=lambda item: abs(item["value"]), reverse=True)
return edges[:top_k]
def _influence_token_color(value: float, max_abs: float) -> str:
"""Gold-to-orange gradient for influence (non-negative) values."""
if max_abs <= 0:
return f"rgb({_NEUTRAL_RGB[0]}, {_NEUTRAL_RGB[1]}, {_NEUTRAL_RGB[2]})"
norm = min(1.0, value / max_abs)
# Light gold (255,244,210) -> deep orange (200,100,30)
r = int(255 - (255 - 200) * norm)
g = int(244 - (244 - 100) * norm)
b = int(210 - (210 - 30) * norm)
return f"rgb({r}, {g}, {b})"
def _influence_text_color(value: float, max_abs: float) -> str:
"""Text color for influence tokens based on background luminance."""
if max_abs <= 0:
return "#020617"
norm = min(1.0, value / max_abs)
r = int(255 - (255 - 200) * norm)
g = int(244 - (244 - 100) * norm)
b = int(210 - (210 - 30) * norm)
luminance = (0.2126 * r + 0.7152 * g + 0.0722 * b) / 255.0
return "#ffffff" if luminance < 0.42 else "#020617"
def create_text_interaction_html(
features: List[str],
marginals: Any,
interactions: List[Dict[str, Any]],
*,
method: str = "shapley",
top_k: int = 20,
threshold: float = 0.0,
) -> str:
if not features:
return "
No tokens available.
"
is_influence = method.lower() == "influence"
values = _coerce_marginals(features, marginals)
if is_influence:
values = [abs(v) for v in values]
max_abs = max((abs(v) for v in values), default=0.0)
edges = _sanitize_edges(interactions or [], len(features), top_k, threshold)
if is_influence:
for edge in edges:
edge["value"] = abs(edge["value"])
edge_note = "" if edges else "
No interactions to display.
"
method_label = (method or "score").strip().title()
influence_note = (
"
"
"Influence scores are always non-negative — they represent squared Fourier "
"coefficients of the Möbius transform, measuring each token's contribution magnitude.
"
if is_influence else ""
)
view_id = f"text-interaction-{uuid.uuid4().hex[:8]}"
tokens_html = []
for idx, token in enumerate(features):
value = values[idx]
bg = _value_to_color(value, max_abs)
fg = _text_color_for_value(value, max_abs)
token_label = escape(_strip_occurrence_suffix(str(token)))
if is_influence:
tooltip = escape(f"{method_label} token {idx + 1}: {value:.4f}")
else:
tooltip = escape(f"{method_label} token {idx + 1}: {value:+.4f}")
tokens_html.append(
""
f"{token_label}"
""
)
payload = {
"edges": edges,
"max_abs": max((abs(edge["value"]) for edge in edges), default=0.0),
}
data_blob = json.dumps(payload)
displayed_edge_count = len(edges)
script_id = f"{view_id}-script"
loader_id = f"{view_id}-loader"
js_code = (
# layout constants for curved routing and hover focus
"const LANE_SPACING=20;" # separation between lanes
"const MAX_LANES=12;" # allow more lanes before shrinking spacing
"const EXTRA_FOR_DISTANCE=0.08;" # multiplies span in px for extra lift
"const MAX_ARC_EXTRA=90;" # cap extra lift based on distance
"const MAX_ARC_LIFT=160;" # cap total arc height
"const TOP_BASE=45;" # base lift for first top lane
"const BOTTOM_BASE=45;" # base drop for first bottom lane
"const TOP_MARGIN=16;" # extra safety margin
"const BOTTOM_MARGIN=16;"
"const TOKEN_BAND=42;" # estimated token row thickness
"const DEBUG_LANES=false;"
# ── getBestAnchors: rectangle-aware anchor selection ──────────────
# Each token chip is a rectangle. For every edge we pick the pair of
# border anchor points (left/right/top/bottom midpoints) that yields
# the shortest Euclidean distance, with a small penalty to discourage
# "bottom" anchors (which cause ugly under-loops) when better options
# exist.
"function _anchors(b){"
"return ["
"{x:b.right, y:b.cy, side:'right'},"
"{x:b.left, y:b.cy, side:'left'},"
"{x:b.cx, y:b.top, side:'top'},"
"{x:b.cx, y:b.bottom, side:'bottom'}"
"];}"
"function getBestAnchors(boxA,boxB){"
"const aa=_anchors(boxA),bb=_anchors(boxB);"
"let best=null,bestScore=Infinity;"
"for(const a of aa){for(const b of bb){"
"let dx=a.x-b.x,dy=a.y-b.y;"
"let d=Math.sqrt(dx*dx+dy*dy);"
# Small penalty for bottom anchors to avoid under-loops.
"if(a.side==='bottom')d+=18;"
"if(b.side==='bottom')d+=18;"
# Penalise paths that would cross through one of the chips:
# if anchor A is on the left of boxA but boxB is to the right (and
# vice-versa) – i.e. the path would have to go backward through A.
"if(a.side==='left' && b.x>boxA.cx)d+=30;"
"if(a.side==='right' && b.xboxB.cx)d+=30;"
"if(b.side==='right' && a.x[el.dataset.index,el]));"
"function trackAdj(a,b){"
"if(!adjacency.has(a))adjacency.set(a,new Set());"
"adjacency.get(a).add(b);"
"}"
"function applyHover(){"
"const focus=hoverToken;"
"const focusKey=focus===null?null:String(focus);"
"if(focusKey===null){"
"tokens.forEach(t=>t.classList.remove('token-dim','token-focus'));"
"edgeEls.forEach(e=>{e.classList.remove('edge-dim','edge-focus');});"
"return;"
"}"
"const neighbors=adjacency.get(focusKey)||new Set();"
"tokens.forEach(t=>{"
"const idx=t.dataset.index;"
"const active=idx===focusKey||neighbors.has(idx);"
"t.classList.toggle('token-focus',active);"
"t.classList.toggle('token-dim',!active);"
"});"
"edgeEls.forEach(e=>{"
"const active=e.dataset.start===focusKey||e.dataset.end===focusKey;"
"e.classList.toggle('edge-focus',active);"
"e.classList.toggle('edge-dim',!active);"
"});"
"}"
"function measureBoxes(){"
"const wrap=wrapper.getBoundingClientRect();"
"const out={};"
"tokens.forEach(el=>{"
"const box=el.getBoundingClientRect();"
"const left=box.left-wrap.left;"
"const top=box.top-wrap.top;"
"const w=box.width;"
"const h=box.height;"
"out[el.dataset.index]={"
"left:left,right:left+w,top:top,bottom:top+h,"
"cx:left+w/2,cy:top+h/2,w:w,h:h};"
"});"
"return out;"
"}"
"function lanePack(edgeList){"
"edgeList.sort((a,b)=>{"
"const spanA=a.endX-a.startX;"
"const spanB=b.endX-b.startX;"
"if(spanA!==spanB){return spanB-spanA;}"
"if(a.startX!==b.startX){return a.startX-b.startX;}"
"return a.endX-b.endX;});"
"const lanes=[];"
"edgeList.forEach(edge=>{"
"let lane=0; let placed=false;"
"for(; lane iv.endX)){overlap=true;break;}"
"}"
"if(!overlap){edge.lane=lane;lanes[lane].push(edge);placed=true;break;}"
"}"
"if(!placed){edge.lane=lanes.length;lanes.push([edge]);}"
"});"
"return lanes.length||1;"
"}"
"function draw(){"
"let rect=wrapper.getBoundingClientRect();"
"if(rect.width===0||rect.height===0){return;}"
"while(svg.firstChild){svg.removeChild(svg.firstChild);}edgeEls.length=0;"
"adjacency.clear();"
"let boxes=measureBoxes();"
"const clipId='ti-clip-" f"{view_id}" "';"
"let defs=svg.querySelector('defs');"
"if(!defs){defs=document.createElementNS('http://www.w3.org/2000/svg','defs');svg.appendChild(defs);}"
"let clip=defs.querySelector(`#${clipId}`);"
"if(!clip){clip=document.createElementNS('http://www.w3.org/2000/svg','clipPath');clip.setAttribute('id',clipId);defs.appendChild(clip);}"
"let clipRect=clip.querySelector('rect');"
"if(!clipRect){clipRect=document.createElementNS('http://www.w3.org/2000/svg','rect');clip.appendChild(clipRect);}"
"clipRect.setAttribute('x','0');clipRect.setAttribute('y','0');"
"clipRect.setAttribute('width',rect.width);clipRect.setAttribute('height',rect.height);"
"let g=svg.querySelector('g.interaction-edges-group');"
"if(!g){g=document.createElementNS('http://www.w3.org/2000/svg','g');g.classList.add('interaction-edges-group');svg.appendChild(g);}"
"g.setAttribute('clip-path',`url(#${clipId})`);"
"while(g.firstChild){g.removeChild(g.firstChild);}"
"const posEdges=[], negEdges=[];"
"edges.forEach(edge=>{"
"const bA=boxes[String(edge.i)];"
"const bB=boxes[String(edge.j)];"
"if(!bA||!bB){return;}"
# Compute token distance (index difference) for path strategy.
"const tokenDist=Math.abs(edge.j-edge.i);"
"const rec=Object.assign({"
"startX:Math.min(bA.cx,bB.cx),"
"endX:Math.max(bA.cx,bB.cx),"
"spanPx:Math.abs(bA.cx-bB.cx),"
"tokenDist:tokenDist"
"}, edge);"
"(edge.value>=0?posEdges:negEdges).push(rec);"
"});"
"const posLaneCount=lanePack(posEdges);"
"const negLaneCount=lanePack(negEdges);"
"const clampExtra=(span)=>Math.min(span*EXTRA_FOR_DISTANCE,MAX_ARC_EXTRA);"
"function laneSpacing(count){"
"if(count<=8){return LANE_SPACING;}"
"return Math.max(10, LANE_SPACING*8/count);"
"}"
"const posSpacing=laneSpacing(posLaneCount);"
"const negSpacing=laneSpacing(negLaneCount);"
"const maxPosLane=posEdges.reduce((m,e)=>Math.max(m,e.lane||0),-1);"
"const maxNegLane=negEdges.reduce((m,e)=>Math.max(m,e.lane||0),-1);"
"const maxPosExtra=posEdges.reduce((m,e)=>Math.max(m,clampExtra(e.spanPx||0)),0);"
"const maxNegExtra=negEdges.reduce((m,e)=>Math.max(m,clampExtra(e.spanPx||0)),0);"
"const topNeeded=TOP_BASE + maxPosExtra + Math.max(0,maxPosLane+1)*posSpacing + TOP_MARGIN;"
"const bottomNeeded=BOTTOM_BASE + maxNegExtra + Math.max(0,maxNegLane+1)*negSpacing + BOTTOM_MARGIN;"
"const padTop=Math.max(24, topNeeded);"
"const padBottom=Math.max(24, bottomNeeded);"
"const desiredHeight=padTop + padBottom + TOKEN_BAND;"
"wrapper.style.paddingTop=`${padTop}px`;"
"wrapper.style.paddingBottom=`${padBottom}px`;"
"wrapper.style.minHeight=`${desiredHeight}px`;"
"rect=wrapper.getBoundingClientRect();"
"boxes=measureBoxes();"
"const tokenYs2=Object.values(boxes).map(c=>c.cy);"
"const tokenRowY2=tokenYs2.length?tokenYs2.reduce((a,b)=>a+b,0)/tokenYs2.length:rect.height/2;"
"svg.setAttribute('width',rect.width);"
"svg.setAttribute('height',rect.height);"
"svg.setAttribute('viewBox',`0 0 ${rect.width} ${rect.height}`);"
"clipRect.setAttribute('width',rect.width);clipRect.setAttribute('height',rect.height);"
"const renderList=posEdges.concat(negEdges);"
"renderList.forEach(edge=>{"
"const bA=boxes[String(edge.i)];"
"const bB=boxes[String(edge.j)];"
"if(!bA||!bB){return;}"
"const labelA=(tokenMap.get(String(edge.i))||{}).textContent||'';"
"const labelB=(tokenMap.get(String(edge.j))||{}).textContent||'';"
"const tokenDist=edge.tokenDist||Math.abs(edge.j-edge.i);"
# Pick the best anchor points on the chip borders.
"const anchors=getBestAnchors(bA,bB);"
"if(!anchors){return;}"
"const pathD=makeSmartPath(anchors.start.x,anchors.start.y,"
"anchors.end.x,anchors.end.y,anchors.start.side,anchors.end.side,tokenDist);"
"if(DEBUG_LANES){"
"console.log({i:edge.i,j:edge.j,sideA:anchors.start.side,sideB:anchors.end.side,tokenDist,lane:edge.lane});"
"}"
"const path=document.createElementNS('http://www.w3.org/2000/svg','path');"
"path.setAttribute('d',pathD);"
"const norm=maxAbs?Math.min(1,Math.abs(edge.value)/maxAbs):0;"
"const width=1+5*norm;"
"const opacity=0.25+0.55*norm;"
"const color=edge.value>=0?'#4a1c87':'#dd1313';"
"path.setAttribute('stroke',color);"
"path.setAttribute('stroke-width',width.toFixed(2));"
"path.setAttribute('opacity',opacity.toFixed(2));"
"path.setAttribute('stroke-linecap','round');"
"path.setAttribute('stroke-linejoin','round');"
"path.classList.add('interaction-edge');"
"path.style.setProperty('--edge-width',width.toFixed(2));"
"path.style.setProperty('--edge-opacity',opacity.toFixed(2));"
"path.dataset.start=String(edge.i);"
"path.dataset.end=String(edge.j);"
"trackAdj(path.dataset.start,path.dataset.end);"
"trackAdj(path.dataset.end,path.dataset.start);"
"const title=document.createElementNS('http://www.w3.org/2000/svg','title');"
"title.textContent=`${labelA} x ${labelB} : ${edge.value.toFixed(3)}`;"
"path.appendChild(title);"
"g.appendChild(path);edgeEls.push(path);"
"});"
"applyHover();"
"}"
"tokens.forEach(token=>{"
"token.addEventListener('mouseenter',()=>{"
"hoverToken=token.dataset.index;"
"applyHover();"
"});"
"token.addEventListener('mouseleave',()=>{"
"hoverToken=null;"
"applyHover();"
"});"
"});"
"const schedule=()=>{window.requestAnimationFrame(draw);};"
"schedule();"
"if(window.ResizeObserver){"
"const ro=new ResizeObserver(schedule);ro.observe(wrapper);"
"}else{window.addEventListener('resize',schedule);}"
"if(document.fonts&&document.fonts.ready){document.fonts.ready.then(schedule);}"
)
return (
""
f"
"
"
"
"
"
"
"
"
Text Interaction View
"
f"
Tokens: {len(features)}
"
"
"
f"
Top-{displayed_edge_count} interactions shown; adjust threshold to declutter.
"
"
"
"
"
""
f"
{''.join(tokens_html)}
"
"
"
f"{edge_note}"
"
"
"
"
"
"
f"{method_label} legend"
"
"
f"{'Low' if is_influence else 'Negative'}"
""
f"{'High' if is_influence else 'Positive'}"
"
"
f"
Normalized by max |value| = {max_abs:.4f}. "
"Hover tokens for exact scores.