File size: 3,912 Bytes
40bf7bd
 
e04e491
40bf7bd
 
f3726fe
130b590
dd7321a
40bf7bd
e04e491
 
 
a75980a
e04e491
 
457903c
 
 
 
a75980a
457903c
 
40bf7bd
 
 
 
 
 
 
457903c
 
 
e04e491
 
457903c
e04e491
457903c
40bf7bd
 
457903c
 
 
a75980a
 
 
 
 
 
e04e491
 
a75980a
 
e04e491
457903c
e04e491
 
457903c
 
 
e04e491
 
 
 
 
40bf7bd
 
457903c
e04e491
 
 
 
40bf7bd
e04e491
 
 
 
40bf7bd
e04e491
 
 
 
a75980a
 
 
 
 
 
 
e04e491
 
 
 
457903c
 
40bf7bd
e04e491
457903c
e04e491
 
457903c
 
 
e04e491
40bf7bd
 
 
 
 
 
e04e491
40bf7bd
 
 
e04e491
457903c
 
 
e04e491
40bf7bd
e04e491
 
40bf7bd
 
 
e04e491
 
40bf7bd
e04e491
40bf7bd
 
e04e491
40bf7bd
457903c
e04e491
 
 
457903c
40bf7bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import cv2
import gradio as gr
import torch
import numpy as np
from salesforce import get_salesforce_connection
from simple_salesforce import Salesforce
sf = get_salesforce_connection()

try:
    from ultralytics import YOLO
except ImportError as e:
    print("❌ Ultralytics not installed. Run: pip install ultralytics")
    raise

# ==========================
# Configuration
# ==========================
DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
FALLBACK_MODEL = "yolov8n.pt"  # Use nano model if custom one is missing
MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)

VIOLATION_LABELS = {
    0: "no_helmet",
    1: "no_harness",
    2: "unsafe_posture",
    3: "unsafe_zone"
}

# ==========================
# Device Setup
# ==========================
try:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"✅ Using device: {device}")
except Exception as e:
    print(f"⚠️ Error setting device: {e}")
    device = torch.device("cpu")

# ==========================
# Load Model
# ==========================
if os.path.isfile(MODEL_PATH):
    selected_model = MODEL_PATH
    print(f"✅ Found model at: {selected_model}")
else:
    print(f"⚠️ Model file '{MODEL_PATH}' not found. Falling back to: {FALLBACK_MODEL}")
    selected_model = FALLBACK_MODEL

try:
    model = YOLO(selected_model)
    print(f"✅ Model loaded: {selected_model}")
except Exception as e:
    print(f"❌ Failed to load model: {e}")
    raise

# ==========================
# Video Processing
# ==========================
def process_video(video_path):
    try:
        video = cv2.VideoCapture(video_path)
        if not video.isOpened():
            raise ValueError("Could not open video file.")

        frame_count = 0
        violations = []

        while True:
            ret, frame = video.read()
            if not ret:
                break

            results = model(frame, device=device)

            for result in results:
                for box in result.boxes:
                    cls = int(box.cls)
                    conf = float(box.conf)
                    xywh = box.xywh.cpu().numpy()[0]

                    label = VIOLATION_LABELS.get(cls, f"class_{cls}")
                    violations.append({
                        "frame": frame_count,
                        "violation": label,
                        "confidence": round(conf, 2),
                        "bounding_box": [round(x, 2) for x in xywh]
                    })

            frame_count += 1

        video.release()
        score = calculate_safety_score(violations)
        return violations, score

    except Exception as e:
        print(f"❌ Error processing video: {e}")
        return [], f"Error: {e}"

# ==========================
# Safety Score
# ==========================
def calculate_safety_score(violations):
    base_score = 100
    penalties = {
        "no_helmet": 25,
        "no_harness": 30,
        "unsafe_posture": 20,
        "unsafe_zone": 25
    }
    for v in violations:
        base_score -= penalties.get(v["violation"], 0)
    return max(base_score, 0)

# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
    if not video_file:
        return "Please upload a video file.", ""

    violations, score = process_video(video_file)
    return violations, f"Safety Score: {score}%"

interface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Video(label="Upload Site Video"),
    outputs=[
        gr.JSON(label="Detected Safety Violations"),
        gr.Textbox(label="Compliance Score")
    ],
    title="Worksite Safety Violation Analyzer",
    description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
)

if __name__ == "__main__":
    print("🚀 Launching Safety Analyzer App...")
    interface.launch()