Mahiruoshi commited on
Commit
76135a7
·
verified ·
1 Parent(s): adddaea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -11
app.py CHANGED
@@ -7,7 +7,7 @@ import zipfile
7
  import shutil
8
  import json
9
  from datetime import datetime
10
- from predict_task1 import Predictor
11
  from id_mapping import mapping
12
  from show_stitched import *
13
  import cv2
@@ -48,7 +48,7 @@ os.makedirs(IMAGES_FOLDER, exist_ok=True)
48
  # Initialize class mapping file
49
  if not os.path.exists(CLASS_MAPPING_FILE):
50
  # Create initial class mapping from id_mapping.py
51
- reverse_mapping = {str(v): k for k, v in mapping.items() if v != -1 and k is not None}
52
  with open(CLASS_MAPPING_FILE, 'w', encoding='utf-8') as f:
53
  json.dump(reverse_mapping, f, indent=2, ensure_ascii=False)
54
 
@@ -70,7 +70,7 @@ def get_compatible_class_id(yolo_class_name):
70
  "circle": 40, "Bullseye": -1, "bullseye": -1
71
  }
72
 
73
- return yolo_to_id_mapping.get(yolo_class_name, mapping.get(yolo_class_name, -1))
74
 
75
  def load_class_mapping():
76
  """Load class mapping from JSON file"""
@@ -290,6 +290,50 @@ def process_file(file_path, direction, task_type, filename):
290
 
291
  # Perform prediction
292
  class_name, results, detection_id = predictor.predict_id(file_path, task_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  # Use compatible mapping function
294
  class_id = str(get_compatible_class_id(class_name))
295
 
@@ -591,7 +635,7 @@ HTML_TEMPLATE = """
591
  <label for="task_type">Task Type:</label>
592
  <select id="task_type" name="task_type" required>
593
  <option value="TASK_1">Task 1</option>
594
- <option value="TASK_2">Task 2</option>
595
  </select>
596
  </div>
597
  <div class="form-group">
@@ -763,7 +807,7 @@ HTML_TEMPLATE = """
763
  `;
764
 
765
  if (numObstacles && numObstacles !== '0') {
766
- if (predictedId === '-1') {
767
  successMessage += `<br><span style="color: #ffc107;">⚠️ Obstacle ${numObstacles} not saved - No valid detection</span>`;
768
  } else {
769
  successMessage += `<br><span style="color: #28a745;">✓ Saved as Obstacle ${numObstacles}</span>`;
@@ -1104,7 +1148,7 @@ OBSTACLES_HTML_TEMPLATE = """
1104
  if (existingCount === 0) {
1105
  statusEl.style.backgroundColor = '#fff3cd';
1106
  statusEl.style.color = '#856404';
1107
- statusEl.textContent = 'No obstacles detected yet. Upload images with NUM_OBSTACLES parameter (1-8) and valid detections (ID ≠ -1).';
1108
  } else {
1109
  statusEl.style.backgroundColor = '#d4edda';
1110
  statusEl.style.color = '#155724';
@@ -1164,7 +1208,7 @@ def image_predict():
1164
 
1165
  # Get parameters from both old and new format
1166
  direction = request.form.get('direction', 'north')
1167
- task_type = request.form.get('task_type', 'TASK_1')
1168
  num_obstacles = request.form.get('NUM_OBSTACLES', '0') # Support NUM_OBSTACLES parameter
1169
 
1170
  # Try to parse filename format: "<timestamp>_<obstacle_id>_<signal>.jpeg"
@@ -1200,14 +1244,14 @@ def image_predict():
1200
  class_id, detection_result = process_file(file_path, signal, task_type, filename)
1201
 
1202
  # Handle NUM_OBSTACLES parameter for obstacle display
1203
- # Only save if detection is valid (image_id != '-1')
1204
  if (num_obstacles and num_obstacles.isdigit() and 1 <= int(num_obstacles) <= 8
1205
- and detection_result and detection_result.get('image_id') != '-1'):
1206
  save_obstacle_image(detection_result, int(num_obstacles))
1207
  print(f"Obstacle {num_obstacles} saved with valid detection (ID: {detection_result.get('image_id')})")
1208
  elif (num_obstacles and num_obstacles.isdigit() and 1 <= int(num_obstacles) <= 8
1209
- and detection_result and detection_result.get('image_id') == '-1'):
1210
- print(f"Obstacle {num_obstacles} NOT saved - invalid detection (ID: -1)")
1211
  elif num_obstacles and num_obstacles.isdigit() and 1 <= int(num_obstacles) <= 8:
1212
  print(f"Obstacle {num_obstacles} NOT saved - no detection result")
1213
 
 
7
  import shutil
8
  import json
9
  from datetime import datetime
10
+ from predict_task2 import Predictor
11
  from id_mapping import mapping
12
  from show_stitched import *
13
  import cv2
 
48
  # Initialize class mapping file
49
  if not os.path.exists(CLASS_MAPPING_FILE):
50
  # Create initial class mapping from id_mapping.py
51
+ reverse_mapping = {str(v): k for k, v in mapping.items() if v not in [-999] and k is not None}
52
  with open(CLASS_MAPPING_FILE, 'w', encoding='utf-8') as f:
53
  json.dump(reverse_mapping, f, indent=2, ensure_ascii=False)
54
 
 
70
  "circle": 40, "Bullseye": -1, "bullseye": -1
71
  }
72
 
73
+ return yolo_to_id_mapping.get(yolo_class_name, mapping.get(yolo_class_name, -999))
74
 
75
  def load_class_mapping():
76
  """Load class mapping from JSON file"""
 
290
 
291
  # Perform prediction
292
  class_name, results, detection_id = predictor.predict_id(file_path, task_type)
293
+
294
+ # For TASK_2, apply priority-based selection AFTER getting all detections
295
+ if task_type == "TASK_2" and results and results[0].boxes is not None and len(results[0].boxes) > 0:
296
+ priority_classes = ['right', 'left', 'up', 'down']
297
+ detections_list = []
298
+
299
+ boxes = results[0].boxes
300
+ for i in range(len(boxes)):
301
+ detected_class = results[0].names[int(boxes.cls[i])]
302
+ confidence = float(boxes.conf[i])
303
+ yolo_class_id = int(boxes.cls[i])
304
+
305
+ print(f"[APP.PY] Detection {i}: {detected_class} (confidence: {confidence:.2f}, class_id: {yolo_class_id})")
306
+
307
+ # Assign priority based on class name
308
+ if detected_class in priority_classes:
309
+ priority = len(priority_classes) - priority_classes.index(detected_class)
310
+ elif detected_class == 'Bullseye' or detected_class == 'bullseye':
311
+ priority = -10 # Lowest priority for bullseye
312
+ else:
313
+ priority = -1 # Lower priority for other symbols
314
+
315
+ detections_list.append({
316
+ 'index': i,
317
+ 'class_name': detected_class,
318
+ 'confidence': confidence,
319
+ 'priority': priority,
320
+ 'yolo_class_id': yolo_class_id
321
+ })
322
+
323
+ if detections_list:
324
+ # Sort by priority (descending), then by confidence (descending)
325
+ detections_list.sort(key=lambda x: (x['priority'], x['confidence']), reverse=True)
326
+
327
+ print(f"\n[APP.PY] Sorted detections:")
328
+ for det in detections_list:
329
+ print(f" - {det['class_name']}: priority={det['priority']}, confidence={det['confidence']:.2f}")
330
+
331
+ # Override with highest priority detection
332
+ selected = detections_list[0]
333
+ class_name = selected['class_name']
334
+ detection_id = selected['index']
335
+ print(f"\n[APP.PY] ✓ Final selection: {class_name} (priority: {selected['priority']}, confidence: {selected['confidence']:.2f})")
336
+
337
  # Use compatible mapping function
338
  class_id = str(get_compatible_class_id(class_name))
339
 
 
635
  <label for="task_type">Task Type:</label>
636
  <select id="task_type" name="task_type" required>
637
  <option value="TASK_1">Task 1</option>
638
+ <option value="TASK_2" selected>Task 2</option>
639
  </select>
640
  </div>
641
  <div class="form-group">
 
807
  `;
808
 
809
  if (numObstacles && numObstacles !== '0') {
810
+ if (predictedId === '-999') {
811
  successMessage += `<br><span style="color: #ffc107;">⚠️ Obstacle ${numObstacles} not saved - No valid detection</span>`;
812
  } else {
813
  successMessage += `<br><span style="color: #28a745;">✓ Saved as Obstacle ${numObstacles}</span>`;
 
1148
  if (existingCount === 0) {
1149
  statusEl.style.backgroundColor = '#fff3cd';
1150
  statusEl.style.color = '#856404';
1151
+ statusEl.textContent = 'No obstacles detected yet. Upload images with NUM_OBSTACLES parameter (1-8) and valid detections (ID ≠ -999).';
1152
  } else {
1153
  statusEl.style.backgroundColor = '#d4edda';
1154
  statusEl.style.color = '#155724';
 
1208
 
1209
  # Get parameters from both old and new format
1210
  direction = request.form.get('direction', 'north')
1211
+ task_type = request.form.get('task_type', 'TASK_2')
1212
  num_obstacles = request.form.get('NUM_OBSTACLES', '0') # Support NUM_OBSTACLES parameter
1213
 
1214
  # Try to parse filename format: "<timestamp>_<obstacle_id>_<signal>.jpeg"
 
1244
  class_id, detection_result = process_file(file_path, signal, task_type, filename)
1245
 
1246
  # Handle NUM_OBSTACLES parameter for obstacle display
1247
+ # Only save if detection is valid (image_id != '-999')
1248
  if (num_obstacles and num_obstacles.isdigit() and 1 <= int(num_obstacles) <= 8
1249
+ and detection_result and detection_result.get('image_id') != '-999'):
1250
  save_obstacle_image(detection_result, int(num_obstacles))
1251
  print(f"Obstacle {num_obstacles} saved with valid detection (ID: {detection_result.get('image_id')})")
1252
  elif (num_obstacles and num_obstacles.isdigit() and 1 <= int(num_obstacles) <= 8
1253
+ and detection_result and detection_result.get('image_id') == '-999'):
1254
+ print(f"Obstacle {num_obstacles} NOT saved - invalid detection (ID: -999)")
1255
  elif num_obstacles and num_obstacles.isdigit() and 1 <= int(num_obstacles) <= 8:
1256
  print(f"Obstacle {num_obstacles} NOT saved - no detection result")
1257