PrashanthB461 commited on
Commit
55fb95e
·
verified ·
1 Parent(s): 727b3f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -264
app.py CHANGED
@@ -25,7 +25,7 @@ import tenacity
25
 
26
  # ========================== # Configuration and Setup # ==========================
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
- logger = logging.getLogger(__name__)
29
 
30
  def check_ffmpeg():
31
  try:
@@ -38,9 +38,9 @@ def check_ffmpeg():
38
 
39
  FFMPEG_AVAILABLE = check_ffmpeg()
40
 
41
- # ========================== # Improved ByteTrack Implementation # ==========================
42
  class BYTETracker:
43
- def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
  self.match_thresh = match_thresh
@@ -49,68 +49,70 @@ class BYTETracker:
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
- self.recently_removed = {}
53
- self.worker_centroids = {} # Store average positions for each worker
54
- self.violation_types = {} # Track violation types per worker
 
 
55
 
56
  def update(self, dets, scores, cls):
57
  tracks = []
58
  current_time = time.time()
59
-
60
  # Prune stale tracks
61
  stale_ids = []
62
  for track_id, track_info in self.tracks.items():
63
  if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
64
  stale_ids.append(track_id)
65
-
66
  for track_id in stale_ids:
67
- # Store recently removed tracks for re-identification (for 1 second)
68
  self.recently_removed[track_id] = {
69
  'bbox': self.tracks[track_id]['bbox'],
70
  'last_seen': current_time,
71
  'last_position': self.last_positions.get(track_id, [0, 0]),
72
- 'violation_types': self.violation_types.get(track_id, set())
73
  }
74
  del self.tracks[track_id]
75
  if track_id in self.worker_history:
76
  del self.worker_history[track_id]
77
  if track_id in self.last_positions:
78
  del self.last_positions[track_id]
79
- # Keep the centroid and violation types for re-identification
80
- # Don't delete from self.worker_centroids or self.violation_types
81
 
82
- # Clean up recently_removed tracks older than 1 second
83
  to_remove = []
84
  for track_id, info in self.recently_removed.items():
85
- if current_time - info['last_seen'] > 1.0:
86
  to_remove.append(track_id)
87
  for track_id in to_remove:
88
  del self.recently_removed[track_id]
89
 
 
 
 
90
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
91
  if score < self.track_thresh:
92
  continue
93
-
94
  x, y, w, h = det
95
  matched = False
96
  best_iou = 0
97
  best_track_id = None
98
 
99
- # Get current violation type
100
- violation_type = CONFIG["VIOLATION_LABELS"].get(int(cl), "unknown")
101
-
102
  # Try to match with active tracks
103
  for track_id, track_info in self.tracks.items():
104
  tx, ty, tw, th = track_info['bbox']
105
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
106
 
107
- # Check if this is the same worker based on position and size
108
  if iou > self.match_thresh and iou > best_iou:
109
  best_iou = iou
110
  best_track_id = track_id
111
  matched = True
112
-
113
  if matched:
 
114
  self.tracks[best_track_id].update({
115
  'bbox': [x, y, w, h],
116
  'score': score,
@@ -118,114 +120,123 @@ class BYTETracker:
118
  'last_seen': current_time
119
  })
120
 
121
- # Update position history
 
 
122
  if best_track_id not in self.worker_history:
123
  self.worker_history[best_track_id] = []
124
- self.worker_history[best_track_id].append([x, y])
125
- self.last_positions[best_track_id] = [x, y]
126
 
127
- # Update worker centroid with exponential moving average
128
- if best_track_id not in self.worker_centroids:
129
- self.worker_centroids[best_track_id] = [x, y]
130
- else:
131
- self.worker_centroids[best_track_id][0] = 0.7 * self.worker_centroids[best_track_id][0] + 0.3 * x
132
- self.worker_centroids[best_track_id][1] = 0.7 * self.worker_centroids[best_track_id][1] + 0.3 * y
 
133
 
134
- # Update violation types for this worker
135
- if best_track_id not in self.violation_types:
136
- self.violation_types[best_track_id] = set()
137
- self.violation_types[best_track_id].add(violation_type)
138
 
139
- tracks.append({
140
  'id': best_track_id,
141
  'bbox': [x, y, w, h],
142
  'score': score,
143
  'cls': cl
144
- })
145
  else:
146
- # Try to match with any known worker based on position
147
- matched_worker = False
148
- best_distance = float('inf')
149
- best_worker_id = None
150
-
151
- # First check active tracks
152
- for worker_id, centroid in self.worker_centroids.items():
153
- if worker_id in self.tracks: # Only consider active tracks
154
- distance = self._calculate_distance([x, y], centroid)
155
- if distance < CONFIG["MAX_WORKER_DISTANCE"] and distance < best_distance:
156
- best_distance = distance
157
- best_worker_id = worker_id
158
- matched_worker = True
159
-
160
- # If no match in active tracks, try recently removed tracks
161
- if not matched_worker:
162
- for track_id, info in self.recently_removed.items():
163
- if track_id in self.worker_centroids:
164
- distance = self._calculate_distance([x, y], self.worker_centroids[track_id])
165
- if distance < CONFIG["MAX_WORKER_DISTANCE"] and distance < best_distance:
166
- best_distance = distance
167
- best_worker_id = track_id
168
- matched_worker = True
 
 
 
 
 
 
 
169
 
170
- if matched_worker:
171
- # Reuse the existing worker ID
172
- self.tracks[best_worker_id] = {
173
- 'bbox': [x, y, w, h],
174
- 'score': score,
175
- 'cls': cl,
176
- 'last_seen': current_time
177
- }
178
-
179
- if best_worker_id not in self.worker_history:
180
- self.worker_history[best_worker_id] = []
181
- self.worker_history[best_worker_id].append([x, y])
182
- self.last_positions[best_worker_id] = [x, y]
183
-
184
- # Update centroid
185
- if best_worker_id not in self.worker_centroids:
186
- self.worker_centroids[best_worker_id] = [x, y]
187
- else:
188
- self.worker_centroids[best_worker_id][0] = 0.7 * self.worker_centroids[best_worker_id][0] + 0.3 * x
189
- self.worker_centroids[best_worker_id][1] = 0.7 * self.worker_centroids[best_worker_id][1] + 0.3 * y
190
-
191
- # Update violation types
192
- if best_worker_id not in self.violation_types:
193
- self.violation_types[best_worker_id] = set()
194
- self.violation_types[best_worker_id].add(violation_type)
195
-
196
- # If it was in recently_removed, remove it from there
197
- if best_worker_id in self.recently_removed:
198
- del self.recently_removed[best_worker_id]
199
-
200
- tracks.append({
201
- 'id': best_worker_id,
202
- 'bbox': [x, y, w, h],
203
- 'score': score,
204
- 'cls': cl
205
- })
206
- else:
207
- # Create a new worker ID
208
- new_id = self.next_id
209
- self.tracks[new_id] = {
210
- 'bbox': [x, y, w, h],
211
- 'score': score,
212
- 'cls': cl,
213
- 'last_seen': current_time
214
- }
215
- self.worker_history[new_id] = [[x, y]]
216
- self.last_positions[new_id] = [x, y]
217
- self.worker_centroids[new_id] = [x, y]
218
- self.violation_types[new_id] = {violation_type}
219
 
220
- tracks.append({
221
- 'id': new_id,
222
- 'bbox': [x, y, w, h],
223
- 'score': score,
224
- 'cls': cl
225
- })
226
- self.next_id += 1
227
-
228
- return tracks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  def _calculate_iou(self, box1, box2):
231
  x1, y1, w1, h1 = box1
@@ -241,14 +252,26 @@ class BYTETracker:
241
  box2_area = w2 * h2
242
  iou = intersection_area / (box1_area + box2_area - intersection_area)
243
  return iou
244
-
245
- def _calculate_distance(self, pos1, pos2):
246
  x1, y1 = pos1
247
  x2, y2 = pos2
248
- return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
249
-
250
- def _is_same_worker(self, pos1, pos2, threshold=150):
251
- return self._calculate_distance(pos1, pos2) < threshold
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  # ========================== # Optimized Configuration # ==========================
254
  CONFIG = {
@@ -293,14 +316,14 @@ CONFIG = {
293
  "VIOLATION_COOLDOWN": 30.0,
294
  "WORKER_TRACKING_DURATION": 10.0,
295
  "MAX_PROCESSING_TIME": 60,
296
- "FRAME_SKIP": 2, # Increased to improve performance
297
- "BATCH_SIZE": 20, # Increased for better throughput
298
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
299
- "TRACK_BUFFER": 150,
300
  "TRACK_THRESH": 0.3,
301
- "MATCH_THRESH": 0.5,
302
  "SNAPSHOT_QUALITY": 95,
303
- "MAX_WORKER_DISTANCE": 150,
304
  "TARGET_RESOLUTION": (384, 384)
305
  }
306
 
@@ -318,7 +341,7 @@ def load_model():
318
  if not os.path.isfile(model_path):
319
  logger.info(f"Downloading fallback model: {model_path}")
320
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
321
-
322
  model = YOLO(model_path).to(device)
323
  if device.type == "cuda":
324
  model.model.half()
@@ -339,7 +362,7 @@ def preprocess_frame(frame):
339
 
340
  def draw_detections(frame, detections):
341
  result_frame = frame.copy()
342
-
343
  for det in detections:
344
  label = det.get("violation", "Unknown")
345
  confidence = det.get("confidence", 0.0)
@@ -350,19 +373,19 @@ def draw_detections(frame, detections):
350
  y1 = int(y - h/2)
351
  x2 = int(x + w/2)
352
  y2 = int(y + h/2)
353
-
354
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
355
-
356
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
357
-
358
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
359
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
360
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
361
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
362
-
363
  conf_text = f"Conf: {confidence:.2f}"
364
  cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
365
-
366
  return result_frame
367
 
368
  def calculate_safety_score(violations):
@@ -548,7 +571,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
548
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
549
  if uploaded_url:
550
  try:
551
- sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
552
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
553
  except Exception as e:
554
  logger.error(f"Failed to update Safety_Video_Report__c: {e}")
@@ -589,11 +612,11 @@ def process_video(video_data, temp_dir):
589
  output_dir = os.path.join(temp_dir, "output")
590
  os.makedirs(output_dir, exist_ok=True)
591
  os.environ['YOLO_CONFIG_DIR'] = temp_dir
592
-
593
  try:
594
  if not video_data:
595
  raise ValueError("Empty video data provided.")
596
-
597
  logger.info(f"Received video data size: {len(video_data)} bytes")
598
  if len(video_data) == 0:
599
  raise ValueError("Video data is empty.")
@@ -628,30 +651,34 @@ def process_video(video_data, temp_dir):
628
  track_thresh=CONFIG["TRACK_THRESH"],
629
  track_buffer=CONFIG["TRACK_BUFFER"],
630
  match_thresh=CONFIG["MATCH_THRESH"],
631
- frame_rate=fps
 
632
  )
633
 
634
  unique_violations = {}
635
  violation_frames = {}
636
- worker_violation_count = {}
637
  start_time = time.time()
638
  frame_skip = CONFIG["FRAME_SKIP"]
639
  processed_frames = 0
640
  last_yield_time = start_time
641
 
642
- # Pre-allocate memory for batch processing
643
- batch_size = CONFIG["BATCH_SIZE"]
644
- batch_frames = []
645
- batch_indices = []
646
 
647
- # Process frames in batches for better performance
648
  while processed_frames < total_frames:
649
- # Clear previous batch
650
  batch_frames = []
651
  batch_indices = []
 
652
 
653
- # Fill the batch
654
- for _ in range(batch_size):
 
 
 
 
655
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
656
  if frame_idx >= total_frames:
657
  break
@@ -661,58 +688,45 @@ def process_video(video_data, temp_dir):
661
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
662
  break
663
 
664
- # Preprocess frame (resize and enhance)
665
  frame = preprocess_frame(frame)
666
-
667
- # Skip frames for performance
668
- for _ in range(frame_skip - 1):
669
- if not cap.grab():
670
- break
671
 
672
  batch_frames.append(frame)
673
  batch_indices.append(frame_idx)
 
674
  processed_frames += 1
675
 
676
  if not batch_frames:
677
  logger.info("No more frames to process.")
678
  break
679
 
680
- # Update progress
681
- current_time = time.time()
682
- if current_time - last_yield_time > 0.1:
683
- progress = (processed_frames / total_frames) * 100
684
- elapsed_time = current_time - start_time
685
- fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
686
- yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
687
- last_yield_time = current_time
688
-
689
  try:
690
- # Convert batch to tensor for efficient processing
691
  batch_frames_np = np.array(batch_frames)
692
  batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
693
  batch_frames_tensor = batch_frames_tensor.to(device)
694
  if device.type == "cuda":
695
  batch_frames_tensor = batch_frames_tensor.half()
696
 
697
- # Run inference on batch
698
  results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
699
  except Exception as e:
700
  logger.error(f"Model inference failed: {e}")
701
  raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}")
702
  finally:
703
- # Clear memory
704
- batch_frames = []
705
  if device.type == "cuda":
706
  torch.cuda.empty_cache()
707
 
708
- # Process results for each frame in the batch
709
- for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
710
- current_time = frame_idx / fps
711
-
 
 
 
 
 
712
  boxes = result.boxes
713
  track_inputs = []
714
 
715
- # Prepare detection inputs for tracker
716
  for box in boxes:
717
  cls = int(box.cls)
718
  conf = float(box.conf)
@@ -733,87 +747,48 @@ def process_video(video_data, temp_dir):
733
 
734
  if not track_inputs:
735
  continue
736
-
737
- # Update tracker with new detections
738
  tracked_objects = tracker.update(
739
  np.array([t["bbox"] for t in track_inputs]),
740
  np.array([t["conf"] for t in track_inputs]),
741
  np.array([t["cls"] for t in track_inputs])
742
  )
743
 
744
- # Process tracked objects
745
  for obj in tracked_objects:
746
  tracker_id = obj['id']
 
 
 
 
 
 
747
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
748
  conf = obj['score']
749
- bbox = obj['bbox']
750
 
751
  if label is None:
752
  continue
753
 
754
- worker_id = tracker_id
755
- violation_key = (worker_id, label)
756
 
757
- # Record unique violations
758
- if violation_key not in unique_violations:
759
- unique_violations[violation_key] = current_time
760
  violation_frames[violation_key] = frame_idx
761
-
762
- # Update violation count for this worker
763
- if worker_id not in worker_violation_count:
764
- worker_violation_count[worker_id] = 0
765
- worker_violation_count[worker_id] += 1
766
 
767
  cap.release()
768
  processing_time = time.time() - start_time
769
  logger.info(f"Processing complete in {processing_time:.2f}s")
770
- logger.info(f"Total unique workers detected: {len(tracker.worker_centroids)}")
771
- logger.info(f"Violations per worker: {worker_violation_count}")
772
-
773
- # Consolidate workers based on spatial proximity
774
- consolidated_workers = {}
775
- processed_workers = set()
776
 
777
- # Sort worker IDs to ensure deterministic consolidation
778
- worker_ids = sorted(tracker.worker_centroids.keys())
779
-
780
- for i, worker_id in enumerate(worker_ids):
781
- if worker_id in processed_workers:
782
- continue
783
-
784
- processed_workers.add(worker_id)
785
- consolidated_workers[worker_id] = [worker_id]
786
-
787
- for j, other_id in enumerate(worker_ids):
788
- if i == j or other_id in processed_workers:
789
- continue
790
-
791
- # Check if workers are close enough to be considered the same person
792
- if worker_id in tracker.worker_centroids and other_id in tracker.worker_centroids:
793
- distance = tracker._calculate_distance(
794
- tracker.worker_centroids[worker_id],
795
- tracker.worker_centroids[other_id]
796
- )
797
-
798
- if distance < CONFIG["MAX_WORKER_DISTANCE"] * 0.8: # More strict for consolidation
799
- consolidated_workers[worker_id].append(other_id)
800
- processed_workers.add(other_id)
801
-
802
- # Create a mapping from old worker IDs to new consolidated IDs
803
- worker_id_mapping = {}
804
- for new_id, old_ids in enumerate(consolidated_workers.values(), 1):
805
- for old_id in old_ids:
806
- worker_id_mapping[old_id] = new_id
807
-
808
- # Update violations with consolidated worker IDs
809
  violations = []
810
  for (worker_id, label), detection_time in unique_violations.items():
811
- new_worker_id = worker_id_mapping.get(worker_id, worker_id)
812
  violations.append({
813
- "worker_id": new_worker_id,
814
  "violation": label,
815
  "timestamp": detection_time,
816
- "confidence": 0.0,
817
  "frame_idx": violation_frames[(worker_id, label)]
818
  })
819
 
@@ -822,7 +797,6 @@ def process_video(video_data, temp_dir):
822
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
823
  return
824
 
825
- # Generate snapshots for each violation
826
  snapshots = []
827
  cap = cv2.VideoCapture(video_path)
828
  for violation in violations:
@@ -867,7 +841,7 @@ def process_video(video_data, temp_dir):
867
  (255, 255, 255),
868
  2
869
  )
870
- snapshot_filename = f"violation_{label}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
871
  snapshot_path = os.path.join(output_dir, snapshot_filename)
872
  cv2.imwrite(
873
  snapshot_path,
@@ -889,40 +863,28 @@ def process_video(video_data, temp_dir):
889
 
890
  score = calculate_safety_score(violations)
891
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
892
-
893
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
894
 
895
- # Generate summary of workers and their violations
896
- worker_summary = {}
897
  for v in violations:
898
- worker_id = v["worker_id"]
899
- if worker_id not in worker_summary:
900
- worker_summary[worker_id] = {
901
- "count": 0,
902
- "violations": set()
903
- }
904
- worker_summary[worker_id]["count"] += 1
905
- worker_summary[worker_id]["violations"].add(v["violation"])
906
-
907
- # Create violation table with worker summary
908
- violation_table = "## Worker Safety Violation Summary\n\n"
909
- violation_table += "| Worker ID | Total Violations | Violation Types |\n"
910
- violation_table += "|-----------|------------------|-----------------|\n"
911
-
912
- for worker_id, info in worker_summary.items():
913
- violation_types = ", ".join([CONFIG["DISPLAY_NAMES"].get(v, v) for v in info["violations"]])
914
- violation_table += f"| {worker_id} | {info['count']} | {violation_types} |\n"
915
 
916
- violation_table += "\n## Detailed Violation Log\n\n"
917
- violation_table += "| Violation | Worker ID | Time (s) | Confidence |\n"
918
  violation_table += "|-----------|-----------|----------|------------|\n"
919
-
920
- for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
921
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
922
- worker_id = v.get("worker_id", "Unknown")
923
- timestamp = v.get("timestamp", 0.0)
924
- confidence = v.get("confidence", 0.0)
925
- violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
 
 
926
 
927
  snapshots_text = ""
928
  for s in snapshots:
@@ -937,7 +899,7 @@ def process_video(video_data, temp_dir):
937
 
938
  yield (
939
  violation_table,
940
- f"Safety Score: {score}%",
941
  snapshots_text,
942
  f"Salesforce Record ID: {record_id}",
943
  final_pdf_url
@@ -962,14 +924,14 @@ def gradio_interface(video_file):
962
  try:
963
  if not video_file:
964
  return "No file uploaded.", "", "No file uploaded.", "", ""
965
-
966
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
967
  logger.info(f"Created temporary directory for video processing: {temp_dir}")
968
 
969
  with open(video_file, "rb") as f:
970
  video_data = f.read()
971
  logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
972
-
973
  if len(video_data) == 0:
974
  return "Uploaded video file is empty.", "", "", "", ""
975
 
@@ -984,7 +946,7 @@ def gradio_interface(video_file):
984
 
985
  for status, score, snapshots_text, record_id, details_url in process_video(video_data, temp_dir):
986
  yield status, score, snapshots_text, record_id, details_url
987
-
988
  except Exception as e:
989
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
990
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
@@ -995,7 +957,7 @@ def gradio_interface(video_file):
995
  logger.info(f"Cleaned up local temporary video file: {local_video_path}")
996
  except Exception as e:
997
  logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
998
-
999
  if temp_dir and os.path.exists(temp_dir):
1000
  shutil.rmtree(temp_dir, ignore_errors=True)
1001
  logger.info(f"Cleaned up temporary directory: {temp_dir}")
@@ -1014,10 +976,10 @@ interface = gr.Interface(
1014
  gr.Textbox(label="Violation Details URL")
1015
  ],
1016
  title="Worksite Safety Violation Analyzer",
1017
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
1018
  allow_flagging="never"
1019
  )
1020
 
1021
- if __name__ == "__main__":
1022
  logger.info("Launching Enhanced Safety Analyzer App...")
1023
  interface.launch()
 
25
 
26
  # ========================== # Configuration and Setup # ==========================
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
+ logger = logging.getLogger(_name_)
29
 
30
  def check_ffmpeg():
31
  try:
 
38
 
39
  FFMPEG_AVAILABLE = check_ffmpeg()
40
 
41
+ # ========================== # ByteTrack Implementation # ==========================
42
  class BYTETracker:
43
+ def _init_(self, track_thresh=0.3, track_buffer=90, match_thresh=0.3, frame_rate=30, max_distance=100):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
  self.match_thresh = match_thresh
 
49
  self.tracks = {}
50
  self.worker_history = {}
51
  self.last_positions = {}
52
+ self.recently_removed = {} # Store recently removed tracks for re-identification
53
+ self.track_attributes = {} # Store additional attributes like appearance features
54
+ self.active_workers = set() # Track currently active workers
55
+ self.worker_violation_history = {} # Track violations per worker
56
+ self.max_worker_distance = max_distance
57
 
58
  def update(self, dets, scores, cls):
59
  tracks = []
60
  current_time = time.time()
61
+
62
  # Prune stale tracks
63
  stale_ids = []
64
  for track_id, track_info in self.tracks.items():
65
  if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
66
  stale_ids.append(track_id)
67
+
68
  for track_id in stale_ids:
69
+ # Store recently removed tracks for re-identification (for 2 seconds)
70
  self.recently_removed[track_id] = {
71
  'bbox': self.tracks[track_id]['bbox'],
72
  'last_seen': current_time,
73
  'last_position': self.last_positions.get(track_id, [0, 0]),
74
+ 'appearance': self.track_attributes.get(track_id, {}).get('appearance', None)
75
  }
76
  del self.tracks[track_id]
77
  if track_id in self.worker_history:
78
  del self.worker_history[track_id]
79
  if track_id in self.last_positions:
80
  del self.last_positions[track_id]
81
+ if track_id in self.active_workers:
82
+ self.active_workers.remove(track_id)
83
 
84
+ # Clean up recently_removed tracks older than 2 seconds
85
  to_remove = []
86
  for track_id, info in self.recently_removed.items():
87
+ if current_time - info['last_seen'] > 2.0:
88
  to_remove.append(track_id)
89
  for track_id in to_remove:
90
  del self.recently_removed[track_id]
91
 
92
+ # Process new detections
93
+ active_tracks = {}
94
+
95
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
96
  if score < self.track_thresh:
97
  continue
98
+
99
  x, y, w, h = det
100
  matched = False
101
  best_iou = 0
102
  best_track_id = None
103
 
 
 
 
104
  # Try to match with active tracks
105
  for track_id, track_info in self.tracks.items():
106
  tx, ty, tw, th = track_info['bbox']
107
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
108
 
 
109
  if iou > self.match_thresh and iou > best_iou:
110
  best_iou = iou
111
  best_track_id = track_id
112
  matched = True
113
+
114
  if matched:
115
+ # Update existing track
116
  self.tracks[best_track_id].update({
117
  'bbox': [x, y, w, h],
118
  'score': score,
 
120
  'last_seen': current_time
121
  })
122
 
123
+ if 'appearance' not in self.track_attributes.get(best_track_id, {}):
124
+ self.track_attributes[best_track_id] = {'appearance': self._extract_appearance_features([x, y, w, h])}
125
+
126
  if best_track_id not in self.worker_history:
127
  self.worker_history[best_track_id] = []
 
 
128
 
129
+ self.worker_history[best_track_id].append({'pos': [x, y], 'time': current_time})
130
+
131
+ if len(self.worker_history[best_track_id]) > 30:
132
+ self.worker_history[best_track_id] = self.worker_history[best_track_id][-30:]
133
+
134
+ self.last_positions[best_track_id] = [x, y]
135
+ self.active_workers.add(best_track_id)
136
 
137
+ if cl is not None:
138
+ if best_track_id not in self.worker_violation_history:
139
+ self.worker_violation_history[best_track_id] = set()
140
+ self.worker_violation_history[best_track_id].add(int(cl))
141
 
142
+ active_tracks[best_track_id] = {
143
  'id': best_track_id,
144
  'bbox': [x, y, w, h],
145
  'score': score,
146
  'cls': cl
147
+ }
148
  else:
149
+ # Try to re-identify with recently removed tracks
150
+ reidentified = False
151
+ for track_id, info in self.recently_removed.items():
152
+ if self._is_same_worker([x, y], info['last_position']):
153
+ self.tracks[track_id] = {
154
+ 'bbox': [x, y, w, h],
155
+ 'score': score,
156
+ 'cls': cl,
157
+ 'last_seen': current_time
158
+ }
159
+ if track_id not in self.worker_history:
160
+ self.worker_history[track_id] = []
161
+ self.worker_history[track_id].append({'pos': [x, y], 'time': current_time})
162
+ self.last_positions[track_id] = [x, y]
163
+ self.active_workers.add(track_id)
164
+
165
+ if cl is not None:
166
+ if track_id not in self.worker_violation_history:
167
+ self.worker_violation_history[track_id] = set()
168
+ self.worker_violation_history[track_id].add(int(cl))
169
+
170
+ active_tracks[track_id] = {
171
+ 'id': track_id,
172
+ 'bbox': [x, y, w, h],
173
+ 'score': score,
174
+ 'cls': cl
175
+ }
176
+ reidentified = True
177
+ del self.recently_removed[track_id]
178
+ break
179
 
180
+ if not reidentified:
181
+ # Try to match with last positions of existing tracks via distance
182
+ same_worker = False
183
+ for worker_id, last_pos in self.last_positions.items():
184
+ if self._is_same_worker([x, y], last_pos):
185
+ self.tracks[worker_id] = {
186
+ 'bbox': [x, y, w, h],
187
+ 'score': score,
188
+ 'cls': cl,
189
+ 'last_seen': current_time
190
+ }
191
+
192
+ if worker_id not in self.worker_history:
193
+ self.worker_history[worker_id] = []
194
+ self.worker_history[worker_id].append({'pos': [x, y], 'time': current_time})
195
+ self.last_positions[worker_id] = [x, y]
196
+ self.active_workers.add(worker_id)
197
+
198
+ if cl is not None:
199
+ if worker_id not in self.worker_violation_history:
200
+ self.worker_violation_history[worker_id] = set()
201
+ self.worker_violation_history[worker_id].add(int(cl))
202
+
203
+ active_tracks[worker_id] = {
204
+ 'id': worker_id,
205
+ 'bbox': [x, y, w, h],
206
+ 'score': score,
207
+ 'cls': cl
208
+ }
209
+ same_worker = True
210
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ if not same_worker:
213
+ # Register a new track
214
+ new_id = self.next_id
215
+ self.tracks[new_id] = {
216
+ 'bbox': [x, y, w, h],
217
+ 'score': score,
218
+ 'cls': cl,
219
+ 'last_seen': current_time
220
+ }
221
+ self.track_attributes[new_id] = {'appearance': self._extract_appearance_features([x, y, w, h])}
222
+ self.worker_history[new_id] = [{'pos': [x, y], 'time': current_time}]
223
+ self.last_positions[new_id] = [x, y]
224
+ self.active_workers.add(new_id)
225
+
226
+ if cl is not None:
227
+ if new_id not in self.worker_violation_history:
228
+ self.worker_violation_history[new_id] = set()
229
+ self.worker_violation_history[new_id].add(int(cl))
230
+
231
+ active_tracks[new_id] = {
232
+ 'id': new_id,
233
+ 'bbox': [x, y, w, h],
234
+ 'score': score,
235
+ 'cls': cl
236
+ }
237
+ self.next_id += 1
238
+
239
+ return list(active_tracks.values())
240
 
241
  def _calculate_iou(self, box1, box2):
242
  x1, y1, w1, h1 = box1
 
252
  box2_area = w2 * h2
253
  iou = intersection_area / (box1_area + box2_area - intersection_area)
254
  return iou
255
+
256
+ def _is_same_worker(self, pos1, pos2):
257
  x1, y1 = pos1
258
  x2, y2 = pos2
259
+ distance = np.sqrt((x1 - x2)*2 + (y1 - y2)*2)
260
+ return distance < self.max_worker_distance
261
+
262
+ def _extract_appearance_features(self, bbox):
263
+ """Simple appearance feature extraction (placeholder)"""
264
+ _, _, w, h = bbox
265
+ return [w, h, w/h]
266
+
267
+ def get_active_worker_count(self):
268
+ return len(self.active_workers)
269
+
270
+ def get_worker_violation_types(self, worker_id):
271
+ return self.worker_violation_history.get(worker_id, set())
272
+
273
+ def get_all_workers(self):
274
+ return set(list(self.tracks.keys()) + list(self.recently_removed.keys()))
275
 
276
  # ========================== # Optimized Configuration # ==========================
277
  CONFIG = {
 
316
  "VIOLATION_COOLDOWN": 30.0,
317
  "WORKER_TRACKING_DURATION": 10.0,
318
  "MAX_PROCESSING_TIME": 60,
319
+ "FRAME_SKIP": 1,
320
+ "BATCH_SIZE": 15,
321
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
322
+ "TRACK_BUFFER": 150, # 5.0 seconds at 30 fps
323
  "TRACK_THRESH": 0.3,
324
+ "MATCH_THRESH": 0.3,
325
  "SNAPSHOT_QUALITY": 95,
326
+ "MAX_WORKER_DISTANCE": 100,
327
  "TARGET_RESOLUTION": (384, 384)
328
  }
329
 
 
341
  if not os.path.isfile(model_path):
342
  logger.info(f"Downloading fallback model: {model_path}")
343
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
344
+
345
  model = YOLO(model_path).to(device)
346
  if device.type == "cuda":
347
  model.model.half()
 
362
 
363
  def draw_detections(frame, detections):
364
  result_frame = frame.copy()
365
+
366
  for det in detections:
367
  label = det.get("violation", "Unknown")
368
  confidence = det.get("confidence", 0.0)
 
373
  y1 = int(y - h/2)
374
  x2 = int(x + w/2)
375
  y2 = int(y + h/2)
376
+
377
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
378
+
379
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
380
+
381
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
382
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
383
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
384
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
385
+
386
  conf_text = f"Conf: {confidence:.2f}"
387
  cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
388
+
389
  return result_frame
390
 
391
  def calculate_safety_score(violations):
 
571
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
572
  if uploaded_url:
573
  try:
574
+ sf.Safety_Video_Report_c.update(record_id, {"PDF_Report_URL_c": uploaded_url})
575
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
576
  except Exception as e:
577
  logger.error(f"Failed to update Safety_Video_Report__c: {e}")
 
612
  output_dir = os.path.join(temp_dir, "output")
613
  os.makedirs(output_dir, exist_ok=True)
614
  os.environ['YOLO_CONFIG_DIR'] = temp_dir
615
+
616
  try:
617
  if not video_data:
618
  raise ValueError("Empty video data provided.")
619
+
620
  logger.info(f"Received video data size: {len(video_data)} bytes")
621
  if len(video_data) == 0:
622
  raise ValueError("Video data is empty.")
 
651
  track_thresh=CONFIG["TRACK_THRESH"],
652
  track_buffer=CONFIG["TRACK_BUFFER"],
653
  match_thresh=CONFIG["MATCH_THRESH"],
654
+ frame_rate=fps,
655
+ max_distance=CONFIG["MAX_WORKER_DISTANCE"]
656
  )
657
 
658
  unique_violations = {}
659
  violation_frames = {}
660
+ violation_confidences = {}
661
  start_time = time.time()
662
  frame_skip = CONFIG["FRAME_SKIP"]
663
  processed_frames = 0
664
  last_yield_time = start_time
665
 
666
+ logger.info("First pass: Worker detection and tracking")
667
+ all_workers = set()
668
+ worker_first_seen = {}
669
+ worker_last_seen = {}
670
 
 
671
  while processed_frames < total_frames:
 
672
  batch_frames = []
673
  batch_indices = []
674
+ batch_timestamps = []
675
 
676
+ for _ in range(CONFIG["BATCH_SIZE"]):
677
+ # Skip frames BEFORE reading to speed up
678
+ for _ in range(frame_skip - 1):
679
+ if not cap.grab():
680
+ break
681
+
682
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
683
  if frame_idx >= total_frames:
684
  break
 
688
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
689
  break
690
 
 
691
  frame = preprocess_frame(frame)
692
+ timestamp = frame_idx / fps
 
 
 
 
693
 
694
  batch_frames.append(frame)
695
  batch_indices.append(frame_idx)
696
+ batch_timestamps.append(timestamp)
697
  processed_frames += 1
698
 
699
  if not batch_frames:
700
  logger.info("No more frames to process.")
701
  break
702
 
 
 
 
 
 
 
 
 
 
703
  try:
 
704
  batch_frames_np = np.array(batch_frames)
705
  batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
706
  batch_frames_tensor = batch_frames_tensor.to(device)
707
  if device.type == "cuda":
708
  batch_frames_tensor = batch_frames_tensor.half()
709
 
 
710
  results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
711
  except Exception as e:
712
  logger.error(f"Model inference failed: {e}")
713
  raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}")
714
  finally:
 
 
715
  if device.type == "cuda":
716
  torch.cuda.empty_cache()
717
 
718
+ current_time = time.time()
719
+ if current_time - last_yield_time > 0.1:
720
+ progress = (processed_frames / total_frames) * 100
721
+ elapsed_time = current_time - start_time
722
+ fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
723
+ yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
724
+ last_yield_time = current_time
725
+
726
+ for i, (result, frame_idx, timestamp) in enumerate(zip(results, batch_indices, batch_timestamps)):
727
  boxes = result.boxes
728
  track_inputs = []
729
 
 
730
  for box in boxes:
731
  cls = int(box.cls)
732
  conf = float(box.conf)
 
747
 
748
  if not track_inputs:
749
  continue
750
+
 
751
  tracked_objects = tracker.update(
752
  np.array([t["bbox"] for t in track_inputs]),
753
  np.array([t["conf"] for t in track_inputs]),
754
  np.array([t["cls"] for t in track_inputs])
755
  )
756
 
 
757
  for obj in tracked_objects:
758
  tracker_id = obj['id']
759
+ all_workers.add(tracker_id)
760
+
761
+ if tracker_id not in worker_first_seen:
762
+ worker_first_seen[tracker_id] = timestamp
763
+ worker_last_seen[tracker_id] = timestamp
764
+
765
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
766
  conf = obj['score']
 
767
 
768
  if label is None:
769
  continue
770
 
771
+ violation_key = (tracker_id, label)
 
772
 
773
+ if violation_key not in unique_violations or conf > violation_confidences.get(violation_key, 0.0):
774
+ unique_violations[violation_key] = timestamp
 
775
  violation_frames[violation_key] = frame_idx
776
+ violation_confidences[violation_key] = conf
 
 
 
 
777
 
778
  cap.release()
779
  processing_time = time.time() - start_time
780
  logger.info(f"Processing complete in {processing_time:.2f}s")
 
 
 
 
 
 
781
 
782
+ total_workers = len(all_workers)
783
+ logger.info(f"Total unique workers detected: {total_workers}")
784
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785
  violations = []
786
  for (worker_id, label), detection_time in unique_violations.items():
 
787
  violations.append({
788
+ "worker_id": worker_id,
789
  "violation": label,
790
  "timestamp": detection_time,
791
+ "confidence": violation_confidences.get((worker_id, label), 0.0),
792
  "frame_idx": violation_frames[(worker_id, label)]
793
  })
794
 
 
797
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
798
  return
799
 
 
800
  snapshots = []
801
  cap = cv2.VideoCapture(video_path)
802
  for violation in violations:
 
841
  (255, 255, 255),
842
  2
843
  )
844
+ snapshot_filename = f"violation_{label}worker{violation['worker_id']}{int(violation['timestamp']*100)}.jpg"
845
  snapshot_path = os.path.join(output_dir, snapshot_filename)
846
  cv2.imwrite(
847
  snapshot_path,
 
863
 
864
  score = calculate_safety_score(violations)
865
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
866
+
867
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
868
 
869
+ worker_violations = {}
 
870
  for v in violations:
871
+ worker_id = v.get("worker_id", "Unknown")
872
+ if worker_id not in worker_violations:
873
+ worker_violations[worker_id] = []
874
+ worker_violations[worker_id].append(v)
 
 
 
 
 
 
 
 
 
 
 
 
 
875
 
876
+ violation_table = f"## Total Workers Detected: {total_workers}\n\n"
877
+ violation_table += "| Worker ID | Violation | Time (s) | Confidence |\n"
878
  violation_table += "|-----------|-----------|----------|------------|\n"
879
+
880
+ for worker_id, vios in sorted(worker_violations.items()):
881
+ vios.sort(key=lambda x: x.get("violation", ""))
882
+
883
+ for v in vios:
884
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
885
+ timestamp = v.get("timestamp", 0.0)
886
+ confidence = v.get("confidence", 0.0)
887
+ violation_table += f"| {worker_id} | {display_name} | {timestamp:.2f} | {confidence:.2f} |\n"
888
 
889
  snapshots_text = ""
890
  for s in snapshots:
 
899
 
900
  yield (
901
  violation_table,
902
+ f"Safety Score: {score}% (Based on {total_workers} workers)",
903
  snapshots_text,
904
  f"Salesforce Record ID: {record_id}",
905
  final_pdf_url
 
924
  try:
925
  if not video_file:
926
  return "No file uploaded.", "", "No file uploaded.", "", ""
927
+
928
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
929
  logger.info(f"Created temporary directory for video processing: {temp_dir}")
930
 
931
  with open(video_file, "rb") as f:
932
  video_data = f.read()
933
  logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
934
+
935
  if len(video_data) == 0:
936
  return "Uploaded video file is empty.", "", "", "", ""
937
 
 
946
 
947
  for status, score, snapshots_text, record_id, details_url in process_video(video_data, temp_dir):
948
  yield status, score, snapshots_text, record_id, details_url
949
+
950
  except Exception as e:
951
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
952
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
 
957
  logger.info(f"Cleaned up local temporary video file: {local_video_path}")
958
  except Exception as e:
959
  logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
960
+
961
  if temp_dir and os.path.exists(temp_dir):
962
  shutil.rmtree(temp_dir, ignore_errors=True)
963
  logger.info(f"Cleaned up temporary directory: {temp_dir}")
 
976
  gr.Textbox(label="Violation Details URL")
977
  ],
978
  title="Worksite Safety Violation Analyzer",
979
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). The system tracks individual workers and their specific violations.",
980
  allow_flagging="never"
981
  )
982
 
983
+ if _name_ == "_main_":
984
  logger.info("Launching Enhanced Safety Analyzer App...")
985
  interface.launch()