File size: 11,019 Bytes
3e72399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc76fd
 
 
 
953646d
 
 
3e72399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1d352
 
 
 
3e72399
 
 
bb1d352
 
 
 
 
3e72399
bb1d352
 
 
3e72399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
coco_loader.py — Load precomputed MS-COCO CLIP cross-modal attribution results.

Results live in a flat directory (e.g. ygao15/image/results_mm/) with naming:
    coco_{id}_summary.txt
    coco_{id}_original.png
    coco_{id}_segmap.png
    coco_{id}_overlay.png

Unlike the medical benchmark (which uses subdirectories per example), these
are all flat files in one directory.
"""
from __future__ import annotations

import base64
import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from .medical_loader import (
    parse_summary_txt,
    extract_segment_regions,
    _image_to_b64,
    apply_method_to_clip_summary,
)

# ---------------------------------------------------------------------------
# Results directory resolution
# ---------------------------------------------------------------------------

_COCO_BASE = Path(__file__).resolve().parent.parent / "results"


def _resolve_coco_results_dir() -> Path:
    """Resolve COCO results directory, preferring dotmask results."""
    env = os.environ.get("ATTRLLM_COCO_RESULTS_DIR")
    if env:
        return Path(env)
    # Prefer dotmask-fixed results, fall back to original
    dotmask = _COCO_BASE / "coco_mm_dotmask"
    if dotmask.is_dir() and any(dotmask.iterdir()):
        return dotmask
    return _COCO_BASE / "coco_mm"


# ---------------------------------------------------------------------------
# Example registry
# ---------------------------------------------------------------------------

COCO_EXAMPLES: Dict[str, Dict[str, Any]] = OrderedDict([
    ("coco_281447", {
        "title": "Horse in Pasture",
        "caption": "Horse in fenced pasture with others grazing on grasses.",
    }),
    ("coco_402992", {
        "title": "Cattle in Grass",
        "caption": "Cattle lie in the grass to chew their cud.",
    }),
    ("coco_133233", {
        "title": "Marina with Boats",
        "caption": "A marina filled with boats floating in crystal blue water.",
    }),
    ("coco_307172", {
        "title": "Baked Dish on Plate",
        "caption": "A baked dish on a plate being touched by a woman.",
    }),
    ("coco_448256", {
        "title": "Men by Car",
        "caption": "Three men stand next to a car with its hood open.",
    }),
])


# ---------------------------------------------------------------------------
# Cross-modal pair filtering (same logic as medical _build_all_cross_modal_pairs
# but sourced from per_token_cross_modal instead of all_mobius)
# ---------------------------------------------------------------------------

def _build_coco_cross_modal_pairs(
    summary: Dict[str, Any],
    *,
    mobius_sidecar: Optional[Dict[str, Any]] = None,
    method: str = "shapley",
) -> List[Dict[str, Any]]:
    """
    Build significant cross-modal pairs.

    When ``mobius_sidecar`` is present, derives pairs from the stored Mobius
    dict using the requested method (shapley/banzhaf/influence). Otherwise
    falls back to the per-token section of the summary.

    Applies: |score| > 10% of max, top-3 per segment.
    """
    if mobius_sidecar is not None:
        from .medical_loader import _derive_cross_pairs_from_sidecar, _filter_and_rank_cross_pairs
        derived = _derive_cross_pairs_from_sidecar(mobius_sidecar, method=method)
        if derived:
            return _filter_and_rank_cross_pairs(derived)

    pairs = summary.get("per_token_cross_modal", [])
    if not pairs:
        return [
            {"pair": item["pair"], "value": item["value"]}
            for item in summary.get("cross_modal_interactions", [])
        ]

    max_abs = max(abs(p["value"]) for p in pairs)
    if max_abs == 0:
        return []
    threshold = max_abs * 0.10
    significant = [p for p in pairs if abs(p["value"]) > threshold]

    from collections import defaultdict
    by_seg: Dict[str, List[Dict]] = defaultdict(list)
    for p in significant:
        seg = p["pair"][0]
        by_seg[seg].append(p)

    result = []
    for seg, items in by_seg.items():
        items.sort(key=lambda x: abs(x["value"]), reverse=True)
        result.extend(items[:3])

    result.sort(key=lambda x: abs(x["value"]), reverse=True)
    return result


# ---------------------------------------------------------------------------
# Influence matrix builder from full_matrix_scores
# ---------------------------------------------------------------------------

def _build_influence_matrix(
    summary: Dict[str, Any],
) -> Tuple[np.ndarray, List[str], List[str]]:
    """Build seg_labels, tok_labels, and influence matrix from full_matrix_scores."""
    scores = summary.get("full_matrix_scores", [])

    # Get ordered labels from the summary's region/token value lists
    seg_labels = [v["label"] for v in summary.get("image_region_values", [])]
    tok_labels = [v["label"] for v in summary.get("token_values", [])]

    if not scores or not seg_labels or not tok_labels:
        return np.zeros((len(seg_labels), len(tok_labels))), seg_labels, tok_labels

    # Build index maps
    seg_idx = {label: i for i, label in enumerate(seg_labels)}
    # Token labels can repeat (e.g. two "tok:a"), so use position-based matching
    tok_idx: Dict[str, List[int]] = {}
    for i, label in enumerate(tok_labels):
        tok_idx.setdefault(label, []).append(i)

    matrix = np.zeros((len(seg_labels), len(tok_labels)))
    # Track which tok column to assign next for each seg-tok pair
    tok_counters: Dict[Tuple[str, str], int] = {}

    for entry in scores:
        seg, tok = entry["pair"]
        val = entry["value"]
        si = seg_idx.get(seg)
        if si is None:
            continue
        cols = tok_idx.get(tok, [])
        if not cols:
            continue
        key = (seg, tok)
        counter = tok_counters.get(key, 0)
        if counter < len(cols):
            matrix[si, cols[counter]] = val
            tok_counters[key] = counter + 1

    return matrix, seg_labels, tok_labels


# ---------------------------------------------------------------------------
# Main loader
# ---------------------------------------------------------------------------

def load_coco_example(
    example_id: str,
    results_dir: Optional[Path] = None,
    *,
    method: str = "shapley",
) -> Dict[str, Any]:
    """
    Load all data for a precomputed MS-COCO example.

    Returns a dict compatible with the benchmark tab's data contract. When a
    ``coco_{id}_mobius_dict.json`` sidecar is present, cross-modal pairs are
    derived fresh for the requested method (shapley/banzhaf/influence) and
    segment/token values are overwritten in place.
    """
    if results_dir is None:
        results_dir = _resolve_coco_results_dir()
    results_dir = Path(results_dir)

    # Extract numeric ID from example_id (e.g. "coco_56350" -> "56350")
    num_id = example_id.replace("coco_", "")

    # Build flat file paths
    prefix = results_dir / f"coco_{num_id}"
    summary_path = Path(f"{prefix}_summary.txt")
    original_path = Path(f"{prefix}_original.png")
    segmap_path = Path(f"{prefix}_segmap.png")
    overlay_path = Path(f"{prefix}_overlay.png")
    mobius_path = Path(f"{prefix}_mobius_dict.json")

    if not summary_path.exists():
        raise FileNotFoundError(f"COCO summary not found: {summary_path}")

    meta = COCO_EXAMPLES.get(example_id, {})

    # Parse summary
    summary = parse_summary_txt(summary_path)

    # Load Mobius sidecar if present (enables method toggle)
    mobius_sidecar: Optional[Dict[str, Any]] = None
    if mobius_path.exists():
        import json as _json
        try:
            with open(mobius_path, "r") as _f:
                mobius_sidecar = _json.load(_f)
        except Exception:
            mobius_sidecar = None

    apply_method_to_clip_summary(summary, mobius_sidecar, method)

    # Build cross-modal pairs (method-aware when sidecar exists)
    all_cross_modal_pairs = _build_coco_cross_modal_pairs(
        summary, mobius_sidecar=mobius_sidecar, method=method,
    )

    # Build influence matrix
    influence_matrix, seg_labels, tok_labels = _build_influence_matrix(summary)

    # Image paths and b64
    image_paths = {}
    image_b64 = {}
    for name, path in [("original", original_path), ("segmap", segmap_path), ("overlay", overlay_path)]:
        if path.exists():
            image_paths[name] = str(path)
            image_b64[name] = _image_to_b64(str(path))
        else:
            image_paths[name] = ""
            image_b64[name] = ""

    caption = summary.get("caption") or meta.get("caption", "")

    # Masked image browser data — list "seg_N (solo)", "seg_N (removed)" in
    # ascending numeric order based on the PNGs on disk, so the dropdown is
    # predictable (seg_0 solo, seg_0 removed, seg_1 solo, ...) regardless of
    # the seg_labels ordering coming out of the summary file.
    masked_dir = results_dir / f"coco_{num_id}_masked_lama"
    region_choices: List[str] = []
    if masked_dir.is_dir():
        seg_indices: set = set()
        for entry in masked_dir.iterdir():
            m = re.match(r"^seg_(\d+)_(solo|removed)\.png$", entry.name)
            if m:
                seg_indices.add(int(m.group(1)))
        region_choices = ["all_masked"]
        for idx in sorted(seg_indices):
            region_choices.append(f"seg_{idx} (solo)")
            region_choices.append(f"seg_{idx} (removed)")

    return {
        "example_id": example_id,
        "meta": meta,
        "caption": caption,
        "method": method,
        "has_clip": True,
        "has_mobius": mobius_sidecar is not None,
        "summary": summary,
        "mobius_sidecar": mobius_sidecar,
        "image_paths": image_paths,
        "image_b64": image_b64,
        "seg_labels": seg_labels,
        "tok_labels": tok_labels,
        "influence_matrix": influence_matrix,
        "all_cross_modal_pairs": all_cross_modal_pairs,
        "region_choices": region_choices,
        "masked_dir": str(masked_dir),
    }


def get_coco_masked_image_path(
    example_id: str, choice: str, results_dir: Optional[Path] = None,
) -> Optional[str]:
    """
    Return the file path for a COCO masked image based on dropdown choice.

    COCO masked images live in a flat subdirectory:
        <results_dir>/coco_<num_id>_masked_lama/{all_masked,seg_X_solo,seg_X_removed}.png

    choice is one of: "all_masked", "seg_0 (solo)", "seg_0 (removed)", etc.
    """
    if results_dir is None:
        results_dir = _resolve_coco_results_dir()
    results_dir = Path(results_dir)

    num_id = example_id.replace("coco_", "")
    masked_dir = results_dir / f"coco_{num_id}_masked_lama"

    if choice == "all_masked":
        p = masked_dir / "all_masked.png"
        return str(p) if p.exists() else None

    m = re.match(r"^(seg_\d+)\s+\((solo|removed)\)$", choice)
    if not m:
        return None

    filename = f"{m.group(1)}_{m.group(2)}.png"
    p = masked_dir / filename
    return str(p) if p.exists() else None