File size: 1,997 Bytes
5b11294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import numpy as np
from ultralytics import YOLO
from PIL import Image
import torch
import yaml

# -----------------------------
# CONFIGURATION
# -----------------------------
MODEL_PATH = "best.pt"
DATA_YAML = "data.yaml"
IMG_SIZE = 512
IMG_PATH = "chroms.jpeg"
USE_GPU = True

# -----------------------------
# GPU/CPU
# -----------------------------
if USE_GPU and torch.cuda.is_available():
    torch.cuda.set_per_process_memory_fraction(0.5, device=0)
else:
    torch.device('cpu')

# -----------------------------
# Load class names 
# -----------------------------
with open(DATA_YAML, "r") as f:
    data = yaml.safe_load(f)

class_names = data["names"]  # alphabetical list
num_classes = len(class_names)

# -----------------------------
# Load YOLO Model
# -----------------------------
model = YOLO(MODEL_PATH)
print(f"Loaded model: {MODEL_PATH}")

# -----------------------------
# Build numeric folder → class index mapping
# -----------------------------
yolo_names = model.names  # dictionary from YOLO output

# Convert yolo class order into integers matching species indices
mapping = {yolo_idx: int(name_str) for yolo_idx, name_str in yolo_names.items()}

# -----------------------------
# Load & preprocess image
# -----------------------------
img = Image.open(IMG_PATH).convert('RGB')
img = img.resize((IMG_SIZE, IMG_SIZE))

# -----------------------------
# Inference
# -----------------------------
results = model.predict(source=img, imgsz=IMG_SIZE, verbose=False)
probs = results[0].probs.data.cpu().numpy()
# -----------------------------
# TOP-5 OUTPUT WITH REAL NAMES
# -----------------------------
top5_yolo_idx = np.argsort(probs)[-5:][::-1]
top5_probs = probs[top5_yolo_idx] * 100

print("\n🎣 Top-5 Predictions:")
for rank, yolo_idx in enumerate(top5_yolo_idx, 1):
    species_idx = mapping[yolo_idx]    # map YOLO index to species index
    species_name = class_names[species_idx]
    print(f"{rank}. {species_name}{top5_probs[rank-1]:.2f}%")