Spaces:
Paused
Paused
File size: 882 Bytes
8125804 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import torch
import torch.nn.functional as F
import math
def bits_per_byte(ce_loss: float) -> float:
"""Convert cross-entropy loss (nats) to bits per byte."""
return ce_loss / math.log(2)
def branch_diversity(branch_logits: list[torch.Tensor]) -> float:
"""Average cosine distance between branch probability distributions."""
if len(branch_logits) < 2:
return 0.0
total = 0.0
n_pairs = 0
for i in range(len(branch_logits)):
for j in range(i + 1, len(branch_logits)):
pi = F.softmax(branch_logits[i], dim=-1).reshape(-1, branch_logits[i].size(-1))
pj = F.softmax(branch_logits[j], dim=-1).reshape(-1, branch_logits[j].size(-1))
cos_sim = F.cosine_similarity(pi, pj, dim=-1).mean().item()
total += 1.0 - cos_sim
n_pairs += 1
return total / n_pairs if n_pairs > 0 else 0.0
|