Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| from ultralytics import YOLO | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import os | |
| from torchvision import transforms | |
| classes = {0: "Defective", 1: "Good"} | |
| model_path = "./best.pt" | |
| def load_model_local(): | |
| detection_model = YOLO(model_path, task='classify') # Load the model | |
| return detection_model | |
| def load_model(repo_id): | |
| download_dir = snapshot_download(repo_id) | |
| print(download_dir) | |
| path = os.path.join(download_dir, "best_int8_openvino_model") | |
| print(path) | |
| detection_model = YOLO(path, task='classify') | |
| return detection_model | |
| def predict(pilimg): | |
| source = pilimg | |
| # Call the model to transform image size | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| source = transform(source) # Update the source image size to 224x224, 1 of 2 sizes accepted by Yolo classification model | |
| #result = detection_model.predict(source, conf=0.5, iou=0.6) | |
| result = detection_model.predict(source) # Make prediction | |
| # Get the top prediction | |
| label = result[0].probs.top1 | |
| class_names = detection_model.names # Retrieves the class names mapping (dict-like) | |
| classified_type = class_names[label] # Map numeric label to class name | |
| print (">>> Class : ", classified_type) | |
| confidence = result[0].probs.top1conf # Get the top class confidence | |
| print(">>> Confidence : ", confidence) | |
| annotated_image = pilimg.convert("RGB") | |
| draw = ImageDraw.Draw(annotated_image) | |
| font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" | |
| font = ImageFont.truetype(font_path, 30) | |
| #font = ImageFont.truetype("font/arialbd.ttf", 30) # Use arial.ttf for bold font | |
| if classified_type == classes[0]: | |
| draw.text((300, 10), classified_type, fill="red", font=font) | |
| gr.Warning("Defect detected, BAD!.") | |
| else: | |
| draw.text((300, 10), classified_type, fill="green", font=font) | |
| gr.Info("No defect detected,GOOD!") | |
| #draw.text((300, 10), classified_type, fill="red", font=font) | |
| return annotated_image | |
| detection_model = load_model_local() | |
| title = "Detect the status of the cap, DEFECTIVE or GOOD" | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Input Image"), | |
| outputs=gr.Image(type="pil", label="Classification result"), | |
| title=title, | |
| ) | |
| # Launch the interface | |
| interface.launch(share=True) | |