File size: 1,419 Bytes
855ba03
1208a8f
 
 
 
855ba03
 
 
 
 
 
 
 
 
 
 
 
 
1208a8f
855ba03
1208a8f
 
855ba03
1208a8f
 
855ba03
1208a8f
855ba03
1208a8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855ba03
 
 
 
 
1208a8f
 
855ba03
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import gradio as gr
from ultralytics import YOLO
from PIL import Image

# Suppress Ultralytics write warning by setting config dir
os.environ["YOLO_CONFIG_DIR"] = "UltralyticsConfig"
os.makedirs("UltralyticsConfig", exist_ok=True)

# Path to YOLO model weights
model_path = "yolo_training/sickle_cls_model/weights/best.pt"

# Check if model exists
if not os.path.exists(model_path):
    raise FileNotFoundError(
        f"Model file not found at {model_path}. "
        "Make sure best.pt is uploaded to this path in your repo."
    )

# Load YOLO model
model = YOLO(model_path)

# Class names
classes = ['sickle', 'non_sickle', 'AIN']

# Prediction function
def predict_image(img):
    # YOLO expects a list of images
    results = model([img])

    # Extract top prediction from the first image result
    probs = results[0].probs
    top1_idx = probs.top1
    top1_conf = probs.top1conf.item()
    top1_class = classes[top1_idx]

    return top1_class, top1_conf

# Create Gradio interface
interface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Label(), gr.Number()],
    title="Sickle Cell Classification",
    description=(
        "Upload a blood smear image to classify it as "
        "Sickle Cell (sickle), Non-Sickle (non_sickle), "
        "or Artifact/Impurities/Noise (AIN)."
    )
)

# Launch interface
if __name__ == "__main__":
    interface.launch()