#!/usr/bin/env python3 """ Compute attribution mass on (input, cot, output) segments from exp3 trace npz files. Definitions (token-level, aligned with exp2/exp3 runners): - input : prompt-side tokens (user prompt), indices [0, prompt_len) - cot : generation tokens in thinking span, indices [prompt_len + t0, prompt_len + t1] - output : generation tokens in sink span (answer), indices [prompt_len + s0, prompt_len + s1] The trace stores token-importance vectors: - v_seq_all, v_row_all, v_rec_all (length = prompt_len + gen_len) This script sums those vectors over each segment and reports both absolute sums and fractions of the total sum. """ from __future__ import annotations import argparse import json from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple import numpy as np @dataclass(frozen=True) class TraceRun: dataset: str model: str run_dir: Path def _pick_single_subdir(parent: Path) -> Path: subdirs = [p for p in parent.iterdir() if p.is_dir()] if not subdirs: raise FileNotFoundError(f"No subdirectories found under {parent}") if len(subdirs) == 1: return subdirs[0] subdirs.sort(key=lambda p: p.stat().st_mtime, reverse=True) return subdirs[0] def _resolve_run( trace_root: Path, *, dataset: str, model: Optional[str], run_tag: Optional[str], ) -> TraceRun: ds_dir = trace_root / dataset if not ds_dir.exists(): raise FileNotFoundError(f"Dataset trace directory not found: {ds_dir}") if model is None: model_dir = _pick_single_subdir(ds_dir) else: model_dir = ds_dir / model if not model_dir.exists(): raise FileNotFoundError(f"Model trace directory not found: {model_dir}") if run_tag is None: run_dir = _pick_single_subdir(model_dir) else: run_dir = model_dir / run_tag if not run_dir.exists(): raise FileNotFoundError(f"Run directory not found: {run_dir}") return TraceRun(dataset=dataset, model=model_dir.name, run_dir=run_dir) def _iter_manifest(run_dir: Path) -> Iterable[dict]: manifest = run_dir / "manifest.jsonl" if not manifest.exists(): raise FileNotFoundError(f"Missing manifest: {manifest}") with manifest.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if line: yield json.loads(line) def _as_span(arr: np.ndarray, *, name: str) -> Tuple[int, int]: if arr is None: raise ValueError(f"Missing {name} in trace npz.") a = np.asarray(arr).reshape(-1) if a.size != 2: raise ValueError(f"Expected {name} to have 2 ints, got shape {a.shape}.") return int(a[0]), int(a[1]) def _segment_sums( v: np.ndarray, *, prompt_len: int, gen_len: int, thinking_span_gen: Optional[Tuple[int, int]], sink_span_gen: Optional[Tuple[int, int]], ) -> Dict[str, float]: total_len = int(prompt_len) + int(gen_len) if int(v.shape[0]) != total_len: raise ValueError(f"Vector length mismatch: len(v)={int(v.shape[0])} vs prompt_len+gen_len={total_len}.") v = np.asarray(v, dtype=np.float64).reshape(-1) prompt_len = int(prompt_len) gen_len = int(gen_len) # Default: no cot/output when spans missing (should not happen in exp3). think_start, think_end = (0, -1) if thinking_span_gen is None else thinking_span_gen sink_start, sink_end = (0, -1) if sink_span_gen is None else sink_span_gen # Clamp spans into [0, gen_len-1]. def _clamp_span(a: int, b: int) -> Tuple[int, int]: a = max(0, min(int(a), gen_len - 1)) b = max(0, min(int(b), gen_len - 1)) if b < a: return 0, -1 return a, b think_start, think_end = _clamp_span(think_start, think_end) sink_start, sink_end = _clamp_span(sink_start, sink_end) mask = np.zeros((total_len,), dtype=bool) # input = all prompt tokens input_slice = slice(0, prompt_len) mask[input_slice] = True cot_slice = slice(prompt_len + think_start, prompt_len + think_end + 1) if think_end >= think_start else slice(0, 0) output_slice = slice(prompt_len + sink_start, prompt_len + sink_end + 1) if sink_end >= sink_start else slice(0, 0) mask[cot_slice] = True mask[output_slice] = True input_sum = float(v[input_slice].sum()) cot_sum = float(v[cot_slice].sum()) if think_end >= think_start else 0.0 output_sum = float(v[output_slice].sum()) if sink_end >= sink_start else 0.0 other_sum = float(v[~mask].sum()) total_sum = float(v.sum()) return { "total": total_sum, "input": input_sum, "cot": cot_sum, "output": output_sum, "other": other_sum, } def _with_fracs(sums: Dict[str, float]) -> Dict[str, float]: total = float(sums.get("total") or 0.0) if total <= 0.0: return {**sums, "input_frac": float("nan"), "cot_frac": float("nan"), "output_frac": float("nan"), "other_frac": float("nan")} return { **sums, "input_frac": float(sums["input"]) / total, "cot_frac": float(sums["cot"]) / total, "output_frac": float(sums["output"]) / total, "other_frac": float(sums["other"]) / total, } def _analyze_npz(npz_path: Path) -> Dict[str, dict]: d = np.load(npz_path) prompt_len = int(np.asarray(d["prompt_len"]).item()) gen_len = int(np.asarray(d["gen_len"]).item()) thinking_span_gen = _as_span(d["thinking_span_gen"], name="thinking_span_gen") if "thinking_span_gen" in d.files else None sink_span_gen = _as_span(d["sink_span_gen"], name="sink_span_gen") if "sink_span_gen" in d.files else None out: Dict[str, dict] = {"prompt_len": prompt_len, "gen_len": gen_len} for key in ("v_seq_all", "v_row_all", "v_rec_all"): if key not in d.files: raise ValueError(f"Missing {key} in trace npz: {npz_path}") sums = _segment_sums( d[key], prompt_len=prompt_len, gen_len=gen_len, thinking_span_gen=thinking_span_gen, sink_span_gen=sink_span_gen, ) out[key] = _with_fracs(sums) out["thinking_span_gen"] = list(thinking_span_gen) if thinking_span_gen is not None else None out["sink_span_gen"] = list(sink_span_gen) if sink_span_gen is not None else None return out def main() -> None: parser = argparse.ArgumentParser("Summarize input/cot/output attribution mass from exp3 traces.") parser.add_argument("--trace_root", type=str, default="exp/exp3/output/traces") parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2", help="Base tag; expands to _short_cot and _long_cot.") parser.add_argument("--datasets", type=str, default=None, help="Comma-separated dataset names (overrides --dataset_tag expansion).") parser.add_argument("--model", type=str, default=None, help="Model directory name under traces (default: auto if single).") parser.add_argument("--run_tag", type=str, default=None, help="Run tag directory (default: auto pick newest/single).") args = parser.parse_args() trace_root = Path(args.trace_root) if not trace_root.exists(): raise SystemExit(f"trace_root not found: {trace_root}") if args.datasets: datasets = [x.strip() for x in str(args.datasets).split(",") if x.strip()] else: datasets = [f"{args.dataset_tag}_short_cot", f"{args.dataset_tag}_long_cot"] for ds in datasets: run = _resolve_run(trace_root, dataset=ds, model=args.model, run_tag=args.run_tag) records = list(_iter_manifest(run.run_dir)) if not records: raise SystemExit(f"Empty manifest: {run.run_dir/'manifest.jsonl'}") for rec in records: npz_path = run.run_dir / str(rec["file"]) analysis = _analyze_npz(npz_path) print( json.dumps( { "dataset": run.dataset, "model": run.model, "run_dir": str(run.run_dir), "example_idx": int(rec.get("example_idx", -1)), **analysis, }, ensure_ascii=False, ) ) if __name__ == "__main__": main()