Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| def bits_per_byte(ce_loss: float) -> float: | |
| """Convert cross-entropy loss (nats) to bits per byte.""" | |
| return ce_loss / math.log(2) | |
| def branch_diversity(branch_logits: list[torch.Tensor]) -> float: | |
| """Average cosine distance between branch probability distributions.""" | |
| if len(branch_logits) < 2: | |
| return 0.0 | |
| total = 0.0 | |
| n_pairs = 0 | |
| for i in range(len(branch_logits)): | |
| for j in range(i + 1, len(branch_logits)): | |
| pi = F.softmax(branch_logits[i], dim=-1).reshape(-1, branch_logits[i].size(-1)) | |
| pj = F.softmax(branch_logits[j], dim=-1).reshape(-1, branch_logits[j].size(-1)) | |
| cos_sim = F.cosine_similarity(pi, pj, dim=-1).mean().item() | |
| total += 1.0 - cos_sim | |
| n_pairs += 1 | |
| return total / n_pairs if n_pairs > 0 else 0.0 | |