File size: 6,483 Bytes
d396400 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import json
from pathlib import Path
class CocoFilter():
""" Filters the COCO dataset
"""
def _process_info(self):
self.info = self.coco['info']
def _process_licenses(self):
self.licenses = self.coco['licenses']
def _process_categories(self):
self.categories = dict()
self.super_categories = dict()
self.category_set = set()
for category in self.coco['categories']:
cat_id = category['id']
super_category = category['supercategory']
# Add category to categories dict
if cat_id not in self.categories:
self.categories[cat_id] = category
self.category_set.add(category['name'])
else:
print(f'ERROR: Skipping duplicate category id: {category}')
# Add category id to the super_categories dict
if super_category not in self.super_categories:
self.super_categories[super_category] = {cat_id}
else:
self.super_categories[super_category] |= {cat_id} # e.g. {1, 2, 3} |= {4} => {1, 2, 3, 4}
def _process_images(self):
self.images = dict()
for image in self.coco['images']:
image_id = image['id']
if image_id not in self.images:
self.images[image_id] = image
else:
print(f'ERROR: Skipping duplicate image id: {image}')
def _process_segmentations(self):
self.segmentations = dict()
for segmentation in self.coco['annotations']:
image_id = segmentation['image_id']
if image_id not in self.segmentations:
self.segmentations[image_id] = []
self.segmentations[image_id].append(segmentation)
def _filter_categories(self):
""" Find category ids matching args
Create mapping from original category id to new category id
Create new collection of categories
"""
missing_categories = set(self.filter_categories) - self.category_set
if len(missing_categories) > 0:
print(f'Did not find categories: {missing_categories}')
should_continue = input('Continue? (y/n) ').lower()
if should_continue != 'y' and should_continue != 'yes':
print('Quitting early.')
quit()
self.new_category_map = dict()
new_id = 1
for key, item in self.categories.items():
if item['name'] in self.filter_categories:
self.new_category_map[key] = new_id
new_id += 1
self.new_categories = []
for original_cat_id, new_id in self.new_category_map.items():
new_category = dict(self.categories[original_cat_id])
new_category['id'] = new_id
self.new_categories.append(new_category)
def _filter_annotations(self):
""" Create new collection of annotations matching category ids
Keep track of image ids matching annotations
"""
self.new_segmentations = []
self.new_image_ids = set()
for image_id, segmentation_list in self.segmentations.items():
for segmentation in segmentation_list:
original_seg_cat = segmentation['category_id']
if original_seg_cat in self.new_category_map.keys():
new_segmentation = dict(segmentation)
new_segmentation['category_id'] = self.new_category_map[original_seg_cat]
self.new_segmentations.append(new_segmentation)
self.new_image_ids.add(image_id)
def _filter_images(self):
""" Create new collection of images
"""
self.new_images = []
for image_id in self.new_image_ids:
self.new_images.append(self.images[image_id])
def main(self, args):
# Open json
self.input_json_path = Path(args.input_json)
self.output_json_path = Path(args.output_json)
self.filter_categories = args.categories
# Verify input path exists
if not self.input_json_path.exists():
print('Input json path not found.')
print('Quitting early.')
quit()
# Verify output path does not already exist
if self.output_json_path.exists():
should_continue = input('Output path already exists. Overwrite? (y/n) ').lower()
if should_continue != 'y' and should_continue != 'yes':
print('Quitting early.')
quit()
# Load the json
print('Loading json file...')
with open(self.input_json_path) as json_file:
self.coco = json.load(json_file)
# Process the json
print('Processing input json...')
self._process_info()
self._process_licenses()
self._process_categories()
self._process_images()
self._process_segmentations()
# Filter to specific categories
print('Filtering...')
self._filter_categories()
self._filter_annotations()
self._filter_images()
# Build new JSON
new_master_json = {
'info': self.info,
'licenses': self.licenses,
'images': self.new_images,
'annotations': self.new_segmentations,
'categories': self.new_categories
}
# Write the JSON to a file
print('Saving new json file...')
with open(self.output_json_path, 'w+') as output_file:
json.dump(new_master_json, output_file)
print('Filtered json saved.')
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Filter COCO JSON: "
"Filters a COCO Instances JSON file to only include specified categories. "
"This includes images, and annotations. Does not modify 'info' or 'licenses'.")
parser.add_argument("-i", "--input_json", dest="input_json",
help="path to a json file in coco format")
parser.add_argument("-o", "--output_json", dest="output_json",
help="path to save the output json")
parser.add_argument("-c", "--categories", nargs='+', dest="categories",
help="List of category names separated by spaces, e.g. -c person dog bicycle")
args = parser.parse_args()
cf = CocoFilter()
cf.main(args)
|