Upload inference.py
Browse files- inference.py +8 -14
inference.py
CHANGED
|
@@ -74,9 +74,8 @@ def load_model() -> None:
|
|
| 74 |
"""
|
| 75 |
global animal_model, MODEL_PATH
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
if not MODEL_PATH.exists():
|
| 81 |
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
|
| 82 |
|
|
@@ -181,8 +180,8 @@ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
|
|
| 181 |
crop: Cropped and preprocessed PIL Image
|
| 182 |
|
| 183 |
Returns:
|
| 184 |
-
List of (
|
| 185 |
-
Example: [("
|
| 186 |
|
| 187 |
Raises:
|
| 188 |
RuntimeError: If model not loaded or inference fails
|
|
@@ -202,15 +201,12 @@ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
|
|
| 202 |
# Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
|
| 203 |
probs = results[0].probs.data.tolist()
|
| 204 |
|
| 205 |
-
# Build list of (
|
| 206 |
-
#
|
| 207 |
classifications = []
|
| 208 |
for idx, class_name in names_dict.items():
|
| 209 |
-
# YOLOv8 uses 0-based indexing, but we need 1-based for compatibility
|
| 210 |
-
# with the taxonomy.csv and expected JSON output
|
| 211 |
-
class_id_str = str(idx + 1) # Convert 0-indexed to 1-indexed
|
| 212 |
confidence = probs[idx]
|
| 213 |
-
classifications.append((
|
| 214 |
|
| 215 |
# Sort by confidence descending (already sorted by YOLOv8, but ensure it)
|
| 216 |
classifications.sort(key=lambda x: x[1], reverse=True)
|
|
@@ -238,9 +234,7 @@ def get_class_names() -> dict[str, str]:
|
|
| 238 |
"""
|
| 239 |
global animal_model, MODEL_DIR
|
| 240 |
|
| 241 |
-
|
| 242 |
-
raise RuntimeError("MODEL_DIR not set - must be injected by framework")
|
| 243 |
-
|
| 244 |
# YOLOv8 models have class names built-in
|
| 245 |
# We'll use those directly since they should match taxonomy.csv
|
| 246 |
if animal_model is None:
|
|
|
|
| 74 |
"""
|
| 75 |
global animal_model, MODEL_PATH
|
| 76 |
|
| 77 |
+
# MODEL_PATH is injected by framework before this function is called
|
| 78 |
+
# Check that the path exists (framework guarantees it's not None)
|
|
|
|
| 79 |
if not MODEL_PATH.exists():
|
| 80 |
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
|
| 81 |
|
|
|
|
| 180 |
crop: Cropped and preprocessed PIL Image
|
| 181 |
|
| 182 |
Returns:
|
| 183 |
+
List of (class_name, confidence) tuples for ALL classes, sorted by confidence.
|
| 184 |
+
Example: [("giraffe", 0.99985), ("cattle", 0.00003), ...]
|
| 185 |
|
| 186 |
Raises:
|
| 187 |
RuntimeError: If model not loaded or inference fails
|
|
|
|
| 201 |
# Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
|
| 202 |
probs = results[0].probs.data.tolist()
|
| 203 |
|
| 204 |
+
# Build list of (class_name, confidence) tuples
|
| 205 |
+
# YOLOv8 names_dict is already in the correct format: {0: "porcupine", 1: "elephant", ...}
|
| 206 |
classifications = []
|
| 207 |
for idx, class_name in names_dict.items():
|
|
|
|
|
|
|
|
|
|
| 208 |
confidence = probs[idx]
|
| 209 |
+
classifications.append((class_name, confidence))
|
| 210 |
|
| 211 |
# Sort by confidence descending (already sorted by YOLOv8, but ensure it)
|
| 212 |
classifications.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
| 234 |
"""
|
| 235 |
global animal_model, MODEL_DIR
|
| 236 |
|
| 237 |
+
# MODEL_DIR is injected by framework (guaranteed not None)
|
|
|
|
|
|
|
| 238 |
# YOLOv8 models have class names built-in
|
| 239 |
# We'll use those directly since they should match taxonomy.csv
|
| 240 |
if animal_model is None:
|