selection-entropy-llm / selection_entropy.py
mioulin's picture
Rename selection-entropy.py to selection_entropy.py
49d901d verified
"""
Selection Entropy for LLMs
--------------------------
Original context: neural state-space divergence in computational neuroscience.
Adapted here: token-level uncertainty quantification in LLM outputs.
Key idea: Selection Entropy measures not just how spread a distribution is
(Shannon entropy), but HOW that probability mass is distributed relative
to a reference — with a non-linear sensitivity to "contested" choices
where a few alternatives compete closely.
SE(p) = -Σ p_i * log(p_i / (p_i + α * Σ_{j≠i} p_j * w_ij))
where w_ij = exp(-|rank_i - rank_j|) penalises distant alternatives less —
capturing that the "cost" of near-alternatives is higher than far-tail noise.
"""
import numpy as np
from dataclasses import dataclass
from typing import List
@dataclass
class TokenEntropy:
token: str
token_id: int
probability: float
shannon_entropy: float
selection_entropy: float
top_k_alternatives: List[dict] # [{token, prob}, ...]
def softmax(logits: np.ndarray) -> np.ndarray:
logits = logits - logits.max()
exp_logits = np.exp(logits)
return exp_logits / exp_logits.sum()
def shannon_entropy(probs: np.ndarray) -> float:
"""Standard Shannon entropy H(p) = -Σ p_i log(p_i)"""
probs = probs[probs > 0]
return float(-np.sum(probs * np.log(probs + 1e-12)))
def selection_entropy(probs: np.ndarray, alpha: float = 0.5, top_k: int = 50) -> float:
"""
Selection Entropy — adapted from neural divergence metric.
Core intuition: A model is more "uncertain" when its top choice
competes closely with ranked alternatives, not just when the distribution
is diffuse. SE is sensitive to the *structure* of competition at the top.
Parameters
----------
probs : array of token probabilities (full vocab or top-k)
alpha : competition sensitivity (0 = like Shannon, 1 = full competition weighting)
top_k : number of top tokens to consider for competition
Returns
-------
float : Selection Entropy value in nats
"""
# Use top_k tokens for efficiency
if len(probs) > top_k:
top_idx = np.argsort(probs)[::-1][:top_k]
probs = probs[top_idx]
probs = probs / probs.sum() # renormalise
n = len(probs)
if n == 0:
return 0.0
# Rank-based competition weights: nearby ranks compete more
ranks = np.arange(n, dtype=float)
se = 0.0
for i in range(n):
if probs[i] < 1e-12:
continue
# Competition: weighted sum of other tokens, decaying with rank distance
rank_distances = np.abs(ranks - ranks[i]).astype(float)
rank_distances[i] = np.inf # exclude self
competition_weights = np.exp(-rank_distances)
competition_weights[i] = 0.0
competitor_mass = np.sum(probs * competition_weights)
denominator = probs[i] + alpha * competitor_mass
if denominator > 1e-12:
se -= probs[i] * np.log(probs[i] / denominator + 1e-12)
return float(se)
def compute_token_entropies(
logits_sequence: np.ndarray,
tokens: List[str],
token_ids: List[int],
vocab_tokens: List[str],
alpha: float = 0.5,
top_k_display: int = 5,
) -> List[TokenEntropy]:
"""
Compute per-token entropy metrics for a generated sequence.
Parameters
----------
logits_sequence : shape (seq_len, vocab_size) — raw logits at each step
tokens : list of generated token strings
token_ids : list of generated token IDs
vocab_tokens : full vocabulary token strings
alpha : SE competition sensitivity
top_k_display : number of top alternatives to return per token
"""
results = []
for step, (logits, token, token_id) in enumerate(
zip(logits_sequence, tokens, token_ids)
):
probs = softmax(logits)
# Top-k alternatives for display
top_idx = np.argsort(probs)[::-1][:top_k_display]
alternatives = [
{"token": vocab_tokens[idx], "prob": float(probs[idx])}
for idx in top_idx
]
results.append(
TokenEntropy(
token=token,
token_id=token_id,
probability=float(probs[token_id]),
shannon_entropy=shannon_entropy(probs),
selection_entropy=selection_entropy(probs, alpha=alpha),
top_k_alternatives=alternatives,
)
)
return results
def normalise_entropies(token_entropies: List[TokenEntropy], metric: str = "selection") -> List[float]:
"""Return 0-1 normalised entropy values for visualisation."""
if metric == "selection":
values = [t.selection_entropy for t in token_entropies]
else:
values = [t.shannon_entropy for t in token_entropies]
max_val = max(values) if values else 1.0
if max_val < 1e-12:
return [0.0] * len(values)
return [v / max_val for v in values]