File size: 13,835 Bytes
55b60a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 | #!/usr/bin/env python3
"""Map exp2 cached JSONL token spans across tokenizers (Qwen -> Llama).
Background
----------
`exp/exp2/run_exp.py` expects cached datasets to provide token-level generation spans:
- indices_to_explain: [start_tok, end_tok] (generation-token indices; closed interval)
- sink_span / thinking_span: same tokenizer convention as indices_to_explain
These spans are computed under a specific tokenizer (often Qwen3-8B). When switching
to a different model/tokenizer (e.g., Llama-3.1-8B-Instruct), the stored spans can
become out-of-range and crash exp2 attribution (IndexError in token-span checks).
This script remaps spans by:
1) Tokenizing `target` with the OLD tokenizer to obtain offset_mapping
2) Converting the OLD token span into a character span in `target`
3) Tokenizing `target` with the NEW tokenizer and mapping the character span back
into NEW token indices
Outputs are written under `exp/exp5/data/` by default, keeping `exp/exp2/` untouched.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
from transformers import AutoTokenizer
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def _split_args(values: Iterable[str]) -> List[str]:
out: List[str] = []
for v in values:
for part in str(v).split(","):
part = part.strip()
if part:
out.append(part)
return out
def _load_tokenizer(tokenizer_model: str):
path = Path(tokenizer_model)
if path.exists():
return AutoTokenizer.from_pretrained(path.as_posix(), local_files_only=True)
# May require network access; keep as fallback for environments that allow it.
return AutoTokenizer.from_pretrained(tokenizer_model)
def _is_token_span(span: Any) -> bool:
return (
isinstance(span, list)
and len(span) == 2
and all(isinstance(x, int) for x in span)
and span[0] >= 0
and span[1] >= span[0]
)
def _pick_old_span(obj: Dict[str, Any]) -> Optional[List[int]]:
span = obj.get("indices_to_explain")
if _is_token_span(span):
return list(span)
span = obj.get("sink_span")
if _is_token_span(span):
return list(span)
return None
def _offsets_to_char_span(offsets: Any, token_span: List[int]) -> Optional[Tuple[int, int]]:
"""Convert a token span [start,end] to a character span [char_start,char_end) using offsets."""
if offsets is None:
return None
if not isinstance(offsets, list):
return None
start_tok, end_tok = token_span
if end_tok >= len(offsets):
return None
char_starts: List[int] = []
char_ends: List[int] = []
for idx in range(start_tok, end_tok + 1):
off = offsets[idx]
if off is None:
continue
if not (isinstance(off, (list, tuple)) and len(off) == 2):
continue
try:
s, e = int(off[0]), int(off[1])
except Exception:
continue
if e <= s:
continue
char_starts.append(s)
char_ends.append(e)
if not char_starts or not char_ends:
return None
return min(char_starts), max(char_ends)
def _char_span_to_token_span(offsets: Any, char_span: Tuple[int, int]) -> Optional[List[int]]:
"""Convert a character span [char_start,char_end) to a token span [start,end] by overlap."""
if offsets is None:
return None
if not isinstance(offsets, list):
return None
char_start, char_end = int(char_span[0]), int(char_span[1])
if char_end <= char_start:
return None
hit: List[int] = []
for tok_idx, off in enumerate(offsets):
if off is None:
continue
if not (isinstance(off, (list, tuple)) and len(off) == 2):
continue
try:
s, e = int(off[0]), int(off[1])
except Exception:
continue
if e <= s:
continue
if s < char_end and e > char_start:
hit.append(int(tok_idx))
if not hit:
return None
return [min(hit), max(hit)]
def _validate_span_with_eos(tokenizer, target: str, token_span: List[int]) -> bool:
eos = tokenizer.eos_token or ""
gen_ids = tokenizer(target + eos, add_special_tokens=False).input_ids
gen_len = int(len(gen_ids))
return 0 <= token_span[0] <= token_span[1] < gen_len
def _guess_answer_text(obj: Dict[str, Any]) -> Optional[str]:
meta = obj.get("metadata") or {}
if isinstance(meta, dict):
boxed = (meta.get("boxed_answer") or "").strip()
if boxed:
return boxed
ref = (meta.get("reference_answer") or "").strip()
if ref:
return ref
tgt = obj.get("target")
if isinstance(tgt, str) and tgt.strip():
# Common exp2 cache convention: last line is the final answer.
last_line = tgt.strip().splitlines()[-1].strip()
return last_line or None
return None
def _fallback_map_via_answer_text(
obj: Dict[str, Any],
*,
new_tokenizer,
) -> Optional[List[int]]:
tgt = obj.get("target")
if not isinstance(tgt, str) or not tgt:
return None
from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer # lazy import
answer_text = _guess_answer_text(obj)
ex = CachedExample(
prompt=str(obj.get("prompt") or ""),
target=tgt,
indices_to_explain=None,
attr_mask_indices=obj.get("attr_mask_indices"),
sink_span=None,
thinking_span=None,
metadata=obj.get("metadata") or {},
)
out = attach_spans_from_answer(ex, new_tokenizer, answer_text)
if out.sink_span is None:
return None
if not _is_token_span(out.sink_span):
return None
return list(out.sink_span)
def _map_one_obj(
obj: Dict[str, Any],
*,
old_tokenizer,
new_tokenizer,
allow_fallback_answer: bool,
) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
target = obj.get("target")
if not isinstance(target, str) or not target:
return None, "missing_target"
old_span = _pick_old_span(obj)
if old_span is None:
return None, "missing_old_span"
# 1) Old token span -> char span in target.
old_enc = old_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
old_offsets = old_enc.get("offset_mapping")
char_span = _offsets_to_char_span(old_offsets, old_span)
if char_span is None:
if not allow_fallback_answer:
return None, "old_span_to_char_failed"
new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
if new_span is None:
return None, "fallback_answer_failed"
if not _validate_span_with_eos(new_tokenizer, target, new_span):
return None, "fallback_answer_span_invalid"
mapped = dict(obj)
mapped["indices_to_explain"] = new_span
mapped["sink_span"] = new_span
mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
meta = mapped.get("metadata")
if not isinstance(meta, dict):
meta = {}
meta = dict(meta)
meta["exp5_span_map_method"] = "answer_text"
mapped["metadata"] = meta
return mapped, None
# 2) Char span -> new token span.
new_enc = new_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
new_offsets = new_enc.get("offset_mapping")
new_span = _char_span_to_token_span(new_offsets, char_span)
if new_span is None:
if not allow_fallback_answer:
return None, "char_to_new_span_failed"
new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
if new_span is None:
return None, "fallback_answer_failed"
if not _validate_span_with_eos(new_tokenizer, target, new_span):
if not allow_fallback_answer:
return None, "new_span_invalid"
fb = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
if fb is None or not _validate_span_with_eos(new_tokenizer, target, fb):
return None, "fallback_answer_span_invalid"
new_span = fb
mapped = dict(obj)
mapped["indices_to_explain"] = new_span
mapped["sink_span"] = new_span
mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
meta = mapped.get("metadata")
if not isinstance(meta, dict):
meta = {}
meta = dict(meta)
meta["exp5_span_map_method"] = "token_span_char_align"
mapped["metadata"] = meta
return mapped, None
def _read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
with path.open("r", encoding="utf-8") as f:
for line_no, line in enumerate(f, start=1):
if not line.strip():
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as exc: # pragma: no cover
raise RuntimeError(f"Invalid JSON at {path}:{line_no}: {exc}") from exc
if not isinstance(obj, dict):
raise RuntimeError(f"Expected JSON object per line at {path}:{line_no}.")
yield obj
def _write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int:
path.parent.mkdir(parents=True, exist_ok=True)
count = 0
with path.open("w", encoding="utf-8") as f:
for obj in rows:
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
count += 1
return count
def _default_old_tokenizer() -> str:
# Repo defaults used in exp2 README examples for span extraction.
return "/opt/share/models/Qwen/Qwen3-8B"
def _default_new_tokenizer() -> str:
return "/opt/share/models/meta-llama/Llama-3.1-8B-Instruct"
def main() -> None:
ap = argparse.ArgumentParser("Map exp2 cache token spans from an old tokenizer to a new tokenizer.")
ap.add_argument(
"--in_jsonl",
type=str,
nargs="+",
required=True,
help="One or more exp2 cached JSONL files (comma-separated also accepted).",
)
ap.add_argument(
"--out_dir",
type=str,
default="exp/exp5/data",
help="Output directory for mapped JSONL files.",
)
ap.add_argument(
"--old_tokenizer_model",
type=str,
default=_default_old_tokenizer(),
help="Tokenizer used to produce the original token spans (default: Qwen3-8B local path).",
)
ap.add_argument(
"--new_tokenizer_model",
type=str,
default=_default_new_tokenizer(),
help="Tokenizer to map spans into (default: Llama-3.1-8B-Instruct local path).",
)
ap.add_argument("--strict", action="store_true", help="Fail on the first example that cannot be mapped.")
ap.add_argument(
"--allow_fallback_answer",
action="store_true",
help=(
"If span alignment fails, try to recompute spans by locating metadata.boxed_answer in target "
"(useful when caches were not built with the assumed old tokenizer)."
),
)
ap.add_argument(
"--overwrite",
action="store_true",
help="Overwrite output files if they already exist.",
)
args = ap.parse_args()
in_paths = [Path(p) for p in _split_args(args.in_jsonl)]
out_dir = Path(args.out_dir)
old_tok = _load_tokenizer(str(args.old_tokenizer_model))
new_tok = _load_tokenizer(str(args.new_tokenizer_model))
# exp2 convention: ensure a pad token exists for downstream perturbation.
if new_tok.pad_token is None and new_tok.eos_token is not None:
new_tok.pad_token = new_tok.eos_token
summary: Dict[str, Any] = {
"old_tokenizer_model": str(args.old_tokenizer_model),
"new_tokenizer_model": str(args.new_tokenizer_model),
"datasets": [],
}
for in_path in in_paths:
if not in_path.exists():
raise SystemExit(f"Missing input JSONL: {in_path}")
out_path = out_dir / in_path.name
if out_path.exists() and not bool(args.overwrite):
raise SystemExit(f"Refusing to overwrite existing output: {out_path} (use --overwrite)")
total = 0
mapped_ok = 0
dropped = 0
errors: Dict[str, int] = {}
mapped_rows: List[Dict[str, Any]] = []
for obj in _read_jsonl(in_path):
total += 1
mapped, err = _map_one_obj(
obj,
old_tokenizer=old_tok,
new_tokenizer=new_tok,
allow_fallback_answer=bool(args.allow_fallback_answer),
)
if err is not None or mapped is None:
errors[err or "unknown_error"] = errors.get(err or "unknown_error", 0) + 1
if bool(args.strict):
raise SystemExit(f"Failed to map {in_path} example #{total}: {err}")
dropped += 1
continue
mapped_ok += 1
mapped_rows.append(mapped)
written = _write_jsonl(out_path, mapped_rows)
if written != mapped_ok: # pragma: no cover
raise SystemExit(f"Internal error: written={written} != mapped_ok={mapped_ok}")
record = {
"in_jsonl": str(in_path),
"out_jsonl": str(out_path),
"total": int(total),
"mapped_ok": int(mapped_ok),
"dropped": int(dropped),
"errors": errors,
}
summary["datasets"].append(record)
print(json.dumps(record, ensure_ascii=False))
# Human-readable compact summary at end.
print(json.dumps(summary, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()
|