NEMOtools / app /Inference_YoLOv11.py
AndrewKof's picture
πŸš€ Update UI with LFS for images and models
5b11294
import os
import numpy as np
from ultralytics import YOLO
from PIL import Image
import torch
import yaml
# -----------------------------
# CONFIGURATION
# -----------------------------
MODEL_PATH = "best.pt"
DATA_YAML = "data.yaml"
IMG_SIZE = 512
IMG_PATH = "chroms.jpeg"
USE_GPU = True
# -----------------------------
# GPU/CPU
# -----------------------------
if USE_GPU and torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(0.5, device=0)
else:
torch.device('cpu')
# -----------------------------
# Load class names
# -----------------------------
with open(DATA_YAML, "r") as f:
data = yaml.safe_load(f)
class_names = data["names"] # alphabetical list
num_classes = len(class_names)
# -----------------------------
# Load YOLO Model
# -----------------------------
model = YOLO(MODEL_PATH)
print(f"Loaded model: {MODEL_PATH}")
# -----------------------------
# Build numeric folder β†’ class index mapping
# -----------------------------
yolo_names = model.names # dictionary from YOLO output
# Convert yolo class order into integers matching species indices
mapping = {yolo_idx: int(name_str) for yolo_idx, name_str in yolo_names.items()}
# -----------------------------
# Load & preprocess image
# -----------------------------
img = Image.open(IMG_PATH).convert('RGB')
img = img.resize((IMG_SIZE, IMG_SIZE))
# -----------------------------
# Inference
# -----------------------------
results = model.predict(source=img, imgsz=IMG_SIZE, verbose=False)
probs = results[0].probs.data.cpu().numpy()
# -----------------------------
# TOP-5 OUTPUT WITH REAL NAMES
# -----------------------------
top5_yolo_idx = np.argsort(probs)[-5:][::-1]
top5_probs = probs[top5_yolo_idx] * 100
print("\n🎣 Top-5 Predictions:")
for rank, yolo_idx in enumerate(top5_yolo_idx, 1):
species_idx = mapping[yolo_idx] # map YOLO index to species index
species_name = class_names[species_idx]
print(f"{rank}. {species_name} β€” {top5_probs[rank-1]:.2f}%")