import base64 import copy import inspect import io import json import os import random import re from itertools import combinations from pathlib import Path from typing import Dict, Any, List, Tuple, Optional import requests # new BACKEND_URL = os.getenv("ATTRLLM_BACKEND_URL", "http://127.0.0.1:8000") _DEFAULT_GRADIO_DIR = Path(os.environ.get("GRADIO_TEMP_DIR", Path.cwd() / ".gradio_tmp")) os.environ.setdefault("GRADIO_TEMP_DIR", str(_DEFAULT_GRADIO_DIR)) _DEFAULT_GRADIO_DIR.mkdir(parents=True, exist_ok=True) def _get_request_timeout() -> float: value = os.getenv("ATTRLLM_REQUEST_TIMEOUT") if not value: return 900.0 try: return float(value) except ValueError: return 900.0 def _env_flag(name: str, default: bool = False) -> bool: value = os.getenv(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "y", "on"} def _is_hf_spaces() -> bool: return bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE")) def _supports_kwarg(callable_obj, kwarg_name: str) -> bool: """Return whether a callable appears to accept a named keyword argument.""" try: return kwarg_name in inspect.signature(callable_obj).parameters except (TypeError, ValueError): return False def _public_only_mode() -> bool: # Keep the text tab visible on Spaces unless explicitly overridden. return _env_flag("ATTRLLM_PUBLIC_ONLY", False) def _mm_only_mode() -> bool: return _env_flag("ATTRLLM_MM_ONLY", False) def _show_auxiliary_tabs() -> bool: return _env_flag("ATTRLLM_SHOW_AUX_TABS", False) def _public_results_file( dataset_key: str, ex_id: str, scalarizer: str, level: str, method: str, ) -> Path: results_dir = _get_results_dir() return ( results_dir / "public" / dataset_key / ex_id / scalarizer / level / f"{method}.json" ) def _reference_results_file( model_size: str, dataset_key: str, ex_id: str, scalarizer: str, level: str, ) -> Path: results_dir = _get_results_dir() return ( results_dir / "reference_answer" / model_size / dataset_key / ex_id / scalarizer / f"{level}.json" ) def _find_available_model_size( dataset_key: str, ex_id: str, scalarizer: str, level: str, ) -> Optional[str]: for size in ("large", "medium", "small"): if _reference_results_file(size, dataset_key, ex_id, scalarizer, level).exists(): return size return None # Fallback order when requested (scalarizer, level) is not present (e.g. on HF Space with partial results). _FALLBACK_SCALARIZER_LEVELS: List[Tuple[str, str]] = [ ("geomean_jointprob", "word"), ("semantic_similarity", "word"), ("geomean_jointprob", "sentence"), ("semantic_similarity", "sentence"), ("geomean_jointprob", "paragraph"), ("semantic_similarity", "paragraph"), ] def _find_any_available_result( dataset_key: str, ex_id: str, get_res: Any, method: str = "shapley", ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[Dict]]: """Try (model_size, scalarizer, level) fallbacks; return (size, scalarizer, level, result_dict) or (None,)*4.""" for size in ("small", "medium", "large"): for scalarizer, level in _FALLBACK_SCALARIZER_LEVELS: try: result = get_res(size, dataset_key, ex_id, scalarizer=scalarizer, feature_level=level) or {} payload = result.get(method, {}) if payload and (payload.get("features") or payload.get("heatmap")): return (size, scalarizer, level, result) except Exception: continue return (None, None, None, None) def _parse_sparse_key(raw_key: str) -> Tuple[int, ...]: key = str(raw_key).strip() if not key: return () return tuple(int(part) for part in key.split(",") if part != "") def _normalize_public_payload_fallback(data: Dict[str, Any], method: str, top_k: int = 10) -> Dict[str, Any]: """Convert your JSON (features list + meta + mobius_dict) to UI display format. mobius_dict can be empty.""" if not isinstance(data, dict): return {} features = data.get("features") mobius_raw = data.get("mobius_dict") if isinstance(data.get("mobius_dict"), dict) else {} if not isinstance(features, list) or not features: return {} method = (method or "shapley").lower() if method not in {"shapley", "banzhaf", "influence"}: method = "shapley" mobius_sparse: Dict[Tuple[int, ...], float] = {} for key, raw_val in mobius_raw.items(): try: val = float(raw_val) except Exception: continue try: loc = _parse_sparse_key(str(key)) except Exception: continue mobius_sparse[tuple(sorted(loc))] = val token_scores: Dict[str, float] = {} index_scores: Dict[int, float] = {} pairwise_acc: Dict[Tuple[int, int], float] = {} if mobius_sparse and mobius_to_shapley is not None: if method == "shapley": singleton_dict = mobius_to_shapley(mobius_sparse) pair_list = shapley_interactions(mobius_sparse, order=2, top_k=top_k) or [] elif method == "banzhaf": singleton_dict = mobius_to_banzhaf(mobius_sparse) pair_list = banzhaf_interactions(mobius_sparse, order=2, top_k=top_k) or [] elif mobius_to_influence is not None and influence_interactions is not None: singleton_dict = mobius_to_influence(mobius_sparse) pair_list = influence_interactions(mobius_sparse, order=2, top_k=top_k) or [] else: singleton_dict = {} pair_list = [] for loc, val in singleton_dict.items(): if len(loc) != 1: continue idx = int(loc[0]) if 0 <= idx < len(features): feat_name = str(features[idx]) val_f = float(val) token_scores[feat_name] = token_scores.get(feat_name, 0.0) + val_f index_scores[idx] = index_scores.get(idx, 0.0) + val_f for loc, val in pair_list: if len(loc) != 2: continue i, j = int(loc[0]), int(loc[1]) if 0 <= i < len(features) and 0 <= j < len(features): key = (i, j) if i <= j else (j, i) pairwise_acc[key] = float(val) else: # Best-effort fallback when attribution helpers are unavailable. for loc, val in mobius_sparse.items(): k = len(loc) if k == 0: continue if method == "shapley": sw = 1.0 / float(k) elif method == "banzhaf": sw = 1.0 / float(2 ** (k - 1)) else: sw = 1.0 / float(k) for idx in loc: if 0 <= idx < len(features): feat_name = str(features[idx]) token_scores[feat_name] = token_scores.get(feat_name, 0.0) + sw * val index_scores[idx] = index_scores.get(idx, 0.0) + sw * val if k >= 2: if method == "shapley": pw = 1.0 / float(k - 1) elif method == "banzhaf": pw = 1.0 / float(2 ** (k - 2)) else: pw = 1.0 / float(k - 1) for i, j in combinations(sorted(loc), 2): pairwise_acc[(i, j)] = pairwise_acc.get((i, j), 0.0) + pw * val unique_feature_labels = [str(x) for x in features] sorted_pairs = sorted(pairwise_acc.items(), key=lambda kv: abs(kv[1]), reverse=True) if top_k and top_k > 0: sorted_pairs = sorted_pairs[:top_k] pairwise = { "%s|%s" % (unique_feature_labels[i], unique_feature_labels[j]): float(v) for (i, j), v in sorted_pairs if 0 <= i < len(unique_feature_labels) and 0 <= j < len(unique_feature_labels) } pairwise_interactions = [ {"features": [unique_feature_labels[i], unique_feature_labels[j]], "value": float(v)} for (i, j), v in sorted_pairs if 0 <= i < len(unique_feature_labels) and 0 <= j < len(unique_feature_labels) ] normalized = dict(data) normalized["token_scores"] = token_scores normalized["pairwise"] = pairwise normalized["pairwise_interactions"] = pairwise_interactions normalized["features"] = [ {"feature": str(features[i]), "value": float(index_scores.get(i, 0.0)), "index": i} for i in range(len(features)) ] normalized["feature_texts"] = [str(x) for x in features] return normalized def _public_get_model_answer_short_from_file( model_size: str, dataset: str, ex_id: str, scalarizer: str = "geomean_jointprob", feature_level: str = "word", ) -> Dict[str, Any]: """Load model_answer_short payload (the wrong-answer attribution) for the Public Mode dual-heatmap branch. Returns a per-method dict shaped like `_public_get_result_from_file`, or {} when the file is missing. """ results_dir = _get_results_dir() path = ( results_dir / "model_answer_short" / model_size / dataset / ex_id / scalarizer / f"{feature_level}.json" ) if not path.exists(): return {} try: with path.open("r", encoding="utf-8") as f: data = json.load(f) except Exception: return {} if not isinstance(data, dict): return {} norm_s = _normalize_public_payload_fallback(copy.deepcopy(data), "shapley") norm_b = _normalize_public_payload_fallback(copy.deepcopy(data), "banzhaf") norm_i = _normalize_public_payload_fallback(copy.deepcopy(data), "influence") if not norm_s and not norm_b and not norm_i: return {} return { "shapley": norm_s, "banzhaf": norm_b, "influence": norm_i, "meta": { "dataset": dataset, "example_id": ex_id, "model_size": model_size, "target_mode": "model_answer_short", "source_layout": "results/model_answer_short/{model_size}/{dataset}/{example_id}/{scalarizer}/{feature_level}.json", }, } def _public_get_result_from_file( model_size: str, dataset: str, ex_id: str, scalarizer: Optional[str] = None, feature_level: Optional[str] = None, ) -> Dict[str, Any]: """Load reference_answer result from disk when loader.results.get_result_by_id is unavailable (e.g. on Space).""" scalarizer = (scalarizer or "").strip() feature_level = (feature_level or "").strip() if not scalarizer or not feature_level: return {} levels_to_try = [feature_level] + [l for l in ("word", "sentence", "paragraph") if l != feature_level] for lvl in levels_to_try: path = _reference_results_file(model_size, dataset, ex_id, scalarizer, lvl) if not path.exists() or os.getenv("SPACE_ID"): try: from loader.results import _maybe_download_from_space path = _maybe_download_from_space(path, force_download=True) or path except Exception: pass if not path.exists(): continue try: with path.open("r", encoding="utf-8") as f: data = json.load(f) except Exception: continue if not isinstance(data, dict): continue # Your JSON: features (list) + meta + mobius_dict (can be empty). Always convert to UI format. norm_s = _normalize_public_payload_fallback(copy.deepcopy(data), "shapley") norm_b = _normalize_public_payload_fallback(copy.deepcopy(data), "banzhaf") norm_i = _normalize_public_payload_fallback(copy.deepcopy(data), "influence") if not norm_s and not norm_b and not norm_i: continue return { "shapley": norm_s, "banzhaf": norm_b, "influence": norm_i, "meta": { "dataset": dataset, "example_id": ex_id, "model_size": model_size, "source_layout": "results/reference_answer/{model_size}/{dataset}/{example_id}/{scalarizer}/{feature_level}.json", }, } return {} _FALLBACK_DATASET_FILES: Dict[str, str] = { "bar_exam": "BarExam_qa.csv", "causal_judgment": "bbh_causal_judgement.csv", "snarks": "bbh_snarks.csv", "bbq_disamb": "BBQ_disamb.csv", "cnn_dailymail": "CNN_dailymail.csv", "drop": "drop.csv", "esnli": "eSNLI.csv", "fever": "fever.csv", "hotpot_qa": "hotpot_qa.csv", "medical_qa": "medical_qa.csv", } def _fallback_datasets_dir() -> Path: return (_REPO_ROOT / "datasets").resolve() def _fallback_pick_first_nonempty(raw: Dict[str, str], candidates: List[str]) -> str: for c in candidates: val = raw.get(c) if val is not None and str(val).strip() != "": return str(val) return "" def _fallback_load_dataset(dataset_key: str, max_rows: int = 10) -> List[Dict[str, str]]: import csv filename = _FALLBACK_DATASET_FILES.get(dataset_key) if not filename: return [] path = _fallback_datasets_dir() / filename if not path.exists(): return [] rows: List[Dict[str, str]] = [] with path.open("r", encoding="utf-8", errors="replace", newline="") as f: reader = csv.DictReader(f) for i, raw in enumerate(reader, start=1): ex_id = raw.get("id") or raw.get("example_id") or raw.get("uid") or f"example_{i}" context = _fallback_pick_first_nonempty(raw, [ "Context", "context", "passage", "article", "story", "premise", "paragraph", "document", "sentence1", "sent1", "background", ]) prompt = _fallback_pick_first_nonempty(raw, [ "Prompt", "prompt", "question", "input", "query", "sentence2", "sent2", "hypothesis", "qa_question", "title", ]) answer = _fallback_pick_first_nonempty(raw, [ "Answer", "answer", "target", "gold", "label", "output", "reference", "highlights", ]) ex = { "id": str(ex_id), "context": context, "prompt": prompt, } if answer: ex["answer"] = answer rows.append(ex) if len(rows) >= max_rows: break return rows REQUEST_TIMEOUT = _get_request_timeout() SCALARIZER_CHOICES = [ ("Semantic Similarity (y vs y_S)", "semantic_similarity"), ("LogProb", "logprob"), ("JointProb", "jointprob"), ("GeoMean JointProb", "geomean_jointprob"), ("Half SimLog", "half_simlog"), ] PUBLIC_SCALARIZER_CHOICES = [ ("Semantic Similarity", "semantic_similarity"), ("Perplexity", "geomean_jointprob"), ] DATASET_DISPLAY_LABELS = { "bar_exam": "Bar Exam Questions", "bbq_disamb": "BBQ Disambiguation", "causal_judgment": "Causal Judgment", "cnn_dailymail": "CNN / DailyMail Summaries", "drop": "DROP Reading Comprehension", "esnli": "e-SNLI Natural Language Inference", "fever": "FEVER Fact Checking", "hotpot_qa": "HotpotQA Multi-hop Questions", "medical_qa": "Medical Questions", "snarks": "Snarks", } import sys _REPO_ROOT = Path(__file__).resolve().parents[1] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) def _get_results_dir() -> Path: """Resolve results directory: env, repo root, or on HF Space fallback to cwd/results.""" env_dir = os.getenv("ATTRLLM_RESULTS_DIR") if env_dir: return Path(env_dir).resolve() default = (_REPO_ROOT / "results").resolve() if default.exists(): return default if _is_hf_spaces(): cwd_results = (Path.cwd() / "results").resolve() if cwd_results.exists(): return cwd_results return default import gradio as gr from PIL import Image from .components.model_selector import ( create_model_selector, create_multimodal_model_selector, create_feature_level_selector, create_attribution_method_toggle, ) from .components.example_browser import create_dataset_selector, create_example_browser from .components.results_display import create_results_display, update from .plotting.heatmap import create_interactive_text_heatmap from .plotting.interactions import ( plot_top_interactions, plot_interaction_matrix, create_interaction_token_view, ) from .plotting.text_interactions import create_text_interaction_html from .plotting.mm_interactions import create_multimodal_interaction_html from .plotting.coalition_viewer import compute_coalition_viewer_data, render_coalition_viewer_html from .build_info import BUILD_ID, BUILD_TS # Medical image precomputed results (optional) try: from .medical_loader import ( MEDICAL_EXAMPLES, load_medical_example, get_masked_image_path, BENCHMARK_EXAMPLES, get_examples_by_modality, list_available_modalities, load_benchmark_example, extract_segment_regions, ) from .plotting.medical_charts import ( create_shapley_bar_chart, create_influence_heatmap, create_cross_modal_bar_chart, draw_grid_overlay, draw_segment_labels, generate_interpretation_text, rename_patch_labels, align_segments_to_reference, remap_region_values, merge_subword_token_values, _tok_to_word, ) from .plotting.benchmark_interaction import create_benchmark_interaction_html _MEDICAL_AVAILABLE = True except ImportError: _MEDICAL_AVAILABLE = False MEDICAL_EXAMPLES = {} BENCHMARK_EXAMPLES = {} # MIMIC-CXR precomputed results (optional) try: from .mimic_loader import ( MIMIC_EXAMPLES, load_mimic_example, get_mimic_image_path, ) _MIMIC_AVAILABLE = bool(MIMIC_EXAMPLES) except ImportError: _MIMIC_AVAILABLE = False MIMIC_EXAMPLES = {} # Dermoscopy ISIC precomputed results (optional) try: from .isic_loader import ( ISIC_EXAMPLES, load_isic_example, get_isic_image_path, ) _ISIC_AVAILABLE = bool(ISIC_EXAMPLES) except ImportError: _ISIC_AVAILABLE = False ISIC_EXAMPLES = {} # MS-COCO precomputed results (optional) try: from .coco_loader import COCO_EXAMPLES, load_coco_example, get_coco_masked_image_path _COCO_AVAILABLE = True except ImportError: _COCO_AVAILABLE = False COCO_EXAMPLES = {} # CLIP cross-modal pipeline for live compute (optional — runs on CPU) try: from attribution.set_mm import ( PipelineConfig, CrossModalCLIPScorer, ImageRegion, TokenPlayer, featurise, tokenise_caption, build_cross_modal_set_function, run_proxyspex, mobius_to_shapley, mobius_to_banzhaf, extract_interactions, extract_cross_per_token, apply_image_mask, render_overlay, render_segmentation_map, mask_token_ids, ) _CLIP_PIPELINE_AVAILABLE = True except ImportError: _CLIP_PIPELINE_AVAILABLE = False # Module-level cache for CLIP scorers (keyed by model name) _clip_scorer_cache: Dict[str, Any] = {} def _raise_backend_error(resp: requests.Response, label: str) -> None: detail = resp.text try: detail = resp.json().get("detail", detail) except Exception: pass raise gr.Error(f"{label} failed ({resp.status_code}). {detail}") # backend API imports try: # loader data APIs are required for public mode from loader.data import ( get_example_by_id, get_examples, list_datasets, list_datasets_with_display_names, list_dataset_display_names, get_dataset_display_name, get_dataset_key_from_display_name, ) except Exception: # pragma: no cover - optional at runtime get_example_by_id = None get_examples = None list_datasets = None list_datasets_with_display_names = None list_dataset_display_names = None get_dataset_display_name = None get_dataset_key_from_display_name = None try: from loader.results import get_result_by_id except Exception: # pragma: no cover get_result_by_id = None try: from loader.models import get_model except Exception: # pragma: no cover get_model = None try: # attribution stack is optional (dev mode) from attribution.masker import get_masker, mask_text from attribution.proxyspex import run_proxyspex from attribution.image_masker import supports_superpixel from attribution.utils import ( influence_interactions, mobius_to_influence, mobius_to_shapley, shapley_interactions, mobius_to_banzhaf, banzhaf_interactions, ) except Exception: # pragma: no cover get_masker = None mask_text = None run_proxyspex = None supports_superpixel = None influence_interactions = None mobius_to_influence = None mobius_to_shapley = None shapley_interactions = None mobius_to_banzhaf = None banzhaf_interactions = None _ANSWER_FIELDS = ( "correct_answer", "answer", "target", "completion", "label", ) _ALLOWED_METHODS = {"shapley", "banzhaf", "influence"} _ALLOWED_LEVELS = {"word", "sentence", "paragraph"} def _ensure_backend(name: str, fn: Optional[Any]): if fn is None: raise RuntimeError( f"{name} is unavailable. Ensure the backend modules are installed and importable." ) return fn def _html_component(label: str, visible: bool = True) -> gr.HTML: try: return gr.HTML(label=label, sanitize_html=False, visible=visible) except TypeError: return gr.HTML(label=label, visible=visible) def _encode_image_to_b64(image: Image.Image) -> str: buffer = io.BytesIO() image.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") def _extract_answer(record: Dict[str, Any]) -> str: for field in _ANSWER_FIELDS: val = record.get(field) if val: return str(val) return "" def _coerce_feature_tuple(raw_key: Any) -> Tuple[str, ...]: if isinstance(raw_key, tuple): return tuple(str(item) for item in raw_key) if isinstance(raw_key, list): return tuple(str(item) for item in raw_key) if isinstance(raw_key, str): for sep in ("·", "|", ",", "×"): if sep in raw_key: parts = [chunk.strip() for chunk in raw_key.split(sep) if chunk.strip()] if parts: return tuple(parts) return (raw_key.strip(),) return (str(raw_key),) # def _normalize_interactions(raw: Any) -> List[Tuple[Tuple[str, ...], float]]: # items: List[Any] # if raw is None: # return [] # if isinstance(raw, dict): # items = list(raw.items()) # else: # items = list(raw) # normalized: List[Tuple[Tuple[str, ...], float]] = [] # for feats, value in items: # try: # numeric = float(value) # except Exception: # continue # normalized.append((_coerce_feature_tuple(feats), numeric)) # return normalized def _normalize_interactions(raw: Any) -> List[Tuple[Tuple[str, ...], float]]: """ Make a best-effort guess at interaction structure. Supported shapes: - { key: float } - { key: {"value": float, "score": ...} } - [ (key, float), ... ] - [ (key, {"value": float}), ... ] - [ {"features": [...], "value": float}, ... ] (this is mostly handled elsewhere) """ if raw is None: return [] items: List[Any] = [] if isinstance(raw, dict): # e.g. { key: float } or { key: {"value": ...} } for k, v in raw.items(): items.append((k, v)) elif isinstance(raw, list): items = list(raw) else: return [] normalized: List[Tuple[Tuple[str, ...], float]] = [] for item in items: # Case 1: dict-style item with explicit fields if isinstance(item, dict): feats = item.get("features") or item.get("indices") or item.get("pair") or item.get("key") val = item.get("value", item.get("score", 0.0)) else: # Case 2: tuple/list pair (feats, value) try: feats, val = item except Exception: continue # If value itself is a dict, dig out "value" / "score" if isinstance(val, dict): val = val.get("value", val.get("score", 0.0)) try: numeric = float(val) except Exception: continue feats_tuple = _coerce_feature_tuple(feats) if feats_tuple: normalized.append((feats_tuple, numeric)) return normalized def _resolve_marginals(payload: Dict[str, Any]) -> Dict[str, float]: for key in ("marginals", "token_scores", "values", "scores"): data = payload.get(key) if isinstance(data, dict): normalized: Dict[str, float] = {} for k, v in data.items(): try: normalized[str(k)] = float(v) except Exception: continue return normalized return {} def _resolve_features(payload: Dict[str, Any], marginals: Dict[str, float]) -> List[str]: features = payload.get("features") if isinstance(features, list): return [str(f) for f in features] if marginals: return list(marginals.keys()) return [] def _extract_interactions_from_response( data_int: Dict[str, Any], method: str, features: List[str], ) -> List[Tuple[Tuple[str, ...], float]]: inter_list: List[Tuple[Tuple[str, ...], float]] = [] method_key = (method or "shapley").lower() method_block = data_int.get(method_key) or data_int raw_interactions = None if isinstance(method_block, dict): for key in ("interactions", "pairwise_interactions", "interactions_2", "pairwise", "data"): if key in method_block: raw_interactions = method_block.get(key) break if raw_interactions is None: raw_interactions = method_block else: raw_interactions = method_block # List-of-dicts or list-of-pairs shape if isinstance(raw_interactions, list) and raw_interactions: if isinstance(raw_interactions[0], dict): for item in raw_interactions: feats = ( item.get("feature_list") or item.get("features") or item.get("indices") or item.get("pair") or [] ) val = None for key_val in ("value", "score", "attribution", "weight"): if key_val in item: try: val = float(item[key_val]) break except Exception: continue if val is None: continue if isinstance(feats, list) and feats and isinstance(feats[0], int): feat_names = tuple( features[i] for i in feats if isinstance(i, int) and 0 <= i < len(features) ) else: feat_names = _coerce_feature_tuple(feats) if feat_names: inter_list.append((feat_names, val)) elif ( isinstance(raw_interactions[0], (list, tuple)) and len(raw_interactions[0]) == 2 ): for item in raw_interactions: if not isinstance(item, (list, tuple)) or len(item) != 2: continue feats_raw, val_raw = item try: val = float(val_raw) except Exception: continue feat_names: Tuple[str, ...] = () if isinstance(feats_raw, (list, tuple)) and feats_raw: if all(isinstance(i, int) for i in feats_raw): feat_names = tuple( features[i] for i in feats_raw if 0 <= i < len(features) ) else: feat_names = _coerce_feature_tuple(feats_raw) elif isinstance(feats_raw, str): feat_names = _coerce_feature_tuple(feats_raw) if feat_names: inter_list.append((feat_names, val)) # Dict shape, e.g. {"(0,2)": 528.0, ...} if not inter_list and isinstance(raw_interactions, dict): metadata_keys = {"method", "order", "scalarizer", "embedding_model"} for k, v in raw_interactions.items(): if str(k) in metadata_keys: continue val = None if isinstance(v, (int, float)): val = float(v) elif isinstance(v, dict): for key_val in ("value", "score", "attribution", "weight"): if key_val in v: try: val = float(v[key_val]) break except Exception: continue if val is None: continue k_str = str(k) idxs = [] try: import re as _re idxs = [int(x) for x in _re.findall(r"\d+", k_str)] except Exception: idxs = [] if idxs: names: List[str] = [] for idx in idxs: if 0 <= idx < len(features): names.append(features[idx]) if names: feat_names = tuple(names) else: feat_names = _coerce_feature_tuple(k_str) else: feat_names = _coerce_feature_tuple(k_str) inter_list.append((feat_names, val)) # Flatten numerics arbitrarily (last resort) if not inter_list and raw_interactions is not None: flat: List[Tuple[Tuple[str, ...], float]] = [] def _collect(obj: Any, prefix: Tuple[str, ...] = ()) -> None: if isinstance(obj, (int, float)): flat.append((prefix or ("",), float(obj))) elif isinstance(obj, list): for i, item in enumerate(obj): _collect(item, prefix + (f"[{i}]",)) elif isinstance(obj, dict): for kk, vv in obj.items(): _collect(vv, prefix + (str(kk),)) _collect(raw_interactions) inter_list = flat return inter_list def _labels_from_regions(regions: List[Dict[str, Any]]) -> List[str]: labels: List[str] = [""] * len(regions) for region in regions: try: idx = int(region.get("index", 0)) except Exception: continue if idx < 0 or idx >= len(labels): continue labels[idx] = str(region.get("label") or f"Region {idx + 1}") for idx, label in enumerate(labels): if not label: labels[idx] = f"Region {idx + 1}" return labels def _interaction_dicts_to_pairs( interactions: List[Dict[str, Any]], labels: List[str], *, order: int | None = None, ) -> List[Tuple[Tuple[str, ...], float]]: pairs: List[Tuple[Tuple[str, ...], float]] = [] for item in interactions: indices = item.get("indices") if not indices: continue if order is not None and len(indices) != order: continue try: value = float(item.get("value", 0.0)) except Exception: continue feats = tuple(labels[int(i)] for i in indices if int(i) < len(labels)) if feats: pairs.append((feats, value)) return pairs def _interaction_dicts_to_table( interactions: List[Dict[str, Any]], labels: List[str], ) -> List[List[Any]]: rows: List[List[Any]] = [] for item in interactions: indices = item.get("indices") if not indices: continue try: value = float(item.get("value", 0.0)) except Exception: continue feats = [labels[int(i)] for i in indices if int(i) < len(labels)] if feats: rows.append([" × ".join(feats), value, len(indices)]) return rows def _feature_display_label( feature: Dict[str, Any], region_labels: List[str], ) -> str: raw = str(feature.get("feature", "")) modality = feature.get("modality") or "" ref_index = int(feature.get("ref_index", 0)) label = raw.split(":", 1)[1] if ":" in raw else raw if modality == "image": if 0 <= ref_index < len(region_labels): return region_labels[ref_index] return label or raw def _extract_feature_series(payload: Dict[str, Any]) -> Tuple[List[str], List[float]]: """ Try to recover an ordered pair of (feature labels, values) from a backend payload. This keeps duplicates in order (appending suffixes later) so word-level tokens don't collapse to a single entry. """ features: List[str] = [] values: List[float] = [] feature_entries = payload.get("features") if isinstance(feature_entries, list) and feature_entries and isinstance(feature_entries[0], dict): for idx, entry in enumerate(feature_entries, start=1): raw_feat = ( entry.get("feature") or entry.get("token") or entry.get("text") or entry.get("label") or "" ) if not raw_feat: raw_feat = f"feature_{idx}" val = entry.get("value") if val is None: for key in ("score", "attribution", "weight"): if key in entry: val = entry[key] break try: values.append(float(val if val is not None else 0.0)) except Exception: values.append(0.0) features.append(str(raw_feat)) if not features: heat = payload.get("heatmap") or {} tokens = heat.get("tokens") or heat.get("features") scores = heat.get("values") or heat.get("scores") if isinstance(tokens, list) and isinstance(scores, list) and len(tokens) == len(scores): features = [str(token if token is not None else f"feature_{idx + 1}") for idx, token in enumerate(tokens)] tmp_vals: List[float] = [] for score in scores: try: tmp_vals.append(float(score)) except Exception: tmp_vals.append(0.0) values = tmp_vals if not features: marginals = _resolve_marginals(payload) if marginals: features = list(marginals.keys()) values = [float(marginals[key]) for key in features] if not features: return [], [] unique_features = _assign_unique_labels(features) return unique_features, values def _resolve_interactions(payload: Dict[str, Any], order: int) -> List[Tuple[Tuple[str, ...], float]]: candidates = [f"interactions_{order}"] if order == 2: candidates += ["pairwise", "pairwise_interactions", "interactions2"] elif order == 3: candidates += ["higher_order", "triple_interactions", "interactions3"] for key in candidates: raw = payload.get(key) normalized = _normalize_interactions(raw) if normalized: return normalized return [] def _fallback_pairwise_from_values( features: List[str], values: List[float], max_edges: int = 40, ) -> List[Tuple[Tuple[str, ...], float]]: """ Generate synthetic pairwise links by connecting neighboring tokens. Used when the backend provides no explicit interactions. """ n = min(len(features), len(values)) if n < 2: return [] edges: List[Tuple[Tuple[str, ...], float]] = [] for idx in range(n - 1): weight = 0.5 * (values[idx] + values[idx + 1]) edges.append(((features[idx], features[idx + 1]), weight)) edges.sort(key=lambda item: abs(item[1]), reverse=True) return edges[:max_edges] def _resolve_pairwise( payload: Dict[str, Any], features: Optional[List[str]] = None, feature_values: Optional[List[float]] = None, ) -> List[Tuple[Tuple[str, ...], float]]: """Convenience helper to always pull order-2 interactions if present.""" pairwise = _resolve_interactions(payload, 2) if pairwise: return pairwise # Some payloads store generic "interactions" lists that mix orders. mixed = payload.get("interactions") normalized = _normalize_interactions(mixed) if normalized: return [item for item in normalized if len(item[0]) == 2] if features and feature_values: return _fallback_pairwise_from_values(features, feature_values) return [] def _normalize_method(method: Optional[str]) -> str: method = (method or "shapley").lower() return method if method in _ALLOWED_METHODS else "shapley" def _normalize_level(level: Optional[str]) -> str: level = (level or "sentence").lower() return level if level in _ALLOWED_LEVELS else "sentence" def _normalize_model_size(model_size: Optional[str]) -> str: raw = (model_size or "small").strip() lowered = raw.lower() if lowered in {"small", "medium", "large"}: return lowered if "small" in lowered: return "small" if "medium" in lowered: return "medium" if "large" in lowered: return "large" return "small" def _assign_unique_labels(chunks: List[str]) -> List[str]: counts: Dict[str, int] = {} labels: List[str] = [] for idx, chunk in enumerate(chunks): normalized = " ".join((chunk or "").split()) if not normalized: normalized = f"" counts[normalized] = counts.get(normalized, 0) + 1 suffix = f" ({counts[normalized]})" if counts[normalized] > 1 else "" labels.append(f"{normalized}{suffix}") return labels def _strip_occurrence_suffix(text: str) -> str: text = text or "" if text.endswith(")") and " (" in text: base, _, tail = text.rpartition(" (") if tail[:-1].isdigit(): return base return text def _pairwise_to_index_interactions( pairwise: List[Tuple[Tuple[str, ...], float]], features: List[str], ) -> List[Dict[str, Any]]: feature_index = {feat: idx for idx, feat in enumerate(features)} base_index: Dict[str, int] = {} for idx, feat in enumerate(features): base_index.setdefault(_strip_occurrence_suffix(feat), idx) interactions: List[Dict[str, Any]] = [] for feats, val in pairwise: if len(feats) != 2: continue a, b = feats a_idx = None b_idx = None if isinstance(a, (int, float)) and isinstance(b, (int, float)): a_idx = int(a) b_idx = int(b) else: try: a_idx = int(str(a)) b_idx = int(str(b)) except ValueError: a_idx = feature_index.get(a) or base_index.get(_strip_occurrence_suffix(str(a))) b_idx = feature_index.get(b) or base_index.get(_strip_occurrence_suffix(str(b))) if a_idx is None or b_idx is None: continue if a_idx < 0 or b_idx < 0 or a_idx >= len(features) or b_idx >= len(features): continue interactions.append({"indices": [a_idx, b_idx], "value": float(val)}) return interactions def _locate_spans(text: str, segments: List[str]) -> List[Tuple[int, int]]: spans: List[Tuple[int, int]] = [] cursor = 0 for segment in segments: if not segment: continue idx = text.find(segment, cursor) if idx == -1: idx = cursor end = idx + len(segment) spans.append((idx, end)) cursor = end return spans def _chunk_text_for_visualization( context: str, level: str, ) -> Tuple[List[str], List[Tuple[int, int]], str]: """ Split input text into feature chunks and spans for visualization. Falls back to the demo text if context is empty. """ text = context or _DEMO_TEXT level = _normalize_level(level) if level == "word": matches = list(re.finditer(r"\S+", text)) chunks = [m.group(0) for m in matches] spans = [(m.start(), m.end()) for m in matches] elif level == "paragraph": parts = [seg for seg in re.split(r"\n\s*\n+", text) if seg.strip()] spans = _locate_spans(text, parts) chunks = parts[: len(spans)] else: # sentence-level default parts = [seg for seg in re.split(r"(?<=[.!?])\s+", text) if seg.strip()] spans = _locate_spans(text, parts) chunks = parts[: len(spans)] if not chunks: chunks = [text] spans = [(0, len(text))] features = _assign_unique_labels(chunks) return features, spans, text def _generate_synthetic_marginals( features: List[str], rng: random.Random, ) -> Dict[str, float]: if not features: return {} max_len = max(len(f) for f in features) or 1 marginals: Dict[str, float] = {} denom = max(1, len(features) - 1) for idx, feat in enumerate(features): length_factor = len(feat) / max_len position_factor = 1 - (idx / denom if denom else 0) noise = rng.uniform(-0.25, 0.25) value = (length_factor - 0.5) * 0.6 + (position_factor - 0.5) * 0.4 + noise marginals[feat] = round(value, 4) return marginals def _generate_synthetic_interactions( features: List[str], marginals: Dict[str, float], rng: random.Random, ) -> Dict[int, List[Tuple[Tuple[str, ...], float]]]: interactions: Dict[int, List[Tuple[Tuple[str, ...], float]]] = {2: [], 3: []} for i in range(len(features) - 1): pair = (features[i], features[i + 1]) base = (marginals.get(pair[0], 0.0) + marginals.get(pair[1], 0.0)) / 2 interactions[2].append((pair, round(base + rng.uniform(-0.1, 0.1), 4))) for i in range(len(features) - 2): triple = (features[i], features[i + 1], features[i + 2]) base = sum(marginals.get(feat, 0.0) for feat in triple) / 3 interactions[3].append((triple, round(base + rng.uniform(-0.1, 0.1), 4))) return interactions def _synthetic_attribution_pipeline( context: str, prompt: str, answer: str, *, method: str, level: str, order: int, reason: Optional[str] = None, ) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]: text_source = context or prompt or answer or _DEMO_TEXT features, spans, text = _chunk_text_for_visualization(text_source, level) seed = hash((text_source, method, level, order)) & 0xFFFFFFFF rng = random.Random(seed) marginals = _generate_synthetic_marginals(features, rng) interactions = _generate_synthetic_interactions(features, marginals, rng) html = None if len(spans) == len(features): html = create_interactive_text_heatmap( text, spans, [marginals.get(f, 0.0) for f in features], method=method, ) meta = { "mode": "synthetic", "reason": reason or "Attribution backend unavailable; showing mock data.", "method": method, "feature_level": level, "interaction_order": order, "feature_count": len(features), } inter_list = interactions.get(order, []) pairwise_for_tokens = interactions.get(2, []) if order != 2 else inter_list if not pairwise_for_tokens: pairwise_for_tokens = _fallback_pairwise_from_values( features, [marginals.get(f, 0.0) for f in features], ) text_interaction_html = create_text_interaction_html( features, [marginals.get(f, 0.0) for f in features], _pairwise_to_index_interactions(pairwise_for_tokens, features), method=method, top_k=20, threshold=0.0, ) figs = { "interactions": plot_top_interactions(inter_list, order=order, method=method), } return update( figs=figs, meta=meta, html=html, interaction_text_html=text_interaction_html, scoring_target_source="answer_input" if answer else "model_output", scoring_target_text=answer or "", reference_answer=answer or "", unmasked_answer="", debug_scores=None, scalarizer_used="logprob", score_full=None, score_empty=None, y_len_tokens=None, ) # def _compute_live_attributions(**kwargs) -> Tuple[Any, Any, Any, Any, Any]: # """ # Placeholder for the real ProxySPEX + perplexity pipeline. # Raises until the attribution backend is implemented. # """ # missing = [ # name # for name, fn in { # "get_model": get_model, # "get_masker": get_masker, # "mask_text": mask_text, # "run_proxyspex": run_proxyspex, # "mobius_to_shapley": mobius_to_shapley, # "mobius_to_banzhaf": mobius_to_banzhaf, # "shapley_interactions": shapley_interactions, # "banzhaf_interactions": banzhaf_interactions, # }.items() # if fn is None # ] # if missing: # raise RuntimeError( # "Missing backend dependencies: " + ", ".join(sorted(missing)) # ) # raise NotImplementedError( # "Live attribution pipeline not wired yet. Integrate once ProxySPEX is ready." # ) def _compute_live_attributions( *, context: str, prompt: str, correct_answer: str, model_size: str, level: str, method: str, order: int, scalarizer: str = "logprob", embedding_model: str | None = None, progress=None, ) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]: """ Call the FastAPI /api/attributions + /api/interactions backends and turn the JSON into figures / table / HTML for Gradio. This version is very defensive and tries hard to extract interactions from whatever shape the backend returns. """ method = _normalize_method(method) level = _normalize_level(level) order = 3 if int(order or 2) >= 3 else 2 context = context or "" prompt = prompt or "" correct_answer = correct_answer or "" text_source = context or prompt or correct_answer or _DEMO_TEXT payload = { "context": context, "answer": correct_answer, "reference_answer": correct_answer, "prompt": prompt, "method": method, "mask_level": level, "order": int(order), "model_size": model_size, "scalarizer": scalarizer, "embedding_model": embedding_model, "debug": False, } if progress is not None: progress(0.1, desc="Calling attribution backend") # ---------- 1. /api/attributions ---------- url_attr = BACKEND_URL.rstrip("/") + "/api/attributions" try: resp_attr = requests.post(url_attr, json=payload, timeout=REQUEST_TIMEOUT) except requests.exceptions.ReadTimeout as exc: raise gr.Error( "Attribution request timed out. The backend may still be running. " "Consider reducing feature granularity or set ATTRLLM_REQUEST_TIMEOUT to a higher value." ) from exc if resp_attr.status_code >= 400: _raise_backend_error(resp_attr, "Attribution request") data_attr = resp_attr.json() if progress is not None: progress(0.35, desc="Received attribution payload") # ---------- 2. FEATURES + MARGINAL VALUES ---------- features, feature_values = _extract_feature_series(data_attr) if not features: features = [""] feature_values = [0.0] marginals = {feat: float(feature_values[idx]) for idx, feat in enumerate(features)} # ---------- 3. /api/interactions ---------- if progress is not None: progress(0.45, desc="Calling interactions backend") url_int = BACKEND_URL.rstrip("/") + "/api/interactions" try: resp_int = requests.post(url_int, json=payload, timeout=REQUEST_TIMEOUT) except requests.exceptions.ReadTimeout as exc: raise gr.Error( "Interaction request timed out. The backend may still be running. " "Consider reducing order or set ATTRLLM_REQUEST_TIMEOUT to a higher value." ) from exc if resp_int.status_code >= 400: _raise_backend_error(resp_int, "Interaction request") data_int = resp_int.json() # DEBUG: see top-level keys print("data_int keys:", list(data_int.keys())) inter_list_all = _extract_interactions_from_response(data_int, method, features) pairwise_for_network = [item for item in inter_list_all if len(item[0]) == 2] used_pairwise_fallback = False inter_list = inter_list_all if inter_list: filtered: List[Tuple[Tuple[str, ...], float]] = [] for feats, val in inter_list: if len(feats) == order: filtered.append((feats, val)) if filtered: inter_list = filtered if order != 2 and not pairwise_for_network: try: payload_pair = dict(payload) payload_pair["order"] = 2 try: resp_pair = requests.post(url_int, json=payload_pair, timeout=REQUEST_TIMEOUT) except requests.exceptions.ReadTimeout as exc: raise gr.Error( "Interaction request timed out. The backend may still be running. " "Consider reducing order or set ATTRLLM_REQUEST_TIMEOUT to a higher value." ) from exc if resp_pair.status_code >= 400: _raise_backend_error(resp_pair, "Interaction request") data_pair = resp_pair.json() pairwise_for_network = [ item for item in _extract_interactions_from_response(data_pair, method, features) if len(item[0]) == 2 ] except Exception as exc: print("Pairwise interaction fetch failed:", exc) if not pairwise_for_network: if method == "influence": pairwise_for_network = [] else: used_pairwise_fallback = True pairwise_for_network = _fallback_pairwise_from_values(features, feature_values) print("LIVE features:", features) print("LIVE inter_list (first 3):", inter_list[:3]) if method == "influence": top_singletons = sorted( list(zip(features, feature_values)), key=lambda kv: abs(float(kv[1])), reverse=True, )[:10] top_pairs = sorted( pairwise_for_network, key=lambda kv: abs(float(kv[1])), reverse=True, )[:10] print( "[influence-ui-debug] " f"pairwise_source={'fallback_neighbors' if used_pairwise_fallback else 'backend'} " f"feature_count={len(features)} pair_count={len(pairwise_for_network)}" , flush=True) print("[influence-ui-debug] top_singletons:", top_singletons, flush=True) print("[influence-ui-debug] top_pairwise:", top_pairs, flush=True) text_interaction_html = create_text_interaction_html( features, feature_values, _pairwise_to_index_interactions(pairwise_for_network, features), method=method, top_k=20, threshold=0.0, ) # ---------- 4. RESCALE VERY SMALL VALUES ---------- max_abs = max((abs(v) for v in marginals.values()), default=0.0) scale = 1.0 if 0 < max_abs < 1e-3: scale = 1e3 if scale != 1.0: marginals = {k: v * scale for k, v in marginals.items()} inter_list = [(feats, val * scale) for feats, val in inter_list] feature_values = [val * scale for val in feature_values] # ---------- 5. INLINE TEXT HEATMAP ---------- spans = None masking = data_attr.get("masking") or data_attr.get("mask") or {} if isinstance(masking, dict): spans = masking.get("feature_spans") or masking.get("spans") html = None if spans and len(spans) == len(feature_values): html = create_interactive_text_heatmap( context or text_source, spans, feature_values, method=method, ) # ---------- 6. PLOTS + TABLES + META ---------- inter_fig = plot_top_interactions(inter_list, order=order, method=method) if progress is not None: progress(0.8, desc="Rendering visualizations") y_len_tokens = data_attr.get("y_len_tokens") scoring_target_source = data_attr.get("scoring_target_source") or "model_output" scoring_target_text = data_attr.get("scoring_target_text") if scoring_target_text is None: scoring_target_text = correct_answer or data_attr.get("y_full") or "" meta = { "mode": "live", "backend_url_attr": url_attr, "backend_url_int": url_int, "method": method, "feature_level": level, "interaction_order": order, "model_size": model_size, "feature_count": len(features), "max_abs_value": max_abs, "scale_applied": scale, "scalarizer": data_attr.get("scalarizer_used", payload.get("scalarizer")), "scoring_target_source": scoring_target_source, "scoring_target_text_preview": str(scoring_target_text)[:200], "score_full": data_attr.get("score_full"), "score_empty": data_attr.get("score_empty"), "y_len_tokens": y_len_tokens, "logprob_full": data_attr.get("logprob_full"), "logprob_empty": data_attr.get("logprob_empty"), "min_logprob_seen": data_attr.get("min_logprob_seen"), "reference_answer_received": data_attr.get("reference_answer_received"), "answer_received": data_attr.get("answer_received"), "raw_attr_keys": list(data_attr.keys()), "raw_int_keys": list(data_int.keys()), } reference_answer = correct_answer unmasked_answer = data_attr.get("y_full") or data_attr.get("unmasked_answer") or "" debug_scores = data_attr.get("debug_scores") or None interaction_chips_html = create_interaction_token_view( features, feature_values, pairwise_for_network, method=method, layout="sentence" if level == "sentence" else "token", ) figs = { "interactions": inter_fig, } if progress is not None: progress(1.0, desc="Done") return update( figs=figs, meta=meta, html=html, interaction_html=interaction_chips_html, interaction_text_html=text_interaction_html, scoring_target_source=scoring_target_source, scoring_target_text=str(scoring_target_text), reference_answer=reference_answer, unmasked_answer=unmasked_answer, debug_scores=debug_scores, scalarizer_used=data_attr.get("scalarizer_used", payload.get("scalarizer")), score_full=data_attr.get("score_full"), score_empty=data_attr.get("score_empty"), y_len_tokens=y_len_tokens, ) # ═══════════════════════════════════════════════════════════════════════════ # CLIP-based live compute helpers (Custom Image / Custom Multimodal tabs) # ═══════════════════════════════════════════════════════════════════════════ _CLIP_MODEL_MAP: Dict[str, str] = { "CLIP (openai/clip-vit-base-patch32)": "openai/clip-vit-base-patch32", "BiomedCLIP": "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", } def _get_clip_scorer(model_display: str) -> "CrossModalCLIPScorer": """Load or return cached CLIP scorer. Apply dot-mask fix.""" model_name = _CLIP_MODEL_MAP.get(model_display, model_display) if model_name in _clip_scorer_cache: return _clip_scorer_cache[model_name] import torch as _torch device = "cuda" if _torch.cuda.is_available() else "cpu" cfg = PipelineConfig(clip_model_name=model_name, device=device) scorer = CrossModalCLIPScorer(cfg) # Dot-mask fix: use "." (ID 269) instead of EOS to avoid CLIP argmax-pooling shift _neutral_ids = scorer.processor.tokenizer.encode(".", add_special_tokens=False) if _neutral_ids: scorer.unk_token_id = _neutral_ids[0] _clip_scorer_cache[model_name] = scorer return scorer def _run_clip_attribution( image: Image.Image, caption: str, clip_model: str, seg_mode: str, grid_size: int, method: str, seed: int, progress=None, ) -> Dict[str, Any]: """ Core CLIP cross-modal attribution pipeline shared by both custom tabs. Returns a dict with regions, token_players, values, interactions, overlay images, masked images, and influence matrix. """ import numpy as np if not _CLIP_PIPELINE_AVAILABLE: raise gr.Error( "CLIP pipeline not available. Ensure attribution.set_mm is importable " "(requires transformers, lightgbm, numpy, scipy)." ) method = _normalize_method(method) # Check LaMa availability, fall back to blur try: from simple_lama_inpainting import SimpleLama # noqa: F401 mask_style = "lama" except ImportError: mask_style = "blur" import torch as _torch device = "cuda" if _torch.cuda.is_available() else "cpu" model_name = _CLIP_MODEL_MAP.get(clip_model, clip_model) cfg = PipelineConfig( mode="patch" if seg_mode == "Patch Grid" else "unsam", grid_size=int(grid_size), mask_style=mask_style, clip_model_name=model_name, max_tokens=15, method=method, max_order=2, top_k_interactions=15, random_seed=int(seed), device=device, ) if progress is not None: progress(0.05, desc="Loading CLIP model...") scorer = _get_clip_scorer(clip_model) # Step 1: Featurise image if progress is not None: progress(0.10, desc="Segmenting image...") try: regions = featurise(image, cfg) except Exception as exc: if seg_mode != "Patch Grid": raise gr.Error( f"UnSAM segmentation failed: {exc}. " "Try using 'Patch Grid' instead." ) from exc raise # Step 2: Tokenise caption if progress is not None: progress(0.15, desc="Tokenising caption...") token_players, full_token_ids = tokenise_caption( caption, scorer.processor, cfg, offset=len(regions) ) n_img = len(regions) n_tok = len(token_players) n_total = n_img + n_tok # Step 3: Build cross-modal set function if progress is not None: progress(0.20, desc="Building set function...") game = build_cross_modal_set_function( image, regions, token_players, full_token_ids, scorer, cfg ) # Step 4: Run ProxySPEX (run_proxyspex wraps the set function for 2D batch calls) _raw_labels = [r.label for r in regions] + [tp.label for tp in token_players] # all_labels is rebuilt after tok_vals disambiguation below; _raw_labels for ProxySPEX if progress is not None: progress(0.25, desc=f"Running ProxySPEX (n={n_total})...") mobius = run_proxyspex(game, _raw_labels, max_order=2, seed=int(seed)) # Step 5: Derive Shapley/Banzhaf values if progress is not None: progress(0.70, desc="Computing values...") if method == "banzhaf": values = mobius_to_banzhaf(mobius) else: values = mobius_to_shapley(mobius) # Split into image and token values # Disambiguate duplicate labels (e.g., two "the" tokens) by appending #N img_vals = {regions[i].label: float(values.get((i,), 0.0)) for i in range(n_img)} tok_vals = {} _tok_label_counts: Dict[str, int] = {} for j, tp in enumerate(token_players): label = tp.label count = _tok_label_counts.get(label, 0) _tok_label_counts[label] = count + 1 key = f"{label}#{count}" if count > 0 else label tok_vals[key] = float(values.get((n_img + j,), 0.0)) # Rebuild all_labels with disambiguated token labels all_labels = list(img_vals.keys()) + list(tok_vals.keys()) # Step 6: Extract interactions interactions = extract_interactions(mobius, order=2, top_k=15) cross_per_token, cross_global_top5 = extract_cross_per_token(mobius, n_img, n_tok) # Image-image and token-token interactions img_filter = lambda loc: all(i < n_img for i in loc) tok_filter = lambda loc: all(i >= n_img for i in loc) interactions_img = extract_interactions(mobius, order=2, top_k=10, player_filter=img_filter) interactions_tok = extract_interactions(mobius, order=2, top_k=10, player_filter=tok_filter) # Cross-modal interactions (for bar chart) cross_filter = lambda loc: any(i < n_img for i in loc) and any(i >= n_img for i in loc) cross_interactions = extract_interactions(mobius, order=2, top_k=15, player_filter=cross_filter) # Step 7: Build influence matrix [n_img x n_tok] influence_matrix = np.zeros((n_img, n_tok)) for loc, val in cross_interactions: img_indices = [i for i in loc if i < n_img] tok_indices = [i - n_img for i in loc if i >= n_img] for ii in img_indices: for tj in tok_indices: if 0 <= ii < n_img and 0 <= tj < n_tok: influence_matrix[ii, tj] += float(val) # Step 8: Render overlay and segmap if progress is not None: progress(0.75, desc="Rendering overlay...") img_val_list = [float(values.get((i,), 0.0)) for i in range(n_img)] overlay_rgba = render_overlay(image, regions, img_val_list) base_rgba = image.convert("RGBA") overlay_img = Image.alpha_composite(base_rgba, overlay_rgba).convert("RGB") overlay_b64 = _encode_image_to_b64(overlay_img) segmap_img = render_segmentation_map(image, regions) segmap_b64 = _encode_image_to_b64(segmap_img) # Step 9: Build segment bboxes (% coordinates for interactive view) w, h = image.size segment_bboxes = [] for reg in regions: x0, y0, x1, y1 = reg.bbox segment_bboxes.append({ "x0_pct": 100.0 * x0 / w, "y0_pct": 100.0 * y0 / h, "w_pct": 100.0 * (x1 - x0) / w, "h_pct": 100.0 * (y1 - y0) / h, "cx_pct": 100.0 * (x0 + x1) / 2 / w, "cy_pct": 100.0 * (y0 + y1) / 2 / h, }) # Step 10: Generate masked images for browser if progress is not None: progress(0.80, desc="Generating masked images...") masked_images: Dict[str, Image.Image] = {} for i, reg in enumerate(regions): # "removed" — mask only this region coal_removed = [1] * n_img coal_removed[i] = 0 removed_img = apply_image_mask( image, regions, coal_removed, style=cfg.mask_style, blur_radius=cfg.blur_radius, cfg=cfg, ) masked_images[f"{reg.label} removed"] = removed_img if progress is not None: progress(0.90, desc="Done computing.") return { "regions": regions, "token_players": token_players, "all_labels": all_labels, "image_values": img_vals, "token_values": tok_vals, "values": values, "mobius": mobius, "interactions": interactions, "interactions_img": interactions_img, "interactions_tok": interactions_tok, "cross_interactions": cross_interactions, "cross_per_token": cross_per_token, "cross_global_top5": cross_global_top5, "influence_matrix": influence_matrix, "overlay_img": overlay_img, "overlay_b64": overlay_b64, "segmap_img": segmap_img, "segmap_b64": segmap_b64, "segment_bboxes": segment_bboxes, "masked_images": masked_images, "method": method, "n_img": n_img, "n_tok": n_tok, "mask_style": mask_style, "seg_mode": seg_mode, "grid_size": int(grid_size), } def _build_masked_choices(masked_images: Dict[str, Image.Image]) -> List[str]: """Return sorted list of masked image choice labels.""" return sorted(masked_images.keys()) def _on_masked_image_select(choice: str, state: Dict) -> Optional[Image.Image]: """Return the masked PIL image for a dropdown choice.""" if not state or not choice: return None return state.get(choice) def _compute_image_attributions_clip( image: Image.Image, caption: str, clip_model: str, seg_mode: str, grid_size: int, method: str, seed: int, progress=None, ): """Compute image-only attributions using CLIP pipeline. Returns UI outputs.""" if image is None: raise gr.Error("Please upload an image.") if not caption or not caption.strip(): raise gr.Error("Please provide a caption or description.") result = _run_clip_attribution( image, caption.strip(), clip_model, seg_mode, int(grid_size), method, int(seed or 0), progress=progress, ) # Build region bar chart seg_labels = list(result["image_values"].keys()) seg_vals = list(result["image_values"].values()) region_chart = create_shapley_bar_chart(seg_labels, seg_vals, "Region Attribution") # Build masked image state and dropdown choices masked_state = result["masked_images"] choices = _build_masked_choices(masked_state) meta = { "mode": "image_clip", "method": result["method"], "clip_model": clip_model, "seg_mode": result["seg_mode"], "grid_size": result["grid_size"], "mask_style": result["mask_style"], "n_regions": result["n_img"], "n_tokens": result["n_tok"], } if progress is not None: progress(1.0, desc="Done") # Returns: original_img, overlay_img, region_chart, masked_dropdown, masked_img, masked_state, meta return ( image, result["overlay_img"], region_chart, gr.update(choices=choices, value=choices[0] if choices else None), masked_state.get(choices[0]) if choices else None, masked_state, meta, ) def _compute_mm_attributions_clip( image: Image.Image, caption: str, clip_model: str, seg_mode: str, grid_size: int, method: str, seed: int, progress=None, ): """Compute cross-modal attributions using CLIP pipeline. Returns UI outputs.""" import numpy as np if image is None: raise gr.Error("Please upload an image.") if not caption or not caption.strip(): raise gr.Error("Please provide a caption or description.") result = _run_clip_attribution( image, caption.strip(), clip_model, seg_mode, int(grid_size), method, int(seed or 0), progress=progress, ) all_labels = result["all_labels"] n_img = result["n_img"] n_tok = result["n_tok"] # Region bar chart seg_labels = list(result["image_values"].keys()) seg_vals = list(result["image_values"].values()) region_chart = create_shapley_bar_chart(seg_labels, seg_vals, "Region Attribution") # Token bar chart tok_labels = list(result["token_values"].keys()) tok_vals = list(result["token_values"].values()) token_chart = create_shapley_bar_chart(tok_labels, tok_vals, "Token Attribution") # Cross-modal bar chart — expects List[Tuple[Tuple[str, str], float]] cross_pairs = [] for loc, val in result["cross_interactions"]: img_parts = [all_labels[i] for i in loc if i < n_img] tok_parts = [all_labels[i] for i in loc if i >= n_img] if img_parts and tok_parts: cross_pairs.append(((img_parts[0], tok_parts[0]), float(val))) cross_chart = create_cross_modal_bar_chart(cross_pairs, "Cross-Modal Interactions", top_k=15) # Influence heatmap heatmap = create_influence_heatmap( seg_labels, tok_labels, result["influence_matrix"], "Influence Heatmap (Regions x Tokens)" ) # Interactive cross-modal HTML view # Build clip_summary dict matching what benchmark_interaction expects clip_summary = { "image_region_values": [ {"label": seg_labels[i], "value": float(seg_vals[i])} for i in range(n_img) ], "token_values": [ {"label": tok_labels[j], "value": float(tok_vals[j])} for j in range(n_tok) ], "cross_modal_interactions": [ {"label": " x ".join(all_labels[i] for i in loc), "value": float(val)} for loc, val in result["cross_global_top5"] ], } image_b64 = _encode_image_to_b64(image) interaction_html = create_benchmark_interaction_html( image_b64=image_b64, clip_summary=clip_summary, vllm_logprob=None, caption=caption, all_cross_modal_pairs=[ { "pair": ( all_labels[loc[0]] if loc[0] < n_img else all_labels[loc[1]], all_labels[loc[1]] if loc[1] >= n_img else all_labels[loc[0]], ), "value": float(val), } for loc, val in result["cross_interactions"] ], segmap_b64=result["segmap_b64"], overlay_b64=result["overlay_b64"], segment_bboxes=result["segment_bboxes"], label_map_b64="", image_width=image.size[0], image_height=image.size[1], title="Cross-Modal Interaction View", ) # Masked image state masked_state = result["masked_images"] choices = _build_masked_choices(masked_state) meta = { "mode": "multimodal_clip", "method": result["method"], "clip_model": clip_model, "seg_mode": result["seg_mode"], "grid_size": result["grid_size"], "mask_style": result["mask_style"], "n_regions": n_img, "n_tokens": n_tok, } if progress is not None: progress(1.0, desc="Done") # Returns: original_img, overlay_img, region_chart, token_chart, # cross_chart, heatmap, interaction_html, # masked_dropdown, masked_img, masked_state, meta return ( image, result["overlay_img"], region_chart, token_chart, cross_chart, heatmap, interaction_html, gr.update(choices=choices, value=choices[0] if choices else None), masked_state.get(choices[0]) if choices else None, masked_state, meta, ) def on_select_example( dataset, ex_id, model_size, order, method, scalarizer=None, feature_level=None, ): """ Public mode handler: load a precomputed example and render figures. Args: dataset (str): dataset name ex_id (str): example id model_size (str): "small" | "medium" | "large" order (int): interaction order (2 or 3) method (str): "shapley" | "banzhaf" | "influence" Returns: tuple ordered as: ( context, prompt, answer, interactions_plot, interactions_token_html, text_html, meta_json, ) """ get_res = get_result_by_id if get_result_by_id is not None else _public_get_result_from_file model_size = _normalize_model_size(model_size) example = {"context": "", "prompt": "", "answer": ""} if get_example_by_id is not None: try: example = get_example_by_id(dataset, ex_id) except Exception: pass result = get_res( model_size, dataset, ex_id, scalarizer=scalarizer, feature_level=feature_level, ) or {} payload = result.get(method, {}) # Your JSON: features (list of strings) + mobius_dict. Convert to UI format if needed. feats = payload.get("features") if isinstance(payload, dict) else None if isinstance(feats, list) and feats and not isinstance(feats[0], dict): payload = _normalize_public_payload_fallback(payload, method) features, feature_values = _extract_feature_series(payload) if not features: features = [""] feature_values = [0.0] # Influence scores are non-negative (squared Fourier coefficients) if method == "influence": feature_values = [abs(v) for v in feature_values] marginals = {feat: float(feature_values[idx]) for idx, feat in enumerate(features)} interactions = _resolve_interactions(payload, order) if method == "influence": interactions = [(feats, abs(val)) for feats, val in interactions] pairwise = _resolve_interactions(payload, 2) if not pairwise: mixed = payload.get("interactions") normalized = _normalize_interactions(mixed) if normalized: pairwise = [item for item in normalized if len(item[0]) == 2] pairwise = [(feats, abs(val)) for feats, val in pairwise] else: pairwise = _resolve_pairwise(payload, features, feature_values) if method == "influence": top_singletons = sorted( list(zip(features, feature_values)), key=lambda kv: abs(float(kv[1])), reverse=True, )[:10] top_pairs = sorted( pairwise, key=lambda kv: abs(float(kv[1])), reverse=True, )[:10] print( "[influence-ui-debug][public] " f"dataset={dataset} ex_id={ex_id} feature_count={len(features)} pair_count={len(pairwise)}" , flush=True) print("[influence-ui-debug][public] top_singletons:", top_singletons, flush=True) print("[influence-ui-debug][public] top_pairwise:", top_pairs, flush=True) payload_level = ( payload.get("mask_level") or payload.get("feature_level") or payload.get("level") or (result.get("meta", {}) if isinstance(result, dict) else {}).get("feature_level") ) layout_mode = "sentence" if _normalize_level(payload_level) == "sentence" else "token" inter = plot_top_interactions(interactions, order=order, method=method) spans = payload.get("feature_spans") or payload.get("spans") if not spans: # Precomputed JSON payloads may not include explicit spans. # Reconstruct spans from context + feature level so Text View can render. _, fallback_spans, _ = _chunk_text_for_visualization( example.get("context", ""), _normalize_level(payload_level), ) if fallback_spans and len(fallback_spans) == len(feature_values): spans = fallback_spans html = None if spans and len(spans) == len(feature_values): html = create_interactive_text_heatmap( example.get("context", ""), spans, feature_values, method=method, ) # Compute the wrong-answer payload up-front so the dual heatmap branch # (which rewrites text_interaction_html below) has it ready. _wrong_values_for_dual: Optional[List[float]] = None _wrong_pairwise_for_dual: Optional[List[Any]] = None _wrong_features_for_dual: Optional[List[str]] = None try: from visualization.wrong_answer_examples import has_wrong_answer_view as _has_wrong_view except Exception: _has_wrong_view = None _is_wrong_view = bool( _has_wrong_view is not None and html is not None and spans and _has_wrong_view(dataset, ex_id, scalarizer or "", feature_level or "") ) if _is_wrong_view: wrong_result = _public_get_model_answer_short_from_file( model_size, dataset, ex_id, scalarizer or "geomean_jointprob", feature_level or "word", ) wrong_payload = wrong_result.get(method, {}) if wrong_result else {} wrong_features_local, wrong_values_local = _extract_feature_series(wrong_payload) if method == "influence": wrong_values_local = [abs(v) for v in wrong_values_local] if wrong_features_local and len(wrong_values_local) == len(feature_values): # Build wrong-side pairwise edges, mirroring the GT logic above. wrong_pairwise = _resolve_interactions(wrong_payload, 2) if not wrong_pairwise: mixed = wrong_payload.get("interactions") if isinstance(wrong_payload, dict) else None normalized = _normalize_interactions(mixed) if normalized: wrong_pairwise = [item for item in normalized if len(item[0]) == 2] if method == "influence": wrong_pairwise = [(f, abs(v)) for f, v in (wrong_pairwise or [])] else: # Best-effort: if no explicit pairwise, derive from wrong feature values if not wrong_pairwise: wrong_pairwise = _resolve_pairwise(wrong_payload, wrong_features_local, wrong_values_local) _wrong_values_for_dual = wrong_values_local _wrong_pairwise_for_dual = wrong_pairwise or [] _wrong_features_for_dual = wrong_features_local else: _is_wrong_view = False meta = { "dataset": dataset, "example_id": ex_id, "model_size": model_size, "method": method, "order": order, "feature_count": len(features), "payload_keys": sorted(payload.keys()), } if "meta" in result: meta["source_meta"] = result["meta"] interaction_chips_html = create_interaction_token_view( features, feature_values, pairwise or [item for item in interactions if len(item[0]) == 2], method=method, layout=layout_mode, ) text_interaction_html = create_text_interaction_html( features, feature_values, _pairwise_to_index_interactions( pairwise or [item for item in interactions if len(item[0]) == 2], features, ), method=method, top_k=20, threshold=0.0, ) # For the 30 wrong-answer examples, replace the visible Text Interaction # view with two chip+arc panels side-by-side (vs Ground Truth | vs Model # Answer (Wrong)) plus a single shared legend + RAW TEXT below. if ( _is_wrong_view and _wrong_values_for_dual is not None and _wrong_features_for_dual is not None ): gt_view = create_text_interaction_html( features, feature_values, _pairwise_to_index_interactions( pairwise or [item for item in interactions if len(item[0]) == 2], features, ), method=method, top_k=20, threshold=0.0, ) wrong_view = create_text_interaction_html( _wrong_features_for_dual, _wrong_values_for_dual, _pairwise_to_index_interactions( _wrong_pairwise_for_dual or [], _wrong_features_for_dual, ), method=method, top_k=20, threshold=0.0, ) method_label = (method or "attribution").title() gt_max_abs = max((abs(v) for v in feature_values), default=0.0) or 1.0 wrong_max_abs = max((abs(v) for v in _wrong_values_for_dual), default=0.0) or 1.0 from html import escape as _escape raw_text = example.get("context", "") or "" raw_text_html = _escape(raw_text).replace("\n", "
") if raw_text else "" # CSS scoped to .dual-heatmap-row hides the per-side legend so we can # show one shared legend below; tightens the per-card max width so two # views fit comfortably side-by-side. dual_css = ( "" ) shared_block = ( '
' '
' f'{method_label} legend' '
' 'Negative' '
' 'Positive' '
' '

' f'Ground-truth max |value| = {gt_max_abs:.4f}; ' f'wrong-answer max |value| = {wrong_max_abs:.4f}. ' 'Hover tokens for exact scores.' '

' '
' '
' 'Raw text' f'

{raw_text_html or "No context available."}

' '
' '
' ) text_interaction_html = ( f'{dual_css}' '
' '
' '
vs Ground Truth
' f'{gt_view}' '
' '
' '
vs Model Answer (Wrong)
' f'{wrong_view}' '
' '
' f'{shared_block}' ) print( f"[wrong-answer] dual chip+lines view rendered for {dataset}/{ex_id} " f"(gt_features={len(features)} wrong_features={len(_wrong_features_for_dual)} " f"gt_pairs={len(pairwise or [])} wrong_pairs={len(_wrong_pairwise_for_dual or [])})", flush=True, ) figs = { "interactions": inter, } outputs = update( figs=figs, meta=meta, html=html, interaction_html=interaction_chips_html, interaction_text_html=text_interaction_html, ) return ( example.get("context", ""), example.get("prompt", ""), _extract_answer(example), *outputs, ) def on_click_compute( context, prompt, correct_answer, model_size, level, method, scalarizer, embedding_model, progress=gr.Progress(track_tqdm=True), ): # """ # Developer mode handler: compute (or mock) attributions and render figures. # """ # method = _normalize_method(method) # level = _normalize_level(level) # order = 3 if int(order or 2) >= 3 else 2 # context = context or "" # prompt = prompt or "" # correct_answer = correct_answer or "" # try: # return _compute_live_attributions( # context=context, # prompt=prompt, # correct_answer=correct_answer, # model_size=model_size, # level=level, # method=method, # order=order, # progress=progress, # ) # except Exception as exc: # pragma: no cover - best-effort fallback # return _synthetic_attribution_pipeline( # context, # prompt, # correct_answer, # method=method, # level=level, # order=order, # reason=str(exc), # ) method = _normalize_method(method) level = _normalize_level(level) model_size = _normalize_model_size(model_size) order = 2 context = context or "" prompt = prompt or "" correct_answer = correct_answer or "" return _compute_live_attributions( context=context, prompt=prompt, correct_answer=correct_answer, model_size=model_size, level=level, method=method, order=order, scalarizer=scalarizer, embedding_model=embedding_model, progress=progress, ) # --------------------------------------------------------------------------- # Multimodal precomputed example handlers (MIMIC-CXR, ISIC, MS-COCO) # --------------------------------------------------------------------------- # ── MIMIC-CXR Tab Handlers ──────────────────────────────────────────────── _MIMIC_METHOD_NAMES = [ "BiomedCLIP Cross-Modal", "LLaVA-Med Log-Prob", "LLaVA-Med Generation", ] def _on_select_mimic_example(example_id, method_label: str = "Influence"): """Load a MIMIC-CXR example and return data for the MIMIC tab.""" # 15 outputs: caption, original, findings, interpretation, # biomedclip_overlay, biomedclip_token_plot, biomedclip_region_plot, # llavamed_unsam_lp_overlay, llavamed_unsam_lp_plot, # llavamed_unsam_gen_overlay, llavamed_unsam_gen_plot, # biomedclip_interaction_html, meta, results_state, compare_method_a n_outputs = 15 empty = tuple([""] + [None] * (n_outputs - 1)) if not _MIMIC_AVAILABLE or not example_id: return empty method = (method_label or "Influence").lower() method_display = "Influence" if method == "influence" else "Shapley" _base_chart = globals()["create_shapley_bar_chart"] _base_html = globals()["create_benchmark_interaction_html"] def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): # noqa: F811 kwargs.setdefault("method_label", method_display) return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs) def create_benchmark_interaction_html(**kwargs): # noqa: F811 kwargs.setdefault("method_label", method_display) return _base_html(**kwargs) try: data = load_mimic_example(example_id, method=method) except Exception: return empty caption = data.get("caption", "") findings = data.get("findings", "") original_img = data.get("original_image_path") meta = data.get("meta", {}) category = meta.get("category", "") # ── BiomedCLIP ─────────────────────────────────────────────────── biomedclip_overlay_labeled = None biomedclip_region_plot = None biomedclip_token_plot = None biomedclip_interaction_html = "" segment_bboxes = None label_map_b64 = "" if data.get("has_biomedclip"): bc_summary = data["biomedclip"]["summary"] bc_overlay_raw = data["biomedclip"]["image_paths"].get("overlay") bc_original = data["biomedclip"]["image_paths"].get("original", "") bc_segmap = data["biomedclip"]["image_paths"].get("segmap", "") bc_n_segs = len(bc_summary.get("image_region_values", [])) bc_bboxes, bc_label_map_b64 = None, "" if bc_original and bc_segmap and bc_n_segs > 0: try: bc_bboxes, bc_label_map_b64 = extract_segment_regions( bc_original, bc_segmap, bc_n_segs) except Exception: pass if bc_overlay_raw: biomedclip_overlay_labeled = draw_segment_labels( bc_overlay_raw, bc_summary.get("image_region_values", []), segment_bboxes=bc_bboxes, label_map_b64=bc_label_map_b64, original_path=bc_original) bc_r_labels = [v["label"] for v in bc_summary.get("image_region_values", [])] bc_r_values = [v["value"] for v in bc_summary.get("image_region_values", [])] if bc_r_labels: biomedclip_region_plot = create_shapley_bar_chart( bc_r_labels, bc_r_values, "BiomedCLIP — Image Region Shapley Values") bc_merged = merge_subword_token_values(bc_summary.get("token_values", []), caption) bc_t_labels = [v["label"] for v in bc_merged] bc_t_values = [v["value"] for v in bc_merged] if bc_t_labels: biomedclip_token_plot = create_shapley_bar_chart( bc_t_labels, bc_t_values, "BiomedCLIP — Caption Word Shapley Values") # Interactive cross-modal HTML bc_image_b64 = data["biomedclip"].get("image_b64", {}).get("original", "") bc_overlay_b64 = data["biomedclip"].get("image_b64", {}).get("overlay", "") bc_all_cross = data["biomedclip"].get("all_cross_modal_pairs", []) bc_segmap_b64 = "" if bc_segmap: import os as _os if _os.path.exists(bc_segmap): import base64 as _b64 with open(bc_segmap, "rb") as _f: bc_segmap_b64 = _b64.b64encode(_f.read()).decode("ascii") biomedclip_interaction_html = create_benchmark_interaction_html( image_b64=bc_image_b64, clip_summary=bc_summary, vllm_logprob=None, caption=caption, all_cross_modal_pairs=bc_all_cross, segmap_b64=bc_segmap_b64, overlay_b64=bc_overlay_b64, segment_bboxes=bc_bboxes, label_map_b64=bc_label_map_b64, title="BiomedCLIP Cross-Modal Interaction View — click segments or words", ) segment_bboxes = bc_bboxes label_map_b64 = bc_label_map_b64 # ── LLaVA-Med UnSAM ───────────────────────────────────────────── # Draw two separate overlays — one colored by Log-Prob values, one by # Generation values — since the signs often differ between methods. llavamed_unsam_lp_overlay_img = None llavamed_unsam_gen_overlay_img = None llavamed_unsam_lp_plot = None llavamed_unsam_gen_plot = None if data.get("has_llavamed_unsam_logprob") or data.get("has_llavamed_unsam_gen"): lu_segmap = data.get("llavamed_unsam_segmap_path", "") lu_original = data.get("llavamed_unsam_original_path", "") or (original_img or "") lu_bboxes, lu_label_map_b64 = None, "" if lu_segmap and lu_original: n_lu_segs = 0 if data.get("has_llavamed_unsam_logprob"): n_lu_segs = len(data["llavamed_unsam_logprob"].get("image_region_values", [])) elif data.get("has_llavamed_unsam_gen"): n_lu_segs = len(data["llavamed_unsam_gen"].get("image_region_values", [])) if n_lu_segs > 0: try: lu_bboxes, lu_label_map_b64 = extract_segment_regions( lu_original, lu_segmap, n_lu_segs) except Exception: pass if data.get("has_llavamed_unsam_logprob"): lu_lp = rename_patch_labels( data["llavamed_unsam_logprob"].get("image_region_values", [])) if lu_lp: llavamed_unsam_lp_plot = create_shapley_bar_chart( [v["label"] for v in lu_lp], [v["value"] for v in lu_lp], "LLaVA-Med Log-Prob — Segment Shapley Values", ) overlay_path = data["llavamed_unsam_logprob"].get("overlay_path", "") if overlay_path: llavamed_unsam_lp_overlay_img = draw_segment_labels( overlay_path, lu_lp, segment_bboxes=lu_bboxes, label_map_b64=lu_label_map_b64, original_path=lu_original) if data.get("has_llavamed_unsam_gen"): lu_gen = rename_patch_labels( data["llavamed_unsam_gen"].get("image_region_values", [])) if lu_gen: llavamed_unsam_gen_plot = create_shapley_bar_chart( [v["label"] for v in lu_gen], [v["value"] for v in lu_gen], "LLaVA-Med Generation — Segment Shapley Values", ) # Use the log-prob overlay as the base image and recolor by gen values overlay_path = (data["llavamed_unsam_gen"].get("overlay_path", "") or data.get("llavamed_unsam_logprob", {}).get("overlay_path", "")) if overlay_path: llavamed_unsam_gen_overlay_img = draw_segment_labels( overlay_path, lu_gen, segment_bboxes=lu_bboxes, label_map_b64=lu_label_map_b64, original_path=lu_original) # ── Interpretation text ────────────────────────────────────────── interpretation = "" try: bc_data = data.get("biomedclip", {}).get("summary") if data.get("has_biomedclip") else None interpretation = generate_interpretation_text( clip_summary=bc_data, vllm_logprob=data.get("llavamed_unsam_logprob") if data.get("has_llavamed_unsam_logprob") else None, modality="Chest X-ray", body_part=category, caption=caption, cross_method_name="BiomedCLIP", vlm_method_name="LLaVA-Med", vlm_region_type="UnSAM segments", ) except Exception: pass # If no precomputed results at all, show informative message if not any(data.get(k) for k in ("has_biomedclip", "has_llavamed_unsam_logprob", "has_llavamed_unsam_gen", "has_clip")): interpretation = ( "**No precomputed attribution results yet.**\n\n" "Run the attribution pipeline on this MIMIC-CXR example to see results here. " "The image and report are shown above for reference." ) # ── Build results state for comparison ─────────────────────────── _results_state = {} if biomedclip_overlay_labeled: _results_state["BiomedCLIP Cross-Modal"] = { "overlay": biomedclip_overlay_labeled, "plot": biomedclip_region_plot} if llavamed_unsam_lp_overlay_img: _results_state["LLaVA-Med Log-Prob"] = { "overlay": llavamed_unsam_lp_overlay_img, "plot": llavamed_unsam_lp_plot} if llavamed_unsam_gen_overlay_img: _results_state["LLaVA-Med Generation"] = { "overlay": llavamed_unsam_gen_overlay_img, "plot": llavamed_unsam_gen_plot} return ( caption, # 1 original_img, # 2 findings, # 3 interpretation, # 4 biomedclip_overlay_labeled, # 5 biomedclip_token_plot, # 6 biomedclip_region_plot, # 7 llavamed_unsam_lp_overlay_img, # 8 llavamed_unsam_lp_plot, # 9 llavamed_unsam_gen_overlay_img, # 10a llavamed_unsam_gen_plot, # 10b biomedclip_interaction_html, # 11 { # 12 — metadata "example_id": example_id, "category": category, "has_biomedclip": data.get("has_biomedclip", False), "has_llavamed_unsam_logprob": data.get("has_llavamed_unsam_logprob", False), "has_llavamed_unsam_gen": data.get("has_llavamed_unsam_gen", False), }, _results_state, # 13 gr.update(), # 14 (placeholder) ) def _on_mimic_compare_methods(method_a, method_b, results_state): """Pick two MIMIC methods from state and display side by side.""" if not method_a or not method_b or not results_state: return None, None, None, None a = results_state.get(method_a, {}) b = results_state.get(method_b, {}) return a.get("overlay"), b.get("overlay"), a.get("plot"), b.get("plot") # ── ISIC Dermoscopy Tab Handlers ────────────────────────────────────────── _ISIC_METHOD_NAMES = [ "BiomedCLIP Cross-Modal", "LLaVA-Med Log-Prob", "LLaVA-Med Generation", ] def _on_select_isic_example(example_id, method_label: str = "Influence"): """Load an ISIC dermoscopy example and return data for the ISIC tab. Mirrors _on_select_mimic_example — same 14 outputs, same layout. ISIC has no separate "findings" field, so slot 3 (findings) is empty. """ n_outputs = 14 empty = tuple([""] + [None] * (n_outputs - 1)) if not _ISIC_AVAILABLE or not example_id: return empty method = (method_label or "Influence").lower() method_display = "Influence" if method == "influence" else "Shapley" _base_chart = globals()["create_shapley_bar_chart"] _base_html = globals()["create_benchmark_interaction_html"] def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): # noqa: F811 kwargs.setdefault("method_label", method_display) return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs) def create_benchmark_interaction_html(**kwargs): # noqa: F811 kwargs.setdefault("method_label", method_display) return _base_html(**kwargs) try: data = load_isic_example(example_id, method=method) except Exception: return empty caption = data.get("caption", "") original_img = data.get("original_image_path") meta = data.get("meta", {}) category = meta.get("category", "") # ── BiomedCLIP ─────────────────────────────────────────────────── biomedclip_overlay_labeled = None biomedclip_region_plot = None biomedclip_token_plot = None biomedclip_interaction_html = "" if data.get("has_biomedclip"): bc_summary = data["biomedclip"]["summary"] bc_overlay_raw = data["biomedclip"]["image_paths"].get("overlay") bc_original = data["biomedclip"]["image_paths"].get("original", "") bc_segmap = data["biomedclip"]["image_paths"].get("segmap", "") bc_n_segs = len(bc_summary.get("image_region_values", [])) bc_bboxes, bc_label_map_b64 = None, "" if bc_original and bc_segmap and bc_n_segs > 0: try: bc_bboxes, bc_label_map_b64 = extract_segment_regions( bc_original, bc_segmap, bc_n_segs) except Exception: pass if bc_overlay_raw: biomedclip_overlay_labeled = draw_segment_labels( bc_overlay_raw, bc_summary.get("image_region_values", []), segment_bboxes=bc_bboxes, label_map_b64=bc_label_map_b64, original_path=bc_original) bc_r_labels = [v["label"] for v in bc_summary.get("image_region_values", [])] bc_r_values = [v["value"] for v in bc_summary.get("image_region_values", [])] if bc_r_labels: biomedclip_region_plot = create_shapley_bar_chart( bc_r_labels, bc_r_values, "BiomedCLIP — Image Region Shapley Values") bc_merged = merge_subword_token_values(bc_summary.get("token_values", []), caption) bc_t_labels = [v["label"] for v in bc_merged] bc_t_values = [v["value"] for v in bc_merged] if bc_t_labels: biomedclip_token_plot = create_shapley_bar_chart( bc_t_labels, bc_t_values, "BiomedCLIP — Caption Word Shapley Values") bc_image_b64 = data["biomedclip"].get("image_b64", {}).get("original", "") bc_overlay_b64 = data["biomedclip"].get("image_b64", {}).get("overlay", "") bc_all_cross = data["biomedclip"].get("all_cross_modal_pairs", []) bc_segmap_b64 = "" if bc_segmap: import os as _os if _os.path.exists(bc_segmap): import base64 as _b64 with open(bc_segmap, "rb") as _f: bc_segmap_b64 = _b64.b64encode(_f.read()).decode("ascii") biomedclip_interaction_html = create_benchmark_interaction_html( image_b64=bc_image_b64, clip_summary=bc_summary, vllm_logprob=None, caption=caption, all_cross_modal_pairs=bc_all_cross, segmap_b64=bc_segmap_b64, overlay_b64=bc_overlay_b64, segment_bboxes=bc_bboxes, label_map_b64=bc_label_map_b64, title="BiomedCLIP Cross-Modal Interaction View — click segments or words", ) # ── LLaVA-Med UnSAM ───────────────────────────────────────────── llavamed_unsam_lp_overlay_img = None llavamed_unsam_gen_overlay_img = None llavamed_unsam_lp_plot = None llavamed_unsam_gen_plot = None if data.get("has_llavamed_unsam_logprob") or data.get("has_llavamed_unsam_gen"): lu_segmap = data.get("llavamed_unsam_segmap_path", "") lu_original = data.get("llavamed_unsam_original_path", "") or (original_img or "") lu_bboxes, lu_label_map_b64 = None, "" if lu_segmap and lu_original: n_lu_segs = 0 if data.get("has_llavamed_unsam_logprob"): n_lu_segs = len(data["llavamed_unsam_logprob"].get("image_region_values", [])) elif data.get("has_llavamed_unsam_gen"): n_lu_segs = len(data["llavamed_unsam_gen"].get("image_region_values", [])) if n_lu_segs > 0: try: lu_bboxes, lu_label_map_b64 = extract_segment_regions( lu_original, lu_segmap, n_lu_segs) except Exception: pass if data.get("has_llavamed_unsam_logprob"): lu_lp = rename_patch_labels( data["llavamed_unsam_logprob"].get("image_region_values", [])) if lu_lp: llavamed_unsam_lp_plot = create_shapley_bar_chart( [v["label"] for v in lu_lp], [v["value"] for v in lu_lp], "LLaVA-Med Log-Prob — Segment Shapley Values", ) overlay_path = data["llavamed_unsam_logprob"].get("overlay_path", "") if overlay_path: llavamed_unsam_lp_overlay_img = draw_segment_labels( overlay_path, lu_lp, segment_bboxes=lu_bboxes, label_map_b64=lu_label_map_b64, original_path=lu_original) if data.get("has_llavamed_unsam_gen"): lu_gen = rename_patch_labels( data["llavamed_unsam_gen"].get("image_region_values", [])) if lu_gen: llavamed_unsam_gen_plot = create_shapley_bar_chart( [v["label"] for v in lu_gen], [v["value"] for v in lu_gen], "LLaVA-Med Generation — Segment Shapley Values", ) overlay_path = (data["llavamed_unsam_gen"].get("overlay_path", "") or data.get("llavamed_unsam_logprob", {}).get("overlay_path", "")) if overlay_path: llavamed_unsam_gen_overlay_img = draw_segment_labels( overlay_path, lu_gen, segment_bboxes=lu_bboxes, label_map_b64=lu_label_map_b64, original_path=lu_original) # ── Interpretation text ────────────────────────────────────────── interpretation = "" try: bc_data = data.get("biomedclip", {}).get("summary") if data.get("has_biomedclip") else None interpretation = generate_interpretation_text( clip_summary=bc_data, vllm_logprob=data.get("llavamed_unsam_logprob") if data.get("has_llavamed_unsam_logprob") else None, modality="Dermoscopy", body_part=category, caption=caption, cross_method_name="BiomedCLIP", vlm_method_name="LLaVA-Med", vlm_region_type="UnSAM segments", ) except Exception: pass if not any(data.get(k) for k in ("has_biomedclip", "has_llavamed_unsam_logprob", "has_llavamed_unsam_gen", "has_clip")): interpretation = ( "**No precomputed attribution results yet.**\n\n" "Run the attribution pipeline on this ISIC example to see results here. " "The image and caption are shown above for reference." ) # ── Results state for comparison dropdowns ────────────────────── _results_state = {} if biomedclip_overlay_labeled: _results_state["BiomedCLIP Cross-Modal"] = { "overlay": biomedclip_overlay_labeled, "plot": biomedclip_region_plot} if llavamed_unsam_lp_overlay_img: _results_state["LLaVA-Med Log-Prob"] = { "overlay": llavamed_unsam_lp_overlay_img, "plot": llavamed_unsam_lp_plot} if llavamed_unsam_gen_overlay_img: _results_state["LLaVA-Med Generation"] = { "overlay": llavamed_unsam_gen_overlay_img, "plot": llavamed_unsam_gen_plot} return ( caption, # 1 original_img, # 2 interpretation, # 3 biomedclip_overlay_labeled, # 4 biomedclip_token_plot, # 5 biomedclip_region_plot, # 6 llavamed_unsam_lp_overlay_img, # 7 llavamed_unsam_lp_plot, # 8 llavamed_unsam_gen_overlay_img, # 9 llavamed_unsam_gen_plot, # 10 biomedclip_interaction_html, # 11 { # 12 — metadata "example_id": example_id, "category": category, "has_biomedclip": data.get("has_biomedclip", False), "has_llavamed_unsam_logprob": data.get("has_llavamed_unsam_logprob", False), "has_llavamed_unsam_gen": data.get("has_llavamed_unsam_gen", False), }, _results_state, # 13 gr.update(), # 14 (placeholder for compare dropdown) ) def _on_isic_compare_methods(method_a, method_b, results_state): """Pick two ISIC methods from state and display side by side.""" if not method_a or not method_b or not results_state: return None, None, None, None a = results_state.get(method_a, {}) b = results_state.get(method_b, {}) return a.get("overlay"), b.get("overlay"), a.get("plot"), b.get("plot") def _on_select_coco_example(example_id, method_label: str = "Influence"): """Load a precomputed MS-COCO example and return outputs for the COCO tab.""" n_outputs = 12 empty = ("",) + (None,) * (n_outputs - 2) + (gr.update(),) if not _COCO_AVAILABLE or not _MEDICAL_AVAILABLE or not example_id: return empty method = (method_label or "Influence").lower() method_display = "Influence" if method == "influence" else "Shapley" _base_chart = globals()["create_shapley_bar_chart"] _base_html = globals()["create_benchmark_interaction_html"] def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): # noqa: F811 kwargs.setdefault("method_label", method_display) return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs) def create_benchmark_interaction_html(**kwargs): # noqa: F811 kwargs.setdefault("method_label", method_display) return _base_html(**kwargs) try: data = load_coco_example(example_id, method=method) except Exception as exc: print(f"[coco] Error loading {example_id}: {exc}") return empty caption = data.get("caption", "") summary = data.get("summary", {}) original_img = data.get("image_paths", {}).get("original") overlay_img = data.get("image_paths", {}).get("overlay") # Segment bboxes from segmap segment_bboxes, label_map_b64 = None, "" clip_original = data["image_paths"].get("original", "") clip_segmap = data["image_paths"].get("segmap", "") n_segs = len(summary.get("image_region_values", [])) if clip_original and clip_segmap and n_segs > 0: try: segment_bboxes, label_map_b64 = extract_segment_regions( clip_original, clip_segmap, n_segs) except Exception: pass # Draw segment labels on overlay overlay_labeled = overlay_img if overlay_img: try: labeled = draw_segment_labels( overlay_img, summary.get("image_region_values", []), segment_bboxes=segment_bboxes, ) if labeled: overlay_labeled = labeled except Exception: pass # Bar charts r_vals = summary.get("image_region_values", []) r_labels = [v["label"] for v in r_vals] r_values = [v["value"] for v in r_vals] region_plot = create_shapley_bar_chart( r_labels, r_values, "CLIP — Image Region Shapley Values") if r_labels else None t_vals = summary.get("token_values", []) merged_toks = merge_subword_token_values(t_vals, caption) t_labels = [v["label"] for v in merged_toks] t_values = [v["value"] for v in merged_toks] token_plot = create_shapley_bar_chart( t_labels, t_values, "CLIP — Caption Word Shapley Values") if t_labels else None # Cross-modal pairs + chart + table all_cross = data.get("all_cross_modal_pairs", []) cross_plot = None cross_table = [] if all_cross: cross_pairs = [ ((item["pair"][0], _tok_to_word(item["pair"][1], caption)), item["value"]) for item in all_cross ] cross_plot = create_cross_modal_bar_chart( cross_pairs, "CLIP — Top Image x Word Interactions", top_k=20) cross_table = [ [item["pair"][0], _tok_to_word(item["pair"][1], caption), f"{item['value']:+.4f}"] for item in all_cross[:30] ] # Heatmap heatmap = None influence_matrix = data.get("influence_matrix") tok_labels_hm = [t.replace("tok:", "").lstrip("#") for t in data.get("tok_labels", [])] if influence_matrix is not None and influence_matrix.size > 0: heatmap = create_influence_heatmap( data.get("seg_labels", r_labels), tok_labels_hm, influence_matrix, "Image Regions x Caption Words — Influence Scores") # Interactive cross-modal HTML image_b64 = data.get("image_b64", {}).get("original", "") overlay_b64 = data.get("image_b64", {}).get("overlay", "") segmap_b64 = data.get("image_b64", {}).get("segmap", "") interaction_html = "" try: interaction_html = create_benchmark_interaction_html( image_b64=image_b64, clip_summary=summary, vllm_logprob=None, caption=caption, all_cross_modal_pairs=all_cross, segmap_b64=segmap_b64, overlay_b64=overlay_b64, segment_bboxes=segment_bboxes, label_map_b64=label_map_b64, title="MS-COCO — Click a region or word to explore interactions", ) except Exception as exc: interaction_html = f"

Error building interaction view: {exc}

" note = ( "**Note:** These results used the original UNK mask token " "(same as `<|endoftext|>`, CLIP token ID 49407). " "A first-token dominance artifact may be visible in the token Shapley chart. " "This will be corrected when scaling to 100 images with the dot-mask fix." ) # Masked Image Browser region_choices = data.get("region_choices", []) masked_dd_update = gr.update( choices=region_choices, value=region_choices[0] if region_choices else None, ) # Pre-load the first masked image (all_masked) so the viewer isn't blank first_masked_img = None if region_choices: try: first_masked_img = get_coco_masked_image_path(example_id, region_choices[0]) except Exception: pass return ( caption, # 1 original_img, # 2 overlay_labeled, # 3 interaction_html, # 4 token_plot, # 5 region_plot, # 6 cross_plot, # 7 cross_table, # 8 heatmap, # 9 note, # 10 first_masked_img, # 11 — masked image viewer masked_dd_update, # 12 — masked dropdown choices ) def _on_select_coco_masked(example_id, choice): """Return a masked image path for the COCO Masked Image Browser.""" if not _COCO_AVAILABLE or not example_id or not choice: return None return get_coco_masked_image_path(example_id, choice) def on_click_image_compute( image, caption, clip_model, seg_mode, grid_size, method, seed, progress=gr.Progress(track_tqdm=True), ): return _compute_image_attributions_clip( image=image, caption=caption, clip_model=clip_model, seg_mode=seg_mode, grid_size=grid_size, method=method, seed=seed, progress=progress, ) def on_click_mm_compute( image, caption, clip_model, seg_mode, grid_size, method, seed, progress=gr.Progress(track_tqdm=True), ): return _compute_mm_attributions_clip( image=image, caption=caption, clip_model=clip_model, seg_mode=seg_mode, grid_size=grid_size, method=method, seed=seed, progress=progress, ) # --------------------------------------------------------------------------- # Demo helpers (used to quickly validate visualization components locally) # --------------------------------------------------------------------------- _DEMO_TEXT = "The quick brown fox jumps over the lazy dog in a sunny meadow." _DEMO_FEATURES = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "in", "a", "sunny", "meadow"] _DEMO_SPANS = [ (0, 3), (4, 9), (10, 15), (16, 19), (20, 25), (26, 30), (31, 34), (35, 39), (40, 43), (44, 46), (47, 48), (49, 54), (55, 61) ] _DEMO_ATTRIBUTIONS: Dict[str, Dict[str, float]] = { "shapley": { "The": -0.04, "quick": 0.18, "brown": 0.12, "fox": 0.27, "jumps": 0.15, "over": 0.05, "the": -0.02, "lazy": -0.11, "dog": -0.07, "in": 0.03, "a": 0.02, "sunny": 0.09, "meadow": 0.21, } } _DEMO_ATTRIBUTIONS["banzhaf"] = { token: round(value * 0.8, 3) for token, value in _DEMO_ATTRIBUTIONS["shapley"].items() } _DEMO_ATTRIBUTIONS["influence"] = { token: round(abs(value), 3) for token, value in _DEMO_ATTRIBUTIONS["shapley"].items() } _DEMO_INTERACTIONS_2: List[Tuple[Tuple[str, ...], float]] = [ (("quick", "fox"), 0.24), (("fox", "jumps"), 0.19), (("sunny", "meadow"), 0.22), (("lazy", "dog"), -0.17), (("the", "lazy"), -0.12), ] _DEMO_INTERACTIONS_3: List[Tuple[Tuple[str, ...], float]] = [ (("quick", "brown", "fox"), 0.28), (("fox", "jumps", "over"), 0.18), (("sunny", "meadow", "dog"), 0.11), (("the", "lazy", "dog"), -0.21), ] _DEMO_INTERACTION_MATRIX: List[Tuple[Tuple[int, int], float]] = [ ((1, 3), 0.23), ((3, 4), 0.17), ((7, 8), -0.18), ((11, 12), 0.2), ((2, 5), 0.09), ] _DEMO_DATASETS = { "squad_demo": [ [ "The quick brown fox jumps over the lazy dog.", "Who jumps over the dog?", "The quick brown fox", ], [ "AttrLLM explains attributions for large language models.", "What does AttrLLM explain?", "Attributions", ], ], "truthfulqa_demo": [ [ "Water boils at 100 degrees Celsius at sea level.", "At what temperature does water boil?", "100 degrees Celsius", ] ], } def _render_demo(method: str = "shapley"): method = (method or "shapley").lower() order = 2 attributions = _DEMO_ATTRIBUTIONS.get(method, _DEMO_ATTRIBUTIONS["shapley"]) interactions = _DEMO_INTERACTIONS_3 if order == 3 else _DEMO_INTERACTIONS_2 interactions_fig = plot_top_interactions(interactions, order=order, method=method) demo_pairwise = _DEMO_INTERACTIONS_2 or _fallback_pairwise_from_values( _DEMO_FEATURES, [attributions[token] for token in _DEMO_FEATURES], ) text_html = create_interactive_text_heatmap( _DEMO_TEXT, _DEMO_SPANS, [attributions[token] for token in _DEMO_FEATURES], method=method, ) text_interaction_html = create_text_interaction_html( _DEMO_FEATURES, [attributions[token] for token in _DEMO_FEATURES], [ {"indices": [i, j], "value": float(val)} for (i, j), val in _DEMO_INTERACTION_MATRIX ], method=method, top_k=20, threshold=0.0, ) meta = { "method": method, "order": order, "feature_count": len(_DEMO_FEATURES), "scalarizer": "logprob", } return update( figs={ "interactions": interactions_fig, }, meta=meta, html=text_html, interaction_text_html=text_interaction_html, scoring_target_source="model_output", scoring_target_text="", reference_answer="", unmasked_answer="", debug_scores=None, scalarizer_used="logprob", score_full=None, score_empty=None, y_len_tokens=None, ) def _render_additional_plots(method: str = "shapley"): return plot_interaction_matrix(_DEMO_FEATURES, _DEMO_INTERACTION_MATRIX) def _records_for_dataset(dataset_name: str) -> List[Dict[str, Any]]: if get_examples is not None: try: records = get_examples(dataset_name, n=10) if records: return records except KeyError: pass except Exception: pass fallback_csv = _fallback_load_dataset(dataset_name, max_rows=10) if fallback_csv: return fallback_csv fallback = [] for idx, row in enumerate(_DEMO_DATASETS.get(dataset_name, []), start=1): context, prompt, answer = row fallback.append( { "id": f"{dataset_name}_demo_{idx}", "context": context, "prompt": prompt, "correct_answer": answer, } ) return fallback def _available_datasets() -> List[str]: if list_datasets is not None: try: datasets = list_datasets() if datasets: return datasets except Exception: pass fallback = [k for k, v in _FALLBACK_DATASET_FILES.items() if (_fallback_datasets_dir() / v).exists()] if fallback: return sorted(fallback) return list(_DEMO_DATASETS.keys()) def _format_examples(records: List[Dict[str, Any]]) -> List[List[str]]: formatted = [] for rec in records: formatted.append([ rec.get("context", ""), rec.get("prompt", ""), rec.get("correct_answer") or rec.get("answer") or rec.get("target") or "", ]) return formatted def _load_examples_for_demo(dataset_name: str): # Convert display name to internal key if needed if get_dataset_key_from_display_name is not None: dataset_key = get_dataset_key_from_display_name(dataset_name) else: dataset_key = dataset_name records = _records_for_dataset(dataset_key) formatted = _format_examples(records) samples = formatted if formatted else _DEMO_DATASETS.get(dataset_key, []) return gr.update(samples=samples or []) def _resolve_example_fields(record: Dict[str, Any]) -> Tuple[str, str, str]: context = record.get("context", "") prompt = record.get("prompt", "") answer = ( record.get("correct_answer") or record.get("answer") or record.get("target") or "" ) return context, prompt, answer def _resolve_dataset_key(dataset_name: str) -> str: if dataset_name in _available_datasets(): return dataset_name for key, label in DATASET_DISPLAY_LABELS.items(): if dataset_name == label: return key if get_dataset_key_from_display_name is not None: return get_dataset_key_from_display_name(dataset_name) return dataset_name def _dataset_choice_labels(dataset_keys: List[str]) -> List[str]: labels: List[str] = [] for key in dataset_keys: if get_dataset_display_name is not None: try: labels.append(get_dataset_display_name(key)) continue except Exception: pass labels.append(DATASET_DISPLAY_LABELS.get(key, key.replace("_", " ").title())) return labels def _resolve_example_index(example_number: Any, records: List[Dict[str, Any]]) -> int: if not records: return 0 try: index = int(example_number) - 1 except Exception: index = 0 return max(0, min(index, len(records) - 1)) def _resolve_example_id(example_number: Any, records: List[Dict[str, Any]]) -> str: if _public_only_mode(): return f"example_{int(example_number or 1)}" index = _resolve_example_index(example_number, records) record = records[index] if records else {} return str(record.get("id") or f"example_{index + 1}") def _build_model_answer_panel(dataset_name: str, example_number: Any) -> str: """Render Model's Answer + Justification HTML for the 30 wrong-answer examples; return empty string for everything else so the gr.HTML slot stays visually empty.""" try: from visualization.wrong_answer_examples import WRONG_ANSWER_EXAMPLES except Exception: return "" dataset_key = _resolve_dataset_key(dataset_name) if dataset_name else "" try: ex_id = f"example_{int(example_number or 1)}" except Exception: ex_id = "example_1" if (dataset_key, ex_id) not in WRONG_ANSWER_EXAMPLES: return "" path = ( _get_results_dir() / "model_answers" / "small" / dataset_key / f"{ex_id}.json" ) if not path.exists(): return "" try: with path.open("r", encoding="utf-8") as f: data = json.load(f) except Exception: return "" from html import escape as _escape letter = (data.get("model_answer_parsed") or "").strip() raw = (data.get("model_answer_raw") or "").strip() gt_letter = (data.get("ground_truth_letter") or data.get("ground_truth") or "").strip() is_match = bool(data.get("is_match")) similarity = data.get("similarity") try: sim_str = f"{float(similarity):.3f}" if similarity is not None else "—" except Exception: sim_str = "—" if is_match: chip_bg, chip_fg, chip_text = "#e7f6ec", "#1f8d4a", "✓ MATCH" else: chip_bg, chip_fg, chip_text = "#fdecea", "#c0392b", "✗ MISMATCH" # Split off the leading letter+rationale prefix for cleaner reading. justification = raw if raw.lower().startswith("justification:"): justification = raw.split(":", 1)[1].strip() elif "Justification:" in raw: justification = raw.split("Justification:", 1)[1].strip() return ( '
' '
' 'Model\'s Answer' f'{_escape(letter or "—")}' f'vs Ground Truth: ' f'{_escape(gt_letter or "—")}' f'{chip_text}' f'sim={sim_str}' '
' '
' f'Justification' f'{_escape(justification) if justification else "No justification captured."}' '
' '
' ) def _load_examples_for_slider(dataset_name: str): dataset_key = _resolve_dataset_key(dataset_name) records = _records_for_dataset(dataset_key) slider_max = max(1, min(10, len(records) or 10)) context = prompt = answer = "" if records: context, prompt, answer = _resolve_example_fields(records[0]) slider_update = gr.update(minimum=1, maximum=slider_max, step=1, value=1) return slider_update, records, context, prompt, answer def _update_example_preview(example_number: Any, records): if not records: return "", "", "" index = _resolve_example_index(example_number, records) return _resolve_example_fields(records[index]) def _results_output_list(results: Dict[str, Any]) -> List[Any]: return [ results["interactions"], results["interactions_tokens_html"], results["interactions_text_html"], results["text_html"], results["meta"], results["scoring_target_source"], results["scoring_target_text"], results["reference_answer"], results["unmasked_answer"], results["debug_scores"], results["scalarizer_used"], results["score_full"], results["score_empty"], results["y_len_tokens"], ] def build_demo_app() -> gr.Blocks: datasets = _available_datasets() default_dataset = datasets[0] if datasets else "demo" # Apply the same colorful CSS theme custom_css = """ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif !important; background: linear-gradient(135deg, #fef5f0 0%, #f0e8ff 50%, #e8f5ff 100%) !important; padding: 24px !important; } .gradio-container h1, .gradio-container h2 { background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; font-weight: 900; font-size: 42px !important; margin: 20px 0 16px 0; letter-spacing: -0.03em; } label, .gr-label { font-weight: 700 !important; font-size: 16px !important; color: #2d1f4a !important; } .gr-button { border-radius: 16px !important; font-weight: 700 !important; font-size: 17px !important; padding: 16px 32px !important; background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; color: white !important; border: none !important; } .gr-box, .gr-input, .gr-dropdown, .gr-textbox { border-radius: 14px !important; border: 3px solid #e8dff5 !important; font-size: 17px !important; } .gr-markdown p { font-size: 17px !important; font-weight: 500 !important; } """ _demo_kwargs = {"title": "AttrLLM Visualization Demo"} if _supports_kwarg(gr.Blocks, "css"): _demo_kwargs["css"] = custom_css with gr.Blocks(**_demo_kwargs) as demo: gr.Markdown( "# 🎨 AttrLLM Visualization Demo\n\n" "**Preview the attribution widgets** before wiring real backends. " "Use the controls below to explore the interface." ) with gr.Row(): with gr.Column(scale=1): # Prepare initial choices and value before creating component initial_choices = _dataset_choice_labels(datasets) initial_value = initial_choices[0] if initial_choices else None dataset_selector = gr.Dropdown( choices=initial_choices, value=initial_value, label="Dataset", interactive=True, allow_custom_value=False, elem_id="dataset-selector-demo", elem_classes=["bubble-select"], ) example_browser = create_example_browser() with gr.Column(scale=1): model_selector = create_model_selector() scalarizer_selector = gr.Dropdown( choices=SCALARIZER_CHOICES, value="logprob", label="Scalarizer", interactive=True, ) embedding_model_box = gr.Textbox( label="Embedding Model (for scalarizer=embedding)", value="Qwen/Qwen3-Embedding-0.6B", lines=1, ) feature_level_selector = create_feature_level_selector() method_toggle = create_attribution_method_toggle() dataset_selector.change( fn=_load_examples_for_demo, inputs=dataset_selector, outputs=example_browser, ) demo.load( fn=_load_examples_for_demo, inputs=[dataset_selector], outputs=[example_browser], ) render_button = gr.Button("Render Demo Visuals", variant="primary") outputs = create_results_display() extra_matrix = gr.Plot(label="Interaction Matrix (demo)") render_button.click( fn=_render_demo, inputs=[method_toggle], outputs=_results_output_list(outputs), ) render_button.click( fn=_render_additional_plots, inputs=[method_toggle], outputs=[extra_matrix], ) return demo def _patch_gradio_schema_generation() -> None: """Prevent Gradio 5.x /info crash caused by additionalProperties: true in schemas.""" try: from gradio_client import utils as client_utils except Exception: return if getattr(client_utils, "_attrllm_schema_patch", False): return original_inner = getattr(client_utils, "_json_schema_to_python_type", None) original_outer = getattr(client_utils, "json_schema_to_python_type", None) if not callable(original_inner) or not callable(original_outer): return def _normalize_schema(schema): if isinstance(schema, bool): return {} if schema else {"type": "null"} if isinstance(schema, list): return [_normalize_schema(item) for item in schema] if not isinstance(schema, dict): return schema normalized = dict(schema) if isinstance(normalized.get("additionalProperties"), bool): normalized["additionalProperties"] = _normalize_schema(normalized["additionalProperties"]) for key in ("properties", "$defs", "definitions", "patternProperties"): value = normalized.get(key) if isinstance(value, dict): normalized[key] = {k: _normalize_schema(v) for k, v in value.items()} for key in ("items", "contains", "not", "if", "then", "else"): if key in normalized: normalized[key] = _normalize_schema(normalized[key]) for key in ("anyOf", "allOf", "oneOf", "prefixItems"): value = normalized.get(key) if isinstance(value, list): normalized[key] = [_normalize_schema(item) for item in value] return normalized client_utils._json_schema_to_python_type = lambda s, d=None: original_inner(_normalize_schema(s), d) client_utils.json_schema_to_python_type = lambda s: original_outer(_normalize_schema(s)) client_utils._attrllm_schema_patch = True _patch_gradio_schema_generation() def build_app() -> gr.Blocks: datasets = _available_datasets() default_dataset = datasets[0] if datasets else "" public_only = _public_only_mode() mm_only = _mm_only_mode() # Custom CSS for prettier UI - Inspired by modern, colorful design custom_css = """ /* Main container styling - Warm gradient background */ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif !important; background: linear-gradient(135deg, #fef5f0 0%, #f0e8ff 50%, #e8f5ff 100%) !important; padding: 24px !important; } /* Header styling - Large, bold, colorful */ .gradio-container h1 { background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; font-weight: 900; font-size: 48px !important; margin: 20px 0 16px 0; letter-spacing: -0.03em; text-align: left; } .gradio-container h2 { background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; font-weight: 900; font-size: 42px !important; margin: 20px 0 16px 0; letter-spacing: -0.03em; } .gradio-container h3 { color: #2d1f4a; font-weight: 800; font-size: 24px !important; margin: 24px 0 16px 0; } /* Tab styling - Bold and colorful */ .tab-nav { border: none !important; background: transparent !important; gap: 8px !important; padding: 8px 0 !important; } .tab-nav button { font-size: 18px !important; font-weight: 700 !important; padding: 16px 32px !important; border-radius: 16px !important; transition: all 0.3s ease !important; border: 3px solid #e0d0f0 !important; background: white !important; color: #6c5ce7 !important; margin-right: 8px !important; } .tab-nav button:hover { background: #f8f4ff !important; border-color: #b8a8db !important; transform: translateY(-2px) !important; } .tab-nav button.selected { background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; color: white !important; border: 3px solid #6c5ce7 !important; box-shadow: 0 6px 20px rgba(108, 92, 231, 0.3) !important; } /* Button styling - Vibrant and interactive */ .gr-button { border-radius: 16px !important; font-weight: 700 !important; font-size: 17px !important; padding: 16px 32px !important; transition: all 0.3s cubic-bezier(0.34, 1.56, 0.64, 1) !important; box-shadow: 0 6px 20px rgba(108, 92, 231, 0.2) !important; border: none !important; } .gr-button-primary { background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; color: white !important; } .gr-button-secondary { background: linear-gradient(135deg, #fd79a8 0%, #ff7675 100%) !important; color: white !important; } .gr-button:hover { transform: translateY(-3px) scale(1.02) !important; box-shadow: 0 10px 30px rgba(108, 92, 231, 0.35) !important; } .gr-button-primary:hover { background: linear-gradient(135deg, #5e4ec7 0%, #9089e8 100%) !important; } /* Input/Dropdown styling - Clear and modern */ .gr-box, .gr-input, .gr-dropdown { border-radius: 14px !important; border: 3px solid #e8dff5 !important; background: white !important; font-size: 17px !important; padding: 12px 16px !important; transition: all 0.3s ease !important; font-weight: 500 !important; } .gr-box:focus, .gr-input:focus, .gr-dropdown:focus { border-color: #6c5ce7 !important; box-shadow: 0 0 0 4px rgba(108, 92, 231, 0.15) !important; transform: translateY(-1px) !important; } /* Textbox styling - Larger text */ .gr-textbox { border-radius: 16px !important; border: 3px solid #e8dff5 !important; font-size: 17px !important; line-height: 1.6 !important; } .gr-textbox textarea { font-size: 17px !important; line-height: 1.6 !important; padding: 14px !important; } .gr-textbox:focus-within { border-color: #6c5ce7 !important; box-shadow: 0 6px 24px rgba(108, 92, 231, 0.2) !important; } /* Radio button styling - Colorful pills */ .gr-radio { gap: 12px !important; } .gr-radio label { font-size: 17px !important; font-weight: 600 !important; padding: 14px 28px !important; border-radius: 14px !important; border: 3px solid #e8dff5 !important; transition: all 0.3s ease !important; background: white !important; cursor: pointer !important; } .gr-radio label:hover { border-color: #b8a8db !important; background: #faf8ff !important; transform: translateY(-2px) !important; box-shadow: 0 4px 12px rgba(108, 92, 231, 0.15) !important; } .gr-radio input:checked + label { background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; color: white !important; border-color: #6c5ce7 !important; font-weight: 800 !important; box-shadow: 0 6px 20px rgba(108, 92, 231, 0.3) !important; } /* Panel/Accordion styling - Clean cards */ .gr-panel { border-radius: 20px !important; border: 3px solid #e8dff5 !important; padding: 24px !important; background: white !important; box-shadow: 0 6px 24px rgba(108, 92, 231, 0.1) !important; margin: 16px 0 !important; } .gr-accordion { border-radius: 18px !important; border: 3px solid #e8dff5 !important; background: white !important; } /* Label styling - Bold and readable */ label, .gr-label { font-weight: 700 !important; font-size: 16px !important; color: #2d1f4a !important; margin-bottom: 10px !important; letter-spacing: -0.01em !important; } /* Dropdown options */ .gr-dropdown-menu { border-radius: 14px !important; border: 3px solid #e8dff5 !important; box-shadow: 0 8px 32px rgba(108, 92, 231, 0.15) !important; font-size: 17px !important; } .gr-dropdown-menu .item { font-size: 17px !important; padding: 12px 16px !important; font-weight: 500 !important; } .gr-dropdown-menu .item:hover { background: linear-gradient(135deg, #f3f0ff 0%, #e8f5ff 100%) !important; } /* Plot container - Prominent */ .gr-plot { border-radius: 20px !important; border: 3px solid #e8dff5 !important; overflow: hidden !important; box-shadow: 0 8px 30px rgba(108, 92, 231, 0.12) !important; background: white !important; width: 100% !important; } /* Force the inner Plotly canvas + svg to fill its container so the Bar View doesn't render in a half-width column when the Text Interaction view above it is wide. */ .gr-plot .js-plotly-plot, .gr-plot .plot-container, .gr-plot .svg-container, .gr-plot .main-svg { width: 100% !important; max-width: 100% !important; } .interaction-stack > .gradio-plot, .interaction-stack > .block.gradio-plot, .interaction-stack .gr-plot { width: 100% !important; max-width: 100% !important; flex: 1 1 100% !important; } /* JSON viewer */ .gr-json { border-radius: 16px !important; border: 3px solid #e8dff5 !important; background: #faf8ff !important; padding: 20px !important; font-family: 'Monaco', 'Menlo', 'Consolas', monospace !important; font-size: 15px !important; } /* Column styling */ .gr-column { padding: 20px !important; } /* Row styling */ .gr-row { gap: 24px !important; margin: 12px 0 !important; } /* Markdown content - Larger, more readable */ .gr-markdown { line-height: 1.8 !important; color: #2d1f4a !important; } .gr-markdown p { font-size: 17px !important; margin: 12px 0 !important; font-weight: 500 !important; } .gr-markdown strong { font-weight: 800 !important; color: #6c5ce7 !important; } /* Status/info messages - Colorful notifications */ .gr-info { border-radius: 16px !important; border-left: 5px solid #6c5ce7 !important; background: linear-gradient(135deg, #f8f6ff 0%, #f0f4ff 100%) !important; padding: 18px 24px !important; font-size: 16px !important; font-weight: 600 !important; color: #2d1f4a !important; box-shadow: 0 4px 16px rgba(108, 92, 231, 0.1) !important; } /* Error messages */ .gr-error { border-radius: 16px !important; border-left: 5px solid #ff6b6b !important; background: linear-gradient(135deg, #fff5f5 0%, #ffe8e8 100%) !important; padding: 18px 24px !important; font-size: 16px !important; font-weight: 600 !important; color: #c44569 !important; } /* Loading spinner */ .loading { border: 4px solid #f3f0ff !important; border-top: 4px solid #6c5ce7 !important; } /* Scrollbar styling */ ::-webkit-scrollbar { width: 12px !important; height: 12px !important; } ::-webkit-scrollbar-track { background: #f8f6ff !important; border-radius: 10px !important; } ::-webkit-scrollbar-thumb { background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; border-radius: 10px !important; } ::-webkit-scrollbar-thumb:hover { background: linear-gradient(135deg, #5e4ec7 0%, #9089e8 100%) !important; } .results-shell { margin-top: 16px !important; background: transparent !important; border: none !important; border-radius: 0 !important; padding: 0 !important; box-shadow: none !important; } .results-shell, .results-shell > div, .results-shell .gr-group, .results-shell .gr-box, .results-shell .gr-panel, .results-shell .block { background: transparent !important; border: none !important; box-shadow: none !important; } .interaction-stack { gap: 20px !important; padding: 0 8px 6px !important; } .interaction-stack h3 { margin-left: 24px !important; margin-bottom: 12px !important; } .public-controls { align-items: stretch !important; gap: 20px !important; margin-top: 8px !important; } .control-card { background: linear-gradient(180deg, rgba(255, 255, 255, 0.82) 0%, rgba(250, 246, 255, 0.96) 100%) !important; border: 2px solid rgba(224, 208, 240, 0.78) !important; border-radius: 26px !important; padding: 18px 20px 14px !important; box-shadow: 0 14px 30px rgba(108, 92, 231, 0.07) !important; } .control-card-primary { background: linear-gradient(180deg, rgba(255, 255, 255, 0.86) 0%, rgba(244, 248, 255, 0.96) 100%) !important; } .control-card-secondary { background: linear-gradient(180deg, rgba(255, 255, 255, 0.86) 0%, rgba(250, 244, 255, 0.96) 100%) !important; } .control-card .gradio-container, .control-card .gr-group { background: transparent !important; } .control-card > div, .control-card .block, .control-card .wrap, .control-card .gr-form, .control-card .form { background: transparent !important; border: none !important; box-shadow: none !important; } .control-card .gr-box, .control-card .gr-panel { background: transparent !important; box-shadow: none !important; } .bubble-select { border: 3px solid #8f5cff !important; border-radius: 18px !important; box-shadow: 0 8px 20px rgba(143, 92, 255, 0.10) !important; transition: box-shadow 0.2s ease, border-color 0.2s ease !important; } .bubble-select:focus-within { border-color: #7a3dff !important; box-shadow: 0 0 0 4px rgba(143, 92, 255, 0.14), 0 10px 24px rgba(143, 92, 255, 0.16) !important; } .example-id-slider { margin-top: 8px !important; padding: 10px 2px 2px !important; } .example-id-slider input[type="range"] { accent-color: #4f7cff !important; } .example-id-slider .number-input, .example-id-slider input[type="number"] { border-radius: 16px !important; border: 2px solid #d8dcee !important; background: linear-gradient(180deg, #ffffff 0%, #f7f9ff 100%) !important; font-weight: 700 !important; min-width: 72px !important; } .example-id-slider .wrap { gap: 14px !important; } @media (prefers-color-scheme: dark) { .gradio-container { background: radial-gradient(circle at top, #1e2a44 0%, #0d1422 52%, #090f19 100%) !important; } .gradio-container h3, label, .gr-label, .gr-markdown, .gr-markdown p { color: #e8eefc !important; } .gr-markdown strong { color: #cbd7ff !important; } .tab-nav button, .gr-box, .gr-input, .gr-dropdown, .gr-textbox, .gr-panel, .gr-accordion, .gr-plot, .gr-json { background: rgba(16, 24, 39, 0.88) !important; border-color: rgba(148, 163, 184, 0.24) !important; color: #e8eefc !important; } .tab-nav button { color: #d7e1ff !important; } .tab-nav button:hover { background: rgba(37, 52, 79, 0.96) !important; border-color: rgba(199, 210, 254, 0.36) !important; } .gr-radio label { background: rgba(16, 24, 39, 0.9) !important; border-color: rgba(148, 163, 184, 0.26) !important; color: #e8eefc !important; } .gr-radio label:hover { background: rgba(37, 52, 79, 0.96) !important; } .gr-textbox textarea, .gr-input input { background: transparent !important; color: #e8eefc !important; } .gr-dropdown-menu { background: #101827 !important; border-color: rgba(148, 163, 184, 0.24) !important; } .gr-dropdown-menu .item { color: #e8eefc !important; } .gr-dropdown-menu .item:hover { background: rgba(37, 52, 79, 0.96) !important; } .gr-plot .main-svg, .gr-plot .svg-container, .gr-plot .plot-container, .gr-plot .user-select-none { background: transparent !important; } .gr-plot .xtick text, .gr-plot .ytick text, .gr-plot .gtitle text, .gr-plot .xtitle text, .gr-plot .ytitle text, .gr-plot .annotation-text, .gr-plot .legend text { fill: #e8eefc !important; color: #e8eefc !important; } .gr-plot .gridlayer path, .gr-plot .zerolinelayer path, .gr-plot .xlines-above path, .gr-plot .ylines-above path { stroke: rgba(148, 163, 184, 0.22) !important; } .gr-info { background: linear-gradient(135deg, rgba(30, 41, 59, 0.95) 0%, rgba(17, 24, 39, 0.95) 100%) !important; color: #dbe7ff !important; border-left-color: #9db4ff !important; } .control-card { background: linear-gradient(180deg, rgba(16, 24, 39, 0.9) 0%, rgba(18, 28, 45, 0.96) 100%) !important; border-color: rgba(148, 163, 184, 0.2) !important; box-shadow: 0 18px 36px rgba(0, 0, 0, 0.24) !important; } .results-shell, .results-shell > div, .results-shell .gr-group, .results-shell .gr-box, .results-shell .gr-panel, .results-shell .block { background: transparent !important; border: none !important; box-shadow: none !important; } .bubble-select { border-color: #a06cff !important; box-shadow: 0 10px 24px rgba(143, 92, 255, 0.18) !important; } .bubble-select:focus-within { border-color: #c29cff !important; box-shadow: 0 0 0 4px rgba(143, 92, 255, 0.16), 0 12px 26px rgba(143, 92, 255, 0.22) !important; } .example-id-slider .number-input, .example-id-slider input[type="number"] { background: linear-gradient(180deg, #162031 0%, #111827 100%) !important; border-color: rgba(148, 163, 184, 0.24) !important; color: #e8eefc !important; } .gr-error { background: linear-gradient(135deg, rgba(68, 18, 32, 0.95) 0%, rgba(39, 12, 20, 0.95) 100%) !important; color: #ffd5dc !important; } ::-webkit-scrollbar-track { background: #111827 !important; } } """ _app_kwargs = {"title": "LLM Reasoning Explorer Studio"} if _supports_kwarg(gr.Blocks, "css"): _app_kwargs["css"] = custom_css with gr.Blocks(**_app_kwargs) as app: gr.Markdown( "# LLM Reasoning Explorer Studio\n\n" "**Explore attribution results and feature interactions** with our interactive visualization tools. " "Browse pre-computed examples or analyze your own text in real-time with powerful AI insights." ) gr.Markdown(f"**Build:** {BUILD_ID} ({BUILD_TS})") example_state = gr.State([]) with (gr.Column(visible=not mm_only) if (public_only or mm_only) else gr.Tab("Public Mode")): with gr.Accordion("How to Use", open=False): gr.Markdown( "1. **Select a dataset** from 10 available datasets (100 total examples, 10 per dataset)\n" "2. **Choose a model** to compare: Qwen3-4B, Qwen3-30B, or Mistral-7B\n" "3. **Pick a scoring method:** Perplexity or Semantic Similarity\n" "4. **Set the feature level:** Word, Sentence, or Paragraph\n" "5. **Choose an attribution method:** Shapley, Banzhaf, or Influence\n" "6. **View results** in the Text Interaction View (inline highlights) and Bar View (ranked interactions)" ) with gr.Row(elem_classes=["public-controls"]): with gr.Column(scale=1, elem_classes=["control-card", "control-card-primary"]): # Prepare initial choices and value before creating component initial_choices = _dataset_choice_labels(datasets) # In mm_only mode the text attribution tab is hidden — no default value # prevents the .change() callback from firing on page load. _preferred_default = "BBQ Disambiguation" if mm_only: initial_value = None elif _preferred_default in initial_choices: initial_value = _preferred_default else: initial_value = initial_choices[0] if initial_choices else None dataset_selector = gr.Dropdown( choices=initial_choices, value=initial_value, label="Dataset", interactive=True, allow_custom_value=False, elem_id="dataset-selector", elem_classes=["bubble-select"], ) example_selector = gr.Slider( label="Example ID", minimum=1, maximum=10, step=1, value=1, interactive=True, elem_classes=["example-id-slider"], ) with gr.Column(scale=1, elem_classes=["control-card", "control-card-secondary"]): model_selector = create_model_selector() scalarizer_selector = gr.Dropdown( choices=PUBLIC_SCALARIZER_CHOICES, value="geomean_jointprob", label="Scalarizer", interactive=True, elem_classes=["bubble-select"], ) public_feature_level_selector = create_feature_level_selector(value="word") method_toggle = create_attribution_method_toggle() with gr.Accordion("Example Preview", open=True): with gr.Row(): with gr.Column(scale=3): context_box = gr.Textbox( label="Context", lines=8, interactive=False, ) with gr.Column(scale=2): prompt_box = gr.Textbox( label="Prompt", lines=4, interactive=False, ) answer_box = gr.Textbox( label="Ground Truth Answer", lines=3, interactive=False, ) # Empty for examples outside the 30-pair allow-list; renders # the model's parsed letter + justification for the others. try: model_answer_html = gr.HTML(value="", sanitize_html=False) except TypeError: model_answer_html = gr.HTML(value="") public_results = create_results_display() def _public_mode_compute( dataset, example_number, records, model_size, scalarizer, feature_level, method, progress=gr.Progress(track_tqdm=True), ): if mm_only: return tuple([None] * 14) if not dataset: raise gr.Error("Please select a dataset.") if not example_number: raise gr.Error("Please select an example.") dataset_key = _resolve_dataset_key(dataset) ex_id = _resolve_example_id(example_number, records) method = _normalize_method(method) level = _normalize_level(feature_level) model_size = _normalize_model_size(model_size) # Prefer precomputed results: use loader if available, else load from file (Space-friendly). get_res = get_result_by_id if get_result_by_id is not None else _public_get_result_from_file result = get_res( model_size, dataset_key, ex_id, scalarizer=scalarizer, feature_level=level, ) or {} payload = result.get(method, {}) if not payload: alt_size = _find_available_model_size(dataset_key, ex_id, scalarizer, level) if alt_size and alt_size != model_size: result = get_res( alt_size, dataset_key, ex_id, scalarizer=scalarizer, feature_level=level, ) or {} payload = result.get(method, {}) if payload: model_size = alt_size # If still no payload, try any available (model_size, scalarizer, level) for this example if not payload: alt_size, alt_scalarizer, alt_level, result = _find_any_available_result( dataset_key, ex_id, get_res, method ) if alt_size and alt_scalarizer and alt_level and result: payload = result.get(method, {}) model_size, scalarizer, level = alt_size, alt_scalarizer, alt_level if payload and (payload.get("features") or payload.get("heatmap")): _, _, _, *outputs = on_select_example( dataset_key, ex_id, model_size, 2, method, scalarizer=scalarizer, feature_level=level, ) return outputs # Public-only mode: do not attempt live compute if _public_only_mode() or get_example_by_id is None: expected_ref = _reference_results_file(model_size, dataset_key, ex_id, scalarizer, level) raise gr.Error( "No precomputed results found.\n\n" f"Expected (reference_answer):\n{expected_ref}\n\n" "On Hugging Face Space: make sure the 'results' folder is in your repo " "(commit & push it). If you use Git LFS, enable 'LFS' in Space Settings → " "Repository and ensure files are pulled. You can also try another " "scalarizer (e.g. Perplexity) or feature level (e.g. word)." ) # Fallback to live compute if no precomputed payload or non-word level get_ex = _ensure_backend("loader.data.get_example_by_id", get_example_by_id) record = get_ex(dataset_key, ex_id) context = record.get("context", "") prompt = record.get("prompt", "") answer = _extract_answer(record) return _compute_live_attributions( context=context, prompt=prompt, correct_answer=answer, model_size=model_size, scalarizer=scalarizer, embedding_model=None, level=level, method=method, order=2, progress=progress, ) public_preview_outputs = [context_box, prompt_box, answer_box] public_compute_inputs = [ dataset_selector, example_selector, example_state, model_selector, scalarizer_selector, public_feature_level_selector, method_toggle, ] public_compute_outputs = _results_output_list(public_results) dataset_change_event = dataset_selector.change( fn=_load_examples_for_slider, inputs=[dataset_selector], outputs=[ example_selector, example_state, context_box, prompt_box, answer_box, ], queue=False, ).then( fn=_build_model_answer_panel, inputs=[dataset_selector, example_selector], outputs=[model_answer_html], queue=False, ) load_event = app.load( fn=_load_examples_for_slider, inputs=[dataset_selector], outputs=[ example_selector, example_state, context_box, prompt_box, answer_box, ], ).then( fn=_build_model_answer_panel, inputs=[dataset_selector, example_selector], outputs=[model_answer_html], queue=False, ) dataset_change_event.then( fn=_public_mode_compute, inputs=public_compute_inputs, outputs=public_compute_outputs, show_progress="full", ) load_event.then( fn=_public_mode_compute, inputs=public_compute_inputs, outputs=public_compute_outputs, show_progress="full", ) example_selector.release( fn=_update_example_preview, inputs=[example_selector, example_state], outputs=public_preview_outputs, queue=False, ).then( fn=_build_model_answer_panel, inputs=[dataset_selector, example_selector], outputs=[model_answer_html], queue=False, ).then( fn=_public_mode_compute, inputs=public_compute_inputs, outputs=public_compute_outputs, show_progress="full", ) for component in ( model_selector, scalarizer_selector, public_feature_level_selector, method_toggle, ): component.change( fn=_public_mode_compute, inputs=public_compute_inputs, outputs=public_compute_outputs, show_progress="full", ) # ── MULTIMODAL TAB ────────────────────────────────────────── with gr.Tab("Multimodal"): with gr.Accordion("How to Use", open=False): gr.Markdown( "1. **Choose a dataset** from the three sub-tabs:\n" " - **MIMIC-CXR (10 Samples)** — chest X-rays across 10 pathology categories\n" " - **Dermoscopy ISIC (10 Samples)** — skin-lesion dermoscopy across 8 diagnostic classes\n" " - **MS-COCO (5 Samples)** — natural-image cross-modal benchmark\n" "2. **Pick an example** from the dropdown (each is an image + caption pair)\n" "3. **Choose an attribution method:** Influence (default, non-negative — clearer for clinicians) or Shapley (signed)\n" "4. **Read the four panels side-by-side:**\n" " - **Interactive Cross-Modal View** — click any image patch or caption word to see its strongest cross-modal partners\n" " - **BiomedCLIP Cross-Modal Attribution** — patch-level overlay + bar charts (cosine-similarity scoring)\n" " - **LLaVA-Med Attribution** — log-prob and generation Shapley charts from the medical 7B VLM\n" " - **Compare Two Methods Side-by-Side** — pick any two of the above to overlay their rankings\n" "5. **Hover token chips and patches** for exact attribution values; hover SVG arcs for pairwise interaction strength" ) with gr.Tab("MIMIC-CXR (10 Samples)"): gr.Markdown( "**10-sample MIMIC-CXR chest X-ray attribution benchmark** " "(10 pathology categories). \n" "Source: [MIMIC-CXR-JPG](https://huggingface.co/datasets/itsanmolgupta/mimic-cxr-dataset-cleaned) " "— de-identified chest radiographs from Beth Israel Deaconess Medical Center. \n" "Each example has a radiology report (impression = caption, findings = detail)." ) # Build (category_name, example_id) choices so picking a # pathology directly loads its example (1:1 mapping). _mimic_choices = ( [(v["category"], k) for k, v in MIMIC_EXAMPLES.items()] if _MIMIC_AVAILABLE else [] ) mimic_selector = gr.Dropdown( choices=_mimic_choices, value=None, label="Filter by Pathology", interactive=True, ) mimic_method_toggle = gr.Radio( choices=["Influence", "Shapley"], value="Influence", label="Attribution method", info=( "Influence (default) is always positive — clearer for clinicians. " "Shapley is signed (green = supports caption, red = detracts)." ), interactive=True, ) mimic_caption = gr.Textbox( label="Radiology Impression (Caption)", interactive=False, lines=2, ) with gr.Accordion("Full Radiology Findings", open=False): mimic_findings = gr.Textbox( label="Detailed Findings", interactive=False, lines=5, ) # ── Original Image ──────────────────────────────── mimic_original = gr.Image(label="Chest X-ray", type="filepath") mimic_interpretation = gr.Markdown( value="*Select an example above to see the attribution analysis.*", label="Interpretation", ) # ── Table of Contents ───────────────────────────── _mimic_pill = ( 'style="display:inline-block;padding:6px 14px;background:#e3f2fd;' 'border-radius:16px;text-decoration:none;color:#1565c0;font-size:0.9em;' 'border:1px solid #bbdefb;"' ) _mimic_toc_html = ( '
' 'Jump to Section:' '
' ) gr.HTML(value=_mimic_toc_html) mimic_results_state = gr.State({}) # ════════════════════════════════════════════════════ # ── Interactive Cross-Modal View ─────────────────── # ════════════════════════════════════════════════════ with gr.Column(elem_id="mimic-interactive"): gr.Markdown("---\n### BiomedCLIP Cross-Modal Interaction View — click segments or words") gr.Markdown( "**How to use:** Click any **image region** to see which caption words " "it connects to, or click a **word** to see which regions activate. \n" "**Green** arrows = positive interaction. **Red** arrows = negative." ) mimic_biomedclip_interaction_html = _html_component( "BiomedCLIP Cross-Modal Interaction View") # ════════════════════════════════════════════════════ # ── BiomedCLIP Cross-Modal Attribution ───────────── # ════════════════════════════════════════════════════ with gr.Column(elem_id="mimic-method-biomedclip"): gr.Markdown("---\n### BiomedCLIP Cross-Modal Attribution") gr.Markdown( "**What it does:** Uses [BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) " "— a CLIP model trained on **15 million biomedical figure-caption pairs** — " "to jointly score image regions (via UnSAM segmentation) and caption tokens. \n" "**How to read:** **Green** = positive Shapley value (contributes to alignment). " "**Red** = negative (hurts alignment)." ) mimic_biomedclip_overlay = gr.Image( label="BiomedCLIP Overlay (labeled segments)", type="filepath") mimic_biomedclip_token_plot = gr.Plot( label="BiomedCLIP — Caption Word Shapley Values") mimic_biomedclip_region_plot = gr.Plot( label="BiomedCLIP — Image Region Shapley Values") # ════════════════════════════════════════════════════ # ── LLaVA-Med Attribution (UnSAM Segments) ───────── # ════════════════════════════════════════════════════ with gr.Column(elem_id="mimic-method-llavamed"): gr.Markdown("---\n### LLaVA-Med Attribution (4×4 Patch Grid, P1–P16)") gr.Markdown( "**What it does:** Uses [LLaVA-Med](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) " "— a **7B parameter** medical VLM — evaluated over a **uniform 4×4 patch grid** " "(16 cells labeled **P1–P16**, row-major). \n" "**Two scoring approaches:** \n" "- **Log-Prob:** How removing a region affects confidence in the correct caption \n" "- **Generation:** How removing a region changes what the model describes" ) gr.Markdown( "Each method colors segments by its own Shapley values — " "**green** = positive, **red** = negative. Signs often differ " "between Log-Prob and Generation, so each has its own overlay." ) with gr.Row(equal_height=True): mimic_llavamed_unsam_lp_overlay = gr.Image( label="LLaVA-Med Log-Prob — Overlay", type="filepath", height=600) mimic_llavamed_unsam_gen_overlay = gr.Image( label="LLaVA-Med Generation — Overlay", type="filepath", height=600) with gr.Row(): mimic_llavamed_unsam_lp_plot = gr.Plot( label="LLaVA-Med Log-Prob — Segment Shapley Values") mimic_llavamed_unsam_gen_plot = gr.Plot( label="LLaVA-Med Generation — Segment Shapley Values") # ════════════════════════════════════════════════════ # ── Compare Two Methods Side-by-Side ────────────── # ════════════════════════════════════════════════════ with gr.Column(elem_id="mimic-compare"): gr.Markdown("---\n### Compare Two Methods Side-by-Side") gr.Markdown( "Select two methods to compare their attribution overlays " "and Shapley value distributions on the same image." ) with gr.Row(): mimic_compare_method_a = gr.Dropdown( choices=_MIMIC_METHOD_NAMES, label="Method A", interactive=True, ) mimic_compare_method_b = gr.Dropdown( choices=_MIMIC_METHOD_NAMES, label="Method B", interactive=True, ) with gr.Row(): mimic_compare_img_a = gr.Image(label="Method A — Overlay", type="filepath") mimic_compare_img_b = gr.Image(label="Method B — Overlay", type="filepath") with gr.Row(): mimic_compare_plot_a = gr.Plot(label="Method A — Shapley Values") mimic_compare_plot_b = gr.Plot(label="Method B — Shapley Values") mimic_meta = gr.JSON(label="Example Info", visible=False) _mimic_outputs = [ mimic_caption, mimic_original, mimic_findings, mimic_interpretation, mimic_biomedclip_overlay, mimic_biomedclip_token_plot, mimic_biomedclip_region_plot, mimic_llavamed_unsam_lp_overlay, mimic_llavamed_unsam_lp_plot, mimic_llavamed_unsam_gen_overlay, mimic_llavamed_unsam_gen_plot, mimic_biomedclip_interaction_html, mimic_meta, mimic_results_state, mimic_compare_method_a, ] mimic_selector.change( fn=_on_select_mimic_example, inputs=[mimic_selector, mimic_method_toggle], outputs=_mimic_outputs, ) mimic_method_toggle.change( fn=_on_select_mimic_example, inputs=[mimic_selector, mimic_method_toggle], outputs=_mimic_outputs, ) # Wire: comparison dropdowns -> side-by-side display for _mimic_cmp_dd in [mimic_compare_method_a, mimic_compare_method_b]: _mimic_cmp_dd.change( fn=_on_mimic_compare_methods, inputs=[mimic_compare_method_a, mimic_compare_method_b, mimic_results_state], outputs=[mimic_compare_img_a, mimic_compare_img_b, mimic_compare_plot_a, mimic_compare_plot_b], ) # ── ISIC Dermoscopy Tab ──────────────────────── with gr.Tab("Dermoscopy ISIC (10 Samples)"): gr.Markdown( "**10-sample ISIC-2019 dermoscopy attribution benchmark** " "(8 diagnostic classes: MEL × 2, NV × 2, BCC, AK, BKL, DF, VASC, SCC). \n" "Source: [ISIC_2019_224](https://huggingface.co/datasets/MKZuziak/ISIC_2019_224) " "— dermoscopic skin-lesion images from the International Skin Imaging Collaboration. \n" "Captions are synthesized from class labels (clinical descriptions of each diagnosis)." ) _isic_choices = ( [(v["category"], k) for k, v in ISIC_EXAMPLES.items()] if _ISIC_AVAILABLE else [] ) isic_selector = gr.Dropdown( choices=_isic_choices, value=None, label="Filter by Diagnosis", interactive=True, ) isic_method_toggle = gr.Radio( choices=["Influence", "Shapley"], value="Influence", label="Attribution method", info=( "Influence (default) is always positive — clearer for clinicians. " "Shapley is signed (green = supports caption, red = detracts)." ), interactive=True, ) isic_caption = gr.Textbox( label="Diagnostic Caption", interactive=False, lines=3, ) isic_original = gr.Image(label="Dermoscopic Image", type="filepath") isic_interpretation = gr.Markdown( value="*Select an example above to see the attribution analysis.*", label="Interpretation", ) _isic_pill = ( 'style="display:inline-block;padding:6px 14px;background:#e3f2fd;' 'border-radius:16px;text-decoration:none;color:#1565c0;font-size:0.9em;' 'border:1px solid #bbdefb;"' ) _isic_toc_html = ( '
' 'Jump to Section:' '
' ) gr.HTML(value=_isic_toc_html) isic_results_state = gr.State({}) with gr.Column(elem_id="isic-interactive"): gr.Markdown("---\n### BiomedCLIP Cross-Modal Interaction View — click segments or words") gr.Markdown( "**How to use:** Click any **image region** to see which caption words " "it connects to, or click a **word** to see which regions activate." ) isic_biomedclip_interaction_html = _html_component( "BiomedCLIP Cross-Modal Interaction View") with gr.Column(elem_id="isic-method-biomedclip"): gr.Markdown("---\n### BiomedCLIP Cross-Modal Attribution") gr.Markdown( "**What it does:** Uses [BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) " "to jointly score dermoscopic image regions (via UnSAM segmentation) " "and caption tokens. \n" "**How to read:** **Influence** bars (default) show positive importance. " "Switch to **Shapley** above for signed values (green/red)." ) isic_biomedclip_overlay = gr.Image( label="BiomedCLIP Overlay (labeled segments)", type="filepath") isic_biomedclip_token_plot = gr.Plot( label="BiomedCLIP — Caption Word Values") isic_biomedclip_region_plot = gr.Plot( label="BiomedCLIP — Image Region Values") with gr.Column(elem_id="isic-method-llavamed"): gr.Markdown("---\n### LLaVA-Med Attribution (4×4 Patch Grid, P1–P16)") gr.Markdown( "**What it does:** Uses [LLaVA-Med](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) " "— a **7B parameter** medical VLM — evaluated over a **uniform 4×4 patch grid** " "(16 cells labeled **P1–P16**, row-major). \n" "**Two scoring approaches:** \n" "- **Log-Prob:** How removing a region affects confidence in the caption \n" "- **Generation:** How removing a region changes what the model describes" ) with gr.Row(equal_height=True): isic_llavamed_unsam_lp_overlay = gr.Image( label="LLaVA-Med Log-Prob — Overlay", type="filepath", height=600) isic_llavamed_unsam_gen_overlay = gr.Image( label="LLaVA-Med Generation — Overlay", type="filepath", height=600) with gr.Row(): isic_llavamed_unsam_lp_plot = gr.Plot( label="LLaVA-Med Log-Prob — Segment Values") isic_llavamed_unsam_gen_plot = gr.Plot( label="LLaVA-Med Generation — Segment Values") with gr.Column(elem_id="isic-compare"): gr.Markdown("---\n### Compare Two Methods Side-by-Side") gr.Markdown( "Select two methods to compare their attribution overlays " "and value distributions on the same image." ) with gr.Row(): isic_compare_method_a = gr.Dropdown( choices=_ISIC_METHOD_NAMES, label="Method A", interactive=True, ) isic_compare_method_b = gr.Dropdown( choices=_ISIC_METHOD_NAMES, label="Method B", interactive=True, ) with gr.Row(): isic_compare_img_a = gr.Image(label="Method A — Overlay", type="filepath") isic_compare_img_b = gr.Image(label="Method B — Overlay", type="filepath") with gr.Row(): isic_compare_plot_a = gr.Plot(label="Method A — Values") isic_compare_plot_b = gr.Plot(label="Method B — Values") isic_meta = gr.JSON(label="Example Info", visible=False) _isic_outputs = [ isic_caption, isic_original, isic_interpretation, isic_biomedclip_overlay, isic_biomedclip_token_plot, isic_biomedclip_region_plot, isic_llavamed_unsam_lp_overlay, isic_llavamed_unsam_lp_plot, isic_llavamed_unsam_gen_overlay, isic_llavamed_unsam_gen_plot, isic_biomedclip_interaction_html, isic_meta, isic_results_state, isic_compare_method_a, ] isic_selector.change( fn=_on_select_isic_example, inputs=[isic_selector, isic_method_toggle], outputs=_isic_outputs, ) isic_method_toggle.change( fn=_on_select_isic_example, inputs=[isic_selector, isic_method_toggle], outputs=_isic_outputs, ) for _isic_cmp_dd in [isic_compare_method_a, isic_compare_method_b]: _isic_cmp_dd.change( fn=_on_isic_compare_methods, inputs=[isic_compare_method_a, isic_compare_method_b, isic_results_state], outputs=[isic_compare_img_a, isic_compare_img_b, isic_compare_plot_a, isic_compare_plot_b], ) # ── MS-COCO Tab ───────────────────────────────── with gr.Tab("MS-COCO (5 Samples)"): gr.Markdown( "**CLIP cross-modal attribution on MS-COCO natural images.** \n" "Click an **image region** or **caption word** below to explore " "which parts of the image and text are most strongly linked via " "CLIP's visual-language similarity score." ) _coco_choices = ( [(v["title"], k) for k, v in COCO_EXAMPLES.items()] if _COCO_AVAILABLE else [] ) _coco_default = _coco_choices[0][1] if _coco_choices else None coco_selector = gr.Radio( choices=_coco_choices, value=_coco_default, label="Select MS-COCO Example", interactive=True, ) coco_method_toggle = gr.Radio( choices=["Influence", "Shapley"], value="Influence", label="Attribution method", info="Influence (default) is always positive. Shapley is signed.", interactive=True, ) coco_caption = gr.Textbox( label="Caption", interactive=False, lines=2, ) gr.Markdown("---\n#### Interactive Cross-Modal View") gr.Markdown( "Click a colored **image region** (left) to highlight the caption " "words it interacts with, or click a **word** (right) to highlight " "linked regions. Green = positive, red = negative." ) coco_interaction_html = _html_component( "COCO Cross-Modal Interaction View") gr.Markdown("---\n#### Attribution Details") with gr.Row(): coco_original = gr.Image( label="Original Image", type="filepath") coco_overlay = gr.Image( label="CLIP Overlay (labeled segments)", type="filepath") with gr.Row(): coco_token_plot = gr.Plot( label="Caption Word Shapley Values") coco_region_plot = gr.Plot( label="Image Region Shapley Values") with gr.Row(): coco_cross_plot = gr.Plot( label="Top Image x Word Interactions") coco_cross_table = gr.Dataframe( headers=["Image Region", "Caption Word", "Score"], label="Cross-Modal Interaction Table", interactive=False, ) with gr.Accordion("Influence Heatmap (Regions x Words)", open=False): coco_heatmap = gr.Plot( label="Full Heatmap: Regions x Caption Words") gr.Markdown("---\n#### Masked Image Browser") gr.Markdown( "Browse ablation images: **solo** shows only the selected region " "(everything else inpainted away); **removed** shows the image with " "that region inpainted out." ) with gr.Row(): coco_masked_dd = gr.Dropdown( choices=[], label="Region / View", interactive=True, ) coco_masked_img = gr.Image( label="Masked View", type="filepath") coco_note = gr.Markdown(value="") _coco_outputs = [ coco_caption, coco_original, coco_overlay, coco_interaction_html, coco_token_plot, coco_region_plot, coco_cross_plot, coco_cross_table, coco_heatmap, coco_note, coco_masked_img, coco_masked_dd, ] coco_selector.change( fn=_on_select_coco_example, inputs=[coco_selector, coco_method_toggle], outputs=_coco_outputs, ) coco_method_toggle.change( fn=_on_select_coco_example, inputs=[coco_selector, coco_method_toggle], outputs=_coco_outputs, ) coco_masked_dd.change( fn=_on_select_coco_masked, inputs=[coco_selector, coco_masked_dd], outputs=[coco_masked_img], ) # NOTE: auto-load removed — too much data on startup crashes the browser. # Users select an example via the Radio to trigger loading. gr.HTML( '
' '

Contributors — University of California, Berkeley

' '

' 'Stephen Tao · Loader Layer · ' 'stephen_tao@berkeley.edu' 'Yiting Gao · Attribution Layer · ' 'yg2025@berkeley.edu' 'Qingpeng Kong · Visualization Layer · ' 'qpkong@berkeley.edu' '

' '

' 'Advisor: Kannan Ramchandran · ' 'kannanr@berkeley.edu' 'Mentor: Landon Butler · ' 'landonb@berkeley.edu' '

' ) # Stash CSS for Gradio 6.x launch() (Blocks(css=) is deprecated in 6.x) app._custom_css = custom_css return app def _launch_kwargs(app_or_demo, **kwargs): """Build common launch kwargs, injecting CSS for Gradio 6.x.""" lk = dict( server_name=kwargs.pop("server_name", os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")), server_port=int(kwargs.pop("server_port", os.getenv("GRADIO_SERVER_PORT", "7860"))), share=kwargs.pop("share", _env_flag("GRADIO_SHARE", False)), show_error=kwargs.pop("show_error", True), ) css = getattr(app_or_demo, "_custom_css", None) if css and _supports_kwarg(app_or_demo.launch, "css"): lk["css"] = css lk.update(kwargs) return lk def launch_demo(**kwargs): demo = build_demo_app() demo.launch(**_launch_kwargs(demo, **kwargs)) def launch_app(**kwargs): app = build_app() app.launch(**_launch_kwargs(app, **kwargs)) if __name__ == "__main__": launch_app()