VietCat commited on
Commit
81a35ad
·
1 Parent(s): ec14871

Add confidence threshold slider and update class labels

Browse files

- Add confidence_threshold slider to UI (0.01-0.9 range)
- Update detect() function to accept dynamic confidence threshold
- Change from 43 classes to 29 aggregated traffic sign categories
- Improve NMS IOU threshold from 0.45 to 0.55
- Update config default confidence to 0.30

Files changed (3) hide show
  1. app.py +15 -4
  2. config.yaml +30 -44
  3. model.py +14 -7
app.py CHANGED
@@ -8,10 +8,11 @@ import io
8
  # Load the detector
9
  detector = TrafficSignDetector('config.yaml')
10
 
11
- def detect_traffic_signs(image):
12
  """
13
  Process the uploaded image and return the image with detected signs.
14
  :param image: PIL Image or numpy array
 
15
  :return: tuple of (detected image, preprocessed image)
16
  """
17
  # Redirect stdout to capture all logs
@@ -25,8 +26,8 @@ def detect_traffic_signs(image):
25
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
26
  print(f"Converted to BGR, shape: {image.shape}")
27
 
28
- # Perform detection (returns tuple of (detected_image, preprocessed_image))
29
- result_image, preprocessed_image = detector.detect(image)
30
 
31
  # Convert back to RGB for Gradio
32
  result_image = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
@@ -46,10 +47,20 @@ with gr.Blocks(title="Traffic Sign Detector") as demo:
46
  with gr.Row():
47
  preprocessed_image = gr.Image(label="Preprocessed Image (640x640, Letterboxed)")
48
 
 
 
 
 
 
 
 
 
 
 
49
  detect_btn = gr.Button("Detect Traffic Signs")
50
  detect_btn.click(
51
  fn=detect_traffic_signs,
52
- inputs=input_image,
53
  outputs=[output_image, preprocessed_image],
54
  queue=True # Enable queue to ensure logs are shown
55
  )
 
8
  # Load the detector
9
  detector = TrafficSignDetector('config.yaml')
10
 
11
+ def detect_traffic_signs(image, confidence_threshold):
12
  """
13
  Process the uploaded image and return the image with detected signs.
14
  :param image: PIL Image or numpy array
15
+ :param confidence_threshold: confidence threshold from slider
16
  :return: tuple of (detected image, preprocessed image)
17
  """
18
  # Redirect stdout to capture all logs
 
26
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
27
  print(f"Converted to BGR, shape: {image.shape}")
28
 
29
+ # Perform detection with the slider's confidence threshold
30
+ result_image, preprocessed_image = detector.detect(image, confidence_threshold=confidence_threshold)
31
 
32
  # Convert back to RGB for Gradio
33
  result_image = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
 
47
  with gr.Row():
48
  preprocessed_image = gr.Image(label="Preprocessed Image (640x640, Letterboxed)")
49
 
50
+ with gr.Row():
51
+ confidence_threshold = gr.Slider(
52
+ minimum=0.01,
53
+ maximum=0.9,
54
+ value=0.30,
55
+ step=0.01,
56
+ label="Confidence Threshold",
57
+ info="Lower values show more detections (less confident). Adjust to find optimal balance."
58
+ )
59
+
60
  detect_btn = gr.Button("Detect Traffic Signs")
61
  detect_btn.click(
62
  fn=detect_traffic_signs,
63
+ inputs=[input_image, confidence_threshold],
64
  outputs=[output_image, preprocessed_image],
65
  queue=True # Enable queue to ensure logs are shown
66
  )
config.yaml CHANGED
@@ -1,6 +1,6 @@
1
  model:
2
  path: 'VietCat/GTSRB-Model/models/GTSRB.pt' # Path to the YOLO model on Hugging Face Hub (will be downloaded automatically)
3
- confidence_threshold: 0.001 # Minimum confidence for detections (very low for testing)
4
 
5
  inference:
6
  box_color: (128, 0, 128) # Purple color for bounding boxes (BGR format)
@@ -8,46 +8,32 @@ inference:
8
  thickness: 2 # Thickness of bounding box lines
9
 
10
  classes:
11
- - 'Speed limit 20'
12
- - 'Speed limit 30'
13
- - 'Speed limit 50'
14
- - 'Speed limit 60'
15
- - 'Speed limit 70'
16
- - 'Speed limit 80'
17
- - 'Restriction ends 80'
18
- - 'Speed limit 100'
19
- - 'Speed limit 120'
20
- - 'No overtaking'
21
- - 'No overtaking trucks'
22
- - 'Priority at next intersection'
23
- - 'Priority road'
24
- - 'Give way'
25
- - 'Stop'
26
- - 'No traffic both ways'
27
- - 'No trucks'
28
- - 'No entry'
29
- - 'Danger'
30
- - 'Bend left'
31
- - 'Bend right'
32
- - 'Bend'
33
- - 'Uneven road'
34
- - 'Slippery road'
35
- - 'Road narrows'
36
- - 'Construction'
37
- - 'Traffic signal'
38
- - 'Pedestrian crossing'
39
- - 'School crossing'
40
- - 'Cycles crossing'
41
- - 'Snow'
42
- - 'Animals'
43
- - 'Restriction ends'
44
- - 'Go right'
45
- - 'Go left'
46
- - 'Go straight'
47
- - 'Go right or straight'
48
- - 'Go left or straight'
49
- - 'Keep right'
50
- - 'Keep left'
51
- - 'Roundabout'
52
- - 'Restriction ends overtaking'
53
- - 'Restriction ends overtaking trucks'
 
1
  model:
2
  path: 'VietCat/GTSRB-Model/models/GTSRB.pt' # Path to the YOLO model on Hugging Face Hub (will be downloaded automatically)
3
+ confidence_threshold: 0.30 # Minimum confidence for detections (0.3 filters most false positives)
4
 
5
  inference:
6
  box_color: (128, 0, 128) # Purple color for bounding boxes (BGR format)
 
8
  thickness: 2 # Thickness of bounding box lines
9
 
10
  classes:
11
+ - 'one_way_prohibition'
12
+ - 'no_parking'
13
+ - 'no_stopping_and_parking'
14
+ - 'no_turn_left'
15
+ - 'no_turn_right'
16
+ - 'no_u_turn'
17
+ - 'no_u_and_left_turn'
18
+ - 'no_u_and_right_turn'
19
+ - 'no_motorbike_entry_turning'
20
+ - 'no_car_entry_turning'
21
+ - 'no_truck_entry_turning'
22
+ - 'other_prohibition'
23
+ - 'indication'
24
+ - 'direction'
25
+ - 'speed_limit'
26
+ - 'weight_limit'
27
+ - 'height_limit'
28
+ - 'pedestrian_crossing'
29
+ - 'intersection_danger'
30
+ - 'road_danger'
31
+ - 'pedestrian_danger'
32
+ - 'construction_danger'
33
+ - 'slow_warning'
34
+ - 'other_warning'
35
+ - 'vehicle_permission_lane'
36
+ - 'vehicle_and_speed_permission_lane'
37
+ - 'overpass_route'
38
+ - 'no_more_prohibition'
39
+ - 'other'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -149,12 +149,18 @@ class TrafficSignDetector:
149
  print(f"Image format: {image.dtype}, Min: {image.min()}, Max: {image.max()}, Mean: {image.mean():.1f}")
150
  return image
151
 
152
- def detect(self, image):
153
  """
154
  Perform inference on the image and draw bounding boxes.
155
  :param image: numpy array of the image
 
156
  :return: tuple of (image with drawn bounding boxes, preprocessed image for visualization)
157
  """
 
 
 
 
 
158
  print(f"\n{'='*80}")
159
  print(f"DETECTION PIPELINE START")
160
  print(f"{'='*80}")
@@ -191,11 +197,11 @@ class TrafficSignDetector:
191
  # Use iou_threshold for NMS (Non-Maximum Suppression) to remove overlapping boxes
192
  print(f"\n[STEP 4] MODEL INFERENCE")
193
  print(f" - Input shape to model: {image.shape}")
194
- print(f" - Confidence threshold: {self.conf_threshold}")
195
- print(f" - IOU threshold: 0.45")
196
 
197
  # Run with conf=0.0 to get raw predictions (before filtering)
198
- results_raw = self.model(image, conf=0.0, imgsz=640, iou=0.45)
199
  raw_box_count = len(results_raw[0].boxes) if results_raw else 0
200
  print(f" - Raw detections (conf=0.0): {raw_box_count}")
201
 
@@ -208,8 +214,8 @@ class TrafficSignDetector:
208
  print(f" - Confidences > 0.0001: {sum(1 for c in all_raw_confs if c > 0.0001)}")
209
 
210
  # Now run with actual threshold
211
- results = self.model(image, conf=self.conf_threshold, imgsz=640, iou=0.45)
212
- print(f" - Filtered detections (conf={self.conf_threshold}): {len(results)}")
213
 
214
  # Get original dimensions for coordinate transformation
215
  orig_h, orig_w = original_image.shape[:2]
@@ -249,7 +255,7 @@ class TrafficSignDetector:
249
  print(f"Detected: {self.classes[cls]} with conf {conf:.4f} at ({x1},{y1})-({x2},{y2})")
250
 
251
  # Only draw if confidence meets threshold
252
- if conf >= self.conf_threshold:
253
  # Draw bounding box on original image
254
  cv2.rectangle(original_image, (x1, y1), (x2, y2), self.box_color, self.thickness)
255
 
@@ -279,6 +285,7 @@ class TrafficSignDetector:
279
  print(f" 2) Use augmentation during training")
280
  print(f" 3) Check training/validation accuracy was good")
281
  print(f" 4) Ensure training data matches inference image types")
 
282
 
283
  if scale < 0.5:
284
  print(f"\n ⚠️ SCALING ISSUE:")
 
149
  print(f"Image format: {image.dtype}, Min: {image.min()}, Max: {image.max()}, Mean: {image.mean():.1f}")
150
  return image
151
 
152
+ def detect(self, image, confidence_threshold=None):
153
  """
154
  Perform inference on the image and draw bounding boxes.
155
  :param image: numpy array of the image
156
+ :param confidence_threshold: optional override for confidence threshold
157
  :return: tuple of (image with drawn bounding boxes, preprocessed image for visualization)
158
  """
159
+ # Use provided threshold or fall back to config value
160
+ if confidence_threshold is None:
161
+ confidence_threshold = self.conf_threshold
162
+ else:
163
+ confidence_threshold = float(confidence_threshold)
164
  print(f"\n{'='*80}")
165
  print(f"DETECTION PIPELINE START")
166
  print(f"{'='*80}")
 
197
  # Use iou_threshold for NMS (Non-Maximum Suppression) to remove overlapping boxes
198
  print(f"\n[STEP 4] MODEL INFERENCE")
199
  print(f" - Input shape to model: {image.shape}")
200
+ print(f" - Confidence threshold: {confidence_threshold}")
201
+ print(f" - IOU threshold: 0.55")
202
 
203
  # Run with conf=0.0 to get raw predictions (before filtering)
204
+ results_raw = self.model(image, conf=0.0, imgsz=640, iou=0.55)
205
  raw_box_count = len(results_raw[0].boxes) if results_raw else 0
206
  print(f" - Raw detections (conf=0.0): {raw_box_count}")
207
 
 
214
  print(f" - Confidences > 0.0001: {sum(1 for c in all_raw_confs if c > 0.0001)}")
215
 
216
  # Now run with actual threshold
217
+ results = self.model(image, conf=confidence_threshold, imgsz=640, iou=0.55)
218
+ print(f" - Filtered detections (conf={confidence_threshold}): {len(results)}")
219
 
220
  # Get original dimensions for coordinate transformation
221
  orig_h, orig_w = original_image.shape[:2]
 
255
  print(f"Detected: {self.classes[cls]} with conf {conf:.4f} at ({x1},{y1})-({x2},{y2})")
256
 
257
  # Only draw if confidence meets threshold
258
+ if conf >= confidence_threshold:
259
  # Draw bounding box on original image
260
  cv2.rectangle(original_image, (x1, y1), (x2, y2), self.box_color, self.thickness)
261
 
 
285
  print(f" 2) Use augmentation during training")
286
  print(f" 3) Check training/validation accuracy was good")
287
  print(f" 4) Ensure training data matches inference image types")
288
+ print(f" - Try lowering the confidence threshold slider to see detections")
289
 
290
  if scale < 0.5:
291
  print(f"\n ⚠️ SCALING ISSUE:")