Sync FlashTrace package from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -35
- .gitignore +33 -0
- .python-version +1 -0
- .vscode/launch.json +25 -0
- LICENSE +21 -0
- MANIFEST.in +5 -0
- README.md +293 -0
- attribution_datasets.py +265 -0
- docs/superpowers/plans/2026-05-03-flashtrace-public-package.md +1605 -0
- docs/superpowers/specs/2026-05-03-flashtrace-public-package-design.md +231 -0
- dump_exp2_hop_vh.py +412 -0
- evaluations/attribution_recovery.py +490 -0
- evaluations/attribution_recovery.sh +18 -0
- evaluations/faithfulness.py +491 -0
- evaluations/faithfulness.sh +80 -0
- example.ipynb +0 -0
- examples/quickstart.py +44 -0
- exp/case_study/README.md +152 -0
- exp/case_study/analysis.py +74 -0
- exp/case_study/faithfulness_trace.py +183 -0
- exp/case_study/run_ifr_case.py +1225 -0
- exp/case_study/run_mas_case.py +805 -0
- exp/case_study/viz.py +647 -0
- exp/exp1/README.md +46 -0
- exp/exp1/run_time_curve.py +757 -0
- exp/exp2/DATASETS.md +231 -0
- exp/exp2/README.md +106 -0
- exp/exp2/dataset_utils.py +386 -0
- exp/exp2/map_math_mine_to_exp2_cache.py +584 -0
- exp/exp2/migrate_indices_to_explain_token_span.py +129 -0
- exp/exp2/out.log +102 -0
- exp/exp2/run_exp.py +1296 -0
- exp/exp2/sample_and_filter.py +363 -0
- exp/exp3/README.md +50 -0
- exp/exp3/extract_segment_weights.py +250 -0
- exp/exp3/part_weights.py +228 -0
- exp/exp3/run_exp.py +430 -0
- exp/exp3/sample_and_filter.py +628 -0
- exp/exp4/README.md +85 -0
- exp/exp4/run_exp.py +487 -0
- exp/exp5/README.md +119 -0
- exp/exp5/map_exp2_cache_token_spans.py +407 -0
- exp/proc/README.md +98 -0
- exp/proc/map_exp2_traces_to_proc.py +411 -0
- exp/proc_1/README.md +72 -0
- exp/proc_1/map_exp2_traces_to_proc_1.py +338 -0
- flashtrace/__init__.py +7 -0
- flashtrace/attribution.py +0 -0
- flashtrace/baselines/__init__.py +5 -0
- flashtrace/baselines/attnlrp.py +12 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,6 @@
|
|
| 1 |
-
|
| 2 |
-
*.
|
| 3 |
-
*.
|
| 4 |
-
*
|
| 5 |
-
*
|
| 6 |
-
*
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Keep language stats focused on the public Python package.
|
| 2 |
+
*.ipynb linguist-vendored
|
| 3 |
+
*.html linguist-generated
|
| 4 |
+
exp/** linguist-vendored
|
| 5 |
+
evaluations/** linguist-vendored
|
| 6 |
+
docs/** linguist-documentation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# Local data
|
| 13 |
+
data/
|
| 14 |
+
|
| 15 |
+
# dev
|
| 16 |
+
AGENTS.md
|
| 17 |
+
readme_dev.md
|
| 18 |
+
.superpowers/
|
| 19 |
+
contribute/ruler/
|
| 20 |
+
repos/.DS_Store
|
| 21 |
+
repomix-output.xml
|
| 22 |
+
|
| 23 |
+
# FlashTrace generated artifacts
|
| 24 |
+
trace.json
|
| 25 |
+
trace.html
|
| 26 |
+
*.trace.json
|
| 27 |
+
*.trace.html
|
| 28 |
+
exp/**/output/
|
| 29 |
+
exp/**/out/
|
| 30 |
+
exp/**/out-*/
|
| 31 |
+
*.npz
|
| 32 |
+
.DS_Store
|
| 33 |
+
repomix-output.xml
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.13
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "0.2.0",
|
| 3 |
+
"configurations": [
|
| 4 |
+
{
|
| 5 |
+
"name": "Faithfulness Eval",
|
| 6 |
+
"type": "debugpy",
|
| 7 |
+
"request": "launch",
|
| 8 |
+
"module": "evaluations.faithfulness",
|
| 9 |
+
"args": [
|
| 10 |
+
"--model",
|
| 11 |
+
"qwen-4B",
|
| 12 |
+
"--cuda_num",
|
| 13 |
+
"1",
|
| 14 |
+
"--num_examples",
|
| 15 |
+
"500",
|
| 16 |
+
"--attr_func",
|
| 17 |
+
"IG",
|
| 18 |
+
"--dataset",
|
| 19 |
+
"facts"
|
| 20 |
+
],
|
| 21 |
+
"console": "integratedTerminal",
|
| 22 |
+
"justMyCode": true
|
| 23 |
+
}
|
| 24 |
+
]
|
| 25 |
+
}
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Wenbo Pan
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include README.md
|
| 2 |
+
include LICENSE
|
| 3 |
+
include pyproject.toml
|
| 4 |
+
include examples/*.py
|
| 5 |
+
include tests/*.py
|
README.md
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="https://raw.githubusercontent.com/wbopan/flashtrace/master/docs/assets/flashtrace-logo.png" alt="FlashTrace logo" width="160">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
<h1 align="center">FlashTrace</h1>
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<em>Fast token attribution for reasoning language models.</em>
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
<p align="center">
|
| 12 |
+
<a href="https://pypi.org/project/flashtrace/"><img alt="PyPI" src="https://img.shields.io/pypi/v/flashtrace.svg?style=flat-square&logo=pypi&logoColor=white&label=PyPI"></a>
|
| 13 |
+
<a href="https://pypi.org/project/flashtrace/"><img alt="Python" src="https://img.shields.io/pypi/pyversions/flashtrace.svg?style=flat-square&logo=python&logoColor=white"></a>
|
| 14 |
+
<a href="https://github.com/wbopan/flashtrace/blob/master/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square"></a>
|
| 15 |
+
<a href="https://pytorch.org/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-2.5%2B-EE4C2C.svg?style=flat-square&logo=pytorch&logoColor=white"></a>
|
| 16 |
+
<a href="https://arxiv.org/abs/2602.01914"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2602.01914-B31B1B.svg?style=flat-square&logo=arxiv&logoColor=white"></a>
|
| 17 |
+
</p>
|
| 18 |
+
|
| 19 |
+
FlashTrace traces generated answers back to the prompt tokens that shaped them. Use it from Python or the command line, export JSON traces, and render standalone HTML heatmaps for inspection and sharing.
|
| 20 |
+
|
| 21 |
+
<p align="center">
|
| 22 |
+
<a href="https://arxiv.org/abs/2602.01914">📄 Paper</a>
|
| 23 |
+
·
|
| 24 |
+
<a href="#quickstart">🚀 Quickstart</a>
|
| 25 |
+
·
|
| 26 |
+
<a href="#command-line">💻 CLI</a>
|
| 27 |
+
·
|
| 28 |
+
<a href="#citation">📝 Citation</a>
|
| 29 |
+
</p>
|
| 30 |
+
|
| 31 |
+
## Why FlashTrace
|
| 32 |
+
|
| 33 |
+
Reasoning models produce long generated chains, final answers, and intermediate spans that deserve targeted inspection. FlashTrace gives researchers a package-first workflow for tracing a selected generated span back to its supporting prompt tokens.
|
| 34 |
+
|
| 35 |
+
You get:
|
| 36 |
+
|
| 37 |
+
- top-k prompt tokens ranked by attribution score
|
| 38 |
+
- JSON traces for downstream analysis
|
| 39 |
+
- standalone HTML token heatmaps
|
| 40 |
+
- optional per-hop attribution panels
|
| 41 |
+
- inclusive generation-token span controls for answer and reasoning segments
|
| 42 |
+
|
| 43 |
+
## Install
|
| 44 |
+
|
| 45 |
+
From PyPI:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
pip install flashtrace
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
From a local checkout:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
pip install -e .
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
For development:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
pip install -e ".[dev]"
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
FlashTrace uses PyTorch, Transformers, Accelerate, NumPy, and tqdm. A CUDA-capable GPU is recommended for public-scale Hugging Face models.
|
| 64 |
+
|
| 65 |
+
## Quickstart
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
from flashtrace import FlashTrace, load_model_and_tokenizer
|
| 69 |
+
|
| 70 |
+
prompt = """Context: Paris is the capital of France.
|
| 71 |
+
Question: What is the capital of France?"""
|
| 72 |
+
target = "Paris"
|
| 73 |
+
|
| 74 |
+
model, tokenizer = load_model_and_tokenizer("Qwen/Qwen3-8B", device_map="auto")
|
| 75 |
+
tracer = FlashTrace(model, tokenizer, chunk_tokens=128, sink_chunk_tokens=32)
|
| 76 |
+
|
| 77 |
+
trace = tracer.trace(
|
| 78 |
+
prompt=prompt,
|
| 79 |
+
target=target,
|
| 80 |
+
output_span=(0, 0),
|
| 81 |
+
hops=1,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
print(trace.topk_inputs(10))
|
| 85 |
+
trace.to_json("trace.json")
|
| 86 |
+
trace.to_html("trace.html")
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
`trace.topk_inputs(10)` returns `TokenScore` objects aligned to prompt-token indices:
|
| 90 |
+
|
| 91 |
+
```text
|
| 92 |
+
rank index token score
|
| 93 |
+
1 2 Paris 0.184
|
| 94 |
+
2 7 capital 0.131
|
| 95 |
+
3 10 France 0.119
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
`trace.html` is a standalone heatmap that highlights prompt tokens by final attribution score and includes trace metadata for the selected generated span.
|
| 99 |
+
|
| 100 |
+
`FlashTrace(..., use_chat_template=True)` formats prompts with the tokenizer chat template for chat-tuned models.
|
| 101 |
+
|
| 102 |
+
## Command Line
|
| 103 |
+
|
| 104 |
+
Create prompt and target files:
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
printf "Context: Paris is the capital of France.\nQuestion: What is the capital of France?\n" > prompt.txt
|
| 108 |
+
printf "Paris" > target.txt
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Run a trace:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
flashtrace trace \
|
| 115 |
+
--model Qwen/Qwen3-8B \
|
| 116 |
+
--prompt prompt.txt \
|
| 117 |
+
--target target.txt \
|
| 118 |
+
--output-span 0:0 \
|
| 119 |
+
--hops 1 \
|
| 120 |
+
--html trace.html \
|
| 121 |
+
--json trace.json
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
The command prints a compact top-k table and writes the requested artifacts.
|
| 125 |
+
|
| 126 |
+
Useful flags:
|
| 127 |
+
|
| 128 |
+
- `--model`: Hugging Face model id or local model path
|
| 129 |
+
- `--prompt`: UTF-8 prompt text file
|
| 130 |
+
- `--target`: UTF-8 target text file
|
| 131 |
+
- `--output-span`: inclusive `START:END` indices over generated tokens
|
| 132 |
+
- `--reasoning-span`: inclusive `START:END` indices for a reasoning segment
|
| 133 |
+
- `--method`: `flashtrace`, `ifr-span`, or `ifr-matrix`
|
| 134 |
+
- `--recompute-attention`: lower-memory attention recomputation path
|
| 135 |
+
- `--use-chat-template`: format prompts with the tokenizer chat template
|
| 136 |
+
- `--device-map`: Transformers device map, default `auto`
|
| 137 |
+
- `--dtype`: `auto`, `float16`, `bfloat16`, or `float32`
|
| 138 |
+
|
| 139 |
+
## Token Spans
|
| 140 |
+
|
| 141 |
+
`output_span` and `reasoning_span` use inclusive generation-token indices. The first generated token has index `0`.
|
| 142 |
+
|
| 143 |
+
Use an initial trace to inspect tokenization:
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
for index, token in enumerate(trace.generation_tokens):
|
| 147 |
+
print(index, repr(token))
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Then choose spans:
|
| 151 |
+
|
| 152 |
+
```python
|
| 153 |
+
trace = tracer.trace(
|
| 154 |
+
prompt=prompt,
|
| 155 |
+
target=target,
|
| 156 |
+
reasoning_span=(0, 79),
|
| 157 |
+
output_span=(80, 85),
|
| 158 |
+
hops=1,
|
| 159 |
+
)
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
Scores are aligned to `trace.prompt_tokens`. `trace.per_hop_scores` stores the same prompt-token alignment for each hop.
|
| 163 |
+
|
| 164 |
+
## Interpreting Results
|
| 165 |
+
|
| 166 |
+
High-scoring prompt tokens are the tokens FlashTrace attributes most strongly to the selected generated span. For answer inspection, use `output_span` around the final answer tokens. For chain-of-thought or reasoning inspection, use `reasoning_span` around the generated reasoning segment.
|
| 167 |
+
|
| 168 |
+
Recommended workflow:
|
| 169 |
+
|
| 170 |
+
1. Run a trace with your prompt and target.
|
| 171 |
+
2. Inspect `trace.generation_tokens`.
|
| 172 |
+
3. Select the answer or reasoning span.
|
| 173 |
+
4. Export `trace.html`.
|
| 174 |
+
5. Compare top-k tokens with the source prompt and any expected evidence.
|
| 175 |
+
|
| 176 |
+
## Supported Models
|
| 177 |
+
|
| 178 |
+
FlashTrace targets Llama/Qwen-style decoder-only Hugging Face causal LMs with:
|
| 179 |
+
|
| 180 |
+
- `model.layers`
|
| 181 |
+
- Q/K/V/O attention projections
|
| 182 |
+
- RMSNorm or LayerNorm
|
| 183 |
+
- RoPE metadata
|
| 184 |
+
|
| 185 |
+
Validated model families for the first public release:
|
| 186 |
+
|
| 187 |
+
- Qwen2
|
| 188 |
+
- Qwen3
|
| 189 |
+
- Llama
|
| 190 |
+
|
| 191 |
+
## Python API
|
| 192 |
+
|
| 193 |
+
The public package exports:
|
| 194 |
+
|
| 195 |
+
```python
|
| 196 |
+
from flashtrace import FlashTrace, TraceResult, load_model_and_tokenizer
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
`FlashTrace.trace(...)` accepts:
|
| 200 |
+
|
| 201 |
+
- `prompt: str`
|
| 202 |
+
- `target: str | None`
|
| 203 |
+
- `output_span: tuple[int, int] | None`
|
| 204 |
+
- `reasoning_span: tuple[int, int] | None`
|
| 205 |
+
- `hops: int`
|
| 206 |
+
- `method: "flashtrace" | "ifr-span" | "ifr-matrix"`
|
| 207 |
+
- `renorm_threshold: float | None`
|
| 208 |
+
|
| 209 |
+
`TraceResult` includes:
|
| 210 |
+
|
| 211 |
+
- `prompt_tokens`
|
| 212 |
+
- `generation_tokens`
|
| 213 |
+
- `scores`
|
| 214 |
+
- `per_hop_scores`
|
| 215 |
+
- `thinking_ratios`
|
| 216 |
+
- `output_span`
|
| 217 |
+
- `reasoning_span`
|
| 218 |
+
- `method`
|
| 219 |
+
- `metadata`
|
| 220 |
+
|
| 221 |
+
Export helpers:
|
| 222 |
+
|
| 223 |
+
```python
|
| 224 |
+
trace.topk_inputs(20)
|
| 225 |
+
trace.to_dict()
|
| 226 |
+
trace.to_json("trace.json")
|
| 227 |
+
trace.to_html("trace.html")
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
## Examples
|
| 231 |
+
|
| 232 |
+
```bash
|
| 233 |
+
python examples/quickstart.py --help
|
| 234 |
+
python examples/quickstart.py \
|
| 235 |
+
--model Qwen/Qwen3-8B \
|
| 236 |
+
--prompt "Context: Paris is the capital of France. Question: What is the capital of France?" \
|
| 237 |
+
--target "Paris" \
|
| 238 |
+
--output-span 0:0 \
|
| 239 |
+
--html trace.html
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
Heavy model examples are intended for GPU environments. CPU smoke tests use tiny randomly initialized models.
|
| 243 |
+
|
| 244 |
+
## Repository Map
|
| 245 |
+
|
| 246 |
+
- `flashtrace/`: reusable Python package
|
| 247 |
+
- `examples/`: public quickstarts
|
| 248 |
+
- `tests/`: CPU smoke tests
|
| 249 |
+
- `exp/`: paper experiments and research artifacts
|
| 250 |
+
- `docs/superpowers/`: design and implementation planning documents
|
| 251 |
+
|
| 252 |
+
## Research Experiments
|
| 253 |
+
|
| 254 |
+
The `exp/` directory contains the paper-era experiment runners, case studies, and saved artifacts. The public package API lives in `flashtrace/`; experiment scripts keep compatibility imports during the package migration.
|
| 255 |
+
|
| 256 |
+
## Troubleshooting
|
| 257 |
+
|
| 258 |
+
**CUDA memory**
|
| 259 |
+
|
| 260 |
+
Use smaller models, lower precision, `device_map="auto"`, shorter prompts, or `--recompute-attention`.
|
| 261 |
+
|
| 262 |
+
**Span selection**
|
| 263 |
+
|
| 264 |
+
Print `trace.generation_tokens` and select inclusive generated-token indices. Tokenization can split visible words into multiple model tokens.
|
| 265 |
+
|
| 266 |
+
**Deterministic generation**
|
| 267 |
+
|
| 268 |
+
Pass a `target` file for attribution against a known output. Leave `--target` out when you want the CLI to generate with deterministic defaults.
|
| 269 |
+
|
| 270 |
+
**Tokenizer alignment**
|
| 271 |
+
|
| 272 |
+
Inspect `trace.prompt_tokens` and `trace.generation_tokens` when scores appear shifted from visible text. Attribution scores follow tokenizer-level alignment.
|
| 273 |
+
|
| 274 |
+
**HTML export**
|
| 275 |
+
|
| 276 |
+
`trace.to_html("trace.html")` writes a standalone file that can be opened locally or shared as an artifact.
|
| 277 |
+
|
| 278 |
+
## Paper
|
| 279 |
+
|
| 280 |
+
FlashTrace implements the method described in [Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs](https://arxiv.org/abs/2602.01914).
|
| 281 |
+
|
| 282 |
+
## Citation
|
| 283 |
+
|
| 284 |
+
```bibtex
|
| 285 |
+
@misc{pan2026flashtrace,
|
| 286 |
+
title={Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs},
|
| 287 |
+
author={Pan, Wenbo and Liu, Zhichao and Wang, Xianlong and Yu, Haining and Jia, Xiaohua},
|
| 288 |
+
year={2026},
|
| 289 |
+
eprint={2602.01914},
|
| 290 |
+
archivePrefix={arXiv},
|
| 291 |
+
primaryClass={cs.LG}
|
| 292 |
+
}
|
| 293 |
+
```
|
attribution_datasets.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence
|
| 7 |
+
|
| 8 |
+
# Import sentence splitter from shared utils; fallback when unavailable
|
| 9 |
+
try:
|
| 10 |
+
from shared_utils import create_sentences, create_sentences_fallback, nlp
|
| 11 |
+
except Exception:
|
| 12 |
+
from shared_utils import create_sentences_fallback as create_sentences
|
| 13 |
+
nlp = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class AttributionExample:
|
| 18 |
+
prompt: str
|
| 19 |
+
target: Optional[str] = None
|
| 20 |
+
indices_to_explain: Optional[List[int]] = None
|
| 21 |
+
attr_mask_indices: Optional[List[int]] = None
|
| 22 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AttributionDataset(Iterable[AttributionExample]):
|
| 26 |
+
"""Base iterable for attribution-ready datasets."""
|
| 27 |
+
|
| 28 |
+
name: str = "dataset"
|
| 29 |
+
|
| 30 |
+
def __init__(self) -> None:
|
| 31 |
+
self.examples: List[AttributionExample] = []
|
| 32 |
+
|
| 33 |
+
def __iter__(self) -> Iterator[AttributionExample]:
|
| 34 |
+
return iter(self.examples)
|
| 35 |
+
|
| 36 |
+
def __len__(self) -> int: # pragma: no cover - trivial
|
| 37 |
+
return len(self.examples)
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, item): # pragma: no cover - convenience
|
| 40 |
+
return self.examples[item]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _add_dummy_facts_to_prompt(text_sentences: Sequence[str]) -> List[str]:
|
| 44 |
+
"""
|
| 45 |
+
Reproduces the original behaviour of interleaving dummy sentences with the
|
| 46 |
+
provided text segments so attribution heads can be masked easily.
|
| 47 |
+
"""
|
| 48 |
+
result: List[str] = []
|
| 49 |
+
for sentence in text_sentences:
|
| 50 |
+
result.append(sentence)
|
| 51 |
+
result.append(" Unrelated Sentence.")
|
| 52 |
+
return result
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class MathAttributionDataset(AttributionDataset):
|
| 56 |
+
"""Dataset wrapper for synthetic math problems with dummy context facts."""
|
| 57 |
+
|
| 58 |
+
name = "math"
|
| 59 |
+
|
| 60 |
+
def __init__(self, path: str | Path, tokenizer: Any) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
data_path = Path(path)
|
| 63 |
+
with data_path.open("r", encoding="utf-8") as f:
|
| 64 |
+
raw_examples = json.load(f)
|
| 65 |
+
|
| 66 |
+
for entry in raw_examples:
|
| 67 |
+
question_text = entry["question"]
|
| 68 |
+
sentences = create_sentences(question_text, tokenizer)
|
| 69 |
+
if not sentences:
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
context_sentences = sentences[:-1]
|
| 73 |
+
question_sentence = sentences[-1]
|
| 74 |
+
if question_sentence.startswith(" "):
|
| 75 |
+
question_sentence = question_sentence[1:]
|
| 76 |
+
|
| 77 |
+
context_with_dummy = _add_dummy_facts_to_prompt(context_sentences)
|
| 78 |
+
question_with_dummy = _add_dummy_facts_to_prompt([question_sentence])
|
| 79 |
+
|
| 80 |
+
prompt = "".join(context_with_dummy) + "\n" + "".join(question_with_dummy)
|
| 81 |
+
total_sentences = len(context_with_dummy) + len(question_with_dummy)
|
| 82 |
+
attr_mask_indices = list(range(0, total_sentences, 2))
|
| 83 |
+
|
| 84 |
+
self.examples.append(
|
| 85 |
+
AttributionExample(
|
| 86 |
+
prompt=prompt,
|
| 87 |
+
target=None,
|
| 88 |
+
indices_to_explain=[-2],
|
| 89 |
+
attr_mask_indices=attr_mask_indices,
|
| 90 |
+
metadata={"raw_question": question_text},
|
| 91 |
+
)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FactsAttributionDataset(AttributionDataset):
|
| 96 |
+
"""Dataset wrapper for curated factual prompts with explicit gold attributions."""
|
| 97 |
+
|
| 98 |
+
name = "facts"
|
| 99 |
+
|
| 100 |
+
def __init__(self, path: str | Path) -> None:
|
| 101 |
+
super().__init__()
|
| 102 |
+
data_path = Path(path)
|
| 103 |
+
with data_path.open("r", encoding="utf-8") as f:
|
| 104 |
+
raw_examples = json.load(f)
|
| 105 |
+
|
| 106 |
+
for entry in raw_examples:
|
| 107 |
+
metadata = {
|
| 108 |
+
key: value
|
| 109 |
+
for key, value in entry.items()
|
| 110 |
+
if key not in {"prompt", "target", "indices_to_explain", "attr_mask_indices"}
|
| 111 |
+
}
|
| 112 |
+
self.examples.append(
|
| 113 |
+
AttributionExample(
|
| 114 |
+
prompt=entry["prompt"],
|
| 115 |
+
target=entry.get("target"),
|
| 116 |
+
indices_to_explain=entry.get("indices_to_explain"),
|
| 117 |
+
attr_mask_indices=entry.get("attr_mask_indices"),
|
| 118 |
+
metadata=metadata,
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class MoreHopQAAttributionDataset(AttributionDataset):
|
| 124 |
+
"""Dataset wrapper for multi-hop QA prompts without explicit gold attribution."""
|
| 125 |
+
|
| 126 |
+
name = "morehopqa"
|
| 127 |
+
|
| 128 |
+
def __init__(self, path: str | Path) -> None:
|
| 129 |
+
super().__init__()
|
| 130 |
+
data_path = Path(path)
|
| 131 |
+
with data_path.open("r", encoding="utf-8") as f:
|
| 132 |
+
raw_examples = json.load(f)
|
| 133 |
+
|
| 134 |
+
for entry in raw_examples:
|
| 135 |
+
context_chunks = ["".join(item[1]) for item in entry.get("context", [])]
|
| 136 |
+
context = " ".join(context_chunks)
|
| 137 |
+
prompt = context + "\n" + entry["question"]
|
| 138 |
+
|
| 139 |
+
self.examples.append(
|
| 140 |
+
AttributionExample(
|
| 141 |
+
prompt=prompt,
|
| 142 |
+
target=None,
|
| 143 |
+
indices_to_explain=[-2],
|
| 144 |
+
attr_mask_indices=None,
|
| 145 |
+
metadata={
|
| 146 |
+
"answer": entry.get("answer"),
|
| 147 |
+
"id": entry.get("_id"),
|
| 148 |
+
"original_context": entry.get("context"),
|
| 149 |
+
},
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# added
|
| 155 |
+
class RulerAttributionDataset(AttributionDataset):
|
| 156 |
+
"""Dataset wrapper for raw RULER JSONL files with needle spans.
|
| 157 |
+
|
| 158 |
+
Expects a JSONL file produced by repos/RULER (with added `needle_spans`).
|
| 159 |
+
Each line must contain at least: `input`, `answer_prefix`, `outputs`, and
|
| 160 |
+
optionally `needle_spans` with character spans relative to `input`.
|
| 161 |
+
|
| 162 |
+
Mapping logic:
|
| 163 |
+
- prompt = input + answer_prefix
|
| 164 |
+
- target = answer_prefix (+ optional space) + ", ".join(outputs)
|
| 165 |
+
- sentence indices computed over " " + prompt (leading space to match evaluator)
|
| 166 |
+
- each span is shifted by +1 to account for that leading space
|
| 167 |
+
- attr_mask_indices = union of all sentences covered by any span
|
| 168 |
+
- indices_to_explain = [0] when target is present
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
name = "ruler"
|
| 172 |
+
|
| 173 |
+
def __init__(self, path: str | Path) -> None:
|
| 174 |
+
super().__init__()
|
| 175 |
+
data_path = Path(path)
|
| 176 |
+
if not data_path.exists():
|
| 177 |
+
raise FileNotFoundError(f"RULER file not found: {data_path}")
|
| 178 |
+
|
| 179 |
+
# Use shared nlp pipeline; fallback to a naive splitter if unavailable
|
| 180 |
+
if nlp is not None:
|
| 181 |
+
def _sentence_bounds(text: str) -> List[tuple[int, int]]:
|
| 182 |
+
doc = nlp(text)
|
| 183 |
+
return [(s.start_char, s.end_char) for s in doc.sents]
|
| 184 |
+
else:
|
| 185 |
+
def _sentence_bounds(text: str) -> List[tuple[int, int]]:
|
| 186 |
+
# Naive fallback: split on newlines, produce contiguous ranges
|
| 187 |
+
bounds: List[tuple[int, int]] = []
|
| 188 |
+
start = 0
|
| 189 |
+
parts = text.split("\n")
|
| 190 |
+
for idx, part in enumerate(parts):
|
| 191 |
+
end = start + len(part)
|
| 192 |
+
if end > start:
|
| 193 |
+
bounds.append((start, end))
|
| 194 |
+
start = end + 1
|
| 195 |
+
if not bounds:
|
| 196 |
+
bounds = [(0, len(text))]
|
| 197 |
+
return bounds
|
| 198 |
+
|
| 199 |
+
def _map_spans(bounds: Sequence[tuple[int, int]], spans: Sequence[tuple[int, int]]) -> List[int]:
|
| 200 |
+
indices: set[int] = set()
|
| 201 |
+
for start, end in spans:
|
| 202 |
+
matched = False
|
| 203 |
+
for i, (bs, be) in enumerate(bounds):
|
| 204 |
+
if start >= bs and end <= be:
|
| 205 |
+
indices.add(i)
|
| 206 |
+
matched = True
|
| 207 |
+
break
|
| 208 |
+
if not matched:
|
| 209 |
+
# fallback: include all sentences with any overlap
|
| 210 |
+
for i, (bs, be) in enumerate(bounds):
|
| 211 |
+
if not (end <= bs or start >= be):
|
| 212 |
+
indices.add(i)
|
| 213 |
+
return sorted(indices)
|
| 214 |
+
|
| 215 |
+
def _read_jsonl(fp: Path) -> Iterator[Dict[str, Any]]:
|
| 216 |
+
with fp.open("r", encoding="utf-8") as f:
|
| 217 |
+
for line in f:
|
| 218 |
+
line = line.strip()
|
| 219 |
+
if line:
|
| 220 |
+
yield json.loads(line)
|
| 221 |
+
|
| 222 |
+
for entry in _read_jsonl(data_path):
|
| 223 |
+
input_text: str = entry.get("input", "")
|
| 224 |
+
answer_prefix: str = entry.get("answer_prefix", "")
|
| 225 |
+
outputs = entry.get("outputs", []) or []
|
| 226 |
+
|
| 227 |
+
# Build prompt/target
|
| 228 |
+
prompt = input_text + answer_prefix
|
| 229 |
+
if outputs:
|
| 230 |
+
sep = " " if answer_prefix and not answer_prefix.endswith((" ", "\n", "\t")) else ""
|
| 231 |
+
target = answer_prefix + sep + ", ".join(outputs)
|
| 232 |
+
else:
|
| 233 |
+
target = answer_prefix
|
| 234 |
+
|
| 235 |
+
# Sentence bounds over leading-space prompt to match evaluator
|
| 236 |
+
prompt_for_seg = " " + prompt
|
| 237 |
+
bounds = _sentence_bounds(prompt_for_seg)
|
| 238 |
+
|
| 239 |
+
# Collect spans and shift by +1 for the leading space
|
| 240 |
+
spans_raw = []
|
| 241 |
+
for item in entry.get("needle_spans", []) or []:
|
| 242 |
+
span = item.get("span")
|
| 243 |
+
if isinstance(span, list) and len(span) == 2:
|
| 244 |
+
spans_raw.append((int(span[0]) + 1, int(span[1]) + 1))
|
| 245 |
+
|
| 246 |
+
attr_indices = _map_spans(bounds, spans_raw) if spans_raw else None
|
| 247 |
+
|
| 248 |
+
self.examples.append(
|
| 249 |
+
AttributionExample(
|
| 250 |
+
prompt=prompt,
|
| 251 |
+
target=target or None,
|
| 252 |
+
indices_to_explain=[0] if target else None,
|
| 253 |
+
attr_mask_indices=attr_indices,
|
| 254 |
+
metadata={
|
| 255 |
+
"dataset": "ruler",
|
| 256 |
+
"length": entry.get("length"),
|
| 257 |
+
"length_w_model_temp": entry.get("length_w_model_temp"),
|
| 258 |
+
"outputs": outputs,
|
| 259 |
+
"answer_prefix": answer_prefix,
|
| 260 |
+
"token_position_answer": entry.get("token_position_answer"),
|
| 261 |
+
"needle_spans": entry.get("needle_spans"),
|
| 262 |
+
"prompt_sentence_count": len(bounds),
|
| 263 |
+
},
|
| 264 |
+
)
|
| 265 |
+
)
|
docs/superpowers/plans/2026-05-03-flashtrace-public-package.md
ADDED
|
@@ -0,0 +1,1605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashTrace Public Package Implementation Plan
|
| 2 |
+
|
| 3 |
+
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
| 4 |
+
|
| 5 |
+
**Goal:** Build an installable `flashtrace` package with a stable Python API, CLI tracing command, JSON export, HTML heatmap export, README quickstart, and CPU smoke tests.
|
| 6 |
+
|
| 7 |
+
**Architecture:** Create a package-first structure while preserving temporary root compatibility wrappers for existing experiment scripts. Move the IFR implementation and attribution engines into `flashtrace/`, wrap them with `FlashTrace` and `TraceResult`, then expose a CLI and public examples.
|
| 8 |
+
|
| 9 |
+
**Tech Stack:** Python 3.10+, PyTorch, Transformers, Accelerate, NumPy, tqdm, argparse, pytest.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## File Structure
|
| 14 |
+
|
| 15 |
+
Create or modify these files:
|
| 16 |
+
|
| 17 |
+
- Create: `flashtrace/__init__.py` for public exports.
|
| 18 |
+
- Create: `flashtrace/core.py` from `ifr_core.py`.
|
| 19 |
+
- Create: `flashtrace/shared_utils.py` from `shared_utils.py`.
|
| 20 |
+
- Create: `flashtrace/lrp_rules.py` from `lrp_rules.py`.
|
| 21 |
+
- Create: `flashtrace/lrp_patches.py` from `lrp_patches.py`.
|
| 22 |
+
- Create: `flashtrace/attribution.py` from `llm_attr.py`.
|
| 23 |
+
- Create: `flashtrace/improved.py` from `ft_ifr_improve.py`.
|
| 24 |
+
- Create: `flashtrace/result.py` for `TokenScore` and `TraceResult`.
|
| 25 |
+
- Create: `flashtrace/viz.py` for standalone HTML token heatmaps.
|
| 26 |
+
- Create: `flashtrace/tracer.py` for the `FlashTrace` facade.
|
| 27 |
+
- Create: `flashtrace/model_io.py` for Hugging Face loading helpers.
|
| 28 |
+
- Create: `flashtrace/cli.py` for `flashtrace trace`.
|
| 29 |
+
- Create: `flashtrace/baselines/__init__.py`.
|
| 30 |
+
- Create: `flashtrace/baselines/attnlrp.py`.
|
| 31 |
+
- Modify: `ifr_core.py`, `shared_utils.py`, `lrp_rules.py`, `lrp_patches.py`, `llm_attr.py`, `ft_ifr_improve.py` into root compatibility wrappers.
|
| 32 |
+
- Modify: `pyproject.toml` package metadata and console script.
|
| 33 |
+
- Modify: `.gitignore` generated artifact rules.
|
| 34 |
+
- Create: `README.md`.
|
| 35 |
+
- Create: `LICENSE`.
|
| 36 |
+
- Create: `examples/quickstart.py`.
|
| 37 |
+
- Create: `tests/helpers.py`.
|
| 38 |
+
- Create: `tests/test_imports.py`.
|
| 39 |
+
- Create: `tests/test_core_recompute.py`.
|
| 40 |
+
- Create: `tests/test_result.py`.
|
| 41 |
+
- Create: `tests/test_tracer.py`.
|
| 42 |
+
- Create: `tests/test_cli.py`.
|
| 43 |
+
- Delete: `model_generation.py`.
|
| 44 |
+
|
| 45 |
+
## Task 1: Package Metadata And Skeleton
|
| 46 |
+
|
| 47 |
+
**Files:**
|
| 48 |
+
- Modify: `pyproject.toml`
|
| 49 |
+
- Create: `flashtrace/__init__.py`
|
| 50 |
+
- Create: `flashtrace/tracer.py`
|
| 51 |
+
- Create: `flashtrace/result.py`
|
| 52 |
+
- Create: `flashtrace/model_io.py`
|
| 53 |
+
- Create: `flashtrace/cli.py`
|
| 54 |
+
- Create: `flashtrace/baselines/__init__.py`
|
| 55 |
+
- Create: `flashtrace/baselines/attnlrp.py`
|
| 56 |
+
- Test: `tests/test_imports.py`
|
| 57 |
+
|
| 58 |
+
- [ ] **Step 1: Write the failing public import test**
|
| 59 |
+
|
| 60 |
+
Create `tests/test_imports.py`:
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
def test_public_imports():
|
| 64 |
+
import flashtrace
|
| 65 |
+
|
| 66 |
+
assert flashtrace.FlashTrace.__name__ == "FlashTrace"
|
| 67 |
+
assert flashtrace.TraceResult.__name__ == "TraceResult"
|
| 68 |
+
assert callable(flashtrace.load_model_and_tokenizer)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
- [ ] **Step 2: Run the import test and see the expected failure**
|
| 72 |
+
|
| 73 |
+
Run:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
uv run pytest tests/test_imports.py -q
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
Expected: pytest reports an import failure for `flashtrace`.
|
| 80 |
+
|
| 81 |
+
- [ ] **Step 3: Create package directories**
|
| 82 |
+
|
| 83 |
+
Run:
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
mkdir -p flashtrace/baselines tests
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
- [ ] **Step 4: Add minimal public package files**
|
| 90 |
+
|
| 91 |
+
Create `flashtrace/tracer.py`:
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
class FlashTrace:
|
| 95 |
+
"""Public facade for FlashTrace attribution."""
|
| 96 |
+
|
| 97 |
+
def __init__(self, model, tokenizer, **kwargs):
|
| 98 |
+
self.model = model
|
| 99 |
+
self.tokenizer = tokenizer
|
| 100 |
+
self.options = dict(kwargs)
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
Create `flashtrace/result.py`:
|
| 104 |
+
|
| 105 |
+
```python
|
| 106 |
+
from __future__ import annotations
|
| 107 |
+
|
| 108 |
+
from dataclasses import dataclass
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@dataclass(frozen=True)
|
| 112 |
+
class TraceResult:
|
| 113 |
+
"""Public attribution result returned by FlashTrace."""
|
| 114 |
+
|
| 115 |
+
prompt_tokens: list[str]
|
| 116 |
+
generation_tokens: list[str]
|
| 117 |
+
scores: list[float]
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
Create `flashtrace/model_io.py`:
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
def load_model_and_tokenizer(*args, **kwargs):
|
| 124 |
+
"""Load a Hugging Face causal LM and tokenizer."""
|
| 125 |
+
|
| 126 |
+
raise RuntimeError("load_model_and_tokenizer will be implemented in the model IO task.")
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
Create `flashtrace/cli.py`:
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
def main(argv=None):
|
| 133 |
+
"""FlashTrace command-line entrypoint."""
|
| 134 |
+
|
| 135 |
+
raise RuntimeError("CLI will be implemented in the CLI task.")
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
Create `flashtrace/baselines/__init__.py`:
|
| 139 |
+
|
| 140 |
+
```python
|
| 141 |
+
"""Baseline attribution methods for FlashTrace."""
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Create `flashtrace/baselines/attnlrp.py`:
|
| 145 |
+
|
| 146 |
+
```python
|
| 147 |
+
"""AttnLRP baseline exports."""
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Create `flashtrace/__init__.py`:
|
| 151 |
+
|
| 152 |
+
```python
|
| 153 |
+
"""FlashTrace: efficient multi-token attribution for reasoning LLMs."""
|
| 154 |
+
|
| 155 |
+
from .model_io import load_model_and_tokenizer
|
| 156 |
+
from .result import TraceResult
|
| 157 |
+
from .tracer import FlashTrace
|
| 158 |
+
|
| 159 |
+
__all__ = ["FlashTrace", "TraceResult", "load_model_and_tokenizer"]
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
- [ ] **Step 5: Update package metadata**
|
| 163 |
+
|
| 164 |
+
Replace `pyproject.toml` with:
|
| 165 |
+
|
| 166 |
+
```toml
|
| 167 |
+
[project]
|
| 168 |
+
name = "flashtrace"
|
| 169 |
+
version = "0.1.0"
|
| 170 |
+
description = "Efficient multi-token attribution for reasoning language models."
|
| 171 |
+
readme = "README.md"
|
| 172 |
+
requires-python = ">=3.10"
|
| 173 |
+
dependencies = [
|
| 174 |
+
"accelerate>=1.11.0",
|
| 175 |
+
"matplotlib>=3.6",
|
| 176 |
+
"networkx>=3.3",
|
| 177 |
+
"numpy>=2.0",
|
| 178 |
+
"seaborn>=0.11",
|
| 179 |
+
"spacy>=3.8",
|
| 180 |
+
"torch>=2.5",
|
| 181 |
+
"tqdm>=4.67",
|
| 182 |
+
"transformers>=4.53",
|
| 183 |
+
"wordfreq>=3.1.1",
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
[project.optional-dependencies]
|
| 187 |
+
baselines = [
|
| 188 |
+
"bert-score>=0.3.13",
|
| 189 |
+
"evaluate>=0.4.6",
|
| 190 |
+
"sentence-transformers>=4.1.0",
|
| 191 |
+
]
|
| 192 |
+
eval = [
|
| 193 |
+
"datasets>=2.21",
|
| 194 |
+
"evaluate>=0.4.6",
|
| 195 |
+
]
|
| 196 |
+
dev = [
|
| 197 |
+
"pytest>=8.0",
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
[project.scripts]
|
| 201 |
+
flashtrace = "flashtrace.cli:main"
|
| 202 |
+
|
| 203 |
+
[tool.setuptools.packages.find]
|
| 204 |
+
include = ["flashtrace*"]
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
- [ ] **Step 6: Run the import test**
|
| 208 |
+
|
| 209 |
+
Run:
|
| 210 |
+
|
| 211 |
+
```bash
|
| 212 |
+
uv run pytest tests/test_imports.py -q
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
Expected: `1 passed`.
|
| 216 |
+
|
| 217 |
+
- [ ] **Step 7: Commit**
|
| 218 |
+
|
| 219 |
+
Run:
|
| 220 |
+
|
| 221 |
+
```bash
|
| 222 |
+
git add pyproject.toml flashtrace tests/test_imports.py
|
| 223 |
+
git commit -m "feat: add flashtrace package skeleton"
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
## Task 2: Core IFR Migration
|
| 227 |
+
|
| 228 |
+
**Files:**
|
| 229 |
+
- Create: `flashtrace/core.py`
|
| 230 |
+
- Create: `flashtrace/shared_utils.py`
|
| 231 |
+
- Modify: `ifr_core.py`
|
| 232 |
+
- Modify: `shared_utils.py`
|
| 233 |
+
- Create: `tests/helpers.py`
|
| 234 |
+
- Create: `tests/test_core_recompute.py`
|
| 235 |
+
|
| 236 |
+
- [ ] **Step 1: Add the tiny-model test helper**
|
| 237 |
+
|
| 238 |
+
Create `tests/helpers.py`:
|
| 239 |
+
|
| 240 |
+
```python
|
| 241 |
+
from __future__ import annotations
|
| 242 |
+
|
| 243 |
+
from tokenizers import Tokenizer, models, pre_tokenizers
|
| 244 |
+
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def make_tiny_qwen2_model_and_tokenizer(
|
| 248 |
+
*,
|
| 249 |
+
n_layers: int = 3,
|
| 250 |
+
d_model: int = 48,
|
| 251 |
+
n_heads: int = 4,
|
| 252 |
+
n_kv_heads: int = 2,
|
| 253 |
+
max_pos: int = 128,
|
| 254 |
+
):
|
| 255 |
+
config = AutoConfig.for_model(
|
| 256 |
+
"qwen2",
|
| 257 |
+
vocab_size=500,
|
| 258 |
+
hidden_size=d_model,
|
| 259 |
+
intermediate_size=d_model * 2,
|
| 260 |
+
num_hidden_layers=n_layers,
|
| 261 |
+
num_attention_heads=n_heads,
|
| 262 |
+
num_key_value_heads=n_kv_heads,
|
| 263 |
+
max_position_embeddings=max_pos,
|
| 264 |
+
use_sliding_window=False,
|
| 265 |
+
attn_implementation="eager",
|
| 266 |
+
)
|
| 267 |
+
model = AutoModelForCausalLM.from_config(config, attn_implementation="eager")
|
| 268 |
+
model.eval()
|
| 269 |
+
|
| 270 |
+
backend = Tokenizer(models.WordLevel(vocab={f"t{i}": i for i in range(500)}, unk_token="t0"))
|
| 271 |
+
backend.pre_tokenizer = pre_tokenizers.Whitespace()
|
| 272 |
+
tokenizer = PreTrainedTokenizerFast(tokenizer_object=backend, eos_token="t1", pad_token="t2")
|
| 273 |
+
tokenizer.chat_template = "{% for m in messages %}{{ m['content'] }}{% endfor %}"
|
| 274 |
+
return model, tokenizer
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
- [ ] **Step 2: Write the failing core import smoke test**
|
| 278 |
+
|
| 279 |
+
Create `tests/test_core_recompute.py`:
|
| 280 |
+
|
| 281 |
+
```python
|
| 282 |
+
import torch
|
| 283 |
+
|
| 284 |
+
from flashtrace import core
|
| 285 |
+
from tests.helpers import make_tiny_qwen2_model_and_tokenizer
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def test_core_metadata_and_weight_pack():
|
| 289 |
+
model, _ = make_tiny_qwen2_model_and_tokenizer()
|
| 290 |
+
|
| 291 |
+
metadata = core.extract_model_metadata(model)
|
| 292 |
+
weight_pack = core.build_weight_pack(metadata, next(model.parameters()).dtype)
|
| 293 |
+
|
| 294 |
+
assert metadata.n_layers == 3
|
| 295 |
+
assert metadata.n_heads_q == 4
|
| 296 |
+
assert metadata.n_kv_heads == 2
|
| 297 |
+
assert len(weight_pack) == 3
|
| 298 |
+
assert torch.is_tensor(weight_pack[0]["v_w"])
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
- [ ] **Step 3: Run the core smoke test and see the expected failure**
|
| 302 |
+
|
| 303 |
+
Run:
|
| 304 |
+
|
| 305 |
+
```bash
|
| 306 |
+
uv run pytest tests/test_core_recompute.py::test_core_metadata_and_weight_pack -q
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
Expected: pytest reports missing `flashtrace.core`.
|
| 310 |
+
|
| 311 |
+
- [ ] **Step 4: Copy the IFR core into the package**
|
| 312 |
+
|
| 313 |
+
Run:
|
| 314 |
+
|
| 315 |
+
```bash
|
| 316 |
+
cp ifr_core.py flashtrace/core.py
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
- [ ] **Step 5: Copy shared utilities into the package**
|
| 320 |
+
|
| 321 |
+
Run:
|
| 322 |
+
|
| 323 |
+
```bash
|
| 324 |
+
cp shared_utils.py flashtrace/shared_utils.py
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
- [ ] **Step 6: Replace root `ifr_core.py` with a compatibility wrapper**
|
| 328 |
+
|
| 329 |
+
Replace `ifr_core.py` with:
|
| 330 |
+
|
| 331 |
+
```python
|
| 332 |
+
"""Compatibility wrapper for package-era imports."""
|
| 333 |
+
|
| 334 |
+
from flashtrace.core import * # noqa: F401,F403
|
| 335 |
+
```
|
| 336 |
+
|
| 337 |
+
- [ ] **Step 7: Replace root `shared_utils.py` with a compatibility wrapper**
|
| 338 |
+
|
| 339 |
+
Replace `shared_utils.py` with:
|
| 340 |
+
|
| 341 |
+
```python
|
| 342 |
+
"""Compatibility wrapper for package-era imports."""
|
| 343 |
+
|
| 344 |
+
from flashtrace.shared_utils import * # noqa: F401,F403
|
| 345 |
+
```
|
| 346 |
+
|
| 347 |
+
- [ ] **Step 8: Run the core smoke test**
|
| 348 |
+
|
| 349 |
+
Run:
|
| 350 |
+
|
| 351 |
+
```bash
|
| 352 |
+
uv run pytest tests/test_core_recompute.py::test_core_metadata_and_weight_pack -q
|
| 353 |
+
```
|
| 354 |
+
|
| 355 |
+
Expected: `1 passed`.
|
| 356 |
+
|
| 357 |
+
- [ ] **Step 9: Commit**
|
| 358 |
+
|
| 359 |
+
Run:
|
| 360 |
+
|
| 361 |
+
```bash
|
| 362 |
+
git add flashtrace/core.py flashtrace/shared_utils.py ifr_core.py shared_utils.py tests/helpers.py tests/test_core_recompute.py
|
| 363 |
+
git commit -m "feat: move IFR core into package"
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
## Task 3: Attribution Engine Migration
|
| 367 |
+
|
| 368 |
+
**Files:**
|
| 369 |
+
- Create: `flashtrace/lrp_rules.py`
|
| 370 |
+
- Create: `flashtrace/lrp_patches.py`
|
| 371 |
+
- Create: `flashtrace/attribution.py`
|
| 372 |
+
- Create: `flashtrace/improved.py`
|
| 373 |
+
- Modify: `lrp_rules.py`
|
| 374 |
+
- Modify: `lrp_patches.py`
|
| 375 |
+
- Modify: `llm_attr.py`
|
| 376 |
+
- Modify: `ft_ifr_improve.py`
|
| 377 |
+
- Modify: `flashtrace/baselines/attnlrp.py`
|
| 378 |
+
- Test: `tests/test_core_recompute.py`
|
| 379 |
+
|
| 380 |
+
- [ ] **Step 1: Extend the recompute test with package attribution paths**
|
| 381 |
+
|
| 382 |
+
Append to `tests/test_core_recompute.py`:
|
| 383 |
+
|
| 384 |
+
```python
|
| 385 |
+
from flashtrace.attribution import LLMIFRAttribution
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def test_package_attribution_recompute_matches_stored_attention():
|
| 389 |
+
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
|
| 390 |
+
prompt = "t10 t20 t30 t40"
|
| 391 |
+
target = "t60 t70"
|
| 392 |
+
|
| 393 |
+
stored = LLMIFRAttribution(model, tokenizer, recompute_attention=False).calculate_ifr_span(prompt, target)
|
| 394 |
+
recomputed = LLMIFRAttribution(model, tokenizer, recompute_attention=True).calculate_ifr_span(prompt, target)
|
| 395 |
+
|
| 396 |
+
diff = (stored.attribution_matrix - recomputed.attribution_matrix).abs().max().item()
|
| 397 |
+
assert diff < 1e-5
|
| 398 |
+
```
|
| 399 |
+
|
| 400 |
+
- [ ] **Step 2: Run the package attribution test and see the expected failure**
|
| 401 |
+
|
| 402 |
+
Run:
|
| 403 |
+
|
| 404 |
+
```bash
|
| 405 |
+
uv run pytest tests/test_core_recompute.py::test_package_attribution_recompute_matches_stored_attention -q
|
| 406 |
+
```
|
| 407 |
+
|
| 408 |
+
Expected: pytest reports missing `flashtrace.attribution`.
|
| 409 |
+
|
| 410 |
+
- [ ] **Step 3: Copy LRP helpers and attribution engines into the package**
|
| 411 |
+
|
| 412 |
+
Run:
|
| 413 |
+
|
| 414 |
+
```bash
|
| 415 |
+
cp lrp_rules.py flashtrace/lrp_rules.py
|
| 416 |
+
cp lrp_patches.py flashtrace/lrp_patches.py
|
| 417 |
+
cp llm_attr.py flashtrace/attribution.py
|
| 418 |
+
cp ft_ifr_improve.py flashtrace/improved.py
|
| 419 |
+
```
|
| 420 |
+
|
| 421 |
+
- [ ] **Step 4: Update imports in `flashtrace/attribution.py`**
|
| 422 |
+
|
| 423 |
+
Edit package-local imports to this form:
|
| 424 |
+
|
| 425 |
+
```python
|
| 426 |
+
from .core import (
|
| 427 |
+
IFRParameters,
|
| 428 |
+
ModelMetadata,
|
| 429 |
+
attach_hooks,
|
| 430 |
+
build_weight_pack,
|
| 431 |
+
compute_ifr_for_all_positions,
|
| 432 |
+
compute_ifr_sentence_aggregate,
|
| 433 |
+
compute_multi_hop_ifr,
|
| 434 |
+
extract_model_metadata,
|
| 435 |
+
)
|
| 436 |
+
from .shared_utils import (
|
| 437 |
+
DEFAULT_GENERATE_KWARGS,
|
| 438 |
+
DEFAULT_PROMPT_TEMPLATE,
|
| 439 |
+
create_sentences,
|
| 440 |
+
create_sentence_masks,
|
| 441 |
+
)
|
| 442 |
+
from .lrp_patches import lrp_context, detect_model_type
|
| 443 |
+
```
|
| 444 |
+
|
| 445 |
+
- [ ] **Step 5: Update imports in `flashtrace/lrp_patches.py`**
|
| 446 |
+
|
| 447 |
+
Edit the LRP helper import to:
|
| 448 |
+
|
| 449 |
+
```python
|
| 450 |
+
from .lrp_rules import stop_gradient, divide_gradient, identity_rule_implicit
|
| 451 |
+
```
|
| 452 |
+
|
| 453 |
+
- [ ] **Step 6: Update imports in `flashtrace/improved.py`**
|
| 454 |
+
|
| 455 |
+
Edit the top-level package imports to:
|
| 456 |
+
|
| 457 |
+
```python
|
| 458 |
+
from . import attribution as llm_attr
|
| 459 |
+
from .core import IFRAggregate, MultiHopIFRResult, compute_ifr_sentence_aggregate
|
| 460 |
+
```
|
| 461 |
+
|
| 462 |
+
- [ ] **Step 7: Replace root compatibility modules**
|
| 463 |
+
|
| 464 |
+
Replace `lrp_rules.py` with:
|
| 465 |
+
|
| 466 |
+
```python
|
| 467 |
+
"""Compatibility wrapper for package-era imports."""
|
| 468 |
+
|
| 469 |
+
from flashtrace.lrp_rules import * # noqa: F401,F403
|
| 470 |
+
```
|
| 471 |
+
|
| 472 |
+
Replace `lrp_patches.py` with:
|
| 473 |
+
|
| 474 |
+
```python
|
| 475 |
+
"""Compatibility wrapper for package-era imports."""
|
| 476 |
+
|
| 477 |
+
from flashtrace.lrp_patches import * # noqa: F401,F403
|
| 478 |
+
```
|
| 479 |
+
|
| 480 |
+
Replace `llm_attr.py` with:
|
| 481 |
+
|
| 482 |
+
```python
|
| 483 |
+
"""Compatibility wrapper for package-era imports."""
|
| 484 |
+
|
| 485 |
+
from flashtrace.attribution import * # noqa: F401,F403
|
| 486 |
+
```
|
| 487 |
+
|
| 488 |
+
Replace `ft_ifr_improve.py` with:
|
| 489 |
+
|
| 490 |
+
```python
|
| 491 |
+
"""Compatibility wrapper for package-era imports."""
|
| 492 |
+
|
| 493 |
+
from flashtrace.improved import * # noqa: F401,F403
|
| 494 |
+
```
|
| 495 |
+
|
| 496 |
+
- [ ] **Step 8: Export the AttnLRP baseline**
|
| 497 |
+
|
| 498 |
+
Replace `flashtrace/baselines/attnlrp.py` with:
|
| 499 |
+
|
| 500 |
+
```python
|
| 501 |
+
"""AttnLRP baseline API."""
|
| 502 |
+
|
| 503 |
+
from flashtrace.attribution import AttnLRPSpanAggregate, LLMLRPAttribution, MultiHopAttnLRPResult
|
| 504 |
+
from flashtrace.lrp_patches import detect_model_type, lrp_context
|
| 505 |
+
|
| 506 |
+
__all__ = [
|
| 507 |
+
"AttnLRPSpanAggregate",
|
| 508 |
+
"LLMLRPAttribution",
|
| 509 |
+
"MultiHopAttnLRPResult",
|
| 510 |
+
"detect_model_type",
|
| 511 |
+
"lrp_context",
|
| 512 |
+
]
|
| 513 |
+
```
|
| 514 |
+
|
| 515 |
+
Replace `flashtrace/baselines/__init__.py` with:
|
| 516 |
+
|
| 517 |
+
```python
|
| 518 |
+
"""Baseline attribution methods for FlashTrace."""
|
| 519 |
+
|
| 520 |
+
from .attnlrp import LLMLRPAttribution
|
| 521 |
+
|
| 522 |
+
__all__ = ["LLMLRPAttribution"]
|
| 523 |
+
```
|
| 524 |
+
|
| 525 |
+
- [ ] **Step 9: Run attribution migration tests**
|
| 526 |
+
|
| 527 |
+
Run:
|
| 528 |
+
|
| 529 |
+
```bash
|
| 530 |
+
uv run pytest tests/test_core_recompute.py -q
|
| 531 |
+
```
|
| 532 |
+
|
| 533 |
+
Expected: all tests in the file pass.
|
| 534 |
+
|
| 535 |
+
- [ ] **Step 10: Run a root compatibility import check**
|
| 536 |
+
|
| 537 |
+
Run:
|
| 538 |
+
|
| 539 |
+
```bash
|
| 540 |
+
uv run python -c "import ifr_core, llm_attr, ft_ifr_improve; print(llm_attr.LLMIFRAttribution.__name__)"
|
| 541 |
+
```
|
| 542 |
+
|
| 543 |
+
Expected: prints `LLMIFRAttribution`.
|
| 544 |
+
|
| 545 |
+
- [ ] **Step 11: Commit**
|
| 546 |
+
|
| 547 |
+
Run:
|
| 548 |
+
|
| 549 |
+
```bash
|
| 550 |
+
git add flashtrace lrp_rules.py lrp_patches.py llm_attr.py ft_ifr_improve.py tests/test_core_recompute.py
|
| 551 |
+
git commit -m "feat: move attribution engines into package"
|
| 552 |
+
```
|
| 553 |
+
|
| 554 |
+
## Task 4: TraceResult And HTML Heatmap
|
| 555 |
+
|
| 556 |
+
**Files:**
|
| 557 |
+
- Modify: `flashtrace/result.py`
|
| 558 |
+
- Create: `flashtrace/viz.py`
|
| 559 |
+
- Create: `tests/test_result.py`
|
| 560 |
+
|
| 561 |
+
- [ ] **Step 1: Write result object tests**
|
| 562 |
+
|
| 563 |
+
Create `tests/test_result.py`:
|
| 564 |
+
|
| 565 |
+
```python
|
| 566 |
+
import json
|
| 567 |
+
|
| 568 |
+
from flashtrace.result import TokenScore, TraceResult
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def make_result():
|
| 572 |
+
return TraceResult(
|
| 573 |
+
prompt_tokens=[" alpha", " beta", " gamma"],
|
| 574 |
+
generation_tokens=[" answer"],
|
| 575 |
+
scores=[0.2, 0.7, 0.1],
|
| 576 |
+
per_hop_scores=[[0.1, 0.4, 0.0], [0.1, 0.3, 0.1]],
|
| 577 |
+
thinking_ratios=[0.5, 0.2],
|
| 578 |
+
output_span=(0, 0),
|
| 579 |
+
reasoning_span=(0, 0),
|
| 580 |
+
method="flashtrace",
|
| 581 |
+
metadata={"model": "tiny"},
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def test_topk_inputs_sorted():
|
| 586 |
+
result = make_result()
|
| 587 |
+
|
| 588 |
+
top = result.topk_inputs(2)
|
| 589 |
+
|
| 590 |
+
assert top == [
|
| 591 |
+
TokenScore(index=1, token=" beta", score=0.7),
|
| 592 |
+
TokenScore(index=0, token=" alpha", score=0.2),
|
| 593 |
+
]
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def test_to_dict_is_json_serializable():
|
| 597 |
+
result = make_result()
|
| 598 |
+
|
| 599 |
+
payload = result.to_dict()
|
| 600 |
+
|
| 601 |
+
assert payload["method"] == "flashtrace"
|
| 602 |
+
assert payload["top_inputs"][0]["token"] == " beta"
|
| 603 |
+
json.dumps(payload)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def test_to_dict_sanitizes_tensor_metadata():
|
| 607 |
+
import torch
|
| 608 |
+
|
| 609 |
+
result = TraceResult(
|
| 610 |
+
prompt_tokens=[" alpha"],
|
| 611 |
+
generation_tokens=[" answer"],
|
| 612 |
+
scores=[1.0],
|
| 613 |
+
metadata={"tensor": torch.tensor([1.0, 2.0]), "object": object()},
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
payload = result.to_dict()
|
| 617 |
+
|
| 618 |
+
assert payload["metadata"]["tensor"] == [1.0, 2.0]
|
| 619 |
+
assert isinstance(payload["metadata"]["object"], str)
|
| 620 |
+
json.dumps(payload)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def test_json_and_html_export(tmp_path):
|
| 624 |
+
result = make_result()
|
| 625 |
+
json_path = tmp_path / "trace.json"
|
| 626 |
+
html_path = tmp_path / "trace.html"
|
| 627 |
+
|
| 628 |
+
result.to_json(json_path)
|
| 629 |
+
result.to_html(html_path)
|
| 630 |
+
|
| 631 |
+
assert json_path.read_text(encoding="utf-8").startswith("{")
|
| 632 |
+
html = html_path.read_text(encoding="utf-8")
|
| 633 |
+
assert "<html" in html
|
| 634 |
+
assert " beta" in html
|
| 635 |
+
```
|
| 636 |
+
|
| 637 |
+
- [ ] **Step 2: Run result tests and see the expected failure**
|
| 638 |
+
|
| 639 |
+
Run:
|
| 640 |
+
|
| 641 |
+
```bash
|
| 642 |
+
uv run pytest tests/test_result.py -q
|
| 643 |
+
```
|
| 644 |
+
|
| 645 |
+
Expected: pytest reports missing `TokenScore` or missing methods.
|
| 646 |
+
|
| 647 |
+
- [ ] **Step 3: Implement `TraceResult`**
|
| 648 |
+
|
| 649 |
+
Replace `flashtrace/result.py` with:
|
| 650 |
+
|
| 651 |
+
```python
|
| 652 |
+
from __future__ import annotations
|
| 653 |
+
|
| 654 |
+
import json
|
| 655 |
+
from dataclasses import asdict, dataclass, field, is_dataclass
|
| 656 |
+
from pathlib import Path
|
| 657 |
+
from typing import Any
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
@dataclass(frozen=True)
|
| 661 |
+
class TokenScore:
|
| 662 |
+
index: int
|
| 663 |
+
token: str
|
| 664 |
+
score: float
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
@dataclass(frozen=True)
|
| 668 |
+
class TraceResult:
|
| 669 |
+
"""Public attribution result returned by FlashTrace."""
|
| 670 |
+
|
| 671 |
+
prompt_tokens: list[str]
|
| 672 |
+
generation_tokens: list[str]
|
| 673 |
+
scores: list[float]
|
| 674 |
+
per_hop_scores: list[list[float]] = field(default_factory=list)
|
| 675 |
+
thinking_ratios: list[float] = field(default_factory=list)
|
| 676 |
+
output_span: tuple[int, int] | None = None
|
| 677 |
+
reasoning_span: tuple[int, int] | None = None
|
| 678 |
+
method: str = "flashtrace"
|
| 679 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 680 |
+
|
| 681 |
+
def topk_inputs(self, k: int = 20) -> list[TokenScore]:
|
| 682 |
+
limit = max(0, int(k))
|
| 683 |
+
items = [
|
| 684 |
+
TokenScore(index=i, token=tok, score=float(score))
|
| 685 |
+
for i, (tok, score) in enumerate(zip(self.prompt_tokens, self.scores))
|
| 686 |
+
]
|
| 687 |
+
items.sort(key=lambda item: item.score, reverse=True)
|
| 688 |
+
return items[:limit]
|
| 689 |
+
|
| 690 |
+
def to_dict(self) -> dict[str, Any]:
|
| 691 |
+
return {
|
| 692 |
+
"method": self.method,
|
| 693 |
+
"prompt_tokens": list(self.prompt_tokens),
|
| 694 |
+
"generation_tokens": list(self.generation_tokens),
|
| 695 |
+
"scores": [float(x) for x in self.scores],
|
| 696 |
+
"per_hop_scores": [[float(x) for x in row] for row in self.per_hop_scores],
|
| 697 |
+
"thinking_ratios": [float(x) for x in self.thinking_ratios],
|
| 698 |
+
"output_span": list(self.output_span) if self.output_span is not None else None,
|
| 699 |
+
"reasoning_span": list(self.reasoning_span) if self.reasoning_span is not None else None,
|
| 700 |
+
"top_inputs": [asdict(item) for item in self.topk_inputs()],
|
| 701 |
+
"metadata": _jsonable(self.metadata),
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
def to_json(self, path: str | Path) -> None:
|
| 705 |
+
target = Path(path)
|
| 706 |
+
target.write_text(json.dumps(self.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
|
| 707 |
+
|
| 708 |
+
def to_html(self, path: str | Path) -> None:
|
| 709 |
+
from .viz import render_trace_html
|
| 710 |
+
|
| 711 |
+
target = Path(path)
|
| 712 |
+
target.write_text(render_trace_html(self), encoding="utf-8")
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def _jsonable(value: Any) -> Any:
|
| 716 |
+
if value is None or isinstance(value, (str, int, float, bool)):
|
| 717 |
+
return value
|
| 718 |
+
if hasattr(value, "detach") and hasattr(value, "cpu"):
|
| 719 |
+
try:
|
| 720 |
+
return value.detach().cpu().tolist()
|
| 721 |
+
except Exception:
|
| 722 |
+
return repr(value)
|
| 723 |
+
if is_dataclass(value):
|
| 724 |
+
return _jsonable(asdict(value))
|
| 725 |
+
if isinstance(value, dict):
|
| 726 |
+
return {str(k): _jsonable(v) for k, v in value.items()}
|
| 727 |
+
if isinstance(value, (list, tuple)):
|
| 728 |
+
return [_jsonable(v) for v in value]
|
| 729 |
+
return repr(value)
|
| 730 |
+
```
|
| 731 |
+
|
| 732 |
+
- [ ] **Step 4: Implement the standalone HTML renderer**
|
| 733 |
+
|
| 734 |
+
Create `flashtrace/viz.py`:
|
| 735 |
+
|
| 736 |
+
```python
|
| 737 |
+
from __future__ import annotations
|
| 738 |
+
|
| 739 |
+
from html import escape
|
| 740 |
+
from typing import TYPE_CHECKING
|
| 741 |
+
|
| 742 |
+
if TYPE_CHECKING:
|
| 743 |
+
from .result import TraceResult
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def _score_color(score: float, max_score: float) -> str:
|
| 747 |
+
if max_score <= 0.0:
|
| 748 |
+
return "rgba(245,245,245,0.75)"
|
| 749 |
+
ratio = min(1.0, abs(float(score)) / (max_score + 1e-12))
|
| 750 |
+
red = 255
|
| 751 |
+
green = int(246 - 105 * ratio)
|
| 752 |
+
blue = int(226 - 170 * ratio)
|
| 753 |
+
alpha = 0.22 + 0.58 * ratio
|
| 754 |
+
return f"rgba({red},{green},{blue},{alpha:.3f})"
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def _render_token_row(tokens: list[str], scores: list[float]) -> str:
|
| 758 |
+
max_score = max((abs(float(x)) for x in scores), default=0.0)
|
| 759 |
+
spans = []
|
| 760 |
+
for index, token in enumerate(tokens):
|
| 761 |
+
score = float(scores[index]) if index < len(scores) else 0.0
|
| 762 |
+
color = _score_color(score, max_score)
|
| 763 |
+
spans.append(
|
| 764 |
+
"<span class='tok' "
|
| 765 |
+
f"title='idx={index} score={score:.6f}' "
|
| 766 |
+
f"style='background:{color}'>{escape(token)}</span>"
|
| 767 |
+
)
|
| 768 |
+
return "".join(spans)
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def render_trace_html(result: "TraceResult") -> str:
|
| 772 |
+
top_rows = "\n".join(
|
| 773 |
+
f"<tr><td>{item.index}</td><td><code>{escape(item.token)}</code></td><td>{item.score:.6f}</td></tr>"
|
| 774 |
+
for item in result.topk_inputs(20)
|
| 775 |
+
)
|
| 776 |
+
hop_sections = []
|
| 777 |
+
for hop_index, hop_scores in enumerate(result.per_hop_scores):
|
| 778 |
+
hop_sections.append(
|
| 779 |
+
f"<section><h2>Hop {hop_index}</h2><div class='tokens'>{_render_token_row(result.prompt_tokens, hop_scores)}</div></section>"
|
| 780 |
+
)
|
| 781 |
+
hop_html = "\n".join(hop_sections)
|
| 782 |
+
metadata = escape(str(result.metadata))
|
| 783 |
+
return f"""<!doctype html>
|
| 784 |
+
<html lang="en">
|
| 785 |
+
<head>
|
| 786 |
+
<meta charset="utf-8">
|
| 787 |
+
<title>FlashTrace</title>
|
| 788 |
+
<style>
|
| 789 |
+
body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; margin: 32px; color: #151515; }}
|
| 790 |
+
h1, h2 {{ margin: 0 0 12px; }}
|
| 791 |
+
section {{ margin: 24px 0; }}
|
| 792 |
+
.tokens {{ line-height: 2.2; font-family: ui-monospace, SFMono-Regular, Menlo, monospace; }}
|
| 793 |
+
.tok {{ display: inline-block; margin: 2px; padding: 2px 4px; border-radius: 4px; white-space: pre-wrap; }}
|
| 794 |
+
table {{ border-collapse: collapse; margin-top: 12px; }}
|
| 795 |
+
td, th {{ border-bottom: 1px solid #ddd; padding: 6px 10px; text-align: left; }}
|
| 796 |
+
.meta {{ color: #555; font-size: 13px; }}
|
| 797 |
+
</style>
|
| 798 |
+
</head>
|
| 799 |
+
<body>
|
| 800 |
+
<h1>FlashTrace</h1>
|
| 801 |
+
<p class="meta">method={escape(result.method)} output_span={escape(str(result.output_span))} reasoning_span={escape(str(result.reasoning_span))}</p>
|
| 802 |
+
<section>
|
| 803 |
+
<h2>Prompt Attribution</h2>
|
| 804 |
+
<div class="tokens">{_render_token_row(result.prompt_tokens, result.scores)}</div>
|
| 805 |
+
</section>
|
| 806 |
+
{hop_html}
|
| 807 |
+
<section>
|
| 808 |
+
<h2>Top Input Tokens</h2>
|
| 809 |
+
<table><thead><tr><th>Index</th><th>Token</th><th>Score</th></tr></thead><tbody>{top_rows}</tbody></table>
|
| 810 |
+
</section>
|
| 811 |
+
<section>
|
| 812 |
+
<h2>Metadata</h2>
|
| 813 |
+
<pre>{metadata}</pre>
|
| 814 |
+
</section>
|
| 815 |
+
</body>
|
| 816 |
+
</html>
|
| 817 |
+
"""
|
| 818 |
+
```
|
| 819 |
+
|
| 820 |
+
- [ ] **Step 5: Run result tests**
|
| 821 |
+
|
| 822 |
+
Run:
|
| 823 |
+
|
| 824 |
+
```bash
|
| 825 |
+
uv run pytest tests/test_result.py -q
|
| 826 |
+
```
|
| 827 |
+
|
| 828 |
+
Expected: `4 passed`.
|
| 829 |
+
|
| 830 |
+
- [ ] **Step 6: Commit**
|
| 831 |
+
|
| 832 |
+
Run:
|
| 833 |
+
|
| 834 |
+
```bash
|
| 835 |
+
git add flashtrace/result.py flashtrace/viz.py tests/test_result.py
|
| 836 |
+
git commit -m "feat: add trace result exports"
|
| 837 |
+
```
|
| 838 |
+
|
| 839 |
+
## Task 5: FlashTrace Facade
|
| 840 |
+
|
| 841 |
+
**Files:**
|
| 842 |
+
- Modify: `flashtrace/tracer.py`
|
| 843 |
+
- Modify: `flashtrace/__init__.py`
|
| 844 |
+
- Create: `tests/test_tracer.py`
|
| 845 |
+
|
| 846 |
+
- [ ] **Step 1: Write tracer API tests**
|
| 847 |
+
|
| 848 |
+
Create `tests/test_tracer.py`:
|
| 849 |
+
|
| 850 |
+
```python
|
| 851 |
+
from flashtrace import FlashTrace, TraceResult
|
| 852 |
+
from tests.helpers import make_tiny_qwen2_model_and_tokenizer
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def test_flashtrace_trace_returns_public_result():
|
| 856 |
+
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
|
| 857 |
+
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
|
| 858 |
+
|
| 859 |
+
result = tracer.trace(
|
| 860 |
+
prompt="t10 t20 t30 t40",
|
| 861 |
+
target="t60 t70 t80",
|
| 862 |
+
output_span=(1, 2),
|
| 863 |
+
reasoning_span=(0, 1),
|
| 864 |
+
hops=1,
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
assert isinstance(result, TraceResult)
|
| 868 |
+
assert result.method == "flashtrace"
|
| 869 |
+
assert len(result.prompt_tokens) > 0
|
| 870 |
+
assert len(result.scores) == len(result.prompt_tokens)
|
| 871 |
+
assert result.output_span == (1, 2)
|
| 872 |
+
assert result.reasoning_span == (0, 1)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def test_ifr_span_method_returns_public_result():
|
| 876 |
+
model, tokenizer = make_tiny_qwen2_model_and_tokenizer(n_layers=2, d_model=32, n_heads=4, n_kv_heads=2)
|
| 877 |
+
tracer = FlashTrace(model, tokenizer, chunk_tokens=16, sink_chunk_tokens=4, recompute_attention=True)
|
| 878 |
+
|
| 879 |
+
result = tracer.trace(
|
| 880 |
+
prompt="t10 t20 t30 t40",
|
| 881 |
+
target="t60 t70",
|
| 882 |
+
output_span=(0, 1),
|
| 883 |
+
method="ifr-span",
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
assert result.method == "ifr-span"
|
| 887 |
+
assert len(result.scores) == len(result.prompt_tokens)
|
| 888 |
+
```
|
| 889 |
+
|
| 890 |
+
- [ ] **Step 2: Run tracer tests and see the expected failure**
|
| 891 |
+
|
| 892 |
+
Run:
|
| 893 |
+
|
| 894 |
+
```bash
|
| 895 |
+
uv run pytest tests/test_tracer.py -q
|
| 896 |
+
```
|
| 897 |
+
|
| 898 |
+
Expected: pytest reports missing `trace`.
|
| 899 |
+
|
| 900 |
+
- [ ] **Step 3: Implement result adaptation helpers and facade**
|
| 901 |
+
|
| 902 |
+
Replace `flashtrace/tracer.py` with:
|
| 903 |
+
|
| 904 |
+
```python
|
| 905 |
+
from __future__ import annotations
|
| 906 |
+
|
| 907 |
+
from typing import Any, Literal
|
| 908 |
+
|
| 909 |
+
import torch
|
| 910 |
+
|
| 911 |
+
from .attribution import LLMIFRAttribution, LLMAttributionResult
|
| 912 |
+
from .improved import LLMIFRAttributionBoth
|
| 913 |
+
from .result import TraceResult
|
| 914 |
+
|
| 915 |
+
TraceMethod = Literal["flashtrace", "ifr-span", "ifr-matrix"]
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def _to_float_list(values: Any) -> list[float]:
|
| 919 |
+
if torch.is_tensor(values):
|
| 920 |
+
values = values.detach().cpu().to(dtype=torch.float32).tolist()
|
| 921 |
+
return [float(x) for x in (values or [])]
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
class FlashTrace:
|
| 925 |
+
"""Public facade for FlashTrace attribution."""
|
| 926 |
+
|
| 927 |
+
def __init__(
|
| 928 |
+
self,
|
| 929 |
+
model,
|
| 930 |
+
tokenizer,
|
| 931 |
+
*,
|
| 932 |
+
chunk_tokens: int = 128,
|
| 933 |
+
sink_chunk_tokens: int = 32,
|
| 934 |
+
recompute_attention: bool = False,
|
| 935 |
+
generate_kwargs: dict[str, Any] | None = None,
|
| 936 |
+
) -> None:
|
| 937 |
+
self.model = model
|
| 938 |
+
self.tokenizer = tokenizer
|
| 939 |
+
self.chunk_tokens = int(chunk_tokens)
|
| 940 |
+
self.sink_chunk_tokens = int(sink_chunk_tokens)
|
| 941 |
+
self.recompute_attention = bool(recompute_attention)
|
| 942 |
+
self.generate_kwargs = generate_kwargs
|
| 943 |
+
|
| 944 |
+
def trace(
|
| 945 |
+
self,
|
| 946 |
+
*,
|
| 947 |
+
prompt: str,
|
| 948 |
+
target: str | None = None,
|
| 949 |
+
output_span: tuple[int, int] | None = None,
|
| 950 |
+
reasoning_span: tuple[int, int] | None = None,
|
| 951 |
+
hops: int = 1,
|
| 952 |
+
method: TraceMethod = "flashtrace",
|
| 953 |
+
renorm_threshold: float | None = None,
|
| 954 |
+
) -> TraceResult:
|
| 955 |
+
if method == "flashtrace":
|
| 956 |
+
engine = LLMIFRAttributionBoth(
|
| 957 |
+
self.model,
|
| 958 |
+
self.tokenizer,
|
| 959 |
+
generate_kwargs=self.generate_kwargs,
|
| 960 |
+
chunk_tokens=self.chunk_tokens,
|
| 961 |
+
sink_chunk_tokens=self.sink_chunk_tokens,
|
| 962 |
+
recompute_attention=self.recompute_attention,
|
| 963 |
+
)
|
| 964 |
+
raw = engine.calculate_ifr_multi_hop_both(
|
| 965 |
+
prompt,
|
| 966 |
+
target=target,
|
| 967 |
+
sink_span=output_span,
|
| 968 |
+
thinking_span=reasoning_span,
|
| 969 |
+
n_hops=int(hops),
|
| 970 |
+
renorm_threshold=renorm_threshold,
|
| 971 |
+
)
|
| 972 |
+
elif method == "ifr-span":
|
| 973 |
+
engine = LLMIFRAttribution(
|
| 974 |
+
self.model,
|
| 975 |
+
self.tokenizer,
|
| 976 |
+
generate_kwargs=self.generate_kwargs,
|
| 977 |
+
chunk_tokens=self.chunk_tokens,
|
| 978 |
+
sink_chunk_tokens=self.sink_chunk_tokens,
|
| 979 |
+
recompute_attention=self.recompute_attention,
|
| 980 |
+
)
|
| 981 |
+
raw = engine.calculate_ifr_span(
|
| 982 |
+
prompt,
|
| 983 |
+
target=target,
|
| 984 |
+
span=output_span,
|
| 985 |
+
renorm_threshold=renorm_threshold,
|
| 986 |
+
)
|
| 987 |
+
elif method == "ifr-matrix":
|
| 988 |
+
engine = LLMIFRAttribution(
|
| 989 |
+
self.model,
|
| 990 |
+
self.tokenizer,
|
| 991 |
+
generate_kwargs=self.generate_kwargs,
|
| 992 |
+
chunk_tokens=self.chunk_tokens,
|
| 993 |
+
sink_chunk_tokens=self.sink_chunk_tokens,
|
| 994 |
+
recompute_attention=self.recompute_attention,
|
| 995 |
+
)
|
| 996 |
+
raw = engine.calculate_ifr_for_all_positions_output_only(
|
| 997 |
+
prompt,
|
| 998 |
+
target=target,
|
| 999 |
+
sink_span=output_span,
|
| 1000 |
+
renorm_threshold=renorm_threshold,
|
| 1001 |
+
)
|
| 1002 |
+
else:
|
| 1003 |
+
raise ValueError(f"Unsupported method: {method}")
|
| 1004 |
+
|
| 1005 |
+
return self._build_result(raw, method=method, output_span=output_span, reasoning_span=reasoning_span)
|
| 1006 |
+
|
| 1007 |
+
def _build_result(
|
| 1008 |
+
self,
|
| 1009 |
+
raw: LLMAttributionResult,
|
| 1010 |
+
*,
|
| 1011 |
+
method: str,
|
| 1012 |
+
output_span: tuple[int, int] | None,
|
| 1013 |
+
reasoning_span: tuple[int, int] | None,
|
| 1014 |
+
) -> TraceResult:
|
| 1015 |
+
prompt_tokens = list(raw.prompt_tokens)
|
| 1016 |
+
generation_tokens = list(raw.generation_tokens)
|
| 1017 |
+
prompt_len = len(prompt_tokens)
|
| 1018 |
+
metadata = dict(raw.metadata or {})
|
| 1019 |
+
if "method" not in metadata:
|
| 1020 |
+
metadata["method"] = method
|
| 1021 |
+
|
| 1022 |
+
ifr_meta = metadata.get("ifr") if isinstance(metadata.get("ifr"), dict) else {}
|
| 1023 |
+
observation = ifr_meta.get("observation_projected") if isinstance(ifr_meta, dict) else None
|
| 1024 |
+
per_hop_projected = ifr_meta.get("per_hop_projected") if isinstance(ifr_meta, dict) else None
|
| 1025 |
+
|
| 1026 |
+
if isinstance(observation, dict) and "sum" in observation:
|
| 1027 |
+
vector = _to_float_list(observation["sum"])
|
| 1028 |
+
scores = vector[:prompt_len]
|
| 1029 |
+
else:
|
| 1030 |
+
matrix = torch.nan_to_num(raw.attribution_matrix.detach().cpu().to(dtype=torch.float32), nan=0.0)
|
| 1031 |
+
if output_span is not None:
|
| 1032 |
+
start, end = output_span
|
| 1033 |
+
selected = matrix[int(start) : int(end) + 1, :prompt_len]
|
| 1034 |
+
else:
|
| 1035 |
+
selected = matrix[:, :prompt_len]
|
| 1036 |
+
scores = selected.mean(dim=0).tolist() if selected.numel() else [0.0 for _ in prompt_tokens]
|
| 1037 |
+
|
| 1038 |
+
per_hop_scores: list[list[float]] = []
|
| 1039 |
+
if per_hop_projected:
|
| 1040 |
+
for hop_vector in per_hop_projected:
|
| 1041 |
+
per_hop_scores.append(_to_float_list(hop_vector)[:prompt_len])
|
| 1042 |
+
|
| 1043 |
+
ratios = ifr_meta.get("thinking_ratios", []) if isinstance(ifr_meta, dict) else []
|
| 1044 |
+
return TraceResult(
|
| 1045 |
+
prompt_tokens=prompt_tokens,
|
| 1046 |
+
generation_tokens=generation_tokens,
|
| 1047 |
+
scores=[float(x) for x in scores],
|
| 1048 |
+
per_hop_scores=per_hop_scores,
|
| 1049 |
+
thinking_ratios=_to_float_list(ratios),
|
| 1050 |
+
output_span=output_span,
|
| 1051 |
+
reasoning_span=reasoning_span,
|
| 1052 |
+
method=method,
|
| 1053 |
+
metadata=metadata,
|
| 1054 |
+
)
|
| 1055 |
+
```
|
| 1056 |
+
|
| 1057 |
+
- [ ] **Step 4: Confirm public exports**
|
| 1058 |
+
|
| 1059 |
+
Keep `flashtrace/__init__.py` as:
|
| 1060 |
+
|
| 1061 |
+
```python
|
| 1062 |
+
"""FlashTrace: efficient multi-token attribution for reasoning LLMs."""
|
| 1063 |
+
|
| 1064 |
+
from .model_io import load_model_and_tokenizer
|
| 1065 |
+
from .result import TokenScore, TraceResult
|
| 1066 |
+
from .tracer import FlashTrace
|
| 1067 |
+
|
| 1068 |
+
__all__ = ["FlashTrace", "TraceResult", "TokenScore", "load_model_and_tokenizer"]
|
| 1069 |
+
```
|
| 1070 |
+
|
| 1071 |
+
- [ ] **Step 5: Run tracer tests**
|
| 1072 |
+
|
| 1073 |
+
Run:
|
| 1074 |
+
|
| 1075 |
+
```bash
|
| 1076 |
+
uv run pytest tests/test_tracer.py -q
|
| 1077 |
+
```
|
| 1078 |
+
|
| 1079 |
+
Expected: `2 passed`.
|
| 1080 |
+
|
| 1081 |
+
- [ ] **Step 6: Run package tests created so far**
|
| 1082 |
+
|
| 1083 |
+
Run:
|
| 1084 |
+
|
| 1085 |
+
```bash
|
| 1086 |
+
uv run pytest tests/test_imports.py tests/test_core_recompute.py tests/test_result.py tests/test_tracer.py -q
|
| 1087 |
+
```
|
| 1088 |
+
|
| 1089 |
+
Expected: all selected tests pass.
|
| 1090 |
+
|
| 1091 |
+
- [ ] **Step 7: Commit**
|
| 1092 |
+
|
| 1093 |
+
Run:
|
| 1094 |
+
|
| 1095 |
+
```bash
|
| 1096 |
+
git add flashtrace/tracer.py flashtrace/__init__.py tests/test_tracer.py
|
| 1097 |
+
git commit -m "feat: add FlashTrace public facade"
|
| 1098 |
+
```
|
| 1099 |
+
|
| 1100 |
+
## Task 6: Model IO And CLI
|
| 1101 |
+
|
| 1102 |
+
**Files:**
|
| 1103 |
+
- Modify: `flashtrace/model_io.py`
|
| 1104 |
+
- Modify: `flashtrace/cli.py`
|
| 1105 |
+
- Create: `tests/test_cli.py`
|
| 1106 |
+
|
| 1107 |
+
- [ ] **Step 1: Write CLI tests**
|
| 1108 |
+
|
| 1109 |
+
Create `tests/test_cli.py`:
|
| 1110 |
+
|
| 1111 |
+
```python
|
| 1112 |
+
import pytest
|
| 1113 |
+
|
| 1114 |
+
from flashtrace.cli import main, parse_span
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
def test_parse_span():
|
| 1118 |
+
assert parse_span("3:8") == (3, 8)
|
| 1119 |
+
assert parse_span(None) is None
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
@pytest.mark.parametrize("value", ["3", "8:3", "a:b"])
|
| 1123 |
+
def test_parse_span_rejects_invalid_values(value):
|
| 1124 |
+
with pytest.raises(ValueError):
|
| 1125 |
+
parse_span(value)
|
| 1126 |
+
|
| 1127 |
+
|
| 1128 |
+
def test_cli_help_exits_successfully(capsys):
|
| 1129 |
+
with pytest.raises(SystemExit) as exc:
|
| 1130 |
+
main(["--help"])
|
| 1131 |
+
|
| 1132 |
+
assert exc.value.code == 0
|
| 1133 |
+
assert "trace" in capsys.readouterr().out
|
| 1134 |
+
```
|
| 1135 |
+
|
| 1136 |
+
- [ ] **Step 2: Run CLI tests and see the expected failure**
|
| 1137 |
+
|
| 1138 |
+
Run:
|
| 1139 |
+
|
| 1140 |
+
```bash
|
| 1141 |
+
uv run pytest tests/test_cli.py -q
|
| 1142 |
+
```
|
| 1143 |
+
|
| 1144 |
+
Expected: pytest reports missing `parse_span`.
|
| 1145 |
+
|
| 1146 |
+
- [ ] **Step 3: Implement model loading**
|
| 1147 |
+
|
| 1148 |
+
Replace `flashtrace/model_io.py` with:
|
| 1149 |
+
|
| 1150 |
+
```python
|
| 1151 |
+
from __future__ import annotations
|
| 1152 |
+
|
| 1153 |
+
from typing import Any
|
| 1154 |
+
|
| 1155 |
+
import torch
|
| 1156 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1157 |
+
|
| 1158 |
+
|
| 1159 |
+
def _resolve_dtype(dtype: str | torch.dtype = "auto") -> str | torch.dtype:
|
| 1160 |
+
if isinstance(dtype, torch.dtype):
|
| 1161 |
+
return dtype
|
| 1162 |
+
value = str(dtype).lower()
|
| 1163 |
+
if value == "auto":
|
| 1164 |
+
return "auto"
|
| 1165 |
+
mapping = {
|
| 1166 |
+
"float16": torch.float16,
|
| 1167 |
+
"fp16": torch.float16,
|
| 1168 |
+
"bfloat16": torch.bfloat16,
|
| 1169 |
+
"bf16": torch.bfloat16,
|
| 1170 |
+
"float32": torch.float32,
|
| 1171 |
+
"fp32": torch.float32,
|
| 1172 |
+
}
|
| 1173 |
+
if value not in mapping:
|
| 1174 |
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
| 1175 |
+
return mapping[value]
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
def load_model_and_tokenizer(
|
| 1179 |
+
model_name_or_path: str,
|
| 1180 |
+
*,
|
| 1181 |
+
device_map: str | dict[str, Any] | None = "auto",
|
| 1182 |
+
dtype: str | torch.dtype = "auto",
|
| 1183 |
+
trust_remote_code: bool = True,
|
| 1184 |
+
**model_kwargs: Any,
|
| 1185 |
+
):
|
| 1186 |
+
"""Load a Hugging Face causal LM and matching tokenizer."""
|
| 1187 |
+
|
| 1188 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
|
| 1189 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 1190 |
+
model_name_or_path,
|
| 1191 |
+
torch_dtype=_resolve_dtype(dtype),
|
| 1192 |
+
device_map=device_map,
|
| 1193 |
+
trust_remote_code=trust_remote_code,
|
| 1194 |
+
**model_kwargs,
|
| 1195 |
+
)
|
| 1196 |
+
model.eval()
|
| 1197 |
+
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
| 1198 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1199 |
+
return model, tokenizer
|
| 1200 |
+
```
|
| 1201 |
+
|
| 1202 |
+
- [ ] **Step 4: Implement CLI**
|
| 1203 |
+
|
| 1204 |
+
Replace `flashtrace/cli.py` with:
|
| 1205 |
+
|
| 1206 |
+
```python
|
| 1207 |
+
from __future__ import annotations
|
| 1208 |
+
|
| 1209 |
+
import argparse
|
| 1210 |
+
from pathlib import Path
|
| 1211 |
+
from typing import Sequence
|
| 1212 |
+
|
| 1213 |
+
from .model_io import load_model_and_tokenizer
|
| 1214 |
+
from .tracer import FlashTrace
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
def parse_span(value: str | None) -> tuple[int, int] | None:
|
| 1218 |
+
if value is None:
|
| 1219 |
+
return None
|
| 1220 |
+
parts = str(value).split(":")
|
| 1221 |
+
if len(parts) != 2:
|
| 1222 |
+
raise ValueError("Span must use START:END format.")
|
| 1223 |
+
try:
|
| 1224 |
+
start = int(parts[0])
|
| 1225 |
+
end = int(parts[1])
|
| 1226 |
+
except ValueError as exc:
|
| 1227 |
+
raise ValueError("Span bounds must be integers.") from exc
|
| 1228 |
+
if start < 0 or end < start:
|
| 1229 |
+
raise ValueError("Span must satisfy 0 <= START <= END.")
|
| 1230 |
+
return start, end
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 1234 |
+
parser = argparse.ArgumentParser(prog="flashtrace", description="Trace language model outputs with FlashTrace.")
|
| 1235 |
+
sub = parser.add_subparsers(dest="command")
|
| 1236 |
+
|
| 1237 |
+
trace = sub.add_parser("trace", help="Run attribution for a prompt and target.")
|
| 1238 |
+
trace.add_argument("--model", required=True, help="Hugging Face model id or local path.")
|
| 1239 |
+
trace.add_argument("--prompt", required=True, help="UTF-8 text file containing the prompt.")
|
| 1240 |
+
trace.add_argument("--target", help="UTF-8 text file containing the target response.")
|
| 1241 |
+
trace.add_argument("--output-span", help="Inclusive generation-token span START:END.")
|
| 1242 |
+
trace.add_argument("--reasoning-span", help="Inclusive generation-token span START:END.")
|
| 1243 |
+
trace.add_argument("--hops", type=int, default=1)
|
| 1244 |
+
trace.add_argument("--method", default="flashtrace", choices=["flashtrace", "ifr-span", "ifr-matrix"])
|
| 1245 |
+
trace.add_argument("--html", help="Write standalone HTML heatmap.")
|
| 1246 |
+
trace.add_argument("--json", help="Write JSON trace.")
|
| 1247 |
+
trace.add_argument("--device-map", default="auto")
|
| 1248 |
+
trace.add_argument("--dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"])
|
| 1249 |
+
trace.add_argument("--chunk-tokens", type=int, default=128)
|
| 1250 |
+
trace.add_argument("--sink-chunk-tokens", type=int, default=32)
|
| 1251 |
+
trace.add_argument("--recompute-attention", action="store_true")
|
| 1252 |
+
return parser
|
| 1253 |
+
|
| 1254 |
+
|
| 1255 |
+
def _read_text(path: str | None) -> str | None:
|
| 1256 |
+
if path is None:
|
| 1257 |
+
return None
|
| 1258 |
+
return Path(path).read_text(encoding="utf-8")
|
| 1259 |
+
|
| 1260 |
+
|
| 1261 |
+
def _run_trace(args: argparse.Namespace) -> int:
|
| 1262 |
+
model, tokenizer = load_model_and_tokenizer(args.model, device_map=args.device_map, dtype=args.dtype)
|
| 1263 |
+
tracer = FlashTrace(
|
| 1264 |
+
model,
|
| 1265 |
+
tokenizer,
|
| 1266 |
+
chunk_tokens=args.chunk_tokens,
|
| 1267 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 1268 |
+
recompute_attention=args.recompute_attention,
|
| 1269 |
+
)
|
| 1270 |
+
result = tracer.trace(
|
| 1271 |
+
prompt=_read_text(args.prompt) or "",
|
| 1272 |
+
target=_read_text(args.target),
|
| 1273 |
+
output_span=parse_span(args.output_span),
|
| 1274 |
+
reasoning_span=parse_span(args.reasoning_span),
|
| 1275 |
+
hops=args.hops,
|
| 1276 |
+
method=args.method,
|
| 1277 |
+
)
|
| 1278 |
+
for item in result.topk_inputs(20):
|
| 1279 |
+
print(f"{item.index}\t{item.score:.6f}\t{item.token!r}")
|
| 1280 |
+
if args.json:
|
| 1281 |
+
result.to_json(args.json)
|
| 1282 |
+
if args.html:
|
| 1283 |
+
result.to_html(args.html)
|
| 1284 |
+
return 0
|
| 1285 |
+
|
| 1286 |
+
|
| 1287 |
+
def main(argv: Sequence[str] | None = None) -> int:
|
| 1288 |
+
parser = build_parser()
|
| 1289 |
+
args = parser.parse_args(argv)
|
| 1290 |
+
if args.command == "trace":
|
| 1291 |
+
return _run_trace(args)
|
| 1292 |
+
parser.print_help()
|
| 1293 |
+
return 0
|
| 1294 |
+
```
|
| 1295 |
+
|
| 1296 |
+
- [ ] **Step 5: Run CLI tests**
|
| 1297 |
+
|
| 1298 |
+
Run:
|
| 1299 |
+
|
| 1300 |
+
```bash
|
| 1301 |
+
uv run pytest tests/test_cli.py -q
|
| 1302 |
+
```
|
| 1303 |
+
|
| 1304 |
+
Expected: all CLI tests pass.
|
| 1305 |
+
|
| 1306 |
+
- [ ] **Step 6: Verify console script metadata**
|
| 1307 |
+
|
| 1308 |
+
Run:
|
| 1309 |
+
|
| 1310 |
+
```bash
|
| 1311 |
+
uv run flashtrace --help
|
| 1312 |
+
```
|
| 1313 |
+
|
| 1314 |
+
Expected: help text includes `trace`.
|
| 1315 |
+
|
| 1316 |
+
- [ ] **Step 7: Commit**
|
| 1317 |
+
|
| 1318 |
+
Run:
|
| 1319 |
+
|
| 1320 |
+
```bash
|
| 1321 |
+
git add flashtrace/model_io.py flashtrace/cli.py tests/test_cli.py
|
| 1322 |
+
git commit -m "feat: add model loader and CLI"
|
| 1323 |
+
```
|
| 1324 |
+
|
| 1325 |
+
## Task 7: README, Example, License, And Release Hygiene
|
| 1326 |
+
|
| 1327 |
+
**Files:**
|
| 1328 |
+
- Create: `README.md`
|
| 1329 |
+
- Create: `LICENSE`
|
| 1330 |
+
- Create: `examples/quickstart.py`
|
| 1331 |
+
- Modify: `.gitignore`
|
| 1332 |
+
- Delete: `model_generation.py`
|
| 1333 |
+
|
| 1334 |
+
- [ ] **Step 1: Create the quickstart example**
|
| 1335 |
+
|
| 1336 |
+
Create `examples/quickstart.py`:
|
| 1337 |
+
|
| 1338 |
+
```python
|
| 1339 |
+
from __future__ import annotations
|
| 1340 |
+
|
| 1341 |
+
import argparse
|
| 1342 |
+
|
| 1343 |
+
from flashtrace import FlashTrace, load_model_and_tokenizer
|
| 1344 |
+
|
| 1345 |
+
|
| 1346 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 1347 |
+
parser = argparse.ArgumentParser(description="FlashTrace quickstart example.")
|
| 1348 |
+
parser.add_argument("--model", required=True, help="Hugging Face model id or local model path.")
|
| 1349 |
+
parser.add_argument("--prompt", required=True, help="Prompt text.")
|
| 1350 |
+
parser.add_argument("--target", help="Target response text.")
|
| 1351 |
+
parser.add_argument("--output-span", default=None, help="Inclusive generation-token span START:END.")
|
| 1352 |
+
parser.add_argument("--reasoning-span", default=None, help="Inclusive generation-token span START:END.")
|
| 1353 |
+
parser.add_argument("--html", default="trace.html", help="Output HTML path.")
|
| 1354 |
+
return parser
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
def parse_span(value: str | None) -> tuple[int, int] | None:
|
| 1358 |
+
from flashtrace.cli import parse_span as parse_cli_span
|
| 1359 |
+
|
| 1360 |
+
return parse_cli_span(value)
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
+
def main() -> int:
|
| 1364 |
+
args = build_parser().parse_args()
|
| 1365 |
+
model, tokenizer = load_model_and_tokenizer(args.model)
|
| 1366 |
+
tracer = FlashTrace(model, tokenizer)
|
| 1367 |
+
trace = tracer.trace(
|
| 1368 |
+
prompt=args.prompt,
|
| 1369 |
+
target=args.target,
|
| 1370 |
+
output_span=parse_span(args.output_span),
|
| 1371 |
+
reasoning_span=parse_span(args.reasoning_span),
|
| 1372 |
+
)
|
| 1373 |
+
for item in trace.topk_inputs(10):
|
| 1374 |
+
print(f"{item.index}\t{item.score:.6f}\t{item.token!r}")
|
| 1375 |
+
trace.to_html(args.html)
|
| 1376 |
+
print(f"wrote {args.html}")
|
| 1377 |
+
return 0
|
| 1378 |
+
|
| 1379 |
+
|
| 1380 |
+
if __name__ == "__main__":
|
| 1381 |
+
raise SystemExit(main())
|
| 1382 |
+
```
|
| 1383 |
+
|
| 1384 |
+
- [ ] **Step 2: Add README**
|
| 1385 |
+
|
| 1386 |
+
Create `README.md`:
|
| 1387 |
+
|
| 1388 |
+
```markdown
|
| 1389 |
+
# FlashTrace
|
| 1390 |
+
|
| 1391 |
+
FlashTrace is an efficient multi-token attribution toolkit for reasoning language models. It implements the method described in [Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs](https://arxiv.org/abs/2602.01914).
|
| 1392 |
+
|
| 1393 |
+
## Install
|
| 1394 |
+
|
| 1395 |
+
```bash
|
| 1396 |
+
pip install -e .
|
| 1397 |
+
```
|
| 1398 |
+
|
| 1399 |
+
## Python Quickstart
|
| 1400 |
+
|
| 1401 |
+
```python
|
| 1402 |
+
from flashtrace import FlashTrace, load_model_and_tokenizer
|
| 1403 |
+
|
| 1404 |
+
model, tokenizer = load_model_and_tokenizer("Qwen/Qwen3-8B")
|
| 1405 |
+
tracer = FlashTrace(model, tokenizer)
|
| 1406 |
+
|
| 1407 |
+
trace = tracer.trace(
|
| 1408 |
+
prompt="Context: Paris is the capital of France.\nQuestion: What is the capital of France?",
|
| 1409 |
+
target="Paris",
|
| 1410 |
+
output_span=(0, 0),
|
| 1411 |
+
hops=1,
|
| 1412 |
+
)
|
| 1413 |
+
|
| 1414 |
+
print(trace.topk_inputs(10))
|
| 1415 |
+
trace.to_html("trace.html")
|
| 1416 |
+
trace.to_json("trace.json")
|
| 1417 |
+
```
|
| 1418 |
+
|
| 1419 |
+
## CLI Quickstart
|
| 1420 |
+
|
| 1421 |
+
```bash
|
| 1422 |
+
flashtrace trace \
|
| 1423 |
+
--model Qwen/Qwen3-8B \
|
| 1424 |
+
--prompt prompt.txt \
|
| 1425 |
+
--target target.txt \
|
| 1426 |
+
--output-span 0:0 \
|
| 1427 |
+
--hops 1 \
|
| 1428 |
+
--html trace.html \
|
| 1429 |
+
--json trace.json
|
| 1430 |
+
```
|
| 1431 |
+
|
| 1432 |
+
## Token Spans
|
| 1433 |
+
|
| 1434 |
+
`output_span` and `reasoning_span` use inclusive generation-token indices. Inspect `trace.generation_tokens` after an initial run to choose spans for a target answer or reasoning segment.
|
| 1435 |
+
|
| 1436 |
+
## Supported Models
|
| 1437 |
+
|
| 1438 |
+
The package targets Llama/Qwen-style decoder-only Hugging Face causal LMs with standard Q/K/V/O projections, RMSNorm or LayerNorm, and RoPE metadata. Qwen2, Qwen3, and Llama are the first validated model families.
|
| 1439 |
+
|
| 1440 |
+
## Repository Map
|
| 1441 |
+
|
| 1442 |
+
- `flashtrace/`: reusable package
|
| 1443 |
+
- `examples/`: public examples
|
| 1444 |
+
- `tests/`: CPU smoke tests
|
| 1445 |
+
- `exp/`: paper experiments and artifacts
|
| 1446 |
+
|
| 1447 |
+
## Citation
|
| 1448 |
+
|
| 1449 |
+
```bibtex
|
| 1450 |
+
@misc{pan2026flashtrace,
|
| 1451 |
+
title={Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs},
|
| 1452 |
+
author={Pan, Wenbo and Liu, Zhichao and Wang, Xianlong and Yu, Haining and Jia, Xiaohua},
|
| 1453 |
+
year={2026},
|
| 1454 |
+
eprint={2602.01914},
|
| 1455 |
+
archivePrefix={arXiv},
|
| 1456 |
+
primaryClass={cs.LG}
|
| 1457 |
+
}
|
| 1458 |
+
```
|
| 1459 |
+
```
|
| 1460 |
+
|
| 1461 |
+
- [ ] **Step 3: Add MIT license**
|
| 1462 |
+
|
| 1463 |
+
Create `LICENSE`:
|
| 1464 |
+
|
| 1465 |
+
```text
|
| 1466 |
+
MIT License
|
| 1467 |
+
|
| 1468 |
+
Copyright (c) 2026 Wenbo Pan
|
| 1469 |
+
|
| 1470 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 1471 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 1472 |
+
in the Software without restriction, including without limitation the rights
|
| 1473 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 1474 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 1475 |
+
furnished to do so, subject to the following conditions:
|
| 1476 |
+
|
| 1477 |
+
The above copyright notice and this permission notice shall be included in all
|
| 1478 |
+
copies or substantial portions of the Software.
|
| 1479 |
+
|
| 1480 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 1481 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 1482 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 1483 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 1484 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 1485 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 1486 |
+
SOFTWARE.
|
| 1487 |
+
```
|
| 1488 |
+
|
| 1489 |
+
- [ ] **Step 4: Update generated artifact ignore rules**
|
| 1490 |
+
|
| 1491 |
+
Append to `.gitignore`:
|
| 1492 |
+
|
| 1493 |
+
```gitignore
|
| 1494 |
+
|
| 1495 |
+
# FlashTrace generated artifacts
|
| 1496 |
+
trace.json
|
| 1497 |
+
trace.html
|
| 1498 |
+
*.trace.json
|
| 1499 |
+
*.trace.html
|
| 1500 |
+
exp/**/output/
|
| 1501 |
+
exp/**/out/
|
| 1502 |
+
exp/**/out-*/
|
| 1503 |
+
*.npz
|
| 1504 |
+
```
|
| 1505 |
+
|
| 1506 |
+
- [ ] **Step 5: Remove the template artifact**
|
| 1507 |
+
|
| 1508 |
+
Run:
|
| 1509 |
+
|
| 1510 |
+
```bash
|
| 1511 |
+
git rm model_generation.py
|
| 1512 |
+
```
|
| 1513 |
+
|
| 1514 |
+
- [ ] **Step 6: Verify quickstart help**
|
| 1515 |
+
|
| 1516 |
+
Run:
|
| 1517 |
+
|
| 1518 |
+
```bash
|
| 1519 |
+
uv run python examples/quickstart.py --help
|
| 1520 |
+
```
|
| 1521 |
+
|
| 1522 |
+
Expected: help text includes `FlashTrace quickstart example`.
|
| 1523 |
+
|
| 1524 |
+
- [ ] **Step 7: Commit**
|
| 1525 |
+
|
| 1526 |
+
Run:
|
| 1527 |
+
|
| 1528 |
+
```bash
|
| 1529 |
+
git add README.md LICENSE examples/quickstart.py .gitignore
|
| 1530 |
+
git commit -m "docs: add public quickstart and release hygiene"
|
| 1531 |
+
```
|
| 1532 |
+
|
| 1533 |
+
## Task 8: Final Verification And Package Audit
|
| 1534 |
+
|
| 1535 |
+
**Files:**
|
| 1536 |
+
- Modify: any files needed to fix verification failures.
|
| 1537 |
+
|
| 1538 |
+
- [ ] **Step 1: Run the full CPU test suite**
|
| 1539 |
+
|
| 1540 |
+
Run:
|
| 1541 |
+
|
| 1542 |
+
```bash
|
| 1543 |
+
uv run pytest tests -q
|
| 1544 |
+
```
|
| 1545 |
+
|
| 1546 |
+
Expected: all tests pass.
|
| 1547 |
+
|
| 1548 |
+
- [ ] **Step 2: Verify editable install import**
|
| 1549 |
+
|
| 1550 |
+
Run:
|
| 1551 |
+
|
| 1552 |
+
```bash
|
| 1553 |
+
uv run python -c "import flashtrace; print(flashtrace.FlashTrace.__name__)"
|
| 1554 |
+
```
|
| 1555 |
+
|
| 1556 |
+
Expected: prints `FlashTrace`.
|
| 1557 |
+
|
| 1558 |
+
- [ ] **Step 3: Verify CLI help**
|
| 1559 |
+
|
| 1560 |
+
Run:
|
| 1561 |
+
|
| 1562 |
+
```bash
|
| 1563 |
+
uv run flashtrace --help
|
| 1564 |
+
uv run flashtrace trace --help
|
| 1565 |
+
```
|
| 1566 |
+
|
| 1567 |
+
Expected: both commands print help text.
|
| 1568 |
+
|
| 1569 |
+
- [ ] **Step 4: Verify root compatibility imports**
|
| 1570 |
+
|
| 1571 |
+
Run:
|
| 1572 |
+
|
| 1573 |
+
```bash
|
| 1574 |
+
uv run python -c "import ifr_core, llm_attr, ft_ifr_improve; print(ifr_core.compute_multi_hop_ifr.__name__)"
|
| 1575 |
+
```
|
| 1576 |
+
|
| 1577 |
+
Expected: prints `compute_multi_hop_ifr`.
|
| 1578 |
+
|
| 1579 |
+
- [ ] **Step 5: Inspect package file list**
|
| 1580 |
+
|
| 1581 |
+
Run:
|
| 1582 |
+
|
| 1583 |
+
```bash
|
| 1584 |
+
git status --short
|
| 1585 |
+
find flashtrace -maxdepth 3 -type f | sort
|
| 1586 |
+
```
|
| 1587 |
+
|
| 1588 |
+
Expected: package files match the design spec and only intended changes appear.
|
| 1589 |
+
|
| 1590 |
+
- [ ] **Step 6: Commit final fixes**
|
| 1591 |
+
|
| 1592 |
+
Run after any verification fixes:
|
| 1593 |
+
|
| 1594 |
+
```bash
|
| 1595 |
+
git add .
|
| 1596 |
+
git commit -m "test: verify public package smoke tests"
|
| 1597 |
+
```
|
| 1598 |
+
|
| 1599 |
+
If verification passes with a clean tree after prior commits, record the passing commands in the final implementation response.
|
| 1600 |
+
|
| 1601 |
+
## Self-Review Checklist
|
| 1602 |
+
|
| 1603 |
+
- Spec coverage: package layout, public API, result export, CLI, visualization, packaging, compatibility, tests, README, and release hygiene each have at least one task.
|
| 1604 |
+
- Type consistency: `FlashTrace.trace`, `TraceResult`, `TokenScore`, `load_model_and_tokenizer`, and CLI span parsing use the same names across tests and implementation steps.
|
| 1605 |
+
- Test path: every implementation task starts with a failing test or a verification command, then ends with a passing command and commit.
|
docs/superpowers/specs/2026-05-03-flashtrace-public-package-design.md
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashTrace Public Package Design
|
| 2 |
+
|
| 3 |
+
## Goal
|
| 4 |
+
|
| 5 |
+
Turn the current FlashTrace research repository into an installable, documented Python package that researchers can use from Python or the command line to trace LLM outputs, export JSON traces, and render HTML token heatmaps.
|
| 6 |
+
|
| 7 |
+
## Release Scope
|
| 8 |
+
|
| 9 |
+
This first public release ships four user-facing capabilities:
|
| 10 |
+
|
| 11 |
+
- A stable Python API centered on `FlashTrace`.
|
| 12 |
+
- A `flashtrace trace` CLI for prompt/target files and Hugging Face model ids or local model paths.
|
| 13 |
+
- A `TraceResult` object with top-k, JSON, and HTML export helpers.
|
| 14 |
+
- A README quickstart that demonstrates Python, CLI, and heatmap workflows.
|
| 15 |
+
|
| 16 |
+
Paper experiment runners and saved experiment artifacts remain in `exp/` as research assets. Their full reproducibility cleanup belongs to a later phase.
|
| 17 |
+
|
| 18 |
+
## Repository Shape
|
| 19 |
+
|
| 20 |
+
The reusable package lives under `flashtrace/`, examples under `examples/`, tests under `tests/`, and paper experiments under `exp/`.
|
| 21 |
+
|
| 22 |
+
```text
|
| 23 |
+
flashtrace/
|
| 24 |
+
__init__.py
|
| 25 |
+
tracer.py
|
| 26 |
+
result.py
|
| 27 |
+
core.py
|
| 28 |
+
model_io.py
|
| 29 |
+
viz.py
|
| 30 |
+
cli.py
|
| 31 |
+
baselines/
|
| 32 |
+
__init__.py
|
| 33 |
+
attnlrp.py
|
| 34 |
+
examples/
|
| 35 |
+
quickstart.py
|
| 36 |
+
tests/
|
| 37 |
+
test_core_recompute.py
|
| 38 |
+
test_tracer.py
|
| 39 |
+
test_result.py
|
| 40 |
+
test_cli.py
|
| 41 |
+
exp/
|
| 42 |
+
exp1/
|
| 43 |
+
exp2/
|
| 44 |
+
case_study/
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Existing root modules are migrated gradually. During migration, compatibility wrappers remain at the root for experiment scripts that still import `llm_attr`, `ifr_core`, or `ft_ifr_improve`.
|
| 48 |
+
|
| 49 |
+
## Core Implementation Mapping
|
| 50 |
+
|
| 51 |
+
`flashtrace.core` contains the IFR tensor implementation from `ifr_core.py`:
|
| 52 |
+
|
| 53 |
+
- `extract_model_metadata`
|
| 54 |
+
- `build_weight_pack`
|
| 55 |
+
- `attach_hooks`
|
| 56 |
+
- `recompute_layer_attention`
|
| 57 |
+
- `compute_ifr_sentence_aggregate`
|
| 58 |
+
- `compute_multi_hop_ifr`
|
| 59 |
+
- `compute_ifr_for_all_positions`
|
| 60 |
+
|
| 61 |
+
`flashtrace.tracer` wraps the current high-level attribution classes:
|
| 62 |
+
|
| 63 |
+
- Default `method="flashtrace"` uses the current `LLMIFRAttributionBoth.calculate_ifr_multi_hop_both` behavior.
|
| 64 |
+
- `method="ifr-span"` uses `LLMIFRAttribution.calculate_ifr_span`.
|
| 65 |
+
- `method="ifr-matrix"` uses `LLMIFRAttribution.calculate_ifr_for_all_positions_output_only`.
|
| 66 |
+
|
| 67 |
+
`flashtrace.baselines.attnlrp` contains the AttnLRP patching and recursive baseline code from `lrp_rules.py`, `lrp_patches.py`, and `LLMLRPAttribution`.
|
| 68 |
+
|
| 69 |
+
## Public Python API
|
| 70 |
+
|
| 71 |
+
The package exports `FlashTrace`, `TraceResult`, and `load_model_and_tokenizer`.
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
from flashtrace import FlashTrace, load_model_and_tokenizer
|
| 75 |
+
|
| 76 |
+
model, tokenizer = load_model_and_tokenizer("Qwen/Qwen3-8B", device_map="auto")
|
| 77 |
+
tracer = FlashTrace(model, tokenizer, chunk_tokens=128, sink_chunk_tokens=32)
|
| 78 |
+
|
| 79 |
+
trace = tracer.trace(
|
| 80 |
+
prompt=prompt,
|
| 81 |
+
target=target,
|
| 82 |
+
output_span=(80, 85),
|
| 83 |
+
reasoning_span=(0, 79),
|
| 84 |
+
hops=1,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
print(trace.topk_inputs(20))
|
| 88 |
+
trace.to_json("trace.json")
|
| 89 |
+
trace.to_html("trace.html")
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
`FlashTrace.trace(...)` accepts:
|
| 93 |
+
|
| 94 |
+
- `prompt: str`
|
| 95 |
+
- `target: str | None`
|
| 96 |
+
- `output_span: tuple[int, int] | None`
|
| 97 |
+
- `reasoning_span: tuple[int, int] | None`
|
| 98 |
+
- `hops: int`
|
| 99 |
+
- `method: Literal["flashtrace", "ifr-span", "ifr-matrix"]`
|
| 100 |
+
- `renorm_threshold: float | None`
|
| 101 |
+
|
| 102 |
+
Generation-token spans are inclusive and use the tokenizer alignment already produced by the attribution path. The README explains this convention and shows how to inspect `trace.generation_tokens`.
|
| 103 |
+
|
| 104 |
+
## TraceResult
|
| 105 |
+
|
| 106 |
+
`TraceResult` is a small dataclass that hides the older `LLMAttributionResult` shape from public users.
|
| 107 |
+
|
| 108 |
+
Fields:
|
| 109 |
+
|
| 110 |
+
- `prompt_tokens: list[str]`
|
| 111 |
+
- `generation_tokens: list[str]`
|
| 112 |
+
- `scores: list[float]`
|
| 113 |
+
- `per_hop_scores: list[list[float]]`
|
| 114 |
+
- `thinking_ratios: list[float]`
|
| 115 |
+
- `output_span: tuple[int, int] | None`
|
| 116 |
+
- `reasoning_span: tuple[int, int] | None`
|
| 117 |
+
- `method: str`
|
| 118 |
+
- `metadata: dict[str, Any]`
|
| 119 |
+
|
| 120 |
+
Methods:
|
| 121 |
+
|
| 122 |
+
- `topk_inputs(k: int = 20) -> list[TokenScore]`
|
| 123 |
+
- `to_dict() -> dict[str, Any]`
|
| 124 |
+
- `to_json(path: str | Path) -> None`
|
| 125 |
+
- `to_html(path: str | Path) -> None`
|
| 126 |
+
|
| 127 |
+
`TokenScore` contains `index`, `token`, and `score`. Scores are aligned to `prompt_tokens`.
|
| 128 |
+
|
| 129 |
+
## CLI
|
| 130 |
+
|
| 131 |
+
The package exposes one console script:
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
flashtrace trace \
|
| 135 |
+
--model Qwen/Qwen3-8B \
|
| 136 |
+
--prompt prompt.txt \
|
| 137 |
+
--target target.txt \
|
| 138 |
+
--output-span 80:85 \
|
| 139 |
+
--reasoning-span 0:79 \
|
| 140 |
+
--hops 1 \
|
| 141 |
+
--html trace.html \
|
| 142 |
+
--json trace.json
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
CLI behavior:
|
| 146 |
+
|
| 147 |
+
- `--model` accepts a Hugging Face id or local path.
|
| 148 |
+
- `--prompt` and `--target` read UTF-8 text files.
|
| 149 |
+
- `--target` is optional; the model generates with deterministic defaults when this flag is absent.
|
| 150 |
+
- `--output-span` and `--reasoning-span` parse inclusive `START:END` generation-token spans.
|
| 151 |
+
- `--method` defaults to `flashtrace`.
|
| 152 |
+
- `--recompute-attention` enables lower-memory attention recomputation.
|
| 153 |
+
- `--device-map` defaults to `auto`.
|
| 154 |
+
- `--dtype` accepts `auto`, `float16`, `bfloat16`, or `float32`.
|
| 155 |
+
|
| 156 |
+
The command prints a compact top-k table to stdout and writes requested artifacts.
|
| 157 |
+
|
| 158 |
+
## Visualization
|
| 159 |
+
|
| 160 |
+
`flashtrace.viz` adapts the token heatmap renderer from `exp/case_study/viz.py`.
|
| 161 |
+
|
| 162 |
+
The public heatmap focuses on:
|
| 163 |
+
|
| 164 |
+
- prompt tokens colored by final attribution score,
|
| 165 |
+
- optional per-hop panels,
|
| 166 |
+
- output and reasoning span summary,
|
| 167 |
+
- model/method metadata.
|
| 168 |
+
|
| 169 |
+
The renderer returns a standalone HTML string and writes standalone HTML files through `TraceResult.to_html`.
|
| 170 |
+
|
| 171 |
+
## Packaging
|
| 172 |
+
|
| 173 |
+
`pyproject.toml` becomes package metadata for `flashtrace`:
|
| 174 |
+
|
| 175 |
+
- `name = "flashtrace"`
|
| 176 |
+
- realistic `requires-python` support for current PyTorch and Transformers use,
|
| 177 |
+
- console script `flashtrace = "flashtrace.cli:main"`,
|
| 178 |
+
- core dependencies: `torch`, `transformers`, `accelerate`, `numpy`, `tqdm`,
|
| 179 |
+
- optional extras: `viz`, `eval`, `dev`, `baselines`.
|
| 180 |
+
|
| 181 |
+
The root README includes:
|
| 182 |
+
|
| 183 |
+
- project tagline,
|
| 184 |
+
- paper link and citation,
|
| 185 |
+
- install instructions,
|
| 186 |
+
- Python quickstart,
|
| 187 |
+
- CLI quickstart,
|
| 188 |
+
- supported model family notes,
|
| 189 |
+
- output interpretation,
|
| 190 |
+
- experiment directory map,
|
| 191 |
+
- troubleshooting for GPU memory and tokenizer spans.
|
| 192 |
+
|
| 193 |
+
## Compatibility
|
| 194 |
+
|
| 195 |
+
The release supports Llama/Qwen-style decoder-only Hugging Face causal LMs with `model.layers`, Q/K/V/O projections, RMSNorm/LayerNorm, and RoPE metadata. The README names Qwen2, Qwen3, and Llama as validated families.
|
| 196 |
+
|
| 197 |
+
Existing experiment scripts continue to run through temporary root-level compatibility modules while package imports are introduced. A later cleanup can remove the compatibility layer after `exp/` imports are migrated.
|
| 198 |
+
|
| 199 |
+
## Testing
|
| 200 |
+
|
| 201 |
+
Tests use a tiny randomly initialized Qwen2 model on CPU, following the existing `test_recompute.py` approach.
|
| 202 |
+
|
| 203 |
+
Required coverage:
|
| 204 |
+
|
| 205 |
+
- stored-attention and recomputed-attention paths return close values on the tiny model,
|
| 206 |
+
- `FlashTrace.trace(...)` returns a `TraceResult`,
|
| 207 |
+
- `TraceResult.topk_inputs(...)` sorts and truncates correctly,
|
| 208 |
+
- `TraceResult.to_dict()` is JSON serializable,
|
| 209 |
+
- `TraceResult.to_html()` writes standalone HTML containing token spans,
|
| 210 |
+
- `flashtrace trace --help` exits successfully.
|
| 211 |
+
|
| 212 |
+
Heavy GPU model tests remain manual examples.
|
| 213 |
+
|
| 214 |
+
## Release Hygiene
|
| 215 |
+
|
| 216 |
+
The release cleanup updates `.gitignore` to cover generated traces, experiment outputs, checkpoints, caches, and HTML/JSON artifacts created by examples.
|
| 217 |
+
|
| 218 |
+
Tracked historical experiment outputs stay untouched during the first package migration. A later artifact cleanup can move them to release assets or remove them with a dedicated confirmation step.
|
| 219 |
+
|
| 220 |
+
`model_generation.py` is a template artifact and is removed or moved outside the package path during implementation.
|
| 221 |
+
|
| 222 |
+
## Success Criteria
|
| 223 |
+
|
| 224 |
+
The release work is complete when:
|
| 225 |
+
|
| 226 |
+
- `pip install -e .` exposes `flashtrace`,
|
| 227 |
+
- `python examples/quickstart.py --help` works,
|
| 228 |
+
- `flashtrace trace --help` works,
|
| 229 |
+
- package smoke tests pass on CPU,
|
| 230 |
+
- README quickstart matches the implemented API,
|
| 231 |
+
- existing experiment entrypoints either run with compatibility imports or document their package-era invocation.
|
dump_exp2_hop_vh.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""One-off: add per-hop IFR vectors (vh) into an existing exp2 trace .npz.
|
| 3 |
+
|
| 4 |
+
This is useful when the original exp2 run saved sample-level traces but did not
|
| 5 |
+
include per-hop vectors for some multi-hop IFR variants (e.g. ifr_multi_hop_both).
|
| 6 |
+
|
| 7 |
+
Defaults are written to match the reference commands in `exp/exp2/README.md`.
|
| 8 |
+
|
| 9 |
+
Example (matches the path in the question):
|
| 10 |
+
|
| 11 |
+
python dump_exp2_hop_vh.py \
|
| 12 |
+
--trace_npz exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/ifr_multi_hop_both_n1_mfaithfulness_gen_95ex/ex_000026.npz \
|
| 13 |
+
--dataset exp/exp2/data/morehopqa.jsonl \
|
| 14 |
+
--attr_func ifr_multi_hop_both \
|
| 15 |
+
--model qwen-8B \
|
| 16 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 17 |
+
--cuda 2,3,4,5,6,7 \
|
| 18 |
+
--n_hops 1 \
|
| 19 |
+
--chunk_tokens 128 \
|
| 20 |
+
--sink_chunk_tokens 32 \
|
| 21 |
+
--inplace
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import hashlib
|
| 28 |
+
import json
|
| 29 |
+
import os
|
| 30 |
+
import re
|
| 31 |
+
import sys
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Any, Optional
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _early_set_cuda_visible_devices() -> None:
|
| 38 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 39 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 40 |
+
args, _ = parser.parse_known_args(sys.argv[1:])
|
| 41 |
+
if args.cuda and "," in str(args.cuda):
|
| 42 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
_early_set_cuda_visible_devices()
|
| 46 |
+
|
| 47 |
+
import numpy as np
|
| 48 |
+
import torch
|
| 49 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 50 |
+
|
| 51 |
+
import ft_ifr_improve
|
| 52 |
+
import llm_attr
|
| 53 |
+
from exp.exp2 import dataset_utils as ds_utils
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _sha1_text(text: str) -> str:
|
| 57 |
+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _resolve_device(cuda: Optional[str], cuda_num: int) -> str:
|
| 61 |
+
"""Mirror exp/exp2/run_exp.py device selection policy."""
|
| 62 |
+
if cuda is not None and "," in cuda:
|
| 63 |
+
# _early_set_cuda_visible_devices already applied.
|
| 64 |
+
return "auto"
|
| 65 |
+
if cuda is not None and str(cuda).strip():
|
| 66 |
+
return f"cuda:{cuda}" if torch.cuda.is_available() else "cpu"
|
| 67 |
+
return f"cuda:{int(cuda_num)}" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _load_model(model_name: str, device: str):
|
| 71 |
+
"""Mirror exp/exp2/run_exp.py model loading knobs."""
|
| 72 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 73 |
+
model_name,
|
| 74 |
+
device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
|
| 75 |
+
torch_dtype=torch.float16,
|
| 76 |
+
attn_implementation="eager",
|
| 77 |
+
)
|
| 78 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 79 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 80 |
+
model.eval()
|
| 81 |
+
return model, tokenizer
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass(frozen=True)
|
| 85 |
+
class ManifestRecord:
|
| 86 |
+
example_idx: int
|
| 87 |
+
prompt_sha1: str
|
| 88 |
+
target_sha1: Optional[str]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _load_manifest_record(manifest_path: Path, *, example_idx: int) -> Optional[ManifestRecord]:
|
| 92 |
+
if not manifest_path.exists():
|
| 93 |
+
return None
|
| 94 |
+
with manifest_path.open("r", encoding="utf-8") as f:
|
| 95 |
+
for line in f:
|
| 96 |
+
if not line.strip():
|
| 97 |
+
continue
|
| 98 |
+
obj = json.loads(line)
|
| 99 |
+
if int(obj.get("example_idx", -1)) != int(example_idx):
|
| 100 |
+
continue
|
| 101 |
+
return ManifestRecord(
|
| 102 |
+
example_idx=int(example_idx),
|
| 103 |
+
prompt_sha1=str(obj.get("prompt_sha1") or ""),
|
| 104 |
+
target_sha1=str(obj["target_sha1"]) if obj.get("target_sha1") is not None else None,
|
| 105 |
+
)
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _parse_example_idx_from_npz_name(path: Path) -> Optional[int]:
|
| 110 |
+
m = re.match(r"^ex_(\d+)$", path.stem)
|
| 111 |
+
if not m:
|
| 112 |
+
return None
|
| 113 |
+
try:
|
| 114 |
+
return int(m.group(1))
|
| 115 |
+
except Exception:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _pick_example(
|
| 120 |
+
examples: list[ds_utils.CachedExample],
|
| 121 |
+
*,
|
| 122 |
+
example_idx: int,
|
| 123 |
+
record: Optional[ManifestRecord],
|
| 124 |
+
) -> ds_utils.CachedExample:
|
| 125 |
+
if record is not None and record.prompt_sha1:
|
| 126 |
+
matches: list[ds_utils.CachedExample] = []
|
| 127 |
+
for ex in examples:
|
| 128 |
+
if _sha1_text(ex.prompt) != record.prompt_sha1:
|
| 129 |
+
continue
|
| 130 |
+
if record.target_sha1 is None:
|
| 131 |
+
if ex.target is None:
|
| 132 |
+
matches.append(ex)
|
| 133 |
+
else:
|
| 134 |
+
if ex.target is not None and _sha1_text(ex.target) == record.target_sha1:
|
| 135 |
+
matches.append(ex)
|
| 136 |
+
if len(matches) == 1:
|
| 137 |
+
return matches[0]
|
| 138 |
+
if len(matches) > 1:
|
| 139 |
+
raise SystemExit(
|
| 140 |
+
f"Manifest sha1 matched multiple dataset entries ({len(matches)}). "
|
| 141 |
+
"Please pass --example_idx to select by index or use a smaller dataset cache."
|
| 142 |
+
)
|
| 143 |
+
raise SystemExit(
|
| 144 |
+
"Failed to locate the trace example in the provided dataset by sha1. "
|
| 145 |
+
"Ensure --dataset points to the same cached JSONL used to produce the trace."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if not (0 <= int(example_idx) < len(examples)):
|
| 149 |
+
raise SystemExit(f"example_idx out of range: {example_idx} not in [0, {len(examples)}).")
|
| 150 |
+
return examples[int(example_idx)]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _extract_vh(attr: Any) -> np.ndarray:
|
| 154 |
+
ifr = (getattr(attr, "metadata", None) or {}).get("ifr") or {}
|
| 155 |
+
per_hop = ifr.get("per_hop_projected") or []
|
| 156 |
+
if not per_hop:
|
| 157 |
+
raise RuntimeError("Attribution result missing metadata['ifr']['per_hop_projected']; cannot build vh.")
|
| 158 |
+
stacked = torch.stack([torch.as_tensor(v, dtype=torch.float32).reshape(-1) for v in per_hop], dim=0)
|
| 159 |
+
return stacked.detach().cpu().numpy().astype(np.float32, copy=False)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _run_ifr_attr(
|
| 163 |
+
attr_func: str,
|
| 164 |
+
*,
|
| 165 |
+
model: Any,
|
| 166 |
+
tokenizer: Any,
|
| 167 |
+
prompt: str,
|
| 168 |
+
target: str,
|
| 169 |
+
sink_span: Optional[tuple[int, int]],
|
| 170 |
+
thinking_span: Optional[tuple[int, int]],
|
| 171 |
+
n_hops: int,
|
| 172 |
+
chunk_tokens: int,
|
| 173 |
+
sink_chunk_tokens: int,
|
| 174 |
+
) -> Any:
|
| 175 |
+
if attr_func == "ifr_multi_hop":
|
| 176 |
+
attributor = llm_attr.LLMIFRAttribution(
|
| 177 |
+
model,
|
| 178 |
+
tokenizer,
|
| 179 |
+
chunk_tokens=chunk_tokens,
|
| 180 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 181 |
+
)
|
| 182 |
+
return attributor.calculate_ifr_multi_hop(
|
| 183 |
+
prompt,
|
| 184 |
+
target=target,
|
| 185 |
+
sink_span=sink_span,
|
| 186 |
+
thinking_span=thinking_span,
|
| 187 |
+
n_hops=int(n_hops),
|
| 188 |
+
)
|
| 189 |
+
if attr_func == "ifr_in_all_gen":
|
| 190 |
+
attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(
|
| 191 |
+
model,
|
| 192 |
+
tokenizer,
|
| 193 |
+
chunk_tokens=chunk_tokens,
|
| 194 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 195 |
+
)
|
| 196 |
+
return attributor.calculate_ifr_in_all_gen(
|
| 197 |
+
prompt,
|
| 198 |
+
target=target,
|
| 199 |
+
sink_span=sink_span,
|
| 200 |
+
thinking_span=thinking_span,
|
| 201 |
+
n_hops=int(n_hops),
|
| 202 |
+
)
|
| 203 |
+
if attr_func == "ifr_multi_hop_stop_words":
|
| 204 |
+
attributor = ft_ifr_improve.LLMIFRAttributionImproved(
|
| 205 |
+
model,
|
| 206 |
+
tokenizer,
|
| 207 |
+
chunk_tokens=chunk_tokens,
|
| 208 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 209 |
+
)
|
| 210 |
+
return attributor.calculate_ifr_multi_hop_stop_words(
|
| 211 |
+
prompt,
|
| 212 |
+
target=target,
|
| 213 |
+
sink_span=sink_span,
|
| 214 |
+
thinking_span=thinking_span,
|
| 215 |
+
n_hops=int(n_hops),
|
| 216 |
+
)
|
| 217 |
+
if attr_func == "ifr_multi_hop_both":
|
| 218 |
+
attributor = ft_ifr_improve.LLMIFRAttributionBoth(
|
| 219 |
+
model,
|
| 220 |
+
tokenizer,
|
| 221 |
+
chunk_tokens=chunk_tokens,
|
| 222 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 223 |
+
)
|
| 224 |
+
return attributor.calculate_ifr_multi_hop_both(
|
| 225 |
+
prompt,
|
| 226 |
+
target=target,
|
| 227 |
+
sink_span=sink_span,
|
| 228 |
+
thinking_span=thinking_span,
|
| 229 |
+
n_hops=int(n_hops),
|
| 230 |
+
)
|
| 231 |
+
if attr_func == "ifr_multi_hop_split_hop":
|
| 232 |
+
attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(
|
| 233 |
+
model,
|
| 234 |
+
tokenizer,
|
| 235 |
+
chunk_tokens=chunk_tokens,
|
| 236 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 237 |
+
)
|
| 238 |
+
return attributor.calculate_ifr_multi_hop_split_hop(
|
| 239 |
+
prompt,
|
| 240 |
+
target=target,
|
| 241 |
+
sink_span=sink_span,
|
| 242 |
+
thinking_span=thinking_span,
|
| 243 |
+
n_hops=int(n_hops),
|
| 244 |
+
)
|
| 245 |
+
raise SystemExit(
|
| 246 |
+
f"Unsupported --attr_func '{attr_func}'. "
|
| 247 |
+
"Supported (vh-capable IFR variants): "
|
| 248 |
+
"ifr_multi_hop, ifr_in_all_gen, ifr_multi_hop_stop_words, ifr_multi_hop_both, ifr_multi_hop_split_hop."
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _save_npz(
|
| 253 |
+
out_path: Path,
|
| 254 |
+
*,
|
| 255 |
+
payload: dict[str, np.ndarray],
|
| 256 |
+
inplace_src: Optional[Path] = None,
|
| 257 |
+
backup: bool = True,
|
| 258 |
+
overwrite_backup: bool = False,
|
| 259 |
+
) -> None:
|
| 260 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 261 |
+
if inplace_src is not None:
|
| 262 |
+
if backup and inplace_src.exists():
|
| 263 |
+
backup_path = inplace_src.with_name(inplace_src.name + ".bak")
|
| 264 |
+
if overwrite_backup and backup_path.exists():
|
| 265 |
+
backup_path.unlink()
|
| 266 |
+
if not backup_path.exists():
|
| 267 |
+
backup_path.write_bytes(inplace_src.read_bytes())
|
| 268 |
+
|
| 269 |
+
# NOTE: numpy.savez* appends ".npz" if the filename does not already end with ".npz".
|
| 270 |
+
# So we must ensure our temporary path ends with ".npz", otherwise we'd write
|
| 271 |
+
# "<name>.tmp.npz" but later try to os.replace("<name>.tmp", ...).
|
| 272 |
+
tmp_path = out_path.with_name(out_path.stem + ".tmp.npz")
|
| 273 |
+
if tmp_path.exists():
|
| 274 |
+
tmp_path.unlink()
|
| 275 |
+
np.savez_compressed(tmp_path, **payload)
|
| 276 |
+
os.replace(tmp_path, out_path)
|
| 277 |
+
return
|
| 278 |
+
|
| 279 |
+
if out_path.exists():
|
| 280 |
+
raise SystemExit(f"Refusing to overwrite existing file: {out_path} (use --inplace).")
|
| 281 |
+
np.savez_compressed(out_path, **payload)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def main() -> None:
|
| 285 |
+
parser = argparse.ArgumentParser("One-off exp2 trace patcher: add per-hop vh vectors.")
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--trace_npz",
|
| 288 |
+
type=str,
|
| 289 |
+
default=(
|
| 290 |
+
"exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/"
|
| 291 |
+
"ifr_multi_hop_both_n1_mfaithfulness_gen_95ex/ex_000026.npz"
|
| 292 |
+
),
|
| 293 |
+
help="Path to the existing exp2 trace npz (ex_*.npz).",
|
| 294 |
+
)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--dataset",
|
| 297 |
+
type=str,
|
| 298 |
+
default="exp/exp2/data/morehopqa.jsonl",
|
| 299 |
+
help="Path to the exp2 cached dataset JSONL used to produce the trace.",
|
| 300 |
+
)
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--attr_func",
|
| 303 |
+
type=str,
|
| 304 |
+
default="ifr_multi_hop_both",
|
| 305 |
+
help="Attribution method to rerun (vh-capable IFR variants only).",
|
| 306 |
+
)
|
| 307 |
+
parser.add_argument("--example_idx", type=int, default=None, help="Override example_idx (0-based).")
|
| 308 |
+
parser.add_argument("--sample", type=int, default=None, help="If the original run used --sample, set it here.")
|
| 309 |
+
parser.add_argument("--seed", type=int, default=42, help="Seed for --sample shuffling (must match original).")
|
| 310 |
+
|
| 311 |
+
parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (used when --model_path not set).")
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--model_path",
|
| 314 |
+
type=str,
|
| 315 |
+
default="/opt/share/models/Qwen/Qwen3-8B/",
|
| 316 |
+
help="Local model path; overrides --model for loading (matches exp2 README examples).",
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--cuda",
|
| 320 |
+
type=str,
|
| 321 |
+
default="2,3,4,5,6,7",
|
| 322 |
+
help="CUDA selection (same semantics as exp2): '0' or '0,1,2'.",
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument("--cuda_num", type=int, default=0, help="Single-device index when --cuda not set.")
|
| 325 |
+
|
| 326 |
+
parser.add_argument("--chunk_tokens", type=int, default=128)
|
| 327 |
+
parser.add_argument("--sink_chunk_tokens", type=int, default=32)
|
| 328 |
+
parser.add_argument("--n_hops", type=int, default=1)
|
| 329 |
+
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--inplace",
|
| 332 |
+
action="store_true",
|
| 333 |
+
help="Overwrite the trace npz in place (recommended so manifest.jsonl stays valid).",
|
| 334 |
+
)
|
| 335 |
+
parser.add_argument("--no_backup", action="store_true", help="Disable .bak creation when using --inplace.")
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--overwrite_backup",
|
| 338 |
+
action="store_true",
|
| 339 |
+
help="Allow replacing an existing .bak when using --inplace.",
|
| 340 |
+
)
|
| 341 |
+
args = parser.parse_args()
|
| 342 |
+
|
| 343 |
+
trace_npz = Path(args.trace_npz)
|
| 344 |
+
if not trace_npz.exists():
|
| 345 |
+
raise SystemExit(f"Missing trace npz: {trace_npz}")
|
| 346 |
+
|
| 347 |
+
example_idx = args.example_idx
|
| 348 |
+
if example_idx is None:
|
| 349 |
+
example_idx = _parse_example_idx_from_npz_name(trace_npz)
|
| 350 |
+
if example_idx is None:
|
| 351 |
+
raise SystemExit("Failed to infer --example_idx from trace filename; please pass --example_idx explicitly.")
|
| 352 |
+
|
| 353 |
+
manifest_path = trace_npz.with_name("manifest.jsonl")
|
| 354 |
+
record = _load_manifest_record(manifest_path, example_idx=int(example_idx))
|
| 355 |
+
|
| 356 |
+
dataset_path = Path(args.dataset)
|
| 357 |
+
if not dataset_path.exists():
|
| 358 |
+
raise SystemExit(f"Missing cached dataset JSONL: {dataset_path}")
|
| 359 |
+
examples = ds_utils.load_cached(dataset_path, sample=args.sample, seed=args.seed)
|
| 360 |
+
ex = _pick_example(examples, example_idx=int(example_idx), record=record)
|
| 361 |
+
|
| 362 |
+
if ex.target is None:
|
| 363 |
+
raise SystemExit("Cached dataset example has target=None; this script requires cached targets (CoT+answer).")
|
| 364 |
+
prompt = ex.prompt
|
| 365 |
+
target = ex.target
|
| 366 |
+
|
| 367 |
+
sink_span = tuple(ex.sink_span) if ex.sink_span else None
|
| 368 |
+
thinking_span = tuple(ex.thinking_span) if ex.thinking_span else None
|
| 369 |
+
|
| 370 |
+
model_name = str(args.model_path or args.model).strip()
|
| 371 |
+
if not model_name:
|
| 372 |
+
raise SystemExit("Please set --model or --model_path.")
|
| 373 |
+
device = _resolve_device(args.cuda, args.cuda_num)
|
| 374 |
+
model, tokenizer = _load_model(model_name, device)
|
| 375 |
+
|
| 376 |
+
attr = _run_ifr_attr(
|
| 377 |
+
str(args.attr_func),
|
| 378 |
+
model=model,
|
| 379 |
+
tokenizer=tokenizer,
|
| 380 |
+
prompt=prompt,
|
| 381 |
+
target=target,
|
| 382 |
+
sink_span=sink_span,
|
| 383 |
+
thinking_span=thinking_span,
|
| 384 |
+
n_hops=int(args.n_hops),
|
| 385 |
+
chunk_tokens=int(args.chunk_tokens),
|
| 386 |
+
sink_chunk_tokens=int(args.sink_chunk_tokens),
|
| 387 |
+
)
|
| 388 |
+
vh = _extract_vh(attr)
|
| 389 |
+
|
| 390 |
+
with np.load(trace_npz, allow_pickle=False) as old:
|
| 391 |
+
payload = {k: old[k] for k in old.files}
|
| 392 |
+
payload["vh"] = vh
|
| 393 |
+
|
| 394 |
+
if args.inplace:
|
| 395 |
+
out_path = trace_npz
|
| 396 |
+
else:
|
| 397 |
+
out_path = trace_npz.with_name(trace_npz.stem + "_with_vh.npz")
|
| 398 |
+
|
| 399 |
+
_save_npz(
|
| 400 |
+
out_path,
|
| 401 |
+
payload=payload,
|
| 402 |
+
inplace_src=trace_npz if args.inplace else None,
|
| 403 |
+
backup=not bool(args.no_backup),
|
| 404 |
+
overwrite_backup=bool(args.overwrite_backup),
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
print(f"Saved vh -> {out_path}")
|
| 408 |
+
print(f"vh shape: {vh.shape} (n_hops+1, prompt_len+gen_len)")
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if __name__ == "__main__":
|
| 412 |
+
main()
|
evaluations/attribution_recovery.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Ensure project root is importable regardless of CWD
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import csv
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
import random
|
| 12 |
+
import time
|
| 13 |
+
from itertools import islice
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, utils
|
| 21 |
+
|
| 22 |
+
import llm_attr
|
| 23 |
+
import llm_attr_eval
|
| 24 |
+
from exp.exp2 import dataset_utils as ds_utils
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
utils.logging.set_verbosity_error()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _first_json_obj(path: Path) -> dict:
|
| 31 |
+
with path.open("r", encoding="utf-8") as f:
|
| 32 |
+
for line in f:
|
| 33 |
+
line = line.strip()
|
| 34 |
+
if line:
|
| 35 |
+
return json.loads(line)
|
| 36 |
+
return {}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _load_ruler_examples(args) -> Tuple[str, List[ds_utils.CachedExample]]:
|
| 40 |
+
ds_arg = args.dataset
|
| 41 |
+
cache_dir = Path(args.data_root)
|
| 42 |
+
|
| 43 |
+
# 1) If dataset points to an existing file, detect cache vs raw RULER.
|
| 44 |
+
p = Path(ds_arg)
|
| 45 |
+
if p.exists():
|
| 46 |
+
obj = _first_json_obj(p)
|
| 47 |
+
if "prompt" in obj:
|
| 48 |
+
return p.stem, ds_utils.load_cached(p, sample=args.sample, seed=args.seed)
|
| 49 |
+
if "input" in obj and "needle_spans" in obj:
|
| 50 |
+
return p.stem, ds_utils.load_ruler(p, sample=args.sample, seed=args.seed)
|
| 51 |
+
raise SystemExit(
|
| 52 |
+
f"Unsupported JSONL schema for recovery_ruler: {p}. "
|
| 53 |
+
"Expected either exp2 cache (has 'prompt') or raw RULER JSONL (has 'input'+'needle_spans')."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# 2) Prefer exp2 cache under --data_root by dataset name.
|
| 57 |
+
cached = cache_dir / f"{ds_arg}.jsonl"
|
| 58 |
+
if cached.exists():
|
| 59 |
+
return ds_arg, ds_utils.load_cached(cached, sample=args.sample, seed=args.seed)
|
| 60 |
+
|
| 61 |
+
# 3) Fall back to raw RULER resolution by name.
|
| 62 |
+
resolved = ds_utils.dataset_from_name(ds_arg)
|
| 63 |
+
if resolved is None:
|
| 64 |
+
raise SystemExit(f"Could not resolve RULER dataset name '{ds_arg}'.")
|
| 65 |
+
return ds_arg, ds_utils.load_ruler(resolved, sample=args.sample, seed=args.seed)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _resolve_indices_to_explain_token_span(
|
| 69 |
+
attr_result: llm_attr.LLMAttributionResult, indices_to_explain: list[int] | None
|
| 70 |
+
) -> list[int]:
|
| 71 |
+
if (
|
| 72 |
+
isinstance(indices_to_explain, list)
|
| 73 |
+
and len(indices_to_explain) == 2
|
| 74 |
+
and all(isinstance(x, int) and x >= 0 for x in indices_to_explain)
|
| 75 |
+
and indices_to_explain[0] <= indices_to_explain[1]
|
| 76 |
+
):
|
| 77 |
+
return indices_to_explain
|
| 78 |
+
|
| 79 |
+
gen_len = int(attr_result.attribution_matrix.shape[0])
|
| 80 |
+
if gen_len <= 0:
|
| 81 |
+
return [0, 0]
|
| 82 |
+
|
| 83 |
+
# Default: explain the full generation excluding the appended EOS token.
|
| 84 |
+
end_tok = max(0, gen_len - 2)
|
| 85 |
+
return [0, end_tok]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def run_attribution(
|
| 89 |
+
testing_dict, example: ds_utils.CachedExample, batch_size: int, target: Optional[str]
|
| 90 |
+
) -> tuple[List[torch.Tensor], dict | None]:
|
| 91 |
+
model = testing_dict["model"]
|
| 92 |
+
tokenizer = testing_dict["tokenizer"]
|
| 93 |
+
attr_func = testing_dict["attr_func"]
|
| 94 |
+
|
| 95 |
+
if "IG" in attr_func:
|
| 96 |
+
llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
|
| 97 |
+
attr = llm_attributor.calculate_IG_per_generation(
|
| 98 |
+
example.prompt,
|
| 99 |
+
20,
|
| 100 |
+
tokenizer.eos_token_id,
|
| 101 |
+
batch_size=batch_size,
|
| 102 |
+
target=target,
|
| 103 |
+
)
|
| 104 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 105 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 106 |
+
|
| 107 |
+
if "perturbation" in attr_func:
|
| 108 |
+
llm_attributor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
|
| 109 |
+
if attr_func == "perturbation_all":
|
| 110 |
+
attr = llm_attributor.calculate_feature_ablation_sentences(
|
| 111 |
+
example.prompt, baseline=tokenizer.eos_token_id, measure="log_loss", target=target
|
| 112 |
+
)
|
| 113 |
+
elif attr_func == "perturbation_CLP":
|
| 114 |
+
attr = llm_attributor.calculate_feature_ablation_sentences(
|
| 115 |
+
example.prompt, baseline=tokenizer.eos_token_id, measure="KL", target=target
|
| 116 |
+
)
|
| 117 |
+
elif attr_func == "perturbation_REAGENT":
|
| 118 |
+
attr = llm_attributor.calculate_feature_ablation_sentences_mlm(example.prompt, target=target)
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(f"Unsupported perturbation attr_func {attr_func}")
|
| 121 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 122 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 123 |
+
|
| 124 |
+
if "attention" in attr_func:
|
| 125 |
+
llm_attributor = llm_attr.LLMAttentionAttribution(model, tokenizer)
|
| 126 |
+
llm_attributor_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
|
| 127 |
+
attr = llm_attributor.calculate_attention_attribution(example.prompt, target=target)
|
| 128 |
+
if attr_func == "attention_I_G":
|
| 129 |
+
attr_b = llm_attributor_ig.calculate_IG_per_generation(
|
| 130 |
+
example.prompt, 20, tokenizer.eos_token_id, batch_size=batch_size, target=target
|
| 131 |
+
)
|
| 132 |
+
attr.attribution_matrix = attr.attribution_matrix * attr_b.attribution_matrix
|
| 133 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 134 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 135 |
+
|
| 136 |
+
if attr_func == "ifr_all_positions":
|
| 137 |
+
llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
|
| 138 |
+
attr = llm_attributor.calculate_ifr_for_all_positions(example.prompt, target=target)
|
| 139 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 140 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 141 |
+
|
| 142 |
+
if attr_func == "ifr_all_positions_output_only":
|
| 143 |
+
llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
|
| 144 |
+
sink_span = tuple(example.sink_span) if example.sink_span else None
|
| 145 |
+
attr = llm_attributor.calculate_ifr_for_all_positions_output_only(
|
| 146 |
+
example.prompt,
|
| 147 |
+
target=target,
|
| 148 |
+
sink_span=sink_span,
|
| 149 |
+
)
|
| 150 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 151 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 152 |
+
|
| 153 |
+
if attr_func == "ifr_span":
|
| 154 |
+
llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
|
| 155 |
+
span = example.sink_span if example.sink_span else None
|
| 156 |
+
attr = llm_attributor.calculate_ifr_span(example.prompt, target=target, span=tuple(span) if span else None)
|
| 157 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 158 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 159 |
+
|
| 160 |
+
if attr_func == "ifr_multi_hop":
|
| 161 |
+
llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
|
| 162 |
+
attr = llm_attributor.calculate_ifr_multi_hop(
|
| 163 |
+
example.prompt,
|
| 164 |
+
target=target,
|
| 165 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 166 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 167 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 168 |
+
)
|
| 169 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 170 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 171 |
+
|
| 172 |
+
if attr_func == "ifr_in_all_gen":
|
| 173 |
+
import ft_ifr_improve
|
| 174 |
+
|
| 175 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(model, tokenizer)
|
| 176 |
+
attr = llm_attributor.calculate_ifr_in_all_gen(
|
| 177 |
+
example.prompt,
|
| 178 |
+
target=target,
|
| 179 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 180 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 181 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 182 |
+
)
|
| 183 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 184 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 185 |
+
|
| 186 |
+
if attr_func == "ifr_multi_hop_stop_words":
|
| 187 |
+
import ft_ifr_improve
|
| 188 |
+
|
| 189 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionImproved(model, tokenizer)
|
| 190 |
+
attr = llm_attributor.calculate_ifr_multi_hop_stop_words(
|
| 191 |
+
example.prompt,
|
| 192 |
+
target=target,
|
| 193 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 194 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 195 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 196 |
+
)
|
| 197 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 198 |
+
extra = {
|
| 199 |
+
"keep_prompt_token_indices": ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens)),
|
| 200 |
+
}
|
| 201 |
+
return list(attr.get_all_token_attrs(token_span)), extra
|
| 202 |
+
|
| 203 |
+
if attr_func == "ifr_multi_hop_both":
|
| 204 |
+
import ft_ifr_improve
|
| 205 |
+
|
| 206 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionBoth(model, tokenizer)
|
| 207 |
+
attr = llm_attributor.calculate_ifr_multi_hop_both(
|
| 208 |
+
example.prompt,
|
| 209 |
+
target=target,
|
| 210 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 211 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 212 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 213 |
+
)
|
| 214 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 215 |
+
extra = {
|
| 216 |
+
"keep_prompt_token_indices": ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens)),
|
| 217 |
+
}
|
| 218 |
+
return list(attr.get_all_token_attrs(token_span)), extra
|
| 219 |
+
|
| 220 |
+
if attr_func == "ifr_multi_hop_split_hop":
|
| 221 |
+
import ft_ifr_improve
|
| 222 |
+
|
| 223 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(model, tokenizer)
|
| 224 |
+
attr = llm_attributor.calculate_ifr_multi_hop_split_hop(
|
| 225 |
+
example.prompt,
|
| 226 |
+
target=target,
|
| 227 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 228 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 229 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 230 |
+
)
|
| 231 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 232 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 233 |
+
|
| 234 |
+
if attr_func == "basic":
|
| 235 |
+
llm_attributor = llm_attr.LLMBasicAttribution(model, tokenizer)
|
| 236 |
+
attr = llm_attributor.calculate_basic_attribution(example.prompt, target=target)
|
| 237 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 238 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 239 |
+
|
| 240 |
+
if attr_func == "attnlrp":
|
| 241 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 242 |
+
sink_span = getattr(example, "sink_span", None)
|
| 243 |
+
thinking_span = getattr(example, "thinking_span", None)
|
| 244 |
+
attr = llm_attributor.calculate_attnlrp_ft_hop0(
|
| 245 |
+
example.prompt,
|
| 246 |
+
target=target,
|
| 247 |
+
sink_span=tuple(sink_span) if sink_span else None,
|
| 248 |
+
thinking_span=tuple(thinking_span) if thinking_span else None,
|
| 249 |
+
)
|
| 250 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 251 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 252 |
+
|
| 253 |
+
if attr_func == "attnlrp_aggregated":
|
| 254 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 255 |
+
attr = llm_attributor.calculate_attnlrp_aggregated(example.prompt, target=target)
|
| 256 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 257 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 258 |
+
|
| 259 |
+
if attr_func == "attnlrp_aggregated_multi_hop":
|
| 260 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 261 |
+
attr = llm_attributor.calculate_attnlrp_aggregated_multi_hop(
|
| 262 |
+
example.prompt,
|
| 263 |
+
target=target,
|
| 264 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 265 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 266 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 267 |
+
)
|
| 268 |
+
token_span = _resolve_indices_to_explain_token_span(attr, example.indices_to_explain)
|
| 269 |
+
return list(attr.get_all_token_attrs(token_span)), None
|
| 270 |
+
|
| 271 |
+
raise ValueError(f"Unsupported attribution function '{attr_func}'.")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def evaluate_dataset_recovery_ruler(testing_dict, dataset_name: str, examples: List[ds_utils.CachedExample]) -> Tuple[np.ndarray, np.ndarray, float, int, int]:
|
| 275 |
+
tokenizer = testing_dict["tokenizer"]
|
| 276 |
+
llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(testing_dict["model"], tokenizer)
|
| 277 |
+
|
| 278 |
+
results: List[np.ndarray] = []
|
| 279 |
+
durations: List[float] = []
|
| 280 |
+
skipped = 0
|
| 281 |
+
|
| 282 |
+
num_examples = testing_dict["num_examples"]
|
| 283 |
+
total = min(len(examples), num_examples)
|
| 284 |
+
iterator = islice(examples, total)
|
| 285 |
+
|
| 286 |
+
description = f"Recovery@10pct {testing_dict['model_name']} {dataset_name} {testing_dict['attr_func']}"
|
| 287 |
+
for ex in tqdm(iterator, desc=description, total=total):
|
| 288 |
+
needle_spans = (ex.metadata or {}).get("needle_spans")
|
| 289 |
+
if not isinstance(needle_spans, list) or not needle_spans:
|
| 290 |
+
raise SystemExit(
|
| 291 |
+
"recovery_ruler only supports RULER examples with metadata.needle_spans; "
|
| 292 |
+
f"dataset={dataset_name} has missing/empty needle_spans."
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer)
|
| 296 |
+
if not gold_prompt:
|
| 297 |
+
skipped += 1
|
| 298 |
+
continue
|
| 299 |
+
|
| 300 |
+
# Batch size is set based on the max_input_len (same policy as faithfulness).
|
| 301 |
+
target = ex.target
|
| 302 |
+
if target is None:
|
| 303 |
+
generation, full_output = llm_evaluator.response(ex.prompt)
|
| 304 |
+
target = generation
|
| 305 |
+
response_len = len(tokenizer(full_output).input_ids)
|
| 306 |
+
else:
|
| 307 |
+
response_len = len(tokenizer(llm_evaluator.format_prompt(" " + ex.prompt) + target).input_ids)
|
| 308 |
+
batch_size = max(1, math.floor((testing_dict["max_input_len"] - 100) / max(1, response_len)))
|
| 309 |
+
|
| 310 |
+
sample_start = time.perf_counter()
|
| 311 |
+
attr_list, extra = run_attribution(testing_dict, ex, batch_size, target)
|
| 312 |
+
durations.append(time.perf_counter() - sample_start)
|
| 313 |
+
|
| 314 |
+
seq_attr = attr_list[0]
|
| 315 |
+
prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
|
| 316 |
+
if prompt_len <= 0:
|
| 317 |
+
skipped += 1
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
if testing_dict["attr_func"] in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and extra is not None:
|
| 321 |
+
import ft_ifr_improve
|
| 322 |
+
|
| 323 |
+
keep_prompt_token_indices = extra.get("keep_prompt_token_indices") or []
|
| 324 |
+
gold_filtered = [idx for idx in gold_prompt if int(idx) in set(int(x) for x in keep_prompt_token_indices)]
|
| 325 |
+
if not gold_filtered:
|
| 326 |
+
skipped += 1
|
| 327 |
+
continue
|
| 328 |
+
scores = [
|
| 329 |
+
ft_ifr_improve.evaluate_attr_recovery_skip_tokens(
|
| 330 |
+
attr[:, :prompt_len],
|
| 331 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 332 |
+
gold_prompt_token_indices=gold_prompt,
|
| 333 |
+
top_fraction=0.1,
|
| 334 |
+
)
|
| 335 |
+
for attr in attr_list
|
| 336 |
+
]
|
| 337 |
+
else:
|
| 338 |
+
scores = [
|
| 339 |
+
llm_evaluator.evaluate_attr_recovery(
|
| 340 |
+
attr,
|
| 341 |
+
prompt_len=prompt_len,
|
| 342 |
+
gold_prompt_token_indices=gold_prompt,
|
| 343 |
+
top_fraction=0.1,
|
| 344 |
+
)
|
| 345 |
+
for attr in attr_list
|
| 346 |
+
]
|
| 347 |
+
results.append(np.asarray(scores, dtype=np.float64))
|
| 348 |
+
|
| 349 |
+
scores = np.stack(results, axis=0) if results else np.zeros((0, 3), dtype=np.float64)
|
| 350 |
+
used = int(scores.shape[0])
|
| 351 |
+
mean = scores.mean(0) if used else np.full((3,), np.nan, dtype=np.float64)
|
| 352 |
+
std = scores.std(0) if used else np.full((3,), np.nan, dtype=np.float64)
|
| 353 |
+
avg_time = float(np.mean(durations)) if durations else 0.0
|
| 354 |
+
return mean, std, avg_time, used, int(skipped)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def load_model(model_name: str, device: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 358 |
+
seed = 42
|
| 359 |
+
random.seed(seed)
|
| 360 |
+
np.random.seed(seed)
|
| 361 |
+
torch.manual_seed(seed)
|
| 362 |
+
torch.cuda.manual_seed(seed)
|
| 363 |
+
torch.cuda.manual_seed_all(seed)
|
| 364 |
+
|
| 365 |
+
if device == "auto":
|
| 366 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 367 |
+
model_name,
|
| 368 |
+
device_map="auto",
|
| 369 |
+
attn_implementation="eager",
|
| 370 |
+
torch_dtype=torch.float16,
|
| 371 |
+
)
|
| 372 |
+
elif isinstance(device, str) and device.startswith("cuda:"):
|
| 373 |
+
try:
|
| 374 |
+
gpu_idx = int(device.split(":")[1])
|
| 375 |
+
except Exception:
|
| 376 |
+
gpu_idx = 0
|
| 377 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 378 |
+
model_name,
|
| 379 |
+
device_map={"": gpu_idx},
|
| 380 |
+
attn_implementation="eager",
|
| 381 |
+
torch_dtype=torch.float16,
|
| 382 |
+
)
|
| 383 |
+
else:
|
| 384 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 385 |
+
model_name,
|
| 386 |
+
attn_implementation="eager",
|
| 387 |
+
torch_dtype=torch.float16,
|
| 388 |
+
)
|
| 389 |
+
model.eval()
|
| 390 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 391 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 392 |
+
return model, tokenizer
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def main(args) -> None:
|
| 396 |
+
if args.cuda is not None and isinstance(args.cuda, str) and "," in args.cuda:
|
| 397 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 398 |
+
device = "auto"
|
| 399 |
+
elif args.cuda is not None and isinstance(args.cuda, str) and args.cuda.strip() != "":
|
| 400 |
+
try:
|
| 401 |
+
idx = int(args.cuda)
|
| 402 |
+
except Exception:
|
| 403 |
+
idx = 0
|
| 404 |
+
device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
|
| 405 |
+
else:
|
| 406 |
+
device = f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
|
| 407 |
+
|
| 408 |
+
if args.model == "llama-1B":
|
| 409 |
+
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
| 410 |
+
max_input_len = 5500
|
| 411 |
+
elif args.model == "llama-3B":
|
| 412 |
+
model_name = "meta-llama/Llama-3.2-3B-Instruct"
|
| 413 |
+
max_input_len = 4800
|
| 414 |
+
elif args.model == "llama-8B":
|
| 415 |
+
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 416 |
+
max_input_len = 3500
|
| 417 |
+
elif args.model == "qwen-1.7B":
|
| 418 |
+
model_name = "Qwen/Qwen3-1.7B"
|
| 419 |
+
max_input_len = 5500
|
| 420 |
+
elif args.model == "qwen-4B":
|
| 421 |
+
model_name = "Qwen/Qwen3-4B-Instruct-2507"
|
| 422 |
+
max_input_len = 3500
|
| 423 |
+
elif args.model == "qwen-8B":
|
| 424 |
+
model_name = "Qwen/Qwen3-8B"
|
| 425 |
+
max_input_len = 3000
|
| 426 |
+
elif args.model == "qwen-32B":
|
| 427 |
+
model_name = "Qwen/Qwen3-32B"
|
| 428 |
+
max_input_len = 1500
|
| 429 |
+
elif args.model == "gemma-12B":
|
| 430 |
+
model_name = "gemma/gemma-3-12b-it"
|
| 431 |
+
max_input_len = 1500
|
| 432 |
+
elif args.model == "gemma-27B":
|
| 433 |
+
model_name = "gemma/gemma-3-27b-it"
|
| 434 |
+
max_input_len = 2000
|
| 435 |
+
else:
|
| 436 |
+
model_name = args.model_path if args.model_path is not None else args.model
|
| 437 |
+
max_input_len = 2000
|
| 438 |
+
|
| 439 |
+
model, tokenizer = load_model(model_name if args.model_path is None else args.model_path, device)
|
| 440 |
+
|
| 441 |
+
dataset_name, examples = _load_ruler_examples(args)
|
| 442 |
+
|
| 443 |
+
testing_dict = {
|
| 444 |
+
"model": model,
|
| 445 |
+
"model_name": args.model,
|
| 446 |
+
"tokenizer": tokenizer,
|
| 447 |
+
"dataset_name": dataset_name,
|
| 448 |
+
"attr_func": args.attr_func,
|
| 449 |
+
"num_examples": args.num_examples,
|
| 450 |
+
"max_input_len": max_input_len,
|
| 451 |
+
"n_hops": args.n_hops,
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
mean, std, avg_time, used, skipped = evaluate_dataset_recovery_ruler(testing_dict, dataset_name, examples)
|
| 455 |
+
|
| 456 |
+
out_dir = Path("./test_results") / "attribution_recovery" / dataset_name / args.model
|
| 457 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 458 |
+
file_name = f"{args.attr_func}_{args.num_examples}_examples.csv"
|
| 459 |
+
with open(out_dir / file_name, "w", newline="") as f:
|
| 460 |
+
writer = csv.writer(f)
|
| 461 |
+
writer.writerow(["Method", "Recovery@10pct"])
|
| 462 |
+
writer.writerow(["Seq Attr Recovery Mean", mean[0]])
|
| 463 |
+
writer.writerow(["Row Attr Recovery Mean", mean[1]])
|
| 464 |
+
writer.writerow(["Recursive Attr Recovery Mean", mean[2]])
|
| 465 |
+
writer.writerow(["Seq Attr Recovery Std", std[0]])
|
| 466 |
+
writer.writerow(["Row Attr Recovery Std", std[1]])
|
| 467 |
+
writer.writerow(["Recursive Attr Recovery Std", std[2]])
|
| 468 |
+
writer.writerow(["Examples Used", used])
|
| 469 |
+
writer.writerow(["Examples Skipped", skipped])
|
| 470 |
+
writer.writerow(["Avg Sample Time (s)", avg_time])
|
| 471 |
+
|
| 472 |
+
print(f"[{dataset_name}] {args.attr_func} -> {out_dir/file_name} (used={used} skipped={skipped} avg {avg_time:.2f}s)")
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if __name__ == "__main__":
|
| 476 |
+
parser = argparse.ArgumentParser("RULER-only token-level attribution recovery evaluation (Recall@10pct).")
|
| 477 |
+
parser.add_argument("--num_examples", type=int, default=100, help="How many examples to evaluate.")
|
| 478 |
+
parser.add_argument("--sample", type=int, default=None, help="Optional subsample before num_examples.")
|
| 479 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 480 |
+
parser.add_argument("--model", type=str, default="qwen-8B")
|
| 481 |
+
parser.add_argument("--model_path", type=str, default=None, help="Optional local model path to load.")
|
| 482 |
+
parser.add_argument("--attr_func", type=str, default="ifr_multi_hop")
|
| 483 |
+
parser.add_argument("--cuda_num", type=int, default=0)
|
| 484 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 485 |
+
parser.add_argument("--dataset", type=str, required=True, help="RULER dataset name or JSONL path (raw or exp2 cache).")
|
| 486 |
+
parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache directory to search by dataset name.")
|
| 487 |
+
parser.add_argument("--n_hops", type=int, default=3)
|
| 488 |
+
|
| 489 |
+
args, _ = parser.parse_known_args()
|
| 490 |
+
main(args)
|
evaluations/attribution_recovery.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RULER-only token-level recovery (Recall@10pct) examples.
|
| 2 |
+
#
|
| 3 |
+
# Dataset can be:
|
| 4 |
+
# - a RULER name (hotpotqa_long / niah_* / vt_*) resolved under data/ruler_multihop/<len>/.../validation.jsonl
|
| 5 |
+
# - a raw RULER JSONL path
|
| 6 |
+
# - an exp2 cache JSONL path (must contain metadata.needle_spans)
|
| 7 |
+
|
| 8 |
+
# Example: evaluate on exp2 cache
|
| 9 |
+
# CUDA_VISIBLE_DEVICES=0 python3 evaluations/attribution_recovery.py \
|
| 10 |
+
# --model qwen-8B --model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 11 |
+
# --cuda 0 --num_examples 50 --attr_func ifr_multi_hop \
|
| 12 |
+
# --dataset exp/exp2/data/hotpotqa.jsonl
|
| 13 |
+
|
| 14 |
+
# Example: evaluate on raw RULER JSONL
|
| 15 |
+
# CUDA_VISIBLE_DEVICES=0 python3 evaluations/attribution_recovery.py \
|
| 16 |
+
# --model qwen-8B --model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 17 |
+
# --cuda 0 --num_examples 50 --attr_func ifr_multi_hop \
|
| 18 |
+
# --dataset data/ruler_multihop/4096/hotpotqa_long/validation.jsonl
|
evaluations/faithfulness.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
# Ensure project root is importable regardless of CWD
|
| 4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
|
| 6 |
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import utils
|
| 10 |
+
import math
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import random
|
| 13 |
+
import argparse
|
| 14 |
+
import csv
|
| 15 |
+
from itertools import islice
|
| 16 |
+
from typing import Tuple
|
| 17 |
+
from huggingface_hub import login
|
| 18 |
+
|
| 19 |
+
from attribution_datasets import (
|
| 20 |
+
AttributionDataset,
|
| 21 |
+
FactsAttributionDataset,
|
| 22 |
+
MathAttributionDataset,
|
| 23 |
+
MoreHopQAAttributionDataset,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
utils.logging.set_verbosity_error() # Suppress standard warnings
|
| 27 |
+
|
| 28 |
+
import llm_attr
|
| 29 |
+
import llm_attr_eval
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _resolve_indices_to_explain_token_span(
|
| 33 |
+
attr_result: llm_attr.LLMAttributionResult, indices_to_explain: list[int] | None
|
| 34 |
+
) -> list[int]:
|
| 35 |
+
if (
|
| 36 |
+
isinstance(indices_to_explain, list)
|
| 37 |
+
and len(indices_to_explain) == 2
|
| 38 |
+
and all(isinstance(x, int) and x >= 0 for x in indices_to_explain)
|
| 39 |
+
and indices_to_explain[0] <= indices_to_explain[1]
|
| 40 |
+
):
|
| 41 |
+
return indices_to_explain
|
| 42 |
+
|
| 43 |
+
gen_len = int(attr_result.attribution_matrix.shape[0])
|
| 44 |
+
if gen_len <= 0:
|
| 45 |
+
return [0, 0]
|
| 46 |
+
|
| 47 |
+
# Default: explain the full generation excluding the appended EOS token.
|
| 48 |
+
end_tok = max(0, gen_len - 2)
|
| 49 |
+
return [0, end_tok]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def run_attribution(testing_dict, prompt, batch_size, indices_to_explain = [1], target = None) -> tuple[list[torch.Tensor], dict | None]:
|
| 53 |
+
model = testing_dict["model"]
|
| 54 |
+
tokenizer = testing_dict["tokenizer"]
|
| 55 |
+
|
| 56 |
+
# Now we create an attribution for the full response
|
| 57 |
+
if "IG" in testing_dict["attr_func"]:
|
| 58 |
+
llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
|
| 59 |
+
|
| 60 |
+
if testing_dict["attr_func"] == "IG":
|
| 61 |
+
attr = llm_attributor.calculate_IG_per_generation(prompt, 20, tokenizer.eos_token_id, batch_size = batch_size, target = target)
|
| 62 |
+
|
| 63 |
+
token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
|
| 64 |
+
attributions = list(attr.get_all_token_attrs(token_span))
|
| 65 |
+
|
| 66 |
+
elif "perturbation" in testing_dict["attr_func"]:
|
| 67 |
+
llm_attributor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
|
| 68 |
+
|
| 69 |
+
if testing_dict["attr_func"] == "perturbation_all":
|
| 70 |
+
attr = llm_attributor.calculate_feature_ablation_sentences(prompt, baseline = tokenizer.eos_token_id, measure="log_loss", target = target)
|
| 71 |
+
elif testing_dict["attr_func"] == "perturbation_CLP":
|
| 72 |
+
attr = llm_attributor.calculate_feature_ablation_sentences(prompt, baseline = tokenizer.eos_token_id, measure="KL", target = target)
|
| 73 |
+
elif testing_dict["attr_func"] == "perturbation_REAGENT":
|
| 74 |
+
attr = llm_attributor.calculate_feature_ablation_sentences_mlm(prompt, target = target)
|
| 75 |
+
|
| 76 |
+
token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
|
| 77 |
+
attributions = list(attr.get_all_token_attrs(token_span))
|
| 78 |
+
|
| 79 |
+
elif "attention" in testing_dict["attr_func"]:
|
| 80 |
+
llm_attributor = llm_attr.LLMAttentionAttribution(model, tokenizer)
|
| 81 |
+
llm_attributor_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
|
| 82 |
+
|
| 83 |
+
if testing_dict["attr_func"] == "attention_I_G":
|
| 84 |
+
attr = llm_attributor.calculate_attention_attribution(prompt, target = target)
|
| 85 |
+
attr_b = llm_attributor_ig.calculate_IG_per_generation(prompt, 20, tokenizer.eos_token_id, batch_size = batch_size, target = target)
|
| 86 |
+
attr.attribution_matrix = attr.attribution_matrix * attr_b.attribution_matrix
|
| 87 |
+
|
| 88 |
+
token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
|
| 89 |
+
attributions = list(attr.get_all_token_attrs(token_span))
|
| 90 |
+
|
| 91 |
+
elif "ifr" in testing_dict["attr_func"].lower():
|
| 92 |
+
llm_attributor = llm_attr.LLMIFRAttribution(model, tokenizer)
|
| 93 |
+
attr_func = testing_dict["attr_func"].lower()
|
| 94 |
+
renorm_threshold = testing_dict.get("renorm_threshold")
|
| 95 |
+
|
| 96 |
+
if attr_func == "ifr_all_positions":
|
| 97 |
+
attr = llm_attributor.calculate_ifr_for_all_positions(prompt, target=target, renorm_threshold=renorm_threshold)
|
| 98 |
+
elif attr_func == "ifr_all_positions_output_only":
|
| 99 |
+
attr = llm_attributor.calculate_ifr_for_all_positions_output_only(
|
| 100 |
+
prompt,
|
| 101 |
+
target=target,
|
| 102 |
+
sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
|
| 103 |
+
renorm_threshold=renorm_threshold,
|
| 104 |
+
)
|
| 105 |
+
elif attr_func == "ifr_span":
|
| 106 |
+
span = testing_dict.get("sink_span")
|
| 107 |
+
attr = llm_attributor.calculate_ifr_span(
|
| 108 |
+
prompt,
|
| 109 |
+
target=target,
|
| 110 |
+
span=tuple(span) if span is not None else None,
|
| 111 |
+
renorm_threshold=renorm_threshold,
|
| 112 |
+
)
|
| 113 |
+
elif attr_func == "ifr_multi_hop":
|
| 114 |
+
attr = llm_attributor.calculate_ifr_multi_hop(
|
| 115 |
+
prompt,
|
| 116 |
+
target=target,
|
| 117 |
+
sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
|
| 118 |
+
thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
|
| 119 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 120 |
+
renorm_threshold=renorm_threshold,
|
| 121 |
+
observation_mask=testing_dict.get("observation_mask"),
|
| 122 |
+
)
|
| 123 |
+
elif attr_func == "ifr_in_all_gen":
|
| 124 |
+
import ft_ifr_improve
|
| 125 |
+
|
| 126 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(model, tokenizer)
|
| 127 |
+
attr = llm_attributor.calculate_ifr_in_all_gen(
|
| 128 |
+
prompt,
|
| 129 |
+
target=target,
|
| 130 |
+
sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
|
| 131 |
+
thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
|
| 132 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 133 |
+
renorm_threshold=renorm_threshold,
|
| 134 |
+
observation_mask=testing_dict.get("observation_mask"),
|
| 135 |
+
)
|
| 136 |
+
elif attr_func == "ifr_multi_hop_stop_words":
|
| 137 |
+
import ft_ifr_improve
|
| 138 |
+
|
| 139 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionImproved(model, tokenizer)
|
| 140 |
+
attr = llm_attributor.calculate_ifr_multi_hop_stop_words(
|
| 141 |
+
prompt,
|
| 142 |
+
target=target,
|
| 143 |
+
sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
|
| 144 |
+
thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
|
| 145 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 146 |
+
renorm_threshold=renorm_threshold,
|
| 147 |
+
observation_mask=testing_dict.get("observation_mask"),
|
| 148 |
+
)
|
| 149 |
+
elif attr_func == "ifr_multi_hop_both":
|
| 150 |
+
import ft_ifr_improve
|
| 151 |
+
|
| 152 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionBoth(model, tokenizer)
|
| 153 |
+
attr = llm_attributor.calculate_ifr_multi_hop_both(
|
| 154 |
+
prompt,
|
| 155 |
+
target=target,
|
| 156 |
+
sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
|
| 157 |
+
thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
|
| 158 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 159 |
+
renorm_threshold=renorm_threshold,
|
| 160 |
+
observation_mask=testing_dict.get("observation_mask"),
|
| 161 |
+
)
|
| 162 |
+
elif attr_func == "ifr_multi_hop_split_hop":
|
| 163 |
+
import ft_ifr_improve
|
| 164 |
+
|
| 165 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(model, tokenizer)
|
| 166 |
+
attr = llm_attributor.calculate_ifr_multi_hop_split_hop(
|
| 167 |
+
prompt,
|
| 168 |
+
target=target,
|
| 169 |
+
sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
|
| 170 |
+
thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
|
| 171 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 172 |
+
renorm_threshold=renorm_threshold,
|
| 173 |
+
observation_mask=testing_dict.get("observation_mask"),
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(f"Unsupported IFR attribution function '{testing_dict['attr_func']}'.")
|
| 177 |
+
|
| 178 |
+
token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
|
| 179 |
+
attributions = list(attr.get_all_token_attrs(token_span))
|
| 180 |
+
|
| 181 |
+
elif "basic" in testing_dict["attr_func"]:
|
| 182 |
+
llm_attributor = llm_attr.LLMBasicAttribution(model, tokenizer)
|
| 183 |
+
attr = llm_attributor.calculate_basic_attribution(prompt, target = target)
|
| 184 |
+
token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
|
| 185 |
+
attributions = list(attr.get_all_token_attrs(token_span))
|
| 186 |
+
|
| 187 |
+
elif testing_dict["attr_func"] == "attnlrp":
|
| 188 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 189 |
+
attr = llm_attributor.calculate_attnlrp_ft_hop0(prompt, target=target)
|
| 190 |
+
token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
|
| 191 |
+
attributions = list(attr.get_all_token_attrs(token_span))
|
| 192 |
+
|
| 193 |
+
elif testing_dict["attr_func"] == "attnlrp_aggregated":
|
| 194 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 195 |
+
attr = llm_attributor.calculate_attnlrp_aggregated(prompt, target=target)
|
| 196 |
+
token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
|
| 197 |
+
attributions = list(attr.get_all_token_attrs(token_span))
|
| 198 |
+
|
| 199 |
+
elif testing_dict["attr_func"] == "attnlrp_aggregated_multi_hop":
|
| 200 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 201 |
+
attr = llm_attributor.calculate_attnlrp_aggregated_multi_hop(
|
| 202 |
+
prompt,
|
| 203 |
+
target=target,
|
| 204 |
+
sink_span=tuple(testing_dict.get("sink_span")) if testing_dict.get("sink_span") is not None else None,
|
| 205 |
+
thinking_span=tuple(testing_dict.get("thinking_span")) if testing_dict.get("thinking_span") is not None else None,
|
| 206 |
+
n_hops=testing_dict.get("n_hops", 1),
|
| 207 |
+
)
|
| 208 |
+
token_span = _resolve_indices_to_explain_token_span(attr, indices_to_explain)
|
| 209 |
+
attributions = list(attr.get_all_token_attrs(token_span))
|
| 210 |
+
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError(f"Unsupported attribution function '{testing_dict['attr_func']}'.")
|
| 213 |
+
|
| 214 |
+
extra = None
|
| 215 |
+
if testing_dict["attr_func"].lower() in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both"):
|
| 216 |
+
import ft_ifr_improve
|
| 217 |
+
|
| 218 |
+
extra = {
|
| 219 |
+
"keep_prompt_token_indices": ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens)),
|
| 220 |
+
"user_prompt_indices": list(getattr(llm_attributor, "user_prompt_indices", []) or []),
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return attributions, extra
|
| 224 |
+
|
| 225 |
+
def faithfulness_test(testing_dict, llm_evaluator, prompt, indices_to_explain, target = None) -> np.ndarray[float]:
|
| 226 |
+
tokenizer = testing_dict["tokenizer"]
|
| 227 |
+
faithfulness_k = int(testing_dict.get("faithfulness_k", 20))
|
| 228 |
+
|
| 229 |
+
scores = []
|
| 230 |
+
|
| 231 |
+
# batch size is set based on the max_input_len in main(). Currently set to fully fill a 196GB GPU.
|
| 232 |
+
if target is None:
|
| 233 |
+
generation, full_output = llm_evaluator.response(prompt)
|
| 234 |
+
batch_size = math.floor((testing_dict["max_input_len"] - 100) / len(tokenizer(full_output).input_ids))
|
| 235 |
+
else:
|
| 236 |
+
generation = target
|
| 237 |
+
batch_size = math.floor(
|
| 238 |
+
(testing_dict["max_input_len"] - 100)
|
| 239 |
+
/ len(tokenizer(llm_evaluator.format_prompt(" " + prompt) + generation).input_ids)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# We run an attribution on the input
|
| 243 |
+
# A list of attribution tensors will be returned and scored individually.
|
| 244 |
+
attr_list, extra = run_attribution(testing_dict, prompt, batch_size, indices_to_explain = indices_to_explain, target = target)
|
| 245 |
+
|
| 246 |
+
seq_attr = attr_list[0]
|
| 247 |
+
prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
|
| 248 |
+
|
| 249 |
+
for i in range(len(attr_list)):
|
| 250 |
+
attr = attr_list[i][:, :prompt_len]
|
| 251 |
+
if testing_dict["attr_func"].lower() in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and extra is not None:
|
| 252 |
+
import ft_ifr_improve
|
| 253 |
+
|
| 254 |
+
scores.append(
|
| 255 |
+
ft_ifr_improve.faithfulness_test_skip_tokens(
|
| 256 |
+
llm_evaluator,
|
| 257 |
+
attr,
|
| 258 |
+
prompt,
|
| 259 |
+
generation,
|
| 260 |
+
keep_prompt_token_indices=extra.get("keep_prompt_token_indices") or [],
|
| 261 |
+
user_prompt_indices=extra.get("user_prompt_indices"),
|
| 262 |
+
k=faithfulness_k,
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
else:
|
| 266 |
+
scores.append(llm_evaluator.faithfulness_test(attr, prompt, generation, k=faithfulness_k)) # [3 scores]
|
| 267 |
+
|
| 268 |
+
return np.array(scores)
|
| 269 |
+
|
| 270 |
+
def clean_trailing_space(text) -> str:
|
| 271 |
+
if text[-1] == ' ':
|
| 272 |
+
return text[:-1]
|
| 273 |
+
else:
|
| 274 |
+
return text
|
| 275 |
+
|
| 276 |
+
def evaluate_attribution(testing_dict) -> None:
|
| 277 |
+
model = testing_dict["model"]
|
| 278 |
+
tokenizer = testing_dict["tokenizer"]
|
| 279 |
+
|
| 280 |
+
llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(model, tokenizer)
|
| 281 |
+
|
| 282 |
+
scores = []
|
| 283 |
+
|
| 284 |
+
description = "Faithfulness " + testing_dict["model_name"] + " " + testing_dict["dataset_name"] + " " + testing_dict["attr_func"]
|
| 285 |
+
|
| 286 |
+
dataset: AttributionDataset = testing_dict["dataset"]
|
| 287 |
+
num_examples = testing_dict["num_examples"]
|
| 288 |
+
total = min(len(dataset), num_examples) if hasattr(dataset, "__len__") else num_examples
|
| 289 |
+
example_iterator = islice(dataset, num_examples)
|
| 290 |
+
|
| 291 |
+
for example in tqdm(example_iterator, desc=description, total=total):
|
| 292 |
+
indices_to_explain = example.indices_to_explain if example.indices_to_explain is not None else [-2]
|
| 293 |
+
scores.append(
|
| 294 |
+
faithfulness_test(
|
| 295 |
+
testing_dict,
|
| 296 |
+
llm_evaluator,
|
| 297 |
+
example.prompt,
|
| 298 |
+
indices_to_explain=indices_to_explain,
|
| 299 |
+
target=example.target,
|
| 300 |
+
)
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
scores = np.array(scores) # [num_examples, num_attrs, 3 scores]
|
| 304 |
+
scores_mean = scores.mean(0) # [num_attrs, 3 scores]
|
| 305 |
+
scores_var = scores.std(0) # [num_attrs, 3 scores]
|
| 306 |
+
|
| 307 |
+
# make the test folder if it doesn't exist
|
| 308 |
+
folder = "./test_results/faithfulness/" + testing_dict["dataset_name"] + "/" + testing_dict["model_name"] + "/"
|
| 309 |
+
if not os.path.exists(folder):
|
| 310 |
+
os.makedirs(folder)
|
| 311 |
+
|
| 312 |
+
# save all data
|
| 313 |
+
file_name = testing_dict["attr_func"] + "_" + str(testing_dict["num_examples"]) + "_examples"
|
| 314 |
+
with open(folder + file_name + ".csv", 'w') as f:
|
| 315 |
+
write = csv.writer(f)
|
| 316 |
+
|
| 317 |
+
write.writerow(["Method", "RISE", "MAS", "RISE + AP"])
|
| 318 |
+
|
| 319 |
+
write.writerow(["Seq Attr Scores Mean"] + scores_mean[0].tolist())
|
| 320 |
+
write.writerow(["Row Attr Scores Mean"] + scores_mean[1].tolist())
|
| 321 |
+
write.writerow(["Recursive Attr Scores Mean"] + scores_mean[2].tolist())
|
| 322 |
+
|
| 323 |
+
write.writerow(["Seq Attr Scores Var"] + scores_var[0].tolist())
|
| 324 |
+
write.writerow(["Row Attr Scores Var"] + scores_var[1].tolist())
|
| 325 |
+
write.writerow(["Recursive Attr Scores Var"] + scores_var[2].tolist())
|
| 326 |
+
|
| 327 |
+
return
|
| 328 |
+
|
| 329 |
+
def load_model(model_name, device) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 330 |
+
seed = 42
|
| 331 |
+
random.seed(seed)
|
| 332 |
+
np.random.seed(seed)
|
| 333 |
+
torch.manual_seed(seed)
|
| 334 |
+
torch.cuda.manual_seed(seed)
|
| 335 |
+
torch.cuda.manual_seed_all(seed) # if multi-GPU
|
| 336 |
+
|
| 337 |
+
# Respect three modes:
|
| 338 |
+
# - device == 'auto' -> multi-GPU sharding across all visible devices
|
| 339 |
+
# - device startswith('cuda:IDX') -> place entire model on a single GPU IDX (relative to visible devices)
|
| 340 |
+
# - device == 'cpu' -> CPU
|
| 341 |
+
if device == "auto":
|
| 342 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 343 |
+
model_name,
|
| 344 |
+
device_map="auto",
|
| 345 |
+
attn_implementation="eager",
|
| 346 |
+
torch_dtype=torch.float16,
|
| 347 |
+
)
|
| 348 |
+
elif isinstance(device, str) and device.startswith("cuda:"):
|
| 349 |
+
try:
|
| 350 |
+
gpu_idx = int(device.split(":")[1])
|
| 351 |
+
except Exception:
|
| 352 |
+
gpu_idx = 0
|
| 353 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 354 |
+
model_name,
|
| 355 |
+
device_map={"": gpu_idx},
|
| 356 |
+
attn_implementation="eager",
|
| 357 |
+
torch_dtype=torch.float16,
|
| 358 |
+
)
|
| 359 |
+
else:
|
| 360 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 361 |
+
model_name,
|
| 362 |
+
attn_implementation="eager",
|
| 363 |
+
torch_dtype=torch.float16,
|
| 364 |
+
)
|
| 365 |
+
model.eval()
|
| 366 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 367 |
+
|
| 368 |
+
# Needed for LLaMA tokenizer
|
| 369 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 370 |
+
|
| 371 |
+
return model, tokenizer
|
| 372 |
+
|
| 373 |
+
def main(args) -> None:
|
| 374 |
+
# login(token = "")
|
| 375 |
+
|
| 376 |
+
# Device selection policy (mirrors attribution_recovery):
|
| 377 |
+
# - If --cuda is a comma-separated list (e.g. "0,1"), set visibility to that list and shard with device_map='auto'.
|
| 378 |
+
# - If --cuda is a single index (e.g. "0"), do NOT override CUDA_VISIBLE_DEVICES; place model on cuda:{index}.
|
| 379 |
+
# - Else (no --cuda), use --cuda_num as single-device index relative to current visibility.
|
| 380 |
+
if args.cuda is not None and isinstance(args.cuda, str) and "," in args.cuda:
|
| 381 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 382 |
+
device = "auto"
|
| 383 |
+
elif args.cuda is not None and isinstance(args.cuda, str) and args.cuda.strip() != "":
|
| 384 |
+
try:
|
| 385 |
+
idx = int(args.cuda)
|
| 386 |
+
except Exception:
|
| 387 |
+
idx = 0
|
| 388 |
+
device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
|
| 389 |
+
else:
|
| 390 |
+
device = f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
|
| 391 |
+
|
| 392 |
+
# set up model
|
| 393 |
+
if args.model == "llama-1B":
|
| 394 |
+
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
| 395 |
+
max_input_len = 5500
|
| 396 |
+
elif args.model == "llama-3B":
|
| 397 |
+
model_name = "meta-llama/Llama-3.2-3B-Instruct"
|
| 398 |
+
max_input_len = 4800
|
| 399 |
+
elif args.model == "llama-8B":
|
| 400 |
+
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 401 |
+
max_input_len = 3500
|
| 402 |
+
elif args.model == "qwen-1.7B":
|
| 403 |
+
model_name = "Qwen/Qwen3-1.7B"
|
| 404 |
+
max_input_len = 5500
|
| 405 |
+
elif args.model == "qwen-4B":
|
| 406 |
+
model_name = "Qwen/Qwen3-4B-Instruct-2507"
|
| 407 |
+
max_input_len = 3500
|
| 408 |
+
elif args.model == "qwen-8B":
|
| 409 |
+
model_name = "Qwen/Qwen3-8B"
|
| 410 |
+
max_input_len = 3000
|
| 411 |
+
elif args.model == "qwen-32B":
|
| 412 |
+
model_name = "Qwen/Qwen3-32B"
|
| 413 |
+
max_input_len = 1500
|
| 414 |
+
elif args.model == "gemma-12B":
|
| 415 |
+
model_name = "gemma/gemma-3-12b-it"
|
| 416 |
+
max_input_len = 1500
|
| 417 |
+
elif args.model == "gemma-27B":
|
| 418 |
+
model_name = "gemma/gemma-3-27b-it"
|
| 419 |
+
max_input_len = 2000
|
| 420 |
+
else:
|
| 421 |
+
model_name = args.model_path if args.model_path is not None else args.model
|
| 422 |
+
max_input_len = 2000
|
| 423 |
+
|
| 424 |
+
model, tokenizer = load_model(model_name if args.model_path is None else args.model_path, device)
|
| 425 |
+
|
| 426 |
+
dataset_registry = {
|
| 427 |
+
"math": lambda: MathAttributionDataset("./data/math_mine.json", tokenizer),
|
| 428 |
+
"facts": lambda: FactsAttributionDataset("./data/10000_facts_9_choose_3.json"),
|
| 429 |
+
"morehopqa": lambda: MoreHopQAAttributionDataset("./data/with_human_verification.json"),
|
| 430 |
+
}
|
| 431 |
+
dataset_loader = dataset_registry.get(args.dataset)
|
| 432 |
+
if dataset_loader is None:
|
| 433 |
+
print("You have not specified an acceptable dataset. Exiting.")
|
| 434 |
+
exit()
|
| 435 |
+
dataset = dataset_loader()
|
| 436 |
+
|
| 437 |
+
testing_dict = {
|
| 438 |
+
"model" : model,
|
| 439 |
+
"model_name": args.model,
|
| 440 |
+
"tokenizer" : tokenizer,
|
| 441 |
+
"dataset" : dataset,
|
| 442 |
+
"dataset_name" : args.dataset,
|
| 443 |
+
"max_input_len": max_input_len,
|
| 444 |
+
"attr_func": args.attr_func,
|
| 445 |
+
"num_examples": args.num_examples,
|
| 446 |
+
"device": device,
|
| 447 |
+
"faithfulness_k": args.faithfulness_k,
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
# call the test function
|
| 451 |
+
evaluate_attribution(testing_dict)
|
| 452 |
+
|
| 453 |
+
return
|
| 454 |
+
|
| 455 |
+
if __name__ == "__main__":
|
| 456 |
+
parser = argparse.ArgumentParser('')
|
| 457 |
+
parser.add_argument('--num_examples',
|
| 458 |
+
type = int, default = 100,
|
| 459 |
+
help='How many dataset examples to test with.')
|
| 460 |
+
parser.add_argument('--model',
|
| 461 |
+
type = str,
|
| 462 |
+
default = "llama",
|
| 463 |
+
help='Model to use: llama or qwen')
|
| 464 |
+
parser.add_argument('--model_path',
|
| 465 |
+
type=str, default=None,
|
| 466 |
+
help='Optional local model path to load (overrides model repo id only).')
|
| 467 |
+
parser.add_argument('--attr_func',
|
| 468 |
+
type = str,
|
| 469 |
+
default = "IG",
|
| 470 |
+
help="attr to use: \
|
| 471 |
+
grad, IG, IG_captum, contextcite, attention, rollout, perturbation \
|
| 472 |
+
")
|
| 473 |
+
parser.add_argument('--cuda_num',
|
| 474 |
+
type=int, default = 0,
|
| 475 |
+
help='The number of the GPU you want to use.')
|
| 476 |
+
parser.add_argument('--cuda',
|
| 477 |
+
type=str, default=None,
|
| 478 |
+
help='GPU selection: use comma-separated ids for multi-GPU sharding (e.g. "0,1"); use a single index for one GPU relative to current CUDA_VISIBLE_DEVICES (e.g. "0").')
|
| 479 |
+
parser.add_argument('--dataset',
|
| 480 |
+
type = str, default = "math",
|
| 481 |
+
help = 'The dataset to evaluate on: math, facts, or morehopqa')
|
| 482 |
+
parser.add_argument(
|
| 483 |
+
"--faithfulness_k",
|
| 484 |
+
type=int,
|
| 485 |
+
default=20,
|
| 486 |
+
help="Total perturbation steps k for MAS/RISE (each step perturbs ~1/k of prompt tokens).",
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
args, unparsed = parser.parse_known_args()
|
| 490 |
+
|
| 491 |
+
main(args)
|
evaluations/faithfulness.sh
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func IG --dataset facts
|
| 2 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset facts
|
| 3 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func IG --dataset facts
|
| 4 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset facts
|
| 5 |
+
|
| 6 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset facts
|
| 7 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset facts
|
| 8 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset facts
|
| 9 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset facts
|
| 10 |
+
|
| 11 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset facts
|
| 12 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset facts
|
| 13 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset facts
|
| 14 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset facts
|
| 15 |
+
|
| 16 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset facts
|
| 17 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset facts
|
| 18 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset facts
|
| 19 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset facts
|
| 20 |
+
|
| 21 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset facts
|
| 22 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset facts
|
| 23 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset facts
|
| 24 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset facts
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func IG --dataset math
|
| 29 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset math
|
| 30 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func IG --dataset math
|
| 31 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset math
|
| 32 |
+
|
| 33 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset math
|
| 34 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset math
|
| 35 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset math
|
| 36 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset math
|
| 37 |
+
|
| 38 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset math
|
| 39 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset math
|
| 40 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset math
|
| 41 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset math
|
| 42 |
+
|
| 43 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset math
|
| 44 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset math
|
| 45 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset math
|
| 46 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset math
|
| 47 |
+
|
| 48 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset math
|
| 49 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset math
|
| 50 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset math
|
| 51 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset math
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func IG --dataset morehopqa
|
| 56 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset morehopqa
|
| 57 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func IG --dataset morehopqa
|
| 58 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func IG --dataset morehopqa
|
| 59 |
+
|
| 60 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset morehopqa
|
| 61 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset morehopqa
|
| 62 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset morehopqa
|
| 63 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func attention_I_G --dataset morehopqa
|
| 64 |
+
|
| 65 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset morehopqa
|
| 66 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset morehopqa
|
| 67 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset morehopqa
|
| 68 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_CLP --dataset morehopqa
|
| 69 |
+
|
| 70 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset morehopqa
|
| 71 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset morehopqa
|
| 72 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset morehopqa
|
| 73 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_REAGENT --dataset morehopqa
|
| 74 |
+
|
| 75 |
+
# python3 faithfulness.py --model llama-3B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset morehopqa
|
| 76 |
+
# python3 faithfulness.py --model llama-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset morehopqa
|
| 77 |
+
# python3 faithfulness.py --model qwen-4B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset morehopqa
|
| 78 |
+
# python3 faithfulness.py --model qwen-8B --cuda_num 0 --num_examples 500 --attr_func perturbation_all --dataset morehopqa
|
| 79 |
+
|
| 80 |
+
CUDA_VISIBLE_DEVICES=4,6 python3 evaluations/faithfulness.py --model qwen-8B --model_path /opt/share/models/Qwen/Qwen3-8B/ --cuda '0,1' --num_examples 50 --attr_func IG --dataset math
|
example.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/quickstart.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
from flashtrace import FlashTrace, load_model_and_tokenizer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 9 |
+
parser = argparse.ArgumentParser(description="FlashTrace quickstart example.")
|
| 10 |
+
parser.add_argument("--model", required=True, help="Hugging Face model id or local model path.")
|
| 11 |
+
parser.add_argument("--prompt", required=True, help="Prompt text.")
|
| 12 |
+
parser.add_argument("--target", help="Target response text.")
|
| 13 |
+
parser.add_argument("--output-span", default=None, help="Inclusive generation-token span START:END.")
|
| 14 |
+
parser.add_argument("--reasoning-span", default=None, help="Inclusive generation-token span START:END.")
|
| 15 |
+
parser.add_argument("--html", default="trace.html", help="Output HTML path.")
|
| 16 |
+
parser.add_argument("--use-chat-template", action="store_true", help="Format prompts with the tokenizer chat template.")
|
| 17 |
+
return parser
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def parse_span(value: str | None) -> tuple[int, int] | None:
|
| 21 |
+
from flashtrace.cli import parse_span as parse_cli_span
|
| 22 |
+
|
| 23 |
+
return parse_cli_span(value)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> int:
|
| 27 |
+
args = build_parser().parse_args()
|
| 28 |
+
model, tokenizer = load_model_and_tokenizer(args.model)
|
| 29 |
+
tracer = FlashTrace(model, tokenizer, use_chat_template=args.use_chat_template)
|
| 30 |
+
trace = tracer.trace(
|
| 31 |
+
prompt=args.prompt,
|
| 32 |
+
target=args.target,
|
| 33 |
+
output_span=parse_span(args.output_span),
|
| 34 |
+
reasoning_span=parse_span(args.reasoning_span),
|
| 35 |
+
)
|
| 36 |
+
for item in trace.topk_inputs(10):
|
| 37 |
+
print(f"{item.index}\t{item.score:.6f}\t{item.token!r}")
|
| 38 |
+
trace.to_html(args.html)
|
| 39 |
+
print(f"wrote {args.html}")
|
| 40 |
+
return 0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
raise SystemExit(main())
|
exp/case_study/README.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FT 多跳案例分析 & IFR 标准可视化(exp/case_study)
|
| 2 |
+
|
| 3 |
+
此目录提供一个轻量的单样本 IFR 可视化流程,不改动核心评测代码。
|
| 4 |
+
|
| 5 |
+
## 功能
|
| 6 |
+
- 读取单个样本(默认 `exp/exp2/data/morehopqa.jsonl`,索引 0)。
|
| 7 |
+
- 支持多种模式:
|
| 8 |
+
- `ft`:当前使用的多跳 FT 归因(内部调用 `LLMIFRAttribution.calculate_ifr_multi_hop`)。
|
| 9 |
+
- `ifr`:标准 IFR(单 hop),默认对指定 sink span 做**聚合 IFR**(只显示 1 个面板)。
|
| 10 |
+
- `ifr_all_positions_output_only`:只对 `sink_span` 范围内的 output tokens 计算 IFR token-level 矩阵,并基于该矩阵得到 Row / Recursive(CAGE)两张面板。
|
| 11 |
+
- `attnlrp`:AttnLRP hop0(复用 FT-AttnLRP 的 span-aggregate 逻辑,等价于 `LLMLRPAttribution.calculate_attnlrp_multi_hop(n_hops=0)`,并可视化 `raw_attributions[0].token_importance_total`)。
|
| 12 |
+
- `ft_attnlrp`:FT-AttnLRP(严格复用 `LLMLRPAttribution.calculate_attnlrp_aggregated_multi_hop`,与 `exp/exp2/` 保持一致;直接可视化每 hop 的 `token_importance_total`)。
|
| 13 |
+
- 可视化两个视图:
|
| 14 |
+
- **裁剪前 token 级(full)**:带 chat template 的完整序列热力图(template + user prompt + generation)。
|
| 15 |
+
- **Prompt-only token 级**:只显示 user prompt tokens 的热力图(不包含 generation tokens)。
|
| 16 |
+
- 热力图按 `|score|` 上色(不区分正负);每个面板的 full/prompt 两张图各自用 p99.5(`|score|`) 独立归一化颜色深度。
|
| 17 |
+
- 输出 JSON(完整数值)和 HTML(逐跳热力图)。
|
| 18 |
+
- 额外提供 MAS(faithfulness / token perturbation)可视化:对指定归因方法做 token 级扰动评估,并渲染扰动影响热力图 + MAS 分数。
|
| 19 |
+
|
| 20 |
+
## 快速开始
|
| 21 |
+
```bash
|
| 22 |
+
# 根据本地模型修改 model/model_path
|
| 23 |
+
# 多跳 FT(默认) ft_split_hop,ft_improve
|
| 24 |
+
python exp/case_study/run_ifr_case.py \
|
| 25 |
+
--mode ft_split_hop \
|
| 26 |
+
--dataset exp/exp2/data/morehopqa.jsonl \
|
| 27 |
+
--index 0 \
|
| 28 |
+
--model qwen-8B \
|
| 29 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 30 |
+
--cuda 0 \
|
| 31 |
+
--n_hops 3
|
| 32 |
+
|
| 33 |
+
# 标准 IFR(单 hop,可指定 sink span)
|
| 34 |
+
python exp/case_study/run_ifr_case.py \
|
| 35 |
+
--mode ifr \
|
| 36 |
+
--dataset exp/exp2/data/morehopqa.jsonl \
|
| 37 |
+
--index 0 \
|
| 38 |
+
--model qwen-8B \
|
| 39 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 40 |
+
--cuda 0 \
|
| 41 |
+
--sink_span 0 0
|
| 42 |
+
|
| 43 |
+
# IFR output-only:只在 output 范围计算 IFR 矩阵,并生成 Row/Recursive(CAGE)两面板
|
| 44 |
+
python exp/case_study/run_ifr_case.py \
|
| 45 |
+
--mode ifr_all_positions_output_only \
|
| 46 |
+
--dataset exp/exp2/data/short-morehopqa.jsonl \
|
| 47 |
+
--index 0 \
|
| 48 |
+
--model qwen-8B \
|
| 49 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 50 |
+
--cuda 0
|
| 51 |
+
|
| 52 |
+
# AttnLRP hop0(复用 FT-AttnLRP span-aggregate;可视化 hop0 raw 向量)
|
| 53 |
+
python exp/case_study/run_ifr_case.py \
|
| 54 |
+
--mode attnlrp \
|
| 55 |
+
--dataset exp/exp2/data/morehopqa.jsonl \
|
| 56 |
+
--index 0 \
|
| 57 |
+
--model qwen-8B \
|
| 58 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 59 |
+
--cuda 0 \
|
| 60 |
+
--sink_span 0 20
|
| 61 |
+
|
| 62 |
+
# FT-attnLRP(多跳递归 AttnLRP)
|
| 63 |
+
python exp/case_study/run_ifr_case.py \
|
| 64 |
+
--mode ft_attnlrp \
|
| 65 |
+
--dataset exp/exp2/data/morehopqa.jsonl \
|
| 66 |
+
--index 0 \
|
| 67 |
+
--model qwen-8B \
|
| 68 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 69 |
+
--cuda 0,2,3,4,5,7 \
|
| 70 |
+
--n_hops 3 \
|
| 71 |
+
--attnlrp_neg_handling abs \
|
| 72 |
+
--attnlrp_norm_mode norm
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
产物位于 `exp/case_study/out/`,文件名前缀根据模式变化,例如:
|
| 76 |
+
- `ft_case_<dataset>_idx<idx>.json/html`
|
| 77 |
+
- `ifr_case_<dataset>_idx<idx>.json/html`
|
| 78 |
+
- `ifr_output_only_case_<dataset>_idx<idx>.json/html`
|
| 79 |
+
- `attnlrp_case_<dataset>_idx<idx>.json/html`
|
| 80 |
+
- `ft_attnlrp_case_<dataset>_idx<idx>.json/html`
|
| 81 |
+
|
| 82 |
+
## MAS(Faithfulness / Token Perturbation)可视化
|
| 83 |
+
|
| 84 |
+
> 说明:这里的 MAS 与项目 `llm_attr_eval.LLMAttributionEvaluator.faithfulness_test()` 保持一致:
|
| 85 |
+
> 1) 先对样本跑指定方法的归因,并取 token-level attribution(Seq / Row / Recursive)。
|
| 86 |
+
> 2) 按 prompt token 的重要性排序,逐步将 token id 替换为 `tokenizer.pad_token_id`(token 级扰动)。
|
| 87 |
+
> 3) 用 `sum log p(generation + EOS | prompt)` 得到分数曲线,计算 RISE / MAS / RISE+AP。
|
| 88 |
+
> 4) 可视化时用“每一步扰动带来的边际 logprob 变化”作为 token 分数,渲染为 token spans 的“扰动影响热力图”。
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
# FT-IFR(ifr_multi_hop;默认 --method ft)
|
| 92 |
+
python exp/case_study/run_mas_case.py \
|
| 93 |
+
--dataset exp/exp2/data/short-morehopqa.jsonl \
|
| 94 |
+
--index 0 \
|
| 95 |
+
--model qwen-8B \
|
| 96 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 97 |
+
--cuda 0 \
|
| 98 |
+
--method ft \
|
| 99 |
+
--n_hops 3
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
常用方法选择(与 `run_ifr_case.py` 的模式名对齐):
|
| 103 |
+
```bash
|
| 104 |
+
# IFR(需要 sink_span;默认会优先使用数据集缓存字段)
|
| 105 |
+
python exp/case_study/run_mas_case.py --method ifr --sink_span 0 20 ...
|
| 106 |
+
|
| 107 |
+
# IFR output-only(仅对 sink_span 内的 output token 计算 IFR token-level matrix)
|
| 108 |
+
python exp/case_study/run_mas_case.py --method ifr_all_positions_output_only --sink_span 0 20 ...
|
| 109 |
+
|
| 110 |
+
# FT-IFR(ifr_multi_hop)
|
| 111 |
+
python exp/case_study/run_mas_case.py --method ft --n_hops 1 --sink_span 0 20 --thinking_span 0 20 ...
|
| 112 |
+
|
| 113 |
+
# AttnLRP hop0(复用 FT-AttnLRP hop0;仍然需要 indices_to_explain/sink_span 来取 Seq/Row/Rec)
|
| 114 |
+
python exp/case_study/run_mas_case.py --method attnlrp --sink_span 0 20 ...
|
| 115 |
+
|
| 116 |
+
# FT-AttnLRP(attnlrp_aggregated_multi_hop)
|
| 117 |
+
python exp/case_study/run_mas_case.py --method ft_attnlrp --n_hops 1 --sink_span 0 20 --thinking_span 0 20 ...
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
产物位于 `exp/case_study/out/`,文件名前缀为:
|
| 121 |
+
- `mas_case_<method>_<dataset>_idx<idx>.json/html`
|
| 122 |
+
|
| 123 |
+
HTML 默认包含 3 个 attribution 视角面板(Seq / Row / Recursive),每个面板里有 2 行 token 级热力图:
|
| 124 |
+
- **Method attribution(token weights)**:该方法的 token 归因权重(用于排序/密度)。
|
| 125 |
+
- **Attribution-guided MAS marginal(path deltas)**:按归因排序逐步替换的边际影响(这就是评测中实际使用的扰动路径)。
|
| 126 |
+
|
| 127 |
+
## 在浏览器中查看 HTML
|
| 128 |
+
1) 先运行上面的命令生成 `.html`(终端会打印形如 `wrote exp/case_study/out/...html`)。
|
| 129 |
+
|
| 130 |
+
2) 在仓库根目录启动一个静态文件服务(任选一个端口,例如 8888):
|
| 131 |
+
```bash
|
| 132 |
+
python -m http.server 8888 --directory exp/case_study/out
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
3) 用浏览器打开(注意是 `http://`,不是 `https://`):
|
| 136 |
+
- 本机:`http://127.0.0.1:8888/<你的html文件名>`
|
| 137 |
+
- 远程机器(推荐端口转发):在本地执行 `ssh -L 8888:127.0.0.1:8888 <user>@<server>`,然后在本地浏览器打开 `http://127.0.0.1:8888/<你的html文件名>`
|
| 138 |
+
|
| 139 |
+
如果你在 `http.server` 日志里看到大量 `400 Bad request version` 且伴随乱码,通常是有客户端用 HTTPS 去连了 HTTP 端口;请确认浏览器地址栏是 `http://...`。
|
| 140 |
+
|
| 141 |
+
## 可选参数
|
| 142 |
+
- `--sink_span a b` / `--thinking_span a b`:覆盖生成侧的 sink/thinking 句子 span(默认使用缓存字段)。
|
| 143 |
+
- `--attnlrp_neg_handling drop|abs`:FT-AttnLRP 每跳负值处理(drop=clamp>=0,abs=取绝对值)。
|
| 144 |
+
- `--attnlrp_norm_mode norm|no_norm`:FT-AttnLRP 正则化与 hop ratio 开关(norm=全局+thinking 归一化并启用 ratio;no_norm=三者都禁用)。
|
| 145 |
+
- `--chunk_tokens` / `--sink_chunk_tokens`:IFR 分块参数。
|
| 146 |
+
- `--output_dir`:修改输出目录。
|
| 147 |
+
|
| 148 |
+
## 文件说明
|
| 149 |
+
- `run_ifr_case.py`:命令行入口与落盘(支持 `ft`/`ifr`/`ifr_all_positions_output_only`/`attnlrp`/`ft_attnlrp` 模式)。
|
| 150 |
+
- `run_mas_case.py`:MAS(faithfulness / token perturbation)可视化入口与落盘(支持 `ifr`/`ifr_all_positions_output_only`/`ft`/`attnlrp`/`ft_attnlrp`)。
|
| 151 |
+
- `analysis.py`:逐跳清洗与封装(token-level)。
|
| 152 |
+
- `viz.py`:HTML 渲染与热力图。
|
exp/case_study/analysis.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for IFR case studies (hop-wise aggregation + sanitization).
|
| 2 |
+
|
| 3 |
+
All utilities stay local to exp/case_study to avoid touching core eval code.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def vector_stats(vec: torch.Tensor) -> Dict[str, float]:
|
| 14 |
+
if vec.numel() == 0:
|
| 15 |
+
return {"min": 0.0, "max": 0.0, "abs_max": 0.0, "mean": 0.0, "sum": 0.0}
|
| 16 |
+
v = vec.detach().to(dtype=torch.float32)
|
| 17 |
+
return {
|
| 18 |
+
"min": float(v.min().item()),
|
| 19 |
+
"max": float(v.max().item()),
|
| 20 |
+
"abs_max": float(v.abs().max().item()),
|
| 21 |
+
"mean": float(v.mean().item()),
|
| 22 |
+
"sum": float(v.sum().item()),
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def tensor_to_list(x: Any) -> Any:
|
| 27 |
+
if torch.is_tensor(x):
|
| 28 |
+
return x.detach().cpu().tolist()
|
| 29 |
+
if isinstance(x, list):
|
| 30 |
+
return [tensor_to_list(v) for v in x]
|
| 31 |
+
if isinstance(x, dict):
|
| 32 |
+
return {k: tensor_to_list(v) for k, v in x.items()}
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def sanitize_ifr_meta(meta: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| 37 |
+
"""Drop bulky raw objects and convert tensors to Python lists for JSON."""
|
| 38 |
+
|
| 39 |
+
if meta is None:
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
cleaned: Dict[str, Any] = {}
|
| 43 |
+
for key, value in meta.items():
|
| 44 |
+
if key == "raw":
|
| 45 |
+
continue
|
| 46 |
+
cleaned[key] = tensor_to_list(value)
|
| 47 |
+
return cleaned
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def package_token_hops(
|
| 51 |
+
hop_vectors: Iterable[Sequence[float]],
|
| 52 |
+
) -> List[Dict[str, Any]]:
|
| 53 |
+
"""Package per-hop token vectors without sentence aggregation.
|
| 54 |
+
|
| 55 |
+
hop_vectors are assumed to already match the experiment's configured
|
| 56 |
+
postprocessing (e.g., FT-AttnLRP neg_handling/norm_mode).
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
packaged: List[Dict[str, Any]] = []
|
| 60 |
+
for hop_idx, vec in enumerate(hop_vectors):
|
| 61 |
+
vec_tensor = torch.nan_to_num(torch.as_tensor(vec, dtype=torch.float32), nan=0.0)
|
| 62 |
+
token_scores = vec_tensor.tolist()
|
| 63 |
+
token_max = float(vec_tensor.abs().max().item()) if vec_tensor.numel() > 0 else 0.0
|
| 64 |
+
total = float(vec_tensor.sum().item())
|
| 65 |
+
packaged.append(
|
| 66 |
+
{
|
| 67 |
+
"hop": hop_idx,
|
| 68 |
+
"token_scores": token_scores,
|
| 69 |
+
"token_score_max": token_max,
|
| 70 |
+
"token_stats": vector_stats(vec_tensor),
|
| 71 |
+
"total_mass": total,
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
return packaged
|
exp/case_study/faithfulness_trace.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Faithfulness (MAS/RISE) trace utilities for exp/case_study.
|
| 2 |
+
|
| 3 |
+
This module is intentionally aligned with `llm_attr_eval.LLMAttributionEvaluator.faithfulness_test`,
|
| 4 |
+
but additionally returns the full trace arrays needed for visualization and supports providing
|
| 5 |
+
`user_prompt_indices` to avoid fragile subsequence matching.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Any, Dict, Optional, Sequence, List
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
import llm_attr_eval
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _auc(arr: np.ndarray) -> float:
|
| 19 |
+
return float((arr.sum() - arr[0] / 2 - arr[-1] / 2) / max(1, (arr.shape[0] - 1)))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@torch.inference_mode()
|
| 23 |
+
def mas_trace(
|
| 24 |
+
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
|
| 25 |
+
*,
|
| 26 |
+
attribution: torch.Tensor,
|
| 27 |
+
prompt: str,
|
| 28 |
+
generation: str,
|
| 29 |
+
user_prompt_indices: Optional[Sequence[int]] = None,
|
| 30 |
+
k: int = 20,
|
| 31 |
+
) -> Dict[str, Any]:
|
| 32 |
+
"""Return a token-level faithfulness trace (RISE/MAS/RISE+AP) plus per-token deltas.
|
| 33 |
+
|
| 34 |
+
attribution: [R, P] token attribution on prompt-side tokens only.
|
| 35 |
+
prompt: raw prompt string.
|
| 36 |
+
generation: target generation string; scored as generation + eos (if defined).
|
| 37 |
+
user_prompt_indices: optional absolute positions of each prompt token inside formatted prompt ids.
|
| 38 |
+
k: number of perturbation steps; each step perturbs ~1/k of prompt tokens.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
if attribution.ndim != 2:
|
| 42 |
+
raise ValueError("Expected 2D prompt-side attribution matrix [R, P].")
|
| 43 |
+
|
| 44 |
+
pad_token_id = llm_evaluator._ensure_pad_token_id()
|
| 45 |
+
|
| 46 |
+
user_prompt = " " + prompt
|
| 47 |
+
formatted_prompt = llm_evaluator.format_prompt(user_prompt)
|
| 48 |
+
formatted_ids = llm_evaluator.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).input_ids
|
| 49 |
+
|
| 50 |
+
prompt_ids = formatted_ids.to(llm_evaluator.device)
|
| 51 |
+
prompt_ids_perturbed = prompt_ids.clone()
|
| 52 |
+
|
| 53 |
+
eos = llm_evaluator.tokenizer.eos_token or ""
|
| 54 |
+
generation_ids = llm_evaluator.tokenizer(
|
| 55 |
+
generation + eos,
|
| 56 |
+
return_tensors="pt",
|
| 57 |
+
add_special_tokens=False,
|
| 58 |
+
).input_ids.to(llm_evaluator.device)
|
| 59 |
+
|
| 60 |
+
attr_cpu = attribution.detach().cpu()
|
| 61 |
+
w = attr_cpu.sum(0)
|
| 62 |
+
sorted_attr_indices = torch.argsort(w, descending=True)
|
| 63 |
+
attr_sum = float(w.sum().item())
|
| 64 |
+
|
| 65 |
+
P = int(w.numel())
|
| 66 |
+
|
| 67 |
+
prompt_positions: List[int]
|
| 68 |
+
if user_prompt_indices is not None:
|
| 69 |
+
prompt_positions = [int(x) for x in user_prompt_indices]
|
| 70 |
+
if len(prompt_positions) != P:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
"user_prompt_indices length does not match prompt-side attribution length: "
|
| 73 |
+
f"indices P={len(prompt_positions)}, attr P={P}."
|
| 74 |
+
)
|
| 75 |
+
if P and max(prompt_positions) >= int(prompt_ids_perturbed.shape[1]):
|
| 76 |
+
raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
|
| 77 |
+
else:
|
| 78 |
+
user_ids = llm_evaluator.tokenizer(user_prompt, return_tensors="pt", add_special_tokens=False).input_ids
|
| 79 |
+
user_start = llm_evaluator._find_subsequence_start(formatted_ids[0], user_ids[0])
|
| 80 |
+
if user_start is None:
|
| 81 |
+
raise RuntimeError("Failed to locate user prompt token span inside formatted chat prompt.")
|
| 82 |
+
if int(user_ids.shape[1]) != P:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"Prompt-side attribution length does not match tokenized user prompt length: "
|
| 85 |
+
f"attr P={P}, user_prompt P={int(user_ids.shape[1])}."
|
| 86 |
+
)
|
| 87 |
+
prompt_positions = [int(user_start) + j for j in range(P)]
|
| 88 |
+
|
| 89 |
+
if P > 0:
|
| 90 |
+
steps = int(k) if k is not None else 0
|
| 91 |
+
if steps <= 0:
|
| 92 |
+
steps = 1
|
| 93 |
+
steps = min(steps, P)
|
| 94 |
+
else:
|
| 95 |
+
steps = 0
|
| 96 |
+
|
| 97 |
+
scores = np.zeros(steps + 1, dtype=np.float64)
|
| 98 |
+
density = np.zeros(steps + 1, dtype=np.float64)
|
| 99 |
+
|
| 100 |
+
scores[0] = (
|
| 101 |
+
llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
|
| 102 |
+
)
|
| 103 |
+
density[0] = 1.0
|
| 104 |
+
|
| 105 |
+
if P == 0:
|
| 106 |
+
return {
|
| 107 |
+
"num_tokens": 0,
|
| 108 |
+
"sorted_attr_indices": [],
|
| 109 |
+
"scores_raw": scores.tolist(),
|
| 110 |
+
"density": density.tolist(),
|
| 111 |
+
"normalized_model_response": [1.0],
|
| 112 |
+
"alignment_penalty": [0.0],
|
| 113 |
+
"corrected_scores": [1.0],
|
| 114 |
+
"token_deltas_raw": [],
|
| 115 |
+
"attr_weights": [],
|
| 116 |
+
"metrics": {"RISE": 0.0, "MAS": 0.0, "RISE+AP": 0.0},
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if attr_sum <= 0:
|
| 120 |
+
density = np.linspace(1.0, 0.0, steps + 1)
|
| 121 |
+
|
| 122 |
+
per_token_delta = np.zeros(P, dtype=np.float64)
|
| 123 |
+
|
| 124 |
+
base = P // steps
|
| 125 |
+
remainder = P % steps
|
| 126 |
+
start = 0
|
| 127 |
+
for step in range(steps):
|
| 128 |
+
size = base + (1 if step < remainder else 0)
|
| 129 |
+
group = sorted_attr_indices[start : start + size]
|
| 130 |
+
start += size
|
| 131 |
+
|
| 132 |
+
for idx_t in group:
|
| 133 |
+
idx = int(idx_t.item())
|
| 134 |
+
abs_pos = int(prompt_positions[idx])
|
| 135 |
+
prompt_ids_perturbed[0, abs_pos] = pad_token_id
|
| 136 |
+
scores[step + 1] = (
|
| 137 |
+
llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
|
| 138 |
+
)
|
| 139 |
+
if attr_sum > 0:
|
| 140 |
+
dec = float(w.index_select(0, group).sum().item()) / attr_sum
|
| 141 |
+
density[step + 1] = density[step] - dec
|
| 142 |
+
|
| 143 |
+
delta = scores[step] - scores[step + 1]
|
| 144 |
+
for idx_t in group:
|
| 145 |
+
idx = int(idx_t.item())
|
| 146 |
+
per_token_delta[idx] = delta
|
| 147 |
+
|
| 148 |
+
min_normalized_pred = 1.0
|
| 149 |
+
normalized_model_response = scores.copy()
|
| 150 |
+
for i in range(len(scores)):
|
| 151 |
+
normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
|
| 152 |
+
normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
|
| 153 |
+
min_normalized_pred = min(min_normalized_pred, normalized_pred)
|
| 154 |
+
normalized_model_response[i] = min_normalized_pred
|
| 155 |
+
|
| 156 |
+
alignment_penalty = np.abs(normalized_model_response - density)
|
| 157 |
+
corrected_scores = normalized_model_response + alignment_penalty
|
| 158 |
+
corrected_scores = corrected_scores.clip(0.0, 1.0)
|
| 159 |
+
corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
|
| 160 |
+
if np.isnan(corrected_scores).any():
|
| 161 |
+
corrected_scores = np.linspace(1.0, 0.0, len(scores))
|
| 162 |
+
|
| 163 |
+
rise = _auc(normalized_model_response)
|
| 164 |
+
mas = _auc(corrected_scores)
|
| 165 |
+
rise_ap = _auc(normalized_model_response + alignment_penalty)
|
| 166 |
+
|
| 167 |
+
if attr_sum > 0:
|
| 168 |
+
attr_weights = (w.numpy() / (attr_sum + 1e-12)).astype(np.float64)
|
| 169 |
+
else:
|
| 170 |
+
attr_weights = np.zeros(P, dtype=np.float64)
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"num_tokens": P,
|
| 174 |
+
"sorted_attr_indices": [int(i.item()) for i in sorted_attr_indices],
|
| 175 |
+
"scores_raw": scores.tolist(),
|
| 176 |
+
"density": density.tolist(),
|
| 177 |
+
"normalized_model_response": normalized_model_response.tolist(),
|
| 178 |
+
"alignment_penalty": alignment_penalty.tolist(),
|
| 179 |
+
"corrected_scores": corrected_scores.tolist(),
|
| 180 |
+
"token_deltas_raw": per_token_delta.tolist(),
|
| 181 |
+
"attr_weights": attr_weights.tolist(),
|
| 182 |
+
"metrics": {"RISE": rise, "MAS": mas, "RISE+AP": rise_ap},
|
| 183 |
+
}
|
exp/case_study/run_ifr_case.py
ADDED
|
@@ -0,0 +1,1225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Case study runner for FlashTrace and attribution baselines.
|
| 3 |
+
|
| 4 |
+
Modes supported (all emit JSON + HTML under ``exp/case_study/out``):
|
| 5 |
+
|
| 6 |
+
- ``ft``: FlashTrace (current project implementation; multi-hop IFR)
|
| 7 |
+
- ``ifr_in_all_gen``: Experimental multi-hop IFR variant (hops over CoT+output; scheme B, aligns with exp/exp2)
|
| 8 |
+
- ``ifr``: IFR span-aggregate visualization (single hop; one panel)
|
| 9 |
+
- ``ifr_all_positions``: IFR full matrix + CAGE (Row/Recursive panels)
|
| 10 |
+
- ``ifr_all_positions_output_only``: IFR output-only token matrix + CAGE (Row/Recursive panels)
|
| 11 |
+
- ``attnlrp``: AttnLRP hop0 (reuse FT-AttnLRP span-aggregate; visualize raw hop0 vector)
|
| 12 |
+
- ``ft_attnlrp``: FT-AttnLRP (multi-hop aggregated AttnLRP; matches exp/exp2)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import types
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 24 |
+
|
| 25 |
+
# Avoid torchvision dependency when importing transformers (Longformer).
|
| 26 |
+
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
|
| 27 |
+
os.environ.setdefault("DISABLE_TRANSFORMERS_IMAGE_TRANSFORMS", "1")
|
| 28 |
+
|
| 29 |
+
def _early_set_cuda_visible_devices() -> None:
|
| 30 |
+
"""Set CUDA_VISIBLE_DEVICES before importing torch/transformers.
|
| 31 |
+
|
| 32 |
+
Note: CUDA device indices are re-mapped inside the process after applying the mask.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 36 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 37 |
+
args, _ = parser.parse_known_args(sys.argv[1:])
|
| 38 |
+
cuda = args.cuda.strip() if isinstance(args.cuda, str) else ""
|
| 39 |
+
if cuda and "," in cuda:
|
| 40 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cuda
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
_early_set_cuda_visible_devices()
|
| 45 |
+
|
| 46 |
+
import torch
|
| 47 |
+
|
| 48 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 49 |
+
if str(REPO_ROOT) not in sys.path:
|
| 50 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _stub_torchvision() -> None:
|
| 54 |
+
"""Provide minimal torchvision stubs so Longformer imports succeed without the real package."""
|
| 55 |
+
|
| 56 |
+
if "torchvision" in sys.modules:
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
from importlib.machinery import ModuleSpec
|
| 60 |
+
|
| 61 |
+
def _mk(name: str) -> types.ModuleType:
|
| 62 |
+
mod = types.ModuleType(name)
|
| 63 |
+
mod.__spec__ = ModuleSpec(name, loader=None)
|
| 64 |
+
return mod
|
| 65 |
+
|
| 66 |
+
tv = _mk("torchvision")
|
| 67 |
+
tv.__dict__["__path__"] = []
|
| 68 |
+
submods = ["transforms", "_meta_registrations", "datasets", "io", "models", "ops", "utils"]
|
| 69 |
+
for name in submods:
|
| 70 |
+
mod = _mk(f"torchvision.{name}")
|
| 71 |
+
sys.modules[f"torchvision.{name}"] = mod
|
| 72 |
+
setattr(tv, name, mod)
|
| 73 |
+
|
| 74 |
+
class _InterpolationMode:
|
| 75 |
+
NEAREST = 0
|
| 76 |
+
NEAREST_EXACT = 0
|
| 77 |
+
BILINEAR = 1
|
| 78 |
+
BICUBIC = 2
|
| 79 |
+
LANCZOS = 3
|
| 80 |
+
BOX = 4
|
| 81 |
+
HAMMING = 5
|
| 82 |
+
|
| 83 |
+
sys.modules["torchvision.transforms"].InterpolationMode = _InterpolationMode
|
| 84 |
+
sys.modules["torchvision.transforms"].__all__ = ["InterpolationMode"]
|
| 85 |
+
|
| 86 |
+
# ops + misc stub for timm/transformers imports
|
| 87 |
+
ops_mod = sys.modules.get("torchvision.ops") or _mk("torchvision.ops")
|
| 88 |
+
sys.modules["torchvision.ops"] = ops_mod
|
| 89 |
+
setattr(tv, "ops", ops_mod)
|
| 90 |
+
misc_mod = _mk("torchvision.ops.misc")
|
| 91 |
+
sys.modules["torchvision.ops.misc"] = misc_mod
|
| 92 |
+
setattr(ops_mod, "misc", misc_mod)
|
| 93 |
+
|
| 94 |
+
class _FrozenBatchNorm2d:
|
| 95 |
+
def __init__(self, *args, **kwargs):
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
misc_mod.FrozenBatchNorm2d = _FrozenBatchNorm2d
|
| 99 |
+
|
| 100 |
+
sys.modules["torchvision"] = tv
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
_stub_torchvision()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _stub_timm() -> None:
|
| 107 |
+
"""Provide minimal timm stubs to avoid optional vision deps."""
|
| 108 |
+
|
| 109 |
+
if "timm" in sys.modules:
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
from importlib.machinery import ModuleSpec
|
| 113 |
+
|
| 114 |
+
def _mk(name: str) -> types.ModuleType:
|
| 115 |
+
mod = types.ModuleType(name)
|
| 116 |
+
mod.__spec__ = ModuleSpec(name, loader=None)
|
| 117 |
+
return mod
|
| 118 |
+
|
| 119 |
+
timm = _mk("timm")
|
| 120 |
+
timm.__dict__["__path__"] = []
|
| 121 |
+
sys.modules["timm"] = timm
|
| 122 |
+
|
| 123 |
+
data_mod = _mk("timm.data")
|
| 124 |
+
sys.modules["timm.data"] = data_mod
|
| 125 |
+
timm.data = data_mod
|
| 126 |
+
|
| 127 |
+
class _ImageNetInfo:
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
def _infer_imagenet_subset(*args, **kwargs):
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
data_mod.ImageNetInfo = _ImageNetInfo
|
| 134 |
+
data_mod.infer_imagenet_subset = _infer_imagenet_subset
|
| 135 |
+
|
| 136 |
+
layers_mod = _mk("timm.layers")
|
| 137 |
+
sys.modules["timm.layers"] = layers_mod
|
| 138 |
+
timm.layers = layers_mod
|
| 139 |
+
|
| 140 |
+
create_norm_mod = _mk("timm.layers.create_norm")
|
| 141 |
+
sys.modules["timm.layers.create_norm"] = create_norm_mod
|
| 142 |
+
layers_mod.create_norm = create_norm_mod
|
| 143 |
+
|
| 144 |
+
def _get_norm_layer(*args, **kwargs):
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
create_norm_mod.get_norm_layer = _get_norm_layer
|
| 148 |
+
|
| 149 |
+
classifier_mod = _mk("timm.layers.classifier")
|
| 150 |
+
sys.modules["timm.layers.classifier"] = classifier_mod
|
| 151 |
+
layers_mod.classifier = classifier_mod
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
_stub_timm()
|
| 155 |
+
|
| 156 |
+
import transformers
|
| 157 |
+
|
| 158 |
+
# Provide light stubs if Longformer classes are unavailable; IFR case study does not use them.
|
| 159 |
+
if not hasattr(transformers, "LongformerTokenizer"):
|
| 160 |
+
class _DummyLongformerTokenizer:
|
| 161 |
+
def __init__(self, *args, **kwargs):
|
| 162 |
+
raise ImportError("LongformerTokenizer stubbed; install full transformers+torchvision if needed.")
|
| 163 |
+
transformers.LongformerTokenizer = _DummyLongformerTokenizer
|
| 164 |
+
|
| 165 |
+
if not hasattr(transformers, "LongformerForMaskedLM"):
|
| 166 |
+
class _DummyLongformerForMaskedLM:
|
| 167 |
+
def __init__(self, *args, **kwargs):
|
| 168 |
+
raise ImportError("LongformerForMaskedLM stubbed; install full transformers+torchvision if needed.")
|
| 169 |
+
transformers.LongformerForMaskedLM = _DummyLongformerForMaskedLM
|
| 170 |
+
|
| 171 |
+
if hasattr(transformers, "__all__"):
|
| 172 |
+
for _name in ["LongformerTokenizer", "LongformerForMaskedLM"]:
|
| 173 |
+
if _name not in transformers.__all__:
|
| 174 |
+
transformers.__all__.append(_name)
|
| 175 |
+
|
| 176 |
+
# Gemma3n stubs (transformers may attempt to import even if unused)
|
| 177 |
+
if "transformers.models.gemma3n.configuration_gemma3n" not in sys.modules:
|
| 178 |
+
from importlib.machinery import ModuleSpec
|
| 179 |
+
|
| 180 |
+
gemma_pkg = types.ModuleType("transformers.models.gemma3n")
|
| 181 |
+
gemma_pkg.__spec__ = ModuleSpec("transformers.models.gemma3n", loader=None, is_package=True)
|
| 182 |
+
sys.modules["transformers.models.gemma3n"] = gemma_pkg
|
| 183 |
+
|
| 184 |
+
gemma_conf = types.ModuleType("transformers.models.gemma3n.configuration_gemma3n")
|
| 185 |
+
gemma_conf.__spec__ = ModuleSpec("transformers.models.gemma3n.configuration_gemma3n", loader=None)
|
| 186 |
+
|
| 187 |
+
class Gemma3nConfig:
|
| 188 |
+
def __init__(self, *args, **kwargs):
|
| 189 |
+
self.model_type = "gemma3n"
|
| 190 |
+
|
| 191 |
+
class Gemma3nTextConfig(Gemma3nConfig):
|
| 192 |
+
pass
|
| 193 |
+
|
| 194 |
+
gemma_conf.Gemma3nConfig = Gemma3nConfig
|
| 195 |
+
gemma_conf.Gemma3nTextConfig = Gemma3nTextConfig
|
| 196 |
+
gemma_conf.__all__ = ["Gemma3nConfig", "Gemma3nTextConfig"]
|
| 197 |
+
sys.modules["transformers.models.gemma3n.configuration_gemma3n"] = gemma_conf
|
| 198 |
+
setattr(gemma_pkg, "configuration_gemma3n", gemma_conf)
|
| 199 |
+
|
| 200 |
+
if hasattr(transformers, "__all__"):
|
| 201 |
+
for _nm in ["Gemma3nConfig", "Gemma3nTextConfig"]:
|
| 202 |
+
if _nm not in transformers.__all__:
|
| 203 |
+
transformers.__all__.append(_nm)
|
| 204 |
+
|
| 205 |
+
import llm_attr
|
| 206 |
+
from exp.exp2 import dataset_utils as ds_utils
|
| 207 |
+
from evaluations.attribution_recovery import load_model
|
| 208 |
+
|
| 209 |
+
from exp.case_study import analysis, viz
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def resolve_device(cuda: Optional[str], cuda_num: int) -> str:
|
| 213 |
+
if cuda and isinstance(cuda, str) and "," in cuda:
|
| 214 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cuda
|
| 215 |
+
return "auto"
|
| 216 |
+
if cuda and isinstance(cuda, str) and cuda.strip():
|
| 217 |
+
try:
|
| 218 |
+
idx = int(cuda)
|
| 219 |
+
except Exception:
|
| 220 |
+
idx = 0
|
| 221 |
+
return f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
|
| 222 |
+
return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu"
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def load_example(dataset: str, index: int, data_root: Path) -> Tuple[ds_utils.CachedExample, str]:
|
| 226 |
+
"""Load a single example from a cache path or dataset name."""
|
| 227 |
+
|
| 228 |
+
ds_path = Path(dataset)
|
| 229 |
+
if ds_path.exists():
|
| 230 |
+
examples = ds_utils.read_cached_jsonl(ds_path)
|
| 231 |
+
dataset_name = ds_path.name
|
| 232 |
+
else:
|
| 233 |
+
loader = ds_utils.DatasetLoader(data_root=data_root)
|
| 234 |
+
examples = loader.load(dataset)
|
| 235 |
+
dataset_name = dataset
|
| 236 |
+
|
| 237 |
+
if not examples:
|
| 238 |
+
raise ValueError(f"No examples found for dataset={dataset}")
|
| 239 |
+
|
| 240 |
+
if index < 0:
|
| 241 |
+
index = len(examples) + index
|
| 242 |
+
if not (0 <= index < len(examples)):
|
| 243 |
+
raise IndexError(f"index {index} out of range for dataset with {len(examples)} examples")
|
| 244 |
+
|
| 245 |
+
return examples[index], dataset_name
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def parse_args() -> argparse.Namespace:
|
| 249 |
+
parser = argparse.ArgumentParser("IFR multi-hop case study")
|
| 250 |
+
parser.add_argument("--dataset", type=str, default="exp/exp2/data/morehopqa.jsonl", help="Dataset name or JSONL path.")
|
| 251 |
+
parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache root for dataset names.")
|
| 252 |
+
parser.add_argument("--index", type=int, default=0, help="Sample index (supports negative for reverse).")
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
"--mode",
|
| 255 |
+
type=str,
|
| 256 |
+
choices=[
|
| 257 |
+
"ft",
|
| 258 |
+
"ft_improve",
|
| 259 |
+
"ft_split_hop",
|
| 260 |
+
"ifr_in_all_gen",
|
| 261 |
+
"ifr",
|
| 262 |
+
"ifr_all_positions",
|
| 263 |
+
"ifr_all_positions_output_only",
|
| 264 |
+
"attnlrp",
|
| 265 |
+
"ft_attnlrp",
|
| 266 |
+
],
|
| 267 |
+
default="ft",
|
| 268 |
+
help=(
|
| 269 |
+
"ft = FlashTrace (multi-hop IFR); ifr = standard IFR span-aggregate; "
|
| 270 |
+
"ifr_in_all_gen = multi-hop IFR over CoT+output (scheme B; exp2-aligned); "
|
| 271 |
+
"ifr_all_positions = full IFR matrix + CAGE row/rec; "
|
| 272 |
+
"ft_improve = FlashTrace (multi-hop IFR, stop-token soft deletion); "
|
| 273 |
+
"ft_split_hop = FlashTrace (split-hop IFR over segmented thinking span); "
|
| 274 |
+
"ifr_all_positions_output_only = output-only IFR matrix + CAGE row/rec; "
|
| 275 |
+
"attnlrp = AttnLRP hop0 (FT-AttnLRP span-aggregate); "
|
| 276 |
+
"ft_attnlrp = FT-AttnLRP (multi-hop aggregated; exp2)."
|
| 277 |
+
),
|
| 278 |
+
)
|
| 279 |
+
parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (ignored if --model_path set).")
|
| 280 |
+
parser.add_argument("--model_path", type=str, default=None, help="Local model path to override --model.")
|
| 281 |
+
parser.add_argument("--cuda", type=str, default=None, help="CUDA spec (e.g., '0' or '0,1').")
|
| 282 |
+
parser.add_argument("--cuda_num", type=int, default=0, help="Fallback GPU index when --cuda unset.")
|
| 283 |
+
parser.add_argument("--n_hops", type=int, default=1, help="Number of hops for IFR multi-hop.")
|
| 284 |
+
parser.add_argument("--sink_span", type=int, nargs=2, default=None, help="Optional sink span over generation tokens.")
|
| 285 |
+
parser.add_argument("--thinking_span", type=int, nargs=2, default=None, help="Optional thinking span over generation tokens.")
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--attnlrp_neg_handling",
|
| 288 |
+
type=str,
|
| 289 |
+
choices=["drop", "abs"],
|
| 290 |
+
default="drop",
|
| 291 |
+
help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--attnlrp_norm_mode",
|
| 295 |
+
type=str,
|
| 296 |
+
choices=["norm", "no_norm"],
|
| 297 |
+
default="norm",
|
| 298 |
+
help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
|
| 299 |
+
)
|
| 300 |
+
parser.add_argument("--chunk_tokens", type=int, default=128, help="IFR chunk size.")
|
| 301 |
+
parser.add_argument("--sink_chunk_tokens", type=int, default=32, help="IFR sink chunk size.")
|
| 302 |
+
parser.add_argument("--output_dir", type=str, default="exp/case_study/out", help="Where to write HTML/JSON artifacts.")
|
| 303 |
+
return parser.parse_args()
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def run_ft_multihop(
|
| 307 |
+
example: ds_utils.CachedExample,
|
| 308 |
+
model: Any,
|
| 309 |
+
tokenizer: Any,
|
| 310 |
+
*,
|
| 311 |
+
n_hops: int,
|
| 312 |
+
sink_span: Optional[Sequence[int]],
|
| 313 |
+
thinking_span: Optional[Sequence[int]],
|
| 314 |
+
chunk_tokens: int,
|
| 315 |
+
sink_chunk_tokens: int,
|
| 316 |
+
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
|
| 317 |
+
"""Execute FT (current multi-hop IFR) attribution for the selected example."""
|
| 318 |
+
|
| 319 |
+
attr = llm_attr.LLMIFRAttribution(
|
| 320 |
+
model,
|
| 321 |
+
tokenizer,
|
| 322 |
+
chunk_tokens=chunk_tokens,
|
| 323 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 327 |
+
thinking = (
|
| 328 |
+
tuple(thinking_span)
|
| 329 |
+
if thinking_span is not None
|
| 330 |
+
else tuple(example.thinking_span) if example.thinking_span else None
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
result = attr.calculate_ifr_multi_hop(
|
| 334 |
+
example.prompt,
|
| 335 |
+
target=example.target,
|
| 336 |
+
sink_span=sink,
|
| 337 |
+
thinking_span=thinking,
|
| 338 |
+
n_hops=n_hops,
|
| 339 |
+
)
|
| 340 |
+
debug_info: Dict[str, Any] = {
|
| 341 |
+
"full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
|
| 342 |
+
"generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
|
| 343 |
+
"user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
|
| 344 |
+
"chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
|
| 345 |
+
"prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
|
| 346 |
+
"generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
raw_vectors = []
|
| 350 |
+
if result.metadata and "ifr" in result.metadata:
|
| 351 |
+
raw_ifr = result.metadata["ifr"].get("raw")
|
| 352 |
+
if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
|
| 353 |
+
try:
|
| 354 |
+
raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
|
| 355 |
+
except Exception:
|
| 356 |
+
raw_vectors = []
|
| 357 |
+
debug_info["raw_hop_vectors"] = raw_vectors
|
| 358 |
+
|
| 359 |
+
return result, sink, thinking, debug_info
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def run_ft_multihop_improve(
|
| 363 |
+
example: ds_utils.CachedExample,
|
| 364 |
+
model: Any,
|
| 365 |
+
tokenizer: Any,
|
| 366 |
+
*,
|
| 367 |
+
n_hops: int,
|
| 368 |
+
sink_span: Optional[Sequence[int]],
|
| 369 |
+
thinking_span: Optional[Sequence[int]],
|
| 370 |
+
chunk_tokens: int,
|
| 371 |
+
sink_chunk_tokens: int,
|
| 372 |
+
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
|
| 373 |
+
"""Execute experimental FT (multi-hop IFR) with stop-token soft deletion."""
|
| 374 |
+
|
| 375 |
+
import ft_ifr_improve
|
| 376 |
+
|
| 377 |
+
attr = ft_ifr_improve.LLMIFRAttributionImproved(
|
| 378 |
+
model,
|
| 379 |
+
tokenizer,
|
| 380 |
+
chunk_tokens=chunk_tokens,
|
| 381 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 385 |
+
thinking = (
|
| 386 |
+
tuple(thinking_span)
|
| 387 |
+
if thinking_span is not None
|
| 388 |
+
else tuple(example.thinking_span) if example.thinking_span else None
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
result = attr.calculate_ifr_multi_hop_stop_words(
|
| 392 |
+
example.prompt,
|
| 393 |
+
target=example.target,
|
| 394 |
+
sink_span=sink,
|
| 395 |
+
thinking_span=thinking,
|
| 396 |
+
n_hops=n_hops,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
debug_info: Dict[str, Any] = {
|
| 400 |
+
"full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
|
| 401 |
+
"generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
|
| 402 |
+
"user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
|
| 403 |
+
"chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
|
| 404 |
+
"prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
|
| 405 |
+
"generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
raw_vectors = []
|
| 409 |
+
if result.metadata and "ifr" in result.metadata:
|
| 410 |
+
raw_ifr = result.metadata["ifr"].get("raw")
|
| 411 |
+
if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
|
| 412 |
+
try:
|
| 413 |
+
raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
|
| 414 |
+
except Exception:
|
| 415 |
+
raw_vectors = []
|
| 416 |
+
debug_info["raw_hop_vectors"] = raw_vectors
|
| 417 |
+
|
| 418 |
+
return result, sink, thinking, debug_info
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def run_ft_multihop_split_hop(
|
| 422 |
+
example: ds_utils.CachedExample,
|
| 423 |
+
model: Any,
|
| 424 |
+
tokenizer: Any,
|
| 425 |
+
*,
|
| 426 |
+
n_hops: int,
|
| 427 |
+
sink_span: Optional[Sequence[int]],
|
| 428 |
+
thinking_span: Optional[Sequence[int]],
|
| 429 |
+
chunk_tokens: int,
|
| 430 |
+
sink_chunk_tokens: int,
|
| 431 |
+
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
|
| 432 |
+
"""Execute experimental FT (split-hop IFR over segmented thinking span)."""
|
| 433 |
+
|
| 434 |
+
import ft_ifr_improve
|
| 435 |
+
|
| 436 |
+
attr = ft_ifr_improve.LLMIFRAttributionSplitHop(
|
| 437 |
+
model,
|
| 438 |
+
tokenizer,
|
| 439 |
+
chunk_tokens=chunk_tokens,
|
| 440 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 444 |
+
thinking = (
|
| 445 |
+
tuple(thinking_span)
|
| 446 |
+
if thinking_span is not None
|
| 447 |
+
else tuple(example.thinking_span) if example.thinking_span else None
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
result = attr.calculate_ifr_multi_hop_split_hop(
|
| 451 |
+
example.prompt,
|
| 452 |
+
target=example.target,
|
| 453 |
+
sink_span=sink,
|
| 454 |
+
thinking_span=thinking,
|
| 455 |
+
n_hops=int(n_hops),
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
debug_info: Dict[str, Any] = {
|
| 459 |
+
"full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
|
| 460 |
+
"generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
|
| 461 |
+
"user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
|
| 462 |
+
"chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
|
| 463 |
+
"prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
|
| 464 |
+
"generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
raw_vectors = []
|
| 468 |
+
if result.metadata and "ifr" in result.metadata:
|
| 469 |
+
raw_ifr = result.metadata["ifr"].get("raw")
|
| 470 |
+
if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
|
| 471 |
+
try:
|
| 472 |
+
raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
|
| 473 |
+
except Exception:
|
| 474 |
+
raw_vectors = []
|
| 475 |
+
debug_info["raw_hop_vectors"] = raw_vectors
|
| 476 |
+
|
| 477 |
+
return result, sink, thinking, debug_info
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def run_ifr_in_all_gen(
|
| 481 |
+
example: ds_utils.CachedExample,
|
| 482 |
+
model: Any,
|
| 483 |
+
tokenizer: Any,
|
| 484 |
+
*,
|
| 485 |
+
n_hops: int,
|
| 486 |
+
sink_span: Optional[Sequence[int]],
|
| 487 |
+
thinking_span: Optional[Sequence[int]],
|
| 488 |
+
chunk_tokens: int,
|
| 489 |
+
sink_chunk_tokens: int,
|
| 490 |
+
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
|
| 491 |
+
"""Execute experimental IFR variant: multi-hop over all generation (CoT + output)."""
|
| 492 |
+
|
| 493 |
+
import ft_ifr_improve
|
| 494 |
+
|
| 495 |
+
attr = ft_ifr_improve.LLMIFRAttributionInAllGen(
|
| 496 |
+
model,
|
| 497 |
+
tokenizer,
|
| 498 |
+
chunk_tokens=chunk_tokens,
|
| 499 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 503 |
+
thinking = (
|
| 504 |
+
tuple(thinking_span)
|
| 505 |
+
if thinking_span is not None
|
| 506 |
+
else tuple(example.thinking_span) if example.thinking_span else None
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
result = attr.calculate_ifr_in_all_gen(
|
| 510 |
+
example.prompt,
|
| 511 |
+
target=example.target,
|
| 512 |
+
sink_span=sink,
|
| 513 |
+
thinking_span=thinking,
|
| 514 |
+
n_hops=int(n_hops),
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
debug_info: Dict[str, Any] = {
|
| 518 |
+
"full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
|
| 519 |
+
"generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
|
| 520 |
+
"user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
|
| 521 |
+
"chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
|
| 522 |
+
"prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
|
| 523 |
+
"generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
raw_vectors = []
|
| 527 |
+
if result.metadata and "ifr" in result.metadata:
|
| 528 |
+
raw_ifr = result.metadata["ifr"].get("raw")
|
| 529 |
+
if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
|
| 530 |
+
try:
|
| 531 |
+
raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
|
| 532 |
+
except Exception:
|
| 533 |
+
raw_vectors = []
|
| 534 |
+
debug_info["raw_hop_vectors"] = raw_vectors
|
| 535 |
+
|
| 536 |
+
return result, sink, thinking, debug_info
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def make_output_stem(dataset_name: str, index: int, mode: str) -> str:
|
| 540 |
+
safe_name = dataset_name.replace("/", "_").replace(" ", "_")
|
| 541 |
+
prefix = {
|
| 542 |
+
"ft": "ft_case_",
|
| 543 |
+
"ft_improve": "ft_improve_case_",
|
| 544 |
+
"ifr": "ifr_case_",
|
| 545 |
+
"ifr_all_positions": "ifr_all_positions_case_",
|
| 546 |
+
"ifr_all_positions_output_only": "ifr_output_only_case_",
|
| 547 |
+
"attnlrp": "attnlrp_case_",
|
| 548 |
+
"ft_attnlrp": "ft_attnlrp_case_",
|
| 549 |
+
}.get(mode, f"{mode}_case_")
|
| 550 |
+
return f"{prefix}{safe_name}_idx{index}"
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def _decode_token_ids(tokenizer: Any, ids: Sequence[int]) -> List[str]:
|
| 554 |
+
"""Decode each token id into a readable text piece (keeps special tokens)."""
|
| 555 |
+
|
| 556 |
+
pieces: List[str] = []
|
| 557 |
+
for tok_id in ids:
|
| 558 |
+
try:
|
| 559 |
+
pieces.append(
|
| 560 |
+
tokenizer.decode([int(tok_id)], skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
| 561 |
+
)
|
| 562 |
+
except Exception:
|
| 563 |
+
pieces.append(str(tok_id))
|
| 564 |
+
return pieces
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def build_raw_tokens_from_ids(tokenizer: Any, prompt_ids: Optional[Sequence[int]], generation_ids: Optional[Sequence[int]]) -> List[str]:
|
| 568 |
+
if not prompt_ids:
|
| 569 |
+
prompt_ids = []
|
| 570 |
+
if not generation_ids:
|
| 571 |
+
generation_ids = []
|
| 572 |
+
return _decode_token_ids(tokenizer, prompt_ids) + _decode_token_ids(tokenizer, generation_ids)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def build_trimmed_roles(tokens: Sequence[str], segments: Dict[str, Any]) -> List[str]:
|
| 576 |
+
"""Assign role labels for trimmed tokens (prompt + generation)."""
|
| 577 |
+
|
| 578 |
+
roles = ["prompt" for _ in range(len(tokens))]
|
| 579 |
+
prompt_len_tokens = segments.get("prompt_len", 0)
|
| 580 |
+
for idx in range(prompt_len_tokens, len(tokens)):
|
| 581 |
+
roles[idx] = "gen"
|
| 582 |
+
thinking_span = segments.get("thinking_span")
|
| 583 |
+
sink_span = segments.get("sink_span")
|
| 584 |
+
if thinking_span is not None:
|
| 585 |
+
start = prompt_len_tokens + int(thinking_span[0])
|
| 586 |
+
end = prompt_len_tokens + int(thinking_span[1])
|
| 587 |
+
for i in range(start, min(len(tokens), end + 1)):
|
| 588 |
+
roles[i] = "think"
|
| 589 |
+
if sink_span is not None:
|
| 590 |
+
start = prompt_len_tokens + int(sink_span[0])
|
| 591 |
+
end = prompt_len_tokens + int(sink_span[1])
|
| 592 |
+
for i in range(start, min(len(tokens), end + 1)):
|
| 593 |
+
roles[i] = "output"
|
| 594 |
+
return roles
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def build_raw_roles(
|
| 598 |
+
tokens: Sequence[str],
|
| 599 |
+
prompt_len_full: int,
|
| 600 |
+
user_indices: Sequence[int],
|
| 601 |
+
template_indices: Sequence[int],
|
| 602 |
+
thinking_span_abs: Optional[Sequence[int]],
|
| 603 |
+
sink_span_abs: Optional[Sequence[int]],
|
| 604 |
+
) -> List[str]:
|
| 605 |
+
"""Assign role labels for raw tokens (template + user + generation)."""
|
| 606 |
+
|
| 607 |
+
roles = ["template" for _ in range(len(tokens))]
|
| 608 |
+
user_set = set(int(i) for i in user_indices)
|
| 609 |
+
tmpl_set = set(int(i) for i in template_indices)
|
| 610 |
+
|
| 611 |
+
for i in range(min(len(tokens), prompt_len_full)):
|
| 612 |
+
if i in user_set:
|
| 613 |
+
roles[i] = "user"
|
| 614 |
+
elif i in tmpl_set:
|
| 615 |
+
roles[i] = "template"
|
| 616 |
+
else:
|
| 617 |
+
roles[i] = "prompt"
|
| 618 |
+
|
| 619 |
+
for i in range(prompt_len_full, len(tokens)):
|
| 620 |
+
roles[i] = "gen"
|
| 621 |
+
|
| 622 |
+
if thinking_span_abs is not None:
|
| 623 |
+
start, end = int(thinking_span_abs[0]), int(thinking_span_abs[1])
|
| 624 |
+
for i in range(start, min(len(tokens), end + 1)):
|
| 625 |
+
roles[i] = "think"
|
| 626 |
+
|
| 627 |
+
if sink_span_abs is not None:
|
| 628 |
+
start, end = int(sink_span_abs[0]), int(sink_span_abs[1])
|
| 629 |
+
for i in range(start, min(len(tokens), end + 1)):
|
| 630 |
+
roles[i] = "output"
|
| 631 |
+
|
| 632 |
+
return roles
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def extract_prompt_only_vectors(hop_vectors: Sequence[torch.Tensor], prompt_len: int) -> List[torch.Tensor]:
|
| 636 |
+
"""Slice hop vectors down to user-prompt tokens only (no generation tokens)."""
|
| 637 |
+
|
| 638 |
+
if prompt_len < 0:
|
| 639 |
+
raise ValueError("prompt_len must be >= 0.")
|
| 640 |
+
|
| 641 |
+
out: List[torch.Tensor] = []
|
| 642 |
+
for vec in hop_vectors:
|
| 643 |
+
v = torch.as_tensor(vec, dtype=torch.float32).detach().cpu()
|
| 644 |
+
if int(v.numel()) < int(prompt_len):
|
| 645 |
+
raise ValueError(f"Hop vector too short for prompt-only slice: len={int(v.numel())} prompt_len={int(prompt_len)}.")
|
| 646 |
+
out.append(v[:prompt_len])
|
| 647 |
+
return out
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def _lift_trimmed_to_full(
|
| 651 |
+
trimmed: torch.Tensor,
|
| 652 |
+
*,
|
| 653 |
+
prompt_len_full: int,
|
| 654 |
+
gen_len: int,
|
| 655 |
+
user_prompt_indices: Sequence[int],
|
| 656 |
+
) -> torch.Tensor:
|
| 657 |
+
"""Lift a trimmed (user prompt + generation) vector into full token space with zeros for chat-template tokens."""
|
| 658 |
+
|
| 659 |
+
t = torch.as_tensor(trimmed, dtype=torch.float32).detach().cpu()
|
| 660 |
+
user_len = len(user_prompt_indices)
|
| 661 |
+
expected = int(user_len + gen_len)
|
| 662 |
+
if int(t.numel()) != expected:
|
| 663 |
+
raise ValueError(f"Trimmed vector length mismatch: got {int(t.numel())}, expected {expected}.")
|
| 664 |
+
|
| 665 |
+
total_len = int(prompt_len_full + gen_len)
|
| 666 |
+
full = torch.zeros((total_len,), dtype=torch.float32)
|
| 667 |
+
for j, abs_pos in enumerate(user_prompt_indices):
|
| 668 |
+
full[int(abs_pos)] = t[j]
|
| 669 |
+
full[int(prompt_len_full) : int(prompt_len_full + gen_len)] = t[user_len:]
|
| 670 |
+
return full
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def _postprocess_attnlrp_full_vector(
|
| 674 |
+
raw_full: torch.Tensor,
|
| 675 |
+
*,
|
| 676 |
+
prompt_len_full: int,
|
| 677 |
+
gen_len: int,
|
| 678 |
+
user_prompt_indices: Sequence[int],
|
| 679 |
+
neg_handling: str,
|
| 680 |
+
norm_mode: str,
|
| 681 |
+
) -> torch.Tensor:
|
| 682 |
+
"""Mirror FT-AttnLRP hop postprocessing while preserving stripped-token normalization.
|
| 683 |
+
|
| 684 |
+
The underlying AttnLRP implementation postprocesses the *stripped* vector (user prompt + generation):
|
| 685 |
+
- NaN->0, then neg_handling ('drop' or 'abs')
|
| 686 |
+
- if norm_mode=='norm': normalize by sum over stripped tokens
|
| 687 |
+
|
| 688 |
+
For the pre-trim full view (chat template + generation), we apply the same non-negativity transform
|
| 689 |
+
to the full vector and normalize using *only the stripped indices*, so overlapping token scores
|
| 690 |
+
match the trimmed vectors used by the evaluation/case-study hop outputs.
|
| 691 |
+
"""
|
| 692 |
+
|
| 693 |
+
v = torch.as_tensor(raw_full, dtype=torch.float32).detach().cpu()
|
| 694 |
+
v = torch.nan_to_num(v, nan=0.0)
|
| 695 |
+
|
| 696 |
+
if neg_handling == "drop":
|
| 697 |
+
v = v.clamp(min=0.0)
|
| 698 |
+
elif neg_handling == "abs":
|
| 699 |
+
v = v.abs()
|
| 700 |
+
else:
|
| 701 |
+
raise ValueError(f"Unsupported neg_handling={neg_handling!r} (expected 'drop' or 'abs').")
|
| 702 |
+
|
| 703 |
+
ratio_enabled = norm_mode == "norm"
|
| 704 |
+
if not ratio_enabled:
|
| 705 |
+
return v
|
| 706 |
+
|
| 707 |
+
keep = list(int(i) for i in user_prompt_indices) + list(range(int(prompt_len_full), int(prompt_len_full + gen_len)))
|
| 708 |
+
if not keep:
|
| 709 |
+
return torch.zeros_like(v)
|
| 710 |
+
|
| 711 |
+
keep_idx = torch.as_tensor(keep, dtype=torch.long)
|
| 712 |
+
denom = float(v.index_select(0, keep_idx).sum().item())
|
| 713 |
+
if denom <= 0.0:
|
| 714 |
+
return torch.zeros_like(v)
|
| 715 |
+
return v / (denom + 1e-12)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def main() -> None:
|
| 719 |
+
args = parse_args()
|
| 720 |
+
device = resolve_device(args.cuda, args.cuda_num)
|
| 721 |
+
if torch.cuda.is_available():
|
| 722 |
+
visible = os.environ.get("CUDA_VISIBLE_DEVICES")
|
| 723 |
+
print(f"[info] CUDA_VISIBLE_DEVICES={visible!r} torch.cuda.device_count()={torch.cuda.device_count()} device={device}")
|
| 724 |
+
|
| 725 |
+
model_name = args.model_path if args.model_path is not None else args.model
|
| 726 |
+
# Align with exp/exp2: always use the shared fp16 loader.
|
| 727 |
+
model, tokenizer = load_model(model_name, device)
|
| 728 |
+
|
| 729 |
+
example, ds_name = load_example(args.dataset, args.index, Path(args.data_root))
|
| 730 |
+
mode = args.mode
|
| 731 |
+
|
| 732 |
+
sink_span: Optional[Tuple[int, int]] = None
|
| 733 |
+
thinking_span: Optional[Tuple[int, int]] = None
|
| 734 |
+
thinking_ratios: Optional[Sequence[float]] = None
|
| 735 |
+
|
| 736 |
+
prompt_tokens_trimmed: List[str] = []
|
| 737 |
+
generation_tokens_trimmed: List[str] = []
|
| 738 |
+
hop_vectors_trimmed: List[torch.Tensor] = []
|
| 739 |
+
hop_vectors_raw: List[torch.Tensor] = []
|
| 740 |
+
prompt_len_full: Optional[int] = None
|
| 741 |
+
user_prompt_indices: List[int] = []
|
| 742 |
+
chat_prompt_indices: List[int] = []
|
| 743 |
+
method_meta: Dict[str, Any] = {}
|
| 744 |
+
raw_prompt_ids: Optional[List[int]] = None
|
| 745 |
+
raw_generation_ids: Optional[List[int]] = None
|
| 746 |
+
attnlrp_raw_attributions: Optional[List[Any]] = None
|
| 747 |
+
|
| 748 |
+
if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen"):
|
| 749 |
+
if mode == "ft":
|
| 750 |
+
attr_result, sink_span, thinking_span, debug_info = run_ft_multihop(
|
| 751 |
+
example,
|
| 752 |
+
model,
|
| 753 |
+
tokenizer,
|
| 754 |
+
n_hops=args.n_hops,
|
| 755 |
+
sink_span=args.sink_span,
|
| 756 |
+
thinking_span=args.thinking_span,
|
| 757 |
+
chunk_tokens=args.chunk_tokens,
|
| 758 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 759 |
+
)
|
| 760 |
+
elif mode == "ft_improve":
|
| 761 |
+
attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_improve(
|
| 762 |
+
example,
|
| 763 |
+
model,
|
| 764 |
+
tokenizer,
|
| 765 |
+
n_hops=args.n_hops,
|
| 766 |
+
sink_span=args.sink_span,
|
| 767 |
+
thinking_span=args.thinking_span,
|
| 768 |
+
chunk_tokens=args.chunk_tokens,
|
| 769 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 770 |
+
)
|
| 771 |
+
elif mode == "ft_split_hop":
|
| 772 |
+
attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_split_hop(
|
| 773 |
+
example,
|
| 774 |
+
model,
|
| 775 |
+
tokenizer,
|
| 776 |
+
n_hops=args.n_hops,
|
| 777 |
+
sink_span=args.sink_span,
|
| 778 |
+
thinking_span=args.thinking_span,
|
| 779 |
+
chunk_tokens=args.chunk_tokens,
|
| 780 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 781 |
+
)
|
| 782 |
+
elif mode == "ifr_in_all_gen":
|
| 783 |
+
attr_result, sink_span, thinking_span, debug_info = run_ifr_in_all_gen(
|
| 784 |
+
example,
|
| 785 |
+
model,
|
| 786 |
+
tokenizer,
|
| 787 |
+
n_hops=args.n_hops,
|
| 788 |
+
sink_span=args.sink_span,
|
| 789 |
+
thinking_span=args.thinking_span,
|
| 790 |
+
chunk_tokens=args.chunk_tokens,
|
| 791 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 792 |
+
)
|
| 793 |
+
else:
|
| 794 |
+
raise ValueError(f"Unsupported mode={mode}")
|
| 795 |
+
ifr_meta = (attr_result.metadata or {}).get("ifr") or {}
|
| 796 |
+
hop_vectors_trimmed = list(ifr_meta.get("per_hop_projected") or [])
|
| 797 |
+
if not hop_vectors_trimmed:
|
| 798 |
+
raise RuntimeError(f"No per-hop vectors found for {mode} mode.")
|
| 799 |
+
|
| 800 |
+
prompt_tokens_trimmed = list(attr_result.prompt_tokens)
|
| 801 |
+
generation_tokens_trimmed = list(attr_result.generation_tokens)
|
| 802 |
+
thinking_ratios = ifr_meta.get("thinking_ratios")
|
| 803 |
+
|
| 804 |
+
raw_prompt_ids = debug_info.get("prompt_ids")
|
| 805 |
+
if isinstance(raw_prompt_ids, list) and raw_prompt_ids and isinstance(raw_prompt_ids[0], list):
|
| 806 |
+
raw_prompt_ids = raw_prompt_ids[0]
|
| 807 |
+
raw_generation_ids = debug_info.get("generation_ids")
|
| 808 |
+
if isinstance(raw_generation_ids, list) and raw_generation_ids and isinstance(raw_generation_ids[0], list):
|
| 809 |
+
raw_generation_ids = raw_generation_ids[0]
|
| 810 |
+
|
| 811 |
+
user_prompt_indices = list(debug_info.get("user_prompt_indices") or [])
|
| 812 |
+
chat_prompt_indices = list(debug_info.get("chat_prompt_indices") or [])
|
| 813 |
+
prompt_len_full = len(raw_prompt_ids) if isinstance(raw_prompt_ids, list) else None
|
| 814 |
+
|
| 815 |
+
raw_vectors = debug_info.get("raw_hop_vectors") or []
|
| 816 |
+
hop_vectors_raw = [vec.detach().cpu() if hasattr(vec, "detach") else torch.as_tensor(vec) for vec in raw_vectors]
|
| 817 |
+
method_meta = {"ifr": analysis.sanitize_ifr_meta(ifr_meta)}
|
| 818 |
+
|
| 819 |
+
elif mode == "ifr":
|
| 820 |
+
# Standard IFR (single-hop span aggregate), with pre/post trim views.
|
| 821 |
+
attr = llm_attr.LLMIFRAttribution(
|
| 822 |
+
model,
|
| 823 |
+
tokenizer,
|
| 824 |
+
chunk_tokens=args.chunk_tokens,
|
| 825 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 826 |
+
)
|
| 827 |
+
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 828 |
+
thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
|
| 829 |
+
|
| 830 |
+
if sink_span is None:
|
| 831 |
+
raise ValueError("sink_span is required for IFR mode (use dataset sink_span or pass --sink_span).")
|
| 832 |
+
span_result = attr.calculate_ifr_span(
|
| 833 |
+
example.prompt,
|
| 834 |
+
target=example.target,
|
| 835 |
+
span=tuple(sink_span),
|
| 836 |
+
)
|
| 837 |
+
span_meta = span_result.metadata.get("ifr") if span_result.metadata else None
|
| 838 |
+
aggregate = span_meta.get("aggregate") if isinstance(span_meta, dict) else None
|
| 839 |
+
if aggregate is None or not hasattr(aggregate, "token_importance_total"):
|
| 840 |
+
raise RuntimeError("IFR span aggregate missing from metadata; cannot render pre-trim view.")
|
| 841 |
+
|
| 842 |
+
raw_vector = aggregate.token_importance_total.detach().cpu()
|
| 843 |
+
trimmed_vector = attr._project_vector(raw_vector)
|
| 844 |
+
hop_vectors_raw = [raw_vector]
|
| 845 |
+
hop_vectors_trimmed = [trimmed_vector]
|
| 846 |
+
|
| 847 |
+
prompt_tokens_trimmed = list(attr.user_prompt_tokens)
|
| 848 |
+
generation_tokens_trimmed = list(attr.generation_tokens)
|
| 849 |
+
|
| 850 |
+
raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
|
| 851 |
+
raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
|
| 852 |
+
user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
|
| 853 |
+
chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
|
| 854 |
+
prompt_len_full = len(raw_prompt_ids)
|
| 855 |
+
|
| 856 |
+
sink_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1])
|
| 857 |
+
think_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1]) if thinking_span else None
|
| 858 |
+
|
| 859 |
+
meta = {
|
| 860 |
+
"type": "span_aggregate",
|
| 861 |
+
"ifr_view": "aggregate",
|
| 862 |
+
"sink_span_generation": sink_span,
|
| 863 |
+
"sink_span_absolute": sink_abs,
|
| 864 |
+
"thinking_span_generation": thinking_span,
|
| 865 |
+
"thinking_span_absolute": think_abs,
|
| 866 |
+
}
|
| 867 |
+
method_meta = {"ifr": analysis.tensor_to_list(meta)}
|
| 868 |
+
|
| 869 |
+
elif mode == "ifr_all_positions_output_only":
|
| 870 |
+
# IFR all-positions (output-only) + token-level CAGE (row/recursive) derived from the matrix.
|
| 871 |
+
attr = llm_attr.LLMIFRAttribution(
|
| 872 |
+
model,
|
| 873 |
+
tokenizer,
|
| 874 |
+
chunk_tokens=args.chunk_tokens,
|
| 875 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 876 |
+
)
|
| 877 |
+
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 878 |
+
thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
|
| 879 |
+
|
| 880 |
+
if sink_span is None:
|
| 881 |
+
raise ValueError(
|
| 882 |
+
"sink_span is required for ifr_all_positions_output_only mode "
|
| 883 |
+
"(use dataset sink_span or pass --sink_span)."
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
attr_result = attr.calculate_ifr_for_all_positions_output_only(
|
| 887 |
+
example.prompt,
|
| 888 |
+
target=example.target,
|
| 889 |
+
sink_span=tuple(sink_span),
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
indices_to_explain = list(sink_span)
|
| 893 |
+
_, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
|
| 894 |
+
row_vec = row_attr.squeeze(0).detach().cpu()
|
| 895 |
+
rec_vec = rec_attr.squeeze(0).detach().cpu()
|
| 896 |
+
|
| 897 |
+
hop_vectors_trimmed = [row_vec, rec_vec]
|
| 898 |
+
|
| 899 |
+
prompt_tokens_trimmed = list(attr.user_prompt_tokens)
|
| 900 |
+
generation_tokens_trimmed = list(attr.generation_tokens)
|
| 901 |
+
|
| 902 |
+
raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
|
| 903 |
+
raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
|
| 904 |
+
user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
|
| 905 |
+
chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
|
| 906 |
+
prompt_len_full = len(raw_prompt_ids)
|
| 907 |
+
|
| 908 |
+
gen_len = len(raw_generation_ids or [])
|
| 909 |
+
hop_vectors_raw = [
|
| 910 |
+
_lift_trimmed_to_full(
|
| 911 |
+
v,
|
| 912 |
+
prompt_len_full=int(prompt_len_full or 0),
|
| 913 |
+
gen_len=gen_len,
|
| 914 |
+
user_prompt_indices=user_prompt_indices,
|
| 915 |
+
)
|
| 916 |
+
for v in hop_vectors_trimmed
|
| 917 |
+
]
|
| 918 |
+
|
| 919 |
+
ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {})
|
| 920 |
+
ifr_meta["ifr_view"] = "all_positions_output_only (row+rec)"
|
| 921 |
+
ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"]
|
| 922 |
+
ifr_meta["indices_to_explain"] = indices_to_explain
|
| 923 |
+
method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)}
|
| 924 |
+
|
| 925 |
+
elif mode == "ifr_all_positions":
|
| 926 |
+
# IFR all-positions (full generation) + token-level CAGE (row/recursive) derived from the matrix.
|
| 927 |
+
attr = llm_attr.LLMIFRAttribution(
|
| 928 |
+
model,
|
| 929 |
+
tokenizer,
|
| 930 |
+
chunk_tokens=args.chunk_tokens,
|
| 931 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 932 |
+
)
|
| 933 |
+
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 934 |
+
thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
|
| 935 |
+
|
| 936 |
+
if sink_span is None:
|
| 937 |
+
raise ValueError(
|
| 938 |
+
"sink_span is required for ifr_all_positions mode (use dataset sink_span or pass --sink_span)."
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
attr_result = attr.calculate_ifr_for_all_positions(
|
| 942 |
+
example.prompt,
|
| 943 |
+
target=example.target,
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
indices_to_explain = list(sink_span)
|
| 947 |
+
_, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
|
| 948 |
+
row_vec = row_attr.squeeze(0).detach().cpu()
|
| 949 |
+
rec_vec = rec_attr.squeeze(0).detach().cpu()
|
| 950 |
+
|
| 951 |
+
hop_vectors_trimmed = [row_vec, rec_vec]
|
| 952 |
+
|
| 953 |
+
prompt_tokens_trimmed = list(attr.user_prompt_tokens)
|
| 954 |
+
generation_tokens_trimmed = list(attr.generation_tokens)
|
| 955 |
+
|
| 956 |
+
raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
|
| 957 |
+
raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
|
| 958 |
+
user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
|
| 959 |
+
chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
|
| 960 |
+
prompt_len_full = len(raw_prompt_ids)
|
| 961 |
+
|
| 962 |
+
gen_len = len(raw_generation_ids or [])
|
| 963 |
+
hop_vectors_raw = [
|
| 964 |
+
_lift_trimmed_to_full(
|
| 965 |
+
v,
|
| 966 |
+
prompt_len_full=int(prompt_len_full or 0),
|
| 967 |
+
gen_len=gen_len,
|
| 968 |
+
user_prompt_indices=user_prompt_indices,
|
| 969 |
+
)
|
| 970 |
+
for v in hop_vectors_trimmed
|
| 971 |
+
]
|
| 972 |
+
|
| 973 |
+
ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {})
|
| 974 |
+
ifr_meta["ifr_view"] = "all_positions (row+rec)"
|
| 975 |
+
ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"]
|
| 976 |
+
ifr_meta["indices_to_explain"] = indices_to_explain
|
| 977 |
+
method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)}
|
| 978 |
+
|
| 979 |
+
elif mode in ("attnlrp", "ft_attnlrp"):
|
| 980 |
+
# Reuse the shared LLMLRPAttribution implementations (root-level).
|
| 981 |
+
attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 982 |
+
|
| 983 |
+
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 984 |
+
thinking_span = (
|
| 985 |
+
tuple(args.thinking_span)
|
| 986 |
+
if args.thinking_span is not None
|
| 987 |
+
else tuple(example.thinking_span) if example.thinking_span else sink_span
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
if mode == "attnlrp":
|
| 991 |
+
# Case-study AttnLRP: reuse FT-AttnLRP logic but take hop0 (the first span-aggregate)
|
| 992 |
+
# for a full, signed attribution vector (no observation masking).
|
| 993 |
+
attr_result = attributor.calculate_attnlrp_ft_hop0(
|
| 994 |
+
example.prompt,
|
| 995 |
+
target=example.target,
|
| 996 |
+
sink_span=sink_span,
|
| 997 |
+
thinking_span=thinking_span,
|
| 998 |
+
neg_handling=args.attnlrp_neg_handling,
|
| 999 |
+
norm_mode=args.attnlrp_norm_mode,
|
| 1000 |
+
)
|
| 1001 |
+
meta = attr_result.metadata or {}
|
| 1002 |
+
multi_hop = meta.get("multi_hop_result")
|
| 1003 |
+
raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
|
| 1004 |
+
attnlrp_raw_attributions = list(raw_attributions)
|
| 1005 |
+
base_attr = raw_attributions[0] if raw_attributions else None
|
| 1006 |
+
if base_attr is None or not hasattr(base_attr, "token_importance_total"):
|
| 1007 |
+
raise RuntimeError("AttnLRP hop0 missing from multi-hop result.")
|
| 1008 |
+
|
| 1009 |
+
hop0_vec = torch.as_tensor(getattr(base_attr, "token_importance_total"), dtype=torch.float32).detach().cpu()
|
| 1010 |
+
if hop0_vec.numel() <= 0:
|
| 1011 |
+
raise RuntimeError("Empty generation for AttnLRP case study.")
|
| 1012 |
+
|
| 1013 |
+
# Use the actual sink span applied by hop0 (defaults to full generation when unset).
|
| 1014 |
+
sink_span = tuple(getattr(base_attr, "sink_range"))
|
| 1015 |
+
if thinking_span is None:
|
| 1016 |
+
thinking_span = sink_span
|
| 1017 |
+
|
| 1018 |
+
hop_vectors_trimmed = [hop0_vec]
|
| 1019 |
+
thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or [])
|
| 1020 |
+
|
| 1021 |
+
method_meta = {
|
| 1022 |
+
"attnlrp": {
|
| 1023 |
+
"type": "calculate_attnlrp_multi_hop(n_hops=0) hop0 raw_attributions[0]",
|
| 1024 |
+
"sink_span_generation": sink_span,
|
| 1025 |
+
"thinking_span_generation": thinking_span,
|
| 1026 |
+
"thinking_ratios": thinking_ratios,
|
| 1027 |
+
"neg_handling": args.attnlrp_neg_handling,
|
| 1028 |
+
"norm_mode": args.attnlrp_norm_mode,
|
| 1029 |
+
"ratio_enabled": args.attnlrp_norm_mode == "norm",
|
| 1030 |
+
}
|
| 1031 |
+
}
|
| 1032 |
+
else:
|
| 1033 |
+
# exp2 ft_attnlrp: multi-hop aggregated AttnLRP (metadata contains per-hop vectors).
|
| 1034 |
+
attr_result = attributor.calculate_attnlrp_aggregated_multi_hop(
|
| 1035 |
+
example.prompt,
|
| 1036 |
+
target=example.target,
|
| 1037 |
+
sink_span=sink_span,
|
| 1038 |
+
thinking_span=thinking_span,
|
| 1039 |
+
n_hops=int(args.n_hops),
|
| 1040 |
+
neg_handling=args.attnlrp_neg_handling,
|
| 1041 |
+
norm_mode=args.attnlrp_norm_mode,
|
| 1042 |
+
)
|
| 1043 |
+
meta = attr_result.metadata or {}
|
| 1044 |
+
multi_hop = meta.get("multi_hop_result")
|
| 1045 |
+
if multi_hop is None:
|
| 1046 |
+
raise RuntimeError("FT-AttnLRP case study missing metadata.multi_hop_result.")
|
| 1047 |
+
|
| 1048 |
+
raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
|
| 1049 |
+
attnlrp_raw_attributions = list(raw_attributions)
|
| 1050 |
+
hop_vectors_trimmed = [
|
| 1051 |
+
torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32).detach().cpu()
|
| 1052 |
+
for hop in raw_attributions
|
| 1053 |
+
]
|
| 1054 |
+
thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or [])
|
| 1055 |
+
|
| 1056 |
+
method_meta = {
|
| 1057 |
+
"attnlrp": {
|
| 1058 |
+
"type": "calculate_attnlrp_aggregated_multi_hop (exp2 ft_attnlrp)",
|
| 1059 |
+
"n_hops": int(args.n_hops),
|
| 1060 |
+
"sink_span_generation": sink_span,
|
| 1061 |
+
"thinking_span_generation": thinking_span,
|
| 1062 |
+
"thinking_ratios": thinking_ratios,
|
| 1063 |
+
"neg_handling": args.attnlrp_neg_handling,
|
| 1064 |
+
"norm_mode": args.attnlrp_norm_mode,
|
| 1065 |
+
"ratio_enabled": args.attnlrp_norm_mode == "norm",
|
| 1066 |
+
}
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
prompt_tokens_trimmed = list(attributor.user_prompt_tokens)
|
| 1070 |
+
generation_tokens_trimmed = list(attributor.generation_tokens)
|
| 1071 |
+
|
| 1072 |
+
raw_prompt_ids = attributor.prompt_ids.detach().cpu().tolist()[0]
|
| 1073 |
+
raw_generation_ids = attributor.generation_ids.detach().cpu().tolist()[0]
|
| 1074 |
+
user_prompt_indices = list(getattr(attributor, "user_prompt_indices", []) or [])
|
| 1075 |
+
chat_prompt_indices = list(getattr(attributor, "chat_prompt_indices", []) or [])
|
| 1076 |
+
prompt_len_full = len(raw_prompt_ids)
|
| 1077 |
+
|
| 1078 |
+
else:
|
| 1079 |
+
raise ValueError(f"Unsupported mode={mode}")
|
| 1080 |
+
|
| 1081 |
+
if not hop_vectors_trimmed:
|
| 1082 |
+
raise RuntimeError("No hop vectors to visualize.")
|
| 1083 |
+
|
| 1084 |
+
raw_tokens = build_raw_tokens_from_ids(tokenizer, raw_prompt_ids, raw_generation_ids)
|
| 1085 |
+
|
| 1086 |
+
sink_span_abs = None
|
| 1087 |
+
thinking_span_abs = None
|
| 1088 |
+
if prompt_len_full is not None and sink_span is not None:
|
| 1089 |
+
sink_span_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1])
|
| 1090 |
+
if prompt_len_full is not None and thinking_span is not None:
|
| 1091 |
+
thinking_span_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1])
|
| 1092 |
+
prompt_len_full_safe = int(prompt_len_full or 0)
|
| 1093 |
+
roles_raw = build_raw_roles(
|
| 1094 |
+
raw_tokens,
|
| 1095 |
+
prompt_len_full_safe,
|
| 1096 |
+
user_prompt_indices,
|
| 1097 |
+
chat_prompt_indices,
|
| 1098 |
+
thinking_span_abs,
|
| 1099 |
+
sink_span_abs,
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
prompt_tokens_only = list(prompt_tokens_trimmed)
|
| 1103 |
+
prompt_only_vectors = extract_prompt_only_vectors(hop_vectors_trimmed, len(prompt_tokens_only))
|
| 1104 |
+
|
| 1105 |
+
# Ensure every method has a pre-trim full vector per panel.
|
| 1106 |
+
if not hop_vectors_raw:
|
| 1107 |
+
if mode in ("attnlrp", "ft_attnlrp") and attnlrp_raw_attributions is not None:
|
| 1108 |
+
gen_len = len(raw_generation_ids or [])
|
| 1109 |
+
expected = int((prompt_len_full_safe + gen_len) if prompt_len_full is not None else 0)
|
| 1110 |
+
full_vectors: List[torch.Tensor] = []
|
| 1111 |
+
for hop in attnlrp_raw_attributions:
|
| 1112 |
+
meta = getattr(hop, "metadata", None) or {}
|
| 1113 |
+
raw_full = meta.get("token_importance_total_with_chat_template")
|
| 1114 |
+
if raw_full is None:
|
| 1115 |
+
full_vectors = []
|
| 1116 |
+
break
|
| 1117 |
+
v = _postprocess_attnlrp_full_vector(
|
| 1118 |
+
torch.as_tensor(raw_full, dtype=torch.float32),
|
| 1119 |
+
prompt_len_full=prompt_len_full_safe,
|
| 1120 |
+
gen_len=gen_len,
|
| 1121 |
+
user_prompt_indices=user_prompt_indices,
|
| 1122 |
+
neg_handling=args.attnlrp_neg_handling,
|
| 1123 |
+
norm_mode=args.attnlrp_norm_mode,
|
| 1124 |
+
)
|
| 1125 |
+
if expected and int(v.numel()) != expected:
|
| 1126 |
+
raise RuntimeError(
|
| 1127 |
+
"AttnLRP full-vector length mismatch for pre-trim view: "
|
| 1128 |
+
f"got {int(v.numel())}, expected {expected}."
|
| 1129 |
+
)
|
| 1130 |
+
full_vectors.append(v)
|
| 1131 |
+
hop_vectors_raw = full_vectors
|
| 1132 |
+
|
| 1133 |
+
if not hop_vectors_raw and prompt_len_full is not None:
|
| 1134 |
+
# Fallback: lift trimmed vectors back to full token space with zeros for template tokens.
|
| 1135 |
+
gen_len = len(raw_generation_ids or [])
|
| 1136 |
+
hop_vectors_raw = [
|
| 1137 |
+
_lift_trimmed_to_full(
|
| 1138 |
+
v,
|
| 1139 |
+
prompt_len_full=prompt_len_full_safe,
|
| 1140 |
+
gen_len=gen_len,
|
| 1141 |
+
user_prompt_indices=user_prompt_indices,
|
| 1142 |
+
)
|
| 1143 |
+
for v in hop_vectors_trimmed
|
| 1144 |
+
]
|
| 1145 |
+
|
| 1146 |
+
if not hop_vectors_raw:
|
| 1147 |
+
raise RuntimeError("Missing pre-trim vectors; cannot render required full-sequence heatmap.")
|
| 1148 |
+
|
| 1149 |
+
# Lightweight debug stats to catch silent all-zero / NaN cases.
|
| 1150 |
+
hop_stats_raw = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in hop_vectors_raw]
|
| 1151 |
+
hop_stats_prompt = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in prompt_only_vectors]
|
| 1152 |
+
for i in range(max(len(hop_stats_raw), len(hop_stats_prompt))):
|
| 1153 |
+
raw_abs = hop_stats_raw[i]["abs_max"] if i < len(hop_stats_raw) else None
|
| 1154 |
+
prompt_abs = hop_stats_prompt[i]["abs_max"] if i < len(hop_stats_prompt) else None
|
| 1155 |
+
print(f"[stats] panel {i}: raw_abs_max={raw_abs} prompt_abs_max={prompt_abs}")
|
| 1156 |
+
|
| 1157 |
+
hop_token_raw = analysis.package_token_hops(hop_vectors_raw)
|
| 1158 |
+
hop_token_prompt = analysis.package_token_hops(prompt_only_vectors)
|
| 1159 |
+
|
| 1160 |
+
case_meta: Dict[str, Any] = {
|
| 1161 |
+
"dataset": ds_name,
|
| 1162 |
+
"index": args.index,
|
| 1163 |
+
"sink_span": sink_span,
|
| 1164 |
+
"thinking_span": thinking_span,
|
| 1165 |
+
"n_hops": args.n_hops,
|
| 1166 |
+
"thinking_ratios": thinking_ratios,
|
| 1167 |
+
"mode": mode,
|
| 1168 |
+
"ifr_view": method_meta.get("ifr", {}).get("ifr_view") if isinstance(method_meta.get("ifr"), dict) else None,
|
| 1169 |
+
"panel_titles": method_meta.get("ifr", {}).get("panel_titles") if isinstance(method_meta.get("ifr"), dict) else None,
|
| 1170 |
+
"attnlrp_neg_handling": args.attnlrp_neg_handling if mode in ("attnlrp", "ft_attnlrp") else None,
|
| 1171 |
+
"attnlrp_norm_mode": args.attnlrp_norm_mode if mode in ("attnlrp", "ft_attnlrp") else None,
|
| 1172 |
+
"attnlrp_ratio_enabled": (args.attnlrp_norm_mode == "norm") if mode in ("attnlrp", "ft_attnlrp") else None,
|
| 1173 |
+
"vector_stats_raw": hop_stats_raw,
|
| 1174 |
+
"vector_stats_prompt": hop_stats_prompt,
|
| 1175 |
+
}
|
| 1176 |
+
|
| 1177 |
+
generation_text = "".join(generation_tokens_trimmed) if generation_tokens_trimmed else ""
|
| 1178 |
+
prompt_text = example.prompt
|
| 1179 |
+
record = {
|
| 1180 |
+
"meta": case_meta,
|
| 1181 |
+
"prompt": prompt_text,
|
| 1182 |
+
"target": example.target,
|
| 1183 |
+
"generation": generation_text,
|
| 1184 |
+
"full_all_tokens": raw_tokens,
|
| 1185 |
+
"raw_token_roles": roles_raw,
|
| 1186 |
+
"prompt_tokens": prompt_tokens_only,
|
| 1187 |
+
"prompt_token_roles": ["user" for _ in range(len(prompt_tokens_only))],
|
| 1188 |
+
"token_hops_raw": hop_token_raw,
|
| 1189 |
+
"token_hops_prompt": hop_token_prompt,
|
| 1190 |
+
"ifr_meta": method_meta.get("ifr"),
|
| 1191 |
+
"attnlrp_meta": method_meta.get("attnlrp"),
|
| 1192 |
+
}
|
| 1193 |
+
|
| 1194 |
+
out_dir = Path(args.output_dir)
|
| 1195 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 1196 |
+
stem = make_output_stem(ds_name, args.index, mode)
|
| 1197 |
+
json_path = out_dir / f"{stem}.json"
|
| 1198 |
+
html_path = out_dir / f"{stem}.html"
|
| 1199 |
+
|
| 1200 |
+
with json_path.open("w", encoding="utf-8") as f:
|
| 1201 |
+
json.dump(record, f, ensure_ascii=False, indent=2)
|
| 1202 |
+
|
| 1203 |
+
html = viz.render_case_html(
|
| 1204 |
+
case_meta,
|
| 1205 |
+
token_view_raw={
|
| 1206 |
+
"label": "Pre-trim token-level heatmap (full sequence with chat template)",
|
| 1207 |
+
"tokens": raw_tokens,
|
| 1208 |
+
"roles": roles_raw,
|
| 1209 |
+
"hops": hop_token_raw,
|
| 1210 |
+
},
|
| 1211 |
+
token_view_prompt={
|
| 1212 |
+
"label": "Prompt-only token-level heatmap (user prompt only)",
|
| 1213 |
+
"tokens": prompt_tokens_only,
|
| 1214 |
+
"roles": ["user" for _ in range(len(prompt_tokens_only))],
|
| 1215 |
+
"hops": hop_token_prompt,
|
| 1216 |
+
},
|
| 1217 |
+
)
|
| 1218 |
+
html_path.write_text(html, encoding="utf-8")
|
| 1219 |
+
|
| 1220 |
+
print(f"[done] wrote {json_path}")
|
| 1221 |
+
print(f"[done] wrote {html_path}")
|
| 1222 |
+
|
| 1223 |
+
|
| 1224 |
+
if __name__ == "__main__":
|
| 1225 |
+
main()
|
exp/case_study/run_mas_case.py
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""MAS case study: visualize token-perturbation faithfulness for attribution methods.
|
| 3 |
+
|
| 4 |
+
This script matches the faithfulness evaluation logic implemented in:
|
| 5 |
+
- evaluations/faithfulness.py
|
| 6 |
+
- llm_attr_eval.LLMAttributionEvaluator.faithfulness_test()
|
| 7 |
+
|
| 8 |
+
For a single example and a selected attribution method, we:
|
| 9 |
+
1) Compute token-level attributions (Seq / Row / Recursive) over prompt tokens.
|
| 10 |
+
2) Rank prompt tokens by attribution mass.
|
| 11 |
+
3) Iteratively perturb the prompt by replacing one token at a time with PAD tokens.
|
| 12 |
+
4) Score the model as sum log p(generation + EOS | prompt) under the chat template.
|
| 13 |
+
5) Compute RISE / MAS / RISE+AP (AUCs) and visualize the perturbation impact as token heatmaps.
|
| 14 |
+
|
| 15 |
+
Outputs JSON + HTML to exp/case_study/out/.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
import types
|
| 25 |
+
from importlib.machinery import ModuleSpec
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _early_set_cuda_visible_devices() -> None:
|
| 33 |
+
"""Set CUDA_VISIBLE_DEVICES before importing torch/transformers.
|
| 34 |
+
|
| 35 |
+
Note: CUDA device indices are re-mapped inside the process after applying the mask.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 39 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 40 |
+
args, _ = parser.parse_known_args(sys.argv[1:])
|
| 41 |
+
cuda = args.cuda.strip() if isinstance(args.cuda, str) else ""
|
| 42 |
+
if cuda and "," in cuda:
|
| 43 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cuda
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
_early_set_cuda_visible_devices()
|
| 48 |
+
|
| 49 |
+
import torch
|
| 50 |
+
|
| 51 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 52 |
+
if str(REPO_ROOT) not in sys.path:
|
| 53 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 54 |
+
|
| 55 |
+
# Avoid optional vision deps when importing transformers.
|
| 56 |
+
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
|
| 57 |
+
os.environ.setdefault("DISABLE_TRANSFORMERS_IMAGE_TRANSFORMS", "1")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _stub_torchvision() -> None:
|
| 61 |
+
"""Provide minimal torchvision stubs so transformers imports succeed without torchvision."""
|
| 62 |
+
|
| 63 |
+
if "torchvision" in sys.modules:
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
def _mk(name: str) -> types.ModuleType:
|
| 67 |
+
mod = types.ModuleType(name)
|
| 68 |
+
mod.__spec__ = ModuleSpec(name, loader=None)
|
| 69 |
+
return mod
|
| 70 |
+
|
| 71 |
+
tv = _mk("torchvision")
|
| 72 |
+
tv.__dict__["__path__"] = []
|
| 73 |
+
submods = ["transforms", "_meta_registrations", "datasets", "io", "models", "ops", "utils"]
|
| 74 |
+
for name in submods:
|
| 75 |
+
mod = _mk(f"torchvision.{name}")
|
| 76 |
+
sys.modules[f"torchvision.{name}"] = mod
|
| 77 |
+
setattr(tv, name, mod)
|
| 78 |
+
|
| 79 |
+
class _InterpolationMode:
|
| 80 |
+
NEAREST = 0
|
| 81 |
+
NEAREST_EXACT = 0
|
| 82 |
+
BILINEAR = 1
|
| 83 |
+
BICUBIC = 2
|
| 84 |
+
LANCZOS = 3
|
| 85 |
+
BOX = 4
|
| 86 |
+
HAMMING = 5
|
| 87 |
+
|
| 88 |
+
sys.modules["torchvision.transforms"].InterpolationMode = _InterpolationMode
|
| 89 |
+
sys.modules["torchvision.transforms"].__all__ = ["InterpolationMode"]
|
| 90 |
+
|
| 91 |
+
ops_mod = sys.modules.get("torchvision.ops") or _mk("torchvision.ops")
|
| 92 |
+
sys.modules["torchvision.ops"] = ops_mod
|
| 93 |
+
setattr(tv, "ops", ops_mod)
|
| 94 |
+
misc_mod = _mk("torchvision.ops.misc")
|
| 95 |
+
sys.modules["torchvision.ops.misc"] = misc_mod
|
| 96 |
+
setattr(ops_mod, "misc", misc_mod)
|
| 97 |
+
|
| 98 |
+
class _FrozenBatchNorm2d:
|
| 99 |
+
def __init__(self, *args, **kwargs):
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
misc_mod.FrozenBatchNorm2d = _FrozenBatchNorm2d
|
| 103 |
+
sys.modules["torchvision"] = tv
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _stub_timm() -> None:
|
| 107 |
+
"""Provide minimal timm stubs to avoid optional vision deps."""
|
| 108 |
+
|
| 109 |
+
if "timm" in sys.modules:
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
def _mk(name: str) -> types.ModuleType:
|
| 113 |
+
mod = types.ModuleType(name)
|
| 114 |
+
mod.__spec__ = ModuleSpec(name, loader=None)
|
| 115 |
+
return mod
|
| 116 |
+
|
| 117 |
+
timm = _mk("timm")
|
| 118 |
+
timm.__dict__["__path__"] = []
|
| 119 |
+
sys.modules["timm"] = timm
|
| 120 |
+
|
| 121 |
+
data_mod = _mk("timm.data")
|
| 122 |
+
sys.modules["timm.data"] = data_mod
|
| 123 |
+
timm.data = data_mod
|
| 124 |
+
|
| 125 |
+
class _ImageNetInfo:
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def _infer_imagenet_subset(*args, **kwargs):
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
data_mod.ImageNetInfo = _ImageNetInfo
|
| 132 |
+
data_mod.infer_imagenet_subset = _infer_imagenet_subset
|
| 133 |
+
|
| 134 |
+
layers_mod = _mk("timm.layers")
|
| 135 |
+
sys.modules["timm.layers"] = layers_mod
|
| 136 |
+
timm.layers = layers_mod
|
| 137 |
+
|
| 138 |
+
create_norm_mod = _mk("timm.layers.create_norm")
|
| 139 |
+
sys.modules["timm.layers.create_norm"] = create_norm_mod
|
| 140 |
+
layers_mod.create_norm = create_norm_mod
|
| 141 |
+
|
| 142 |
+
def _get_norm_layer(*args, **kwargs):
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
create_norm_mod.get_norm_layer = _get_norm_layer
|
| 146 |
+
|
| 147 |
+
classifier_mod = _mk("timm.layers.classifier")
|
| 148 |
+
sys.modules["timm.layers.classifier"] = classifier_mod
|
| 149 |
+
layers_mod.classifier = classifier_mod
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _stub_gemma3n() -> None:
|
| 153 |
+
"""Stub Gemma3n config module if transformers tries to import it."""
|
| 154 |
+
|
| 155 |
+
if "transformers.models.gemma3n.configuration_gemma3n" in sys.modules:
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
gemma_pkg = types.ModuleType("transformers.models.gemma3n")
|
| 159 |
+
gemma_pkg.__spec__ = ModuleSpec("transformers.models.gemma3n", loader=None, is_package=True)
|
| 160 |
+
sys.modules["transformers.models.gemma3n"] = gemma_pkg
|
| 161 |
+
|
| 162 |
+
gemma_conf = types.ModuleType("transformers.models.gemma3n.configuration_gemma3n")
|
| 163 |
+
gemma_conf.__spec__ = ModuleSpec("transformers.models.gemma3n.configuration_gemma3n", loader=None)
|
| 164 |
+
|
| 165 |
+
class Gemma3nConfig:
|
| 166 |
+
def __init__(self, *args, **kwargs):
|
| 167 |
+
self.model_type = "gemma3n"
|
| 168 |
+
|
| 169 |
+
class Gemma3nTextConfig(Gemma3nConfig):
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
gemma_conf.Gemma3nConfig = Gemma3nConfig
|
| 173 |
+
gemma_conf.Gemma3nTextConfig = Gemma3nTextConfig
|
| 174 |
+
gemma_conf.__all__ = ["Gemma3nConfig", "Gemma3nTextConfig"]
|
| 175 |
+
sys.modules["transformers.models.gemma3n.configuration_gemma3n"] = gemma_conf
|
| 176 |
+
setattr(gemma_pkg, "configuration_gemma3n", gemma_conf)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
_stub_torchvision()
|
| 180 |
+
_stub_timm()
|
| 181 |
+
_stub_gemma3n()
|
| 182 |
+
|
| 183 |
+
import transformers # noqa: E402
|
| 184 |
+
|
| 185 |
+
# Provide light stubs if Longformer classes are unavailable; we don't use them here.
|
| 186 |
+
if not hasattr(transformers, "LongformerTokenizer"):
|
| 187 |
+
class _DummyLongformerTokenizer:
|
| 188 |
+
def __init__(self, *args, **kwargs):
|
| 189 |
+
raise ImportError("LongformerTokenizer stubbed; install full transformers if needed.")
|
| 190 |
+
transformers.LongformerTokenizer = _DummyLongformerTokenizer
|
| 191 |
+
if not hasattr(transformers, "LongformerForMaskedLM"):
|
| 192 |
+
class _DummyLongformerForMaskedLM:
|
| 193 |
+
def __init__(self, *args, **kwargs):
|
| 194 |
+
raise ImportError("LongformerForMaskedLM stubbed; install full transformers if needed.")
|
| 195 |
+
transformers.LongformerForMaskedLM = _DummyLongformerForMaskedLM
|
| 196 |
+
|
| 197 |
+
from exp.case_study import viz # noqa: E402
|
| 198 |
+
from exp.exp2 import dataset_utils as ds_utils # noqa: E402
|
| 199 |
+
from shared_utils import DEFAULT_PROMPT_TEMPLATE # noqa: E402
|
| 200 |
+
|
| 201 |
+
import llm_attr # noqa: E402
|
| 202 |
+
from evaluations.attribution_recovery import load_model # noqa: E402
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def resolve_device(cuda: Optional[str], cuda_num: int) -> str:
|
| 206 |
+
if cuda and isinstance(cuda, str) and "," in cuda:
|
| 207 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cuda
|
| 208 |
+
return "auto"
|
| 209 |
+
if cuda and isinstance(cuda, str) and cuda.strip():
|
| 210 |
+
try:
|
| 211 |
+
idx = int(cuda)
|
| 212 |
+
except Exception:
|
| 213 |
+
idx = 0
|
| 214 |
+
return f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
|
| 215 |
+
return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu"
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def load_example(dataset: str, index: int, data_root: Path) -> Tuple[ds_utils.CachedExample, str]:
|
| 219 |
+
ds_path = Path(dataset)
|
| 220 |
+
if ds_path.exists():
|
| 221 |
+
examples = ds_utils.read_cached_jsonl(ds_path)
|
| 222 |
+
dataset_name = ds_path.name
|
| 223 |
+
else:
|
| 224 |
+
loader = ds_utils.DatasetLoader(data_root=data_root)
|
| 225 |
+
examples = loader.load(dataset)
|
| 226 |
+
dataset_name = dataset
|
| 227 |
+
|
| 228 |
+
if not examples:
|
| 229 |
+
raise ValueError(f"No examples found for dataset={dataset}")
|
| 230 |
+
|
| 231 |
+
if index < 0:
|
| 232 |
+
index = len(examples) + index
|
| 233 |
+
if not (0 <= index < len(examples)):
|
| 234 |
+
raise IndexError(f"index {index} out of range for dataset with {len(examples)} examples")
|
| 235 |
+
|
| 236 |
+
return examples[index], dataset_name
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def make_output_stem(dataset_name: str, index: int, method: str) -> str:
|
| 240 |
+
safe_name = dataset_name.replace("/", "_").replace(" ", "_")
|
| 241 |
+
return f"mas_case_{method}_{safe_name}_idx{index}"
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def format_prompt(tokenizer: Any, prompt: str) -> str:
|
| 245 |
+
modified_prompt = DEFAULT_PROMPT_TEMPLATE.format(context=prompt, query="")
|
| 246 |
+
formatted_prompt = [{"role": "user", "content": modified_prompt}]
|
| 247 |
+
return tokenizer.apply_chat_template(
|
| 248 |
+
formatted_prompt,
|
| 249 |
+
tokenize=False,
|
| 250 |
+
add_generation_prompt=True,
|
| 251 |
+
enable_thinking=False,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@torch.inference_mode()
|
| 256 |
+
def compute_logprob_response_given_prompt(model: Any, prompt_ids: torch.Tensor, response_ids: torch.Tensor) -> torch.Tensor:
|
| 257 |
+
"""Compute log-probabilities of response_ids given prompt_ids.
|
| 258 |
+
|
| 259 |
+
Shapes:
|
| 260 |
+
prompt_ids: [B, N]
|
| 261 |
+
response_ids: [B, M]
|
| 262 |
+
returns: [B, M]
|
| 263 |
+
"""
|
| 264 |
+
input_ids = torch.cat([prompt_ids, response_ids], dim=1)
|
| 265 |
+
attention_mask = torch.ones_like(input_ids)
|
| 266 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits # [B, N+M, V]
|
| 267 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 268 |
+
|
| 269 |
+
response_start = int(prompt_ids.shape[1])
|
| 270 |
+
logits_for_response = log_probs[:, response_start - 1 : -1, :] # [B, M, V]
|
| 271 |
+
gathered = logits_for_response.gather(2, response_ids.unsqueeze(-1))
|
| 272 |
+
return gathered.squeeze(-1)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@torch.inference_mode()
|
| 276 |
+
def score_prompt_ids_with_generation(model: Any, *, prompt_ids: torch.Tensor, generation_ids: torch.Tensor) -> float:
|
| 277 |
+
return float(compute_logprob_response_given_prompt(model, prompt_ids, generation_ids).sum().detach().cpu().item())
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@torch.inference_mode()
|
| 281 |
+
def _ensure_pad_token_id(tokenizer: Any) -> int:
|
| 282 |
+
if tokenizer.pad_token_id is None:
|
| 283 |
+
if tokenizer.eos_token_id is None:
|
| 284 |
+
raise RuntimeError("tokenizer has neither pad_token_id nor eos_token_id; cannot define baseline token.")
|
| 285 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 286 |
+
return int(tokenizer.pad_token_id)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _find_subsequence_start(haystack: torch.Tensor, needle: torch.Tensor) -> Optional[int]:
|
| 290 |
+
if haystack.ndim != 1 or needle.ndim != 1:
|
| 291 |
+
raise ValueError("Expected 1D tensors for subsequence matching.")
|
| 292 |
+
if needle.numel() == 0:
|
| 293 |
+
return 0
|
| 294 |
+
hay_len = int(haystack.numel())
|
| 295 |
+
needle_len = int(needle.numel())
|
| 296 |
+
if needle_len > hay_len:
|
| 297 |
+
return None
|
| 298 |
+
for i in range(hay_len - needle_len + 1):
|
| 299 |
+
if torch.equal(haystack[i : i + needle_len], needle):
|
| 300 |
+
return i
|
| 301 |
+
return None
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def decode_text_into_tokens(tokenizer: Any, text: str) -> List[str]:
|
| 305 |
+
encoding = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
|
| 306 |
+
offsets = list(encoding["offset_mapping"])
|
| 307 |
+
tokens: List[str] = []
|
| 308 |
+
for start, end in offsets:
|
| 309 |
+
tokens.append(text[start:end])
|
| 310 |
+
return tokens
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def auc(arr: np.ndarray) -> float:
|
| 314 |
+
return float((arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1))
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def mas_trace(
|
| 318 |
+
model: Any,
|
| 319 |
+
tokenizer: Any,
|
| 320 |
+
*,
|
| 321 |
+
attribution: torch.Tensor,
|
| 322 |
+
prompt: str,
|
| 323 |
+
generation: str,
|
| 324 |
+
user_prompt_indices: Optional[Sequence[int]] = None,
|
| 325 |
+
keep_prompt_token_indices: Optional[Sequence[int]] = None,
|
| 326 |
+
k: int = 20,
|
| 327 |
+
) -> Dict[str, Any]:
|
| 328 |
+
"""Return a token-level faithfulness trace (RISE/MAS/RISE+AP) plus per-token deltas."""
|
| 329 |
+
|
| 330 |
+
pad_token_id = _ensure_pad_token_id(tokenizer)
|
| 331 |
+
|
| 332 |
+
user_prompt = " " + prompt
|
| 333 |
+
formatted = format_prompt(tokenizer, user_prompt)
|
| 334 |
+
formatted_ids = tokenizer(formatted, return_tensors="pt", add_special_tokens=False).input_ids
|
| 335 |
+
user_ids = tokenizer(user_prompt, return_tensors="pt", add_special_tokens=False).input_ids
|
| 336 |
+
|
| 337 |
+
prompt_ids = formatted_ids.to(model.device)
|
| 338 |
+
prompt_ids_perturbed = prompt_ids.clone()
|
| 339 |
+
gen_ids = tokenizer(
|
| 340 |
+
generation + (tokenizer.eos_token or ""),
|
| 341 |
+
return_tensors="pt",
|
| 342 |
+
add_special_tokens=False,
|
| 343 |
+
).input_ids.to(model.device)
|
| 344 |
+
|
| 345 |
+
attr_cpu = attribution.detach().cpu()
|
| 346 |
+
w = attr_cpu.sum(0)
|
| 347 |
+
P = int(w.numel())
|
| 348 |
+
|
| 349 |
+
if keep_prompt_token_indices is None:
|
| 350 |
+
keep = list(range(P))
|
| 351 |
+
else:
|
| 352 |
+
keep = []
|
| 353 |
+
seen: set[int] = set()
|
| 354 |
+
for raw in keep_prompt_token_indices:
|
| 355 |
+
try:
|
| 356 |
+
idx = int(raw)
|
| 357 |
+
except Exception:
|
| 358 |
+
continue
|
| 359 |
+
if 0 <= idx < P and idx not in seen:
|
| 360 |
+
keep.append(idx)
|
| 361 |
+
seen.add(idx)
|
| 362 |
+
keep.sort()
|
| 363 |
+
|
| 364 |
+
K = len(keep)
|
| 365 |
+
if K:
|
| 366 |
+
w_keep = w.index_select(0, torch.as_tensor(keep, dtype=torch.long))
|
| 367 |
+
sorted_local = torch.argsort(w_keep, descending=True)
|
| 368 |
+
sorted_attr_indices = torch.as_tensor([keep[int(i.item())] for i in sorted_local], dtype=torch.long)
|
| 369 |
+
attr_sum = float(w_keep.sum().item())
|
| 370 |
+
else:
|
| 371 |
+
sorted_attr_indices = torch.zeros((0,), dtype=torch.long)
|
| 372 |
+
attr_sum = 0.0
|
| 373 |
+
|
| 374 |
+
if int(user_ids.shape[1]) != P:
|
| 375 |
+
raise ValueError(
|
| 376 |
+
"Prompt-side attribution length does not match tokenized user prompt length: "
|
| 377 |
+
f"attr P={P}, user_prompt P={int(user_ids.shape[1])}."
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
prompt_positions: List[int]
|
| 381 |
+
if user_prompt_indices is not None:
|
| 382 |
+
prompt_positions = [int(x) for x in user_prompt_indices]
|
| 383 |
+
if len(prompt_positions) != P:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
"user_prompt_indices length does not match prompt-side attribution length: "
|
| 386 |
+
f"indices P={len(prompt_positions)}, attr P={P}."
|
| 387 |
+
)
|
| 388 |
+
if P and max(prompt_positions) >= int(prompt_ids_perturbed.shape[1]):
|
| 389 |
+
raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
|
| 390 |
+
else:
|
| 391 |
+
user_start = _find_subsequence_start(formatted_ids[0], user_ids[0])
|
| 392 |
+
if user_start is None:
|
| 393 |
+
raise RuntimeError("Failed to locate user prompt token span inside formatted chat prompt.")
|
| 394 |
+
prompt_positions = [int(user_start) + j for j in range(P)]
|
| 395 |
+
|
| 396 |
+
if K > 0:
|
| 397 |
+
steps = int(k) if k is not None else 0
|
| 398 |
+
if steps <= 0:
|
| 399 |
+
steps = 1
|
| 400 |
+
steps = min(steps, K)
|
| 401 |
+
else:
|
| 402 |
+
steps = 0
|
| 403 |
+
|
| 404 |
+
scores = np.zeros(steps + 1, dtype=np.float64)
|
| 405 |
+
density = np.zeros(steps + 1, dtype=np.float64)
|
| 406 |
+
|
| 407 |
+
scores[0] = score_prompt_ids_with_generation(model, prompt_ids=prompt_ids_perturbed, generation_ids=gen_ids)
|
| 408 |
+
density[0] = 1.0
|
| 409 |
+
|
| 410 |
+
if K == 0:
|
| 411 |
+
return {
|
| 412 |
+
"num_tokens": P,
|
| 413 |
+
"sorted_attr_indices": [],
|
| 414 |
+
"scores_raw": scores.tolist(),
|
| 415 |
+
"density": density.tolist(),
|
| 416 |
+
"normalized_model_response": scores.tolist(),
|
| 417 |
+
"alignment_penalty": np.zeros_like(scores).tolist(),
|
| 418 |
+
"corrected_scores": scores.tolist(),
|
| 419 |
+
"token_deltas_raw": np.zeros(P, dtype=np.float64).tolist(),
|
| 420 |
+
"attr_weights": np.zeros(P, dtype=np.float64).tolist(),
|
| 421 |
+
"metrics": {"RISE": 0.0, "MAS": 0.0, "RISE+AP": 0.0},
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
if attr_sum <= 0:
|
| 425 |
+
density = np.linspace(1.0, 0.0, steps + 1)
|
| 426 |
+
|
| 427 |
+
per_token_delta = np.zeros(P, dtype=np.float64)
|
| 428 |
+
|
| 429 |
+
base = K // steps
|
| 430 |
+
remainder = K % steps
|
| 431 |
+
start = 0
|
| 432 |
+
for step in range(steps):
|
| 433 |
+
size = base + (1 if step < remainder else 0)
|
| 434 |
+
group = sorted_attr_indices[start : start + size]
|
| 435 |
+
start += size
|
| 436 |
+
|
| 437 |
+
for idx_t in group:
|
| 438 |
+
idx = int(idx_t.item())
|
| 439 |
+
abs_pos = int(prompt_positions[idx])
|
| 440 |
+
prompt_ids_perturbed[0, abs_pos] = pad_token_id
|
| 441 |
+
|
| 442 |
+
scores[step + 1] = score_prompt_ids_with_generation(model, prompt_ids=prompt_ids_perturbed, generation_ids=gen_ids)
|
| 443 |
+
if attr_sum > 0:
|
| 444 |
+
dec = float(w.index_select(0, group).sum().item()) / attr_sum
|
| 445 |
+
density[step + 1] = density[step] - dec
|
| 446 |
+
|
| 447 |
+
delta = scores[step] - scores[step + 1]
|
| 448 |
+
for idx_t in group:
|
| 449 |
+
idx = int(idx_t.item())
|
| 450 |
+
per_token_delta[idx] = delta
|
| 451 |
+
|
| 452 |
+
min_normalized_pred = 1.0
|
| 453 |
+
normalized_model_response = scores.copy()
|
| 454 |
+
for i in range(len(scores)):
|
| 455 |
+
normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
|
| 456 |
+
normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
|
| 457 |
+
min_normalized_pred = min(min_normalized_pred, float(normalized_pred))
|
| 458 |
+
normalized_model_response[i] = min_normalized_pred
|
| 459 |
+
|
| 460 |
+
alignment_penalty = np.abs(normalized_model_response - density)
|
| 461 |
+
corrected_scores = normalized_model_response + alignment_penalty
|
| 462 |
+
corrected_scores = corrected_scores.clip(0, 1)
|
| 463 |
+
corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
|
| 464 |
+
if np.isnan(corrected_scores).any():
|
| 465 |
+
corrected_scores = np.linspace(1, 0, len(scores))
|
| 466 |
+
|
| 467 |
+
rise = auc(normalized_model_response)
|
| 468 |
+
mas = auc(corrected_scores)
|
| 469 |
+
rise_ap = auc(normalized_model_response + alignment_penalty)
|
| 470 |
+
|
| 471 |
+
if attr_sum > 0:
|
| 472 |
+
attr_weights = np.zeros(P, dtype=np.float64)
|
| 473 |
+
for idx in keep:
|
| 474 |
+
attr_weights[idx] = float(w[idx].item()) / (attr_sum + 1e-12)
|
| 475 |
+
else:
|
| 476 |
+
attr_weights = np.zeros(P, dtype=np.float64)
|
| 477 |
+
|
| 478 |
+
return {
|
| 479 |
+
"num_tokens": P,
|
| 480 |
+
"sorted_attr_indices": [int(i.item()) for i in sorted_attr_indices],
|
| 481 |
+
"scores_raw": scores.tolist(),
|
| 482 |
+
"density": density.tolist(),
|
| 483 |
+
"normalized_model_response": normalized_model_response.tolist(),
|
| 484 |
+
"alignment_penalty": alignment_penalty.tolist(),
|
| 485 |
+
"corrected_scores": corrected_scores.tolist(),
|
| 486 |
+
"token_deltas_raw": per_token_delta.tolist(),
|
| 487 |
+
"attr_weights": attr_weights.tolist(),
|
| 488 |
+
"metrics": {"RISE": rise, "MAS": mas, "RISE+AP": rise_ap},
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def compute_method_attribution(
|
| 493 |
+
method: str,
|
| 494 |
+
example: ds_utils.CachedExample,
|
| 495 |
+
model: Any,
|
| 496 |
+
tokenizer: Any,
|
| 497 |
+
*,
|
| 498 |
+
n_hops: int,
|
| 499 |
+
sink_span: Optional[Tuple[int, int]],
|
| 500 |
+
thinking_span: Optional[Tuple[int, int]],
|
| 501 |
+
chunk_tokens: int,
|
| 502 |
+
sink_chunk_tokens: int,
|
| 503 |
+
attnlrp_neg_handling: str,
|
| 504 |
+
attnlrp_norm_mode: str,
|
| 505 |
+
) -> Tuple[str, Any, llm_attr.LLMAttributionResult]:
|
| 506 |
+
prompt = example.prompt
|
| 507 |
+
target = example.target
|
| 508 |
+
|
| 509 |
+
if method == "ifr":
|
| 510 |
+
if sink_span is None:
|
| 511 |
+
raise ValueError("IFR requires sink_span (use dataset sink_span or pass --sink_span).")
|
| 512 |
+
attributor = llm_attr.LLMIFRAttribution(model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens)
|
| 513 |
+
result = attributor.calculate_ifr_span(prompt, target=target, span=sink_span)
|
| 514 |
+
return "IFR (ifr_span)", attributor, result
|
| 515 |
+
|
| 516 |
+
if method == "ifr_all_positions_output_only":
|
| 517 |
+
if sink_span is None:
|
| 518 |
+
raise ValueError(
|
| 519 |
+
"ifr_all_positions_output_only requires sink_span (use dataset sink_span or pass --sink_span)."
|
| 520 |
+
)
|
| 521 |
+
attributor = llm_attr.LLMIFRAttribution(model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens)
|
| 522 |
+
result = attributor.calculate_ifr_for_all_positions_output_only(
|
| 523 |
+
prompt,
|
| 524 |
+
target=target,
|
| 525 |
+
sink_span=sink_span,
|
| 526 |
+
)
|
| 527 |
+
return "IFR (ifr_all_positions_output_only)", attributor, result
|
| 528 |
+
|
| 529 |
+
if method in ("ft", "ft_ifr"):
|
| 530 |
+
attributor = llm_attr.LLMIFRAttribution(model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens)
|
| 531 |
+
result = attributor.calculate_ifr_multi_hop(
|
| 532 |
+
prompt,
|
| 533 |
+
target=target,
|
| 534 |
+
sink_span=sink_span,
|
| 535 |
+
thinking_span=thinking_span,
|
| 536 |
+
n_hops=int(n_hops),
|
| 537 |
+
)
|
| 538 |
+
return "FT-IFR (ifr_multi_hop)", attributor, result
|
| 539 |
+
|
| 540 |
+
if method in ("ft_improve", "ft_ifr_improve"):
|
| 541 |
+
import ft_ifr_improve
|
| 542 |
+
|
| 543 |
+
attributor = ft_ifr_improve.LLMIFRAttributionImproved(
|
| 544 |
+
model,
|
| 545 |
+
tokenizer,
|
| 546 |
+
chunk_tokens=chunk_tokens,
|
| 547 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 548 |
+
)
|
| 549 |
+
result = attributor.calculate_ifr_multi_hop_stop_words(
|
| 550 |
+
prompt,
|
| 551 |
+
target=target,
|
| 552 |
+
sink_span=sink_span,
|
| 553 |
+
thinking_span=thinking_span,
|
| 554 |
+
n_hops=int(n_hops),
|
| 555 |
+
)
|
| 556 |
+
return "FT-IFR (ifr_multi_hop_stop_words)", attributor, result
|
| 557 |
+
|
| 558 |
+
if method == "ft_split_hop":
|
| 559 |
+
import ft_ifr_improve
|
| 560 |
+
|
| 561 |
+
attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(
|
| 562 |
+
model,
|
| 563 |
+
tokenizer,
|
| 564 |
+
chunk_tokens=chunk_tokens,
|
| 565 |
+
sink_chunk_tokens=sink_chunk_tokens,
|
| 566 |
+
)
|
| 567 |
+
result = attributor.calculate_ifr_multi_hop_split_hop(
|
| 568 |
+
prompt,
|
| 569 |
+
target=target,
|
| 570 |
+
sink_span=sink_span,
|
| 571 |
+
thinking_span=thinking_span,
|
| 572 |
+
n_hops=int(n_hops),
|
| 573 |
+
)
|
| 574 |
+
return "FT-IFR (ifr_multi_hop_split_hop)", attributor, result
|
| 575 |
+
|
| 576 |
+
if method == "attnlrp":
|
| 577 |
+
attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 578 |
+
result = attributor.calculate_attnlrp_ft_hop0(
|
| 579 |
+
prompt,
|
| 580 |
+
target=target,
|
| 581 |
+
sink_span=sink_span,
|
| 582 |
+
thinking_span=thinking_span,
|
| 583 |
+
neg_handling=attnlrp_neg_handling,
|
| 584 |
+
norm_mode=attnlrp_norm_mode,
|
| 585 |
+
)
|
| 586 |
+
return "AttnLRP (ft_attnlrp hop0)", attributor, result
|
| 587 |
+
|
| 588 |
+
if method == "ft_attnlrp":
|
| 589 |
+
attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 590 |
+
result = attributor.calculate_attnlrp_aggregated_multi_hop(
|
| 591 |
+
prompt,
|
| 592 |
+
target=target,
|
| 593 |
+
sink_span=sink_span,
|
| 594 |
+
thinking_span=thinking_span,
|
| 595 |
+
n_hops=int(n_hops),
|
| 596 |
+
neg_handling=attnlrp_neg_handling,
|
| 597 |
+
norm_mode=attnlrp_norm_mode,
|
| 598 |
+
)
|
| 599 |
+
return "FT-AttnLRP (attnlrp_aggregated_multi_hop)", attributor, result
|
| 600 |
+
|
| 601 |
+
raise ValueError(f"Unsupported method={method!r}")
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def parse_args() -> argparse.Namespace:
|
| 605 |
+
parser = argparse.ArgumentParser("MAS case study (faithfulness perturbation visualization)")
|
| 606 |
+
parser.add_argument("--dataset", type=str, default="exp/exp2/data/morehopqa.jsonl", help="Dataset name or JSONL path.")
|
| 607 |
+
parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache root for dataset names.")
|
| 608 |
+
parser.add_argument("--index", type=int, default=0, help="Sample index (supports negative for reverse).")
|
| 609 |
+
parser.add_argument(
|
| 610 |
+
"--method",
|
| 611 |
+
type=str,
|
| 612 |
+
choices=[
|
| 613 |
+
"ifr",
|
| 614 |
+
"ifr_all_positions_output_only",
|
| 615 |
+
"ft",
|
| 616 |
+
"ft_ifr",
|
| 617 |
+
"ft_improve",
|
| 618 |
+
"ft_ifr_improve",
|
| 619 |
+
"ft_split_hop",
|
| 620 |
+
"attnlrp",
|
| 621 |
+
"ft_attnlrp",
|
| 622 |
+
],
|
| 623 |
+
default="ft",
|
| 624 |
+
)
|
| 625 |
+
parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (ignored if --model_path set).")
|
| 626 |
+
parser.add_argument("--model_path", type=str, default=None, help="Local model path to override --model.")
|
| 627 |
+
parser.add_argument("--cuda", type=str, default=None, help="CUDA spec (e.g., '0' or '0,1').")
|
| 628 |
+
parser.add_argument("--cuda_num", type=int, default=0, help="Fallback GPU index when --cuda unset.")
|
| 629 |
+
parser.add_argument("--n_hops", type=int, default=1, help="Number of hops for multi-hop methods.")
|
| 630 |
+
parser.add_argument("--sink_span", type=int, nargs=2, default=None, help="Optional sink span over generation tokens.")
|
| 631 |
+
parser.add_argument("--thinking_span", type=int, nargs=2, default=None, help="Optional thinking span over generation tokens.")
|
| 632 |
+
parser.add_argument("--chunk_tokens", type=int, default=128, help="IFR chunk size.")
|
| 633 |
+
parser.add_argument("--sink_chunk_tokens", type=int, default=32, help="IFR sink chunk size.")
|
| 634 |
+
parser.add_argument(
|
| 635 |
+
"--attnlrp_neg_handling",
|
| 636 |
+
type=str,
|
| 637 |
+
choices=["drop", "abs"],
|
| 638 |
+
default="drop",
|
| 639 |
+
help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
|
| 640 |
+
)
|
| 641 |
+
parser.add_argument(
|
| 642 |
+
"--attnlrp_norm_mode",
|
| 643 |
+
type=str,
|
| 644 |
+
choices=["norm", "no_norm"],
|
| 645 |
+
default="norm",
|
| 646 |
+
help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
|
| 647 |
+
)
|
| 648 |
+
parser.add_argument("--output_dir", type=str, default="exp/case_study/out", help="Where to write HTML/JSON artifacts.")
|
| 649 |
+
return parser.parse_args()
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def main() -> None:
|
| 653 |
+
args = parse_args()
|
| 654 |
+
device = resolve_device(args.cuda, args.cuda_num)
|
| 655 |
+
if torch.cuda.is_available():
|
| 656 |
+
visible = os.environ.get("CUDA_VISIBLE_DEVICES")
|
| 657 |
+
print(f"[info] CUDA_VISIBLE_DEVICES={visible!r} torch.cuda.device_count()={torch.cuda.device_count()} device={device}")
|
| 658 |
+
|
| 659 |
+
if args.method == "ft_ifr":
|
| 660 |
+
method_key = "ft"
|
| 661 |
+
elif args.method == "ft_ifr_improve":
|
| 662 |
+
method_key = "ft_improve"
|
| 663 |
+
else:
|
| 664 |
+
method_key = args.method
|
| 665 |
+
|
| 666 |
+
model_name = args.model_path if args.model_path is not None else args.model
|
| 667 |
+
model, tokenizer = load_model(model_name, device)
|
| 668 |
+
|
| 669 |
+
example, ds_name = load_example(args.dataset, args.index, Path(args.data_root))
|
| 670 |
+
|
| 671 |
+
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
|
| 672 |
+
thinking_span = (
|
| 673 |
+
tuple(args.thinking_span)
|
| 674 |
+
if args.thinking_span is not None
|
| 675 |
+
else tuple(example.thinking_span) if example.thinking_span else None
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
method_label, attributor, attr_result = compute_method_attribution(
|
| 679 |
+
method_key,
|
| 680 |
+
example,
|
| 681 |
+
model,
|
| 682 |
+
tokenizer,
|
| 683 |
+
n_hops=args.n_hops,
|
| 684 |
+
sink_span=sink_span,
|
| 685 |
+
thinking_span=thinking_span,
|
| 686 |
+
chunk_tokens=args.chunk_tokens,
|
| 687 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 688 |
+
attnlrp_neg_handling=args.attnlrp_neg_handling,
|
| 689 |
+
attnlrp_norm_mode=args.attnlrp_norm_mode,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
indices_to_explain = example.indices_to_explain or example.sink_span
|
| 693 |
+
if not (isinstance(indices_to_explain, list) and len(indices_to_explain) == 2):
|
| 694 |
+
raise ValueError("MAS case study requires token-span indices_to_explain=[start_tok,end_tok] (e.g. sink_span).")
|
| 695 |
+
seq_attr, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
|
| 696 |
+
|
| 697 |
+
prompt_tokens = decode_text_into_tokens(tokenizer, " " + example.prompt)
|
| 698 |
+
generation_text = example.target if example.target is not None else (getattr(attributor, "generation", None) or "")
|
| 699 |
+
|
| 700 |
+
variant_specs = [
|
| 701 |
+
("seq", "Seq attribution", seq_attr),
|
| 702 |
+
("row", "Row attribution", row_attr),
|
| 703 |
+
("recursive", "Recursive attribution", rec_attr),
|
| 704 |
+
]
|
| 705 |
+
|
| 706 |
+
formatted = format_prompt(tokenizer, " " + example.prompt)
|
| 707 |
+
prompt_ids = tokenizer(formatted, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
|
| 708 |
+
gen_ids = tokenizer(
|
| 709 |
+
generation_text + (tokenizer.eos_token or ""),
|
| 710 |
+
return_tensors="pt",
|
| 711 |
+
add_special_tokens=False,
|
| 712 |
+
).input_ids.to(model.device)
|
| 713 |
+
base_score = score_prompt_ids_with_generation(model, prompt_ids=prompt_ids, generation_ids=gen_ids)
|
| 714 |
+
|
| 715 |
+
panels_raw: List[Dict[str, Any]] = []
|
| 716 |
+
panels_display: List[Dict[str, Any]] = []
|
| 717 |
+
|
| 718 |
+
for variant_key, variant_label, variant_attr in variant_specs:
|
| 719 |
+
prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
|
| 720 |
+
attr_prompt = variant_attr[:, :prompt_len]
|
| 721 |
+
keep_prompt_token_indices = None
|
| 722 |
+
if method_key == "ft_improve":
|
| 723 |
+
import ft_ifr_improve
|
| 724 |
+
|
| 725 |
+
keep_prompt_token_indices = ft_ifr_improve.keep_token_indices(list(getattr(attributor, "user_prompt_tokens", []) or []))
|
| 726 |
+
trace = mas_trace(
|
| 727 |
+
model,
|
| 728 |
+
tokenizer,
|
| 729 |
+
attribution=attr_prompt.to(device="cpu"),
|
| 730 |
+
prompt=example.prompt,
|
| 731 |
+
generation=generation_text,
|
| 732 |
+
user_prompt_indices=getattr(attributor, "user_prompt_indices", None),
|
| 733 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 734 |
+
)
|
| 735 |
+
trace["variant"] = variant_key
|
| 736 |
+
trace["variant_label"] = variant_label
|
| 737 |
+
|
| 738 |
+
panel_raw = {
|
| 739 |
+
"variant": variant_key,
|
| 740 |
+
"variant_label": variant_label,
|
| 741 |
+
"metrics": trace.get("metrics"),
|
| 742 |
+
"sorted_attr_indices": trace.get("sorted_attr_indices"),
|
| 743 |
+
"attr_weights": trace.get("attr_weights"),
|
| 744 |
+
"token_deltas_raw": trace.get("token_deltas_raw"),
|
| 745 |
+
"mas_trace": trace,
|
| 746 |
+
}
|
| 747 |
+
panels_raw.append(panel_raw)
|
| 748 |
+
|
| 749 |
+
panel_display = {
|
| 750 |
+
"variant": variant_key,
|
| 751 |
+
"variant_label": variant_label,
|
| 752 |
+
"metrics": trace.get("metrics"),
|
| 753 |
+
"sorted_attr_indices": trace.get("sorted_attr_indices"),
|
| 754 |
+
"attr_weights": trace.get("attr_weights"),
|
| 755 |
+
"token_deltas_raw": trace.get("token_deltas_raw"),
|
| 756 |
+
}
|
| 757 |
+
panels_display.append(panel_display)
|
| 758 |
+
|
| 759 |
+
case_meta: Dict[str, Any] = {
|
| 760 |
+
"dataset": ds_name,
|
| 761 |
+
"index": args.index,
|
| 762 |
+
"mode": "mas",
|
| 763 |
+
"attr_method": method_key,
|
| 764 |
+
"attr_method_label": method_label,
|
| 765 |
+
"sink_span": sink_span,
|
| 766 |
+
"thinking_span": thinking_span,
|
| 767 |
+
"n_hops": int(args.n_hops),
|
| 768 |
+
"attnlrp_neg_handling": args.attnlrp_neg_handling if method_key in ("attnlrp", "ft_attnlrp") else None,
|
| 769 |
+
"attnlrp_norm_mode": args.attnlrp_norm_mode if method_key in ("attnlrp", "ft_attnlrp") else None,
|
| 770 |
+
"attnlrp_ratio_enabled": (args.attnlrp_norm_mode == "norm") if method_key in ("attnlrp", "ft_attnlrp") else None,
|
| 771 |
+
"base_score": float(base_score),
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
record = {
|
| 775 |
+
"meta": case_meta,
|
| 776 |
+
"prompt": example.prompt,
|
| 777 |
+
"target": example.target,
|
| 778 |
+
"generation": generation_text,
|
| 779 |
+
"prompt_tokens": prompt_tokens,
|
| 780 |
+
"panels": panels_raw,
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
out_dir = Path(args.output_dir)
|
| 784 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 785 |
+
stem = make_output_stem(ds_name, args.index, method_key)
|
| 786 |
+
json_path = out_dir / f"{stem}.json"
|
| 787 |
+
html_path = out_dir / f"{stem}.html"
|
| 788 |
+
|
| 789 |
+
with json_path.open("w", encoding="utf-8") as f:
|
| 790 |
+
json.dump(record, f, ensure_ascii=False, indent=2)
|
| 791 |
+
|
| 792 |
+
html = viz.render_mas_token_html(
|
| 793 |
+
case_meta,
|
| 794 |
+
prompt_tokens=prompt_tokens,
|
| 795 |
+
panels=panels_display,
|
| 796 |
+
generation=generation_text,
|
| 797 |
+
)
|
| 798 |
+
html_path.write_text(html, encoding="utf-8")
|
| 799 |
+
|
| 800 |
+
print(f"[done] wrote {json_path}")
|
| 801 |
+
print(f"[done] wrote {html_path}")
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
if __name__ == "__main__":
|
| 805 |
+
main()
|
exp/case_study/viz.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HTML helpers for visualizing hop-wise IFR/AttnLRP attributions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Any, Dict, List, Optional, Sequence
|
| 7 |
+
|
| 8 |
+
from html import escape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
TOKEN_SCALE_QUANTILE = 0.995
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _robust_abs_max(scores: Sequence[float], *, quantile: float = TOKEN_SCALE_QUANTILE) -> float:
|
| 15 |
+
"""Return a robust abs max to avoid a single outlier washing out the colormap.
|
| 16 |
+
|
| 17 |
+
Uses a high quantile (default: p99.5) over |scores|. Top outliers saturate.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
abs_vals: List[float] = []
|
| 21 |
+
for x in scores:
|
| 22 |
+
try:
|
| 23 |
+
v = float(x)
|
| 24 |
+
except Exception:
|
| 25 |
+
continue
|
| 26 |
+
if math.isnan(v):
|
| 27 |
+
continue
|
| 28 |
+
abs_vals.append(abs(v))
|
| 29 |
+
|
| 30 |
+
if not abs_vals:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
abs_vals.sort()
|
| 34 |
+
q = float(quantile)
|
| 35 |
+
if q < 0.0:
|
| 36 |
+
q = 0.0
|
| 37 |
+
if q > 1.0:
|
| 38 |
+
q = 1.0
|
| 39 |
+
idx = int(q * (len(abs_vals) - 1))
|
| 40 |
+
return float(abs_vals[idx])
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _color_for_score(score: float, max_score: float) -> str:
|
| 44 |
+
if max_score <= 0:
|
| 45 |
+
return "background-color: rgba(245,245,245,0.7);"
|
| 46 |
+
ratio = min(1.0, score / (max_score + 1e-12))
|
| 47 |
+
r = 255
|
| 48 |
+
g = int(235 - 90 * ratio)
|
| 49 |
+
b = int(220 - 160 * ratio)
|
| 50 |
+
alpha = 0.25 + 0.55 * ratio
|
| 51 |
+
return f"background-color: rgba({r}, {g}, {b}, {alpha});"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _render_sentence_list(title: str, sentences: Sequence[str], scores: Sequence[float], max_score: float) -> str:
|
| 55 |
+
rows: List[str] = []
|
| 56 |
+
for sent, sc in zip(sentences, scores):
|
| 57 |
+
style = _color_for_score(abs(float(sc)), max_score)
|
| 58 |
+
rows.append(
|
| 59 |
+
f'<div class="sent-row" style="{style}"><span class="score">{sc:.4f}</span>'
|
| 60 |
+
f'<span class="text">{escape(sent)}</span></div>'
|
| 61 |
+
)
|
| 62 |
+
return f"""
|
| 63 |
+
<div class="sent-block">
|
| 64 |
+
<div class="sent-title">{escape(title)}</div>
|
| 65 |
+
{''.join(rows)}
|
| 66 |
+
</div>
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _render_tokens(
|
| 71 |
+
tokens: Sequence[str],
|
| 72 |
+
scores: Sequence[float],
|
| 73 |
+
max_score: float,
|
| 74 |
+
roles: Sequence[str],
|
| 75 |
+
) -> str:
|
| 76 |
+
spans: List[str] = []
|
| 77 |
+
if max_score <= 0:
|
| 78 |
+
max_score = 1e-8
|
| 79 |
+
for idx, tok in enumerate(tokens):
|
| 80 |
+
score = float(scores[idx]) if idx < len(scores) else 0.0
|
| 81 |
+
style = _color_for_score(abs(score), max_score)
|
| 82 |
+
role = roles[idx] if idx < len(roles) else "gen"
|
| 83 |
+
safe_tok = escape(tok)
|
| 84 |
+
spans.append(
|
| 85 |
+
f'<span class="tok {role}" title="idx={idx}, score={score:.6f}" style="{style}">{safe_tok}</span>'
|
| 86 |
+
)
|
| 87 |
+
return "".join(spans)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _render_top_table(top_items: List[Dict[str, Any]]) -> str:
|
| 91 |
+
if not top_items:
|
| 92 |
+
return "<div class='top-table'><em>No attribution mass.</em></div>"
|
| 93 |
+
|
| 94 |
+
header = "<div class='top-row top-header'><span>Rank</span><span>Idx</span><span>Score</span><span>Sentence</span></div>"
|
| 95 |
+
body_rows = []
|
| 96 |
+
for rank, item in enumerate(top_items, start=1):
|
| 97 |
+
body_rows.append(
|
| 98 |
+
f"<div class='top-row'><span>{rank}</span><span>{item['idx']}</span>"
|
| 99 |
+
f"<span>{item['score']:.4f}</span><span>{escape(item['sentence'])}</span></div>"
|
| 100 |
+
)
|
| 101 |
+
return f"<div class='top-table'>{header}{''.join(body_rows)}</div>"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def render_case_html(
|
| 105 |
+
case_meta: Dict[str, Any],
|
| 106 |
+
*,
|
| 107 |
+
token_view_raw: Dict[str, Any],
|
| 108 |
+
token_view_prompt: Dict[str, Any],
|
| 109 |
+
context: Optional[Dict[str, Any]] = None,
|
| 110 |
+
hops_sent: Optional[Sequence[Dict[str, Any]]] = None,
|
| 111 |
+
) -> str:
|
| 112 |
+
has_sentence_view = bool(context) and bool(hops_sent)
|
| 113 |
+
prompt_len = len((context or {}).get("prompt_sentences") or []) if has_sentence_view else 0
|
| 114 |
+
gen_len = len((context or {}).get("generation_sentences") or []) if has_sentence_view else 0
|
| 115 |
+
|
| 116 |
+
prompt_max = 0.0
|
| 117 |
+
gen_max = 0.0
|
| 118 |
+
if has_sentence_view:
|
| 119 |
+
prompt_max = max(
|
| 120 |
+
(
|
| 121 |
+
max(h["sentence_scores_raw"][:prompt_len])
|
| 122 |
+
for h in (hops_sent or [])
|
| 123 |
+
if h.get("sentence_scores_raw") and h["sentence_scores_raw"][:prompt_len]
|
| 124 |
+
),
|
| 125 |
+
default=0.0,
|
| 126 |
+
)
|
| 127 |
+
gen_max = max(
|
| 128 |
+
(
|
| 129 |
+
max(h["sentence_scores_raw"][prompt_len:])
|
| 130 |
+
for h in (hops_sent or [])
|
| 131 |
+
if h.get("sentence_scores_raw") and h["sentence_scores_raw"][prompt_len:]
|
| 132 |
+
),
|
| 133 |
+
default=0.0,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
raw_hops = token_view_raw.get("hops", []) or []
|
| 137 |
+
prompt_hops = token_view_prompt.get("hops", []) or []
|
| 138 |
+
if len(raw_hops) != len(prompt_hops):
|
| 139 |
+
raise ValueError(
|
| 140 |
+
"token_view_raw and token_view_prompt must have the same number of panels: "
|
| 141 |
+
f"raw={len(raw_hops)} prompt={len(prompt_hops)}"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
hop_sections: List[str] = []
|
| 145 |
+
hop_count = len(prompt_hops)
|
| 146 |
+
mode = case_meta.get("mode", "ft")
|
| 147 |
+
ifr_view = case_meta.get("ifr_view", "aggregate")
|
| 148 |
+
sink_span = case_meta.get("sink_span")
|
| 149 |
+
panel_titles = case_meta.get("panel_titles")
|
| 150 |
+
|
| 151 |
+
def _panel_title(panel_idx: int) -> str:
|
| 152 |
+
if isinstance(panel_titles, list) and panel_idx < len(panel_titles):
|
| 153 |
+
try:
|
| 154 |
+
title = panel_titles[panel_idx]
|
| 155 |
+
except Exception:
|
| 156 |
+
title = None
|
| 157 |
+
if title is not None:
|
| 158 |
+
return str(title)
|
| 159 |
+
if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen", "ft_attnlrp"):
|
| 160 |
+
return f"Hop {panel_idx}"
|
| 161 |
+
if mode == "ifr_all_positions_output_only":
|
| 162 |
+
return f"IFR output-only panel {panel_idx}"
|
| 163 |
+
if mode == "ifr_all_positions":
|
| 164 |
+
return f"IFR all-positions panel {panel_idx}"
|
| 165 |
+
if mode == "attnlrp":
|
| 166 |
+
return "AttnLRP (sink-span aggregate)"
|
| 167 |
+
return "IFR (sink-span aggregate)"
|
| 168 |
+
|
| 169 |
+
for hop_idx in range(hop_count):
|
| 170 |
+
raw_entry = raw_hops[hop_idx]
|
| 171 |
+
raw_scores = raw_entry.get("token_scores") or []
|
| 172 |
+
raw_mass = float(raw_entry.get("total_mass", 0.0))
|
| 173 |
+
raw_scale = _robust_abs_max(raw_scores)
|
| 174 |
+
if raw_scale <= 0:
|
| 175 |
+
raw_scale = float(raw_entry.get("token_score_max") or 0.0)
|
| 176 |
+
if raw_scale <= 0:
|
| 177 |
+
raw_scale = 1e-8
|
| 178 |
+
|
| 179 |
+
prompt_entry = prompt_hops[hop_idx]
|
| 180 |
+
prompt_scores = prompt_entry.get("token_scores") or []
|
| 181 |
+
prompt_mass = float(prompt_entry.get("total_mass", 0.0))
|
| 182 |
+
prompt_scale = _robust_abs_max(prompt_scores)
|
| 183 |
+
if prompt_scale <= 0:
|
| 184 |
+
prompt_scale = float(prompt_entry.get("token_score_max") or 0.0)
|
| 185 |
+
if prompt_scale <= 0:
|
| 186 |
+
prompt_scale = 1e-8
|
| 187 |
+
|
| 188 |
+
tok_raw_html = f"""
|
| 189 |
+
<div class="tokens-block">
|
| 190 |
+
<div class="tokens-title">{escape(token_view_raw.get("label", "Pre-trim token-level heatmap (full)"))}</div>
|
| 191 |
+
<div class="tokens-row">
|
| 192 |
+
{_render_tokens(token_view_raw.get("tokens", []), raw_scores, raw_scale, token_view_raw.get("roles", []))}
|
| 193 |
+
</div>
|
| 194 |
+
</div>
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
tok_prompt_html = f"""
|
| 198 |
+
<div class="tokens-block">
|
| 199 |
+
<div class="tokens-title">{escape(token_view_prompt.get("label", "Prompt-only token-level heatmap"))}</div>
|
| 200 |
+
<div class="tokens-row">
|
| 201 |
+
{_render_tokens(token_view_prompt.get("tokens", []), prompt_scores, prompt_scale, token_view_prompt.get("roles", []))}
|
| 202 |
+
</div>
|
| 203 |
+
</div>
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
sentence_html = ""
|
| 207 |
+
top_html = ""
|
| 208 |
+
if has_sentence_view and hop_idx < len(hops_sent or []):
|
| 209 |
+
hop = (hops_sent or [])[hop_idx]
|
| 210 |
+
raw_scores = hop.get("sentence_scores_raw") or []
|
| 211 |
+
prompt_scores = raw_scores[:prompt_len]
|
| 212 |
+
gen_scores = raw_scores[prompt_len:]
|
| 213 |
+
# Sentence view is not used by the current case-study runner; keep the path for completeness.
|
| 214 |
+
sentence_html = f"""
|
| 215 |
+
<div class="columns">
|
| 216 |
+
{_render_sentence_list('Prompt sentences', (context or {}).get('prompt_sentences') or [], prompt_scores, prompt_max)}
|
| 217 |
+
{_render_sentence_list('Generation sentences', (context or {}).get('generation_sentences') or [], gen_scores, gen_max)}
|
| 218 |
+
</div>
|
| 219 |
+
"""
|
| 220 |
+
top_html = f"""
|
| 221 |
+
<div class="top-wrap">
|
| 222 |
+
<div class="section-label">Top sentences (all)</div>
|
| 223 |
+
{_render_top_table(hop.get('top_sentences') or [])}
|
| 224 |
+
</div>
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
hop_sections.append(
|
| 228 |
+
f"""
|
| 229 |
+
<div class="hop">
|
| 230 |
+
<div class="hop-header">
|
| 231 |
+
<div class="hop-title">{escape(_panel_title(hop_idx))}</div>
|
| 232 |
+
<div class="hop-meta">
|
| 233 |
+
raw mass: {raw_mass:.6f} | raw scale(p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f} abs): {raw_scale:.6g}
|
| 234 |
+
|
|
| 235 |
+
prompt mass: {prompt_mass:.6f} | prompt scale(p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f} abs): {prompt_scale:.6g}
|
| 236 |
+
</div>
|
| 237 |
+
</div>
|
| 238 |
+
{tok_raw_html}
|
| 239 |
+
{tok_prompt_html}
|
| 240 |
+
{sentence_html}
|
| 241 |
+
{top_html}
|
| 242 |
+
</div>
|
| 243 |
+
"""
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
thinking_ratios = case_meta.get("thinking_ratios") or []
|
| 247 |
+
ratios_str = ", ".join(f"{r:.4f}" for r in thinking_ratios) if thinking_ratios else "N/A"
|
| 248 |
+
|
| 249 |
+
if mode == "ft":
|
| 250 |
+
mode_label = "FT Multi-hop (IFR)"
|
| 251 |
+
elif mode == "ifr_in_all_gen":
|
| 252 |
+
mode_label = "IFR In-all-gen (multi-hop)"
|
| 253 |
+
elif mode == "ifr":
|
| 254 |
+
mode_label = "IFR Standard"
|
| 255 |
+
elif mode == "ifr_all_positions":
|
| 256 |
+
mode_label = "IFR All-positions"
|
| 257 |
+
elif mode == "ifr_all_positions_output_only":
|
| 258 |
+
mode_label = "IFR Output-only (all positions)"
|
| 259 |
+
elif mode == "attnlrp":
|
| 260 |
+
mode_label = "AttnLRP"
|
| 261 |
+
elif mode == "ft_attnlrp":
|
| 262 |
+
mode_label = "FT Multi-hop (AttnLRP)"
|
| 263 |
+
else:
|
| 264 |
+
mode_label = str(mode)
|
| 265 |
+
|
| 266 |
+
if mode in ("ft", "ifr_in_all_gen", "ft_attnlrp"):
|
| 267 |
+
view_key = "Recursive hops"
|
| 268 |
+
view_val = case_meta.get("n_hops")
|
| 269 |
+
elif mode in ("ifr", "ifr_all_positions", "ifr_all_positions_output_only"):
|
| 270 |
+
view_key = "IFR view"
|
| 271 |
+
view_val = ifr_view
|
| 272 |
+
elif mode == "attnlrp":
|
| 273 |
+
view_key = "AttnLRP view"
|
| 274 |
+
view_val = "ft_hop0_span_aggregate"
|
| 275 |
+
else:
|
| 276 |
+
view_key = "View"
|
| 277 |
+
view_val = "N/A"
|
| 278 |
+
|
| 279 |
+
scale_row = f"<div>Token scale: per-panel per-view p{int(TOKEN_SCALE_QUANTILE*1000)/10:.1f}(|score|)</div>"
|
| 280 |
+
neg_handling = case_meta.get("attnlrp_neg_handling")
|
| 281 |
+
norm_mode = case_meta.get("attnlrp_norm_mode")
|
| 282 |
+
ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
|
| 283 |
+
attn_rows = []
|
| 284 |
+
if neg_handling:
|
| 285 |
+
attn_rows.append(f"<div>FT-AttnLRP neg_handling: {escape(str(neg_handling))}</div>")
|
| 286 |
+
if norm_mode:
|
| 287 |
+
attn_rows.append(f"<div>FT-AttnLRP norm_mode: {escape(str(norm_mode))}</div>")
|
| 288 |
+
if ratio_enabled is not None:
|
| 289 |
+
attn_rows.append(f"<div>FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}</div>")
|
| 290 |
+
|
| 291 |
+
header = f"""
|
| 292 |
+
<div class="header">
|
| 293 |
+
<div>
|
| 294 |
+
<div class="title">{escape(mode_label)} Case Study</div>
|
| 295 |
+
<div class="subtitle">Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}</div>
|
| 296 |
+
</div>
|
| 297 |
+
<div class="meta">
|
| 298 |
+
<div>Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}</div>
|
| 299 |
+
<div>Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}</div>
|
| 300 |
+
<div>Panels: {hop_count}</div>
|
| 301 |
+
<div>{escape(str(view_key))}: {escape(str(view_val))}</div>
|
| 302 |
+
{scale_row}
|
| 303 |
+
{''.join(attn_rows)}
|
| 304 |
+
<div>Thinking ratios: {ratios_str}</div>
|
| 305 |
+
</div>
|
| 306 |
+
</div>
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
style = """
|
| 310 |
+
<style>
|
| 311 |
+
body { font-family: "Inter", "Helvetica Neue", Arial, sans-serif; margin: 0; padding: 24px; background: #fcfcff; color: #1f2933; }
|
| 312 |
+
.title { font-size: 24px; font-weight: 700; }
|
| 313 |
+
.subtitle { font-size: 14px; color: #566; margin-top: 4px; }
|
| 314 |
+
.header { display: flex; justify-content: space-between; align-items: flex-start; gap: 16px; padding-bottom: 16px; border-bottom: 1px solid #e5e8ee; }
|
| 315 |
+
.meta { font-size: 13px; color: #334; line-height: 1.6; }
|
| 316 |
+
.hop { margin-top: 20px; padding: 16px; border: 1px solid #e5e8ee; border-radius: 10px; background: #fff; box-shadow: 0 2px 6px rgba(0,0,0,0.04); }
|
| 317 |
+
.hop-header { display: flex; justify-content: space-between; align-items: center; }
|
| 318 |
+
.hop-title { font-weight: 600; font-size: 16px; }
|
| 319 |
+
.hop-meta { font-size: 12px; color: #556; }
|
| 320 |
+
.tokens-block { margin-top: 12px; border: 1px solid #eef1f6; border-radius: 8px; padding: 10px; background: #f9fbff; }
|
| 321 |
+
.tokens-title { font-size: 13px; font-weight: 600; margin-bottom: 8px; color: #263; }
|
| 322 |
+
.tokens-row { font-family: "SFMono-Regular", Consolas, monospace; font-size: 12px; line-height: 1.8; word-break: break-word; }
|
| 323 |
+
.tok { display: inline; padding: 2px 1px; margin: 0 0px; border-radius: 3px; }
|
| 324 |
+
.tok.prompt { border-bottom: 1px dashed #6b8fb8; }
|
| 325 |
+
.tok.user { border-bottom: 1px dashed #4f72c7; }
|
| 326 |
+
.tok.template { border-bottom: 1px dashed #9aa9c0; }
|
| 327 |
+
.tok.think { border-bottom: 1px dashed #8ba86b; }
|
| 328 |
+
.tok.output { border-bottom: 1px dashed #c78a6e; }
|
| 329 |
+
.tok.gen { border-bottom: 1px dashed #999; }
|
| 330 |
+
.tok:hover { outline: 1px solid #8899aa; }
|
| 331 |
+
.columns { display: grid; grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); gap: 12px; margin-top: 12px; }
|
| 332 |
+
.sent-block { padding: 8px; border: 1px solid #eef1f6; border-radius: 8px; background: #f9fbff; }
|
| 333 |
+
.sent-title { font-weight: 600; font-size: 13px; margin-bottom: 6px; color: #263; }
|
| 334 |
+
.sent-row { padding: 6px 8px; border-radius: 6px; margin-bottom: 6px; display: flex; gap: 8px; align-items: flex-start; }
|
| 335 |
+
.sent-row:last-child { margin-bottom: 0; }
|
| 336 |
+
.sent-row .score { font-family: "SFMono-Regular", Consolas, monospace; font-size: 12px; color: #233; min-width: 60px; }
|
| 337 |
+
.sent-row .text { flex: 1; font-size: 13px; }
|
| 338 |
+
.top-wrap { margin-top: 10px; }
|
| 339 |
+
.section-label { font-size: 13px; font-weight: 600; margin-bottom: 6px; color: #263; }
|
| 340 |
+
.top-table { border: 1px solid #eef1f6; border-radius: 8px; background: #fff; }
|
| 341 |
+
.top-row { display: grid; grid-template-columns: 50px 50px 80px 1fr; padding: 6px 8px; gap: 8px; font-size: 12px; }
|
| 342 |
+
.top-header { background: #f3f6fb; font-weight: 700; color: #223; }
|
| 343 |
+
.top-row:nth-child(odd):not(.top-header) { background: #fbfdff; }
|
| 344 |
+
</style>
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
title = f"{mode_label} Case Study"
|
| 348 |
+
html = f"""<!DOCTYPE html>
|
| 349 |
+
<html>
|
| 350 |
+
<head>
|
| 351 |
+
<meta charset="utf-8" />
|
| 352 |
+
<title>{escape(title)}</title>
|
| 353 |
+
{style}
|
| 354 |
+
</head>
|
| 355 |
+
<body>
|
| 356 |
+
{header}
|
| 357 |
+
{''.join(hop_sections)}
|
| 358 |
+
</body>
|
| 359 |
+
</html>"""
|
| 360 |
+
return html
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def _render_sentence_spans(title: str, sentences: Sequence[str], scores: Sequence[float]) -> str:
|
| 364 |
+
max_abs = max((abs(float(x)) for x in scores), default=0.0)
|
| 365 |
+
spans: List[str] = []
|
| 366 |
+
for idx, sentence in enumerate(sentences):
|
| 367 |
+
score = float(scores[idx]) if idx < len(scores) else 0.0
|
| 368 |
+
style = _color_for_score(abs(score), max_abs)
|
| 369 |
+
spans.append(
|
| 370 |
+
f'<span class="sent-span" title="idx={idx}, score={score:.6f}" style="{style}">{escape(sentence)}</span>'
|
| 371 |
+
)
|
| 372 |
+
return f"""
|
| 373 |
+
<div class="sentmap">
|
| 374 |
+
<div class="sentmap-title">{escape(title)}</div>
|
| 375 |
+
<div class="sentmap-text">{''.join(spans)}</div>
|
| 376 |
+
</div>
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _render_token_spans(title: str, tokens: Sequence[str], scores: Sequence[float]) -> str:
|
| 381 |
+
max_abs = max((abs(float(x)) for x in scores), default=0.0)
|
| 382 |
+
spans: List[str] = []
|
| 383 |
+
for idx, tok in enumerate(tokens):
|
| 384 |
+
score = float(scores[idx]) if idx < len(scores) else 0.0
|
| 385 |
+
style = _color_for_score(abs(score), max_abs)
|
| 386 |
+
spans.append(
|
| 387 |
+
f'<span class="tok-span" title="idx={idx}, score={score:.6f}" style="{style}">{escape(tok)}</span>'
|
| 388 |
+
)
|
| 389 |
+
return f"""
|
| 390 |
+
<div class="tokmap">
|
| 391 |
+
<div class="tokmap-title">{escape(title)}</div>
|
| 392 |
+
<div class="tokmap-text">{''.join(spans)}</div>
|
| 393 |
+
</div>
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def render_mas_sentence_html(
|
| 398 |
+
case_meta: Dict[str, Any],
|
| 399 |
+
*,
|
| 400 |
+
prompt_sentences: Sequence[str],
|
| 401 |
+
panels: Sequence[Dict[str, Any]],
|
| 402 |
+
generation: Optional[str] = None,
|
| 403 |
+
) -> str:
|
| 404 |
+
"""Render MAS sentence-level diagnostics (attribution / pure ablation / guided marginal)."""
|
| 405 |
+
|
| 406 |
+
method_label = case_meta.get("attr_method_label") or case_meta.get("attr_method") or "Unknown method"
|
| 407 |
+
title = f"MAS Sentence Study ({method_label})"
|
| 408 |
+
|
| 409 |
+
neg_handling = case_meta.get("attnlrp_neg_handling")
|
| 410 |
+
norm_mode = case_meta.get("attnlrp_norm_mode")
|
| 411 |
+
ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
|
| 412 |
+
attn_rows = []
|
| 413 |
+
if neg_handling:
|
| 414 |
+
attn_rows.append(f"<div>FT-AttnLRP neg_handling: {escape(str(neg_handling))}</div>")
|
| 415 |
+
if norm_mode:
|
| 416 |
+
attn_rows.append(f"<div>FT-AttnLRP norm_mode: {escape(str(norm_mode))}</div>")
|
| 417 |
+
if ratio_enabled is not None:
|
| 418 |
+
attn_rows.append(f"<div>FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}</div>")
|
| 419 |
+
|
| 420 |
+
base_score = case_meta.get("base_score")
|
| 421 |
+
base_score_row = f"<div>Base score: {float(base_score):.6f}</div>" if isinstance(base_score, (int, float)) else ""
|
| 422 |
+
|
| 423 |
+
gen_block = ""
|
| 424 |
+
if isinstance(generation, str) and generation:
|
| 425 |
+
gen_block = f"""
|
| 426 |
+
<div class="text-block">
|
| 427 |
+
<div class="text-title">Generation (scored)</div>
|
| 428 |
+
<div class="text-body">{escape(generation)}</div>
|
| 429 |
+
</div>
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
header = f"""
|
| 433 |
+
<div class="header">
|
| 434 |
+
<div>
|
| 435 |
+
<div class="title">{escape(title)}</div>
|
| 436 |
+
<div class="subtitle">Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}</div>
|
| 437 |
+
</div>
|
| 438 |
+
<div class="meta">
|
| 439 |
+
<div>Attribution method: {escape(str(case_meta.get('attr_method')))}</div>
|
| 440 |
+
<div>Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}</div>
|
| 441 |
+
<div>Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}</div>
|
| 442 |
+
<div>Panels: {len(panels)}</div>
|
| 443 |
+
{''.join(attn_rows)}
|
| 444 |
+
{base_score_row}
|
| 445 |
+
</div>
|
| 446 |
+
</div>
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
panel_sections: List[str] = []
|
| 450 |
+
for panel in panels:
|
| 451 |
+
label = panel.get("variant_label") or panel.get("panel_label") or panel.get("variant") or "Panel"
|
| 452 |
+
metrics = panel.get("metrics") or {}
|
| 453 |
+
metrics_str = " | ".join(
|
| 454 |
+
f"{k}: {float(metrics[k]):.4f}" if isinstance(metrics.get(k), (int, float)) else f"{k}: {metrics.get(k)}"
|
| 455 |
+
for k in ("RISE", "MAS", "RISE+AP")
|
| 456 |
+
if k in metrics
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
attr_weights = panel.get("attr_weights") or []
|
| 460 |
+
pure_deltas = panel.get("pure_sentence_deltas_raw") or []
|
| 461 |
+
guided_deltas = panel.get("guided_sentence_deltas_raw") or panel.get("sentence_deltas_raw") or []
|
| 462 |
+
rank_order = panel.get("sorted_attr_indices") or []
|
| 463 |
+
rank_str = ", ".join(str(int(x)) for x in rank_order) if rank_order else "N/A"
|
| 464 |
+
|
| 465 |
+
panel_sections.append(
|
| 466 |
+
f"""
|
| 467 |
+
<div class="panel">
|
| 468 |
+
<div class="panel-header">
|
| 469 |
+
<div class="panel-title">{escape(str(label))}</div>
|
| 470 |
+
<div class="panel-meta">{escape(metrics_str)}</div>
|
| 471 |
+
</div>
|
| 472 |
+
|
| 473 |
+
{_render_sentence_spans("Method attribution (sentence weights)", prompt_sentences, attr_weights)}
|
| 474 |
+
{_render_sentence_spans("Pure sentence ablation (base − score)", prompt_sentences, pure_deltas)}
|
| 475 |
+
{_render_sentence_spans("Attribution-guided MAS marginal (path deltas)", prompt_sentences, guided_deltas)}
|
| 476 |
+
|
| 477 |
+
<div class="panel-foot">Rank order: {escape(rank_str)}</div>
|
| 478 |
+
</div>
|
| 479 |
+
"""
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
style = """
|
| 483 |
+
<style>
|
| 484 |
+
body { font-family: "Inter", "Helvetica Neue", Arial, sans-serif; margin: 0; padding: 24px; background: #fcfcff; color: #1f2933; }
|
| 485 |
+
.title { font-size: 24px; font-weight: 700; }
|
| 486 |
+
.subtitle { font-size: 14px; color: #566; margin-top: 4px; }
|
| 487 |
+
.header { display: flex; justify-content: space-between; align-items: flex-start; gap: 16px; padding-bottom: 16px; border-bottom: 1px solid #e5e8ee; }
|
| 488 |
+
.meta { font-size: 13px; color: #334; line-height: 1.6; }
|
| 489 |
+
|
| 490 |
+
.text-block { margin-top: 16px; border: 1px solid #eef1f6; border-radius: 10px; padding: 12px; background: #fff; }
|
| 491 |
+
.text-title { font-size: 13px; font-weight: 700; color: #263; margin-bottom: 8px; }
|
| 492 |
+
.text-body { font-size: 13px; line-height: 1.7; white-space: pre-wrap; word-break: break-word; }
|
| 493 |
+
|
| 494 |
+
.panel { margin-top: 18px; padding: 16px; border: 1px solid #e5e8ee; border-radius: 10px; background: #fff; box-shadow: 0 2px 6px rgba(0,0,0,0.04); }
|
| 495 |
+
.panel-header { display: flex; justify-content: space-between; align-items: center; }
|
| 496 |
+
.panel-title { font-weight: 600; font-size: 16px; }
|
| 497 |
+
.panel-meta { font-size: 12px; color: #556; }
|
| 498 |
+
.panel-foot { margin-top: 8px; font-size: 12px; color: #556; }
|
| 499 |
+
|
| 500 |
+
.sentmap { margin-top: 12px; border: 1px solid #eef1f6; border-radius: 8px; padding: 10px; background: #f9fbff; }
|
| 501 |
+
.sentmap-title { font-size: 13px; font-weight: 600; margin-bottom: 8px; color: #263; }
|
| 502 |
+
.sentmap-text { font-size: 13px; line-height: 1.8; white-space: pre-wrap; word-break: break-word; }
|
| 503 |
+
.sent-span { display: inline; padding: 2px 2px; margin: 0 0px; border-radius: 4px; }
|
| 504 |
+
.sent-span:hover { outline: 1px solid #8899aa; }
|
| 505 |
+
</style>
|
| 506 |
+
"""
|
| 507 |
+
|
| 508 |
+
html = f"""<!DOCTYPE html>
|
| 509 |
+
<html>
|
| 510 |
+
<head>
|
| 511 |
+
<meta charset="utf-8" />
|
| 512 |
+
<title>{escape(title)}</title>
|
| 513 |
+
{style}
|
| 514 |
+
</head>
|
| 515 |
+
<body>
|
| 516 |
+
{header}
|
| 517 |
+
{gen_block}
|
| 518 |
+
{''.join(panel_sections)}
|
| 519 |
+
</body>
|
| 520 |
+
</html>"""
|
| 521 |
+
return html
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def render_mas_token_html(
|
| 525 |
+
case_meta: Dict[str, Any],
|
| 526 |
+
*,
|
| 527 |
+
prompt_tokens: Sequence[str],
|
| 528 |
+
panels: Sequence[Dict[str, Any]],
|
| 529 |
+
generation: Optional[str] = None,
|
| 530 |
+
) -> str:
|
| 531 |
+
"""Render MAS token-level diagnostics (attribution weights + guided marginal deltas)."""
|
| 532 |
+
|
| 533 |
+
method_label = case_meta.get("attr_method_label") or case_meta.get("attr_method") or "Unknown method"
|
| 534 |
+
title = f"MAS Token Study ({method_label})"
|
| 535 |
+
|
| 536 |
+
neg_handling = case_meta.get("attnlrp_neg_handling")
|
| 537 |
+
norm_mode = case_meta.get("attnlrp_norm_mode")
|
| 538 |
+
ratio_enabled = case_meta.get("attnlrp_ratio_enabled")
|
| 539 |
+
attn_rows = []
|
| 540 |
+
if neg_handling:
|
| 541 |
+
attn_rows.append(f"<div>FT-AttnLRP neg_handling: {escape(str(neg_handling))}</div>")
|
| 542 |
+
if norm_mode:
|
| 543 |
+
attn_rows.append(f"<div>FT-AttnLRP norm_mode: {escape(str(norm_mode))}</div>")
|
| 544 |
+
if ratio_enabled is not None:
|
| 545 |
+
attn_rows.append(f"<div>FT-AttnLRP ratio_enabled: {escape(str(bool(ratio_enabled)))}</div>")
|
| 546 |
+
|
| 547 |
+
base_score = case_meta.get("base_score")
|
| 548 |
+
base_score_row = f"<div>Base score: {float(base_score):.6f}</div>" if isinstance(base_score, (int, float)) else ""
|
| 549 |
+
|
| 550 |
+
gen_block = ""
|
| 551 |
+
if isinstance(generation, str) and generation:
|
| 552 |
+
gen_block = f"""
|
| 553 |
+
<div class="text-block">
|
| 554 |
+
<div class="text-title">Generation (scored)</div>
|
| 555 |
+
<div class="text-body">{escape(generation)}</div>
|
| 556 |
+
</div>
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
header = f"""
|
| 560 |
+
<div class="header">
|
| 561 |
+
<div>
|
| 562 |
+
<div class="title">{escape(title)}</div>
|
| 563 |
+
<div class="subtitle">Dataset: {escape(str(case_meta.get('dataset')))} | index: {case_meta.get('index')}</div>
|
| 564 |
+
</div>
|
| 565 |
+
<div class="meta">
|
| 566 |
+
<div>Attribution method: {escape(str(case_meta.get('attr_method')))}</div>
|
| 567 |
+
<div>Sink span (gen idx): {escape(str(case_meta.get('sink_span')))}</div>
|
| 568 |
+
<div>Thinking span (gen idx): {escape(str(case_meta.get('thinking_span')))}</div>
|
| 569 |
+
<div>Prompt tokens: {len(prompt_tokens)}</div>
|
| 570 |
+
<div>Panels: {len(panels)}</div>
|
| 571 |
+
{''.join(attn_rows)}
|
| 572 |
+
{base_score_row}
|
| 573 |
+
</div>
|
| 574 |
+
</div>
|
| 575 |
+
"""
|
| 576 |
+
|
| 577 |
+
panel_sections: List[str] = []
|
| 578 |
+
for panel in panels:
|
| 579 |
+
label = panel.get("variant_label") or panel.get("panel_label") or panel.get("variant") or "Panel"
|
| 580 |
+
metrics = panel.get("metrics") or {}
|
| 581 |
+
metrics_str = " | ".join(
|
| 582 |
+
f"{k}: {float(metrics[k]):.4f}" if isinstance(metrics.get(k), (int, float)) else f"{k}: {metrics.get(k)}"
|
| 583 |
+
for k in ("RISE", "MAS", "RISE+AP")
|
| 584 |
+
if k in metrics
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
attr_weights = panel.get("attr_weights") or []
|
| 588 |
+
guided_deltas = panel.get("token_deltas_raw") or []
|
| 589 |
+
rank_order = panel.get("sorted_attr_indices") or []
|
| 590 |
+
rank_str = ", ".join(str(int(x)) for x in rank_order) if rank_order else "N/A"
|
| 591 |
+
|
| 592 |
+
panel_sections.append(
|
| 593 |
+
f"""
|
| 594 |
+
<div class="panel">
|
| 595 |
+
<div class="panel-header">
|
| 596 |
+
<div class="panel-title">{escape(str(label))}</div>
|
| 597 |
+
<div class="panel-meta">{escape(metrics_str)}</div>
|
| 598 |
+
</div>
|
| 599 |
+
|
| 600 |
+
{_render_token_spans("Method attribution (token weights)", prompt_tokens, attr_weights)}
|
| 601 |
+
{_render_token_spans("Attribution-guided MAS marginal (path deltas)", prompt_tokens, guided_deltas)}
|
| 602 |
+
|
| 603 |
+
<div class="panel-foot">Rank order: {escape(rank_str)}</div>
|
| 604 |
+
</div>
|
| 605 |
+
"""
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
style = """
|
| 609 |
+
<style>
|
| 610 |
+
body { font-family: "Inter", "Helvetica Neue", Arial, sans-serif; margin: 0; padding: 24px; background: #fcfcff; color: #1f2933; }
|
| 611 |
+
.title { font-size: 24px; font-weight: 700; }
|
| 612 |
+
.subtitle { font-size: 14px; color: #566; margin-top: 4px; }
|
| 613 |
+
.header { display: flex; justify-content: space-between; align-items: flex-start; gap: 16px; padding-bottom: 16px; border-bottom: 1px solid #e5e8ee; }
|
| 614 |
+
.meta { font-size: 13px; color: #334; line-height: 1.6; }
|
| 615 |
+
|
| 616 |
+
.text-block { margin-top: 16px; border: 1px solid #eef1f6; border-radius: 10px; padding: 12px; background: #fff; }
|
| 617 |
+
.text-title { font-size: 13px; font-weight: 700; color: #263; margin-bottom: 8px; }
|
| 618 |
+
.text-body { font-size: 13px; line-height: 1.7; white-space: pre-wrap; word-break: break-word; }
|
| 619 |
+
|
| 620 |
+
.panel { margin-top: 18px; padding: 16px; border: 1px solid #e5e8ee; border-radius: 10px; background: #fff; box-shadow: 0 2px 6px rgba(0,0,0,0.04); }
|
| 621 |
+
.panel-header { display: flex; justify-content: space-between; align-items: center; }
|
| 622 |
+
.panel-title { font-weight: 600; font-size: 16px; }
|
| 623 |
+
.panel-meta { font-size: 12px; color: #556; }
|
| 624 |
+
.panel-foot { margin-top: 8px; font-size: 12px; color: #556; }
|
| 625 |
+
|
| 626 |
+
.tokmap { margin-top: 12px; border: 1px solid #eef1f6; border-radius: 8px; padding: 10px; background: #f9fbff; }
|
| 627 |
+
.tokmap-title { font-size: 13px; font-weight: 600; margin-bottom: 8px; color: #263; }
|
| 628 |
+
.tokmap-text { font-size: 13px; line-height: 1.8; white-space: pre-wrap; word-break: break-word; }
|
| 629 |
+
.tok-span { display: inline; padding: 1px 1px; margin: 0 0px; border-radius: 3px; }
|
| 630 |
+
.tok-span:hover { outline: 1px solid #8899aa; }
|
| 631 |
+
</style>
|
| 632 |
+
"""
|
| 633 |
+
|
| 634 |
+
html = f"""<!DOCTYPE html>
|
| 635 |
+
<html>
|
| 636 |
+
<head>
|
| 637 |
+
<meta charset="utf-8" />
|
| 638 |
+
<title>{escape(title)}</title>
|
| 639 |
+
{style}
|
| 640 |
+
</head>
|
| 641 |
+
<body>
|
| 642 |
+
{header}
|
| 643 |
+
{gen_block}
|
| 644 |
+
{''.join(panel_sections)}
|
| 645 |
+
</body>
|
| 646 |
+
</html>"""
|
| 647 |
+
return html
|
exp/exp1/README.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashTrace 长上下文耗时实验(exp1)
|
| 2 |
+
|
| 3 |
+
自包含脚本:`exp/exp1/run_time_curve.py`
|
| 4 |
+
用途:在单个 RULER 样本上,测量不同上下文长度下各归因方法的 wall-clock 时间与 GPU 峰值显存,供论文中的线性增长表格使用。
|
| 5 |
+
|
| 6 |
+
## 方法覆盖
|
| 7 |
+
- `IG`(20 步)
|
| 8 |
+
- `attention_I_G`(注意力 * IG)
|
| 9 |
+
- `attnlrp`(单次反传的 LRP 版本)
|
| 10 |
+
- `perturbation_all`(log-loss ablation)
|
| 11 |
+
- `perturbation_CLP`(KL 版)
|
| 12 |
+
- `perturbation_REAGENT`(MLM 替换,LED/4096 上限,超过则可能失败)
|
| 13 |
+
- `ifr_all_positions`(IFR one-by-one baseline,`sink_chunk_tokens=1` 固定)
|
| 14 |
+
- `ifr_multi_hop`(FlashTrace,多跳+chunk 支持)
|
| 15 |
+
- `ifr_multi_hop_both`(FT-IFR both:stop_words + in_all_gen,多跳+chunk 支持)
|
| 16 |
+
|
| 17 |
+
## 运行示例
|
| 18 |
+
```bash
|
| 19 |
+
# 默认 input 长度 1024,4096,8192,output 长度 32,256,512;每格 3 次
|
| 20 |
+
python exp/exp1/run_time_curve.py \
|
| 21 |
+
--model qwen-8B \
|
| 22 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 23 |
+
--cuda 2,3,4,5,6,7 \
|
| 24 |
+
--attr_funcs perturbation_all,perturbation_REAGENT,ifr_all_positions,perturbation_CLP,ifr_multi_hop,ifr_multi_hop_both,attnlrp \
|
| 25 |
+
--input_lengths 10 \
|
| 26 |
+
--output_lengths 2000,5000,10000 \
|
| 27 |
+
--repeats 1 \
|
| 28 |
+
--chunk_tokens 128 \
|
| 29 |
+
--sink_chunk_tokens 32 \
|
| 30 |
+
--catch_oom \
|
| 31 |
+
--ruler_file data/ruler_multihop/8192/vt_h10_c1/validation.jsonl
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
输出:
|
| 35 |
+
- `exp/exp1/out/time_curve_runs.jsonl`:每次运行的原始记录(attr、目标 input/output/total、实际长度、time、peak_mem、status)。
|
| 36 |
+
- `exp/exp1/out/time_curve_summary.csv`:按方法 + 目标 input/output 汇总的均值/方差(同时写出 total=input+output)。
|
| 37 |
+
|
| 38 |
+
## 注意事项
|
| 39 |
+
- `--input_lengths` 控制 prompt(user prompt)长度,`--output_lengths` 控制 output(sink)长度;每个格子的 total = input + output。
|
| 40 |
+
- 兼容:仍支持 `--total_lengths/--lengths`(deprecated),表示 prompt+output 总长度;prompt 长度按两者差值生成。
|
| 41 |
+
- `--target_text` 作为基底被重复拼接以满足目标 output 长度,仅用于控制长度,不在乎语义。
|
| 42 |
+
- `--catch_oom/--no-catch-oom` 用于选择是把 OOM 记为 status 继续,还是直接抛错中止。
|
| 43 |
+
- 多卡:`--cuda 0,1` 会在脚本启动前设置 `CUDA_VISIBLE_DEVICES` 并用 `device_map=balanced` 分片加载;单卡指定 `--cuda 0`。
|
| 44 |
+
- 超出模型上下文 (`config.max_position_embeddings`) 会标记 `skipped_model_ctx`(按实际喂给模型的 formatted prompt + output(+eos) token 数检查)。
|
| 45 |
+
- `perturbation_REAGENT` 的 Longformer 仅支持 4096 tokens,超过可能返回 OOM 或 runtime_error。
|
| 46 |
+
- IFR multi-hop 提供 `--chunk_tokens/--sink_chunk_tokens` 以在超长上下文上强制分块,显存会下降但时间略升;`ifr_all_positions` 分支固定 `sink_chunk_tokens=1`。
|
exp/exp1/run_time_curve.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Measure wall-clock time and GPU memory for attribution methods across
|
| 4 |
+
different context lengths using a single synthetic RULER-style example.
|
| 5 |
+
|
| 6 |
+
This script stays self-contained under exp/exp1 and reuses the attribution
|
| 7 |
+
implementations in the repo (IG, perturbation, attention*IG, IFR/FlashTrace).
|
| 8 |
+
The goal is to populate the time-vs-length table; correctness of the task
|
| 9 |
+
content is not important, only matching token lengths and running 3 repeats.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
import random
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _early_set_cuda_visible_devices() -> None:
|
| 29 |
+
"""Parse --cuda early to set CUDA_VISIBLE_DEVICES before torch import."""
|
| 30 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 31 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 32 |
+
args, _ = parser.parse_known_args(sys.argv[1:])
|
| 33 |
+
if args.cuda and "," in args.cuda:
|
| 34 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_early_set_cuda_visible_devices()
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 41 |
+
|
| 42 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 43 |
+
if str(REPO_ROOT) not in sys.path:
|
| 44 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 45 |
+
|
| 46 |
+
import llm_attr
|
| 47 |
+
|
| 48 |
+
DEFAULT_INPUT_LENGTHS = [1024, 4096, 8192]
|
| 49 |
+
DEFAULT_OUTPUT_LENGTHS = [32, 256, 512]
|
| 50 |
+
DEFAULT_ATTRS = [
|
| 51 |
+
"IG",
|
| 52 |
+
"perturbation_all",
|
| 53 |
+
"attention_I_G",
|
| 54 |
+
"perturbation_REAGENT",
|
| 55 |
+
"ifr_all_positions",
|
| 56 |
+
"perturbation_CLP",
|
| 57 |
+
"ifr_multi_hop",
|
| 58 |
+
"attnlrp",
|
| 59 |
+
]
|
| 60 |
+
DEFAULT_RULER_FILE = REPO_ROOT / "data" / "ruler_multihop" / "8192" / "vt_h10_c1" / "validation.jsonl"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def parse_args() -> argparse.Namespace:
|
| 64 |
+
parser = argparse.ArgumentParser("FlashTrace time/memory curve.")
|
| 65 |
+
parser.add_argument("--model", type=str, required=True, help="Model name or HF repo id.")
|
| 66 |
+
parser.add_argument("--model_path", type=str, default=None, help="Optional local model path.")
|
| 67 |
+
parser.add_argument("--cuda", type=str, default=None, help='CUDA devices, e.g. "0,1" or "0".')
|
| 68 |
+
parser.add_argument("--cuda_num", type=int, default=0, help="Single GPU index if --cuda is not set.")
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--attr_funcs",
|
| 71 |
+
type=str,
|
| 72 |
+
default=",".join(DEFAULT_ATTRS),
|
| 73 |
+
help="Comma-separated attribution methods.",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
length_group = parser.add_mutually_exclusive_group()
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--output_lengths",
|
| 79 |
+
type=str,
|
| 80 |
+
default=",".join(str(x) for x in DEFAULT_OUTPUT_LENGTHS),
|
| 81 |
+
help="Comma-separated target output token lengths (sink/output segment).",
|
| 82 |
+
)
|
| 83 |
+
length_group.add_argument(
|
| 84 |
+
"--input_lengths",
|
| 85 |
+
type=str,
|
| 86 |
+
default=",".join(str(x) for x in DEFAULT_INPUT_LENGTHS),
|
| 87 |
+
help="Comma-separated target input/prompt token lengths (user prompt only; excludes chat template).",
|
| 88 |
+
)
|
| 89 |
+
length_group.add_argument(
|
| 90 |
+
"--total_lengths",
|
| 91 |
+
"--lengths",
|
| 92 |
+
dest="total_lengths",
|
| 93 |
+
type=str,
|
| 94 |
+
default=None,
|
| 95 |
+
help="Deprecated. Target total token lengths (prompt + output). Use --input_lengths instead.",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument("--repeats", type=int, default=3, help="Number of runs per cell.")
|
| 98 |
+
parser.add_argument("--output_dir", type=str, default="exp/exp1/out", help="Output directory.")
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--ruler_file",
|
| 101 |
+
type=str,
|
| 102 |
+
default=str(DEFAULT_RULER_FILE),
|
| 103 |
+
help="RULER jsonl file providing a long base passage.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--chunk_tokens",
|
| 107 |
+
type=int,
|
| 108 |
+
default=128,
|
| 109 |
+
help="IFR chunk_tokens override when context is long.",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--sink_chunk_tokens",
|
| 113 |
+
type=int,
|
| 114 |
+
default=32,
|
| 115 |
+
help="IFR sink_chunk_tokens override when context is long.",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--catch_oom",
|
| 119 |
+
action=argparse.BooleanOptionalAction,
|
| 120 |
+
default=True,
|
| 121 |
+
help="If true, treat CUDA OOM as status=oom and continue; if false, let OOM raise.",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--target_text",
|
| 125 |
+
type=str,
|
| 126 |
+
default=" The answer is 42.",
|
| 127 |
+
help="Base text to tile when constructing outputs of a given length.",
|
| 128 |
+
)
|
| 129 |
+
return parser.parse_args()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def parse_csv_ints(value: str) -> List[int]:
|
| 133 |
+
return [int(x) for x in value.split(",") if x.strip()]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def resolve_device(cuda: Optional[str], cuda_num: int) -> str:
|
| 137 |
+
if cuda is not None and "," in cuda:
|
| 138 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = cuda
|
| 139 |
+
return "auto"
|
| 140 |
+
if cuda is not None and cuda.strip():
|
| 141 |
+
try:
|
| 142 |
+
idx = int(cuda)
|
| 143 |
+
except Exception:
|
| 144 |
+
idx = 0
|
| 145 |
+
return f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
|
| 146 |
+
return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu"
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_ruler_base(path: Path, fallback: str) -> str:
|
| 150 |
+
if not path.exists():
|
| 151 |
+
return fallback
|
| 152 |
+
with path.open() as f:
|
| 153 |
+
for line in f:
|
| 154 |
+
try:
|
| 155 |
+
record = json.loads(line)
|
| 156 |
+
if "input" in record:
|
| 157 |
+
return record["input"]
|
| 158 |
+
except json.JSONDecodeError:
|
| 159 |
+
continue
|
| 160 |
+
return fallback
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def build_prompt_to_length(tokenizer, base_text: str, target_tokens: int) -> Tuple[str, int]:
|
| 164 |
+
"""
|
| 165 |
+
Build a prompt whose tokenized length (without special tokens) is ~target_tokens.
|
| 166 |
+
If base_text is shorter, we repeat it; if longer, we truncate.
|
| 167 |
+
"""
|
| 168 |
+
if target_tokens <= 0:
|
| 169 |
+
return "", 0
|
| 170 |
+
|
| 171 |
+
base_ids = tokenizer(base_text, add_special_tokens=False).input_ids
|
| 172 |
+
if not base_ids:
|
| 173 |
+
base_ids = [tokenizer.eos_token_id]
|
| 174 |
+
|
| 175 |
+
tiled: List[int] = []
|
| 176 |
+
while len(tiled) < target_tokens:
|
| 177 |
+
tiled.extend(base_ids)
|
| 178 |
+
tiled = tiled[:target_tokens]
|
| 179 |
+
prompt = tokenizer.decode(tiled, clean_up_tokenization_spaces=False)
|
| 180 |
+
return prompt, len(tiled)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def build_output_to_length(tokenizer, base_text: str, target_tokens: int) -> Tuple[str, int]:
|
| 184 |
+
"""
|
| 185 |
+
Build a target/output string of ~target_tokens using a base snippet.
|
| 186 |
+
"""
|
| 187 |
+
if target_tokens <= 0:
|
| 188 |
+
return "", 0
|
| 189 |
+
|
| 190 |
+
base_ids = tokenizer(base_text, add_special_tokens=False).input_ids
|
| 191 |
+
if not base_ids:
|
| 192 |
+
base_ids = [tokenizer.eos_token_id]
|
| 193 |
+
|
| 194 |
+
tiled: List[int] = []
|
| 195 |
+
while len(tiled) < target_tokens:
|
| 196 |
+
tiled.extend(base_ids)
|
| 197 |
+
tiled = tiled[:target_tokens]
|
| 198 |
+
text = tokenizer.decode(tiled, clean_up_tokenization_spaces=False)
|
| 199 |
+
return text, len(tiled)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def build_formatted_prompt(tokenizer, prompt: str) -> str:
|
| 203 |
+
user_prompt = " " + prompt
|
| 204 |
+
modified_prompt = llm_attr.DEFAULT_PROMPT_TEMPLATE.format(context=user_prompt, query="")
|
| 205 |
+
formatted_prompt = [{"role": "user", "content": modified_prompt}]
|
| 206 |
+
return tokenizer.apply_chat_template(
|
| 207 |
+
formatted_prompt,
|
| 208 |
+
tokenize=False,
|
| 209 |
+
add_generation_prompt=True,
|
| 210 |
+
enable_thinking=False,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def estimate_model_lengths(tokenizer, prompt: str, target: str) -> Dict[str, int]:
|
| 215 |
+
user_prompt = " " + prompt
|
| 216 |
+
formatted_prompt = build_formatted_prompt(tokenizer, prompt)
|
| 217 |
+
|
| 218 |
+
user_prompt_len = len(tokenizer(user_prompt, add_special_tokens=False).input_ids)
|
| 219 |
+
formatted_prompt_len = len(tokenizer(formatted_prompt, add_special_tokens=False).input_ids)
|
| 220 |
+
generation_len = len(tokenizer(target + tokenizer.eos_token, add_special_tokens=False).input_ids)
|
| 221 |
+
|
| 222 |
+
return {
|
| 223 |
+
"user_prompt_tokens": user_prompt_len,
|
| 224 |
+
"formatted_prompt_tokens": formatted_prompt_len,
|
| 225 |
+
"generation_tokens": generation_len,
|
| 226 |
+
"total_tokens": formatted_prompt_len + generation_len,
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def exceeds_model_ctx(tokenizer, prompt: str, target: str, max_ctx: Optional[int]) -> bool:
|
| 231 |
+
if max_ctx is None:
|
| 232 |
+
return False
|
| 233 |
+
return estimate_model_lengths(tokenizer, prompt, target)["total_tokens"] > max_ctx
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def load_model_balanced(model_name: str, device: str):
|
| 237 |
+
"""Load model with an explicit balanced device_map when multi-GPU is requested."""
|
| 238 |
+
if device == "auto":
|
| 239 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 240 |
+
model_name,
|
| 241 |
+
device_map="balanced",
|
| 242 |
+
torch_dtype=torch.float16,
|
| 243 |
+
attn_implementation="eager",
|
| 244 |
+
)
|
| 245 |
+
elif isinstance(device, str) and device.startswith("cuda:"):
|
| 246 |
+
try:
|
| 247 |
+
gpu_idx = int(device.split(":")[1])
|
| 248 |
+
except Exception:
|
| 249 |
+
gpu_idx = 0
|
| 250 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 251 |
+
model_name,
|
| 252 |
+
device_map={"": gpu_idx},
|
| 253 |
+
torch_dtype=torch.float16,
|
| 254 |
+
attn_implementation="eager",
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 258 |
+
model_name,
|
| 259 |
+
torch_dtype=torch.float16,
|
| 260 |
+
attn_implementation="eager",
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 264 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 265 |
+
model.eval()
|
| 266 |
+
return model, tokenizer
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def collect_device_indices(device_str: str, model: Any) -> List[int]:
|
| 270 |
+
"""
|
| 271 |
+
Infer the CUDA device indices that should be tracked for memory stats.
|
| 272 |
+
Prefers the model's device map; otherwise falls back to all visible devices
|
| 273 |
+
or the single requested device.
|
| 274 |
+
"""
|
| 275 |
+
if not torch.cuda.is_available():
|
| 276 |
+
return []
|
| 277 |
+
|
| 278 |
+
devices: set[int] = set()
|
| 279 |
+
device_map = getattr(model, "hf_device_map", None)
|
| 280 |
+
if isinstance(device_map, dict):
|
| 281 |
+
for dev in device_map.values():
|
| 282 |
+
if dev is None:
|
| 283 |
+
continue
|
| 284 |
+
idx: Optional[int] = None
|
| 285 |
+
if isinstance(dev, torch.device):
|
| 286 |
+
idx = dev.index if dev.index is not None else (0 if dev.type == "cuda" else None)
|
| 287 |
+
elif isinstance(dev, str):
|
| 288 |
+
try:
|
| 289 |
+
d = torch.device(dev)
|
| 290 |
+
idx = d.index if d.index is not None else (0 if d.type == "cuda" else None)
|
| 291 |
+
except Exception:
|
| 292 |
+
idx = None
|
| 293 |
+
elif isinstance(dev, int):
|
| 294 |
+
idx = dev
|
| 295 |
+
if idx is not None:
|
| 296 |
+
devices.add(idx)
|
| 297 |
+
|
| 298 |
+
if not devices:
|
| 299 |
+
if device_str == "auto":
|
| 300 |
+
devices.update(range(torch.cuda.device_count()))
|
| 301 |
+
elif isinstance(device_str, str) and device_str.startswith("cuda:"):
|
| 302 |
+
try:
|
| 303 |
+
devices.add(int(device_str.split(":")[1]))
|
| 304 |
+
except Exception:
|
| 305 |
+
pass
|
| 306 |
+
else:
|
| 307 |
+
devices.update(range(torch.cuda.device_count()))
|
| 308 |
+
|
| 309 |
+
return sorted(devices)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def maybe_reset_cuda(device_indices: List[int]) -> None:
|
| 313 |
+
if not torch.cuda.is_available() or not device_indices:
|
| 314 |
+
return
|
| 315 |
+
for idx in device_indices:
|
| 316 |
+
try:
|
| 317 |
+
torch.cuda.reset_peak_memory_stats(device=idx)
|
| 318 |
+
except Exception:
|
| 319 |
+
pass
|
| 320 |
+
try:
|
| 321 |
+
torch.cuda.empty_cache()
|
| 322 |
+
except Exception:
|
| 323 |
+
pass
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def measure(
|
| 327 |
+
method_fn,
|
| 328 |
+
device_indices: List[int],
|
| 329 |
+
*,
|
| 330 |
+
catch_oom: bool,
|
| 331 |
+
) -> Tuple[str, Optional[float], Optional[float], Optional[float], Dict[int, Dict[str, float]]]:
|
| 332 |
+
status = "ok"
|
| 333 |
+
wall: Optional[float] = None
|
| 334 |
+
mem_alloc: Optional[float] = None
|
| 335 |
+
mem_reserved: Optional[float] = None
|
| 336 |
+
mem_by_device: Dict[int, Dict[str, float]] = {}
|
| 337 |
+
try:
|
| 338 |
+
if torch.cuda.is_available() and device_indices:
|
| 339 |
+
for idx in device_indices:
|
| 340 |
+
torch.cuda.synchronize(device=idx)
|
| 341 |
+
t0 = time.time()
|
| 342 |
+
method_fn()
|
| 343 |
+
if torch.cuda.is_available() and device_indices:
|
| 344 |
+
for idx in device_indices:
|
| 345 |
+
torch.cuda.synchronize(device=idx)
|
| 346 |
+
wall = time.time() - t0
|
| 347 |
+
except RuntimeError as e:
|
| 348 |
+
if "out of memory" in str(e).lower():
|
| 349 |
+
status = "oom"
|
| 350 |
+
if not catch_oom:
|
| 351 |
+
raise
|
| 352 |
+
else:
|
| 353 |
+
status = f"runtime_error: {e}"
|
| 354 |
+
if not catch_oom:
|
| 355 |
+
raise
|
| 356 |
+
except Exception as e:
|
| 357 |
+
status = f"error: {e}"
|
| 358 |
+
if not catch_oom:
|
| 359 |
+
raise
|
| 360 |
+
finally:
|
| 361 |
+
if torch.cuda.is_available() and device_indices:
|
| 362 |
+
try:
|
| 363 |
+
total_alloc = 0.0
|
| 364 |
+
total_reserved = 0.0
|
| 365 |
+
for idx in device_indices:
|
| 366 |
+
alloc_bytes = torch.cuda.max_memory_allocated(device=idx)
|
| 367 |
+
reserved_bytes = torch.cuda.max_memory_reserved(device=idx)
|
| 368 |
+
total_alloc += alloc_bytes
|
| 369 |
+
total_reserved += reserved_bytes
|
| 370 |
+
mem_by_device[idx] = {
|
| 371 |
+
"allocated_gb": alloc_bytes / 1e9,
|
| 372 |
+
"reserved_gb": reserved_bytes / 1e9,
|
| 373 |
+
}
|
| 374 |
+
mem_alloc = total_alloc / 1e9
|
| 375 |
+
mem_reserved = total_reserved / 1e9
|
| 376 |
+
except Exception:
|
| 377 |
+
pass
|
| 378 |
+
return status, wall, mem_alloc, mem_reserved, mem_by_device
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def make_attr_runner(
|
| 382 |
+
attr_func: str,
|
| 383 |
+
model: Any,
|
| 384 |
+
tokenizer: Any,
|
| 385 |
+
chunk_tokens: int,
|
| 386 |
+
sink_chunk_tokens: int,
|
| 387 |
+
batch_size: int,
|
| 388 |
+
prompt: str,
|
| 389 |
+
target: str,
|
| 390 |
+
):
|
| 391 |
+
lf = attr_func.lower()
|
| 392 |
+
if lf == "ig":
|
| 393 |
+
llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
|
| 394 |
+
|
| 395 |
+
def fn():
|
| 396 |
+
return llm_attributor.calculate_IG_per_generation(
|
| 397 |
+
prompt, steps=20, baseline=tokenizer.eos_token_id, batch_size=batch_size, target=target
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
return fn
|
| 401 |
+
|
| 402 |
+
if lf == "attention_i_g":
|
| 403 |
+
llm_attn = llm_attr.LLMAttentionAttribution(model, tokenizer)
|
| 404 |
+
llm_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
|
| 405 |
+
|
| 406 |
+
def fn():
|
| 407 |
+
attn = llm_attn.calculate_attention_attribution(prompt, target=target)
|
| 408 |
+
ig = llm_ig.calculate_IG_per_generation(
|
| 409 |
+
prompt, steps=20, baseline=tokenizer.eos_token_id, batch_size=batch_size, target=target
|
| 410 |
+
)
|
| 411 |
+
attn.attribution_matrix = attn.attribution_matrix * ig.attribution_matrix
|
| 412 |
+
return attn
|
| 413 |
+
|
| 414 |
+
return fn
|
| 415 |
+
|
| 416 |
+
if lf == "perturbation_all":
|
| 417 |
+
llm_attrtor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
|
| 418 |
+
|
| 419 |
+
def fn():
|
| 420 |
+
return llm_attrtor.calculate_feature_ablation_sentences(
|
| 421 |
+
prompt, baseline=tokenizer.eos_token_id, measure="log_loss", target=target
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
return fn
|
| 425 |
+
|
| 426 |
+
if lf == "perturbation_clp":
|
| 427 |
+
llm_attrtor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
|
| 428 |
+
|
| 429 |
+
def fn():
|
| 430 |
+
return llm_attrtor.calculate_feature_ablation_sentences(
|
| 431 |
+
prompt, baseline=tokenizer.eos_token_id, measure="KL", target=target
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
return fn
|
| 435 |
+
|
| 436 |
+
if lf == "perturbation_reagent":
|
| 437 |
+
llm_attrtor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
|
| 438 |
+
|
| 439 |
+
def fn():
|
| 440 |
+
return llm_attrtor.calculate_feature_ablation_sentences_mlm(prompt, target=target)
|
| 441 |
+
|
| 442 |
+
return fn
|
| 443 |
+
|
| 444 |
+
if lf == "ifr_all_positions":
|
| 445 |
+
llm_attrtor = llm_attr.LLMIFRAttribution(
|
| 446 |
+
model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=1
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
def fn():
|
| 450 |
+
return llm_attrtor.calculate_ifr_for_all_positions(prompt, target=target)
|
| 451 |
+
|
| 452 |
+
return fn
|
| 453 |
+
|
| 454 |
+
if lf == "ifr_multi_hop":
|
| 455 |
+
llm_attrtor = llm_attr.LLMIFRAttribution(
|
| 456 |
+
model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def fn():
|
| 460 |
+
return llm_attrtor.calculate_ifr_multi_hop(prompt, target=target)
|
| 461 |
+
|
| 462 |
+
return fn
|
| 463 |
+
|
| 464 |
+
if lf == "ifr_multi_hop_both":
|
| 465 |
+
import ft_ifr_improve
|
| 466 |
+
|
| 467 |
+
llm_attrtor = ft_ifr_improve.LLMIFRAttributionBoth(
|
| 468 |
+
model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
def fn():
|
| 472 |
+
return llm_attrtor.calculate_ifr_multi_hop_both(prompt, target=target)
|
| 473 |
+
|
| 474 |
+
return fn
|
| 475 |
+
|
| 476 |
+
if lf == "attnlrp":
|
| 477 |
+
llm_attrtor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 478 |
+
|
| 479 |
+
def fn():
|
| 480 |
+
return llm_attrtor.calculate_attnlrp(prompt, target=target)
|
| 481 |
+
|
| 482 |
+
return fn
|
| 483 |
+
|
| 484 |
+
raise ValueError(f"Unsupported attr_func {attr_func}")
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def compute_batch_size(sequence_length: int, max_input_len: int) -> int:
|
| 488 |
+
denom = int(sequence_length)
|
| 489 |
+
return max(1, math.floor((max_input_len - 100) / max(1, denom)))
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def aggregate_results(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 493 |
+
grouped: Dict[Tuple[str, int, int], Dict[str, List[float]]] = defaultdict(lambda: {"time": [], "mem": []})
|
| 494 |
+
statuses: Dict[Tuple[str, int, int], List[str]] = defaultdict(list)
|
| 495 |
+
for row in rows:
|
| 496 |
+
key = (row["attr_func"], row["target_input_tokens"], row["target_output_tokens"])
|
| 497 |
+
statuses[key].append(row["status"])
|
| 498 |
+
if row.get("time_sec") is not None:
|
| 499 |
+
grouped[key]["time"].append(row["time_sec"])
|
| 500 |
+
if row.get("peak_mem_gb") is not None:
|
| 501 |
+
grouped[key]["mem"].append(row["peak_mem_gb"])
|
| 502 |
+
|
| 503 |
+
summary = []
|
| 504 |
+
for key, vals in grouped.items():
|
| 505 |
+
attr_func, input_tokens, output_tokens = key
|
| 506 |
+
total_tokens = input_tokens + output_tokens
|
| 507 |
+
times = vals["time"]
|
| 508 |
+
mems = vals["mem"]
|
| 509 |
+
summary.append(
|
| 510 |
+
{
|
| 511 |
+
"attr_func": attr_func,
|
| 512 |
+
"target_input_tokens": input_tokens,
|
| 513 |
+
"target_total_tokens": total_tokens,
|
| 514 |
+
"target_output_tokens": output_tokens,
|
| 515 |
+
"time_mean": np.mean(times) if times else None,
|
| 516 |
+
"time_std": np.std(times) if times else None,
|
| 517 |
+
"mem_mean": np.mean(mems) if mems else None,
|
| 518 |
+
"mem_std": np.std(mems) if mems else None,
|
| 519 |
+
"statuses": statuses[key],
|
| 520 |
+
}
|
| 521 |
+
)
|
| 522 |
+
return summary
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def append_jsonl_row(f, row: Dict[str, Any]) -> None:
|
| 526 |
+
f.write(json.dumps(row) + "\n")
|
| 527 |
+
f.flush()
|
| 528 |
+
try:
|
| 529 |
+
os.fsync(f.fileno())
|
| 530 |
+
except OSError:
|
| 531 |
+
pass
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def write_summary_csv(rows: List[Dict[str, Any]], out_dir: Path) -> Path:
|
| 535 |
+
summary = aggregate_results(rows)
|
| 536 |
+
summary_path = out_dir / "time_curve_summary.csv"
|
| 537 |
+
tmp_path = out_dir / "time_curve_summary.csv.tmp"
|
| 538 |
+
|
| 539 |
+
with tmp_path.open("w") as f:
|
| 540 |
+
f.write(
|
| 541 |
+
"attr_func,target_input_tokens,target_output_tokens,target_total_tokens,time_mean,time_std,peak_mem_mean,peak_mem_std,statuses\n"
|
| 542 |
+
)
|
| 543 |
+
for row in summary:
|
| 544 |
+
f.write(
|
| 545 |
+
"{},{},{},{},{},{},{},{},{}\n".format(
|
| 546 |
+
row["attr_func"],
|
| 547 |
+
row["target_input_tokens"],
|
| 548 |
+
row["target_output_tokens"],
|
| 549 |
+
row["target_total_tokens"],
|
| 550 |
+
"" if row["time_mean"] is None else f"{row['time_mean']:.4f}",
|
| 551 |
+
"" if row["time_std"] is None else f"{row['time_std']:.4f}",
|
| 552 |
+
"" if row["mem_mean"] is None else f"{row['mem_mean']:.4f}",
|
| 553 |
+
"" if row["mem_std"] is None else f"{row['mem_std']:.4f}",
|
| 554 |
+
"|".join(row["statuses"]),
|
| 555 |
+
)
|
| 556 |
+
)
|
| 557 |
+
f.flush()
|
| 558 |
+
try:
|
| 559 |
+
os.fsync(f.fileno())
|
| 560 |
+
except OSError:
|
| 561 |
+
pass
|
| 562 |
+
|
| 563 |
+
tmp_path.replace(summary_path)
|
| 564 |
+
return summary_path
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def main() -> None:
|
| 568 |
+
args = parse_args()
|
| 569 |
+
device = resolve_device(args.cuda, args.cuda_num)
|
| 570 |
+
attr_funcs = [a.strip() for a in args.attr_funcs.split(",") if a.strip()]
|
| 571 |
+
target_output_lengths = parse_csv_ints(args.output_lengths)
|
| 572 |
+
out_dir = Path(args.output_dir)
|
| 573 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 574 |
+
|
| 575 |
+
random.seed(42)
|
| 576 |
+
np.random.seed(42)
|
| 577 |
+
torch.manual_seed(42)
|
| 578 |
+
|
| 579 |
+
model_name = args.model if args.model_path is None else args.model_path
|
| 580 |
+
model, tokenizer = load_model_balanced(model_name, device)
|
| 581 |
+
device_indices = collect_device_indices(device, model)
|
| 582 |
+
max_ctx = getattr(getattr(model, "config", None), "max_position_embeddings", None)
|
| 583 |
+
|
| 584 |
+
base_text = load_ruler_base(Path(args.ruler_file), fallback="RULER fallback text. ")
|
| 585 |
+
target_base = args.target_text
|
| 586 |
+
all_rows: List[Dict[str, Any]] = []
|
| 587 |
+
runner = None
|
| 588 |
+
raised: Optional[BaseException] = None
|
| 589 |
+
jsonl_f = None
|
| 590 |
+
jsonl_path = out_dir / "time_curve_runs.jsonl"
|
| 591 |
+
summary_path = out_dir / "time_curve_summary.csv"
|
| 592 |
+
|
| 593 |
+
def record_row(row: Dict[str, Any]) -> None:
|
| 594 |
+
all_rows.append(row)
|
| 595 |
+
if jsonl_f is not None:
|
| 596 |
+
append_jsonl_row(jsonl_f, row)
|
| 597 |
+
write_summary_csv(all_rows, out_dir)
|
| 598 |
+
|
| 599 |
+
using_deprecated_total = args.total_lengths is not None
|
| 600 |
+
if using_deprecated_total:
|
| 601 |
+
target_total_lengths = parse_csv_ints(args.total_lengths)
|
| 602 |
+
length_grid: List[Tuple[int, int, int]] = []
|
| 603 |
+
for total_tokens in target_total_lengths:
|
| 604 |
+
for output_tokens in target_output_lengths:
|
| 605 |
+
length_grid.append((total_tokens - output_tokens, output_tokens, total_tokens))
|
| 606 |
+
else:
|
| 607 |
+
target_input_lengths = parse_csv_ints(args.input_lengths)
|
| 608 |
+
length_grid = []
|
| 609 |
+
for input_tokens in target_input_lengths:
|
| 610 |
+
for output_tokens in target_output_lengths:
|
| 611 |
+
length_grid.append((input_tokens, output_tokens, input_tokens + output_tokens))
|
| 612 |
+
|
| 613 |
+
try:
|
| 614 |
+
jsonl_f = jsonl_path.open("w")
|
| 615 |
+
write_summary_csv([], out_dir)
|
| 616 |
+
|
| 617 |
+
for input_tokens, output_tokens, total_tokens in length_grid:
|
| 618 |
+
if input_tokens <= 0:
|
| 619 |
+
for attr in attr_funcs:
|
| 620 |
+
for rep in range(args.repeats):
|
| 621 |
+
record_row(
|
| 622 |
+
{
|
| 623 |
+
"attr_func": attr,
|
| 624 |
+
"target_input_tokens": input_tokens,
|
| 625 |
+
"target_output_tokens": output_tokens,
|
| 626 |
+
"target_total_tokens": total_tokens,
|
| 627 |
+
"actual_input_tokens": None,
|
| 628 |
+
"actual_output_tokens": None,
|
| 629 |
+
"actual_total_tokens_raw": None,
|
| 630 |
+
"actual_user_prompt_tokens": None,
|
| 631 |
+
"actual_formatted_prompt_tokens": None,
|
| 632 |
+
"actual_generation_tokens": None,
|
| 633 |
+
"actual_total_tokens": None,
|
| 634 |
+
"status": "skipped_nonpositive_input",
|
| 635 |
+
"time_sec": None,
|
| 636 |
+
"peak_mem_gb": None,
|
| 637 |
+
"peak_mem_reserved_gb": None,
|
| 638 |
+
"repeat": rep,
|
| 639 |
+
"used_deprecated_total_lengths": using_deprecated_total,
|
| 640 |
+
}
|
| 641 |
+
)
|
| 642 |
+
continue
|
| 643 |
+
|
| 644 |
+
prompt, actual_input_len = build_prompt_to_length(tokenizer, base_text, input_tokens)
|
| 645 |
+
target, actual_output_len = build_output_to_length(tokenizer, target_base, output_tokens)
|
| 646 |
+
actual_total_tokens_raw = len(tokenizer(prompt + target, add_special_tokens=False).input_ids)
|
| 647 |
+
model_lens = estimate_model_lengths(tokenizer, prompt, target)
|
| 648 |
+
|
| 649 |
+
if max_ctx is not None and model_lens["total_tokens"] > max_ctx:
|
| 650 |
+
for attr in attr_funcs:
|
| 651 |
+
for rep in range(args.repeats):
|
| 652 |
+
record_row(
|
| 653 |
+
{
|
| 654 |
+
"attr_func": attr,
|
| 655 |
+
"target_input_tokens": input_tokens,
|
| 656 |
+
"target_output_tokens": output_tokens,
|
| 657 |
+
"target_total_tokens": total_tokens,
|
| 658 |
+
"actual_input_tokens": actual_input_len,
|
| 659 |
+
"actual_output_tokens": actual_output_len,
|
| 660 |
+
"actual_total_tokens_raw": actual_total_tokens_raw,
|
| 661 |
+
"actual_user_prompt_tokens": model_lens["user_prompt_tokens"],
|
| 662 |
+
"actual_formatted_prompt_tokens": model_lens["formatted_prompt_tokens"],
|
| 663 |
+
"actual_generation_tokens": model_lens["generation_tokens"],
|
| 664 |
+
"actual_total_tokens": model_lens["total_tokens"],
|
| 665 |
+
"status": "skipped_model_ctx",
|
| 666 |
+
"time_sec": None,
|
| 667 |
+
"peak_mem_gb": None,
|
| 668 |
+
"peak_mem_reserved_gb": None,
|
| 669 |
+
"repeat": rep,
|
| 670 |
+
"used_deprecated_total_lengths": using_deprecated_total,
|
| 671 |
+
}
|
| 672 |
+
)
|
| 673 |
+
continue
|
| 674 |
+
|
| 675 |
+
batch_size = compute_batch_size(model_lens["total_tokens"], max_input_len=max_ctx or 200000)
|
| 676 |
+
|
| 677 |
+
for attr in attr_funcs:
|
| 678 |
+
for rep in range(args.repeats):
|
| 679 |
+
runner = None
|
| 680 |
+
maybe_reset_cuda(device_indices)
|
| 681 |
+
try:
|
| 682 |
+
runner = make_attr_runner(
|
| 683 |
+
attr,
|
| 684 |
+
model=model,
|
| 685 |
+
tokenizer=tokenizer,
|
| 686 |
+
chunk_tokens=args.chunk_tokens,
|
| 687 |
+
sink_chunk_tokens=args.sink_chunk_tokens,
|
| 688 |
+
batch_size=batch_size,
|
| 689 |
+
prompt=prompt,
|
| 690 |
+
target=target,
|
| 691 |
+
)
|
| 692 |
+
except RuntimeError as e:
|
| 693 |
+
if "out of memory" in str(e).lower():
|
| 694 |
+
status = "oom"
|
| 695 |
+
if not args.catch_oom:
|
| 696 |
+
raise
|
| 697 |
+
else:
|
| 698 |
+
status = f"init_runtime_error: {e}"
|
| 699 |
+
if not args.catch_oom:
|
| 700 |
+
raise
|
| 701 |
+
wall = None
|
| 702 |
+
mem_alloc = None
|
| 703 |
+
mem_reserved = None
|
| 704 |
+
mem_by_device = {}
|
| 705 |
+
except Exception as e:
|
| 706 |
+
status = f"init_error: {e}"
|
| 707 |
+
if not args.catch_oom:
|
| 708 |
+
raise
|
| 709 |
+
wall = None
|
| 710 |
+
mem_alloc = None
|
| 711 |
+
mem_reserved = None
|
| 712 |
+
mem_by_device = {}
|
| 713 |
+
else:
|
| 714 |
+
status, wall, mem_alloc, mem_reserved, mem_by_device = measure(
|
| 715 |
+
runner, device_indices=device_indices, catch_oom=args.catch_oom
|
| 716 |
+
)
|
| 717 |
+
finally:
|
| 718 |
+
runner = None
|
| 719 |
+
|
| 720 |
+
record_row(
|
| 721 |
+
{
|
| 722 |
+
"attr_func": attr,
|
| 723 |
+
"target_input_tokens": input_tokens,
|
| 724 |
+
"target_output_tokens": output_tokens,
|
| 725 |
+
"target_total_tokens": total_tokens,
|
| 726 |
+
"actual_input_tokens": actual_input_len,
|
| 727 |
+
"actual_output_tokens": actual_output_len,
|
| 728 |
+
"actual_total_tokens_raw": actual_total_tokens_raw,
|
| 729 |
+
"actual_user_prompt_tokens": model_lens["user_prompt_tokens"],
|
| 730 |
+
"actual_formatted_prompt_tokens": model_lens["formatted_prompt_tokens"],
|
| 731 |
+
"actual_generation_tokens": model_lens["generation_tokens"],
|
| 732 |
+
"actual_total_tokens": model_lens["total_tokens"],
|
| 733 |
+
"status": status,
|
| 734 |
+
"time_sec": wall,
|
| 735 |
+
"peak_mem_gb": mem_reserved if mem_reserved is not None else mem_alloc,
|
| 736 |
+
"peak_mem_reserved_gb": mem_reserved,
|
| 737 |
+
"peak_mem_by_device_gb": mem_by_device if mem_by_device else None,
|
| 738 |
+
"repeat": rep,
|
| 739 |
+
"used_deprecated_total_lengths": using_deprecated_total,
|
| 740 |
+
}
|
| 741 |
+
)
|
| 742 |
+
except BaseException as e:
|
| 743 |
+
raised = e
|
| 744 |
+
finally:
|
| 745 |
+
runner = None
|
| 746 |
+
if jsonl_f is not None:
|
| 747 |
+
jsonl_f.close()
|
| 748 |
+
write_summary_csv(all_rows, out_dir)
|
| 749 |
+
print(f"Wrote per-run records to {jsonl_path}")
|
| 750 |
+
print(f"Wrote summary to {summary_path}")
|
| 751 |
+
|
| 752 |
+
if raised is not None:
|
| 753 |
+
raise raised
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
if __name__ == "__main__":
|
| 757 |
+
main()
|
exp/exp2/DATASETS.md
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# exp/exp2 数据集与样本流说明
|
| 2 |
+
|
| 3 |
+
本文件说明 Experiment 2 中支持的数据集、样本结构,以及在「采样阶段」与「归因阶段」的处理方式。
|
| 4 |
+
|
| 5 |
+
## 支持的数据集
|
| 6 |
+
- `morehopqa`(`data/with_human_verification.json`)
|
| 7 |
+
- RULER 系列 JSONL:`hotpotqa_long`、`niah_*`、`vt_*`(自动在 `data/ruler_multihop/<len>/.../validation.jsonl` 搜索),或直接传入任意 RULER JSONL 路径
|
| 8 |
+
- 其余数据集(如 math)被显式跳过
|
| 9 |
+
- 归因阶段同样优先使用缓存文件 `exp/exp2/data/<name>.jsonl`,否则按上述规则解析;传入存在的 JSONL 路径也会按 RULER 结构加载
|
| 10 |
+
|
| 11 |
+
### 共同的样本字段定义
|
| 12 |
+
```json
|
| 13 |
+
{
|
| 14 |
+
"prompt": "<上下文+问题>",
|
| 15 |
+
"target": "<答案或生成>",
|
| 16 |
+
"indices_to_explain": [start_tok, end_tok] | null, // token-level:需要解释的 generation token span(闭区间)
|
| 17 |
+
"attr_mask_indices": [...], // legacy:覆盖率金标句子索引(当前 exp2 不再使用),可能为 null
|
| 18 |
+
"sink_span": [start, end] | null, // 生成 token 中的答案片段
|
| 19 |
+
"thinking_span": [start, end] | null, // 生成 token 中的 CoT 片段
|
| 20 |
+
"metadata": { ... } // 数据集特定元信息
|
| 21 |
+
}
|
| 22 |
+
```
|
| 23 |
+
- **`CachedExample`**:`dataset_utils.py` 统一的内存态结构,字段与上述 JSON 完全一致,用于采样阶段(加载原始数据)与归因阶段(加载缓存或原始)。
|
| 24 |
+
- **缓存行(JSONL)**:`sample_and_filter.py` 写入的每行 JSON,与 `CachedExample` 字段一一对应。
|
| 25 |
+
- **采样阶段处理流(通用)**:
|
| 26 |
+
1. 加载原始数据集样本(`prompt`/`indices_to_explain` 等保持一致)。
|
| 27 |
+
2. 按模板调用生成模型,要求「思考文本 + 末尾 \\box{} 答案」。
|
| 28 |
+
3. 若生成不符合「思考 + 单个 \\box{} 且无尾巴」的格式,直接丢弃该样本。
|
| 29 |
+
4. 提取思考片段与 `\\box{}` 内文本,仅用 `\\box{}` 内文调用判定模型。
|
| 30 |
+
5. 判定为 True 时,重新拼接「思考片段 + 去除 box 包裹的答案文本」作为 `target`,并据此记录 `sink_span`/`thinking_span`。
|
| 31 |
+
6. 写入缓存:只保留 `reference_answer`、`judge_response`(可选 `boxed_answer`),不再存储 `candidate_answer`。
|
| 32 |
+
|
| 33 |
+
### 生成切分与 span 解析
|
| 34 |
+
- `split_boxed_generation`(`dataset_utils.py`)校验格式:必须是「非空思考文本 + 单个末尾 \\box{}」且箱体之后无其他字符,否则直接跳过。
|
| 35 |
+
- `target` 由「思考片段 + 换行 + 最终答案文本(无 box)」重组。
|
| 36 |
+
- `attach_spans_from_answer` 使用 tokenizer 的 offset mapping 将最终答案在 `target` 中的字符区间映射到 token 级索引,得到 `sink_span`;`thinking_span` 取从开头到 `sink_span` 前一 token 的闭区间。两者均为 token 级 span,满足后续多跳 IFR 的调用约定。
|
| 37 |
+
- `indices_to_explain` 在采样写缓存时统一设置为 `sink_span`(boxed 内文在 `target` 中对应的 generation token span)。
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## MoreHopQA
|
| 42 |
+
- **原始样本结构(`MoreHopQAAttributionDataset` → `CachedExample`)**
|
| 43 |
+
```json
|
| 44 |
+
{
|
| 45 |
+
"prompt": "<context 拼接>\\n<question>",
|
| 46 |
+
"target": null,
|
| 47 |
+
"indices_to_explain": null,
|
| 48 |
+
"attr_mask_indices": null,
|
| 49 |
+
"sink_span": null,
|
| 50 |
+
"thinking_span": null,
|
| 51 |
+
"metadata": {
|
| 52 |
+
"answer": "<gold answer>",
|
| 53 |
+
"_id": "<example id>",
|
| 54 |
+
"original_context": <原始上下文结构>
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
```
|
| 58 |
+
- 加载时机:`DatasetLoader.load_raw("morehopqa")` 在采样阶段、归因阶段(无缓存时)都会产出 `CachedExample`。
|
| 59 |
+
- 说明:exp2 的 token-level row/rec 需要 `target` + 可定位的答案 token span;建议先跑 `sample_and_filter.py` 产出缓存后再做归因评估。
|
| 60 |
+
|
| 61 |
+
- **采样阶段(生成 & 过滤后写缓存)**
|
| 62 |
+
```json
|
| 63 |
+
{
|
| 64 |
+
"prompt": "<同上>",
|
| 65 |
+
"target": "<生成的 CoT + 最终答案文本(已去掉 box 包裹)>",
|
| 66 |
+
"indices_to_explain": [start_tok, end_tok],
|
| 67 |
+
"attr_mask_indices": null,
|
| 68 |
+
"sink_span": [start_tok, end_tok] | null,
|
| 69 |
+
"thinking_span": [start_tok, end_tok] | null,
|
| 70 |
+
"metadata": {
|
| 71 |
+
"answer": "<gold answer>",
|
| 72 |
+
"_id": "<example id>",
|
| 73 |
+
"original_context": <原始上下文结构>,
|
| 74 |
+
"reference_answer": "<gold answer>",
|
| 75 |
+
"judge_response": "<True/False 文本>",
|
| 76 |
+
"boxed_answer": "<可选,boxed 解析结果>"
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
```
|
| 80 |
+
- `sink_span`/`thinking_span`:仅在成功解析 `\\box{}` 时填充;`target` 为「思考 + 最终答案文本」的裁剪版。
|
| 81 |
+
- 写入:`exp/exp2/data/morehopqa.jsonl`。
|
| 82 |
+
|
| 83 |
+
- **归因阶段(加载缓存优先)**
|
| 84 |
+
- 加载:`run_exp.py` 优先 `load_cached`(JSONL → `CachedExample`),否则回退原始结构并在线生成 `target`。
|
| 85 |
+
- 使用:忠实度(token-level RISE/MAS)直接用缓存的 `target`;`ifr_multi_hop` 在有 `sink_span`/`thinking_span` 时限定答案/CoT,否则视整个生成为 sink。
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## RULER 热点问答(`hotpotqa_long`)
|
| 90 |
+
- **原始样本结构(`RulerAttributionDataset` → `CachedExample`)**
|
| 91 |
+
```json
|
| 92 |
+
{
|
| 93 |
+
"prompt": "<input> + <answer_prefix>",
|
| 94 |
+
"target": "<answer_prefix + sep + ', '.join(outputs)>",
|
| 95 |
+
"indices_to_explain": [0],
|
| 96 |
+
"attr_mask_indices": [<句子索引>...] | null,
|
| 97 |
+
"sink_span": null,
|
| 98 |
+
"thinking_span": null,
|
| 99 |
+
"metadata": {
|
| 100 |
+
"dataset": "ruler",
|
| 101 |
+
"length": <int>,
|
| 102 |
+
"length_w_model_temp": <any>,
|
| 103 |
+
"outputs": [...],
|
| 104 |
+
"answer_prefix": "<str>",
|
| 105 |
+
"token_position_answer": <any>,
|
| 106 |
+
"needle_spans": [
|
| 107 |
+
{
|
| 108 |
+
"title": "<str>",
|
| 109 |
+
"doc_index": <int>,
|
| 110 |
+
"document_number": <int>,
|
| 111 |
+
"sentence_index": <int>,
|
| 112 |
+
"sentence": "<str>",
|
| 113 |
+
"context_span": [start, end],
|
| 114 |
+
"span": [start, end],
|
| 115 |
+
"snippet": "<str>"
|
| 116 |
+
},
|
| 117 |
+
...
|
| 118 |
+
],
|
| 119 |
+
"prompt_sentence_count": <int>,
|
| 120 |
+
"reference_answer": "<在 loader 中补充,来自 outputs 或 target>"
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
```
|
| 124 |
+
- 加载时机:`DatasetLoader.load_raw("hotpotqa_long")` 在采样阶段、归因阶段(无缓存时)都会产出 `CachedExample`。
|
| 125 |
+
|
| 126 |
+
- **采样阶段(生成 & 过滤后写缓存)**
|
| 127 |
+
```json
|
| 128 |
+
{
|
| 129 |
+
"prompt": "<同上>",
|
| 130 |
+
"target": "<生成的 CoT + 最终答案文本(已去掉 box 包裹)>",
|
| 131 |
+
"indices_to_explain": [-2],
|
| 132 |
+
"attr_mask_indices": [<句子索引>...] | null,
|
| 133 |
+
"sink_span": [start_tok, end_tok] | null,
|
| 134 |
+
"thinking_span": [start_tok, end_tok] | null,
|
| 135 |
+
"metadata": {
|
| 136 |
+
"dataset": "ruler",
|
| 137 |
+
"length": <int>,
|
| 138 |
+
"length_w_model_temp": <any>,
|
| 139 |
+
"outputs": [...],
|
| 140 |
+
"answer_prefix": "<str>",
|
| 141 |
+
"token_position_answer": <any>,
|
| 142 |
+
"needle_spans": [...],
|
| 143 |
+
"prompt_sentence_count": <int>,
|
| 144 |
+
"reference_answer": "<outputs 拼接或 target>",
|
| 145 |
+
"judge_response": "<True/False 文本>",
|
| 146 |
+
"boxed_answer": "<可选>"
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
```
|
| 150 |
+
- `attr_mask_indices` 保留原值;`indices_to_explain` 统一为末句 `[-2]`(最后一个非 EOS 生成句);`sink_span`/`thinking_span` 仅在成功解析 `\\box{}` 时填充;`target` 为「思考 + 最终答案文本」的裁剪版。
|
| 151 |
+
- 写入:`exp/exp2/data/hotpotqa_long.jsonl`。
|
| 152 |
+
|
| 153 |
+
- **归因阶段(加载缓存优先)**
|
| 154 |
+
- 加载:优先 `load_cached`(JSONL → `CachedExample`),否则回退原始解析。
|
| 155 |
+
- 使用:覆盖率使用 `attr_mask_indices`;忠实度与 `ifr_multi_hop` 利用缓存的 `sink_span`/`thinking_span` 定位答案/CoT,若缺失则视整个生成为 sink。
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## RULER NIAH / Variable Tracking(`niah_*`, `vt_*`)
|
| 160 |
+
- **原始样本结构(同 RULER 通用)**
|
| 161 |
+
```json
|
| 162 |
+
{
|
| 163 |
+
"prompt": "<input> + <answer_prefix>",
|
| 164 |
+
"target": "<answer_prefix + sep + ', '.join(outputs)>",
|
| 165 |
+
"indices_to_explain": [0],
|
| 166 |
+
"attr_mask_indices": [<句子索引>...] | null,
|
| 167 |
+
"sink_span": null,
|
| 168 |
+
"thinking_span": null,
|
| 169 |
+
"metadata": {
|
| 170 |
+
"dataset": "ruler",
|
| 171 |
+
"length": <int>,
|
| 172 |
+
"length_w_model_temp": <any>,
|
| 173 |
+
"outputs": [...],
|
| 174 |
+
"answer_prefix": "<str>",
|
| 175 |
+
"token_position_answer": <any>,
|
| 176 |
+
"needle_spans": [...],
|
| 177 |
+
"prompt_sentence_count": <int>,
|
| 178 |
+
"reference_answer": "<在 loader 中补充>"
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
```
|
| 182 |
+
- 加载时机:`DatasetLoader.load_raw("<niah_* 或 vt_*>")` 在采样阶段、归因阶段(无缓存时)使用。
|
| 183 |
+
|
| 184 |
+
- **采样阶段(生成 & 过滤后写缓存)**
|
| 185 |
+
```json
|
| 186 |
+
{
|
| 187 |
+
"prompt": "<同上>",
|
| 188 |
+
"target": "<思考 + 最终答案文本(无 box),无其他尾巴>",
|
| 189 |
+
"indices_to_explain": [start_tok, end_tok],
|
| 190 |
+
"attr_mask_indices": [<句子索引>...] | null,
|
| 191 |
+
"sink_span": [start_tok, end_tok] | null,
|
| 192 |
+
"thinking_span": [start_tok, end_tok] | null,
|
| 193 |
+
"metadata": {
|
| 194 |
+
"dataset": "ruler",
|
| 195 |
+
"length": <int>,
|
| 196 |
+
"length_w_model_temp": <any>,
|
| 197 |
+
"outputs": [...],
|
| 198 |
+
"answer_prefix": "<str>",
|
| 199 |
+
"token_position_answer": <any>,
|
| 200 |
+
"needle_spans": [...],
|
| 201 |
+
"prompt_sentence_count": <int>,
|
| 202 |
+
"reference_answer": "<outputs 拼接或 target>",
|
| 203 |
+
"judge_response": "<True/False 文本>",
|
| 204 |
+
"boxed_answer": "<可选>"
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
```
|
| 208 |
+
- 生成/判定流程与 `hotpotqa_long` 相同;`target` 是裁剪后的「思考 + 最终答案文本」。
|
| 209 |
+
- 写入:`exp/exp2/data/<dataset>.jsonl`(例如 `niah_mq_q2.jsonl`, `vt_h6_c1.jsonl`)。
|
| 210 |
+
|
| 211 |
+
- **归因阶段(加载缓存优先)**
|
| 212 |
+
- 与 `hotpotqa_long` 相同:优先缓存,否则原始;恢复率(`recovery_ruler`)使用 `metadata.needle_spans`(映射到 prompt tokens);多跳 IFR 在有 `sink_span`/`thinking_span` 时作用于答案/CoT。
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
## `indices_to_explain` 约定
|
| 217 |
+
- token-level:`indices_to_explain = [start_tok, end_tok]`(闭区间),坐标系为 `tokenizer(target, add_special_tokens=False)` 的 generation token indices。
|
| 218 |
+
- exp2 推荐:`indices_to_explain == sink_span`,即 boxed 内文(最终答案)在 `target` 中对应的 token span。
|
| 219 |
+
|
| 220 |
+
---
|
| 221 |
+
|
| 222 |
+
## 自定义 RULER JSONL 路径
|
| 223 |
+
- 若 `--dataset` 传入存在的 JSONL 路径,`dataset_from_name` 按 RULER 文件解析,字段与流程同 RULER 系列。
|
| 224 |
+
- 采样、归因阶段行为与上文 RULER 描述一致,只是文件名由显式路径决定。
|
| 225 |
+
|
| 226 |
+
---
|
| 227 |
+
|
| 228 |
+
## 归因阶段加载优先级与效果
|
| 229 |
+
- `run_exp.py` 加载顺序:`exp/exp2/data/<name>.jsonl` 缓存 > 显式给定的 JSONL 路径 > 原始解析(MoreHopQA 或 RULER)
|
| 230 |
+
- 恢复率 (`mode=recovery_ruler`) 仅支持 RULER(要求 `metadata.needle_spans`),否则拒绝
|
| 231 |
+
- 忠实度 (`mode=faithfulness_gen`) 使用生成文本;`ifr_multi_hop` 在有 `sink_span`/`thinking_span` 时才对答案/CoT 做多跳,否则退化为整段生成
|
exp/exp2/README.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashTrace 实验 2(多步推理下的忠实度)
|
| 2 |
+
|
| 3 |
+
本目录提供「11 数据集 × 9 方法 × 3 指标」的实验工具,**跳过 AT2**,**跳过 math**。流程分为两步:先采样并过滤高质量 CoT+boxed 生成,再对过滤结果做归因评估。
|
| 4 |
+
|
| 5 |
+
支持数据集:MoreHopQA、HotpotQA(RULER hotpotqa_long)、RULER niah(niah_*)、RULER variable tracking(vt_*)。RULER 路径自动在 `data/ruler_multihop/<len>/.../validation.jsonl` 中搜索。
|
| 6 |
+
|
| 7 |
+
主要文件:
|
| 8 |
+
- `sample_and_filter.py`:采样 + 判定一致性,输出到 `exp/exp2/data/`
|
| 9 |
+
- `run_exp.py`:归因测试,输出到 `exp/exp2/output/`
|
| 10 |
+
- `dataset_utils.py`:数据加载、答案 span 解析
|
| 11 |
+
|
| 12 |
+
采样脚本支持的数据集
|
| 13 |
+
- `morehopqa`(本地 `data/with_human_verification.json`)
|
| 14 |
+
- `hotpotqa_long`(自动在 `data/ruler_multihop/<len>/hotpotqa_long/validation.jsonl` 搜索)
|
| 15 |
+
- `niah_*`(RULER niah 变体,自动搜索同上)
|
| 16 |
+
- `vt_*`(RULER variable tracking 变体,自动搜索同上)
|
| 17 |
+
- 直接传 RULER JSONL 路径(作为数据集名处理),其余类型不支持
|
| 18 |
+
|
| 19 |
+
归因测试支持
|
| 20 |
+
- 数据集:优先使用 `exp/exp2/data/<name>.jsonl` 缓存,若无则按采样同样的解析规则加载;math 显式拒绝。
|
| 21 |
+
- 指标:
|
| 22 |
+
- `faithfulness_gen`(生成侧):可运行在任何已加载样本(math 以外)。
|
| 23 |
+
- `recovery_ruler`(恢复率,仅 RULER):Recall@10%(排名只在 prompt tokens 上进行,gold 来自 `needle_spans`)。
|
| 24 |
+
- 方法(`--attr_funcs`):`IG`、`perturbation_all`、`perturbation_CLP`、`perturbation_REAGENT`、`attention`(内部融合 IG)、`ifr_all_positions`、`ifr_multi_hop`、`attnlrp`、`ft_attnlrp`、`basic`。AT2 未提供。
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 数据采样
|
| 29 |
+
|
| 30 |
+
实现逻辑
|
| 31 |
+
- 统一数据加载:`DatasetLoader` 读取 MoreHopQA / HotpotQA / RULER niah / RULER vt;可直接传自定义 RULER JSONL。
|
| 32 |
+
- 生成模型:`qwen3-235b-a22b-2507`(英文 system prompt),要求「先简要思考,再用 `\box{}` 包裹最终答案且末尾不追加内容」;user prompt 为原题,无额外模板。
|
| 33 |
+
- 判定模型:`deepseek-v3-1-terminus`(英文 system prompt),只输出 True/False 判断 `\box{}` 内文与参考答案是否一致。
|
| 34 |
+
- 过滤:仅保留「思考 + 末尾 boxed 答案」且判定为 True 的样本;`target` 用提取的思考片段与 **去掉 box 包裹的最终答案** 重组,附带 token 级 `sink_span`/`thinking_span`、`reference_answer`、`judge_response`(不再存 `candidate_answer`),`indices_to_explain` 统一写为 `sink_span`(boxed 内文在 `target` 的 generation token span,[start_tok, end_tok])。
|
| 35 |
+
- 采样会按原始顺序依次尝试样本,判定失败立即跳过;累计到 `--max_examples` 条成功样本即提前停止(若源数据不足则更少),tqdm 会分别显示尝试与成功计数。
|
| 36 |
+
|
| 37 |
+
使用说明
|
| 38 |
+
```bash
|
| 39 |
+
export FLASHTRACE_API_KEY=sk-yaojia-get-ccfa # 或 OPENAI_API_KEY
|
| 40 |
+
|
| 41 |
+
# 示例:采样 hotpotqa_long,保留最多 100 条判定为 True 的样本
|
| 42 |
+
python exp/exp2/sample_and_filter.py \
|
| 43 |
+
--dataset data/with_human_verification.json \
|
| 44 |
+
--max_examples 100 \
|
| 45 |
+
--api_key sk-yaojia-get-ccfa \
|
| 46 |
+
--tokenizer_model /opt/share/models/Qwen/Qwen3-8B > exp/exp2/out.log
|
| 47 |
+
```
|
| 48 |
+
常用参数:
|
| 49 |
+
- `--dataset`:morehopqa | hotpotqa_long | niah_* | vt_*(或直接 JSONL 路径)
|
| 50 |
+
- `--max_examples`:希望保留的成功样本数;达到后即停止(若源数据不足则更少)
|
| 51 |
+
- `--tokenizer_model`:用于 span 检测的 tokenizer(默认复用生成模型)
|
| 52 |
+
- `--api_base`/`--api_key`:接口地址与密钥(默认本地 http://localhost:4000/v1)
|
| 53 |
+
- `--request_interval` / `--judge_interval`:生成/判定间隔节流(默认 1s)
|
| 54 |
+
- `--rate_limit_delay`:遇到 HTTP 429 时的等待秒数(默认 5s);会在重试前自动 sleep
|
| 55 |
+
输出:`exp/exp2/data/<dataset>.jsonl`
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## 归因测试
|
| 60 |
+
|
| 61 |
+
实现逻辑
|
| 62 |
+
- 输入:优先读取 `exp/exp2/data/<dataset>.jsonl`(过滤缓存);若不存在则回退到原始数据解析。
|
| 63 |
+
- 方法:忠实度(token-level RISE/MAS)对齐 `evaluations/faithfulness.py` 的逻辑(AT2 未实现),math 自动拒绝。
|
| 64 |
+
- 多跳 FlashTrace:若缓存含 `sink_span`/`thinking_span` 则用于 multi-hop IFR,否则默认整句答案为 sink。
|
| 65 |
+
- 一次运行可同时评测多个指标:`--mode` 支持多值与逗号分隔(如 `--mode faithfulness_gen,recovery_ruler` 或 `--mode faithfulness_gen, recovery_ruler`),对同一批样本只做一次归因。
|
| 66 |
+
- 可选保存样本级 trace:加 `--save_hop_traces` 会为**所有方法、所有样本**保存归因向量与逐样本指标到 `exp/exp2/output/traces/...`;对 multi-hop 方法还会额外保存每跳的 token-level 向量 `V_h`(单一 `vh`,即实际参与多跳传播的向量),并在 manifest 中记录 `attnlrp_neg_handling/attnlrp_norm_mode` 等设置。
|
| 67 |
+
- 已知兼容性:部分 tokenizer 在 chat template 边界会出现 token 合并,导致评测侧用 token-id 子序列定位 user prompt 失败;exp2 已改为直接复用归因阶段算出的 `user_prompt_indices` 做扰动定位。
|
| 68 |
+
- 批大小估算:沿用原脚本 `(max_input_len-100)/len(tokenizer(format_prompt(prompt)+target))` 的保守估计(至少 1)。`max_input_len` 由代码内置映射表基于 `--model` 字符串决定,未命中或仅传 `--model_path` 时默认 2000;如需映射值而又用本地路径,请同时传入对应的 `--model` 名称。
|
| 69 |
+
- 计时:对每个样本的归因计算(recovery/faithfulness)分别计时,最终在 CSV 末尾追加 `Avg Sample Time (s)` 并在控制台打印平均耗时。
|
| 70 |
+
- 输出:`exp/exp2/output/faithfulness/...`、`exp/exp2/output/recovery/...`,以及(可选)`exp/exp2/output/traces/...`,按数据集和模型分目录。
|
| 71 |
+
|
| 72 |
+
使用说明
|
| 73 |
+
```bash
|
| 74 |
+
# 生成侧 RISE/MAS 忠实度 perturbation_all_fast,perturbation_CLP_fast,perturbation_REAGENT_fast,ifr_multi_hop_stop_words,ifr_multi_hop_both,ifr_multi_hop_split_hop,ft_attnlrp,ifr_multi_hop,attnlrp,ifr_all_positions,perturbation_all,perturbation_REAGENT,perturbation_CLP,IG,attention
|
| 75 |
+
python exp/exp2/run_exp.py \
|
| 76 |
+
--datasets exp/exp2/data/math.jsonl \
|
| 77 |
+
--attr_funcs IG,attention \
|
| 78 |
+
--model qwen-8B \
|
| 79 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 80 |
+
--cuda 2,3,4,5,6,7 \
|
| 81 |
+
--num_examples 100 \
|
| 82 |
+
--mode faithfulness_gen \
|
| 83 |
+
--n_hops 1 \
|
| 84 |
+
--save_hop_traces \
|
| 85 |
+
&& python exp/exp2/run_exp.py \
|
| 86 |
+
--datasets exp/exp2/data/morehopqa.jsonl \
|
| 87 |
+
--attr_funcs IG,attention \
|
| 88 |
+
--model qwen-8B \
|
| 89 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 90 |
+
--cuda 2,3,4,5,6,7 \
|
| 91 |
+
--num_examples 100 \
|
| 92 |
+
--mode faithfulness_gen \
|
| 93 |
+
--n_hops 1 \
|
| 94 |
+
--save_hop_traces
|
| 95 |
+
|
| 96 |
+
# --attnlrp_neg_handling drop \
|
| 97 |
+
# --attnlrp_norm_mode norm
|
| 98 |
+
```
|
| 99 |
+
常用参数:
|
| 100 |
+
- `--datasets`:逗号分隔数据集名;若已存在 `exp/exp2/data/<name>.jsonl` 则直接使用。
|
| 101 |
+
- `--attr_funcs`:逗号分隔方法(无 AT2);`ifr_multi_hop` 与 `ft_attnlrp` 支持多跳(由 `--n_hops` 控制)。
|
| 102 |
+
- `--attnlrp_neg_handling`:FT-AttnLRP 每跳负值处理(`drop`/`abs`)。
|
| 103 |
+
- `--attnlrp_norm_mode`:FT-AttnLRP 正则化与 hop ratio 开关(`norm`/`no_norm`)。
|
| 104 |
+
- `--data_root`/`--output_root`:缓存与结果目录(默认 `exp/exp2/data` / `exp/exp2/output`)。
|
| 105 |
+
- `--mode`:`faithfulness_gen`、`recovery_ruler`,可多值/逗号分隔(一次归因同时输出多个指标);`--num_examples` 控制评测条数。math 会被拒绝。***
|
| 106 |
+
- `--save_hop_traces`:保存样本级 trace 到 `exp/exp2/output/traces/<dataset>/<model>/<run_tag>/`(每样本 `ex_*.npz` + `manifest.jsonl`)。
|
exp/exp2/dataset_utils.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset helpers for Experiment 2 (CoT / multi-hop faithfulness).
|
| 2 |
+
|
| 3 |
+
Named dataset_utils to avoid collision with the HF `datasets` package.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
import re
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Dict, Iterable, List, Optional
|
| 14 |
+
|
| 15 |
+
from attribution_datasets import (
|
| 16 |
+
AttributionExample,
|
| 17 |
+
MoreHopQAAttributionDataset,
|
| 18 |
+
RulerAttributionDataset,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class CachedExample:
|
| 24 |
+
prompt: str
|
| 25 |
+
target: Optional[str]
|
| 26 |
+
indices_to_explain: Optional[List[int]]
|
| 27 |
+
attr_mask_indices: Optional[List[int]]
|
| 28 |
+
sink_span: Optional[List[int]]
|
| 29 |
+
thinking_span: Optional[List[int]]
|
| 30 |
+
metadata: Dict[str, Any]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def read_cached_jsonl(path: Path) -> List[CachedExample]:
|
| 34 |
+
examples: List[CachedExample] = []
|
| 35 |
+
with path.open("r", encoding="utf-8") as f:
|
| 36 |
+
for line in f:
|
| 37 |
+
if not line.strip():
|
| 38 |
+
continue
|
| 39 |
+
obj = json.loads(line)
|
| 40 |
+
examples.append(
|
| 41 |
+
CachedExample(
|
| 42 |
+
prompt=obj["prompt"],
|
| 43 |
+
target=obj.get("target"),
|
| 44 |
+
indices_to_explain=obj.get("indices_to_explain"),
|
| 45 |
+
attr_mask_indices=obj.get("attr_mask_indices"),
|
| 46 |
+
sink_span=obj.get("sink_span"),
|
| 47 |
+
thinking_span=obj.get("thinking_span"),
|
| 48 |
+
metadata=obj.get("metadata", {}),
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
return examples
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_cached(path: Path, sample: Optional[int] = None, seed: int = 42) -> List[CachedExample]:
|
| 55 |
+
ex = read_cached_jsonl(path)
|
| 56 |
+
if sample is not None and sample < len(ex):
|
| 57 |
+
random.Random(seed).shuffle(ex)
|
| 58 |
+
ex = ex[:sample]
|
| 59 |
+
return ex
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def load_ruler(path: Path, sample: Optional[int] = None, seed: int = 42) -> List[CachedExample]:
|
| 63 |
+
ds = RulerAttributionDataset(path)
|
| 64 |
+
examples: List[CachedExample] = []
|
| 65 |
+
ex_iter: Iterable[AttributionExample] = ds
|
| 66 |
+
if sample is not None and sample < len(ds):
|
| 67 |
+
ex_iter = list(ds)
|
| 68 |
+
random.Random(seed).shuffle(ex_iter)
|
| 69 |
+
ex_iter = ex_iter[:sample]
|
| 70 |
+
for ex in ex_iter:
|
| 71 |
+
examples.append(
|
| 72 |
+
CachedExample(
|
| 73 |
+
prompt=ex.prompt,
|
| 74 |
+
target=ex.target,
|
| 75 |
+
indices_to_explain=ex.indices_to_explain,
|
| 76 |
+
attr_mask_indices=ex.attr_mask_indices,
|
| 77 |
+
sink_span=None,
|
| 78 |
+
thinking_span=None,
|
| 79 |
+
metadata=ex.metadata,
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
return examples
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_morehopqa(
|
| 86 |
+
path: str | Path = "./data/with_human_verification.json", sample: Optional[int] = None, seed: int = 42
|
| 87 |
+
) -> List[CachedExample]:
|
| 88 |
+
ds = MoreHopQAAttributionDataset(path)
|
| 89 |
+
ex_iter: Iterable[AttributionExample] = ds
|
| 90 |
+
if sample is not None and sample < len(ds):
|
| 91 |
+
ex_iter = list(ds)
|
| 92 |
+
random.Random(seed).shuffle(ex_iter)
|
| 93 |
+
ex_iter = ex_iter[:sample]
|
| 94 |
+
examples: List[CachedExample] = []
|
| 95 |
+
for ex in ex_iter:
|
| 96 |
+
examples.append(
|
| 97 |
+
CachedExample(
|
| 98 |
+
prompt=ex.prompt,
|
| 99 |
+
target=None,
|
| 100 |
+
indices_to_explain=ex.indices_to_explain,
|
| 101 |
+
attr_mask_indices=ex.attr_mask_indices,
|
| 102 |
+
sink_span=None,
|
| 103 |
+
thinking_span=None,
|
| 104 |
+
metadata=ex.metadata,
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
return examples
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def auto_find_ruler(task: str) -> Optional[Path]:
|
| 111 |
+
length_dirs = ["4096", "8192", "16384", "32768", "65536", "131072"]
|
| 112 |
+
base = Path("data/ruler_multihop")
|
| 113 |
+
for ld in length_dirs:
|
| 114 |
+
cand = base / ld / task / "validation.jsonl"
|
| 115 |
+
if cand.exists():
|
| 116 |
+
return cand
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def dataset_from_name(name: str) -> Optional[Path]:
|
| 121 |
+
if name == "hotpotqa_long":
|
| 122 |
+
return auto_find_ruler("hotpotqa_long")
|
| 123 |
+
if name.startswith("vt_"):
|
| 124 |
+
return auto_find_ruler(name)
|
| 125 |
+
if name.startswith("niah"):
|
| 126 |
+
return auto_find_ruler(name)
|
| 127 |
+
p = Path(name)
|
| 128 |
+
if p.exists():
|
| 129 |
+
return p
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
_BOX_PATTERN = re.compile(r"\\box(?:ed)?\s*[\{{](.*?)[\}}]", flags=re.DOTALL)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _find_box_span(text: str) -> Optional[tuple[int, int, str]]:
|
| 137 |
+
"""Return (start_char, end_char, answer_text) for the last \\boxed block."""
|
| 138 |
+
matches = list(_BOX_PATTERN.finditer(text))
|
| 139 |
+
if not matches:
|
| 140 |
+
return None
|
| 141 |
+
m = matches[-1]
|
| 142 |
+
return m.start(0), m.end(0), m.group(1).strip()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def extract_boxed_answer(text: str) -> Optional[str]:
|
| 146 |
+
"""Extract the answer string inside the last \\boxed{} block."""
|
| 147 |
+
match = _find_box_span(text)
|
| 148 |
+
return match[2] if match else None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _find_answer_span(text: str, answer: str) -> Optional[tuple[int, int]]:
|
| 152 |
+
"""Return (start_char, end_char) for the last occurrence of `answer` in text."""
|
| 153 |
+
if not answer or not text:
|
| 154 |
+
return None
|
| 155 |
+
start = text.rfind(answer)
|
| 156 |
+
if start == -1:
|
| 157 |
+
return None
|
| 158 |
+
return start, start + len(answer)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def split_boxed_generation(text: str) -> Optional[tuple[str, str, str]]:
|
| 162 |
+
"""Return (thinking_text, boxed_segment, boxed_answer) if format matches."""
|
| 163 |
+
if not text:
|
| 164 |
+
return None
|
| 165 |
+
match = _find_box_span(text)
|
| 166 |
+
if not match:
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
start_char, end_char, boxed_inner = match
|
| 170 |
+
boxed_segment = text[start_char:end_char].strip()
|
| 171 |
+
thinking_text = text[:start_char].strip()
|
| 172 |
+
trailing = text[end_char:].strip()
|
| 173 |
+
|
| 174 |
+
if not boxed_inner or not boxed_segment:
|
| 175 |
+
return None
|
| 176 |
+
if trailing:
|
| 177 |
+
return None
|
| 178 |
+
if not thinking_text:
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
return thinking_text, boxed_segment, boxed_inner
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def attach_spans_from_answer(
|
| 185 |
+
example: CachedExample, tokenizer, answer_text: Optional[str] = None
|
| 186 |
+
) -> CachedExample:
|
| 187 |
+
"""Attach sink/thinking spans by locating the (plain) answer in `target`.
|
| 188 |
+
|
| 189 |
+
`answer_text` should be the extracted boxed answer; falls back to metadata or
|
| 190 |
+
parsing the target when omitted. Works even when the target no longer keeps
|
| 191 |
+
the \\box{} wrapper.
|
| 192 |
+
"""
|
| 193 |
+
tgt = example.target or ""
|
| 194 |
+
answer = (answer_text or "").strip()
|
| 195 |
+
if not answer:
|
| 196 |
+
answer = (example.metadata.get("boxed_answer") or extract_boxed_answer(tgt) or "").strip()
|
| 197 |
+
|
| 198 |
+
metadata = dict(example.metadata)
|
| 199 |
+
if answer:
|
| 200 |
+
metadata.setdefault("boxed_answer", answer)
|
| 201 |
+
|
| 202 |
+
if tokenizer is None or not tgt or not answer:
|
| 203 |
+
return CachedExample(
|
| 204 |
+
prompt=example.prompt,
|
| 205 |
+
target=example.target,
|
| 206 |
+
indices_to_explain=example.indices_to_explain,
|
| 207 |
+
attr_mask_indices=example.attr_mask_indices,
|
| 208 |
+
sink_span=example.sink_span,
|
| 209 |
+
thinking_span=example.thinking_span,
|
| 210 |
+
metadata=metadata,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
span = _find_answer_span(tgt, answer)
|
| 214 |
+
if span is None:
|
| 215 |
+
return CachedExample(
|
| 216 |
+
prompt=example.prompt,
|
| 217 |
+
target=example.target,
|
| 218 |
+
indices_to_explain=example.indices_to_explain,
|
| 219 |
+
attr_mask_indices=example.attr_mask_indices,
|
| 220 |
+
sink_span=example.sink_span,
|
| 221 |
+
thinking_span=example.thinking_span,
|
| 222 |
+
metadata=metadata,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
span_start_char, span_end_char = span
|
| 226 |
+
gen_ids = tokenizer(tgt, add_special_tokens=False, return_offsets_mapping=True)
|
| 227 |
+
sink_tokens: List[int] = []
|
| 228 |
+
for idx, (s, e) in enumerate(gen_ids["offset_mapping"]):
|
| 229 |
+
# include tokens that overlap the answer span
|
| 230 |
+
if s < span_end_char and e > span_start_char:
|
| 231 |
+
sink_tokens.append(idx)
|
| 232 |
+
if not sink_tokens:
|
| 233 |
+
return CachedExample(
|
| 234 |
+
prompt=example.prompt,
|
| 235 |
+
target=example.target,
|
| 236 |
+
indices_to_explain=example.indices_to_explain,
|
| 237 |
+
attr_mask_indices=example.attr_mask_indices,
|
| 238 |
+
sink_span=example.sink_span,
|
| 239 |
+
thinking_span=example.thinking_span,
|
| 240 |
+
metadata=metadata,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
sink_span = [min(sink_tokens), max(sink_tokens)]
|
| 244 |
+
thinking_end = max(0, sink_span[0] - 1)
|
| 245 |
+
thinking_span = [0, thinking_end] if thinking_end >= 0 else sink_span
|
| 246 |
+
|
| 247 |
+
return CachedExample(
|
| 248 |
+
prompt=example.prompt,
|
| 249 |
+
target=example.target,
|
| 250 |
+
indices_to_explain=example.indices_to_explain,
|
| 251 |
+
attr_mask_indices=example.attr_mask_indices,
|
| 252 |
+
sink_span=example.sink_span or sink_span,
|
| 253 |
+
thinking_span=example.thinking_span or thinking_span,
|
| 254 |
+
metadata=metadata,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def attach_spans_from_boxed(example: CachedExample, tokenizer) -> CachedExample:
|
| 259 |
+
"""Backward-compatible wrapper that first looks for \\box{} then falls back to answer text."""
|
| 260 |
+
tgt = example.target
|
| 261 |
+
match = _find_box_span(tgt) if tgt else None
|
| 262 |
+
boxed_answer = match[2] if match else None
|
| 263 |
+
return attach_spans_from_answer(example, tokenizer, boxed_answer)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def ruler_gold_prompt_token_indices(example: CachedExample, tokenizer) -> List[int]:
|
| 267 |
+
"""Return token indices (prompt-side) that overlap RULER `needle_spans` in metadata.
|
| 268 |
+
|
| 269 |
+
The returned indices are with respect to `tokenizer(" " + example.prompt, add_special_tokens=False)`,
|
| 270 |
+
matching the attribution pipeline's leading-space convention.
|
| 271 |
+
"""
|
| 272 |
+
needle_spans = (example.metadata or {}).get("needle_spans") or []
|
| 273 |
+
if not isinstance(needle_spans, list) or not needle_spans:
|
| 274 |
+
return []
|
| 275 |
+
|
| 276 |
+
prompt_text = " " + (example.prompt or "")
|
| 277 |
+
enc = tokenizer(prompt_text, add_special_tokens=False, return_offsets_mapping=True)
|
| 278 |
+
offsets = enc.get("offset_mapping")
|
| 279 |
+
if offsets is None:
|
| 280 |
+
raise ValueError("Tokenizer does not provide offset_mapping; cannot map needle_spans to tokens.")
|
| 281 |
+
|
| 282 |
+
spans: List[tuple[int, int]] = []
|
| 283 |
+
for item in needle_spans:
|
| 284 |
+
if not isinstance(item, dict):
|
| 285 |
+
continue
|
| 286 |
+
raw = item.get("span")
|
| 287 |
+
if not (isinstance(raw, list) and len(raw) == 2):
|
| 288 |
+
continue
|
| 289 |
+
try:
|
| 290 |
+
start = int(raw[0]) + 1 # shift for leading space in prompt_text
|
| 291 |
+
end = int(raw[1]) + 1
|
| 292 |
+
except Exception:
|
| 293 |
+
continue
|
| 294 |
+
if end > start:
|
| 295 |
+
spans.append((start, end))
|
| 296 |
+
|
| 297 |
+
if not spans:
|
| 298 |
+
return []
|
| 299 |
+
|
| 300 |
+
gold: set[int] = set()
|
| 301 |
+
for tok_idx, off in enumerate(offsets):
|
| 302 |
+
if off is None:
|
| 303 |
+
continue
|
| 304 |
+
try:
|
| 305 |
+
s, e = int(off[0]), int(off[1])
|
| 306 |
+
except Exception:
|
| 307 |
+
continue
|
| 308 |
+
if e <= s:
|
| 309 |
+
continue
|
| 310 |
+
for span_start, span_end in spans:
|
| 311 |
+
if s < span_end and e > span_start:
|
| 312 |
+
gold.add(tok_idx)
|
| 313 |
+
break
|
| 314 |
+
|
| 315 |
+
return sorted(gold)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class DatasetLoader:
|
| 319 |
+
"""Thin loader that resolves and samples datasets for exp2."""
|
| 320 |
+
|
| 321 |
+
def __init__(self, seed: int = 42, data_root: Path | str = Path("exp/exp2/data")) -> None:
|
| 322 |
+
self.seed = seed
|
| 323 |
+
self.data_root = Path(data_root)
|
| 324 |
+
|
| 325 |
+
def _sample(self, items: List[CachedExample], sample: Optional[int]) -> List[CachedExample]:
|
| 326 |
+
if sample is not None and sample < len(items):
|
| 327 |
+
rnd = random.Random(self.seed)
|
| 328 |
+
rnd.shuffle(items)
|
| 329 |
+
items = items[:sample]
|
| 330 |
+
return items
|
| 331 |
+
|
| 332 |
+
def _cached_path(self, name: str) -> Optional[Path]:
|
| 333 |
+
path = self.data_root / f"{name}.jsonl"
|
| 334 |
+
return path if path.exists() else None
|
| 335 |
+
|
| 336 |
+
def load(self, name: str, sample: Optional[int] = None) -> List[CachedExample]:
|
| 337 |
+
# 1) Prefer prepared cache under exp/exp2/data
|
| 338 |
+
cached_path = self._cached_path(name)
|
| 339 |
+
if cached_path:
|
| 340 |
+
return self._sample(load_cached(cached_path), sample)
|
| 341 |
+
|
| 342 |
+
return self.load_raw(name, sample=sample)
|
| 343 |
+
|
| 344 |
+
def load_raw(self, name: str, sample: Optional[int] = None) -> List[CachedExample]:
|
| 345 |
+
def _looks_like_json_array(path: Path) -> bool:
|
| 346 |
+
try:
|
| 347 |
+
with path.open("r", encoding="utf-8") as f:
|
| 348 |
+
while True:
|
| 349 |
+
ch = f.read(1)
|
| 350 |
+
if not ch:
|
| 351 |
+
return False
|
| 352 |
+
if ch.isspace():
|
| 353 |
+
continue
|
| 354 |
+
return ch == "["
|
| 355 |
+
except OSError:
|
| 356 |
+
return False
|
| 357 |
+
|
| 358 |
+
# MoreHopQA
|
| 359 |
+
if name == "morehopqa":
|
| 360 |
+
ex = load_morehopqa()
|
| 361 |
+
for item in ex:
|
| 362 |
+
if "answer" in item.metadata:
|
| 363 |
+
item.metadata.setdefault("reference_answer", item.metadata["answer"])
|
| 364 |
+
return self._sample(ex, sample)
|
| 365 |
+
|
| 366 |
+
# Allow passing the raw MoreHopQA JSON path directly.
|
| 367 |
+
p = Path(name)
|
| 368 |
+
if p.exists() and _looks_like_json_array(p):
|
| 369 |
+
ex = load_morehopqa(p)
|
| 370 |
+
for item in ex:
|
| 371 |
+
if "answer" in item.metadata:
|
| 372 |
+
item.metadata.setdefault("reference_answer", item.metadata["answer"])
|
| 373 |
+
return self._sample(ex, sample)
|
| 374 |
+
|
| 375 |
+
# RULER / HotpotQA / niah / vt (all go through RulerAttributionDataset)
|
| 376 |
+
resolved = dataset_from_name(name)
|
| 377 |
+
if resolved is None:
|
| 378 |
+
raise FileNotFoundError(f"Could not resolve dataset {name}")
|
| 379 |
+
ex = load_ruler(resolved)
|
| 380 |
+
for item in ex:
|
| 381 |
+
outputs = item.metadata.get("outputs") or []
|
| 382 |
+
if outputs:
|
| 383 |
+
item.metadata.setdefault("reference_answer", ", ".join(outputs))
|
| 384 |
+
if item.target and "reference_answer" not in item.metadata:
|
| 385 |
+
item.metadata["reference_answer"] = item.target
|
| 386 |
+
return self._sample(ex, sample)
|
exp/exp2/map_math_mine_to_exp2_cache.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Prepare data/math_mine.json into an exp2 cached JSONL dataset.
|
| 3 |
+
|
| 4 |
+
This script supports two modes:
|
| 5 |
+
|
| 6 |
+
- map (offline): convert GSM8K-style math examples:
|
| 7 |
+
|
| 8 |
+
{"question": "...", "answer": "... #### 18"}
|
| 9 |
+
|
| 10 |
+
into exp2's cached JSONL format (one JSON object per line).
|
| 11 |
+
|
| 12 |
+
- resample (online): resample targets like exp/exp2/sample_and_filter.py:
|
| 13 |
+
call a chat completion API to generate "<thinking> + final \\box{} answer",
|
| 14 |
+
judge the boxed answer against the reference answer extracted from the raw
|
| 15 |
+
GSM8K-style entry, and write only judge=True samples.
|
| 16 |
+
|
| 17 |
+
In both modes, exp2 expects token-level spans (NOT character spans):
|
| 18 |
+
|
| 19 |
+
- indices_to_explain: [start_tok, end_tok] (generation-token indices, closed interval)
|
| 20 |
+
- sink_span/thinking_span: token spans over tokenizer(target, add_special_tokens=False)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
import sys
|
| 29 |
+
import time
|
| 30 |
+
import urllib.error
|
| 31 |
+
import urllib.request
|
| 32 |
+
from dataclasses import asdict
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 35 |
+
|
| 36 |
+
from transformers import AutoTokenizer
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 40 |
+
if str(REPO_ROOT) not in sys.path:
|
| 41 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 42 |
+
|
| 43 |
+
from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer, split_boxed_generation # noqa: E402
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RateLimitError(RuntimeError):
|
| 47 |
+
"""Raised when API returns 429; carries a suggested wait time."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, wait_seconds: float, detail: str) -> None:
|
| 50 |
+
super().__init__(detail)
|
| 51 |
+
self.wait_seconds = wait_seconds
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
GEN_SYSTEM_PROMPT = (
|
| 55 |
+
"You are a reasoning assistant. "
|
| 56 |
+
"Before answering, engage in an chain of thought. "
|
| 57 |
+
"Process this freely and naturally without using specific headers or strict formatting. "
|
| 58 |
+
"When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
|
| 59 |
+
"Ensure the box wraps the **sentence** that naturally delivers the answer. DO NOT rewrite the answer word for the box separately."
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
JUDGE_SYSTEM_PROMPT = (
|
| 63 |
+
"You verify whether the model's boxed answer matches the reference answer. "
|
| 64 |
+
"Reply strictly with True or False and nothing else."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def call_chat_api(
|
| 69 |
+
api_base: str,
|
| 70 |
+
api_key: str,
|
| 71 |
+
model: str,
|
| 72 |
+
messages: List[Dict[str, str]],
|
| 73 |
+
*,
|
| 74 |
+
timeout: int,
|
| 75 |
+
max_tokens: int,
|
| 76 |
+
temperature: float,
|
| 77 |
+
cache_ttl: int,
|
| 78 |
+
cache_namespace: Optional[str],
|
| 79 |
+
rate_limit_delay: Optional[float] = None,
|
| 80 |
+
) -> str:
|
| 81 |
+
"""Minimal OpenAI-compatible chat.completions client (no external deps)."""
|
| 82 |
+
url = api_base.rstrip("/") + "/chat/completions"
|
| 83 |
+
payload: Dict[str, Any] = {
|
| 84 |
+
"model": model,
|
| 85 |
+
"messages": messages,
|
| 86 |
+
"max_tokens": max_tokens,
|
| 87 |
+
"temperature": temperature,
|
| 88 |
+
}
|
| 89 |
+
if cache_ttl > 0:
|
| 90 |
+
cache_obj: Dict[str, Any] = {"ttl": cache_ttl}
|
| 91 |
+
if cache_namespace:
|
| 92 |
+
cache_obj["namespace"] = cache_namespace
|
| 93 |
+
payload["cache"] = cache_obj
|
| 94 |
+
|
| 95 |
+
data = json.dumps(payload).encode("utf-8")
|
| 96 |
+
headers = {"Content-Type": "application/json"}
|
| 97 |
+
if api_key:
|
| 98 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
| 99 |
+
|
| 100 |
+
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
|
| 101 |
+
opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
|
| 102 |
+
try:
|
| 103 |
+
with opener.open(req, timeout=timeout) as resp:
|
| 104 |
+
resp_bytes = resp.read()
|
| 105 |
+
except urllib.error.HTTPError as e:
|
| 106 |
+
detail = e.read().decode("utf-8", errors="ignore") if hasattr(e, "read") else ""
|
| 107 |
+
if e.code == 429:
|
| 108 |
+
retry_after = None
|
| 109 |
+
if hasattr(e, "headers") and e.headers:
|
| 110 |
+
retry_after_header = e.headers.get("Retry-After")
|
| 111 |
+
if retry_after_header:
|
| 112 |
+
try:
|
| 113 |
+
retry_after = float(retry_after_header)
|
| 114 |
+
except ValueError:
|
| 115 |
+
retry_after = None
|
| 116 |
+
wait = retry_after or rate_limit_delay or 5.0
|
| 117 |
+
raise RateLimitError(wait, f"API HTTP 429: {detail}") from e
|
| 118 |
+
raise RuntimeError(f"API HTTP error {e.code}: {detail}") from e
|
| 119 |
+
except urllib.error.URLError as e:
|
| 120 |
+
raise RuntimeError(f"API request failed: {e}") from e
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
response = json.loads(resp_bytes.decode("utf-8"))
|
| 124 |
+
except json.JSONDecodeError as e:
|
| 125 |
+
raise RuntimeError(f"Failed to decode API response: {resp_bytes!r}") from e
|
| 126 |
+
|
| 127 |
+
choices = response.get("choices", [])
|
| 128 |
+
if not choices:
|
| 129 |
+
raise RuntimeError(f"Empty choices from API: {response}")
|
| 130 |
+
content = choices[0].get("message", {}).get("content", "")
|
| 131 |
+
if not content:
|
| 132 |
+
raise RuntimeError(f"Empty content from API: {response}")
|
| 133 |
+
return content.strip()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def build_gen_messages(prompt: str) -> List[Dict[str, str]]:
|
| 137 |
+
return [
|
| 138 |
+
{"role": "system", "content": GEN_SYSTEM_PROMPT},
|
| 139 |
+
{"role": "user", "content": prompt},
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def build_judge_messages(reference_answer: str, candidate_answer: str) -> List[Dict[str, str]]:
|
| 144 |
+
user = (
|
| 145 |
+
"Decide if the model's boxed answer matches the reference answer.\n"
|
| 146 |
+
f"Reference answer: {reference_answer}\n"
|
| 147 |
+
f"Model boxed answer (only the content inside \\box{{}}): {candidate_answer}\n"
|
| 148 |
+
"Output only True if they are semantically consistent; otherwise output False."
|
| 149 |
+
)
|
| 150 |
+
return [
|
| 151 |
+
{"role": "system", "content": JUDGE_SYSTEM_PROMPT},
|
| 152 |
+
{"role": "user", "content": user},
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def parse_bool(text: str) -> bool:
|
| 157 |
+
first = text.strip().splitlines()[0].strip().lower()
|
| 158 |
+
if first in {"true", "yes"}:
|
| 159 |
+
return True
|
| 160 |
+
if first in {"false", "no"}:
|
| 161 |
+
return False
|
| 162 |
+
# fallback: check substring
|
| 163 |
+
if "true" in first and "false" not in first:
|
| 164 |
+
return True
|
| 165 |
+
if "false" in first:
|
| 166 |
+
return False
|
| 167 |
+
raise ValueError(f"Cannot parse boolean from: {text!r}")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _load_tokenizer(tokenizer_model: str):
|
| 171 |
+
tok_path = Path(tokenizer_model)
|
| 172 |
+
if tok_path.exists():
|
| 173 |
+
tokenizer = AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
|
| 174 |
+
else:
|
| 175 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
|
| 176 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 177 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 178 |
+
return tokenizer
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _split_gsm8k_answer(answer: str) -> Optional[Tuple[str, str]]:
|
| 182 |
+
"""Return (thinking_text, final_answer) parsed from GSM8K `answer`."""
|
| 183 |
+
text = (answer or "").strip()
|
| 184 |
+
if not text:
|
| 185 |
+
return None
|
| 186 |
+
if "####" not in text:
|
| 187 |
+
return None
|
| 188 |
+
thinking, final = text.rsplit("####", 1)
|
| 189 |
+
thinking = thinking.strip()
|
| 190 |
+
final = final.strip()
|
| 191 |
+
if not final:
|
| 192 |
+
return None
|
| 193 |
+
return thinking, final
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _is_token_span(span: Any) -> bool:
|
| 197 |
+
return isinstance(span, list) and len(span) == 2 and all(isinstance(x, int) for x in span)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _build_cached_example(
|
| 201 |
+
*,
|
| 202 |
+
question: str,
|
| 203 |
+
answer: str,
|
| 204 |
+
tokenizer,
|
| 205 |
+
example_idx: int,
|
| 206 |
+
source_path: str,
|
| 207 |
+
) -> Optional[CachedExample]:
|
| 208 |
+
parsed = _split_gsm8k_answer(answer)
|
| 209 |
+
if parsed is None:
|
| 210 |
+
return None
|
| 211 |
+
thinking_text, final_answer = parsed
|
| 212 |
+
|
| 213 |
+
prompt = question.strip()
|
| 214 |
+
target = f"{thinking_text}\n{final_answer}" if thinking_text else final_answer
|
| 215 |
+
|
| 216 |
+
example = CachedExample(
|
| 217 |
+
prompt=prompt,
|
| 218 |
+
target=target,
|
| 219 |
+
indices_to_explain=None,
|
| 220 |
+
attr_mask_indices=None,
|
| 221 |
+
sink_span=None,
|
| 222 |
+
thinking_span=None,
|
| 223 |
+
metadata={
|
| 224 |
+
"dataset": "math_mine",
|
| 225 |
+
"source_path": source_path,
|
| 226 |
+
"example_idx": int(example_idx),
|
| 227 |
+
"raw_question": question,
|
| 228 |
+
"raw_answer": answer,
|
| 229 |
+
"reference_answer": final_answer,
|
| 230 |
+
"boxed_answer": final_answer,
|
| 231 |
+
},
|
| 232 |
+
)
|
| 233 |
+
example = attach_spans_from_answer(example, tokenizer, final_answer)
|
| 234 |
+
if not _is_token_span(example.sink_span):
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
# exp2 requires token-level indices_to_explain=[start_tok,end_tok] (closed interval).
|
| 238 |
+
indices_to_explain = list(example.sink_span)
|
| 239 |
+
thinking_span = example.thinking_span
|
| 240 |
+
if thinking_span is not None and _is_token_span(thinking_span) and indices_to_explain[0] == 0:
|
| 241 |
+
# No room for "thinking" tokens; avoid overlapping spans.
|
| 242 |
+
thinking_span = None
|
| 243 |
+
|
| 244 |
+
return CachedExample(
|
| 245 |
+
prompt=example.prompt,
|
| 246 |
+
target=example.target,
|
| 247 |
+
indices_to_explain=indices_to_explain,
|
| 248 |
+
attr_mask_indices=example.attr_mask_indices,
|
| 249 |
+
sink_span=indices_to_explain,
|
| 250 |
+
thinking_span=thinking_span,
|
| 251 |
+
metadata=example.metadata,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _build_resampled_example(
|
| 256 |
+
*,
|
| 257 |
+
question: str,
|
| 258 |
+
raw_answer: str,
|
| 259 |
+
reference_answer: str,
|
| 260 |
+
generation: str,
|
| 261 |
+
tokenizer,
|
| 262 |
+
example_idx: int,
|
| 263 |
+
source_path: str,
|
| 264 |
+
judge_response: str,
|
| 265 |
+
generator_model: str,
|
| 266 |
+
judge_model: str,
|
| 267 |
+
) -> Optional[CachedExample]:
|
| 268 |
+
parsed = split_boxed_generation(generation)
|
| 269 |
+
if not parsed:
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
thinking_text, _boxed_segment, boxed_answer = parsed
|
| 273 |
+
target_text = f"{thinking_text}\n{boxed_answer}" if thinking_text else boxed_answer
|
| 274 |
+
|
| 275 |
+
example = CachedExample(
|
| 276 |
+
prompt=question.strip(),
|
| 277 |
+
target=target_text,
|
| 278 |
+
indices_to_explain=None,
|
| 279 |
+
attr_mask_indices=None,
|
| 280 |
+
sink_span=None,
|
| 281 |
+
thinking_span=None,
|
| 282 |
+
metadata={
|
| 283 |
+
"dataset": "math_mine",
|
| 284 |
+
"source_path": source_path,
|
| 285 |
+
"example_idx": int(example_idx),
|
| 286 |
+
"raw_question": question,
|
| 287 |
+
"raw_answer": raw_answer,
|
| 288 |
+
"reference_answer": reference_answer,
|
| 289 |
+
"judge_response": judge_response,
|
| 290 |
+
"generator_model": generator_model,
|
| 291 |
+
"judge_model": judge_model,
|
| 292 |
+
},
|
| 293 |
+
)
|
| 294 |
+
example = attach_spans_from_answer(example, tokenizer, boxed_answer)
|
| 295 |
+
if not _is_token_span(example.sink_span):
|
| 296 |
+
return None
|
| 297 |
+
|
| 298 |
+
indices_to_explain = list(example.sink_span)
|
| 299 |
+
return CachedExample(
|
| 300 |
+
prompt=example.prompt,
|
| 301 |
+
target=example.target,
|
| 302 |
+
indices_to_explain=indices_to_explain,
|
| 303 |
+
attr_mask_indices=example.attr_mask_indices,
|
| 304 |
+
sink_span=indices_to_explain,
|
| 305 |
+
thinking_span=example.thinking_span,
|
| 306 |
+
metadata=example.metadata,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _write_jsonl(path: Path, *, examples) -> int:
|
| 311 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 312 |
+
count = 0
|
| 313 |
+
with path.open("w", encoding="utf-8") as f:
|
| 314 |
+
for ex in examples:
|
| 315 |
+
f.write(json.dumps(asdict(ex), ensure_ascii=False) + "\n")
|
| 316 |
+
count += 1
|
| 317 |
+
return count
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def main() -> None:
|
| 321 |
+
ap = argparse.ArgumentParser("Prepare data/math_mine.json for exp2 cached JSONL.")
|
| 322 |
+
ap.add_argument("--in_json", type=str, default="data/math_mine.json")
|
| 323 |
+
ap.add_argument("--out_jsonl", type=str, default="exp/exp2/data/math.jsonl")
|
| 324 |
+
ap.add_argument(
|
| 325 |
+
"--tokenizer_model",
|
| 326 |
+
type=str,
|
| 327 |
+
required=True,
|
| 328 |
+
help="Tokenizer name or local path; must match the tokenizer used in exp2 attribution.",
|
| 329 |
+
)
|
| 330 |
+
ap.add_argument(
|
| 331 |
+
"--mode",
|
| 332 |
+
type=str,
|
| 333 |
+
choices=["map", "resample"],
|
| 334 |
+
default="map",
|
| 335 |
+
help="map=offline mapping from GSM8K answers; resample=generate+judge like exp/exp2/sample_and_filter.py.",
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Resample (online) options (kept compatible with exp/exp2/sample_and_filter.py).
|
| 339 |
+
ap.add_argument("--max_examples", type=int, default=100, help="Number of judge=True examples to keep (resample mode).")
|
| 340 |
+
ap.add_argument("--seed", type=int, default=42, help="Shuffle seed (only used with --shuffle).")
|
| 341 |
+
ap.add_argument("--shuffle", action="store_true", help="Shuffle examples before attempting (resample mode).")
|
| 342 |
+
ap.add_argument("--api_base", type=str, default="http://localhost:4000/v1", help="Chat API base URL.")
|
| 343 |
+
ap.add_argument("--api_key", type=str, default=None, help="API key; defaults to FLASHTRACE_API_KEY/OPENAI_API_KEY.")
|
| 344 |
+
ap.add_argument("--generator_model", type=str, default="qwen3-235b-a22b-2507")
|
| 345 |
+
ap.add_argument("--judge_model", type=str, default="deepseek-v3-1-terminus")
|
| 346 |
+
ap.add_argument("--api_timeout", type=int, default=300)
|
| 347 |
+
ap.add_argument("--api_max_tokens", type=int, default=8192)
|
| 348 |
+
ap.add_argument("--api_temperature", type=float, default=0.0)
|
| 349 |
+
ap.add_argument("--api_cache_ttl", type=int, default=600)
|
| 350 |
+
ap.add_argument("--api_cache_namespace", type=str, default="flashtrace-exp2")
|
| 351 |
+
ap.add_argument("--retry_delay", type=float, default=2.0)
|
| 352 |
+
ap.add_argument("--retries", type=int, default=2, help="Additional retries on API failure.")
|
| 353 |
+
ap.add_argument("--request_interval", type=float, default=1.0, help="Sleep seconds between generation calls.")
|
| 354 |
+
ap.add_argument("--judge_interval", type=float, default=1.0, help="Sleep seconds between judge calls.")
|
| 355 |
+
ap.add_argument("--rate_limit_delay", type=float, default=5.0, help="Seconds to wait on HTTP 429 before retrying.")
|
| 356 |
+
args = ap.parse_args()
|
| 357 |
+
|
| 358 |
+
in_path = Path(args.in_json)
|
| 359 |
+
out_path = Path(args.out_jsonl)
|
| 360 |
+
tokenizer = _load_tokenizer(args.tokenizer_model)
|
| 361 |
+
|
| 362 |
+
raw = json.loads(in_path.read_text(encoding="utf-8"))
|
| 363 |
+
if not isinstance(raw, list):
|
| 364 |
+
raise SystemExit(f"Expected a JSON array in {in_path}, got {type(raw).__name__}.")
|
| 365 |
+
|
| 366 |
+
source_total = len(raw)
|
| 367 |
+
total = 0
|
| 368 |
+
kept = 0
|
| 369 |
+
skipped_empty_q = 0
|
| 370 |
+
skipped_empty_a = 0
|
| 371 |
+
skipped_parse = 0
|
| 372 |
+
skipped_span = 0
|
| 373 |
+
|
| 374 |
+
examples = []
|
| 375 |
+
if args.mode == "map":
|
| 376 |
+
attempted = None
|
| 377 |
+
skipped_format = None
|
| 378 |
+
judged_false = None
|
| 379 |
+
for idx, item in enumerate(raw):
|
| 380 |
+
total += 1
|
| 381 |
+
if not isinstance(item, dict):
|
| 382 |
+
skipped_parse += 1
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
question = str(item.get("question") or "")
|
| 386 |
+
answer = str(item.get("answer") or "")
|
| 387 |
+
if not question.strip():
|
| 388 |
+
skipped_empty_q += 1
|
| 389 |
+
continue
|
| 390 |
+
if not answer.strip():
|
| 391 |
+
skipped_empty_a += 1
|
| 392 |
+
continue
|
| 393 |
+
|
| 394 |
+
ex = _build_cached_example(
|
| 395 |
+
question=question,
|
| 396 |
+
answer=answer,
|
| 397 |
+
tokenizer=tokenizer,
|
| 398 |
+
example_idx=idx,
|
| 399 |
+
source_path=str(in_path),
|
| 400 |
+
)
|
| 401 |
+
if ex is None:
|
| 402 |
+
# distinguish parse-vs-span failure
|
| 403 |
+
parsed = _split_gsm8k_answer(answer)
|
| 404 |
+
if parsed is None:
|
| 405 |
+
skipped_parse += 1
|
| 406 |
+
else:
|
| 407 |
+
skipped_span += 1
|
| 408 |
+
continue
|
| 409 |
+
|
| 410 |
+
examples.append(ex)
|
| 411 |
+
kept += 1
|
| 412 |
+
else:
|
| 413 |
+
api_key = args.api_key or os.environ.get("FLASHTRACE_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
| 414 |
+
if not api_key:
|
| 415 |
+
raise SystemExit("resample mode requires --api_key or FLASHTRACE_API_KEY/OPENAI_API_KEY.")
|
| 416 |
+
|
| 417 |
+
attempted = 0
|
| 418 |
+
skipped_format = 0
|
| 419 |
+
judged_false = 0
|
| 420 |
+
|
| 421 |
+
indices = list(range(len(raw)))
|
| 422 |
+
if bool(args.shuffle):
|
| 423 |
+
import random
|
| 424 |
+
|
| 425 |
+
rnd = random.Random(int(args.seed))
|
| 426 |
+
rnd.shuffle(indices)
|
| 427 |
+
|
| 428 |
+
kept_bar = tqdm(total=int(args.max_examples), desc="Kept (judge=True)", position=1, leave=False)
|
| 429 |
+
for loop_idx in tqdm(indices, total=len(indices), desc="Resampling"):
|
| 430 |
+
if kept >= int(args.max_examples):
|
| 431 |
+
break
|
| 432 |
+
|
| 433 |
+
total += 1
|
| 434 |
+
item = raw[loop_idx]
|
| 435 |
+
if not isinstance(item, dict):
|
| 436 |
+
skipped_parse += 1
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
question = str(item.get("question") or "")
|
| 440 |
+
answer = str(item.get("answer") or "")
|
| 441 |
+
if not question.strip():
|
| 442 |
+
skipped_empty_q += 1
|
| 443 |
+
continue
|
| 444 |
+
if not answer.strip():
|
| 445 |
+
skipped_empty_a += 1
|
| 446 |
+
continue
|
| 447 |
+
|
| 448 |
+
parsed = _split_gsm8k_answer(answer)
|
| 449 |
+
if parsed is None:
|
| 450 |
+
skipped_parse += 1
|
| 451 |
+
continue
|
| 452 |
+
_ref_thinking, reference_answer = parsed
|
| 453 |
+
|
| 454 |
+
attempted += 1
|
| 455 |
+
gen_messages = build_gen_messages(question.strip())
|
| 456 |
+
|
| 457 |
+
# Step 1: generation
|
| 458 |
+
for attempt in range(int(args.retries) + 1):
|
| 459 |
+
try:
|
| 460 |
+
generation = call_chat_api(
|
| 461 |
+
str(args.api_base),
|
| 462 |
+
str(api_key),
|
| 463 |
+
str(args.generator_model),
|
| 464 |
+
gen_messages,
|
| 465 |
+
timeout=int(args.api_timeout),
|
| 466 |
+
max_tokens=int(args.api_max_tokens),
|
| 467 |
+
temperature=float(args.api_temperature),
|
| 468 |
+
cache_ttl=int(args.api_cache_ttl),
|
| 469 |
+
cache_namespace=str(args.api_cache_namespace) if args.api_cache_namespace else None,
|
| 470 |
+
rate_limit_delay=float(args.rate_limit_delay) if args.rate_limit_delay is not None else None,
|
| 471 |
+
)
|
| 472 |
+
break
|
| 473 |
+
except RateLimitError as e:
|
| 474 |
+
if attempt >= int(args.retries):
|
| 475 |
+
raise
|
| 476 |
+
time.sleep(float(e.wait_seconds))
|
| 477 |
+
except Exception: # noqa: BLE001
|
| 478 |
+
if attempt >= int(args.retries):
|
| 479 |
+
raise
|
| 480 |
+
time.sleep(float(args.retry_delay))
|
| 481 |
+
if float(args.request_interval) > 0:
|
| 482 |
+
time.sleep(float(args.request_interval))
|
| 483 |
+
|
| 484 |
+
parsed_gen = split_boxed_generation(generation)
|
| 485 |
+
if not parsed_gen:
|
| 486 |
+
skipped_format += 1
|
| 487 |
+
print(f"[attempt={attempted}] skipped=format")
|
| 488 |
+
continue
|
| 489 |
+
|
| 490 |
+
thinking_text, _boxed_segment, boxed_answer = parsed_gen
|
| 491 |
+
judge_messages = build_judge_messages(reference_answer, boxed_answer)
|
| 492 |
+
|
| 493 |
+
ok = False
|
| 494 |
+
judge_resp = ""
|
| 495 |
+
for attempt in range(int(args.retries) + 1):
|
| 496 |
+
try:
|
| 497 |
+
judge_resp = call_chat_api(
|
| 498 |
+
str(args.api_base),
|
| 499 |
+
str(api_key),
|
| 500 |
+
str(args.judge_model),
|
| 501 |
+
judge_messages,
|
| 502 |
+
timeout=int(args.api_timeout),
|
| 503 |
+
max_tokens=64,
|
| 504 |
+
temperature=0.0,
|
| 505 |
+
cache_ttl=int(args.api_cache_ttl),
|
| 506 |
+
cache_namespace=str(args.api_cache_namespace) if args.api_cache_namespace else None,
|
| 507 |
+
rate_limit_delay=float(args.rate_limit_delay) if args.rate_limit_delay is not None else None,
|
| 508 |
+
)
|
| 509 |
+
ok = parse_bool(judge_resp)
|
| 510 |
+
break
|
| 511 |
+
except RateLimitError as e:
|
| 512 |
+
if attempt >= int(args.retries):
|
| 513 |
+
raise
|
| 514 |
+
time.sleep(float(e.wait_seconds))
|
| 515 |
+
except Exception: # noqa: BLE001
|
| 516 |
+
if attempt >= int(args.retries):
|
| 517 |
+
raise
|
| 518 |
+
time.sleep(float(args.retry_delay))
|
| 519 |
+
if float(args.judge_interval) > 0:
|
| 520 |
+
time.sleep(float(args.judge_interval))
|
| 521 |
+
|
| 522 |
+
if not ok:
|
| 523 |
+
judged_false += 1
|
| 524 |
+
print(f"[attempt={attempted}] judge=filtered")
|
| 525 |
+
continue
|
| 526 |
+
|
| 527 |
+
ex = _build_resampled_example(
|
| 528 |
+
question=question,
|
| 529 |
+
raw_answer=answer,
|
| 530 |
+
reference_answer=reference_answer,
|
| 531 |
+
generation=generation,
|
| 532 |
+
tokenizer=tokenizer,
|
| 533 |
+
example_idx=int(loop_idx),
|
| 534 |
+
source_path=str(in_path),
|
| 535 |
+
judge_response=judge_resp,
|
| 536 |
+
generator_model=str(args.generator_model),
|
| 537 |
+
judge_model=str(args.judge_model),
|
| 538 |
+
)
|
| 539 |
+
if ex is None:
|
| 540 |
+
skipped_span += 1
|
| 541 |
+
print(f"[attempt={attempted}] skipped=span")
|
| 542 |
+
continue
|
| 543 |
+
|
| 544 |
+
examples.append(ex)
|
| 545 |
+
kept += 1
|
| 546 |
+
kept_bar.update(1)
|
| 547 |
+
print(f"[attempt={attempted}] judge=kept")
|
| 548 |
+
|
| 549 |
+
kept_bar.close()
|
| 550 |
+
|
| 551 |
+
written = _write_jsonl(out_path, examples=examples)
|
| 552 |
+
if written != kept:
|
| 553 |
+
raise SystemExit(f"Internal error: written={written} != kept={kept}")
|
| 554 |
+
|
| 555 |
+
print(
|
| 556 |
+
json.dumps(
|
| 557 |
+
{
|
| 558 |
+
"in_json": str(in_path),
|
| 559 |
+
"out_jsonl": str(out_path),
|
| 560 |
+
"tokenizer_model": args.tokenizer_model,
|
| 561 |
+
"mode": str(args.mode),
|
| 562 |
+
"source_total": int(source_total),
|
| 563 |
+
"visited": total,
|
| 564 |
+
"kept": kept,
|
| 565 |
+
"skipped_empty_question": skipped_empty_q,
|
| 566 |
+
"skipped_empty_answer": skipped_empty_a,
|
| 567 |
+
"skipped_parse": skipped_parse,
|
| 568 |
+
"skipped_span": skipped_span,
|
| 569 |
+
"attempted": attempted,
|
| 570 |
+
"skipped_format": skipped_format,
|
| 571 |
+
"judged_false": judged_false,
|
| 572 |
+
"max_examples": int(args.max_examples) if str(args.mode) == "resample" else None,
|
| 573 |
+
"api_base": str(args.api_base) if str(args.mode) == "resample" else None,
|
| 574 |
+
"generator_model": str(args.generator_model) if str(args.mode) == "resample" else None,
|
| 575 |
+
"judge_model": str(args.judge_model) if str(args.mode) == "resample" else None,
|
| 576 |
+
},
|
| 577 |
+
ensure_ascii=False,
|
| 578 |
+
indent=2,
|
| 579 |
+
)
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
if __name__ == "__main__":
|
| 584 |
+
main()
|
exp/exp2/migrate_indices_to_explain_token_span.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Migrate exp2 cached JSONL to token-span `indices_to_explain`.
|
| 3 |
+
|
| 4 |
+
This converts legacy caches that used sentence indices (e.g. `[-2]`) into the
|
| 5 |
+
token-span format:
|
| 6 |
+
|
| 7 |
+
indices_to_explain = [start_tok, end_tok]
|
| 8 |
+
|
| 9 |
+
Where the span points to the boxed-inner (final answer) token span in `target`
|
| 10 |
+
under `tokenizer(target, add_special_tokens=False)`.
|
| 11 |
+
|
| 12 |
+
Rule:
|
| 13 |
+
1) If `sink_span` exists and looks valid -> copy it to `indices_to_explain`
|
| 14 |
+
2) Else try to recompute spans from `target` + `metadata.boxed_answer` using
|
| 15 |
+
`exp/exp2/dataset_utils.attach_spans_from_answer`
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import json
|
| 22 |
+
import sys
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Any, Dict, Optional
|
| 25 |
+
|
| 26 |
+
from transformers import AutoTokenizer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _ensure_repo_root_on_path() -> None:
|
| 30 |
+
repo_root = Path(__file__).resolve().parents[2]
|
| 31 |
+
if str(repo_root) not in sys.path:
|
| 32 |
+
sys.path.insert(0, str(repo_root))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _is_token_span(span: Any) -> bool:
|
| 36 |
+
return isinstance(span, list) and len(span) == 2 and all(isinstance(x, int) for x in span)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _load_tokenizer(tokenizer_model: str):
|
| 40 |
+
tok_path = Path(tokenizer_model)
|
| 41 |
+
if tok_path.exists():
|
| 42 |
+
return AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
|
| 43 |
+
return AutoTokenizer.from_pretrained(tokenizer_model)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _migrate_obj(obj: Dict[str, Any], tokenizer) -> tuple[Dict[str, Any], bool]:
|
| 47 |
+
sink_span = obj.get("sink_span")
|
| 48 |
+
if _is_token_span(sink_span):
|
| 49 |
+
obj["indices_to_explain"] = sink_span
|
| 50 |
+
return obj, True
|
| 51 |
+
|
| 52 |
+
_ensure_repo_root_on_path()
|
| 53 |
+
from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer # noqa: E402
|
| 54 |
+
|
| 55 |
+
example = CachedExample(
|
| 56 |
+
prompt=obj.get("prompt") or "",
|
| 57 |
+
target=obj.get("target"),
|
| 58 |
+
indices_to_explain=obj.get("indices_to_explain"),
|
| 59 |
+
attr_mask_indices=obj.get("attr_mask_indices"),
|
| 60 |
+
sink_span=obj.get("sink_span"),
|
| 61 |
+
thinking_span=obj.get("thinking_span"),
|
| 62 |
+
metadata=obj.get("metadata") or {},
|
| 63 |
+
)
|
| 64 |
+
answer_text = (example.metadata.get("boxed_answer") or "").strip() or None
|
| 65 |
+
migrated = attach_spans_from_answer(example, tokenizer, answer_text)
|
| 66 |
+
if not _is_token_span(migrated.sink_span):
|
| 67 |
+
return obj, False
|
| 68 |
+
|
| 69 |
+
obj["sink_span"] = migrated.sink_span
|
| 70 |
+
obj["thinking_span"] = migrated.thinking_span
|
| 71 |
+
obj["indices_to_explain"] = migrated.sink_span
|
| 72 |
+
obj["metadata"] = migrated.metadata
|
| 73 |
+
return obj, True
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main() -> None:
|
| 77 |
+
ap = argparse.ArgumentParser()
|
| 78 |
+
ap.add_argument("--in_jsonl", type=str, required=True)
|
| 79 |
+
ap.add_argument("--out_jsonl", type=str, required=True)
|
| 80 |
+
ap.add_argument("--tokenizer_model", type=str, required=True)
|
| 81 |
+
ap.add_argument("--strict", action="store_true", help="Fail on any line that cannot be migrated.")
|
| 82 |
+
args = ap.parse_args()
|
| 83 |
+
|
| 84 |
+
tokenizer = _load_tokenizer(args.tokenizer_model)
|
| 85 |
+
|
| 86 |
+
in_path = Path(args.in_jsonl)
|
| 87 |
+
out_path = Path(args.out_jsonl)
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
same_path = in_path.resolve() == out_path.resolve()
|
| 91 |
+
except FileNotFoundError:
|
| 92 |
+
same_path = False
|
| 93 |
+
|
| 94 |
+
tmp_out_path = out_path
|
| 95 |
+
if same_path:
|
| 96 |
+
tmp_out_path = out_path.with_name(out_path.name + ".tmp")
|
| 97 |
+
if tmp_out_path.exists():
|
| 98 |
+
tmp_out_path.unlink()
|
| 99 |
+
|
| 100 |
+
tmp_out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 101 |
+
|
| 102 |
+
total = 0
|
| 103 |
+
migrated_ok = 0
|
| 104 |
+
bad = 0
|
| 105 |
+
|
| 106 |
+
with in_path.open("r", encoding="utf-8") as fin, tmp_out_path.open("w", encoding="utf-8") as fout:
|
| 107 |
+
for line_no, line in enumerate(fin, start=1):
|
| 108 |
+
if not line.strip():
|
| 109 |
+
continue
|
| 110 |
+
total += 1
|
| 111 |
+
obj: Dict[str, Any] = json.loads(line)
|
| 112 |
+
new_obj, ok = _migrate_obj(obj, tokenizer)
|
| 113 |
+
if ok:
|
| 114 |
+
migrated_ok += 1
|
| 115 |
+
else:
|
| 116 |
+
bad += 1
|
| 117 |
+
if args.strict:
|
| 118 |
+
raise RuntimeError(f"cannot migrate line {line_no}: cannot resolve sink_span token span")
|
| 119 |
+
fout.write(json.dumps(new_obj, ensure_ascii=False) + "\n")
|
| 120 |
+
|
| 121 |
+
if same_path:
|
| 122 |
+
tmp_out_path.replace(out_path)
|
| 123 |
+
print(f"[done] total={total} migrated_ok={migrated_ok} bad={bad} wrote={out_path} (in-place)")
|
| 124 |
+
else:
|
| 125 |
+
print(f"[done] total={total} migrated_ok={migrated_ok} bad={bad} wrote={out_path}")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
exp/exp2/out.log
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[1/500] judge=kept
|
| 2 |
+
[2/500] judge=kept
|
| 3 |
+
[3/500] judge=kept
|
| 4 |
+
[4/500] judge=kept
|
| 5 |
+
[5/500] judge=kept
|
| 6 |
+
[6/500] judge=kept
|
| 7 |
+
[7/500] judge=kept
|
| 8 |
+
[8/500] judge=kept
|
| 9 |
+
[9/500] judge=kept
|
| 10 |
+
[10/500] judge=kept
|
| 11 |
+
[11/500] judge=kept
|
| 12 |
+
[12/500] judge=kept
|
| 13 |
+
[13/500] judge=kept
|
| 14 |
+
[14/500] judge=kept
|
| 15 |
+
[15/500] judge=kept
|
| 16 |
+
[16/500] judge=kept
|
| 17 |
+
[17/500] judge=kept
|
| 18 |
+
[18/500] judge=kept
|
| 19 |
+
[19/500] judge=kept
|
| 20 |
+
[20/500] judge=kept
|
| 21 |
+
[21/500] judge=kept
|
| 22 |
+
[22/500] judge=kept
|
| 23 |
+
[23/500] judge=kept
|
| 24 |
+
[24/500] judge=kept
|
| 25 |
+
[25/500] judge=kept
|
| 26 |
+
[26/500] judge=kept
|
| 27 |
+
[27/500] judge=kept
|
| 28 |
+
[28/500] judge=kept
|
| 29 |
+
[29/500] judge=kept
|
| 30 |
+
[30/500] judge=kept
|
| 31 |
+
[31/500] judge=kept
|
| 32 |
+
[32/500] judge=kept
|
| 33 |
+
[33/500] judge=kept
|
| 34 |
+
[34/500] judge=kept
|
| 35 |
+
[35/500] judge=kept
|
| 36 |
+
[36/500] judge=kept
|
| 37 |
+
[37/500] judge=kept
|
| 38 |
+
[38/500] judge=kept
|
| 39 |
+
[39/500] judge=kept
|
| 40 |
+
[40/500] judge=kept
|
| 41 |
+
[41/500] judge=kept
|
| 42 |
+
[42/500] judge=kept
|
| 43 |
+
[43/500] judge=kept
|
| 44 |
+
[44/500] judge=kept
|
| 45 |
+
[45/500] judge=kept
|
| 46 |
+
[46/500] judge=kept
|
| 47 |
+
[47/500] judge=kept
|
| 48 |
+
[48/500] judge=kept
|
| 49 |
+
[49/500] judge=kept
|
| 50 |
+
[50/500] judge=kept
|
| 51 |
+
[51/500] judge=kept
|
| 52 |
+
[52/500] judge=kept
|
| 53 |
+
[53/500] judge=kept
|
| 54 |
+
[54/500] judge=kept
|
| 55 |
+
[55/500] judge=kept
|
| 56 |
+
[56/500] judge=kept
|
| 57 |
+
[57/500] judge=kept
|
| 58 |
+
[58/500] judge=kept
|
| 59 |
+
[59/500] judge=kept
|
| 60 |
+
[60/500] judge=kept
|
| 61 |
+
[61/500] judge=kept
|
| 62 |
+
[62/500] judge=kept
|
| 63 |
+
[63/500] skipped=format
|
| 64 |
+
[64/500] judge=kept
|
| 65 |
+
[65/500] judge=kept
|
| 66 |
+
[66/500] judge=kept
|
| 67 |
+
[67/500] judge=kept
|
| 68 |
+
[68/500] judge=kept
|
| 69 |
+
[69/500] judge=kept
|
| 70 |
+
[70/500] judge=kept
|
| 71 |
+
[71/500] judge=kept
|
| 72 |
+
[72/500] judge=kept
|
| 73 |
+
[73/500] judge=kept
|
| 74 |
+
[74/500] judge=kept
|
| 75 |
+
[75/500] judge=kept
|
| 76 |
+
[76/500] judge=kept
|
| 77 |
+
[77/500] judge=kept
|
| 78 |
+
[78/500] judge=kept
|
| 79 |
+
[79/500] judge=kept
|
| 80 |
+
[80/500] judge=kept
|
| 81 |
+
[81/500] judge=kept
|
| 82 |
+
[82/500] judge=kept
|
| 83 |
+
[83/500] judge=kept
|
| 84 |
+
[84/500] judge=kept
|
| 85 |
+
[85/500] judge=kept
|
| 86 |
+
[86/500] judge=kept
|
| 87 |
+
[87/500] judge=kept
|
| 88 |
+
[88/500] judge=kept
|
| 89 |
+
[89/500] judge=kept
|
| 90 |
+
[90/500] judge=kept
|
| 91 |
+
[91/500] judge=kept
|
| 92 |
+
[92/500] judge=kept
|
| 93 |
+
[93/500] judge=kept
|
| 94 |
+
[94/500] judge=kept
|
| 95 |
+
[95/500] judge=kept
|
| 96 |
+
[96/500] judge=kept
|
| 97 |
+
[97/500] judge=kept
|
| 98 |
+
[98/500] judge=kept
|
| 99 |
+
[99/500] judge=kept
|
| 100 |
+
[100/500] judge=kept
|
| 101 |
+
[101/500] judge=kept
|
| 102 |
+
Kept 100 / target 100 (attempted 101 / 500) -> exp/exp2/data/data/ruler_multihop/1024/vt_h10_c1/validation.jsonl.jsonl
|
exp/exp2/run_exp.py
ADDED
|
@@ -0,0 +1,1296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Experiment 2 runner: token-level faithfulness (generation perturbation).
|
| 4 |
+
|
| 5 |
+
AT2 is omitted.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import hashlib
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from itertools import islice
|
| 16 |
+
import math
|
| 17 |
+
import time
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
# Early CUDA mask handling: set CUDA_VISIBLE_DEVICES before importing torch.
|
| 22 |
+
def _early_set_cuda_visible_devices():
|
| 23 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 24 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 25 |
+
# parse_known_args keeps the full argv for later parsing by the main parser
|
| 26 |
+
args, _ = parser.parse_known_args(sys.argv[1:])
|
| 27 |
+
if args.cuda and "," in args.cuda:
|
| 28 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
_early_set_cuda_visible_devices()
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import torch
|
| 35 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, utils
|
| 36 |
+
|
| 37 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 38 |
+
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
|
| 41 |
+
# ensure repo root on path
|
| 42 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 43 |
+
if str(REPO_ROOT) not in sys.path:
|
| 44 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 45 |
+
|
| 46 |
+
import llm_attr
|
| 47 |
+
import llm_attr_eval
|
| 48 |
+
from attribution_datasets import AttributionExample
|
| 49 |
+
from exp.exp2 import dataset_utils as ds_utils
|
| 50 |
+
|
| 51 |
+
utils.logging.set_verbosity_error()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _sha1_text(text: str) -> str:
|
| 55 |
+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _infer_attnlrp_spans_from_hops(
|
| 59 |
+
raw_attributions: Any,
|
| 60 |
+
*,
|
| 61 |
+
gen_len: int,
|
| 62 |
+
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
| 63 |
+
if not raw_attributions:
|
| 64 |
+
return (0, max(0, gen_len - 1)), (0, max(0, gen_len - 1))
|
| 65 |
+
sink_span = tuple(int(x) for x in raw_attributions[0].sink_range)
|
| 66 |
+
if len(raw_attributions) >= 2:
|
| 67 |
+
thinking_span = tuple(int(x) for x in raw_attributions[1].sink_range)
|
| 68 |
+
else:
|
| 69 |
+
thinking_span = sink_span
|
| 70 |
+
return sink_span, thinking_span
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _build_hop_trace_payload(
|
| 74 |
+
attr_func: str,
|
| 75 |
+
attr: Any,
|
| 76 |
+
*,
|
| 77 |
+
indices_to_explain: List[int],
|
| 78 |
+
) -> Optional[Dict[str, np.ndarray]]:
|
| 79 |
+
"""Extract per-hop vectors (postprocessed) and minimal span metadata."""
|
| 80 |
+
prompt_len = int(len(getattr(attr, "prompt_tokens", []) or []))
|
| 81 |
+
gen_len = int(len(getattr(attr, "generation_tokens", []) or []))
|
| 82 |
+
total_len = prompt_len + gen_len
|
| 83 |
+
if total_len <= 0:
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
hop_vectors: List[torch.Tensor] = []
|
| 87 |
+
sink_span_gen: Optional[Tuple[int, int]] = None
|
| 88 |
+
thinking_span_gen: Optional[Tuple[int, int]] = None
|
| 89 |
+
attnlrp_neg_handling: str = ""
|
| 90 |
+
attnlrp_norm_mode: str = ""
|
| 91 |
+
attnlrp_ratio_enabled: int = -1
|
| 92 |
+
|
| 93 |
+
# IFR multi-hop variants expose projected hop vectors via metadata["ifr"]["per_hop_projected"].
|
| 94 |
+
ifr_meta = (getattr(attr, "metadata", None) or {}).get("ifr") or {}
|
| 95 |
+
ifr_per_hop = ifr_meta.get("per_hop_projected") or []
|
| 96 |
+
|
| 97 |
+
if ifr_per_hop:
|
| 98 |
+
hop_vectors = [torch.as_tensor(v, dtype=torch.float32) for v in ifr_per_hop]
|
| 99 |
+
sink_span_gen = ifr_meta.get("sink_span_generation")
|
| 100 |
+
thinking_span_gen = ifr_meta.get("thinking_span_generation")
|
| 101 |
+
if sink_span_gen is not None:
|
| 102 |
+
sink_span_gen = tuple(int(x) for x in sink_span_gen)
|
| 103 |
+
if thinking_span_gen is not None:
|
| 104 |
+
thinking_span_gen = tuple(int(x) for x in thinking_span_gen)
|
| 105 |
+
|
| 106 |
+
elif attr_func in ("ft_attnlrp", "attnlrp_aggregated_multi_hop"):
|
| 107 |
+
meta = getattr(attr, "metadata", None) or {}
|
| 108 |
+
attnlrp_neg_handling = str(meta.get("neg_handling") or "")
|
| 109 |
+
attnlrp_norm_mode = str(meta.get("norm_mode") or "")
|
| 110 |
+
if meta.get("ratio_enabled") is not None:
|
| 111 |
+
attnlrp_ratio_enabled = int(bool(meta.get("ratio_enabled")))
|
| 112 |
+
multi_hop = meta.get("multi_hop_result")
|
| 113 |
+
if multi_hop is None:
|
| 114 |
+
return None
|
| 115 |
+
raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
|
| 116 |
+
if not raw_attributions:
|
| 117 |
+
return None
|
| 118 |
+
hop_vectors = [
|
| 119 |
+
torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32)
|
| 120 |
+
for hop in raw_attributions
|
| 121 |
+
]
|
| 122 |
+
sink_span_gen, thinking_span_gen = _infer_attnlrp_spans_from_hops(raw_attributions, gen_len=gen_len)
|
| 123 |
+
sink_override = meta.get("sink_span")
|
| 124 |
+
thinking_override = meta.get("thinking_span")
|
| 125 |
+
if sink_override is not None:
|
| 126 |
+
sink_span_gen = tuple(int(x) for x in sink_override)
|
| 127 |
+
if thinking_override is not None:
|
| 128 |
+
thinking_span_gen = tuple(int(x) for x in thinking_override)
|
| 129 |
+
|
| 130 |
+
else:
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
if sink_span_gen is None:
|
| 134 |
+
sink_span_gen = (0, max(0, gen_len - 1))
|
| 135 |
+
if thinking_span_gen is None:
|
| 136 |
+
thinking_span_gen = sink_span_gen
|
| 137 |
+
|
| 138 |
+
stacked = torch.stack([v.reshape(-1) for v in hop_vectors], dim=0)
|
| 139 |
+
if stacked.shape[1] != total_len:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"Hop vector length mismatch for {attr_func}: expected T={total_len}, got {stacked.shape[1]}."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
"vh": stacked.detach().cpu().numpy().astype(np.float32, copy=False),
|
| 146 |
+
"prompt_len": np.asarray(prompt_len, dtype=np.int64),
|
| 147 |
+
"gen_len": np.asarray(gen_len, dtype=np.int64),
|
| 148 |
+
"sink_span_gen": np.asarray(sink_span_gen, dtype=np.int64),
|
| 149 |
+
"thinking_span_gen": np.asarray(thinking_span_gen, dtype=np.int64),
|
| 150 |
+
"indices_to_explain_gen": np.asarray(indices_to_explain, dtype=np.int64),
|
| 151 |
+
"attnlrp_neg_handling": np.asarray(attnlrp_neg_handling, dtype="U16"),
|
| 152 |
+
"attnlrp_norm_mode": np.asarray(attnlrp_norm_mode, dtype="U16"),
|
| 153 |
+
"attnlrp_ratio_enabled": np.asarray(attnlrp_ratio_enabled, dtype=np.int64),
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _write_hop_trace(
|
| 158 |
+
trace_dir: Path,
|
| 159 |
+
*,
|
| 160 |
+
example_idx: int,
|
| 161 |
+
attr_func: str,
|
| 162 |
+
prompt: str,
|
| 163 |
+
target: Optional[str],
|
| 164 |
+
payload: Dict[str, np.ndarray],
|
| 165 |
+
manifest_handle,
|
| 166 |
+
) -> None:
|
| 167 |
+
trace_dir.mkdir(parents=True, exist_ok=True)
|
| 168 |
+
npz_name = f"ex_{example_idx:06d}.npz"
|
| 169 |
+
npz_path = trace_dir / npz_name
|
| 170 |
+
np.savez_compressed(npz_path, **payload)
|
| 171 |
+
|
| 172 |
+
record = {
|
| 173 |
+
"example_idx": int(example_idx),
|
| 174 |
+
"attr_func": attr_func,
|
| 175 |
+
"file": npz_name,
|
| 176 |
+
"prompt_sha1": _sha1_text(prompt),
|
| 177 |
+
"target_sha1": _sha1_text(target) if target is not None else None,
|
| 178 |
+
"prompt_len": int(payload["prompt_len"].item()),
|
| 179 |
+
"gen_len": int(payload["gen_len"].item()),
|
| 180 |
+
"n_hops_plus_one": int(payload["vh"].shape[0]),
|
| 181 |
+
"total_len": int(payload["vh"].shape[1]),
|
| 182 |
+
"sink_span_gen": payload["sink_span_gen"].tolist(),
|
| 183 |
+
"thinking_span_gen": payload["thinking_span_gen"].tolist(),
|
| 184 |
+
"indices_to_explain_gen": payload["indices_to_explain_gen"].tolist(),
|
| 185 |
+
"attnlrp_neg_handling": str(payload["attnlrp_neg_handling"].item()),
|
| 186 |
+
"attnlrp_norm_mode": str(payload["attnlrp_norm_mode"].item()),
|
| 187 |
+
"attnlrp_ratio_enabled": int(payload["attnlrp_ratio_enabled"].item()),
|
| 188 |
+
}
|
| 189 |
+
manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 190 |
+
manifest_handle.flush()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _parse_modes(mode_args: Any) -> List[str]:
|
| 194 |
+
"""Parse --mode which may be provided as multiple args and/or comma-separated."""
|
| 195 |
+
if mode_args is None:
|
| 196 |
+
raw_parts: List[str] = []
|
| 197 |
+
elif isinstance(mode_args, str):
|
| 198 |
+
raw_parts = [mode_args]
|
| 199 |
+
else:
|
| 200 |
+
raw_parts = [str(x) for x in mode_args]
|
| 201 |
+
|
| 202 |
+
modes: List[str] = []
|
| 203 |
+
for chunk in raw_parts:
|
| 204 |
+
for part in str(chunk).split(","):
|
| 205 |
+
m = part.strip()
|
| 206 |
+
if m:
|
| 207 |
+
modes.append(m)
|
| 208 |
+
|
| 209 |
+
# Default to faithfulness_gen for backward compatibility.
|
| 210 |
+
if not modes:
|
| 211 |
+
modes = ["faithfulness_gen"]
|
| 212 |
+
|
| 213 |
+
allowed = {"faithfulness_gen", "recovery_ruler"}
|
| 214 |
+
seen: set[str] = set()
|
| 215 |
+
unique: List[str] = []
|
| 216 |
+
for m in modes:
|
| 217 |
+
if m not in seen:
|
| 218 |
+
unique.append(m)
|
| 219 |
+
seen.add(m)
|
| 220 |
+
|
| 221 |
+
unknown = [m for m in unique if m not in allowed]
|
| 222 |
+
if unknown:
|
| 223 |
+
raise SystemExit(f"Unsupported --mode value(s): {unknown}. Allowed: {sorted(allowed)}.")
|
| 224 |
+
|
| 225 |
+
return unique
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _trace_run_tag(
|
| 229 |
+
testing_dict: Dict[str, Any],
|
| 230 |
+
*,
|
| 231 |
+
modes: List[str],
|
| 232 |
+
total: int,
|
| 233 |
+
) -> str:
|
| 234 |
+
attr_func = str(testing_dict.get("attr_func") or "attr")
|
| 235 |
+
parts = [attr_func]
|
| 236 |
+
|
| 237 |
+
if attr_func in (
|
| 238 |
+
"ifr_multi_hop",
|
| 239 |
+
"ifr_in_all_gen",
|
| 240 |
+
"ifr_multi_hop_stop_words",
|
| 241 |
+
"ifr_multi_hop_both",
|
| 242 |
+
"ifr_multi_hop_split_hop",
|
| 243 |
+
"ft_attnlrp",
|
| 244 |
+
"attnlrp_aggregated_multi_hop",
|
| 245 |
+
):
|
| 246 |
+
parts.append(f"n{int(testing_dict.get('n_hops', 0))}")
|
| 247 |
+
|
| 248 |
+
if attr_func in ("attnlrp", "ft_attnlrp", "attnlrp_aggregated_multi_hop"):
|
| 249 |
+
parts.append(f"neg{str(testing_dict.get('attnlrp_neg_handling', ''))}")
|
| 250 |
+
parts.append(f"norm{str(testing_dict.get('attnlrp_norm_mode', ''))}")
|
| 251 |
+
|
| 252 |
+
if modes:
|
| 253 |
+
parts.append("m" + "+".join(modes))
|
| 254 |
+
|
| 255 |
+
parts.append(f"{int(total)}ex")
|
| 256 |
+
return "_".join(parts)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _token_importance_vector(attr: torch.Tensor) -> np.ndarray:
|
| 260 |
+
"""Return token importance vector w = sum_rows(attr) in shape [P+G]."""
|
| 261 |
+
w = torch.nan_to_num(attr.sum(0).to(dtype=torch.float32), nan=0.0).clamp(min=0.0)
|
| 262 |
+
return w.detach().cpu().numpy().astype(np.float32, copy=False)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _build_sample_trace_payload(
|
| 266 |
+
example: ds_utils.CachedExample,
|
| 267 |
+
*,
|
| 268 |
+
attr_list: List[torch.Tensor],
|
| 269 |
+
prompt_len: int,
|
| 270 |
+
user_prompt_indices: Optional[List[int]],
|
| 271 |
+
keep_prompt_token_indices: Optional[List[int]],
|
| 272 |
+
gold_prompt_token_indices: Optional[List[int]],
|
| 273 |
+
hop_payload: Optional[Dict[str, np.ndarray]],
|
| 274 |
+
faithfulness_scores: Optional[np.ndarray],
|
| 275 |
+
recovery_scores: Optional[np.ndarray],
|
| 276 |
+
time_attr_s: Optional[float],
|
| 277 |
+
time_faith_s: Optional[float],
|
| 278 |
+
time_recovery_s: Optional[float],
|
| 279 |
+
) -> Dict[str, np.ndarray]:
|
| 280 |
+
seq_attr, row_attr, rec_attr = attr_list
|
| 281 |
+
gen_len = int(seq_attr.shape[0])
|
| 282 |
+
|
| 283 |
+
v_seq_all = _token_importance_vector(seq_attr)
|
| 284 |
+
v_row_all = _token_importance_vector(row_attr)
|
| 285 |
+
v_rec_all = _token_importance_vector(rec_attr)
|
| 286 |
+
|
| 287 |
+
payload: Dict[str, np.ndarray] = {
|
| 288 |
+
"v_seq_all": v_seq_all,
|
| 289 |
+
"v_row_all": v_row_all,
|
| 290 |
+
"v_rec_all": v_rec_all,
|
| 291 |
+
"v_seq_prompt": v_seq_all[:prompt_len],
|
| 292 |
+
"v_row_prompt": v_row_all[:prompt_len],
|
| 293 |
+
"v_rec_prompt": v_rec_all[:prompt_len],
|
| 294 |
+
"prompt_len": np.asarray(int(prompt_len), dtype=np.int64),
|
| 295 |
+
"gen_len": np.asarray(int(gen_len), dtype=np.int64),
|
| 296 |
+
"indices_to_explain_gen": np.asarray(list(example.indices_to_explain or []), dtype=np.int64),
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
if example.sink_span is not None:
|
| 300 |
+
payload["sink_span_gen"] = np.asarray(list(example.sink_span), dtype=np.int64)
|
| 301 |
+
if example.thinking_span is not None:
|
| 302 |
+
payload["thinking_span_gen"] = np.asarray(list(example.thinking_span), dtype=np.int64)
|
| 303 |
+
|
| 304 |
+
if user_prompt_indices is not None:
|
| 305 |
+
payload["user_prompt_indices"] = np.asarray(list(user_prompt_indices), dtype=np.int64)
|
| 306 |
+
if keep_prompt_token_indices is not None:
|
| 307 |
+
payload["keep_prompt_token_indices"] = np.asarray(list(keep_prompt_token_indices), dtype=np.int64)
|
| 308 |
+
if gold_prompt_token_indices is not None:
|
| 309 |
+
payload["gold_prompt_token_indices"] = np.asarray(list(gold_prompt_token_indices), dtype=np.int64)
|
| 310 |
+
|
| 311 |
+
if faithfulness_scores is not None:
|
| 312 |
+
payload["faithfulness_scores"] = np.asarray(faithfulness_scores, dtype=np.float64)
|
| 313 |
+
if recovery_scores is not None:
|
| 314 |
+
payload["recovery_scores"] = np.asarray(recovery_scores, dtype=np.float64)
|
| 315 |
+
|
| 316 |
+
if time_attr_s is not None:
|
| 317 |
+
payload["time_attr_s"] = np.asarray(float(time_attr_s), dtype=np.float64)
|
| 318 |
+
if time_faith_s is not None:
|
| 319 |
+
payload["time_faith_s"] = np.asarray(float(time_faith_s), dtype=np.float64)
|
| 320 |
+
if time_recovery_s is not None:
|
| 321 |
+
payload["time_recovery_s"] = np.asarray(float(time_recovery_s), dtype=np.float64)
|
| 322 |
+
|
| 323 |
+
if hop_payload is not None:
|
| 324 |
+
for k, v in hop_payload.items():
|
| 325 |
+
if k in payload:
|
| 326 |
+
continue
|
| 327 |
+
payload[k] = v
|
| 328 |
+
|
| 329 |
+
return payload
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _write_sample_trace(
|
| 333 |
+
trace_dir: Path,
|
| 334 |
+
*,
|
| 335 |
+
example_idx: int,
|
| 336 |
+
attr_func: str,
|
| 337 |
+
prompt: str,
|
| 338 |
+
target: Optional[str],
|
| 339 |
+
payload: Dict[str, np.ndarray],
|
| 340 |
+
manifest_handle,
|
| 341 |
+
recovery_skipped_reason: Optional[str],
|
| 342 |
+
) -> None:
|
| 343 |
+
trace_dir.mkdir(parents=True, exist_ok=True)
|
| 344 |
+
npz_name = f"ex_{example_idx:06d}.npz"
|
| 345 |
+
npz_path = trace_dir / npz_name
|
| 346 |
+
np.savez_compressed(npz_path, **payload)
|
| 347 |
+
|
| 348 |
+
prompt_len = int(np.asarray(payload.get("prompt_len", 0)).item())
|
| 349 |
+
gen_len = int(np.asarray(payload.get("gen_len", 0)).item())
|
| 350 |
+
record: Dict[str, Any] = {
|
| 351 |
+
"example_idx": int(example_idx),
|
| 352 |
+
"attr_func": attr_func,
|
| 353 |
+
"file": npz_name,
|
| 354 |
+
"prompt_sha1": _sha1_text(prompt),
|
| 355 |
+
"target_sha1": _sha1_text(target) if target is not None else None,
|
| 356 |
+
"prompt_len": prompt_len,
|
| 357 |
+
"gen_len": gen_len,
|
| 358 |
+
"indices_to_explain_gen": payload.get("indices_to_explain_gen").tolist()
|
| 359 |
+
if payload.get("indices_to_explain_gen") is not None
|
| 360 |
+
else None,
|
| 361 |
+
"sink_span_gen": payload.get("sink_span_gen").tolist() if payload.get("sink_span_gen") is not None else None,
|
| 362 |
+
"thinking_span_gen": payload.get("thinking_span_gen").tolist()
|
| 363 |
+
if payload.get("thinking_span_gen") is not None
|
| 364 |
+
else None,
|
| 365 |
+
"faithfulness_scores": payload.get("faithfulness_scores").tolist()
|
| 366 |
+
if payload.get("faithfulness_scores") is not None
|
| 367 |
+
else None,
|
| 368 |
+
"recovery_scores": payload.get("recovery_scores").tolist() if payload.get("recovery_scores") is not None else None,
|
| 369 |
+
"recovery_skipped_reason": recovery_skipped_reason,
|
| 370 |
+
"time_attr_s": float(np.asarray(payload.get("time_attr_s")).item()) if payload.get("time_attr_s") is not None else None,
|
| 371 |
+
"time_faith_s": float(np.asarray(payload.get("time_faith_s")).item()) if payload.get("time_faith_s") is not None else None,
|
| 372 |
+
"time_recovery_s": float(np.asarray(payload.get("time_recovery_s")).item())
|
| 373 |
+
if payload.get("time_recovery_s") is not None
|
| 374 |
+
else None,
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
# Derived, sample-level bookkeeping (token lengths and per-sample MAS/RISE).
|
| 378 |
+
record["input_len"] = int(prompt_len)
|
| 379 |
+
|
| 380 |
+
sink_span = record.get("sink_span_gen")
|
| 381 |
+
if isinstance(sink_span, list) and len(sink_span) == 2:
|
| 382 |
+
try:
|
| 383 |
+
start = int(sink_span[0])
|
| 384 |
+
end = int(sink_span[1])
|
| 385 |
+
record["output_len"] = (end - start + 1) if end >= start else None
|
| 386 |
+
except Exception:
|
| 387 |
+
record["output_len"] = None
|
| 388 |
+
else:
|
| 389 |
+
record["output_len"] = None
|
| 390 |
+
|
| 391 |
+
thinking_span = record.get("thinking_span_gen")
|
| 392 |
+
if isinstance(thinking_span, list) and len(thinking_span) == 2:
|
| 393 |
+
try:
|
| 394 |
+
start = int(thinking_span[0])
|
| 395 |
+
end = int(thinking_span[1])
|
| 396 |
+
record["cot_len"] = (end - start + 1) if end >= start else None
|
| 397 |
+
except Exception:
|
| 398 |
+
record["cot_len"] = None
|
| 399 |
+
else:
|
| 400 |
+
record["cot_len"] = None
|
| 401 |
+
|
| 402 |
+
record["rise_seq"] = None
|
| 403 |
+
record["mas_seq"] = None
|
| 404 |
+
record["rise_row"] = None
|
| 405 |
+
record["mas_row"] = None
|
| 406 |
+
record["rise_rec"] = None
|
| 407 |
+
record["mas_rec"] = None
|
| 408 |
+
faith = record.get("faithfulness_scores")
|
| 409 |
+
if isinstance(faith, list) and len(faith) == 3:
|
| 410 |
+
try:
|
| 411 |
+
record["rise_seq"] = float(faith[0][0])
|
| 412 |
+
record["mas_seq"] = float(faith[0][1])
|
| 413 |
+
record["rise_row"] = float(faith[1][0])
|
| 414 |
+
record["mas_row"] = float(faith[1][1])
|
| 415 |
+
record["rise_rec"] = float(faith[2][0])
|
| 416 |
+
record["mas_rec"] = float(faith[2][1])
|
| 417 |
+
except Exception:
|
| 418 |
+
pass
|
| 419 |
+
|
| 420 |
+
if payload.get("vh") is not None:
|
| 421 |
+
vh = payload["vh"]
|
| 422 |
+
record["n_hops_plus_one"] = int(vh.shape[0])
|
| 423 |
+
record["total_len"] = int(vh.shape[1])
|
| 424 |
+
record["attnlrp_neg_handling"] = str(payload.get("attnlrp_neg_handling").item()) if payload.get("attnlrp_neg_handling") is not None else ""
|
| 425 |
+
record["attnlrp_norm_mode"] = str(payload.get("attnlrp_norm_mode").item()) if payload.get("attnlrp_norm_mode") is not None else ""
|
| 426 |
+
record["attnlrp_ratio_enabled"] = int(payload.get("attnlrp_ratio_enabled").item()) if payload.get("attnlrp_ratio_enabled") is not None else -1
|
| 427 |
+
|
| 428 |
+
manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 429 |
+
manifest_handle.flush()
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _compute_faithfulness_scores(
|
| 433 |
+
testing_dict: Dict[str, Any],
|
| 434 |
+
*,
|
| 435 |
+
attr_list: List[torch.Tensor],
|
| 436 |
+
prompt_len: int,
|
| 437 |
+
prompt: str,
|
| 438 |
+
generation: str,
|
| 439 |
+
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
|
| 440 |
+
user_prompt_indices: Optional[List[int]],
|
| 441 |
+
keep_prompt_token_indices: Optional[List[int]],
|
| 442 |
+
) -> np.ndarray:
|
| 443 |
+
attr_func = str(testing_dict.get("attr_func") or "")
|
| 444 |
+
results: List[Tuple[float, float, float]] = []
|
| 445 |
+
for attr in attr_list:
|
| 446 |
+
attr_prompt = attr[:, :prompt_len]
|
| 447 |
+
if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
|
| 448 |
+
import ft_ifr_improve
|
| 449 |
+
|
| 450 |
+
scores = ft_ifr_improve.faithfulness_test_skip_tokens(
|
| 451 |
+
llm_evaluator,
|
| 452 |
+
attr_prompt,
|
| 453 |
+
prompt,
|
| 454 |
+
generation,
|
| 455 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 456 |
+
user_prompt_indices=user_prompt_indices,
|
| 457 |
+
)
|
| 458 |
+
elif user_prompt_indices is not None:
|
| 459 |
+
scores = _faithfulness_test_with_user_prompt_indices(
|
| 460 |
+
llm_evaluator,
|
| 461 |
+
attr_prompt,
|
| 462 |
+
prompt,
|
| 463 |
+
generation,
|
| 464 |
+
user_prompt_indices=user_prompt_indices,
|
| 465 |
+
)
|
| 466 |
+
else:
|
| 467 |
+
scores = llm_evaluator.faithfulness_test(attr_prompt, prompt, generation)
|
| 468 |
+
results.append(scores)
|
| 469 |
+
return np.asarray(results, dtype=np.float64)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _compute_recovery_scores(
|
| 473 |
+
testing_dict: Dict[str, Any],
|
| 474 |
+
*,
|
| 475 |
+
attr_list: List[torch.Tensor],
|
| 476 |
+
prompt_len: int,
|
| 477 |
+
gold_prompt_token_indices: List[int],
|
| 478 |
+
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
|
| 479 |
+
keep_prompt_token_indices: Optional[List[int]],
|
| 480 |
+
) -> Tuple[Optional[np.ndarray], Optional[str]]:
|
| 481 |
+
attr_func = str(testing_dict.get("attr_func") or "")
|
| 482 |
+
|
| 483 |
+
if prompt_len <= 0:
|
| 484 |
+
return None, "empty_prompt_len"
|
| 485 |
+
|
| 486 |
+
gold_prompt = [int(x) for x in (gold_prompt_token_indices or [])]
|
| 487 |
+
if not gold_prompt:
|
| 488 |
+
return None, "empty_gold_prompt"
|
| 489 |
+
|
| 490 |
+
if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
|
| 491 |
+
import ft_ifr_improve
|
| 492 |
+
|
| 493 |
+
keep_set = {int(x) for x in keep_prompt_token_indices}
|
| 494 |
+
gold_filtered = [idx for idx in gold_prompt if int(idx) in keep_set]
|
| 495 |
+
if not gold_filtered:
|
| 496 |
+
return None, "empty_gold_after_keep_filter"
|
| 497 |
+
|
| 498 |
+
scores = [
|
| 499 |
+
ft_ifr_improve.evaluate_attr_recovery_skip_tokens(
|
| 500 |
+
attr[:, :prompt_len],
|
| 501 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 502 |
+
gold_prompt_token_indices=gold_prompt,
|
| 503 |
+
top_fraction=0.1,
|
| 504 |
+
)
|
| 505 |
+
for attr in attr_list
|
| 506 |
+
]
|
| 507 |
+
else:
|
| 508 |
+
scores = [
|
| 509 |
+
llm_evaluator.evaluate_attr_recovery(
|
| 510 |
+
attr,
|
| 511 |
+
prompt_len=prompt_len,
|
| 512 |
+
gold_prompt_token_indices=gold_prompt,
|
| 513 |
+
top_fraction=0.1,
|
| 514 |
+
)
|
| 515 |
+
for attr in attr_list
|
| 516 |
+
]
|
| 517 |
+
|
| 518 |
+
return np.asarray(scores, dtype=np.float64), None
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def evaluate_dataset_multi(
|
| 522 |
+
args,
|
| 523 |
+
dataset_name: str,
|
| 524 |
+
examples: List[ds_utils.CachedExample],
|
| 525 |
+
testing_dict: Dict[str, Any],
|
| 526 |
+
*,
|
| 527 |
+
modes: List[str],
|
| 528 |
+
) -> Dict[str, Any]:
|
| 529 |
+
tokenizer = testing_dict["tokenizer"]
|
| 530 |
+
llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(testing_dict["model"], tokenizer)
|
| 531 |
+
|
| 532 |
+
want_faith = "faithfulness_gen" in modes
|
| 533 |
+
want_recovery = "recovery_ruler" in modes
|
| 534 |
+
|
| 535 |
+
faith_results: List[np.ndarray] = []
|
| 536 |
+
faith_durations: List[float] = []
|
| 537 |
+
|
| 538 |
+
recovery_results: List[np.ndarray] = []
|
| 539 |
+
recovery_attr_durations: List[float] = []
|
| 540 |
+
recovery_skipped = 0
|
| 541 |
+
|
| 542 |
+
total = min(len(examples), args.num_examples)
|
| 543 |
+
iterator = islice(examples, total)
|
| 544 |
+
|
| 545 |
+
save_traces = bool(getattr(args, "save_hop_traces", False))
|
| 546 |
+
manifest_handle = None
|
| 547 |
+
trace_dir: Optional[Path] = None
|
| 548 |
+
if save_traces:
|
| 549 |
+
model_tag = str(testing_dict.get("model_tag", "model"))
|
| 550 |
+
run_tag = _trace_run_tag(testing_dict, modes=modes, total=total)
|
| 551 |
+
trace_dir = Path(args.output_root) / "traces" / dataset_name / model_tag / run_tag
|
| 552 |
+
trace_dir.mkdir(parents=True, exist_ok=True)
|
| 553 |
+
manifest_handle = open(trace_dir / "manifest.jsonl", "w", encoding="utf-8")
|
| 554 |
+
|
| 555 |
+
try:
|
| 556 |
+
for example_idx, ex in enumerate(iterator):
|
| 557 |
+
if want_recovery:
|
| 558 |
+
needle_spans = (ex.metadata or {}).get("needle_spans")
|
| 559 |
+
if not isinstance(needle_spans, list) or not needle_spans:
|
| 560 |
+
raise SystemExit(
|
| 561 |
+
"recovery_ruler requires RULER samples with metadata.needle_spans; "
|
| 562 |
+
f"dataset={dataset_name} has missing/empty needle_spans."
|
| 563 |
+
)
|
| 564 |
+
if ex.target is None:
|
| 565 |
+
raise SystemExit(
|
| 566 |
+
"recovery_ruler requires cached targets (CoT+answer) so row/rec attribution is well-defined. "
|
| 567 |
+
f"dataset={dataset_name} has target=None; run exp/exp2/sample_and_filter.py first."
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# Determine generation/target once.
|
| 571 |
+
target = ex.target
|
| 572 |
+
if target is None:
|
| 573 |
+
generation, full_output = llm_evaluator.response(ex.prompt)
|
| 574 |
+
target = generation
|
| 575 |
+
response_len = len(tokenizer(full_output).input_ids)
|
| 576 |
+
else:
|
| 577 |
+
response_len = len(tokenizer(llm_evaluator.format_prompt(" " + ex.prompt) + target).input_ids)
|
| 578 |
+
|
| 579 |
+
testing_dict["batch_size"] = max(1, math.floor((testing_dict["max_input_len"] - 100) / max(1, response_len)))
|
| 580 |
+
|
| 581 |
+
gold_prompt: Optional[List[int]] = None
|
| 582 |
+
if want_recovery:
|
| 583 |
+
gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer)
|
| 584 |
+
|
| 585 |
+
if want_recovery and not want_faith and not save_traces:
|
| 586 |
+
# Preserve recovery-only fast path when not saving traces: skip samples with empty gold.
|
| 587 |
+
if not gold_prompt:
|
| 588 |
+
recovery_skipped += 1
|
| 589 |
+
continue
|
| 590 |
+
|
| 591 |
+
time_attr_s = None
|
| 592 |
+
time_faith_s = None
|
| 593 |
+
time_recovery_s = None
|
| 594 |
+
|
| 595 |
+
t0 = time.perf_counter()
|
| 596 |
+
attr_list, hop_payload, user_prompt_indices, keep_prompt_token_indices = run_attribution(testing_dict, ex, target)
|
| 597 |
+
time_attr_s = time.perf_counter() - t0
|
| 598 |
+
|
| 599 |
+
seq_attr = attr_list[0]
|
| 600 |
+
prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
|
| 601 |
+
|
| 602 |
+
if want_recovery and gold_prompt:
|
| 603 |
+
recovery_attr_durations.append(float(time_attr_s))
|
| 604 |
+
|
| 605 |
+
faith_scores = None
|
| 606 |
+
if want_faith:
|
| 607 |
+
t1 = time.perf_counter()
|
| 608 |
+
faith_scores = _compute_faithfulness_scores(
|
| 609 |
+
testing_dict,
|
| 610 |
+
attr_list=attr_list,
|
| 611 |
+
prompt_len=prompt_len,
|
| 612 |
+
prompt=ex.prompt,
|
| 613 |
+
generation=target,
|
| 614 |
+
llm_evaluator=llm_evaluator,
|
| 615 |
+
user_prompt_indices=user_prompt_indices,
|
| 616 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 617 |
+
)
|
| 618 |
+
time_faith_s = time.perf_counter() - t1
|
| 619 |
+
faith_results.append(faith_scores)
|
| 620 |
+
faith_durations.append(float(time_attr_s))
|
| 621 |
+
|
| 622 |
+
recovery_scores = None
|
| 623 |
+
recovery_skip_reason = None
|
| 624 |
+
if want_recovery:
|
| 625 |
+
if not gold_prompt:
|
| 626 |
+
recovery_skip_reason = "empty_gold_prompt"
|
| 627 |
+
recovery_skipped += 1
|
| 628 |
+
else:
|
| 629 |
+
t2 = time.perf_counter()
|
| 630 |
+
recovery_scores, recovery_skip_reason = _compute_recovery_scores(
|
| 631 |
+
testing_dict,
|
| 632 |
+
attr_list=attr_list,
|
| 633 |
+
prompt_len=prompt_len,
|
| 634 |
+
gold_prompt_token_indices=gold_prompt,
|
| 635 |
+
llm_evaluator=llm_evaluator,
|
| 636 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 637 |
+
)
|
| 638 |
+
time_recovery_s = time.perf_counter() - t2
|
| 639 |
+
if recovery_scores is None:
|
| 640 |
+
recovery_skipped += 1
|
| 641 |
+
else:
|
| 642 |
+
recovery_results.append(recovery_scores)
|
| 643 |
+
|
| 644 |
+
if manifest_handle is not None and trace_dir is not None:
|
| 645 |
+
try:
|
| 646 |
+
payload = _build_sample_trace_payload(
|
| 647 |
+
ex,
|
| 648 |
+
attr_list=attr_list,
|
| 649 |
+
prompt_len=prompt_len,
|
| 650 |
+
user_prompt_indices=user_prompt_indices,
|
| 651 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 652 |
+
gold_prompt_token_indices=gold_prompt,
|
| 653 |
+
hop_payload=hop_payload,
|
| 654 |
+
faithfulness_scores=faith_scores,
|
| 655 |
+
recovery_scores=recovery_scores,
|
| 656 |
+
time_attr_s=time_attr_s,
|
| 657 |
+
time_faith_s=time_faith_s,
|
| 658 |
+
time_recovery_s=time_recovery_s,
|
| 659 |
+
)
|
| 660 |
+
_write_sample_trace(
|
| 661 |
+
trace_dir,
|
| 662 |
+
example_idx=example_idx,
|
| 663 |
+
attr_func=str(testing_dict.get("attr_func") or ""),
|
| 664 |
+
prompt=ex.prompt,
|
| 665 |
+
target=target,
|
| 666 |
+
payload=payload,
|
| 667 |
+
manifest_handle=manifest_handle,
|
| 668 |
+
recovery_skipped_reason=recovery_skip_reason,
|
| 669 |
+
)
|
| 670 |
+
except Exception as exc:
|
| 671 |
+
print(f"[warn] sample trace save failed for {testing_dict.get('attr_func')} ex={example_idx}: {exc}")
|
| 672 |
+
finally:
|
| 673 |
+
if manifest_handle is not None:
|
| 674 |
+
try:
|
| 675 |
+
manifest_handle.close()
|
| 676 |
+
except Exception:
|
| 677 |
+
pass
|
| 678 |
+
|
| 679 |
+
out: Dict[str, Any] = {}
|
| 680 |
+
if want_faith:
|
| 681 |
+
if not faith_results:
|
| 682 |
+
out["faithfulness"] = None
|
| 683 |
+
else:
|
| 684 |
+
scores = np.stack(faith_results, axis=0) # [N, 3, 3]
|
| 685 |
+
out["faithfulness"] = {
|
| 686 |
+
"mean": scores.mean(0),
|
| 687 |
+
"std": scores.std(0),
|
| 688 |
+
"avg_time": float(np.mean(faith_durations)) if faith_durations else 0.0,
|
| 689 |
+
}
|
| 690 |
+
if want_recovery:
|
| 691 |
+
if not recovery_results:
|
| 692 |
+
out["recovery"] = None
|
| 693 |
+
else:
|
| 694 |
+
scores = np.stack(recovery_results, axis=0) # [N, 3]
|
| 695 |
+
out["recovery"] = {
|
| 696 |
+
"mean": scores.mean(0),
|
| 697 |
+
"std": scores.std(0),
|
| 698 |
+
"avg_time": float(np.mean(recovery_attr_durations)) if recovery_attr_durations else 0.0,
|
| 699 |
+
"used": int(scores.shape[0]),
|
| 700 |
+
"skipped": int(recovery_skipped),
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
return out
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def _faithfulness_test_with_user_prompt_indices(
|
| 707 |
+
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
|
| 708 |
+
attribution: torch.Tensor,
|
| 709 |
+
prompt: str,
|
| 710 |
+
generation: str,
|
| 711 |
+
*,
|
| 712 |
+
user_prompt_indices: List[int],
|
| 713 |
+
k: int = 20, ### control the MAS steps per sample
|
| 714 |
+
) -> Tuple[float, float, float]:
|
| 715 |
+
"""Token-level MAS/RISE faithfulness via guided deletion in k perturbation steps using provided prompt indices.
|
| 716 |
+
|
| 717 |
+
This mirrors llm_attr_eval.LLMAttributionEvaluator.faithfulness_test, but avoids
|
| 718 |
+
locating the user prompt span via token-id subsequence matching (which may fail
|
| 719 |
+
for some tokenizers due to non-compositional BPE merges at template boundaries).
|
| 720 |
+
"""
|
| 721 |
+
|
| 722 |
+
def auc(arr: np.ndarray) -> float:
|
| 723 |
+
return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / max(1, (arr.shape[0] - 1))
|
| 724 |
+
|
| 725 |
+
pad_token_id = llm_evaluator._ensure_pad_token_id()
|
| 726 |
+
|
| 727 |
+
user_prompt = " " + prompt
|
| 728 |
+
formatted_prompt = llm_evaluator.format_prompt(user_prompt)
|
| 729 |
+
formatted_ids = llm_evaluator.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).input_ids
|
| 730 |
+
|
| 731 |
+
prompt_ids = formatted_ids.to(llm_evaluator.device)
|
| 732 |
+
prompt_ids_perturbed = prompt_ids.clone()
|
| 733 |
+
generation_ids = llm_evaluator.tokenizer(
|
| 734 |
+
generation + llm_evaluator.tokenizer.eos_token,
|
| 735 |
+
return_tensors="pt",
|
| 736 |
+
add_special_tokens=False,
|
| 737 |
+
).input_ids.to(llm_evaluator.device)
|
| 738 |
+
|
| 739 |
+
attr_cpu = attribution.detach().cpu()
|
| 740 |
+
w = attr_cpu.sum(0)
|
| 741 |
+
sorted_attr_indices = torch.argsort(w, descending=True)
|
| 742 |
+
attr_sum = float(w.sum().item())
|
| 743 |
+
|
| 744 |
+
P = int(w.numel())
|
| 745 |
+
if len(user_prompt_indices) != P:
|
| 746 |
+
raise ValueError(
|
| 747 |
+
"user_prompt_indices length does not match prompt-side attribution length: "
|
| 748 |
+
f"indices P={len(user_prompt_indices)}, attr P={P}."
|
| 749 |
+
)
|
| 750 |
+
if P == 0:
|
| 751 |
+
return 0.0, 0.0, 0.0
|
| 752 |
+
|
| 753 |
+
if max(user_prompt_indices) >= int(prompt_ids_perturbed.shape[1]):
|
| 754 |
+
raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
|
| 755 |
+
|
| 756 |
+
if P > 0:
|
| 757 |
+
steps = int(k) if k is not None else 0
|
| 758 |
+
if steps <= 0:
|
| 759 |
+
steps = 1
|
| 760 |
+
steps = min(steps, P)
|
| 761 |
+
else:
|
| 762 |
+
steps = 0
|
| 763 |
+
|
| 764 |
+
scores = np.zeros(steps + 1, dtype=np.float64)
|
| 765 |
+
density = np.zeros(steps + 1, dtype=np.float64)
|
| 766 |
+
|
| 767 |
+
scores[0] = (
|
| 768 |
+
llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
|
| 769 |
+
)
|
| 770 |
+
density[0] = 1.0
|
| 771 |
+
|
| 772 |
+
if attr_sum <= 0:
|
| 773 |
+
density = np.linspace(1.0, 0.0, steps + 1)
|
| 774 |
+
|
| 775 |
+
base = P // steps
|
| 776 |
+
remainder = P % steps
|
| 777 |
+
start = 0
|
| 778 |
+
for step in range(steps):
|
| 779 |
+
size = base + (1 if step < remainder else 0)
|
| 780 |
+
group = sorted_attr_indices[start : start + size]
|
| 781 |
+
start += size
|
| 782 |
+
|
| 783 |
+
for idx in group:
|
| 784 |
+
j = int(idx.item())
|
| 785 |
+
abs_pos = int(user_prompt_indices[j])
|
| 786 |
+
prompt_ids_perturbed[0, abs_pos] = pad_token_id
|
| 787 |
+
scores[step + 1] = (
|
| 788 |
+
llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
|
| 789 |
+
)
|
| 790 |
+
if attr_sum > 0:
|
| 791 |
+
dec = float(w.index_select(0, group).sum().item()) / attr_sum
|
| 792 |
+
density[step + 1] = density[step] - dec
|
| 793 |
+
|
| 794 |
+
min_normalized_pred = 1.0
|
| 795 |
+
normalized_model_response = scores.copy()
|
| 796 |
+
for i in range(len(scores)):
|
| 797 |
+
normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
|
| 798 |
+
normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
|
| 799 |
+
min_normalized_pred = min(min_normalized_pred, normalized_pred)
|
| 800 |
+
normalized_model_response[i] = min_normalized_pred
|
| 801 |
+
|
| 802 |
+
alignment_penalty = np.abs(normalized_model_response - density)
|
| 803 |
+
corrected_scores = normalized_model_response + alignment_penalty
|
| 804 |
+
corrected_scores = corrected_scores.clip(0.0, 1.0)
|
| 805 |
+
corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
|
| 806 |
+
|
| 807 |
+
if np.isnan(corrected_scores).any():
|
| 808 |
+
corrected_scores = np.linspace(1.0, 0.0, len(scores))
|
| 809 |
+
|
| 810 |
+
return auc(normalized_model_response), auc(corrected_scores), auc(normalized_model_response + alignment_penalty)
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
def load_model(model_name: str, device: str):
|
| 814 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 815 |
+
model_name,
|
| 816 |
+
device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
|
| 817 |
+
torch_dtype=torch.float16,
|
| 818 |
+
attn_implementation="eager",
|
| 819 |
+
)
|
| 820 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 821 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 822 |
+
model.eval()
|
| 823 |
+
return model, tokenizer
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def resolve_device(args) -> str:
|
| 827 |
+
if args.cuda is not None and "," in args.cuda:
|
| 828 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 829 |
+
return "auto"
|
| 830 |
+
if args.cuda is not None and args.cuda.strip():
|
| 831 |
+
return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
|
| 832 |
+
return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
def run_attribution(
|
| 836 |
+
testing_dict, example: ds_utils.CachedExample, target: Optional[str]
|
| 837 |
+
) -> Tuple[List[torch.Tensor], Optional[Dict[str, np.ndarray]], Optional[List[int]]]:
|
| 838 |
+
model = testing_dict["model"]
|
| 839 |
+
tokenizer = testing_dict["tokenizer"]
|
| 840 |
+
attr_func = testing_dict["attr_func"]
|
| 841 |
+
|
| 842 |
+
indices_to_explain = example.indices_to_explain
|
| 843 |
+
if not (isinstance(indices_to_explain, list) and len(indices_to_explain) == 2):
|
| 844 |
+
raise ValueError(
|
| 845 |
+
"exp2 requires token-span indices_to_explain=[start_tok,end_tok]. "
|
| 846 |
+
"Please re-sample or run exp/exp2/migrate_indices_to_explain_token_span.py on your cache."
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
llm_attributor = None
|
| 850 |
+
if "IG" in attr_func:
|
| 851 |
+
llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
|
| 852 |
+
attr = llm_attributor.calculate_IG_per_generation(
|
| 853 |
+
example.prompt,
|
| 854 |
+
20,
|
| 855 |
+
tokenizer.eos_token_id,
|
| 856 |
+
batch_size=testing_dict["batch_size"],
|
| 857 |
+
target=target,
|
| 858 |
+
)
|
| 859 |
+
elif "perturbation" in attr_func:
|
| 860 |
+
if attr_func in ("perturbation_all_fast", "perturbation_CLP_fast", "perturbation_REAGENT_fast"):
|
| 861 |
+
import perturbation_fast
|
| 862 |
+
|
| 863 |
+
llm_attributor = perturbation_fast.LLMPerturbationFastAttribution(model, tokenizer)
|
| 864 |
+
if attr_func == "perturbation_all_fast":
|
| 865 |
+
attr = llm_attributor.calculate_feature_ablation_segments(
|
| 866 |
+
example.prompt,
|
| 867 |
+
baseline=tokenizer.eos_token_id,
|
| 868 |
+
measure="log_loss",
|
| 869 |
+
target=target,
|
| 870 |
+
source_k=20,
|
| 871 |
+
)
|
| 872 |
+
elif attr_func == "perturbation_CLP_fast":
|
| 873 |
+
attr = llm_attributor.calculate_feature_ablation_segments(
|
| 874 |
+
example.prompt,
|
| 875 |
+
baseline=tokenizer.eos_token_id,
|
| 876 |
+
measure="KL",
|
| 877 |
+
target=target,
|
| 878 |
+
source_k=20,
|
| 879 |
+
)
|
| 880 |
+
else:
|
| 881 |
+
attr = llm_attributor.calculate_feature_ablation_segments_mlm(
|
| 882 |
+
example.prompt,
|
| 883 |
+
target=target,
|
| 884 |
+
source_k=20,
|
| 885 |
+
)
|
| 886 |
+
else:
|
| 887 |
+
llm_attributor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
|
| 888 |
+
if attr_func == "perturbation_all":
|
| 889 |
+
attr = llm_attributor.calculate_feature_ablation_sentences(
|
| 890 |
+
example.prompt, baseline=tokenizer.eos_token_id, measure="log_loss", target=target
|
| 891 |
+
)
|
| 892 |
+
elif attr_func == "perturbation_CLP":
|
| 893 |
+
attr = llm_attributor.calculate_feature_ablation_sentences(
|
| 894 |
+
example.prompt, baseline=tokenizer.eos_token_id, measure="KL", target=target
|
| 895 |
+
)
|
| 896 |
+
elif attr_func == "perturbation_REAGENT":
|
| 897 |
+
attr = llm_attributor.calculate_feature_ablation_sentences_mlm(example.prompt, target=target)
|
| 898 |
+
else:
|
| 899 |
+
raise ValueError(f"Unsupported perturbation attr_func {attr_func}")
|
| 900 |
+
elif "attention" in attr_func:
|
| 901 |
+
llm_attributor = llm_attr.LLMAttentionAttribution(model, tokenizer)
|
| 902 |
+
llm_attributor_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
|
| 903 |
+
attr = llm_attributor.calculate_attention_attribution(example.prompt, target=target)
|
| 904 |
+
attr_b = llm_attributor_ig.calculate_IG_per_generation(
|
| 905 |
+
example.prompt, 20, tokenizer.eos_token_id, batch_size=testing_dict["batch_size"], target=target
|
| 906 |
+
)
|
| 907 |
+
attr.attribution_matrix = attr.attribution_matrix * attr_b.attribution_matrix
|
| 908 |
+
elif attr_func == "ifr_all_positions":
|
| 909 |
+
llm_attributor = llm_attr.LLMIFRAttribution(
|
| 910 |
+
model,
|
| 911 |
+
tokenizer,
|
| 912 |
+
chunk_tokens=testing_dict["chunk_tokens"],
|
| 913 |
+
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
|
| 914 |
+
)
|
| 915 |
+
attr = llm_attributor.calculate_ifr_for_all_positions(example.prompt, target=target)
|
| 916 |
+
elif attr_func == "ifr_all_positions_output_only":
|
| 917 |
+
llm_attributor = llm_attr.LLMIFRAttribution(
|
| 918 |
+
model,
|
| 919 |
+
tokenizer,
|
| 920 |
+
chunk_tokens=testing_dict["chunk_tokens"],
|
| 921 |
+
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
|
| 922 |
+
)
|
| 923 |
+
sink_span = tuple(example.sink_span) if example.sink_span else tuple(indices_to_explain)
|
| 924 |
+
attr = llm_attributor.calculate_ifr_for_all_positions_output_only(
|
| 925 |
+
example.prompt,
|
| 926 |
+
target=target,
|
| 927 |
+
sink_span=sink_span,
|
| 928 |
+
)
|
| 929 |
+
elif attr_func == "ifr_multi_hop":
|
| 930 |
+
llm_attributor = llm_attr.LLMIFRAttribution(
|
| 931 |
+
model,
|
| 932 |
+
tokenizer,
|
| 933 |
+
chunk_tokens=testing_dict["chunk_tokens"],
|
| 934 |
+
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
|
| 935 |
+
)
|
| 936 |
+
attr = llm_attributor.calculate_ifr_multi_hop(
|
| 937 |
+
example.prompt,
|
| 938 |
+
target=target,
|
| 939 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 940 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 941 |
+
n_hops=testing_dict["n_hops"],
|
| 942 |
+
)
|
| 943 |
+
elif attr_func == "ifr_in_all_gen":
|
| 944 |
+
import ft_ifr_improve
|
| 945 |
+
|
| 946 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(
|
| 947 |
+
model,
|
| 948 |
+
tokenizer,
|
| 949 |
+
chunk_tokens=testing_dict["chunk_tokens"],
|
| 950 |
+
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
|
| 951 |
+
)
|
| 952 |
+
attr = llm_attributor.calculate_ifr_in_all_gen(
|
| 953 |
+
example.prompt,
|
| 954 |
+
target=target,
|
| 955 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 956 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 957 |
+
n_hops=testing_dict["n_hops"],
|
| 958 |
+
)
|
| 959 |
+
elif attr_func == "ifr_multi_hop_stop_words":
|
| 960 |
+
import ft_ifr_improve
|
| 961 |
+
|
| 962 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionImproved(
|
| 963 |
+
model,
|
| 964 |
+
tokenizer,
|
| 965 |
+
chunk_tokens=testing_dict["chunk_tokens"],
|
| 966 |
+
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
|
| 967 |
+
)
|
| 968 |
+
attr = llm_attributor.calculate_ifr_multi_hop_stop_words(
|
| 969 |
+
example.prompt,
|
| 970 |
+
target=target,
|
| 971 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 972 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 973 |
+
n_hops=testing_dict["n_hops"],
|
| 974 |
+
)
|
| 975 |
+
elif attr_func == "ifr_multi_hop_both":
|
| 976 |
+
import ft_ifr_improve
|
| 977 |
+
|
| 978 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionBoth(
|
| 979 |
+
model,
|
| 980 |
+
tokenizer,
|
| 981 |
+
chunk_tokens=testing_dict["chunk_tokens"],
|
| 982 |
+
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
|
| 983 |
+
)
|
| 984 |
+
attr = llm_attributor.calculate_ifr_multi_hop_both(
|
| 985 |
+
example.prompt,
|
| 986 |
+
target=target,
|
| 987 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 988 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 989 |
+
n_hops=testing_dict["n_hops"],
|
| 990 |
+
)
|
| 991 |
+
elif attr_func == "ifr_multi_hop_split_hop":
|
| 992 |
+
import ft_ifr_improve
|
| 993 |
+
|
| 994 |
+
llm_attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(
|
| 995 |
+
model,
|
| 996 |
+
tokenizer,
|
| 997 |
+
chunk_tokens=testing_dict["chunk_tokens"],
|
| 998 |
+
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
|
| 999 |
+
)
|
| 1000 |
+
attr = llm_attributor.calculate_ifr_multi_hop_split_hop(
|
| 1001 |
+
example.prompt,
|
| 1002 |
+
target=target,
|
| 1003 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 1004 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 1005 |
+
n_hops=testing_dict["n_hops"],
|
| 1006 |
+
)
|
| 1007 |
+
elif attr_func == "attnlrp":
|
| 1008 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 1009 |
+
attr = llm_attributor.calculate_attnlrp_ft_hop0(
|
| 1010 |
+
example.prompt,
|
| 1011 |
+
target=target,
|
| 1012 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 1013 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 1014 |
+
neg_handling=str(testing_dict.get("attnlrp_neg_handling", "drop")),
|
| 1015 |
+
norm_mode=str(testing_dict.get("attnlrp_norm_mode", "norm")),
|
| 1016 |
+
)
|
| 1017 |
+
elif attr_func in ("ft_attnlrp", "attnlrp_aggregated_multi_hop"):
|
| 1018 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 1019 |
+
attr = llm_attributor.calculate_attnlrp_aggregated_multi_hop(
|
| 1020 |
+
example.prompt,
|
| 1021 |
+
target=target,
|
| 1022 |
+
sink_span=tuple(example.sink_span) if example.sink_span else None,
|
| 1023 |
+
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
|
| 1024 |
+
n_hops=testing_dict["n_hops"],
|
| 1025 |
+
neg_handling=str(testing_dict.get("attnlrp_neg_handling", "drop")),
|
| 1026 |
+
norm_mode=str(testing_dict.get("attnlrp_norm_mode", "norm")),
|
| 1027 |
+
)
|
| 1028 |
+
elif attr_func == "basic":
|
| 1029 |
+
llm_attributor = llm_attr.LLMBasicAttribution(model, tokenizer)
|
| 1030 |
+
attr = llm_attributor.calculate_basic_attribution(example.prompt, target=target)
|
| 1031 |
+
else:
|
| 1032 |
+
raise ValueError(f"Unsupported attr_func {attr_func}")
|
| 1033 |
+
|
| 1034 |
+
seq_attr, row_attr, rec_attr = attr.get_all_token_attrs(indices_to_explain)
|
| 1035 |
+
hop_payload = None
|
| 1036 |
+
if bool(testing_dict.get("save_hop_traces", False)):
|
| 1037 |
+
try:
|
| 1038 |
+
hop_payload = _build_hop_trace_payload(attr_func, attr, indices_to_explain=indices_to_explain)
|
| 1039 |
+
except Exception as exc:
|
| 1040 |
+
print(f"[warn] hop trace extraction failed for {attr_func}: {exc}")
|
| 1041 |
+
hop_payload = None
|
| 1042 |
+
|
| 1043 |
+
user_prompt_indices = getattr(llm_attributor, "user_prompt_indices", None)
|
| 1044 |
+
if isinstance(user_prompt_indices, list):
|
| 1045 |
+
user_prompt_indices = [int(x) for x in user_prompt_indices]
|
| 1046 |
+
else:
|
| 1047 |
+
user_prompt_indices = None
|
| 1048 |
+
|
| 1049 |
+
keep_prompt_token_indices = None
|
| 1050 |
+
if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both"):
|
| 1051 |
+
try:
|
| 1052 |
+
import ft_ifr_improve
|
| 1053 |
+
|
| 1054 |
+
keep_prompt_token_indices = ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens))
|
| 1055 |
+
except Exception:
|
| 1056 |
+
keep_prompt_token_indices = None
|
| 1057 |
+
|
| 1058 |
+
return [seq_attr, row_attr, rec_attr], hop_payload, user_prompt_indices, keep_prompt_token_indices
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
def faithfulness_generation(
|
| 1062 |
+
testing_dict, example: ds_utils.CachedExample, target: str, llm_evaluator
|
| 1063 |
+
) -> Tuple[np.ndarray, Optional[Dict[str, np.ndarray]]]:
|
| 1064 |
+
prompt = example.prompt
|
| 1065 |
+
generation = target
|
| 1066 |
+
|
| 1067 |
+
attr_func = str(testing_dict.get("attr_func") or "")
|
| 1068 |
+
attr_list, hop_payload, user_prompt_indices, keep_prompt_token_indices = run_attribution(
|
| 1069 |
+
testing_dict, example, target
|
| 1070 |
+
)
|
| 1071 |
+
seq_attr = attr_list[0]
|
| 1072 |
+
prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
|
| 1073 |
+
|
| 1074 |
+
results = []
|
| 1075 |
+
for attr in attr_list:
|
| 1076 |
+
# Only use prompt-side attribution, matching evaluations/faithfulness.py
|
| 1077 |
+
attr_prompt = attr[:, :prompt_len]
|
| 1078 |
+
if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
|
| 1079 |
+
import ft_ifr_improve
|
| 1080 |
+
|
| 1081 |
+
scores = ft_ifr_improve.faithfulness_test_skip_tokens(
|
| 1082 |
+
llm_evaluator,
|
| 1083 |
+
attr_prompt,
|
| 1084 |
+
prompt,
|
| 1085 |
+
generation,
|
| 1086 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 1087 |
+
user_prompt_indices=user_prompt_indices,
|
| 1088 |
+
)
|
| 1089 |
+
elif user_prompt_indices is not None:
|
| 1090 |
+
scores = _faithfulness_test_with_user_prompt_indices(
|
| 1091 |
+
llm_evaluator,
|
| 1092 |
+
attr_prompt,
|
| 1093 |
+
prompt,
|
| 1094 |
+
generation,
|
| 1095 |
+
user_prompt_indices=user_prompt_indices,
|
| 1096 |
+
)
|
| 1097 |
+
else:
|
| 1098 |
+
scores = llm_evaluator.faithfulness_test(attr_prompt, prompt, generation)
|
| 1099 |
+
results.append(scores)
|
| 1100 |
+
|
| 1101 |
+
return np.array(results), hop_payload
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
def evaluate_dataset(args, dataset_name: str, examples: List[ds_utils.CachedExample], testing_dict):
|
| 1105 |
+
out = evaluate_dataset_multi(args, dataset_name, examples, testing_dict, modes=["faithfulness_gen"])
|
| 1106 |
+
faith = out.get("faithfulness")
|
| 1107 |
+
if not faith:
|
| 1108 |
+
return None
|
| 1109 |
+
return faith["mean"], faith["std"], faith["avg_time"]
|
| 1110 |
+
|
| 1111 |
+
|
| 1112 |
+
def evaluate_dataset_recovery_ruler(args, dataset_name: str, examples: List[ds_utils.CachedExample], testing_dict):
|
| 1113 |
+
out = evaluate_dataset_multi(args, dataset_name, examples, testing_dict, modes=["recovery_ruler"])
|
| 1114 |
+
rec = out.get("recovery")
|
| 1115 |
+
if not rec:
|
| 1116 |
+
return None
|
| 1117 |
+
return rec["mean"], rec["std"], rec["avg_time"], rec["used"], rec["skipped"]
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
def main():
|
| 1121 |
+
parser = argparse.ArgumentParser("Experiment 2 runner (math skipped, AT2 skipped).")
|
| 1122 |
+
parser.add_argument("--datasets", type=str, required=True, help="Comma-separated names or paths.")
|
| 1123 |
+
parser.add_argument("--attr_funcs", type=str, required=True, help="Comma-separated attr funcs (no AT2).")
|
| 1124 |
+
parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).")
|
| 1125 |
+
parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.")
|
| 1126 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 1127 |
+
parser.add_argument("--cuda_num", type=int, default=0)
|
| 1128 |
+
parser.add_argument("--num_examples", type=int, default=100)
|
| 1129 |
+
parser.add_argument(
|
| 1130 |
+
"--mode",
|
| 1131 |
+
type=str,
|
| 1132 |
+
nargs="+",
|
| 1133 |
+
default=["faithfulness_gen"],
|
| 1134 |
+
help=(
|
| 1135 |
+
"One or more of: faithfulness_gen, recovery_ruler. "
|
| 1136 |
+
"Accepts comma-separated values, e.g. '--mode faithfulness_gen,recovery_ruler' "
|
| 1137 |
+
"or '--mode faithfulness_gen, recovery_ruler'."
|
| 1138 |
+
),
|
| 1139 |
+
)
|
| 1140 |
+
parser.add_argument("--sample", type=int, default=None, help="Optional subsample before num_examples.")
|
| 1141 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 1142 |
+
parser.add_argument("--chunk_tokens", type=int, default=128)
|
| 1143 |
+
parser.add_argument("--sink_chunk_tokens", type=int, default=32)
|
| 1144 |
+
parser.add_argument("--n_hops", type=int, default=3)
|
| 1145 |
+
parser.add_argument(
|
| 1146 |
+
"--attnlrp_neg_handling",
|
| 1147 |
+
type=str,
|
| 1148 |
+
choices=["drop", "abs"],
|
| 1149 |
+
default="drop",
|
| 1150 |
+
help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
|
| 1151 |
+
)
|
| 1152 |
+
parser.add_argument(
|
| 1153 |
+
"--attnlrp_norm_mode",
|
| 1154 |
+
type=str,
|
| 1155 |
+
choices=["norm", "no_norm"],
|
| 1156 |
+
default="norm",
|
| 1157 |
+
help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
|
| 1158 |
+
)
|
| 1159 |
+
parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Filtered dataset cache directory.")
|
| 1160 |
+
parser.add_argument("--output_root", type=str, default="exp/exp2/output", help="Directory to store evaluation outputs.")
|
| 1161 |
+
parser.add_argument(
|
| 1162 |
+
"--save_hop_traces",
|
| 1163 |
+
action="store_true",
|
| 1164 |
+
help=(
|
| 1165 |
+
"Save per-sample trace artifacts (attribution vectors + per-sample metrics) under output_root/traces/. "
|
| 1166 |
+
"For multi-hop methods, also saves per-hop token vectors (vh)."
|
| 1167 |
+
),
|
| 1168 |
+
)
|
| 1169 |
+
args = parser.parse_args()
|
| 1170 |
+
modes = _parse_modes(args.mode)
|
| 1171 |
+
|
| 1172 |
+
if args.model_path:
|
| 1173 |
+
model_name = args.model_path
|
| 1174 |
+
elif args.model:
|
| 1175 |
+
model_name = args.model
|
| 1176 |
+
else:
|
| 1177 |
+
raise SystemExit("Please set --model or --model_path.")
|
| 1178 |
+
model_tag = args.model if args.model else Path(args.model_path).name
|
| 1179 |
+
|
| 1180 |
+
datasets = [d.strip() for d in args.datasets.split(",") if d.strip()]
|
| 1181 |
+
attr_funcs = [a.strip() for a in args.attr_funcs.split(",") if a.strip()]
|
| 1182 |
+
|
| 1183 |
+
device = resolve_device(args)
|
| 1184 |
+
model, tokenizer = load_model(model_name, device)
|
| 1185 |
+
|
| 1186 |
+
max_input_len = {
|
| 1187 |
+
"llama-1B": 5500,
|
| 1188 |
+
"llama-3B": 4800,
|
| 1189 |
+
"llama-8B": 3500,
|
| 1190 |
+
"qwen-1.7B": 5500,
|
| 1191 |
+
"qwen-4B": 3500,
|
| 1192 |
+
"qwen-8B": 5000,
|
| 1193 |
+
"qwen-32B": 1500,
|
| 1194 |
+
"gemma-12B": 1500,
|
| 1195 |
+
"gemma-27B": 2000,
|
| 1196 |
+
}.get(args.model, 2000)
|
| 1197 |
+
|
| 1198 |
+
for ds_name in datasets:
|
| 1199 |
+
if "recovery_ruler" in modes and ds_name == "morehopqa":
|
| 1200 |
+
raise SystemExit("recovery_ruler only supports RULER datasets (with needle_spans), not morehopqa.")
|
| 1201 |
+
if "recovery_ruler" in modes and ds_name.startswith("math"):
|
| 1202 |
+
raise SystemExit("recovery_ruler only supports RULER datasets (with needle_spans), not math.")
|
| 1203 |
+
|
| 1204 |
+
# Resolve dataset (prefer prepared cache under data_root)
|
| 1205 |
+
cached_path = Path(args.data_root) / f"{ds_name}.jsonl"
|
| 1206 |
+
if cached_path.exists():
|
| 1207 |
+
examples = ds_utils.load_cached(cached_path, sample=args.sample, seed=args.seed)
|
| 1208 |
+
else:
|
| 1209 |
+
# allow direct cached path or raw loader
|
| 1210 |
+
p = Path(ds_name)
|
| 1211 |
+
if p.exists():
|
| 1212 |
+
examples = ds_utils.load_cached(p, sample=args.sample, seed=args.seed)
|
| 1213 |
+
else:
|
| 1214 |
+
hint = "please run exp/exp2/sample_and_filter.py first (or pass an explicit cached JSONL path)."
|
| 1215 |
+
if ds_name.startswith("math"):
|
| 1216 |
+
hint = "please run exp/exp2/map_math_mine_to_exp2_cache.py first (or pass an explicit cached JSONL path)."
|
| 1217 |
+
raise SystemExit(f"Missing exp2 cache for '{ds_name}'. Expected {cached_path}; {hint}")
|
| 1218 |
+
|
| 1219 |
+
for attr_func in attr_funcs:
|
| 1220 |
+
if attr_func.lower() == "at2":
|
| 1221 |
+
print("Skipping AT2 as requested.")
|
| 1222 |
+
continue
|
| 1223 |
+
|
| 1224 |
+
testing_dict: Dict[str, any] = {
|
| 1225 |
+
"model": model,
|
| 1226 |
+
"model_tag": model_tag,
|
| 1227 |
+
"tokenizer": tokenizer,
|
| 1228 |
+
"attr_func": attr_func,
|
| 1229 |
+
"max_input_len": max_input_len,
|
| 1230 |
+
"chunk_tokens": args.chunk_tokens,
|
| 1231 |
+
"sink_chunk_tokens": args.sink_chunk_tokens,
|
| 1232 |
+
"n_hops": args.n_hops,
|
| 1233 |
+
"attnlrp_neg_handling": args.attnlrp_neg_handling,
|
| 1234 |
+
"attnlrp_norm_mode": args.attnlrp_norm_mode,
|
| 1235 |
+
"device": device,
|
| 1236 |
+
"batch_size": 1,
|
| 1237 |
+
"save_hop_traces": bool(args.save_hop_traces),
|
| 1238 |
+
}
|
| 1239 |
+
result = evaluate_dataset_multi(args, ds_name, examples, testing_dict, modes=modes)
|
| 1240 |
+
|
| 1241 |
+
if "faithfulness_gen" in modes:
|
| 1242 |
+
faith = result.get("faithfulness")
|
| 1243 |
+
if not faith:
|
| 1244 |
+
print(f"No faithfulness results for {ds_name} with {attr_func}.")
|
| 1245 |
+
else:
|
| 1246 |
+
mean = faith["mean"]
|
| 1247 |
+
std = faith["std"]
|
| 1248 |
+
avg_time = float(faith["avg_time"])
|
| 1249 |
+
|
| 1250 |
+
out_dir = Path(args.output_root) / "faithfulness" / ds_name / model_tag
|
| 1251 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 1252 |
+
filename = f"{attr_func}_{args.num_examples}_examples.csv"
|
| 1253 |
+
with open(out_dir / filename, "w") as f:
|
| 1254 |
+
f.write("Method,RISE,MAS,RISE+AP\n")
|
| 1255 |
+
f.write(",".join(["Seq Attr Scores Mean"] + [str(x) for x in mean[0].tolist()]) + "\n")
|
| 1256 |
+
f.write(",".join(["Row Attr Scores Mean"] + [str(x) for x in mean[1].tolist()]) + "\n")
|
| 1257 |
+
f.write(",".join(["Recursive Attr Scores Mean"] + [str(x) for x in mean[2].tolist()]) + "\n")
|
| 1258 |
+
f.write(",".join(["Seq Attr Scores Var"] + [str(x) for x in std[0].tolist()]) + "\n")
|
| 1259 |
+
f.write(",".join(["Row Attr Scores Var"] + [str(x) for x in std[1].tolist()]) + "\n")
|
| 1260 |
+
f.write(",".join(["Recursive Attr Scores Var"] + [str(x) for x in std[2].tolist()]) + "\n")
|
| 1261 |
+
f.write(f"Avg Sample Time (s),{avg_time}\n")
|
| 1262 |
+
print(f"[{ds_name}] {attr_func} -> {out_dir/filename} (avg sample time: {avg_time:.2f}s)")
|
| 1263 |
+
|
| 1264 |
+
if "recovery_ruler" in modes:
|
| 1265 |
+
rec = result.get("recovery")
|
| 1266 |
+
if not rec:
|
| 1267 |
+
print(f"No recovery results for {ds_name} with {attr_func}.")
|
| 1268 |
+
else:
|
| 1269 |
+
mean = rec["mean"]
|
| 1270 |
+
std = rec["std"]
|
| 1271 |
+
avg_time = float(rec["avg_time"])
|
| 1272 |
+
used = int(rec["used"])
|
| 1273 |
+
skipped = int(rec["skipped"])
|
| 1274 |
+
|
| 1275 |
+
out_dir = Path(args.output_root) / "recovery" / ds_name / model_tag
|
| 1276 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 1277 |
+
filename = f"{attr_func}_{args.num_examples}_examples.csv"
|
| 1278 |
+
with open(out_dir / filename, "w") as f:
|
| 1279 |
+
f.write("Method,Recovery@10%\n")
|
| 1280 |
+
f.write(f"Seq Attr Recovery Mean,{mean[0]}\n")
|
| 1281 |
+
f.write(f"Row Attr Recovery Mean,{mean[1]}\n")
|
| 1282 |
+
f.write(f"Recursive Attr Recovery Mean,{mean[2]}\n")
|
| 1283 |
+
f.write(f"Seq Attr Recovery Std,{std[0]}\n")
|
| 1284 |
+
f.write(f"Row Attr Recovery Std,{std[1]}\n")
|
| 1285 |
+
f.write(f"Recursive Attr Recovery Std,{std[2]}\n")
|
| 1286 |
+
f.write(f"Examples Used,{used}\n")
|
| 1287 |
+
f.write(f"Examples Skipped,{skipped}\n")
|
| 1288 |
+
f.write(f"Avg Sample Time (s),{avg_time}\n")
|
| 1289 |
+
print(
|
| 1290 |
+
f"[{ds_name}] {attr_func} -> {out_dir/filename} "
|
| 1291 |
+
f"(used={used} skipped={skipped} avg sample time: {avg_time:.2f}s)"
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
|
| 1295 |
+
if __name__ == "__main__":
|
| 1296 |
+
main()
|
exp/exp2/sample_and_filter.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Dataset sampler for Experiment 2.
|
| 4 |
+
|
| 5 |
+
Steps:
|
| 6 |
+
- Load a dataset item (MoreHopQA / HotpotQA / RULER niah / RULER vt).
|
| 7 |
+
- Call the generation model (qwen3-235b-a22b-2507) with a system prompt that
|
| 8 |
+
asks for brief reasoning and a final answer wrapped in \\box{}.
|
| 9 |
+
- Enforce the output format: keep only generations that look like
|
| 10 |
+
"<reasoning text> + final \\box{} answer" with nothing after the box.
|
| 11 |
+
- Call the judge model (deepseek-v3-1-terminus) to check whether the boxed
|
| 12 |
+
answer matches the dataset reference answer; keep only judged True samples.
|
| 13 |
+
- Rebuild `target` as "<reasoning>\\n<answer text (no box)>" and store filtered
|
| 14 |
+
samples to exp/exp2/data/<dataset>.jsonl (or a custom path) with inferred spans.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import time
|
| 24 |
+
import urllib.error
|
| 25 |
+
import urllib.request
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Dict, Iterable, List, Optional
|
| 28 |
+
|
| 29 |
+
from transformers import AutoTokenizer
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
|
| 32 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 33 |
+
if str(REPO_ROOT) not in sys.path:
|
| 34 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 35 |
+
|
| 36 |
+
from exp.exp2.dataset_utils import (
|
| 37 |
+
CachedExample,
|
| 38 |
+
DatasetLoader,
|
| 39 |
+
attach_spans_from_answer,
|
| 40 |
+
split_boxed_generation,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class RateLimitError(RuntimeError):
|
| 45 |
+
"""Raised when API returns 429; carries a suggested wait time."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, wait_seconds: float, detail: str) -> None:
|
| 48 |
+
super().__init__(detail)
|
| 49 |
+
self.wait_seconds = wait_seconds
|
| 50 |
+
|
| 51 |
+
# GEN_SYSTEM_PROMPT = (
|
| 52 |
+
# "You are a careful reasoning assistant. "
|
| 53 |
+
# "Before answering, engage in an extremely detailed and exhaustive chain of thought. **No fewer than 2k tokens.** "
|
| 54 |
+
# "Do not skip any logical steps, even if they seem obvious. "
|
| 55 |
+
# "Process this freely and naturally without using specific headers or strict formatting. "
|
| 56 |
+
# "When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
|
| 57 |
+
# "Ensure the box wraps the **sentence** that naturally delivers the answer. DO NOT rewrite the answer word for the box separately."
|
| 58 |
+
# )
|
| 59 |
+
|
| 60 |
+
GEN_SYSTEM_PROMPT = (
|
| 61 |
+
"You are a reasoning assistant. "
|
| 62 |
+
"Before answering, engage in an chain of thought. "
|
| 63 |
+
"Process this freely and naturally without using specific headers or strict formatting. "
|
| 64 |
+
"When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
|
| 65 |
+
"Ensure the box wraps the **sentence** that naturally delivers the answer. DO NOT rewrite the answer word for the box separately."
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
JUDGE_SYSTEM_PROMPT = (
|
| 69 |
+
"You verify whether the model's boxed answer matches the reference answer. "
|
| 70 |
+
"Reply strictly with True or False and nothing else."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def call_chat_api(
|
| 75 |
+
api_base: str,
|
| 76 |
+
api_key: str,
|
| 77 |
+
model: str,
|
| 78 |
+
messages: List[Dict[str, str]],
|
| 79 |
+
*,
|
| 80 |
+
timeout: int,
|
| 81 |
+
max_tokens: int,
|
| 82 |
+
temperature: float,
|
| 83 |
+
cache_ttl: int,
|
| 84 |
+
cache_namespace: Optional[str],
|
| 85 |
+
rate_limit_delay: Optional[float] = None,
|
| 86 |
+
) -> str:
|
| 87 |
+
url = api_base.rstrip("/") + "/chat/completions"
|
| 88 |
+
payload: Dict[str, Any] = {
|
| 89 |
+
"model": model,
|
| 90 |
+
"messages": messages,
|
| 91 |
+
"max_tokens": max_tokens,
|
| 92 |
+
"temperature": temperature,
|
| 93 |
+
}
|
| 94 |
+
if cache_ttl > 0:
|
| 95 |
+
cache_obj: Dict[str, Any] = {"ttl": cache_ttl}
|
| 96 |
+
if cache_namespace:
|
| 97 |
+
cache_obj["namespace"] = cache_namespace
|
| 98 |
+
payload["cache"] = cache_obj
|
| 99 |
+
|
| 100 |
+
data = json.dumps(payload).encode("utf-8")
|
| 101 |
+
headers = {"Content-Type": "application/json"}
|
| 102 |
+
if api_key:
|
| 103 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
| 104 |
+
|
| 105 |
+
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
|
| 106 |
+
opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
|
| 107 |
+
try:
|
| 108 |
+
with opener.open(req, timeout=timeout) as resp:
|
| 109 |
+
resp_bytes = resp.read()
|
| 110 |
+
except urllib.error.HTTPError as e:
|
| 111 |
+
detail = e.read().decode("utf-8", errors="ignore") if hasattr(e, "read") else ""
|
| 112 |
+
if e.code == 429:
|
| 113 |
+
retry_after = None
|
| 114 |
+
if hasattr(e, "headers") and e.headers:
|
| 115 |
+
retry_after_header = e.headers.get("Retry-After")
|
| 116 |
+
if retry_after_header:
|
| 117 |
+
try:
|
| 118 |
+
retry_after = float(retry_after_header)
|
| 119 |
+
except ValueError:
|
| 120 |
+
retry_after = None
|
| 121 |
+
wait = retry_after or rate_limit_delay or 5.0
|
| 122 |
+
raise RateLimitError(wait, f"API HTTP 429: {detail}") from e
|
| 123 |
+
raise RuntimeError(f"API HTTP error {e.code}: {detail}") from e
|
| 124 |
+
except urllib.error.URLError as e:
|
| 125 |
+
raise RuntimeError(f"API request failed: {e}") from e
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
response = json.loads(resp_bytes.decode("utf-8"))
|
| 129 |
+
except json.JSONDecodeError as e:
|
| 130 |
+
raise RuntimeError(f"Failed to decode API response: {resp_bytes!r}") from e
|
| 131 |
+
|
| 132 |
+
choices = response.get("choices", [])
|
| 133 |
+
if not choices:
|
| 134 |
+
raise RuntimeError(f"Empty choices from API: {response}")
|
| 135 |
+
content = choices[0].get("message", {}).get("content", "")
|
| 136 |
+
if not content:
|
| 137 |
+
raise RuntimeError(f"Empty content from API: {response}")
|
| 138 |
+
return content.strip()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def build_gen_messages(prompt: str) -> List[Dict[str, str]]:
|
| 142 |
+
return [
|
| 143 |
+
{"role": "system", "content": GEN_SYSTEM_PROMPT},
|
| 144 |
+
{"role": "user", "content": prompt},
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def build_judge_messages(reference_answer: str, candidate_answer: str) -> List[Dict[str, str]]:
|
| 149 |
+
user = (
|
| 150 |
+
"Decide if the model's boxed answer matches the reference answer.\n"
|
| 151 |
+
f"Reference answer: {reference_answer}\n"
|
| 152 |
+
f"Model boxed answer (only the content inside \\box{{}}): {candidate_answer}\n"
|
| 153 |
+
"Output only True if they are semantically consistent; otherwise output False."
|
| 154 |
+
)
|
| 155 |
+
return [
|
| 156 |
+
{"role": "system", "content": JUDGE_SYSTEM_PROMPT},
|
| 157 |
+
{"role": "user", "content": user},
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def parse_bool(text: str) -> bool:
|
| 162 |
+
first = text.strip().splitlines()[0].strip().lower()
|
| 163 |
+
if first in {"true", "yes"}:
|
| 164 |
+
return True
|
| 165 |
+
if first in {"false", "no"}:
|
| 166 |
+
return False
|
| 167 |
+
# fallback: check substring
|
| 168 |
+
if "true" in first and "false" not in first:
|
| 169 |
+
return True
|
| 170 |
+
if "false" in first:
|
| 171 |
+
return False
|
| 172 |
+
raise ValueError(f"Cannot parse boolean from: {text!r}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def write_cache(out_path: Path, examples: Iterable[CachedExample]) -> int:
|
| 176 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 177 |
+
count = 0
|
| 178 |
+
with out_path.open("w", encoding="utf-8") as f:
|
| 179 |
+
for ex in examples:
|
| 180 |
+
obj: Dict[str, Any] = {
|
| 181 |
+
"prompt": ex.prompt,
|
| 182 |
+
"target": ex.target,
|
| 183 |
+
"indices_to_explain": ex.indices_to_explain,
|
| 184 |
+
"attr_mask_indices": ex.attr_mask_indices,
|
| 185 |
+
"sink_span": ex.sink_span,
|
| 186 |
+
"thinking_span": ex.thinking_span,
|
| 187 |
+
"metadata": ex.metadata,
|
| 188 |
+
}
|
| 189 |
+
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
| 190 |
+
count += 1
|
| 191 |
+
return count
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def main():
|
| 195 |
+
parser = argparse.ArgumentParser("Sample and filter dataset examples for exp2.")
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--dataset",
|
| 198 |
+
type=str,
|
| 199 |
+
required=True,
|
| 200 |
+
help="morehopqa | hotpotqa_long | niah_* | vt_* | <morehopqa_json_path> | <ruler_jsonl_path>",
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument("--max_examples", type=int, default=100, help="Number of raw examples to sample before filtering.")
|
| 203 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 204 |
+
parser.add_argument("--api_base", type=str, default="http://localhost:4000/v1", help="Chat API base URL.")
|
| 205 |
+
parser.add_argument("--api_key", type=str, default=None, help="API key; defaults to FLASHTRACE_API_KEY/OPENAI_API_KEY.")
|
| 206 |
+
parser.add_argument("--generator_model", type=str, default="qwen3-235b-a22b-2507")
|
| 207 |
+
parser.add_argument("--judge_model", type=str, default="deepseek-v3-1-terminus")
|
| 208 |
+
parser.add_argument("--api_timeout", type=int, default=300)
|
| 209 |
+
parser.add_argument("--api_max_tokens", type=int, default=8192)
|
| 210 |
+
parser.add_argument("--api_temperature", type=float, default=0.0)
|
| 211 |
+
parser.add_argument("--api_cache_ttl", type=int, default=600)
|
| 212 |
+
parser.add_argument("--api_cache_namespace", type=str, default="flashtrace-exp2")
|
| 213 |
+
parser.add_argument("--retry_delay", type=float, default=2.0)
|
| 214 |
+
parser.add_argument("--retries", type=int, default=2, help="Additional retries on API failure.")
|
| 215 |
+
parser.add_argument("--request_interval", type=float, default=1.0, help="Sleep seconds between generation calls.")
|
| 216 |
+
parser.add_argument("--judge_interval", type=float, default=1.0, help="Sleep seconds between judge calls.")
|
| 217 |
+
parser.add_argument("--tokenizer_model", type=str, default=None, help="Tokenizer path for span extraction (default: generator model).")
|
| 218 |
+
parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Output directory for filtered caches.")
|
| 219 |
+
parser.add_argument("--out", type=str, default=None, help="Optional explicit output path (JSONL).")
|
| 220 |
+
parser.add_argument("--rate_limit_delay", type=float, default=5.0, help="Seconds to wait on HTTP 429 before retrying.")
|
| 221 |
+
args = parser.parse_args()
|
| 222 |
+
|
| 223 |
+
api_key = args.api_key or os.environ.get("FLASHTRACE_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
| 224 |
+
if not api_key:
|
| 225 |
+
raise SystemExit("Set --api_key or FLASHTRACE_API_KEY/OPENAI_API_KEY for API access.")
|
| 226 |
+
|
| 227 |
+
loader = DatasetLoader(seed=args.seed, data_root=args.data_root)
|
| 228 |
+
# Load full dataset; we will stop early once enough kept examples are collected.
|
| 229 |
+
raw_examples = loader.load_raw(args.dataset, sample=None)
|
| 230 |
+
if not raw_examples:
|
| 231 |
+
raise SystemExit("No examples loaded.")
|
| 232 |
+
|
| 233 |
+
tok_name = args.tokenizer_model or args.generator_model
|
| 234 |
+
tok_path = Path(tok_name)
|
| 235 |
+
if tok_path.exists():
|
| 236 |
+
tokenizer = AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
|
| 237 |
+
else:
|
| 238 |
+
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
| 239 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 240 |
+
|
| 241 |
+
kept: List[CachedExample] = []
|
| 242 |
+
total = len(raw_examples)
|
| 243 |
+
kept_bar = tqdm(total=args.max_examples, desc="Kept (judge=True)", position=1, leave=False)
|
| 244 |
+
attempted = 0
|
| 245 |
+
|
| 246 |
+
for idx, ex in enumerate(tqdm(raw_examples, total=total, desc="Sampling"), 1):
|
| 247 |
+
if len(kept) >= args.max_examples:
|
| 248 |
+
break
|
| 249 |
+
reference_answer = ex.metadata.get("reference_answer") or ex.target or ""
|
| 250 |
+
gen_messages = build_gen_messages(ex.prompt)
|
| 251 |
+
attempted = idx
|
| 252 |
+
|
| 253 |
+
# Step 1: generation
|
| 254 |
+
for attempt in range(args.retries + 1):
|
| 255 |
+
try:
|
| 256 |
+
generation = call_chat_api(
|
| 257 |
+
args.api_base,
|
| 258 |
+
api_key,
|
| 259 |
+
args.generator_model,
|
| 260 |
+
gen_messages,
|
| 261 |
+
timeout=args.api_timeout,
|
| 262 |
+
max_tokens=args.api_max_tokens,
|
| 263 |
+
temperature=args.api_temperature,
|
| 264 |
+
cache_ttl=args.api_cache_ttl,
|
| 265 |
+
cache_namespace=args.api_cache_namespace,
|
| 266 |
+
rate_limit_delay=args.rate_limit_delay,
|
| 267 |
+
)
|
| 268 |
+
break
|
| 269 |
+
except RateLimitError as e:
|
| 270 |
+
if attempt >= args.retries:
|
| 271 |
+
raise
|
| 272 |
+
time.sleep(e.wait_seconds)
|
| 273 |
+
except Exception: # noqa: BLE001
|
| 274 |
+
if attempt >= args.retries:
|
| 275 |
+
raise
|
| 276 |
+
time.sleep(args.retry_delay)
|
| 277 |
+
if args.request_interval > 0:
|
| 278 |
+
time.sleep(args.request_interval)
|
| 279 |
+
|
| 280 |
+
parsed = split_boxed_generation(generation)
|
| 281 |
+
if not parsed:
|
| 282 |
+
print(f"[{idx}/{total}] skipped=format")
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
thinking_text, boxed_segment, boxed_answer = parsed
|
| 286 |
+
target_text = f"{thinking_text}\n{boxed_answer}" if thinking_text else boxed_answer
|
| 287 |
+
judge_messages = build_judge_messages(reference_answer, boxed_answer)
|
| 288 |
+
|
| 289 |
+
ok = False
|
| 290 |
+
judge_resp = ""
|
| 291 |
+
for attempt in range(args.retries + 1):
|
| 292 |
+
try:
|
| 293 |
+
judge_resp = call_chat_api(
|
| 294 |
+
args.api_base,
|
| 295 |
+
api_key,
|
| 296 |
+
args.judge_model,
|
| 297 |
+
judge_messages,
|
| 298 |
+
timeout=args.api_timeout,
|
| 299 |
+
max_tokens=64,
|
| 300 |
+
temperature=0.0,
|
| 301 |
+
cache_ttl=args.api_cache_ttl,
|
| 302 |
+
cache_namespace=args.api_cache_namespace,
|
| 303 |
+
rate_limit_delay=args.rate_limit_delay,
|
| 304 |
+
)
|
| 305 |
+
ok = parse_bool(judge_resp)
|
| 306 |
+
break
|
| 307 |
+
except RateLimitError as e:
|
| 308 |
+
if attempt >= args.retries:
|
| 309 |
+
raise
|
| 310 |
+
time.sleep(e.wait_seconds)
|
| 311 |
+
except Exception: # noqa: BLE001
|
| 312 |
+
if attempt >= args.retries:
|
| 313 |
+
raise
|
| 314 |
+
time.sleep(args.retry_delay)
|
| 315 |
+
if args.judge_interval > 0:
|
| 316 |
+
time.sleep(args.judge_interval)
|
| 317 |
+
|
| 318 |
+
status = "kept" if ok else "filtered"
|
| 319 |
+
print(f"[{idx}/{total}] judge={status}")
|
| 320 |
+
if not ok:
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
new_meta = dict(ex.metadata)
|
| 324 |
+
new_meta["reference_answer"] = reference_answer
|
| 325 |
+
new_meta["judge_response"] = judge_resp
|
| 326 |
+
|
| 327 |
+
new_ex = CachedExample(
|
| 328 |
+
prompt=ex.prompt,
|
| 329 |
+
target=target_text,
|
| 330 |
+
indices_to_explain=None,
|
| 331 |
+
attr_mask_indices=ex.attr_mask_indices,
|
| 332 |
+
sink_span=None,
|
| 333 |
+
thinking_span=None,
|
| 334 |
+
metadata=new_meta,
|
| 335 |
+
)
|
| 336 |
+
new_ex = attach_spans_from_answer(new_ex, tokenizer, boxed_answer)
|
| 337 |
+
if not (isinstance(new_ex.sink_span, list) and len(new_ex.sink_span) == 2):
|
| 338 |
+
print(f"[{idx}/{total}] skipped=span")
|
| 339 |
+
continue
|
| 340 |
+
|
| 341 |
+
# Token-level indices_to_explain: boxed-inner answer token span in target (closed interval).
|
| 342 |
+
new_ex = CachedExample(
|
| 343 |
+
prompt=new_ex.prompt,
|
| 344 |
+
target=new_ex.target,
|
| 345 |
+
indices_to_explain=new_ex.sink_span,
|
| 346 |
+
attr_mask_indices=new_ex.attr_mask_indices,
|
| 347 |
+
sink_span=new_ex.sink_span,
|
| 348 |
+
thinking_span=new_ex.thinking_span,
|
| 349 |
+
metadata=new_ex.metadata,
|
| 350 |
+
)
|
| 351 |
+
kept.append(new_ex)
|
| 352 |
+
kept_bar.update(1)
|
| 353 |
+
|
| 354 |
+
kept_bar.close()
|
| 355 |
+
|
| 356 |
+
out_path = Path(args.out) if args.out else Path(args.data_root) / f"{args.dataset}.jsonl"
|
| 357 |
+
written = write_cache(out_path, kept)
|
| 358 |
+
attempted_total = attempted or 0
|
| 359 |
+
print(f"Kept {written} / target {args.max_examples} (attempted {attempted_total} / {total}) -> {out_path}")
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
main()
|
exp/exp3/README.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashTrace 实验 3:长/短 CoT 对比(case study)
|
| 2 |
+
|
| 3 |
+
本目录提供一个「长/短 CoT」的最小可复现实验:
|
| 4 |
+
- 从 RULER `niah_mq_q2 (1024)` 中分别筛出:
|
| 5 |
+
- short-CoT:短推理 + `\box{}` 最终答案
|
| 6 |
+
- long-CoT:长推理 + `\box{}` 最终答案
|
| 7 |
+
- 只跑 `attnlrp`(hop0)并只计算 token-level `recovery@10%`(gold 来自 `needle_spans`)。
|
| 8 |
+
- 落盘 trace(npz + manifest)到 `exp/exp3/output/`,格式对齐 `exp/exp2/run_exp.py` 的 trace 习惯。
|
| 9 |
+
|
| 10 |
+
## 1) 采样与过滤(生成 + judge)
|
| 11 |
+
|
| 12 |
+
默认读取:
|
| 13 |
+
`data/ruler_multihop/1024/niah_mq_q2/validation.jsonl`
|
| 14 |
+
|
| 15 |
+
需要一个 OpenAI-compatible 的 chat API(默认 `http://localhost:4000/v1`)以及 API key。
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
export FLASHTRACE_API_KEY=... # 或 OPENAI_API_KEY
|
| 19 |
+
|
| 20 |
+
python exp/exp3/sample_and_filter.py \
|
| 21 |
+
--tokenizer_model /opt/share/models/Qwen/Qwen3-8B/ \
|
| 22 |
+
--min_long_thinking_tokens 512 \
|
| 23 |
+
--max_short_thinking_tokens 256
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
输出(默认):
|
| 27 |
+
- `exp/exp3/data/niah_mq_q2_short_cot.jsonl`
|
| 28 |
+
- `exp/exp3/data/niah_mq_q2_long_cot.jsonl`
|
| 29 |
+
|
| 30 |
+
说明:
|
| 31 |
+
- 默认各采 1 条;可用 `--max_short` / `--max_long` 分别指定数量(`--max_pairs` 是两者的兼容别名)。
|
| 32 |
+
|
| 33 |
+
## 2) 归因与 recovery(AttnLRP hop0)
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
python exp/exp3/run_exp.py \
|
| 37 |
+
--model qwen-8B \
|
| 38 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 39 |
+
--cuda 3,4,5,7
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
输出:
|
| 43 |
+
- recovery CSV:`exp/exp3/output/recovery/<dataset>/<model>/attnlrp_1_examples.csv`
|
| 44 |
+
- trace:`exp/exp3/output/traces/<dataset>/<model>/<run_tag>/ex_*.npz` + `manifest.jsonl`
|
| 45 |
+
- 汇总 JSON:`exp/exp3/output/recovery/summary_<model>.json`
|
| 46 |
+
|
| 47 |
+
常用参数:
|
| 48 |
+
- `--top_fraction`:recovery 的 top fraction(默认 0.1)
|
| 49 |
+
- `--attnlrp_neg_handling drop|abs`
|
| 50 |
+
- `--attnlrp_norm_mode norm|no_norm`
|
exp/exp3/extract_segment_weights.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Extract CoT/output segment attribution weights from exp3 trace artifacts.
|
| 4 |
+
|
| 5 |
+
Background
|
| 6 |
+
----------
|
| 7 |
+
exp/exp3/run_exp.py saves per-sample trace npz files that contain token-level
|
| 8 |
+
importance vectors over the FULL (prompt + generation) token sequence:
|
| 9 |
+
- v_seq_all: sum over rows of seq attribution matrix (shape [P+G])
|
| 10 |
+
- v_row_all: row attribution vector for indices_to_explain (shape [P+G])
|
| 11 |
+
- v_rec_all: recursive attribution vector for indices_to_explain (shape [P+G])
|
| 12 |
+
|
| 13 |
+
For exp3 cached samples, we also have generation-token spans:
|
| 14 |
+
- thinking_span_gen: CoT span [start,end] in generation-token coordinates
|
| 15 |
+
- sink_span_gen: output span [start,end] in generation-token coordinates
|
| 16 |
+
|
| 17 |
+
This script slices v_*_all into:
|
| 18 |
+
- cot: tokens in thinking_span_gen (offset by prompt_len)
|
| 19 |
+
- output: tokens in sink_span_gen (offset by prompt_len)
|
| 20 |
+
|
| 21 |
+
and reports segment sums/fractions (and optionally writes a JSON summary).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass(frozen=True)
|
| 36 |
+
class TracePaths:
|
| 37 |
+
dataset: str
|
| 38 |
+
model_tag: str
|
| 39 |
+
run_tag: str
|
| 40 |
+
npz_path: Path
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _pick_latest_subdir(path: Path) -> Optional[Path]:
|
| 44 |
+
if not path.exists():
|
| 45 |
+
return None
|
| 46 |
+
subs = [p for p in path.iterdir() if p.is_dir()]
|
| 47 |
+
if not subs:
|
| 48 |
+
return None
|
| 49 |
+
subs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
| 50 |
+
return subs[0]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _resolve_trace_paths(
|
| 54 |
+
*,
|
| 55 |
+
output_root: Path,
|
| 56 |
+
dataset: str,
|
| 57 |
+
model_tag: Optional[str],
|
| 58 |
+
run_tag: Optional[str],
|
| 59 |
+
example_idx: int,
|
| 60 |
+
) -> TracePaths:
|
| 61 |
+
base = output_root / "traces" / dataset
|
| 62 |
+
if not base.exists():
|
| 63 |
+
raise FileNotFoundError(f"Trace dataset dir not found: {base}")
|
| 64 |
+
|
| 65 |
+
if model_tag is None:
|
| 66 |
+
model_dirs = [p for p in base.iterdir() if p.is_dir()]
|
| 67 |
+
if not model_dirs:
|
| 68 |
+
raise FileNotFoundError(f"No model subdir under: {base}")
|
| 69 |
+
if len(model_dirs) != 1:
|
| 70 |
+
raise SystemExit(f"Multiple model dirs under {base}; pass --model_tag. Found: {[p.name for p in model_dirs]}")
|
| 71 |
+
model_dir = model_dirs[0]
|
| 72 |
+
model_tag = model_dir.name
|
| 73 |
+
else:
|
| 74 |
+
model_dir = base / model_tag
|
| 75 |
+
if not model_dir.exists():
|
| 76 |
+
raise FileNotFoundError(f"Trace model dir not found: {model_dir}")
|
| 77 |
+
|
| 78 |
+
if run_tag is None:
|
| 79 |
+
run_dir = _pick_latest_subdir(model_dir)
|
| 80 |
+
if run_dir is None:
|
| 81 |
+
raise FileNotFoundError(f"No run subdir under: {model_dir}")
|
| 82 |
+
run_tag = run_dir.name
|
| 83 |
+
else:
|
| 84 |
+
run_dir = model_dir / run_tag
|
| 85 |
+
if not run_dir.exists():
|
| 86 |
+
raise FileNotFoundError(f"Trace run dir not found: {run_dir}")
|
| 87 |
+
|
| 88 |
+
npz_name = f"ex_{int(example_idx):06d}.npz"
|
| 89 |
+
npz_path = run_dir / npz_name
|
| 90 |
+
if not npz_path.exists():
|
| 91 |
+
raise FileNotFoundError(f"Trace npz not found: {npz_path}")
|
| 92 |
+
|
| 93 |
+
return TracePaths(dataset=dataset, model_tag=model_tag, run_tag=run_tag, npz_path=npz_path)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _as_span(arr: Any) -> Optional[Tuple[int, int]]:
|
| 97 |
+
if arr is None:
|
| 98 |
+
return None
|
| 99 |
+
try:
|
| 100 |
+
a = np.asarray(arr).reshape(-1).tolist()
|
| 101 |
+
except Exception:
|
| 102 |
+
return None
|
| 103 |
+
if len(a) != 2:
|
| 104 |
+
return None
|
| 105 |
+
try:
|
| 106 |
+
start = int(a[0])
|
| 107 |
+
end = int(a[1])
|
| 108 |
+
except Exception:
|
| 109 |
+
return None
|
| 110 |
+
if start < 0 or end < start:
|
| 111 |
+
return None
|
| 112 |
+
return start, end
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _segment_stats(v: np.ndarray, start: int, end: int) -> Dict[str, float]:
|
| 116 |
+
if end < start:
|
| 117 |
+
return {"sum": 0.0, "mean": 0.0, "max": 0.0}
|
| 118 |
+
seg = v[start : end + 1]
|
| 119 |
+
if seg.size == 0:
|
| 120 |
+
return {"sum": 0.0, "mean": 0.0, "max": 0.0}
|
| 121 |
+
return {
|
| 122 |
+
"sum": float(seg.sum()),
|
| 123 |
+
"mean": float(seg.mean()),
|
| 124 |
+
"max": float(seg.max()),
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _slice_segment(v: np.ndarray, start: int, end: int) -> List[float]:
|
| 129 |
+
if end < start:
|
| 130 |
+
return []
|
| 131 |
+
seg = v[start : end + 1]
|
| 132 |
+
return [float(x) for x in seg.tolist()]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def extract_one(npz_path: Path) -> Dict[str, Any]:
|
| 136 |
+
d = np.load(npz_path)
|
| 137 |
+
required = ["prompt_len", "gen_len", "v_seq_all", "v_row_all", "v_rec_all"]
|
| 138 |
+
for k in required:
|
| 139 |
+
if k not in d:
|
| 140 |
+
raise KeyError(f"Missing key in trace npz {npz_path}: {k}")
|
| 141 |
+
|
| 142 |
+
prompt_len = int(np.asarray(d["prompt_len"]).item())
|
| 143 |
+
gen_len = int(np.asarray(d["gen_len"]).item())
|
| 144 |
+
total_len = prompt_len + gen_len
|
| 145 |
+
|
| 146 |
+
v_seq_all = np.asarray(d["v_seq_all"], dtype=np.float64).reshape(-1)
|
| 147 |
+
v_row_all = np.asarray(d["v_row_all"], dtype=np.float64).reshape(-1)
|
| 148 |
+
v_rec_all = np.asarray(d["v_rec_all"], dtype=np.float64).reshape(-1)
|
| 149 |
+
for name, v in [("v_seq_all", v_seq_all), ("v_row_all", v_row_all), ("v_rec_all", v_rec_all)]:
|
| 150 |
+
if int(v.size) != int(total_len):
|
| 151 |
+
raise ValueError(f"{name} length mismatch: expected {total_len}, got {int(v.size)}")
|
| 152 |
+
|
| 153 |
+
sink_span_gen = _as_span(d.get("sink_span_gen"))
|
| 154 |
+
thinking_span_gen = _as_span(d.get("thinking_span_gen"))
|
| 155 |
+
if sink_span_gen is None:
|
| 156 |
+
raise KeyError("Trace missing sink_span_gen; cannot define output span.")
|
| 157 |
+
if thinking_span_gen is None:
|
| 158 |
+
# Best-effort: infer thinking span as [0, sink_start-1].
|
| 159 |
+
sink_start, _ = sink_span_gen
|
| 160 |
+
thinking_span_gen = (0, max(0, sink_start - 1))
|
| 161 |
+
|
| 162 |
+
think_start_g, think_end_g = thinking_span_gen
|
| 163 |
+
sink_start_g, sink_end_g = sink_span_gen
|
| 164 |
+
|
| 165 |
+
cot_start = prompt_len + think_start_g
|
| 166 |
+
cot_end = min(prompt_len + think_end_g, total_len - 1)
|
| 167 |
+
out_start = prompt_len + sink_start_g
|
| 168 |
+
out_end = min(prompt_len + sink_end_g, total_len - 1)
|
| 169 |
+
|
| 170 |
+
def pack(v: np.ndarray) -> Dict[str, Any]:
|
| 171 |
+
total = float(v.sum())
|
| 172 |
+
cot = _segment_stats(v, cot_start, cot_end)
|
| 173 |
+
out = _segment_stats(v, out_start, out_end)
|
| 174 |
+
denom = cot["sum"] + out["sum"]
|
| 175 |
+
return {
|
| 176 |
+
"total_sum": total,
|
| 177 |
+
"cot": {
|
| 178 |
+
"start_abs": int(cot_start),
|
| 179 |
+
"end_abs": int(cot_end),
|
| 180 |
+
"len": int(max(0, cot_end - cot_start + 1)),
|
| 181 |
+
**cot,
|
| 182 |
+
"fraction_of_total": float(cot["sum"] / total) if total > 0 else float("nan"),
|
| 183 |
+
"fraction_of_cot_plus_output": float(cot["sum"] / denom) if denom > 0 else float("nan"),
|
| 184 |
+
},
|
| 185 |
+
"output": {
|
| 186 |
+
"start_abs": int(out_start),
|
| 187 |
+
"end_abs": int(out_end),
|
| 188 |
+
"len": int(max(0, out_end - out_start + 1)),
|
| 189 |
+
**out,
|
| 190 |
+
"fraction_of_total": float(out["sum"] / total) if total > 0 else float("nan"),
|
| 191 |
+
"fraction_of_cot_plus_output": float(out["sum"] / denom) if denom > 0 else float("nan"),
|
| 192 |
+
},
|
| 193 |
+
"cot_weights": _slice_segment(v, cot_start, cot_end),
|
| 194 |
+
"output_weights": _slice_segment(v, out_start, out_end),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"prompt_len": int(prompt_len),
|
| 199 |
+
"gen_len": int(gen_len),
|
| 200 |
+
"total_len": int(total_len),
|
| 201 |
+
"thinking_span_gen": [int(think_start_g), int(think_end_g)],
|
| 202 |
+
"sink_span_gen": [int(sink_start_g), int(sink_end_g)],
|
| 203 |
+
"seq": pack(v_seq_all),
|
| 204 |
+
"row": pack(v_row_all),
|
| 205 |
+
"rec": pack(v_rec_all),
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def main() -> None:
|
| 210 |
+
parser = argparse.ArgumentParser("Extract CoT/output weights from exp3 traces.")
|
| 211 |
+
parser.add_argument("--output_root", type=str, default="exp/exp3/output")
|
| 212 |
+
parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2")
|
| 213 |
+
parser.add_argument("--model_tag", type=str, default=None, help="If omitted, auto-detect when unique.")
|
| 214 |
+
parser.add_argument("--run_tag", type=str, default=None, help="If omitted, picks the latest run subdir.")
|
| 215 |
+
parser.add_argument("--example_idx", type=int, default=0)
|
| 216 |
+
parser.add_argument("--out", type=str, default=None, help="Optional JSON output path.")
|
| 217 |
+
args = parser.parse_args()
|
| 218 |
+
|
| 219 |
+
output_root = Path(args.output_root)
|
| 220 |
+
datasets = [f"{args.dataset_tag}_short_cot", f"{args.dataset_tag}_long_cot"]
|
| 221 |
+
|
| 222 |
+
results: List[Dict[str, Any]] = []
|
| 223 |
+
for ds_name in datasets:
|
| 224 |
+
paths = _resolve_trace_paths(
|
| 225 |
+
output_root=output_root,
|
| 226 |
+
dataset=ds_name,
|
| 227 |
+
model_tag=args.model_tag,
|
| 228 |
+
run_tag=args.run_tag,
|
| 229 |
+
example_idx=args.example_idx,
|
| 230 |
+
)
|
| 231 |
+
out = extract_one(paths.npz_path)
|
| 232 |
+
out["dataset"] = paths.dataset
|
| 233 |
+
out["model_tag"] = paths.model_tag
|
| 234 |
+
out["run_tag"] = paths.run_tag
|
| 235 |
+
out["npz_path"] = str(paths.npz_path)
|
| 236 |
+
results.append(out)
|
| 237 |
+
|
| 238 |
+
text = json.dumps(results, ensure_ascii=False, indent=2)
|
| 239 |
+
if args.out:
|
| 240 |
+
out_path = Path(args.out)
|
| 241 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 242 |
+
out_path.write_text(text + "\n", encoding="utf-8")
|
| 243 |
+
print(f"Wrote -> {out_path}")
|
| 244 |
+
else:
|
| 245 |
+
print(text)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
main()
|
| 250 |
+
|
exp/exp3/part_weights.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Compute attribution mass on (input, cot, output) segments from exp3 trace npz files.
|
| 4 |
+
|
| 5 |
+
Definitions (token-level, aligned with exp2/exp3 runners):
|
| 6 |
+
- input : prompt-side tokens (user prompt), indices [0, prompt_len)
|
| 7 |
+
- cot : generation tokens in thinking span, indices [prompt_len + t0, prompt_len + t1]
|
| 8 |
+
- output : generation tokens in sink span (answer), indices [prompt_len + s0, prompt_len + s1]
|
| 9 |
+
|
| 10 |
+
The trace stores token-importance vectors:
|
| 11 |
+
- v_seq_all, v_row_all, v_rec_all (length = prompt_len + gen_len)
|
| 12 |
+
|
| 13 |
+
This script sums those vectors over each segment and reports both absolute sums
|
| 14 |
+
and fractions of the total sum.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Dict, Iterable, List, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class TraceRun:
|
| 30 |
+
dataset: str
|
| 31 |
+
model: str
|
| 32 |
+
run_dir: Path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _pick_single_subdir(parent: Path) -> Path:
|
| 36 |
+
subdirs = [p for p in parent.iterdir() if p.is_dir()]
|
| 37 |
+
if not subdirs:
|
| 38 |
+
raise FileNotFoundError(f"No subdirectories found under {parent}")
|
| 39 |
+
if len(subdirs) == 1:
|
| 40 |
+
return subdirs[0]
|
| 41 |
+
subdirs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
| 42 |
+
return subdirs[0]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _resolve_run(
|
| 46 |
+
trace_root: Path,
|
| 47 |
+
*,
|
| 48 |
+
dataset: str,
|
| 49 |
+
model: Optional[str],
|
| 50 |
+
run_tag: Optional[str],
|
| 51 |
+
) -> TraceRun:
|
| 52 |
+
ds_dir = trace_root / dataset
|
| 53 |
+
if not ds_dir.exists():
|
| 54 |
+
raise FileNotFoundError(f"Dataset trace directory not found: {ds_dir}")
|
| 55 |
+
|
| 56 |
+
if model is None:
|
| 57 |
+
model_dir = _pick_single_subdir(ds_dir)
|
| 58 |
+
else:
|
| 59 |
+
model_dir = ds_dir / model
|
| 60 |
+
if not model_dir.exists():
|
| 61 |
+
raise FileNotFoundError(f"Model trace directory not found: {model_dir}")
|
| 62 |
+
|
| 63 |
+
if run_tag is None:
|
| 64 |
+
run_dir = _pick_single_subdir(model_dir)
|
| 65 |
+
else:
|
| 66 |
+
run_dir = model_dir / run_tag
|
| 67 |
+
if not run_dir.exists():
|
| 68 |
+
raise FileNotFoundError(f"Run directory not found: {run_dir}")
|
| 69 |
+
|
| 70 |
+
return TraceRun(dataset=dataset, model=model_dir.name, run_dir=run_dir)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _iter_manifest(run_dir: Path) -> Iterable[dict]:
|
| 74 |
+
manifest = run_dir / "manifest.jsonl"
|
| 75 |
+
if not manifest.exists():
|
| 76 |
+
raise FileNotFoundError(f"Missing manifest: {manifest}")
|
| 77 |
+
with manifest.open("r", encoding="utf-8") as f:
|
| 78 |
+
for line in f:
|
| 79 |
+
line = line.strip()
|
| 80 |
+
if line:
|
| 81 |
+
yield json.loads(line)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _as_span(arr: np.ndarray, *, name: str) -> Tuple[int, int]:
|
| 85 |
+
if arr is None:
|
| 86 |
+
raise ValueError(f"Missing {name} in trace npz.")
|
| 87 |
+
a = np.asarray(arr).reshape(-1)
|
| 88 |
+
if a.size != 2:
|
| 89 |
+
raise ValueError(f"Expected {name} to have 2 ints, got shape {a.shape}.")
|
| 90 |
+
return int(a[0]), int(a[1])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _segment_sums(
|
| 94 |
+
v: np.ndarray,
|
| 95 |
+
*,
|
| 96 |
+
prompt_len: int,
|
| 97 |
+
gen_len: int,
|
| 98 |
+
thinking_span_gen: Optional[Tuple[int, int]],
|
| 99 |
+
sink_span_gen: Optional[Tuple[int, int]],
|
| 100 |
+
) -> Dict[str, float]:
|
| 101 |
+
total_len = int(prompt_len) + int(gen_len)
|
| 102 |
+
if int(v.shape[0]) != total_len:
|
| 103 |
+
raise ValueError(f"Vector length mismatch: len(v)={int(v.shape[0])} vs prompt_len+gen_len={total_len}.")
|
| 104 |
+
|
| 105 |
+
v = np.asarray(v, dtype=np.float64).reshape(-1)
|
| 106 |
+
prompt_len = int(prompt_len)
|
| 107 |
+
gen_len = int(gen_len)
|
| 108 |
+
|
| 109 |
+
# Default: no cot/output when spans missing (should not happen in exp3).
|
| 110 |
+
think_start, think_end = (0, -1) if thinking_span_gen is None else thinking_span_gen
|
| 111 |
+
sink_start, sink_end = (0, -1) if sink_span_gen is None else sink_span_gen
|
| 112 |
+
|
| 113 |
+
# Clamp spans into [0, gen_len-1].
|
| 114 |
+
def _clamp_span(a: int, b: int) -> Tuple[int, int]:
|
| 115 |
+
a = max(0, min(int(a), gen_len - 1))
|
| 116 |
+
b = max(0, min(int(b), gen_len - 1))
|
| 117 |
+
if b < a:
|
| 118 |
+
return 0, -1
|
| 119 |
+
return a, b
|
| 120 |
+
|
| 121 |
+
think_start, think_end = _clamp_span(think_start, think_end)
|
| 122 |
+
sink_start, sink_end = _clamp_span(sink_start, sink_end)
|
| 123 |
+
|
| 124 |
+
mask = np.zeros((total_len,), dtype=bool)
|
| 125 |
+
# input = all prompt tokens
|
| 126 |
+
input_slice = slice(0, prompt_len)
|
| 127 |
+
mask[input_slice] = True
|
| 128 |
+
|
| 129 |
+
cot_slice = slice(prompt_len + think_start, prompt_len + think_end + 1) if think_end >= think_start else slice(0, 0)
|
| 130 |
+
output_slice = slice(prompt_len + sink_start, prompt_len + sink_end + 1) if sink_end >= sink_start else slice(0, 0)
|
| 131 |
+
mask[cot_slice] = True
|
| 132 |
+
mask[output_slice] = True
|
| 133 |
+
|
| 134 |
+
input_sum = float(v[input_slice].sum())
|
| 135 |
+
cot_sum = float(v[cot_slice].sum()) if think_end >= think_start else 0.0
|
| 136 |
+
output_sum = float(v[output_slice].sum()) if sink_end >= sink_start else 0.0
|
| 137 |
+
other_sum = float(v[~mask].sum())
|
| 138 |
+
total_sum = float(v.sum())
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
"total": total_sum,
|
| 142 |
+
"input": input_sum,
|
| 143 |
+
"cot": cot_sum,
|
| 144 |
+
"output": output_sum,
|
| 145 |
+
"other": other_sum,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _with_fracs(sums: Dict[str, float]) -> Dict[str, float]:
|
| 150 |
+
total = float(sums.get("total") or 0.0)
|
| 151 |
+
if total <= 0.0:
|
| 152 |
+
return {**sums, "input_frac": float("nan"), "cot_frac": float("nan"), "output_frac": float("nan"), "other_frac": float("nan")}
|
| 153 |
+
return {
|
| 154 |
+
**sums,
|
| 155 |
+
"input_frac": float(sums["input"]) / total,
|
| 156 |
+
"cot_frac": float(sums["cot"]) / total,
|
| 157 |
+
"output_frac": float(sums["output"]) / total,
|
| 158 |
+
"other_frac": float(sums["other"]) / total,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _analyze_npz(npz_path: Path) -> Dict[str, dict]:
|
| 163 |
+
d = np.load(npz_path)
|
| 164 |
+
prompt_len = int(np.asarray(d["prompt_len"]).item())
|
| 165 |
+
gen_len = int(np.asarray(d["gen_len"]).item())
|
| 166 |
+
thinking_span_gen = _as_span(d["thinking_span_gen"], name="thinking_span_gen") if "thinking_span_gen" in d.files else None
|
| 167 |
+
sink_span_gen = _as_span(d["sink_span_gen"], name="sink_span_gen") if "sink_span_gen" in d.files else None
|
| 168 |
+
|
| 169 |
+
out: Dict[str, dict] = {"prompt_len": prompt_len, "gen_len": gen_len}
|
| 170 |
+
for key in ("v_seq_all", "v_row_all", "v_rec_all"):
|
| 171 |
+
if key not in d.files:
|
| 172 |
+
raise ValueError(f"Missing {key} in trace npz: {npz_path}")
|
| 173 |
+
sums = _segment_sums(
|
| 174 |
+
d[key],
|
| 175 |
+
prompt_len=prompt_len,
|
| 176 |
+
gen_len=gen_len,
|
| 177 |
+
thinking_span_gen=thinking_span_gen,
|
| 178 |
+
sink_span_gen=sink_span_gen,
|
| 179 |
+
)
|
| 180 |
+
out[key] = _with_fracs(sums)
|
| 181 |
+
out["thinking_span_gen"] = list(thinking_span_gen) if thinking_span_gen is not None else None
|
| 182 |
+
out["sink_span_gen"] = list(sink_span_gen) if sink_span_gen is not None else None
|
| 183 |
+
return out
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def main() -> None:
|
| 187 |
+
parser = argparse.ArgumentParser("Summarize input/cot/output attribution mass from exp3 traces.")
|
| 188 |
+
parser.add_argument("--trace_root", type=str, default="exp/exp3/output/traces")
|
| 189 |
+
parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2", help="Base tag; expands to <tag>_short_cot and <tag>_long_cot.")
|
| 190 |
+
parser.add_argument("--datasets", type=str, default=None, help="Comma-separated dataset names (overrides --dataset_tag expansion).")
|
| 191 |
+
parser.add_argument("--model", type=str, default=None, help="Model directory name under traces (default: auto if single).")
|
| 192 |
+
parser.add_argument("--run_tag", type=str, default=None, help="Run tag directory (default: auto pick newest/single).")
|
| 193 |
+
args = parser.parse_args()
|
| 194 |
+
|
| 195 |
+
trace_root = Path(args.trace_root)
|
| 196 |
+
if not trace_root.exists():
|
| 197 |
+
raise SystemExit(f"trace_root not found: {trace_root}")
|
| 198 |
+
|
| 199 |
+
if args.datasets:
|
| 200 |
+
datasets = [x.strip() for x in str(args.datasets).split(",") if x.strip()]
|
| 201 |
+
else:
|
| 202 |
+
datasets = [f"{args.dataset_tag}_short_cot", f"{args.dataset_tag}_long_cot"]
|
| 203 |
+
|
| 204 |
+
for ds in datasets:
|
| 205 |
+
run = _resolve_run(trace_root, dataset=ds, model=args.model, run_tag=args.run_tag)
|
| 206 |
+
records = list(_iter_manifest(run.run_dir))
|
| 207 |
+
if not records:
|
| 208 |
+
raise SystemExit(f"Empty manifest: {run.run_dir/'manifest.jsonl'}")
|
| 209 |
+
for rec in records:
|
| 210 |
+
npz_path = run.run_dir / str(rec["file"])
|
| 211 |
+
analysis = _analyze_npz(npz_path)
|
| 212 |
+
print(
|
| 213 |
+
json.dumps(
|
| 214 |
+
{
|
| 215 |
+
"dataset": run.dataset,
|
| 216 |
+
"model": run.model,
|
| 217 |
+
"run_dir": str(run.run_dir),
|
| 218 |
+
"example_idx": int(rec.get("example_idx", -1)),
|
| 219 |
+
**analysis,
|
| 220 |
+
},
|
| 221 |
+
ensure_ascii=False,
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main()
|
| 228 |
+
|
exp/exp3/run_exp.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Experiment 3 runner: long-vs-short CoT case study (AttnLRP hop0, Recovery@10%).
|
| 4 |
+
|
| 5 |
+
This runner is intentionally minimal:
|
| 6 |
+
- Only reads two cached samples produced by exp/exp3/sample_and_filter.py:
|
| 7 |
+
<dataset_tag>_short_cot.jsonl
|
| 8 |
+
<dataset_tag>_long_cot.jsonl
|
| 9 |
+
- Only runs attribution method: attnlrp (hop0 path, aligned with exp2).
|
| 10 |
+
- Only computes token-level recovery (Recall@10%) using RULER needle_spans.
|
| 11 |
+
- Always saves per-sample trace artifacts under exp/exp3/output/traces/.
|
| 12 |
+
|
| 13 |
+
All outputs are written under exp/exp3/output/ (configurable via --output_root).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import hashlib
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import time
|
| 24 |
+
from itertools import islice
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _early_set_cuda_visible_devices() -> None:
|
| 30 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 31 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 32 |
+
args, _ = parser.parse_known_args(sys.argv[1:])
|
| 33 |
+
if args.cuda and "," in args.cuda:
|
| 34 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_early_set_cuda_visible_devices()
|
| 38 |
+
|
| 39 |
+
import numpy as np
|
| 40 |
+
import torch
|
| 41 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, utils
|
| 42 |
+
|
| 43 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 44 |
+
if str(REPO_ROOT) not in sys.path:
|
| 45 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 46 |
+
|
| 47 |
+
import llm_attr
|
| 48 |
+
import llm_attr_eval
|
| 49 |
+
from exp.exp2 import dataset_utils as ds_utils
|
| 50 |
+
|
| 51 |
+
utils.logging.set_verbosity_error()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _sha1_text(text: str) -> str:
|
| 55 |
+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _token_importance_vector(attr: torch.Tensor) -> np.ndarray:
|
| 59 |
+
w = torch.nan_to_num(attr.sum(0).to(dtype=torch.float32), nan=0.0).clamp(min=0.0)
|
| 60 |
+
return w.detach().cpu().numpy().astype(np.float32, copy=False)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _trace_run_tag(*, neg_handling: str, norm_mode: str, total: int) -> str:
|
| 64 |
+
return f"attnlrp_neg{neg_handling}_norm{norm_mode}_recovery_{int(total)}ex"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _build_sample_trace_payload(
|
| 68 |
+
example: ds_utils.CachedExample,
|
| 69 |
+
*,
|
| 70 |
+
seq_attr: torch.Tensor,
|
| 71 |
+
row_attr: torch.Tensor,
|
| 72 |
+
rec_attr: torch.Tensor,
|
| 73 |
+
prompt_len: int,
|
| 74 |
+
user_prompt_indices: Optional[List[int]],
|
| 75 |
+
gold_prompt_token_indices: Optional[List[int]],
|
| 76 |
+
recovery_scores: Optional[np.ndarray],
|
| 77 |
+
time_attr_s: Optional[float],
|
| 78 |
+
time_recovery_s: Optional[float],
|
| 79 |
+
) -> Dict[str, np.ndarray]:
|
| 80 |
+
gen_len = int(seq_attr.shape[0])
|
| 81 |
+
|
| 82 |
+
v_seq_all = _token_importance_vector(seq_attr)
|
| 83 |
+
v_row_all = _token_importance_vector(row_attr)
|
| 84 |
+
v_rec_all = _token_importance_vector(rec_attr)
|
| 85 |
+
|
| 86 |
+
payload: Dict[str, np.ndarray] = {
|
| 87 |
+
"v_seq_all": v_seq_all,
|
| 88 |
+
"v_row_all": v_row_all,
|
| 89 |
+
"v_rec_all": v_rec_all,
|
| 90 |
+
"v_seq_prompt": v_seq_all[:prompt_len],
|
| 91 |
+
"v_row_prompt": v_row_all[:prompt_len],
|
| 92 |
+
"v_rec_prompt": v_rec_all[:prompt_len],
|
| 93 |
+
"prompt_len": np.asarray(int(prompt_len), dtype=np.int64),
|
| 94 |
+
"gen_len": np.asarray(int(gen_len), dtype=np.int64),
|
| 95 |
+
"indices_to_explain_gen": np.asarray(list(example.indices_to_explain or []), dtype=np.int64),
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
if example.sink_span is not None:
|
| 99 |
+
payload["sink_span_gen"] = np.asarray(list(example.sink_span), dtype=np.int64)
|
| 100 |
+
if example.thinking_span is not None:
|
| 101 |
+
payload["thinking_span_gen"] = np.asarray(list(example.thinking_span), dtype=np.int64)
|
| 102 |
+
|
| 103 |
+
if user_prompt_indices is not None:
|
| 104 |
+
payload["user_prompt_indices"] = np.asarray(list(user_prompt_indices), dtype=np.int64)
|
| 105 |
+
if gold_prompt_token_indices is not None:
|
| 106 |
+
payload["gold_prompt_token_indices"] = np.asarray(list(gold_prompt_token_indices), dtype=np.int64)
|
| 107 |
+
|
| 108 |
+
if recovery_scores is not None:
|
| 109 |
+
payload["recovery_scores"] = np.asarray(recovery_scores, dtype=np.float64)
|
| 110 |
+
|
| 111 |
+
if time_attr_s is not None:
|
| 112 |
+
payload["time_attr_s"] = np.asarray(float(time_attr_s), dtype=np.float64)
|
| 113 |
+
if time_recovery_s is not None:
|
| 114 |
+
payload["time_recovery_s"] = np.asarray(float(time_recovery_s), dtype=np.float64)
|
| 115 |
+
|
| 116 |
+
return payload
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _write_sample_trace(
|
| 120 |
+
trace_dir: Path,
|
| 121 |
+
*,
|
| 122 |
+
example_idx: int,
|
| 123 |
+
prompt: str,
|
| 124 |
+
target: str,
|
| 125 |
+
payload: Dict[str, np.ndarray],
|
| 126 |
+
manifest_handle,
|
| 127 |
+
neg_handling: str,
|
| 128 |
+
norm_mode: str,
|
| 129 |
+
recovery_skipped_reason: Optional[str],
|
| 130 |
+
) -> None:
|
| 131 |
+
trace_dir.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
npz_name = f"ex_{example_idx:06d}.npz"
|
| 133 |
+
npz_path = trace_dir / npz_name
|
| 134 |
+
np.savez_compressed(npz_path, **payload)
|
| 135 |
+
|
| 136 |
+
prompt_len = int(np.asarray(payload.get("prompt_len", 0)).item())
|
| 137 |
+
gen_len = int(np.asarray(payload.get("gen_len", 0)).item())
|
| 138 |
+
record: Dict[str, Any] = {
|
| 139 |
+
"example_idx": int(example_idx),
|
| 140 |
+
"attr_func": "attnlrp",
|
| 141 |
+
"file": npz_name,
|
| 142 |
+
"prompt_sha1": _sha1_text(prompt),
|
| 143 |
+
"target_sha1": _sha1_text(target),
|
| 144 |
+
"prompt_len": prompt_len,
|
| 145 |
+
"gen_len": gen_len,
|
| 146 |
+
"indices_to_explain_gen": payload.get("indices_to_explain_gen").tolist()
|
| 147 |
+
if payload.get("indices_to_explain_gen") is not None
|
| 148 |
+
else None,
|
| 149 |
+
"sink_span_gen": payload.get("sink_span_gen").tolist() if payload.get("sink_span_gen") is not None else None,
|
| 150 |
+
"thinking_span_gen": payload.get("thinking_span_gen").tolist()
|
| 151 |
+
if payload.get("thinking_span_gen") is not None
|
| 152 |
+
else None,
|
| 153 |
+
"gold_prompt_token_indices": payload.get("gold_prompt_token_indices").tolist()
|
| 154 |
+
if payload.get("gold_prompt_token_indices") is not None
|
| 155 |
+
else None,
|
| 156 |
+
"recovery_scores": payload.get("recovery_scores").tolist() if payload.get("recovery_scores") is not None else None,
|
| 157 |
+
"recovery_skipped_reason": recovery_skipped_reason,
|
| 158 |
+
"time_attr_s": float(np.asarray(payload.get("time_attr_s")).item()) if payload.get("time_attr_s") is not None else None,
|
| 159 |
+
"time_recovery_s": float(np.asarray(payload.get("time_recovery_s")).item())
|
| 160 |
+
if payload.get("time_recovery_s") is not None
|
| 161 |
+
else None,
|
| 162 |
+
"attnlrp_neg_handling": str(neg_handling),
|
| 163 |
+
"attnlrp_norm_mode": str(norm_mode),
|
| 164 |
+
}
|
| 165 |
+
manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 166 |
+
manifest_handle.flush()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def resolve_device(args) -> str:
|
| 170 |
+
if args.cuda is not None and "," in args.cuda:
|
| 171 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 172 |
+
return "auto"
|
| 173 |
+
if args.cuda is not None and str(args.cuda).strip():
|
| 174 |
+
return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
|
| 175 |
+
return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def load_model(model_name: str, device: str):
|
| 179 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 180 |
+
model_name,
|
| 181 |
+
device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
|
| 182 |
+
torch_dtype=torch.float16,
|
| 183 |
+
attn_implementation="eager",
|
| 184 |
+
)
|
| 185 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 186 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 187 |
+
model.eval()
|
| 188 |
+
return model, tokenizer
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _evaluate_one_dataset(
|
| 192 |
+
*,
|
| 193 |
+
dataset_name: str,
|
| 194 |
+
examples: List[ds_utils.CachedExample],
|
| 195 |
+
model,
|
| 196 |
+
tokenizer,
|
| 197 |
+
output_root: Path,
|
| 198 |
+
model_tag: str,
|
| 199 |
+
neg_handling: str,
|
| 200 |
+
norm_mode: str,
|
| 201 |
+
top_fraction: float,
|
| 202 |
+
num_examples: int,
|
| 203 |
+
) -> Tuple[np.ndarray, np.ndarray, float, int, int]:
|
| 204 |
+
llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(model, tokenizer)
|
| 205 |
+
|
| 206 |
+
results: List[np.ndarray] = []
|
| 207 |
+
durations: List[float] = []
|
| 208 |
+
skipped = 0
|
| 209 |
+
|
| 210 |
+
total = min(len(examples), int(num_examples))
|
| 211 |
+
iterator = islice(examples, total)
|
| 212 |
+
|
| 213 |
+
run_tag = _trace_run_tag(neg_handling=neg_handling, norm_mode=norm_mode, total=total)
|
| 214 |
+
trace_dir = output_root / "traces" / dataset_name / model_tag / run_tag
|
| 215 |
+
trace_dir.mkdir(parents=True, exist_ok=True)
|
| 216 |
+
manifest_handle = open(trace_dir / "manifest.jsonl", "w", encoding="utf-8")
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
for example_idx, ex in enumerate(iterator):
|
| 220 |
+
time_recovery_s: Optional[float] = None
|
| 221 |
+
recovery_scores: Optional[np.ndarray] = None
|
| 222 |
+
|
| 223 |
+
needle_spans = (ex.metadata or {}).get("needle_spans")
|
| 224 |
+
if not isinstance(needle_spans, list) or not needle_spans:
|
| 225 |
+
raise SystemExit(
|
| 226 |
+
"exp3 recovery requires RULER samples with metadata.needle_spans; "
|
| 227 |
+
f"dataset={dataset_name} has missing/empty needle_spans."
|
| 228 |
+
)
|
| 229 |
+
if ex.target is None:
|
| 230 |
+
raise SystemExit(
|
| 231 |
+
"exp3 recovery requires cached targets (CoT+answer) so row/rec attribution is well-defined. "
|
| 232 |
+
f"dataset={dataset_name} has target=None; run exp/exp3/sample_and_filter.py first."
|
| 233 |
+
)
|
| 234 |
+
if not (isinstance(ex.indices_to_explain, list) and len(ex.indices_to_explain) == 2):
|
| 235 |
+
raise SystemExit(
|
| 236 |
+
"exp3 expects indices_to_explain=[start_tok,end_tok] in generation-token coordinates; "
|
| 237 |
+
f"dataset={dataset_name} has indices_to_explain={ex.indices_to_explain!r}; "
|
| 238 |
+
"run exp/exp3/sample_and_filter.py first."
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer)
|
| 242 |
+
recovery_skip_reason: Optional[str] = None
|
| 243 |
+
|
| 244 |
+
sample_start = time.perf_counter()
|
| 245 |
+
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
|
| 246 |
+
attr_result = llm_attributor.calculate_attnlrp_ft_hop0(
|
| 247 |
+
ex.prompt,
|
| 248 |
+
target=ex.target,
|
| 249 |
+
sink_span=tuple(ex.sink_span) if ex.sink_span else None,
|
| 250 |
+
thinking_span=tuple(ex.thinking_span) if ex.thinking_span else None,
|
| 251 |
+
neg_handling=str(neg_handling),
|
| 252 |
+
norm_mode=str(norm_mode),
|
| 253 |
+
)
|
| 254 |
+
seq_attr, row_attr, rec_attr = attr_result.get_all_token_attrs(list(ex.indices_to_explain))
|
| 255 |
+
time_attr_s = time.perf_counter() - sample_start
|
| 256 |
+
durations.append(float(time_attr_s))
|
| 257 |
+
|
| 258 |
+
prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0])
|
| 259 |
+
if prompt_len <= 0:
|
| 260 |
+
recovery_skip_reason = "empty_prompt_len"
|
| 261 |
+
elif not gold_prompt:
|
| 262 |
+
recovery_skip_reason = "empty_gold_prompt"
|
| 263 |
+
else:
|
| 264 |
+
t2 = time.perf_counter()
|
| 265 |
+
recovery_scores = np.asarray(
|
| 266 |
+
[
|
| 267 |
+
llm_evaluator.evaluate_attr_recovery(
|
| 268 |
+
a,
|
| 269 |
+
prompt_len=prompt_len,
|
| 270 |
+
gold_prompt_token_indices=gold_prompt,
|
| 271 |
+
top_fraction=top_fraction,
|
| 272 |
+
)
|
| 273 |
+
for a in (seq_attr, row_attr, rec_attr)
|
| 274 |
+
],
|
| 275 |
+
dtype=np.float64,
|
| 276 |
+
)
|
| 277 |
+
time_recovery_s = time.perf_counter() - t2
|
| 278 |
+
if np.isnan(recovery_scores).any():
|
| 279 |
+
recovery_scores = None
|
| 280 |
+
recovery_skip_reason = "nan_recovery"
|
| 281 |
+
|
| 282 |
+
if recovery_scores is None and recovery_skip_reason is not None:
|
| 283 |
+
skipped += 1
|
| 284 |
+
elif recovery_scores is not None:
|
| 285 |
+
results.append(recovery_scores)
|
| 286 |
+
|
| 287 |
+
payload = _build_sample_trace_payload(
|
| 288 |
+
ex,
|
| 289 |
+
seq_attr=seq_attr,
|
| 290 |
+
row_attr=row_attr,
|
| 291 |
+
rec_attr=rec_attr,
|
| 292 |
+
prompt_len=prompt_len,
|
| 293 |
+
user_prompt_indices=getattr(llm_attributor, "user_prompt_indices", None),
|
| 294 |
+
gold_prompt_token_indices=gold_prompt,
|
| 295 |
+
recovery_scores=recovery_scores,
|
| 296 |
+
time_attr_s=time_attr_s,
|
| 297 |
+
time_recovery_s=time_recovery_s,
|
| 298 |
+
)
|
| 299 |
+
_write_sample_trace(
|
| 300 |
+
trace_dir,
|
| 301 |
+
example_idx=example_idx,
|
| 302 |
+
prompt=ex.prompt,
|
| 303 |
+
target=str(ex.target),
|
| 304 |
+
payload=payload,
|
| 305 |
+
manifest_handle=manifest_handle,
|
| 306 |
+
neg_handling=str(neg_handling),
|
| 307 |
+
norm_mode=str(norm_mode),
|
| 308 |
+
recovery_skipped_reason=recovery_skip_reason,
|
| 309 |
+
)
|
| 310 |
+
finally:
|
| 311 |
+
try:
|
| 312 |
+
manifest_handle.close()
|
| 313 |
+
except Exception:
|
| 314 |
+
pass
|
| 315 |
+
|
| 316 |
+
scores = np.stack(results, axis=0) if results else np.zeros((0, 3), dtype=np.float64)
|
| 317 |
+
used = int(scores.shape[0])
|
| 318 |
+
mean = scores.mean(0) if used else np.full((3,), np.nan, dtype=np.float64)
|
| 319 |
+
std = scores.std(0) if used else np.full((3,), np.nan, dtype=np.float64)
|
| 320 |
+
avg_time = float(np.mean(durations)) if durations else 0.0
|
| 321 |
+
return mean, std, avg_time, used, int(skipped)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def main() -> None:
|
| 325 |
+
parser = argparse.ArgumentParser("Experiment 3 runner (attnlrp hop0, recovery only).")
|
| 326 |
+
parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2", help="Base tag for exp3 caches.")
|
| 327 |
+
parser.add_argument("--data_root", type=str, default="exp/exp3/data")
|
| 328 |
+
parser.add_argument("--output_root", type=str, default="exp/exp3/output")
|
| 329 |
+
parser.add_argument("--num_examples", type=int, default=1, help="How many examples to evaluate per dataset (default 1).")
|
| 330 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 331 |
+
parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).")
|
| 332 |
+
parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.")
|
| 333 |
+
parser.add_argument("--cuda_num", type=int, default=0)
|
| 334 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 335 |
+
parser.add_argument("--top_fraction", type=float, default=0.1, help="Top fraction of prompt tokens used for recovery.")
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--attnlrp_neg_handling",
|
| 338 |
+
type=str,
|
| 339 |
+
choices=["drop", "abs"],
|
| 340 |
+
default="drop",
|
| 341 |
+
help="AttnLRP hop0: how to handle negative values (drop=clamp>=0, abs=absolute value).",
|
| 342 |
+
)
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--attnlrp_norm_mode",
|
| 345 |
+
type=str,
|
| 346 |
+
choices=["norm", "no_norm"],
|
| 347 |
+
default="norm",
|
| 348 |
+
help="AttnLRP hop0: norm enables internal normalization; no_norm disables it.",
|
| 349 |
+
)
|
| 350 |
+
args = parser.parse_args()
|
| 351 |
+
|
| 352 |
+
if args.model_path:
|
| 353 |
+
model_name = args.model_path
|
| 354 |
+
elif args.model:
|
| 355 |
+
model_name = args.model
|
| 356 |
+
else:
|
| 357 |
+
raise SystemExit("Please set --model or --model_path.")
|
| 358 |
+
model_tag = args.model if args.model else Path(args.model_path).name
|
| 359 |
+
|
| 360 |
+
device = resolve_device(args)
|
| 361 |
+
model, tokenizer = load_model(model_name, device)
|
| 362 |
+
|
| 363 |
+
data_root = Path(args.data_root)
|
| 364 |
+
output_root = Path(args.output_root)
|
| 365 |
+
output_root.mkdir(parents=True, exist_ok=True)
|
| 366 |
+
|
| 367 |
+
short_name = f"{args.dataset_tag}_short_cot"
|
| 368 |
+
long_name = f"{args.dataset_tag}_long_cot"
|
| 369 |
+
dataset_names = [short_name, long_name]
|
| 370 |
+
|
| 371 |
+
summary_rows: List[Dict[str, Any]] = []
|
| 372 |
+
|
| 373 |
+
for ds_name in dataset_names:
|
| 374 |
+
cache_path = data_root / f"{ds_name}.jsonl"
|
| 375 |
+
if not cache_path.exists():
|
| 376 |
+
raise SystemExit(f"Missing exp3 cache: {cache_path}. Run exp/exp3/sample_and_filter.py first.")
|
| 377 |
+
examples = ds_utils.load_cached(cache_path, sample=None, seed=args.seed)
|
| 378 |
+
|
| 379 |
+
mean, std, avg_time, used, skipped = _evaluate_one_dataset(
|
| 380 |
+
dataset_name=ds_name,
|
| 381 |
+
examples=examples,
|
| 382 |
+
model=model,
|
| 383 |
+
tokenizer=tokenizer,
|
| 384 |
+
output_root=output_root,
|
| 385 |
+
model_tag=model_tag,
|
| 386 |
+
neg_handling=args.attnlrp_neg_handling,
|
| 387 |
+
norm_mode=args.attnlrp_norm_mode,
|
| 388 |
+
top_fraction=float(args.top_fraction),
|
| 389 |
+
num_examples=int(args.num_examples),
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
out_dir = output_root / "recovery" / ds_name / model_tag
|
| 393 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 394 |
+
filename = f"attnlrp_{int(args.num_examples)}_examples.csv"
|
| 395 |
+
with (out_dir / filename).open("w", encoding="utf-8") as f:
|
| 396 |
+
f.write("Method,Recovery@10%\n")
|
| 397 |
+
f.write(f"Seq Attr Recovery Mean,{mean[0]}\n")
|
| 398 |
+
f.write(f"Row Attr Recovery Mean,{mean[1]}\n")
|
| 399 |
+
f.write(f"Recursive Attr Recovery Mean,{mean[2]}\n")
|
| 400 |
+
f.write(f"Seq Attr Recovery Std,{std[0]}\n")
|
| 401 |
+
f.write(f"Row Attr Recovery Std,{std[1]}\n")
|
| 402 |
+
f.write(f"Recursive Attr Recovery Std,{std[2]}\n")
|
| 403 |
+
f.write(f"Examples Used,{used}\n")
|
| 404 |
+
f.write(f"Examples Skipped,{skipped}\n")
|
| 405 |
+
f.write(f"Avg Sample Time (s),{avg_time}\n")
|
| 406 |
+
|
| 407 |
+
print(f"[{ds_name}] attnlrp -> {out_dir/filename} (used={used} skipped={skipped} avg {avg_time:.2f}s)")
|
| 408 |
+
summary_rows.append(
|
| 409 |
+
{
|
| 410 |
+
"dataset": ds_name,
|
| 411 |
+
"model": model_tag,
|
| 412 |
+
"neg_handling": args.attnlrp_neg_handling,
|
| 413 |
+
"norm_mode": args.attnlrp_norm_mode,
|
| 414 |
+
"seq_recovery@10%": float(mean[0]) if used else float("nan"),
|
| 415 |
+
"row_recovery@10%": float(mean[1]) if used else float("nan"),
|
| 416 |
+
"rec_recovery@10%": float(mean[2]) if used else float("nan"),
|
| 417 |
+
"used": int(used),
|
| 418 |
+
"skipped": int(skipped),
|
| 419 |
+
}
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Lightweight combined summary for quick comparison.
|
| 423 |
+
summary_path = output_root / "recovery" / f"summary_{model_tag}.json"
|
| 424 |
+
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
| 425 |
+
summary_path.write_text(json.dumps(summary_rows, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
| 426 |
+
print(f"Wrote summary -> {summary_path}")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
main()
|
exp/exp3/sample_and_filter.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Experiment 3 sampler: long-vs-short CoT case study (RULER niah_mq_q2, 1024).
|
| 4 |
+
|
| 5 |
+
This script searches the raw RULER JSONL for a *single* prompt where BOTH:
|
| 6 |
+
- a short-CoT generation and a long-CoT generation
|
| 7 |
+
- follow the strict format: "<thinking text> + final \\box{...} answer" with
|
| 8 |
+
nothing after the box
|
| 9 |
+
- pass a judge model verifying the boxed answer matches the reference answer
|
| 10 |
+
- satisfy length constraints (short <= max_short_thinking_tokens,
|
| 11 |
+
long >= min_long_thinking_tokens)
|
| 12 |
+
|
| 13 |
+
It writes two exp2-compatible cache JSONLs to exp/exp3/data/:
|
| 14 |
+
- <dataset_tag>_short_cot.jsonl
|
| 15 |
+
- <dataset_tag>_long_cot.jsonl
|
| 16 |
+
|
| 17 |
+
Each JSONL line matches exp/exp2/dataset_utils.CachedExample schema and keeps
|
| 18 |
+
RULER metadata. The output caches are intended to be consumed by exp/exp3/run_exp.py.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import hashlib
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
import time
|
| 29 |
+
import urllib.error
|
| 30 |
+
import urllib.request
|
| 31 |
+
from dataclasses import dataclass
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import Any, Dict, Iterable, List, Optional
|
| 34 |
+
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
from transformers import AutoTokenizer
|
| 37 |
+
|
| 38 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 39 |
+
if str(REPO_ROOT) not in sys.path:
|
| 40 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 41 |
+
|
| 42 |
+
from exp.exp2 import dataset_utils as ds_utils
|
| 43 |
+
from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer, split_boxed_generation
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RateLimitError(RuntimeError):
|
| 47 |
+
"""Raised when API returns 429; carries a suggested wait time."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, wait_seconds: float, detail: str) -> None:
|
| 50 |
+
super().__init__(detail)
|
| 51 |
+
self.wait_seconds = wait_seconds
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
SHORT_COT_SYSTEM_PROMPT = (
|
| 55 |
+
"You are a reasoning assistant. "
|
| 56 |
+
"Before answering, engage in a brief chain of thought. "
|
| 57 |
+
"Process this freely and naturally without using specific headers or strict formatting. "
|
| 58 |
+
"When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
|
| 59 |
+
"Ensure the box wraps the **sentence** that naturally delivers the answer. "
|
| 60 |
+
"Do not add anything after the box."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
LONG_COT_SYSTEM_PROMPT = (
|
| 64 |
+
"You are a careful reasoning assistant. "
|
| 65 |
+
"Before answering, engage in an extremely detailed and exhaustive chain of thought. "
|
| 66 |
+
"Do not skip any logical steps, even if they seem obvious. "
|
| 67 |
+
"Process this freely and naturally without using specific headers or strict formatting. "
|
| 68 |
+
"When you reach the conclusion, wrap the entire final sentence containing the answer inside \\box{}. "
|
| 69 |
+
"Ensure the box wraps the **sentence** that naturally delivers the answer. "
|
| 70 |
+
"Do not add anything after the box."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
JUDGE_SYSTEM_PROMPT = (
|
| 74 |
+
"You verify whether the model's boxed answer matches the reference answer. "
|
| 75 |
+
"Reply strictly with True or False and nothing else."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _sha1_text(text: str) -> str:
|
| 80 |
+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def call_chat_api(
|
| 84 |
+
api_base: str,
|
| 85 |
+
api_key: str,
|
| 86 |
+
model: str,
|
| 87 |
+
messages: List[Dict[str, str]],
|
| 88 |
+
*,
|
| 89 |
+
timeout: int,
|
| 90 |
+
max_tokens: int,
|
| 91 |
+
temperature: float,
|
| 92 |
+
cache_ttl: int,
|
| 93 |
+
cache_namespace: Optional[str],
|
| 94 |
+
rate_limit_delay: Optional[float] = None,
|
| 95 |
+
) -> str:
|
| 96 |
+
url = api_base.rstrip("/") + "/chat/completions"
|
| 97 |
+
payload: Dict[str, Any] = {
|
| 98 |
+
"model": model,
|
| 99 |
+
"messages": messages,
|
| 100 |
+
"max_tokens": max_tokens,
|
| 101 |
+
"temperature": temperature,
|
| 102 |
+
}
|
| 103 |
+
if cache_ttl > 0:
|
| 104 |
+
cache_obj: Dict[str, Any] = {"ttl": cache_ttl}
|
| 105 |
+
if cache_namespace:
|
| 106 |
+
cache_obj["namespace"] = cache_namespace
|
| 107 |
+
payload["cache"] = cache_obj
|
| 108 |
+
|
| 109 |
+
data = json.dumps(payload).encode("utf-8")
|
| 110 |
+
headers = {"Content-Type": "application/json"}
|
| 111 |
+
if api_key:
|
| 112 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
| 113 |
+
|
| 114 |
+
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
|
| 115 |
+
opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
|
| 116 |
+
try:
|
| 117 |
+
with opener.open(req, timeout=timeout) as resp:
|
| 118 |
+
resp_bytes = resp.read()
|
| 119 |
+
except urllib.error.HTTPError as e:
|
| 120 |
+
detail = e.read().decode("utf-8", errors="ignore") if hasattr(e, "read") else ""
|
| 121 |
+
if e.code == 429:
|
| 122 |
+
retry_after = None
|
| 123 |
+
if hasattr(e, "headers") and e.headers:
|
| 124 |
+
retry_after_header = e.headers.get("Retry-After")
|
| 125 |
+
if retry_after_header:
|
| 126 |
+
try:
|
| 127 |
+
retry_after = float(retry_after_header)
|
| 128 |
+
except ValueError:
|
| 129 |
+
retry_after = None
|
| 130 |
+
wait = retry_after or rate_limit_delay or 5.0
|
| 131 |
+
raise RateLimitError(wait, f"API HTTP 429: {detail}") from e
|
| 132 |
+
raise RuntimeError(f"API HTTP error {e.code}: {detail}") from e
|
| 133 |
+
except urllib.error.URLError as e:
|
| 134 |
+
raise RuntimeError(f"API request failed: {e}") from e
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
response = json.loads(resp_bytes.decode("utf-8"))
|
| 138 |
+
except json.JSONDecodeError as e:
|
| 139 |
+
raise RuntimeError(f"Failed to decode API response: {resp_bytes!r}") from e
|
| 140 |
+
|
| 141 |
+
choices = response.get("choices", [])
|
| 142 |
+
if not choices:
|
| 143 |
+
raise RuntimeError(f"Empty choices from API: {response}")
|
| 144 |
+
content = choices[0].get("message", {}).get("content", "")
|
| 145 |
+
if not content:
|
| 146 |
+
raise RuntimeError(f"Empty content from API: {response}")
|
| 147 |
+
return content.strip()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _call_with_retries(
|
| 151 |
+
*,
|
| 152 |
+
api_base: str,
|
| 153 |
+
api_key: str,
|
| 154 |
+
model: str,
|
| 155 |
+
messages: List[Dict[str, str]],
|
| 156 |
+
timeout: int,
|
| 157 |
+
max_tokens: int,
|
| 158 |
+
temperature: float,
|
| 159 |
+
cache_ttl: int,
|
| 160 |
+
cache_namespace: Optional[str],
|
| 161 |
+
rate_limit_delay: float,
|
| 162 |
+
retries: int,
|
| 163 |
+
retry_delay: float,
|
| 164 |
+
) -> str:
|
| 165 |
+
for attempt in range(retries + 1):
|
| 166 |
+
try:
|
| 167 |
+
return call_chat_api(
|
| 168 |
+
api_base,
|
| 169 |
+
api_key,
|
| 170 |
+
model,
|
| 171 |
+
messages,
|
| 172 |
+
timeout=timeout,
|
| 173 |
+
max_tokens=max_tokens,
|
| 174 |
+
temperature=temperature,
|
| 175 |
+
cache_ttl=cache_ttl,
|
| 176 |
+
cache_namespace=cache_namespace,
|
| 177 |
+
rate_limit_delay=rate_limit_delay,
|
| 178 |
+
)
|
| 179 |
+
except RateLimitError as e:
|
| 180 |
+
if attempt >= retries:
|
| 181 |
+
raise
|
| 182 |
+
time.sleep(e.wait_seconds)
|
| 183 |
+
except Exception: # noqa: BLE001
|
| 184 |
+
if attempt >= retries:
|
| 185 |
+
raise
|
| 186 |
+
time.sleep(retry_delay)
|
| 187 |
+
raise RuntimeError("Unreachable")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def build_gen_messages(prompt: str, system_prompt: str) -> List[Dict[str, str]]:
|
| 191 |
+
return [
|
| 192 |
+
{"role": "system", "content": system_prompt},
|
| 193 |
+
{"role": "user", "content": prompt},
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def build_judge_messages(reference_answer: str, candidate_answer: str) -> List[Dict[str, str]]:
|
| 198 |
+
user = (
|
| 199 |
+
"Decide if the model's boxed answer matches the reference answer.\n"
|
| 200 |
+
f"Reference answer: {reference_answer}\n"
|
| 201 |
+
f"Model boxed answer (only the content inside \\box{{}}): {candidate_answer}\n"
|
| 202 |
+
"Output only True if they are semantically consistent; otherwise output False."
|
| 203 |
+
)
|
| 204 |
+
return [
|
| 205 |
+
{"role": "system", "content": JUDGE_SYSTEM_PROMPT},
|
| 206 |
+
{"role": "user", "content": user},
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def parse_bool(text: str) -> bool:
|
| 211 |
+
first = text.strip().splitlines()[0].strip().lower()
|
| 212 |
+
if first in {"true", "yes"}:
|
| 213 |
+
return True
|
| 214 |
+
if first in {"false", "no"}:
|
| 215 |
+
return False
|
| 216 |
+
if "true" in first and "false" not in first:
|
| 217 |
+
return True
|
| 218 |
+
if "false" in first:
|
| 219 |
+
return False
|
| 220 |
+
raise ValueError(f"Cannot parse boolean from: {text!r}")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def write_cache(out_path: Path, examples: Iterable[CachedExample]) -> int:
|
| 224 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 225 |
+
count = 0
|
| 226 |
+
with out_path.open("w", encoding="utf-8") as f:
|
| 227 |
+
for ex in examples:
|
| 228 |
+
obj: Dict[str, Any] = {
|
| 229 |
+
"prompt": ex.prompt,
|
| 230 |
+
"target": ex.target,
|
| 231 |
+
"indices_to_explain": ex.indices_to_explain,
|
| 232 |
+
"attr_mask_indices": ex.attr_mask_indices,
|
| 233 |
+
"sink_span": ex.sink_span,
|
| 234 |
+
"thinking_span": ex.thinking_span,
|
| 235 |
+
"metadata": ex.metadata,
|
| 236 |
+
}
|
| 237 |
+
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
| 238 |
+
count += 1
|
| 239 |
+
return count
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@dataclass(frozen=True)
|
| 243 |
+
class AcceptedGeneration:
|
| 244 |
+
thinking_text: str
|
| 245 |
+
boxed_answer: str
|
| 246 |
+
target_text: str
|
| 247 |
+
thinking_tokens: int
|
| 248 |
+
generation_text: str
|
| 249 |
+
judge_response: str
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _infer_reference_answer(example: CachedExample) -> str:
|
| 253 |
+
meta = example.metadata or {}
|
| 254 |
+
ref = str(meta.get("reference_answer") or "").strip()
|
| 255 |
+
if ref:
|
| 256 |
+
return ref
|
| 257 |
+
outputs = meta.get("outputs") or []
|
| 258 |
+
if isinstance(outputs, list) and outputs:
|
| 259 |
+
return ", ".join(str(x) for x in outputs)
|
| 260 |
+
tgt = str(example.target or "").strip()
|
| 261 |
+
return tgt
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _infer_dataset_tag(dataset_path: Path) -> str:
|
| 265 |
+
if dataset_path.name.endswith(".jsonl") and dataset_path.name != "validation.jsonl":
|
| 266 |
+
return dataset_path.stem
|
| 267 |
+
if dataset_path.name == "validation.jsonl":
|
| 268 |
+
return dataset_path.parent.name
|
| 269 |
+
return dataset_path.stem
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _count_tokens(tokenizer, text: str) -> int:
|
| 273 |
+
return int(len(tokenizer(text, add_special_tokens=False).input_ids))
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _generate_one_style(
|
| 277 |
+
*,
|
| 278 |
+
prompt: str,
|
| 279 |
+
reference_answer: str,
|
| 280 |
+
tokenizer,
|
| 281 |
+
style: str,
|
| 282 |
+
system_prompt: str,
|
| 283 |
+
api_base: str,
|
| 284 |
+
api_key: str,
|
| 285 |
+
generator_model: str,
|
| 286 |
+
judge_model: str,
|
| 287 |
+
timeout: int,
|
| 288 |
+
max_tokens: int,
|
| 289 |
+
temperature: float,
|
| 290 |
+
cache_ttl: int,
|
| 291 |
+
cache_namespace: Optional[str],
|
| 292 |
+
rate_limit_delay: float,
|
| 293 |
+
retries: int,
|
| 294 |
+
retry_delay: float,
|
| 295 |
+
request_interval: float,
|
| 296 |
+
judge_interval: float,
|
| 297 |
+
min_long_thinking_tokens: int,
|
| 298 |
+
max_short_thinking_tokens: int,
|
| 299 |
+
) -> Optional[AcceptedGeneration]:
|
| 300 |
+
gen_messages = build_gen_messages(prompt, system_prompt)
|
| 301 |
+
generation = _call_with_retries(
|
| 302 |
+
api_base=api_base,
|
| 303 |
+
api_key=api_key,
|
| 304 |
+
model=generator_model,
|
| 305 |
+
messages=gen_messages,
|
| 306 |
+
timeout=timeout,
|
| 307 |
+
max_tokens=max_tokens,
|
| 308 |
+
temperature=temperature,
|
| 309 |
+
cache_ttl=cache_ttl,
|
| 310 |
+
cache_namespace=cache_namespace,
|
| 311 |
+
rate_limit_delay=rate_limit_delay,
|
| 312 |
+
retries=retries,
|
| 313 |
+
retry_delay=retry_delay,
|
| 314 |
+
)
|
| 315 |
+
if request_interval > 0:
|
| 316 |
+
time.sleep(request_interval)
|
| 317 |
+
|
| 318 |
+
parsed = split_boxed_generation(generation)
|
| 319 |
+
if not parsed:
|
| 320 |
+
return None
|
| 321 |
+
thinking_text, _boxed_segment, boxed_answer = parsed
|
| 322 |
+
thinking_tokens = _count_tokens(tokenizer, thinking_text)
|
| 323 |
+
|
| 324 |
+
if style == "short":
|
| 325 |
+
if max_short_thinking_tokens > 0 and thinking_tokens > max_short_thinking_tokens:
|
| 326 |
+
return None
|
| 327 |
+
elif style == "long":
|
| 328 |
+
if min_long_thinking_tokens > 0 and thinking_tokens < min_long_thinking_tokens:
|
| 329 |
+
return None
|
| 330 |
+
else:
|
| 331 |
+
raise ValueError(f"Unsupported style: {style}")
|
| 332 |
+
|
| 333 |
+
judge_messages = build_judge_messages(reference_answer, boxed_answer)
|
| 334 |
+
judge_resp = _call_with_retries(
|
| 335 |
+
api_base=api_base,
|
| 336 |
+
api_key=api_key,
|
| 337 |
+
model=judge_model,
|
| 338 |
+
messages=judge_messages,
|
| 339 |
+
timeout=timeout,
|
| 340 |
+
max_tokens=64,
|
| 341 |
+
temperature=0.0,
|
| 342 |
+
cache_ttl=cache_ttl,
|
| 343 |
+
cache_namespace=cache_namespace,
|
| 344 |
+
rate_limit_delay=rate_limit_delay,
|
| 345 |
+
retries=retries,
|
| 346 |
+
retry_delay=retry_delay,
|
| 347 |
+
)
|
| 348 |
+
if judge_interval > 0:
|
| 349 |
+
time.sleep(judge_interval)
|
| 350 |
+
ok = parse_bool(judge_resp)
|
| 351 |
+
if not ok:
|
| 352 |
+
return None
|
| 353 |
+
|
| 354 |
+
target_text = f"{thinking_text}\n{boxed_answer}" if thinking_text else boxed_answer
|
| 355 |
+
return AcceptedGeneration(
|
| 356 |
+
thinking_text=thinking_text,
|
| 357 |
+
boxed_answer=boxed_answer,
|
| 358 |
+
target_text=target_text,
|
| 359 |
+
thinking_tokens=thinking_tokens,
|
| 360 |
+
generation_text=generation,
|
| 361 |
+
judge_response=judge_resp,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def main() -> None:
|
| 366 |
+
parser = argparse.ArgumentParser("Sample short-CoT and long-CoT cases for exp3 (independently).")
|
| 367 |
+
parser.add_argument(
|
| 368 |
+
"--dataset_path",
|
| 369 |
+
type=str,
|
| 370 |
+
default="data/ruler_multihop/1024/niah_mq_q2/validation.jsonl",
|
| 371 |
+
help="Raw RULER JSONL path (default: niah_mq_q2 1024 validation).",
|
| 372 |
+
)
|
| 373 |
+
parser.add_argument("--dataset_tag", type=str, default=None, help="Output tag; default inferred from dataset_path.")
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--max_pairs",
|
| 376 |
+
type=int,
|
| 377 |
+
default=1,
|
| 378 |
+
help="Deprecated alias for --max_short and --max_long (kept for convenience).",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument("--max_short", type=int, default=None, help="How many short-CoT samples to keep (default: --max_pairs).")
|
| 381 |
+
parser.add_argument("--max_long", type=int, default=None, help="How many long-CoT samples to keep (default: --max_pairs).")
|
| 382 |
+
parser.add_argument("--max_raw_examples", type=int, default=None, help="Optional cap on raw examples to try.")
|
| 383 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 384 |
+
parser.add_argument("--api_base", type=str, default="http://localhost:4000/v1", help="Chat API base URL.")
|
| 385 |
+
parser.add_argument("--api_key", type=str, default=None, help="API key; defaults to FLASHTRACE_API_KEY/OPENAI_API_KEY.")
|
| 386 |
+
parser.add_argument("--generator_model", type=str, default="qwen3-235b-a22b-2507")
|
| 387 |
+
parser.add_argument("--judge_model", type=str, default="deepseek-v3-1-terminus")
|
| 388 |
+
parser.add_argument("--api_timeout", type=int, default=300)
|
| 389 |
+
parser.add_argument("--api_temperature", type=float, default=0.0)
|
| 390 |
+
parser.add_argument("--api_cache_ttl", type=int, default=600)
|
| 391 |
+
parser.add_argument("--api_cache_namespace", type=str, default="flashtrace-exp3")
|
| 392 |
+
parser.add_argument("--retry_delay", type=float, default=2.0)
|
| 393 |
+
parser.add_argument("--retries", type=int, default=2, help="Additional retries on API failure.")
|
| 394 |
+
parser.add_argument("--request_interval", type=float, default=1.0, help="Sleep seconds between generation calls.")
|
| 395 |
+
parser.add_argument("--judge_interval", type=float, default=1.0, help="Sleep seconds between judge calls.")
|
| 396 |
+
parser.add_argument("--rate_limit_delay", type=float, default=5.0, help="Seconds to wait on HTTP 429 before retrying.")
|
| 397 |
+
parser.add_argument(
|
| 398 |
+
"--api_max_tokens_short",
|
| 399 |
+
type=int,
|
| 400 |
+
default=2048,
|
| 401 |
+
help="Max tokens for the short-CoT generation call.",
|
| 402 |
+
)
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--api_max_tokens_long",
|
| 405 |
+
type=int,
|
| 406 |
+
default=8192,
|
| 407 |
+
help="Max tokens for the long-CoT generation call.",
|
| 408 |
+
)
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--min_long_thinking_tokens",
|
| 411 |
+
type=int,
|
| 412 |
+
default=512,
|
| 413 |
+
help="Minimum tokenizer tokens required in the long-CoT thinking segment.",
|
| 414 |
+
)
|
| 415 |
+
parser.add_argument(
|
| 416 |
+
"--max_short_thinking_tokens",
|
| 417 |
+
type=int,
|
| 418 |
+
default=256,
|
| 419 |
+
help="Maximum tokenizer tokens allowed in the short-CoT thinking segment.",
|
| 420 |
+
)
|
| 421 |
+
parser.add_argument(
|
| 422 |
+
"--tokenizer_model",
|
| 423 |
+
type=str,
|
| 424 |
+
default=None,
|
| 425 |
+
help="Tokenizer path for span extraction & length constraints (default: generator_model).",
|
| 426 |
+
)
|
| 427 |
+
parser.add_argument("--data_root", type=str, default="exp/exp3/data", help="Output directory for exp3 caches.")
|
| 428 |
+
parser.add_argument("--out_short", type=str, default=None, help="Optional explicit output path (short JSONL).")
|
| 429 |
+
parser.add_argument("--out_long", type=str, default=None, help="Optional explicit output path (long JSONL).")
|
| 430 |
+
args = parser.parse_args()
|
| 431 |
+
|
| 432 |
+
api_key = args.api_key or os.environ.get("FLASHTRACE_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
| 433 |
+
if not api_key:
|
| 434 |
+
raise SystemExit("Set --api_key or FLASHTRACE_API_KEY/OPENAI_API_KEY for API access.")
|
| 435 |
+
|
| 436 |
+
dataset_path = Path(args.dataset_path)
|
| 437 |
+
if not dataset_path.exists():
|
| 438 |
+
raise SystemExit(f"Dataset file not found: {dataset_path}")
|
| 439 |
+
dataset_tag = str(args.dataset_tag or _infer_dataset_tag(dataset_path))
|
| 440 |
+
|
| 441 |
+
tok_name = args.tokenizer_model or args.generator_model
|
| 442 |
+
tok_path = Path(tok_name)
|
| 443 |
+
if tok_path.exists():
|
| 444 |
+
tokenizer = AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
|
| 445 |
+
else:
|
| 446 |
+
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
| 447 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 448 |
+
|
| 449 |
+
raw_examples = ds_utils.load_ruler(dataset_path, sample=None, seed=args.seed)
|
| 450 |
+
if not raw_examples:
|
| 451 |
+
raise SystemExit("No examples loaded from the RULER JSONL.")
|
| 452 |
+
|
| 453 |
+
max_short = int(args.max_short) if args.max_short is not None else int(args.max_pairs)
|
| 454 |
+
max_long = int(args.max_long) if args.max_long is not None else int(args.max_pairs)
|
| 455 |
+
if max_short < 0 or max_long < 0:
|
| 456 |
+
raise SystemExit("--max_short/--max_long must be >= 0.")
|
| 457 |
+
|
| 458 |
+
kept_short: List[CachedExample] = []
|
| 459 |
+
kept_long: List[CachedExample] = []
|
| 460 |
+
|
| 461 |
+
total = len(raw_examples)
|
| 462 |
+
attempted = 0
|
| 463 |
+
|
| 464 |
+
for idx, ex in enumerate(tqdm(raw_examples, total=total, desc="Scanning raw RULER"), 1):
|
| 465 |
+
attempted = idx
|
| 466 |
+
if args.max_raw_examples is not None and idx > int(args.max_raw_examples):
|
| 467 |
+
break
|
| 468 |
+
if len(kept_short) >= max_short and len(kept_long) >= max_long:
|
| 469 |
+
break
|
| 470 |
+
|
| 471 |
+
reference_answer = _infer_reference_answer(ex)
|
| 472 |
+
prompt = ex.prompt
|
| 473 |
+
|
| 474 |
+
sample_id = _sha1_text(prompt)
|
| 475 |
+
base_meta = dict(ex.metadata or {})
|
| 476 |
+
base_meta["reference_answer"] = reference_answer
|
| 477 |
+
base_meta["sample_id"] = sample_id
|
| 478 |
+
base_meta["pair_id"] = sample_id # backward-compatible name (may not be paired)
|
| 479 |
+
base_meta["source_dataset_path"] = str(dataset_path)
|
| 480 |
+
base_meta["prompt_sha1"] = sample_id
|
| 481 |
+
|
| 482 |
+
if len(kept_short) < max_short:
|
| 483 |
+
short_gen = _generate_one_style(
|
| 484 |
+
prompt=prompt,
|
| 485 |
+
reference_answer=reference_answer,
|
| 486 |
+
tokenizer=tokenizer,
|
| 487 |
+
style="short",
|
| 488 |
+
system_prompt=SHORT_COT_SYSTEM_PROMPT,
|
| 489 |
+
api_base=args.api_base,
|
| 490 |
+
api_key=api_key,
|
| 491 |
+
generator_model=args.generator_model,
|
| 492 |
+
judge_model=args.judge_model,
|
| 493 |
+
timeout=args.api_timeout,
|
| 494 |
+
max_tokens=args.api_max_tokens_short,
|
| 495 |
+
temperature=args.api_temperature,
|
| 496 |
+
cache_ttl=args.api_cache_ttl,
|
| 497 |
+
cache_namespace=args.api_cache_namespace,
|
| 498 |
+
rate_limit_delay=args.rate_limit_delay,
|
| 499 |
+
retries=args.retries,
|
| 500 |
+
retry_delay=args.retry_delay,
|
| 501 |
+
request_interval=args.request_interval,
|
| 502 |
+
judge_interval=args.judge_interval,
|
| 503 |
+
min_long_thinking_tokens=args.min_long_thinking_tokens,
|
| 504 |
+
max_short_thinking_tokens=args.max_short_thinking_tokens,
|
| 505 |
+
)
|
| 506 |
+
if short_gen is not None:
|
| 507 |
+
short_meta = dict(base_meta)
|
| 508 |
+
short_meta.update(
|
| 509 |
+
{
|
| 510 |
+
"cot_style": "short",
|
| 511 |
+
"generator_model": args.generator_model,
|
| 512 |
+
"judge_model": args.judge_model,
|
| 513 |
+
"judge_response": short_gen.judge_response,
|
| 514 |
+
"boxed_answer": short_gen.boxed_answer,
|
| 515 |
+
"thinking_tokens": int(short_gen.thinking_tokens),
|
| 516 |
+
}
|
| 517 |
+
)
|
| 518 |
+
short_ex = CachedExample(
|
| 519 |
+
prompt=prompt,
|
| 520 |
+
target=short_gen.target_text,
|
| 521 |
+
indices_to_explain=None,
|
| 522 |
+
attr_mask_indices=ex.attr_mask_indices,
|
| 523 |
+
sink_span=None,
|
| 524 |
+
thinking_span=None,
|
| 525 |
+
metadata=short_meta,
|
| 526 |
+
)
|
| 527 |
+
short_ex = attach_spans_from_answer(short_ex, tokenizer, short_gen.boxed_answer)
|
| 528 |
+
if isinstance(short_ex.sink_span, list) and len(short_ex.sink_span) == 2:
|
| 529 |
+
short_ex = CachedExample(
|
| 530 |
+
prompt=short_ex.prompt,
|
| 531 |
+
target=short_ex.target,
|
| 532 |
+
indices_to_explain=short_ex.sink_span,
|
| 533 |
+
attr_mask_indices=short_ex.attr_mask_indices,
|
| 534 |
+
sink_span=short_ex.sink_span,
|
| 535 |
+
thinking_span=short_ex.thinking_span,
|
| 536 |
+
metadata=short_ex.metadata,
|
| 537 |
+
)
|
| 538 |
+
kept_short.append(short_ex)
|
| 539 |
+
print(
|
| 540 |
+
f"[kept short] raw_idx={idx}/{total} thinking_tokens={short_gen.thinking_tokens} "
|
| 541 |
+
f"sample_id={sample_id[:8]} kept={len(kept_short)}/{max_short}"
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
if len(kept_long) < max_long:
|
| 545 |
+
long_gen = _generate_one_style(
|
| 546 |
+
prompt=prompt,
|
| 547 |
+
reference_answer=reference_answer,
|
| 548 |
+
tokenizer=tokenizer,
|
| 549 |
+
style="long",
|
| 550 |
+
system_prompt=LONG_COT_SYSTEM_PROMPT,
|
| 551 |
+
api_base=args.api_base,
|
| 552 |
+
api_key=api_key,
|
| 553 |
+
generator_model=args.generator_model,
|
| 554 |
+
judge_model=args.judge_model,
|
| 555 |
+
timeout=args.api_timeout,
|
| 556 |
+
max_tokens=args.api_max_tokens_long,
|
| 557 |
+
temperature=args.api_temperature,
|
| 558 |
+
cache_ttl=args.api_cache_ttl,
|
| 559 |
+
cache_namespace=args.api_cache_namespace,
|
| 560 |
+
rate_limit_delay=args.rate_limit_delay,
|
| 561 |
+
retries=args.retries,
|
| 562 |
+
retry_delay=args.retry_delay,
|
| 563 |
+
request_interval=args.request_interval,
|
| 564 |
+
judge_interval=args.judge_interval,
|
| 565 |
+
min_long_thinking_tokens=args.min_long_thinking_tokens,
|
| 566 |
+
max_short_thinking_tokens=args.max_short_thinking_tokens,
|
| 567 |
+
)
|
| 568 |
+
if long_gen is not None:
|
| 569 |
+
long_meta = dict(base_meta)
|
| 570 |
+
long_meta.update(
|
| 571 |
+
{
|
| 572 |
+
"cot_style": "long",
|
| 573 |
+
"generator_model": args.generator_model,
|
| 574 |
+
"judge_model": args.judge_model,
|
| 575 |
+
"judge_response": long_gen.judge_response,
|
| 576 |
+
"boxed_answer": long_gen.boxed_answer,
|
| 577 |
+
"thinking_tokens": int(long_gen.thinking_tokens),
|
| 578 |
+
}
|
| 579 |
+
)
|
| 580 |
+
long_ex = CachedExample(
|
| 581 |
+
prompt=prompt,
|
| 582 |
+
target=long_gen.target_text,
|
| 583 |
+
indices_to_explain=None,
|
| 584 |
+
attr_mask_indices=ex.attr_mask_indices,
|
| 585 |
+
sink_span=None,
|
| 586 |
+
thinking_span=None,
|
| 587 |
+
metadata=long_meta,
|
| 588 |
+
)
|
| 589 |
+
long_ex = attach_spans_from_answer(long_ex, tokenizer, long_gen.boxed_answer)
|
| 590 |
+
if isinstance(long_ex.sink_span, list) and len(long_ex.sink_span) == 2:
|
| 591 |
+
long_ex = CachedExample(
|
| 592 |
+
prompt=long_ex.prompt,
|
| 593 |
+
target=long_ex.target,
|
| 594 |
+
indices_to_explain=long_ex.sink_span,
|
| 595 |
+
attr_mask_indices=long_ex.attr_mask_indices,
|
| 596 |
+
sink_span=long_ex.sink_span,
|
| 597 |
+
thinking_span=long_ex.thinking_span,
|
| 598 |
+
metadata=long_ex.metadata,
|
| 599 |
+
)
|
| 600 |
+
kept_long.append(long_ex)
|
| 601 |
+
print(
|
| 602 |
+
f"[kept long] raw_idx={idx}/{total} thinking_tokens={long_gen.thinking_tokens} "
|
| 603 |
+
f"sample_id={sample_id[:8]} kept={len(kept_long)}/{max_long}"
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
data_root = Path(args.data_root)
|
| 607 |
+
out_short = Path(args.out_short) if args.out_short else data_root / f"{dataset_tag}_short_cot.jsonl"
|
| 608 |
+
out_long = Path(args.out_long) if args.out_long else data_root / f"{dataset_tag}_long_cot.jsonl"
|
| 609 |
+
|
| 610 |
+
n_short = write_cache(out_short, kept_short)
|
| 611 |
+
n_long = write_cache(out_long, kept_long)
|
| 612 |
+
print(
|
| 613 |
+
f"Wrote short={n_short} -> {out_short}\n"
|
| 614 |
+
f"Wrote long ={n_long} -> {out_long}\n"
|
| 615 |
+
f"Attempted {attempted} / {total}"
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
missing: List[str] = []
|
| 619 |
+
if len(kept_short) < max_short:
|
| 620 |
+
missing.append(f"short({len(kept_short)}/{max_short})")
|
| 621 |
+
if len(kept_long) < max_long:
|
| 622 |
+
missing.append(f"long({len(kept_long)}/{max_long})")
|
| 623 |
+
if missing:
|
| 624 |
+
raise SystemExit(f"Could not find enough samples: {', '.join(missing)} (attempted {attempted} / {total}).")
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
if __name__ == "__main__":
|
| 628 |
+
main()
|
exp/exp4/README.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashTrace 实验 4(Aider 归因忠实度 / row-only)
|
| 2 |
+
|
| 3 |
+
本目录提供 Aider 数据集上的 token-level 归因忠实度评测工具,**只输出 row 部分的 RISE/MAS**,不保存样本级 trace。
|
| 4 |
+
|
| 5 |
+
评测范围(固定):
|
| 6 |
+
- 数据集:`exp/exp4/data/aider.jsonl`
|
| 7 |
+
- 方法:
|
| 8 |
+
- `ifr_all_positions`
|
| 9 |
+
- `ifr_multi_hop_both`(FlashTrace)
|
| 10 |
+
- 指标:`RISE`、`MAS`(row attribution only)
|
| 11 |
+
|
| 12 |
+
主要文件:
|
| 13 |
+
- `run_exp.py`:归因 + 忠实度评测,输出到 `exp/exp4/output/`
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## 数据格式
|
| 18 |
+
|
| 19 |
+
`exp/exp4/data/aider.jsonl` 每行一个 JSON,对应一个样本:
|
| 20 |
+
- `input`:prompt(直接作为 user prompt 内容)
|
| 21 |
+
- `output`:target(直接作为模型生成文本;脚本会内部追加 EOS 做打分)
|
| 22 |
+
- `length`:数据自带字段(脚本不依赖,仅透传到 metadata)
|
| 23 |
+
|
| 24 |
+
说明:Aider 的 `output` 形如:
|
| 25 |
+
1) 第一行 `xxx.py`
|
| 26 |
+
2) 第二行 opening fence ```
|
| 27 |
+
3) 中间为代码
|
| 28 |
+
4) 最后一行为 closing fence ```
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## 归因与 sink 选择
|
| 33 |
+
|
| 34 |
+
脚本对每个样本都将 `input` 作为 `prompt`,将 `output` 作为 `target`(不做重新生成),并在归因结果上选择不同的 sink(`indices_to_explain=[start_tok,end_tok]`,均基于 `tokenizer(target, add_special_tokens=False)` 的 token span;不含 EOS)。
|
| 35 |
+
|
| 36 |
+
### `ifr_all_positions`(输出两个 sink)
|
| 37 |
+
|
| 38 |
+
- `last_line`:取 `output` 中 **closing fence 之前最后一个“非空且非 ```”行**,并将该行的字符 span 映射到 token span;若无法解析则回退为 `full_output`。
|
| 39 |
+
- `last_token`:取 `last_line` 的最后一个 token(单点 span `[end,end]`)。
|
| 40 |
+
|
| 41 |
+
注意:脚本会对同一个样本只计算一次 `ifr_all_positions` 的归因矩阵,然后分别在两个 sink 上取 row attribution 并计算忠实度。
|
| 42 |
+
|
| 43 |
+
### `ifr_multi_hop_both`(FlashTrace,只输出一个 sink)
|
| 44 |
+
|
| 45 |
+
- `full_output`:用完整 `output` 作为 sink(token span `[0, n_tok-1]`)。
|
| 46 |
+
- 忠实度扰动侧会沿用 exp2 的协议:对 prompt-side 会跳过 stop tokens(由 `ft_ifr_improve.py` 的 stop-token 配置决定)。
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## 指标输出(row-only)
|
| 51 |
+
|
| 52 |
+
输出 CSV 仅包含 row attribution 的 `RISE/MAS` 聚合统计:
|
| 53 |
+
- `Method,Sink,Row_RISE_Mean,Row_RISE_Std,Row_MAS_Mean,Row_MAS_Std,Used,Skipped,Avg_Sample_Time_s`
|
| 54 |
+
|
| 55 |
+
输出路径:
|
| 56 |
+
- `exp/exp4/output/faithfulness/aider/<model_tag>/row_only_<N>_examples.csv`
|
| 57 |
+
|
| 58 |
+
其中 `<model_tag>` 优先取 `--model`,否则取 `--model_path` 的目录名。
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## 使用说明
|
| 63 |
+
|
| 64 |
+
推荐从 repo root 运行(保证相对路径可用):
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
python exp/exp4/run_exp.py \
|
| 68 |
+
--data_path exp/exp4/data/aider.jsonl \
|
| 69 |
+
--output_root exp/exp4/output \
|
| 70 |
+
--model qwen-8B \
|
| 71 |
+
--model_path /opt/share/models/Qwen/Qwen3-8B/ \
|
| 72 |
+
--cuda 2,3,4,5,6,7 \
|
| 73 |
+
--num_examples 100 \
|
| 74 |
+
--n_hops 1 \
|
| 75 |
+
--k 20
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
常用参数:
|
| 79 |
+
- `--model_path` / `--model`:本地模型路径或 HF repo id(至少提供其一)
|
| 80 |
+
- `--tokenizer_path`:可选;不提供则默认复用模型路径/id
|
| 81 |
+
- `--cuda`:支持 `0`(单卡)或 `0,1,2`(多卡,内部会设置 `CUDA_VISIBLE_DEVICES` 并用 `device_map=auto`)
|
| 82 |
+
- `--num_examples`:评测前 N 条(按文件顺序;`--seed` 预留,当前不做随机抽样)
|
| 83 |
+
- `--n_hops`:FlashTrace(`ifr_multi_hop_both`)的 hop 数
|
| 84 |
+
- `--k`:MAS/RISE 的扰动步数
|
| 85 |
+
- `--chunk_tokens` / `--sink_chunk_tokens`:IFR 计算的 chunk 参数(一般保持默认)
|
exp/exp4/run_exp.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Experiment 4 runner: Aider token-level attribution faithfulness.
|
| 4 |
+
|
| 5 |
+
Evaluates only:
|
| 6 |
+
- IFR: ifr_all_positions
|
| 7 |
+
- sink = last meaningful code line (excluding fences)
|
| 8 |
+
- sink = last token of that code line
|
| 9 |
+
- FlashTrace: ifr_multi_hop_both
|
| 10 |
+
- sink = full output (excluding appended EOS)
|
| 11 |
+
|
| 12 |
+
Outputs only row-level faithfulness scores (RISE, MAS). No sample-level traces.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from itertools import islice
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _early_set_cuda_visible_devices() -> None:
|
| 29 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 30 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 31 |
+
args, _ = parser.parse_known_args(sys.argv[1:])
|
| 32 |
+
if args.cuda and "," in args.cuda:
|
| 33 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_early_set_cuda_visible_devices()
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, utils
|
| 41 |
+
|
| 42 |
+
# Ensure repo root on path for `import llm_attr`, `import ft_ifr_improve`, etc.
|
| 43 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 44 |
+
if str(REPO_ROOT) not in sys.path:
|
| 45 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 46 |
+
|
| 47 |
+
import ft_ifr_improve
|
| 48 |
+
import llm_attr
|
| 49 |
+
import llm_attr_eval
|
| 50 |
+
|
| 51 |
+
utils.logging.set_verbosity_error()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass(frozen=True)
|
| 55 |
+
class AiderExample:
|
| 56 |
+
prompt: str
|
| 57 |
+
target: str
|
| 58 |
+
metadata: Dict[str, Any]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _read_jsonl(path: Path) -> List[Dict[str, Any]]:
|
| 62 |
+
rows: List[Dict[str, Any]] = []
|
| 63 |
+
with path.open("r", encoding="utf-8") as f:
|
| 64 |
+
for line in f:
|
| 65 |
+
if not line.strip():
|
| 66 |
+
continue
|
| 67 |
+
rows.append(json.loads(line))
|
| 68 |
+
return rows
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_aider(path: Path) -> List[AiderExample]:
|
| 72 |
+
rows = _read_jsonl(path)
|
| 73 |
+
examples: List[AiderExample] = []
|
| 74 |
+
for row in rows:
|
| 75 |
+
prompt = str(row.get("input") or "")
|
| 76 |
+
target = str(row.get("output") or "")
|
| 77 |
+
examples.append(AiderExample(prompt=prompt, target=target, metadata={"length": row.get("length")}))
|
| 78 |
+
return examples
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _token_span_full_output(tokenizer, target: str) -> List[int]:
|
| 82 |
+
ids = tokenizer(target, add_special_tokens=False).input_ids
|
| 83 |
+
if not ids:
|
| 84 |
+
return [0, 0]
|
| 85 |
+
return [0, int(len(ids) - 1)]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _last_meaningful_code_line_char_span(target: str) -> Optional[Tuple[int, int]]:
|
| 89 |
+
lines = target.splitlines(keepends=True)
|
| 90 |
+
pos = 0
|
| 91 |
+
spans: List[Tuple[int, int, str]] = []
|
| 92 |
+
for line in lines:
|
| 93 |
+
start = pos
|
| 94 |
+
pos += len(line)
|
| 95 |
+
spans.append((start, pos, line))
|
| 96 |
+
|
| 97 |
+
for start, end, line in reversed(spans):
|
| 98 |
+
stripped = line.strip()
|
| 99 |
+
if not stripped:
|
| 100 |
+
continue
|
| 101 |
+
if stripped.startswith("```"):
|
| 102 |
+
continue
|
| 103 |
+
if start == 0 and stripped.endswith(".py"):
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
line_no_nl = line.rstrip("\r\n")
|
| 107 |
+
end_no_nl = start + len(line_no_nl)
|
| 108 |
+
if end_no_nl <= start:
|
| 109 |
+
continue
|
| 110 |
+
return start, end_no_nl
|
| 111 |
+
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _char_span_to_token_span(tokenizer, text: str, span: Tuple[int, int]) -> Optional[List[int]]:
|
| 116 |
+
start_char, end_char = int(span[0]), int(span[1])
|
| 117 |
+
if end_char <= start_char:
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
enc = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
|
| 121 |
+
offsets = enc.get("offset_mapping")
|
| 122 |
+
if offsets is None:
|
| 123 |
+
raise ValueError("Tokenizer does not provide offset_mapping; cannot map char spans to tokens.")
|
| 124 |
+
|
| 125 |
+
tok_indices: List[int] = []
|
| 126 |
+
for idx, off in enumerate(offsets):
|
| 127 |
+
if off is None:
|
| 128 |
+
continue
|
| 129 |
+
s, e = int(off[0]), int(off[1])
|
| 130 |
+
if s < end_char and e > start_char:
|
| 131 |
+
tok_indices.append(int(idx))
|
| 132 |
+
if not tok_indices:
|
| 133 |
+
return None
|
| 134 |
+
return [min(tok_indices), max(tok_indices)]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _last_meaningful_code_line_token_span(tokenizer, target: str) -> List[int]:
|
| 138 |
+
full_span = _token_span_full_output(tokenizer, target)
|
| 139 |
+
span_chars = _last_meaningful_code_line_char_span(target)
|
| 140 |
+
if span_chars is None:
|
| 141 |
+
return full_span
|
| 142 |
+
|
| 143 |
+
span_toks = _char_span_to_token_span(tokenizer, target, span_chars)
|
| 144 |
+
if span_toks is None:
|
| 145 |
+
return full_span
|
| 146 |
+
|
| 147 |
+
span_toks[0] = max(int(span_toks[0]), int(full_span[0]))
|
| 148 |
+
span_toks[1] = min(int(span_toks[1]), int(full_span[1]))
|
| 149 |
+
if span_toks[1] < span_toks[0]:
|
| 150 |
+
return full_span
|
| 151 |
+
return span_toks
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _last_token_span(token_span: Sequence[int]) -> List[int]:
|
| 155 |
+
if not (isinstance(token_span, Sequence) and len(token_span) == 2):
|
| 156 |
+
return [0, 0]
|
| 157 |
+
end = int(token_span[1])
|
| 158 |
+
return [end, end]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def resolve_device(args) -> str:
|
| 162 |
+
if args.cuda is not None and "," in args.cuda:
|
| 163 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
| 164 |
+
return "auto"
|
| 165 |
+
if args.cuda is not None and args.cuda.strip():
|
| 166 |
+
return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
|
| 167 |
+
return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def load_model_and_tokenizer(args) -> tuple[Any, Any]:
|
| 171 |
+
model_id = args.model_path or args.model
|
| 172 |
+
if not model_id:
|
| 173 |
+
raise SystemExit("Provide --model_path (local) or --model (HF repo id).")
|
| 174 |
+
|
| 175 |
+
tokenizer_id = args.tokenizer_path or model_id
|
| 176 |
+
device = resolve_device(args)
|
| 177 |
+
|
| 178 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 179 |
+
model_id,
|
| 180 |
+
device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
|
| 181 |
+
torch_dtype=torch.float16,
|
| 182 |
+
attn_implementation="eager",
|
| 183 |
+
)
|
| 184 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
| 185 |
+
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
| 186 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 187 |
+
model.eval()
|
| 188 |
+
return model, tokenizer
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _faithfulness_test_with_user_prompt_indices(
|
| 192 |
+
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
|
| 193 |
+
attribution: torch.Tensor,
|
| 194 |
+
prompt: str,
|
| 195 |
+
generation: str,
|
| 196 |
+
*,
|
| 197 |
+
user_prompt_indices: List[int],
|
| 198 |
+
k: int = 20,
|
| 199 |
+
) -> Tuple[float, float, float]:
|
| 200 |
+
def auc(arr: np.ndarray) -> float:
|
| 201 |
+
return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / max(1, (arr.shape[0] - 1))
|
| 202 |
+
|
| 203 |
+
pad_token_id = llm_evaluator._ensure_pad_token_id()
|
| 204 |
+
|
| 205 |
+
user_prompt = " " + prompt
|
| 206 |
+
formatted_prompt = llm_evaluator.format_prompt(user_prompt)
|
| 207 |
+
formatted_ids = llm_evaluator.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).input_ids
|
| 208 |
+
|
| 209 |
+
prompt_ids = formatted_ids.to(llm_evaluator.device)
|
| 210 |
+
prompt_ids_perturbed = prompt_ids.clone()
|
| 211 |
+
generation_ids = llm_evaluator.tokenizer(
|
| 212 |
+
generation + llm_evaluator.tokenizer.eos_token,
|
| 213 |
+
return_tensors="pt",
|
| 214 |
+
add_special_tokens=False,
|
| 215 |
+
).input_ids.to(llm_evaluator.device)
|
| 216 |
+
|
| 217 |
+
attr_cpu = attribution.detach().cpu()
|
| 218 |
+
w = attr_cpu.sum(0)
|
| 219 |
+
sorted_attr_indices = torch.argsort(w, descending=True)
|
| 220 |
+
attr_sum = float(w.sum().item())
|
| 221 |
+
|
| 222 |
+
P = int(w.numel())
|
| 223 |
+
if len(user_prompt_indices) != P:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
"user_prompt_indices length does not match prompt-side attribution length: "
|
| 226 |
+
f"indices P={len(user_prompt_indices)}, attr P={P}."
|
| 227 |
+
)
|
| 228 |
+
if P == 0:
|
| 229 |
+
return 0.0, 0.0, 0.0
|
| 230 |
+
|
| 231 |
+
if max(user_prompt_indices) >= int(prompt_ids_perturbed.shape[1]):
|
| 232 |
+
raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
|
| 233 |
+
|
| 234 |
+
steps = int(k) if k is not None else 0
|
| 235 |
+
if steps <= 0:
|
| 236 |
+
steps = 1
|
| 237 |
+
steps = min(steps, P)
|
| 238 |
+
|
| 239 |
+
scores = np.zeros(steps + 1, dtype=np.float64)
|
| 240 |
+
density = np.zeros(steps + 1, dtype=np.float64)
|
| 241 |
+
|
| 242 |
+
scores[0] = (
|
| 243 |
+
llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
|
| 244 |
+
)
|
| 245 |
+
density[0] = 1.0
|
| 246 |
+
|
| 247 |
+
if attr_sum <= 0:
|
| 248 |
+
density = np.linspace(1.0, 0.0, steps + 1)
|
| 249 |
+
|
| 250 |
+
base = P // steps
|
| 251 |
+
remainder = P % steps
|
| 252 |
+
start = 0
|
| 253 |
+
for step in range(steps):
|
| 254 |
+
size = base + (1 if step < remainder else 0)
|
| 255 |
+
group = sorted_attr_indices[start : start + size]
|
| 256 |
+
start += size
|
| 257 |
+
|
| 258 |
+
for idx in group:
|
| 259 |
+
j = int(idx.item())
|
| 260 |
+
abs_pos = int(user_prompt_indices[j])
|
| 261 |
+
prompt_ids_perturbed[0, abs_pos] = pad_token_id
|
| 262 |
+
scores[step + 1] = (
|
| 263 |
+
llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
|
| 264 |
+
)
|
| 265 |
+
if attr_sum > 0:
|
| 266 |
+
dec = float(w.index_select(0, group).sum().item()) / attr_sum
|
| 267 |
+
density[step + 1] = density[step] - dec
|
| 268 |
+
|
| 269 |
+
min_normalized_pred = 1.0
|
| 270 |
+
normalized_model_response = scores.copy()
|
| 271 |
+
for i in range(len(scores)):
|
| 272 |
+
normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
|
| 273 |
+
normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
|
| 274 |
+
min_normalized_pred = min(min_normalized_pred, normalized_pred)
|
| 275 |
+
normalized_model_response[i] = min_normalized_pred
|
| 276 |
+
|
| 277 |
+
alignment_penalty = np.abs(normalized_model_response - density)
|
| 278 |
+
corrected_scores = normalized_model_response + alignment_penalty
|
| 279 |
+
corrected_scores = corrected_scores.clip(0.0, 1.0)
|
| 280 |
+
corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
|
| 281 |
+
|
| 282 |
+
if np.isnan(corrected_scores).any():
|
| 283 |
+
corrected_scores = np.linspace(1.0, 0.0, len(scores))
|
| 284 |
+
|
| 285 |
+
return auc(normalized_model_response), auc(corrected_scores), auc(normalized_model_response + alignment_penalty)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _row_faithfulness_scores(
|
| 289 |
+
*,
|
| 290 |
+
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
|
| 291 |
+
attribution_prompt: torch.Tensor,
|
| 292 |
+
prompt: str,
|
| 293 |
+
generation: str,
|
| 294 |
+
user_prompt_indices: Optional[List[int]],
|
| 295 |
+
keep_prompt_token_indices: Optional[Sequence[int]] = None,
|
| 296 |
+
k: int = 20,
|
| 297 |
+
) -> Tuple[float, float]:
|
| 298 |
+
if keep_prompt_token_indices is not None:
|
| 299 |
+
rise, mas, _ = ft_ifr_improve.faithfulness_test_skip_tokens(
|
| 300 |
+
llm_evaluator,
|
| 301 |
+
attribution_prompt,
|
| 302 |
+
prompt,
|
| 303 |
+
generation,
|
| 304 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 305 |
+
user_prompt_indices=user_prompt_indices,
|
| 306 |
+
k=int(k),
|
| 307 |
+
)
|
| 308 |
+
return float(rise), float(mas)
|
| 309 |
+
if user_prompt_indices is not None:
|
| 310 |
+
rise, mas, _ = _faithfulness_test_with_user_prompt_indices(
|
| 311 |
+
llm_evaluator,
|
| 312 |
+
attribution_prompt,
|
| 313 |
+
prompt,
|
| 314 |
+
generation,
|
| 315 |
+
user_prompt_indices=user_prompt_indices,
|
| 316 |
+
k=int(k),
|
| 317 |
+
)
|
| 318 |
+
return float(rise), float(mas)
|
| 319 |
+
|
| 320 |
+
rise, mas, _ = llm_evaluator.faithfulness_test(attribution_prompt, prompt, generation, k=int(k))
|
| 321 |
+
return float(rise), float(mas)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _model_tag(args) -> str:
|
| 325 |
+
if args.model:
|
| 326 |
+
return str(args.model)
|
| 327 |
+
if args.model_path:
|
| 328 |
+
return Path(args.model_path).name
|
| 329 |
+
return "model"
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def main() -> None:
|
| 333 |
+
parser = argparse.ArgumentParser("Experiment 4 runner: aider faithfulness (row-only).")
|
| 334 |
+
parser.add_argument("--data_path", type=str, default="exp/exp4/data/aider.jsonl")
|
| 335 |
+
parser.add_argument("--output_root", type=str, default="exp/exp4/output")
|
| 336 |
+
parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).")
|
| 337 |
+
parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.")
|
| 338 |
+
parser.add_argument("--tokenizer_path", type=str, default=None, help="Optional tokenizer path/id (defaults to model).")
|
| 339 |
+
parser.add_argument("--cuda", type=str, default=None)
|
| 340 |
+
parser.add_argument("--cuda_num", type=int, default=0)
|
| 341 |
+
parser.add_argument("--num_examples", type=int, default=100)
|
| 342 |
+
parser.add_argument("--seed", type=int, default=42, help="Reserved for future use; exp4 runs in file order.")
|
| 343 |
+
parser.add_argument("--chunk_tokens", type=int, default=128)
|
| 344 |
+
parser.add_argument("--sink_chunk_tokens", type=int, default=32)
|
| 345 |
+
parser.add_argument("--n_hops", type=int, default=3)
|
| 346 |
+
parser.add_argument("--k", type=int, default=20, help="Perturbation steps for MAS/RISE.")
|
| 347 |
+
args = parser.parse_args()
|
| 348 |
+
|
| 349 |
+
data_path = Path(args.data_path)
|
| 350 |
+
if not data_path.exists():
|
| 351 |
+
raise SystemExit(f"Missing Aider JSONL: {data_path}")
|
| 352 |
+
|
| 353 |
+
model, tokenizer = load_model_and_tokenizer(args)
|
| 354 |
+
llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(model, tokenizer)
|
| 355 |
+
|
| 356 |
+
examples = load_aider(data_path)
|
| 357 |
+
total = min(len(examples), int(args.num_examples))
|
| 358 |
+
iterator = islice(examples, total)
|
| 359 |
+
|
| 360 |
+
ifr = llm_attr.LLMIFRAttribution(
|
| 361 |
+
model,
|
| 362 |
+
tokenizer,
|
| 363 |
+
chunk_tokens=int(args.chunk_tokens),
|
| 364 |
+
sink_chunk_tokens=int(args.sink_chunk_tokens),
|
| 365 |
+
)
|
| 366 |
+
flashtrace = ft_ifr_improve.LLMIFRAttributionBoth(
|
| 367 |
+
model,
|
| 368 |
+
tokenizer,
|
| 369 |
+
chunk_tokens=int(args.chunk_tokens),
|
| 370 |
+
sink_chunk_tokens=int(args.sink_chunk_tokens),
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
results: Dict[Tuple[str, str], List[Tuple[float, float]]] = {
|
| 374 |
+
("ifr_all_positions", "last_line"): [],
|
| 375 |
+
("ifr_all_positions", "last_token"): [],
|
| 376 |
+
("ifr_multi_hop_both", "full_output"): [],
|
| 377 |
+
}
|
| 378 |
+
skipped: Dict[Tuple[str, str], int] = {k: 0 for k in results}
|
| 379 |
+
sample_times: Dict[Tuple[str, str], List[float]] = {k: [] for k in results}
|
| 380 |
+
|
| 381 |
+
for example_idx, ex in enumerate(iterator):
|
| 382 |
+
prompt = ex.prompt
|
| 383 |
+
target = ex.target
|
| 384 |
+
|
| 385 |
+
full_span = _token_span_full_output(tokenizer, target)
|
| 386 |
+
last_line_span = _last_meaningful_code_line_token_span(tokenizer, target)
|
| 387 |
+
last_token_span = _last_token_span(last_line_span)
|
| 388 |
+
|
| 389 |
+
attr_all = None
|
| 390 |
+
attr_all_time_s = 0.0
|
| 391 |
+
user_prompt_indices_all: Optional[List[int]] = None
|
| 392 |
+
prompt_len_all = 0
|
| 393 |
+
try:
|
| 394 |
+
t_attr = time.perf_counter()
|
| 395 |
+
attr_all = ifr.calculate_ifr_for_all_positions(prompt, target=target)
|
| 396 |
+
attr_all_time_s = float(time.perf_counter() - t_attr)
|
| 397 |
+
user_prompt_indices_all = list(getattr(ifr, "user_prompt_indices", []) or [])
|
| 398 |
+
prompt_len_all = int(len(attr_all.prompt_tokens))
|
| 399 |
+
except Exception as exc:
|
| 400 |
+
skipped[("ifr_all_positions", "last_line")] += 1
|
| 401 |
+
skipped[("ifr_all_positions", "last_token")] += 1
|
| 402 |
+
print(f"[warn] ifr_all_positions attribution failed ex={example_idx}: {exc}")
|
| 403 |
+
|
| 404 |
+
if attr_all is not None and user_prompt_indices_all is not None and prompt_len_all >= 0:
|
| 405 |
+
for sink_name, span in (("last_line", last_line_span), ("last_token", last_token_span)):
|
| 406 |
+
key = ("ifr_all_positions", sink_name)
|
| 407 |
+
try:
|
| 408 |
+
t_faith = time.perf_counter()
|
| 409 |
+
row = attr_all.get_all_token_attrs(list(span))[1]
|
| 410 |
+
rise, mas = _row_faithfulness_scores(
|
| 411 |
+
llm_evaluator=llm_evaluator,
|
| 412 |
+
attribution_prompt=row[:, :prompt_len_all],
|
| 413 |
+
prompt=prompt,
|
| 414 |
+
generation=target,
|
| 415 |
+
user_prompt_indices=user_prompt_indices_all,
|
| 416 |
+
k=int(args.k),
|
| 417 |
+
)
|
| 418 |
+
faith_time_s = float(time.perf_counter() - t_faith)
|
| 419 |
+
results[key].append((rise, mas))
|
| 420 |
+
sample_times[key].append(attr_all_time_s + faith_time_s)
|
| 421 |
+
except Exception as exc:
|
| 422 |
+
skipped[key] += 1
|
| 423 |
+
print(f"[warn] ifr_all_positions {sink_name} failed ex={example_idx}: {exc}")
|
| 424 |
+
|
| 425 |
+
try:
|
| 426 |
+
t_attr = time.perf_counter()
|
| 427 |
+
attr_ft = flashtrace.calculate_ifr_multi_hop_both(
|
| 428 |
+
prompt,
|
| 429 |
+
target=target,
|
| 430 |
+
sink_span=None,
|
| 431 |
+
thinking_span=None,
|
| 432 |
+
n_hops=int(args.n_hops),
|
| 433 |
+
)
|
| 434 |
+
attr_ft_time_s = float(time.perf_counter() - t_attr)
|
| 435 |
+
user_prompt_indices_ft = list(getattr(flashtrace, "user_prompt_indices", []) or [])
|
| 436 |
+
prompt_len_ft = int(len(attr_ft.prompt_tokens))
|
| 437 |
+
keep_prompt_token_indices = ft_ifr_improve.keep_token_indices(list(attr_ft.prompt_tokens))
|
| 438 |
+
|
| 439 |
+
t_faith = time.perf_counter()
|
| 440 |
+
row_full = attr_ft.get_all_token_attrs(full_span)[1]
|
| 441 |
+
rise, mas = _row_faithfulness_scores(
|
| 442 |
+
llm_evaluator=llm_evaluator,
|
| 443 |
+
attribution_prompt=row_full[:, :prompt_len_ft],
|
| 444 |
+
prompt=prompt,
|
| 445 |
+
generation=target,
|
| 446 |
+
user_prompt_indices=user_prompt_indices_ft,
|
| 447 |
+
keep_prompt_token_indices=keep_prompt_token_indices,
|
| 448 |
+
k=int(args.k),
|
| 449 |
+
)
|
| 450 |
+
faith_time_s = float(time.perf_counter() - t_faith)
|
| 451 |
+
results[("ifr_multi_hop_both", "full_output")].append((rise, mas))
|
| 452 |
+
sample_times[("ifr_multi_hop_both", "full_output")].append(attr_ft_time_s + faith_time_s)
|
| 453 |
+
except Exception as exc:
|
| 454 |
+
skipped[("ifr_multi_hop_both", "full_output")] += 1
|
| 455 |
+
print(f"[warn] ifr_multi_hop_both failed ex={example_idx}: {exc}")
|
| 456 |
+
|
| 457 |
+
model_tag = _model_tag(args)
|
| 458 |
+
out_dir = Path(args.output_root) / "faithfulness" / "aider" / model_tag
|
| 459 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 460 |
+
out_path = out_dir / f"row_only_{total}_examples.csv"
|
| 461 |
+
|
| 462 |
+
with out_path.open("w", encoding="utf-8") as f:
|
| 463 |
+
f.write("Method,Sink,Row_RISE_Mean,Row_RISE_Std,Row_MAS_Mean,Row_MAS_Std,Used,Skipped,Avg_Sample_Time_s\n")
|
| 464 |
+
for (method, sink), vals in results.items():
|
| 465 |
+
arr = np.asarray(vals, dtype=np.float64)
|
| 466 |
+
used = int(arr.shape[0])
|
| 467 |
+
if used == 0:
|
| 468 |
+
rise_mean = float("nan")
|
| 469 |
+
rise_std = float("nan")
|
| 470 |
+
mas_mean = float("nan")
|
| 471 |
+
mas_std = float("nan")
|
| 472 |
+
else:
|
| 473 |
+
rise_mean = float(arr[:, 0].mean())
|
| 474 |
+
rise_std = float(arr[:, 0].std())
|
| 475 |
+
mas_mean = float(arr[:, 1].mean())
|
| 476 |
+
mas_std = float(arr[:, 1].std())
|
| 477 |
+
times = sample_times.get((method, sink)) or []
|
| 478 |
+
avg_time = float(np.mean(times)) if times else 0.0
|
| 479 |
+
f.write(
|
| 480 |
+
f"{method},{sink},{rise_mean},{rise_std},{mas_mean},{mas_std},{used},{int(skipped[(method, sink)])},{avg_time}\n"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
print(f"[done] wrote {out_path}")
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
if __name__ == "__main__":
|
| 487 |
+
main()
|
exp/exp5/README.md
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashTrace 实验 5:跨模型(Qwen → Llama)token-span 映射
|
| 2 |
+
|
| 3 |
+
## 背景:为什么需要映射
|
| 4 |
+
|
| 5 |
+
`exp/exp2/run_exp.py` 的归因与评估是严格 **token-level** 的,并且依赖缓存数据中的 token-span 字段:
|
| 6 |
+
|
| 7 |
+
- `indices_to_explain = [start_tok, end_tok]`(generation token indices,闭区间)
|
| 8 |
+
- `sink_span` / `thinking_span`(同样是 generation token spans)
|
| 9 |
+
|
| 10 |
+
这些 span 在生成缓存(`exp/exp2/sample_and_filter.py`、`exp/exp2/map_math_mine_to_exp2_cache.py`)时是用某个 tokenizer 计算并写死的(通常是 `Qwen3-8B` 的 tokenizer)。
|
| 11 |
+
|
| 12 |
+
当你切换到新模型(例如 `Llama-3.1-8B-Instruct`)时,**tokenizer 不同**,`target` 的 tokenization 长度/边界会变化,导致旧的 span 在新 tokenizer 下经常越界,从而让 exp2 在归因阶段直接报错(`IndexError: end_tok out of range`)。
|
| 13 |
+
|
| 14 |
+
## 解决方案:exp5 映射脚本
|
| 15 |
+
|
| 16 |
+
`exp/exp5/map_exp2_cache_token_spans.py` 将 exp2 缓存里的旧 token-span 从旧 tokenizer(默认 `Qwen3-8B`)映射到新 tokenizer(默认 `Llama-3.1-8B-Instruct`),并输出到:
|
| 17 |
+
|
| 18 |
+
`exp/exp5/data/<同名数据集>.jsonl`
|
| 19 |
+
|
| 20 |
+
映射策略(默认):
|
| 21 |
+
1) 用旧 tokenizer 对 `target` 做 `return_offsets_mapping=True`
|
| 22 |
+
2) 把旧的 token-span 转成 `target` 的字符区间
|
| 23 |
+
3) 用新 tokenizer 对同一个 `target` 做 offsets,再把字符区间映射回新的 token-span
|
| 24 |
+
|
| 25 |
+
如遇极端情况(缓存并非由预期旧 tokenizer 产生),可启用 `--allow_fallback_answer`,用 `metadata.boxed_answer`(或 `reference_answer`)在新 tokenizer 下重新定位 span 作为兜底。
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## Step 1:把 exp2 数据集缓存映射到 exp5/data
|
| 30 |
+
|
| 31 |
+
推荐使用仓库的 venv:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
.venv/bin/python exp/exp5/map_exp2_cache_token_spans.py \
|
| 35 |
+
--in_jsonl exp/exp2/data/niah_mq_q2.jsonl \
|
| 36 |
+
--out_dir exp/exp5/data \
|
| 37 |
+
--old_tokenizer_model /opt/share/models/Qwen/Qwen3-8B \
|
| 38 |
+
--new_tokenizer_model /opt/share/models/meta-llama/Llama-3.1-8B-Instruct
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
一次映射多个数据集(示例:RULER + math):
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
.venv/bin/python exp/exp5/map_exp2_cache_token_spans.py \
|
| 45 |
+
--in_jsonl exp/exp2/data/niah_mq_q2.jsonl exp/exp2/data/math.jsonl \
|
| 46 |
+
--out_dir exp/exp5/data \
|
| 47 |
+
--old_tokenizer_model /opt/share/models/Qwen/Qwen3-8B \
|
| 48 |
+
--new_tokenizer_model /opt/share/models/meta-llama/Llama-3.1-8B-Instruct
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
如果输出文件已存在,加 `--overwrite`。
|
| 52 |
+
|
| 53 |
+
默认行为:若某条样本无法映射,脚本会将其 **drop** 并在输出统计中报告;如需严格一致性请加 `--strict`(遇到首个失败样本直接退出)。如怀疑原缓存并非由 `--old_tokenizer_model` 产生,可加 `--allow_fallback_answer` 启用基于 `metadata.boxed_answer` 的兜底定位。
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## Step 2:用 exp2 直接跑 Llama 归因评测(但数据/输出都指向 exp5)
|
| 58 |
+
|
| 59 |
+
关键点:
|
| 60 |
+
- **数据读取**:用 `--data_root exp/exp5/data`(让 exp2 读取映射后的缓存)
|
| 61 |
+
- **结果输出**:用 `--output_root exp/exp5/output`(避免写入 `exp/exp2/output`)
|
| 62 |
+
- **不要加** `--save_hop_traces`(避免写 trace)
|
| 63 |
+
|
| 64 |
+
### RULER(可跑 recovery + faithfulness)
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
CUDA_VISIBLE_DEVICES=0 .venv/bin/python exp/exp2/run_exp.py \
|
| 68 |
+
--datasets niah_mq_q2 \
|
| 69 |
+
--data_root exp/exp5/data \
|
| 70 |
+
--output_root exp/exp5/output \
|
| 71 |
+
--attr_funcs ifr_all_positions,attnlrp,ifr_multi_hop_both \
|
| 72 |
+
--model_path /opt/share/models/meta-llama/Llama-3.1-8B-Instruct \
|
| 73 |
+
--cuda 0 \
|
| 74 |
+
--num_examples 100 \
|
| 75 |
+
--mode faithfulness_gen,recovery_ruler
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### math(只能跑 faithfulness;recovery 会被 exp2 显式拒绝)
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
CUDA_VISIBLE_DEVICES=0 .venv/bin/python exp/exp2/run_exp.py \
|
| 82 |
+
--datasets math \
|
| 83 |
+
--data_root exp/exp5/data \
|
| 84 |
+
--output_root exp/exp5/output \
|
| 85 |
+
--attr_funcs ifr_all_positions,attnlrp,ifr_multi_hop_both \
|
| 86 |
+
--model_path /opt/share/models/meta-llama/Llama-3.1-8B-Instruct \
|
| 87 |
+
--cuda 0 \
|
| 88 |
+
--num_examples 100 \
|
| 89 |
+
--mode faithfulness_gen
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## 关于“是否会污染 exp2 文件夹”
|
| 93 |
+
|
| 94 |
+
- **不会污染 `exp/exp2/data/`**:我们不改 exp2 的缓存,而是输出到 `exp/exp5/data/`。
|
| 95 |
+
- **不加 `--save_hop_traces` 不会写 trace**。
|
| 96 |
+
- 但注意:`exp/exp2/run_exp.py` 本身**一定会写 CSV 指标文件**到 `--output_root`(代码行为如此,exp5 不改 exp2),所以要做到“exp2 文件夹不新增文件”,请把 `--output_root` 指向 `exp/exp5/output`(或其它目录)。
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
python exp/exp2/run_exp.py \
|
| 100 |
+
--datasets niah_mq_q2 \
|
| 101 |
+
--data_root exp/exp5/data \
|
| 102 |
+
--output_root exp/exp5/output \
|
| 103 |
+
--attr_funcs ifr_all_positions,attnlrp,ifr_multi_hop_both \
|
| 104 |
+
--model_path /opt/share/models/meta-llama/Llama-3.1-8B-Instruct \
|
| 105 |
+
--cuda 2,3,4,5,6,7 \
|
| 106 |
+
--num_examples 100 \
|
| 107 |
+
--mode faithfulness_gen \
|
| 108 |
+
--n_hops 1
|
| 109 |
+
&& python exp/exp2/run_exp.py \
|
| 110 |
+
--datasets math \
|
| 111 |
+
--data_root exp/exp5/data \
|
| 112 |
+
--output_root exp/exp5/output \
|
| 113 |
+
--attr_funcs ifr_all_positions,attnlrp,ifr_multi_hop_both \
|
| 114 |
+
--model_path /opt/share/models/meta-llama/Llama-3.1-8B-Instruct \
|
| 115 |
+
--cuda 2,3,4,5,6,7 \
|
| 116 |
+
--num_examples 100 \
|
| 117 |
+
--mode faithfulness_gen \
|
| 118 |
+
--n_hops 1
|
| 119 |
+
```
|
exp/exp5/map_exp2_cache_token_spans.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Map exp2 cached JSONL token spans across tokenizers (Qwen -> Llama).
|
| 3 |
+
|
| 4 |
+
Background
|
| 5 |
+
----------
|
| 6 |
+
`exp/exp2/run_exp.py` expects cached datasets to provide token-level generation spans:
|
| 7 |
+
|
| 8 |
+
- indices_to_explain: [start_tok, end_tok] (generation-token indices; closed interval)
|
| 9 |
+
- sink_span / thinking_span: same tokenizer convention as indices_to_explain
|
| 10 |
+
|
| 11 |
+
These spans are computed under a specific tokenizer (often Qwen3-8B). When switching
|
| 12 |
+
to a different model/tokenizer (e.g., Llama-3.1-8B-Instruct), the stored spans can
|
| 13 |
+
become out-of-range and crash exp2 attribution (IndexError in token-span checks).
|
| 14 |
+
|
| 15 |
+
This script remaps spans by:
|
| 16 |
+
1) Tokenizing `target` with the OLD tokenizer to obtain offset_mapping
|
| 17 |
+
2) Converting the OLD token span into a character span in `target`
|
| 18 |
+
3) Tokenizing `target` with the NEW tokenizer and mapping the character span back
|
| 19 |
+
into NEW token indices
|
| 20 |
+
|
| 21 |
+
Outputs are written under `exp/exp5/data/` by default, keeping `exp/exp2/` untouched.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import sys
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
| 31 |
+
|
| 32 |
+
from transformers import AutoTokenizer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 36 |
+
if str(REPO_ROOT) not in sys.path:
|
| 37 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _split_args(values: Iterable[str]) -> List[str]:
|
| 41 |
+
out: List[str] = []
|
| 42 |
+
for v in values:
|
| 43 |
+
for part in str(v).split(","):
|
| 44 |
+
part = part.strip()
|
| 45 |
+
if part:
|
| 46 |
+
out.append(part)
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _load_tokenizer(tokenizer_model: str):
|
| 51 |
+
path = Path(tokenizer_model)
|
| 52 |
+
if path.exists():
|
| 53 |
+
return AutoTokenizer.from_pretrained(path.as_posix(), local_files_only=True)
|
| 54 |
+
# May require network access; keep as fallback for environments that allow it.
|
| 55 |
+
return AutoTokenizer.from_pretrained(tokenizer_model)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _is_token_span(span: Any) -> bool:
|
| 59 |
+
return (
|
| 60 |
+
isinstance(span, list)
|
| 61 |
+
and len(span) == 2
|
| 62 |
+
and all(isinstance(x, int) for x in span)
|
| 63 |
+
and span[0] >= 0
|
| 64 |
+
and span[1] >= span[0]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _pick_old_span(obj: Dict[str, Any]) -> Optional[List[int]]:
|
| 69 |
+
span = obj.get("indices_to_explain")
|
| 70 |
+
if _is_token_span(span):
|
| 71 |
+
return list(span)
|
| 72 |
+
span = obj.get("sink_span")
|
| 73 |
+
if _is_token_span(span):
|
| 74 |
+
return list(span)
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _offsets_to_char_span(offsets: Any, token_span: List[int]) -> Optional[Tuple[int, int]]:
|
| 79 |
+
"""Convert a token span [start,end] to a character span [char_start,char_end) using offsets."""
|
| 80 |
+
if offsets is None:
|
| 81 |
+
return None
|
| 82 |
+
if not isinstance(offsets, list):
|
| 83 |
+
return None
|
| 84 |
+
start_tok, end_tok = token_span
|
| 85 |
+
if end_tok >= len(offsets):
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
char_starts: List[int] = []
|
| 89 |
+
char_ends: List[int] = []
|
| 90 |
+
for idx in range(start_tok, end_tok + 1):
|
| 91 |
+
off = offsets[idx]
|
| 92 |
+
if off is None:
|
| 93 |
+
continue
|
| 94 |
+
if not (isinstance(off, (list, tuple)) and len(off) == 2):
|
| 95 |
+
continue
|
| 96 |
+
try:
|
| 97 |
+
s, e = int(off[0]), int(off[1])
|
| 98 |
+
except Exception:
|
| 99 |
+
continue
|
| 100 |
+
if e <= s:
|
| 101 |
+
continue
|
| 102 |
+
char_starts.append(s)
|
| 103 |
+
char_ends.append(e)
|
| 104 |
+
|
| 105 |
+
if not char_starts or not char_ends:
|
| 106 |
+
return None
|
| 107 |
+
return min(char_starts), max(char_ends)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _char_span_to_token_span(offsets: Any, char_span: Tuple[int, int]) -> Optional[List[int]]:
|
| 111 |
+
"""Convert a character span [char_start,char_end) to a token span [start,end] by overlap."""
|
| 112 |
+
if offsets is None:
|
| 113 |
+
return None
|
| 114 |
+
if not isinstance(offsets, list):
|
| 115 |
+
return None
|
| 116 |
+
char_start, char_end = int(char_span[0]), int(char_span[1])
|
| 117 |
+
if char_end <= char_start:
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
hit: List[int] = []
|
| 121 |
+
for tok_idx, off in enumerate(offsets):
|
| 122 |
+
if off is None:
|
| 123 |
+
continue
|
| 124 |
+
if not (isinstance(off, (list, tuple)) and len(off) == 2):
|
| 125 |
+
continue
|
| 126 |
+
try:
|
| 127 |
+
s, e = int(off[0]), int(off[1])
|
| 128 |
+
except Exception:
|
| 129 |
+
continue
|
| 130 |
+
if e <= s:
|
| 131 |
+
continue
|
| 132 |
+
if s < char_end and e > char_start:
|
| 133 |
+
hit.append(int(tok_idx))
|
| 134 |
+
|
| 135 |
+
if not hit:
|
| 136 |
+
return None
|
| 137 |
+
return [min(hit), max(hit)]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _validate_span_with_eos(tokenizer, target: str, token_span: List[int]) -> bool:
|
| 141 |
+
eos = tokenizer.eos_token or ""
|
| 142 |
+
gen_ids = tokenizer(target + eos, add_special_tokens=False).input_ids
|
| 143 |
+
gen_len = int(len(gen_ids))
|
| 144 |
+
return 0 <= token_span[0] <= token_span[1] < gen_len
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _guess_answer_text(obj: Dict[str, Any]) -> Optional[str]:
|
| 148 |
+
meta = obj.get("metadata") or {}
|
| 149 |
+
if isinstance(meta, dict):
|
| 150 |
+
boxed = (meta.get("boxed_answer") or "").strip()
|
| 151 |
+
if boxed:
|
| 152 |
+
return boxed
|
| 153 |
+
ref = (meta.get("reference_answer") or "").strip()
|
| 154 |
+
if ref:
|
| 155 |
+
return ref
|
| 156 |
+
tgt = obj.get("target")
|
| 157 |
+
if isinstance(tgt, str) and tgt.strip():
|
| 158 |
+
# Common exp2 cache convention: last line is the final answer.
|
| 159 |
+
last_line = tgt.strip().splitlines()[-1].strip()
|
| 160 |
+
return last_line or None
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _fallback_map_via_answer_text(
|
| 165 |
+
obj: Dict[str, Any],
|
| 166 |
+
*,
|
| 167 |
+
new_tokenizer,
|
| 168 |
+
) -> Optional[List[int]]:
|
| 169 |
+
tgt = obj.get("target")
|
| 170 |
+
if not isinstance(tgt, str) or not tgt:
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
from exp.exp2.dataset_utils import CachedExample, attach_spans_from_answer # lazy import
|
| 174 |
+
|
| 175 |
+
answer_text = _guess_answer_text(obj)
|
| 176 |
+
ex = CachedExample(
|
| 177 |
+
prompt=str(obj.get("prompt") or ""),
|
| 178 |
+
target=tgt,
|
| 179 |
+
indices_to_explain=None,
|
| 180 |
+
attr_mask_indices=obj.get("attr_mask_indices"),
|
| 181 |
+
sink_span=None,
|
| 182 |
+
thinking_span=None,
|
| 183 |
+
metadata=obj.get("metadata") or {},
|
| 184 |
+
)
|
| 185 |
+
out = attach_spans_from_answer(ex, new_tokenizer, answer_text)
|
| 186 |
+
if out.sink_span is None:
|
| 187 |
+
return None
|
| 188 |
+
if not _is_token_span(out.sink_span):
|
| 189 |
+
return None
|
| 190 |
+
return list(out.sink_span)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _map_one_obj(
|
| 194 |
+
obj: Dict[str, Any],
|
| 195 |
+
*,
|
| 196 |
+
old_tokenizer,
|
| 197 |
+
new_tokenizer,
|
| 198 |
+
allow_fallback_answer: bool,
|
| 199 |
+
) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
| 200 |
+
target = obj.get("target")
|
| 201 |
+
if not isinstance(target, str) or not target:
|
| 202 |
+
return None, "missing_target"
|
| 203 |
+
|
| 204 |
+
old_span = _pick_old_span(obj)
|
| 205 |
+
if old_span is None:
|
| 206 |
+
return None, "missing_old_span"
|
| 207 |
+
|
| 208 |
+
# 1) Old token span -> char span in target.
|
| 209 |
+
old_enc = old_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
|
| 210 |
+
old_offsets = old_enc.get("offset_mapping")
|
| 211 |
+
char_span = _offsets_to_char_span(old_offsets, old_span)
|
| 212 |
+
if char_span is None:
|
| 213 |
+
if not allow_fallback_answer:
|
| 214 |
+
return None, "old_span_to_char_failed"
|
| 215 |
+
new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
|
| 216 |
+
if new_span is None:
|
| 217 |
+
return None, "fallback_answer_failed"
|
| 218 |
+
if not _validate_span_with_eos(new_tokenizer, target, new_span):
|
| 219 |
+
return None, "fallback_answer_span_invalid"
|
| 220 |
+
mapped = dict(obj)
|
| 221 |
+
mapped["indices_to_explain"] = new_span
|
| 222 |
+
mapped["sink_span"] = new_span
|
| 223 |
+
mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
|
| 224 |
+
meta = mapped.get("metadata")
|
| 225 |
+
if not isinstance(meta, dict):
|
| 226 |
+
meta = {}
|
| 227 |
+
meta = dict(meta)
|
| 228 |
+
meta["exp5_span_map_method"] = "answer_text"
|
| 229 |
+
mapped["metadata"] = meta
|
| 230 |
+
return mapped, None
|
| 231 |
+
|
| 232 |
+
# 2) Char span -> new token span.
|
| 233 |
+
new_enc = new_tokenizer(target, add_special_tokens=False, return_offsets_mapping=True)
|
| 234 |
+
new_offsets = new_enc.get("offset_mapping")
|
| 235 |
+
new_span = _char_span_to_token_span(new_offsets, char_span)
|
| 236 |
+
if new_span is None:
|
| 237 |
+
if not allow_fallback_answer:
|
| 238 |
+
return None, "char_to_new_span_failed"
|
| 239 |
+
new_span = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
|
| 240 |
+
if new_span is None:
|
| 241 |
+
return None, "fallback_answer_failed"
|
| 242 |
+
|
| 243 |
+
if not _validate_span_with_eos(new_tokenizer, target, new_span):
|
| 244 |
+
if not allow_fallback_answer:
|
| 245 |
+
return None, "new_span_invalid"
|
| 246 |
+
fb = _fallback_map_via_answer_text(obj, new_tokenizer=new_tokenizer)
|
| 247 |
+
if fb is None or not _validate_span_with_eos(new_tokenizer, target, fb):
|
| 248 |
+
return None, "fallback_answer_span_invalid"
|
| 249 |
+
new_span = fb
|
| 250 |
+
|
| 251 |
+
mapped = dict(obj)
|
| 252 |
+
mapped["indices_to_explain"] = new_span
|
| 253 |
+
mapped["sink_span"] = new_span
|
| 254 |
+
mapped["thinking_span"] = [0, new_span[0] - 1] if new_span[0] > 0 else None
|
| 255 |
+
|
| 256 |
+
meta = mapped.get("metadata")
|
| 257 |
+
if not isinstance(meta, dict):
|
| 258 |
+
meta = {}
|
| 259 |
+
meta = dict(meta)
|
| 260 |
+
meta["exp5_span_map_method"] = "token_span_char_align"
|
| 261 |
+
mapped["metadata"] = meta
|
| 262 |
+
return mapped, None
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
|
| 266 |
+
with path.open("r", encoding="utf-8") as f:
|
| 267 |
+
for line_no, line in enumerate(f, start=1):
|
| 268 |
+
if not line.strip():
|
| 269 |
+
continue
|
| 270 |
+
try:
|
| 271 |
+
obj = json.loads(line)
|
| 272 |
+
except json.JSONDecodeError as exc: # pragma: no cover
|
| 273 |
+
raise RuntimeError(f"Invalid JSON at {path}:{line_no}: {exc}") from exc
|
| 274 |
+
if not isinstance(obj, dict):
|
| 275 |
+
raise RuntimeError(f"Expected JSON object per line at {path}:{line_no}.")
|
| 276 |
+
yield obj
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int:
|
| 280 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 281 |
+
count = 0
|
| 282 |
+
with path.open("w", encoding="utf-8") as f:
|
| 283 |
+
for obj in rows:
|
| 284 |
+
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
| 285 |
+
count += 1
|
| 286 |
+
return count
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _default_old_tokenizer() -> str:
|
| 290 |
+
# Repo defaults used in exp2 README examples for span extraction.
|
| 291 |
+
return "/opt/share/models/Qwen/Qwen3-8B"
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _default_new_tokenizer() -> str:
|
| 295 |
+
return "/opt/share/models/meta-llama/Llama-3.1-8B-Instruct"
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def main() -> None:
|
| 299 |
+
ap = argparse.ArgumentParser("Map exp2 cache token spans from an old tokenizer to a new tokenizer.")
|
| 300 |
+
ap.add_argument(
|
| 301 |
+
"--in_jsonl",
|
| 302 |
+
type=str,
|
| 303 |
+
nargs="+",
|
| 304 |
+
required=True,
|
| 305 |
+
help="One or more exp2 cached JSONL files (comma-separated also accepted).",
|
| 306 |
+
)
|
| 307 |
+
ap.add_argument(
|
| 308 |
+
"--out_dir",
|
| 309 |
+
type=str,
|
| 310 |
+
default="exp/exp5/data",
|
| 311 |
+
help="Output directory for mapped JSONL files.",
|
| 312 |
+
)
|
| 313 |
+
ap.add_argument(
|
| 314 |
+
"--old_tokenizer_model",
|
| 315 |
+
type=str,
|
| 316 |
+
default=_default_old_tokenizer(),
|
| 317 |
+
help="Tokenizer used to produce the original token spans (default: Qwen3-8B local path).",
|
| 318 |
+
)
|
| 319 |
+
ap.add_argument(
|
| 320 |
+
"--new_tokenizer_model",
|
| 321 |
+
type=str,
|
| 322 |
+
default=_default_new_tokenizer(),
|
| 323 |
+
help="Tokenizer to map spans into (default: Llama-3.1-8B-Instruct local path).",
|
| 324 |
+
)
|
| 325 |
+
ap.add_argument("--strict", action="store_true", help="Fail on the first example that cannot be mapped.")
|
| 326 |
+
ap.add_argument(
|
| 327 |
+
"--allow_fallback_answer",
|
| 328 |
+
action="store_true",
|
| 329 |
+
help=(
|
| 330 |
+
"If span alignment fails, try to recompute spans by locating metadata.boxed_answer in target "
|
| 331 |
+
"(useful when caches were not built with the assumed old tokenizer)."
|
| 332 |
+
),
|
| 333 |
+
)
|
| 334 |
+
ap.add_argument(
|
| 335 |
+
"--overwrite",
|
| 336 |
+
action="store_true",
|
| 337 |
+
help="Overwrite output files if they already exist.",
|
| 338 |
+
)
|
| 339 |
+
args = ap.parse_args()
|
| 340 |
+
|
| 341 |
+
in_paths = [Path(p) for p in _split_args(args.in_jsonl)]
|
| 342 |
+
out_dir = Path(args.out_dir)
|
| 343 |
+
|
| 344 |
+
old_tok = _load_tokenizer(str(args.old_tokenizer_model))
|
| 345 |
+
new_tok = _load_tokenizer(str(args.new_tokenizer_model))
|
| 346 |
+
|
| 347 |
+
# exp2 convention: ensure a pad token exists for downstream perturbation.
|
| 348 |
+
if new_tok.pad_token is None and new_tok.eos_token is not None:
|
| 349 |
+
new_tok.pad_token = new_tok.eos_token
|
| 350 |
+
|
| 351 |
+
summary: Dict[str, Any] = {
|
| 352 |
+
"old_tokenizer_model": str(args.old_tokenizer_model),
|
| 353 |
+
"new_tokenizer_model": str(args.new_tokenizer_model),
|
| 354 |
+
"datasets": [],
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
for in_path in in_paths:
|
| 358 |
+
if not in_path.exists():
|
| 359 |
+
raise SystemExit(f"Missing input JSONL: {in_path}")
|
| 360 |
+
out_path = out_dir / in_path.name
|
| 361 |
+
if out_path.exists() and not bool(args.overwrite):
|
| 362 |
+
raise SystemExit(f"Refusing to overwrite existing output: {out_path} (use --overwrite)")
|
| 363 |
+
|
| 364 |
+
total = 0
|
| 365 |
+
mapped_ok = 0
|
| 366 |
+
dropped = 0
|
| 367 |
+
errors: Dict[str, int] = {}
|
| 368 |
+
|
| 369 |
+
mapped_rows: List[Dict[str, Any]] = []
|
| 370 |
+
for obj in _read_jsonl(in_path):
|
| 371 |
+
total += 1
|
| 372 |
+
mapped, err = _map_one_obj(
|
| 373 |
+
obj,
|
| 374 |
+
old_tokenizer=old_tok,
|
| 375 |
+
new_tokenizer=new_tok,
|
| 376 |
+
allow_fallback_answer=bool(args.allow_fallback_answer),
|
| 377 |
+
)
|
| 378 |
+
if err is not None or mapped is None:
|
| 379 |
+
errors[err or "unknown_error"] = errors.get(err or "unknown_error", 0) + 1
|
| 380 |
+
if bool(args.strict):
|
| 381 |
+
raise SystemExit(f"Failed to map {in_path} example #{total}: {err}")
|
| 382 |
+
dropped += 1
|
| 383 |
+
continue
|
| 384 |
+
mapped_ok += 1
|
| 385 |
+
mapped_rows.append(mapped)
|
| 386 |
+
|
| 387 |
+
written = _write_jsonl(out_path, mapped_rows)
|
| 388 |
+
if written != mapped_ok: # pragma: no cover
|
| 389 |
+
raise SystemExit(f"Internal error: written={written} != mapped_ok={mapped_ok}")
|
| 390 |
+
|
| 391 |
+
record = {
|
| 392 |
+
"in_jsonl": str(in_path),
|
| 393 |
+
"out_jsonl": str(out_path),
|
| 394 |
+
"total": int(total),
|
| 395 |
+
"mapped_ok": int(mapped_ok),
|
| 396 |
+
"dropped": int(dropped),
|
| 397 |
+
"errors": errors,
|
| 398 |
+
}
|
| 399 |
+
summary["datasets"].append(record)
|
| 400 |
+
print(json.dumps(record, ensure_ascii=False))
|
| 401 |
+
|
| 402 |
+
# Human-readable compact summary at end.
|
| 403 |
+
print(json.dumps(summary, ensure_ascii=False, indent=2))
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
if __name__ == "__main__":
|
| 407 |
+
main()
|
exp/proc/README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# exp/proc(exp2 trace 映射/对外导出)
|
| 2 |
+
|
| 3 |
+
本目录提供把 `exp/exp2/run_exp.py --save_hop_traces` 产出的 trace 结果,整理成“给合作者使用”的精简样本级 `.npz` 的工具。
|
| 4 |
+
|
| 5 |
+
主要文件:
|
| 6 |
+
- `exp/proc/map_exp2_traces_to_proc.py`:读取 exp2 的 trace run 文件夹(`manifest.jsonl` + `ex_*.npz`),输出精简格式到 `exp/proc/output/`。
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## 输入要求
|
| 11 |
+
|
| 12 |
+
你需要提供(或可自动推断):
|
| 13 |
+
- `--trace_dir`:exp2 的 trace run 文件夹,例如:
|
| 14 |
+
- `exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/ifr_all_positions_mfaithfulness_gen_95ex/`
|
| 15 |
+
- `--dataset_jsonl`:与该 trace run 对应的 exp2 缓存数据集(必须包含 `prompt` + `target`),例如:
|
| 16 |
+
- `exp/exp2/data/morehopqa.jsonl`
|
| 17 |
+
- `--tokenizer_model`:与 exp2 归因时一致的 tokenizer(本地路径或模型名),例如:
|
| 18 |
+
- `/opt/share/models/Qwen/Qwen3-8B/`
|
| 19 |
+
|
| 20 |
+
注意:
|
| 21 |
+
- 本脚本会严格复刻 exp2 的 token 对齐逻辑(prompt 前导空格、generation 用 `target + eos_token` 再 decode + offset 切片),因此 tokenizer 必须与 exp2 归因一致,否则会直接报错(长度对不上)。
|
| 22 |
+
- 样本匹配使用 `manifest.jsonl` 中的 `prompt_sha1/target_sha1` 对齐 `--dataset_jsonl`;所以 `--dataset_jsonl` 必须是当次 trace run 使用的那份缓存。
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## 输出位置与命名
|
| 27 |
+
|
| 28 |
+
默认输出到:
|
| 29 |
+
- `exp/proc/output/<trace_dir 在 traces/ 之后的同构路径>/`
|
| 30 |
+
|
| 31 |
+
例如输入:
|
| 32 |
+
- `.../output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/<run_tag>/`
|
| 33 |
+
|
| 34 |
+
默认输出:
|
| 35 |
+
- `exp/proc/output/exp/exp2/data/morehopqa.jsonl/qwen-8B/<run_tag>/`
|
| 36 |
+
|
| 37 |
+
你也可以用 `--out_dir` 显式指定输出目录。
|
| 38 |
+
|
| 39 |
+
输出目录内每个样本一个文件:`ex_000000.npz`、`ex_000001.npz` …
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## 输出 `.npz` 字段(精简且仅包含必要信息)
|
| 44 |
+
|
| 45 |
+
每个输出样本 `.npz` **仅包含**下列键:
|
| 46 |
+
- `attr`:`float32[L]`,row 归因向量;已去掉 chat template,且去掉 EOS,仅覆盖 `input+cot+output` 的有效 token。
|
| 47 |
+
- `hop`:`float32[H, L]`(可选,仅 FT-IFR 类方法),逐 hop 的向量;同样已去掉 EOS,并与 `attr` 等长对齐。
|
| 48 |
+
- `tok`:`U[L]`,与 `attr/hop` 严格对齐的 token 文本片段序列(同样不含 chat template 与 EOS)。
|
| 49 |
+
- `span_in`:`int64[2]`,input 在向量中的闭区间范围。
|
| 50 |
+
- `span_cot`:`int64[2]`,cot 在向量中的闭区间范围(无 cot 时为 `[-1, -1]`)。
|
| 51 |
+
- `span_out`:`int64[2]`,output 在向量中的闭区间范围。
|
| 52 |
+
- `rise`:`float64`,row 的 RISE(faithfulness)。
|
| 53 |
+
- `mas`:`float64`,row 的 MAS(faithfulness)。
|
| 54 |
+
- `recovery`:`float64`,row 的 Recovery@10%(没有 recovery 时为 NaN)。
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## 用法示例
|
| 59 |
+
|
| 60 |
+
最常用(建议显式传入 dataset 与 tokenizer):
|
| 61 |
+
```bash
|
| 62 |
+
python exp/proc/map_exp2_traces_to_proc.py \
|
| 63 |
+
--trace_dir exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/ifr_all_positions_mfaithfulness_gen_95ex \
|
| 64 |
+
--dataset_jsonl exp/exp2/data/morehopqa.jsonl \
|
| 65 |
+
--tokenizer_model /opt/share/models/Qwen/Qwen3-8B/
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
显式指定输出目录(避免默认同构路径):
|
| 69 |
+
```bash
|
| 70 |
+
python exp/proc/map_exp2_traces_to_proc.py \
|
| 71 |
+
--trace_dir exp/exp2/output/traces/exp/exp2/data/math.jsonl/qwen-8B/ifr_multi_hop_both_n1_mfaithfulness_gen_100ex/ \
|
| 72 |
+
--dataset_jsonl exp/exp2/data/math.jsonl \
|
| 73 |
+
--tokenizer_model /opt/share/models/Qwen/Qwen3-8B/ \
|
| 74 |
+
--out_dir exp/proc/output/math_ifr_multi_hop_both
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
调试:只处理前 5 条、允许覆盖输出文件:
|
| 78 |
+
```bash
|
| 79 |
+
python exp/proc/map_exp2_traces_to_proc.py \
|
| 80 |
+
--trace_dir ... \
|
| 81 |
+
--dataset_jsonl ... \
|
| 82 |
+
--tokenizer_model ... \
|
| 83 |
+
--limit 5 \
|
| 84 |
+
--overwrite
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## 常见问题
|
| 90 |
+
|
| 91 |
+
- 报错 “Prompt/Generation token length mismatch”
|
| 92 |
+
- 几乎总是 tokenizer 不一致;请确认 `--tokenizer_model` 与 exp2 归因时使用的 tokenizer 完全一致(建议直接用同一个 `--model_path`)。
|
| 93 |
+
- 报错 “Failed to match manifest sha1 to dataset_jsonl”
|
| 94 |
+
- `--dataset_jsonl` 不是当次 trace run 使用的缓存,或缓存里没有 `target`。
|
| 95 |
+
- FT-IFR 方法输出缺 `hop`
|
| 96 |
+
- 对 `ifr_multi_hop_stop_words/ifr_multi_hop_both/ifr_multi_hop_split_hop/ifr_in_all_gen`,exp2 trace 必须包含 `vh`;若 trace 较旧请重新跑 exp2(带 `--save_hop_traces`)。
|
| 97 |
+
- 如确有需要可加 `--allow_missing_ft_hops` 强行输出(不推荐)。
|
| 98 |
+
|
exp/proc/map_exp2_traces_to_proc.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Map exp2 trace artifacts into a collaborator-friendly per-sample NPZ format.
|
| 3 |
+
|
| 4 |
+
Input: an exp2 trace run directory produced by `exp/exp2/run_exp.py --save_hop_traces`,
|
| 5 |
+
e.g.:
|
| 6 |
+
|
| 7 |
+
exp/exp2/output/traces/exp/exp2/data/morehopqa.jsonl/qwen-8B/ifr_all_positions_mfaithfulness_gen_95ex/
|
| 8 |
+
|
| 9 |
+
This directory contains:
|
| 10 |
+
- manifest.jsonl (one JSON object per sample)
|
| 11 |
+
- ex_*.npz (per-sample vectors and scores)
|
| 12 |
+
|
| 13 |
+
Output: per-sample NPZ files under `exp/proc/output/` (or a user-provided output path),
|
| 14 |
+
each containing only:
|
| 15 |
+
- attr: row attribution vector over [input + CoT + output] tokens, with chat template and EOS removed
|
| 16 |
+
- hop: per-hop vectors (FT-IFR only), aligned to attr (optional)
|
| 17 |
+
- tok: tokenized text pieces aligned to attr/hop (no chat template, no EOS)
|
| 18 |
+
- span_in/span_cot/span_out: inclusive ranges for input/CoT/output in the above vectors
|
| 19 |
+
- rise/mas: row faithfulness scores (RISE, MAS)
|
| 20 |
+
- recovery: row Recovery@10% score (NaN when unavailable)
|
| 21 |
+
|
| 22 |
+
This script is intentionally self-contained under exp/proc/ and does not modify exp2.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import hashlib
|
| 29 |
+
import json
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Dict, List, Optional, Tuple
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
from transformers import AutoTokenizer
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
FT_IFR_ATTR_FUNCS: set[str] = {
|
| 39 |
+
"ifr_in_all_gen",
|
| 40 |
+
"ifr_multi_hop_stop_words",
|
| 41 |
+
"ifr_multi_hop_both",
|
| 42 |
+
"ifr_multi_hop_split_hop",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _sha1_text(text: str) -> str:
|
| 47 |
+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _load_tokenizer(tokenizer_model: str):
|
| 51 |
+
tok_path = Path(tokenizer_model)
|
| 52 |
+
if tok_path.exists():
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained(tok_path.as_posix(), local_files_only=True)
|
| 54 |
+
else:
|
| 55 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
|
| 56 |
+
if tokenizer.eos_token is None:
|
| 57 |
+
raise SystemExit("Tokenizer is missing eos_token; cannot match exp2 generation tokenization.")
|
| 58 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 59 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 60 |
+
return tokenizer
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _decode_text_into_tokens(tokenizer, text: str) -> List[str]:
|
| 64 |
+
"""Mirror llm_attr.LLMAttribution.decode_text_into_tokens (offset-slice tokens)."""
|
| 65 |
+
enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
|
| 66 |
+
ids = enc.get("input_ids")
|
| 67 |
+
offsets = enc.get("offset_mapping")
|
| 68 |
+
if ids is None or offsets is None:
|
| 69 |
+
raise ValueError("Tokenizer must provide input_ids and offset_mapping for exact exp2 token alignment.")
|
| 70 |
+
if len(ids) != len(offsets):
|
| 71 |
+
raise ValueError("Tokenizer returned mismatched input_ids vs offset_mapping lengths.")
|
| 72 |
+
tokens: List[str] = []
|
| 73 |
+
for start, end in offsets:
|
| 74 |
+
tokens.append(text[int(start) : int(end)])
|
| 75 |
+
return tokens
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass(frozen=True)
|
| 79 |
+
class DatasetEntry:
|
| 80 |
+
prompt: str
|
| 81 |
+
target: str
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _index_dataset_by_sha1(dataset_jsonl: Path) -> Dict[Tuple[str, str], DatasetEntry]:
|
| 85 |
+
"""Build (prompt_sha1, target_sha1) -> (prompt, target) for cache lookup."""
|
| 86 |
+
index: Dict[Tuple[str, str], DatasetEntry] = {}
|
| 87 |
+
collisions: Dict[Tuple[str, str], int] = {}
|
| 88 |
+
|
| 89 |
+
with dataset_jsonl.open("r", encoding="utf-8") as f:
|
| 90 |
+
for line_num, line in enumerate(f, start=1):
|
| 91 |
+
if not line.strip():
|
| 92 |
+
continue
|
| 93 |
+
obj = json.loads(line)
|
| 94 |
+
prompt = str(obj.get("prompt") or "")
|
| 95 |
+
target = obj.get("target")
|
| 96 |
+
if target is None:
|
| 97 |
+
# exp2 trace matching requires cached targets.
|
| 98 |
+
continue
|
| 99 |
+
target = str(target)
|
| 100 |
+
|
| 101 |
+
key = (_sha1_text(prompt), _sha1_text(target))
|
| 102 |
+
if key in index:
|
| 103 |
+
collisions[key] = collisions.get(key, 1) + 1
|
| 104 |
+
continue
|
| 105 |
+
index[key] = DatasetEntry(prompt=prompt, target=target)
|
| 106 |
+
|
| 107 |
+
if collisions:
|
| 108 |
+
raise SystemExit(
|
| 109 |
+
"Dataset cache contains duplicate (prompt,target) pairs; cannot uniquely match by sha1. "
|
| 110 |
+
f"Example collision count={next(iter(collisions.values()))}. "
|
| 111 |
+
f"dataset_jsonl={dataset_jsonl}"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if not index:
|
| 115 |
+
raise SystemExit(
|
| 116 |
+
"No usable (prompt,target) pairs found in dataset cache. "
|
| 117 |
+
"Ensure you pass the exp2 cached JSONL used for attribution (with target filled)."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return index
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _infer_trace_suffix(trace_dir: Path) -> Optional[Path]:
|
| 124 |
+
parts = list(trace_dir.parts)
|
| 125 |
+
if "traces" not in parts:
|
| 126 |
+
return None
|
| 127 |
+
idx = parts.index("traces")
|
| 128 |
+
suffix_parts = parts[idx + 1 :]
|
| 129 |
+
if not suffix_parts:
|
| 130 |
+
return None
|
| 131 |
+
return Path(*suffix_parts)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _parse_manifest(manifest_path: Path) -> List[dict]:
|
| 135 |
+
records: List[dict] = []
|
| 136 |
+
with manifest_path.open("r", encoding="utf-8") as f:
|
| 137 |
+
for line in f:
|
| 138 |
+
if not line.strip():
|
| 139 |
+
continue
|
| 140 |
+
records.append(json.loads(line))
|
| 141 |
+
if not records:
|
| 142 |
+
raise SystemExit(f"Empty manifest.jsonl: {manifest_path}")
|
| 143 |
+
return records
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _read_span(npz: np.lib.npyio.NpzFile, key: str) -> Optional[Tuple[int, int]]:
|
| 147 |
+
if key not in npz.files:
|
| 148 |
+
return None
|
| 149 |
+
arr = npz[key]
|
| 150 |
+
if arr.shape != (2,):
|
| 151 |
+
raise ValueError(f"Expected {key} to have shape (2,), got {arr.shape}.")
|
| 152 |
+
return int(arr[0]), int(arr[1])
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _span_or_empty(span: Optional[Tuple[int, int]]) -> Tuple[int, int]:
|
| 156 |
+
if span is None:
|
| 157 |
+
return -1, -1
|
| 158 |
+
return int(span[0]), int(span[1])
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _tokenize_for_exp2_alignment(
|
| 162 |
+
tokenizer,
|
| 163 |
+
*,
|
| 164 |
+
prompt: str,
|
| 165 |
+
target: str,
|
| 166 |
+
expected_prompt_len: int,
|
| 167 |
+
expected_gen_len: int,
|
| 168 |
+
) -> List[str]:
|
| 169 |
+
prompt_text = " " + (prompt or "")
|
| 170 |
+
prompt_tokens = _decode_text_into_tokens(tokenizer, prompt_text)
|
| 171 |
+
if len(prompt_tokens) != int(expected_prompt_len):
|
| 172 |
+
raise ValueError(f"Prompt token length mismatch: expected {expected_prompt_len}, got {len(prompt_tokens)}.")
|
| 173 |
+
|
| 174 |
+
gen_ids = tokenizer(target + tokenizer.eos_token, add_special_tokens=False).input_ids
|
| 175 |
+
gen_text = tokenizer.decode(gen_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
| 176 |
+
gen_tokens = _decode_text_into_tokens(tokenizer, gen_text)
|
| 177 |
+
if len(gen_tokens) != int(expected_gen_len):
|
| 178 |
+
raise ValueError(f"Generation token length mismatch: expected {expected_gen_len}, got {len(gen_tokens)}.")
|
| 179 |
+
|
| 180 |
+
gen_tokens_no_eos = gen_tokens[:-1] if gen_tokens else []
|
| 181 |
+
return prompt_tokens + gen_tokens_no_eos
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _clamp_span(span: Optional[Tuple[int, int]], *, max_index: int) -> Optional[Tuple[int, int]]:
|
| 185 |
+
if span is None:
|
| 186 |
+
return None
|
| 187 |
+
start, end = int(span[0]), int(span[1])
|
| 188 |
+
if max_index < 0:
|
| 189 |
+
return None
|
| 190 |
+
if end < 0 or start > max_index:
|
| 191 |
+
return None
|
| 192 |
+
start = max(0, start)
|
| 193 |
+
end = min(max_index, end)
|
| 194 |
+
if end < start:
|
| 195 |
+
return None
|
| 196 |
+
return start, end
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _proc_one(
|
| 200 |
+
*,
|
| 201 |
+
trace_npz_path: Path,
|
| 202 |
+
record: dict,
|
| 203 |
+
dataset_index: Dict[Tuple[str, str], DatasetEntry],
|
| 204 |
+
tokenizer,
|
| 205 |
+
out_path: Path,
|
| 206 |
+
overwrite: bool,
|
| 207 |
+
allow_missing_ft_hops: bool,
|
| 208 |
+
) -> None:
|
| 209 |
+
prompt_sha1 = str(record.get("prompt_sha1") or "")
|
| 210 |
+
target_sha1 = str(record.get("target_sha1") or "")
|
| 211 |
+
if not prompt_sha1 or not target_sha1:
|
| 212 |
+
raise ValueError("manifest record missing prompt_sha1/target_sha1; cannot match dataset.")
|
| 213 |
+
|
| 214 |
+
entry = dataset_index.get((prompt_sha1, target_sha1))
|
| 215 |
+
if entry is None:
|
| 216 |
+
raise ValueError(
|
| 217 |
+
"Failed to match manifest sha1 to dataset_jsonl. "
|
| 218 |
+
"Ensure --dataset_jsonl points to the exact cached JSONL used for this trace run."
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if out_path.exists() and not overwrite:
|
| 222 |
+
raise FileExistsError(f"Refusing to overwrite existing file: {out_path} (use --overwrite).")
|
| 223 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 224 |
+
|
| 225 |
+
with np.load(trace_npz_path, allow_pickle=False) as f:
|
| 226 |
+
prompt_len = int(np.asarray(f.get("prompt_len")).item())
|
| 227 |
+
gen_len = int(np.asarray(f.get("gen_len")).item())
|
| 228 |
+
total_len = prompt_len + gen_len
|
| 229 |
+
gen_no_eos = max(0, gen_len - 1)
|
| 230 |
+
L = prompt_len + gen_no_eos
|
| 231 |
+
|
| 232 |
+
v_row_all = f.get("v_row_all")
|
| 233 |
+
if v_row_all is None:
|
| 234 |
+
raise ValueError("Missing v_row_all in trace npz; cannot build row attribution vector.")
|
| 235 |
+
v_row_all = np.asarray(v_row_all, dtype=np.float32)
|
| 236 |
+
if v_row_all.ndim != 1 or int(v_row_all.shape[0]) != int(total_len):
|
| 237 |
+
raise ValueError(f"v_row_all shape mismatch: expected ({total_len},), got {tuple(v_row_all.shape)}.")
|
| 238 |
+
attr = v_row_all[:L]
|
| 239 |
+
|
| 240 |
+
indices_to_explain = _read_span(f, "indices_to_explain_gen")
|
| 241 |
+
sink_span_gen = _read_span(f, "sink_span_gen") or indices_to_explain
|
| 242 |
+
if sink_span_gen is None:
|
| 243 |
+
raise ValueError("Missing sink_span_gen/indices_to_explain_gen; cannot define output span.")
|
| 244 |
+
thinking_span_gen = _read_span(f, "thinking_span_gen")
|
| 245 |
+
if thinking_span_gen is None:
|
| 246 |
+
sink_start = int(sink_span_gen[0])
|
| 247 |
+
think_end = sink_start - 1
|
| 248 |
+
thinking_span_gen = (0, think_end) if think_end >= 0 else None
|
| 249 |
+
|
| 250 |
+
sink_span_gen = _clamp_span(sink_span_gen, max_index=gen_no_eos - 1)
|
| 251 |
+
thinking_span_gen = _clamp_span(thinking_span_gen, max_index=gen_no_eos - 1)
|
| 252 |
+
|
| 253 |
+
span_in = (0, prompt_len - 1) if prompt_len > 0 else (-1, -1)
|
| 254 |
+
span_cot = (
|
| 255 |
+
(prompt_len + thinking_span_gen[0], prompt_len + thinking_span_gen[1])
|
| 256 |
+
if thinking_span_gen is not None
|
| 257 |
+
else (-1, -1)
|
| 258 |
+
)
|
| 259 |
+
span_out = (
|
| 260 |
+
(prompt_len + sink_span_gen[0], prompt_len + sink_span_gen[1]) if sink_span_gen is not None else (-1, -1)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
tokens = _tokenize_for_exp2_alignment(
|
| 264 |
+
tokenizer,
|
| 265 |
+
prompt=entry.prompt,
|
| 266 |
+
target=entry.target,
|
| 267 |
+
expected_prompt_len=prompt_len,
|
| 268 |
+
expected_gen_len=gen_len,
|
| 269 |
+
)
|
| 270 |
+
if len(tokens) != int(L):
|
| 271 |
+
raise ValueError(f"Token length mismatch after EOS drop: expected {L}, got {len(tokens)}.")
|
| 272 |
+
|
| 273 |
+
# Scores: row = index 1.
|
| 274 |
+
rise = float("nan")
|
| 275 |
+
mas = float("nan")
|
| 276 |
+
faith = f.get("faithfulness_scores")
|
| 277 |
+
if faith is not None:
|
| 278 |
+
faith = np.asarray(faith, dtype=np.float64)
|
| 279 |
+
if faith.shape != (3, 3):
|
| 280 |
+
raise ValueError(f"faithfulness_scores shape mismatch: expected (3,3), got {tuple(faith.shape)}.")
|
| 281 |
+
rise = float(faith[1, 0])
|
| 282 |
+
mas = float(faith[1, 1])
|
| 283 |
+
|
| 284 |
+
recovery = float("nan")
|
| 285 |
+
rec = f.get("recovery_scores")
|
| 286 |
+
if rec is not None:
|
| 287 |
+
rec = np.asarray(rec, dtype=np.float64)
|
| 288 |
+
if rec.shape != (3,):
|
| 289 |
+
raise ValueError(f"recovery_scores shape mismatch: expected (3,), got {tuple(rec.shape)}.")
|
| 290 |
+
recovery = float(rec[1])
|
| 291 |
+
|
| 292 |
+
out_payload = {
|
| 293 |
+
"attr": np.asarray(attr, dtype=np.float32),
|
| 294 |
+
"tok": np.asarray(tokens, dtype=np.str_),
|
| 295 |
+
"span_in": np.asarray(span_in, dtype=np.int64),
|
| 296 |
+
"span_cot": np.asarray(span_cot, dtype=np.int64),
|
| 297 |
+
"span_out": np.asarray(span_out, dtype=np.int64),
|
| 298 |
+
"rise": np.asarray(rise, dtype=np.float64),
|
| 299 |
+
"mas": np.asarray(mas, dtype=np.float64),
|
| 300 |
+
"recovery": np.asarray(recovery, dtype=np.float64),
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
attr_func = str(record.get("attr_func") or "")
|
| 304 |
+
want_hops = attr_func in FT_IFR_ATTR_FUNCS
|
| 305 |
+
if want_hops:
|
| 306 |
+
vh = f.get("vh")
|
| 307 |
+
if vh is None:
|
| 308 |
+
if not allow_missing_ft_hops:
|
| 309 |
+
raise ValueError(
|
| 310 |
+
f"FT-IFR method '{attr_func}' requires per-hop vectors but trace npz is missing 'vh'. "
|
| 311 |
+
"Re-run exp2 with --save_hop_traces using the updated code."
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
vh = np.asarray(vh, dtype=np.float32)
|
| 315 |
+
if vh.ndim != 2 or int(vh.shape[1]) != int(total_len):
|
| 316 |
+
raise ValueError(
|
| 317 |
+
f"vh shape mismatch: expected (H,{total_len}), got {tuple(vh.shape)} for {trace_npz_path}."
|
| 318 |
+
)
|
| 319 |
+
out_payload["hop"] = vh[:, :L]
|
| 320 |
+
|
| 321 |
+
np.savez_compressed(out_path, **out_payload)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def main() -> None:
|
| 325 |
+
ap = argparse.ArgumentParser("Map exp2 trace folder -> exp/proc/output per-sample npz files.")
|
| 326 |
+
ap.add_argument("--trace_dir", type=str, required=True, help="Path to an exp2 trace run directory (contains manifest.jsonl).")
|
| 327 |
+
ap.add_argument("--dataset_jsonl", type=str, default=None, help="Path to the exp2 cached dataset JSONL used for this trace.")
|
| 328 |
+
ap.add_argument(
|
| 329 |
+
"--tokenizer_model",
|
| 330 |
+
type=str,
|
| 331 |
+
required=True,
|
| 332 |
+
help="Tokenizer model name or local path (must match exp2 attribution tokenizer).",
|
| 333 |
+
)
|
| 334 |
+
ap.add_argument("--out_root", type=str, default="exp/proc/output", help="Root directory for proc outputs.")
|
| 335 |
+
ap.add_argument("--out_dir", type=str, default=None, help="Optional explicit output directory (overrides --out_root).")
|
| 336 |
+
ap.add_argument("--overwrite", action="store_true", help="Overwrite existing output files if present.")
|
| 337 |
+
ap.add_argument("--limit", type=int, default=None, help="Optional limit on number of samples to process (debug).")
|
| 338 |
+
ap.add_argument(
|
| 339 |
+
"--allow_missing_ft_hops",
|
| 340 |
+
action="store_true",
|
| 341 |
+
help="Allow producing FT-IFR outputs even when per-hop vectors (vh) are missing (not recommended).",
|
| 342 |
+
)
|
| 343 |
+
args = ap.parse_args()
|
| 344 |
+
|
| 345 |
+
trace_dir = Path(args.trace_dir)
|
| 346 |
+
if not trace_dir.exists() or not trace_dir.is_dir():
|
| 347 |
+
raise SystemExit(f"Missing trace_dir: {trace_dir}")
|
| 348 |
+
manifest_path = trace_dir / "manifest.jsonl"
|
| 349 |
+
if not manifest_path.exists():
|
| 350 |
+
raise SystemExit(f"Missing manifest.jsonl: {manifest_path}")
|
| 351 |
+
|
| 352 |
+
dataset_jsonl: Optional[Path] = Path(args.dataset_jsonl) if args.dataset_jsonl else None
|
| 353 |
+
if dataset_jsonl is None:
|
| 354 |
+
suffix = _infer_trace_suffix(trace_dir)
|
| 355 |
+
if suffix is not None and len(suffix.parts) >= 3:
|
| 356 |
+
# suffix = <dataset_name...>/<model_tag>/<run_tag>
|
| 357 |
+
inferred_dataset = Path(*suffix.parts[:-2])
|
| 358 |
+
if inferred_dataset.exists() and inferred_dataset.is_file():
|
| 359 |
+
dataset_jsonl = inferred_dataset
|
| 360 |
+
if dataset_jsonl is None:
|
| 361 |
+
raise SystemExit("Please pass --dataset_jsonl (could not infer it from --trace_dir).")
|
| 362 |
+
if not dataset_jsonl.exists():
|
| 363 |
+
raise SystemExit(f"Missing --dataset_jsonl: {dataset_jsonl}")
|
| 364 |
+
|
| 365 |
+
tokenizer = _load_tokenizer(str(args.tokenizer_model))
|
| 366 |
+
dataset_index = _index_dataset_by_sha1(dataset_jsonl)
|
| 367 |
+
records = _parse_manifest(manifest_path)
|
| 368 |
+
|
| 369 |
+
if args.out_dir:
|
| 370 |
+
out_dir = Path(args.out_dir)
|
| 371 |
+
else:
|
| 372 |
+
suffix = _infer_trace_suffix(trace_dir)
|
| 373 |
+
out_dir = Path(args.out_root) / suffix if suffix is not None else Path(args.out_root) / trace_dir.name
|
| 374 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 375 |
+
|
| 376 |
+
total = len(records)
|
| 377 |
+
limit = args.limit
|
| 378 |
+
if limit is not None:
|
| 379 |
+
if limit <= 0:
|
| 380 |
+
raise SystemExit("--limit must be a positive integer.")
|
| 381 |
+
total = min(total, int(limit))
|
| 382 |
+
|
| 383 |
+
processed = 0
|
| 384 |
+
for record in records[:total]:
|
| 385 |
+
file_name = str(record.get("file") or "")
|
| 386 |
+
if not file_name:
|
| 387 |
+
raise SystemExit("manifest record missing 'file' field.")
|
| 388 |
+
trace_npz_path = trace_dir / file_name
|
| 389 |
+
if not trace_npz_path.exists():
|
| 390 |
+
raise SystemExit(f"Missing trace npz referenced by manifest: {trace_npz_path}")
|
| 391 |
+
|
| 392 |
+
out_path = out_dir / file_name
|
| 393 |
+
try:
|
| 394 |
+
_proc_one(
|
| 395 |
+
trace_npz_path=trace_npz_path,
|
| 396 |
+
record=record,
|
| 397 |
+
dataset_index=dataset_index,
|
| 398 |
+
tokenizer=tokenizer,
|
| 399 |
+
out_path=out_path,
|
| 400 |
+
overwrite=bool(args.overwrite),
|
| 401 |
+
allow_missing_ft_hops=bool(args.allow_missing_ft_hops),
|
| 402 |
+
)
|
| 403 |
+
except Exception as exc:
|
| 404 |
+
raise SystemExit(f"Failed processing {trace_npz_path}: {exc}") from exc
|
| 405 |
+
processed += 1
|
| 406 |
+
|
| 407 |
+
print(f"Wrote {processed} proc samples -> {out_dir}")
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
if __name__ == "__main__":
|
| 411 |
+
main()
|
exp/proc_1/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# exp/proc_1(exp2 trace 映射/对外导出 v1)
|
| 2 |
+
|
| 3 |
+
本目录提供把 `exp/exp2/run_exp.py --save_hop_traces` 产出的 trace 结果,整理成“给合作者使用”的精简样本级 `.npz` 的工具(v1)。
|
| 4 |
+
|
| 5 |
+
与 `exp/proc/` 的区别:
|
| 6 |
+
- 去掉 `tok`(逐 token 文本片段)。
|
| 7 |
+
- 新增 `length`(三段 token 长度):`[in, cot, out]`,并保证与 `span_in/span_cot/span_out` 对齐。
|
| 8 |
+
- `hop` 字段采用“默认策略”:当 trace 样本中存在 `vh` 时才输出 `hop`;否则不输出且不报错。
|
| 9 |
+
- 支持一次性处理 `exp/exp2/output/traces/` 下所有 run 目录(所有数据集-方法组合)。
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 输入结构(exp2 traces)
|
| 14 |
+
|
| 15 |
+
`exp2` 的 trace run 目录形如:
|
| 16 |
+
- `exp/exp2/output/traces/<dataset>/<model>/<run_tag>/`
|
| 17 |
+
|
| 18 |
+
每个 run 目录包含:
|
| 19 |
+
- `manifest.jsonl`(每行一个样本记录,包含 `file=ex_*.npz`)
|
| 20 |
+
- `ex_*.npz`(每样本一个 npz)
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## 输出位置与命名
|
| 25 |
+
|
| 26 |
+
默认输出到:
|
| 27 |
+
- `exp/proc_1/output/<trace_dir 在 traces/ 之后的同构路径>/`
|
| 28 |
+
|
| 29 |
+
例如输入:
|
| 30 |
+
- `.../output/traces/exp/exp2/data/math.jsonl/qwen-8B/<run_tag>/`
|
| 31 |
+
|
| 32 |
+
默认输出:
|
| 33 |
+
- `exp/proc_1/output/exp/exp2/data/math.jsonl/qwen-8B/<run_tag>/`
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## 输出 `.npz` 字段
|
| 38 |
+
|
| 39 |
+
每个输出样本 `.npz` 仅包含下列键:
|
| 40 |
+
- `attr`:`float32[L]`,row 归因向量;覆盖 `input+cot+output` 的有效 token(移除 generation 末尾 EOS)。
|
| 41 |
+
- `hop`:`float32[H, L]`(可选),当 trace npz 中存在 `vh` 时输出(同样移除 EOS,并与 `attr` 等长对齐)。
|
| 42 |
+
- `span_in`:`int64[2]`,input 在向量中的闭区间范围。
|
| 43 |
+
- `span_cot`:`int64[2]`,cot 在向量中的闭区间范围(无 cot 时为 `[-1, -1]`)。
|
| 44 |
+
- `span_out`:`int64[2]`,output 在向量中的闭区间范围。
|
| 45 |
+
- `length`:`int64[3]`,顺序为 `[in, cot, out]`,长度与 `span_*` 严格对应(闭区间长度 `end-start+1`,空 span 长度为 0)。
|
| 46 |
+
- `rise`:`float64`,row 的 RISE(faithfulness)。
|
| 47 |
+
- `mas`:`float64`,row 的 MAS(faithfulness)。
|
| 48 |
+
- `recovery`:`float64`,row 的 Recovery@10%(没有 recovery 时为 NaN)。
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## 用法示例
|
| 53 |
+
|
| 54 |
+
处理 traces 下所有 run(推荐):
|
| 55 |
+
```bash
|
| 56 |
+
python exp/proc_1/map_exp2_traces_to_proc_1.py \
|
| 57 |
+
--traces_root exp/exp2/output/traces
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
只处理某一个 run 目录:
|
| 61 |
+
```bash
|
| 62 |
+
python exp/proc_1/map_exp2_traces_to_proc_1.py \
|
| 63 |
+
--trace_dir exp/exp2/output/traces/exp/exp2/data/math.jsonl/qwen-8B/ifr_multi_hop_both_n1_mfaithfulness_gen_100ex
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
调试:每个 run 只处理前 5 条、允许覆盖输出:
|
| 67 |
+
```bash
|
| 68 |
+
python exp/proc_1/map_exp2_traces_to_proc_1.py \
|
| 69 |
+
--traces_root exp/exp2/output/traces \
|
| 70 |
+
--limit 5 \
|
| 71 |
+
--overwrite
|
| 72 |
+
```
|
exp/proc_1/map_exp2_traces_to_proc_1.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Map exp2 trace artifacts into a collaborator-friendly per-sample NPZ format (proc_1).
|
| 3 |
+
|
| 4 |
+
This is a lightweight variant of `exp/proc/map_exp2_traces_to_proc.py`:
|
| 5 |
+
- Removes `tok` (per-token text pieces).
|
| 6 |
+
- Adds `length` with three components [in, cot, out], aligned to span_in/span_cot/span_out.
|
| 7 |
+
- Saves `hop` only when the trace sample contains `vh` (default strategy).
|
| 8 |
+
- Can process a single exp2 trace run directory or all run directories under a traces root.
|
| 9 |
+
|
| 10 |
+
Input: an exp2 trace run directory produced by `exp/exp2/run_exp.py --save_hop_traces`, e.g.:
|
| 11 |
+
|
| 12 |
+
exp/exp2/output/traces/exp/exp2/data/math.jsonl/qwen-8B/ifr_multi_hop_both_n1_mfaithfulness_gen_100ex/
|
| 13 |
+
|
| 14 |
+
This directory contains:
|
| 15 |
+
- manifest.jsonl (one JSON object per sample)
|
| 16 |
+
- ex_*.npz (per-sample vectors and scores)
|
| 17 |
+
|
| 18 |
+
Output: per-sample NPZ files under `exp/proc_1/output/` (or a user-provided output path),
|
| 19 |
+
each containing only:
|
| 20 |
+
- attr: row attribution vector over [input + CoT + output] tokens, with EOS removed
|
| 21 |
+
- hop: per-hop vectors (optional; only if `vh` exists in the trace npz), aligned to attr
|
| 22 |
+
- span_in/span_cot/span_out: inclusive ranges for input/CoT/output in the above vectors
|
| 23 |
+
- length: int64[3] = [in, cot, out], derived strictly from spans
|
| 24 |
+
- rise/mas: row faithfulness scores (RISE, MAS)
|
| 25 |
+
- recovery: row Recovery@10% score (NaN when unavailable)
|
| 26 |
+
|
| 27 |
+
This script is intentionally self-contained under exp/proc_1/ and does not modify exp2.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import argparse
|
| 33 |
+
import json
|
| 34 |
+
from dataclasses import dataclass
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import Iterable, List, Optional, Tuple
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _infer_trace_suffix(trace_dir: Path) -> Optional[Path]:
|
| 42 |
+
parts = list(trace_dir.parts)
|
| 43 |
+
if "traces" not in parts:
|
| 44 |
+
return None
|
| 45 |
+
idx = parts.index("traces")
|
| 46 |
+
suffix_parts = parts[idx + 1 :]
|
| 47 |
+
if not suffix_parts:
|
| 48 |
+
return None
|
| 49 |
+
return Path(*suffix_parts)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _iter_run_dirs(traces_root: Path) -> List[Path]:
|
| 53 |
+
runs = {p.parent for p in traces_root.rglob("manifest.jsonl") if p.is_file()}
|
| 54 |
+
return sorted(runs)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _parse_manifest(manifest_path: Path) -> List[dict]:
|
| 58 |
+
records: List[dict] = []
|
| 59 |
+
with manifest_path.open("r", encoding="utf-8") as f:
|
| 60 |
+
for line in f:
|
| 61 |
+
if not line.strip():
|
| 62 |
+
continue
|
| 63 |
+
records.append(json.loads(line))
|
| 64 |
+
return records
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _read_span(npz: np.lib.npyio.NpzFile, key: str) -> Optional[Tuple[int, int]]:
|
| 68 |
+
if key not in npz.files:
|
| 69 |
+
return None
|
| 70 |
+
arr = npz[key]
|
| 71 |
+
if arr.shape != (2,):
|
| 72 |
+
raise ValueError(f"Expected {key} to have shape (2,), got {arr.shape}.")
|
| 73 |
+
return int(arr[0]), int(arr[1])
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _clamp_span(span: Optional[Tuple[int, int]], *, max_index: int) -> Optional[Tuple[int, int]]:
|
| 77 |
+
if span is None:
|
| 78 |
+
return None
|
| 79 |
+
start, end = int(span[0]), int(span[1])
|
| 80 |
+
if max_index < 0:
|
| 81 |
+
return None
|
| 82 |
+
if end < 0 or start > max_index:
|
| 83 |
+
return None
|
| 84 |
+
start = max(0, start)
|
| 85 |
+
end = min(max_index, end)
|
| 86 |
+
if end < start:
|
| 87 |
+
return None
|
| 88 |
+
return start, end
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _span_len(span: Tuple[int, int]) -> int:
|
| 92 |
+
start, end = int(span[0]), int(span[1])
|
| 93 |
+
if start < 0 or end < 0 or end < start:
|
| 94 |
+
return 0
|
| 95 |
+
return int(end - start + 1)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass(frozen=True)
|
| 99 |
+
class ProcOneResult:
|
| 100 |
+
wrote: bool
|
| 101 |
+
has_hop: bool
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _proc_one(
|
| 105 |
+
*,
|
| 106 |
+
trace_npz_path: Path,
|
| 107 |
+
record: dict,
|
| 108 |
+
out_path: Path,
|
| 109 |
+
overwrite: bool,
|
| 110 |
+
) -> ProcOneResult:
|
| 111 |
+
if out_path.exists() and not overwrite:
|
| 112 |
+
raise FileExistsError(f"Refusing to overwrite existing file: {out_path} (use --overwrite).")
|
| 113 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
with np.load(trace_npz_path, allow_pickle=False) as f:
|
| 116 |
+
prompt_len = int(np.asarray(f.get("prompt_len")).item())
|
| 117 |
+
gen_len = int(np.asarray(f.get("gen_len")).item())
|
| 118 |
+
total_len = prompt_len + gen_len
|
| 119 |
+
gen_no_eos = max(0, gen_len - 1)
|
| 120 |
+
L = prompt_len + gen_no_eos
|
| 121 |
+
|
| 122 |
+
v_row_all = f.get("v_row_all")
|
| 123 |
+
if v_row_all is None:
|
| 124 |
+
raise ValueError("Missing v_row_all in trace npz; cannot build row attribution vector.")
|
| 125 |
+
v_row_all = np.asarray(v_row_all, dtype=np.float32)
|
| 126 |
+
if v_row_all.ndim != 1 or int(v_row_all.shape[0]) != int(total_len):
|
| 127 |
+
raise ValueError(f"v_row_all shape mismatch: expected ({total_len},), got {tuple(v_row_all.shape)}.")
|
| 128 |
+
attr = v_row_all[:L]
|
| 129 |
+
|
| 130 |
+
indices_to_explain = _read_span(f, "indices_to_explain_gen")
|
| 131 |
+
sink_span_gen = _read_span(f, "sink_span_gen") or indices_to_explain
|
| 132 |
+
if sink_span_gen is None:
|
| 133 |
+
raise ValueError("Missing sink_span_gen/indices_to_explain_gen; cannot define output span.")
|
| 134 |
+
thinking_span_gen = _read_span(f, "thinking_span_gen")
|
| 135 |
+
if thinking_span_gen is None:
|
| 136 |
+
sink_start = int(sink_span_gen[0])
|
| 137 |
+
think_end = sink_start - 1
|
| 138 |
+
thinking_span_gen = (0, think_end) if think_end >= 0 else None
|
| 139 |
+
|
| 140 |
+
sink_span_gen = _clamp_span(sink_span_gen, max_index=gen_no_eos - 1)
|
| 141 |
+
thinking_span_gen = _clamp_span(thinking_span_gen, max_index=gen_no_eos - 1)
|
| 142 |
+
|
| 143 |
+
span_in = (0, prompt_len - 1) if prompt_len > 0 else (-1, -1)
|
| 144 |
+
span_cot = (
|
| 145 |
+
(prompt_len + thinking_span_gen[0], prompt_len + thinking_span_gen[1])
|
| 146 |
+
if thinking_span_gen is not None
|
| 147 |
+
else (-1, -1)
|
| 148 |
+
)
|
| 149 |
+
span_out = (
|
| 150 |
+
(prompt_len + sink_span_gen[0], prompt_len + sink_span_gen[1]) if sink_span_gen is not None else (-1, -1)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
length = np.asarray([_span_len(span_in), _span_len(span_cot), _span_len(span_out)], dtype=np.int64)
|
| 154 |
+
|
| 155 |
+
rise = float("nan")
|
| 156 |
+
mas = float("nan")
|
| 157 |
+
faith = f.get("faithfulness_scores")
|
| 158 |
+
if faith is not None:
|
| 159 |
+
faith = np.asarray(faith, dtype=np.float64)
|
| 160 |
+
if faith.shape != (3, 3):
|
| 161 |
+
raise ValueError(f"faithfulness_scores shape mismatch: expected (3,3), got {tuple(faith.shape)}.")
|
| 162 |
+
rise = float(faith[1, 0])
|
| 163 |
+
mas = float(faith[1, 1])
|
| 164 |
+
|
| 165 |
+
recovery = float("nan")
|
| 166 |
+
rec = f.get("recovery_scores")
|
| 167 |
+
if rec is not None:
|
| 168 |
+
rec = np.asarray(rec, dtype=np.float64)
|
| 169 |
+
if rec.shape != (3,):
|
| 170 |
+
raise ValueError(f"recovery_scores shape mismatch: expected (3,), got {tuple(rec.shape)}.")
|
| 171 |
+
recovery = float(rec[1])
|
| 172 |
+
|
| 173 |
+
out_payload = {
|
| 174 |
+
"attr": np.asarray(attr, dtype=np.float32),
|
| 175 |
+
"span_in": np.asarray(span_in, dtype=np.int64),
|
| 176 |
+
"span_cot": np.asarray(span_cot, dtype=np.int64),
|
| 177 |
+
"span_out": np.asarray(span_out, dtype=np.int64),
|
| 178 |
+
"length": np.asarray(length, dtype=np.int64),
|
| 179 |
+
"rise": np.asarray(rise, dtype=np.float64),
|
| 180 |
+
"mas": np.asarray(mas, dtype=np.float64),
|
| 181 |
+
"recovery": np.asarray(recovery, dtype=np.float64),
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
has_hop = False
|
| 185 |
+
vh = f.get("vh")
|
| 186 |
+
if vh is not None:
|
| 187 |
+
vh = np.asarray(vh, dtype=np.float32)
|
| 188 |
+
if vh.ndim != 2 or int(vh.shape[1]) != int(total_len):
|
| 189 |
+
raise ValueError(f"vh shape mismatch: expected (H,{total_len}), got {tuple(vh.shape)} for {trace_npz_path}.")
|
| 190 |
+
out_payload["hop"] = vh[:, :L]
|
| 191 |
+
has_hop = True
|
| 192 |
+
|
| 193 |
+
np.savez_compressed(out_path, **out_payload)
|
| 194 |
+
_ = record
|
| 195 |
+
return ProcOneResult(wrote=True, has_hop=has_hop)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _resolve_out_dir_for_trace_dir(*, trace_dir: Path, out_root: Path, out_dir: Optional[Path]) -> Path:
|
| 199 |
+
if out_dir is not None:
|
| 200 |
+
return out_dir
|
| 201 |
+
suffix = _infer_trace_suffix(trace_dir)
|
| 202 |
+
return (out_root / suffix) if suffix is not None else (out_root / trace_dir.name)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _process_trace_dir(
|
| 206 |
+
*,
|
| 207 |
+
trace_dir: Path,
|
| 208 |
+
out_root: Path,
|
| 209 |
+
out_dir: Optional[Path],
|
| 210 |
+
overwrite: bool,
|
| 211 |
+
limit: Optional[int],
|
| 212 |
+
skip_empty_manifest: bool,
|
| 213 |
+
) -> Tuple[int, int]:
|
| 214 |
+
manifest_path = trace_dir / "manifest.jsonl"
|
| 215 |
+
if not manifest_path.exists():
|
| 216 |
+
raise SystemExit(f"Missing manifest.jsonl: {manifest_path}")
|
| 217 |
+
|
| 218 |
+
records = _parse_manifest(manifest_path)
|
| 219 |
+
if not records:
|
| 220 |
+
if skip_empty_manifest:
|
| 221 |
+
print(f"[skip] empty manifest: {manifest_path}")
|
| 222 |
+
return 0, 0
|
| 223 |
+
raise SystemExit(f"Empty manifest.jsonl: {manifest_path}")
|
| 224 |
+
|
| 225 |
+
total = len(records)
|
| 226 |
+
if limit is not None:
|
| 227 |
+
if limit <= 0:
|
| 228 |
+
raise SystemExit("--limit must be a positive integer.")
|
| 229 |
+
total = min(total, int(limit))
|
| 230 |
+
|
| 231 |
+
resolved_out_dir = _resolve_out_dir_for_trace_dir(trace_dir=trace_dir, out_root=out_root, out_dir=out_dir)
|
| 232 |
+
resolved_out_dir.mkdir(parents=True, exist_ok=True)
|
| 233 |
+
|
| 234 |
+
wrote = 0
|
| 235 |
+
wrote_with_hop = 0
|
| 236 |
+
for record in records[:total]:
|
| 237 |
+
file_name = str(record.get("file") or "")
|
| 238 |
+
if not file_name:
|
| 239 |
+
raise SystemExit("manifest record missing 'file' field.")
|
| 240 |
+
trace_npz_path = trace_dir / file_name
|
| 241 |
+
if not trace_npz_path.exists():
|
| 242 |
+
raise SystemExit(f"Missing trace npz referenced by manifest: {trace_npz_path}")
|
| 243 |
+
|
| 244 |
+
out_path = resolved_out_dir / file_name
|
| 245 |
+
try:
|
| 246 |
+
res = _proc_one(trace_npz_path=trace_npz_path, record=record, out_path=out_path, overwrite=overwrite)
|
| 247 |
+
except Exception as exc:
|
| 248 |
+
raise SystemExit(f"Failed processing {trace_npz_path}: {exc}") from exc
|
| 249 |
+
wrote += int(res.wrote)
|
| 250 |
+
wrote_with_hop += int(res.has_hop)
|
| 251 |
+
|
| 252 |
+
print(f"[ok] wrote {wrote} samples ({wrote_with_hop} with hop) -> {resolved_out_dir}")
|
| 253 |
+
return wrote, wrote_with_hop
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def main() -> None:
|
| 257 |
+
ap = argparse.ArgumentParser("Map exp2 trace folder(s) -> exp/proc_1/output per-sample npz files.")
|
| 258 |
+
ap.add_argument(
|
| 259 |
+
"--trace_dir",
|
| 260 |
+
type=str,
|
| 261 |
+
default=None,
|
| 262 |
+
help="Path to a single exp2 trace run directory (contains manifest.jsonl).",
|
| 263 |
+
)
|
| 264 |
+
ap.add_argument(
|
| 265 |
+
"--traces_root",
|
| 266 |
+
type=str,
|
| 267 |
+
default=None,
|
| 268 |
+
help="Path to traces root; processes all run dirs under it (each with a manifest.jsonl).",
|
| 269 |
+
)
|
| 270 |
+
ap.add_argument("--out_root", type=str, default="exp/proc_1/output", help="Root directory for proc_1 outputs.")
|
| 271 |
+
ap.add_argument(
|
| 272 |
+
"--out_dir",
|
| 273 |
+
type=str,
|
| 274 |
+
default=None,
|
| 275 |
+
help="Optional explicit output directory (only valid with --trace_dir; overrides --out_root).",
|
| 276 |
+
)
|
| 277 |
+
ap.add_argument("--overwrite", action="store_true", help="Overwrite existing output files if present.")
|
| 278 |
+
ap.add_argument("--limit", type=int, default=None, help="Optional limit on number of samples per run (debug).")
|
| 279 |
+
ap.add_argument(
|
| 280 |
+
"--fail_on_empty_manifest",
|
| 281 |
+
action="store_true",
|
| 282 |
+
help="Fail (instead of skipping) when encountering an empty manifest.jsonl.",
|
| 283 |
+
)
|
| 284 |
+
args = ap.parse_args()
|
| 285 |
+
|
| 286 |
+
trace_dir = Path(args.trace_dir) if args.trace_dir else None
|
| 287 |
+
traces_root = Path(args.traces_root) if args.traces_root else None
|
| 288 |
+
if (trace_dir is None) == (traces_root is None):
|
| 289 |
+
raise SystemExit("Please pass exactly one of --trace_dir or --traces_root.")
|
| 290 |
+
|
| 291 |
+
out_root = Path(args.out_root)
|
| 292 |
+
out_dir = Path(args.out_dir) if args.out_dir else None
|
| 293 |
+
if out_dir is not None and trace_dir is None:
|
| 294 |
+
raise SystemExit("--out_dir is only valid with --trace_dir (for --traces_root use --out_root).")
|
| 295 |
+
|
| 296 |
+
skip_empty_manifest = not bool(args.fail_on_empty_manifest)
|
| 297 |
+
|
| 298 |
+
if trace_dir is not None:
|
| 299 |
+
if not trace_dir.exists() or not trace_dir.is_dir():
|
| 300 |
+
raise SystemExit(f"Missing trace_dir: {trace_dir}")
|
| 301 |
+
_process_trace_dir(
|
| 302 |
+
trace_dir=trace_dir,
|
| 303 |
+
out_root=out_root,
|
| 304 |
+
out_dir=out_dir,
|
| 305 |
+
overwrite=bool(args.overwrite),
|
| 306 |
+
limit=args.limit,
|
| 307 |
+
skip_empty_manifest=skip_empty_manifest,
|
| 308 |
+
)
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
assert traces_root is not None
|
| 312 |
+
if not traces_root.exists() or not traces_root.is_dir():
|
| 313 |
+
raise SystemExit(f"Missing traces_root: {traces_root}")
|
| 314 |
+
|
| 315 |
+
run_dirs = _iter_run_dirs(traces_root)
|
| 316 |
+
if not run_dirs:
|
| 317 |
+
raise SystemExit(f"No run directories found under traces_root={traces_root} (expected manifest.jsonl).")
|
| 318 |
+
|
| 319 |
+
total_written = 0
|
| 320 |
+
total_with_hop = 0
|
| 321 |
+
for run_dir in run_dirs:
|
| 322 |
+
wrote, wrote_with_hop = _process_trace_dir(
|
| 323 |
+
trace_dir=run_dir,
|
| 324 |
+
out_root=out_root,
|
| 325 |
+
out_dir=None,
|
| 326 |
+
overwrite=bool(args.overwrite),
|
| 327 |
+
limit=args.limit,
|
| 328 |
+
skip_empty_manifest=skip_empty_manifest,
|
| 329 |
+
)
|
| 330 |
+
total_written += wrote
|
| 331 |
+
total_with_hop += wrote_with_hop
|
| 332 |
+
|
| 333 |
+
print(f"[done] total wrote {total_written} samples ({total_with_hop} with hop) under out_root={out_root}")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
if __name__ == "__main__":
|
| 337 |
+
main()
|
| 338 |
+
|
flashtrace/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FlashTrace: efficient multi-token attribution for reasoning LLMs."""
|
| 2 |
+
|
| 3 |
+
from .model_io import load_model_and_tokenizer
|
| 4 |
+
from .result import TokenScore, TraceResult
|
| 5 |
+
from .tracer import FlashTrace
|
| 6 |
+
|
| 7 |
+
__all__ = ["FlashTrace", "TraceResult", "TokenScore", "load_model_and_tokenizer"]
|
flashtrace/attribution.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flashtrace/baselines/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Baseline attribution methods for FlashTrace."""
|
| 2 |
+
|
| 3 |
+
from .attnlrp import LLMLRPAttribution
|
| 4 |
+
|
| 5 |
+
__all__ = ["LLMLRPAttribution"]
|
flashtrace/baselines/attnlrp.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AttnLRP baseline API."""
|
| 2 |
+
|
| 3 |
+
from flashtrace.attribution import AttnLRPSpanAggregate, LLMLRPAttribution, MultiHopAttnLRPResult
|
| 4 |
+
from flashtrace.lrp_patches import detect_model_type, lrp_context
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"AttnLRPSpanAggregate",
|
| 8 |
+
"LLMLRPAttribution",
|
| 9 |
+
"MultiHopAttnLRPResult",
|
| 10 |
+
"detect_model_type",
|
| 11 |
+
"lrp_context",
|
| 12 |
+
]
|