Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import onnxruntime as rt | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| H, W = 224, 224 | |
| classes=['aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable', | |
| 'dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] | |
| providers = ['CPUExecutionProvider'] | |
| m = rt.InferenceSession("./model/yolo_efficient.onnx", providers=providers) | |
| def nms(final_boxes, scores, IOU_threshold=0): | |
| scores = np.array(scores) | |
| final_boxes = np.array(final_boxes) | |
| boxes = final_boxes[...,:-1] | |
| boxes = [list(map(int, i)) for i in boxes] | |
| boxes = np.array(boxes) | |
| x1 = boxes[:, 0] | |
| y1 = boxes[:, 1] | |
| x2 = boxes[:, 2] | |
| y2 = boxes[:, 3] | |
| area = (x2 - x1)*(y2 - y1) | |
| order = np.argsort(scores) | |
| pick = [] | |
| while len(order) > 0: | |
| last = len(order)-1 | |
| i = order[last] | |
| pick.append(i) | |
| suppress = [last] | |
| if len(order)==0: | |
| break | |
| for pos in range(last): | |
| j = order[pos] | |
| xx1 = max(x1[i], x1[j]) | |
| yy1 = max(y1[i], y1[j]) | |
| xx2 = min(x2[i], x2[j]) | |
| yy2 = min(y2[i], y2[j]) | |
| w = max(0, xx2-xx1+1) | |
| h = max(0, yy2-yy1+1) | |
| overlap = float(w*h)/area[j] | |
| if overlap > IOU_threshold: | |
| suppress.append(pos) | |
| order = np.delete(order, suppress) | |
| return final_boxes[pick] | |
| def detect_obj(input_image, obj_threshold, bb_threshold): | |
| try: | |
| image = np.array(input_image) | |
| image = cv2.resize(image, (H, W)) | |
| img = image | |
| image = image.astype(np.float32) | |
| image = np.expand_dims(image, axis=0) | |
| output = m.run(['reshape'], {"input": image}) | |
| output = np.squeeze(output, axis=0) | |
| object_positions = np.concatenate( | |
| [np.stack(np.where(output[..., 0]>=obj_threshold), axis=-1), | |
| np.stack(np.where(output[..., 5]>=obj_threshold), axis=-1)], axis=0 | |
| ) | |
| selected_output = [] | |
| for indices in object_positions: | |
| selected_output.append(output[indices[0]][indices[1]][indices[2]]) | |
| selected_output = np.array(selected_output) | |
| final_boxes = [] | |
| final_scores = [] | |
| for i,pos in enumerate(object_positions): | |
| for j in range(2): | |
| if selected_output[i][j*5]>obj_threshold: | |
| output_box = np.array(output[pos[0]][pos[1]][pos[2]][(j*5)+1:(j*5)+5], dtype=float) | |
| x_centre = (np.array(pos[1], dtype=float) + output_box[0])*32 | |
| y_centre = (np.array(pos[2], dtype=float) + output_box[1])*32 | |
| x_width, y_height = abs(W*output_box[2]), abs(H*output_box[3]) | |
| x_min, y_min = int(x_centre - (x_width/2)), int(y_centre-(y_height/2)) | |
| x_max, y_max = int(x_centre+(x_width/2)), int(y_centre + (y_height/2)) | |
| if(x_min<0):x_min=0 | |
| if(y_min<0):y_min=0 | |
| if(x_max<0):x_max=0 | |
| if(y_max<0):y_max=0 | |
| final_boxes.append( | |
| [x_min, y_min, x_max, y_max, str(classes[np.argmax(selected_output[..., 10:], axis=-1)[i]])] | |
| ) | |
| final_scores.append(selected_output[i][j*5]) | |
| final_boxes = np.array(final_boxes) | |
| nms_output = nms(final_boxes, final_scores, bb_threshold) | |
| for i in nms_output: | |
| cv2.rectangle( | |
| img, | |
| (int(i[0]), int(i[1])), | |
| (int(i[2]), int(i[3])), (255, 0, 0) | |
| ) | |
| cv2.putText( | |
| img, | |
| i[-1], | |
| (int(i[0]), int(i[1])+15), | |
| cv2.FONT_HERSHEY_PLAIN, 1, (255, 0, 0), 1 | |
| ) | |
| output_pil_img = Image.fromarray(np.uint8(img)).convert('RGB') | |
| return output_pil_img | |
| except: | |
| return input_image | |
| with gr.Blocks(title="YOLOS Object Detection - ClassCat", css=".gradio-container {background:lightyellow;}") as demo: | |
| gr.HTML('<h1>Yolo Object Detection</h1>') | |
| gr.HTML("<h4>supported objects are [aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,diningtable,dog,horse,motorbike,person,pottedplant,sheep,sofa,train,tvmonitor]</h4>") | |
| gr.HTML("<br>") | |
| with gr.Row(): | |
| input_image = gr.Image(label="Input image", type="pil") | |
| output_image = gr.Image(label="Output image", type="pil") | |
| gr.HTML("<br>") | |
| gr.HTML("<h4>object centre detection threshold means the object centre will be considered a new object if it's value is above threshold</h4>") | |
| gr.HTML("<p>less means more objects</p>") | |
| gr.HTML("<h4>bounding box threshold is IOU value threshold. If intersection/union area of two bounding boxes are greater than threshold value the one box will be suppressed</h4>") | |
| gr.HTML("<p>more means more bounding boxes<p>") | |
| gr.HTML("<br>") | |
| obj_threshold = gr.Slider(0, 1.0, value=0.2, label=' object centre detection threshold') | |
| gr.HTML("<br>") | |
| bb_threshold = gr.Slider(0, 1.0, value=0.3, label=' bounding box draw threshold') | |
| gr.HTML("<br>") | |
| send_btn = gr.Button("Detect") | |
| gr.HTML("<br>") | |
| gr.Examples(['./samples/out_1.jpg'], inputs=input_image) | |
| send_btn.click(fn=detect_obj, inputs=[input_image, obj_threshold, bb_threshold], outputs=[output_image]) | |
| demo.launch(debug=True) |