Spaces:
Build error
Build error
| import cv2 | |
| import gradio as gr | |
| from detection import PearDetectionModel | |
| from classification import predict | |
| # make streaming interface that reads from camera and displays the output with bounding boxes | |
| config = {"model_path": "./weights/best.pt", "classes": ['burn_bbox', 'defected_pear', 'defected_pear_bbox', 'normal_pear', 'normal_pear_bbox']} | |
| model = PearDetectionModel(config) | |
| def classify(image): | |
| """ | |
| Gradio에서 PIL 이미지를 입력받아 추론 결과를 반환. | |
| Args: | |
| image (PIL.Image): 업로드된 이미지. | |
| Returns: | |
| str: 모델 예측 결과. | |
| """ | |
| # 임시 파일 저장 후 처리 | |
| image_path = "temp_image.jpg" | |
| image.save(image_path) | |
| return predict(image_path) | |
| def detect(img): | |
| cls, xyxy, conf = model.inference(img) | |
| for box, conf in zip(xyxy, conf): | |
| cv2.rectangle( | |
| img, | |
| (int(box[0]), int(box[1])), | |
| (int(box[2]), int(box[3])), | |
| (0, 255, 0), | |
| 2, | |
| ) | |
| cv2.putText( | |
| img, | |
| f"{conf:.2f}", | |
| (int(box[0]), int(box[1])), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1, | |
| (0, 255, 0), | |
| 2, | |
| ) | |
| cv2.putText( | |
| img, | |
| "Class: Normal Pear" if cls == 0 else "Class: Abnormal Pear", | |
| (0, 50), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1, | |
| (0, 255, 0), | |
| 2, | |
| ) | |
| return img | |
| css = """.my-group {max-width: 500px !important; max-height: 500px !important;} | |
| .my-column {display: flex !important; justify-content: center !important; align-items: center !important};""" | |
| with gr.Blocks(css=css) as demo: | |
| demo.title = "Pear Playground" | |
| # add markdown | |
| gr.Markdown("## This is a demo for Pear Playground by AISeed.") | |
| with gr.Tab(label="Classification"): | |
| gr.Interface( | |
| fn=classify, | |
| inputs=gr.Image(type="pil", label="Upload an image"), | |
| outputs=gr.Label(num_top_classes=9), | |
| examples=["examples/1.jpg", "examples/2.jpg"], | |
| title="비정상 과수 분류기", | |
| description="경량 모델 ResNet101e 을 활용하여 비정상배 분류" | |
| ) | |
| with gr.Tab(label="Detection"): | |
| with gr.Column(elem_classes=["my-column"]): | |
| with gr.Group(elem_classes=["my-group"]): | |
| input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True) | |
| input_img.stream( | |
| detect, | |
| [input_img], | |
| [input_img], | |
| time_limit=30, | |
| stream_every=0.1, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(allowed_paths=["./examples"]) | |