Addax-Data-Science commited on
Commit
d7437ff
·
verified ·
1 Parent(s): f8950ea

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -15
inference.py CHANGED
@@ -195,14 +195,15 @@ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
195
  # Run YOLOv8 classification (verbose=False suppresses progress bar)
196
  results = animal_model(crop, verbose=False)
197
 
198
- # Extract class names dict: {0: "porcupine", 1: "elephant", ...}
 
199
  names_dict = results[0].names
200
 
201
  # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
202
  probs = results[0].probs.data.tolist()
203
 
204
  # Build list of (class_name, confidence) tuples
205
- # YOLOv8 names_dict is already in the correct format: {0: "porcupine", 1: "elephant", ...}
206
  classifications = []
207
  for idx, class_name in names_dict.items():
208
  confidence = probs[idx]
@@ -219,32 +220,31 @@ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
219
 
220
  def get_class_names() -> dict[str, str]:
221
  """
222
- Get mapping of class IDs to species names from taxonomy.csv.
223
 
224
- Reads taxonomy.csv from the model directory and extracts the model_class
225
- (common name) for each species.
 
 
 
226
 
227
  Returns:
228
  Dict mapping class ID (1-indexed string) to common name
229
- Example: {"1": "porcupine", "2": "elephant", ...}
230
 
231
  Raises:
232
- FileNotFoundError: If taxonomy.csv not found
233
- RuntimeError: If parsing fails
234
  """
235
- global animal_model, MODEL_DIR
236
 
237
- # MODEL_DIR is injected by framework (guaranteed not None)
238
- # YOLOv8 models have class names built-in
239
- # We'll use those directly since they should match taxonomy.csv
240
  if animal_model is None:
241
  raise RuntimeError("Model not loaded - call load_model() first")
242
 
243
  try:
244
- # YOLOv8 names dict: {0: "porcupine", 1: "elephant", ...}
245
  yolo_names = animal_model.names
246
 
247
- # Convert to 1-indexed dict for compatibility
248
  class_names = {}
249
  for idx, name in yolo_names.items():
250
  class_id_str = str(idx + 1) # 1-indexed
@@ -253,4 +253,4 @@ def get_class_names() -> dict[str, str]:
253
  return class_names
254
 
255
  except Exception as e:
256
- raise RuntimeError(f"Failed to extract class names: {e}") from e
 
195
  # Run YOLOv8 classification (verbose=False suppresses progress bar)
196
  results = animal_model(crop, verbose=False)
197
 
198
+ # Extract class names dict (YOLOv8 uses alphabetical order)
199
+ # Example: {0: "aardwolf", 1: "african wild cat", ..., 13: "giraffe", ...}
200
  names_dict = results[0].names
201
 
202
  # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
203
  probs = results[0].probs.data.tolist()
204
 
205
  # Build list of (class_name, confidence) tuples
206
+ # Return YOLOv8's class names (which will be mapped to taxonomy IDs later)
207
  classifications = []
208
  for idx, class_name in names_dict.items():
209
  confidence = probs[idx]
 
220
 
221
  def get_class_names() -> dict[str, str]:
222
  """
223
+ Get mapping of class IDs to species names from YOLOv8 model.
224
 
225
+ YOLOv8 stores class names in alphabetical order internally. This function
226
+ extracts those names and creates a 1-indexed mapping for the JSON format.
227
+
228
+ NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display.
229
+ The class IDs here are YOLOv8's alphabetical indices (0-based) + 1.
230
 
231
  Returns:
232
  Dict mapping class ID (1-indexed string) to common name
233
+ Example: {"1": "aardwolf", "2": "african wild cat", ..., "14": "giraffe", ...}
234
 
235
  Raises:
236
+ RuntimeError: If model not loaded
 
237
  """
238
+ global animal_model
239
 
 
 
 
240
  if animal_model is None:
241
  raise RuntimeError("Model not loaded - call load_model() first")
242
 
243
  try:
244
+ # YOLOv8 names dict (alphabetical order): {0: "aardwolf", 1: "african wild cat", ...}
245
  yolo_names = animal_model.names
246
 
247
+ # Convert to 1-indexed dict for JSON compatibility
248
  class_names = {}
249
  for idx, name in yolo_names.items():
250
  class_id_str = str(idx + 1) # 1-indexed
 
253
  return class_names
254
 
255
  except Exception as e:
256
+ raise RuntimeError(f"Failed to extract class names from model: {e}") from e