File size: 13,335 Bytes
55b60a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
"""Dataset helpers for Experiment 2 (CoT / multi-hop faithfulness).

Named dataset_utils to avoid collision with the HF `datasets` package.
"""

from __future__ import annotations

import json
import random
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

from attribution_datasets import (
    AttributionExample,
    MoreHopQAAttributionDataset,
    RulerAttributionDataset,
)


@dataclass
class CachedExample:
    prompt: str
    target: Optional[str]
    indices_to_explain: Optional[List[int]]
    attr_mask_indices: Optional[List[int]]
    sink_span: Optional[List[int]]
    thinking_span: Optional[List[int]]
    metadata: Dict[str, Any]


def read_cached_jsonl(path: Path) -> List[CachedExample]:
    examples: List[CachedExample] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            examples.append(
                CachedExample(
                    prompt=obj["prompt"],
                    target=obj.get("target"),
                    indices_to_explain=obj.get("indices_to_explain"),
                    attr_mask_indices=obj.get("attr_mask_indices"),
                    sink_span=obj.get("sink_span"),
                    thinking_span=obj.get("thinking_span"),
                    metadata=obj.get("metadata", {}),
                )
            )
    return examples


def load_cached(path: Path, sample: Optional[int] = None, seed: int = 42) -> List[CachedExample]:
    ex = read_cached_jsonl(path)
    if sample is not None and sample < len(ex):
        random.Random(seed).shuffle(ex)
        ex = ex[:sample]
    return ex


def load_ruler(path: Path, sample: Optional[int] = None, seed: int = 42) -> List[CachedExample]:
    ds = RulerAttributionDataset(path)
    examples: List[CachedExample] = []
    ex_iter: Iterable[AttributionExample] = ds
    if sample is not None and sample < len(ds):
        ex_iter = list(ds)
        random.Random(seed).shuffle(ex_iter)
        ex_iter = ex_iter[:sample]
    for ex in ex_iter:
        examples.append(
            CachedExample(
                prompt=ex.prompt,
                target=ex.target,
                indices_to_explain=ex.indices_to_explain,
                attr_mask_indices=ex.attr_mask_indices,
                sink_span=None,
                thinking_span=None,
                metadata=ex.metadata,
            )
        )
    return examples


def load_morehopqa(
    path: str | Path = "./data/with_human_verification.json", sample: Optional[int] = None, seed: int = 42
) -> List[CachedExample]:
    ds = MoreHopQAAttributionDataset(path)
    ex_iter: Iterable[AttributionExample] = ds
    if sample is not None and sample < len(ds):
        ex_iter = list(ds)
        random.Random(seed).shuffle(ex_iter)
        ex_iter = ex_iter[:sample]
    examples: List[CachedExample] = []
    for ex in ex_iter:
        examples.append(
            CachedExample(
                prompt=ex.prompt,
                target=None,
                indices_to_explain=ex.indices_to_explain,
                attr_mask_indices=ex.attr_mask_indices,
                sink_span=None,
                thinking_span=None,
                metadata=ex.metadata,
            )
        )
    return examples


def auto_find_ruler(task: str) -> Optional[Path]:
    length_dirs = ["4096", "8192", "16384", "32768", "65536", "131072"]
    base = Path("data/ruler_multihop")
    for ld in length_dirs:
        cand = base / ld / task / "validation.jsonl"
        if cand.exists():
            return cand
    return None


def dataset_from_name(name: str) -> Optional[Path]:
    if name == "hotpotqa_long":
        return auto_find_ruler("hotpotqa_long")
    if name.startswith("vt_"):
        return auto_find_ruler(name)
    if name.startswith("niah"):
        return auto_find_ruler(name)
    p = Path(name)
    if p.exists():
        return p
    return None


_BOX_PATTERN = re.compile(r"\\box(?:ed)?\s*[\{{](.*?)[\}}]", flags=re.DOTALL)


def _find_box_span(text: str) -> Optional[tuple[int, int, str]]:
    """Return (start_char, end_char, answer_text) for the last \\boxed block."""
    matches = list(_BOX_PATTERN.finditer(text))
    if not matches:
        return None
    m = matches[-1]
    return m.start(0), m.end(0), m.group(1).strip()


def extract_boxed_answer(text: str) -> Optional[str]:
    """Extract the answer string inside the last \\boxed{} block."""
    match = _find_box_span(text)
    return match[2] if match else None


def _find_answer_span(text: str, answer: str) -> Optional[tuple[int, int]]:
    """Return (start_char, end_char) for the last occurrence of `answer` in text."""
    if not answer or not text:
        return None
    start = text.rfind(answer)
    if start == -1:
        return None
    return start, start + len(answer)


def split_boxed_generation(text: str) -> Optional[tuple[str, str, str]]:
    """Return (thinking_text, boxed_segment, boxed_answer) if format matches."""
    if not text:
        return None
    match = _find_box_span(text)
    if not match:
        return None

    start_char, end_char, boxed_inner = match
    boxed_segment = text[start_char:end_char].strip()
    thinking_text = text[:start_char].strip()
    trailing = text[end_char:].strip()

    if not boxed_inner or not boxed_segment:
        return None
    if trailing:
        return None
    if not thinking_text:
        return None

    return thinking_text, boxed_segment, boxed_inner


def attach_spans_from_answer(
    example: CachedExample, tokenizer, answer_text: Optional[str] = None
) -> CachedExample:
    """Attach sink/thinking spans by locating the (plain) answer in `target`.

    `answer_text` should be the extracted boxed answer; falls back to metadata or
    parsing the target when omitted. Works even when the target no longer keeps
    the \\box{} wrapper.
    """
    tgt = example.target or ""
    answer = (answer_text or "").strip()
    if not answer:
        answer = (example.metadata.get("boxed_answer") or extract_boxed_answer(tgt) or "").strip()

    metadata = dict(example.metadata)
    if answer:
        metadata.setdefault("boxed_answer", answer)

    if tokenizer is None or not tgt or not answer:
        return CachedExample(
            prompt=example.prompt,
            target=example.target,
            indices_to_explain=example.indices_to_explain,
            attr_mask_indices=example.attr_mask_indices,
            sink_span=example.sink_span,
            thinking_span=example.thinking_span,
            metadata=metadata,
        )

    span = _find_answer_span(tgt, answer)
    if span is None:
        return CachedExample(
            prompt=example.prompt,
            target=example.target,
            indices_to_explain=example.indices_to_explain,
            attr_mask_indices=example.attr_mask_indices,
            sink_span=example.sink_span,
            thinking_span=example.thinking_span,
            metadata=metadata,
        )

    span_start_char, span_end_char = span
    gen_ids = tokenizer(tgt, add_special_tokens=False, return_offsets_mapping=True)
    sink_tokens: List[int] = []
    for idx, (s, e) in enumerate(gen_ids["offset_mapping"]):
        # include tokens that overlap the answer span
        if s < span_end_char and e > span_start_char:
            sink_tokens.append(idx)
    if not sink_tokens:
        return CachedExample(
            prompt=example.prompt,
            target=example.target,
            indices_to_explain=example.indices_to_explain,
            attr_mask_indices=example.attr_mask_indices,
            sink_span=example.sink_span,
            thinking_span=example.thinking_span,
            metadata=metadata,
        )

    sink_span = [min(sink_tokens), max(sink_tokens)]
    thinking_end = max(0, sink_span[0] - 1)
    thinking_span = [0, thinking_end] if thinking_end >= 0 else sink_span

    return CachedExample(
        prompt=example.prompt,
        target=example.target,
        indices_to_explain=example.indices_to_explain,
        attr_mask_indices=example.attr_mask_indices,
        sink_span=example.sink_span or sink_span,
        thinking_span=example.thinking_span or thinking_span,
        metadata=metadata,
    )


def attach_spans_from_boxed(example: CachedExample, tokenizer) -> CachedExample:
    """Backward-compatible wrapper that first looks for \\box{} then falls back to answer text."""
    tgt = example.target
    match = _find_box_span(tgt) if tgt else None
    boxed_answer = match[2] if match else None
    return attach_spans_from_answer(example, tokenizer, boxed_answer)


def ruler_gold_prompt_token_indices(example: CachedExample, tokenizer) -> List[int]:
    """Return token indices (prompt-side) that overlap RULER `needle_spans` in metadata.

    The returned indices are with respect to `tokenizer(" " + example.prompt, add_special_tokens=False)`,
    matching the attribution pipeline's leading-space convention.
    """
    needle_spans = (example.metadata or {}).get("needle_spans") or []
    if not isinstance(needle_spans, list) or not needle_spans:
        return []

    prompt_text = " " + (example.prompt or "")
    enc = tokenizer(prompt_text, add_special_tokens=False, return_offsets_mapping=True)
    offsets = enc.get("offset_mapping")
    if offsets is None:
        raise ValueError("Tokenizer does not provide offset_mapping; cannot map needle_spans to tokens.")

    spans: List[tuple[int, int]] = []
    for item in needle_spans:
        if not isinstance(item, dict):
            continue
        raw = item.get("span")
        if not (isinstance(raw, list) and len(raw) == 2):
            continue
        try:
            start = int(raw[0]) + 1  # shift for leading space in prompt_text
            end = int(raw[1]) + 1
        except Exception:
            continue
        if end > start:
            spans.append((start, end))

    if not spans:
        return []

    gold: set[int] = set()
    for tok_idx, off in enumerate(offsets):
        if off is None:
            continue
        try:
            s, e = int(off[0]), int(off[1])
        except Exception:
            continue
        if e <= s:
            continue
        for span_start, span_end in spans:
            if s < span_end and e > span_start:
                gold.add(tok_idx)
                break

    return sorted(gold)


class DatasetLoader:
    """Thin loader that resolves and samples datasets for exp2."""

    def __init__(self, seed: int = 42, data_root: Path | str = Path("exp/exp2/data")) -> None:
        self.seed = seed
        self.data_root = Path(data_root)

    def _sample(self, items: List[CachedExample], sample: Optional[int]) -> List[CachedExample]:
        if sample is not None and sample < len(items):
            rnd = random.Random(self.seed)
            rnd.shuffle(items)
            items = items[:sample]
        return items

    def _cached_path(self, name: str) -> Optional[Path]:
        path = self.data_root / f"{name}.jsonl"
        return path if path.exists() else None

    def load(self, name: str, sample: Optional[int] = None) -> List[CachedExample]:
        # 1) Prefer prepared cache under exp/exp2/data
        cached_path = self._cached_path(name)
        if cached_path:
            return self._sample(load_cached(cached_path), sample)

        return self.load_raw(name, sample=sample)

    def load_raw(self, name: str, sample: Optional[int] = None) -> List[CachedExample]:
        def _looks_like_json_array(path: Path) -> bool:
            try:
                with path.open("r", encoding="utf-8") as f:
                    while True:
                        ch = f.read(1)
                        if not ch:
                            return False
                        if ch.isspace():
                            continue
                        return ch == "["
            except OSError:
                return False

        # MoreHopQA
        if name == "morehopqa":
            ex = load_morehopqa()
            for item in ex:
                if "answer" in item.metadata:
                    item.metadata.setdefault("reference_answer", item.metadata["answer"])
            return self._sample(ex, sample)

        # Allow passing the raw MoreHopQA JSON path directly.
        p = Path(name)
        if p.exists() and _looks_like_json_array(p):
            ex = load_morehopqa(p)
            for item in ex:
                if "answer" in item.metadata:
                    item.metadata.setdefault("reference_answer", item.metadata["answer"])
            return self._sample(ex, sample)

        # RULER / HotpotQA / niah / vt (all go through RulerAttributionDataset)
        resolved = dataset_from_name(name)
        if resolved is None:
            raise FileNotFoundError(f"Could not resolve dataset {name}")
        ex = load_ruler(resolved)
        for item in ex:
            outputs = item.metadata.get("outputs") or []
            if outputs:
                item.metadata.setdefault("reference_answer", ", ".join(outputs))
            if item.target and "reference_answer" not in item.metadata:
                item.metadata["reference_answer"] = item.target
        return self._sample(ex, sample)