PrashanthB461 commited on
Commit
1c6b58d
·
verified ·
1 Parent(s): 15e98c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -233
app.py CHANGED
@@ -19,6 +19,7 @@ from retrying import retry
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
 
22
  from collections import defaultdict
23
 
24
  # ========================== # Configuration and Setup # ==========================
@@ -28,134 +29,174 @@ os.makedirs('/tmp/Ultralytics', exist_ok=True)
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
29
  logger = logging.getLogger(__name__)
30
 
31
- # ========================== # Enhanced BYTETracker Implementation # ==========================
32
- class EnhancedBYTETracker:
 
 
 
33
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
34
  self.track_thresh = track_thresh
35
  self.track_buffer = track_buffer
36
  self.match_thresh = match_thresh
37
  self.frame_rate = frame_rate
38
  self.next_id = 1
39
- self.tracks = {} # Store active tracks
 
 
40
  self.violation_history = defaultdict(dict) # Track violations per worker
41
- self.last_positions = {} # Last known positions of workers
42
- self.violation_cooldown = {} # Cooldown periods for violations
 
 
 
 
 
 
 
 
 
43
 
44
- def update(self, dets, scores, cls, frame_idx):
45
- tracks = []
46
- current_time = frame_idx / self.frame_rate
 
 
47
 
48
- # Update existing tracks with new detections
49
- for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
50
- if score < self.track_thresh:
51
- continue
52
-
53
- x, y, w, h = det
54
- matched = False
55
- best_iou = 0
56
- best_track_id = None
57
 
58
- # Try to match with existing tracks
59
- for track_id, track_info in self.tracks.items():
60
- if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
61
- continue
62
-
63
- tx, ty, tw, th = track_info['bbox']
64
- iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
65
-
66
- if iou > self.match_thresh and iou > best_iou:
67
- best_iou = iou
68
- best_track_id = track_id
69
- matched = True
70
 
71
- if matched:
72
- # Update existing track
73
- self.tracks[best_track_id].update({
74
- 'bbox': [x, y, w, h],
75
- 'score': score,
76
- 'cls': cl,
77
- 'last_seen': current_time
78
- })
79
-
80
- # Update position
81
- self.last_positions[best_track_id] = [x, y]
82
 
83
- tracks.append({
84
- 'id': best_track_id,
85
- 'bbox': [x, y, w, h],
86
- 'score': score,
87
- 'cls': cl
88
- })
89
- else:
90
- # Create new track
91
- self.tracks[self.next_id] = {
92
- 'bbox': [x, y, w, h],
93
- 'score': score,
94
- 'cls': cl,
95
- 'last_seen': current_time
96
  }
97
- self.last_positions[self.next_id] = [x, y]
98
- tracks.append({
99
- 'id': self.next_id,
100
- 'bbox': [x, y, w, h],
101
- 'score': score,
102
- 'cls': cl
103
- })
104
- self.next_id += 1
 
 
 
 
 
 
 
105
 
106
  # Clean up old tracks
107
- stale_ids = [tid for tid, track in self.tracks.items()
108
- if current_time - track['last_seen'] > self.track_buffer / self.frame_rate]
109
 
110
- for track_id in stale_ids:
111
- del self.tracks[track_id]
112
- if track_id in self.last_positions:
113
- del self.last_positions[track_id]
114
-
115
- return tracks
116
-
117
- def _calculate_iou(self, box1, box2):
118
- """Calculate IOU between two boxes"""
119
- x1, y1, w1, h1 = box1
120
- x2, y2, w2, h2 = box2
121
 
122
- # Calculate intersection coordinates
123
- x_left = max(x1 - w1/2, x2 - w2/2)
124
- y_top = max(y1 - h1/2, y2 - h2/2)
125
- x_right = min(x1 + w1/2, x2 + w2/2)
126
- y_bottom = min(y1 + h1/2, y2 + h2/2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- if x_right < x_left or y_bottom < y_top:
129
- return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
 
132
 
133
- box1_area = w1 * h1
134
- box2_area = w2 * h2
 
 
 
 
135
 
136
- iou = intersection_area / (box1_area + box2_area - intersection_area)
137
- return iou
138
-
139
- def is_new_violation(self, worker_id, violation_type, bbox, current_time):
140
- """Check if this is a new violation that should be reported"""
141
- # Check if this worker has had this violation type before
142
- if violation_type in self.violation_history[worker_id]:
143
- last_time, last_bbox = self.violation_history[worker_id][violation_type]
144
 
145
- # For no_helmet, we only report once per worker
146
- if violation_type == "no_helmet":
147
- return False
148
-
149
- # For other violations, check if it's in a new location or after cooldown
150
- if current_time - last_time < CONFIG["VIOLATION_COOLDOWN"]:
151
- # Check if it's in the same area (using IOU)
152
- iou = self._calculate_iou(bbox, last_bbox)
153
- if iou > 0.3: # If overlap > 30%, consider it the same violation
154
- return False
155
-
156
- # If we get here, it's a new violation
157
- self.violation_history[worker_id][violation_type] = (current_time, bbox)
158
- return True
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  # ========================== # Optimized Configuration # ==========================
161
  CONFIG = {
@@ -197,18 +238,12 @@ CONFIG = {
197
  "unsafe_zone": 0.3,
198
  "improper_tool_use": 0.3
199
  },
200
- "MIN_VIOLATION_FRAMES": 5,
201
- "VIOLATION_COOLDOWN": 30.0, # 30 seconds cooldown for same violation type in same area
202
- "WORKER_TRACKING_DURATION": 5.0,
203
- "MAX_PROCESSING_TIME": 60,
204
  "FRAME_SKIP": 2,
205
  "BATCH_SIZE": 16,
206
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
207
- "TRACK_BUFFER": 30,
208
- "TRACK_THRESH": 0.3,
209
- "MATCH_THRESH": 0.7,
210
  "SNAPSHOT_QUALITY": 95,
211
- "MAX_WORKER_DISTANCE": 100
212
  }
213
 
214
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -283,22 +318,13 @@ def calculate_safety_score(violations):
283
  "improper_tool_use": 25
284
  }
285
 
286
- # Count unique violation types per worker
287
- worker_violations = {}
288
  for v in violations:
289
- worker_id = v.get("worker_id", "Unknown")
290
  violation_type = v.get("violation", "Unknown")
291
-
292
- if worker_id not in worker_violations:
293
- worker_violations[worker_id] = set()
294
- worker_violations[worker_id].add(violation_type)
295
-
296
- # Calculate total penalty
297
- total_penalty = 0
298
- for worker_violations_set in worker_violations.values():
299
- worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
300
- total_penalty += worker_penalty
301
 
 
302
  score = max(0, 100 - total_penalty)
303
  return score
304
 
@@ -329,18 +355,10 @@ def generate_violation_pdf(violations, score):
329
  c.drawString(1 * inch, y_position, "Summary:")
330
  y_position -= 0.3 * inch
331
 
332
- # Group violations by worker
333
- worker_violations = {}
334
- for v in violations:
335
- worker_id = v.get("worker_id", "Unknown")
336
- if worker_id not in worker_violations:
337
- worker_violations[worker_id] = []
338
- worker_violations[worker_id].append(v)
339
-
340
  c.setFont("Helvetica", 10)
341
  summary_data = {
342
- "Total Workers with Violations": len(worker_violations),
343
  "Total Violations Found": len(violations),
 
344
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
345
  }
346
 
@@ -348,30 +366,27 @@ def generate_violation_pdf(violations, score):
348
  c.drawString(1 * inch, y_position, f"{key}: {value}")
349
  y_position -= 0.25 * inch
350
 
351
- # Detailed Violations by Worker
352
  y_position -= 0.5 * inch
353
  c.setFont("Helvetica-Bold", 12)
354
- c.drawString(1 * inch, y_position, "Violations by Worker:")
355
  y_position -= 0.3 * inch
356
 
357
  c.setFont("Helvetica", 10)
358
- for worker_id, worker_vios in worker_violations.items():
359
- c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
 
 
 
 
 
 
360
  y_position -= 0.2 * inch
361
 
362
- for v in worker_vios:
363
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
364
- time_str = f"{v.get('timestamp', 0.0):.2f}s"
365
- conf_str = f"{v.get('confidence', 0.0):.2f}"
366
-
367
- violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
368
- c.drawString(1.2 * inch, y_position, violation_text)
369
- y_position -= 0.2 * inch
370
-
371
- if y_position < 1 * inch:
372
- c.showPage()
373
- c.setFont("Helvetica", 10)
374
- y_position = 10 * inch
375
 
376
  c.save()
377
  pdf_file.seek(0)
@@ -507,18 +522,12 @@ def process_video(video_data):
507
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
508
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
509
 
510
- tracker = EnhancedBYTETracker(
511
- track_thresh=CONFIG["TRACK_THRESH"],
512
- track_buffer=CONFIG["TRACK_BUFFER"],
513
- match_thresh=CONFIG["MATCH_THRESH"],
514
- frame_rate=fps
515
- )
516
-
517
- violations = []
518
  snapshots = []
519
  start_time = time.time()
520
  frame_skip = CONFIG["FRAME_SKIP"]
521
  processed_frames = 0
 
522
 
523
  while processed_frames < total_frames:
524
  batch_frames = []
@@ -543,6 +552,7 @@ def process_video(video_data):
543
  batch_frames.append(frame)
544
  batch_indices.append(frame_idx)
545
  processed_frames += 1
 
546
 
547
  if not batch_frames:
548
  break
@@ -560,7 +570,7 @@ def process_video(video_data):
560
  start_time = time.time()
561
 
562
  boxes = result.boxes
563
- track_inputs = []
564
 
565
  for box in boxes:
566
  cls = int(box.cls)
@@ -574,78 +584,54 @@ def process_video(video_data):
574
  continue
575
 
576
  bbox = box.xywh.cpu().numpy()[0]
577
- track_inputs.append({
578
  "bbox": bbox,
579
- "conf": conf,
580
- "cls": cls
581
  })
582
 
583
- if not track_inputs:
584
  continue
585
 
586
- tracked_objects = tracker.update(
587
- np.array([t["bbox"] for t in track_inputs]),
588
- np.array([t["conf"] for t in track_inputs]),
589
- np.array([t["cls"] for t in track_inputs]),
590
- frame_idx
591
- )
592
-
593
- # Process tracked objects for violations
594
- for obj in tracked_objects:
595
- worker_id = obj['id']
596
- label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
597
- conf = obj['score']
598
- bbox = obj['bbox']
599
 
600
- if label is None:
601
- continue
602
-
603
- # Check if this is a new violation that should be reported
604
- if tracker.is_new_violation(worker_id, label, bbox, current_time):
605
- # This is a new violation that should be reported
606
- detection = {
607
- "worker_id": worker_id,
608
- "violation": label,
609
- "confidence": round(conf, 2),
610
- "bounding_box": bbox,
611
- "timestamp": current_time
612
- }
613
- violations.append(detection)
614
-
615
- # Take snapshot for the new violation
616
- snapshot_frame = batch_frames[i].copy()
617
- snapshot_frame = draw_detections(snapshot_frame, [detection])
618
-
619
- # Add timestamp to snapshot
620
- cv2.putText(
621
- snapshot_frame,
622
- f"Time: {current_time:.2f}s",
623
- (10, 30),
624
- cv2.FONT_HERSHEY_SIMPLEX,
625
- 0.7,
626
- (255, 255, 255),
627
- 2
628
- )
629
-
630
- # Save snapshot with high quality
631
- snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
632
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
633
-
634
- cv2.imwrite(
635
- snapshot_path,
636
- snapshot_frame,
637
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
638
- )
639
-
640
- snapshots.append({
641
- "violation": label,
642
- "worker_id": worker_id,
643
- "timestamp": current_time,
644
- "snapshot_path": snapshot_path,
645
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
646
- })
647
-
648
- logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
649
 
650
  cap.release()
651
  if os.path.exists(video_path):
@@ -654,6 +640,16 @@ def process_video(video_data):
654
  processing_time = time.time() - start_time
655
  logger.info(f"Processing complete in {processing_time:.2f}s")
656
 
 
 
 
 
 
 
 
 
 
 
657
  if not violations:
658
  logger.info("No violations detected after processing")
659
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
@@ -669,16 +665,15 @@ def process_video(video_data):
669
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
670
 
671
  # Format violations table for display
672
- violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
673
- violation_table += "|-----------|-----------|----------|------------|\n"
674
 
675
- for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
676
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
677
  worker_id = v.get("worker_id", "Unknown")
678
  timestamp = v.get("timestamp", 0.0)
679
- confidence = v.get("confidence", 0.0)
680
 
681
- violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
682
 
683
  # Format snapshots for display
684
  snapshots_text = ""
 
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
22
+ import face_recognition
23
  from collections import defaultdict
24
 
25
  # ========================== # Configuration and Setup # ==========================
 
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
30
  logger = logging.getLogger(__name__)
31
 
32
+ # Suppress warnings
33
+ warnings.filterwarnings("ignore")
34
+
35
+ # ========================== # Enhanced Tracker Implementation # ==========================
36
+ class SafetyTracker:
37
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
38
  self.track_thresh = track_thresh
39
  self.track_buffer = track_buffer
40
  self.match_thresh = match_thresh
41
  self.frame_rate = frame_rate
42
  self.next_id = 1
43
+
44
+ # Trackers for different purposes
45
+ self.worker_tracks = {} # Active worker tracks
46
  self.violation_history = defaultdict(dict) # Track violations per worker
47
+ self.face_encodings = {} # Store face encodings for helmet violations
48
+ self.position_history = defaultdict(list) # Track positions for non-helmet violations
49
+
50
+ # Cooldown periods (in seconds)
51
+ self.VIOLATION_COOLDOWNS = {
52
+ "no_helmet": 30.0,
53
+ "no_harness": 20.0,
54
+ "unsafe_posture": 15.0,
55
+ "unsafe_zone": 10.0,
56
+ "improper_tool_use": 15.0
57
+ }
58
 
59
+ def update(self, detections, frame):
60
+ """Update tracks with new detections and check for violations"""
61
+ current_time = time.time()
62
+ active_violations = []
63
+ new_violations = []
64
 
65
+ for det in detections:
66
+ bbox = det['bbox']
67
+ label = det['violation']
68
+ confidence = det['confidence']
 
 
 
 
 
69
 
70
+ # For helmet violations, use face recognition
71
+ if label == "no_helmet":
72
+ worker_id = self._match_by_face(bbox, frame)
73
+ else:
74
+ # For other violations, use position tracking
75
+ worker_id = self._match_by_position(bbox, label)
 
 
 
 
 
 
76
 
77
+ if worker_id is None:
78
+ worker_id = self.next_id
79
+ self.next_id += 1
 
 
 
 
 
 
 
 
80
 
81
+ # Check if this is a new violation for this worker
82
+ if self._is_new_violation(worker_id, label, current_time):
83
+ # Record the violation
84
+ violation = {
85
+ 'worker_id': worker_id,
86
+ 'violation': label,
87
+ 'confidence': confidence,
88
+ 'bbox': bbox,
89
+ 'timestamp': current_time
 
 
 
 
90
  }
91
+ new_violations.append(violation)
92
+
93
+ # Update violation history
94
+ self.violation_history[worker_id][label] = current_time
95
+
96
+ # For helmet violations, store face encoding
97
+ if label == "no_helmet":
98
+ self._store_face_encoding(worker_id, bbox, frame)
99
+
100
+ # Keep track of active workers
101
+ self.worker_tracks[worker_id] = {
102
+ 'bbox': bbox,
103
+ 'last_seen': current_time,
104
+ 'label': label
105
+ }
106
 
107
  # Clean up old tracks
108
+ self._cleanup_tracks(current_time)
 
109
 
110
+ return new_violations
111
+
112
+ def _match_by_face(self, bbox, frame):
113
+ """Match detection by face recognition (for helmet violations)"""
114
+ x, y, w, h = bbox
115
+ face_region = frame[max(0, int(y-h/2)):int(y+h/2), max(0, int(x-w/2)):int(x+w/2)]
 
 
 
 
 
116
 
117
+ if face_region.size == 0:
118
+ return None
119
+
120
+ try:
121
+ # Get face encodings from current detection
122
+ face_locations = face_recognition.face_locations(face_region)
123
+ if not face_locations:
124
+ return None
125
+
126
+ current_encoding = face_recognition.face_encodings(face_region, face_locations)[0]
127
+
128
+ # Compare with known faces
129
+ for worker_id, encodings in self.face_encodings.items():
130
+ matches = face_recognition.compare_faces(encodings, current_encoding, tolerance=0.6)
131
+ if any(matches):
132
+ return worker_id
133
+
134
+ except Exception as e:
135
+ logger.warning(f"Face recognition error: {e}")
136
+
137
+ return None
138
+
139
+ def _match_by_position(self, bbox, label):
140
+ """Match detection by position (for non-helmet violations)"""
141
+ x, y, w, h = bbox
142
+ current_pos = (x, y)
143
 
144
+ for worker_id, positions in self.position_history.items():
145
+ if label not in self.violation_history[worker_id]:
146
+ continue
147
+
148
+ # Check if current position is near any previous positions for this worker
149
+ for pos in positions:
150
+ distance = np.sqrt((current_pos[0]-pos[0])**2 + (current_pos[1]-pos[1])**2)
151
+ if distance < 100: # Within 100 pixels
152
+ return worker_id
153
+
154
+ return None
155
+
156
+ def _is_new_violation(self, worker_id, label, current_time):
157
+ """Check if this is a new violation for this worker"""
158
+ if label not in self.violation_history[worker_id]:
159
+ return True
160
 
161
+ last_detection = self.violation_history[worker_id][label]
162
+ cooldown = self.VIOLATION_COOLDOWNS.get(label, 10.0)
163
 
164
+ return (current_time - last_detection) > cooldown
165
+
166
+ def _store_face_encoding(self, worker_id, bbox, frame):
167
+ """Store face encoding for a worker"""
168
+ x, y, w, h = bbox
169
+ face_region = frame[max(0, int(y-h/2)):int(y+h/2), max(0, int(x-w/2)):int(x+w/2)]
170
 
171
+ if face_region.size == 0:
172
+ return
 
 
 
 
 
 
173
 
174
+ try:
175
+ face_locations = face_recognition.face_locations(face_region)
176
+ if face_locations:
177
+ encoding = face_recognition.face_encodings(face_region, face_locations)[0]
178
+ if worker_id not in self.face_encodings:
179
+ self.face_encodings[worker_id] = []
180
+ self.face_encodings[worker_id].append(encoding)
181
+ except Exception as e:
182
+ logger.warning(f"Error storing face encoding: {e}")
183
+
184
+ def _cleanup_tracks(self, current_time):
185
+ """Clean up old tracks and face encodings"""
186
+ # Remove inactive workers
187
+ inactive_ids = [
188
+ worker_id for worker_id, track in self.worker_tracks.items()
189
+ if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
190
+ ]
191
+
192
+ for worker_id in inactive_ids:
193
+ self.worker_tracks.pop(worker_id, None)
194
+ self.position_history.pop(worker_id, None)
195
+
196
+ # Keep face encodings for a longer period (for helmet violations)
197
+ if (current_time - max(self.violation_history[worker_id].values(), default=0)) > 300: # 5 minutes
198
+ self.face_encodings.pop(worker_id, None)
199
+ self.violation_history.pop(worker_id, None)
200
 
201
  # ========================== # Optimized Configuration # ==========================
202
  CONFIG = {
 
238
  "unsafe_zone": 0.3,
239
  "improper_tool_use": 0.3
240
  },
241
+ "MIN_VIOLATION_FRAMES": 1,
 
 
 
242
  "FRAME_SKIP": 2,
243
  "BATCH_SIZE": 16,
244
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
 
 
 
245
  "SNAPSHOT_QUALITY": 95,
246
+ "FACE_RECOGNITION_INTERVAL": 5 # Process face recognition every 5 frames
247
  }
248
 
249
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
318
  "improper_tool_use": 25
319
  }
320
 
321
+ # Count unique violation types
322
+ unique_violations = set()
323
  for v in violations:
 
324
  violation_type = v.get("violation", "Unknown")
325
+ unique_violations.add(violation_type)
 
 
 
 
 
 
 
 
 
326
 
327
+ total_penalty = sum(penalties.get(v, 0) for v in unique_violations)
328
  score = max(0, 100 - total_penalty)
329
  return score
330
 
 
355
  c.drawString(1 * inch, y_position, "Summary:")
356
  y_position -= 0.3 * inch
357
 
 
 
 
 
 
 
 
 
358
  c.setFont("Helvetica", 10)
359
  summary_data = {
 
360
  "Total Violations Found": len(violations),
361
+ "Unique Violation Types": len(set(v['violation'] for v in violations)),
362
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
363
  }
364
 
 
366
  c.drawString(1 * inch, y_position, f"{key}: {value}")
367
  y_position -= 0.25 * inch
368
 
369
+ # Detailed Violations
370
  y_position -= 0.5 * inch
371
  c.setFont("Helvetica-Bold", 12)
372
+ c.drawString(1 * inch, y_position, "Violation Details:")
373
  y_position -= 0.3 * inch
374
 
375
  c.setFont("Helvetica", 10)
376
+ for v in violations:
377
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
378
+ worker_id = v.get("worker_id", "Unknown")
379
+ time_str = f"{v.get('timestamp', 0.0):.2f}s"
380
+ conf_str = f"{v.get('confidence', 0.0):.2f}"
381
+
382
+ violation_text = f"- {display_name} by Worker {worker_id} at {time_str} (Confidence: {conf_str})"
383
+ c.drawString(1.2 * inch, y_position, violation_text)
384
  y_position -= 0.2 * inch
385
 
386
+ if y_position < 1 * inch:
387
+ c.showPage()
388
+ c.setFont("Helvetica", 10)
389
+ y_position = 10 * inch
 
 
 
 
 
 
 
 
 
390
 
391
  c.save()
392
  pdf_file.seek(0)
 
522
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
523
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
524
 
525
+ tracker = SafetyTracker(frame_rate=fps)
 
 
 
 
 
 
 
526
  snapshots = []
527
  start_time = time.time()
528
  frame_skip = CONFIG["FRAME_SKIP"]
529
  processed_frames = 0
530
+ frame_counter = 0
531
 
532
  while processed_frames < total_frames:
533
  batch_frames = []
 
552
  batch_frames.append(frame)
553
  batch_indices.append(frame_idx)
554
  processed_frames += 1
555
+ frame_counter += 1
556
 
557
  if not batch_frames:
558
  break
 
570
  start_time = time.time()
571
 
572
  boxes = result.boxes
573
+ detections = []
574
 
575
  for box in boxes:
576
  cls = int(box.cls)
 
584
  continue
585
 
586
  bbox = box.xywh.cpu().numpy()[0]
587
+ detections.append({
588
  "bbox": bbox,
589
+ "violation": label,
590
+ "confidence": conf
591
  })
592
 
593
+ if not detections:
594
  continue
595
 
596
+ # Update tracker with new detections
597
+ new_violations = tracker.update(detections, batch_frames[i])
598
+
599
+ # Process new violations
600
+ for violation in new_violations:
601
+ # Take snapshot for the new violation
602
+ snapshot_frame = batch_frames[i].copy()
603
+ snapshot_frame = draw_detections(snapshot_frame, [violation])
 
 
 
 
 
604
 
605
+ # Add timestamp to snapshot
606
+ cv2.putText(
607
+ snapshot_frame,
608
+ f"Time: {violation['timestamp']:.2f}s",
609
+ (10, 30),
610
+ cv2.FONT_HERSHEY_SIMPLEX,
611
+ 0.7,
612
+ (255, 255, 255),
613
+ 2
614
+ )
615
+
616
+ # Save snapshot with high quality
617
+ snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
618
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
619
+
620
+ cv2.imwrite(
621
+ snapshot_path,
622
+ snapshot_frame,
623
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
624
+ )
625
+
626
+ snapshots.append({
627
+ "violation": violation['violation'],
628
+ "worker_id": violation['worker_id'],
629
+ "timestamp": violation['timestamp'],
630
+ "snapshot_path": snapshot_path,
631
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
632
+ })
633
+
634
+ logger.info(f"Captured snapshot for {violation['violation']} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
  cap.release()
637
  if os.path.exists(video_path):
 
640
  processing_time = time.time() - start_time
641
  logger.info(f"Processing complete in {processing_time:.2f}s")
642
 
643
+ # Get all unique violations from tracker
644
+ violations = []
645
+ for worker_id, worker_violations in tracker.violation_history.items():
646
+ for label, detection_time in worker_violations.items():
647
+ violations.append({
648
+ "worker_id": worker_id,
649
+ "violation": label,
650
+ "timestamp": detection_time
651
+ })
652
+
653
  if not violations:
654
  logger.info("No violations detected after processing")
655
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
 
665
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
666
 
667
  # Format violations table for display
668
+ violation_table = "| Violation | Worker ID | Time (s) |\n"
669
+ violation_table += "|-----------|-----------|----------|\n"
670
 
671
+ for v in sorted(violations, key=lambda x: x.get("timestamp", 0.0)):
672
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
673
  worker_id = v.get("worker_id", "Unknown")
674
  timestamp = v.get("timestamp", 0.0)
 
675
 
676
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} |\n"
677
 
678
  # Format snapshots for display
679
  snapshots_text = ""