import numpy as np from scipy.optimize import curve_fit from scipy import stats import torch import torch.nn.functional as F import torch.nn as nn esp = 1e-8 def logistic_func(X, bayta1, bayta2, bayta3, bayta4): logisticPart = 1 + np.exp(-(X - bayta3) / np.abs(bayta4)) yhat = bayta2 + (bayta1 - bayta2) / logisticPart return yhat def fit_function(y_label, y_output): # print(np.max(y_label)) # breakpoint() beta = [np.max(y_label), np.min(y_label), np.mean(y_output), 0.5] popt, _ = curve_fit(logistic_func, y_output, y_label, p0=beta, maxfev=10000) y_output_logistic = logistic_func(y_output, *popt) return y_output_logistic def performance_fit(y_label, y_output): y_output_logistic = fit_function(y_label, y_output) PLCC = stats.pearsonr(y_output_logistic, y_label)[0] SRCC = stats.spearmanr(y_output, y_label)[0] KRCC = stats.kendalltau(y_output, y_label)[0] RMSE = np.sqrt(((y_output_logistic-y_label) ** 2).mean()) return PLCC, SRCC, KRCC, RMSE def performance_no_fit(y_label, y_output): PLCC = stats.pearsonr(y_output, y_label)[0] SRCC = stats.spearmanr(y_output, y_label)[0] KRCC = stats.stats.kendalltau(y_output, y_label)[0] RMSE = np.sqrt(((y_label-y_label) ** 2).mean()) return PLCC, SRCC, KRCC, RMSE class L1RankLoss(torch.nn.Module): """ L1 loss + Rank loss """ def __init__(self, **kwargs): super(L1RankLoss, self).__init__() self.l1_w = kwargs.get("l1_w", 1) self.rank_w = kwargs.get("rank_w", 1) self.hard_thred = kwargs.get("hard_thred", 1) self.use_margin = kwargs.get("use_margin", False) def forward(self, preds, gts): preds = preds.view(-1) gts = gts.view(-1) # l1 loss l1_loss = F.l1_loss(preds, gts) * self.l1_w # simple rank n = len(preds) preds = preds.unsqueeze(0).repeat(n, 1) preds_t = preds.t() img_label = gts.unsqueeze(0).repeat(n, 1) img_label_t = img_label.t() masks = torch.sign(img_label - img_label_t) masks_hard = (torch.abs(img_label - img_label_t) < self.hard_thred) & (torch.abs(img_label - img_label_t) > 0) if self.use_margin: rank_loss = masks_hard * \ torch.relu(torch.abs(img_label - img_label_t) - masks * (preds - preds_t)) else: rank_loss = masks_hard * torch.relu(- masks * (preds - preds_t)) rank_loss = rank_loss.sum() / (masks_hard.sum() + 1e-08) loss_total = l1_loss + rank_loss * self.rank_w return loss_total def plcc_loss(y, y_pred): sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False) y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8) sigma, m = torch.std_mean(y, unbiased=False) y = (y - m) / (sigma + 1e-8) # print(y.shape) # print(y_pred.shape) loss0 = F.mse_loss(y_pred, y) / 4 rho = torch.mean(y_pred * y) loss1 = F.mse_loss(rho * y_pred, y) / 4 return ((loss0 + loss1) / 2).float() def rank_loss(y, y_pred): ranking_loss = F.relu((y_pred - y) * torch.sign(y_pred - y)) scale = 1 + torch.max(ranking_loss) return ( torch.sum(ranking_loss) / y_pred.shape[0] / (y_pred.shape[0] - 1) / scale ).float() def plcc_rank_loss(y_label, y_output): plcc = plcc_loss(y_label, y_output) rank = rank_loss(y_label, y_output) return plcc + rank*0.3 def plcc_l1_loss(y_label, y_output): plcc = plcc_loss(y_label, y_output) l1_loss = F.l1_loss(y_label, y_output) return plcc + 0.0025*l1_loss class Multi_Fidelity_Loss(torch.nn.Module): def __init__(self): super(Multi_Fidelity_Loss, self).__init__() def forward(self, y_pred, y): assert y.size(0) > 1 # # y_pred = y_pred # y = y # preds = y_pred - y_pred.t() # gts = y - y.t() # # triu_indices = torch.triu_indices(y_pred.size(0), y_pred.size(0), offset=1) # p = preds[triu_indices[0], triu_indices[1]] # g = gts[triu_indices[0], triu_indices[1]] # # g = 0.5 * (torch.sign(g) + 1) # constant = torch.sqrt(torch.Tensor([2.])).to(p.device) # p = 0.5 * (1 + torch.erf(p / constant)) # g = g.view(-1, 1) # p = p.view(-1, 1) constant = torch.sqrt(torch.Tensor([2.])).to(y_pred.device) loss = 0 for i in range(y.size(1)): p_i = y_pred[:, i].unsqueeze(1) g_i = y[:, i].unsqueeze(1) preds = p_i - p_i.t() gts = g_i - g_i.t() triu_indices = torch.triu_indices(preds.size(0), preds.size(0), offset=1) p = preds[triu_indices[0], triu_indices[1]] g = gts[triu_indices[0], triu_indices[1]] g = 0.5 * (torch.sign(g) + 1) p = 0.5 * (1 + torch.erf(p / constant)) g_i = g.view(-1, 1) p_i = p.view(-1, 1) loss_i = 1 - (torch.sqrt(p_i * g_i + esp) + torch.sqrt((1 - p_i) * (1 - g_i) + esp)) loss = loss + loss_i loss = loss / y.size(1) return torch.mean(loss)