PrashanthB461 commited on
Commit
357b766
·
verified ·
1 Parent(s): 44fcb88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -223
app.py CHANGED
@@ -19,7 +19,6 @@ from retrying import retry
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
22
- from collections import defaultdict
23
 
24
  # ========================== # Configuration and Setup # ==========================
25
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
@@ -28,174 +27,148 @@ os.makedirs('/tmp/Ultralytics', exist_ok=True)
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
29
  logger = logging.getLogger(__name__)
30
 
31
- # Suppress warnings
32
- warnings.filterwarnings("ignore")
33
-
34
- # ========================== # Enhanced Tracker Implementation # ==========================
35
- class SafetyTracker:
36
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
37
  self.track_thresh = track_thresh
38
  self.track_buffer = track_buffer
39
  self.match_thresh = match_thresh
40
  self.frame_rate = frame_rate
41
  self.next_id = 1
42
-
43
- # Trackers for different purposes
44
- self.worker_tracks = {} # Active worker tracks
45
- self.violation_history = defaultdict(dict) # Track violations per worker
46
- self.face_encodings = {} # Store face encodings for helmet violations
47
- self.position_history = defaultdict(list) # Track positions for non-helmet violations
48
-
49
- # Cooldown periods (in seconds)
50
- self.VIOLATION_COOLDOWNS = {
51
- "no_helmet": 30.0,
52
- "no_harness": 20.0,
53
- "unsafe_posture": 15.0,
54
- "unsafe_zone": 10.0,
55
- "improper_tool_use": 15.0
56
- }
57
 
58
- def update(self, detections, frame):
59
- """Update tracks with new detections and check for violations"""
60
  current_time = time.time()
61
- active_violations = []
62
- new_violations = []
63
 
64
- for det in detections:
65
- bbox = det['bbox']
66
- label = det['violation']
67
- confidence = det['confidence']
 
 
 
 
 
68
 
69
- # For helmet violations, use face recognition
70
- if label == "no_helmet":
71
- worker_id = self._match_by_face(bbox, frame)
72
- else:
73
- # For other violations, use position tracking
74
- worker_id = self._match_by_position(bbox, label)
 
 
 
 
 
 
75
 
76
- if worker_id is None:
77
- worker_id = self.next_id
78
- self.next_id += 1
 
 
 
 
 
79
 
80
- # Check if this is a new violation for this worker
81
- if self._is_new_violation(worker_id, label, current_time):
82
- # Record the violation
83
- violation = {
84
- 'worker_id': worker_id,
85
- 'violation': label,
86
- 'confidence': confidence,
87
- 'bbox': bbox,
88
- 'timestamp': current_time
89
- }
90
- new_violations.append(violation)
91
 
92
- # Update violation history
93
- self.violation_history[worker_id][label] = current_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # For helmet violations, store face encoding
96
- if label == "no_helmet":
97
- self._store_face_encoding(worker_id, bbox, frame)
98
-
99
- # Keep track of active workers
100
- self.worker_tracks[worker_id] = {
101
- 'bbox': bbox,
102
- 'last_seen': current_time,
103
- 'label': label
104
- }
 
 
 
 
 
 
105
 
106
  # Clean up old tracks
107
- self._cleanup_tracks(current_time)
108
-
109
- return new_violations
110
-
111
- def _match_by_face(self, bbox, frame):
112
- """Match detection by face recognition (for helmet violations)"""
113
- x, y, w, h = bbox
114
- face_region = frame[max(0, int(y-h/2)):int(y+h/2), max(0, int(x-w/2)):int(x+w/2)]
115
-
116
- if face_region.size == 0:
117
- return None
118
-
119
- try:
120
- # Get face encodings from current detection
121
- face_locations = face_recognition.face_locations(face_region)
122
- if not face_locations:
123
- return None
124
-
125
- current_encoding = face_recognition.face_encodings(face_region, face_locations)[0]
126
-
127
- # Compare with known faces
128
- for worker_id, encodings in self.face_encodings.items():
129
- matches = face_recognition.compare_faces(encodings, current_encoding, tolerance=0.6)
130
- if any(matches):
131
- return worker_id
132
-
133
- except Exception as e:
134
- logger.warning(f"Face recognition error: {e}")
135
-
136
- return None
137
-
138
- def _match_by_position(self, bbox, label):
139
- """Match detection by position (for non-helmet violations)"""
140
- x, y, w, h = bbox
141
- current_pos = (x, y)
142
 
143
- for worker_id, positions in self.position_history.items():
144
- if label not in self.violation_history[worker_id]:
145
- continue
146
-
147
- # Check if current position is near any previous positions for this worker
148
- for pos in positions:
149
- distance = np.sqrt((current_pos[0]-pos[0])**2 + (current_pos[1]-pos[1])**2)
150
- if distance < 100: # Within 100 pixels
151
- return worker_id
152
-
153
- return None
154
-
155
- def _is_new_violation(self, worker_id, label, current_time):
156
- """Check if this is a new violation for this worker"""
157
- if label not in self.violation_history[worker_id]:
158
- return True
159
 
160
- last_detection = self.violation_history[worker_id][label]
161
- cooldown = self.VIOLATION_COOLDOWNS.get(label, 10.0)
 
 
 
 
162
 
163
- return (current_time - last_detection) > cooldown
164
-
165
- def _store_face_encoding(self, worker_id, bbox, frame):
166
- """Store face encoding for a worker"""
167
- x, y, w, h = bbox
168
- face_region = frame[max(0, int(y-h/2)):int(y+h/2), max(0, int(x-w/2)):int(x+w/2)]
169
 
170
- if face_region.size == 0:
171
- return
172
 
173
- try:
174
- face_locations = face_recognition.face_locations(face_region)
175
- if face_locations:
176
- encoding = face_recognition.face_encodings(face_region, face_locations)[0]
177
- if worker_id not in self.face_encodings:
178
- self.face_encodings[worker_id] = []
179
- self.face_encodings[worker_id].append(encoding)
180
- except Exception as e:
181
- logger.warning(f"Error storing face encoding: {e}")
182
-
183
- def _cleanup_tracks(self, current_time):
184
- """Clean up old tracks and face encodings"""
185
- # Remove inactive workers
186
- inactive_ids = [
187
- worker_id for worker_id, track in self.worker_tracks.items()
188
- if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
189
- ]
190
 
191
- for worker_id in inactive_ids:
192
- self.worker_tracks.pop(worker_id, None)
193
- self.position_history.pop(worker_id, None)
194
-
195
- # Keep face encodings for a longer period (for helmet violations)
196
- if (current_time - max(self.violation_history[worker_id].values(), default=0)) > 300: # 5 minutes
197
- self.face_encodings.pop(worker_id, None)
198
- self.violation_history.pop(worker_id, None)
 
 
 
 
199
 
200
  # ========================== # Optimized Configuration # ==========================
201
  CONFIG = {
@@ -238,11 +211,17 @@ CONFIG = {
238
  "improper_tool_use": 0.3
239
  },
240
  "MIN_VIOLATION_FRAMES": 1,
241
- "FRAME_SKIP": 2,
 
 
 
242
  "BATCH_SIZE": 16,
243
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
244
- "SNAPSHOT_QUALITY": 95,
245
- "FACE_RECOGNITION_INTERVAL": 5 # Process face recognition every 5 frames
 
 
 
246
  }
247
 
248
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -317,13 +296,22 @@ def calculate_safety_score(violations):
317
  "improper_tool_use": 25
318
  }
319
 
320
- # Count unique violation types
321
- unique_violations = set()
322
  for v in violations:
 
323
  violation_type = v.get("violation", "Unknown")
324
- unique_violations.add(violation_type)
 
 
 
 
 
 
 
 
 
325
 
326
- total_penalty = sum(penalties.get(v, 0) for v in unique_violations)
327
  score = max(0, 100 - total_penalty)
328
  return score
329
 
@@ -354,10 +342,18 @@ def generate_violation_pdf(violations, score):
354
  c.drawString(1 * inch, y_position, "Summary:")
355
  y_position -= 0.3 * inch
356
 
 
 
 
 
 
 
 
 
357
  c.setFont("Helvetica", 10)
358
  summary_data = {
 
359
  "Total Violations Found": len(violations),
360
- "Unique Violation Types": len(set(v['violation'] for v in violations)),
361
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
362
  }
363
 
@@ -365,27 +361,30 @@ def generate_violation_pdf(violations, score):
365
  c.drawString(1 * inch, y_position, f"{key}: {value}")
366
  y_position -= 0.25 * inch
367
 
368
- # Detailed Violations
369
  y_position -= 0.5 * inch
370
  c.setFont("Helvetica-Bold", 12)
371
- c.drawString(1 * inch, y_position, "Violation Details:")
372
  y_position -= 0.3 * inch
373
 
374
  c.setFont("Helvetica", 10)
375
- for v in violations:
376
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
377
- worker_id = v.get("worker_id", "Unknown")
378
- time_str = f"{v.get('timestamp', 0.0):.2f}s"
379
- conf_str = f"{v.get('confidence', 0.0):.2f}"
380
-
381
- violation_text = f"- {display_name} by Worker {worker_id} at {time_str} (Confidence: {conf_str})"
382
- c.drawString(1.2 * inch, y_position, violation_text)
383
  y_position -= 0.2 * inch
384
 
385
- if y_position < 1 * inch:
386
- c.showPage()
387
- c.setFont("Helvetica", 10)
388
- y_position = 10 * inch
 
 
 
 
 
 
 
 
 
389
 
390
  c.save()
391
  pdf_file.seek(0)
@@ -499,7 +498,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
499
  return None, ""
500
 
501
  def process_video(video_data):
502
- """Process video to detect safety violations with enhanced tracking"""
503
  try:
504
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
505
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
@@ -521,12 +520,19 @@ def process_video(video_data):
521
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
522
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
523
 
524
- tracker = SafetyTracker(frame_rate=fps)
 
 
 
 
 
 
 
 
525
  snapshots = []
526
  start_time = time.time()
527
  frame_skip = CONFIG["FRAME_SKIP"]
528
  processed_frames = 0
529
- frame_counter = 0
530
 
531
  while processed_frames < total_frames:
532
  batch_frames = []
@@ -551,7 +557,6 @@ def process_video(video_data):
551
  batch_frames.append(frame)
552
  batch_indices.append(frame_idx)
553
  processed_frames += 1
554
- frame_counter += 1
555
 
556
  if not batch_frames:
557
  break
@@ -569,7 +574,7 @@ def process_video(video_data):
569
  start_time = time.time()
570
 
571
  boxes = result.boxes
572
- detections = []
573
 
574
  for box in boxes:
575
  cls = int(box.cls)
@@ -583,54 +588,83 @@ def process_video(video_data):
583
  continue
584
 
585
  bbox = box.xywh.cpu().numpy()[0]
586
- detections.append({
587
  "bbox": bbox,
588
- "violation": label,
589
- "confidence": conf
590
  })
591
 
592
- if not detections:
593
  continue
594
 
595
- # Update tracker with new detections
596
- new_violations = tracker.update(detections, batch_frames[i])
597
-
598
- # Process new violations
599
- for violation in new_violations:
600
- # Take snapshot for the new violation
601
- snapshot_frame = batch_frames[i].copy()
602
- snapshot_frame = draw_detections(snapshot_frame, [violation])
603
-
604
- # Add timestamp to snapshot
605
- cv2.putText(
606
- snapshot_frame,
607
- f"Time: {violation['timestamp']:.2f}s",
608
- (10, 30),
609
- cv2.FONT_HERSHEY_SIMPLEX,
610
- 0.7,
611
- (255, 255, 255),
612
- 2
613
- )
614
-
615
- # Save snapshot with high quality
616
- snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
617
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
618
 
619
- cv2.imwrite(
620
- snapshot_path,
621
- snapshot_frame,
622
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
623
- )
624
-
625
- snapshots.append({
626
- "violation": violation['violation'],
627
- "worker_id": violation['worker_id'],
628
- "timestamp": violation['timestamp'],
629
- "snapshot_path": snapshot_path,
630
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
631
- })
632
 
633
- logger.info(f"Captured snapshot for {violation['violation']} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
  cap.release()
636
  if os.path.exists(video_path):
@@ -639,15 +673,16 @@ def process_video(video_data):
639
  processing_time = time.time() - start_time
640
  logger.info(f"Processing complete in {processing_time:.2f}s")
641
 
642
- # Get all unique violations from tracker
643
  violations = []
644
- for worker_id, worker_violations in tracker.violation_history.items():
645
  for label, detection_time in worker_violations.items():
646
- violations.append({
647
  "worker_id": worker_id,
648
  "violation": label,
649
  "timestamp": detection_time
650
- })
 
651
 
652
  if not violations:
653
  logger.info("No violations detected after processing")
@@ -664,15 +699,16 @@ def process_video(video_data):
664
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
665
 
666
  # Format violations table for display
667
- violation_table = "| Violation | Worker ID | Time (s) |\n"
668
- violation_table += "|-----------|-----------|----------|\n"
669
 
670
- for v in sorted(violations, key=lambda x: x.get("timestamp", 0.0)):
671
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
672
  worker_id = v.get("worker_id", "Unknown")
673
  timestamp = v.get("timestamp", 0.0)
 
674
 
675
- violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} |\n"
676
 
677
  # Format snapshots for display
678
  snapshots_text = ""
 
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
 
22
 
23
  # ========================== # Configuration and Setup # ==========================
24
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
 
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
  logger = logging.getLogger(__name__)
29
 
30
+ # ========================== # ByteTrack Implementation # ==========================
31
+ class BYTETracker:
 
 
 
32
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
33
  self.track_thresh = track_thresh
34
  self.track_buffer = track_buffer
35
  self.match_thresh = match_thresh
36
  self.frame_rate = frame_rate
37
  self.next_id = 1
38
+ self.tracks = {} # Store active tracks
39
+ self.worker_history = {} # Track worker positions over time
40
+ self.last_positions = {} # Last known positions of workers
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def update(self, dets, scores, cls):
43
+ tracks = []
44
  current_time = time.time()
 
 
45
 
46
+ # Update existing tracks with new detections
47
+ for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
48
+ if score < self.track_thresh:
49
+ continue
50
+
51
+ x, y, w, h = det
52
+ matched = False
53
+ best_iou = 0
54
+ best_track_id = None
55
 
56
+ # Try to match with existing tracks
57
+ for track_id, track_info in self.tracks.items():
58
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
59
+ continue
60
+
61
+ tx, ty, tw, th = track_info['bbox']
62
+ iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
63
+
64
+ if iou > self.match_thresh and iou > best_iou:
65
+ best_iou = iou
66
+ best_track_id = track_id
67
+ matched = True
68
 
69
+ if matched:
70
+ # Update existing track
71
+ self.tracks[best_track_id].update({
72
+ 'bbox': [x, y, w, h],
73
+ 'score': score,
74
+ 'cls': cl,
75
+ 'last_seen': current_time
76
+ })
77
 
78
+ # Update position history
79
+ if best_track_id not in self.worker_history:
80
+ self.worker_history[best_track_id] = []
81
+ self.worker_history[best_track_id].append([x, y])
82
+ self.last_positions[best_track_id] = [x, y]
 
 
 
 
 
 
83
 
84
+ tracks.append({
85
+ 'id': best_track_id,
86
+ 'bbox': [x, y, w, h],
87
+ 'score': score,
88
+ 'cls': cl
89
+ })
90
+ else:
91
+ # Create new track
92
+ # Check if this detection might be the same worker from a different angle
93
+ same_worker = False
94
+ for worker_id, last_pos in self.last_positions.items():
95
+ if self._is_same_worker([x, y], last_pos):
96
+ self.tracks[worker_id] = {
97
+ 'bbox': [x, y, w, h],
98
+ 'score': score,
99
+ 'cls': cl,
100
+ 'last_seen': current_time
101
+ }
102
+ tracks.append({
103
+ 'id': worker_id,
104
+ 'bbox': [x, y, w, h],
105
+ 'score': score,
106
+ 'cls': cl
107
+ })
108
+ same_worker = True
109
+ break
110
 
111
+ if not same_worker:
112
+ self.tracks[self.next_id] = {
113
+ 'bbox': [x, y, w, h],
114
+ 'score': score,
115
+ 'cls': cl,
116
+ 'last_seen': current_time
117
+ }
118
+ self.worker_history[self.next_id] = [[x, y]]
119
+ self.last_positions[self.next_id] = [x, y]
120
+ tracks.append({
121
+ 'id': self.next_id,
122
+ 'bbox': [x, y, w, h],
123
+ 'score': score,
124
+ 'cls': cl
125
+ })
126
+ self.next_id += 1
127
 
128
  # Clean up old tracks
129
+ current_time = time.time()
130
+ stale_ids = []
131
+ for track_id, track_info in self.tracks.items():
132
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
133
+ stale_ids.append(track_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ for track_id in stale_ids:
136
+ del self.tracks[track_id]
137
+ if track_id in self.worker_history:
138
+ del self.worker_history[track_id]
139
+ if track_id in self.last_positions:
140
+ del self.last_positions[track_id]
 
 
 
 
 
 
 
 
 
 
141
 
142
+ return tracks
143
+
144
+ def _calculate_iou(self, box1, box2):
145
+ """Calculate IOU between two boxes"""
146
+ x1, y1, w1, h1 = box1
147
+ x2, y2, w2, h2 = box2
148
 
149
+ # Calculate intersection coordinates
150
+ x_left = max(x1 - w1/2, x2 - w2/2)
151
+ y_top = max(y1 - h1/2, y2 - h2/2)
152
+ x_right = min(x1 + w1/2, x2 + w2/2)
153
+ y_bottom = min(y1 + h1/2, y2 + h2/2)
 
154
 
155
+ if x_right < x_left or y_bottom < y_top:
156
+ return 0.0
157
 
158
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ box1_area = w1 * h1
161
+ box2_area = w2 * h2
162
+
163
+ iou = intersection_area / (box1_area + box2_area - intersection_area)
164
+ return iou
165
+
166
+ def _is_same_worker(self, pos1, pos2, threshold=100):
167
+ """Check if two positions likely belong to the same worker"""
168
+ x1, y1 = pos1
169
+ x2, y2 = pos2
170
+ distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
171
+ return distance < threshold
172
 
173
  # ========================== # Optimized Configuration # ==========================
174
  CONFIG = {
 
211
  "improper_tool_use": 0.3
212
  },
213
  "MIN_VIOLATION_FRAMES": 1,
214
+ "VIOLATION_COOLDOWN": 30.0, # Increased cooldown period
215
+ "WORKER_TRACKING_DURATION": 5.0,
216
+ "MAX_PROCESSING_TIME": 60,
217
+ "FRAME_SKIP": 2, # Skip more frames for faster processing
218
  "BATCH_SIZE": 16,
219
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
220
+ "TRACK_BUFFER": 30,
221
+ "TRACK_THRESH": 0.3,
222
+ "MATCH_THRESH": 0.7,
223
+ "SNAPSHOT_QUALITY": 95, # Higher quality for better visibility
224
+ "MAX_WORKER_DISTANCE": 100 # Maximum pixel distance to consider same worker
225
  }
226
 
227
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
296
  "improper_tool_use": 25
297
  }
298
 
299
+ # Count unique violation types per worker
300
+ worker_violations = {}
301
  for v in violations:
302
+ worker_id = v.get("worker_id", "Unknown")
303
  violation_type = v.get("violation", "Unknown")
304
+
305
+ if worker_id not in worker_violations:
306
+ worker_violations[worker_id] = set()
307
+ worker_violations[worker_id].add(violation_type)
308
+
309
+ # Calculate total penalty
310
+ total_penalty = 0
311
+ for worker_violations_set in worker_violations.values():
312
+ worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
313
+ total_penalty += worker_penalty
314
 
 
315
  score = max(0, 100 - total_penalty)
316
  return score
317
 
 
342
  c.drawString(1 * inch, y_position, "Summary:")
343
  y_position -= 0.3 * inch
344
 
345
+ # Group violations by worker
346
+ worker_violations = {}
347
+ for v in violations:
348
+ worker_id = v.get("worker_id", "Unknown")
349
+ if worker_id not in worker_violations:
350
+ worker_violations[worker_id] = []
351
+ worker_violations[worker_id].append(v)
352
+
353
  c.setFont("Helvetica", 10)
354
  summary_data = {
355
+ "Total Workers with Violations": len(worker_violations),
356
  "Total Violations Found": len(violations),
 
357
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
358
  }
359
 
 
361
  c.drawString(1 * inch, y_position, f"{key}: {value}")
362
  y_position -= 0.25 * inch
363
 
364
+ # Detailed Violations by Worker
365
  y_position -= 0.5 * inch
366
  c.setFont("Helvetica-Bold", 12)
367
+ c.drawString(1 * inch, y_position, "Violations by Worker:")
368
  y_position -= 0.3 * inch
369
 
370
  c.setFont("Helvetica", 10)
371
+ for worker_id, worker_vios in worker_violations.items():
372
+ c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
 
 
 
 
 
 
373
  y_position -= 0.2 * inch
374
 
375
+ for v in worker_vios:
376
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
377
+ time_str = f"{v.get('timestamp', 0.0):.2f}s"
378
+ conf_str = f"{v.get('confidence', 0.0):.2f}"
379
+
380
+ violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
381
+ c.drawString(1.2 * inch, y_position, violation_text)
382
+ y_position -= 0.2 * inch
383
+
384
+ if y_position < 1 * inch:
385
+ c.showPage()
386
+ c.setFont("Helvetica", 10)
387
+ y_position = 10 * inch
388
 
389
  c.save()
390
  pdf_file.seek(0)
 
498
  return None, ""
499
 
500
  def process_video(video_data):
501
+ """Process video to detect safety violations"""
502
  try:
503
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
504
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
 
520
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
521
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
522
 
523
+ tracker = BYTETracker(
524
+ track_thresh=CONFIG["TRACK_THRESH"],
525
+ track_buffer=CONFIG["TRACK_BUFFER"],
526
+ match_thresh=CONFIG["MATCH_THRESH"],
527
+ frame_rate=fps
528
+ )
529
+
530
+ # Track unique violations by worker ID
531
+ unique_violations = {} # {worker_id: {violation_type: first_detection_time}}
532
  snapshots = []
533
  start_time = time.time()
534
  frame_skip = CONFIG["FRAME_SKIP"]
535
  processed_frames = 0
 
536
 
537
  while processed_frames < total_frames:
538
  batch_frames = []
 
557
  batch_frames.append(frame)
558
  batch_indices.append(frame_idx)
559
  processed_frames += 1
 
560
 
561
  if not batch_frames:
562
  break
 
574
  start_time = time.time()
575
 
576
  boxes = result.boxes
577
+ track_inputs = []
578
 
579
  for box in boxes:
580
  cls = int(box.cls)
 
588
  continue
589
 
590
  bbox = box.xywh.cpu().numpy()[0]
591
+ track_inputs.append({
592
  "bbox": bbox,
593
+ "conf": conf,
594
+ "cls": cls
595
  })
596
 
597
+ if not track_inputs:
598
  continue
599
 
600
+ tracked_objects = tracker.update(
601
+ np.array([t["bbox"] for t in track_inputs]),
602
+ np.array([t["conf"] for t in track_inputs]),
603
+ np.array([t["cls"] for t in track_inputs])
604
+ )
605
+
606
+ # Process tracked objects for violations
607
+ for obj in tracked_objects:
608
+ worker_id = obj['id']
609
+ label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
610
+ conf = obj['score']
611
+ bbox = obj['bbox']
 
 
 
 
 
 
 
 
 
 
 
612
 
613
+ if label is None:
614
+ continue
615
+
616
+ # Initialize worker if not seen before
617
+ if worker_id not in unique_violations:
618
+ unique_violations[worker_id] = {}
 
 
 
 
 
 
 
619
 
620
+ # Check if this violation type has been recorded for this worker
621
+ if label not in unique_violations[worker_id]:
622
+ # This is a new violation type for this worker
623
+ unique_violations[worker_id][label] = current_time
624
+
625
+ # Create detection object
626
+ detection = {
627
+ "worker_id": worker_id,
628
+ "violation": label,
629
+ "confidence": round(conf, 2),
630
+ "bounding_box": bbox,
631
+ "timestamp": current_time
632
+ }
633
+
634
+ # Take snapshot for the new violation
635
+ snapshot_frame = batch_frames[i].copy()
636
+ snapshot_frame = draw_detections(snapshot_frame, [detection])
637
+
638
+ # Add timestamp to snapshot
639
+ cv2.putText(
640
+ snapshot_frame,
641
+ f"Time: {current_time:.2f}s",
642
+ (10, 30),
643
+ cv2.FONT_HERSHEY_SIMPLEX,
644
+ 0.7,
645
+ (255, 255, 255),
646
+ 2
647
+ )
648
+
649
+ # Save snapshot with high quality
650
+ snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
651
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
652
+
653
+ cv2.imwrite(
654
+ snapshot_path,
655
+ snapshot_frame,
656
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
657
+ )
658
+
659
+ snapshots.append({
660
+ "violation": label,
661
+ "worker_id": worker_id,
662
+ "timestamp": current_time,
663
+ "snapshot_path": snapshot_path,
664
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
665
+ })
666
+
667
+ logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
668
 
669
  cap.release()
670
  if os.path.exists(video_path):
 
673
  processing_time = time.time() - start_time
674
  logger.info(f"Processing complete in {processing_time:.2f}s")
675
 
676
+ # Convert tracked violations to final violation list
677
  violations = []
678
+ for worker_id, worker_violations in unique_violations.items():
679
  for label, detection_time in worker_violations.items():
680
+ violation = {
681
  "worker_id": worker_id,
682
  "violation": label,
683
  "timestamp": detection_time
684
+ }
685
+ violations.append(violation)
686
 
687
  if not violations:
688
  logger.info("No violations detected after processing")
 
699
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
700
 
701
  # Format violations table for display
702
+ violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
703
+ violation_table += "|-----------|-----------|----------|------------|\n"
704
 
705
+ for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
706
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
707
  worker_id = v.get("worker_id", "Unknown")
708
  timestamp = v.get("timestamp", 0.0)
709
+ confidence = v.get("confidence", 0.0)
710
 
711
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
712
 
713
  # Format snapshots for display
714
  snapshots_text = ""