PrashanthB461 commited on
Commit
3677731
·
verified ·
1 Parent(s): a71c85c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -62
app.py CHANGED
@@ -38,6 +38,7 @@ class BYTETracker:
38
  self.tracks = {} # Store active tracks
39
  self.worker_history = {} # Track worker positions over time
40
  self.last_positions = {} # Last known positions of workers
 
41
 
42
  def update(self, dets, scores, cls):
43
  tracks = []
@@ -117,6 +118,7 @@ class BYTETracker:
117
  }
118
  self.worker_history[self.next_id] = [[x, y]]
119
  self.last_positions[self.next_id] = [x, y]
 
120
  tracks.append({
121
  'id': self.next_id,
122
  'bbox': [x, y, w, h],
@@ -138,9 +140,21 @@ class BYTETracker:
138
  del self.worker_history[track_id]
139
  if track_id in self.last_positions:
140
  del self.last_positions[track_id]
 
 
141
 
142
  return tracks
143
 
 
 
 
 
 
 
 
 
 
 
144
  def _calculate_iou(self, box1, box2):
145
  """Calculate IOU between two boxes"""
146
  x1, y1, w1, h1 = box1
@@ -211,17 +225,17 @@ CONFIG = {
211
  "improper_tool_use": 0.3
212
  },
213
  "MIN_VIOLATION_FRAMES": 1,
214
- "VIOLATION_COOLDOWN": 30.0,
215
  "WORKER_TRACKING_DURATION": 5.0,
216
  "MAX_PROCESSING_TIME": 60,
217
- "FRAME_SKIP": 2,
218
  "BATCH_SIZE": 16,
219
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
220
  "TRACK_BUFFER": 30,
221
  "TRACK_THRESH": 0.3,
222
  "MATCH_THRESH": 0.7,
223
- "SNAPSHOT_QUALITY": 95,
224
- "MAX_WORKER_DISTANCE": 100
225
  }
226
 
227
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -498,7 +512,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
498
  return None, ""
499
 
500
  def process_video(video_data):
501
- """Process video to detect safety violations with de-duplication"""
502
  try:
503
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
504
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
@@ -528,7 +542,7 @@ def process_video(video_data):
528
  )
529
 
530
  # Track unique violations by worker ID
531
- unique_violations = {} # {worker_id: {violation_type: {"timestamp": float, "confidence": float, "bbox": list}}}
532
  snapshots = []
533
  start_time = time.time()
534
  frame_skip = CONFIG["FRAME_SKIP"]
@@ -613,62 +627,61 @@ def process_video(video_data):
613
  if label is None:
614
  continue
615
 
 
 
 
 
616
  # Initialize worker if not seen before
617
  if worker_id not in unique_violations:
618
  unique_violations[worker_id] = {}
619
 
620
- # Check if this violation type has been recorded for this worker
621
- if label not in unique_violations[worker_id]:
622
- # This is a new violation type for this worker
623
- unique_violations[worker_id][label] = {
624
- "timestamp": current_time,
625
- "confidence": round(conf, 2),
626
- "bbox": bbox
627
- }
628
-
629
- # Create detection object
630
- detection = {
631
- "worker_id": worker_id,
632
- "violation": label,
633
- "confidence": round(conf, 2),
634
- "bounding_box": bbox,
635
- "timestamp": current_time
636
- }
637
-
638
- # Take snapshot for the new violation
639
- snapshot_frame = batch_frames[i].copy()
640
- snapshot_frame = draw_detections(snapshot_frame, [detection])
641
-
642
- # Add timestamp to snapshot
643
- cv2.putText(
644
- snapshot_frame,
645
- f"Time: {current_time:.2f}s",
646
- (10, 30),
647
- cv2.FONT_HERSHEY_SIMPLEX,
648
- 0.7,
649
- (255, 255, 255),
650
- 2
651
- )
652
-
653
- # Save snapshot with high quality
654
- snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
655
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
656
-
657
- cv2.imwrite(
658
- snapshot_path,
659
- snapshot_frame,
660
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
661
- )
662
-
663
- snapshots.append({
664
- "violation": label,
665
- "worker_id": worker_id,
666
- "timestamp": current_time,
667
- "snapshot_path": snapshot_path,
668
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
669
- })
670
-
671
- logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
672
 
673
  cap.release()
674
  if os.path.exists(video_path):
@@ -680,13 +693,11 @@ def process_video(video_data):
680
  # Convert tracked violations to final violation list
681
  violations = []
682
  for worker_id, worker_violations in unique_violations.items():
683
- for label, info in worker_violations.items():
684
  violation = {
685
  "worker_id": worker_id,
686
  "violation": label,
687
- "timestamp": info["timestamp"],
688
- "confidence": info["confidence"],
689
- "bounding_box": info["bbox"]
690
  }
691
  violations.append(violation)
692
 
 
38
  self.tracks = {} # Store active tracks
39
  self.worker_history = {} # Track worker positions over time
40
  self.last_positions = {} # Last known positions of workers
41
+ self.violation_history = {} # Track violations per worker: {worker_id: set(violation_types)}
42
 
43
  def update(self, dets, scores, cls):
44
  tracks = []
 
118
  }
119
  self.worker_history[self.next_id] = [[x, y]]
120
  self.last_positions[self.next_id] = [x, y]
121
+ self.violation_history[self.next_id] = set() # Initialize violation set for new worker
122
  tracks.append({
123
  'id': self.next_id,
124
  'bbox': [x, y, w, h],
 
140
  del self.worker_history[track_id]
141
  if track_id in self.last_positions:
142
  del self.last_positions[track_id]
143
+ if track_id in self.violation_history:
144
+ del self.violation_history[track_id]
145
 
146
  return tracks
147
 
148
+ def has_violation(self, worker_id, violation_type):
149
+ """Check if this worker already has this violation type recorded"""
150
+ return worker_id in self.violation_history and violation_type in self.violation_history[worker_id]
151
+
152
+ def record_violation(self, worker_id, violation_type):
153
+ """Record that this worker has this violation type"""
154
+ if worker_id not in self.violation_history:
155
+ self.violation_history[worker_id] = set()
156
+ self.violation_history[worker_id].add(violation_type)
157
+
158
  def _calculate_iou(self, box1, box2):
159
  """Calculate IOU between two boxes"""
160
  x1, y1, w1, h1 = box1
 
225
  "improper_tool_use": 0.3
226
  },
227
  "MIN_VIOLATION_FRAMES": 1,
228
+ "VIOLATION_COOLDOWN": 30.0, # Increased cooldown period
229
  "WORKER_TRACKING_DURATION": 5.0,
230
  "MAX_PROCESSING_TIME": 60,
231
+ "FRAME_SKIP": 2, # Skip more frames for faster processing
232
  "BATCH_SIZE": 16,
233
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
234
  "TRACK_BUFFER": 30,
235
  "TRACK_THRESH": 0.3,
236
  "MATCH_THRESH": 0.7,
237
+ "SNAPSHOT_QUALITY": 95, # Higher quality for better visibility
238
+ "MAX_WORKER_DISTANCE": 100 # Maximum pixel distance to consider same worker
239
  }
240
 
241
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
512
  return None, ""
513
 
514
  def process_video(video_data):
515
+ """Process video to detect safety violations"""
516
  try:
517
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
518
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
 
542
  )
543
 
544
  # Track unique violations by worker ID
545
+ unique_violations = {} # {worker_id: {violation_type: first_detection_time}}
546
  snapshots = []
547
  start_time = time.time()
548
  frame_skip = CONFIG["FRAME_SKIP"]
 
627
  if label is None:
628
  continue
629
 
630
+ # Skip if this worker already has this violation recorded
631
+ if tracker.has_violation(worker_id, label):
632
+ continue
633
+
634
  # Initialize worker if not seen before
635
  if worker_id not in unique_violations:
636
  unique_violations[worker_id] = {}
637
 
638
+ # Record this violation for this worker
639
+ tracker.record_violation(worker_id, label)
640
+ unique_violations[worker_id][label] = current_time
641
+
642
+ # Create detection object
643
+ detection = {
644
+ "worker_id": worker_id,
645
+ "violation": label,
646
+ "confidence": round(conf, 2),
647
+ "bounding_box": bbox,
648
+ "timestamp": current_time
649
+ }
650
+
651
+ # Take snapshot for the new violation
652
+ snapshot_frame = batch_frames[i].copy()
653
+ snapshot_frame = draw_detections(snapshot_frame, [detection])
654
+
655
+ # Add timestamp to snapshot
656
+ cv2.putText(
657
+ snapshot_frame,
658
+ f"Time: {current_time:.2f}s",
659
+ (10, 30),
660
+ cv2.FONT_HERSHEY_SIMPLEX,
661
+ 0.7,
662
+ (255, 255, 255),
663
+ 2
664
+ )
665
+
666
+ # Save snapshot with high quality
667
+ snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
668
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
669
+
670
+ cv2.imwrite(
671
+ snapshot_path,
672
+ snapshot_frame,
673
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
674
+ )
675
+
676
+ snapshots.append({
677
+ "violation": label,
678
+ "worker_id": worker_id,
679
+ "timestamp": current_time,
680
+ "snapshot_path": snapshot_path,
681
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
682
+ })
683
+
684
+ logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
 
 
 
 
 
685
 
686
  cap.release()
687
  if os.path.exists(video_path):
 
693
  # Convert tracked violations to final violation list
694
  violations = []
695
  for worker_id, worker_violations in unique_violations.items():
696
+ for label, detection_time in worker_violations.items():
697
  violation = {
698
  "worker_id": worker_id,
699
  "violation": label,
700
+ "timestamp": detection_time
 
 
701
  }
702
  violations.append(violation)
703