flashtrace / exp /exp2 /run_exp.py
wenbopan's picture
Sync FlashTrace package from GitHub
55b60a8
#!/usr/bin/env python3
"""
Experiment 2 runner: token-level faithfulness (generation perturbation).
AT2 is omitted.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import os
import sys
from itertools import islice
import math
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# Early CUDA mask handling: set CUDA_VISIBLE_DEVICES before importing torch.
def _early_set_cuda_visible_devices():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--cuda", type=str, default=None)
# parse_known_args keeps the full argv for later parsing by the main parser
args, _ = parser.parse_known_args(sys.argv[1:])
if args.cuda and "," in args.cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
_early_set_cuda_visible_devices()
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, utils
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pathlib import Path
# ensure repo root on path
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
import llm_attr
import llm_attr_eval
from attribution_datasets import AttributionExample
from exp.exp2 import dataset_utils as ds_utils
utils.logging.set_verbosity_error()
def _sha1_text(text: str) -> str:
return hashlib.sha1(text.encode("utf-8")).hexdigest()
def _infer_attnlrp_spans_from_hops(
raw_attributions: Any,
*,
gen_len: int,
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
if not raw_attributions:
return (0, max(0, gen_len - 1)), (0, max(0, gen_len - 1))
sink_span = tuple(int(x) for x in raw_attributions[0].sink_range)
if len(raw_attributions) >= 2:
thinking_span = tuple(int(x) for x in raw_attributions[1].sink_range)
else:
thinking_span = sink_span
return sink_span, thinking_span
def _build_hop_trace_payload(
attr_func: str,
attr: Any,
*,
indices_to_explain: List[int],
) -> Optional[Dict[str, np.ndarray]]:
"""Extract per-hop vectors (postprocessed) and minimal span metadata."""
prompt_len = int(len(getattr(attr, "prompt_tokens", []) or []))
gen_len = int(len(getattr(attr, "generation_tokens", []) or []))
total_len = prompt_len + gen_len
if total_len <= 0:
return None
hop_vectors: List[torch.Tensor] = []
sink_span_gen: Optional[Tuple[int, int]] = None
thinking_span_gen: Optional[Tuple[int, int]] = None
attnlrp_neg_handling: str = ""
attnlrp_norm_mode: str = ""
attnlrp_ratio_enabled: int = -1
# IFR multi-hop variants expose projected hop vectors via metadata["ifr"]["per_hop_projected"].
ifr_meta = (getattr(attr, "metadata", None) or {}).get("ifr") or {}
ifr_per_hop = ifr_meta.get("per_hop_projected") or []
if ifr_per_hop:
hop_vectors = [torch.as_tensor(v, dtype=torch.float32) for v in ifr_per_hop]
sink_span_gen = ifr_meta.get("sink_span_generation")
thinking_span_gen = ifr_meta.get("thinking_span_generation")
if sink_span_gen is not None:
sink_span_gen = tuple(int(x) for x in sink_span_gen)
if thinking_span_gen is not None:
thinking_span_gen = tuple(int(x) for x in thinking_span_gen)
elif attr_func in ("ft_attnlrp", "attnlrp_aggregated_multi_hop"):
meta = getattr(attr, "metadata", None) or {}
attnlrp_neg_handling = str(meta.get("neg_handling") or "")
attnlrp_norm_mode = str(meta.get("norm_mode") or "")
if meta.get("ratio_enabled") is not None:
attnlrp_ratio_enabled = int(bool(meta.get("ratio_enabled")))
multi_hop = meta.get("multi_hop_result")
if multi_hop is None:
return None
raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
if not raw_attributions:
return None
hop_vectors = [
torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32)
for hop in raw_attributions
]
sink_span_gen, thinking_span_gen = _infer_attnlrp_spans_from_hops(raw_attributions, gen_len=gen_len)
sink_override = meta.get("sink_span")
thinking_override = meta.get("thinking_span")
if sink_override is not None:
sink_span_gen = tuple(int(x) for x in sink_override)
if thinking_override is not None:
thinking_span_gen = tuple(int(x) for x in thinking_override)
else:
return None
if sink_span_gen is None:
sink_span_gen = (0, max(0, gen_len - 1))
if thinking_span_gen is None:
thinking_span_gen = sink_span_gen
stacked = torch.stack([v.reshape(-1) for v in hop_vectors], dim=0)
if stacked.shape[1] != total_len:
raise ValueError(
f"Hop vector length mismatch for {attr_func}: expected T={total_len}, got {stacked.shape[1]}."
)
return {
"vh": stacked.detach().cpu().numpy().astype(np.float32, copy=False),
"prompt_len": np.asarray(prompt_len, dtype=np.int64),
"gen_len": np.asarray(gen_len, dtype=np.int64),
"sink_span_gen": np.asarray(sink_span_gen, dtype=np.int64),
"thinking_span_gen": np.asarray(thinking_span_gen, dtype=np.int64),
"indices_to_explain_gen": np.asarray(indices_to_explain, dtype=np.int64),
"attnlrp_neg_handling": np.asarray(attnlrp_neg_handling, dtype="U16"),
"attnlrp_norm_mode": np.asarray(attnlrp_norm_mode, dtype="U16"),
"attnlrp_ratio_enabled": np.asarray(attnlrp_ratio_enabled, dtype=np.int64),
}
def _write_hop_trace(
trace_dir: Path,
*,
example_idx: int,
attr_func: str,
prompt: str,
target: Optional[str],
payload: Dict[str, np.ndarray],
manifest_handle,
) -> None:
trace_dir.mkdir(parents=True, exist_ok=True)
npz_name = f"ex_{example_idx:06d}.npz"
npz_path = trace_dir / npz_name
np.savez_compressed(npz_path, **payload)
record = {
"example_idx": int(example_idx),
"attr_func": attr_func,
"file": npz_name,
"prompt_sha1": _sha1_text(prompt),
"target_sha1": _sha1_text(target) if target is not None else None,
"prompt_len": int(payload["prompt_len"].item()),
"gen_len": int(payload["gen_len"].item()),
"n_hops_plus_one": int(payload["vh"].shape[0]),
"total_len": int(payload["vh"].shape[1]),
"sink_span_gen": payload["sink_span_gen"].tolist(),
"thinking_span_gen": payload["thinking_span_gen"].tolist(),
"indices_to_explain_gen": payload["indices_to_explain_gen"].tolist(),
"attnlrp_neg_handling": str(payload["attnlrp_neg_handling"].item()),
"attnlrp_norm_mode": str(payload["attnlrp_norm_mode"].item()),
"attnlrp_ratio_enabled": int(payload["attnlrp_ratio_enabled"].item()),
}
manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
manifest_handle.flush()
def _parse_modes(mode_args: Any) -> List[str]:
"""Parse --mode which may be provided as multiple args and/or comma-separated."""
if mode_args is None:
raw_parts: List[str] = []
elif isinstance(mode_args, str):
raw_parts = [mode_args]
else:
raw_parts = [str(x) for x in mode_args]
modes: List[str] = []
for chunk in raw_parts:
for part in str(chunk).split(","):
m = part.strip()
if m:
modes.append(m)
# Default to faithfulness_gen for backward compatibility.
if not modes:
modes = ["faithfulness_gen"]
allowed = {"faithfulness_gen", "recovery_ruler"}
seen: set[str] = set()
unique: List[str] = []
for m in modes:
if m not in seen:
unique.append(m)
seen.add(m)
unknown = [m for m in unique if m not in allowed]
if unknown:
raise SystemExit(f"Unsupported --mode value(s): {unknown}. Allowed: {sorted(allowed)}.")
return unique
def _trace_run_tag(
testing_dict: Dict[str, Any],
*,
modes: List[str],
total: int,
) -> str:
attr_func = str(testing_dict.get("attr_func") or "attr")
parts = [attr_func]
if attr_func in (
"ifr_multi_hop",
"ifr_in_all_gen",
"ifr_multi_hop_stop_words",
"ifr_multi_hop_both",
"ifr_multi_hop_split_hop",
"ft_attnlrp",
"attnlrp_aggregated_multi_hop",
):
parts.append(f"n{int(testing_dict.get('n_hops', 0))}")
if attr_func in ("attnlrp", "ft_attnlrp", "attnlrp_aggregated_multi_hop"):
parts.append(f"neg{str(testing_dict.get('attnlrp_neg_handling', ''))}")
parts.append(f"norm{str(testing_dict.get('attnlrp_norm_mode', ''))}")
if modes:
parts.append("m" + "+".join(modes))
parts.append(f"{int(total)}ex")
return "_".join(parts)
def _token_importance_vector(attr: torch.Tensor) -> np.ndarray:
"""Return token importance vector w = sum_rows(attr) in shape [P+G]."""
w = torch.nan_to_num(attr.sum(0).to(dtype=torch.float32), nan=0.0).clamp(min=0.0)
return w.detach().cpu().numpy().astype(np.float32, copy=False)
def _build_sample_trace_payload(
example: ds_utils.CachedExample,
*,
attr_list: List[torch.Tensor],
prompt_len: int,
user_prompt_indices: Optional[List[int]],
keep_prompt_token_indices: Optional[List[int]],
gold_prompt_token_indices: Optional[List[int]],
hop_payload: Optional[Dict[str, np.ndarray]],
faithfulness_scores: Optional[np.ndarray],
recovery_scores: Optional[np.ndarray],
time_attr_s: Optional[float],
time_faith_s: Optional[float],
time_recovery_s: Optional[float],
) -> Dict[str, np.ndarray]:
seq_attr, row_attr, rec_attr = attr_list
gen_len = int(seq_attr.shape[0])
v_seq_all = _token_importance_vector(seq_attr)
v_row_all = _token_importance_vector(row_attr)
v_rec_all = _token_importance_vector(rec_attr)
payload: Dict[str, np.ndarray] = {
"v_seq_all": v_seq_all,
"v_row_all": v_row_all,
"v_rec_all": v_rec_all,
"v_seq_prompt": v_seq_all[:prompt_len],
"v_row_prompt": v_row_all[:prompt_len],
"v_rec_prompt": v_rec_all[:prompt_len],
"prompt_len": np.asarray(int(prompt_len), dtype=np.int64),
"gen_len": np.asarray(int(gen_len), dtype=np.int64),
"indices_to_explain_gen": np.asarray(list(example.indices_to_explain or []), dtype=np.int64),
}
if example.sink_span is not None:
payload["sink_span_gen"] = np.asarray(list(example.sink_span), dtype=np.int64)
if example.thinking_span is not None:
payload["thinking_span_gen"] = np.asarray(list(example.thinking_span), dtype=np.int64)
if user_prompt_indices is not None:
payload["user_prompt_indices"] = np.asarray(list(user_prompt_indices), dtype=np.int64)
if keep_prompt_token_indices is not None:
payload["keep_prompt_token_indices"] = np.asarray(list(keep_prompt_token_indices), dtype=np.int64)
if gold_prompt_token_indices is not None:
payload["gold_prompt_token_indices"] = np.asarray(list(gold_prompt_token_indices), dtype=np.int64)
if faithfulness_scores is not None:
payload["faithfulness_scores"] = np.asarray(faithfulness_scores, dtype=np.float64)
if recovery_scores is not None:
payload["recovery_scores"] = np.asarray(recovery_scores, dtype=np.float64)
if time_attr_s is not None:
payload["time_attr_s"] = np.asarray(float(time_attr_s), dtype=np.float64)
if time_faith_s is not None:
payload["time_faith_s"] = np.asarray(float(time_faith_s), dtype=np.float64)
if time_recovery_s is not None:
payload["time_recovery_s"] = np.asarray(float(time_recovery_s), dtype=np.float64)
if hop_payload is not None:
for k, v in hop_payload.items():
if k in payload:
continue
payload[k] = v
return payload
def _write_sample_trace(
trace_dir: Path,
*,
example_idx: int,
attr_func: str,
prompt: str,
target: Optional[str],
payload: Dict[str, np.ndarray],
manifest_handle,
recovery_skipped_reason: Optional[str],
) -> None:
trace_dir.mkdir(parents=True, exist_ok=True)
npz_name = f"ex_{example_idx:06d}.npz"
npz_path = trace_dir / npz_name
np.savez_compressed(npz_path, **payload)
prompt_len = int(np.asarray(payload.get("prompt_len", 0)).item())
gen_len = int(np.asarray(payload.get("gen_len", 0)).item())
record: Dict[str, Any] = {
"example_idx": int(example_idx),
"attr_func": attr_func,
"file": npz_name,
"prompt_sha1": _sha1_text(prompt),
"target_sha1": _sha1_text(target) if target is not None else None,
"prompt_len": prompt_len,
"gen_len": gen_len,
"indices_to_explain_gen": payload.get("indices_to_explain_gen").tolist()
if payload.get("indices_to_explain_gen") is not None
else None,
"sink_span_gen": payload.get("sink_span_gen").tolist() if payload.get("sink_span_gen") is not None else None,
"thinking_span_gen": payload.get("thinking_span_gen").tolist()
if payload.get("thinking_span_gen") is not None
else None,
"faithfulness_scores": payload.get("faithfulness_scores").tolist()
if payload.get("faithfulness_scores") is not None
else None,
"recovery_scores": payload.get("recovery_scores").tolist() if payload.get("recovery_scores") is not None else None,
"recovery_skipped_reason": recovery_skipped_reason,
"time_attr_s": float(np.asarray(payload.get("time_attr_s")).item()) if payload.get("time_attr_s") is not None else None,
"time_faith_s": float(np.asarray(payload.get("time_faith_s")).item()) if payload.get("time_faith_s") is not None else None,
"time_recovery_s": float(np.asarray(payload.get("time_recovery_s")).item())
if payload.get("time_recovery_s") is not None
else None,
}
# Derived, sample-level bookkeeping (token lengths and per-sample MAS/RISE).
record["input_len"] = int(prompt_len)
sink_span = record.get("sink_span_gen")
if isinstance(sink_span, list) and len(sink_span) == 2:
try:
start = int(sink_span[0])
end = int(sink_span[1])
record["output_len"] = (end - start + 1) if end >= start else None
except Exception:
record["output_len"] = None
else:
record["output_len"] = None
thinking_span = record.get("thinking_span_gen")
if isinstance(thinking_span, list) and len(thinking_span) == 2:
try:
start = int(thinking_span[0])
end = int(thinking_span[1])
record["cot_len"] = (end - start + 1) if end >= start else None
except Exception:
record["cot_len"] = None
else:
record["cot_len"] = None
record["rise_seq"] = None
record["mas_seq"] = None
record["rise_row"] = None
record["mas_row"] = None
record["rise_rec"] = None
record["mas_rec"] = None
faith = record.get("faithfulness_scores")
if isinstance(faith, list) and len(faith) == 3:
try:
record["rise_seq"] = float(faith[0][0])
record["mas_seq"] = float(faith[0][1])
record["rise_row"] = float(faith[1][0])
record["mas_row"] = float(faith[1][1])
record["rise_rec"] = float(faith[2][0])
record["mas_rec"] = float(faith[2][1])
except Exception:
pass
if payload.get("vh") is not None:
vh = payload["vh"]
record["n_hops_plus_one"] = int(vh.shape[0])
record["total_len"] = int(vh.shape[1])
record["attnlrp_neg_handling"] = str(payload.get("attnlrp_neg_handling").item()) if payload.get("attnlrp_neg_handling") is not None else ""
record["attnlrp_norm_mode"] = str(payload.get("attnlrp_norm_mode").item()) if payload.get("attnlrp_norm_mode") is not None else ""
record["attnlrp_ratio_enabled"] = int(payload.get("attnlrp_ratio_enabled").item()) if payload.get("attnlrp_ratio_enabled") is not None else -1
manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
manifest_handle.flush()
def _compute_faithfulness_scores(
testing_dict: Dict[str, Any],
*,
attr_list: List[torch.Tensor],
prompt_len: int,
prompt: str,
generation: str,
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
user_prompt_indices: Optional[List[int]],
keep_prompt_token_indices: Optional[List[int]],
) -> np.ndarray:
attr_func = str(testing_dict.get("attr_func") or "")
results: List[Tuple[float, float, float]] = []
for attr in attr_list:
attr_prompt = attr[:, :prompt_len]
if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
import ft_ifr_improve
scores = ft_ifr_improve.faithfulness_test_skip_tokens(
llm_evaluator,
attr_prompt,
prompt,
generation,
keep_prompt_token_indices=keep_prompt_token_indices,
user_prompt_indices=user_prompt_indices,
)
elif user_prompt_indices is not None:
scores = _faithfulness_test_with_user_prompt_indices(
llm_evaluator,
attr_prompt,
prompt,
generation,
user_prompt_indices=user_prompt_indices,
)
else:
scores = llm_evaluator.faithfulness_test(attr_prompt, prompt, generation)
results.append(scores)
return np.asarray(results, dtype=np.float64)
def _compute_recovery_scores(
testing_dict: Dict[str, Any],
*,
attr_list: List[torch.Tensor],
prompt_len: int,
gold_prompt_token_indices: List[int],
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
keep_prompt_token_indices: Optional[List[int]],
) -> Tuple[Optional[np.ndarray], Optional[str]]:
attr_func = str(testing_dict.get("attr_func") or "")
if prompt_len <= 0:
return None, "empty_prompt_len"
gold_prompt = [int(x) for x in (gold_prompt_token_indices or [])]
if not gold_prompt:
return None, "empty_gold_prompt"
if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
import ft_ifr_improve
keep_set = {int(x) for x in keep_prompt_token_indices}
gold_filtered = [idx for idx in gold_prompt if int(idx) in keep_set]
if not gold_filtered:
return None, "empty_gold_after_keep_filter"
scores = [
ft_ifr_improve.evaluate_attr_recovery_skip_tokens(
attr[:, :prompt_len],
keep_prompt_token_indices=keep_prompt_token_indices,
gold_prompt_token_indices=gold_prompt,
top_fraction=0.1,
)
for attr in attr_list
]
else:
scores = [
llm_evaluator.evaluate_attr_recovery(
attr,
prompt_len=prompt_len,
gold_prompt_token_indices=gold_prompt,
top_fraction=0.1,
)
for attr in attr_list
]
return np.asarray(scores, dtype=np.float64), None
def evaluate_dataset_multi(
args,
dataset_name: str,
examples: List[ds_utils.CachedExample],
testing_dict: Dict[str, Any],
*,
modes: List[str],
) -> Dict[str, Any]:
tokenizer = testing_dict["tokenizer"]
llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(testing_dict["model"], tokenizer)
want_faith = "faithfulness_gen" in modes
want_recovery = "recovery_ruler" in modes
faith_results: List[np.ndarray] = []
faith_durations: List[float] = []
recovery_results: List[np.ndarray] = []
recovery_attr_durations: List[float] = []
recovery_skipped = 0
total = min(len(examples), args.num_examples)
iterator = islice(examples, total)
save_traces = bool(getattr(args, "save_hop_traces", False))
manifest_handle = None
trace_dir: Optional[Path] = None
if save_traces:
model_tag = str(testing_dict.get("model_tag", "model"))
run_tag = _trace_run_tag(testing_dict, modes=modes, total=total)
trace_dir = Path(args.output_root) / "traces" / dataset_name / model_tag / run_tag
trace_dir.mkdir(parents=True, exist_ok=True)
manifest_handle = open(trace_dir / "manifest.jsonl", "w", encoding="utf-8")
try:
for example_idx, ex in enumerate(iterator):
if want_recovery:
needle_spans = (ex.metadata or {}).get("needle_spans")
if not isinstance(needle_spans, list) or not needle_spans:
raise SystemExit(
"recovery_ruler requires RULER samples with metadata.needle_spans; "
f"dataset={dataset_name} has missing/empty needle_spans."
)
if ex.target is None:
raise SystemExit(
"recovery_ruler requires cached targets (CoT+answer) so row/rec attribution is well-defined. "
f"dataset={dataset_name} has target=None; run exp/exp2/sample_and_filter.py first."
)
# Determine generation/target once.
target = ex.target
if target is None:
generation, full_output = llm_evaluator.response(ex.prompt)
target = generation
response_len = len(tokenizer(full_output).input_ids)
else:
response_len = len(tokenizer(llm_evaluator.format_prompt(" " + ex.prompt) + target).input_ids)
testing_dict["batch_size"] = max(1, math.floor((testing_dict["max_input_len"] - 100) / max(1, response_len)))
gold_prompt: Optional[List[int]] = None
if want_recovery:
gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer)
if want_recovery and not want_faith and not save_traces:
# Preserve recovery-only fast path when not saving traces: skip samples with empty gold.
if not gold_prompt:
recovery_skipped += 1
continue
time_attr_s = None
time_faith_s = None
time_recovery_s = None
t0 = time.perf_counter()
attr_list, hop_payload, user_prompt_indices, keep_prompt_token_indices = run_attribution(testing_dict, ex, target)
time_attr_s = time.perf_counter() - t0
seq_attr = attr_list[0]
prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
if want_recovery and gold_prompt:
recovery_attr_durations.append(float(time_attr_s))
faith_scores = None
if want_faith:
t1 = time.perf_counter()
faith_scores = _compute_faithfulness_scores(
testing_dict,
attr_list=attr_list,
prompt_len=prompt_len,
prompt=ex.prompt,
generation=target,
llm_evaluator=llm_evaluator,
user_prompt_indices=user_prompt_indices,
keep_prompt_token_indices=keep_prompt_token_indices,
)
time_faith_s = time.perf_counter() - t1
faith_results.append(faith_scores)
faith_durations.append(float(time_attr_s))
recovery_scores = None
recovery_skip_reason = None
if want_recovery:
if not gold_prompt:
recovery_skip_reason = "empty_gold_prompt"
recovery_skipped += 1
else:
t2 = time.perf_counter()
recovery_scores, recovery_skip_reason = _compute_recovery_scores(
testing_dict,
attr_list=attr_list,
prompt_len=prompt_len,
gold_prompt_token_indices=gold_prompt,
llm_evaluator=llm_evaluator,
keep_prompt_token_indices=keep_prompt_token_indices,
)
time_recovery_s = time.perf_counter() - t2
if recovery_scores is None:
recovery_skipped += 1
else:
recovery_results.append(recovery_scores)
if manifest_handle is not None and trace_dir is not None:
try:
payload = _build_sample_trace_payload(
ex,
attr_list=attr_list,
prompt_len=prompt_len,
user_prompt_indices=user_prompt_indices,
keep_prompt_token_indices=keep_prompt_token_indices,
gold_prompt_token_indices=gold_prompt,
hop_payload=hop_payload,
faithfulness_scores=faith_scores,
recovery_scores=recovery_scores,
time_attr_s=time_attr_s,
time_faith_s=time_faith_s,
time_recovery_s=time_recovery_s,
)
_write_sample_trace(
trace_dir,
example_idx=example_idx,
attr_func=str(testing_dict.get("attr_func") or ""),
prompt=ex.prompt,
target=target,
payload=payload,
manifest_handle=manifest_handle,
recovery_skipped_reason=recovery_skip_reason,
)
except Exception as exc:
print(f"[warn] sample trace save failed for {testing_dict.get('attr_func')} ex={example_idx}: {exc}")
finally:
if manifest_handle is not None:
try:
manifest_handle.close()
except Exception:
pass
out: Dict[str, Any] = {}
if want_faith:
if not faith_results:
out["faithfulness"] = None
else:
scores = np.stack(faith_results, axis=0) # [N, 3, 3]
out["faithfulness"] = {
"mean": scores.mean(0),
"std": scores.std(0),
"avg_time": float(np.mean(faith_durations)) if faith_durations else 0.0,
}
if want_recovery:
if not recovery_results:
out["recovery"] = None
else:
scores = np.stack(recovery_results, axis=0) # [N, 3]
out["recovery"] = {
"mean": scores.mean(0),
"std": scores.std(0),
"avg_time": float(np.mean(recovery_attr_durations)) if recovery_attr_durations else 0.0,
"used": int(scores.shape[0]),
"skipped": int(recovery_skipped),
}
return out
def _faithfulness_test_with_user_prompt_indices(
llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
attribution: torch.Tensor,
prompt: str,
generation: str,
*,
user_prompt_indices: List[int],
k: int = 20, ### control the MAS steps per sample
) -> Tuple[float, float, float]:
"""Token-level MAS/RISE faithfulness via guided deletion in k perturbation steps using provided prompt indices.
This mirrors llm_attr_eval.LLMAttributionEvaluator.faithfulness_test, but avoids
locating the user prompt span via token-id subsequence matching (which may fail
for some tokenizers due to non-compositional BPE merges at template boundaries).
"""
def auc(arr: np.ndarray) -> float:
return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / max(1, (arr.shape[0] - 1))
pad_token_id = llm_evaluator._ensure_pad_token_id()
user_prompt = " " + prompt
formatted_prompt = llm_evaluator.format_prompt(user_prompt)
formatted_ids = llm_evaluator.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).input_ids
prompt_ids = formatted_ids.to(llm_evaluator.device)
prompt_ids_perturbed = prompt_ids.clone()
generation_ids = llm_evaluator.tokenizer(
generation + llm_evaluator.tokenizer.eos_token,
return_tensors="pt",
add_special_tokens=False,
).input_ids.to(llm_evaluator.device)
attr_cpu = attribution.detach().cpu()
w = attr_cpu.sum(0)
sorted_attr_indices = torch.argsort(w, descending=True)
attr_sum = float(w.sum().item())
P = int(w.numel())
if len(user_prompt_indices) != P:
raise ValueError(
"user_prompt_indices length does not match prompt-side attribution length: "
f"indices P={len(user_prompt_indices)}, attr P={P}."
)
if P == 0:
return 0.0, 0.0, 0.0
if max(user_prompt_indices) >= int(prompt_ids_perturbed.shape[1]):
raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")
if P > 0:
steps = int(k) if k is not None else 0
if steps <= 0:
steps = 1
steps = min(steps, P)
else:
steps = 0
scores = np.zeros(steps + 1, dtype=np.float64)
density = np.zeros(steps + 1, dtype=np.float64)
scores[0] = (
llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
)
density[0] = 1.0
if attr_sum <= 0:
density = np.linspace(1.0, 0.0, steps + 1)
base = P // steps
remainder = P % steps
start = 0
for step in range(steps):
size = base + (1 if step < remainder else 0)
group = sorted_attr_indices[start : start + size]
start += size
for idx in group:
j = int(idx.item())
abs_pos = int(user_prompt_indices[j])
prompt_ids_perturbed[0, abs_pos] = pad_token_id
scores[step + 1] = (
llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
)
if attr_sum > 0:
dec = float(w.index_select(0, group).sum().item()) / attr_sum
density[step + 1] = density[step] - dec
min_normalized_pred = 1.0
normalized_model_response = scores.copy()
for i in range(len(scores)):
normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
min_normalized_pred = min(min_normalized_pred, normalized_pred)
normalized_model_response[i] = min_normalized_pred
alignment_penalty = np.abs(normalized_model_response - density)
corrected_scores = normalized_model_response + alignment_penalty
corrected_scores = corrected_scores.clip(0.0, 1.0)
corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))
if np.isnan(corrected_scores).any():
corrected_scores = np.linspace(1.0, 0.0, len(scores))
return auc(normalized_model_response), auc(corrected_scores), auc(normalized_model_response + alignment_penalty)
def load_model(model_name: str, device: str):
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
torch_dtype=torch.float16,
attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model.eval()
return model, tokenizer
def resolve_device(args) -> str:
if args.cuda is not None and "," in args.cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
return "auto"
if args.cuda is not None and args.cuda.strip():
return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"
def run_attribution(
testing_dict, example: ds_utils.CachedExample, target: Optional[str]
) -> Tuple[List[torch.Tensor], Optional[Dict[str, np.ndarray]], Optional[List[int]]]:
model = testing_dict["model"]
tokenizer = testing_dict["tokenizer"]
attr_func = testing_dict["attr_func"]
indices_to_explain = example.indices_to_explain
if not (isinstance(indices_to_explain, list) and len(indices_to_explain) == 2):
raise ValueError(
"exp2 requires token-span indices_to_explain=[start_tok,end_tok]. "
"Please re-sample or run exp/exp2/migrate_indices_to_explain_token_span.py on your cache."
)
llm_attributor = None
if "IG" in attr_func:
llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
attr = llm_attributor.calculate_IG_per_generation(
example.prompt,
20,
tokenizer.eos_token_id,
batch_size=testing_dict["batch_size"],
target=target,
)
elif "perturbation" in attr_func:
if attr_func in ("perturbation_all_fast", "perturbation_CLP_fast", "perturbation_REAGENT_fast"):
import perturbation_fast
llm_attributor = perturbation_fast.LLMPerturbationFastAttribution(model, tokenizer)
if attr_func == "perturbation_all_fast":
attr = llm_attributor.calculate_feature_ablation_segments(
example.prompt,
baseline=tokenizer.eos_token_id,
measure="log_loss",
target=target,
source_k=20,
)
elif attr_func == "perturbation_CLP_fast":
attr = llm_attributor.calculate_feature_ablation_segments(
example.prompt,
baseline=tokenizer.eos_token_id,
measure="KL",
target=target,
source_k=20,
)
else:
attr = llm_attributor.calculate_feature_ablation_segments_mlm(
example.prompt,
target=target,
source_k=20,
)
else:
llm_attributor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
if attr_func == "perturbation_all":
attr = llm_attributor.calculate_feature_ablation_sentences(
example.prompt, baseline=tokenizer.eos_token_id, measure="log_loss", target=target
)
elif attr_func == "perturbation_CLP":
attr = llm_attributor.calculate_feature_ablation_sentences(
example.prompt, baseline=tokenizer.eos_token_id, measure="KL", target=target
)
elif attr_func == "perturbation_REAGENT":
attr = llm_attributor.calculate_feature_ablation_sentences_mlm(example.prompt, target=target)
else:
raise ValueError(f"Unsupported perturbation attr_func {attr_func}")
elif "attention" in attr_func:
llm_attributor = llm_attr.LLMAttentionAttribution(model, tokenizer)
llm_attributor_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
attr = llm_attributor.calculate_attention_attribution(example.prompt, target=target)
attr_b = llm_attributor_ig.calculate_IG_per_generation(
example.prompt, 20, tokenizer.eos_token_id, batch_size=testing_dict["batch_size"], target=target
)
attr.attribution_matrix = attr.attribution_matrix * attr_b.attribution_matrix
elif attr_func == "ifr_all_positions":
llm_attributor = llm_attr.LLMIFRAttribution(
model,
tokenizer,
chunk_tokens=testing_dict["chunk_tokens"],
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
)
attr = llm_attributor.calculate_ifr_for_all_positions(example.prompt, target=target)
elif attr_func == "ifr_all_positions_output_only":
llm_attributor = llm_attr.LLMIFRAttribution(
model,
tokenizer,
chunk_tokens=testing_dict["chunk_tokens"],
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
)
sink_span = tuple(example.sink_span) if example.sink_span else tuple(indices_to_explain)
attr = llm_attributor.calculate_ifr_for_all_positions_output_only(
example.prompt,
target=target,
sink_span=sink_span,
)
elif attr_func == "ifr_multi_hop":
llm_attributor = llm_attr.LLMIFRAttribution(
model,
tokenizer,
chunk_tokens=testing_dict["chunk_tokens"],
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
)
attr = llm_attributor.calculate_ifr_multi_hop(
example.prompt,
target=target,
sink_span=tuple(example.sink_span) if example.sink_span else None,
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
n_hops=testing_dict["n_hops"],
)
elif attr_func == "ifr_in_all_gen":
import ft_ifr_improve
llm_attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(
model,
tokenizer,
chunk_tokens=testing_dict["chunk_tokens"],
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
)
attr = llm_attributor.calculate_ifr_in_all_gen(
example.prompt,
target=target,
sink_span=tuple(example.sink_span) if example.sink_span else None,
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
n_hops=testing_dict["n_hops"],
)
elif attr_func == "ifr_multi_hop_stop_words":
import ft_ifr_improve
llm_attributor = ft_ifr_improve.LLMIFRAttributionImproved(
model,
tokenizer,
chunk_tokens=testing_dict["chunk_tokens"],
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
)
attr = llm_attributor.calculate_ifr_multi_hop_stop_words(
example.prompt,
target=target,
sink_span=tuple(example.sink_span) if example.sink_span else None,
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
n_hops=testing_dict["n_hops"],
)
elif attr_func == "ifr_multi_hop_both":
import ft_ifr_improve
llm_attributor = ft_ifr_improve.LLMIFRAttributionBoth(
model,
tokenizer,
chunk_tokens=testing_dict["chunk_tokens"],
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
)
attr = llm_attributor.calculate_ifr_multi_hop_both(
example.prompt,
target=target,
sink_span=tuple(example.sink_span) if example.sink_span else None,
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
n_hops=testing_dict["n_hops"],
)
elif attr_func == "ifr_multi_hop_split_hop":
import ft_ifr_improve
llm_attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(
model,
tokenizer,
chunk_tokens=testing_dict["chunk_tokens"],
sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
)
attr = llm_attributor.calculate_ifr_multi_hop_split_hop(
example.prompt,
target=target,
sink_span=tuple(example.sink_span) if example.sink_span else None,
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
n_hops=testing_dict["n_hops"],
)
elif attr_func == "attnlrp":
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
attr = llm_attributor.calculate_attnlrp_ft_hop0(
example.prompt,
target=target,
sink_span=tuple(example.sink_span) if example.sink_span else None,
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
neg_handling=str(testing_dict.get("attnlrp_neg_handling", "drop")),
norm_mode=str(testing_dict.get("attnlrp_norm_mode", "norm")),
)
elif attr_func in ("ft_attnlrp", "attnlrp_aggregated_multi_hop"):
llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
attr = llm_attributor.calculate_attnlrp_aggregated_multi_hop(
example.prompt,
target=target,
sink_span=tuple(example.sink_span) if example.sink_span else None,
thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
n_hops=testing_dict["n_hops"],
neg_handling=str(testing_dict.get("attnlrp_neg_handling", "drop")),
norm_mode=str(testing_dict.get("attnlrp_norm_mode", "norm")),
)
elif attr_func == "basic":
llm_attributor = llm_attr.LLMBasicAttribution(model, tokenizer)
attr = llm_attributor.calculate_basic_attribution(example.prompt, target=target)
else:
raise ValueError(f"Unsupported attr_func {attr_func}")
seq_attr, row_attr, rec_attr = attr.get_all_token_attrs(indices_to_explain)
hop_payload = None
if bool(testing_dict.get("save_hop_traces", False)):
try:
hop_payload = _build_hop_trace_payload(attr_func, attr, indices_to_explain=indices_to_explain)
except Exception as exc:
print(f"[warn] hop trace extraction failed for {attr_func}: {exc}")
hop_payload = None
user_prompt_indices = getattr(llm_attributor, "user_prompt_indices", None)
if isinstance(user_prompt_indices, list):
user_prompt_indices = [int(x) for x in user_prompt_indices]
else:
user_prompt_indices = None
keep_prompt_token_indices = None
if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both"):
try:
import ft_ifr_improve
keep_prompt_token_indices = ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens))
except Exception:
keep_prompt_token_indices = None
return [seq_attr, row_attr, rec_attr], hop_payload, user_prompt_indices, keep_prompt_token_indices
def faithfulness_generation(
testing_dict, example: ds_utils.CachedExample, target: str, llm_evaluator
) -> Tuple[np.ndarray, Optional[Dict[str, np.ndarray]]]:
prompt = example.prompt
generation = target
attr_func = str(testing_dict.get("attr_func") or "")
attr_list, hop_payload, user_prompt_indices, keep_prompt_token_indices = run_attribution(
testing_dict, example, target
)
seq_attr = attr_list[0]
prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0]) # cols=(P+G), rows=G
results = []
for attr in attr_list:
# Only use prompt-side attribution, matching evaluations/faithfulness.py
attr_prompt = attr[:, :prompt_len]
if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
import ft_ifr_improve
scores = ft_ifr_improve.faithfulness_test_skip_tokens(
llm_evaluator,
attr_prompt,
prompt,
generation,
keep_prompt_token_indices=keep_prompt_token_indices,
user_prompt_indices=user_prompt_indices,
)
elif user_prompt_indices is not None:
scores = _faithfulness_test_with_user_prompt_indices(
llm_evaluator,
attr_prompt,
prompt,
generation,
user_prompt_indices=user_prompt_indices,
)
else:
scores = llm_evaluator.faithfulness_test(attr_prompt, prompt, generation)
results.append(scores)
return np.array(results), hop_payload
def evaluate_dataset(args, dataset_name: str, examples: List[ds_utils.CachedExample], testing_dict):
out = evaluate_dataset_multi(args, dataset_name, examples, testing_dict, modes=["faithfulness_gen"])
faith = out.get("faithfulness")
if not faith:
return None
return faith["mean"], faith["std"], faith["avg_time"]
def evaluate_dataset_recovery_ruler(args, dataset_name: str, examples: List[ds_utils.CachedExample], testing_dict):
out = evaluate_dataset_multi(args, dataset_name, examples, testing_dict, modes=["recovery_ruler"])
rec = out.get("recovery")
if not rec:
return None
return rec["mean"], rec["std"], rec["avg_time"], rec["used"], rec["skipped"]
def main():
parser = argparse.ArgumentParser("Experiment 2 runner (math skipped, AT2 skipped).")
parser.add_argument("--datasets", type=str, required=True, help="Comma-separated names or paths.")
parser.add_argument("--attr_funcs", type=str, required=True, help="Comma-separated attr funcs (no AT2).")
parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).")
parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.")
parser.add_argument("--cuda", type=str, default=None)
parser.add_argument("--cuda_num", type=int, default=0)
parser.add_argument("--num_examples", type=int, default=100)
parser.add_argument(
"--mode",
type=str,
nargs="+",
default=["faithfulness_gen"],
help=(
"One or more of: faithfulness_gen, recovery_ruler. "
"Accepts comma-separated values, e.g. '--mode faithfulness_gen,recovery_ruler' "
"or '--mode faithfulness_gen, recovery_ruler'."
),
)
parser.add_argument("--sample", type=int, default=None, help="Optional subsample before num_examples.")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--chunk_tokens", type=int, default=128)
parser.add_argument("--sink_chunk_tokens", type=int, default=32)
parser.add_argument("--n_hops", type=int, default=3)
parser.add_argument(
"--attnlrp_neg_handling",
type=str,
choices=["drop", "abs"],
default="drop",
help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
)
parser.add_argument(
"--attnlrp_norm_mode",
type=str,
choices=["norm", "no_norm"],
default="norm",
help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
)
parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Filtered dataset cache directory.")
parser.add_argument("--output_root", type=str, default="exp/exp2/output", help="Directory to store evaluation outputs.")
parser.add_argument(
"--save_hop_traces",
action="store_true",
help=(
"Save per-sample trace artifacts (attribution vectors + per-sample metrics) under output_root/traces/. "
"For multi-hop methods, also saves per-hop token vectors (vh)."
),
)
args = parser.parse_args()
modes = _parse_modes(args.mode)
if args.model_path:
model_name = args.model_path
elif args.model:
model_name = args.model
else:
raise SystemExit("Please set --model or --model_path.")
model_tag = args.model if args.model else Path(args.model_path).name
datasets = [d.strip() for d in args.datasets.split(",") if d.strip()]
attr_funcs = [a.strip() for a in args.attr_funcs.split(",") if a.strip()]
device = resolve_device(args)
model, tokenizer = load_model(model_name, device)
max_input_len = {
"llama-1B": 5500,
"llama-3B": 4800,
"llama-8B": 3500,
"qwen-1.7B": 5500,
"qwen-4B": 3500,
"qwen-8B": 5000,
"qwen-32B": 1500,
"gemma-12B": 1500,
"gemma-27B": 2000,
}.get(args.model, 2000)
for ds_name in datasets:
if "recovery_ruler" in modes and ds_name == "morehopqa":
raise SystemExit("recovery_ruler only supports RULER datasets (with needle_spans), not morehopqa.")
if "recovery_ruler" in modes and ds_name.startswith("math"):
raise SystemExit("recovery_ruler only supports RULER datasets (with needle_spans), not math.")
# Resolve dataset (prefer prepared cache under data_root)
cached_path = Path(args.data_root) / f"{ds_name}.jsonl"
if cached_path.exists():
examples = ds_utils.load_cached(cached_path, sample=args.sample, seed=args.seed)
else:
# allow direct cached path or raw loader
p = Path(ds_name)
if p.exists():
examples = ds_utils.load_cached(p, sample=args.sample, seed=args.seed)
else:
hint = "please run exp/exp2/sample_and_filter.py first (or pass an explicit cached JSONL path)."
if ds_name.startswith("math"):
hint = "please run exp/exp2/map_math_mine_to_exp2_cache.py first (or pass an explicit cached JSONL path)."
raise SystemExit(f"Missing exp2 cache for '{ds_name}'. Expected {cached_path}; {hint}")
for attr_func in attr_funcs:
if attr_func.lower() == "at2":
print("Skipping AT2 as requested.")
continue
testing_dict: Dict[str, any] = {
"model": model,
"model_tag": model_tag,
"tokenizer": tokenizer,
"attr_func": attr_func,
"max_input_len": max_input_len,
"chunk_tokens": args.chunk_tokens,
"sink_chunk_tokens": args.sink_chunk_tokens,
"n_hops": args.n_hops,
"attnlrp_neg_handling": args.attnlrp_neg_handling,
"attnlrp_norm_mode": args.attnlrp_norm_mode,
"device": device,
"batch_size": 1,
"save_hop_traces": bool(args.save_hop_traces),
}
result = evaluate_dataset_multi(args, ds_name, examples, testing_dict, modes=modes)
if "faithfulness_gen" in modes:
faith = result.get("faithfulness")
if not faith:
print(f"No faithfulness results for {ds_name} with {attr_func}.")
else:
mean = faith["mean"]
std = faith["std"]
avg_time = float(faith["avg_time"])
out_dir = Path(args.output_root) / "faithfulness" / ds_name / model_tag
out_dir.mkdir(parents=True, exist_ok=True)
filename = f"{attr_func}_{args.num_examples}_examples.csv"
with open(out_dir / filename, "w") as f:
f.write("Method,RISE,MAS,RISE+AP\n")
f.write(",".join(["Seq Attr Scores Mean"] + [str(x) for x in mean[0].tolist()]) + "\n")
f.write(",".join(["Row Attr Scores Mean"] + [str(x) for x in mean[1].tolist()]) + "\n")
f.write(",".join(["Recursive Attr Scores Mean"] + [str(x) for x in mean[2].tolist()]) + "\n")
f.write(",".join(["Seq Attr Scores Var"] + [str(x) for x in std[0].tolist()]) + "\n")
f.write(",".join(["Row Attr Scores Var"] + [str(x) for x in std[1].tolist()]) + "\n")
f.write(",".join(["Recursive Attr Scores Var"] + [str(x) for x in std[2].tolist()]) + "\n")
f.write(f"Avg Sample Time (s),{avg_time}\n")
print(f"[{ds_name}] {attr_func} -> {out_dir/filename} (avg sample time: {avg_time:.2f}s)")
if "recovery_ruler" in modes:
rec = result.get("recovery")
if not rec:
print(f"No recovery results for {ds_name} with {attr_func}.")
else:
mean = rec["mean"]
std = rec["std"]
avg_time = float(rec["avg_time"])
used = int(rec["used"])
skipped = int(rec["skipped"])
out_dir = Path(args.output_root) / "recovery" / ds_name / model_tag
out_dir.mkdir(parents=True, exist_ok=True)
filename = f"{attr_func}_{args.num_examples}_examples.csv"
with open(out_dir / filename, "w") as f:
f.write("Method,Recovery@10%\n")
f.write(f"Seq Attr Recovery Mean,{mean[0]}\n")
f.write(f"Row Attr Recovery Mean,{mean[1]}\n")
f.write(f"Recursive Attr Recovery Mean,{mean[2]}\n")
f.write(f"Seq Attr Recovery Std,{std[0]}\n")
f.write(f"Row Attr Recovery Std,{std[1]}\n")
f.write(f"Recursive Attr Recovery Std,{std[2]}\n")
f.write(f"Examples Used,{used}\n")
f.write(f"Examples Skipped,{skipped}\n")
f.write(f"Avg Sample Time (s),{avg_time}\n")
print(
f"[{ds_name}] {attr_func} -> {out_dir/filename} "
f"(used={used} skipped={skipped} avg sample time: {avg_time:.2f}s)"
)
if __name__ == "__main__":
main()