"""
benchmark_interaction.py — Interactive cross-modal visualization for the benchmark tab.
Creates a self-contained HTML component that shows:
- Image with real UnSAM segment outlines (clickable via hidden canvas label map)
- Caption tokens below the image (clickable, colored by Shapley value)
- Arrows/lines connecting image regions to their most interacting tokens
- Click a region → highlights linked tokens; click a token → highlights linked regions
"""
from __future__ import annotations
import json
import uuid
from html import escape
from typing import Any, Dict, List, Optional, Sequence, Tuple
def _value_to_color(value: float, max_abs: float, single_color: bool = False) -> str:
if max_abs <= 0:
return "rgb(225, 225, 223)"
norm = min(1.0, abs(value) / max_abs)
if single_color:
base = (225, 225, 223)
target = (52, 102, 177) # influence neutral blue
elif value >= 0:
base = (225, 225, 223)
target = (1, 109, 1)
else:
base = (225, 225, 223)
target = (221, 19, 19)
r = int(round(base[0] + (target[0] - base[0]) * norm))
g = int(round(base[1] + (target[1] - base[1]) * norm))
b = int(round(base[2] + (target[2] - base[2]) * norm))
return f"rgb({r}, {g}, {b})"
def _value_to_rgba(value: float, max_abs: float, alpha: float = 0.5) -> str:
if max_abs <= 0:
return "rgba(200, 200, 200, 0.1)"
norm = min(1.0, abs(value) / max_abs)
if value >= 0:
return f"rgba(1, 109, 1, {alpha * norm:.2f})"
else:
return f"rgba(221, 19, 19, {alpha * norm:.2f})"
def create_benchmark_interaction_html(
image_b64: str,
clip_summary: Optional[Dict[str, Any]],
vllm_logprob: Optional[Dict[str, Any]],
caption: str = "",
all_cross_modal_pairs: Optional[List[Dict[str, Any]]] = None,
segmap_b64: str = "",
overlay_b64: str = "",
segment_bboxes: Optional[List[Optional[Dict[str, float]]]] = None,
label_map_b64: str = "",
image_width: int = 0,
image_height: int = 0,
title: str = "Cross-Modal Interaction View",
method_label: str = "Shapley",
) -> str:
"""
Build a self-contained HTML + JS component for the benchmark tab.
Shows:
- Image panel (left) with real UnSAM segment outlines and clickable regions
- Token panel (right) with clickable colored tokens
- SVG arrow layer connecting regions to tokens on click
- Cross-modal interaction list
When segment_bboxes and label_map_b64 are provided, uses canvas-based
pixel-level click detection for accurate segment selection matching
the real UnSAM segmentation. Falls back to horizontal strips otherwise.
"""
if not image_b64 and not clip_summary:
return "
No data available. Select an example.
"
view_id = f"bm-{uuid.uuid4().hex[:8]}"
is_influence = (method_label or "").lower() == "influence"
# Prepare image URL
img_url = image_b64
if img_url and not img_url.startswith("data:"):
img_url = f"data:image/png;base64,{img_url}"
# Extract CLIP data
regions = [] # {index, label, value}
tokens = [] # {index, label, value}
cross_interactions = [] # {seg, tok, value}
n_segs = 0
if clip_summary:
raw_items = clip_summary.get("image_region_values", [])
total_regions = len(raw_items)
grid_guess = int(round(total_regions ** 0.5))
looks_like_patch_grid = (grid_guess * grid_guess == total_regions) and all(
str(it.get("label", "")).startswith("patch_")
or str(it.get("label", "")).isdigit()
for it in raw_items
)
for item in raw_items:
# Resolve the segment number:
# seg_6 → 6 (UnSAM)
# patch_1_2 → 1*grid+2 (raw patch-grid; assumes grid=sqrt(n))
# "7" → 7 - 1 = 6 (post-rename patch-grid)
raw_label = str(item["label"])
seg_num = n_segs # fallback: sequential
if raw_label.startswith("seg_"):
try:
seg_num = int(raw_label.split("_", 1)[1])
except (ValueError, IndexError):
pass
elif raw_label.startswith("patch_"):
try:
_, r_str, c_str = raw_label.split("_", 2)
seg_num = int(r_str) * grid_guess + int(c_str)
except (ValueError, IndexError):
pass
elif raw_label.isdigit():
try:
seg_num = int(raw_label) - 1
except ValueError:
pass
# Display label: in patch-grid mode always show "1".."N" in reading
# order so the overlay doesn't leak raw "patch_r_c" text.
display_label = str(seg_num + 1) if looks_like_patch_grid else raw_label
regions.append({
"index": seg_num,
"label": display_label,
"value": item["value"],
"type": "segment",
})
n_segs += 1
# Build a lookup from CLIP token labels to values.
# Also build a ##-stripped version for subword matching.
clip_tok_values: Dict[str, float] = {}
for item in clip_summary.get("token_values", []):
tok_label = item["label"].replace("tok:", "")
clip_tok_values[tok_label] = item["value"]
# Collect tokens with ## stripped for substring matching
clip_tok_set = set(clip_tok_values.keys())
# Include ALL words from the full caption, not just CLIP's top-k
if caption:
words = caption.replace(".", " .").replace(",", " ,").replace("(", " (").replace(")", " )").split()
for i, word in enumerate(words):
value = clip_tok_values.get(word, 0.0)
matched_tok = word if value != 0.0 else None
if value == 0.0:
value = clip_tok_values.get(word.lower(), 0.0)
matched_tok = word.lower() if value != 0.0 else None
if value == 0.0:
# Sum all matching subword tokens (strip ## before matching)
total = 0.0
for tok in clip_tok_set:
tok_clean = tok.lstrip("#")
if len(tok_clean) >= 3 and tok_clean.lower() in word.lower():
total += clip_tok_values[tok]
matched_tok = tok
value = total
tokens.append({
"index": i,
"label": word,
"value": value,
"_matched_tok": matched_tok,
})
else:
for i, item in enumerate(clip_summary.get("token_values", [])):
tok_label = item["label"].replace("tok:", "").lstrip("#")
tokens.append({
"index": i,
"label": tok_label,
"value": item["value"],
"_matched_tok": tok_label,
})
# Use ALL cross-modal pairs if provided, else fall back to top-5.
# Map subword token labels to whole caption words.
from .medical_charts import _tok_to_word
cross_source = all_cross_modal_pairs or clip_summary.get("cross_modal_interactions", [])
def _seg_display(seg_raw: str) -> str:
# Normalize cross-pair segment labels the same way we normalized
# region labels above — otherwise arrows can't match regions.
s = str(seg_raw)
if looks_like_patch_grid and s.startswith("patch_"):
try:
_, rr, cc = s.split("_", 2)
return str(int(rr) * grid_guess + int(cc) + 1)
except (ValueError, IndexError):
return s
return s
for item in cross_source:
cross_interactions.append({
"seg": _seg_display(item["pair"][0]),
"tok": _tok_to_word(item["pair"][1], caption) if caption else item["pair"][1].replace("tok:", "").lstrip("#"),
"value": item["value"],
})
# Determine if we have real segment bounding boxes
has_real_bboxes = (
segment_bboxes is not None
and len(segment_bboxes) == n_segs
and any(b is not None for b in segment_bboxes)
)
has_label_map = bool(label_map_b64)
# Build region overlays: real bboxes if available, else horizontal strips
# When a pixel-accurate label map is available, skip the rectangular div
# overlays entirely — the segmap image already shows real segment contours.
# We still render small labels at segment centers for identification.
region_overlays_html = ""
max_abs_r = max((abs(r["value"]) for r in regions), default=1.0) or 1.0
if has_label_map and has_real_bboxes and n_segs > 0:
# Label-map mode: no rectangular divs, just center labels.
# When segments look like a patch grid (perfect-square count and numeric
# or patch-style labels), show reading-order numbers 1..N so the overlay
# reads left-to-right top-to-bottom without the raw "patch_r_c" noise.
grid = int(round(n_segs ** 0.5))
is_patch_grid = (grid * grid == n_segs) and all(
str(r["label"]).isdigit() or str(r["label"]).startswith("patch_")
for r in regions
)
for r in regions:
r_idx = r["index"]
bbox = segment_bboxes[r_idx] if r_idx < len(segment_bboxes) else None
if bbox is None:
continue
r_label = escape(str(r_idx + 1) if is_patch_grid else r["label"])
r_value = r["value"]
region_overlays_html += (
f""
f"{r_label}"
)
elif has_real_bboxes and n_segs > 0:
# No label map — use rectangular bounding box divs as fallback
for r in regions:
r_idx = r["index"]
bbox = segment_bboxes[r_idx] if r_idx < len(segment_bboxes) else None
if bbox is None:
continue
r_label = escape(r["label"])
r_value = r["value"]
if is_influence:
norm = min(1.0, abs(r_value) / max_abs_r) if max_abs_r else 0.0
bg = f"rgba(52,102,177,{0.25*norm:.2f})"
border_color = "rgba(52,102,177,0.8)"
val_str = f"{r_value:.2f}"
title_str = f"{r_label}: {r_value:.3f}"
else:
bg = _value_to_rgba(r_value, max_abs_r, 0.25)
border_color = "rgba(1,109,1,0.8)" if r_value >= 0 else "rgba(200,40,40,0.8)"
val_str = f"{r_value:+.2f}"
title_str = f"{r_label}: {r_value:+.3f}"
region_overlays_html += (
f"