flashtrace / exp /case_study /run_ifr_case.py
wenbopan's picture
Sync FlashTrace package from GitHub
55b60a8
#!/usr/bin/env python3
"""Case study runner for FlashTrace and attribution baselines.
Modes supported (all emit JSON + HTML under ``exp/case_study/out``):
- ``ft``: FlashTrace (current project implementation; multi-hop IFR)
- ``ifr_in_all_gen``: Experimental multi-hop IFR variant (hops over CoT+output; scheme B, aligns with exp/exp2)
- ``ifr``: IFR span-aggregate visualization (single hop; one panel)
- ``ifr_all_positions``: IFR full matrix + CAGE (Row/Recursive panels)
- ``ifr_all_positions_output_only``: IFR output-only token matrix + CAGE (Row/Recursive panels)
- ``attnlrp``: AttnLRP hop0 (reuse FT-AttnLRP span-aggregate; visualize raw hop0 vector)
- ``ft_attnlrp``: FT-AttnLRP (multi-hop aggregated AttnLRP; matches exp/exp2)
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import types
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
# Avoid torchvision dependency when importing transformers (Longformer).
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ.setdefault("DISABLE_TRANSFORMERS_IMAGE_TRANSFORMS", "1")
def _early_set_cuda_visible_devices() -> None:
"""Set CUDA_VISIBLE_DEVICES before importing torch/transformers.
Note: CUDA device indices are re-mapped inside the process after applying the mask.
"""
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--cuda", type=str, default=None)
args, _ = parser.parse_known_args(sys.argv[1:])
cuda = args.cuda.strip() if isinstance(args.cuda, str) else ""
if cuda and "," in cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda
if __name__ == "__main__":
_early_set_cuda_visible_devices()
import torch
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def _stub_torchvision() -> None:
"""Provide minimal torchvision stubs so Longformer imports succeed without the real package."""
if "torchvision" in sys.modules:
return
from importlib.machinery import ModuleSpec
def _mk(name: str) -> types.ModuleType:
mod = types.ModuleType(name)
mod.__spec__ = ModuleSpec(name, loader=None)
return mod
tv = _mk("torchvision")
tv.__dict__["__path__"] = []
submods = ["transforms", "_meta_registrations", "datasets", "io", "models", "ops", "utils"]
for name in submods:
mod = _mk(f"torchvision.{name}")
sys.modules[f"torchvision.{name}"] = mod
setattr(tv, name, mod)
class _InterpolationMode:
NEAREST = 0
NEAREST_EXACT = 0
BILINEAR = 1
BICUBIC = 2
LANCZOS = 3
BOX = 4
HAMMING = 5
sys.modules["torchvision.transforms"].InterpolationMode = _InterpolationMode
sys.modules["torchvision.transforms"].__all__ = ["InterpolationMode"]
# ops + misc stub for timm/transformers imports
ops_mod = sys.modules.get("torchvision.ops") or _mk("torchvision.ops")
sys.modules["torchvision.ops"] = ops_mod
setattr(tv, "ops", ops_mod)
misc_mod = _mk("torchvision.ops.misc")
sys.modules["torchvision.ops.misc"] = misc_mod
setattr(ops_mod, "misc", misc_mod)
class _FrozenBatchNorm2d:
def __init__(self, *args, **kwargs):
pass
misc_mod.FrozenBatchNorm2d = _FrozenBatchNorm2d
sys.modules["torchvision"] = tv
_stub_torchvision()
def _stub_timm() -> None:
"""Provide minimal timm stubs to avoid optional vision deps."""
if "timm" in sys.modules:
return
from importlib.machinery import ModuleSpec
def _mk(name: str) -> types.ModuleType:
mod = types.ModuleType(name)
mod.__spec__ = ModuleSpec(name, loader=None)
return mod
timm = _mk("timm")
timm.__dict__["__path__"] = []
sys.modules["timm"] = timm
data_mod = _mk("timm.data")
sys.modules["timm.data"] = data_mod
timm.data = data_mod
class _ImageNetInfo:
pass
def _infer_imagenet_subset(*args, **kwargs):
return None
data_mod.ImageNetInfo = _ImageNetInfo
data_mod.infer_imagenet_subset = _infer_imagenet_subset
layers_mod = _mk("timm.layers")
sys.modules["timm.layers"] = layers_mod
timm.layers = layers_mod
create_norm_mod = _mk("timm.layers.create_norm")
sys.modules["timm.layers.create_norm"] = create_norm_mod
layers_mod.create_norm = create_norm_mod
def _get_norm_layer(*args, **kwargs):
return None
create_norm_mod.get_norm_layer = _get_norm_layer
classifier_mod = _mk("timm.layers.classifier")
sys.modules["timm.layers.classifier"] = classifier_mod
layers_mod.classifier = classifier_mod
_stub_timm()
import transformers
# Provide light stubs if Longformer classes are unavailable; IFR case study does not use them.
if not hasattr(transformers, "LongformerTokenizer"):
class _DummyLongformerTokenizer:
def __init__(self, *args, **kwargs):
raise ImportError("LongformerTokenizer stubbed; install full transformers+torchvision if needed.")
transformers.LongformerTokenizer = _DummyLongformerTokenizer
if not hasattr(transformers, "LongformerForMaskedLM"):
class _DummyLongformerForMaskedLM:
def __init__(self, *args, **kwargs):
raise ImportError("LongformerForMaskedLM stubbed; install full transformers+torchvision if needed.")
transformers.LongformerForMaskedLM = _DummyLongformerForMaskedLM
if hasattr(transformers, "__all__"):
for _name in ["LongformerTokenizer", "LongformerForMaskedLM"]:
if _name not in transformers.__all__:
transformers.__all__.append(_name)
# Gemma3n stubs (transformers may attempt to import even if unused)
if "transformers.models.gemma3n.configuration_gemma3n" not in sys.modules:
from importlib.machinery import ModuleSpec
gemma_pkg = types.ModuleType("transformers.models.gemma3n")
gemma_pkg.__spec__ = ModuleSpec("transformers.models.gemma3n", loader=None, is_package=True)
sys.modules["transformers.models.gemma3n"] = gemma_pkg
gemma_conf = types.ModuleType("transformers.models.gemma3n.configuration_gemma3n")
gemma_conf.__spec__ = ModuleSpec("transformers.models.gemma3n.configuration_gemma3n", loader=None)
class Gemma3nConfig:
def __init__(self, *args, **kwargs):
self.model_type = "gemma3n"
class Gemma3nTextConfig(Gemma3nConfig):
pass
gemma_conf.Gemma3nConfig = Gemma3nConfig
gemma_conf.Gemma3nTextConfig = Gemma3nTextConfig
gemma_conf.__all__ = ["Gemma3nConfig", "Gemma3nTextConfig"]
sys.modules["transformers.models.gemma3n.configuration_gemma3n"] = gemma_conf
setattr(gemma_pkg, "configuration_gemma3n", gemma_conf)
if hasattr(transformers, "__all__"):
for _nm in ["Gemma3nConfig", "Gemma3nTextConfig"]:
if _nm not in transformers.__all__:
transformers.__all__.append(_nm)
import llm_attr
from exp.exp2 import dataset_utils as ds_utils
from evaluations.attribution_recovery import load_model
from exp.case_study import analysis, viz
def resolve_device(cuda: Optional[str], cuda_num: int) -> str:
if cuda and isinstance(cuda, str) and "," in cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda
return "auto"
if cuda and isinstance(cuda, str) and cuda.strip():
try:
idx = int(cuda)
except Exception:
idx = 0
return f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu"
def load_example(dataset: str, index: int, data_root: Path) -> Tuple[ds_utils.CachedExample, str]:
"""Load a single example from a cache path or dataset name."""
ds_path = Path(dataset)
if ds_path.exists():
examples = ds_utils.read_cached_jsonl(ds_path)
dataset_name = ds_path.name
else:
loader = ds_utils.DatasetLoader(data_root=data_root)
examples = loader.load(dataset)
dataset_name = dataset
if not examples:
raise ValueError(f"No examples found for dataset={dataset}")
if index < 0:
index = len(examples) + index
if not (0 <= index < len(examples)):
raise IndexError(f"index {index} out of range for dataset with {len(examples)} examples")
return examples[index], dataset_name
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser("IFR multi-hop case study")
parser.add_argument("--dataset", type=str, default="exp/exp2/data/morehopqa.jsonl", help="Dataset name or JSONL path.")
parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache root for dataset names.")
parser.add_argument("--index", type=int, default=0, help="Sample index (supports negative for reverse).")
parser.add_argument(
"--mode",
type=str,
choices=[
"ft",
"ft_improve",
"ft_split_hop",
"ifr_in_all_gen",
"ifr",
"ifr_all_positions",
"ifr_all_positions_output_only",
"attnlrp",
"ft_attnlrp",
],
default="ft",
help=(
"ft = FlashTrace (multi-hop IFR); ifr = standard IFR span-aggregate; "
"ifr_in_all_gen = multi-hop IFR over CoT+output (scheme B; exp2-aligned); "
"ifr_all_positions = full IFR matrix + CAGE row/rec; "
"ft_improve = FlashTrace (multi-hop IFR, stop-token soft deletion); "
"ft_split_hop = FlashTrace (split-hop IFR over segmented thinking span); "
"ifr_all_positions_output_only = output-only IFR matrix + CAGE row/rec; "
"attnlrp = AttnLRP hop0 (FT-AttnLRP span-aggregate); "
"ft_attnlrp = FT-AttnLRP (multi-hop aggregated; exp2)."
),
)
parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (ignored if --model_path set).")
parser.add_argument("--model_path", type=str, default=None, help="Local model path to override --model.")
parser.add_argument("--cuda", type=str, default=None, help="CUDA spec (e.g., '0' or '0,1').")
parser.add_argument("--cuda_num", type=int, default=0, help="Fallback GPU index when --cuda unset.")
parser.add_argument("--n_hops", type=int, default=1, help="Number of hops for IFR multi-hop.")
parser.add_argument("--sink_span", type=int, nargs=2, default=None, help="Optional sink span over generation tokens.")
parser.add_argument("--thinking_span", type=int, nargs=2, default=None, help="Optional thinking span over generation tokens.")
parser.add_argument(
"--attnlrp_neg_handling",
type=str,
choices=["drop", "abs"],
default="drop",
help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
)
parser.add_argument(
"--attnlrp_norm_mode",
type=str,
choices=["norm", "no_norm"],
default="norm",
help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
)
parser.add_argument("--chunk_tokens", type=int, default=128, help="IFR chunk size.")
parser.add_argument("--sink_chunk_tokens", type=int, default=32, help="IFR sink chunk size.")
parser.add_argument("--output_dir", type=str, default="exp/case_study/out", help="Where to write HTML/JSON artifacts.")
return parser.parse_args()
def run_ft_multihop(
example: ds_utils.CachedExample,
model: Any,
tokenizer: Any,
*,
n_hops: int,
sink_span: Optional[Sequence[int]],
thinking_span: Optional[Sequence[int]],
chunk_tokens: int,
sink_chunk_tokens: int,
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
"""Execute FT (current multi-hop IFR) attribution for the selected example."""
attr = llm_attr.LLMIFRAttribution(
model,
tokenizer,
chunk_tokens=chunk_tokens,
sink_chunk_tokens=sink_chunk_tokens,
)
sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
thinking = (
tuple(thinking_span)
if thinking_span is not None
else tuple(example.thinking_span) if example.thinking_span else None
)
result = attr.calculate_ifr_multi_hop(
example.prompt,
target=example.target,
sink_span=sink,
thinking_span=thinking,
n_hops=n_hops,
)
debug_info: Dict[str, Any] = {
"full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
"generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
"user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
"chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
"prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
"generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
}
raw_vectors = []
if result.metadata and "ifr" in result.metadata:
raw_ifr = result.metadata["ifr"].get("raw")
if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
try:
raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
except Exception:
raw_vectors = []
debug_info["raw_hop_vectors"] = raw_vectors
return result, sink, thinking, debug_info
def run_ft_multihop_improve(
example: ds_utils.CachedExample,
model: Any,
tokenizer: Any,
*,
n_hops: int,
sink_span: Optional[Sequence[int]],
thinking_span: Optional[Sequence[int]],
chunk_tokens: int,
sink_chunk_tokens: int,
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
"""Execute experimental FT (multi-hop IFR) with stop-token soft deletion."""
import ft_ifr_improve
attr = ft_ifr_improve.LLMIFRAttributionImproved(
model,
tokenizer,
chunk_tokens=chunk_tokens,
sink_chunk_tokens=sink_chunk_tokens,
)
sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
thinking = (
tuple(thinking_span)
if thinking_span is not None
else tuple(example.thinking_span) if example.thinking_span else None
)
result = attr.calculate_ifr_multi_hop_stop_words(
example.prompt,
target=example.target,
sink_span=sink,
thinking_span=thinking,
n_hops=n_hops,
)
debug_info: Dict[str, Any] = {
"full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
"generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
"user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
"chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
"prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
"generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
}
raw_vectors = []
if result.metadata and "ifr" in result.metadata:
raw_ifr = result.metadata["ifr"].get("raw")
if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
try:
raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
except Exception:
raw_vectors = []
debug_info["raw_hop_vectors"] = raw_vectors
return result, sink, thinking, debug_info
def run_ft_multihop_split_hop(
example: ds_utils.CachedExample,
model: Any,
tokenizer: Any,
*,
n_hops: int,
sink_span: Optional[Sequence[int]],
thinking_span: Optional[Sequence[int]],
chunk_tokens: int,
sink_chunk_tokens: int,
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
"""Execute experimental FT (split-hop IFR over segmented thinking span)."""
import ft_ifr_improve
attr = ft_ifr_improve.LLMIFRAttributionSplitHop(
model,
tokenizer,
chunk_tokens=chunk_tokens,
sink_chunk_tokens=sink_chunk_tokens,
)
sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
thinking = (
tuple(thinking_span)
if thinking_span is not None
else tuple(example.thinking_span) if example.thinking_span else None
)
result = attr.calculate_ifr_multi_hop_split_hop(
example.prompt,
target=example.target,
sink_span=sink,
thinking_span=thinking,
n_hops=int(n_hops),
)
debug_info: Dict[str, Any] = {
"full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
"generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
"user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
"chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
"prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
"generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
}
raw_vectors = []
if result.metadata and "ifr" in result.metadata:
raw_ifr = result.metadata["ifr"].get("raw")
if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
try:
raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
except Exception:
raw_vectors = []
debug_info["raw_hop_vectors"] = raw_vectors
return result, sink, thinking, debug_info
def run_ifr_in_all_gen(
example: ds_utils.CachedExample,
model: Any,
tokenizer: Any,
*,
n_hops: int,
sink_span: Optional[Sequence[int]],
thinking_span: Optional[Sequence[int]],
chunk_tokens: int,
sink_chunk_tokens: int,
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
"""Execute experimental IFR variant: multi-hop over all generation (CoT + output)."""
import ft_ifr_improve
attr = ft_ifr_improve.LLMIFRAttributionInAllGen(
model,
tokenizer,
chunk_tokens=chunk_tokens,
sink_chunk_tokens=sink_chunk_tokens,
)
sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
thinking = (
tuple(thinking_span)
if thinking_span is not None
else tuple(example.thinking_span) if example.thinking_span else None
)
result = attr.calculate_ifr_in_all_gen(
example.prompt,
target=example.target,
sink_span=sink,
thinking_span=thinking,
n_hops=int(n_hops),
)
debug_info: Dict[str, Any] = {
"full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
"generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
"user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
"chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
"prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
"generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
}
raw_vectors = []
if result.metadata and "ifr" in result.metadata:
raw_ifr = result.metadata["ifr"].get("raw")
if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
try:
raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
except Exception:
raw_vectors = []
debug_info["raw_hop_vectors"] = raw_vectors
return result, sink, thinking, debug_info
def make_output_stem(dataset_name: str, index: int, mode: str) -> str:
safe_name = dataset_name.replace("/", "_").replace(" ", "_")
prefix = {
"ft": "ft_case_",
"ft_improve": "ft_improve_case_",
"ifr": "ifr_case_",
"ifr_all_positions": "ifr_all_positions_case_",
"ifr_all_positions_output_only": "ifr_output_only_case_",
"attnlrp": "attnlrp_case_",
"ft_attnlrp": "ft_attnlrp_case_",
}.get(mode, f"{mode}_case_")
return f"{prefix}{safe_name}_idx{index}"
def _decode_token_ids(tokenizer: Any, ids: Sequence[int]) -> List[str]:
"""Decode each token id into a readable text piece (keeps special tokens)."""
pieces: List[str] = []
for tok_id in ids:
try:
pieces.append(
tokenizer.decode([int(tok_id)], skip_special_tokens=False, clean_up_tokenization_spaces=False)
)
except Exception:
pieces.append(str(tok_id))
return pieces
def build_raw_tokens_from_ids(tokenizer: Any, prompt_ids: Optional[Sequence[int]], generation_ids: Optional[Sequence[int]]) -> List[str]:
if not prompt_ids:
prompt_ids = []
if not generation_ids:
generation_ids = []
return _decode_token_ids(tokenizer, prompt_ids) + _decode_token_ids(tokenizer, generation_ids)
def build_trimmed_roles(tokens: Sequence[str], segments: Dict[str, Any]) -> List[str]:
"""Assign role labels for trimmed tokens (prompt + generation)."""
roles = ["prompt" for _ in range(len(tokens))]
prompt_len_tokens = segments.get("prompt_len", 0)
for idx in range(prompt_len_tokens, len(tokens)):
roles[idx] = "gen"
thinking_span = segments.get("thinking_span")
sink_span = segments.get("sink_span")
if thinking_span is not None:
start = prompt_len_tokens + int(thinking_span[0])
end = prompt_len_tokens + int(thinking_span[1])
for i in range(start, min(len(tokens), end + 1)):
roles[i] = "think"
if sink_span is not None:
start = prompt_len_tokens + int(sink_span[0])
end = prompt_len_tokens + int(sink_span[1])
for i in range(start, min(len(tokens), end + 1)):
roles[i] = "output"
return roles
def build_raw_roles(
tokens: Sequence[str],
prompt_len_full: int,
user_indices: Sequence[int],
template_indices: Sequence[int],
thinking_span_abs: Optional[Sequence[int]],
sink_span_abs: Optional[Sequence[int]],
) -> List[str]:
"""Assign role labels for raw tokens (template + user + generation)."""
roles = ["template" for _ in range(len(tokens))]
user_set = set(int(i) for i in user_indices)
tmpl_set = set(int(i) for i in template_indices)
for i in range(min(len(tokens), prompt_len_full)):
if i in user_set:
roles[i] = "user"
elif i in tmpl_set:
roles[i] = "template"
else:
roles[i] = "prompt"
for i in range(prompt_len_full, len(tokens)):
roles[i] = "gen"
if thinking_span_abs is not None:
start, end = int(thinking_span_abs[0]), int(thinking_span_abs[1])
for i in range(start, min(len(tokens), end + 1)):
roles[i] = "think"
if sink_span_abs is not None:
start, end = int(sink_span_abs[0]), int(sink_span_abs[1])
for i in range(start, min(len(tokens), end + 1)):
roles[i] = "output"
return roles
def extract_prompt_only_vectors(hop_vectors: Sequence[torch.Tensor], prompt_len: int) -> List[torch.Tensor]:
"""Slice hop vectors down to user-prompt tokens only (no generation tokens)."""
if prompt_len < 0:
raise ValueError("prompt_len must be >= 0.")
out: List[torch.Tensor] = []
for vec in hop_vectors:
v = torch.as_tensor(vec, dtype=torch.float32).detach().cpu()
if int(v.numel()) < int(prompt_len):
raise ValueError(f"Hop vector too short for prompt-only slice: len={int(v.numel())} prompt_len={int(prompt_len)}.")
out.append(v[:prompt_len])
return out
def _lift_trimmed_to_full(
trimmed: torch.Tensor,
*,
prompt_len_full: int,
gen_len: int,
user_prompt_indices: Sequence[int],
) -> torch.Tensor:
"""Lift a trimmed (user prompt + generation) vector into full token space with zeros for chat-template tokens."""
t = torch.as_tensor(trimmed, dtype=torch.float32).detach().cpu()
user_len = len(user_prompt_indices)
expected = int(user_len + gen_len)
if int(t.numel()) != expected:
raise ValueError(f"Trimmed vector length mismatch: got {int(t.numel())}, expected {expected}.")
total_len = int(prompt_len_full + gen_len)
full = torch.zeros((total_len,), dtype=torch.float32)
for j, abs_pos in enumerate(user_prompt_indices):
full[int(abs_pos)] = t[j]
full[int(prompt_len_full) : int(prompt_len_full + gen_len)] = t[user_len:]
return full
def _postprocess_attnlrp_full_vector(
raw_full: torch.Tensor,
*,
prompt_len_full: int,
gen_len: int,
user_prompt_indices: Sequence[int],
neg_handling: str,
norm_mode: str,
) -> torch.Tensor:
"""Mirror FT-AttnLRP hop postprocessing while preserving stripped-token normalization.
The underlying AttnLRP implementation postprocesses the *stripped* vector (user prompt + generation):
- NaN->0, then neg_handling ('drop' or 'abs')
- if norm_mode=='norm': normalize by sum over stripped tokens
For the pre-trim full view (chat template + generation), we apply the same non-negativity transform
to the full vector and normalize using *only the stripped indices*, so overlapping token scores
match the trimmed vectors used by the evaluation/case-study hop outputs.
"""
v = torch.as_tensor(raw_full, dtype=torch.float32).detach().cpu()
v = torch.nan_to_num(v, nan=0.0)
if neg_handling == "drop":
v = v.clamp(min=0.0)
elif neg_handling == "abs":
v = v.abs()
else:
raise ValueError(f"Unsupported neg_handling={neg_handling!r} (expected 'drop' or 'abs').")
ratio_enabled = norm_mode == "norm"
if not ratio_enabled:
return v
keep = list(int(i) for i in user_prompt_indices) + list(range(int(prompt_len_full), int(prompt_len_full + gen_len)))
if not keep:
return torch.zeros_like(v)
keep_idx = torch.as_tensor(keep, dtype=torch.long)
denom = float(v.index_select(0, keep_idx).sum().item())
if denom <= 0.0:
return torch.zeros_like(v)
return v / (denom + 1e-12)
def main() -> None:
args = parse_args()
device = resolve_device(args.cuda, args.cuda_num)
if torch.cuda.is_available():
visible = os.environ.get("CUDA_VISIBLE_DEVICES")
print(f"[info] CUDA_VISIBLE_DEVICES={visible!r} torch.cuda.device_count()={torch.cuda.device_count()} device={device}")
model_name = args.model_path if args.model_path is not None else args.model
# Align with exp/exp2: always use the shared fp16 loader.
model, tokenizer = load_model(model_name, device)
example, ds_name = load_example(args.dataset, args.index, Path(args.data_root))
mode = args.mode
sink_span: Optional[Tuple[int, int]] = None
thinking_span: Optional[Tuple[int, int]] = None
thinking_ratios: Optional[Sequence[float]] = None
prompt_tokens_trimmed: List[str] = []
generation_tokens_trimmed: List[str] = []
hop_vectors_trimmed: List[torch.Tensor] = []
hop_vectors_raw: List[torch.Tensor] = []
prompt_len_full: Optional[int] = None
user_prompt_indices: List[int] = []
chat_prompt_indices: List[int] = []
method_meta: Dict[str, Any] = {}
raw_prompt_ids: Optional[List[int]] = None
raw_generation_ids: Optional[List[int]] = None
attnlrp_raw_attributions: Optional[List[Any]] = None
if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen"):
if mode == "ft":
attr_result, sink_span, thinking_span, debug_info = run_ft_multihop(
example,
model,
tokenizer,
n_hops=args.n_hops,
sink_span=args.sink_span,
thinking_span=args.thinking_span,
chunk_tokens=args.chunk_tokens,
sink_chunk_tokens=args.sink_chunk_tokens,
)
elif mode == "ft_improve":
attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_improve(
example,
model,
tokenizer,
n_hops=args.n_hops,
sink_span=args.sink_span,
thinking_span=args.thinking_span,
chunk_tokens=args.chunk_tokens,
sink_chunk_tokens=args.sink_chunk_tokens,
)
elif mode == "ft_split_hop":
attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_split_hop(
example,
model,
tokenizer,
n_hops=args.n_hops,
sink_span=args.sink_span,
thinking_span=args.thinking_span,
chunk_tokens=args.chunk_tokens,
sink_chunk_tokens=args.sink_chunk_tokens,
)
elif mode == "ifr_in_all_gen":
attr_result, sink_span, thinking_span, debug_info = run_ifr_in_all_gen(
example,
model,
tokenizer,
n_hops=args.n_hops,
sink_span=args.sink_span,
thinking_span=args.thinking_span,
chunk_tokens=args.chunk_tokens,
sink_chunk_tokens=args.sink_chunk_tokens,
)
else:
raise ValueError(f"Unsupported mode={mode}")
ifr_meta = (attr_result.metadata or {}).get("ifr") or {}
hop_vectors_trimmed = list(ifr_meta.get("per_hop_projected") or [])
if not hop_vectors_trimmed:
raise RuntimeError(f"No per-hop vectors found for {mode} mode.")
prompt_tokens_trimmed = list(attr_result.prompt_tokens)
generation_tokens_trimmed = list(attr_result.generation_tokens)
thinking_ratios = ifr_meta.get("thinking_ratios")
raw_prompt_ids = debug_info.get("prompt_ids")
if isinstance(raw_prompt_ids, list) and raw_prompt_ids and isinstance(raw_prompt_ids[0], list):
raw_prompt_ids = raw_prompt_ids[0]
raw_generation_ids = debug_info.get("generation_ids")
if isinstance(raw_generation_ids, list) and raw_generation_ids and isinstance(raw_generation_ids[0], list):
raw_generation_ids = raw_generation_ids[0]
user_prompt_indices = list(debug_info.get("user_prompt_indices") or [])
chat_prompt_indices = list(debug_info.get("chat_prompt_indices") or [])
prompt_len_full = len(raw_prompt_ids) if isinstance(raw_prompt_ids, list) else None
raw_vectors = debug_info.get("raw_hop_vectors") or []
hop_vectors_raw = [vec.detach().cpu() if hasattr(vec, "detach") else torch.as_tensor(vec) for vec in raw_vectors]
method_meta = {"ifr": analysis.sanitize_ifr_meta(ifr_meta)}
elif mode == "ifr":
# Standard IFR (single-hop span aggregate), with pre/post trim views.
attr = llm_attr.LLMIFRAttribution(
model,
tokenizer,
chunk_tokens=args.chunk_tokens,
sink_chunk_tokens=args.sink_chunk_tokens,
)
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
if sink_span is None:
raise ValueError("sink_span is required for IFR mode (use dataset sink_span or pass --sink_span).")
span_result = attr.calculate_ifr_span(
example.prompt,
target=example.target,
span=tuple(sink_span),
)
span_meta = span_result.metadata.get("ifr") if span_result.metadata else None
aggregate = span_meta.get("aggregate") if isinstance(span_meta, dict) else None
if aggregate is None or not hasattr(aggregate, "token_importance_total"):
raise RuntimeError("IFR span aggregate missing from metadata; cannot render pre-trim view.")
raw_vector = aggregate.token_importance_total.detach().cpu()
trimmed_vector = attr._project_vector(raw_vector)
hop_vectors_raw = [raw_vector]
hop_vectors_trimmed = [trimmed_vector]
prompt_tokens_trimmed = list(attr.user_prompt_tokens)
generation_tokens_trimmed = list(attr.generation_tokens)
raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
prompt_len_full = len(raw_prompt_ids)
sink_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1])
think_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1]) if thinking_span else None
meta = {
"type": "span_aggregate",
"ifr_view": "aggregate",
"sink_span_generation": sink_span,
"sink_span_absolute": sink_abs,
"thinking_span_generation": thinking_span,
"thinking_span_absolute": think_abs,
}
method_meta = {"ifr": analysis.tensor_to_list(meta)}
elif mode == "ifr_all_positions_output_only":
# IFR all-positions (output-only) + token-level CAGE (row/recursive) derived from the matrix.
attr = llm_attr.LLMIFRAttribution(
model,
tokenizer,
chunk_tokens=args.chunk_tokens,
sink_chunk_tokens=args.sink_chunk_tokens,
)
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
if sink_span is None:
raise ValueError(
"sink_span is required for ifr_all_positions_output_only mode "
"(use dataset sink_span or pass --sink_span)."
)
attr_result = attr.calculate_ifr_for_all_positions_output_only(
example.prompt,
target=example.target,
sink_span=tuple(sink_span),
)
indices_to_explain = list(sink_span)
_, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
row_vec = row_attr.squeeze(0).detach().cpu()
rec_vec = rec_attr.squeeze(0).detach().cpu()
hop_vectors_trimmed = [row_vec, rec_vec]
prompt_tokens_trimmed = list(attr.user_prompt_tokens)
generation_tokens_trimmed = list(attr.generation_tokens)
raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
prompt_len_full = len(raw_prompt_ids)
gen_len = len(raw_generation_ids or [])
hop_vectors_raw = [
_lift_trimmed_to_full(
v,
prompt_len_full=int(prompt_len_full or 0),
gen_len=gen_len,
user_prompt_indices=user_prompt_indices,
)
for v in hop_vectors_trimmed
]
ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {})
ifr_meta["ifr_view"] = "all_positions_output_only (row+rec)"
ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"]
ifr_meta["indices_to_explain"] = indices_to_explain
method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)}
elif mode == "ifr_all_positions":
# IFR all-positions (full generation) + token-level CAGE (row/recursive) derived from the matrix.
attr = llm_attr.LLMIFRAttribution(
model,
tokenizer,
chunk_tokens=args.chunk_tokens,
sink_chunk_tokens=args.sink_chunk_tokens,
)
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span
if sink_span is None:
raise ValueError(
"sink_span is required for ifr_all_positions mode (use dataset sink_span or pass --sink_span)."
)
attr_result = attr.calculate_ifr_for_all_positions(
example.prompt,
target=example.target,
)
indices_to_explain = list(sink_span)
_, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
row_vec = row_attr.squeeze(0).detach().cpu()
rec_vec = rec_attr.squeeze(0).detach().cpu()
hop_vectors_trimmed = [row_vec, rec_vec]
prompt_tokens_trimmed = list(attr.user_prompt_tokens)
generation_tokens_trimmed = list(attr.generation_tokens)
raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
prompt_len_full = len(raw_prompt_ids)
gen_len = len(raw_generation_ids or [])
hop_vectors_raw = [
_lift_trimmed_to_full(
v,
prompt_len_full=int(prompt_len_full or 0),
gen_len=gen_len,
user_prompt_indices=user_prompt_indices,
)
for v in hop_vectors_trimmed
]
ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {})
ifr_meta["ifr_view"] = "all_positions (row+rec)"
ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"]
ifr_meta["indices_to_explain"] = indices_to_explain
method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)}
elif mode in ("attnlrp", "ft_attnlrp"):
# Reuse the shared LLMLRPAttribution implementations (root-level).
attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
thinking_span = (
tuple(args.thinking_span)
if args.thinking_span is not None
else tuple(example.thinking_span) if example.thinking_span else sink_span
)
if mode == "attnlrp":
# Case-study AttnLRP: reuse FT-AttnLRP logic but take hop0 (the first span-aggregate)
# for a full, signed attribution vector (no observation masking).
attr_result = attributor.calculate_attnlrp_ft_hop0(
example.prompt,
target=example.target,
sink_span=sink_span,
thinking_span=thinking_span,
neg_handling=args.attnlrp_neg_handling,
norm_mode=args.attnlrp_norm_mode,
)
meta = attr_result.metadata or {}
multi_hop = meta.get("multi_hop_result")
raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
attnlrp_raw_attributions = list(raw_attributions)
base_attr = raw_attributions[0] if raw_attributions else None
if base_attr is None or not hasattr(base_attr, "token_importance_total"):
raise RuntimeError("AttnLRP hop0 missing from multi-hop result.")
hop0_vec = torch.as_tensor(getattr(base_attr, "token_importance_total"), dtype=torch.float32).detach().cpu()
if hop0_vec.numel() <= 0:
raise RuntimeError("Empty generation for AttnLRP case study.")
# Use the actual sink span applied by hop0 (defaults to full generation when unset).
sink_span = tuple(getattr(base_attr, "sink_range"))
if thinking_span is None:
thinking_span = sink_span
hop_vectors_trimmed = [hop0_vec]
thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or [])
method_meta = {
"attnlrp": {
"type": "calculate_attnlrp_multi_hop(n_hops=0) hop0 raw_attributions[0]",
"sink_span_generation": sink_span,
"thinking_span_generation": thinking_span,
"thinking_ratios": thinking_ratios,
"neg_handling": args.attnlrp_neg_handling,
"norm_mode": args.attnlrp_norm_mode,
"ratio_enabled": args.attnlrp_norm_mode == "norm",
}
}
else:
# exp2 ft_attnlrp: multi-hop aggregated AttnLRP (metadata contains per-hop vectors).
attr_result = attributor.calculate_attnlrp_aggregated_multi_hop(
example.prompt,
target=example.target,
sink_span=sink_span,
thinking_span=thinking_span,
n_hops=int(args.n_hops),
neg_handling=args.attnlrp_neg_handling,
norm_mode=args.attnlrp_norm_mode,
)
meta = attr_result.metadata or {}
multi_hop = meta.get("multi_hop_result")
if multi_hop is None:
raise RuntimeError("FT-AttnLRP case study missing metadata.multi_hop_result.")
raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
attnlrp_raw_attributions = list(raw_attributions)
hop_vectors_trimmed = [
torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32).detach().cpu()
for hop in raw_attributions
]
thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or [])
method_meta = {
"attnlrp": {
"type": "calculate_attnlrp_aggregated_multi_hop (exp2 ft_attnlrp)",
"n_hops": int(args.n_hops),
"sink_span_generation": sink_span,
"thinking_span_generation": thinking_span,
"thinking_ratios": thinking_ratios,
"neg_handling": args.attnlrp_neg_handling,
"norm_mode": args.attnlrp_norm_mode,
"ratio_enabled": args.attnlrp_norm_mode == "norm",
}
}
prompt_tokens_trimmed = list(attributor.user_prompt_tokens)
generation_tokens_trimmed = list(attributor.generation_tokens)
raw_prompt_ids = attributor.prompt_ids.detach().cpu().tolist()[0]
raw_generation_ids = attributor.generation_ids.detach().cpu().tolist()[0]
user_prompt_indices = list(getattr(attributor, "user_prompt_indices", []) or [])
chat_prompt_indices = list(getattr(attributor, "chat_prompt_indices", []) or [])
prompt_len_full = len(raw_prompt_ids)
else:
raise ValueError(f"Unsupported mode={mode}")
if not hop_vectors_trimmed:
raise RuntimeError("No hop vectors to visualize.")
raw_tokens = build_raw_tokens_from_ids(tokenizer, raw_prompt_ids, raw_generation_ids)
sink_span_abs = None
thinking_span_abs = None
if prompt_len_full is not None and sink_span is not None:
sink_span_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1])
if prompt_len_full is not None and thinking_span is not None:
thinking_span_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1])
prompt_len_full_safe = int(prompt_len_full or 0)
roles_raw = build_raw_roles(
raw_tokens,
prompt_len_full_safe,
user_prompt_indices,
chat_prompt_indices,
thinking_span_abs,
sink_span_abs,
)
prompt_tokens_only = list(prompt_tokens_trimmed)
prompt_only_vectors = extract_prompt_only_vectors(hop_vectors_trimmed, len(prompt_tokens_only))
# Ensure every method has a pre-trim full vector per panel.
if not hop_vectors_raw:
if mode in ("attnlrp", "ft_attnlrp") and attnlrp_raw_attributions is not None:
gen_len = len(raw_generation_ids or [])
expected = int((prompt_len_full_safe + gen_len) if prompt_len_full is not None else 0)
full_vectors: List[torch.Tensor] = []
for hop in attnlrp_raw_attributions:
meta = getattr(hop, "metadata", None) or {}
raw_full = meta.get("token_importance_total_with_chat_template")
if raw_full is None:
full_vectors = []
break
v = _postprocess_attnlrp_full_vector(
torch.as_tensor(raw_full, dtype=torch.float32),
prompt_len_full=prompt_len_full_safe,
gen_len=gen_len,
user_prompt_indices=user_prompt_indices,
neg_handling=args.attnlrp_neg_handling,
norm_mode=args.attnlrp_norm_mode,
)
if expected and int(v.numel()) != expected:
raise RuntimeError(
"AttnLRP full-vector length mismatch for pre-trim view: "
f"got {int(v.numel())}, expected {expected}."
)
full_vectors.append(v)
hop_vectors_raw = full_vectors
if not hop_vectors_raw and prompt_len_full is not None:
# Fallback: lift trimmed vectors back to full token space with zeros for template tokens.
gen_len = len(raw_generation_ids or [])
hop_vectors_raw = [
_lift_trimmed_to_full(
v,
prompt_len_full=prompt_len_full_safe,
gen_len=gen_len,
user_prompt_indices=user_prompt_indices,
)
for v in hop_vectors_trimmed
]
if not hop_vectors_raw:
raise RuntimeError("Missing pre-trim vectors; cannot render required full-sequence heatmap.")
# Lightweight debug stats to catch silent all-zero / NaN cases.
hop_stats_raw = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in hop_vectors_raw]
hop_stats_prompt = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in prompt_only_vectors]
for i in range(max(len(hop_stats_raw), len(hop_stats_prompt))):
raw_abs = hop_stats_raw[i]["abs_max"] if i < len(hop_stats_raw) else None
prompt_abs = hop_stats_prompt[i]["abs_max"] if i < len(hop_stats_prompt) else None
print(f"[stats] panel {i}: raw_abs_max={raw_abs} prompt_abs_max={prompt_abs}")
hop_token_raw = analysis.package_token_hops(hop_vectors_raw)
hop_token_prompt = analysis.package_token_hops(prompt_only_vectors)
case_meta: Dict[str, Any] = {
"dataset": ds_name,
"index": args.index,
"sink_span": sink_span,
"thinking_span": thinking_span,
"n_hops": args.n_hops,
"thinking_ratios": thinking_ratios,
"mode": mode,
"ifr_view": method_meta.get("ifr", {}).get("ifr_view") if isinstance(method_meta.get("ifr"), dict) else None,
"panel_titles": method_meta.get("ifr", {}).get("panel_titles") if isinstance(method_meta.get("ifr"), dict) else None,
"attnlrp_neg_handling": args.attnlrp_neg_handling if mode in ("attnlrp", "ft_attnlrp") else None,
"attnlrp_norm_mode": args.attnlrp_norm_mode if mode in ("attnlrp", "ft_attnlrp") else None,
"attnlrp_ratio_enabled": (args.attnlrp_norm_mode == "norm") if mode in ("attnlrp", "ft_attnlrp") else None,
"vector_stats_raw": hop_stats_raw,
"vector_stats_prompt": hop_stats_prompt,
}
generation_text = "".join(generation_tokens_trimmed) if generation_tokens_trimmed else ""
prompt_text = example.prompt
record = {
"meta": case_meta,
"prompt": prompt_text,
"target": example.target,
"generation": generation_text,
"full_all_tokens": raw_tokens,
"raw_token_roles": roles_raw,
"prompt_tokens": prompt_tokens_only,
"prompt_token_roles": ["user" for _ in range(len(prompt_tokens_only))],
"token_hops_raw": hop_token_raw,
"token_hops_prompt": hop_token_prompt,
"ifr_meta": method_meta.get("ifr"),
"attnlrp_meta": method_meta.get("attnlrp"),
}
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
stem = make_output_stem(ds_name, args.index, mode)
json_path = out_dir / f"{stem}.json"
html_path = out_dir / f"{stem}.html"
with json_path.open("w", encoding="utf-8") as f:
json.dump(record, f, ensure_ascii=False, indent=2)
html = viz.render_case_html(
case_meta,
token_view_raw={
"label": "Pre-trim token-level heatmap (full sequence with chat template)",
"tokens": raw_tokens,
"roles": roles_raw,
"hops": hop_token_raw,
},
token_view_prompt={
"label": "Prompt-only token-level heatmap (user prompt only)",
"tokens": prompt_tokens_only,
"roles": ["user" for _ in range(len(prompt_tokens_only))],
"hops": hop_token_prompt,
},
)
html_path.write_text(html, encoding="utf-8")
print(f"[done] wrote {json_path}")
print(f"[done] wrote {html_path}")
if __name__ == "__main__":
main()