| | |
| | import os.path as osp |
| | from typing import List, Union |
| |
|
| | from mmdet.registry import DATASETS |
| | from .base_video_dataset import BaseVideoDataset |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class MOTChallengeDataset(BaseVideoDataset): |
| | """Dataset for MOTChallenge. |
| | |
| | Args: |
| | visibility_thr (float, optional): The minimum visibility |
| | for the objects during training. Default to -1. |
| | """ |
| |
|
| | METAINFO = { |
| | 'classes': |
| | ('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike', |
| | 'non_mot_vehicle', 'static_person', 'distractor', 'occluder', |
| | 'occluder_on_ground', 'occluder_full', 'reflection', 'crowd') |
| | } |
| |
|
| | def __init__(self, visibility_thr: float = -1, *args, **kwargs): |
| | self.visibility_thr = visibility_thr |
| | super().__init__(*args, **kwargs) |
| |
|
| | def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: |
| | """Parse raw annotation to target format. The difference between this |
| | function and the one in ``BaseVideoDataset`` is that the parsing here |
| | adds ``visibility`` and ``mot_conf``. |
| | |
| | Args: |
| | raw_data_info (dict): Raw data information load from ``ann_file`` |
| | |
| | Returns: |
| | Union[dict, List[dict]]: Parsed annotation. |
| | """ |
| | img_info = raw_data_info['raw_img_info'] |
| | ann_info = raw_data_info['raw_ann_info'] |
| | data_info = {} |
| |
|
| | data_info.update(img_info) |
| | if self.data_prefix.get('img_path', None) is not None: |
| | img_path = osp.join(self.data_prefix['img_path'], |
| | img_info['file_name']) |
| | else: |
| | img_path = img_info['file_name'] |
| | data_info['img_path'] = img_path |
| |
|
| | instances = [] |
| | for i, ann in enumerate(ann_info): |
| | instance = {} |
| |
|
| | if (not self.test_mode) and (ann['visibility'] < |
| | self.visibility_thr): |
| | continue |
| | if ann.get('ignore', False): |
| | continue |
| | x1, y1, w, h = ann['bbox'] |
| | inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) |
| | inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) |
| | if inter_w * inter_h == 0: |
| | continue |
| | if ann['area'] <= 0 or w < 1 or h < 1: |
| | continue |
| | if ann['category_id'] not in self.cat_ids: |
| | continue |
| | bbox = [x1, y1, x1 + w, y1 + h] |
| |
|
| | if ann.get('iscrowd', False): |
| | instance['ignore_flag'] = 1 |
| | else: |
| | instance['ignore_flag'] = 0 |
| | instance['bbox'] = bbox |
| | instance['bbox_label'] = self.cat2label[ann['category_id']] |
| | instance['instance_id'] = ann['instance_id'] |
| | instance['category_id'] = ann['category_id'] |
| | instance['mot_conf'] = ann['mot_conf'] |
| | instance['visibility'] = ann['visibility'] |
| | if len(instance) > 0: |
| | instances.append(instance) |
| | if not self.test_mode: |
| | assert len(instances) > 0, f'No valid instances found in ' \ |
| | f'image {data_info["img_path"]}!' |
| | data_info['instances'] = instances |
| | return data_info |
| |
|