Stephentao-30
Public Mode: force Bar View Plot to full container width
38ffa3d
import base64
import copy
import inspect
import io
import json
import os
import random
import re
from itertools import combinations
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional
import requests # new
BACKEND_URL = os.getenv("ATTRLLM_BACKEND_URL", "http://127.0.0.1:8000")
_DEFAULT_GRADIO_DIR = Path(os.environ.get("GRADIO_TEMP_DIR", Path.cwd() / ".gradio_tmp"))
os.environ.setdefault("GRADIO_TEMP_DIR", str(_DEFAULT_GRADIO_DIR))
_DEFAULT_GRADIO_DIR.mkdir(parents=True, exist_ok=True)
def _get_request_timeout() -> float:
value = os.getenv("ATTRLLM_REQUEST_TIMEOUT")
if not value:
return 900.0
try:
return float(value)
except ValueError:
return 900.0
def _env_flag(name: str, default: bool = False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
def _is_hf_spaces() -> bool:
return bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE"))
def _supports_kwarg(callable_obj, kwarg_name: str) -> bool:
"""Return whether a callable appears to accept a named keyword argument."""
try:
return kwarg_name in inspect.signature(callable_obj).parameters
except (TypeError, ValueError):
return False
def _public_only_mode() -> bool:
# Keep the text tab visible on Spaces unless explicitly overridden.
return _env_flag("ATTRLLM_PUBLIC_ONLY", False)
def _mm_only_mode() -> bool:
return _env_flag("ATTRLLM_MM_ONLY", False)
def _show_auxiliary_tabs() -> bool:
return _env_flag("ATTRLLM_SHOW_AUX_TABS", False)
def _public_results_file(
dataset_key: str,
ex_id: str,
scalarizer: str,
level: str,
method: str,
) -> Path:
results_dir = _get_results_dir()
return (
results_dir
/ "public"
/ dataset_key
/ ex_id
/ scalarizer
/ level
/ f"{method}.json"
)
def _reference_results_file(
model_size: str,
dataset_key: str,
ex_id: str,
scalarizer: str,
level: str,
) -> Path:
results_dir = _get_results_dir()
return (
results_dir
/ "reference_answer"
/ model_size
/ dataset_key
/ ex_id
/ scalarizer
/ f"{level}.json"
)
def _find_available_model_size(
dataset_key: str,
ex_id: str,
scalarizer: str,
level: str,
) -> Optional[str]:
for size in ("large", "medium", "small"):
if _reference_results_file(size, dataset_key, ex_id, scalarizer, level).exists():
return size
return None
# Fallback order when requested (scalarizer, level) is not present (e.g. on HF Space with partial results).
_FALLBACK_SCALARIZER_LEVELS: List[Tuple[str, str]] = [
("geomean_jointprob", "word"),
("semantic_similarity", "word"),
("geomean_jointprob", "sentence"),
("semantic_similarity", "sentence"),
("geomean_jointprob", "paragraph"),
("semantic_similarity", "paragraph"),
]
def _find_any_available_result(
dataset_key: str,
ex_id: str,
get_res: Any,
method: str = "shapley",
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[Dict]]:
"""Try (model_size, scalarizer, level) fallbacks; return (size, scalarizer, level, result_dict) or (None,)*4."""
for size in ("small", "medium", "large"):
for scalarizer, level in _FALLBACK_SCALARIZER_LEVELS:
try:
result = get_res(size, dataset_key, ex_id, scalarizer=scalarizer, feature_level=level) or {}
payload = result.get(method, {})
if payload and (payload.get("features") or payload.get("heatmap")):
return (size, scalarizer, level, result)
except Exception:
continue
return (None, None, None, None)
def _parse_sparse_key(raw_key: str) -> Tuple[int, ...]:
key = str(raw_key).strip()
if not key:
return ()
return tuple(int(part) for part in key.split(",") if part != "")
def _normalize_public_payload_fallback(data: Dict[str, Any], method: str, top_k: int = 10) -> Dict[str, Any]:
"""Convert your JSON (features list + meta + mobius_dict) to UI display format. mobius_dict can be empty."""
if not isinstance(data, dict):
return {}
features = data.get("features")
mobius_raw = data.get("mobius_dict") if isinstance(data.get("mobius_dict"), dict) else {}
if not isinstance(features, list) or not features:
return {}
method = (method or "shapley").lower()
if method not in {"shapley", "banzhaf", "influence"}:
method = "shapley"
mobius_sparse: Dict[Tuple[int, ...], float] = {}
for key, raw_val in mobius_raw.items():
try:
val = float(raw_val)
except Exception:
continue
try:
loc = _parse_sparse_key(str(key))
except Exception:
continue
mobius_sparse[tuple(sorted(loc))] = val
token_scores: Dict[str, float] = {}
index_scores: Dict[int, float] = {}
pairwise_acc: Dict[Tuple[int, int], float] = {}
if mobius_sparse and mobius_to_shapley is not None:
if method == "shapley":
singleton_dict = mobius_to_shapley(mobius_sparse)
pair_list = shapley_interactions(mobius_sparse, order=2, top_k=top_k) or []
elif method == "banzhaf":
singleton_dict = mobius_to_banzhaf(mobius_sparse)
pair_list = banzhaf_interactions(mobius_sparse, order=2, top_k=top_k) or []
elif mobius_to_influence is not None and influence_interactions is not None:
singleton_dict = mobius_to_influence(mobius_sparse)
pair_list = influence_interactions(mobius_sparse, order=2, top_k=top_k) or []
else:
singleton_dict = {}
pair_list = []
for loc, val in singleton_dict.items():
if len(loc) != 1:
continue
idx = int(loc[0])
if 0 <= idx < len(features):
feat_name = str(features[idx])
val_f = float(val)
token_scores[feat_name] = token_scores.get(feat_name, 0.0) + val_f
index_scores[idx] = index_scores.get(idx, 0.0) + val_f
for loc, val in pair_list:
if len(loc) != 2:
continue
i, j = int(loc[0]), int(loc[1])
if 0 <= i < len(features) and 0 <= j < len(features):
key = (i, j) if i <= j else (j, i)
pairwise_acc[key] = float(val)
else:
# Best-effort fallback when attribution helpers are unavailable.
for loc, val in mobius_sparse.items():
k = len(loc)
if k == 0:
continue
if method == "shapley":
sw = 1.0 / float(k)
elif method == "banzhaf":
sw = 1.0 / float(2 ** (k - 1))
else:
sw = 1.0 / float(k)
for idx in loc:
if 0 <= idx < len(features):
feat_name = str(features[idx])
token_scores[feat_name] = token_scores.get(feat_name, 0.0) + sw * val
index_scores[idx] = index_scores.get(idx, 0.0) + sw * val
if k >= 2:
if method == "shapley":
pw = 1.0 / float(k - 1)
elif method == "banzhaf":
pw = 1.0 / float(2 ** (k - 2))
else:
pw = 1.0 / float(k - 1)
for i, j in combinations(sorted(loc), 2):
pairwise_acc[(i, j)] = pairwise_acc.get((i, j), 0.0) + pw * val
unique_feature_labels = [str(x) for x in features]
sorted_pairs = sorted(pairwise_acc.items(), key=lambda kv: abs(kv[1]), reverse=True)
if top_k and top_k > 0:
sorted_pairs = sorted_pairs[:top_k]
pairwise = {
"%s|%s" % (unique_feature_labels[i], unique_feature_labels[j]): float(v)
for (i, j), v in sorted_pairs
if 0 <= i < len(unique_feature_labels) and 0 <= j < len(unique_feature_labels)
}
pairwise_interactions = [
{"features": [unique_feature_labels[i], unique_feature_labels[j]], "value": float(v)}
for (i, j), v in sorted_pairs
if 0 <= i < len(unique_feature_labels) and 0 <= j < len(unique_feature_labels)
]
normalized = dict(data)
normalized["token_scores"] = token_scores
normalized["pairwise"] = pairwise
normalized["pairwise_interactions"] = pairwise_interactions
normalized["features"] = [
{"feature": str(features[i]), "value": float(index_scores.get(i, 0.0)), "index": i}
for i in range(len(features))
]
normalized["feature_texts"] = [str(x) for x in features]
return normalized
def _public_get_model_answer_short_from_file(
model_size: str,
dataset: str,
ex_id: str,
scalarizer: str = "geomean_jointprob",
feature_level: str = "word",
) -> Dict[str, Any]:
"""Load model_answer_short payload (the wrong-answer attribution) for the
Public Mode dual-heatmap branch. Returns a per-method dict shaped like
`_public_get_result_from_file`, or {} when the file is missing.
"""
results_dir = _get_results_dir()
path = (
results_dir
/ "model_answer_short"
/ model_size
/ dataset
/ ex_id
/ scalarizer
/ f"{feature_level}.json"
)
if not path.exists():
return {}
try:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
except Exception:
return {}
if not isinstance(data, dict):
return {}
norm_s = _normalize_public_payload_fallback(copy.deepcopy(data), "shapley")
norm_b = _normalize_public_payload_fallback(copy.deepcopy(data), "banzhaf")
norm_i = _normalize_public_payload_fallback(copy.deepcopy(data), "influence")
if not norm_s and not norm_b and not norm_i:
return {}
return {
"shapley": norm_s,
"banzhaf": norm_b,
"influence": norm_i,
"meta": {
"dataset": dataset,
"example_id": ex_id,
"model_size": model_size,
"target_mode": "model_answer_short",
"source_layout": "results/model_answer_short/{model_size}/{dataset}/{example_id}/{scalarizer}/{feature_level}.json",
},
}
def _public_get_result_from_file(
model_size: str,
dataset: str,
ex_id: str,
scalarizer: Optional[str] = None,
feature_level: Optional[str] = None,
) -> Dict[str, Any]:
"""Load reference_answer result from disk when loader.results.get_result_by_id is unavailable (e.g. on Space)."""
scalarizer = (scalarizer or "").strip()
feature_level = (feature_level or "").strip()
if not scalarizer or not feature_level:
return {}
levels_to_try = [feature_level] + [l for l in ("word", "sentence", "paragraph") if l != feature_level]
for lvl in levels_to_try:
path = _reference_results_file(model_size, dataset, ex_id, scalarizer, lvl)
if not path.exists() or os.getenv("SPACE_ID"):
try:
from loader.results import _maybe_download_from_space
path = _maybe_download_from_space(path, force_download=True) or path
except Exception:
pass
if not path.exists():
continue
try:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
except Exception:
continue
if not isinstance(data, dict):
continue
# Your JSON: features (list) + meta + mobius_dict (can be empty). Always convert to UI format.
norm_s = _normalize_public_payload_fallback(copy.deepcopy(data), "shapley")
norm_b = _normalize_public_payload_fallback(copy.deepcopy(data), "banzhaf")
norm_i = _normalize_public_payload_fallback(copy.deepcopy(data), "influence")
if not norm_s and not norm_b and not norm_i:
continue
return {
"shapley": norm_s,
"banzhaf": norm_b,
"influence": norm_i,
"meta": {
"dataset": dataset,
"example_id": ex_id,
"model_size": model_size,
"source_layout": "results/reference_answer/{model_size}/{dataset}/{example_id}/{scalarizer}/{feature_level}.json",
},
}
return {}
_FALLBACK_DATASET_FILES: Dict[str, str] = {
"bar_exam": "BarExam_qa.csv",
"causal_judgment": "bbh_causal_judgement.csv",
"snarks": "bbh_snarks.csv",
"bbq_disamb": "BBQ_disamb.csv",
"cnn_dailymail": "CNN_dailymail.csv",
"drop": "drop.csv",
"esnli": "eSNLI.csv",
"fever": "fever.csv",
"hotpot_qa": "hotpot_qa.csv",
"medical_qa": "medical_qa.csv",
}
def _fallback_datasets_dir() -> Path:
return (_REPO_ROOT / "datasets").resolve()
def _fallback_pick_first_nonempty(raw: Dict[str, str], candidates: List[str]) -> str:
for c in candidates:
val = raw.get(c)
if val is not None and str(val).strip() != "":
return str(val)
return ""
def _fallback_load_dataset(dataset_key: str, max_rows: int = 10) -> List[Dict[str, str]]:
import csv
filename = _FALLBACK_DATASET_FILES.get(dataset_key)
if not filename:
return []
path = _fallback_datasets_dir() / filename
if not path.exists():
return []
rows: List[Dict[str, str]] = []
with path.open("r", encoding="utf-8", errors="replace", newline="") as f:
reader = csv.DictReader(f)
for i, raw in enumerate(reader, start=1):
ex_id = raw.get("id") or raw.get("example_id") or raw.get("uid") or f"example_{i}"
context = _fallback_pick_first_nonempty(raw, [
"Context", "context",
"passage", "article", "story", "premise",
"paragraph", "document", "sentence1", "sent1", "background",
])
prompt = _fallback_pick_first_nonempty(raw, [
"Prompt", "prompt",
"question", "input", "query",
"sentence2", "sent2", "hypothesis",
"qa_question", "title",
])
answer = _fallback_pick_first_nonempty(raw, [
"Answer", "answer",
"target", "gold", "label", "output", "reference",
"highlights",
])
ex = {
"id": str(ex_id),
"context": context,
"prompt": prompt,
}
if answer:
ex["answer"] = answer
rows.append(ex)
if len(rows) >= max_rows:
break
return rows
REQUEST_TIMEOUT = _get_request_timeout()
SCALARIZER_CHOICES = [
("Semantic Similarity (y vs y_S)", "semantic_similarity"),
("LogProb", "logprob"),
("JointProb", "jointprob"),
("GeoMean JointProb", "geomean_jointprob"),
("Half SimLog", "half_simlog"),
]
PUBLIC_SCALARIZER_CHOICES = [
("Semantic Similarity", "semantic_similarity"),
("Perplexity", "geomean_jointprob"),
]
DATASET_DISPLAY_LABELS = {
"bar_exam": "Bar Exam Questions",
"bbq_disamb": "BBQ Disambiguation",
"causal_judgment": "Causal Judgment",
"cnn_dailymail": "CNN / DailyMail Summaries",
"drop": "DROP Reading Comprehension",
"esnli": "e-SNLI Natural Language Inference",
"fever": "FEVER Fact Checking",
"hotpot_qa": "HotpotQA Multi-hop Questions",
"medical_qa": "Medical Questions",
"snarks": "Snarks",
}
import sys
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
def _get_results_dir() -> Path:
"""Resolve results directory: env, repo root, or on HF Space fallback to cwd/results."""
env_dir = os.getenv("ATTRLLM_RESULTS_DIR")
if env_dir:
return Path(env_dir).resolve()
default = (_REPO_ROOT / "results").resolve()
if default.exists():
return default
if _is_hf_spaces():
cwd_results = (Path.cwd() / "results").resolve()
if cwd_results.exists():
return cwd_results
return default
import gradio as gr
from PIL import Image
from .components.model_selector import (
create_model_selector,
create_multimodal_model_selector,
create_feature_level_selector,
create_attribution_method_toggle,
)
from .components.example_browser import create_dataset_selector, create_example_browser
from .components.results_display import create_results_display, update
from .plotting.heatmap import create_interactive_text_heatmap
from .plotting.interactions import (
plot_top_interactions,
plot_interaction_matrix,
create_interaction_token_view,
)
from .plotting.text_interactions import create_text_interaction_html
from .plotting.mm_interactions import create_multimodal_interaction_html
from .plotting.coalition_viewer import compute_coalition_viewer_data, render_coalition_viewer_html
from .build_info import BUILD_ID, BUILD_TS
# Medical image precomputed results (optional)
try:
from .medical_loader import (
MEDICAL_EXAMPLES,
load_medical_example,
get_masked_image_path,
BENCHMARK_EXAMPLES,
get_examples_by_modality,
list_available_modalities,
load_benchmark_example,
extract_segment_regions,
)
from .plotting.medical_charts import (
create_shapley_bar_chart,
create_influence_heatmap,
create_cross_modal_bar_chart,
draw_grid_overlay,
draw_segment_labels,
generate_interpretation_text,
rename_patch_labels,
align_segments_to_reference,
remap_region_values,
merge_subword_token_values,
_tok_to_word,
)
from .plotting.benchmark_interaction import create_benchmark_interaction_html
_MEDICAL_AVAILABLE = True
except ImportError:
_MEDICAL_AVAILABLE = False
MEDICAL_EXAMPLES = {}
BENCHMARK_EXAMPLES = {}
# MIMIC-CXR precomputed results (optional)
try:
from .mimic_loader import (
MIMIC_EXAMPLES,
load_mimic_example,
get_mimic_image_path,
)
_MIMIC_AVAILABLE = bool(MIMIC_EXAMPLES)
except ImportError:
_MIMIC_AVAILABLE = False
MIMIC_EXAMPLES = {}
# Dermoscopy ISIC precomputed results (optional)
try:
from .isic_loader import (
ISIC_EXAMPLES,
load_isic_example,
get_isic_image_path,
)
_ISIC_AVAILABLE = bool(ISIC_EXAMPLES)
except ImportError:
_ISIC_AVAILABLE = False
ISIC_EXAMPLES = {}
# MS-COCO precomputed results (optional)
try:
from .coco_loader import COCO_EXAMPLES, load_coco_example, get_coco_masked_image_path
_COCO_AVAILABLE = True
except ImportError:
_COCO_AVAILABLE = False
COCO_EXAMPLES = {}
# CLIP cross-modal pipeline for live compute (optional — runs on CPU)
try:
from attribution.set_mm import (
PipelineConfig,
CrossModalCLIPScorer,
ImageRegion,
TokenPlayer,
featurise,
tokenise_caption,
build_cross_modal_set_function,
run_proxyspex,
mobius_to_shapley,
mobius_to_banzhaf,
extract_interactions,
extract_cross_per_token,
apply_image_mask,
render_overlay,
render_segmentation_map,
mask_token_ids,
)
_CLIP_PIPELINE_AVAILABLE = True
except ImportError:
_CLIP_PIPELINE_AVAILABLE = False
# Module-level cache for CLIP scorers (keyed by model name)
_clip_scorer_cache: Dict[str, Any] = {}
def _raise_backend_error(resp: requests.Response, label: str) -> None:
detail = resp.text
try:
detail = resp.json().get("detail", detail)
except Exception:
pass
raise gr.Error(f"{label} failed ({resp.status_code}). {detail}")
# backend API imports
try: # loader data APIs are required for public mode
from loader.data import (
get_example_by_id,
get_examples,
list_datasets,
list_datasets_with_display_names,
list_dataset_display_names,
get_dataset_display_name,
get_dataset_key_from_display_name,
)
except Exception: # pragma: no cover - optional at runtime
get_example_by_id = None
get_examples = None
list_datasets = None
list_datasets_with_display_names = None
list_dataset_display_names = None
get_dataset_display_name = None
get_dataset_key_from_display_name = None
try:
from loader.results import get_result_by_id
except Exception: # pragma: no cover
get_result_by_id = None
try:
from loader.models import get_model
except Exception: # pragma: no cover
get_model = None
try: # attribution stack is optional (dev mode)
from attribution.masker import get_masker, mask_text
from attribution.proxyspex import run_proxyspex
from attribution.image_masker import supports_superpixel
from attribution.utils import (
influence_interactions,
mobius_to_influence,
mobius_to_shapley,
shapley_interactions,
mobius_to_banzhaf,
banzhaf_interactions,
)
except Exception: # pragma: no cover
get_masker = None
mask_text = None
run_proxyspex = None
supports_superpixel = None
influence_interactions = None
mobius_to_influence = None
mobius_to_shapley = None
shapley_interactions = None
mobius_to_banzhaf = None
banzhaf_interactions = None
_ANSWER_FIELDS = (
"correct_answer",
"answer",
"target",
"completion",
"label",
)
_ALLOWED_METHODS = {"shapley", "banzhaf", "influence"}
_ALLOWED_LEVELS = {"word", "sentence", "paragraph"}
def _ensure_backend(name: str, fn: Optional[Any]):
if fn is None:
raise RuntimeError(
f"{name} is unavailable. Ensure the backend modules are installed and importable."
)
return fn
def _html_component(label: str, visible: bool = True) -> gr.HTML:
try:
return gr.HTML(label=label, sanitize_html=False, visible=visible)
except TypeError:
return gr.HTML(label=label, visible=visible)
def _encode_image_to_b64(image: Image.Image) -> str:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def _extract_answer(record: Dict[str, Any]) -> str:
for field in _ANSWER_FIELDS:
val = record.get(field)
if val:
return str(val)
return ""
def _coerce_feature_tuple(raw_key: Any) -> Tuple[str, ...]:
if isinstance(raw_key, tuple):
return tuple(str(item) for item in raw_key)
if isinstance(raw_key, list):
return tuple(str(item) for item in raw_key)
if isinstance(raw_key, str):
for sep in ("·", "|", ",", "×"):
if sep in raw_key:
parts = [chunk.strip() for chunk in raw_key.split(sep) if chunk.strip()]
if parts:
return tuple(parts)
return (raw_key.strip(),)
return (str(raw_key),)
# def _normalize_interactions(raw: Any) -> List[Tuple[Tuple[str, ...], float]]:
# items: List[Any]
# if raw is None:
# return []
# if isinstance(raw, dict):
# items = list(raw.items())
# else:
# items = list(raw)
# normalized: List[Tuple[Tuple[str, ...], float]] = []
# for feats, value in items:
# try:
# numeric = float(value)
# except Exception:
# continue
# normalized.append((_coerce_feature_tuple(feats), numeric))
# return normalized
def _normalize_interactions(raw: Any) -> List[Tuple[Tuple[str, ...], float]]:
"""
Make a best-effort guess at interaction structure.
Supported shapes:
- { key: float }
- { key: {"value": float, "score": ...} }
- [ (key, float), ... ]
- [ (key, {"value": float}), ... ]
- [ {"features": [...], "value": float}, ... ] (this is mostly handled elsewhere)
"""
if raw is None:
return []
items: List[Any] = []
if isinstance(raw, dict):
# e.g. { key: float } or { key: {"value": ...} }
for k, v in raw.items():
items.append((k, v))
elif isinstance(raw, list):
items = list(raw)
else:
return []
normalized: List[Tuple[Tuple[str, ...], float]] = []
for item in items:
# Case 1: dict-style item with explicit fields
if isinstance(item, dict):
feats = item.get("features") or item.get("indices") or item.get("pair") or item.get("key")
val = item.get("value", item.get("score", 0.0))
else:
# Case 2: tuple/list pair (feats, value)
try:
feats, val = item
except Exception:
continue
# If value itself is a dict, dig out "value" / "score"
if isinstance(val, dict):
val = val.get("value", val.get("score", 0.0))
try:
numeric = float(val)
except Exception:
continue
feats_tuple = _coerce_feature_tuple(feats)
if feats_tuple:
normalized.append((feats_tuple, numeric))
return normalized
def _resolve_marginals(payload: Dict[str, Any]) -> Dict[str, float]:
for key in ("marginals", "token_scores", "values", "scores"):
data = payload.get(key)
if isinstance(data, dict):
normalized: Dict[str, float] = {}
for k, v in data.items():
try:
normalized[str(k)] = float(v)
except Exception:
continue
return normalized
return {}
def _resolve_features(payload: Dict[str, Any], marginals: Dict[str, float]) -> List[str]:
features = payload.get("features")
if isinstance(features, list):
return [str(f) for f in features]
if marginals:
return list(marginals.keys())
return []
def _extract_interactions_from_response(
data_int: Dict[str, Any],
method: str,
features: List[str],
) -> List[Tuple[Tuple[str, ...], float]]:
inter_list: List[Tuple[Tuple[str, ...], float]] = []
method_key = (method or "shapley").lower()
method_block = data_int.get(method_key) or data_int
raw_interactions = None
if isinstance(method_block, dict):
for key in ("interactions", "pairwise_interactions", "interactions_2", "pairwise", "data"):
if key in method_block:
raw_interactions = method_block.get(key)
break
if raw_interactions is None:
raw_interactions = method_block
else:
raw_interactions = method_block
# List-of-dicts or list-of-pairs shape
if isinstance(raw_interactions, list) and raw_interactions:
if isinstance(raw_interactions[0], dict):
for item in raw_interactions:
feats = (
item.get("feature_list")
or item.get("features")
or item.get("indices")
or item.get("pair")
or []
)
val = None
for key_val in ("value", "score", "attribution", "weight"):
if key_val in item:
try:
val = float(item[key_val])
break
except Exception:
continue
if val is None:
continue
if isinstance(feats, list) and feats and isinstance(feats[0], int):
feat_names = tuple(
features[i] for i in feats
if isinstance(i, int) and 0 <= i < len(features)
)
else:
feat_names = _coerce_feature_tuple(feats)
if feat_names:
inter_list.append((feat_names, val))
elif (
isinstance(raw_interactions[0], (list, tuple))
and len(raw_interactions[0]) == 2
):
for item in raw_interactions:
if not isinstance(item, (list, tuple)) or len(item) != 2:
continue
feats_raw, val_raw = item
try:
val = float(val_raw)
except Exception:
continue
feat_names: Tuple[str, ...] = ()
if isinstance(feats_raw, (list, tuple)) and feats_raw:
if all(isinstance(i, int) for i in feats_raw):
feat_names = tuple(
features[i] for i in feats_raw
if 0 <= i < len(features)
)
else:
feat_names = _coerce_feature_tuple(feats_raw)
elif isinstance(feats_raw, str):
feat_names = _coerce_feature_tuple(feats_raw)
if feat_names:
inter_list.append((feat_names, val))
# Dict shape, e.g. {"(0,2)": 528.0, ...}
if not inter_list and isinstance(raw_interactions, dict):
metadata_keys = {"method", "order", "scalarizer", "embedding_model"}
for k, v in raw_interactions.items():
if str(k) in metadata_keys:
continue
val = None
if isinstance(v, (int, float)):
val = float(v)
elif isinstance(v, dict):
for key_val in ("value", "score", "attribution", "weight"):
if key_val in v:
try:
val = float(v[key_val])
break
except Exception:
continue
if val is None:
continue
k_str = str(k)
idxs = []
try:
import re as _re
idxs = [int(x) for x in _re.findall(r"\d+", k_str)]
except Exception:
idxs = []
if idxs:
names: List[str] = []
for idx in idxs:
if 0 <= idx < len(features):
names.append(features[idx])
if names:
feat_names = tuple(names)
else:
feat_names = _coerce_feature_tuple(k_str)
else:
feat_names = _coerce_feature_tuple(k_str)
inter_list.append((feat_names, val))
# Flatten numerics arbitrarily (last resort)
if not inter_list and raw_interactions is not None:
flat: List[Tuple[Tuple[str, ...], float]] = []
def _collect(obj: Any, prefix: Tuple[str, ...] = ()) -> None:
if isinstance(obj, (int, float)):
flat.append((prefix or ("<interaction>",), float(obj)))
elif isinstance(obj, list):
for i, item in enumerate(obj):
_collect(item, prefix + (f"[{i}]",))
elif isinstance(obj, dict):
for kk, vv in obj.items():
_collect(vv, prefix + (str(kk),))
_collect(raw_interactions)
inter_list = flat
return inter_list
def _labels_from_regions(regions: List[Dict[str, Any]]) -> List[str]:
labels: List[str] = [""] * len(regions)
for region in regions:
try:
idx = int(region.get("index", 0))
except Exception:
continue
if idx < 0 or idx >= len(labels):
continue
labels[idx] = str(region.get("label") or f"Region {idx + 1}")
for idx, label in enumerate(labels):
if not label:
labels[idx] = f"Region {idx + 1}"
return labels
def _interaction_dicts_to_pairs(
interactions: List[Dict[str, Any]],
labels: List[str],
*,
order: int | None = None,
) -> List[Tuple[Tuple[str, ...], float]]:
pairs: List[Tuple[Tuple[str, ...], float]] = []
for item in interactions:
indices = item.get("indices")
if not indices:
continue
if order is not None and len(indices) != order:
continue
try:
value = float(item.get("value", 0.0))
except Exception:
continue
feats = tuple(labels[int(i)] for i in indices if int(i) < len(labels))
if feats:
pairs.append((feats, value))
return pairs
def _interaction_dicts_to_table(
interactions: List[Dict[str, Any]],
labels: List[str],
) -> List[List[Any]]:
rows: List[List[Any]] = []
for item in interactions:
indices = item.get("indices")
if not indices:
continue
try:
value = float(item.get("value", 0.0))
except Exception:
continue
feats = [labels[int(i)] for i in indices if int(i) < len(labels)]
if feats:
rows.append([" × ".join(feats), value, len(indices)])
return rows
def _feature_display_label(
feature: Dict[str, Any],
region_labels: List[str],
) -> str:
raw = str(feature.get("feature", ""))
modality = feature.get("modality") or ""
ref_index = int(feature.get("ref_index", 0))
label = raw.split(":", 1)[1] if ":" in raw else raw
if modality == "image":
if 0 <= ref_index < len(region_labels):
return region_labels[ref_index]
return label or raw
def _extract_feature_series(payload: Dict[str, Any]) -> Tuple[List[str], List[float]]:
"""
Try to recover an ordered pair of (feature labels, values) from a backend payload.
This keeps duplicates in order (appending suffixes later) so word-level tokens
don't collapse to a single entry.
"""
features: List[str] = []
values: List[float] = []
feature_entries = payload.get("features")
if isinstance(feature_entries, list) and feature_entries and isinstance(feature_entries[0], dict):
for idx, entry in enumerate(feature_entries, start=1):
raw_feat = (
entry.get("feature")
or entry.get("token")
or entry.get("text")
or entry.get("label")
or ""
)
if not raw_feat:
raw_feat = f"feature_{idx}"
val = entry.get("value")
if val is None:
for key in ("score", "attribution", "weight"):
if key in entry:
val = entry[key]
break
try:
values.append(float(val if val is not None else 0.0))
except Exception:
values.append(0.0)
features.append(str(raw_feat))
if not features:
heat = payload.get("heatmap") or {}
tokens = heat.get("tokens") or heat.get("features")
scores = heat.get("values") or heat.get("scores")
if isinstance(tokens, list) and isinstance(scores, list) and len(tokens) == len(scores):
features = [str(token if token is not None else f"feature_{idx + 1}") for idx, token in enumerate(tokens)]
tmp_vals: List[float] = []
for score in scores:
try:
tmp_vals.append(float(score))
except Exception:
tmp_vals.append(0.0)
values = tmp_vals
if not features:
marginals = _resolve_marginals(payload)
if marginals:
features = list(marginals.keys())
values = [float(marginals[key]) for key in features]
if not features:
return [], []
unique_features = _assign_unique_labels(features)
return unique_features, values
def _resolve_interactions(payload: Dict[str, Any], order: int) -> List[Tuple[Tuple[str, ...], float]]:
candidates = [f"interactions_{order}"]
if order == 2:
candidates += ["pairwise", "pairwise_interactions", "interactions2"]
elif order == 3:
candidates += ["higher_order", "triple_interactions", "interactions3"]
for key in candidates:
raw = payload.get(key)
normalized = _normalize_interactions(raw)
if normalized:
return normalized
return []
def _fallback_pairwise_from_values(
features: List[str],
values: List[float],
max_edges: int = 40,
) -> List[Tuple[Tuple[str, ...], float]]:
"""
Generate synthetic pairwise links by connecting neighboring tokens.
Used when the backend provides no explicit interactions.
"""
n = min(len(features), len(values))
if n < 2:
return []
edges: List[Tuple[Tuple[str, ...], float]] = []
for idx in range(n - 1):
weight = 0.5 * (values[idx] + values[idx + 1])
edges.append(((features[idx], features[idx + 1]), weight))
edges.sort(key=lambda item: abs(item[1]), reverse=True)
return edges[:max_edges]
def _resolve_pairwise(
payload: Dict[str, Any],
features: Optional[List[str]] = None,
feature_values: Optional[List[float]] = None,
) -> List[Tuple[Tuple[str, ...], float]]:
"""Convenience helper to always pull order-2 interactions if present."""
pairwise = _resolve_interactions(payload, 2)
if pairwise:
return pairwise
# Some payloads store generic "interactions" lists that mix orders.
mixed = payload.get("interactions")
normalized = _normalize_interactions(mixed)
if normalized:
return [item for item in normalized if len(item[0]) == 2]
if features and feature_values:
return _fallback_pairwise_from_values(features, feature_values)
return []
def _normalize_method(method: Optional[str]) -> str:
method = (method or "shapley").lower()
return method if method in _ALLOWED_METHODS else "shapley"
def _normalize_level(level: Optional[str]) -> str:
level = (level or "sentence").lower()
return level if level in _ALLOWED_LEVELS else "sentence"
def _normalize_model_size(model_size: Optional[str]) -> str:
raw = (model_size or "small").strip()
lowered = raw.lower()
if lowered in {"small", "medium", "large"}:
return lowered
if "small" in lowered:
return "small"
if "medium" in lowered:
return "medium"
if "large" in lowered:
return "large"
return "small"
def _assign_unique_labels(chunks: List[str]) -> List[str]:
counts: Dict[str, int] = {}
labels: List[str] = []
for idx, chunk in enumerate(chunks):
normalized = " ".join((chunk or "").split())
if not normalized:
normalized = f"<chunk {idx + 1}>"
counts[normalized] = counts.get(normalized, 0) + 1
suffix = f" ({counts[normalized]})" if counts[normalized] > 1 else ""
labels.append(f"{normalized}{suffix}")
return labels
def _strip_occurrence_suffix(text: str) -> str:
text = text or ""
if text.endswith(")") and " (" in text:
base, _, tail = text.rpartition(" (")
if tail[:-1].isdigit():
return base
return text
def _pairwise_to_index_interactions(
pairwise: List[Tuple[Tuple[str, ...], float]],
features: List[str],
) -> List[Dict[str, Any]]:
feature_index = {feat: idx for idx, feat in enumerate(features)}
base_index: Dict[str, int] = {}
for idx, feat in enumerate(features):
base_index.setdefault(_strip_occurrence_suffix(feat), idx)
interactions: List[Dict[str, Any]] = []
for feats, val in pairwise:
if len(feats) != 2:
continue
a, b = feats
a_idx = None
b_idx = None
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
a_idx = int(a)
b_idx = int(b)
else:
try:
a_idx = int(str(a))
b_idx = int(str(b))
except ValueError:
a_idx = feature_index.get(a) or base_index.get(_strip_occurrence_suffix(str(a)))
b_idx = feature_index.get(b) or base_index.get(_strip_occurrence_suffix(str(b)))
if a_idx is None or b_idx is None:
continue
if a_idx < 0 or b_idx < 0 or a_idx >= len(features) or b_idx >= len(features):
continue
interactions.append({"indices": [a_idx, b_idx], "value": float(val)})
return interactions
def _locate_spans(text: str, segments: List[str]) -> List[Tuple[int, int]]:
spans: List[Tuple[int, int]] = []
cursor = 0
for segment in segments:
if not segment:
continue
idx = text.find(segment, cursor)
if idx == -1:
idx = cursor
end = idx + len(segment)
spans.append((idx, end))
cursor = end
return spans
def _chunk_text_for_visualization(
context: str,
level: str,
) -> Tuple[List[str], List[Tuple[int, int]], str]:
"""
Split input text into feature chunks and spans for visualization.
Falls back to the demo text if context is empty.
"""
text = context or _DEMO_TEXT
level = _normalize_level(level)
if level == "word":
matches = list(re.finditer(r"\S+", text))
chunks = [m.group(0) for m in matches]
spans = [(m.start(), m.end()) for m in matches]
elif level == "paragraph":
parts = [seg for seg in re.split(r"\n\s*\n+", text) if seg.strip()]
spans = _locate_spans(text, parts)
chunks = parts[: len(spans)]
else: # sentence-level default
parts = [seg for seg in re.split(r"(?<=[.!?])\s+", text) if seg.strip()]
spans = _locate_spans(text, parts)
chunks = parts[: len(spans)]
if not chunks:
chunks = [text]
spans = [(0, len(text))]
features = _assign_unique_labels(chunks)
return features, spans, text
def _generate_synthetic_marginals(
features: List[str],
rng: random.Random,
) -> Dict[str, float]:
if not features:
return {}
max_len = max(len(f) for f in features) or 1
marginals: Dict[str, float] = {}
denom = max(1, len(features) - 1)
for idx, feat in enumerate(features):
length_factor = len(feat) / max_len
position_factor = 1 - (idx / denom if denom else 0)
noise = rng.uniform(-0.25, 0.25)
value = (length_factor - 0.5) * 0.6 + (position_factor - 0.5) * 0.4 + noise
marginals[feat] = round(value, 4)
return marginals
def _generate_synthetic_interactions(
features: List[str],
marginals: Dict[str, float],
rng: random.Random,
) -> Dict[int, List[Tuple[Tuple[str, ...], float]]]:
interactions: Dict[int, List[Tuple[Tuple[str, ...], float]]] = {2: [], 3: []}
for i in range(len(features) - 1):
pair = (features[i], features[i + 1])
base = (marginals.get(pair[0], 0.0) + marginals.get(pair[1], 0.0)) / 2
interactions[2].append((pair, round(base + rng.uniform(-0.1, 0.1), 4)))
for i in range(len(features) - 2):
triple = (features[i], features[i + 1], features[i + 2])
base = sum(marginals.get(feat, 0.0) for feat in triple) / 3
interactions[3].append((triple, round(base + rng.uniform(-0.1, 0.1), 4)))
return interactions
def _synthetic_attribution_pipeline(
context: str,
prompt: str,
answer: str,
*,
method: str,
level: str,
order: int,
reason: Optional[str] = None,
) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]:
text_source = context or prompt or answer or _DEMO_TEXT
features, spans, text = _chunk_text_for_visualization(text_source, level)
seed = hash((text_source, method, level, order)) & 0xFFFFFFFF
rng = random.Random(seed)
marginals = _generate_synthetic_marginals(features, rng)
interactions = _generate_synthetic_interactions(features, marginals, rng)
html = None
if len(spans) == len(features):
html = create_interactive_text_heatmap(
text,
spans,
[marginals.get(f, 0.0) for f in features],
method=method,
)
meta = {
"mode": "synthetic",
"reason": reason or "Attribution backend unavailable; showing mock data.",
"method": method,
"feature_level": level,
"interaction_order": order,
"feature_count": len(features),
}
inter_list = interactions.get(order, [])
pairwise_for_tokens = interactions.get(2, []) if order != 2 else inter_list
if not pairwise_for_tokens:
pairwise_for_tokens = _fallback_pairwise_from_values(
features,
[marginals.get(f, 0.0) for f in features],
)
text_interaction_html = create_text_interaction_html(
features,
[marginals.get(f, 0.0) for f in features],
_pairwise_to_index_interactions(pairwise_for_tokens, features),
method=method,
top_k=20,
threshold=0.0,
)
figs = {
"interactions": plot_top_interactions(inter_list, order=order, method=method),
}
return update(
figs=figs,
meta=meta,
html=html,
interaction_text_html=text_interaction_html,
scoring_target_source="answer_input" if answer else "model_output",
scoring_target_text=answer or "",
reference_answer=answer or "",
unmasked_answer="",
debug_scores=None,
scalarizer_used="logprob",
score_full=None,
score_empty=None,
y_len_tokens=None,
)
# def _compute_live_attributions(**kwargs) -> Tuple[Any, Any, Any, Any, Any]:
# """
# Placeholder for the real ProxySPEX + perplexity pipeline.
# Raises until the attribution backend is implemented.
# """
# missing = [
# name
# for name, fn in {
# "get_model": get_model,
# "get_masker": get_masker,
# "mask_text": mask_text,
# "run_proxyspex": run_proxyspex,
# "mobius_to_shapley": mobius_to_shapley,
# "mobius_to_banzhaf": mobius_to_banzhaf,
# "shapley_interactions": shapley_interactions,
# "banzhaf_interactions": banzhaf_interactions,
# }.items()
# if fn is None
# ]
# if missing:
# raise RuntimeError(
# "Missing backend dependencies: " + ", ".join(sorted(missing))
# )
# raise NotImplementedError(
# "Live attribution pipeline not wired yet. Integrate once ProxySPEX is ready."
# )
def _compute_live_attributions(
*,
context: str,
prompt: str,
correct_answer: str,
model_size: str,
level: str,
method: str,
order: int,
scalarizer: str = "logprob",
embedding_model: str | None = None,
progress=None,
) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]:
"""
Call the FastAPI /api/attributions + /api/interactions backends and turn
the JSON into figures / table / HTML for Gradio.
This version is very defensive and tries hard to extract interactions
from whatever shape the backend returns.
"""
method = _normalize_method(method)
level = _normalize_level(level)
order = 3 if int(order or 2) >= 3 else 2
context = context or ""
prompt = prompt or ""
correct_answer = correct_answer or ""
text_source = context or prompt or correct_answer or _DEMO_TEXT
payload = {
"context": context,
"answer": correct_answer,
"reference_answer": correct_answer,
"prompt": prompt,
"method": method,
"mask_level": level,
"order": int(order),
"model_size": model_size,
"scalarizer": scalarizer,
"embedding_model": embedding_model,
"debug": False,
}
if progress is not None:
progress(0.1, desc="Calling attribution backend")
# ---------- 1. /api/attributions ----------
url_attr = BACKEND_URL.rstrip("/") + "/api/attributions"
try:
resp_attr = requests.post(url_attr, json=payload, timeout=REQUEST_TIMEOUT)
except requests.exceptions.ReadTimeout as exc:
raise gr.Error(
"Attribution request timed out. The backend may still be running. "
"Consider reducing feature granularity or set ATTRLLM_REQUEST_TIMEOUT to a higher value."
) from exc
if resp_attr.status_code >= 400:
_raise_backend_error(resp_attr, "Attribution request")
data_attr = resp_attr.json()
if progress is not None:
progress(0.35, desc="Received attribution payload")
# ---------- 2. FEATURES + MARGINAL VALUES ----------
features, feature_values = _extract_feature_series(data_attr)
if not features:
features = ["<no features>"]
feature_values = [0.0]
marginals = {feat: float(feature_values[idx]) for idx, feat in enumerate(features)}
# ---------- 3. /api/interactions ----------
if progress is not None:
progress(0.45, desc="Calling interactions backend")
url_int = BACKEND_URL.rstrip("/") + "/api/interactions"
try:
resp_int = requests.post(url_int, json=payload, timeout=REQUEST_TIMEOUT)
except requests.exceptions.ReadTimeout as exc:
raise gr.Error(
"Interaction request timed out. The backend may still be running. "
"Consider reducing order or set ATTRLLM_REQUEST_TIMEOUT to a higher value."
) from exc
if resp_int.status_code >= 400:
_raise_backend_error(resp_int, "Interaction request")
data_int = resp_int.json()
# DEBUG: see top-level keys
print("data_int keys:", list(data_int.keys()))
inter_list_all = _extract_interactions_from_response(data_int, method, features)
pairwise_for_network = [item for item in inter_list_all if len(item[0]) == 2]
used_pairwise_fallback = False
inter_list = inter_list_all
if inter_list:
filtered: List[Tuple[Tuple[str, ...], float]] = []
for feats, val in inter_list:
if len(feats) == order:
filtered.append((feats, val))
if filtered:
inter_list = filtered
if order != 2 and not pairwise_for_network:
try:
payload_pair = dict(payload)
payload_pair["order"] = 2
try:
resp_pair = requests.post(url_int, json=payload_pair, timeout=REQUEST_TIMEOUT)
except requests.exceptions.ReadTimeout as exc:
raise gr.Error(
"Interaction request timed out. The backend may still be running. "
"Consider reducing order or set ATTRLLM_REQUEST_TIMEOUT to a higher value."
) from exc
if resp_pair.status_code >= 400:
_raise_backend_error(resp_pair, "Interaction request")
data_pair = resp_pair.json()
pairwise_for_network = [
item for item in _extract_interactions_from_response(data_pair, method, features)
if len(item[0]) == 2
]
except Exception as exc:
print("Pairwise interaction fetch failed:", exc)
if not pairwise_for_network:
if method == "influence":
pairwise_for_network = []
else:
used_pairwise_fallback = True
pairwise_for_network = _fallback_pairwise_from_values(features, feature_values)
print("LIVE features:", features)
print("LIVE inter_list (first 3):", inter_list[:3])
if method == "influence":
top_singletons = sorted(
list(zip(features, feature_values)),
key=lambda kv: abs(float(kv[1])),
reverse=True,
)[:10]
top_pairs = sorted(
pairwise_for_network,
key=lambda kv: abs(float(kv[1])),
reverse=True,
)[:10]
print(
"[influence-ui-debug] "
f"pairwise_source={'fallback_neighbors' if used_pairwise_fallback else 'backend'} "
f"feature_count={len(features)} pair_count={len(pairwise_for_network)}"
, flush=True)
print("[influence-ui-debug] top_singletons:", top_singletons, flush=True)
print("[influence-ui-debug] top_pairwise:", top_pairs, flush=True)
text_interaction_html = create_text_interaction_html(
features,
feature_values,
_pairwise_to_index_interactions(pairwise_for_network, features),
method=method,
top_k=20,
threshold=0.0,
)
# ---------- 4. RESCALE VERY SMALL VALUES ----------
max_abs = max((abs(v) for v in marginals.values()), default=0.0)
scale = 1.0
if 0 < max_abs < 1e-3:
scale = 1e3
if scale != 1.0:
marginals = {k: v * scale for k, v in marginals.items()}
inter_list = [(feats, val * scale) for feats, val in inter_list]
feature_values = [val * scale for val in feature_values]
# ---------- 5. INLINE TEXT HEATMAP ----------
spans = None
masking = data_attr.get("masking") or data_attr.get("mask") or {}
if isinstance(masking, dict):
spans = masking.get("feature_spans") or masking.get("spans")
html = None
if spans and len(spans) == len(feature_values):
html = create_interactive_text_heatmap(
context or text_source,
spans,
feature_values,
method=method,
)
# ---------- 6. PLOTS + TABLES + META ----------
inter_fig = plot_top_interactions(inter_list, order=order, method=method)
if progress is not None:
progress(0.8, desc="Rendering visualizations")
y_len_tokens = data_attr.get("y_len_tokens")
scoring_target_source = data_attr.get("scoring_target_source") or "model_output"
scoring_target_text = data_attr.get("scoring_target_text")
if scoring_target_text is None:
scoring_target_text = correct_answer or data_attr.get("y_full") or ""
meta = {
"mode": "live",
"backend_url_attr": url_attr,
"backend_url_int": url_int,
"method": method,
"feature_level": level,
"interaction_order": order,
"model_size": model_size,
"feature_count": len(features),
"max_abs_value": max_abs,
"scale_applied": scale,
"scalarizer": data_attr.get("scalarizer_used", payload.get("scalarizer")),
"scoring_target_source": scoring_target_source,
"scoring_target_text_preview": str(scoring_target_text)[:200],
"score_full": data_attr.get("score_full"),
"score_empty": data_attr.get("score_empty"),
"y_len_tokens": y_len_tokens,
"logprob_full": data_attr.get("logprob_full"),
"logprob_empty": data_attr.get("logprob_empty"),
"min_logprob_seen": data_attr.get("min_logprob_seen"),
"reference_answer_received": data_attr.get("reference_answer_received"),
"answer_received": data_attr.get("answer_received"),
"raw_attr_keys": list(data_attr.keys()),
"raw_int_keys": list(data_int.keys()),
}
reference_answer = correct_answer
unmasked_answer = data_attr.get("y_full") or data_attr.get("unmasked_answer") or ""
debug_scores = data_attr.get("debug_scores") or None
interaction_chips_html = create_interaction_token_view(
features,
feature_values,
pairwise_for_network,
method=method,
layout="sentence" if level == "sentence" else "token",
)
figs = {
"interactions": inter_fig,
}
if progress is not None:
progress(1.0, desc="Done")
return update(
figs=figs,
meta=meta,
html=html,
interaction_html=interaction_chips_html,
interaction_text_html=text_interaction_html,
scoring_target_source=scoring_target_source,
scoring_target_text=str(scoring_target_text),
reference_answer=reference_answer,
unmasked_answer=unmasked_answer,
debug_scores=debug_scores,
scalarizer_used=data_attr.get("scalarizer_used", payload.get("scalarizer")),
score_full=data_attr.get("score_full"),
score_empty=data_attr.get("score_empty"),
y_len_tokens=y_len_tokens,
)
# ═══════════════════════════════════════════════════════════════════════════
# CLIP-based live compute helpers (Custom Image / Custom Multimodal tabs)
# ═══════════════════════════════════════════════════════════════════════════
_CLIP_MODEL_MAP: Dict[str, str] = {
"CLIP (openai/clip-vit-base-patch32)": "openai/clip-vit-base-patch32",
"BiomedCLIP": "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
}
def _get_clip_scorer(model_display: str) -> "CrossModalCLIPScorer":
"""Load or return cached CLIP scorer. Apply dot-mask fix."""
model_name = _CLIP_MODEL_MAP.get(model_display, model_display)
if model_name in _clip_scorer_cache:
return _clip_scorer_cache[model_name]
import torch as _torch
device = "cuda" if _torch.cuda.is_available() else "cpu"
cfg = PipelineConfig(clip_model_name=model_name, device=device)
scorer = CrossModalCLIPScorer(cfg)
# Dot-mask fix: use "." (ID 269) instead of EOS to avoid CLIP argmax-pooling shift
_neutral_ids = scorer.processor.tokenizer.encode(".", add_special_tokens=False)
if _neutral_ids:
scorer.unk_token_id = _neutral_ids[0]
_clip_scorer_cache[model_name] = scorer
return scorer
def _run_clip_attribution(
image: Image.Image,
caption: str,
clip_model: str,
seg_mode: str,
grid_size: int,
method: str,
seed: int,
progress=None,
) -> Dict[str, Any]:
"""
Core CLIP cross-modal attribution pipeline shared by both custom tabs.
Returns a dict with regions, token_players, values, interactions,
overlay images, masked images, and influence matrix.
"""
import numpy as np
if not _CLIP_PIPELINE_AVAILABLE:
raise gr.Error(
"CLIP pipeline not available. Ensure attribution.set_mm is importable "
"(requires transformers, lightgbm, numpy, scipy)."
)
method = _normalize_method(method)
# Check LaMa availability, fall back to blur
try:
from simple_lama_inpainting import SimpleLama # noqa: F401
mask_style = "lama"
except ImportError:
mask_style = "blur"
import torch as _torch
device = "cuda" if _torch.cuda.is_available() else "cpu"
model_name = _CLIP_MODEL_MAP.get(clip_model, clip_model)
cfg = PipelineConfig(
mode="patch" if seg_mode == "Patch Grid" else "unsam",
grid_size=int(grid_size),
mask_style=mask_style,
clip_model_name=model_name,
max_tokens=15,
method=method,
max_order=2,
top_k_interactions=15,
random_seed=int(seed),
device=device,
)
if progress is not None:
progress(0.05, desc="Loading CLIP model...")
scorer = _get_clip_scorer(clip_model)
# Step 1: Featurise image
if progress is not None:
progress(0.10, desc="Segmenting image...")
try:
regions = featurise(image, cfg)
except Exception as exc:
if seg_mode != "Patch Grid":
raise gr.Error(
f"UnSAM segmentation failed: {exc}. "
"Try using 'Patch Grid' instead."
) from exc
raise
# Step 2: Tokenise caption
if progress is not None:
progress(0.15, desc="Tokenising caption...")
token_players, full_token_ids = tokenise_caption(
caption, scorer.processor, cfg, offset=len(regions)
)
n_img = len(regions)
n_tok = len(token_players)
n_total = n_img + n_tok
# Step 3: Build cross-modal set function
if progress is not None:
progress(0.20, desc="Building set function...")
game = build_cross_modal_set_function(
image, regions, token_players, full_token_ids, scorer, cfg
)
# Step 4: Run ProxySPEX (run_proxyspex wraps the set function for 2D batch calls)
_raw_labels = [r.label for r in regions] + [tp.label for tp in token_players]
# all_labels is rebuilt after tok_vals disambiguation below; _raw_labels for ProxySPEX
if progress is not None:
progress(0.25, desc=f"Running ProxySPEX (n={n_total})...")
mobius = run_proxyspex(game, _raw_labels, max_order=2, seed=int(seed))
# Step 5: Derive Shapley/Banzhaf values
if progress is not None:
progress(0.70, desc="Computing values...")
if method == "banzhaf":
values = mobius_to_banzhaf(mobius)
else:
values = mobius_to_shapley(mobius)
# Split into image and token values
# Disambiguate duplicate labels (e.g., two "the" tokens) by appending #N
img_vals = {regions[i].label: float(values.get((i,), 0.0)) for i in range(n_img)}
tok_vals = {}
_tok_label_counts: Dict[str, int] = {}
for j, tp in enumerate(token_players):
label = tp.label
count = _tok_label_counts.get(label, 0)
_tok_label_counts[label] = count + 1
key = f"{label}#{count}" if count > 0 else label
tok_vals[key] = float(values.get((n_img + j,), 0.0))
# Rebuild all_labels with disambiguated token labels
all_labels = list(img_vals.keys()) + list(tok_vals.keys())
# Step 6: Extract interactions
interactions = extract_interactions(mobius, order=2, top_k=15)
cross_per_token, cross_global_top5 = extract_cross_per_token(mobius, n_img, n_tok)
# Image-image and token-token interactions
img_filter = lambda loc: all(i < n_img for i in loc)
tok_filter = lambda loc: all(i >= n_img for i in loc)
interactions_img = extract_interactions(mobius, order=2, top_k=10, player_filter=img_filter)
interactions_tok = extract_interactions(mobius, order=2, top_k=10, player_filter=tok_filter)
# Cross-modal interactions (for bar chart)
cross_filter = lambda loc: any(i < n_img for i in loc) and any(i >= n_img for i in loc)
cross_interactions = extract_interactions(mobius, order=2, top_k=15, player_filter=cross_filter)
# Step 7: Build influence matrix [n_img x n_tok]
influence_matrix = np.zeros((n_img, n_tok))
for loc, val in cross_interactions:
img_indices = [i for i in loc if i < n_img]
tok_indices = [i - n_img for i in loc if i >= n_img]
for ii in img_indices:
for tj in tok_indices:
if 0 <= ii < n_img and 0 <= tj < n_tok:
influence_matrix[ii, tj] += float(val)
# Step 8: Render overlay and segmap
if progress is not None:
progress(0.75, desc="Rendering overlay...")
img_val_list = [float(values.get((i,), 0.0)) for i in range(n_img)]
overlay_rgba = render_overlay(image, regions, img_val_list)
base_rgba = image.convert("RGBA")
overlay_img = Image.alpha_composite(base_rgba, overlay_rgba).convert("RGB")
overlay_b64 = _encode_image_to_b64(overlay_img)
segmap_img = render_segmentation_map(image, regions)
segmap_b64 = _encode_image_to_b64(segmap_img)
# Step 9: Build segment bboxes (% coordinates for interactive view)
w, h = image.size
segment_bboxes = []
for reg in regions:
x0, y0, x1, y1 = reg.bbox
segment_bboxes.append({
"x0_pct": 100.0 * x0 / w,
"y0_pct": 100.0 * y0 / h,
"w_pct": 100.0 * (x1 - x0) / w,
"h_pct": 100.0 * (y1 - y0) / h,
"cx_pct": 100.0 * (x0 + x1) / 2 / w,
"cy_pct": 100.0 * (y0 + y1) / 2 / h,
})
# Step 10: Generate masked images for browser
if progress is not None:
progress(0.80, desc="Generating masked images...")
masked_images: Dict[str, Image.Image] = {}
for i, reg in enumerate(regions):
# "removed" — mask only this region
coal_removed = [1] * n_img
coal_removed[i] = 0
removed_img = apply_image_mask(
image, regions, coal_removed, style=cfg.mask_style,
blur_radius=cfg.blur_radius, cfg=cfg,
)
masked_images[f"{reg.label} removed"] = removed_img
if progress is not None:
progress(0.90, desc="Done computing.")
return {
"regions": regions,
"token_players": token_players,
"all_labels": all_labels,
"image_values": img_vals,
"token_values": tok_vals,
"values": values,
"mobius": mobius,
"interactions": interactions,
"interactions_img": interactions_img,
"interactions_tok": interactions_tok,
"cross_interactions": cross_interactions,
"cross_per_token": cross_per_token,
"cross_global_top5": cross_global_top5,
"influence_matrix": influence_matrix,
"overlay_img": overlay_img,
"overlay_b64": overlay_b64,
"segmap_img": segmap_img,
"segmap_b64": segmap_b64,
"segment_bboxes": segment_bboxes,
"masked_images": masked_images,
"method": method,
"n_img": n_img,
"n_tok": n_tok,
"mask_style": mask_style,
"seg_mode": seg_mode,
"grid_size": int(grid_size),
}
def _build_masked_choices(masked_images: Dict[str, Image.Image]) -> List[str]:
"""Return sorted list of masked image choice labels."""
return sorted(masked_images.keys())
def _on_masked_image_select(choice: str, state: Dict) -> Optional[Image.Image]:
"""Return the masked PIL image for a dropdown choice."""
if not state or not choice:
return None
return state.get(choice)
def _compute_image_attributions_clip(
image: Image.Image,
caption: str,
clip_model: str,
seg_mode: str,
grid_size: int,
method: str,
seed: int,
progress=None,
):
"""Compute image-only attributions using CLIP pipeline. Returns UI outputs."""
if image is None:
raise gr.Error("Please upload an image.")
if not caption or not caption.strip():
raise gr.Error("Please provide a caption or description.")
result = _run_clip_attribution(
image, caption.strip(), clip_model, seg_mode, int(grid_size),
method, int(seed or 0), progress=progress,
)
# Build region bar chart
seg_labels = list(result["image_values"].keys())
seg_vals = list(result["image_values"].values())
region_chart = create_shapley_bar_chart(seg_labels, seg_vals, "Region Attribution")
# Build masked image state and dropdown choices
masked_state = result["masked_images"]
choices = _build_masked_choices(masked_state)
meta = {
"mode": "image_clip",
"method": result["method"],
"clip_model": clip_model,
"seg_mode": result["seg_mode"],
"grid_size": result["grid_size"],
"mask_style": result["mask_style"],
"n_regions": result["n_img"],
"n_tokens": result["n_tok"],
}
if progress is not None:
progress(1.0, desc="Done")
# Returns: original_img, overlay_img, region_chart, masked_dropdown, masked_img, masked_state, meta
return (
image,
result["overlay_img"],
region_chart,
gr.update(choices=choices, value=choices[0] if choices else None),
masked_state.get(choices[0]) if choices else None,
masked_state,
meta,
)
def _compute_mm_attributions_clip(
image: Image.Image,
caption: str,
clip_model: str,
seg_mode: str,
grid_size: int,
method: str,
seed: int,
progress=None,
):
"""Compute cross-modal attributions using CLIP pipeline. Returns UI outputs."""
import numpy as np
if image is None:
raise gr.Error("Please upload an image.")
if not caption or not caption.strip():
raise gr.Error("Please provide a caption or description.")
result = _run_clip_attribution(
image, caption.strip(), clip_model, seg_mode, int(grid_size),
method, int(seed or 0), progress=progress,
)
all_labels = result["all_labels"]
n_img = result["n_img"]
n_tok = result["n_tok"]
# Region bar chart
seg_labels = list(result["image_values"].keys())
seg_vals = list(result["image_values"].values())
region_chart = create_shapley_bar_chart(seg_labels, seg_vals, "Region Attribution")
# Token bar chart
tok_labels = list(result["token_values"].keys())
tok_vals = list(result["token_values"].values())
token_chart = create_shapley_bar_chart(tok_labels, tok_vals, "Token Attribution")
# Cross-modal bar chart — expects List[Tuple[Tuple[str, str], float]]
cross_pairs = []
for loc, val in result["cross_interactions"]:
img_parts = [all_labels[i] for i in loc if i < n_img]
tok_parts = [all_labels[i] for i in loc if i >= n_img]
if img_parts and tok_parts:
cross_pairs.append(((img_parts[0], tok_parts[0]), float(val)))
cross_chart = create_cross_modal_bar_chart(cross_pairs, "Cross-Modal Interactions", top_k=15)
# Influence heatmap
heatmap = create_influence_heatmap(
seg_labels, tok_labels, result["influence_matrix"],
"Influence Heatmap (Regions x Tokens)"
)
# Interactive cross-modal HTML view
# Build clip_summary dict matching what benchmark_interaction expects
clip_summary = {
"image_region_values": [
{"label": seg_labels[i], "value": float(seg_vals[i])} for i in range(n_img)
],
"token_values": [
{"label": tok_labels[j], "value": float(tok_vals[j])} for j in range(n_tok)
],
"cross_modal_interactions": [
{"label": " x ".join(all_labels[i] for i in loc), "value": float(val)}
for loc, val in result["cross_global_top5"]
],
}
image_b64 = _encode_image_to_b64(image)
interaction_html = create_benchmark_interaction_html(
image_b64=image_b64,
clip_summary=clip_summary,
vllm_logprob=None,
caption=caption,
all_cross_modal_pairs=[
{
"pair": (
all_labels[loc[0]] if loc[0] < n_img else all_labels[loc[1]],
all_labels[loc[1]] if loc[1] >= n_img else all_labels[loc[0]],
),
"value": float(val),
}
for loc, val in result["cross_interactions"]
],
segmap_b64=result["segmap_b64"],
overlay_b64=result["overlay_b64"],
segment_bboxes=result["segment_bboxes"],
label_map_b64="",
image_width=image.size[0],
image_height=image.size[1],
title="Cross-Modal Interaction View",
)
# Masked image state
masked_state = result["masked_images"]
choices = _build_masked_choices(masked_state)
meta = {
"mode": "multimodal_clip",
"method": result["method"],
"clip_model": clip_model,
"seg_mode": result["seg_mode"],
"grid_size": result["grid_size"],
"mask_style": result["mask_style"],
"n_regions": n_img,
"n_tokens": n_tok,
}
if progress is not None:
progress(1.0, desc="Done")
# Returns: original_img, overlay_img, region_chart, token_chart,
# cross_chart, heatmap, interaction_html,
# masked_dropdown, masked_img, masked_state, meta
return (
image,
result["overlay_img"],
region_chart,
token_chart,
cross_chart,
heatmap,
interaction_html,
gr.update(choices=choices, value=choices[0] if choices else None),
masked_state.get(choices[0]) if choices else None,
masked_state,
meta,
)
def on_select_example(
dataset,
ex_id,
model_size,
order,
method,
scalarizer=None,
feature_level=None,
):
"""
Public mode handler: load a precomputed example and render figures.
Args:
dataset (str): dataset name
ex_id (str): example id
model_size (str): "small" | "medium" | "large"
order (int): interaction order (2 or 3)
method (str): "shapley" | "banzhaf" | "influence"
Returns:
tuple ordered as:
(
context,
prompt,
answer,
interactions_plot,
interactions_token_html,
text_html,
meta_json,
)
"""
get_res = get_result_by_id if get_result_by_id is not None else _public_get_result_from_file
model_size = _normalize_model_size(model_size)
example = {"context": "", "prompt": "", "answer": ""}
if get_example_by_id is not None:
try:
example = get_example_by_id(dataset, ex_id)
except Exception:
pass
result = get_res(
model_size,
dataset,
ex_id,
scalarizer=scalarizer,
feature_level=feature_level,
) or {}
payload = result.get(method, {})
# Your JSON: features (list of strings) + mobius_dict. Convert to UI format if needed.
feats = payload.get("features") if isinstance(payload, dict) else None
if isinstance(feats, list) and feats and not isinstance(feats[0], dict):
payload = _normalize_public_payload_fallback(payload, method)
features, feature_values = _extract_feature_series(payload)
if not features:
features = ["<no features>"]
feature_values = [0.0]
# Influence scores are non-negative (squared Fourier coefficients)
if method == "influence":
feature_values = [abs(v) for v in feature_values]
marginals = {feat: float(feature_values[idx]) for idx, feat in enumerate(features)}
interactions = _resolve_interactions(payload, order)
if method == "influence":
interactions = [(feats, abs(val)) for feats, val in interactions]
pairwise = _resolve_interactions(payload, 2)
if not pairwise:
mixed = payload.get("interactions")
normalized = _normalize_interactions(mixed)
if normalized:
pairwise = [item for item in normalized if len(item[0]) == 2]
pairwise = [(feats, abs(val)) for feats, val in pairwise]
else:
pairwise = _resolve_pairwise(payload, features, feature_values)
if method == "influence":
top_singletons = sorted(
list(zip(features, feature_values)),
key=lambda kv: abs(float(kv[1])),
reverse=True,
)[:10]
top_pairs = sorted(
pairwise,
key=lambda kv: abs(float(kv[1])),
reverse=True,
)[:10]
print(
"[influence-ui-debug][public] "
f"dataset={dataset} ex_id={ex_id} feature_count={len(features)} pair_count={len(pairwise)}"
, flush=True)
print("[influence-ui-debug][public] top_singletons:", top_singletons, flush=True)
print("[influence-ui-debug][public] top_pairwise:", top_pairs, flush=True)
payload_level = (
payload.get("mask_level")
or payload.get("feature_level")
or payload.get("level")
or (result.get("meta", {}) if isinstance(result, dict) else {}).get("feature_level")
)
layout_mode = "sentence" if _normalize_level(payload_level) == "sentence" else "token"
inter = plot_top_interactions(interactions, order=order, method=method)
spans = payload.get("feature_spans") or payload.get("spans")
if not spans:
# Precomputed JSON payloads may not include explicit spans.
# Reconstruct spans from context + feature level so Text View can render.
_, fallback_spans, _ = _chunk_text_for_visualization(
example.get("context", ""),
_normalize_level(payload_level),
)
if fallback_spans and len(fallback_spans) == len(feature_values):
spans = fallback_spans
html = None
if spans and len(spans) == len(feature_values):
html = create_interactive_text_heatmap(
example.get("context", ""),
spans,
feature_values,
method=method,
)
# Compute the wrong-answer payload up-front so the dual heatmap branch
# (which rewrites text_interaction_html below) has it ready.
_wrong_values_for_dual: Optional[List[float]] = None
_wrong_pairwise_for_dual: Optional[List[Any]] = None
_wrong_features_for_dual: Optional[List[str]] = None
try:
from visualization.wrong_answer_examples import has_wrong_answer_view as _has_wrong_view
except Exception:
_has_wrong_view = None
_is_wrong_view = bool(
_has_wrong_view is not None
and html is not None
and spans
and _has_wrong_view(dataset, ex_id, scalarizer or "", feature_level or "")
)
if _is_wrong_view:
wrong_result = _public_get_model_answer_short_from_file(
model_size, dataset, ex_id, scalarizer or "geomean_jointprob",
feature_level or "word",
)
wrong_payload = wrong_result.get(method, {}) if wrong_result else {}
wrong_features_local, wrong_values_local = _extract_feature_series(wrong_payload)
if method == "influence":
wrong_values_local = [abs(v) for v in wrong_values_local]
if wrong_features_local and len(wrong_values_local) == len(feature_values):
# Build wrong-side pairwise edges, mirroring the GT logic above.
wrong_pairwise = _resolve_interactions(wrong_payload, 2)
if not wrong_pairwise:
mixed = wrong_payload.get("interactions") if isinstance(wrong_payload, dict) else None
normalized = _normalize_interactions(mixed)
if normalized:
wrong_pairwise = [item for item in normalized if len(item[0]) == 2]
if method == "influence":
wrong_pairwise = [(f, abs(v)) for f, v in (wrong_pairwise or [])]
else:
# Best-effort: if no explicit pairwise, derive from wrong feature values
if not wrong_pairwise:
wrong_pairwise = _resolve_pairwise(wrong_payload, wrong_features_local, wrong_values_local)
_wrong_values_for_dual = wrong_values_local
_wrong_pairwise_for_dual = wrong_pairwise or []
_wrong_features_for_dual = wrong_features_local
else:
_is_wrong_view = False
meta = {
"dataset": dataset,
"example_id": ex_id,
"model_size": model_size,
"method": method,
"order": order,
"feature_count": len(features),
"payload_keys": sorted(payload.keys()),
}
if "meta" in result:
meta["source_meta"] = result["meta"]
interaction_chips_html = create_interaction_token_view(
features,
feature_values,
pairwise or [item for item in interactions if len(item[0]) == 2],
method=method,
layout=layout_mode,
)
text_interaction_html = create_text_interaction_html(
features,
feature_values,
_pairwise_to_index_interactions(
pairwise or [item for item in interactions if len(item[0]) == 2],
features,
),
method=method,
top_k=20,
threshold=0.0,
)
# For the 30 wrong-answer examples, replace the visible Text Interaction
# view with two chip+arc panels side-by-side (vs Ground Truth | vs Model
# Answer (Wrong)) plus a single shared legend + RAW TEXT below.
if (
_is_wrong_view
and _wrong_values_for_dual is not None
and _wrong_features_for_dual is not None
):
gt_view = create_text_interaction_html(
features,
feature_values,
_pairwise_to_index_interactions(
pairwise or [item for item in interactions if len(item[0]) == 2],
features,
),
method=method,
top_k=20,
threshold=0.0,
)
wrong_view = create_text_interaction_html(
_wrong_features_for_dual,
_wrong_values_for_dual,
_pairwise_to_index_interactions(
_wrong_pairwise_for_dual or [],
_wrong_features_for_dual,
),
method=method,
top_k=20,
threshold=0.0,
)
method_label = (method or "attribution").title()
gt_max_abs = max((abs(v) for v in feature_values), default=0.0) or 1.0
wrong_max_abs = max((abs(v) for v in _wrong_values_for_dual), default=0.0) or 1.0
from html import escape as _escape
raw_text = example.get("context", "") or ""
raw_text_html = _escape(raw_text).replace("\n", "<br/>") if raw_text else ""
# CSS scoped to .dual-heatmap-row hides the per-side legend so we can
# show one shared legend below; tightens the per-card max width so two
# views fit comfortably side-by-side.
dual_css = (
"<style>"
".dual-heatmap-row{display:grid;grid-template-columns:1fr 1fr;gap:16px;align-items:start;}"
".dual-heatmap-row .text-interaction-side-panel{display:none !important;}"
".dual-heatmap-row .text-interaction-root{flex:1 1 100%;}"
".dual-heatmap-row .text-interaction-card{flex:1 1 100%;max-width:100%;}"
".dual-heatmap-shared{margin-top:16px;display:grid;grid-template-columns:1fr 1fr;gap:16px;}"
".dual-heatmap-shared .shared-card{background:#f8f5ff;border:1px solid #e2d6f3;"
"border-radius:12px;padding:12px 14px;box-shadow:0 4px 10px rgba(80,50,140,0.05);}"
".dual-heatmap-shared .shared-legend-bar{display:flex;align-items:center;gap:8px;margin:6px 0;}"
".dual-heatmap-shared .shared-legend-label{font-size:12px;color:#6f5a72;text-transform:uppercase;letter-spacing:.04em;}"
".dual-heatmap-shared .shared-legend-gradient{flex:1;height:10px;border-radius:999px;"
"background:linear-gradient(90deg,#dd1313 0%,#d8c6f0 50%,#4a1c87 100%);}"
".dual-heatmap-shared .shared-legend-note{font-size:12px;color:#6f5a72;margin:4px 0 0 0;}"
".dual-heatmap-shared .shared-raw-text p{margin:6px 0 0 0;line-height:1.5;color:#3a2b4a;}"
"@media (prefers-color-scheme: dark){"
".dual-heatmap-shared .shared-card{background:#111a2b;border-color:#33435f;}"
".dual-heatmap-shared .shared-legend-label,.dual-heatmap-shared .shared-legend-note,"
".dual-heatmap-shared .shared-raw-text p{color:#a9b6cb;}}"
"@media (max-width: 900px){"
".dual-heatmap-row,.dual-heatmap-shared{grid-template-columns:1fr;}}"
"</style>"
)
shared_block = (
'<div class="dual-heatmap-shared">'
'<div class="shared-card">'
f'<strong>{method_label} legend</strong>'
'<div class="shared-legend-bar">'
'<span class="shared-legend-label">Negative</span>'
'<div class="shared-legend-gradient"></div>'
'<span class="shared-legend-label">Positive</span>'
'</div>'
'<p class="shared-legend-note">'
f'Ground-truth max |value| = {gt_max_abs:.4f}; '
f'wrong-answer max |value| = {wrong_max_abs:.4f}. '
'Hover tokens for exact scores.'
'</p>'
'</div>'
'<div class="shared-card shared-raw-text">'
'<strong>Raw text</strong>'
f'<p>{raw_text_html or "<em>No context available.</em>"}</p>'
'</div>'
'</div>'
)
text_interaction_html = (
f'{dual_css}'
'<div class="dual-heatmap-row">'
'<div>'
'<div class="heatmap-caption" '
'style="font-weight:600;margin-bottom:6px;">vs Ground Truth</div>'
f'{gt_view}'
'</div>'
'<div>'
'<div class="heatmap-caption" '
'style="font-weight:600;margin-bottom:6px;">vs Model Answer (Wrong)</div>'
f'{wrong_view}'
'</div>'
'</div>'
f'{shared_block}'
)
print(
f"[wrong-answer] dual chip+lines view rendered for {dataset}/{ex_id} "
f"(gt_features={len(features)} wrong_features={len(_wrong_features_for_dual)} "
f"gt_pairs={len(pairwise or [])} wrong_pairs={len(_wrong_pairwise_for_dual or [])})",
flush=True,
)
figs = {
"interactions": inter,
}
outputs = update(
figs=figs,
meta=meta,
html=html,
interaction_html=interaction_chips_html,
interaction_text_html=text_interaction_html,
)
return (
example.get("context", ""),
example.get("prompt", ""),
_extract_answer(example),
*outputs,
)
def on_click_compute(
context,
prompt,
correct_answer,
model_size,
level,
method,
scalarizer,
embedding_model,
progress=gr.Progress(track_tqdm=True),
):
# """
# Developer mode handler: compute (or mock) attributions and render figures.
# """
# method = _normalize_method(method)
# level = _normalize_level(level)
# order = 3 if int(order or 2) >= 3 else 2
# context = context or ""
# prompt = prompt or ""
# correct_answer = correct_answer or ""
# try:
# return _compute_live_attributions(
# context=context,
# prompt=prompt,
# correct_answer=correct_answer,
# model_size=model_size,
# level=level,
# method=method,
# order=order,
# progress=progress,
# )
# except Exception as exc: # pragma: no cover - best-effort fallback
# return _synthetic_attribution_pipeline(
# context,
# prompt,
# correct_answer,
# method=method,
# level=level,
# order=order,
# reason=str(exc),
# )
method = _normalize_method(method)
level = _normalize_level(level)
model_size = _normalize_model_size(model_size)
order = 2
context = context or ""
prompt = prompt or ""
correct_answer = correct_answer or ""
return _compute_live_attributions(
context=context,
prompt=prompt,
correct_answer=correct_answer,
model_size=model_size,
level=level,
method=method,
order=order,
scalarizer=scalarizer,
embedding_model=embedding_model,
progress=progress,
)
# ---------------------------------------------------------------------------
# Multimodal precomputed example handlers (MIMIC-CXR, ISIC, MS-COCO)
# ---------------------------------------------------------------------------
# ── MIMIC-CXR Tab Handlers ────────────────────────────────────────────────
_MIMIC_METHOD_NAMES = [
"BiomedCLIP Cross-Modal",
"LLaVA-Med Log-Prob",
"LLaVA-Med Generation",
]
def _on_select_mimic_example(example_id, method_label: str = "Influence"):
"""Load a MIMIC-CXR example and return data for the MIMIC tab."""
# 15 outputs: caption, original, findings, interpretation,
# biomedclip_overlay, biomedclip_token_plot, biomedclip_region_plot,
# llavamed_unsam_lp_overlay, llavamed_unsam_lp_plot,
# llavamed_unsam_gen_overlay, llavamed_unsam_gen_plot,
# biomedclip_interaction_html, meta, results_state, compare_method_a
n_outputs = 15
empty = tuple([""] + [None] * (n_outputs - 1))
if not _MIMIC_AVAILABLE or not example_id:
return empty
method = (method_label or "Influence").lower()
method_display = "Influence" if method == "influence" else "Shapley"
_base_chart = globals()["create_shapley_bar_chart"]
_base_html = globals()["create_benchmark_interaction_html"]
def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): # noqa: F811
kwargs.setdefault("method_label", method_display)
return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs)
def create_benchmark_interaction_html(**kwargs): # noqa: F811
kwargs.setdefault("method_label", method_display)
return _base_html(**kwargs)
try:
data = load_mimic_example(example_id, method=method)
except Exception:
return empty
caption = data.get("caption", "")
findings = data.get("findings", "")
original_img = data.get("original_image_path")
meta = data.get("meta", {})
category = meta.get("category", "")
# ── BiomedCLIP ───────────────────────────────────────────────────
biomedclip_overlay_labeled = None
biomedclip_region_plot = None
biomedclip_token_plot = None
biomedclip_interaction_html = ""
segment_bboxes = None
label_map_b64 = ""
if data.get("has_biomedclip"):
bc_summary = data["biomedclip"]["summary"]
bc_overlay_raw = data["biomedclip"]["image_paths"].get("overlay")
bc_original = data["biomedclip"]["image_paths"].get("original", "")
bc_segmap = data["biomedclip"]["image_paths"].get("segmap", "")
bc_n_segs = len(bc_summary.get("image_region_values", []))
bc_bboxes, bc_label_map_b64 = None, ""
if bc_original and bc_segmap and bc_n_segs > 0:
try:
bc_bboxes, bc_label_map_b64 = extract_segment_regions(
bc_original, bc_segmap, bc_n_segs)
except Exception:
pass
if bc_overlay_raw:
biomedclip_overlay_labeled = draw_segment_labels(
bc_overlay_raw, bc_summary.get("image_region_values", []),
segment_bboxes=bc_bboxes,
label_map_b64=bc_label_map_b64,
original_path=bc_original)
bc_r_labels = [v["label"] for v in bc_summary.get("image_region_values", [])]
bc_r_values = [v["value"] for v in bc_summary.get("image_region_values", [])]
if bc_r_labels:
biomedclip_region_plot = create_shapley_bar_chart(
bc_r_labels, bc_r_values, "BiomedCLIP — Image Region Shapley Values")
bc_merged = merge_subword_token_values(bc_summary.get("token_values", []), caption)
bc_t_labels = [v["label"] for v in bc_merged]
bc_t_values = [v["value"] for v in bc_merged]
if bc_t_labels:
biomedclip_token_plot = create_shapley_bar_chart(
bc_t_labels, bc_t_values, "BiomedCLIP — Caption Word Shapley Values")
# Interactive cross-modal HTML
bc_image_b64 = data["biomedclip"].get("image_b64", {}).get("original", "")
bc_overlay_b64 = data["biomedclip"].get("image_b64", {}).get("overlay", "")
bc_all_cross = data["biomedclip"].get("all_cross_modal_pairs", [])
bc_segmap_b64 = ""
if bc_segmap:
import os as _os
if _os.path.exists(bc_segmap):
import base64 as _b64
with open(bc_segmap, "rb") as _f:
bc_segmap_b64 = _b64.b64encode(_f.read()).decode("ascii")
biomedclip_interaction_html = create_benchmark_interaction_html(
image_b64=bc_image_b64,
clip_summary=bc_summary,
vllm_logprob=None,
caption=caption,
all_cross_modal_pairs=bc_all_cross,
segmap_b64=bc_segmap_b64,
overlay_b64=bc_overlay_b64,
segment_bboxes=bc_bboxes,
label_map_b64=bc_label_map_b64,
title="BiomedCLIP Cross-Modal Interaction View — click segments or words",
)
segment_bboxes = bc_bboxes
label_map_b64 = bc_label_map_b64
# ── LLaVA-Med UnSAM ─────────────────────────────────────────────
# Draw two separate overlays — one colored by Log-Prob values, one by
# Generation values — since the signs often differ between methods.
llavamed_unsam_lp_overlay_img = None
llavamed_unsam_gen_overlay_img = None
llavamed_unsam_lp_plot = None
llavamed_unsam_gen_plot = None
if data.get("has_llavamed_unsam_logprob") or data.get("has_llavamed_unsam_gen"):
lu_segmap = data.get("llavamed_unsam_segmap_path", "")
lu_original = data.get("llavamed_unsam_original_path", "") or (original_img or "")
lu_bboxes, lu_label_map_b64 = None, ""
if lu_segmap and lu_original:
n_lu_segs = 0
if data.get("has_llavamed_unsam_logprob"):
n_lu_segs = len(data["llavamed_unsam_logprob"].get("image_region_values", []))
elif data.get("has_llavamed_unsam_gen"):
n_lu_segs = len(data["llavamed_unsam_gen"].get("image_region_values", []))
if n_lu_segs > 0:
try:
lu_bboxes, lu_label_map_b64 = extract_segment_regions(
lu_original, lu_segmap, n_lu_segs)
except Exception:
pass
if data.get("has_llavamed_unsam_logprob"):
lu_lp = rename_patch_labels(
data["llavamed_unsam_logprob"].get("image_region_values", []))
if lu_lp:
llavamed_unsam_lp_plot = create_shapley_bar_chart(
[v["label"] for v in lu_lp],
[v["value"] for v in lu_lp],
"LLaVA-Med Log-Prob — Segment Shapley Values",
)
overlay_path = data["llavamed_unsam_logprob"].get("overlay_path", "")
if overlay_path:
llavamed_unsam_lp_overlay_img = draw_segment_labels(
overlay_path, lu_lp,
segment_bboxes=lu_bboxes,
label_map_b64=lu_label_map_b64,
original_path=lu_original)
if data.get("has_llavamed_unsam_gen"):
lu_gen = rename_patch_labels(
data["llavamed_unsam_gen"].get("image_region_values", []))
if lu_gen:
llavamed_unsam_gen_plot = create_shapley_bar_chart(
[v["label"] for v in lu_gen],
[v["value"] for v in lu_gen],
"LLaVA-Med Generation — Segment Shapley Values",
)
# Use the log-prob overlay as the base image and recolor by gen values
overlay_path = (data["llavamed_unsam_gen"].get("overlay_path", "")
or data.get("llavamed_unsam_logprob", {}).get("overlay_path", ""))
if overlay_path:
llavamed_unsam_gen_overlay_img = draw_segment_labels(
overlay_path, lu_gen,
segment_bboxes=lu_bboxes,
label_map_b64=lu_label_map_b64,
original_path=lu_original)
# ── Interpretation text ──────────────────────────────────────────
interpretation = ""
try:
bc_data = data.get("biomedclip", {}).get("summary") if data.get("has_biomedclip") else None
interpretation = generate_interpretation_text(
clip_summary=bc_data,
vllm_logprob=data.get("llavamed_unsam_logprob") if data.get("has_llavamed_unsam_logprob") else None,
modality="Chest X-ray",
body_part=category,
caption=caption,
cross_method_name="BiomedCLIP",
vlm_method_name="LLaVA-Med",
vlm_region_type="UnSAM segments",
)
except Exception:
pass
# If no precomputed results at all, show informative message
if not any(data.get(k) for k in ("has_biomedclip", "has_llavamed_unsam_logprob",
"has_llavamed_unsam_gen", "has_clip")):
interpretation = (
"**No precomputed attribution results yet.**\n\n"
"Run the attribution pipeline on this MIMIC-CXR example to see results here. "
"The image and report are shown above for reference."
)
# ── Build results state for comparison ───────────────────────────
_results_state = {}
if biomedclip_overlay_labeled:
_results_state["BiomedCLIP Cross-Modal"] = {
"overlay": biomedclip_overlay_labeled, "plot": biomedclip_region_plot}
if llavamed_unsam_lp_overlay_img:
_results_state["LLaVA-Med Log-Prob"] = {
"overlay": llavamed_unsam_lp_overlay_img, "plot": llavamed_unsam_lp_plot}
if llavamed_unsam_gen_overlay_img:
_results_state["LLaVA-Med Generation"] = {
"overlay": llavamed_unsam_gen_overlay_img, "plot": llavamed_unsam_gen_plot}
return (
caption, # 1
original_img, # 2
findings, # 3
interpretation, # 4
biomedclip_overlay_labeled, # 5
biomedclip_token_plot, # 6
biomedclip_region_plot, # 7
llavamed_unsam_lp_overlay_img, # 8
llavamed_unsam_lp_plot, # 9
llavamed_unsam_gen_overlay_img, # 10a
llavamed_unsam_gen_plot, # 10b
biomedclip_interaction_html, # 11
{ # 12 — metadata
"example_id": example_id,
"category": category,
"has_biomedclip": data.get("has_biomedclip", False),
"has_llavamed_unsam_logprob": data.get("has_llavamed_unsam_logprob", False),
"has_llavamed_unsam_gen": data.get("has_llavamed_unsam_gen", False),
},
_results_state, # 13
gr.update(), # 14 (placeholder)
)
def _on_mimic_compare_methods(method_a, method_b, results_state):
"""Pick two MIMIC methods from state and display side by side."""
if not method_a or not method_b or not results_state:
return None, None, None, None
a = results_state.get(method_a, {})
b = results_state.get(method_b, {})
return a.get("overlay"), b.get("overlay"), a.get("plot"), b.get("plot")
# ── ISIC Dermoscopy Tab Handlers ──────────────────────────────────────────
_ISIC_METHOD_NAMES = [
"BiomedCLIP Cross-Modal",
"LLaVA-Med Log-Prob",
"LLaVA-Med Generation",
]
def _on_select_isic_example(example_id, method_label: str = "Influence"):
"""Load an ISIC dermoscopy example and return data for the ISIC tab.
Mirrors _on_select_mimic_example — same 14 outputs, same layout.
ISIC has no separate "findings" field, so slot 3 (findings) is empty.
"""
n_outputs = 14
empty = tuple([""] + [None] * (n_outputs - 1))
if not _ISIC_AVAILABLE or not example_id:
return empty
method = (method_label or "Influence").lower()
method_display = "Influence" if method == "influence" else "Shapley"
_base_chart = globals()["create_shapley_bar_chart"]
_base_html = globals()["create_benchmark_interaction_html"]
def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): # noqa: F811
kwargs.setdefault("method_label", method_display)
return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs)
def create_benchmark_interaction_html(**kwargs): # noqa: F811
kwargs.setdefault("method_label", method_display)
return _base_html(**kwargs)
try:
data = load_isic_example(example_id, method=method)
except Exception:
return empty
caption = data.get("caption", "")
original_img = data.get("original_image_path")
meta = data.get("meta", {})
category = meta.get("category", "")
# ── BiomedCLIP ───────────────────────────────────────────────────
biomedclip_overlay_labeled = None
biomedclip_region_plot = None
biomedclip_token_plot = None
biomedclip_interaction_html = ""
if data.get("has_biomedclip"):
bc_summary = data["biomedclip"]["summary"]
bc_overlay_raw = data["biomedclip"]["image_paths"].get("overlay")
bc_original = data["biomedclip"]["image_paths"].get("original", "")
bc_segmap = data["biomedclip"]["image_paths"].get("segmap", "")
bc_n_segs = len(bc_summary.get("image_region_values", []))
bc_bboxes, bc_label_map_b64 = None, ""
if bc_original and bc_segmap and bc_n_segs > 0:
try:
bc_bboxes, bc_label_map_b64 = extract_segment_regions(
bc_original, bc_segmap, bc_n_segs)
except Exception:
pass
if bc_overlay_raw:
biomedclip_overlay_labeled = draw_segment_labels(
bc_overlay_raw, bc_summary.get("image_region_values", []),
segment_bboxes=bc_bboxes,
label_map_b64=bc_label_map_b64,
original_path=bc_original)
bc_r_labels = [v["label"] for v in bc_summary.get("image_region_values", [])]
bc_r_values = [v["value"] for v in bc_summary.get("image_region_values", [])]
if bc_r_labels:
biomedclip_region_plot = create_shapley_bar_chart(
bc_r_labels, bc_r_values, "BiomedCLIP — Image Region Shapley Values")
bc_merged = merge_subword_token_values(bc_summary.get("token_values", []), caption)
bc_t_labels = [v["label"] for v in bc_merged]
bc_t_values = [v["value"] for v in bc_merged]
if bc_t_labels:
biomedclip_token_plot = create_shapley_bar_chart(
bc_t_labels, bc_t_values, "BiomedCLIP — Caption Word Shapley Values")
bc_image_b64 = data["biomedclip"].get("image_b64", {}).get("original", "")
bc_overlay_b64 = data["biomedclip"].get("image_b64", {}).get("overlay", "")
bc_all_cross = data["biomedclip"].get("all_cross_modal_pairs", [])
bc_segmap_b64 = ""
if bc_segmap:
import os as _os
if _os.path.exists(bc_segmap):
import base64 as _b64
with open(bc_segmap, "rb") as _f:
bc_segmap_b64 = _b64.b64encode(_f.read()).decode("ascii")
biomedclip_interaction_html = create_benchmark_interaction_html(
image_b64=bc_image_b64,
clip_summary=bc_summary,
vllm_logprob=None,
caption=caption,
all_cross_modal_pairs=bc_all_cross,
segmap_b64=bc_segmap_b64,
overlay_b64=bc_overlay_b64,
segment_bboxes=bc_bboxes,
label_map_b64=bc_label_map_b64,
title="BiomedCLIP Cross-Modal Interaction View — click segments or words",
)
# ── LLaVA-Med UnSAM ─────────────────────────────────────────────
llavamed_unsam_lp_overlay_img = None
llavamed_unsam_gen_overlay_img = None
llavamed_unsam_lp_plot = None
llavamed_unsam_gen_plot = None
if data.get("has_llavamed_unsam_logprob") or data.get("has_llavamed_unsam_gen"):
lu_segmap = data.get("llavamed_unsam_segmap_path", "")
lu_original = data.get("llavamed_unsam_original_path", "") or (original_img or "")
lu_bboxes, lu_label_map_b64 = None, ""
if lu_segmap and lu_original:
n_lu_segs = 0
if data.get("has_llavamed_unsam_logprob"):
n_lu_segs = len(data["llavamed_unsam_logprob"].get("image_region_values", []))
elif data.get("has_llavamed_unsam_gen"):
n_lu_segs = len(data["llavamed_unsam_gen"].get("image_region_values", []))
if n_lu_segs > 0:
try:
lu_bboxes, lu_label_map_b64 = extract_segment_regions(
lu_original, lu_segmap, n_lu_segs)
except Exception:
pass
if data.get("has_llavamed_unsam_logprob"):
lu_lp = rename_patch_labels(
data["llavamed_unsam_logprob"].get("image_region_values", []))
if lu_lp:
llavamed_unsam_lp_plot = create_shapley_bar_chart(
[v["label"] for v in lu_lp],
[v["value"] for v in lu_lp],
"LLaVA-Med Log-Prob — Segment Shapley Values",
)
overlay_path = data["llavamed_unsam_logprob"].get("overlay_path", "")
if overlay_path:
llavamed_unsam_lp_overlay_img = draw_segment_labels(
overlay_path, lu_lp,
segment_bboxes=lu_bboxes,
label_map_b64=lu_label_map_b64,
original_path=lu_original)
if data.get("has_llavamed_unsam_gen"):
lu_gen = rename_patch_labels(
data["llavamed_unsam_gen"].get("image_region_values", []))
if lu_gen:
llavamed_unsam_gen_plot = create_shapley_bar_chart(
[v["label"] for v in lu_gen],
[v["value"] for v in lu_gen],
"LLaVA-Med Generation — Segment Shapley Values",
)
overlay_path = (data["llavamed_unsam_gen"].get("overlay_path", "")
or data.get("llavamed_unsam_logprob", {}).get("overlay_path", ""))
if overlay_path:
llavamed_unsam_gen_overlay_img = draw_segment_labels(
overlay_path, lu_gen,
segment_bboxes=lu_bboxes,
label_map_b64=lu_label_map_b64,
original_path=lu_original)
# ── Interpretation text ──────────────────────────────────────────
interpretation = ""
try:
bc_data = data.get("biomedclip", {}).get("summary") if data.get("has_biomedclip") else None
interpretation = generate_interpretation_text(
clip_summary=bc_data,
vllm_logprob=data.get("llavamed_unsam_logprob") if data.get("has_llavamed_unsam_logprob") else None,
modality="Dermoscopy",
body_part=category,
caption=caption,
cross_method_name="BiomedCLIP",
vlm_method_name="LLaVA-Med",
vlm_region_type="UnSAM segments",
)
except Exception:
pass
if not any(data.get(k) for k in ("has_biomedclip", "has_llavamed_unsam_logprob",
"has_llavamed_unsam_gen", "has_clip")):
interpretation = (
"**No precomputed attribution results yet.**\n\n"
"Run the attribution pipeline on this ISIC example to see results here. "
"The image and caption are shown above for reference."
)
# ── Results state for comparison dropdowns ──────────────────────
_results_state = {}
if biomedclip_overlay_labeled:
_results_state["BiomedCLIP Cross-Modal"] = {
"overlay": biomedclip_overlay_labeled, "plot": biomedclip_region_plot}
if llavamed_unsam_lp_overlay_img:
_results_state["LLaVA-Med Log-Prob"] = {
"overlay": llavamed_unsam_lp_overlay_img, "plot": llavamed_unsam_lp_plot}
if llavamed_unsam_gen_overlay_img:
_results_state["LLaVA-Med Generation"] = {
"overlay": llavamed_unsam_gen_overlay_img, "plot": llavamed_unsam_gen_plot}
return (
caption, # 1
original_img, # 2
interpretation, # 3
biomedclip_overlay_labeled, # 4
biomedclip_token_plot, # 5
biomedclip_region_plot, # 6
llavamed_unsam_lp_overlay_img, # 7
llavamed_unsam_lp_plot, # 8
llavamed_unsam_gen_overlay_img, # 9
llavamed_unsam_gen_plot, # 10
biomedclip_interaction_html, # 11
{ # 12 — metadata
"example_id": example_id,
"category": category,
"has_biomedclip": data.get("has_biomedclip", False),
"has_llavamed_unsam_logprob": data.get("has_llavamed_unsam_logprob", False),
"has_llavamed_unsam_gen": data.get("has_llavamed_unsam_gen", False),
},
_results_state, # 13
gr.update(), # 14 (placeholder for compare dropdown)
)
def _on_isic_compare_methods(method_a, method_b, results_state):
"""Pick two ISIC methods from state and display side by side."""
if not method_a or not method_b or not results_state:
return None, None, None, None
a = results_state.get(method_a, {})
b = results_state.get(method_b, {})
return a.get("overlay"), b.get("overlay"), a.get("plot"), b.get("plot")
def _on_select_coco_example(example_id, method_label: str = "Influence"):
"""Load a precomputed MS-COCO example and return outputs for the COCO tab."""
n_outputs = 12
empty = ("",) + (None,) * (n_outputs - 2) + (gr.update(),)
if not _COCO_AVAILABLE or not _MEDICAL_AVAILABLE or not example_id:
return empty
method = (method_label or "Influence").lower()
method_display = "Influence" if method == "influence" else "Shapley"
_base_chart = globals()["create_shapley_bar_chart"]
_base_html = globals()["create_benchmark_interaction_html"]
def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): # noqa: F811
kwargs.setdefault("method_label", method_display)
return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs)
def create_benchmark_interaction_html(**kwargs): # noqa: F811
kwargs.setdefault("method_label", method_display)
return _base_html(**kwargs)
try:
data = load_coco_example(example_id, method=method)
except Exception as exc:
print(f"[coco] Error loading {example_id}: {exc}")
return empty
caption = data.get("caption", "")
summary = data.get("summary", {})
original_img = data.get("image_paths", {}).get("original")
overlay_img = data.get("image_paths", {}).get("overlay")
# Segment bboxes from segmap
segment_bboxes, label_map_b64 = None, ""
clip_original = data["image_paths"].get("original", "")
clip_segmap = data["image_paths"].get("segmap", "")
n_segs = len(summary.get("image_region_values", []))
if clip_original and clip_segmap and n_segs > 0:
try:
segment_bboxes, label_map_b64 = extract_segment_regions(
clip_original, clip_segmap, n_segs)
except Exception:
pass
# Draw segment labels on overlay
overlay_labeled = overlay_img
if overlay_img:
try:
labeled = draw_segment_labels(
overlay_img,
summary.get("image_region_values", []),
segment_bboxes=segment_bboxes,
)
if labeled:
overlay_labeled = labeled
except Exception:
pass
# Bar charts
r_vals = summary.get("image_region_values", [])
r_labels = [v["label"] for v in r_vals]
r_values = [v["value"] for v in r_vals]
region_plot = create_shapley_bar_chart(
r_labels, r_values, "CLIP — Image Region Shapley Values") if r_labels else None
t_vals = summary.get("token_values", [])
merged_toks = merge_subword_token_values(t_vals, caption)
t_labels = [v["label"] for v in merged_toks]
t_values = [v["value"] for v in merged_toks]
token_plot = create_shapley_bar_chart(
t_labels, t_values, "CLIP — Caption Word Shapley Values") if t_labels else None
# Cross-modal pairs + chart + table
all_cross = data.get("all_cross_modal_pairs", [])
cross_plot = None
cross_table = []
if all_cross:
cross_pairs = [
((item["pair"][0], _tok_to_word(item["pair"][1], caption)), item["value"])
for item in all_cross
]
cross_plot = create_cross_modal_bar_chart(
cross_pairs, "CLIP — Top Image x Word Interactions", top_k=20)
cross_table = [
[item["pair"][0], _tok_to_word(item["pair"][1], caption), f"{item['value']:+.4f}"]
for item in all_cross[:30]
]
# Heatmap
heatmap = None
influence_matrix = data.get("influence_matrix")
tok_labels_hm = [t.replace("tok:", "").lstrip("#") for t in data.get("tok_labels", [])]
if influence_matrix is not None and influence_matrix.size > 0:
heatmap = create_influence_heatmap(
data.get("seg_labels", r_labels), tok_labels_hm, influence_matrix,
"Image Regions x Caption Words — Influence Scores")
# Interactive cross-modal HTML
image_b64 = data.get("image_b64", {}).get("original", "")
overlay_b64 = data.get("image_b64", {}).get("overlay", "")
segmap_b64 = data.get("image_b64", {}).get("segmap", "")
interaction_html = ""
try:
interaction_html = create_benchmark_interaction_html(
image_b64=image_b64,
clip_summary=summary,
vllm_logprob=None,
caption=caption,
all_cross_modal_pairs=all_cross,
segmap_b64=segmap_b64,
overlay_b64=overlay_b64,
segment_bboxes=segment_bboxes,
label_map_b64=label_map_b64,
title="MS-COCO — Click a region or word to explore interactions",
)
except Exception as exc:
interaction_html = f"<p>Error building interaction view: {exc}</p>"
note = (
"**Note:** These results used the original UNK mask token "
"(same as `<|endoftext|>`, CLIP token ID 49407). "
"A first-token dominance artifact may be visible in the token Shapley chart. "
"This will be corrected when scaling to 100 images with the dot-mask fix."
)
# Masked Image Browser
region_choices = data.get("region_choices", [])
masked_dd_update = gr.update(
choices=region_choices,
value=region_choices[0] if region_choices else None,
)
# Pre-load the first masked image (all_masked) so the viewer isn't blank
first_masked_img = None
if region_choices:
try:
first_masked_img = get_coco_masked_image_path(example_id, region_choices[0])
except Exception:
pass
return (
caption, # 1
original_img, # 2
overlay_labeled, # 3
interaction_html, # 4
token_plot, # 5
region_plot, # 6
cross_plot, # 7
cross_table, # 8
heatmap, # 9
note, # 10
first_masked_img, # 11 — masked image viewer
masked_dd_update, # 12 — masked dropdown choices
)
def _on_select_coco_masked(example_id, choice):
"""Return a masked image path for the COCO Masked Image Browser."""
if not _COCO_AVAILABLE or not example_id or not choice:
return None
return get_coco_masked_image_path(example_id, choice)
def on_click_image_compute(
image,
caption,
clip_model,
seg_mode,
grid_size,
method,
seed,
progress=gr.Progress(track_tqdm=True),
):
return _compute_image_attributions_clip(
image=image,
caption=caption,
clip_model=clip_model,
seg_mode=seg_mode,
grid_size=grid_size,
method=method,
seed=seed,
progress=progress,
)
def on_click_mm_compute(
image,
caption,
clip_model,
seg_mode,
grid_size,
method,
seed,
progress=gr.Progress(track_tqdm=True),
):
return _compute_mm_attributions_clip(
image=image,
caption=caption,
clip_model=clip_model,
seg_mode=seg_mode,
grid_size=grid_size,
method=method,
seed=seed,
progress=progress,
)
# ---------------------------------------------------------------------------
# Demo helpers (used to quickly validate visualization components locally)
# ---------------------------------------------------------------------------
_DEMO_TEXT = "The quick brown fox jumps over the lazy dog in a sunny meadow."
_DEMO_FEATURES = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "in", "a", "sunny", "meadow"]
_DEMO_SPANS = [
(0, 3), (4, 9), (10, 15), (16, 19), (20, 25), (26, 30), (31, 34),
(35, 39), (40, 43), (44, 46), (47, 48), (49, 54), (55, 61)
]
_DEMO_ATTRIBUTIONS: Dict[str, Dict[str, float]] = {
"shapley": {
"The": -0.04,
"quick": 0.18,
"brown": 0.12,
"fox": 0.27,
"jumps": 0.15,
"over": 0.05,
"the": -0.02,
"lazy": -0.11,
"dog": -0.07,
"in": 0.03,
"a": 0.02,
"sunny": 0.09,
"meadow": 0.21,
}
}
_DEMO_ATTRIBUTIONS["banzhaf"] = {
token: round(value * 0.8, 3)
for token, value in _DEMO_ATTRIBUTIONS["shapley"].items()
}
_DEMO_ATTRIBUTIONS["influence"] = {
token: round(abs(value), 3)
for token, value in _DEMO_ATTRIBUTIONS["shapley"].items()
}
_DEMO_INTERACTIONS_2: List[Tuple[Tuple[str, ...], float]] = [
(("quick", "fox"), 0.24),
(("fox", "jumps"), 0.19),
(("sunny", "meadow"), 0.22),
(("lazy", "dog"), -0.17),
(("the", "lazy"), -0.12),
]
_DEMO_INTERACTIONS_3: List[Tuple[Tuple[str, ...], float]] = [
(("quick", "brown", "fox"), 0.28),
(("fox", "jumps", "over"), 0.18),
(("sunny", "meadow", "dog"), 0.11),
(("the", "lazy", "dog"), -0.21),
]
_DEMO_INTERACTION_MATRIX: List[Tuple[Tuple[int, int], float]] = [
((1, 3), 0.23),
((3, 4), 0.17),
((7, 8), -0.18),
((11, 12), 0.2),
((2, 5), 0.09),
]
_DEMO_DATASETS = {
"squad_demo": [
[
"The quick brown fox jumps over the lazy dog.",
"Who jumps over the dog?",
"The quick brown fox",
],
[
"AttrLLM explains attributions for large language models.",
"What does AttrLLM explain?",
"Attributions",
],
],
"truthfulqa_demo": [
[
"Water boils at 100 degrees Celsius at sea level.",
"At what temperature does water boil?",
"100 degrees Celsius",
]
],
}
def _render_demo(method: str = "shapley"):
method = (method or "shapley").lower()
order = 2
attributions = _DEMO_ATTRIBUTIONS.get(method, _DEMO_ATTRIBUTIONS["shapley"])
interactions = _DEMO_INTERACTIONS_3 if order == 3 else _DEMO_INTERACTIONS_2
interactions_fig = plot_top_interactions(interactions, order=order, method=method)
demo_pairwise = _DEMO_INTERACTIONS_2 or _fallback_pairwise_from_values(
_DEMO_FEATURES,
[attributions[token] for token in _DEMO_FEATURES],
)
text_html = create_interactive_text_heatmap(
_DEMO_TEXT,
_DEMO_SPANS,
[attributions[token] for token in _DEMO_FEATURES],
method=method,
)
text_interaction_html = create_text_interaction_html(
_DEMO_FEATURES,
[attributions[token] for token in _DEMO_FEATURES],
[
{"indices": [i, j], "value": float(val)}
for (i, j), val in _DEMO_INTERACTION_MATRIX
],
method=method,
top_k=20,
threshold=0.0,
)
meta = {
"method": method,
"order": order,
"feature_count": len(_DEMO_FEATURES),
"scalarizer": "logprob",
}
return update(
figs={
"interactions": interactions_fig,
},
meta=meta,
html=text_html,
interaction_text_html=text_interaction_html,
scoring_target_source="model_output",
scoring_target_text="",
reference_answer="",
unmasked_answer="",
debug_scores=None,
scalarizer_used="logprob",
score_full=None,
score_empty=None,
y_len_tokens=None,
)
def _render_additional_plots(method: str = "shapley"):
return plot_interaction_matrix(_DEMO_FEATURES, _DEMO_INTERACTION_MATRIX)
def _records_for_dataset(dataset_name: str) -> List[Dict[str, Any]]:
if get_examples is not None:
try:
records = get_examples(dataset_name, n=10)
if records:
return records
except KeyError:
pass
except Exception:
pass
fallback_csv = _fallback_load_dataset(dataset_name, max_rows=10)
if fallback_csv:
return fallback_csv
fallback = []
for idx, row in enumerate(_DEMO_DATASETS.get(dataset_name, []), start=1):
context, prompt, answer = row
fallback.append(
{
"id": f"{dataset_name}_demo_{idx}",
"context": context,
"prompt": prompt,
"correct_answer": answer,
}
)
return fallback
def _available_datasets() -> List[str]:
if list_datasets is not None:
try:
datasets = list_datasets()
if datasets:
return datasets
except Exception:
pass
fallback = [k for k, v in _FALLBACK_DATASET_FILES.items() if (_fallback_datasets_dir() / v).exists()]
if fallback:
return sorted(fallback)
return list(_DEMO_DATASETS.keys())
def _format_examples(records: List[Dict[str, Any]]) -> List[List[str]]:
formatted = []
for rec in records:
formatted.append([
rec.get("context", ""),
rec.get("prompt", ""),
rec.get("correct_answer")
or rec.get("answer")
or rec.get("target")
or "",
])
return formatted
def _load_examples_for_demo(dataset_name: str):
# Convert display name to internal key if needed
if get_dataset_key_from_display_name is not None:
dataset_key = get_dataset_key_from_display_name(dataset_name)
else:
dataset_key = dataset_name
records = _records_for_dataset(dataset_key)
formatted = _format_examples(records)
samples = formatted if formatted else _DEMO_DATASETS.get(dataset_key, [])
return gr.update(samples=samples or [])
def _resolve_example_fields(record: Dict[str, Any]) -> Tuple[str, str, str]:
context = record.get("context", "")
prompt = record.get("prompt", "")
answer = (
record.get("correct_answer")
or record.get("answer")
or record.get("target")
or ""
)
return context, prompt, answer
def _resolve_dataset_key(dataset_name: str) -> str:
if dataset_name in _available_datasets():
return dataset_name
for key, label in DATASET_DISPLAY_LABELS.items():
if dataset_name == label:
return key
if get_dataset_key_from_display_name is not None:
return get_dataset_key_from_display_name(dataset_name)
return dataset_name
def _dataset_choice_labels(dataset_keys: List[str]) -> List[str]:
labels: List[str] = []
for key in dataset_keys:
if get_dataset_display_name is not None:
try:
labels.append(get_dataset_display_name(key))
continue
except Exception:
pass
labels.append(DATASET_DISPLAY_LABELS.get(key, key.replace("_", " ").title()))
return labels
def _resolve_example_index(example_number: Any, records: List[Dict[str, Any]]) -> int:
if not records:
return 0
try:
index = int(example_number) - 1
except Exception:
index = 0
return max(0, min(index, len(records) - 1))
def _resolve_example_id(example_number: Any, records: List[Dict[str, Any]]) -> str:
if _public_only_mode():
return f"example_{int(example_number or 1)}"
index = _resolve_example_index(example_number, records)
record = records[index] if records else {}
return str(record.get("id") or f"example_{index + 1}")
def _build_model_answer_panel(dataset_name: str, example_number: Any) -> str:
"""Render Model's Answer + Justification HTML for the 30 wrong-answer
examples; return empty string for everything else so the gr.HTML slot
stays visually empty."""
try:
from visualization.wrong_answer_examples import WRONG_ANSWER_EXAMPLES
except Exception:
return ""
dataset_key = _resolve_dataset_key(dataset_name) if dataset_name else ""
try:
ex_id = f"example_{int(example_number or 1)}"
except Exception:
ex_id = "example_1"
if (dataset_key, ex_id) not in WRONG_ANSWER_EXAMPLES:
return ""
path = (
_get_results_dir() / "model_answers" / "small" / dataset_key / f"{ex_id}.json"
)
if not path.exists():
return ""
try:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
except Exception:
return ""
from html import escape as _escape
letter = (data.get("model_answer_parsed") or "").strip()
raw = (data.get("model_answer_raw") or "").strip()
gt_letter = (data.get("ground_truth_letter") or data.get("ground_truth") or "").strip()
is_match = bool(data.get("is_match"))
similarity = data.get("similarity")
try:
sim_str = f"{float(similarity):.3f}" if similarity is not None else "—"
except Exception:
sim_str = "—"
if is_match:
chip_bg, chip_fg, chip_text = "#e7f6ec", "#1f8d4a", "✓ MATCH"
else:
chip_bg, chip_fg, chip_text = "#fdecea", "#c0392b", "✗ MISMATCH"
# Split off the leading letter+rationale prefix for cleaner reading.
justification = raw
if raw.lower().startswith("justification:"):
justification = raw.split(":", 1)[1].strip()
elif "Justification:" in raw:
justification = raw.split("Justification:", 1)[1].strip()
return (
'<div style="display:flex;flex-direction:column;gap:8px;'
'background:#fdf7ff;border:1px solid #e2d6f3;border-radius:10px;'
'padding:12px 14px;margin-top:6px;">'
'<div style="display:flex;align-items:center;gap:10px;flex-wrap:wrap;">'
'<strong style="color:#4a1c87;">Model\'s Answer</strong>'
f'<span style="background:#fff;border:1px solid #d8c6f0;border-radius:6px;'
f'padding:2px 8px;font-weight:600;color:#4a1c87;">{_escape(letter or "—")}</span>'
f'<span style="color:#6f5a72;font-size:12px;">vs Ground Truth: '
f'<strong>{_escape(gt_letter or "—")}</strong></span>'
f'<span style="background:{chip_bg};color:{chip_fg};border-radius:999px;'
f'padding:2px 10px;font-size:12px;font-weight:600;">{chip_text}</span>'
f'<span style="color:#6f5a72;font-size:12px;">sim={sim_str}</span>'
'</div>'
'<div style="font-size:13px;color:#3a2b4a;line-height:1.55;'
'background:#fff;border:1px solid #ece4f8;border-radius:8px;padding:10px 12px;">'
f'<strong style="display:block;margin-bottom:4px;color:#4a1c87;">Justification</strong>'
f'{_escape(justification) if justification else "<em>No justification captured.</em>"}'
'</div>'
'</div>'
)
def _load_examples_for_slider(dataset_name: str):
dataset_key = _resolve_dataset_key(dataset_name)
records = _records_for_dataset(dataset_key)
slider_max = max(1, min(10, len(records) or 10))
context = prompt = answer = ""
if records:
context, prompt, answer = _resolve_example_fields(records[0])
slider_update = gr.update(minimum=1, maximum=slider_max, step=1, value=1)
return slider_update, records, context, prompt, answer
def _update_example_preview(example_number: Any, records):
if not records:
return "", "", ""
index = _resolve_example_index(example_number, records)
return _resolve_example_fields(records[index])
def _results_output_list(results: Dict[str, Any]) -> List[Any]:
return [
results["interactions"],
results["interactions_tokens_html"],
results["interactions_text_html"],
results["text_html"],
results["meta"],
results["scoring_target_source"],
results["scoring_target_text"],
results["reference_answer"],
results["unmasked_answer"],
results["debug_scores"],
results["scalarizer_used"],
results["score_full"],
results["score_empty"],
results["y_len_tokens"],
]
def build_demo_app() -> gr.Blocks:
datasets = _available_datasets()
default_dataset = datasets[0] if datasets else "demo"
# Apply the same colorful CSS theme
custom_css = """
.gradio-container {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif !important;
background: linear-gradient(135deg, #fef5f0 0%, #f0e8ff 50%, #e8f5ff 100%) !important;
padding: 24px !important;
}
.gradio-container h1, .gradio-container h2 {
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
font-weight: 900;
font-size: 42px !important;
margin: 20px 0 16px 0;
letter-spacing: -0.03em;
}
label, .gr-label {
font-weight: 700 !important;
font-size: 16px !important;
color: #2d1f4a !important;
}
.gr-button {
border-radius: 16px !important;
font-weight: 700 !important;
font-size: 17px !important;
padding: 16px 32px !important;
background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important;
color: white !important;
border: none !important;
}
.gr-box, .gr-input, .gr-dropdown, .gr-textbox {
border-radius: 14px !important;
border: 3px solid #e8dff5 !important;
font-size: 17px !important;
}
.gr-markdown p {
font-size: 17px !important;
font-weight: 500 !important;
}
"""
_demo_kwargs = {"title": "AttrLLM Visualization Demo"}
if _supports_kwarg(gr.Blocks, "css"):
_demo_kwargs["css"] = custom_css
with gr.Blocks(**_demo_kwargs) as demo:
gr.Markdown(
"# 🎨 AttrLLM Visualization Demo\n\n"
"**Preview the attribution widgets** before wiring real backends. "
"Use the controls below to explore the interface."
)
with gr.Row():
with gr.Column(scale=1):
# Prepare initial choices and value before creating component
initial_choices = _dataset_choice_labels(datasets)
initial_value = initial_choices[0] if initial_choices else None
dataset_selector = gr.Dropdown(
choices=initial_choices,
value=initial_value,
label="Dataset",
interactive=True,
allow_custom_value=False,
elem_id="dataset-selector-demo",
elem_classes=["bubble-select"],
)
example_browser = create_example_browser()
with gr.Column(scale=1):
model_selector = create_model_selector()
scalarizer_selector = gr.Dropdown(
choices=SCALARIZER_CHOICES,
value="logprob",
label="Scalarizer",
interactive=True,
)
embedding_model_box = gr.Textbox(
label="Embedding Model (for scalarizer=embedding)",
value="Qwen/Qwen3-Embedding-0.6B",
lines=1,
)
feature_level_selector = create_feature_level_selector()
method_toggle = create_attribution_method_toggle()
dataset_selector.change(
fn=_load_examples_for_demo,
inputs=dataset_selector,
outputs=example_browser,
)
demo.load(
fn=_load_examples_for_demo,
inputs=[dataset_selector],
outputs=[example_browser],
)
render_button = gr.Button("Render Demo Visuals", variant="primary")
outputs = create_results_display()
extra_matrix = gr.Plot(label="Interaction Matrix (demo)")
render_button.click(
fn=_render_demo,
inputs=[method_toggle],
outputs=_results_output_list(outputs),
)
render_button.click(
fn=_render_additional_plots,
inputs=[method_toggle],
outputs=[extra_matrix],
)
return demo
def _patch_gradio_schema_generation() -> None:
"""Prevent Gradio 5.x /info crash caused by additionalProperties: true in schemas."""
try:
from gradio_client import utils as client_utils
except Exception:
return
if getattr(client_utils, "_attrllm_schema_patch", False):
return
original_inner = getattr(client_utils, "_json_schema_to_python_type", None)
original_outer = getattr(client_utils, "json_schema_to_python_type", None)
if not callable(original_inner) or not callable(original_outer):
return
def _normalize_schema(schema):
if isinstance(schema, bool):
return {} if schema else {"type": "null"}
if isinstance(schema, list):
return [_normalize_schema(item) for item in schema]
if not isinstance(schema, dict):
return schema
normalized = dict(schema)
if isinstance(normalized.get("additionalProperties"), bool):
normalized["additionalProperties"] = _normalize_schema(normalized["additionalProperties"])
for key in ("properties", "$defs", "definitions", "patternProperties"):
value = normalized.get(key)
if isinstance(value, dict):
normalized[key] = {k: _normalize_schema(v) for k, v in value.items()}
for key in ("items", "contains", "not", "if", "then", "else"):
if key in normalized:
normalized[key] = _normalize_schema(normalized[key])
for key in ("anyOf", "allOf", "oneOf", "prefixItems"):
value = normalized.get(key)
if isinstance(value, list):
normalized[key] = [_normalize_schema(item) for item in value]
return normalized
client_utils._json_schema_to_python_type = lambda s, d=None: original_inner(_normalize_schema(s), d)
client_utils.json_schema_to_python_type = lambda s: original_outer(_normalize_schema(s))
client_utils._attrllm_schema_patch = True
_patch_gradio_schema_generation()
def build_app() -> gr.Blocks:
datasets = _available_datasets()
default_dataset = datasets[0] if datasets else ""
public_only = _public_only_mode()
mm_only = _mm_only_mode()
# Custom CSS for prettier UI - Inspired by modern, colorful design
custom_css = """
/* Main container styling - Warm gradient background */
.gradio-container {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif !important;
background: linear-gradient(135deg, #fef5f0 0%, #f0e8ff 50%, #e8f5ff 100%) !important;
padding: 24px !important;
}
/* Header styling - Large, bold, colorful */
.gradio-container h1 {
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
font-weight: 900;
font-size: 48px !important;
margin: 20px 0 16px 0;
letter-spacing: -0.03em;
text-align: left;
}
.gradio-container h2 {
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
font-weight: 900;
font-size: 42px !important;
margin: 20px 0 16px 0;
letter-spacing: -0.03em;
}
.gradio-container h3 {
color: #2d1f4a;
font-weight: 800;
font-size: 24px !important;
margin: 24px 0 16px 0;
}
/* Tab styling - Bold and colorful */
.tab-nav {
border: none !important;
background: transparent !important;
gap: 8px !important;
padding: 8px 0 !important;
}
.tab-nav button {
font-size: 18px !important;
font-weight: 700 !important;
padding: 16px 32px !important;
border-radius: 16px !important;
transition: all 0.3s ease !important;
border: 3px solid #e0d0f0 !important;
background: white !important;
color: #6c5ce7 !important;
margin-right: 8px !important;
}
.tab-nav button:hover {
background: #f8f4ff !important;
border-color: #b8a8db !important;
transform: translateY(-2px) !important;
}
.tab-nav button.selected {
background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important;
color: white !important;
border: 3px solid #6c5ce7 !important;
box-shadow: 0 6px 20px rgba(108, 92, 231, 0.3) !important;
}
/* Button styling - Vibrant and interactive */
.gr-button {
border-radius: 16px !important;
font-weight: 700 !important;
font-size: 17px !important;
padding: 16px 32px !important;
transition: all 0.3s cubic-bezier(0.34, 1.56, 0.64, 1) !important;
box-shadow: 0 6px 20px rgba(108, 92, 231, 0.2) !important;
border: none !important;
}
.gr-button-primary {
background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important;
color: white !important;
}
.gr-button-secondary {
background: linear-gradient(135deg, #fd79a8 0%, #ff7675 100%) !important;
color: white !important;
}
.gr-button:hover {
transform: translateY(-3px) scale(1.02) !important;
box-shadow: 0 10px 30px rgba(108, 92, 231, 0.35) !important;
}
.gr-button-primary:hover {
background: linear-gradient(135deg, #5e4ec7 0%, #9089e8 100%) !important;
}
/* Input/Dropdown styling - Clear and modern */
.gr-box, .gr-input, .gr-dropdown {
border-radius: 14px !important;
border: 3px solid #e8dff5 !important;
background: white !important;
font-size: 17px !important;
padding: 12px 16px !important;
transition: all 0.3s ease !important;
font-weight: 500 !important;
}
.gr-box:focus, .gr-input:focus, .gr-dropdown:focus {
border-color: #6c5ce7 !important;
box-shadow: 0 0 0 4px rgba(108, 92, 231, 0.15) !important;
transform: translateY(-1px) !important;
}
/* Textbox styling - Larger text */
.gr-textbox {
border-radius: 16px !important;
border: 3px solid #e8dff5 !important;
font-size: 17px !important;
line-height: 1.6 !important;
}
.gr-textbox textarea {
font-size: 17px !important;
line-height: 1.6 !important;
padding: 14px !important;
}
.gr-textbox:focus-within {
border-color: #6c5ce7 !important;
box-shadow: 0 6px 24px rgba(108, 92, 231, 0.2) !important;
}
/* Radio button styling - Colorful pills */
.gr-radio {
gap: 12px !important;
}
.gr-radio label {
font-size: 17px !important;
font-weight: 600 !important;
padding: 14px 28px !important;
border-radius: 14px !important;
border: 3px solid #e8dff5 !important;
transition: all 0.3s ease !important;
background: white !important;
cursor: pointer !important;
}
.gr-radio label:hover {
border-color: #b8a8db !important;
background: #faf8ff !important;
transform: translateY(-2px) !important;
box-shadow: 0 4px 12px rgba(108, 92, 231, 0.15) !important;
}
.gr-radio input:checked + label {
background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important;
color: white !important;
border-color: #6c5ce7 !important;
font-weight: 800 !important;
box-shadow: 0 6px 20px rgba(108, 92, 231, 0.3) !important;
}
/* Panel/Accordion styling - Clean cards */
.gr-panel {
border-radius: 20px !important;
border: 3px solid #e8dff5 !important;
padding: 24px !important;
background: white !important;
box-shadow: 0 6px 24px rgba(108, 92, 231, 0.1) !important;
margin: 16px 0 !important;
}
.gr-accordion {
border-radius: 18px !important;
border: 3px solid #e8dff5 !important;
background: white !important;
}
/* Label styling - Bold and readable */
label, .gr-label {
font-weight: 700 !important;
font-size: 16px !important;
color: #2d1f4a !important;
margin-bottom: 10px !important;
letter-spacing: -0.01em !important;
}
/* Dropdown options */
.gr-dropdown-menu {
border-radius: 14px !important;
border: 3px solid #e8dff5 !important;
box-shadow: 0 8px 32px rgba(108, 92, 231, 0.15) !important;
font-size: 17px !important;
}
.gr-dropdown-menu .item {
font-size: 17px !important;
padding: 12px 16px !important;
font-weight: 500 !important;
}
.gr-dropdown-menu .item:hover {
background: linear-gradient(135deg, #f3f0ff 0%, #e8f5ff 100%) !important;
}
/* Plot container - Prominent */
.gr-plot {
border-radius: 20px !important;
border: 3px solid #e8dff5 !important;
overflow: hidden !important;
box-shadow: 0 8px 30px rgba(108, 92, 231, 0.12) !important;
background: white !important;
width: 100% !important;
}
/* Force the inner Plotly canvas + svg to fill its container so the Bar
View doesn't render in a half-width column when the Text Interaction
view above it is wide. */
.gr-plot .js-plotly-plot,
.gr-plot .plot-container,
.gr-plot .svg-container,
.gr-plot .main-svg {
width: 100% !important;
max-width: 100% !important;
}
.interaction-stack > .gradio-plot,
.interaction-stack > .block.gradio-plot,
.interaction-stack .gr-plot {
width: 100% !important;
max-width: 100% !important;
flex: 1 1 100% !important;
}
/* JSON viewer */
.gr-json {
border-radius: 16px !important;
border: 3px solid #e8dff5 !important;
background: #faf8ff !important;
padding: 20px !important;
font-family: 'Monaco', 'Menlo', 'Consolas', monospace !important;
font-size: 15px !important;
}
/* Column styling */
.gr-column {
padding: 20px !important;
}
/* Row styling */
.gr-row {
gap: 24px !important;
margin: 12px 0 !important;
}
/* Markdown content - Larger, more readable */
.gr-markdown {
line-height: 1.8 !important;
color: #2d1f4a !important;
}
.gr-markdown p {
font-size: 17px !important;
margin: 12px 0 !important;
font-weight: 500 !important;
}
.gr-markdown strong {
font-weight: 800 !important;
color: #6c5ce7 !important;
}
/* Status/info messages - Colorful notifications */
.gr-info {
border-radius: 16px !important;
border-left: 5px solid #6c5ce7 !important;
background: linear-gradient(135deg, #f8f6ff 0%, #f0f4ff 100%) !important;
padding: 18px 24px !important;
font-size: 16px !important;
font-weight: 600 !important;
color: #2d1f4a !important;
box-shadow: 0 4px 16px rgba(108, 92, 231, 0.1) !important;
}
/* Error messages */
.gr-error {
border-radius: 16px !important;
border-left: 5px solid #ff6b6b !important;
background: linear-gradient(135deg, #fff5f5 0%, #ffe8e8 100%) !important;
padding: 18px 24px !important;
font-size: 16px !important;
font-weight: 600 !important;
color: #c44569 !important;
}
/* Loading spinner */
.loading {
border: 4px solid #f3f0ff !important;
border-top: 4px solid #6c5ce7 !important;
}
/* Scrollbar styling */
::-webkit-scrollbar {
width: 12px !important;
height: 12px !important;
}
::-webkit-scrollbar-track {
background: #f8f6ff !important;
border-radius: 10px !important;
}
::-webkit-scrollbar-thumb {
background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important;
border-radius: 10px !important;
}
::-webkit-scrollbar-thumb:hover {
background: linear-gradient(135deg, #5e4ec7 0%, #9089e8 100%) !important;
}
.results-shell {
margin-top: 16px !important;
background: transparent !important;
border: none !important;
border-radius: 0 !important;
padding: 0 !important;
box-shadow: none !important;
}
.results-shell,
.results-shell > div,
.results-shell .gr-group,
.results-shell .gr-box,
.results-shell .gr-panel,
.results-shell .block {
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
.interaction-stack {
gap: 20px !important;
padding: 0 8px 6px !important;
}
.interaction-stack h3 {
margin-left: 24px !important;
margin-bottom: 12px !important;
}
.public-controls {
align-items: stretch !important;
gap: 20px !important;
margin-top: 8px !important;
}
.control-card {
background: linear-gradient(180deg, rgba(255, 255, 255, 0.82) 0%, rgba(250, 246, 255, 0.96) 100%) !important;
border: 2px solid rgba(224, 208, 240, 0.78) !important;
border-radius: 26px !important;
padding: 18px 20px 14px !important;
box-shadow: 0 14px 30px rgba(108, 92, 231, 0.07) !important;
}
.control-card-primary {
background: linear-gradient(180deg, rgba(255, 255, 255, 0.86) 0%, rgba(244, 248, 255, 0.96) 100%) !important;
}
.control-card-secondary {
background: linear-gradient(180deg, rgba(255, 255, 255, 0.86) 0%, rgba(250, 244, 255, 0.96) 100%) !important;
}
.control-card .gradio-container,
.control-card .gr-group {
background: transparent !important;
}
.control-card > div,
.control-card .block,
.control-card .wrap,
.control-card .gr-form,
.control-card .form {
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
.control-card .gr-box,
.control-card .gr-panel {
background: transparent !important;
box-shadow: none !important;
}
.bubble-select {
border: 3px solid #8f5cff !important;
border-radius: 18px !important;
box-shadow: 0 8px 20px rgba(143, 92, 255, 0.10) !important;
transition: box-shadow 0.2s ease, border-color 0.2s ease !important;
}
.bubble-select:focus-within {
border-color: #7a3dff !important;
box-shadow: 0 0 0 4px rgba(143, 92, 255, 0.14), 0 10px 24px rgba(143, 92, 255, 0.16) !important;
}
.example-id-slider {
margin-top: 8px !important;
padding: 10px 2px 2px !important;
}
.example-id-slider input[type="range"] {
accent-color: #4f7cff !important;
}
.example-id-slider .number-input,
.example-id-slider input[type="number"] {
border-radius: 16px !important;
border: 2px solid #d8dcee !important;
background: linear-gradient(180deg, #ffffff 0%, #f7f9ff 100%) !important;
font-weight: 700 !important;
min-width: 72px !important;
}
.example-id-slider .wrap {
gap: 14px !important;
}
@media (prefers-color-scheme: dark) {
.gradio-container {
background: radial-gradient(circle at top, #1e2a44 0%, #0d1422 52%, #090f19 100%) !important;
}
.gradio-container h3,
label, .gr-label,
.gr-markdown,
.gr-markdown p {
color: #e8eefc !important;
}
.gr-markdown strong {
color: #cbd7ff !important;
}
.tab-nav button,
.gr-box, .gr-input, .gr-dropdown, .gr-textbox,
.gr-panel, .gr-accordion,
.gr-plot, .gr-json {
background: rgba(16, 24, 39, 0.88) !important;
border-color: rgba(148, 163, 184, 0.24) !important;
color: #e8eefc !important;
}
.tab-nav button {
color: #d7e1ff !important;
}
.tab-nav button:hover {
background: rgba(37, 52, 79, 0.96) !important;
border-color: rgba(199, 210, 254, 0.36) !important;
}
.gr-radio label {
background: rgba(16, 24, 39, 0.9) !important;
border-color: rgba(148, 163, 184, 0.26) !important;
color: #e8eefc !important;
}
.gr-radio label:hover {
background: rgba(37, 52, 79, 0.96) !important;
}
.gr-textbox textarea,
.gr-input input {
background: transparent !important;
color: #e8eefc !important;
}
.gr-dropdown-menu {
background: #101827 !important;
border-color: rgba(148, 163, 184, 0.24) !important;
}
.gr-dropdown-menu .item {
color: #e8eefc !important;
}
.gr-dropdown-menu .item:hover {
background: rgba(37, 52, 79, 0.96) !important;
}
.gr-plot .main-svg,
.gr-plot .svg-container,
.gr-plot .plot-container,
.gr-plot .user-select-none {
background: transparent !important;
}
.gr-plot .xtick text,
.gr-plot .ytick text,
.gr-plot .gtitle text,
.gr-plot .xtitle text,
.gr-plot .ytitle text,
.gr-plot .annotation-text,
.gr-plot .legend text {
fill: #e8eefc !important;
color: #e8eefc !important;
}
.gr-plot .gridlayer path,
.gr-plot .zerolinelayer path,
.gr-plot .xlines-above path,
.gr-plot .ylines-above path {
stroke: rgba(148, 163, 184, 0.22) !important;
}
.gr-info {
background: linear-gradient(135deg, rgba(30, 41, 59, 0.95) 0%, rgba(17, 24, 39, 0.95) 100%) !important;
color: #dbe7ff !important;
border-left-color: #9db4ff !important;
}
.control-card {
background: linear-gradient(180deg, rgba(16, 24, 39, 0.9) 0%, rgba(18, 28, 45, 0.96) 100%) !important;
border-color: rgba(148, 163, 184, 0.2) !important;
box-shadow: 0 18px 36px rgba(0, 0, 0, 0.24) !important;
}
.results-shell,
.results-shell > div,
.results-shell .gr-group,
.results-shell .gr-box,
.results-shell .gr-panel,
.results-shell .block {
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
.bubble-select {
border-color: #a06cff !important;
box-shadow: 0 10px 24px rgba(143, 92, 255, 0.18) !important;
}
.bubble-select:focus-within {
border-color: #c29cff !important;
box-shadow: 0 0 0 4px rgba(143, 92, 255, 0.16), 0 12px 26px rgba(143, 92, 255, 0.22) !important;
}
.example-id-slider .number-input,
.example-id-slider input[type="number"] {
background: linear-gradient(180deg, #162031 0%, #111827 100%) !important;
border-color: rgba(148, 163, 184, 0.24) !important;
color: #e8eefc !important;
}
.gr-error {
background: linear-gradient(135deg, rgba(68, 18, 32, 0.95) 0%, rgba(39, 12, 20, 0.95) 100%) !important;
color: #ffd5dc !important;
}
::-webkit-scrollbar-track {
background: #111827 !important;
}
}
"""
_app_kwargs = {"title": "LLM Reasoning Explorer Studio"}
if _supports_kwarg(gr.Blocks, "css"):
_app_kwargs["css"] = custom_css
with gr.Blocks(**_app_kwargs) as app:
gr.Markdown(
"# LLM Reasoning Explorer Studio\n\n"
"**Explore attribution results and feature interactions** with our interactive visualization tools. "
"Browse pre-computed examples or analyze your own text in real-time with powerful AI insights."
)
gr.Markdown(f"**Build:** {BUILD_ID} ({BUILD_TS})")
example_state = gr.State([])
with (gr.Column(visible=not mm_only) if (public_only or mm_only) else gr.Tab("Public Mode")):
with gr.Accordion("How to Use", open=False):
gr.Markdown(
"1. **Select a dataset** from 10 available datasets (100 total examples, 10 per dataset)\n"
"2. **Choose a model** to compare: Qwen3-4B, Qwen3-30B, or Mistral-7B\n"
"3. **Pick a scoring method:** Perplexity or Semantic Similarity\n"
"4. **Set the feature level:** Word, Sentence, or Paragraph\n"
"5. **Choose an attribution method:** Shapley, Banzhaf, or Influence\n"
"6. **View results** in the Text Interaction View (inline highlights) and Bar View (ranked interactions)"
)
with gr.Row(elem_classes=["public-controls"]):
with gr.Column(scale=1, elem_classes=["control-card", "control-card-primary"]):
# Prepare initial choices and value before creating component
initial_choices = _dataset_choice_labels(datasets)
# In mm_only mode the text attribution tab is hidden — no default value
# prevents the .change() callback from firing on page load.
_preferred_default = "BBQ Disambiguation"
if mm_only:
initial_value = None
elif _preferred_default in initial_choices:
initial_value = _preferred_default
else:
initial_value = initial_choices[0] if initial_choices else None
dataset_selector = gr.Dropdown(
choices=initial_choices,
value=initial_value,
label="Dataset",
interactive=True,
allow_custom_value=False,
elem_id="dataset-selector",
elem_classes=["bubble-select"],
)
example_selector = gr.Slider(
label="Example ID",
minimum=1,
maximum=10,
step=1,
value=1,
interactive=True,
elem_classes=["example-id-slider"],
)
with gr.Column(scale=1, elem_classes=["control-card", "control-card-secondary"]):
model_selector = create_model_selector()
scalarizer_selector = gr.Dropdown(
choices=PUBLIC_SCALARIZER_CHOICES,
value="geomean_jointprob",
label="Scalarizer",
interactive=True,
elem_classes=["bubble-select"],
)
public_feature_level_selector = create_feature_level_selector(value="word")
method_toggle = create_attribution_method_toggle()
with gr.Accordion("Example Preview", open=True):
with gr.Row():
with gr.Column(scale=3):
context_box = gr.Textbox(
label="Context",
lines=8,
interactive=False,
)
with gr.Column(scale=2):
prompt_box = gr.Textbox(
label="Prompt",
lines=4,
interactive=False,
)
answer_box = gr.Textbox(
label="Ground Truth Answer",
lines=3,
interactive=False,
)
# Empty for examples outside the 30-pair allow-list; renders
# the model's parsed letter + justification for the others.
try:
model_answer_html = gr.HTML(value="", sanitize_html=False)
except TypeError:
model_answer_html = gr.HTML(value="")
public_results = create_results_display()
def _public_mode_compute(
dataset,
example_number,
records,
model_size,
scalarizer,
feature_level,
method,
progress=gr.Progress(track_tqdm=True),
):
if mm_only:
return tuple([None] * 14)
if not dataset:
raise gr.Error("Please select a dataset.")
if not example_number:
raise gr.Error("Please select an example.")
dataset_key = _resolve_dataset_key(dataset)
ex_id = _resolve_example_id(example_number, records)
method = _normalize_method(method)
level = _normalize_level(feature_level)
model_size = _normalize_model_size(model_size)
# Prefer precomputed results: use loader if available, else load from file (Space-friendly).
get_res = get_result_by_id if get_result_by_id is not None else _public_get_result_from_file
result = get_res(
model_size,
dataset_key,
ex_id,
scalarizer=scalarizer,
feature_level=level,
) or {}
payload = result.get(method, {})
if not payload:
alt_size = _find_available_model_size(dataset_key, ex_id, scalarizer, level)
if alt_size and alt_size != model_size:
result = get_res(
alt_size,
dataset_key,
ex_id,
scalarizer=scalarizer,
feature_level=level,
) or {}
payload = result.get(method, {})
if payload:
model_size = alt_size
# If still no payload, try any available (model_size, scalarizer, level) for this example
if not payload:
alt_size, alt_scalarizer, alt_level, result = _find_any_available_result(
dataset_key, ex_id, get_res, method
)
if alt_size and alt_scalarizer and alt_level and result:
payload = result.get(method, {})
model_size, scalarizer, level = alt_size, alt_scalarizer, alt_level
if payload and (payload.get("features") or payload.get("heatmap")):
_, _, _, *outputs = on_select_example(
dataset_key,
ex_id,
model_size,
2,
method,
scalarizer=scalarizer,
feature_level=level,
)
return outputs
# Public-only mode: do not attempt live compute
if _public_only_mode() or get_example_by_id is None:
expected_ref = _reference_results_file(model_size, dataset_key, ex_id, scalarizer, level)
raise gr.Error(
"No precomputed results found.\n\n"
f"Expected (reference_answer):\n{expected_ref}\n\n"
"On Hugging Face Space: make sure the 'results' folder is in your repo "
"(commit & push it). If you use Git LFS, enable 'LFS' in Space Settings → "
"Repository and ensure files are pulled. You can also try another "
"scalarizer (e.g. Perplexity) or feature level (e.g. word)."
)
# Fallback to live compute if no precomputed payload or non-word level
get_ex = _ensure_backend("loader.data.get_example_by_id", get_example_by_id)
record = get_ex(dataset_key, ex_id)
context = record.get("context", "")
prompt = record.get("prompt", "")
answer = _extract_answer(record)
return _compute_live_attributions(
context=context,
prompt=prompt,
correct_answer=answer,
model_size=model_size,
scalarizer=scalarizer,
embedding_model=None,
level=level,
method=method,
order=2,
progress=progress,
)
public_preview_outputs = [context_box, prompt_box, answer_box]
public_compute_inputs = [
dataset_selector,
example_selector,
example_state,
model_selector,
scalarizer_selector,
public_feature_level_selector,
method_toggle,
]
public_compute_outputs = _results_output_list(public_results)
dataset_change_event = dataset_selector.change(
fn=_load_examples_for_slider,
inputs=[dataset_selector],
outputs=[
example_selector,
example_state,
context_box,
prompt_box,
answer_box,
],
queue=False,
).then(
fn=_build_model_answer_panel,
inputs=[dataset_selector, example_selector],
outputs=[model_answer_html],
queue=False,
)
load_event = app.load(
fn=_load_examples_for_slider,
inputs=[dataset_selector],
outputs=[
example_selector,
example_state,
context_box,
prompt_box,
answer_box,
],
).then(
fn=_build_model_answer_panel,
inputs=[dataset_selector, example_selector],
outputs=[model_answer_html],
queue=False,
)
dataset_change_event.then(
fn=_public_mode_compute,
inputs=public_compute_inputs,
outputs=public_compute_outputs,
show_progress="full",
)
load_event.then(
fn=_public_mode_compute,
inputs=public_compute_inputs,
outputs=public_compute_outputs,
show_progress="full",
)
example_selector.release(
fn=_update_example_preview,
inputs=[example_selector, example_state],
outputs=public_preview_outputs,
queue=False,
).then(
fn=_build_model_answer_panel,
inputs=[dataset_selector, example_selector],
outputs=[model_answer_html],
queue=False,
).then(
fn=_public_mode_compute,
inputs=public_compute_inputs,
outputs=public_compute_outputs,
show_progress="full",
)
for component in (
model_selector,
scalarizer_selector,
public_feature_level_selector,
method_toggle,
):
component.change(
fn=_public_mode_compute,
inputs=public_compute_inputs,
outputs=public_compute_outputs,
show_progress="full",
)
# ── MULTIMODAL TAB ──────────────────────────────────────────
with gr.Tab("Multimodal"):
with gr.Accordion("How to Use", open=False):
gr.Markdown(
"1. **Choose a dataset** from the three sub-tabs:\n"
" - **MIMIC-CXR (10 Samples)** — chest X-rays across 10 pathology categories\n"
" - **Dermoscopy ISIC (10 Samples)** — skin-lesion dermoscopy across 8 diagnostic classes\n"
" - **MS-COCO (5 Samples)** — natural-image cross-modal benchmark\n"
"2. **Pick an example** from the dropdown (each is an image + caption pair)\n"
"3. **Choose an attribution method:** Influence (default, non-negative — clearer for clinicians) or Shapley (signed)\n"
"4. **Read the four panels side-by-side:**\n"
" - **Interactive Cross-Modal View** — click any image patch or caption word to see its strongest cross-modal partners\n"
" - **BiomedCLIP Cross-Modal Attribution** — patch-level overlay + bar charts (cosine-similarity scoring)\n"
" - **LLaVA-Med Attribution** — log-prob and generation Shapley charts from the medical 7B VLM\n"
" - **Compare Two Methods Side-by-Side** — pick any two of the above to overlay their rankings\n"
"5. **Hover token chips and patches** for exact attribution values; hover SVG arcs for pairwise interaction strength"
)
with gr.Tab("MIMIC-CXR (10 Samples)"):
gr.Markdown(
"**10-sample MIMIC-CXR chest X-ray attribution benchmark** "
"(10 pathology categories). \n"
"Source: [MIMIC-CXR-JPG](https://huggingface.co/datasets/itsanmolgupta/mimic-cxr-dataset-cleaned) "
"— de-identified chest radiographs from Beth Israel Deaconess Medical Center. \n"
"Each example has a radiology report (impression = caption, findings = detail)."
)
# Build (category_name, example_id) choices so picking a
# pathology directly loads its example (1:1 mapping).
_mimic_choices = (
[(v["category"], k) for k, v in MIMIC_EXAMPLES.items()]
if _MIMIC_AVAILABLE else []
)
mimic_selector = gr.Dropdown(
choices=_mimic_choices,
value=None,
label="Filter by Pathology",
interactive=True,
)
mimic_method_toggle = gr.Radio(
choices=["Influence", "Shapley"],
value="Influence",
label="Attribution method",
info=(
"Influence (default) is always positive — clearer for clinicians. "
"Shapley is signed (green = supports caption, red = detracts)."
),
interactive=True,
)
mimic_caption = gr.Textbox(
label="Radiology Impression (Caption)",
interactive=False,
lines=2,
)
with gr.Accordion("Full Radiology Findings", open=False):
mimic_findings = gr.Textbox(
label="Detailed Findings",
interactive=False,
lines=5,
)
# ── Original Image ────────────────────────────────
mimic_original = gr.Image(label="Chest X-ray", type="filepath")
mimic_interpretation = gr.Markdown(
value="*Select an example above to see the attribution analysis.*",
label="Interpretation",
)
# ── Table of Contents ─────────────────────────────
_mimic_pill = (
'style="display:inline-block;padding:6px 14px;background:#e3f2fd;'
'border-radius:16px;text-decoration:none;color:#1565c0;font-size:0.9em;'
'border:1px solid #bbdefb;"'
)
_mimic_toc_html = (
'<div style="background:#f8f9fa;border:1px solid #dee2e6;border-radius:8px;'
'padding:16px;margin:12px 0;">'
'<strong style="font-size:1.05em;">Jump to Section:</strong>'
'<div style="display:flex;flex-wrap:wrap;gap:8px;margin-top:10px;">'
f'<a href="#mimic-method-biomedclip" {_mimic_pill}>BiomedCLIP Cross-Modal</a>'
f'<a href="#mimic-method-llavamed" {_mimic_pill}>LLaVA-Med (UnSAM)</a>'
f'<a href="#mimic-interactive" {_mimic_pill}>Interactive View</a>'
f'<a href="#mimic-compare" {_mimic_pill}>Compare Methods</a>'
'</div></div>'
)
gr.HTML(value=_mimic_toc_html)
mimic_results_state = gr.State({})
# ════════════════════════════════════════════════════
# ── Interactive Cross-Modal View ───────────────────
# ════════════════════════════════════════════════════
with gr.Column(elem_id="mimic-interactive"):
gr.Markdown("---\n### BiomedCLIP Cross-Modal Interaction View — click segments or words")
gr.Markdown(
"**How to use:** Click any **image region** to see which caption words "
"it connects to, or click a **word** to see which regions activate. \n"
"**Green** arrows = positive interaction. **Red** arrows = negative."
)
mimic_biomedclip_interaction_html = _html_component(
"BiomedCLIP Cross-Modal Interaction View")
# ════════════════════════════════════════════════════
# ── BiomedCLIP Cross-Modal Attribution ─────────────
# ════════════════════════════════════════════════════
with gr.Column(elem_id="mimic-method-biomedclip"):
gr.Markdown("---\n### BiomedCLIP Cross-Modal Attribution")
gr.Markdown(
"**What it does:** Uses [BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) "
"— a CLIP model trained on **15 million biomedical figure-caption pairs** — "
"to jointly score image regions (via UnSAM segmentation) and caption tokens. \n"
"**How to read:** **Green** = positive Shapley value (contributes to alignment). "
"**Red** = negative (hurts alignment)."
)
mimic_biomedclip_overlay = gr.Image(
label="BiomedCLIP Overlay (labeled segments)", type="filepath")
mimic_biomedclip_token_plot = gr.Plot(
label="BiomedCLIP — Caption Word Shapley Values")
mimic_biomedclip_region_plot = gr.Plot(
label="BiomedCLIP — Image Region Shapley Values")
# ════════════════════════════════════════════════════
# ── LLaVA-Med Attribution (UnSAM Segments) ─────────
# ════════════════════════════════════════════════════
with gr.Column(elem_id="mimic-method-llavamed"):
gr.Markdown("---\n### LLaVA-Med Attribution (4×4 Patch Grid, P1–P16)")
gr.Markdown(
"**What it does:** Uses [LLaVA-Med](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) "
"— a **7B parameter** medical VLM — evaluated over a **uniform 4×4 patch grid** "
"(16 cells labeled **P1–P16**, row-major). \n"
"**Two scoring approaches:** \n"
"- **Log-Prob:** How removing a region affects confidence in the correct caption \n"
"- **Generation:** How removing a region changes what the model describes"
)
gr.Markdown(
"Each method colors segments by its own Shapley values — "
"**green** = positive, **red** = negative. Signs often differ "
"between Log-Prob and Generation, so each has its own overlay."
)
with gr.Row(equal_height=True):
mimic_llavamed_unsam_lp_overlay = gr.Image(
label="LLaVA-Med Log-Prob — Overlay",
type="filepath", height=600)
mimic_llavamed_unsam_gen_overlay = gr.Image(
label="LLaVA-Med Generation — Overlay",
type="filepath", height=600)
with gr.Row():
mimic_llavamed_unsam_lp_plot = gr.Plot(
label="LLaVA-Med Log-Prob — Segment Shapley Values")
mimic_llavamed_unsam_gen_plot = gr.Plot(
label="LLaVA-Med Generation — Segment Shapley Values")
# ════════════════════════════════════════════════════
# ── Compare Two Methods Side-by-Side ──────────────
# ════════════════════════════════════════════════════
with gr.Column(elem_id="mimic-compare"):
gr.Markdown("---\n### Compare Two Methods Side-by-Side")
gr.Markdown(
"Select two methods to compare their attribution overlays "
"and Shapley value distributions on the same image."
)
with gr.Row():
mimic_compare_method_a = gr.Dropdown(
choices=_MIMIC_METHOD_NAMES,
label="Method A",
interactive=True,
)
mimic_compare_method_b = gr.Dropdown(
choices=_MIMIC_METHOD_NAMES,
label="Method B",
interactive=True,
)
with gr.Row():
mimic_compare_img_a = gr.Image(label="Method A — Overlay", type="filepath")
mimic_compare_img_b = gr.Image(label="Method B — Overlay", type="filepath")
with gr.Row():
mimic_compare_plot_a = gr.Plot(label="Method A — Shapley Values")
mimic_compare_plot_b = gr.Plot(label="Method B — Shapley Values")
mimic_meta = gr.JSON(label="Example Info", visible=False)
_mimic_outputs = [
mimic_caption, mimic_original, mimic_findings, mimic_interpretation,
mimic_biomedclip_overlay, mimic_biomedclip_token_plot,
mimic_biomedclip_region_plot, mimic_llavamed_unsam_lp_overlay,
mimic_llavamed_unsam_lp_plot, mimic_llavamed_unsam_gen_overlay,
mimic_llavamed_unsam_gen_plot, mimic_biomedclip_interaction_html,
mimic_meta, mimic_results_state, mimic_compare_method_a,
]
mimic_selector.change(
fn=_on_select_mimic_example,
inputs=[mimic_selector, mimic_method_toggle],
outputs=_mimic_outputs,
)
mimic_method_toggle.change(
fn=_on_select_mimic_example,
inputs=[mimic_selector, mimic_method_toggle],
outputs=_mimic_outputs,
)
# Wire: comparison dropdowns -> side-by-side display
for _mimic_cmp_dd in [mimic_compare_method_a, mimic_compare_method_b]:
_mimic_cmp_dd.change(
fn=_on_mimic_compare_methods,
inputs=[mimic_compare_method_a, mimic_compare_method_b,
mimic_results_state],
outputs=[mimic_compare_img_a, mimic_compare_img_b,
mimic_compare_plot_a, mimic_compare_plot_b],
)
# ── ISIC Dermoscopy Tab ────────────────────────
with gr.Tab("Dermoscopy ISIC (10 Samples)"):
gr.Markdown(
"**10-sample ISIC-2019 dermoscopy attribution benchmark** "
"(8 diagnostic classes: MEL × 2, NV × 2, BCC, AK, BKL, DF, VASC, SCC). \n"
"Source: [ISIC_2019_224](https://huggingface.co/datasets/MKZuziak/ISIC_2019_224) "
"— dermoscopic skin-lesion images from the International Skin Imaging Collaboration. \n"
"Captions are synthesized from class labels (clinical descriptions of each diagnosis)."
)
_isic_choices = (
[(v["category"], k) for k, v in ISIC_EXAMPLES.items()]
if _ISIC_AVAILABLE else []
)
isic_selector = gr.Dropdown(
choices=_isic_choices,
value=None,
label="Filter by Diagnosis",
interactive=True,
)
isic_method_toggle = gr.Radio(
choices=["Influence", "Shapley"],
value="Influence",
label="Attribution method",
info=(
"Influence (default) is always positive — clearer for clinicians. "
"Shapley is signed (green = supports caption, red = detracts)."
),
interactive=True,
)
isic_caption = gr.Textbox(
label="Diagnostic Caption",
interactive=False,
lines=3,
)
isic_original = gr.Image(label="Dermoscopic Image", type="filepath")
isic_interpretation = gr.Markdown(
value="*Select an example above to see the attribution analysis.*",
label="Interpretation",
)
_isic_pill = (
'style="display:inline-block;padding:6px 14px;background:#e3f2fd;'
'border-radius:16px;text-decoration:none;color:#1565c0;font-size:0.9em;'
'border:1px solid #bbdefb;"'
)
_isic_toc_html = (
'<div style="background:#f8f9fa;border:1px solid #dee2e6;border-radius:8px;'
'padding:16px;margin:12px 0;">'
'<strong style="font-size:1.05em;">Jump to Section:</strong>'
'<div style="display:flex;flex-wrap:wrap;gap:8px;margin-top:10px;">'
f'<a href="#isic-method-biomedclip" {_isic_pill}>BiomedCLIP Cross-Modal</a>'
f'<a href="#isic-method-llavamed" {_isic_pill}>LLaVA-Med (UnSAM)</a>'
f'<a href="#isic-interactive" {_isic_pill}>Interactive View</a>'
f'<a href="#isic-compare" {_isic_pill}>Compare Methods</a>'
'</div></div>'
)
gr.HTML(value=_isic_toc_html)
isic_results_state = gr.State({})
with gr.Column(elem_id="isic-interactive"):
gr.Markdown("---\n### BiomedCLIP Cross-Modal Interaction View — click segments or words")
gr.Markdown(
"**How to use:** Click any **image region** to see which caption words "
"it connects to, or click a **word** to see which regions activate."
)
isic_biomedclip_interaction_html = _html_component(
"BiomedCLIP Cross-Modal Interaction View")
with gr.Column(elem_id="isic-method-biomedclip"):
gr.Markdown("---\n### BiomedCLIP Cross-Modal Attribution")
gr.Markdown(
"**What it does:** Uses [BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) "
"to jointly score dermoscopic image regions (via UnSAM segmentation) "
"and caption tokens. \n"
"**How to read:** **Influence** bars (default) show positive importance. "
"Switch to **Shapley** above for signed values (green/red)."
)
isic_biomedclip_overlay = gr.Image(
label="BiomedCLIP Overlay (labeled segments)", type="filepath")
isic_biomedclip_token_plot = gr.Plot(
label="BiomedCLIP — Caption Word Values")
isic_biomedclip_region_plot = gr.Plot(
label="BiomedCLIP — Image Region Values")
with gr.Column(elem_id="isic-method-llavamed"):
gr.Markdown("---\n### LLaVA-Med Attribution (4×4 Patch Grid, P1–P16)")
gr.Markdown(
"**What it does:** Uses [LLaVA-Med](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) "
"— a **7B parameter** medical VLM — evaluated over a **uniform 4×4 patch grid** "
"(16 cells labeled **P1–P16**, row-major). \n"
"**Two scoring approaches:** \n"
"- **Log-Prob:** How removing a region affects confidence in the caption \n"
"- **Generation:** How removing a region changes what the model describes"
)
with gr.Row(equal_height=True):
isic_llavamed_unsam_lp_overlay = gr.Image(
label="LLaVA-Med Log-Prob — Overlay",
type="filepath", height=600)
isic_llavamed_unsam_gen_overlay = gr.Image(
label="LLaVA-Med Generation — Overlay",
type="filepath", height=600)
with gr.Row():
isic_llavamed_unsam_lp_plot = gr.Plot(
label="LLaVA-Med Log-Prob — Segment Values")
isic_llavamed_unsam_gen_plot = gr.Plot(
label="LLaVA-Med Generation — Segment Values")
with gr.Column(elem_id="isic-compare"):
gr.Markdown("---\n### Compare Two Methods Side-by-Side")
gr.Markdown(
"Select two methods to compare their attribution overlays "
"and value distributions on the same image."
)
with gr.Row():
isic_compare_method_a = gr.Dropdown(
choices=_ISIC_METHOD_NAMES,
label="Method A",
interactive=True,
)
isic_compare_method_b = gr.Dropdown(
choices=_ISIC_METHOD_NAMES,
label="Method B",
interactive=True,
)
with gr.Row():
isic_compare_img_a = gr.Image(label="Method A — Overlay", type="filepath")
isic_compare_img_b = gr.Image(label="Method B — Overlay", type="filepath")
with gr.Row():
isic_compare_plot_a = gr.Plot(label="Method A — Values")
isic_compare_plot_b = gr.Plot(label="Method B — Values")
isic_meta = gr.JSON(label="Example Info", visible=False)
_isic_outputs = [
isic_caption, isic_original, isic_interpretation,
isic_biomedclip_overlay, isic_biomedclip_token_plot,
isic_biomedclip_region_plot, isic_llavamed_unsam_lp_overlay,
isic_llavamed_unsam_lp_plot, isic_llavamed_unsam_gen_overlay,
isic_llavamed_unsam_gen_plot, isic_biomedclip_interaction_html,
isic_meta, isic_results_state, isic_compare_method_a,
]
isic_selector.change(
fn=_on_select_isic_example,
inputs=[isic_selector, isic_method_toggle],
outputs=_isic_outputs,
)
isic_method_toggle.change(
fn=_on_select_isic_example,
inputs=[isic_selector, isic_method_toggle],
outputs=_isic_outputs,
)
for _isic_cmp_dd in [isic_compare_method_a, isic_compare_method_b]:
_isic_cmp_dd.change(
fn=_on_isic_compare_methods,
inputs=[isic_compare_method_a, isic_compare_method_b,
isic_results_state],
outputs=[isic_compare_img_a, isic_compare_img_b,
isic_compare_plot_a, isic_compare_plot_b],
)
# ── MS-COCO Tab ─────────────────────────────────
with gr.Tab("MS-COCO (5 Samples)"):
gr.Markdown(
"**CLIP cross-modal attribution on MS-COCO natural images.** \n"
"Click an **image region** or **caption word** below to explore "
"which parts of the image and text are most strongly linked via "
"CLIP's visual-language similarity score."
)
_coco_choices = (
[(v["title"], k) for k, v in COCO_EXAMPLES.items()]
if _COCO_AVAILABLE else []
)
_coco_default = _coco_choices[0][1] if _coco_choices else None
coco_selector = gr.Radio(
choices=_coco_choices,
value=_coco_default,
label="Select MS-COCO Example",
interactive=True,
)
coco_method_toggle = gr.Radio(
choices=["Influence", "Shapley"],
value="Influence",
label="Attribution method",
info="Influence (default) is always positive. Shapley is signed.",
interactive=True,
)
coco_caption = gr.Textbox(
label="Caption", interactive=False, lines=2,
)
gr.Markdown("---\n#### Interactive Cross-Modal View")
gr.Markdown(
"Click a colored **image region** (left) to highlight the caption "
"words it interacts with, or click a **word** (right) to highlight "
"linked regions. Green = positive, red = negative."
)
coco_interaction_html = _html_component(
"COCO Cross-Modal Interaction View")
gr.Markdown("---\n#### Attribution Details")
with gr.Row():
coco_original = gr.Image(
label="Original Image", type="filepath")
coco_overlay = gr.Image(
label="CLIP Overlay (labeled segments)", type="filepath")
with gr.Row():
coco_token_plot = gr.Plot(
label="Caption Word Shapley Values")
coco_region_plot = gr.Plot(
label="Image Region Shapley Values")
with gr.Row():
coco_cross_plot = gr.Plot(
label="Top Image x Word Interactions")
coco_cross_table = gr.Dataframe(
headers=["Image Region", "Caption Word", "Score"],
label="Cross-Modal Interaction Table",
interactive=False,
)
with gr.Accordion("Influence Heatmap (Regions x Words)",
open=False):
coco_heatmap = gr.Plot(
label="Full Heatmap: Regions x Caption Words")
gr.Markdown("---\n#### Masked Image Browser")
gr.Markdown(
"Browse ablation images: **solo** shows only the selected region "
"(everything else inpainted away); **removed** shows the image with "
"that region inpainted out."
)
with gr.Row():
coco_masked_dd = gr.Dropdown(
choices=[],
label="Region / View",
interactive=True,
)
coco_masked_img = gr.Image(
label="Masked View", type="filepath")
coco_note = gr.Markdown(value="")
_coco_outputs = [
coco_caption,
coco_original,
coco_overlay,
coco_interaction_html,
coco_token_plot,
coco_region_plot,
coco_cross_plot,
coco_cross_table,
coco_heatmap,
coco_note,
coco_masked_img,
coco_masked_dd,
]
coco_selector.change(
fn=_on_select_coco_example,
inputs=[coco_selector, coco_method_toggle],
outputs=_coco_outputs,
)
coco_method_toggle.change(
fn=_on_select_coco_example,
inputs=[coco_selector, coco_method_toggle],
outputs=_coco_outputs,
)
coco_masked_dd.change(
fn=_on_select_coco_masked,
inputs=[coco_selector, coco_masked_dd],
outputs=[coco_masked_img],
)
# NOTE: auto-load removed — too much data on startup crashes the browser.
# Users select an example via the Radio to trigger loading.
gr.HTML(
'<div style="margin-top:32px;padding:20px;border-top:1px solid #e5e7eb;text-align:center;">'
'<p style="font-weight:600;margin-bottom:10px;">Contributors — University of California, Berkeley</p>'
'<p style="display:flex;justify-content:center;gap:40px;flex-wrap:wrap;font-size:0.9em;">'
'<span><strong>Stephen Tao</strong> · Loader Layer · '
'<a href="mailto:stephen_tao@berkeley.edu" style="color:#6366f1;">stephen_tao@berkeley.edu</a></span>'
'<span><strong>Yiting Gao</strong> · Attribution Layer · '
'<a href="mailto:yg2025@berkeley.edu" style="color:#6366f1;">yg2025@berkeley.edu</a></span>'
'<span><strong>Qingpeng Kong</strong> · Visualization Layer · '
'<a href="mailto:qpkong@berkeley.edu" style="color:#6366f1;">qpkong@berkeley.edu</a></span>'
'</p>'
'<p style="display:flex;justify-content:center;gap:40px;flex-wrap:wrap;font-size:0.9em;margin-top:6px;">'
'<span><strong>Advisor:</strong> Kannan Ramchandran · '
'<a href="mailto:kannanr@berkeley.edu" style="color:#6366f1;">kannanr@berkeley.edu</a></span>'
'<span><strong>Mentor:</strong> Landon Butler · '
'<a href="mailto:landonb@berkeley.edu" style="color:#6366f1;">landonb@berkeley.edu</a></span>'
'</p></div>'
)
# Stash CSS for Gradio 6.x launch() (Blocks(css=) is deprecated in 6.x)
app._custom_css = custom_css
return app
def _launch_kwargs(app_or_demo, **kwargs):
"""Build common launch kwargs, injecting CSS for Gradio 6.x."""
lk = dict(
server_name=kwargs.pop("server_name", os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")),
server_port=int(kwargs.pop("server_port", os.getenv("GRADIO_SERVER_PORT", "7860"))),
share=kwargs.pop("share", _env_flag("GRADIO_SHARE", False)),
show_error=kwargs.pop("show_error", True),
)
css = getattr(app_or_demo, "_custom_css", None)
if css and _supports_kwarg(app_or_demo.launch, "css"):
lk["css"] = css
lk.update(kwargs)
return lk
def launch_demo(**kwargs):
demo = build_demo_app()
demo.launch(**_launch_kwargs(demo, **kwargs))
def launch_app(**kwargs):
app = build_app()
app.launch(**_launch_kwargs(app, **kwargs))
if __name__ == "__main__":
launch_app()