| | |
| | import copy |
| | import os.path as osp |
| | from collections import defaultdict |
| | from typing import Any, Dict, List |
| |
|
| | import numpy as np |
| | from mmengine.dataset import BaseDataset |
| | from mmengine.utils import check_file_exist |
| |
|
| | from mmdet.registry import DATASETS |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class ReIDDataset(BaseDataset): |
| | """Dataset for ReID. |
| | |
| | Args: |
| | triplet_sampler (dict, optional): The sampler for hard mining |
| | triplet loss. Defaults to None. |
| | keys: num_ids (int): The number of person ids. |
| | ins_per_id (int): The number of image for each person. |
| | """ |
| |
|
| | def __init__(self, triplet_sampler: dict = None, *args, **kwargs): |
| | self.triplet_sampler = triplet_sampler |
| | super().__init__(*args, **kwargs) |
| |
|
| | def load_data_list(self) -> List[dict]: |
| | """Load annotations from an annotation file named as ''self.ann_file''. |
| | |
| | Returns: |
| | list[dict]: A list of annotation. |
| | """ |
| | assert isinstance(self.ann_file, str) |
| | check_file_exist(self.ann_file) |
| | data_list = [] |
| | with open(self.ann_file) as f: |
| | samples = [x.strip().split(' ') for x in f.readlines()] |
| | for filename, gt_label in samples: |
| | info = dict(img_prefix=self.data_prefix) |
| | if self.data_prefix['img_path'] is not None: |
| | info['img_path'] = osp.join(self.data_prefix['img_path'], |
| | filename) |
| | else: |
| | info['img_path'] = filename |
| | info['gt_label'] = np.array(gt_label, dtype=np.int64) |
| | data_list.append(info) |
| | self._parse_ann_info(data_list) |
| | return data_list |
| |
|
| | def _parse_ann_info(self, data_list: List[dict]): |
| | """Parse person id annotations.""" |
| | index_tmp_dic = defaultdict(list) |
| | self.index_dic = dict() |
| | for idx, info in enumerate(data_list): |
| | pid = info['gt_label'] |
| | index_tmp_dic[int(pid)].append(idx) |
| | for pid, idxs in index_tmp_dic.items(): |
| | self.index_dic[pid] = np.asarray(idxs, dtype=np.int64) |
| | self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64) |
| |
|
| | def prepare_data(self, idx: int) -> Any: |
| | """Get data processed by ''self.pipeline''. |
| | |
| | Args: |
| | idx (int): The index of ''data_info'' |
| | |
| | Returns: |
| | Any: Depends on ''self.pipeline'' |
| | """ |
| | data_info = self.get_data_info(idx) |
| | if self.triplet_sampler is not None: |
| | img_info = self.triplet_sampling(data_info['gt_label'], |
| | **self.triplet_sampler) |
| | data_info = copy.deepcopy(img_info) |
| | else: |
| | data_info = copy.deepcopy(data_info) |
| | return self.pipeline(data_info) |
| |
|
| | def triplet_sampling(self, |
| | pos_pid, |
| | num_ids: int = 8, |
| | ins_per_id: int = 4) -> Dict: |
| | """Triplet sampler for hard mining triplet loss. First, for one |
| | pos_pid, random sample ins_per_id images with same person id. |
| | |
| | Then, random sample num_ids - 1 images for each negative id. |
| | Finally, random sample ins_per_id images for each negative id. |
| | |
| | Args: |
| | pos_pid (ndarray): The person id of the anchor. |
| | num_ids (int): The number of person ids. |
| | ins_per_id (int): The number of images for each person. |
| | |
| | Returns: |
| | Dict: Annotation information of num_ids X ins_per_id images. |
| | """ |
| | assert len(self.pids) >= num_ids, \ |
| | 'The number of person ids in the training set must ' \ |
| | 'be greater than the number of person ids in the sample.' |
| |
|
| | pos_idxs = self.index_dic[int( |
| | pos_pid)] |
| | idxs_list = [] |
| | |
| | idxs_list.extend(pos_idxs[np.random.choice( |
| | pos_idxs.shape[0], ins_per_id, replace=True)]) |
| | |
| | neg_pids = np.random.choice( |
| | [i for i, _ in enumerate(self.pids) if i != pos_pid], |
| | num_ids - 1, |
| | replace=False) |
| | |
| | for neg_pid in neg_pids: |
| | neg_idxs = self.index_dic[neg_pid] |
| | idxs_list.extend(neg_idxs[np.random.choice( |
| | neg_idxs.shape[0], ins_per_id, replace=True)]) |
| | |
| | triplet_img_infos = [] |
| | for idx in idxs_list: |
| | triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx))) |
| | |
| | out = dict() |
| | for key in triplet_img_infos[0].keys(): |
| | out[key] = [_info[key] for _info in triplet_img_infos] |
| | return out |
| |
|