import os import gradio as gr from ultralytics import YOLO from PIL import Image # Suppress Ultralytics write warning by setting config dir os.environ["YOLO_CONFIG_DIR"] = "UltralyticsConfig" os.makedirs("UltralyticsConfig", exist_ok=True) # Path to YOLO model weights model_path = "yolo_training/sickle_cls_model/weights/best.pt" # Check if model exists if not os.path.exists(model_path): raise FileNotFoundError( f"Model file not found at {model_path}. " "Make sure best.pt is uploaded to this path in your repo." ) # Load YOLO model model = YOLO(model_path) # Class names classes = ['sickle', 'non_sickle', 'AIN'] # Prediction function def predict_image(img): # YOLO expects a list of images results = model([img]) # Extract top prediction from the first image result probs = results[0].probs top1_idx = probs.top1 top1_conf = probs.top1conf.item() top1_class = classes[top1_idx] return top1_class, top1_conf # Create Gradio interface interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs=[gr.Label(), gr.Number()], title="Sickle Cell Classification", description=( "Upload a blood smear image to classify it as " "Sickle Cell (sickle), Non-Sickle (non_sickle), " "or Artifact/Impurities/Noise (AIN)." ) ) # Launch interface if __name__ == "__main__": interface.launch()