| import base64 |
| import copy |
| import inspect |
| import io |
| import json |
| import os |
| import random |
| import re |
| from itertools import combinations |
| from pathlib import Path |
| from typing import Dict, Any, List, Tuple, Optional |
| import requests |
|
|
| BACKEND_URL = os.getenv("ATTRLLM_BACKEND_URL", "http://127.0.0.1:8000") |
|
|
| _DEFAULT_GRADIO_DIR = Path(os.environ.get("GRADIO_TEMP_DIR", Path.cwd() / ".gradio_tmp")) |
| os.environ.setdefault("GRADIO_TEMP_DIR", str(_DEFAULT_GRADIO_DIR)) |
| _DEFAULT_GRADIO_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def _get_request_timeout() -> float: |
| value = os.getenv("ATTRLLM_REQUEST_TIMEOUT") |
| if not value: |
| return 900.0 |
| try: |
| return float(value) |
| except ValueError: |
| return 900.0 |
|
|
|
|
| def _env_flag(name: str, default: bool = False) -> bool: |
| value = os.getenv(name) |
| if value is None: |
| return default |
| return value.strip().lower() in {"1", "true", "yes", "y", "on"} |
|
|
|
|
| def _is_hf_spaces() -> bool: |
| return bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE")) |
|
|
|
|
| def _supports_kwarg(callable_obj, kwarg_name: str) -> bool: |
| """Return whether a callable appears to accept a named keyword argument.""" |
| try: |
| return kwarg_name in inspect.signature(callable_obj).parameters |
| except (TypeError, ValueError): |
| return False |
|
|
|
|
| def _public_only_mode() -> bool: |
| |
| return _env_flag("ATTRLLM_PUBLIC_ONLY", False) |
|
|
|
|
| def _mm_only_mode() -> bool: |
| return _env_flag("ATTRLLM_MM_ONLY", False) |
|
|
|
|
| def _show_auxiliary_tabs() -> bool: |
| return _env_flag("ATTRLLM_SHOW_AUX_TABS", False) |
|
|
|
|
| def _public_results_file( |
| dataset_key: str, |
| ex_id: str, |
| scalarizer: str, |
| level: str, |
| method: str, |
| ) -> Path: |
| results_dir = _get_results_dir() |
| return ( |
| results_dir |
| / "public" |
| / dataset_key |
| / ex_id |
| / scalarizer |
| / level |
| / f"{method}.json" |
| ) |
|
|
|
|
| def _reference_results_file( |
| model_size: str, |
| dataset_key: str, |
| ex_id: str, |
| scalarizer: str, |
| level: str, |
| ) -> Path: |
| results_dir = _get_results_dir() |
| return ( |
| results_dir |
| / "reference_answer" |
| / model_size |
| / dataset_key |
| / ex_id |
| / scalarizer |
| / f"{level}.json" |
| ) |
|
|
|
|
| def _find_available_model_size( |
| dataset_key: str, |
| ex_id: str, |
| scalarizer: str, |
| level: str, |
| ) -> Optional[str]: |
| for size in ("large", "medium", "small"): |
| if _reference_results_file(size, dataset_key, ex_id, scalarizer, level).exists(): |
| return size |
| return None |
|
|
|
|
| |
| _FALLBACK_SCALARIZER_LEVELS: List[Tuple[str, str]] = [ |
| ("geomean_jointprob", "word"), |
| ("semantic_similarity", "word"), |
| ("geomean_jointprob", "sentence"), |
| ("semantic_similarity", "sentence"), |
| ("geomean_jointprob", "paragraph"), |
| ("semantic_similarity", "paragraph"), |
| ] |
|
|
|
|
| def _find_any_available_result( |
| dataset_key: str, |
| ex_id: str, |
| get_res: Any, |
| method: str = "shapley", |
| ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[Dict]]: |
| """Try (model_size, scalarizer, level) fallbacks; return (size, scalarizer, level, result_dict) or (None,)*4.""" |
| for size in ("small", "medium", "large"): |
| for scalarizer, level in _FALLBACK_SCALARIZER_LEVELS: |
| try: |
| result = get_res(size, dataset_key, ex_id, scalarizer=scalarizer, feature_level=level) or {} |
| payload = result.get(method, {}) |
| if payload and (payload.get("features") or payload.get("heatmap")): |
| return (size, scalarizer, level, result) |
| except Exception: |
| continue |
| return (None, None, None, None) |
|
|
|
|
| def _parse_sparse_key(raw_key: str) -> Tuple[int, ...]: |
| key = str(raw_key).strip() |
| if not key: |
| return () |
| return tuple(int(part) for part in key.split(",") if part != "") |
|
|
|
|
| def _normalize_public_payload_fallback(data: Dict[str, Any], method: str, top_k: int = 10) -> Dict[str, Any]: |
| """Convert your JSON (features list + meta + mobius_dict) to UI display format. mobius_dict can be empty.""" |
| if not isinstance(data, dict): |
| return {} |
| features = data.get("features") |
| mobius_raw = data.get("mobius_dict") if isinstance(data.get("mobius_dict"), dict) else {} |
| if not isinstance(features, list) or not features: |
| return {} |
| method = (method or "shapley").lower() |
| if method not in {"shapley", "banzhaf", "influence"}: |
| method = "shapley" |
| mobius_sparse: Dict[Tuple[int, ...], float] = {} |
| for key, raw_val in mobius_raw.items(): |
| try: |
| val = float(raw_val) |
| except Exception: |
| continue |
| try: |
| loc = _parse_sparse_key(str(key)) |
| except Exception: |
| continue |
| mobius_sparse[tuple(sorted(loc))] = val |
| token_scores: Dict[str, float] = {} |
| index_scores: Dict[int, float] = {} |
| pairwise_acc: Dict[Tuple[int, int], float] = {} |
| if mobius_sparse and mobius_to_shapley is not None: |
| if method == "shapley": |
| singleton_dict = mobius_to_shapley(mobius_sparse) |
| pair_list = shapley_interactions(mobius_sparse, order=2, top_k=top_k) or [] |
| elif method == "banzhaf": |
| singleton_dict = mobius_to_banzhaf(mobius_sparse) |
| pair_list = banzhaf_interactions(mobius_sparse, order=2, top_k=top_k) or [] |
| elif mobius_to_influence is not None and influence_interactions is not None: |
| singleton_dict = mobius_to_influence(mobius_sparse) |
| pair_list = influence_interactions(mobius_sparse, order=2, top_k=top_k) or [] |
| else: |
| singleton_dict = {} |
| pair_list = [] |
|
|
| for loc, val in singleton_dict.items(): |
| if len(loc) != 1: |
| continue |
| idx = int(loc[0]) |
| if 0 <= idx < len(features): |
| feat_name = str(features[idx]) |
| val_f = float(val) |
| token_scores[feat_name] = token_scores.get(feat_name, 0.0) + val_f |
| index_scores[idx] = index_scores.get(idx, 0.0) + val_f |
|
|
| for loc, val in pair_list: |
| if len(loc) != 2: |
| continue |
| i, j = int(loc[0]), int(loc[1]) |
| if 0 <= i < len(features) and 0 <= j < len(features): |
| key = (i, j) if i <= j else (j, i) |
| pairwise_acc[key] = float(val) |
| else: |
| |
| for loc, val in mobius_sparse.items(): |
| k = len(loc) |
| if k == 0: |
| continue |
| if method == "shapley": |
| sw = 1.0 / float(k) |
| elif method == "banzhaf": |
| sw = 1.0 / float(2 ** (k - 1)) |
| else: |
| sw = 1.0 / float(k) |
| for idx in loc: |
| if 0 <= idx < len(features): |
| feat_name = str(features[idx]) |
| token_scores[feat_name] = token_scores.get(feat_name, 0.0) + sw * val |
| index_scores[idx] = index_scores.get(idx, 0.0) + sw * val |
| if k >= 2: |
| if method == "shapley": |
| pw = 1.0 / float(k - 1) |
| elif method == "banzhaf": |
| pw = 1.0 / float(2 ** (k - 2)) |
| else: |
| pw = 1.0 / float(k - 1) |
| for i, j in combinations(sorted(loc), 2): |
| pairwise_acc[(i, j)] = pairwise_acc.get((i, j), 0.0) + pw * val |
| unique_feature_labels = [str(x) for x in features] |
| sorted_pairs = sorted(pairwise_acc.items(), key=lambda kv: abs(kv[1]), reverse=True) |
| if top_k and top_k > 0: |
| sorted_pairs = sorted_pairs[:top_k] |
| pairwise = { |
| "%s|%s" % (unique_feature_labels[i], unique_feature_labels[j]): float(v) |
| for (i, j), v in sorted_pairs |
| if 0 <= i < len(unique_feature_labels) and 0 <= j < len(unique_feature_labels) |
| } |
| pairwise_interactions = [ |
| {"features": [unique_feature_labels[i], unique_feature_labels[j]], "value": float(v)} |
| for (i, j), v in sorted_pairs |
| if 0 <= i < len(unique_feature_labels) and 0 <= j < len(unique_feature_labels) |
| ] |
| normalized = dict(data) |
| normalized["token_scores"] = token_scores |
| normalized["pairwise"] = pairwise |
| normalized["pairwise_interactions"] = pairwise_interactions |
| normalized["features"] = [ |
| {"feature": str(features[i]), "value": float(index_scores.get(i, 0.0)), "index": i} |
| for i in range(len(features)) |
| ] |
| normalized["feature_texts"] = [str(x) for x in features] |
| return normalized |
|
|
|
|
| def _public_get_model_answer_short_from_file( |
| model_size: str, |
| dataset: str, |
| ex_id: str, |
| scalarizer: str = "geomean_jointprob", |
| feature_level: str = "word", |
| ) -> Dict[str, Any]: |
| """Load model_answer_short payload (the wrong-answer attribution) for the |
| Public Mode dual-heatmap branch. Returns a per-method dict shaped like |
| `_public_get_result_from_file`, or {} when the file is missing. |
| """ |
| results_dir = _get_results_dir() |
| path = ( |
| results_dir |
| / "model_answer_short" |
| / model_size |
| / dataset |
| / ex_id |
| / scalarizer |
| / f"{feature_level}.json" |
| ) |
| if not path.exists(): |
| return {} |
| try: |
| with path.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
| except Exception: |
| return {} |
| if not isinstance(data, dict): |
| return {} |
| norm_s = _normalize_public_payload_fallback(copy.deepcopy(data), "shapley") |
| norm_b = _normalize_public_payload_fallback(copy.deepcopy(data), "banzhaf") |
| norm_i = _normalize_public_payload_fallback(copy.deepcopy(data), "influence") |
| if not norm_s and not norm_b and not norm_i: |
| return {} |
| return { |
| "shapley": norm_s, |
| "banzhaf": norm_b, |
| "influence": norm_i, |
| "meta": { |
| "dataset": dataset, |
| "example_id": ex_id, |
| "model_size": model_size, |
| "target_mode": "model_answer_short", |
| "source_layout": "results/model_answer_short/{model_size}/{dataset}/{example_id}/{scalarizer}/{feature_level}.json", |
| }, |
| } |
|
|
|
|
| def _public_get_result_from_file( |
| model_size: str, |
| dataset: str, |
| ex_id: str, |
| scalarizer: Optional[str] = None, |
| feature_level: Optional[str] = None, |
| ) -> Dict[str, Any]: |
| """Load reference_answer result from disk when loader.results.get_result_by_id is unavailable (e.g. on Space).""" |
| scalarizer = (scalarizer or "").strip() |
| feature_level = (feature_level or "").strip() |
| if not scalarizer or not feature_level: |
| return {} |
| levels_to_try = [feature_level] + [l for l in ("word", "sentence", "paragraph") if l != feature_level] |
| for lvl in levels_to_try: |
| path = _reference_results_file(model_size, dataset, ex_id, scalarizer, lvl) |
| if not path.exists() or os.getenv("SPACE_ID"): |
| try: |
| from loader.results import _maybe_download_from_space |
| path = _maybe_download_from_space(path, force_download=True) or path |
| except Exception: |
| pass |
| if not path.exists(): |
| continue |
| try: |
| with path.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
| except Exception: |
| continue |
| if not isinstance(data, dict): |
| continue |
| |
| norm_s = _normalize_public_payload_fallback(copy.deepcopy(data), "shapley") |
| norm_b = _normalize_public_payload_fallback(copy.deepcopy(data), "banzhaf") |
| norm_i = _normalize_public_payload_fallback(copy.deepcopy(data), "influence") |
| if not norm_s and not norm_b and not norm_i: |
| continue |
| return { |
| "shapley": norm_s, |
| "banzhaf": norm_b, |
| "influence": norm_i, |
| "meta": { |
| "dataset": dataset, |
| "example_id": ex_id, |
| "model_size": model_size, |
| "source_layout": "results/reference_answer/{model_size}/{dataset}/{example_id}/{scalarizer}/{feature_level}.json", |
| }, |
| } |
| return {} |
|
|
|
|
| _FALLBACK_DATASET_FILES: Dict[str, str] = { |
| "bar_exam": "BarExam_qa.csv", |
| "causal_judgment": "bbh_causal_judgement.csv", |
| "snarks": "bbh_snarks.csv", |
| "bbq_disamb": "BBQ_disamb.csv", |
| "cnn_dailymail": "CNN_dailymail.csv", |
| "drop": "drop.csv", |
| "esnli": "eSNLI.csv", |
| "fever": "fever.csv", |
| "hotpot_qa": "hotpot_qa.csv", |
| "medical_qa": "medical_qa.csv", |
| } |
|
|
|
|
| def _fallback_datasets_dir() -> Path: |
| return (_REPO_ROOT / "datasets").resolve() |
|
|
|
|
| def _fallback_pick_first_nonempty(raw: Dict[str, str], candidates: List[str]) -> str: |
| for c in candidates: |
| val = raw.get(c) |
| if val is not None and str(val).strip() != "": |
| return str(val) |
| return "" |
|
|
|
|
| def _fallback_load_dataset(dataset_key: str, max_rows: int = 10) -> List[Dict[str, str]]: |
| import csv |
|
|
| filename = _FALLBACK_DATASET_FILES.get(dataset_key) |
| if not filename: |
| return [] |
| path = _fallback_datasets_dir() / filename |
| if not path.exists(): |
| return [] |
|
|
| rows: List[Dict[str, str]] = [] |
| with path.open("r", encoding="utf-8", errors="replace", newline="") as f: |
| reader = csv.DictReader(f) |
| for i, raw in enumerate(reader, start=1): |
| ex_id = raw.get("id") or raw.get("example_id") or raw.get("uid") or f"example_{i}" |
| context = _fallback_pick_first_nonempty(raw, [ |
| "Context", "context", |
| "passage", "article", "story", "premise", |
| "paragraph", "document", "sentence1", "sent1", "background", |
| ]) |
| prompt = _fallback_pick_first_nonempty(raw, [ |
| "Prompt", "prompt", |
| "question", "input", "query", |
| "sentence2", "sent2", "hypothesis", |
| "qa_question", "title", |
| ]) |
| answer = _fallback_pick_first_nonempty(raw, [ |
| "Answer", "answer", |
| "target", "gold", "label", "output", "reference", |
| "highlights", |
| ]) |
| ex = { |
| "id": str(ex_id), |
| "context": context, |
| "prompt": prompt, |
| } |
| if answer: |
| ex["answer"] = answer |
| rows.append(ex) |
| if len(rows) >= max_rows: |
| break |
| return rows |
|
|
|
|
| REQUEST_TIMEOUT = _get_request_timeout() |
|
|
| SCALARIZER_CHOICES = [ |
| ("Semantic Similarity (y vs y_S)", "semantic_similarity"), |
| ("LogProb", "logprob"), |
| ("JointProb", "jointprob"), |
| ("GeoMean JointProb", "geomean_jointprob"), |
| ("Half SimLog", "half_simlog"), |
| ] |
|
|
| PUBLIC_SCALARIZER_CHOICES = [ |
| ("Semantic Similarity", "semantic_similarity"), |
| ("Perplexity", "geomean_jointprob"), |
| ] |
|
|
| DATASET_DISPLAY_LABELS = { |
| "bar_exam": "Bar Exam Questions", |
| "bbq_disamb": "BBQ Disambiguation", |
| "causal_judgment": "Causal Judgment", |
| "cnn_dailymail": "CNN / DailyMail Summaries", |
| "drop": "DROP Reading Comprehension", |
| "esnli": "e-SNLI Natural Language Inference", |
| "fever": "FEVER Fact Checking", |
| "hotpot_qa": "HotpotQA Multi-hop Questions", |
| "medical_qa": "Medical Questions", |
| "snarks": "Snarks", |
| } |
|
|
| import sys |
| _REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(_REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
|
|
| def _get_results_dir() -> Path: |
| """Resolve results directory: env, repo root, or on HF Space fallback to cwd/results.""" |
| env_dir = os.getenv("ATTRLLM_RESULTS_DIR") |
| if env_dir: |
| return Path(env_dir).resolve() |
| default = (_REPO_ROOT / "results").resolve() |
| if default.exists(): |
| return default |
| if _is_hf_spaces(): |
| cwd_results = (Path.cwd() / "results").resolve() |
| if cwd_results.exists(): |
| return cwd_results |
| return default |
|
|
|
|
| import gradio as gr |
| from PIL import Image |
|
|
| from .components.model_selector import ( |
| create_model_selector, |
| create_multimodal_model_selector, |
| create_feature_level_selector, |
| create_attribution_method_toggle, |
| ) |
| from .components.example_browser import create_dataset_selector, create_example_browser |
| from .components.results_display import create_results_display, update |
| from .plotting.heatmap import create_interactive_text_heatmap |
| from .plotting.interactions import ( |
| plot_top_interactions, |
| plot_interaction_matrix, |
| create_interaction_token_view, |
| ) |
| from .plotting.text_interactions import create_text_interaction_html |
| from .plotting.mm_interactions import create_multimodal_interaction_html |
| from .plotting.coalition_viewer import compute_coalition_viewer_data, render_coalition_viewer_html |
| from .build_info import BUILD_ID, BUILD_TS |
|
|
| |
| try: |
| from .medical_loader import ( |
| MEDICAL_EXAMPLES, |
| load_medical_example, |
| get_masked_image_path, |
| BENCHMARK_EXAMPLES, |
| get_examples_by_modality, |
| list_available_modalities, |
| load_benchmark_example, |
| extract_segment_regions, |
| ) |
| from .plotting.medical_charts import ( |
| create_shapley_bar_chart, |
| create_influence_heatmap, |
| create_cross_modal_bar_chart, |
| draw_grid_overlay, |
| draw_segment_labels, |
| generate_interpretation_text, |
| rename_patch_labels, |
| align_segments_to_reference, |
| remap_region_values, |
| merge_subword_token_values, |
| _tok_to_word, |
| ) |
| from .plotting.benchmark_interaction import create_benchmark_interaction_html |
| _MEDICAL_AVAILABLE = True |
| except ImportError: |
| _MEDICAL_AVAILABLE = False |
| MEDICAL_EXAMPLES = {} |
| BENCHMARK_EXAMPLES = {} |
|
|
| |
| try: |
| from .mimic_loader import ( |
| MIMIC_EXAMPLES, |
| load_mimic_example, |
| get_mimic_image_path, |
| ) |
| _MIMIC_AVAILABLE = bool(MIMIC_EXAMPLES) |
| except ImportError: |
| _MIMIC_AVAILABLE = False |
| MIMIC_EXAMPLES = {} |
|
|
| |
| try: |
| from .isic_loader import ( |
| ISIC_EXAMPLES, |
| load_isic_example, |
| get_isic_image_path, |
| ) |
| _ISIC_AVAILABLE = bool(ISIC_EXAMPLES) |
| except ImportError: |
| _ISIC_AVAILABLE = False |
| ISIC_EXAMPLES = {} |
|
|
| |
| try: |
| from .coco_loader import COCO_EXAMPLES, load_coco_example, get_coco_masked_image_path |
| _COCO_AVAILABLE = True |
| except ImportError: |
| _COCO_AVAILABLE = False |
| COCO_EXAMPLES = {} |
|
|
| |
| try: |
| from attribution.set_mm import ( |
| PipelineConfig, |
| CrossModalCLIPScorer, |
| ImageRegion, |
| TokenPlayer, |
| featurise, |
| tokenise_caption, |
| build_cross_modal_set_function, |
| run_proxyspex, |
| mobius_to_shapley, |
| mobius_to_banzhaf, |
| extract_interactions, |
| extract_cross_per_token, |
| apply_image_mask, |
| render_overlay, |
| render_segmentation_map, |
| mask_token_ids, |
| ) |
| _CLIP_PIPELINE_AVAILABLE = True |
| except ImportError: |
| _CLIP_PIPELINE_AVAILABLE = False |
|
|
| |
| _clip_scorer_cache: Dict[str, Any] = {} |
|
|
|
|
| def _raise_backend_error(resp: requests.Response, label: str) -> None: |
| detail = resp.text |
| try: |
| detail = resp.json().get("detail", detail) |
| except Exception: |
| pass |
| raise gr.Error(f"{label} failed ({resp.status_code}). {detail}") |
|
|
| |
| try: |
| from loader.data import ( |
| get_example_by_id, |
| get_examples, |
| list_datasets, |
| list_datasets_with_display_names, |
| list_dataset_display_names, |
| get_dataset_display_name, |
| get_dataset_key_from_display_name, |
| ) |
| except Exception: |
| get_example_by_id = None |
| get_examples = None |
| list_datasets = None |
| list_datasets_with_display_names = None |
| list_dataset_display_names = None |
| get_dataset_display_name = None |
| get_dataset_key_from_display_name = None |
|
|
| try: |
| from loader.results import get_result_by_id |
| except Exception: |
| get_result_by_id = None |
|
|
| try: |
| from loader.models import get_model |
| except Exception: |
| get_model = None |
|
|
| try: |
| from attribution.masker import get_masker, mask_text |
| from attribution.proxyspex import run_proxyspex |
| from attribution.image_masker import supports_superpixel |
| from attribution.utils import ( |
| influence_interactions, |
| mobius_to_influence, |
| mobius_to_shapley, |
| shapley_interactions, |
| mobius_to_banzhaf, |
| banzhaf_interactions, |
| ) |
| except Exception: |
| get_masker = None |
| mask_text = None |
| run_proxyspex = None |
| supports_superpixel = None |
| influence_interactions = None |
| mobius_to_influence = None |
| mobius_to_shapley = None |
| shapley_interactions = None |
| mobius_to_banzhaf = None |
| banzhaf_interactions = None |
|
|
|
|
| _ANSWER_FIELDS = ( |
| "correct_answer", |
| "answer", |
| "target", |
| "completion", |
| "label", |
| ) |
| _ALLOWED_METHODS = {"shapley", "banzhaf", "influence"} |
| _ALLOWED_LEVELS = {"word", "sentence", "paragraph"} |
|
|
|
|
| def _ensure_backend(name: str, fn: Optional[Any]): |
| if fn is None: |
| raise RuntimeError( |
| f"{name} is unavailable. Ensure the backend modules are installed and importable." |
| ) |
| return fn |
|
|
|
|
| def _html_component(label: str, visible: bool = True) -> gr.HTML: |
| try: |
| return gr.HTML(label=label, sanitize_html=False, visible=visible) |
| except TypeError: |
| return gr.HTML(label=label, visible=visible) |
|
|
|
|
| def _encode_image_to_b64(image: Image.Image) -> str: |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
|
| def _extract_answer(record: Dict[str, Any]) -> str: |
| for field in _ANSWER_FIELDS: |
| val = record.get(field) |
| if val: |
| return str(val) |
| return "" |
|
|
|
|
| def _coerce_feature_tuple(raw_key: Any) -> Tuple[str, ...]: |
| if isinstance(raw_key, tuple): |
| return tuple(str(item) for item in raw_key) |
| if isinstance(raw_key, list): |
| return tuple(str(item) for item in raw_key) |
| if isinstance(raw_key, str): |
| for sep in ("·", "|", ",", "×"): |
| if sep in raw_key: |
| parts = [chunk.strip() for chunk in raw_key.split(sep) if chunk.strip()] |
| if parts: |
| return tuple(parts) |
| return (raw_key.strip(),) |
| return (str(raw_key),) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def _normalize_interactions(raw: Any) -> List[Tuple[Tuple[str, ...], float]]: |
| """ |
| Make a best-effort guess at interaction structure. |
| |
| Supported shapes: |
| - { key: float } |
| - { key: {"value": float, "score": ...} } |
| - [ (key, float), ... ] |
| - [ (key, {"value": float}), ... ] |
| - [ {"features": [...], "value": float}, ... ] (this is mostly handled elsewhere) |
| """ |
| if raw is None: |
| return [] |
|
|
| items: List[Any] = [] |
|
|
| if isinstance(raw, dict): |
| |
| for k, v in raw.items(): |
| items.append((k, v)) |
| elif isinstance(raw, list): |
| items = list(raw) |
| else: |
| return [] |
|
|
| normalized: List[Tuple[Tuple[str, ...], float]] = [] |
|
|
| for item in items: |
| |
| if isinstance(item, dict): |
| feats = item.get("features") or item.get("indices") or item.get("pair") or item.get("key") |
| val = item.get("value", item.get("score", 0.0)) |
| else: |
| |
| try: |
| feats, val = item |
| except Exception: |
| continue |
|
|
| |
| if isinstance(val, dict): |
| val = val.get("value", val.get("score", 0.0)) |
|
|
| try: |
| numeric = float(val) |
| except Exception: |
| continue |
|
|
| feats_tuple = _coerce_feature_tuple(feats) |
| if feats_tuple: |
| normalized.append((feats_tuple, numeric)) |
|
|
| return normalized |
|
|
| def _resolve_marginals(payload: Dict[str, Any]) -> Dict[str, float]: |
| for key in ("marginals", "token_scores", "values", "scores"): |
| data = payload.get(key) |
| if isinstance(data, dict): |
| normalized: Dict[str, float] = {} |
| for k, v in data.items(): |
| try: |
| normalized[str(k)] = float(v) |
| except Exception: |
| continue |
| return normalized |
| return {} |
|
|
|
|
| def _resolve_features(payload: Dict[str, Any], marginals: Dict[str, float]) -> List[str]: |
| features = payload.get("features") |
| if isinstance(features, list): |
| return [str(f) for f in features] |
| if marginals: |
| return list(marginals.keys()) |
| return [] |
|
|
|
|
| def _extract_interactions_from_response( |
| data_int: Dict[str, Any], |
| method: str, |
| features: List[str], |
| ) -> List[Tuple[Tuple[str, ...], float]]: |
| inter_list: List[Tuple[Tuple[str, ...], float]] = [] |
|
|
| method_key = (method or "shapley").lower() |
| method_block = data_int.get(method_key) or data_int |
|
|
| raw_interactions = None |
| if isinstance(method_block, dict): |
| for key in ("interactions", "pairwise_interactions", "interactions_2", "pairwise", "data"): |
| if key in method_block: |
| raw_interactions = method_block.get(key) |
| break |
| if raw_interactions is None: |
| raw_interactions = method_block |
| else: |
| raw_interactions = method_block |
|
|
| |
| if isinstance(raw_interactions, list) and raw_interactions: |
| if isinstance(raw_interactions[0], dict): |
| for item in raw_interactions: |
| feats = ( |
| item.get("feature_list") |
| or item.get("features") |
| or item.get("indices") |
| or item.get("pair") |
| or [] |
| ) |
| val = None |
| for key_val in ("value", "score", "attribution", "weight"): |
| if key_val in item: |
| try: |
| val = float(item[key_val]) |
| break |
| except Exception: |
| continue |
| if val is None: |
| continue |
|
|
| if isinstance(feats, list) and feats and isinstance(feats[0], int): |
| feat_names = tuple( |
| features[i] for i in feats |
| if isinstance(i, int) and 0 <= i < len(features) |
| ) |
| else: |
| feat_names = _coerce_feature_tuple(feats) |
|
|
| if feat_names: |
| inter_list.append((feat_names, val)) |
| elif ( |
| isinstance(raw_interactions[0], (list, tuple)) |
| and len(raw_interactions[0]) == 2 |
| ): |
| for item in raw_interactions: |
| if not isinstance(item, (list, tuple)) or len(item) != 2: |
| continue |
| feats_raw, val_raw = item |
| try: |
| val = float(val_raw) |
| except Exception: |
| continue |
|
|
| feat_names: Tuple[str, ...] = () |
| if isinstance(feats_raw, (list, tuple)) and feats_raw: |
| if all(isinstance(i, int) for i in feats_raw): |
| feat_names = tuple( |
| features[i] for i in feats_raw |
| if 0 <= i < len(features) |
| ) |
| else: |
| feat_names = _coerce_feature_tuple(feats_raw) |
| elif isinstance(feats_raw, str): |
| feat_names = _coerce_feature_tuple(feats_raw) |
|
|
| if feat_names: |
| inter_list.append((feat_names, val)) |
|
|
| |
| if not inter_list and isinstance(raw_interactions, dict): |
| metadata_keys = {"method", "order", "scalarizer", "embedding_model"} |
| for k, v in raw_interactions.items(): |
| if str(k) in metadata_keys: |
| continue |
| val = None |
| if isinstance(v, (int, float)): |
| val = float(v) |
| elif isinstance(v, dict): |
| for key_val in ("value", "score", "attribution", "weight"): |
| if key_val in v: |
| try: |
| val = float(v[key_val]) |
| break |
| except Exception: |
| continue |
| if val is None: |
| continue |
|
|
| k_str = str(k) |
| idxs = [] |
| try: |
| import re as _re |
| idxs = [int(x) for x in _re.findall(r"\d+", k_str)] |
| except Exception: |
| idxs = [] |
|
|
| if idxs: |
| names: List[str] = [] |
| for idx in idxs: |
| if 0 <= idx < len(features): |
| names.append(features[idx]) |
| if names: |
| feat_names = tuple(names) |
| else: |
| feat_names = _coerce_feature_tuple(k_str) |
| else: |
| feat_names = _coerce_feature_tuple(k_str) |
|
|
| inter_list.append((feat_names, val)) |
|
|
| |
| if not inter_list and raw_interactions is not None: |
| flat: List[Tuple[Tuple[str, ...], float]] = [] |
|
|
| def _collect(obj: Any, prefix: Tuple[str, ...] = ()) -> None: |
| if isinstance(obj, (int, float)): |
| flat.append((prefix or ("<interaction>",), float(obj))) |
| elif isinstance(obj, list): |
| for i, item in enumerate(obj): |
| _collect(item, prefix + (f"[{i}]",)) |
| elif isinstance(obj, dict): |
| for kk, vv in obj.items(): |
| _collect(vv, prefix + (str(kk),)) |
|
|
| _collect(raw_interactions) |
| inter_list = flat |
|
|
| return inter_list |
|
|
|
|
| def _labels_from_regions(regions: List[Dict[str, Any]]) -> List[str]: |
| labels: List[str] = [""] * len(regions) |
| for region in regions: |
| try: |
| idx = int(region.get("index", 0)) |
| except Exception: |
| continue |
| if idx < 0 or idx >= len(labels): |
| continue |
| labels[idx] = str(region.get("label") or f"Region {idx + 1}") |
| for idx, label in enumerate(labels): |
| if not label: |
| labels[idx] = f"Region {idx + 1}" |
| return labels |
|
|
|
|
| def _interaction_dicts_to_pairs( |
| interactions: List[Dict[str, Any]], |
| labels: List[str], |
| *, |
| order: int | None = None, |
| ) -> List[Tuple[Tuple[str, ...], float]]: |
| pairs: List[Tuple[Tuple[str, ...], float]] = [] |
| for item in interactions: |
| indices = item.get("indices") |
| if not indices: |
| continue |
| if order is not None and len(indices) != order: |
| continue |
| try: |
| value = float(item.get("value", 0.0)) |
| except Exception: |
| continue |
| feats = tuple(labels[int(i)] for i in indices if int(i) < len(labels)) |
| if feats: |
| pairs.append((feats, value)) |
| return pairs |
|
|
|
|
| def _interaction_dicts_to_table( |
| interactions: List[Dict[str, Any]], |
| labels: List[str], |
| ) -> List[List[Any]]: |
| rows: List[List[Any]] = [] |
| for item in interactions: |
| indices = item.get("indices") |
| if not indices: |
| continue |
| try: |
| value = float(item.get("value", 0.0)) |
| except Exception: |
| continue |
| feats = [labels[int(i)] for i in indices if int(i) < len(labels)] |
| if feats: |
| rows.append([" × ".join(feats), value, len(indices)]) |
| return rows |
|
|
|
|
| def _feature_display_label( |
| feature: Dict[str, Any], |
| region_labels: List[str], |
| ) -> str: |
| raw = str(feature.get("feature", "")) |
| modality = feature.get("modality") or "" |
| ref_index = int(feature.get("ref_index", 0)) |
| label = raw.split(":", 1)[1] if ":" in raw else raw |
| if modality == "image": |
| if 0 <= ref_index < len(region_labels): |
| return region_labels[ref_index] |
| return label or raw |
|
|
|
|
| def _extract_feature_series(payload: Dict[str, Any]) -> Tuple[List[str], List[float]]: |
| """ |
| Try to recover an ordered pair of (feature labels, values) from a backend payload. |
| This keeps duplicates in order (appending suffixes later) so word-level tokens |
| don't collapse to a single entry. |
| """ |
| features: List[str] = [] |
| values: List[float] = [] |
|
|
| feature_entries = payload.get("features") |
| if isinstance(feature_entries, list) and feature_entries and isinstance(feature_entries[0], dict): |
| for idx, entry in enumerate(feature_entries, start=1): |
| raw_feat = ( |
| entry.get("feature") |
| or entry.get("token") |
| or entry.get("text") |
| or entry.get("label") |
| or "" |
| ) |
| if not raw_feat: |
| raw_feat = f"feature_{idx}" |
|
|
| val = entry.get("value") |
| if val is None: |
| for key in ("score", "attribution", "weight"): |
| if key in entry: |
| val = entry[key] |
| break |
| try: |
| values.append(float(val if val is not None else 0.0)) |
| except Exception: |
| values.append(0.0) |
| features.append(str(raw_feat)) |
|
|
| if not features: |
| heat = payload.get("heatmap") or {} |
| tokens = heat.get("tokens") or heat.get("features") |
| scores = heat.get("values") or heat.get("scores") |
| if isinstance(tokens, list) and isinstance(scores, list) and len(tokens) == len(scores): |
| features = [str(token if token is not None else f"feature_{idx + 1}") for idx, token in enumerate(tokens)] |
| tmp_vals: List[float] = [] |
| for score in scores: |
| try: |
| tmp_vals.append(float(score)) |
| except Exception: |
| tmp_vals.append(0.0) |
| values = tmp_vals |
|
|
| if not features: |
| marginals = _resolve_marginals(payload) |
| if marginals: |
| features = list(marginals.keys()) |
| values = [float(marginals[key]) for key in features] |
|
|
| if not features: |
| return [], [] |
|
|
| unique_features = _assign_unique_labels(features) |
| return unique_features, values |
|
|
|
|
| def _resolve_interactions(payload: Dict[str, Any], order: int) -> List[Tuple[Tuple[str, ...], float]]: |
| candidates = [f"interactions_{order}"] |
| if order == 2: |
| candidates += ["pairwise", "pairwise_interactions", "interactions2"] |
| elif order == 3: |
| candidates += ["higher_order", "triple_interactions", "interactions3"] |
|
|
| for key in candidates: |
| raw = payload.get(key) |
| normalized = _normalize_interactions(raw) |
| if normalized: |
| return normalized |
| return [] |
|
|
|
|
| def _fallback_pairwise_from_values( |
| features: List[str], |
| values: List[float], |
| max_edges: int = 40, |
| ) -> List[Tuple[Tuple[str, ...], float]]: |
| """ |
| Generate synthetic pairwise links by connecting neighboring tokens. |
| Used when the backend provides no explicit interactions. |
| """ |
| n = min(len(features), len(values)) |
| if n < 2: |
| return [] |
| edges: List[Tuple[Tuple[str, ...], float]] = [] |
| for idx in range(n - 1): |
| weight = 0.5 * (values[idx] + values[idx + 1]) |
| edges.append(((features[idx], features[idx + 1]), weight)) |
| edges.sort(key=lambda item: abs(item[1]), reverse=True) |
| return edges[:max_edges] |
|
|
|
|
| def _resolve_pairwise( |
| payload: Dict[str, Any], |
| features: Optional[List[str]] = None, |
| feature_values: Optional[List[float]] = None, |
| ) -> List[Tuple[Tuple[str, ...], float]]: |
| """Convenience helper to always pull order-2 interactions if present.""" |
| pairwise = _resolve_interactions(payload, 2) |
| if pairwise: |
| return pairwise |
| |
| mixed = payload.get("interactions") |
| normalized = _normalize_interactions(mixed) |
| if normalized: |
| return [item for item in normalized if len(item[0]) == 2] |
| if features and feature_values: |
| return _fallback_pairwise_from_values(features, feature_values) |
| return [] |
|
|
|
|
| def _normalize_method(method: Optional[str]) -> str: |
| method = (method or "shapley").lower() |
| return method if method in _ALLOWED_METHODS else "shapley" |
|
|
|
|
| def _normalize_level(level: Optional[str]) -> str: |
| level = (level or "sentence").lower() |
| return level if level in _ALLOWED_LEVELS else "sentence" |
|
|
|
|
| def _normalize_model_size(model_size: Optional[str]) -> str: |
| raw = (model_size or "small").strip() |
| lowered = raw.lower() |
| if lowered in {"small", "medium", "large"}: |
| return lowered |
| if "small" in lowered: |
| return "small" |
| if "medium" in lowered: |
| return "medium" |
| if "large" in lowered: |
| return "large" |
| return "small" |
|
|
|
|
| def _assign_unique_labels(chunks: List[str]) -> List[str]: |
| counts: Dict[str, int] = {} |
| labels: List[str] = [] |
| for idx, chunk in enumerate(chunks): |
| normalized = " ".join((chunk or "").split()) |
| if not normalized: |
| normalized = f"<chunk {idx + 1}>" |
| counts[normalized] = counts.get(normalized, 0) + 1 |
| suffix = f" ({counts[normalized]})" if counts[normalized] > 1 else "" |
| labels.append(f"{normalized}{suffix}") |
| return labels |
|
|
|
|
| def _strip_occurrence_suffix(text: str) -> str: |
| text = text or "" |
| if text.endswith(")") and " (" in text: |
| base, _, tail = text.rpartition(" (") |
| if tail[:-1].isdigit(): |
| return base |
| return text |
|
|
|
|
| def _pairwise_to_index_interactions( |
| pairwise: List[Tuple[Tuple[str, ...], float]], |
| features: List[str], |
| ) -> List[Dict[str, Any]]: |
| feature_index = {feat: idx for idx, feat in enumerate(features)} |
| base_index: Dict[str, int] = {} |
| for idx, feat in enumerate(features): |
| base_index.setdefault(_strip_occurrence_suffix(feat), idx) |
|
|
| interactions: List[Dict[str, Any]] = [] |
| for feats, val in pairwise: |
| if len(feats) != 2: |
| continue |
| a, b = feats |
| a_idx = None |
| b_idx = None |
| if isinstance(a, (int, float)) and isinstance(b, (int, float)): |
| a_idx = int(a) |
| b_idx = int(b) |
| else: |
| try: |
| a_idx = int(str(a)) |
| b_idx = int(str(b)) |
| except ValueError: |
| a_idx = feature_index.get(a) or base_index.get(_strip_occurrence_suffix(str(a))) |
| b_idx = feature_index.get(b) or base_index.get(_strip_occurrence_suffix(str(b))) |
| if a_idx is None or b_idx is None: |
| continue |
| if a_idx < 0 or b_idx < 0 or a_idx >= len(features) or b_idx >= len(features): |
| continue |
| interactions.append({"indices": [a_idx, b_idx], "value": float(val)}) |
| return interactions |
|
|
|
|
| def _locate_spans(text: str, segments: List[str]) -> List[Tuple[int, int]]: |
| spans: List[Tuple[int, int]] = [] |
| cursor = 0 |
| for segment in segments: |
| if not segment: |
| continue |
| idx = text.find(segment, cursor) |
| if idx == -1: |
| idx = cursor |
| end = idx + len(segment) |
| spans.append((idx, end)) |
| cursor = end |
| return spans |
|
|
|
|
| def _chunk_text_for_visualization( |
| context: str, |
| level: str, |
| ) -> Tuple[List[str], List[Tuple[int, int]], str]: |
| """ |
| Split input text into feature chunks and spans for visualization. |
| Falls back to the demo text if context is empty. |
| """ |
| text = context or _DEMO_TEXT |
| level = _normalize_level(level) |
|
|
| if level == "word": |
| matches = list(re.finditer(r"\S+", text)) |
| chunks = [m.group(0) for m in matches] |
| spans = [(m.start(), m.end()) for m in matches] |
| elif level == "paragraph": |
| parts = [seg for seg in re.split(r"\n\s*\n+", text) if seg.strip()] |
| spans = _locate_spans(text, parts) |
| chunks = parts[: len(spans)] |
| else: |
| parts = [seg for seg in re.split(r"(?<=[.!?])\s+", text) if seg.strip()] |
| spans = _locate_spans(text, parts) |
| chunks = parts[: len(spans)] |
|
|
| if not chunks: |
| chunks = [text] |
| spans = [(0, len(text))] |
|
|
| features = _assign_unique_labels(chunks) |
| return features, spans, text |
|
|
|
|
| def _generate_synthetic_marginals( |
| features: List[str], |
| rng: random.Random, |
| ) -> Dict[str, float]: |
| if not features: |
| return {} |
| max_len = max(len(f) for f in features) or 1 |
| marginals: Dict[str, float] = {} |
| denom = max(1, len(features) - 1) |
| for idx, feat in enumerate(features): |
| length_factor = len(feat) / max_len |
| position_factor = 1 - (idx / denom if denom else 0) |
| noise = rng.uniform(-0.25, 0.25) |
| value = (length_factor - 0.5) * 0.6 + (position_factor - 0.5) * 0.4 + noise |
| marginals[feat] = round(value, 4) |
| return marginals |
|
|
|
|
| def _generate_synthetic_interactions( |
| features: List[str], |
| marginals: Dict[str, float], |
| rng: random.Random, |
| ) -> Dict[int, List[Tuple[Tuple[str, ...], float]]]: |
| interactions: Dict[int, List[Tuple[Tuple[str, ...], float]]] = {2: [], 3: []} |
| for i in range(len(features) - 1): |
| pair = (features[i], features[i + 1]) |
| base = (marginals.get(pair[0], 0.0) + marginals.get(pair[1], 0.0)) / 2 |
| interactions[2].append((pair, round(base + rng.uniform(-0.1, 0.1), 4))) |
| for i in range(len(features) - 2): |
| triple = (features[i], features[i + 1], features[i + 2]) |
| base = sum(marginals.get(feat, 0.0) for feat in triple) / 3 |
| interactions[3].append((triple, round(base + rng.uniform(-0.1, 0.1), 4))) |
| return interactions |
|
|
|
|
| def _synthetic_attribution_pipeline( |
| context: str, |
| prompt: str, |
| answer: str, |
| *, |
| method: str, |
| level: str, |
| order: int, |
| reason: Optional[str] = None, |
| ) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]: |
| text_source = context or prompt or answer or _DEMO_TEXT |
| features, spans, text = _chunk_text_for_visualization(text_source, level) |
| seed = hash((text_source, method, level, order)) & 0xFFFFFFFF |
| rng = random.Random(seed) |
| marginals = _generate_synthetic_marginals(features, rng) |
| interactions = _generate_synthetic_interactions(features, marginals, rng) |
|
|
| html = None |
| if len(spans) == len(features): |
| html = create_interactive_text_heatmap( |
| text, |
| spans, |
| [marginals.get(f, 0.0) for f in features], |
| method=method, |
| ) |
|
|
| meta = { |
| "mode": "synthetic", |
| "reason": reason or "Attribution backend unavailable; showing mock data.", |
| "method": method, |
| "feature_level": level, |
| "interaction_order": order, |
| "feature_count": len(features), |
| } |
|
|
| inter_list = interactions.get(order, []) |
| pairwise_for_tokens = interactions.get(2, []) if order != 2 else inter_list |
| if not pairwise_for_tokens: |
| pairwise_for_tokens = _fallback_pairwise_from_values( |
| features, |
| [marginals.get(f, 0.0) for f in features], |
| ) |
| text_interaction_html = create_text_interaction_html( |
| features, |
| [marginals.get(f, 0.0) for f in features], |
| _pairwise_to_index_interactions(pairwise_for_tokens, features), |
| method=method, |
| top_k=20, |
| threshold=0.0, |
| ) |
| figs = { |
| "interactions": plot_top_interactions(inter_list, order=order, method=method), |
| } |
|
|
| return update( |
| figs=figs, |
| meta=meta, |
| html=html, |
| interaction_text_html=text_interaction_html, |
| scoring_target_source="answer_input" if answer else "model_output", |
| scoring_target_text=answer or "", |
| reference_answer=answer or "", |
| unmasked_answer="", |
| debug_scores=None, |
| scalarizer_used="logprob", |
| score_full=None, |
| score_empty=None, |
| y_len_tokens=None, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def _compute_live_attributions( |
| *, |
| context: str, |
| prompt: str, |
| correct_answer: str, |
| model_size: str, |
| level: str, |
| method: str, |
| order: int, |
| scalarizer: str = "logprob", |
| embedding_model: str | None = None, |
| progress=None, |
| ) -> Tuple[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]: |
| """ |
| Call the FastAPI /api/attributions + /api/interactions backends and turn |
| the JSON into figures / table / HTML for Gradio. |
| This version is very defensive and tries hard to extract interactions |
| from whatever shape the backend returns. |
| """ |
| method = _normalize_method(method) |
| level = _normalize_level(level) |
| order = 3 if int(order or 2) >= 3 else 2 |
|
|
| context = context or "" |
| prompt = prompt or "" |
| correct_answer = correct_answer or "" |
| text_source = context or prompt or correct_answer or _DEMO_TEXT |
|
|
| payload = { |
| "context": context, |
| "answer": correct_answer, |
| "reference_answer": correct_answer, |
| "prompt": prompt, |
| "method": method, |
| "mask_level": level, |
| "order": int(order), |
| "model_size": model_size, |
| "scalarizer": scalarizer, |
| "embedding_model": embedding_model, |
| "debug": False, |
| } |
|
|
| if progress is not None: |
| progress(0.1, desc="Calling attribution backend") |
|
|
| |
| url_attr = BACKEND_URL.rstrip("/") + "/api/attributions" |
| try: |
| resp_attr = requests.post(url_attr, json=payload, timeout=REQUEST_TIMEOUT) |
| except requests.exceptions.ReadTimeout as exc: |
| raise gr.Error( |
| "Attribution request timed out. The backend may still be running. " |
| "Consider reducing feature granularity or set ATTRLLM_REQUEST_TIMEOUT to a higher value." |
| ) from exc |
| if resp_attr.status_code >= 400: |
| _raise_backend_error(resp_attr, "Attribution request") |
| data_attr = resp_attr.json() |
|
|
| if progress is not None: |
| progress(0.35, desc="Received attribution payload") |
|
|
| |
| features, feature_values = _extract_feature_series(data_attr) |
| if not features: |
| features = ["<no features>"] |
| feature_values = [0.0] |
| marginals = {feat: float(feature_values[idx]) for idx, feat in enumerate(features)} |
|
|
| |
| if progress is not None: |
| progress(0.45, desc="Calling interactions backend") |
|
|
| url_int = BACKEND_URL.rstrip("/") + "/api/interactions" |
| try: |
| resp_int = requests.post(url_int, json=payload, timeout=REQUEST_TIMEOUT) |
| except requests.exceptions.ReadTimeout as exc: |
| raise gr.Error( |
| "Interaction request timed out. The backend may still be running. " |
| "Consider reducing order or set ATTRLLM_REQUEST_TIMEOUT to a higher value." |
| ) from exc |
| if resp_int.status_code >= 400: |
| _raise_backend_error(resp_int, "Interaction request") |
| data_int = resp_int.json() |
|
|
| |
| print("data_int keys:", list(data_int.keys())) |
|
|
| inter_list_all = _extract_interactions_from_response(data_int, method, features) |
| pairwise_for_network = [item for item in inter_list_all if len(item[0]) == 2] |
| used_pairwise_fallback = False |
|
|
| inter_list = inter_list_all |
| if inter_list: |
| filtered: List[Tuple[Tuple[str, ...], float]] = [] |
| for feats, val in inter_list: |
| if len(feats) == order: |
| filtered.append((feats, val)) |
| if filtered: |
| inter_list = filtered |
|
|
| if order != 2 and not pairwise_for_network: |
| try: |
| payload_pair = dict(payload) |
| payload_pair["order"] = 2 |
| try: |
| resp_pair = requests.post(url_int, json=payload_pair, timeout=REQUEST_TIMEOUT) |
| except requests.exceptions.ReadTimeout as exc: |
| raise gr.Error( |
| "Interaction request timed out. The backend may still be running. " |
| "Consider reducing order or set ATTRLLM_REQUEST_TIMEOUT to a higher value." |
| ) from exc |
| if resp_pair.status_code >= 400: |
| _raise_backend_error(resp_pair, "Interaction request") |
| data_pair = resp_pair.json() |
| pairwise_for_network = [ |
| item for item in _extract_interactions_from_response(data_pair, method, features) |
| if len(item[0]) == 2 |
| ] |
| except Exception as exc: |
| print("Pairwise interaction fetch failed:", exc) |
| if not pairwise_for_network: |
| if method == "influence": |
| pairwise_for_network = [] |
| else: |
| used_pairwise_fallback = True |
| pairwise_for_network = _fallback_pairwise_from_values(features, feature_values) |
|
|
| print("LIVE features:", features) |
| print("LIVE inter_list (first 3):", inter_list[:3]) |
| if method == "influence": |
| top_singletons = sorted( |
| list(zip(features, feature_values)), |
| key=lambda kv: abs(float(kv[1])), |
| reverse=True, |
| )[:10] |
| top_pairs = sorted( |
| pairwise_for_network, |
| key=lambda kv: abs(float(kv[1])), |
| reverse=True, |
| )[:10] |
| print( |
| "[influence-ui-debug] " |
| f"pairwise_source={'fallback_neighbors' if used_pairwise_fallback else 'backend'} " |
| f"feature_count={len(features)} pair_count={len(pairwise_for_network)}" |
| , flush=True) |
| print("[influence-ui-debug] top_singletons:", top_singletons, flush=True) |
| print("[influence-ui-debug] top_pairwise:", top_pairs, flush=True) |
|
|
| text_interaction_html = create_text_interaction_html( |
| features, |
| feature_values, |
| _pairwise_to_index_interactions(pairwise_for_network, features), |
| method=method, |
| top_k=20, |
| threshold=0.0, |
| ) |
|
|
| |
| max_abs = max((abs(v) for v in marginals.values()), default=0.0) |
| scale = 1.0 |
| if 0 < max_abs < 1e-3: |
| scale = 1e3 |
| if scale != 1.0: |
| marginals = {k: v * scale for k, v in marginals.items()} |
| inter_list = [(feats, val * scale) for feats, val in inter_list] |
| feature_values = [val * scale for val in feature_values] |
|
|
| |
| spans = None |
| masking = data_attr.get("masking") or data_attr.get("mask") or {} |
| if isinstance(masking, dict): |
| spans = masking.get("feature_spans") or masking.get("spans") |
|
|
| html = None |
| if spans and len(spans) == len(feature_values): |
| html = create_interactive_text_heatmap( |
| context or text_source, |
| spans, |
| feature_values, |
| method=method, |
| ) |
|
|
| |
| inter_fig = plot_top_interactions(inter_list, order=order, method=method) |
|
|
| if progress is not None: |
| progress(0.8, desc="Rendering visualizations") |
|
|
| y_len_tokens = data_attr.get("y_len_tokens") |
| scoring_target_source = data_attr.get("scoring_target_source") or "model_output" |
| scoring_target_text = data_attr.get("scoring_target_text") |
| if scoring_target_text is None: |
| scoring_target_text = correct_answer or data_attr.get("y_full") or "" |
|
|
| meta = { |
| "mode": "live", |
| "backend_url_attr": url_attr, |
| "backend_url_int": url_int, |
| "method": method, |
| "feature_level": level, |
| "interaction_order": order, |
| "model_size": model_size, |
| "feature_count": len(features), |
| "max_abs_value": max_abs, |
| "scale_applied": scale, |
| "scalarizer": data_attr.get("scalarizer_used", payload.get("scalarizer")), |
| "scoring_target_source": scoring_target_source, |
| "scoring_target_text_preview": str(scoring_target_text)[:200], |
| "score_full": data_attr.get("score_full"), |
| "score_empty": data_attr.get("score_empty"), |
| "y_len_tokens": y_len_tokens, |
| "logprob_full": data_attr.get("logprob_full"), |
| "logprob_empty": data_attr.get("logprob_empty"), |
| "min_logprob_seen": data_attr.get("min_logprob_seen"), |
| "reference_answer_received": data_attr.get("reference_answer_received"), |
| "answer_received": data_attr.get("answer_received"), |
| "raw_attr_keys": list(data_attr.keys()), |
| "raw_int_keys": list(data_int.keys()), |
| } |
| reference_answer = correct_answer |
| unmasked_answer = data_attr.get("y_full") or data_attr.get("unmasked_answer") or "" |
| debug_scores = data_attr.get("debug_scores") or None |
|
|
| interaction_chips_html = create_interaction_token_view( |
| features, |
| feature_values, |
| pairwise_for_network, |
| method=method, |
| layout="sentence" if level == "sentence" else "token", |
| ) |
| figs = { |
| "interactions": inter_fig, |
| } |
|
|
| if progress is not None: |
| progress(1.0, desc="Done") |
|
|
| return update( |
| figs=figs, |
| meta=meta, |
| html=html, |
| interaction_html=interaction_chips_html, |
| interaction_text_html=text_interaction_html, |
| scoring_target_source=scoring_target_source, |
| scoring_target_text=str(scoring_target_text), |
| reference_answer=reference_answer, |
| unmasked_answer=unmasked_answer, |
| debug_scores=debug_scores, |
| scalarizer_used=data_attr.get("scalarizer_used", payload.get("scalarizer")), |
| score_full=data_attr.get("score_full"), |
| score_empty=data_attr.get("score_empty"), |
| y_len_tokens=y_len_tokens, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| _CLIP_MODEL_MAP: Dict[str, str] = { |
| "CLIP (openai/clip-vit-base-patch32)": "openai/clip-vit-base-patch32", |
| "BiomedCLIP": "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", |
| } |
|
|
|
|
| def _get_clip_scorer(model_display: str) -> "CrossModalCLIPScorer": |
| """Load or return cached CLIP scorer. Apply dot-mask fix.""" |
| model_name = _CLIP_MODEL_MAP.get(model_display, model_display) |
| if model_name in _clip_scorer_cache: |
| return _clip_scorer_cache[model_name] |
|
|
| import torch as _torch |
| device = "cuda" if _torch.cuda.is_available() else "cpu" |
| cfg = PipelineConfig(clip_model_name=model_name, device=device) |
| scorer = CrossModalCLIPScorer(cfg) |
|
|
| |
| _neutral_ids = scorer.processor.tokenizer.encode(".", add_special_tokens=False) |
| if _neutral_ids: |
| scorer.unk_token_id = _neutral_ids[0] |
|
|
| _clip_scorer_cache[model_name] = scorer |
| return scorer |
|
|
|
|
| def _run_clip_attribution( |
| image: Image.Image, |
| caption: str, |
| clip_model: str, |
| seg_mode: str, |
| grid_size: int, |
| method: str, |
| seed: int, |
| progress=None, |
| ) -> Dict[str, Any]: |
| """ |
| Core CLIP cross-modal attribution pipeline shared by both custom tabs. |
| |
| Returns a dict with regions, token_players, values, interactions, |
| overlay images, masked images, and influence matrix. |
| """ |
| import numpy as np |
|
|
| if not _CLIP_PIPELINE_AVAILABLE: |
| raise gr.Error( |
| "CLIP pipeline not available. Ensure attribution.set_mm is importable " |
| "(requires transformers, lightgbm, numpy, scipy)." |
| ) |
|
|
| method = _normalize_method(method) |
|
|
| |
| try: |
| from simple_lama_inpainting import SimpleLama |
| mask_style = "lama" |
| except ImportError: |
| mask_style = "blur" |
|
|
| import torch as _torch |
| device = "cuda" if _torch.cuda.is_available() else "cpu" |
| model_name = _CLIP_MODEL_MAP.get(clip_model, clip_model) |
|
|
| cfg = PipelineConfig( |
| mode="patch" if seg_mode == "Patch Grid" else "unsam", |
| grid_size=int(grid_size), |
| mask_style=mask_style, |
| clip_model_name=model_name, |
| max_tokens=15, |
| method=method, |
| max_order=2, |
| top_k_interactions=15, |
| random_seed=int(seed), |
| device=device, |
| ) |
|
|
| if progress is not None: |
| progress(0.05, desc="Loading CLIP model...") |
| scorer = _get_clip_scorer(clip_model) |
|
|
| |
| if progress is not None: |
| progress(0.10, desc="Segmenting image...") |
| try: |
| regions = featurise(image, cfg) |
| except Exception as exc: |
| if seg_mode != "Patch Grid": |
| raise gr.Error( |
| f"UnSAM segmentation failed: {exc}. " |
| "Try using 'Patch Grid' instead." |
| ) from exc |
| raise |
|
|
| |
| if progress is not None: |
| progress(0.15, desc="Tokenising caption...") |
| token_players, full_token_ids = tokenise_caption( |
| caption, scorer.processor, cfg, offset=len(regions) |
| ) |
|
|
| n_img = len(regions) |
| n_tok = len(token_players) |
| n_total = n_img + n_tok |
|
|
| |
| if progress is not None: |
| progress(0.20, desc="Building set function...") |
| game = build_cross_modal_set_function( |
| image, regions, token_players, full_token_ids, scorer, cfg |
| ) |
|
|
| |
| _raw_labels = [r.label for r in regions] + [tp.label for tp in token_players] |
| |
| if progress is not None: |
| progress(0.25, desc=f"Running ProxySPEX (n={n_total})...") |
| mobius = run_proxyspex(game, _raw_labels, max_order=2, seed=int(seed)) |
|
|
| |
| if progress is not None: |
| progress(0.70, desc="Computing values...") |
| if method == "banzhaf": |
| values = mobius_to_banzhaf(mobius) |
| else: |
| values = mobius_to_shapley(mobius) |
|
|
| |
| |
| img_vals = {regions[i].label: float(values.get((i,), 0.0)) for i in range(n_img)} |
| tok_vals = {} |
| _tok_label_counts: Dict[str, int] = {} |
| for j, tp in enumerate(token_players): |
| label = tp.label |
| count = _tok_label_counts.get(label, 0) |
| _tok_label_counts[label] = count + 1 |
| key = f"{label}#{count}" if count > 0 else label |
| tok_vals[key] = float(values.get((n_img + j,), 0.0)) |
|
|
| |
| all_labels = list(img_vals.keys()) + list(tok_vals.keys()) |
|
|
| |
| interactions = extract_interactions(mobius, order=2, top_k=15) |
| cross_per_token, cross_global_top5 = extract_cross_per_token(mobius, n_img, n_tok) |
|
|
| |
| img_filter = lambda loc: all(i < n_img for i in loc) |
| tok_filter = lambda loc: all(i >= n_img for i in loc) |
| interactions_img = extract_interactions(mobius, order=2, top_k=10, player_filter=img_filter) |
| interactions_tok = extract_interactions(mobius, order=2, top_k=10, player_filter=tok_filter) |
|
|
| |
| cross_filter = lambda loc: any(i < n_img for i in loc) and any(i >= n_img for i in loc) |
| cross_interactions = extract_interactions(mobius, order=2, top_k=15, player_filter=cross_filter) |
|
|
| |
| influence_matrix = np.zeros((n_img, n_tok)) |
| for loc, val in cross_interactions: |
| img_indices = [i for i in loc if i < n_img] |
| tok_indices = [i - n_img for i in loc if i >= n_img] |
| for ii in img_indices: |
| for tj in tok_indices: |
| if 0 <= ii < n_img and 0 <= tj < n_tok: |
| influence_matrix[ii, tj] += float(val) |
|
|
| |
| if progress is not None: |
| progress(0.75, desc="Rendering overlay...") |
| img_val_list = [float(values.get((i,), 0.0)) for i in range(n_img)] |
| overlay_rgba = render_overlay(image, regions, img_val_list) |
| base_rgba = image.convert("RGBA") |
| overlay_img = Image.alpha_composite(base_rgba, overlay_rgba).convert("RGB") |
| overlay_b64 = _encode_image_to_b64(overlay_img) |
|
|
| segmap_img = render_segmentation_map(image, regions) |
| segmap_b64 = _encode_image_to_b64(segmap_img) |
|
|
| |
| w, h = image.size |
| segment_bboxes = [] |
| for reg in regions: |
| x0, y0, x1, y1 = reg.bbox |
| segment_bboxes.append({ |
| "x0_pct": 100.0 * x0 / w, |
| "y0_pct": 100.0 * y0 / h, |
| "w_pct": 100.0 * (x1 - x0) / w, |
| "h_pct": 100.0 * (y1 - y0) / h, |
| "cx_pct": 100.0 * (x0 + x1) / 2 / w, |
| "cy_pct": 100.0 * (y0 + y1) / 2 / h, |
| }) |
|
|
| |
| if progress is not None: |
| progress(0.80, desc="Generating masked images...") |
| masked_images: Dict[str, Image.Image] = {} |
| for i, reg in enumerate(regions): |
| |
| coal_removed = [1] * n_img |
| coal_removed[i] = 0 |
| removed_img = apply_image_mask( |
| image, regions, coal_removed, style=cfg.mask_style, |
| blur_radius=cfg.blur_radius, cfg=cfg, |
| ) |
| masked_images[f"{reg.label} removed"] = removed_img |
|
|
| if progress is not None: |
| progress(0.90, desc="Done computing.") |
|
|
| return { |
| "regions": regions, |
| "token_players": token_players, |
| "all_labels": all_labels, |
| "image_values": img_vals, |
| "token_values": tok_vals, |
| "values": values, |
| "mobius": mobius, |
| "interactions": interactions, |
| "interactions_img": interactions_img, |
| "interactions_tok": interactions_tok, |
| "cross_interactions": cross_interactions, |
| "cross_per_token": cross_per_token, |
| "cross_global_top5": cross_global_top5, |
| "influence_matrix": influence_matrix, |
| "overlay_img": overlay_img, |
| "overlay_b64": overlay_b64, |
| "segmap_img": segmap_img, |
| "segmap_b64": segmap_b64, |
| "segment_bboxes": segment_bboxes, |
| "masked_images": masked_images, |
| "method": method, |
| "n_img": n_img, |
| "n_tok": n_tok, |
| "mask_style": mask_style, |
| "seg_mode": seg_mode, |
| "grid_size": int(grid_size), |
| } |
|
|
|
|
| def _build_masked_choices(masked_images: Dict[str, Image.Image]) -> List[str]: |
| """Return sorted list of masked image choice labels.""" |
| return sorted(masked_images.keys()) |
|
|
|
|
| def _on_masked_image_select(choice: str, state: Dict) -> Optional[Image.Image]: |
| """Return the masked PIL image for a dropdown choice.""" |
| if not state or not choice: |
| return None |
| return state.get(choice) |
|
|
|
|
| def _compute_image_attributions_clip( |
| image: Image.Image, |
| caption: str, |
| clip_model: str, |
| seg_mode: str, |
| grid_size: int, |
| method: str, |
| seed: int, |
| progress=None, |
| ): |
| """Compute image-only attributions using CLIP pipeline. Returns UI outputs.""" |
| if image is None: |
| raise gr.Error("Please upload an image.") |
| if not caption or not caption.strip(): |
| raise gr.Error("Please provide a caption or description.") |
|
|
| result = _run_clip_attribution( |
| image, caption.strip(), clip_model, seg_mode, int(grid_size), |
| method, int(seed or 0), progress=progress, |
| ) |
|
|
| |
| seg_labels = list(result["image_values"].keys()) |
| seg_vals = list(result["image_values"].values()) |
| region_chart = create_shapley_bar_chart(seg_labels, seg_vals, "Region Attribution") |
|
|
| |
| masked_state = result["masked_images"] |
| choices = _build_masked_choices(masked_state) |
|
|
| meta = { |
| "mode": "image_clip", |
| "method": result["method"], |
| "clip_model": clip_model, |
| "seg_mode": result["seg_mode"], |
| "grid_size": result["grid_size"], |
| "mask_style": result["mask_style"], |
| "n_regions": result["n_img"], |
| "n_tokens": result["n_tok"], |
| } |
|
|
| if progress is not None: |
| progress(1.0, desc="Done") |
|
|
| |
| return ( |
| image, |
| result["overlay_img"], |
| region_chart, |
| gr.update(choices=choices, value=choices[0] if choices else None), |
| masked_state.get(choices[0]) if choices else None, |
| masked_state, |
| meta, |
| ) |
|
|
|
|
| def _compute_mm_attributions_clip( |
| image: Image.Image, |
| caption: str, |
| clip_model: str, |
| seg_mode: str, |
| grid_size: int, |
| method: str, |
| seed: int, |
| progress=None, |
| ): |
| """Compute cross-modal attributions using CLIP pipeline. Returns UI outputs.""" |
| import numpy as np |
|
|
| if image is None: |
| raise gr.Error("Please upload an image.") |
| if not caption or not caption.strip(): |
| raise gr.Error("Please provide a caption or description.") |
|
|
| result = _run_clip_attribution( |
| image, caption.strip(), clip_model, seg_mode, int(grid_size), |
| method, int(seed or 0), progress=progress, |
| ) |
|
|
| all_labels = result["all_labels"] |
| n_img = result["n_img"] |
| n_tok = result["n_tok"] |
|
|
| |
| seg_labels = list(result["image_values"].keys()) |
| seg_vals = list(result["image_values"].values()) |
| region_chart = create_shapley_bar_chart(seg_labels, seg_vals, "Region Attribution") |
|
|
| |
| tok_labels = list(result["token_values"].keys()) |
| tok_vals = list(result["token_values"].values()) |
| token_chart = create_shapley_bar_chart(tok_labels, tok_vals, "Token Attribution") |
|
|
| |
| cross_pairs = [] |
| for loc, val in result["cross_interactions"]: |
| img_parts = [all_labels[i] for i in loc if i < n_img] |
| tok_parts = [all_labels[i] for i in loc if i >= n_img] |
| if img_parts and tok_parts: |
| cross_pairs.append(((img_parts[0], tok_parts[0]), float(val))) |
| cross_chart = create_cross_modal_bar_chart(cross_pairs, "Cross-Modal Interactions", top_k=15) |
|
|
| |
| heatmap = create_influence_heatmap( |
| seg_labels, tok_labels, result["influence_matrix"], |
| "Influence Heatmap (Regions x Tokens)" |
| ) |
|
|
| |
| |
| clip_summary = { |
| "image_region_values": [ |
| {"label": seg_labels[i], "value": float(seg_vals[i])} for i in range(n_img) |
| ], |
| "token_values": [ |
| {"label": tok_labels[j], "value": float(tok_vals[j])} for j in range(n_tok) |
| ], |
| "cross_modal_interactions": [ |
| {"label": " x ".join(all_labels[i] for i in loc), "value": float(val)} |
| for loc, val in result["cross_global_top5"] |
| ], |
| } |
| image_b64 = _encode_image_to_b64(image) |
| interaction_html = create_benchmark_interaction_html( |
| image_b64=image_b64, |
| clip_summary=clip_summary, |
| vllm_logprob=None, |
| caption=caption, |
| all_cross_modal_pairs=[ |
| { |
| "pair": ( |
| all_labels[loc[0]] if loc[0] < n_img else all_labels[loc[1]], |
| all_labels[loc[1]] if loc[1] >= n_img else all_labels[loc[0]], |
| ), |
| "value": float(val), |
| } |
| for loc, val in result["cross_interactions"] |
| ], |
| segmap_b64=result["segmap_b64"], |
| overlay_b64=result["overlay_b64"], |
| segment_bboxes=result["segment_bboxes"], |
| label_map_b64="", |
| image_width=image.size[0], |
| image_height=image.size[1], |
| title="Cross-Modal Interaction View", |
| ) |
|
|
| |
| masked_state = result["masked_images"] |
| choices = _build_masked_choices(masked_state) |
|
|
| meta = { |
| "mode": "multimodal_clip", |
| "method": result["method"], |
| "clip_model": clip_model, |
| "seg_mode": result["seg_mode"], |
| "grid_size": result["grid_size"], |
| "mask_style": result["mask_style"], |
| "n_regions": n_img, |
| "n_tokens": n_tok, |
| } |
|
|
| if progress is not None: |
| progress(1.0, desc="Done") |
|
|
| |
| |
| |
| return ( |
| image, |
| result["overlay_img"], |
| region_chart, |
| token_chart, |
| cross_chart, |
| heatmap, |
| interaction_html, |
| gr.update(choices=choices, value=choices[0] if choices else None), |
| masked_state.get(choices[0]) if choices else None, |
| masked_state, |
| meta, |
| ) |
|
|
| def on_select_example( |
| dataset, |
| ex_id, |
| model_size, |
| order, |
| method, |
| scalarizer=None, |
| feature_level=None, |
| ): |
| """ |
| Public mode handler: load a precomputed example and render figures. |
| |
| Args: |
| dataset (str): dataset name |
| ex_id (str): example id |
| model_size (str): "small" | "medium" | "large" |
| order (int): interaction order (2 or 3) |
| method (str): "shapley" | "banzhaf" | "influence" |
| |
| Returns: |
| tuple ordered as: |
| ( |
| context, |
| prompt, |
| answer, |
| interactions_plot, |
| interactions_token_html, |
| text_html, |
| meta_json, |
| ) |
| """ |
| get_res = get_result_by_id if get_result_by_id is not None else _public_get_result_from_file |
| model_size = _normalize_model_size(model_size) |
| example = {"context": "", "prompt": "", "answer": ""} |
| if get_example_by_id is not None: |
| try: |
| example = get_example_by_id(dataset, ex_id) |
| except Exception: |
| pass |
| result = get_res( |
| model_size, |
| dataset, |
| ex_id, |
| scalarizer=scalarizer, |
| feature_level=feature_level, |
| ) or {} |
| payload = result.get(method, {}) |
| |
| feats = payload.get("features") if isinstance(payload, dict) else None |
| if isinstance(feats, list) and feats and not isinstance(feats[0], dict): |
| payload = _normalize_public_payload_fallback(payload, method) |
|
|
| features, feature_values = _extract_feature_series(payload) |
| if not features: |
| features = ["<no features>"] |
| feature_values = [0.0] |
| |
| if method == "influence": |
| feature_values = [abs(v) for v in feature_values] |
| marginals = {feat: float(feature_values[idx]) for idx, feat in enumerate(features)} |
| interactions = _resolve_interactions(payload, order) |
| if method == "influence": |
| interactions = [(feats, abs(val)) for feats, val in interactions] |
| pairwise = _resolve_interactions(payload, 2) |
| if not pairwise: |
| mixed = payload.get("interactions") |
| normalized = _normalize_interactions(mixed) |
| if normalized: |
| pairwise = [item for item in normalized if len(item[0]) == 2] |
| pairwise = [(feats, abs(val)) for feats, val in pairwise] |
| else: |
| pairwise = _resolve_pairwise(payload, features, feature_values) |
| if method == "influence": |
| top_singletons = sorted( |
| list(zip(features, feature_values)), |
| key=lambda kv: abs(float(kv[1])), |
| reverse=True, |
| )[:10] |
| top_pairs = sorted( |
| pairwise, |
| key=lambda kv: abs(float(kv[1])), |
| reverse=True, |
| )[:10] |
| print( |
| "[influence-ui-debug][public] " |
| f"dataset={dataset} ex_id={ex_id} feature_count={len(features)} pair_count={len(pairwise)}" |
| , flush=True) |
| print("[influence-ui-debug][public] top_singletons:", top_singletons, flush=True) |
| print("[influence-ui-debug][public] top_pairwise:", top_pairs, flush=True) |
| payload_level = ( |
| payload.get("mask_level") |
| or payload.get("feature_level") |
| or payload.get("level") |
| or (result.get("meta", {}) if isinstance(result, dict) else {}).get("feature_level") |
| ) |
| layout_mode = "sentence" if _normalize_level(payload_level) == "sentence" else "token" |
|
|
| inter = plot_top_interactions(interactions, order=order, method=method) |
|
|
| spans = payload.get("feature_spans") or payload.get("spans") |
| if not spans: |
| |
| |
| _, fallback_spans, _ = _chunk_text_for_visualization( |
| example.get("context", ""), |
| _normalize_level(payload_level), |
| ) |
| if fallback_spans and len(fallback_spans) == len(feature_values): |
| spans = fallback_spans |
| html = None |
| if spans and len(spans) == len(feature_values): |
| html = create_interactive_text_heatmap( |
| example.get("context", ""), |
| spans, |
| feature_values, |
| method=method, |
| ) |
|
|
| |
| |
| _wrong_values_for_dual: Optional[List[float]] = None |
| _wrong_pairwise_for_dual: Optional[List[Any]] = None |
| _wrong_features_for_dual: Optional[List[str]] = None |
| try: |
| from visualization.wrong_answer_examples import has_wrong_answer_view as _has_wrong_view |
| except Exception: |
| _has_wrong_view = None |
| _is_wrong_view = bool( |
| _has_wrong_view is not None |
| and html is not None |
| and spans |
| and _has_wrong_view(dataset, ex_id, scalarizer or "", feature_level or "") |
| ) |
| if _is_wrong_view: |
| wrong_result = _public_get_model_answer_short_from_file( |
| model_size, dataset, ex_id, scalarizer or "geomean_jointprob", |
| feature_level or "word", |
| ) |
| wrong_payload = wrong_result.get(method, {}) if wrong_result else {} |
| wrong_features_local, wrong_values_local = _extract_feature_series(wrong_payload) |
| if method == "influence": |
| wrong_values_local = [abs(v) for v in wrong_values_local] |
| if wrong_features_local and len(wrong_values_local) == len(feature_values): |
| |
| wrong_pairwise = _resolve_interactions(wrong_payload, 2) |
| if not wrong_pairwise: |
| mixed = wrong_payload.get("interactions") if isinstance(wrong_payload, dict) else None |
| normalized = _normalize_interactions(mixed) |
| if normalized: |
| wrong_pairwise = [item for item in normalized if len(item[0]) == 2] |
| if method == "influence": |
| wrong_pairwise = [(f, abs(v)) for f, v in (wrong_pairwise or [])] |
| else: |
| |
| if not wrong_pairwise: |
| wrong_pairwise = _resolve_pairwise(wrong_payload, wrong_features_local, wrong_values_local) |
| _wrong_values_for_dual = wrong_values_local |
| _wrong_pairwise_for_dual = wrong_pairwise or [] |
| _wrong_features_for_dual = wrong_features_local |
| else: |
| _is_wrong_view = False |
|
|
| meta = { |
| "dataset": dataset, |
| "example_id": ex_id, |
| "model_size": model_size, |
| "method": method, |
| "order": order, |
| "feature_count": len(features), |
| "payload_keys": sorted(payload.keys()), |
| } |
| if "meta" in result: |
| meta["source_meta"] = result["meta"] |
|
|
| interaction_chips_html = create_interaction_token_view( |
| features, |
| feature_values, |
| pairwise or [item for item in interactions if len(item[0]) == 2], |
| method=method, |
| layout=layout_mode, |
| ) |
| text_interaction_html = create_text_interaction_html( |
| features, |
| feature_values, |
| _pairwise_to_index_interactions( |
| pairwise or [item for item in interactions if len(item[0]) == 2], |
| features, |
| ), |
| method=method, |
| top_k=20, |
| threshold=0.0, |
| ) |
|
|
| |
| |
| |
| if ( |
| _is_wrong_view |
| and _wrong_values_for_dual is not None |
| and _wrong_features_for_dual is not None |
| ): |
| gt_view = create_text_interaction_html( |
| features, |
| feature_values, |
| _pairwise_to_index_interactions( |
| pairwise or [item for item in interactions if len(item[0]) == 2], |
| features, |
| ), |
| method=method, |
| top_k=20, |
| threshold=0.0, |
| ) |
| wrong_view = create_text_interaction_html( |
| _wrong_features_for_dual, |
| _wrong_values_for_dual, |
| _pairwise_to_index_interactions( |
| _wrong_pairwise_for_dual or [], |
| _wrong_features_for_dual, |
| ), |
| method=method, |
| top_k=20, |
| threshold=0.0, |
| ) |
|
|
| method_label = (method or "attribution").title() |
| gt_max_abs = max((abs(v) for v in feature_values), default=0.0) or 1.0 |
| wrong_max_abs = max((abs(v) for v in _wrong_values_for_dual), default=0.0) or 1.0 |
|
|
| from html import escape as _escape |
| raw_text = example.get("context", "") or "" |
| raw_text_html = _escape(raw_text).replace("\n", "<br/>") if raw_text else "" |
|
|
| |
| |
| |
| dual_css = ( |
| "<style>" |
| ".dual-heatmap-row{display:grid;grid-template-columns:1fr 1fr;gap:16px;align-items:start;}" |
| ".dual-heatmap-row .text-interaction-side-panel{display:none !important;}" |
| ".dual-heatmap-row .text-interaction-root{flex:1 1 100%;}" |
| ".dual-heatmap-row .text-interaction-card{flex:1 1 100%;max-width:100%;}" |
| ".dual-heatmap-shared{margin-top:16px;display:grid;grid-template-columns:1fr 1fr;gap:16px;}" |
| ".dual-heatmap-shared .shared-card{background:#f8f5ff;border:1px solid #e2d6f3;" |
| "border-radius:12px;padding:12px 14px;box-shadow:0 4px 10px rgba(80,50,140,0.05);}" |
| ".dual-heatmap-shared .shared-legend-bar{display:flex;align-items:center;gap:8px;margin:6px 0;}" |
| ".dual-heatmap-shared .shared-legend-label{font-size:12px;color:#6f5a72;text-transform:uppercase;letter-spacing:.04em;}" |
| ".dual-heatmap-shared .shared-legend-gradient{flex:1;height:10px;border-radius:999px;" |
| "background:linear-gradient(90deg,#dd1313 0%,#d8c6f0 50%,#4a1c87 100%);}" |
| ".dual-heatmap-shared .shared-legend-note{font-size:12px;color:#6f5a72;margin:4px 0 0 0;}" |
| ".dual-heatmap-shared .shared-raw-text p{margin:6px 0 0 0;line-height:1.5;color:#3a2b4a;}" |
| "@media (prefers-color-scheme: dark){" |
| ".dual-heatmap-shared .shared-card{background:#111a2b;border-color:#33435f;}" |
| ".dual-heatmap-shared .shared-legend-label,.dual-heatmap-shared .shared-legend-note," |
| ".dual-heatmap-shared .shared-raw-text p{color:#a9b6cb;}}" |
| "@media (max-width: 900px){" |
| ".dual-heatmap-row,.dual-heatmap-shared{grid-template-columns:1fr;}}" |
| "</style>" |
| ) |
|
|
| shared_block = ( |
| '<div class="dual-heatmap-shared">' |
| '<div class="shared-card">' |
| f'<strong>{method_label} legend</strong>' |
| '<div class="shared-legend-bar">' |
| '<span class="shared-legend-label">Negative</span>' |
| '<div class="shared-legend-gradient"></div>' |
| '<span class="shared-legend-label">Positive</span>' |
| '</div>' |
| '<p class="shared-legend-note">' |
| f'Ground-truth max |value| = {gt_max_abs:.4f}; ' |
| f'wrong-answer max |value| = {wrong_max_abs:.4f}. ' |
| 'Hover tokens for exact scores.' |
| '</p>' |
| '</div>' |
| '<div class="shared-card shared-raw-text">' |
| '<strong>Raw text</strong>' |
| f'<p>{raw_text_html or "<em>No context available.</em>"}</p>' |
| '</div>' |
| '</div>' |
| ) |
|
|
| text_interaction_html = ( |
| f'{dual_css}' |
| '<div class="dual-heatmap-row">' |
| '<div>' |
| '<div class="heatmap-caption" ' |
| 'style="font-weight:600;margin-bottom:6px;">vs Ground Truth</div>' |
| f'{gt_view}' |
| '</div>' |
| '<div>' |
| '<div class="heatmap-caption" ' |
| 'style="font-weight:600;margin-bottom:6px;">vs Model Answer (Wrong)</div>' |
| f'{wrong_view}' |
| '</div>' |
| '</div>' |
| f'{shared_block}' |
| ) |
| print( |
| f"[wrong-answer] dual chip+lines view rendered for {dataset}/{ex_id} " |
| f"(gt_features={len(features)} wrong_features={len(_wrong_features_for_dual)} " |
| f"gt_pairs={len(pairwise or [])} wrong_pairs={len(_wrong_pairwise_for_dual or [])})", |
| flush=True, |
| ) |
|
|
| figs = { |
| "interactions": inter, |
| } |
| outputs = update( |
| figs=figs, |
| meta=meta, |
| html=html, |
| interaction_html=interaction_chips_html, |
| interaction_text_html=text_interaction_html, |
| ) |
| return ( |
| example.get("context", ""), |
| example.get("prompt", ""), |
| _extract_answer(example), |
| *outputs, |
| ) |
|
|
|
|
| def on_click_compute( |
| context, |
| prompt, |
| correct_answer, |
| model_size, |
| level, |
| method, |
| scalarizer, |
| embedding_model, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| method = _normalize_method(method) |
| level = _normalize_level(level) |
| model_size = _normalize_model_size(model_size) |
| order = 2 |
|
|
| context = context or "" |
| prompt = prompt or "" |
| correct_answer = correct_answer or "" |
|
|
| return _compute_live_attributions( |
| context=context, |
| prompt=prompt, |
| correct_answer=correct_answer, |
| model_size=model_size, |
| level=level, |
| method=method, |
| order=order, |
| scalarizer=scalarizer, |
| embedding_model=embedding_model, |
| progress=progress, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| _MIMIC_METHOD_NAMES = [ |
| "BiomedCLIP Cross-Modal", |
| "LLaVA-Med Log-Prob", |
| "LLaVA-Med Generation", |
| ] |
|
|
|
|
| def _on_select_mimic_example(example_id, method_label: str = "Influence"): |
| """Load a MIMIC-CXR example and return data for the MIMIC tab.""" |
| |
| |
| |
| |
| |
| n_outputs = 15 |
| empty = tuple([""] + [None] * (n_outputs - 1)) |
| if not _MIMIC_AVAILABLE or not example_id: |
| return empty |
|
|
| method = (method_label or "Influence").lower() |
| method_display = "Influence" if method == "influence" else "Shapley" |
|
|
| _base_chart = globals()["create_shapley_bar_chart"] |
| _base_html = globals()["create_benchmark_interaction_html"] |
|
|
| def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): |
| kwargs.setdefault("method_label", method_display) |
| return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs) |
|
|
| def create_benchmark_interaction_html(**kwargs): |
| kwargs.setdefault("method_label", method_display) |
| return _base_html(**kwargs) |
|
|
| try: |
| data = load_mimic_example(example_id, method=method) |
| except Exception: |
| return empty |
|
|
| caption = data.get("caption", "") |
| findings = data.get("findings", "") |
| original_img = data.get("original_image_path") |
| meta = data.get("meta", {}) |
| category = meta.get("category", "") |
|
|
| |
| biomedclip_overlay_labeled = None |
| biomedclip_region_plot = None |
| biomedclip_token_plot = None |
| biomedclip_interaction_html = "" |
|
|
| segment_bboxes = None |
| label_map_b64 = "" |
|
|
| if data.get("has_biomedclip"): |
| bc_summary = data["biomedclip"]["summary"] |
| bc_overlay_raw = data["biomedclip"]["image_paths"].get("overlay") |
| bc_original = data["biomedclip"]["image_paths"].get("original", "") |
| bc_segmap = data["biomedclip"]["image_paths"].get("segmap", "") |
| bc_n_segs = len(bc_summary.get("image_region_values", [])) |
|
|
| bc_bboxes, bc_label_map_b64 = None, "" |
| if bc_original and bc_segmap and bc_n_segs > 0: |
| try: |
| bc_bboxes, bc_label_map_b64 = extract_segment_regions( |
| bc_original, bc_segmap, bc_n_segs) |
| except Exception: |
| pass |
|
|
| if bc_overlay_raw: |
| biomedclip_overlay_labeled = draw_segment_labels( |
| bc_overlay_raw, bc_summary.get("image_region_values", []), |
| segment_bboxes=bc_bboxes, |
| label_map_b64=bc_label_map_b64, |
| original_path=bc_original) |
|
|
| bc_r_labels = [v["label"] for v in bc_summary.get("image_region_values", [])] |
| bc_r_values = [v["value"] for v in bc_summary.get("image_region_values", [])] |
| if bc_r_labels: |
| biomedclip_region_plot = create_shapley_bar_chart( |
| bc_r_labels, bc_r_values, "BiomedCLIP — Image Region Shapley Values") |
|
|
| bc_merged = merge_subword_token_values(bc_summary.get("token_values", []), caption) |
| bc_t_labels = [v["label"] for v in bc_merged] |
| bc_t_values = [v["value"] for v in bc_merged] |
| if bc_t_labels: |
| biomedclip_token_plot = create_shapley_bar_chart( |
| bc_t_labels, bc_t_values, "BiomedCLIP — Caption Word Shapley Values") |
|
|
| |
| bc_image_b64 = data["biomedclip"].get("image_b64", {}).get("original", "") |
| bc_overlay_b64 = data["biomedclip"].get("image_b64", {}).get("overlay", "") |
| bc_all_cross = data["biomedclip"].get("all_cross_modal_pairs", []) |
| bc_segmap_b64 = "" |
| if bc_segmap: |
| import os as _os |
| if _os.path.exists(bc_segmap): |
| import base64 as _b64 |
| with open(bc_segmap, "rb") as _f: |
| bc_segmap_b64 = _b64.b64encode(_f.read()).decode("ascii") |
| biomedclip_interaction_html = create_benchmark_interaction_html( |
| image_b64=bc_image_b64, |
| clip_summary=bc_summary, |
| vllm_logprob=None, |
| caption=caption, |
| all_cross_modal_pairs=bc_all_cross, |
| segmap_b64=bc_segmap_b64, |
| overlay_b64=bc_overlay_b64, |
| segment_bboxes=bc_bboxes, |
| label_map_b64=bc_label_map_b64, |
| title="BiomedCLIP Cross-Modal Interaction View — click segments or words", |
| ) |
| segment_bboxes = bc_bboxes |
| label_map_b64 = bc_label_map_b64 |
|
|
| |
| |
| |
| llavamed_unsam_lp_overlay_img = None |
| llavamed_unsam_gen_overlay_img = None |
| llavamed_unsam_lp_plot = None |
| llavamed_unsam_gen_plot = None |
|
|
| if data.get("has_llavamed_unsam_logprob") or data.get("has_llavamed_unsam_gen"): |
| lu_segmap = data.get("llavamed_unsam_segmap_path", "") |
| lu_original = data.get("llavamed_unsam_original_path", "") or (original_img or "") |
| lu_bboxes, lu_label_map_b64 = None, "" |
| if lu_segmap and lu_original: |
| n_lu_segs = 0 |
| if data.get("has_llavamed_unsam_logprob"): |
| n_lu_segs = len(data["llavamed_unsam_logprob"].get("image_region_values", [])) |
| elif data.get("has_llavamed_unsam_gen"): |
| n_lu_segs = len(data["llavamed_unsam_gen"].get("image_region_values", [])) |
| if n_lu_segs > 0: |
| try: |
| lu_bboxes, lu_label_map_b64 = extract_segment_regions( |
| lu_original, lu_segmap, n_lu_segs) |
| except Exception: |
| pass |
|
|
| if data.get("has_llavamed_unsam_logprob"): |
| lu_lp = rename_patch_labels( |
| data["llavamed_unsam_logprob"].get("image_region_values", [])) |
| if lu_lp: |
| llavamed_unsam_lp_plot = create_shapley_bar_chart( |
| [v["label"] for v in lu_lp], |
| [v["value"] for v in lu_lp], |
| "LLaVA-Med Log-Prob — Segment Shapley Values", |
| ) |
| overlay_path = data["llavamed_unsam_logprob"].get("overlay_path", "") |
| if overlay_path: |
| llavamed_unsam_lp_overlay_img = draw_segment_labels( |
| overlay_path, lu_lp, |
| segment_bboxes=lu_bboxes, |
| label_map_b64=lu_label_map_b64, |
| original_path=lu_original) |
|
|
| if data.get("has_llavamed_unsam_gen"): |
| lu_gen = rename_patch_labels( |
| data["llavamed_unsam_gen"].get("image_region_values", [])) |
| if lu_gen: |
| llavamed_unsam_gen_plot = create_shapley_bar_chart( |
| [v["label"] for v in lu_gen], |
| [v["value"] for v in lu_gen], |
| "LLaVA-Med Generation — Segment Shapley Values", |
| ) |
| |
| overlay_path = (data["llavamed_unsam_gen"].get("overlay_path", "") |
| or data.get("llavamed_unsam_logprob", {}).get("overlay_path", "")) |
| if overlay_path: |
| llavamed_unsam_gen_overlay_img = draw_segment_labels( |
| overlay_path, lu_gen, |
| segment_bboxes=lu_bboxes, |
| label_map_b64=lu_label_map_b64, |
| original_path=lu_original) |
|
|
| |
| interpretation = "" |
| try: |
| bc_data = data.get("biomedclip", {}).get("summary") if data.get("has_biomedclip") else None |
| interpretation = generate_interpretation_text( |
| clip_summary=bc_data, |
| vllm_logprob=data.get("llavamed_unsam_logprob") if data.get("has_llavamed_unsam_logprob") else None, |
| modality="Chest X-ray", |
| body_part=category, |
| caption=caption, |
| cross_method_name="BiomedCLIP", |
| vlm_method_name="LLaVA-Med", |
| vlm_region_type="UnSAM segments", |
| ) |
| except Exception: |
| pass |
|
|
| |
| if not any(data.get(k) for k in ("has_biomedclip", "has_llavamed_unsam_logprob", |
| "has_llavamed_unsam_gen", "has_clip")): |
| interpretation = ( |
| "**No precomputed attribution results yet.**\n\n" |
| "Run the attribution pipeline on this MIMIC-CXR example to see results here. " |
| "The image and report are shown above for reference." |
| ) |
|
|
| |
| _results_state = {} |
| if biomedclip_overlay_labeled: |
| _results_state["BiomedCLIP Cross-Modal"] = { |
| "overlay": biomedclip_overlay_labeled, "plot": biomedclip_region_plot} |
| if llavamed_unsam_lp_overlay_img: |
| _results_state["LLaVA-Med Log-Prob"] = { |
| "overlay": llavamed_unsam_lp_overlay_img, "plot": llavamed_unsam_lp_plot} |
| if llavamed_unsam_gen_overlay_img: |
| _results_state["LLaVA-Med Generation"] = { |
| "overlay": llavamed_unsam_gen_overlay_img, "plot": llavamed_unsam_gen_plot} |
|
|
| return ( |
| caption, |
| original_img, |
| findings, |
| interpretation, |
| biomedclip_overlay_labeled, |
| biomedclip_token_plot, |
| biomedclip_region_plot, |
| llavamed_unsam_lp_overlay_img, |
| llavamed_unsam_lp_plot, |
| llavamed_unsam_gen_overlay_img, |
| llavamed_unsam_gen_plot, |
| biomedclip_interaction_html, |
| { |
| "example_id": example_id, |
| "category": category, |
| "has_biomedclip": data.get("has_biomedclip", False), |
| "has_llavamed_unsam_logprob": data.get("has_llavamed_unsam_logprob", False), |
| "has_llavamed_unsam_gen": data.get("has_llavamed_unsam_gen", False), |
| }, |
| _results_state, |
| gr.update(), |
| ) |
|
|
|
|
| def _on_mimic_compare_methods(method_a, method_b, results_state): |
| """Pick two MIMIC methods from state and display side by side.""" |
| if not method_a or not method_b or not results_state: |
| return None, None, None, None |
| a = results_state.get(method_a, {}) |
| b = results_state.get(method_b, {}) |
| return a.get("overlay"), b.get("overlay"), a.get("plot"), b.get("plot") |
|
|
|
|
| |
|
|
| _ISIC_METHOD_NAMES = [ |
| "BiomedCLIP Cross-Modal", |
| "LLaVA-Med Log-Prob", |
| "LLaVA-Med Generation", |
| ] |
|
|
|
|
| def _on_select_isic_example(example_id, method_label: str = "Influence"): |
| """Load an ISIC dermoscopy example and return data for the ISIC tab. |
| |
| Mirrors _on_select_mimic_example — same 14 outputs, same layout. |
| ISIC has no separate "findings" field, so slot 3 (findings) is empty. |
| """ |
| n_outputs = 14 |
| empty = tuple([""] + [None] * (n_outputs - 1)) |
| if not _ISIC_AVAILABLE or not example_id: |
| return empty |
|
|
| method = (method_label or "Influence").lower() |
| method_display = "Influence" if method == "influence" else "Shapley" |
|
|
| _base_chart = globals()["create_shapley_bar_chart"] |
| _base_html = globals()["create_benchmark_interaction_html"] |
|
|
| def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): |
| kwargs.setdefault("method_label", method_display) |
| return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs) |
|
|
| def create_benchmark_interaction_html(**kwargs): |
| kwargs.setdefault("method_label", method_display) |
| return _base_html(**kwargs) |
|
|
| try: |
| data = load_isic_example(example_id, method=method) |
| except Exception: |
| return empty |
|
|
| caption = data.get("caption", "") |
| original_img = data.get("original_image_path") |
| meta = data.get("meta", {}) |
| category = meta.get("category", "") |
|
|
| |
| biomedclip_overlay_labeled = None |
| biomedclip_region_plot = None |
| biomedclip_token_plot = None |
| biomedclip_interaction_html = "" |
|
|
| if data.get("has_biomedclip"): |
| bc_summary = data["biomedclip"]["summary"] |
| bc_overlay_raw = data["biomedclip"]["image_paths"].get("overlay") |
| bc_original = data["biomedclip"]["image_paths"].get("original", "") |
| bc_segmap = data["biomedclip"]["image_paths"].get("segmap", "") |
| bc_n_segs = len(bc_summary.get("image_region_values", [])) |
|
|
| bc_bboxes, bc_label_map_b64 = None, "" |
| if bc_original and bc_segmap and bc_n_segs > 0: |
| try: |
| bc_bboxes, bc_label_map_b64 = extract_segment_regions( |
| bc_original, bc_segmap, bc_n_segs) |
| except Exception: |
| pass |
|
|
| if bc_overlay_raw: |
| biomedclip_overlay_labeled = draw_segment_labels( |
| bc_overlay_raw, bc_summary.get("image_region_values", []), |
| segment_bboxes=bc_bboxes, |
| label_map_b64=bc_label_map_b64, |
| original_path=bc_original) |
|
|
| bc_r_labels = [v["label"] for v in bc_summary.get("image_region_values", [])] |
| bc_r_values = [v["value"] for v in bc_summary.get("image_region_values", [])] |
| if bc_r_labels: |
| biomedclip_region_plot = create_shapley_bar_chart( |
| bc_r_labels, bc_r_values, "BiomedCLIP — Image Region Shapley Values") |
|
|
| bc_merged = merge_subword_token_values(bc_summary.get("token_values", []), caption) |
| bc_t_labels = [v["label"] for v in bc_merged] |
| bc_t_values = [v["value"] for v in bc_merged] |
| if bc_t_labels: |
| biomedclip_token_plot = create_shapley_bar_chart( |
| bc_t_labels, bc_t_values, "BiomedCLIP — Caption Word Shapley Values") |
|
|
| bc_image_b64 = data["biomedclip"].get("image_b64", {}).get("original", "") |
| bc_overlay_b64 = data["biomedclip"].get("image_b64", {}).get("overlay", "") |
| bc_all_cross = data["biomedclip"].get("all_cross_modal_pairs", []) |
| bc_segmap_b64 = "" |
| if bc_segmap: |
| import os as _os |
| if _os.path.exists(bc_segmap): |
| import base64 as _b64 |
| with open(bc_segmap, "rb") as _f: |
| bc_segmap_b64 = _b64.b64encode(_f.read()).decode("ascii") |
| biomedclip_interaction_html = create_benchmark_interaction_html( |
| image_b64=bc_image_b64, |
| clip_summary=bc_summary, |
| vllm_logprob=None, |
| caption=caption, |
| all_cross_modal_pairs=bc_all_cross, |
| segmap_b64=bc_segmap_b64, |
| overlay_b64=bc_overlay_b64, |
| segment_bboxes=bc_bboxes, |
| label_map_b64=bc_label_map_b64, |
| title="BiomedCLIP Cross-Modal Interaction View — click segments or words", |
| ) |
|
|
| |
| llavamed_unsam_lp_overlay_img = None |
| llavamed_unsam_gen_overlay_img = None |
| llavamed_unsam_lp_plot = None |
| llavamed_unsam_gen_plot = None |
|
|
| if data.get("has_llavamed_unsam_logprob") or data.get("has_llavamed_unsam_gen"): |
| lu_segmap = data.get("llavamed_unsam_segmap_path", "") |
| lu_original = data.get("llavamed_unsam_original_path", "") or (original_img or "") |
| lu_bboxes, lu_label_map_b64 = None, "" |
| if lu_segmap and lu_original: |
| n_lu_segs = 0 |
| if data.get("has_llavamed_unsam_logprob"): |
| n_lu_segs = len(data["llavamed_unsam_logprob"].get("image_region_values", [])) |
| elif data.get("has_llavamed_unsam_gen"): |
| n_lu_segs = len(data["llavamed_unsam_gen"].get("image_region_values", [])) |
| if n_lu_segs > 0: |
| try: |
| lu_bboxes, lu_label_map_b64 = extract_segment_regions( |
| lu_original, lu_segmap, n_lu_segs) |
| except Exception: |
| pass |
|
|
| if data.get("has_llavamed_unsam_logprob"): |
| lu_lp = rename_patch_labels( |
| data["llavamed_unsam_logprob"].get("image_region_values", [])) |
| if lu_lp: |
| llavamed_unsam_lp_plot = create_shapley_bar_chart( |
| [v["label"] for v in lu_lp], |
| [v["value"] for v in lu_lp], |
| "LLaVA-Med Log-Prob — Segment Shapley Values", |
| ) |
| overlay_path = data["llavamed_unsam_logprob"].get("overlay_path", "") |
| if overlay_path: |
| llavamed_unsam_lp_overlay_img = draw_segment_labels( |
| overlay_path, lu_lp, |
| segment_bboxes=lu_bboxes, |
| label_map_b64=lu_label_map_b64, |
| original_path=lu_original) |
|
|
| if data.get("has_llavamed_unsam_gen"): |
| lu_gen = rename_patch_labels( |
| data["llavamed_unsam_gen"].get("image_region_values", [])) |
| if lu_gen: |
| llavamed_unsam_gen_plot = create_shapley_bar_chart( |
| [v["label"] for v in lu_gen], |
| [v["value"] for v in lu_gen], |
| "LLaVA-Med Generation — Segment Shapley Values", |
| ) |
| overlay_path = (data["llavamed_unsam_gen"].get("overlay_path", "") |
| or data.get("llavamed_unsam_logprob", {}).get("overlay_path", "")) |
| if overlay_path: |
| llavamed_unsam_gen_overlay_img = draw_segment_labels( |
| overlay_path, lu_gen, |
| segment_bboxes=lu_bboxes, |
| label_map_b64=lu_label_map_b64, |
| original_path=lu_original) |
|
|
| |
| interpretation = "" |
| try: |
| bc_data = data.get("biomedclip", {}).get("summary") if data.get("has_biomedclip") else None |
| interpretation = generate_interpretation_text( |
| clip_summary=bc_data, |
| vllm_logprob=data.get("llavamed_unsam_logprob") if data.get("has_llavamed_unsam_logprob") else None, |
| modality="Dermoscopy", |
| body_part=category, |
| caption=caption, |
| cross_method_name="BiomedCLIP", |
| vlm_method_name="LLaVA-Med", |
| vlm_region_type="UnSAM segments", |
| ) |
| except Exception: |
| pass |
|
|
| if not any(data.get(k) for k in ("has_biomedclip", "has_llavamed_unsam_logprob", |
| "has_llavamed_unsam_gen", "has_clip")): |
| interpretation = ( |
| "**No precomputed attribution results yet.**\n\n" |
| "Run the attribution pipeline on this ISIC example to see results here. " |
| "The image and caption are shown above for reference." |
| ) |
|
|
| |
| _results_state = {} |
| if biomedclip_overlay_labeled: |
| _results_state["BiomedCLIP Cross-Modal"] = { |
| "overlay": biomedclip_overlay_labeled, "plot": biomedclip_region_plot} |
| if llavamed_unsam_lp_overlay_img: |
| _results_state["LLaVA-Med Log-Prob"] = { |
| "overlay": llavamed_unsam_lp_overlay_img, "plot": llavamed_unsam_lp_plot} |
| if llavamed_unsam_gen_overlay_img: |
| _results_state["LLaVA-Med Generation"] = { |
| "overlay": llavamed_unsam_gen_overlay_img, "plot": llavamed_unsam_gen_plot} |
|
|
| return ( |
| caption, |
| original_img, |
| interpretation, |
| biomedclip_overlay_labeled, |
| biomedclip_token_plot, |
| biomedclip_region_plot, |
| llavamed_unsam_lp_overlay_img, |
| llavamed_unsam_lp_plot, |
| llavamed_unsam_gen_overlay_img, |
| llavamed_unsam_gen_plot, |
| biomedclip_interaction_html, |
| { |
| "example_id": example_id, |
| "category": category, |
| "has_biomedclip": data.get("has_biomedclip", False), |
| "has_llavamed_unsam_logprob": data.get("has_llavamed_unsam_logprob", False), |
| "has_llavamed_unsam_gen": data.get("has_llavamed_unsam_gen", False), |
| }, |
| _results_state, |
| gr.update(), |
| ) |
|
|
|
|
| def _on_isic_compare_methods(method_a, method_b, results_state): |
| """Pick two ISIC methods from state and display side by side.""" |
| if not method_a or not method_b or not results_state: |
| return None, None, None, None |
| a = results_state.get(method_a, {}) |
| b = results_state.get(method_b, {}) |
| return a.get("overlay"), b.get("overlay"), a.get("plot"), b.get("plot") |
|
|
|
|
| def _on_select_coco_example(example_id, method_label: str = "Influence"): |
| """Load a precomputed MS-COCO example and return outputs for the COCO tab.""" |
| n_outputs = 12 |
| empty = ("",) + (None,) * (n_outputs - 2) + (gr.update(),) |
| if not _COCO_AVAILABLE or not _MEDICAL_AVAILABLE or not example_id: |
| return empty |
|
|
| method = (method_label or "Influence").lower() |
| method_display = "Influence" if method == "influence" else "Shapley" |
|
|
| _base_chart = globals()["create_shapley_bar_chart"] |
| _base_html = globals()["create_benchmark_interaction_html"] |
|
|
| def create_shapley_bar_chart(labels, values, title="Shapley Values", **kwargs): |
| kwargs.setdefault("method_label", method_display) |
| return _base_chart(labels, values, title.replace("Shapley", method_display), **kwargs) |
|
|
| def create_benchmark_interaction_html(**kwargs): |
| kwargs.setdefault("method_label", method_display) |
| return _base_html(**kwargs) |
|
|
| try: |
| data = load_coco_example(example_id, method=method) |
| except Exception as exc: |
| print(f"[coco] Error loading {example_id}: {exc}") |
| return empty |
|
|
| caption = data.get("caption", "") |
| summary = data.get("summary", {}) |
| original_img = data.get("image_paths", {}).get("original") |
| overlay_img = data.get("image_paths", {}).get("overlay") |
|
|
| |
| segment_bboxes, label_map_b64 = None, "" |
| clip_original = data["image_paths"].get("original", "") |
| clip_segmap = data["image_paths"].get("segmap", "") |
| n_segs = len(summary.get("image_region_values", [])) |
| if clip_original and clip_segmap and n_segs > 0: |
| try: |
| segment_bboxes, label_map_b64 = extract_segment_regions( |
| clip_original, clip_segmap, n_segs) |
| except Exception: |
| pass |
|
|
| |
| overlay_labeled = overlay_img |
| if overlay_img: |
| try: |
| labeled = draw_segment_labels( |
| overlay_img, |
| summary.get("image_region_values", []), |
| segment_bboxes=segment_bboxes, |
| ) |
| if labeled: |
| overlay_labeled = labeled |
| except Exception: |
| pass |
|
|
| |
| r_vals = summary.get("image_region_values", []) |
| r_labels = [v["label"] for v in r_vals] |
| r_values = [v["value"] for v in r_vals] |
| region_plot = create_shapley_bar_chart( |
| r_labels, r_values, "CLIP — Image Region Shapley Values") if r_labels else None |
|
|
| t_vals = summary.get("token_values", []) |
| merged_toks = merge_subword_token_values(t_vals, caption) |
| t_labels = [v["label"] for v in merged_toks] |
| t_values = [v["value"] for v in merged_toks] |
| token_plot = create_shapley_bar_chart( |
| t_labels, t_values, "CLIP — Caption Word Shapley Values") if t_labels else None |
|
|
| |
| all_cross = data.get("all_cross_modal_pairs", []) |
| cross_plot = None |
| cross_table = [] |
| if all_cross: |
| cross_pairs = [ |
| ((item["pair"][0], _tok_to_word(item["pair"][1], caption)), item["value"]) |
| for item in all_cross |
| ] |
| cross_plot = create_cross_modal_bar_chart( |
| cross_pairs, "CLIP — Top Image x Word Interactions", top_k=20) |
| cross_table = [ |
| [item["pair"][0], _tok_to_word(item["pair"][1], caption), f"{item['value']:+.4f}"] |
| for item in all_cross[:30] |
| ] |
|
|
| |
| heatmap = None |
| influence_matrix = data.get("influence_matrix") |
| tok_labels_hm = [t.replace("tok:", "").lstrip("#") for t in data.get("tok_labels", [])] |
| if influence_matrix is not None and influence_matrix.size > 0: |
| heatmap = create_influence_heatmap( |
| data.get("seg_labels", r_labels), tok_labels_hm, influence_matrix, |
| "Image Regions x Caption Words — Influence Scores") |
|
|
| |
| image_b64 = data.get("image_b64", {}).get("original", "") |
| overlay_b64 = data.get("image_b64", {}).get("overlay", "") |
| segmap_b64 = data.get("image_b64", {}).get("segmap", "") |
|
|
| interaction_html = "" |
| try: |
| interaction_html = create_benchmark_interaction_html( |
| image_b64=image_b64, |
| clip_summary=summary, |
| vllm_logprob=None, |
| caption=caption, |
| all_cross_modal_pairs=all_cross, |
| segmap_b64=segmap_b64, |
| overlay_b64=overlay_b64, |
| segment_bboxes=segment_bboxes, |
| label_map_b64=label_map_b64, |
| title="MS-COCO — Click a region or word to explore interactions", |
| ) |
| except Exception as exc: |
| interaction_html = f"<p>Error building interaction view: {exc}</p>" |
|
|
| note = ( |
| "**Note:** These results used the original UNK mask token " |
| "(same as `<|endoftext|>`, CLIP token ID 49407). " |
| "A first-token dominance artifact may be visible in the token Shapley chart. " |
| "This will be corrected when scaling to 100 images with the dot-mask fix." |
| ) |
|
|
| |
| region_choices = data.get("region_choices", []) |
| masked_dd_update = gr.update( |
| choices=region_choices, |
| value=region_choices[0] if region_choices else None, |
| ) |
| |
| first_masked_img = None |
| if region_choices: |
| try: |
| first_masked_img = get_coco_masked_image_path(example_id, region_choices[0]) |
| except Exception: |
| pass |
|
|
| return ( |
| caption, |
| original_img, |
| overlay_labeled, |
| interaction_html, |
| token_plot, |
| region_plot, |
| cross_plot, |
| cross_table, |
| heatmap, |
| note, |
| first_masked_img, |
| masked_dd_update, |
| ) |
|
|
|
|
| def _on_select_coco_masked(example_id, choice): |
| """Return a masked image path for the COCO Masked Image Browser.""" |
| if not _COCO_AVAILABLE or not example_id or not choice: |
| return None |
| return get_coco_masked_image_path(example_id, choice) |
|
|
|
|
| def on_click_image_compute( |
| image, |
| caption, |
| clip_model, |
| seg_mode, |
| grid_size, |
| method, |
| seed, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| return _compute_image_attributions_clip( |
| image=image, |
| caption=caption, |
| clip_model=clip_model, |
| seg_mode=seg_mode, |
| grid_size=grid_size, |
| method=method, |
| seed=seed, |
| progress=progress, |
| ) |
|
|
|
|
| def on_click_mm_compute( |
| image, |
| caption, |
| clip_model, |
| seg_mode, |
| grid_size, |
| method, |
| seed, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| return _compute_mm_attributions_clip( |
| image=image, |
| caption=caption, |
| clip_model=clip_model, |
| seg_mode=seg_mode, |
| grid_size=grid_size, |
| method=method, |
| seed=seed, |
| progress=progress, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| _DEMO_TEXT = "The quick brown fox jumps over the lazy dog in a sunny meadow." |
| _DEMO_FEATURES = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "in", "a", "sunny", "meadow"] |
| _DEMO_SPANS = [ |
| (0, 3), (4, 9), (10, 15), (16, 19), (20, 25), (26, 30), (31, 34), |
| (35, 39), (40, 43), (44, 46), (47, 48), (49, 54), (55, 61) |
| ] |
| _DEMO_ATTRIBUTIONS: Dict[str, Dict[str, float]] = { |
| "shapley": { |
| "The": -0.04, |
| "quick": 0.18, |
| "brown": 0.12, |
| "fox": 0.27, |
| "jumps": 0.15, |
| "over": 0.05, |
| "the": -0.02, |
| "lazy": -0.11, |
| "dog": -0.07, |
| "in": 0.03, |
| "a": 0.02, |
| "sunny": 0.09, |
| "meadow": 0.21, |
| } |
| } |
| _DEMO_ATTRIBUTIONS["banzhaf"] = { |
| token: round(value * 0.8, 3) |
| for token, value in _DEMO_ATTRIBUTIONS["shapley"].items() |
| } |
| _DEMO_ATTRIBUTIONS["influence"] = { |
| token: round(abs(value), 3) |
| for token, value in _DEMO_ATTRIBUTIONS["shapley"].items() |
| } |
|
|
| _DEMO_INTERACTIONS_2: List[Tuple[Tuple[str, ...], float]] = [ |
| (("quick", "fox"), 0.24), |
| (("fox", "jumps"), 0.19), |
| (("sunny", "meadow"), 0.22), |
| (("lazy", "dog"), -0.17), |
| (("the", "lazy"), -0.12), |
| ] |
|
|
| _DEMO_INTERACTIONS_3: List[Tuple[Tuple[str, ...], float]] = [ |
| (("quick", "brown", "fox"), 0.28), |
| (("fox", "jumps", "over"), 0.18), |
| (("sunny", "meadow", "dog"), 0.11), |
| (("the", "lazy", "dog"), -0.21), |
| ] |
|
|
| _DEMO_INTERACTION_MATRIX: List[Tuple[Tuple[int, int], float]] = [ |
| ((1, 3), 0.23), |
| ((3, 4), 0.17), |
| ((7, 8), -0.18), |
| ((11, 12), 0.2), |
| ((2, 5), 0.09), |
| ] |
|
|
| _DEMO_DATASETS = { |
| "squad_demo": [ |
| [ |
| "The quick brown fox jumps over the lazy dog.", |
| "Who jumps over the dog?", |
| "The quick brown fox", |
| ], |
| [ |
| "AttrLLM explains attributions for large language models.", |
| "What does AttrLLM explain?", |
| "Attributions", |
| ], |
| ], |
| "truthfulqa_demo": [ |
| [ |
| "Water boils at 100 degrees Celsius at sea level.", |
| "At what temperature does water boil?", |
| "100 degrees Celsius", |
| ] |
| ], |
| } |
|
|
|
|
| def _render_demo(method: str = "shapley"): |
| method = (method or "shapley").lower() |
| order = 2 |
| attributions = _DEMO_ATTRIBUTIONS.get(method, _DEMO_ATTRIBUTIONS["shapley"]) |
| interactions = _DEMO_INTERACTIONS_3 if order == 3 else _DEMO_INTERACTIONS_2 |
|
|
| interactions_fig = plot_top_interactions(interactions, order=order, method=method) |
| demo_pairwise = _DEMO_INTERACTIONS_2 or _fallback_pairwise_from_values( |
| _DEMO_FEATURES, |
| [attributions[token] for token in _DEMO_FEATURES], |
| ) |
| text_html = create_interactive_text_heatmap( |
| _DEMO_TEXT, |
| _DEMO_SPANS, |
| [attributions[token] for token in _DEMO_FEATURES], |
| method=method, |
| ) |
| text_interaction_html = create_text_interaction_html( |
| _DEMO_FEATURES, |
| [attributions[token] for token in _DEMO_FEATURES], |
| [ |
| {"indices": [i, j], "value": float(val)} |
| for (i, j), val in _DEMO_INTERACTION_MATRIX |
| ], |
| method=method, |
| top_k=20, |
| threshold=0.0, |
| ) |
|
|
| meta = { |
| "method": method, |
| "order": order, |
| "feature_count": len(_DEMO_FEATURES), |
| "scalarizer": "logprob", |
| } |
|
|
| return update( |
| figs={ |
| "interactions": interactions_fig, |
| }, |
| meta=meta, |
| html=text_html, |
| interaction_text_html=text_interaction_html, |
| scoring_target_source="model_output", |
| scoring_target_text="", |
| reference_answer="", |
| unmasked_answer="", |
| debug_scores=None, |
| scalarizer_used="logprob", |
| score_full=None, |
| score_empty=None, |
| y_len_tokens=None, |
| ) |
|
|
|
|
| def _render_additional_plots(method: str = "shapley"): |
| return plot_interaction_matrix(_DEMO_FEATURES, _DEMO_INTERACTION_MATRIX) |
|
|
|
|
| def _records_for_dataset(dataset_name: str) -> List[Dict[str, Any]]: |
| if get_examples is not None: |
| try: |
| records = get_examples(dataset_name, n=10) |
| if records: |
| return records |
| except KeyError: |
| pass |
| except Exception: |
| pass |
|
|
| fallback_csv = _fallback_load_dataset(dataset_name, max_rows=10) |
| if fallback_csv: |
| return fallback_csv |
|
|
| fallback = [] |
| for idx, row in enumerate(_DEMO_DATASETS.get(dataset_name, []), start=1): |
| context, prompt, answer = row |
| fallback.append( |
| { |
| "id": f"{dataset_name}_demo_{idx}", |
| "context": context, |
| "prompt": prompt, |
| "correct_answer": answer, |
| } |
| ) |
| return fallback |
|
|
|
|
| def _available_datasets() -> List[str]: |
| if list_datasets is not None: |
| try: |
| datasets = list_datasets() |
| if datasets: |
| return datasets |
| except Exception: |
| pass |
| fallback = [k for k, v in _FALLBACK_DATASET_FILES.items() if (_fallback_datasets_dir() / v).exists()] |
| if fallback: |
| return sorted(fallback) |
| return list(_DEMO_DATASETS.keys()) |
|
|
|
|
| def _format_examples(records: List[Dict[str, Any]]) -> List[List[str]]: |
| formatted = [] |
| for rec in records: |
| formatted.append([ |
| rec.get("context", ""), |
| rec.get("prompt", ""), |
| rec.get("correct_answer") |
| or rec.get("answer") |
| or rec.get("target") |
| or "", |
| ]) |
| return formatted |
|
|
|
|
| def _load_examples_for_demo(dataset_name: str): |
| |
| if get_dataset_key_from_display_name is not None: |
| dataset_key = get_dataset_key_from_display_name(dataset_name) |
| else: |
| dataset_key = dataset_name |
| |
| records = _records_for_dataset(dataset_key) |
| formatted = _format_examples(records) |
| samples = formatted if formatted else _DEMO_DATASETS.get(dataset_key, []) |
| return gr.update(samples=samples or []) |
|
|
|
|
| def _resolve_example_fields(record: Dict[str, Any]) -> Tuple[str, str, str]: |
| context = record.get("context", "") |
| prompt = record.get("prompt", "") |
| answer = ( |
| record.get("correct_answer") |
| or record.get("answer") |
| or record.get("target") |
| or "" |
| ) |
| return context, prompt, answer |
|
|
|
|
| def _resolve_dataset_key(dataset_name: str) -> str: |
| if dataset_name in _available_datasets(): |
| return dataset_name |
| for key, label in DATASET_DISPLAY_LABELS.items(): |
| if dataset_name == label: |
| return key |
| if get_dataset_key_from_display_name is not None: |
| return get_dataset_key_from_display_name(dataset_name) |
| return dataset_name |
|
|
|
|
| def _dataset_choice_labels(dataset_keys: List[str]) -> List[str]: |
| labels: List[str] = [] |
| for key in dataset_keys: |
| if get_dataset_display_name is not None: |
| try: |
| labels.append(get_dataset_display_name(key)) |
| continue |
| except Exception: |
| pass |
| labels.append(DATASET_DISPLAY_LABELS.get(key, key.replace("_", " ").title())) |
| return labels |
|
|
|
|
| def _resolve_example_index(example_number: Any, records: List[Dict[str, Any]]) -> int: |
| if not records: |
| return 0 |
| try: |
| index = int(example_number) - 1 |
| except Exception: |
| index = 0 |
| return max(0, min(index, len(records) - 1)) |
|
|
|
|
| def _resolve_example_id(example_number: Any, records: List[Dict[str, Any]]) -> str: |
| if _public_only_mode(): |
| return f"example_{int(example_number or 1)}" |
| index = _resolve_example_index(example_number, records) |
| record = records[index] if records else {} |
| return str(record.get("id") or f"example_{index + 1}") |
|
|
|
|
| def _build_model_answer_panel(dataset_name: str, example_number: Any) -> str: |
| """Render Model's Answer + Justification HTML for the 30 wrong-answer |
| examples; return empty string for everything else so the gr.HTML slot |
| stays visually empty.""" |
| try: |
| from visualization.wrong_answer_examples import WRONG_ANSWER_EXAMPLES |
| except Exception: |
| return "" |
| dataset_key = _resolve_dataset_key(dataset_name) if dataset_name else "" |
| try: |
| ex_id = f"example_{int(example_number or 1)}" |
| except Exception: |
| ex_id = "example_1" |
| if (dataset_key, ex_id) not in WRONG_ANSWER_EXAMPLES: |
| return "" |
| path = ( |
| _get_results_dir() / "model_answers" / "small" / dataset_key / f"{ex_id}.json" |
| ) |
| if not path.exists(): |
| return "" |
| try: |
| with path.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
| except Exception: |
| return "" |
|
|
| from html import escape as _escape |
|
|
| letter = (data.get("model_answer_parsed") or "").strip() |
| raw = (data.get("model_answer_raw") or "").strip() |
| gt_letter = (data.get("ground_truth_letter") or data.get("ground_truth") or "").strip() |
| is_match = bool(data.get("is_match")) |
| similarity = data.get("similarity") |
| try: |
| sim_str = f"{float(similarity):.3f}" if similarity is not None else "—" |
| except Exception: |
| sim_str = "—" |
|
|
| if is_match: |
| chip_bg, chip_fg, chip_text = "#e7f6ec", "#1f8d4a", "✓ MATCH" |
| else: |
| chip_bg, chip_fg, chip_text = "#fdecea", "#c0392b", "✗ MISMATCH" |
|
|
| |
| justification = raw |
| if raw.lower().startswith("justification:"): |
| justification = raw.split(":", 1)[1].strip() |
| elif "Justification:" in raw: |
| justification = raw.split("Justification:", 1)[1].strip() |
|
|
| return ( |
| '<div style="display:flex;flex-direction:column;gap:8px;' |
| 'background:#fdf7ff;border:1px solid #e2d6f3;border-radius:10px;' |
| 'padding:12px 14px;margin-top:6px;">' |
| '<div style="display:flex;align-items:center;gap:10px;flex-wrap:wrap;">' |
| '<strong style="color:#4a1c87;">Model\'s Answer</strong>' |
| f'<span style="background:#fff;border:1px solid #d8c6f0;border-radius:6px;' |
| f'padding:2px 8px;font-weight:600;color:#4a1c87;">{_escape(letter or "—")}</span>' |
| f'<span style="color:#6f5a72;font-size:12px;">vs Ground Truth: ' |
| f'<strong>{_escape(gt_letter or "—")}</strong></span>' |
| f'<span style="background:{chip_bg};color:{chip_fg};border-radius:999px;' |
| f'padding:2px 10px;font-size:12px;font-weight:600;">{chip_text}</span>' |
| f'<span style="color:#6f5a72;font-size:12px;">sim={sim_str}</span>' |
| '</div>' |
| '<div style="font-size:13px;color:#3a2b4a;line-height:1.55;' |
| 'background:#fff;border:1px solid #ece4f8;border-radius:8px;padding:10px 12px;">' |
| f'<strong style="display:block;margin-bottom:4px;color:#4a1c87;">Justification</strong>' |
| f'{_escape(justification) if justification else "<em>No justification captured.</em>"}' |
| '</div>' |
| '</div>' |
| ) |
|
|
|
|
| def _load_examples_for_slider(dataset_name: str): |
| dataset_key = _resolve_dataset_key(dataset_name) |
| records = _records_for_dataset(dataset_key) |
| slider_max = max(1, min(10, len(records) or 10)) |
| context = prompt = answer = "" |
| if records: |
| context, prompt, answer = _resolve_example_fields(records[0]) |
| slider_update = gr.update(minimum=1, maximum=slider_max, step=1, value=1) |
| return slider_update, records, context, prompt, answer |
|
|
|
|
| def _update_example_preview(example_number: Any, records): |
| if not records: |
| return "", "", "" |
| index = _resolve_example_index(example_number, records) |
| return _resolve_example_fields(records[index]) |
|
|
|
|
| def _results_output_list(results: Dict[str, Any]) -> List[Any]: |
| return [ |
| results["interactions"], |
| results["interactions_tokens_html"], |
| results["interactions_text_html"], |
| results["text_html"], |
| results["meta"], |
| results["scoring_target_source"], |
| results["scoring_target_text"], |
| results["reference_answer"], |
| results["unmasked_answer"], |
| results["debug_scores"], |
| results["scalarizer_used"], |
| results["score_full"], |
| results["score_empty"], |
| results["y_len_tokens"], |
| ] |
|
|
|
|
| def build_demo_app() -> gr.Blocks: |
| datasets = _available_datasets() |
| default_dataset = datasets[0] if datasets else "demo" |
|
|
| |
| custom_css = """ |
| .gradio-container { |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif !important; |
| background: linear-gradient(135deg, #fef5f0 0%, #f0e8ff 50%, #e8f5ff 100%) !important; |
| padding: 24px !important; |
| } |
| |
| .gradio-container h1, .gradio-container h2 { |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| font-weight: 900; |
| font-size: 42px !important; |
| margin: 20px 0 16px 0; |
| letter-spacing: -0.03em; |
| } |
| |
| label, .gr-label { |
| font-weight: 700 !important; |
| font-size: 16px !important; |
| color: #2d1f4a !important; |
| } |
| |
| .gr-button { |
| border-radius: 16px !important; |
| font-weight: 700 !important; |
| font-size: 17px !important; |
| padding: 16px 32px !important; |
| background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; |
| color: white !important; |
| border: none !important; |
| } |
| |
| .gr-box, .gr-input, .gr-dropdown, .gr-textbox { |
| border-radius: 14px !important; |
| border: 3px solid #e8dff5 !important; |
| font-size: 17px !important; |
| } |
| |
| .gr-markdown p { |
| font-size: 17px !important; |
| font-weight: 500 !important; |
| } |
| """ |
|
|
| _demo_kwargs = {"title": "AttrLLM Visualization Demo"} |
| if _supports_kwarg(gr.Blocks, "css"): |
| _demo_kwargs["css"] = custom_css |
| with gr.Blocks(**_demo_kwargs) as demo: |
| gr.Markdown( |
| "# 🎨 AttrLLM Visualization Demo\n\n" |
| "**Preview the attribution widgets** before wiring real backends. " |
| "Use the controls below to explore the interface." |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| initial_choices = _dataset_choice_labels(datasets) |
| initial_value = initial_choices[0] if initial_choices else None |
| |
| dataset_selector = gr.Dropdown( |
| choices=initial_choices, |
| value=initial_value, |
| label="Dataset", |
| interactive=True, |
| allow_custom_value=False, |
| elem_id="dataset-selector-demo", |
| elem_classes=["bubble-select"], |
| ) |
| example_browser = create_example_browser() |
| with gr.Column(scale=1): |
| model_selector = create_model_selector() |
| scalarizer_selector = gr.Dropdown( |
| choices=SCALARIZER_CHOICES, |
| value="logprob", |
| label="Scalarizer", |
| interactive=True, |
| ) |
| embedding_model_box = gr.Textbox( |
| label="Embedding Model (for scalarizer=embedding)", |
| value="Qwen/Qwen3-Embedding-0.6B", |
| lines=1, |
| ) |
| feature_level_selector = create_feature_level_selector() |
| method_toggle = create_attribution_method_toggle() |
|
|
| dataset_selector.change( |
| fn=_load_examples_for_demo, |
| inputs=dataset_selector, |
| outputs=example_browser, |
| ) |
| demo.load( |
| fn=_load_examples_for_demo, |
| inputs=[dataset_selector], |
| outputs=[example_browser], |
| ) |
|
|
| render_button = gr.Button("Render Demo Visuals", variant="primary") |
|
|
| outputs = create_results_display() |
| extra_matrix = gr.Plot(label="Interaction Matrix (demo)") |
|
|
| render_button.click( |
| fn=_render_demo, |
| inputs=[method_toggle], |
| outputs=_results_output_list(outputs), |
| ) |
|
|
| render_button.click( |
| fn=_render_additional_plots, |
| inputs=[method_toggle], |
| outputs=[extra_matrix], |
| ) |
|
|
| return demo |
|
|
|
|
| def _patch_gradio_schema_generation() -> None: |
| """Prevent Gradio 5.x /info crash caused by additionalProperties: true in schemas.""" |
| try: |
| from gradio_client import utils as client_utils |
| except Exception: |
| return |
| if getattr(client_utils, "_attrllm_schema_patch", False): |
| return |
| original_inner = getattr(client_utils, "_json_schema_to_python_type", None) |
| original_outer = getattr(client_utils, "json_schema_to_python_type", None) |
| if not callable(original_inner) or not callable(original_outer): |
| return |
|
|
| def _normalize_schema(schema): |
| if isinstance(schema, bool): |
| return {} if schema else {"type": "null"} |
| if isinstance(schema, list): |
| return [_normalize_schema(item) for item in schema] |
| if not isinstance(schema, dict): |
| return schema |
| normalized = dict(schema) |
| if isinstance(normalized.get("additionalProperties"), bool): |
| normalized["additionalProperties"] = _normalize_schema(normalized["additionalProperties"]) |
| for key in ("properties", "$defs", "definitions", "patternProperties"): |
| value = normalized.get(key) |
| if isinstance(value, dict): |
| normalized[key] = {k: _normalize_schema(v) for k, v in value.items()} |
| for key in ("items", "contains", "not", "if", "then", "else"): |
| if key in normalized: |
| normalized[key] = _normalize_schema(normalized[key]) |
| for key in ("anyOf", "allOf", "oneOf", "prefixItems"): |
| value = normalized.get(key) |
| if isinstance(value, list): |
| normalized[key] = [_normalize_schema(item) for item in value] |
| return normalized |
|
|
| client_utils._json_schema_to_python_type = lambda s, d=None: original_inner(_normalize_schema(s), d) |
| client_utils.json_schema_to_python_type = lambda s: original_outer(_normalize_schema(s)) |
| client_utils._attrllm_schema_patch = True |
|
|
|
|
| _patch_gradio_schema_generation() |
|
|
|
|
| def build_app() -> gr.Blocks: |
| datasets = _available_datasets() |
| default_dataset = datasets[0] if datasets else "" |
| public_only = _public_only_mode() |
| mm_only = _mm_only_mode() |
|
|
| |
| custom_css = """ |
| /* Main container styling - Warm gradient background */ |
| .gradio-container { |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif !important; |
| background: linear-gradient(135deg, #fef5f0 0%, #f0e8ff 50%, #e8f5ff 100%) !important; |
| padding: 24px !important; |
| } |
| |
| /* Header styling - Large, bold, colorful */ |
| .gradio-container h1 { |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| font-weight: 900; |
| font-size: 48px !important; |
| margin: 20px 0 16px 0; |
| letter-spacing: -0.03em; |
| text-align: left; |
| } |
| |
| .gradio-container h2 { |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a6f 30%, #c44569 60%, #6c5ce7 100%); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| font-weight: 900; |
| font-size: 42px !important; |
| margin: 20px 0 16px 0; |
| letter-spacing: -0.03em; |
| } |
| |
| .gradio-container h3 { |
| color: #2d1f4a; |
| font-weight: 800; |
| font-size: 24px !important; |
| margin: 24px 0 16px 0; |
| } |
| |
| /* Tab styling - Bold and colorful */ |
| .tab-nav { |
| border: none !important; |
| background: transparent !important; |
| gap: 8px !important; |
| padding: 8px 0 !important; |
| } |
| |
| .tab-nav button { |
| font-size: 18px !important; |
| font-weight: 700 !important; |
| padding: 16px 32px !important; |
| border-radius: 16px !important; |
| transition: all 0.3s ease !important; |
| border: 3px solid #e0d0f0 !important; |
| background: white !important; |
| color: #6c5ce7 !important; |
| margin-right: 8px !important; |
| } |
| |
| .tab-nav button:hover { |
| background: #f8f4ff !important; |
| border-color: #b8a8db !important; |
| transform: translateY(-2px) !important; |
| } |
| |
| .tab-nav button.selected { |
| background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; |
| color: white !important; |
| border: 3px solid #6c5ce7 !important; |
| box-shadow: 0 6px 20px rgba(108, 92, 231, 0.3) !important; |
| } |
| |
| /* Button styling - Vibrant and interactive */ |
| .gr-button { |
| border-radius: 16px !important; |
| font-weight: 700 !important; |
| font-size: 17px !important; |
| padding: 16px 32px !important; |
| transition: all 0.3s cubic-bezier(0.34, 1.56, 0.64, 1) !important; |
| box-shadow: 0 6px 20px rgba(108, 92, 231, 0.2) !important; |
| border: none !important; |
| } |
| |
| .gr-button-primary { |
| background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; |
| color: white !important; |
| } |
| |
| .gr-button-secondary { |
| background: linear-gradient(135deg, #fd79a8 0%, #ff7675 100%) !important; |
| color: white !important; |
| } |
| |
| .gr-button:hover { |
| transform: translateY(-3px) scale(1.02) !important; |
| box-shadow: 0 10px 30px rgba(108, 92, 231, 0.35) !important; |
| } |
| |
| .gr-button-primary:hover { |
| background: linear-gradient(135deg, #5e4ec7 0%, #9089e8 100%) !important; |
| } |
| |
| /* Input/Dropdown styling - Clear and modern */ |
| .gr-box, .gr-input, .gr-dropdown { |
| border-radius: 14px !important; |
| border: 3px solid #e8dff5 !important; |
| background: white !important; |
| font-size: 17px !important; |
| padding: 12px 16px !important; |
| transition: all 0.3s ease !important; |
| font-weight: 500 !important; |
| } |
| |
| .gr-box:focus, .gr-input:focus, .gr-dropdown:focus { |
| border-color: #6c5ce7 !important; |
| box-shadow: 0 0 0 4px rgba(108, 92, 231, 0.15) !important; |
| transform: translateY(-1px) !important; |
| } |
| |
| /* Textbox styling - Larger text */ |
| .gr-textbox { |
| border-radius: 16px !important; |
| border: 3px solid #e8dff5 !important; |
| font-size: 17px !important; |
| line-height: 1.6 !important; |
| } |
| |
| .gr-textbox textarea { |
| font-size: 17px !important; |
| line-height: 1.6 !important; |
| padding: 14px !important; |
| } |
| |
| .gr-textbox:focus-within { |
| border-color: #6c5ce7 !important; |
| box-shadow: 0 6px 24px rgba(108, 92, 231, 0.2) !important; |
| } |
| |
| /* Radio button styling - Colorful pills */ |
| .gr-radio { |
| gap: 12px !important; |
| } |
| |
| .gr-radio label { |
| font-size: 17px !important; |
| font-weight: 600 !important; |
| padding: 14px 28px !important; |
| border-radius: 14px !important; |
| border: 3px solid #e8dff5 !important; |
| transition: all 0.3s ease !important; |
| background: white !important; |
| cursor: pointer !important; |
| } |
| |
| .gr-radio label:hover { |
| border-color: #b8a8db !important; |
| background: #faf8ff !important; |
| transform: translateY(-2px) !important; |
| box-shadow: 0 4px 12px rgba(108, 92, 231, 0.15) !important; |
| } |
| |
| .gr-radio input:checked + label { |
| background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; |
| color: white !important; |
| border-color: #6c5ce7 !important; |
| font-weight: 800 !important; |
| box-shadow: 0 6px 20px rgba(108, 92, 231, 0.3) !important; |
| } |
| |
| /* Panel/Accordion styling - Clean cards */ |
| .gr-panel { |
| border-radius: 20px !important; |
| border: 3px solid #e8dff5 !important; |
| padding: 24px !important; |
| background: white !important; |
| box-shadow: 0 6px 24px rgba(108, 92, 231, 0.1) !important; |
| margin: 16px 0 !important; |
| } |
| |
| .gr-accordion { |
| border-radius: 18px !important; |
| border: 3px solid #e8dff5 !important; |
| background: white !important; |
| } |
| |
| /* Label styling - Bold and readable */ |
| label, .gr-label { |
| font-weight: 700 !important; |
| font-size: 16px !important; |
| color: #2d1f4a !important; |
| margin-bottom: 10px !important; |
| letter-spacing: -0.01em !important; |
| } |
| |
| /* Dropdown options */ |
| .gr-dropdown-menu { |
| border-radius: 14px !important; |
| border: 3px solid #e8dff5 !important; |
| box-shadow: 0 8px 32px rgba(108, 92, 231, 0.15) !important; |
| font-size: 17px !important; |
| } |
| |
| .gr-dropdown-menu .item { |
| font-size: 17px !important; |
| padding: 12px 16px !important; |
| font-weight: 500 !important; |
| } |
| |
| .gr-dropdown-menu .item:hover { |
| background: linear-gradient(135deg, #f3f0ff 0%, #e8f5ff 100%) !important; |
| } |
| |
| /* Plot container - Prominent */ |
| .gr-plot { |
| border-radius: 20px !important; |
| border: 3px solid #e8dff5 !important; |
| overflow: hidden !important; |
| box-shadow: 0 8px 30px rgba(108, 92, 231, 0.12) !important; |
| background: white !important; |
| width: 100% !important; |
| } |
| |
| /* Force the inner Plotly canvas + svg to fill its container so the Bar |
| View doesn't render in a half-width column when the Text Interaction |
| view above it is wide. */ |
| .gr-plot .js-plotly-plot, |
| .gr-plot .plot-container, |
| .gr-plot .svg-container, |
| .gr-plot .main-svg { |
| width: 100% !important; |
| max-width: 100% !important; |
| } |
| .interaction-stack > .gradio-plot, |
| .interaction-stack > .block.gradio-plot, |
| .interaction-stack .gr-plot { |
| width: 100% !important; |
| max-width: 100% !important; |
| flex: 1 1 100% !important; |
| } |
| |
| /* JSON viewer */ |
| .gr-json { |
| border-radius: 16px !important; |
| border: 3px solid #e8dff5 !important; |
| background: #faf8ff !important; |
| padding: 20px !important; |
| font-family: 'Monaco', 'Menlo', 'Consolas', monospace !important; |
| font-size: 15px !important; |
| } |
| |
| /* Column styling */ |
| .gr-column { |
| padding: 20px !important; |
| } |
| |
| /* Row styling */ |
| .gr-row { |
| gap: 24px !important; |
| margin: 12px 0 !important; |
| } |
| |
| /* Markdown content - Larger, more readable */ |
| .gr-markdown { |
| line-height: 1.8 !important; |
| color: #2d1f4a !important; |
| } |
| |
| .gr-markdown p { |
| font-size: 17px !important; |
| margin: 12px 0 !important; |
| font-weight: 500 !important; |
| } |
| |
| .gr-markdown strong { |
| font-weight: 800 !important; |
| color: #6c5ce7 !important; |
| } |
| |
| /* Status/info messages - Colorful notifications */ |
| .gr-info { |
| border-radius: 16px !important; |
| border-left: 5px solid #6c5ce7 !important; |
| background: linear-gradient(135deg, #f8f6ff 0%, #f0f4ff 100%) !important; |
| padding: 18px 24px !important; |
| font-size: 16px !important; |
| font-weight: 600 !important; |
| color: #2d1f4a !important; |
| box-shadow: 0 4px 16px rgba(108, 92, 231, 0.1) !important; |
| } |
| |
| /* Error messages */ |
| .gr-error { |
| border-radius: 16px !important; |
| border-left: 5px solid #ff6b6b !important; |
| background: linear-gradient(135deg, #fff5f5 0%, #ffe8e8 100%) !important; |
| padding: 18px 24px !important; |
| font-size: 16px !important; |
| font-weight: 600 !important; |
| color: #c44569 !important; |
| } |
| |
| /* Loading spinner */ |
| .loading { |
| border: 4px solid #f3f0ff !important; |
| border-top: 4px solid #6c5ce7 !important; |
| } |
| |
| /* Scrollbar styling */ |
| ::-webkit-scrollbar { |
| width: 12px !important; |
| height: 12px !important; |
| } |
| |
| ::-webkit-scrollbar-track { |
| background: #f8f6ff !important; |
| border-radius: 10px !important; |
| } |
| |
| ::-webkit-scrollbar-thumb { |
| background: linear-gradient(135deg, #6c5ce7 0%, #a29bfe 100%) !important; |
| border-radius: 10px !important; |
| } |
| |
| ::-webkit-scrollbar-thumb:hover { |
| background: linear-gradient(135deg, #5e4ec7 0%, #9089e8 100%) !important; |
| } |
| |
| .results-shell { |
| margin-top: 16px !important; |
| background: transparent !important; |
| border: none !important; |
| border-radius: 0 !important; |
| padding: 0 !important; |
| box-shadow: none !important; |
| } |
| |
| .results-shell, |
| .results-shell > div, |
| .results-shell .gr-group, |
| .results-shell .gr-box, |
| .results-shell .gr-panel, |
| .results-shell .block { |
| background: transparent !important; |
| border: none !important; |
| box-shadow: none !important; |
| } |
| |
| .interaction-stack { |
| gap: 20px !important; |
| padding: 0 8px 6px !important; |
| } |
| |
| .interaction-stack h3 { |
| margin-left: 24px !important; |
| margin-bottom: 12px !important; |
| } |
| |
| .public-controls { |
| align-items: stretch !important; |
| gap: 20px !important; |
| margin-top: 8px !important; |
| } |
| |
| .control-card { |
| background: linear-gradient(180deg, rgba(255, 255, 255, 0.82) 0%, rgba(250, 246, 255, 0.96) 100%) !important; |
| border: 2px solid rgba(224, 208, 240, 0.78) !important; |
| border-radius: 26px !important; |
| padding: 18px 20px 14px !important; |
| box-shadow: 0 14px 30px rgba(108, 92, 231, 0.07) !important; |
| } |
| |
| .control-card-primary { |
| background: linear-gradient(180deg, rgba(255, 255, 255, 0.86) 0%, rgba(244, 248, 255, 0.96) 100%) !important; |
| } |
| |
| .control-card-secondary { |
| background: linear-gradient(180deg, rgba(255, 255, 255, 0.86) 0%, rgba(250, 244, 255, 0.96) 100%) !important; |
| } |
| |
| .control-card .gradio-container, |
| .control-card .gr-group { |
| background: transparent !important; |
| } |
| |
| .control-card > div, |
| .control-card .block, |
| .control-card .wrap, |
| .control-card .gr-form, |
| .control-card .form { |
| background: transparent !important; |
| border: none !important; |
| box-shadow: none !important; |
| } |
| |
| .control-card .gr-box, |
| .control-card .gr-panel { |
| background: transparent !important; |
| box-shadow: none !important; |
| } |
| |
| .bubble-select { |
| border: 3px solid #8f5cff !important; |
| border-radius: 18px !important; |
| box-shadow: 0 8px 20px rgba(143, 92, 255, 0.10) !important; |
| transition: box-shadow 0.2s ease, border-color 0.2s ease !important; |
| } |
| |
| .bubble-select:focus-within { |
| border-color: #7a3dff !important; |
| box-shadow: 0 0 0 4px rgba(143, 92, 255, 0.14), 0 10px 24px rgba(143, 92, 255, 0.16) !important; |
| } |
| |
| .example-id-slider { |
| margin-top: 8px !important; |
| padding: 10px 2px 2px !important; |
| } |
| |
| .example-id-slider input[type="range"] { |
| accent-color: #4f7cff !important; |
| } |
| |
| .example-id-slider .number-input, |
| .example-id-slider input[type="number"] { |
| border-radius: 16px !important; |
| border: 2px solid #d8dcee !important; |
| background: linear-gradient(180deg, #ffffff 0%, #f7f9ff 100%) !important; |
| font-weight: 700 !important; |
| min-width: 72px !important; |
| } |
| |
| .example-id-slider .wrap { |
| gap: 14px !important; |
| } |
| |
| @media (prefers-color-scheme: dark) { |
| .gradio-container { |
| background: radial-gradient(circle at top, #1e2a44 0%, #0d1422 52%, #090f19 100%) !important; |
| } |
| |
| .gradio-container h3, |
| label, .gr-label, |
| .gr-markdown, |
| .gr-markdown p { |
| color: #e8eefc !important; |
| } |
| |
| .gr-markdown strong { |
| color: #cbd7ff !important; |
| } |
| |
| .tab-nav button, |
| .gr-box, .gr-input, .gr-dropdown, .gr-textbox, |
| .gr-panel, .gr-accordion, |
| .gr-plot, .gr-json { |
| background: rgba(16, 24, 39, 0.88) !important; |
| border-color: rgba(148, 163, 184, 0.24) !important; |
| color: #e8eefc !important; |
| } |
| |
| .tab-nav button { |
| color: #d7e1ff !important; |
| } |
| |
| .tab-nav button:hover { |
| background: rgba(37, 52, 79, 0.96) !important; |
| border-color: rgba(199, 210, 254, 0.36) !important; |
| } |
| |
| .gr-radio label { |
| background: rgba(16, 24, 39, 0.9) !important; |
| border-color: rgba(148, 163, 184, 0.26) !important; |
| color: #e8eefc !important; |
| } |
| |
| .gr-radio label:hover { |
| background: rgba(37, 52, 79, 0.96) !important; |
| } |
| |
| .gr-textbox textarea, |
| .gr-input input { |
| background: transparent !important; |
| color: #e8eefc !important; |
| } |
| |
| .gr-dropdown-menu { |
| background: #101827 !important; |
| border-color: rgba(148, 163, 184, 0.24) !important; |
| } |
| |
| .gr-dropdown-menu .item { |
| color: #e8eefc !important; |
| } |
| |
| .gr-dropdown-menu .item:hover { |
| background: rgba(37, 52, 79, 0.96) !important; |
| } |
| |
| .gr-plot .main-svg, |
| .gr-plot .svg-container, |
| .gr-plot .plot-container, |
| .gr-plot .user-select-none { |
| background: transparent !important; |
| } |
| |
| .gr-plot .xtick text, |
| .gr-plot .ytick text, |
| .gr-plot .gtitle text, |
| .gr-plot .xtitle text, |
| .gr-plot .ytitle text, |
| .gr-plot .annotation-text, |
| .gr-plot .legend text { |
| fill: #e8eefc !important; |
| color: #e8eefc !important; |
| } |
| |
| .gr-plot .gridlayer path, |
| .gr-plot .zerolinelayer path, |
| .gr-plot .xlines-above path, |
| .gr-plot .ylines-above path { |
| stroke: rgba(148, 163, 184, 0.22) !important; |
| } |
| |
| .gr-info { |
| background: linear-gradient(135deg, rgba(30, 41, 59, 0.95) 0%, rgba(17, 24, 39, 0.95) 100%) !important; |
| color: #dbe7ff !important; |
| border-left-color: #9db4ff !important; |
| } |
| |
| .control-card { |
| background: linear-gradient(180deg, rgba(16, 24, 39, 0.9) 0%, rgba(18, 28, 45, 0.96) 100%) !important; |
| border-color: rgba(148, 163, 184, 0.2) !important; |
| box-shadow: 0 18px 36px rgba(0, 0, 0, 0.24) !important; |
| } |
| |
| .results-shell, |
| .results-shell > div, |
| .results-shell .gr-group, |
| .results-shell .gr-box, |
| .results-shell .gr-panel, |
| .results-shell .block { |
| background: transparent !important; |
| border: none !important; |
| box-shadow: none !important; |
| } |
| |
| .bubble-select { |
| border-color: #a06cff !important; |
| box-shadow: 0 10px 24px rgba(143, 92, 255, 0.18) !important; |
| } |
| |
| .bubble-select:focus-within { |
| border-color: #c29cff !important; |
| box-shadow: 0 0 0 4px rgba(143, 92, 255, 0.16), 0 12px 26px rgba(143, 92, 255, 0.22) !important; |
| } |
| |
| .example-id-slider .number-input, |
| .example-id-slider input[type="number"] { |
| background: linear-gradient(180deg, #162031 0%, #111827 100%) !important; |
| border-color: rgba(148, 163, 184, 0.24) !important; |
| color: #e8eefc !important; |
| } |
| |
| .gr-error { |
| background: linear-gradient(135deg, rgba(68, 18, 32, 0.95) 0%, rgba(39, 12, 20, 0.95) 100%) !important; |
| color: #ffd5dc !important; |
| } |
| |
| ::-webkit-scrollbar-track { |
| background: #111827 !important; |
| } |
| } |
| """ |
|
|
| _app_kwargs = {"title": "LLM Reasoning Explorer Studio"} |
| if _supports_kwarg(gr.Blocks, "css"): |
| _app_kwargs["css"] = custom_css |
| with gr.Blocks(**_app_kwargs) as app: |
| gr.Markdown( |
| "# LLM Reasoning Explorer Studio\n\n" |
| "**Explore attribution results and feature interactions** with our interactive visualization tools. " |
| "Browse pre-computed examples or analyze your own text in real-time with powerful AI insights." |
| ) |
| gr.Markdown(f"**Build:** {BUILD_ID} ({BUILD_TS})") |
|
|
| example_state = gr.State([]) |
|
|
| with (gr.Column(visible=not mm_only) if (public_only or mm_only) else gr.Tab("Public Mode")): |
| with gr.Accordion("How to Use", open=False): |
| gr.Markdown( |
| "1. **Select a dataset** from 10 available datasets (100 total examples, 10 per dataset)\n" |
| "2. **Choose a model** to compare: Qwen3-4B, Qwen3-30B, or Mistral-7B\n" |
| "3. **Pick a scoring method:** Perplexity or Semantic Similarity\n" |
| "4. **Set the feature level:** Word, Sentence, or Paragraph\n" |
| "5. **Choose an attribution method:** Shapley, Banzhaf, or Influence\n" |
| "6. **View results** in the Text Interaction View (inline highlights) and Bar View (ranked interactions)" |
| ) |
| with gr.Row(elem_classes=["public-controls"]): |
| with gr.Column(scale=1, elem_classes=["control-card", "control-card-primary"]): |
| |
| initial_choices = _dataset_choice_labels(datasets) |
| |
| |
| _preferred_default = "BBQ Disambiguation" |
| if mm_only: |
| initial_value = None |
| elif _preferred_default in initial_choices: |
| initial_value = _preferred_default |
| else: |
| initial_value = initial_choices[0] if initial_choices else None |
|
|
| dataset_selector = gr.Dropdown( |
| choices=initial_choices, |
| value=initial_value, |
| label="Dataset", |
| interactive=True, |
| allow_custom_value=False, |
| elem_id="dataset-selector", |
| elem_classes=["bubble-select"], |
| ) |
|
|
| example_selector = gr.Slider( |
| label="Example ID", |
| minimum=1, |
| maximum=10, |
| step=1, |
| value=1, |
| interactive=True, |
| elem_classes=["example-id-slider"], |
| ) |
| with gr.Column(scale=1, elem_classes=["control-card", "control-card-secondary"]): |
| model_selector = create_model_selector() |
| scalarizer_selector = gr.Dropdown( |
| choices=PUBLIC_SCALARIZER_CHOICES, |
| value="geomean_jointprob", |
| label="Scalarizer", |
| interactive=True, |
| elem_classes=["bubble-select"], |
| ) |
| public_feature_level_selector = create_feature_level_selector(value="word") |
| method_toggle = create_attribution_method_toggle() |
|
|
| with gr.Accordion("Example Preview", open=True): |
| with gr.Row(): |
| with gr.Column(scale=3): |
| context_box = gr.Textbox( |
| label="Context", |
| lines=8, |
| interactive=False, |
| ) |
| with gr.Column(scale=2): |
| prompt_box = gr.Textbox( |
| label="Prompt", |
| lines=4, |
| interactive=False, |
| ) |
| answer_box = gr.Textbox( |
| label="Ground Truth Answer", |
| lines=3, |
| interactive=False, |
| ) |
| |
| |
| try: |
| model_answer_html = gr.HTML(value="", sanitize_html=False) |
| except TypeError: |
| model_answer_html = gr.HTML(value="") |
|
|
| public_results = create_results_display() |
|
|
| def _public_mode_compute( |
| dataset, |
| example_number, |
| records, |
| model_size, |
| scalarizer, |
| feature_level, |
| method, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| if mm_only: |
| return tuple([None] * 14) |
| if not dataset: |
| raise gr.Error("Please select a dataset.") |
| if not example_number: |
| raise gr.Error("Please select an example.") |
|
|
| dataset_key = _resolve_dataset_key(dataset) |
| ex_id = _resolve_example_id(example_number, records) |
|
|
| method = _normalize_method(method) |
| level = _normalize_level(feature_level) |
| model_size = _normalize_model_size(model_size) |
|
|
| |
| get_res = get_result_by_id if get_result_by_id is not None else _public_get_result_from_file |
| result = get_res( |
| model_size, |
| dataset_key, |
| ex_id, |
| scalarizer=scalarizer, |
| feature_level=level, |
| ) or {} |
| payload = result.get(method, {}) |
|
|
| if not payload: |
| alt_size = _find_available_model_size(dataset_key, ex_id, scalarizer, level) |
| if alt_size and alt_size != model_size: |
| result = get_res( |
| alt_size, |
| dataset_key, |
| ex_id, |
| scalarizer=scalarizer, |
| feature_level=level, |
| ) or {} |
| payload = result.get(method, {}) |
| if payload: |
| model_size = alt_size |
|
|
| |
| if not payload: |
| alt_size, alt_scalarizer, alt_level, result = _find_any_available_result( |
| dataset_key, ex_id, get_res, method |
| ) |
| if alt_size and alt_scalarizer and alt_level and result: |
| payload = result.get(method, {}) |
| model_size, scalarizer, level = alt_size, alt_scalarizer, alt_level |
|
|
| if payload and (payload.get("features") or payload.get("heatmap")): |
| _, _, _, *outputs = on_select_example( |
| dataset_key, |
| ex_id, |
| model_size, |
| 2, |
| method, |
| scalarizer=scalarizer, |
| feature_level=level, |
| ) |
| return outputs |
|
|
| |
| if _public_only_mode() or get_example_by_id is None: |
| expected_ref = _reference_results_file(model_size, dataset_key, ex_id, scalarizer, level) |
| raise gr.Error( |
| "No precomputed results found.\n\n" |
| f"Expected (reference_answer):\n{expected_ref}\n\n" |
| "On Hugging Face Space: make sure the 'results' folder is in your repo " |
| "(commit & push it). If you use Git LFS, enable 'LFS' in Space Settings → " |
| "Repository and ensure files are pulled. You can also try another " |
| "scalarizer (e.g. Perplexity) or feature level (e.g. word)." |
| ) |
|
|
| |
| get_ex = _ensure_backend("loader.data.get_example_by_id", get_example_by_id) |
| record = get_ex(dataset_key, ex_id) |
| context = record.get("context", "") |
| prompt = record.get("prompt", "") |
| answer = _extract_answer(record) |
|
|
| return _compute_live_attributions( |
| context=context, |
| prompt=prompt, |
| correct_answer=answer, |
| model_size=model_size, |
| scalarizer=scalarizer, |
| embedding_model=None, |
| level=level, |
| method=method, |
| order=2, |
| progress=progress, |
| ) |
|
|
| public_preview_outputs = [context_box, prompt_box, answer_box] |
| public_compute_inputs = [ |
| dataset_selector, |
| example_selector, |
| example_state, |
| model_selector, |
| scalarizer_selector, |
| public_feature_level_selector, |
| method_toggle, |
| ] |
| public_compute_outputs = _results_output_list(public_results) |
|
|
| dataset_change_event = dataset_selector.change( |
| fn=_load_examples_for_slider, |
| inputs=[dataset_selector], |
| outputs=[ |
| example_selector, |
| example_state, |
| context_box, |
| prompt_box, |
| answer_box, |
| ], |
| queue=False, |
| ).then( |
| fn=_build_model_answer_panel, |
| inputs=[dataset_selector, example_selector], |
| outputs=[model_answer_html], |
| queue=False, |
| ) |
| load_event = app.load( |
| fn=_load_examples_for_slider, |
| inputs=[dataset_selector], |
| outputs=[ |
| example_selector, |
| example_state, |
| context_box, |
| prompt_box, |
| answer_box, |
| ], |
| ).then( |
| fn=_build_model_answer_panel, |
| inputs=[dataset_selector, example_selector], |
| outputs=[model_answer_html], |
| queue=False, |
| ) |
|
|
| dataset_change_event.then( |
| fn=_public_mode_compute, |
| inputs=public_compute_inputs, |
| outputs=public_compute_outputs, |
| show_progress="full", |
| ) |
| load_event.then( |
| fn=_public_mode_compute, |
| inputs=public_compute_inputs, |
| outputs=public_compute_outputs, |
| show_progress="full", |
| ) |
|
|
| example_selector.release( |
| fn=_update_example_preview, |
| inputs=[example_selector, example_state], |
| outputs=public_preview_outputs, |
| queue=False, |
| ).then( |
| fn=_build_model_answer_panel, |
| inputs=[dataset_selector, example_selector], |
| outputs=[model_answer_html], |
| queue=False, |
| ).then( |
| fn=_public_mode_compute, |
| inputs=public_compute_inputs, |
| outputs=public_compute_outputs, |
| show_progress="full", |
| ) |
|
|
| for component in ( |
| model_selector, |
| scalarizer_selector, |
| public_feature_level_selector, |
| method_toggle, |
| ): |
| component.change( |
| fn=_public_mode_compute, |
| inputs=public_compute_inputs, |
| outputs=public_compute_outputs, |
| show_progress="full", |
| ) |
|
|
| |
| with gr.Tab("Multimodal"): |
|
|
| with gr.Accordion("How to Use", open=False): |
| gr.Markdown( |
| "1. **Choose a dataset** from the three sub-tabs:\n" |
| " - **MIMIC-CXR (10 Samples)** — chest X-rays across 10 pathology categories\n" |
| " - **Dermoscopy ISIC (10 Samples)** — skin-lesion dermoscopy across 8 diagnostic classes\n" |
| " - **MS-COCO (5 Samples)** — natural-image cross-modal benchmark\n" |
| "2. **Pick an example** from the dropdown (each is an image + caption pair)\n" |
| "3. **Choose an attribution method:** Influence (default, non-negative — clearer for clinicians) or Shapley (signed)\n" |
| "4. **Read the four panels side-by-side:**\n" |
| " - **Interactive Cross-Modal View** — click any image patch or caption word to see its strongest cross-modal partners\n" |
| " - **BiomedCLIP Cross-Modal Attribution** — patch-level overlay + bar charts (cosine-similarity scoring)\n" |
| " - **LLaVA-Med Attribution** — log-prob and generation Shapley charts from the medical 7B VLM\n" |
| " - **Compare Two Methods Side-by-Side** — pick any two of the above to overlay their rankings\n" |
| "5. **Hover token chips and patches** for exact attribution values; hover SVG arcs for pairwise interaction strength" |
| ) |
|
|
| with gr.Tab("MIMIC-CXR (10 Samples)"): |
| gr.Markdown( |
| "**10-sample MIMIC-CXR chest X-ray attribution benchmark** " |
| "(10 pathology categories). \n" |
| "Source: [MIMIC-CXR-JPG](https://huggingface.co/datasets/itsanmolgupta/mimic-cxr-dataset-cleaned) " |
| "— de-identified chest radiographs from Beth Israel Deaconess Medical Center. \n" |
| "Each example has a radiology report (impression = caption, findings = detail)." |
| ) |
|
|
| |
| |
| _mimic_choices = ( |
| [(v["category"], k) for k, v in MIMIC_EXAMPLES.items()] |
| if _MIMIC_AVAILABLE else [] |
| ) |
|
|
| mimic_selector = gr.Dropdown( |
| choices=_mimic_choices, |
| value=None, |
| label="Filter by Pathology", |
| interactive=True, |
| ) |
|
|
| mimic_method_toggle = gr.Radio( |
| choices=["Influence", "Shapley"], |
| value="Influence", |
| label="Attribution method", |
| info=( |
| "Influence (default) is always positive — clearer for clinicians. " |
| "Shapley is signed (green = supports caption, red = detracts)." |
| ), |
| interactive=True, |
| ) |
|
|
| mimic_caption = gr.Textbox( |
| label="Radiology Impression (Caption)", |
| interactive=False, |
| lines=2, |
| ) |
| with gr.Accordion("Full Radiology Findings", open=False): |
| mimic_findings = gr.Textbox( |
| label="Detailed Findings", |
| interactive=False, |
| lines=5, |
| ) |
|
|
| |
| mimic_original = gr.Image(label="Chest X-ray", type="filepath") |
|
|
| mimic_interpretation = gr.Markdown( |
| value="*Select an example above to see the attribution analysis.*", |
| label="Interpretation", |
| ) |
|
|
| |
| _mimic_pill = ( |
| 'style="display:inline-block;padding:6px 14px;background:#e3f2fd;' |
| 'border-radius:16px;text-decoration:none;color:#1565c0;font-size:0.9em;' |
| 'border:1px solid #bbdefb;"' |
| ) |
| _mimic_toc_html = ( |
| '<div style="background:#f8f9fa;border:1px solid #dee2e6;border-radius:8px;' |
| 'padding:16px;margin:12px 0;">' |
| '<strong style="font-size:1.05em;">Jump to Section:</strong>' |
| '<div style="display:flex;flex-wrap:wrap;gap:8px;margin-top:10px;">' |
| f'<a href="#mimic-method-biomedclip" {_mimic_pill}>BiomedCLIP Cross-Modal</a>' |
| f'<a href="#mimic-method-llavamed" {_mimic_pill}>LLaVA-Med (UnSAM)</a>' |
| f'<a href="#mimic-interactive" {_mimic_pill}>Interactive View</a>' |
| f'<a href="#mimic-compare" {_mimic_pill}>Compare Methods</a>' |
| '</div></div>' |
| ) |
| gr.HTML(value=_mimic_toc_html) |
|
|
| mimic_results_state = gr.State({}) |
|
|
| |
| |
| |
| with gr.Column(elem_id="mimic-interactive"): |
| gr.Markdown("---\n### BiomedCLIP Cross-Modal Interaction View — click segments or words") |
| gr.Markdown( |
| "**How to use:** Click any **image region** to see which caption words " |
| "it connects to, or click a **word** to see which regions activate. \n" |
| "**Green** arrows = positive interaction. **Red** arrows = negative." |
| ) |
| mimic_biomedclip_interaction_html = _html_component( |
| "BiomedCLIP Cross-Modal Interaction View") |
|
|
| |
| |
| |
| with gr.Column(elem_id="mimic-method-biomedclip"): |
| gr.Markdown("---\n### BiomedCLIP Cross-Modal Attribution") |
| gr.Markdown( |
| "**What it does:** Uses [BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) " |
| "— a CLIP model trained on **15 million biomedical figure-caption pairs** — " |
| "to jointly score image regions (via UnSAM segmentation) and caption tokens. \n" |
| "**How to read:** **Green** = positive Shapley value (contributes to alignment). " |
| "**Red** = negative (hurts alignment)." |
| ) |
| mimic_biomedclip_overlay = gr.Image( |
| label="BiomedCLIP Overlay (labeled segments)", type="filepath") |
| mimic_biomedclip_token_plot = gr.Plot( |
| label="BiomedCLIP — Caption Word Shapley Values") |
| mimic_biomedclip_region_plot = gr.Plot( |
| label="BiomedCLIP — Image Region Shapley Values") |
|
|
| |
| |
| |
| with gr.Column(elem_id="mimic-method-llavamed"): |
| gr.Markdown("---\n### LLaVA-Med Attribution (4×4 Patch Grid, P1–P16)") |
| gr.Markdown( |
| "**What it does:** Uses [LLaVA-Med](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) " |
| "— a **7B parameter** medical VLM — evaluated over a **uniform 4×4 patch grid** " |
| "(16 cells labeled **P1–P16**, row-major). \n" |
| "**Two scoring approaches:** \n" |
| "- **Log-Prob:** How removing a region affects confidence in the correct caption \n" |
| "- **Generation:** How removing a region changes what the model describes" |
| ) |
| gr.Markdown( |
| "Each method colors segments by its own Shapley values — " |
| "**green** = positive, **red** = negative. Signs often differ " |
| "between Log-Prob and Generation, so each has its own overlay." |
| ) |
| with gr.Row(equal_height=True): |
| mimic_llavamed_unsam_lp_overlay = gr.Image( |
| label="LLaVA-Med Log-Prob — Overlay", |
| type="filepath", height=600) |
| mimic_llavamed_unsam_gen_overlay = gr.Image( |
| label="LLaVA-Med Generation — Overlay", |
| type="filepath", height=600) |
| with gr.Row(): |
| mimic_llavamed_unsam_lp_plot = gr.Plot( |
| label="LLaVA-Med Log-Prob — Segment Shapley Values") |
| mimic_llavamed_unsam_gen_plot = gr.Plot( |
| label="LLaVA-Med Generation — Segment Shapley Values") |
|
|
| |
| |
| |
| with gr.Column(elem_id="mimic-compare"): |
| gr.Markdown("---\n### Compare Two Methods Side-by-Side") |
| gr.Markdown( |
| "Select two methods to compare their attribution overlays " |
| "and Shapley value distributions on the same image." |
| ) |
| with gr.Row(): |
| mimic_compare_method_a = gr.Dropdown( |
| choices=_MIMIC_METHOD_NAMES, |
| label="Method A", |
| interactive=True, |
| ) |
| mimic_compare_method_b = gr.Dropdown( |
| choices=_MIMIC_METHOD_NAMES, |
| label="Method B", |
| interactive=True, |
| ) |
| with gr.Row(): |
| mimic_compare_img_a = gr.Image(label="Method A — Overlay", type="filepath") |
| mimic_compare_img_b = gr.Image(label="Method B — Overlay", type="filepath") |
| with gr.Row(): |
| mimic_compare_plot_a = gr.Plot(label="Method A — Shapley Values") |
| mimic_compare_plot_b = gr.Plot(label="Method B — Shapley Values") |
|
|
| mimic_meta = gr.JSON(label="Example Info", visible=False) |
|
|
| _mimic_outputs = [ |
| mimic_caption, mimic_original, mimic_findings, mimic_interpretation, |
| mimic_biomedclip_overlay, mimic_biomedclip_token_plot, |
| mimic_biomedclip_region_plot, mimic_llavamed_unsam_lp_overlay, |
| mimic_llavamed_unsam_lp_plot, mimic_llavamed_unsam_gen_overlay, |
| mimic_llavamed_unsam_gen_plot, mimic_biomedclip_interaction_html, |
| mimic_meta, mimic_results_state, mimic_compare_method_a, |
| ] |
| mimic_selector.change( |
| fn=_on_select_mimic_example, |
| inputs=[mimic_selector, mimic_method_toggle], |
| outputs=_mimic_outputs, |
| ) |
| mimic_method_toggle.change( |
| fn=_on_select_mimic_example, |
| inputs=[mimic_selector, mimic_method_toggle], |
| outputs=_mimic_outputs, |
| ) |
|
|
| |
| for _mimic_cmp_dd in [mimic_compare_method_a, mimic_compare_method_b]: |
| _mimic_cmp_dd.change( |
| fn=_on_mimic_compare_methods, |
| inputs=[mimic_compare_method_a, mimic_compare_method_b, |
| mimic_results_state], |
| outputs=[mimic_compare_img_a, mimic_compare_img_b, |
| mimic_compare_plot_a, mimic_compare_plot_b], |
| ) |
|
|
| |
| with gr.Tab("Dermoscopy ISIC (10 Samples)"): |
| gr.Markdown( |
| "**10-sample ISIC-2019 dermoscopy attribution benchmark** " |
| "(8 diagnostic classes: MEL × 2, NV × 2, BCC, AK, BKL, DF, VASC, SCC). \n" |
| "Source: [ISIC_2019_224](https://huggingface.co/datasets/MKZuziak/ISIC_2019_224) " |
| "— dermoscopic skin-lesion images from the International Skin Imaging Collaboration. \n" |
| "Captions are synthesized from class labels (clinical descriptions of each diagnosis)." |
| ) |
|
|
| _isic_choices = ( |
| [(v["category"], k) for k, v in ISIC_EXAMPLES.items()] |
| if _ISIC_AVAILABLE else [] |
| ) |
|
|
| isic_selector = gr.Dropdown( |
| choices=_isic_choices, |
| value=None, |
| label="Filter by Diagnosis", |
| interactive=True, |
| ) |
|
|
| isic_method_toggle = gr.Radio( |
| choices=["Influence", "Shapley"], |
| value="Influence", |
| label="Attribution method", |
| info=( |
| "Influence (default) is always positive — clearer for clinicians. " |
| "Shapley is signed (green = supports caption, red = detracts)." |
| ), |
| interactive=True, |
| ) |
|
|
| isic_caption = gr.Textbox( |
| label="Diagnostic Caption", |
| interactive=False, |
| lines=3, |
| ) |
|
|
| isic_original = gr.Image(label="Dermoscopic Image", type="filepath") |
|
|
| isic_interpretation = gr.Markdown( |
| value="*Select an example above to see the attribution analysis.*", |
| label="Interpretation", |
| ) |
|
|
| _isic_pill = ( |
| 'style="display:inline-block;padding:6px 14px;background:#e3f2fd;' |
| 'border-radius:16px;text-decoration:none;color:#1565c0;font-size:0.9em;' |
| 'border:1px solid #bbdefb;"' |
| ) |
| _isic_toc_html = ( |
| '<div style="background:#f8f9fa;border:1px solid #dee2e6;border-radius:8px;' |
| 'padding:16px;margin:12px 0;">' |
| '<strong style="font-size:1.05em;">Jump to Section:</strong>' |
| '<div style="display:flex;flex-wrap:wrap;gap:8px;margin-top:10px;">' |
| f'<a href="#isic-method-biomedclip" {_isic_pill}>BiomedCLIP Cross-Modal</a>' |
| f'<a href="#isic-method-llavamed" {_isic_pill}>LLaVA-Med (UnSAM)</a>' |
| f'<a href="#isic-interactive" {_isic_pill}>Interactive View</a>' |
| f'<a href="#isic-compare" {_isic_pill}>Compare Methods</a>' |
| '</div></div>' |
| ) |
| gr.HTML(value=_isic_toc_html) |
|
|
| isic_results_state = gr.State({}) |
|
|
| with gr.Column(elem_id="isic-interactive"): |
| gr.Markdown("---\n### BiomedCLIP Cross-Modal Interaction View — click segments or words") |
| gr.Markdown( |
| "**How to use:** Click any **image region** to see which caption words " |
| "it connects to, or click a **word** to see which regions activate." |
| ) |
| isic_biomedclip_interaction_html = _html_component( |
| "BiomedCLIP Cross-Modal Interaction View") |
|
|
| with gr.Column(elem_id="isic-method-biomedclip"): |
| gr.Markdown("---\n### BiomedCLIP Cross-Modal Attribution") |
| gr.Markdown( |
| "**What it does:** Uses [BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) " |
| "to jointly score dermoscopic image regions (via UnSAM segmentation) " |
| "and caption tokens. \n" |
| "**How to read:** **Influence** bars (default) show positive importance. " |
| "Switch to **Shapley** above for signed values (green/red)." |
| ) |
| isic_biomedclip_overlay = gr.Image( |
| label="BiomedCLIP Overlay (labeled segments)", type="filepath") |
| isic_biomedclip_token_plot = gr.Plot( |
| label="BiomedCLIP — Caption Word Values") |
| isic_biomedclip_region_plot = gr.Plot( |
| label="BiomedCLIP — Image Region Values") |
|
|
| with gr.Column(elem_id="isic-method-llavamed"): |
| gr.Markdown("---\n### LLaVA-Med Attribution (4×4 Patch Grid, P1–P16)") |
| gr.Markdown( |
| "**What it does:** Uses [LLaVA-Med](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) " |
| "— a **7B parameter** medical VLM — evaluated over a **uniform 4×4 patch grid** " |
| "(16 cells labeled **P1–P16**, row-major). \n" |
| "**Two scoring approaches:** \n" |
| "- **Log-Prob:** How removing a region affects confidence in the caption \n" |
| "- **Generation:** How removing a region changes what the model describes" |
| ) |
| with gr.Row(equal_height=True): |
| isic_llavamed_unsam_lp_overlay = gr.Image( |
| label="LLaVA-Med Log-Prob — Overlay", |
| type="filepath", height=600) |
| isic_llavamed_unsam_gen_overlay = gr.Image( |
| label="LLaVA-Med Generation — Overlay", |
| type="filepath", height=600) |
| with gr.Row(): |
| isic_llavamed_unsam_lp_plot = gr.Plot( |
| label="LLaVA-Med Log-Prob — Segment Values") |
| isic_llavamed_unsam_gen_plot = gr.Plot( |
| label="LLaVA-Med Generation — Segment Values") |
|
|
| with gr.Column(elem_id="isic-compare"): |
| gr.Markdown("---\n### Compare Two Methods Side-by-Side") |
| gr.Markdown( |
| "Select two methods to compare their attribution overlays " |
| "and value distributions on the same image." |
| ) |
| with gr.Row(): |
| isic_compare_method_a = gr.Dropdown( |
| choices=_ISIC_METHOD_NAMES, |
| label="Method A", |
| interactive=True, |
| ) |
| isic_compare_method_b = gr.Dropdown( |
| choices=_ISIC_METHOD_NAMES, |
| label="Method B", |
| interactive=True, |
| ) |
| with gr.Row(): |
| isic_compare_img_a = gr.Image(label="Method A — Overlay", type="filepath") |
| isic_compare_img_b = gr.Image(label="Method B — Overlay", type="filepath") |
| with gr.Row(): |
| isic_compare_plot_a = gr.Plot(label="Method A — Values") |
| isic_compare_plot_b = gr.Plot(label="Method B — Values") |
|
|
| isic_meta = gr.JSON(label="Example Info", visible=False) |
|
|
| _isic_outputs = [ |
| isic_caption, isic_original, isic_interpretation, |
| isic_biomedclip_overlay, isic_biomedclip_token_plot, |
| isic_biomedclip_region_plot, isic_llavamed_unsam_lp_overlay, |
| isic_llavamed_unsam_lp_plot, isic_llavamed_unsam_gen_overlay, |
| isic_llavamed_unsam_gen_plot, isic_biomedclip_interaction_html, |
| isic_meta, isic_results_state, isic_compare_method_a, |
| ] |
| isic_selector.change( |
| fn=_on_select_isic_example, |
| inputs=[isic_selector, isic_method_toggle], |
| outputs=_isic_outputs, |
| ) |
| isic_method_toggle.change( |
| fn=_on_select_isic_example, |
| inputs=[isic_selector, isic_method_toggle], |
| outputs=_isic_outputs, |
| ) |
|
|
| for _isic_cmp_dd in [isic_compare_method_a, isic_compare_method_b]: |
| _isic_cmp_dd.change( |
| fn=_on_isic_compare_methods, |
| inputs=[isic_compare_method_a, isic_compare_method_b, |
| isic_results_state], |
| outputs=[isic_compare_img_a, isic_compare_img_b, |
| isic_compare_plot_a, isic_compare_plot_b], |
| ) |
|
|
| |
| with gr.Tab("MS-COCO (5 Samples)"): |
| gr.Markdown( |
| "**CLIP cross-modal attribution on MS-COCO natural images.** \n" |
| "Click an **image region** or **caption word** below to explore " |
| "which parts of the image and text are most strongly linked via " |
| "CLIP's visual-language similarity score." |
| ) |
|
|
| _coco_choices = ( |
| [(v["title"], k) for k, v in COCO_EXAMPLES.items()] |
| if _COCO_AVAILABLE else [] |
| ) |
| _coco_default = _coco_choices[0][1] if _coco_choices else None |
| coco_selector = gr.Radio( |
| choices=_coco_choices, |
| value=_coco_default, |
| label="Select MS-COCO Example", |
| interactive=True, |
| ) |
| coco_method_toggle = gr.Radio( |
| choices=["Influence", "Shapley"], |
| value="Influence", |
| label="Attribution method", |
| info="Influence (default) is always positive. Shapley is signed.", |
| interactive=True, |
| ) |
| coco_caption = gr.Textbox( |
| label="Caption", interactive=False, lines=2, |
| ) |
|
|
| gr.Markdown("---\n#### Interactive Cross-Modal View") |
| gr.Markdown( |
| "Click a colored **image region** (left) to highlight the caption " |
| "words it interacts with, or click a **word** (right) to highlight " |
| "linked regions. Green = positive, red = negative." |
| ) |
| coco_interaction_html = _html_component( |
| "COCO Cross-Modal Interaction View") |
|
|
| gr.Markdown("---\n#### Attribution Details") |
| with gr.Row(): |
| coco_original = gr.Image( |
| label="Original Image", type="filepath") |
| coco_overlay = gr.Image( |
| label="CLIP Overlay (labeled segments)", type="filepath") |
|
|
| with gr.Row(): |
| coco_token_plot = gr.Plot( |
| label="Caption Word Shapley Values") |
| coco_region_plot = gr.Plot( |
| label="Image Region Shapley Values") |
|
|
| with gr.Row(): |
| coco_cross_plot = gr.Plot( |
| label="Top Image x Word Interactions") |
| coco_cross_table = gr.Dataframe( |
| headers=["Image Region", "Caption Word", "Score"], |
| label="Cross-Modal Interaction Table", |
| interactive=False, |
| ) |
|
|
| with gr.Accordion("Influence Heatmap (Regions x Words)", |
| open=False): |
| coco_heatmap = gr.Plot( |
| label="Full Heatmap: Regions x Caption Words") |
|
|
| gr.Markdown("---\n#### Masked Image Browser") |
| gr.Markdown( |
| "Browse ablation images: **solo** shows only the selected region " |
| "(everything else inpainted away); **removed** shows the image with " |
| "that region inpainted out." |
| ) |
| with gr.Row(): |
| coco_masked_dd = gr.Dropdown( |
| choices=[], |
| label="Region / View", |
| interactive=True, |
| ) |
| coco_masked_img = gr.Image( |
| label="Masked View", type="filepath") |
|
|
| coco_note = gr.Markdown(value="") |
|
|
| _coco_outputs = [ |
| coco_caption, |
| coco_original, |
| coco_overlay, |
| coco_interaction_html, |
| coco_token_plot, |
| coco_region_plot, |
| coco_cross_plot, |
| coco_cross_table, |
| coco_heatmap, |
| coco_note, |
| coco_masked_img, |
| coco_masked_dd, |
| ] |
| coco_selector.change( |
| fn=_on_select_coco_example, |
| inputs=[coco_selector, coco_method_toggle], |
| outputs=_coco_outputs, |
| ) |
| coco_method_toggle.change( |
| fn=_on_select_coco_example, |
| inputs=[coco_selector, coco_method_toggle], |
| outputs=_coco_outputs, |
| ) |
| coco_masked_dd.change( |
| fn=_on_select_coco_masked, |
| inputs=[coco_selector, coco_masked_dd], |
| outputs=[coco_masked_img], |
| ) |
|
|
| |
| |
|
|
| gr.HTML( |
| '<div style="margin-top:32px;padding:20px;border-top:1px solid #e5e7eb;text-align:center;">' |
| '<p style="font-weight:600;margin-bottom:10px;">Contributors — University of California, Berkeley</p>' |
| '<p style="display:flex;justify-content:center;gap:40px;flex-wrap:wrap;font-size:0.9em;">' |
| '<span><strong>Stephen Tao</strong> · Loader Layer · ' |
| '<a href="mailto:stephen_tao@berkeley.edu" style="color:#6366f1;">stephen_tao@berkeley.edu</a></span>' |
| '<span><strong>Yiting Gao</strong> · Attribution Layer · ' |
| '<a href="mailto:yg2025@berkeley.edu" style="color:#6366f1;">yg2025@berkeley.edu</a></span>' |
| '<span><strong>Qingpeng Kong</strong> · Visualization Layer · ' |
| '<a href="mailto:qpkong@berkeley.edu" style="color:#6366f1;">qpkong@berkeley.edu</a></span>' |
| '</p>' |
| '<p style="display:flex;justify-content:center;gap:40px;flex-wrap:wrap;font-size:0.9em;margin-top:6px;">' |
| '<span><strong>Advisor:</strong> Kannan Ramchandran · ' |
| '<a href="mailto:kannanr@berkeley.edu" style="color:#6366f1;">kannanr@berkeley.edu</a></span>' |
| '<span><strong>Mentor:</strong> Landon Butler · ' |
| '<a href="mailto:landonb@berkeley.edu" style="color:#6366f1;">landonb@berkeley.edu</a></span>' |
| '</p></div>' |
| ) |
|
|
| |
| app._custom_css = custom_css |
| return app |
|
|
|
|
| def _launch_kwargs(app_or_demo, **kwargs): |
| """Build common launch kwargs, injecting CSS for Gradio 6.x.""" |
| lk = dict( |
| server_name=kwargs.pop("server_name", os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")), |
| server_port=int(kwargs.pop("server_port", os.getenv("GRADIO_SERVER_PORT", "7860"))), |
| share=kwargs.pop("share", _env_flag("GRADIO_SHARE", False)), |
| show_error=kwargs.pop("show_error", True), |
| ) |
| css = getattr(app_or_demo, "_custom_css", None) |
| if css and _supports_kwarg(app_or_demo.launch, "css"): |
| lk["css"] = css |
| lk.update(kwargs) |
| return lk |
|
|
|
|
| def launch_demo(**kwargs): |
| demo = build_demo_app() |
| demo.launch(**_launch_kwargs(demo, **kwargs)) |
|
|
|
|
| def launch_app(**kwargs): |
| app = build_app() |
| app.launch(**_launch_kwargs(app, **kwargs)) |
|
|
|
|
| if __name__ == "__main__": |
| launch_app() |
|
|