MRI_orientation / app.py
shimaa22's picture
Update app.py
fa01726 verified
import gradio as gr
from PIL import Image
from ultralytics import YOLO
import torch
# ------------------------
# Load the YOLOv11 classification model
# ------------------------
MODEL_PATH = "best.pt"
model = YOLO(MODEL_PATH)
# Mapping from model's short labels to full orientation names
ORIENTATION_MAP = {
"ax": "axial",
"co": "coaxial",
"sa": "sagittal"
}
# ------------------------
# Prediction function
# ------------------------
def predict_orientation(image: Image.Image):
# Perform inference using the model
results = model.predict(source=image, device="cpu", imgsz=224, verbose=False)
# Extract probabilities
probs = results[0].probs.data.cpu()
pred = torch.argmax(probs)
# Get the original class label from the model
original_class = model.names[pred.item()] # "ax", "co", or "sa"
# Map to full orientation name
orientation = ORIENTATION_MAP.get(original_class, original_class)
# Get confidence score
confidence = round(probs[pred].item(), 2)
return f"Orientation: {orientation} | Confidence: {confidence}"
# ------------------------
# Gradio Interface
# ------------------------
iface = gr.Interface(
fn=predict_orientation,
inputs=gr.Image(type="pil"),
outputs="text",
title="MRI Orientation Predictor",
description="Upload your image and the model outputs prediction and confidence."
)
iface.launch()