Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,16 +2,17 @@ import os
|
|
| 2 |
import cv2
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
|
|
|
| 5 |
from ultralytics import YOLO
|
| 6 |
import time
|
| 7 |
-
import logging
|
| 8 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
| 9 |
|
| 10 |
# ==========================
|
| 11 |
# Optimized Configuration
|
| 12 |
# ==========================
|
| 13 |
CONFIG = {
|
| 14 |
-
"MODEL_PATH": "yolov8_safety.pt",
|
| 15 |
"OUTPUT_DIR": "static/output",
|
| 16 |
"VIOLATION_LABELS": {
|
| 17 |
0: "no_helmet",
|
|
@@ -20,11 +21,18 @@ CONFIG = {
|
|
| 20 |
3: "unsafe_zone",
|
| 21 |
4: "improper_tool_use"
|
| 22 |
},
|
| 23 |
-
"
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
}
|
| 29 |
|
| 30 |
# Setup logging
|
|
@@ -38,9 +46,27 @@ logger.info(f"Using device: {device}")
|
|
| 38 |
# Load model
|
| 39 |
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def process_frame(frame, frame_count, fps):
|
| 42 |
-
"""Process a single frame for violations"""
|
| 43 |
-
results = model(frame, device=device, verbose=False)
|
| 44 |
detections = []
|
| 45 |
|
| 46 |
for result in results:
|
|
@@ -51,21 +77,22 @@ def process_frame(frame, frame_count, fps):
|
|
| 51 |
detections.append({
|
| 52 |
"frame": frame_count,
|
| 53 |
"violation": CONFIG["VIOLATION_LABELS"][cls],
|
| 54 |
-
"confidence": conf,
|
| 55 |
"bounding_box": box.xywh.cpu().numpy()[0],
|
| 56 |
"timestamp": frame_count / fps
|
| 57 |
})
|
| 58 |
return detections
|
| 59 |
|
| 60 |
def process_video(video_path):
|
| 61 |
-
"""Optimized video processing with parallel
|
| 62 |
start_time = time.time()
|
| 63 |
cap = cv2.VideoCapture(video_path)
|
| 64 |
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
| 65 |
violations = []
|
|
|
|
| 66 |
frame_count = 0
|
| 67 |
|
| 68 |
-
with ThreadPoolExecutor() as executor:
|
| 69 |
futures = []
|
| 70 |
while cap.isOpened():
|
| 71 |
ret, frame = cap.read()
|
|
@@ -73,15 +100,34 @@ def process_video(video_path):
|
|
| 73 |
break
|
| 74 |
|
| 75 |
if frame_count % CONFIG["FRAME_SKIP"] == 0:
|
| 76 |
-
futures.append(executor.submit(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
frame_count += 1
|
| 79 |
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
|
| 80 |
logger.info("Processing time limit reached")
|
| 81 |
break
|
| 82 |
|
|
|
|
| 83 |
for future in futures:
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
cap.release()
|
| 87 |
|
|
@@ -96,48 +142,94 @@ def process_video(video_path):
|
|
| 96 |
if violation_counts.get((v["violation"], int(v["timestamp"])), 0) >= CONFIG["MIN_VIOLATION_FRAMES"]
|
| 97 |
]
|
| 98 |
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
def analyze_video(video_file):
|
| 102 |
-
"""
|
| 103 |
if not video_file:
|
| 104 |
return "No video uploaded", "", "", ""
|
| 105 |
|
| 106 |
try:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
# Generate simple output (removed PDF generation for speed)
|
| 111 |
-
violation_table = (
|
| 112 |
-
"| Violation Type | Timestamp (s) | Confidence |\n"
|
| 113 |
-
"|----------------|---------------|------------|\n" +
|
| 114 |
-
"\n".join(
|
| 115 |
-
f"| {v['violation']:<14} | {v['timestamp']:.1f} | {v['confidence']:.2f} |"
|
| 116 |
-
for v in violations
|
| 117 |
-
) if violations else "No violations detected."
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
return (
|
| 121 |
-
violation_table,
|
| 122 |
-
f"Processing Time: {processing_time:.1f}s",
|
| 123 |
-
f"Violations Found: {len(violations)}",
|
| 124 |
-
f"Analysis Completed in {time.time()-start_time:.1f}s"
|
| 125 |
-
)
|
| 126 |
except Exception as e:
|
|
|
|
| 127 |
return f"Error: {str(e)}", "", "", ""
|
| 128 |
|
| 129 |
-
#
|
| 130 |
interface = gr.Interface(
|
| 131 |
fn=analyze_video,
|
| 132 |
inputs=gr.Video(label="Upload Site Video"),
|
| 133 |
outputs=[
|
| 134 |
gr.Markdown("## Detected Violations"),
|
| 135 |
-
gr.Textbox(label="
|
| 136 |
-
gr.
|
| 137 |
-
gr.Textbox(label="
|
| 138 |
],
|
| 139 |
-
title="
|
| 140 |
-
description="Optimized
|
|
|
|
| 141 |
)
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
|
|
|
| 2 |
import cv2
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
from ultralytics import YOLO
|
| 7 |
import time
|
|
|
|
| 8 |
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
+
import logging
|
| 10 |
|
| 11 |
# ==========================
|
| 12 |
# Optimized Configuration
|
| 13 |
# ==========================
|
| 14 |
CONFIG = {
|
| 15 |
+
"MODEL_PATH": "yolov8_safety.pt", # Your trained model
|
| 16 |
"OUTPUT_DIR": "static/output",
|
| 17 |
"VIOLATION_LABELS": {
|
| 18 |
0: "no_helmet",
|
|
|
|
| 21 |
3: "unsafe_zone",
|
| 22 |
4: "improper_tool_use"
|
| 23 |
},
|
| 24 |
+
"CLASS_COLORS": {
|
| 25 |
+
"no_helmet": (0, 0, 255), # Red
|
| 26 |
+
"no_harness": (0, 165, 255), # Orange
|
| 27 |
+
"unsafe_posture": (0, 255, 0), # Green
|
| 28 |
+
"unsafe_zone": (255, 0, 0), # Blue
|
| 29 |
+
"improper_tool_use": (255, 255, 0) # Yellow
|
| 30 |
+
},
|
| 31 |
+
"FRAME_SKIP": 8, # Process every 8th frame
|
| 32 |
+
"MAX_PROCESSING_TIME": 45, # Max processing time (seconds)
|
| 33 |
+
"CONFIDENCE_THRESHOLD": 0.35, # Balanced threshold
|
| 34 |
+
"MIN_VIOLATION_FRAMES": 2, # Reduced from 3 to 2
|
| 35 |
+
"GPU_ACCELERATION": True # Enable GPU if available
|
| 36 |
}
|
| 37 |
|
| 38 |
# Setup logging
|
|
|
|
| 46 |
# Load model
|
| 47 |
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
|
| 48 |
|
| 49 |
+
def draw_detections(frame, detections):
|
| 50 |
+
"""Draw bounding boxes with labels on frame."""
|
| 51 |
+
for det in detections:
|
| 52 |
+
label = det["violation"]
|
| 53 |
+
conf = det["confidence"]
|
| 54 |
+
x, y, w, h = det["bounding_box"]
|
| 55 |
+
x1, y1 = int(x - w/2), int(y - h/2)
|
| 56 |
+
x2, y2 = int(x + w/2), int(y + h/2)
|
| 57 |
+
|
| 58 |
+
color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
|
| 59 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
| 60 |
+
cv2.putText(frame,
|
| 61 |
+
f"{label}: {conf:.2f}",
|
| 62 |
+
(x1, y1-10),
|
| 63 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 64 |
+
0.5, color, 2)
|
| 65 |
+
return frame
|
| 66 |
+
|
| 67 |
def process_frame(frame, frame_count, fps):
|
| 68 |
+
"""Process a single frame for violations (optimized)"""
|
| 69 |
+
results = model(frame, device=device, verbose=False)
|
| 70 |
detections = []
|
| 71 |
|
| 72 |
for result in results:
|
|
|
|
| 77 |
detections.append({
|
| 78 |
"frame": frame_count,
|
| 79 |
"violation": CONFIG["VIOLATION_LABELS"][cls],
|
| 80 |
+
"confidence": round(conf, 2),
|
| 81 |
"bounding_box": box.xywh.cpu().numpy()[0],
|
| 82 |
"timestamp": frame_count / fps
|
| 83 |
})
|
| 84 |
return detections
|
| 85 |
|
| 86 |
def process_video(video_path):
|
| 87 |
+
"""Optimized video processing with parallel execution"""
|
| 88 |
start_time = time.time()
|
| 89 |
cap = cv2.VideoCapture(video_path)
|
| 90 |
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
| 91 |
violations = []
|
| 92 |
+
snapshots = {}
|
| 93 |
frame_count = 0
|
| 94 |
|
| 95 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 96 |
futures = []
|
| 97 |
while cap.isOpened():
|
| 98 |
ret, frame = cap.read()
|
|
|
|
| 100 |
break
|
| 101 |
|
| 102 |
if frame_count % CONFIG["FRAME_SKIP"] == 0:
|
| 103 |
+
futures.append(executor.submit(
|
| 104 |
+
process_frame,
|
| 105 |
+
frame.copy(),
|
| 106 |
+
frame_count,
|
| 107 |
+
fps
|
| 108 |
+
))
|
| 109 |
|
| 110 |
frame_count += 1
|
| 111 |
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
|
| 112 |
logger.info("Processing time limit reached")
|
| 113 |
break
|
| 114 |
|
| 115 |
+
# Process results as they complete
|
| 116 |
for future in futures:
|
| 117 |
+
frame_detections = future.result()
|
| 118 |
+
violations.extend(frame_detections)
|
| 119 |
+
|
| 120 |
+
# Capture first occurrence of each violation type
|
| 121 |
+
for det in frame_detections:
|
| 122 |
+
if det["violation"] not in snapshots:
|
| 123 |
+
snapshots[det["violation"]] = {
|
| 124 |
+
"frame": det["frame"],
|
| 125 |
+
"timestamp": det["timestamp"],
|
| 126 |
+
"image": draw_detections(
|
| 127 |
+
cv2.cvtColor(cap.read()[1], cv2.COLOR_BGR2RGB),
|
| 128 |
+
[det]
|
| 129 |
+
)
|
| 130 |
+
}
|
| 131 |
|
| 132 |
cap.release()
|
| 133 |
|
|
|
|
| 142 |
if violation_counts.get((v["violation"], int(v["timestamp"])), 0) >= CONFIG["MIN_VIOLATION_FRAMES"]
|
| 143 |
]
|
| 144 |
|
| 145 |
+
# Prepare snapshot outputs
|
| 146 |
+
snapshot_outputs = []
|
| 147 |
+
for violation_type, data in snapshots.items():
|
| 148 |
+
snapshot_path = os.path.join(
|
| 149 |
+
CONFIG["OUTPUT_DIR"],
|
| 150 |
+
f"{violation_type}_{data['frame']}.jpg"
|
| 151 |
+
)
|
| 152 |
+
cv2.imwrite(snapshot_path, data["image"])
|
| 153 |
+
snapshot_outputs.append({
|
| 154 |
+
"violation": violation_type,
|
| 155 |
+
"frame": data["frame"],
|
| 156 |
+
"timestamp": data["timestamp"],
|
| 157 |
+
"path": snapshot_path
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
"violations": filtered_violations,
|
| 162 |
+
"snapshots": snapshot_outputs,
|
| 163 |
+
"processing_time": time.time() - start_time
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
def calculate_safety_score(violations):
|
| 167 |
+
"""Calculate safety score (0-100)"""
|
| 168 |
+
penalty_weights = {
|
| 169 |
+
"no_helmet": 25,
|
| 170 |
+
"no_harness": 30,
|
| 171 |
+
"unsafe_posture": 20,
|
| 172 |
+
"unsafe_zone": 35,
|
| 173 |
+
"improper_tool_use": 25
|
| 174 |
+
}
|
| 175 |
+
unique_violations = set((v["violation"]) for v in violations)
|
| 176 |
+
total_penalty = sum(penalty_weights.get(v, 0) for v in unique_violations)
|
| 177 |
+
return max(100 - total_penalty, 0)
|
| 178 |
+
|
| 179 |
+
def format_output(result):
|
| 180 |
+
"""Format results for Gradio output"""
|
| 181 |
+
# Violation table
|
| 182 |
+
violation_table = (
|
| 183 |
+
"| Violation Type | Timestamp (s) | Confidence |\n"
|
| 184 |
+
"|----------------|---------------|------------|\n" +
|
| 185 |
+
"\n".join(
|
| 186 |
+
f"| {v['violation']:<14} | {v['timestamp']:.1f} | {v['confidence']:.2f} |"
|
| 187 |
+
for v in result["violations"]
|
| 188 |
+
) if result["violations"] else "No violations detected."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Snapshots
|
| 192 |
+
snapshots_md = "\n".join(
|
| 193 |
+
f"**{s['violation']}** at {s['timestamp']:.1f}s: "
|
| 194 |
+
f""
|
| 195 |
+
for s in result["snapshots"]
|
| 196 |
+
) if result["snapshots"] else "No snapshots available."
|
| 197 |
+
|
| 198 |
+
# Safety score
|
| 199 |
+
safety_score = calculate_safety_score(result["violations"])
|
| 200 |
+
|
| 201 |
+
return (
|
| 202 |
+
violation_table,
|
| 203 |
+
f"Safety Score: {safety_score}%",
|
| 204 |
+
snapshots_md,
|
| 205 |
+
f"Processed in {result['processing_time']:.1f}s"
|
| 206 |
+
)
|
| 207 |
|
| 208 |
def analyze_video(video_file):
|
| 209 |
+
"""Gradio interface function"""
|
| 210 |
if not video_file:
|
| 211 |
return "No video uploaded", "", "", ""
|
| 212 |
|
| 213 |
try:
|
| 214 |
+
result = process_video(video_file)
|
| 215 |
+
return format_output(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
except Exception as e:
|
| 217 |
+
logger.error(f"Error: {str(e)}")
|
| 218 |
return f"Error: {str(e)}", "", "", ""
|
| 219 |
|
| 220 |
+
# Gradio Interface
|
| 221 |
interface = gr.Interface(
|
| 222 |
fn=analyze_video,
|
| 223 |
inputs=gr.Video(label="Upload Site Video"),
|
| 224 |
outputs=[
|
| 225 |
gr.Markdown("## Detected Violations"),
|
| 226 |
+
gr.Textbox(label="Safety Score"),
|
| 227 |
+
gr.Markdown("## Violation Snapshots"),
|
| 228 |
+
gr.Textbox(label="Processing Info")
|
| 229 |
],
|
| 230 |
+
title="AI Safety Compliance Analyzer",
|
| 231 |
+
description="Optimized for fast detection of safety violations",
|
| 232 |
+
allow_flagging="never"
|
| 233 |
)
|
| 234 |
|
| 235 |
if __name__ == "__main__":
|