PrashanthB461 commited on
Commit
714f201
·
verified ·
1 Parent(s): 6c07dac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -359
app.py CHANGED
@@ -43,44 +43,36 @@ class BYTETracker:
43
  def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
- self.match_thresh = match_thresh # Increased to 0.5 for better matching
47
  self.frame_rate = frame_rate
48
  self.next_id = 1
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
- self.recently_removed = {} # Store recently removed tracks for re-identification
53
 
54
  def update(self, dets, scores, cls):
55
  tracks = []
56
  current_time = time.time()
57
 
58
  # Prune stale tracks
59
- stale_ids = []
60
- for track_id, track_info in self.tracks.items():
61
- if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
62
- stale_ids.append(track_id)
63
-
64
- for track_id in stale_ids:
65
- # Store recently removed tracks for re-identification (for 1 second)
66
- self.recently_removed[track_id] = {
67
- 'bbox': self.tracks[track_id]['bbox'],
68
  'last_seen': current_time,
69
- 'last_position': self.last_positions.get(track_id, [0, 0])
70
  }
71
- del self.tracks[track_id]
72
- if track_id in self.worker_history:
73
- del self.worker_history[track_id]
74
- if track_id in self.last_positions:
75
- del self.last_positions[track_id]
76
-
77
- # Clean up recently_removed tracks older than 1 second
78
- to_remove = []
79
- for track_id, info in self.recently_removed.items():
80
- if current_time - info['last_seen'] > 1.0:
81
- to_remove.append(track_id)
82
- for track_id in to_remove:
83
- del self.recently_removed[track_id]
84
 
85
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
86
  if score < self.track_thresh:
@@ -91,28 +83,17 @@ class BYTETracker:
91
  best_iou = 0
92
  best_track_id = None
93
 
94
- # Try to match with active tracks
95
  for track_id, track_info in self.tracks.items():
96
  tx, ty, tw, th = track_info['bbox']
97
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
98
-
99
  if iou > self.match_thresh and iou > best_iou:
100
  best_iou = iou
101
  best_track_id = track_id
102
  matched = True
103
 
104
  if matched:
105
- self.tracks[best_track_id].update({
106
- 'bbox': [x, y, w, h],
107
- 'score': score,
108
- 'cls': cl,
109
- 'last_seen': current_time
110
- })
111
- if best_track_id not in self.worker_history:
112
- self.worker_history[best_track_id] = []
113
- self.worker_history[best_track_id].append([x, y])
114
- self.last_positions[best_track_id] = [x, y]
115
-
116
  tracks.append({
117
  'id': best_track_id,
118
  'bbox': [x, y, w, h],
@@ -123,15 +104,8 @@ class BYTETracker:
123
  # Try to re-identify with recently removed tracks
124
  reidentified = False
125
  for track_id, info in self.recently_removed.items():
126
- if self._is_same_worker([x, y], info['last_position'], threshold=150): # Increased threshold
127
- self.tracks[track_id] = {
128
- 'bbox': [x, y, w, h],
129
- 'score': score,
130
- 'cls': cl,
131
- 'last_seen': current_time
132
- }
133
- self.worker_history[track_id] = [[x, y]]
134
- self.last_positions[track_id] = [x, y]
135
  tracks.append({
136
  'id': track_id,
137
  'bbox': [x, y, w, h],
@@ -139,20 +113,15 @@ class BYTETracker:
139
  'cls': cl
140
  })
141
  reidentified = True
142
- del self.recently_removed[track_id]
143
  break
144
 
145
  if not reidentified:
146
- # Check if it matches an existing worker by position
147
  same_worker = False
148
  for worker_id, last_pos in self.last_positions.items():
149
- if self._is_same_worker([x, y], last_pos, threshold=150): # Increased threshold
150
- self.tracks[worker_id] = {
151
- 'bbox': [x, y, w, h],
152
- 'score': score,
153
- 'cls': cl,
154
- 'last_seen': current_time
155
- }
156
  tracks.append({
157
  'id': worker_id,
158
  'bbox': [x, y, w, h],
@@ -163,14 +132,7 @@ class BYTETracker:
163
  break
164
 
165
  if not same_worker:
166
- self.tracks[self.next_id] = {
167
- 'bbox': [x, y, w, h],
168
- 'score': score,
169
- 'cls': cl,
170
- 'last_seen': current_time
171
- }
172
- self.worker_history[self.next_id] = [[x, y]]
173
- self.last_positions[self.next_id] = [x, y]
174
  tracks.append({
175
  'id': self.next_id,
176
  'bbox': [x, y, w, h],
@@ -181,6 +143,18 @@ class BYTETracker:
181
 
182
  return tracks
183
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def _calculate_iou(self, box1, box2):
185
  x1, y1, w1, h1 = box1
186
  x2, y2, w2, h2 = box2
@@ -193,14 +167,12 @@ class BYTETracker:
193
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
194
  box1_area = w1 * h1
195
  box2_area = w2 * h2
196
- iou = intersection_area / (box1_area + box2_area - intersection_area)
197
- return iou
198
 
199
- def _is_same_worker(self, pos1, pos2, threshold=150): # Increased threshold to 150
200
  x1, y1 = pos1
201
  x2, y2 = pos2
202
- distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
203
- return distance < threshold
204
 
205
  # ========================== # Optimized Configuration # ==========================
206
  CONFIG = {
@@ -221,10 +193,10 @@ CONFIG = {
221
  "improper_tool_use": (255, 255, 0)
222
  },
223
  "DISPLAY_NAMES": {
224
- "no_helmet": "No Helmet Violation",
225
- "no_harness": "No Harness Violation",
226
  "unsafe_posture": "Unsafe Posture",
227
- "unsafe_zone": "Unsafe Zone Entry",
228
  "improper_tool_use": "Improper Tool Use"
229
  },
230
  "SF_CREDENTIALS": {
@@ -243,16 +215,16 @@ CONFIG = {
243
  },
244
  "MIN_VIOLATION_FRAMES": 1,
245
  "VIOLATION_COOLDOWN": 30.0,
246
- "WORKER_TRACKING_DURATION": 10.0, # Reverted to 5.0 seconds
247
  "MAX_PROCESSING_TIME": 60,
248
- "FRAME_SKIP": 1,
249
- "BATCH_SIZE": 4,
250
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
251
- "TRACK_BUFFER": 150, # 5.0 seconds at 30 fps
252
  "TRACK_THRESH": 0.3,
253
- "MATCH_THRESH": 0.5, # Increased to 0.5
254
  "SNAPSHOT_QUALITY": 95,
255
- "MAX_WORKER_DISTANCE": 150, # Increased to match _is_same_worker threshold
256
  "TARGET_RESOLUTION": (384, 384)
257
  }
258
 
@@ -284,37 +256,27 @@ model = load_model()
284
 
285
  # ========================== # Helper Functions # ==========================
286
  def preprocess_frame(frame):
287
- target_res = CONFIG["TARGET_RESOLUTION"]
288
- frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
289
- frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
290
- return frame
291
 
292
  def draw_detections(frame, detections):
293
  result_frame = frame.copy()
294
-
295
  for det in detections:
296
  label = det.get("violation", "Unknown")
297
  confidence = det.get("confidence", 0.0)
298
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
299
  worker_id = det.get("worker_id", "Unknown")
300
 
301
- x1 = int(x - w/2)
302
- y1 = int(y - h/2)
303
- x2 = int(x + w/2)
304
- y2 = int(y + h/2)
305
-
306
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
307
 
308
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
309
-
310
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
311
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
312
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
313
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
314
-
315
- conf_text = f"Conf: {confidence:.2f}"
316
- cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
317
-
318
  return result_frame
319
 
320
  def calculate_safety_score(violations):
@@ -325,23 +287,13 @@ def calculate_safety_score(violations):
325
  "unsafe_zone": 35,
326
  "improper_tool_use": 25
327
  }
328
-
329
  worker_violations = {}
330
  for v in violations:
331
  worker_id = v.get("worker_id", "Unknown")
332
- violation_type = v.get("violation", "Unknown")
333
-
334
  if worker_id not in worker_violations:
335
  worker_violations[worker_id] = set()
336
- worker_violations[worker_id].add(violation_type)
337
-
338
- total_penalty = 0
339
- for worker_violations_set in worker_violations.values():
340
- worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
341
- total_penalty += worker_penalty
342
-
343
- score = max(0, 100 - total_penalty)
344
- return score
345
 
346
  def generate_violation_pdf(violations, score, output_dir):
347
  try:
@@ -352,11 +304,9 @@ def generate_violation_pdf(violations, score, output_dir):
352
 
353
  c.setFont("Helvetica-Bold", 16)
354
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
355
-
356
  c.setFont("Helvetica", 12)
357
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
358
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
359
-
360
  c.setFont("Helvetica-Bold", 14)
361
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
362
 
@@ -392,16 +342,12 @@ def generate_violation_pdf(violations, score, output_dir):
392
  for worker_id, worker_vios in worker_violations.items():
393
  c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
394
  y_position -= 0.2 * inch
395
-
396
  for v in worker_vios:
397
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
398
  time_str = f"{v.get('timestamp', 0.0):.2f}s"
399
  conf_str = f"{v.get('confidence', 0.0):.2f}"
400
-
401
- violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
402
- c.drawString(1.2 * inch, y_position, violation_text)
403
  y_position -= 0.2 * inch
404
-
405
  if y_position < 1 * inch:
406
  c.showPage()
407
  c.setFont("Helvetica", 10)
@@ -409,13 +355,9 @@ def generate_violation_pdf(violations, score, output_dir):
409
 
410
  c.save()
411
  pdf_file.seek(0)
412
-
413
  with open(pdf_path, "wb") as f:
414
  f.write(pdf_file.getvalue())
415
-
416
- public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
417
- logger.info(f"PDF generated: {public_url}")
418
- return pdf_path, public_url, pdf_file
419
  except Exception as e:
420
  logger.error(f"Error generating PDF: {e}")
421
  return "", "", None
@@ -425,7 +367,6 @@ def connect_to_salesforce():
425
  try:
426
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
427
  logger.info("Connected to Salesforce")
428
- sf.describe()
429
  return sf
430
  except Exception as e:
431
  logger.error(f"Salesforce connection failed: {e}")
@@ -434,26 +375,18 @@ def connect_to_salesforce():
434
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
435
  try:
436
  if not pdf_file:
437
- logger.error("No PDF file provided for upload")
438
  return ""
439
-
440
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
441
- content_version_data = {
442
  "Title": f"Safety_Violation_Report_{int(time.time())}",
443
  "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
444
  "VersionData": encoded_pdf,
445
  "FirstPublishLocationId": report_id
446
- }
447
- content_version = sf.ContentVersion.create(content_version_data)
448
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
449
-
450
- if not result['records']:
451
- logger.error("Failed to retrieve ContentVersion")
452
- return ""
453
-
454
- file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
455
- logger.info(f"PDF uploaded to Salesforce: {file_url}")
456
- return file_url
457
  except Exception as e:
458
  logger.error(f"Error uploading PDF to Salesforce: {e}")
459
  return ""
@@ -461,38 +394,25 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
461
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
462
  try:
463
  sf = connect_to_salesforce()
 
 
 
 
 
 
464
 
465
- violations_text = ""
466
- for v in violations:
467
- display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
468
- worker_id = v.get('worker_id', 'Unknown')
469
- timestamp = v.get('timestamp', 0.0)
470
- confidence = v.get('confidence', 0.0)
471
-
472
- violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
473
-
474
- if not violations_text:
475
- violations_text = "No violations detected."
476
-
477
- pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
478
-
479
  record_data = {
480
  "Compliance_Score__c": score,
481
  "Violations_Found__c": len(violations),
482
  "Violations_Details__c": violations_text,
483
  "Status__c": "Pending",
484
- "PDF_Report_URL__c": pdf_url
485
  }
486
 
487
- logger.info(f"Creating Salesforce record with data: {record_data}")
488
-
489
  try:
490
  record = sf.Safety_Video_Report__c.create(record_data)
491
- logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
492
- except Exception as e:
493
- logger.error(f"Failed to create Safety_Video_Report__c: {e}")
494
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
495
- logger.warning(f"Fell back to Account record: {record['id']}")
496
 
497
  record_id = record["id"]
498
 
@@ -501,81 +421,28 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
501
  if uploaded_url:
502
  try:
503
  sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
504
- logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
505
- except Exception as e:
506
- logger.error(f"Failed to update Safety_Video_Report__c: {e}")
507
  sf.Account.update(record_id, {"Description": uploaded_url})
508
- logger.info(f"Updated Account record {record_id} with PDF URL")
509
- pdf_url = uploaded_url
510
-
511
- return record_id, pdf_url
512
  except Exception as e:
513
  logger.error(f"Salesforce record creation failed: {e}")
514
  return "N/A", "Salesforce integration failed."
515
 
516
- @tenacity.retry(
517
- stop=tenacity.stop_after_attempt(3),
518
- wait=tenacity.wait_fixed(1),
519
- retry=tenacity.retry_if_exception_type((IOError, OSError)),
520
- before_sleep=lambda retry_state: logger.info(f"Retrying file access (attempt {retry_state.attempt_number}/3)...")
521
- )
522
- def verify_and_open_video(video_path):
523
- if not os.path.exists(video_path):
524
- raise FileNotFoundError(f"Temporary video file not found: {video_path}")
525
-
526
- file_size = os.path.getsize(video_path)
527
- if file_size == 0:
528
- raise ValueError(f"Temporary video file is empty: {video_path}")
529
-
530
- with open(video_path, "rb") as f:
531
- f.read(1)
532
-
533
- cap = cv2.VideoCapture(video_path)
534
- if not cap.isOpened():
535
- raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
536
-
537
- return cap
538
-
539
  def process_video(video_data, temp_dir):
540
  video_path = None
541
  output_dir = os.path.join(temp_dir, "output")
542
  os.makedirs(output_dir, exist_ok=True)
543
- os.environ['YOLO_CONFIG_DIR'] = temp_dir
544
-
545
  try:
546
- if not video_data:
547
- raise ValueError("Empty video data provided.")
548
-
549
- logger.info(f"Received video data size: {len(video_data)} bytes")
550
- if len(video_data) == 0:
551
- raise ValueError("Video data is empty.")
552
-
553
  with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
554
  temp_file.write(video_data)
555
- temp_file.flush()
556
  video_path = temp_file.name
557
- logger.info(f"Video saved to temporary file: {video_path}")
558
-
559
- if not os.path.exists(video_path):
560
- raise FileNotFoundError(f"Temporary video file not found: {video_path}")
561
- file_size = os.path.getsize(video_path)
562
- if file_size == 0:
563
- raise ValueError(f"Temporary video file is empty: {video_path}")
564
- logger.info(f"Temporary video file size: {file_size} bytes")
565
-
566
- cap = verify_and_open_video(video_path)
567
- logger.info(f"Successfully opened video file: {video_path}")
568
 
 
569
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
570
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
571
- duration = total_frames / fps
572
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
573
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
574
- logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
575
-
576
- if total_frames <= 0:
577
- raise ValueError("Video has no frames.")
578
-
579
  tracker = BYTETracker(
580
  track_thresh=CONFIG["TRACK_THRESH"],
581
  track_buffer=CONFIG["TRACK_BUFFER"],
@@ -586,11 +453,9 @@ def process_video(video_data, temp_dir):
586
  worker_id_mapping = {}
587
  unique_violations = {}
588
  violation_frames = {}
589
- worker_violation_count = {} # Track violation count per worker
590
  start_time = time.time()
591
- frame_skip = CONFIG["FRAME_SKIP"]
592
  processed_frames = 0
593
- last_yield_time = start_time
594
  worker_counter = 1
595
 
596
  while processed_frames < total_frames:
@@ -604,50 +469,39 @@ def process_video(video_data, temp_dir):
604
 
605
  ret, frame = cap.read()
606
  if not ret:
607
- logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
608
  break
609
 
610
  frame = preprocess_frame(frame)
611
-
612
- for _ in range(frame_skip - 1):
613
- if not cap.grab():
614
- break
615
-
616
  batch_frames.append(frame)
617
  batch_indices.append(frame_idx)
618
  processed_frames += 1
 
 
 
 
 
 
619
 
620
  if not batch_frames:
621
- logger.info("No more frames to process.")
622
  break
623
 
624
  try:
625
  batch_frames_np = np.array(batch_frames)
626
  batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
627
- batch_frames_tensor = batch_frames_tensor.to(device)
628
  if device.type == "cuda":
629
- batch_frames_tensor = batch_frames_tensor.half()
630
-
631
  results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
632
  except Exception as e:
633
  logger.error(f"Model inference failed: {e}")
634
- raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}")
635
- finally:
636
- batch_frames = []
637
- if device.type == "cuda":
638
- torch.cuda.empty_cache()
639
 
640
  current_time = time.time()
641
- if current_time - last_yield_time > 0.1:
642
- progress = (processed_frames / total_frames) * 100
643
- elapsed_time = current_time - start_time
644
- fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
645
- yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
646
- last_yield_time = current_time
647
 
648
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
649
- current_time = frame_idx / fps
650
-
651
  boxes = result.boxes
652
  track_inputs = []
653
 
@@ -655,19 +509,12 @@ def process_video(video_data, temp_dir):
655
  cls = int(box.cls)
656
  conf = float(box.conf)
657
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
658
-
659
- if label is None:
660
- continue
661
-
662
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
663
- continue
664
-
665
- bbox = box.xywh.cpu().numpy()[0]
666
- track_inputs.append({
667
- "bbox": bbox,
668
- "conf": conf,
669
- "cls": cls
670
- })
671
 
672
  if not track_inputs:
673
  continue
@@ -677,15 +524,11 @@ def process_video(video_data, temp_dir):
677
  np.array([t["conf"] for t in track_inputs]),
678
  np.array([t["cls"] for t in track_inputs])
679
  )
680
- logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers")
681
 
682
  for obj in tracked_objects:
683
  tracker_id = obj['id']
684
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
685
- conf = obj['score']
686
- bbox = obj['bbox']
687
-
688
- if label is None:
689
  continue
690
 
691
  if tracker_id not in worker_id_mapping:
@@ -693,108 +536,81 @@ def process_video(video_data, temp_dir):
693
  worker_counter += 1
694
 
695
  worker_id = worker_id_mapping[tracker_id]
696
-
697
  violation_key = (worker_id, label)
698
 
699
  if violation_key not in unique_violations:
700
- unique_violations[violation_key] = current_time
701
  violation_frames[violation_key] = frame_idx
702
- # Update violation count for this worker
703
  if worker_id not in worker_violation_count:
704
  worker_violation_count[worker_id] = 0
705
  worker_violation_count[worker_id] += 1
706
 
707
  cap.release()
708
- processing_time = time.time() - start_time
709
- logger.info(f"Processing complete in {processing_time:.2f}s")
710
- logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
711
- logger.info(f"Violations per worker: {worker_violation_count}")
712
-
713
- violations = []
714
- for (worker_id, label), detection_time in unique_violations.items():
715
- violations.append({
716
- "worker_id": worker_id,
717
- "violation": label,
718
- "timestamp": detection_time,
719
- "confidence": 0.0,
720
- "frame_idx": violation_frames[(worker_id, label)]
721
- })
722
 
723
  if not violations:
724
- logger.info("No violations detected after processing")
725
- yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
726
  return
727
 
 
728
  snapshots = []
729
  cap = cv2.VideoCapture(video_path)
730
  for violation in violations:
731
- frame_idx = violation["frame_idx"]
732
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
733
  ret, frame = cap.read()
734
  if not ret:
735
- logger.warning(f"Failed to read frame {frame_idx} for snapshot.")
736
  continue
737
 
738
  frame = preprocess_frame(frame)
739
  frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
740
- frame_tensor = frame_tensor.unsqueeze(0).to(device)
741
  if device.type == "cuda":
742
- frame_tensor = frame_tensor.half()
743
-
744
- result = model(frame_tensor, device=device, conf=0.1, verbose=False)[0]
745
- boxes = result.boxes
746
 
747
- for box in boxes:
 
748
  cls = int(box.cls)
749
  conf = float(box.conf)
750
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
751
- if label == violation["violation"]:
752
  violation["confidence"] = round(conf, 2)
753
  bbox = box.xywh.cpu().numpy()[0]
754
- detection = {
755
  "worker_id": violation["worker_id"],
756
- "violation": label,
757
  "confidence": violation["confidence"],
758
  "bounding_box": bbox,
759
  "timestamp": violation["timestamp"]
760
- }
761
- snapshot_frame = frame.copy()
762
- snapshot_frame = draw_detections(snapshot_frame, [detection])
763
- cv2.putText(
764
- snapshot_frame,
765
- f"Time: {violation['timestamp']:.2f}s",
766
- (10, 30),
767
- cv2.FONT_HERSHEY_SIMPLEX,
768
- 0.7,
769
- (255, 255, 255),
770
- 2
771
- )
772
- snapshot_filename = f"violation_{label}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
773
  snapshot_path = os.path.join(output_dir, snapshot_filename)
774
- cv2.imwrite(
775
- snapshot_path,
776
- snapshot_frame,
777
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
778
- )
779
  snapshots.append({
780
- "violation": label,
781
  "worker_id": violation["worker_id"],
782
  "timestamp": violation["timestamp"],
783
  "snapshot_path": snapshot_path,
784
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
785
  "confidence": violation["confidence"]
786
  })
787
- logger.info(f"Captured snapshot for {label} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
788
  break
789
-
790
  cap.release()
791
 
792
  score = calculate_safety_score(violations)
793
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
794
-
795
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
796
 
797
- # Generate summary of workers and their violations
798
  worker_summary = {}
799
  for v in violations:
800
  worker_id = v["worker_id"]
@@ -806,36 +622,29 @@ def process_video(video_data, temp_dir):
806
  worker_summary[worker_id]["count"] += 1
807
  worker_summary[worker_id]["violations"].add(v["violation"])
808
 
809
- # Create violation table with worker summary
810
  violation_table = "## Worker Safety Violation Summary\n\n"
811
- violation_table += "| Worker ID | Total Violations | Violation Types |\n"
812
- violation_table += "|-----------|------------------|-----------------|\n"
 
 
813
 
814
  for worker_id, info in worker_summary.items():
815
  violation_types = ", ".join([CONFIG["DISPLAY_NAMES"].get(v, v) for v in info["violations"]])
816
  violation_table += f"| {worker_id} | {info['count']} | {violation_types} |\n"
817
 
818
- violation_table += "\n## Detailed Violation Log\n\n"
819
- violation_table += "| Violation | Worker ID | Time (s) | Confidence |\n"
820
  violation_table += "|-----------|-----------|----------|------------|\n"
821
 
822
- for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
823
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
824
- worker_id = v.get("worker_id", "Unknown")
825
- timestamp = v.get("timestamp", 0.0)
826
- confidence = v.get("confidence", 0.0)
827
- violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
828
-
829
- snapshots_text = ""
830
- for s in snapshots:
831
- display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
832
- worker_id = s.get("worker_id", "Unknown")
833
- timestamp = s.get("timestamp", 0.0)
834
- snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
835
- snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
836
 
837
- if not snapshots_text:
838
- snapshots_text = "No snapshots captured."
 
 
 
839
 
840
  yield (
841
  violation_table,
@@ -852,55 +661,33 @@ def process_video(video_data, temp_dir):
852
  if video_path and os.path.exists(video_path):
853
  try:
854
  os.remove(video_path)
855
- logger.info(f"Cleaned up temporary video file: {video_path}")
856
  except Exception as e:
857
- logger.error(f"Failed to clean up temporary video file {video_path}: {e}")
858
  if device.type == "cuda":
859
  torch.cuda.empty_cache()
860
 
861
  def gradio_interface(video_file):
862
  temp_dir = None
863
- local_video_path = None
864
  try:
865
  if not video_file:
866
  return "No file uploaded.", "", "No file uploaded.", "", ""
867
 
868
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
869
- logger.info(f"Created temporary directory for video processing: {temp_dir}")
870
-
871
  with open(video_file, "rb") as f:
872
  video_data = f.read()
873
- logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
874
-
875
- if len(video_data) == 0:
876
- return "Uploaded video file is empty.", "", "", "", ""
877
-
878
- with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
879
- temp_file.write(video_data)
880
- temp_file.flush()
881
- local_video_path = temp_file.name
882
- logger.info(f"Copied Gradio video to local temporary file: {local_video_path}")
883
 
884
  if not FFMPEG_AVAILABLE:
885
- return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", ""
886
 
887
- for status, score, snapshots_text, record_id, details_url in process_video(video_data, temp_dir):
888
- yield status, score, snapshots_text, record_id, details_url
889
 
890
  except Exception as e:
891
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
892
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
893
  finally:
894
- if local_video_path and os.path.exists(local_video_path):
895
- try:
896
- os.remove(local_video_path)
897
- logger.info(f"Cleaned up local temporary video file: {local_video_path}")
898
- except Exception as e:
899
- logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
900
-
901
  if temp_dir and os.path.exists(temp_dir):
902
  shutil.rmtree(temp_dir, ignore_errors=True)
903
- logger.info(f"Cleaned up temporary directory: {temp_dir}")
904
  if device.type == "cuda":
905
  torch.cuda.empty_cache()
906
 
@@ -916,10 +703,10 @@ interface = gr.Interface(
916
  gr.Textbox(label="Violation Details URL")
917
  ],
918
  title="Worksite Safety Violation Analyzer",
919
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
920
  allow_flagging="never"
921
  )
922
 
923
  if __name__ == "__main__":
924
- logger.info("Launching Enhanced Safety Analyzer App...")
925
  interface.launch()
 
43
  def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
+ self.match_thresh = match_thresh
47
  self.frame_rate = frame_rate
48
  self.next_id = 1
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
+ self.recently_removed = {}
53
 
54
  def update(self, dets, scores, cls):
55
  tracks = []
56
  current_time = time.time()
57
 
58
  # Prune stale tracks
59
+ stale_ids = [tid for tid, track in self.tracks.items()
60
+ if current_time - track['last_seen'] > self.track_buffer / self.frame_rate]
61
+
62
+ for tid in stale_ids:
63
+ self.recently_removed[tid] = {
64
+ 'bbox': self.tracks[tid]['bbox'],
 
 
 
65
  'last_seen': current_time,
66
+ 'last_position': self.last_positions.get(tid, [0, 0])
67
  }
68
+ self.tracks.pop(tid, None)
69
+ self.worker_history.pop(tid, None)
70
+ self.last_positions.pop(tid, None)
71
+
72
+ # Clean up recently_removed
73
+ for tid in [tid for tid, info in self.recently_removed.items()
74
+ if current_time - info['last_seen'] > 1.0]:
75
+ self.recently_removed.pop(tid, None)
 
 
 
 
 
76
 
77
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
78
  if score < self.track_thresh:
 
83
  best_iou = 0
84
  best_track_id = None
85
 
86
+ # Match with active tracks
87
  for track_id, track_info in self.tracks.items():
88
  tx, ty, tw, th = track_info['bbox']
89
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
 
90
  if iou > self.match_thresh and iou > best_iou:
91
  best_iou = iou
92
  best_track_id = track_id
93
  matched = True
94
 
95
  if matched:
96
+ self._update_track(best_track_id, x, y, w, h, score, cl, current_time)
 
 
 
 
 
 
 
 
 
 
97
  tracks.append({
98
  'id': best_track_id,
99
  'bbox': [x, y, w, h],
 
104
  # Try to re-identify with recently removed tracks
105
  reidentified = False
106
  for track_id, info in self.recently_removed.items():
107
+ if self._is_same_worker([x, y], info['last_position'], threshold=100):
108
+ self._update_track(track_id, x, y, w, h, score, cl, current_time)
 
 
 
 
 
 
 
109
  tracks.append({
110
  'id': track_id,
111
  'bbox': [x, y, w, h],
 
113
  'cls': cl
114
  })
115
  reidentified = True
116
+ self.recently_removed.pop(track_id, None)
117
  break
118
 
119
  if not reidentified:
120
+ # Check existing workers by position
121
  same_worker = False
122
  for worker_id, last_pos in self.last_positions.items():
123
+ if self._is_same_worker([x, y], last_pos, threshold=100):
124
+ self._update_track(worker_id, x, y, w, h, score, cl, current_time)
 
 
 
 
 
125
  tracks.append({
126
  'id': worker_id,
127
  'bbox': [x, y, w, h],
 
132
  break
133
 
134
  if not same_worker:
135
+ self._update_track(self.next_id, x, y, w, h, score, cl, current_time)
 
 
 
 
 
 
 
136
  tracks.append({
137
  'id': self.next_id,
138
  'bbox': [x, y, w, h],
 
143
 
144
  return tracks
145
 
146
+ def _update_track(self, track_id, x, y, w, h, score, cls, current_time):
147
+ self.tracks[track_id] = {
148
+ 'bbox': [x, y, w, h],
149
+ 'score': score,
150
+ 'cls': cls,
151
+ 'last_seen': current_time
152
+ }
153
+ if track_id not in self.worker_history:
154
+ self.worker_history[track_id] = []
155
+ self.worker_history[track_id].append([x, y])
156
+ self.last_positions[track_id] = [x, y]
157
+
158
  def _calculate_iou(self, box1, box2):
159
  x1, y1, w1, h1 = box1
160
  x2, y2, w2, h2 = box2
 
167
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
168
  box1_area = w1 * h1
169
  box2_area = w2 * h2
170
+ return intersection_area / (box1_area + box2_area - intersection_area)
 
171
 
172
+ def _is_same_worker(self, pos1, pos2, threshold=100):
173
  x1, y1 = pos1
174
  x2, y2 = pos2
175
+ return np.sqrt((x1 - x2)**2 + (y1 - y2)**2) < threshold
 
176
 
177
  # ========================== # Optimized Configuration # ==========================
178
  CONFIG = {
 
193
  "improper_tool_use": (255, 255, 0)
194
  },
195
  "DISPLAY_NAMES": {
196
+ "no_helmet": "No Helmet",
197
+ "no_harness": "No Harness",
198
  "unsafe_posture": "Unsafe Posture",
199
+ "unsafe_zone": "Unsafe Zone",
200
  "improper_tool_use": "Improper Tool Use"
201
  },
202
  "SF_CREDENTIALS": {
 
215
  },
216
  "MIN_VIOLATION_FRAMES": 1,
217
  "VIOLATION_COOLDOWN": 30.0,
218
+ "WORKER_TRACKING_DURATION": 10.0,
219
  "MAX_PROCESSING_TIME": 60,
220
+ "FRAME_SKIP": 2, # Increased frame skip for faster processing
221
+ "BATCH_SIZE": 8, # Increased batch size
222
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
223
+ "TRACK_BUFFER": 150,
224
  "TRACK_THRESH": 0.3,
225
+ "MATCH_THRESH": 0.5,
226
  "SNAPSHOT_QUALITY": 95,
227
+ "MAX_WORKER_DISTANCE": 100, # Reduced threshold for better worker matching
228
  "TARGET_RESOLUTION": (384, 384)
229
  }
230
 
 
256
 
257
  # ========================== # Helper Functions # ==========================
258
  def preprocess_frame(frame):
259
+ frame = cv2.resize(frame, CONFIG["TARGET_RESOLUTION"], interpolation=cv2.INTER_LINEAR)
260
+ return cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
 
 
261
 
262
  def draw_detections(frame, detections):
263
  result_frame = frame.copy()
 
264
  for det in detections:
265
  label = det.get("violation", "Unknown")
266
  confidence = det.get("confidence", 0.0)
267
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
268
  worker_id = det.get("worker_id", "Unknown")
269
 
270
+ x1, y1 = int(x - w/2), int(y - h/2)
271
+ x2, y2 = int(x + w/2), int(y + h/2)
 
 
 
272
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
273
 
274
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
 
275
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
276
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
277
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
278
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
279
+ cv2.putText(result_frame, f"Conf: {confidence:.2f}", (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
 
 
 
280
  return result_frame
281
 
282
  def calculate_safety_score(violations):
 
287
  "unsafe_zone": 35,
288
  "improper_tool_use": 25
289
  }
 
290
  worker_violations = {}
291
  for v in violations:
292
  worker_id = v.get("worker_id", "Unknown")
 
 
293
  if worker_id not in worker_violations:
294
  worker_violations[worker_id] = set()
295
+ worker_violations[worker_id].add(v.get("violation", "Unknown"))
296
+ return max(0, 100 - sum(penalties.get(v, 0) for violations in worker_violations.values() for v in violations))
 
 
 
 
 
 
 
297
 
298
  def generate_violation_pdf(violations, score, output_dir):
299
  try:
 
304
 
305
  c.setFont("Helvetica-Bold", 16)
306
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
 
307
  c.setFont("Helvetica", 12)
308
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
309
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
 
310
  c.setFont("Helvetica-Bold", 14)
311
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
312
 
 
342
  for worker_id, worker_vios in worker_violations.items():
343
  c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
344
  y_position -= 0.2 * inch
 
345
  for v in worker_vios:
346
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
347
  time_str = f"{v.get('timestamp', 0.0):.2f}s"
348
  conf_str = f"{v.get('confidence', 0.0):.2f}"
349
+ c.drawString(1.2 * inch, y_position, f" - {display_name} at {time_str} (Confidence: {conf_str})")
 
 
350
  y_position -= 0.2 * inch
 
351
  if y_position < 1 * inch:
352
  c.showPage()
353
  c.setFont("Helvetica", 10)
 
355
 
356
  c.save()
357
  pdf_file.seek(0)
 
358
  with open(pdf_path, "wb") as f:
359
  f.write(pdf_file.getvalue())
360
+ return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}", pdf_file
 
 
 
361
  except Exception as e:
362
  logger.error(f"Error generating PDF: {e}")
363
  return "", "", None
 
367
  try:
368
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
369
  logger.info("Connected to Salesforce")
 
370
  return sf
371
  except Exception as e:
372
  logger.error(f"Salesforce connection failed: {e}")
 
375
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
376
  try:
377
  if not pdf_file:
 
378
  return ""
 
379
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
380
+ content_version = sf.ContentVersion.create({
381
  "Title": f"Safety_Violation_Report_{int(time.time())}",
382
  "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
383
  "VersionData": encoded_pdf,
384
  "FirstPublishLocationId": report_id
385
+ })
 
386
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
387
+ if result['records']:
388
+ return f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
389
+ return ""
 
 
 
 
 
390
  except Exception as e:
391
  logger.error(f"Error uploading PDF to Salesforce: {e}")
392
  return ""
 
394
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
395
  try:
396
  sf = connect_to_salesforce()
397
+ violations_text = "\n".join(
398
+ f"Worker {v.get('worker_id', 'Unknown')}: "
399
+ f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} "
400
+ f"at {v.get('timestamp', 0.0):.2f}s (Conf: {v.get('confidence', 0.0):.2f})"
401
+ for v in violations
402
+ ) or "No violations detected."
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  record_data = {
405
  "Compliance_Score__c": score,
406
  "Violations_Found__c": len(violations),
407
  "Violations_Details__c": violations_text,
408
  "Status__c": "Pending",
409
+ "PDF_Report_URL__c": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
410
  }
411
 
 
 
412
  try:
413
  record = sf.Safety_Video_Report__c.create(record_data)
414
+ except Exception:
 
 
415
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
 
416
 
417
  record_id = record["id"]
418
 
 
421
  if uploaded_url:
422
  try:
423
  sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
424
+ except Exception:
 
 
425
  sf.Account.update(record_id, {"Description": uploaded_url})
426
+ return record_id, uploaded_url
427
+ return record_id, record_data["PDF_Report_URL__c"]
 
 
428
  except Exception as e:
429
  logger.error(f"Salesforce record creation failed: {e}")
430
  return "N/A", "Salesforce integration failed."
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  def process_video(video_data, temp_dir):
433
  video_path = None
434
  output_dir = os.path.join(temp_dir, "output")
435
  os.makedirs(output_dir, exist_ok=True)
436
+
 
437
  try:
438
+ # Save video to temp file
 
 
 
 
 
 
439
  with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
440
  temp_file.write(video_data)
 
441
  video_path = temp_file.name
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ cap = cv2.VideoCapture(video_path)
444
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
445
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
 
 
 
 
 
 
 
446
  tracker = BYTETracker(
447
  track_thresh=CONFIG["TRACK_THRESH"],
448
  track_buffer=CONFIG["TRACK_BUFFER"],
 
453
  worker_id_mapping = {}
454
  unique_violations = {}
455
  violation_frames = {}
456
+ worker_violation_count = {}
457
  start_time = time.time()
 
458
  processed_frames = 0
 
459
  worker_counter = 1
460
 
461
  while processed_frames < total_frames:
 
469
 
470
  ret, frame = cap.read()
471
  if not ret:
 
472
  break
473
 
474
  frame = preprocess_frame(frame)
 
 
 
 
 
475
  batch_frames.append(frame)
476
  batch_indices.append(frame_idx)
477
  processed_frames += 1
478
+
479
+ # Skip frames for faster processing
480
+ for _ in range(CONFIG["FRAME_SKIP"] - 1):
481
+ if not cap.grab():
482
+ break
483
+ processed_frames += 1
484
 
485
  if not batch_frames:
 
486
  break
487
 
488
  try:
489
  batch_frames_np = np.array(batch_frames)
490
  batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
 
491
  if device.type == "cuda":
492
+ batch_frames_tensor = batch_frames_tensor.half().to(device)
 
493
  results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
494
  except Exception as e:
495
  logger.error(f"Model inference failed: {e}")
496
+ raise ValueError(f"Failed to process video frames: {str(e)}")
 
 
 
 
497
 
498
  current_time = time.time()
499
+ if current_time - start_time > CONFIG["MAX_PROCESSING_TIME"]:
500
+ logger.warning("Max processing time reached")
501
+ break
 
 
 
502
 
503
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
504
+ current_timestamp = frame_idx / fps
 
505
  boxes = result.boxes
506
  track_inputs = []
507
 
 
509
  cls = int(box.cls)
510
  conf = float(box.conf)
511
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
512
+ if label and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
513
+ track_inputs.append({
514
+ "bbox": box.xywh.cpu().numpy()[0],
515
+ "conf": conf,
516
+ "cls": cls
517
+ })
 
 
 
 
 
 
 
518
 
519
  if not track_inputs:
520
  continue
 
524
  np.array([t["conf"] for t in track_inputs]),
525
  np.array([t["cls"] for t in track_inputs])
526
  )
 
527
 
528
  for obj in tracked_objects:
529
  tracker_id = obj['id']
530
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
531
+ if not label:
 
 
 
532
  continue
533
 
534
  if tracker_id not in worker_id_mapping:
 
536
  worker_counter += 1
537
 
538
  worker_id = worker_id_mapping[tracker_id]
 
539
  violation_key = (worker_id, label)
540
 
541
  if violation_key not in unique_violations:
542
+ unique_violations[violation_key] = current_timestamp
543
  violation_frames[violation_key] = frame_idx
 
544
  if worker_id not in worker_violation_count:
545
  worker_violation_count[worker_id] = 0
546
  worker_violation_count[worker_id] += 1
547
 
548
  cap.release()
549
+ logger.info(f"Processing complete in {time.time() - start_time:.2f}s")
550
+ logger.info(f"Workers detected: {worker_violation_count}")
551
+
552
+ # Prepare violations list
553
+ violations = [{
554
+ "worker_id": worker_id,
555
+ "violation": label,
556
+ "timestamp": timestamp,
557
+ "confidence": 0.0,
558
+ "frame_idx": violation_frames[(worker_id, label)]
559
+ } for (worker_id, label), timestamp in unique_violations.items()]
 
 
 
560
 
561
  if not violations:
562
+ yield "No violations detected.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
 
563
  return
564
 
565
+ # Capture snapshots of violations
566
  snapshots = []
567
  cap = cv2.VideoCapture(video_path)
568
  for violation in violations:
569
+ cap.set(cv2.CAP_PROP_POS_FRAMES, violation["frame_idx"])
 
570
  ret, frame = cap.read()
571
  if not ret:
 
572
  continue
573
 
574
  frame = preprocess_frame(frame)
575
  frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
 
576
  if device.type == "cuda":
577
+ frame_tensor = frame_tensor.half().to(device)
 
 
 
578
 
579
+ result = model(frame_tensor.unsqueeze(0), device=device, conf=0.1, verbose=False)[0]
580
+ for box in result.boxes:
581
  cls = int(box.cls)
582
  conf = float(box.conf)
583
+ if CONFIG["VIOLATION_LABELS"].get(cls, None) == violation["violation"]:
 
584
  violation["confidence"] = round(conf, 2)
585
  bbox = box.xywh.cpu().numpy()[0]
586
+ snapshot_frame = draw_detections(frame.copy(), [{
587
  "worker_id": violation["worker_id"],
588
+ "violation": violation["violation"],
589
  "confidence": violation["confidence"],
590
  "bounding_box": bbox,
591
  "timestamp": violation["timestamp"]
592
+ }])
593
+ cv2.putText(snapshot_frame, f"Time: {violation['timestamp']:.2f}s",
594
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
595
+ snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
 
 
 
 
 
 
 
 
 
596
  snapshot_path = os.path.join(output_dir, snapshot_filename)
597
+ cv2.imwrite(snapshot_path, snapshot_frame, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
 
 
 
 
598
  snapshots.append({
599
+ "violation": violation["violation"],
600
  "worker_id": violation["worker_id"],
601
  "timestamp": violation["timestamp"],
602
  "snapshot_path": snapshot_path,
603
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
604
  "confidence": violation["confidence"]
605
  })
 
606
  break
 
607
  cap.release()
608
 
609
  score = calculate_safety_score(violations)
610
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
 
611
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
612
 
613
+ # Generate output
614
  worker_summary = {}
615
  for v in violations:
616
  worker_id = v["worker_id"]
 
622
  worker_summary[worker_id]["count"] += 1
623
  worker_summary[worker_id]["violations"].add(v["violation"])
624
 
 
625
  violation_table = "## Worker Safety Violation Summary\n\n"
626
+ violation_table += f"**Total Workers with Violations:** {len(worker_summary)}\n"
627
+ violation_table += f"**Total Violations Found:** {len(violations)}\n\n"
628
+ violation_table += "| Worker ID | Violation Count | Violation Types |\n"
629
+ violation_table += "|-----------|-----------------|-----------------|\n"
630
 
631
  for worker_id, info in worker_summary.items():
632
  violation_types = ", ".join([CONFIG["DISPLAY_NAMES"].get(v, v) for v in info["violations"]])
633
  violation_table += f"| {worker_id} | {info['count']} | {violation_types} |\n"
634
 
635
+ violation_table += "\n## Detailed Violations\n\n"
636
+ violation_table += "| Worker ID | Violation | Time (s) | Confidence |\n"
637
  violation_table += "|-----------|-----------|----------|------------|\n"
638
 
639
+ for v in sorted(violations, key=lambda x: (x["worker_id"], x["timestamp"])):
640
+ display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], "Unknown")
641
+ violation_table += f"| {v['worker_id']} | {display_name} | {v['timestamp']:.2f} | {v['confidence']:.2f} |\n"
 
 
 
 
 
 
 
 
 
 
 
642
 
643
+ snapshots_text = "\n".join(
644
+ f"### {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} - Worker {s['worker_id']} at {s['timestamp']:.2f}s\n\n"
645
+ f"![Violation]({s['snapshot_url']})\n"
646
+ for s in snapshots
647
+ ) or "No snapshots captured."
648
 
649
  yield (
650
  violation_table,
 
661
  if video_path and os.path.exists(video_path):
662
  try:
663
  os.remove(video_path)
 
664
  except Exception as e:
665
+ logger.error(f"Failed to clean up video file: {e}")
666
  if device.type == "cuda":
667
  torch.cuda.empty_cache()
668
 
669
  def gradio_interface(video_file):
670
  temp_dir = None
 
671
  try:
672
  if not video_file:
673
  return "No file uploaded.", "", "No file uploaded.", "", ""
674
 
675
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
 
 
676
  with open(video_file, "rb") as f:
677
  video_data = f.read()
 
 
 
 
 
 
 
 
 
 
678
 
679
  if not FFMPEG_AVAILABLE:
680
+ return "FFmpeg not available. Please install FFmpeg.", "", "", "", ""
681
 
682
+ for output in process_video(video_data, temp_dir):
683
+ yield output
684
 
685
  except Exception as e:
686
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
687
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
688
  finally:
 
 
 
 
 
 
 
689
  if temp_dir and os.path.exists(temp_dir):
690
  shutil.rmtree(temp_dir, ignore_errors=True)
 
691
  if device.type == "cuda":
692
  torch.cuda.empty_cache()
693
 
 
703
  gr.Textbox(label="Violation Details URL")
704
  ],
705
  title="Worksite Safety Violation Analyzer",
706
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use).",
707
  allow_flagging="never"
708
  )
709
 
710
  if __name__ == "__main__":
711
+ logger.info("Launching Safety Analyzer App...")
712
  interface.launch()