Spaces:
Runtime error
Runtime error
| import torch | |
| from mmengine.structures import InstanceData | |
| from typing import List, Any | |
| from mmpl.registry import MODELS | |
| from mmseg.utils import SampleList | |
| from .base_pler import BasePLer | |
| import torch.nn.functional as F | |
| from modules.sam import sam_model_registry | |
| class SegSAMPLer(BasePLer): | |
| def __init__(self, | |
| backbone, | |
| sam_neck=None, | |
| panoptic_head=None, | |
| panoptic_fusion_head=None, | |
| need_train_names=None, | |
| train_cfg=None, | |
| test_cfg=None, | |
| *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.save_hyperparameters() | |
| self.need_train_names = need_train_names | |
| backbone_type = backbone.pop('type') | |
| self.backbone = sam_model_registry[backbone_type](**backbone) | |
| if sam_neck is not None: | |
| self.sam_neck = MODELS.build(sam_neck) | |
| panoptic_head_ = panoptic_head.deepcopy() | |
| panoptic_head_.update(train_cfg=train_cfg) | |
| panoptic_head_.update(test_cfg=test_cfg) | |
| self.panoptic_head = MODELS.build(panoptic_head_) | |
| panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() | |
| panoptic_fusion_head_.update(test_cfg=test_cfg) | |
| self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) | |
| self.num_things_classes = self.panoptic_head.num_things_classes | |
| self.num_stuff_classes = self.panoptic_head.num_stuff_classes | |
| self.num_classes = self.panoptic_head.num_classes | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| def setup(self, stage: str) -> None: | |
| super().setup(stage) | |
| if self.need_train_names is not None: | |
| self._set_grad(self.need_train_names, noneed_train_names=[]) | |
| def init_weights(self): | |
| import ipdb; ipdb.set_trace() | |
| pass | |
| def train(self, mode=True): | |
| if self.need_train_names is not None: | |
| return self._set_train_module(mode, self.need_train_names) | |
| else: | |
| super().train(mode) | |
| return self | |
| def extract_feat(self, batch_inputs): | |
| feat, inter_features = self.backbone.image_encoder(batch_inputs) | |
| return feat, inter_features | |
| def validation_step(self, batch, batch_idx): | |
| data = self.data_preprocessor(batch, False) | |
| batch_inputs = data['inputs'] | |
| batch_data_samples = data['data_samples'] | |
| feats = self.extract_feat(batch_inputs) | |
| if hasattr(self, 'sam_neck'): | |
| feats = self.sam_neck(feats) | |
| mask_cls_results, mask_pred_results = self.panoptic_head.predict( | |
| feats, batch_data_samples) | |
| else: | |
| mask_cls_results, mask_pred_results = self.panoptic_head.predict( | |
| feats, batch_data_samples, self.backbone) | |
| results_list = self.panoptic_fusion_head.predict( | |
| mask_cls_results, | |
| mask_pred_results, | |
| batch_data_samples, | |
| rescale=True) | |
| results = self.add_pred_to_datasample(batch_data_samples, results_list) | |
| # preds = [] | |
| # targets = [] | |
| # for data_sample in results: | |
| # result = dict() | |
| # pred = data_sample.pred_instances | |
| # result['boxes'] = pred['bboxes'] | |
| # result['scores'] = pred['scores'] | |
| # result['labels'] = pred['labels'] | |
| # if 'masks' in pred: | |
| # result['masks'] = pred['masks'] | |
| # preds.append(result) | |
| # # parse gt | |
| # gt = dict() | |
| # gt_data = data_sample.get('gt_instances', None) | |
| # gt['boxes'] = gt_data['bboxes'] | |
| # gt['labels'] = gt_data['labels'] | |
| # if 'masks' in pred: | |
| # gt['masks'] = gt_data['masks'].to_tensor(dtype=torch.bool, device=result['masks'].device) | |
| # targets.append(gt) | |
| # | |
| # self.val_evaluator.update(preds, targets) | |
| self.val_evaluator.update(batch, results) | |
| def training_step(self, batch, batch_idx): | |
| data = self.data_preprocessor(batch, True) | |
| batch_inputs = data['inputs'] | |
| batch_data_samples = data['data_samples'] | |
| x = self.extract_feat(batch_inputs) | |
| if hasattr(self, 'sam_neck'): | |
| x = self.sam_neck(x) | |
| losses = self.panoptic_head.loss(x, batch_data_samples) | |
| else: | |
| losses = self.panoptic_head.loss(x, batch_data_samples, self.backbone) | |
| parsed_losses, log_vars = self.parse_losses(losses) | |
| log_vars = {f'train_{k}': v for k, v in log_vars.items()} | |
| log_vars['loss'] = parsed_losses | |
| self.log_dict(log_vars, prog_bar=True) | |
| return log_vars | |
| def on_before_optimizer_step(self, optimizer) -> None: | |
| self.log_grad(module=self.panoptic_head) | |
| def add_pred_to_datasample(self, data_samples: SampleList, | |
| results_list: List[dict]) -> SampleList: | |
| """Add predictions to `DetDataSample`. | |
| Args: | |
| data_samples (list[:obj:`DetDataSample`], optional): A batch of | |
| data samples that contain annotations and predictions. | |
| results_list (List[dict]): Instance segmentation, segmantic | |
| segmentation and panoptic segmentation results. | |
| Returns: | |
| list[:obj:`DetDataSample`]: Detection results of the | |
| input images. Each DetDataSample usually contain | |
| 'pred_instances' and `pred_panoptic_seg`. And the | |
| ``pred_instances`` usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - masks (Tensor): Has a shape (num_instances, H, W). | |
| And the ``pred_panoptic_seg`` contains the following key | |
| - sem_seg (Tensor): panoptic segmentation mask, has a | |
| shape (1, h, w). | |
| """ | |
| for data_sample, pred_results in zip(data_samples, results_list): | |
| if 'pan_results' in pred_results: | |
| data_sample.pred_panoptic_seg = pred_results['pan_results'] | |
| if 'ins_results' in pred_results: | |
| data_sample.pred_instances = pred_results['ins_results'] | |
| assert 'sem_results' not in pred_results, 'segmantic ' \ | |
| 'segmentation results are not supported yet.' | |
| return data_samples | |
| def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: | |
| data = self.data_preprocessor(batch, False) | |
| batch_inputs = data['inputs'] | |
| batch_data_samples = data['data_samples'] | |
| # import ipdb; ipdb.set_trace() | |
| feats = self.extract_feat(batch_inputs) | |
| if hasattr(self, 'sam_neck'): | |
| feats = self.sam_neck(feats) | |
| mask_cls_results, mask_pred_results = self.panoptic_head.predict( | |
| feats, batch_data_samples) | |
| else: | |
| mask_cls_results, mask_pred_results = self.panoptic_head.predict( | |
| feats, batch_data_samples, self.backbone) | |
| results_list = self.panoptic_fusion_head.predict( | |
| mask_cls_results, | |
| mask_pred_results, | |
| batch_data_samples, | |
| rescale=True) | |
| results = self.add_pred_to_datasample(batch_data_samples, results_list) | |
| return results | |