"""Helpers for IFR case studies (hop-wise aggregation + sanitization). All utilities stay local to exp/case_study to avoid touching core eval code. """ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional, Sequence import torch def vector_stats(vec: torch.Tensor) -> Dict[str, float]: if vec.numel() == 0: return {"min": 0.0, "max": 0.0, "abs_max": 0.0, "mean": 0.0, "sum": 0.0} v = vec.detach().to(dtype=torch.float32) return { "min": float(v.min().item()), "max": float(v.max().item()), "abs_max": float(v.abs().max().item()), "mean": float(v.mean().item()), "sum": float(v.sum().item()), } def tensor_to_list(x: Any) -> Any: if torch.is_tensor(x): return x.detach().cpu().tolist() if isinstance(x, list): return [tensor_to_list(v) for v in x] if isinstance(x, dict): return {k: tensor_to_list(v) for k, v in x.items()} return x def sanitize_ifr_meta(meta: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Drop bulky raw objects and convert tensors to Python lists for JSON.""" if meta is None: return None cleaned: Dict[str, Any] = {} for key, value in meta.items(): if key == "raw": continue cleaned[key] = tensor_to_list(value) return cleaned def package_token_hops( hop_vectors: Iterable[Sequence[float]], ) -> List[Dict[str, Any]]: """Package per-hop token vectors without sentence aggregation. hop_vectors are assumed to already match the experiment's configured postprocessing (e.g., FT-AttnLRP neg_handling/norm_mode). """ packaged: List[Dict[str, Any]] = [] for hop_idx, vec in enumerate(hop_vectors): vec_tensor = torch.nan_to_num(torch.as_tensor(vec, dtype=torch.float32), nan=0.0) token_scores = vec_tensor.tolist() token_max = float(vec_tensor.abs().max().item()) if vec_tensor.numel() > 0 else 0.0 total = float(vec_tensor.sum().item()) packaged.append( { "hop": hop_idx, "token_scores": token_scores, "token_score_max": token_max, "token_stats": vector_stats(vec_tensor), "total_mass": total, } ) return packaged