Spaces:
Sleeping
Sleeping
File size: 5,133 Bytes
f37be5a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | from __future__ import annotations
from dataclasses import asdict, dataclass
from math import log2
from statistics import mean, median
import torch
@dataclass(frozen=True)
class TokenStructureStats:
token_count: int
vocab_size: int
unique_tokens: int
unique_token_ratio: float
unigram_entropy_bits: float
top_8_mass: float
repeat_within_4: float
repeat_within_8: float
repeat_within_16: float
mean_repeat_distance: float
median_repeat_distance: float
conditional_entropy_bits: float
mean_next_token_peak_prob: float
@dataclass(frozen=True)
class SequenceStructureStats:
num_sequences: int
seq_len: int
unique_sequences: int
sequence_uniqueness_ratio: float
top_sequence_mass: float
def _entropy_from_counts(counts: torch.Tensor) -> float:
probs = counts.float() / counts.sum().clamp_min(1)
probs = probs[probs > 0]
if probs.numel() == 0:
return 0.0
return float((-(probs * probs.log2())).sum().item())
def _repeat_stats(token_ids: torch.Tensor, window: int) -> float:
ids = token_ids.tolist()
hits = 0
total = 0
for idx, token in enumerate(ids):
if idx == 0:
continue
total += 1
start = max(0, idx - window)
if token in ids[start:idx]:
hits += 1
return hits / max(total, 1)
def _repeat_distances(token_ids: torch.Tensor) -> tuple[float, float]:
last_seen: dict[int, int] = {}
distances: list[int] = []
for idx, token in enumerate(token_ids.tolist()):
if token in last_seen:
distances.append(idx - last_seen[token])
last_seen[token] = idx
if not distances:
return 0.0, 0.0
return float(mean(distances)), float(median(distances))
def _conditional_stats(token_ids: torch.Tensor, vocab_size: int) -> tuple[float, float]:
if token_ids.numel() < 2:
return 0.0, 0.0
pair_counts = torch.zeros((vocab_size, vocab_size), dtype=torch.long)
src = token_ids[:-1].long()
dst = token_ids[1:].long()
for a, b in zip(src.tolist(), dst.tolist()):
pair_counts[a, b] += 1
totals = pair_counts.sum(dim=1)
entropies: list[float] = []
peaks: list[float] = []
weights: list[float] = []
for row, total in zip(pair_counts, totals):
total_int = int(total.item())
if total_int <= 0:
continue
probs = row.float() / total_int
nz = probs[probs > 0]
entropies.append(float((-(nz * nz.log2())).sum().item()))
peaks.append(float(probs.max().item()))
weights.append(total_int)
if not weights:
return 0.0, 0.0
total_weight = sum(weights)
weighted_entropy = sum(e * w for e, w in zip(entropies, weights)) / total_weight
weighted_peak = sum(p * w for p, w in zip(peaks, weights)) / total_weight
return weighted_entropy, weighted_peak
def compute_token_structure_stats(token_ids: torch.Tensor, vocab_size: int | None = None) -> TokenStructureStats:
flat = token_ids.detach().flatten().to(torch.long).cpu()
if flat.numel() == 0:
raise ValueError("token_ids must be non-empty")
if vocab_size is None:
vocab_size = int(flat.max().item()) + 1
counts = torch.bincount(flat, minlength=vocab_size)
sorted_counts, _ = torch.sort(counts, descending=True)
top_8_mass = float(sorted_counts[:8].sum().item() / flat.numel())
mean_repeat_distance, median_repeat_distance = _repeat_distances(flat)
conditional_entropy_bits, mean_next_token_peak_prob = _conditional_stats(flat, vocab_size)
unique_tokens = int((counts > 0).sum().item())
return TokenStructureStats(
token_count=int(flat.numel()),
vocab_size=int(vocab_size),
unique_tokens=unique_tokens,
unique_token_ratio=unique_tokens / max(int(flat.numel()), 1),
unigram_entropy_bits=_entropy_from_counts(counts),
top_8_mass=top_8_mass,
repeat_within_4=_repeat_stats(flat, 4),
repeat_within_8=_repeat_stats(flat, 8),
repeat_within_16=_repeat_stats(flat, 16),
mean_repeat_distance=mean_repeat_distance,
median_repeat_distance=median_repeat_distance,
conditional_entropy_bits=conditional_entropy_bits,
mean_next_token_peak_prob=mean_next_token_peak_prob,
)
def compute_sequence_structure_stats(sequences: torch.Tensor) -> SequenceStructureStats:
if sequences.ndim != 2:
raise ValueError("sequences must have shape [N, T]")
rows = [tuple(int(v) for v in row.tolist()) for row in sequences.cpu()]
counts: dict[tuple[int, ...], int] = {}
for row in rows:
counts[row] = counts.get(row, 0) + 1
top_sequence_mass = max(counts.values()) / max(len(rows), 1)
return SequenceStructureStats(
num_sequences=len(rows),
seq_len=int(sequences.size(1)),
unique_sequences=len(counts),
sequence_uniqueness_ratio=len(counts) / max(len(rows), 1),
top_sequence_mass=top_sequence_mass,
)
def stats_to_dict(stats: TokenStructureStats | SequenceStructureStats) -> dict[str, float | int]:
return asdict(stats)
|