Teeradej Sawettraporn commited on
Commit
d396400
·
verified ·
1 Parent(s): 52bdd67

Upload 3 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ TF-object_detection/TF_training_client.ipynb filter=lfs diff=lfs merge=lfs -text
TF-object_detection/TF_training_client.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cd5167770945299a7080bca49e5919f4114d3d9d52117c6bdd153479fc0752e
3
+ size 28363915
TF-object_detection/filter.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ class CocoFilter():
5
+ """ Filters the COCO dataset
6
+ """
7
+ def _process_info(self):
8
+ self.info = self.coco['info']
9
+
10
+ def _process_licenses(self):
11
+ self.licenses = self.coco['licenses']
12
+
13
+ def _process_categories(self):
14
+ self.categories = dict()
15
+ self.super_categories = dict()
16
+ self.category_set = set()
17
+
18
+ for category in self.coco['categories']:
19
+ cat_id = category['id']
20
+ super_category = category['supercategory']
21
+
22
+ # Add category to categories dict
23
+ if cat_id not in self.categories:
24
+ self.categories[cat_id] = category
25
+ self.category_set.add(category['name'])
26
+ else:
27
+ print(f'ERROR: Skipping duplicate category id: {category}')
28
+
29
+ # Add category id to the super_categories dict
30
+ if super_category not in self.super_categories:
31
+ self.super_categories[super_category] = {cat_id}
32
+ else:
33
+ self.super_categories[super_category] |= {cat_id} # e.g. {1, 2, 3} |= {4} => {1, 2, 3, 4}
34
+
35
+ def _process_images(self):
36
+ self.images = dict()
37
+ for image in self.coco['images']:
38
+ image_id = image['id']
39
+ if image_id not in self.images:
40
+ self.images[image_id] = image
41
+ else:
42
+ print(f'ERROR: Skipping duplicate image id: {image}')
43
+
44
+ def _process_segmentations(self):
45
+ self.segmentations = dict()
46
+ for segmentation in self.coco['annotations']:
47
+ image_id = segmentation['image_id']
48
+ if image_id not in self.segmentations:
49
+ self.segmentations[image_id] = []
50
+ self.segmentations[image_id].append(segmentation)
51
+
52
+ def _filter_categories(self):
53
+ """ Find category ids matching args
54
+ Create mapping from original category id to new category id
55
+ Create new collection of categories
56
+ """
57
+ missing_categories = set(self.filter_categories) - self.category_set
58
+ if len(missing_categories) > 0:
59
+ print(f'Did not find categories: {missing_categories}')
60
+ should_continue = input('Continue? (y/n) ').lower()
61
+ if should_continue != 'y' and should_continue != 'yes':
62
+ print('Quitting early.')
63
+ quit()
64
+
65
+ self.new_category_map = dict()
66
+ new_id = 1
67
+ for key, item in self.categories.items():
68
+ if item['name'] in self.filter_categories:
69
+ self.new_category_map[key] = new_id
70
+ new_id += 1
71
+
72
+ self.new_categories = []
73
+ for original_cat_id, new_id in self.new_category_map.items():
74
+ new_category = dict(self.categories[original_cat_id])
75
+ new_category['id'] = new_id
76
+ self.new_categories.append(new_category)
77
+
78
+ def _filter_annotations(self):
79
+ """ Create new collection of annotations matching category ids
80
+ Keep track of image ids matching annotations
81
+ """
82
+ self.new_segmentations = []
83
+ self.new_image_ids = set()
84
+ for image_id, segmentation_list in self.segmentations.items():
85
+ for segmentation in segmentation_list:
86
+ original_seg_cat = segmentation['category_id']
87
+ if original_seg_cat in self.new_category_map.keys():
88
+ new_segmentation = dict(segmentation)
89
+ new_segmentation['category_id'] = self.new_category_map[original_seg_cat]
90
+ self.new_segmentations.append(new_segmentation)
91
+ self.new_image_ids.add(image_id)
92
+
93
+ def _filter_images(self):
94
+ """ Create new collection of images
95
+ """
96
+ self.new_images = []
97
+ for image_id in self.new_image_ids:
98
+ self.new_images.append(self.images[image_id])
99
+
100
+ def main(self, args):
101
+ # Open json
102
+ self.input_json_path = Path(args.input_json)
103
+ self.output_json_path = Path(args.output_json)
104
+ self.filter_categories = args.categories
105
+
106
+ # Verify input path exists
107
+ if not self.input_json_path.exists():
108
+ print('Input json path not found.')
109
+ print('Quitting early.')
110
+ quit()
111
+
112
+ # Verify output path does not already exist
113
+ if self.output_json_path.exists():
114
+ should_continue = input('Output path already exists. Overwrite? (y/n) ').lower()
115
+ if should_continue != 'y' and should_continue != 'yes':
116
+ print('Quitting early.')
117
+ quit()
118
+
119
+ # Load the json
120
+ print('Loading json file...')
121
+ with open(self.input_json_path) as json_file:
122
+ self.coco = json.load(json_file)
123
+
124
+ # Process the json
125
+ print('Processing input json...')
126
+ self._process_info()
127
+ self._process_licenses()
128
+ self._process_categories()
129
+ self._process_images()
130
+ self._process_segmentations()
131
+
132
+ # Filter to specific categories
133
+ print('Filtering...')
134
+ self._filter_categories()
135
+ self._filter_annotations()
136
+ self._filter_images()
137
+
138
+ # Build new JSON
139
+ new_master_json = {
140
+ 'info': self.info,
141
+ 'licenses': self.licenses,
142
+ 'images': self.new_images,
143
+ 'annotations': self.new_segmentations,
144
+ 'categories': self.new_categories
145
+ }
146
+
147
+ # Write the JSON to a file
148
+ print('Saving new json file...')
149
+ with open(self.output_json_path, 'w+') as output_file:
150
+ json.dump(new_master_json, output_file)
151
+
152
+ print('Filtered json saved.')
153
+
154
+ if __name__ == "__main__":
155
+ import argparse
156
+
157
+ parser = argparse.ArgumentParser(description="Filter COCO JSON: "
158
+ "Filters a COCO Instances JSON file to only include specified categories. "
159
+ "This includes images, and annotations. Does not modify 'info' or 'licenses'.")
160
+
161
+ parser.add_argument("-i", "--input_json", dest="input_json",
162
+ help="path to a json file in coco format")
163
+ parser.add_argument("-o", "--output_json", dest="output_json",
164
+ help="path to save the output json")
165
+ parser.add_argument("-c", "--categories", nargs='+', dest="categories",
166
+ help="List of category names separated by spaces, e.g. -c person dog bicycle")
167
+
168
+ args = parser.parse_args()
169
+
170
+ cf = CocoFilter()
171
+ cf.main(args)
TF-object_detection/readme.txt.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Reference:
2
+
3
+ Based training model from: https://www.youtube.com/watch?v=XZ7FYAMCc4M
4
+ COCO filter: https://github.com/immersive-limit/coco-manager