#!/usr/bin/env python3 """Case study runner for FlashTrace and attribution baselines. Modes supported (all emit JSON + HTML under ``exp/case_study/out``): - ``ft``: FlashTrace (current project implementation; multi-hop IFR) - ``ifr_in_all_gen``: Experimental multi-hop IFR variant (hops over CoT+output; scheme B, aligns with exp/exp2) - ``ifr``: IFR span-aggregate visualization (single hop; one panel) - ``ifr_all_positions``: IFR full matrix + CAGE (Row/Recursive panels) - ``ifr_all_positions_output_only``: IFR output-only token matrix + CAGE (Row/Recursive panels) - ``attnlrp``: AttnLRP hop0 (reuse FT-AttnLRP span-aggregate; visualize raw hop0 vector) - ``ft_attnlrp``: FT-AttnLRP (multi-hop aggregated AttnLRP; matches exp/exp2) """ from __future__ import annotations import argparse import json import os import sys import types from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple # Avoid torchvision dependency when importing transformers (Longformer). os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") os.environ.setdefault("DISABLE_TRANSFORMERS_IMAGE_TRANSFORMS", "1") def _early_set_cuda_visible_devices() -> None: """Set CUDA_VISIBLE_DEVICES before importing torch/transformers. Note: CUDA device indices are re-mapped inside the process after applying the mask. """ parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--cuda", type=str, default=None) args, _ = parser.parse_known_args(sys.argv[1:]) cuda = args.cuda.strip() if isinstance(args.cuda, str) else "" if cuda and "," in cuda: os.environ["CUDA_VISIBLE_DEVICES"] = cuda if __name__ == "__main__": _early_set_cuda_visible_devices() import torch REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) def _stub_torchvision() -> None: """Provide minimal torchvision stubs so Longformer imports succeed without the real package.""" if "torchvision" in sys.modules: return from importlib.machinery import ModuleSpec def _mk(name: str) -> types.ModuleType: mod = types.ModuleType(name) mod.__spec__ = ModuleSpec(name, loader=None) return mod tv = _mk("torchvision") tv.__dict__["__path__"] = [] submods = ["transforms", "_meta_registrations", "datasets", "io", "models", "ops", "utils"] for name in submods: mod = _mk(f"torchvision.{name}") sys.modules[f"torchvision.{name}"] = mod setattr(tv, name, mod) class _InterpolationMode: NEAREST = 0 NEAREST_EXACT = 0 BILINEAR = 1 BICUBIC = 2 LANCZOS = 3 BOX = 4 HAMMING = 5 sys.modules["torchvision.transforms"].InterpolationMode = _InterpolationMode sys.modules["torchvision.transforms"].__all__ = ["InterpolationMode"] # ops + misc stub for timm/transformers imports ops_mod = sys.modules.get("torchvision.ops") or _mk("torchvision.ops") sys.modules["torchvision.ops"] = ops_mod setattr(tv, "ops", ops_mod) misc_mod = _mk("torchvision.ops.misc") sys.modules["torchvision.ops.misc"] = misc_mod setattr(ops_mod, "misc", misc_mod) class _FrozenBatchNorm2d: def __init__(self, *args, **kwargs): pass misc_mod.FrozenBatchNorm2d = _FrozenBatchNorm2d sys.modules["torchvision"] = tv _stub_torchvision() def _stub_timm() -> None: """Provide minimal timm stubs to avoid optional vision deps.""" if "timm" in sys.modules: return from importlib.machinery import ModuleSpec def _mk(name: str) -> types.ModuleType: mod = types.ModuleType(name) mod.__spec__ = ModuleSpec(name, loader=None) return mod timm = _mk("timm") timm.__dict__["__path__"] = [] sys.modules["timm"] = timm data_mod = _mk("timm.data") sys.modules["timm.data"] = data_mod timm.data = data_mod class _ImageNetInfo: pass def _infer_imagenet_subset(*args, **kwargs): return None data_mod.ImageNetInfo = _ImageNetInfo data_mod.infer_imagenet_subset = _infer_imagenet_subset layers_mod = _mk("timm.layers") sys.modules["timm.layers"] = layers_mod timm.layers = layers_mod create_norm_mod = _mk("timm.layers.create_norm") sys.modules["timm.layers.create_norm"] = create_norm_mod layers_mod.create_norm = create_norm_mod def _get_norm_layer(*args, **kwargs): return None create_norm_mod.get_norm_layer = _get_norm_layer classifier_mod = _mk("timm.layers.classifier") sys.modules["timm.layers.classifier"] = classifier_mod layers_mod.classifier = classifier_mod _stub_timm() import transformers # Provide light stubs if Longformer classes are unavailable; IFR case study does not use them. if not hasattr(transformers, "LongformerTokenizer"): class _DummyLongformerTokenizer: def __init__(self, *args, **kwargs): raise ImportError("LongformerTokenizer stubbed; install full transformers+torchvision if needed.") transformers.LongformerTokenizer = _DummyLongformerTokenizer if not hasattr(transformers, "LongformerForMaskedLM"): class _DummyLongformerForMaskedLM: def __init__(self, *args, **kwargs): raise ImportError("LongformerForMaskedLM stubbed; install full transformers+torchvision if needed.") transformers.LongformerForMaskedLM = _DummyLongformerForMaskedLM if hasattr(transformers, "__all__"): for _name in ["LongformerTokenizer", "LongformerForMaskedLM"]: if _name not in transformers.__all__: transformers.__all__.append(_name) # Gemma3n stubs (transformers may attempt to import even if unused) if "transformers.models.gemma3n.configuration_gemma3n" not in sys.modules: from importlib.machinery import ModuleSpec gemma_pkg = types.ModuleType("transformers.models.gemma3n") gemma_pkg.__spec__ = ModuleSpec("transformers.models.gemma3n", loader=None, is_package=True) sys.modules["transformers.models.gemma3n"] = gemma_pkg gemma_conf = types.ModuleType("transformers.models.gemma3n.configuration_gemma3n") gemma_conf.__spec__ = ModuleSpec("transformers.models.gemma3n.configuration_gemma3n", loader=None) class Gemma3nConfig: def __init__(self, *args, **kwargs): self.model_type = "gemma3n" class Gemma3nTextConfig(Gemma3nConfig): pass gemma_conf.Gemma3nConfig = Gemma3nConfig gemma_conf.Gemma3nTextConfig = Gemma3nTextConfig gemma_conf.__all__ = ["Gemma3nConfig", "Gemma3nTextConfig"] sys.modules["transformers.models.gemma3n.configuration_gemma3n"] = gemma_conf setattr(gemma_pkg, "configuration_gemma3n", gemma_conf) if hasattr(transformers, "__all__"): for _nm in ["Gemma3nConfig", "Gemma3nTextConfig"]: if _nm not in transformers.__all__: transformers.__all__.append(_nm) import llm_attr from exp.exp2 import dataset_utils as ds_utils from evaluations.attribution_recovery import load_model from exp.case_study import analysis, viz def resolve_device(cuda: Optional[str], cuda_num: int) -> str: if cuda and isinstance(cuda, str) and "," in cuda: os.environ["CUDA_VISIBLE_DEVICES"] = cuda return "auto" if cuda and isinstance(cuda, str) and cuda.strip(): try: idx = int(cuda) except Exception: idx = 0 return f"cuda:{idx}" if torch.cuda.is_available() else "cpu" return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu" def load_example(dataset: str, index: int, data_root: Path) -> Tuple[ds_utils.CachedExample, str]: """Load a single example from a cache path or dataset name.""" ds_path = Path(dataset) if ds_path.exists(): examples = ds_utils.read_cached_jsonl(ds_path) dataset_name = ds_path.name else: loader = ds_utils.DatasetLoader(data_root=data_root) examples = loader.load(dataset) dataset_name = dataset if not examples: raise ValueError(f"No examples found for dataset={dataset}") if index < 0: index = len(examples) + index if not (0 <= index < len(examples)): raise IndexError(f"index {index} out of range for dataset with {len(examples)} examples") return examples[index], dataset_name def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser("IFR multi-hop case study") parser.add_argument("--dataset", type=str, default="exp/exp2/data/morehopqa.jsonl", help="Dataset name or JSONL path.") parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache root for dataset names.") parser.add_argument("--index", type=int, default=0, help="Sample index (supports negative for reverse).") parser.add_argument( "--mode", type=str, choices=[ "ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen", "ifr", "ifr_all_positions", "ifr_all_positions_output_only", "attnlrp", "ft_attnlrp", ], default="ft", help=( "ft = FlashTrace (multi-hop IFR); ifr = standard IFR span-aggregate; " "ifr_in_all_gen = multi-hop IFR over CoT+output (scheme B; exp2-aligned); " "ifr_all_positions = full IFR matrix + CAGE row/rec; " "ft_improve = FlashTrace (multi-hop IFR, stop-token soft deletion); " "ft_split_hop = FlashTrace (split-hop IFR over segmented thinking span); " "ifr_all_positions_output_only = output-only IFR matrix + CAGE row/rec; " "attnlrp = AttnLRP hop0 (FT-AttnLRP span-aggregate); " "ft_attnlrp = FT-AttnLRP (multi-hop aggregated; exp2)." ), ) parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (ignored if --model_path set).") parser.add_argument("--model_path", type=str, default=None, help="Local model path to override --model.") parser.add_argument("--cuda", type=str, default=None, help="CUDA spec (e.g., '0' or '0,1').") parser.add_argument("--cuda_num", type=int, default=0, help="Fallback GPU index when --cuda unset.") parser.add_argument("--n_hops", type=int, default=1, help="Number of hops for IFR multi-hop.") parser.add_argument("--sink_span", type=int, nargs=2, default=None, help="Optional sink span over generation tokens.") parser.add_argument("--thinking_span", type=int, nargs=2, default=None, help="Optional thinking span over generation tokens.") parser.add_argument( "--attnlrp_neg_handling", type=str, choices=["drop", "abs"], default="drop", help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).", ) parser.add_argument( "--attnlrp_norm_mode", type=str, choices=["norm", "no_norm"], default="norm", help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.", ) parser.add_argument("--chunk_tokens", type=int, default=128, help="IFR chunk size.") parser.add_argument("--sink_chunk_tokens", type=int, default=32, help="IFR sink chunk size.") parser.add_argument("--output_dir", type=str, default="exp/case_study/out", help="Where to write HTML/JSON artifacts.") return parser.parse_args() def run_ft_multihop( example: ds_utils.CachedExample, model: Any, tokenizer: Any, *, n_hops: int, sink_span: Optional[Sequence[int]], thinking_span: Optional[Sequence[int]], chunk_tokens: int, sink_chunk_tokens: int, ) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]: """Execute FT (current multi-hop IFR) attribution for the selected example.""" attr = llm_attr.LLMIFRAttribution( model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens, ) sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None thinking = ( tuple(thinking_span) if thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else None ) result = attr.calculate_ifr_multi_hop( example.prompt, target=example.target, sink_span=sink, thinking_span=thinking, n_hops=n_hops, ) debug_info: Dict[str, Any] = { "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []), "generation_tokens": list(getattr(attr, "generation_tokens", []) or []), "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []), "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []), "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None, "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None, } raw_vectors = [] if result.metadata and "ifr" in result.metadata: raw_ifr = result.metadata["ifr"].get("raw") if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"): try: raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions] except Exception: raw_vectors = [] debug_info["raw_hop_vectors"] = raw_vectors return result, sink, thinking, debug_info def run_ft_multihop_improve( example: ds_utils.CachedExample, model: Any, tokenizer: Any, *, n_hops: int, sink_span: Optional[Sequence[int]], thinking_span: Optional[Sequence[int]], chunk_tokens: int, sink_chunk_tokens: int, ) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]: """Execute experimental FT (multi-hop IFR) with stop-token soft deletion.""" import ft_ifr_improve attr = ft_ifr_improve.LLMIFRAttributionImproved( model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens, ) sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None thinking = ( tuple(thinking_span) if thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else None ) result = attr.calculate_ifr_multi_hop_stop_words( example.prompt, target=example.target, sink_span=sink, thinking_span=thinking, n_hops=n_hops, ) debug_info: Dict[str, Any] = { "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []), "generation_tokens": list(getattr(attr, "generation_tokens", []) or []), "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []), "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []), "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None, "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None, } raw_vectors = [] if result.metadata and "ifr" in result.metadata: raw_ifr = result.metadata["ifr"].get("raw") if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"): try: raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions] except Exception: raw_vectors = [] debug_info["raw_hop_vectors"] = raw_vectors return result, sink, thinking, debug_info def run_ft_multihop_split_hop( example: ds_utils.CachedExample, model: Any, tokenizer: Any, *, n_hops: int, sink_span: Optional[Sequence[int]], thinking_span: Optional[Sequence[int]], chunk_tokens: int, sink_chunk_tokens: int, ) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]: """Execute experimental FT (split-hop IFR over segmented thinking span).""" import ft_ifr_improve attr = ft_ifr_improve.LLMIFRAttributionSplitHop( model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens, ) sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None thinking = ( tuple(thinking_span) if thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else None ) result = attr.calculate_ifr_multi_hop_split_hop( example.prompt, target=example.target, sink_span=sink, thinking_span=thinking, n_hops=int(n_hops), ) debug_info: Dict[str, Any] = { "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []), "generation_tokens": list(getattr(attr, "generation_tokens", []) or []), "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []), "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []), "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None, "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None, } raw_vectors = [] if result.metadata and "ifr" in result.metadata: raw_ifr = result.metadata["ifr"].get("raw") if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"): try: raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions] except Exception: raw_vectors = [] debug_info["raw_hop_vectors"] = raw_vectors return result, sink, thinking, debug_info def run_ifr_in_all_gen( example: ds_utils.CachedExample, model: Any, tokenizer: Any, *, n_hops: int, sink_span: Optional[Sequence[int]], thinking_span: Optional[Sequence[int]], chunk_tokens: int, sink_chunk_tokens: int, ) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]: """Execute experimental IFR variant: multi-hop over all generation (CoT + output).""" import ft_ifr_improve attr = ft_ifr_improve.LLMIFRAttributionInAllGen( model, tokenizer, chunk_tokens=chunk_tokens, sink_chunk_tokens=sink_chunk_tokens, ) sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None thinking = ( tuple(thinking_span) if thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else None ) result = attr.calculate_ifr_in_all_gen( example.prompt, target=example.target, sink_span=sink, thinking_span=thinking, n_hops=int(n_hops), ) debug_info: Dict[str, Any] = { "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []), "generation_tokens": list(getattr(attr, "generation_tokens", []) or []), "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []), "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []), "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None, "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None, } raw_vectors = [] if result.metadata and "ifr" in result.metadata: raw_ifr = result.metadata["ifr"].get("raw") if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"): try: raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions] except Exception: raw_vectors = [] debug_info["raw_hop_vectors"] = raw_vectors return result, sink, thinking, debug_info def make_output_stem(dataset_name: str, index: int, mode: str) -> str: safe_name = dataset_name.replace("/", "_").replace(" ", "_") prefix = { "ft": "ft_case_", "ft_improve": "ft_improve_case_", "ifr": "ifr_case_", "ifr_all_positions": "ifr_all_positions_case_", "ifr_all_positions_output_only": "ifr_output_only_case_", "attnlrp": "attnlrp_case_", "ft_attnlrp": "ft_attnlrp_case_", }.get(mode, f"{mode}_case_") return f"{prefix}{safe_name}_idx{index}" def _decode_token_ids(tokenizer: Any, ids: Sequence[int]) -> List[str]: """Decode each token id into a readable text piece (keeps special tokens).""" pieces: List[str] = [] for tok_id in ids: try: pieces.append( tokenizer.decode([int(tok_id)], skip_special_tokens=False, clean_up_tokenization_spaces=False) ) except Exception: pieces.append(str(tok_id)) return pieces def build_raw_tokens_from_ids(tokenizer: Any, prompt_ids: Optional[Sequence[int]], generation_ids: Optional[Sequence[int]]) -> List[str]: if not prompt_ids: prompt_ids = [] if not generation_ids: generation_ids = [] return _decode_token_ids(tokenizer, prompt_ids) + _decode_token_ids(tokenizer, generation_ids) def build_trimmed_roles(tokens: Sequence[str], segments: Dict[str, Any]) -> List[str]: """Assign role labels for trimmed tokens (prompt + generation).""" roles = ["prompt" for _ in range(len(tokens))] prompt_len_tokens = segments.get("prompt_len", 0) for idx in range(prompt_len_tokens, len(tokens)): roles[idx] = "gen" thinking_span = segments.get("thinking_span") sink_span = segments.get("sink_span") if thinking_span is not None: start = prompt_len_tokens + int(thinking_span[0]) end = prompt_len_tokens + int(thinking_span[1]) for i in range(start, min(len(tokens), end + 1)): roles[i] = "think" if sink_span is not None: start = prompt_len_tokens + int(sink_span[0]) end = prompt_len_tokens + int(sink_span[1]) for i in range(start, min(len(tokens), end + 1)): roles[i] = "output" return roles def build_raw_roles( tokens: Sequence[str], prompt_len_full: int, user_indices: Sequence[int], template_indices: Sequence[int], thinking_span_abs: Optional[Sequence[int]], sink_span_abs: Optional[Sequence[int]], ) -> List[str]: """Assign role labels for raw tokens (template + user + generation).""" roles = ["template" for _ in range(len(tokens))] user_set = set(int(i) for i in user_indices) tmpl_set = set(int(i) for i in template_indices) for i in range(min(len(tokens), prompt_len_full)): if i in user_set: roles[i] = "user" elif i in tmpl_set: roles[i] = "template" else: roles[i] = "prompt" for i in range(prompt_len_full, len(tokens)): roles[i] = "gen" if thinking_span_abs is not None: start, end = int(thinking_span_abs[0]), int(thinking_span_abs[1]) for i in range(start, min(len(tokens), end + 1)): roles[i] = "think" if sink_span_abs is not None: start, end = int(sink_span_abs[0]), int(sink_span_abs[1]) for i in range(start, min(len(tokens), end + 1)): roles[i] = "output" return roles def extract_prompt_only_vectors(hop_vectors: Sequence[torch.Tensor], prompt_len: int) -> List[torch.Tensor]: """Slice hop vectors down to user-prompt tokens only (no generation tokens).""" if prompt_len < 0: raise ValueError("prompt_len must be >= 0.") out: List[torch.Tensor] = [] for vec in hop_vectors: v = torch.as_tensor(vec, dtype=torch.float32).detach().cpu() if int(v.numel()) < int(prompt_len): raise ValueError(f"Hop vector too short for prompt-only slice: len={int(v.numel())} prompt_len={int(prompt_len)}.") out.append(v[:prompt_len]) return out def _lift_trimmed_to_full( trimmed: torch.Tensor, *, prompt_len_full: int, gen_len: int, user_prompt_indices: Sequence[int], ) -> torch.Tensor: """Lift a trimmed (user prompt + generation) vector into full token space with zeros for chat-template tokens.""" t = torch.as_tensor(trimmed, dtype=torch.float32).detach().cpu() user_len = len(user_prompt_indices) expected = int(user_len + gen_len) if int(t.numel()) != expected: raise ValueError(f"Trimmed vector length mismatch: got {int(t.numel())}, expected {expected}.") total_len = int(prompt_len_full + gen_len) full = torch.zeros((total_len,), dtype=torch.float32) for j, abs_pos in enumerate(user_prompt_indices): full[int(abs_pos)] = t[j] full[int(prompt_len_full) : int(prompt_len_full + gen_len)] = t[user_len:] return full def _postprocess_attnlrp_full_vector( raw_full: torch.Tensor, *, prompt_len_full: int, gen_len: int, user_prompt_indices: Sequence[int], neg_handling: str, norm_mode: str, ) -> torch.Tensor: """Mirror FT-AttnLRP hop postprocessing while preserving stripped-token normalization. The underlying AttnLRP implementation postprocesses the *stripped* vector (user prompt + generation): - NaN->0, then neg_handling ('drop' or 'abs') - if norm_mode=='norm': normalize by sum over stripped tokens For the pre-trim full view (chat template + generation), we apply the same non-negativity transform to the full vector and normalize using *only the stripped indices*, so overlapping token scores match the trimmed vectors used by the evaluation/case-study hop outputs. """ v = torch.as_tensor(raw_full, dtype=torch.float32).detach().cpu() v = torch.nan_to_num(v, nan=0.0) if neg_handling == "drop": v = v.clamp(min=0.0) elif neg_handling == "abs": v = v.abs() else: raise ValueError(f"Unsupported neg_handling={neg_handling!r} (expected 'drop' or 'abs').") ratio_enabled = norm_mode == "norm" if not ratio_enabled: return v keep = list(int(i) for i in user_prompt_indices) + list(range(int(prompt_len_full), int(prompt_len_full + gen_len))) if not keep: return torch.zeros_like(v) keep_idx = torch.as_tensor(keep, dtype=torch.long) denom = float(v.index_select(0, keep_idx).sum().item()) if denom <= 0.0: return torch.zeros_like(v) return v / (denom + 1e-12) def main() -> None: args = parse_args() device = resolve_device(args.cuda, args.cuda_num) if torch.cuda.is_available(): visible = os.environ.get("CUDA_VISIBLE_DEVICES") print(f"[info] CUDA_VISIBLE_DEVICES={visible!r} torch.cuda.device_count()={torch.cuda.device_count()} device={device}") model_name = args.model_path if args.model_path is not None else args.model # Align with exp/exp2: always use the shared fp16 loader. model, tokenizer = load_model(model_name, device) example, ds_name = load_example(args.dataset, args.index, Path(args.data_root)) mode = args.mode sink_span: Optional[Tuple[int, int]] = None thinking_span: Optional[Tuple[int, int]] = None thinking_ratios: Optional[Sequence[float]] = None prompt_tokens_trimmed: List[str] = [] generation_tokens_trimmed: List[str] = [] hop_vectors_trimmed: List[torch.Tensor] = [] hop_vectors_raw: List[torch.Tensor] = [] prompt_len_full: Optional[int] = None user_prompt_indices: List[int] = [] chat_prompt_indices: List[int] = [] method_meta: Dict[str, Any] = {} raw_prompt_ids: Optional[List[int]] = None raw_generation_ids: Optional[List[int]] = None attnlrp_raw_attributions: Optional[List[Any]] = None if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen"): if mode == "ft": attr_result, sink_span, thinking_span, debug_info = run_ft_multihop( example, model, tokenizer, n_hops=args.n_hops, sink_span=args.sink_span, thinking_span=args.thinking_span, chunk_tokens=args.chunk_tokens, sink_chunk_tokens=args.sink_chunk_tokens, ) elif mode == "ft_improve": attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_improve( example, model, tokenizer, n_hops=args.n_hops, sink_span=args.sink_span, thinking_span=args.thinking_span, chunk_tokens=args.chunk_tokens, sink_chunk_tokens=args.sink_chunk_tokens, ) elif mode == "ft_split_hop": attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_split_hop( example, model, tokenizer, n_hops=args.n_hops, sink_span=args.sink_span, thinking_span=args.thinking_span, chunk_tokens=args.chunk_tokens, sink_chunk_tokens=args.sink_chunk_tokens, ) elif mode == "ifr_in_all_gen": attr_result, sink_span, thinking_span, debug_info = run_ifr_in_all_gen( example, model, tokenizer, n_hops=args.n_hops, sink_span=args.sink_span, thinking_span=args.thinking_span, chunk_tokens=args.chunk_tokens, sink_chunk_tokens=args.sink_chunk_tokens, ) else: raise ValueError(f"Unsupported mode={mode}") ifr_meta = (attr_result.metadata or {}).get("ifr") or {} hop_vectors_trimmed = list(ifr_meta.get("per_hop_projected") or []) if not hop_vectors_trimmed: raise RuntimeError(f"No per-hop vectors found for {mode} mode.") prompt_tokens_trimmed = list(attr_result.prompt_tokens) generation_tokens_trimmed = list(attr_result.generation_tokens) thinking_ratios = ifr_meta.get("thinking_ratios") raw_prompt_ids = debug_info.get("prompt_ids") if isinstance(raw_prompt_ids, list) and raw_prompt_ids and isinstance(raw_prompt_ids[0], list): raw_prompt_ids = raw_prompt_ids[0] raw_generation_ids = debug_info.get("generation_ids") if isinstance(raw_generation_ids, list) and raw_generation_ids and isinstance(raw_generation_ids[0], list): raw_generation_ids = raw_generation_ids[0] user_prompt_indices = list(debug_info.get("user_prompt_indices") or []) chat_prompt_indices = list(debug_info.get("chat_prompt_indices") or []) prompt_len_full = len(raw_prompt_ids) if isinstance(raw_prompt_ids, list) else None raw_vectors = debug_info.get("raw_hop_vectors") or [] hop_vectors_raw = [vec.detach().cpu() if hasattr(vec, "detach") else torch.as_tensor(vec) for vec in raw_vectors] method_meta = {"ifr": analysis.sanitize_ifr_meta(ifr_meta)} elif mode == "ifr": # Standard IFR (single-hop span aggregate), with pre/post trim views. attr = llm_attr.LLMIFRAttribution( model, tokenizer, chunk_tokens=args.chunk_tokens, sink_chunk_tokens=args.sink_chunk_tokens, ) sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span if sink_span is None: raise ValueError("sink_span is required for IFR mode (use dataset sink_span or pass --sink_span).") span_result = attr.calculate_ifr_span( example.prompt, target=example.target, span=tuple(sink_span), ) span_meta = span_result.metadata.get("ifr") if span_result.metadata else None aggregate = span_meta.get("aggregate") if isinstance(span_meta, dict) else None if aggregate is None or not hasattr(aggregate, "token_importance_total"): raise RuntimeError("IFR span aggregate missing from metadata; cannot render pre-trim view.") raw_vector = aggregate.token_importance_total.detach().cpu() trimmed_vector = attr._project_vector(raw_vector) hop_vectors_raw = [raw_vector] hop_vectors_trimmed = [trimmed_vector] prompt_tokens_trimmed = list(attr.user_prompt_tokens) generation_tokens_trimmed = list(attr.generation_tokens) raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0] raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0] user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or []) chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or []) prompt_len_full = len(raw_prompt_ids) sink_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1]) think_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1]) if thinking_span else None meta = { "type": "span_aggregate", "ifr_view": "aggregate", "sink_span_generation": sink_span, "sink_span_absolute": sink_abs, "thinking_span_generation": thinking_span, "thinking_span_absolute": think_abs, } method_meta = {"ifr": analysis.tensor_to_list(meta)} elif mode == "ifr_all_positions_output_only": # IFR all-positions (output-only) + token-level CAGE (row/recursive) derived from the matrix. attr = llm_attr.LLMIFRAttribution( model, tokenizer, chunk_tokens=args.chunk_tokens, sink_chunk_tokens=args.sink_chunk_tokens, ) sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span if sink_span is None: raise ValueError( "sink_span is required for ifr_all_positions_output_only mode " "(use dataset sink_span or pass --sink_span)." ) attr_result = attr.calculate_ifr_for_all_positions_output_only( example.prompt, target=example.target, sink_span=tuple(sink_span), ) indices_to_explain = list(sink_span) _, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain) row_vec = row_attr.squeeze(0).detach().cpu() rec_vec = rec_attr.squeeze(0).detach().cpu() hop_vectors_trimmed = [row_vec, rec_vec] prompt_tokens_trimmed = list(attr.user_prompt_tokens) generation_tokens_trimmed = list(attr.generation_tokens) raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0] raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0] user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or []) chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or []) prompt_len_full = len(raw_prompt_ids) gen_len = len(raw_generation_ids or []) hop_vectors_raw = [ _lift_trimmed_to_full( v, prompt_len_full=int(prompt_len_full or 0), gen_len=gen_len, user_prompt_indices=user_prompt_indices, ) for v in hop_vectors_trimmed ] ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {}) ifr_meta["ifr_view"] = "all_positions_output_only (row+rec)" ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"] ifr_meta["indices_to_explain"] = indices_to_explain method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)} elif mode == "ifr_all_positions": # IFR all-positions (full generation) + token-level CAGE (row/recursive) derived from the matrix. attr = llm_attr.LLMIFRAttribution( model, tokenizer, chunk_tokens=args.chunk_tokens, sink_chunk_tokens=args.sink_chunk_tokens, ) sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span if sink_span is None: raise ValueError( "sink_span is required for ifr_all_positions mode (use dataset sink_span or pass --sink_span)." ) attr_result = attr.calculate_ifr_for_all_positions( example.prompt, target=example.target, ) indices_to_explain = list(sink_span) _, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain) row_vec = row_attr.squeeze(0).detach().cpu() rec_vec = rec_attr.squeeze(0).detach().cpu() hop_vectors_trimmed = [row_vec, rec_vec] prompt_tokens_trimmed = list(attr.user_prompt_tokens) generation_tokens_trimmed = list(attr.generation_tokens) raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0] raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0] user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or []) chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or []) prompt_len_full = len(raw_prompt_ids) gen_len = len(raw_generation_ids or []) hop_vectors_raw = [ _lift_trimmed_to_full( v, prompt_len_full=int(prompt_len_full or 0), gen_len=gen_len, user_prompt_indices=user_prompt_indices, ) for v in hop_vectors_trimmed ] ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {}) ifr_meta["ifr_view"] = "all_positions (row+rec)" ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"] ifr_meta["indices_to_explain"] = indices_to_explain method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)} elif mode in ("attnlrp", "ft_attnlrp"): # Reuse the shared LLMLRPAttribution implementations (root-level). attributor = llm_attr.LLMLRPAttribution(model, tokenizer) sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None thinking_span = ( tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span ) if mode == "attnlrp": # Case-study AttnLRP: reuse FT-AttnLRP logic but take hop0 (the first span-aggregate) # for a full, signed attribution vector (no observation masking). attr_result = attributor.calculate_attnlrp_ft_hop0( example.prompt, target=example.target, sink_span=sink_span, thinking_span=thinking_span, neg_handling=args.attnlrp_neg_handling, norm_mode=args.attnlrp_norm_mode, ) meta = attr_result.metadata or {} multi_hop = meta.get("multi_hop_result") raw_attributions = getattr(multi_hop, "raw_attributions", None) or [] attnlrp_raw_attributions = list(raw_attributions) base_attr = raw_attributions[0] if raw_attributions else None if base_attr is None or not hasattr(base_attr, "token_importance_total"): raise RuntimeError("AttnLRP hop0 missing from multi-hop result.") hop0_vec = torch.as_tensor(getattr(base_attr, "token_importance_total"), dtype=torch.float32).detach().cpu() if hop0_vec.numel() <= 0: raise RuntimeError("Empty generation for AttnLRP case study.") # Use the actual sink span applied by hop0 (defaults to full generation when unset). sink_span = tuple(getattr(base_attr, "sink_range")) if thinking_span is None: thinking_span = sink_span hop_vectors_trimmed = [hop0_vec] thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or []) method_meta = { "attnlrp": { "type": "calculate_attnlrp_multi_hop(n_hops=0) hop0 raw_attributions[0]", "sink_span_generation": sink_span, "thinking_span_generation": thinking_span, "thinking_ratios": thinking_ratios, "neg_handling": args.attnlrp_neg_handling, "norm_mode": args.attnlrp_norm_mode, "ratio_enabled": args.attnlrp_norm_mode == "norm", } } else: # exp2 ft_attnlrp: multi-hop aggregated AttnLRP (metadata contains per-hop vectors). attr_result = attributor.calculate_attnlrp_aggregated_multi_hop( example.prompt, target=example.target, sink_span=sink_span, thinking_span=thinking_span, n_hops=int(args.n_hops), neg_handling=args.attnlrp_neg_handling, norm_mode=args.attnlrp_norm_mode, ) meta = attr_result.metadata or {} multi_hop = meta.get("multi_hop_result") if multi_hop is None: raise RuntimeError("FT-AttnLRP case study missing metadata.multi_hop_result.") raw_attributions = getattr(multi_hop, "raw_attributions", None) or [] attnlrp_raw_attributions = list(raw_attributions) hop_vectors_trimmed = [ torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32).detach().cpu() for hop in raw_attributions ] thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or []) method_meta = { "attnlrp": { "type": "calculate_attnlrp_aggregated_multi_hop (exp2 ft_attnlrp)", "n_hops": int(args.n_hops), "sink_span_generation": sink_span, "thinking_span_generation": thinking_span, "thinking_ratios": thinking_ratios, "neg_handling": args.attnlrp_neg_handling, "norm_mode": args.attnlrp_norm_mode, "ratio_enabled": args.attnlrp_norm_mode == "norm", } } prompt_tokens_trimmed = list(attributor.user_prompt_tokens) generation_tokens_trimmed = list(attributor.generation_tokens) raw_prompt_ids = attributor.prompt_ids.detach().cpu().tolist()[0] raw_generation_ids = attributor.generation_ids.detach().cpu().tolist()[0] user_prompt_indices = list(getattr(attributor, "user_prompt_indices", []) or []) chat_prompt_indices = list(getattr(attributor, "chat_prompt_indices", []) or []) prompt_len_full = len(raw_prompt_ids) else: raise ValueError(f"Unsupported mode={mode}") if not hop_vectors_trimmed: raise RuntimeError("No hop vectors to visualize.") raw_tokens = build_raw_tokens_from_ids(tokenizer, raw_prompt_ids, raw_generation_ids) sink_span_abs = None thinking_span_abs = None if prompt_len_full is not None and sink_span is not None: sink_span_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1]) if prompt_len_full is not None and thinking_span is not None: thinking_span_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1]) prompt_len_full_safe = int(prompt_len_full or 0) roles_raw = build_raw_roles( raw_tokens, prompt_len_full_safe, user_prompt_indices, chat_prompt_indices, thinking_span_abs, sink_span_abs, ) prompt_tokens_only = list(prompt_tokens_trimmed) prompt_only_vectors = extract_prompt_only_vectors(hop_vectors_trimmed, len(prompt_tokens_only)) # Ensure every method has a pre-trim full vector per panel. if not hop_vectors_raw: if mode in ("attnlrp", "ft_attnlrp") and attnlrp_raw_attributions is not None: gen_len = len(raw_generation_ids or []) expected = int((prompt_len_full_safe + gen_len) if prompt_len_full is not None else 0) full_vectors: List[torch.Tensor] = [] for hop in attnlrp_raw_attributions: meta = getattr(hop, "metadata", None) or {} raw_full = meta.get("token_importance_total_with_chat_template") if raw_full is None: full_vectors = [] break v = _postprocess_attnlrp_full_vector( torch.as_tensor(raw_full, dtype=torch.float32), prompt_len_full=prompt_len_full_safe, gen_len=gen_len, user_prompt_indices=user_prompt_indices, neg_handling=args.attnlrp_neg_handling, norm_mode=args.attnlrp_norm_mode, ) if expected and int(v.numel()) != expected: raise RuntimeError( "AttnLRP full-vector length mismatch for pre-trim view: " f"got {int(v.numel())}, expected {expected}." ) full_vectors.append(v) hop_vectors_raw = full_vectors if not hop_vectors_raw and prompt_len_full is not None: # Fallback: lift trimmed vectors back to full token space with zeros for template tokens. gen_len = len(raw_generation_ids or []) hop_vectors_raw = [ _lift_trimmed_to_full( v, prompt_len_full=prompt_len_full_safe, gen_len=gen_len, user_prompt_indices=user_prompt_indices, ) for v in hop_vectors_trimmed ] if not hop_vectors_raw: raise RuntimeError("Missing pre-trim vectors; cannot render required full-sequence heatmap.") # Lightweight debug stats to catch silent all-zero / NaN cases. hop_stats_raw = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in hop_vectors_raw] hop_stats_prompt = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in prompt_only_vectors] for i in range(max(len(hop_stats_raw), len(hop_stats_prompt))): raw_abs = hop_stats_raw[i]["abs_max"] if i < len(hop_stats_raw) else None prompt_abs = hop_stats_prompt[i]["abs_max"] if i < len(hop_stats_prompt) else None print(f"[stats] panel {i}: raw_abs_max={raw_abs} prompt_abs_max={prompt_abs}") hop_token_raw = analysis.package_token_hops(hop_vectors_raw) hop_token_prompt = analysis.package_token_hops(prompt_only_vectors) case_meta: Dict[str, Any] = { "dataset": ds_name, "index": args.index, "sink_span": sink_span, "thinking_span": thinking_span, "n_hops": args.n_hops, "thinking_ratios": thinking_ratios, "mode": mode, "ifr_view": method_meta.get("ifr", {}).get("ifr_view") if isinstance(method_meta.get("ifr"), dict) else None, "panel_titles": method_meta.get("ifr", {}).get("panel_titles") if isinstance(method_meta.get("ifr"), dict) else None, "attnlrp_neg_handling": args.attnlrp_neg_handling if mode in ("attnlrp", "ft_attnlrp") else None, "attnlrp_norm_mode": args.attnlrp_norm_mode if mode in ("attnlrp", "ft_attnlrp") else None, "attnlrp_ratio_enabled": (args.attnlrp_norm_mode == "norm") if mode in ("attnlrp", "ft_attnlrp") else None, "vector_stats_raw": hop_stats_raw, "vector_stats_prompt": hop_stats_prompt, } generation_text = "".join(generation_tokens_trimmed) if generation_tokens_trimmed else "" prompt_text = example.prompt record = { "meta": case_meta, "prompt": prompt_text, "target": example.target, "generation": generation_text, "full_all_tokens": raw_tokens, "raw_token_roles": roles_raw, "prompt_tokens": prompt_tokens_only, "prompt_token_roles": ["user" for _ in range(len(prompt_tokens_only))], "token_hops_raw": hop_token_raw, "token_hops_prompt": hop_token_prompt, "ifr_meta": method_meta.get("ifr"), "attnlrp_meta": method_meta.get("attnlrp"), } out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) stem = make_output_stem(ds_name, args.index, mode) json_path = out_dir / f"{stem}.json" html_path = out_dir / f"{stem}.html" with json_path.open("w", encoding="utf-8") as f: json.dump(record, f, ensure_ascii=False, indent=2) html = viz.render_case_html( case_meta, token_view_raw={ "label": "Pre-trim token-level heatmap (full sequence with chat template)", "tokens": raw_tokens, "roles": roles_raw, "hops": hop_token_raw, }, token_view_prompt={ "label": "Prompt-only token-level heatmap (user prompt only)", "tokens": prompt_tokens_only, "roles": ["user" for _ in range(len(prompt_tokens_only))], "hops": hop_token_prompt, }, ) html_path.write_text(html, encoding="utf-8") print(f"[done] wrote {json_path}") print(f"[done] wrote {html_path}") if __name__ == "__main__": main()