File size: 5,133 Bytes
f37be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from __future__ import annotations

from dataclasses import asdict, dataclass
from math import log2
from statistics import mean, median

import torch


@dataclass(frozen=True)
class TokenStructureStats:
    token_count: int
    vocab_size: int
    unique_tokens: int
    unique_token_ratio: float
    unigram_entropy_bits: float
    top_8_mass: float
    repeat_within_4: float
    repeat_within_8: float
    repeat_within_16: float
    mean_repeat_distance: float
    median_repeat_distance: float
    conditional_entropy_bits: float
    mean_next_token_peak_prob: float


@dataclass(frozen=True)
class SequenceStructureStats:
    num_sequences: int
    seq_len: int
    unique_sequences: int
    sequence_uniqueness_ratio: float
    top_sequence_mass: float


def _entropy_from_counts(counts: torch.Tensor) -> float:
    probs = counts.float() / counts.sum().clamp_min(1)
    probs = probs[probs > 0]
    if probs.numel() == 0:
        return 0.0
    return float((-(probs * probs.log2())).sum().item())


def _repeat_stats(token_ids: torch.Tensor, window: int) -> float:
    ids = token_ids.tolist()
    hits = 0
    total = 0
    for idx, token in enumerate(ids):
        if idx == 0:
            continue
        total += 1
        start = max(0, idx - window)
        if token in ids[start:idx]:
            hits += 1
    return hits / max(total, 1)


def _repeat_distances(token_ids: torch.Tensor) -> tuple[float, float]:
    last_seen: dict[int, int] = {}
    distances: list[int] = []
    for idx, token in enumerate(token_ids.tolist()):
        if token in last_seen:
            distances.append(idx - last_seen[token])
        last_seen[token] = idx
    if not distances:
        return 0.0, 0.0
    return float(mean(distances)), float(median(distances))


def _conditional_stats(token_ids: torch.Tensor, vocab_size: int) -> tuple[float, float]:
    if token_ids.numel() < 2:
        return 0.0, 0.0
    pair_counts = torch.zeros((vocab_size, vocab_size), dtype=torch.long)
    src = token_ids[:-1].long()
    dst = token_ids[1:].long()
    for a, b in zip(src.tolist(), dst.tolist()):
        pair_counts[a, b] += 1

    totals = pair_counts.sum(dim=1)
    entropies: list[float] = []
    peaks: list[float] = []
    weights: list[float] = []
    for row, total in zip(pair_counts, totals):
        total_int = int(total.item())
        if total_int <= 0:
            continue
        probs = row.float() / total_int
        nz = probs[probs > 0]
        entropies.append(float((-(nz * nz.log2())).sum().item()))
        peaks.append(float(probs.max().item()))
        weights.append(total_int)
    if not weights:
        return 0.0, 0.0
    total_weight = sum(weights)
    weighted_entropy = sum(e * w for e, w in zip(entropies, weights)) / total_weight
    weighted_peak = sum(p * w for p, w in zip(peaks, weights)) / total_weight
    return weighted_entropy, weighted_peak


def compute_token_structure_stats(token_ids: torch.Tensor, vocab_size: int | None = None) -> TokenStructureStats:
    flat = token_ids.detach().flatten().to(torch.long).cpu()
    if flat.numel() == 0:
        raise ValueError("token_ids must be non-empty")
    if vocab_size is None:
        vocab_size = int(flat.max().item()) + 1
    counts = torch.bincount(flat, minlength=vocab_size)
    sorted_counts, _ = torch.sort(counts, descending=True)
    top_8_mass = float(sorted_counts[:8].sum().item() / flat.numel())
    mean_repeat_distance, median_repeat_distance = _repeat_distances(flat)
    conditional_entropy_bits, mean_next_token_peak_prob = _conditional_stats(flat, vocab_size)
    unique_tokens = int((counts > 0).sum().item())
    return TokenStructureStats(
        token_count=int(flat.numel()),
        vocab_size=int(vocab_size),
        unique_tokens=unique_tokens,
        unique_token_ratio=unique_tokens / max(int(flat.numel()), 1),
        unigram_entropy_bits=_entropy_from_counts(counts),
        top_8_mass=top_8_mass,
        repeat_within_4=_repeat_stats(flat, 4),
        repeat_within_8=_repeat_stats(flat, 8),
        repeat_within_16=_repeat_stats(flat, 16),
        mean_repeat_distance=mean_repeat_distance,
        median_repeat_distance=median_repeat_distance,
        conditional_entropy_bits=conditional_entropy_bits,
        mean_next_token_peak_prob=mean_next_token_peak_prob,
    )


def compute_sequence_structure_stats(sequences: torch.Tensor) -> SequenceStructureStats:
    if sequences.ndim != 2:
        raise ValueError("sequences must have shape [N, T]")
    rows = [tuple(int(v) for v in row.tolist()) for row in sequences.cpu()]
    counts: dict[tuple[int, ...], int] = {}
    for row in rows:
        counts[row] = counts.get(row, 0) + 1
    top_sequence_mass = max(counts.values()) / max(len(rows), 1)
    return SequenceStructureStats(
        num_sequences=len(rows),
        seq_len=int(sequences.size(1)),
        unique_sequences=len(counts),
        sequence_uniqueness_ratio=len(counts) / max(len(rows), 1),
        top_sequence_mass=top_sequence_mass,
    )


def stats_to_dict(stats: TokenStructureStats | SequenceStructureStats) -> dict[str, float | int]:
    return asdict(stats)