wenbopan commited on
Commit
55b60a8
·
1 Parent(s): 12ad26e

Sync FlashTrace package from GitHub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -35
  2. .gitignore +33 -0
  3. .python-version +1 -0
  4. .vscode/launch.json +25 -0
  5. LICENSE +21 -0
  6. MANIFEST.in +5 -0
  7. README.md +293 -0
  8. attribution_datasets.py +265 -0
  9. docs/superpowers/plans/2026-05-03-flashtrace-public-package.md +1605 -0
  10. docs/superpowers/specs/2026-05-03-flashtrace-public-package-design.md +231 -0
  11. dump_exp2_hop_vh.py +412 -0
  12. evaluations/attribution_recovery.py +490 -0
  13. evaluations/attribution_recovery.sh +18 -0
  14. evaluations/faithfulness.py +491 -0
  15. evaluations/faithfulness.sh +80 -0
  16. example.ipynb +0 -0
  17. examples/quickstart.py +44 -0
  18. exp/case_study/README.md +152 -0
  19. exp/case_study/analysis.py +74 -0
  20. exp/case_study/faithfulness_trace.py +183 -0
  21. exp/case_study/run_ifr_case.py +1225 -0
  22. exp/case_study/run_mas_case.py +805 -0
  23. exp/case_study/viz.py +647 -0
  24. exp/exp1/README.md +46 -0
  25. exp/exp1/run_time_curve.py +757 -0
  26. exp/exp2/DATASETS.md +231 -0
  27. exp/exp2/README.md +106 -0
  28. exp/exp2/dataset_utils.py +386 -0
  29. exp/exp2/map_math_mine_to_exp2_cache.py +584 -0
  30. exp/exp2/migrate_indices_to_explain_token_span.py +129 -0
  31. exp/exp2/out.log +102 -0
  32. exp/exp2/run_exp.py +1296 -0
  33. exp/exp2/sample_and_filter.py +363 -0
  34. exp/exp3/README.md +50 -0
  35. exp/exp3/extract_segment_weights.py +250 -0
  36. exp/exp3/part_weights.py +228 -0
  37. exp/exp3/run_exp.py +430 -0
  38. exp/exp3/sample_and_filter.py +628 -0
  39. exp/exp4/README.md +85 -0
  40. exp/exp4/run_exp.py +487 -0
  41. exp/exp5/README.md +119 -0
  42. exp/exp5/map_exp2_cache_token_spans.py +407 -0
  43. exp/proc/README.md +98 -0
  44. exp/proc/map_exp2_traces_to_proc.py +411 -0
  45. exp/proc_1/README.md +72 -0
  46. exp/proc_1/map_exp2_traces_to_proc_1.py +338 -0
  47. flashtrace/__init__.py +7 -0
  48. flashtrace/attribution.py +0 -0
  49. flashtrace/baselines/__init__.py +5 -0
  50. flashtrace/baselines/attnlrp.py +12 -0
.gitattributes CHANGED
@@ -1,35 +1,6 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Keep language stats focused on the public Python package.
2
+ *.ipynb linguist-vendored
3
+ *.html linguist-generated
4
+ exp/** linguist-vendored
5
+ evaluations/** linguist-vendored
6
+ docs/** linguist-documentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Local data
13
+ data/
14
+
15
+ # dev
16
+ AGENTS.md
17
+ readme_dev.md
18
+ .superpowers/
19
+ contribute/ruler/
20
+ repos/.DS_Store
21
+ repomix-output.xml
22
+
23
+ # FlashTrace generated artifacts
24
+ trace.json
25
+ trace.html
26
+ *.trace.json
27
+ *.trace.html
28
+ exp/**/output/
29
+ exp/**/out/
30
+ exp/**/out-*/
31
+ *.npz
32
+ .DS_Store
33
+ repomix-output.xml
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.13
.vscode/launch.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Faithfulness Eval",
6
+ "type": "debugpy",
7
+ "request": "launch",
8
+ "module": "evaluations.faithfulness",
9
+ "args": [
10
+ "--model",
11
+ "qwen-4B",
12
+ "--cuda_num",
13
+ "1",
14
+ "--num_examples",
15
+ "500",
16
+ "--attr_func",
17
+ "IG",
18
+ "--dataset",
19
+ "facts"
20
+ ],
21
+ "console": "integratedTerminal",
22
+ "justMyCode": true
23
+ }
24
+ ]
25
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Wenbo Pan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ include README.md
2
+ include LICENSE
3
+ include pyproject.toml
4
+ include examples/*.py
5
+ include tests/*.py
README.md ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="https://raw.githubusercontent.com/wbopan/flashtrace/master/docs/assets/flashtrace-logo.png" alt="FlashTrace logo" width="160">
3
+ </p>
4
+
5
+ <h1 align="center">FlashTrace</h1>
6
+
7
+ <p align="center">
8
+ <em>Fast token attribution for reasoning language models.</em>
9
+ </p>
10
+
11
+ <p align="center">
12
+ <a href="https://pypi.org/project/flashtrace/"><img alt="PyPI" src="https://img.shields.io/pypi/v/flashtrace.svg?style=flat-square&logo=pypi&logoColor=white&label=PyPI"></a>
13
+ <a href="https://pypi.org/project/flashtrace/"><img alt="Python" src="https://img.shields.io/pypi/pyversions/flashtrace.svg?style=flat-square&logo=python&logoColor=white"></a>
14
+ <a href="https://github.com/wbopan/flashtrace/blob/master/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square"></a>
15
+ <a href="https://pytorch.org/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-2.5%2B-EE4C2C.svg?style=flat-square&logo=pytorch&logoColor=white"></a>
16
+ <a href="https://arxiv.org/abs/2602.01914"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2602.01914-B31B1B.svg?style=flat-square&logo=arxiv&logoColor=white"></a>
17
+ </p>
18
+
19
+ FlashTrace traces generated answers back to the prompt tokens that shaped them. Use it from Python or the command line, export JSON traces, and render standalone HTML heatmaps for inspection and sharing.
20
+
21
+ <p align="center">
22
+ <a href="https://arxiv.org/abs/2602.01914">📄 Paper</a>
23
+ &nbsp;·&nbsp;
24
+ <a href="#quickstart">🚀 Quickstart</a>
25
+ &nbsp;·&nbsp;
26
+ <a href="#command-line">💻 CLI</a>
27
+ &nbsp;·&nbsp;
28
+ <a href="#citation">📝 Citation</a>
29
+ </p>
30
+
31
+ ## Why FlashTrace
32
+
33
+ Reasoning models produce long generated chains, final answers, and intermediate spans that deserve targeted inspection. FlashTrace gives researchers a package-first workflow for tracing a selected generated span back to its supporting prompt tokens.
34
+
35
+ You get:
36
+
37
+ - top-k prompt tokens ranked by attribution score
38
+ - JSON traces for downstream analysis
39
+ - standalone HTML token heatmaps
40
+ - optional per-hop attribution panels
41
+ - inclusive generation-token span controls for answer and reasoning segments
42
+
43
+ ## Install
44
+
45
+ From PyPI:
46
+
47
+ ```bash
48
+ pip install flashtrace
49
+ ```
50
+
51
+ From a local checkout:
52
+
53
+ ```bash
54
+ pip install -e .
55
+ ```
56
+
57
+ For development:
58
+
59
+ ```bash
60
+ pip install -e ".[dev]"
61
+ ```
62
+
63
+ FlashTrace uses PyTorch, Transformers, Accelerate, NumPy, and tqdm. A CUDA-capable GPU is recommended for public-scale Hugging Face models.
64
+
65
+ ## Quickstart
66
+
67
+ ```python
68
+ from flashtrace import FlashTrace, load_model_and_tokenizer
69
+
70
+ prompt = """Context: Paris is the capital of France.
71
+ Question: What is the capital of France?"""
72
+ target = "Paris"
73
+
74
+ model, tokenizer = load_model_and_tokenizer("Qwen/Qwen3-8B", device_map="auto")
75
+ tracer = FlashTrace(model, tokenizer, chunk_tokens=128, sink_chunk_tokens=32)
76
+
77
+ trace = tracer.trace(
78
+ prompt=prompt,
79
+ target=target,
80
+ output_span=(0, 0),
81
+ hops=1,
82
+ )
83
+
84
+ print(trace.topk_inputs(10))
85
+ trace.to_json("trace.json")
86
+ trace.to_html("trace.html")
87
+ ```
88
+
89
+ `trace.topk_inputs(10)` returns `TokenScore` objects aligned to prompt-token indices:
90
+
91
+ ```text
92
+ rank index token score
93
+ 1 2 Paris 0.184
94
+ 2 7 capital 0.131
95
+ 3 10 France 0.119
96
+ ```
97
+
98
+ `trace.html` is a standalone heatmap that highlights prompt tokens by final attribution score and includes trace metadata for the selected generated span.
99
+
100
+ `FlashTrace(..., use_chat_template=True)` formats prompts with the tokenizer chat template for chat-tuned models.
101
+
102
+ ## Command Line
103
+
104
+ Create prompt and target files:
105
+
106
+ ```bash
107
+ printf "Context: Paris is the capital of France.\nQuestion: What is the capital of France?\n" > prompt.txt
108
+ printf "Paris" > target.txt
109
+ ```
110
+
111
+ Run a trace:
112
+
113
+ ```bash
114
+ flashtrace trace \
115
+ --model Qwen/Qwen3-8B \
116
+ --prompt prompt.txt \
117
+ --target target.txt \
118
+ --output-span 0:0 \
119
+ --hops 1 \
120
+ --html trace.html \
121
+ --json trace.json
122
+ ```
123
+
124
+ The command prints a compact top-k table and writes the requested artifacts.
125
+
126
+ Useful flags:
127
+
128
+ - `--model`: Hugging Face model id or local model path
129
+ - `--prompt`: UTF-8 prompt text file
130
+ - `--target`: UTF-8 target text file
131
+ - `--output-span`: inclusive `START:END` indices over generated tokens
132
+ - `--reasoning-span`: inclusive `START:END` indices for a reasoning segment
133
+ - `--method`: `flashtrace`, `ifr-span`, or `ifr-matrix`
134
+ - `--recompute-attention`: lower-memory attention recomputation path
135
+ - `--use-chat-template`: format prompts with the tokenizer chat template
136
+ - `--device-map`: Transformers device map, default `auto`
137
+ - `--dtype`: `auto`, `float16`, `bfloat16`, or `float32`
138
+
139
+ ## Token Spans
140
+
141
+ `output_span` and `reasoning_span` use inclusive generation-token indices. The first generated token has index `0`.
142
+
143
+ Use an initial trace to inspect tokenization:
144
+
145
+ ```python
146
+ for index, token in enumerate(trace.generation_tokens):
147
+ print(index, repr(token))
148
+ ```
149
+
150
+ Then choose spans:
151
+
152
+ ```python
153
+ trace = tracer.trace(
154
+ prompt=prompt,
155
+ target=target,
156
+ reasoning_span=(0, 79),
157
+ output_span=(80, 85),
158
+ hops=1,
159
+ )
160
+ ```
161
+
162
+ Scores are aligned to `trace.prompt_tokens`. `trace.per_hop_scores` stores the same prompt-token alignment for each hop.
163
+
164
+ ## Interpreting Results
165
+
166
+ High-scoring prompt tokens are the tokens FlashTrace attributes most strongly to the selected generated span. For answer inspection, use `output_span` around the final answer tokens. For chain-of-thought or reasoning inspection, use `reasoning_span` around the generated reasoning segment.
167
+
168
+ Recommended workflow:
169
+
170
+ 1. Run a trace with your prompt and target.
171
+ 2. Inspect `trace.generation_tokens`.
172
+ 3. Select the answer or reasoning span.
173
+ 4. Export `trace.html`.
174
+ 5. Compare top-k tokens with the source prompt and any expected evidence.
175
+
176
+ ## Supported Models
177
+
178
+ FlashTrace targets Llama/Qwen-style decoder-only Hugging Face causal LMs with:
179
+
180
+ - `model.layers`
181
+ - Q/K/V/O attention projections
182
+ - RMSNorm or LayerNorm
183
+ - RoPE metadata
184
+
185
+ Validated model families for the first public release:
186
+
187
+ - Qwen2
188
+ - Qwen3
189
+ - Llama
190
+
191
+ ## Python API
192
+
193
+ The public package exports:
194
+
195
+ ```python
196
+ from flashtrace import FlashTrace, TraceResult, load_model_and_tokenizer
197
+ ```
198
+
199
+ `FlashTrace.trace(...)` accepts:
200
+
201
+ - `prompt: str`
202
+ - `target: str | None`
203
+ - `output_span: tuple[int, int] | None`
204
+ - `reasoning_span: tuple[int, int] | None`
205
+ - `hops: int`
206
+ - `method: "flashtrace" | "ifr-span" | "ifr-matrix"`
207
+ - `renorm_threshold: float | None`
208
+
209
+ `TraceResult` includes:
210
+
211
+ - `prompt_tokens`
212
+ - `generation_tokens`
213
+ - `scores`
214
+ - `per_hop_scores`
215
+ - `thinking_ratios`
216
+ - `output_span`
217
+ - `reasoning_span`
218
+ - `method`
219
+ - `metadata`
220
+
221
+ Export helpers:
222
+
223
+ ```python
224
+ trace.topk_inputs(20)
225
+ trace.to_dict()
226
+ trace.to_json("trace.json")
227
+ trace.to_html("trace.html")
228
+ ```
229
+
230
+ ## Examples
231
+
232
+ ```bash
233
+ python examples/quickstart.py --help
234
+ python examples/quickstart.py \
235
+ --model Qwen/Qwen3-8B \
236
+ --prompt "Context: Paris is the capital of France. Question: What is the capital of France?" \
237
+ --target "Paris" \
238
+ --output-span 0:0 \
239
+ --html trace.html
240
+ ```
241
+
242
+ Heavy model examples are intended for GPU environments. CPU smoke tests use tiny randomly initialized models.
243
+
244
+ ## Repository Map
245
+
246
+ - `flashtrace/`: reusable Python package
247
+ - `examples/`: public quickstarts
248
+ - `tests/`: CPU smoke tests
249
+ - `exp/`: paper experiments and research artifacts
250
+ - `docs/superpowers/`: design and implementation planning documents
251
+
252
+ ## Research Experiments
253
+
254
+ The `exp/` directory contains the paper-era experiment runners, case studies, and saved artifacts. The public package API lives in `flashtrace/`; experiment scripts keep compatibility imports during the package migration.
255
+
256
+ ## Troubleshooting
257
+
258
+ **CUDA memory**
259
+
260
+ Use smaller models, lower precision, `device_map="auto"`, shorter prompts, or `--recompute-attention`.
261
+
262
+ **Span selection**
263
+
264
+ Print `trace.generation_tokens` and select inclusive generated-token indices. Tokenization can split visible words into multiple model tokens.
265
+
266
+ **Deterministic generation**
267
+
268
+ Pass a `target` file for attribution against a known output. Leave `--target` out when you want the CLI to generate with deterministic defaults.
269
+
270
+ **Tokenizer alignment**
271
+
272
+ Inspect `trace.prompt_tokens` and `trace.generation_tokens` when scores appear shifted from visible text. Attribution scores follow tokenizer-level alignment.
273
+
274
+ **HTML export**
275
+
276
+ `trace.to_html("trace.html")` writes a standalone file that can be opened locally or shared as an artifact.
277
+
278
+ ## Paper
279
+
280
+ FlashTrace implements the method described in [Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs](https://arxiv.org/abs/2602.01914).
281
+
282
+ ## Citation
283
+
284
+ ```bibtex
285
+ @misc{pan2026flashtrace,
286
+ title={Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs},
287
+ author={Pan, Wenbo and Liu, Zhichao and Wang, Xianlong and Yu, Haining and Jia, Xiaohua},
288
+ year={2026},
289
+ eprint={2602.01914},
290
+ archivePrefix={arXiv},
291
+ primaryClass={cs.LG}
292
+ }
293
+ ```
attribution_datasets.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
7
+
8
+ # Import sentence splitter from shared utils; fallback when unavailable
9
+ try:
10
+ from shared_utils import create_sentences, create_sentences_fallback, nlp
11
+ except Exception:
12
+ from shared_utils import create_sentences_fallback as create_sentences
13
+ nlp = None
14
+
15
+
16
+ @dataclass
17
+ class AttributionExample:
18
+ prompt: str
19
+ target: Optional[str] = None
20
+ indices_to_explain: Optional[List[int]] = None
21
+ attr_mask_indices: Optional[List[int]] = None
22
+ metadata: Dict[str, Any] = field(default_factory=dict)
23
+
24
+
25
+ class AttributionDataset(Iterable[AttributionExample]):
26
+ """Base iterable for attribution-ready datasets."""
27
+
28
+ name: str = "dataset"
29
+
30
+ def __init__(self) -> None:
31
+ self.examples: List[AttributionExample] = []
32
+
33
+ def __iter__(self) -> Iterator[AttributionExample]:
34
+ return iter(self.examples)
35
+
36
+ def __len__(self) -> int: # pragma: no cover - trivial
37
+ return len(self.examples)
38
+
39
+ def __getitem__(self, item): # pragma: no cover - convenience
40
+ return self.examples[item]
41
+
42
+
43
+ def _add_dummy_facts_to_prompt(text_sentences: Sequence[str]) -> List[str]:
44
+ """
45
+ Reproduces the original behaviour of interleaving dummy sentences with the
46
+ provided text segments so attribution heads can be masked easily.
47
+ """
48
+ result: List[str] = []
49
+ for sentence in text_sentences:
50
+ result.append(sentence)
51
+ result.append(" Unrelated Sentence.")
52
+ return result
53
+
54
+
55
+ class MathAttributionDataset(AttributionDataset):
56
+ """Dataset wrapper for synthetic math problems with dummy context facts."""
57
+
58
+ name = "math"
59
+
60
+ def __init__(self, path: str | Path, tokenizer: Any) -> None:
61
+ super().__init__()
62
+ data_path = Path(path)
63
+ with data_path.open("r", encoding="utf-8") as f:
64
+ raw_examples = json.load(f)
65
+
66
+ for entry in raw_examples:
67
+ question_text = entry["question"]
68
+ sentences = create_sentences(question_text, tokenizer)
69
+ if not sentences:
70
+ continue
71
+
72
+ context_sentences = sentences[:-1]
73
+ question_sentence = sentences[-1]
74
+ if question_sentence.startswith(" "):
75
+ question_sentence = question_sentence[1:]
76
+
77
+ context_with_dummy = _add_dummy_facts_to_prompt(context_sentences)
78
+ question_with_dummy = _add_dummy_facts_to_prompt([question_sentence])
79
+
80
+ prompt = "".join(context_with_dummy) + "\n" + "".join(question_with_dummy)
81
+ total_sentences = len(context_with_dummy) + len(question_with_dummy)
82
+ attr_mask_indices = list(range(0, total_sentences, 2))
83
+
84
+ self.examples.append(
85
+ AttributionExample(
86
+ prompt=prompt,
87
+ target=None,
88
+ indices_to_explain=[-2],
89
+ attr_mask_indices=attr_mask_indices,
90
+ metadata={"raw_question": question_text},
91
+ )
92
+ )
93
+
94
+
95
+ class FactsAttributionDataset(AttributionDataset):
96
+ """Dataset wrapper for curated factual prompts with explicit gold attributions."""
97
+
98
+ name = "facts"
99
+
100
+ def __init__(self, path: str | Path) -> None:
101
+ super().__init__()
102
+ data_path = Path(path)
103
+ with data_path.open("r", encoding="utf-8") as f:
104
+ raw_examples = json.load(f)
105
+
106
+ for entry in raw_examples:
107
+ metadata = {
108
+ key: value
109
+ for key, value in entry.items()
110
+ if key not in {"prompt", "target", "indices_to_explain", "attr_mask_indices"}
111
+ }
112
+ self.examples.append(
113
+ AttributionExample(
114
+ prompt=entry["prompt"],
115
+ target=entry.get("target"),
116
+ indices_to_explain=entry.get("indices_to_explain"),
117
+ attr_mask_indices=entry.get("attr_mask_indices"),
118
+ metadata=metadata,
119
+ )
120
+ )
121
+
122
+
123
+ class MoreHopQAAttributionDataset(AttributionDataset):
124
+ """Dataset wrapper for multi-hop QA prompts without explicit gold attribution."""
125
+
126
+ name = "morehopqa"
127
+
128
+ def __init__(self, path: str | Path) -> None:
129
+ super().__init__()
130
+ data_path = Path(path)
131
+ with data_path.open("r", encoding="utf-8") as f:
132
+ raw_examples = json.load(f)
133
+
134
+ for entry in raw_examples:
135
+ context_chunks = ["".join(item[1]) for item in entry.get("context", [])]
136
+ context = " ".join(context_chunks)
137
+ prompt = context + "\n" + entry["question"]
138
+
139
+ self.examples.append(
140
+ AttributionExample(
141
+ prompt=prompt,
142
+ target=None,
143
+ indices_to_explain=[-2],
144
+ attr_mask_indices=None,
145
+ metadata={
146
+ "answer": entry.get("answer"),
147
+ "id": entry.get("_id"),
148
+ "original_context": entry.get("context"),
149
+ },
150
+ )
151
+ )
152
+
153
+
154
+ # added
155
+ class RulerAttributionDataset(AttributionDataset):
156
+ """Dataset wrapper for raw RULER JSONL files with needle spans.
157
+
158
+ Expects a JSONL file produced by repos/RULER (with added `needle_spans`).
159
+ Each line must contain at least: `input`, `answer_prefix`, `outputs`, and
160
+ optionally `needle_spans` with character spans relative to `input`.
161
+
162
+ Mapping logic:
163
+ - prompt = input + answer_prefix
164
+ - target = answer_prefix (+ optional space) + ", ".join(outputs)
165
+ - sentence indices computed over " " + prompt (leading space to match evaluator)
166
+ - each span is shifted by +1 to account for that leading space
167
+ - attr_mask_indices = union of all sentences covered by any span
168
+ - indices_to_explain = [0] when target is present
169
+ """
170
+
171
+ name = "ruler"
172
+
173
+ def __init__(self, path: str | Path) -> None:
174
+ super().__init__()
175
+ data_path = Path(path)
176
+ if not data_path.exists():
177
+ raise FileNotFoundError(f"RULER file not found: {data_path}")
178
+
179
+ # Use shared nlp pipeline; fallback to a naive splitter if unavailable
180
+ if nlp is not None:
181
+ def _sentence_bounds(text: str) -> List[tuple[int, int]]:
182
+ doc = nlp(text)
183
+ return [(s.start_char, s.end_char) for s in doc.sents]
184
+ else:
185
+ def _sentence_bounds(text: str) -> List[tuple[int, int]]:
186
+ # Naive fallback: split on newlines, produce contiguous ranges
187
+ bounds: List[tuple[int, int]] = []
188
+ start = 0
189
+ parts = text.split("\n")
190
+ for idx, part in enumerate(parts):
191
+ end = start + len(part)
192
+ if end > start:
193
+ bounds.append((start, end))
194
+ start = end + 1
195
+ if not bounds:
196
+ bounds = [(0, len(text))]
197
+ return bounds
198
+
199
+ def _map_spans(bounds: Sequence[tuple[int, int]], spans: Sequence[tuple[int, int]]) -> List[int]:
200
+ indices: set[int] = set()
201
+ for start, end in spans:
202
+ matched = False
203
+ for i, (bs, be) in enumerate(bounds):
204
+ if start >= bs and end <= be:
205
+ indices.add(i)
206
+ matched = True
207
+ break
208
+ if not matched:
209
+ # fallback: include all sentences with any overlap
210
+ for i, (bs, be) in enumerate(bounds):
211
+ if not (end <= bs or start >= be):
212
+ indices.add(i)
213
+ return sorted(indices)
214
+
215
+ def _read_jsonl(fp: Path) -> Iterator[Dict[str, Any]]:
216
+ with fp.open("r", encoding="utf-8") as f:
217
+ for line in f:
218
+ line = line.strip()
219
+ if line:
220
+ yield json.loads(line)
221
+
222
+ for entry in _read_jsonl(data_path):
223
+ input_text: str = entry.get("input", "")
224
+ answer_prefix: str = entry.get("answer_prefix", "")
225
+ outputs = entry.get("outputs", []) or []
226
+
227
+ # Build prompt/target
228
+ prompt = input_text + answer_prefix
229
+ if outputs:
230
+ sep = " " if answer_prefix and not answer_prefix.endswith((" ", "\n", "\t")) else ""
231
+ target = answer_prefix + sep + ", ".join(outputs)
232
+ else:
233
+ target = answer_prefix
234
+
235
+ # Sentence bounds over leading-space prompt to match evaluator
236
+ prompt_for_seg = " " + prompt
237
+ bounds = _sentence_bounds(prompt_for_seg)
238
+
239
+ # Collect spans and shift by +1 for the leading space
240
+ spans_raw = []
241
+ for item in entry.get("needle_spans", []) or []:
242
+ span = item.get("span")
243
+ if isinstance(span, list) and len(span) == 2:
244
+ spans_raw.append((int(span[0]) + 1, int(span[1]) + 1))
245
+
246
+ attr_indices = _map_spans(bounds, spans_raw) if spans_raw else None
247
+
248
+ self.examples.append(
249
+ AttributionExample(
250
+ prompt=prompt,
251
+ target=target or None,
252
+ indices_to_explain=[0] if target else None,
253
+ attr_mask_indices=attr_indices,
254
+ metadata={
255
+ "dataset": "ruler",
256
+ "length": entry.get("length"),
257
+ "length_w_model_temp": entry.get("length_w_model_temp"),
258
+ "outputs": outputs,
259
+ "answer_prefix": answer_prefix,
260
+ "token_position_answer": entry.get("token_position_answer"),
261
+ "needle_spans": entry.get("needle_spans"),
262
+ "prompt_sentence_count": len(bounds),
263
+ },
264
+ )
265
+ )
docs/superpowers/plans/2026-05-03-flashtrace-public-package.md ADDED
@@ -0,0 +1,1605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashTrace Public Package Implementation Plan
2
+
3
+ > **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
4
+
5
+ **Goal:** Build an installable `flashtrace` package with a stable Python API, CLI tracing command, JSON export, HTML heatmap export, README quickstart, and CPU smoke tests.
6
+
7
+ **Architecture:** Create a package-first structure while preserving temporary root compatibility wrappers for existing experiment scripts. Move the IFR implementation and attribution engines into `flashtrace/`, wrap them with `FlashTrace` and `TraceResult`, then expose a CLI and public examples.
8
+
9
+ **Tech Stack:** Python 3.10+, PyTorch, Transformers, Accelerate, NumPy, tqdm, argparse, pytest.
10
+
11
+ ---
12
+
13
+ ## File Structure
14
+
15
+ Create or modify these files:
16
+
17
+ - Create: `flashtrace/__init__.py` for public exports.
18
+ - Create: `flashtrace/core.py` from `ifr_core.py`.
19
+ - Create: `flashtrace/shared_utils.py` from `shared_utils.py`.
20
+ - Create: `flashtrace/lrp_rules.py` from `lrp_rules.py`.
21
+ - Create: `flashtrace/lrp_patches.py` from `lrp_patches.py`.
22
+ - Create: `flashtrace/attribution.py` from `llm_attr.py`.
23
+ - Create: `flashtrace/improved.py` from `ft_ifr_improve.py`.
24
+ - Create: `flashtrace/result.py` for `TokenScore` and `TraceResult`.
25
+ - Create: `flashtrace/viz.py` for standalone HTML token heatmaps.
26
+ - Create: `flashtrace/tracer.py` for the `FlashTrace` facade.
27
+ - Create: `flashtrace/model_io.py` for Hugging Face loading helpers.
28
+ - Create: `flashtrace/cli.py` for `flashtrace trace`.
29
+ - Create: `flashtrace/baselines/__init__.py`.
30
+ - Create: `flashtrace/baselines/attnlrp.py`.
31
+ - Modify: `ifr_core.py`, `shared_utils.py`, `lrp_rules.py`, `lrp_patches.py`, `llm_attr.py`, `ft_ifr_improve.py` into root compatibility wrappers.
32
+ - Modify: `pyproject.toml` package metadata and console script.
33
+ - Modify: `.gitignore` generated artifact rules.
34
+ - Create: `README.md`.
35
+ - Create: `LICENSE`.
36
+ - Create: `examples/quickstart.py`.
37
+ - Create: `tests/helpers.py`.
38
+ - Create: `tests/test_imports.py`.
39
+ - Create: `tests/test_core_recompute.py`.
40
+ - Create: `tests/test_result.py`.
41
+ - Create: `tests/test_tracer.py`.
42
+ - Create: `tests/test_cli.py`.
43
+ - Delete: `model_generation.py`.
44
+
45
+ ## Task 1: Package Metadata And Skeleton
46
+
47
+ **Files:**
48
+ - Modify: `pyproject.toml`
49
+ - Create: `flashtrace/__init__.py`
50
+ - Create: `flashtrace/tracer.py`
51
+ - Create: `flashtrace/result.py`
52
+ - Create: `flashtrace/model_io.py`
53
+ - Create: `flashtrace/cli.py`
54
+ - Create: `flashtrace/baselines/__init__.py`
55
+ - Create: `flashtrace/baselines/attnlrp.py`
56
+ - Test: `tests/test_imports.py`
57
+
58
+ - [ ] **Step 1: Write the failing public import test**
59
+
60
+ Create `tests/test_imports.py`:
61
+
62
+ ```python
63
+ def test_public_imports():
64
+ import flashtrace
65
+
66
+ assert flashtrace.FlashTrace.__name__ == "FlashTrace"
67
+ assert flashtrace.TraceResult.__name__ == "TraceResult"
68
+ assert callable(flashtrace.load_model_and_tokenizer)
69
+ ```
70
+
71
+ - [ ] **Step 2: Run the import test and see the expected failure**
72
+
73
+ Run:
74
+
75
+ ```bash
76
+ uv run pytest tests/test_imports.py -q
77
+ ```
78
+
79
+ Expected: pytest reports an import failure for `flashtrace`.
80
+
81
+ - [ ] **Step 3: Create package directories**
82
+
83
+ Run:
84
+
85
+ ```bash
86
+ mkdir -p flashtrace/baselines tests
87
+ ```
88
+
89
+ - [ ] **Step 4: Add minimal public package files**
90
+
91
+ Create `flashtrace/tracer.py`:
92
+
93
+ ```python
94
+ class FlashTrace:
95
+ """Public facade for FlashTrace attribution."""
96
+
97
+ def __init__(self, model, tokenizer, **kwargs):
98
+ self.model = model
99
+ self.tokenizer = tokenizer
100
+ self.options = dict(kwargs)
101
+ ```
102
+
103
+ Create `flashtrace/result.py`:
104
+
105
+ ```python
106
+ from __future__ import annotations
107
+
108
+ from dataclasses import dataclass
109
+
110
+
111
+ @dataclass(frozen=True)
112
+ class TraceResult:
113
+ """Public attribution result returned by FlashTrace."""
114
+
115
+ prompt_tokens: list[str]
116
+ generation_tokens: list[str]
117
+ scores: list[float]
118
+ ```
119
+
120
+ Create `flashtrace/model_io.py`:
121
+
122
+ ```python
123
+ def load_model_and_tokenizer(*args, **kwargs):
124
+ """Load a Hugging Face causal LM and tokenizer."""
125
+
126
+ raise RuntimeError("load_model_and_tokenizer will be implemented in the model IO task.")
127
+ ```
128
+
129
+ Create `flashtrace/cli.py`:
130
+
131
+ ```python
132
+ def main(argv=None):
133
+ """FlashTrace command-line entrypoint."""
134
+
135
+ raise RuntimeError("CLI will be implemented in the CLI task.")
136
+ ```
137
+
138
+ Create `flashtrace/baselines/__init__.py`:
139
+
140
+ ```python
141
+ """Baseline attribution methods for FlashTrace."""
142
+ ```
143
+
144
+ Create `flashtrace/baselines/attnlrp.py`:
145
+
146
+ ```python
147
+ """AttnLRP baseline exports."""
148
+ ```
149
+
150
+ Create `flashtrace/__init__.py`:
151
+
152
+ ```python
153
+ """FlashTrace: efficient multi-token attribution for reasoning LLMs."""
154
+
155
+ from .model_io import load_model_and_tokenizer
156
+ from .result import TraceResult
157
+ from .tracer import FlashTrace
158
+
159
+ __all__ = ["FlashTrace", "TraceResult", "load_model_and_tokenizer"]
160
+ ```
161
+
162
+ - [ ] **Step 5: Update package metadata**
163
+
164
+ Replace `pyproject.toml` with:
165
+
166
+ ```toml
167
+ [project]
168
+ name = "flashtrace"
169
+ version = "0.1.0"
170
+ description = "Efficient multi-token attribution for reasoning language models."
171
+ readme = "README.md"
172
+ requires-python = ">=3.10"
173
+ dependencies = [
174
+ "accelerate>=1.11.0",
175
+ "matplotlib>=3.6",
176
+ "networkx>=3.3",
177
+ "numpy>=2.0",
178
+ "seaborn>=0.11",
179
+ "spacy>=3.8",
180
+ "torch>=2.5",
181
+ "tqdm>=4.67",
182
+ "transformers>=4.53",
183
+ "wordfreq>=3.1.1",
184
+ ]
185
+
186
+ [project.optional-dependencies]
187
+ baselines = [
188
+ "bert-score>=0.3.13",
189
+ "evaluate>=0.4.6",
190
+ "sentence-transformers>=4.1.0",
191
+ ]
192
+ eval = [
193
+ "datasets>=2.21",
194
+ "evaluate>=0.4.6",
195
+ ]
196
+ dev = [
197
+ "pytest>=8.0",
198
+ ]
199
+
200
+ [project.scripts]
201
+ flashtrace = "flashtrace.cli:main"
202
+
203
+ [tool.setuptools.packages.find]
204
+ include = ["flashtrace*"]
205
+ ```
206
+
207
+ - [ ] **Step 6: Run the import test**
208
+
209
+ Run:
210
+
211
+ ```bash
212
+ uv run pytest tests/test_imports.py -q
213
+ ```
214
+
215
+ Expected: `1 passed`.
216
+
217
+ - [ ] **Step 7: Commit**
218
+
219
+ Run:
220
+
221
+ ```bash
222
+ git add pyproject.toml flashtrace tests/test_imports.py
223
+ git commit -m "feat: add flashtrace package skeleton"
224
+ ```
225
+
226
+ ## Task 2: Core IFR Migration
227
+
228
+ **Files:**
229
+ - Create: `flashtrace/core.py`
230
+ - Create: `flashtrace/shared_utils.py`
231
+ - Modify: `ifr_core.py`
232
+ - Modify: `shared_utils.py`
233
+ - Create: `tests/helpers.py`
234
+ - Create: `tests/test_core_recompute.py`
235
+
236
+ - [ ] **Step 1: Add the tiny-model test helper**
237
+
238
+ Create `tests/helpers.py`:
239
+
240
+ ```python
241
+ from __future__ import annotations
242
+
243
+ from tokenizers import Tokenizer, models, pre_tokenizers
244
+ from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
245
+
246
+
247
+ def make_tiny_qwen2_model_and_tokenizer(
248
+ *,
249
+ n_layers: int = 3,
250
+ d_model: int = 48,
251
+ n_heads: int = 4,
252
+ n_kv_heads: int = 2,
253
+ max_pos: int = 128,
254
+ ):
255
+ config = AutoConfig.for_model(
256
+ "qwen2",
257
+ vocab_size=500,
258
+ hidden_size=d_model,
259
+ intermediate_size=d_model * 2,
260
+ num_hidden_layers=n_layers,
261
+ num_attention_heads=n_heads,
262
+ num_key_value_heads=n_kv_heads,
263
+ max_position_embeddings=max_pos,
264
+ use_sliding_window=False,
265
+ attn_implementation="eager",
266
+ )
267
+ model = AutoModelForCausalLM.from_config(config, attn_implementation="eager")
268
+ model.eval()
269
+
270
+ backend = Tokenizer(models.WordLevel(vocab={f"t{i}": i for i in range(500)}, unk_token="t0"))
271
+ backend.pre_tokenizer = pre_tokenizers.Whitespace()
272
+ tokenizer = PreTrainedTokenizerFast(tokenizer_object=backend, eos_token="t1", pad_token="t2")
273
+ tokenizer.chat_template = "{% for m in messages %}{{ m['content'] }}{% endfor %}"
274
+ return model, tokenizer
275
+ ```
276
+
277
+ - [ ] **Step 2: Write the failing core import smoke test**
278
+
279
+ Create `tests/test_core_recompute.py`:
280
+
281
+ ```python
282
+ import torch
283
+
284
+ from flashtrace import core
285
+ from tests.helpers import make_tiny_qwen2_model_and_tokenizer
286
+
287
+
288
+ def test_core_metadata_and_weight_pack():
289
+ model, _ = make_tiny_qwen2_model_and_tokenizer()
290
+
291
+ metadata = core.extract_model_metadata(model)
292
+ weight_pack = core.build_weight_pack(metadata, next(model.parameters()).dtype)
293
+
294
+ assert metadata.n_layers == 3
295
+ assert metadata.n_heads_q == 4
296
+ assert metadata.n_kv_heads == 2
297
+ assert len(weight_pack) == 3
298
+ assert torch.is_tensor(weight_pack[0]["v_w"])
299
+ ```
300
+
301
+ - [ ] **Step 3: Run the core smoke test and see the expected failure**
302
+
303
+ Run:
304
+
305
+ ```bash
306
+ uv run pytest tests/test_core_recompute.py::test_core_metadata_and_weight_pack -q
307
+ ```
308
+
309
+ Expected: pytest reports missing `flashtrace.core`.
310
+
311
+ - [ ] **Step 4: Copy the IFR core into the package**
312
+
313
+ Run:
314
+
315
+ ```bash
316
+ cp ifr_core.py flashtrace/core.py
317
+ ```
318
+
319
+ - [ ] **Step 5: Copy shared utilities into the package**
320
+
321
+ Run:
322
+
323
+ ```bash
324
+ cp shared_utils.py flashtrace/shared_utils.py
325
+ ```
326
+
327
+ - [ ] **Step 6: Replace root `ifr_core.py` with a compatibility wrapper**
328
+
329
+ Replace `ifr_core.py` with:
330
+
331
+ ```python
332
+ """Compatibility wrapper for package-era imports."""
333
+
334
+ from flashtrace.core import * # noqa: F401,F403
335
+ ```
336
+
337
+ - [ ] **Step 7: Replace root `shared_utils.py` with a compatibility wrapper**
338
+
339
+ Replace `shared_utils.py` with:
340
+
341
+ ```python
342
+ """Compatibility wrapper for package-era imports."""
343
+
344
+ from flashtrace.shared_utils import * # noqa: F401,F403
345
+ ```
346
+
347
+ - [ ] **Step 8: Run the core smoke test**
348
+
349
+ Run:
350
+
351
+ ```bash
352
+ uv run pytest tests/test_core_recompute.py::test_core_metadata_and_weight_pack -q
353
+ ```
354
+
355
+ Expected: `1 passed`.
356
+
357
+ - [ ] **Step 9: Commit**
358
+
359
+ Run:
360
+
361
+ ```bash
362
+ git add flashtrace/core.py flashtrace/shared_utils.py ifr_core.py shared_utils.py tests/helpers.py tests/test_core_recompute.py
363
+ git commit -m "feat: move IFR core into package"
364
+ ```
365
+
366
+ ## Task 3: Attribution Engine Migration
367
+
368
+ **Files:**
369
+ - Create: `flashtrace/lrp_rules.py`
370
+ - Create: `flashtrace/lrp_patches.py`
371
+ - Create: `flashtrace/attribution.py`
372
+ - Create: `flashtrace/improved.py`
373
+ - Modify: `lrp_rules.py`
374
+ - Modify: `lrp_patches.py`
375
+ - Modify: `llm_attr.py`
376
+ - Modify: `ft_ifr_improve.py`
377
+ - Modify: `flashtrace/baselines/attnlrp.py`
378
+ - Test: `tests/test_core_recompute.py`
379
+
380
+ - [ ] **Step 1: Extend the recompute test with package attribution paths**
381
+
382
+ Append to `tests/test_core_recompute.py`:
383
+
384
+ ```python
385
+ from flashtrace.attribution import LLMIFRAttribution
386
+
387
+
388
+ def test_package_attribution_recompute_matches_stored_attention():
389
+ model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
390
+ prompt = "t10 t20 t30 t40"
391
+ target = "t60 t70"
392
+
393
+ stored = LLMIFRAttribution(model, tokenizer, recompute_attention=False).calculate_ifr_span(prompt, target)
394
+ recomputed = LLMIFRAttribution(model, tokenizer, recompute_attention=True).calculate_ifr_span(prompt, target)
395
+
396
+ diff = (stored.attribution_matrix - recomputed.attribution_matrix).abs().max().item()
397
+ assert diff < 1e-5
398
+ ```
399
+
400
+ - [ ] **Step 2: Run the package attribution test and see the expected failure**
401
+
402
+ Run:
403
+
404
+ ```bash
405
+ uv run pytest tests/test_core_recompute.py::test_package_attribution_recompute_matches_stored_attention -q
406
+ ```
407
+
408
+ Expected: pytest reports missing `flashtrace.attribution`.
409
+
410
+ - [ ] **Step 3: Copy LRP helpers and attribution engines into the package**
411
+
412
+ Run:
413
+
414
+ ```bash
415
+ cp lrp_rules.py flashtrace/lrp_rules.py
416
+ cp lrp_patches.py flashtrace/lrp_patches.py
417
+ cp llm_attr.py flashtrace/attribution.py
418
+ cp ft_ifr_improve.py flashtrace/improved.py
419
+ ```
420
+
421
+ - [ ] **Step 4: Update imports in `flashtrace/attribution.py`**
422
+
423
+ Edit package-local imports to this form:
424
+
425
+ ```python
426
+ from .core import (
427
+ IFRParameters,
428
+ ModelMetadata,
429
+ attach_hooks,
430
+ build_weight_pack,
431
+ compute_ifr_for_all_positions,
432
+ compute_ifr_sentence_aggregate,
433
+ compute_multi_hop_ifr,
434
+ extract_model_metadata,
435
+ )
436
+ from .shared_utils import (
437
+ DEFAULT_GENERATE_KWARGS,
438
+ DEFAULT_PROMPT_TEMPLATE,
439
+ create_sentences,
440
+ create_sentence_masks,
441
+ )
442
+ from .lrp_patches import lrp_context, detect_model_type
443
+ ```
444
+
445
+ - [ ] **Step 5: Update imports in `flashtrace/lrp_patches.py`**
446
+
447
+ Edit the LRP helper import to:
448
+
449
+ ```python
450
+ from .lrp_rules import stop_gradient, divide_gradient, identity_rule_implicit
451
+ ```
452
+
453
+ - [ ] **Step 6: Update imports in `flashtrace/improved.py`**
454
+
455
+ Edit the top-level package imports to:
456
+
457
+ ```python
458
+ from . import attribution as llm_attr
459
+ from .core import IFRAggregate, MultiHopIFRResult, compute_ifr_sentence_aggregate
460
+ ```
461
+
462
+ - [ ] **Step 7: Replace root compatibility modules**
463
+
464
+ Replace `lrp_rules.py` with:
465
+
466
+ ```python
467
+ """Compatibility wrapper for package-era imports."""
468
+
469
+ from flashtrace.lrp_rules import * # noqa: F401,F403
470
+ ```
471
+
472
+ Replace `lrp_patches.py` with:
473
+
474
+ ```python
475
+ """Compatibility wrapper for package-era imports."""
476
+
477
+ from flashtrace.lrp_patches import * # noqa: F401,F403
478
+ ```
479
+
480
+ Replace `llm_attr.py` with:
481
+
482
+ ```python
483
+ """Compatibility wrapper for package-era imports."""
484
+
485
+ from flashtrace.attribution import * # noqa: F401,F403
486
+ ```
487
+
488
+ Replace `ft_ifr_improve.py` with:
489
+
490
+ ```python
491
+ """Compatibility wrapper for package-era imports."""
492
+
493
+ from flashtrace.improved import * # noqa: F401,F403
494
+ ```
495
+
496
+ - [ ] **Step 8: Export the AttnLRP baseline**
497
+
498
+ Replace `flashtrace/baselines/attnlrp.py` with:
499
+
500
+ ```python
501
+ """AttnLRP baseline API."""
502
+
503
+ from flashtrace.attribution import AttnLRPSpanAggregate, LLMLRPAttribution, MultiHopAttnLRPResult
504
+ from flashtrace.lrp_patches import detect_model_type, lrp_context
505
+
506
+ __all__ = [
507
+ "AttnLRPSpanAggregate",
508
+ "LLMLRPAttribution",
509
+ "MultiHopAttnLRPResult",
510
+ "detect_model_type",
511
+ "lrp_context",
512
+ ]
513
+ ```
514
+
515
+ Replace `flashtrace/baselines/__init__.py` with:
516
+
517
+ ```python
518
+ """Baseline attribution methods for FlashTrace."""
519
+
520
+ from .attnlrp import LLMLRPAttribution
521
+
522
+ __all__ = ["LLMLRPAttribution"]
523
+ ```
524
+
525
+ - [ ] **Step 9: Run attribution migration tests**
526
+
527
+ Run:
528
+
529
+ ```bash
530
+ uv run pytest tests/test_core_recompute.py -q
531
+ ```
532
+
533
+ Expected: all tests in the file pass.
534
+
535
+ - [ ] **Step 10: Run a root compatibility import check**
536
+
537
+ Run:
538
+
539
+ ```bash
540
+ uv run python -c "import ifr_core, llm_attr, ft_ifr_improve; print(llm_attr.LLMIFRAttribution.__name__)"
541
+ ```
542
+
543
+ Expected: prints `LLMIFRAttribution`.
544
+
545
+ - [ ] **Step 11: Commit**
546
+
547
+ Run:
548
+
549
+ ```bash
550
+ git add flashtrace lrp_rules.py lrp_patches.py llm_attr.py ft_ifr_improve.py tests/test_core_recompute.py
551
+ git commit -m "feat: move attribution engines into package"
552
+ ```
553
+
554
+ ## Task 4: TraceResult And HTML Heatmap
555
+
556
+ **Files:**
557
+ - Modify: `flashtrace/result.py`
558
+ - Create: `flashtrace/viz.py`
559
+ - Create: `tests/test_result.py`
560
+
561
+ - [ ] **Step 1: Write result object tests**
562
+
563
+ Create `tests/test_result.py`:
564
+
565
+ ```python
566
+ import json
567
+
568
+ from flashtrace.result import TokenScore, TraceResult
569
+
570
+
571
+ def make_result():
572
+ return TraceResult(
573
+ prompt_tokens=[" alpha", " beta", " gamma"],
574
+ generation_tokens=[" answer"],
575
+ scores=[0.2, 0.7, 0.1],
576
+ per_hop_scores=[[0.1, 0.4, 0.0], [0.1, 0.3, 0.1]],
577
+ thinking_ratios=[0.5, 0.2],
578
+ output_span=(0, 0),
579
+ reasoning_span=(0, 0),
580
+ method="flashtrace",
581
+ metadata={"model": "tiny"},
582
+ )
583
+
584
+
585
+ def test_topk_inputs_sorted():
586
+ result = make_result()
587
+
588
+ top = result.topk_inputs(2)
589
+
590
+ assert top == [
591
+ TokenScore(index=1, token=" beta", score=0.7),
592
+ TokenScore(index=0, token=" alpha", score=0.2),
593
+ ]
594
+
595
+
596
+ def test_to_dict_is_json_serializable():
597
+ result = make_result()
598
+
599
+ payload = result.to_dict()
600
+
601
+ assert payload["method"] == "flashtrace"
602
+ assert payload["top_inputs"][0]["token"] == " beta"
603
+ json.dumps(payload)
604
+
605
+
606
+ def test_to_dict_sanitizes_tensor_metadata():
607
+ import torch
608
+
609
+ result = TraceResult(
610
+ prompt_tokens=[" alpha"],
611
+ generation_tokens=[" answer"],
612
+ scores=[1.0],
613
+ metadata={"tensor": torch.tensor([1.0, 2.0]), "object": object()},
614
+ )
615
+
616
+ payload = result.to_dict()
617
+
618
+ assert payload["metadata"]["tensor"] == [1.0, 2.0]
619
+ assert isinstance(payload["metadata"]["object"], str)
620
+ json.dumps(payload)
621
+
622
+
623
+ def test_json_and_html_export(tmp_path):
624
+ result = make_result()
625
+ json_path = tmp_path / "trace.json"
626
+ html_path = tmp_path / "trace.html"
627
+
628
+ result.to_json(json_path)
629
+ result.to_html(html_path)
630
+
631
+ assert json_path.read_text(encoding="utf-8").startswith("{")
632
+ html = html_path.read_text(encoding="utf-8")
633
+ assert "<html" in html
634
+ assert " beta" in html
635
+ ```
636
+
637
+ - [ ] **Step 2: Run result tests and see the expected failure**
638
+
639
+ Run:
640
+
641
+ ```bash
642
+ uv run pytest tests/test_result.py -q
643
+ ```
644
+
645
+ Expected: pytest reports missing `TokenScore` or missing methods.
646
+
647
+ - [ ] **Step 3: Implement `TraceResult`**
648
+
649
+ Replace `flashtrace/result.py` with:
650
+
651
+ ```python
652
+ from __future__ import annotations
653
+
654
+ import json
655
+ from dataclasses import asdict, dataclass, field, is_dataclass
656
+ from pathlib import Path
657
+ from typing import Any
658
+
659
+
660
+ @dataclass(frozen=True)
661
+ class TokenScore:
662
+ index: int
663
+ token: str
664
+ score: float
665
+
666
+
667
+ @dataclass(frozen=True)
668
+ class TraceResult:
669
+ """Public attribution result returned by FlashTrace."""
670
+
671
+ prompt_tokens: list[str]
672
+ generation_tokens: list[str]
673
+ scores: list[float]
674
+ per_hop_scores: list[list[float]] = field(default_factory=list)
675
+ thinking_ratios: list[float] = field(default_factory=list)
676
+ output_span: tuple[int, int] | None = None
677
+ reasoning_span: tuple[int, int] | None = None
678
+ method: str = "flashtrace"
679
+ metadata: dict[str, Any] = field(default_factory=dict)
680
+
681
+ def topk_inputs(self, k: int = 20) -> list[TokenScore]:
682
+ limit = max(0, int(k))
683
+ items = [
684
+ TokenScore(index=i, token=tok, score=float(score))
685
+ for i, (tok, score) in enumerate(zip(self.prompt_tokens, self.scores))
686
+ ]
687
+ items.sort(key=lambda item: item.score, reverse=True)
688
+ return items[:limit]
689
+
690
+ def to_dict(self) -> dict[str, Any]:
691
+ return {
692
+ "method": self.method,
693
+ "prompt_tokens": list(self.prompt_tokens),
694
+ "generation_tokens": list(self.generation_tokens),
695
+ "scores": [float(x) for x in self.scores],
696
+ "per_hop_scores": [[float(x) for x in row] for row in self.per_hop_scores],
697
+ "thinking_ratios": [float(x) for x in self.thinking_ratios],
698
+ "output_span": list(self.output_span) if self.output_span is not None else None,
699
+ "reasoning_span": list(self.reasoning_span) if self.reasoning_span is not None else None,
700
+ "top_inputs": [asdict(item) for item in self.topk_inputs()],
701
+ "metadata": _jsonable(self.metadata),
702
+ }
703
+
704
+ def to_json(self, path: str | Path) -> None:
705
+ target = Path(path)
706
+ target.write_text(json.dumps(self.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
707
+
708
+ def to_html(self, path: str | Path) -> None:
709
+ from .viz import render_trace_html
710
+
711
+ target = Path(path)
712
+ target.write_text(render_trace_html(self), encoding="utf-8")
713
+
714
+
715
+ def _jsonable(value: Any) -> Any:
716
+ if value is None or isinstance(value, (str, int, float, bool)):
717
+ return value
718
+ if hasattr(value, "detach") and hasattr(value, "cpu"):
719
+ try:
720
+ return value.detach().cpu().tolist()
721
+ except Exception:
722
+ return repr(value)
723
+ if is_dataclass(value):
724
+ return _jsonable(asdict(value))
725
+ if isinstance(value, dict):
726
+ return {str(k): _jsonable(v) for k, v in value.items()}
727
+ if isinstance(value, (list, tuple)):
728
+ return [_jsonable(v) for v in value]
729
+ return repr(value)
730
+ ```
731
+
732
+ - [ ] **Step 4: Implement the standalone HTML renderer**
733
+
734
+ Create `flashtrace/viz.py`:
735
+
736
+ ```python
737
+ from __future__ import annotations
738
+
739
+ from html import escape
740
+ from typing import TYPE_CHECKING
741
+
742
+ if TYPE_CHECKING:
743
+ from .result import TraceResult
744
+
745
+
746
+ def _score_color(score: float, max_score: float) -> str:
747
+ if max_score <= 0.0:
748
+ return "rgba(245,245,245,0.75)"
749
+ ratio = min(1.0, abs(float(score)) / (max_score + 1e-12))
750
+ red = 255
751
+ green = int(246 - 105 * ratio)
752
+ blue = int(226 - 170 * ratio)
753
+ alpha = 0.22 + 0.58 * ratio
754
+ return f"rgba({red},{green},{blue},{alpha:.3f})"
755
+
756
+
757
+ def _render_token_row(tokens: list[str], scores: list[float]) -> str:
758
+ max_score = max((abs(float(x)) for x in scores), default=0.0)
759
+ spans = []
760
+ for index, token in enumerate(tokens):
761
+ score = float(scores[index]) if index < len(scores) else 0.0
762
+ color = _score_color(score, max_score)
763
+ spans.append(
764
+ "<span class='tok' "
765
+ f"title='idx={index} score={score:.6f}' "
766
+ f"style='background:{color}'>{escape(token)}</span>"
767
+ )
768
+ return "".join(spans)
769
+
770
+
771
+ def render_trace_html(result: "TraceResult") -> str:
772
+ top_rows = "\n".join(
773
+ f"<tr><td>{item.index}</td><td><code>{escape(item.token)}</code></td><td>{item.score:.6f}</td></tr>"
774
+ for item in result.topk_inputs(20)
775
+ )
776
+ hop_sections = []
777
+ for hop_index, hop_scores in enumerate(result.per_hop_scores):
778
+ hop_sections.append(
779
+ f"<section><h2>Hop {hop_index}</h2><div class='tokens'>{_render_token_row(result.prompt_tokens, hop_scores)}</div></section>"
780
+ )
781
+ hop_html = "\n".join(hop_sections)
782
+ metadata = escape(str(result.metadata))
783
+ return f"""<!doctype html>
784
+ <html lang="en">
785
+ <head>
786
+ <meta charset="utf-8">
787
+ <title>FlashTrace</title>
788
+ <style>
789
+ body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; margin: 32px; color: #151515; }}
790
+ h1, h2 {{ margin: 0 0 12px; }}
791
+ section {{ margin: 24px 0; }}
792
+ .tokens {{ line-height: 2.2; font-family: ui-monospace, SFMono-Regular, Menlo, monospace; }}
793
+ .tok {{ display: inline-block; margin: 2px; padding: 2px 4px; border-radius: 4px; white-space: pre-wrap; }}
794
+ table {{ border-collapse: collapse; margin-top: 12px; }}
795
+ td, th {{ border-bottom: 1px solid #ddd; padding: 6px 10px; text-align: left; }}
796
+ .meta {{ color: #555; font-size: 13px; }}
797
+ </style>
798
+ </head>
799
+ <body>
800
+ <h1>FlashTrace</h1>
801
+ <p class="meta">method={escape(result.method)} output_span={escape(str(result.output_span))} reasoning_span={escape(str(result.reasoning_span))}</p>
802
+ <section>
803
+ <h2>Prompt Attribution</h2>
804
+ <div class="tokens">{_render_token_row(result.prompt_tokens, result.scores)}</div>
805
+ </section>
806
+ {hop_html}
807
+ <section>
808
+ <h2>Top Input Tokens</h2>
809
+ <table><thead><tr><th>Index</th><th>Token</th><th>Score</th></tr></thead><tbody>{top_rows}</tbody></table>
810
+ </section>
811
+ <section>
812
+ <h2>Metadata</h2>
813
+ <pre>{metadata}</pre>
814
+ </section>
815
+ </body>
816
+ </html>
817
+ """
818
+ ```
819
+
820
+ - [ ] **Step 5: Run result tests**
821
+
822
+ Run:
823
+
824
+ ```bash
825
+ uv run pytest tests/test_result.py -q
826
+ ```
827
+
828
+ Expected: `4 passed`.
829
+
830
+ - [ ] **Step 6: Commit**
831
+
832
+ Run:
833
+
834
+ ```bash
835
+ git add flashtrace/result.py flashtrace/viz.py tests/test_result.py
836
+ git commit -m "feat: add trace result exports"
837
+ ```
838
+
839
+ ## Task 5: FlashTrace Facade
840
+
841
+ **Files:**
842
+ - Modify: `flashtrace/tracer.py`
843
+ - Modify: `flashtrace/__init__.py`
844
+ - Create: `tests/test_tracer.py`
845
+
846
+ - [ ] **Step 1: Write tracer API tests**
847
+
848
+ Create `tests/test_tracer.py`:
849
+
850
+ ```python
851
+ from flashtrace import FlashTrace, TraceResult
852
+ from tests.helpers import make_tiny_qwen2_model_and_tokenizer
853
+
854
+
855
+ def test_flashtrace_trace_returns_public_result():
856
+ model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
857
+ tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
858
+
859
+ result = tracer.trace(
860
+ prompt="t10 t20 t30 t40",
861
+ target="t60 t70 t80",
862
+ output_span=(1, 2),
863
+ reasoning_span=(0, 1),
864
+ hops=1,
865
+ )
866
+
867
+ assert isinstance(result, TraceResult)
868
+ assert result.method == "flashtrace"
869
+ assert len(result.prompt_tokens) > 0
870
+ assert len(result.scores) == len(result.prompt_tokens)
871
+ assert result.output_span == (1, 2)
872
+ assert result.reasoning_span == (0, 1)
873
+
874
+
875
+ def test_ifr_span_method_returns_public_result():
876
+ model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
877
+ tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
878
+
879
+ result = tracer.trace(
880
+ prompt="t10 t20 t30 t40",
881
+ target="t60 t70",
882
+ output_span=(0, 1),
883
+ method="ifr-span",
884
+ )
885
+
886
+ assert result.method == "ifr-span"
887
+ assert len(result.scores) == len(result.prompt_tokens)
888
+ ```
889
+
890
+ - [ ] **Step 2: Run tracer tests and see the expected failure**
891
+
892
+ Run:
893
+
894
+ ```bash
895
+ uv run pytest tests/test_tracer.py -q
896
+ ```
897
+
898
+ Expected: pytest reports missing `trace`.
899
+
900
+ - [ ] **Step 3: Implement result adaptation helpers and facade**
901
+
902
+ Replace `flashtrace/tracer.py` with:
903
+
904
+ ```python
905
+ from __future__ import annotations
906
+
907
+ from typing import Any, Literal
908
+
909
+ import torch
910
+
911
+ from .attribution import LLMIFRAttribution, LLMAttributionResult
912
+ from .improved import LLMIFRAttributionBoth
913
+ from .result import TraceResult
914
+
915
+ TraceMethod = Literal["flashtrace", "ifr-span", "ifr-matrix"]
916
+
917
+
918
+ def _to_float_list(values: Any) -> list[float]:
919
+ if torch.is_tensor(values):
920
+ values = values.detach().cpu().to(dtype=torch.float32).tolist()
921
+ return [float(x) for x in (values or [])]
922
+
923
+
924
+ class FlashTrace:
925
+ """Public facade for FlashTrace attribution."""
926
+
927
+ def __init__(
928
+ self,
929
+ model,
930
+ tokenizer,
931
+ *,
932
+ chunk_tokens: int = 128,
933
+ sink_chunk_tokens: int = 32,
934
+ recompute_attention: bool = False,
935
+ generate_kwargs: dict[str, Any] | None = None,
936
+ ) -> None:
937
+ self.model = model
938
+ self.tokenizer = tokenizer
939
+ self.chunk_tokens = int(chunk_tokens)
940
+ self.sink_chunk_tokens = int(sink_chunk_tokens)
941
+ self.recompute_attention = bool(recompute_attention)
942
+ self.generate_kwargs = generate_kwargs
943
+
944
+ def trace(
945
+ self,
946
+ *,
947
+ prompt: str,
948
+ target: str | None = None,
949
+ output_span: tuple[int, int] | None = None,
950
+ reasoning_span: tuple[int, int] | None = None,
951
+ hops: int = 1,
952
+ method: TraceMethod = "flashtrace",
953
+ renorm_threshold: float | None = None,
954
+ ) -> TraceResult:
955
+ if method == "flashtrace":
956
+ engine = LLMIFRAttributionBoth(
957
+ self.model,
958
+ self.tokenizer,
959
+ generate_kwargs=self.generate_kwargs,
960
+ chunk_tokens=self.chunk_tokens,
961
+ sink_chunk_tokens=self.sink_chunk_tokens,
962
+ recompute_attention=self.recompute_attention,
963
+ )
964
+ raw = engine.calculate_ifr_multi_hop_both(
965
+ prompt,
966
+ target=target,
967
+ sink_span=output_span,
968
+ thinking_span=reasoning_span,
969
+ n_hops=int(hops),
970
+ renorm_threshold=renorm_threshold,
971
+ )
972
+ elif method == "ifr-span":
973
+ engine = LLMIFRAttribution(
974
+ self.model,
975
+ self.tokenizer,
976
+ generate_kwargs=self.generate_kwargs,
977
+ chunk_tokens=self.chunk_tokens,
978
+ sink_chunk_tokens=self.sink_chunk_tokens,
979
+ recompute_attention=self.recompute_attention,
980
+ )
981
+ raw = engine.calculate_ifr_span(
982
+ prompt,
983
+ target=target,
984
+ span=output_span,
985
+ renorm_threshold=renorm_threshold,
986
+ )
987
+ elif method == "ifr-matrix":
988
+ engine = LLMIFRAttribution(
989
+ self.model,
990
+ self.tokenizer,
991
+ generate_kwargs=self.generate_kwargs,
992
+ chunk_tokens=self.chunk_tokens,
993
+ sink_chunk_tokens=self.sink_chunk_tokens,
994
+ recompute_attention=self.recompute_attention,
995
+ )
996
+ raw = engine.calculate_ifr_for_all_positions_output_only(
997
+ prompt,
998
+ target=target,
999
+ sink_span=output_span,
1000
+ renorm_threshold=renorm_threshold,
1001
+ )
1002
+ else:
1003
+ raise ValueError(f"Unsupported method: {method}")
1004
+
1005
+ return self._build_result(raw, method=method, output_span=output_span, reasoning_span=reasoning_span)
1006
+
1007
+ def _build_result(
1008
+ self,
1009
+ raw: LLMAttributionResult,
1010
+ *,
1011
+ method: str,
1012
+ output_span: tuple[int, int] | None,
1013
+ reasoning_span: tuple[int, int] | None,
1014
+ ) -> TraceResult:
1015
+ prompt_tokens = list(raw.prompt_tokens)
1016
+ generation_tokens = list(raw.generation_tokens)
1017
+ prompt_len = len(prompt_tokens)
1018
+ metadata = dict(raw.metadata or {})
1019
+ if "method" not in metadata:
1020
+ metadata["method"] = method
1021
+
1022
+ ifr_meta = metadata.get("ifr") if isinstance(metadata.get("ifr"), dict) else {}
1023
+ observation = ifr_meta.get("observation_projected") if isinstance(ifr_meta, dict) else None
1024
+ per_hop_projected = ifr_meta.get("per_hop_projected") if isinstance(ifr_meta, dict) else None
1025
+
1026
+ if isinstance(observation, dict) and "sum" in observation:
1027
+ vector = _to_float_list(observation["sum"])
1028
+ scores = vector[:prompt_len]
1029
+ else:
1030
+ matrix = torch.nan_to_num(raw.attribution_matrix.detach().cpu().to(dtype=torch.float32), nan=0.0)
1031
+ if output_span is not None:
1032
+ start, end = output_span
1033
+ selected = matrix[int(start) : int(end) + 1, :prompt_len]
1034
+ else:
1035
+ selected = matrix[:, :prompt_len]
1036
+ scores = selected.mean(dim=0).tolist() if selected.numel() else [0.0 for _ in prompt_tokens]
1037
+
1038
+ per_hop_scores: list[list[float]] = []
1039
+ if per_hop_projected:
1040
+ for hop_vector in per_hop_projected:
1041
+ per_hop_scores.append(_to_float_list(hop_vector)[:prompt_len])
1042
+
1043
+ ratios = ifr_meta.get("thinking_ratios", []) if isinstance(ifr_meta, dict) else []
1044
+ return TraceResult(
1045
+ prompt_tokens=prompt_tokens,
1046
+ generation_tokens=generation_tokens,
1047
+ scores=[float(x) for x in scores],
1048
+ per_hop_scores=per_hop_scores,
1049
+ thinking_ratios=_to_float_list(ratios),
1050
+ output_span=output_span,
1051
+ reasoning_span=reasoning_span,
1052
+ method=method,
1053
+ metadata=metadata,
1054
+ )
1055
+ ```
1056
+
1057
+ - [ ] **Step 4: Confirm public exports**
1058
+
1059
+ Keep `flashtrace/__init__.py` as:
1060
+
1061
+ ```python
1062
+ """FlashTrace: efficient multi-token attribution for reasoning LLMs."""
1063
+
1064
+ from .model_io import load_model_and_tokenizer
1065
+ from .result import TokenScore, TraceResult
1066
+ from .tracer import FlashTrace
1067
+
1068
+ __all__ = ["FlashTrace", "TraceResult", "TokenScore", "load_model_and_tokenizer"]
1069
+ ```
1070
+
1071
+ - [ ] **Step 5: Run tracer tests**
1072
+
1073
+ Run:
1074
+
1075
+ ```bash
1076
+ uv run pytest tests/test_tracer.py -q
1077
+ ```
1078
+
1079
+ Expected: `2 passed`.
1080
+
1081
+ - [ ] **Step 6: Run package tests created so far**
1082
+
1083
+ Run:
1084
+
1085
+ ```bash
1086
+ uv run pytest tests/test_imports.py tests/test_core_recompute.py tests/test_result.py tests/test_tracer.py -q
1087
+ ```
1088
+
1089
+ Expected: all selected tests pass.
1090
+
1091
+ - [ ] **Step 7: Commit**
1092
+
1093
+ Run:
1094
+
1095
+ ```bash
1096
+ git add flashtrace/tracer.py flashtrace/__init__.py tests/test_tracer.py
1097
+ git commit -m "feat: add FlashTrace public facade"
1098
+ ```
1099
+
1100
+ ## Task 6: Model IO And CLI
1101
+
1102
+ **Files:**
1103
+ - Modify: `flashtrace/model_io.py`
1104
+ - Modify: `flashtrace/cli.py`
1105
+ - Create: `tests/test_cli.py`
1106
+
1107
+ - [ ] **Step 1: Write CLI tests**
1108
+
1109
+ Create `tests/test_cli.py`:
1110
+
1111
+ ```python
1112
+ import pytest
1113
+
1114
+ from flashtrace.cli import main, parse_span
1115
+
1116
+
1117
+ def test_parse_span():
1118
+ assert parse_span("3:8") == (3, 8)
1119
+ assert parse_span(None) is None
1120
+
1121
+
1122
+ @pytest.mark.parametrize("value", ["3", "8:3", "a:b"])
1123
+ def test_parse_span_rejects_invalid_values(value):
1124
+ with pytest.raises(ValueError):
1125
+ parse_span(value)
1126
+
1127
+
1128
+ def test_cli_help_exits_successfully(capsys):
1129
+ with pytest.raises(SystemExit) as exc:
1130
+ main(["--help"])
1131
+
1132
+ assert exc.value.code == 0
1133
+ assert "trace" in capsys.readouterr().out
1134
+ ```
1135
+
1136
+ - [ ] **Step 2: Run CLI tests and see the expected failure**
1137
+
1138
+ Run:
1139
+
1140
+ ```bash
1141
+ uv run pytest tests/test_cli.py -q
1142
+ ```
1143
+
1144
+ Expected: pytest reports missing `parse_span`.
1145
+
1146
+ - [ ] **Step 3: Implement model loading**
1147
+
1148
+ Replace `flashtrace/model_io.py` with:
1149
+
1150
+ ```python
1151
+ from __future__ import annotations
1152
+
1153
+ from typing import Any
1154
+
1155
+ import torch
1156
+ from transformers import AutoModelForCausalLM, AutoTokenizer
1157
+
1158
+
1159
+ def _resolve_dtype(dtype: str | torch.dtype = "auto") -> str | torch.dtype:
1160
+ if isinstance(dtype, torch.dtype):
1161
+ return dtype
1162
+ value = str(dtype).lower()
1163
+ if value == "auto":
1164
+ return "auto"
1165
+ mapping = {
1166
+ "float16": torch.float16,
1167
+ "fp16": torch.float16,
1168
+ "bfloat16": torch.bfloat16,
1169
+ "bf16": torch.bfloat16,
1170
+ "float32": torch.float32,
1171
+ "fp32": torch.float32,
1172
+ }
1173
+ if value not in mapping:
1174
+ raise ValueError(f"Unsupported dtype: {dtype}")
1175
+ return mapping[value]
1176
+
1177
+
1178
+ def load_model_and_tokenizer(
1179
+ model_name_or_path: str,
1180
+ *,
1181
+ device_map: str | dict[str, Any] | None = "auto",
1182
+ dtype: str | torch.dtype = "auto",
1183
+ trust_remote_code: bool = True,
1184
+ **model_kwargs: Any,
1185
+ ):
1186
+ """Load a Hugging Face causal LM and matching tokenizer."""
1187
+
1188
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
1189
+ model = AutoModelForCausalLM.from_pretrained(
1190
+ model_name_or_path,
1191
+ torch_dtype=_resolve_dtype(dtype),
1192
+ device_map=device_map,
1193
+ trust_remote_code=trust_remote_code,
1194
+ **model_kwargs,
1195
+ )
1196
+ model.eval()
1197
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
1198
+ tokenizer.pad_token = tokenizer.eos_token
1199
+ return model, tokenizer
1200
+ ```
1201
+
1202
+ - [ ] **Step 4: Implement CLI**
1203
+
1204
+ Replace `flashtrace/cli.py` with:
1205
+
1206
+ ```python
1207
+ from __future__ import annotations
1208
+
1209
+ import argparse
1210
+ from pathlib import Path
1211
+ from typing import Sequence
1212
+
1213
+ from .model_io import load_model_and_tokenizer
1214
+ from .tracer import FlashTrace
1215
+
1216
+
1217
+ def parse_span(value: str | None) -> tuple[int, int] | None:
1218
+ if value is None:
1219
+ return None
1220
+ parts = str(value).split(":")
1221
+ if len(parts) != 2:
1222
+ raise ValueError("Span must use START:END format.")
1223
+ try:
1224
+ start = int(parts[0])
1225
+ end = int(parts[1])
1226
+ except ValueError as exc:
1227
+ raise ValueError("Span bounds must be integers.") from exc
1228
+ if start < 0 or end < start:
1229
+ raise ValueError("Span must satisfy 0 <= START <= END.")
1230
+ return start, end
1231
+
1232
+
1233
+ def build_parser() -> argparse.ArgumentParser:
1234
+ parser = argparse.ArgumentParser(prog="flashtrace", description="Trace language model outputs with FlashTrace.")
1235
+ sub = parser.add_subparsers(dest="command")
1236
+
1237
+ trace = sub.add_parser("trace", help="Run attribution for a prompt and target.")
1238
+ trace.add_argument("--model", required=True, help="Hugging Face model id or local path.")
1239
+ trace.add_argument("--prompt", required=True, help="UTF-8 text file containing the prompt.")
1240
+ trace.add_argument("--target", help="UTF-8 text file containing the target response.")
1241
+ trace.add_argument("--output-span", help="Inclusive generation-token span START:END.")
1242
+ trace.add_argument("--reasoning-span", help="Inclusive generation-token span START:END.")
1243
+ trace.add_argument("--hops", type=int, default=1)
1244
+ trace.add_argument("--method", default="flashtrace", choices=["flashtrace", "ifr-span", "ifr-matrix"])
1245
+ trace.add_argument("--html", help="Write standalone HTML heatmap.")
1246
+ trace.add_argument("--json", help="Write JSON trace.")
1247
+ trace.add_argument("--device-map", default="auto")
1248
+ trace.add_argument("--dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"])
1249
+ trace.add_argument("--chunk-tokens", type=int, default=128)
1250
+ trace.add_argument("--sink-chunk-tokens", type=int, default=32)
1251
+ trace.add_argument("--recompute-attention", action="store_true")
1252
+ return parser
1253
+
1254
+
1255
+ def _read_text(path: str | None) -> str | None:
1256
+ if path is None:
1257
+ return None
1258
+ return Path(path).read_text(encoding="utf-8")
1259
+
1260
+
1261
+ def _run_trace(args: argparse.Namespace) -> int:
1262
+ model, tokenizer = load_model_and_tokenizer(args.model, device_map=args.device_map, dtype=args.dtype)
1263
+ tracer = FlashTrace(
1264
+ model,
1265
+ tokenizer,
1266
+ chunk_tokens=args.chunk_tokens,
1267
+ sink_chunk_tokens=args.sink_chunk_tokens,
1268
+ recompute_attention=args.recompute_attention,
1269
+ )
1270
+ result = tracer.trace(
1271
+ prompt=_read_text(args.prompt) or "",
1272
+ target=_read_text(args.target),
1273
+ output_span=parse_span(args.output_span),
1274
+ reasoning_span=parse_span(args.reasoning_span),
1275
+ hops=args.hops,
1276
+ method=args.method,
1277
+ )
1278
+ for item in result.topk_inputs(20):
1279
+ print(f"{item.index}\t{item.score:.6f}\t{item.token!r}")
1280
+ if args.json:
1281
+ result.to_json(args.json)
1282
+ if args.html:
1283
+ result.to_html(args.html)
1284
+ return 0
1285
+
1286
+
1287
+ def main(argv: Sequence[str] | None = None) -> int:
1288
+ parser = build_parser()
1289
+ args = parser.parse_args(argv)
1290
+ if args.command == "trace":
1291
+ return _run_trace(args)
1292
+ parser.print_help()
1293
+ return 0
1294
+ ```
1295
+
1296
+ - [ ] **Step 5: Run CLI tests**
1297
+
1298
+ Run:
1299
+
1300
+ ```bash
1301
+ uv run pytest tests/test_cli.py -q
1302
+ ```
1303
+
1304
+ Expected: all CLI tests pass.
1305
+
1306
+ - [ ] **Step 6: Verify console script metadata**
1307
+
1308
+ Run:
1309
+
1310
+ ```bash
1311
+ uv run flashtrace --help
1312
+ ```
1313
+
1314
+ Expected: help text includes `trace`.
1315
+
1316
+ - [ ] **Step 7: Commit**
1317
+
1318
+ Run:
1319
+
1320
+ ```bash
1321
+ git add flashtrace/model_io.py flashtrace/cli.py tests/test_cli.py
1322
+ git commit -m "feat: add model loader and CLI"
1323
+ ```
1324
+
1325
+ ## Task 7: README, Example, License, And Release Hygiene
1326
+
1327
+ **Files:**
1328
+ - Create: `README.md`
1329
+ - Create: `LICENSE`
1330
+ - Create: `examples/quickstart.py`
1331
+ - Modify: `.gitignore`
1332
+ - Delete: `model_generation.py`
1333
+
1334
+ - [ ] **Step 1: Create the quickstart example**
1335
+
1336
+ Create `examples/quickstart.py`:
1337
+
1338
+ ```python
1339
+ from __future__ import annotations
1340
+
1341
+ import argparse
1342
+
1343
+ from flashtrace import FlashTrace, load_model_and_tokenizer
1344
+
1345
+
1346
+ def build_parser() -> argparse.ArgumentParser:
1347
+ parser = argparse.ArgumentParser(description="FlashTrace quickstart example.")
1348
+ parser.add_argument("--model", required=True, help="Hugging Face model id or local model path.")
1349
+ parser.add_argument("--prompt", required=True, help="Prompt text.")
1350
+ parser.add_argument("--target", help="Target response text.")
1351
+ parser.add_argument("--output-span", default=None, help="Inclusive generation-token span START:END.")
1352
+ parser.add_argument("--reasoning-span", default=None, help="Inclusive generation-token span START:END.")
1353
+ parser.add_argument("--html", default="trace.html", help="Output HTML path.")
1354
+ return parser
1355
+
1356
+
1357
+ def parse_span(value: str | None) -> tuple[int, int] | None:
1358
+ from flashtrace.cli import parse_span as parse_cli_span
1359
+
1360
+ return parse_cli_span(value)
1361
+
1362
+
1363
+ def main() -> int:
1364
+ args = build_parser().parse_args()
1365
+ model, tokenizer = load_model_and_tokenizer(args.model)
1366
+ tracer = FlashTrace(model, tokenizer)
1367
+ trace = tracer.trace(
1368
+ prompt=args.prompt,
1369
+ target=args.target,
1370
+ output_span=parse_span(args.output_span),
1371
+ reasoning_span=parse_span(args.reasoning_span),
1372
+ )
1373
+ for item in trace.topk_inputs(10):
1374
+ print(f"{item.index}\t{item.score:.6f}\t{item.token!r}")
1375
+ trace.to_html(args.html)
1376
+ print(f"wrote {args.html}")
1377
+ return 0
1378
+
1379
+
1380
+ if __name__ == "__main__":
1381
+ raise SystemExit(main())
1382
+ ```
1383
+
1384
+ - [ ] **Step 2: Add README**
1385
+
1386
+ Create `README.md`:
1387
+
1388
+ ```markdown
1389
+ # FlashTrace
1390
+
1391
+ FlashTrace is an efficient multi-token attribution toolkit for reasoning language models. It implements the method described in [Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs](https://arxiv.org/abs/2602.01914).
1392
+
1393
+ ## Install
1394
+
1395
+ ```bash
1396
+ pip install -e .
1397
+ ```
1398
+
1399
+ ## Python Quickstart
1400
+
1401
+ ```python
1402
+ from flashtrace import FlashTrace, load_model_and_tokenizer
1403
+
1404
+ model, tokenizer = load_model_and_tokenizer("Qwen/Qwen3-8B")
1405
+ tracer = FlashTrace(model, tokenizer)
1406
+
1407
+ trace = tracer.trace(
1408
+ prompt="Context: Paris is the capital of France.\nQuestion: What is the capital of France?",
1409
+ target="Paris",
1410
+ output_span=(0, 0),
1411
+ hops=1,
1412
+ )
1413
+
1414
+ print(trace.topk_inputs(10))
1415
+ trace.to_html("trace.html")
1416
+ trace.to_json("trace.json")
1417
+ ```
1418
+
1419
+ ## CLI Quickstart
1420
+
1421
+ ```bash
1422
+ flashtrace trace \
1423
+ --model Qwen/Qwen3-8B \
1424
+ --prompt prompt.txt \
1425
+ --target target.txt \
1426
+ --output-span 0:0 \
1427
+ --hops 1 \
1428
+ --html trace.html \
1429
+ --json trace.json
1430
+ ```
1431
+
1432
+ ## Token Spans
1433
+
1434
+ `output_span` and `reasoning_span` use inclusive generation-token indices. Inspect `trace.generation_tokens` after an initial run to choose spans for a target answer or reasoning segment.
1435
+
1436
+ ## Supported Models
1437
+
1438
+ The package targets Llama/Qwen-style decoder-only Hugging Face causal LMs with standard Q/K/V/O projections, RMSNorm or LayerNorm, and RoPE metadata. Qwen2, Qwen3, and Llama are the first validated model families.
1439
+
1440
+ ## Repository Map
1441
+
1442
+ - `flashtrace/`: reusable package
1443
+ - `examples/`: public examples
1444
+ - `tests/`: CPU smoke tests
1445
+ - `exp/`: paper experiments and artifacts
1446
+
1447
+ ## Citation
1448
+
1449
+ ```bibtex
1450
+ @misc{pan2026flashtrace,
1451
+ title={Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs},
1452
+ author={Pan, Wenbo and Liu, Zhichao and Wang, Xianlong and Yu, Haining and Jia, Xiaohua},
1453
+ year={2026},
1454
+ eprint={2602.01914},
1455
+ archivePrefix={arXiv},
1456
+ primaryClass={cs.LG}
1457
+ }
1458
+ ```
1459
+ ```
1460
+
1461
+ - [ ] **Step 3: Add MIT license**
1462
+
1463
+ Create `LICENSE`:
1464
+
1465
+ ```text
1466
+ MIT License
1467
+
1468
+ Copyright (c) 2026 Wenbo Pan
1469
+
1470
+ Permission is hereby granted, free of charge, to any person obtaining a copy
1471
+ of this software and associated documentation files (the "Software"), to deal
1472
+ in the Software without restriction, including without limitation the rights
1473
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
1474
+ copies of the Software, and to permit persons to whom the Software is
1475
+ furnished to do so, subject to the following conditions:
1476
+
1477
+ The above copyright notice and this permission notice shall be included in all
1478
+ copies or substantial portions of the Software.
1479
+
1480
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1481
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1482
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1483
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1484
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1485
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1486
+ SOFTWARE.
1487
+ ```
1488
+
1489
+ - [ ] **Step 4: Update generated artifact ignore rules**
1490
+
1491
+ Append to `.gitignore`:
1492
+
1493
+ ```gitignore
1494
+
1495
+ # FlashTrace generated artifacts
1496
+ trace.json
1497
+ trace.html
1498
+ *.trace.json
1499
+ *.trace.html
1500
+ exp/**/output/
1501
+ exp/**/out/
1502
+ exp/**/out-*/
1503
+ *.npz
1504
+ ```
1505
+
1506
+ - [ ] **Step 5: Remove the template artifact**
1507
+
1508
+ Run:
1509
+
1510
+ ```bash
1511
+ git rm model_generation.py
1512
+ ```
1513
+
1514
+ - [ ] **Step 6: Verify quickstart help**
1515
+
1516
+ Run:
1517
+
1518
+ ```bash
1519
+ uv run python examples/quickstart.py --help
1520
+ ```
1521
+
1522
+ Expected: help text includes `FlashTrace quickstart example`.
1523
+
1524
+ - [ ] **Step 7: Commit**
1525
+
1526
+ Run:
1527
+
1528
+ ```bash
1529
+ git add README.md LICENSE examples/quickstart.py .gitignore
1530
+ git commit -m "docs: add public quickstart and release hygiene"
1531
+ ```
1532
+
1533
+ ## Task 8: Final Verification And Package Audit
1534
+
1535
+ **Files:**
1536
+ - Modify: any files needed to fix verification failures.
1537
+
1538
+ - [ ] **Step 1: Run the full CPU test suite**
1539
+
1540
+ Run:
1541
+
1542
+ ```bash
1543
+ uv run pytest tests -q
1544
+ ```
1545
+
1546
+ Expected: all tests pass.
1547
+
1548
+ - [ ] **Step 2: Verify editable install import**
1549
+
1550
+ Run:
1551
+
1552
+ ```bash
1553
+ uv run python -c "import flashtrace; print(flashtrace.FlashTrace.__name__)"
1554
+ ```
1555
+
1556
+ Expected: prints `FlashTrace`.
1557
+
1558
+ - [ ] **Step 3: Verify CLI help**
1559
+
1560
+ Run:
1561
+
1562
+ ```bash
1563
+ uv run flashtrace --help
1564
+ uv run flashtrace trace --help
1565
+ ```
1566
+
1567
+ Expected: both commands print help text.
1568
+
1569
+ - [ ] **Step 4: Verify root compatibility imports**
1570
+
1571
+ Run:
1572
+
1573
+ ```bash
1574
+ uv run python -c "import ifr_core, llm_attr, ft_ifr_improve; print(ifr_core.compute_multi_hop_ifr.__name__)"
1575
+ ```
1576
+
1577
+ Expected: prints `compute_multi_hop_ifr`.
1578
+
1579
+ - [ ] **Step 5: Inspect package file list**
1580
+
1581
+ Run:
1582
+
1583
+ ```bash
1584
+ git status --short
1585
+ find flashtrace -maxdepth 3 -type f | sort
1586
+ ```
1587
+
1588
+ Expected: package files match the design spec and only intended changes appear.
1589
+
1590
+ - [ ] **Step 6: Commit final fixes**
1591
+
1592
+ Run after any verification fixes:
1593
+
1594
+ ```bash
1595
+ git add .
1596
+ git commit -m "test: verify public package smoke tests"
1597
+ ```
1598
+
1599
+ If verification passes with a clean tree after prior commits, record the passing commands in the final implementation response.
1600
+
1601
+ ## Self-Review Checklist
1602
+
1603
+ - Spec coverage: package layout, public API, result export, CLI, visualization, packaging, compatibility, tests, README, and release hygiene each have at least one task.
1604
+ - Type consistency: `FlashTrace.trace`, `TraceResult`, `TokenScore`, `load_model_and_tokenizer`, and CLI span parsing use the same names across tests and implementation steps.
1605
+ - Test path: every implementation task starts with a failing test or a verification command, then ends with a passing command and commit.
docs/superpowers/specs/2026-05-03-flashtrace-public-package-design.md ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashTrace Public Package Design
2
+
3
+ ## Goal
4
+
5
+ Turn the current FlashTrace research repository into an installable, documented Python package that researchers can use from Python or the command line to trace LLM outputs, export JSON traces, and render HTML token heatmaps.
6
+
7
+ ## Release Scope
8
+
9
+ This first public release ships four user-facing capabilities:
10
+
11
+ - A stable Python API centered on `FlashTrace`.
12
+ - A `flashtrace trace` CLI for prompt/target files and Hugging Face model ids or local model paths.
13
+ - A `TraceResult` object with top-k, JSON, and HTML export helpers.
14
+ - A README quickstart that demonstrates Python, CLI, and heatmap workflows.
15
+
16
+ Paper experiment runners and saved experiment artifacts remain in `exp/` as research assets. Their full reproducibility cleanup belongs to a later phase.
17
+
18
+ ## Repository Shape
19
+
20
+ The reusable package lives under `flashtrace/`, examples under `examples/`, tests under `tests/`, and paper experiments under `exp/`.
21
+
22
+ ```text
23
+ flashtrace/
24
+ __init__.py
25
+ tracer.py
26
+ result.py
27
+ core.py
28
+ model_io.py
29
+ viz.py
30
+ cli.py
31
+ baselines/
32
+ __init__.py
33
+ attnlrp.py
34
+ examples/
35
+ quickstart.py
36
+ tests/
37
+ test_core_recompute.py
38
+ test_tracer.py
39
+ test_result.py
40
+ test_cli.py
41
+ exp/
42
+ exp1/
43
+ exp2/
44
+ case_study/
45
+ ```
46
+
47
+ Existing root modules are migrated gradually. During migration, compatibility wrappers remain at the root for experiment scripts that still import `llm_attr`, `ifr_core`, or `ft_ifr_improve`.
48
+
49
+ ## Core Implementation Mapping
50
+
51
+ `flashtrace.core` contains the IFR tensor implementation from `ifr_core.py`:
52
+
53
+ - `extract_model_metadata`
54
+ - `build_weight_pack`
55
+ - `attach_hooks`
56
+ - `recompute_layer_attention`
57
+ - `compute_ifr_sentence_aggregate`
58
+ - `compute_multi_hop_ifr`
59
+ - `compute_ifr_for_all_positions`
60
+
61
+ `flashtrace.tracer` wraps the current high-level attribution classes:
62
+
63
+ - Default `method="flashtrace"` uses the current `LLMIFRAttributionBoth.calculate_ifr_multi_hop_both` behavior.
64
+ - `method="ifr-span"` uses `LLMIFRAttribution.calculate_ifr_span`.
65
+ - `method="ifr-matrix"` uses `LLMIFRAttribution.calculate_ifr_for_all_positions_output_only`.
66
+
67
+ `flashtrace.baselines.attnlrp` contains the AttnLRP patching and recursive baseline code from `lrp_rules.py`, `lrp_patches.py`, and `LLMLRPAttribution`.
68
+
69
+ ## Public Python API
70
+
71
+ The package exports `FlashTrace`, `TraceResult`, and `load_model_and_tokenizer`.
72
+
73
+ ```python
74
+ from flashtrace import FlashTrace, load_model_and_tokenizer
75
+
76
+ model, tokenizer = load_model_and_tokenizer("Qwen/Qwen3-8B", device_map="auto")
77
+ tracer = FlashTrace(model, tokenizer, chunk_tokens=128, sink_chunk_tokens=32)
78
+
79
+ trace = tracer.trace(
80
+ prompt=prompt,
81
+ target=target,
82
+ output_span=(80, 85),
83
+ reasoning_span=(0, 79),
84
+ hops=1,
85
+ )
86
+
87
+ print(trace.topk_inputs(20))
88
+ trace.to_json("trace.json")
89
+ trace.to_html("trace.html")
90
+ ```
91
+
92
+ `FlashTrace.trace(...)` accepts:
93
+
94
+ - `prompt: str`
95
+ - `target: str | None`
96
+ - `output_span: tuple[int, int] | None`
97
+ - `reasoning_span: tuple[int, int] | None`
98
+ - `hops: int`
99
+ - `method: Literal["flashtrace", "ifr-span", "ifr-matrix"]`
100
+ - `renorm_threshold: float | None`
101
+
102
+ Generation-token spans are inclusive and use the tokenizer alignment already produced by the attribution path. The README explains this convention and shows how to inspect `trace.generation_tokens`.
103
+
104
+ ## TraceResult
105
+
106
+ `TraceResult` is a small dataclass that hides the older `LLMAttributionResult` shape from public users.
107
+
108
+ Fields:
109
+
110
+ - `prompt_tokens: list[str]`
111
+ - `generation_tokens: list[str]`
112
+ - `scores: list[float]`
113
+ - `per_hop_scores: list[list[float]]`
114
+ - `thinking_ratios: list[float]`
115
+ - `output_span: tuple[int, int] | None`
116
+ - `reasoning_span: tuple[int, int] | None`
117
+ - `method: str`
118
+ - `metadata: dict[str, Any]`
119
+
120
+ Methods:
121
+
122
+ - `topk_inputs(k: int = 20) -> list[TokenScore]`
123
+ - `to_dict() -> dict[str, Any]`
124
+ - `to_json(path: str | Path) -> None`
125
+ - `to_html(path: str | Path) -> None`
126
+
127
+ `TokenScore` contains `index`, `token`, and `score`. Scores are aligned to `prompt_tokens`.
128
+
129
+ ## CLI
130
+
131
+ The package exposes one console script:
132
+
133
+ ```bash
134
+ flashtrace trace \
135
+ --model Qwen/Qwen3-8B \
136
+ --prompt prompt.txt \
137
+ --target target.txt \
138
+ --output-span 80:85 \
139
+ --reasoning-span 0:79 \
140
+ --hops 1 \
141
+ --html trace.html \
142
+ --json trace.json
143
+ ```
144
+
145
+ CLI behavior:
146
+
147
+ - `--model` accepts a Hugging Face id or local path.
148
+ - `--prompt` and `--target` read UTF-8 text files.
149
+ - `--target` is optional; the model generates with deterministic defaults when this flag is absent.
150
+ - `--output-span` and `--reasoning-span` parse inclusive `START:END` generation-token spans.
151
+ - `--method` defaults to `flashtrace`.
152
+ - `--recompute-attention` enables lower-memory attention recomputation.
153
+ - `--device-map` defaults to `auto`.
154
+ - `--dtype` accepts `auto`, `float16`, `bfloat16`, or `float32`.
155
+
156
+ The command prints a compact top-k table to stdout and writes requested artifacts.
157
+
158
+ ## Visualization
159
+
160
+ `flashtrace.viz` adapts the token heatmap renderer from `exp/case_study/viz.py`.
161
+
162
+ The public heatmap focuses on:
163
+
164
+ - prompt tokens colored by final attribution score,
165
+ - optional per-hop panels,
166
+ - output and reasoning span summary,
167
+ - model/method metadata.
168
+
169
+ The renderer returns a standalone HTML string and writes standalone HTML files through `TraceResult.to_html`.
170
+
171
+ ## Packaging
172
+
173
+ `pyproject.toml` becomes package metadata for `flashtrace`:
174
+
175
+ - `name = "flashtrace"`
176
+ - realistic `requires-python` support for current PyTorch and Transformers use,
177
+ - console script `flashtrace = "flashtrace.cli:main"`,
178
+ - core dependencies: `torch`, `transformers`, `accelerate`, `numpy`, `tqdm`,
179
+ - optional extras: `viz`, `eval`, `dev`, `baselines`.
180
+
181
+ The root README includes:
182
+
183
+ - project tagline,
184
+ - paper link and citation,
185
+ - install instructions,
186
+ - Python quickstart,
187
+ - CLI quickstart,
188
+ - supported model family notes,
189
+ - output interpretation,
190
+ - experiment directory map,
191
+ - troubleshooting for GPU memory and tokenizer spans.
192
+
193
+ ## Compatibility
194
+
195
+ The release supports Llama/Qwen-style decoder-only Hugging Face causal LMs with `model.layers`, Q/K/V/O projections, RMSNorm/LayerNorm, and RoPE metadata. The README names Qwen2, Qwen3, and Llama as validated families.
196
+
197
+ Existing experiment scripts continue to run through temporary root-level compatibility modules while package imports are introduced. A later cleanup can remove the compatibility layer after `exp/` imports are migrated.
198
+
199
+ ## Testing
200
+
201
+ Tests use a tiny randomly initialized Qwen2 model on CPU, following the existing `test_recompute.py` approach.
202
+
203
+ Required coverage:
204
+
205
+ - stored-attention and recomputed-attention paths return close values on the tiny model,
206
+ - `FlashTrace.trace(...)` returns a `TraceResult`,
207
+ - `TraceResult.topk_inputs(...)` sorts and truncates correctly,
208
+ - `TraceResult.to_dict()` is JSON serializable,
209
+ - `TraceResult.to_html()` writes standalone HTML containing token spans,
210
+ - `flashtrace trace --help` exits successfully.
211
+
212
+ Heavy GPU model tests remain manual examples.
213
+
214
+ ## Release Hygiene
215
+
216
+ The release cleanup updates `.gitignore` to cover generated traces, experiment outputs, checkpoints, caches, and HTML/JSON artifacts created by examples.
217
+
218
+ Tracked historical experiment outputs stay untouched during the first package migration. A later artifact cleanup can move them to release assets or remove them with a dedicated confirmation step.
219
+
220
+ `model_generation.py` is a template artifact and is removed or moved outside the package path during implementation.
221
+
222
+ ## Success Criteria
223
+
224
+ The release work is complete when:
225
+
226
+ - `pip install -e .` exposes `flashtrace`,
227
+ - `python examples/quickstart.py --help` works,
228
+ - `flashtrace trace --help` works,
229
+ - package smoke tests pass on CPU,
230
+ - README quickstart matches the implemented API,
231
+ - existing experiment entrypoints either run with compatibility imports or document their package-era invocation.
dump_exp2_hop_vh.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """One-off: add per-hop IFR vectors (vh) into an existing exp2 trace .npz.
3
+
4
+ This is useful when the original exp2 run saved sample-level traces but did not
5
+ include per-hop vectors for some multi-hop IFR variants (e.g. ifr_multi_hop_both).
6
+
7
+ Defaults are written to match the reference commands in `exp/exp2/README.md`.
8
+
9
+ Example (matches the path in the question):
10
+
11
+ python dump_exp2_hop_vh.py \
12
+ --trace_npz exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/ifr_multi_hop_both_n1_mfaithfulness_gen_95ex/ex_000026.npz \
13
+ --dataset exp/exp2/data/morehopqa.jsonl \
14
+ --attr_func ifr_multi_hop_both \
15
+ --model qwen-8B \
16
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
17
+ --cuda 2,3,4,5,6,7 \
18
+ --n_hops 1 \
19
+ --chunk_tokens 128 \
20
+ --sink_chunk_tokens 32 \
21
+ --inplace
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import hashlib
28
+ import json
29
+ import os
30
+ import re
31
+ import sys
32
+ from dataclasses import dataclass
33
+ from pathlib import Path
34
+ from typing import Any, Optional
35
+
36
+
37
+ def _early_set_cuda_visible_devices() -> None:
38
+ parser = argparse.ArgumentParser(add_help=False)
39
+ parser.add_argument("--cuda", type=str, default=None)
40
+ args, _ = parser.parse_known_args(sys.argv[1:])
41
+ if args.cuda and "," in str(args.cuda):
42
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
43
+
44
+
45
+ _early_set_cuda_visible_devices()
46
+
47
+ import numpy as np
48
+ import torch
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer
50
+
51
+ import ft_ifr_improve
52
+ import llm_attr
53
+ from exp.exp2 import dataset_utils as ds_utils
54
+
55
+
56
+ def _sha1_text(text: str) -> str:
57
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
58
+
59
+
60
+ def _resolve_device(cuda: Optional[str], cuda_num: int) -> str:
61
+ """Mirror exp/exp2/run_exp.py device selection policy."""
62
+ if cuda is not None and "," in cuda:
63
+ # _early_set_cuda_visible_devices already applied.
64
+ return "auto"
65
+ if cuda is not None and str(cuda).strip():
66
+ return f"cuda:{cuda}" if torch.cuda.is_available() else "cpu"
67
+ return f"cuda:{int(cuda_num)}" if torch.cuda.is_available() else "cpu"
68
+
69
+
70
+ def _load_model(model_name: str, device: str):
71
+ """Mirror exp/exp2/run_exp.py model loading knobs."""
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ model_name,
74
+ device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
75
+ torch_dtype=torch.float16,
76
+ attn_implementation="eager",
77
+ )
78
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+ model.eval()
81
+ return model, tokenizer
82
+
83
+
84
+ @dataclass(frozen=True)
85
+ class ManifestRecord:
86
+ example_idx: int
87
+ prompt_sha1: str
88
+ target_sha1: Optional[str]
89
+
90
+
91
+ def _load_manifest_record(manifest_path: Path, *, example_idx: int) -> Optional[ManifestRecord]:
92
+ if not manifest_path.exists():
93
+ return None
94
+ with manifest_path.open("r", encoding="utf-8") as f:
95
+ for line in f:
96
+ if not line.strip():
97
+ continue
98
+ obj = json.loads(line)
99
+ if int(obj.get("example_idx", -1)) != int(example_idx):
100
+ continue
101
+ return ManifestRecord(
102
+ example_idx=int(example_idx),
103
+ prompt_sha1=str(obj.get("prompt_sha1") or ""),
104
+ target_sha1=str(obj["target_sha1"]) if obj.get("target_sha1") is not None else None,
105
+ )
106
+ return None
107
+
108
+
109
+ def _parse_example_idx_from_npz_name(path: Path) -> Optional[int]:
110
+ m = re.match(r"^ex_(\d+)$", path.stem)
111
+ if not m:
112
+ return None
113
+ try:
114
+ return int(m.group(1))
115
+ except Exception:
116
+ return None
117
+
118
+
119
+ def _pick_example(
120
+ examples: list[ds_utils.CachedExample],
121
+ *,
122
+ example_idx: int,
123
+ record: Optional[ManifestRecord],
124
+ ) -> ds_utils.CachedExample:
125
+ if record is not None and record.prompt_sha1:
126
+ matches: list[ds_utils.CachedExample] = []
127
+ for ex in examples:
128
+ if _sha1_text(ex.prompt) != record.prompt_sha1:
129
+ continue
130
+ if record.target_sha1 is None:
131
+ if ex.target is None:
132
+ matches.append(ex)
133
+ else:
134
+ if ex.target is not None and _sha1_text(ex.target) == record.target_sha1:
135
+ matches.append(ex)
136
+ if len(matches) == 1:
137
+ return matches[0]
138
+ if len(matches) > 1:
139
+ raise SystemExit(
140
+ f"Manifest sha1 matched multiple dataset entries ({len(matches)}). "
141
+ "Please pass --example_idx to select by index or use a smaller dataset cache."
142
+ )
143
+ raise SystemExit(
144
+ "Failed to locate the trace example in the provided dataset by sha1. "
145
+ "Ensure --dataset points to the same cached JSONL used to produce the trace."
146
+ )
147
+
148
+ if not (0 <= int(example_idx) < len(examples)):
149
+ raise SystemExit(f"example_idx out of range: {example_idx} not in [0, {len(examples)}).")
150
+ return examples[int(example_idx)]
151
+
152
+
153
+ def _extract_vh(attr: Any) -> np.ndarray:
154
+ ifr = (getattr(attr, "metadata", None) or {}).get("ifr") or {}
155
+ per_hop = ifr.get("per_hop_projected") or []
156
+ if not per_hop:
157
+ raise RuntimeError("Attribution result missing metadata['ifr']['per_hop_projected']; cannot build vh.")
158
+ stacked = torch.stack([torch.as_tensor(v, dtype=torch.float32).reshape(-1) for v in per_hop], dim=0)
159
+ return stacked.detach().cpu().numpy().astype(np.float32, copy=False)
160
+
161
+
162
+ def _run_ifr_attr(
163
+ attr_func: str,
164
+ *,
165
+ model: Any,
166
+ tokenizer: Any,
167
+ prompt: str,
168
+ target: str,
169
+ sink_span: Optional[tuple[int, int]],
170
+ thinking_span: Optional[tuple[int, int]],
171
+ n_hops: int,
172
+ chunk_tokens: int,
173
+ sink_chunk_tokens: int,
174
+ ) -> Any:
175
+ if attr_func == "ifr_multi_hop":
176
+ attributor = llm_attr.LLMIFRAttribution(
177
+ model,
178
+ tokenizer,
179
+ chunk_tokens=chunk_tokens,
180
+ sink_chunk_tokens=sink_chunk_tokens,
181
+ )
182
+ return attributor.calculate_ifr_multi_hop(
183
+ prompt,
184
+ target=target,
185
+ sink_span=sink_span,
186
+ thinking_span=thinking_span,
187
+ n_hops=int(n_hops),
188
+ )
189
+ if attr_func == "ifr_in_all_gen":
190
+ attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(
191
+ model,
192
+ tokenizer,
193
+ chunk_tokens=chunk_tokens,
194
+ sink_chunk_tokens=sink_chunk_tokens,
195
+ )
196
+ return attributor.calculate_ifr_in_all_gen(
197
+ prompt,
198
+ target=target,
199
+ sink_span=sink_span,
200
+ thinking_span=thinking_span,
201
+ n_hops=int(n_hops),
202
+ )
203
+ if attr_func == "ifr_multi_hop_stop_words":
204
+ attributor = ft_ifr_improve.LLMIFRAttributionImproved(
205
+ model,
206
+ tokenizer,
207
+ chunk_tokens=chunk_tokens,
208
+ sink_chunk_tokens=sink_chunk_tokens,
209
+ )
210
+ return attributor.calculate_ifr_multi_hop_stop_words(
211
+ prompt,
212
+ target=target,
213
+ sink_span=sink_span,
214
+ thinking_span=thinking_span,
215
+ n_hops=int(n_hops),
216
+ )
217
+ if attr_func == "ifr_multi_hop_both":
218
+ attributor = ft_ifr_improve.LLMIFRAttributionBoth(
219
+ model,
220
+ tokenizer,
221
+ chunk_tokens=chunk_tokens,
222
+ sink_chunk_tokens=sink_chunk_tokens,
223
+ )
224
+ return attributor.calculate_ifr_multi_hop_both(
225
+ prompt,
226
+ target=target,
227
+ sink_span=sink_span,
228
+ thinking_span=thinking_span,
229
+ n_hops=int(n_hops),
230
+ )
231
+ if attr_func == "ifr_multi_hop_split_hop":
232
+ attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(
233
+ model,
234
+ tokenizer,
235
+ chunk_tokens=chunk_tokens,
236
+ sink_chunk_tokens=sink_chunk_tokens,
237
+ )
238
+ return attributor.calculate_ifr_multi_hop_split_hop(
239
+ prompt,
240
+ target=target,
241
+ sink_span=sink_span,
242
+ thinking_span=thinking_span,
243
+ n_hops=int(n_hops),
244
+ )
245
+ raise SystemExit(
246
+ f"Unsupported --attr_func '{attr_func}'. "
247
+ "Supported (vh-capable IFR variants): "
248
+ "ifr_multi_hop, ifr_in_all_gen, ifr_multi_hop_stop_words, ifr_multi_hop_both, ifr_multi_hop_split_hop."
249
+ )
250
+
251
+
252
+ def _save_npz(
253
+ out_path: Path,
254
+ *,
255
+ payload: dict[str, np.ndarray],
256
+ inplace_src: Optional[Path] = None,
257
+ backup: bool = True,
258
+ overwrite_backup: bool = False,
259
+ ) -> None:
260
+ out_path.parent.mkdir(parents=True, exist_ok=True)
261
+ if inplace_src is not None:
262
+ if backup and inplace_src.exists():
263
+ backup_path = inplace_src.with_name(inplace_src.name + ".bak")
264
+ if overwrite_backup and backup_path.exists():
265
+ backup_path.unlink()
266
+ if not backup_path.exists():
267
+ backup_path.write_bytes(inplace_src.read_bytes())
268
+
269
+ # NOTE: numpy.savez* appends ".npz" if the filename does not already end with ".npz".
270
+ # So we must ensure our temporary path ends with ".npz", otherwise we'd write
271
+ # "<name>.tmp.npz" but later try to os.replace("<name>.tmp", ...).
272
+ tmp_path = out_path.with_name(out_path.stem + ".tmp.npz")
273
+ if tmp_path.exists():
274
+ tmp_path.unlink()
275
+ np.savez_compressed(tmp_path, **payload)
276
+ os.replace(tmp_path, out_path)
277
+ return
278
+
279
+ if out_path.exists():
280
+ raise SystemExit(f"Refusing to overwrite existing file: {out_path} (use --inplace).")
281
+ np.savez_compressed(out_path, **payload)
282
+
283
+
284
+ def main() -> None:
285
+ parser = argparse.ArgumentParser("One-off exp2 trace patcher: add per-hop vh vectors.")
286
+ parser.add_argument(
287
+ "--trace_npz",
288
+ type=str,
289
+ default=(
290
+ "exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/"
291
+ "ifr_multi_hop_both_n1_mfaithfulness_gen_95ex/ex_000026.npz"
292
+ ),
293
+ help="Path to the existing exp2 trace npz (ex_*.npz).",
294
+ )
295
+ parser.add_argument(
296
+ "--dataset",
297
+ type=str,
298
+ default="exp/exp2/data/morehopqa.jsonl",
299
+ help="Path to the exp2 cached dataset JSONL used to produce the trace.",
300
+ )
301
+ parser.add_argument(
302
+ "--attr_func",
303
+ type=str,
304
+ default="ifr_multi_hop_both",
305
+ help="Attribution method to rerun (vh-capable IFR variants only).",
306
+ )
307
+ parser.add_argument("--example_idx", type=int, default=None, help="Override example_idx (0-based).")
308
+ parser.add_argument("--sample", type=int, default=None, help="If the original run used --sample, set it here.")
309
+ parser.add_argument("--seed", type=int, default=42, help="Seed for --sample shuffling (must match original).")
310
+
311
+ parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (used when --model_path not set).")
312
+ parser.add_argument(
313
+ "--model_path",
314
+ type=str,
315
+ default="/opt/share/models/Qwen/Qwen3-8B/",
316
+ help="Local model path; overrides --model for loading (matches exp2 README examples).",
317
+ )
318
+ parser.add_argument(
319
+ "--cuda",
320
+ type=str,
321
+ default="2,3,4,5,6,7",
322
+ help="CUDA selection (same semantics as exp2): '0' or '0,1,2'.",
323
+ )
324
+ parser.add_argument("--cuda_num", type=int, default=0, help="Single-device index when --cuda not set.")
325
+
326
+ parser.add_argument("--chunk_tokens", type=int, default=128)
327
+ parser.add_argument("--sink_chunk_tokens", type=int, default=32)
328
+ parser.add_argument("--n_hops", type=int, default=1)
329
+
330
+ parser.add_argument(
331
+ "--inplace",
332
+ action="store_true",
333
+ help="Overwrite the trace npz in place (recommended so manifest.jsonl stays valid).",
334
+ )
335
+ parser.add_argument("--no_backup", action="store_true", help="Disable .bak creation when using --inplace.")
336
+ parser.add_argument(
337
+ "--overwrite_backup",
338
+ action="store_true",
339
+ help="Allow replacing an existing .bak when using --inplace.",
340
+ )
341
+ args = parser.parse_args()
342
+
343
+ trace_npz = Path(args.trace_npz)
344
+ if not trace_npz.exists():
345
+ raise SystemExit(f"Missing trace npz: {trace_npz}")
346
+
347
+ example_idx = args.example_idx
348
+ if example_idx is None:
349
+ example_idx = _parse_example_idx_from_npz_name(trace_npz)
350
+ if example_idx is None:
351
+ raise SystemExit("Failed to infer --example_idx from trace filename; please pass --example_idx explicitly.")
352
+
353
+ manifest_path = trace_npz.with_name("manifest.jsonl")
354
+ record = _load_manifest_record(manifest_path, example_idx=int(example_idx))
355
+
356
+ dataset_path = Path(args.dataset)
357
+ if not dataset_path.exists():
358
+ raise SystemExit(f"Missing cached dataset JSONL: {dataset_path}")
359
+ examples = ds_utils.load_cached(dataset_path, sample=args.sample, seed=args.seed)
360
+ ex = _pick_example(examples, example_idx=int(example_idx), record=record)
361
+
362
+ if ex.target is None:
363
+ raise SystemExit("Cached dataset example has target=None; this script requires cached targets (CoT+answer).")
364
+ prompt = ex.prompt
365
+ target = ex.target
366
+
367
+ sink_span = tuple(ex.sink_span) if ex.sink_span else None
368
+ thinking_span = tuple(ex.thinking_span) if ex.thinking_span else None
369
+
370
+ model_name = str(args.model_path or args.model).strip()
371
+ if not model_name:
372
+ raise SystemExit("Please set --model or --model_path.")
373
+ device = _resolve_device(args.cuda, args.cuda_num)
374
+ model, tokenizer = _load_model(model_name, device)
375
+
376
+ attr = _run_ifr_attr(
377
+ str(args.attr_func),
378
+ model=model,
379
+ tokenizer=tokenizer,
380
+ prompt=prompt,
381
+ target=target,
382
+ sink_span=sink_span,
383
+ thinking_span=thinking_span,
384
+ n_hops=int(args.n_hops),
385
+ chunk_tokens=int(args.chunk_tokens),
386
+ sink_chunk_tokens=int(args.sink_chunk_tokens),
387
+ )
388
+ vh = _extract_vh(attr)
389
+
390
+ with np.load(trace_npz, allow_pickle=False) as old:
391
+ payload = {k: old[k] for k in old.files}
392
+ payload["vh"] = vh
393
+
394
+ if args.inplace:
395
+ out_path = trace_npz
396
+ else:
397
+ out_path = trace_npz.with_name(trace_npz.stem + "_with_vh.npz")
398
+
399
+ _save_npz(
400
+ out_path,
401
+ payload=payload,
402
+ inplace_src=trace_npz if args.inplace else None,
403
+ backup=not bool(args.no_backup),
404
+ overwrite_backup=bool(args.overwrite_backup),
405
+ )
406
+
407
+ print(f"Saved vh -> {out_path}")
408
+ print(f"vh shape: {vh.shape} (n_hops+1, prompt_len+gen_len)")
409
+
410
+
411
+ if __name__ == "__main__":
412
+ main()
evaluations/attribution_recovery.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Ensure project root is importable regardless of CWD
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+
7
+ import argparse
8
+ import csv
9
+ import json
10
+ import math
11
+ import random
12
+ import time
13
+ from itertools import islice
14
+ from pathlib import Path
15
+ from typing import List, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ from tqdm import tqdm
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, utils
21
+
22
+ import llm_attr
23
+ import llm_attr_eval
24
+ from exp.exp2 import dataset_utils as ds_utils
25
+
26
+
27
+ utils.logging.set_verbosity_error()
28
+
29
+
30
+ def _first_json_obj(path: Path) -> dict:
31
+ with path.open("r", encoding="utf-8") as f:
32
+ for line in f:
33
+ line = line.strip()
34
+ if line:
35
+ return json.loads(line)
36
+ return {}
37
+
38
+
39
+ def _load_ruler_examples(args) -> Tuple[str, List[ds_utils.CachedExample]]:
40
+ ds_arg = args.dataset
41
+ cache_dir = Path(args.data_root)
42
+
43
+ # 1) If dataset points to an existing file, detect cache vs raw RULER.
44
+ p = Path(ds_arg)
45
+ if p.exists():
46
+ obj = _first_json_obj(p)
47
+ if "prompt" in obj:
48
+ return p.stem, ds_utils.load_cached(p, sample=args.sample, seed=args.seed)
49
+ if "input" in obj and "needle_spans" in obj:
50
+ return p.stem, ds_utils.load_ruler(p, sample=args.sample, seed=args.seed)
51
+ raise SystemExit(
52
+ f"Unsupported JSONL schema for recovery_ruler: {p}. "
53
+ "Expected either exp2 cache (has 'prompt') or raw RULER JSONL (has 'input'+'needle_spans')."
54
+ )
55
+
56
+ # 2) Prefer exp2 cache under --data_root by dataset name.
57
+ cached = cache_dir / f"{ds_arg}.jsonl"
58
+ if cached.exists():
59
+ return ds_arg, ds_utils.load_cached(cached, sample=args.sample, seed=args.seed)
60
+
61
+ # 3) Fall back to raw RULER resolution by name.
62
+ resolved = ds_utils.dataset_from_name(ds_arg)
63
+ if resolved is None:
64
+ raise SystemExit(f"Could not resolve RULER dataset name '{ds_arg}'.")
65
+ return ds_arg, ds_utils.load_ruler(resolved, sample=args.sample, seed=args.seed)
66
+
67
+
68
+ def _resolve_indices_to_explain_token_span(
69
+ attr_result: llm_attr.LLMAttributionResult, indices_to_explain: list[int] | None
70
+ ) -> list[int]:
71
+ if (
72
+ isinstance(indices_to_explain, list)
73
+ and len(indices_to_explain) == 2
74
+ and all(isinstance(x, int) and x >= 0 for x in indices_to_explain)
75
+ and indices_to_explain[0] <= indices_to_explain[1]
76
+ ):
77
+ return indices_to_explain
78
+
79
+ gen_len = int(attr_result.attribution_matrix.shape[0])
80
+ if gen_len <= 0:
81
+ return [0, 0]
82
+
83
+ # Default: explain the full generation excluding the appended EOS token.
84
+ end_tok = max(0, gen_len - 2)
85
+ return [0, end_tok]
86
+
87
+
88
+ def run_attribution(
89
+ testing_dict, example: ds_utils.CachedExample, batch_size: int, target: Optional[str]
90
+ ) -> tuple[List[torch.Tensor], dict | None]:
91
+ model = testing_dict["model"]
92
+ tokenizer = testing_dict["tokenizer"]
93
+ attr_func = testing_dict["attr_func"]
94
+
95
+ if "IG" in attr_func:
96
+ llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
97
+ attr = llm_attributor.calculate_IG_per_generation(
98
+ example.prompt,
99
+ 20,
100
+ tokenizer.eos_token_id,
101
+ batch_size=batch_size,
102
+ target=target,
103
+ )
104
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
105
+ return list(attr.get_all_token_attrs(token_span)), None
106
+
107
+ if "perturbation" in attr_func:
108
+ llm_attributor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
109
+ if attr_func == "perturbation_all":
110
+ attr = llm_attributor.calculate_feature_ablation_sentences(
111
+ example.prompt, baseline=tokenizer.eos_token_id, measure="log_loss", target=target
112
+ )
113
+ elif attr_func == "perturbation_CLP":
114
+ attr = llm_attributor.calculate_feature_ablation_sentences(
115
+ example.prompt, baseline=tokenizer.eos_token_id, measure="KL", target=target
116
+ )
117
+ elif attr_func == "perturbation_REAGENT":
118
+ attr = llm_attributor.calculate_feature_ablation_sentences_mlm(example.prompt, target=target)
119
+ else:
120
+ raise ValueError(f"Unsupported perturbation attr_func {attr_func}")
121
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
122
+ return list(attr.get_all_token_attrs(token_span)), None
123
+
124
+ if "attention" in attr_func:
125
+ llm_attributor = llm_attr.LLMAttentionAttribution(model, tokenizer)
126
+ llm_attributor_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
127
+ attr = llm_attributor.calculate_attention_attribution(example.prompt, target=target)
128
+ if attr_func == "attention_I_G":
129
+ attr_b = llm_attributor_ig.calculate_IG_per_generation(
130
+ example.prompt, 20, tokenizer.eos_token_id, batch_size=batch_size, target=target
131
+ )
132
+ attr.attribution_matrix = attr.attribution_matrix * attr_b.attribution_matrix
133
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
134
+ return list(attr.get_all_token_attrs(token_span)), None
135
+
136
+ if attr_func == "ifr_all_positions":
137
+ llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
138
+ attr = llm_attributor.calculate_ifr_for_all_positions(example.prompt, target=target)
139
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
140
+ return list(attr.get_all_token_attrs(token_span)), None
141
+
142
+ if attr_func == "ifr_all_positions_output_only":
143
+ llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
144
+ sink_span = tuple(example.sink_span) if example.sink_span else None
145
+ attr = llm_attributor.calculate_ifr_for_all_positions_output_only(
146
+ example.prompt,
147
+ target=target,
148
+ sink_span=sink_span,
149
+ )
150
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
151
+ return list(attr.get_all_token_attrs(token_span)), None
152
+
153
+ if attr_func == "ifr_span":
154
+ llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
155
+ span = example.sink_span if example.sink_span else None
156
+ attr = llm_attributor.calculate_ifr_span(example.prompt, target=target, span=tuple(span) if span else None)
157
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
158
+ return list(attr.get_all_token_attrs(token_span)), None
159
+
160
+ if attr_func == "ifr_multi_hop":
161
+ llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
162
+ attr = llm_attributor.calculate_ifr_multi_hop(
163
+ example.prompt,
164
+ target=target,
165
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
166
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
167
+ n_hops=testing_dict.get("n_hops", 1),
168
+ )
169
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
170
+ return list(attr.get_all_token_attrs(token_span)), None
171
+
172
+ if attr_func == "ifr_in_all_gen":
173
+ import ft_ifr_improve
174
+
175
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(model, tokenizer)
176
+ attr = llm_attributor.calculate_ifr_in_all_gen(
177
+ example.prompt,
178
+ target=target,
179
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
180
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
181
+ n_hops=testing_dict.get("n_hops", 1),
182
+ )
183
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
184
+ return list(attr.get_all_token_attrs(token_span)), None
185
+
186
+ if attr_func == "ifr_multi_hop_stop_words":
187
+ import ft_ifr_improve
188
+
189
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionImproved(model, tokenizer)
190
+ attr = llm_attributor.calculate_ifr_multi_hop_stop_words(
191
+ example.prompt,
192
+ target=target,
193
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
194
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
195
+ n_hops=testing_dict.get("n_hops", 1),
196
+ )
197
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
198
+ extra = {
199
+ "keep_prompt_token_indices": ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens)),
200
+ }
201
+ return list(attr.get_all_token_attrs(token_span)), extra
202
+
203
+ if attr_func == "ifr_multi_hop_both":
204
+ import ft_ifr_improve
205
+
206
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionBoth(model, tokenizer)
207
+ attr = llm_attributor.calculate_ifr_multi_hop_both(
208
+ example.prompt,
209
+ target=target,
210
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
211
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
212
+ n_hops=testing_dict.get("n_hops", 1),
213
+ )
214
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
215
+ extra = {
216
+ "keep_prompt_token_indices": ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens)),
217
+ }
218
+ return list(attr.get_all_token_attrs(token_span)), extra
219
+
220
+ if attr_func == "ifr_multi_hop_split_hop":
221
+ import ft_ifr_improve
222
+
223
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(model, tokenizer)
224
+ attr = llm_attributor.calculate_ifr_multi_hop_split_hop(
225
+ example.prompt,
226
+ target=target,
227
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
228
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
229
+ n_hops=testing_dict.get("n_hops", 1),
230
+ )
231
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
232
+ return list(attr.get_all_token_attrs(token_span)), None
233
+
234
+ if attr_func == "basic":
235
+ llm_attributor = llm_attr.LLMBasicAttribution(model, tokenizer)
236
+ attr = llm_attributor.calculate_basic_attribution(example.prompt, target=target)
237
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
238
+ return list(attr.get_all_token_attrs(token_span)), None
239
+
240
+ if attr_func == "attnlrp":
241
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
242
+ sink_span = getattr(example, "sink_span", None)
243
+ thinking_span = getattr(example, "thinking_span", None)
244
+ attr = llm_attributor.calculate_attnlrp_ft_hop0(
245
+ example.prompt,
246
+ target=target,
247
+ sink_span=tuple(sink_span) if sink_span else None,
248
+ thinking_span=tuple(thinking_span) if thinking_span else None,
249
+ )
250
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
251
+ return list(attr.get_all_token_attrs(token_span)), None
252
+
253
+ if attr_func == "attnlrp_aggregated":
254
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
255
+ attr = llm_attributor.calculate_attnlrp_aggregated(example.prompt, target=target)
256
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
257
+ return list(attr.get_all_token_attrs(token_span)), None
258
+
259
+ if attr_func == "attnlrp_aggregated_multi_hop":
260
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
261
+ attr = llm_attributor.calculate_attnlrp_aggregated_multi_hop(
262
+ example.prompt,
263
+ target=target,
264
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
265
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
266
+ n_hops=testing_dict.get("n_hops", 1),
267
+ )
268
+ token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
269
+ return list(attr.get_all_token_attrs(token_span)), None
270
+
271
+ raise ValueError(f"Unsupported attribution function '{attr_func}'.")
272
+
273
+
274
+ def evaluate_dataset_recovery_ruler(testing_dict, dataset_name: str, examples: List[ds_utils.CachedExample]) -> Tuple[np.ndarray, np.ndarray, float, int, int]:
275
+ tokenizer = testing_dict["tokenizer"]
276
+ llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(testing_dict["model"], tokenizer)
277
+
278
+ results: List[np.ndarray] = []
279
+ durations: List[float] = []
280
+ skipped = 0
281
+
282
+ num_examples = testing_dict["num_examples"]
283
+ total = min(len(examples), num_examples)
284
+ iterator = islice(examples, total)
285
+
286
+ description = f"Recovery@10pct {testing_dict['model_name']} {dataset_name} {testing_dict['attr_func']}"
287
+ for ex in tqdm(iterator, desc=description, total=total):
288
+ needle_spans = (ex.metadata or {}).get("needle_spans")
289
+ if not isinstance(needle_spans, list) or not needle_spans:
290
+ raise SystemExit(
291
+ "recovery_ruler only supports RULER examples with metadata.needle_spans; "
292
+ f"dataset={dataset_name} has missing/empty needle_spans."
293
+ )
294
+
295
+ gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer)
296
+ if not gold_prompt:
297
+ skipped += 1
298
+ continue
299
+
300
+ # Batch size is set based on the max_input_len (same policy as faithfulness).
301
+ target = ex.target
302
+ if target is None:
303
+ generation, full_output = llm_evaluator.response(ex.prompt)
304
+ target = generation
305
+ response_len = len(tokenizer(full_output).input_ids)
306
+ else:
307
+ response_len = len(tokenizer(llm_evaluator.format_prompt(" " + ex.prompt) + target).input_ids)
308
+ batch_size = max(1, math.floor((testing_dict["max_input_len"] - 100) / max(1, response_len)))
309
+
310
+ sample_start = time.perf_counter()
311
+ attr_list, extra = run_attribution(testing_dict, ex, batch_size, target)
312
+ durations.append(time.perf_counter() - sample_start)
313
+
314
+ seq_attr = attr_list[0]
315
+ prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
316
+ if prompt_len <= 0:
317
+ skipped += 1
318
+ continue
319
+
320
+ if testing_dict["attr_func"] in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and extra is not None:
321
+ import ft_ifr_improve
322
+
323
+ keep_prompt_token_indices = extra.get("keep_prompt_token_indices") or []
324
+ gold_filtered = [idx for idx in gold_prompt if int(idx) in set(int(x) for x in keep_prompt_token_indices)]
325
+ if not gold_filtered:
326
+ skipped += 1
327
+ continue
328
+ scores = [
329
+ ft_ifr_improve.evaluate_attr_recovery_skip_tokens(
330
+ attr[:, :prompt_len],
331
+ keep_prompt_token_indices=keep_prompt_token_indices,
332
+ gold_prompt_token_indices=gold_prompt,
333
+ top_fraction=0.1,
334
+ )
335
+ for attr in attr_list
336
+ ]
337
+ else:
338
+ scores = [
339
+ llm_evaluator.evaluate_attr_recovery(
340
+ attr,
341
+ prompt_len=prompt_len,
342
+ gold_prompt_token_indices=gold_prompt,
343
+ top_fraction=0.1,
344
+ )
345
+ for attr in attr_list
346
+ ]
347
+ results.append(np.asarray(scores, dtype=np.float64))
348
+
349
+ scores = np.stack(results, axis=0) if results else np.zeros((0, 3), dtype=np.float64)
350
+ used = int(scores.shape[0])
351
+ mean = scores.mean(0) if used else np.full((3,), np.nan, dtype=np.float64)
352
+ std = scores.std(0) if used else np.full((3,), np.nan, dtype=np.float64)
353
+ avg_time = float(np.mean(durations)) if durations else 0.0
354
+ return mean, std, avg_time, used, int(skipped)
355
+
356
+
357
+ def load_model(model_name: str, device: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
358
+ seed = 42
359
+ random.seed(seed)
360
+ np.random.seed(seed)
361
+ torch.manual_seed(seed)
362
+ torch.cuda.manual_seed(seed)
363
+ torch.cuda.manual_seed_all(seed)
364
+
365
+ if device == "auto":
366
+ model = AutoModelForCausalLM.from_pretrained(
367
+ model_name,
368
+ device_map="auto",
369
+ attn_implementation="eager",
370
+ torch_dtype=torch.float16,
371
+ )
372
+ elif isinstance(device, str) and device.startswith("cuda:"):
373
+ try:
374
+ gpu_idx = int(device.split(":")[1])
375
+ except Exception:
376
+ gpu_idx = 0
377
+ model = AutoModelForCausalLM.from_pretrained(
378
+ model_name,
379
+ device_map={"": gpu_idx},
380
+ attn_implementation="eager",
381
+ torch_dtype=torch.float16,
382
+ )
383
+ else:
384
+ model = AutoModelForCausalLM.from_pretrained(
385
+ model_name,
386
+ attn_implementation="eager",
387
+ torch_dtype=torch.float16,
388
+ )
389
+ model.eval()
390
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
391
+ tokenizer.pad_token = tokenizer.eos_token
392
+ return model, tokenizer
393
+
394
+
395
+ def main(args) -> None:
396
+ if args.cuda is not None and isinstance(args.cuda, str) and "," in args.cuda:
397
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
398
+ device = "auto"
399
+ elif args.cuda is not None and isinstance(args.cuda, str) and args.cuda.strip() != "":
400
+ try:
401
+ idx = int(args.cuda)
402
+ except Exception:
403
+ idx = 0
404
+ device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
405
+ else:
406
+ device = f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
407
+
408
+ if args.model == "llama-1B":
409
+ model_name = "meta-llama/Llama-3.2-1B-Instruct"
410
+ max_input_len = 5500
411
+ elif args.model == "llama-3B":
412
+ model_name = "meta-llama/Llama-3.2-3B-Instruct"
413
+ max_input_len = 4800
414
+ elif args.model == "llama-8B":
415
+ model_name = "meta-llama/Llama-3.1-8B-Instruct"
416
+ max_input_len = 3500
417
+ elif args.model == "qwen-1.7B":
418
+ model_name = "Qwen/Qwen3-1.7B"
419
+ max_input_len = 5500
420
+ elif args.model == "qwen-4B":
421
+ model_name = "Qwen/Qwen3-4B-Instruct-2507"
422
+ max_input_len = 3500
423
+ elif args.model == "qwen-8B":
424
+ model_name = "Qwen/Qwen3-8B"
425
+ max_input_len = 3000
426
+ elif args.model == "qwen-32B":
427
+ model_name = "Qwen/Qwen3-32B"
428
+ max_input_len = 1500
429
+ elif args.model == "gemma-12B":
430
+ model_name = "gemma/gemma-3-12b-it"
431
+ max_input_len = 1500
432
+ elif args.model == "gemma-27B":
433
+ model_name = "gemma/gemma-3-27b-it"
434
+ max_input_len = 2000
435
+ else:
436
+ model_name = args.model_path if args.model_path is not None else args.model
437
+ max_input_len = 2000
438
+
439
+ model, tokenizer = load_model(model_name if args.model_path is None else args.model_path, device)
440
+
441
+ dataset_name, examples = _load_ruler_examples(args)
442
+
443
+ testing_dict = {
444
+ "model": model,
445
+ "model_name": args.model,
446
+ "tokenizer": tokenizer,
447
+ "dataset_name": dataset_name,
448
+ "attr_func": args.attr_func,
449
+ "num_examples": args.num_examples,
450
+ "max_input_len": max_input_len,
451
+ "n_hops": args.n_hops,
452
+ }
453
+
454
+ mean, std, avg_time, used, skipped = evaluate_dataset_recovery_ruler(testing_dict, dataset_name, examples)
455
+
456
+ out_dir = Path("./test_results") / "attribution_recovery" / dataset_name / args.model
457
+ out_dir.mkdir(parents=True, exist_ok=True)
458
+ file_name = f"{args.attr_func}_{args.num_examples}_examples.csv"
459
+ with open(out_dir / file_name, "w", newline="") as f:
460
+ writer = csv.writer(f)
461
+ writer.writerow(["Method", "Recovery@10pct"])
462
+ writer.writerow(["Seq Attr Recovery Mean", mean[0]])
463
+ writer.writerow(["Row Attr Recovery Mean", mean[1]])
464
+ writer.writerow(["Recursive Attr Recovery Mean", mean[2]])
465
+ writer.writerow(["Seq Attr Recovery Std", std[0]])
466
+ writer.writerow(["Row Attr Recovery Std", std[1]])
467
+ writer.writerow(["Recursive Attr Recovery Std", std[2]])
468
+ writer.writerow(["Examples Used", used])
469
+ writer.writerow(["Examples Skipped", skipped])
470
+ writer.writerow(["Avg Sample Time (s)", avg_time])
471
+
472
+ print(f"[{dataset_name}] {args.attr_func} -> {out_dir/file_name} (used={used} skipped={skipped} avg {avg_time:.2f}s)")
473
+
474
+
475
+ if __name__ == "__main__":
476
+ parser = argparse.ArgumentParser("RULER-only token-level attribution recovery evaluation (Recall@10pct).")
477
+ parser.add_argument("--num_examples", type=int, default=100, help="How many examples to evaluate.")
478
+ parser.add_argument("--sample", type=int, default=None, help="Optional subsample before num_examples.")
479
+ parser.add_argument("--seed", type=int, default=42)
480
+ parser.add_argument("--model", type=str, default="qwen-8B")
481
+ parser.add_argument("--model_path", type=str, default=None, help="Optional local model path to load.")
482
+ parser.add_argument("--attr_func", type=str, default="ifr_multi_hop")
483
+ parser.add_argument("--cuda_num", type=int, default=0)
484
+ parser.add_argument("--cuda", type=str, default=None)
485
+ parser.add_argument("--dataset", type=str, required=True, help="RULER dataset name or JSONL path (raw or exp2 cache).")
486
+ parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache directory to search by dataset name.")
487
+ parser.add_argument("--n_hops", type=int, default=3)
488
+
489
+ args, _ = parser.parse_known_args()
490
+ main(args)
evaluations/attribution_recovery.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RULER-only token-level recovery (Recall@10pct) examples.
2
+ #
3
+ # Dataset can be:
4
+ # - a RULER name (hotpotqa_long / niah_* / vt_*) resolved under data/ruler_multihop/<len>/.../validation.jsonl
5
+ # - a raw RULER JSONL path
6
+ # - an exp2 cache JSONL path (must contain metadata.needle_spans)
7
+
8
+ # Example: evaluate on exp2 cache
9
+ # CUDA_VISIBLE_DEVICES=0 python3 evaluations/attribution_recovery.py \
10
+ # --model qwen-8B --model_path /opt/share/models/Qwen/Qwen3-8B/ \
11
+ # --cuda 0 --num_examples 50 --attr_func ifr_multi_hop \
12
+ # --dataset exp/exp2/data/hotpotqa.jsonl
13
+
14
+ # Example: evaluate on raw RULER JSONL
15
+ # CUDA_VISIBLE_DEVICES=0 python3 evaluations/attribution_recovery.py \
16
+ # --model qwen-8B --model_path /opt/share/models/Qwen/Qwen3-8B/ \
17
+ # --cuda 0 --num_examples 50 --attr_func ifr_multi_hop \
18
+ # --dataset data/ruler_multihop/4096/hotpotqa_long/validation.jsonl
evaluations/faithfulness.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ # Ensure project root is importable regardless of CWD
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
7
+ import torch
8
+ import numpy as np
9
+ from transformers import utils
10
+ import math
11
+ from tqdm import tqdm
12
+ import random
13
+ import argparse
14
+ import csv
15
+ from itertools import islice
16
+ from typing import Tuple
17
+ from huggingface_hub import login
18
+
19
+ from attribution_datasets import (
20
+ AttributionDataset,
21
+ FactsAttributionDataset,
22
+ MathAttributionDataset,
23
+ MoreHopQAAttributionDataset,
24
+ )
25
+
26
+ utils.logging.set_verbosity_error() # Suppress standard warnings
27
+
28
+ import llm_attr
29
+ import llm_attr_eval
30
+
31
+
32
+ def _resolve_indices_to_explain_token_span(
33
+ attr_result: llm_attr.LLMAttributionResult, indices_to_explain: list[int] | None
34
+ ) -> list[int]:
35
+ if (
36
+ isinstance(indices_to_explain, list)
37
+ and len(indices_to_explain) == 2
38
+ and all(isinstance(x, int) and x >= 0 for x in indices_to_explain)
39
+ and indices_to_explain[0] <= indices_to_explain[1]
40
+ ):
41
+ return indices_to_explain
42
+
43
+ gen_len = int(attr_result.attribution_matrix.shape[0])
44
+ if gen_len <= 0:
45
+ return [0, 0]
46
+
47
+ # Default: explain the full generation excluding the appended EOS token.
48
+ end_tok = max(0, gen_len - 2)
49
+ return [0, end_tok]
50
+
51
+
52
+ def run_attribution(testing_dict, prompt, batch_size, indices_to_explain = [1], target = None) -> tuple[list[torch.Tensor], dict | None]:
53
+ model = testing_dict["model"]
54
+ tokenizer = testing_dict["tokenizer"]
55
+
56
+ # Now we create an attribution for the full response
57
+ if "IG" in testing_dict["attr_func"]:
58
+ llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
59
+
60
+ if testing_dict["attr_func"] == "IG":
61
+ attr = llm_attributor.calculate_IG_per_generation(prompt, 20, tokenizer.eos_token_id, batch_size = batch_size, target = target)
62
+
63
+ token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
64
+ attributions = list(attr.get_all_token_attrs(token_span))
65
+
66
+ elif "perturbation" in testing_dict["attr_func"]:
67
+ llm_attributor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
68
+
69
+ if testing_dict["attr_func"] == "perturbation_all":
70
+ attr = llm_attributor.calculate_feature_ablation_sentences(prompt, baseline = tokenizer.eos_token_id, measure="log_loss", target = target)
71
+ elif testing_dict["attr_func"] == "perturbation_CLP":
72
+ attr = llm_attributor.calculate_feature_ablation_sentences(prompt, baseline = tokenizer.eos_token_id, measure="KL", target = target)
73
+ elif testing_dict["attr_func"] == "perturbation_REAGENT":
74
+ attr = llm_attributor.calculate_feature_ablation_sentences_mlm(prompt, target = target)
75
+
76
+ token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
77
+ attributions = list(attr.get_all_token_attrs(token_span))
78
+
79
+ elif "attention" in testing_dict["attr_func"]:
80
+ llm_attributor = llm_attr.LLMAttentionAttribution(model, tokenizer)
81
+ llm_attributor_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
82
+
83
+ if testing_dict["attr_func"] == "attention_I_G":
84
+ attr = llm_attributor.calculate_attention_attribution(prompt, target = target)
85
+ attr_b = llm_attributor_ig.calculate_IG_per_generation(prompt, 20, tokenizer.eos_token_id, batch_size = batch_size, target = target)
86
+ attr.attribution_matrix = attr.attribution_matrix * attr_b.attribution_matrix
87
+
88
+ token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
89
+ attributions = list(attr.get_all_token_attrs(token_span))
90
+
91
+ elif "ifr" in testing_dict["attr_func"].lower():
92
+ llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
93
+ attr_func = testing_dict["attr_func"].lower()
94
+ renorm_threshold = testing_dict.get("renorm_threshold")
95
+
96
+ if attr_func == "ifr_all_positions":
97
+ attr = llm_attributor.calculate_ifr_for_all_positions(prompt, target=target, renorm_threshold=renorm_threshold)
98
+ elif attr_func == "ifr_all_positions_output_only":
99
+ attr = llm_attributor.calculate_ifr_for_all_positions_output_only(
100
+ prompt,
101
+ target=target,
102
+ sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
103
+ renorm_threshold=renorm_threshold,
104
+ )
105
+ elif attr_func == "ifr_span":
106
+ span = testing_dict.get("sink_span")
107
+ attr = llm_attributor.calculate_ifr_span(
108
+ prompt,
109
+ target=target,
110
+ span=tuple(span) if span is not None else None,
111
+ renorm_threshold=renorm_threshold,
112
+ )
113
+ elif attr_func == "ifr_multi_hop":
114
+ attr = llm_attributor.calculate_ifr_multi_hop(
115
+ prompt,
116
+ target=target,
117
+ sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
118
+ thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
119
+ n_hops=testing_dict.get("n_hops", 1),
120
+ renorm_threshold=renorm_threshold,
121
+ observation_mask=testing_dict.get("observation_mask"),
122
+ )
123
+ elif attr_func == "ifr_in_all_gen":
124
+ import ft_ifr_improve
125
+
126
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(model, tokenizer)
127
+ attr = llm_attributor.calculate_ifr_in_all_gen(
128
+ prompt,
129
+ target=target,
130
+ sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
131
+ thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
132
+ n_hops=testing_dict.get("n_hops", 1),
133
+ renorm_threshold=renorm_threshold,
134
+ observation_mask=testing_dict.get("observation_mask"),
135
+ )
136
+ elif attr_func == "ifr_multi_hop_stop_words":
137
+ import ft_ifr_improve
138
+
139
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionImproved(model, tokenizer)
140
+ attr = llm_attributor.calculate_ifr_multi_hop_stop_words(
141
+ prompt,
142
+ target=target,
143
+ sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
144
+ thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
145
+ n_hops=testing_dict.get("n_hops", 1),
146
+ renorm_threshold=renorm_threshold,
147
+ observation_mask=testing_dict.get("observation_mask"),
148
+ )
149
+ elif attr_func == "ifr_multi_hop_both":
150
+ import ft_ifr_improve
151
+
152
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionBoth(model, tokenizer)
153
+ attr = llm_attributor.calculate_ifr_multi_hop_both(
154
+ prompt,
155
+ target=target,
156
+ sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
157
+ thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
158
+ n_hops=testing_dict.get("n_hops", 1),
159
+ renorm_threshold=renorm_threshold,
160
+ observation_mask=testing_dict.get("observation_mask"),
161
+ )
162
+ elif attr_func == "ifr_multi_hop_split_hop":
163
+ import ft_ifr_improve
164
+
165
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(model, tokenizer)
166
+ attr = llm_attributor.calculate_ifr_multi_hop_split_hop(
167
+ prompt,
168
+ target=target,
169
+ sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
170
+ thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
171
+ n_hops=testing_dict.get("n_hops", 1),
172
+ renorm_threshold=renorm_threshold,
173
+ observation_mask=testing_dict.get("observation_mask"),
174
+ )
175
+ else:
176
+ raise ValueError(f"Unsupported IFR attribution function '{testing_dict['attr_func']}'.")
177
+
178
+ token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
179
+ attributions = list(attr.get_all_token_attrs(token_span))
180
+
181
+ elif "basic" in testing_dict["attr_func"]:
182
+ llm_attributor = llm_attr.LLMBasicAttribution(model, tokenizer)
183
+ attr = llm_attributor.calculate_basic_attribution(prompt, target = target)
184
+ token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
185
+ attributions = list(attr.get_all_token_attrs(token_span))
186
+
187
+ elif testing_dict["attr_func"] == "attnlrp":
188
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
189
+ attr = llm_attributor.calculate_attnlrp_ft_hop0(prompt, target=target)
190
+ token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
191
+ attributions = list(attr.get_all_token_attrs(token_span))
192
+
193
+ elif testing_dict["attr_func"] == "attnlrp_aggregated":
194
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
195
+ attr = llm_attributor.calculate_attnlrp_aggregated(prompt, target=target)
196
+ token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
197
+ attributions = list(attr.get_all_token_attrs(token_span))
198
+
199
+ elif testing_dict["attr_func"] == "attnlrp_aggregated_multi_hop":
200
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
201
+ attr = llm_attributor.calculate_attnlrp_aggregated_multi_hop(
202
+ prompt,
203
+ target=target,
204
+ sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
205
+ thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
206
+ n_hops=testing_dict.get("n_hops", 1),
207
+ )
208
+ token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
209
+ attributions = list(attr.get_all_token_attrs(token_span))
210
+
211
+ else:
212
+ raise ValueError(f"Unsupported attribution function '{testing_dict['attr_func']}'.")
213
+
214
+ extra = None
215
+ if testing_dict["attr_func"].lower() in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both"):
216
+ import ft_ifr_improve
217
+
218
+ extra = {
219
+ "keep_prompt_token_indices": ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens)),
220
+ "user_prompt_indices": list(getattr(llm_attributor, "user_prompt_indices", []) or []),
221
+ }
222
+
223
+ return attributions, extra
224
+
225
+ def faithfulness_test(testing_dict, llm_evaluator, prompt, indices_to_explain, target = None) -> np.ndarray[float]:
226
+ tokenizer = testing_dict["tokenizer"]
227
+ faithfulness_k = int(testing_dict.get("faithfulness_k", 20))
228
+
229
+ scores = []
230
+
231
+ # batch size is set based on the max_input_len in main(). Currently set to fully fill a 196GB GPU.
232
+ if target is None:
233
+ generation, full_output = llm_evaluator.response(prompt)
234
+ batch_size = math.floor((testing_dict["max_input_len"] - 100) / len(tokenizer(full_output).input_ids))
235
+ else:
236
+ generation = target
237
+ batch_size = math.floor(
238
+ (testing_dict["max_input_len"] - 100)
239
+ / len(tokenizer(llm_evaluator.format_prompt(" " + prompt) + generation).input_ids)
240
+ )
241
+
242
+ # We run an attribution on the input
243
+ # A list of attribution tensors will be returned and scored individually.
244
+ attr_list, extra = run_attribution(testing_dict, prompt, batch_size, indices_to_explain = indices_to_explain, target = target)
245
+
246
+ seq_attr = attr_list[0]
247
+ prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
248
+
249
+ for i in range(len(attr_list)):
250
+ attr = attr_list[i][:, :prompt_len]
251
+ if testing_dict["attr_func"].lower() in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and extra is not None:
252
+ import ft_ifr_improve
253
+
254
+ scores.append(
255
+ ft_ifr_improve.faithfulness_test_skip_tokens(
256
+ llm_evaluator,
257
+ attr,
258
+ prompt,
259
+ generation,
260
+ keep_prompt_token_indices=extra.get("keep_prompt_token_indices") or [],
261
+ user_prompt_indices=extra.get("user_prompt_indices"),
262
+ k=faithfulness_k,
263
+ )
264
+ )
265
+ else:
266
+ scores.append(llm_evaluator.faithfulness_test(attr, prompt, generation, k=faithfulness_k)) # [3 scores]
267
+
268
+ return np.array(scores)
269
+
270
+ def clean_trailing_space(text) -> str:
271
+ if text[-1] == ' ':
272
+ return text[:-1]
273
+ else:
274
+ return text
275
+
276
+ def evaluate_attribution(testing_dict) -> None:
277
+ model = testing_dict["model"]
278
+ tokenizer = testing_dict["tokenizer"]
279
+
280
+ llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(model, tokenizer)
281
+
282
+ scores = []
283
+
284
+ description = "Faithfulness " + testing_dict["model_name"] + " " + testing_dict["dataset_name"] + " " + testing_dict["attr_func"]
285
+
286
+ dataset: AttributionDataset = testing_dict["dataset"]
287
+ num_examples = testing_dict["num_examples"]
288
+ total = min(len(dataset), num_examples) if hasattr(dataset, "__len__") else num_examples
289
+ example_iterator = islice(dataset, num_examples)
290
+
291
+ for example in tqdm(example_iterator, desc=description, total=total):
292
+ indices_to_explain = example.indices_to_explain if example.indices_to_explain is not None else [-2]
293
+ scores.append(
294
+ faithfulness_test(
295
+ testing_dict,
296
+ llm_evaluator,
297
+ example.prompt,
298
+ indices_to_explain=indices_to_explain,
299
+ target=example.target,
300
+ )
301
+ )
302
+
303
+ scores = np.array(scores) # [num_examples, num_attrs, 3 scores]
304
+ scores_mean = scores.mean(0) # [num_attrs, 3 scores]
305
+ scores_var = scores.std(0) # [num_attrs, 3 scores]
306
+
307
+ # make the test folder if it doesn't exist
308
+ folder = "./test_results/faithfulness/" + testing_dict["dataset_name"] + "/" + testing_dict["model_name"] + "/"
309
+ if not os.path.exists(folder):
310
+ os.makedirs(folder)
311
+
312
+ # save all data
313
+ file_name = testing_dict["attr_func"] + "_" + str(testing_dict["num_examples"]) + "_examples"
314
+ with open(folder + file_name + ".csv", 'w') as f:
315
+ write = csv.writer(f)
316
+
317
+ write.writerow(["Method", "RISE", "MAS", "RISE + AP"])
318
+
319
+ write.writerow(["Seq Attr Scores Mean"] + scores_mean[0].tolist())
320
+ write.writerow(["Row Attr Scores Mean"] + scores_mean[1].tolist())
321
+ write.writerow(["Recursive Attr Scores Mean"] + scores_mean[2].tolist())
322
+
323
+ write.writerow(["Seq Attr Scores Var"] + scores_var[0].tolist())
324
+ write.writerow(["Row Attr Scores Var"] + scores_var[1].tolist())
325
+ write.writerow(["Recursive Attr Scores Var"] + scores_var[2].tolist())
326
+
327
+ return
328
+
329
+ def load_model(model_name, device) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
330
+ seed = 42
331
+ random.seed(seed)
332
+ np.random.seed(seed)
333
+ torch.manual_seed(seed)
334
+ torch.cuda.manual_seed(seed)
335
+ torch.cuda.manual_seed_all(seed) # if multi-GPU
336
+
337
+ # Respect three modes:
338
+ # - device == 'auto' -> multi-GPU sharding across all visible devices
339
+ # - device startswith('cuda:IDX') -> place entire model on a single GPU IDX (relative to visible devices)
340
+ # - device == 'cpu' -> CPU
341
+ if device == "auto":
342
+ model = AutoModelForCausalLM.from_pretrained(
343
+ model_name,
344
+ device_map="auto",
345
+ attn_implementation="eager",
346
+ torch_dtype=torch.float16,
347
+ )
348
+ elif isinstance(device, str) and device.startswith("cuda:"):
349
+ try:
350
+ gpu_idx = int(device.split(":")[1])
351
+ except Exception:
352
+ gpu_idx = 0
353
+ model = AutoModelForCausalLM.from_pretrained(
354
+ model_name,
355
+ device_map={"": gpu_idx},
356
+ attn_implementation="eager",
357
+ torch_dtype=torch.float16,
358
+ )
359
+ else:
360
+ model = AutoModelForCausalLM.from_pretrained(
361
+ model_name,
362
+ attn_implementation="eager",
363
+ torch_dtype=torch.float16,
364
+ )
365
+ model.eval()
366
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
367
+
368
+ # Needed for LLaMA tokenizer
369
+ tokenizer.pad_token = tokenizer.eos_token
370
+
371
+ return model, tokenizer
372
+
373
+ def main(args) -> None:
374
+ # login(token = "")
375
+
376
+ # Device selection policy (mirrors attribution_recovery):
377
+ # - If --cuda is a comma-separated list (e.g. "0,1"), set visibility to that list and shard with device_map='auto'.
378
+ # - If --cuda is a single index (e.g. "0"), do NOT override CUDA_VISIBLE_DEVICES; place model on cuda:{index}.
379
+ # - Else (no --cuda), use --cuda_num as single-device index relative to current visibility.
380
+ if args.cuda is not None and isinstance(args.cuda, str) and "," in args.cuda:
381
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
382
+ device = "auto"
383
+ elif args.cuda is not None and isinstance(args.cuda, str) and args.cuda.strip() != "":
384
+ try:
385
+ idx = int(args.cuda)
386
+ except Exception:
387
+ idx = 0
388
+ device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
389
+ else:
390
+ device = f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
391
+
392
+ # set up model
393
+ if args.model == "llama-1B":
394
+ model_name = "meta-llama/Llama-3.2-1B-Instruct"
395
+ max_input_len = 5500
396
+ elif args.model == "llama-3B":
397
+ model_name = "meta-llama/Llama-3.2-3B-Instruct"
398
+ max_input_len = 4800
399
+ elif args.model == "llama-8B":
400
+ model_name = "meta-llama/Llama-3.1-8B-Instruct"
401
+ max_input_len = 3500
402
+ elif args.model == "qwen-1.7B":
403
+ model_name = "Qwen/Qwen3-1.7B"
404
+ max_input_len = 5500
405
+ elif args.model == "qwen-4B":
406
+ model_name = "Qwen/Qwen3-4B-Instruct-2507"
407
+ max_input_len = 3500
408
+ elif args.model == "qwen-8B":
409
+ model_name = "Qwen/Qwen3-8B"
410
+ max_input_len = 3000
411
+ elif args.model == "qwen-32B":
412
+ model_name = "Qwen/Qwen3-32B"
413
+ max_input_len = 1500
414
+ elif args.model == "gemma-12B":
415
+ model_name = "gemma/gemma-3-12b-it"
416
+ max_input_len = 1500
417
+ elif args.model == "gemma-27B":
418
+ model_name = "gemma/gemma-3-27b-it"
419
+ max_input_len = 2000
420
+ else:
421
+ model_name = args.model_path if args.model_path is not None else args.model
422
+ max_input_len = 2000
423
+
424
+ model, tokenizer = load_model(model_name if args.model_path is None else args.model_path, device)
425
+
426
+ dataset_registry = {
427
+ "math": lambda: MathAttributionDataset("./data/math_mine.json", tokenizer),
428
+ "facts": lambda: FactsAttributionDataset("./data/10000_facts_9_choose_3.json"),
429
+ "morehopqa": lambda: MoreHopQAAttributionDataset("./data/with_human_verification.json"),
430
+ }
431
+ dataset_loader = dataset_registry.get(args.dataset)
432
+ if dataset_loader is None:
433
+ print("You have not specified an acceptable dataset. Exiting.")
434
+ exit()
435
+ dataset = dataset_loader()
436
+
437
+ testing_dict = {
438
+ "model" : model,
439
+ "model_name": args.model,
440
+ "tokenizer" : tokenizer,
441
+ "dataset" : dataset,
442
+ "dataset_name" : args.dataset,
443
+ "max_input_len": max_input_len,
444
+ "attr_func": args.attr_func,
445
+ "num_examples": args.num_examples,
446
+ "device": device,
447
+ "faithfulness_k": args.faithfulness_k,
448
+ }
449
+
450
+ # call the test function
451
+ evaluate_attribution(testing_dict)
452
+
453
+ return
454
+
455
+ if __name__ == "__main__":
456
+ parser = argparse.ArgumentParser('')
457
+ parser.add_argument('--num_examples',
458
+ type = int, default = 100,
459
+ help='How many dataset examples to test with.')
460
+ parser.add_argument('--model',
461
+ type = str,
462
+ default = "llama",
463
+ help='Model to use: llama or qwen')
464
+ parser.add_argument('--model_path',
465
+ type=str, default=None,
466
+ help='Optional local model path to load (overrides model repo id only).')
467
+ parser.add_argument('--attr_func',
468
+ type = str,
469
+ default = "IG",
470
+ help="attr to use: \
471
+ grad, IG, IG_captum, contextcite, attention, rollout, perturbation \
472
+ ")
473
+ parser.add_argument('--cuda_num',
474
+ type=int, default = 0,
475
+ help='The number of the GPU you want to use.')
476
+ parser.add_argument('--cuda',
477
+ type=str, default=None,
478
+ help='GPU selection: use comma-separated ids for multi-GPU sharding (e.g. "0,1"); use a single index for one GPU relative to current CUDA_VISIBLE_DEVICES (e.g. "0").')
479
+ parser.add_argument('--dataset',
480
+ type = str, default = "math",
481
+ help = 'The dataset to evaluate on: math, facts, or morehopqa')
482
+ parser.add_argument(
483
+ "--faithfulness_k",
484
+ type=int,
485
+ default=20,
486
+ help="Total perturbation steps k for MAS/RISE (each step perturbs ~1/k of prompt tokens).",
487
+ )
488
+
489
+ args, unparsed = parser.parse_known_args()
490
+
491
+ main(args)
evaluations/faithfulness.sh ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func IG --dataset facts
2
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset facts
3
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func IG --dataset facts
4
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset facts
5
+
6
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset facts
7
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset facts
8
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset facts
9
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset facts
10
+
11
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset facts
12
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset facts
13
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset facts
14
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset facts
15
+
16
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset facts
17
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset facts
18
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset facts
19
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset facts
20
+
21
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset facts
22
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset facts
23
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset facts
24
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset facts
25
+
26
+
27
+
28
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func IG --dataset math
29
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset math
30
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func IG --dataset math
31
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset math
32
+
33
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset math
34
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset math
35
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset math
36
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset math
37
+
38
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset math
39
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset math
40
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset math
41
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset math
42
+
43
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset math
44
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset math
45
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset math
46
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset math
47
+
48
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset math
49
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset math
50
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset math
51
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset math
52
+
53
+
54
+
55
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func IG --dataset morehopqa
56
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset morehopqa
57
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func IG --dataset morehopqa
58
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset morehopqa
59
+
60
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset morehopqa
61
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset morehopqa
62
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset morehopqa
63
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset morehopqa
64
+
65
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset morehopqa
66
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset morehopqa
67
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset morehopqa
68
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset morehopqa
69
+
70
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset morehopqa
71
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset morehopqa
72
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset morehopqa
73
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset morehopqa
74
+
75
+ # python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset morehopqa
76
+ # python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset morehopqa
77
+ # python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset morehopqa
78
+ # python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset morehopqa
79
+
80
+ CUDA_VISIBLE_DEVICES=4,6 python3 evaluations/faithfulness.py --model qwen-8B --model_path /opt/share/models/Qwen/Qwen3-8B/ --cuda '0,1' --num_examples 50 --attr_func IG --dataset math
example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/quickstart.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ from flashtrace import FlashTrace, load_model_and_tokenizer
6
+
7
+
8
+ def build_parser() -> argparse.ArgumentParser:
9
+ parser = argparse.ArgumentParser(description="FlashTrace quickstart example.")
10
+ parser.add_argument("--model", required=True, help="Hugging Face model id or local model path.")
11
+ parser.add_argument("--prompt", required=True, help="Prompt text.")
12
+ parser.add_argument("--target", help="Target response text.")
13
+ parser.add_argument("--output-span", default=None, help="Inclusive generation-token span START:END.")
14
+ parser.add_argument("--reasoning-span", default=None, help="Inclusive generation-token span START:END.")
15
+ parser.add_argument("--html", default="trace.html", help="Output HTML path.")
16
+ parser.add_argument("--use-chat-template", action="store_true", help="Format prompts with the tokenizer chat template.")
17
+ return parser
18
+
19
+
20
+ def parse_span(value: str | None) -> tuple[int, int] | None:
21
+ from flashtrace.cli import parse_span as parse_cli_span
22
+
23
+ return parse_cli_span(value)
24
+
25
+
26
+ def main() -> int:
27
+ args = build_parser().parse_args()
28
+ model, tokenizer = load_model_and_tokenizer(args.model)
29
+ tracer = FlashTrace(model, tokenizer, use_chat_template=args.use_chat_template)
30
+ trace = tracer.trace(
31
+ prompt=args.prompt,
32
+ target=args.target,
33
+ output_span=parse_span(args.output_span),
34
+ reasoning_span=parse_span(args.reasoning_span),
35
+ )
36
+ for item in trace.topk_inputs(10):
37
+ print(f"{item.index}\t{item.score:.6f}\t{item.token!r}")
38
+ trace.to_html(args.html)
39
+ print(f"wrote {args.html}")
40
+ return 0
41
+
42
+
43
+ if __name__ == "__main__":
44
+ raise SystemExit(main())
exp/case_study/README.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FT 多跳案例分析 & IFR 标准可视化(exp/case_study)
2
+
3
+ 此目录提供一个轻量的单样本 IFR 可视化流程,不改动核心评测代码。
4
+
5
+ ## 功能
6
+ - 读取单个样本(默认 `exp/exp2/data/morehopqa.jsonl`,索引 0)。
7
+ - 支持多种模式:
8
+ - `ft`:当前使用的多跳 FT 归因(内部调用 `LLMIFRAttribution.calculate_ifr_multi_hop`)。
9
+ - `ifr`:标准 IFR(单 hop),默认对指定 sink span 做**聚合 IFR**(只显示 1 个面板)。
10
+ - `ifr_all_positions_output_only`:只对 `sink_span` 范围内的 output tokens 计算 IFR token-level 矩阵,并基于该矩阵得到 Row / Recursive(CAGE)两张面板。
11
+ - `attnlrp`:AttnLRP hop0(复用 FT-AttnLRP 的 span-aggregate 逻辑,等价于 `LLMLRPAttribution.calculate_attnlrp_multi_hop(n_hops=0)`,并可视化 `raw_attributions[0].token_importance_total`)。
12
+ - `ft_attnlrp`:FT-AttnLRP(严格复用 `LLMLRPAttribution.calculate_attnlrp_aggregated_multi_hop`,与 `exp/exp2/` 保持一致;直接可视化每 hop 的 `token_importance_total`)。
13
+ - 可视化两个视图:
14
+ - **裁剪前 token 级(full)**:带 chat template 的完整序列热力图(template + user prompt + generation)。
15
+ - **Prompt-only token 级**:只显示 user prompt tokens 的热力图(不包含 generation tokens)。
16
+ - 热力图按 `|score|` 上色(不区分正负);每个面板的 full/prompt 两张图各自用 p99.5(`|score|`) 独立归一化颜色深度。
17
+ - 输出 JSON(完整数值)和 HTML(逐跳热力图)。
18
+ - 额外提供 MAS(faithfulness / token perturbation)可视化:对指定归因方法做 token 级扰动评估,并渲染扰动影响热力图 + MAS 分数。
19
+
20
+ ## 快速开始
21
+ ```bash
22
+ # 根据本地模型修改 model/model_path
23
+ # 多跳 FT(默认) ft_split_hop,ft_improve
24
+ python exp/case_study/run_ifr_case.py \
25
+ --mode ft_split_hop \
26
+ --dataset exp/exp2/data/morehopqa.jsonl \
27
+ --index 0 \
28
+ --model qwen-8B \
29
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
30
+ --cuda 0 \
31
+ --n_hops 3
32
+
33
+ # 标准 IFR(单 hop,可指定 sink span)
34
+ python exp/case_study/run_ifr_case.py \
35
+ --mode ifr \
36
+ --dataset exp/exp2/data/morehopqa.jsonl \
37
+ --index 0 \
38
+ --model qwen-8B \
39
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
40
+ --cuda 0 \
41
+ --sink_span 0 0
42
+
43
+ # IFR output-only:只在 output 范围计算 IFR 矩阵,并生成 Row/Recursive(CAGE)两面板
44
+ python exp/case_study/run_ifr_case.py \
45
+ --mode ifr_all_positions_output_only \
46
+ --dataset exp/exp2/data/short-morehopqa.jsonl \
47
+ --index 0 \
48
+ --model qwen-8B \
49
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
50
+ --cuda 0
51
+
52
+ # AttnLRP hop0(复用 FT-AttnLRP span-aggregate;可视化 hop0 raw 向量)
53
+ python exp/case_study/run_ifr_case.py \
54
+ --mode attnlrp \
55
+ --dataset exp/exp2/data/morehopqa.jsonl \
56
+ --index 0 \
57
+ --model qwen-8B \
58
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
59
+ --cuda 0 \
60
+ --sink_span 0 20
61
+
62
+ # FT-attnLRP(多跳递归 AttnLRP)
63
+ python exp/case_study/run_ifr_case.py \
64
+ --mode ft_attnlrp \
65
+ --dataset exp/exp2/data/morehopqa.jsonl \
66
+ --index 0 \
67
+ --model qwen-8B \
68
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
69
+ --cuda 0,2,3,4,5,7 \
70
+ --n_hops 3 \
71
+ --attnlrp_neg_handling abs \
72
+ --attnlrp_norm_mode norm
73
+ ```
74
+
75
+ 产物位于 `exp/case_study/out/`,文件名前缀根据模式变化,例如:
76
+ - `ft_case_<dataset>_idx<idx>.json/html`
77
+ - `ifr_case_<dataset>_idx<idx>.json/html`
78
+ - `ifr_output_only_case_<dataset>_idx<idx>.json/html`
79
+ - `attnlrp_case_<dataset>_idx<idx>.json/html`
80
+ - `ft_attnlrp_case_<dataset>_idx<idx>.json/html`
81
+
82
+ ## MAS(Faithfulness / Token Perturbation)可视化
83
+
84
+ > 说明:这里的 MAS 与项目 `llm_attr_eval.LLMAttributionEvaluator.faithfulness_test()` 保持一致:
85
+ > 1) 先对样本跑指定方法的归因,并取 token-level attribution(Seq / Row / Recursive)。
86
+ > 2) 按 prompt token 的重要性排序,逐步将 token id 替换为 `tokenizer.pad_token_id`(token 级扰动)。
87
+ > 3) 用 `sum log p(generation + EOS | prompt)` 得到分数曲线,计算 RISE / MAS / RISE+AP。
88
+ > 4) 可视化时用“每一步扰动带来的边际 logprob 变化”作为 token 分数,渲染为 token spans 的“扰动影响热力图”。
89
+
90
+ ```bash
91
+ # FT-IFR(ifr_multi_hop;默认 --method ft)
92
+ python exp/case_study/run_mas_case.py \
93
+ --dataset exp/exp2/data/short-morehopqa.jsonl \
94
+ --index 0 \
95
+ --model qwen-8B \
96
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
97
+ --cuda 0 \
98
+ --method ft \
99
+ --n_hops 3
100
+ ```
101
+
102
+ 常用方法选择(与 `run_ifr_case.py` 的模式名对齐):
103
+ ```bash
104
+ # IFR(需要 sink_span;默认会优先使用数据集缓存字段)
105
+ python exp/case_study/run_mas_case.py --method ifr --sink_span 0 20 ...
106
+
107
+ # IFR output-only(仅对 sink_span 内的 output token 计算 IFR token-level matrix)
108
+ python exp/case_study/run_mas_case.py --method ifr_all_positions_output_only --sink_span 0 20 ...
109
+
110
+ # FT-IFR(ifr_multi_hop)
111
+ python exp/case_study/run_mas_case.py --method ft --n_hops 1 --sink_span 0 20 --thinking_span 0 20 ...
112
+
113
+ # AttnLRP hop0(复用 FT-AttnLRP hop0;仍然需要 indices_to_explain/sink_span 来取 Seq/Row/Rec)
114
+ python exp/case_study/run_mas_case.py --method attnlrp --sink_span 0 20 ...
115
+
116
+ # FT-AttnLRP(attnlrp_aggregated_multi_hop)
117
+ python exp/case_study/run_mas_case.py --method ft_attnlrp --n_hops 1 --sink_span 0 20 --thinking_span 0 20 ...
118
+ ```
119
+
120
+ 产物位于 `exp/case_study/out/`,文件名前缀为:
121
+ - `mas_case_<method>_<dataset>_idx<idx>.json/html`
122
+
123
+ HTML 默认包含 3 个 attribution 视角面板(Seq / Row / Recursive),每个面板里有 2 行 token 级热力图:
124
+ - **Method attribution(token weights)**:该方法的 token 归因权重(用于排序/密度)。
125
+ - **Attribution-guided MAS marginal(path deltas)**:按归因排序逐步替换的边际影响(这就是评测中实际使用的扰动路径)。
126
+
127
+ ## 在浏览器中查看 HTML
128
+ 1) 先运行上面的命令生成 `.html`(终端会打印形如 `wrote exp/case_study/out/...html`)。
129
+
130
+ 2) 在仓库根目录启动一个静态文件服务(任选一个端口,例如 8888):
131
+ ```bash
132
+ python -m http.server 8888 --directory exp/case_study/out
133
+ ```
134
+
135
+ 3) 用浏览器打开(注意是 `http://`,不是 `https://`):
136
+ - 本机:`http://127.0.0.1:8888/<你的html文件名>`
137
+ - 远程机器(推荐端口转发):在本地执行 `ssh -L 8888:127.0.0.1:8888 <user>@<server>`,然后在本地浏览器打开 `http://127.0.0.1:8888/<你的html文件名>`
138
+
139
+ 如果你在 `http.server` 日志里看到大量 `400 Bad request version` 且伴随乱码,通常是有客户端用 HTTPS 去连了 HTTP 端口;请确认浏览器地址栏是 `http://...`。
140
+
141
+ ## 可选参数
142
+ - `--sink_span a b` / `--thinking_span a b`:覆盖生成侧的 sink/thinking 句子 span(默认使用缓存字段)。
143
+ - `--attnlrp_neg_handling drop|abs`:FT-AttnLRP 每跳负值处理(drop=clamp>=0,abs=取绝对值)。
144
+ - `--attnlrp_norm_mode norm|no_norm`:FT-AttnLRP 正则化与 hop ratio 开关(norm=全局+thinking 归一化并启用 ratio;no_norm=三者都禁用)。
145
+ - `--chunk_tokens` / `--sink_chunk_tokens`:IFR 分块参数。
146
+ - `--output_dir`:修改输出目录。
147
+
148
+ ## 文件说明
149
+ - `run_ifr_case.py`:命令行入口与落盘(支持 `ft`/`ifr`/`ifr_all_positions_output_only`/`attnlrp`/`ft_attnlrp` 模式)。
150
+ - `run_mas_case.py`:MAS(faithfulness / token perturbation)可视化入口与落盘(支持 `ifr`/`ifr_all_positions_output_only`/`ft`/`attnlrp`/`ft_attnlrp`)。
151
+ - `analysis.py`:逐跳清洗与封装(token-level)。
152
+ - `viz.py`:HTML 渲染与热力图。
exp/case_study/analysis.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for IFR case studies (hop-wise aggregation + sanitization).
2
+
3
+ All utilities stay local to exp/case_study to avoid touching core eval code.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Any, Dict, Iterable, List, Optional, Sequence
9
+
10
+ import torch
11
+
12
+
13
+ def vector_stats(vec: torch.Tensor) -> Dict[str, float]:
14
+ if vec.numel() == 0:
15
+ return {"min": 0.0, "max": 0.0, "abs_max": 0.0, "mean": 0.0, "sum": 0.0}
16
+ v = vec.detach().to(dtype=torch.float32)
17
+ return {
18
+ "min": float(v.min().item()),
19
+ "max": float(v.max().item()),
20
+ "abs_max": float(v.abs().max().item()),
21
+ "mean": float(v.mean().item()),
22
+ "sum": float(v.sum().item()),
23
+ }
24
+
25
+
26
+ def tensor_to_list(x: Any) -> Any:
27
+ if torch.is_tensor(x):
28
+ return x.detach().cpu().tolist()
29
+ if isinstance(x, list):
30
+ return [tensor_to_list(v) for v in x]
31
+ if isinstance(x, dict):
32
+ return {k: tensor_to_list(v) for k, v in x.items()}
33
+ return x
34
+
35
+
36
+ def sanitize_ifr_meta(meta: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
37
+ """Drop bulky raw objects and convert tensors to Python lists for JSON."""
38
+
39
+ if meta is None:
40
+ return None
41
+
42
+ cleaned: Dict[str, Any] = {}
43
+ for key, value in meta.items():
44
+ if key == "raw":
45
+ continue
46
+ cleaned[key] = tensor_to_list(value)
47
+ return cleaned
48
+
49
+
50
+ def package_token_hops(
51
+ hop_vectors: Iterable[Sequence[float]],
52
+ ) -> List[Dict[str, Any]]:
53
+ """Package per-hop token vectors without sentence aggregation.
54
+
55
+ hop_vectors are assumed to already match the experiment's configured
56
+ postprocessing (e.g., FT-AttnLRP neg_handling/norm_mode).
57
+ """
58
+
59
+ packaged: List[Dict[str, Any]] = []
60
+ for hop_idx, vec in enumerate(hop_vectors):
61
+ vec_tensor = torch.nan_to_num(torch.as_tensor(vec, dtype=torch.float32), nan=0.0)
62
+ token_scores = vec_tensor.tolist()
63
+ token_max = float(vec_tensor.abs().max().item()) if vec_tensor.numel() > 0 else 0.0
64
+ total = float(vec_tensor.sum().item())
65
+ packaged.append(
66
+ {
67
+ "hop": hop_idx,
68
+ "token_scores": token_scores,
69
+ "token_score_max": token_max,
70
+ "token_stats": vector_stats(vec_tensor),
71
+ "total_mass": total,
72
+ }
73
+ )
74
+ return packaged
exp/case_study/faithfulness_trace.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Faithfulness (MAS/RISE) trace utilities for exp/case_study.
2
+
3
+ This module is intentionally aligned with `llm_attr_eval.LLMAttributionEvaluator.faithfulness_test`,
4
+ but additionally returns the full trace arrays needed for visualization and supports providing
5
+ `user_prompt_indices` to avoid fragile subsequence matching.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Dict, Optional, Sequence, List
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ import llm_attr_eval
16
+
17
+
18
+ def _auc(arr: np.ndarray) -> float:
19
+ return float((arr.sum() - arr[0] / 2 - arr[-1] / 2) / max(1, (arr.shape[0] - 1)))
20
+
21
+
22
+ @torch.inference_mode()
23
+ def mas_trace(
24
+ llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
25
+ *,
26
+ attribution: torch.Tensor,
27
+ prompt: str,
28
+ generation: str,
29
+ user_prompt_indices: Optional[Sequence[int]] = None,
30
+ k: int = 20,
31
+ ) -> Dict[str, Any]:
32
+ """Return a token-level faithfulness trace (RISE/MAS/RISE+AP) plus per-token deltas.
33
+
34
+ attribution: [R, P] token attribution on prompt-side tokens only.
35
+ prompt: raw prompt string.
36
+ generation: target generation string; scored as generation + eos (if defined).
37
+ user_prompt_indices: optional absolute positions of each prompt token inside formatted prompt ids.
38
+ k: number of perturbation steps; each step perturbs ~1/k of prompt tokens.
39
+ """
40
+
41
+ if attribution.ndim != 2:
42
+ raise ValueError("Expected 2D prompt-side attribution matrix [R, P].")
43
+
44
+ pad_token_id = llm_evaluator._ensure_pad_token_id()
45
+
46
+ user_prompt = " " + prompt
47
+ formatted_prompt = llm_evaluator.format_prompt(user_prompt)
48
+ formatted_ids = llm_evaluator.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).input_ids
49
+
50
+ prompt_ids = formatted_ids.to(llm_evaluator.device)
51
+ prompt_ids_perturbed = prompt_ids.clone()
52
+
53
+ eos = llm_evaluator.tokenizer.eos_token or ""
54
+ generation_ids = llm_evaluator.tokenizer(
55
+ generation + eos,
56
+ return_tensors="pt",
57
+ add_special_tokens=False,
58
+ ).input_ids.to(llm_evaluator.device)
59
+
60
+ attr_cpu = attribution.detach().cpu()
61
+ w = attr_cpu.sum(0)
62
+ sorted_attr_indices = torch.argsort(w, descending=True)
63
+ attr_sum = float(w.sum().item())
64
+
65
+ P = int(w.numel())
66
+
67
+ prompt_positions: List[int]
68
+ if user_prompt_indices is not None:
69
+ prompt_positions = [int(x) for x in user_prompt_indices]
70
+ if len(prompt_positions) != P:
71
+ raise ValueError(
72
+ "user_prompt_indices length does not match prompt-side attribution length: "
73
+ f"indices P={len(prompt_positions)}, attr P={P}."
74
+ )
75
+ if P and max(prompt_positions) >= int(prompt_ids_perturbed.shape[1]):
76
+ raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
77
+ else:
78
+ user_ids = llm_evaluator.tokenizer(user_prompt, return_tensors="pt", add_special_tokens=False).input_ids
79
+ user_start = llm_evaluator._find_subsequence_start(formatted_ids[0], user_ids[0])
80
+ if user_start is None:
81
+ raise RuntimeError("Failed to locate user prompt token span inside formatted chat prompt.")
82
+ if int(user_ids.shape[1]) != P:
83
+ raise ValueError(
84
+ "Prompt-side attribution length does not match tokenized user prompt length: "
85
+ f"attr P={P}, user_prompt P={int(user_ids.shape[1])}."
86
+ )
87
+ prompt_positions = [int(user_start) + j for j in range(P)]
88
+
89
+ if P > 0:
90
+ steps = int(k) if k is not None else 0
91
+ if steps <= 0:
92
+ steps = 1
93
+ steps = min(steps, P)
94
+ else:
95
+ steps = 0
96
+
97
+ scores = np.zeros(steps + 1, dtype=np.float64)
98
+ density = np.zeros(steps + 1, dtype=np.float64)
99
+
100
+ scores[0] = (
101
+ llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
102
+ )
103
+ density[0] = 1.0
104
+
105
+ if P == 0:
106
+ return {
107
+ "num_tokens": 0,
108
+ "sorted_attr_indices": [],
109
+ "scores_raw": scores.tolist(),
110
+ "density": density.tolist(),
111
+ "normalized_model_response": [1.0],
112
+ "alignment_penalty": [0.0],
113
+ "corrected_scores": [1.0],
114
+ "token_deltas_raw": [],
115
+ "attr_weights": [],
116
+ "metrics": {"RISE": 0.0, "MAS": 0.0, "RISE+AP": 0.0},
117
+ }
118
+
119
+ if attr_sum <= 0:
120
+ density = np.linspace(1.0, 0.0, steps + 1)
121
+
122
+ per_token_delta = np.zeros(P, dtype=np.float64)
123
+
124
+ base = P // steps
125
+ remainder = P % steps
126
+ start = 0
127
+ for step in range(steps):
128
+ size = base + (1 if step < remainder else 0)
129
+ group = sorted_attr_indices[start : start + size]
130
+ start += size
131
+
132
+ for idx_t in group:
133
+ idx = int(idx_t.item())
134
+ abs_pos = int(prompt_positions[idx])
135
+ prompt_ids_perturbed[0, abs_pos] = pad_token_id
136
+ scores[step + 1] = (
137
+ llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
138
+ )
139
+ if attr_sum > 0:
140
+ dec = float(w.index_select(0, group).sum().item()) / attr_sum
141
+ density[step + 1] = density[step] - dec
142
+
143
+ delta = scores[step] - scores[step + 1]
144
+ for idx_t in group:
145
+ idx = int(idx_t.item())
146
+ per_token_delta[idx] = delta
147
+
148
+ min_normalized_pred = 1.0
149
+ normalized_model_response = scores.copy()
150
+ for i in range(len(scores)):
151
+ normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
152
+ normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
153
+ min_normalized_pred = min(min_normalized_pred, normalized_pred)
154
+ normalized_model_response[i] = min_normalized_pred
155
+
156
+ alignment_penalty = np.abs(normalized_model_response - density)
157
+ corrected_scores = normalized_model_response + alignment_penalty
158
+ corrected_scores = corrected_scores.clip(0.0, 1.0)
159
+ corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
160
+ if np.isnan(corrected_scores).any():
161
+ corrected_scores = np.linspace(1.0, 0.0, len(scores))
162
+
163
+ rise = _auc(normalized_model_response)
164
+ mas = _auc(corrected_scores)
165
+ rise_ap = _auc(normalized_model_response + alignment_penalty)
166
+
167
+ if attr_sum > 0:
168
+ attr_weights = (w.numpy() / (attr_sum + 1e-12)).astype(np.float64)
169
+ else:
170
+ attr_weights = np.zeros(P, dtype=np.float64)
171
+
172
+ return {
173
+ "num_tokens": P,
174
+ "sorted_attr_indices": [int(i.item()) for i in sorted_attr_indices],
175
+ "scores_raw": scores.tolist(),
176
+ "density": density.tolist(),
177
+ "normalized_model_response": normalized_model_response.tolist(),
178
+ "alignment_penalty": alignment_penalty.tolist(),
179
+ "corrected_scores": corrected_scores.tolist(),
180
+ "token_deltas_raw": per_token_delta.tolist(),
181
+ "attr_weights": attr_weights.tolist(),
182
+ "metrics": {"RISE": rise, "MAS": mas, "RISE+AP": rise_ap},
183
+ }
exp/case_study/run_ifr_case.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Case study runner for FlashTrace and attribution baselines.
3
+
4
+ Modes supported (all emit JSON + HTML under ``exp/case_study/out``):
5
+
6
+ - ``ft``: FlashTrace (current project implementation; multi-hop IFR)
7
+ - ``ifr_in_all_gen``: Experimental multi-hop IFR variant (hops over CoT+output; scheme B, aligns with exp/exp2)
8
+ - ``ifr``: IFR span-aggregate visualization (single hop; one panel)
9
+ - ``ifr_all_positions``: IFR full matrix + CAGE (Row/Recursive panels)
10
+ - ``ifr_all_positions_output_only``: IFR output-only token matrix + CAGE (Row/Recursive panels)
11
+ - ``attnlrp``: AttnLRP hop0 (reuse FT-AttnLRP span-aggregate; visualize raw hop0 vector)
12
+ - ``ft_attnlrp``: FT-AttnLRP (multi-hop aggregated AttnLRP; matches exp/exp2)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import sys
21
+ import types
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
24
+
25
+ # Avoid torchvision dependency when importing transformers (Longformer).
26
+ os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
27
+ os.environ.setdefault("DISABLE_TRANSFORMERS_IMAGE_TRANSFORMS", "1")
28
+
29
+ def _early_set_cuda_visible_devices() -> None:
30
+ """Set CUDA_VISIBLE_DEVICES before importing torch/transformers.
31
+
32
+ Note: CUDA device indices are re-mapped inside the process after applying the mask.
33
+ """
34
+
35
+ parser = argparse.ArgumentParser(add_help=False)
36
+ parser.add_argument("--cuda", type=str, default=None)
37
+ args, _ = parser.parse_known_args(sys.argv[1:])
38
+ cuda = args.cuda.strip() if isinstance(args.cuda, str) else ""
39
+ if cuda and "," in cuda:
40
+ os.environ["CUDA_VISIBLE_DEVICES"] = cuda
41
+
42
+
43
+ if __name__ == "__main__":
44
+ _early_set_cuda_visible_devices()
45
+
46
+ import torch
47
+
48
+ REPO_ROOT = Path(__file__).resolve().parents[2]
49
+ if str(REPO_ROOT) not in sys.path:
50
+ sys.path.insert(0, str(REPO_ROOT))
51
+
52
+
53
+ def _stub_torchvision() -> None:
54
+ """Provide minimal torchvision stubs so Longformer imports succeed without the real package."""
55
+
56
+ if "torchvision" in sys.modules:
57
+ return
58
+
59
+ from importlib.machinery import ModuleSpec
60
+
61
+ def _mk(name: str) -> types.ModuleType:
62
+ mod = types.ModuleType(name)
63
+ mod.__spec__ = ModuleSpec(name, loader=None)
64
+ return mod
65
+
66
+ tv = _mk("torchvision")
67
+ tv.__dict__["__path__"] = []
68
+ submods = ["transforms", "_meta_registrations", "datasets", "io", "models", "ops", "utils"]
69
+ for name in submods:
70
+ mod = _mk(f"torchvision.{name}")
71
+ sys.modules[f"torchvision.{name}"] = mod
72
+ setattr(tv, name, mod)
73
+
74
+ class _InterpolationMode:
75
+ NEAREST = 0
76
+ NEAREST_EXACT = 0
77
+ BILINEAR = 1
78
+ BICUBIC = 2
79
+ LANCZOS = 3
80
+ BOX = 4
81
+ HAMMING = 5
82
+
83
+ sys.modules["torchvision.transforms"].InterpolationMode = _InterpolationMode
84
+ sys.modules["torchvision.transforms"].__all__ = ["InterpolationMode"]
85
+
86
+ # ops + misc stub for timm/transformers imports
87
+ ops_mod = sys.modules.get("torchvision.ops") or _mk("torchvision.ops")
88
+ sys.modules["torchvision.ops"] = ops_mod
89
+ setattr(tv, "ops", ops_mod)
90
+ misc_mod = _mk("torchvision.ops.misc")
91
+ sys.modules["torchvision.ops.misc"] = misc_mod
92
+ setattr(ops_mod, "misc", misc_mod)
93
+
94
+ class _FrozenBatchNorm2d:
95
+ def __init__(self, *args, **kwargs):
96
+ pass
97
+
98
+ misc_mod.FrozenBatchNorm2d = _FrozenBatchNorm2d
99
+
100
+ sys.modules["torchvision"] = tv
101
+
102
+
103
+ _stub_torchvision()
104
+
105
+
106
+ def _stub_timm() -> None:
107
+ """Provide minimal timm stubs to avoid optional vision deps."""
108
+
109
+ if "timm" in sys.modules:
110
+ return
111
+
112
+ from importlib.machinery import ModuleSpec
113
+
114
+ def _mk(name: str) -> types.ModuleType:
115
+ mod = types.ModuleType(name)
116
+ mod.__spec__ = ModuleSpec(name, loader=None)
117
+ return mod
118
+
119
+ timm = _mk("timm")
120
+ timm.__dict__["__path__"] = []
121
+ sys.modules["timm"] = timm
122
+
123
+ data_mod = _mk("timm.data")
124
+ sys.modules["timm.data"] = data_mod
125
+ timm.data = data_mod
126
+
127
+ class _ImageNetInfo:
128
+ pass
129
+
130
+ def _infer_imagenet_subset(*args, **kwargs):
131
+ return None
132
+
133
+ data_mod.ImageNetInfo = _ImageNetInfo
134
+ data_mod.infer_imagenet_subset = _infer_imagenet_subset
135
+
136
+ layers_mod = _mk("timm.layers")
137
+ sys.modules["timm.layers"] = layers_mod
138
+ timm.layers = layers_mod
139
+
140
+ create_norm_mod = _mk("timm.layers.create_norm")
141
+ sys.modules["timm.layers.create_norm"] = create_norm_mod
142
+ layers_mod.create_norm = create_norm_mod
143
+
144
+ def _get_norm_layer(*args, **kwargs):
145
+ return None
146
+
147
+ create_norm_mod.get_norm_layer = _get_norm_layer
148
+
149
+ classifier_mod = _mk("timm.layers.classifier")
150
+ sys.modules["timm.layers.classifier"] = classifier_mod
151
+ layers_mod.classifier = classifier_mod
152
+
153
+
154
+ _stub_timm()
155
+
156
+ import transformers
157
+
158
+ # Provide light stubs if Longformer classes are unavailable; IFR case study does not use them.
159
+ if not hasattr(transformers, "LongformerTokenizer"):
160
+ class _DummyLongformerTokenizer:
161
+ def __init__(self, *args, **kwargs):
162
+ raise ImportError("LongformerTokenizer stubbed; install full transformers+torchvision if needed.")
163
+ transformers.LongformerTokenizer = _DummyLongformerTokenizer
164
+
165
+ if not hasattr(transformers, "LongformerForMaskedLM"):
166
+ class _DummyLongformerForMaskedLM:
167
+ def __init__(self, *args, **kwargs):
168
+ raise ImportError("LongformerForMaskedLM stubbed; install full transformers+torchvision if needed.")
169
+ transformers.LongformerForMaskedLM = _DummyLongformerForMaskedLM
170
+
171
+ if hasattr(transformers, "__all__"):
172
+ for _name in ["LongformerTokenizer", "LongformerForMaskedLM"]:
173
+ if _name not in transformers.__all__:
174
+ transformers.__all__.append(_name)
175
+
176
+ # Gemma3n stubs (transformers may attempt to import even if unused)
177
+ if "transformers.models.gemma3n.configuration_gemma3n" not in sys.modules:
178
+ from importlib.machinery import ModuleSpec
179
+
180
+ gemma_pkg = types.ModuleType("transformers.models.gemma3n")
181
+ gemma_pkg.__spec__ = ModuleSpec("transformers.models.gemma3n", loader=None, is_package=True)
182
+ sys.modules["transformers.models.gemma3n"] = gemma_pkg
183
+
184
+ gemma_conf = types.ModuleType("transformers.models.gemma3n.configuration_gemma3n")
185
+ gemma_conf.__spec__ = ModuleSpec("transformers.models.gemma3n.configuration_gemma3n", loader=None)
186
+
187
+ class Gemma3nConfig:
188
+ def __init__(self, *args, **kwargs):
189
+ self.model_type = "gemma3n"
190
+
191
+ class Gemma3nTextConfig(Gemma3nConfig):
192
+ pass
193
+
194
+ gemma_conf.Gemma3nConfig = Gemma3nConfig
195
+ gemma_conf.Gemma3nTextConfig = Gemma3nTextConfig
196
+ gemma_conf.__all__ = ["Gemma3nConfig", "Gemma3nTextConfig"]
197
+ sys.modules["transformers.models.gemma3n.configuration_gemma3n"] = gemma_conf
198
+ setattr(gemma_pkg, "configuration_gemma3n", gemma_conf)
199
+
200
+ if hasattr(transformers, "__all__"):
201
+ for _nm in ["Gemma3nConfig", "Gemma3nTextConfig"]:
202
+ if _nm not in transformers.__all__:
203
+ transformers.__all__.append(_nm)
204
+
205
+ import llm_attr
206
+ from exp.exp2 import dataset_utils as ds_utils
207
+ from evaluations.attribution_recovery import load_model
208
+
209
+ from exp.case_study import analysis, viz
210
+
211
+
212
+ def resolve_device(cuda: Optional[str], cuda_num: int) -> str:
213
+ if cuda and isinstance(cuda, str) and "," in cuda:
214
+ os.environ["CUDA_VISIBLE_DEVICES"] = cuda
215
+ return "auto"
216
+ if cuda and isinstance(cuda, str) and cuda.strip():
217
+ try:
218
+ idx = int(cuda)
219
+ except Exception:
220
+ idx = 0
221
+ return f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
222
+ return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu"
223
+
224
+
225
+ def load_example(dataset: str, index: int, data_root: Path) -> Tuple[ds_utils.CachedExample, str]:
226
+ """Load a single example from a cache path or dataset name."""
227
+
228
+ ds_path = Path(dataset)
229
+ if ds_path.exists():
230
+ examples = ds_utils.read_cached_jsonl(ds_path)
231
+ dataset_name = ds_path.name
232
+ else:
233
+ loader = ds_utils.DatasetLoader(data_root=data_root)
234
+ examples = loader.load(dataset)
235
+ dataset_name = dataset
236
+
237
+ if not examples:
238
+ raise ValueError(f"No examples found for dataset={dataset}")
239
+
240
+ if index < 0:
241
+ index = len(examples) + index
242
+ if not (0 <= index < len(examples)):
243
+ raise IndexError(f"index {index} out of range for dataset with {len(examples)} examples")
244
+
245
+ return examples[index], dataset_name
246
+
247
+
248
+ def parse_args() -> argparse.Namespace:
249
+ parser = argparse.ArgumentParser("IFR multi-hop case study")
250
+ parser.add_argument("--dataset", type=str, default="exp/exp2/data/morehopqa.jsonl", help="Dataset name or JSONL path.")
251
+ parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache root for dataset names.")
252
+ parser.add_argument("--index", type=int, default=0, help="Sample index (supports negative for reverse).")
253
+ parser.add_argument(
254
+ "--mode",
255
+ type=str,
256
+ choices=[
257
+ "ft",
258
+ "ft_improve",
259
+ "ft_split_hop",
260
+ "ifr_in_all_gen",
261
+ "ifr",
262
+ "ifr_all_positions",
263
+ "ifr_all_positions_output_only",
264
+ "attnlrp",
265
+ "ft_attnlrp",
266
+ ],
267
+ default="ft",
268
+ help=(
269
+ "ft = FlashTrace (multi-hop IFR); ifr = standard IFR span-aggregate; "
270
+ "ifr_in_all_gen = multi-hop IFR over CoT+output (scheme B; exp2-aligned); "
271
+ "ifr_all_positions = full IFR matrix + CAGE row/rec; "
272
+ "ft_improve = FlashTrace (multi-hop IFR, stop-token soft deletion); "
273
+ "ft_split_hop = FlashTrace (split-hop IFR over segmented thinking span); "
274
+ "ifr_all_positions_output_only = output-only IFR matrix + CAGE row/rec; "
275
+ "attnlrp = AttnLRP hop0 (FT-AttnLRP span-aggregate); "
276
+ "ft_attnlrp = FT-AttnLRP (multi-hop aggregated; exp2)."
277
+ ),
278
+ )
279
+ parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (ignored if --model_path set).")
280
+ parser.add_argument("--model_path", type=str, default=None, help="Local model path to override --model.")
281
+ parser.add_argument("--cuda", type=str, default=None, help="CUDA spec (e.g., '0' or '0,1').")
282
+ parser.add_argument("--cuda_num", type=int, default=0, help="Fallback GPU index when --cuda unset.")
283
+ parser.add_argument("--n_hops", type=int, default=1, help="Number of hops for IFR multi-hop.")
284
+ parser.add_argument("--sink_span", type=int, nargs=2, default=None, help="Optional sink span over generation tokens.")
285
+ parser.add_argument("--thinking_span", type=int, nargs=2, default=None, help="Optional thinking span over generation tokens.")
286
+ parser.add_argument(
287
+ "--attnlrp_neg_handling",
288
+ type=str,
289
+ choices=["drop", "abs"],
290
+ default="drop",
291
+ help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
292
+ )
293
+ parser.add_argument(
294
+ "--attnlrp_norm_mode",
295
+ type=str,
296
+ choices=["norm", "no_norm"],
297
+ default="norm",
298
+ help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
299
+ )
300
+ parser.add_argument("--chunk_tokens", type=int, default=128, help="IFR chunk size.")
301
+ parser.add_argument("--sink_chunk_tokens", type=int, default=32, help="IFR sink chunk size.")
302
+ parser.add_argument("--output_dir", type=str, default="exp/case_study/out", help="Where to write HTML/JSON artifacts.")
303
+ return parser.parse_args()
304
+
305
+
306
+ def run_ft_multihop(
307
+ example: ds_utils.CachedExample,
308
+ model: Any,
309
+ tokenizer: Any,
310
+ *,
311
+ n_hops: int,
312
+ sink_span: Optional[Sequence[int]],
313
+ thinking_span: Optional[Sequence[int]],
314
+ chunk_tokens: int,
315
+ sink_chunk_tokens: int,
316
+ ) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
317
+ """Execute FT (current multi-hop IFR) attribution for the selected example."""
318
+
319
+ attr = llm_attr.LLMIFRAttribution(
320
+ model,
321
+ tokenizer,
322
+ chunk_tokens=chunk_tokens,
323
+ sink_chunk_tokens=sink_chunk_tokens,
324
+ )
325
+
326
+ sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
327
+ thinking = (
328
+ tuple(thinking_span)
329
+ if thinking_span is not None
330
+ else tuple(example.thinking_span) if example.thinking_span else None
331
+ )
332
+
333
+ result = attr.calculate_ifr_multi_hop(
334
+ example.prompt,
335
+ target=example.target,
336
+ sink_span=sink,
337
+ thinking_span=thinking,
338
+ n_hops=n_hops,
339
+ )
340
+ debug_info: Dict[str, Any] = {
341
+ "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
342
+ "generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
343
+ "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
344
+ "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
345
+ "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
346
+ "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
347
+ }
348
+
349
+ raw_vectors = []
350
+ if result.metadata and "ifr" in result.metadata:
351
+ raw_ifr = result.metadata["ifr"].get("raw")
352
+ if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
353
+ try:
354
+ raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
355
+ except Exception:
356
+ raw_vectors = []
357
+ debug_info["raw_hop_vectors"] = raw_vectors
358
+
359
+ return result, sink, thinking, debug_info
360
+
361
+
362
+ def run_ft_multihop_improve(
363
+ example: ds_utils.CachedExample,
364
+ model: Any,
365
+ tokenizer: Any,
366
+ *,
367
+ n_hops: int,
368
+ sink_span: Optional[Sequence[int]],
369
+ thinking_span: Optional[Sequence[int]],
370
+ chunk_tokens: int,
371
+ sink_chunk_tokens: int,
372
+ ) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
373
+ """Execute experimental FT (multi-hop IFR) with stop-token soft deletion."""
374
+
375
+ import ft_ifr_improve
376
+
377
+ attr = ft_ifr_improve.LLMIFRAttributionImproved(
378
+ model,
379
+ tokenizer,
380
+ chunk_tokens=chunk_tokens,
381
+ sink_chunk_tokens=sink_chunk_tokens,
382
+ )
383
+
384
+ sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
385
+ thinking = (
386
+ tuple(thinking_span)
387
+ if thinking_span is not None
388
+ else tuple(example.thinking_span) if example.thinking_span else None
389
+ )
390
+
391
+ result = attr.calculate_ifr_multi_hop_stop_words(
392
+ example.prompt,
393
+ target=example.target,
394
+ sink_span=sink,
395
+ thinking_span=thinking,
396
+ n_hops=n_hops,
397
+ )
398
+
399
+ debug_info: Dict[str, Any] = {
400
+ "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
401
+ "generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
402
+ "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
403
+ "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
404
+ "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
405
+ "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
406
+ }
407
+
408
+ raw_vectors = []
409
+ if result.metadata and "ifr" in result.metadata:
410
+ raw_ifr = result.metadata["ifr"].get("raw")
411
+ if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
412
+ try:
413
+ raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
414
+ except Exception:
415
+ raw_vectors = []
416
+ debug_info["raw_hop_vectors"] = raw_vectors
417
+
418
+ return result, sink, thinking, debug_info
419
+
420
+
421
+ def run_ft_multihop_split_hop(
422
+ example: ds_utils.CachedExample,
423
+ model: Any,
424
+ tokenizer: Any,
425
+ *,
426
+ n_hops: int,
427
+ sink_span: Optional[Sequence[int]],
428
+ thinking_span: Optional[Sequence[int]],
429
+ chunk_tokens: int,
430
+ sink_chunk_tokens: int,
431
+ ) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
432
+ """Execute experimental FT (split-hop IFR over segmented thinking span)."""
433
+
434
+ import ft_ifr_improve
435
+
436
+ attr = ft_ifr_improve.LLMIFRAttributionSplitHop(
437
+ model,
438
+ tokenizer,
439
+ chunk_tokens=chunk_tokens,
440
+ sink_chunk_tokens=sink_chunk_tokens,
441
+ )
442
+
443
+ sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
444
+ thinking = (
445
+ tuple(thinking_span)
446
+ if thinking_span is not None
447
+ else tuple(example.thinking_span) if example.thinking_span else None
448
+ )
449
+
450
+ result = attr.calculate_ifr_multi_hop_split_hop(
451
+ example.prompt,
452
+ target=example.target,
453
+ sink_span=sink,
454
+ thinking_span=thinking,
455
+ n_hops=int(n_hops),
456
+ )
457
+
458
+ debug_info: Dict[str, Any] = {
459
+ "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
460
+ "generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
461
+ "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
462
+ "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
463
+ "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
464
+ "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
465
+ }
466
+
467
+ raw_vectors = []
468
+ if result.metadata and "ifr" in result.metadata:
469
+ raw_ifr = result.metadata["ifr"].get("raw")
470
+ if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
471
+ try:
472
+ raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
473
+ except Exception:
474
+ raw_vectors = []
475
+ debug_info["raw_hop_vectors"] = raw_vectors
476
+
477
+ return result, sink, thinking, debug_info
478
+
479
+
480
+ def run_ifr_in_all_gen(
481
+ example: ds_utils.CachedExample,
482
+ model: Any,
483
+ tokenizer: Any,
484
+ *,
485
+ n_hops: int,
486
+ sink_span: Optional[Sequence[int]],
487
+ thinking_span: Optional[Sequence[int]],
488
+ chunk_tokens: int,
489
+ sink_chunk_tokens: int,
490
+ ) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
491
+ """Execute experimental IFR variant: multi-hop over all generation (CoT + output)."""
492
+
493
+ import ft_ifr_improve
494
+
495
+ attr = ft_ifr_improve.LLMIFRAttributionInAllGen(
496
+ model,
497
+ tokenizer,
498
+ chunk_tokens=chunk_tokens,
499
+ sink_chunk_tokens=sink_chunk_tokens,
500
+ )
501
+
502
+ sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
503
+ thinking = (
504
+ tuple(thinking_span)
505
+ if thinking_span is not None
506
+ else tuple(example.thinking_span) if example.thinking_span else None
507
+ )
508
+
509
+ result = attr.calculate_ifr_in_all_gen(
510
+ example.prompt,
511
+ target=example.target,
512
+ sink_span=sink,
513
+ thinking_span=thinking,
514
+ n_hops=int(n_hops),
515
+ )
516
+
517
+ debug_info: Dict[str, Any] = {
518
+ "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
519
+ "generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
520
+ "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
521
+ "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
522
+ "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
523
+ "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
524
+ }
525
+
526
+ raw_vectors = []
527
+ if result.metadata and "ifr" in result.metadata:
528
+ raw_ifr = result.metadata["ifr"].get("raw")
529
+ if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
530
+ try:
531
+ raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
532
+ except Exception:
533
+ raw_vectors = []
534
+ debug_info["raw_hop_vectors"] = raw_vectors
535
+
536
+ return result, sink, thinking, debug_info
537
+
538
+
539
+ def make_output_stem(dataset_name: str, index: int, mode: str) -> str:
540
+ safe_name = dataset_name.replace("/", "_").replace(" ", "_")
541
+ prefix = {
542
+ "ft": "ft_case_",
543
+ "ft_improve": "ft_improve_case_",
544
+ "ifr": "ifr_case_",
545
+ "ifr_all_positions": "ifr_all_positions_case_",
546
+ "ifr_all_positions_output_only": "ifr_output_only_case_",
547
+ "attnlrp": "attnlrp_case_",
548
+ "ft_attnlrp": "ft_attnlrp_case_",
549
+ }.get(mode, f"{mode}_case_")
550
+ return f"{prefix}{safe_name}_idx{index}"
551
+
552
+
553
+ def _decode_token_ids(tokenizer: Any, ids: Sequence[int]) -> List[str]:
554
+ """Decode each token id into a readable text piece (keeps special tokens)."""
555
+
556
+ pieces: List[str] = []
557
+ for tok_id in ids:
558
+ try:
559
+ pieces.append(
560
+ tokenizer.decode([int(tok_id)], skip_special_tokens=False, clean_up_tokenization_spaces=False)
561
+ )
562
+ except Exception:
563
+ pieces.append(str(tok_id))
564
+ return pieces
565
+
566
+
567
+ def build_raw_tokens_from_ids(tokenizer: Any, prompt_ids: Optional[Sequence[int]], generation_ids: Optional[Sequence[int]]) -> List[str]:
568
+ if not prompt_ids:
569
+ prompt_ids = []
570
+ if not generation_ids:
571
+ generation_ids = []
572
+ return _decode_token_ids(tokenizer, prompt_ids) + _decode_token_ids(tokenizer, generation_ids)
573
+
574
+
575
+ def build_trimmed_roles(tokens: Sequence[str], segments: Dict[str, Any]) -> List[str]:
576
+ """Assign role labels for trimmed tokens (prompt + generation)."""
577
+
578
+ roles = ["prompt" for _ in range(len(tokens))]
579
+ prompt_len_tokens = segments.get("prompt_len", 0)
580
+ for idx in range(prompt_len_tokens, len(tokens)):
581
+ roles[idx] = "gen"
582
+ thinking_span = segments.get("thinking_span")
583
+ sink_span = segments.get("sink_span")
584
+ if thinking_span is not None:
585
+ start = prompt_len_tokens + int(thinking_span[0])
586
+ end = prompt_len_tokens + int(thinking_span[1])
587
+ for i in range(start, min(len(tokens), end + 1)):
588
+ roles[i] = "think"
589
+ if sink_span is not None:
590
+ start = prompt_len_tokens + int(sink_span[0])
591
+ end = prompt_len_tokens + int(sink_span[1])
592
+ for i in range(start, min(len(tokens), end + 1)):
593
+ roles[i] = "output"
594
+ return roles
595
+
596
+
597
+ def build_raw_roles(
598
+ tokens: Sequence[str],
599
+ prompt_len_full: int,
600
+ user_indices: Sequence[int],
601
+ template_indices: Sequence[int],
602
+ thinking_span_abs: Optional[Sequence[int]],
603
+ sink_span_abs: Optional[Sequence[int]],
604
+ ) -> List[str]:
605
+ """Assign role labels for raw tokens (template + user + generation)."""
606
+
607
+ roles = ["template" for _ in range(len(tokens))]
608
+ user_set = set(int(i) for i in user_indices)
609
+ tmpl_set = set(int(i) for i in template_indices)
610
+
611
+ for i in range(min(len(tokens), prompt_len_full)):
612
+ if i in user_set:
613
+ roles[i] = "user"
614
+ elif i in tmpl_set:
615
+ roles[i] = "template"
616
+ else:
617
+ roles[i] = "prompt"
618
+
619
+ for i in range(prompt_len_full, len(tokens)):
620
+ roles[i] = "gen"
621
+
622
+ if thinking_span_abs is not None:
623
+ start, end = int(thinking_span_abs[0]), int(thinking_span_abs[1])
624
+ for i in range(start, min(len(tokens), end + 1)):
625
+ roles[i] = "think"
626
+
627
+ if sink_span_abs is not None:
628
+ start, end = int(sink_span_abs[0]), int(sink_span_abs[1])
629
+ for i in range(start, min(len(tokens), end + 1)):
630
+ roles[i] = "output"
631
+
632
+ return roles
633
+
634
+
635
+ def extract_prompt_only_vectors(hop_vectors: Sequence[torch.Tensor], prompt_len: int) -> List[torch.Tensor]:
636
+ """Slice hop vectors down to user-prompt tokens only (no generation tokens)."""
637
+
638
+ if prompt_len < 0:
639
+ raise ValueError("prompt_len must be >= 0.")
640
+
641
+ out: List[torch.Tensor] = []
642
+ for vec in hop_vectors:
643
+ v = torch.as_tensor(vec, dtype=torch.float32).detach().cpu()
644
+ if int(v.numel()) < int(prompt_len):
645
+ raise ValueError(f"Hop vector too short for prompt-only slice: len={int(v.numel())} prompt_len={int(prompt_len)}.")
646
+ out.append(v[:prompt_len])
647
+ return out
648
+
649
+
650
+ def _lift_trimmed_to_full(
651
+ trimmed: torch.Tensor,
652
+ *,
653
+ prompt_len_full: int,
654
+ gen_len: int,
655
+ user_prompt_indices: Sequence[int],
656
+ ) -> torch.Tensor:
657
+ """Lift a trimmed (user prompt + generation) vector into full token space with zeros for chat-template tokens."""
658
+
659
+ t = torch.as_tensor(trimmed, dtype=torch.float32).detach().cpu()
660
+ user_len = len(user_prompt_indices)
661
+ expected = int(user_len + gen_len)
662
+ if int(t.numel()) != expected:
663
+ raise ValueError(f"Trimmed vector length mismatch: got {int(t.numel())}, expected {expected}.")
664
+
665
+ total_len = int(prompt_len_full + gen_len)
666
+ full = torch.zeros((total_len,), dtype=torch.float32)
667
+ for j, abs_pos in enumerate(user_prompt_indices):
668
+ full[int(abs_pos)] = t[j]
669
+ full[int(prompt_len_full) : int(prompt_len_full + gen_len)] = t[user_len:]
670
+ return full
671
+
672
+
673
+ def _postprocess_attnlrp_full_vector(
674
+ raw_full: torch.Tensor,
675
+ *,
676
+ prompt_len_full: int,
677
+ gen_len: int,
678
+ user_prompt_indices: Sequence[int],
679
+ neg_handling: str,
680
+ norm_mode: str,
681
+ ) -> torch.Tensor:
682
+ """Mirror FT-AttnLRP hop postprocessing while preserving stripped-token normalization.
683
+
684
+ The underlying AttnLRP implementation postprocesses the *stripped* vector (user prompt + generation):
685
+ - NaN->0, then neg_handling ('drop' or 'abs')
686
+ - if norm_mode=='norm': normalize by sum over stripped tokens
687
+
688
+ For the pre-trim full view (chat template + generation), we apply the same non-negativity transform
689
+ to the full vector and normalize using *only the stripped indices*, so overlapping token scores
690
+ match the trimmed vectors used by the evaluation/case-study hop outputs.
691
+ """
692
+
693
+ v = torch.as_tensor(raw_full, dtype=torch.float32).detach().cpu()
694
+ v = torch.nan_to_num(v, nan=0.0)
695
+
696
+ if neg_handling == "drop":
697
+ v = v.clamp(min=0.0)
698
+ elif neg_handling == "abs":
699
+ v = v.abs()
700
+ else:
701
+ raise ValueError(f"Unsupported neg_handling={neg_handling!r} (expected 'drop' or 'abs').")
702
+
703
+ ratio_enabled = norm_mode == "norm"
704
+ if not ratio_enabled:
705
+ return v
706
+
707
+ keep = list(int(i) for i in user_prompt_indices) + list(range(int(prompt_len_full), int(prompt_len_full + gen_len)))
708
+ if not keep:
709
+ return torch.zeros_like(v)
710
+
711
+ keep_idx = torch.as_tensor(keep, dtype=torch.long)
712
+ denom = float(v.index_select(0, keep_idx).sum().item())
713
+ if denom <= 0.0:
714
+ return torch.zeros_like(v)
715
+ return v / (denom + 1e-12)
716
+
717
+
718
+ def main() -> None:
719
+ args = parse_args()
720
+ device = resolve_device(args.cuda, args.cuda_num)
721
+ if torch.cuda.is_available():
722
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES")
723
+ print(f"[info] CUDA_VISIBLE_DEVICES={visible!r} torch.cuda.device_count()={torch.cuda.device_count()} device={device}")
724
+
725
+ model_name = args.model_path if args.model_path is not None else args.model
726
+ # Align with exp/exp2: always use the shared fp16 loader.
727
+ model, tokenizer = load_model(model_name, device)
728
+
729
+ example, ds_name = load_example(args.dataset, args.index, Path(args.data_root))
730
+ mode = args.mode
731
+
732
+ sink_span: Optional[Tuple[int, int]] = None
733
+ thinking_span: Optional[Tuple[int, int]] = None
734
+ thinking_ratios: Optional[Sequence[float]] = None
735
+
736
+ prompt_tokens_trimmed: List[str] = []
737
+ generation_tokens_trimmed: List[str] = []
738
+ hop_vectors_trimmed: List[torch.Tensor] = []
739
+ hop_vectors_raw: List[torch.Tensor] = []
740
+ prompt_len_full: Optional[int] = None
741
+ user_prompt_indices: List[int] = []
742
+ chat_prompt_indices: List[int] = []
743
+ method_meta: Dict[str, Any] = {}
744
+ raw_prompt_ids: Optional[List[int]] = None
745
+ raw_generation_ids: Optional[List[int]] = None
746
+ attnlrp_raw_attributions: Optional[List[Any]] = None
747
+
748
+ if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen"):
749
+ if mode == "ft":
750
+ attr_result, sink_span, thinking_span, debug_info = run_ft_multihop(
751
+ example,
752
+ model,
753
+ tokenizer,
754
+ n_hops=args.n_hops,
755
+ sink_span=args.sink_span,
756
+ thinking_span=args.thinking_span,
757
+ chunk_tokens=args.chunk_tokens,
758
+ sink_chunk_tokens=args.sink_chunk_tokens,
759
+ )
760
+ elif mode == "ft_improve":
761
+ attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_improve(
762
+ example,
763
+ model,
764
+ tokenizer,
765
+ n_hops=args.n_hops,
766
+ sink_span=args.sink_span,
767
+ thinking_span=args.thinking_span,
768
+ chunk_tokens=args.chunk_tokens,
769
+ sink_chunk_tokens=args.sink_chunk_tokens,
770
+ )
771
+ elif mode == "ft_split_hop":
772
+ attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_split_hop(
773
+ example,
774
+ model,
775
+ tokenizer,
776
+ n_hops=args.n_hops,
777
+ sink_span=args.sink_span,
778
+ thinking_span=args.thinking_span,
779
+ chunk_tokens=args.chunk_tokens,
780
+ sink_chunk_tokens=args.sink_chunk_tokens,
781
+ )
782
+ elif mode == "ifr_in_all_gen":
783
+ attr_result, sink_span, thinking_span, debug_info = run_ifr_in_all_gen(
784
+ example,
785
+ model,
786
+ tokenizer,
787
+ n_hops=args.n_hops,
788
+ sink_span=args.sink_span,
789
+ thinking_span=args.thinking_span,
790
+ chunk_tokens=args.chunk_tokens,
791
+ sink_chunk_tokens=args.sink_chunk_tokens,
792
+ )
793
+ else:
794
+ raise ValueError(f"Unsupported mode={mode}")
795
+ ifr_meta = (attr_result.metadata or {}).get("ifr") or {}
796
+ hop_vectors_trimmed = list(ifr_meta.get("per_hop_projected") or [])
797
+ if not hop_vectors_trimmed:
798
+ raise RuntimeError(f"No per-hop vectors found for {mode} mode.")
799
+
800
+ prompt_tokens_trimmed = list(attr_result.prompt_tokens)
801
+ generation_tokens_trimmed = list(attr_result.generation_tokens)
802
+ thinking_ratios = ifr_meta.get("thinking_ratios")
803
+
804
+ raw_prompt_ids = debug_info.get("prompt_ids")
805
+ if isinstance(raw_prompt_ids, list) and raw_prompt_ids and isinstance(raw_prompt_ids[0], list):
806
+ raw_prompt_ids = raw_prompt_ids[0]
807
+ raw_generation_ids = debug_info.get("generation_ids")
808
+ if isinstance(raw_generation_ids, list) and raw_generation_ids and isinstance(raw_generation_ids[0], list):
809
+ raw_generation_ids = raw_generation_ids[0]
810
+
811
+ user_prompt_indices = list(debug_info.get("user_prompt_indices") or [])
812
+ chat_prompt_indices = list(debug_info.get("chat_prompt_indices") or [])
813
+ prompt_len_full = len(raw_prompt_ids) if isinstance(raw_prompt_ids, list) else None
814
+
815
+ raw_vectors = debug_info.get("raw_hop_vectors") or []
816
+ hop_vectors_raw = [vec.detach().cpu() if hasattr(vec, "detach") else torch.as_tensor(vec) for vec in raw_vectors]
817
+ method_meta = {"ifr": analysis.sanitize_ifr_meta(ifr_meta)}
818
+
819
+ elif mode == "ifr":
820
+ # Standard IFR (single-hop span aggregate), with pre/post trim views.
821
+ attr = llm_attr.LLMIFRAttribution(
822
+ model,
823
+ tokenizer,
824
+ chunk_tokens=args.chunk_tokens,
825
+ sink_chunk_tokens=args.sink_chunk_tokens,
826
+ )
827
+ sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
828
+ thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
829
+
830
+ if sink_span is None:
831
+ raise ValueError("sink_span is required for IFR mode (use dataset sink_span or pass --sink_span).")
832
+ span_result = attr.calculate_ifr_span(
833
+ example.prompt,
834
+ target=example.target,
835
+ span=tuple(sink_span),
836
+ )
837
+ span_meta = span_result.metadata.get("ifr") if span_result.metadata else None
838
+ aggregate = span_meta.get("aggregate") if isinstance(span_meta, dict) else None
839
+ if aggregate is None or not hasattr(aggregate, "token_importance_total"):
840
+ raise RuntimeError("IFR span aggregate missing from metadata; cannot render pre-trim view.")
841
+
842
+ raw_vector = aggregate.token_importance_total.detach().cpu()
843
+ trimmed_vector = attr._project_vector(raw_vector)
844
+ hop_vectors_raw = [raw_vector]
845
+ hop_vectors_trimmed = [trimmed_vector]
846
+
847
+ prompt_tokens_trimmed = list(attr.user_prompt_tokens)
848
+ generation_tokens_trimmed = list(attr.generation_tokens)
849
+
850
+ raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
851
+ raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
852
+ user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
853
+ chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
854
+ prompt_len_full = len(raw_prompt_ids)
855
+
856
+ sink_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1])
857
+ think_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1]) if thinking_span else None
858
+
859
+ meta = {
860
+ "type": "span_aggregate",
861
+ "ifr_view": "aggregate",
862
+ "sink_span_generation": sink_span,
863
+ "sink_span_absolute": sink_abs,
864
+ "thinking_span_generation": thinking_span,
865
+ "thinking_span_absolute": think_abs,
866
+ }
867
+ method_meta = {"ifr": analysis.tensor_to_list(meta)}
868
+
869
+ elif mode == "ifr_all_positions_output_only":
870
+ # IFR all-positions (output-only) + token-level CAGE (row/recursive) derived from the matrix.
871
+ attr = llm_attr.LLMIFRAttribution(
872
+ model,
873
+ tokenizer,
874
+ chunk_tokens=args.chunk_tokens,
875
+ sink_chunk_tokens=args.sink_chunk_tokens,
876
+ )
877
+ sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
878
+ thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
879
+
880
+ if sink_span is None:
881
+ raise ValueError(
882
+ "sink_span is required for ifr_all_positions_output_only mode "
883
+ "(use dataset sink_span or pass --sink_span)."
884
+ )
885
+
886
+ attr_result = attr.calculate_ifr_for_all_positions_output_only(
887
+ example.prompt,
888
+ target=example.target,
889
+ sink_span=tuple(sink_span),
890
+ )
891
+
892
+ indices_to_explain = list(sink_span)
893
+ _, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
894
+ row_vec = row_attr.squeeze(0).detach().cpu()
895
+ rec_vec = rec_attr.squeeze(0).detach().cpu()
896
+
897
+ hop_vectors_trimmed = [row_vec, rec_vec]
898
+
899
+ prompt_tokens_trimmed = list(attr.user_prompt_tokens)
900
+ generation_tokens_trimmed = list(attr.generation_tokens)
901
+
902
+ raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
903
+ raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
904
+ user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
905
+ chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
906
+ prompt_len_full = len(raw_prompt_ids)
907
+
908
+ gen_len = len(raw_generation_ids or [])
909
+ hop_vectors_raw = [
910
+ _lift_trimmed_to_full(
911
+ v,
912
+ prompt_len_full=int(prompt_len_full or 0),
913
+ gen_len=gen_len,
914
+ user_prompt_indices=user_prompt_indices,
915
+ )
916
+ for v in hop_vectors_trimmed
917
+ ]
918
+
919
+ ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {})
920
+ ifr_meta["ifr_view"] = "all_positions_output_only (row+rec)"
921
+ ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"]
922
+ ifr_meta["indices_to_explain"] = indices_to_explain
923
+ method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)}
924
+
925
+ elif mode == "ifr_all_positions":
926
+ # IFR all-positions (full generation) + token-level CAGE (row/recursive) derived from the matrix.
927
+ attr = llm_attr.LLMIFRAttribution(
928
+ model,
929
+ tokenizer,
930
+ chunk_tokens=args.chunk_tokens,
931
+ sink_chunk_tokens=args.sink_chunk_tokens,
932
+ )
933
+ sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
934
+ thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
935
+
936
+ if sink_span is None:
937
+ raise ValueError(
938
+ "sink_span is required for ifr_all_positions mode (use dataset sink_span or pass --sink_span)."
939
+ )
940
+
941
+ attr_result = attr.calculate_ifr_for_all_positions(
942
+ example.prompt,
943
+ target=example.target,
944
+ )
945
+
946
+ indices_to_explain = list(sink_span)
947
+ _, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
948
+ row_vec = row_attr.squeeze(0).detach().cpu()
949
+ rec_vec = rec_attr.squeeze(0).detach().cpu()
950
+
951
+ hop_vectors_trimmed = [row_vec, rec_vec]
952
+
953
+ prompt_tokens_trimmed = list(attr.user_prompt_tokens)
954
+ generation_tokens_trimmed = list(attr.generation_tokens)
955
+
956
+ raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
957
+ raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
958
+ user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
959
+ chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
960
+ prompt_len_full = len(raw_prompt_ids)
961
+
962
+ gen_len = len(raw_generation_ids or [])
963
+ hop_vectors_raw = [
964
+ _lift_trimmed_to_full(
965
+ v,
966
+ prompt_len_full=int(prompt_len_full or 0),
967
+ gen_len=gen_len,
968
+ user_prompt_indices=user_prompt_indices,
969
+ )
970
+ for v in hop_vectors_trimmed
971
+ ]
972
+
973
+ ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {})
974
+ ifr_meta["ifr_view"] = "all_positions (row+rec)"
975
+ ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"]
976
+ ifr_meta["indices_to_explain"] = indices_to_explain
977
+ method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)}
978
+
979
+ elif mode in ("attnlrp", "ft_attnlrp"):
980
+ # Reuse the shared LLMLRPAttribution implementations (root-level).
981
+ attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
982
+
983
+ sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
984
+ thinking_span = (
985
+ tuple(args.thinking_span)
986
+ if args.thinking_span is not None
987
+ else tuple(example.thinking_span) if example.thinking_span else sink_span
988
+ )
989
+
990
+ if mode == "attnlrp":
991
+ # Case-study AttnLRP: reuse FT-AttnLRP logic but take hop0 (the first span-aggregate)
992
+ # for a full, signed attribution vector (no observation masking).
993
+ attr_result = attributor.calculate_attnlrp_ft_hop0(
994
+ example.prompt,
995
+ target=example.target,
996
+ sink_span=sink_span,
997
+ thinking_span=thinking_span,
998
+ neg_handling=args.attnlrp_neg_handling,
999
+ norm_mode=args.attnlrp_norm_mode,
1000
+ )
1001
+ meta = attr_result.metadata or {}
1002
+ multi_hop = meta.get("multi_hop_result")
1003
+ raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
1004
+ attnlrp_raw_attributions = list(raw_attributions)
1005
+ base_attr = raw_attributions[0] if raw_attributions else None
1006
+ if base_attr is None or not hasattr(base_attr, "token_importance_total"):
1007
+ raise RuntimeError("AttnLRP hop0 missing from multi-hop result.")
1008
+
1009
+ hop0_vec = torch.as_tensor(getattr(base_attr, "token_importance_total"), dtype=torch.float32).detach().cpu()
1010
+ if hop0_vec.numel() <= 0:
1011
+ raise RuntimeError("Empty generation for AttnLRP case study.")
1012
+
1013
+ # Use the actual sink span applied by hop0 (defaults to full generation when unset).
1014
+ sink_span = tuple(getattr(base_attr, "sink_range"))
1015
+ if thinking_span is None:
1016
+ thinking_span = sink_span
1017
+
1018
+ hop_vectors_trimmed = [hop0_vec]
1019
+ thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or [])
1020
+
1021
+ method_meta = {
1022
+ "attnlrp": {
1023
+ "type": "calculate_attnlrp_multi_hop(n_hops=0) hop0 raw_attributions[0]",
1024
+ "sink_span_generation": sink_span,
1025
+ "thinking_span_generation": thinking_span,
1026
+ "thinking_ratios": thinking_ratios,
1027
+ "neg_handling": args.attnlrp_neg_handling,
1028
+ "norm_mode": args.attnlrp_norm_mode,
1029
+ "ratio_enabled": args.attnlrp_norm_mode == "norm",
1030
+ }
1031
+ }
1032
+ else:
1033
+ # exp2 ft_attnlrp: multi-hop aggregated AttnLRP (metadata contains per-hop vectors).
1034
+ attr_result = attributor.calculate_attnlrp_aggregated_multi_hop(
1035
+ example.prompt,
1036
+ target=example.target,
1037
+ sink_span=sink_span,
1038
+ thinking_span=thinking_span,
1039
+ n_hops=int(args.n_hops),
1040
+ neg_handling=args.attnlrp_neg_handling,
1041
+ norm_mode=args.attnlrp_norm_mode,
1042
+ )
1043
+ meta = attr_result.metadata or {}
1044
+ multi_hop = meta.get("multi_hop_result")
1045
+ if multi_hop is None:
1046
+ raise RuntimeError("FT-AttnLRP case study missing metadata.multi_hop_result.")
1047
+
1048
+ raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
1049
+ attnlrp_raw_attributions = list(raw_attributions)
1050
+ hop_vectors_trimmed = [
1051
+ torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32).detach().cpu()
1052
+ for hop in raw_attributions
1053
+ ]
1054
+ thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or [])
1055
+
1056
+ method_meta = {
1057
+ "attnlrp": {
1058
+ "type": "calculate_attnlrp_aggregated_multi_hop (exp2 ft_attnlrp)",
1059
+ "n_hops": int(args.n_hops),
1060
+ "sink_span_generation": sink_span,
1061
+ "thinking_span_generation": thinking_span,
1062
+ "thinking_ratios": thinking_ratios,
1063
+ "neg_handling": args.attnlrp_neg_handling,
1064
+ "norm_mode": args.attnlrp_norm_mode,
1065
+ "ratio_enabled": args.attnlrp_norm_mode == "norm",
1066
+ }
1067
+ }
1068
+
1069
+ prompt_tokens_trimmed = list(attributor.user_prompt_tokens)
1070
+ generation_tokens_trimmed = list(attributor.generation_tokens)
1071
+
1072
+ raw_prompt_ids = attributor.prompt_ids.detach().cpu().tolist()[0]
1073
+ raw_generation_ids = attributor.generation_ids.detach().cpu().tolist()[0]
1074
+ user_prompt_indices = list(getattr(attributor, "user_prompt_indices", []) or [])
1075
+ chat_prompt_indices = list(getattr(attributor, "chat_prompt_indices", []) or [])
1076
+ prompt_len_full = len(raw_prompt_ids)
1077
+
1078
+ else:
1079
+ raise ValueError(f"Unsupported mode={mode}")
1080
+
1081
+ if not hop_vectors_trimmed:
1082
+ raise RuntimeError("No hop vectors to visualize.")
1083
+
1084
+ raw_tokens = build_raw_tokens_from_ids(tokenizer, raw_prompt_ids, raw_generation_ids)
1085
+
1086
+ sink_span_abs = None
1087
+ thinking_span_abs = None
1088
+ if prompt_len_full is not None and sink_span is not None:
1089
+ sink_span_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1])
1090
+ if prompt_len_full is not None and thinking_span is not None:
1091
+ thinking_span_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1])
1092
+ prompt_len_full_safe = int(prompt_len_full or 0)
1093
+ roles_raw = build_raw_roles(
1094
+ raw_tokens,
1095
+ prompt_len_full_safe,
1096
+ user_prompt_indices,
1097
+ chat_prompt_indices,
1098
+ thinking_span_abs,
1099
+ sink_span_abs,
1100
+ )
1101
+
1102
+ prompt_tokens_only = list(prompt_tokens_trimmed)
1103
+ prompt_only_vectors = extract_prompt_only_vectors(hop_vectors_trimmed, len(prompt_tokens_only))
1104
+
1105
+ # Ensure every method has a pre-trim full vector per panel.
1106
+ if not hop_vectors_raw:
1107
+ if mode in ("attnlrp", "ft_attnlrp") and attnlrp_raw_attributions is not None:
1108
+ gen_len = len(raw_generation_ids or [])
1109
+ expected = int((prompt_len_full_safe + gen_len) if prompt_len_full is not None else 0)
1110
+ full_vectors: List[torch.Tensor] = []
1111
+ for hop in attnlrp_raw_attributions:
1112
+ meta = getattr(hop, "metadata", None) or {}
1113
+ raw_full = meta.get("token_importance_total_with_chat_template")
1114
+ if raw_full is None:
1115
+ full_vectors = []
1116
+ break
1117
+ v = _postprocess_attnlrp_full_vector(
1118
+ torch.as_tensor(raw_full, dtype=torch.float32),
1119
+ prompt_len_full=prompt_len_full_safe,
1120
+ gen_len=gen_len,
1121
+ user_prompt_indices=user_prompt_indices,
1122
+ neg_handling=args.attnlrp_neg_handling,
1123
+ norm_mode=args.attnlrp_norm_mode,
1124
+ )
1125
+ if expected and int(v.numel()) != expected:
1126
+ raise RuntimeError(
1127
+ "AttnLRP full-vector length mismatch for pre-trim view: "
1128
+ f"got {int(v.numel())}, expected {expected}."
1129
+ )
1130
+ full_vectors.append(v)
1131
+ hop_vectors_raw = full_vectors
1132
+
1133
+ if not hop_vectors_raw and prompt_len_full is not None:
1134
+ # Fallback: lift trimmed vectors back to full token space with zeros for template tokens.
1135
+ gen_len = len(raw_generation_ids or [])
1136
+ hop_vectors_raw = [
1137
+ _lift_trimmed_to_full(
1138
+ v,
1139
+ prompt_len_full=prompt_len_full_safe,
1140
+ gen_len=gen_len,
1141
+ user_prompt_indices=user_prompt_indices,
1142
+ )
1143
+ for v in hop_vectors_trimmed
1144
+ ]
1145
+
1146
+ if not hop_vectors_raw:
1147
+ raise RuntimeError("Missing pre-trim vectors; cannot render required full-sequence heatmap.")
1148
+
1149
+ # Lightweight debug stats to catch silent all-zero / NaN cases.
1150
+ hop_stats_raw = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in hop_vectors_raw]
1151
+ hop_stats_prompt = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in prompt_only_vectors]
1152
+ for i in range(max(len(hop_stats_raw), len(hop_stats_prompt))):
1153
+ raw_abs = hop_stats_raw[i]["abs_max"] if i < len(hop_stats_raw) else None
1154
+ prompt_abs = hop_stats_prompt[i]["abs_max"] if i < len(hop_stats_prompt) else None
1155
+ print(f"[stats] panel {i}: raw_abs_max={raw_abs} prompt_abs_max={prompt_abs}")
1156
+
1157
+ hop_token_raw = analysis.package_token_hops(hop_vectors_raw)
1158
+ hop_token_prompt = analysis.package_token_hops(prompt_only_vectors)
1159
+
1160
+ case_meta: Dict[str, Any] = {
1161
+ "dataset": ds_name,
1162
+ "index": args.index,
1163
+ "sink_span": sink_span,
1164
+ "thinking_span": thinking_span,
1165
+ "n_hops": args.n_hops,
1166
+ "thinking_ratios": thinking_ratios,
1167
+ "mode": mode,
1168
+ "ifr_view": method_meta.get("ifr", {}).get("ifr_view") if isinstance(method_meta.get("ifr"), dict) else None,
1169
+ "panel_titles": method_meta.get("ifr", {}).get("panel_titles") if isinstance(method_meta.get("ifr"), dict) else None,
1170
+ "attnlrp_neg_handling": args.attnlrp_neg_handling if mode in ("attnlrp", "ft_attnlrp") else None,
1171
+ "attnlrp_norm_mode": args.attnlrp_norm_mode if mode in ("attnlrp", "ft_attnlrp") else None,
1172
+ "attnlrp_ratio_enabled": (args.attnlrp_norm_mode == "norm") if mode in ("attnlrp", "ft_attnlrp") else None,
1173
+ "vector_stats_raw": hop_stats_raw,
1174
+ "vector_stats_prompt": hop_stats_prompt,
1175
+ }
1176
+
1177
+ generation_text = "".join(generation_tokens_trimmed) if generation_tokens_trimmed else ""
1178
+ prompt_text = example.prompt
1179
+ record = {
1180
+ "meta": case_meta,
1181
+ "prompt": prompt_text,
1182
+ "target": example.target,
1183
+ "generation": generation_text,
1184
+ "full_all_tokens": raw_tokens,
1185
+ "raw_token_roles": roles_raw,
1186
+ "prompt_tokens": prompt_tokens_only,
1187
+ "prompt_token_roles": ["user" for _ in range(len(prompt_tokens_only))],
1188
+ "token_hops_raw": hop_token_raw,
1189
+ "token_hops_prompt": hop_token_prompt,
1190
+ "ifr_meta": method_meta.get("ifr"),
1191
+ "attnlrp_meta": method_meta.get("attnlrp"),
1192
+ }
1193
+
1194
+ out_dir = Path(args.output_dir)
1195
+ out_dir.mkdir(parents=True, exist_ok=True)
1196
+ stem = make_output_stem(ds_name, args.index, mode)
1197
+ json_path = out_dir / f"{stem}.json"
1198
+ html_path = out_dir / f"{stem}.html"
1199
+
1200
+ with json_path.open("w", encoding="utf-8") as f:
1201
+ json.dump(record, f, ensure_ascii=False, indent=2)
1202
+
1203
+ html = viz.render_case_html(
1204
+ case_meta,
1205
+ token_view_raw={
1206
+ "label": "Pre-trim token-level heatmap (full sequence with chat template)",
1207
+ "tokens": raw_tokens,
1208
+ "roles": roles_raw,
1209
+ "hops": hop_token_raw,
1210
+ },
1211
+ token_view_prompt={
1212
+ "label": "Prompt-only token-level heatmap (user prompt only)",
1213
+ "tokens": prompt_tokens_only,
1214
+ "roles": ["user" for _ in range(len(prompt_tokens_only))],
1215
+ "hops": hop_token_prompt,
1216
+ },
1217
+ )
1218
+ html_path.write_text(html, encoding="utf-8")
1219
+
1220
+ print(f"[done] wrote {json_path}")
1221
+ print(f"[done] wrote {html_path}")
1222
+
1223
+
1224
+ if __name__ == "__main__":
1225
+ main()
exp/case_study/run_mas_case.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """MAS case study: visualize token-perturbation faithfulness for attribution methods.
3
+
4
+ This script matches the faithfulness evaluation logic implemented in:
5
+ - evaluations/faithfulness.py
6
+ - llm_attr_eval.LLMAttributionEvaluator.faithfulness_test()
7
+
8
+ For a single example and a selected attribution method, we:
9
+ 1) Compute token-level attributions (Seq / Row / Recursive) over prompt tokens.
10
+ 2) Rank prompt tokens by attribution mass.
11
+ 3) Iteratively perturb the prompt by replacing one token at a time with PAD tokens.
12
+ 4) Score the model as sum log p(generation + EOS | prompt) under the chat template.
13
+ 5) Compute RISE / MAS / RISE+AP (AUCs) and visualize the perturbation impact as token heatmaps.
14
+
15
+ Outputs JSON + HTML to exp/case_study/out/.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import json
22
+ import os
23
+ import sys
24
+ import types
25
+ from importlib.machinery import ModuleSpec
26
+ from pathlib import Path
27
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
28
+
29
+ import numpy as np
30
+
31
+
32
+ def _early_set_cuda_visible_devices() -> None:
33
+ """Set CUDA_VISIBLE_DEVICES before importing torch/transformers.
34
+
35
+ Note: CUDA device indices are re-mapped inside the process after applying the mask.
36
+ """
37
+
38
+ parser = argparse.ArgumentParser(add_help=False)
39
+ parser.add_argument("--cuda", type=str, default=None)
40
+ args, _ = parser.parse_known_args(sys.argv[1:])
41
+ cuda = args.cuda.strip() if isinstance(args.cuda, str) else ""
42
+ if cuda and "," in cuda:
43
+ os.environ["CUDA_VISIBLE_DEVICES"] = cuda
44
+
45
+
46
+ if __name__ == "__main__":
47
+ _early_set_cuda_visible_devices()
48
+
49
+ import torch
50
+
51
+ REPO_ROOT = Path(__file__).resolve().parents[2]
52
+ if str(REPO_ROOT) not in sys.path:
53
+ sys.path.insert(0, str(REPO_ROOT))
54
+
55
+ # Avoid optional vision deps when importing transformers.
56
+ os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
57
+ os.environ.setdefault("DISABLE_TRANSFORMERS_IMAGE_TRANSFORMS", "1")
58
+
59
+
60
+ def _stub_torchvision() -> None:
61
+ """Provide minimal torchvision stubs so transformers imports succeed without torchvision."""
62
+
63
+ if "torchvision" in sys.modules:
64
+ return
65
+
66
+ def _mk(name: str) -> types.ModuleType:
67
+ mod = types.ModuleType(name)
68
+ mod.__spec__ = ModuleSpec(name, loader=None)
69
+ return mod
70
+
71
+ tv = _mk("torchvision")
72
+ tv.__dict__["__path__"] = []
73
+ submods = ["transforms", "_meta_registrations", "datasets", "io", "models", "ops", "utils"]
74
+ for name in submods:
75
+ mod = _mk(f"torchvision.{name}")
76
+ sys.modules[f"torchvision.{name}"] = mod
77
+ setattr(tv, name, mod)
78
+
79
+ class _InterpolationMode:
80
+ NEAREST = 0
81
+ NEAREST_EXACT = 0
82
+ BILINEAR = 1
83
+ BICUBIC = 2
84
+ LANCZOS = 3
85
+ BOX = 4
86
+ HAMMING = 5
87
+
88
+ sys.modules["torchvision.transforms"].InterpolationMode = _InterpolationMode
89
+ sys.modules["torchvision.transforms"].__all__ = ["InterpolationMode"]
90
+
91
+ ops_mod = sys.modules.get("torchvision.ops") or _mk("torchvision.ops")
92
+ sys.modules["torchvision.ops"] = ops_mod
93
+ setattr(tv, "ops", ops_mod)
94
+ misc_mod = _mk("torchvision.ops.misc")
95
+ sys.modules["torchvision.ops.misc"] = misc_mod
96
+ setattr(ops_mod, "misc", misc_mod)
97
+
98
+ class _FrozenBatchNorm2d:
99
+ def __init__(self, *args, **kwargs):
100
+ pass
101
+
102
+ misc_mod.FrozenBatchNorm2d = _FrozenBatchNorm2d
103
+ sys.modules["torchvision"] = tv
104
+
105
+
106
+ def _stub_timm() -> None:
107
+ """Provide minimal timm stubs to avoid optional vision deps."""
108
+
109
+ if "timm" in sys.modules:
110
+ return
111
+
112
+ def _mk(name: str) -> types.ModuleType:
113
+ mod = types.ModuleType(name)
114
+ mod.__spec__ = ModuleSpec(name, loader=None)
115
+ return mod
116
+
117
+ timm = _mk("timm")
118
+ timm.__dict__["__path__"] = []
119
+ sys.modules["timm"] = timm
120
+
121
+ data_mod = _mk("timm.data")
122
+ sys.modules["timm.data"] = data_mod
123
+ timm.data = data_mod
124
+
125
+ class _ImageNetInfo:
126
+ pass
127
+
128
+ def _infer_imagenet_subset(*args, **kwargs):
129
+ return None
130
+
131
+ data_mod.ImageNetInfo = _ImageNetInfo
132
+ data_mod.infer_imagenet_subset = _infer_imagenet_subset
133
+
134
+ layers_mod = _mk("timm.layers")
135
+ sys.modules["timm.layers"] = layers_mod
136
+ timm.layers = layers_mod
137
+
138
+ create_norm_mod = _mk("timm.layers.create_norm")
139
+ sys.modules["timm.layers.create_norm"] = create_norm_mod
140
+ layers_mod.create_norm = create_norm_mod
141
+
142
+ def _get_norm_layer(*args, **kwargs):
143
+ return None
144
+
145
+ create_norm_mod.get_norm_layer = _get_norm_layer
146
+
147
+ classifier_mod = _mk("timm.layers.classifier")
148
+ sys.modules["timm.layers.classifier"] = classifier_mod
149
+ layers_mod.classifier = classifier_mod
150
+
151
+
152
+ def _stub_gemma3n() -> None:
153
+ """Stub Gemma3n config module if transformers tries to import it."""
154
+
155
+ if "transformers.models.gemma3n.configuration_gemma3n" in sys.modules:
156
+ return
157
+
158
+ gemma_pkg = types.ModuleType("transformers.models.gemma3n")
159
+ gemma_pkg.__spec__ = ModuleSpec("transformers.models.gemma3n", loader=None, is_package=True)
160
+ sys.modules["transformers.models.gemma3n"] = gemma_pkg
161
+
162
+ gemma_conf = types.ModuleType("transformers.models.gemma3n.configuration_gemma3n")
163
+ gemma_conf.__spec__ = ModuleSpec("transformers.models.gemma3n.configuration_gemma3n", loader=None)
164
+
165
+ class Gemma3nConfig:
166
+ def __init__(self, *args, **kwargs):
167
+ self.model_type = "gemma3n"
168
+
169
+ class Gemma3nTextConfig(Gemma3nConfig):
170
+ pass
171
+
172
+ gemma_conf.Gemma3nConfig = Gemma3nConfig
173
+ gemma_conf.Gemma3nTextConfig = Gemma3nTextConfig
174
+ gemma_conf.__all__ = ["Gemma3nConfig", "Gemma3nTextConfig"]
175
+ sys.modules["transformers.models.gemma3n.configuration_gemma3n"] = gemma_conf
176
+ setattr(gemma_pkg, "configuration_gemma3n", gemma_conf)
177
+
178
+
179
+ _stub_torchvision()
180
+ _stub_timm()
181
+ _stub_gemma3n()
182
+
183
+ import transformers # noqa: E402
184
+
185
+ # Provide light stubs if Longformer classes are unavailable; we don't use them here.
186
+ if not hasattr(transformers, "LongformerTokenizer"):
187
+ class _DummyLongformerTokenizer:
188
+ def __init__(self, *args, **kwargs):
189
+ raise ImportError("LongformerTokenizer stubbed; install full transformers if needed.")
190
+ transformers.LongformerTokenizer = _DummyLongformerTokenizer
191
+ if not hasattr(transformers, "LongformerForMaskedLM"):
192
+ class _DummyLongformerForMaskedLM:
193
+ def __init__(self, *args, **kwargs):
194
+ raise ImportError("LongformerForMaskedLM stubbed; install full transformers if needed.")
195
+ transformers.LongformerForMaskedLM = _DummyLongformerForMaskedLM
196
+
197
+ from exp.case_study import viz # noqa: E402
198
+ from exp.exp2 import dataset_utils as ds_utils # noqa: E402
199
+ from shared_utils import DEFAULT_PROMPT_TEMPLATE # noqa: E402
200
+
201
+ import llm_attr # noqa: E402
202
+ from evaluations.attribution_recovery import load_model # noqa: E402
203
+
204
+
205
+ def resolve_device(cuda: Optional[str], cuda_num: int) -> str:
206
+ if cuda and isinstance(cuda, str) and "," in cuda:
207
+ os.environ["CUDA_VISIBLE_DEVICES"] = cuda
208
+ return "auto"
209
+ if cuda and isinstance(cuda, str) and cuda.strip():
210
+ try:
211
+ idx = int(cuda)
212
+ except Exception:
213
+ idx = 0
214
+ return f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
215
+ return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu"
216
+
217
+
218
+ def load_example(dataset: str, index: int, data_root: Path) -> Tuple[ds_utils.CachedExample, str]:
219
+ ds_path = Path(dataset)
220
+ if ds_path.exists():
221
+ examples = ds_utils.read_cached_jsonl(ds_path)
222
+ dataset_name = ds_path.name
223
+ else:
224
+ loader = ds_utils.DatasetLoader(data_root=data_root)
225
+ examples = loader.load(dataset)
226
+ dataset_name = dataset
227
+
228
+ if not examples:
229
+ raise ValueError(f"No examples found for dataset={dataset}")
230
+
231
+ if index < 0:
232
+ index = len(examples) + index
233
+ if not (0 <= index < len(examples)):
234
+ raise IndexError(f"index {index} out of range for dataset with {len(examples)} examples")
235
+
236
+ return examples[index], dataset_name
237
+
238
+
239
+ def make_output_stem(dataset_name: str, index: int, method: str) -> str:
240
+ safe_name = dataset_name.replace("/", "_").replace(" ", "_")
241
+ return f"mas_case_{method}_{safe_name}_idx{index}"
242
+
243
+
244
+ def format_prompt(tokenizer: Any, prompt: str) -> str:
245
+ modified_prompt = DEFAULT_PROMPT_TEMPLATE.format(context=prompt, query="")
246
+ formatted_prompt = [{"role": "user", "content": modified_prompt}]
247
+ return tokenizer.apply_chat_template(
248
+ formatted_prompt,
249
+ tokenize=False,
250
+ add_generation_prompt=True,
251
+ enable_thinking=False,
252
+ )
253
+
254
+
255
+ @torch.inference_mode()
256
+ def compute_logprob_response_given_prompt(model: Any, prompt_ids: torch.Tensor, response_ids: torch.Tensor) -> torch.Tensor:
257
+ """Compute log-probabilities of response_ids given prompt_ids.
258
+
259
+ Shapes:
260
+ prompt_ids: [B, N]
261
+ response_ids: [B, M]
262
+ returns: [B, M]
263
+ """
264
+ input_ids = torch.cat([prompt_ids, response_ids], dim=1)
265
+ attention_mask = torch.ones_like(input_ids)
266
+ logits = model(input_ids=input_ids, attention_mask=attention_mask).logits # [B, N+M, V]
267
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
268
+
269
+ response_start = int(prompt_ids.shape[1])
270
+ logits_for_response = log_probs[:, response_start - 1 : -1, :] # [B, M, V]
271
+ gathered = logits_for_response.gather(2, response_ids.unsqueeze(-1))
272
+ return gathered.squeeze(-1)
273
+
274
+
275
+ @torch.inference_mode()
276
+ def score_prompt_ids_with_generation(model: Any, *, prompt_ids: torch.Tensor, generation_ids: torch.Tensor) -> float:
277
+ return float(compute_logprob_response_given_prompt(model, prompt_ids, generation_ids).sum().detach().cpu().item())
278
+
279
+
280
+ @torch.inference_mode()
281
+ def _ensure_pad_token_id(tokenizer: Any) -> int:
282
+ if tokenizer.pad_token_id is None:
283
+ if tokenizer.eos_token_id is None:
284
+ raise RuntimeError("tokenizer has neither pad_token_id nor eos_token_id; cannot define baseline token.")
285
+ tokenizer.pad_token = tokenizer.eos_token
286
+ return int(tokenizer.pad_token_id)
287
+
288
+
289
+ def _find_subsequence_start(haystack: torch.Tensor, needle: torch.Tensor) -> Optional[int]:
290
+ if haystack.ndim != 1 or needle.ndim != 1:
291
+ raise ValueError("Expected 1D tensors for subsequence matching.")
292
+ if needle.numel() == 0:
293
+ return 0
294
+ hay_len = int(haystack.numel())
295
+ needle_len = int(needle.numel())
296
+ if needle_len > hay_len:
297
+ return None
298
+ for i in range(hay_len - needle_len + 1):
299
+ if torch.equal(haystack[i : i + needle_len], needle):
300
+ return i
301
+ return None
302
+
303
+
304
+ def decode_text_into_tokens(tokenizer: Any, text: str) -> List[str]:
305
+ encoding = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
306
+ offsets = list(encoding["offset_mapping"])
307
+ tokens: List[str] = []
308
+ for start, end in offsets:
309
+ tokens.append(text[start:end])
310
+ return tokens
311
+
312
+
313
+ def auc(arr: np.ndarray) -> float:
314
+ return float((arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1))
315
+
316
+
317
+ def mas_trace(
318
+ model: Any,
319
+ tokenizer: Any,
320
+ *,
321
+ attribution: torch.Tensor,
322
+ prompt: str,
323
+ generation: str,
324
+ user_prompt_indices: Optional[Sequence[int]] = None,
325
+ keep_prompt_token_indices: Optional[Sequence[int]] = None,
326
+ k: int = 20,
327
+ ) -> Dict[str, Any]:
328
+ """Return a token-level faithfulness trace (RISE/MAS/RISE+AP) plus per-token deltas."""
329
+
330
+ pad_token_id = _ensure_pad_token_id(tokenizer)
331
+
332
+ user_prompt = " " + prompt
333
+ formatted = format_prompt(tokenizer, user_prompt)
334
+ formatted_ids = tokenizer(formatted, return_tensors="pt", add_special_tokens=False).input_ids
335
+ user_ids = tokenizer(user_prompt, return_tensors="pt", add_special_tokens=False).input_ids
336
+
337
+ prompt_ids = formatted_ids.to(model.device)
338
+ prompt_ids_perturbed = prompt_ids.clone()
339
+ gen_ids = tokenizer(
340
+ generation + (tokenizer.eos_token or ""),
341
+ return_tensors="pt",
342
+ add_special_tokens=False,
343
+ ).input_ids.to(model.device)
344
+
345
+ attr_cpu = attribution.detach().cpu()
346
+ w = attr_cpu.sum(0)
347
+ P = int(w.numel())
348
+
349
+ if keep_prompt_token_indices is None:
350
+ keep = list(range(P))
351
+ else:
352
+ keep = []
353
+ seen: set[int] = set()
354
+ for raw in keep_prompt_token_indices:
355
+ try:
356
+ idx = int(raw)
357
+ except Exception:
358
+ continue
359
+ if 0 <= idx < P and idx not in seen:
360
+ keep.append(idx)
361
+ seen.add(idx)
362
+ keep.sort()
363
+
364
+ K = len(keep)
365
+ if K:
366
+ w_keep = w.index_select(0, torch.as_tensor(keep, dtype=torch.long))
367
+ sorted_local = torch.argsort(w_keep, descending=True)
368
+ sorted_attr_indices = torch.as_tensor([keep[int(i.item())] for i in sorted_local], dtype=torch.long)
369
+ attr_sum = float(w_keep.sum().item())
370
+ else:
371
+ sorted_attr_indices = torch.zeros((0,), dtype=torch.long)
372
+ attr_sum = 0.0
373
+
374
+ if int(user_ids.shape[1]) != P:
375
+ raise ValueError(
376
+ "Prompt-side attribution length does not match tokenized user prompt length: "
377
+ f"attr P={P}, user_prompt P={int(user_ids.shape[1])}."
378
+ )
379
+
380
+ prompt_positions: List[int]
381
+ if user_prompt_indices is not None:
382
+ prompt_positions = [int(x) for x in user_prompt_indices]
383
+ if len(prompt_positions) != P:
384
+ raise ValueError(
385
+ "user_prompt_indices length does not match prompt-side attribution length: "
386
+ f"indices P={len(prompt_positions)}, attr P={P}."
387
+ )
388
+ if P and max(prompt_positions) >= int(prompt_ids_perturbed.shape[1]):
389
+ raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
390
+ else:
391
+ user_start = _find_subsequence_start(formatted_ids[0], user_ids[0])
392
+ if user_start is None:
393
+ raise RuntimeError("Failed to locate user prompt token span inside formatted chat prompt.")
394
+ prompt_positions = [int(user_start) + j for j in range(P)]
395
+
396
+ if K > 0:
397
+ steps = int(k) if k is not None else 0
398
+ if steps <= 0:
399
+ steps = 1
400
+ steps = min(steps, K)
401
+ else:
402
+ steps = 0
403
+
404
+ scores = np.zeros(steps + 1, dtype=np.float64)
405
+ density = np.zeros(steps + 1, dtype=np.float64)
406
+
407
+ scores[0] = score_prompt_ids_with_generation(model, prompt_ids=prompt_ids_perturbed, generation_ids=gen_ids)
408
+ density[0] = 1.0
409
+
410
+ if K == 0:
411
+ return {
412
+ "num_tokens": P,
413
+ "sorted_attr_indices": [],
414
+ "scores_raw": scores.tolist(),
415
+ "density": density.tolist(),
416
+ "normalized_model_response": scores.tolist(),
417
+ "alignment_penalty": np.zeros_like(scores).tolist(),
418
+ "corrected_scores": scores.tolist(),
419
+ "token_deltas_raw": np.zeros(P, dtype=np.float64).tolist(),
420
+ "attr_weights": np.zeros(P, dtype=np.float64).tolist(),
421
+ "metrics": {"RISE": 0.0, "MAS": 0.0, "RISE+AP": 0.0},
422
+ }
423
+
424
+ if attr_sum <= 0:
425
+ density = np.linspace(1.0, 0.0, steps + 1)
426
+
427
+ per_token_delta = np.zeros(P, dtype=np.float64)
428
+
429
+ base = K // steps
430
+ remainder = K % steps
431
+ start = 0
432
+ for step in range(steps):
433
+ size = base + (1 if step < remainder else 0)
434
+ group = sorted_attr_indices[start : start + size]
435
+ start += size
436
+
437
+ for idx_t in group:
438
+ idx = int(idx_t.item())
439
+ abs_pos = int(prompt_positions[idx])
440
+ prompt_ids_perturbed[0, abs_pos] = pad_token_id
441
+
442
+ scores[step + 1] = score_prompt_ids_with_generation(model, prompt_ids=prompt_ids_perturbed, generation_ids=gen_ids)
443
+ if attr_sum > 0:
444
+ dec = float(w.index_select(0, group).sum().item()) / attr_sum
445
+ density[step + 1] = density[step] - dec
446
+
447
+ delta = scores[step] - scores[step + 1]
448
+ for idx_t in group:
449
+ idx = int(idx_t.item())
450
+ per_token_delta[idx] = delta
451
+
452
+ min_normalized_pred = 1.0
453
+ normalized_model_response = scores.copy()
454
+ for i in range(len(scores)):
455
+ normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
456
+ normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
457
+ min_normalized_pred = min(min_normalized_pred, float(normalized_pred))
458
+ normalized_model_response[i] = min_normalized_pred
459
+
460
+ alignment_penalty = np.abs(normalized_model_response - density)
461
+ corrected_scores = normalized_model_response + alignment_penalty
462
+ corrected_scores = corrected_scores.clip(0, 1)
463
+ corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
464
+ if np.isnan(corrected_scores).any():
465
+ corrected_scores = np.linspace(1, 0, len(scores))
466
+
467
+ rise = auc(normalized_model_response)
468
+ mas = auc(corrected_scores)
469
+ rise_ap = auc(normalized_model_response + alignment_penalty)
470
+
471
+ if attr_sum > 0:
472
+ attr_weights = np.zeros(P, dtype=np.float64)
473
+ for idx in keep:
474
+ attr_weights[idx] = float(w[idx].item()) / (attr_sum + 1e-12)
475
+ else:
476
+ attr_weights = np.zeros(P, dtype=np.float64)
477
+
478
+ return {
479
+ "num_tokens": P,
480
+ "sorted_attr_indices": [int(i.item()) for i in sorted_attr_indices],
481
+ "scores_raw": scores.tolist(),
482
+ "density": density.tolist(),
483
+ "normalized_model_response": normalized_model_response.tolist(),
484
+ "alignment_penalty": alignment_penalty.tolist(),
485
+ "corrected_scores": corrected_scores.tolist(),
486
+ "token_deltas_raw": per_token_delta.tolist(),
487
+ "attr_weights": attr_weights.tolist(),
488
+ "metrics": {"RISE": rise, "MAS": mas, "RISE+AP": rise_ap},
489
+ }
490
+
491
+
492
+ def compute_method_attribution(
493
+ method: str,
494
+ example: ds_utils.CachedExample,
495
+ model: Any,
496
+ tokenizer: Any,
497
+ *,
498
+ n_hops: int,
499
+ sink_span: Optional[Tuple[int, int]],
500
+ thinking_span: Optional[Tuple[int, int]],
501
+ chunk_tokens: int,
502
+ sink_chunk_tokens: int,
503
+ attnlrp_neg_handling: str,
504
+ attnlrp_norm_mode: str,
505
+ ) -> Tuple[str, Any, llm_attr.LLMAttributionResult]:
506
+ prompt = example.prompt
507
+ target = example.target
508
+
509
+ if method == "ifr":
510
+ if sink_span is None:
511
+ raise ValueError("IFR requires sink_span (use dataset sink_span or pass --sink_span).")
512
+ attributor = llm_attr.LLMIFRAttribution(model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens)
513
+ result = attributor.calculate_ifr_span(prompt, target=target, span=sink_span)
514
+ return "IFR (ifr_span)", attributor, result
515
+
516
+ if method == "ifr_all_positions_output_only":
517
+ if sink_span is None:
518
+ raise ValueError(
519
+ "ifr_all_positions_output_only requires sink_span (use dataset sink_span or pass --sink_span)."
520
+ )
521
+ attributor = llm_attr.LLMIFRAttribution(model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens)
522
+ result = attributor.calculate_ifr_for_all_positions_output_only(
523
+ prompt,
524
+ target=target,
525
+ sink_span=sink_span,
526
+ )
527
+ return "IFR (ifr_all_positions_output_only)", attributor, result
528
+
529
+ if method in ("ft", "ft_ifr"):
530
+ attributor = llm_attr.LLMIFRAttribution(model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens)
531
+ result = attributor.calculate_ifr_multi_hop(
532
+ prompt,
533
+ target=target,
534
+ sink_span=sink_span,
535
+ thinking_span=thinking_span,
536
+ n_hops=int(n_hops),
537
+ )
538
+ return "FT-IFR (ifr_multi_hop)", attributor, result
539
+
540
+ if method in ("ft_improve", "ft_ifr_improve"):
541
+ import ft_ifr_improve
542
+
543
+ attributor = ft_ifr_improve.LLMIFRAttributionImproved(
544
+ model,
545
+ tokenizer,
546
+ chunk_tokens=chunk_tokens,
547
+ sink_chunk_tokens=sink_chunk_tokens,
548
+ )
549
+ result = attributor.calculate_ifr_multi_hop_stop_words(
550
+ prompt,
551
+ target=target,
552
+ sink_span=sink_span,
553
+ thinking_span=thinking_span,
554
+ n_hops=int(n_hops),
555
+ )
556
+ return "FT-IFR (ifr_multi_hop_stop_words)", attributor, result
557
+
558
+ if method == "ft_split_hop":
559
+ import ft_ifr_improve
560
+
561
+ attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(
562
+ model,
563
+ tokenizer,
564
+ chunk_tokens=chunk_tokens,
565
+ sink_chunk_tokens=sink_chunk_tokens,
566
+ )
567
+ result = attributor.calculate_ifr_multi_hop_split_hop(
568
+ prompt,
569
+ target=target,
570
+ sink_span=sink_span,
571
+ thinking_span=thinking_span,
572
+ n_hops=int(n_hops),
573
+ )
574
+ return "FT-IFR (ifr_multi_hop_split_hop)", attributor, result
575
+
576
+ if method == "attnlrp":
577
+ attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
578
+ result = attributor.calculate_attnlrp_ft_hop0(
579
+ prompt,
580
+ target=target,
581
+ sink_span=sink_span,
582
+ thinking_span=thinking_span,
583
+ neg_handling=attnlrp_neg_handling,
584
+ norm_mode=attnlrp_norm_mode,
585
+ )
586
+ return "AttnLRP (ft_attnlrp hop0)", attributor, result
587
+
588
+ if method == "ft_attnlrp":
589
+ attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
590
+ result = attributor.calculate_attnlrp_aggregated_multi_hop(
591
+ prompt,
592
+ target=target,
593
+ sink_span=sink_span,
594
+ thinking_span=thinking_span,
595
+ n_hops=int(n_hops),
596
+ neg_handling=attnlrp_neg_handling,
597
+ norm_mode=attnlrp_norm_mode,
598
+ )
599
+ return "FT-AttnLRP (attnlrp_aggregated_multi_hop)", attributor, result
600
+
601
+ raise ValueError(f"Unsupported method={method!r}")
602
+
603
+
604
+ def parse_args() -> argparse.Namespace:
605
+ parser = argparse.ArgumentParser("MAS case study (faithfulness perturbation visualization)")
606
+ parser.add_argument("--dataset", type=str, default="exp/exp2/data/morehopqa.jsonl", help="Dataset name or JSONL path.")
607
+ parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache root for dataset names.")
608
+ parser.add_argument("--index", type=int, default=0, help="Sample index (supports negative for reverse).")
609
+ parser.add_argument(
610
+ "--method",
611
+ type=str,
612
+ choices=[
613
+ "ifr",
614
+ "ifr_all_positions_output_only",
615
+ "ft",
616
+ "ft_ifr",
617
+ "ft_improve",
618
+ "ft_ifr_improve",
619
+ "ft_split_hop",
620
+ "attnlrp",
621
+ "ft_attnlrp",
622
+ ],
623
+ default="ft",
624
+ )
625
+ parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (ignored if --model_path set).")
626
+ parser.add_argument("--model_path", type=str, default=None, help="Local model path to override --model.")
627
+ parser.add_argument("--cuda", type=str, default=None, help="CUDA spec (e.g., '0' or '0,1').")
628
+ parser.add_argument("--cuda_num", type=int, default=0, help="Fallback GPU index when --cuda unset.")
629
+ parser.add_argument("--n_hops", type=int, default=1, help="Number of hops for multi-hop methods.")
630
+ parser.add_argument("--sink_span", type=int, nargs=2, default=None, help="Optional sink span over generation tokens.")
631
+ parser.add_argument("--thinking_span", type=int, nargs=2, default=None, help="Optional thinking span over generation tokens.")
632
+ parser.add_argument("--chunk_tokens", type=int, default=128, help="IFR chunk size.")
633
+ parser.add_argument("--sink_chunk_tokens", type=int, default=32, help="IFR sink chunk size.")
634
+ parser.add_argument(
635
+ "--attnlrp_neg_handling",
636
+ type=str,
637
+ choices=["drop", "abs"],
638
+ default="drop",
639
+ help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
640
+ )
641
+ parser.add_argument(
642
+ "--attnlrp_norm_mode",
643
+ type=str,
644
+ choices=["norm", "no_norm"],
645
+ default="norm",
646
+ help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
647
+ )
648
+ parser.add_argument("--output_dir", type=str, default="exp/case_study/out", help="Where to write HTML/JSON artifacts.")
649
+ return parser.parse_args()
650
+
651
+
652
+ def main() -> None:
653
+ args = parse_args()
654
+ device = resolve_device(args.cuda, args.cuda_num)
655
+ if torch.cuda.is_available():
656
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES")
657
+ print(f"[info] CUDA_VISIBLE_DEVICES={visible!r} torch.cuda.device_count()={torch.cuda.device_count()} device={device}")
658
+
659
+ if args.method == "ft_ifr":
660
+ method_key = "ft"
661
+ elif args.method == "ft_ifr_improve":
662
+ method_key = "ft_improve"
663
+ else:
664
+ method_key = args.method
665
+
666
+ model_name = args.model_path if args.model_path is not None else args.model
667
+ model, tokenizer = load_model(model_name, device)
668
+
669
+ example, ds_name = load_example(args.dataset, args.index, Path(args.data_root))
670
+
671
+ sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
672
+ thinking_span = (
673
+ tuple(args.thinking_span)
674
+ if args.thinking_span is not None
675
+ else tuple(example.thinking_span) if example.thinking_span else None
676
+ )
677
+
678
+ method_label, attributor, attr_result = compute_method_attribution(
679
+ method_key,
680
+ example,
681
+ model,
682
+ tokenizer,
683
+ n_hops=args.n_hops,
684
+ sink_span=sink_span,
685
+ thinking_span=thinking_span,
686
+ chunk_tokens=args.chunk_tokens,
687
+ sink_chunk_tokens=args.sink_chunk_tokens,
688
+ attnlrp_neg_handling=args.attnlrp_neg_handling,
689
+ attnlrp_norm_mode=args.attnlrp_norm_mode,
690
+ )
691
+
692
+ indices_to_explain = example.indices_to_explain or example.sink_span
693
+ if not (isinstance(indices_to_explain, list) and len(indices_to_explain) == 2):
694
+ raise ValueError("MAS case study requires token-span indices_to_explain=[start_tok,end_tok] (e.g. sink_span).")
695
+ seq_attr, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
696
+
697
+ prompt_tokens = decode_text_into_tokens(tokenizer, " " + example.prompt)
698
+ generation_text = example.target if example.target is not None else (getattr(attributor, "generation", None) or "")
699
+
700
+ variant_specs = [
701
+ ("seq", "Seq attribution", seq_attr),
702
+ ("row", "Row attribution", row_attr),
703
+ ("recursive", "Recursive attribution", rec_attr),
704
+ ]
705
+
706
+ formatted = format_prompt(tokenizer, " " + example.prompt)
707
+ prompt_ids = tokenizer(formatted, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
708
+ gen_ids = tokenizer(
709
+ generation_text + (tokenizer.eos_token or ""),
710
+ return_tensors="pt",
711
+ add_special_tokens=False,
712
+ ).input_ids.to(model.device)
713
+ base_score = score_prompt_ids_with_generation(model, prompt_ids=prompt_ids, generation_ids=gen_ids)
714
+
715
+ panels_raw: List[Dict[str, Any]] = []
716
+ panels_display: List[Dict[str, Any]] = []
717
+
718
+ for variant_key, variant_label, variant_attr in variant_specs:
719
+ prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
720
+ attr_prompt = variant_attr[:, :prompt_len]
721
+ keep_prompt_token_indices = None
722
+ if method_key == "ft_improve":
723
+ import ft_ifr_improve
724
+
725
+ keep_prompt_token_indices = ft_ifr_improve.keep_token_indices(list(getattr(attributor, "user_prompt_tokens", []) or []))
726
+ trace = mas_trace(
727
+ model,
728
+ tokenizer,
729
+ attribution=attr_prompt.to(device="cpu"),
730
+ prompt=example.prompt,
731
+ generation=generation_text,
732
+ user_prompt_indices=getattr(attributor, "user_prompt_indices", None),
733
+ keep_prompt_token_indices=keep_prompt_token_indices,
734
+ )
735
+ trace["variant"] = variant_key
736
+ trace["variant_label"] = variant_label
737
+
738
+ panel_raw = {
739
+ "variant": variant_key,
740
+ "variant_label": variant_label,
741
+ "metrics": trace.get("metrics"),
742
+ "sorted_attr_indices": trace.get("sorted_attr_indices"),
743
+ "attr_weights": trace.get("attr_weights"),
744
+ "token_deltas_raw": trace.get("token_deltas_raw"),
745
+ "mas_trace": trace,
746
+ }
747
+ panels_raw.append(panel_raw)
748
+
749
+ panel_display = {
750
+ "variant": variant_key,
751
+ "variant_label": variant_label,
752
+ "metrics": trace.get("metrics"),
753
+ "sorted_attr_indices": trace.get("sorted_attr_indices"),
754
+ "attr_weights": trace.get("attr_weights"),
755
+ "token_deltas_raw": trace.get("token_deltas_raw"),
756
+ }
757
+ panels_display.append(panel_display)
758
+
759
+ case_meta: Dict[str, Any] = {
760
+ "dataset": ds_name,
761
+ "index": args.index,
762
+ "mode": "mas",
763
+ "attr_method": method_key,
764
+ "attr_method_label": method_label,
765
+ "sink_span": sink_span,
766
+ "thinking_span": thinking_span,
767
+ "n_hops": int(args.n_hops),
768
+ "attnlrp_neg_handling": args.attnlrp_neg_handling if method_key in ("attnlrp", "ft_attnlrp") else None,
769
+ "attnlrp_norm_mode": args.attnlrp_norm_mode if method_key in ("attnlrp", "ft_attnlrp") else None,
770
+ "attnlrp_ratio_enabled": (args.attnlrp_norm_mode == "norm") if method_key in ("attnlrp", "ft_attnlrp") else None,
771
+ "base_score": float(base_score),
772
+ }
773
+
774
+ record = {
775
+ "meta": case_meta,
776
+ "prompt": example.prompt,
777
+ "target": example.target,
778
+ "generation": generation_text,
779
+ "prompt_tokens": prompt_tokens,
780
+ "panels": panels_raw,
781
+ }
782
+
783
+ out_dir = Path(args.output_dir)
784
+ out_dir.mkdir(parents=True, exist_ok=True)
785
+ stem = make_output_stem(ds_name, args.index, method_key)
786
+ json_path = out_dir / f"{stem}.json"
787
+ html_path = out_dir / f"{stem}.html"
788
+
789
+ with json_path.open("w", encoding="utf-8") as f:
790
+ json.dump(record, f, ensure_ascii=False, indent=2)
791
+
792
+ html = viz.render_mas_token_html(
793
+ case_meta,
794
+ prompt_tokens=prompt_tokens,
795
+ panels=panels_display,
796
+ generation=generation_text,
797
+ )
798
+ html_path.write_text(html, encoding="utf-8")
799
+
800
+ print(f"[done] wrote {json_path}")
801
+ print(f"[done] wrote {html_path}")
802
+
803
+
804
+ if __name__ == "__main__":
805
+ main()
exp/case_study/viz.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTML helpers for visualizing hop-wise IFR/AttnLRP attributions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Any, Dict, List, Optional, Sequence
7
+
8
+ from html import escape
9
+
10
+
11
+ TOKEN_SCALE_QUANTILE = 0.995
12
+
13
+
14
+ def _robust_abs_max(scores: Sequence[float], *, quantile: float = TOKEN_SCALE_QUANTILE) -> float:
15
+ """Return a robust abs max to avoid a single outlier washing out the colormap.
16
+
17
+ Uses a high quantile (default: p99.5) over |scores|. Top outliers saturate.
18
+ """
19
+
20
+ abs_vals: List[float] = []
21
+ for x in scores:
22
+ try:
23
+ v = float(x)
24
+ except Exception:
25
+ continue
26
+ if math.isnan(v):
27
+ continue
28
+ abs_vals.append(abs(v))
29
+
30
+ if not abs_vals:
31
+ return 0.0
32
+
33
+ abs_vals.sort()
34
+ q = float(quantile)
35
+ if q < 0.0:
36
+ q = 0.0
37
+ if q > 1.0:
38
+ q = 1.0
39
+ idx = int(q * (len(abs_vals) - 1))
40
+ return float(abs_vals[idx])
41
+
42
+
43
+ def _color_for_score(score: float, max_score: float) -> str:
44
+ if max_score <= 0:
45
+ return "background-color: rgba(245,245,245,0.7);"
46
+ ratio = min(1.0, score / (max_score + 1e-12))
47
+ r = 255
48
+ g = int(235 - 90 * ratio)
49
+ b = int(220 - 160 * ratio)
50
+ alpha = 0.25 + 0.55 * ratio
51
+ return f"background-color: rgba({r}, {g}, {b}, {alpha});"
52
+
53
+
54
+ def _render_sentence_list(title: str, sentences: Sequence[str], scores: Sequence[float], max_score: float) -> str:
55
+ rows: List[str] = []
56
+ for sent, sc in zip(sentences, scores):
57
+ style = _color_for_score(abs(float(sc)), max_score)
58
+ rows.append(
59
+ f'<div class="sent-row" style="{style}"><span class="score">{sc:.4f}</span>'
60
+ f'<span class="text">{escape(sent)}</span></div>'
61
+ )
62
+ return f"""
63
+ <div class="sent-block">
64
+ <div class="sent-title">{escape(title)}</div>
65
+ {''.join(rows)}
66
+ </div>
67
+ """
68
+
69
+
70
+ def _render_tokens(
71
+ tokens: Sequence[str],
72
+ scores: Sequence[float],
73
+ max_score: float,
74
+ roles: Sequence[str],
75
+ ) -> str:
76
+ spans: List[str] = []
77
+ if max_score <= 0:
78
+ max_score = 1e-8
79
+ for idx, tok in enumerate(tokens):
80
+ score = float(scores[idx]) if idx < len(scores) else 0.0
81
+ style = _color_for_score(abs(score), max_score)
82
+ role = roles[idx] if idx < len(roles) else "gen"
83
+ safe_tok = escape(tok)
84
+ spans.append(
85
+ f'<span class="tok {role}" title="idx={idx}, score={score:.6f}" style="{style}">{safe_tok}</span>'
86
+ )
87
+ return "".join(spans)
88
+
89
+
90
+ def _render_top_table(top_items: List[Dict[str, Any]]) -> str:
91
+ if not top_items:
92
+ return "<div class='top-table'><em>No attribution mass.</em></div>"
93
+
94
+ header = "<div class='top-row top-header'><span>Rank</span><span>Idx</span><span>Score</span><span>Sentence</span></div>"
95
+ body_rows = []
96
+ for rank, item in enumerate(top_items, start=1):
97
+ body_rows.append(
98
+ f"<div class='top-row'><span>{rank}</span><span>{item['idx']}</span>"
99
+ f"<span>{item['score']:.4f}</span><span>{escape(item['sentence'])}</span></div>"
100
+ )
101
+ return f"<div class='top-table'>{header}{''.join(body_rows)}</div>"
102
+
103
+
104
+ def render_case_html(
105
+ case_meta: Dict[str, Any],
106
+ *,
107
+ token_view_raw: Dict[str, Any],
108
+ token_view_prompt: Dict[str, Any],
109
+ context: Optional[Dict[str, Any]] = None,
110
+ hops_sent: Optional[Sequence[Dict[str, Any]]] = None,
111
+ ) -> str:
112
+ has_sentence_view = bool(context) and bool(hops_sent)
113
+ prompt_len = len((context or {}).get("prompt_sentences") or []) if has_sentence_view else 0
114
+ gen_len = len((context or {}).get("generation_sentences") or []) if has_sentence_view else 0
115
+
116
+ prompt_max = 0.0
117
+ gen_max = 0.0
118
+ if has_sentence_view:
119
+ prompt_max = max(
120
+ (
121
+ max(h["sentence_scores_raw"][:prompt_len])
122
+ for h in (hops_sent or [])
123
+ if h.get("sentence_scores_raw") and h["sentence_scores_raw"][:prompt_len]
124
+ ),
125
+ default=0.0,
126
+ )
127
+ gen_max = max(
128
+ (
129
+ max(h["sentence_scores_raw"][prompt_len:])
130
+ for h in (hops_sent or [])
131
+ if h.get("sentence_scores_raw") and h["sentence_scores_raw"][prompt_len:]
132
+ ),
133
+ default=0.0,
134
+ )
135
+
136
+ raw_hops = token_view_raw.get("hops", []) or []
137
+ prompt_hops = token_view_prompt.get("hops", []) or []
138
+ if len(raw_hops) != len(prompt_hops):
139
+ raise ValueError(
140
+ "token_view_raw and token_view_prompt must have the same number of panels: "
141
+ f"raw={len(raw_hops)} prompt={len(prompt_hops)}"
142
+ )
143
+
144
+ hop_sections: List[str] = []
145
+ hop_count = len(prompt_hops)
146
+ mode = case_meta.get("mode", "ft")
147
+ ifr_view = case_meta.get("ifr_view", "aggregate")
148
+ sink_span = case_meta.get("sink_span")
149
+ panel_titles = case_meta.get("panel_titles")
150
+
151
+ def _panel_title(panel_idx: int) -> str:
152
+ if isinstance(panel_titles, list) and panel_idx < len(panel_titles):
153
+ try:
154
+ title = panel_titles[panel_idx]
155
+ except Exception:
156
+ title = None
157
+ if title is not None:
158
+ return str(title)
159
+ if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen", "ft_attnlrp"):
160
+ return f"Hop {panel_idx}"
161
+ if mode == "ifr_all_positions_output_only":
162
+ return f"IFR output-only panel {panel_idx}"
163
+ if mode == "ifr_all_positions":
164
+ return f"IFR all-positions panel {panel_idx}"
165
+ if mode == "attnlrp":
166
+ return "AttnLRP (sink-span aggregate)"
167
+ return "IFR (sink-span aggregate)"
168
+
169
+ for hop_idx in range(hop_count):
170
+ raw_entry = raw_hops[hop_idx]
171
+ raw_scores = raw_entry.get("token_scores") or []
172
+ raw_mass = float(raw_entry.get("total_mass", 0.0))
173
+ raw_scale = _robust_abs_max(raw_scores)
174
+ if raw_scale <= 0:
175
+ raw_scale = float(raw_entry.get("token_score_max") or 0.0)
176
+ if raw_scale <= 0:
177
+ raw_scale = 1e-8
178
+
179
+ prompt_entry = prompt_hops[hop_idx]
180
+ prompt_scores = prompt_entry.get("token_scores") or []
181
+ prompt_mass = float(prompt_entry.get("total_mass", 0.0))
182
+ prompt_scale = _robust_abs_max(prompt_scores)
183
+ if prompt_scale <= 0:
184
+ prompt_scale = float(prompt_entry.get("token_score_max") or 0.0)
185
+ if prompt_scale <= 0:
186
+ prompt_scale = 1e-8
187
+
188
+ tok_raw_html = f"""
189
+ <div class="tokens-block">
190
+ <div class="tokens-title">{escape(token_view_raw.get("label", "Pre-trim token-level heatmap (full)"))}</div>
191
+ <div class="tokens-row">
192
+ {_render_tokens(token_view_raw.get("tokens", []), raw_scores, raw_scale, token_view_raw.get("roles", []))}
193
+ </div>
194
+ </div>
195
+ """
196
+
197
+ tok_prompt_html = f"""
198
+ <div class="tokens-block">
199
+ <div class="tokens-title">{escape(token_view_prompt.get("label", "Prompt-only token-level heatmap"))}</div>
200
+ <div class="tokens-row">
201
+ {_render_tokens(token_view_prompt.get("tokens", []), prompt_scores, prompt_scale, token_view_prompt.get("roles", []))}
202
+ </div>
203
+ </div>
204
+ """
205
+
206
+ sentence_html = ""
207
+ top_html = ""
208
+ if has_sentence_view and hop_idx < len(hops_sent or []):
209
+ hop = (hops_sent or [])[hop_idx]
210
+ raw_scores = hop.get("sentence_scores_raw") or []
211
+ prompt_scores = raw_scores[:prompt_len]
212
+ gen_scores = raw_scores[prompt_len:]
213
+ # Sentence view is not used by the current case-study runner; keep the path for completeness.
214
+ sentence_html = f"""
215
+ <div class="columns">
216
+ {_render_sentence_list('Prompt sentences', (context or {}).get('prompt_sentences') or [], prompt_scores, prompt_max)}
217
+ {_render_sentence_list('Generation sentences', (context or {}).get('generation_sentences') or [], gen_scores, gen_max)}
218
+ </div>
219
+ """
220
+ top_html = f"""
221
+ <div class="top-wrap">
222
+ <div class="section-label">Top sentences (all)</div>
223
+ {_render_top_table(hop.get('top_sentences') or [])}
224
+ </div>
225
+ """
226
+
227
+ hop_sections.append(
228
+ f"""
229
+ <div class="hop">
230
+ <div class="hop-header">
231
+ <div class="hop-title">{escape(_panel_title(hop_idx))}</div>
232
+ <div class="hop-meta">
233
+ raw mass: {raw_mass:.6f} | raw scale(p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f} abs): {raw_scale:.6g}
234
+ &nbsp;|&nbsp;
235
+ prompt mass: {prompt_mass:.6f} | prompt scale(p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f} abs): {prompt_scale:.6g}
236
+ </div>
237
+ </div>
238
+ {tok_raw_html}
239
+ {tok_prompt_html}
240
+ {sentence_html}
241
+ {top_html}
242
+ </div>
243
+ """
244
+ )
245
+
246
+ thinking_ratios = case_meta.get("thinking_ratios") or []
247
+ ratios_str = ", ".join(f"{r:.4f}" for r in thinking_ratios) if thinking_ratios else "N/A"
248
+
249
+ if mode == "ft":
250
+ mode_label = "FT Multi-hop (IFR)"
251
+ elif mode == "ifr_in_all_gen":
252
+ mode_label = "IFR In-all-gen (multi-hop)"
253
+ elif mode == "ifr":
254
+ mode_label = "IFR Standard"
255
+ elif mode == "ifr_all_positions":
256
+ mode_label = "IFR All-positions"
257
+ elif mode == "ifr_all_positions_output_only":
258
+ mode_label = "IFR Output-only (all positions)"
259
+ elif mode == "attnlrp":
260
+ mode_label = "AttnLRP"
261
+ elif mode == "ft_attnlrp":
262
+ mode_label = "FT Multi-hop (AttnLRP)"
263
+ else:
264
+ mode_label = str(mode)
265
+
266
+ if mode in ("ft", "ifr_in_all_gen", "ft_attnlrp"):
267
+ view_key = "Recursive hops"
268
+ view_val = case_meta.get("n_hops")
269
+ elif mode in ("ifr", "ifr_all_positions", "ifr_all_positions_output_only"):
270
+ view_key = "IFR view"
271
+ view_val = ifr_view
272
+ elif mode == "attnlrp":
273
+ view_key = "AttnLRP view"
274
+ view_val = "ft_hop0_span_aggregate"
275
+ else:
276
+ view_key = "View"
277
+ view_val = "N/A"
278
+
279
+ scale_row = f"<div>Token scale: per-panel per-view p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f}(|score|)</div>"
280
+ neg_handling = case_meta.get("attnlrp_neg_handling")
281
+ norm_mode = case_meta.get("attnlrp_norm_mode")
282
+ ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
283
+ attn_rows = []
284
+ if neg_handling:
285
+ attn_rows.append(f"<div>FT-AttnLRP neg_handling: {escape(str(neg_handling))}</div>")
286
+ if norm_mode:
287
+ attn_rows.append(f"<div>FT-AttnLRP norm_mode: {escape(str(norm_mode))}</div>")
288
+ if ratio_enabled is not None:
289
+ attn_rows.append(f"<div>FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}</div>")
290
+
291
+ header = f"""
292
+ <div class="header">
293
+ <div>
294
+ <div class="title">{escape(mode_label)} Case Study</div>
295
+ <div class="subtitle">Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}</div>
296
+ </div>
297
+ <div class="meta">
298
+ <div>Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}</div>
299
+ <div>Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}</div>
300
+ <div>Panels: {hop_count}</div>
301
+ <div>{escape(str(view_key))}: {escape(str(view_val))}</div>
302
+ {scale_row}
303
+ {''.join(attn_rows)}
304
+ <div>Thinking ratios: {ratios_str}</div>
305
+ </div>
306
+ </div>
307
+ """
308
+
309
+ style = """
310
+ <style>
311
+ body { font-family: "Inter", "Helvetica Neue", Arial, sans-serif; margin: 0; padding: 24px; background: #fcfcff; color: #1f2933; }
312
+ .title { font-size: 24px; font-weight: 700; }
313
+ .subtitle { font-size: 14px; color: #566; margin-top: 4px; }
314
+ .header { display: flex; justify-content: space-between; align-items: flex-start; gap: 16px; padding-bottom: 16px; border-bottom: 1px solid #e5e8ee; }
315
+ .meta { font-size: 13px; color: #334; line-height: 1.6; }
316
+ .hop { margin-top: 20px; padding: 16px; border: 1px solid #e5e8ee; border-radius: 10px; background: #fff; box-shadow: 0 2px 6px rgba(0,0,0,0.04); }
317
+ .hop-header { display: flex; justify-content: space-between; align-items: center; }
318
+ .hop-title { font-weight: 600; font-size: 16px; }
319
+ .hop-meta { font-size: 12px; color: #556; }
320
+ .tokens-block { margin-top: 12px; border: 1px solid #eef1f6; border-radius: 8px; padding: 10px; background: #f9fbff; }
321
+ .tokens-title { font-size: 13px; font-weight: 600; margin-bottom: 8px; color: #263; }
322
+ .tokens-row { font-family: "SFMono-Regular", Consolas, monospace; font-size: 12px; line-height: 1.8; word-break: break-word; }
323
+ .tok { display: inline; padding: 2px 1px; margin: 0 0px; border-radius: 3px; }
324
+ .tok.prompt { border-bottom: 1px dashed #6b8fb8; }
325
+ .tok.user { border-bottom: 1px dashed #4f72c7; }
326
+ .tok.template { border-bottom: 1px dashed #9aa9c0; }
327
+ .tok.think { border-bottom: 1px dashed #8ba86b; }
328
+ .tok.output { border-bottom: 1px dashed #c78a6e; }
329
+ .tok.gen { border-bottom: 1px dashed #999; }
330
+ .tok:hover { outline: 1px solid #8899aa; }
331
+ .columns { display: grid; grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); gap: 12px; margin-top: 12px; }
332
+ .sent-block { padding: 8px; border: 1px solid #eef1f6; border-radius: 8px; background: #f9fbff; }
333
+ .sent-title { font-weight: 600; font-size: 13px; margin-bottom: 6px; color: #263; }
334
+ .sent-row { padding: 6px 8px; border-radius: 6px; margin-bottom: 6px; display: flex; gap: 8px; align-items: flex-start; }
335
+ .sent-row:last-child { margin-bottom: 0; }
336
+ .sent-row .score { font-family: "SFMono-Regular", Consolas, monospace; font-size: 12px; color: #233; min-width: 60px; }
337
+ .sent-row .text { flex: 1; font-size: 13px; }
338
+ .top-wrap { margin-top: 10px; }
339
+ .section-label { font-size: 13px; font-weight: 600; margin-bottom: 6px; color: #263; }
340
+ .top-table { border: 1px solid #eef1f6; border-radius: 8px; background: #fff; }
341
+ .top-row { display: grid; grid-template-columns: 50px 50px 80px 1fr; padding: 6px 8px; gap: 8px; font-size: 12px; }
342
+ .top-header { background: #f3f6fb; font-weight: 700; color: #223; }
343
+ .top-row:nth-child(odd):not(.top-header) { background: #fbfdff; }
344
+ </style>
345
+ """
346
+
347
+ title = f"{mode_label} Case Study"
348
+ html = f"""<!DOCTYPE html>
349
+ <html>
350
+ <head>
351
+ <meta charset="utf-8" />
352
+ <title>{escape(title)}</title>
353
+ {style}
354
+ </head>
355
+ <body>
356
+ {header}
357
+ {''.join(hop_sections)}
358
+ </body>
359
+ </html>"""
360
+ return html
361
+
362
+
363
+ def _render_sentence_spans(title: str, sentences: Sequence[str], scores: Sequence[float]) -> str:
364
+ max_abs = max((abs(float(x)) for x in scores), default=0.0)
365
+ spans: List[str] = []
366
+ for idx, sentence in enumerate(sentences):
367
+ score = float(scores[idx]) if idx < len(scores) else 0.0
368
+ style = _color_for_score(abs(score), max_abs)
369
+ spans.append(
370
+ f'<span class="sent-span" title="idx={idx}, score={score:.6f}" style="{style}">{escape(sentence)}</span>'
371
+ )
372
+ return f"""
373
+ <div class="sentmap">
374
+ <div class="sentmap-title">{escape(title)}</div>
375
+ <div class="sentmap-text">{''.join(spans)}</div>
376
+ </div>
377
+ """
378
+
379
+
380
+ def _render_token_spans(title: str, tokens: Sequence[str], scores: Sequence[float]) -> str:
381
+ max_abs = max((abs(float(x)) for x in scores), default=0.0)
382
+ spans: List[str] = []
383
+ for idx, tok in enumerate(tokens):
384
+ score = float(scores[idx]) if idx < len(scores) else 0.0
385
+ style = _color_for_score(abs(score), max_abs)
386
+ spans.append(
387
+ f'<span class="tok-span" title="idx={idx}, score={score:.6f}" style="{style}">{escape(tok)}</span>'
388
+ )
389
+ return f"""
390
+ <div class="tokmap">
391
+ <div class="tokmap-title">{escape(title)}</div>
392
+ <div class="tokmap-text">{''.join(spans)}</div>
393
+ </div>
394
+ """
395
+
396
+
397
+ def render_mas_sentence_html(
398
+ case_meta: Dict[str, Any],
399
+ *,
400
+ prompt_sentences: Sequence[str],
401
+ panels: Sequence[Dict[str, Any]],
402
+ generation: Optional[str] = None,
403
+ ) -> str:
404
+ """Render MAS sentence-level diagnostics (attribution / pure ablation / guided marginal)."""
405
+
406
+ method_label = case_meta.get("attr_method_label") or case_meta.get("attr_method") or "Unknown method"
407
+ title = f"MAS Sentence Study ({method_label})"
408
+
409
+ neg_handling = case_meta.get("attnlrp_neg_handling")
410
+ norm_mode = case_meta.get("attnlrp_norm_mode")
411
+ ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
412
+ attn_rows = []
413
+ if neg_handling:
414
+ attn_rows.append(f"<div>FT-AttnLRP neg_handling: {escape(str(neg_handling))}</div>")
415
+ if norm_mode:
416
+ attn_rows.append(f"<div>FT-AttnLRP norm_mode: {escape(str(norm_mode))}</div>")
417
+ if ratio_enabled is not None:
418
+ attn_rows.append(f"<div>FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}</div>")
419
+
420
+ base_score = case_meta.get("base_score")
421
+ base_score_row = f"<div>Base score: {float(base_score):.6f}</div>" if isinstance(base_score, (int, float)) else ""
422
+
423
+ gen_block = ""
424
+ if isinstance(generation, str) and generation:
425
+ gen_block = f"""
426
+ <div class="text-block">
427
+ <div class="text-title">Generation (scored)</div>
428
+ <div class="text-body">{escape(generation)}</div>
429
+ </div>
430
+ """
431
+
432
+ header = f"""
433
+ <div class="header">
434
+ <div>
435
+ <div class="title">{escape(title)}</div>
436
+ <div class="subtitle">Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}</div>
437
+ </div>
438
+ <div class="meta">
439
+ <div>Attribution method: {escape(str(case_meta.get('attr_method')))}</div>
440
+ <div>Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}</div>
441
+ <div>Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}</div>
442
+ <div>Panels: {len(panels)}</div>
443
+ {''.join(attn_rows)}
444
+ {base_score_row}
445
+ </div>
446
+ </div>
447
+ """
448
+
449
+ panel_sections: List[str] = []
450
+ for panel in panels:
451
+ label = panel.get("variant_label") or panel.get("panel_label") or panel.get("variant") or "Panel"
452
+ metrics = panel.get("metrics") or {}
453
+ metrics_str = " | ".join(
454
+ f"{k}: {float(metrics[k]):.4f}" if isinstance(metrics.get(k), (int, float)) else f"{k}: {metrics.get(k)}"
455
+ for k in ("RISE", "MAS", "RISE+AP")
456
+ if k in metrics
457
+ )
458
+
459
+ attr_weights = panel.get("attr_weights") or []
460
+ pure_deltas = panel.get("pure_sentence_deltas_raw") or []
461
+ guided_deltas = panel.get("guided_sentence_deltas_raw") or panel.get("sentence_deltas_raw") or []
462
+ rank_order = panel.get("sorted_attr_indices") or []
463
+ rank_str = ", ".join(str(int(x)) for x in rank_order) if rank_order else "N/A"
464
+
465
+ panel_sections.append(
466
+ f"""
467
+ <div class="panel">
468
+ <div class="panel-header">
469
+ <div class="panel-title">{escape(str(label))}</div>
470
+ <div class="panel-meta">{escape(metrics_str)}</div>
471
+ </div>
472
+
473
+ {_render_sentence_spans("Method attribution (sentence weights)", prompt_sentences, attr_weights)}
474
+ {_render_sentence_spans("Pure sentence ablation (base − score)", prompt_sentences, pure_deltas)}
475
+ {_render_sentence_spans("Attribution-guided MAS marginal (path deltas)", prompt_sentences, guided_deltas)}
476
+
477
+ <div class="panel-foot">Rank order: {escape(rank_str)}</div>
478
+ </div>
479
+ """
480
+ )
481
+
482
+ style = """
483
+ <style>
484
+ body { font-family: "Inter", "Helvetica Neue", Arial, sans-serif; margin: 0; padding: 24px; background: #fcfcff; color: #1f2933; }
485
+ .title { font-size: 24px; font-weight: 700; }
486
+ .subtitle { font-size: 14px; color: #566; margin-top: 4px; }
487
+ .header { display: flex; justify-content: space-between; align-items: flex-start; gap: 16px; padding-bottom: 16px; border-bottom: 1px solid #e5e8ee; }
488
+ .meta { font-size: 13px; color: #334; line-height: 1.6; }
489
+
490
+ .text-block { margin-top: 16px; border: 1px solid #eef1f6; border-radius: 10px; padding: 12px; background: #fff; }
491
+ .text-title { font-size: 13px; font-weight: 700; color: #263; margin-bottom: 8px; }
492
+ .text-body { font-size: 13px; line-height: 1.7; white-space: pre-wrap; word-break: break-word; }
493
+
494
+ .panel { margin-top: 18px; padding: 16px; border: 1px solid #e5e8ee; border-radius: 10px; background: #fff; box-shadow: 0 2px 6px rgba(0,0,0,0.04); }
495
+ .panel-header { display: flex; justify-content: space-between; align-items: center; }
496
+ .panel-title { font-weight: 600; font-size: 16px; }
497
+ .panel-meta { font-size: 12px; color: #556; }
498
+ .panel-foot { margin-top: 8px; font-size: 12px; color: #556; }
499
+
500
+ .sentmap { margin-top: 12px; border: 1px solid #eef1f6; border-radius: 8px; padding: 10px; background: #f9fbff; }
501
+ .sentmap-title { font-size: 13px; font-weight: 600; margin-bottom: 8px; color: #263; }
502
+ .sentmap-text { font-size: 13px; line-height: 1.8; white-space: pre-wrap; word-break: break-word; }
503
+ .sent-span { display: inline; padding: 2px 2px; margin: 0 0px; border-radius: 4px; }
504
+ .sent-span:hover { outline: 1px solid #8899aa; }
505
+ </style>
506
+ """
507
+
508
+ html = f"""<!DOCTYPE html>
509
+ <html>
510
+ <head>
511
+ <meta charset="utf-8" />
512
+ <title>{escape(title)}</title>
513
+ {style}
514
+ </head>
515
+ <body>
516
+ {header}
517
+ {gen_block}
518
+ {''.join(panel_sections)}
519
+ </body>
520
+ </html>"""
521
+ return html
522
+
523
+
524
+ def render_mas_token_html(
525
+ case_meta: Dict[str, Any],
526
+ *,
527
+ prompt_tokens: Sequence[str],
528
+ panels: Sequence[Dict[str, Any]],
529
+ generation: Optional[str] = None,
530
+ ) -> str:
531
+ """Render MAS token-level diagnostics (attribution weights + guided marginal deltas)."""
532
+
533
+ method_label = case_meta.get("attr_method_label") or case_meta.get("attr_method") or "Unknown method"
534
+ title = f"MAS Token Study ({method_label})"
535
+
536
+ neg_handling = case_meta.get("attnlrp_neg_handling")
537
+ norm_mode = case_meta.get("attnlrp_norm_mode")
538
+ ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
539
+ attn_rows = []
540
+ if neg_handling:
541
+ attn_rows.append(f"<div>FT-AttnLRP neg_handling: {escape(str(neg_handling))}</div>")
542
+ if norm_mode:
543
+ attn_rows.append(f"<div>FT-AttnLRP norm_mode: {escape(str(norm_mode))}</div>")
544
+ if ratio_enabled is not None:
545
+ attn_rows.append(f"<div>FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}</div>")
546
+
547
+ base_score = case_meta.get("base_score")
548
+ base_score_row = f"<div>Base score: {float(base_score):.6f}</div>" if isinstance(base_score, (int, float)) else ""
549
+
550
+ gen_block = ""
551
+ if isinstance(generation, str) and generation:
552
+ gen_block = f"""
553
+ <div class="text-block">
554
+ <div class="text-title">Generation (scored)</div>
555
+ <div class="text-body">{escape(generation)}</div>
556
+ </div>
557
+ """
558
+
559
+ header = f"""
560
+ <div class="header">
561
+ <div>
562
+ <div class="title">{escape(title)}</div>
563
+ <div class="subtitle">Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}</div>
564
+ </div>
565
+ <div class="meta">
566
+ <div>Attribution method: {escape(str(case_meta.get('attr_method')))}</div>
567
+ <div>Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}</div>
568
+ <div>Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}</div>
569
+ <div>Prompt tokens: {len(prompt_tokens)}</div>
570
+ <div>Panels: {len(panels)}</div>
571
+ {''.join(attn_rows)}
572
+ {base_score_row}
573
+ </div>
574
+ </div>
575
+ """
576
+
577
+ panel_sections: List[str] = []
578
+ for panel in panels:
579
+ label = panel.get("variant_label") or panel.get("panel_label") or panel.get("variant") or "Panel"
580
+ metrics = panel.get("metrics") or {}
581
+ metrics_str = " | ".join(
582
+ f"{k}: {float(metrics[k]):.4f}" if isinstance(metrics.get(k), (int, float)) else f"{k}: {metrics.get(k)}"
583
+ for k in ("RISE", "MAS", "RISE+AP")
584
+ if k in metrics
585
+ )
586
+
587
+ attr_weights = panel.get("attr_weights") or []
588
+ guided_deltas = panel.get("token_deltas_raw") or []
589
+ rank_order = panel.get("sorted_attr_indices") or []
590
+ rank_str = ", ".join(str(int(x)) for x in rank_order) if rank_order else "N/A"
591
+
592
+ panel_sections.append(
593
+ f"""
594
+ <div class="panel">
595
+ <div class="panel-header">
596
+ <div class="panel-title">{escape(str(label))}</div>
597
+ <div class="panel-meta">{escape(metrics_str)}</div>
598
+ </div>
599
+
600
+ {_render_token_spans("Method attribution (token weights)", prompt_tokens, attr_weights)}
601
+ {_render_token_spans("Attribution-guided MAS marginal (path deltas)", prompt_tokens, guided_deltas)}
602
+
603
+ <div class="panel-foot">Rank order: {escape(rank_str)}</div>
604
+ </div>
605
+ """
606
+ )
607
+
608
+ style = """
609
+ <style>
610
+ body { font-family: "Inter", "Helvetica Neue", Arial, sans-serif; margin: 0; padding: 24px; background: #fcfcff; color: #1f2933; }
611
+ .title { font-size: 24px; font-weight: 700; }
612
+ .subtitle { font-size: 14px; color: #566; margin-top: 4px; }
613
+ .header { display: flex; justify-content: space-between; align-items: flex-start; gap: 16px; padding-bottom: 16px; border-bottom: 1px solid #e5e8ee; }
614
+ .meta { font-size: 13px; color: #334; line-height: 1.6; }
615
+
616
+ .text-block { margin-top: 16px; border: 1px solid #eef1f6; border-radius: 10px; padding: 12px; background: #fff; }
617
+ .text-title { font-size: 13px; font-weight: 700; color: #263; margin-bottom: 8px; }
618
+ .text-body { font-size: 13px; line-height: 1.7; white-space: pre-wrap; word-break: break-word; }
619
+
620
+ .panel { margin-top: 18px; padding: 16px; border: 1px solid #e5e8ee; border-radius: 10px; background: #fff; box-shadow: 0 2px 6px rgba(0,0,0,0.04); }
621
+ .panel-header { display: flex; justify-content: space-between; align-items: center; }
622
+ .panel-title { font-weight: 600; font-size: 16px; }
623
+ .panel-meta { font-size: 12px; color: #556; }
624
+ .panel-foot { margin-top: 8px; font-size: 12px; color: #556; }
625
+
626
+ .tokmap { margin-top: 12px; border: 1px solid #eef1f6; border-radius: 8px; padding: 10px; background: #f9fbff; }
627
+ .tokmap-title { font-size: 13px; font-weight: 600; margin-bottom: 8px; color: #263; }
628
+ .tokmap-text { font-size: 13px; line-height: 1.8; white-space: pre-wrap; word-break: break-word; }
629
+ .tok-span { display: inline; padding: 1px 1px; margin: 0 0px; border-radius: 3px; }
630
+ .tok-span:hover { outline: 1px solid #8899aa; }
631
+ </style>
632
+ """
633
+
634
+ html = f"""<!DOCTYPE html>
635
+ <html>
636
+ <head>
637
+ <meta charset="utf-8" />
638
+ <title>{escape(title)}</title>
639
+ {style}
640
+ </head>
641
+ <body>
642
+ {header}
643
+ {gen_block}
644
+ {''.join(panel_sections)}
645
+ </body>
646
+ </html>"""
647
+ return html
exp/exp1/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashTrace 长上下文耗时实验(exp1)
2
+
3
+ 自包含脚本:`exp/exp1/run_time_curve.py`
4
+ 用途:在单个 RULER 样本上,测量不同上下文长度下各归因方法的 wall-clock 时间与 GPU 峰值显存,供论文中的线性增长表格使用。
5
+
6
+ ## 方法覆盖
7
+ - `IG`(20 步)
8
+ - `attention_I_G`(注意力 * IG)
9
+ - `attnlrp`(单次反传的 LRP 版本)
10
+ - `perturbation_all`(log-loss ablation)
11
+ - `perturbation_CLP`(KL 版)
12
+ - `perturbation_REAGENT`(MLM 替换,LED/4096 上限,超过则可能失败)
13
+ - `ifr_all_positions`(IFR one-by-one baseline,`sink_chunk_tokens=1` 固定)
14
+ - `ifr_multi_hop`(FlashTrace,多跳+chunk 支持)
15
+ - `ifr_multi_hop_both`(FT-IFR both:stop_words + in_all_gen,多跳+chunk 支持)
16
+
17
+ ## 运行示例
18
+ ```bash
19
+ # 默认 input 长度 1024,4096,8192,output 长度 32,256,512;每格 3 次
20
+ python exp/exp1/run_time_curve.py \
21
+ --model qwen-8B \
22
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
23
+ --cuda 2,3,4,5,6,7 \
24
+ --attr_funcs perturbation_all,perturbation_REAGENT,ifr_all_positions,perturbation_CLP,ifr_multi_hop,ifr_multi_hop_both,attnlrp \
25
+ --input_lengths 10 \
26
+ --output_lengths 2000,5000,10000 \
27
+ --repeats 1 \
28
+ --chunk_tokens 128 \
29
+ --sink_chunk_tokens 32 \
30
+ --catch_oom \
31
+ --ruler_file data/ruler_multihop/8192/vt_h10_c1/validation.jsonl
32
+ ```
33
+
34
+ 输出:
35
+ - `exp/exp1/out/time_curve_runs.jsonl`:每次运行的原始记录(attr、目标 input/output/total、实际长度、time、peak_mem、status)。
36
+ - `exp/exp1/out/time_curve_summary.csv`:按方法 + 目标 input/output 汇总的均值/方差(同时写出 total=input+output)。
37
+
38
+ ## 注意事项
39
+ - `--input_lengths` 控制 prompt(user prompt)长度,`--output_lengths` 控制 output(sink)长度;每个格子的 total = input + output。
40
+ - 兼容:仍支持 `--total_lengths/--lengths`(deprecated),表示 prompt+output 总长度;prompt 长度按两者差值生成。
41
+ - `--target_text` 作为基底被重复拼接以满足目标 output 长度,仅用于控制长度,不在乎语义。
42
+ - `--catch_oom/--no-catch-oom` 用于选择是把 OOM 记为 status 继续,还是直接抛错中止。
43
+ - 多卡:`--cuda 0,1` 会在脚本启动前设置 `CUDA_VISIBLE_DEVICES` 并用 `device_map=balanced` 分片加载;单卡指定 `--cuda 0`。
44
+ - 超出模型上下文 (`config.max_position_embeddings`) 会标记 `skipped_model_ctx`(按实际喂给模型的 formatted prompt + output(+eos) token 数检查)。
45
+ - `perturbation_REAGENT` 的 Longformer 仅支持 4096 tokens,超过可能返回 OOM 或 runtime_error。
46
+ - IFR multi-hop 提供 `--chunk_tokens/--sink_chunk_tokens` 以在超长上下文上强制分块,显存会下降但时间略升;`ifr_all_positions` 分支固定 `sink_chunk_tokens=1`。
exp/exp1/run_time_curve.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Measure wall-clock time and GPU memory for attribution methods across
4
+ different context lengths using a single synthetic RULER-style example.
5
+
6
+ This script stays self-contained under exp/exp1 and reuses the attribution
7
+ implementations in the repo (IG, perturbation, attention*IG, IFR/FlashTrace).
8
+ The goal is to populate the time-vs-length table; correctness of the task
9
+ content is not important, only matching token lengths and running 3 repeats.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import math
17
+ import os
18
+ import random
19
+ import sys
20
+ import time
21
+ from collections import defaultdict
22
+ from pathlib import Path
23
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
24
+
25
+ import numpy as np
26
+
27
+
28
+ def _early_set_cuda_visible_devices() -> None:
29
+ """Parse --cuda early to set CUDA_VISIBLE_DEVICES before torch import."""
30
+ parser = argparse.ArgumentParser(add_help=False)
31
+ parser.add_argument("--cuda", type=str, default=None)
32
+ args, _ = parser.parse_known_args(sys.argv[1:])
33
+ if args.cuda and "," in args.cuda:
34
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
35
+
36
+
37
+ _early_set_cuda_visible_devices()
38
+
39
+ import torch
40
+ from transformers import AutoModelForCausalLM, AutoTokenizer
41
+
42
+ REPO_ROOT = Path(__file__).resolve().parents[2]
43
+ if str(REPO_ROOT) not in sys.path:
44
+ sys.path.insert(0, str(REPO_ROOT))
45
+
46
+ import llm_attr
47
+
48
+ DEFAULT_INPUT_LENGTHS = [1024, 4096, 8192]
49
+ DEFAULT_OUTPUT_LENGTHS = [32, 256, 512]
50
+ DEFAULT_ATTRS = [
51
+ "IG",
52
+ "perturbation_all",
53
+ "attention_I_G",
54
+ "perturbation_REAGENT",
55
+ "ifr_all_positions",
56
+ "perturbation_CLP",
57
+ "ifr_multi_hop",
58
+ "attnlrp",
59
+ ]
60
+ DEFAULT_RULER_FILE = REPO_ROOT / "data" / "ruler_multihop" / "8192" / "vt_h10_c1" / "validation.jsonl"
61
+
62
+
63
+ def parse_args() -> argparse.Namespace:
64
+ parser = argparse.ArgumentParser("FlashTrace time/memory curve.")
65
+ parser.add_argument("--model", type=str, required=True, help="Model name or HF repo id.")
66
+ parser.add_argument("--model_path", type=str, default=None, help="Optional local model path.")
67
+ parser.add_argument("--cuda", type=str, default=None, help='CUDA devices, e.g. "0,1" or "0".')
68
+ parser.add_argument("--cuda_num", type=int, default=0, help="Single GPU index if --cuda is not set.")
69
+ parser.add_argument(
70
+ "--attr_funcs",
71
+ type=str,
72
+ default=",".join(DEFAULT_ATTRS),
73
+ help="Comma-separated attribution methods.",
74
+ )
75
+
76
+ length_group = parser.add_mutually_exclusive_group()
77
+ parser.add_argument(
78
+ "--output_lengths",
79
+ type=str,
80
+ default=",".join(str(x) for x in DEFAULT_OUTPUT_LENGTHS),
81
+ help="Comma-separated target output token lengths (sink/output segment).",
82
+ )
83
+ length_group.add_argument(
84
+ "--input_lengths",
85
+ type=str,
86
+ default=",".join(str(x) for x in DEFAULT_INPUT_LENGTHS),
87
+ help="Comma-separated target input/prompt token lengths (user prompt only; excludes chat template).",
88
+ )
89
+ length_group.add_argument(
90
+ "--total_lengths",
91
+ "--lengths",
92
+ dest="total_lengths",
93
+ type=str,
94
+ default=None,
95
+ help="Deprecated. Target total token lengths (prompt + output). Use --input_lengths instead.",
96
+ )
97
+ parser.add_argument("--repeats", type=int, default=3, help="Number of runs per cell.")
98
+ parser.add_argument("--output_dir", type=str, default="exp/exp1/out", help="Output directory.")
99
+ parser.add_argument(
100
+ "--ruler_file",
101
+ type=str,
102
+ default=str(DEFAULT_RULER_FILE),
103
+ help="RULER jsonl file providing a long base passage.",
104
+ )
105
+ parser.add_argument(
106
+ "--chunk_tokens",
107
+ type=int,
108
+ default=128,
109
+ help="IFR chunk_tokens override when context is long.",
110
+ )
111
+ parser.add_argument(
112
+ "--sink_chunk_tokens",
113
+ type=int,
114
+ default=32,
115
+ help="IFR sink_chunk_tokens override when context is long.",
116
+ )
117
+ parser.add_argument(
118
+ "--catch_oom",
119
+ action=argparse.BooleanOptionalAction,
120
+ default=True,
121
+ help="If true, treat CUDA OOM as status=oom and continue; if false, let OOM raise.",
122
+ )
123
+ parser.add_argument(
124
+ "--target_text",
125
+ type=str,
126
+ default=" The answer is 42.",
127
+ help="Base text to tile when constructing outputs of a given length.",
128
+ )
129
+ return parser.parse_args()
130
+
131
+
132
+ def parse_csv_ints(value: str) -> List[int]:
133
+ return [int(x) for x in value.split(",") if x.strip()]
134
+
135
+
136
+ def resolve_device(cuda: Optional[str], cuda_num: int) -> str:
137
+ if cuda is not None and "," in cuda:
138
+ os.environ["CUDA_VISIBLE_DEVICES"] = cuda
139
+ return "auto"
140
+ if cuda is not None and cuda.strip():
141
+ try:
142
+ idx = int(cuda)
143
+ except Exception:
144
+ idx = 0
145
+ return f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
146
+ return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu"
147
+
148
+
149
+ def load_ruler_base(path: Path, fallback: str) -> str:
150
+ if not path.exists():
151
+ return fallback
152
+ with path.open() as f:
153
+ for line in f:
154
+ try:
155
+ record = json.loads(line)
156
+ if "input" in record:
157
+ return record["input"]
158
+ except json.JSONDecodeError:
159
+ continue
160
+ return fallback
161
+
162
+
163
+ def build_prompt_to_length(tokenizer, base_text: str, target_tokens: int) -> Tuple[str, int]:
164
+ """
165
+ Build a prompt whose tokenized length (without special tokens) is ~target_tokens.
166
+ If base_text is shorter, we repeat it; if longer, we truncate.
167
+ """
168
+ if target_tokens <= 0:
169
+ return "", 0
170
+
171
+ base_ids = tokenizer(base_text, add_special_tokens=False).input_ids
172
+ if not base_ids:
173
+ base_ids = [tokenizer.eos_token_id]
174
+
175
+ tiled: List[int] = []
176
+ while len(tiled) < target_tokens:
177
+ tiled.extend(base_ids)
178
+ tiled = tiled[:target_tokens]
179
+ prompt = tokenizer.decode(tiled, clean_up_tokenization_spaces=False)
180
+ return prompt, len(tiled)
181
+
182
+
183
+ def build_output_to_length(tokenizer, base_text: str, target_tokens: int) -> Tuple[str, int]:
184
+ """
185
+ Build a target/output string of ~target_tokens using a base snippet.
186
+ """
187
+ if target_tokens <= 0:
188
+ return "", 0
189
+
190
+ base_ids = tokenizer(base_text, add_special_tokens=False).input_ids
191
+ if not base_ids:
192
+ base_ids = [tokenizer.eos_token_id]
193
+
194
+ tiled: List[int] = []
195
+ while len(tiled) < target_tokens:
196
+ tiled.extend(base_ids)
197
+ tiled = tiled[:target_tokens]
198
+ text = tokenizer.decode(tiled, clean_up_tokenization_spaces=False)
199
+ return text, len(tiled)
200
+
201
+
202
+ def build_formatted_prompt(tokenizer, prompt: str) -> str:
203
+ user_prompt = " " + prompt
204
+ modified_prompt = llm_attr.DEFAULT_PROMPT_TEMPLATE.format(context=user_prompt, query="")
205
+ formatted_prompt = [{"role": "user", "content": modified_prompt}]
206
+ return tokenizer.apply_chat_template(
207
+ formatted_prompt,
208
+ tokenize=False,
209
+ add_generation_prompt=True,
210
+ enable_thinking=False,
211
+ )
212
+
213
+
214
+ def estimate_model_lengths(tokenizer, prompt: str, target: str) -> Dict[str, int]:
215
+ user_prompt = " " + prompt
216
+ formatted_prompt = build_formatted_prompt(tokenizer, prompt)
217
+
218
+ user_prompt_len = len(tokenizer(user_prompt, add_special_tokens=False).input_ids)
219
+ formatted_prompt_len = len(tokenizer(formatted_prompt, add_special_tokens=False).input_ids)
220
+ generation_len = len(tokenizer(target + tokenizer.eos_token, add_special_tokens=False).input_ids)
221
+
222
+ return {
223
+ "user_prompt_tokens": user_prompt_len,
224
+ "formatted_prompt_tokens": formatted_prompt_len,
225
+ "generation_tokens": generation_len,
226
+ "total_tokens": formatted_prompt_len + generation_len,
227
+ }
228
+
229
+
230
+ def exceeds_model_ctx(tokenizer, prompt: str, target: str, max_ctx: Optional[int]) -> bool:
231
+ if max_ctx is None:
232
+ return False
233
+ return estimate_model_lengths(tokenizer, prompt, target)["total_tokens"] > max_ctx
234
+
235
+
236
+ def load_model_balanced(model_name: str, device: str):
237
+ """Load model with an explicit balanced device_map when multi-GPU is requested."""
238
+ if device == "auto":
239
+ model = AutoModelForCausalLM.from_pretrained(
240
+ model_name,
241
+ device_map="balanced",
242
+ torch_dtype=torch.float16,
243
+ attn_implementation="eager",
244
+ )
245
+ elif isinstance(device, str) and device.startswith("cuda:"):
246
+ try:
247
+ gpu_idx = int(device.split(":")[1])
248
+ except Exception:
249
+ gpu_idx = 0
250
+ model = AutoModelForCausalLM.from_pretrained(
251
+ model_name,
252
+ device_map={"": gpu_idx},
253
+ torch_dtype=torch.float16,
254
+ attn_implementation="eager",
255
+ )
256
+ else:
257
+ model = AutoModelForCausalLM.from_pretrained(
258
+ model_name,
259
+ torch_dtype=torch.float16,
260
+ attn_implementation="eager",
261
+ )
262
+
263
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
264
+ tokenizer.pad_token = tokenizer.eos_token
265
+ model.eval()
266
+ return model, tokenizer
267
+
268
+
269
+ def collect_device_indices(device_str: str, model: Any) -> List[int]:
270
+ """
271
+ Infer the CUDA device indices that should be tracked for memory stats.
272
+ Prefers the model's device map; otherwise falls back to all visible devices
273
+ or the single requested device.
274
+ """
275
+ if not torch.cuda.is_available():
276
+ return []
277
+
278
+ devices: set[int] = set()
279
+ device_map = getattr(model, "hf_device_map", None)
280
+ if isinstance(device_map, dict):
281
+ for dev in device_map.values():
282
+ if dev is None:
283
+ continue
284
+ idx: Optional[int] = None
285
+ if isinstance(dev, torch.device):
286
+ idx = dev.index if dev.index is not None else (0 if dev.type == "cuda" else None)
287
+ elif isinstance(dev, str):
288
+ try:
289
+ d = torch.device(dev)
290
+ idx = d.index if d.index is not None else (0 if d.type == "cuda" else None)
291
+ except Exception:
292
+ idx = None
293
+ elif isinstance(dev, int):
294
+ idx = dev
295
+ if idx is not None:
296
+ devices.add(idx)
297
+
298
+ if not devices:
299
+ if device_str == "auto":
300
+ devices.update(range(torch.cuda.device_count()))
301
+ elif isinstance(device_str, str) and device_str.startswith("cuda:"):
302
+ try:
303
+ devices.add(int(device_str.split(":")[1]))
304
+ except Exception:
305
+ pass
306
+ else:
307
+ devices.update(range(torch.cuda.device_count()))
308
+
309
+ return sorted(devices)
310
+
311
+
312
+ def maybe_reset_cuda(device_indices: List[int]) -> None:
313
+ if not torch.cuda.is_available() or not device_indices:
314
+ return
315
+ for idx in device_indices:
316
+ try:
317
+ torch.cuda.reset_peak_memory_stats(device=idx)
318
+ except Exception:
319
+ pass
320
+ try:
321
+ torch.cuda.empty_cache()
322
+ except Exception:
323
+ pass
324
+
325
+
326
+ def measure(
327
+ method_fn,
328
+ device_indices: List[int],
329
+ *,
330
+ catch_oom: bool,
331
+ ) -> Tuple[str, Optional[float], Optional[float], Optional[float], Dict[int, Dict[str, float]]]:
332
+ status = "ok"
333
+ wall: Optional[float] = None
334
+ mem_alloc: Optional[float] = None
335
+ mem_reserved: Optional[float] = None
336
+ mem_by_device: Dict[int, Dict[str, float]] = {}
337
+ try:
338
+ if torch.cuda.is_available() and device_indices:
339
+ for idx in device_indices:
340
+ torch.cuda.synchronize(device=idx)
341
+ t0 = time.time()
342
+ method_fn()
343
+ if torch.cuda.is_available() and device_indices:
344
+ for idx in device_indices:
345
+ torch.cuda.synchronize(device=idx)
346
+ wall = time.time() - t0
347
+ except RuntimeError as e:
348
+ if "out of memory" in str(e).lower():
349
+ status = "oom"
350
+ if not catch_oom:
351
+ raise
352
+ else:
353
+ status = f"runtime_error: {e}"
354
+ if not catch_oom:
355
+ raise
356
+ except Exception as e:
357
+ status = f"error: {e}"
358
+ if not catch_oom:
359
+ raise
360
+ finally:
361
+ if torch.cuda.is_available() and device_indices:
362
+ try:
363
+ total_alloc = 0.0
364
+ total_reserved = 0.0
365
+ for idx in device_indices:
366
+ alloc_bytes = torch.cuda.max_memory_allocated(device=idx)
367
+ reserved_bytes = torch.cuda.max_memory_reserved(device=idx)
368
+ total_alloc += alloc_bytes
369
+ total_reserved += reserved_bytes
370
+ mem_by_device[idx] = {
371
+ "allocated_gb": alloc_bytes / 1e9,
372
+ "reserved_gb": reserved_bytes / 1e9,
373
+ }
374
+ mem_alloc = total_alloc / 1e9
375
+ mem_reserved = total_reserved / 1e9
376
+ except Exception:
377
+ pass
378
+ return status, wall, mem_alloc, mem_reserved, mem_by_device
379
+
380
+
381
+ def make_attr_runner(
382
+ attr_func: str,
383
+ model: Any,
384
+ tokenizer: Any,
385
+ chunk_tokens: int,
386
+ sink_chunk_tokens: int,
387
+ batch_size: int,
388
+ prompt: str,
389
+ target: str,
390
+ ):
391
+ lf = attr_func.lower()
392
+ if lf == "ig":
393
+ llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
394
+
395
+ def fn():
396
+ return llm_attributor.calculate_IG_per_generation(
397
+ prompt, steps=20, baseline=tokenizer.eos_token_id, batch_size=batch_size, target=target
398
+ )
399
+
400
+ return fn
401
+
402
+ if lf == "attention_i_g":
403
+ llm_attn = llm_attr.LLMAttentionAttribution(model, tokenizer)
404
+ llm_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
405
+
406
+ def fn():
407
+ attn = llm_attn.calculate_attention_attribution(prompt, target=target)
408
+ ig = llm_ig.calculate_IG_per_generation(
409
+ prompt, steps=20, baseline=tokenizer.eos_token_id, batch_size=batch_size, target=target
410
+ )
411
+ attn.attribution_matrix = attn.attribution_matrix * ig.attribution_matrix
412
+ return attn
413
+
414
+ return fn
415
+
416
+ if lf == "perturbation_all":
417
+ llm_attrtor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
418
+
419
+ def fn():
420
+ return llm_attrtor.calculate_feature_ablation_sentences(
421
+ prompt, baseline=tokenizer.eos_token_id, measure="log_loss", target=target
422
+ )
423
+
424
+ return fn
425
+
426
+ if lf == "perturbation_clp":
427
+ llm_attrtor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
428
+
429
+ def fn():
430
+ return llm_attrtor.calculate_feature_ablation_sentences(
431
+ prompt, baseline=tokenizer.eos_token_id, measure="KL", target=target
432
+ )
433
+
434
+ return fn
435
+
436
+ if lf == "perturbation_reagent":
437
+ llm_attrtor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
438
+
439
+ def fn():
440
+ return llm_attrtor.calculate_feature_ablation_sentences_mlm(prompt, target=target)
441
+
442
+ return fn
443
+
444
+ if lf == "ifr_all_positions":
445
+ llm_attrtor = llm_attr.LLMIFRAttribution(
446
+ model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=1
447
+ )
448
+
449
+ def fn():
450
+ return llm_attrtor.calculate_ifr_for_all_positions(prompt, target=target)
451
+
452
+ return fn
453
+
454
+ if lf == "ifr_multi_hop":
455
+ llm_attrtor = llm_attr.LLMIFRAttribution(
456
+ model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens
457
+ )
458
+
459
+ def fn():
460
+ return llm_attrtor.calculate_ifr_multi_hop(prompt, target=target)
461
+
462
+ return fn
463
+
464
+ if lf == "ifr_multi_hop_both":
465
+ import ft_ifr_improve
466
+
467
+ llm_attrtor = ft_ifr_improve.LLMIFRAttributionBoth(
468
+ model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens
469
+ )
470
+
471
+ def fn():
472
+ return llm_attrtor.calculate_ifr_multi_hop_both(prompt, target=target)
473
+
474
+ return fn
475
+
476
+ if lf == "attnlrp":
477
+ llm_attrtor = llm_attr.LLMLRPAttribution(model, tokenizer)
478
+
479
+ def fn():
480
+ return llm_attrtor.calculate_attnlrp(prompt, target=target)
481
+
482
+ return fn
483
+
484
+ raise ValueError(f"Unsupported attr_func {attr_func}")
485
+
486
+
487
+ def compute_batch_size(sequence_length: int, max_input_len: int) -> int:
488
+ denom = int(sequence_length)
489
+ return max(1, math.floor((max_input_len - 100) / max(1, denom)))
490
+
491
+
492
+ def aggregate_results(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
493
+ grouped: Dict[Tuple[str, int, int], Dict[str, List[float]]] = defaultdict(lambda: {"time": [], "mem": []})
494
+ statuses: Dict[Tuple[str, int, int], List[str]] = defaultdict(list)
495
+ for row in rows:
496
+ key = (row["attr_func"], row["target_input_tokens"], row["target_output_tokens"])
497
+ statuses[key].append(row["status"])
498
+ if row.get("time_sec") is not None:
499
+ grouped[key]["time"].append(row["time_sec"])
500
+ if row.get("peak_mem_gb") is not None:
501
+ grouped[key]["mem"].append(row["peak_mem_gb"])
502
+
503
+ summary = []
504
+ for key, vals in grouped.items():
505
+ attr_func, input_tokens, output_tokens = key
506
+ total_tokens = input_tokens + output_tokens
507
+ times = vals["time"]
508
+ mems = vals["mem"]
509
+ summary.append(
510
+ {
511
+ "attr_func": attr_func,
512
+ "target_input_tokens": input_tokens,
513
+ "target_total_tokens": total_tokens,
514
+ "target_output_tokens": output_tokens,
515
+ "time_mean": np.mean(times) if times else None,
516
+ "time_std": np.std(times) if times else None,
517
+ "mem_mean": np.mean(mems) if mems else None,
518
+ "mem_std": np.std(mems) if mems else None,
519
+ "statuses": statuses[key],
520
+ }
521
+ )
522
+ return summary
523
+
524
+
525
+ def append_jsonl_row(f, row: Dict[str, Any]) -> None:
526
+ f.write(json.dumps(row) + "\n")
527
+ f.flush()
528
+ try:
529
+ os.fsync(f.fileno())
530
+ except OSError:
531
+ pass
532
+
533
+
534
+ def write_summary_csv(rows: List[Dict[str, Any]], out_dir: Path) -> Path:
535
+ summary = aggregate_results(rows)
536
+ summary_path = out_dir / "time_curve_summary.csv"
537
+ tmp_path = out_dir / "time_curve_summary.csv.tmp"
538
+
539
+ with tmp_path.open("w") as f:
540
+ f.write(
541
+ "attr_func,target_input_tokens,target_output_tokens,target_total_tokens,time_mean,time_std,peak_mem_mean,peak_mem_std,statuses\n"
542
+ )
543
+ for row in summary:
544
+ f.write(
545
+ "{},{},{},{},{},{},{},{},{}\n".format(
546
+ row["attr_func"],
547
+ row["target_input_tokens"],
548
+ row["target_output_tokens"],
549
+ row["target_total_tokens"],
550
+ "" if row["time_mean"] is None else f"{row['time_mean']:.4f}",
551
+ "" if row["time_std"] is None else f"{row['time_std']:.4f}",
552
+ "" if row["mem_mean"] is None else f"{row['mem_mean']:.4f}",
553
+ "" if row["mem_std"] is None else f"{row['mem_std']:.4f}",
554
+ "|".join(row["statuses"]),
555
+ )
556
+ )
557
+ f.flush()
558
+ try:
559
+ os.fsync(f.fileno())
560
+ except OSError:
561
+ pass
562
+
563
+ tmp_path.replace(summary_path)
564
+ return summary_path
565
+
566
+
567
+ def main() -> None:
568
+ args = parse_args()
569
+ device = resolve_device(args.cuda, args.cuda_num)
570
+ attr_funcs = [a.strip() for a in args.attr_funcs.split(",") if a.strip()]
571
+ target_output_lengths = parse_csv_ints(args.output_lengths)
572
+ out_dir = Path(args.output_dir)
573
+ out_dir.mkdir(parents=True, exist_ok=True)
574
+
575
+ random.seed(42)
576
+ np.random.seed(42)
577
+ torch.manual_seed(42)
578
+
579
+ model_name = args.model if args.model_path is None else args.model_path
580
+ model, tokenizer = load_model_balanced(model_name, device)
581
+ device_indices = collect_device_indices(device, model)
582
+ max_ctx = getattr(getattr(model, "config", None), "max_position_embeddings", None)
583
+
584
+ base_text = load_ruler_base(Path(args.ruler_file), fallback="RULER fallback text. ")
585
+ target_base = args.target_text
586
+ all_rows: List[Dict[str, Any]] = []
587
+ runner = None
588
+ raised: Optional[BaseException] = None
589
+ jsonl_f = None
590
+ jsonl_path = out_dir / "time_curve_runs.jsonl"
591
+ summary_path = out_dir / "time_curve_summary.csv"
592
+
593
+ def record_row(row: Dict[str, Any]) -> None:
594
+ all_rows.append(row)
595
+ if jsonl_f is not None:
596
+ append_jsonl_row(jsonl_f, row)
597
+ write_summary_csv(all_rows, out_dir)
598
+
599
+ using_deprecated_total = args.total_lengths is not None
600
+ if using_deprecated_total:
601
+ target_total_lengths = parse_csv_ints(args.total_lengths)
602
+ length_grid: List[Tuple[int, int, int]] = []
603
+ for total_tokens in target_total_lengths:
604
+ for output_tokens in target_output_lengths:
605
+ length_grid.append((total_tokens - output_tokens, output_tokens, total_tokens))
606
+ else:
607
+ target_input_lengths = parse_csv_ints(args.input_lengths)
608
+ length_grid = []
609
+ for input_tokens in target_input_lengths:
610
+ for output_tokens in target_output_lengths:
611
+ length_grid.append((input_tokens, output_tokens, input_tokens + output_tokens))
612
+
613
+ try:
614
+ jsonl_f = jsonl_path.open("w")
615
+ write_summary_csv([], out_dir)
616
+
617
+ for input_tokens, output_tokens, total_tokens in length_grid:
618
+ if input_tokens <= 0:
619
+ for attr in attr_funcs:
620
+ for rep in range(args.repeats):
621
+ record_row(
622
+ {
623
+ "attr_func": attr,
624
+ "target_input_tokens": input_tokens,
625
+ "target_output_tokens": output_tokens,
626
+ "target_total_tokens": total_tokens,
627
+ "actual_input_tokens": None,
628
+ "actual_output_tokens": None,
629
+ "actual_total_tokens_raw": None,
630
+ "actual_user_prompt_tokens": None,
631
+ "actual_formatted_prompt_tokens": None,
632
+ "actual_generation_tokens": None,
633
+ "actual_total_tokens": None,
634
+ "status": "skipped_nonpositive_input",
635
+ "time_sec": None,
636
+ "peak_mem_gb": None,
637
+ "peak_mem_reserved_gb": None,
638
+ "repeat": rep,
639
+ "used_deprecated_total_lengths": using_deprecated_total,
640
+ }
641
+ )
642
+ continue
643
+
644
+ prompt, actual_input_len = build_prompt_to_length(tokenizer, base_text, input_tokens)
645
+ target, actual_output_len = build_output_to_length(tokenizer, target_base, output_tokens)
646
+ actual_total_tokens_raw = len(tokenizer(prompt + target, add_special_tokens=False).input_ids)
647
+ model_lens = estimate_model_lengths(tokenizer, prompt, target)
648
+
649
+ if max_ctx is not None and model_lens["total_tokens"] > max_ctx:
650
+ for attr in attr_funcs:
651
+ for rep in range(args.repeats):
652
+ record_row(
653
+ {
654
+ "attr_func": attr,
655
+ "target_input_tokens": input_tokens,
656
+ "target_output_tokens": output_tokens,
657
+ "target_total_tokens": total_tokens,
658
+ "actual_input_tokens": actual_input_len,
659
+ "actual_output_tokens": actual_output_len,
660
+ "actual_total_tokens_raw": actual_total_tokens_raw,
661
+ "actual_user_prompt_tokens": model_lens["user_prompt_tokens"],
662
+ "actual_formatted_prompt_tokens": model_lens["formatted_prompt_tokens"],
663
+ "actual_generation_tokens": model_lens["generation_tokens"],
664
+ "actual_total_tokens": model_lens["total_tokens"],
665
+ "status": "skipped_model_ctx",
666
+ "time_sec": None,
667
+ "peak_mem_gb": None,
668
+ "peak_mem_reserved_gb": None,
669
+ "repeat": rep,
670
+ "used_deprecated_total_lengths": using_deprecated_total,
671
+ }
672
+ )
673
+ continue
674
+
675
+ batch_size = compute_batch_size(model_lens["total_tokens"], max_input_len=max_ctx or 200000)
676
+
677
+ for attr in attr_funcs:
678
+ for rep in range(args.repeats):
679
+ runner = None
680
+ maybe_reset_cuda(device_indices)
681
+ try:
682
+ runner = make_attr_runner(
683
+ attr,
684
+ model=model,
685
+ tokenizer=tokenizer,
686
+ chunk_tokens=args.chunk_tokens,
687
+ sink_chunk_tokens=args.sink_chunk_tokens,
688
+ batch_size=batch_size,
689
+ prompt=prompt,
690
+ target=target,
691
+ )
692
+ except RuntimeError as e:
693
+ if "out of memory" in str(e).lower():
694
+ status = "oom"
695
+ if not args.catch_oom:
696
+ raise
697
+ else:
698
+ status = f"init_runtime_error: {e}"
699
+ if not args.catch_oom:
700
+ raise
701
+ wall = None
702
+ mem_alloc = None
703
+ mem_reserved = None
704
+ mem_by_device = {}
705
+ except Exception as e:
706
+ status = f"init_error: {e}"
707
+ if not args.catch_oom:
708
+ raise
709
+ wall = None
710
+ mem_alloc = None
711
+ mem_reserved = None
712
+ mem_by_device = {}
713
+ else:
714
+ status, wall, mem_alloc, mem_reserved, mem_by_device = measure(
715
+ runner, device_indices=device_indices, catch_oom=args.catch_oom
716
+ )
717
+ finally:
718
+ runner = None
719
+
720
+ record_row(
721
+ {
722
+ "attr_func": attr,
723
+ "target_input_tokens": input_tokens,
724
+ "target_output_tokens": output_tokens,
725
+ "target_total_tokens": total_tokens,
726
+ "actual_input_tokens": actual_input_len,
727
+ "actual_output_tokens": actual_output_len,
728
+ "actual_total_tokens_raw": actual_total_tokens_raw,
729
+ "actual_user_prompt_tokens": model_lens["user_prompt_tokens"],
730
+ "actual_formatted_prompt_tokens": model_lens["formatted_prompt_tokens"],
731
+ "actual_generation_tokens": model_lens["generation_tokens"],
732
+ "actual_total_tokens": model_lens["total_tokens"],
733
+ "status": status,
734
+ "time_sec": wall,
735
+ "peak_mem_gb": mem_reserved if mem_reserved is not None else mem_alloc,
736
+ "peak_mem_reserved_gb": mem_reserved,
737
+ "peak_mem_by_device_gb": mem_by_device if mem_by_device else None,
738
+ "repeat": rep,
739
+ "used_deprecated_total_lengths": using_deprecated_total,
740
+ }
741
+ )
742
+ except BaseException as e:
743
+ raised = e
744
+ finally:
745
+ runner = None
746
+ if jsonl_f is not None:
747
+ jsonl_f.close()
748
+ write_summary_csv(all_rows, out_dir)
749
+ print(f"Wrote per-run records to {jsonl_path}")
750
+ print(f"Wrote summary to {summary_path}")
751
+
752
+ if raised is not None:
753
+ raise raised
754
+
755
+
756
+ if __name__ == "__main__":
757
+ main()
exp/exp2/DATASETS.md ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # exp/exp2 数据集与样本流说明
2
+
3
+ 本文件说明 Experiment 2 中支持的数据集、样本结构,以及在「采样阶段」与「归因阶段」的处理方式。
4
+
5
+ ## 支持的数据集
6
+ - `morehopqa`(`data/with_human_verification.json`)
7
+ - RULER 系列 JSONL:`hotpotqa_long`、`niah_*`、`vt_*`(自动在 `data/ruler_multihop/<len>/.../validation.jsonl` 搜索),或直接传入任意 RULER JSONL 路径
8
+ - 其余数据集(如 math)被显式跳过
9
+ - 归因阶段同样优先使用缓存文件 `exp/exp2/data/<name>.jsonl`,否则按上述规则解析;传入存在的 JSONL 路径也会按 RULER 结构加载
10
+
11
+ ### 共同的样本字段定义
12
+ ```json
13
+ {
14
+ "prompt": "<上下文+问题>",
15
+ "target": "<答案或生成>",
16
+ "indices_to_explain": [start_tok, end_tok] | null, // token-level:需要解释的 generation token span(闭区间)
17
+ "attr_mask_indices": [...], // legacy:覆盖率金标句子索引(当前 exp2 不再使用),可能为 null
18
+ "sink_span": [start, end] | null, // 生成 token 中的答案片段
19
+ "thinking_span": [start, end] | null, // 生成 token 中的 CoT 片段
20
+ "metadata": { ... } // 数据集特定元信息
21
+ }
22
+ ```
23
+ - **`CachedExample`**:`dataset_utils.py` 统一的内存态结构,字段与上述 JSON 完全一致,用于采样阶段(加载原始数据)与归因阶段(加载缓存或原始)。
24
+ - **缓存行(JSONL)**:`sample_and_filter.py` 写入的每行 JSON,与 `CachedExample` 字段一一对应。
25
+ - **采样阶段处理流(通用)**:
26
+ 1. 加载原始数据集样本(`prompt`/`indices_to_explain` 等保持一致)。
27
+ 2. 按模板调用生成模型,要求「思考文本 + 末尾 \\box{} 答案」。
28
+ 3. 若生成不符合「思考 + 单个 \\box{} 且无尾巴」的格式,直接丢弃该样本。
29
+ 4. 提取思考片段与 `\\box{}` 内文本,仅用 `\\box{}` 内文调用判定模型。
30
+ 5. 判定为 True 时,重新拼接「思考片段 + 去除 box 包裹的答案文本」作为 `target`,并据此记录 `sink_span`/`thinking_span`。
31
+ 6. 写入缓存:只保留 `reference_answer`、`judge_response`(可选 `boxed_answer`),不再存储 `candidate_answer`。
32
+
33
+ ### 生成切分与 span 解析
34
+ - `split_boxed_generation`(`dataset_utils.py`)校验格式:必须是「非空思考文本 + 单个末尾 \\box{}」且箱体之后无其他字符,否则直接跳过。
35
+ - `target` 由「思考片段 + 换行 + 最终答案文本(无 box)」重组。
36
+ - `attach_spans_from_answer` 使用 tokenizer 的 offset mapping 将最终答案在 `target` 中的字符区间映射到 token 级索引,得到 `sink_span`;`thinking_span` 取从开头到 `sink_span` 前一 token 的闭区间。两者均为 token 级 span,满足后续多跳 IFR 的调用约定。
37
+ - `indices_to_explain` 在采样写缓存时统一设置为 `sink_span`(boxed 内文在 `target` 中对应的 generation token span)。
38
+
39
+ ---
40
+
41
+ ## MoreHopQA
42
+ - **原始样本结构(`MoreHopQAAttributionDataset` → `CachedExample`)**
43
+ ```json
44
+ {
45
+ "prompt": "<context 拼接>\\n<question>",
46
+ "target": null,
47
+ "indices_to_explain": null,
48
+ "attr_mask_indices": null,
49
+ "sink_span": null,
50
+ "thinking_span": null,
51
+ "metadata": {
52
+ "answer": "<gold answer>",
53
+ "_id": "<example id>",
54
+ "original_context": <原始上下文结构>
55
+ }
56
+ }
57
+ ```
58
+ - 加载时机:`DatasetLoader.load_raw("morehopqa")` 在采样阶段、归因阶段(无缓存时)都会产出 `CachedExample`。
59
+ - 说明:exp2 的 token-level row/rec 需要 `target` + 可定位的答案 token span;建议先跑 `sample_and_filter.py` 产出缓存后再做归因评估。
60
+
61
+ - **采样阶段(生成 & 过滤后写缓存)**
62
+ ```json
63
+ {
64
+ "prompt": "<同上>",
65
+ "target": "<生成的 CoT + 最终答案文本(已去掉 box 包裹)>",
66
+ "indices_to_explain": [start_tok, end_tok],
67
+ "attr_mask_indices": null,
68
+ "sink_span": [start_tok, end_tok] | null,
69
+ "thinking_span": [start_tok, end_tok] | null,
70
+ "metadata": {
71
+ "answer": "<gold answer>",
72
+ "_id": "<example id>",
73
+ "original_context": <原始上下文结构>,
74
+ "reference_answer": "<gold answer>",
75
+ "judge_response": "<True/False 文本>",
76
+ "boxed_answer": "<可选,boxed 解析结果>"
77
+ }
78
+ }
79
+ ```
80
+ - `sink_span`/`thinking_span`:仅在成功解析 `\\box{}` 时填充;`target` 为「思考 + 最终答案文本」的裁剪版。
81
+ - 写入:`exp/exp2/data/morehopqa.jsonl`。
82
+
83
+ - **归因阶段(加载缓存优先)**
84
+ - 加载:`run_exp.py` 优先 `load_cached`(JSONL → `CachedExample`),否则回退原始结构并在线生成 `target`。
85
+ - 使用:忠实度(token-level RISE/MAS)直接用缓存的 `target`;`ifr_multi_hop` 在有 `sink_span`/`thinking_span` 时限定答案/CoT,否则视整个生成为 sink。
86
+
87
+ ---
88
+
89
+ ## RULER 热点问答(`hotpotqa_long`)
90
+ - **原始样本结构(`RulerAttributionDataset` → `CachedExample`)**
91
+ ```json
92
+ {
93
+ "prompt": "<input> + <answer_prefix>",
94
+ "target": "<answer_prefix + sep + ', '.join(outputs)>",
95
+ "indices_to_explain": [0],
96
+ "attr_mask_indices": [<句子索引>...] | null,
97
+ "sink_span": null,
98
+ "thinking_span": null,
99
+ "metadata": {
100
+ "dataset": "ruler",
101
+ "length": <int>,
102
+ "length_w_model_temp": <any>,
103
+ "outputs": [...],
104
+ "answer_prefix": "<str>",
105
+ "token_position_answer": <any>,
106
+ "needle_spans": [
107
+ {
108
+ "title": "<str>",
109
+ "doc_index": <int>,
110
+ "document_number": <int>,
111
+ "sentence_index": <int>,
112
+ "sentence": "<str>",
113
+ "context_span": [start, end],
114
+ "span": [start, end],
115
+ "snippet": "<str>"
116
+ },
117
+ ...
118
+ ],
119
+ "prompt_sentence_count": <int>,
120
+ "reference_answer": "<在 loader 中补充,来自 outputs 或 target>"
121
+ }
122
+ }
123
+ ```
124
+ - 加载时机:`DatasetLoader.load_raw("hotpotqa_long")` 在采样阶段、归因阶段(无缓存时)都会产出 `CachedExample`。
125
+
126
+ - **采样阶段(生成 & 过滤后写缓存)**
127
+ ```json
128
+ {
129
+ "prompt": "<同上>",
130
+ "target": "<生成的 CoT + 最终答案文本(已去掉 box 包裹)>",
131
+ "indices_to_explain": [-2],
132
+ "attr_mask_indices": [<句子索引>...] | null,
133
+ "sink_span": [start_tok, end_tok] | null,
134
+ "thinking_span": [start_tok, end_tok] | null,
135
+ "metadata": {
136
+ "dataset": "ruler",
137
+ "length": <int>,
138
+ "length_w_model_temp": <any>,
139
+ "outputs": [...],
140
+ "answer_prefix": "<str>",
141
+ "token_position_answer": <any>,
142
+ "needle_spans": [...],
143
+ "prompt_sentence_count": <int>,
144
+ "reference_answer": "<outputs 拼接或 target>",
145
+ "judge_response": "<True/False 文本>",
146
+ "boxed_answer": "<可选>"
147
+ }
148
+ }
149
+ ```
150
+ - `attr_mask_indices` 保留原值;`indices_to_explain` 统一为末句 `[-2]`(最后一个非 EOS 生成句);`sink_span`/`thinking_span` 仅在成功解析 `\\box{}` 时填充;`target` 为「思考 + 最终答案文本」的裁剪版。
151
+ - 写入:`exp/exp2/data/hotpotqa_long.jsonl`。
152
+
153
+ - **归因阶段(加载缓存优先)**
154
+ - 加载:优先 `load_cached`(JSONL → `CachedExample`),否则回退原始解析。
155
+ - 使用:覆盖率使用 `attr_mask_indices`;忠实度与 `ifr_multi_hop` 利用缓存的 `sink_span`/`thinking_span` 定位答案/CoT,若缺失则视整个生成为 sink。
156
+
157
+ ---
158
+
159
+ ## RULER NIAH / Variable Tracking(`niah_*`, `vt_*`)
160
+ - **原始样本结构(同 RULER 通用)**
161
+ ```json
162
+ {
163
+ "prompt": "<input> + <answer_prefix>",
164
+ "target": "<answer_prefix + sep + ', '.join(outputs)>",
165
+ "indices_to_explain": [0],
166
+ "attr_mask_indices": [<句子索引>...] | null,
167
+ "sink_span": null,
168
+ "thinking_span": null,
169
+ "metadata": {
170
+ "dataset": "ruler",
171
+ "length": <int>,
172
+ "length_w_model_temp": <any>,
173
+ "outputs": [...],
174
+ "answer_prefix": "<str>",
175
+ "token_position_answer": <any>,
176
+ "needle_spans": [...],
177
+ "prompt_sentence_count": <int>,
178
+ "reference_answer": "<在 loader 中补充>"
179
+ }
180
+ }
181
+ ```
182
+ - 加载时机:`DatasetLoader.load_raw("<niah_* 或 vt_*>")` 在采样阶段、归因阶段(无缓存时)使用。
183
+
184
+ - **采样阶段(生成 & 过滤后写缓存)**
185
+ ```json
186
+ {
187
+ "prompt": "<同上>",
188
+ "target": "<思考 + 最终答案文本(无 box),无其他尾巴>",
189
+ "indices_to_explain": [start_tok, end_tok],
190
+ "attr_mask_indices": [<句子索引>...] | null,
191
+ "sink_span": [start_tok, end_tok] | null,
192
+ "thinking_span": [start_tok, end_tok] | null,
193
+ "metadata": {
194
+ "dataset": "ruler",
195
+ "length": <int>,
196
+ "length_w_model_temp": <any>,
197
+ "outputs": [...],
198
+ "answer_prefix": "<str>",
199
+ "token_position_answer": <any>,
200
+ "needle_spans": [...],
201
+ "prompt_sentence_count": <int>,
202
+ "reference_answer": "<outputs 拼接或 target>",
203
+ "judge_response": "<True/False 文本>",
204
+ "boxed_answer": "<可选>"
205
+ }
206
+ }
207
+ ```
208
+ - 生成/判定流程与 `hotpotqa_long` 相同;`target` 是裁剪后的「思考 + 最终答案文本」。
209
+ - 写入:`exp/exp2/data/<dataset>.jsonl`(例如 `niah_mq_q2.jsonl`, `vt_h6_c1.jsonl`)。
210
+
211
+ - **归因阶段(加载缓存优先)**
212
+ - 与 `hotpotqa_long` 相同:优先缓存,否则原始;恢复率(`recovery_ruler`)使用 `metadata.needle_spans`(映射到 prompt tokens);多跳 IFR 在有 `sink_span`/`thinking_span` 时作用于答案/CoT。
213
+
214
+ ---
215
+
216
+ ## `indices_to_explain` 约定
217
+ - token-level:`indices_to_explain = [start_tok, end_tok]`(闭区间),坐标系为 `tokenizer(target, add_special_tokens=False)` 的 generation token indices。
218
+ - exp2 推荐:`indices_to_explain == sink_span`,即 boxed 内文(最终答案)在 `target` 中对应的 token span。
219
+
220
+ ---
221
+
222
+ ## 自定义 RULER JSONL 路径
223
+ - 若 `--dataset` 传入存在的 JSONL 路径,`dataset_from_name` 按 RULER 文件解析,字段与流程同 RULER 系列。
224
+ - 采样、归因阶段行为与上文 RULER 描述一致,只是文件名由显式路径决定。
225
+
226
+ ---
227
+
228
+ ## 归因阶段加载优先级与效果
229
+ - `run_exp.py` 加载顺序:`exp/exp2/data/<name>.jsonl` 缓存 > 显式给定的 JSONL 路径 > 原始解析(MoreHopQA 或 RULER)
230
+ - 恢复率 (`mode=recovery_ruler`) 仅支持 RULER(要求 `metadata.needle_spans`),否则拒绝
231
+ - 忠实度 (`mode=faithfulness_gen`) 使用生成文本;`ifr_multi_hop` 在有 `sink_span`/`thinking_span` 时才对答案/CoT 做多跳,否则退化为整段生成
exp/exp2/README.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashTrace 实验 2(多步推理下的忠实度)
2
+
3
+ 本目录提供「11 数据集 × 9 方法 × 3 指标」的实验工具,**跳过 AT2**,**跳过 math**。流程分为两步:先采样并过滤高质量 CoT+boxed 生成,再对过滤结果做归因评估。
4
+
5
+ 支持数据集:MoreHopQA、HotpotQA(RULER hotpotqa_long)、RULER niah(niah_*)、RULER variable tracking(vt_*)。RULER 路径自动在 `data/ruler_multihop/<len>/.../validation.jsonl` 中搜索。
6
+
7
+ 主要文件:
8
+ - `sample_and_filter.py`:采样 + 判定一致性,输出到 `exp/exp2/data/`
9
+ - `run_exp.py`:归因测试,输出到 `exp/exp2/output/`
10
+ - `dataset_utils.py`:数据加载、答案 span 解析
11
+
12
+ 采样脚本支持的数据集
13
+ - `morehopqa`(本地 `data/with_human_verification.json`)
14
+ - `hotpotqa_long`(自动在 `data/ruler_multihop/<len>/hotpotqa_long/validation.jsonl` 搜索)
15
+ - `niah_*`(RULER niah 变体,自动搜索同上)
16
+ - `vt_*`(RULER variable tracking 变体,自动搜索同上)
17
+ - 直接传 RULER JSONL 路径(作为数据集名处理),其余类型不支持
18
+
19
+ 归因测试支持
20
+ - 数据集:优先使用 `exp/exp2/data/<name>.jsonl` 缓存,若无则按采样同样的解析规则加载;math 显式拒绝。
21
+ - 指标:
22
+ - `faithfulness_gen`(生成侧):可运行在任何已加载样本(math 以外)。
23
+ - `recovery_ruler`(恢复率,仅 RULER):Recall@10%(排名只在 prompt tokens 上进行,gold 来自 `needle_spans`)。
24
+ - 方法(`--attr_funcs`):`IG`、`perturbation_all`、`perturbation_CLP`、`perturbation_REAGENT`、`attention`(内部融合 IG)、`ifr_all_positions`、`ifr_multi_hop`、`attnlrp`、`ft_attnlrp`、`basic`。AT2 未提供。
25
+
26
+ ---
27
+
28
+ ## 数据采样
29
+
30
+ 实现逻辑
31
+ - 统一数据加载:`DatasetLoader` 读取 MoreHopQA / HotpotQA / RULER niah / RULER vt;可直接传自定义 RULER JSONL。
32
+ - 生成模型:`qwen3-235b-a22b-2507`(英文 system prompt),要求「先简要思考,再用 `\box{}` 包裹最终答案且末尾不追加内容」;user prompt 为原题,无额外模板。
33
+ - 判定模型:`deepseek-v3-1-terminus`(英文 system prompt),只输出 True/False 判断 `\box{}` 内文与参考答案是否一致。
34
+ - 过滤:仅保留「思考 + 末尾 boxed 答案」且判定为 True 的样本;`target` 用提取的思考片段与 **去掉 box 包裹的最终答案** 重组,附带 token 级 `sink_span`/`thinking_span`、`reference_answer`、`judge_response`(不再存 `candidate_answer`),`indices_to_explain` 统一写为 `sink_span`(boxed 内文在 `target` 的 generation token span,[start_tok, end_tok])。
35
+ - 采样会按原始顺序依次尝试样本,判定失败立即跳过;累计到 `--max_examples` 条成功样本即提前停止(若源数据不足则更少),tqdm 会分别显示尝试与成功计数。
36
+
37
+ 使用说明
38
+ ```bash
39
+ export FLASHTRACE_API_KEY=sk-yaojia-get-ccfa # 或 OPENAI_API_KEY
40
+
41
+ # 示例:采样 hotpotqa_long,保留最多 100 条判定为 True 的样本
42
+ python exp/exp2/sample_and_filter.py \
43
+ --dataset data/with_human_verification.json \
44
+ --max_examples 100 \
45
+ --api_key sk-yaojia-get-ccfa \
46
+ --tokenizer_model /opt/share/models/Qwen/Qwen3-8B > exp/exp2/out.log
47
+ ```
48
+ 常用参数:
49
+ - `--dataset`:morehopqa | hotpotqa_long | niah_* | vt_*(或直接 JSONL 路径)
50
+ - `--max_examples`:希望保留的成功样本数;达到后即停止(若源数据不足则更少)
51
+ - `--tokenizer_model`:用于 span 检测的 tokenizer(默认复用生成模型)
52
+ - `--api_base`/`--api_key`:接口地址与密钥(默认本地 http://localhost:4000/v1)
53
+ - `--request_interval` / `--judge_interval`:生成/判定间隔节流(默认 1s)
54
+ - `--rate_limit_delay`:遇到 HTTP 429 时的等待秒数(默认 5s);会在重试前自动 sleep
55
+ 输出:`exp/exp2/data/<dataset>.jsonl`
56
+
57
+ ---
58
+
59
+ ## 归因测试
60
+
61
+ 实现逻辑
62
+ - 输入:优先读取 `exp/exp2/data/<dataset>.jsonl`(过滤缓存);若不存在则回退到原始数据解析。
63
+ - 方法:忠实度(token-level RISE/MAS)对齐 `evaluations/faithfulness.py` 的逻辑(AT2 未实现),math 自动拒绝。
64
+ - 多跳 FlashTrace:若缓存含 `sink_span`/`thinking_span` 则用于 multi-hop IFR,否则默认整句答案为 sink。
65
+ - 一次运行可同时评测多个指标:`--mode` 支持多值与逗号分隔(如 `--mode faithfulness_gen,recovery_ruler` 或 `--mode faithfulness_gen, recovery_ruler`),对同一批样本只做一次归因。
66
+ - 可选保存样本级 trace:加 `--save_hop_traces` 会为**所有方法、所有样本**保存归因向量与逐样本指标到 `exp/exp2/output/traces/...`;对 multi-hop 方法还会额外保存每跳的 token-level 向量 `V_h`(单一 `vh`,即实际参与多跳传播的向量),并在 manifest 中记录 `attnlrp_neg_handling/attnlrp_norm_mode` 等设置。
67
+ - 已知兼容性:部分 tokenizer 在 chat template 边界会出现 token 合并,导致评测侧用 token-id 子序列定位 user prompt 失败;exp2 已改为直接复用归因阶段算出的 `user_prompt_indices` 做扰动定位。
68
+ - 批大小估算:沿用原脚本 `(max_input_len-100)/len(tokenizer(format_prompt(prompt)+target))` 的保守估计(至少 1)。`max_input_len` 由代码内置映射表基于 `--model` 字符串决定,未命中或仅传 `--model_path` 时默认 2000;如需映射值而又用本地路径,请同时传入对应的 `--model` 名称。
69
+ - 计时:对每个样本的归因计算(recovery/faithfulness)分别计时,最终在 CSV 末尾追加 `Avg Sample Time (s)` 并在控制台打印平均耗时。
70
+ - 输出:`exp/exp2/output/faithfulness/...`、`exp/exp2/output/recovery/...`,以及(可选)`exp/exp2/output/traces/...`,按数据集和模型分目录。
71
+
72
+ 使用说明
73
+ ```bash
74
+ # 生成侧 RISE/MAS 忠实度 perturbation_all_fast,perturbation_CLP_fast,perturbation_REAGENT_fast,ifr_multi_hop_stop_words,ifr_multi_hop_both,ifr_multi_hop_split_hop,ft_attnlrp,ifr_multi_hop,attnlrp,ifr_all_positions,perturbation_all,perturbation_REAGENT,perturbation_CLP,IG,attention
75
+ python exp/exp2/run_exp.py \
76
+ --datasets exp/exp2/data/math.jsonl \
77
+ --attr_funcs IG,attention \
78
+ --model qwen-8B \
79
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
80
+ --cuda 2,3,4,5,6,7 \
81
+ --num_examples 100 \
82
+ --mode faithfulness_gen \
83
+ --n_hops 1 \
84
+ --save_hop_traces \
85
+ && python exp/exp2/run_exp.py \
86
+ --datasets exp/exp2/data/morehopqa.jsonl \
87
+ --attr_funcs IG,attention \
88
+ --model qwen-8B \
89
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
90
+ --cuda 2,3,4,5,6,7 \
91
+ --num_examples 100 \
92
+ --mode faithfulness_gen \
93
+ --n_hops 1 \
94
+ --save_hop_traces
95
+
96
+ # --attnlrp_neg_handling drop \
97
+ # --attnlrp_norm_mode norm
98
+ ```
99
+ 常用参数:
100
+ - `--datasets`:逗号分隔数据集名;若已存在 `exp/exp2/data/<name>.jsonl` 则直接使用。
101
+ - `--attr_funcs`:逗号分隔方法(无 AT2);`ifr_multi_hop` 与 `ft_attnlrp` 支持多跳(由 `--n_hops` 控制)。
102
+ - `--attnlrp_neg_handling`:FT-AttnLRP 每跳负值处理(`drop`/`abs`)。
103
+ - `--attnlrp_norm_mode`:FT-AttnLRP 正则化与 hop ratio 开关(`norm`/`no_norm`)。
104
+ - `--data_root`/`--output_root`:缓存与结果目录(默认 `exp/exp2/data` / `exp/exp2/output`)。
105
+ - `--mode`:`faithfulness_gen`、`recovery_ruler`,可多值/逗号分隔(一次归因同时输出多个指标);`--num_examples` 控制评测条数。math 会被拒绝。***
106
+ - `--save_hop_traces`:保存样本级 trace 到 `exp/exp2/output/traces/<dataset>/<model>/<run_tag>/`(每样本 `ex_*.npz` + `manifest.jsonl`)。
exp/exp2/dataset_utils.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset helpers for Experiment 2 (CoT / multi-hop faithfulness).
2
+
3
+ Named dataset_utils to avoid collision with the HF `datasets` package.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import random
10
+ import re
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Any, Dict, Iterable, List, Optional
14
+
15
+ from attribution_datasets import (
16
+ AttributionExample,
17
+ MoreHopQAAttributionDataset,
18
+ RulerAttributionDataset,
19
+ )
20
+
21
+
22
+ @dataclass
23
+ class CachedExample:
24
+ prompt: str
25
+ target: Optional[str]
26
+ indices_to_explain: Optional[List[int]]
27
+ attr_mask_indices: Optional[List[int]]
28
+ sink_span: Optional[List[int]]
29
+ thinking_span: Optional[List[int]]
30
+ metadata: Dict[str, Any]
31
+
32
+
33
+ def read_cached_jsonl(path: Path) -> List[CachedExample]:
34
+ examples: List[CachedExample] = []
35
+ with path.open("r", encoding="utf-8") as f:
36
+ for line in f:
37
+ if not line.strip():
38
+ continue
39
+ obj = json.loads(line)
40
+ examples.append(
41
+ CachedExample(
42
+ prompt=obj["prompt"],
43
+ target=obj.get("target"),
44
+ indices_to_explain=obj.get("indices_to_explain"),
45
+ attr_mask_indices=obj.get("attr_mask_indices"),
46
+ sink_span=obj.get("sink_span"),
47
+ thinking_span=obj.get("thinking_span"),
48
+ metadata=obj.get("metadata", {}),
49
+ )
50
+ )
51
+ return examples
52
+
53
+
54
+ def load_cached(path: Path, sample: Optional[int] = None, seed: int = 42) -> List[CachedExample]:
55
+ ex = read_cached_jsonl(path)
56
+ if sample is not None and sample < len(ex):
57
+ random.Random(seed).shuffle(ex)
58
+ ex = ex[:sample]
59
+ return ex
60
+
61
+
62
+ def load_ruler(path: Path, sample: Optional[int] = None, seed: int = 42) -> List[CachedExample]:
63
+ ds = RulerAttributionDataset(path)
64
+ examples: List[CachedExample] = []
65
+ ex_iter: Iterable[AttributionExample] = ds
66
+ if sample is not None and sample < len(ds):
67
+ ex_iter = list(ds)
68
+ random.Random(seed).shuffle(ex_iter)
69
+ ex_iter = ex_iter[:sample]
70
+ for ex in ex_iter:
71
+ examples.append(
72
+ CachedExample(
73
+ prompt=ex.prompt,
74
+ target=ex.target,
75
+ indices_to_explain=ex.indices_to_explain,
76
+ attr_mask_indices=ex.attr_mask_indices,
77
+ sink_span=None,
78
+ thinking_span=None,
79
+ metadata=ex.metadata,
80
+ )
81
+ )
82
+ return examples
83
+
84
+
85
+ def load_morehopqa(
86
+ path: str | Path = "./data/with_human_verification.json", sample: Optional[int] = None, seed: int = 42
87
+ ) -> List[CachedExample]:
88
+ ds = MoreHopQAAttributionDataset(path)
89
+ ex_iter: Iterable[AttributionExample] = ds
90
+ if sample is not None and sample < len(ds):
91
+ ex_iter = list(ds)
92
+ random.Random(seed).shuffle(ex_iter)
93
+ ex_iter = ex_iter[:sample]
94
+ examples: List[CachedExample] = []
95
+ for ex in ex_iter:
96
+ examples.append(
97
+ CachedExample(
98
+ prompt=ex.prompt,
99
+ target=None,
100
+ indices_to_explain=ex.indices_to_explain,
101
+ attr_mask_indices=ex.attr_mask_indices,
102
+ sink_span=None,
103
+ thinking_span=None,
104
+ metadata=ex.metadata,
105
+ )
106
+ )
107
+ return examples
108
+
109
+
110
+ def auto_find_ruler(task: str) -> Optional[Path]:
111
+ length_dirs = ["4096", "8192", "16384", "32768", "65536", "131072"]
112
+ base = Path("data/ruler_multihop")
113
+ for ld in length_dirs:
114
+ cand = base / ld / task / "validation.jsonl"
115
+ if cand.exists():
116
+ return cand
117
+ return None
118
+
119
+
120
+ def dataset_from_name(name: str) -> Optional[Path]:
121
+ if name == "hotpotqa_long":
122
+ return auto_find_ruler("hotpotqa_long")
123
+ if name.startswith("vt_"):
124
+ return auto_find_ruler(name)
125
+ if name.startswith("niah"):
126
+ return auto_find_ruler(name)
127
+ p = Path(name)
128
+ if p.exists():
129
+ return p
130
+ return None
131
+
132
+
133
+ _BOX_PATTERN = re.compile(r"\\box(?:ed)?\s*[\{{](.*?)[\}}]", flags=re.DOTALL)
134
+
135
+
136
+ def _find_box_span(text: str) -> Optional[tuple[int, int, str]]:
137
+ """Return (start_char, end_char, answer_text) for the last \\boxed block."""
138
+ matches = list(_BOX_PATTERN.finditer(text))
139
+ if not matches:
140
+ return None
141
+ m = matches[-1]
142
+ return m.start(0), m.end(0), m.group(1).strip()
143
+
144
+
145
+ def extract_boxed_answer(text: str) -> Optional[str]:
146
+ """Extract the answer string inside the last \\boxed{} block."""
147
+ match = _find_box_span(text)
148
+ return match[2] if match else None
149
+
150
+
151
+ def _find_answer_span(text: str, answer: str) -> Optional[tuple[int, int]]:
152
+ """Return (start_char, end_char) for the last occurrence of `answer` in text."""
153
+ if not answer or not text:
154
+ return None
155
+ start = text.rfind(answer)
156
+ if start == -1:
157
+ return None
158
+ return start, start + len(answer)
159
+
160
+
161
+ def split_boxed_generation(text: str) -> Optional[tuple[str, str, str]]:
162
+ """Return (thinking_text, boxed_segment, boxed_answer) if format matches."""
163
+ if not text:
164
+ return None
165
+ match = _find_box_span(text)
166
+ if not match:
167
+ return None
168
+
169
+ start_char, end_char, boxed_inner = match
170
+ boxed_segment = text[start_char:end_char].strip()
171
+ thinking_text = text[:start_char].strip()
172
+ trailing = text[end_char:].strip()
173
+
174
+ if not boxed_inner or not boxed_segment:
175
+ return None
176
+ if trailing:
177
+ return None
178
+ if not thinking_text:
179
+ return None
180
+
181
+ return thinking_text, boxed_segment, boxed_inner
182
+
183
+
184
+ def attach_spans_from_answer(
185
+ example: CachedExample, tokenizer, answer_text: Optional[str] = None
186
+ ) -> CachedExample:
187
+ """Attach sink/thinking spans by locating the (plain) answer in `target`.
188
+
189
+ `answer_text` should be the extracted boxed answer; falls back to metadata or
190
+ parsing the target when omitted. Works even when the target no longer keeps
191
+ the \\box{} wrapper.
192
+ """
193
+ tgt = example.target or ""
194
+ answer = (answer_text or "").strip()
195
+ if not answer:
196
+ answer = (example.metadata.get("boxed_answer") or extract_boxed_answer(tgt) or "").strip()
197
+
198
+ metadata = dict(example.metadata)
199
+ if answer:
200
+ metadata.setdefault("boxed_answer", answer)
201
+
202
+ if tokenizer is None or not tgt or not answer:
203
+ return CachedExample(
204
+ prompt=example.prompt,
205
+ target=example.target,
206
+ indices_to_explain=example.indices_to_explain,
207
+ attr_mask_indices=example.attr_mask_indices,
208
+ sink_span=example.sink_span,
209
+ thinking_span=example.thinking_span,
210
+ metadata=metadata,
211
+ )
212
+
213
+ span = _find_answer_span(tgt, answer)
214
+ if span is None:
215
+ return CachedExample(
216
+ prompt=example.prompt,
217
+ target=example.target,
218
+ indices_to_explain=example.indices_to_explain,
219
+ attr_mask_indices=example.attr_mask_indices,
220
+ sink_span=example.sink_span,
221
+ thinking_span=example.thinking_span,
222
+ metadata=metadata,
223
+ )
224
+
225
+ span_start_char, span_end_char = span
226
+ gen_ids = tokenizer(tgt, add_special_tokens=False, return_offsets_mapping=True)
227
+ sink_tokens: List[int] = []
228
+ for idx, (s, e) in enumerate(gen_ids["offset_mapping"]):
229
+ # include tokens that overlap the answer span
230
+ if s < span_end_char and e > span_start_char:
231
+ sink_tokens.append(idx)
232
+ if not sink_tokens:
233
+ return CachedExample(
234
+ prompt=example.prompt,
235
+ target=example.target,
236
+ indices_to_explain=example.indices_to_explain,
237
+ attr_mask_indices=example.attr_mask_indices,
238
+ sink_span=example.sink_span,
239
+ thinking_span=example.thinking_span,
240
+ metadata=metadata,
241
+ )
242
+
243
+ sink_span = [min(sink_tokens), max(sink_tokens)]
244
+ thinking_end = max(0, sink_span[0] - 1)
245
+ thinking_span = [0, thinking_end] if thinking_end >= 0 else sink_span
246
+
247
+ return CachedExample(
248
+ prompt=example.prompt,
249
+ target=example.target,
250
+ indices_to_explain=example.indices_to_explain,
251
+ attr_mask_indices=example.attr_mask_indices,
252
+ sink_span=example.sink_span or sink_span,
253
+ thinking_span=example.thinking_span or thinking_span,
254
+ metadata=metadata,
255
+ )
256
+
257
+
258
+ def attach_spans_from_boxed(example: CachedExample, tokenizer) -> CachedExample:
259
+ """Backward-compatible wrapper that first looks for \\box{} then falls back to answer text."""
260
+ tgt = example.target
261
+ match = _find_box_span(tgt) if tgt else None
262
+ boxed_answer = match[2] if match else None
263
+ return attach_spans_from_answer(example, tokenizer, boxed_answer)
264
+
265
+
266
+ def ruler_gold_prompt_token_indices(example: CachedExample, tokenizer) -> List[int]:
267
+ """Return token indices (prompt-side) that overlap RULER `needle_spans` in metadata.
268
+
269
+ The returned indices are with respect to `tokenizer(" " + example.prompt, add_special_tokens=False)`,
270
+ matching the attribution pipeline's leading-space convention.
271
+ """
272
+ needle_spans = (example.metadata or {}).get("needle_spans") or []
273
+ if not isinstance(needle_spans, list) or not needle_spans:
274
+ return []
275
+
276
+ prompt_text = " " + (example.prompt or "")
277
+ enc = tokenizer(prompt_text, add_special_tokens=False, return_offsets_mapping=True)
278
+ offsets = enc.get("offset_mapping")
279
+ if offsets is None:
280
+ raise ValueError("Tokenizer does not provide offset_mapping; cannot map needle_spans to tokens.")
281
+
282
+ spans: List[tuple[int, int]] = []
283
+ for item in needle_spans:
284
+ if not isinstance(item, dict):
285
+ continue
286
+ raw = item.get("span")
287
+ if not (isinstance(raw, list) and len(raw) == 2):
288
+ continue
289
+ try:
290
+ start = int(raw[0]) + 1 # shift for leading space in prompt_text
291
+ end = int(raw[1]) + 1
292
+ except Exception:
293
+ continue
294
+ if end > start:
295
+ spans.append((start, end))
296
+
297
+ if not spans:
298
+ return []
299
+
300
+ gold: set[int] = set()
301
+ for tok_idx, off in enumerate(offsets):
302
+ if off is None:
303
+ continue
304
+ try:
305
+ s, e = int(off[0]), int(off[1])
306
+ except Exception:
307
+ continue
308
+ if e <= s:
309
+ continue
310
+ for span_start, span_end in spans:
311
+ if s < span_end and e > span_start:
312
+ gold.add(tok_idx)
313
+ break
314
+
315
+ return sorted(gold)
316
+
317
+
318
+ class DatasetLoader:
319
+ """Thin loader that resolves and samples datasets for exp2."""
320
+
321
+ def __init__(self, seed: int = 42, data_root: Path | str = Path("exp/exp2/data")) -> None:
322
+ self.seed = seed
323
+ self.data_root = Path(data_root)
324
+
325
+ def _sample(self, items: List[CachedExample], sample: Optional[int]) -> List[CachedExample]:
326
+ if sample is not None and sample < len(items):
327
+ rnd = random.Random(self.seed)
328
+ rnd.shuffle(items)
329
+ items = items[:sample]
330
+ return items
331
+
332
+ def _cached_path(self, name: str) -> Optional[Path]:
333
+ path = self.data_root / f"{name}.jsonl"
334
+ return path if path.exists() else None
335
+
336
+ def load(self, name: str, sample: Optional[int] = None) -> List[CachedExample]:
337
+ # 1) Prefer prepared cache under exp/exp2/data
338
+ cached_path = self._cached_path(name)
339
+ if cached_path:
340
+ return self._sample(load_cached(cached_path), sample)
341
+
342
+ return self.load_raw(name, sample=sample)
343
+
344
+ def load_raw(self, name: str, sample: Optional[int] = None) -> List[CachedExample]:
345
+ def _looks_like_json_array(path: Path) -> bool:
346
+ try:
347
+ with path.open("r", encoding="utf-8") as f:
348
+ while True:
349
+ ch = f.read(1)
350
+ if not ch:
351
+ return False
352
+ if ch.isspace():
353
+ continue
354
+ return ch == "["
355
+ except OSError:
356
+ return False
357
+
358
+ # MoreHopQA
359
+ if name == "morehopqa":
360
+ ex = load_morehopqa()
361
+ for item in ex:
362
+ if "answer" in item.metadata:
363
+ item.metadata.setdefault("reference_answer", item.metadata["answer"])
364
+ return self._sample(ex, sample)
365
+
366
+ # Allow passing the raw MoreHopQA JSON path directly.
367
+ p = Path(name)
368
+ if p.exists() and _looks_like_json_array(p):
369
+ ex = load_morehopqa(p)
370
+ for item in ex:
371
+ if "answer" in item.metadata:
372
+ item.metadata.setdefault("reference_answer", item.metadata["answer"])
373
+ return self._sample(ex, sample)
374
+
375
+ # RULER / HotpotQA / niah / vt (all go through RulerAttributionDataset)
376
+ resolved = dataset_from_name(name)
377
+ if resolved is None:
378
+ raise FileNotFoundError(f"Could not resolve dataset {name}")
379
+ ex = load_ruler(resolved)
380
+ for item in ex:
381
+ outputs = item.metadata.get("outputs") or []
382
+ if outputs:
383
+ item.metadata.setdefault("reference_answer", ", ".join(outputs))
384
+ if item.target and "reference_answer" not in item.metadata:
385
+ item.metadata["reference_answer"] = item.target
386
+ return self._sample(ex, sample)
exp/exp2/map_math_mine_to_exp2_cache.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Prepare data/math_mine.json into an exp2 cached JSONL dataset.
3
+
4
+ This script supports two modes:
5
+
6
+ - map (offline): convert GSM8K-style math examples:
7
+
8
+ {"question": "...", "answer": "... #### 18"}
9
+
10
+ into exp2's cached JSONL format (one JSON object per line).
11
+
12
+ - resample (online): resample targets like exp/exp2/sample_and_filter.py:
13
+ call a chat completion API to generate "<thinking> + final \\box{} answer",
14
+ judge the boxed answer against the reference answer extracted from the raw
15
+ GSM8K-style entry, and write only judge=True samples.
16
+
17
+ In both modes, exp2 expects token-level spans (NOT character spans):
18
+
19
+ - indices_to_explain: [start_tok, end_tok] (generation-token indices, closed interval)
20
+ - sink_span/thinking_span: token spans over tokenizer(target, add_special_tokens=False)
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import json
27
+ import os
28
+ import sys
29
+ import time
30
+ import urllib.error
31
+ import urllib.request
32
+ from dataclasses import asdict
33
+ from pathlib import Path
34
+ from typing import Any, Dict, List, Optional, Tuple
35
+
36
+ from transformers import AutoTokenizer
37
+ from tqdm import tqdm
38
+
39
+ REPO_ROOT = Path(__file__).resolve().parents[2]
40
+ if str(REPO_ROOT) not in sys.path:
41
+ sys.path.insert(0, str(REPO_ROOT))
42
+
43
+ from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer, split_boxed_generation # noqa: E402
44
+
45
+
46
+ class RateLimitError(RuntimeError):
47
+ """Raised when API returns 429; carries a suggested wait time."""
48
+
49
+ def __init__(self, wait_seconds: float, detail: str) -> None:
50
+ super().__init__(detail)
51
+ self.wait_seconds = wait_seconds
52
+
53
+
54
+ GEN_SYSTEM_PROMPT = (
55
+ "You are a reasoning assistant. "
56
+ "Before answering, engage in an chain of thought. "
57
+ "Process this freely and naturally without using specific headers or strict formatting. "
58
+ "When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
59
+ "Ensure the box wraps the **sentence** that naturally delivers the answer. DO NOT rewrite the answer word for the box separately."
60
+ )
61
+
62
+ JUDGE_SYSTEM_PROMPT = (
63
+ "You verify whether the model's boxed answer matches the reference answer. "
64
+ "Reply strictly with True or False and nothing else."
65
+ )
66
+
67
+
68
+ def call_chat_api(
69
+ api_base: str,
70
+ api_key: str,
71
+ model: str,
72
+ messages: List[Dict[str, str]],
73
+ *,
74
+ timeout: int,
75
+ max_tokens: int,
76
+ temperature: float,
77
+ cache_ttl: int,
78
+ cache_namespace: Optional[str],
79
+ rate_limit_delay: Optional[float] = None,
80
+ ) -> str:
81
+ """Minimal OpenAI-compatible chat.completions client (no external deps)."""
82
+ url = api_base.rstrip("/") + "/chat/completions"
83
+ payload: Dict[str, Any] = {
84
+ "model": model,
85
+ "messages": messages,
86
+ "max_tokens": max_tokens,
87
+ "temperature": temperature,
88
+ }
89
+ if cache_ttl > 0:
90
+ cache_obj: Dict[str, Any] = {"ttl": cache_ttl}
91
+ if cache_namespace:
92
+ cache_obj["namespace"] = cache_namespace
93
+ payload["cache"] = cache_obj
94
+
95
+ data = json.dumps(payload).encode("utf-8")
96
+ headers = {"Content-Type": "application/json"}
97
+ if api_key:
98
+ headers["Authorization"] = f"Bearer {api_key}"
99
+
100
+ req = urllib.request.Request(url, data=data, headers=headers, method="POST")
101
+ opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
102
+ try:
103
+ with opener.open(req, timeout=timeout) as resp:
104
+ resp_bytes = resp.read()
105
+ except urllib.error.HTTPError as e:
106
+ detail = e.read().decode("utf-8", errors="ignore") if hasattr(e, "read") else ""
107
+ if e.code == 429:
108
+ retry_after = None
109
+ if hasattr(e, "headers") and e.headers:
110
+ retry_after_header = e.headers.get("Retry-After")
111
+ if retry_after_header:
112
+ try:
113
+ retry_after = float(retry_after_header)
114
+ except ValueError:
115
+ retry_after = None
116
+ wait = retry_after or rate_limit_delay or 5.0
117
+ raise RateLimitError(wait, f"API HTTP 429: {detail}") from e
118
+ raise RuntimeError(f"API HTTP error {e.code}: {detail}") from e
119
+ except urllib.error.URLError as e:
120
+ raise RuntimeError(f"API request failed: {e}") from e
121
+
122
+ try:
123
+ response = json.loads(resp_bytes.decode("utf-8"))
124
+ except json.JSONDecodeError as e:
125
+ raise RuntimeError(f"Failed to decode API response: {resp_bytes!r}") from e
126
+
127
+ choices = response.get("choices", [])
128
+ if not choices:
129
+ raise RuntimeError(f"Empty choices from API: {response}")
130
+ content = choices[0].get("message", {}).get("content", "")
131
+ if not content:
132
+ raise RuntimeError(f"Empty content from API: {response}")
133
+ return content.strip()
134
+
135
+
136
+ def build_gen_messages(prompt: str) -> List[Dict[str, str]]:
137
+ return [
138
+ {"role": "system", "content": GEN_SYSTEM_PROMPT},
139
+ {"role": "user", "content": prompt},
140
+ ]
141
+
142
+
143
+ def build_judge_messages(reference_answer: str, candidate_answer: str) -> List[Dict[str, str]]:
144
+ user = (
145
+ "Decide if the model's boxed answer matches the reference answer.\n"
146
+ f"Reference answer: {reference_answer}\n"
147
+ f"Model boxed answer (only the content inside \\box{{}}): {candidate_answer}\n"
148
+ "Output only True if they are semantically consistent; otherwise output False."
149
+ )
150
+ return [
151
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
152
+ {"role": "user", "content": user},
153
+ ]
154
+
155
+
156
+ def parse_bool(text: str) -> bool:
157
+ first = text.strip().splitlines()[0].strip().lower()
158
+ if first in {"true", "yes"}:
159
+ return True
160
+ if first in {"false", "no"}:
161
+ return False
162
+ # fallback: check substring
163
+ if "true" in first and "false" not in first:
164
+ return True
165
+ if "false" in first:
166
+ return False
167
+ raise ValueError(f"Cannot parse boolean from: {text!r}")
168
+
169
+
170
+ def _load_tokenizer(tokenizer_model: str):
171
+ tok_path = Path(tokenizer_model)
172
+ if tok_path.exists():
173
+ tokenizer = AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
174
+ else:
175
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
176
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
177
+ tokenizer.pad_token = tokenizer.eos_token
178
+ return tokenizer
179
+
180
+
181
+ def _split_gsm8k_answer(answer: str) -> Optional[Tuple[str, str]]:
182
+ """Return (thinking_text, final_answer) parsed from GSM8K `answer`."""
183
+ text = (answer or "").strip()
184
+ if not text:
185
+ return None
186
+ if "####" not in text:
187
+ return None
188
+ thinking, final = text.rsplit("####", 1)
189
+ thinking = thinking.strip()
190
+ final = final.strip()
191
+ if not final:
192
+ return None
193
+ return thinking, final
194
+
195
+
196
+ def _is_token_span(span: Any) -> bool:
197
+ return isinstance(span, list) and len(span) == 2 and all(isinstance(x, int) for x in span)
198
+
199
+
200
+ def _build_cached_example(
201
+ *,
202
+ question: str,
203
+ answer: str,
204
+ tokenizer,
205
+ example_idx: int,
206
+ source_path: str,
207
+ ) -> Optional[CachedExample]:
208
+ parsed = _split_gsm8k_answer(answer)
209
+ if parsed is None:
210
+ return None
211
+ thinking_text, final_answer = parsed
212
+
213
+ prompt = question.strip()
214
+ target = f"{thinking_text}\n{final_answer}" if thinking_text else final_answer
215
+
216
+ example = CachedExample(
217
+ prompt=prompt,
218
+ target=target,
219
+ indices_to_explain=None,
220
+ attr_mask_indices=None,
221
+ sink_span=None,
222
+ thinking_span=None,
223
+ metadata={
224
+ "dataset": "math_mine",
225
+ "source_path": source_path,
226
+ "example_idx": int(example_idx),
227
+ "raw_question": question,
228
+ "raw_answer": answer,
229
+ "reference_answer": final_answer,
230
+ "boxed_answer": final_answer,
231
+ },
232
+ )
233
+ example = attach_spans_from_answer(example, tokenizer, final_answer)
234
+ if not _is_token_span(example.sink_span):
235
+ return None
236
+
237
+ # exp2 requires token-level indices_to_explain=[start_tok,end_tok] (closed interval).
238
+ indices_to_explain = list(example.sink_span)
239
+ thinking_span = example.thinking_span
240
+ if thinking_span is not None and _is_token_span(thinking_span) and indices_to_explain[0] == 0:
241
+ # No room for "thinking" tokens; avoid overlapping spans.
242
+ thinking_span = None
243
+
244
+ return CachedExample(
245
+ prompt=example.prompt,
246
+ target=example.target,
247
+ indices_to_explain=indices_to_explain,
248
+ attr_mask_indices=example.attr_mask_indices,
249
+ sink_span=indices_to_explain,
250
+ thinking_span=thinking_span,
251
+ metadata=example.metadata,
252
+ )
253
+
254
+
255
+ def _build_resampled_example(
256
+ *,
257
+ question: str,
258
+ raw_answer: str,
259
+ reference_answer: str,
260
+ generation: str,
261
+ tokenizer,
262
+ example_idx: int,
263
+ source_path: str,
264
+ judge_response: str,
265
+ generator_model: str,
266
+ judge_model: str,
267
+ ) -> Optional[CachedExample]:
268
+ parsed = split_boxed_generation(generation)
269
+ if not parsed:
270
+ return None
271
+
272
+ thinking_text, _boxed_segment, boxed_answer = parsed
273
+ target_text = f"{thinking_text}\n{boxed_answer}" if thinking_text else boxed_answer
274
+
275
+ example = CachedExample(
276
+ prompt=question.strip(),
277
+ target=target_text,
278
+ indices_to_explain=None,
279
+ attr_mask_indices=None,
280
+ sink_span=None,
281
+ thinking_span=None,
282
+ metadata={
283
+ "dataset": "math_mine",
284
+ "source_path": source_path,
285
+ "example_idx": int(example_idx),
286
+ "raw_question": question,
287
+ "raw_answer": raw_answer,
288
+ "reference_answer": reference_answer,
289
+ "judge_response": judge_response,
290
+ "generator_model": generator_model,
291
+ "judge_model": judge_model,
292
+ },
293
+ )
294
+ example = attach_spans_from_answer(example, tokenizer, boxed_answer)
295
+ if not _is_token_span(example.sink_span):
296
+ return None
297
+
298
+ indices_to_explain = list(example.sink_span)
299
+ return CachedExample(
300
+ prompt=example.prompt,
301
+ target=example.target,
302
+ indices_to_explain=indices_to_explain,
303
+ attr_mask_indices=example.attr_mask_indices,
304
+ sink_span=indices_to_explain,
305
+ thinking_span=example.thinking_span,
306
+ metadata=example.metadata,
307
+ )
308
+
309
+
310
+ def _write_jsonl(path: Path, *, examples) -> int:
311
+ path.parent.mkdir(parents=True, exist_ok=True)
312
+ count = 0
313
+ with path.open("w", encoding="utf-8") as f:
314
+ for ex in examples:
315
+ f.write(json.dumps(asdict(ex), ensure_ascii=False) + "\n")
316
+ count += 1
317
+ return count
318
+
319
+
320
+ def main() -> None:
321
+ ap = argparse.ArgumentParser("Prepare data/math_mine.json for exp2 cached JSONL.")
322
+ ap.add_argument("--in_json", type=str, default="data/math_mine.json")
323
+ ap.add_argument("--out_jsonl", type=str, default="exp/exp2/data/math.jsonl")
324
+ ap.add_argument(
325
+ "--tokenizer_model",
326
+ type=str,
327
+ required=True,
328
+ help="Tokenizer name or local path; must match the tokenizer used in exp2 attribution.",
329
+ )
330
+ ap.add_argument(
331
+ "--mode",
332
+ type=str,
333
+ choices=["map", "resample"],
334
+ default="map",
335
+ help="map=offline mapping from GSM8K answers; resample=generate+judge like exp/exp2/sample_and_filter.py.",
336
+ )
337
+
338
+ # Resample (online) options (kept compatible with exp/exp2/sample_and_filter.py).
339
+ ap.add_argument("--max_examples", type=int, default=100, help="Number of judge=True examples to keep (resample mode).")
340
+ ap.add_argument("--seed", type=int, default=42, help="Shuffle seed (only used with --shuffle).")
341
+ ap.add_argument("--shuffle", action="store_true", help="Shuffle examples before attempting (resample mode).")
342
+ ap.add_argument("--api_base", type=str, default="http://localhost:4000/v1", help="Chat API base URL.")
343
+ ap.add_argument("--api_key", type=str, default=None, help="API key; defaults to FLASHTRACE_API_KEY/OPENAI_API_KEY.")
344
+ ap.add_argument("--generator_model", type=str, default="qwen3-235b-a22b-2507")
345
+ ap.add_argument("--judge_model", type=str, default="deepseek-v3-1-terminus")
346
+ ap.add_argument("--api_timeout", type=int, default=300)
347
+ ap.add_argument("--api_max_tokens", type=int, default=8192)
348
+ ap.add_argument("--api_temperature", type=float, default=0.0)
349
+ ap.add_argument("--api_cache_ttl", type=int, default=600)
350
+ ap.add_argument("--api_cache_namespace", type=str, default="flashtrace-exp2")
351
+ ap.add_argument("--retry_delay", type=float, default=2.0)
352
+ ap.add_argument("--retries", type=int, default=2, help="Additional retries on API failure.")
353
+ ap.add_argument("--request_interval", type=float, default=1.0, help="Sleep seconds between generation calls.")
354
+ ap.add_argument("--judge_interval", type=float, default=1.0, help="Sleep seconds between judge calls.")
355
+ ap.add_argument("--rate_limit_delay", type=float, default=5.0, help="Seconds to wait on HTTP 429 before retrying.")
356
+ args = ap.parse_args()
357
+
358
+ in_path = Path(args.in_json)
359
+ out_path = Path(args.out_jsonl)
360
+ tokenizer = _load_tokenizer(args.tokenizer_model)
361
+
362
+ raw = json.loads(in_path.read_text(encoding="utf-8"))
363
+ if not isinstance(raw, list):
364
+ raise SystemExit(f"Expected a JSON array in {in_path}, got {type(raw).__name__}.")
365
+
366
+ source_total = len(raw)
367
+ total = 0
368
+ kept = 0
369
+ skipped_empty_q = 0
370
+ skipped_empty_a = 0
371
+ skipped_parse = 0
372
+ skipped_span = 0
373
+
374
+ examples = []
375
+ if args.mode == "map":
376
+ attempted = None
377
+ skipped_format = None
378
+ judged_false = None
379
+ for idx, item in enumerate(raw):
380
+ total += 1
381
+ if not isinstance(item, dict):
382
+ skipped_parse += 1
383
+ continue
384
+
385
+ question = str(item.get("question") or "")
386
+ answer = str(item.get("answer") or "")
387
+ if not question.strip():
388
+ skipped_empty_q += 1
389
+ continue
390
+ if not answer.strip():
391
+ skipped_empty_a += 1
392
+ continue
393
+
394
+ ex = _build_cached_example(
395
+ question=question,
396
+ answer=answer,
397
+ tokenizer=tokenizer,
398
+ example_idx=idx,
399
+ source_path=str(in_path),
400
+ )
401
+ if ex is None:
402
+ # distinguish parse-vs-span failure
403
+ parsed = _split_gsm8k_answer(answer)
404
+ if parsed is None:
405
+ skipped_parse += 1
406
+ else:
407
+ skipped_span += 1
408
+ continue
409
+
410
+ examples.append(ex)
411
+ kept += 1
412
+ else:
413
+ api_key = args.api_key or os.environ.get("FLASHTRACE_API_KEY") or os.environ.get("OPENAI_API_KEY")
414
+ if not api_key:
415
+ raise SystemExit("resample mode requires --api_key or FLASHTRACE_API_KEY/OPENAI_API_KEY.")
416
+
417
+ attempted = 0
418
+ skipped_format = 0
419
+ judged_false = 0
420
+
421
+ indices = list(range(len(raw)))
422
+ if bool(args.shuffle):
423
+ import random
424
+
425
+ rnd = random.Random(int(args.seed))
426
+ rnd.shuffle(indices)
427
+
428
+ kept_bar = tqdm(total=int(args.max_examples), desc="Kept (judge=True)", position=1, leave=False)
429
+ for loop_idx in tqdm(indices, total=len(indices), desc="Resampling"):
430
+ if kept >= int(args.max_examples):
431
+ break
432
+
433
+ total += 1
434
+ item = raw[loop_idx]
435
+ if not isinstance(item, dict):
436
+ skipped_parse += 1
437
+ continue
438
+
439
+ question = str(item.get("question") or "")
440
+ answer = str(item.get("answer") or "")
441
+ if not question.strip():
442
+ skipped_empty_q += 1
443
+ continue
444
+ if not answer.strip():
445
+ skipped_empty_a += 1
446
+ continue
447
+
448
+ parsed = _split_gsm8k_answer(answer)
449
+ if parsed is None:
450
+ skipped_parse += 1
451
+ continue
452
+ _ref_thinking, reference_answer = parsed
453
+
454
+ attempted += 1
455
+ gen_messages = build_gen_messages(question.strip())
456
+
457
+ # Step 1: generation
458
+ for attempt in range(int(args.retries) + 1):
459
+ try:
460
+ generation = call_chat_api(
461
+ str(args.api_base),
462
+ str(api_key),
463
+ str(args.generator_model),
464
+ gen_messages,
465
+ timeout=int(args.api_timeout),
466
+ max_tokens=int(args.api_max_tokens),
467
+ temperature=float(args.api_temperature),
468
+ cache_ttl=int(args.api_cache_ttl),
469
+ cache_namespace=str(args.api_cache_namespace) if args.api_cache_namespace else None,
470
+ rate_limit_delay=float(args.rate_limit_delay) if args.rate_limit_delay is not None else None,
471
+ )
472
+ break
473
+ except RateLimitError as e:
474
+ if attempt >= int(args.retries):
475
+ raise
476
+ time.sleep(float(e.wait_seconds))
477
+ except Exception: # noqa: BLE001
478
+ if attempt >= int(args.retries):
479
+ raise
480
+ time.sleep(float(args.retry_delay))
481
+ if float(args.request_interval) > 0:
482
+ time.sleep(float(args.request_interval))
483
+
484
+ parsed_gen = split_boxed_generation(generation)
485
+ if not parsed_gen:
486
+ skipped_format += 1
487
+ print(f"[attempt={attempted}] skipped=format")
488
+ continue
489
+
490
+ thinking_text, _boxed_segment, boxed_answer = parsed_gen
491
+ judge_messages = build_judge_messages(reference_answer, boxed_answer)
492
+
493
+ ok = False
494
+ judge_resp = ""
495
+ for attempt in range(int(args.retries) + 1):
496
+ try:
497
+ judge_resp = call_chat_api(
498
+ str(args.api_base),
499
+ str(api_key),
500
+ str(args.judge_model),
501
+ judge_messages,
502
+ timeout=int(args.api_timeout),
503
+ max_tokens=64,
504
+ temperature=0.0,
505
+ cache_ttl=int(args.api_cache_ttl),
506
+ cache_namespace=str(args.api_cache_namespace) if args.api_cache_namespace else None,
507
+ rate_limit_delay=float(args.rate_limit_delay) if args.rate_limit_delay is not None else None,
508
+ )
509
+ ok = parse_bool(judge_resp)
510
+ break
511
+ except RateLimitError as e:
512
+ if attempt >= int(args.retries):
513
+ raise
514
+ time.sleep(float(e.wait_seconds))
515
+ except Exception: # noqa: BLE001
516
+ if attempt >= int(args.retries):
517
+ raise
518
+ time.sleep(float(args.retry_delay))
519
+ if float(args.judge_interval) > 0:
520
+ time.sleep(float(args.judge_interval))
521
+
522
+ if not ok:
523
+ judged_false += 1
524
+ print(f"[attempt={attempted}] judge=filtered")
525
+ continue
526
+
527
+ ex = _build_resampled_example(
528
+ question=question,
529
+ raw_answer=answer,
530
+ reference_answer=reference_answer,
531
+ generation=generation,
532
+ tokenizer=tokenizer,
533
+ example_idx=int(loop_idx),
534
+ source_path=str(in_path),
535
+ judge_response=judge_resp,
536
+ generator_model=str(args.generator_model),
537
+ judge_model=str(args.judge_model),
538
+ )
539
+ if ex is None:
540
+ skipped_span += 1
541
+ print(f"[attempt={attempted}] skipped=span")
542
+ continue
543
+
544
+ examples.append(ex)
545
+ kept += 1
546
+ kept_bar.update(1)
547
+ print(f"[attempt={attempted}] judge=kept")
548
+
549
+ kept_bar.close()
550
+
551
+ written = _write_jsonl(out_path, examples=examples)
552
+ if written != kept:
553
+ raise SystemExit(f"Internal error: written={written} != kept={kept}")
554
+
555
+ print(
556
+ json.dumps(
557
+ {
558
+ "in_json": str(in_path),
559
+ "out_jsonl": str(out_path),
560
+ "tokenizer_model": args.tokenizer_model,
561
+ "mode": str(args.mode),
562
+ "source_total": int(source_total),
563
+ "visited": total,
564
+ "kept": kept,
565
+ "skipped_empty_question": skipped_empty_q,
566
+ "skipped_empty_answer": skipped_empty_a,
567
+ "skipped_parse": skipped_parse,
568
+ "skipped_span": skipped_span,
569
+ "attempted": attempted,
570
+ "skipped_format": skipped_format,
571
+ "judged_false": judged_false,
572
+ "max_examples": int(args.max_examples) if str(args.mode) == "resample" else None,
573
+ "api_base": str(args.api_base) if str(args.mode) == "resample" else None,
574
+ "generator_model": str(args.generator_model) if str(args.mode) == "resample" else None,
575
+ "judge_model": str(args.judge_model) if str(args.mode) == "resample" else None,
576
+ },
577
+ ensure_ascii=False,
578
+ indent=2,
579
+ )
580
+ )
581
+
582
+
583
+ if __name__ == "__main__":
584
+ main()
exp/exp2/migrate_indices_to_explain_token_span.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Migrate exp2 cached JSONL to token-span `indices_to_explain`.
3
+
4
+ This converts legacy caches that used sentence indices (e.g. `[-2]`) into the
5
+ token-span format:
6
+
7
+ indices_to_explain = [start_tok, end_tok]
8
+
9
+ Where the span points to the boxed-inner (final answer) token span in `target`
10
+ under `tokenizer(target, add_special_tokens=False)`.
11
+
12
+ Rule:
13
+ 1) If `sink_span` exists and looks valid -> copy it to `indices_to_explain`
14
+ 2) Else try to recompute spans from `target` + `metadata.boxed_answer` using
15
+ `exp/exp2/dataset_utils.attach_spans_from_answer`
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import json
22
+ import sys
23
+ from pathlib import Path
24
+ from typing import Any, Dict, Optional
25
+
26
+ from transformers import AutoTokenizer
27
+
28
+
29
+ def _ensure_repo_root_on_path() -> None:
30
+ repo_root = Path(__file__).resolve().parents[2]
31
+ if str(repo_root) not in sys.path:
32
+ sys.path.insert(0, str(repo_root))
33
+
34
+
35
+ def _is_token_span(span: Any) -> bool:
36
+ return isinstance(span, list) and len(span) == 2 and all(isinstance(x, int) for x in span)
37
+
38
+
39
+ def _load_tokenizer(tokenizer_model: str):
40
+ tok_path = Path(tokenizer_model)
41
+ if tok_path.exists():
42
+ return AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
43
+ return AutoTokenizer.from_pretrained(tokenizer_model)
44
+
45
+
46
+ def _migrate_obj(obj: Dict[str, Any], tokenizer) -> tuple[Dict[str, Any], bool]:
47
+ sink_span = obj.get("sink_span")
48
+ if _is_token_span(sink_span):
49
+ obj["indices_to_explain"] = sink_span
50
+ return obj, True
51
+
52
+ _ensure_repo_root_on_path()
53
+ from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer # noqa: E402
54
+
55
+ example = CachedExample(
56
+ prompt=obj.get("prompt") or "",
57
+ target=obj.get("target"),
58
+ indices_to_explain=obj.get("indices_to_explain"),
59
+ attr_mask_indices=obj.get("attr_mask_indices"),
60
+ sink_span=obj.get("sink_span"),
61
+ thinking_span=obj.get("thinking_span"),
62
+ metadata=obj.get("metadata") or {},
63
+ )
64
+ answer_text = (example.metadata.get("boxed_answer") or "").strip() or None
65
+ migrated = attach_spans_from_answer(example, tokenizer, answer_text)
66
+ if not _is_token_span(migrated.sink_span):
67
+ return obj, False
68
+
69
+ obj["sink_span"] = migrated.sink_span
70
+ obj["thinking_span"] = migrated.thinking_span
71
+ obj["indices_to_explain"] = migrated.sink_span
72
+ obj["metadata"] = migrated.metadata
73
+ return obj, True
74
+
75
+
76
+ def main() -> None:
77
+ ap = argparse.ArgumentParser()
78
+ ap.add_argument("--in_jsonl", type=str, required=True)
79
+ ap.add_argument("--out_jsonl", type=str, required=True)
80
+ ap.add_argument("--tokenizer_model", type=str, required=True)
81
+ ap.add_argument("--strict", action="store_true", help="Fail on any line that cannot be migrated.")
82
+ args = ap.parse_args()
83
+
84
+ tokenizer = _load_tokenizer(args.tokenizer_model)
85
+
86
+ in_path = Path(args.in_jsonl)
87
+ out_path = Path(args.out_jsonl)
88
+
89
+ try:
90
+ same_path = in_path.resolve() == out_path.resolve()
91
+ except FileNotFoundError:
92
+ same_path = False
93
+
94
+ tmp_out_path = out_path
95
+ if same_path:
96
+ tmp_out_path = out_path.with_name(out_path.name + ".tmp")
97
+ if tmp_out_path.exists():
98
+ tmp_out_path.unlink()
99
+
100
+ tmp_out_path.parent.mkdir(parents=True, exist_ok=True)
101
+
102
+ total = 0
103
+ migrated_ok = 0
104
+ bad = 0
105
+
106
+ with in_path.open("r", encoding="utf-8") as fin, tmp_out_path.open("w", encoding="utf-8") as fout:
107
+ for line_no, line in enumerate(fin, start=1):
108
+ if not line.strip():
109
+ continue
110
+ total += 1
111
+ obj: Dict[str, Any] = json.loads(line)
112
+ new_obj, ok = _migrate_obj(obj, tokenizer)
113
+ if ok:
114
+ migrated_ok += 1
115
+ else:
116
+ bad += 1
117
+ if args.strict:
118
+ raise RuntimeError(f"cannot migrate line {line_no}: cannot resolve sink_span token span")
119
+ fout.write(json.dumps(new_obj, ensure_ascii=False) + "\n")
120
+
121
+ if same_path:
122
+ tmp_out_path.replace(out_path)
123
+ print(f"[done] total={total} migrated_ok={migrated_ok} bad={bad} wrote={out_path} (in-place)")
124
+ else:
125
+ print(f"[done] total={total} migrated_ok={migrated_ok} bad={bad} wrote={out_path}")
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
exp/exp2/out.log ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [1/500] judge=kept
2
+ [2/500] judge=kept
3
+ [3/500] judge=kept
4
+ [4/500] judge=kept
5
+ [5/500] judge=kept
6
+ [6/500] judge=kept
7
+ [7/500] judge=kept
8
+ [8/500] judge=kept
9
+ [9/500] judge=kept
10
+ [10/500] judge=kept
11
+ [11/500] judge=kept
12
+ [12/500] judge=kept
13
+ [13/500] judge=kept
14
+ [14/500] judge=kept
15
+ [15/500] judge=kept
16
+ [16/500] judge=kept
17
+ [17/500] judge=kept
18
+ [18/500] judge=kept
19
+ [19/500] judge=kept
20
+ [20/500] judge=kept
21
+ [21/500] judge=kept
22
+ [22/500] judge=kept
23
+ [23/500] judge=kept
24
+ [24/500] judge=kept
25
+ [25/500] judge=kept
26
+ [26/500] judge=kept
27
+ [27/500] judge=kept
28
+ [28/500] judge=kept
29
+ [29/500] judge=kept
30
+ [30/500] judge=kept
31
+ [31/500] judge=kept
32
+ [32/500] judge=kept
33
+ [33/500] judge=kept
34
+ [34/500] judge=kept
35
+ [35/500] judge=kept
36
+ [36/500] judge=kept
37
+ [37/500] judge=kept
38
+ [38/500] judge=kept
39
+ [39/500] judge=kept
40
+ [40/500] judge=kept
41
+ [41/500] judge=kept
42
+ [42/500] judge=kept
43
+ [43/500] judge=kept
44
+ [44/500] judge=kept
45
+ [45/500] judge=kept
46
+ [46/500] judge=kept
47
+ [47/500] judge=kept
48
+ [48/500] judge=kept
49
+ [49/500] judge=kept
50
+ [50/500] judge=kept
51
+ [51/500] judge=kept
52
+ [52/500] judge=kept
53
+ [53/500] judge=kept
54
+ [54/500] judge=kept
55
+ [55/500] judge=kept
56
+ [56/500] judge=kept
57
+ [57/500] judge=kept
58
+ [58/500] judge=kept
59
+ [59/500] judge=kept
60
+ [60/500] judge=kept
61
+ [61/500] judge=kept
62
+ [62/500] judge=kept
63
+ [63/500] skipped=format
64
+ [64/500] judge=kept
65
+ [65/500] judge=kept
66
+ [66/500] judge=kept
67
+ [67/500] judge=kept
68
+ [68/500] judge=kept
69
+ [69/500] judge=kept
70
+ [70/500] judge=kept
71
+ [71/500] judge=kept
72
+ [72/500] judge=kept
73
+ [73/500] judge=kept
74
+ [74/500] judge=kept
75
+ [75/500] judge=kept
76
+ [76/500] judge=kept
77
+ [77/500] judge=kept
78
+ [78/500] judge=kept
79
+ [79/500] judge=kept
80
+ [80/500] judge=kept
81
+ [81/500] judge=kept
82
+ [82/500] judge=kept
83
+ [83/500] judge=kept
84
+ [84/500] judge=kept
85
+ [85/500] judge=kept
86
+ [86/500] judge=kept
87
+ [87/500] judge=kept
88
+ [88/500] judge=kept
89
+ [89/500] judge=kept
90
+ [90/500] judge=kept
91
+ [91/500] judge=kept
92
+ [92/500] judge=kept
93
+ [93/500] judge=kept
94
+ [94/500] judge=kept
95
+ [95/500] judge=kept
96
+ [96/500] judge=kept
97
+ [97/500] judge=kept
98
+ [98/500] judge=kept
99
+ [99/500] judge=kept
100
+ [100/500] judge=kept
101
+ [101/500] judge=kept
102
+ Kept 100 / target 100 (attempted 101 / 500) -> exp/exp2/data/data/ruler_multihop/1024/vt_h10_c1/validation.jsonl.jsonl
exp/exp2/run_exp.py ADDED
@@ -0,0 +1,1296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Experiment 2 runner: token-level faithfulness (generation perturbation).
4
+
5
+ AT2 is omitted.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import hashlib
12
+ import json
13
+ import os
14
+ import sys
15
+ from itertools import islice
16
+ import math
17
+ import time
18
+ from pathlib import Path
19
+ from typing import Any, Dict, List, Optional, Tuple
20
+
21
+ # Early CUDA mask handling: set CUDA_VISIBLE_DEVICES before importing torch.
22
+ def _early_set_cuda_visible_devices():
23
+ parser = argparse.ArgumentParser(add_help=False)
24
+ parser.add_argument("--cuda", type=str, default=None)
25
+ # parse_known_args keeps the full argv for later parsing by the main parser
26
+ args, _ = parser.parse_known_args(sys.argv[1:])
27
+ if args.cuda and "," in args.cuda:
28
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
29
+
30
+
31
+ _early_set_cuda_visible_devices()
32
+
33
+ import numpy as np
34
+ import torch
35
+ from transformers import AutoModelForCausalLM, AutoTokenizer, utils
36
+
37
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
38
+
39
+ from pathlib import Path
40
+
41
+ # ensure repo root on path
42
+ REPO_ROOT = Path(__file__).resolve().parents[2]
43
+ if str(REPO_ROOT) not in sys.path:
44
+ sys.path.insert(0, str(REPO_ROOT))
45
+
46
+ import llm_attr
47
+ import llm_attr_eval
48
+ from attribution_datasets import AttributionExample
49
+ from exp.exp2 import dataset_utils as ds_utils
50
+
51
+ utils.logging.set_verbosity_error()
52
+
53
+
54
+ def _sha1_text(text: str) -> str:
55
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
56
+
57
+
58
+ def _infer_attnlrp_spans_from_hops(
59
+ raw_attributions: Any,
60
+ *,
61
+ gen_len: int,
62
+ ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
63
+ if not raw_attributions:
64
+ return (0, max(0, gen_len - 1)), (0, max(0, gen_len - 1))
65
+ sink_span = tuple(int(x) for x in raw_attributions[0].sink_range)
66
+ if len(raw_attributions) >= 2:
67
+ thinking_span = tuple(int(x) for x in raw_attributions[1].sink_range)
68
+ else:
69
+ thinking_span = sink_span
70
+ return sink_span, thinking_span
71
+
72
+
73
+ def _build_hop_trace_payload(
74
+ attr_func: str,
75
+ attr: Any,
76
+ *,
77
+ indices_to_explain: List[int],
78
+ ) -> Optional[Dict[str, np.ndarray]]:
79
+ """Extract per-hop vectors (postprocessed) and minimal span metadata."""
80
+ prompt_len = int(len(getattr(attr, "prompt_tokens", []) or []))
81
+ gen_len = int(len(getattr(attr, "generation_tokens", []) or []))
82
+ total_len = prompt_len + gen_len
83
+ if total_len <= 0:
84
+ return None
85
+
86
+ hop_vectors: List[torch.Tensor] = []
87
+ sink_span_gen: Optional[Tuple[int, int]] = None
88
+ thinking_span_gen: Optional[Tuple[int, int]] = None
89
+ attnlrp_neg_handling: str = ""
90
+ attnlrp_norm_mode: str = ""
91
+ attnlrp_ratio_enabled: int = -1
92
+
93
+ # IFR multi-hop variants expose projected hop vectors via metadata["ifr"]["per_hop_projected"].
94
+ ifr_meta = (getattr(attr, "metadata", None) or {}).get("ifr") or {}
95
+ ifr_per_hop = ifr_meta.get("per_hop_projected") or []
96
+
97
+ if ifr_per_hop:
98
+ hop_vectors = [torch.as_tensor(v, dtype=torch.float32) for v in ifr_per_hop]
99
+ sink_span_gen = ifr_meta.get("sink_span_generation")
100
+ thinking_span_gen = ifr_meta.get("thinking_span_generation")
101
+ if sink_span_gen is not None:
102
+ sink_span_gen = tuple(int(x) for x in sink_span_gen)
103
+ if thinking_span_gen is not None:
104
+ thinking_span_gen = tuple(int(x) for x in thinking_span_gen)
105
+
106
+ elif attr_func in ("ft_attnlrp", "attnlrp_aggregated_multi_hop"):
107
+ meta = getattr(attr, "metadata", None) or {}
108
+ attnlrp_neg_handling = str(meta.get("neg_handling") or "")
109
+ attnlrp_norm_mode = str(meta.get("norm_mode") or "")
110
+ if meta.get("ratio_enabled") is not None:
111
+ attnlrp_ratio_enabled = int(bool(meta.get("ratio_enabled")))
112
+ multi_hop = meta.get("multi_hop_result")
113
+ if multi_hop is None:
114
+ return None
115
+ raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
116
+ if not raw_attributions:
117
+ return None
118
+ hop_vectors = [
119
+ torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32)
120
+ for hop in raw_attributions
121
+ ]
122
+ sink_span_gen, thinking_span_gen = _infer_attnlrp_spans_from_hops(raw_attributions, gen_len=gen_len)
123
+ sink_override = meta.get("sink_span")
124
+ thinking_override = meta.get("thinking_span")
125
+ if sink_override is not None:
126
+ sink_span_gen = tuple(int(x) for x in sink_override)
127
+ if thinking_override is not None:
128
+ thinking_span_gen = tuple(int(x) for x in thinking_override)
129
+
130
+ else:
131
+ return None
132
+
133
+ if sink_span_gen is None:
134
+ sink_span_gen = (0, max(0, gen_len - 1))
135
+ if thinking_span_gen is None:
136
+ thinking_span_gen = sink_span_gen
137
+
138
+ stacked = torch.stack([v.reshape(-1) for v in hop_vectors], dim=0)
139
+ if stacked.shape[1] != total_len:
140
+ raise ValueError(
141
+ f"Hop vector length mismatch for {attr_func}: expected T={total_len}, got {stacked.shape[1]}."
142
+ )
143
+
144
+ return {
145
+ "vh": stacked.detach().cpu().numpy().astype(np.float32, copy=False),
146
+ "prompt_len": np.asarray(prompt_len, dtype=np.int64),
147
+ "gen_len": np.asarray(gen_len, dtype=np.int64),
148
+ "sink_span_gen": np.asarray(sink_span_gen, dtype=np.int64),
149
+ "thinking_span_gen": np.asarray(thinking_span_gen, dtype=np.int64),
150
+ "indices_to_explain_gen": np.asarray(indices_to_explain, dtype=np.int64),
151
+ "attnlrp_neg_handling": np.asarray(attnlrp_neg_handling, dtype="U16"),
152
+ "attnlrp_norm_mode": np.asarray(attnlrp_norm_mode, dtype="U16"),
153
+ "attnlrp_ratio_enabled": np.asarray(attnlrp_ratio_enabled, dtype=np.int64),
154
+ }
155
+
156
+
157
+ def _write_hop_trace(
158
+ trace_dir: Path,
159
+ *,
160
+ example_idx: int,
161
+ attr_func: str,
162
+ prompt: str,
163
+ target: Optional[str],
164
+ payload: Dict[str, np.ndarray],
165
+ manifest_handle,
166
+ ) -> None:
167
+ trace_dir.mkdir(parents=True, exist_ok=True)
168
+ npz_name = f"ex_{example_idx:06d}.npz"
169
+ npz_path = trace_dir / npz_name
170
+ np.savez_compressed(npz_path, **payload)
171
+
172
+ record = {
173
+ "example_idx": int(example_idx),
174
+ "attr_func": attr_func,
175
+ "file": npz_name,
176
+ "prompt_sha1": _sha1_text(prompt),
177
+ "target_sha1": _sha1_text(target) if target is not None else None,
178
+ "prompt_len": int(payload["prompt_len"].item()),
179
+ "gen_len": int(payload["gen_len"].item()),
180
+ "n_hops_plus_one": int(payload["vh"].shape[0]),
181
+ "total_len": int(payload["vh"].shape[1]),
182
+ "sink_span_gen": payload["sink_span_gen"].tolist(),
183
+ "thinking_span_gen": payload["thinking_span_gen"].tolist(),
184
+ "indices_to_explain_gen": payload["indices_to_explain_gen"].tolist(),
185
+ "attnlrp_neg_handling": str(payload["attnlrp_neg_handling"].item()),
186
+ "attnlrp_norm_mode": str(payload["attnlrp_norm_mode"].item()),
187
+ "attnlrp_ratio_enabled": int(payload["attnlrp_ratio_enabled"].item()),
188
+ }
189
+ manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
190
+ manifest_handle.flush()
191
+
192
+
193
+ def _parse_modes(mode_args: Any) -> List[str]:
194
+ """Parse --mode which may be provided as multiple args and/or comma-separated."""
195
+ if mode_args is None:
196
+ raw_parts: List[str] = []
197
+ elif isinstance(mode_args, str):
198
+ raw_parts = [mode_args]
199
+ else:
200
+ raw_parts = [str(x) for x in mode_args]
201
+
202
+ modes: List[str] = []
203
+ for chunk in raw_parts:
204
+ for part in str(chunk).split(","):
205
+ m = part.strip()
206
+ if m:
207
+ modes.append(m)
208
+
209
+ # Default to faithfulness_gen for backward compatibility.
210
+ if not modes:
211
+ modes = ["faithfulness_gen"]
212
+
213
+ allowed = {"faithfulness_gen", "recovery_ruler"}
214
+ seen: set[str] = set()
215
+ unique: List[str] = []
216
+ for m in modes:
217
+ if m not in seen:
218
+ unique.append(m)
219
+ seen.add(m)
220
+
221
+ unknown = [m for m in unique if m not in allowed]
222
+ if unknown:
223
+ raise SystemExit(f"Unsupported --mode value(s): {unknown}. Allowed: {sorted(allowed)}.")
224
+
225
+ return unique
226
+
227
+
228
+ def _trace_run_tag(
229
+ testing_dict: Dict[str, Any],
230
+ *,
231
+ modes: List[str],
232
+ total: int,
233
+ ) -> str:
234
+ attr_func = str(testing_dict.get("attr_func") or "attr")
235
+ parts = [attr_func]
236
+
237
+ if attr_func in (
238
+ "ifr_multi_hop",
239
+ "ifr_in_all_gen",
240
+ "ifr_multi_hop_stop_words",
241
+ "ifr_multi_hop_both",
242
+ "ifr_multi_hop_split_hop",
243
+ "ft_attnlrp",
244
+ "attnlrp_aggregated_multi_hop",
245
+ ):
246
+ parts.append(f"n{int(testing_dict.get('n_hops', 0))}")
247
+
248
+ if attr_func in ("attnlrp", "ft_attnlrp", "attnlrp_aggregated_multi_hop"):
249
+ parts.append(f"neg{str(testing_dict.get('attnlrp_neg_handling', ''))}")
250
+ parts.append(f"norm{str(testing_dict.get('attnlrp_norm_mode', ''))}")
251
+
252
+ if modes:
253
+ parts.append("m" + "+".join(modes))
254
+
255
+ parts.append(f"{int(total)}ex")
256
+ return "_".join(parts)
257
+
258
+
259
+ def _token_importance_vector(attr: torch.Tensor) -> np.ndarray:
260
+ """Return token importance vector w = sum_rows(attr) in shape [P+G]."""
261
+ w = torch.nan_to_num(attr.sum(0).to(dtype=torch.float32), nan=0.0).clamp(min=0.0)
262
+ return w.detach().cpu().numpy().astype(np.float32, copy=False)
263
+
264
+
265
+ def _build_sample_trace_payload(
266
+ example: ds_utils.CachedExample,
267
+ *,
268
+ attr_list: List[torch.Tensor],
269
+ prompt_len: int,
270
+ user_prompt_indices: Optional[List[int]],
271
+ keep_prompt_token_indices: Optional[List[int]],
272
+ gold_prompt_token_indices: Optional[List[int]],
273
+ hop_payload: Optional[Dict[str, np.ndarray]],
274
+ faithfulness_scores: Optional[np.ndarray],
275
+ recovery_scores: Optional[np.ndarray],
276
+ time_attr_s: Optional[float],
277
+ time_faith_s: Optional[float],
278
+ time_recovery_s: Optional[float],
279
+ ) -> Dict[str, np.ndarray]:
280
+ seq_attr, row_attr, rec_attr = attr_list
281
+ gen_len = int(seq_attr.shape[0])
282
+
283
+ v_seq_all = _token_importance_vector(seq_attr)
284
+ v_row_all = _token_importance_vector(row_attr)
285
+ v_rec_all = _token_importance_vector(rec_attr)
286
+
287
+ payload: Dict[str, np.ndarray] = {
288
+ "v_seq_all": v_seq_all,
289
+ "v_row_all": v_row_all,
290
+ "v_rec_all": v_rec_all,
291
+ "v_seq_prompt": v_seq_all[:prompt_len],
292
+ "v_row_prompt": v_row_all[:prompt_len],
293
+ "v_rec_prompt": v_rec_all[:prompt_len],
294
+ "prompt_len": np.asarray(int(prompt_len), dtype=np.int64),
295
+ "gen_len": np.asarray(int(gen_len), dtype=np.int64),
296
+ "indices_to_explain_gen": np.asarray(list(example.indices_to_explain or []), dtype=np.int64),
297
+ }
298
+
299
+ if example.sink_span is not None:
300
+ payload["sink_span_gen"] = np.asarray(list(example.sink_span), dtype=np.int64)
301
+ if example.thinking_span is not None:
302
+ payload["thinking_span_gen"] = np.asarray(list(example.thinking_span), dtype=np.int64)
303
+
304
+ if user_prompt_indices is not None:
305
+ payload["user_prompt_indices"] = np.asarray(list(user_prompt_indices), dtype=np.int64)
306
+ if keep_prompt_token_indices is not None:
307
+ payload["keep_prompt_token_indices"] = np.asarray(list(keep_prompt_token_indices), dtype=np.int64)
308
+ if gold_prompt_token_indices is not None:
309
+ payload["gold_prompt_token_indices"] = np.asarray(list(gold_prompt_token_indices), dtype=np.int64)
310
+
311
+ if faithfulness_scores is not None:
312
+ payload["faithfulness_scores"] = np.asarray(faithfulness_scores, dtype=np.float64)
313
+ if recovery_scores is not None:
314
+ payload["recovery_scores"] = np.asarray(recovery_scores, dtype=np.float64)
315
+
316
+ if time_attr_s is not None:
317
+ payload["time_attr_s"] = np.asarray(float(time_attr_s), dtype=np.float64)
318
+ if time_faith_s is not None:
319
+ payload["time_faith_s"] = np.asarray(float(time_faith_s), dtype=np.float64)
320
+ if time_recovery_s is not None:
321
+ payload["time_recovery_s"] = np.asarray(float(time_recovery_s), dtype=np.float64)
322
+
323
+ if hop_payload is not None:
324
+ for k, v in hop_payload.items():
325
+ if k in payload:
326
+ continue
327
+ payload[k] = v
328
+
329
+ return payload
330
+
331
+
332
+ def _write_sample_trace(
333
+ trace_dir: Path,
334
+ *,
335
+ example_idx: int,
336
+ attr_func: str,
337
+ prompt: str,
338
+ target: Optional[str],
339
+ payload: Dict[str, np.ndarray],
340
+ manifest_handle,
341
+ recovery_skipped_reason: Optional[str],
342
+ ) -> None:
343
+ trace_dir.mkdir(parents=True, exist_ok=True)
344
+ npz_name = f"ex_{example_idx:06d}.npz"
345
+ npz_path = trace_dir / npz_name
346
+ np.savez_compressed(npz_path, **payload)
347
+
348
+ prompt_len = int(np.asarray(payload.get("prompt_len", 0)).item())
349
+ gen_len = int(np.asarray(payload.get("gen_len", 0)).item())
350
+ record: Dict[str, Any] = {
351
+ "example_idx": int(example_idx),
352
+ "attr_func": attr_func,
353
+ "file": npz_name,
354
+ "prompt_sha1": _sha1_text(prompt),
355
+ "target_sha1": _sha1_text(target) if target is not None else None,
356
+ "prompt_len": prompt_len,
357
+ "gen_len": gen_len,
358
+ "indices_to_explain_gen": payload.get("indices_to_explain_gen").tolist()
359
+ if payload.get("indices_to_explain_gen") is not None
360
+ else None,
361
+ "sink_span_gen": payload.get("sink_span_gen").tolist() if payload.get("sink_span_gen") is not None else None,
362
+ "thinking_span_gen": payload.get("thinking_span_gen").tolist()
363
+ if payload.get("thinking_span_gen") is not None
364
+ else None,
365
+ "faithfulness_scores": payload.get("faithfulness_scores").tolist()
366
+ if payload.get("faithfulness_scores") is not None
367
+ else None,
368
+ "recovery_scores": payload.get("recovery_scores").tolist() if payload.get("recovery_scores") is not None else None,
369
+ "recovery_skipped_reason": recovery_skipped_reason,
370
+ "time_attr_s": float(np.asarray(payload.get("time_attr_s")).item()) if payload.get("time_attr_s") is not None else None,
371
+ "time_faith_s": float(np.asarray(payload.get("time_faith_s")).item()) if payload.get("time_faith_s") is not None else None,
372
+ "time_recovery_s": float(np.asarray(payload.get("time_recovery_s")).item())
373
+ if payload.get("time_recovery_s") is not None
374
+ else None,
375
+ }
376
+
377
+ # Derived, sample-level bookkeeping (token lengths and per-sample MAS/RISE).
378
+ record["input_len"] = int(prompt_len)
379
+
380
+ sink_span = record.get("sink_span_gen")
381
+ if isinstance(sink_span, list) and len(sink_span) == 2:
382
+ try:
383
+ start = int(sink_span[0])
384
+ end = int(sink_span[1])
385
+ record["output_len"] = (end - start + 1) if end >= start else None
386
+ except Exception:
387
+ record["output_len"] = None
388
+ else:
389
+ record["output_len"] = None
390
+
391
+ thinking_span = record.get("thinking_span_gen")
392
+ if isinstance(thinking_span, list) and len(thinking_span) == 2:
393
+ try:
394
+ start = int(thinking_span[0])
395
+ end = int(thinking_span[1])
396
+ record["cot_len"] = (end - start + 1) if end >= start else None
397
+ except Exception:
398
+ record["cot_len"] = None
399
+ else:
400
+ record["cot_len"] = None
401
+
402
+ record["rise_seq"] = None
403
+ record["mas_seq"] = None
404
+ record["rise_row"] = None
405
+ record["mas_row"] = None
406
+ record["rise_rec"] = None
407
+ record["mas_rec"] = None
408
+ faith = record.get("faithfulness_scores")
409
+ if isinstance(faith, list) and len(faith) == 3:
410
+ try:
411
+ record["rise_seq"] = float(faith[0][0])
412
+ record["mas_seq"] = float(faith[0][1])
413
+ record["rise_row"] = float(faith[1][0])
414
+ record["mas_row"] = float(faith[1][1])
415
+ record["rise_rec"] = float(faith[2][0])
416
+ record["mas_rec"] = float(faith[2][1])
417
+ except Exception:
418
+ pass
419
+
420
+ if payload.get("vh") is not None:
421
+ vh = payload["vh"]
422
+ record["n_hops_plus_one"] = int(vh.shape[0])
423
+ record["total_len"] = int(vh.shape[1])
424
+ record["attnlrp_neg_handling"] = str(payload.get("attnlrp_neg_handling").item()) if payload.get("attnlrp_neg_handling") is not None else ""
425
+ record["attnlrp_norm_mode"] = str(payload.get("attnlrp_norm_mode").item()) if payload.get("attnlrp_norm_mode") is not None else ""
426
+ record["attnlrp_ratio_enabled"] = int(payload.get("attnlrp_ratio_enabled").item()) if payload.get("attnlrp_ratio_enabled") is not None else -1
427
+
428
+ manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
429
+ manifest_handle.flush()
430
+
431
+
432
+ def _compute_faithfulness_scores(
433
+ testing_dict: Dict[str, Any],
434
+ *,
435
+ attr_list: List[torch.Tensor],
436
+ prompt_len: int,
437
+ prompt: str,
438
+ generation: str,
439
+ llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
440
+ user_prompt_indices: Optional[List[int]],
441
+ keep_prompt_token_indices: Optional[List[int]],
442
+ ) -> np.ndarray:
443
+ attr_func = str(testing_dict.get("attr_func") or "")
444
+ results: List[Tuple[float, float, float]] = []
445
+ for attr in attr_list:
446
+ attr_prompt = attr[:, :prompt_len]
447
+ if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
448
+ import ft_ifr_improve
449
+
450
+ scores = ft_ifr_improve.faithfulness_test_skip_tokens(
451
+ llm_evaluator,
452
+ attr_prompt,
453
+ prompt,
454
+ generation,
455
+ keep_prompt_token_indices=keep_prompt_token_indices,
456
+ user_prompt_indices=user_prompt_indices,
457
+ )
458
+ elif user_prompt_indices is not None:
459
+ scores = _faithfulness_test_with_user_prompt_indices(
460
+ llm_evaluator,
461
+ attr_prompt,
462
+ prompt,
463
+ generation,
464
+ user_prompt_indices=user_prompt_indices,
465
+ )
466
+ else:
467
+ scores = llm_evaluator.faithfulness_test(attr_prompt, prompt, generation)
468
+ results.append(scores)
469
+ return np.asarray(results, dtype=np.float64)
470
+
471
+
472
+ def _compute_recovery_scores(
473
+ testing_dict: Dict[str, Any],
474
+ *,
475
+ attr_list: List[torch.Tensor],
476
+ prompt_len: int,
477
+ gold_prompt_token_indices: List[int],
478
+ llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
479
+ keep_prompt_token_indices: Optional[List[int]],
480
+ ) -> Tuple[Optional[np.ndarray], Optional[str]]:
481
+ attr_func = str(testing_dict.get("attr_func") or "")
482
+
483
+ if prompt_len <= 0:
484
+ return None, "empty_prompt_len"
485
+
486
+ gold_prompt = [int(x) for x in (gold_prompt_token_indices or [])]
487
+ if not gold_prompt:
488
+ return None, "empty_gold_prompt"
489
+
490
+ if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
491
+ import ft_ifr_improve
492
+
493
+ keep_set = {int(x) for x in keep_prompt_token_indices}
494
+ gold_filtered = [idx for idx in gold_prompt if int(idx) in keep_set]
495
+ if not gold_filtered:
496
+ return None, "empty_gold_after_keep_filter"
497
+
498
+ scores = [
499
+ ft_ifr_improve.evaluate_attr_recovery_skip_tokens(
500
+ attr[:, :prompt_len],
501
+ keep_prompt_token_indices=keep_prompt_token_indices,
502
+ gold_prompt_token_indices=gold_prompt,
503
+ top_fraction=0.1,
504
+ )
505
+ for attr in attr_list
506
+ ]
507
+ else:
508
+ scores = [
509
+ llm_evaluator.evaluate_attr_recovery(
510
+ attr,
511
+ prompt_len=prompt_len,
512
+ gold_prompt_token_indices=gold_prompt,
513
+ top_fraction=0.1,
514
+ )
515
+ for attr in attr_list
516
+ ]
517
+
518
+ return np.asarray(scores, dtype=np.float64), None
519
+
520
+
521
+ def evaluate_dataset_multi(
522
+ args,
523
+ dataset_name: str,
524
+ examples: List[ds_utils.CachedExample],
525
+ testing_dict: Dict[str, Any],
526
+ *,
527
+ modes: List[str],
528
+ ) -> Dict[str, Any]:
529
+ tokenizer = testing_dict["tokenizer"]
530
+ llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(testing_dict["model"], tokenizer)
531
+
532
+ want_faith = "faithfulness_gen" in modes
533
+ want_recovery = "recovery_ruler" in modes
534
+
535
+ faith_results: List[np.ndarray] = []
536
+ faith_durations: List[float] = []
537
+
538
+ recovery_results: List[np.ndarray] = []
539
+ recovery_attr_durations: List[float] = []
540
+ recovery_skipped = 0
541
+
542
+ total = min(len(examples), args.num_examples)
543
+ iterator = islice(examples, total)
544
+
545
+ save_traces = bool(getattr(args, "save_hop_traces", False))
546
+ manifest_handle = None
547
+ trace_dir: Optional[Path] = None
548
+ if save_traces:
549
+ model_tag = str(testing_dict.get("model_tag", "model"))
550
+ run_tag = _trace_run_tag(testing_dict, modes=modes, total=total)
551
+ trace_dir = Path(args.output_root) / "traces" / dataset_name / model_tag / run_tag
552
+ trace_dir.mkdir(parents=True, exist_ok=True)
553
+ manifest_handle = open(trace_dir / "manifest.jsonl", "w", encoding="utf-8")
554
+
555
+ try:
556
+ for example_idx, ex in enumerate(iterator):
557
+ if want_recovery:
558
+ needle_spans = (ex.metadata or {}).get("needle_spans")
559
+ if not isinstance(needle_spans, list) or not needle_spans:
560
+ raise SystemExit(
561
+ "recovery_ruler requires RULER samples with metadata.needle_spans; "
562
+ f"dataset={dataset_name} has missing/empty needle_spans."
563
+ )
564
+ if ex.target is None:
565
+ raise SystemExit(
566
+ "recovery_ruler requires cached targets (CoT+answer) so row/rec attribution is well-defined. "
567
+ f"dataset={dataset_name} has target=None; run exp/exp2/sample_and_filter.py first."
568
+ )
569
+
570
+ # Determine generation/target once.
571
+ target = ex.target
572
+ if target is None:
573
+ generation, full_output = llm_evaluator.response(ex.prompt)
574
+ target = generation
575
+ response_len = len(tokenizer(full_output).input_ids)
576
+ else:
577
+ response_len = len(tokenizer(llm_evaluator.format_prompt(" " + ex.prompt) + target).input_ids)
578
+
579
+ testing_dict["batch_size"] = max(1, math.floor((testing_dict["max_input_len"] - 100) / max(1, response_len)))
580
+
581
+ gold_prompt: Optional[List[int]] = None
582
+ if want_recovery:
583
+ gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer)
584
+
585
+ if want_recovery and not want_faith and not save_traces:
586
+ # Preserve recovery-only fast path when not saving traces: skip samples with empty gold.
587
+ if not gold_prompt:
588
+ recovery_skipped += 1
589
+ continue
590
+
591
+ time_attr_s = None
592
+ time_faith_s = None
593
+ time_recovery_s = None
594
+
595
+ t0 = time.perf_counter()
596
+ attr_list, hop_payload, user_prompt_indices, keep_prompt_token_indices = run_attribution(testing_dict, ex, target)
597
+ time_attr_s = time.perf_counter() - t0
598
+
599
+ seq_attr = attr_list[0]
600
+ prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
601
+
602
+ if want_recovery and gold_prompt:
603
+ recovery_attr_durations.append(float(time_attr_s))
604
+
605
+ faith_scores = None
606
+ if want_faith:
607
+ t1 = time.perf_counter()
608
+ faith_scores = _compute_faithfulness_scores(
609
+ testing_dict,
610
+ attr_list=attr_list,
611
+ prompt_len=prompt_len,
612
+ prompt=ex.prompt,
613
+ generation=target,
614
+ llm_evaluator=llm_evaluator,
615
+ user_prompt_indices=user_prompt_indices,
616
+ keep_prompt_token_indices=keep_prompt_token_indices,
617
+ )
618
+ time_faith_s = time.perf_counter() - t1
619
+ faith_results.append(faith_scores)
620
+ faith_durations.append(float(time_attr_s))
621
+
622
+ recovery_scores = None
623
+ recovery_skip_reason = None
624
+ if want_recovery:
625
+ if not gold_prompt:
626
+ recovery_skip_reason = "empty_gold_prompt"
627
+ recovery_skipped += 1
628
+ else:
629
+ t2 = time.perf_counter()
630
+ recovery_scores, recovery_skip_reason = _compute_recovery_scores(
631
+ testing_dict,
632
+ attr_list=attr_list,
633
+ prompt_len=prompt_len,
634
+ gold_prompt_token_indices=gold_prompt,
635
+ llm_evaluator=llm_evaluator,
636
+ keep_prompt_token_indices=keep_prompt_token_indices,
637
+ )
638
+ time_recovery_s = time.perf_counter() - t2
639
+ if recovery_scores is None:
640
+ recovery_skipped += 1
641
+ else:
642
+ recovery_results.append(recovery_scores)
643
+
644
+ if manifest_handle is not None and trace_dir is not None:
645
+ try:
646
+ payload = _build_sample_trace_payload(
647
+ ex,
648
+ attr_list=attr_list,
649
+ prompt_len=prompt_len,
650
+ user_prompt_indices=user_prompt_indices,
651
+ keep_prompt_token_indices=keep_prompt_token_indices,
652
+ gold_prompt_token_indices=gold_prompt,
653
+ hop_payload=hop_payload,
654
+ faithfulness_scores=faith_scores,
655
+ recovery_scores=recovery_scores,
656
+ time_attr_s=time_attr_s,
657
+ time_faith_s=time_faith_s,
658
+ time_recovery_s=time_recovery_s,
659
+ )
660
+ _write_sample_trace(
661
+ trace_dir,
662
+ example_idx=example_idx,
663
+ attr_func=str(testing_dict.get("attr_func") or ""),
664
+ prompt=ex.prompt,
665
+ target=target,
666
+ payload=payload,
667
+ manifest_handle=manifest_handle,
668
+ recovery_skipped_reason=recovery_skip_reason,
669
+ )
670
+ except Exception as exc:
671
+ print(f"[warn] sample trace save failed for {testing_dict.get('attr_func')} ex={example_idx}: {exc}")
672
+ finally:
673
+ if manifest_handle is not None:
674
+ try:
675
+ manifest_handle.close()
676
+ except Exception:
677
+ pass
678
+
679
+ out: Dict[str, Any] = {}
680
+ if want_faith:
681
+ if not faith_results:
682
+ out["faithfulness"] = None
683
+ else:
684
+ scores = np.stack(faith_results, axis=0) # [N, 3, 3]
685
+ out["faithfulness"] = {
686
+ "mean": scores.mean(0),
687
+ "std": scores.std(0),
688
+ "avg_time": float(np.mean(faith_durations)) if faith_durations else 0.0,
689
+ }
690
+ if want_recovery:
691
+ if not recovery_results:
692
+ out["recovery"] = None
693
+ else:
694
+ scores = np.stack(recovery_results, axis=0) # [N, 3]
695
+ out["recovery"] = {
696
+ "mean": scores.mean(0),
697
+ "std": scores.std(0),
698
+ "avg_time": float(np.mean(recovery_attr_durations)) if recovery_attr_durations else 0.0,
699
+ "used": int(scores.shape[0]),
700
+ "skipped": int(recovery_skipped),
701
+ }
702
+
703
+ return out
704
+
705
+
706
+ def _faithfulness_test_with_user_prompt_indices(
707
+ llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
708
+ attribution: torch.Tensor,
709
+ prompt: str,
710
+ generation: str,
711
+ *,
712
+ user_prompt_indices: List[int],
713
+ k: int = 20, ### control the MAS steps per sample
714
+ ) -> Tuple[float, float, float]:
715
+ """Token-level MAS/RISE faithfulness via guided deletion in k perturbation steps using provided prompt indices.
716
+
717
+ This mirrors llm_attr_eval.LLMAttributionEvaluator.faithfulness_test, but avoids
718
+ locating the user prompt span via token-id subsequence matching (which may fail
719
+ for some tokenizers due to non-compositional BPE merges at template boundaries).
720
+ """
721
+
722
+ def auc(arr: np.ndarray) -> float:
723
+ return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / max(1, (arr.shape[0] - 1))
724
+
725
+ pad_token_id = llm_evaluator._ensure_pad_token_id()
726
+
727
+ user_prompt = " " + prompt
728
+ formatted_prompt = llm_evaluator.format_prompt(user_prompt)
729
+ formatted_ids = llm_evaluator.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).input_ids
730
+
731
+ prompt_ids = formatted_ids.to(llm_evaluator.device)
732
+ prompt_ids_perturbed = prompt_ids.clone()
733
+ generation_ids = llm_evaluator.tokenizer(
734
+ generation + llm_evaluator.tokenizer.eos_token,
735
+ return_tensors="pt",
736
+ add_special_tokens=False,
737
+ ).input_ids.to(llm_evaluator.device)
738
+
739
+ attr_cpu = attribution.detach().cpu()
740
+ w = attr_cpu.sum(0)
741
+ sorted_attr_indices = torch.argsort(w, descending=True)
742
+ attr_sum = float(w.sum().item())
743
+
744
+ P = int(w.numel())
745
+ if len(user_prompt_indices) != P:
746
+ raise ValueError(
747
+ "user_prompt_indices length does not match prompt-side attribution length: "
748
+ f"indices P={len(user_prompt_indices)}, attr P={P}."
749
+ )
750
+ if P == 0:
751
+ return 0.0, 0.0, 0.0
752
+
753
+ if max(user_prompt_indices) >= int(prompt_ids_perturbed.shape[1]):
754
+ raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
755
+
756
+ if P > 0:
757
+ steps = int(k) if k is not None else 0
758
+ if steps <= 0:
759
+ steps = 1
760
+ steps = min(steps, P)
761
+ else:
762
+ steps = 0
763
+
764
+ scores = np.zeros(steps + 1, dtype=np.float64)
765
+ density = np.zeros(steps + 1, dtype=np.float64)
766
+
767
+ scores[0] = (
768
+ llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
769
+ )
770
+ density[0] = 1.0
771
+
772
+ if attr_sum <= 0:
773
+ density = np.linspace(1.0, 0.0, steps + 1)
774
+
775
+ base = P // steps
776
+ remainder = P % steps
777
+ start = 0
778
+ for step in range(steps):
779
+ size = base + (1 if step < remainder else 0)
780
+ group = sorted_attr_indices[start : start + size]
781
+ start += size
782
+
783
+ for idx in group:
784
+ j = int(idx.item())
785
+ abs_pos = int(user_prompt_indices[j])
786
+ prompt_ids_perturbed[0, abs_pos] = pad_token_id
787
+ scores[step + 1] = (
788
+ llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
789
+ )
790
+ if attr_sum > 0:
791
+ dec = float(w.index_select(0, group).sum().item()) / attr_sum
792
+ density[step + 1] = density[step] - dec
793
+
794
+ min_normalized_pred = 1.0
795
+ normalized_model_response = scores.copy()
796
+ for i in range(len(scores)):
797
+ normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
798
+ normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
799
+ min_normalized_pred = min(min_normalized_pred, normalized_pred)
800
+ normalized_model_response[i] = min_normalized_pred
801
+
802
+ alignment_penalty = np.abs(normalized_model_response - density)
803
+ corrected_scores = normalized_model_response + alignment_penalty
804
+ corrected_scores = corrected_scores.clip(0.0, 1.0)
805
+ corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
806
+
807
+ if np.isnan(corrected_scores).any():
808
+ corrected_scores = np.linspace(1.0, 0.0, len(scores))
809
+
810
+ return auc(normalized_model_response), auc(corrected_scores), auc(normalized_model_response + alignment_penalty)
811
+
812
+
813
+ def load_model(model_name: str, device: str):
814
+ model = AutoModelForCausalLM.from_pretrained(
815
+ model_name,
816
+ device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
817
+ torch_dtype=torch.float16,
818
+ attn_implementation="eager",
819
+ )
820
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
821
+ tokenizer.pad_token = tokenizer.eos_token
822
+ model.eval()
823
+ return model, tokenizer
824
+
825
+
826
+ def resolve_device(args) -> str:
827
+ if args.cuda is not None and "," in args.cuda:
828
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
829
+ return "auto"
830
+ if args.cuda is not None and args.cuda.strip():
831
+ return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
832
+ return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
833
+
834
+
835
+ def run_attribution(
836
+ testing_dict, example: ds_utils.CachedExample, target: Optional[str]
837
+ ) -> Tuple[List[torch.Tensor], Optional[Dict[str, np.ndarray]], Optional[List[int]]]:
838
+ model = testing_dict["model"]
839
+ tokenizer = testing_dict["tokenizer"]
840
+ attr_func = testing_dict["attr_func"]
841
+
842
+ indices_to_explain = example.indices_to_explain
843
+ if not (isinstance(indices_to_explain, list) and len(indices_to_explain) == 2):
844
+ raise ValueError(
845
+ "exp2 requires token-span indices_to_explain=[start_tok,end_tok]. "
846
+ "Please re-sample or run exp/exp2/migrate_indices_to_explain_token_span.py on your cache."
847
+ )
848
+
849
+ llm_attributor = None
850
+ if "IG" in attr_func:
851
+ llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
852
+ attr = llm_attributor.calculate_IG_per_generation(
853
+ example.prompt,
854
+ 20,
855
+ tokenizer.eos_token_id,
856
+ batch_size=testing_dict["batch_size"],
857
+ target=target,
858
+ )
859
+ elif "perturbation" in attr_func:
860
+ if attr_func in ("perturbation_all_fast", "perturbation_CLP_fast", "perturbation_REAGENT_fast"):
861
+ import perturbation_fast
862
+
863
+ llm_attributor = perturbation_fast.LLMPerturbationFastAttribution(model, tokenizer)
864
+ if attr_func == "perturbation_all_fast":
865
+ attr = llm_attributor.calculate_feature_ablation_segments(
866
+ example.prompt,
867
+ baseline=tokenizer.eos_token_id,
868
+ measure="log_loss",
869
+ target=target,
870
+ source_k=20,
871
+ )
872
+ elif attr_func == "perturbation_CLP_fast":
873
+ attr = llm_attributor.calculate_feature_ablation_segments(
874
+ example.prompt,
875
+ baseline=tokenizer.eos_token_id,
876
+ measure="KL",
877
+ target=target,
878
+ source_k=20,
879
+ )
880
+ else:
881
+ attr = llm_attributor.calculate_feature_ablation_segments_mlm(
882
+ example.prompt,
883
+ target=target,
884
+ source_k=20,
885
+ )
886
+ else:
887
+ llm_attributor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
888
+ if attr_func == "perturbation_all":
889
+ attr = llm_attributor.calculate_feature_ablation_sentences(
890
+ example.prompt, baseline=tokenizer.eos_token_id, measure="log_loss", target=target
891
+ )
892
+ elif attr_func == "perturbation_CLP":
893
+ attr = llm_attributor.calculate_feature_ablation_sentences(
894
+ example.prompt, baseline=tokenizer.eos_token_id, measure="KL", target=target
895
+ )
896
+ elif attr_func == "perturbation_REAGENT":
897
+ attr = llm_attributor.calculate_feature_ablation_sentences_mlm(example.prompt, target=target)
898
+ else:
899
+ raise ValueError(f"Unsupported perturbation attr_func {attr_func}")
900
+ elif "attention" in attr_func:
901
+ llm_attributor = llm_attr.LLMAttentionAttribution(model, tokenizer)
902
+ llm_attributor_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
903
+ attr = llm_attributor.calculate_attention_attribution(example.prompt, target=target)
904
+ attr_b = llm_attributor_ig.calculate_IG_per_generation(
905
+ example.prompt, 20, tokenizer.eos_token_id, batch_size=testing_dict["batch_size"], target=target
906
+ )
907
+ attr.attribution_matrix = attr.attribution_matrix * attr_b.attribution_matrix
908
+ elif attr_func == "ifr_all_positions":
909
+ llm_attributor = llm_attr.LLMIFRAttribution(
910
+ model,
911
+ tokenizer,
912
+ chunk_tokens=testing_dict["chunk_tokens"],
913
+ sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
914
+ )
915
+ attr = llm_attributor.calculate_ifr_for_all_positions(example.prompt, target=target)
916
+ elif attr_func == "ifr_all_positions_output_only":
917
+ llm_attributor = llm_attr.LLMIFRAttribution(
918
+ model,
919
+ tokenizer,
920
+ chunk_tokens=testing_dict["chunk_tokens"],
921
+ sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
922
+ )
923
+ sink_span = tuple(example.sink_span) if example.sink_span else tuple(indices_to_explain)
924
+ attr = llm_attributor.calculate_ifr_for_all_positions_output_only(
925
+ example.prompt,
926
+ target=target,
927
+ sink_span=sink_span,
928
+ )
929
+ elif attr_func == "ifr_multi_hop":
930
+ llm_attributor = llm_attr.LLMIFRAttribution(
931
+ model,
932
+ tokenizer,
933
+ chunk_tokens=testing_dict["chunk_tokens"],
934
+ sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
935
+ )
936
+ attr = llm_attributor.calculate_ifr_multi_hop(
937
+ example.prompt,
938
+ target=target,
939
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
940
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
941
+ n_hops=testing_dict["n_hops"],
942
+ )
943
+ elif attr_func == "ifr_in_all_gen":
944
+ import ft_ifr_improve
945
+
946
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(
947
+ model,
948
+ tokenizer,
949
+ chunk_tokens=testing_dict["chunk_tokens"],
950
+ sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
951
+ )
952
+ attr = llm_attributor.calculate_ifr_in_all_gen(
953
+ example.prompt,
954
+ target=target,
955
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
956
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
957
+ n_hops=testing_dict["n_hops"],
958
+ )
959
+ elif attr_func == "ifr_multi_hop_stop_words":
960
+ import ft_ifr_improve
961
+
962
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionImproved(
963
+ model,
964
+ tokenizer,
965
+ chunk_tokens=testing_dict["chunk_tokens"],
966
+ sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
967
+ )
968
+ attr = llm_attributor.calculate_ifr_multi_hop_stop_words(
969
+ example.prompt,
970
+ target=target,
971
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
972
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
973
+ n_hops=testing_dict["n_hops"],
974
+ )
975
+ elif attr_func == "ifr_multi_hop_both":
976
+ import ft_ifr_improve
977
+
978
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionBoth(
979
+ model,
980
+ tokenizer,
981
+ chunk_tokens=testing_dict["chunk_tokens"],
982
+ sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
983
+ )
984
+ attr = llm_attributor.calculate_ifr_multi_hop_both(
985
+ example.prompt,
986
+ target=target,
987
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
988
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
989
+ n_hops=testing_dict["n_hops"],
990
+ )
991
+ elif attr_func == "ifr_multi_hop_split_hop":
992
+ import ft_ifr_improve
993
+
994
+ llm_attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(
995
+ model,
996
+ tokenizer,
997
+ chunk_tokens=testing_dict["chunk_tokens"],
998
+ sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
999
+ )
1000
+ attr = llm_attributor.calculate_ifr_multi_hop_split_hop(
1001
+ example.prompt,
1002
+ target=target,
1003
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
1004
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
1005
+ n_hops=testing_dict["n_hops"],
1006
+ )
1007
+ elif attr_func == "attnlrp":
1008
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
1009
+ attr = llm_attributor.calculate_attnlrp_ft_hop0(
1010
+ example.prompt,
1011
+ target=target,
1012
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
1013
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
1014
+ neg_handling=str(testing_dict.get("attnlrp_neg_handling", "drop")),
1015
+ norm_mode=str(testing_dict.get("attnlrp_norm_mode", "norm")),
1016
+ )
1017
+ elif attr_func in ("ft_attnlrp", "attnlrp_aggregated_multi_hop"):
1018
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
1019
+ attr = llm_attributor.calculate_attnlrp_aggregated_multi_hop(
1020
+ example.prompt,
1021
+ target=target,
1022
+ sink_span=tuple(example.sink_span) if example.sink_span else None,
1023
+ thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
1024
+ n_hops=testing_dict["n_hops"],
1025
+ neg_handling=str(testing_dict.get("attnlrp_neg_handling", "drop")),
1026
+ norm_mode=str(testing_dict.get("attnlrp_norm_mode", "norm")),
1027
+ )
1028
+ elif attr_func == "basic":
1029
+ llm_attributor = llm_attr.LLMBasicAttribution(model, tokenizer)
1030
+ attr = llm_attributor.calculate_basic_attribution(example.prompt, target=target)
1031
+ else:
1032
+ raise ValueError(f"Unsupported attr_func {attr_func}")
1033
+
1034
+ seq_attr, row_attr, rec_attr = attr.get_all_token_attrs(indices_to_explain)
1035
+ hop_payload = None
1036
+ if bool(testing_dict.get("save_hop_traces", False)):
1037
+ try:
1038
+ hop_payload = _build_hop_trace_payload(attr_func, attr, indices_to_explain=indices_to_explain)
1039
+ except Exception as exc:
1040
+ print(f"[warn] hop trace extraction failed for {attr_func}: {exc}")
1041
+ hop_payload = None
1042
+
1043
+ user_prompt_indices = getattr(llm_attributor, "user_prompt_indices", None)
1044
+ if isinstance(user_prompt_indices, list):
1045
+ user_prompt_indices = [int(x) for x in user_prompt_indices]
1046
+ else:
1047
+ user_prompt_indices = None
1048
+
1049
+ keep_prompt_token_indices = None
1050
+ if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both"):
1051
+ try:
1052
+ import ft_ifr_improve
1053
+
1054
+ keep_prompt_token_indices = ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens))
1055
+ except Exception:
1056
+ keep_prompt_token_indices = None
1057
+
1058
+ return [seq_attr, row_attr, rec_attr], hop_payload, user_prompt_indices, keep_prompt_token_indices
1059
+
1060
+
1061
+ def faithfulness_generation(
1062
+ testing_dict, example: ds_utils.CachedExample, target: str, llm_evaluator
1063
+ ) -> Tuple[np.ndarray, Optional[Dict[str, np.ndarray]]]:
1064
+ prompt = example.prompt
1065
+ generation = target
1066
+
1067
+ attr_func = str(testing_dict.get("attr_func") or "")
1068
+ attr_list, hop_payload, user_prompt_indices, keep_prompt_token_indices = run_attribution(
1069
+ testing_dict, example, target
1070
+ )
1071
+ seq_attr = attr_list[0]
1072
+ prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
1073
+
1074
+ results = []
1075
+ for attr in attr_list:
1076
+ # Only use prompt-side attribution, matching evaluations/faithfulness.py
1077
+ attr_prompt = attr[:, :prompt_len]
1078
+ if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
1079
+ import ft_ifr_improve
1080
+
1081
+ scores = ft_ifr_improve.faithfulness_test_skip_tokens(
1082
+ llm_evaluator,
1083
+ attr_prompt,
1084
+ prompt,
1085
+ generation,
1086
+ keep_prompt_token_indices=keep_prompt_token_indices,
1087
+ user_prompt_indices=user_prompt_indices,
1088
+ )
1089
+ elif user_prompt_indices is not None:
1090
+ scores = _faithfulness_test_with_user_prompt_indices(
1091
+ llm_evaluator,
1092
+ attr_prompt,
1093
+ prompt,
1094
+ generation,
1095
+ user_prompt_indices=user_prompt_indices,
1096
+ )
1097
+ else:
1098
+ scores = llm_evaluator.faithfulness_test(attr_prompt, prompt, generation)
1099
+ results.append(scores)
1100
+
1101
+ return np.array(results), hop_payload
1102
+
1103
+
1104
+ def evaluate_dataset(args, dataset_name: str, examples: List[ds_utils.CachedExample], testing_dict):
1105
+ out = evaluate_dataset_multi(args, dataset_name, examples, testing_dict, modes=["faithfulness_gen"])
1106
+ faith = out.get("faithfulness")
1107
+ if not faith:
1108
+ return None
1109
+ return faith["mean"], faith["std"], faith["avg_time"]
1110
+
1111
+
1112
+ def evaluate_dataset_recovery_ruler(args, dataset_name: str, examples: List[ds_utils.CachedExample], testing_dict):
1113
+ out = evaluate_dataset_multi(args, dataset_name, examples, testing_dict, modes=["recovery_ruler"])
1114
+ rec = out.get("recovery")
1115
+ if not rec:
1116
+ return None
1117
+ return rec["mean"], rec["std"], rec["avg_time"], rec["used"], rec["skipped"]
1118
+
1119
+
1120
+ def main():
1121
+ parser = argparse.ArgumentParser("Experiment 2 runner (math skipped, AT2 skipped).")
1122
+ parser.add_argument("--datasets", type=str, required=True, help="Comma-separated names or paths.")
1123
+ parser.add_argument("--attr_funcs", type=str, required=True, help="Comma-separated attr funcs (no AT2).")
1124
+ parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).")
1125
+ parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.")
1126
+ parser.add_argument("--cuda", type=str, default=None)
1127
+ parser.add_argument("--cuda_num", type=int, default=0)
1128
+ parser.add_argument("--num_examples", type=int, default=100)
1129
+ parser.add_argument(
1130
+ "--mode",
1131
+ type=str,
1132
+ nargs="+",
1133
+ default=["faithfulness_gen"],
1134
+ help=(
1135
+ "One or more of: faithfulness_gen, recovery_ruler. "
1136
+ "Accepts comma-separated values, e.g. '--mode faithfulness_gen,recovery_ruler' "
1137
+ "or '--mode faithfulness_gen, recovery_ruler'."
1138
+ ),
1139
+ )
1140
+ parser.add_argument("--sample", type=int, default=None, help="Optional subsample before num_examples.")
1141
+ parser.add_argument("--seed", type=int, default=42)
1142
+ parser.add_argument("--chunk_tokens", type=int, default=128)
1143
+ parser.add_argument("--sink_chunk_tokens", type=int, default=32)
1144
+ parser.add_argument("--n_hops", type=int, default=3)
1145
+ parser.add_argument(
1146
+ "--attnlrp_neg_handling",
1147
+ type=str,
1148
+ choices=["drop", "abs"],
1149
+ default="drop",
1150
+ help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
1151
+ )
1152
+ parser.add_argument(
1153
+ "--attnlrp_norm_mode",
1154
+ type=str,
1155
+ choices=["norm", "no_norm"],
1156
+ default="norm",
1157
+ help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
1158
+ )
1159
+ parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Filtered dataset cache directory.")
1160
+ parser.add_argument("--output_root", type=str, default="exp/exp2/output", help="Directory to store evaluation outputs.")
1161
+ parser.add_argument(
1162
+ "--save_hop_traces",
1163
+ action="store_true",
1164
+ help=(
1165
+ "Save per-sample trace artifacts (attribution vectors + per-sample metrics) under output_root/traces/. "
1166
+ "For multi-hop methods, also saves per-hop token vectors (vh)."
1167
+ ),
1168
+ )
1169
+ args = parser.parse_args()
1170
+ modes = _parse_modes(args.mode)
1171
+
1172
+ if args.model_path:
1173
+ model_name = args.model_path
1174
+ elif args.model:
1175
+ model_name = args.model
1176
+ else:
1177
+ raise SystemExit("Please set --model or --model_path.")
1178
+ model_tag = args.model if args.model else Path(args.model_path).name
1179
+
1180
+ datasets = [d.strip() for d in args.datasets.split(",") if d.strip()]
1181
+ attr_funcs = [a.strip() for a in args.attr_funcs.split(",") if a.strip()]
1182
+
1183
+ device = resolve_device(args)
1184
+ model, tokenizer = load_model(model_name, device)
1185
+
1186
+ max_input_len = {
1187
+ "llama-1B": 5500,
1188
+ "llama-3B": 4800,
1189
+ "llama-8B": 3500,
1190
+ "qwen-1.7B": 5500,
1191
+ "qwen-4B": 3500,
1192
+ "qwen-8B": 5000,
1193
+ "qwen-32B": 1500,
1194
+ "gemma-12B": 1500,
1195
+ "gemma-27B": 2000,
1196
+ }.get(args.model, 2000)
1197
+
1198
+ for ds_name in datasets:
1199
+ if "recovery_ruler" in modes and ds_name == "morehopqa":
1200
+ raise SystemExit("recovery_ruler only supports RULER datasets (with needle_spans), not morehopqa.")
1201
+ if "recovery_ruler" in modes and ds_name.startswith("math"):
1202
+ raise SystemExit("recovery_ruler only supports RULER datasets (with needle_spans), not math.")
1203
+
1204
+ # Resolve dataset (prefer prepared cache under data_root)
1205
+ cached_path = Path(args.data_root) / f"{ds_name}.jsonl"
1206
+ if cached_path.exists():
1207
+ examples = ds_utils.load_cached(cached_path, sample=args.sample, seed=args.seed)
1208
+ else:
1209
+ # allow direct cached path or raw loader
1210
+ p = Path(ds_name)
1211
+ if p.exists():
1212
+ examples = ds_utils.load_cached(p, sample=args.sample, seed=args.seed)
1213
+ else:
1214
+ hint = "please run exp/exp2/sample_and_filter.py first (or pass an explicit cached JSONL path)."
1215
+ if ds_name.startswith("math"):
1216
+ hint = "please run exp/exp2/map_math_mine_to_exp2_cache.py first (or pass an explicit cached JSONL path)."
1217
+ raise SystemExit(f"Missing exp2 cache for '{ds_name}'. Expected {cached_path}; {hint}")
1218
+
1219
+ for attr_func in attr_funcs:
1220
+ if attr_func.lower() == "at2":
1221
+ print("Skipping AT2 as requested.")
1222
+ continue
1223
+
1224
+ testing_dict: Dict[str, any] = {
1225
+ "model": model,
1226
+ "model_tag": model_tag,
1227
+ "tokenizer": tokenizer,
1228
+ "attr_func": attr_func,
1229
+ "max_input_len": max_input_len,
1230
+ "chunk_tokens": args.chunk_tokens,
1231
+ "sink_chunk_tokens": args.sink_chunk_tokens,
1232
+ "n_hops": args.n_hops,
1233
+ "attnlrp_neg_handling": args.attnlrp_neg_handling,
1234
+ "attnlrp_norm_mode": args.attnlrp_norm_mode,
1235
+ "device": device,
1236
+ "batch_size": 1,
1237
+ "save_hop_traces": bool(args.save_hop_traces),
1238
+ }
1239
+ result = evaluate_dataset_multi(args, ds_name, examples, testing_dict, modes=modes)
1240
+
1241
+ if "faithfulness_gen" in modes:
1242
+ faith = result.get("faithfulness")
1243
+ if not faith:
1244
+ print(f"No faithfulness results for {ds_name} with {attr_func}.")
1245
+ else:
1246
+ mean = faith["mean"]
1247
+ std = faith["std"]
1248
+ avg_time = float(faith["avg_time"])
1249
+
1250
+ out_dir = Path(args.output_root) / "faithfulness" / ds_name / model_tag
1251
+ out_dir.mkdir(parents=True, exist_ok=True)
1252
+ filename = f"{attr_func}_{args.num_examples}_examples.csv"
1253
+ with open(out_dir / filename, "w") as f:
1254
+ f.write("Method,RISE,MAS,RISE+AP\n")
1255
+ f.write(",".join(["Seq Attr Scores Mean"] + [str(x) for x in mean[0].tolist()]) + "\n")
1256
+ f.write(",".join(["Row Attr Scores Mean"] + [str(x) for x in mean[1].tolist()]) + "\n")
1257
+ f.write(",".join(["Recursive Attr Scores Mean"] + [str(x) for x in mean[2].tolist()]) + "\n")
1258
+ f.write(",".join(["Seq Attr Scores Var"] + [str(x) for x in std[0].tolist()]) + "\n")
1259
+ f.write(",".join(["Row Attr Scores Var"] + [str(x) for x in std[1].tolist()]) + "\n")
1260
+ f.write(",".join(["Recursive Attr Scores Var"] + [str(x) for x in std[2].tolist()]) + "\n")
1261
+ f.write(f"Avg Sample Time (s),{avg_time}\n")
1262
+ print(f"[{ds_name}] {attr_func} -> {out_dir/filename} (avg sample time: {avg_time:.2f}s)")
1263
+
1264
+ if "recovery_ruler" in modes:
1265
+ rec = result.get("recovery")
1266
+ if not rec:
1267
+ print(f"No recovery results for {ds_name} with {attr_func}.")
1268
+ else:
1269
+ mean = rec["mean"]
1270
+ std = rec["std"]
1271
+ avg_time = float(rec["avg_time"])
1272
+ used = int(rec["used"])
1273
+ skipped = int(rec["skipped"])
1274
+
1275
+ out_dir = Path(args.output_root) / "recovery" / ds_name / model_tag
1276
+ out_dir.mkdir(parents=True, exist_ok=True)
1277
+ filename = f"{attr_func}_{args.num_examples}_examples.csv"
1278
+ with open(out_dir / filename, "w") as f:
1279
+ f.write("Method,Recovery@10%\n")
1280
+ f.write(f"Seq Attr Recovery Mean,{mean[0]}\n")
1281
+ f.write(f"Row Attr Recovery Mean,{mean[1]}\n")
1282
+ f.write(f"Recursive Attr Recovery Mean,{mean[2]}\n")
1283
+ f.write(f"Seq Attr Recovery Std,{std[0]}\n")
1284
+ f.write(f"Row Attr Recovery Std,{std[1]}\n")
1285
+ f.write(f"Recursive Attr Recovery Std,{std[2]}\n")
1286
+ f.write(f"Examples Used,{used}\n")
1287
+ f.write(f"Examples Skipped,{skipped}\n")
1288
+ f.write(f"Avg Sample Time (s),{avg_time}\n")
1289
+ print(
1290
+ f"[{ds_name}] {attr_func} -> {out_dir/filename} "
1291
+ f"(used={used} skipped={skipped} avg sample time: {avg_time:.2f}s)"
1292
+ )
1293
+
1294
+
1295
+ if __name__ == "__main__":
1296
+ main()
exp/exp2/sample_and_filter.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dataset sampler for Experiment 2.
4
+
5
+ Steps:
6
+ - Load a dataset item (MoreHopQA / HotpotQA / RULER niah / RULER vt).
7
+ - Call the generation model (qwen3-235b-a22b-2507) with a system prompt that
8
+ asks for brief reasoning and a final answer wrapped in \\box{}.
9
+ - Enforce the output format: keep only generations that look like
10
+ "<reasoning text> + final \\box{} answer" with nothing after the box.
11
+ - Call the judge model (deepseek-v3-1-terminus) to check whether the boxed
12
+ answer matches the dataset reference answer; keep only judged True samples.
13
+ - Rebuild `target` as "<reasoning>\\n<answer text (no box)>" and store filtered
14
+ samples to exp/exp2/data/<dataset>.jsonl (or a custom path) with inferred spans.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import json
21
+ import os
22
+ import sys
23
+ import time
24
+ import urllib.error
25
+ import urllib.request
26
+ from pathlib import Path
27
+ from typing import Any, Dict, Iterable, List, Optional
28
+
29
+ from transformers import AutoTokenizer
30
+ from tqdm import tqdm
31
+
32
+ REPO_ROOT = Path(__file__).resolve().parents[2]
33
+ if str(REPO_ROOT) not in sys.path:
34
+ sys.path.insert(0, str(REPO_ROOT))
35
+
36
+ from exp.exp2.dataset_utils import (
37
+ CachedExample,
38
+ DatasetLoader,
39
+ attach_spans_from_answer,
40
+ split_boxed_generation,
41
+ )
42
+
43
+
44
+ class RateLimitError(RuntimeError):
45
+ """Raised when API returns 429; carries a suggested wait time."""
46
+
47
+ def __init__(self, wait_seconds: float, detail: str) -> None:
48
+ super().__init__(detail)
49
+ self.wait_seconds = wait_seconds
50
+
51
+ # GEN_SYSTEM_PROMPT = (
52
+ # "You are a careful reasoning assistant. "
53
+ # "Before answering, engage in an extremely detailed and exhaustive chain of thought. **No fewer than 2k tokens.** "
54
+ # "Do not skip any logical steps, even if they seem obvious. "
55
+ # "Process this freely and naturally without using specific headers or strict formatting. "
56
+ # "When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
57
+ # "Ensure the box wraps the **sentence** that naturally delivers the answer. DO NOT rewrite the answer word for the box separately."
58
+ # )
59
+
60
+ GEN_SYSTEM_PROMPT = (
61
+ "You are a reasoning assistant. "
62
+ "Before answering, engage in an chain of thought. "
63
+ "Process this freely and naturally without using specific headers or strict formatting. "
64
+ "When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
65
+ "Ensure the box wraps the **sentence** that naturally delivers the answer. DO NOT rewrite the answer word for the box separately."
66
+ )
67
+
68
+ JUDGE_SYSTEM_PROMPT = (
69
+ "You verify whether the model's boxed answer matches the reference answer. "
70
+ "Reply strictly with True or False and nothing else."
71
+ )
72
+
73
+
74
+ def call_chat_api(
75
+ api_base: str,
76
+ api_key: str,
77
+ model: str,
78
+ messages: List[Dict[str, str]],
79
+ *,
80
+ timeout: int,
81
+ max_tokens: int,
82
+ temperature: float,
83
+ cache_ttl: int,
84
+ cache_namespace: Optional[str],
85
+ rate_limit_delay: Optional[float] = None,
86
+ ) -> str:
87
+ url = api_base.rstrip("/") + "/chat/completions"
88
+ payload: Dict[str, Any] = {
89
+ "model": model,
90
+ "messages": messages,
91
+ "max_tokens": max_tokens,
92
+ "temperature": temperature,
93
+ }
94
+ if cache_ttl > 0:
95
+ cache_obj: Dict[str, Any] = {"ttl": cache_ttl}
96
+ if cache_namespace:
97
+ cache_obj["namespace"] = cache_namespace
98
+ payload["cache"] = cache_obj
99
+
100
+ data = json.dumps(payload).encode("utf-8")
101
+ headers = {"Content-Type": "application/json"}
102
+ if api_key:
103
+ headers["Authorization"] = f"Bearer {api_key}"
104
+
105
+ req = urllib.request.Request(url, data=data, headers=headers, method="POST")
106
+ opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
107
+ try:
108
+ with opener.open(req, timeout=timeout) as resp:
109
+ resp_bytes = resp.read()
110
+ except urllib.error.HTTPError as e:
111
+ detail = e.read().decode("utf-8", errors="ignore") if hasattr(e, "read") else ""
112
+ if e.code == 429:
113
+ retry_after = None
114
+ if hasattr(e, "headers") and e.headers:
115
+ retry_after_header = e.headers.get("Retry-After")
116
+ if retry_after_header:
117
+ try:
118
+ retry_after = float(retry_after_header)
119
+ except ValueError:
120
+ retry_after = None
121
+ wait = retry_after or rate_limit_delay or 5.0
122
+ raise RateLimitError(wait, f"API HTTP 429: {detail}") from e
123
+ raise RuntimeError(f"API HTTP error {e.code}: {detail}") from e
124
+ except urllib.error.URLError as e:
125
+ raise RuntimeError(f"API request failed: {e}") from e
126
+
127
+ try:
128
+ response = json.loads(resp_bytes.decode("utf-8"))
129
+ except json.JSONDecodeError as e:
130
+ raise RuntimeError(f"Failed to decode API response: {resp_bytes!r}") from e
131
+
132
+ choices = response.get("choices", [])
133
+ if not choices:
134
+ raise RuntimeError(f"Empty choices from API: {response}")
135
+ content = choices[0].get("message", {}).get("content", "")
136
+ if not content:
137
+ raise RuntimeError(f"Empty content from API: {response}")
138
+ return content.strip()
139
+
140
+
141
+ def build_gen_messages(prompt: str) -> List[Dict[str, str]]:
142
+ return [
143
+ {"role": "system", "content": GEN_SYSTEM_PROMPT},
144
+ {"role": "user", "content": prompt},
145
+ ]
146
+
147
+
148
+ def build_judge_messages(reference_answer: str, candidate_answer: str) -> List[Dict[str, str]]:
149
+ user = (
150
+ "Decide if the model's boxed answer matches the reference answer.\n"
151
+ f"Reference answer: {reference_answer}\n"
152
+ f"Model boxed answer (only the content inside \\box{{}}): {candidate_answer}\n"
153
+ "Output only True if they are semantically consistent; otherwise output False."
154
+ )
155
+ return [
156
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
157
+ {"role": "user", "content": user},
158
+ ]
159
+
160
+
161
+ def parse_bool(text: str) -> bool:
162
+ first = text.strip().splitlines()[0].strip().lower()
163
+ if first in {"true", "yes"}:
164
+ return True
165
+ if first in {"false", "no"}:
166
+ return False
167
+ # fallback: check substring
168
+ if "true" in first and "false" not in first:
169
+ return True
170
+ if "false" in first:
171
+ return False
172
+ raise ValueError(f"Cannot parse boolean from: {text!r}")
173
+
174
+
175
+ def write_cache(out_path: Path, examples: Iterable[CachedExample]) -> int:
176
+ out_path.parent.mkdir(parents=True, exist_ok=True)
177
+ count = 0
178
+ with out_path.open("w", encoding="utf-8") as f:
179
+ for ex in examples:
180
+ obj: Dict[str, Any] = {
181
+ "prompt": ex.prompt,
182
+ "target": ex.target,
183
+ "indices_to_explain": ex.indices_to_explain,
184
+ "attr_mask_indices": ex.attr_mask_indices,
185
+ "sink_span": ex.sink_span,
186
+ "thinking_span": ex.thinking_span,
187
+ "metadata": ex.metadata,
188
+ }
189
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
190
+ count += 1
191
+ return count
192
+
193
+
194
+ def main():
195
+ parser = argparse.ArgumentParser("Sample and filter dataset examples for exp2.")
196
+ parser.add_argument(
197
+ "--dataset",
198
+ type=str,
199
+ required=True,
200
+ help="morehopqa | hotpotqa_long | niah_* | vt_* | <morehopqa_json_path> | <ruler_jsonl_path>",
201
+ )
202
+ parser.add_argument("--max_examples", type=int, default=100, help="Number of raw examples to sample before filtering.")
203
+ parser.add_argument("--seed", type=int, default=42)
204
+ parser.add_argument("--api_base", type=str, default="http://localhost:4000/v1", help="Chat API base URL.")
205
+ parser.add_argument("--api_key", type=str, default=None, help="API key; defaults to FLASHTRACE_API_KEY/OPENAI_API_KEY.")
206
+ parser.add_argument("--generator_model", type=str, default="qwen3-235b-a22b-2507")
207
+ parser.add_argument("--judge_model", type=str, default="deepseek-v3-1-terminus")
208
+ parser.add_argument("--api_timeout", type=int, default=300)
209
+ parser.add_argument("--api_max_tokens", type=int, default=8192)
210
+ parser.add_argument("--api_temperature", type=float, default=0.0)
211
+ parser.add_argument("--api_cache_ttl", type=int, default=600)
212
+ parser.add_argument("--api_cache_namespace", type=str, default="flashtrace-exp2")
213
+ parser.add_argument("--retry_delay", type=float, default=2.0)
214
+ parser.add_argument("--retries", type=int, default=2, help="Additional retries on API failure.")
215
+ parser.add_argument("--request_interval", type=float, default=1.0, help="Sleep seconds between generation calls.")
216
+ parser.add_argument("--judge_interval", type=float, default=1.0, help="Sleep seconds between judge calls.")
217
+ parser.add_argument("--tokenizer_model", type=str, default=None, help="Tokenizer path for span extraction (default: generator model).")
218
+ parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Output directory for filtered caches.")
219
+ parser.add_argument("--out", type=str, default=None, help="Optional explicit output path (JSONL).")
220
+ parser.add_argument("--rate_limit_delay", type=float, default=5.0, help="Seconds to wait on HTTP 429 before retrying.")
221
+ args = parser.parse_args()
222
+
223
+ api_key = args.api_key or os.environ.get("FLASHTRACE_API_KEY") or os.environ.get("OPENAI_API_KEY")
224
+ if not api_key:
225
+ raise SystemExit("Set --api_key or FLASHTRACE_API_KEY/OPENAI_API_KEY for API access.")
226
+
227
+ loader = DatasetLoader(seed=args.seed, data_root=args.data_root)
228
+ # Load full dataset; we will stop early once enough kept examples are collected.
229
+ raw_examples = loader.load_raw(args.dataset, sample=None)
230
+ if not raw_examples:
231
+ raise SystemExit("No examples loaded.")
232
+
233
+ tok_name = args.tokenizer_model or args.generator_model
234
+ tok_path = Path(tok_name)
235
+ if tok_path.exists():
236
+ tokenizer = AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
237
+ else:
238
+ tokenizer = AutoTokenizer.from_pretrained(tok_name)
239
+ tokenizer.pad_token = tokenizer.eos_token
240
+
241
+ kept: List[CachedExample] = []
242
+ total = len(raw_examples)
243
+ kept_bar = tqdm(total=args.max_examples, desc="Kept (judge=True)", position=1, leave=False)
244
+ attempted = 0
245
+
246
+ for idx, ex in enumerate(tqdm(raw_examples, total=total, desc="Sampling"), 1):
247
+ if len(kept) >= args.max_examples:
248
+ break
249
+ reference_answer = ex.metadata.get("reference_answer") or ex.target or ""
250
+ gen_messages = build_gen_messages(ex.prompt)
251
+ attempted = idx
252
+
253
+ # Step 1: generation
254
+ for attempt in range(args.retries + 1):
255
+ try:
256
+ generation = call_chat_api(
257
+ args.api_base,
258
+ api_key,
259
+ args.generator_model,
260
+ gen_messages,
261
+ timeout=args.api_timeout,
262
+ max_tokens=args.api_max_tokens,
263
+ temperature=args.api_temperature,
264
+ cache_ttl=args.api_cache_ttl,
265
+ cache_namespace=args.api_cache_namespace,
266
+ rate_limit_delay=args.rate_limit_delay,
267
+ )
268
+ break
269
+ except RateLimitError as e:
270
+ if attempt >= args.retries:
271
+ raise
272
+ time.sleep(e.wait_seconds)
273
+ except Exception: # noqa: BLE001
274
+ if attempt >= args.retries:
275
+ raise
276
+ time.sleep(args.retry_delay)
277
+ if args.request_interval > 0:
278
+ time.sleep(args.request_interval)
279
+
280
+ parsed = split_boxed_generation(generation)
281
+ if not parsed:
282
+ print(f"[{idx}/{total}] skipped=format")
283
+ continue
284
+
285
+ thinking_text, boxed_segment, boxed_answer = parsed
286
+ target_text = f"{thinking_text}\n{boxed_answer}" if thinking_text else boxed_answer
287
+ judge_messages = build_judge_messages(reference_answer, boxed_answer)
288
+
289
+ ok = False
290
+ judge_resp = ""
291
+ for attempt in range(args.retries + 1):
292
+ try:
293
+ judge_resp = call_chat_api(
294
+ args.api_base,
295
+ api_key,
296
+ args.judge_model,
297
+ judge_messages,
298
+ timeout=args.api_timeout,
299
+ max_tokens=64,
300
+ temperature=0.0,
301
+ cache_ttl=args.api_cache_ttl,
302
+ cache_namespace=args.api_cache_namespace,
303
+ rate_limit_delay=args.rate_limit_delay,
304
+ )
305
+ ok = parse_bool(judge_resp)
306
+ break
307
+ except RateLimitError as e:
308
+ if attempt >= args.retries:
309
+ raise
310
+ time.sleep(e.wait_seconds)
311
+ except Exception: # noqa: BLE001
312
+ if attempt >= args.retries:
313
+ raise
314
+ time.sleep(args.retry_delay)
315
+ if args.judge_interval > 0:
316
+ time.sleep(args.judge_interval)
317
+
318
+ status = "kept" if ok else "filtered"
319
+ print(f"[{idx}/{total}] judge={status}")
320
+ if not ok:
321
+ continue
322
+
323
+ new_meta = dict(ex.metadata)
324
+ new_meta["reference_answer"] = reference_answer
325
+ new_meta["judge_response"] = judge_resp
326
+
327
+ new_ex = CachedExample(
328
+ prompt=ex.prompt,
329
+ target=target_text,
330
+ indices_to_explain=None,
331
+ attr_mask_indices=ex.attr_mask_indices,
332
+ sink_span=None,
333
+ thinking_span=None,
334
+ metadata=new_meta,
335
+ )
336
+ new_ex = attach_spans_from_answer(new_ex, tokenizer, boxed_answer)
337
+ if not (isinstance(new_ex.sink_span, list) and len(new_ex.sink_span) == 2):
338
+ print(f"[{idx}/{total}] skipped=span")
339
+ continue
340
+
341
+ # Token-level indices_to_explain: boxed-inner answer token span in target (closed interval).
342
+ new_ex = CachedExample(
343
+ prompt=new_ex.prompt,
344
+ target=new_ex.target,
345
+ indices_to_explain=new_ex.sink_span,
346
+ attr_mask_indices=new_ex.attr_mask_indices,
347
+ sink_span=new_ex.sink_span,
348
+ thinking_span=new_ex.thinking_span,
349
+ metadata=new_ex.metadata,
350
+ )
351
+ kept.append(new_ex)
352
+ kept_bar.update(1)
353
+
354
+ kept_bar.close()
355
+
356
+ out_path = Path(args.out) if args.out else Path(args.data_root) / f"{args.dataset}.jsonl"
357
+ written = write_cache(out_path, kept)
358
+ attempted_total = attempted or 0
359
+ print(f"Kept {written} / target {args.max_examples} (attempted {attempted_total} / {total}) -> {out_path}")
360
+
361
+
362
+ if __name__ == "__main__":
363
+ main()
exp/exp3/README.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashTrace 实验 3:长/短 CoT 对比(case study)
2
+
3
+ 本目录提供一个「长/短 CoT」的最小可复现实验:
4
+ - 从 RULER `niah_mq_q2 (1024)` 中分别筛出:
5
+ - short-CoT:短推理 + `\box{}` 最终答案
6
+ - long-CoT:长推理 + `\box{}` 最终答案
7
+ - 只跑 `attnlrp`(hop0)并只计算 token-level `recovery@10%`(gold 来自 `needle_spans`)。
8
+ - 落盘 trace(npz + manifest)到 `exp/exp3/output/`,格式对齐 `exp/exp2/run_exp.py` 的 trace 习惯。
9
+
10
+ ## 1) 采样与过滤(生成 + judge)
11
+
12
+ 默认读取:
13
+ `data/ruler_multihop/1024/niah_mq_q2/validation.jsonl`
14
+
15
+ 需要一个 OpenAI-compatible 的 chat API(默认 `http://localhost:4000/v1`)以及 API key。
16
+
17
+ ```bash
18
+ export FLASHTRACE_API_KEY=... # 或 OPENAI_API_KEY
19
+
20
+ python exp/exp3/sample_and_filter.py \
21
+ --tokenizer_model /opt/share/models/Qwen/Qwen3-8B/ \
22
+ --min_long_thinking_tokens 512 \
23
+ --max_short_thinking_tokens 256
24
+ ```
25
+
26
+ 输出(默认):
27
+ - `exp/exp3/data/niah_mq_q2_short_cot.jsonl`
28
+ - `exp/exp3/data/niah_mq_q2_long_cot.jsonl`
29
+
30
+ 说明:
31
+ - 默认各采 1 条;可用 `--max_short` / `--max_long` 分别指定数量(`--max_pairs` 是两者的兼容别名)。
32
+
33
+ ## 2) 归因与 recovery(AttnLRP hop0)
34
+
35
+ ```bash
36
+ python exp/exp3/run_exp.py \
37
+ --model qwen-8B \
38
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
39
+ --cuda 3,4,5,7
40
+ ```
41
+
42
+ 输出:
43
+ - recovery CSV:`exp/exp3/output/recovery/<dataset>/<model>/attnlrp_1_examples.csv`
44
+ - trace:`exp/exp3/output/traces/<dataset>/<model>/<run_tag>/ex_*.npz` + `manifest.jsonl`
45
+ - 汇总 JSON:`exp/exp3/output/recovery/summary_<model>.json`
46
+
47
+ 常用参数:
48
+ - `--top_fraction`:recovery 的 top fraction(默认 0.1)
49
+ - `--attnlrp_neg_handling drop|abs`
50
+ - `--attnlrp_norm_mode norm|no_norm`
exp/exp3/extract_segment_weights.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract CoT/output segment attribution weights from exp3 trace artifacts.
4
+
5
+ Background
6
+ ----------
7
+ exp/exp3/run_exp.py saves per-sample trace npz files that contain token-level
8
+ importance vectors over the FULL (prompt + generation) token sequence:
9
+ - v_seq_all: sum over rows of seq attribution matrix (shape [P+G])
10
+ - v_row_all: row attribution vector for indices_to_explain (shape [P+G])
11
+ - v_rec_all: recursive attribution vector for indices_to_explain (shape [P+G])
12
+
13
+ For exp3 cached samples, we also have generation-token spans:
14
+ - thinking_span_gen: CoT span [start,end] in generation-token coordinates
15
+ - sink_span_gen: output span [start,end] in generation-token coordinates
16
+
17
+ This script slices v_*_all into:
18
+ - cot: tokens in thinking_span_gen (offset by prompt_len)
19
+ - output: tokens in sink_span_gen (offset by prompt_len)
20
+
21
+ and reports segment sums/fractions (and optionally writes a JSON summary).
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import json
28
+ from dataclasses import dataclass
29
+ from pathlib import Path
30
+ from typing import Any, Dict, List, Optional, Tuple
31
+
32
+ import numpy as np
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class TracePaths:
37
+ dataset: str
38
+ model_tag: str
39
+ run_tag: str
40
+ npz_path: Path
41
+
42
+
43
+ def _pick_latest_subdir(path: Path) -> Optional[Path]:
44
+ if not path.exists():
45
+ return None
46
+ subs = [p for p in path.iterdir() if p.is_dir()]
47
+ if not subs:
48
+ return None
49
+ subs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
50
+ return subs[0]
51
+
52
+
53
+ def _resolve_trace_paths(
54
+ *,
55
+ output_root: Path,
56
+ dataset: str,
57
+ model_tag: Optional[str],
58
+ run_tag: Optional[str],
59
+ example_idx: int,
60
+ ) -> TracePaths:
61
+ base = output_root / "traces" / dataset
62
+ if not base.exists():
63
+ raise FileNotFoundError(f"Trace dataset dir not found: {base}")
64
+
65
+ if model_tag is None:
66
+ model_dirs = [p for p in base.iterdir() if p.is_dir()]
67
+ if not model_dirs:
68
+ raise FileNotFoundError(f"No model subdir under: {base}")
69
+ if len(model_dirs) != 1:
70
+ raise SystemExit(f"Multiple model dirs under {base}; pass --model_tag. Found: {[p.name for p in model_dirs]}")
71
+ model_dir = model_dirs[0]
72
+ model_tag = model_dir.name
73
+ else:
74
+ model_dir = base / model_tag
75
+ if not model_dir.exists():
76
+ raise FileNotFoundError(f"Trace model dir not found: {model_dir}")
77
+
78
+ if run_tag is None:
79
+ run_dir = _pick_latest_subdir(model_dir)
80
+ if run_dir is None:
81
+ raise FileNotFoundError(f"No run subdir under: {model_dir}")
82
+ run_tag = run_dir.name
83
+ else:
84
+ run_dir = model_dir / run_tag
85
+ if not run_dir.exists():
86
+ raise FileNotFoundError(f"Trace run dir not found: {run_dir}")
87
+
88
+ npz_name = f"ex_{int(example_idx):06d}.npz"
89
+ npz_path = run_dir / npz_name
90
+ if not npz_path.exists():
91
+ raise FileNotFoundError(f"Trace npz not found: {npz_path}")
92
+
93
+ return TracePaths(dataset=dataset, model_tag=model_tag, run_tag=run_tag, npz_path=npz_path)
94
+
95
+
96
+ def _as_span(arr: Any) -> Optional[Tuple[int, int]]:
97
+ if arr is None:
98
+ return None
99
+ try:
100
+ a = np.asarray(arr).reshape(-1).tolist()
101
+ except Exception:
102
+ return None
103
+ if len(a) != 2:
104
+ return None
105
+ try:
106
+ start = int(a[0])
107
+ end = int(a[1])
108
+ except Exception:
109
+ return None
110
+ if start < 0 or end < start:
111
+ return None
112
+ return start, end
113
+
114
+
115
+ def _segment_stats(v: np.ndarray, start: int, end: int) -> Dict[str, float]:
116
+ if end < start:
117
+ return {"sum": 0.0, "mean": 0.0, "max": 0.0}
118
+ seg = v[start : end + 1]
119
+ if seg.size == 0:
120
+ return {"sum": 0.0, "mean": 0.0, "max": 0.0}
121
+ return {
122
+ "sum": float(seg.sum()),
123
+ "mean": float(seg.mean()),
124
+ "max": float(seg.max()),
125
+ }
126
+
127
+
128
+ def _slice_segment(v: np.ndarray, start: int, end: int) -> List[float]:
129
+ if end < start:
130
+ return []
131
+ seg = v[start : end + 1]
132
+ return [float(x) for x in seg.tolist()]
133
+
134
+
135
+ def extract_one(npz_path: Path) -> Dict[str, Any]:
136
+ d = np.load(npz_path)
137
+ required = ["prompt_len", "gen_len", "v_seq_all", "v_row_all", "v_rec_all"]
138
+ for k in required:
139
+ if k not in d:
140
+ raise KeyError(f"Missing key in trace npz {npz_path}: {k}")
141
+
142
+ prompt_len = int(np.asarray(d["prompt_len"]).item())
143
+ gen_len = int(np.asarray(d["gen_len"]).item())
144
+ total_len = prompt_len + gen_len
145
+
146
+ v_seq_all = np.asarray(d["v_seq_all"], dtype=np.float64).reshape(-1)
147
+ v_row_all = np.asarray(d["v_row_all"], dtype=np.float64).reshape(-1)
148
+ v_rec_all = np.asarray(d["v_rec_all"], dtype=np.float64).reshape(-1)
149
+ for name, v in [("v_seq_all", v_seq_all), ("v_row_all", v_row_all), ("v_rec_all", v_rec_all)]:
150
+ if int(v.size) != int(total_len):
151
+ raise ValueError(f"{name} length mismatch: expected {total_len}, got {int(v.size)}")
152
+
153
+ sink_span_gen = _as_span(d.get("sink_span_gen"))
154
+ thinking_span_gen = _as_span(d.get("thinking_span_gen"))
155
+ if sink_span_gen is None:
156
+ raise KeyError("Trace missing sink_span_gen; cannot define output span.")
157
+ if thinking_span_gen is None:
158
+ # Best-effort: infer thinking span as [0, sink_start-1].
159
+ sink_start, _ = sink_span_gen
160
+ thinking_span_gen = (0, max(0, sink_start - 1))
161
+
162
+ think_start_g, think_end_g = thinking_span_gen
163
+ sink_start_g, sink_end_g = sink_span_gen
164
+
165
+ cot_start = prompt_len + think_start_g
166
+ cot_end = min(prompt_len + think_end_g, total_len - 1)
167
+ out_start = prompt_len + sink_start_g
168
+ out_end = min(prompt_len + sink_end_g, total_len - 1)
169
+
170
+ def pack(v: np.ndarray) -> Dict[str, Any]:
171
+ total = float(v.sum())
172
+ cot = _segment_stats(v, cot_start, cot_end)
173
+ out = _segment_stats(v, out_start, out_end)
174
+ denom = cot["sum"] + out["sum"]
175
+ return {
176
+ "total_sum": total,
177
+ "cot": {
178
+ "start_abs": int(cot_start),
179
+ "end_abs": int(cot_end),
180
+ "len": int(max(0, cot_end - cot_start + 1)),
181
+ **cot,
182
+ "fraction_of_total": float(cot["sum"] / total) if total > 0 else float("nan"),
183
+ "fraction_of_cot_plus_output": float(cot["sum"] / denom) if denom > 0 else float("nan"),
184
+ },
185
+ "output": {
186
+ "start_abs": int(out_start),
187
+ "end_abs": int(out_end),
188
+ "len": int(max(0, out_end - out_start + 1)),
189
+ **out,
190
+ "fraction_of_total": float(out["sum"] / total) if total > 0 else float("nan"),
191
+ "fraction_of_cot_plus_output": float(out["sum"] / denom) if denom > 0 else float("nan"),
192
+ },
193
+ "cot_weights": _slice_segment(v, cot_start, cot_end),
194
+ "output_weights": _slice_segment(v, out_start, out_end),
195
+ }
196
+
197
+ return {
198
+ "prompt_len": int(prompt_len),
199
+ "gen_len": int(gen_len),
200
+ "total_len": int(total_len),
201
+ "thinking_span_gen": [int(think_start_g), int(think_end_g)],
202
+ "sink_span_gen": [int(sink_start_g), int(sink_end_g)],
203
+ "seq": pack(v_seq_all),
204
+ "row": pack(v_row_all),
205
+ "rec": pack(v_rec_all),
206
+ }
207
+
208
+
209
+ def main() -> None:
210
+ parser = argparse.ArgumentParser("Extract CoT/output weights from exp3 traces.")
211
+ parser.add_argument("--output_root", type=str, default="exp/exp3/output")
212
+ parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2")
213
+ parser.add_argument("--model_tag", type=str, default=None, help="If omitted, auto-detect when unique.")
214
+ parser.add_argument("--run_tag", type=str, default=None, help="If omitted, picks the latest run subdir.")
215
+ parser.add_argument("--example_idx", type=int, default=0)
216
+ parser.add_argument("--out", type=str, default=None, help="Optional JSON output path.")
217
+ args = parser.parse_args()
218
+
219
+ output_root = Path(args.output_root)
220
+ datasets = [f"{args.dataset_tag}_short_cot", f"{args.dataset_tag}_long_cot"]
221
+
222
+ results: List[Dict[str, Any]] = []
223
+ for ds_name in datasets:
224
+ paths = _resolve_trace_paths(
225
+ output_root=output_root,
226
+ dataset=ds_name,
227
+ model_tag=args.model_tag,
228
+ run_tag=args.run_tag,
229
+ example_idx=args.example_idx,
230
+ )
231
+ out = extract_one(paths.npz_path)
232
+ out["dataset"] = paths.dataset
233
+ out["model_tag"] = paths.model_tag
234
+ out["run_tag"] = paths.run_tag
235
+ out["npz_path"] = str(paths.npz_path)
236
+ results.append(out)
237
+
238
+ text = json.dumps(results, ensure_ascii=False, indent=2)
239
+ if args.out:
240
+ out_path = Path(args.out)
241
+ out_path.parent.mkdir(parents=True, exist_ok=True)
242
+ out_path.write_text(text + "\n", encoding="utf-8")
243
+ print(f"Wrote -> {out_path}")
244
+ else:
245
+ print(text)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()
250
+
exp/exp3/part_weights.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Compute attribution mass on (input, cot, output) segments from exp3 trace npz files.
4
+
5
+ Definitions (token-level, aligned with exp2/exp3 runners):
6
+ - input : prompt-side tokens (user prompt), indices [0, prompt_len)
7
+ - cot : generation tokens in thinking span, indices [prompt_len + t0, prompt_len + t1]
8
+ - output : generation tokens in sink span (answer), indices [prompt_len + s0, prompt_len + s1]
9
+
10
+ The trace stores token-importance vectors:
11
+ - v_seq_all, v_row_all, v_rec_all (length = prompt_len + gen_len)
12
+
13
+ This script sums those vectors over each segment and reports both absolute sums
14
+ and fractions of the total sum.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import json
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import Dict, Iterable, List, Optional, Tuple
24
+
25
+ import numpy as np
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class TraceRun:
30
+ dataset: str
31
+ model: str
32
+ run_dir: Path
33
+
34
+
35
+ def _pick_single_subdir(parent: Path) -> Path:
36
+ subdirs = [p for p in parent.iterdir() if p.is_dir()]
37
+ if not subdirs:
38
+ raise FileNotFoundError(f"No subdirectories found under {parent}")
39
+ if len(subdirs) == 1:
40
+ return subdirs[0]
41
+ subdirs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
42
+ return subdirs[0]
43
+
44
+
45
+ def _resolve_run(
46
+ trace_root: Path,
47
+ *,
48
+ dataset: str,
49
+ model: Optional[str],
50
+ run_tag: Optional[str],
51
+ ) -> TraceRun:
52
+ ds_dir = trace_root / dataset
53
+ if not ds_dir.exists():
54
+ raise FileNotFoundError(f"Dataset trace directory not found: {ds_dir}")
55
+
56
+ if model is None:
57
+ model_dir = _pick_single_subdir(ds_dir)
58
+ else:
59
+ model_dir = ds_dir / model
60
+ if not model_dir.exists():
61
+ raise FileNotFoundError(f"Model trace directory not found: {model_dir}")
62
+
63
+ if run_tag is None:
64
+ run_dir = _pick_single_subdir(model_dir)
65
+ else:
66
+ run_dir = model_dir / run_tag
67
+ if not run_dir.exists():
68
+ raise FileNotFoundError(f"Run directory not found: {run_dir}")
69
+
70
+ return TraceRun(dataset=dataset, model=model_dir.name, run_dir=run_dir)
71
+
72
+
73
+ def _iter_manifest(run_dir: Path) -> Iterable[dict]:
74
+ manifest = run_dir / "manifest.jsonl"
75
+ if not manifest.exists():
76
+ raise FileNotFoundError(f"Missing manifest: {manifest}")
77
+ with manifest.open("r", encoding="utf-8") as f:
78
+ for line in f:
79
+ line = line.strip()
80
+ if line:
81
+ yield json.loads(line)
82
+
83
+
84
+ def _as_span(arr: np.ndarray, *, name: str) -> Tuple[int, int]:
85
+ if arr is None:
86
+ raise ValueError(f"Missing {name} in trace npz.")
87
+ a = np.asarray(arr).reshape(-1)
88
+ if a.size != 2:
89
+ raise ValueError(f"Expected {name} to have 2 ints, got shape {a.shape}.")
90
+ return int(a[0]), int(a[1])
91
+
92
+
93
+ def _segment_sums(
94
+ v: np.ndarray,
95
+ *,
96
+ prompt_len: int,
97
+ gen_len: int,
98
+ thinking_span_gen: Optional[Tuple[int, int]],
99
+ sink_span_gen: Optional[Tuple[int, int]],
100
+ ) -> Dict[str, float]:
101
+ total_len = int(prompt_len) + int(gen_len)
102
+ if int(v.shape[0]) != total_len:
103
+ raise ValueError(f"Vector length mismatch: len(v)={int(v.shape[0])} vs prompt_len+gen_len={total_len}.")
104
+
105
+ v = np.asarray(v, dtype=np.float64).reshape(-1)
106
+ prompt_len = int(prompt_len)
107
+ gen_len = int(gen_len)
108
+
109
+ # Default: no cot/output when spans missing (should not happen in exp3).
110
+ think_start, think_end = (0, -1) if thinking_span_gen is None else thinking_span_gen
111
+ sink_start, sink_end = (0, -1) if sink_span_gen is None else sink_span_gen
112
+
113
+ # Clamp spans into [0, gen_len-1].
114
+ def _clamp_span(a: int, b: int) -> Tuple[int, int]:
115
+ a = max(0, min(int(a), gen_len - 1))
116
+ b = max(0, min(int(b), gen_len - 1))
117
+ if b < a:
118
+ return 0, -1
119
+ return a, b
120
+
121
+ think_start, think_end = _clamp_span(think_start, think_end)
122
+ sink_start, sink_end = _clamp_span(sink_start, sink_end)
123
+
124
+ mask = np.zeros((total_len,), dtype=bool)
125
+ # input = all prompt tokens
126
+ input_slice = slice(0, prompt_len)
127
+ mask[input_slice] = True
128
+
129
+ cot_slice = slice(prompt_len + think_start, prompt_len + think_end + 1) if think_end >= think_start else slice(0, 0)
130
+ output_slice = slice(prompt_len + sink_start, prompt_len + sink_end + 1) if sink_end >= sink_start else slice(0, 0)
131
+ mask[cot_slice] = True
132
+ mask[output_slice] = True
133
+
134
+ input_sum = float(v[input_slice].sum())
135
+ cot_sum = float(v[cot_slice].sum()) if think_end >= think_start else 0.0
136
+ output_sum = float(v[output_slice].sum()) if sink_end >= sink_start else 0.0
137
+ other_sum = float(v[~mask].sum())
138
+ total_sum = float(v.sum())
139
+
140
+ return {
141
+ "total": total_sum,
142
+ "input": input_sum,
143
+ "cot": cot_sum,
144
+ "output": output_sum,
145
+ "other": other_sum,
146
+ }
147
+
148
+
149
+ def _with_fracs(sums: Dict[str, float]) -> Dict[str, float]:
150
+ total = float(sums.get("total") or 0.0)
151
+ if total <= 0.0:
152
+ return {**sums, "input_frac": float("nan"), "cot_frac": float("nan"), "output_frac": float("nan"), "other_frac": float("nan")}
153
+ return {
154
+ **sums,
155
+ "input_frac": float(sums["input"]) / total,
156
+ "cot_frac": float(sums["cot"]) / total,
157
+ "output_frac": float(sums["output"]) / total,
158
+ "other_frac": float(sums["other"]) / total,
159
+ }
160
+
161
+
162
+ def _analyze_npz(npz_path: Path) -> Dict[str, dict]:
163
+ d = np.load(npz_path)
164
+ prompt_len = int(np.asarray(d["prompt_len"]).item())
165
+ gen_len = int(np.asarray(d["gen_len"]).item())
166
+ thinking_span_gen = _as_span(d["thinking_span_gen"], name="thinking_span_gen") if "thinking_span_gen" in d.files else None
167
+ sink_span_gen = _as_span(d["sink_span_gen"], name="sink_span_gen") if "sink_span_gen" in d.files else None
168
+
169
+ out: Dict[str, dict] = {"prompt_len": prompt_len, "gen_len": gen_len}
170
+ for key in ("v_seq_all", "v_row_all", "v_rec_all"):
171
+ if key not in d.files:
172
+ raise ValueError(f"Missing {key} in trace npz: {npz_path}")
173
+ sums = _segment_sums(
174
+ d[key],
175
+ prompt_len=prompt_len,
176
+ gen_len=gen_len,
177
+ thinking_span_gen=thinking_span_gen,
178
+ sink_span_gen=sink_span_gen,
179
+ )
180
+ out[key] = _with_fracs(sums)
181
+ out["thinking_span_gen"] = list(thinking_span_gen) if thinking_span_gen is not None else None
182
+ out["sink_span_gen"] = list(sink_span_gen) if sink_span_gen is not None else None
183
+ return out
184
+
185
+
186
+ def main() -> None:
187
+ parser = argparse.ArgumentParser("Summarize input/cot/output attribution mass from exp3 traces.")
188
+ parser.add_argument("--trace_root", type=str, default="exp/exp3/output/traces")
189
+ parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2", help="Base tag; expands to <tag>_short_cot and <tag>_long_cot.")
190
+ parser.add_argument("--datasets", type=str, default=None, help="Comma-separated dataset names (overrides --dataset_tag expansion).")
191
+ parser.add_argument("--model", type=str, default=None, help="Model directory name under traces (default: auto if single).")
192
+ parser.add_argument("--run_tag", type=str, default=None, help="Run tag directory (default: auto pick newest/single).")
193
+ args = parser.parse_args()
194
+
195
+ trace_root = Path(args.trace_root)
196
+ if not trace_root.exists():
197
+ raise SystemExit(f"trace_root not found: {trace_root}")
198
+
199
+ if args.datasets:
200
+ datasets = [x.strip() for x in str(args.datasets).split(",") if x.strip()]
201
+ else:
202
+ datasets = [f"{args.dataset_tag}_short_cot", f"{args.dataset_tag}_long_cot"]
203
+
204
+ for ds in datasets:
205
+ run = _resolve_run(trace_root, dataset=ds, model=args.model, run_tag=args.run_tag)
206
+ records = list(_iter_manifest(run.run_dir))
207
+ if not records:
208
+ raise SystemExit(f"Empty manifest: {run.run_dir/'manifest.jsonl'}")
209
+ for rec in records:
210
+ npz_path = run.run_dir / str(rec["file"])
211
+ analysis = _analyze_npz(npz_path)
212
+ print(
213
+ json.dumps(
214
+ {
215
+ "dataset": run.dataset,
216
+ "model": run.model,
217
+ "run_dir": str(run.run_dir),
218
+ "example_idx": int(rec.get("example_idx", -1)),
219
+ **analysis,
220
+ },
221
+ ensure_ascii=False,
222
+ )
223
+ )
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()
228
+
exp/exp3/run_exp.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Experiment 3 runner: long-vs-short CoT case study (AttnLRP hop0, Recovery@10%).
4
+
5
+ This runner is intentionally minimal:
6
+ - Only reads two cached samples produced by exp/exp3/sample_and_filter.py:
7
+ <dataset_tag>_short_cot.jsonl
8
+ <dataset_tag>_long_cot.jsonl
9
+ - Only runs attribution method: attnlrp (hop0 path, aligned with exp2).
10
+ - Only computes token-level recovery (Recall@10%) using RULER needle_spans.
11
+ - Always saves per-sample trace artifacts under exp/exp3/output/traces/.
12
+
13
+ All outputs are written under exp/exp3/output/ (configurable via --output_root).
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import hashlib
20
+ import json
21
+ import os
22
+ import sys
23
+ import time
24
+ from itertools import islice
25
+ from pathlib import Path
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+
29
+ def _early_set_cuda_visible_devices() -> None:
30
+ parser = argparse.ArgumentParser(add_help=False)
31
+ parser.add_argument("--cuda", type=str, default=None)
32
+ args, _ = parser.parse_known_args(sys.argv[1:])
33
+ if args.cuda and "," in args.cuda:
34
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
35
+
36
+
37
+ _early_set_cuda_visible_devices()
38
+
39
+ import numpy as np
40
+ import torch
41
+ from transformers import AutoModelForCausalLM, AutoTokenizer, utils
42
+
43
+ REPO_ROOT = Path(__file__).resolve().parents[2]
44
+ if str(REPO_ROOT) not in sys.path:
45
+ sys.path.insert(0, str(REPO_ROOT))
46
+
47
+ import llm_attr
48
+ import llm_attr_eval
49
+ from exp.exp2 import dataset_utils as ds_utils
50
+
51
+ utils.logging.set_verbosity_error()
52
+
53
+
54
+ def _sha1_text(text: str) -> str:
55
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
56
+
57
+
58
+ def _token_importance_vector(attr: torch.Tensor) -> np.ndarray:
59
+ w = torch.nan_to_num(attr.sum(0).to(dtype=torch.float32), nan=0.0).clamp(min=0.0)
60
+ return w.detach().cpu().numpy().astype(np.float32, copy=False)
61
+
62
+
63
+ def _trace_run_tag(*, neg_handling: str, norm_mode: str, total: int) -> str:
64
+ return f"attnlrp_neg{neg_handling}_norm{norm_mode}_recovery_{int(total)}ex"
65
+
66
+
67
+ def _build_sample_trace_payload(
68
+ example: ds_utils.CachedExample,
69
+ *,
70
+ seq_attr: torch.Tensor,
71
+ row_attr: torch.Tensor,
72
+ rec_attr: torch.Tensor,
73
+ prompt_len: int,
74
+ user_prompt_indices: Optional[List[int]],
75
+ gold_prompt_token_indices: Optional[List[int]],
76
+ recovery_scores: Optional[np.ndarray],
77
+ time_attr_s: Optional[float],
78
+ time_recovery_s: Optional[float],
79
+ ) -> Dict[str, np.ndarray]:
80
+ gen_len = int(seq_attr.shape[0])
81
+
82
+ v_seq_all = _token_importance_vector(seq_attr)
83
+ v_row_all = _token_importance_vector(row_attr)
84
+ v_rec_all = _token_importance_vector(rec_attr)
85
+
86
+ payload: Dict[str, np.ndarray] = {
87
+ "v_seq_all": v_seq_all,
88
+ "v_row_all": v_row_all,
89
+ "v_rec_all": v_rec_all,
90
+ "v_seq_prompt": v_seq_all[:prompt_len],
91
+ "v_row_prompt": v_row_all[:prompt_len],
92
+ "v_rec_prompt": v_rec_all[:prompt_len],
93
+ "prompt_len": np.asarray(int(prompt_len), dtype=np.int64),
94
+ "gen_len": np.asarray(int(gen_len), dtype=np.int64),
95
+ "indices_to_explain_gen": np.asarray(list(example.indices_to_explain or []), dtype=np.int64),
96
+ }
97
+
98
+ if example.sink_span is not None:
99
+ payload["sink_span_gen"] = np.asarray(list(example.sink_span), dtype=np.int64)
100
+ if example.thinking_span is not None:
101
+ payload["thinking_span_gen"] = np.asarray(list(example.thinking_span), dtype=np.int64)
102
+
103
+ if user_prompt_indices is not None:
104
+ payload["user_prompt_indices"] = np.asarray(list(user_prompt_indices), dtype=np.int64)
105
+ if gold_prompt_token_indices is not None:
106
+ payload["gold_prompt_token_indices"] = np.asarray(list(gold_prompt_token_indices), dtype=np.int64)
107
+
108
+ if recovery_scores is not None:
109
+ payload["recovery_scores"] = np.asarray(recovery_scores, dtype=np.float64)
110
+
111
+ if time_attr_s is not None:
112
+ payload["time_attr_s"] = np.asarray(float(time_attr_s), dtype=np.float64)
113
+ if time_recovery_s is not None:
114
+ payload["time_recovery_s"] = np.asarray(float(time_recovery_s), dtype=np.float64)
115
+
116
+ return payload
117
+
118
+
119
+ def _write_sample_trace(
120
+ trace_dir: Path,
121
+ *,
122
+ example_idx: int,
123
+ prompt: str,
124
+ target: str,
125
+ payload: Dict[str, np.ndarray],
126
+ manifest_handle,
127
+ neg_handling: str,
128
+ norm_mode: str,
129
+ recovery_skipped_reason: Optional[str],
130
+ ) -> None:
131
+ trace_dir.mkdir(parents=True, exist_ok=True)
132
+ npz_name = f"ex_{example_idx:06d}.npz"
133
+ npz_path = trace_dir / npz_name
134
+ np.savez_compressed(npz_path, **payload)
135
+
136
+ prompt_len = int(np.asarray(payload.get("prompt_len", 0)).item())
137
+ gen_len = int(np.asarray(payload.get("gen_len", 0)).item())
138
+ record: Dict[str, Any] = {
139
+ "example_idx": int(example_idx),
140
+ "attr_func": "attnlrp",
141
+ "file": npz_name,
142
+ "prompt_sha1": _sha1_text(prompt),
143
+ "target_sha1": _sha1_text(target),
144
+ "prompt_len": prompt_len,
145
+ "gen_len": gen_len,
146
+ "indices_to_explain_gen": payload.get("indices_to_explain_gen").tolist()
147
+ if payload.get("indices_to_explain_gen") is not None
148
+ else None,
149
+ "sink_span_gen": payload.get("sink_span_gen").tolist() if payload.get("sink_span_gen") is not None else None,
150
+ "thinking_span_gen": payload.get("thinking_span_gen").tolist()
151
+ if payload.get("thinking_span_gen") is not None
152
+ else None,
153
+ "gold_prompt_token_indices": payload.get("gold_prompt_token_indices").tolist()
154
+ if payload.get("gold_prompt_token_indices") is not None
155
+ else None,
156
+ "recovery_scores": payload.get("recovery_scores").tolist() if payload.get("recovery_scores") is not None else None,
157
+ "recovery_skipped_reason": recovery_skipped_reason,
158
+ "time_attr_s": float(np.asarray(payload.get("time_attr_s")).item()) if payload.get("time_attr_s") is not None else None,
159
+ "time_recovery_s": float(np.asarray(payload.get("time_recovery_s")).item())
160
+ if payload.get("time_recovery_s") is not None
161
+ else None,
162
+ "attnlrp_neg_handling": str(neg_handling),
163
+ "attnlrp_norm_mode": str(norm_mode),
164
+ }
165
+ manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
166
+ manifest_handle.flush()
167
+
168
+
169
+ def resolve_device(args) -> str:
170
+ if args.cuda is not None and "," in args.cuda:
171
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
172
+ return "auto"
173
+ if args.cuda is not None and str(args.cuda).strip():
174
+ return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
175
+ return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
176
+
177
+
178
+ def load_model(model_name: str, device: str):
179
+ model = AutoModelForCausalLM.from_pretrained(
180
+ model_name,
181
+ device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
182
+ torch_dtype=torch.float16,
183
+ attn_implementation="eager",
184
+ )
185
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
186
+ tokenizer.pad_token = tokenizer.eos_token
187
+ model.eval()
188
+ return model, tokenizer
189
+
190
+
191
+ def _evaluate_one_dataset(
192
+ *,
193
+ dataset_name: str,
194
+ examples: List[ds_utils.CachedExample],
195
+ model,
196
+ tokenizer,
197
+ output_root: Path,
198
+ model_tag: str,
199
+ neg_handling: str,
200
+ norm_mode: str,
201
+ top_fraction: float,
202
+ num_examples: int,
203
+ ) -> Tuple[np.ndarray, np.ndarray, float, int, int]:
204
+ llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(model, tokenizer)
205
+
206
+ results: List[np.ndarray] = []
207
+ durations: List[float] = []
208
+ skipped = 0
209
+
210
+ total = min(len(examples), int(num_examples))
211
+ iterator = islice(examples, total)
212
+
213
+ run_tag = _trace_run_tag(neg_handling=neg_handling, norm_mode=norm_mode, total=total)
214
+ trace_dir = output_root / "traces" / dataset_name / model_tag / run_tag
215
+ trace_dir.mkdir(parents=True, exist_ok=True)
216
+ manifest_handle = open(trace_dir / "manifest.jsonl", "w", encoding="utf-8")
217
+
218
+ try:
219
+ for example_idx, ex in enumerate(iterator):
220
+ time_recovery_s: Optional[float] = None
221
+ recovery_scores: Optional[np.ndarray] = None
222
+
223
+ needle_spans = (ex.metadata or {}).get("needle_spans")
224
+ if not isinstance(needle_spans, list) or not needle_spans:
225
+ raise SystemExit(
226
+ "exp3 recovery requires RULER samples with metadata.needle_spans; "
227
+ f"dataset={dataset_name} has missing/empty needle_spans."
228
+ )
229
+ if ex.target is None:
230
+ raise SystemExit(
231
+ "exp3 recovery requires cached targets (CoT+answer) so row/rec attribution is well-defined. "
232
+ f"dataset={dataset_name} has target=None; run exp/exp3/sample_and_filter.py first."
233
+ )
234
+ if not (isinstance(ex.indices_to_explain, list) and len(ex.indices_to_explain) == 2):
235
+ raise SystemExit(
236
+ "exp3 expects indices_to_explain=[start_tok,end_tok] in generation-token coordinates; "
237
+ f"dataset={dataset_name} has indices_to_explain={ex.indices_to_explain!r}; "
238
+ "run exp/exp3/sample_and_filter.py first."
239
+ )
240
+
241
+ gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer)
242
+ recovery_skip_reason: Optional[str] = None
243
+
244
+ sample_start = time.perf_counter()
245
+ llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
246
+ attr_result = llm_attributor.calculate_attnlrp_ft_hop0(
247
+ ex.prompt,
248
+ target=ex.target,
249
+ sink_span=tuple(ex.sink_span) if ex.sink_span else None,
250
+ thinking_span=tuple(ex.thinking_span) if ex.thinking_span else None,
251
+ neg_handling=str(neg_handling),
252
+ norm_mode=str(norm_mode),
253
+ )
254
+ seq_attr, row_attr, rec_attr = attr_result.get_all_token_attrs(list(ex.indices_to_explain))
255
+ time_attr_s = time.perf_counter() - sample_start
256
+ durations.append(float(time_attr_s))
257
+
258
+ prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0])
259
+ if prompt_len <= 0:
260
+ recovery_skip_reason = "empty_prompt_len"
261
+ elif not gold_prompt:
262
+ recovery_skip_reason = "empty_gold_prompt"
263
+ else:
264
+ t2 = time.perf_counter()
265
+ recovery_scores = np.asarray(
266
+ [
267
+ llm_evaluator.evaluate_attr_recovery(
268
+ a,
269
+ prompt_len=prompt_len,
270
+ gold_prompt_token_indices=gold_prompt,
271
+ top_fraction=top_fraction,
272
+ )
273
+ for a in (seq_attr, row_attr, rec_attr)
274
+ ],
275
+ dtype=np.float64,
276
+ )
277
+ time_recovery_s = time.perf_counter() - t2
278
+ if np.isnan(recovery_scores).any():
279
+ recovery_scores = None
280
+ recovery_skip_reason = "nan_recovery"
281
+
282
+ if recovery_scores is None and recovery_skip_reason is not None:
283
+ skipped += 1
284
+ elif recovery_scores is not None:
285
+ results.append(recovery_scores)
286
+
287
+ payload = _build_sample_trace_payload(
288
+ ex,
289
+ seq_attr=seq_attr,
290
+ row_attr=row_attr,
291
+ rec_attr=rec_attr,
292
+ prompt_len=prompt_len,
293
+ user_prompt_indices=getattr(llm_attributor, "user_prompt_indices", None),
294
+ gold_prompt_token_indices=gold_prompt,
295
+ recovery_scores=recovery_scores,
296
+ time_attr_s=time_attr_s,
297
+ time_recovery_s=time_recovery_s,
298
+ )
299
+ _write_sample_trace(
300
+ trace_dir,
301
+ example_idx=example_idx,
302
+ prompt=ex.prompt,
303
+ target=str(ex.target),
304
+ payload=payload,
305
+ manifest_handle=manifest_handle,
306
+ neg_handling=str(neg_handling),
307
+ norm_mode=str(norm_mode),
308
+ recovery_skipped_reason=recovery_skip_reason,
309
+ )
310
+ finally:
311
+ try:
312
+ manifest_handle.close()
313
+ except Exception:
314
+ pass
315
+
316
+ scores = np.stack(results, axis=0) if results else np.zeros((0, 3), dtype=np.float64)
317
+ used = int(scores.shape[0])
318
+ mean = scores.mean(0) if used else np.full((3,), np.nan, dtype=np.float64)
319
+ std = scores.std(0) if used else np.full((3,), np.nan, dtype=np.float64)
320
+ avg_time = float(np.mean(durations)) if durations else 0.0
321
+ return mean, std, avg_time, used, int(skipped)
322
+
323
+
324
+ def main() -> None:
325
+ parser = argparse.ArgumentParser("Experiment 3 runner (attnlrp hop0, recovery only).")
326
+ parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2", help="Base tag for exp3 caches.")
327
+ parser.add_argument("--data_root", type=str, default="exp/exp3/data")
328
+ parser.add_argument("--output_root", type=str, default="exp/exp3/output")
329
+ parser.add_argument("--num_examples", type=int, default=1, help="How many examples to evaluate per dataset (default 1).")
330
+ parser.add_argument("--seed", type=int, default=42)
331
+ parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).")
332
+ parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.")
333
+ parser.add_argument("--cuda_num", type=int, default=0)
334
+ parser.add_argument("--cuda", type=str, default=None)
335
+ parser.add_argument("--top_fraction", type=float, default=0.1, help="Top fraction of prompt tokens used for recovery.")
336
+ parser.add_argument(
337
+ "--attnlrp_neg_handling",
338
+ type=str,
339
+ choices=["drop", "abs"],
340
+ default="drop",
341
+ help="AttnLRP hop0: how to handle negative values (drop=clamp>=0, abs=absolute value).",
342
+ )
343
+ parser.add_argument(
344
+ "--attnlrp_norm_mode",
345
+ type=str,
346
+ choices=["norm", "no_norm"],
347
+ default="norm",
348
+ help="AttnLRP hop0: norm enables internal normalization; no_norm disables it.",
349
+ )
350
+ args = parser.parse_args()
351
+
352
+ if args.model_path:
353
+ model_name = args.model_path
354
+ elif args.model:
355
+ model_name = args.model
356
+ else:
357
+ raise SystemExit("Please set --model or --model_path.")
358
+ model_tag = args.model if args.model else Path(args.model_path).name
359
+
360
+ device = resolve_device(args)
361
+ model, tokenizer = load_model(model_name, device)
362
+
363
+ data_root = Path(args.data_root)
364
+ output_root = Path(args.output_root)
365
+ output_root.mkdir(parents=True, exist_ok=True)
366
+
367
+ short_name = f"{args.dataset_tag}_short_cot"
368
+ long_name = f"{args.dataset_tag}_long_cot"
369
+ dataset_names = [short_name, long_name]
370
+
371
+ summary_rows: List[Dict[str, Any]] = []
372
+
373
+ for ds_name in dataset_names:
374
+ cache_path = data_root / f"{ds_name}.jsonl"
375
+ if not cache_path.exists():
376
+ raise SystemExit(f"Missing exp3 cache: {cache_path}. Run exp/exp3/sample_and_filter.py first.")
377
+ examples = ds_utils.load_cached(cache_path, sample=None, seed=args.seed)
378
+
379
+ mean, std, avg_time, used, skipped = _evaluate_one_dataset(
380
+ dataset_name=ds_name,
381
+ examples=examples,
382
+ model=model,
383
+ tokenizer=tokenizer,
384
+ output_root=output_root,
385
+ model_tag=model_tag,
386
+ neg_handling=args.attnlrp_neg_handling,
387
+ norm_mode=args.attnlrp_norm_mode,
388
+ top_fraction=float(args.top_fraction),
389
+ num_examples=int(args.num_examples),
390
+ )
391
+
392
+ out_dir = output_root / "recovery" / ds_name / model_tag
393
+ out_dir.mkdir(parents=True, exist_ok=True)
394
+ filename = f"attnlrp_{int(args.num_examples)}_examples.csv"
395
+ with (out_dir / filename).open("w", encoding="utf-8") as f:
396
+ f.write("Method,Recovery@10%\n")
397
+ f.write(f"Seq Attr Recovery Mean,{mean[0]}\n")
398
+ f.write(f"Row Attr Recovery Mean,{mean[1]}\n")
399
+ f.write(f"Recursive Attr Recovery Mean,{mean[2]}\n")
400
+ f.write(f"Seq Attr Recovery Std,{std[0]}\n")
401
+ f.write(f"Row Attr Recovery Std,{std[1]}\n")
402
+ f.write(f"Recursive Attr Recovery Std,{std[2]}\n")
403
+ f.write(f"Examples Used,{used}\n")
404
+ f.write(f"Examples Skipped,{skipped}\n")
405
+ f.write(f"Avg Sample Time (s),{avg_time}\n")
406
+
407
+ print(f"[{ds_name}] attnlrp -> {out_dir/filename} (used={used} skipped={skipped} avg {avg_time:.2f}s)")
408
+ summary_rows.append(
409
+ {
410
+ "dataset": ds_name,
411
+ "model": model_tag,
412
+ "neg_handling": args.attnlrp_neg_handling,
413
+ "norm_mode": args.attnlrp_norm_mode,
414
+ "seq_recovery@10%": float(mean[0]) if used else float("nan"),
415
+ "row_recovery@10%": float(mean[1]) if used else float("nan"),
416
+ "rec_recovery@10%": float(mean[2]) if used else float("nan"),
417
+ "used": int(used),
418
+ "skipped": int(skipped),
419
+ }
420
+ )
421
+
422
+ # Lightweight combined summary for quick comparison.
423
+ summary_path = output_root / "recovery" / f"summary_{model_tag}.json"
424
+ summary_path.parent.mkdir(parents=True, exist_ok=True)
425
+ summary_path.write_text(json.dumps(summary_rows, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
426
+ print(f"Wrote summary -> {summary_path}")
427
+
428
+
429
+ if __name__ == "__main__":
430
+ main()
exp/exp3/sample_and_filter.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Experiment 3 sampler: long-vs-short CoT case study (RULER niah_mq_q2, 1024).
4
+
5
+ This script searches the raw RULER JSONL for a *single* prompt where BOTH:
6
+ - a short-CoT generation and a long-CoT generation
7
+ - follow the strict format: "<thinking text> + final \\box{...} answer" with
8
+ nothing after the box
9
+ - pass a judge model verifying the boxed answer matches the reference answer
10
+ - satisfy length constraints (short <= max_short_thinking_tokens,
11
+ long >= min_long_thinking_tokens)
12
+
13
+ It writes two exp2-compatible cache JSONLs to exp/exp3/data/:
14
+ - <dataset_tag>_short_cot.jsonl
15
+ - <dataset_tag>_long_cot.jsonl
16
+
17
+ Each JSONL line matches exp/exp2/dataset_utils.CachedExample schema and keeps
18
+ RULER metadata. The output caches are intended to be consumed by exp/exp3/run_exp.py.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import hashlib
25
+ import json
26
+ import os
27
+ import sys
28
+ import time
29
+ import urllib.error
30
+ import urllib.request
31
+ from dataclasses import dataclass
32
+ from pathlib import Path
33
+ from typing import Any, Dict, Iterable, List, Optional
34
+
35
+ from tqdm import tqdm
36
+ from transformers import AutoTokenizer
37
+
38
+ REPO_ROOT = Path(__file__).resolve().parents[2]
39
+ if str(REPO_ROOT) not in sys.path:
40
+ sys.path.insert(0, str(REPO_ROOT))
41
+
42
+ from exp.exp2 import dataset_utils as ds_utils
43
+ from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer, split_boxed_generation
44
+
45
+
46
+ class RateLimitError(RuntimeError):
47
+ """Raised when API returns 429; carries a suggested wait time."""
48
+
49
+ def __init__(self, wait_seconds: float, detail: str) -> None:
50
+ super().__init__(detail)
51
+ self.wait_seconds = wait_seconds
52
+
53
+
54
+ SHORT_COT_SYSTEM_PROMPT = (
55
+ "You are a reasoning assistant. "
56
+ "Before answering, engage in a brief chain of thought. "
57
+ "Process this freely and naturally without using specific headers or strict formatting. "
58
+ "When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
59
+ "Ensure the box wraps the **sentence** that naturally delivers the answer. "
60
+ "Do not add anything after the box."
61
+ )
62
+
63
+ LONG_COT_SYSTEM_PROMPT = (
64
+ "You are a careful reasoning assistant. "
65
+ "Before answering, engage in an extremely detailed and exhaustive chain of thought. "
66
+ "Do not skip any logical steps, even if they seem obvious. "
67
+ "Process this freely and naturally without using specific headers or strict formatting. "
68
+ "When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
69
+ "Ensure the box wraps the **sentence** that naturally delivers the answer. "
70
+ "Do not add anything after the box."
71
+ )
72
+
73
+ JUDGE_SYSTEM_PROMPT = (
74
+ "You verify whether the model's boxed answer matches the reference answer. "
75
+ "Reply strictly with True or False and nothing else."
76
+ )
77
+
78
+
79
+ def _sha1_text(text: str) -> str:
80
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
81
+
82
+
83
+ def call_chat_api(
84
+ api_base: str,
85
+ api_key: str,
86
+ model: str,
87
+ messages: List[Dict[str, str]],
88
+ *,
89
+ timeout: int,
90
+ max_tokens: int,
91
+ temperature: float,
92
+ cache_ttl: int,
93
+ cache_namespace: Optional[str],
94
+ rate_limit_delay: Optional[float] = None,
95
+ ) -> str:
96
+ url = api_base.rstrip("/") + "/chat/completions"
97
+ payload: Dict[str, Any] = {
98
+ "model": model,
99
+ "messages": messages,
100
+ "max_tokens": max_tokens,
101
+ "temperature": temperature,
102
+ }
103
+ if cache_ttl > 0:
104
+ cache_obj: Dict[str, Any] = {"ttl": cache_ttl}
105
+ if cache_namespace:
106
+ cache_obj["namespace"] = cache_namespace
107
+ payload["cache"] = cache_obj
108
+
109
+ data = json.dumps(payload).encode("utf-8")
110
+ headers = {"Content-Type": "application/json"}
111
+ if api_key:
112
+ headers["Authorization"] = f"Bearer {api_key}"
113
+
114
+ req = urllib.request.Request(url, data=data, headers=headers, method="POST")
115
+ opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
116
+ try:
117
+ with opener.open(req, timeout=timeout) as resp:
118
+ resp_bytes = resp.read()
119
+ except urllib.error.HTTPError as e:
120
+ detail = e.read().decode("utf-8", errors="ignore") if hasattr(e, "read") else ""
121
+ if e.code == 429:
122
+ retry_after = None
123
+ if hasattr(e, "headers") and e.headers:
124
+ retry_after_header = e.headers.get("Retry-After")
125
+ if retry_after_header:
126
+ try:
127
+ retry_after = float(retry_after_header)
128
+ except ValueError:
129
+ retry_after = None
130
+ wait = retry_after or rate_limit_delay or 5.0
131
+ raise RateLimitError(wait, f"API HTTP 429: {detail}") from e
132
+ raise RuntimeError(f"API HTTP error {e.code}: {detail}") from e
133
+ except urllib.error.URLError as e:
134
+ raise RuntimeError(f"API request failed: {e}") from e
135
+
136
+ try:
137
+ response = json.loads(resp_bytes.decode("utf-8"))
138
+ except json.JSONDecodeError as e:
139
+ raise RuntimeError(f"Failed to decode API response: {resp_bytes!r}") from e
140
+
141
+ choices = response.get("choices", [])
142
+ if not choices:
143
+ raise RuntimeError(f"Empty choices from API: {response}")
144
+ content = choices[0].get("message", {}).get("content", "")
145
+ if not content:
146
+ raise RuntimeError(f"Empty content from API: {response}")
147
+ return content.strip()
148
+
149
+
150
+ def _call_with_retries(
151
+ *,
152
+ api_base: str,
153
+ api_key: str,
154
+ model: str,
155
+ messages: List[Dict[str, str]],
156
+ timeout: int,
157
+ max_tokens: int,
158
+ temperature: float,
159
+ cache_ttl: int,
160
+ cache_namespace: Optional[str],
161
+ rate_limit_delay: float,
162
+ retries: int,
163
+ retry_delay: float,
164
+ ) -> str:
165
+ for attempt in range(retries + 1):
166
+ try:
167
+ return call_chat_api(
168
+ api_base,
169
+ api_key,
170
+ model,
171
+ messages,
172
+ timeout=timeout,
173
+ max_tokens=max_tokens,
174
+ temperature=temperature,
175
+ cache_ttl=cache_ttl,
176
+ cache_namespace=cache_namespace,
177
+ rate_limit_delay=rate_limit_delay,
178
+ )
179
+ except RateLimitError as e:
180
+ if attempt >= retries:
181
+ raise
182
+ time.sleep(e.wait_seconds)
183
+ except Exception: # noqa: BLE001
184
+ if attempt >= retries:
185
+ raise
186
+ time.sleep(retry_delay)
187
+ raise RuntimeError("Unreachable")
188
+
189
+
190
+ def build_gen_messages(prompt: str, system_prompt: str) -> List[Dict[str, str]]:
191
+ return [
192
+ {"role": "system", "content": system_prompt},
193
+ {"role": "user", "content": prompt},
194
+ ]
195
+
196
+
197
+ def build_judge_messages(reference_answer: str, candidate_answer: str) -> List[Dict[str, str]]:
198
+ user = (
199
+ "Decide if the model's boxed answer matches the reference answer.\n"
200
+ f"Reference answer: {reference_answer}\n"
201
+ f"Model boxed answer (only the content inside \\box{{}}): {candidate_answer}\n"
202
+ "Output only True if they are semantically consistent; otherwise output False."
203
+ )
204
+ return [
205
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
206
+ {"role": "user", "content": user},
207
+ ]
208
+
209
+
210
+ def parse_bool(text: str) -> bool:
211
+ first = text.strip().splitlines()[0].strip().lower()
212
+ if first in {"true", "yes"}:
213
+ return True
214
+ if first in {"false", "no"}:
215
+ return False
216
+ if "true" in first and "false" not in first:
217
+ return True
218
+ if "false" in first:
219
+ return False
220
+ raise ValueError(f"Cannot parse boolean from: {text!r}")
221
+
222
+
223
+ def write_cache(out_path: Path, examples: Iterable[CachedExample]) -> int:
224
+ out_path.parent.mkdir(parents=True, exist_ok=True)
225
+ count = 0
226
+ with out_path.open("w", encoding="utf-8") as f:
227
+ for ex in examples:
228
+ obj: Dict[str, Any] = {
229
+ "prompt": ex.prompt,
230
+ "target": ex.target,
231
+ "indices_to_explain": ex.indices_to_explain,
232
+ "attr_mask_indices": ex.attr_mask_indices,
233
+ "sink_span": ex.sink_span,
234
+ "thinking_span": ex.thinking_span,
235
+ "metadata": ex.metadata,
236
+ }
237
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
238
+ count += 1
239
+ return count
240
+
241
+
242
+ @dataclass(frozen=True)
243
+ class AcceptedGeneration:
244
+ thinking_text: str
245
+ boxed_answer: str
246
+ target_text: str
247
+ thinking_tokens: int
248
+ generation_text: str
249
+ judge_response: str
250
+
251
+
252
+ def _infer_reference_answer(example: CachedExample) -> str:
253
+ meta = example.metadata or {}
254
+ ref = str(meta.get("reference_answer") or "").strip()
255
+ if ref:
256
+ return ref
257
+ outputs = meta.get("outputs") or []
258
+ if isinstance(outputs, list) and outputs:
259
+ return ", ".join(str(x) for x in outputs)
260
+ tgt = str(example.target or "").strip()
261
+ return tgt
262
+
263
+
264
+ def _infer_dataset_tag(dataset_path: Path) -> str:
265
+ if dataset_path.name.endswith(".jsonl") and dataset_path.name != "validation.jsonl":
266
+ return dataset_path.stem
267
+ if dataset_path.name == "validation.jsonl":
268
+ return dataset_path.parent.name
269
+ return dataset_path.stem
270
+
271
+
272
+ def _count_tokens(tokenizer, text: str) -> int:
273
+ return int(len(tokenizer(text, add_special_tokens=False).input_ids))
274
+
275
+
276
+ def _generate_one_style(
277
+ *,
278
+ prompt: str,
279
+ reference_answer: str,
280
+ tokenizer,
281
+ style: str,
282
+ system_prompt: str,
283
+ api_base: str,
284
+ api_key: str,
285
+ generator_model: str,
286
+ judge_model: str,
287
+ timeout: int,
288
+ max_tokens: int,
289
+ temperature: float,
290
+ cache_ttl: int,
291
+ cache_namespace: Optional[str],
292
+ rate_limit_delay: float,
293
+ retries: int,
294
+ retry_delay: float,
295
+ request_interval: float,
296
+ judge_interval: float,
297
+ min_long_thinking_tokens: int,
298
+ max_short_thinking_tokens: int,
299
+ ) -> Optional[AcceptedGeneration]:
300
+ gen_messages = build_gen_messages(prompt, system_prompt)
301
+ generation = _call_with_retries(
302
+ api_base=api_base,
303
+ api_key=api_key,
304
+ model=generator_model,
305
+ messages=gen_messages,
306
+ timeout=timeout,
307
+ max_tokens=max_tokens,
308
+ temperature=temperature,
309
+ cache_ttl=cache_ttl,
310
+ cache_namespace=cache_namespace,
311
+ rate_limit_delay=rate_limit_delay,
312
+ retries=retries,
313
+ retry_delay=retry_delay,
314
+ )
315
+ if request_interval > 0:
316
+ time.sleep(request_interval)
317
+
318
+ parsed = split_boxed_generation(generation)
319
+ if not parsed:
320
+ return None
321
+ thinking_text, _boxed_segment, boxed_answer = parsed
322
+ thinking_tokens = _count_tokens(tokenizer, thinking_text)
323
+
324
+ if style == "short":
325
+ if max_short_thinking_tokens > 0 and thinking_tokens > max_short_thinking_tokens:
326
+ return None
327
+ elif style == "long":
328
+ if min_long_thinking_tokens > 0 and thinking_tokens < min_long_thinking_tokens:
329
+ return None
330
+ else:
331
+ raise ValueError(f"Unsupported style: {style}")
332
+
333
+ judge_messages = build_judge_messages(reference_answer, boxed_answer)
334
+ judge_resp = _call_with_retries(
335
+ api_base=api_base,
336
+ api_key=api_key,
337
+ model=judge_model,
338
+ messages=judge_messages,
339
+ timeout=timeout,
340
+ max_tokens=64,
341
+ temperature=0.0,
342
+ cache_ttl=cache_ttl,
343
+ cache_namespace=cache_namespace,
344
+ rate_limit_delay=rate_limit_delay,
345
+ retries=retries,
346
+ retry_delay=retry_delay,
347
+ )
348
+ if judge_interval > 0:
349
+ time.sleep(judge_interval)
350
+ ok = parse_bool(judge_resp)
351
+ if not ok:
352
+ return None
353
+
354
+ target_text = f"{thinking_text}\n{boxed_answer}" if thinking_text else boxed_answer
355
+ return AcceptedGeneration(
356
+ thinking_text=thinking_text,
357
+ boxed_answer=boxed_answer,
358
+ target_text=target_text,
359
+ thinking_tokens=thinking_tokens,
360
+ generation_text=generation,
361
+ judge_response=judge_resp,
362
+ )
363
+
364
+
365
+ def main() -> None:
366
+ parser = argparse.ArgumentParser("Sample short-CoT and long-CoT cases for exp3 (independently).")
367
+ parser.add_argument(
368
+ "--dataset_path",
369
+ type=str,
370
+ default="data/ruler_multihop/1024/niah_mq_q2/validation.jsonl",
371
+ help="Raw RULER JSONL path (default: niah_mq_q2 1024 validation).",
372
+ )
373
+ parser.add_argument("--dataset_tag", type=str, default=None, help="Output tag; default inferred from dataset_path.")
374
+ parser.add_argument(
375
+ "--max_pairs",
376
+ type=int,
377
+ default=1,
378
+ help="Deprecated alias for --max_short and --max_long (kept for convenience).",
379
+ )
380
+ parser.add_argument("--max_short", type=int, default=None, help="How many short-CoT samples to keep (default: --max_pairs).")
381
+ parser.add_argument("--max_long", type=int, default=None, help="How many long-CoT samples to keep (default: --max_pairs).")
382
+ parser.add_argument("--max_raw_examples", type=int, default=None, help="Optional cap on raw examples to try.")
383
+ parser.add_argument("--seed", type=int, default=42)
384
+ parser.add_argument("--api_base", type=str, default="http://localhost:4000/v1", help="Chat API base URL.")
385
+ parser.add_argument("--api_key", type=str, default=None, help="API key; defaults to FLASHTRACE_API_KEY/OPENAI_API_KEY.")
386
+ parser.add_argument("--generator_model", type=str, default="qwen3-235b-a22b-2507")
387
+ parser.add_argument("--judge_model", type=str, default="deepseek-v3-1-terminus")
388
+ parser.add_argument("--api_timeout", type=int, default=300)
389
+ parser.add_argument("--api_temperature", type=float, default=0.0)
390
+ parser.add_argument("--api_cache_ttl", type=int, default=600)
391
+ parser.add_argument("--api_cache_namespace", type=str, default="flashtrace-exp3")
392
+ parser.add_argument("--retry_delay", type=float, default=2.0)
393
+ parser.add_argument("--retries", type=int, default=2, help="Additional retries on API failure.")
394
+ parser.add_argument("--request_interval", type=float, default=1.0, help="Sleep seconds between generation calls.")
395
+ parser.add_argument("--judge_interval", type=float, default=1.0, help="Sleep seconds between judge calls.")
396
+ parser.add_argument("--rate_limit_delay", type=float, default=5.0, help="Seconds to wait on HTTP 429 before retrying.")
397
+ parser.add_argument(
398
+ "--api_max_tokens_short",
399
+ type=int,
400
+ default=2048,
401
+ help="Max tokens for the short-CoT generation call.",
402
+ )
403
+ parser.add_argument(
404
+ "--api_max_tokens_long",
405
+ type=int,
406
+ default=8192,
407
+ help="Max tokens for the long-CoT generation call.",
408
+ )
409
+ parser.add_argument(
410
+ "--min_long_thinking_tokens",
411
+ type=int,
412
+ default=512,
413
+ help="Minimum tokenizer tokens required in the long-CoT thinking segment.",
414
+ )
415
+ parser.add_argument(
416
+ "--max_short_thinking_tokens",
417
+ type=int,
418
+ default=256,
419
+ help="Maximum tokenizer tokens allowed in the short-CoT thinking segment.",
420
+ )
421
+ parser.add_argument(
422
+ "--tokenizer_model",
423
+ type=str,
424
+ default=None,
425
+ help="Tokenizer path for span extraction & length constraints (default: generator_model).",
426
+ )
427
+ parser.add_argument("--data_root", type=str, default="exp/exp3/data", help="Output directory for exp3 caches.")
428
+ parser.add_argument("--out_short", type=str, default=None, help="Optional explicit output path (short JSONL).")
429
+ parser.add_argument("--out_long", type=str, default=None, help="Optional explicit output path (long JSONL).")
430
+ args = parser.parse_args()
431
+
432
+ api_key = args.api_key or os.environ.get("FLASHTRACE_API_KEY") or os.environ.get("OPENAI_API_KEY")
433
+ if not api_key:
434
+ raise SystemExit("Set --api_key or FLASHTRACE_API_KEY/OPENAI_API_KEY for API access.")
435
+
436
+ dataset_path = Path(args.dataset_path)
437
+ if not dataset_path.exists():
438
+ raise SystemExit(f"Dataset file not found: {dataset_path}")
439
+ dataset_tag = str(args.dataset_tag or _infer_dataset_tag(dataset_path))
440
+
441
+ tok_name = args.tokenizer_model or args.generator_model
442
+ tok_path = Path(tok_name)
443
+ if tok_path.exists():
444
+ tokenizer = AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
445
+ else:
446
+ tokenizer = AutoTokenizer.from_pretrained(tok_name)
447
+ tokenizer.pad_token = tokenizer.eos_token
448
+
449
+ raw_examples = ds_utils.load_ruler(dataset_path, sample=None, seed=args.seed)
450
+ if not raw_examples:
451
+ raise SystemExit("No examples loaded from the RULER JSONL.")
452
+
453
+ max_short = int(args.max_short) if args.max_short is not None else int(args.max_pairs)
454
+ max_long = int(args.max_long) if args.max_long is not None else int(args.max_pairs)
455
+ if max_short < 0 or max_long < 0:
456
+ raise SystemExit("--max_short/--max_long must be >= 0.")
457
+
458
+ kept_short: List[CachedExample] = []
459
+ kept_long: List[CachedExample] = []
460
+
461
+ total = len(raw_examples)
462
+ attempted = 0
463
+
464
+ for idx, ex in enumerate(tqdm(raw_examples, total=total, desc="Scanning raw RULER"), 1):
465
+ attempted = idx
466
+ if args.max_raw_examples is not None and idx > int(args.max_raw_examples):
467
+ break
468
+ if len(kept_short) >= max_short and len(kept_long) >= max_long:
469
+ break
470
+
471
+ reference_answer = _infer_reference_answer(ex)
472
+ prompt = ex.prompt
473
+
474
+ sample_id = _sha1_text(prompt)
475
+ base_meta = dict(ex.metadata or {})
476
+ base_meta["reference_answer"] = reference_answer
477
+ base_meta["sample_id"] = sample_id
478
+ base_meta["pair_id"] = sample_id # backward-compatible name (may not be paired)
479
+ base_meta["source_dataset_path"] = str(dataset_path)
480
+ base_meta["prompt_sha1"] = sample_id
481
+
482
+ if len(kept_short) < max_short:
483
+ short_gen = _generate_one_style(
484
+ prompt=prompt,
485
+ reference_answer=reference_answer,
486
+ tokenizer=tokenizer,
487
+ style="short",
488
+ system_prompt=SHORT_COT_SYSTEM_PROMPT,
489
+ api_base=args.api_base,
490
+ api_key=api_key,
491
+ generator_model=args.generator_model,
492
+ judge_model=args.judge_model,
493
+ timeout=args.api_timeout,
494
+ max_tokens=args.api_max_tokens_short,
495
+ temperature=args.api_temperature,
496
+ cache_ttl=args.api_cache_ttl,
497
+ cache_namespace=args.api_cache_namespace,
498
+ rate_limit_delay=args.rate_limit_delay,
499
+ retries=args.retries,
500
+ retry_delay=args.retry_delay,
501
+ request_interval=args.request_interval,
502
+ judge_interval=args.judge_interval,
503
+ min_long_thinking_tokens=args.min_long_thinking_tokens,
504
+ max_short_thinking_tokens=args.max_short_thinking_tokens,
505
+ )
506
+ if short_gen is not None:
507
+ short_meta = dict(base_meta)
508
+ short_meta.update(
509
+ {
510
+ "cot_style": "short",
511
+ "generator_model": args.generator_model,
512
+ "judge_model": args.judge_model,
513
+ "judge_response": short_gen.judge_response,
514
+ "boxed_answer": short_gen.boxed_answer,
515
+ "thinking_tokens": int(short_gen.thinking_tokens),
516
+ }
517
+ )
518
+ short_ex = CachedExample(
519
+ prompt=prompt,
520
+ target=short_gen.target_text,
521
+ indices_to_explain=None,
522
+ attr_mask_indices=ex.attr_mask_indices,
523
+ sink_span=None,
524
+ thinking_span=None,
525
+ metadata=short_meta,
526
+ )
527
+ short_ex = attach_spans_from_answer(short_ex, tokenizer, short_gen.boxed_answer)
528
+ if isinstance(short_ex.sink_span, list) and len(short_ex.sink_span) == 2:
529
+ short_ex = CachedExample(
530
+ prompt=short_ex.prompt,
531
+ target=short_ex.target,
532
+ indices_to_explain=short_ex.sink_span,
533
+ attr_mask_indices=short_ex.attr_mask_indices,
534
+ sink_span=short_ex.sink_span,
535
+ thinking_span=short_ex.thinking_span,
536
+ metadata=short_ex.metadata,
537
+ )
538
+ kept_short.append(short_ex)
539
+ print(
540
+ f"[kept short] raw_idx={idx}/{total} thinking_tokens={short_gen.thinking_tokens} "
541
+ f"sample_id={sample_id[:8]} kept={len(kept_short)}/{max_short}"
542
+ )
543
+
544
+ if len(kept_long) < max_long:
545
+ long_gen = _generate_one_style(
546
+ prompt=prompt,
547
+ reference_answer=reference_answer,
548
+ tokenizer=tokenizer,
549
+ style="long",
550
+ system_prompt=LONG_COT_SYSTEM_PROMPT,
551
+ api_base=args.api_base,
552
+ api_key=api_key,
553
+ generator_model=args.generator_model,
554
+ judge_model=args.judge_model,
555
+ timeout=args.api_timeout,
556
+ max_tokens=args.api_max_tokens_long,
557
+ temperature=args.api_temperature,
558
+ cache_ttl=args.api_cache_ttl,
559
+ cache_namespace=args.api_cache_namespace,
560
+ rate_limit_delay=args.rate_limit_delay,
561
+ retries=args.retries,
562
+ retry_delay=args.retry_delay,
563
+ request_interval=args.request_interval,
564
+ judge_interval=args.judge_interval,
565
+ min_long_thinking_tokens=args.min_long_thinking_tokens,
566
+ max_short_thinking_tokens=args.max_short_thinking_tokens,
567
+ )
568
+ if long_gen is not None:
569
+ long_meta = dict(base_meta)
570
+ long_meta.update(
571
+ {
572
+ "cot_style": "long",
573
+ "generator_model": args.generator_model,
574
+ "judge_model": args.judge_model,
575
+ "judge_response": long_gen.judge_response,
576
+ "boxed_answer": long_gen.boxed_answer,
577
+ "thinking_tokens": int(long_gen.thinking_tokens),
578
+ }
579
+ )
580
+ long_ex = CachedExample(
581
+ prompt=prompt,
582
+ target=long_gen.target_text,
583
+ indices_to_explain=None,
584
+ attr_mask_indices=ex.attr_mask_indices,
585
+ sink_span=None,
586
+ thinking_span=None,
587
+ metadata=long_meta,
588
+ )
589
+ long_ex = attach_spans_from_answer(long_ex, tokenizer, long_gen.boxed_answer)
590
+ if isinstance(long_ex.sink_span, list) and len(long_ex.sink_span) == 2:
591
+ long_ex = CachedExample(
592
+ prompt=long_ex.prompt,
593
+ target=long_ex.target,
594
+ indices_to_explain=long_ex.sink_span,
595
+ attr_mask_indices=long_ex.attr_mask_indices,
596
+ sink_span=long_ex.sink_span,
597
+ thinking_span=long_ex.thinking_span,
598
+ metadata=long_ex.metadata,
599
+ )
600
+ kept_long.append(long_ex)
601
+ print(
602
+ f"[kept long] raw_idx={idx}/{total} thinking_tokens={long_gen.thinking_tokens} "
603
+ f"sample_id={sample_id[:8]} kept={len(kept_long)}/{max_long}"
604
+ )
605
+
606
+ data_root = Path(args.data_root)
607
+ out_short = Path(args.out_short) if args.out_short else data_root / f"{dataset_tag}_short_cot.jsonl"
608
+ out_long = Path(args.out_long) if args.out_long else data_root / f"{dataset_tag}_long_cot.jsonl"
609
+
610
+ n_short = write_cache(out_short, kept_short)
611
+ n_long = write_cache(out_long, kept_long)
612
+ print(
613
+ f"Wrote short={n_short} -> {out_short}\n"
614
+ f"Wrote long ={n_long} -> {out_long}\n"
615
+ f"Attempted {attempted} / {total}"
616
+ )
617
+
618
+ missing: List[str] = []
619
+ if len(kept_short) < max_short:
620
+ missing.append(f"short({len(kept_short)}/{max_short})")
621
+ if len(kept_long) < max_long:
622
+ missing.append(f"long({len(kept_long)}/{max_long})")
623
+ if missing:
624
+ raise SystemExit(f"Could not find enough samples: {', '.join(missing)} (attempted {attempted} / {total}).")
625
+
626
+
627
+ if __name__ == "__main__":
628
+ main()
exp/exp4/README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashTrace 实验 4(Aider 归因忠实度 / row-only)
2
+
3
+ 本目录提供 Aider 数据集上的 token-level 归因忠实度评测工具,**只输出 row 部分的 RISE/MAS**,不保存样本级 trace。
4
+
5
+ 评测范围(固定):
6
+ - 数据集:`exp/exp4/data/aider.jsonl`
7
+ - 方法:
8
+ - `ifr_all_positions`
9
+ - `ifr_multi_hop_both`(FlashTrace)
10
+ - 指标:`RISE`、`MAS`(row attribution only)
11
+
12
+ 主要文件:
13
+ - `run_exp.py`:归因 + 忠实度评测,输出到 `exp/exp4/output/`
14
+
15
+ ---
16
+
17
+ ## 数据格式
18
+
19
+ `exp/exp4/data/aider.jsonl` 每行一个 JSON,对应一个样本:
20
+ - `input`:prompt(直接作为 user prompt 内容)
21
+ - `output`:target(直接作为模型生成文本;脚本会内部追加 EOS 做打分)
22
+ - `length`:数据自带字段(脚本不依赖,仅透传到 metadata)
23
+
24
+ 说明:Aider 的 `output` 形如:
25
+ 1) 第一行 `xxx.py`
26
+ 2) 第二行 opening fence ```
27
+ 3) 中间为代码
28
+ 4) 最后一行为 closing fence ```
29
+
30
+ ---
31
+
32
+ ## 归因与 sink 选择
33
+
34
+ 脚本对每个样本都将 `input` 作为 `prompt`,将 `output` 作为 `target`(不做重新生成),并在归因结果上选择不同的 sink(`indices_to_explain=[start_tok,end_tok]`,均基于 `tokenizer(target, add_special_tokens=False)` 的 token span;不含 EOS)。
35
+
36
+ ### `ifr_all_positions`(输出两个 sink)
37
+
38
+ - `last_line`:取 `output` 中 **closing fence 之前最后一个“非空且非 ```”行**,并将该行的字符 span 映射到 token span;若无法解析则回退为 `full_output`。
39
+ - `last_token`:取 `last_line` 的最后一个 token(单点 span `[end,end]`)。
40
+
41
+ 注意:脚本会对同一个样本只计算一次 `ifr_all_positions` 的归因矩阵,然后分别在两个 sink 上取 row attribution 并计算忠实度。
42
+
43
+ ### `ifr_multi_hop_both`(FlashTrace,只输出一个 sink)
44
+
45
+ - `full_output`:用完整 `output` 作为 sink(token span `[0, n_tok-1]`)。
46
+ - 忠实度扰动侧会沿用 exp2 的协议:对 prompt-side 会跳过 stop tokens(由 `ft_ifr_improve.py` 的 stop-token 配置决定)。
47
+
48
+ ---
49
+
50
+ ## 指标输出(row-only)
51
+
52
+ 输出 CSV 仅包含 row attribution 的 `RISE/MAS` 聚合统计:
53
+ - `Method,Sink,Row_RISE_Mean,Row_RISE_Std,Row_MAS_Mean,Row_MAS_Std,Used,Skipped,Avg_Sample_Time_s`
54
+
55
+ 输出路径:
56
+ - `exp/exp4/output/faithfulness/aider/<model_tag>/row_only_<N>_examples.csv`
57
+
58
+ 其中 `<model_tag>` 优先取 `--model`,否则取 `--model_path` 的目录名。
59
+
60
+ ---
61
+
62
+ ## 使用说明
63
+
64
+ 推荐从 repo root 运行(保证相对路径可用):
65
+
66
+ ```bash
67
+ python exp/exp4/run_exp.py \
68
+ --data_path exp/exp4/data/aider.jsonl \
69
+ --output_root exp/exp4/output \
70
+ --model qwen-8B \
71
+ --model_path /opt/share/models/Qwen/Qwen3-8B/ \
72
+ --cuda 2,3,4,5,6,7 \
73
+ --num_examples 100 \
74
+ --n_hops 1 \
75
+ --k 20
76
+ ```
77
+
78
+ 常用参数:
79
+ - `--model_path` / `--model`:本地模型路径或 HF repo id(至少提供其一)
80
+ - `--tokenizer_path`:可选;不提供则默认复用模型路径/id
81
+ - `--cuda`:支持 `0`(单卡)或 `0,1,2`(多卡,内部会设置 `CUDA_VISIBLE_DEVICES` 并用 `device_map=auto`)
82
+ - `--num_examples`:评测前 N 条(按文件顺序;`--seed` 预留,当前不做随机抽样)
83
+ - `--n_hops`:FlashTrace(`ifr_multi_hop_both`)的 hop 数
84
+ - `--k`:MAS/RISE 的扰动步数
85
+ - `--chunk_tokens` / `--sink_chunk_tokens`:IFR 计算的 chunk 参数(一般保持默认)
exp/exp4/run_exp.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Experiment 4 runner: Aider token-level attribution faithfulness.
4
+
5
+ Evaluates only:
6
+ - IFR: ifr_all_positions
7
+ - sink = last meaningful code line (excluding fences)
8
+ - sink = last token of that code line
9
+ - FlashTrace: ifr_multi_hop_both
10
+ - sink = full output (excluding appended EOS)
11
+
12
+ Outputs only row-level faithfulness scores (RISE, MAS). No sample-level traces.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import sys
21
+ import time
22
+ from dataclasses import dataclass
23
+ from itertools import islice
24
+ from pathlib import Path
25
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
26
+
27
+
28
+ def _early_set_cuda_visible_devices() -> None:
29
+ parser = argparse.ArgumentParser(add_help=False)
30
+ parser.add_argument("--cuda", type=str, default=None)
31
+ args, _ = parser.parse_known_args(sys.argv[1:])
32
+ if args.cuda and "," in args.cuda:
33
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
34
+
35
+
36
+ _early_set_cuda_visible_devices()
37
+
38
+ import numpy as np
39
+ import torch
40
+ from transformers import AutoModelForCausalLM, AutoTokenizer, utils
41
+
42
+ # Ensure repo root on path for `import llm_attr`, `import ft_ifr_improve`, etc.
43
+ REPO_ROOT = Path(__file__).resolve().parents[2]
44
+ if str(REPO_ROOT) not in sys.path:
45
+ sys.path.insert(0, str(REPO_ROOT))
46
+
47
+ import ft_ifr_improve
48
+ import llm_attr
49
+ import llm_attr_eval
50
+
51
+ utils.logging.set_verbosity_error()
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class AiderExample:
56
+ prompt: str
57
+ target: str
58
+ metadata: Dict[str, Any]
59
+
60
+
61
+ def _read_jsonl(path: Path) -> List[Dict[str, Any]]:
62
+ rows: List[Dict[str, Any]] = []
63
+ with path.open("r", encoding="utf-8") as f:
64
+ for line in f:
65
+ if not line.strip():
66
+ continue
67
+ rows.append(json.loads(line))
68
+ return rows
69
+
70
+
71
+ def load_aider(path: Path) -> List[AiderExample]:
72
+ rows = _read_jsonl(path)
73
+ examples: List[AiderExample] = []
74
+ for row in rows:
75
+ prompt = str(row.get("input") or "")
76
+ target = str(row.get("output") or "")
77
+ examples.append(AiderExample(prompt=prompt, target=target, metadata={"length": row.get("length")}))
78
+ return examples
79
+
80
+
81
+ def _token_span_full_output(tokenizer, target: str) -> List[int]:
82
+ ids = tokenizer(target, add_special_tokens=False).input_ids
83
+ if not ids:
84
+ return [0, 0]
85
+ return [0, int(len(ids) - 1)]
86
+
87
+
88
+ def _last_meaningful_code_line_char_span(target: str) -> Optional[Tuple[int, int]]:
89
+ lines = target.splitlines(keepends=True)
90
+ pos = 0
91
+ spans: List[Tuple[int, int, str]] = []
92
+ for line in lines:
93
+ start = pos
94
+ pos += len(line)
95
+ spans.append((start, pos, line))
96
+
97
+ for start, end, line in reversed(spans):
98
+ stripped = line.strip()
99
+ if not stripped:
100
+ continue
101
+ if stripped.startswith("```"):
102
+ continue
103
+ if start == 0 and stripped.endswith(".py"):
104
+ return None
105
+
106
+ line_no_nl = line.rstrip("\r\n")
107
+ end_no_nl = start + len(line_no_nl)
108
+ if end_no_nl <= start:
109
+ continue
110
+ return start, end_no_nl
111
+
112
+ return None
113
+
114
+
115
+ def _char_span_to_token_span(tokenizer, text: str, span: Tuple[int, int]) -> Optional[List[int]]:
116
+ start_char, end_char = int(span[0]), int(span[1])
117
+ if end_char <= start_char:
118
+ return None
119
+
120
+ enc = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
121
+ offsets = enc.get("offset_mapping")
122
+ if offsets is None:
123
+ raise ValueError("Tokenizer does not provide offset_mapping; cannot map char spans to tokens.")
124
+
125
+ tok_indices: List[int] = []
126
+ for idx, off in enumerate(offsets):
127
+ if off is None:
128
+ continue
129
+ s, e = int(off[0]), int(off[1])
130
+ if s < end_char and e > start_char:
131
+ tok_indices.append(int(idx))
132
+ if not tok_indices:
133
+ return None
134
+ return [min(tok_indices), max(tok_indices)]
135
+
136
+
137
+ def _last_meaningful_code_line_token_span(tokenizer, target: str) -> List[int]:
138
+ full_span = _token_span_full_output(tokenizer, target)
139
+ span_chars = _last_meaningful_code_line_char_span(target)
140
+ if span_chars is None:
141
+ return full_span
142
+
143
+ span_toks = _char_span_to_token_span(tokenizer, target, span_chars)
144
+ if span_toks is None:
145
+ return full_span
146
+
147
+ span_toks[0] = max(int(span_toks[0]), int(full_span[0]))
148
+ span_toks[1] = min(int(span_toks[1]), int(full_span[1]))
149
+ if span_toks[1] < span_toks[0]:
150
+ return full_span
151
+ return span_toks
152
+
153
+
154
+ def _last_token_span(token_span: Sequence[int]) -> List[int]:
155
+ if not (isinstance(token_span, Sequence) and len(token_span) == 2):
156
+ return [0, 0]
157
+ end = int(token_span[1])
158
+ return [end, end]
159
+
160
+
161
+ def resolve_device(args) -> str:
162
+ if args.cuda is not None and "," in args.cuda:
163
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
164
+ return "auto"
165
+ if args.cuda is not None and args.cuda.strip():
166
+ return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
167
+ return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
168
+
169
+
170
+ def load_model_and_tokenizer(args) -> tuple[Any, Any]:
171
+ model_id = args.model_path or args.model
172
+ if not model_id:
173
+ raise SystemExit("Provide --model_path (local) or --model (HF repo id).")
174
+
175
+ tokenizer_id = args.tokenizer_path or model_id
176
+ device = resolve_device(args)
177
+
178
+ model = AutoModelForCausalLM.from_pretrained(
179
+ model_id,
180
+ device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
181
+ torch_dtype=torch.float16,
182
+ attn_implementation="eager",
183
+ )
184
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
185
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
186
+ tokenizer.pad_token = tokenizer.eos_token
187
+ model.eval()
188
+ return model, tokenizer
189
+
190
+
191
+ def _faithfulness_test_with_user_prompt_indices(
192
+ llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
193
+ attribution: torch.Tensor,
194
+ prompt: str,
195
+ generation: str,
196
+ *,
197
+ user_prompt_indices: List[int],
198
+ k: int = 20,
199
+ ) -> Tuple[float, float, float]:
200
+ def auc(arr: np.ndarray) -> float:
201
+ return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / max(1, (arr.shape[0] - 1))
202
+
203
+ pad_token_id = llm_evaluator._ensure_pad_token_id()
204
+
205
+ user_prompt = " " + prompt
206
+ formatted_prompt = llm_evaluator.format_prompt(user_prompt)
207
+ formatted_ids = llm_evaluator.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).input_ids
208
+
209
+ prompt_ids = formatted_ids.to(llm_evaluator.device)
210
+ prompt_ids_perturbed = prompt_ids.clone()
211
+ generation_ids = llm_evaluator.tokenizer(
212
+ generation + llm_evaluator.tokenizer.eos_token,
213
+ return_tensors="pt",
214
+ add_special_tokens=False,
215
+ ).input_ids.to(llm_evaluator.device)
216
+
217
+ attr_cpu = attribution.detach().cpu()
218
+ w = attr_cpu.sum(0)
219
+ sorted_attr_indices = torch.argsort(w, descending=True)
220
+ attr_sum = float(w.sum().item())
221
+
222
+ P = int(w.numel())
223
+ if len(user_prompt_indices) != P:
224
+ raise ValueError(
225
+ "user_prompt_indices length does not match prompt-side attribution length: "
226
+ f"indices P={len(user_prompt_indices)}, attr P={P}."
227
+ )
228
+ if P == 0:
229
+ return 0.0, 0.0, 0.0
230
+
231
+ if max(user_prompt_indices) >= int(prompt_ids_perturbed.shape[1]):
232
+ raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
233
+
234
+ steps = int(k) if k is not None else 0
235
+ if steps <= 0:
236
+ steps = 1
237
+ steps = min(steps, P)
238
+
239
+ scores = np.zeros(steps + 1, dtype=np.float64)
240
+ density = np.zeros(steps + 1, dtype=np.float64)
241
+
242
+ scores[0] = (
243
+ llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
244
+ )
245
+ density[0] = 1.0
246
+
247
+ if attr_sum <= 0:
248
+ density = np.linspace(1.0, 0.0, steps + 1)
249
+
250
+ base = P // steps
251
+ remainder = P % steps
252
+ start = 0
253
+ for step in range(steps):
254
+ size = base + (1 if step < remainder else 0)
255
+ group = sorted_attr_indices[start : start + size]
256
+ start += size
257
+
258
+ for idx in group:
259
+ j = int(idx.item())
260
+ abs_pos = int(user_prompt_indices[j])
261
+ prompt_ids_perturbed[0, abs_pos] = pad_token_id
262
+ scores[step + 1] = (
263
+ llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
264
+ )
265
+ if attr_sum > 0:
266
+ dec = float(w.index_select(0, group).sum().item()) / attr_sum
267
+ density[step + 1] = density[step] - dec
268
+
269
+ min_normalized_pred = 1.0
270
+ normalized_model_response = scores.copy()
271
+ for i in range(len(scores)):
272
+ normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
273
+ normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
274
+ min_normalized_pred = min(min_normalized_pred, normalized_pred)
275
+ normalized_model_response[i] = min_normalized_pred
276
+
277
+ alignment_penalty = np.abs(normalized_model_response - density)
278
+ corrected_scores = normalized_model_response + alignment_penalty
279
+ corrected_scores = corrected_scores.clip(0.0, 1.0)
280
+ corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
281
+
282
+ if np.isnan(corrected_scores).any():
283
+ corrected_scores = np.linspace(1.0, 0.0, len(scores))
284
+
285
+ return auc(normalized_model_response), auc(corrected_scores), auc(normalized_model_response + alignment_penalty)
286
+
287
+
288
+ def _row_faithfulness_scores(
289
+ *,
290
+ llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
291
+ attribution_prompt: torch.Tensor,
292
+ prompt: str,
293
+ generation: str,
294
+ user_prompt_indices: Optional[List[int]],
295
+ keep_prompt_token_indices: Optional[Sequence[int]] = None,
296
+ k: int = 20,
297
+ ) -> Tuple[float, float]:
298
+ if keep_prompt_token_indices is not None:
299
+ rise, mas, _ = ft_ifr_improve.faithfulness_test_skip_tokens(
300
+ llm_evaluator,
301
+ attribution_prompt,
302
+ prompt,
303
+ generation,
304
+ keep_prompt_token_indices=keep_prompt_token_indices,
305
+ user_prompt_indices=user_prompt_indices,
306
+ k=int(k),
307
+ )
308
+ return float(rise), float(mas)
309
+ if user_prompt_indices is not None:
310
+ rise, mas, _ = _faithfulness_test_with_user_prompt_indices(
311
+ llm_evaluator,
312
+ attribution_prompt,
313
+ prompt,
314
+ generation,
315
+ user_prompt_indices=user_prompt_indices,
316
+ k=int(k),
317
+ )
318
+ return float(rise), float(mas)
319
+
320
+ rise, mas, _ = llm_evaluator.faithfulness_test(attribution_prompt, prompt, generation, k=int(k))
321
+ return float(rise), float(mas)
322
+
323
+
324
+ def _model_tag(args) -> str:
325
+ if args.model:
326
+ return str(args.model)
327
+ if args.model_path:
328
+ return Path(args.model_path).name
329
+ return "model"
330
+
331
+
332
+ def main() -> None:
333
+ parser = argparse.ArgumentParser("Experiment 4 runner: aider faithfulness (row-only).")
334
+ parser.add_argument("--data_path", type=str, default="exp/exp4/data/aider.jsonl")
335
+ parser.add_argument("--output_root", type=str, default="exp/exp4/output")
336
+ parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).")
337
+ parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.")
338
+ parser.add_argument("--tokenizer_path", type=str, default=None, help="Optional tokenizer path/id (defaults to model).")
339
+ parser.add_argument("--cuda", type=str, default=None)
340
+ parser.add_argument("--cuda_num", type=int, default=0)
341
+ parser.add_argument("--num_examples", type=int, default=100)
342
+ parser.add_argument("--seed", type=int, default=42, help="Reserved for future use; exp4 runs in file order.")
343
+ parser.add_argument("--chunk_tokens", type=int, default=128)
344
+ parser.add_argument("--sink_chunk_tokens", type=int, default=32)
345
+ parser.add_argument("--n_hops", type=int, default=3)
346
+ parser.add_argument("--k", type=int, default=20, help="Perturbation steps for MAS/RISE.")
347
+ args = parser.parse_args()
348
+
349
+ data_path = Path(args.data_path)
350
+ if not data_path.exists():
351
+ raise SystemExit(f"Missing Aider JSONL: {data_path}")
352
+
353
+ model, tokenizer = load_model_and_tokenizer(args)
354
+ llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(model, tokenizer)
355
+
356
+ examples = load_aider(data_path)
357
+ total = min(len(examples), int(args.num_examples))
358
+ iterator = islice(examples, total)
359
+
360
+ ifr = llm_attr.LLMIFRAttribution(
361
+ model,
362
+ tokenizer,
363
+ chunk_tokens=int(args.chunk_tokens),
364
+ sink_chunk_tokens=int(args.sink_chunk_tokens),
365
+ )
366
+ flashtrace = ft_ifr_improve.LLMIFRAttributionBoth(
367
+ model,
368
+ tokenizer,
369
+ chunk_tokens=int(args.chunk_tokens),
370
+ sink_chunk_tokens=int(args.sink_chunk_tokens),
371
+ )
372
+
373
+ results: Dict[Tuple[str, str], List[Tuple[float, float]]] = {
374
+ ("ifr_all_positions", "last_line"): [],
375
+ ("ifr_all_positions", "last_token"): [],
376
+ ("ifr_multi_hop_both", "full_output"): [],
377
+ }
378
+ skipped: Dict[Tuple[str, str], int] = {k: 0 for k in results}
379
+ sample_times: Dict[Tuple[str, str], List[float]] = {k: [] for k in results}
380
+
381
+ for example_idx, ex in enumerate(iterator):
382
+ prompt = ex.prompt
383
+ target = ex.target
384
+
385
+ full_span = _token_span_full_output(tokenizer, target)
386
+ last_line_span = _last_meaningful_code_line_token_span(tokenizer, target)
387
+ last_token_span = _last_token_span(last_line_span)
388
+
389
+ attr_all = None
390
+ attr_all_time_s = 0.0
391
+ user_prompt_indices_all: Optional[List[int]] = None
392
+ prompt_len_all = 0
393
+ try:
394
+ t_attr = time.perf_counter()
395
+ attr_all = ifr.calculate_ifr_for_all_positions(prompt, target=target)
396
+ attr_all_time_s = float(time.perf_counter() - t_attr)
397
+ user_prompt_indices_all = list(getattr(ifr, "user_prompt_indices", []) or [])
398
+ prompt_len_all = int(len(attr_all.prompt_tokens))
399
+ except Exception as exc:
400
+ skipped[("ifr_all_positions", "last_line")] += 1
401
+ skipped[("ifr_all_positions", "last_token")] += 1
402
+ print(f"[warn] ifr_all_positions attribution failed ex={example_idx}: {exc}")
403
+
404
+ if attr_all is not None and user_prompt_indices_all is not None and prompt_len_all >= 0:
405
+ for sink_name, span in (("last_line", last_line_span), ("last_token", last_token_span)):
406
+ key = ("ifr_all_positions", sink_name)
407
+ try:
408
+ t_faith = time.perf_counter()
409
+ row = attr_all.get_all_token_attrs(list(span))[1]
410
+ rise, mas = _row_faithfulness_scores(
411
+ llm_evaluator=llm_evaluator,
412
+ attribution_prompt=row[:, :prompt_len_all],
413
+ prompt=prompt,
414
+ generation=target,
415
+ user_prompt_indices=user_prompt_indices_all,
416
+ k=int(args.k),
417
+ )
418
+ faith_time_s = float(time.perf_counter() - t_faith)
419
+ results[key].append((rise, mas))
420
+ sample_times[key].append(attr_all_time_s + faith_time_s)
421
+ except Exception as exc:
422
+ skipped[key] += 1
423
+ print(f"[warn] ifr_all_positions {sink_name} failed ex={example_idx}: {exc}")
424
+
425
+ try:
426
+ t_attr = time.perf_counter()
427
+ attr_ft = flashtrace.calculate_ifr_multi_hop_both(
428
+ prompt,
429
+ target=target,
430
+ sink_span=None,
431
+ thinking_span=None,
432
+ n_hops=int(args.n_hops),
433
+ )
434
+ attr_ft_time_s = float(time.perf_counter() - t_attr)
435
+ user_prompt_indices_ft = list(getattr(flashtrace, "user_prompt_indices", []) or [])
436
+ prompt_len_ft = int(len(attr_ft.prompt_tokens))
437
+ keep_prompt_token_indices = ft_ifr_improve.keep_token_indices(list(attr_ft.prompt_tokens))
438
+
439
+ t_faith = time.perf_counter()
440
+ row_full = attr_ft.get_all_token_attrs(full_span)[1]
441
+ rise, mas = _row_faithfulness_scores(
442
+ llm_evaluator=llm_evaluator,
443
+ attribution_prompt=row_full[:, :prompt_len_ft],
444
+ prompt=prompt,
445
+ generation=target,
446
+ user_prompt_indices=user_prompt_indices_ft,
447
+ keep_prompt_token_indices=keep_prompt_token_indices,
448
+ k=int(args.k),
449
+ )
450
+ faith_time_s = float(time.perf_counter() - t_faith)
451
+ results[("ifr_multi_hop_both", "full_output")].append((rise, mas))
452
+ sample_times[("ifr_multi_hop_both", "full_output")].append(attr_ft_time_s + faith_time_s)
453
+ except Exception as exc:
454
+ skipped[("ifr_multi_hop_both", "full_output")] += 1
455
+ print(f"[warn] ifr_multi_hop_both failed ex={example_idx}: {exc}")
456
+
457
+ model_tag = _model_tag(args)
458
+ out_dir = Path(args.output_root) / "faithfulness" / "aider" / model_tag
459
+ out_dir.mkdir(parents=True, exist_ok=True)
460
+ out_path = out_dir / f"row_only_{total}_examples.csv"
461
+
462
+ with out_path.open("w", encoding="utf-8") as f:
463
+ f.write("Method,Sink,Row_RISE_Mean,Row_RISE_Std,Row_MAS_Mean,Row_MAS_Std,Used,Skipped,Avg_Sample_Time_s\n")
464
+ for (method, sink), vals in results.items():
465
+ arr = np.asarray(vals, dtype=np.float64)
466
+ used = int(arr.shape[0])
467
+ if used == 0:
468
+ rise_mean = float("nan")
469
+ rise_std = float("nan")
470
+ mas_mean = float("nan")
471
+ mas_std = float("nan")
472
+ else:
473
+ rise_mean = float(arr[:, 0].mean())
474
+ rise_std = float(arr[:, 0].std())
475
+ mas_mean = float(arr[:, 1].mean())
476
+ mas_std = float(arr[:, 1].std())
477
+ times = sample_times.get((method, sink)) or []
478
+ avg_time = float(np.mean(times)) if times else 0.0
479
+ f.write(
480
+ f"{method},{sink},{rise_mean},{rise_std},{mas_mean},{mas_std},{used},{int(skipped[(method, sink)])},{avg_time}\n"
481
+ )
482
+
483
+ print(f"[done] wrote {out_path}")
484
+
485
+
486
+ if __name__ == "__main__":
487
+ main()
exp/exp5/README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashTrace 实验 5:跨模型(Qwen → Llama)token-span 映射
2
+
3
+ ## 背景:为什么需要映射
4
+
5
+ `exp/exp2/run_exp.py` 的归因与评估是严格 **token-level** 的,并且依赖缓存数据中的 token-span 字段:
6
+
7
+ - `indices_to_explain = [start_tok, end_tok]`(generation token indices,闭区间)
8
+ - `sink_span` / `thinking_span`(同样是 generation token spans)
9
+
10
+ 这些 span 在生成缓存(`exp/exp2/sample_and_filter.py`、`exp/exp2/map_math_mine_to_exp2_cache.py`)时是用某个 tokenizer 计算并写死的(通常是 `Qwen3-8B` 的 tokenizer)。
11
+
12
+ 当你切换到新模型(例如 `Llama-3.1-8B-Instruct`)时,**tokenizer 不同**,`target` 的 tokenization 长度/边界会变化,导致旧的 span 在新 tokenizer 下经常越界,从而让 exp2 在归因阶段直接报错(`IndexError: end_tok out of range`)。
13
+
14
+ ## 解决方案:exp5 映射脚本
15
+
16
+ `exp/exp5/map_exp2_cache_token_spans.py` 将 exp2 缓存里的旧 token-span 从旧 tokenizer(默认 `Qwen3-8B`)映射到新 tokenizer(默认 `Llama-3.1-8B-Instruct`),并输出到:
17
+
18
+ `exp/exp5/data/<同名数据集>.jsonl`
19
+
20
+ 映射策略(默认):
21
+ 1) 用旧 tokenizer 对 `target` 做 `return_offsets_mapping=True`
22
+ 2) 把旧的 token-span 转成 `target` 的字符区间
23
+ 3) 用新 tokenizer 对同一个 `target` 做 offsets,再把字符区间映射回新的 token-span
24
+
25
+ 如遇极端情况(缓存并非由预期旧 tokenizer 产生),可启用 `--allow_fallback_answer`,用 `metadata.boxed_answer`(或 `reference_answer`)在新 tokenizer 下重新定位 span 作为兜底。
26
+
27
+ ---
28
+
29
+ ## Step 1:把 exp2 数据集缓存映射到 exp5/data
30
+
31
+ 推荐使用仓库的 venv:
32
+
33
+ ```bash
34
+ .venv/bin/python exp/exp5/map_exp2_cache_token_spans.py \
35
+ --in_jsonl exp/exp2/data/niah_mq_q2.jsonl \
36
+ --out_dir exp/exp5/data \
37
+ --old_tokenizer_model /opt/share/models/Qwen/Qwen3-8B \
38
+ --new_tokenizer_model /opt/share/models/meta-llama/Llama-3.1-8B-Instruct
39
+ ```
40
+
41
+ 一次映射多个数据集(示例:RULER + math):
42
+
43
+ ```bash
44
+ .venv/bin/python exp/exp5/map_exp2_cache_token_spans.py \
45
+ --in_jsonl exp/exp2/data/niah_mq_q2.jsonl exp/exp2/data/math.jsonl \
46
+ --out_dir exp/exp5/data \
47
+ --old_tokenizer_model /opt/share/models/Qwen/Qwen3-8B \
48
+ --new_tokenizer_model /opt/share/models/meta-llama/Llama-3.1-8B-Instruct
49
+ ```
50
+
51
+ 如果输出文件已存在,加 `--overwrite`。
52
+
53
+ 默认行为:若某条样本无法映射,脚本会将其 **drop** 并在输出统计中报告;如需严格一致性请加 `--strict`(遇到首个失败样本直接退出)。如怀疑原缓存并非由 `--old_tokenizer_model` 产生,可加 `--allow_fallback_answer` 启用基于 `metadata.boxed_answer` 的兜底定位。
54
+
55
+ ---
56
+
57
+ ## Step 2:用 exp2 直接跑 Llama 归因评测(但数据/输出都指向 exp5)
58
+
59
+ 关键点:
60
+ - **数据读取**:用 `--data_root exp/exp5/data`(让 exp2 读取映射后的缓存)
61
+ - **结果输出**:用 `--output_root exp/exp5/output`(避免写入 `exp/exp2/output`)
62
+ - **不要加** `--save_hop_traces`(避免写 trace)
63
+
64
+ ### RULER(可跑 recovery + faithfulness)
65
+
66
+ ```bash
67
+ CUDA_VISIBLE_DEVICES=0 .venv/bin/python exp/exp2/run_exp.py \
68
+ --datasets niah_mq_q2 \
69
+ --data_root exp/exp5/data \
70
+ --output_root exp/exp5/output \
71
+ --attr_funcs ifr_all_positions,attnlrp,ifr_multi_hop_both \
72
+ --model_path /opt/share/models/meta-llama/Llama-3.1-8B-Instruct \
73
+ --cuda 0 \
74
+ --num_examples 100 \
75
+ --mode faithfulness_gen,recovery_ruler
76
+ ```
77
+
78
+ ### math(只能跑 faithfulness;recovery 会被 exp2 显式拒绝)
79
+
80
+ ```bash
81
+ CUDA_VISIBLE_DEVICES=0 .venv/bin/python exp/exp2/run_exp.py \
82
+ --datasets math \
83
+ --data_root exp/exp5/data \
84
+ --output_root exp/exp5/output \
85
+ --attr_funcs ifr_all_positions,attnlrp,ifr_multi_hop_both \
86
+ --model_path /opt/share/models/meta-llama/Llama-3.1-8B-Instruct \
87
+ --cuda 0 \
88
+ --num_examples 100 \
89
+ --mode faithfulness_gen
90
+ ```
91
+
92
+ ## 关于“是否会污染 exp2 文件夹”
93
+
94
+ - **不会污染 `exp/exp2/data/`**:我们不改 exp2 的缓存,而是输出到 `exp/exp5/data/`。
95
+ - **不加 `--save_hop_traces` 不会写 trace**。
96
+ - 但注意:`exp/exp2/run_exp.py` 本身**一定会写 CSV 指标文件**到 `--output_root`(代码行为如此,exp5 不改 exp2),所以要做到“exp2 文件夹不新增文件”,请把 `--output_root` 指向 `exp/exp5/output`(或其它目录)。
97
+
98
+ ```bash
99
+ python exp/exp2/run_exp.py \
100
+ --datasets niah_mq_q2 \
101
+ --data_root exp/exp5/data \
102
+ --output_root exp/exp5/output \
103
+ --attr_funcs ifr_all_positions,attnlrp,ifr_multi_hop_both \
104
+ --model_path /opt/share/models/meta-llama/Llama-3.1-8B-Instruct \
105
+ --cuda 2,3,4,5,6,7 \
106
+ --num_examples 100 \
107
+ --mode faithfulness_gen \
108
+ --n_hops 1
109
+ && python exp/exp2/run_exp.py \
110
+ --datasets math \
111
+ --data_root exp/exp5/data \
112
+ --output_root exp/exp5/output \
113
+ --attr_funcs ifr_all_positions,attnlrp,ifr_multi_hop_both \
114
+ --model_path /opt/share/models/meta-llama/Llama-3.1-8B-Instruct \
115
+ --cuda 2,3,4,5,6,7 \
116
+ --num_examples 100 \
117
+ --mode faithfulness_gen \
118
+ --n_hops 1
119
+ ```
exp/exp5/map_exp2_cache_token_spans.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Map exp2 cached JSONL token spans across tokenizers (Qwen -> Llama).
3
+
4
+ Background
5
+ ----------
6
+ `exp/exp2/run_exp.py` expects cached datasets to provide token-level generation spans:
7
+
8
+ - indices_to_explain: [start_tok, end_tok] (generation-token indices; closed interval)
9
+ - sink_span / thinking_span: same tokenizer convention as indices_to_explain
10
+
11
+ These spans are computed under a specific tokenizer (often Qwen3-8B). When switching
12
+ to a different model/tokenizer (e.g., Llama-3.1-8B-Instruct), the stored spans can
13
+ become out-of-range and crash exp2 attribution (IndexError in token-span checks).
14
+
15
+ This script remaps spans by:
16
+ 1) Tokenizing `target` with the OLD tokenizer to obtain offset_mapping
17
+ 2) Converting the OLD token span into a character span in `target`
18
+ 3) Tokenizing `target` with the NEW tokenizer and mapping the character span back
19
+ into NEW token indices
20
+
21
+ Outputs are written under `exp/exp5/data/` by default, keeping `exp/exp2/` untouched.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import json
28
+ import sys
29
+ from pathlib import Path
30
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
31
+
32
+ from transformers import AutoTokenizer
33
+
34
+
35
+ REPO_ROOT = Path(__file__).resolve().parents[2]
36
+ if str(REPO_ROOT) not in sys.path:
37
+ sys.path.insert(0, str(REPO_ROOT))
38
+
39
+
40
+ def _split_args(values: Iterable[str]) -> List[str]:
41
+ out: List[str] = []
42
+ for v in values:
43
+ for part in str(v).split(","):
44
+ part = part.strip()
45
+ if part:
46
+ out.append(part)
47
+ return out
48
+
49
+
50
+ def _load_tokenizer(tokenizer_model: str):
51
+ path = Path(tokenizer_model)
52
+ if path.exists():
53
+ return AutoTokenizer.from_pretrained(path.as_posix(), local_files_only=True)
54
+ # May require network access; keep as fallback for environments that allow it.
55
+ return AutoTokenizer.from_pretrained(tokenizer_model)
56
+
57
+
58
+ def _is_token_span(span: Any) -> bool:
59
+ return (
60
+ isinstance(span, list)
61
+ and len(span) == 2
62
+ and all(isinstance(x, int) for x in span)
63
+ and span[0] >= 0
64
+ and span[1] >= span[0]
65
+ )
66
+
67
+
68
+ def _pick_old_span(obj: Dict[str, Any]) -> Optional[List[int]]:
69
+ span = obj.get("indices_to_explain")
70
+ if _is_token_span(span):
71
+ return list(span)
72
+ span = obj.get("sink_span")
73
+ if _is_token_span(span):
74
+ return list(span)
75
+ return None
76
+
77
+
78
+ def _offsets_to_char_span(offsets: Any, token_span: List[int]) -> Optional[Tuple[int, int]]:
79
+ """Convert a token span [start,end] to a character span [char_start,char_end) using offsets."""
80
+ if offsets is None:
81
+ return None
82
+ if not isinstance(offsets, list):
83
+ return None
84
+ start_tok, end_tok = token_span
85
+ if end_tok >= len(offsets):
86
+ return None
87
+
88
+ char_starts: List[int] = []
89
+ char_ends: List[int] = []
90
+ for idx in range(start_tok, end_tok + 1):
91
+ off = offsets[idx]
92
+ if off is None:
93
+ continue
94
+ if not (isinstance(off, (list, tuple)) and len(off) == 2):
95
+ continue
96
+ try:
97
+ s, e = int(off[0]), int(off[1])
98
+ except Exception:
99
+ continue
100
+ if e <= s:
101
+ continue
102
+ char_starts.append(s)
103
+ char_ends.append(e)
104
+
105
+ if not char_starts or not char_ends:
106
+ return None
107
+ return min(char_starts), max(char_ends)
108
+
109
+
110
+ def _char_span_to_token_span(offsets: Any, char_span: Tuple[int, int]) -> Optional[List[int]]:
111
+ """Convert a character span [char_start,char_end) to a token span [start,end] by overlap."""
112
+ if offsets is None:
113
+ return None
114
+ if not isinstance(offsets, list):
115
+ return None
116
+ char_start, char_end = int(char_span[0]), int(char_span[1])
117
+ if char_end <= char_start:
118
+ return None
119
+
120
+ hit: List[int] = []
121
+ for tok_idx, off in enumerate(offsets):
122
+ if off is None:
123
+ continue
124
+ if not (isinstance(off, (list, tuple)) and len(off) == 2):
125
+ continue
126
+ try:
127
+ s, e = int(off[0]), int(off[1])
128
+ except Exception:
129
+ continue
130
+ if e <= s:
131
+ continue
132
+ if s < char_end and e > char_start:
133
+ hit.append(int(tok_idx))
134
+
135
+ if not hit:
136
+ return None
137
+ return [min(hit), max(hit)]
138
+
139
+
140
+ def _validate_span_with_eos(tokenizer, target: str, token_span: List[int]) -> bool:
141
+ eos = tokenizer.eos_token or ""
142
+ gen_ids = tokenizer(target + eos, add_special_tokens=False).input_ids
143
+ gen_len = int(len(gen_ids))
144
+ return 0 <= token_span[0] <= token_span[1] < gen_len
145
+
146
+
147
+ def _guess_answer_text(obj: Dict[str, Any]) -> Optional[str]:
148
+ meta = obj.get("metadata") or {}
149
+ if isinstance(meta, dict):
150
+ boxed = (meta.get("boxed_answer") or "").strip()
151
+ if boxed:
152
+ return boxed
153
+ ref = (meta.get("reference_answer") or "").strip()
154
+ if ref:
155
+ return ref
156
+ tgt = obj.get("target")
157
+ if isinstance(tgt, str) and tgt.strip():
158
+ # Common exp2 cache convention: last line is the final answer.
159
+ last_line = tgt.strip().splitlines()[-1].strip()
160
+ return last_line or None
161
+ return None
162
+
163
+
164
+ def _fallback_map_via_answer_text(
165
+ obj: Dict[str, Any],
166
+ *,
167
+ new_tokenizer,
168
+ ) -> Optional[List[int]]:
169
+ tgt = obj.get("target")
170
+ if not isinstance(tgt, str) or not tgt:
171
+ return None
172
+
173
+ from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer # lazy import
174
+
175
+ answer_text = _guess_answer_text(obj)
176
+ ex = CachedExample(
177
+ prompt=str(obj.get("prompt") or ""),
178
+ target=tgt,
179
+ indices_to_explain=None,
180
+ attr_mask_indices=obj.get("attr_mask_indices"),
181
+ sink_span=None,
182
+ thinking_span=None,
183
+ metadata=obj.get("metadata") or {},
184
+ )
185
+ out = attach_spans_from_answer(ex, new_tokenizer, answer_text)
186
+ if out.sink_span is None:
187
+ return None
188
+ if not _is_token_span(out.sink_span):
189
+ return None
190
+ return list(out.sink_span)
191
+
192
+
193
+ def _map_one_obj(
194
+ obj: Dict[str, Any],
195
+ *,
196
+ old_tokenizer,
197
+ new_tokenizer,
198
+ allow_fallback_answer: bool,
199
+ ) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
200
+ target = obj.get("target")
201
+ if not isinstance(target, str) or not target:
202
+ return None, "missing_target"
203
+
204
+ old_span = _pick_old_span(obj)
205
+ if old_span is None:
206
+ return None, "missing_old_span"
207
+
208
+ # 1) Old token span -> char span in target.
209
+ old_enc = old_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
210
+ old_offsets = old_enc.get("offset_mapping")
211
+ char_span = _offsets_to_char_span(old_offsets, old_span)
212
+ if char_span is None:
213
+ if not allow_fallback_answer:
214
+ return None, "old_span_to_char_failed"
215
+ new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
216
+ if new_span is None:
217
+ return None, "fallback_answer_failed"
218
+ if not _validate_span_with_eos(new_tokenizer, target, new_span):
219
+ return None, "fallback_answer_span_invalid"
220
+ mapped = dict(obj)
221
+ mapped["indices_to_explain"] = new_span
222
+ mapped["sink_span"] = new_span
223
+ mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
224
+ meta = mapped.get("metadata")
225
+ if not isinstance(meta, dict):
226
+ meta = {}
227
+ meta = dict(meta)
228
+ meta["exp5_span_map_method"] = "answer_text"
229
+ mapped["metadata"] = meta
230
+ return mapped, None
231
+
232
+ # 2) Char span -> new token span.
233
+ new_enc = new_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
234
+ new_offsets = new_enc.get("offset_mapping")
235
+ new_span = _char_span_to_token_span(new_offsets, char_span)
236
+ if new_span is None:
237
+ if not allow_fallback_answer:
238
+ return None, "char_to_new_span_failed"
239
+ new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
240
+ if new_span is None:
241
+ return None, "fallback_answer_failed"
242
+
243
+ if not _validate_span_with_eos(new_tokenizer, target, new_span):
244
+ if not allow_fallback_answer:
245
+ return None, "new_span_invalid"
246
+ fb = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
247
+ if fb is None or not _validate_span_with_eos(new_tokenizer, target, fb):
248
+ return None, "fallback_answer_span_invalid"
249
+ new_span = fb
250
+
251
+ mapped = dict(obj)
252
+ mapped["indices_to_explain"] = new_span
253
+ mapped["sink_span"] = new_span
254
+ mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
255
+
256
+ meta = mapped.get("metadata")
257
+ if not isinstance(meta, dict):
258
+ meta = {}
259
+ meta = dict(meta)
260
+ meta["exp5_span_map_method"] = "token_span_char_align"
261
+ mapped["metadata"] = meta
262
+ return mapped, None
263
+
264
+
265
+ def _read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
266
+ with path.open("r", encoding="utf-8") as f:
267
+ for line_no, line in enumerate(f, start=1):
268
+ if not line.strip():
269
+ continue
270
+ try:
271
+ obj = json.loads(line)
272
+ except json.JSONDecodeError as exc: # pragma: no cover
273
+ raise RuntimeError(f"Invalid JSON at {path}:{line_no}: {exc}") from exc
274
+ if not isinstance(obj, dict):
275
+ raise RuntimeError(f"Expected JSON object per line at {path}:{line_no}.")
276
+ yield obj
277
+
278
+
279
+ def _write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int:
280
+ path.parent.mkdir(parents=True, exist_ok=True)
281
+ count = 0
282
+ with path.open("w", encoding="utf-8") as f:
283
+ for obj in rows:
284
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
285
+ count += 1
286
+ return count
287
+
288
+
289
+ def _default_old_tokenizer() -> str:
290
+ # Repo defaults used in exp2 README examples for span extraction.
291
+ return "/opt/share/models/Qwen/Qwen3-8B"
292
+
293
+
294
+ def _default_new_tokenizer() -> str:
295
+ return "/opt/share/models/meta-llama/Llama-3.1-8B-Instruct"
296
+
297
+
298
+ def main() -> None:
299
+ ap = argparse.ArgumentParser("Map exp2 cache token spans from an old tokenizer to a new tokenizer.")
300
+ ap.add_argument(
301
+ "--in_jsonl",
302
+ type=str,
303
+ nargs="+",
304
+ required=True,
305
+ help="One or more exp2 cached JSONL files (comma-separated also accepted).",
306
+ )
307
+ ap.add_argument(
308
+ "--out_dir",
309
+ type=str,
310
+ default="exp/exp5/data",
311
+ help="Output directory for mapped JSONL files.",
312
+ )
313
+ ap.add_argument(
314
+ "--old_tokenizer_model",
315
+ type=str,
316
+ default=_default_old_tokenizer(),
317
+ help="Tokenizer used to produce the original token spans (default: Qwen3-8B local path).",
318
+ )
319
+ ap.add_argument(
320
+ "--new_tokenizer_model",
321
+ type=str,
322
+ default=_default_new_tokenizer(),
323
+ help="Tokenizer to map spans into (default: Llama-3.1-8B-Instruct local path).",
324
+ )
325
+ ap.add_argument("--strict", action="store_true", help="Fail on the first example that cannot be mapped.")
326
+ ap.add_argument(
327
+ "--allow_fallback_answer",
328
+ action="store_true",
329
+ help=(
330
+ "If span alignment fails, try to recompute spans by locating metadata.boxed_answer in target "
331
+ "(useful when caches were not built with the assumed old tokenizer)."
332
+ ),
333
+ )
334
+ ap.add_argument(
335
+ "--overwrite",
336
+ action="store_true",
337
+ help="Overwrite output files if they already exist.",
338
+ )
339
+ args = ap.parse_args()
340
+
341
+ in_paths = [Path(p) for p in _split_args(args.in_jsonl)]
342
+ out_dir = Path(args.out_dir)
343
+
344
+ old_tok = _load_tokenizer(str(args.old_tokenizer_model))
345
+ new_tok = _load_tokenizer(str(args.new_tokenizer_model))
346
+
347
+ # exp2 convention: ensure a pad token exists for downstream perturbation.
348
+ if new_tok.pad_token is None and new_tok.eos_token is not None:
349
+ new_tok.pad_token = new_tok.eos_token
350
+
351
+ summary: Dict[str, Any] = {
352
+ "old_tokenizer_model": str(args.old_tokenizer_model),
353
+ "new_tokenizer_model": str(args.new_tokenizer_model),
354
+ "datasets": [],
355
+ }
356
+
357
+ for in_path in in_paths:
358
+ if not in_path.exists():
359
+ raise SystemExit(f"Missing input JSONL: {in_path}")
360
+ out_path = out_dir / in_path.name
361
+ if out_path.exists() and not bool(args.overwrite):
362
+ raise SystemExit(f"Refusing to overwrite existing output: {out_path} (use --overwrite)")
363
+
364
+ total = 0
365
+ mapped_ok = 0
366
+ dropped = 0
367
+ errors: Dict[str, int] = {}
368
+
369
+ mapped_rows: List[Dict[str, Any]] = []
370
+ for obj in _read_jsonl(in_path):
371
+ total += 1
372
+ mapped, err = _map_one_obj(
373
+ obj,
374
+ old_tokenizer=old_tok,
375
+ new_tokenizer=new_tok,
376
+ allow_fallback_answer=bool(args.allow_fallback_answer),
377
+ )
378
+ if err is not None or mapped is None:
379
+ errors[err or "unknown_error"] = errors.get(err or "unknown_error", 0) + 1
380
+ if bool(args.strict):
381
+ raise SystemExit(f"Failed to map {in_path} example #{total}: {err}")
382
+ dropped += 1
383
+ continue
384
+ mapped_ok += 1
385
+ mapped_rows.append(mapped)
386
+
387
+ written = _write_jsonl(out_path, mapped_rows)
388
+ if written != mapped_ok: # pragma: no cover
389
+ raise SystemExit(f"Internal error: written={written} != mapped_ok={mapped_ok}")
390
+
391
+ record = {
392
+ "in_jsonl": str(in_path),
393
+ "out_jsonl": str(out_path),
394
+ "total": int(total),
395
+ "mapped_ok": int(mapped_ok),
396
+ "dropped": int(dropped),
397
+ "errors": errors,
398
+ }
399
+ summary["datasets"].append(record)
400
+ print(json.dumps(record, ensure_ascii=False))
401
+
402
+ # Human-readable compact summary at end.
403
+ print(json.dumps(summary, ensure_ascii=False, indent=2))
404
+
405
+
406
+ if __name__ == "__main__":
407
+ main()
exp/proc/README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # exp/proc(exp2 trace 映射/对外导出)
2
+
3
+ 本目录提供把 `exp/exp2/run_exp.py --save_hop_traces` 产出的 trace 结果,整理成“给合作者使用”的精简样本级 `.npz` 的工具。
4
+
5
+ 主要文件:
6
+ - `exp/proc/map_exp2_traces_to_proc.py`:读取 exp2 的 trace run 文件夹(`manifest.jsonl` + `ex_*.npz`),输出精简格式到 `exp/proc/output/`。
7
+
8
+ ---
9
+
10
+ ## 输入要求
11
+
12
+ 你需要提供(或可自动推断):
13
+ - `--trace_dir`:exp2 的 trace run 文件夹,例如:
14
+ - `exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/ifr_all_positions_mfaithfulness_gen_95ex/`
15
+ - `--dataset_jsonl`:与该 trace run 对应的 exp2 缓存数据集(必须包含 `prompt` + `target`),例如:
16
+ - `exp/exp2/data/morehopqa.jsonl`
17
+ - `--tokenizer_model`:与 exp2 归因时一致的 tokenizer(本地路径或模型名),例如:
18
+ - `/opt/share/models/Qwen/Qwen3-8B/`
19
+
20
+ 注意:
21
+ - 本脚本会严格复刻 exp2 的 token 对齐逻辑(prompt 前导空格、generation 用 `target + eos_token` 再 decode + offset 切片),因此 tokenizer 必须与 exp2 归因一致,否则会直接报错(长度对不上)。
22
+ - 样本匹配使用 `manifest.jsonl` 中的 `prompt_sha1/target_sha1` 对齐 `--dataset_jsonl`;所以 `--dataset_jsonl` 必须是当次 trace run 使用的那份缓存。
23
+
24
+ ---
25
+
26
+ ## 输出位置与命名
27
+
28
+ 默认输出到:
29
+ - `exp/proc/output/<trace_dir 在 traces/ 之后的同构路径>/`
30
+
31
+ 例如输入:
32
+ - `.../output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/<run_tag>/`
33
+
34
+ 默认输出:
35
+ - `exp/proc/output/exp/exp2/data/morehopqa.jsonl/qwen-8B/<run_tag>/`
36
+
37
+ 你也可以用 `--out_dir` 显式指定输出目录。
38
+
39
+ 输出目录内每个样本一个文件:`ex_000000.npz`、`ex_000001.npz` …
40
+
41
+ ---
42
+
43
+ ## 输出 `.npz` 字段(精简且仅包含必要信息)
44
+
45
+ 每个输出样本 `.npz` **仅包含**下列键:
46
+ - `attr`:`float32[L]`,row 归因向量;已去掉 chat template,且去掉 EOS,仅覆盖 `input+cot+output` 的有效 token。
47
+ - `hop`:`float32[H, L]`(可选,仅 FT-IFR 类方法),逐 hop 的向量;同样已去掉 EOS,并与 `attr` 等长对齐。
48
+ - `tok`:`U[L]`,与 `attr/hop` 严格对齐的 token 文本片段序列(同样不含 chat template 与 EOS)。
49
+ - `span_in`:`int64[2]`,input 在向量中的闭区间范围。
50
+ - `span_cot`:`int64[2]`,cot 在向量中的闭区间范围(无 cot 时为 `[-1, -1]`)。
51
+ - `span_out`:`int64[2]`,output 在向量中的闭区间范围。
52
+ - `rise`:`float64`,row 的 RISE(faithfulness)。
53
+ - `mas`:`float64`,row 的 MAS(faithfulness)。
54
+ - `recovery`:`float64`,row 的 Recovery@10%(没有 recovery 时为 NaN)。
55
+
56
+ ---
57
+
58
+ ## 用法示例
59
+
60
+ 最常用(建议显式传入 dataset 与 tokenizer):
61
+ ```bash
62
+ python exp/proc/map_exp2_traces_to_proc.py \
63
+ --trace_dir exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/ifr_all_positions_mfaithfulness_gen_95ex \
64
+ --dataset_jsonl exp/exp2/data/morehopqa.jsonl \
65
+ --tokenizer_model /opt/share/models/Qwen/Qwen3-8B/
66
+ ```
67
+
68
+ 显式指定输出目录(避免默认同构路径):
69
+ ```bash
70
+ python exp/proc/map_exp2_traces_to_proc.py \
71
+ --trace_dir exp/exp2/output/traces/exp/exp2/data/math.jsonl/qwen-8B/ifr_multi_hop_both_n1_mfaithfulness_gen_100ex/ \
72
+ --dataset_jsonl exp/exp2/data/math.jsonl \
73
+ --tokenizer_model /opt/share/models/Qwen/Qwen3-8B/ \
74
+ --out_dir exp/proc/output/math_ifr_multi_hop_both
75
+ ```
76
+
77
+ 调试:只处理前 5 条、允许覆盖输出文件:
78
+ ```bash
79
+ python exp/proc/map_exp2_traces_to_proc.py \
80
+ --trace_dir ... \
81
+ --dataset_jsonl ... \
82
+ --tokenizer_model ... \
83
+ --limit 5 \
84
+ --overwrite
85
+ ```
86
+
87
+ ---
88
+
89
+ ## 常见问题
90
+
91
+ - 报错 “Prompt/Generation token length mismatch”
92
+ - 几乎总是 tokenizer 不一致;请确认 `--tokenizer_model` 与 exp2 归因时使用的 tokenizer 完全一致(建议直接用同一个 `--model_path`)。
93
+ - 报错 “Failed to match manifest sha1 to dataset_jsonl”
94
+ - `--dataset_jsonl` 不是当次 trace run 使用的缓存,或缓存里没有 `target`。
95
+ - FT-IFR 方法输出缺 `hop`
96
+ - 对 `ifr_multi_hop_stop_words/ifr_multi_hop_both/ifr_multi_hop_split_hop/ifr_in_all_gen`,exp2 trace 必须包含 `vh`;若 trace 较旧请重新跑 exp2(带 `--save_hop_traces`)。
97
+ - 如确有需要可加 `--allow_missing_ft_hops` 强行输出(不推荐)。
98
+
exp/proc/map_exp2_traces_to_proc.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Map exp2 trace artifacts into a collaborator-friendly per-sample NPZ format.
3
+
4
+ Input: an exp2 trace run directory produced by `exp/exp2/run_exp.py --save_hop_traces`,
5
+ e.g.:
6
+
7
+ exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/ifr_all_positions_mfaithfulness_gen_95ex/
8
+
9
+ This directory contains:
10
+ - manifest.jsonl (one JSON object per sample)
11
+ - ex_*.npz (per-sample vectors and scores)
12
+
13
+ Output: per-sample NPZ files under `exp/proc/output/` (or a user-provided output path),
14
+ each containing only:
15
+ - attr: row attribution vector over [input + CoT + output] tokens, with chat template and EOS removed
16
+ - hop: per-hop vectors (FT-IFR only), aligned to attr (optional)
17
+ - tok: tokenized text pieces aligned to attr/hop (no chat template, no EOS)
18
+ - span_in/span_cot/span_out: inclusive ranges for input/CoT/output in the above vectors
19
+ - rise/mas: row faithfulness scores (RISE, MAS)
20
+ - recovery: row Recovery@10% score (NaN when unavailable)
21
+
22
+ This script is intentionally self-contained under exp/proc/ and does not modify exp2.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import hashlib
29
+ import json
30
+ from dataclasses import dataclass
31
+ from pathlib import Path
32
+ from typing import Dict, List, Optional, Tuple
33
+
34
+ import numpy as np
35
+ from transformers import AutoTokenizer
36
+
37
+
38
+ FT_IFR_ATTR_FUNCS: set[str] = {
39
+ "ifr_in_all_gen",
40
+ "ifr_multi_hop_stop_words",
41
+ "ifr_multi_hop_both",
42
+ "ifr_multi_hop_split_hop",
43
+ }
44
+
45
+
46
+ def _sha1_text(text: str) -> str:
47
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
48
+
49
+
50
+ def _load_tokenizer(tokenizer_model: str):
51
+ tok_path = Path(tokenizer_model)
52
+ if tok_path.exists():
53
+ tokenizer = AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
54
+ else:
55
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
56
+ if tokenizer.eos_token is None:
57
+ raise SystemExit("Tokenizer is missing eos_token; cannot match exp2 generation tokenization.")
58
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+ return tokenizer
61
+
62
+
63
+ def _decode_text_into_tokens(tokenizer, text: str) -> List[str]:
64
+ """Mirror llm_attr.LLMAttribution.decode_text_into_tokens (offset-slice tokens)."""
65
+ enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
66
+ ids = enc.get("input_ids")
67
+ offsets = enc.get("offset_mapping")
68
+ if ids is None or offsets is None:
69
+ raise ValueError("Tokenizer must provide input_ids and offset_mapping for exact exp2 token alignment.")
70
+ if len(ids) != len(offsets):
71
+ raise ValueError("Tokenizer returned mismatched input_ids vs offset_mapping lengths.")
72
+ tokens: List[str] = []
73
+ for start, end in offsets:
74
+ tokens.append(text[int(start) : int(end)])
75
+ return tokens
76
+
77
+
78
+ @dataclass(frozen=True)
79
+ class DatasetEntry:
80
+ prompt: str
81
+ target: str
82
+
83
+
84
+ def _index_dataset_by_sha1(dataset_jsonl: Path) -> Dict[Tuple[str, str], DatasetEntry]:
85
+ """Build (prompt_sha1, target_sha1) -> (prompt, target) for cache lookup."""
86
+ index: Dict[Tuple[str, str], DatasetEntry] = {}
87
+ collisions: Dict[Tuple[str, str], int] = {}
88
+
89
+ with dataset_jsonl.open("r", encoding="utf-8") as f:
90
+ for line_num, line in enumerate(f, start=1):
91
+ if not line.strip():
92
+ continue
93
+ obj = json.loads(line)
94
+ prompt = str(obj.get("prompt") or "")
95
+ target = obj.get("target")
96
+ if target is None:
97
+ # exp2 trace matching requires cached targets.
98
+ continue
99
+ target = str(target)
100
+
101
+ key = (_sha1_text(prompt), _sha1_text(target))
102
+ if key in index:
103
+ collisions[key] = collisions.get(key, 1) + 1
104
+ continue
105
+ index[key] = DatasetEntry(prompt=prompt, target=target)
106
+
107
+ if collisions:
108
+ raise SystemExit(
109
+ "Dataset cache contains duplicate (prompt,target) pairs; cannot uniquely match by sha1. "
110
+ f"Example collision count={next(iter(collisions.values()))}. "
111
+ f"dataset_jsonl={dataset_jsonl}"
112
+ )
113
+
114
+ if not index:
115
+ raise SystemExit(
116
+ "No usable (prompt,target) pairs found in dataset cache. "
117
+ "Ensure you pass the exp2 cached JSONL used for attribution (with target filled)."
118
+ )
119
+
120
+ return index
121
+
122
+
123
+ def _infer_trace_suffix(trace_dir: Path) -> Optional[Path]:
124
+ parts = list(trace_dir.parts)
125
+ if "traces" not in parts:
126
+ return None
127
+ idx = parts.index("traces")
128
+ suffix_parts = parts[idx + 1 :]
129
+ if not suffix_parts:
130
+ return None
131
+ return Path(*suffix_parts)
132
+
133
+
134
+ def _parse_manifest(manifest_path: Path) -> List[dict]:
135
+ records: List[dict] = []
136
+ with manifest_path.open("r", encoding="utf-8") as f:
137
+ for line in f:
138
+ if not line.strip():
139
+ continue
140
+ records.append(json.loads(line))
141
+ if not records:
142
+ raise SystemExit(f"Empty manifest.jsonl: {manifest_path}")
143
+ return records
144
+
145
+
146
+ def _read_span(npz: np.lib.npyio.NpzFile, key: str) -> Optional[Tuple[int, int]]:
147
+ if key not in npz.files:
148
+ return None
149
+ arr = npz[key]
150
+ if arr.shape != (2,):
151
+ raise ValueError(f"Expected {key} to have shape (2,), got {arr.shape}.")
152
+ return int(arr[0]), int(arr[1])
153
+
154
+
155
+ def _span_or_empty(span: Optional[Tuple[int, int]]) -> Tuple[int, int]:
156
+ if span is None:
157
+ return -1, -1
158
+ return int(span[0]), int(span[1])
159
+
160
+
161
+ def _tokenize_for_exp2_alignment(
162
+ tokenizer,
163
+ *,
164
+ prompt: str,
165
+ target: str,
166
+ expected_prompt_len: int,
167
+ expected_gen_len: int,
168
+ ) -> List[str]:
169
+ prompt_text = " " + (prompt or "")
170
+ prompt_tokens = _decode_text_into_tokens(tokenizer, prompt_text)
171
+ if len(prompt_tokens) != int(expected_prompt_len):
172
+ raise ValueError(f"Prompt token length mismatch: expected {expected_prompt_len}, got {len(prompt_tokens)}.")
173
+
174
+ gen_ids = tokenizer(target + tokenizer.eos_token, add_special_tokens=False).input_ids
175
+ gen_text = tokenizer.decode(gen_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
176
+ gen_tokens = _decode_text_into_tokens(tokenizer, gen_text)
177
+ if len(gen_tokens) != int(expected_gen_len):
178
+ raise ValueError(f"Generation token length mismatch: expected {expected_gen_len}, got {len(gen_tokens)}.")
179
+
180
+ gen_tokens_no_eos = gen_tokens[:-1] if gen_tokens else []
181
+ return prompt_tokens + gen_tokens_no_eos
182
+
183
+
184
+ def _clamp_span(span: Optional[Tuple[int, int]], *, max_index: int) -> Optional[Tuple[int, int]]:
185
+ if span is None:
186
+ return None
187
+ start, end = int(span[0]), int(span[1])
188
+ if max_index < 0:
189
+ return None
190
+ if end < 0 or start > max_index:
191
+ return None
192
+ start = max(0, start)
193
+ end = min(max_index, end)
194
+ if end < start:
195
+ return None
196
+ return start, end
197
+
198
+
199
+ def _proc_one(
200
+ *,
201
+ trace_npz_path: Path,
202
+ record: dict,
203
+ dataset_index: Dict[Tuple[str, str], DatasetEntry],
204
+ tokenizer,
205
+ out_path: Path,
206
+ overwrite: bool,
207
+ allow_missing_ft_hops: bool,
208
+ ) -> None:
209
+ prompt_sha1 = str(record.get("prompt_sha1") or "")
210
+ target_sha1 = str(record.get("target_sha1") or "")
211
+ if not prompt_sha1 or not target_sha1:
212
+ raise ValueError("manifest record missing prompt_sha1/target_sha1; cannot match dataset.")
213
+
214
+ entry = dataset_index.get((prompt_sha1, target_sha1))
215
+ if entry is None:
216
+ raise ValueError(
217
+ "Failed to match manifest sha1 to dataset_jsonl. "
218
+ "Ensure --dataset_jsonl points to the exact cached JSONL used for this trace run."
219
+ )
220
+
221
+ if out_path.exists() and not overwrite:
222
+ raise FileExistsError(f"Refusing to overwrite existing file: {out_path} (use --overwrite).")
223
+ out_path.parent.mkdir(parents=True, exist_ok=True)
224
+
225
+ with np.load(trace_npz_path, allow_pickle=False) as f:
226
+ prompt_len = int(np.asarray(f.get("prompt_len")).item())
227
+ gen_len = int(np.asarray(f.get("gen_len")).item())
228
+ total_len = prompt_len + gen_len
229
+ gen_no_eos = max(0, gen_len - 1)
230
+ L = prompt_len + gen_no_eos
231
+
232
+ v_row_all = f.get("v_row_all")
233
+ if v_row_all is None:
234
+ raise ValueError("Missing v_row_all in trace npz; cannot build row attribution vector.")
235
+ v_row_all = np.asarray(v_row_all, dtype=np.float32)
236
+ if v_row_all.ndim != 1 or int(v_row_all.shape[0]) != int(total_len):
237
+ raise ValueError(f"v_row_all shape mismatch: expected ({total_len},), got {tuple(v_row_all.shape)}.")
238
+ attr = v_row_all[:L]
239
+
240
+ indices_to_explain = _read_span(f, "indices_to_explain_gen")
241
+ sink_span_gen = _read_span(f, "sink_span_gen") or indices_to_explain
242
+ if sink_span_gen is None:
243
+ raise ValueError("Missing sink_span_gen/indices_to_explain_gen; cannot define output span.")
244
+ thinking_span_gen = _read_span(f, "thinking_span_gen")
245
+ if thinking_span_gen is None:
246
+ sink_start = int(sink_span_gen[0])
247
+ think_end = sink_start - 1
248
+ thinking_span_gen = (0, think_end) if think_end >= 0 else None
249
+
250
+ sink_span_gen = _clamp_span(sink_span_gen, max_index=gen_no_eos - 1)
251
+ thinking_span_gen = _clamp_span(thinking_span_gen, max_index=gen_no_eos - 1)
252
+
253
+ span_in = (0, prompt_len - 1) if prompt_len > 0 else (-1, -1)
254
+ span_cot = (
255
+ (prompt_len + thinking_span_gen[0], prompt_len + thinking_span_gen[1])
256
+ if thinking_span_gen is not None
257
+ else (-1, -1)
258
+ )
259
+ span_out = (
260
+ (prompt_len + sink_span_gen[0], prompt_len + sink_span_gen[1]) if sink_span_gen is not None else (-1, -1)
261
+ )
262
+
263
+ tokens = _tokenize_for_exp2_alignment(
264
+ tokenizer,
265
+ prompt=entry.prompt,
266
+ target=entry.target,
267
+ expected_prompt_len=prompt_len,
268
+ expected_gen_len=gen_len,
269
+ )
270
+ if len(tokens) != int(L):
271
+ raise ValueError(f"Token length mismatch after EOS drop: expected {L}, got {len(tokens)}.")
272
+
273
+ # Scores: row = index 1.
274
+ rise = float("nan")
275
+ mas = float("nan")
276
+ faith = f.get("faithfulness_scores")
277
+ if faith is not None:
278
+ faith = np.asarray(faith, dtype=np.float64)
279
+ if faith.shape != (3, 3):
280
+ raise ValueError(f"faithfulness_scores shape mismatch: expected (3,3), got {tuple(faith.shape)}.")
281
+ rise = float(faith[1, 0])
282
+ mas = float(faith[1, 1])
283
+
284
+ recovery = float("nan")
285
+ rec = f.get("recovery_scores")
286
+ if rec is not None:
287
+ rec = np.asarray(rec, dtype=np.float64)
288
+ if rec.shape != (3,):
289
+ raise ValueError(f"recovery_scores shape mismatch: expected (3,), got {tuple(rec.shape)}.")
290
+ recovery = float(rec[1])
291
+
292
+ out_payload = {
293
+ "attr": np.asarray(attr, dtype=np.float32),
294
+ "tok": np.asarray(tokens, dtype=np.str_),
295
+ "span_in": np.asarray(span_in, dtype=np.int64),
296
+ "span_cot": np.asarray(span_cot, dtype=np.int64),
297
+ "span_out": np.asarray(span_out, dtype=np.int64),
298
+ "rise": np.asarray(rise, dtype=np.float64),
299
+ "mas": np.asarray(mas, dtype=np.float64),
300
+ "recovery": np.asarray(recovery, dtype=np.float64),
301
+ }
302
+
303
+ attr_func = str(record.get("attr_func") or "")
304
+ want_hops = attr_func in FT_IFR_ATTR_FUNCS
305
+ if want_hops:
306
+ vh = f.get("vh")
307
+ if vh is None:
308
+ if not allow_missing_ft_hops:
309
+ raise ValueError(
310
+ f"FT-IFR method '{attr_func}' requires per-hop vectors but trace npz is missing 'vh'. "
311
+ "Re-run exp2 with --save_hop_traces using the updated code."
312
+ )
313
+ else:
314
+ vh = np.asarray(vh, dtype=np.float32)
315
+ if vh.ndim != 2 or int(vh.shape[1]) != int(total_len):
316
+ raise ValueError(
317
+ f"vh shape mismatch: expected (H,{total_len}), got {tuple(vh.shape)} for {trace_npz_path}."
318
+ )
319
+ out_payload["hop"] = vh[:, :L]
320
+
321
+ np.savez_compressed(out_path, **out_payload)
322
+
323
+
324
+ def main() -> None:
325
+ ap = argparse.ArgumentParser("Map exp2 trace folder -> exp/proc/output per-sample npz files.")
326
+ ap.add_argument("--trace_dir", type=str, required=True, help="Path to an exp2 trace run directory (contains manifest.jsonl).")
327
+ ap.add_argument("--dataset_jsonl", type=str, default=None, help="Path to the exp2 cached dataset JSONL used for this trace.")
328
+ ap.add_argument(
329
+ "--tokenizer_model",
330
+ type=str,
331
+ required=True,
332
+ help="Tokenizer model name or local path (must match exp2 attribution tokenizer).",
333
+ )
334
+ ap.add_argument("--out_root", type=str, default="exp/proc/output", help="Root directory for proc outputs.")
335
+ ap.add_argument("--out_dir", type=str, default=None, help="Optional explicit output directory (overrides --out_root).")
336
+ ap.add_argument("--overwrite", action="store_true", help="Overwrite existing output files if present.")
337
+ ap.add_argument("--limit", type=int, default=None, help="Optional limit on number of samples to process (debug).")
338
+ ap.add_argument(
339
+ "--allow_missing_ft_hops",
340
+ action="store_true",
341
+ help="Allow producing FT-IFR outputs even when per-hop vectors (vh) are missing (not recommended).",
342
+ )
343
+ args = ap.parse_args()
344
+
345
+ trace_dir = Path(args.trace_dir)
346
+ if not trace_dir.exists() or not trace_dir.is_dir():
347
+ raise SystemExit(f"Missing trace_dir: {trace_dir}")
348
+ manifest_path = trace_dir / "manifest.jsonl"
349
+ if not manifest_path.exists():
350
+ raise SystemExit(f"Missing manifest.jsonl: {manifest_path}")
351
+
352
+ dataset_jsonl: Optional[Path] = Path(args.dataset_jsonl) if args.dataset_jsonl else None
353
+ if dataset_jsonl is None:
354
+ suffix = _infer_trace_suffix(trace_dir)
355
+ if suffix is not None and len(suffix.parts) >= 3:
356
+ # suffix = <dataset_name...>/<model_tag>/<run_tag>
357
+ inferred_dataset = Path(*suffix.parts[:-2])
358
+ if inferred_dataset.exists() and inferred_dataset.is_file():
359
+ dataset_jsonl = inferred_dataset
360
+ if dataset_jsonl is None:
361
+ raise SystemExit("Please pass --dataset_jsonl (could not infer it from --trace_dir).")
362
+ if not dataset_jsonl.exists():
363
+ raise SystemExit(f"Missing --dataset_jsonl: {dataset_jsonl}")
364
+
365
+ tokenizer = _load_tokenizer(str(args.tokenizer_model))
366
+ dataset_index = _index_dataset_by_sha1(dataset_jsonl)
367
+ records = _parse_manifest(manifest_path)
368
+
369
+ if args.out_dir:
370
+ out_dir = Path(args.out_dir)
371
+ else:
372
+ suffix = _infer_trace_suffix(trace_dir)
373
+ out_dir = Path(args.out_root) / suffix if suffix is not None else Path(args.out_root) / trace_dir.name
374
+ out_dir.mkdir(parents=True, exist_ok=True)
375
+
376
+ total = len(records)
377
+ limit = args.limit
378
+ if limit is not None:
379
+ if limit <= 0:
380
+ raise SystemExit("--limit must be a positive integer.")
381
+ total = min(total, int(limit))
382
+
383
+ processed = 0
384
+ for record in records[:total]:
385
+ file_name = str(record.get("file") or "")
386
+ if not file_name:
387
+ raise SystemExit("manifest record missing 'file' field.")
388
+ trace_npz_path = trace_dir / file_name
389
+ if not trace_npz_path.exists():
390
+ raise SystemExit(f"Missing trace npz referenced by manifest: {trace_npz_path}")
391
+
392
+ out_path = out_dir / file_name
393
+ try:
394
+ _proc_one(
395
+ trace_npz_path=trace_npz_path,
396
+ record=record,
397
+ dataset_index=dataset_index,
398
+ tokenizer=tokenizer,
399
+ out_path=out_path,
400
+ overwrite=bool(args.overwrite),
401
+ allow_missing_ft_hops=bool(args.allow_missing_ft_hops),
402
+ )
403
+ except Exception as exc:
404
+ raise SystemExit(f"Failed processing {trace_npz_path}: {exc}") from exc
405
+ processed += 1
406
+
407
+ print(f"Wrote {processed} proc samples -> {out_dir}")
408
+
409
+
410
+ if __name__ == "__main__":
411
+ main()
exp/proc_1/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # exp/proc_1(exp2 trace 映射/对外导出 v1)
2
+
3
+ 本目录提供把 `exp/exp2/run_exp.py --save_hop_traces` 产出的 trace 结果,整理成“给合作者使用”的精简样本级 `.npz` 的工具(v1)。
4
+
5
+ 与 `exp/proc/` 的区别:
6
+ - 去掉 `tok`(逐 token 文本片段)。
7
+ - 新增 `length`(三段 token 长度):`[in, cot, out]`,并保证与 `span_in/span_cot/span_out` 对齐。
8
+ - `hop` 字段采用“默认策略”:当 trace 样本中存在 `vh` 时才输出 `hop`;否则不输出且不报错。
9
+ - 支持一次性处理 `exp/exp2/output/traces/` 下所有 run 目录(所有数据集-方法组合)。
10
+
11
+ ---
12
+
13
+ ## 输入结构(exp2 traces)
14
+
15
+ `exp2` 的 trace run 目录形如:
16
+ - `exp/exp2/output/traces/<dataset>/<model>/<run_tag>/`
17
+
18
+ 每个 run 目录包含:
19
+ - `manifest.jsonl`(每行一个样本记录,包含 `file=ex_*.npz`)
20
+ - `ex_*.npz`(每样本一个 npz)
21
+
22
+ ---
23
+
24
+ ## 输出位置与命名
25
+
26
+ 默认输出到:
27
+ - `exp/proc_1/output/<trace_dir 在 traces/ 之后的同构路径>/`
28
+
29
+ 例如输入:
30
+ - `.../output/traces/exp/exp2/data/math.jsonl/qwen-8B/<run_tag>/`
31
+
32
+ 默认输出:
33
+ - `exp/proc_1/output/exp/exp2/data/math.jsonl/qwen-8B/<run_tag>/`
34
+
35
+ ---
36
+
37
+ ## 输出 `.npz` 字段
38
+
39
+ 每个输出样本 `.npz` 仅包含下列键:
40
+ - `attr`:`float32[L]`,row 归因向量;覆盖 `input+cot+output` 的有效 token(移除 generation 末尾 EOS)。
41
+ - `hop`:`float32[H, L]`(可选),当 trace npz 中存在 `vh` 时输出(同样移除 EOS,并与 `attr` 等长对齐)。
42
+ - `span_in`:`int64[2]`,input 在向量中的闭区间范围。
43
+ - `span_cot`:`int64[2]`,cot 在向量中的闭区间范围(无 cot 时为 `[-1, -1]`)。
44
+ - `span_out`:`int64[2]`,output 在向量中的闭区间范围。
45
+ - `length`:`int64[3]`,顺序为 `[in, cot, out]`,长度与 `span_*` 严格对应(闭区间长度 `end-start+1`,空 span 长度为 0)。
46
+ - `rise`:`float64`,row 的 RISE(faithfulness)。
47
+ - `mas`:`float64`,row 的 MAS(faithfulness)。
48
+ - `recovery`:`float64`,row 的 Recovery@10%(没有 recovery 时为 NaN)。
49
+
50
+ ---
51
+
52
+ ## 用法示例
53
+
54
+ 处理 traces 下所有 run(推荐):
55
+ ```bash
56
+ python exp/proc_1/map_exp2_traces_to_proc_1.py \
57
+ --traces_root exp/exp2/output/traces
58
+ ```
59
+
60
+ 只处理某一个 run 目录:
61
+ ```bash
62
+ python exp/proc_1/map_exp2_traces_to_proc_1.py \
63
+ --trace_dir exp/exp2/output/traces/exp/exp2/data/math.jsonl/qwen-8B/ifr_multi_hop_both_n1_mfaithfulness_gen_100ex
64
+ ```
65
+
66
+ 调试:每个 run 只处理前 5 条、允许覆盖输出:
67
+ ```bash
68
+ python exp/proc_1/map_exp2_traces_to_proc_1.py \
69
+ --traces_root exp/exp2/output/traces \
70
+ --limit 5 \
71
+ --overwrite
72
+ ```
exp/proc_1/map_exp2_traces_to_proc_1.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Map exp2 trace artifacts into a collaborator-friendly per-sample NPZ format (proc_1).
3
+
4
+ This is a lightweight variant of `exp/proc/map_exp2_traces_to_proc.py`:
5
+ - Removes `tok` (per-token text pieces).
6
+ - Adds `length` with three components [in, cot, out], aligned to span_in/span_cot/span_out.
7
+ - Saves `hop` only when the trace sample contains `vh` (default strategy).
8
+ - Can process a single exp2 trace run directory or all run directories under a traces root.
9
+
10
+ Input: an exp2 trace run directory produced by `exp/exp2/run_exp.py --save_hop_traces`, e.g.:
11
+
12
+ exp/exp2/output/traces/exp/exp2/data/math.jsonl/qwen-8B/ifr_multi_hop_both_n1_mfaithfulness_gen_100ex/
13
+
14
+ This directory contains:
15
+ - manifest.jsonl (one JSON object per sample)
16
+ - ex_*.npz (per-sample vectors and scores)
17
+
18
+ Output: per-sample NPZ files under `exp/proc_1/output/` (or a user-provided output path),
19
+ each containing only:
20
+ - attr: row attribution vector over [input + CoT + output] tokens, with EOS removed
21
+ - hop: per-hop vectors (optional; only if `vh` exists in the trace npz), aligned to attr
22
+ - span_in/span_cot/span_out: inclusive ranges for input/CoT/output in the above vectors
23
+ - length: int64[3] = [in, cot, out], derived strictly from spans
24
+ - rise/mas: row faithfulness scores (RISE, MAS)
25
+ - recovery: row Recovery@10% score (NaN when unavailable)
26
+
27
+ This script is intentionally self-contained under exp/proc_1/ and does not modify exp2.
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import argparse
33
+ import json
34
+ from dataclasses import dataclass
35
+ from pathlib import Path
36
+ from typing import Iterable, List, Optional, Tuple
37
+
38
+ import numpy as np
39
+
40
+
41
+ def _infer_trace_suffix(trace_dir: Path) -> Optional[Path]:
42
+ parts = list(trace_dir.parts)
43
+ if "traces" not in parts:
44
+ return None
45
+ idx = parts.index("traces")
46
+ suffix_parts = parts[idx + 1 :]
47
+ if not suffix_parts:
48
+ return None
49
+ return Path(*suffix_parts)
50
+
51
+
52
+ def _iter_run_dirs(traces_root: Path) -> List[Path]:
53
+ runs = {p.parent for p in traces_root.rglob("manifest.jsonl") if p.is_file()}
54
+ return sorted(runs)
55
+
56
+
57
+ def _parse_manifest(manifest_path: Path) -> List[dict]:
58
+ records: List[dict] = []
59
+ with manifest_path.open("r", encoding="utf-8") as f:
60
+ for line in f:
61
+ if not line.strip():
62
+ continue
63
+ records.append(json.loads(line))
64
+ return records
65
+
66
+
67
+ def _read_span(npz: np.lib.npyio.NpzFile, key: str) -> Optional[Tuple[int, int]]:
68
+ if key not in npz.files:
69
+ return None
70
+ arr = npz[key]
71
+ if arr.shape != (2,):
72
+ raise ValueError(f"Expected {key} to have shape (2,), got {arr.shape}.")
73
+ return int(arr[0]), int(arr[1])
74
+
75
+
76
+ def _clamp_span(span: Optional[Tuple[int, int]], *, max_index: int) -> Optional[Tuple[int, int]]:
77
+ if span is None:
78
+ return None
79
+ start, end = int(span[0]), int(span[1])
80
+ if max_index < 0:
81
+ return None
82
+ if end < 0 or start > max_index:
83
+ return None
84
+ start = max(0, start)
85
+ end = min(max_index, end)
86
+ if end < start:
87
+ return None
88
+ return start, end
89
+
90
+
91
+ def _span_len(span: Tuple[int, int]) -> int:
92
+ start, end = int(span[0]), int(span[1])
93
+ if start < 0 or end < 0 or end < start:
94
+ return 0
95
+ return int(end - start + 1)
96
+
97
+
98
+ @dataclass(frozen=True)
99
+ class ProcOneResult:
100
+ wrote: bool
101
+ has_hop: bool
102
+
103
+
104
+ def _proc_one(
105
+ *,
106
+ trace_npz_path: Path,
107
+ record: dict,
108
+ out_path: Path,
109
+ overwrite: bool,
110
+ ) -> ProcOneResult:
111
+ if out_path.exists() and not overwrite:
112
+ raise FileExistsError(f"Refusing to overwrite existing file: {out_path} (use --overwrite).")
113
+ out_path.parent.mkdir(parents=True, exist_ok=True)
114
+
115
+ with np.load(trace_npz_path, allow_pickle=False) as f:
116
+ prompt_len = int(np.asarray(f.get("prompt_len")).item())
117
+ gen_len = int(np.asarray(f.get("gen_len")).item())
118
+ total_len = prompt_len + gen_len
119
+ gen_no_eos = max(0, gen_len - 1)
120
+ L = prompt_len + gen_no_eos
121
+
122
+ v_row_all = f.get("v_row_all")
123
+ if v_row_all is None:
124
+ raise ValueError("Missing v_row_all in trace npz; cannot build row attribution vector.")
125
+ v_row_all = np.asarray(v_row_all, dtype=np.float32)
126
+ if v_row_all.ndim != 1 or int(v_row_all.shape[0]) != int(total_len):
127
+ raise ValueError(f"v_row_all shape mismatch: expected ({total_len},), got {tuple(v_row_all.shape)}.")
128
+ attr = v_row_all[:L]
129
+
130
+ indices_to_explain = _read_span(f, "indices_to_explain_gen")
131
+ sink_span_gen = _read_span(f, "sink_span_gen") or indices_to_explain
132
+ if sink_span_gen is None:
133
+ raise ValueError("Missing sink_span_gen/indices_to_explain_gen; cannot define output span.")
134
+ thinking_span_gen = _read_span(f, "thinking_span_gen")
135
+ if thinking_span_gen is None:
136
+ sink_start = int(sink_span_gen[0])
137
+ think_end = sink_start - 1
138
+ thinking_span_gen = (0, think_end) if think_end >= 0 else None
139
+
140
+ sink_span_gen = _clamp_span(sink_span_gen, max_index=gen_no_eos - 1)
141
+ thinking_span_gen = _clamp_span(thinking_span_gen, max_index=gen_no_eos - 1)
142
+
143
+ span_in = (0, prompt_len - 1) if prompt_len > 0 else (-1, -1)
144
+ span_cot = (
145
+ (prompt_len + thinking_span_gen[0], prompt_len + thinking_span_gen[1])
146
+ if thinking_span_gen is not None
147
+ else (-1, -1)
148
+ )
149
+ span_out = (
150
+ (prompt_len + sink_span_gen[0], prompt_len + sink_span_gen[1]) if sink_span_gen is not None else (-1, -1)
151
+ )
152
+
153
+ length = np.asarray([_span_len(span_in), _span_len(span_cot), _span_len(span_out)], dtype=np.int64)
154
+
155
+ rise = float("nan")
156
+ mas = float("nan")
157
+ faith = f.get("faithfulness_scores")
158
+ if faith is not None:
159
+ faith = np.asarray(faith, dtype=np.float64)
160
+ if faith.shape != (3, 3):
161
+ raise ValueError(f"faithfulness_scores shape mismatch: expected (3,3), got {tuple(faith.shape)}.")
162
+ rise = float(faith[1, 0])
163
+ mas = float(faith[1, 1])
164
+
165
+ recovery = float("nan")
166
+ rec = f.get("recovery_scores")
167
+ if rec is not None:
168
+ rec = np.asarray(rec, dtype=np.float64)
169
+ if rec.shape != (3,):
170
+ raise ValueError(f"recovery_scores shape mismatch: expected (3,), got {tuple(rec.shape)}.")
171
+ recovery = float(rec[1])
172
+
173
+ out_payload = {
174
+ "attr": np.asarray(attr, dtype=np.float32),
175
+ "span_in": np.asarray(span_in, dtype=np.int64),
176
+ "span_cot": np.asarray(span_cot, dtype=np.int64),
177
+ "span_out": np.asarray(span_out, dtype=np.int64),
178
+ "length": np.asarray(length, dtype=np.int64),
179
+ "rise": np.asarray(rise, dtype=np.float64),
180
+ "mas": np.asarray(mas, dtype=np.float64),
181
+ "recovery": np.asarray(recovery, dtype=np.float64),
182
+ }
183
+
184
+ has_hop = False
185
+ vh = f.get("vh")
186
+ if vh is not None:
187
+ vh = np.asarray(vh, dtype=np.float32)
188
+ if vh.ndim != 2 or int(vh.shape[1]) != int(total_len):
189
+ raise ValueError(f"vh shape mismatch: expected (H,{total_len}), got {tuple(vh.shape)} for {trace_npz_path}.")
190
+ out_payload["hop"] = vh[:, :L]
191
+ has_hop = True
192
+
193
+ np.savez_compressed(out_path, **out_payload)
194
+ _ = record
195
+ return ProcOneResult(wrote=True, has_hop=has_hop)
196
+
197
+
198
+ def _resolve_out_dir_for_trace_dir(*, trace_dir: Path, out_root: Path, out_dir: Optional[Path]) -> Path:
199
+ if out_dir is not None:
200
+ return out_dir
201
+ suffix = _infer_trace_suffix(trace_dir)
202
+ return (out_root / suffix) if suffix is not None else (out_root / trace_dir.name)
203
+
204
+
205
+ def _process_trace_dir(
206
+ *,
207
+ trace_dir: Path,
208
+ out_root: Path,
209
+ out_dir: Optional[Path],
210
+ overwrite: bool,
211
+ limit: Optional[int],
212
+ skip_empty_manifest: bool,
213
+ ) -> Tuple[int, int]:
214
+ manifest_path = trace_dir / "manifest.jsonl"
215
+ if not manifest_path.exists():
216
+ raise SystemExit(f"Missing manifest.jsonl: {manifest_path}")
217
+
218
+ records = _parse_manifest(manifest_path)
219
+ if not records:
220
+ if skip_empty_manifest:
221
+ print(f"[skip] empty manifest: {manifest_path}")
222
+ return 0, 0
223
+ raise SystemExit(f"Empty manifest.jsonl: {manifest_path}")
224
+
225
+ total = len(records)
226
+ if limit is not None:
227
+ if limit <= 0:
228
+ raise SystemExit("--limit must be a positive integer.")
229
+ total = min(total, int(limit))
230
+
231
+ resolved_out_dir = _resolve_out_dir_for_trace_dir(trace_dir=trace_dir, out_root=out_root, out_dir=out_dir)
232
+ resolved_out_dir.mkdir(parents=True, exist_ok=True)
233
+
234
+ wrote = 0
235
+ wrote_with_hop = 0
236
+ for record in records[:total]:
237
+ file_name = str(record.get("file") or "")
238
+ if not file_name:
239
+ raise SystemExit("manifest record missing 'file' field.")
240
+ trace_npz_path = trace_dir / file_name
241
+ if not trace_npz_path.exists():
242
+ raise SystemExit(f"Missing trace npz referenced by manifest: {trace_npz_path}")
243
+
244
+ out_path = resolved_out_dir / file_name
245
+ try:
246
+ res = _proc_one(trace_npz_path=trace_npz_path, record=record, out_path=out_path, overwrite=overwrite)
247
+ except Exception as exc:
248
+ raise SystemExit(f"Failed processing {trace_npz_path}: {exc}") from exc
249
+ wrote += int(res.wrote)
250
+ wrote_with_hop += int(res.has_hop)
251
+
252
+ print(f"[ok] wrote {wrote} samples ({wrote_with_hop} with hop) -> {resolved_out_dir}")
253
+ return wrote, wrote_with_hop
254
+
255
+
256
+ def main() -> None:
257
+ ap = argparse.ArgumentParser("Map exp2 trace folder(s) -> exp/proc_1/output per-sample npz files.")
258
+ ap.add_argument(
259
+ "--trace_dir",
260
+ type=str,
261
+ default=None,
262
+ help="Path to a single exp2 trace run directory (contains manifest.jsonl).",
263
+ )
264
+ ap.add_argument(
265
+ "--traces_root",
266
+ type=str,
267
+ default=None,
268
+ help="Path to traces root; processes all run dirs under it (each with a manifest.jsonl).",
269
+ )
270
+ ap.add_argument("--out_root", type=str, default="exp/proc_1/output", help="Root directory for proc_1 outputs.")
271
+ ap.add_argument(
272
+ "--out_dir",
273
+ type=str,
274
+ default=None,
275
+ help="Optional explicit output directory (only valid with --trace_dir; overrides --out_root).",
276
+ )
277
+ ap.add_argument("--overwrite", action="store_true", help="Overwrite existing output files if present.")
278
+ ap.add_argument("--limit", type=int, default=None, help="Optional limit on number of samples per run (debug).")
279
+ ap.add_argument(
280
+ "--fail_on_empty_manifest",
281
+ action="store_true",
282
+ help="Fail (instead of skipping) when encountering an empty manifest.jsonl.",
283
+ )
284
+ args = ap.parse_args()
285
+
286
+ trace_dir = Path(args.trace_dir) if args.trace_dir else None
287
+ traces_root = Path(args.traces_root) if args.traces_root else None
288
+ if (trace_dir is None) == (traces_root is None):
289
+ raise SystemExit("Please pass exactly one of --trace_dir or --traces_root.")
290
+
291
+ out_root = Path(args.out_root)
292
+ out_dir = Path(args.out_dir) if args.out_dir else None
293
+ if out_dir is not None and trace_dir is None:
294
+ raise SystemExit("--out_dir is only valid with --trace_dir (for --traces_root use --out_root).")
295
+
296
+ skip_empty_manifest = not bool(args.fail_on_empty_manifest)
297
+
298
+ if trace_dir is not None:
299
+ if not trace_dir.exists() or not trace_dir.is_dir():
300
+ raise SystemExit(f"Missing trace_dir: {trace_dir}")
301
+ _process_trace_dir(
302
+ trace_dir=trace_dir,
303
+ out_root=out_root,
304
+ out_dir=out_dir,
305
+ overwrite=bool(args.overwrite),
306
+ limit=args.limit,
307
+ skip_empty_manifest=skip_empty_manifest,
308
+ )
309
+ return
310
+
311
+ assert traces_root is not None
312
+ if not traces_root.exists() or not traces_root.is_dir():
313
+ raise SystemExit(f"Missing traces_root: {traces_root}")
314
+
315
+ run_dirs = _iter_run_dirs(traces_root)
316
+ if not run_dirs:
317
+ raise SystemExit(f"No run directories found under traces_root={traces_root} (expected manifest.jsonl).")
318
+
319
+ total_written = 0
320
+ total_with_hop = 0
321
+ for run_dir in run_dirs:
322
+ wrote, wrote_with_hop = _process_trace_dir(
323
+ trace_dir=run_dir,
324
+ out_root=out_root,
325
+ out_dir=None,
326
+ overwrite=bool(args.overwrite),
327
+ limit=args.limit,
328
+ skip_empty_manifest=skip_empty_manifest,
329
+ )
330
+ total_written += wrote
331
+ total_with_hop += wrote_with_hop
332
+
333
+ print(f"[done] total wrote {total_written} samples ({total_with_hop} with hop) under out_root={out_root}")
334
+
335
+
336
+ if __name__ == "__main__":
337
+ main()
338
+
flashtrace/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """FlashTrace: efficient multi-token attribution for reasoning LLMs."""
2
+
3
+ from .model_io import load_model_and_tokenizer
4
+ from .result import TokenScore, TraceResult
5
+ from .tracer import FlashTrace
6
+
7
+ __all__ = ["FlashTrace", "TraceResult", "TokenScore", "load_model_and_tokenizer"]
flashtrace/attribution.py ADDED
The diff for this file is too large to render. See raw diff
 
flashtrace/baselines/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Baseline attribution methods for FlashTrace."""
2
+
3
+ from .attnlrp import LLMLRPAttribution
4
+
5
+ __all__ = ["LLMLRPAttribution"]
flashtrace/baselines/attnlrp.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AttnLRP baseline API."""
2
+
3
+ from flashtrace.attribution import AttnLRPSpanAggregate, LLMLRPAttribution, MultiHopAttnLRPResult
4
+ from flashtrace.lrp_patches import detect_model_type, lrp_context
5
+
6
+ __all__ = [
7
+ "AttnLRPSpanAggregate",
8
+ "LLMLRPAttribution",
9
+ "MultiHopAttnLRPResult",
10
+ "detect_model_type",
11
+ "lrp_context",
12
+ ]