Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| ################################################################### | |
| # ########################## iou loss ############################# | |
| ################################################################### | |
| class IOU(torch.nn.Module): | |
| def __init__(self): | |
| super(IOU, self).__init__() | |
| def _iou(self, pred, target): | |
| pred = torch.sigmoid(pred) | |
| inter = (pred * target).sum(dim=(2, 3)) | |
| union = (pred + target).sum(dim=(2, 3)) - inter | |
| iou = 1 - (inter / union) | |
| return iou.mean() | |
| def forward(self, pred, target): | |
| return self._iou(pred, target) | |