Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -224,7 +224,7 @@ class BYTETracker:
|
|
| 224 |
|
| 225 |
# ========================== # Optimized Configuration # ==========================
|
| 226 |
CONFIG = {
|
| 227 |
-
"MODEL_NAME": "facebook/detr-resnet-50",
|
| 228 |
"VIOLATION_LABELS": {
|
| 229 |
"no_helmet": "No Helmet",
|
| 230 |
"no_harness": "No Harness",
|
|
@@ -263,25 +263,26 @@ CONFIG = {
|
|
| 263 |
"MIN_VIOLATION_FRAMES": 2,
|
| 264 |
"VIOLATION_COOLDOWN": 30.0,
|
| 265 |
"WORKER_TRACKING_DURATION": 10.0,
|
| 266 |
-
"MAX_PROCESSING_TIME": 60,
|
| 267 |
-
"FRAME_SKIP":
|
| 268 |
-
"BATCH_SIZE":
|
| 269 |
"PARALLEL_WORKERS": max(1, cpu_count() - 1),
|
| 270 |
"TRACK_BUFFER": 150,
|
| 271 |
"TRACK_THRESH": 0.3,
|
| 272 |
"MATCH_THRESH": 0.5,
|
| 273 |
"SNAPSHOT_QUALITY": 95,
|
| 274 |
"MAX_WORKER_DISTANCE": 150,
|
| 275 |
-
"TARGET_RESOLUTION": (
|
| 276 |
"HELMET_VALIDATION_FRAMES": 3
|
| 277 |
}
|
| 278 |
|
| 279 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 280 |
logger.info(f"Using device: {device}")
|
|
|
|
|
|
|
| 281 |
|
| 282 |
def load_model():
|
| 283 |
try:
|
| 284 |
-
# Check for timm dependency
|
| 285 |
import timm
|
| 286 |
logger.info("timm library is available.")
|
| 287 |
except ImportError as e:
|
|
@@ -312,7 +313,6 @@ def preprocess_frame(frame):
|
|
| 312 |
return frame
|
| 313 |
|
| 314 |
def is_unsafe_posture(box, frame_shape):
|
| 315 |
-
"""Placeholder for unsafe posture detection. Replace with pose estimation (e.g., MediaPipe)."""
|
| 316 |
x1, y1, x2, y2 = box
|
| 317 |
height = y2 - y1
|
| 318 |
width = x2 - x1
|
|
@@ -320,14 +320,12 @@ def is_unsafe_posture(box, frame_shape):
|
|
| 320 |
return aspect_ratio > 2.0
|
| 321 |
|
| 322 |
def is_improper_tool_use(person_box, tool_box):
|
| 323 |
-
"""Placeholder for improper tool use. Fine-tune DETR for specific tools."""
|
| 324 |
person_center = ((person_box[0] + person_box[2]) / 2, (person_box[1] + person_box[3]) / 2)
|
| 325 |
tool_center = ((tool_box[0] + tool_box[2]) / 2, (tool_box[1] + tool_box[3]) / 2)
|
| 326 |
dist = distance.euclidean(person_center, tool_center)
|
| 327 |
return dist > 100
|
| 328 |
|
| 329 |
def is_unsafe_zone(person_box, frame_shape):
|
| 330 |
-
"""Check if person is in restricted area (e.g., top-left quadrant)."""
|
| 331 |
px, py, pw, ph = person_box
|
| 332 |
frame_h, frame_w = frame_shape
|
| 333 |
person_center = (px + pw / 2, py + ph / 2)
|
|
@@ -508,7 +506,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
|
|
| 508 |
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL": uploaded_url})
|
| 509 |
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
|
| 510 |
except Exception as e:
|
| 511 |
-
logger.error(f"Failed to update
|
| 512 |
sf.Account.update(record_id, {"Description": uploaded_url})
|
| 513 |
logger.info(f"Updated account record {record_id} with PDF URL")
|
| 514 |
pdf_url = uploaded_url
|
|
@@ -617,6 +615,7 @@ def process_video(video_data, temp_dir):
|
|
| 617 |
unique_violations = {}
|
| 618 |
violation_frames = {}
|
| 619 |
helmet_detections = {}
|
|
|
|
| 620 |
start_time = time.time()
|
| 621 |
frame_skip = CONFIG["FRAME_SKIP"]
|
| 622 |
processed_frames = 0
|
|
@@ -634,7 +633,7 @@ def process_video(video_data, temp_dir):
|
|
| 634 |
break
|
| 635 |
ret, frame = cap.read()
|
| 636 |
if not ret:
|
| 637 |
-
logger.warning(f"Failed to read frame {frame_idx}. Skipping
|
| 638 |
break
|
| 639 |
original_frame = frame.copy()
|
| 640 |
frame = preprocess_frame(frame)
|
|
@@ -644,12 +643,18 @@ def process_video(video_data, temp_dir):
|
|
| 644 |
batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
| 645 |
batch_indices.append(frame_idx)
|
| 646 |
batch_originals.append(original_frame)
|
| 647 |
-
processed_frames += frame_skip
|
| 648 |
|
| 649 |
if not batch_frames:
|
| 650 |
logger.info("No more frames to process.")
|
| 651 |
break
|
| 652 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
try:
|
| 654 |
inputs = processor(images=batch_frames, return_tensors="pt").to(device)
|
| 655 |
if device.type == "cuda":
|
|
@@ -671,7 +676,7 @@ def process_video(video_data, temp_dir):
|
|
| 671 |
progress = (processed_frames / total_frames) * 100
|
| 672 |
elapsed_time = current_time - start_time
|
| 673 |
fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
|
| 674 |
-
yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", "", ""
|
| 675 |
last_yield_time = current_time
|
| 676 |
|
| 677 |
for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
|
|
@@ -680,6 +685,8 @@ def process_video(video_data, temp_dir):
|
|
| 680 |
person_boxes = []
|
| 681 |
tool_boxes = []
|
| 682 |
|
|
|
|
|
|
|
| 683 |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 684 |
label_name = model.config.id2label[label.item()]
|
| 685 |
conf = float(score)
|
|
@@ -689,24 +696,27 @@ def process_video(video_data, temp_dir):
|
|
| 689 |
bbox_xywh = [x + w/2, y + h/2, w, h]
|
| 690 |
|
| 691 |
if label_name in ["no_helmet", "no_harness"] and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label_name, 0.25):
|
| 692 |
-
if label_name == "no_helmet" and not
|
| 693 |
-
logger.info(f"Frame {frame_idx}:
|
| 694 |
continue
|
| 695 |
track_inputs.append({"bbox": bbox_xywh, "conf": conf, "cls": label_name})
|
|
|
|
| 696 |
elif label_name == "person":
|
| 697 |
person_boxes.append(bbox_xywh)
|
| 698 |
-
elif label_name in ["hammer", "wrench"]:
|
| 699 |
tool_boxes.append(bbox_xywh)
|
| 700 |
|
| 701 |
-
# Handle Unsafe violations
|
| 702 |
for pbox in person_boxes:
|
| 703 |
if is_unsafe_posture(pbox, original_frame.shape[:2]):
|
| 704 |
track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_posture"})
|
|
|
|
| 705 |
if is_unsafe_zone(pbox, original_frame.shape[:2]):
|
| 706 |
track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_zone"})
|
|
|
|
| 707 |
for tbox in tool_boxes:
|
| 708 |
if is_improper_tool_use(pbox, tbox):
|
| 709 |
track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "improper_tool_use"})
|
|
|
|
| 710 |
|
| 711 |
if not track_inputs:
|
| 712 |
continue
|
|
@@ -722,6 +732,7 @@ def process_video(video_data, temp_dir):
|
|
| 722 |
tracker_id = obj['id']
|
| 723 |
label = obj['cls']
|
| 724 |
conf = obj['score']
|
|
|
|
| 725 |
|
| 726 |
if label not in CONFIG["VIOLATION_LABELS"]:
|
| 727 |
continue
|
|
@@ -761,100 +772,104 @@ def process_video(video_data, temp_dir):
|
|
| 761 |
|
| 762 |
violations = []
|
| 763 |
for (worker_id, label), detection_time in unique_violations.items():
|
|
|
|
|
|
|
| 764 |
violations.append({
|
| 765 |
"worker_id": worker_id,
|
| 766 |
"violation": label,
|
| 767 |
"timestamp": detection_time,
|
| 768 |
-
"confidence":
|
| 769 |
"frame_idx": violation_frames[(worker_id, label)]
|
| 770 |
})
|
| 771 |
|
| 772 |
if not violations:
|
| 773 |
logger.info("No violations detected after processing")
|
| 774 |
-
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", ""
|
| 775 |
return
|
| 776 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
snapshots = []
|
| 778 |
cap = cv2.VideoCapture(video_path)
|
| 779 |
for violation in violations:
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
|
|
|
| 786 |
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
"
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
(
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
snapshots.append({
|
| 832 |
-
"violation": label_name,
|
| 833 |
-
"worker_id": violation["worker_id"],
|
| 834 |
-
"timestamp": violation["timestamp"],
|
| 835 |
-
"snapshot_path": snapshot_path,
|
| 836 |
-
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
|
| 837 |
-
"confidence": violation["confidence"]
|
| 838 |
-
})
|
| 839 |
-
logger.info(f"Captured snapshot for {label_name} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
|
| 840 |
-
break
|
| 841 |
|
| 842 |
cap.release()
|
| 843 |
|
| 844 |
score = calculate_safety_score(violations)
|
| 845 |
-
pdf_path, pdf_url, pdf_file =
|
| 846 |
-
|
| 847 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 848 |
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
worker_id = v.get("worker_id", "Unknown")
|
| 855 |
-
timestamp = v.get("timestamp", 0.0)
|
| 856 |
-
confidence = v.get("confidence", 0.0)
|
| 857 |
-
violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
|
| 858 |
|
| 859 |
snapshots_text = ""
|
| 860 |
for s in snapshots:
|
|
@@ -873,12 +888,12 @@ def process_video(video_data, temp_dir):
|
|
| 873 |
snapshots_text,
|
| 874 |
f"Salesforce Record ID: {record_id}",
|
| 875 |
final_pdf_url,
|
| 876 |
-
""
|
| 877 |
)
|
| 878 |
|
| 879 |
except Exception as e:
|
| 880 |
logger.error(f"Error processing video: {str(e)}", exc_info=True)
|
| 881 |
-
yield f"Error processing video: {str(e)}", "", "", "", "", ""
|
| 882 |
finally:
|
| 883 |
if video_path and os.path.exists(video_path):
|
| 884 |
try:
|
|
@@ -915,12 +930,12 @@ def gradio_interface(video_file):
|
|
| 915 |
if not FFMPEG_AVAILABLE:
|
| 916 |
return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", "", ""
|
| 917 |
|
| 918 |
-
for status, score, snapshots_text, record_id, details_url,
|
| 919 |
-
yield status, score, snapshots_text, record_id, details_url,
|
| 920 |
|
| 921 |
except Exception as e:
|
| 922 |
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
|
| 923 |
-
yield f"Error: {str(e)}", "", "Error in processing.", "", "",
|
| 924 |
finally:
|
| 925 |
if local_video_path and os.path.exists(local_video_path):
|
| 926 |
try:
|
|
@@ -945,7 +960,7 @@ interface = gr.Interface(
|
|
| 945 |
gr.Markdown(label="Snapshots"),
|
| 946 |
gr.Textbox(label="Salesforce Record ID"),
|
| 947 |
gr.Textbox(label="Violation Details URL"),
|
| 948 |
-
gr.Textbox(label="
|
| 949 |
],
|
| 950 |
title="Worksite Safety Violation Analyzer",
|
| 951 |
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
|
|
|
|
| 224 |
|
| 225 |
# ========================== # Optimized Configuration # ==========================
|
| 226 |
CONFIG = {
|
| 227 |
+
"MODEL_NAME": "facebook/detr-resnet-50",
|
| 228 |
"VIOLATION_LABELS": {
|
| 229 |
"no_helmet": "No Helmet",
|
| 230 |
"no_harness": "No Harness",
|
|
|
|
| 263 |
"MIN_VIOLATION_FRAMES": 2,
|
| 264 |
"VIOLATION_COOLDOWN": 30.0,
|
| 265 |
"WORKER_TRACKING_DURATION": 10.0,
|
| 266 |
+
"MAX_PROCESSING_TIME": 60, # Reduced for early termination
|
| 267 |
+
"FRAME_SKIP": 4, # Increased to reduce frames processed
|
| 268 |
+
"BATCH_SIZE": 4, # Reduced for CPU efficiency
|
| 269 |
"PARALLEL_WORKERS": max(1, cpu_count() - 1),
|
| 270 |
"TRACK_BUFFER": 150,
|
| 271 |
"TRACK_THRESH": 0.3,
|
| 272 |
"MATCH_THRESH": 0.5,
|
| 273 |
"SNAPSHOT_QUALITY": 95,
|
| 274 |
"MAX_WORKER_DISTANCE": 150,
|
| 275 |
+
"TARGET_RESOLUTION": (320, 320), # Reduced for faster inference
|
| 276 |
"HELMET_VALIDATION_FRAMES": 3
|
| 277 |
}
|
| 278 |
|
| 279 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 280 |
logger.info(f"Using device: {device}")
|
| 281 |
+
if device.type == "cpu":
|
| 282 |
+
logger.warning("Running on CPU, which may lead to slower processing. Consider using a GPU for better performance.")
|
| 283 |
|
| 284 |
def load_model():
|
| 285 |
try:
|
|
|
|
| 286 |
import timm
|
| 287 |
logger.info("timm library is available.")
|
| 288 |
except ImportError as e:
|
|
|
|
| 313 |
return frame
|
| 314 |
|
| 315 |
def is_unsafe_posture(box, frame_shape):
|
|
|
|
| 316 |
x1, y1, x2, y2 = box
|
| 317 |
height = y2 - y1
|
| 318 |
width = x2 - x1
|
|
|
|
| 320 |
return aspect_ratio > 2.0
|
| 321 |
|
| 322 |
def is_improper_tool_use(person_box, tool_box):
|
|
|
|
| 323 |
person_center = ((person_box[0] + person_box[2]) / 2, (person_box[1] + person_box[3]) / 2)
|
| 324 |
tool_center = ((tool_box[0] + tool_box[2]) / 2, (tool_box[1] + tool_box[3]) / 2)
|
| 325 |
dist = distance.euclidean(person_center, tool_center)
|
| 326 |
return dist > 100
|
| 327 |
|
| 328 |
def is_unsafe_zone(person_box, frame_shape):
|
|
|
|
| 329 |
px, py, pw, ph = person_box
|
| 330 |
frame_h, frame_w = frame_shape
|
| 331 |
person_center = (px + pw / 2, py + ph / 2)
|
|
|
|
| 506 |
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL": uploaded_url})
|
| 507 |
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
|
| 508 |
except Exception as e:
|
| 509 |
+
logger.error(f"Failed to update Safety_Violation_Report__c: {e}")
|
| 510 |
sf.Account.update(record_id, {"Description": uploaded_url})
|
| 511 |
logger.info(f"Updated account record {record_id} with PDF URL")
|
| 512 |
pdf_url = uploaded_url
|
|
|
|
| 615 |
unique_violations = {}
|
| 616 |
violation_frames = {}
|
| 617 |
helmet_detections = {}
|
| 618 |
+
frame_detections = {} # Store detections for snapshot reuse
|
| 619 |
start_time = time.time()
|
| 620 |
frame_skip = CONFIG["FRAME_SKIP"]
|
| 621 |
processed_frames = 0
|
|
|
|
| 633 |
break
|
| 634 |
ret, frame = cap.read()
|
| 635 |
if not ret:
|
| 636 |
+
logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
|
| 637 |
break
|
| 638 |
original_frame = frame.copy()
|
| 639 |
frame = preprocess_frame(frame)
|
|
|
|
| 643 |
batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
| 644 |
batch_indices.append(frame_idx)
|
| 645 |
batch_originals.append(original_frame)
|
| 646 |
+
processed_frames += frame_skip
|
| 647 |
|
| 648 |
if not batch_frames:
|
| 649 |
logger.info("No more frames to process.")
|
| 650 |
break
|
| 651 |
|
| 652 |
+
# Check for timeout
|
| 653 |
+
elapsed_time = time.time() - start_time
|
| 654 |
+
if elapsed_time > CONFIG["MAX_PROCESSING_TIME"]:
|
| 655 |
+
logger.warning(f"Processing exceeded time limit of {CONFIG['MAX_PROCESSING_TIME']}s. Terminating early.")
|
| 656 |
+
break
|
| 657 |
+
|
| 658 |
try:
|
| 659 |
inputs = processor(images=batch_frames, return_tensors="pt").to(device)
|
| 660 |
if device.type == "cuda":
|
|
|
|
| 676 |
progress = (processed_frames / total_frames) * 100
|
| 677 |
elapsed_time = current_time - start_time
|
| 678 |
fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
|
| 679 |
+
yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", "", f"Elapsed: {elapsed_time:.1f}s"
|
| 680 |
last_yield_time = current_time
|
| 681 |
|
| 682 |
for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
|
|
|
|
| 685 |
person_boxes = []
|
| 686 |
tool_boxes = []
|
| 687 |
|
| 688 |
+
frame_detections[frame_idx] = [] # Store detections for this frame
|
| 689 |
+
|
| 690 |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 691 |
label_name = model.config.id2label[label.item()]
|
| 692 |
conf = float(score)
|
|
|
|
| 696 |
bbox_xywh = [x + w/2, y + h/2, w, h]
|
| 697 |
|
| 698 |
if label_name in ["no_helmet", "no_harness"] and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label_name, 0.25):
|
| 699 |
+
if label_name == "no_helmet" and not validate_helmet_detection(original_frame, bbox_xywh, conf):
|
| 700 |
+
logger.info(f"Frame {frame_idx}: Helmet false positive filtered at {conf:.2f} confidence")
|
| 701 |
continue
|
| 702 |
track_inputs.append({"bbox": bbox_xywh, "conf": conf, "cls": label_name})
|
| 703 |
+
frame_detections[frame_idx].append({"label": label_name, "conf": conf, "bbox": bbox_xywh})
|
| 704 |
elif label_name == "person":
|
| 705 |
person_boxes.append(bbox_xywh)
|
| 706 |
+
elif label_name in ["hammer", "wrench"]:
|
| 707 |
tool_boxes.append(bbox_xywh)
|
| 708 |
|
|
|
|
| 709 |
for pbox in person_boxes:
|
| 710 |
if is_unsafe_posture(pbox, original_frame.shape[:2]):
|
| 711 |
track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_posture"})
|
| 712 |
+
frame_detections[frame_idx].append({"label": "unsafe_posture", "conf": 0.9, "bbox": pbox})
|
| 713 |
if is_unsafe_zone(pbox, original_frame.shape[:2]):
|
| 714 |
track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_zone"})
|
| 715 |
+
frame_detections[frame_idx].append({"label": "unsafe_zone", "conf": 0.9, "bbox": pbox})
|
| 716 |
for tbox in tool_boxes:
|
| 717 |
if is_improper_tool_use(pbox, tbox):
|
| 718 |
track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "improper_tool_use"})
|
| 719 |
+
frame_detections[frame_idx].append({"label": "improper_tool_use", "conf": 0.9, "bbox": pbox})
|
| 720 |
|
| 721 |
if not track_inputs:
|
| 722 |
continue
|
|
|
|
| 732 |
tracker_id = obj['id']
|
| 733 |
label = obj['cls']
|
| 734 |
conf = obj['score']
|
| 735 |
+
bbox = obj['bbox']
|
| 736 |
|
| 737 |
if label not in CONFIG["VIOLATION_LABELS"]:
|
| 738 |
continue
|
|
|
|
| 772 |
|
| 773 |
violations = []
|
| 774 |
for (worker_id, label), detection_time in unique_violations.items():
|
| 775 |
+
frame_idx = violation_frames[(worker_id, label)]
|
| 776 |
+
conf = next((d["conf"] for d in frame_detections.get(frame_idx, []) if d["label"] == label), 0.0)
|
| 777 |
violations.append({
|
| 778 |
"worker_id": worker_id,
|
| 779 |
"violation": label,
|
| 780 |
"timestamp": detection_time,
|
| 781 |
+
"confidence": conf,
|
| 782 |
"frame_idx": violation_frames[(worker_id, label)]
|
| 783 |
})
|
| 784 |
|
| 785 |
if not violations:
|
| 786 |
logger.info("No violations detected after processing")
|
| 787 |
+
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", f"Completed in {processing_time:.1f}s"
|
| 788 |
return
|
| 789 |
|
| 790 |
+
# Generate violation table early for intermediate output
|
| 791 |
+
violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
|
| 792 |
+
violation_table += "|-----------|-----------|----------|------------|\n"
|
| 793 |
+
for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
|
| 794 |
+
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
|
| 795 |
+
worker_id = v.get("worker_id", "Unknown")
|
| 796 |
+
timestamp = v.get("timestamp", 0.0)
|
| 797 |
+
confidence = v.get("confidence", 0.0)
|
| 798 |
+
violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
|
| 799 |
+
yield violation_table, "", "", "", "", f"Violations detected in {processing_time:.1f}s"
|
| 800 |
+
|
| 801 |
snapshots = []
|
| 802 |
cap = cv2.VideoCapture(video_path)
|
| 803 |
for violation in violations:
|
| 804 |
+
try:
|
| 805 |
+
frame_idx = violation["frame_idx"]
|
| 806 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
| 807 |
+
ret, frame = cap.read()
|
| 808 |
+
if not ret:
|
| 809 |
+
logger.warning(f"Failed to read frame {frame_idx} for snapshot.")
|
| 810 |
+
continue
|
| 811 |
|
| 812 |
+
frame = preprocess_frame(frame)
|
| 813 |
+
# Reuse detections instead of re-running inference
|
| 814 |
+
detections = frame_detections.get(frame_idx, [])
|
| 815 |
+
for det in detections:
|
| 816 |
+
if det["label"] == violation["violation"]:
|
| 817 |
+
violation["confidence"] = round(det["conf"], 2)
|
| 818 |
+
detection = {
|
| 819 |
+
"worker_id": violation["worker_id"],
|
| 820 |
+
"violation": det["label"],
|
| 821 |
+
"confidence": violation["confidence"],
|
| 822 |
+
"bounding_box": det["bbox"],
|
| 823 |
+
"timestamp": violation["timestamp"]
|
| 824 |
+
}
|
| 825 |
+
snapshot_frame = frame.copy()
|
| 826 |
+
snapshot_frame = draw_detections(snapshot_frame, [detection])
|
| 827 |
+
cv2.putText(
|
| 828 |
+
snapshot_frame,
|
| 829 |
+
f"Time: {violation['timestamp']:.2f}s",
|
| 830 |
+
(10, 30),
|
| 831 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 832 |
+
0.7,
|
| 833 |
+
(255, 255, 255),
|
| 834 |
+
2
|
| 835 |
+
)
|
| 836 |
+
snapshot_filename = f"violation_{det['label']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
|
| 837 |
+
snapshot_path = os.path.join(output_dir, snapshot_filename)
|
| 838 |
+
cv2.imwrite(
|
| 839 |
+
snapshot_path,
|
| 840 |
+
snapshot_frame,
|
| 841 |
+
[cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
|
| 842 |
+
)
|
| 843 |
+
snapshots.append({
|
| 844 |
+
"violation": det["label"],
|
| 845 |
+
"worker_id": violation["worker_id"],
|
| 846 |
+
"timestamp": violation["timestamp"],
|
| 847 |
+
"snapshot_path": snapshot_path,
|
| 848 |
+
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
|
| 849 |
+
"confidence": violation["confidence"]
|
| 850 |
+
})
|
| 851 |
+
logger.info(f"Captured snapshot for {det['label']} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
|
| 852 |
+
break
|
| 853 |
+
except Exception as e:
|
| 854 |
+
logger.error(f"Error generating snapshot for violation: {e}")
|
| 855 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
|
| 857 |
cap.release()
|
| 858 |
|
| 859 |
score = calculate_safety_score(violations)
|
| 860 |
+
pdf_path, pdf_url, pdf_file = "", "", None
|
| 861 |
+
try:
|
| 862 |
+
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
|
| 863 |
+
except Exception as e:
|
| 864 |
+
logger.error(f"PDF generation failed: {e}")
|
| 865 |
+
yield violation_table, f"Safety Score: {score}%", "Failed to generate snapshots due to PDF error.", "N/A", "N/A", f"Completed in {processing_time:.1f}s\nError: {str(e)}"
|
| 866 |
+
return
|
| 867 |
|
| 868 |
+
record_id, final_pdf_url = "N/A", "Salesforce integration failed."
|
| 869 |
+
try:
|
| 870 |
+
record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|
| 871 |
+
except Exception as e:
|
| 872 |
+
logger.error(f"Salesforce integration failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
snapshots_text = ""
|
| 875 |
for s in snapshots:
|
|
|
|
| 888 |
snapshots_text,
|
| 889 |
f"Salesforce Record ID: {record_id}",
|
| 890 |
final_pdf_url,
|
| 891 |
+
f"Completed in {processing_time:.1f}s"
|
| 892 |
)
|
| 893 |
|
| 894 |
except Exception as e:
|
| 895 |
logger.error(f"Error processing video: {str(e)}", exc_info=True)
|
| 896 |
+
yield f"Error processing video: {str(e)}", "", "", "", "", f"Failed after {time.time() - start_time:.1f}s"
|
| 897 |
finally:
|
| 898 |
if video_path and os.path.exists(video_path):
|
| 899 |
try:
|
|
|
|
| 930 |
if not FFMPEG_AVAILABLE:
|
| 931 |
return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", "", ""
|
| 932 |
|
| 933 |
+
for status, score, snapshots_text, record_id, details_url, log in process_video(video_data, temp_dir):
|
| 934 |
+
yield status, score, snapshots_text, record_id, details_url, log
|
| 935 |
|
| 936 |
except Exception as e:
|
| 937 |
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
|
| 938 |
+
yield f"Error: {str(e)}", "", "Error in processing.", "", "", str(e)
|
| 939 |
finally:
|
| 940 |
if local_video_path and os.path.exists(local_video_path):
|
| 941 |
try:
|
|
|
|
| 960 |
gr.Markdown(label="Snapshots"),
|
| 961 |
gr.Textbox(label="Salesforce Record ID"),
|
| 962 |
gr.Textbox(label="Violation Details URL"),
|
| 963 |
+
gr.Textbox(label="Processing Log")
|
| 964 |
],
|
| 965 |
title="Worksite Safety Violation Analyzer",
|
| 966 |
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
|