PrashanthB461 commited on
Commit
a3d6280
·
verified ·
1 Parent(s): 23719a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -492
app.py CHANGED
@@ -19,7 +19,6 @@ from retrying import retry
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
22
- import face_recognition
23
  from collections import defaultdict
24
 
25
  # ========================== # Configuration and Setup # ==========================
@@ -28,11 +27,9 @@ os.makedirs('/tmp/Ultralytics', exist_ok=True)
28
 
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
30
  logger = logging.getLogger(__name__)
31
-
32
- # Suppress warnings
33
  warnings.filterwarnings("ignore")
34
 
35
- # ========================== # Enhanced Tracker Implementation # ==========================
36
  class SafetyTracker:
37
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
38
  self.track_thresh = track_thresh
@@ -41,13 +38,12 @@ class SafetyTracker:
41
  self.frame_rate = frame_rate
42
  self.next_id = 1
43
 
44
- # Trackers for different purposes
45
- self.worker_tracks = {} # Active worker tracks
46
- self.violation_history = defaultdict(dict) # Track violations per worker
47
- self.face_encodings = {} # Store face encodings for helmet violations
48
- self.position_history = defaultdict(list) # Track positions for non-helmet violations
49
 
50
- # Cooldown periods (in seconds)
51
  self.VIOLATION_COOLDOWNS = {
52
  "no_helmet": 30.0,
53
  "no_harness": 20.0,
@@ -56,10 +52,9 @@ class SafetyTracker:
56
  "improper_tool_use": 15.0
57
  }
58
 
59
- def update(self, detections, frame):
60
- """Update tracks with new detections and check for violations"""
61
  current_time = time.time()
62
- active_violations = []
63
  new_violations = []
64
 
65
  for det in detections:
@@ -67,20 +62,15 @@ class SafetyTracker:
67
  label = det['violation']
68
  confidence = det['confidence']
69
 
70
- # For helmet violations, use face recognition
71
- if label == "no_helmet":
72
- worker_id = self._match_by_face(bbox, frame)
73
- else:
74
- # For other violations, use position tracking
75
- worker_id = self._match_by_position(bbox, label)
76
 
77
  if worker_id is None:
78
  worker_id = self.next_id
79
  self.next_id += 1
80
 
81
- # Check if this is a new violation for this worker
82
  if self._is_new_violation(worker_id, label, current_time):
83
- # Record the violation
84
  violation = {
85
  'worker_id': worker_id,
86
  'violation': label,
@@ -89,116 +79,64 @@ class SafetyTracker:
89
  'timestamp': current_time
90
  }
91
  new_violations.append(violation)
92
-
93
- # Update violation history
94
  self.violation_history[worker_id][label] = current_time
95
-
96
- # For helmet violations, store face encoding
97
- if label == "no_helmet":
98
- self._store_face_encoding(worker_id, bbox, frame)
99
 
100
- # Keep track of active workers
 
 
 
 
101
  self.worker_tracks[worker_id] = {
102
  'bbox': bbox,
103
  'last_seen': current_time,
104
  'label': label
105
  }
106
 
107
- # Clean up old tracks
108
  self._cleanup_tracks(current_time)
109
 
110
  return new_violations
111
 
112
- def _match_by_face(self, bbox, frame):
113
- """Match detection by face recognition (for helmet violations)"""
114
- x, y, w, h = bbox
115
- face_region = frame[max(0, int(y-h/2)):int(y+h/2), max(0, int(x-w/2)):int(x+w/2)]
116
-
117
- if face_region.size == 0:
118
- return None
119
-
120
- try:
121
- # Get face encodings from current detection
122
- face_locations = face_recognition.face_locations(face_region)
123
- if not face_locations:
124
- return None
125
-
126
- current_encoding = face_recognition.face_encodings(face_region, face_locations)[0]
127
-
128
- # Compare with known faces
129
- for worker_id, encodings in self.face_encodings.items():
130
- matches = face_recognition.compare_faces(encodings, current_encoding, tolerance=0.6)
131
- if any(matches):
132
- return worker_id
133
-
134
- except Exception as e:
135
- logger.warning(f"Face recognition error: {e}")
136
-
137
- return None
138
-
139
  def _match_by_position(self, bbox, label):
140
- """Match detection by position (for non-helmet violations)"""
141
  x, y, w, h = bbox
142
  current_pos = (x, y)
143
 
144
  for worker_id, positions in self.position_history.items():
 
145
  if label not in self.violation_history[worker_id]:
146
  continue
147
 
148
- # Check if current position is near any previous positions for this worker
149
- for pos in positions:
150
  distance = np.sqrt((current_pos[0]-pos[0])**2 + (current_pos[1]-pos[1])**2)
151
  if distance < 100: # Within 100 pixels
152
  return worker_id
153
-
154
  return None
155
 
156
  def _is_new_violation(self, worker_id, label, current_time):
157
- """Check if this is a new violation for this worker"""
158
  if label not in self.violation_history[worker_id]:
159
  return True
160
 
161
- last_detection = self.violation_history[worker_id][label]
162
  cooldown = self.VIOLATION_COOLDOWNS.get(label, 10.0)
163
-
164
- return (current_time - last_detection) > cooldown
165
-
166
- def _store_face_encoding(self, worker_id, bbox, frame):
167
- """Store face encoding for a worker"""
168
- x, y, w, h = bbox
169
- face_region = frame[max(0, int(y-h/2)):int(y+h/2), max(0, int(x-w/2)):int(x+w/2)]
170
-
171
- if face_region.size == 0:
172
- return
173
-
174
- try:
175
- face_locations = face_recognition.face_locations(face_region)
176
- if face_locations:
177
- encoding = face_recognition.face_encodings(face_region, face_locations)[0]
178
- if worker_id not in self.face_encodings:
179
- self.face_encodings[worker_id] = []
180
- self.face_encodings[worker_id].append(encoding)
181
- except Exception as e:
182
- logger.warning(f"Error storing face encoding: {e}")
183
 
184
  def _cleanup_tracks(self, current_time):
185
- """Clean up old tracks and face encodings"""
186
- # Remove inactive workers
187
  inactive_ids = [
188
- worker_id for worker_id, track in self.worker_tracks.items()
189
  if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
190
  ]
191
-
192
- for worker_id in inactive_ids:
193
- self.worker_tracks.pop(worker_id, None)
194
- self.position_history.pop(worker_id, None)
195
-
196
- # Keep face encodings for a longer period (for helmet violations)
197
- if (current_time - max(self.violation_history[worker_id].values(), default=0)) > 300: # 5 minutes
198
- self.face_encodings.pop(worker_id, None)
199
- self.violation_history.pop(worker_id, None)
200
-
201
- # ========================== # Optimized Configuration # ==========================
202
  CONFIG = {
203
  "MODEL_PATH": "yolov8_safety.pt",
204
  "FALLBACK_MODEL": "yolov8n.pt",
@@ -211,11 +149,11 @@ CONFIG = {
211
  4: "improper_tool_use"
212
  },
213
  "CLASS_COLORS": {
214
- "no_helmet": (0, 0, 255), # Red
215
- "no_harness": (0, 165, 255), # Orange
216
- "unsafe_posture": (0, 255, 0), # Green
217
- "unsafe_zone": (255, 0, 0), # Blue
218
- "improper_tool_use": (255, 255, 0) # Cyan
219
  },
220
  "DISPLAY_NAMES": {
221
  "no_helmet": "No Helmet Violation",
@@ -238,502 +176,293 @@ CONFIG = {
238
  "unsafe_zone": 0.3,
239
  "improper_tool_use": 0.3
240
  },
241
- "MIN_VIOLATION_FRAMES": 1,
242
  "FRAME_SKIP": 2,
243
- "BATCH_SIZE": 16,
244
- "PARALLEL_WORKERS": max(1, cpu_count() - 1),
245
- "SNAPSHOT_QUALITY": 95,
246
- "FACE_RECOGNITION_INTERVAL": 5 # Process face recognition every 5 frames
247
  }
248
 
249
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
250
  logger.info(f"Using device: {device}")
251
 
 
252
  def load_model():
253
  try:
254
- if os.path.isfile(CONFIG["MODEL_PATH"]):
255
- model_path = CONFIG["MODEL_PATH"]
256
- logger.info(f"Model loaded: {model_path}")
257
- else:
258
- model_path = CONFIG["FALLBACK_MODEL"]
259
- logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
260
- if not os.path.isfile(model_path):
261
- logger.info(f"Downloading fallback model: {model_path}")
262
- torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
263
-
264
- model = YOLO(model_path).to(device)
265
- logger.info(f"Model classes: {model.names}")
266
- return model
267
  except Exception as e:
268
- logger.error(f"Failed to load model: {e}")
269
  raise
270
 
271
  model = load_model()
272
 
273
- # ========================== # Helper Functions # ==========================
274
  def preprocess_frame(frame):
275
- """Apply basic preprocessing to enhance detection"""
276
- frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
277
- return frame
278
 
279
  def draw_detections(frame, detections):
280
- """Draw bounding boxes and labels on detection frame with improved visibility"""
281
- result_frame = frame.copy()
282
-
283
  for det in detections:
284
- label = det.get("violation", "Unknown")
285
- confidence = det.get("confidence", 0.0)
286
- x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
287
- worker_id = det.get("worker_id", "Unknown")
288
-
289
- x1 = int(x - w/2)
290
- y1 = int(y - h/2)
291
- x2 = int(x + w/2)
292
- y2 = int(y + h/2)
293
-
294
- color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
295
-
296
- # Draw thicker rectangle with border
297
- cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
298
 
299
- # Add black background behind text
300
- display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
301
- text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
302
- cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
303
- cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
304
-
305
- # Add confidence score
306
- conf_text = f"Conf: {confidence:.2f}"
307
- cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
308
-
309
- return result_frame
310
 
311
  def calculate_safety_score(violations):
312
- """Calculate safety score based on detected violations"""
313
- penalties = {
314
- "no_helmet": 25,
315
- "no_harness": 30,
316
- "unsafe_posture": 20,
317
- "unsafe_zone": 35,
318
  "improper_tool_use": 25
319
  }
320
-
321
- # Count unique violation types
322
- unique_violations = set()
323
- for v in violations:
324
- violation_type = v.get("violation", "Unknown")
325
- unique_violations.add(violation_type)
326
-
327
- total_penalty = sum(penalties.get(v, 0) for v in unique_violations)
328
- score = max(0, 100 - total_penalty)
329
- return score
330
 
 
331
  def generate_violation_pdf(violations, score):
332
- """Generate a PDF report for the detected violations"""
333
  try:
334
- pdf_filename = f"violations_{int(time.time())}.pdf"
335
- pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
336
- pdf_file = BytesIO()
337
- c = canvas.Canvas(pdf_file, pagesize=letter)
338
 
339
- # Title
340
  c.setFont("Helvetica-Bold", 16)
341
- c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
342
-
343
- # Basic Information
344
  c.setFont("Helvetica", 12)
345
- c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
346
- c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
347
 
348
- # Safety Score
 
349
  c.setFont("Helvetica-Bold", 14)
350
- c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
351
-
352
- # Violation Summary
353
- y_position = 8.2 * inch
354
- c.setFont("Helvetica-Bold", 12)
355
- c.drawString(1 * inch, y_position, "Summary:")
356
- y_position -= 0.3 * inch
357
-
358
- c.setFont("Helvetica", 10)
359
- summary_data = {
360
- "Total Violations Found": len(violations),
361
- "Unique Violation Types": len(set(v['violation'] for v in violations)),
362
- "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
363
- }
364
-
365
- for key, value in summary_data.items():
366
- c.drawString(1 * inch, y_position, f"{key}: {value}")
367
- y_position -= 0.25 * inch
368
-
369
- # Detailed Violations
370
- y_position -= 0.5 * inch
371
- c.setFont("Helvetica-Bold", 12)
372
- c.drawString(1 * inch, y_position, "Violation Details:")
373
- y_position -= 0.3 * inch
374
 
375
  c.setFont("Helvetica", 10)
376
  for v in violations:
377
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
378
- worker_id = v.get("worker_id", "Unknown")
379
- time_str = f"{v.get('timestamp', 0.0):.2f}s"
380
- conf_str = f"{v.get('confidence', 0.0):.2f}"
381
-
382
- violation_text = f"- {display_name} by Worker {worker_id} at {time_str} (Confidence: {conf_str})"
383
- c.drawString(1.2 * inch, y_position, violation_text)
384
- y_position -= 0.2 * inch
385
-
386
- if y_position < 1 * inch:
387
  c.showPage()
388
- c.setFont("Helvetica", 10)
389
- y_position = 10 * inch
390
-
391
  c.save()
392
- pdf_file.seek(0)
393
-
394
- # Save PDF file
 
395
  with open(pdf_path, "wb") as f:
396
- f.write(pdf_file.getvalue())
397
 
398
- public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
399
- logger.info(f"PDF generated: {public_url}")
400
- return pdf_path, public_url, pdf_file
401
  except Exception as e:
402
- logger.error(f"Error generating PDF: {e}")
403
- return "", "", None
404
 
405
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
406
  def connect_to_salesforce():
407
- """Connect to Salesforce with retry logic"""
408
  try:
409
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
410
- logger.info("Connected to Salesforce")
411
- sf.describe()
412
  return sf
413
  except Exception as e:
414
  logger.error(f"Salesforce connection failed: {e}")
415
  raise
416
 
417
- def upload_pdf_to_salesforce(sf, pdf_file, report_id):
418
- """Upload PDF report to Salesforce"""
419
- try:
420
- if not pdf_file:
421
- logger.error("No PDF file provided for upload")
422
- return ""
423
-
424
- encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
425
- content_version_data = {
426
- "Title": f"Safety_Violation_Report_{int(time.time())}",
427
- "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
428
- "VersionData": encoded_pdf,
429
- "FirstPublishLocationId": report_id
430
- }
431
- content_version = sf.ContentVersion.create(content_version_data)
432
- result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
433
-
434
- if not result['records']:
435
- logger.error("Failed to retrieve ContentVersion")
436
- return ""
437
-
438
- file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
439
- logger.info(f"PDF uploaded to Salesforce: {file_url}")
440
- return file_url
441
- except Exception as e:
442
- logger.error(f"Error uploading PDF to Salesforce: {e}")
443
- return ""
444
-
445
- def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
446
- """Push violation report to Salesforce"""
447
  try:
448
  sf = connect_to_salesforce()
449
 
450
- # Format violations for Salesforce
451
- violations_text = ""
452
- for v in violations:
453
- display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
454
- worker_id = v.get('worker_id', 'Unknown')
455
- timestamp = v.get('timestamp', 0.0)
456
- confidence = v.get('confidence', 0.0)
457
-
458
- violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
459
-
460
- if not violations_text:
461
- violations_text = "No violations detected."
462
-
463
- pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
464
-
465
  record_data = {
466
  "Compliance_Score__c": score,
467
  "Violations_Found__c": len(violations),
468
- "Violations_Details__c": violations_text,
469
- "Status__c": "Pending",
470
- "PDF_Report_URL__c": pdf_url
 
 
471
  }
472
 
473
- logger.info(f"Creating Salesforce record with data: {record_data}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
- try:
476
- record = sf.Safety_Video_Report__c.create(record_data)
477
- logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
478
- except Exception as e:
479
- logger.error(f"Failed to create Safety_Video_Report__c: {e}")
480
- record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
481
- logger.warning(f"Fell back to Account record: {record['id']}")
482
-
483
- record_id = record["id"]
484
-
485
- if pdf_file:
486
- uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
487
- if uploaded_url:
488
- try:
489
- sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
490
- logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
491
- except Exception as e:
492
- logger.error(f"Failed to update Safety_Video_Report__c: {e}")
493
- sf.Account.update(record_id, {"Description": uploaded_url})
494
- logger.info(f"Updated Account record {record_id} with PDF URL")
495
- pdf_url = uploaded_url
496
-
497
  return record_id, pdf_url
498
  except Exception as e:
499
- logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
500
  return None, ""
501
 
 
502
  def process_video(video_data):
503
- """Process video to detect safety violations with enhanced tracking"""
504
  try:
505
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
506
- logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
507
-
508
- video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
509
  with open(video_path, "wb") as f:
510
  f.write(video_data)
511
- logger.info(f"Video saved: {video_path}")
512
-
513
  cap = cv2.VideoCapture(video_path)
514
  if not cap.isOpened():
515
- os.remove(video_path)
516
- raise ValueError("Could not open video file")
517
-
518
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
519
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
520
- duration = total_frames / fps
521
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
522
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
523
- logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
524
-
525
  tracker = SafetyTracker(frame_rate=fps)
526
  snapshots = []
527
- start_time = time.time()
528
- frame_skip = CONFIG["FRAME_SKIP"]
529
- processed_frames = 0
530
- frame_counter = 0
531
-
532
- while processed_frames < total_frames:
533
- batch_frames = []
534
- batch_indices = []
535
-
536
- for _ in range(CONFIG["BATCH_SIZE"]):
537
- frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
538
- if frame_idx >= total_frames:
539
- break
540
-
541
- ret, frame = cap.read()
542
- if not ret:
543
- break
544
-
545
- frame = preprocess_frame(frame)
546
-
547
- # Skip frames if needed
548
- for _ in range(frame_skip - 1):
549
- if not cap.grab():
550
- break
551
-
552
- batch_frames.append(frame)
553
- batch_indices.append(frame_idx)
554
- processed_frames += 1
555
- frame_counter += 1
556
-
557
- if not batch_frames:
558
  break
559
-
560
- # Process batch with YOLO model
561
- results = model(batch_frames, device=device, conf=0.1, verbose=False)
562
-
563
- for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
564
- current_time = frame_idx / fps
565
 
566
- # Update progress every second
567
- if time.time() - start_time > 1.0:
568
- progress = (processed_frames / total_frames) * 100
569
- yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", ""
570
- start_time = time.time()
571
-
572
- boxes = result.boxes
573
- detections = []
574
 
575
- for box in boxes:
576
- cls = int(box.cls)
577
- conf = float(box.conf)
578
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
579
-
580
- if label is None:
581
- continue
582
-
583
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
584
- continue
585
-
586
- bbox = box.xywh.cpu().numpy()[0]
587
  detections.append({
588
- "bbox": bbox,
589
- "violation": label,
590
- "confidence": conf
591
  })
592
-
593
- if not detections:
594
- continue
595
-
596
- # Update tracker with new detections
597
- new_violations = tracker.update(detections, batch_frames[i])
598
-
599
- # Process new violations
600
- for violation in new_violations:
601
- # Take snapshot for the new violation
602
- snapshot_frame = batch_frames[i].copy()
603
- snapshot_frame = draw_detections(snapshot_frame, [violation])
604
-
605
- # Add timestamp to snapshot
606
- cv2.putText(
607
- snapshot_frame,
608
- f"Time: {violation['timestamp']:.2f}s",
609
- (10, 30),
610
- cv2.FONT_HERSHEY_SIMPLEX,
611
- 0.7,
612
- (255, 255, 255),
613
- 2
614
- )
615
-
616
- # Save snapshot with high quality
617
- snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
618
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
619
-
620
- cv2.imwrite(
621
- snapshot_path,
622
- snapshot_frame,
623
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
624
- )
625
-
626
- snapshots.append({
627
- "violation": violation['violation'],
628
- "worker_id": violation['worker_id'],
629
- "timestamp": violation['timestamp'],
630
- "snapshot_path": snapshot_path,
631
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
632
- })
633
-
634
- logger.info(f"Captured snapshot for {violation['violation']} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
635
-
636
- cap.release()
637
- if os.path.exists(video_path):
638
- os.remove(video_path)
639
 
640
- processing_time = time.time() - start_time
641
- logger.info(f"Processing complete in {processing_time:.2f}s")
642
-
643
- # Get all unique violations from tracker
644
- violations = []
645
- for worker_id, worker_violations in tracker.violation_history.items():
646
- for label, detection_time in worker_violations.items():
647
- violations.append({
648
- "worker_id": worker_id,
649
- "violation": label,
650
- "timestamp": detection_time
 
 
 
 
 
651
  })
652
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  if not violations:
654
- logger.info("No violations detected after processing")
655
- yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
656
  return
657
-
658
- # Calculate safety score
659
  score = calculate_safety_score(violations)
 
 
660
 
661
- # Generate PDF report
662
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
 
 
 
 
663
 
664
- # Push report to Salesforce
665
- report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
666
-
667
- # Format violations table for display
668
- violation_table = "| Violation | Worker ID | Time (s) |\n"
669
- violation_table += "|-----------|-----------|----------|\n"
670
 
671
- for v in sorted(violations, key=lambda x: x.get("timestamp", 0.0)):
672
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
673
- worker_id = v.get("worker_id", "Unknown")
674
- timestamp = v.get("timestamp", 0.0)
675
-
676
- violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} |\n"
677
-
678
- # Format snapshots for display
679
- snapshots_text = ""
680
- for s in snapshots:
681
- display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
682
- worker_id = s.get("worker_id", "Unknown")
683
- timestamp = s.get("timestamp", 0.0)
684
-
685
- snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
686
- snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
687
-
688
- if not snapshots_text:
689
- snapshots_text = "No snapshots captured."
690
-
691
  yield (
692
- violation_table,
693
  f"Safety Score: {score}%",
694
- snapshots_text,
695
- f"Salesforce Record ID: {report_id or 'N/A'}",
696
- final_pdf_url or "N/A"
697
  )
698
-
699
  except Exception as e:
700
- logger.error(f"Error processing video: {e}", exc_info=True)
701
  if 'video_path' in locals() and os.path.exists(video_path):
702
  os.remove(video_path)
703
- yield f"Error processing video: {e}", "", "", "", ""
704
 
 
705
  def gradio_interface(video_file):
706
- """Gradio interface for the video processing"""
707
  if not video_file:
708
- return "No file uploaded.", "", "No file uploaded.", "", ""
709
 
710
  try:
711
  with open(video_file, "rb") as f:
712
  video_data = f.read()
713
-
714
- for status, score, snapshots_text, record_id, details_url in process_video(video_data):
715
- yield status, score, snapshots_text, record_id, details_url
716
 
717
  except Exception as e:
718
- logger.error(f"Error in Gradio interface: {e}", exc_info=True)
719
- yield f"Error: {str(e)}", "", "Error in processing.", "", ""
720
 
721
- # ========================== # Gradio Interface # ==========================
722
  interface = gr.Interface(
723
  fn=gradio_interface,
724
- inputs=gr.Video(label="Upload Site Video"),
725
  outputs=[
726
- gr.Markdown(label="Detected Safety Violations"),
727
- gr.Textbox(label="Compliance Score"),
728
- gr.Markdown(label="Snapshots"),
729
- gr.Textbox(label="Salesforce Record ID"),
730
- gr.Textbox(label="Violation Details URL")
731
  ],
732
- title="Worksite Safety Violation Analyzer",
733
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
734
- allow_flagging="never"
735
  )
736
 
737
  if __name__ == "__main__":
738
- logger.info("Launching Enhanced Safety Analyzer App...")
739
  interface.launch()
 
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
 
22
  from collections import defaultdict
23
 
24
  # ========================== # Configuration and Setup # ==========================
 
27
 
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
29
  logger = logging.getLogger(__name__)
 
 
30
  warnings.filterwarnings("ignore")
31
 
32
+ # ========================== # Position-Based Tracker Implementation # ==========================
33
  class SafetyTracker:
34
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
35
  self.track_thresh = track_thresh
 
38
  self.frame_rate = frame_rate
39
  self.next_id = 1
40
 
41
+ # Tracking stores
42
+ self.worker_tracks = {} # Active tracks
43
+ self.violation_history = defaultdict(dict) # {worker_id: {violation_type: last_detection_time}}
44
+ self.position_history = defaultdict(list) # {worker_id: [positions]}
 
45
 
46
+ # Violation cooldowns (seconds)
47
  self.VIOLATION_COOLDOWNS = {
48
  "no_helmet": 30.0,
49
  "no_harness": 20.0,
 
52
  "improper_tool_use": 15.0
53
  }
54
 
55
+ def update(self, detections):
56
+ """Update tracks with new detections using position-based matching"""
57
  current_time = time.time()
 
58
  new_violations = []
59
 
60
  for det in detections:
 
62
  label = det['violation']
63
  confidence = det['confidence']
64
 
65
+ # Match by position
66
+ worker_id = self._match_by_position(bbox, label)
 
 
 
 
67
 
68
  if worker_id is None:
69
  worker_id = self.next_id
70
  self.next_id += 1
71
 
72
+ # Check if new violation
73
  if self._is_new_violation(worker_id, label, current_time):
 
74
  violation = {
75
  'worker_id': worker_id,
76
  'violation': label,
 
79
  'timestamp': current_time
80
  }
81
  new_violations.append(violation)
 
 
82
  self.violation_history[worker_id][label] = current_time
 
 
 
 
83
 
84
+ # Update position history
85
+ x, y, w, h = bbox
86
+ self.position_history[worker_id].append((x, y))
87
+
88
+ # Update active tracks
89
  self.worker_tracks[worker_id] = {
90
  'bbox': bbox,
91
  'last_seen': current_time,
92
  'label': label
93
  }
94
 
95
+ # Cleanup old tracks
96
  self._cleanup_tracks(current_time)
97
 
98
  return new_violations
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def _match_by_position(self, bbox, label):
101
+ """Match detection to existing tracks using position"""
102
  x, y, w, h = bbox
103
  current_pos = (x, y)
104
 
105
  for worker_id, positions in self.position_history.items():
106
+ # Only match if worker has had this violation type before
107
  if label not in self.violation_history[worker_id]:
108
  continue
109
 
110
+ # Check distance to historical positions
111
+ for pos in positions[-5:]: # Check last 5 positions
112
  distance = np.sqrt((current_pos[0]-pos[0])**2 + (current_pos[1]-pos[1])**2)
113
  if distance < 100: # Within 100 pixels
114
  return worker_id
 
115
  return None
116
 
117
  def _is_new_violation(self, worker_id, label, current_time):
118
+ """Check if violation is new based on cooldown"""
119
  if label not in self.violation_history[worker_id]:
120
  return True
121
 
122
+ last_time = self.violation_history[worker_id][label]
123
  cooldown = self.VIOLATION_COOLDOWNS.get(label, 10.0)
124
+ return (current_time - last_time) > cooldown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  def _cleanup_tracks(self, current_time):
127
+ """Remove inactive tracks"""
 
128
  inactive_ids = [
129
+ id for id, track in self.worker_tracks.items()
130
  if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
131
  ]
132
+ for id in inactive_ids:
133
+ self.worker_tracks.pop(id, None)
134
+ self.position_history.pop(id, None)
135
+ # Keep violation history for longer
136
+ if (current_time - max(self.violation_history[id].values(), default=0)) > 300:
137
+ self.violation_history.pop(id, None)
138
+
139
+ # ========================== # App Configuration # ==========================
 
 
 
140
  CONFIG = {
141
  "MODEL_PATH": "yolov8_safety.pt",
142
  "FALLBACK_MODEL": "yolov8n.pt",
 
149
  4: "improper_tool_use"
150
  },
151
  "CLASS_COLORS": {
152
+ "no_helmet": (0, 0, 255),
153
+ "no_harness": (0, 165, 255),
154
+ "unsafe_posture": (0, 255, 0),
155
+ "unsafe_zone": (255, 0, 0),
156
+ "improper_tool_use": (255, 255, 0)
157
  },
158
  "DISPLAY_NAMES": {
159
  "no_helmet": "No Helmet Violation",
 
176
  "unsafe_zone": 0.3,
177
  "improper_tool_use": 0.3
178
  },
 
179
  "FRAME_SKIP": 2,
180
+ "BATCH_SIZE": 8, # Reduced for stability
181
+ "SNAPSHOT_QUALITY": 90
 
 
182
  }
183
 
184
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185
  logger.info(f"Using device: {device}")
186
 
187
+ # ========================== # Core Functions # ==========================
188
  def load_model():
189
  try:
190
+ model_path = CONFIG["MODEL_PATH"] if os.path.exists(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
191
+ logger.info(f"Loading model: {model_path}")
192
+ if not os.path.exists(model_path):
193
+ logger.info("Downloading fallback model...")
194
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
195
+ return YOLO(model_path).to(device)
 
 
 
 
 
 
 
196
  except Exception as e:
197
+ logger.error(f"Model loading failed: {e}")
198
  raise
199
 
200
  model = load_model()
201
 
 
202
  def preprocess_frame(frame):
203
+ """Basic image enhancement"""
204
+ return cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
 
205
 
206
  def draw_detections(frame, detections):
207
+ """Draw bounding boxes with labels"""
208
+ result = frame.copy()
 
209
  for det in detections:
210
+ x, y, w, h = det['bbox']
211
+ x1, y1 = int(x-w/2), int(y-h/2)
212
+ x2, y2 = int(x+w/2), int(y+h/2)
213
+ label = CONFIG["DISPLAY_NAMES"].get(det['violation'], det['violation'])
214
+ color = CONFIG["CLASS_COLORS"].get(det['violation'], (0,0,255))
 
 
 
 
 
 
 
 
 
215
 
216
+ cv2.rectangle(result, (x1, y1), (x2, y2), color, 3)
217
+ cv2.putText(result, f"{label} (Worker {det['worker_id']})",
218
+ (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
219
+ return result
 
 
 
 
 
 
 
220
 
221
  def calculate_safety_score(violations):
222
+ penalty_map = {
223
+ "no_helmet": 25, "no_harness": 30,
224
+ "unsafe_posture": 20, "unsafe_zone": 35,
 
 
 
225
  "improper_tool_use": 25
226
  }
227
+ unique_violations = {v['violation'] for v in violations}
228
+ return max(0, 100 - sum(penalty_map.get(v,0) for v in unique_violations))
 
 
 
 
 
 
 
 
229
 
230
+ # ========================== # Reporting Functions # ==========================
231
  def generate_violation_pdf(violations, score):
 
232
  try:
233
+ pdf_buffer = BytesIO()
234
+ c = canvas.Canvas(pdf_buffer, pagesize=letter)
 
 
235
 
236
+ # Header
237
  c.setFont("Helvetica-Bold", 16)
238
+ c.drawString(1*inch, 10*inch, "Safety Violation Report")
 
 
239
  c.setFont("Helvetica", 12)
240
+ c.drawString(1*inch, 9.5*inch, f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}")
241
+ c.drawString(1*inch, 9*inch, f"Safety Score: {score}%")
242
 
243
+ # Violations list
244
+ y = 8.5*inch
245
  c.setFont("Helvetica-Bold", 14)
246
+ c.drawString(1*inch, y, "Violations Detected:")
247
+ y -= 0.3*inch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  c.setFont("Helvetica", 10)
250
  for v in violations:
251
+ text = f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]} at {v['timestamp']:.1f}s"
252
+ c.drawString(1.2*inch, y, text)
253
+ y -= 0.2*inch
254
+ if y < 1*inch:
 
 
 
 
 
 
255
  c.showPage()
256
+ y = 10*inch
257
+
 
258
  c.save()
259
+ pdf_buffer.seek(0)
260
+
261
+ # Save to file
262
+ pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], f"report_{int(time.time())}.pdf")
263
  with open(pdf_path, "wb") as f:
264
+ f.write(pdf_buffer.getvalue())
265
 
266
+ return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}", pdf_buffer
 
 
267
  except Exception as e:
268
+ logger.error(f"PDF generation failed: {e}")
269
+ return None, None, None
270
 
271
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
272
  def connect_to_salesforce():
 
273
  try:
274
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
275
+ logger.info("Salesforce connection established")
 
276
  return sf
277
  except Exception as e:
278
  logger.error(f"Salesforce connection failed: {e}")
279
  raise
280
 
281
+ def push_report_to_salesforce(violations, score, pdf_path, pdf_buffer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  try:
283
  sf = connect_to_salesforce()
284
 
285
+ # Create record
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  record_data = {
287
  "Compliance_Score__c": score,
288
  "Violations_Found__c": len(violations),
289
+ "Violations_Details__c": "\n".join(
290
+ f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]}"
291
+ for v in violations
292
+ ),
293
+ "Status__c": "New"
294
  }
295
 
296
+ record = sf.Safety_Video_Report__c.create(record_data)
297
+ record_id = record['id']
298
+ logger.info(f"Created Salesforce record: {record_id}")
299
+
300
+ # Upload PDF if available
301
+ pdf_url = ""
302
+ if pdf_buffer:
303
+ encoded = base64.b64encode(pdf_buffer.getvalue()).decode()
304
+ content_version = sf.ContentVersion.create({
305
+ "Title": f"Safety_Report_{record_id}",
306
+ "PathOnClient": "report.pdf",
307
+ "VersionData": encoded,
308
+ "FirstPublishLocationId": record_id
309
+ })
310
+ pdf_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
311
+ logger.info(f"PDF uploaded: {pdf_url}")
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  return record_id, pdf_url
314
  except Exception as e:
315
+ logger.error(f"Salesforce upload failed: {e}")
316
  return None, ""
317
 
318
+ # ========================== # Video Processing # ==========================
319
  def process_video(video_data):
 
320
  try:
321
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
322
+
323
+ # Save video
324
+ video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"input_{int(time.time())}.mp4")
325
  with open(video_path, "wb") as f:
326
  f.write(video_data)
327
+
 
328
  cap = cv2.VideoCapture(video_path)
329
  if not cap.isOpened():
330
+ raise ValueError("Failed to open video")
331
+
 
 
332
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
333
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
 
 
334
  tracker = SafetyTracker(frame_rate=fps)
335
  snapshots = []
336
+
337
+ frame_count = 0
338
+ while True:
339
+ ret, frame = cap.read()
340
+ if not ret:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  break
 
 
 
 
 
 
342
 
343
+ if frame_count % CONFIG["FRAME_SKIP"] != 0:
344
+ frame_count += 1
345
+ continue
 
 
 
 
 
346
 
347
+ # Process frame
348
+ frame = preprocess_frame(frame)
349
+ results = model(frame, verbose=False)[0]
350
+
351
+ # Get detections
352
+ detections = []
353
+ for box in results.boxes:
354
+ cls = int(box.cls)
355
+ label = CONFIG["VIOLATION_LABELS"].get(cls)
356
+ if label and box.conf > CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
 
 
357
  detections.append({
358
+ 'bbox': box.xywh[0].cpu().numpy(),
359
+ 'violation': label,
360
+ 'confidence': float(box.conf)
361
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
+ # Update tracker
364
+ new_violations = tracker.update(detections)
365
+
366
+ # Capture snapshots for new violations
367
+ for violation in new_violations:
368
+ snapshot = draw_detections(frame.copy(), [violation])
369
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
370
+ img_path = os.path.join(
371
+ CONFIG["OUTPUT_DIR"],
372
+ f"violation_{violation['worker_id']}_{violation['violation']}_{timestamp}.jpg"
373
+ )
374
+ cv2.imwrite(img_path, snapshot, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
375
+ snapshots.append({
376
+ 'path': img_path,
377
+ 'url': f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(img_path)}",
378
+ 'violation': violation
379
  })
380
+
381
+ # Update progress
382
+ if frame_count % 10 == 0:
383
+ progress = min(100, frame_count / total_frames * 100)
384
+ yield f"Processing... {progress:.1f}%", "", "", "", ""
385
+
386
+ frame_count += 1
387
+
388
+ cap.release()
389
+ os.remove(video_path)
390
+
391
+ # Generate report
392
+ violations = [
393
+ {
394
+ 'worker_id': worker_id,
395
+ 'violation': violation_type,
396
+ 'timestamp': detection_time
397
+ }
398
+ for worker_id, violations in tracker.violation_history.items()
399
+ for violation_type, detection_time in violations.items()
400
+ ]
401
+
402
  if not violations:
403
+ yield "No violations found", "Safety Score: 100%", "No snapshots", "N/A", "N/A"
 
404
  return
405
+
 
406
  score = calculate_safety_score(violations)
407
+ pdf_path, pdf_url, pdf_buffer = generate_violation_pdf(violations, score)
408
+ record_id, salesforce_url = push_report_to_salesforce(violations, score, pdf_path, pdf_buffer)
409
 
410
+ # Format output
411
+ violations_table = "| Violation | Worker ID | Time |\n|-----------|-----------|------|\n"
412
+ violations_table += "\n".join(
413
+ f"| {CONFIG['DISPLAY_NAMES'][v['violation']]} | {v['worker_id']} | {v['timestamp']:.1f}s |"
414
+ for v in violations
415
+ )
416
 
417
+ snapshots_md = "\n\n".join(
418
+ f"### {CONFIG['DISPLAY_NAMES'][s['violation']['violation']]} (Worker {s['violation']['worker_id']})\n"
419
+ f"![Snapshot]({s['url']})"
420
+ for s in snapshots
421
+ )
 
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  yield (
424
+ violations_table,
425
  f"Safety Score: {score}%",
426
+ snapshots_md or "No snapshots",
427
+ f"Salesforce ID: {record_id or 'N/A'}",
428
+ salesforce_url or pdf_url or "N/A"
429
  )
430
+
431
  except Exception as e:
432
+ logger.error(f"Processing failed: {e}")
433
  if 'video_path' in locals() and os.path.exists(video_path):
434
  os.remove(video_path)
435
+ yield f"Error: {str(e)}", "", "", "", ""
436
 
437
+ # ========================== # Gradio Interface # ==========================
438
  def gradio_interface(video_file):
 
439
  if not video_file:
440
+ return "Upload a video file", "", "", "", ""
441
 
442
  try:
443
  with open(video_file, "rb") as f:
444
  video_data = f.read()
445
+
446
+ for output in process_video(video_data):
447
+ yield output
448
 
449
  except Exception as e:
450
+ logger.error(f"Interface error: {e}")
451
+ yield f"Error: {str(e)}", "", "", "", ""
452
 
 
453
  interface = gr.Interface(
454
  fn=gradio_interface,
455
+ inputs=gr.Video(label="Upload Safety Video"),
456
  outputs=[
457
+ gr.Markdown("## Detected Violations"),
458
+ gr.Textbox(label="Safety Score"),
459
+ gr.Markdown("## Evidence Snapshots"),
460
+ gr.Textbox(label="Salesforce Record"),
461
+ gr.Textbox(label="Report URL")
462
  ],
463
+ title="AI Safety Compliance Analyzer",
464
+ description="Detects PPE and safety violations in worksite videos"
 
465
  )
466
 
467
  if __name__ == "__main__":
 
468
  interface.launch()