PrashanthB461 commited on
Commit
28ba0f6
·
verified ·
1 Parent(s): f7c1bff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +474 -158
app.py CHANGED
@@ -38,96 +38,155 @@ def check_ffmpeg():
38
 
39
  FFMPEG_AVAILABLE = check_ffmpeg()
40
 
41
- # ========================== # Improved ByteTrack Implementation # ==========================
42
  class BYTETracker:
43
- def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
- self.match_thresh = match_thresh # Increased matching threshold
47
  self.frame_rate = frame_rate
48
- self.next_id = 1
49
  self.tracks = {}
 
50
  self.last_positions = {}
51
- self.worker_appearance = {} # Track worker appearance patterns
52
 
53
  def update(self, dets, scores, cls):
54
  tracks = []
55
  current_time = time.time()
56
-
57
  # Prune stale tracks
58
- stale_ids = [tid for tid, track in self.tracks.items()
59
- if current_time - track['last_seen'] > self.track_buffer / self.frame_rate]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- for tid in stale_ids:
62
- del self.tracks[tid]
63
- if tid in self.last_positions:
64
- del self.last_positions[tid]
65
- if tid in self.worker_appearance:
66
- del self.worker_appearance[tid]
67
-
68
- for det, score, cl in zip(dets, scores, cls):
69
  if score < self.track_thresh:
70
  continue
71
-
72
  x, y, w, h = det
73
  matched = False
74
-
75
- # Find best match among active tracks
76
- best_match = None
77
  best_iou = 0
78
- for tid, track in self.tracks.items():
79
- tx, ty, tw, th = track['bbox']
 
 
 
80
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
81
-
82
- # Additional check for similar appearance
83
- if tid in self.worker_appearance:
84
- appearance_similarity = self._appearance_similarity([x,y,w,h], self.worker_appearance[tid])
85
- iou = (iou + appearance_similarity) / 2 # Combine spatial and appearance similarity
86
-
87
  if iou > self.match_thresh and iou > best_iou:
88
  best_iou = iou
89
- best_match = tid
90
-
91
- if best_match is not None:
92
- # Update existing track
93
- self.tracks[best_match].update({
94
  'bbox': [x, y, w, h],
95
  'score': score,
96
  'cls': cl,
97
  'last_seen': current_time
98
  })
99
- self.last_positions[best_match] = [x, y]
100
- self.worker_appearance[best_match] = [x, y, w, h] # Update appearance
 
101
  tracks.append({
102
- 'id': best_match,
103
  'bbox': [x, y, w, h],
104
  'score': score,
105
  'cls': cl
106
  })
107
  else:
108
- # Check if this might be an existing worker based on movement pattern
109
- existing_worker = self._find_existing_worker([x, y, w, h])
110
- if existing_worker is not None:
111
- tid = existing_worker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  else:
113
- tid = self.next_id
114
- self.next_id += 1
115
-
116
- self.tracks[tid] = {
117
- 'bbox': [x, y, w, h],
118
- 'score': score,
119
- 'cls': cl,
120
- 'last_seen': current_time
121
- }
122
- self.last_positions[tid] = [x, y]
123
- self.worker_appearance[tid] = [x, y, w, h]
124
- tracks.append({
125
- 'id': tid,
126
- 'bbox': [x, y, w, h],
127
- 'score': score,
128
- 'cls': cl
129
- })
130
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  return tracks
132
 
133
  def _calculate_iou(self, box1, box2):
@@ -142,23 +201,13 @@ class BYTETracker:
142
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
143
  box1_area = w1 * h1
144
  box2_area = w2 * h2
145
- return intersection_area / (box1_area + box2_area - intersection_area)
146
-
147
- def _appearance_similarity(self, box1, box2):
148
- # Simple size similarity (can be enhanced with more sophisticated features)
149
- _, _, w1, h1 = box1
150
- _, _, w2, h2 = box2
151
- size_similarity = 1 - abs(w1*h1 - w2*h2) / max(w1*h1, w2*h2)
152
- return size_similarity
153
-
154
- def _find_existing_worker(self, box):
155
- x, y, w, h = box
156
- for tid, last_pos in self.last_positions.items():
157
- lx, ly = last_pos
158
- distance = np.sqrt((x - lx)**2 + (y - ly)**2)
159
- if distance < 50: # If very close to last known position
160
- return tid
161
- return None
162
 
163
  # ========================== # Optimized Configuration # ==========================
164
  CONFIG = {
@@ -203,15 +252,15 @@ CONFIG = {
203
  "VIOLATION_COOLDOWN": 30.0,
204
  "WORKER_TRACKING_DURATION": 5.0,
205
  "MAX_PROCESSING_TIME": 60,
206
- "FRAME_SKIP": 2, # Balanced processing speed and accuracy
207
- "BATCH_SIZE": 4,
208
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
209
- "TRACK_BUFFER": 30,
210
  "TRACK_THRESH": 0.3,
211
- "MATCH_THRESH": 0.7, # Increased for more strict matching
212
- "SNAPSHOT_QUALITY": 90,
213
- "MAX_WORKER_DISTANCE": 50, # Reduced for more precise tracking
214
- "TARGET_RESOLUTION": (384, 384) # Balanced resolution
215
  }
216
 
217
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -228,7 +277,7 @@ def load_model():
228
  if not os.path.isfile(model_path):
229
  logger.info(f"Downloading fallback model: {model_path}")
230
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
231
-
232
  model = YOLO(model_path).to(device)
233
  if device.type == "cuda":
234
  model.model.half()
@@ -240,34 +289,299 @@ def load_model():
240
 
241
  model = load_model()
242
 
243
- # [Rest of your helper functions (preprocess_frame, draw_detections, calculate_safety_score,
244
- # generate_violation_pdf, connect_to_salesforce, push_report_to_salesforce, upload_pdf_to_salesforce)
245
- # remain exactly the same as in your original code]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- # ========================== # Improved Video Processing # ==========================
248
  def process_video(video_data, temp_dir):
249
  video_path = None
250
  output_dir = os.path.join(temp_dir, "output")
251
  os.makedirs(output_dir, exist_ok=True)
252
-
 
253
  try:
 
 
 
 
 
 
 
254
  with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
255
  temp_file.write(video_data)
 
256
  video_path = temp_file.name
 
257
 
258
- cap = cv2.VideoCapture(video_path)
259
- if not cap.isOpened():
260
- raise ValueError("Could not open video file")
 
 
 
 
 
 
261
 
262
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
263
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
264
  duration = total_frames / fps
265
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
266
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
267
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
268
 
269
  if total_frames <= 0:
270
- raise ValueError("Video has no frames")
271
 
272
  tracker = BYTETracker(
273
  track_thresh=CONFIG["TRACK_THRESH"],
@@ -279,6 +593,7 @@ def process_video(video_data, temp_dir):
279
  worker_id_mapping = {}
280
  unique_violations = {}
281
  violation_frames = {}
 
282
  start_time = time.time()
283
  frame_skip = CONFIG["FRAME_SKIP"]
284
  processed_frames = 0
@@ -288,23 +603,23 @@ def process_video(video_data, temp_dir):
288
  while processed_frames < total_frames:
289
  batch_frames = []
290
  batch_indices = []
291
-
292
  for _ in range(CONFIG["BATCH_SIZE"]):
293
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
294
  if frame_idx >= total_frames:
295
  break
296
-
297
  ret, frame = cap.read()
298
  if not ret:
299
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
300
  break
301
-
302
  frame = preprocess_frame(frame)
303
-
304
  for _ in range(frame_skip - 1):
305
  if not cap.grab():
306
  break
307
-
308
  batch_frames.append(frame)
309
  batch_indices.append(frame_idx)
310
  processed_frames += 1
@@ -320,7 +635,8 @@ def process_video(video_data, temp_dir):
320
  if device.type == "cuda":
321
  batch_frames_tensor = batch_frames_tensor.half()
322
 
323
- results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
 
324
  except Exception as e:
325
  logger.error(f"Model inference failed: {e}")
326
  raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}")
@@ -330,27 +646,32 @@ def process_video(video_data, temp_dir):
330
  torch.cuda.empty_cache()
331
 
332
  current_time = time.time()
333
- if current_time - last_yield_time > 0.5: # Update progress every 0.5s
334
  progress = (processed_frames / total_frames) * 100
335
  elapsed_time = current_time - start_time
336
  fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
337
  yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
338
  last_yield_time = current_time
339
 
 
 
 
 
 
340
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
341
  current_time = frame_idx / fps
342
-
343
  boxes = result.boxes
344
  track_inputs = []
345
-
346
  for box in boxes:
347
  cls = int(box.cls)
348
  conf = float(box.conf)
349
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
350
-
351
  if label is None:
352
  continue
353
-
354
  if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
355
  continue
356
 
@@ -363,7 +684,7 @@ def process_video(video_data, temp_dir):
363
 
364
  if not track_inputs:
365
  continue
366
-
367
  tracked_objects = tracker.update(
368
  np.array([t["bbox"] for t in track_inputs]),
369
  np.array([t["conf"] for t in track_inputs]),
@@ -376,32 +697,30 @@ def process_video(video_data, temp_dir):
376
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
377
  conf = obj['score']
378
  bbox = obj['bbox']
379
-
380
  if label is None:
381
  continue
382
-
383
- # More conservative worker ID assignment
384
  if tracker_id not in worker_id_mapping:
385
- # Check if this is likely the same worker as before
386
- if len(worker_id_mapping) > 0: # If we already have a worker
387
- existing_worker_id = next(iter(worker_id_mapping.values()))
388
- worker_id_mapping[tracker_id] = existing_worker_id
389
- else:
390
- worker_id_mapping[tracker_id] = worker_counter
391
- worker_counter += 1
392
-
393
  worker_id = worker_id_mapping[tracker_id]
394
-
395
  violation_key = (worker_id, label)
396
-
397
  if violation_key not in unique_violations:
398
  unique_violations[violation_key] = current_time
399
  violation_frames[violation_key] = frame_idx
 
 
 
400
 
401
  cap.release()
402
  processing_time = time.time() - start_time
403
  logger.info(f"Processing complete in {processing_time:.2f}s")
404
  logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
 
405
 
406
  violations = []
407
  for (worker_id, label), detection_time in unique_violations.items():
@@ -418,20 +737,14 @@ def process_video(video_data, temp_dir):
418
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
419
  return
420
 
421
- # Generate snapshots (only for the first worker)
422
  snapshots = []
423
  cap = cv2.VideoCapture(video_path)
424
- worker_ids = set(v["worker_id"] for v in violations)
425
-
426
- # Only capture snapshots for the first worker (assuming single worker)
427
- first_worker_id = min(worker_ids) if worker_ids else 1
428
- worker_violations = [v for v in violations if v["worker_id"] == first_worker_id][:5] # Limit to 5 violations
429
-
430
- for violation in worker_violations:
431
  frame_idx = violation["frame_idx"]
432
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
433
  ret, frame = cap.read()
434
  if not ret:
 
435
  continue
436
 
437
  frame = preprocess_frame(frame)
@@ -446,13 +759,13 @@ def process_video(video_data, temp_dir):
446
  for box in boxes:
447
  cls = int(box.cls)
448
  conf = float(box.conf)
449
- box_label = CONFIG["VIOLATION_LABELS"].get(cls, None)
450
- if box_label == violation["violation"]:
451
  violation["confidence"] = round(conf, 2)
452
  bbox = box.xywh.cpu().numpy()[0]
453
  detection = {
454
  "worker_id": violation["worker_id"],
455
- "violation": box_label,
456
  "confidence": violation["confidence"],
457
  "bounding_box": bbox,
458
  "timestamp": violation["timestamp"]
@@ -468,7 +781,7 @@ def process_video(video_data, temp_dir):
468
  (255, 255, 255),
469
  2
470
  )
471
- snapshot_filename = f"violation_{box_label}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
472
  snapshot_path = os.path.join(output_dir, snapshot_filename)
473
  cv2.imwrite(
474
  snapshot_path,
@@ -476,55 +789,59 @@ def process_video(video_data, temp_dir):
476
  [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
477
  )
478
  snapshots.append({
479
- "violation": box_label,
480
  "worker_id": violation["worker_id"],
481
  "timestamp": violation["timestamp"],
482
  "snapshot_path": snapshot_path,
483
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
484
  "confidence": violation["confidence"]
485
  })
 
486
  break
487
 
488
  cap.release()
489
 
490
  score = calculate_safety_score(violations)
491
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
492
-
493
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
494
 
495
- # Generate output
496
- violation_table = "## Safety Violation Report\n"
497
-
498
- # Worker summary
499
  worker_summary = {}
500
  for v in violations:
501
- if v["worker_id"] not in worker_summary:
502
- worker_summary[v["worker_id"]] = {"count": 0, "types": set()}
503
- worker_summary[v["worker_id"]]["count"] += 1
504
- worker_summary[v["worker_id"]]["types"].add(v["violation"])
 
 
 
 
505
 
 
506
  violation_table += "| Worker ID | Total Violations | Violation Types |\n"
507
  violation_table += "|-----------|------------------|-----------------|\n"
 
508
  for worker_id, info in worker_summary.items():
509
- types = ", ".join([CONFIG["DISPLAY_NAMES"].get(t, t) for t in info["types"]])
510
- violation_table += f"| {worker_id} | {info['count']} | {types} |\n"
511
-
512
- violation_table += "\n## Detailed Violations\n"
513
- violation_table += "| Violation | Time (s) | Confidence |\n"
514
- violation_table += "|-----------|----------|------------|\n"
515
-
516
- for v in sorted(violations, key=lambda x: x.get("timestamp", 0.0)):
517
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
 
518
  timestamp = v.get("timestamp", 0.0)
519
  confidence = v.get("confidence", 0.0)
520
- violation_table += f"| {display_name} | {timestamp:.2f} | {confidence:.2f} |\n"
521
 
522
  snapshots_text = ""
523
  for s in snapshots:
524
  display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
525
  worker_id = s.get("worker_id", "Unknown")
526
  timestamp = s.get("timestamp", 0.0)
527
- snapshots_text += f"### {display_name} at {timestamp:.2f}s\n\n"
528
  snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
529
 
530
  if not snapshots_text:
@@ -535,37 +852,36 @@ def process_video(video_data, temp_dir):
535
  f"Safety Score: {score}%",
536
  snapshots_text,
537
  f"Salesforce Record ID: {record_id}",
538
- final_pdf_url if final_pdf_url else pdf_url
539
  )
540
 
541
  except Exception as e:
542
  logger.error(f"Error processing video: {str(e)}", exc_info=True)
543
- yield f"Error: {str(e)}", "", "", "", ""
544
  finally:
545
  if video_path and os.path.exists(video_path):
546
  try:
547
  os.remove(video_path)
 
548
  except Exception as e:
549
  logger.error(f"Failed to clean up temporary video file {video_path}: {e}")
550
  if device.type == "cuda":
551
  torch.cuda.empty_cache()
552
 
553
- # [Rest of your code (gradio_interface function and interface setup) remains the same]
554
-
555
  def gradio_interface(video_file):
556
  temp_dir = None
557
  local_video_path = None
558
  try:
559
  if not video_file:
560
  return "No file uploaded.", "", "No file uploaded.", "", ""
561
-
562
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
563
  logger.info(f"Created temporary directory for video processing: {temp_dir}")
564
 
565
  with open(video_file, "rb") as f:
566
  video_data = f.read()
567
  logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
568
-
569
  if len(video_data) == 0:
570
  return "Uploaded video file is empty.", "", "", "", ""
571
 
@@ -580,7 +896,7 @@ def gradio_interface(video_file):
580
 
581
  for status, score, snapshots_text, record_id, details_url in process_video(video_data, temp_dir):
582
  yield status, score, snapshots_text, record_id, details_url
583
-
584
  except Exception as e:
585
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
586
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
@@ -591,7 +907,7 @@ def gradio_interface(video_file):
591
  logger.info(f"Cleaned up local temporary video file: {local_video_path}")
592
  except Exception as e:
593
  logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
594
-
595
  if temp_dir and os.path.exists(temp_dir):
596
  shutil.rmtree(temp_dir, ignore_errors=True)
597
  logger.info(f"Cleaned up temporary directory: {temp_dir}")
 
38
 
39
  FFMPEG_AVAILABLE = check_ffmpeg()
40
 
41
+ # ========================== # Optimized BYTETracker Implementation # ==========================
42
  class BYTETracker:
43
+ def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.6, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
+ self.match_thresh = match_thresh # Increased for stricter matching
47
  self.frame_rate = frame_rate
48
+ self.next_id = 1 # Start IDs from 1
49
  self.tracks = {}
50
+ self.worker_history = {}
51
  self.last_positions = {}
52
+ self.recently_removed = {}
53
 
54
  def update(self, dets, scores, cls):
55
  tracks = []
56
  current_time = time.time()
57
+
58
  # Prune stale tracks
59
+ stale_ids = []
60
+ for track_id, track_info in self.tracks.items():
61
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
62
+ stale_ids.append(track_id)
63
+
64
+ for track_id in stale_ids:
65
+ self.recently_removed[track_id] = {
66
+ 'bbox': self.tracks[track_id]['bbox'],
67
+ 'last_seen': current_time,
68
+ 'last_position': self.last_positions.get(track_id, [0, 0])
69
+ }
70
+ del self.tracks[track_id]
71
+ if track_id in self.worker_history:
72
+ del self.worker_history[track_id]
73
+ if track_id in self.last_positions:
74
+ del self.last_positions[track_id]
75
+
76
+ # Clean up recently_removed tracks older than 0.5 seconds
77
+ to_remove = []
78
+ for track_id, info in self.recently_removed.items():
79
+ if current_time - info['last_seen'] > 0.5:
80
+ to_remove.append(track_id)
81
+ for track_id in to_remove:
82
+ del self.recently_removed[track_id]
83
+
84
+ # Precompute bounding box centers for efficiency
85
+ det_centers = [(det[0], det[1]) for det in dets]
86
 
87
+ for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
 
 
 
 
 
 
 
88
  if score < self.track_thresh:
89
  continue
90
+
91
  x, y, w, h = det
92
  matched = False
 
 
 
93
  best_iou = 0
94
+ best_track_id = None
95
+
96
+ # Try to match with active tracks
97
+ for track_id, track_info in self.tracks.items():
98
+ tx, ty, tw, th = track_info['bbox']
99
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
100
+
 
 
 
 
 
101
  if iou > self.match_thresh and iou > best_iou:
102
  best_iou = iou
103
+ best_track_id = track_id
104
+ matched = True
105
+
106
+ if matched:
107
+ self.tracks[best_track_id].update({
108
  'bbox': [x, y, w, h],
109
  'score': score,
110
  'cls': cl,
111
  'last_seen': current_time
112
  })
113
+ self.worker_history[best_track_id].append([x, y])
114
+ self.last_positions[best_track_id] = [x, y]
115
+
116
  tracks.append({
117
+ 'id': best_track_id,
118
  'bbox': [x, y, w, h],
119
  'score': score,
120
  'cls': cl
121
  })
122
  else:
123
+ # Try to re-identify with recently removed tracks
124
+ reidentified = False
125
+ min_distance = float('inf')
126
+ best_removed_id = None
127
+
128
+ for track_id, info in self.recently_removed.items():
129
+ distance = self._calculate_distance([x, y], info['last_position'])
130
+ if distance < CONFIG["MAX_WORKER_DISTANCE"] and distance < min_distance:
131
+ min_distance = distance
132
+ best_removed_id = track_id
133
+ reidentified = True
134
+
135
+ if reidentified:
136
+ self.tracks[best_removed_id] = {
137
+ 'bbox': [x, y, w, h],
138
+ 'score': score,
139
+ 'cls': cl,
140
+ 'last_seen': current_time
141
+ }
142
+ self.worker_history[best_removed_id] = self.worker_history.get(best_removed_id, []) + [[x, y]]
143
+ self.last_positions[best_removed_id] = [x, y]
144
+ tracks.append({
145
+ 'id': best_removed_id,
146
+ 'bbox': [x, y, w, h],
147
+ 'score': score,
148
+ 'cls': cl
149
+ })
150
+ del self.recently_removed[best_removed_id]
151
  else:
152
+ # Only create new ID if no existing worker is close
153
+ same_worker = False
154
+ for track_id, last_pos in self.last_positions.items():
155
+ if self._calculate_distance([x, y], last_pos) < CONFIG["MAX_WORKER_DISTANCE"]:
156
+ self.tracks[track_id] = {
157
+ 'bbox': [x, y, w, h],
158
+ 'score': score,
159
+ 'cls': cl,
160
+ 'last_seen': current_time
161
+ }
162
+ self.worker_history[track_id].append([x, y])
163
+ self.last_positions[track_id] = [x, y]
164
+ tracks.append({
165
+ 'id': track_id,
166
+ 'bbox': [x, y, w, h],
167
+ 'score': score,
168
+ 'cls': cl
169
+ })
170
+ same_worker = True
171
+ break
172
+
173
+ if not same_worker:
174
+ self.tracks[self.next_id] = {
175
+ 'bbox': [x, y, w, h],
176
+ 'score': score,
177
+ 'cls': cl,
178
+ 'last_seen': current_time
179
+ }
180
+ self.worker_history[self.next_id] = [[x, y]]
181
+ self.last_positions[self.next_id] = [x, y]
182
+ tracks.append({
183
+ 'id': self.next_id,
184
+ 'bbox': [x, y, w, h],
185
+ 'score': score,
186
+ 'cls': cl
187
+ })
188
+ self.next_id += 1
189
+
190
  return tracks
191
 
192
  def _calculate_iou(self, box1, box2):
 
201
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
202
  box1_area = w1 * h1
203
  box2_area = w2 * h2
204
+ iou = intersection_area / (box1_area + box2_area - intersection_area)
205
+ return iou
206
+
207
+ def _calculate_distance(self, pos1, pos2):
208
+ x1, y1 = pos1
209
+ x2, y2 = pos2
210
+ return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
 
 
 
 
 
 
 
 
 
 
211
 
212
  # ========================== # Optimized Configuration # ==========================
213
  CONFIG = {
 
252
  "VIOLATION_COOLDOWN": 30.0,
253
  "WORKER_TRACKING_DURATION": 5.0,
254
  "MAX_PROCESSING_TIME": 60,
255
+ "FRAME_SKIP": 1,
256
+ "BATCH_SIZE": 8, # Increased for better GPU utilization
257
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
258
+ "TRACK_BUFFER": 150,
259
  "TRACK_THRESH": 0.3,
260
+ "MATCH_THRESH": 0.6, # Increased for stricter matching
261
+ "SNAPSHOT_QUALITY": 95,
262
+ "MAX_WORKER_DISTANCE": 150,
263
+ "TARGET_RESOLUTION": (320, 320) # Reduced for faster processing
264
  }
265
 
266
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
277
  if not os.path.isfile(model_path):
278
  logger.info(f"Downloading fallback model: {model_path}")
279
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
280
+
281
  model = YOLO(model_path).to(device)
282
  if device.type == "cuda":
283
  model.model.half()
 
289
 
290
  model = load_model()
291
 
292
+ # ========================== # Helper Functions # ==========================
293
+ def preprocess_frame(frame):
294
+ target_res = CONFIG["TARGET_RESOLUTION"]
295
+ frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_AREA) # Faster interpolation
296
+ frame = cv2.convertScaleAbs(frame, alpha=1.1, beta=10) # Reduced contrast adjustment
297
+ return frame
298
+
299
+ def draw_detections(frame, detections):
300
+ result_frame = frame.copy()
301
+
302
+ for det in detections:
303
+ label = det.get("violation", "Unknown")
304
+ confidence = det.get("confidence", 0.0)
305
+ x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
306
+ worker_id = det.get("worker_id", "Unknown")
307
+
308
+ x1 = int(x - w/2)
309
+ y1 = int(y - h/2)
310
+ x2 = int(x + w/2)
311
+ y2 = int(y + h/2)
312
+
313
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
314
+
315
+ cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 2)
316
+
317
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
318
+ text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
319
+ cv2.rectangle(result_frame, (x1, y1-text_size[1]-5), (x1+text_size[0]+5, y1), (0, 0, 0), -1)
320
+ cv2.putText(result_frame, display_text, (x1+3, y1-3), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
321
+
322
+ conf_text = f"Conf: {confidence:.2f}"
323
+ cv2.putText(result_frame, conf_text, (x1+3, y2+15), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
324
+
325
+ return result_frame
326
+
327
+ def calculate_safety_score(violations):
328
+ penalties = {
329
+ "no_helmet": 25,
330
+ "no_harness": 30,
331
+ "unsafe_posture": 20,
332
+ "unsafe_zone": 35,
333
+ "improper_tool_use": 25
334
+ }
335
+
336
+ worker_violations = {}
337
+ for v in violations:
338
+ worker_id = v.get("worker_id", "Unknown")
339
+ violation_type = v.get("violation", "Unknown")
340
+
341
+ if worker_id not in worker_violations:
342
+ worker_violations[worker_id] = set()
343
+ worker_violations[worker_id].add(violation_type)
344
+
345
+ total_penalty = 0
346
+ for worker_violations_set in worker_violations.values():
347
+ worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
348
+ total_penalty += worker_penalty
349
+
350
+ score = max(0, 100 - total_penalty)
351
+ return score
352
+
353
+ def generate_violation_pdf(violations, score, output_dir):
354
+ try:
355
+ pdf_filename = f"violations_{int(time.time())}.pdf"
356
+ pdf_path = os.path.join(output_dir, pdf_filename)
357
+ pdf_file = BytesIO()
358
+ c = canvas.Canvas(pdf_file, pagesize=letter)
359
+
360
+ c.setFont("Helvetica-Bold", 16)
361
+ c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
362
+
363
+ c.setFont("Helvetica", 12)
364
+ c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
365
+ c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
366
+
367
+ c.setFont("Helvetica-Bold", 14)
368
+ c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
369
+
370
+ y_position = 8.2 * inch
371
+ c.setFont("Helvetica-Bold", 12)
372
+ c.drawString(1 * inch, y_position, "Summary:")
373
+ y_position -= 0.3 * inch
374
+
375
+ worker_violations = {}
376
+ for v in violations:
377
+ worker_id = v.get("worker_id", "Unknown")
378
+ if worker_id not in worker_violations:
379
+ worker_violations[worker_id] = []
380
+ worker_violations[worker_id].append(v)
381
+
382
+ c.setFont("Helvetica", 10)
383
+ summary_data = {
384
+ "Total Workers with Violations": len(worker_violations),
385
+ "Total Violations Found": len(violations),
386
+ "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
387
+ }
388
+
389
+ for key, value in summary_data.items():
390
+ c.drawString(1 * inch, y_position, f"{key}: {value}")
391
+ y_position -= 0.25 * inch
392
+
393
+ y_position -= 0.5 * inch
394
+ c.setFont("Helvetica-Bold", 12)
395
+ c.drawString(1 * inch, y_position, "Violations by Worker:")
396
+ y_position -= 0.3 * inch
397
+
398
+ c.setFont("Helvetica", 10)
399
+ for worker_id, worker_vios in worker_violations.items():
400
+ c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
401
+ y_position -= 0.2 * inch
402
+
403
+ for v in worker_vios:
404
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
405
+ time_str = f"{v.get('timestamp', 0.0):.2f}s"
406
+ conf_str = f"{v.get('confidence', 0.0):.2f}"
407
+
408
+ violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
409
+ c.drawString(1.2 * inch, y_position, violation_text)
410
+ y_position -= 0.2 * inch
411
+
412
+ if y_position < 1 * inch:
413
+ c.showPage()
414
+ c.setFont("Helvetica", 10)
415
+ y_position = 10 * inch
416
+
417
+ c.save()
418
+ pdf_file.seek(0)
419
+
420
+ with open(pdf_path, "wb") as f:
421
+ f.write(pdf_file.getvalue())
422
+
423
+ public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
424
+ logger.info(f"PDF generated: {public_url}")
425
+ return pdf_path, public_url, pdf_file
426
+ except Exception as e:
427
+ logger.error(f"Error generating PDF: {e}")
428
+ return "", "", None
429
+
430
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
431
+ def connect_to_salesforce():
432
+ try:
433
+ sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
434
+ logger.info("Connected to Salesforce")
435
+ sf.describe()
436
+ return sf
437
+ except Exception as e:
438
+ logger.error(f"Salesforce connection failed: {e}")
439
+ raise
440
+
441
+ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
442
+ try:
443
+ if not pdf_file:
444
+ logger.error("No PDF file provided for upload")
445
+ return ""
446
+
447
+ encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
448
+ content_version_data = {
449
+ "Title": f"Safety_Violation_Report_{int(time.time())}",
450
+ "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
451
+ "VersionData": encoded_pdf,
452
+ "FirstPublishLocationId": report_id
453
+ }
454
+ content_version = sf.ContentVersion.create(content_version_data)
455
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
456
+
457
+ if not result['records']:
458
+ logger.error("Failed to retrieve ContentVersion")
459
+ return ""
460
+
461
+ file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
462
+ logger.info(f"PDF uploaded to Salesforce: {file_url}")
463
+ return file_url
464
+ except Exception as e:
465
+ logger.error(f"Error uploading PDF to Salesforce: {e}")
466
+ return ""
467
+
468
+ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
469
+ try:
470
+ sf = connect_to_salesforce()
471
+
472
+ violations_text = ""
473
+ for v in violations:
474
+ display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
475
+ worker_id = v.get('worker_id', 'Unknown')
476
+ timestamp = v.get('timestamp', 0.0)
477
+ confidence = v.get('confidence', 0.0)
478
+
479
+ violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
480
+
481
+ if not violations_text:
482
+ violations_text = "No violations detected."
483
+
484
+ pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
485
+
486
+ record_data = {
487
+ "Compliance_Score__c": score,
488
+ "Violations_Found__c": len(violations),
489
+ "Violations_Details__c": violations_text,
490
+ "Status__c": "Pending",
491
+ "PDF_Report_URL__c": pdf_url
492
+ }
493
+
494
+ logger.info(f"Creating Salesforce record with data: {record_data}")
495
+
496
+ try:
497
+ record = sf.Safety_Video_Report__c.create(record_data)
498
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
499
+ except Exception as e:
500
+ logger.error(f"Failed to create Safety_Video_Report__c: {e}")
501
+ record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
502
+ logger.warning(f"Fell back to Account record: {record['id']}")
503
+
504
+ record_id = record["id"]
505
+
506
+ if pdf_file:
507
+ uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
508
+ if uploaded_url:
509
+ try:
510
+ sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
511
+ logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
512
+ except Exception as e:
513
+ logger.error(f"Failed to update Safety_Video_Report__c: {e}")
514
+ sf.Account.update(record_id, {"Description": uploaded_url})
515
+ logger.info(f"Updated Account record {record_id} with PDF URL")
516
+ pdf_url = uploaded_url
517
+
518
+ return record_id, pdf_url
519
+ except Exception as e:
520
+ logger.error(f"Salesforce record creation failed: {e}")
521
+ return "N/A", "Salesforce integration failed."
522
+
523
+ @tenacity.retry(
524
+ stop=tenacity.stop_after_attempt(3),
525
+ wait=tenacity.wait_fixed(1),
526
+ retry=tenacity.retry_if_exception_type((IOError, OSError)),
527
+ before_sleep=lambda retry_state: logger.info(f"Retrying file access (attempt {retry_state.attempt_number}/3)...")
528
+ )
529
+ def verify_and_open_video(video_path):
530
+ if not os.path.exists(video_path):
531
+ raise FileNotFoundError(f"Temporary video file not found: {video_path}")
532
+
533
+ file_size = os.path.getsize(video_path)
534
+ if file_size == 0:
535
+ raise ValueError(f"Temporary video file is empty: {video_path}")
536
+
537
+ with open(video_path, "rb") as f:
538
+ f.read(1)
539
+
540
+ cap = cv2.VideoCapture(video_path)
541
+ if not cap.isOpened():
542
+ raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
543
+
544
+ return cap
545
 
 
546
  def process_video(video_data, temp_dir):
547
  video_path = None
548
  output_dir = os.path.join(temp_dir, "output")
549
  os.makedirs(output_dir, exist_ok=True)
550
+ os.environ['YOLO_CONFIG_DIR'] = temp_dir
551
+
552
  try:
553
+ if not video_data:
554
+ raise ValueError("Empty video data provided.")
555
+
556
+ logger.info(f"Received video data size: {len(video_data)} bytes")
557
+ if len(video_data) == 0:
558
+ raise ValueError("Video data is empty.")
559
+
560
  with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
561
  temp_file.write(video_data)
562
+ temp_file.flush()
563
  video_path = temp_file.name
564
+ logger.info(f"Video saved to temporary file: {video_path}")
565
 
566
+ if not os.path.exists(video_path):
567
+ raise FileNotFoundError(f"Temporary video file not found: {video_path}")
568
+ file_size = os.path.getsize(video_path)
569
+ if file_size == 0:
570
+ raise ValueError(f"Temporary video file is empty: {video_path}")
571
+ logger.info(f"Temporary video file size: {file_size} bytes")
572
+
573
+ cap = verify_and_open_video(video_path)
574
+ logger.info(f"Successfully opened video file: {video_path}")
575
 
 
576
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
577
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
578
  duration = total_frames / fps
579
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
580
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
581
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
582
 
583
  if total_frames <= 0:
584
+ raise ValueError("Video has no frames.")
585
 
586
  tracker = BYTETracker(
587
  track_thresh=CONFIG["TRACK_THRESH"],
 
593
  worker_id_mapping = {}
594
  unique_violations = {}
595
  violation_frames = {}
596
+ worker_violation_count = {}
597
  start_time = time.time()
598
  frame_skip = CONFIG["FRAME_SKIP"]
599
  processed_frames = 0
 
603
  while processed_frames < total_frames:
604
  batch_frames = []
605
  batch_indices = []
606
+
607
  for _ in range(CONFIG["BATCH_SIZE"]):
608
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
609
  if frame_idx >= total_frames:
610
  break
611
+
612
  ret, frame = cap.read()
613
  if not ret:
614
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
615
  break
616
+
617
  frame = preprocess_frame(frame)
618
+
619
  for _ in range(frame_skip - 1):
620
  if not cap.grab():
621
  break
622
+
623
  batch_frames.append(frame)
624
  batch_indices.append(frame_idx)
625
  processed_frames += 1
 
635
  if device.type == "cuda":
636
  batch_frames_tensor = batch_frames_tensor.half()
637
 
638
+ with torch.no_grad(): # Disable gradient computation
639
+ results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
640
  except Exception as e:
641
  logger.error(f"Model inference failed: {e}")
642
  raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}")
 
646
  torch.cuda.empty_cache()
647
 
648
  current_time = time.time()
649
+ if current_time - last_yield_time > 0.1:
650
  progress = (processed_frames / total_frames) * 100
651
  elapsed_time = current_time - start_time
652
  fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
653
  yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
654
  last_yield_time = current_time
655
 
656
+ # Early stopping if enough violations are detected
657
+ if len(unique_violations) >= 10 and processed_frames > total_frames * 0.5:
658
+ logger.info("Early stopping: Sufficient violations detected.")
659
+ break
660
+
661
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
662
  current_time = frame_idx / fps
663
+
664
  boxes = result.boxes
665
  track_inputs = []
666
+
667
  for box in boxes:
668
  cls = int(box.cls)
669
  conf = float(box.conf)
670
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
671
+
672
  if label is None:
673
  continue
674
+
675
  if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
676
  continue
677
 
 
684
 
685
  if not track_inputs:
686
  continue
687
+
688
  tracked_objects = tracker.update(
689
  np.array([t["bbox"] for t in track_inputs]),
690
  np.array([t["conf"] for t in track_inputs]),
 
697
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
698
  conf = obj['score']
699
  bbox = obj['bbox']
700
+
701
  if label is None:
702
  continue
703
+
 
704
  if tracker_id not in worker_id_mapping:
705
+ worker_id_mapping[tracker_id] = worker_counter
706
+ worker_counter += 1
707
+
 
 
 
 
 
708
  worker_id = worker_id_mapping[tracker_id]
709
+
710
  violation_key = (worker_id, label)
711
+
712
  if violation_key not in unique_violations:
713
  unique_violations[violation_key] = current_time
714
  violation_frames[violation_key] = frame_idx
715
+ if worker_id not in worker_violation_count:
716
+ worker_violation_count[worker_id] = 0
717
+ worker_violation_count[worker_id] += 1
718
 
719
  cap.release()
720
  processing_time = time.time() - start_time
721
  logger.info(f"Processing complete in {processing_time:.2f}s")
722
  logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
723
+ logger.info(f"Violations per worker: {worker_violation_count}")
724
 
725
  violations = []
726
  for (worker_id, label), detection_time in unique_violations.items():
 
737
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
738
  return
739
 
 
740
  snapshots = []
741
  cap = cv2.VideoCapture(video_path)
742
+ for violation in violations:
 
 
 
 
 
 
743
  frame_idx = violation["frame_idx"]
744
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
745
  ret, frame = cap.read()
746
  if not ret:
747
+ logger.warning(f"Failed to read frame {frame_idx} for snapshot.")
748
  continue
749
 
750
  frame = preprocess_frame(frame)
 
759
  for box in boxes:
760
  cls = int(box.cls)
761
  conf = float(box.conf)
762
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
763
+ if label == violation["violation"]:
764
  violation["confidence"] = round(conf, 2)
765
  bbox = box.xywh.cpu().numpy()[0]
766
  detection = {
767
  "worker_id": violation["worker_id"],
768
+ "violation": label,
769
  "confidence": violation["confidence"],
770
  "bounding_box": bbox,
771
  "timestamp": violation["timestamp"]
 
781
  (255, 255, 255),
782
  2
783
  )
784
+ snapshot_filename = f"violation_{label}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
785
  snapshot_path = os.path.join(output_dir, snapshot_filename)
786
  cv2.imwrite(
787
  snapshot_path,
 
789
  [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
790
  )
791
  snapshots.append({
792
+ "violation": label,
793
  "worker_id": violation["worker_id"],
794
  "timestamp": violation["timestamp"],
795
  "snapshot_path": snapshot_path,
796
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
797
  "confidence": violation["confidence"]
798
  })
799
+ logger.info(f"Captured snapshot for {label} violation by worker {violation['worker_id']} at {violation['factor']:.2f}s")
800
  break
801
 
802
  cap.release()
803
 
804
  score = calculate_safety_score(violations)
805
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
806
+
807
  record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
808
 
 
 
 
 
809
  worker_summary = {}
810
  for v in violations:
811
+ worker_id = v["worker_id"]
812
+ if worker_id not in worker_summary:
813
+ worker_summary[worker_id] = {
814
+ "count": 0,
815
+ "violations": set()
816
+ }
817
+ worker_summary[worker_id]["count"] += 1
818
+ worker_summary[worker_id]["violations"].add(v["violation"])
819
 
820
+ violation_table = "## Worker Safety Violation Summary\n\n"
821
  violation_table += "| Worker ID | Total Violations | Violation Types |\n"
822
  violation_table += "|-----------|------------------|-----------------|\n"
823
+
824
  for worker_id, info in worker_summary.items():
825
+ violation_types = ", ".join([CONFIG["DISPLAY_NAMES"].get(v, v) for v in info["violations"]])
826
+ violation_table += f"| {worker_id} | {info['count']} | {violation_types} |\n"
827
+
828
+ violation_table += "\n## Detailed Violation Log\n\n"
829
+ violation_table += "| Violation | Worker ID | Time (s) | Confidence |\n"
830
+ violation_table += "|-----------|-----------|----------|------------|\n"
831
+
832
+ for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
833
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
834
+ worker_id = v.get("worker_id", "Unknown")
835
  timestamp = v.get("timestamp", 0.0)
836
  confidence = v.get("confidence", 0.0)
837
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
838
 
839
  snapshots_text = ""
840
  for s in snapshots:
841
  display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
842
  worker_id = s.get("worker_id", "Unknown")
843
  timestamp = s.get("timestamp", 0.0)
844
+ snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
845
  snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
846
 
847
  if not snapshots_text:
 
852
  f"Safety Score: {score}%",
853
  snapshots_text,
854
  f"Salesforce Record ID: {record_id}",
855
+ final_pdf_url
856
  )
857
 
858
  except Exception as e:
859
  logger.error(f"Error processing video: {str(e)}", exc_info=True)
860
+ yield f"Error processing video: {str(e)}", "", "", "", ""
861
  finally:
862
  if video_path and os.path.exists(video_path):
863
  try:
864
  os.remove(video_path)
865
+ logger.info(f"Cleaned up temporary video file: {video_path}")
866
  except Exception as e:
867
  logger.error(f"Failed to clean up temporary video file {video_path}: {e}")
868
  if device.type == "cuda":
869
  torch.cuda.empty_cache()
870
 
 
 
871
  def gradio_interface(video_file):
872
  temp_dir = None
873
  local_video_path = None
874
  try:
875
  if not video_file:
876
  return "No file uploaded.", "", "No file uploaded.", "", ""
877
+
878
  temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
879
  logger.info(f"Created temporary directory for video processing: {temp_dir}")
880
 
881
  with open(video_file, "rb") as f:
882
  video_data = f.read()
883
  logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
884
+
885
  if len(video_data) == 0:
886
  return "Uploaded video file is empty.", "", "", "", ""
887
 
 
896
 
897
  for status, score, snapshots_text, record_id, details_url in process_video(video_data, temp_dir):
898
  yield status, score, snapshots_text, record_id, details_url
899
+
900
  except Exception as e:
901
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
902
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
 
907
  logger.info(f"Cleaned up local temporary video file: {local_video_path}")
908
  except Exception as e:
909
  logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
910
+
911
  if temp_dir and os.path.exists(temp_dir):
912
  shutil.rmtree(temp_dir, ignore_errors=True)
913
  logger.info(f"Cleaned up temporary directory: {temp_dir}")