yezdata commited on
Commit
7377a60
·
1 Parent(s): df5b9a8

swap log for log2 (shannon binary entropy)

Browse files
Files changed (1) hide show
  1. main.py +1 -1
main.py CHANGED
@@ -18,7 +18,7 @@ model.eval()
18
 
19
  def compute_binary_entropy(p: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
20
  p = torch.clamp(p, min=eps, max=1.0 - eps)
21
- return -(p * torch.log(p) + (1.0 - p) * torch.log(1.0 - p))
22
 
23
 
24
  def compute_uncertainty(probs_samples: torch.Tensor, mean_probs: torch.Tensor) -> dict:
 
18
 
19
  def compute_binary_entropy(p: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
20
  p = torch.clamp(p, min=eps, max=1.0 - eps)
21
+ return -(p * torch.log2(p) + (1.0 - p) * torch.log2(1.0 - p))
22
 
23
 
24
  def compute_uncertainty(probs_samples: torch.Tensor, mean_probs: torch.Tensor) -> dict: