| |
| """Map exp2 cached JSONL token spans across tokenizers (Qwen -> Llama). |
| |
| Background |
| ---------- |
| `exp/exp2/run_exp.py` expects cached datasets to provide token-level generation spans: |
| |
| - indices_to_explain: [start_tok, end_tok] (generation-token indices; closed interval) |
| - sink_span / thinking_span: same tokenizer convention as indices_to_explain |
| |
| These spans are computed under a specific tokenizer (often Qwen3-8B). When switching |
| to a different model/tokenizer (e.g., Llama-3.1-8B-Instruct), the stored spans can |
| become out-of-range and crash exp2 attribution (IndexError in token-span checks). |
| |
| This script remaps spans by: |
| 1) Tokenizing `target` with the OLD tokenizer to obtain offset_mapping |
| 2) Converting the OLD token span into a character span in `target` |
| 3) Tokenizing `target` with the NEW tokenizer and mapping the character span back |
| into NEW token indices |
| |
| Outputs are written under `exp/exp5/data/` by default, keeping `exp/exp2/` untouched. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict, Iterable, List, Optional, Tuple |
|
|
| from transformers import AutoTokenizer |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[2] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
|
|
| def _split_args(values: Iterable[str]) -> List[str]: |
| out: List[str] = [] |
| for v in values: |
| for part in str(v).split(","): |
| part = part.strip() |
| if part: |
| out.append(part) |
| return out |
|
|
|
|
| def _load_tokenizer(tokenizer_model: str): |
| path = Path(tokenizer_model) |
| if path.exists(): |
| return AutoTokenizer.from_pretrained(path.as_posix(), local_files_only=True) |
| |
| return AutoTokenizer.from_pretrained(tokenizer_model) |
|
|
|
|
| def _is_token_span(span: Any) -> bool: |
| return ( |
| isinstance(span, list) |
| and len(span) == 2 |
| and all(isinstance(x, int) for x in span) |
| and span[0] >= 0 |
| and span[1] >= span[0] |
| ) |
|
|
|
|
| def _pick_old_span(obj: Dict[str, Any]) -> Optional[List[int]]: |
| span = obj.get("indices_to_explain") |
| if _is_token_span(span): |
| return list(span) |
| span = obj.get("sink_span") |
| if _is_token_span(span): |
| return list(span) |
| return None |
|
|
|
|
| def _offsets_to_char_span(offsets: Any, token_span: List[int]) -> Optional[Tuple[int, int]]: |
| """Convert a token span [start,end] to a character span [char_start,char_end) using offsets.""" |
| if offsets is None: |
| return None |
| if not isinstance(offsets, list): |
| return None |
| start_tok, end_tok = token_span |
| if end_tok >= len(offsets): |
| return None |
|
|
| char_starts: List[int] = [] |
| char_ends: List[int] = [] |
| for idx in range(start_tok, end_tok + 1): |
| off = offsets[idx] |
| if off is None: |
| continue |
| if not (isinstance(off, (list, tuple)) and len(off) == 2): |
| continue |
| try: |
| s, e = int(off[0]), int(off[1]) |
| except Exception: |
| continue |
| if e <= s: |
| continue |
| char_starts.append(s) |
| char_ends.append(e) |
|
|
| if not char_starts or not char_ends: |
| return None |
| return min(char_starts), max(char_ends) |
|
|
|
|
| def _char_span_to_token_span(offsets: Any, char_span: Tuple[int, int]) -> Optional[List[int]]: |
| """Convert a character span [char_start,char_end) to a token span [start,end] by overlap.""" |
| if offsets is None: |
| return None |
| if not isinstance(offsets, list): |
| return None |
| char_start, char_end = int(char_span[0]), int(char_span[1]) |
| if char_end <= char_start: |
| return None |
|
|
| hit: List[int] = [] |
| for tok_idx, off in enumerate(offsets): |
| if off is None: |
| continue |
| if not (isinstance(off, (list, tuple)) and len(off) == 2): |
| continue |
| try: |
| s, e = int(off[0]), int(off[1]) |
| except Exception: |
| continue |
| if e <= s: |
| continue |
| if s < char_end and e > char_start: |
| hit.append(int(tok_idx)) |
|
|
| if not hit: |
| return None |
| return [min(hit), max(hit)] |
|
|
|
|
| def _validate_span_with_eos(tokenizer, target: str, token_span: List[int]) -> bool: |
| eos = tokenizer.eos_token or "" |
| gen_ids = tokenizer(target + eos, add_special_tokens=False).input_ids |
| gen_len = int(len(gen_ids)) |
| return 0 <= token_span[0] <= token_span[1] < gen_len |
|
|
|
|
| def _guess_answer_text(obj: Dict[str, Any]) -> Optional[str]: |
| meta = obj.get("metadata") or {} |
| if isinstance(meta, dict): |
| boxed = (meta.get("boxed_answer") or "").strip() |
| if boxed: |
| return boxed |
| ref = (meta.get("reference_answer") or "").strip() |
| if ref: |
| return ref |
| tgt = obj.get("target") |
| if isinstance(tgt, str) and tgt.strip(): |
| |
| last_line = tgt.strip().splitlines()[-1].strip() |
| return last_line or None |
| return None |
|
|
|
|
| def _fallback_map_via_answer_text( |
| obj: Dict[str, Any], |
| *, |
| new_tokenizer, |
| ) -> Optional[List[int]]: |
| tgt = obj.get("target") |
| if not isinstance(tgt, str) or not tgt: |
| return None |
|
|
| from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer |
|
|
| answer_text = _guess_answer_text(obj) |
| ex = CachedExample( |
| prompt=str(obj.get("prompt") or ""), |
| target=tgt, |
| indices_to_explain=None, |
| attr_mask_indices=obj.get("attr_mask_indices"), |
| sink_span=None, |
| thinking_span=None, |
| metadata=obj.get("metadata") or {}, |
| ) |
| out = attach_spans_from_answer(ex, new_tokenizer, answer_text) |
| if out.sink_span is None: |
| return None |
| if not _is_token_span(out.sink_span): |
| return None |
| return list(out.sink_span) |
|
|
|
|
| def _map_one_obj( |
| obj: Dict[str, Any], |
| *, |
| old_tokenizer, |
| new_tokenizer, |
| allow_fallback_answer: bool, |
| ) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: |
| target = obj.get("target") |
| if not isinstance(target, str) or not target: |
| return None, "missing_target" |
|
|
| old_span = _pick_old_span(obj) |
| if old_span is None: |
| return None, "missing_old_span" |
|
|
| |
| old_enc = old_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True) |
| old_offsets = old_enc.get("offset_mapping") |
| char_span = _offsets_to_char_span(old_offsets, old_span) |
| if char_span is None: |
| if not allow_fallback_answer: |
| return None, "old_span_to_char_failed" |
| new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer) |
| if new_span is None: |
| return None, "fallback_answer_failed" |
| if not _validate_span_with_eos(new_tokenizer, target, new_span): |
| return None, "fallback_answer_span_invalid" |
| mapped = dict(obj) |
| mapped["indices_to_explain"] = new_span |
| mapped["sink_span"] = new_span |
| mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None |
| meta = mapped.get("metadata") |
| if not isinstance(meta, dict): |
| meta = {} |
| meta = dict(meta) |
| meta["exp5_span_map_method"] = "answer_text" |
| mapped["metadata"] = meta |
| return mapped, None |
|
|
| |
| new_enc = new_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True) |
| new_offsets = new_enc.get("offset_mapping") |
| new_span = _char_span_to_token_span(new_offsets, char_span) |
| if new_span is None: |
| if not allow_fallback_answer: |
| return None, "char_to_new_span_failed" |
| new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer) |
| if new_span is None: |
| return None, "fallback_answer_failed" |
|
|
| if not _validate_span_with_eos(new_tokenizer, target, new_span): |
| if not allow_fallback_answer: |
| return None, "new_span_invalid" |
| fb = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer) |
| if fb is None or not _validate_span_with_eos(new_tokenizer, target, fb): |
| return None, "fallback_answer_span_invalid" |
| new_span = fb |
|
|
| mapped = dict(obj) |
| mapped["indices_to_explain"] = new_span |
| mapped["sink_span"] = new_span |
| mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None |
|
|
| meta = mapped.get("metadata") |
| if not isinstance(meta, dict): |
| meta = {} |
| meta = dict(meta) |
| meta["exp5_span_map_method"] = "token_span_char_align" |
| mapped["metadata"] = meta |
| return mapped, None |
|
|
|
|
| def _read_jsonl(path: Path) -> Iterable[Dict[str, Any]]: |
| with path.open("r", encoding="utf-8") as f: |
| for line_no, line in enumerate(f, start=1): |
| if not line.strip(): |
| continue |
| try: |
| obj = json.loads(line) |
| except json.JSONDecodeError as exc: |
| raise RuntimeError(f"Invalid JSON at {path}:{line_no}: {exc}") from exc |
| if not isinstance(obj, dict): |
| raise RuntimeError(f"Expected JSON object per line at {path}:{line_no}.") |
| yield obj |
|
|
|
|
| def _write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| count = 0 |
| with path.open("w", encoding="utf-8") as f: |
| for obj in rows: |
| f.write(json.dumps(obj, ensure_ascii=False) + "\n") |
| count += 1 |
| return count |
|
|
|
|
| def _default_old_tokenizer() -> str: |
| |
| return "/opt/share/models/Qwen/Qwen3-8B" |
|
|
|
|
| def _default_new_tokenizer() -> str: |
| return "/opt/share/models/meta-llama/Llama-3.1-8B-Instruct" |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser("Map exp2 cache token spans from an old tokenizer to a new tokenizer.") |
| ap.add_argument( |
| "--in_jsonl", |
| type=str, |
| nargs="+", |
| required=True, |
| help="One or more exp2 cached JSONL files (comma-separated also accepted).", |
| ) |
| ap.add_argument( |
| "--out_dir", |
| type=str, |
| default="exp/exp5/data", |
| help="Output directory for mapped JSONL files.", |
| ) |
| ap.add_argument( |
| "--old_tokenizer_model", |
| type=str, |
| default=_default_old_tokenizer(), |
| help="Tokenizer used to produce the original token spans (default: Qwen3-8B local path).", |
| ) |
| ap.add_argument( |
| "--new_tokenizer_model", |
| type=str, |
| default=_default_new_tokenizer(), |
| help="Tokenizer to map spans into (default: Llama-3.1-8B-Instruct local path).", |
| ) |
| ap.add_argument("--strict", action="store_true", help="Fail on the first example that cannot be mapped.") |
| ap.add_argument( |
| "--allow_fallback_answer", |
| action="store_true", |
| help=( |
| "If span alignment fails, try to recompute spans by locating metadata.boxed_answer in target " |
| "(useful when caches were not built with the assumed old tokenizer)." |
| ), |
| ) |
| ap.add_argument( |
| "--overwrite", |
| action="store_true", |
| help="Overwrite output files if they already exist.", |
| ) |
| args = ap.parse_args() |
|
|
| in_paths = [Path(p) for p in _split_args(args.in_jsonl)] |
| out_dir = Path(args.out_dir) |
|
|
| old_tok = _load_tokenizer(str(args.old_tokenizer_model)) |
| new_tok = _load_tokenizer(str(args.new_tokenizer_model)) |
|
|
| |
| if new_tok.pad_token is None and new_tok.eos_token is not None: |
| new_tok.pad_token = new_tok.eos_token |
|
|
| summary: Dict[str, Any] = { |
| "old_tokenizer_model": str(args.old_tokenizer_model), |
| "new_tokenizer_model": str(args.new_tokenizer_model), |
| "datasets": [], |
| } |
|
|
| for in_path in in_paths: |
| if not in_path.exists(): |
| raise SystemExit(f"Missing input JSONL: {in_path}") |
| out_path = out_dir / in_path.name |
| if out_path.exists() and not bool(args.overwrite): |
| raise SystemExit(f"Refusing to overwrite existing output: {out_path} (use --overwrite)") |
|
|
| total = 0 |
| mapped_ok = 0 |
| dropped = 0 |
| errors: Dict[str, int] = {} |
|
|
| mapped_rows: List[Dict[str, Any]] = [] |
| for obj in _read_jsonl(in_path): |
| total += 1 |
| mapped, err = _map_one_obj( |
| obj, |
| old_tokenizer=old_tok, |
| new_tokenizer=new_tok, |
| allow_fallback_answer=bool(args.allow_fallback_answer), |
| ) |
| if err is not None or mapped is None: |
| errors[err or "unknown_error"] = errors.get(err or "unknown_error", 0) + 1 |
| if bool(args.strict): |
| raise SystemExit(f"Failed to map {in_path} example #{total}: {err}") |
| dropped += 1 |
| continue |
| mapped_ok += 1 |
| mapped_rows.append(mapped) |
|
|
| written = _write_jsonl(out_path, mapped_rows) |
| if written != mapped_ok: |
| raise SystemExit(f"Internal error: written={written} != mapped_ok={mapped_ok}") |
|
|
| record = { |
| "in_jsonl": str(in_path), |
| "out_jsonl": str(out_path), |
| "total": int(total), |
| "mapped_ok": int(mapped_ok), |
| "dropped": int(dropped), |
| "errors": errors, |
| } |
| summary["datasets"].append(record) |
| print(json.dumps(record, ensure_ascii=False)) |
|
|
| |
| print(json.dumps(summary, ensure_ascii=False, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|