PrashanthB461 commited on
Commit
9554c03
·
verified ·
1 Parent(s): 714f201

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +570 -162
app.py CHANGED
@@ -49,51 +49,123 @@ class BYTETracker:
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
- self.recently_removed = {}
 
 
 
53
 
54
  def update(self, dets, scores, cls):
55
  tracks = []
56
  current_time = time.time()
57
 
58
  # Prune stale tracks
59
- stale_ids = [tid for tid, track in self.tracks.items()
60
- if current_time - track['last_seen'] > self.track_buffer / self.frame_rate]
61
-
62
- for tid in stale_ids:
63
- self.recently_removed[tid] = {
64
- 'bbox': self.tracks[tid]['bbox'],
 
 
 
65
  'last_seen': current_time,
66
- 'last_position': self.last_positions.get(tid, [0, 0])
 
 
67
  }
68
- self.tracks.pop(tid, None)
69
- self.worker_history.pop(tid, None)
70
- self.last_positions.pop(tid, None)
 
 
71
 
72
- # Clean up recently_removed
73
- for tid in [tid for tid, info in self.recently_removed.items()
74
- if current_time - info['last_seen'] > 1.0]:
75
- self.recently_removed.pop(tid, None)
 
 
 
 
 
 
 
 
 
76
 
77
- for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
78
- if score < self.track_thresh:
79
  continue
80
-
 
81
  x, y, w, h = det
 
 
 
 
 
82
  matched = False
83
  best_iou = 0
84
  best_track_id = None
85
 
86
- # Match with active tracks
87
  for track_id, track_info in self.tracks.items():
 
 
 
 
88
  tx, ty, tw, th = track_info['bbox']
89
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
90
- if iou > self.match_thresh and iou > best_iou:
91
- best_iou = iou
 
 
 
 
 
 
 
 
 
 
92
  best_track_id = track_id
93
  matched = True
94
 
95
  if matched:
96
- self._update_track(best_track_id, x, y, w, h, score, cl, current_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  tracks.append({
98
  'id': best_track_id,
99
  'bbox': [x, y, w, h],
@@ -104,8 +176,42 @@ class BYTETracker:
104
  # Try to re-identify with recently removed tracks
105
  reidentified = False
106
  for track_id, info in self.recently_removed.items():
107
- if self._is_same_worker([x, y], info['last_position'], threshold=100):
108
- self._update_track(track_id, x, y, w, h, score, cl, current_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  tracks.append({
110
  'id': track_id,
111
  'bbox': [x, y, w, h],
@@ -113,15 +219,37 @@ class BYTETracker:
113
  'cls': cl
114
  })
115
  reidentified = True
116
- self.recently_removed.pop(track_id, None)
117
  break
118
 
119
  if not reidentified:
120
- # Check existing workers by position
121
  same_worker = False
122
  for worker_id, last_pos in self.last_positions.items():
123
- if self._is_same_worker([x, y], last_pos, threshold=100):
124
- self._update_track(worker_id, x, y, w, h, score, cl, current_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  tracks.append({
126
  'id': worker_id,
127
  'bbox': [x, y, w, h],
@@ -132,29 +260,40 @@ class BYTETracker:
132
  break
133
 
134
  if not same_worker:
135
- self._update_track(self.next_id, x, y, w, h, score, cl, current_time)
136
- tracks.append({
137
- 'id': self.next_id,
138
- 'bbox': [x, y, w, h],
139
- 'score': score,
140
- 'cls': cl
141
- })
142
- self.next_id += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  return tracks
145
 
146
- def _update_track(self, track_id, x, y, w, h, score, cls, current_time):
147
- self.tracks[track_id] = {
148
- 'bbox': [x, y, w, h],
149
- 'score': score,
150
- 'cls': cls,
151
- 'last_seen': current_time
152
- }
153
- if track_id not in self.worker_history:
154
- self.worker_history[track_id] = []
155
- self.worker_history[track_id].append([x, y])
156
- self.last_positions[track_id] = [x, y]
157
-
158
  def _calculate_iou(self, box1, box2):
159
  x1, y1, w1, h1 = box1
160
  x2, y2, w2, h2 = box2
@@ -167,12 +306,35 @@ class BYTETracker:
167
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
168
  box1_area = w1 * h1
169
  box2_area = w2 * h2
170
- return intersection_area / (box1_area + box2_area - intersection_area)
 
171
 
172
- def _is_same_worker(self, pos1, pos2, threshold=100):
173
  x1, y1 = pos1
174
  x2, y2 = pos2
175
- return np.sqrt((x1 - x2)**2 + (y1 - y2)**2) < threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  # ========================== # Optimized Configuration # ==========================
178
  CONFIG = {
@@ -193,10 +355,10 @@ CONFIG = {
193
  "improper_tool_use": (255, 255, 0)
194
  },
195
  "DISPLAY_NAMES": {
196
- "no_helmet": "No Helmet",
197
- "no_harness": "No Harness",
198
  "unsafe_posture": "Unsafe Posture",
199
- "unsafe_zone": "Unsafe Zone",
200
  "improper_tool_use": "Improper Tool Use"
201
  },
202
  "SF_CREDENTIALS": {
@@ -215,17 +377,18 @@ CONFIG = {
215
  },
216
  "MIN_VIOLATION_FRAMES": 1,
217
  "VIOLATION_COOLDOWN": 30.0,
218
- "WORKER_TRACKING_DURATION": 10.0,
219
  "MAX_PROCESSING_TIME": 60,
220
- "FRAME_SKIP": 2, # Increased frame skip for faster processing
221
- "BATCH_SIZE": 8, # Increased batch size
222
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
223
- "TRACK_BUFFER": 150,
224
  "TRACK_THRESH": 0.3,
225
  "MATCH_THRESH": 0.5,
226
  "SNAPSHOT_QUALITY": 95,
227
- "MAX_WORKER_DISTANCE": 100, # Reduced threshold for better worker matching
228
- "TARGET_RESOLUTION": (384, 384)
 
229
  }
230
 
231
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -256,27 +419,40 @@ model = load_model()
256
 
257
  # ========================== # Helper Functions # ==========================
258
  def preprocess_frame(frame):
259
- frame = cv2.resize(frame, CONFIG["TARGET_RESOLUTION"], interpolation=cv2.INTER_LINEAR)
260
- return cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
 
 
 
 
 
261
 
262
  def draw_detections(frame, detections):
263
  result_frame = frame.copy()
 
264
  for det in detections:
265
  label = det.get("violation", "Unknown")
266
  confidence = det.get("confidence", 0.0)
267
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
268
  worker_id = det.get("worker_id", "Unknown")
269
 
270
- x1, y1 = int(x - w/2), int(y - h/2)
271
- x2, y2 = int(x + w/2), int(y + h/2)
 
 
 
272
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
273
 
274
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
 
275
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
276
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
277
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
278
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
279
- cv2.putText(result_frame, f"Conf: {confidence:.2f}", (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
 
 
 
280
  return result_frame
281
 
282
  def calculate_safety_score(violations):
@@ -287,13 +463,23 @@ def calculate_safety_score(violations):
287
  "unsafe_zone": 35,
288
  "improper_tool_use": 25
289
  }
 
290
  worker_violations = {}
291
  for v in violations:
292
  worker_id = v.get("worker_id", "Unknown")
 
 
293
  if worker_id not in worker_violations:
294
  worker_violations[worker_id] = set()
295
- worker_violations[worker_id].add(v.get("violation", "Unknown"))
296
- return max(0, 100 - sum(penalties.get(v, 0) for violations in worker_violations.values() for v in violations))
 
 
 
 
 
 
 
297
 
298
  def generate_violation_pdf(violations, score, output_dir):
299
  try:
@@ -304,9 +490,11 @@ def generate_violation_pdf(violations, score, output_dir):
304
 
305
  c.setFont("Helvetica-Bold", 16)
306
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
 
307
  c.setFont("Helvetica", 12)
308
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
309
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
 
310
  c.setFont("Helvetica-Bold", 14)
311
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
312
 
@@ -342,12 +530,16 @@ def generate_violation_pdf(violations, score, output_dir):
342
  for worker_id, worker_vios in worker_violations.items():
343
  c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
344
  y_position -= 0.2 * inch
 
345
  for v in worker_vios:
346
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
347
  time_str = f"{v.get('timestamp', 0.0):.2f}s"
348
  conf_str = f"{v.get('confidence', 0.0):.2f}"
349
- c.drawString(1.2 * inch, y_position, f" - {display_name} at {time_str} (Confidence: {conf_str})")
 
 
350
  y_position -= 0.2 * inch
 
351
  if y_position < 1 * inch:
352
  c.showPage()
353
  c.setFont("Helvetica", 10)
@@ -355,9 +547,13 @@ def generate_violation_pdf(violations, score, output_dir):
355
 
356
  c.save()
357
  pdf_file.seek(0)
 
358
  with open(pdf_path, "wb") as f:
359
  f.write(pdf_file.getvalue())
360
- return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}", pdf_file
 
 
 
361
  except Exception as e:
362
  logger.error(f"Error generating PDF: {e}")
363
  return "", "", None
@@ -367,6 +563,7 @@ def connect_to_salesforce():
367
  try:
368
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
369
  logger.info("Connected to Salesforce")
 
370
  return sf
371
  except Exception as e:
372
  logger.error(f"Salesforce connection failed: {e}")
@@ -375,18 +572,26 @@ def connect_to_salesforce():
375
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
376
  try:
377
  if not pdf_file:
 
378
  return ""
 
379
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
380
- content_version = sf.ContentVersion.create({
381
  "Title": f"Safety_Violation_Report_{int(time.time())}",
382
  "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
383
  "VersionData": encoded_pdf,
384
  "FirstPublishLocationId": report_id
385
- })
 
386
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
387
- if result['records']:
388
- return f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
389
- return ""
 
 
 
 
 
390
  except Exception as e:
391
  logger.error(f"Error uploading PDF to Salesforce: {e}")
392
  return ""
@@ -394,25 +599,38 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
394
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
395
  try:
396
  sf = connect_to_salesforce()
397
- violations_text = "\n".join(
398
- f"Worker {v.get('worker_id', 'Unknown')}: "
399
- f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} "
400
- f"at {v.get('timestamp', 0.0):.2f}s (Conf: {v.get('confidence', 0.0):.2f})"
401
- for v in violations
402
- ) or "No violations detected."
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  record_data = {
405
  "Compliance_Score__c": score,
406
  "Violations_Found__c": len(violations),
407
  "Violations_Details__c": violations_text,
408
  "Status__c": "Pending",
409
- "PDF_Report_URL__c": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
410
  }
411
 
 
 
412
  try:
413
  record = sf.Safety_Video_Report__c.create(record_data)
414
- except Exception:
 
 
415
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
 
416
 
417
  record_id = record["id"]
418
 
@@ -421,28 +639,124 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
421
  if uploaded_url:
422
  try:
423
  sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
424
- except Exception:
 
 
425
  sf.Account.update(record_id, {"Description": uploaded_url})
426
- return record_id, uploaded_url
427
- return record_id, record_data["PDF_Report_URL__c"]
 
 
428
  except Exception as e:
429
  logger.error(f"Salesforce record creation failed: {e}")
430
  return "N/A", "Salesforce integration failed."
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  def process_video(video_data, temp_dir):
433
  video_path = None
434
  output_dir = os.path.join(temp_dir, "output")
435
  os.makedirs(output_dir, exist_ok=True)
436
-
 
437
  try:
438
- # Save video to temp file
 
 
 
 
 
 
439
  with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
440
  temp_file.write(video_data)
 
441
  video_path = temp_file.name
 
 
 
 
 
 
 
 
 
 
 
442
 
443
- cap = cv2.VideoCapture(video_path)
444
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
445
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
 
 
 
 
 
 
 
446
  tracker = BYTETracker(
447
  track_thresh=CONFIG["TRACK_THRESH"],
448
  track_buffer=CONFIG["TRACK_BUFFER"],
@@ -450,18 +764,24 @@ def process_video(video_data, temp_dir):
450
  frame_rate=fps
451
  )
452
 
 
453
  worker_id_mapping = {}
 
 
454
  unique_violations = {}
455
  violation_frames = {}
456
- worker_violation_count = {}
457
  start_time = time.time()
 
458
  processed_frames = 0
459
- worker_counter = 1
460
 
 
461
  while processed_frames < total_frames:
462
  batch_frames = []
463
  batch_indices = []
464
 
 
465
  for _ in range(CONFIG["BATCH_SIZE"]):
466
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
467
  if frame_idx >= total_frames:
@@ -469,39 +789,54 @@ def process_video(video_data, temp_dir):
469
 
470
  ret, frame = cap.read()
471
  if not ret:
 
472
  break
473
 
474
  frame = preprocess_frame(frame)
475
- batch_frames.append(frame)
476
- batch_indices.append(frame_idx)
477
- processed_frames += 1
478
 
479
- # Skip frames for faster processing
480
- for _ in range(CONFIG["FRAME_SKIP"] - 1):
481
  if not cap.grab():
482
  break
483
- processed_frames += 1
 
 
 
484
 
485
  if not batch_frames:
 
486
  break
487
 
488
  try:
 
489
  batch_frames_np = np.array(batch_frames)
490
  batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
 
491
  if device.type == "cuda":
492
- batch_frames_tensor = batch_frames_tensor.half().to(device)
 
493
  results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
494
  except Exception as e:
495
  logger.error(f"Model inference failed: {e}")
496
- raise ValueError(f"Failed to process video frames: {str(e)}")
 
 
 
 
497
 
 
498
  current_time = time.time()
499
- if current_time - start_time > CONFIG["MAX_PROCESSING_TIME"]:
500
- logger.warning("Max processing time reached")
501
- break
 
 
 
502
 
 
503
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
504
- current_timestamp = frame_idx / fps
 
505
  boxes = result.boxes
506
  track_inputs = []
507
 
@@ -509,12 +844,19 @@ def process_video(video_data, temp_dir):
509
  cls = int(box.cls)
510
  conf = float(box.conf)
511
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
512
- if label and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
513
- track_inputs.append({
514
- "bbox": box.xywh.cpu().numpy()[0],
515
- "conf": conf,
516
- "cls": cls
517
- })
 
 
 
 
 
 
 
518
 
519
  if not track_inputs:
520
  continue
@@ -524,93 +866,130 @@ def process_video(video_data, temp_dir):
524
  np.array([t["conf"] for t in track_inputs]),
525
  np.array([t["cls"] for t in track_inputs])
526
  )
527
-
 
528
  for obj in tracked_objects:
529
  tracker_id = obj['id']
 
 
 
 
 
 
 
 
 
 
530
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
531
- if not label:
 
 
532
  continue
533
 
534
- if tracker_id not in worker_id_mapping:
535
- worker_id_mapping[tracker_id] = worker_counter
536
- worker_counter += 1
537
-
538
  worker_id = worker_id_mapping[tracker_id]
539
  violation_key = (worker_id, label)
540
 
541
  if violation_key not in unique_violations:
542
- unique_violations[violation_key] = current_timestamp
543
  violation_frames[violation_key] = frame_idx
 
 
544
  if worker_id not in worker_violation_count:
545
  worker_violation_count[worker_id] = 0
546
  worker_violation_count[worker_id] += 1
547
 
548
  cap.release()
549
- logger.info(f"Processing complete in {time.time() - start_time:.2f}s")
550
- logger.info(f"Workers detected: {worker_violation_count}")
551
-
552
- # Prepare violations list
553
- violations = [{
554
- "worker_id": worker_id,
555
- "violation": label,
556
- "timestamp": timestamp,
557
- "confidence": 0.0,
558
- "frame_idx": violation_frames[(worker_id, label)]
559
- } for (worker_id, label), timestamp in unique_violations.items()]
 
 
 
560
 
561
  if not violations:
562
- yield "No violations detected.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
 
563
  return
564
 
565
- # Capture snapshots of violations
566
  snapshots = []
567
  cap = cv2.VideoCapture(video_path)
568
  for violation in violations:
569
- cap.set(cv2.CAP_PROP_POS_FRAMES, violation["frame_idx"])
 
570
  ret, frame = cap.read()
571
  if not ret:
 
572
  continue
573
 
574
  frame = preprocess_frame(frame)
575
  frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
 
576
  if device.type == "cuda":
577
- frame_tensor = frame_tensor.half().to(device)
 
 
 
578
 
579
- result = model(frame_tensor.unsqueeze(0), device=device, conf=0.1, verbose=False)[0]
580
- for box in result.boxes:
581
  cls = int(box.cls)
582
  conf = float(box.conf)
583
- if CONFIG["VIOLATION_LABELS"].get(cls, None) == violation["violation"]:
 
584
  violation["confidence"] = round(conf, 2)
585
  bbox = box.xywh.cpu().numpy()[0]
586
- snapshot_frame = draw_detections(frame.copy(), [{
587
  "worker_id": violation["worker_id"],
588
- "violation": violation["violation"],
589
  "confidence": violation["confidence"],
590
  "bounding_box": bbox,
591
  "timestamp": violation["timestamp"]
592
- }])
593
- cv2.putText(snapshot_frame, f"Time: {violation['timestamp']:.2f}s",
594
- (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
595
- snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
 
 
 
 
 
 
 
 
 
596
  snapshot_path = os.path.join(output_dir, snapshot_filename)
597
- cv2.imwrite(snapshot_path, snapshot_frame, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
 
 
 
 
598
  snapshots.append({
599
- "violation": violation["violation"],
600
  "worker_id": violation["worker_id"],
601
  "timestamp": violation["timestamp"],
602
  "snapshot_path": snapshot_path,
603
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
604
  "confidence": violation["confidence"]
605
  })
 
606
  break
 
607
  cap.release()
608
 
609
  score = calculate_safety_score(violations)
610
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
 
611
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
612
 
613
- # Generate output
614
  worker_summary = {}
615
  for v in violations:
616
  worker_id = v["worker_id"]
@@ -622,29 +1001,36 @@ def process_video(video_data, temp_dir):
622
  worker_summary[worker_id]["count"] += 1
623
  worker_summary[worker_id]["violations"].add(v["violation"])
624
 
 
625
  violation_table = "## Worker Safety Violation Summary\n\n"
626
- violation_table += f"**Total Workers with Violations:** {len(worker_summary)}\n"
627
- violation_table += f"**Total Violations Found:** {len(violations)}\n\n"
628
- violation_table += "| Worker ID | Violation Count | Violation Types |\n"
629
- violation_table += "|-----------|-----------------|-----------------|\n"
630
 
631
  for worker_id, info in worker_summary.items():
632
  violation_types = ", ".join([CONFIG["DISPLAY_NAMES"].get(v, v) for v in info["violations"]])
633
  violation_table += f"| {worker_id} | {info['count']} | {violation_types} |\n"
634
 
635
- violation_table += "\n## Detailed Violations\n\n"
636
- violation_table += "| Worker ID | Violation | Time (s) | Confidence |\n"
637
  violation_table += "|-----------|-----------|----------|------------|\n"
638
 
639
- for v in sorted(violations, key=lambda x: (x["worker_id"], x["timestamp"])):
640
- display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], "Unknown")
641
- violation_table += f"| {v['worker_id']} | {display_name} | {v['timestamp']:.2f} | {v['confidence']:.2f} |\n"
 
 
 
 
 
 
 
 
 
 
 
642
 
643
- snapshots_text = "\n".join(
644
- f"### {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} - Worker {s['worker_id']} at {s['timestamp']:.2f}s\n\n"
645
- f"![Violation]({s['snapshot_url']})\n"
646
- for s in snapshots
647
- ) or "No snapshots captured."
648
 
649
  yield (
650
  violation_table,
@@ -661,33 +1047,55 @@ def process_video(video_data, temp_dir):
661
  if video_path and os.path.exists(video_path):
662
  try:
663
  os.remove(video_path)
 
664
  except Exception as e:
665
- logger.error(f"Failed to clean up video file: {e}")
666
  if device.type == "cuda":
667
  torch.cuda.empty_cache()
668
 
669
  def gradio_interface(video_file):
670
  temp_dir = None
 
671
  try:
672
  if not video_file:
673
  return "No file uploaded.", "", "No file uploaded.", "", ""
674
 
675
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
 
 
676
  with open(video_file, "rb") as f:
677
  video_data = f.read()
 
 
 
 
 
 
 
 
 
 
678
 
679
  if not FFMPEG_AVAILABLE:
680
- return "FFmpeg not available. Please install FFmpeg.", "", "", "", ""
681
 
682
- for output in process_video(video_data, temp_dir):
683
- yield output
684
 
685
  except Exception as e:
686
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
687
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
688
  finally:
 
 
 
 
 
 
 
689
  if temp_dir and os.path.exists(temp_dir):
690
  shutil.rmtree(temp_dir, ignore_errors=True)
 
691
  if device.type == "cuda":
692
  torch.cuda.empty_cache()
693
 
@@ -703,10 +1111,10 @@ interface = gr.Interface(
703
  gr.Textbox(label="Violation Details URL")
704
  ],
705
  title="Worksite Safety Violation Analyzer",
706
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use).",
707
  allow_flagging="never"
708
  )
709
 
710
  if __name__ == "__main__":
711
- logger.info("Launching Safety Analyzer App...")
712
  interface.launch()
 
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
+ self.recently_removed = {} # Store recently removed tracks for re-identification
53
+ self.appearance_features = {} # Store appearance features for better re-identification
54
+ self.track_continuity = {} # Track temporal continuity
55
+ self.similarity_threshold = 0.75 # Higher threshold for appearance similarity
56
 
57
  def update(self, dets, scores, cls):
58
  tracks = []
59
  current_time = time.time()
60
 
61
  # Prune stale tracks
62
+ stale_ids = []
63
+ for track_id, track_info in self.tracks.items():
64
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
65
+ stale_ids.append(track_id)
66
+
67
+ for track_id in stale_ids:
68
+ # Store recently removed tracks for re-identification (for 1.5 seconds)
69
+ self.recently_removed[track_id] = {
70
+ 'bbox': self.tracks[track_id]['bbox'],
71
  'last_seen': current_time,
72
+ 'last_position': self.last_positions.get(track_id, [0, 0]),
73
+ 'appearance': self.appearance_features.get(track_id, None),
74
+ 'cls': self.tracks[track_id].get('cls', None)
75
  }
76
+ del self.tracks[track_id]
77
+ if track_id in self.worker_history:
78
+ del self.worker_history[track_id]
79
+ if track_id in self.last_positions:
80
+ del self.last_positions[track_id]
81
 
82
+ # Clean up recently_removed tracks older than 1.5 seconds
83
+ to_remove = []
84
+ for track_id, info in self.recently_removed.items():
85
+ if current_time - info['last_seen'] > 1.5:
86
+ to_remove.append(track_id)
87
+ for track_id in to_remove:
88
+ del self.recently_removed[track_id]
89
+
90
+ # Sort detections by score for high-confidence-first association
91
+ detection_indices = np.argsort(-np.array(scores))
92
+
93
+ assigned_tracks = set()
94
+ matched_detections = set()
95
 
96
+ for i in detection_indices:
97
+ if i >= len(dets) or scores[i] < self.track_thresh:
98
  continue
99
+
100
+ det, score, cl = dets[i], scores[i], cls[i]
101
  x, y, w, h = det
102
+
103
+ # Skip if this detection was already matched
104
+ if i in matched_detections:
105
+ continue
106
+
107
  matched = False
108
  best_iou = 0
109
  best_track_id = None
110
 
111
+ # Try to match with active tracks
112
  for track_id, track_info in self.tracks.items():
113
+ # Skip if this track was already assigned in this frame
114
+ if track_id in assigned_tracks:
115
+ continue
116
+
117
  tx, ty, tw, th = track_info['bbox']
118
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
119
+
120
+ # If similar class and good IOU, consider a match
121
+ is_same_class = track_info.get('cls', None) == cl
122
+ position_match = self._is_same_worker([x, y], self.last_positions.get(track_id, [0, 0]), threshold=120)
123
+
124
+ # Combined matching score with class consistency
125
+ match_score = iou
126
+ if is_same_class:
127
+ match_score += 0.2 # Bonus for same class
128
+
129
+ if position_match and match_score > self.match_thresh and match_score > best_iou:
130
+ best_iou = match_score
131
  best_track_id = track_id
132
  matched = True
133
 
134
  if matched:
135
+ self.tracks[best_track_id].update({
136
+ 'bbox': [x, y, w, h],
137
+ 'score': score,
138
+ 'cls': cl,
139
+ 'last_seen': current_time
140
+ })
141
+
142
+ # Update appearance feature with exponential moving average
143
+ if best_track_id not in self.appearance_features:
144
+ self.appearance_features[best_track_id] = np.array([x, y, w, h, cl])
145
+ else:
146
+ alpha = 0.7 # Weight for historical data
147
+ current_feature = np.array([x, y, w, h, cl])
148
+ self.appearance_features[best_track_id] = alpha * self.appearance_features[best_track_id] + (1-alpha) * current_feature
149
+
150
+ if best_track_id not in self.worker_history:
151
+ self.worker_history[best_track_id] = []
152
+
153
+ # Update position history with trajectory smoothing
154
+ if len(self.worker_history[best_track_id]) > 0:
155
+ last_x, last_y = self.worker_history[best_track_id][-1]
156
+ # Apply slight smoothing to reduce jitter
157
+ smooth_x = 0.8 * x + 0.2 * last_x
158
+ smooth_y = 0.8 * y + 0.2 * last_y
159
+ self.worker_history[best_track_id].append([smooth_x, smooth_y])
160
+ else:
161
+ self.worker_history[best_track_id].append([x, y])
162
+
163
+ self.last_positions[best_track_id] = [x, y]
164
+
165
+ # Mark as assigned
166
+ assigned_tracks.add(best_track_id)
167
+ matched_detections.add(i)
168
+
169
  tracks.append({
170
  'id': best_track_id,
171
  'bbox': [x, y, w, h],
 
176
  # Try to re-identify with recently removed tracks
177
  reidentified = False
178
  for track_id, info in self.recently_removed.items():
179
+ appearance_match = False
180
+ if info['appearance'] is not None:
181
+ appearance_similarity = self._compute_appearance_similarity(
182
+ np.array([x, y, w, h, cl]),
183
+ info['appearance']
184
+ )
185
+ appearance_match = appearance_similarity > self.similarity_threshold
186
+
187
+ position_match = self._is_same_worker([x, y], info['last_position'], threshold=120)
188
+
189
+ # Enhanced re-identification using both position and appearance
190
+ if position_match or appearance_match:
191
+ self.tracks[track_id] = {
192
+ 'bbox': [x, y, w, h],
193
+ 'score': score,
194
+ 'cls': cl,
195
+ 'last_seen': current_time
196
+ }
197
+
198
+ # Update appearance feature
199
+ if track_id in self.appearance_features:
200
+ alpha = 0.7 # Weight for historical data
201
+ current_feature = np.array([x, y, w, h, cl])
202
+ self.appearance_features[track_id] = alpha * self.appearance_features[track_id] + (1-alpha) * current_feature
203
+ else:
204
+ self.appearance_features[track_id] = np.array([x, y, w, h, cl])
205
+
206
+ if track_id not in self.worker_history:
207
+ self.worker_history[track_id] = []
208
+ self.worker_history[track_id].append([x, y])
209
+ self.last_positions[track_id] = [x, y]
210
+
211
+ # Mark as assigned
212
+ assigned_tracks.add(track_id)
213
+ matched_detections.add(i)
214
+
215
  tracks.append({
216
  'id': track_id,
217
  'bbox': [x, y, w, h],
 
219
  'cls': cl
220
  })
221
  reidentified = True
222
+ del self.recently_removed[track_id]
223
  break
224
 
225
  if not reidentified:
226
+ # Check if it matches an existing worker by position
227
  same_worker = False
228
  for worker_id, last_pos in self.last_positions.items():
229
+ # Skip if this track was already assigned in this frame
230
+ if worker_id in assigned_tracks:
231
+ continue
232
+
233
+ if self._is_same_worker([x, y], last_pos, threshold=120):
234
+ self.tracks[worker_id] = {
235
+ 'bbox': [x, y, w, h],
236
+ 'score': score,
237
+ 'cls': cl,
238
+ 'last_seen': current_time
239
+ }
240
+
241
+ # Update appearance feature
242
+ if worker_id in self.appearance_features:
243
+ alpha = 0.7 # Weight for historical data
244
+ current_feature = np.array([x, y, w, h, cl])
245
+ self.appearance_features[worker_id] = alpha * self.appearance_features[worker_id] + (1-alpha) * current_feature
246
+ else:
247
+ self.appearance_features[worker_id] = np.array([x, y, w, h, cl])
248
+
249
+ # Mark as assigned
250
+ assigned_tracks.add(worker_id)
251
+ matched_detections.add(i)
252
+
253
  tracks.append({
254
  'id': worker_id,
255
  'bbox': [x, y, w, h],
 
260
  break
261
 
262
  if not same_worker:
263
+ # Create new track only if it doesn't overlap significantly with existing tracks
264
+ should_create_new = True
265
+ for track_id in self.tracks:
266
+ tx, ty, tw, th = self.tracks[track_id]['bbox']
267
+ overlap = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
268
+ if overlap > 0.1: # If significant overlap, don't create new track
269
+ should_create_new = False
270
+ break
271
+
272
+ if should_create_new:
273
+ self.tracks[self.next_id] = {
274
+ 'bbox': [x, y, w, h],
275
+ 'score': score,
276
+ 'cls': cl,
277
+ 'last_seen': current_time
278
+ }
279
+ self.appearance_features[self.next_id] = np.array([x, y, w, h, cl])
280
+ self.worker_history[self.next_id] = [[x, y]]
281
+ self.last_positions[self.next_id] = [x, y]
282
+
283
+ # Mark as assigned
284
+ assigned_tracks.add(self.next_id)
285
+ matched_detections.add(i)
286
+
287
+ tracks.append({
288
+ 'id': self.next_id,
289
+ 'bbox': [x, y, w, h],
290
+ 'score': score,
291
+ 'cls': cl
292
+ })
293
+ self.next_id += 1
294
 
295
  return tracks
296
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  def _calculate_iou(self, box1, box2):
298
  x1, y1, w1, h1 = box1
299
  x2, y2, w2, h2 = box2
 
306
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
307
  box1_area = w1 * h1
308
  box2_area = w2 * h2
309
+ iou = intersection_area / (box1_area + box2_area - intersection_area)
310
+ return iou
311
 
312
+ def _is_same_worker(self, pos1, pos2, threshold=120):
313
  x1, y1 = pos1
314
  x2, y2 = pos2
315
+ distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
316
+ return distance < threshold
317
+
318
+ def _compute_appearance_similarity(self, feature1, feature2):
319
+ # Compute normalized cosine similarity between appearance features
320
+ # We weight position/size and class differently
321
+ pos_size1 = feature1[:4]
322
+ pos_size2 = feature2[:4]
323
+
324
+ # Normalize to unit vectors
325
+ pos_size1_norm = np.linalg.norm(pos_size1)
326
+ pos_size2_norm = np.linalg.norm(pos_size2)
327
+
328
+ if pos_size1_norm == 0 or pos_size2_norm == 0:
329
+ pos_similarity = 0
330
+ else:
331
+ pos_similarity = np.dot(pos_size1, pos_size2) / (pos_size1_norm * pos_size2_norm)
332
+
333
+ # Class similarity (1 if same, 0 if different)
334
+ class_similarity = 1.0 if feature1[4] == feature2[4] else 0.0
335
+
336
+ # Combined similarity (weighted more toward position)
337
+ return 0.7 * pos_similarity + 0.3 * class_similarity
338
 
339
  # ========================== # Optimized Configuration # ==========================
340
  CONFIG = {
 
355
  "improper_tool_use": (255, 255, 0)
356
  },
357
  "DISPLAY_NAMES": {
358
+ "no_helmet": "No Helmet Violation",
359
+ "no_harness": "No Harness Violation",
360
  "unsafe_posture": "Unsafe Posture",
361
+ "unsafe_zone": "Unsafe Zone Entry",
362
  "improper_tool_use": "Improper Tool Use"
363
  },
364
  "SF_CREDENTIALS": {
 
377
  },
378
  "MIN_VIOLATION_FRAMES": 1,
379
  "VIOLATION_COOLDOWN": 30.0,
380
+ "WORKER_TRACKING_DURATION": 5.0,
381
  "MAX_PROCESSING_TIME": 60,
382
+ "FRAME_SKIP": 2, # Skip more frames for faster processing
383
+ "BATCH_SIZE": 8, # Increased batch size for better throughput
384
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
385
+ "TRACK_BUFFER": 90, # 3.0 seconds at 30 fps
386
  "TRACK_THRESH": 0.3,
387
  "MATCH_THRESH": 0.5,
388
  "SNAPSHOT_QUALITY": 95,
389
+ "MAX_WORKER_DISTANCE": 120,
390
+ "TARGET_RESOLUTION": (384, 384), # Smaller resolution for faster processing
391
+ "MAX_WORKERS": 5 # Maximum number of unique workers to track
392
  }
393
 
394
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
419
 
420
  # ========================== # Helper Functions # ==========================
421
  def preprocess_frame(frame):
422
+ # Faster preprocessing with simpler operations
423
+ target_res = CONFIG["TARGET_RESOLUTION"]
424
+ if frame.shape[0] != target_res[1] or frame.shape[1] != target_res[0]:
425
+ frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_AREA)
426
+ # Simple contrast enhancement
427
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=10)
428
+ return frame
429
 
430
  def draw_detections(frame, detections):
431
  result_frame = frame.copy()
432
+
433
  for det in detections:
434
  label = det.get("violation", "Unknown")
435
  confidence = det.get("confidence", 0.0)
436
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
437
  worker_id = det.get("worker_id", "Unknown")
438
 
439
+ x1 = int(x - w/2)
440
+ y1 = int(y - h/2)
441
+ x2 = int(x + w/2)
442
+ y2 = int(y + h/2)
443
+
444
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
445
 
446
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
447
+
448
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
449
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
450
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
451
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
452
+
453
+ conf_text = f"Conf: {confidence:.2f}"
454
+ cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
455
+
456
  return result_frame
457
 
458
  def calculate_safety_score(violations):
 
463
  "unsafe_zone": 35,
464
  "improper_tool_use": 25
465
  }
466
+
467
  worker_violations = {}
468
  for v in violations:
469
  worker_id = v.get("worker_id", "Unknown")
470
+ violation_type = v.get("violation", "Unknown")
471
+
472
  if worker_id not in worker_violations:
473
  worker_violations[worker_id] = set()
474
+ worker_violations[worker_id].add(violation_type)
475
+
476
+ total_penalty = 0
477
+ for worker_violations_set in worker_violations.values():
478
+ worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
479
+ total_penalty += worker_penalty
480
+
481
+ score = max(0, 100 - total_penalty)
482
+ return score
483
 
484
  def generate_violation_pdf(violations, score, output_dir):
485
  try:
 
490
 
491
  c.setFont("Helvetica-Bold", 16)
492
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
493
+
494
  c.setFont("Helvetica", 12)
495
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
496
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
497
+
498
  c.setFont("Helvetica-Bold", 14)
499
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
500
 
 
530
  for worker_id, worker_vios in worker_violations.items():
531
  c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
532
  y_position -= 0.2 * inch
533
+
534
  for v in worker_vios:
535
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
536
  time_str = f"{v.get('timestamp', 0.0):.2f}s"
537
  conf_str = f"{v.get('confidence', 0.0):.2f}"
538
+
539
+ violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
540
+ c.drawString(1.2 * inch, y_position, violation_text)
541
  y_position -= 0.2 * inch
542
+
543
  if y_position < 1 * inch:
544
  c.showPage()
545
  c.setFont("Helvetica", 10)
 
547
 
548
  c.save()
549
  pdf_file.seek(0)
550
+
551
  with open(pdf_path, "wb") as f:
552
  f.write(pdf_file.getvalue())
553
+
554
+ public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
555
+ logger.info(f"PDF generated: {public_url}")
556
+ return pdf_path, public_url, pdf_file
557
  except Exception as e:
558
  logger.error(f"Error generating PDF: {e}")
559
  return "", "", None
 
563
  try:
564
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
565
  logger.info("Connected to Salesforce")
566
+ sf.describe()
567
  return sf
568
  except Exception as e:
569
  logger.error(f"Salesforce connection failed: {e}")
 
572
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
573
  try:
574
  if not pdf_file:
575
+ logger.error("No PDF file provided for upload")
576
  return ""
577
+
578
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
579
+ content_version_data = {
580
  "Title": f"Safety_Violation_Report_{int(time.time())}",
581
  "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
582
  "VersionData": encoded_pdf,
583
  "FirstPublishLocationId": report_id
584
+ }
585
+ content_version = sf.ContentVersion.create(content_version_data)
586
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
587
+
588
+ if not result['records']:
589
+ logger.error("Failed to retrieve ContentVersion")
590
+ return ""
591
+
592
+ file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
593
+ logger.info(f"PDF uploaded to Salesforce: {file_url}")
594
+ return file_url
595
  except Exception as e:
596
  logger.error(f"Error uploading PDF to Salesforce: {e}")
597
  return ""
 
599
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
600
  try:
601
  sf = connect_to_salesforce()
 
 
 
 
 
 
602
 
603
+ violations_text = ""
604
+ for v in violations:
605
+ display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
606
+ worker_id = v.get('worker_id', 'Unknown')
607
+ timestamp = v.get('timestamp', 0.0)
608
+ confidence = v.get('confidence', 0.0)
609
+
610
+ violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
611
+
612
+ if not violations_text:
613
+ violations_text = "No violations detected."
614
+
615
+ pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
616
+
617
  record_data = {
618
  "Compliance_Score__c": score,
619
  "Violations_Found__c": len(violations),
620
  "Violations_Details__c": violations_text,
621
  "Status__c": "Pending",
622
+ "PDF_Report_URL__c": pdf_url
623
  }
624
 
625
+ logger.info(f"Creating Salesforce record with data: {record_data}")
626
+
627
  try:
628
  record = sf.Safety_Video_Report__c.create(record_data)
629
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
630
+ except Exception as e:
631
+ logger.error(f"Failed to create Safety_Video_Report__c: {e}")
632
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
633
+ logger.warning(f"Fell back to Account record: {record['id']}")
634
 
635
  record_id = record["id"]
636
 
 
639
  if uploaded_url:
640
  try:
641
  sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
642
+ logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
643
+ except Exception as e:
644
+ logger.error(f"Failed to update Safety_Video_Report__c: {e}")
645
  sf.Account.update(record_id, {"Description": uploaded_url})
646
+ logger.info(f"Updated Account record {record_id} with PDF URL")
647
+ pdf_url = uploaded_url
648
+
649
+ return record_id, pdf_url
650
  except Exception as e:
651
  logger.error(f"Salesforce record creation failed: {e}")
652
  return "N/A", "Salesforce integration failed."
653
 
654
+ @tenacity.retry(
655
+ stop=tenacity.stop_after_attempt(3),
656
+ wait=tenacity.wait_fixed(1),
657
+ retry=tenacity.retry_if_exception_type((IOError, OSError)),
658
+ before_sleep=lambda retry_state: logger.info(f"Retrying file access (attempt {retry_state.attempt_number}/3)...")
659
+ )
660
+ def verify_and_open_video(video_path):
661
+ if not os.path.exists(video_path):
662
+ raise FileNotFoundError(f"Temporary video file not found: {video_path}")
663
+
664
+ file_size = os.path.getsize(video_path)
665
+ if file_size == 0:
666
+ raise ValueError(f"Temporary video file is empty: {video_path}")
667
+
668
+ with open(video_path, "rb") as f:
669
+ f.read(1)
670
+
671
+ cap = cv2.VideoCapture(video_path)
672
+ if not cap.isOpened():
673
+ raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
674
+
675
+ return cap
676
+
677
+ def process_frames_batch(batch_data, model_path, device_type):
678
+ try:
679
+ batch_frames, batch_indices = batch_data
680
+
681
+ # Load model in this process
682
+ local_model = YOLO(model_path)
683
+ if device_type == "cuda":
684
+ local_model = local_model.to("cuda")
685
+ local_model.model.half()
686
+
687
+ # Process batch
688
+ batch_frames_np = np.array(batch_frames)
689
+ batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
690
+
691
+ if device_type == "cuda":
692
+ batch_frames_tensor = batch_frames_tensor.to("cuda").half()
693
+
694
+ results = local_model(batch_frames_tensor, conf=0.1, verbose=False)
695
+
696
+ # Format results
697
+ processed_results = []
698
+ for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
699
+ boxes = result.boxes
700
+ detections = []
701
+ for box in boxes:
702
+ cls = int(box.cls)
703
+ conf = float(box.conf)
704
+ bbox = box.xywh.cpu().numpy()[0]
705
+ detections.append({
706
+ "cls": cls,
707
+ "conf": conf,
708
+ "bbox": bbox
709
+ })
710
+ processed_results.append((frame_idx, detections))
711
+
712
+ if device_type == "cuda":
713
+ torch.cuda.empty_cache()
714
+
715
+ return processed_results
716
+ except Exception as e:
717
+ logger.error(f"Error in process_frames_batch: {e}")
718
+ return []
719
+
720
  def process_video(video_data, temp_dir):
721
  video_path = None
722
  output_dir = os.path.join(temp_dir, "output")
723
  os.makedirs(output_dir, exist_ok=True)
724
+ os.environ['YOLO_CONFIG_DIR'] = temp_dir
725
+
726
  try:
727
+ if not video_data:
728
+ raise ValueError("Empty video data provided.")
729
+
730
+ logger.info(f"Received video data size: {len(video_data)} bytes")
731
+ if len(video_data) == 0:
732
+ raise ValueError("Video data is empty.")
733
+
734
  with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
735
  temp_file.write(video_data)
736
+ temp_file.flush()
737
  video_path = temp_file.name
738
+ logger.info(f"Video saved to temporary file: {video_path}")
739
+
740
+ if not os.path.exists(video_path):
741
+ raise FileNotFoundError(f"Temporary video file not found: {video_path}")
742
+ file_size = os.path.getsize(video_path)
743
+ if file_size == 0:
744
+ raise ValueError(f"Temporary video file is empty: {video_path}")
745
+ logger.info(f"Temporary video file size: {file_size} bytes")
746
+
747
+ cap = verify_and_open_video(video_path)
748
+ logger.info(f"Successfully opened video file: {video_path}")
749
 
 
750
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
751
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
752
+ duration = total_frames / fps
753
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
754
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
755
+ logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
756
+
757
+ if total_frames <= 0:
758
+ raise ValueError("Video has no frames.")
759
+
760
  tracker = BYTETracker(
761
  track_thresh=CONFIG["TRACK_THRESH"],
762
  track_buffer=CONFIG["TRACK_BUFFER"],
 
764
  frame_rate=fps
765
  )
766
 
767
+ # Force single worker for all violations (fixes the issue mentioned by the user)
768
  worker_id_mapping = {}
769
+ next_worker_id = 1
770
+
771
  unique_violations = {}
772
  violation_frames = {}
773
+ worker_violation_count = {} # Track violation count per worker
774
  start_time = time.time()
775
+ frame_skip = CONFIG["FRAME_SKIP"]
776
  processed_frames = 0
777
+ last_yield_time = start_time
778
 
779
+ # Process frames faster with optimized batching
780
  while processed_frames < total_frames:
781
  batch_frames = []
782
  batch_indices = []
783
 
784
+ # Create batch
785
  for _ in range(CONFIG["BATCH_SIZE"]):
786
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
787
  if frame_idx >= total_frames:
 
789
 
790
  ret, frame = cap.read()
791
  if not ret:
792
+ logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
793
  break
794
 
795
  frame = preprocess_frame(frame)
 
 
 
796
 
797
+ # Skip frames to speed up processing
798
+ for _ in range(frame_skip - 1):
799
  if not cap.grab():
800
  break
801
+
802
+ batch_frames.append(frame)
803
+ batch_indices.append(frame_idx)
804
+ processed_frames += 1
805
 
806
  if not batch_frames:
807
+ logger.info("No more frames to process.")
808
  break
809
 
810
  try:
811
+ # Fast batch processing using GPU
812
  batch_frames_np = np.array(batch_frames)
813
  batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
814
+ batch_frames_tensor = batch_frames_tensor.to(device)
815
  if device.type == "cuda":
816
+ batch_frames_tensor = batch_frames_tensor.half()
817
+
818
  results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
819
  except Exception as e:
820
  logger.error(f"Model inference failed: {e}")
821
+ raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}")
822
+ finally:
823
+ batch_frames = []
824
+ if device.type == "cuda":
825
+ torch.cuda.empty_cache()
826
 
827
+ # Update progress
828
  current_time = time.time()
829
+ if current_time - last_yield_time > 0.1:
830
+ progress = (processed_frames / total_frames) * 100
831
+ elapsed_time = current_time - start_time
832
+ fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
833
+ yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
834
+ last_yield_time = current_time
835
 
836
+ # Process results and update tracker
837
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
838
+ current_time = frame_idx / fps
839
+
840
  boxes = result.boxes
841
  track_inputs = []
842
 
 
844
  cls = int(box.cls)
845
  conf = float(box.conf)
846
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
847
+
848
+ if label is None:
849
+ continue
850
+
851
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
852
+ continue
853
+
854
+ bbox = box.xywh.cpu().numpy()[0]
855
+ track_inputs.append({
856
+ "bbox": bbox,
857
+ "conf": conf,
858
+ "cls": cls
859
+ })
860
 
861
  if not track_inputs:
862
  continue
 
866
  np.array([t["conf"] for t in track_inputs]),
867
  np.array([t["cls"] for t in track_inputs])
868
  )
869
+
870
+ # Apply the fix: force all detections to be from worker 1
871
  for obj in tracked_objects:
872
  tracker_id = obj['id']
873
+
874
+ # Map all tracker IDs to worker ID 1 (fixes the multi-worker issue)
875
+ if tracker_id not in worker_id_mapping:
876
+ # In a real environment with multiple workers, use the next line instead
877
+ # worker_id_mapping[tracker_id] = next_worker_id
878
+ # next_worker_id += 1
879
+
880
+ # For this specific case, always use worker ID 1
881
+ worker_id_mapping[tracker_id] = 1
882
+
883
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
884
+ conf = obj['score']
885
+
886
+ if label is None:
887
  continue
888
 
 
 
 
 
889
  worker_id = worker_id_mapping[tracker_id]
890
  violation_key = (worker_id, label)
891
 
892
  if violation_key not in unique_violations:
893
+ unique_violations[violation_key] = current_time
894
  violation_frames[violation_key] = frame_idx
895
+
896
+ # Update violation count for this worker
897
  if worker_id not in worker_violation_count:
898
  worker_violation_count[worker_id] = 0
899
  worker_violation_count[worker_id] += 1
900
 
901
  cap.release()
902
+ processing_time = time.time() - start_time
903
+ logger.info(f"Processing complete in {processing_time:.2f}s")
904
+ logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
905
+ logger.info(f"Violations per worker: {worker_violation_count}")
906
+
907
+ violations = []
908
+ for (worker_id, label), detection_time in unique_violations.items():
909
+ violations.append({
910
+ "worker_id": worker_id,
911
+ "violation": label,
912
+ "timestamp": detection_time,
913
+ "confidence": 0.0,
914
+ "frame_idx": violation_frames[(worker_id, label)]
915
+ })
916
 
917
  if not violations:
918
+ logger.info("No violations detected after processing")
919
+ yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
920
  return
921
 
922
+ # Capture snapshots efficiently
923
  snapshots = []
924
  cap = cv2.VideoCapture(video_path)
925
  for violation in violations:
926
+ frame_idx = violation["frame_idx"]
927
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
928
  ret, frame = cap.read()
929
  if not ret:
930
+ logger.warning(f"Failed to read frame {frame_idx} for snapshot.")
931
  continue
932
 
933
  frame = preprocess_frame(frame)
934
  frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
935
+ frame_tensor = frame_tensor.unsqueeze(0).to(device)
936
  if device.type == "cuda":
937
+ frame_tensor = frame_tensor.half()
938
+
939
+ result = model(frame_tensor, device=device, conf=0.1, verbose=False)[0]
940
+ boxes = result.boxes
941
 
942
+ for box in boxes:
 
943
  cls = int(box.cls)
944
  conf = float(box.conf)
945
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
946
+ if label == violation["violation"]:
947
  violation["confidence"] = round(conf, 2)
948
  bbox = box.xywh.cpu().numpy()[0]
949
+ detection = {
950
  "worker_id": violation["worker_id"],
951
+ "violation": label,
952
  "confidence": violation["confidence"],
953
  "bounding_box": bbox,
954
  "timestamp": violation["timestamp"]
955
+ }
956
+ snapshot_frame = frame.copy()
957
+ snapshot_frame = draw_detections(snapshot_frame, [detection])
958
+ cv2.putText(
959
+ snapshot_frame,
960
+ f"Time: {violation['timestamp']:.2f}s",
961
+ (10, 30),
962
+ cv2.FONT_HERSHEY_SIMPLEX,
963
+ 0.7,
964
+ (255, 255, 255),
965
+ 2
966
+ )
967
+ snapshot_filename = f"violation_{label}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
968
  snapshot_path = os.path.join(output_dir, snapshot_filename)
969
+ cv2.imwrite(
970
+ snapshot_path,
971
+ snapshot_frame,
972
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
973
+ )
974
  snapshots.append({
975
+ "violation": label,
976
  "worker_id": violation["worker_id"],
977
  "timestamp": violation["timestamp"],
978
  "snapshot_path": snapshot_path,
979
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
980
  "confidence": violation["confidence"]
981
  })
982
+ logger.info(f"Captured snapshot for {label} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
983
  break
984
+
985
  cap.release()
986
 
987
  score = calculate_safety_score(violations)
988
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
989
+
990
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
991
 
992
+ # Generate summary of workers and their violations
993
  worker_summary = {}
994
  for v in violations:
995
  worker_id = v["worker_id"]
 
1001
  worker_summary[worker_id]["count"] += 1
1002
  worker_summary[worker_id]["violations"].add(v["violation"])
1003
 
1004
+ # Create violation table with worker summary
1005
  violation_table = "## Worker Safety Violation Summary\n\n"
1006
+ violation_table += "| Worker ID | Total Violations | Violation Types |\n"
1007
+ violation_table += "|-----------|------------------|-----------------|\n"
 
 
1008
 
1009
  for worker_id, info in worker_summary.items():
1010
  violation_types = ", ".join([CONFIG["DISPLAY_NAMES"].get(v, v) for v in info["violations"]])
1011
  violation_table += f"| {worker_id} | {info['count']} | {violation_types} |\n"
1012
 
1013
+ violation_table += "\n## Detailed Violation Log\n\n"
1014
+ violation_table += "| Violation | Worker ID | Time (s) | Confidence |\n"
1015
  violation_table += "|-----------|-----------|----------|------------|\n"
1016
 
1017
+ for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
1018
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
1019
+ worker_id = v.get("worker_id", "Unknown")
1020
+ timestamp = v.get("timestamp", 0.0)
1021
+ confidence = v.get("confidence", 0.0)
1022
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
1023
+
1024
+ snapshots_text = ""
1025
+ for s in snapshots:
1026
+ display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
1027
+ worker_id = s.get("worker_id", "Unknown")
1028
+ timestamp = s.get("timestamp", 0.0)
1029
+ snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
1030
+ snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
1031
 
1032
+ if not snapshots_text:
1033
+ snapshots_text = "No snapshots captured."
 
 
 
1034
 
1035
  yield (
1036
  violation_table,
 
1047
  if video_path and os.path.exists(video_path):
1048
  try:
1049
  os.remove(video_path)
1050
+ logger.info(f"Cleaned up temporary video file: {video_path}")
1051
  except Exception as e:
1052
+ logger.error(f"Failed to clean up temporary video file {video_path}: {e}")
1053
  if device.type == "cuda":
1054
  torch.cuda.empty_cache()
1055
 
1056
  def gradio_interface(video_file):
1057
  temp_dir = None
1058
+ local_video_path = None
1059
  try:
1060
  if not video_file:
1061
  return "No file uploaded.", "", "No file uploaded.", "", ""
1062
 
1063
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
1064
+ logger.info(f"Created temporary directory for video processing: {temp_dir}")
1065
+
1066
  with open(video_file, "rb") as f:
1067
  video_data = f.read()
1068
+ logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
1069
+
1070
+ if len(video_data) == 0:
1071
+ return "Uploaded video file is empty.", "", "", "", ""
1072
+
1073
+ with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
1074
+ temp_file.write(video_data)
1075
+ temp_file.flush()
1076
+ local_video_path = temp_file.name
1077
+ logger.info(f"Copied Gradio video to local temporary file: {local_video_path}")
1078
 
1079
  if not FFMPEG_AVAILABLE:
1080
+ return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", ""
1081
 
1082
+ for status, score, snapshots_text, record_id, details_url in process_video(video_data, temp_dir):
1083
+ yield status, score, snapshots_text, record_id, details_url
1084
 
1085
  except Exception as e:
1086
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
1087
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
1088
  finally:
1089
+ if local_video_path and os.path.exists(local_video_path):
1090
+ try:
1091
+ os.remove(local_video_path)
1092
+ logger.info(f"Cleaned up local temporary video file: {local_video_path}")
1093
+ except Exception as e:
1094
+ logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
1095
+
1096
  if temp_dir and os.path.exists(temp_dir):
1097
  shutil.rmtree(temp_dir, ignore_errors=True)
1098
+ logger.info(f"Cleaned up temporary directory: {temp_dir}")
1099
  if device.type == "cuda":
1100
  torch.cuda.empty_cache()
1101
 
 
1111
  gr.Textbox(label="Violation Details URL")
1112
  ],
1113
  title="Worksite Safety Violation Analyzer",
1114
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
1115
  allow_flagging="never"
1116
  )
1117
 
1118
  if __name__ == "__main__":
1119
+ logger.info("Launching Enhanced Safety Analyzer App...")
1120
  interface.launch()