| |
| """ |
| Extract CoT/output segment attribution weights from exp3 trace artifacts. |
| |
| Background |
| ---------- |
| exp/exp3/run_exp.py saves per-sample trace npz files that contain token-level |
| importance vectors over the FULL (prompt + generation) token sequence: |
| - v_seq_all: sum over rows of seq attribution matrix (shape [P+G]) |
| - v_row_all: row attribution vector for indices_to_explain (shape [P+G]) |
| - v_rec_all: recursive attribution vector for indices_to_explain (shape [P+G]) |
| |
| For exp3 cached samples, we also have generation-token spans: |
| - thinking_span_gen: CoT span [start,end] in generation-token coordinates |
| - sink_span_gen: output span [start,end] in generation-token coordinates |
| |
| This script slices v_*_all into: |
| - cot: tokens in thinking_span_gen (offset by prompt_len) |
| - output: tokens in sink_span_gen (offset by prompt_len) |
| |
| and reports segment sums/fractions (and optionally writes a JSON summary). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import numpy as np |
|
|
|
|
| @dataclass(frozen=True) |
| class TracePaths: |
| dataset: str |
| model_tag: str |
| run_tag: str |
| npz_path: Path |
|
|
|
|
| def _pick_latest_subdir(path: Path) -> Optional[Path]: |
| if not path.exists(): |
| return None |
| subs = [p for p in path.iterdir() if p.is_dir()] |
| if not subs: |
| return None |
| subs.sort(key=lambda p: p.stat().st_mtime, reverse=True) |
| return subs[0] |
|
|
|
|
| def _resolve_trace_paths( |
| *, |
| output_root: Path, |
| dataset: str, |
| model_tag: Optional[str], |
| run_tag: Optional[str], |
| example_idx: int, |
| ) -> TracePaths: |
| base = output_root / "traces" / dataset |
| if not base.exists(): |
| raise FileNotFoundError(f"Trace dataset dir not found: {base}") |
|
|
| if model_tag is None: |
| model_dirs = [p for p in base.iterdir() if p.is_dir()] |
| if not model_dirs: |
| raise FileNotFoundError(f"No model subdir under: {base}") |
| if len(model_dirs) != 1: |
| raise SystemExit(f"Multiple model dirs under {base}; pass --model_tag. Found: {[p.name for p in model_dirs]}") |
| model_dir = model_dirs[0] |
| model_tag = model_dir.name |
| else: |
| model_dir = base / model_tag |
| if not model_dir.exists(): |
| raise FileNotFoundError(f"Trace model dir not found: {model_dir}") |
|
|
| if run_tag is None: |
| run_dir = _pick_latest_subdir(model_dir) |
| if run_dir is None: |
| raise FileNotFoundError(f"No run subdir under: {model_dir}") |
| run_tag = run_dir.name |
| else: |
| run_dir = model_dir / run_tag |
| if not run_dir.exists(): |
| raise FileNotFoundError(f"Trace run dir not found: {run_dir}") |
|
|
| npz_name = f"ex_{int(example_idx):06d}.npz" |
| npz_path = run_dir / npz_name |
| if not npz_path.exists(): |
| raise FileNotFoundError(f"Trace npz not found: {npz_path}") |
|
|
| return TracePaths(dataset=dataset, model_tag=model_tag, run_tag=run_tag, npz_path=npz_path) |
|
|
|
|
| def _as_span(arr: Any) -> Optional[Tuple[int, int]]: |
| if arr is None: |
| return None |
| try: |
| a = np.asarray(arr).reshape(-1).tolist() |
| except Exception: |
| return None |
| if len(a) != 2: |
| return None |
| try: |
| start = int(a[0]) |
| end = int(a[1]) |
| except Exception: |
| return None |
| if start < 0 or end < start: |
| return None |
| return start, end |
|
|
|
|
| def _segment_stats(v: np.ndarray, start: int, end: int) -> Dict[str, float]: |
| if end < start: |
| return {"sum": 0.0, "mean": 0.0, "max": 0.0} |
| seg = v[start : end + 1] |
| if seg.size == 0: |
| return {"sum": 0.0, "mean": 0.0, "max": 0.0} |
| return { |
| "sum": float(seg.sum()), |
| "mean": float(seg.mean()), |
| "max": float(seg.max()), |
| } |
|
|
|
|
| def _slice_segment(v: np.ndarray, start: int, end: int) -> List[float]: |
| if end < start: |
| return [] |
| seg = v[start : end + 1] |
| return [float(x) for x in seg.tolist()] |
|
|
|
|
| def extract_one(npz_path: Path) -> Dict[str, Any]: |
| d = np.load(npz_path) |
| required = ["prompt_len", "gen_len", "v_seq_all", "v_row_all", "v_rec_all"] |
| for k in required: |
| if k not in d: |
| raise KeyError(f"Missing key in trace npz {npz_path}: {k}") |
|
|
| prompt_len = int(np.asarray(d["prompt_len"]).item()) |
| gen_len = int(np.asarray(d["gen_len"]).item()) |
| total_len = prompt_len + gen_len |
|
|
| v_seq_all = np.asarray(d["v_seq_all"], dtype=np.float64).reshape(-1) |
| v_row_all = np.asarray(d["v_row_all"], dtype=np.float64).reshape(-1) |
| v_rec_all = np.asarray(d["v_rec_all"], dtype=np.float64).reshape(-1) |
| for name, v in [("v_seq_all", v_seq_all), ("v_row_all", v_row_all), ("v_rec_all", v_rec_all)]: |
| if int(v.size) != int(total_len): |
| raise ValueError(f"{name} length mismatch: expected {total_len}, got {int(v.size)}") |
|
|
| sink_span_gen = _as_span(d.get("sink_span_gen")) |
| thinking_span_gen = _as_span(d.get("thinking_span_gen")) |
| if sink_span_gen is None: |
| raise KeyError("Trace missing sink_span_gen; cannot define output span.") |
| if thinking_span_gen is None: |
| |
| sink_start, _ = sink_span_gen |
| thinking_span_gen = (0, max(0, sink_start - 1)) |
|
|
| think_start_g, think_end_g = thinking_span_gen |
| sink_start_g, sink_end_g = sink_span_gen |
|
|
| cot_start = prompt_len + think_start_g |
| cot_end = min(prompt_len + think_end_g, total_len - 1) |
| out_start = prompt_len + sink_start_g |
| out_end = min(prompt_len + sink_end_g, total_len - 1) |
|
|
| def pack(v: np.ndarray) -> Dict[str, Any]: |
| total = float(v.sum()) |
| cot = _segment_stats(v, cot_start, cot_end) |
| out = _segment_stats(v, out_start, out_end) |
| denom = cot["sum"] + out["sum"] |
| return { |
| "total_sum": total, |
| "cot": { |
| "start_abs": int(cot_start), |
| "end_abs": int(cot_end), |
| "len": int(max(0, cot_end - cot_start + 1)), |
| **cot, |
| "fraction_of_total": float(cot["sum"] / total) if total > 0 else float("nan"), |
| "fraction_of_cot_plus_output": float(cot["sum"] / denom) if denom > 0 else float("nan"), |
| }, |
| "output": { |
| "start_abs": int(out_start), |
| "end_abs": int(out_end), |
| "len": int(max(0, out_end - out_start + 1)), |
| **out, |
| "fraction_of_total": float(out["sum"] / total) if total > 0 else float("nan"), |
| "fraction_of_cot_plus_output": float(out["sum"] / denom) if denom > 0 else float("nan"), |
| }, |
| "cot_weights": _slice_segment(v, cot_start, cot_end), |
| "output_weights": _slice_segment(v, out_start, out_end), |
| } |
|
|
| return { |
| "prompt_len": int(prompt_len), |
| "gen_len": int(gen_len), |
| "total_len": int(total_len), |
| "thinking_span_gen": [int(think_start_g), int(think_end_g)], |
| "sink_span_gen": [int(sink_start_g), int(sink_end_g)], |
| "seq": pack(v_seq_all), |
| "row": pack(v_row_all), |
| "rec": pack(v_rec_all), |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser("Extract CoT/output weights from exp3 traces.") |
| parser.add_argument("--output_root", type=str, default="exp/exp3/output") |
| parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2") |
| parser.add_argument("--model_tag", type=str, default=None, help="If omitted, auto-detect when unique.") |
| parser.add_argument("--run_tag", type=str, default=None, help="If omitted, picks the latest run subdir.") |
| parser.add_argument("--example_idx", type=int, default=0) |
| parser.add_argument("--out", type=str, default=None, help="Optional JSON output path.") |
| args = parser.parse_args() |
|
|
| output_root = Path(args.output_root) |
| datasets = [f"{args.dataset_tag}_short_cot", f"{args.dataset_tag}_long_cot"] |
|
|
| results: List[Dict[str, Any]] = [] |
| for ds_name in datasets: |
| paths = _resolve_trace_paths( |
| output_root=output_root, |
| dataset=ds_name, |
| model_tag=args.model_tag, |
| run_tag=args.run_tag, |
| example_idx=args.example_idx, |
| ) |
| out = extract_one(paths.npz_path) |
| out["dataset"] = paths.dataset |
| out["model_tag"] = paths.model_tag |
| out["run_tag"] = paths.run_tag |
| out["npz_path"] = str(paths.npz_path) |
| results.append(out) |
|
|
| text = json.dumps(results, ensure_ascii=False, indent=2) |
| if args.out: |
| out_path = Path(args.out) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| out_path.write_text(text + "\n", encoding="utf-8") |
| print(f"Wrote -> {out_path}") |
| else: |
| print(text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|