Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import tensorflow as tf | |
| from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions | |
| from tensorflow.keras.preprocessing import image | |
| from ultralytics import YOLO | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import os | |
| from torchvision import transforms | |
| # Define the class labels | |
| classes = {0: "Defective", 1: "Good"} | |
| model_path = "ResNet50_Classification.h5" # Trained RestNet50 model | |
| best_yolo_model = "best.pt" # Trained YOLOv8 detection model | |
| classification_model = tf.keras.models.load_model('ResNet50_Classification.h5') | |
| detection_model = YOLO(best_yolo_model, task='detect') | |
| # Define the image preprocessing function | |
| def preprocess_image(pilimg): | |
| img = pilimg.resize((224, 224)) # Resize to the input size of ResNet50 | |
| img_array = image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) # Add batch dimension | |
| return img_array | |
| def classify_image(pilimg): | |
| img_array = preprocess_image(pilimg) # Preprocess the input image | |
| classify_result = classification_model.predict(img_array)[0][0] # Get prediction probability | |
| print(">>> Result : ", classify_result) | |
| predicted_class = "Good" if classify_result >= 0.5 else "Defective" | |
| print(">>> predicted_class : ", predicted_class) | |
| return predicted_class | |
| def detect_defect(img): | |
| detection_result = detection_model.predict(img, conf=0.4, iou=0.5) | |
| return detection_result | |
| def process_image(pilimg): | |
| summary_str = "" # summary variable | |
| # Perform classification first, then perform detection if Defective | |
| classification = classify_image(pilimg) | |
| if classification == "Good": | |
| out_pilimg = pilimg.convert("RGB") | |
| draw = ImageDraw.Draw(out_pilimg) | |
| font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" | |
| font = ImageFont.truetype(font_path, 30) | |
| #font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font | |
| draw.text((250, 10), "Good", fill="green", font=font) | |
| #summary_str = "No defect is detected, the cap is GOOD!" | |
| summary_str = f"<span style='font-size:30px; font-weight:bold; color:green'>No defect is detected, the cap is GOOD!</span>" | |
| else: # Defective | |
| detection_result = detect_defect(pilimg) | |
| img_bgr = detection_result[0].plot() | |
| out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image | |
| draw = ImageDraw.Draw(out_pilimg) | |
| font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" | |
| font = ImageFont.truetype(font_path, 30) | |
| #font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font | |
| draw.text((300, 10), "Defective", fill="red", font=font) | |
| detections = detection_result[0].boxes.data # Get detections | |
| if len(detections) > 0: | |
| #summary_str = "Defect is detected, the cap is BAD" | |
| summary_str = f"<span style='font-size:30px; font-weight:bold; color:red'>Defect is detected, the cap is BAD!</span>" | |
| else: | |
| #summary_str = "The cap is classifed as Defective but the defect cannot be detected!" | |
| summary_str = f"<span style='font-size:30px; font-weight:bold; color:blue'>The cap is classifed as Defective but the defect cannot be detected!</span>" | |
| #return out_pilimg, f"**{summary_str}**" | |
| return out_pilimg, summary_str | |
| title = "Detect the status of the cap: DEFECTIVE or GOOD" | |
| interface = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(type="pil", label="Input Image"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Classification/Detection result"), | |
| gr.Markdown(label="Classification/Detection Summary"), | |
| ], | |
| title=title, | |
| ) | |
| # Launch the interface | |
| interface.launch(share=True) | |