RingMo-SAM / models /iou_loss.py
AI-Cyber's picture
Upload 123 files
8d7921b
import torch
import torch.nn as nn
import torch.nn.functional as F
###################################################################
# ########################## iou loss #############################
###################################################################
class IOU(torch.nn.Module):
def __init__(self):
super(IOU, self).__init__()
def _iou(self, pred, target):
pred = torch.sigmoid(pred)
inter = (pred * target).sum(dim=(2, 3))
union = (pred + target).sum(dim=(2, 3)) - inter
iou = 1 - (inter / union)
return iou.mean()
def forward(self, pred, target):
return self._iou(pred, target)