PrashanthB461 commited on
Commit
01caa2c
·
verified ·
1 Parent(s): 18c5b9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -97
app.py CHANGED
@@ -224,7 +224,7 @@ class BYTETracker:
224
 
225
  # ========================== # Optimized Configuration # ==========================
226
  CONFIG = {
227
- "MODEL_NAME": "facebook/detr-resnet-50", # Fine-tune with your dataset, e.g., "your-username/detr-resnet-50-finetuned-safety"
228
  "VIOLATION_LABELS": {
229
  "no_helmet": "No Helmet",
230
  "no_harness": "No Harness",
@@ -263,25 +263,26 @@ CONFIG = {
263
  "MIN_VIOLATION_FRAMES": 2,
264
  "VIOLATION_COOLDOWN": 30.0,
265
  "WORKER_TRACKING_DURATION": 10.0,
266
- "MAX_PROCESSING_TIME": 60,
267
- "FRAME_SKIP": 2,
268
- "BATCH_SIZE": 8,
269
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
270
  "TRACK_BUFFER": 150,
271
  "TRACK_THRESH": 0.3,
272
  "MATCH_THRESH": 0.5,
273
  "SNAPSHOT_QUALITY": 95,
274
  "MAX_WORKER_DISTANCE": 150,
275
- "TARGET_RESOLUTION": (384, 384),
276
  "HELMET_VALIDATION_FRAMES": 3
277
  }
278
 
279
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280
  logger.info(f"Using device: {device}")
 
 
281
 
282
  def load_model():
283
  try:
284
- # Check for timm dependency
285
  import timm
286
  logger.info("timm library is available.")
287
  except ImportError as e:
@@ -312,7 +313,6 @@ def preprocess_frame(frame):
312
  return frame
313
 
314
  def is_unsafe_posture(box, frame_shape):
315
- """Placeholder for unsafe posture detection. Replace with pose estimation (e.g., MediaPipe)."""
316
  x1, y1, x2, y2 = box
317
  height = y2 - y1
318
  width = x2 - x1
@@ -320,14 +320,12 @@ def is_unsafe_posture(box, frame_shape):
320
  return aspect_ratio > 2.0
321
 
322
  def is_improper_tool_use(person_box, tool_box):
323
- """Placeholder for improper tool use. Fine-tune DETR for specific tools."""
324
  person_center = ((person_box[0] + person_box[2]) / 2, (person_box[1] + person_box[3]) / 2)
325
  tool_center = ((tool_box[0] + tool_box[2]) / 2, (tool_box[1] + tool_box[3]) / 2)
326
  dist = distance.euclidean(person_center, tool_center)
327
  return dist > 100
328
 
329
  def is_unsafe_zone(person_box, frame_shape):
330
- """Check if person is in restricted area (e.g., top-left quadrant)."""
331
  px, py, pw, ph = person_box
332
  frame_h, frame_w = frame_shape
333
  person_center = (px + pw / 2, py + ph / 2)
@@ -508,7 +506,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
508
  sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL": uploaded_url})
509
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
510
  except Exception as e:
511
- logger.error(f"Failed to update Safety_Video_Report__c: {e}")
512
  sf.Account.update(record_id, {"Description": uploaded_url})
513
  logger.info(f"Updated account record {record_id} with PDF URL")
514
  pdf_url = uploaded_url
@@ -617,6 +615,7 @@ def process_video(video_data, temp_dir):
617
  unique_violations = {}
618
  violation_frames = {}
619
  helmet_detections = {}
 
620
  start_time = time.time()
621
  frame_skip = CONFIG["FRAME_SKIP"]
622
  processed_frames = 0
@@ -634,7 +633,7 @@ def process_video(video_data, temp_dir):
634
  break
635
  ret, frame = cap.read()
636
  if not ret:
637
- logger.warning(f"Failed to read frame {frame_idx}. Skipping...")
638
  break
639
  original_frame = frame.copy()
640
  frame = preprocess_frame(frame)
@@ -644,12 +643,18 @@ def process_video(video_data, temp_dir):
644
  batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
645
  batch_indices.append(frame_idx)
646
  batch_originals.append(original_frame)
647
- processed_frames += frame_skip + 1
648
 
649
  if not batch_frames:
650
  logger.info("No more frames to process.")
651
  break
652
 
 
 
 
 
 
 
653
  try:
654
  inputs = processor(images=batch_frames, return_tensors="pt").to(device)
655
  if device.type == "cuda":
@@ -671,7 +676,7 @@ def process_video(video_data, temp_dir):
671
  progress = (processed_frames / total_frames) * 100
672
  elapsed_time = current_time - start_time
673
  fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
674
- yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", "", ""
675
  last_yield_time = current_time
676
 
677
  for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
@@ -680,6 +685,8 @@ def process_video(video_data, temp_dir):
680
  person_boxes = []
681
  tool_boxes = []
682
 
 
 
683
  for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
684
  label_name = model.config.id2label[label.item()]
685
  conf = float(score)
@@ -689,24 +696,27 @@ def process_video(video_data, temp_dir):
689
  bbox_xywh = [x + w/2, y + h/2, w, h]
690
 
691
  if label_name in ["no_helmet", "no_harness"] and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label_name, 0.25):
692
- if label_name == "no_helmet" and not validate_helmet(original_frame, bbox_xywh, conf):
693
- logger.info(f"Frame {frame_idx}: Height false positive violation filtered out at {conf:.2f} confidence")
694
  continue
695
  track_inputs.append({"bbox": bbox_xywh, "conf": conf, "cls": label_name})
 
696
  elif label_name == "person":
697
  person_boxes.append(bbox_xywh)
698
- elif label_name in ["hammer", "wrench"]: # Example tools; update with your dataset
699
  tool_boxes.append(bbox_xywh)
700
 
701
- # Handle Unsafe violations
702
  for pbox in person_boxes:
703
  if is_unsafe_posture(pbox, original_frame.shape[:2]):
704
  track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_posture"})
 
705
  if is_unsafe_zone(pbox, original_frame.shape[:2]):
706
  track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_zone"})
 
707
  for tbox in tool_boxes:
708
  if is_improper_tool_use(pbox, tbox):
709
  track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "improper_tool_use"})
 
710
 
711
  if not track_inputs:
712
  continue
@@ -722,6 +732,7 @@ def process_video(video_data, temp_dir):
722
  tracker_id = obj['id']
723
  label = obj['cls']
724
  conf = obj['score']
 
725
 
726
  if label not in CONFIG["VIOLATION_LABELS"]:
727
  continue
@@ -761,100 +772,104 @@ def process_video(video_data, temp_dir):
761
 
762
  violations = []
763
  for (worker_id, label), detection_time in unique_violations.items():
 
 
764
  violations.append({
765
  "worker_id": worker_id,
766
  "violation": label,
767
  "timestamp": detection_time,
768
- "confidence": 0.0,
769
  "frame_idx": violation_frames[(worker_id, label)]
770
  })
771
 
772
  if not violations:
773
  logger.info("No violations detected after processing")
774
- yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", ""
775
  return
776
 
 
 
 
 
 
 
 
 
 
 
 
777
  snapshots = []
778
  cap = cv2.VideoCapture(video_path)
779
  for violation in violations:
780
- frame_idx = violation["frame_idx"]
781
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
782
- ret, frame = cap.read()
783
- if not ret:
784
- logger.warning(f"Failed to read frame {frame_idx} for snapshot.")
785
- continue
 
786
 
787
- frame = preprocess_frame(frame)
788
- frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
789
- inputs = processor(images=frame_pil, return_tensors="pt").to(device)
790
- if device.type == "cuda":
791
- inputs = {k: v.half() for k, v in inputs.items()}
792
- with torch.no_grad():
793
- outputs = model(**inputs)
794
- target_sizes = torch.tensor([frame_pil.size[::-1]]).to(device)
795
- result = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.1)[0]
796
-
797
- for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
798
- label_name = model.config.id2label[label.item()]
799
- conf = float(score)
800
- bbox = box.cpu().numpy()
801
- x, y, x2, y2 = bbox
802
- w, h = x2 - x, y2 - y
803
- bbox_xywh = [x + w/2, y + h/2, w, h]
804
- if label_name == violation["violation"]:
805
- violation["confidence"] = round(conf, 2)
806
- detection = {
807
- "worker_id": violation["worker_id"],
808
- "violation": label_name,
809
- "confidence": violation["confidence"],
810
- "bounding_box": bbox_xywh,
811
- "timestamp": violation["timestamp"]
812
- }
813
- snapshot_frame = frame.copy()
814
- snapshot_frame = draw_detections(snapshot_frame, [detection])
815
- cv2.putText(
816
- snapshot_frame,
817
- f"Time: {violation['timestamp']:.2f}s",
818
- (10, 30),
819
- cv2.FONT_HERSHEY_SIMPLEX,
820
- 0.7,
821
- (255, 255, 255),
822
- 2
823
- )
824
- snapshot_filename = f"violation_{label_name}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
825
- snapshot_path = os.path.join(output_dir, snapshot_filename)
826
- cv2.imwrite(
827
- snapshot_path,
828
- snapshot_frame,
829
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
830
- )
831
- snapshots.append({
832
- "violation": label_name,
833
- "worker_id": violation["worker_id"],
834
- "timestamp": violation["timestamp"],
835
- "snapshot_path": snapshot_path,
836
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
837
- "confidence": violation["confidence"]
838
- })
839
- logger.info(f"Captured snapshot for {label_name} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
840
- break
841
 
842
  cap.release()
843
 
844
  score = calculate_safety_score(violations)
845
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
846
-
847
- record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
 
 
 
 
848
 
849
- violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
850
- violation_table += "|-----------|-----------|----------|------------|\n"
851
-
852
- for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
853
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
854
- worker_id = v.get("worker_id", "Unknown")
855
- timestamp = v.get("timestamp", 0.0)
856
- confidence = v.get("confidence", 0.0)
857
- violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
858
 
859
  snapshots_text = ""
860
  for s in snapshots:
@@ -873,12 +888,12 @@ def process_video(video_data, temp_dir):
873
  snapshots_text,
874
  f"Salesforce Record ID: {record_id}",
875
  final_pdf_url,
876
- ""
877
  )
878
 
879
  except Exception as e:
880
  logger.error(f"Error processing video: {str(e)}", exc_info=True)
881
- yield f"Error processing video: {str(e)}", "", "", "", "", ""
882
  finally:
883
  if video_path and os.path.exists(video_path):
884
  try:
@@ -915,12 +930,12 @@ def gradio_interface(video_file):
915
  if not FFMPEG_AVAILABLE:
916
  return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", "", ""
917
 
918
- for status, score, snapshots_text, record_id, details_url, _ in process_video(video_data, temp_dir):
919
- yield status, score, snapshots_text, record_id, details_url, ""
920
 
921
  except Exception as e:
922
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
923
- yield f"Error: {str(e)}", "", "Error in processing.", "", "", ""
924
  finally:
925
  if local_video_path and os.path.exists(local_video_path):
926
  try:
@@ -945,7 +960,7 @@ interface = gr.Interface(
945
  gr.Markdown(label="Snapshots"),
946
  gr.Textbox(label="Salesforce Record ID"),
947
  gr.Textbox(label="Violation Details URL"),
948
- gr.Textbox(label="Error Log", visible=False)
949
  ],
950
  title="Worksite Safety Violation Analyzer",
951
  description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
 
224
 
225
  # ========================== # Optimized Configuration # ==========================
226
  CONFIG = {
227
+ "MODEL_NAME": "facebook/detr-resnet-50",
228
  "VIOLATION_LABELS": {
229
  "no_helmet": "No Helmet",
230
  "no_harness": "No Harness",
 
263
  "MIN_VIOLATION_FRAMES": 2,
264
  "VIOLATION_COOLDOWN": 30.0,
265
  "WORKER_TRACKING_DURATION": 10.0,
266
+ "MAX_PROCESSING_TIME": 60, # Reduced for early termination
267
+ "FRAME_SKIP": 4, # Increased to reduce frames processed
268
+ "BATCH_SIZE": 4, # Reduced for CPU efficiency
269
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
270
  "TRACK_BUFFER": 150,
271
  "TRACK_THRESH": 0.3,
272
  "MATCH_THRESH": 0.5,
273
  "SNAPSHOT_QUALITY": 95,
274
  "MAX_WORKER_DISTANCE": 150,
275
+ "TARGET_RESOLUTION": (320, 320), # Reduced for faster inference
276
  "HELMET_VALIDATION_FRAMES": 3
277
  }
278
 
279
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280
  logger.info(f"Using device: {device}")
281
+ if device.type == "cpu":
282
+ logger.warning("Running on CPU, which may lead to slower processing. Consider using a GPU for better performance.")
283
 
284
  def load_model():
285
  try:
 
286
  import timm
287
  logger.info("timm library is available.")
288
  except ImportError as e:
 
313
  return frame
314
 
315
  def is_unsafe_posture(box, frame_shape):
 
316
  x1, y1, x2, y2 = box
317
  height = y2 - y1
318
  width = x2 - x1
 
320
  return aspect_ratio > 2.0
321
 
322
  def is_improper_tool_use(person_box, tool_box):
 
323
  person_center = ((person_box[0] + person_box[2]) / 2, (person_box[1] + person_box[3]) / 2)
324
  tool_center = ((tool_box[0] + tool_box[2]) / 2, (tool_box[1] + tool_box[3]) / 2)
325
  dist = distance.euclidean(person_center, tool_center)
326
  return dist > 100
327
 
328
  def is_unsafe_zone(person_box, frame_shape):
 
329
  px, py, pw, ph = person_box
330
  frame_h, frame_w = frame_shape
331
  person_center = (px + pw / 2, py + ph / 2)
 
506
  sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL": uploaded_url})
507
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
508
  except Exception as e:
509
+ logger.error(f"Failed to update Safety_Violation_Report__c: {e}")
510
  sf.Account.update(record_id, {"Description": uploaded_url})
511
  logger.info(f"Updated account record {record_id} with PDF URL")
512
  pdf_url = uploaded_url
 
615
  unique_violations = {}
616
  violation_frames = {}
617
  helmet_detections = {}
618
+ frame_detections = {} # Store detections for snapshot reuse
619
  start_time = time.time()
620
  frame_skip = CONFIG["FRAME_SKIP"]
621
  processed_frames = 0
 
633
  break
634
  ret, frame = cap.read()
635
  if not ret:
636
+ logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
637
  break
638
  original_frame = frame.copy()
639
  frame = preprocess_frame(frame)
 
643
  batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
644
  batch_indices.append(frame_idx)
645
  batch_originals.append(original_frame)
646
+ processed_frames += frame_skip
647
 
648
  if not batch_frames:
649
  logger.info("No more frames to process.")
650
  break
651
 
652
+ # Check for timeout
653
+ elapsed_time = time.time() - start_time
654
+ if elapsed_time > CONFIG["MAX_PROCESSING_TIME"]:
655
+ logger.warning(f"Processing exceeded time limit of {CONFIG['MAX_PROCESSING_TIME']}s. Terminating early.")
656
+ break
657
+
658
  try:
659
  inputs = processor(images=batch_frames, return_tensors="pt").to(device)
660
  if device.type == "cuda":
 
676
  progress = (processed_frames / total_frames) * 100
677
  elapsed_time = current_time - start_time
678
  fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
679
+ yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", "", f"Elapsed: {elapsed_time:.1f}s"
680
  last_yield_time = current_time
681
 
682
  for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
 
685
  person_boxes = []
686
  tool_boxes = []
687
 
688
+ frame_detections[frame_idx] = [] # Store detections for this frame
689
+
690
  for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
691
  label_name = model.config.id2label[label.item()]
692
  conf = float(score)
 
696
  bbox_xywh = [x + w/2, y + h/2, w, h]
697
 
698
  if label_name in ["no_helmet", "no_harness"] and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label_name, 0.25):
699
+ if label_name == "no_helmet" and not validate_helmet_detection(original_frame, bbox_xywh, conf):
700
+ logger.info(f"Frame {frame_idx}: Helmet false positive filtered at {conf:.2f} confidence")
701
  continue
702
  track_inputs.append({"bbox": bbox_xywh, "conf": conf, "cls": label_name})
703
+ frame_detections[frame_idx].append({"label": label_name, "conf": conf, "bbox": bbox_xywh})
704
  elif label_name == "person":
705
  person_boxes.append(bbox_xywh)
706
+ elif label_name in ["hammer", "wrench"]:
707
  tool_boxes.append(bbox_xywh)
708
 
 
709
  for pbox in person_boxes:
710
  if is_unsafe_posture(pbox, original_frame.shape[:2]):
711
  track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_posture"})
712
+ frame_detections[frame_idx].append({"label": "unsafe_posture", "conf": 0.9, "bbox": pbox})
713
  if is_unsafe_zone(pbox, original_frame.shape[:2]):
714
  track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_zone"})
715
+ frame_detections[frame_idx].append({"label": "unsafe_zone", "conf": 0.9, "bbox": pbox})
716
  for tbox in tool_boxes:
717
  if is_improper_tool_use(pbox, tbox):
718
  track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "improper_tool_use"})
719
+ frame_detections[frame_idx].append({"label": "improper_tool_use", "conf": 0.9, "bbox": pbox})
720
 
721
  if not track_inputs:
722
  continue
 
732
  tracker_id = obj['id']
733
  label = obj['cls']
734
  conf = obj['score']
735
+ bbox = obj['bbox']
736
 
737
  if label not in CONFIG["VIOLATION_LABELS"]:
738
  continue
 
772
 
773
  violations = []
774
  for (worker_id, label), detection_time in unique_violations.items():
775
+ frame_idx = violation_frames[(worker_id, label)]
776
+ conf = next((d["conf"] for d in frame_detections.get(frame_idx, []) if d["label"] == label), 0.0)
777
  violations.append({
778
  "worker_id": worker_id,
779
  "violation": label,
780
  "timestamp": detection_time,
781
+ "confidence": conf,
782
  "frame_idx": violation_frames[(worker_id, label)]
783
  })
784
 
785
  if not violations:
786
  logger.info("No violations detected after processing")
787
+ yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", f"Completed in {processing_time:.1f}s"
788
  return
789
 
790
+ # Generate violation table early for intermediate output
791
+ violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
792
+ violation_table += "|-----------|-----------|----------|------------|\n"
793
+ for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
794
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
795
+ worker_id = v.get("worker_id", "Unknown")
796
+ timestamp = v.get("timestamp", 0.0)
797
+ confidence = v.get("confidence", 0.0)
798
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
799
+ yield violation_table, "", "", "", "", f"Violations detected in {processing_time:.1f}s"
800
+
801
  snapshots = []
802
  cap = cv2.VideoCapture(video_path)
803
  for violation in violations:
804
+ try:
805
+ frame_idx = violation["frame_idx"]
806
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
807
+ ret, frame = cap.read()
808
+ if not ret:
809
+ logger.warning(f"Failed to read frame {frame_idx} for snapshot.")
810
+ continue
811
 
812
+ frame = preprocess_frame(frame)
813
+ # Reuse detections instead of re-running inference
814
+ detections = frame_detections.get(frame_idx, [])
815
+ for det in detections:
816
+ if det["label"] == violation["violation"]:
817
+ violation["confidence"] = round(det["conf"], 2)
818
+ detection = {
819
+ "worker_id": violation["worker_id"],
820
+ "violation": det["label"],
821
+ "confidence": violation["confidence"],
822
+ "bounding_box": det["bbox"],
823
+ "timestamp": violation["timestamp"]
824
+ }
825
+ snapshot_frame = frame.copy()
826
+ snapshot_frame = draw_detections(snapshot_frame, [detection])
827
+ cv2.putText(
828
+ snapshot_frame,
829
+ f"Time: {violation['timestamp']:.2f}s",
830
+ (10, 30),
831
+ cv2.FONT_HERSHEY_SIMPLEX,
832
+ 0.7,
833
+ (255, 255, 255),
834
+ 2
835
+ )
836
+ snapshot_filename = f"violation_{det['label']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
837
+ snapshot_path = os.path.join(output_dir, snapshot_filename)
838
+ cv2.imwrite(
839
+ snapshot_path,
840
+ snapshot_frame,
841
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
842
+ )
843
+ snapshots.append({
844
+ "violation": det["label"],
845
+ "worker_id": violation["worker_id"],
846
+ "timestamp": violation["timestamp"],
847
+ "snapshot_path": snapshot_path,
848
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
849
+ "confidence": violation["confidence"]
850
+ })
851
+ logger.info(f"Captured snapshot for {det['label']} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
852
+ break
853
+ except Exception as e:
854
+ logger.error(f"Error generating snapshot for violation: {e}")
855
+ continue
 
 
 
 
 
 
 
 
 
 
856
 
857
  cap.release()
858
 
859
  score = calculate_safety_score(violations)
860
+ pdf_path, pdf_url, pdf_file = "", "", None
861
+ try:
862
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
863
+ except Exception as e:
864
+ logger.error(f"PDF generation failed: {e}")
865
+ yield violation_table, f"Safety Score: {score}%", "Failed to generate snapshots due to PDF error.", "N/A", "N/A", f"Completed in {processing_time:.1f}s\nError: {str(e)}"
866
+ return
867
 
868
+ record_id, final_pdf_url = "N/A", "Salesforce integration failed."
869
+ try:
870
+ record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
871
+ except Exception as e:
872
+ logger.error(f"Salesforce integration failed: {e}")
 
 
 
 
873
 
874
  snapshots_text = ""
875
  for s in snapshots:
 
888
  snapshots_text,
889
  f"Salesforce Record ID: {record_id}",
890
  final_pdf_url,
891
+ f"Completed in {processing_time:.1f}s"
892
  )
893
 
894
  except Exception as e:
895
  logger.error(f"Error processing video: {str(e)}", exc_info=True)
896
+ yield f"Error processing video: {str(e)}", "", "", "", "", f"Failed after {time.time() - start_time:.1f}s"
897
  finally:
898
  if video_path and os.path.exists(video_path):
899
  try:
 
930
  if not FFMPEG_AVAILABLE:
931
  return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", "", ""
932
 
933
+ for status, score, snapshots_text, record_id, details_url, log in process_video(video_data, temp_dir):
934
+ yield status, score, snapshots_text, record_id, details_url, log
935
 
936
  except Exception as e:
937
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
938
+ yield f"Error: {str(e)}", "", "Error in processing.", "", "", str(e)
939
  finally:
940
  if local_video_path and os.path.exists(local_video_path):
941
  try:
 
960
  gr.Markdown(label="Snapshots"),
961
  gr.Textbox(label="Salesforce Record ID"),
962
  gr.Textbox(label="Violation Details URL"),
963
+ gr.Textbox(label="Processing Log")
964
  ],
965
  title="Worksite Safety Violation Analyzer",
966
  description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",