PrashanthB461 commited on
Commit
f2bfa69
·
verified ·
1 Parent(s): 4c61ad0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -261
app.py CHANGED
@@ -43,25 +43,24 @@ class BYTETracker:
43
  def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
- self.match_thresh = match_thresh
47
  self.frame_rate = frame_rate
48
  self.next_id = 1
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
  self.recently_removed = {} # Store recently removed tracks for re-identification
53
- self.helmet_status = {} # Track helmet status for each worker
54
 
55
  def update(self, dets, scores, cls):
56
  tracks = []
57
  current_time = time.time()
58
-
59
  # Prune stale tracks
60
  stale_ids = []
61
  for track_id, track_info in self.tracks.items():
62
  if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
63
  stale_ids.append(track_id)
64
-
65
  for track_id in stale_ids:
66
  # Store recently removed tracks for re-identification (for 1 second)
67
  self.recently_removed[track_id] = {
@@ -86,22 +85,22 @@ class BYTETracker:
86
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
87
  if score < self.track_thresh:
88
  continue
89
-
90
  x, y, w, h = det
91
  matched = False
92
  best_iou = 0
93
  best_track_id = None
94
-
95
  # Try to match with active tracks
96
  for track_id, track_info in self.tracks.items():
97
  tx, ty, tw, th = track_info['bbox']
98
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
99
-
100
  if iou > self.match_thresh and iou > best_iou:
101
  best_iou = iou
102
  best_track_id = track_id
103
  matched = True
104
-
105
  if matched:
106
  self.tracks[best_track_id].update({
107
  'bbox': [x, y, w, h],
@@ -109,18 +108,11 @@ class BYTETracker:
109
  'cls': cl,
110
  'last_seen': current_time
111
  })
112
-
113
- # Update helmet status if this is a helmet detection
114
- if cl == 0: # Helmet violation class
115
- # Higher confidence for helmet violations
116
- if score > 0.45: # Increased threshold for helmet violations
117
- self.helmet_status[best_track_id] = True
118
-
119
  if best_track_id not in self.worker_history:
120
  self.worker_history[best_track_id] = []
121
  self.worker_history[best_track_id].append([x, y])
122
  self.last_positions[best_track_id] = [x, y]
123
-
124
  tracks.append({
125
  'id': best_track_id,
126
  'bbox': [x, y, w, h],
@@ -140,13 +132,6 @@ class BYTETracker:
140
  }
141
  self.worker_history[track_id] = [[x, y]]
142
  self.last_positions[track_id] = [x, y]
143
-
144
- # Update helmet status if this is a helmet detection
145
- if cl == 0: # Helmet violation class
146
- # Higher confidence for helmet violations
147
- if score > 0.45: # Increased threshold for helmet violations
148
- self.helmet_status[track_id] = True
149
-
150
  tracks.append({
151
  'id': track_id,
152
  'bbox': [x, y, w, h],
@@ -156,7 +141,7 @@ class BYTETracker:
156
  reidentified = True
157
  del self.recently_removed[track_id]
158
  break
159
-
160
  if not reidentified:
161
  # Check if it matches an existing worker by position
162
  same_worker = False
@@ -168,13 +153,6 @@ class BYTETracker:
168
  'cls': cl,
169
  'last_seen': current_time
170
  }
171
-
172
- # Update helmet status if this is a helmet detection
173
- if cl == 0: # Helmet violation class
174
- # Higher confidence for helmet violations
175
- if score > 0.45: # Increased threshold for helmet violations
176
- self.helmet_status[worker_id] = True
177
-
178
  tracks.append({
179
  'id': worker_id,
180
  'bbox': [x, y, w, h],
@@ -183,7 +161,7 @@ class BYTETracker:
183
  })
184
  same_worker = True
185
  break
186
-
187
  if not same_worker:
188
  self.tracks[self.next_id] = {
189
  'bbox': [x, y, w, h],
@@ -193,13 +171,6 @@ class BYTETracker:
193
  }
194
  self.worker_history[self.next_id] = [[x, y]]
195
  self.last_positions[self.next_id] = [x, y]
196
-
197
- # Update helmet status if this is a helmet detection
198
- if cl == 0: # Helmet violation class
199
- # Higher confidence for helmet violations
200
- if score > 0.45: # Increased threshold for helmet violations
201
- self.helmet_status[self.next_id] = True
202
-
203
  tracks.append({
204
  'id': self.next_id,
205
  'bbox': [x, y, w, h],
@@ -207,7 +178,7 @@ class BYTETracker:
207
  'cls': cl
208
  })
209
  self.next_id += 1
210
-
211
  return tracks
212
 
213
  def _calculate_iou(self, box1, box2):
@@ -224,18 +195,13 @@ class BYTETracker:
224
  box2_area = w2 * h2
225
  iou = intersection_area / (box1_area + box2_area - intersection_area)
226
  return iou
227
-
228
- def _is_same_worker(self, pos1, pos2, threshold=150):
229
  x1, y1 = pos1
230
  x2, y2 = pos2
231
  distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
232
  return distance < threshold
233
 
234
- # Function to validate if a helmet violation is consistent across frames
235
- def validate_helmet_violation(self, worker_id, current_confidence):
236
- # If we have consistent high confidence or multiple detections, it's a valid violation
237
- return worker_id in self.helmet_status and self.helmet_status[worker_id]
238
-
239
  # ========================== # Optimized Configuration # ==========================
240
  CONFIG = {
241
  "MODEL_PATH": "yolov8_safety.pt",
@@ -269,26 +235,25 @@ CONFIG = {
269
  },
270
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
271
  "CONFIDENCE_THRESHOLDS": {
272
- "no_helmet": 0.45, # Increased threshold for helmet violations
273
  "no_harness": 0.25,
274
  "unsafe_posture": 0.25,
275
  "unsafe_zone": 0.25,
276
  "improper_tool_use": 0.25
277
  },
278
- "MIN_VIOLATION_FRAMES": 2, # Increased to require multiple frames for confirmation
279
  "VIOLATION_COOLDOWN": 30.0,
280
- "WORKER_TRACKING_DURATION": 10.0,
281
  "MAX_PROCESSING_TIME": 60,
282
- "FRAME_SKIP": 2, # Increased frame skip for faster processing
283
- "BATCH_SIZE": 8, # Increased batch size for better GPU utilization
284
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
285
- "TRACK_BUFFER": 150,
286
  "TRACK_THRESH": 0.3,
287
- "MATCH_THRESH": 0.5,
288
  "SNAPSHOT_QUALITY": 95,
289
- "MAX_WORKER_DISTANCE": 150,
290
- "TARGET_RESOLUTION": (384, 384),
291
- "HELMET_VALIDATION_FRAMES": 3 # Number of frames to validate helmet violations
292
  }
293
 
294
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -305,7 +270,7 @@ def load_model():
305
  if not os.path.isfile(model_path):
306
  logger.info(f"Downloading fallback model: {model_path}")
307
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
308
-
309
  model = YOLO(model_path).to(device)
310
  if device.type == "cuda":
311
  model.model.half()
@@ -320,23 +285,13 @@ model = load_model()
320
  # ========================== # Helper Functions # ==========================
321
  def preprocess_frame(frame):
322
  target_res = CONFIG["TARGET_RESOLUTION"]
323
- # Enhanced preprocessing for better helmet detection
324
  frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
325
- # Increase contrast to better differentiate helmets from other head coverings
326
- frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20) # Increased contrast
327
-
328
- # Additional preprocessing to enhance head/helmet features
329
- # Apply slight sharpening to make edges more distinct
330
- kernel = np.array([[-1,-1,-1],
331
- [-1, 9,-1],
332
- [-1,-1,-1]])
333
- frame = cv2.filter2D(frame, -1, kernel)
334
-
335
  return frame
336
 
337
  def draw_detections(frame, detections):
338
  result_frame = frame.copy()
339
-
340
  for det in detections:
341
  label = det.get("violation", "Unknown")
342
  confidence = det.get("confidence", 0.0)
@@ -347,22 +302,19 @@ def draw_detections(frame, detections):
347
  y1 = int(y - h/2)
348
  x2 = int(x + w/2)
349
  y2 = int(y + h/2)
350
-
351
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
352
-
353
- # Make no_helmet violations more prominent
354
- line_thickness = 4 if label == "no_helmet" else 3
355
-
356
- cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, line_thickness)
357
-
358
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
359
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
360
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
361
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
362
-
363
  conf_text = f"Conf: {confidence:.2f}"
364
  cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
365
-
366
  return result_frame
367
 
368
  def calculate_safety_score(violations):
@@ -373,21 +325,21 @@ def calculate_safety_score(violations):
373
  "unsafe_zone": 35,
374
  "improper_tool_use": 25
375
  }
376
-
377
  worker_violations = {}
378
  for v in violations:
379
  worker_id = v.get("worker_id", "Unknown")
380
  violation_type = v.get("violation", "Unknown")
381
-
382
  if worker_id not in worker_violations:
383
  worker_violations[worker_id] = set()
384
  worker_violations[worker_id].add(violation_type)
385
-
386
  total_penalty = 0
387
  for worker_violations_set in worker_violations.values():
388
  worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
389
  total_penalty += worker_penalty
390
-
391
  score = max(0, 100 - total_penalty)
392
  return score
393
 
@@ -397,14 +349,14 @@ def generate_violation_pdf(violations, score, output_dir):
397
  pdf_path = os.path.join(output_dir, pdf_filename)
398
  pdf_file = BytesIO()
399
  c = canvas.Canvas(pdf_file, pagesize=letter)
400
-
401
  c.setFont("Helvetica-Bold", 16)
402
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
403
-
404
  c.setFont("Helvetica", 12)
405
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
406
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
407
-
408
  c.setFont("Helvetica-Bold", 14)
409
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
410
 
@@ -412,21 +364,21 @@ def generate_violation_pdf(violations, score, output_dir):
412
  c.setFont("Helvetica-Bold", 12)
413
  c.drawString(1 * inch, y_position, "Summary:")
414
  y_position -= 0.3 * inch
415
-
416
  worker_violations = {}
417
  for v in violations:
418
  worker_id = v.get("worker_id", "Unknown")
419
  if worker_id not in worker_violations:
420
  worker_violations[worker_id] = []
421
  worker_violations[worker_id].append(v)
422
-
423
  c.setFont("Helvetica", 10)
424
  summary_data = {
425
  "Total Workers with Violations": len(worker_violations),
426
  "Total Violations Found": len(violations),
427
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
428
  }
429
-
430
  for key, value in summary_data.items():
431
  c.drawString(1 * inch, y_position, f"{key}: {value}")
432
  y_position -= 0.25 * inch
@@ -435,21 +387,21 @@ def generate_violation_pdf(violations, score, output_dir):
435
  c.setFont("Helvetica-Bold", 12)
436
  c.drawString(1 * inch, y_position, "Violations by Worker:")
437
  y_position -= 0.3 * inch
438
-
439
  c.setFont("Helvetica", 10)
440
  for worker_id, worker_vios in worker_violations.items():
441
  c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
442
  y_position -= 0.2 * inch
443
-
444
  for v in worker_vios:
445
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
446
  time_str = f"{v.get('timestamp', 0.0):.2f}s"
447
  conf_str = f"{v.get('confidence', 0.0):.2f}"
448
-
449
  violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
450
  c.drawString(1.2 * inch, y_position, violation_text)
451
  y_position -= 0.2 * inch
452
-
453
  if y_position < 1 * inch:
454
  c.showPage()
455
  c.setFont("Helvetica", 10)
@@ -460,7 +412,7 @@ def generate_violation_pdf(violations, score, output_dir):
460
 
461
  with open(pdf_path, "wb") as f:
462
  f.write(pdf_file.getvalue())
463
-
464
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
465
  logger.info(f"PDF generated: {public_url}")
466
  return pdf_path, public_url, pdf_file
@@ -484,7 +436,7 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
484
  if not pdf_file:
485
  logger.error("No PDF file provided for upload")
486
  return ""
487
-
488
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
489
  content_version_data = {
490
  "Title": f"Safety_Violation_Report_{int(time.time())}",
@@ -494,11 +446,11 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
494
  }
495
  content_version = sf.ContentVersion.create(content_version_data)
496
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
497
-
498
  if not result['records']:
499
  logger.error("Failed to retrieve ContentVersion")
500
  return ""
501
-
502
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
503
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
504
  return file_url
@@ -509,19 +461,19 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
509
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
510
  try:
511
  sf = connect_to_salesforce()
512
-
513
  violations_text = ""
514
  for v in violations:
515
  display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
516
  worker_id = v.get('worker_id', 'Unknown')
517
  timestamp = v.get('timestamp', 0.0)
518
  confidence = v.get('confidence', 0.0)
519
-
520
  violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
521
-
522
  if not violations_text:
523
  violations_text = "No violations detected."
524
-
525
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
526
 
527
  record_data = {
@@ -531,9 +483,9 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
531
  "Status__c": "Pending",
532
  "PDF_Report_URL__c": pdf_url
533
  }
534
-
535
  logger.info(f"Creating Salesforce record with data: {record_data}")
536
-
537
  try:
538
  record = sf.Safety_Video_Report__c.create(record_data)
539
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
@@ -541,7 +493,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
541
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
542
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
543
  logger.warning(f"Fell back to Account record: {record['id']}")
544
-
545
  record_id = record["id"]
546
 
547
  if pdf_file:
@@ -570,107 +522,30 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
570
  def verify_and_open_video(video_path):
571
  if not os.path.exists(video_path):
572
  raise FileNotFoundError(f"Temporary video file not found: {video_path}")
573
-
574
  file_size = os.path.getsize(video_path)
575
  if file_size == 0:
576
  raise ValueError(f"Temporary video file is empty: {video_path}")
577
-
578
  with open(video_path, "rb") as f:
579
  f.read(1)
580
-
581
  cap = cv2.VideoCapture(video_path)
582
  if not cap.isOpened():
583
  raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
584
-
585
  return cap
586
 
587
- # Helper for helmet validation
588
- def validate_helmet_detection(frame, bbox, confidence_threshold=0.45):
589
- """
590
- Additional validation for helmet detection to reduce false positives.
591
- This function performs additional checks on the region to confirm it's a true helmet violation.
592
- """
593
- x, y, w, h = bbox
594
- x1 = int(max(0, x - w/2))
595
- y1 = int(max(0, y - h/2))
596
- x2 = int(min(frame.shape[1], x + w/2))
597
- y2 = int(min(frame.shape[0], y + h/2))
598
-
599
- # Extract head region
600
- head_region = frame[y1:y2, x1:x2]
601
- if head_region.size == 0:
602
- return False
603
-
604
- # Check if this is truly a helmet violation by analyzing the region
605
- # 1. Check color distribution - helmets often have more uniform color
606
- hsv = cv2.cvtColor(head_region, cv2.COLOR_BGR2HSV)
607
-
608
- # Check for typical helmet colors (many construction helmets are yellow, white, orange, blue)
609
- # This helps differentiate from cloth head coverings
610
- yellow_lower = np.array([20, 100, 100])
611
- yellow_upper = np.array([30, 255, 255])
612
- yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)
613
-
614
- white_lower = np.array([0, 0, 200])
615
- white_upper = np.array([180, 30, 255])
616
- white_mask = cv2.inRange(hsv, white_lower, white_upper)
617
-
618
- orange_lower = np.array([5, 100, 100])
619
- orange_upper = np.array([15, 255, 255])
620
- orange_mask = cv2.inRange(hsv, orange_lower, orange_upper)
621
-
622
- blue_lower = np.array([100, 100, 100])
623
- blue_upper = np.array([130, 255, 255])
624
- blue_mask = cv2.inRange(hsv, blue_lower, blue_upper)
625
-
626
- helmet_mask = cv2.bitwise_or(yellow_mask, white_mask)
627
- helmet_mask = cv2.bitwise_or(helmet_mask, orange_mask)
628
- helmet_mask = cv2.bitwise_or(helmet_mask, blue_mask)
629
-
630
- # If there's a significant amount of helmet-colored pixels, this might be a helmet
631
- helmet_percentage = np.sum(helmet_mask > 0) / (head_region.shape[0] * head_region.shape[1])
632
-
633
- # If the region has a significant amount of helmet-like colors, it's probably a helmet
634
- # so we should NOT flag it as a violation (return False)
635
- if helmet_percentage > 0.25:
636
- return False
637
-
638
- # Check texture uniformity - helmets have more uniform texture compared to head coverings
639
- gray = cv2.cvtColor(head_region, cv2.COLOR_BGR2GRAY)
640
- texture_score = np.std(gray)
641
-
642
- # If texture is very uniform (low standard deviation), it might be a helmet or bare head
643
- # Very uniform texture (like a hard helmet) would have low texture_score
644
- if texture_score < 15: # Low texture suggests uniform surface like a helmet
645
- return False
646
-
647
- # Additional check for cloth-like textures
648
- edges = cv2.Canny(gray, 50, 150)
649
- edge_density = np.sum(edges > 0) / (head_region.shape[0] * head_region.shape[1])
650
-
651
- # If there are many edges (cloth wrinkles), this might be a kurchief
652
- if edge_density > 0.15:
653
- # This is likely a cloth head covering, not a helmet violation
654
- # But also not a proper helmet, so we should still detect as violation
655
- return True
656
-
657
- # If confidence is very high, trust the model
658
- if confidence_threshold >= 0.6:
659
- return True
660
-
661
- # Default to the original detection
662
- return True
663
-
664
  def process_video(video_data, temp_dir):
665
  video_path = None
666
  output_dir = os.path.join(temp_dir, "output")
667
  os.makedirs(output_dir, exist_ok=True)
668
  os.environ['YOLO_CONFIG_DIR'] = temp_dir
669
-
670
  try:
671
  if not video_data:
672
  raise ValueError("Empty video data provided.")
673
-
674
  logger.info(f"Received video data size: {len(video_data)} bytes")
675
  if len(video_data) == 0:
676
  raise ValueError("Video data is empty.")
@@ -711,8 +586,7 @@ def process_video(video_data, temp_dir):
711
  worker_id_mapping = {}
712
  unique_violations = {}
713
  violation_frames = {}
714
- # Track helmet detections across frames for each worker
715
- helmet_detections = {}
716
  start_time = time.time()
717
  frame_skip = CONFIG["FRAME_SKIP"]
718
  processed_frames = 0
@@ -722,30 +596,25 @@ def process_video(video_data, temp_dir):
722
  while processed_frames < total_frames:
723
  batch_frames = []
724
  batch_indices = []
725
- batch_originals = [] # Store original frames for helmet validation
726
-
727
  for _ in range(CONFIG["BATCH_SIZE"]):
728
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
729
  if frame_idx >= total_frames:
730
  break
731
-
732
  ret, frame = cap.read()
733
  if not ret:
734
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
735
  break
736
-
737
- # Store original frame for validation
738
- original_frame = frame.copy()
739
-
740
  frame = preprocess_frame(frame)
741
-
742
  for _ in range(frame_skip - 1):
743
  if not cap.grab():
744
  break
745
-
746
  batch_frames.append(frame)
747
  batch_indices.append(frame_idx)
748
- batch_originals.append(original_frame)
749
  processed_frames += 1
750
 
751
  if not batch_frames:
@@ -776,34 +645,22 @@ def process_video(video_data, temp_dir):
776
  yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
777
  last_yield_time = current_time
778
 
779
- for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
780
  current_time = frame_idx / fps
781
-
782
  boxes = result.boxes
783
  track_inputs = []
784
-
785
  for box in boxes:
786
  cls = int(box.cls)
787
  conf = float(box.conf)
788
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
789
-
790
  if label is None:
791
  continue
792
-
793
- # Enhanced confidence threshold handling, especially for helmet detection
794
- if label == "no_helmet":
795
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.45):
796
- continue
797
-
798
- # Additional validation for helmet detection
799
- bbox = box.xywh.cpu().numpy()[0]
800
- if not validate_helmet_detection(original_frame, bbox, conf):
801
- logger.info(f"Frame {frame_idx}: Helmet false positive filtered at {conf:.2f} confidence")
802
- continue
803
- else:
804
- # Use regular thresholds for other violations
805
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
806
- continue
807
 
808
  bbox = box.xywh.cpu().numpy()[0]
809
  track_inputs.append({
@@ -814,7 +671,7 @@ def process_video(video_data, temp_dir):
814
 
815
  if not track_inputs:
816
  continue
817
-
818
  tracked_objects = tracker.update(
819
  np.array([t["bbox"] for t in track_inputs]),
820
  np.array([t["conf"] for t in track_inputs]),
@@ -827,52 +684,31 @@ def process_video(video_data, temp_dir):
827
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
828
  conf = obj['score']
829
  bbox = obj['bbox']
830
-
831
  if label is None:
832
  continue
833
-
834
  if tracker_id not in worker_id_mapping:
835
  worker_id_mapping[tracker_id] = worker_counter
836
  worker_counter += 1
837
-
838
  worker_id = worker_id_mapping[tracker_id]
839
-
840
- # Special handling for helmet violations to ensure consistency
841
- if label == "no_helmet":
842
- # Track helmet violations for this worker
843
- if worker_id not in helmet_detections:
844
- helmet_detections[worker_id] = []
845
-
846
- # Store this detection with frame index and confidence
847
- helmet_detections[worker_id].append({
848
- "frame_idx": frame_idx,
849
- "confidence": conf,
850
- "bbox": bbox
851
- })
852
-
853
- # Only record a helmet violation if we have multiple consistent detections
854
- if len(helmet_detections[worker_id]) >= CONFIG["HELMET_VALIDATION_FRAMES"]:
855
- # Calculate average confidence
856
- avg_conf = sum(d["confidence"] for d in helmet_detections[worker_id]) / len(helmet_detections[worker_id])
857
-
858
- # If confidence is consistently high across multiple frames, record the violation
859
- if avg_conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
860
- violation_key = (worker_id, label)
861
- if violation_key not in unique_violations:
862
- unique_violations[violation_key] = current_time
863
- violation_frames[violation_key] = frame_idx
864
- logger.info(f"Frame {frame_idx}: Valid helmet violation for worker {worker_id} with avg conf {avg_conf:.2f}")
865
- else:
866
- # Regular handling for other violations
867
- violation_key = (worker_id, label)
868
- if violation_key not in unique_violations:
869
- unique_violations[violation_key] = current_time
870
- violation_frames[violation_key] = frame_idx
871
 
872
  cap.release()
873
  processing_time = time.time() - start_time
874
  logger.info(f"Processing complete in {processing_time:.2f}s")
875
  logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
 
876
 
877
  violations = []
878
  for (worker_id, label), detection_time in unique_violations.items():
@@ -955,12 +791,34 @@ def process_video(video_data, temp_dir):
955
 
956
  score = calculate_safety_score(violations)
957
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
958
-
959
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
960
 
961
- violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
962
- violation_table += "|-----------|-----------|----------|------------|\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963
 
 
 
 
 
 
 
 
 
964
  for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
965
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
966
  worker_id = v.get("worker_id", "Unknown")
@@ -1006,14 +864,14 @@ def gradio_interface(video_file):
1006
  try:
1007
  if not video_file:
1008
  return "No file uploaded.", "", "No file uploaded.", "", ""
1009
-
1010
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
1011
  logger.info(f"Created temporary directory for video processing: {temp_dir}")
1012
 
1013
  with open(video_file, "rb") as f:
1014
  video_data = f.read()
1015
  logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
1016
-
1017
  if len(video_data) == 0:
1018
  return "Uploaded video file is empty.", "", "", "", ""
1019
 
@@ -1028,7 +886,7 @@ def gradio_interface(video_file):
1028
 
1029
  for status, score, snapshots_text, record_id, details_url in process_video(video_data, temp_dir):
1030
  yield status, score, snapshots_text, record_id, details_url
1031
-
1032
  except Exception as e:
1033
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
1034
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
@@ -1039,7 +897,7 @@ def gradio_interface(video_file):
1039
  logger.info(f"Cleaned up local temporary video file: {local_video_path}")
1040
  except Exception as e:
1041
  logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
1042
-
1043
  if temp_dir and os.path.exists(temp_dir):
1044
  shutil.rmtree(temp_dir, ignore_errors=True)
1045
  logger.info(f"Cleaned up temporary directory: {temp_dir}")
 
43
  def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
+ self.match_thresh = match_thresh # Increased to 0.5 for better matching
47
  self.frame_rate = frame_rate
48
  self.next_id = 1
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
  self.recently_removed = {} # Store recently removed tracks for re-identification
 
53
 
54
  def update(self, dets, scores, cls):
55
  tracks = []
56
  current_time = time.time()
57
+
58
  # Prune stale tracks
59
  stale_ids = []
60
  for track_id, track_info in self.tracks.items():
61
  if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
62
  stale_ids.append(track_id)
63
+
64
  for track_id in stale_ids:
65
  # Store recently removed tracks for re-identification (for 1 second)
66
  self.recently_removed[track_id] = {
 
85
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
86
  if score < self.track_thresh:
87
  continue
88
+
89
  x, y, w, h = det
90
  matched = False
91
  best_iou = 0
92
  best_track_id = None
93
+
94
  # Try to match with active tracks
95
  for track_id, track_info in self.tracks.items():
96
  tx, ty, tw, th = track_info['bbox']
97
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
98
+
99
  if iou > self.match_thresh and iou > best_iou:
100
  best_iou = iou
101
  best_track_id = track_id
102
  matched = True
103
+
104
  if matched:
105
  self.tracks[best_track_id].update({
106
  'bbox': [x, y, w, h],
 
108
  'cls': cl,
109
  'last_seen': current_time
110
  })
 
 
 
 
 
 
 
111
  if best_track_id not in self.worker_history:
112
  self.worker_history[best_track_id] = []
113
  self.worker_history[best_track_id].append([x, y])
114
  self.last_positions[best_track_id] = [x, y]
115
+
116
  tracks.append({
117
  'id': best_track_id,
118
  'bbox': [x, y, w, h],
 
132
  }
133
  self.worker_history[track_id] = [[x, y]]
134
  self.last_positions[track_id] = [x, y]
 
 
 
 
 
 
 
135
  tracks.append({
136
  'id': track_id,
137
  'bbox': [x, y, w, h],
 
141
  reidentified = True
142
  del self.recently_removed[track_id]
143
  break
144
+
145
  if not reidentified:
146
  # Check if it matches an existing worker by position
147
  same_worker = False
 
153
  'cls': cl,
154
  'last_seen': current_time
155
  }
 
 
 
 
 
 
 
156
  tracks.append({
157
  'id': worker_id,
158
  'bbox': [x, y, w, h],
 
161
  })
162
  same_worker = True
163
  break
164
+
165
  if not same_worker:
166
  self.tracks[self.next_id] = {
167
  'bbox': [x, y, w, h],
 
171
  }
172
  self.worker_history[self.next_id] = [[x, y]]
173
  self.last_positions[self.next_id] = [x, y]
 
 
 
 
 
 
 
174
  tracks.append({
175
  'id': self.next_id,
176
  'bbox': [x, y, w, h],
 
178
  'cls': cl
179
  })
180
  self.next_id += 1
181
+
182
  return tracks
183
 
184
  def _calculate_iou(self, box1, box2):
 
195
  box2_area = w2 * h2
196
  iou = intersection_area / (box1_area + box2_area - intersection_area)
197
  return iou
198
+
199
+ def _is_same_worker(self, pos1, pos2, threshold=150): # Increased threshold to 150
200
  x1, y1 = pos1
201
  x2, y2 = pos2
202
  distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
203
  return distance < threshold
204
 
 
 
 
 
 
205
  # ========================== # Optimized Configuration # ==========================
206
  CONFIG = {
207
  "MODEL_PATH": "yolov8_safety.pt",
 
235
  },
236
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
237
  "CONFIDENCE_THRESHOLDS": {
238
+ "no_helmet": 0.4,
239
  "no_harness": 0.25,
240
  "unsafe_posture": 0.25,
241
  "unsafe_zone": 0.25,
242
  "improper_tool_use": 0.25
243
  },
244
+ "MIN_VIOLATION_FRAMES": 1,
245
  "VIOLATION_COOLDOWN": 30.0,
246
+ "WORKER_TRACKING_DURATION": 10.0, # Reverted to 5.0 seconds
247
  "MAX_PROCESSING_TIME": 60,
248
+ "FRAME_SKIP": 1,
249
+ "BATCH_SIZE": 4,
250
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
251
+ "TRACK_BUFFER": 150, # 5.0 seconds at 30 fps
252
  "TRACK_THRESH": 0.3,
253
+ "MATCH_THRESH": 0.5, # Increased to 0.5
254
  "SNAPSHOT_QUALITY": 95,
255
+ "MAX_WORKER_DISTANCE": 150, # Increased to match _is_same_worker threshold
256
+ "TARGET_RESOLUTION": (384, 384)
 
257
  }
258
 
259
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
270
  if not os.path.isfile(model_path):
271
  logger.info(f"Downloading fallback model: {model_path}")
272
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
273
+
274
  model = YOLO(model_path).to(device)
275
  if device.type == "cuda":
276
  model.model.half()
 
285
  # ========================== # Helper Functions # ==========================
286
  def preprocess_frame(frame):
287
  target_res = CONFIG["TARGET_RESOLUTION"]
 
288
  frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
289
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
 
 
 
 
 
 
 
 
 
290
  return frame
291
 
292
  def draw_detections(frame, detections):
293
  result_frame = frame.copy()
294
+
295
  for det in detections:
296
  label = det.get("violation", "Unknown")
297
  confidence = det.get("confidence", 0.0)
 
302
  y1 = int(y - h/2)
303
  x2 = int(x + w/2)
304
  y2 = int(y + h/2)
305
+
306
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
307
+
308
+ cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
309
+
 
 
 
310
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
311
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
312
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
313
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
314
+
315
  conf_text = f"Conf: {confidence:.2f}"
316
  cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
317
+
318
  return result_frame
319
 
320
  def calculate_safety_score(violations):
 
325
  "unsafe_zone": 35,
326
  "improper_tool_use": 25
327
  }
328
+
329
  worker_violations = {}
330
  for v in violations:
331
  worker_id = v.get("worker_id", "Unknown")
332
  violation_type = v.get("violation", "Unknown")
333
+
334
  if worker_id not in worker_violations:
335
  worker_violations[worker_id] = set()
336
  worker_violations[worker_id].add(violation_type)
337
+
338
  total_penalty = 0
339
  for worker_violations_set in worker_violations.values():
340
  worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
341
  total_penalty += worker_penalty
342
+
343
  score = max(0, 100 - total_penalty)
344
  return score
345
 
 
349
  pdf_path = os.path.join(output_dir, pdf_filename)
350
  pdf_file = BytesIO()
351
  c = canvas.Canvas(pdf_file, pagesize=letter)
352
+
353
  c.setFont("Helvetica-Bold", 16)
354
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
355
+
356
  c.setFont("Helvetica", 12)
357
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
358
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
359
+
360
  c.setFont("Helvetica-Bold", 14)
361
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
362
 
 
364
  c.setFont("Helvetica-Bold", 12)
365
  c.drawString(1 * inch, y_position, "Summary:")
366
  y_position -= 0.3 * inch
367
+
368
  worker_violations = {}
369
  for v in violations:
370
  worker_id = v.get("worker_id", "Unknown")
371
  if worker_id not in worker_violations:
372
  worker_violations[worker_id] = []
373
  worker_violations[worker_id].append(v)
374
+
375
  c.setFont("Helvetica", 10)
376
  summary_data = {
377
  "Total Workers with Violations": len(worker_violations),
378
  "Total Violations Found": len(violations),
379
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
380
  }
381
+
382
  for key, value in summary_data.items():
383
  c.drawString(1 * inch, y_position, f"{key}: {value}")
384
  y_position -= 0.25 * inch
 
387
  c.setFont("Helvetica-Bold", 12)
388
  c.drawString(1 * inch, y_position, "Violations by Worker:")
389
  y_position -= 0.3 * inch
390
+
391
  c.setFont("Helvetica", 10)
392
  for worker_id, worker_vios in worker_violations.items():
393
  c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
394
  y_position -= 0.2 * inch
395
+
396
  for v in worker_vios:
397
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
398
  time_str = f"{v.get('timestamp', 0.0):.2f}s"
399
  conf_str = f"{v.get('confidence', 0.0):.2f}"
400
+
401
  violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
402
  c.drawString(1.2 * inch, y_position, violation_text)
403
  y_position -= 0.2 * inch
404
+
405
  if y_position < 1 * inch:
406
  c.showPage()
407
  c.setFont("Helvetica", 10)
 
412
 
413
  with open(pdf_path, "wb") as f:
414
  f.write(pdf_file.getvalue())
415
+
416
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
417
  logger.info(f"PDF generated: {public_url}")
418
  return pdf_path, public_url, pdf_file
 
436
  if not pdf_file:
437
  logger.error("No PDF file provided for upload")
438
  return ""
439
+
440
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
441
  content_version_data = {
442
  "Title": f"Safety_Violation_Report_{int(time.time())}",
 
446
  }
447
  content_version = sf.ContentVersion.create(content_version_data)
448
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
449
+
450
  if not result['records']:
451
  logger.error("Failed to retrieve ContentVersion")
452
  return ""
453
+
454
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
455
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
456
  return file_url
 
461
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
462
  try:
463
  sf = connect_to_salesforce()
464
+
465
  violations_text = ""
466
  for v in violations:
467
  display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
468
  worker_id = v.get('worker_id', 'Unknown')
469
  timestamp = v.get('timestamp', 0.0)
470
  confidence = v.get('confidence', 0.0)
471
+
472
  violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
473
+
474
  if not violations_text:
475
  violations_text = "No violations detected."
476
+
477
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
478
 
479
  record_data = {
 
483
  "Status__c": "Pending",
484
  "PDF_Report_URL__c": pdf_url
485
  }
486
+
487
  logger.info(f"Creating Salesforce record with data: {record_data}")
488
+
489
  try:
490
  record = sf.Safety_Video_Report__c.create(record_data)
491
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
 
493
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
494
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
495
  logger.warning(f"Fell back to Account record: {record['id']}")
496
+
497
  record_id = record["id"]
498
 
499
  if pdf_file:
 
522
  def verify_and_open_video(video_path):
523
  if not os.path.exists(video_path):
524
  raise FileNotFoundError(f"Temporary video file not found: {video_path}")
525
+
526
  file_size = os.path.getsize(video_path)
527
  if file_size == 0:
528
  raise ValueError(f"Temporary video file is empty: {video_path}")
529
+
530
  with open(video_path, "rb") as f:
531
  f.read(1)
532
+
533
  cap = cv2.VideoCapture(video_path)
534
  if not cap.isOpened():
535
  raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
536
+
537
  return cap
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  def process_video(video_data, temp_dir):
540
  video_path = None
541
  output_dir = os.path.join(temp_dir, "output")
542
  os.makedirs(output_dir, exist_ok=True)
543
  os.environ['YOLO_CONFIG_DIR'] = temp_dir
544
+
545
  try:
546
  if not video_data:
547
  raise ValueError("Empty video data provided.")
548
+
549
  logger.info(f"Received video data size: {len(video_data)} bytes")
550
  if len(video_data) == 0:
551
  raise ValueError("Video data is empty.")
 
586
  worker_id_mapping = {}
587
  unique_violations = {}
588
  violation_frames = {}
589
+ worker_violation_count = {} # Track violation count per worker
 
590
  start_time = time.time()
591
  frame_skip = CONFIG["FRAME_SKIP"]
592
  processed_frames = 0
 
596
  while processed_frames < total_frames:
597
  batch_frames = []
598
  batch_indices = []
599
+
 
600
  for _ in range(CONFIG["BATCH_SIZE"]):
601
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
602
  if frame_idx >= total_frames:
603
  break
604
+
605
  ret, frame = cap.read()
606
  if not ret:
607
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
608
  break
609
+
 
 
 
610
  frame = preprocess_frame(frame)
611
+
612
  for _ in range(frame_skip - 1):
613
  if not cap.grab():
614
  break
615
+
616
  batch_frames.append(frame)
617
  batch_indices.append(frame_idx)
 
618
  processed_frames += 1
619
 
620
  if not batch_frames:
 
645
  yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
646
  last_yield_time = current_time
647
 
648
+ for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
649
  current_time = frame_idx / fps
650
+
651
  boxes = result.boxes
652
  track_inputs = []
653
+
654
  for box in boxes:
655
  cls = int(box.cls)
656
  conf = float(box.conf)
657
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
658
+
659
  if label is None:
660
  continue
661
+
662
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
663
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
664
 
665
  bbox = box.xywh.cpu().numpy()[0]
666
  track_inputs.append({
 
671
 
672
  if not track_inputs:
673
  continue
674
+
675
  tracked_objects = tracker.update(
676
  np.array([t["bbox"] for t in track_inputs]),
677
  np.array([t["conf"] for t in track_inputs]),
 
684
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
685
  conf = obj['score']
686
  bbox = obj['bbox']
687
+
688
  if label is None:
689
  continue
690
+
691
  if tracker_id not in worker_id_mapping:
692
  worker_id_mapping[tracker_id] = worker_counter
693
  worker_counter += 1
694
+
695
  worker_id = worker_id_mapping[tracker_id]
696
+
697
+ violation_key = (worker_id, label)
698
+
699
+ if violation_key not in unique_violations:
700
+ unique_violations[violation_key] = current_time
701
+ violation_frames[violation_key] = frame_idx
702
+ # Update violation count for this worker
703
+ if worker_id not in worker_violation_count:
704
+ worker_violation_count[worker_id] = 0
705
+ worker_violation_count[worker_id] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
 
707
  cap.release()
708
  processing_time = time.time() - start_time
709
  logger.info(f"Processing complete in {processing_time:.2f}s")
710
  logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
711
+ logger.info(f"Violations per worker: {worker_violation_count}")
712
 
713
  violations = []
714
  for (worker_id, label), detection_time in unique_violations.items():
 
791
 
792
  score = calculate_safety_score(violations)
793
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
794
+
795
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
796
 
797
+ # Generate summary of workers and their violations
798
+ worker_summary = {}
799
+ for v in violations:
800
+ worker_id = v["worker_id"]
801
+ if worker_id not in worker_summary:
802
+ worker_summary[worker_id] = {
803
+ "count": 0,
804
+ "violations": set()
805
+ }
806
+ worker_summary[worker_id]["count"] += 1
807
+ worker_summary[worker_id]["violations"].add(v["violation"])
808
+
809
+ # Create violation table with worker summary
810
+ violation_table = "## Worker Safety Violation Summary\n\n"
811
+ violation_table += "| Worker ID | Total Violations | Violation Types |\n"
812
+ violation_table += "|-----------|------------------|-----------------|\n"
813
 
814
+ for worker_id, info in worker_summary.items():
815
+ violation_types = ", ".join([CONFIG["DISPLAY_NAMES"].get(v, v) for v in info["violations"]])
816
+ violation_table += f"| {worker_id} | {info['count']} | {violation_types} |\n"
817
+
818
+ violation_table += "\n## Detailed Violation Log\n\n"
819
+ violation_table += "| Violation | Worker ID | Time (s) | Confidence |\n"
820
+ violation_table += "|-----------|-----------|----------|------------|\n"
821
+
822
  for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
823
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
824
  worker_id = v.get("worker_id", "Unknown")
 
864
  try:
865
  if not video_file:
866
  return "No file uploaded.", "", "No file uploaded.", "", ""
867
+
868
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
869
  logger.info(f"Created temporary directory for video processing: {temp_dir}")
870
 
871
  with open(video_file, "rb") as f:
872
  video_data = f.read()
873
  logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
874
+
875
  if len(video_data) == 0:
876
  return "Uploaded video file is empty.", "", "", "", ""
877
 
 
886
 
887
  for status, score, snapshots_text, record_id, details_url in process_video(video_data, temp_dir):
888
  yield status, score, snapshots_text, record_id, details_url
889
+
890
  except Exception as e:
891
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
892
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
 
897
  logger.info(f"Cleaned up local temporary video file: {local_video_path}")
898
  except Exception as e:
899
  logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
900
+
901
  if temp_dir and os.path.exists(temp_dir):
902
  shutil.rmtree(temp_dir, ignore_errors=True)
903
  logger.info(f"Cleaned up temporary directory: {temp_dir}")