Spaces:
Sleeping
Sleeping
File size: 3,912 Bytes
40bf7bd e04e491 40bf7bd f3726fe 130b590 dd7321a 40bf7bd e04e491 a75980a e04e491 457903c a75980a 457903c 40bf7bd 457903c e04e491 457903c e04e491 457903c 40bf7bd 457903c a75980a e04e491 a75980a e04e491 457903c e04e491 457903c e04e491 40bf7bd 457903c e04e491 40bf7bd e04e491 40bf7bd e04e491 a75980a e04e491 457903c 40bf7bd e04e491 457903c e04e491 457903c e04e491 40bf7bd e04e491 40bf7bd e04e491 457903c e04e491 40bf7bd e04e491 40bf7bd e04e491 40bf7bd e04e491 40bf7bd e04e491 40bf7bd 457903c e04e491 457903c 40bf7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import cv2
import gradio as gr
import torch
import numpy as np
from salesforce import get_salesforce_connection
from simple_salesforce import Salesforce
sf = get_salesforce_connection()
try:
from ultralytics import YOLO
except ImportError as e:
print("❌ Ultralytics not installed. Run: pip install ultralytics")
raise
# ==========================
# Configuration
# ==========================
DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
FALLBACK_MODEL = "yolov8n.pt" # Use nano model if custom one is missing
MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
VIOLATION_LABELS = {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone"
}
# ==========================
# Device Setup
# ==========================
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")
except Exception as e:
print(f"⚠️ Error setting device: {e}")
device = torch.device("cpu")
# ==========================
# Load Model
# ==========================
if os.path.isfile(MODEL_PATH):
selected_model = MODEL_PATH
print(f"✅ Found model at: {selected_model}")
else:
print(f"⚠️ Model file '{MODEL_PATH}' not found. Falling back to: {FALLBACK_MODEL}")
selected_model = FALLBACK_MODEL
try:
model = YOLO(selected_model)
print(f"✅ Model loaded: {selected_model}")
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise
# ==========================
# Video Processing
# ==========================
def process_video(video_path):
try:
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file.")
frame_count = 0
violations = []
while True:
ret, frame = video.read()
if not ret:
break
results = model(frame, device=device)
for result in results:
for box in result.boxes:
cls = int(box.cls)
conf = float(box.conf)
xywh = box.xywh.cpu().numpy()[0]
label = VIOLATION_LABELS.get(cls, f"class_{cls}")
violations.append({
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in xywh]
})
frame_count += 1
video.release()
score = calculate_safety_score(violations)
return violations, score
except Exception as e:
print(f"❌ Error processing video: {e}")
return [], f"Error: {e}"
# ==========================
# Safety Score
# ==========================
def calculate_safety_score(violations):
base_score = 100
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 25
}
for v in violations:
base_score -= penalties.get(v["violation"], 0)
return max(base_score, 0)
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "Please upload a video file.", ""
violations, score = process_video(video_file)
return violations, f"Safety Score: {score}%"
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.JSON(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score")
],
title="Worksite Safety Violation Analyzer",
description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
)
if __name__ == "__main__":
print("🚀 Launching Safety Analyzer App...")
interface.launch()
|