File size: 563 Bytes
13188b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn


class WeightedMSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, prediction, target):
        squared_errors = (prediction - target) ** 2
        weights = torch.ones_like(target)

        weights[target >= 6.0] = 2.0  # Fine x2 pKd > 6 good binding
        weights[target >= 7.0] = 5.0  # Fine x5 pKd > 7 great binding
        weights[target >= 8.0] = 10.0  # Fine x10 pKd > 8 super binding

        weighted_loss = squared_errors * weights
        return torch.mean(weighted_loss)