Image Classification
English
TTA
ReservoirTTA / utils /losses.py
GuillaumeVray
Uploading files
02ba886
import torch
import torch.nn as nn
class Entropy(nn.Module):
def __init__(self):
super(Entropy, self).__init__()
def __call__(self, logits):
return -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
def softmax_mean_entropy(x: torch.Tensor) -> torch.Tensor:
"""Mean entropy of softmax distribution from logits."""
x = x.softmax(1).mean(0)
return -(x * torch.log(x)).sum()
class SymmetricCrossEntropy(nn.Module):
def __init__(self, alpha=0.5):
super(SymmetricCrossEntropy, self).__init__()
self.alpha = alpha
def __call__(self, x, x_ema):
return -(1-self.alpha) * (x_ema.softmax(1) * x.log_softmax(1)).sum(1) - self.alpha * (x.softmax(1) * x_ema.log_softmax(1)).sum(1)
class AugCrossEntropy(nn.Module):
def __init__(self, alpha=0.5):
super(AugCrossEntropy, self).__init__()
self.alpha = alpha
def __call__(self, x, x_aug, x_ema):
return -(1-self.alpha) * (x.softmax(1) * x_ema.log_softmax(1)).sum(1) \
- self.alpha * (x_aug.softmax(1) * x_ema.log_softmax(1)).sum(1)
class SoftLikelihoodRatio(nn.Module):
def __init__(self, clip=0.99, eps=1e-5):
super(SoftLikelihoodRatio, self).__init__()
self.eps = eps
self.clip = clip
def __call__(self, logits):
probs = logits.softmax(1)
probs = torch.clamp(probs, min=0.0, max=self.clip)
return - (probs * torch.log((probs / (torch.ones_like(probs) - probs)) + self.eps)).sum(1)
class GeneralizedCrossEntropy(nn.Module):
""" Paper: https://arxiv.org/abs/1805.07836 """
def __init__(self, q=0.8):
super(GeneralizedCrossEntropy, self).__init__()
self.q = q
def __call__(self, logits, targets=None):
probs = logits.softmax(1)
if targets is None:
targets = probs.argmax(dim=1)
probs_with_correct_idx = probs.index_select(-1, targets).diag()
return (1.0 - probs_with_correct_idx ** self.q) / self.q