DAminoMuta / loss.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import torch
from torch import nn
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
from math import cos, pi, sin
import math
import numpy as np
from scipy.special import lambertw
def mixup_criterion(criterion, pred, y_a, y_b, lam, pow=2):
y = lam ** pow * y_a + (1 - lam) ** pow * y_b
return criterion(pred, y)
def mixup_data(v, q, a):
'''Returns mixed inputs, pairs of targets, and lambda without organ constraint'''
lam = np.random.beta(1, 1)
batch_size = v.shape[0]
index = torch.randperm(batch_size)
mixed_v = lam * v + (1 - lam) * v[index, :]
mixed_q = lam * q + (1 - lam) * q[index, :]
a_1, a_2 = a, a[index]
return mixed_v, mixed_q, a_1, a_2, lam
def linear(epoch, nepoch):
return 1 - epoch / nepoch
def convex(epoch, nepoch):
return epoch / (2 - nepoch)
def concave(epoch, nepoch):
return 1 - sin((epoch / nepoch) * (pi / 2))
def composite(epoch, nepoch):
return 0.5 * cos((epoch / nepoch) * pi) + 0.5
class LogCoshLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, y_t, y_prime_t):
ey_t = y_t - y_prime_t
return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))+F.mse_loss(y_t, y_prime_t)
class WeightedMSELoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, y, y_t, weights=None):
loss = (y - y_t) ** 2
if weights is not None:
loss *= weights.expand_as(loss)
return torch.mean(loss)
class MLCE(nn.Module):
def __init__(self):
super(MLCE, self).__init__()
def _mlcce(self, y_pred, y_true):
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
loss = torch.mean(neg_loss + pos_loss)
return loss
def __call__(self, y_pred, y_true):
return self._mlcce(y_pred, y_true)
class SuperLoss(nn.Module):
def __init__(self, C=10, lam=1, batch_size=256):
super(SuperLoss, self).__init__()
self.tau = math.log(C)
self.lam = lam # set to 1 for CIFAR10 and 0.25 for CIFAR100
self.batch_size = batch_size
def forward(self, logits, targets):
l_i = F.mse_loss(logits, targets, reduction='none').detach()
sigma = self.sigma(l_i)
loss = (F.mse_loss(logits, targets, reduction='none') - self.tau) * sigma + self.lam * (
torch.log(sigma) ** 2)
loss = loss.sum() / self.batch_size
return loss
def sigma(self, l_i):
x = torch.ones_like(l_i) * (-2 / math.exp(1.))
y = 0.5 * torch.max(x, (l_i - self.tau) / self.lam)
y = y.cpu().numpy()
sigma = np.exp(-lambertw(y))
sigma = sigma.real.astype(np.float32)
sigma = torch.from_numpy(sigma).to(l_i.device)
return sigma
def unbiased_curriculum_loss(out, data, args, epoch, epochs, scheduler='linear'):
losses = []
scheduler = linear if scheduler == 'linear' else concave
# calculate difficulty measurement function
adjusted_losses = []
for idx in range(out.shape[0]):
ground_truth = max(1, abs(data[idx].item()))
loss = F.mse_loss(out[idx], data[idx])
losses.append(loss)
adjusted_losses.append(loss.item() / ground_truth)
mean_loss, std_loss = np.mean(adjusted_losses), np.std(adjusted_losses)
# re-weight losses
total_loss = 0
for i, loss in enumerate(losses):
if adjusted_losses[i] > mean_loss + 1 * std_loss:
schedule_factor = scheduler(epoch, args.epochs)
total_loss += schedule_factor * loss
else:
total_loss += loss
return total_loss
class BMCLoss(_Loss):
def __init__(self, init_noise_sigma=1.0):
super(BMCLoss, self).__init__()
self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma))
def bmc_loss(self, pred, target, noise_var):
"""Compute the Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.
Args:
pred: A float tensor of size [batch, 1].
target: A float tensor of size [batch, 1].
noise_var: A float number or tensor.
Returns:
loss: A float tensor. Balanced MSE Loss.
"""
if len(pred.shape) == 1:
pred = pred.unsqueeze(1)
if len(target.shape) == 1:
target = target.unsqueeze(1)
logits = - (pred - target.T).pow(2) / (2 * noise_var) # logit size: [batch, batch]
loss = F.cross_entropy(logits, torch.arange(pred.shape[0], device=pred.device)) # contrastive-like loss
loss = loss * (2 * noise_var).detach() # optional: restore the loss scale, 'detach' when noise is learnable
return loss
def forward(self, pred, target):
noise_var = self.noise_sigma ** 2
return self.bmc_loss(pred, target, noise_var)