leonelgv commited on
Commit
179346e
·
verified ·
1 Parent(s): 0571425

Add Main inference script for pollinator classification

Browse files
Files changed (1) hide show
  1. pollinator_classifier.py +128 -0
pollinator_classifier.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 🔬 Clasificador de Insectos Polinizadores - Versión de Producción
4
+ Precisión alcanzada: 92.07%
5
+ Modelo: YOLOv8 Nano
6
+ """
7
+
8
+ from ultralytics import YOLO
9
+ import sys
10
+ import os
11
+ from pathlib import Path
12
+
13
+ class PollinatorClassifier:
14
+ def __init__(self, model_path="pollinator_results/nano_quick/weights/best.pt"):
15
+ """Inicializar el clasificador"""
16
+ try:
17
+ self.model = YOLO(model_path)
18
+ self.classes = [
19
+ 'Acmaeodera Flavomarginata', 'Acromyrmex Octospinosus',
20
+ 'Adelpha Basiloides', 'Adelpha Iphicleola', 'Aedes Aegypti',
21
+ 'Agrius Cingulata', 'Anaea Aidea', 'Anartia fatima',
22
+ 'Anartia jatrophae', 'Anoplolepis Gracilipes'
23
+ ]
24
+ print("🔬 Clasificador de Insectos Polinizadores v1.0")
25
+ print(f"✅ Modelo cargado con 92.07% de precisión")
26
+ print(f"🏷️ {len(self.classes)} clases disponibles")
27
+
28
+ except Exception as e:
29
+ print(f"❌ Error cargando modelo: {e}")
30
+ sys.exit(1)
31
+
32
+ def classify(self, image_path):
33
+ """Clasificar una imagen de insecto"""
34
+
35
+ if not os.path.exists(image_path):
36
+ print(f"❌ Imagen no encontrada: {image_path}")
37
+ return None
38
+
39
+ # Predicción
40
+ results = self.model(image_path, verbose=False)
41
+ probs = results[0].probs
42
+
43
+ # Obtener predicción principal
44
+ top_class_idx = probs.top1
45
+ confidence = probs.top1conf.item() * 100
46
+ predicted_class = self.classes[top_class_idx]
47
+
48
+ print(f"\n🔍 Imagen: {os.path.basename(image_path)}")
49
+ print(f"🎯 Predicción: {predicted_class}")
50
+ print(f"📊 Confianza: {confidence:.1f}%")
51
+
52
+ # Top 3 predicciones
53
+ print(f"\n📋 Top 3 predicciones:")
54
+ for i in range(min(3, len(probs.top5))):
55
+ idx = probs.top5[i]
56
+ conf = probs.top5conf[i].item() * 100
57
+ class_name = self.classes[idx]
58
+ emoji = "🥇" if i == 0 else "🥈" if i == 1 else "🥉"
59
+ print(f" {emoji} {class_name}: {conf:.1f}%")
60
+
61
+ return predicted_class, confidence
62
+
63
+ def classify_batch(self, folder_path):
64
+ """Clasificar múltiples imágenes en una carpeta"""
65
+
66
+ folder = Path(folder_path)
67
+ if not folder.exists():
68
+ print(f"❌ Carpeta no encontrada: {folder_path}")
69
+ return
70
+
71
+ # Buscar imágenes
72
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
73
+ images = []
74
+ for ext in image_extensions:
75
+ images.extend(list(folder.glob(ext)))
76
+
77
+ if not images:
78
+ print("❌ No se encontraron imágenes")
79
+ return
80
+
81
+ print(f"🔍 Clasificando {len(images)} imágenes...")
82
+ print("-" * 60)
83
+
84
+ results = []
85
+ for img_path in images:
86
+ pred_class, confidence = self.classify(str(img_path))
87
+ if pred_class:
88
+ results.append({
89
+ 'imagen': img_path.name,
90
+ 'prediccion': pred_class,
91
+ 'confianza': confidence
92
+ })
93
+
94
+ return results
95
+
96
+ def main():
97
+ """Función principal"""
98
+ classifier = PollinatorClassifier()
99
+
100
+ if len(sys.argv) < 2:
101
+ # Modo interactivo
102
+ print("\n🎯 MODO INTERACTIVO")
103
+ print("Opciones:")
104
+ print("1. Clasificar una imagen")
105
+ print("2. Clasificar carpeta de imágenes")
106
+
107
+ choice = input("\nSelecciona opción (1 o 2): ")
108
+
109
+ if choice == "1":
110
+ image_path = input("Ruta de la imagen: ")
111
+ classifier.classify(image_path)
112
+ elif choice == "2":
113
+ folder_path = input("Ruta de la carpeta: ")
114
+ classifier.classify_batch(folder_path)
115
+ else:
116
+ print("Opción inválida")
117
+ else:
118
+ # Modo comando
119
+ path = sys.argv[1]
120
+ if os.path.isfile(path):
121
+ classifier.classify(path)
122
+ elif os.path.isdir(path):
123
+ classifier.classify_batch(path)
124
+ else:
125
+ print(f"❌ Ruta inválida: {path}")
126
+
127
+ if __name__ == "__main__":
128
+ main()