File size: 13,835 Bytes
55b60a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
#!/usr/bin/env python3
"""Map exp2 cached JSONL token spans across tokenizers (Qwen -> Llama).

Background
----------
`exp/exp2/run_exp.py` expects cached datasets to provide token-level generation spans:

  - indices_to_explain: [start_tok, end_tok] (generation-token indices; closed interval)
  - sink_span / thinking_span: same tokenizer convention as indices_to_explain

These spans are computed under a specific tokenizer (often Qwen3-8B). When switching
to a different model/tokenizer (e.g., Llama-3.1-8B-Instruct), the stored spans can
become out-of-range and crash exp2 attribution (IndexError in token-span checks).

This script remaps spans by:
  1) Tokenizing `target` with the OLD tokenizer to obtain offset_mapping
  2) Converting the OLD token span into a character span in `target`
  3) Tokenizing `target` with the NEW tokenizer and mapping the character span back
     into NEW token indices

Outputs are written under `exp/exp5/data/` by default, keeping `exp/exp2/` untouched.
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

from transformers import AutoTokenizer


REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))


def _split_args(values: Iterable[str]) -> List[str]:
    out: List[str] = []
    for v in values:
        for part in str(v).split(","):
            part = part.strip()
            if part:
                out.append(part)
    return out


def _load_tokenizer(tokenizer_model: str):
    path = Path(tokenizer_model)
    if path.exists():
        return AutoTokenizer.from_pretrained(path.as_posix(), local_files_only=True)
    # May require network access; keep as fallback for environments that allow it.
    return AutoTokenizer.from_pretrained(tokenizer_model)


def _is_token_span(span: Any) -> bool:
    return (
        isinstance(span, list)
        and len(span) == 2
        and all(isinstance(x, int) for x in span)
        and span[0] >= 0
        and span[1] >= span[0]
    )


def _pick_old_span(obj: Dict[str, Any]) -> Optional[List[int]]:
    span = obj.get("indices_to_explain")
    if _is_token_span(span):
        return list(span)
    span = obj.get("sink_span")
    if _is_token_span(span):
        return list(span)
    return None


def _offsets_to_char_span(offsets: Any, token_span: List[int]) -> Optional[Tuple[int, int]]:
    """Convert a token span [start,end] to a character span [char_start,char_end) using offsets."""
    if offsets is None:
        return None
    if not isinstance(offsets, list):
        return None
    start_tok, end_tok = token_span
    if end_tok >= len(offsets):
        return None

    char_starts: List[int] = []
    char_ends: List[int] = []
    for idx in range(start_tok, end_tok + 1):
        off = offsets[idx]
        if off is None:
            continue
        if not (isinstance(off, (list, tuple)) and len(off) == 2):
            continue
        try:
            s, e = int(off[0]), int(off[1])
        except Exception:
            continue
        if e <= s:
            continue
        char_starts.append(s)
        char_ends.append(e)

    if not char_starts or not char_ends:
        return None
    return min(char_starts), max(char_ends)


def _char_span_to_token_span(offsets: Any, char_span: Tuple[int, int]) -> Optional[List[int]]:
    """Convert a character span [char_start,char_end) to a token span [start,end] by overlap."""
    if offsets is None:
        return None
    if not isinstance(offsets, list):
        return None
    char_start, char_end = int(char_span[0]), int(char_span[1])
    if char_end <= char_start:
        return None

    hit: List[int] = []
    for tok_idx, off in enumerate(offsets):
        if off is None:
            continue
        if not (isinstance(off, (list, tuple)) and len(off) == 2):
            continue
        try:
            s, e = int(off[0]), int(off[1])
        except Exception:
            continue
        if e <= s:
            continue
        if s < char_end and e > char_start:
            hit.append(int(tok_idx))

    if not hit:
        return None
    return [min(hit), max(hit)]


def _validate_span_with_eos(tokenizer, target: str, token_span: List[int]) -> bool:
    eos = tokenizer.eos_token or ""
    gen_ids = tokenizer(target + eos, add_special_tokens=False).input_ids
    gen_len = int(len(gen_ids))
    return 0 <= token_span[0] <= token_span[1] < gen_len


def _guess_answer_text(obj: Dict[str, Any]) -> Optional[str]:
    meta = obj.get("metadata") or {}
    if isinstance(meta, dict):
        boxed = (meta.get("boxed_answer") or "").strip()
        if boxed:
            return boxed
        ref = (meta.get("reference_answer") or "").strip()
        if ref:
            return ref
    tgt = obj.get("target")
    if isinstance(tgt, str) and tgt.strip():
        # Common exp2 cache convention: last line is the final answer.
        last_line = tgt.strip().splitlines()[-1].strip()
        return last_line or None
    return None


def _fallback_map_via_answer_text(
    obj: Dict[str, Any],
    *,
    new_tokenizer,
) -> Optional[List[int]]:
    tgt = obj.get("target")
    if not isinstance(tgt, str) or not tgt:
        return None

    from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer  # lazy import

    answer_text = _guess_answer_text(obj)
    ex = CachedExample(
        prompt=str(obj.get("prompt") or ""),
        target=tgt,
        indices_to_explain=None,
        attr_mask_indices=obj.get("attr_mask_indices"),
        sink_span=None,
        thinking_span=None,
        metadata=obj.get("metadata") or {},
    )
    out = attach_spans_from_answer(ex, new_tokenizer, answer_text)
    if out.sink_span is None:
        return None
    if not _is_token_span(out.sink_span):
        return None
    return list(out.sink_span)


def _map_one_obj(
    obj: Dict[str, Any],
    *,
    old_tokenizer,
    new_tokenizer,
    allow_fallback_answer: bool,
) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
    target = obj.get("target")
    if not isinstance(target, str) or not target:
        return None, "missing_target"

    old_span = _pick_old_span(obj)
    if old_span is None:
        return None, "missing_old_span"

    # 1) Old token span -> char span in target.
    old_enc = old_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
    old_offsets = old_enc.get("offset_mapping")
    char_span = _offsets_to_char_span(old_offsets, old_span)
    if char_span is None:
        if not allow_fallback_answer:
            return None, "old_span_to_char_failed"
        new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
        if new_span is None:
            return None, "fallback_answer_failed"
        if not _validate_span_with_eos(new_tokenizer, target, new_span):
            return None, "fallback_answer_span_invalid"
        mapped = dict(obj)
        mapped["indices_to_explain"] = new_span
        mapped["sink_span"] = new_span
        mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
        meta = mapped.get("metadata")
        if not isinstance(meta, dict):
            meta = {}
        meta = dict(meta)
        meta["exp5_span_map_method"] = "answer_text"
        mapped["metadata"] = meta
        return mapped, None

    # 2) Char span -> new token span.
    new_enc = new_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
    new_offsets = new_enc.get("offset_mapping")
    new_span = _char_span_to_token_span(new_offsets, char_span)
    if new_span is None:
        if not allow_fallback_answer:
            return None, "char_to_new_span_failed"
        new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
        if new_span is None:
            return None, "fallback_answer_failed"

    if not _validate_span_with_eos(new_tokenizer, target, new_span):
        if not allow_fallback_answer:
            return None, "new_span_invalid"
        fb = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
        if fb is None or not _validate_span_with_eos(new_tokenizer, target, fb):
            return None, "fallback_answer_span_invalid"
        new_span = fb

    mapped = dict(obj)
    mapped["indices_to_explain"] = new_span
    mapped["sink_span"] = new_span
    mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None

    meta = mapped.get("metadata")
    if not isinstance(meta, dict):
        meta = {}
    meta = dict(meta)
    meta["exp5_span_map_method"] = "token_span_char_align"
    mapped["metadata"] = meta
    return mapped, None


def _read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
    with path.open("r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, start=1):
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError as exc:  # pragma: no cover
                raise RuntimeError(f"Invalid JSON at {path}:{line_no}: {exc}") from exc
            if not isinstance(obj, dict):
                raise RuntimeError(f"Expected JSON object per line at {path}:{line_no}.")
            yield obj


def _write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int:
    path.parent.mkdir(parents=True, exist_ok=True)
    count = 0
    with path.open("w", encoding="utf-8") as f:
        for obj in rows:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")
            count += 1
    return count


def _default_old_tokenizer() -> str:
    # Repo defaults used in exp2 README examples for span extraction.
    return "/opt/share/models/Qwen/Qwen3-8B"


def _default_new_tokenizer() -> str:
    return "/opt/share/models/meta-llama/Llama-3.1-8B-Instruct"


def main() -> None:
    ap = argparse.ArgumentParser("Map exp2 cache token spans from an old tokenizer to a new tokenizer.")
    ap.add_argument(
        "--in_jsonl",
        type=str,
        nargs="+",
        required=True,
        help="One or more exp2 cached JSONL files (comma-separated also accepted).",
    )
    ap.add_argument(
        "--out_dir",
        type=str,
        default="exp/exp5/data",
        help="Output directory for mapped JSONL files.",
    )
    ap.add_argument(
        "--old_tokenizer_model",
        type=str,
        default=_default_old_tokenizer(),
        help="Tokenizer used to produce the original token spans (default: Qwen3-8B local path).",
    )
    ap.add_argument(
        "--new_tokenizer_model",
        type=str,
        default=_default_new_tokenizer(),
        help="Tokenizer to map spans into (default: Llama-3.1-8B-Instruct local path).",
    )
    ap.add_argument("--strict", action="store_true", help="Fail on the first example that cannot be mapped.")
    ap.add_argument(
        "--allow_fallback_answer",
        action="store_true",
        help=(
            "If span alignment fails, try to recompute spans by locating metadata.boxed_answer in target "
            "(useful when caches were not built with the assumed old tokenizer)."
        ),
    )
    ap.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite output files if they already exist.",
    )
    args = ap.parse_args()

    in_paths = [Path(p) for p in _split_args(args.in_jsonl)]
    out_dir = Path(args.out_dir)

    old_tok = _load_tokenizer(str(args.old_tokenizer_model))
    new_tok = _load_tokenizer(str(args.new_tokenizer_model))

    # exp2 convention: ensure a pad token exists for downstream perturbation.
    if new_tok.pad_token is None and new_tok.eos_token is not None:
        new_tok.pad_token = new_tok.eos_token

    summary: Dict[str, Any] = {
        "old_tokenizer_model": str(args.old_tokenizer_model),
        "new_tokenizer_model": str(args.new_tokenizer_model),
        "datasets": [],
    }

    for in_path in in_paths:
        if not in_path.exists():
            raise SystemExit(f"Missing input JSONL: {in_path}")
        out_path = out_dir / in_path.name
        if out_path.exists() and not bool(args.overwrite):
            raise SystemExit(f"Refusing to overwrite existing output: {out_path} (use --overwrite)")

        total = 0
        mapped_ok = 0
        dropped = 0
        errors: Dict[str, int] = {}

        mapped_rows: List[Dict[str, Any]] = []
        for obj in _read_jsonl(in_path):
            total += 1
            mapped, err = _map_one_obj(
                obj,
                old_tokenizer=old_tok,
                new_tokenizer=new_tok,
                allow_fallback_answer=bool(args.allow_fallback_answer),
            )
            if err is not None or mapped is None:
                errors[err or "unknown_error"] = errors.get(err or "unknown_error", 0) + 1
                if bool(args.strict):
                    raise SystemExit(f"Failed to map {in_path} example #{total}: {err}")
                dropped += 1
                continue
            mapped_ok += 1
            mapped_rows.append(mapped)

        written = _write_jsonl(out_path, mapped_rows)
        if written != mapped_ok:  # pragma: no cover
            raise SystemExit(f"Internal error: written={written} != mapped_ok={mapped_ok}")

        record = {
            "in_jsonl": str(in_path),
            "out_jsonl": str(out_path),
            "total": int(total),
            "mapped_ok": int(mapped_ok),
            "dropped": int(dropped),
            "errors": errors,
        }
        summary["datasets"].append(record)
        print(json.dumps(record, ensure_ascii=False))

    # Human-readable compact summary at end.
    print(json.dumps(summary, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()