"""HTML helpers for visualizing hop-wise IFR/AttnLRP attributions."""
from __future__ import annotations
import math
from typing import Any, Dict, List, Optional, Sequence
from html import escape
TOKEN_SCALE_QUANTILE = 0.995
def _robust_abs_max(scores: Sequence[float], *, quantile: float = TOKEN_SCALE_QUANTILE) -> float:
"""Return a robust abs max to avoid a single outlier washing out the colormap.
Uses a high quantile (default: p99.5) over |scores|. Top outliers saturate.
"""
abs_vals: List[float] = []
for x in scores:
try:
v = float(x)
except Exception:
continue
if math.isnan(v):
continue
abs_vals.append(abs(v))
if not abs_vals:
return 0.0
abs_vals.sort()
q = float(quantile)
if q < 0.0:
q = 0.0
if q > 1.0:
q = 1.0
idx = int(q * (len(abs_vals) - 1))
return float(abs_vals[idx])
def _color_for_score(score: float, max_score: float) -> str:
if max_score <= 0:
return "background-color: rgba(245,245,245,0.7);"
ratio = min(1.0, score / (max_score + 1e-12))
r = 255
g = int(235 - 90 * ratio)
b = int(220 - 160 * ratio)
alpha = 0.25 + 0.55 * ratio
return f"background-color: rgba({r}, {g}, {b}, {alpha});"
def _render_sentence_list(title: str, sentences: Sequence[str], scores: Sequence[float], max_score: float) -> str:
rows: List[str] = []
for sent, sc in zip(sentences, scores):
style = _color_for_score(abs(float(sc)), max_score)
rows.append(
f'
{sc:.4f}'
f'{escape(sent)}
'
)
return f"""
{escape(title)}
{''.join(rows)}
"""
def _render_tokens(
tokens: Sequence[str],
scores: Sequence[float],
max_score: float,
roles: Sequence[str],
) -> str:
spans: List[str] = []
if max_score <= 0:
max_score = 1e-8
for idx, tok in enumerate(tokens):
score = float(scores[idx]) if idx < len(scores) else 0.0
style = _color_for_score(abs(score), max_score)
role = roles[idx] if idx < len(roles) else "gen"
safe_tok = escape(tok)
spans.append(
f'{safe_tok}'
)
return "".join(spans)
def _render_top_table(top_items: List[Dict[str, Any]]) -> str:
if not top_items:
return "No attribution mass.
"
header = ""
body_rows = []
for rank, item in enumerate(top_items, start=1):
body_rows.append(
f"{rank}{item['idx']}"
f"{item['score']:.4f}{escape(item['sentence'])}
"
)
return f"{header}{''.join(body_rows)}
"
def render_case_html(
case_meta: Dict[str, Any],
*,
token_view_raw: Dict[str, Any],
token_view_prompt: Dict[str, Any],
context: Optional[Dict[str, Any]] = None,
hops_sent: Optional[Sequence[Dict[str, Any]]] = None,
) -> str:
has_sentence_view = bool(context) and bool(hops_sent)
prompt_len = len((context or {}).get("prompt_sentences") or []) if has_sentence_view else 0
gen_len = len((context or {}).get("generation_sentences") or []) if has_sentence_view else 0
prompt_max = 0.0
gen_max = 0.0
if has_sentence_view:
prompt_max = max(
(
max(h["sentence_scores_raw"][:prompt_len])
for h in (hops_sent or [])
if h.get("sentence_scores_raw") and h["sentence_scores_raw"][:prompt_len]
),
default=0.0,
)
gen_max = max(
(
max(h["sentence_scores_raw"][prompt_len:])
for h in (hops_sent or [])
if h.get("sentence_scores_raw") and h["sentence_scores_raw"][prompt_len:]
),
default=0.0,
)
raw_hops = token_view_raw.get("hops", []) or []
prompt_hops = token_view_prompt.get("hops", []) or []
if len(raw_hops) != len(prompt_hops):
raise ValueError(
"token_view_raw and token_view_prompt must have the same number of panels: "
f"raw={len(raw_hops)} prompt={len(prompt_hops)}"
)
hop_sections: List[str] = []
hop_count = len(prompt_hops)
mode = case_meta.get("mode", "ft")
ifr_view = case_meta.get("ifr_view", "aggregate")
sink_span = case_meta.get("sink_span")
panel_titles = case_meta.get("panel_titles")
def _panel_title(panel_idx: int) -> str:
if isinstance(panel_titles, list) and panel_idx < len(panel_titles):
try:
title = panel_titles[panel_idx]
except Exception:
title = None
if title is not None:
return str(title)
if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen", "ft_attnlrp"):
return f"Hop {panel_idx}"
if mode == "ifr_all_positions_output_only":
return f"IFR output-only panel {panel_idx}"
if mode == "ifr_all_positions":
return f"IFR all-positions panel {panel_idx}"
if mode == "attnlrp":
return "AttnLRP (sink-span aggregate)"
return "IFR (sink-span aggregate)"
for hop_idx in range(hop_count):
raw_entry = raw_hops[hop_idx]
raw_scores = raw_entry.get("token_scores") or []
raw_mass = float(raw_entry.get("total_mass", 0.0))
raw_scale = _robust_abs_max(raw_scores)
if raw_scale <= 0:
raw_scale = float(raw_entry.get("token_score_max") or 0.0)
if raw_scale <= 0:
raw_scale = 1e-8
prompt_entry = prompt_hops[hop_idx]
prompt_scores = prompt_entry.get("token_scores") or []
prompt_mass = float(prompt_entry.get("total_mass", 0.0))
prompt_scale = _robust_abs_max(prompt_scores)
if prompt_scale <= 0:
prompt_scale = float(prompt_entry.get("token_score_max") or 0.0)
if prompt_scale <= 0:
prompt_scale = 1e-8
tok_raw_html = f"""
{escape(token_view_raw.get("label", "Pre-trim token-level heatmap (full)"))}
{_render_tokens(token_view_raw.get("tokens", []), raw_scores, raw_scale, token_view_raw.get("roles", []))}
"""
tok_prompt_html = f"""
{escape(token_view_prompt.get("label", "Prompt-only token-level heatmap"))}
{_render_tokens(token_view_prompt.get("tokens", []), prompt_scores, prompt_scale, token_view_prompt.get("roles", []))}
"""
sentence_html = ""
top_html = ""
if has_sentence_view and hop_idx < len(hops_sent or []):
hop = (hops_sent or [])[hop_idx]
raw_scores = hop.get("sentence_scores_raw") or []
prompt_scores = raw_scores[:prompt_len]
gen_scores = raw_scores[prompt_len:]
# Sentence view is not used by the current case-study runner; keep the path for completeness.
sentence_html = f"""
{_render_sentence_list('Prompt sentences', (context or {}).get('prompt_sentences') or [], prompt_scores, prompt_max)}
{_render_sentence_list('Generation sentences', (context or {}).get('generation_sentences') or [], gen_scores, gen_max)}
"""
top_html = f"""
Top sentences (all)
{_render_top_table(hop.get('top_sentences') or [])}
"""
hop_sections.append(
f"""
{tok_raw_html}
{tok_prompt_html}
{sentence_html}
{top_html}
"""
)
thinking_ratios = case_meta.get("thinking_ratios") or []
ratios_str = ", ".join(f"{r:.4f}" for r in thinking_ratios) if thinking_ratios else "N/A"
if mode == "ft":
mode_label = "FT Multi-hop (IFR)"
elif mode == "ifr_in_all_gen":
mode_label = "IFR In-all-gen (multi-hop)"
elif mode == "ifr":
mode_label = "IFR Standard"
elif mode == "ifr_all_positions":
mode_label = "IFR All-positions"
elif mode == "ifr_all_positions_output_only":
mode_label = "IFR Output-only (all positions)"
elif mode == "attnlrp":
mode_label = "AttnLRP"
elif mode == "ft_attnlrp":
mode_label = "FT Multi-hop (AttnLRP)"
else:
mode_label = str(mode)
if mode in ("ft", "ifr_in_all_gen", "ft_attnlrp"):
view_key = "Recursive hops"
view_val = case_meta.get("n_hops")
elif mode in ("ifr", "ifr_all_positions", "ifr_all_positions_output_only"):
view_key = "IFR view"
view_val = ifr_view
elif mode == "attnlrp":
view_key = "AttnLRP view"
view_val = "ft_hop0_span_aggregate"
else:
view_key = "View"
view_val = "N/A"
scale_row = f"Token scale: per-panel per-view p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f}(|score|)
"
neg_handling = case_meta.get("attnlrp_neg_handling")
norm_mode = case_meta.get("attnlrp_norm_mode")
ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
attn_rows = []
if neg_handling:
attn_rows.append(f"FT-AttnLRP neg_handling: {escape(str(neg_handling))}
")
if norm_mode:
attn_rows.append(f"FT-AttnLRP norm_mode: {escape(str(norm_mode))}
")
if ratio_enabled is not None:
attn_rows.append(f"FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}
")
header = f"""
"""
style = """
"""
title = f"{mode_label} Case Study"
html = f"""
{escape(title)}
{style}
{header}
{''.join(hop_sections)}
"""
return html
def _render_sentence_spans(title: str, sentences: Sequence[str], scores: Sequence[float]) -> str:
max_abs = max((abs(float(x)) for x in scores), default=0.0)
spans: List[str] = []
for idx, sentence in enumerate(sentences):
score = float(scores[idx]) if idx < len(scores) else 0.0
style = _color_for_score(abs(score), max_abs)
spans.append(
f'{escape(sentence)}'
)
return f"""
{escape(title)}
{''.join(spans)}
"""
def _render_token_spans(title: str, tokens: Sequence[str], scores: Sequence[float]) -> str:
max_abs = max((abs(float(x)) for x in scores), default=0.0)
spans: List[str] = []
for idx, tok in enumerate(tokens):
score = float(scores[idx]) if idx < len(scores) else 0.0
style = _color_for_score(abs(score), max_abs)
spans.append(
f'{escape(tok)}'
)
return f"""
{escape(title)}
{''.join(spans)}
"""
def render_mas_sentence_html(
case_meta: Dict[str, Any],
*,
prompt_sentences: Sequence[str],
panels: Sequence[Dict[str, Any]],
generation: Optional[str] = None,
) -> str:
"""Render MAS sentence-level diagnostics (attribution / pure ablation / guided marginal)."""
method_label = case_meta.get("attr_method_label") or case_meta.get("attr_method") or "Unknown method"
title = f"MAS Sentence Study ({method_label})"
neg_handling = case_meta.get("attnlrp_neg_handling")
norm_mode = case_meta.get("attnlrp_norm_mode")
ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
attn_rows = []
if neg_handling:
attn_rows.append(f"FT-AttnLRP neg_handling: {escape(str(neg_handling))}
")
if norm_mode:
attn_rows.append(f"FT-AttnLRP norm_mode: {escape(str(norm_mode))}
")
if ratio_enabled is not None:
attn_rows.append(f"FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}
")
base_score = case_meta.get("base_score")
base_score_row = f"Base score: {float(base_score):.6f}
" if isinstance(base_score, (int, float)) else ""
gen_block = ""
if isinstance(generation, str) and generation:
gen_block = f"""
Generation (scored)
{escape(generation)}
"""
header = f"""
"""
panel_sections: List[str] = []
for panel in panels:
label = panel.get("variant_label") or panel.get("panel_label") or panel.get("variant") or "Panel"
metrics = panel.get("metrics") or {}
metrics_str = " | ".join(
f"{k}: {float(metrics[k]):.4f}" if isinstance(metrics.get(k), (int, float)) else f"{k}: {metrics.get(k)}"
for k in ("RISE", "MAS", "RISE+AP")
if k in metrics
)
attr_weights = panel.get("attr_weights") or []
pure_deltas = panel.get("pure_sentence_deltas_raw") or []
guided_deltas = panel.get("guided_sentence_deltas_raw") or panel.get("sentence_deltas_raw") or []
rank_order = panel.get("sorted_attr_indices") or []
rank_str = ", ".join(str(int(x)) for x in rank_order) if rank_order else "N/A"
panel_sections.append(
f"""
{_render_sentence_spans("Method attribution (sentence weights)", prompt_sentences, attr_weights)}
{_render_sentence_spans("Pure sentence ablation (base − score)", prompt_sentences, pure_deltas)}
{_render_sentence_spans("Attribution-guided MAS marginal (path deltas)", prompt_sentences, guided_deltas)}
"""
)
style = """
"""
html = f"""
{escape(title)}
{style}
{header}
{gen_block}
{''.join(panel_sections)}
"""
return html
def render_mas_token_html(
case_meta: Dict[str, Any],
*,
prompt_tokens: Sequence[str],
panels: Sequence[Dict[str, Any]],
generation: Optional[str] = None,
) -> str:
"""Render MAS token-level diagnostics (attribution weights + guided marginal deltas)."""
method_label = case_meta.get("attr_method_label") or case_meta.get("attr_method") or "Unknown method"
title = f"MAS Token Study ({method_label})"
neg_handling = case_meta.get("attnlrp_neg_handling")
norm_mode = case_meta.get("attnlrp_norm_mode")
ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
attn_rows = []
if neg_handling:
attn_rows.append(f"FT-AttnLRP neg_handling: {escape(str(neg_handling))}
")
if norm_mode:
attn_rows.append(f"FT-AttnLRP norm_mode: {escape(str(norm_mode))}
")
if ratio_enabled is not None:
attn_rows.append(f"FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}
")
base_score = case_meta.get("base_score")
base_score_row = f"Base score: {float(base_score):.6f}
" if isinstance(base_score, (int, float)) else ""
gen_block = ""
if isinstance(generation, str) and generation:
gen_block = f"""
Generation (scored)
{escape(generation)}
"""
header = f"""
"""
panel_sections: List[str] = []
for panel in panels:
label = panel.get("variant_label") or panel.get("panel_label") or panel.get("variant") or "Panel"
metrics = panel.get("metrics") or {}
metrics_str = " | ".join(
f"{k}: {float(metrics[k]):.4f}" if isinstance(metrics.get(k), (int, float)) else f"{k}: {metrics.get(k)}"
for k in ("RISE", "MAS", "RISE+AP")
if k in metrics
)
attr_weights = panel.get("attr_weights") or []
guided_deltas = panel.get("token_deltas_raw") or []
rank_order = panel.get("sorted_attr_indices") or []
rank_str = ", ".join(str(int(x)) for x in rank_order) if rank_order else "N/A"
panel_sections.append(
f"""
{_render_token_spans("Method attribution (token weights)", prompt_tokens, attr_weights)}
{_render_token_spans("Attribution-guided MAS marginal (path deltas)", prompt_tokens, guided_deltas)}
"""
)
style = """
"""
html = f"""
{escape(title)}
{style}
{header}
{gen_block}
{''.join(panel_sections)}
"""
return html