import torch import torch.nn as nn from utils import intersection_over_union class Loss_Yolo(nn.Module): def __init__(self, S=7, B=2, C=20): super(Loss_Yolo, self).__init__() self.mse = nn.MSELoss(reduction="sum") self.S = S self.B = B self.C = C self.lambda_noobj = 0.5 self.lambda_coord = 5 def forward(self, predictions, target): predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5) iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25]) iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25]) ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0) iou_maxes, bestbox = torch.max(ious, dim=0) exists_box = target[..., 20].unsqueeze(3) box_predictions = exists_box * ( ( bestbox * predictions[..., 26:30] + (1 - bestbox) * predictions[..., 21:25] ) ) box_targets = exists_box * target[..., 21:25] box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt( torch.abs(box_predictions[..., 2:4] + 1e-6) ) box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4]) box_loss = self.mse( torch.flatten(box_predictions, end_dim=-2), torch.flatten(box_targets, end_dim=-2), ) pred_box = ( bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21] ) object_loss = self.mse( torch.flatten(exists_box * pred_box), torch.flatten(exists_box * target[..., 20:21]), ) no_object_loss = self.mse( torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1), torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1), ) no_object_loss += self.mse( torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1), torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1) ) class_loss = self.mse( torch.flatten(exists_box * predictions[..., :20], end_dim=-2,), torch.flatten(exists_box * target[..., :20], end_dim=-2,), ) loss = ( self.lambda_coord * box_loss # les deux premieres lignes dans le papier + object_loss + self.lambda_noobj * no_object_loss + class_loss ) return loss