"""HTML helpers for visualizing hop-wise IFR/AttnLRP attributions.""" from __future__ import annotations import math from typing import Any, Dict, List, Optional, Sequence from html import escape TOKEN_SCALE_QUANTILE = 0.995 def _robust_abs_max(scores: Sequence[float], *, quantile: float = TOKEN_SCALE_QUANTILE) -> float: """Return a robust abs max to avoid a single outlier washing out the colormap. Uses a high quantile (default: p99.5) over |scores|. Top outliers saturate. """ abs_vals: List[float] = [] for x in scores: try: v = float(x) except Exception: continue if math.isnan(v): continue abs_vals.append(abs(v)) if not abs_vals: return 0.0 abs_vals.sort() q = float(quantile) if q < 0.0: q = 0.0 if q > 1.0: q = 1.0 idx = int(q * (len(abs_vals) - 1)) return float(abs_vals[idx]) def _color_for_score(score: float, max_score: float) -> str: if max_score <= 0: return "background-color: rgba(245,245,245,0.7);" ratio = min(1.0, score / (max_score + 1e-12)) r = 255 g = int(235 - 90 * ratio) b = int(220 - 160 * ratio) alpha = 0.25 + 0.55 * ratio return f"background-color: rgba({r}, {g}, {b}, {alpha});" def _render_sentence_list(title: str, sentences: Sequence[str], scores: Sequence[float], max_score: float) -> str: rows: List[str] = [] for sent, sc in zip(sentences, scores): style = _color_for_score(abs(float(sc)), max_score) rows.append( f'
{sc:.4f}' f'{escape(sent)}
' ) return f"""
{escape(title)}
{''.join(rows)}
""" def _render_tokens( tokens: Sequence[str], scores: Sequence[float], max_score: float, roles: Sequence[str], ) -> str: spans: List[str] = [] if max_score <= 0: max_score = 1e-8 for idx, tok in enumerate(tokens): score = float(scores[idx]) if idx < len(scores) else 0.0 style = _color_for_score(abs(score), max_score) role = roles[idx] if idx < len(roles) else "gen" safe_tok = escape(tok) spans.append( f'{safe_tok}' ) return "".join(spans) def _render_top_table(top_items: List[Dict[str, Any]]) -> str: if not top_items: return "
No attribution mass.
" header = "
RankIdxScoreSentence
" body_rows = [] for rank, item in enumerate(top_items, start=1): body_rows.append( f"
{rank}{item['idx']}" f"{item['score']:.4f}{escape(item['sentence'])}
" ) return f"
{header}{''.join(body_rows)}
" def render_case_html( case_meta: Dict[str, Any], *, token_view_raw: Dict[str, Any], token_view_prompt: Dict[str, Any], context: Optional[Dict[str, Any]] = None, hops_sent: Optional[Sequence[Dict[str, Any]]] = None, ) -> str: has_sentence_view = bool(context) and bool(hops_sent) prompt_len = len((context or {}).get("prompt_sentences") or []) if has_sentence_view else 0 gen_len = len((context or {}).get("generation_sentences") or []) if has_sentence_view else 0 prompt_max = 0.0 gen_max = 0.0 if has_sentence_view: prompt_max = max( ( max(h["sentence_scores_raw"][:prompt_len]) for h in (hops_sent or []) if h.get("sentence_scores_raw") and h["sentence_scores_raw"][:prompt_len] ), default=0.0, ) gen_max = max( ( max(h["sentence_scores_raw"][prompt_len:]) for h in (hops_sent or []) if h.get("sentence_scores_raw") and h["sentence_scores_raw"][prompt_len:] ), default=0.0, ) raw_hops = token_view_raw.get("hops", []) or [] prompt_hops = token_view_prompt.get("hops", []) or [] if len(raw_hops) != len(prompt_hops): raise ValueError( "token_view_raw and token_view_prompt must have the same number of panels: " f"raw={len(raw_hops)} prompt={len(prompt_hops)}" ) hop_sections: List[str] = [] hop_count = len(prompt_hops) mode = case_meta.get("mode", "ft") ifr_view = case_meta.get("ifr_view", "aggregate") sink_span = case_meta.get("sink_span") panel_titles = case_meta.get("panel_titles") def _panel_title(panel_idx: int) -> str: if isinstance(panel_titles, list) and panel_idx < len(panel_titles): try: title = panel_titles[panel_idx] except Exception: title = None if title is not None: return str(title) if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen", "ft_attnlrp"): return f"Hop {panel_idx}" if mode == "ifr_all_positions_output_only": return f"IFR output-only panel {panel_idx}" if mode == "ifr_all_positions": return f"IFR all-positions panel {panel_idx}" if mode == "attnlrp": return "AttnLRP (sink-span aggregate)" return "IFR (sink-span aggregate)" for hop_idx in range(hop_count): raw_entry = raw_hops[hop_idx] raw_scores = raw_entry.get("token_scores") or [] raw_mass = float(raw_entry.get("total_mass", 0.0)) raw_scale = _robust_abs_max(raw_scores) if raw_scale <= 0: raw_scale = float(raw_entry.get("token_score_max") or 0.0) if raw_scale <= 0: raw_scale = 1e-8 prompt_entry = prompt_hops[hop_idx] prompt_scores = prompt_entry.get("token_scores") or [] prompt_mass = float(prompt_entry.get("total_mass", 0.0)) prompt_scale = _robust_abs_max(prompt_scores) if prompt_scale <= 0: prompt_scale = float(prompt_entry.get("token_score_max") or 0.0) if prompt_scale <= 0: prompt_scale = 1e-8 tok_raw_html = f"""
{escape(token_view_raw.get("label", "Pre-trim token-level heatmap (full)"))}
{_render_tokens(token_view_raw.get("tokens", []), raw_scores, raw_scale, token_view_raw.get("roles", []))}
""" tok_prompt_html = f"""
{escape(token_view_prompt.get("label", "Prompt-only token-level heatmap"))}
{_render_tokens(token_view_prompt.get("tokens", []), prompt_scores, prompt_scale, token_view_prompt.get("roles", []))}
""" sentence_html = "" top_html = "" if has_sentence_view and hop_idx < len(hops_sent or []): hop = (hops_sent or [])[hop_idx] raw_scores = hop.get("sentence_scores_raw") or [] prompt_scores = raw_scores[:prompt_len] gen_scores = raw_scores[prompt_len:] # Sentence view is not used by the current case-study runner; keep the path for completeness. sentence_html = f"""
{_render_sentence_list('Prompt sentences', (context or {}).get('prompt_sentences') or [], prompt_scores, prompt_max)} {_render_sentence_list('Generation sentences', (context or {}).get('generation_sentences') or [], gen_scores, gen_max)}
""" top_html = f"""
Top sentences (all)
{_render_top_table(hop.get('top_sentences') or [])}
""" hop_sections.append( f"""
{escape(_panel_title(hop_idx))}
raw mass: {raw_mass:.6f} | raw scale(p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f} abs): {raw_scale:.6g}  |  prompt mass: {prompt_mass:.6f} | prompt scale(p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f} abs): {prompt_scale:.6g}
{tok_raw_html} {tok_prompt_html} {sentence_html} {top_html}
""" ) thinking_ratios = case_meta.get("thinking_ratios") or [] ratios_str = ", ".join(f"{r:.4f}" for r in thinking_ratios) if thinking_ratios else "N/A" if mode == "ft": mode_label = "FT Multi-hop (IFR)" elif mode == "ifr_in_all_gen": mode_label = "IFR In-all-gen (multi-hop)" elif mode == "ifr": mode_label = "IFR Standard" elif mode == "ifr_all_positions": mode_label = "IFR All-positions" elif mode == "ifr_all_positions_output_only": mode_label = "IFR Output-only (all positions)" elif mode == "attnlrp": mode_label = "AttnLRP" elif mode == "ft_attnlrp": mode_label = "FT Multi-hop (AttnLRP)" else: mode_label = str(mode) if mode in ("ft", "ifr_in_all_gen", "ft_attnlrp"): view_key = "Recursive hops" view_val = case_meta.get("n_hops") elif mode in ("ifr", "ifr_all_positions", "ifr_all_positions_output_only"): view_key = "IFR view" view_val = ifr_view elif mode == "attnlrp": view_key = "AttnLRP view" view_val = "ft_hop0_span_aggregate" else: view_key = "View" view_val = "N/A" scale_row = f"
Token scale: per-panel per-view p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f}(|score|)
" neg_handling = case_meta.get("attnlrp_neg_handling") norm_mode = case_meta.get("attnlrp_norm_mode") ratio_enabled = case_meta.get("attnlrp_ratio_enabled") attn_rows = [] if neg_handling: attn_rows.append(f"
FT-AttnLRP neg_handling: {escape(str(neg_handling))}
") if norm_mode: attn_rows.append(f"
FT-AttnLRP norm_mode: {escape(str(norm_mode))}
") if ratio_enabled is not None: attn_rows.append(f"
FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}
") header = f"""
{escape(mode_label)} Case Study
Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}
Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}
Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}
Panels: {hop_count}
{escape(str(view_key))}: {escape(str(view_val))}
{scale_row} {''.join(attn_rows)}
Thinking ratios: {ratios_str}
""" style = """ """ title = f"{mode_label} Case Study" html = f""" {escape(title)} {style} {header} {''.join(hop_sections)} """ return html def _render_sentence_spans(title: str, sentences: Sequence[str], scores: Sequence[float]) -> str: max_abs = max((abs(float(x)) for x in scores), default=0.0) spans: List[str] = [] for idx, sentence in enumerate(sentences): score = float(scores[idx]) if idx < len(scores) else 0.0 style = _color_for_score(abs(score), max_abs) spans.append( f'{escape(sentence)}' ) return f"""
{escape(title)}
{''.join(spans)}
""" def _render_token_spans(title: str, tokens: Sequence[str], scores: Sequence[float]) -> str: max_abs = max((abs(float(x)) for x in scores), default=0.0) spans: List[str] = [] for idx, tok in enumerate(tokens): score = float(scores[idx]) if idx < len(scores) else 0.0 style = _color_for_score(abs(score), max_abs) spans.append( f'{escape(tok)}' ) return f"""
{escape(title)}
{''.join(spans)}
""" def render_mas_sentence_html( case_meta: Dict[str, Any], *, prompt_sentences: Sequence[str], panels: Sequence[Dict[str, Any]], generation: Optional[str] = None, ) -> str: """Render MAS sentence-level diagnostics (attribution / pure ablation / guided marginal).""" method_label = case_meta.get("attr_method_label") or case_meta.get("attr_method") or "Unknown method" title = f"MAS Sentence Study ({method_label})" neg_handling = case_meta.get("attnlrp_neg_handling") norm_mode = case_meta.get("attnlrp_norm_mode") ratio_enabled = case_meta.get("attnlrp_ratio_enabled") attn_rows = [] if neg_handling: attn_rows.append(f"
FT-AttnLRP neg_handling: {escape(str(neg_handling))}
") if norm_mode: attn_rows.append(f"
FT-AttnLRP norm_mode: {escape(str(norm_mode))}
") if ratio_enabled is not None: attn_rows.append(f"
FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}
") base_score = case_meta.get("base_score") base_score_row = f"
Base score: {float(base_score):.6f}
" if isinstance(base_score, (int, float)) else "" gen_block = "" if isinstance(generation, str) and generation: gen_block = f"""
Generation (scored)
{escape(generation)}
""" header = f"""
{escape(title)}
Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}
Attribution method: {escape(str(case_meta.get('attr_method')))}
Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}
Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}
Panels: {len(panels)}
{''.join(attn_rows)} {base_score_row}
""" panel_sections: List[str] = [] for panel in panels: label = panel.get("variant_label") or panel.get("panel_label") or panel.get("variant") or "Panel" metrics = panel.get("metrics") or {} metrics_str = " | ".join( f"{k}: {float(metrics[k]):.4f}" if isinstance(metrics.get(k), (int, float)) else f"{k}: {metrics.get(k)}" for k in ("RISE", "MAS", "RISE+AP") if k in metrics ) attr_weights = panel.get("attr_weights") or [] pure_deltas = panel.get("pure_sentence_deltas_raw") or [] guided_deltas = panel.get("guided_sentence_deltas_raw") or panel.get("sentence_deltas_raw") or [] rank_order = panel.get("sorted_attr_indices") or [] rank_str = ", ".join(str(int(x)) for x in rank_order) if rank_order else "N/A" panel_sections.append( f"""
{escape(str(label))}
{escape(metrics_str)}
{_render_sentence_spans("Method attribution (sentence weights)", prompt_sentences, attr_weights)} {_render_sentence_spans("Pure sentence ablation (base − score)", prompt_sentences, pure_deltas)} {_render_sentence_spans("Attribution-guided MAS marginal (path deltas)", prompt_sentences, guided_deltas)}
Rank order: {escape(rank_str)}
""" ) style = """ """ html = f""" {escape(title)} {style} {header} {gen_block} {''.join(panel_sections)} """ return html def render_mas_token_html( case_meta: Dict[str, Any], *, prompt_tokens: Sequence[str], panels: Sequence[Dict[str, Any]], generation: Optional[str] = None, ) -> str: """Render MAS token-level diagnostics (attribution weights + guided marginal deltas).""" method_label = case_meta.get("attr_method_label") or case_meta.get("attr_method") or "Unknown method" title = f"MAS Token Study ({method_label})" neg_handling = case_meta.get("attnlrp_neg_handling") norm_mode = case_meta.get("attnlrp_norm_mode") ratio_enabled = case_meta.get("attnlrp_ratio_enabled") attn_rows = [] if neg_handling: attn_rows.append(f"
FT-AttnLRP neg_handling: {escape(str(neg_handling))}
") if norm_mode: attn_rows.append(f"
FT-AttnLRP norm_mode: {escape(str(norm_mode))}
") if ratio_enabled is not None: attn_rows.append(f"
FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}
") base_score = case_meta.get("base_score") base_score_row = f"
Base score: {float(base_score):.6f}
" if isinstance(base_score, (int, float)) else "" gen_block = "" if isinstance(generation, str) and generation: gen_block = f"""
Generation (scored)
{escape(generation)}
""" header = f"""
{escape(title)}
Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}
Attribution method: {escape(str(case_meta.get('attr_method')))}
Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}
Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}
Prompt tokens: {len(prompt_tokens)}
Panels: {len(panels)}
{''.join(attn_rows)} {base_score_row}
""" panel_sections: List[str] = [] for panel in panels: label = panel.get("variant_label") or panel.get("panel_label") or panel.get("variant") or "Panel" metrics = panel.get("metrics") or {} metrics_str = " | ".join( f"{k}: {float(metrics[k]):.4f}" if isinstance(metrics.get(k), (int, float)) else f"{k}: {metrics.get(k)}" for k in ("RISE", "MAS", "RISE+AP") if k in metrics ) attr_weights = panel.get("attr_weights") or [] guided_deltas = panel.get("token_deltas_raw") or [] rank_order = panel.get("sorted_attr_indices") or [] rank_str = ", ".join(str(int(x)) for x in rank_order) if rank_order else "N/A" panel_sections.append( f"""
{escape(str(label))}
{escape(metrics_str)}
{_render_token_spans("Method attribution (token weights)", prompt_tokens, attr_weights)} {_render_token_spans("Attribution-guided MAS marginal (path deltas)", prompt_tokens, guided_deltas)}
Rank order: {escape(rank_str)}
""" ) style = """ """ html = f""" {escape(title)} {style} {header} {gen_block} {''.join(panel_sections)} """ return html