| |
| from typing import List, Union |
|
|
| import torch |
| from mmengine import ConfigDict |
| from mmengine.structures import InstanceData |
| from scipy.optimize import linear_sum_assignment |
| from torch.cuda.amp import autocast |
|
|
| from mmseg.registry import TASK_UTILS |
| from .base_assigner import BaseAssigner |
|
|
|
|
| @TASK_UTILS.register_module() |
| class HungarianAssigner(BaseAssigner): |
| """Computes one-to-one matching between prediction masks and ground truth. |
| |
| This class uses bipartite matching-based assignment to computes an |
| assignment between the prediction masks and the ground truth. The |
| assignment result is based on the weighted sum of match costs. The |
| Hungarian algorithm is used to calculate the best matching with the |
| minimum cost. The prediction masks that are not matched are classified |
| as background. |
| |
| Args: |
| match_costs (ConfigDict|List[ConfigDict]): Match cost configs. |
| """ |
|
|
| def __init__( |
| self, match_costs: Union[List[Union[dict, ConfigDict]], dict, |
| ConfigDict] |
| ) -> None: |
|
|
| if isinstance(match_costs, dict): |
| match_costs = [match_costs] |
| elif isinstance(match_costs, list): |
| assert len(match_costs) > 0, \ |
| 'match_costs must not be a empty list.' |
|
|
| self.match_costs = [ |
| TASK_UTILS.build(match_cost) for match_cost in match_costs |
| ] |
|
|
| def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, |
| **kwargs): |
| """Computes one-to-one matching based on the weighted costs. |
| |
| This method assign each query prediction to a ground truth or |
| background. The assignment first calculates the cost for each |
| category assigned to each query mask, and then uses the |
| Hungarian algorithm to calculate the minimum cost as the best |
| match. |
| |
| Args: |
| pred_instances (InstanceData): Instances of model |
| predictions. It includes "masks", with shape |
| (n, h, w) or (n, l), and "cls", with shape (n, num_classes+1) |
| gt_instances (InstanceData): Ground truth of instance |
| annotations. It includes "labels", with shape (k, ), |
| and "masks", with shape (k, h, w) or (k, l). |
| |
| Returns: |
| matched_quiery_inds (Tensor): The indexes of matched quieres. |
| matched_label_inds (Tensor): The indexes of matched labels. |
| """ |
| |
| cost_list = [] |
| with autocast(enabled=False): |
| for match_cost in self.match_costs: |
| cost = match_cost( |
| pred_instances=pred_instances, gt_instances=gt_instances) |
| cost_list.append(cost) |
| cost = torch.stack(cost_list).sum(dim=0) |
|
|
| device = cost.device |
| |
| cost = cost.detach().cpu() |
| if linear_sum_assignment is None: |
| raise ImportError('Please run "pip install scipy" ' |
| 'to install scipy first.') |
|
|
| matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost) |
| matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device) |
| matched_label_inds = torch.from_numpy(matched_label_inds).to(device) |
|
|
| return matched_quiery_inds, matched_label_inds |
|
|