abpt / src /utils /metrics.py
Search
feat: add src/ module for script imports
8125804
raw
history blame contribute delete
882 Bytes
import torch
import torch.nn.functional as F
import math
def bits_per_byte(ce_loss: float) -> float:
"""Convert cross-entropy loss (nats) to bits per byte."""
return ce_loss / math.log(2)
def branch_diversity(branch_logits: list[torch.Tensor]) -> float:
"""Average cosine distance between branch probability distributions."""
if len(branch_logits) < 2:
return 0.0
total = 0.0
n_pairs = 0
for i in range(len(branch_logits)):
for j in range(i + 1, len(branch_logits)):
pi = F.softmax(branch_logits[i], dim=-1).reshape(-1, branch_logits[i].size(-1))
pj = F.softmax(branch_logits[j], dim=-1).reshape(-1, branch_logits[j].size(-1))
cos_sim = F.cosine_similarity(pi, pj, dim=-1).mean().item()
total += 1.0 - cos_sim
n_pairs += 1
return total / n_pairs if n_pairs > 0 else 0.0