Upload inference.py
Browse files- inference.py +15 -15
inference.py
CHANGED
|
@@ -195,14 +195,15 @@ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
|
|
| 195 |
# Run YOLOv8 classification (verbose=False suppresses progress bar)
|
| 196 |
results = animal_model(crop, verbose=False)
|
| 197 |
|
| 198 |
-
# Extract class names dict
|
|
|
|
| 199 |
names_dict = results[0].names
|
| 200 |
|
| 201 |
# Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
|
| 202 |
probs = results[0].probs.data.tolist()
|
| 203 |
|
| 204 |
# Build list of (class_name, confidence) tuples
|
| 205 |
-
# YOLOv8
|
| 206 |
classifications = []
|
| 207 |
for idx, class_name in names_dict.items():
|
| 208 |
confidence = probs[idx]
|
|
@@ -219,32 +220,31 @@ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
|
|
| 219 |
|
| 220 |
def get_class_names() -> dict[str, str]:
|
| 221 |
"""
|
| 222 |
-
Get mapping of class IDs to species names from
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
Returns:
|
| 228 |
Dict mapping class ID (1-indexed string) to common name
|
| 229 |
-
Example: {"1": "
|
| 230 |
|
| 231 |
Raises:
|
| 232 |
-
|
| 233 |
-
RuntimeError: If parsing fails
|
| 234 |
"""
|
| 235 |
-
global animal_model
|
| 236 |
|
| 237 |
-
# MODEL_DIR is injected by framework (guaranteed not None)
|
| 238 |
-
# YOLOv8 models have class names built-in
|
| 239 |
-
# We'll use those directly since they should match taxonomy.csv
|
| 240 |
if animal_model is None:
|
| 241 |
raise RuntimeError("Model not loaded - call load_model() first")
|
| 242 |
|
| 243 |
try:
|
| 244 |
-
# YOLOv8 names dict: {0: "
|
| 245 |
yolo_names = animal_model.names
|
| 246 |
|
| 247 |
-
# Convert to 1-indexed dict for compatibility
|
| 248 |
class_names = {}
|
| 249 |
for idx, name in yolo_names.items():
|
| 250 |
class_id_str = str(idx + 1) # 1-indexed
|
|
@@ -253,4 +253,4 @@ def get_class_names() -> dict[str, str]:
|
|
| 253 |
return class_names
|
| 254 |
|
| 255 |
except Exception as e:
|
| 256 |
-
raise RuntimeError(f"Failed to extract class names: {e}") from e
|
|
|
|
| 195 |
# Run YOLOv8 classification (verbose=False suppresses progress bar)
|
| 196 |
results = animal_model(crop, verbose=False)
|
| 197 |
|
| 198 |
+
# Extract class names dict (YOLOv8 uses alphabetical order)
|
| 199 |
+
# Example: {0: "aardwolf", 1: "african wild cat", ..., 13: "giraffe", ...}
|
| 200 |
names_dict = results[0].names
|
| 201 |
|
| 202 |
# Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
|
| 203 |
probs = results[0].probs.data.tolist()
|
| 204 |
|
| 205 |
# Build list of (class_name, confidence) tuples
|
| 206 |
+
# Return YOLOv8's class names (which will be mapped to taxonomy IDs later)
|
| 207 |
classifications = []
|
| 208 |
for idx, class_name in names_dict.items():
|
| 209 |
confidence = probs[idx]
|
|
|
|
| 220 |
|
| 221 |
def get_class_names() -> dict[str, str]:
|
| 222 |
"""
|
| 223 |
+
Get mapping of class IDs to species names from YOLOv8 model.
|
| 224 |
|
| 225 |
+
YOLOv8 stores class names in alphabetical order internally. This function
|
| 226 |
+
extracts those names and creates a 1-indexed mapping for the JSON format.
|
| 227 |
+
|
| 228 |
+
NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display.
|
| 229 |
+
The class IDs here are YOLOv8's alphabetical indices (0-based) + 1.
|
| 230 |
|
| 231 |
Returns:
|
| 232 |
Dict mapping class ID (1-indexed string) to common name
|
| 233 |
+
Example: {"1": "aardwolf", "2": "african wild cat", ..., "14": "giraffe", ...}
|
| 234 |
|
| 235 |
Raises:
|
| 236 |
+
RuntimeError: If model not loaded
|
|
|
|
| 237 |
"""
|
| 238 |
+
global animal_model
|
| 239 |
|
|
|
|
|
|
|
|
|
|
| 240 |
if animal_model is None:
|
| 241 |
raise RuntimeError("Model not loaded - call load_model() first")
|
| 242 |
|
| 243 |
try:
|
| 244 |
+
# YOLOv8 names dict (alphabetical order): {0: "aardwolf", 1: "african wild cat", ...}
|
| 245 |
yolo_names = animal_model.names
|
| 246 |
|
| 247 |
+
# Convert to 1-indexed dict for JSON compatibility
|
| 248 |
class_names = {}
|
| 249 |
for idx, name in yolo_names.items():
|
| 250 |
class_id_str = str(idx + 1) # 1-indexed
|
|
|
|
| 253 |
return class_names
|
| 254 |
|
| 255 |
except Exception as e:
|
| 256 |
+
raise RuntimeError(f"Failed to extract class names from model: {e}") from e
|