File size: 8,878 Bytes
55b60a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
#!/usr/bin/env python3
"""
Extract CoT/output segment attribution weights from exp3 trace artifacts.

Background
----------
exp/exp3/run_exp.py saves per-sample trace npz files that contain token-level
importance vectors over the FULL (prompt + generation) token sequence:
  - v_seq_all: sum over rows of seq attribution matrix (shape [P+G])
  - v_row_all: row attribution vector for indices_to_explain (shape [P+G])
  - v_rec_all: recursive attribution vector for indices_to_explain (shape [P+G])

For exp3 cached samples, we also have generation-token spans:
  - thinking_span_gen: CoT span [start,end] in generation-token coordinates
  - sink_span_gen: output span [start,end] in generation-token coordinates

This script slices v_*_all into:
  - cot: tokens in thinking_span_gen (offset by prompt_len)
  - output: tokens in sink_span_gen (offset by prompt_len)

and reports segment sums/fractions (and optionally writes a JSON summary).
"""

from __future__ import annotations

import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np


@dataclass(frozen=True)
class TracePaths:
    dataset: str
    model_tag: str
    run_tag: str
    npz_path: Path


def _pick_latest_subdir(path: Path) -> Optional[Path]:
    if not path.exists():
        return None
    subs = [p for p in path.iterdir() if p.is_dir()]
    if not subs:
        return None
    subs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
    return subs[0]


def _resolve_trace_paths(
    *,
    output_root: Path,
    dataset: str,
    model_tag: Optional[str],
    run_tag: Optional[str],
    example_idx: int,
) -> TracePaths:
    base = output_root / "traces" / dataset
    if not base.exists():
        raise FileNotFoundError(f"Trace dataset dir not found: {base}")

    if model_tag is None:
        model_dirs = [p for p in base.iterdir() if p.is_dir()]
        if not model_dirs:
            raise FileNotFoundError(f"No model subdir under: {base}")
        if len(model_dirs) != 1:
            raise SystemExit(f"Multiple model dirs under {base}; pass --model_tag. Found: {[p.name for p in model_dirs]}")
        model_dir = model_dirs[0]
        model_tag = model_dir.name
    else:
        model_dir = base / model_tag
        if not model_dir.exists():
            raise FileNotFoundError(f"Trace model dir not found: {model_dir}")

    if run_tag is None:
        run_dir = _pick_latest_subdir(model_dir)
        if run_dir is None:
            raise FileNotFoundError(f"No run subdir under: {model_dir}")
        run_tag = run_dir.name
    else:
        run_dir = model_dir / run_tag
        if not run_dir.exists():
            raise FileNotFoundError(f"Trace run dir not found: {run_dir}")

    npz_name = f"ex_{int(example_idx):06d}.npz"
    npz_path = run_dir / npz_name
    if not npz_path.exists():
        raise FileNotFoundError(f"Trace npz not found: {npz_path}")

    return TracePaths(dataset=dataset, model_tag=model_tag, run_tag=run_tag, npz_path=npz_path)


def _as_span(arr: Any) -> Optional[Tuple[int, int]]:
    if arr is None:
        return None
    try:
        a = np.asarray(arr).reshape(-1).tolist()
    except Exception:
        return None
    if len(a) != 2:
        return None
    try:
        start = int(a[0])
        end = int(a[1])
    except Exception:
        return None
    if start < 0 or end < start:
        return None
    return start, end


def _segment_stats(v: np.ndarray, start: int, end: int) -> Dict[str, float]:
    if end < start:
        return {"sum": 0.0, "mean": 0.0, "max": 0.0}
    seg = v[start : end + 1]
    if seg.size == 0:
        return {"sum": 0.0, "mean": 0.0, "max": 0.0}
    return {
        "sum": float(seg.sum()),
        "mean": float(seg.mean()),
        "max": float(seg.max()),
    }


def _slice_segment(v: np.ndarray, start: int, end: int) -> List[float]:
    if end < start:
        return []
    seg = v[start : end + 1]
    return [float(x) for x in seg.tolist()]


def extract_one(npz_path: Path) -> Dict[str, Any]:
    d = np.load(npz_path)
    required = ["prompt_len", "gen_len", "v_seq_all", "v_row_all", "v_rec_all"]
    for k in required:
        if k not in d:
            raise KeyError(f"Missing key in trace npz {npz_path}: {k}")

    prompt_len = int(np.asarray(d["prompt_len"]).item())
    gen_len = int(np.asarray(d["gen_len"]).item())
    total_len = prompt_len + gen_len

    v_seq_all = np.asarray(d["v_seq_all"], dtype=np.float64).reshape(-1)
    v_row_all = np.asarray(d["v_row_all"], dtype=np.float64).reshape(-1)
    v_rec_all = np.asarray(d["v_rec_all"], dtype=np.float64).reshape(-1)
    for name, v in [("v_seq_all", v_seq_all), ("v_row_all", v_row_all), ("v_rec_all", v_rec_all)]:
        if int(v.size) != int(total_len):
            raise ValueError(f"{name} length mismatch: expected {total_len}, got {int(v.size)}")

    sink_span_gen = _as_span(d.get("sink_span_gen"))
    thinking_span_gen = _as_span(d.get("thinking_span_gen"))
    if sink_span_gen is None:
        raise KeyError("Trace missing sink_span_gen; cannot define output span.")
    if thinking_span_gen is None:
        # Best-effort: infer thinking span as [0, sink_start-1].
        sink_start, _ = sink_span_gen
        thinking_span_gen = (0, max(0, sink_start - 1))

    think_start_g, think_end_g = thinking_span_gen
    sink_start_g, sink_end_g = sink_span_gen

    cot_start = prompt_len + think_start_g
    cot_end = min(prompt_len + think_end_g, total_len - 1)
    out_start = prompt_len + sink_start_g
    out_end = min(prompt_len + sink_end_g, total_len - 1)

    def pack(v: np.ndarray) -> Dict[str, Any]:
        total = float(v.sum())
        cot = _segment_stats(v, cot_start, cot_end)
        out = _segment_stats(v, out_start, out_end)
        denom = cot["sum"] + out["sum"]
        return {
            "total_sum": total,
            "cot": {
                "start_abs": int(cot_start),
                "end_abs": int(cot_end),
                "len": int(max(0, cot_end - cot_start + 1)),
                **cot,
                "fraction_of_total": float(cot["sum"] / total) if total > 0 else float("nan"),
                "fraction_of_cot_plus_output": float(cot["sum"] / denom) if denom > 0 else float("nan"),
            },
            "output": {
                "start_abs": int(out_start),
                "end_abs": int(out_end),
                "len": int(max(0, out_end - out_start + 1)),
                **out,
                "fraction_of_total": float(out["sum"] / total) if total > 0 else float("nan"),
                "fraction_of_cot_plus_output": float(out["sum"] / denom) if denom > 0 else float("nan"),
            },
            "cot_weights": _slice_segment(v, cot_start, cot_end),
            "output_weights": _slice_segment(v, out_start, out_end),
        }

    return {
        "prompt_len": int(prompt_len),
        "gen_len": int(gen_len),
        "total_len": int(total_len),
        "thinking_span_gen": [int(think_start_g), int(think_end_g)],
        "sink_span_gen": [int(sink_start_g), int(sink_end_g)],
        "seq": pack(v_seq_all),
        "row": pack(v_row_all),
        "rec": pack(v_rec_all),
    }


def main() -> None:
    parser = argparse.ArgumentParser("Extract CoT/output weights from exp3 traces.")
    parser.add_argument("--output_root", type=str, default="exp/exp3/output")
    parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2")
    parser.add_argument("--model_tag", type=str, default=None, help="If omitted, auto-detect when unique.")
    parser.add_argument("--run_tag", type=str, default=None, help="If omitted, picks the latest run subdir.")
    parser.add_argument("--example_idx", type=int, default=0)
    parser.add_argument("--out", type=str, default=None, help="Optional JSON output path.")
    args = parser.parse_args()

    output_root = Path(args.output_root)
    datasets = [f"{args.dataset_tag}_short_cot", f"{args.dataset_tag}_long_cot"]

    results: List[Dict[str, Any]] = []
    for ds_name in datasets:
        paths = _resolve_trace_paths(
            output_root=output_root,
            dataset=ds_name,
            model_tag=args.model_tag,
            run_tag=args.run_tag,
            example_idx=args.example_idx,
        )
        out = extract_one(paths.npz_path)
        out["dataset"] = paths.dataset
        out["model_tag"] = paths.model_tag
        out["run_tag"] = paths.run_tag
        out["npz_path"] = str(paths.npz_path)
        results.append(out)

    text = json.dumps(results, ensure_ascii=False, indent=2)
    if args.out:
        out_path = Path(args.out)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        out_path.write_text(text + "\n", encoding="utf-8")
        print(f"Wrote -> {out_path}")
    else:
        print(text)


if __name__ == "__main__":
    main()