flashtrace / exp /case_study /analysis.py
wenbopan's picture
Sync FlashTrace package from GitHub
55b60a8
"""Helpers for IFR case studies (hop-wise aggregation + sanitization).
All utilities stay local to exp/case_study to avoid touching core eval code.
"""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional, Sequence
import torch
def vector_stats(vec: torch.Tensor) -> Dict[str, float]:
if vec.numel() == 0:
return {"min": 0.0, "max": 0.0, "abs_max": 0.0, "mean": 0.0, "sum": 0.0}
v = vec.detach().to(dtype=torch.float32)
return {
"min": float(v.min().item()),
"max": float(v.max().item()),
"abs_max": float(v.abs().max().item()),
"mean": float(v.mean().item()),
"sum": float(v.sum().item()),
}
def tensor_to_list(x: Any) -> Any:
if torch.is_tensor(x):
return x.detach().cpu().tolist()
if isinstance(x, list):
return [tensor_to_list(v) for v in x]
if isinstance(x, dict):
return {k: tensor_to_list(v) for k, v in x.items()}
return x
def sanitize_ifr_meta(meta: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Drop bulky raw objects and convert tensors to Python lists for JSON."""
if meta is None:
return None
cleaned: Dict[str, Any] = {}
for key, value in meta.items():
if key == "raw":
continue
cleaned[key] = tensor_to_list(value)
return cleaned
def package_token_hops(
hop_vectors: Iterable[Sequence[float]],
) -> List[Dict[str, Any]]:
"""Package per-hop token vectors without sentence aggregation.
hop_vectors are assumed to already match the experiment's configured
postprocessing (e.g., FT-AttnLRP neg_handling/norm_mode).
"""
packaged: List[Dict[str, Any]] = []
for hop_idx, vec in enumerate(hop_vectors):
vec_tensor = torch.nan_to_num(torch.as_tensor(vec, dtype=torch.float32), nan=0.0)
token_scores = vec_tensor.tolist()
token_max = float(vec_tensor.abs().max().item()) if vec_tensor.numel() > 0 else 0.0
total = float(vec_tensor.sum().item())
packaged.append(
{
"hop": hop_idx,
"token_scores": token_scores,
"token_score_max": token_max,
"token_stats": vector_stats(vec_tensor),
"total_mass": total,
}
)
return packaged