Spaces:
Runtime error
Runtime error
| import torch | |
| from mmengine.structures import InstanceData, PixelData | |
| from typing import List | |
| from torch import Tensor | |
| from mmpl.registry import MODELS | |
| from mmseg.models.utils import resize | |
| from mmseg.structures import SegDataSample | |
| from mmseg.utils import SampleList, OptSampleList | |
| from .base_pler import BasePLer | |
| import torch.nn.functional as F | |
| from modules.sam import sam_model_registry | |
| class SemSegSAMPLer(BasePLer): | |
| def __init__(self, | |
| backbone, | |
| adaphead=None, | |
| decode_head=None, | |
| need_train_names=None, | |
| align_corners=False, | |
| train_cfg=None, | |
| test_cfg=None, | |
| *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.save_hyperparameters() | |
| self.need_train_names = need_train_names | |
| self.align_corners = align_corners | |
| backbone_type = backbone.pop('type') | |
| delete_submodel = backbone.pop('delete_submodel', []) | |
| self.backbone = sam_model_registry[backbone_type](**backbone) | |
| for submodel in delete_submodel: | |
| delattr(self.backbone, submodel) | |
| if adaphead is not None: | |
| self.adaphead = MODELS.build(adaphead) | |
| decode_head_ = decode_head.deepcopy() | |
| decode_head_.update(train_cfg=train_cfg) | |
| decode_head_.update(test_cfg=test_cfg) | |
| self.decode_head = MODELS.build(decode_head_) | |
| self.num_classes = self.decode_head.num_classes | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| def setup(self, stage: str) -> None: | |
| if self.need_train_names is not None: | |
| self._set_grad(self.need_train_names, noneed_train_names=[]) | |
| def init_weights(self): | |
| import ipdb; ipdb.set_trace() | |
| pass | |
| def train(self, mode=True): | |
| if self.need_train_names is not None: | |
| return self._set_train_module(mode, self.need_train_names) | |
| else: | |
| super().train(mode) | |
| return self | |
| def extract_feat(self, batch_inputs): | |
| x0, x1 = self.adaphead(batch_inputs, self.backbone.image_encoder) | |
| return x0, x1 | |
| def validation_step(self, batch, batch_idx): | |
| data = self.data_preprocessor(batch, False) | |
| batch_inputs = data['inputs'] | |
| batch_data_samples = data['data_samples'] | |
| if batch_data_samples is not None: | |
| batch_img_metas = [ | |
| data_sample.metainfo for data_sample in batch_data_samples | |
| ] | |
| else: | |
| batch_img_metas = [ | |
| dict( | |
| ori_shape=batch_inputs.shape[2:], | |
| img_shape=batch_inputs.shape[2:], | |
| pad_shape=batch_inputs.shape[2:], | |
| padding_size=[0, 0, 0, 0]) | |
| ] * batch_inputs.shape[0] | |
| x = self.extract_feat(batch_inputs) | |
| seg_logits = self.decode_head.predict(x, batch_img_metas, self.test_cfg) | |
| results = self.postprocess_result(seg_logits, batch_data_samples) | |
| preds = [] | |
| targets = [] | |
| for data_sample in results: | |
| pred_label = data_sample.pred_sem_seg.data.squeeze() | |
| label = data_sample.gt_sem_seg.data.squeeze().to(pred_label) | |
| preds.append(pred_label) | |
| targets.append(label) | |
| preds = torch.stack(preds, dim=0) | |
| targets = torch.stack(targets, dim=0) | |
| self.val_evaluator.update(preds, targets) | |
| def training_step(self, batch, batch_idx): | |
| # import ipdb; ipdb.set_trace() | |
| data = self.data_preprocessor(batch, True) | |
| batch_inputs = data['inputs'] | |
| batch_data_samples = data['data_samples'] | |
| x = self.extract_feat(batch_inputs) | |
| losses = self.decode_head.loss(x, batch_data_samples) | |
| # import ipdb; ipdb.set_trace() | |
| parsed_losses, log_vars = self.parse_losses(losses) | |
| log_vars = {f'train_{k}': v for k, v in log_vars.items()} | |
| log_vars['loss'] = parsed_losses | |
| self.log_dict(log_vars, prog_bar=True) | |
| return log_vars | |
| def on_before_optimizer_step(self, optimizer) -> None: | |
| self.log_grad(module=self.adaphead) | |
| def postprocess_result(self, | |
| seg_logits: Tensor, | |
| data_samples: OptSampleList = None) -> SampleList: | |
| """ Convert results list to `SegDataSample`. | |
| Args: | |
| seg_logits (Tensor): The segmentation results, seg_logits from | |
| model of each input image. | |
| data_samples (list[:obj:`SegDataSample`]): The seg data samples. | |
| It usually includes information such as `metainfo` and | |
| `gt_sem_seg`. Default to None. | |
| Returns: | |
| list[:obj:`SegDataSample`]: Segmentation results of the | |
| input images. Each SegDataSample usually contain: | |
| - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. | |
| - ``seg_logits``(PixelData): Predicted logits of semantic | |
| segmentation before normalization. | |
| """ | |
| batch_size, C, H, W = seg_logits.shape | |
| if data_samples is None: | |
| data_samples = [SegDataSample() for _ in range(batch_size)] | |
| only_prediction = True | |
| else: | |
| only_prediction = False | |
| for i in range(batch_size): | |
| if not only_prediction: | |
| img_meta = data_samples[i].metainfo | |
| # remove padding area | |
| if 'img_padding_size' not in img_meta: | |
| padding_size = img_meta.get('padding_size', [0] * 4) | |
| else: | |
| padding_size = img_meta['img_padding_size'] | |
| padding_left, padding_right, padding_top, padding_bottom =\ | |
| padding_size | |
| # i_seg_logits shape is 1, C, H, W after remove padding | |
| i_seg_logits = seg_logits[i:i + 1, :, | |
| padding_top:H - padding_bottom, | |
| padding_left:W - padding_right] | |
| flip = img_meta.get('flip', None) | |
| if flip: | |
| flip_direction = img_meta.get('flip_direction', None) | |
| assert flip_direction in ['horizontal', 'vertical'] | |
| if flip_direction == 'horizontal': | |
| i_seg_logits = i_seg_logits.flip(dims=(3, )) | |
| else: | |
| i_seg_logits = i_seg_logits.flip(dims=(2, )) | |
| # resize as original shape | |
| i_seg_logits = resize( | |
| i_seg_logits, | |
| size=img_meta['ori_shape'], | |
| mode='bilinear', | |
| align_corners=self.align_corners, | |
| warning=False).squeeze(0) | |
| else: | |
| i_seg_logits = seg_logits[i] | |
| if C > 1: | |
| i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) | |
| else: | |
| i_seg_pred = (i_seg_logits > | |
| self.decode_head.threshold).to(i_seg_logits) | |
| data_samples[i].set_data({ | |
| 'seg_logits': | |
| PixelData(**{'data': i_seg_logits}), | |
| 'pred_sem_seg': | |
| PixelData(**{'data': i_seg_pred}) | |
| }) | |
| return data_samples | |