File size: 2,336 Bytes
55b60a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Helpers for IFR case studies (hop-wise aggregation + sanitization).

All utilities stay local to exp/case_study to avoid touching core eval code.
"""

from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional, Sequence

import torch


def vector_stats(vec: torch.Tensor) -> Dict[str, float]:
    if vec.numel() == 0:
        return {"min": 0.0, "max": 0.0, "abs_max": 0.0, "mean": 0.0, "sum": 0.0}
    v = vec.detach().to(dtype=torch.float32)
    return {
        "min": float(v.min().item()),
        "max": float(v.max().item()),
        "abs_max": float(v.abs().max().item()),
        "mean": float(v.mean().item()),
        "sum": float(v.sum().item()),
    }


def tensor_to_list(x: Any) -> Any:
    if torch.is_tensor(x):
        return x.detach().cpu().tolist()
    if isinstance(x, list):
        return [tensor_to_list(v) for v in x]
    if isinstance(x, dict):
        return {k: tensor_to_list(v) for k, v in x.items()}
    return x


def sanitize_ifr_meta(meta: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    """Drop bulky raw objects and convert tensors to Python lists for JSON."""

    if meta is None:
        return None

    cleaned: Dict[str, Any] = {}
    for key, value in meta.items():
        if key == "raw":
            continue
        cleaned[key] = tensor_to_list(value)
    return cleaned


def package_token_hops(
    hop_vectors: Iterable[Sequence[float]],
) -> List[Dict[str, Any]]:
    """Package per-hop token vectors without sentence aggregation.

    hop_vectors are assumed to already match the experiment's configured
    postprocessing (e.g., FT-AttnLRP neg_handling/norm_mode).
    """

    packaged: List[Dict[str, Any]] = []
    for hop_idx, vec in enumerate(hop_vectors):
        vec_tensor = torch.nan_to_num(torch.as_tensor(vec, dtype=torch.float32), nan=0.0)
        token_scores = vec_tensor.tolist()
        token_max = float(vec_tensor.abs().max().item()) if vec_tensor.numel() > 0 else 0.0
        total = float(vec_tensor.sum().item())
        packaged.append(
            {
                "hop": hop_idx,
                "token_scores": token_scores,
                "token_score_max": token_max,
                "token_stats": vector_stats(vec_tensor),
                "total_mass": total,
            }
        )
    return packaged