PolarisFTL commited on
Commit
e0cf1b5
·
verified ·
1 Parent(s): 4cae3ea

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +67 -0
  2. predict.py +39 -0
  3. requirements.txt +9 -0
  4. yolo.py +448 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from yolo import YOLO
3
+ import gradio as gr
4
+ import os
5
+
6
+ # Initialize YOLO model
7
+ yolo = YOLO()
8
+
9
+ def detect_objects(image, crop=False, count=True):
10
+ r_image = yolo.detect_image(image, crop=crop, count=count)
11
+ return r_image
12
+
13
+ def save_image(image, filename):
14
+ if not os.path.exists("img_out"):
15
+ os.makedirs("img_out")
16
+ image.save(os.path.join("img_out", filename), quality=95, subsampling=0)
17
+ return os.path.join("img_out", filename)
18
+
19
+ # Gradio interface for single image prediction
20
+ def predict(image):
21
+ result_image = detect_objects(image)
22
+ output_path = save_image(result_image, "output.png")
23
+ return output_path
24
+
25
+ # Gradio interface for directory prediction
26
+ def dir_predict(dir_origin_path):
27
+ img_names = os.listdir(dir_origin_path)
28
+ output_images = []
29
+ for img_name in img_names:
30
+ if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
31
+ image_path = os.path.join(dir_origin_path, img_name)
32
+ image = Image.open(image_path)
33
+ r_image = detect_objects(image)
34
+ output_path = save_image(r_image, img_name.replace(".jpg", ".png"))
35
+ output_images.append(output_path)
36
+ return output_images
37
+
38
+ # Gradio interface components
39
+ image_input = gr.inputs.Image(type="pil", label="Input Image")
40
+ image_output = gr.outputs.Image(type="file", label="Output Image")
41
+
42
+ # Gradio app
43
+ iface = gr.Interface(
44
+ fn=predict,
45
+ inputs=image_input,
46
+ outputs=image_output,
47
+ title="YOLO Object Detection",
48
+ description="Upload an image to detect objects using YOLO model."
49
+ )
50
+
51
+ # Directory prediction interface
52
+ dir_input = gr.inputs.Textbox(label="Directory Path")
53
+ dir_output = gr.outputs.Textbox(label="Output Paths")
54
+
55
+ iface_dir = gr.Interface(
56
+ fn=dir_predict,
57
+ inputs=dir_input,
58
+ outputs=dir_output,
59
+ title="YOLO Object Detection for Directory",
60
+ description="Provide a directory path to detect objects in all images within the directory."
61
+ )
62
+
63
+ # Combine both interfaces
64
+ app = gr.TabbedInterface([iface, iface_dir], ["Single Image Prediction", "Directory Prediction"])
65
+
66
+ # Launch the app
67
+ app.launch()
predict.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from yolo import YOLO
3
+
4
+ if __name__ == "__main__":
5
+ mode = 'predict'
6
+ crop = False
7
+ count = True
8
+ dir_origin_path = "img/vs"
9
+ dir_save_path = "img_out"
10
+
11
+ yolo = YOLO()
12
+
13
+ if mode == "predict":
14
+ while True:
15
+ img = input('Input image filename:')
16
+ try:
17
+ image = Image.open(img)
18
+ except:
19
+ print('Open Error! Try again!')
20
+ continue
21
+ else:
22
+ r_image = yolo.detect_image(image, crop = crop, count=count)
23
+ r_image.show()
24
+
25
+ elif mode == "dir_predict":
26
+ import os
27
+ from tqdm import tqdm
28
+
29
+ img_names = os.listdir(dir_origin_path)
30
+ for img_name in tqdm(img_names):
31
+ if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
32
+ image_path = os.path.join(dir_origin_path, img_name)
33
+ image = Image.open(image_path)
34
+ r_image = yolo.detect_image(image)
35
+ if not os.path.exists(dir_save_path):
36
+ os.makedirs(dir_save_path)
37
+ r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
38
+ else:
39
+ raise AssertionError("Please specify the correct mode: 'predict', 'dir_predict'.")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ scipy==1.2.1
2
+ numpy==1.17.0
3
+ matplotlib==3.1.2
4
+ opencv_python==4.1.2.30
5
+ torch==1.2.0
6
+ torchvision==0.4.0
7
+ tqdm==4.60.0
8
+ Pillow==8.2.0
9
+ h5py==2.10.0
yolo.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import cv2
9
+ from PIL import ImageDraw, ImageFont, Image
10
+
11
+ from nets.yolo import YoloBody
12
+ from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
13
+ resize_image, show_config)
14
+ from utils.utils_bbox import DecodeBox, DecodeBoxNP
15
+
16
+ '''
17
+ 训练自己的数据集必看注释!
18
+ '''
19
+ class YOLO(object):
20
+ _defaults = {
21
+ "model_path" : 'model_data/rtts.pth',
22
+ "classes_path" : 'model_data/rtts_classes.txt',
23
+ "anchors_path" : 'model_data/yolo_anchors.txt',
24
+ "anchors_mask" : [[3,4,5], [1,2,3]],
25
+ "backbone" : 'tiny',
26
+ "phi" : 0,
27
+ "input_shape" : [416, 416],
28
+ "confidence" : 0.5,
29
+ "nms_iou" : 0.3,
30
+ "letterbox_image" : False,
31
+ "cuda" : True,
32
+ }
33
+
34
+ @classmethod
35
+ def get_defaults(cls, n):
36
+ if n in cls._defaults:
37
+ return cls._defaults[n]
38
+ else:
39
+ return "Unrecognized attribute name '" + n + "'"
40
+
41
+ def __init__(self, **kwargs):
42
+ self.__dict__.update(self._defaults)
43
+ for name, value in kwargs.items():
44
+ setattr(self, name, value)
45
+ self._defaults[name] = value
46
+
47
+ self.class_names, self.num_classes = get_classes(self.classes_path)
48
+ self.anchors, self.num_anchors = get_anchors(self.anchors_path)
49
+ self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
50
+
51
+ hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
52
+ self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
53
+ self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
54
+ self.generate()
55
+
56
+ show_config(**self._defaults)
57
+
58
+ def generate(self, onnx=False):
59
+ self.net = YoloBody(self.anchors_mask, self.num_classes, self.phi, self.backbone)
60
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61
+ self.net.load_state_dict(torch.load(self.model_path, map_location=device))
62
+ self.net = self.net.eval()
63
+ print('{} model, anchors, and classes loaded.'.format(self.model_path))
64
+ if not onnx:
65
+ if self.cuda:
66
+ self.net = nn.DataParallel(self.net)
67
+ self.net = self.net.cuda()
68
+
69
+ def detect_image(self, image, crop = False, count = False):
70
+ image_shape = np.array(np.shape(image)[0:2])
71
+ image = cvtColor(image)
72
+ image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
73
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
74
+
75
+ with torch.no_grad():
76
+ images = torch.from_numpy(image_data)
77
+ if self.cuda:
78
+ images = images.cuda()
79
+ outputs = self.net(images)
80
+ outputs = self.bbox_util.decode_box(outputs)
81
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
82
+ image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
83
+
84
+ if results[0] is None:
85
+ return image
86
+
87
+ top_label = np.array(results[0][:, 6], dtype = 'int32')
88
+ top_conf = results[0][:, 4] * results[0][:, 5]
89
+ top_boxes = results[0][:, :4]
90
+ font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
91
+ thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
92
+ if count:
93
+ print("top_label:", top_label)
94
+ classes_nums = np.zeros([self.num_classes])
95
+ for i in range(self.num_classes):
96
+ num = np.sum(top_label == i)
97
+ if num > 0:
98
+ print(self.class_names[i], " : ", num)
99
+ classes_nums[i] = num
100
+ print("classes_nums:", classes_nums)
101
+ if crop:
102
+ for i, c in list(enumerate(top_label)):
103
+ top, left, bottom, right = top_boxes[i]
104
+ top = max(0, np.floor(top).astype('int32'))
105
+ left = max(0, np.floor(left).astype('int32'))
106
+ bottom = min(image.size[1], np.floor(bottom).astype('int32'))
107
+ right = min(image.size[0], np.floor(right).astype('int32'))
108
+
109
+ dir_save_path = "img_crop"
110
+ if not os.path.exists(dir_save_path):
111
+ os.makedirs(dir_save_path)
112
+ crop_image = image.crop([left, top, right, bottom])
113
+ crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
114
+ print("save crop_" + str(i) + ".png to " + dir_save_path)
115
+ for i, c in list(enumerate(top_label)):
116
+ predicted_class = self.class_names[int(c)]
117
+ box = top_boxes[i]
118
+ score = top_conf[i]
119
+
120
+ top, left, bottom, right = box
121
+
122
+ top = max(0, np.floor(top).astype('int32'))
123
+ left = max(0, np.floor(left).astype('int32'))
124
+ bottom = min(image.size[1], np.floor(bottom).astype('int32'))
125
+ right = min(image.size[0], np.floor(right).astype('int32'))
126
+
127
+ label = '{} {:.2f}'.format(predicted_class, score)
128
+ draw = ImageDraw.Draw(image)
129
+ label_size = draw.textsize(label, font)
130
+ label = label.encode('utf-8')
131
+ print(label, top, left, bottom, right)
132
+
133
+ if top - label_size[1] >= 0:
134
+ text_origin = np.array([left, top - label_size[1]])
135
+ else:
136
+ text_origin = np.array([left, top + 1])
137
+
138
+ for i in range(thickness):
139
+ draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
140
+ draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
141
+ draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
142
+ del draw
143
+
144
+ return image
145
+
146
+ def get_FPS(self, image, test_interval):
147
+ image_shape = np.array(np.shape(image)[0:2])
148
+ image = cvtColor(image)
149
+ image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
150
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
151
+
152
+ with torch.no_grad():
153
+ images = torch.from_numpy(image_data)
154
+ if self.cuda:
155
+ images = images.cuda()
156
+ outputs = self.net(images)
157
+ outputs = self.bbox_util.decode_box(outputs)
158
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
159
+ image_shape, self.letterbox_image, conf_thres=self.confidence, nms_thres=self.nms_iou)
160
+
161
+ t1 = time.time()
162
+ for _ in range(test_interval):
163
+ with torch.no_grad():
164
+ outputs = self.net(images)
165
+ outputs = self.bbox_util.decode_box(outputs)
166
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
167
+ image_shape, self.letterbox_image, conf_thres=self.confidence, nms_thres=self.nms_iou)
168
+
169
+ t2 = time.time()
170
+ tact_time = (t2 - t1) / test_interval
171
+ return tact_time
172
+
173
+ def detect_heatmap(self, image, heatmap_save_path):
174
+ import cv2
175
+ import matplotlib.pyplot as plt
176
+ def sigmoid(x):
177
+ y = 1.0 / (1.0 + np.exp(-x))
178
+ return y
179
+ image = cvtColor(image)
180
+ image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
181
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
182
+
183
+ with torch.no_grad():
184
+ images = torch.from_numpy(image_data)
185
+ if self.cuda:
186
+ images = images.cuda()
187
+ outputs = self.net(images)
188
+ plt.clf()
189
+ plt.imshow(image, alpha=1)
190
+ plt.axis('off')
191
+ mask = np.zeros((image.size[1], image.size[0]))
192
+ for sub_output in outputs:
193
+ sub_output = sub_output.cpu().numpy()
194
+ b, c, h, w = np.shape(sub_output)
195
+ sub_output = np.transpose(np.reshape(sub_output, [b, 3, -1, h, w]), [0, 3, 4, 1, 2])[0]
196
+ score = np.max(sigmoid(sub_output[..., 4]), -1)
197
+ score = cv2.resize(score, (image.size[0], image.size[1]))
198
+ normed_score = (score * 255).astype('uint8')
199
+ mask = np.maximum(mask, normed_score)
200
+
201
+ plt.imshow(mask, alpha=0.5, interpolation='nearest', cmap="jet")
202
+
203
+ plt.axis('off')
204
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
205
+ plt.margins(0, 0)
206
+ plt.savefig(heatmap_save_path, dpi=200, bbox_inches='tight', pad_inches = -0.1)
207
+ print("Save to the " + heatmap_save_path)
208
+ plt.show()
209
+
210
+ def convert_to_onnx(self, simplify, model_path):
211
+ import onnx
212
+ self.generate(onnx=True)
213
+ im = torch.zeros(1, 3, *self.input_shape).to('cpu')
214
+ input_layer_names = ["images"]
215
+ output_layer_names = ["output"]
216
+
217
+ print(f'Starting export with onnx {onnx.__version__}.')
218
+ torch.onnx.export(self.net,
219
+ im,
220
+ f = model_path,
221
+ verbose = False,
222
+ opset_version = 12,
223
+ training = torch.onnx.TrainingMode.EVAL,
224
+ do_constant_folding = True,
225
+ input_names = input_layer_names,
226
+ output_names = output_layer_names,
227
+ dynamic_axes = None)
228
+
229
+ model_onnx = onnx.load(model_path)
230
+ onnx.checker.check_model(model_onnx)
231
+ if simplify:
232
+ import onnxsim
233
+ print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
234
+ model_onnx, check = onnxsim.simplify(
235
+ model_onnx,
236
+ dynamic_input_shape=False,
237
+ input_shapes=None)
238
+ assert check, 'assert check failed'
239
+ onnx.save(model_onnx, model_path)
240
+
241
+ print('Onnx model save as {}'.format(model_path))
242
+
243
+ def get_map_txt(self, image_id, image, class_names, map_out_path):
244
+ f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
245
+ image_shape = np.array(np.shape(image)[0:2])
246
+ image = cvtColor(image)
247
+ image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
248
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
249
+
250
+ with torch.no_grad():
251
+ images = torch.from_numpy(image_data)
252
+ if self.cuda:
253
+ images = images.cuda()
254
+ outputs = self.net(images)
255
+ outputs = self.bbox_util.decode_box(outputs)
256
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
257
+ image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
258
+
259
+ if results[0] is None:
260
+ return
261
+
262
+ top_label = np.array(results[0][:, 6], dtype = 'int32')
263
+ top_conf = results[0][:, 4] * results[0][:, 5]
264
+ top_boxes = results[0][:, :4]
265
+
266
+ for i, c in list(enumerate(top_label)):
267
+ predicted_class = self.class_names[int(c)]
268
+ box = top_boxes[i]
269
+ score = str(top_conf[i])
270
+
271
+ top, left, bottom, right = box
272
+ if predicted_class not in class_names:
273
+ continue
274
+
275
+ f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
276
+
277
+ f.close()
278
+ return
279
+
280
+ class YOLO_ONNX(object):
281
+ _defaults = {
282
+ "onnx_path" : 'model_data/models.onnx',
283
+ "classes_path" : 'model_data/rtts_classes.txt',
284
+ "anchors_path" : 'model_data/yolo_anchors.txt',
285
+ "anchors_mask" : [[3, 4, 5], [1, 2, 3]],
286
+ "input_shape" : [416, 416],
287
+ "confidence" : 0.5,
288
+ "nms_iou" : 0.3,
289
+ "letterbox_image" : True
290
+ }
291
+
292
+ @classmethod
293
+ def get_defaults(cls, n):
294
+ if n in cls._defaults:
295
+ return cls._defaults[n]
296
+ else:
297
+ return "Unrecognized attribute name '" + n + "'"
298
+
299
+ def __init__(self, **kwargs):
300
+ self.__dict__.update(self._defaults)
301
+ for name, value in kwargs.items():
302
+ setattr(self, name, value)
303
+ self._defaults[name] = value
304
+
305
+ import onnxruntime
306
+ self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
307
+ self.input_name = self.get_input_name()
308
+ self.output_name = self.get_output_name()
309
+
310
+ self.class_names, self.num_classes = self.get_classes(self.classes_path)
311
+ self.anchors, self.num_anchors = self.get_anchors(self.anchors_path)
312
+ self.bbox_util = DecodeBoxNP(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
313
+
314
+ hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
315
+ self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
316
+ self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
317
+
318
+ show_config(**self._defaults)
319
+
320
+ def get_classes(self, classes_path):
321
+ with open(classes_path, encoding='utf-8') as f:
322
+ class_names = f.readlines()
323
+ class_names = [c.strip() for c in class_names]
324
+ return class_names, len(class_names)
325
+
326
+ def get_anchors(self, anchors_path):
327
+ '''loads the anchors from a file'''
328
+ with open(anchors_path, encoding='utf-8') as f:
329
+ anchors = f.readline()
330
+ anchors = [float(x) for x in anchors.split(',')]
331
+ anchors = np.array(anchors).reshape(-1, 2)
332
+ return anchors, len(anchors)
333
+
334
+ def get_input_name(self):
335
+ input_name=[]
336
+ for node in self.onnx_session.get_inputs():
337
+ input_name.append(node.name)
338
+ return input_name
339
+
340
+ def get_output_name(self):
341
+ output_name=[]
342
+ for node in self.onnx_session.get_outputs():
343
+ output_name.append(node.name)
344
+ return output_name
345
+
346
+ def get_input_feed(self,image_tensor):
347
+ input_feed={}
348
+ for name in self.input_name:
349
+ input_feed[name]=image_tensor
350
+ return input_feed
351
+
352
+ def resize_image(self, image, size, letterbox_image, mode='PIL'):
353
+ if mode == 'PIL':
354
+ iw, ih = image.size
355
+ w, h = size
356
+
357
+ if letterbox_image:
358
+ scale = min(w/iw, h/ih)
359
+ nw = int(iw*scale)
360
+ nh = int(ih*scale)
361
+
362
+ image = image.resize((nw,nh), Image.BICUBIC)
363
+ new_image = Image.new('RGB', size, (128,128,128))
364
+ new_image.paste(image, ((w-nw)//2, (h-nh)//2))
365
+ else:
366
+ new_image = image.resize((w, h), Image.BICUBIC)
367
+ else:
368
+ image = np.array(image)
369
+ if letterbox_image:
370
+ shape = np.shape(image)[:2]
371
+ if isinstance(size, int):
372
+ size = (size, size)
373
+
374
+ r = min(size[0] / shape[0], size[1] / shape[1])
375
+
376
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
377
+ dw, dh = size[1] - new_unpad[0], size[0] - new_unpad[1]
378
+
379
+ dw /= 2
380
+ dh /= 2
381
+
382
+ image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
383
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
384
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
385
+
386
+ else:
387
+ new_image = cv2.resize(image, (w, h))
388
+
389
+ return new_image
390
+
391
+ def detect_image(self, image):
392
+ image_shape = np.array(np.shape(image)[0:2])
393
+ image = cvtColor(image)
394
+
395
+ image_data = self.resize_image(image, self.input_shape, True)
396
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
397
+
398
+ input_feed = self.get_input_feed(image_data)
399
+ outputs = self.onnx_session.run(output_names=self.output_name, input_feed=input_feed)
400
+
401
+ feature_map_shape = [[int(j / (2 ** (i + 4))) for j in self.input_shape] for i in range(len(self.anchors_mask))][::-1]
402
+ for i in range(len(self.anchors_mask)):
403
+ outputs[i] = np.reshape(outputs[i], (1, len(self.anchors_mask[i]) * (5 + self.num_classes), feature_map_shape[i][0], feature_map_shape[i][1]))
404
+
405
+ outputs = self.bbox_util.decode_box(outputs)
406
+ results = self.bbox_util.non_max_suppression(np.concatenate(outputs, 1), self.num_classes, self.input_shape,
407
+ image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
408
+
409
+ if results[0] is None:
410
+ return image
411
+
412
+ top_label = np.array(results[0][:, 6], dtype = 'int32')
413
+ top_conf = results[0][:, 4] * results[0][:, 5]
414
+ top_boxes = results[0][:, :4]
415
+
416
+ font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
417
+ thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
418
+
419
+ for i, c in list(enumerate(top_label)):
420
+ predicted_class = self.class_names[int(c)]
421
+ box = top_boxes[i]
422
+ score = top_conf[i]
423
+
424
+ top, left, bottom, right = box
425
+
426
+ top = max(0, np.floor(top).astype('int32'))
427
+ left = max(0, np.floor(left).astype('int32'))
428
+ bottom = min(image.size[1], np.floor(bottom).astype('int32'))
429
+ right = min(image.size[0], np.floor(right).astype('int32'))
430
+
431
+ label = '{} {:.2f}'.format(predicted_class, score)
432
+ draw = ImageDraw.Draw(image)
433
+ label_size = draw.textsize(label, font)
434
+ label = label.encode('utf-8')
435
+ print(label, top, left, bottom, right)
436
+
437
+ if top - label_size[1] >= 0:
438
+ text_origin = np.array([left, top - label_size[1]])
439
+ else:
440
+ text_origin = np.array([left, top + 1])
441
+
442
+ for i in range(thickness):
443
+ draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
444
+ draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
445
+ draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
446
+ del draw
447
+
448
+ return image