sickelcellcdcd / app.py
rjaditya's picture
Update app.py
855ba03 verified
import os
import gradio as gr
from ultralytics import YOLO
from PIL import Image
# Suppress Ultralytics write warning by setting config dir
os.environ["YOLO_CONFIG_DIR"] = "UltralyticsConfig"
os.makedirs("UltralyticsConfig", exist_ok=True)
# Path to YOLO model weights
model_path = "yolo_training/sickle_cls_model/weights/best.pt"
# Check if model exists
if not os.path.exists(model_path):
raise FileNotFoundError(
f"Model file not found at {model_path}. "
"Make sure best.pt is uploaded to this path in your repo."
)
# Load YOLO model
model = YOLO(model_path)
# Class names
classes = ['sickle', 'non_sickle', 'AIN']
# Prediction function
def predict_image(img):
# YOLO expects a list of images
results = model([img])
# Extract top prediction from the first image result
probs = results[0].probs
top1_idx = probs.top1
top1_conf = probs.top1conf.item()
top1_class = classes[top1_idx]
return top1_class, top1_conf
# Create Gradio interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(), gr.Number()],
title="Sickle Cell Classification",
description=(
"Upload a blood smear image to classify it as "
"Sickle Cell (sickle), Non-Sickle (non_sickle), "
"or Artifact/Impurities/Noise (AIN)."
)
)
# Launch interface
if __name__ == "__main__":
interface.launch()