File size: 882 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn.functional as F
import math


def bits_per_byte(ce_loss: float) -> float:
    """Convert cross-entropy loss (nats) to bits per byte."""
    return ce_loss / math.log(2)


def branch_diversity(branch_logits: list[torch.Tensor]) -> float:
    """Average cosine distance between branch probability distributions."""
    if len(branch_logits) < 2:
        return 0.0
    total = 0.0
    n_pairs = 0
    for i in range(len(branch_logits)):
        for j in range(i + 1, len(branch_logits)):
            pi = F.softmax(branch_logits[i], dim=-1).reshape(-1, branch_logits[i].size(-1))
            pj = F.softmax(branch_logits[j], dim=-1).reshape(-1, branch_logits[j].size(-1))
            cos_sim = F.cosine_similarity(pi, pj, dim=-1).mean().item()
            total += 1.0 - cos_sim
            n_pairs += 1
    return total / n_pairs if n_pairs > 0 else 0.0