maverickrzw's picture
des
2402804
import copy
from typing import Sequence
import torch
from mmengine.structures import InstanceData, PixelData
from torch import nn
from torch.nn import functional as F
from mmdet.evaluation.functional import INSTANCE_OFFSET
from mmdet.registry import MODELS
from .utils import (is_lower_torch_version, retry_if_cuda_oom,
sem_seg_postprocess)
@MODELS.register_module()
class XDecoderUnifiedhead(nn.Module):
def __init__(self,
in_channels: int,
pixel_decoder: nn.Module,
transformer_decoder: nn.Module,
task: str = 'semseg',
test_cfg=None):
super().__init__()
self.task = task
self.test_cfg = test_cfg
pixel_decoder_ = copy.deepcopy(pixel_decoder)
pixel_decoder_.update(in_channels=in_channels)
self.pixel_decoder = MODELS.build(pixel_decoder_)
transformer_decoder_ = copy.deepcopy(transformer_decoder)
transformer_decoder_.update(task=task)
self.predictor = MODELS.build(transformer_decoder_)
self.return_inter_mask = False
if self.task == 'ref-caption':
# ref-caption = ref-seg + caption,
# so we need to return the intermediate mask
self.return_inter_mask = True
self._all_text_prompts = None
self._extra = None
# TODO: Very trick, for retrieval task
self._force_not_use_cache = False
def pre_process(self, batch_data_samples, device):
extra = {}
if self.task != 'caption':
# have text
all_text_prompts = []
num_thing_class = 0
for data_samples in batch_data_samples:
if isinstance(data_samples.text, str):
text = data_samples.text.split('.')
elif isinstance(data_samples.text, Sequence):
text = data_samples.text
else:
raise TypeError(
'Type pf data_sample.text must be sequence or str')
text = list(filter(lambda x: len(x) > 0, text))
all_text_prompts.append(text)
num_thing_class = len(text)
# for panoptic
if 'stuff_text' in data_samples:
if isinstance(data_samples.stuff_text, str):
text = data_samples.stuff_text.split('.')
elif isinstance(data_samples.stuff_text, Sequence):
text = data_samples.stuff_text
else:
raise TypeError('Type pf data_sample.stuff_text '
'must be sequence or str')
text = list(filter(lambda x: len(x) > 0, text))
all_text_prompts[-1].extend(text)
# TODO: support batch
all_text_prompts = all_text_prompts[0]
if all_text_prompts != self._all_text_prompts \
or self._force_not_use_cache:
# avoid redundant computation
self._all_text_prompts = all_text_prompts
if self.task in ['semseg', 'instance', 'panoptic']:
self.predictor.lang_encoder.get_mean_embeds(
all_text_prompts + ['background'])
elif self.task == 'ref-seg':
token_info = self.predictor.lang_encoder.get_text_embeds(
all_text_prompts, norm=False)
token_emb = token_info['token_emb']
tokens = token_info['tokens']
query_emb = token_emb[tokens['attention_mask'].bool()]
extra['grounding_tokens'] = query_emb[:, None]
extra['class_emb'] = token_info['class_emb']
elif self.task == 'retrieval':
token_info = self.predictor.lang_encoder.get_text_embeds(
all_text_prompts, norm=True)
extra['class_emb'] = token_info['class_emb']
self._extra = extra
return extra, all_text_prompts, num_thing_class
else:
return self._extra, all_text_prompts, num_thing_class
else:
if not hasattr(self, 'start_token'):
self.start_token = self.predictor.lang_encoder. \
get_sot_token(device=device)
extra['start_token'] = self.start_token
return extra, None, None
def predict(self, features, batch_data_samples):
# multi scale feature
mask_features, multi_scale_features = self.pixel_decoder(features)
# pre process
extra, all_text_prompts, num_thing_class = self.pre_process(
batch_data_samples, mask_features.device)
# transformer decoder forward
predictions = self.predictor(
multi_scale_features, mask_features, extra=extra)
# post process
return self.post_process(predictions, batch_data_samples,
all_text_prompts, num_thing_class)
def post_process(self, predictions, batch_data_samples, all_text_prompts,
num_thing_class):
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
batch_input_shape = batch_data_samples[0].metainfo['batch_input_shape']
if self.task == 'caption':
for text, data_samples in zip(predictions['pred_caption'],
batch_data_samples):
data_samples.pred_caption = text
if 'pred_instances' in batch_data_samples[0]:
for img_metas, data_samples in zip(batch_img_metas,
batch_data_samples):
original_caption = data_samples.text.split('.')
text_prompts = list(
filter(lambda x: len(x) > 0, original_caption))
height = img_metas['ori_shape'][0]
width = img_metas['ori_shape'][1]
image_size = img_metas['grounding_img_shape'][:2]
mask_pred_result = data_samples.pred_instances.masks.float(
)
mask_cls_result = data_samples.pred_instances.scores.float(
)
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
mask_pred_result, image_size, height, width)
pred_instances = retry_if_cuda_oom(
self._instance_inference)(mask_cls_result,
mask_pred_result,
text_prompts)
data_samples.pred_instances = pred_instances
elif self.task in ['semseg', 'instance', 'panoptic']:
mask_pred_results = predictions['pred_masks']
mask_cls_results = predictions['pred_logits']
if is_lower_torch_version():
mask_pred_results = F.interpolate(
mask_pred_results,
size=(batch_input_shape[-2], batch_input_shape[-1]),
mode='bicubic',
align_corners=False)
else:
mask_pred_results = F.interpolate(
mask_pred_results,
size=(batch_input_shape[-2], batch_input_shape[-1]),
mode='bicubic',
align_corners=False,
antialias=True)
# for batch
for mask_cls_result, \
mask_pred_result, \
img_metas, \
data_samples in zip(
mask_cls_results,
mask_pred_results,
batch_img_metas,
batch_data_samples):
height = img_metas['ori_shape'][0]
width = img_metas['ori_shape'][1]
image_size = img_metas['img_shape'][:2]
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
mask_pred_result, image_size, height, width)
mask_cls_result = mask_cls_result.to(mask_pred_result)
if self.task == 'semseg':
pred_sem_seg = retry_if_cuda_oom(self._semantic_inference)(
mask_cls_result, mask_pred_result, all_text_prompts)
data_samples.pred_sem_seg = pred_sem_seg
elif self.task == 'instance':
pred_instances = retry_if_cuda_oom(
self._instance_inference)(mask_cls_result,
mask_pred_result,
all_text_prompts)
data_samples.pred_instances = pred_instances
elif self.task == 'panoptic':
pred_panoptic_seg = retry_if_cuda_oom(
self._panoptic_inference)(mask_cls_result,
mask_pred_result,
all_text_prompts,
num_thing_class)
data_samples.pred_panoptic_seg = pred_panoptic_seg
elif self.task == 'ref-seg':
mask_pred_results = predictions['pred_masks']
mask_cls_results = predictions['pred_logits']
results_ = zip(mask_pred_results, mask_cls_results,
batch_img_metas, batch_data_samples)
for mask_pred_result, mask_cls_result, \
img_metas, data_samples in results_:
if is_lower_torch_version():
mask_pred_result = F.interpolate(
mask_pred_result[None],
size=(batch_input_shape[-2], batch_input_shape[-1]),
mode='bicubic',
align_corners=False)[0]
else:
mask_pred_result = F.interpolate(
mask_pred_result[None],
size=(batch_input_shape[-2], batch_input_shape[-1]),
mode='bicubic',
align_corners=False,
antialias=True)[0]
if self.return_inter_mask:
mask = mask_pred_result > 0
pred_instances = InstanceData()
pred_instances.masks = mask
pred_instances.scores = mask_cls_result
data_samples.pred_instances = pred_instances
continue
height = img_metas['ori_shape'][0]
width = img_metas['ori_shape'][1]
image_size = img_metas['img_shape'][:2]
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
mask_pred_result, image_size, height, width)
pred_instances = retry_if_cuda_oom(self._instance_inference)(
mask_cls_result, mask_pred_result, all_text_prompts)
data_samples.pred_instances = pred_instances
elif self.task == 'retrieval':
batch_data_samples[0].pred_score = predictions['pred_logits']
return batch_data_samples
def _instance_inference(self, mask_cls, mask_pred, text_prompts):
num_class = len(text_prompts)
if self.task in ['ref-seg', 'caption']:
scores = F.softmax(mask_cls, dim=-1)
scores_per_image = scores.max(dim=-1)[0]
labels_per_image = torch.arange(num_class)
else:
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
labels = torch.arange(
num_class,
device=scores.device).unsqueeze(0).repeat(scores.shape[0],
1).flatten(0, 1)
scores_per_image, topk_indices = scores.flatten(0, 1).topk(
self.test_cfg.get('max_per_img', 100), sorted=False)
labels_per_image = labels[topk_indices]
topk_indices = (topk_indices // num_class)
mask_pred = mask_pred[topk_indices]
result = InstanceData()
mask_pred = mask_pred.sigmoid()
result.masks = (mask_pred > self.test_cfg.mask_thr).float()
# calculate average mask prob
mask_scores_per_image = (mask_pred.flatten(1) *
result.masks.flatten(1)).sum(1) / (
result.masks.flatten(1).sum(1) + 1e-6)
result.scores = scores_per_image * mask_scores_per_image
result.labels = labels_per_image
result.label_names = [
text_prompts[label] for label in labels_per_image
]
result.bboxes = result.scores.new_zeros(len(result.scores), 4)
return result
def _semantic_inference(self, mask_cls, mask_pred, text_prompts):
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
sem_seg = torch.einsum('qc,qhw->chw', mask_cls, mask_pred)
if sem_seg.shape[0] == 1:
# 0 is foreground, ignore_index is background
sem_seg = (sem_seg.squeeze(0) <= self.test_cfg.mask_thr).int()
sem_seg[sem_seg == 1] = self.test_cfg.get('ignore_index', 255)
else:
# 0 is foreground, ignore_index is background
if self.test_cfg.use_thr_for_mc:
foreground_flag = sem_seg > self.test_cfg.mask_thr
sem_seg = sem_seg.max(0)[1]
sem_seg[foreground_flag.sum(0) == 0] = self.test_cfg.get(
'ignore_index', 255)
else:
sem_seg = sem_seg.max(0)[1]
pred_sem_seg = PixelData(
sem_seg=sem_seg[None],
metainfo={
'label_names': text_prompts,
'ignore_index': self.test_cfg.get('ignore_index', 255)
})
return pred_sem_seg
def _panoptic_inference(self, mask_cls, mask_pred, all_text_prompts,
num_thing_class):
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
mask_pred = mask_pred.sigmoid()
keep = labels.ne(len(all_text_prompts)) & (
scores > self.test_cfg.mask_thr)
cur_scores = scores[keep]
cur_classes = labels[keep]
cur_masks = mask_pred[keep]
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
h, w = cur_masks.shape[-2:]
panoptic_seg = torch.full((h, w),
self.test_cfg.get('ignore_index', 255),
dtype=torch.int32,
device=cur_masks.device)
instance_id = 1
if cur_masks.shape[0] > 0:
cur_mask_ids = cur_prob_masks.argmax(0)
for k in range(cur_classes.shape[0]):
pred_class = cur_classes[k].item()
isthing = int(pred_class) < num_thing_class
mask_area = (cur_mask_ids == k).sum().item()
original_area = (cur_masks[k] >= 0.5).sum().item()
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
if mask_area > 0 and original_area > 0 and mask.sum().item(
) > 0:
if mask_area / original_area < self.test_cfg.overlap_thr:
continue
# merge stuff regions
if not isthing:
panoptic_seg[mask] = int(pred_class)
else:
panoptic_seg[mask] = int(
pred_class) + instance_id * INSTANCE_OFFSET
instance_id += 1
panoptic_seg = PixelData(
sem_seg=panoptic_seg[None],
metainfo={
'label_names': all_text_prompts,
'ignore_index': self.test_cfg.get('ignore_index', 255)
})
return panoptic_seg