File size: 5,187 Bytes
a6bc892
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import numpy as np
from scipy.optimize import curve_fit
from scipy import stats
import torch
import torch.nn.functional as F
import torch.nn as nn

esp = 1e-8
def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
    logisticPart = 1 + np.exp(-(X - bayta3) / np.abs(bayta4))
    yhat = bayta2 + (bayta1 - bayta2) / logisticPart
    return yhat


def fit_function(y_label, y_output):

    # print(np.max(y_label))
    # breakpoint()
    beta = [np.max(y_label), np.min(y_label), np.mean(y_output), 0.5]
    popt, _ = curve_fit(logistic_func, y_output,
                        y_label, p0=beta, maxfev=10000)
    y_output_logistic = logistic_func(y_output, *popt)

    return y_output_logistic


def performance_fit(y_label, y_output):
    y_output_logistic = fit_function(y_label, y_output)
    PLCC = stats.pearsonr(y_output_logistic, y_label)[0]
    SRCC = stats.spearmanr(y_output, y_label)[0]
    KRCC = stats.kendalltau(y_output, y_label)[0]
    RMSE = np.sqrt(((y_output_logistic-y_label) ** 2).mean())

    return PLCC, SRCC, KRCC, RMSE


def performance_no_fit(y_label, y_output):
    PLCC = stats.pearsonr(y_output, y_label)[0]
    SRCC = stats.spearmanr(y_output, y_label)[0]
    KRCC = stats.stats.kendalltau(y_output, y_label)[0]
    RMSE = np.sqrt(((y_label-y_label) ** 2).mean())

    return PLCC, SRCC, KRCC, RMSE


class L1RankLoss(torch.nn.Module):
    """
    L1 loss + Rank loss
    """

    def __init__(self, **kwargs):
        super(L1RankLoss, self).__init__()
        self.l1_w = kwargs.get("l1_w", 1)
        self.rank_w = kwargs.get("rank_w", 1)
        self.hard_thred = kwargs.get("hard_thred", 1)
        self.use_margin = kwargs.get("use_margin", False)

    def forward(self, preds, gts):
        preds = preds.view(-1)
        gts = gts.view(-1)
        # l1 loss
        l1_loss = F.l1_loss(preds, gts) * self.l1_w

        # simple rank
        n = len(preds)
        preds = preds.unsqueeze(0).repeat(n, 1)
        preds_t = preds.t()
        img_label = gts.unsqueeze(0).repeat(n, 1)
        img_label_t = img_label.t()
        masks = torch.sign(img_label - img_label_t)
        masks_hard = (torch.abs(img_label - img_label_t) <
                      self.hard_thred) & (torch.abs(img_label - img_label_t) > 0)
        if self.use_margin:
            rank_loss = masks_hard * \
                torch.relu(torch.abs(img_label - img_label_t) -
                           masks * (preds - preds_t))
        else:
            rank_loss = masks_hard * torch.relu(- masks * (preds - preds_t))
        rank_loss = rank_loss.sum() / (masks_hard.sum() + 1e-08)
        loss_total = l1_loss + rank_loss * self.rank_w
        return loss_total


def plcc_loss(y, y_pred):
    sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False)
    y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8)
    sigma, m = torch.std_mean(y, unbiased=False)
    y = (y - m) / (sigma + 1e-8)
    # print(y.shape)
    # print(y_pred.shape)
    loss0 = F.mse_loss(y_pred, y) / 4
    rho = torch.mean(y_pred * y)
    loss1 = F.mse_loss(rho * y_pred, y) / 4
    return ((loss0 + loss1) / 2).float()


def rank_loss(y, y_pred):
    ranking_loss = F.relu((y_pred - y) * torch.sign(y_pred - y))
    scale = 1 + torch.max(ranking_loss)
    return (
        torch.sum(ranking_loss) /
        y_pred.shape[0] / (y_pred.shape[0] - 1) / scale
    ).float()


def plcc_rank_loss(y_label, y_output):
    plcc = plcc_loss(y_label, y_output)
    rank = rank_loss(y_label, y_output)
    return plcc + rank*0.3


def plcc_l1_loss(y_label, y_output):
    plcc = plcc_loss(y_label, y_output)
    l1_loss = F.l1_loss(y_label, y_output)
    return plcc + 0.0025*l1_loss


class Multi_Fidelity_Loss(torch.nn.Module):

    def __init__(self):
        super(Multi_Fidelity_Loss, self).__init__()

    def forward(self, y_pred, y):

        assert y.size(0) > 1  #
        # y_pred = y_pred
        # y = y
        # preds = y_pred - y_pred.t()
        # gts = y - y.t()
        #
        # triu_indices = torch.triu_indices(y_pred.size(0), y_pred.size(0), offset=1)
        # p = preds[triu_indices[0], triu_indices[1]]
        # g = gts[triu_indices[0], triu_indices[1]]
        #
        # g = 0.5 * (torch.sign(g) + 1)
        # constant = torch.sqrt(torch.Tensor([2.])).to(p.device)
        # p = 0.5 * (1 + torch.erf(p / constant))
        # g = g.view(-1, 1)
        # p = p.view(-1, 1)
        constant = torch.sqrt(torch.Tensor([2.])).to(y_pred.device)
        loss = 0
        for i in range(y.size(1)):
            p_i = y_pred[:, i].unsqueeze(1)
            g_i = y[:, i].unsqueeze(1)

            preds = p_i - p_i.t()
            gts = g_i - g_i.t()
            triu_indices = torch.triu_indices(preds.size(0), preds.size(0), offset=1)
            p = preds[triu_indices[0], triu_indices[1]]
            g = gts[triu_indices[0], triu_indices[1]]
            g = 0.5 * (torch.sign(g) + 1)
            p = 0.5 * (1 + torch.erf(p / constant))
            g_i = g.view(-1, 1)
            p_i = p.view(-1, 1)
            loss_i = 1 - (torch.sqrt(p_i * g_i + esp) + torch.sqrt((1 - p_i) * (1 - g_i) + esp))
            loss = loss + loss_i
        loss = loss / y.size(1)

        return torch.mean(loss)