import gradio as gr from PIL import Image from ultralytics import YOLO import torch # ------------------------ # Load the YOLOv11 classification model # ------------------------ MODEL_PATH = "best.pt" model = YOLO(MODEL_PATH) # Mapping from model's short labels to full orientation names ORIENTATION_MAP = { "ax": "axial", "co": "coaxial", "sa": "sagittal" } # ------------------------ # Prediction function # ------------------------ def predict_orientation(image: Image.Image): # Perform inference using the model results = model.predict(source=image, device="cpu", imgsz=224, verbose=False) # Extract probabilities probs = results[0].probs.data.cpu() pred = torch.argmax(probs) # Get the original class label from the model original_class = model.names[pred.item()] # "ax", "co", or "sa" # Map to full orientation name orientation = ORIENTATION_MAP.get(original_class, original_class) # Get confidence score confidence = round(probs[pred].item(), 2) return f"Orientation: {orientation} | Confidence: {confidence}" # ------------------------ # Gradio Interface # ------------------------ iface = gr.Interface( fn=predict_orientation, inputs=gr.Image(type="pil"), outputs="text", title="MRI Orientation Predictor", description="Upload your image and the model outputs prediction and confidence." ) iface.launch()