""" coco_loader.py — Load precomputed MS-COCO CLIP cross-modal attribution results. Results live in a flat directory (e.g. ygao15/image/results_mm/) with naming: coco_{id}_summary.txt coco_{id}_original.png coco_{id}_segmap.png coco_{id}_overlay.png Unlike the medical benchmark (which uses subdirectories per example), these are all flat files in one directory. """ from __future__ import annotations import base64 import os import re from collections import OrderedDict from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np from .medical_loader import ( parse_summary_txt, extract_segment_regions, _image_to_b64, apply_method_to_clip_summary, ) # --------------------------------------------------------------------------- # Results directory resolution # --------------------------------------------------------------------------- _COCO_BASE = Path(__file__).resolve().parent.parent / "results" def _resolve_coco_results_dir() -> Path: """Resolve COCO results directory, preferring dotmask results.""" env = os.environ.get("ATTRLLM_COCO_RESULTS_DIR") if env: return Path(env) # Prefer dotmask-fixed results, fall back to original dotmask = _COCO_BASE / "coco_mm_dotmask" if dotmask.is_dir() and any(dotmask.iterdir()): return dotmask return _COCO_BASE / "coco_mm" # --------------------------------------------------------------------------- # Example registry # --------------------------------------------------------------------------- COCO_EXAMPLES: Dict[str, Dict[str, Any]] = OrderedDict([ ("coco_281447", { "title": "Horse in Pasture", "caption": "Horse in fenced pasture with others grazing on grasses.", }), ("coco_402992", { "title": "Cattle in Grass", "caption": "Cattle lie in the grass to chew their cud.", }), ("coco_133233", { "title": "Marina with Boats", "caption": "A marina filled with boats floating in crystal blue water.", }), ("coco_307172", { "title": "Baked Dish on Plate", "caption": "A baked dish on a plate being touched by a woman.", }), ("coco_448256", { "title": "Men by Car", "caption": "Three men stand next to a car with its hood open.", }), ]) # --------------------------------------------------------------------------- # Cross-modal pair filtering (same logic as medical _build_all_cross_modal_pairs # but sourced from per_token_cross_modal instead of all_mobius) # --------------------------------------------------------------------------- def _build_coco_cross_modal_pairs( summary: Dict[str, Any], *, mobius_sidecar: Optional[Dict[str, Any]] = None, method: str = "shapley", ) -> List[Dict[str, Any]]: """ Build significant cross-modal pairs. When ``mobius_sidecar`` is present, derives pairs from the stored Mobius dict using the requested method (shapley/banzhaf/influence). Otherwise falls back to the per-token section of the summary. Applies: |score| > 10% of max, top-3 per segment. """ if mobius_sidecar is not None: from .medical_loader import _derive_cross_pairs_from_sidecar, _filter_and_rank_cross_pairs derived = _derive_cross_pairs_from_sidecar(mobius_sidecar, method=method) if derived: return _filter_and_rank_cross_pairs(derived) pairs = summary.get("per_token_cross_modal", []) if not pairs: return [ {"pair": item["pair"], "value": item["value"]} for item in summary.get("cross_modal_interactions", []) ] max_abs = max(abs(p["value"]) for p in pairs) if max_abs == 0: return [] threshold = max_abs * 0.10 significant = [p for p in pairs if abs(p["value"]) > threshold] from collections import defaultdict by_seg: Dict[str, List[Dict]] = defaultdict(list) for p in significant: seg = p["pair"][0] by_seg[seg].append(p) result = [] for seg, items in by_seg.items(): items.sort(key=lambda x: abs(x["value"]), reverse=True) result.extend(items[:3]) result.sort(key=lambda x: abs(x["value"]), reverse=True) return result # --------------------------------------------------------------------------- # Influence matrix builder from full_matrix_scores # --------------------------------------------------------------------------- def _build_influence_matrix( summary: Dict[str, Any], ) -> Tuple[np.ndarray, List[str], List[str]]: """Build seg_labels, tok_labels, and influence matrix from full_matrix_scores.""" scores = summary.get("full_matrix_scores", []) # Get ordered labels from the summary's region/token value lists seg_labels = [v["label"] for v in summary.get("image_region_values", [])] tok_labels = [v["label"] for v in summary.get("token_values", [])] if not scores or not seg_labels or not tok_labels: return np.zeros((len(seg_labels), len(tok_labels))), seg_labels, tok_labels # Build index maps seg_idx = {label: i for i, label in enumerate(seg_labels)} # Token labels can repeat (e.g. two "tok:a"), so use position-based matching tok_idx: Dict[str, List[int]] = {} for i, label in enumerate(tok_labels): tok_idx.setdefault(label, []).append(i) matrix = np.zeros((len(seg_labels), len(tok_labels))) # Track which tok column to assign next for each seg-tok pair tok_counters: Dict[Tuple[str, str], int] = {} for entry in scores: seg, tok = entry["pair"] val = entry["value"] si = seg_idx.get(seg) if si is None: continue cols = tok_idx.get(tok, []) if not cols: continue key = (seg, tok) counter = tok_counters.get(key, 0) if counter < len(cols): matrix[si, cols[counter]] = val tok_counters[key] = counter + 1 return matrix, seg_labels, tok_labels # --------------------------------------------------------------------------- # Main loader # --------------------------------------------------------------------------- def load_coco_example( example_id: str, results_dir: Optional[Path] = None, *, method: str = "shapley", ) -> Dict[str, Any]: """ Load all data for a precomputed MS-COCO example. Returns a dict compatible with the benchmark tab's data contract. When a ``coco_{id}_mobius_dict.json`` sidecar is present, cross-modal pairs are derived fresh for the requested method (shapley/banzhaf/influence) and segment/token values are overwritten in place. """ if results_dir is None: results_dir = _resolve_coco_results_dir() results_dir = Path(results_dir) # Extract numeric ID from example_id (e.g. "coco_56350" -> "56350") num_id = example_id.replace("coco_", "") # Build flat file paths prefix = results_dir / f"coco_{num_id}" summary_path = Path(f"{prefix}_summary.txt") original_path = Path(f"{prefix}_original.png") segmap_path = Path(f"{prefix}_segmap.png") overlay_path = Path(f"{prefix}_overlay.png") mobius_path = Path(f"{prefix}_mobius_dict.json") if not summary_path.exists(): raise FileNotFoundError(f"COCO summary not found: {summary_path}") meta = COCO_EXAMPLES.get(example_id, {}) # Parse summary summary = parse_summary_txt(summary_path) # Load Mobius sidecar if present (enables method toggle) mobius_sidecar: Optional[Dict[str, Any]] = None if mobius_path.exists(): import json as _json try: with open(mobius_path, "r") as _f: mobius_sidecar = _json.load(_f) except Exception: mobius_sidecar = None apply_method_to_clip_summary(summary, mobius_sidecar, method) # Build cross-modal pairs (method-aware when sidecar exists) all_cross_modal_pairs = _build_coco_cross_modal_pairs( summary, mobius_sidecar=mobius_sidecar, method=method, ) # Build influence matrix influence_matrix, seg_labels, tok_labels = _build_influence_matrix(summary) # Image paths and b64 image_paths = {} image_b64 = {} for name, path in [("original", original_path), ("segmap", segmap_path), ("overlay", overlay_path)]: if path.exists(): image_paths[name] = str(path) image_b64[name] = _image_to_b64(str(path)) else: image_paths[name] = "" image_b64[name] = "" caption = summary.get("caption") or meta.get("caption", "") # Masked image browser data — list "seg_N (solo)", "seg_N (removed)" in # ascending numeric order based on the PNGs on disk, so the dropdown is # predictable (seg_0 solo, seg_0 removed, seg_1 solo, ...) regardless of # the seg_labels ordering coming out of the summary file. masked_dir = results_dir / f"coco_{num_id}_masked_lama" region_choices: List[str] = [] if masked_dir.is_dir(): seg_indices: set = set() for entry in masked_dir.iterdir(): m = re.match(r"^seg_(\d+)_(solo|removed)\.png$", entry.name) if m: seg_indices.add(int(m.group(1))) region_choices = ["all_masked"] for idx in sorted(seg_indices): region_choices.append(f"seg_{idx} (solo)") region_choices.append(f"seg_{idx} (removed)") return { "example_id": example_id, "meta": meta, "caption": caption, "method": method, "has_clip": True, "has_mobius": mobius_sidecar is not None, "summary": summary, "mobius_sidecar": mobius_sidecar, "image_paths": image_paths, "image_b64": image_b64, "seg_labels": seg_labels, "tok_labels": tok_labels, "influence_matrix": influence_matrix, "all_cross_modal_pairs": all_cross_modal_pairs, "region_choices": region_choices, "masked_dir": str(masked_dir), } def get_coco_masked_image_path( example_id: str, choice: str, results_dir: Optional[Path] = None, ) -> Optional[str]: """ Return the file path for a COCO masked image based on dropdown choice. COCO masked images live in a flat subdirectory: /coco__masked_lama/{all_masked,seg_X_solo,seg_X_removed}.png choice is one of: "all_masked", "seg_0 (solo)", "seg_0 (removed)", etc. """ if results_dir is None: results_dir = _resolve_coco_results_dir() results_dir = Path(results_dir) num_id = example_id.replace("coco_", "") masked_dir = results_dir / f"coco_{num_id}_masked_lama" if choice == "all_masked": p = masked_dir / "all_masked.png" return str(p) if p.exists() else None m = re.match(r"^(seg_\d+)\s+\((solo|removed)\)$", choice) if not m: return None filename = f"{m.group(1)}_{m.group(2)}.png" p = masked_dir / filename return str(p) if p.exists() else None