|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
class CocoFilter(): |
|
|
""" Filters the COCO dataset |
|
|
""" |
|
|
def _process_info(self): |
|
|
self.info = self.coco['info'] |
|
|
|
|
|
def _process_licenses(self): |
|
|
self.licenses = self.coco['licenses'] |
|
|
|
|
|
def _process_categories(self): |
|
|
self.categories = dict() |
|
|
self.super_categories = dict() |
|
|
self.category_set = set() |
|
|
|
|
|
for category in self.coco['categories']: |
|
|
cat_id = category['id'] |
|
|
super_category = category['supercategory'] |
|
|
|
|
|
|
|
|
if cat_id not in self.categories: |
|
|
self.categories[cat_id] = category |
|
|
self.category_set.add(category['name']) |
|
|
else: |
|
|
print(f'ERROR: Skipping duplicate category id: {category}') |
|
|
|
|
|
|
|
|
if super_category not in self.super_categories: |
|
|
self.super_categories[super_category] = {cat_id} |
|
|
else: |
|
|
self.super_categories[super_category] |= {cat_id} |
|
|
|
|
|
def _process_images(self): |
|
|
self.images = dict() |
|
|
for image in self.coco['images']: |
|
|
image_id = image['id'] |
|
|
if image_id not in self.images: |
|
|
self.images[image_id] = image |
|
|
else: |
|
|
print(f'ERROR: Skipping duplicate image id: {image}') |
|
|
|
|
|
def _process_segmentations(self): |
|
|
self.segmentations = dict() |
|
|
for segmentation in self.coco['annotations']: |
|
|
image_id = segmentation['image_id'] |
|
|
if image_id not in self.segmentations: |
|
|
self.segmentations[image_id] = [] |
|
|
self.segmentations[image_id].append(segmentation) |
|
|
|
|
|
def _filter_categories(self): |
|
|
""" Find category ids matching args |
|
|
Create mapping from original category id to new category id |
|
|
Create new collection of categories |
|
|
""" |
|
|
missing_categories = set(self.filter_categories) - self.category_set |
|
|
if len(missing_categories) > 0: |
|
|
print(f'Did not find categories: {missing_categories}') |
|
|
should_continue = input('Continue? (y/n) ').lower() |
|
|
if should_continue != 'y' and should_continue != 'yes': |
|
|
print('Quitting early.') |
|
|
quit() |
|
|
|
|
|
self.new_category_map = dict() |
|
|
new_id = 1 |
|
|
for key, item in self.categories.items(): |
|
|
if item['name'] in self.filter_categories: |
|
|
self.new_category_map[key] = new_id |
|
|
new_id += 1 |
|
|
|
|
|
self.new_categories = [] |
|
|
for original_cat_id, new_id in self.new_category_map.items(): |
|
|
new_category = dict(self.categories[original_cat_id]) |
|
|
new_category['id'] = new_id |
|
|
self.new_categories.append(new_category) |
|
|
|
|
|
def _filter_annotations(self): |
|
|
""" Create new collection of annotations matching category ids |
|
|
Keep track of image ids matching annotations |
|
|
""" |
|
|
self.new_segmentations = [] |
|
|
self.new_image_ids = set() |
|
|
for image_id, segmentation_list in self.segmentations.items(): |
|
|
for segmentation in segmentation_list: |
|
|
original_seg_cat = segmentation['category_id'] |
|
|
if original_seg_cat in self.new_category_map.keys(): |
|
|
new_segmentation = dict(segmentation) |
|
|
new_segmentation['category_id'] = self.new_category_map[original_seg_cat] |
|
|
self.new_segmentations.append(new_segmentation) |
|
|
self.new_image_ids.add(image_id) |
|
|
|
|
|
def _filter_images(self): |
|
|
""" Create new collection of images |
|
|
""" |
|
|
self.new_images = [] |
|
|
for image_id in self.new_image_ids: |
|
|
self.new_images.append(self.images[image_id]) |
|
|
|
|
|
def main(self, args): |
|
|
|
|
|
self.input_json_path = Path(args.input_json) |
|
|
self.output_json_path = Path(args.output_json) |
|
|
self.filter_categories = args.categories |
|
|
|
|
|
|
|
|
if not self.input_json_path.exists(): |
|
|
print('Input json path not found.') |
|
|
print('Quitting early.') |
|
|
quit() |
|
|
|
|
|
|
|
|
if self.output_json_path.exists(): |
|
|
should_continue = input('Output path already exists. Overwrite? (y/n) ').lower() |
|
|
if should_continue != 'y' and should_continue != 'yes': |
|
|
print('Quitting early.') |
|
|
quit() |
|
|
|
|
|
|
|
|
print('Loading json file...') |
|
|
with open(self.input_json_path) as json_file: |
|
|
self.coco = json.load(json_file) |
|
|
|
|
|
|
|
|
print('Processing input json...') |
|
|
self._process_info() |
|
|
self._process_licenses() |
|
|
self._process_categories() |
|
|
self._process_images() |
|
|
self._process_segmentations() |
|
|
|
|
|
|
|
|
print('Filtering...') |
|
|
self._filter_categories() |
|
|
self._filter_annotations() |
|
|
self._filter_images() |
|
|
|
|
|
|
|
|
new_master_json = { |
|
|
'info': self.info, |
|
|
'licenses': self.licenses, |
|
|
'images': self.new_images, |
|
|
'annotations': self.new_segmentations, |
|
|
'categories': self.new_categories |
|
|
} |
|
|
|
|
|
|
|
|
print('Saving new json file...') |
|
|
with open(self.output_json_path, 'w+') as output_file: |
|
|
json.dump(new_master_json, output_file) |
|
|
|
|
|
print('Filtered json saved.') |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Filter COCO JSON: " |
|
|
"Filters a COCO Instances JSON file to only include specified categories. " |
|
|
"This includes images, and annotations. Does not modify 'info' or 'licenses'.") |
|
|
|
|
|
parser.add_argument("-i", "--input_json", dest="input_json", |
|
|
help="path to a json file in coco format") |
|
|
parser.add_argument("-o", "--output_json", dest="output_json", |
|
|
help="path to save the output json") |
|
|
parser.add_argument("-c", "--categories", nargs='+', dest="categories", |
|
|
help="List of category names separated by spaces, e.g. -c person dog bicycle") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
cf = CocoFilter() |
|
|
cf.main(args) |
|
|
|