| """Dataset helpers for Experiment 2 (CoT / multi-hop faithfulness). |
| |
| Named dataset_utils to avoid collision with the HF `datasets` package. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import random |
| import re |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, Iterable, List, Optional |
|
|
| from attribution_datasets import ( |
| AttributionExample, |
| MoreHopQAAttributionDataset, |
| RulerAttributionDataset, |
| ) |
|
|
|
|
| @dataclass |
| class CachedExample: |
| prompt: str |
| target: Optional[str] |
| indices_to_explain: Optional[List[int]] |
| attr_mask_indices: Optional[List[int]] |
| sink_span: Optional[List[int]] |
| thinking_span: Optional[List[int]] |
| metadata: Dict[str, Any] |
|
|
|
|
| def read_cached_jsonl(path: Path) -> List[CachedExample]: |
| examples: List[CachedExample] = [] |
| with path.open("r", encoding="utf-8") as f: |
| for line in f: |
| if not line.strip(): |
| continue |
| obj = json.loads(line) |
| examples.append( |
| CachedExample( |
| prompt=obj["prompt"], |
| target=obj.get("target"), |
| indices_to_explain=obj.get("indices_to_explain"), |
| attr_mask_indices=obj.get("attr_mask_indices"), |
| sink_span=obj.get("sink_span"), |
| thinking_span=obj.get("thinking_span"), |
| metadata=obj.get("metadata", {}), |
| ) |
| ) |
| return examples |
|
|
|
|
| def load_cached(path: Path, sample: Optional[int] = None, seed: int = 42) -> List[CachedExample]: |
| ex = read_cached_jsonl(path) |
| if sample is not None and sample < len(ex): |
| random.Random(seed).shuffle(ex) |
| ex = ex[:sample] |
| return ex |
|
|
|
|
| def load_ruler(path: Path, sample: Optional[int] = None, seed: int = 42) -> List[CachedExample]: |
| ds = RulerAttributionDataset(path) |
| examples: List[CachedExample] = [] |
| ex_iter: Iterable[AttributionExample] = ds |
| if sample is not None and sample < len(ds): |
| ex_iter = list(ds) |
| random.Random(seed).shuffle(ex_iter) |
| ex_iter = ex_iter[:sample] |
| for ex in ex_iter: |
| examples.append( |
| CachedExample( |
| prompt=ex.prompt, |
| target=ex.target, |
| indices_to_explain=ex.indices_to_explain, |
| attr_mask_indices=ex.attr_mask_indices, |
| sink_span=None, |
| thinking_span=None, |
| metadata=ex.metadata, |
| ) |
| ) |
| return examples |
|
|
|
|
| def load_morehopqa( |
| path: str | Path = "./data/with_human_verification.json", sample: Optional[int] = None, seed: int = 42 |
| ) -> List[CachedExample]: |
| ds = MoreHopQAAttributionDataset(path) |
| ex_iter: Iterable[AttributionExample] = ds |
| if sample is not None and sample < len(ds): |
| ex_iter = list(ds) |
| random.Random(seed).shuffle(ex_iter) |
| ex_iter = ex_iter[:sample] |
| examples: List[CachedExample] = [] |
| for ex in ex_iter: |
| examples.append( |
| CachedExample( |
| prompt=ex.prompt, |
| target=None, |
| indices_to_explain=ex.indices_to_explain, |
| attr_mask_indices=ex.attr_mask_indices, |
| sink_span=None, |
| thinking_span=None, |
| metadata=ex.metadata, |
| ) |
| ) |
| return examples |
|
|
|
|
| def auto_find_ruler(task: str) -> Optional[Path]: |
| length_dirs = ["4096", "8192", "16384", "32768", "65536", "131072"] |
| base = Path("data/ruler_multihop") |
| for ld in length_dirs: |
| cand = base / ld / task / "validation.jsonl" |
| if cand.exists(): |
| return cand |
| return None |
|
|
|
|
| def dataset_from_name(name: str) -> Optional[Path]: |
| if name == "hotpotqa_long": |
| return auto_find_ruler("hotpotqa_long") |
| if name.startswith("vt_"): |
| return auto_find_ruler(name) |
| if name.startswith("niah"): |
| return auto_find_ruler(name) |
| p = Path(name) |
| if p.exists(): |
| return p |
| return None |
|
|
|
|
| _BOX_PATTERN = re.compile(r"\\box(?:ed)?\s*[\{{](.*?)[\}}]", flags=re.DOTALL) |
|
|
|
|
| def _find_box_span(text: str) -> Optional[tuple[int, int, str]]: |
| """Return (start_char, end_char, answer_text) for the last \\boxed block.""" |
| matches = list(_BOX_PATTERN.finditer(text)) |
| if not matches: |
| return None |
| m = matches[-1] |
| return m.start(0), m.end(0), m.group(1).strip() |
|
|
|
|
| def extract_boxed_answer(text: str) -> Optional[str]: |
| """Extract the answer string inside the last \\boxed{} block.""" |
| match = _find_box_span(text) |
| return match[2] if match else None |
|
|
|
|
| def _find_answer_span(text: str, answer: str) -> Optional[tuple[int, int]]: |
| """Return (start_char, end_char) for the last occurrence of `answer` in text.""" |
| if not answer or not text: |
| return None |
| start = text.rfind(answer) |
| if start == -1: |
| return None |
| return start, start + len(answer) |
|
|
|
|
| def split_boxed_generation(text: str) -> Optional[tuple[str, str, str]]: |
| """Return (thinking_text, boxed_segment, boxed_answer) if format matches.""" |
| if not text: |
| return None |
| match = _find_box_span(text) |
| if not match: |
| return None |
|
|
| start_char, end_char, boxed_inner = match |
| boxed_segment = text[start_char:end_char].strip() |
| thinking_text = text[:start_char].strip() |
| trailing = text[end_char:].strip() |
|
|
| if not boxed_inner or not boxed_segment: |
| return None |
| if trailing: |
| return None |
| if not thinking_text: |
| return None |
|
|
| return thinking_text, boxed_segment, boxed_inner |
|
|
|
|
| def attach_spans_from_answer( |
| example: CachedExample, tokenizer, answer_text: Optional[str] = None |
| ) -> CachedExample: |
| """Attach sink/thinking spans by locating the (plain) answer in `target`. |
| |
| `answer_text` should be the extracted boxed answer; falls back to metadata or |
| parsing the target when omitted. Works even when the target no longer keeps |
| the \\box{} wrapper. |
| """ |
| tgt = example.target or "" |
| answer = (answer_text or "").strip() |
| if not answer: |
| answer = (example.metadata.get("boxed_answer") or extract_boxed_answer(tgt) or "").strip() |
|
|
| metadata = dict(example.metadata) |
| if answer: |
| metadata.setdefault("boxed_answer", answer) |
|
|
| if tokenizer is None or not tgt or not answer: |
| return CachedExample( |
| prompt=example.prompt, |
| target=example.target, |
| indices_to_explain=example.indices_to_explain, |
| attr_mask_indices=example.attr_mask_indices, |
| sink_span=example.sink_span, |
| thinking_span=example.thinking_span, |
| metadata=metadata, |
| ) |
|
|
| span = _find_answer_span(tgt, answer) |
| if span is None: |
| return CachedExample( |
| prompt=example.prompt, |
| target=example.target, |
| indices_to_explain=example.indices_to_explain, |
| attr_mask_indices=example.attr_mask_indices, |
| sink_span=example.sink_span, |
| thinking_span=example.thinking_span, |
| metadata=metadata, |
| ) |
|
|
| span_start_char, span_end_char = span |
| gen_ids = tokenizer(tgt, add_special_tokens=False, return_offsets_mapping=True) |
| sink_tokens: List[int] = [] |
| for idx, (s, e) in enumerate(gen_ids["offset_mapping"]): |
| |
| if s < span_end_char and e > span_start_char: |
| sink_tokens.append(idx) |
| if not sink_tokens: |
| return CachedExample( |
| prompt=example.prompt, |
| target=example.target, |
| indices_to_explain=example.indices_to_explain, |
| attr_mask_indices=example.attr_mask_indices, |
| sink_span=example.sink_span, |
| thinking_span=example.thinking_span, |
| metadata=metadata, |
| ) |
|
|
| sink_span = [min(sink_tokens), max(sink_tokens)] |
| thinking_end = max(0, sink_span[0] - 1) |
| thinking_span = [0, thinking_end] if thinking_end >= 0 else sink_span |
|
|
| return CachedExample( |
| prompt=example.prompt, |
| target=example.target, |
| indices_to_explain=example.indices_to_explain, |
| attr_mask_indices=example.attr_mask_indices, |
| sink_span=example.sink_span or sink_span, |
| thinking_span=example.thinking_span or thinking_span, |
| metadata=metadata, |
| ) |
|
|
|
|
| def attach_spans_from_boxed(example: CachedExample, tokenizer) -> CachedExample: |
| """Backward-compatible wrapper that first looks for \\box{} then falls back to answer text.""" |
| tgt = example.target |
| match = _find_box_span(tgt) if tgt else None |
| boxed_answer = match[2] if match else None |
| return attach_spans_from_answer(example, tokenizer, boxed_answer) |
|
|
|
|
| def ruler_gold_prompt_token_indices(example: CachedExample, tokenizer) -> List[int]: |
| """Return token indices (prompt-side) that overlap RULER `needle_spans` in metadata. |
| |
| The returned indices are with respect to `tokenizer(" " + example.prompt, add_special_tokens=False)`, |
| matching the attribution pipeline's leading-space convention. |
| """ |
| needle_spans = (example.metadata or {}).get("needle_spans") or [] |
| if not isinstance(needle_spans, list) or not needle_spans: |
| return [] |
|
|
| prompt_text = " " + (example.prompt or "") |
| enc = tokenizer(prompt_text, add_special_tokens=False, return_offsets_mapping=True) |
| offsets = enc.get("offset_mapping") |
| if offsets is None: |
| raise ValueError("Tokenizer does not provide offset_mapping; cannot map needle_spans to tokens.") |
|
|
| spans: List[tuple[int, int]] = [] |
| for item in needle_spans: |
| if not isinstance(item, dict): |
| continue |
| raw = item.get("span") |
| if not (isinstance(raw, list) and len(raw) == 2): |
| continue |
| try: |
| start = int(raw[0]) + 1 |
| end = int(raw[1]) + 1 |
| except Exception: |
| continue |
| if end > start: |
| spans.append((start, end)) |
|
|
| if not spans: |
| return [] |
|
|
| gold: set[int] = set() |
| for tok_idx, off in enumerate(offsets): |
| if off is None: |
| continue |
| try: |
| s, e = int(off[0]), int(off[1]) |
| except Exception: |
| continue |
| if e <= s: |
| continue |
| for span_start, span_end in spans: |
| if s < span_end and e > span_start: |
| gold.add(tok_idx) |
| break |
|
|
| return sorted(gold) |
|
|
|
|
| class DatasetLoader: |
| """Thin loader that resolves and samples datasets for exp2.""" |
|
|
| def __init__(self, seed: int = 42, data_root: Path | str = Path("exp/exp2/data")) -> None: |
| self.seed = seed |
| self.data_root = Path(data_root) |
|
|
| def _sample(self, items: List[CachedExample], sample: Optional[int]) -> List[CachedExample]: |
| if sample is not None and sample < len(items): |
| rnd = random.Random(self.seed) |
| rnd.shuffle(items) |
| items = items[:sample] |
| return items |
|
|
| def _cached_path(self, name: str) -> Optional[Path]: |
| path = self.data_root / f"{name}.jsonl" |
| return path if path.exists() else None |
|
|
| def load(self, name: str, sample: Optional[int] = None) -> List[CachedExample]: |
| |
| cached_path = self._cached_path(name) |
| if cached_path: |
| return self._sample(load_cached(cached_path), sample) |
|
|
| return self.load_raw(name, sample=sample) |
|
|
| def load_raw(self, name: str, sample: Optional[int] = None) -> List[CachedExample]: |
| def _looks_like_json_array(path: Path) -> bool: |
| try: |
| with path.open("r", encoding="utf-8") as f: |
| while True: |
| ch = f.read(1) |
| if not ch: |
| return False |
| if ch.isspace(): |
| continue |
| return ch == "[" |
| except OSError: |
| return False |
|
|
| |
| if name == "morehopqa": |
| ex = load_morehopqa() |
| for item in ex: |
| if "answer" in item.metadata: |
| item.metadata.setdefault("reference_answer", item.metadata["answer"]) |
| return self._sample(ex, sample) |
|
|
| |
| p = Path(name) |
| if p.exists() and _looks_like_json_array(p): |
| ex = load_morehopqa(p) |
| for item in ex: |
| if "answer" in item.metadata: |
| item.metadata.setdefault("reference_answer", item.metadata["answer"]) |
| return self._sample(ex, sample) |
|
|
| |
| resolved = dataset_from_name(name) |
| if resolved is None: |
| raise FileNotFoundError(f"Could not resolve dataset {name}") |
| ex = load_ruler(resolved) |
| for item in ex: |
| outputs = item.metadata.get("outputs") or [] |
| if outputs: |
| item.metadata.setdefault("reference_answer", ", ".join(outputs)) |
| if item.target and "reference_answer" not in item.metadata: |
| item.metadata["reference_answer"] = item.target |
| return self._sample(ex, sample) |
|
|