Addax-Data-Science commited on
Commit
f8950ea
·
verified ·
1 Parent(s): 343862b

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +8 -14
inference.py CHANGED
@@ -74,9 +74,8 @@ def load_model() -> None:
74
  """
75
  global animal_model, MODEL_PATH
76
 
77
- if MODEL_PATH is None:
78
- raise RuntimeError("MODEL_PATH not set - must be injected by framework")
79
-
80
  if not MODEL_PATH.exists():
81
  raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
82
 
@@ -181,8 +180,8 @@ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
181
  crop: Cropped and preprocessed PIL Image
182
 
183
  Returns:
184
- List of (class_id, confidence) tuples for ALL classes, sorted by confidence.
185
- Example: [("14", 0.99985), ("7", 0.00003), ...]
186
 
187
  Raises:
188
  RuntimeError: If model not loaded or inference fails
@@ -202,15 +201,12 @@ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
202
  # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
203
  probs = results[0].probs.data.tolist()
204
 
205
- # Build list of (class_id, confidence) tuples
206
- # Class IDs are 0-indexed in YOLOv8 but we output as strings
207
  classifications = []
208
  for idx, class_name in names_dict.items():
209
- # YOLOv8 uses 0-based indexing, but we need 1-based for compatibility
210
- # with the taxonomy.csv and expected JSON output
211
- class_id_str = str(idx + 1) # Convert 0-indexed to 1-indexed
212
  confidence = probs[idx]
213
- classifications.append((class_id_str, confidence))
214
 
215
  # Sort by confidence descending (already sorted by YOLOv8, but ensure it)
216
  classifications.sort(key=lambda x: x[1], reverse=True)
@@ -238,9 +234,7 @@ def get_class_names() -> dict[str, str]:
238
  """
239
  global animal_model, MODEL_DIR
240
 
241
- if MODEL_DIR is None:
242
- raise RuntimeError("MODEL_DIR not set - must be injected by framework")
243
-
244
  # YOLOv8 models have class names built-in
245
  # We'll use those directly since they should match taxonomy.csv
246
  if animal_model is None:
 
74
  """
75
  global animal_model, MODEL_PATH
76
 
77
+ # MODEL_PATH is injected by framework before this function is called
78
+ # Check that the path exists (framework guarantees it's not None)
 
79
  if not MODEL_PATH.exists():
80
  raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
81
 
 
180
  crop: Cropped and preprocessed PIL Image
181
 
182
  Returns:
183
+ List of (class_name, confidence) tuples for ALL classes, sorted by confidence.
184
+ Example: [("giraffe", 0.99985), ("cattle", 0.00003), ...]
185
 
186
  Raises:
187
  RuntimeError: If model not loaded or inference fails
 
201
  # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
202
  probs = results[0].probs.data.tolist()
203
 
204
+ # Build list of (class_name, confidence) tuples
205
+ # YOLOv8 names_dict is already in the correct format: {0: "porcupine", 1: "elephant", ...}
206
  classifications = []
207
  for idx, class_name in names_dict.items():
 
 
 
208
  confidence = probs[idx]
209
+ classifications.append((class_name, confidence))
210
 
211
  # Sort by confidence descending (already sorted by YOLOv8, but ensure it)
212
  classifications.sort(key=lambda x: x[1], reverse=True)
 
234
  """
235
  global animal_model, MODEL_DIR
236
 
237
+ # MODEL_DIR is injected by framework (guaranteed not None)
 
 
238
  # YOLOv8 models have class names built-in
239
  # We'll use those directly since they should match taxonomy.csv
240
  if animal_model is None: