PrashanthB461 commited on
Commit
6f75206
·
verified ·
1 Parent(s): 3dcbfd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -19
app.py CHANGED
@@ -43,13 +43,14 @@ class BYTETracker:
43
  def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
- self.match_thresh = match_thresh # Increased to 0.5 for better matching
47
  self.frame_rate = frame_rate
48
  self.next_id = 1
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
  self.recently_removed = {} # Store recently removed tracks for re-identification
 
53
 
54
  def update(self, dets, scores, cls):
55
  tracks = []
@@ -108,6 +109,13 @@ class BYTETracker:
108
  'cls': cl,
109
  'last_seen': current_time
110
  })
 
 
 
 
 
 
 
111
  if best_track_id not in self.worker_history:
112
  self.worker_history[best_track_id] = []
113
  self.worker_history[best_track_id].append([x, y])
@@ -132,6 +140,13 @@ class BYTETracker:
132
  }
133
  self.worker_history[track_id] = [[x, y]]
134
  self.last_positions[track_id] = [x, y]
 
 
 
 
 
 
 
135
  tracks.append({
136
  'id': track_id,
137
  'bbox': [x, y, w, h],
@@ -153,6 +168,13 @@ class BYTETracker:
153
  'cls': cl,
154
  'last_seen': current_time
155
  }
 
 
 
 
 
 
 
156
  tracks.append({
157
  'id': worker_id,
158
  'bbox': [x, y, w, h],
@@ -171,6 +193,13 @@ class BYTETracker:
171
  }
172
  self.worker_history[self.next_id] = [[x, y]]
173
  self.last_positions[self.next_id] = [x, y]
 
 
 
 
 
 
 
174
  tracks.append({
175
  'id': self.next_id,
176
  'bbox': [x, y, w, h],
@@ -196,12 +225,17 @@ class BYTETracker:
196
  iou = intersection_area / (box1_area + box2_area - intersection_area)
197
  return iou
198
 
199
- def _is_same_worker(self, pos1, pos2, threshold=150): # Increased threshold to 150
200
  x1, y1 = pos1
201
  x2, y2 = pos2
202
  distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
203
  return distance < threshold
204
 
 
 
 
 
 
205
  # ========================== # Optimized Configuration # ==========================
206
  CONFIG = {
207
  "MODEL_PATH": "yolov8_safety.pt",
@@ -235,25 +269,26 @@ CONFIG = {
235
  },
236
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
237
  "CONFIDENCE_THRESHOLDS": {
238
- "no_helmet": 0.4,
239
  "no_harness": 0.25,
240
  "unsafe_posture": 0.25,
241
  "unsafe_zone": 0.25,
242
  "improper_tool_use": 0.25
243
  },
244
- "MIN_VIOLATION_FRAMES": 1,
245
  "VIOLATION_COOLDOWN": 30.0,
246
- "WORKER_TRACKING_DURATION": 10.0, # Reverted to 5.0 seconds
247
  "MAX_PROCESSING_TIME": 60,
248
  "FRAME_SKIP": 1,
249
  "BATCH_SIZE": 4,
250
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
251
- "TRACK_BUFFER": 150, # 5.0 seconds at 30 fps
252
  "TRACK_THRESH": 0.3,
253
- "MATCH_THRESH": 0.5, # Increased to 0.5
254
  "SNAPSHOT_QUALITY": 95,
255
- "MAX_WORKER_DISTANCE": 150, # Increased to match _is_same_worker threshold
256
- "TARGET_RESOLUTION": (384, 384)
 
257
  }
258
 
259
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -285,8 +320,18 @@ model = load_model()
285
  # ========================== # Helper Functions # ==========================
286
  def preprocess_frame(frame):
287
  target_res = CONFIG["TARGET_RESOLUTION"]
 
288
  frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
289
- frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
 
 
 
 
 
 
 
 
 
290
  return frame
291
 
292
  def draw_detections(frame, detections):
@@ -305,7 +350,10 @@ def draw_detections(frame, detections):
305
 
306
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
307
 
308
- cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
 
 
 
309
 
310
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
311
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
@@ -536,6 +584,83 @@ def verify_and_open_video(video_path):
536
 
537
  return cap
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  def process_video(video_data, temp_dir):
540
  video_path = None
541
  output_dir = os.path.join(temp_dir, "output")
@@ -586,6 +711,8 @@ def process_video(video_data, temp_dir):
586
  worker_id_mapping = {}
587
  unique_violations = {}
588
  violation_frames = {}
 
 
589
  start_time = time.time()
590
  frame_skip = CONFIG["FRAME_SKIP"]
591
  processed_frames = 0
@@ -595,6 +722,7 @@ def process_video(video_data, temp_dir):
595
  while processed_frames < total_frames:
596
  batch_frames = []
597
  batch_indices = []
 
598
 
599
  for _ in range(CONFIG["BATCH_SIZE"]):
600
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
@@ -606,6 +734,9 @@ def process_video(video_data, temp_dir):
606
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
607
  break
608
 
 
 
 
609
  frame = preprocess_frame(frame)
610
 
611
  for _ in range(frame_skip - 1):
@@ -614,6 +745,7 @@ def process_video(video_data, temp_dir):
614
 
615
  batch_frames.append(frame)
616
  batch_indices.append(frame_idx)
 
617
  processed_frames += 1
618
 
619
  if not batch_frames:
@@ -644,7 +776,7 @@ def process_video(video_data, temp_dir):
644
  yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
645
  last_yield_time = current_time
646
 
647
- for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
648
  current_time = frame_idx / fps
649
 
650
  boxes = result.boxes
@@ -658,8 +790,20 @@ def process_video(video_data, temp_dir):
658
  if label is None:
659
  continue
660
 
661
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
662
- continue
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
  bbox = box.xywh.cpu().numpy()[0]
665
  track_inputs.append({
@@ -693,11 +837,37 @@ def process_video(video_data, temp_dir):
693
 
694
  worker_id = worker_id_mapping[tracker_id]
695
 
696
- violation_key = (worker_id, label)
697
-
698
- if violation_key not in unique_violations:
699
- unique_violations[violation_key] = current_time
700
- violation_frames[violation_key] = frame_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
  cap.release()
703
  processing_time = time.time() - start_time
 
43
  def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
+ self.match_thresh = match_thresh
47
  self.frame_rate = frame_rate
48
  self.next_id = 1
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
  self.recently_removed = {} # Store recently removed tracks for re-identification
53
+ self.helmet_status = {} # Track helmet status for each worker
54
 
55
  def update(self, dets, scores, cls):
56
  tracks = []
 
109
  'cls': cl,
110
  'last_seen': current_time
111
  })
112
+
113
+ # Update helmet status if this is a helmet detection
114
+ if cl == 0: # Helmet violation class
115
+ # Higher confidence for helmet violations
116
+ if score > 0.45: # Increased threshold for helmet violations
117
+ self.helmet_status[best_track_id] = True
118
+
119
  if best_track_id not in self.worker_history:
120
  self.worker_history[best_track_id] = []
121
  self.worker_history[best_track_id].append([x, y])
 
140
  }
141
  self.worker_history[track_id] = [[x, y]]
142
  self.last_positions[track_id] = [x, y]
143
+
144
+ # Update helmet status if this is a helmet detection
145
+ if cl == 0: # Helmet violation class
146
+ # Higher confidence for helmet violations
147
+ if score > 0.45: # Increased threshold for helmet violations
148
+ self.helmet_status[track_id] = True
149
+
150
  tracks.append({
151
  'id': track_id,
152
  'bbox': [x, y, w, h],
 
168
  'cls': cl,
169
  'last_seen': current_time
170
  }
171
+
172
+ # Update helmet status if this is a helmet detection
173
+ if cl == 0: # Helmet violation class
174
+ # Higher confidence for helmet violations
175
+ if score > 0.45: # Increased threshold for helmet violations
176
+ self.helmet_status[worker_id] = True
177
+
178
  tracks.append({
179
  'id': worker_id,
180
  'bbox': [x, y, w, h],
 
193
  }
194
  self.worker_history[self.next_id] = [[x, y]]
195
  self.last_positions[self.next_id] = [x, y]
196
+
197
+ # Update helmet status if this is a helmet detection
198
+ if cl == 0: # Helmet violation class
199
+ # Higher confidence for helmet violations
200
+ if score > 0.45: # Increased threshold for helmet violations
201
+ self.helmet_status[self.next_id] = True
202
+
203
  tracks.append({
204
  'id': self.next_id,
205
  'bbox': [x, y, w, h],
 
225
  iou = intersection_area / (box1_area + box2_area - intersection_area)
226
  return iou
227
 
228
+ def _is_same_worker(self, pos1, pos2, threshold=150):
229
  x1, y1 = pos1
230
  x2, y2 = pos2
231
  distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
232
  return distance < threshold
233
 
234
+ # Function to validate if a helmet violation is consistent across frames
235
+ def validate_helmet_violation(self, worker_id, current_confidence):
236
+ # If we have consistent high confidence or multiple detections, it's a valid violation
237
+ return worker_id in self.helmet_status and self.helmet_status[worker_id]
238
+
239
  # ========================== # Optimized Configuration # ==========================
240
  CONFIG = {
241
  "MODEL_PATH": "yolov8_safety.pt",
 
269
  },
270
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
271
  "CONFIDENCE_THRESHOLDS": {
272
+ "no_helmet": 0.45, # Increased threshold for helmet violations
273
  "no_harness": 0.25,
274
  "unsafe_posture": 0.25,
275
  "unsafe_zone": 0.25,
276
  "improper_tool_use": 0.25
277
  },
278
+ "MIN_VIOLATION_FRAMES": 2, # Increased to require multiple frames for confirmation
279
  "VIOLATION_COOLDOWN": 30.0,
280
+ "WORKER_TRACKING_DURATION": 10.0,
281
  "MAX_PROCESSING_TIME": 60,
282
  "FRAME_SKIP": 1,
283
  "BATCH_SIZE": 4,
284
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
285
+ "TRACK_BUFFER": 150,
286
  "TRACK_THRESH": 0.3,
287
+ "MATCH_THRESH": 0.5,
288
  "SNAPSHOT_QUALITY": 95,
289
+ "MAX_WORKER_DISTANCE": 150,
290
+ "TARGET_RESOLUTION": (384, 384),
291
+ "HELMET_VALIDATION_FRAMES": 3 # Number of frames to validate helmet violations
292
  }
293
 
294
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
320
  # ========================== # Helper Functions # ==========================
321
  def preprocess_frame(frame):
322
  target_res = CONFIG["TARGET_RESOLUTION"]
323
+ # Enhanced preprocessing for better helmet detection
324
  frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
325
+ # Increase contrast to better differentiate helmets from other head coverings
326
+ frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20) # Increased contrast
327
+
328
+ # Additional preprocessing to enhance head/helmet features
329
+ # Apply slight sharpening to make edges more distinct
330
+ kernel = np.array([[-1,-1,-1],
331
+ [-1, 9,-1],
332
+ [-1,-1,-1]])
333
+ frame = cv2.filter2D(frame, -1, kernel)
334
+
335
  return frame
336
 
337
  def draw_detections(frame, detections):
 
350
 
351
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
352
 
353
+ # Make no_helmet violations more prominent
354
+ line_thickness = 4 if label == "no_helmet" else 3
355
+
356
+ cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, line_thickness)
357
 
358
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
359
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
 
584
 
585
  return cap
586
 
587
+ # Helper for helmet validation
588
+ def validate_helmet_detection(frame, bbox, confidence_threshold=0.45):
589
+ """
590
+ Additional validation for helmet detection to reduce false positives.
591
+ This function performs additional checks on the region to confirm it's a true helmet violation.
592
+ """
593
+ x, y, w, h = bbox
594
+ x1 = int(max(0, x - w/2))
595
+ y1 = int(max(0, y - h/2))
596
+ x2 = int(min(frame.shape[1], x + w/2))
597
+ y2 = int(min(frame.shape[0], y + h/2))
598
+
599
+ # Extract head region
600
+ head_region = frame[y1:y2, x1:x2]
601
+ if head_region.size == 0:
602
+ return False
603
+
604
+ # Check if this is truly a helmet violation by analyzing the region
605
+ # 1. Check color distribution - helmets often have more uniform color
606
+ hsv = cv2.cvtColor(head_region, cv2.COLOR_BGR2HSV)
607
+
608
+ # Check for typical helmet colors (many construction helmets are yellow, white, orange, blue)
609
+ # This helps differentiate from cloth head coverings
610
+ yellow_lower = np.array([20, 100, 100])
611
+ yellow_upper = np.array([30, 255, 255])
612
+ yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)
613
+
614
+ white_lower = np.array([0, 0, 200])
615
+ white_upper = np.array([180, 30, 255])
616
+ white_mask = cv2.inRange(hsv, white_lower, white_upper)
617
+
618
+ orange_lower = np.array([5, 100, 100])
619
+ orange_upper = np.array([15, 255, 255])
620
+ orange_mask = cv2.inRange(hsv, orange_lower, orange_upper)
621
+
622
+ blue_lower = np.array([100, 100, 100])
623
+ blue_upper = np.array([130, 255, 255])
624
+ blue_mask = cv2.inRange(hsv, blue_lower, blue_upper)
625
+
626
+ helmet_mask = cv2.bitwise_or(yellow_mask, white_mask)
627
+ helmet_mask = cv2.bitwise_or(helmet_mask, orange_mask)
628
+ helmet_mask = cv2.bitwise_or(helmet_mask, blue_mask)
629
+
630
+ # If there's a significant amount of helmet-colored pixels, this might be a helmet
631
+ helmet_percentage = np.sum(helmet_mask > 0) / (head_region.shape[0] * head_region.shape[1])
632
+
633
+ # If the region has a significant amount of helmet-like colors, it's probably a helmet
634
+ # so we should NOT flag it as a violation (return False)
635
+ if helmet_percentage > 0.25:
636
+ return False
637
+
638
+ # Check texture uniformity - helmets have more uniform texture compared to head coverings
639
+ gray = cv2.cvtColor(head_region, cv2.COLOR_BGR2GRAY)
640
+ texture_score = np.std(gray)
641
+
642
+ # If texture is very uniform (low standard deviation), it might be a helmet or bare head
643
+ # Very uniform texture (like a hard helmet) would have low texture_score
644
+ if texture_score < 15: # Low texture suggests uniform surface like a helmet
645
+ return False
646
+
647
+ # Additional check for cloth-like textures
648
+ edges = cv2.Canny(gray, 50, 150)
649
+ edge_density = np.sum(edges > 0) / (head_region.shape[0] * head_region.shape[1])
650
+
651
+ # If there are many edges (cloth wrinkles), this might be a kurchief
652
+ if edge_density > 0.15:
653
+ # This is likely a cloth head covering, not a helmet violation
654
+ # But also not a proper helmet, so we should still detect as violation
655
+ return True
656
+
657
+ # If confidence is very high, trust the model
658
+ if confidence_threshold >= 0.6:
659
+ return True
660
+
661
+ # Default to the original detection
662
+ return True
663
+
664
  def process_video(video_data, temp_dir):
665
  video_path = None
666
  output_dir = os.path.join(temp_dir, "output")
 
711
  worker_id_mapping = {}
712
  unique_violations = {}
713
  violation_frames = {}
714
+ # Track helmet detections across frames for each worker
715
+ helmet_detections = {}
716
  start_time = time.time()
717
  frame_skip = CONFIG["FRAME_SKIP"]
718
  processed_frames = 0
 
722
  while processed_frames < total_frames:
723
  batch_frames = []
724
  batch_indices = []
725
+ batch_originals = [] # Store original frames for helmet validation
726
 
727
  for _ in range(CONFIG["BATCH_SIZE"]):
728
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
 
734
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
735
  break
736
 
737
+ # Store original frame for validation
738
+ original_frame = frame.copy()
739
+
740
  frame = preprocess_frame(frame)
741
 
742
  for _ in range(frame_skip - 1):
 
745
 
746
  batch_frames.append(frame)
747
  batch_indices.append(frame_idx)
748
+ batch_originals.append(original_frame)
749
  processed_frames += 1
750
 
751
  if not batch_frames:
 
776
  yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
777
  last_yield_time = current_time
778
 
779
+ for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
780
  current_time = frame_idx / fps
781
 
782
  boxes = result.boxes
 
790
  if label is None:
791
  continue
792
 
793
+ # Enhanced confidence threshold handling, especially for helmet detection
794
+ if label == "no_helmet":
795
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.45):
796
+ continue
797
+
798
+ # Additional validation for helmet detection
799
+ bbox = box.xywh.cpu().numpy()[0]
800
+ if not validate_helmet_detection(original_frame, bbox, conf):
801
+ logger.info(f"Frame {frame_idx}: Helmet false positive filtered at {conf:.2f} confidence")
802
+ continue
803
+ else:
804
+ # Use regular thresholds for other violations
805
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
806
+ continue
807
 
808
  bbox = box.xywh.cpu().numpy()[0]
809
  track_inputs.append({
 
837
 
838
  worker_id = worker_id_mapping[tracker_id]
839
 
840
+ # Special handling for helmet violations to ensure consistency
841
+ if label == "no_helmet":
842
+ # Track helmet violations for this worker
843
+ if worker_id not in helmet_detections:
844
+ helmet_detections[worker_id] = []
845
+
846
+ # Store this detection with frame index and confidence
847
+ helmet_detections[worker_id].append({
848
+ "frame_idx": frame_idx,
849
+ "confidence": conf,
850
+ "bbox": bbox
851
+ })
852
+
853
+ # Only record a helmet violation if we have multiple consistent detections
854
+ if len(helmet_detections[worker_id]) >= CONFIG["HELMET_VALIDATION_FRAMES"]:
855
+ # Calculate average confidence
856
+ avg_conf = sum(d["confidence"] for d in helmet_detections[worker_id]) / len(helmet_detections[worker_id])
857
+
858
+ # If confidence is consistently high across multiple frames, record the violation
859
+ if avg_conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
860
+ violation_key = (worker_id, label)
861
+ if violation_key not in unique_violations:
862
+ unique_violations[violation_key] = current_time
863
+ violation_frames[violation_key] = frame_idx
864
+ logger.info(f"Frame {frame_idx}: Valid helmet violation for worker {worker_id} with avg conf {avg_conf:.2f}")
865
+ else:
866
+ # Regular handling for other violations
867
+ violation_key = (worker_id, label)
868
+ if violation_key not in unique_violations:
869
+ unique_violations[violation_key] = current_time
870
+ violation_frames[violation_key] = frame_idx
871
 
872
  cap.release()
873
  processing_time = time.time() - start_time