| | import os.path as osp |
| | import xml.etree.ElementTree as ET |
| |
|
| | import mmcv |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | from .builder import DATASETS |
| | from .custom import CustomDataset |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class XMLDataset(CustomDataset): |
| | """XML dataset for detection. |
| | |
| | Args: |
| | min_size (int | float, optional): The minimum size of bounding |
| | boxes in the images. If the size of a bounding box is less than |
| | ``min_size``, it would be add to ignored field. |
| | """ |
| |
|
| | def __init__(self, min_size=None, **kwargs): |
| | assert self.CLASSES or kwargs.get( |
| | 'classes', None), 'CLASSES in `XMLDataset` can not be None.' |
| | super(XMLDataset, self).__init__(**kwargs) |
| | self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} |
| | self.min_size = min_size |
| |
|
| | def load_annotations(self, ann_file): |
| | """Load annotation from XML style ann_file. |
| | |
| | Args: |
| | ann_file (str): Path of XML file. |
| | |
| | Returns: |
| | list[dict]: Annotation info from XML file. |
| | """ |
| |
|
| | data_infos = [] |
| | img_ids = mmcv.list_from_file(ann_file) |
| | for img_id in img_ids: |
| | filename = f'JPEGImages/{img_id}.jpg' |
| | xml_path = osp.join(self.img_prefix, 'Annotations', |
| | f'{img_id}.xml') |
| | tree = ET.parse(xml_path) |
| | root = tree.getroot() |
| | size = root.find('size') |
| | if size is not None: |
| | width = int(size.find('width').text) |
| | height = int(size.find('height').text) |
| | else: |
| | img_path = osp.join(self.img_prefix, 'JPEGImages', |
| | '{}.jpg'.format(img_id)) |
| | img = Image.open(img_path) |
| | width, height = img.size |
| | data_infos.append( |
| | dict(id=img_id, filename=filename, width=width, height=height)) |
| |
|
| | return data_infos |
| |
|
| | def _filter_imgs(self, min_size=32): |
| | """Filter images too small or without annotation.""" |
| | valid_inds = [] |
| | for i, img_info in enumerate(self.data_infos): |
| | if min(img_info['width'], img_info['height']) < min_size: |
| | continue |
| | if self.filter_empty_gt: |
| | img_id = img_info['id'] |
| | xml_path = osp.join(self.img_prefix, 'Annotations', |
| | f'{img_id}.xml') |
| | tree = ET.parse(xml_path) |
| | root = tree.getroot() |
| | for obj in root.findall('object'): |
| | name = obj.find('name').text |
| | if name in self.CLASSES: |
| | valid_inds.append(i) |
| | break |
| | else: |
| | valid_inds.append(i) |
| | return valid_inds |
| |
|
| | def get_ann_info(self, idx): |
| | """Get annotation from XML file by index. |
| | |
| | Args: |
| | idx (int): Index of data. |
| | |
| | Returns: |
| | dict: Annotation info of specified index. |
| | """ |
| |
|
| | img_id = self.data_infos[idx]['id'] |
| | xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') |
| | tree = ET.parse(xml_path) |
| | root = tree.getroot() |
| | bboxes = [] |
| | labels = [] |
| | bboxes_ignore = [] |
| | labels_ignore = [] |
| | for obj in root.findall('object'): |
| | name = obj.find('name').text |
| | if name not in self.CLASSES: |
| | continue |
| | label = self.cat2label[name] |
| | difficult = obj.find('difficult') |
| | difficult = 0 if difficult is None else int(difficult.text) |
| | bnd_box = obj.find('bndbox') |
| | |
| | |
| | bbox = [ |
| | int(float(bnd_box.find('xmin').text)), |
| | int(float(bnd_box.find('ymin').text)), |
| | int(float(bnd_box.find('xmax').text)), |
| | int(float(bnd_box.find('ymax').text)) |
| | ] |
| | ignore = False |
| | if self.min_size: |
| | assert not self.test_mode |
| | w = bbox[2] - bbox[0] |
| | h = bbox[3] - bbox[1] |
| | if w < self.min_size or h < self.min_size: |
| | ignore = True |
| | if difficult or ignore: |
| | bboxes_ignore.append(bbox) |
| | labels_ignore.append(label) |
| | else: |
| | bboxes.append(bbox) |
| | labels.append(label) |
| | if not bboxes: |
| | bboxes = np.zeros((0, 4)) |
| | labels = np.zeros((0, )) |
| | else: |
| | bboxes = np.array(bboxes, ndmin=2) - 1 |
| | labels = np.array(labels) |
| | if not bboxes_ignore: |
| | bboxes_ignore = np.zeros((0, 4)) |
| | labels_ignore = np.zeros((0, )) |
| | else: |
| | bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 |
| | labels_ignore = np.array(labels_ignore) |
| | ann = dict( |
| | bboxes=bboxes.astype(np.float32), |
| | labels=labels.astype(np.int64), |
| | bboxes_ignore=bboxes_ignore.astype(np.float32), |
| | labels_ignore=labels_ignore.astype(np.int64)) |
| | return ann |
| |
|
| | def get_cat_ids(self, idx): |
| | """Get category ids in XML file by index. |
| | |
| | Args: |
| | idx (int): Index of data. |
| | |
| | Returns: |
| | list[int]: All categories in the image of specified index. |
| | """ |
| |
|
| | cat_ids = [] |
| | img_id = self.data_infos[idx]['id'] |
| | xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') |
| | tree = ET.parse(xml_path) |
| | root = tree.getroot() |
| | for obj in root.findall('object'): |
| | name = obj.find('name').text |
| | if name not in self.CLASSES: |
| | continue |
| | label = self.cat2label[name] |
| | cat_ids.append(label) |
| |
|
| | return cat_ids |
| |
|