| |
| """ |
| Experiment 3 runner: long-vs-short CoT case study (AttnLRP hop0, Recovery@10%). |
| |
| This runner is intentionally minimal: |
| - Only reads two cached samples produced by exp/exp3/sample_and_filter.py: |
| <dataset_tag>_short_cot.jsonl |
| <dataset_tag>_long_cot.jsonl |
| - Only runs attribution method: attnlrp (hop0 path, aligned with exp2). |
| - Only computes token-level recovery (Recall@10%) using RULER needle_spans. |
| - Always saves per-sample trace artifacts under exp/exp3/output/traces/. |
| |
| All outputs are written under exp/exp3/output/ (configurable via --output_root). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| import os |
| import sys |
| import time |
| from itertools import islice |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
| def _early_set_cuda_visible_devices() -> None: |
| parser = argparse.ArgumentParser(add_help=False) |
| parser.add_argument("--cuda", type=str, default=None) |
| args, _ = parser.parse_known_args(sys.argv[1:]) |
| if args.cuda and "," in args.cuda: |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda |
|
|
|
|
| _early_set_cuda_visible_devices() |
|
|
| import numpy as np |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, utils |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[2] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| import llm_attr |
| import llm_attr_eval |
| from exp.exp2 import dataset_utils as ds_utils |
|
|
| utils.logging.set_verbosity_error() |
|
|
|
|
| def _sha1_text(text: str) -> str: |
| return hashlib.sha1(text.encode("utf-8")).hexdigest() |
|
|
|
|
| def _token_importance_vector(attr: torch.Tensor) -> np.ndarray: |
| w = torch.nan_to_num(attr.sum(0).to(dtype=torch.float32), nan=0.0).clamp(min=0.0) |
| return w.detach().cpu().numpy().astype(np.float32, copy=False) |
|
|
|
|
| def _trace_run_tag(*, neg_handling: str, norm_mode: str, total: int) -> str: |
| return f"attnlrp_neg{neg_handling}_norm{norm_mode}_recovery_{int(total)}ex" |
|
|
|
|
| def _build_sample_trace_payload( |
| example: ds_utils.CachedExample, |
| *, |
| seq_attr: torch.Tensor, |
| row_attr: torch.Tensor, |
| rec_attr: torch.Tensor, |
| prompt_len: int, |
| user_prompt_indices: Optional[List[int]], |
| gold_prompt_token_indices: Optional[List[int]], |
| recovery_scores: Optional[np.ndarray], |
| time_attr_s: Optional[float], |
| time_recovery_s: Optional[float], |
| ) -> Dict[str, np.ndarray]: |
| gen_len = int(seq_attr.shape[0]) |
|
|
| v_seq_all = _token_importance_vector(seq_attr) |
| v_row_all = _token_importance_vector(row_attr) |
| v_rec_all = _token_importance_vector(rec_attr) |
|
|
| payload: Dict[str, np.ndarray] = { |
| "v_seq_all": v_seq_all, |
| "v_row_all": v_row_all, |
| "v_rec_all": v_rec_all, |
| "v_seq_prompt": v_seq_all[:prompt_len], |
| "v_row_prompt": v_row_all[:prompt_len], |
| "v_rec_prompt": v_rec_all[:prompt_len], |
| "prompt_len": np.asarray(int(prompt_len), dtype=np.int64), |
| "gen_len": np.asarray(int(gen_len), dtype=np.int64), |
| "indices_to_explain_gen": np.asarray(list(example.indices_to_explain or []), dtype=np.int64), |
| } |
|
|
| if example.sink_span is not None: |
| payload["sink_span_gen"] = np.asarray(list(example.sink_span), dtype=np.int64) |
| if example.thinking_span is not None: |
| payload["thinking_span_gen"] = np.asarray(list(example.thinking_span), dtype=np.int64) |
|
|
| if user_prompt_indices is not None: |
| payload["user_prompt_indices"] = np.asarray(list(user_prompt_indices), dtype=np.int64) |
| if gold_prompt_token_indices is not None: |
| payload["gold_prompt_token_indices"] = np.asarray(list(gold_prompt_token_indices), dtype=np.int64) |
|
|
| if recovery_scores is not None: |
| payload["recovery_scores"] = np.asarray(recovery_scores, dtype=np.float64) |
|
|
| if time_attr_s is not None: |
| payload["time_attr_s"] = np.asarray(float(time_attr_s), dtype=np.float64) |
| if time_recovery_s is not None: |
| payload["time_recovery_s"] = np.asarray(float(time_recovery_s), dtype=np.float64) |
|
|
| return payload |
|
|
|
|
| def _write_sample_trace( |
| trace_dir: Path, |
| *, |
| example_idx: int, |
| prompt: str, |
| target: str, |
| payload: Dict[str, np.ndarray], |
| manifest_handle, |
| neg_handling: str, |
| norm_mode: str, |
| recovery_skipped_reason: Optional[str], |
| ) -> None: |
| trace_dir.mkdir(parents=True, exist_ok=True) |
| npz_name = f"ex_{example_idx:06d}.npz" |
| npz_path = trace_dir / npz_name |
| np.savez_compressed(npz_path, **payload) |
|
|
| prompt_len = int(np.asarray(payload.get("prompt_len", 0)).item()) |
| gen_len = int(np.asarray(payload.get("gen_len", 0)).item()) |
| record: Dict[str, Any] = { |
| "example_idx": int(example_idx), |
| "attr_func": "attnlrp", |
| "file": npz_name, |
| "prompt_sha1": _sha1_text(prompt), |
| "target_sha1": _sha1_text(target), |
| "prompt_len": prompt_len, |
| "gen_len": gen_len, |
| "indices_to_explain_gen": payload.get("indices_to_explain_gen").tolist() |
| if payload.get("indices_to_explain_gen") is not None |
| else None, |
| "sink_span_gen": payload.get("sink_span_gen").tolist() if payload.get("sink_span_gen") is not None else None, |
| "thinking_span_gen": payload.get("thinking_span_gen").tolist() |
| if payload.get("thinking_span_gen") is not None |
| else None, |
| "gold_prompt_token_indices": payload.get("gold_prompt_token_indices").tolist() |
| if payload.get("gold_prompt_token_indices") is not None |
| else None, |
| "recovery_scores": payload.get("recovery_scores").tolist() if payload.get("recovery_scores") is not None else None, |
| "recovery_skipped_reason": recovery_skipped_reason, |
| "time_attr_s": float(np.asarray(payload.get("time_attr_s")).item()) if payload.get("time_attr_s") is not None else None, |
| "time_recovery_s": float(np.asarray(payload.get("time_recovery_s")).item()) |
| if payload.get("time_recovery_s") is not None |
| else None, |
| "attnlrp_neg_handling": str(neg_handling), |
| "attnlrp_norm_mode": str(norm_mode), |
| } |
| manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n") |
| manifest_handle.flush() |
|
|
|
|
| def resolve_device(args) -> str: |
| if args.cuda is not None and "," in args.cuda: |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda |
| return "auto" |
| if args.cuda is not None and str(args.cuda).strip(): |
| return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu" |
| return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def load_model(model_name: str, device: str): |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None, |
| torch_dtype=torch.float16, |
| attn_implementation="eager", |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| tokenizer.pad_token = tokenizer.eos_token |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| def _evaluate_one_dataset( |
| *, |
| dataset_name: str, |
| examples: List[ds_utils.CachedExample], |
| model, |
| tokenizer, |
| output_root: Path, |
| model_tag: str, |
| neg_handling: str, |
| norm_mode: str, |
| top_fraction: float, |
| num_examples: int, |
| ) -> Tuple[np.ndarray, np.ndarray, float, int, int]: |
| llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(model, tokenizer) |
|
|
| results: List[np.ndarray] = [] |
| durations: List[float] = [] |
| skipped = 0 |
|
|
| total = min(len(examples), int(num_examples)) |
| iterator = islice(examples, total) |
|
|
| run_tag = _trace_run_tag(neg_handling=neg_handling, norm_mode=norm_mode, total=total) |
| trace_dir = output_root / "traces" / dataset_name / model_tag / run_tag |
| trace_dir.mkdir(parents=True, exist_ok=True) |
| manifest_handle = open(trace_dir / "manifest.jsonl", "w", encoding="utf-8") |
|
|
| try: |
| for example_idx, ex in enumerate(iterator): |
| time_recovery_s: Optional[float] = None |
| recovery_scores: Optional[np.ndarray] = None |
|
|
| needle_spans = (ex.metadata or {}).get("needle_spans") |
| if not isinstance(needle_spans, list) or not needle_spans: |
| raise SystemExit( |
| "exp3 recovery requires RULER samples with metadata.needle_spans; " |
| f"dataset={dataset_name} has missing/empty needle_spans." |
| ) |
| if ex.target is None: |
| raise SystemExit( |
| "exp3 recovery requires cached targets (CoT+answer) so row/rec attribution is well-defined. " |
| f"dataset={dataset_name} has target=None; run exp/exp3/sample_and_filter.py first." |
| ) |
| if not (isinstance(ex.indices_to_explain, list) and len(ex.indices_to_explain) == 2): |
| raise SystemExit( |
| "exp3 expects indices_to_explain=[start_tok,end_tok] in generation-token coordinates; " |
| f"dataset={dataset_name} has indices_to_explain={ex.indices_to_explain!r}; " |
| "run exp/exp3/sample_and_filter.py first." |
| ) |
|
|
| gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer) |
| recovery_skip_reason: Optional[str] = None |
|
|
| sample_start = time.perf_counter() |
| llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer) |
| attr_result = llm_attributor.calculate_attnlrp_ft_hop0( |
| ex.prompt, |
| target=ex.target, |
| sink_span=tuple(ex.sink_span) if ex.sink_span else None, |
| thinking_span=tuple(ex.thinking_span) if ex.thinking_span else None, |
| neg_handling=str(neg_handling), |
| norm_mode=str(norm_mode), |
| ) |
| seq_attr, row_attr, rec_attr = attr_result.get_all_token_attrs(list(ex.indices_to_explain)) |
| time_attr_s = time.perf_counter() - sample_start |
| durations.append(float(time_attr_s)) |
|
|
| prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) |
| if prompt_len <= 0: |
| recovery_skip_reason = "empty_prompt_len" |
| elif not gold_prompt: |
| recovery_skip_reason = "empty_gold_prompt" |
| else: |
| t2 = time.perf_counter() |
| recovery_scores = np.asarray( |
| [ |
| llm_evaluator.evaluate_attr_recovery( |
| a, |
| prompt_len=prompt_len, |
| gold_prompt_token_indices=gold_prompt, |
| top_fraction=top_fraction, |
| ) |
| for a in (seq_attr, row_attr, rec_attr) |
| ], |
| dtype=np.float64, |
| ) |
| time_recovery_s = time.perf_counter() - t2 |
| if np.isnan(recovery_scores).any(): |
| recovery_scores = None |
| recovery_skip_reason = "nan_recovery" |
|
|
| if recovery_scores is None and recovery_skip_reason is not None: |
| skipped += 1 |
| elif recovery_scores is not None: |
| results.append(recovery_scores) |
|
|
| payload = _build_sample_trace_payload( |
| ex, |
| seq_attr=seq_attr, |
| row_attr=row_attr, |
| rec_attr=rec_attr, |
| prompt_len=prompt_len, |
| user_prompt_indices=getattr(llm_attributor, "user_prompt_indices", None), |
| gold_prompt_token_indices=gold_prompt, |
| recovery_scores=recovery_scores, |
| time_attr_s=time_attr_s, |
| time_recovery_s=time_recovery_s, |
| ) |
| _write_sample_trace( |
| trace_dir, |
| example_idx=example_idx, |
| prompt=ex.prompt, |
| target=str(ex.target), |
| payload=payload, |
| manifest_handle=manifest_handle, |
| neg_handling=str(neg_handling), |
| norm_mode=str(norm_mode), |
| recovery_skipped_reason=recovery_skip_reason, |
| ) |
| finally: |
| try: |
| manifest_handle.close() |
| except Exception: |
| pass |
|
|
| scores = np.stack(results, axis=0) if results else np.zeros((0, 3), dtype=np.float64) |
| used = int(scores.shape[0]) |
| mean = scores.mean(0) if used else np.full((3,), np.nan, dtype=np.float64) |
| std = scores.std(0) if used else np.full((3,), np.nan, dtype=np.float64) |
| avg_time = float(np.mean(durations)) if durations else 0.0 |
| return mean, std, avg_time, used, int(skipped) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser("Experiment 3 runner (attnlrp hop0, recovery only).") |
| parser.add_argument("--dataset_tag", type=str, default="niah_mq_q2", help="Base tag for exp3 caches.") |
| parser.add_argument("--data_root", type=str, default="exp/exp3/data") |
| parser.add_argument("--output_root", type=str, default="exp/exp3/output") |
| parser.add_argument("--num_examples", type=int, default=1, help="How many examples to evaluate per dataset (default 1).") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).") |
| parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.") |
| parser.add_argument("--cuda_num", type=int, default=0) |
| parser.add_argument("--cuda", type=str, default=None) |
| parser.add_argument("--top_fraction", type=float, default=0.1, help="Top fraction of prompt tokens used for recovery.") |
| parser.add_argument( |
| "--attnlrp_neg_handling", |
| type=str, |
| choices=["drop", "abs"], |
| default="drop", |
| help="AttnLRP hop0: how to handle negative values (drop=clamp>=0, abs=absolute value).", |
| ) |
| parser.add_argument( |
| "--attnlrp_norm_mode", |
| type=str, |
| choices=["norm", "no_norm"], |
| default="norm", |
| help="AttnLRP hop0: norm enables internal normalization; no_norm disables it.", |
| ) |
| args = parser.parse_args() |
|
|
| if args.model_path: |
| model_name = args.model_path |
| elif args.model: |
| model_name = args.model |
| else: |
| raise SystemExit("Please set --model or --model_path.") |
| model_tag = args.model if args.model else Path(args.model_path).name |
|
|
| device = resolve_device(args) |
| model, tokenizer = load_model(model_name, device) |
|
|
| data_root = Path(args.data_root) |
| output_root = Path(args.output_root) |
| output_root.mkdir(parents=True, exist_ok=True) |
|
|
| short_name = f"{args.dataset_tag}_short_cot" |
| long_name = f"{args.dataset_tag}_long_cot" |
| dataset_names = [short_name, long_name] |
|
|
| summary_rows: List[Dict[str, Any]] = [] |
|
|
| for ds_name in dataset_names: |
| cache_path = data_root / f"{ds_name}.jsonl" |
| if not cache_path.exists(): |
| raise SystemExit(f"Missing exp3 cache: {cache_path}. Run exp/exp3/sample_and_filter.py first.") |
| examples = ds_utils.load_cached(cache_path, sample=None, seed=args.seed) |
|
|
| mean, std, avg_time, used, skipped = _evaluate_one_dataset( |
| dataset_name=ds_name, |
| examples=examples, |
| model=model, |
| tokenizer=tokenizer, |
| output_root=output_root, |
| model_tag=model_tag, |
| neg_handling=args.attnlrp_neg_handling, |
| norm_mode=args.attnlrp_norm_mode, |
| top_fraction=float(args.top_fraction), |
| num_examples=int(args.num_examples), |
| ) |
|
|
| out_dir = output_root / "recovery" / ds_name / model_tag |
| out_dir.mkdir(parents=True, exist_ok=True) |
| filename = f"attnlrp_{int(args.num_examples)}_examples.csv" |
| with (out_dir / filename).open("w", encoding="utf-8") as f: |
| f.write("Method,Recovery@10%\n") |
| f.write(f"Seq Attr Recovery Mean,{mean[0]}\n") |
| f.write(f"Row Attr Recovery Mean,{mean[1]}\n") |
| f.write(f"Recursive Attr Recovery Mean,{mean[2]}\n") |
| f.write(f"Seq Attr Recovery Std,{std[0]}\n") |
| f.write(f"Row Attr Recovery Std,{std[1]}\n") |
| f.write(f"Recursive Attr Recovery Std,{std[2]}\n") |
| f.write(f"Examples Used,{used}\n") |
| f.write(f"Examples Skipped,{skipped}\n") |
| f.write(f"Avg Sample Time (s),{avg_time}\n") |
|
|
| print(f"[{ds_name}] attnlrp -> {out_dir/filename} (used={used} skipped={skipped} avg {avg_time:.2f}s)") |
| summary_rows.append( |
| { |
| "dataset": ds_name, |
| "model": model_tag, |
| "neg_handling": args.attnlrp_neg_handling, |
| "norm_mode": args.attnlrp_norm_mode, |
| "seq_recovery@10%": float(mean[0]) if used else float("nan"), |
| "row_recovery@10%": float(mean[1]) if used else float("nan"), |
| "rec_recovery@10%": float(mean[2]) if used else float("nan"), |
| "used": int(used), |
| "skipped": int(skipped), |
| } |
| ) |
|
|
| |
| summary_path = output_root / "recovery" / f"summary_{model_tag}.json" |
| summary_path.parent.mkdir(parents=True, exist_ok=True) |
| summary_path.write_text(json.dumps(summary_rows, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") |
| print(f"Wrote summary -> {summary_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|