Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| This file contains specific functions for computing losses on the SEG | |
| file | |
| """ | |
| import torch | |
| class SEGLossComputation(object): | |
| """ | |
| This class computes the SEG loss. | |
| """ | |
| def __init__(self, cfg): | |
| self.eps = 1e-6 | |
| self.cfg = cfg | |
| def __call__(self, preds, targets): | |
| """ | |
| Arguments: | |
| preds (Tensor) | |
| targets (list[Tensor]) | |
| masks (list[Tensor]) | |
| Returns: | |
| seg_loss (Tensor) | |
| """ | |
| image_size = (preds.shape[2], preds.shape[3]) | |
| segm_targets, masks = self.prepare_targets(targets, image_size) | |
| device = preds.device | |
| segm_targets = segm_targets.float().to(device) | |
| masks = masks.float().to(device) | |
| seg_loss = self.dice_loss(preds, segm_targets, masks) | |
| return seg_loss | |
| def dice_loss(self, pred, gt, m): | |
| intersection = torch.sum(pred * gt * m) | |
| union = torch.sum(pred * m) + torch.sum(gt * m) + self.eps | |
| loss = 1 - 2.0 * intersection / union | |
| return loss | |
| def project_masks_on_image(self, mask_polygons, labels, shrink_ratio, image_size): | |
| seg_map, training_mask = mask_polygons.convert_seg_map( | |
| labels, shrink_ratio, image_size, self.cfg.MODEL.SEG.IGNORE_DIFFICULT | |
| ) | |
| return torch.from_numpy(seg_map), torch.from_numpy(training_mask) | |
| def prepare_targets(self, targets, image_size): | |
| segms = [] | |
| training_masks = [] | |
| for target_per_image in targets: | |
| segmentation_masks = target_per_image.get_field("masks") | |
| labels = target_per_image.get_field("labels") | |
| seg_maps_per_image, training_masks_per_image = self.project_masks_on_image( | |
| segmentation_masks, labels, self.cfg.MODEL.SEG.SHRINK_RATIO, image_size | |
| ) | |
| segms.append(seg_maps_per_image) | |
| training_masks.append(training_masks_per_image) | |
| return torch.stack(segms), torch.stack(training_masks) | |
| def make_seg_loss_evaluator(cfg): | |
| loss_evaluator = SEGLossComputation(cfg) | |
| return loss_evaluator | |