AttrLLM / visualization /coco_loader.py
Stephentao-30
COCO Masked Image Browser: order Region/View dropdown by seg index
bb1d352
"""
coco_loader.py — Load precomputed MS-COCO CLIP cross-modal attribution results.
Results live in a flat directory (e.g. ygao15/image/results_mm/) with naming:
coco_{id}_summary.txt
coco_{id}_original.png
coco_{id}_segmap.png
coco_{id}_overlay.png
Unlike the medical benchmark (which uses subdirectories per example), these
are all flat files in one directory.
"""
from __future__ import annotations
import base64
import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from .medical_loader import (
parse_summary_txt,
extract_segment_regions,
_image_to_b64,
apply_method_to_clip_summary,
)
# ---------------------------------------------------------------------------
# Results directory resolution
# ---------------------------------------------------------------------------
_COCO_BASE = Path(__file__).resolve().parent.parent / "results"
def _resolve_coco_results_dir() -> Path:
"""Resolve COCO results directory, preferring dotmask results."""
env = os.environ.get("ATTRLLM_COCO_RESULTS_DIR")
if env:
return Path(env)
# Prefer dotmask-fixed results, fall back to original
dotmask = _COCO_BASE / "coco_mm_dotmask"
if dotmask.is_dir() and any(dotmask.iterdir()):
return dotmask
return _COCO_BASE / "coco_mm"
# ---------------------------------------------------------------------------
# Example registry
# ---------------------------------------------------------------------------
COCO_EXAMPLES: Dict[str, Dict[str, Any]] = OrderedDict([
("coco_281447", {
"title": "Horse in Pasture",
"caption": "Horse in fenced pasture with others grazing on grasses.",
}),
("coco_402992", {
"title": "Cattle in Grass",
"caption": "Cattle lie in the grass to chew their cud.",
}),
("coco_133233", {
"title": "Marina with Boats",
"caption": "A marina filled with boats floating in crystal blue water.",
}),
("coco_307172", {
"title": "Baked Dish on Plate",
"caption": "A baked dish on a plate being touched by a woman.",
}),
("coco_448256", {
"title": "Men by Car",
"caption": "Three men stand next to a car with its hood open.",
}),
])
# ---------------------------------------------------------------------------
# Cross-modal pair filtering (same logic as medical _build_all_cross_modal_pairs
# but sourced from per_token_cross_modal instead of all_mobius)
# ---------------------------------------------------------------------------
def _build_coco_cross_modal_pairs(
summary: Dict[str, Any],
*,
mobius_sidecar: Optional[Dict[str, Any]] = None,
method: str = "shapley",
) -> List[Dict[str, Any]]:
"""
Build significant cross-modal pairs.
When ``mobius_sidecar`` is present, derives pairs from the stored Mobius
dict using the requested method (shapley/banzhaf/influence). Otherwise
falls back to the per-token section of the summary.
Applies: |score| > 10% of max, top-3 per segment.
"""
if mobius_sidecar is not None:
from .medical_loader import _derive_cross_pairs_from_sidecar, _filter_and_rank_cross_pairs
derived = _derive_cross_pairs_from_sidecar(mobius_sidecar, method=method)
if derived:
return _filter_and_rank_cross_pairs(derived)
pairs = summary.get("per_token_cross_modal", [])
if not pairs:
return [
{"pair": item["pair"], "value": item["value"]}
for item in summary.get("cross_modal_interactions", [])
]
max_abs = max(abs(p["value"]) for p in pairs)
if max_abs == 0:
return []
threshold = max_abs * 0.10
significant = [p for p in pairs if abs(p["value"]) > threshold]
from collections import defaultdict
by_seg: Dict[str, List[Dict]] = defaultdict(list)
for p in significant:
seg = p["pair"][0]
by_seg[seg].append(p)
result = []
for seg, items in by_seg.items():
items.sort(key=lambda x: abs(x["value"]), reverse=True)
result.extend(items[:3])
result.sort(key=lambda x: abs(x["value"]), reverse=True)
return result
# ---------------------------------------------------------------------------
# Influence matrix builder from full_matrix_scores
# ---------------------------------------------------------------------------
def _build_influence_matrix(
summary: Dict[str, Any],
) -> Tuple[np.ndarray, List[str], List[str]]:
"""Build seg_labels, tok_labels, and influence matrix from full_matrix_scores."""
scores = summary.get("full_matrix_scores", [])
# Get ordered labels from the summary's region/token value lists
seg_labels = [v["label"] for v in summary.get("image_region_values", [])]
tok_labels = [v["label"] for v in summary.get("token_values", [])]
if not scores or not seg_labels or not tok_labels:
return np.zeros((len(seg_labels), len(tok_labels))), seg_labels, tok_labels
# Build index maps
seg_idx = {label: i for i, label in enumerate(seg_labels)}
# Token labels can repeat (e.g. two "tok:a"), so use position-based matching
tok_idx: Dict[str, List[int]] = {}
for i, label in enumerate(tok_labels):
tok_idx.setdefault(label, []).append(i)
matrix = np.zeros((len(seg_labels), len(tok_labels)))
# Track which tok column to assign next for each seg-tok pair
tok_counters: Dict[Tuple[str, str], int] = {}
for entry in scores:
seg, tok = entry["pair"]
val = entry["value"]
si = seg_idx.get(seg)
if si is None:
continue
cols = tok_idx.get(tok, [])
if not cols:
continue
key = (seg, tok)
counter = tok_counters.get(key, 0)
if counter < len(cols):
matrix[si, cols[counter]] = val
tok_counters[key] = counter + 1
return matrix, seg_labels, tok_labels
# ---------------------------------------------------------------------------
# Main loader
# ---------------------------------------------------------------------------
def load_coco_example(
example_id: str,
results_dir: Optional[Path] = None,
*,
method: str = "shapley",
) -> Dict[str, Any]:
"""
Load all data for a precomputed MS-COCO example.
Returns a dict compatible with the benchmark tab's data contract. When a
``coco_{id}_mobius_dict.json`` sidecar is present, cross-modal pairs are
derived fresh for the requested method (shapley/banzhaf/influence) and
segment/token values are overwritten in place.
"""
if results_dir is None:
results_dir = _resolve_coco_results_dir()
results_dir = Path(results_dir)
# Extract numeric ID from example_id (e.g. "coco_56350" -> "56350")
num_id = example_id.replace("coco_", "")
# Build flat file paths
prefix = results_dir / f"coco_{num_id}"
summary_path = Path(f"{prefix}_summary.txt")
original_path = Path(f"{prefix}_original.png")
segmap_path = Path(f"{prefix}_segmap.png")
overlay_path = Path(f"{prefix}_overlay.png")
mobius_path = Path(f"{prefix}_mobius_dict.json")
if not summary_path.exists():
raise FileNotFoundError(f"COCO summary not found: {summary_path}")
meta = COCO_EXAMPLES.get(example_id, {})
# Parse summary
summary = parse_summary_txt(summary_path)
# Load Mobius sidecar if present (enables method toggle)
mobius_sidecar: Optional[Dict[str, Any]] = None
if mobius_path.exists():
import json as _json
try:
with open(mobius_path, "r") as _f:
mobius_sidecar = _json.load(_f)
except Exception:
mobius_sidecar = None
apply_method_to_clip_summary(summary, mobius_sidecar, method)
# Build cross-modal pairs (method-aware when sidecar exists)
all_cross_modal_pairs = _build_coco_cross_modal_pairs(
summary, mobius_sidecar=mobius_sidecar, method=method,
)
# Build influence matrix
influence_matrix, seg_labels, tok_labels = _build_influence_matrix(summary)
# Image paths and b64
image_paths = {}
image_b64 = {}
for name, path in [("original", original_path), ("segmap", segmap_path), ("overlay", overlay_path)]:
if path.exists():
image_paths[name] = str(path)
image_b64[name] = _image_to_b64(str(path))
else:
image_paths[name] = ""
image_b64[name] = ""
caption = summary.get("caption") or meta.get("caption", "")
# Masked image browser data — list "seg_N (solo)", "seg_N (removed)" in
# ascending numeric order based on the PNGs on disk, so the dropdown is
# predictable (seg_0 solo, seg_0 removed, seg_1 solo, ...) regardless of
# the seg_labels ordering coming out of the summary file.
masked_dir = results_dir / f"coco_{num_id}_masked_lama"
region_choices: List[str] = []
if masked_dir.is_dir():
seg_indices: set = set()
for entry in masked_dir.iterdir():
m = re.match(r"^seg_(\d+)_(solo|removed)\.png$", entry.name)
if m:
seg_indices.add(int(m.group(1)))
region_choices = ["all_masked"]
for idx in sorted(seg_indices):
region_choices.append(f"seg_{idx} (solo)")
region_choices.append(f"seg_{idx} (removed)")
return {
"example_id": example_id,
"meta": meta,
"caption": caption,
"method": method,
"has_clip": True,
"has_mobius": mobius_sidecar is not None,
"summary": summary,
"mobius_sidecar": mobius_sidecar,
"image_paths": image_paths,
"image_b64": image_b64,
"seg_labels": seg_labels,
"tok_labels": tok_labels,
"influence_matrix": influence_matrix,
"all_cross_modal_pairs": all_cross_modal_pairs,
"region_choices": region_choices,
"masked_dir": str(masked_dir),
}
def get_coco_masked_image_path(
example_id: str, choice: str, results_dir: Optional[Path] = None,
) -> Optional[str]:
"""
Return the file path for a COCO masked image based on dropdown choice.
COCO masked images live in a flat subdirectory:
<results_dir>/coco_<num_id>_masked_lama/{all_masked,seg_X_solo,seg_X_removed}.png
choice is one of: "all_masked", "seg_0 (solo)", "seg_0 (removed)", etc.
"""
if results_dir is None:
results_dir = _resolve_coco_results_dir()
results_dir = Path(results_dir)
num_id = example_id.replace("coco_", "")
masked_dir = results_dir / f"coco_{num_id}_masked_lama"
if choice == "all_masked":
p = masked_dir / "all_masked.png"
return str(p) if p.exists() else None
m = re.match(r"^(seg_\d+)\s+\((solo|removed)\)$", choice)
if not m:
return None
filename = f"{m.group(1)}_{m.group(2)}.png"
p = masked_dir / filename
return str(p) if p.exists() else None