Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Any | |
| import einops | |
| import mmengine | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from lightning.pytorch.utilities import grad_norm | |
| from mmengine.structures import InstanceData | |
| from mmpl.registry import MODELS | |
| from mmseg.utils import SampleList | |
| from ..builder import build_backbone, build_loss, build_neck, build_head | |
| from .base_pler import BasePLer | |
| from mmpl.structures import ClsDataSample | |
| from .base import BaseClassifier | |
| import lightning.pytorch as pl | |
| import torch.nn.functional as F | |
| class SegPLer(BasePLer): | |
| def __init__(self, | |
| sam=None, | |
| sam_checkpoint='', | |
| points_per_side=None, | |
| sam_prompt_generator=None, | |
| only_img_encoder=False, | |
| only_decoder=False, | |
| global_prompt=None, | |
| need_train_names=None, | |
| head=None, | |
| with_clip=False, | |
| train_head=False, | |
| threshold=0.5, | |
| ignore_index=255, | |
| train_cfg=None, | |
| test_cfg=None, | |
| *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.save_hyperparameters() | |
| self.need_train_names = need_train_names | |
| self.ignore_index = ignore_index | |
| self.threshold = threshold | |
| self.only_img_encoder = only_img_encoder | |
| self.only_decoder = only_decoder | |
| self.global_prompt = global_prompt | |
| self.train_head = train_head | |
| if sam is not None: | |
| if self.only_img_encoder: | |
| self.sam = sam_model_registry[sam](sam_checkpoint).image_encoder | |
| elif self.only_decoder: | |
| self.prompt_encoder = sam_model_registry[sam](sam_checkpoint).prompt_encoder | |
| self.mask_decoder = sam_model_registry[sam](sam_checkpoint).mask_decoder | |
| else: | |
| sam = sam_model_registry[sam](sam_checkpoint, train_head=train_head) | |
| self.img_encoder = sam.image_encoder | |
| self.prompt_encoder = sam.prompt_encoder | |
| self.mask_decoder = sam.mask_decoder | |
| self.prompt_encoder_no_mask_embed = sam.prompt_encoder.no_mask_embed | |
| if points_per_side is not None: | |
| self.point_grids = build_all_layer_point_grids( | |
| points_per_side, 0, 1) | |
| if sam_prompt_generator is not None: | |
| self.sam_prompt_generator = MODELS.build(sam_prompt_generator) | |
| if head is not None: | |
| self.head = MODELS.build(head) | |
| self.with_clip = with_clip | |
| if global_prompt is not None: | |
| if with_clip: | |
| self.logits_prompt = nn.Sequential( | |
| nn.Linear(1, 8), | |
| nn.ReLU(), | |
| nn.Linear(8, 16) | |
| ) | |
| self.global_prompt = nn.Sequential( | |
| nn.Conv2d(768+16, 256, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(256, 1, kernel_size=3, padding=1), | |
| ) | |
| else: | |
| self.global_prompt = nn.Sequential( | |
| nn.Conv2d(256, 128, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(128, 1, kernel_size=3, padding=1), | |
| ) | |
| def setup(self, stage: str) -> None: | |
| if self.need_train_names is not None: | |
| self._set_grad(self.need_train_names, noneed_train_names=[]) | |
| def configure_sharded_model(self) -> None: | |
| if self.trainer.strategy.__class__.__name__ == 'FSDPStrategy': | |
| from torch.distributed.fsdp.wrap import wrap | |
| self.sam_prompt_generator = wrap(self.sam_prompt_generator) | |
| self.img_encoder = wrap(self.img_encoder) | |
| self.prompt_encoder_no_mask_embed = wrap(self.prompt_encoder_no_mask_embed) | |
| self.mask_decoder = wrap(self.mask_decoder) | |
| self.prompt_encoder = wrap(self.prompt_encoder) | |
| from torch.distributed.fsdp import CPUOffload | |
| # from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy | |
| # import functools | |
| # strategy = dict( | |
| # type='FSDPStrategy', | |
| # cpu_offload=CPUOffload(offload_params=True), | |
| # auto_wrap_policy=functools.partial( | |
| # size_based_auto_wrap_policy, min_num_params=int(1e8) | |
| # ) | |
| # | |
| # ) | |
| else: | |
| super().configure_sharded_model() | |
| def configure_optimizers(self): | |
| if self.trainer.strategy.__class__.__name__ == 'DeepSpeedStrategy': | |
| import deepspeed | |
| # optimizer = deepspeed.runtime. | |
| optimizer = deepspeed.ops.adam.FusedAdam(self.sam_prompt_generator.parameters(), lr=1e-4) | |
| # optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(self.sam_prompt_generator.parameters(), lr=1e-4) | |
| # optimizer = torch.optim.Adam(self.sam_prompt_generator.parameters(), lr=1e-4) | |
| lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) | |
| return [optimizer], [lr_scheduler] | |
| else: | |
| return super().configure_optimizers() | |
| def init_weights(self): | |
| import ipdb; ipdb.set_trace() | |
| pass | |
| # def on_fit_start(self) -> None: | |
| # if hasattr(self, 'train_evaluator'): | |
| # self.train_evaluator = self.train_evaluator.to(self.device) | |
| # if hasattr(self, 'val_evaluator'): | |
| # self.val_evaluator = self.val_evaluator.to(self.device) | |
| def train(self, mode=True): | |
| if self.need_train_names is not None: | |
| return self._set_train_module(mode, self.need_train_names) | |
| else: | |
| super().train(mode) | |
| return self | |
| def validation_step(self, batch, batch_idx): | |
| seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) | |
| if self.only_img_encoder: | |
| masks_pred = self.forward_only_img_encoder(batch) | |
| masks_pred = F.interpolate(masks_pred, size=seg_label.shape[-2:], mode='bilinear', | |
| align_corners=True) | |
| seg_logits = masks_pred > 0 | |
| elif self.only_decoder: | |
| cls_logits, masks, n_iou_preds = self.forward_sam_prompt_generator(batch) # 1x100x2, 1x100x1x256x256, 1x100x1 | |
| masks = masks.squeeze(2) | |
| masks = F.interpolate(masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) | |
| # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds | |
| seg_logits = self.post_process(cls_logits.detach(), masks.detach()) | |
| seg_logits = seg_logits > self.threshold | |
| else: | |
| cls_logits, pred_masks, n_iou_preds = self.forward_sam_prompt_generator_all( | |
| batch) # 1x100x2, 1x100x1x256x256, 1x100x1 | |
| pred_masks = pred_masks.squeeze(2) | |
| pred_masks = F.interpolate(pred_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) | |
| # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds | |
| seg_logits = self.post_process(cls_logits.detach(), pred_masks.detach()) | |
| seg_logits = seg_logits > self.threshold | |
| # import ipdb; ipdb.set_trace() | |
| self.val_evaluator.update(seg_logits, seg_label) | |
| def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any): | |
| cls_logits, n_img_masks = self.forward(batch) | |
| seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) | |
| seg_label = seg_label.squeeze(1) | |
| masks = F.interpolate(n_img_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) | |
| masks = masks.squeeze(1) > 0 | |
| self.evaluator.update(masks, seg_label) | |
| def _seg_data_to_instance_data(self, batch_data_samples: SampleList): | |
| """Perform forward propagation to convert paradigm from MMSegmentation | |
| to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called | |
| normally. Specifically, ``batch_gt_instances`` would be added. | |
| Args: | |
| batch_data_samples (List[:obj:`SegDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_sem_seg`. | |
| Returns: | |
| tuple[Tensor]: A tuple contains two lists. | |
| - batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``labels``, each is | |
| unique ground truth label id of images, with | |
| shape (num_gt, ) and ``masks``, each is ground truth | |
| masks of each instances of a image, shape (num_gt, h, w). | |
| - batch_img_metas (list[dict]): List of image meta information. | |
| """ | |
| batch_img_metas = [] | |
| batch_gt_instances = [] | |
| for data_sample in batch_data_samples: | |
| batch_img_metas.append(data_sample.metainfo) | |
| gt_masks = data_sample.instances_data.long() | |
| gt_labels = data_sample.instances_label.long() | |
| instance_data = InstanceData(labels=gt_labels, masks=gt_masks) | |
| batch_gt_instances.append(instance_data) | |
| return batch_gt_instances, batch_img_metas | |
| def training_step(self, batch, batch_idx): | |
| if self.only_img_encoder: | |
| masks_pred = self.forward_only_img_encoder(batch) | |
| seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) | |
| masks_pred = F.interpolate(masks_pred, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) | |
| losses = self.head.loss(masks_pred, seg_label) | |
| masks_pred_result = masks_pred > 0 | |
| self.train_evaluator.update(masks_pred_result.detach(), seg_label.detach()) | |
| elif self.only_decoder: | |
| cls_logits, masks, n_iou_preds = self.forward_sam_prompt_generator(batch) # 1x100x2, 1x100x1x256x256, 1x100x1 | |
| masks = masks.squeeze(2) | |
| seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) | |
| masks = F.interpolate(masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) | |
| # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds | |
| seg_logits = self.post_process(cls_logits.clone().detach(), masks.clone().detach()) | |
| seg_logits = seg_logits > self.threshold | |
| self.train_evaluator.update(seg_logits, seg_label) | |
| batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( | |
| batch['data_samples']) | |
| losses = self.head.loss(cls_logits, masks, batch_gt_instances, batch_img_metas) | |
| else: | |
| cls_logits, pred_masks, n_iou_preds = self.forward_sam_prompt_generator_all( | |
| batch) # 1x100x2, 1x100x1x256x256, 1x100x1 | |
| pred_masks = pred_masks.squeeze(2) | |
| if torch.isinf(pred_masks).any() or torch.isnan(pred_masks).any(): | |
| # import ipdb; | |
| # ipdb.set_trace() | |
| # raise ValueError('cost is nan in CrossEntropyLossCost') | |
| print('!!!!!!!!!!!!!!!!!!!!loss is nan or inf!!!!!!!!!!!!!!!!!!') | |
| return torch.tensor(0.0, requires_grad=True, device=self.device) | |
| seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) | |
| pred_masks = F.interpolate(pred_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) | |
| # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds | |
| seg_logits = self.post_process(cls_logits.clone().detach(), pred_masks.clone().detach()) | |
| seg_logits = seg_logits > self.threshold | |
| self.train_evaluator.update(seg_logits, seg_label) | |
| batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( | |
| batch['data_samples']) | |
| losses = self.head.loss(cls_logits, pred_masks, batch_gt_instances, batch_img_metas) | |
| parsed_losses, log_vars = self.parse_losses(losses) | |
| log_vars = {f'train_{k}': v for k, v in log_vars.items()} | |
| log_vars['loss'] = parsed_losses | |
| self.log_dict(log_vars, prog_bar=True) | |
| return log_vars | |
| def on_before_optimizer_step(self, optimizer) -> None: | |
| self.log_grad(module=self.sam_prompt_generator) | |
| def post_process(self, mask_cls_results, mask_pred_results): | |
| cls_score = F.softmax(mask_cls_results, dim=-1)[..., 1:2] | |
| mask_pred = mask_pred_results.sigmoid() | |
| seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred) | |
| return seg_logits | |
| def forward_only_img_encoder(self, batch, *args: Any, **kwargs: Any) -> Any: | |
| if self.with_clip: | |
| clip_dense_embs = torch.stack([x.clip_dense_embs for x in batch['data_samples']], dim=0) | |
| logits_per_images = torch.stack([x.logits_per_image for x in batch['data_samples']], dim=0) | |
| logits_per_images = self.logits_prompt(logits_per_images) # Bx576x16 | |
| clip_dense_embs = torch.cat([clip_dense_embs, logits_per_images], dim=-1) | |
| clip_dense_embs = rearrange(clip_dense_embs, 'b (h w) c -> b c h w', h=int(clip_dense_embs.shape[1]**0.5)) | |
| masks_pred = self.global_prompt(clip_dense_embs) | |
| else: | |
| image_embeddings = torch.stack([x.image_embeddings for x in batch['data_samples']], dim=0) | |
| masks_pred = self.global_prompt(image_embeddings) | |
| return masks_pred | |
| def forward_sam_prompt_generator(self, batch, *args: Any, **kwargs: Any) -> Any: | |
| inner_states = [x.inner_states for x in batch['data_samples']] | |
| image_embeddings = torch.stack([x.image_embeddings for x in batch['data_samples']], dim=0) | |
| inner_states_tmp = [] | |
| for idx in range(len(inner_states[0])): | |
| inner_states_tmp.append(torch.stack([x[idx] for x in inner_states], dim=0).to(image_embeddings.device)) | |
| point_embs, cls_logits = self.sam_prompt_generator(inner_states_tmp) | |
| # if has points prompt, then get points embeddings | |
| if hasattr(self, 'point_grids'): | |
| points_scale = np.array(img.shape[-2:], dtype=np.float32).reshape(1, -1) # 2, | |
| points_for_image = self.point_grids[0] * points_scale | |
| in_points = torch.as_tensor(points_for_image, device=img.device) | |
| in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) | |
| in_points = rearrange(in_points, 'n c -> n () c') | |
| in_labels = rearrange(in_labels, 'n -> n ()') | |
| points = (in_points, in_labels) | |
| sparse_embeddings, dense_embeddings = self.sam.prompt_encoder( | |
| points=points, | |
| boxes=None, | |
| masks=None, | |
| ) # 1024x2x256; 1024x256x64x64 | |
| else: | |
| # ponits_embeddings B T N C | |
| sparse_embeddings = point_embs | |
| dense_embeddings = self.prompt_encoder.no_mask_embed.weight.view(1, 1, -1, 1, 1).expand( | |
| sparse_embeddings.shape[0], sparse_embeddings.shape[1], -1, | |
| self.prompt_encoder.image_embedding_size[0], self.prompt_encoder.image_embedding_size[1] | |
| ) | |
| n_img_masks = [] | |
| n_iou_preds = [] | |
| n_class_aware_probs = [] | |
| for curr_img_embedding, cur_s_emb, cur_d_emb in zip(image_embeddings, sparse_embeddings, dense_embeddings): | |
| lr_masks, iou_pred, class_aware_prob = self.mask_decoder( | |
| image_embeddings=curr_img_embedding.unsqueeze(0), | |
| image_pe=self.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=cur_s_emb, | |
| dense_prompt_embeddings=cur_d_emb | |
| ) | |
| mask_slice = slice(0, 1) | |
| masks = lr_masks[:, mask_slice, :, :] | |
| iou_pred = iou_pred[:, mask_slice] | |
| class_aware_prob = class_aware_prob[:, mask_slice] | |
| n_img_masks.append(masks) | |
| n_iou_preds.append(iou_pred) | |
| n_img_masks = torch.stack(n_img_masks, dim=0) | |
| n_iou_preds = torch.stack(n_iou_preds, dim=0) | |
| return cls_logits, n_img_masks, n_iou_preds | |
| def forward_sam_prompt_generator_all(self, batch, *args: Any, **kwargs: Any) -> Any: | |
| x = torch.stack(batch['inputs'], dim=0) | |
| # if self.local_rank == 0: | |
| # import pdb; pdb.set_trace() | |
| # self.trainer.strategy.barrier() | |
| x = x[:, [2, 1, 0], :, :] # BGR -> RGB | |
| x = (x - self.img_encoder.pixel_mean) / self.img_encoder.pixel_std | |
| with torch.no_grad(): | |
| image_embeddings, inner_states = self.img_encoder(x) | |
| point_embs, cls_logits = self.sam_prompt_generator(inner_states) | |
| # if has points prompt, then get points embeddings | |
| if hasattr(self, 'point_grids'): | |
| points_scale = np.array(img.shape[-2:], dtype=np.float32).reshape(1, -1) # 2, | |
| points_for_image = self.point_grids[0] * points_scale | |
| in_points = torch.as_tensor(points_for_image, device=img.device) | |
| in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) | |
| in_points = rearrange(in_points, 'n c -> n () c') | |
| in_labels = rearrange(in_labels, 'n -> n ()') | |
| points = (in_points, in_labels) | |
| sparse_embeddings, dense_embeddings = self.sam.prompt_encoder( | |
| points=points, | |
| boxes=None, | |
| masks=None, | |
| ) # 1024x2x256; 1024x256x64x64 | |
| else: | |
| # ponits_embeddings B T N C | |
| sparse_embeddings = point_embs | |
| dense_embeddings = self.prompt_encoder_no_mask_embed(torch.tensor([0], device=self.device)).view(1, 1, -1, 1, 1).expand( | |
| sparse_embeddings.shape[0], sparse_embeddings.shape[1], -1, | |
| image_embeddings.shape[-2], image_embeddings.shape[-1] | |
| ) | |
| n_img_masks = [] | |
| n_iou_preds = [] | |
| n_class_aware_probs = [] | |
| for curr_img_embedding, cur_s_emb, cur_d_emb in zip(image_embeddings, sparse_embeddings, dense_embeddings): | |
| lr_masks, iou_pred, class_aware_prob = self.mask_decoder( | |
| image_embeddings=curr_img_embedding.unsqueeze(0), | |
| image_pe=self.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=cur_s_emb, | |
| dense_prompt_embeddings=cur_d_emb | |
| ) | |
| if self.train_head: | |
| masks = lr_masks | |
| iou_pred = iou_pred | |
| else: | |
| mask_slice = slice(0, 1) | |
| masks = lr_masks[:, mask_slice, :, :] | |
| iou_pred = iou_pred[:, mask_slice] | |
| n_img_masks.append(masks) | |
| n_iou_preds.append(iou_pred) | |
| n_img_masks = torch.stack(n_img_masks, dim=0) | |
| n_iou_preds = torch.stack(n_iou_preds, dim=0) | |
| return cls_logits, n_img_masks, n_iou_preds | |
| def vis_inter_states(self, batch, masks, *args: Any, **kwargs: Any): | |
| folder = 'results/tmp' | |
| import cv2 | |
| cv2.imwrite(os.path.join(folder, f'img.png'), batch['inputs'][0].permute((1, 2, 0)).detach().cpu().numpy()) | |
| cv2.imwrite(os.path.join(folder, f'label_mask.png'), seg_label[0][0].detach().cpu().numpy() * 255) | |
| masks = masks > 0 | |
| for idx, mask_pred in enumerate(masks[0]): | |
| cv2.imwrite(os.path.join(folder, f'pred_mask_{idx}.png'), mask_pred[0].detach().cpu().numpy() * 255) | |
| import ipdb; ipdb.set_trace() | |