PrashanthB461 commited on
Commit
a409c73
·
verified ·
1 Parent(s): 4c61ad0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -266
app.py CHANGED
@@ -7,7 +7,8 @@ import cv2
7
  import gradio as gr
8
  import torch
9
  import numpy as np
10
- from ultralytics import YOLO
 
11
  import time
12
  from simple_salesforce import Salesforce
13
  from reportlab.lib.pagesizes import letter
@@ -22,6 +23,7 @@ from functools import partial
22
  import tempfile
23
  import shutil
24
  import tenacity
 
25
 
26
  # ========================== # Configuration and Setup # ==========================
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -38,7 +40,7 @@ def check_ffmpeg():
38
 
39
  FFMPEG_AVAILABLE = check_ffmpeg()
40
 
41
- # ========================== # ByteTrack Implementation # ==========================
42
  class BYTETracker:
43
  def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
@@ -49,37 +51,30 @@ class BYTETracker:
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
- self.recently_removed = {} # Store recently removed tracks for re-identification
53
- self.helmet_status = {} # Track helmet status for each worker
 
54
 
55
  def update(self, dets, scores, cls):
56
  tracks = []
57
  current_time = time.time()
58
 
59
  # Prune stale tracks
60
- stale_ids = []
61
- for track_id, track_info in self.tracks.items():
62
- if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
63
- stale_ids.append(track_id)
64
-
65
  for track_id in stale_ids:
66
- # Store recently removed tracks for re-identification (for 1 second)
67
  self.recently_removed[track_id] = {
68
  'bbox': self.tracks[track_id]['bbox'],
69
  'last_seen': current_time,
70
  'last_position': self.last_positions.get(track_id, [0, 0])
71
  }
72
  del self.tracks[track_id]
73
- if track_id in self.worker_history:
74
- del self.worker_history[track_id]
75
- if track_id in self.last_positions:
76
- del self.last_positions[track_id]
77
 
78
  # Clean up recently_removed tracks older than 1 second
79
- to_remove = []
80
- for track_id, info in self.recently_removed.items():
81
- if current_time - info['last_seen'] > 1.0:
82
- to_remove.append(track_id)
83
  for track_id in to_remove:
84
  del self.recently_removed[track_id]
85
 
@@ -92,7 +87,6 @@ class BYTETracker:
92
  best_iou = 0
93
  best_track_id = None
94
 
95
- # Try to match with active tracks
96
  for track_id, track_info in self.tracks.items():
97
  tx, ty, tw, th = track_info['bbox']
98
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
@@ -110,15 +104,12 @@ class BYTETracker:
110
  'last_seen': current_time
111
  })
112
 
113
- # Update helmet status if this is a helmet detection
114
- if cl == 0: # Helmet violation class
115
- # Higher confidence for helmet violations
116
- if score > 0.45: # Increased threshold for helmet violations
117
- self.helmet_status[best_track_id] = True
118
 
119
- if best_track_id not in self.worker_history:
120
- self.worker_history[best_track_id] = []
121
- self.worker_history[best_track_id].append([x, y])
122
  self.last_positions[best_track_id] = [x, y]
123
 
124
  tracks.append({
@@ -128,10 +119,9 @@ class BYTETracker:
128
  'cls': cl
129
  })
130
  else:
131
- # Try to re-identify with recently removed tracks
132
  reidentified = False
133
- for track_id, info in self.recently_removed.items():
134
- if self._is_same_worker([x, y], info['last_position'], threshold=150): # Increased threshold
135
  self.tracks[track_id] = {
136
  'bbox': [x, y, w, h],
137
  'score': score,
@@ -141,11 +131,10 @@ class BYTETracker:
141
  self.worker_history[track_id] = [[x, y]]
142
  self.last_positions[track_id] = [x, y]
143
 
144
- # Update helmet status if this is a helmet detection
145
- if cl == 0: # Helmet violation class
146
- # Higher confidence for helmet violations
147
- if score > 0.45: # Increased threshold for helmet violations
148
- self.helmet_status[track_id] = True
149
 
150
  tracks.append({
151
  'id': track_id,
@@ -158,10 +147,9 @@ class BYTETracker:
158
  break
159
 
160
  if not reidentified:
161
- # Check if it matches an existing worker by position
162
  same_worker = False
163
  for worker_id, last_pos in self.last_positions.items():
164
- if self._is_same_worker([x, y], last_pos, threshold=150): # Increased threshold
165
  self.tracks[worker_id] = {
166
  'bbox': [x, y, w, h],
167
  'score': score,
@@ -169,11 +157,10 @@ class BYTETracker:
169
  'last_seen': current_time
170
  }
171
 
172
- # Update helmet status if this is a helmet detection
173
- if cl == 0: # Helmet violation class
174
- # Higher confidence for helmet violations
175
- if score > 0.45: # Increased threshold for helmet violations
176
- self.helmet_status[worker_id] = True
177
 
178
  tracks.append({
179
  'id': worker_id,
@@ -194,11 +181,10 @@ class BYTETracker:
194
  self.worker_history[self.next_id] = [[x, y]]
195
  self.last_positions[self.next_id] = [x, y]
196
 
197
- # Update helmet status if this is a helmet detection
198
- if cl == 0: # Helmet violation class
199
- # Higher confidence for helmet violations
200
- if score > 0.45: # Increased threshold for helmet violations
201
- self.helmet_status[self.next_id] = True
202
 
203
  tracks.append({
204
  'id': self.next_id,
@@ -228,24 +214,23 @@ class BYTETracker:
228
  def _is_same_worker(self, pos1, pos2, threshold=150):
229
  x1, y1 = pos1
230
  x2, y2 = pos2
231
- distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
232
- return distance < threshold
233
 
234
- # Function to validate if a helmet violation is consistent across frames
235
  def validate_helmet_violation(self, worker_id, current_confidence):
236
- # If we have consistent high confidence or multiple detections, it's a valid violation
237
  return worker_id in self.helmet_status and self.helmet_status[worker_id]
238
 
 
 
 
239
  # ========================== # Optimized Configuration # ==========================
240
  CONFIG = {
241
- "MODEL_PATH": "yolov8_safety.pt",
242
- "FALLBACK_MODEL": "yolov8n.pt",
243
  "VIOLATION_LABELS": {
244
- 0: "no_helmet",
245
- 1: "no_harness",
246
- 2: "unsafe_posture",
247
- 3: "unsafe_zone",
248
- 4: "improper_tool_use"
249
  },
250
  "CLASS_COLORS": {
251
  "no_helmet": (0, 0, 255),
@@ -269,18 +254,18 @@ CONFIG = {
269
  },
270
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
271
  "CONFIDENCE_THRESHOLDS": {
272
- "no_helmet": 0.45, # Increased threshold for helmet violations
273
  "no_harness": 0.25,
274
  "unsafe_posture": 0.25,
275
  "unsafe_zone": 0.25,
276
  "improper_tool_use": 0.25
277
  },
278
- "MIN_VIOLATION_FRAMES": 2, # Increased to require multiple frames for confirmation
279
  "VIOLATION_COOLDOWN": 30.0,
280
  "WORKER_TRACKING_DURATION": 10.0,
281
  "MAX_PROCESSING_TIME": 60,
282
- "FRAME_SKIP": 2, # Increased frame skip for faster processing
283
- "BATCH_SIZE": 8, # Increased batch size for better GPU utilization
284
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
285
  "TRACK_BUFFER": 150,
286
  "TRACK_THRESH": 0.3,
@@ -288,7 +273,7 @@ CONFIG = {
288
  "SNAPSHOT_QUALITY": 95,
289
  "MAX_WORKER_DISTANCE": 150,
290
  "TARGET_RESOLUTION": (384, 384),
291
- "HELMET_VALIDATION_FRAMES": 3 # Number of frames to validate helmet violations
292
  }
293
 
294
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -296,73 +281,72 @@ logger.info(f"Using device: {device}")
296
 
297
  def load_model():
298
  try:
299
- if os.path.isfile(CONFIG["MODEL_PATH"]):
300
- model_path = CONFIG["MODEL_PATH"]
301
- logger.info(f"Model loaded: {model_path}")
302
- else:
303
- model_path = CONFIG["FALLBACK_MODEL"]
304
- logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
305
- if not os.path.isfile(model_path):
306
- logger.info(f"Downloading fallback model: {model_path}")
307
- torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
308
-
309
- model = YOLO(model_path).to(device)
310
  if device.type == "cuda":
311
- model.model.half()
312
- logger.info(f"Model classes: {model.names}")
313
- return model
 
314
  except Exception as e:
315
  logger.error(f"Failed to load model: {e}")
316
  raise
317
 
318
- model = load_model()
319
 
320
  # ========================== # Helper Functions # ==========================
321
  def preprocess_frame(frame):
322
  target_res = CONFIG["TARGET_RESOLUTION"]
323
- # Enhanced preprocessing for better helmet detection
324
  frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
325
- # Increase contrast to better differentiate helmets from other head coverings
326
- frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20) # Increased contrast
327
-
328
- # Additional preprocessing to enhance head/helmet features
329
- # Apply slight sharpening to make edges more distinct
330
- kernel = np.array([[-1,-1,-1],
331
- [-1, 9,-1],
332
- [-1,-1,-1]])
333
  frame = cv2.filter2D(frame, -1, kernel)
334
-
335
  return frame
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  def draw_detections(frame, detections):
338
  result_frame = frame.copy()
339
-
340
  for det in detections:
341
  label = det.get("violation", "Unknown")
342
  confidence = det.get("confidence", 0.0)
343
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
344
  worker_id = det.get("worker_id", "Unknown")
345
-
346
  x1 = int(x - w/2)
347
  y1 = int(y - h/2)
348
  x2 = int(x + w/2)
349
  y2 = int(y + h/2)
350
-
351
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
352
-
353
- # Make no_helmet violations more prominent
354
  line_thickness = 4 if label == "no_helmet" else 3
355
-
356
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, line_thickness)
357
-
358
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
359
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
360
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
361
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
362
-
363
  conf_text = f"Conf: {confidence:.2f}"
364
  cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
365
-
366
  return result_frame
367
 
368
  def calculate_safety_score(violations):
@@ -373,23 +357,15 @@ def calculate_safety_score(violations):
373
  "unsafe_zone": 35,
374
  "improper_tool_use": 25
375
  }
376
-
377
  worker_violations = {}
378
  for v in violations:
379
  worker_id = v.get("worker_id", "Unknown")
380
  violation_type = v.get("violation", "Unknown")
381
-
382
  if worker_id not in worker_violations:
383
  worker_violations[worker_id] = set()
384
  worker_violations[worker_id].add(violation_type)
385
-
386
- total_penalty = 0
387
- for worker_violations_set in worker_violations.values():
388
- worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
389
- total_penalty += worker_penalty
390
-
391
- score = max(0, 100 - total_penalty)
392
- return score
393
 
394
  def generate_violation_pdf(violations, score, output_dir):
395
  try:
@@ -397,70 +373,55 @@ def generate_violation_pdf(violations, score, output_dir):
397
  pdf_path = os.path.join(output_dir, pdf_filename)
398
  pdf_file = BytesIO()
399
  c = canvas.Canvas(pdf_file, pagesize=letter)
400
-
401
  c.setFont("Helvetica-Bold", 16)
402
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
403
-
404
  c.setFont("Helvetica", 12)
405
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
406
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
407
-
408
  c.setFont("Helvetica-Bold", 14)
409
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
410
-
411
  y_position = 8.2 * inch
412
  c.setFont("Helvetica-Bold", 12)
413
  c.drawString(1 * inch, y_position, "Summary:")
414
  y_position -= 0.3 * inch
415
-
416
  worker_violations = {}
417
  for v in violations:
418
  worker_id = v.get("worker_id", "Unknown")
419
  if worker_id not in worker_violations:
420
  worker_violations[worker_id] = []
421
  worker_violations[worker_id].append(v)
422
-
423
  c.setFont("Helvetica", 10)
424
  summary_data = {
425
  "Total Workers with Violations": len(worker_violations),
426
  "Total Violations Found": len(violations),
427
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
428
  }
429
-
430
  for key, value in summary_data.items():
431
  c.drawString(1 * inch, y_position, f"{key}: {value}")
432
  y_position -= 0.25 * inch
433
-
434
  y_position -= 0.5 * inch
435
  c.setFont("Helvetica-Bold", 12)
436
  c.drawString(1 * inch, y_position, "Violations by Worker:")
437
  y_position -= 0.3 * inch
438
-
439
  c.setFont("Helvetica", 10)
440
  for worker_id, worker_vios in worker_violations.items():
441
  c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
442
  y_position -= 0.2 * inch
443
-
444
  for v in worker_vios:
445
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
446
  time_str = f"{v.get('timestamp', 0.0):.2f}s"
447
  conf_str = f"{v.get('confidence', 0.0):.2f}"
448
-
449
  violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
450
  c.drawString(1.2 * inch, y_position, violation_text)
451
  y_position -= 0.2 * inch
452
-
453
  if y_position < 1 * inch:
454
  c.showPage()
455
  c.setFont("Helvetica", 10)
456
  y_position = 10 * inch
457
-
458
  c.save()
459
  pdf_file.seek(0)
460
-
461
  with open(pdf_path, "wb") as f:
462
  f.write(pdf_file.getvalue())
463
-
464
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
465
  logger.info(f"PDF generated: {public_url}")
466
  return pdf_path, public_url, pdf_file
@@ -484,7 +445,6 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
484
  if not pdf_file:
485
  logger.error("No PDF file provided for upload")
486
  return ""
487
-
488
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
489
  content_version_data = {
490
  "Title": f"Safety_Violation_Report_{int(time.time())}",
@@ -494,11 +454,9 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
494
  }
495
  content_version = sf.ContentVersion.create(content_version_data)
496
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
497
-
498
  if not result['records']:
499
  logger.error("Failed to retrieve ContentVersion")
500
  return ""
501
-
502
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
503
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
504
  return file_url
@@ -509,21 +467,16 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
509
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
510
  try:
511
  sf = connect_to_salesforce()
512
-
513
  violations_text = ""
514
  for v in violations:
515
  display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
516
  worker_id = v.get('worker_id', 'Unknown')
517
  timestamp = v.get('timestamp', 0.0)
518
  confidence = v.get('confidence', 0.0)
519
-
520
  violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
521
-
522
  if not violations_text:
523
  violations_text = "No violations detected."
524
-
525
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
526
-
527
  record_data = {
528
  "Compliance_Score__c": score,
529
  "Violations_Found__c": len(violations),
@@ -531,9 +484,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
531
  "Status__c": "Pending",
532
  "PDF_Report_URL__c": pdf_url
533
  }
534
-
535
  logger.info(f"Creating Salesforce record with data: {record_data}")
536
-
537
  try:
538
  record = sf.Safety_Video_Report__c.create(record_data)
539
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
@@ -541,9 +492,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
541
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
542
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
543
  logger.warning(f"Fell back to Account record: {record['id']}")
544
-
545
  record_id = record["id"]
546
-
547
  if pdf_file:
548
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
549
  if uploaded_url:
@@ -555,7 +504,6 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
555
  sf.Account.update(record_id, {"Description": uploaded_url})
556
  logger.info(f"Updated Account record {record_id} with PDF URL")
557
  pdf_url = uploaded_url
558
-
559
  return record_id, pdf_url
560
  except Exception as e:
561
  logger.error(f"Salesforce record creation failed: {e}")
@@ -570,102 +518,60 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
570
  def verify_and_open_video(video_path):
571
  if not os.path.exists(video_path):
572
  raise FileNotFoundError(f"Temporary video file not found: {video_path}")
573
-
574
  file_size = os.path.getsize(video_path)
575
  if file_size == 0:
576
  raise ValueError(f"Temporary video file is empty: {video_path}")
577
-
578
  with open(video_path, "rb") as f:
579
  f.read(1)
580
-
581
  cap = cv2.VideoCapture(video_path)
582
  if not cap.isOpened():
583
  raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
584
-
585
  return cap
586
 
587
- # Helper for helmet validation
588
  def validate_helmet_detection(frame, bbox, confidence_threshold=0.45):
589
- """
590
- Additional validation for helmet detection to reduce false positives.
591
- This function performs additional checks on the region to confirm it's a true helmet violation.
592
- """
593
  x, y, w, h = bbox
594
  x1 = int(max(0, x - w/2))
595
  y1 = int(max(0, y - h/2))
596
  x2 = int(min(frame.shape[1], x + w/2))
597
  y2 = int(min(frame.shape[0], y + h/2))
598
-
599
- # Extract head region
600
  head_region = frame[y1:y2, x1:x2]
601
  if head_region.size == 0:
602
  return False
603
-
604
- # Check if this is truly a helmet violation by analyzing the region
605
- # 1. Check color distribution - helmets often have more uniform color
606
  hsv = cv2.cvtColor(head_region, cv2.COLOR_BGR2HSV)
607
-
608
- # Check for typical helmet colors (many construction helmets are yellow, white, orange, blue)
609
- # This helps differentiate from cloth head coverings
610
  yellow_lower = np.array([20, 100, 100])
611
  yellow_upper = np.array([30, 255, 255])
612
  yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)
613
-
614
  white_lower = np.array([0, 0, 200])
615
  white_upper = np.array([180, 30, 255])
616
  white_mask = cv2.inRange(hsv, white_lower, white_upper)
617
-
618
  orange_lower = np.array([5, 100, 100])
619
  orange_upper = np.array([15, 255, 255])
620
  orange_mask = cv2.inRange(hsv, orange_lower, orange_upper)
621
-
622
  blue_lower = np.array([100, 100, 100])
623
  blue_upper = np.array([130, 255, 255])
624
  blue_mask = cv2.inRange(hsv, blue_lower, blue_upper)
625
-
626
  helmet_mask = cv2.bitwise_or(yellow_mask, white_mask)
627
  helmet_mask = cv2.bitwise_or(helmet_mask, orange_mask)
628
  helmet_mask = cv2.bitwise_or(helmet_mask, blue_mask)
629
-
630
- # If there's a significant amount of helmet-colored pixels, this might be a helmet
631
  helmet_percentage = np.sum(helmet_mask > 0) / (head_region.shape[0] * head_region.shape[1])
632
-
633
- # If the region has a significant amount of helmet-like colors, it's probably a helmet
634
- # so we should NOT flag it as a violation (return False)
635
  if helmet_percentage > 0.25:
636
  return False
637
-
638
- # Check texture uniformity - helmets have more uniform texture compared to head coverings
639
  gray = cv2.cvtColor(head_region, cv2.COLOR_BGR2GRAY)
640
  texture_score = np.std(gray)
641
-
642
- # If texture is very uniform (low standard deviation), it might be a helmet or bare head
643
- # Very uniform texture (like a hard helmet) would have low texture_score
644
- if texture_score < 15: # Low texture suggests uniform surface like a helmet
645
  return False
646
-
647
- # Additional check for cloth-like textures
648
  edges = cv2.Canny(gray, 50, 150)
649
  edge_density = np.sum(edges > 0) / (head_region.shape[0] * head_region.shape[1])
650
-
651
- # If there are many edges (cloth wrinkles), this might be a kurchief
652
  if edge_density > 0.15:
653
- # This is likely a cloth head covering, not a helmet violation
654
- # But also not a proper helmet, so we should still detect as violation
655
  return True
656
-
657
- # If confidence is very high, trust the model
658
  if confidence_threshold >= 0.6:
659
  return True
660
-
661
- # Default to the original detection
662
  return True
663
 
664
  def process_video(video_data, temp_dir):
665
  video_path = None
666
  output_dir = os.path.join(temp_dir, "output")
667
  os.makedirs(output_dir, exist_ok=True)
668
- os.environ['YOLO_CONFIG_DIR'] = temp_dir
669
 
670
  try:
671
  if not video_data:
@@ -681,16 +587,7 @@ def process_video(video_data, temp_dir):
681
  video_path = temp_file.name
682
  logger.info(f"Video saved to temporary file: {video_path}")
683
 
684
- if not os.path.exists(video_path):
685
- raise FileNotFoundError(f"Temporary video file not found: {video_path}")
686
- file_size = os.path.getsize(video_path)
687
- if file_size == 0:
688
- raise ValueError(f"Temporary video file is empty: {video_path}")
689
- logger.info(f"Temporary video file size: {file_size} bytes")
690
-
691
  cap = verify_and_open_video(video_path)
692
- logger.info(f"Successfully opened video file: {video_path}")
693
-
694
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
695
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
696
  duration = total_frames / fps
@@ -711,8 +608,7 @@ def process_video(video_data, temp_dir):
711
  worker_id_mapping = {}
712
  unique_violations = {}
713
  violation_frames = {}
714
- # Track helmet detections across frames for each worker
715
- helmet_detections = {}
716
  start_time = time.time()
717
  frame_skip = CONFIG["FRAME_SKIP"]
718
  processed_frames = 0
@@ -722,28 +618,22 @@ def process_video(video_data, temp_dir):
722
  while processed_frames < total_frames:
723
  batch_frames = []
724
  batch_indices = []
725
- batch_originals = [] # Store original frames for helmet validation
726
 
727
  for _ in range(CONFIG["BATCH_SIZE"]):
728
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
729
  if frame_idx >= total_frames:
730
  break
731
-
732
  ret, frame = cap.read()
733
  if not ret:
734
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
735
  break
736
-
737
- # Store original frame for validation
738
  original_frame = frame.copy()
739
-
740
  frame = preprocess_frame(frame)
741
-
742
  for _ in range(frame_skip - 1):
743
  if not cap.grab():
744
  break
745
-
746
- batch_frames.append(frame)
747
  batch_indices.append(frame_idx)
748
  batch_originals.append(original_frame)
749
  processed_frames += 1
@@ -753,16 +643,16 @@ def process_video(video_data, temp_dir):
753
  break
754
 
755
  try:
756
- batch_frames_np = np.array(batch_frames)
757
- batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
758
- batch_frames_tensor = batch_frames_tensor.to(device)
759
  if device.type == "cuda":
760
- batch_frames_tensor = batch_frames_tensor.half()
761
-
762
- results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
 
 
763
  except Exception as e:
764
  logger.error(f"Model inference failed: {e}")
765
- raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}")
766
  finally:
767
  batch_frames = []
768
  if device.type == "cuda":
@@ -778,39 +668,37 @@ def process_video(video_data, temp_dir):
778
 
779
  for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
780
  current_time = frame_idx / fps
781
-
782
- boxes = result.boxes
783
  track_inputs = []
784
-
785
- for box in boxes:
786
- cls = int(box.cls)
787
- conf = float(box.conf)
788
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
789
-
790
- if label is None:
791
- continue
792
-
793
- # Enhanced confidence threshold handling, especially for helmet detection
794
- if label == "no_helmet":
795
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.45):
796
- continue
797
-
798
- # Additional validation for helmet detection
799
- bbox = box.xywh.cpu().numpy()[0]
800
- if not validate_helmet_detection(original_frame, bbox, conf):
801
  logger.info(f"Frame {frame_idx}: Helmet false positive filtered at {conf:.2f} confidence")
802
  continue
803
- else:
804
- # Use regular thresholds for other violations
805
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
806
- continue
807
-
808
- bbox = box.xywh.cpu().numpy()[0]
809
- track_inputs.append({
810
- "bbox": bbox,
811
- "conf": conf,
812
- "cls": cls
813
- })
 
 
 
 
814
 
815
  if not track_inputs:
816
  continue
@@ -824,11 +712,11 @@ def process_video(video_data, temp_dir):
824
 
825
  for obj in tracked_objects:
826
  tracker_id = obj['id']
827
- label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
828
  conf = obj['score']
829
  bbox = obj['bbox']
830
 
831
- if label is None:
832
  continue
833
 
834
  if tracker_id not in worker_id_mapping:
@@ -837,25 +725,16 @@ def process_video(video_data, temp_dir):
837
 
838
  worker_id = worker_id_mapping[tracker_id]
839
 
840
- # Special handling for helmet violations to ensure consistency
841
  if label == "no_helmet":
842
- # Track helmet violations for this worker
843
  if worker_id not in helmet_detections:
844
  helmet_detections[worker_id] = []
845
-
846
- # Store this detection with frame index and confidence
847
  helmet_detections[worker_id].append({
848
  "frame_idx": frame_idx,
849
  "confidence": conf,
850
  "bbox": bbox
851
  })
852
-
853
- # Only record a helmet violation if we have multiple consistent detections
854
  if len(helmet_detections[worker_id]) >= CONFIG["HELMET_VALIDATION_FRAMES"]:
855
- # Calculate average confidence
856
  avg_conf = sum(d["confidence"] for d in helmet_detections[worker_id]) / len(helmet_detections[worker_id])
857
-
858
- # If confidence is consistently high across multiple frames, record the violation
859
  if avg_conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
860
  violation_key = (worker_id, label)
861
  if violation_key not in unique_violations:
@@ -863,7 +742,6 @@ def process_video(video_data, temp_dir):
863
  violation_frames[violation_key] = frame_idx
864
  logger.info(f"Frame {frame_idx}: Valid helmet violation for worker {worker_id} with avg conf {avg_conf:.2f}")
865
  else:
866
- # Regular handling for other violations
867
  violation_key = (worker_id, label)
868
  if violation_key not in unique_violations:
869
  unique_violations[violation_key] = current_time
@@ -900,26 +778,29 @@ def process_video(video_data, temp_dir):
900
  continue
901
 
902
  frame = preprocess_frame(frame)
903
- frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
904
- frame_tensor = frame_tensor.unsqueeze(0).to(device)
905
  if device.type == "cuda":
906
- frame_tensor = frame_tensor.half()
907
-
908
- result = model(frame_tensor, device=device, conf=0.1, verbose=False)[0]
909
- boxes = result.boxes
910
-
911
- for box in boxes:
912
- cls = int(box.cls)
913
- conf = float(box.conf)
914
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
915
- if label == violation["violation"]:
 
 
 
 
916
  violation["confidence"] = round(conf, 2)
917
- bbox = box.xywh.cpu().numpy()[0]
918
  detection = {
919
  "worker_id": violation["worker_id"],
920
- "violation": label,
921
  "confidence": violation["confidence"],
922
- "bounding_box": bbox,
923
  "timestamp": violation["timestamp"]
924
  }
925
  snapshot_frame = frame.copy()
@@ -933,7 +814,7 @@ def process_video(video_data, temp_dir):
933
  (255, 255, 255),
934
  2
935
  )
936
- snapshot_filename = f"violation_{label}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
937
  snapshot_path = os.path.join(output_dir, snapshot_filename)
938
  cv2.imwrite(
939
  snapshot_path,
@@ -941,14 +822,14 @@ def process_video(video_data, temp_dir):
941
  [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
942
  )
943
  snapshots.append({
944
- "violation": label,
945
  "worker_id": violation["worker_id"],
946
  "timestamp": violation["timestamp"],
947
  "snapshot_path": snapshot_path,
948
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
949
  "confidence": violation["confidence"]
950
  })
951
- logger.info(f"Captured snapshot for {label} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
952
  break
953
 
954
  cap.release()
@@ -1007,7 +888,7 @@ def gradio_interface(video_file):
1007
  if not video_file:
1008
  return "No file uploaded.", "", "No file uploaded.", "", ""
1009
 
1010
- temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
1011
  logger.info(f"Created temporary directory for video processing: {temp_dir}")
1012
 
1013
  with open(video_file, "rb") as f:
@@ -1063,5 +944,5 @@ interface = gr.Interface(
1063
  )
1064
 
1065
  if __name__ == "__main__":
1066
- logger.info("Launching Enhanced Safety Analyzer App...")
1067
  interface.launch()
 
7
  import gradio as gr
8
  import torch
9
  import numpy as np
10
+ from transformers import DetrImageProcessor, DetrForObjectDetection
11
+ from PIL import Image
12
  import time
13
  from simple_salesforce import Salesforce
14
  from reportlab.lib.pagesizes import letter
 
23
  import tempfile
24
  import shutil
25
  import tenacity
26
+ from scipy.spatial import distance
27
 
28
  # ========================== # Configuration and Setup # ==========================
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
40
 
41
  FFMPEG_AVAILABLE = check_ffmpeg()
42
 
43
+ # ========================== # BYTETracker Implementation # ==========================
44
  class BYTETracker:
45
  def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
46
  self.track_thresh = track_thresh
 
51
  self.tracks = {}
52
  self.worker_history = {}
53
  self.last_positions = {}
54
+ self.recently_removed = {}
55
+ self.helmet_status = {}
56
+ self.harness_status = {}
57
 
58
  def update(self, dets, scores, cls):
59
  tracks = []
60
  current_time = time.time()
61
 
62
  # Prune stale tracks
63
+ stale_ids = [track_id for track_id, track_info in self.tracks.items()
64
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate]
 
 
 
65
  for track_id in stale_ids:
 
66
  self.recently_removed[track_id] = {
67
  'bbox': self.tracks[track_id]['bbox'],
68
  'last_seen': current_time,
69
  'last_position': self.last_positions.get(track_id, [0, 0])
70
  }
71
  del self.tracks[track_id]
72
+ self.worker_history.pop(track_id, None)
73
+ self.last_positions.pop(track_id, None)
 
 
74
 
75
  # Clean up recently_removed tracks older than 1 second
76
+ to_remove = [track_id for track_id, info in self.recently_removed.items()
77
+ if current_time - info['last_seen'] > 1.0]
 
 
78
  for track_id in to_remove:
79
  del self.recently_removed[track_id]
80
 
 
87
  best_iou = 0
88
  best_track_id = None
89
 
 
90
  for track_id, track_info in self.tracks.items():
91
  tx, ty, tw, th = track_info['bbox']
92
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
 
104
  'last_seen': current_time
105
  })
106
 
107
+ if cl == "no_helmet" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
108
+ self.helmet_status[best_track_id] = True
109
+ elif cl == "no_harness" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_harness"]:
110
+ self.harness_status[best_track_id] = True
 
111
 
112
+ self.worker_history[best_track_id] = self.worker_history.get(best_track_id, []) + [[x, y]]
 
 
113
  self.last_positions[best_track_id] = [x, y]
114
 
115
  tracks.append({
 
119
  'cls': cl
120
  })
121
  else:
 
122
  reidentified = False
123
+ for track_id, info in list(self.recently_removed.items()):
124
+ if self._is_same_worker([x, y], info['last_position'], threshold=CONFIG["MAX_WORKER_DISTANCE"]):
125
  self.tracks[track_id] = {
126
  'bbox': [x, y, w, h],
127
  'score': score,
 
131
  self.worker_history[track_id] = [[x, y]]
132
  self.last_positions[track_id] = [x, y]
133
 
134
+ if cl == "no_helmet" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
135
+ self.helmet_status[track_id] = True
136
+ elif cl == "no_harness" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_harness"]:
137
+ self.harness_status[track_id] = True
 
138
 
139
  tracks.append({
140
  'id': track_id,
 
147
  break
148
 
149
  if not reidentified:
 
150
  same_worker = False
151
  for worker_id, last_pos in self.last_positions.items():
152
+ if self._is_same_worker([x, y], last_pos, threshold=CONFIG["MAX_WORKER_DISTANCE"]):
153
  self.tracks[worker_id] = {
154
  'bbox': [x, y, w, h],
155
  'score': score,
 
157
  'last_seen': current_time
158
  }
159
 
160
+ if cl == "no_helmet" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
161
+ self.helmet_status[worker_id] = True
162
+ elif cl == "no_harness" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_harness"]:
163
+ self.harness_status[worker_id] = True
 
164
 
165
  tracks.append({
166
  'id': worker_id,
 
181
  self.worker_history[self.next_id] = [[x, y]]
182
  self.last_positions[self.next_id] = [x, y]
183
 
184
+ if cl == "no_helmet" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
185
+ self.helmet_status[self.next_id] = True
186
+ elif cl == "no_harness" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_harness"]:
187
+ self.harness_status[self.next_id] = True
 
188
 
189
  tracks.append({
190
  'id': self.next_id,
 
214
  def _is_same_worker(self, pos1, pos2, threshold=150):
215
  x1, y1 = pos1
216
  x2, y2 = pos2
217
+ return np.sqrt((x1 - x2)**2 + (y1 - y2)**2) < threshold
 
218
 
 
219
  def validate_helmet_violation(self, worker_id, current_confidence):
 
220
  return worker_id in self.helmet_status and self.helmet_status[worker_id]
221
 
222
+ def validate_harness_violation(self, worker_id, current_confidence):
223
+ return worker_id in self.harness_status and self.harness_status[worker_id]
224
+
225
  # ========================== # Optimized Configuration # ==========================
226
  CONFIG = {
227
+ "MODEL_NAME": "facebook/detr-resnet-50", # Fine-tune with your dataset, e.g., "your-username/detr-resnet-50-finetuned-safety"
 
228
  "VIOLATION_LABELS": {
229
+ "no_helmet": "No Helmet",
230
+ "no_harness": "No Harness",
231
+ "unsafe_posture": "Unsafe Posture",
232
+ "unsafe_zone": "Unsafe Zone",
233
+ "improper_tool_use": "Improper Tool Use"
234
  },
235
  "CLASS_COLORS": {
236
  "no_helmet": (0, 0, 255),
 
254
  },
255
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
256
  "CONFIDENCE_THRESHOLDS": {
257
+ "no_helmet": 0.45,
258
  "no_harness": 0.25,
259
  "unsafe_posture": 0.25,
260
  "unsafe_zone": 0.25,
261
  "improper_tool_use": 0.25
262
  },
263
+ "MIN_VIOLATION_FRAMES": 2,
264
  "VIOLATION_COOLDOWN": 30.0,
265
  "WORKER_TRACKING_DURATION": 10.0,
266
  "MAX_PROCESSING_TIME": 60,
267
+ "FRAME_SKIP": 2,
268
+ "BATCH_SIZE": 8,
269
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
270
  "TRACK_BUFFER": 150,
271
  "TRACK_THRESH": 0.3,
 
273
  "SNAPSHOT_QUALITY": 95,
274
  "MAX_WORKER_DISTANCE": 150,
275
  "TARGET_RESOLUTION": (384, 384),
276
+ "HELMET_VALIDATION_FRAMES": 3
277
  }
278
 
279
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
281
 
282
  def load_model():
283
  try:
284
+ processor = DetrImageProcessor.from_pretrained(CONFIG["MODEL_NAME"])
285
+ model = DetrForObjectDetection.from_pretrained(CONFIG["MODEL_NAME"]).to(device)
 
 
 
 
 
 
 
 
 
286
  if device.type == "cuda":
287
+ model = model.half()
288
+ logger.info(f"Loaded DETR model: {CONFIG['MODEL_NAME']}")
289
+ logger.info(f"Model classes: {model.config.id2label}")
290
+ return processor, model
291
  except Exception as e:
292
  logger.error(f"Failed to load model: {e}")
293
  raise
294
 
295
+ processor, model = load_model()
296
 
297
  # ========================== # Helper Functions # ==========================
298
  def preprocess_frame(frame):
299
  target_res = CONFIG["TARGET_RESOLUTION"]
 
300
  frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
301
+ frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20)
302
+ kernel = np.array([[-1,-1,-1], [-1, 9,-1], [-1,-1,-1]])
 
 
 
 
 
 
303
  frame = cv2.filter2D(frame, -1, kernel)
 
304
  return frame
305
 
306
+ def is_unsafe_posture(box, frame_shape):
307
+ """Placeholder for unsafe posture detection. Replace with pose estimation (e.g., MediaPipe)."""
308
+ x1, y1, x2, y2 = box
309
+ height = y2 - y1
310
+ width = x2 - x1
311
+ aspect_ratio = height / max(width, 1)
312
+ return aspect_ratio > 2.0 # Tall, narrow box suggests bending/unsafe posture
313
+
314
+ def is_improper_tool_use(person_box, tool_box):
315
+ """Placeholder for improper tool use. Fine-tune DETR for specific tools."""
316
+ person_center = ((person_box[0] + person_box[2]) / 2, (person_box[1] + person_box[3]) / 2)
317
+ tool_center = ((tool_box[0] + tool_box[2]) / 2, (tool_box[1] + tool_box[3]) / 2)
318
+ dist = distance.euclidean(person_center, tool_center)
319
+ return dist > 100 # Tool too far from person
320
+
321
+ def is_unsafe_zone(person_box, frame_shape):
322
+ """Check if person is in restricted area (e.g., top-left quadrant)."""
323
+ px, py, pw, ph = person_box
324
+ frame_h, frame_w = frame_shape
325
+ person_center = (px + pw / 2, py + ph / 2)
326
+ unsafe_zone = (0, 0, 0.5, 0.5) # Top-left quadrant
327
+ return (unsafe_zone[0] * frame_w < person_center[0] < unsafe_zone[2] * frame_w and
328
+ unsafe_zone[1] * frame_h < person_center[1] < unsafe_zone[3] * frame_h)
329
+
330
  def draw_detections(frame, detections):
331
  result_frame = frame.copy()
 
332
  for det in detections:
333
  label = det.get("violation", "Unknown")
334
  confidence = det.get("confidence", 0.0)
335
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
336
  worker_id = det.get("worker_id", "Unknown")
 
337
  x1 = int(x - w/2)
338
  y1 = int(y - h/2)
339
  x2 = int(x + w/2)
340
  y2 = int(y + h/2)
 
341
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
 
 
342
  line_thickness = 4 if label == "no_helmet" else 3
 
343
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, line_thickness)
 
344
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
345
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
346
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
347
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
 
348
  conf_text = f"Conf: {confidence:.2f}"
349
  cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
 
350
  return result_frame
351
 
352
  def calculate_safety_score(violations):
 
357
  "unsafe_zone": 35,
358
  "improper_tool_use": 25
359
  }
 
360
  worker_violations = {}
361
  for v in violations:
362
  worker_id = v.get("worker_id", "Unknown")
363
  violation_type = v.get("violation", "Unknown")
 
364
  if worker_id not in worker_violations:
365
  worker_violations[worker_id] = set()
366
  worker_violations[worker_id].add(violation_type)
367
+ total_penalty = sum(sum(penalties.get(v, 0) for v in worker_violations[wid]) for wid in worker_violations)
368
+ return max(0, 100 - total_penalty)
 
 
 
 
 
 
369
 
370
  def generate_violation_pdf(violations, score, output_dir):
371
  try:
 
373
  pdf_path = os.path.join(output_dir, pdf_filename)
374
  pdf_file = BytesIO()
375
  c = canvas.Canvas(pdf_file, pagesize=letter)
 
376
  c.setFont("Helvetica-Bold", 16)
377
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
 
378
  c.setFont("Helvetica", 12)
379
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
380
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
 
381
  c.setFont("Helvetica-Bold", 14)
382
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
 
383
  y_position = 8.2 * inch
384
  c.setFont("Helvetica-Bold", 12)
385
  c.drawString(1 * inch, y_position, "Summary:")
386
  y_position -= 0.3 * inch
 
387
  worker_violations = {}
388
  for v in violations:
389
  worker_id = v.get("worker_id", "Unknown")
390
  if worker_id not in worker_violations:
391
  worker_violations[worker_id] = []
392
  worker_violations[worker_id].append(v)
 
393
  c.setFont("Helvetica", 10)
394
  summary_data = {
395
  "Total Workers with Violations": len(worker_violations),
396
  "Total Violations Found": len(violations),
397
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
398
  }
 
399
  for key, value in summary_data.items():
400
  c.drawString(1 * inch, y_position, f"{key}: {value}")
401
  y_position -= 0.25 * inch
 
402
  y_position -= 0.5 * inch
403
  c.setFont("Helvetica-Bold", 12)
404
  c.drawString(1 * inch, y_position, "Violations by Worker:")
405
  y_position -= 0.3 * inch
 
406
  c.setFont("Helvetica", 10)
407
  for worker_id, worker_vios in worker_violations.items():
408
  c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
409
  y_position -= 0.2 * inch
 
410
  for v in worker_vios:
411
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
412
  time_str = f"{v.get('timestamp', 0.0):.2f}s"
413
  conf_str = f"{v.get('confidence', 0.0):.2f}"
 
414
  violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
415
  c.drawString(1.2 * inch, y_position, violation_text)
416
  y_position -= 0.2 * inch
 
417
  if y_position < 1 * inch:
418
  c.showPage()
419
  c.setFont("Helvetica", 10)
420
  y_position = 10 * inch
 
421
  c.save()
422
  pdf_file.seek(0)
 
423
  with open(pdf_path, "wb") as f:
424
  f.write(pdf_file.getvalue())
 
425
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
426
  logger.info(f"PDF generated: {public_url}")
427
  return pdf_path, public_url, pdf_file
 
445
  if not pdf_file:
446
  logger.error("No PDF file provided for upload")
447
  return ""
 
448
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
449
  content_version_data = {
450
  "Title": f"Safety_Violation_Report_{int(time.time())}",
 
454
  }
455
  content_version = sf.ContentVersion.create(content_version_data)
456
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
 
457
  if not result['records']:
458
  logger.error("Failed to retrieve ContentVersion")
459
  return ""
 
460
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
461
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
462
  return file_url
 
467
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
468
  try:
469
  sf = connect_to_salesforce()
 
470
  violations_text = ""
471
  for v in violations:
472
  display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
473
  worker_id = v.get('worker_id', 'Unknown')
474
  timestamp = v.get('timestamp', 0.0)
475
  confidence = v.get('confidence', 0.0)
 
476
  violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
 
477
  if not violations_text:
478
  violations_text = "No violations detected."
 
479
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
 
480
  record_data = {
481
  "Compliance_Score__c": score,
482
  "Violations_Found__c": len(violations),
 
484
  "Status__c": "Pending",
485
  "PDF_Report_URL__c": pdf_url
486
  }
 
487
  logger.info(f"Creating Salesforce record with data: {record_data}")
 
488
  try:
489
  record = sf.Safety_Video_Report__c.create(record_data)
490
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
 
492
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
493
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
494
  logger.warning(f"Fell back to Account record: {record['id']}")
 
495
  record_id = record["id"]
 
496
  if pdf_file:
497
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
498
  if uploaded_url:
 
504
  sf.Account.update(record_id, {"Description": uploaded_url})
505
  logger.info(f"Updated Account record {record_id} with PDF URL")
506
  pdf_url = uploaded_url
 
507
  return record_id, pdf_url
508
  except Exception as e:
509
  logger.error(f"Salesforce record creation failed: {e}")
 
518
  def verify_and_open_video(video_path):
519
  if not os.path.exists(video_path):
520
  raise FileNotFoundError(f"Temporary video file not found: {video_path}")
 
521
  file_size = os.path.getsize(video_path)
522
  if file_size == 0:
523
  raise ValueError(f"Temporary video file is empty: {video_path}")
 
524
  with open(video_path, "rb") as f:
525
  f.read(1)
 
526
  cap = cv2.VideoCapture(video_path)
527
  if not cap.isOpened():
528
  raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
 
529
  return cap
530
 
 
531
  def validate_helmet_detection(frame, bbox, confidence_threshold=0.45):
 
 
 
 
532
  x, y, w, h = bbox
533
  x1 = int(max(0, x - w/2))
534
  y1 = int(max(0, y - h/2))
535
  x2 = int(min(frame.shape[1], x + w/2))
536
  y2 = int(min(frame.shape[0], y + h/2))
 
 
537
  head_region = frame[y1:y2, x1:x2]
538
  if head_region.size == 0:
539
  return False
 
 
 
540
  hsv = cv2.cvtColor(head_region, cv2.COLOR_BGR2HSV)
 
 
 
541
  yellow_lower = np.array([20, 100, 100])
542
  yellow_upper = np.array([30, 255, 255])
543
  yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)
 
544
  white_lower = np.array([0, 0, 200])
545
  white_upper = np.array([180, 30, 255])
546
  white_mask = cv2.inRange(hsv, white_lower, white_upper)
 
547
  orange_lower = np.array([5, 100, 100])
548
  orange_upper = np.array([15, 255, 255])
549
  orange_mask = cv2.inRange(hsv, orange_lower, orange_upper)
 
550
  blue_lower = np.array([100, 100, 100])
551
  blue_upper = np.array([130, 255, 255])
552
  blue_mask = cv2.inRange(hsv, blue_lower, blue_upper)
 
553
  helmet_mask = cv2.bitwise_or(yellow_mask, white_mask)
554
  helmet_mask = cv2.bitwise_or(helmet_mask, orange_mask)
555
  helmet_mask = cv2.bitwise_or(helmet_mask, blue_mask)
 
 
556
  helmet_percentage = np.sum(helmet_mask > 0) / (head_region.shape[0] * head_region.shape[1])
 
 
 
557
  if helmet_percentage > 0.25:
558
  return False
 
 
559
  gray = cv2.cvtColor(head_region, cv2.COLOR_BGR2GRAY)
560
  texture_score = np.std(gray)
561
+ if texture_score < 15:
 
 
 
562
  return False
 
 
563
  edges = cv2.Canny(gray, 50, 150)
564
  edge_density = np.sum(edges > 0) / (head_region.shape[0] * head_region.shape[1])
 
 
565
  if edge_density > 0.15:
 
 
566
  return True
 
 
567
  if confidence_threshold >= 0.6:
568
  return True
 
 
569
  return True
570
 
571
  def process_video(video_data, temp_dir):
572
  video_path = None
573
  output_dir = os.path.join(temp_dir, "output")
574
  os.makedirs(output_dir, exist_ok=True)
 
575
 
576
  try:
577
  if not video_data:
 
587
  video_path = temp_file.name
588
  logger.info(f"Video saved to temporary file: {video_path}")
589
 
 
 
 
 
 
 
 
590
  cap = verify_and_open_video(video_path)
 
 
591
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
592
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
593
  duration = total_frames / fps
 
608
  worker_id_mapping = {}
609
  unique_violations = {}
610
  violation_frames = {}
611
+ helmet_detections = {}
 
612
  start_time = time.time()
613
  frame_skip = CONFIG["FRAME_SKIP"]
614
  processed_frames = 0
 
618
  while processed_frames < total_frames:
619
  batch_frames = []
620
  batch_indices = []
621
+ batch_originals = []
622
 
623
  for _ in range(CONFIG["BATCH_SIZE"]):
624
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
625
  if frame_idx >= total_frames:
626
  break
 
627
  ret, frame = cap.read()
628
  if not ret:
629
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
630
  break
 
 
631
  original_frame = frame.copy()
 
632
  frame = preprocess_frame(frame)
 
633
  for _ in range(frame_skip - 1):
634
  if not cap.grab():
635
  break
636
+ batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
 
637
  batch_indices.append(frame_idx)
638
  batch_originals.append(original_frame)
639
  processed_frames += 1
 
643
  break
644
 
645
  try:
646
+ inputs = processor(images=batch_frames, return_tensors="pt").to(device)
 
 
647
  if device.type == "cuda":
648
+ inputs = {k: v.half() for k, v in inputs.items()}
649
+ with torch.no_grad():
650
+ outputs = model(**inputs)
651
+ target_sizes = torch.tensor([frame.size[::-1] for frame in batch_frames]).to(device)
652
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.1)
653
  except Exception as e:
654
  logger.error(f"Model inference failed: {e}")
655
+ raise ValueError(f"Failed to process video frames with DETR model: {str(e)}")
656
  finally:
657
  batch_frames = []
658
  if device.type == "cuda":
 
668
 
669
  for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
670
  current_time = frame_idx / fps
 
 
671
  track_inputs = []
672
+ person_boxes = []
673
+ tool_boxes = []
674
+
675
+ for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
676
+ label_name = model.config.id2label[label.item()]
677
+ conf = float(score)
678
+ bbox = box.cpu().numpy()
679
+ x, y, x2, y2 = bbox
680
+ w, h = x2 - x, y2 - y
681
+ bbox_xywh = [x + w/2, y + h/2, w, h]
682
+
683
+ if label_name in ["no_helmet", "no_harness"] and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label_name, 0.25):
684
+ if label_name == "no_helmet" and not validate_helmet_detection(original_frame, bbox_xywh, conf):
 
 
 
 
685
  logger.info(f"Frame {frame_idx}: Helmet false positive filtered at {conf:.2f} confidence")
686
  continue
687
+ track_inputs.append({"bbox": bbox_xywh, "conf": conf, "cls": label_name})
688
+ elif label_name == "person":
689
+ person_boxes.append(bbox_xywh)
690
+ elif label_name in ["hammer", "wrench"]: # Example tools; update with your dataset
691
+ tool_boxes.append(bbox_xywh)
692
+
693
+ # Handle Unsafe Posture, Unsafe Zone, Improper Tool Use
694
+ for pbox in person_boxes:
695
+ if is_unsafe_posture(pbox, original_frame.shape[:2]):
696
+ track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_posture"})
697
+ if is_unsafe_zone(pbox, original_frame.shape[:2]):
698
+ track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_zone"})
699
+ for tbox in tool_boxes:
700
+ if is_improper_tool_use(pbox, tbox):
701
+ track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "improper_tool_use"})
702
 
703
  if not track_inputs:
704
  continue
 
712
 
713
  for obj in tracked_objects:
714
  tracker_id = obj['id']
715
+ label = obj['cls']
716
  conf = obj['score']
717
  bbox = obj['bbox']
718
 
719
+ if label not in CONFIG["VIOLATION_LABELS"]:
720
  continue
721
 
722
  if tracker_id not in worker_id_mapping:
 
725
 
726
  worker_id = worker_id_mapping[tracker_id]
727
 
 
728
  if label == "no_helmet":
 
729
  if worker_id not in helmet_detections:
730
  helmet_detections[worker_id] = []
 
 
731
  helmet_detections[worker_id].append({
732
  "frame_idx": frame_idx,
733
  "confidence": conf,
734
  "bbox": bbox
735
  })
 
 
736
  if len(helmet_detections[worker_id]) >= CONFIG["HELMET_VALIDATION_FRAMES"]:
 
737
  avg_conf = sum(d["confidence"] for d in helmet_detections[worker_id]) / len(helmet_detections[worker_id])
 
 
738
  if avg_conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
739
  violation_key = (worker_id, label)
740
  if violation_key not in unique_violations:
 
742
  violation_frames[violation_key] = frame_idx
743
  logger.info(f"Frame {frame_idx}: Valid helmet violation for worker {worker_id} with avg conf {avg_conf:.2f}")
744
  else:
 
745
  violation_key = (worker_id, label)
746
  if violation_key not in unique_violations:
747
  unique_violations[violation_key] = current_time
 
778
  continue
779
 
780
  frame = preprocess_frame(frame)
781
+ frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
782
+ inputs = processor(images=frame_pil, return_tensors="pt").to(device)
783
  if device.type == "cuda":
784
+ inputs = {k: v.half() for k, v in inputs.items()}
785
+ with torch.no_grad():
786
+ outputs = model(**inputs)
787
+ target_sizes = torch.tensor([frame_pil.size[::-1]]).to(device)
788
+ result = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.1)[0]
789
+
790
+ for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
791
+ label_name = model.config.id2label[label.item()]
792
+ conf = float(score)
793
+ bbox = box.cpu().numpy()
794
+ x, y, x2, y2 = bbox
795
+ w, h = x2 - x, y2 - y
796
+ bbox_xywh = [x + w/2, y + h/2, w, h]
797
+ if label_name == violation["violation"]:
798
  violation["confidence"] = round(conf, 2)
 
799
  detection = {
800
  "worker_id": violation["worker_id"],
801
+ "violation": label_name,
802
  "confidence": violation["confidence"],
803
+ "bounding_box": bbox_xywh,
804
  "timestamp": violation["timestamp"]
805
  }
806
  snapshot_frame = frame.copy()
 
814
  (255, 255, 255),
815
  2
816
  )
817
+ snapshot_filename = f"violation_{label_name}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
818
  snapshot_path = os.path.join(output_dir, snapshot_filename)
819
  cv2.imwrite(
820
  snapshot_path,
 
822
  [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
823
  )
824
  snapshots.append({
825
+ "violation": label_name,
826
  "worker_id": violation["worker_id"],
827
  "timestamp": violation["timestamp"],
828
  "snapshot_path": snapshot_path,
829
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
830
  "confidence": violation["confidence"]
831
  })
832
+ logger.info(f"Captured snapshot for {label_name} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
833
  break
834
 
835
  cap.release()
 
888
  if not video_file:
889
  return "No file uploaded.", "", "No file uploaded.", "", ""
890
 
891
+ temp_dir = tempfile.mkdtemp(prefix="DETR_")
892
  logger.info(f"Created temporary directory for video processing: {temp_dir}")
893
 
894
  with open(video_file, "rb") as f:
 
944
  )
945
 
946
  if __name__ == "__main__":
947
+ logger.info("Launching Enhanced Safety Analyzer App with DETR...")
948
  interface.launch()