import torch import torch.nn as nn class WeightedMSELoss(nn.Module): def __init__(self): super().__init__() def forward(self, prediction, target): squared_errors = (prediction - target) ** 2 weights = torch.ones_like(target) weights[target >= 6.0] = 2.0 # Fine x2 pKd > 6 good binding weights[target >= 7.0] = 5.0 # Fine x5 pKd > 7 great binding weights[target >= 8.0] = 10.0 # Fine x10 pKd > 8 super binding weighted_loss = squared_errors * weights return torch.mean(weighted_loss)