Spaces:
Paused
Paused
| import gradio as gr | |
| import PIL.Image as Image | |
| import spaces | |
| import super_gradients | |
| from tools.tools import py_cpu_nms,get_sub_image,filter_small_fp | |
| import cv2 | |
| import numpy as np | |
| import os | |
| from classifiers.MixMatch.mixmatch_classification import mixmatch_classifier_inference | |
| def inference_mega_image_yolonas(img, conf_threshold, iou_threshold,height): | |
| record_list = [] | |
| model_dir = './checkpoint/yolonas/height_varient/ckpt_best{}.pth'.format(height) | |
| model = super_gradients.training.models.get('yolo_nas_m',num_classes=1,checkpoint_path=model_dir).cuda() | |
| # mega_image = np.array(img)[:, :, ::-1].copy() | |
| mega_image = img | |
| ratio = 1 | |
| bbox_list = [] | |
| sub_image_list, coor_list = get_sub_image(mega_image, overlap=0.2, ratio=ratio) | |
| for index, sub_image in enumerate(sub_image_list): | |
| # sub_image = cv2.cvtColor(sub_image, cv2.COLOR_BGR2RGB) | |
| # sub_image = Image.fromarray(sub_image) | |
| images_predictions = model.predict(sub_image) | |
| image_prediction = next(iter(images_predictions)) | |
| labels = image_prediction.prediction.labels | |
| confidences = image_prediction.prediction.confidence | |
| bboxes = image_prediction.prediction.bboxes_xyxy | |
| for i in range(len(labels)): | |
| label = labels[i] | |
| confidence = confidences[i] | |
| bbox = bboxes[i] | |
| if confidence > conf_threshold: | |
| bbox_list.append([int(coor_list[index][1]+bbox[0]), int(coor_list[index][0]+bbox[1]),int(coor_list[index][1]+bbox[2]), int(coor_list[index][0]+bbox[3]), confidence]) | |
| if (len(bbox_list) != 0): | |
| bbox_list = np.asarray([box for box in bbox_list]) | |
| box_idx = py_cpu_nms(bbox_list, iou_threshold) | |
| selected_bbox = bbox_list[box_idx] | |
| selected_bbox = sorted(selected_bbox,key = lambda x: x[4],reverse = True) | |
| mega_image = draw_image(mega_image,selected_bbox) | |
| else: | |
| selected_bbox = [] | |
| return mega_image,selected_bbox | |
| def draw_image(img,bboxes): | |
| for box in bboxes: | |
| cv2.rectangle(img, (int(box[0]),int(box[1])), (int(box[2]),int(box[3])), (0,255,0), 3) | |
| cv2.putText(img, 'bird', (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) | |
| return img | |
| def predict_image(img, conf_threshold, iou_threshold,height): | |
| result_image,bbox_list = inference_mega_image_yolonas(img, conf_threshold, iou_threshold,height) | |
| cla_dict = mixmatch_classifier_inference('./checkpoint/classifier/mixmatch/model_best.pth.tar',result_image,bbox_list) | |
| return result_image,cla_dict | |
| iface = gr.Interface( | |
| fn=predict_image, | |
| inputs=[ | |
| gr.Image(type="numpy", label="Upload Image"), | |
| gr.Slider(minimum=0, maximum=1, value=0.7, label="Confidence threshold"), | |
| gr.Slider(minimum=0, maximum=1, value=0.3, label="IoU threshold"), | |
| gr.Radio(["15m", "30m", "60m", "90m"], value="15m", label="Height", info="The image taken height"), | |
| ], | |
| outputs=[ | |
| gr.Image(type="numpy", label="Result"), | |
| gr.Image(type="numpy", label="BarChart") | |
| ], | |
| title="Waterfowl detection with YOLONAS", | |
| description="Upload images for Waterfowl object detection.", | |
| ) | |
| iface.launch(share=True) |