Spaces:
Sleeping
Sleeping
| import cv2, os | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from models.common import DetectMultiBackend | |
| from utils.augmentations import letterbox | |
| from utils.general import non_max_suppression, scale_boxes | |
| from utils.plots import Annotator, colors, save_one_box | |
| from utils.torch_utils import select_device | |
| img_size = 640 | |
| stride = 32 | |
| auto = True | |
| max_det=1000 | |
| classes = None | |
| agnostic_nms = False | |
| line_thickness = 3 | |
| # Load model | |
| device = select_device('cpu') | |
| dnn =False | |
| data = "data/custom_data.yaml" | |
| weights = "weights/best.pt" | |
| model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=False) | |
| def inference(image, iou_thres=0.5, conf_thres=0.5): | |
| im = letterbox(image, img_size, stride=stride, auto=auto)[0] # padded resize | |
| im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | |
| im = np.ascontiguousarray(im) # contiguous | |
| im = torch.from_numpy(im).to(model.device) | |
| im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 | |
| im /= 255 # 0 - 255 to 0.0 - 1.0 | |
| if len(im.shape) == 3: | |
| im = im[None] | |
| pred = model(im, augment=False, visualize=False) | |
| pred = pred[0][1] if isinstance(pred[0], list) else pred[0] | |
| pred = non_max_suppression(pred, conf_thres, iou_thres, None, agnostic_nms, max_det=max_det) | |
| # Process predictions | |
| for i, det in enumerate(pred): # per image | |
| gn = torch.tensor(image.shape)[[1, 0, 1, 0]] # normalization gain whwh | |
| imc = image.copy() | |
| annotator = Annotator(image, line_width=line_thickness, example=str({0:'buffalo'})) | |
| if len(det): | |
| # Rescale boxes from img_size to im0 size | |
| det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image.shape).round() | |
| # Print results | |
| for c in det[:, 5].unique(): | |
| n = (det[:, 5] == c).sum() # detections per class | |
| # Write results | |
| for *xyxy, conf, cls in reversed(det): | |
| c = int(cls) # integer class | |
| label = 'buffalo' | |
| annotator.box_label(xyxy, label, color=colors(c, True)) | |
| # Stream results | |
| im0 = annotator.result() | |
| return im0 | |
| title = "YOLO V9 trained on Custom Dataset" | |
| description = "Gradio interface to show yoloV9 object detection." | |
| examples = [[f'examples/{i}'] for i in os.listdir("examples")] | |
| demo = gr.Interface( | |
| inference, | |
| inputs = [gr.Image(height=640, width = 640, label="Input Image"), gr.Slider(0, 1, value = 0.5, label="IOU Value"), gr.Slider(0, 1, value = 0.5, label="Threshold Value")], | |
| outputs = [gr.Image(label="YoloV9 Output", height=640, width = 640)], | |
| title = title, | |
| description = description, | |
| examples = examples, | |
| ) | |
| demo.launch() |