kaburia commited on
Commit
a7c5bc0
·
1 Parent(s): e2110e7

increased confidence scores

Browse files
Files changed (1) hide show
  1. app.py +71 -38
app.py CHANGED
@@ -1,7 +1,5 @@
1
- # import os
2
  import json
3
  import io
4
- # import cv2
5
  import numpy as np
6
  import torch
7
  from fastapi import FastAPI, File, UploadFile, HTTPException
@@ -9,27 +7,34 @@ from PIL import Image
9
  from ultralytics import YOLO
10
  from torchvision import transforms
11
 
12
- # Import your utils (ensure geofencing_utils.py is in the same folder)
13
  from geofencing_utils import geofence_animal_classification
14
 
15
  app = FastAPI()
16
 
17
- # --- CONFIGURATION ---
18
- IMG_SIZE_DET = 1280
19
- IMG_SIZE_CLF = 480
20
- CONF_THRESHOLD = 0.2
21
- MD_CLASSES = {0: 'animal', 1: 'person', 2: 'vehicle'}
 
 
 
 
 
 
 
 
22
 
23
  # Global Resources (Loaded on Startup)
24
  RESOURCES = {}
25
 
26
-
27
  @app.on_event("startup")
28
  async def load_models():
29
  """Loads models into memory once when the Space starts"""
30
  print("Loading models...")
31
 
32
- # PATHS (Upload your models to the Space root folder)
33
  model_det_path = "species-space/MDV6-yolov9-e-1280.pt"
34
  model_clf_path = "species-space/always_crop_99710272_22x8_v12_epoch_00148.pt"
35
  labels_path = "species-space/always_crop_99710272_22x8_v12_epoch_00148.labels.txt"
@@ -44,7 +49,7 @@ async def load_models():
44
 
45
  # Load Transforms
46
  RESOURCES['transform'] = transforms.Compose([
47
- transforms.Resize((IMG_SIZE_CLF, IMG_SIZE_CLF)),
48
  transforms.ToTensor()
49
  ])
50
 
@@ -53,11 +58,16 @@ async def load_models():
53
  lines = [line.strip() for line in f.readlines()]
54
  RESOURCES['labels_list'] = lines
55
  RESOURCES['taxonomy_map'] = {i: line for i, line in enumerate(lines)}
 
 
 
 
 
 
 
 
56
 
57
- with open(geofence_path, 'r') as f:
58
- RESOURCES['geofence_map'] = json.load(f)
59
-
60
- print("Models loaded successfully.")
61
 
62
  @app.post("/predict")
63
  async def predict(
@@ -69,14 +79,13 @@ async def predict(
69
  try:
70
  contents = await file.read()
71
  image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
72
- image_np = np.array(image_pil)
73
- # Convert RGB to BGR for OpenCV/YOLO if needed, though Ultralytics handles PIL
74
  except Exception as e:
75
  raise HTTPException(status_code=400, detail="Invalid image file")
76
 
77
  # 2. Detect
78
  detector = RESOURCES['detector']
79
- results = detector(image_pil, imgsz=IMG_SIZE_DET, conf=CONF_THRESHOLD, verbose=False)[0]
 
80
 
81
  detections = []
82
 
@@ -87,7 +96,6 @@ async def predict(
87
  classifier = RESOURCES['classifier']
88
  transform = RESOURCES['transform']
89
 
90
- # Buffers for batch classification
91
  crops_tensors = []
92
  crop_indices = [] # Index in 'detections' list
93
 
@@ -95,7 +103,7 @@ async def predict(
95
  cls_id = int(box.cls[0].item())
96
  conf = float(box.conf[0].item())
97
  xyxy = box.xyxy[0].tolist()
98
- label_det = MD_CLASSES.get(cls_id, 'unknown')
99
 
100
  det_obj = {
101
  "id": i + 1,
@@ -107,8 +115,6 @@ async def predict(
107
  if cls_id == 0: # Animal
108
  # Crop
109
  crop = image_pil.crop((int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])))
110
-
111
- # Transform & Permute (CHW -> HWC)
112
  t_crop = transform(crop).permute(1, 2, 0)
113
 
114
  crops_tensors.append(t_crop)
@@ -124,23 +130,34 @@ async def predict(
124
 
125
  # 4. Classify Animals
126
  if crops_tensors:
127
- # Stack
128
- batch = torch.stack(crops_tensors) # No .to(DEVICE) needed, default is CPU
129
 
130
  with torch.no_grad():
131
  logits = classifier(batch)
132
  probs = torch.softmax(logits, dim=1)
 
133
  top_conf, top_idx = torch.topk(probs, k=5, dim=1)
134
 
135
  for k in range(len(batch)):
136
  det_idx = crop_indices[k]
137
 
138
- # Prepare for Geofencing
139
  labels_list = RESOURCES['labels_list']
140
- k_labels = [labels_list[top_idx[k][r].item()] for r in range(5) if top_idx[k][r] < len(labels_list)]
141
- k_scores = [top_conf[k][r].item() for r in range(5) if top_idx[k][r] < len(labels_list)]
142
 
143
- if k_labels:
 
 
 
 
 
 
 
 
 
 
 
144
  final_lbl, final_scr, src = geofence_animal_classification(
145
  labels=k_labels,
146
  scores=k_scores,
@@ -148,21 +165,37 @@ async def predict(
148
  geofence_map=RESOURCES['geofence_map'],
149
  taxonomy_map=RESOURCES['taxonomy_map']
150
  )
151
-
152
- try: species_clean = final_lbl.split(';')[-1]
153
- except: species_clean = final_lbl
154
-
155
- detections[det_idx].update({
156
- "species": final_lbl,
157
- "species_common": species_clean,
158
- "species_conf": float(final_scr),
159
- "source": src
160
- })
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  # 5. Return JSON
163
  return {
164
  "filename": file.filename,
165
  "sensor_id": sensor_id,
 
 
 
 
166
  "total_detections": len(detections),
167
  "detections": detections
168
  }
 
 
1
  import json
2
  import io
 
3
  import numpy as np
4
  import torch
5
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
7
  from ultralytics import YOLO
8
  from torchvision import transforms
9
 
10
+ # Import your utils
11
  from geofencing_utils import geofence_animal_classification
12
 
13
  app = FastAPI()
14
 
15
+ # --- 1. GLOBAL CONFIGURATION ---
16
+ class GlobalConfig:
17
+ # TOGGLES
18
+ USE_GEOFENCING = False # Set to True to enable location-based filtering
19
+
20
+ # THRESHOLDS
21
+ DET_CONF_THRESH = 0.1 # Minimum confidence to detect an object (MegaDetector)
22
+ CLF_CONF_THRESH = 0.5 # Minimum confidence to accept a species classification
23
+
24
+ # CONSTANTS
25
+ IMG_SIZE_DET = 1280
26
+ IMG_SIZE_CLF = 480
27
+ MD_CLASSES = {0: 'animal', 1: 'person', 2: 'vehicle'}
28
 
29
  # Global Resources (Loaded on Startup)
30
  RESOURCES = {}
31
 
 
32
  @app.on_event("startup")
33
  async def load_models():
34
  """Loads models into memory once when the Space starts"""
35
  print("Loading models...")
36
 
37
+ # PATHS (Ensure these match your Hugging Face Space file structure)
38
  model_det_path = "species-space/MDV6-yolov9-e-1280.pt"
39
  model_clf_path = "species-space/always_crop_99710272_22x8_v12_epoch_00148.pt"
40
  labels_path = "species-space/always_crop_99710272_22x8_v12_epoch_00148.labels.txt"
 
49
 
50
  # Load Transforms
51
  RESOURCES['transform'] = transforms.Compose([
52
+ transforms.Resize((GlobalConfig.IMG_SIZE_CLF, GlobalConfig.IMG_SIZE_CLF)),
53
  transforms.ToTensor()
54
  ])
55
 
 
58
  lines = [line.strip() for line in f.readlines()]
59
  RESOURCES['labels_list'] = lines
60
  RESOURCES['taxonomy_map'] = {i: line for i, line in enumerate(lines)}
61
+
62
+ # We load Geofence map even if disabled, in case you want to toggle it on later via code
63
+ try:
64
+ with open(geofence_path, 'r') as f:
65
+ RESOURCES['geofence_map'] = json.load(f)
66
+ except Exception as e:
67
+ print(f"Warning: Geofence file not found or invalid: {e}")
68
+ RESOURCES['geofence_map'] = None
69
 
70
+ print(f"Models loaded. Geofencing is set to: {GlobalConfig.USE_GEOFENCING}")
 
 
 
71
 
72
  @app.post("/predict")
73
  async def predict(
 
79
  try:
80
  contents = await file.read()
81
  image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
 
 
82
  except Exception as e:
83
  raise HTTPException(status_code=400, detail="Invalid image file")
84
 
85
  # 2. Detect
86
  detector = RESOURCES['detector']
87
+ # USE GLOBAL CONFIG FOR DETECTOR THRESHOLD
88
+ results = detector(image_pil, imgsz=GlobalConfig.IMG_SIZE_DET, conf=GlobalConfig.DET_CONF_THRESH, verbose=False)[0]
89
 
90
  detections = []
91
 
 
96
  classifier = RESOURCES['classifier']
97
  transform = RESOURCES['transform']
98
 
 
99
  crops_tensors = []
100
  crop_indices = [] # Index in 'detections' list
101
 
 
103
  cls_id = int(box.cls[0].item())
104
  conf = float(box.conf[0].item())
105
  xyxy = box.xyxy[0].tolist()
106
+ label_det = GlobalConfig.MD_CLASSES.get(cls_id, 'unknown')
107
 
108
  det_obj = {
109
  "id": i + 1,
 
115
  if cls_id == 0: # Animal
116
  # Crop
117
  crop = image_pil.crop((int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])))
 
 
118
  t_crop = transform(crop).permute(1, 2, 0)
119
 
120
  crops_tensors.append(t_crop)
 
130
 
131
  # 4. Classify Animals
132
  if crops_tensors:
133
+ batch = torch.stack(crops_tensors)
 
134
 
135
  with torch.no_grad():
136
  logits = classifier(batch)
137
  probs = torch.softmax(logits, dim=1)
138
+ # Get Top 5 for Geofencing logic
139
  top_conf, top_idx = torch.topk(probs, k=5, dim=1)
140
 
141
  for k in range(len(batch)):
142
  det_idx = crop_indices[k]
143
 
144
+ # Prepare data
145
  labels_list = RESOURCES['labels_list']
146
+ k_labels = []
147
+ k_scores = []
148
 
149
+ for r in range(5):
150
+ idx = top_idx[k][r].item()
151
+ if idx < len(labels_list):
152
+ k_labels.append(labels_list[idx])
153
+ k_scores.append(top_conf[k][r].item())
154
+
155
+ if not k_labels:
156
+ continue
157
+
158
+ # --- LOGIC SWITCH BASED ON GLOBAL CONFIG ---
159
+ if GlobalConfig.USE_GEOFENCING and RESOURCES['geofence_map'] is not None:
160
+ # A. USE GEOFENCING
161
  final_lbl, final_scr, src = geofence_animal_classification(
162
  labels=k_labels,
163
  scores=k_scores,
 
165
  geofence_map=RESOURCES['geofence_map'],
166
  taxonomy_map=RESOURCES['taxonomy_map']
167
  )
168
+ else:
169
+ # B. RAW MODEL OUTPUT (Top 1)
170
+ final_lbl = k_labels[0]
171
+ final_scr = k_scores[0]
172
+ src = "raw_model"
173
+
174
+ # --- APPLY CLASSIFIER THRESHOLD ---
175
+ # If the score is too low, mark it as Unidentified/Unknown
176
+ if float(final_scr) < GlobalConfig.CLF_CONF_THRESH:
177
+ final_lbl = "Unidentified"
178
+ src = "low_confidence"
179
+
180
+ # Clean Label (remove scientific name if present)
181
+ try: species_clean = final_lbl.split(';')[-1].strip()
182
+ except: species_clean = final_lbl
183
+
184
+ detections[det_idx].update({
185
+ "species": final_lbl,
186
+ "species_common": species_clean,
187
+ "species_conf": float(final_scr),
188
+ "source": src
189
+ })
190
 
191
  # 5. Return JSON
192
  return {
193
  "filename": file.filename,
194
  "sensor_id": sensor_id,
195
+ "config_used": {
196
+ "geofencing": GlobalConfig.USE_GEOFENCING,
197
+ "det_threshold": GlobalConfig.DET_CONF_THRESH
198
+ },
199
  "total_detections": len(detections),
200
  "detections": detections
201
  }