Spaces:
Sleeping
Sleeping
File size: 11,876 Bytes
08fbbb7 d047708 dc66d57 b0446d7 a87de62 71de8e6 9f5a84e 71de8e6 059a4f4 71de8e6 c92dc21 71de8e6 d047708 9f5a84e a87de62 059a4f4 ebfa2db 059a4f4 ebfa2db 059a4f4 ebfa2db 059a4f4 ebfa2db 059a4f4 a87de62 71de8e6 059a4f4 54ff363 059a4f4 71de8e6 a87de62 9f5a84e 068739f 54ff363 068739f 71de8e6 068739f 059a4f4 7ecb9d9 d047708 71de8e6 9f5a84e d047708 71de8e6 02b45d6 af9bf78 15feb42 71de8e6 7ecb9d9 d047708 7ecb9d9 c6381a2 7ecb9d9 c6381a2 7ecb9d9 71de8e6 7ecb9d9 220ca2f c6381a2 71de8e6 7ecb9d9 a87de62 7ecb9d9 71de8e6 9f5a84e 71de8e6 220ca2f e56bcc5 c6381a2 71de8e6 ebfa2db 71de8e6 059a4f4 71de8e6 02b45d6 a3d6280 7ecb9d9 059a4f4 220ca2f 71de8e6 059a4f4 71de8e6 059a4f4 71de8e6 220ca2f 7ecb9d9 059a4f4 7ecb9d9 71de8e6 d047708 71de8e6 7ecb9d9 12dad16 7ecb9d9 a87de62 c6381a2 059a4f4 12dad16 7ecb9d9 a87de62 71de8e6 059a4f4 9f5a84e 068739f 71de8e6 02b45d6 a87de62 71de8e6 ffe0631 059a4f4 ffe0631 059a4f4 71de8e6 7ecb9d9 71de8e6 ebfa2db 71de8e6 059a4f4 71de8e6 ebfa2db 71de8e6 ffe0631 54ff363 71de8e6 7ecb9d9 059a4f4 c6381a2 71de8e6 c6381a2 059a4f4 c6381a2 71de8e6 059a4f4 c6381a2 ba9ee16 71de8e6 c6381a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 | import os
import sys
import logging
import cv2
import gradio as gr
import torch
import numpy as np
import asyncio
import tempfile
import shutil
from ultralytics import YOLO
from tracker import BYTETracker
from utils import (
preprocess_frame, draw_detections, calculate_safety_score,
generate_violation_pdf, clean_output_directory
)
from config import CONFIG, check_ffmpeg
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='{"time": "%(asctime)s", "level": "%(levelname)s", "message": "%(message)s"}',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
FFMPEG_AVAILABLE = check_ffmpeg()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load YOLO model
def load_model():
static_dir = "static"
os.makedirs(static_dir, exist_ok=True)
model_path = CONFIG["MODEL_PATH"]
fallback_model = CONFIG["FALLBACK_MODEL"]
if not os.path.isfile(model_path):
logger.warning(f"Custom model {model_path} not found. Using fallback.")
if not os.path.isfile(fallback_model):
logger.info(f"Downloading fallback model: {fallback_model}")
torch.hub.download_url_to_file(
'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt',
fallback_model
)
else:
logger.info(f"Using custom model: {model_path}")
model_path = model_path if os.path.isfile(model_path) else fallback_model
model = YOLO(model_path).to(device)
if device.type == "cuda":
model.model.half()
logger.info(f"Model loaded: {model_path}, classes: {model.names}")
return model
model = load_model()
async def process_video(video_data, temp_dir, progress=gr.Progress()):
output_dir = os.path.join("static", "output")
os.makedirs(output_dir, exist_ok=True)
clean_output_directory() # Clean old files
video_path = None
try:
if not video_data:
raise ValueError("Empty video data provided.")
logger.info(f"Received video data size: {len(video_data)} bytes")
with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
temp_file.write(video_data)
video_path = temp_file.name
logger.info(f"Video saved to: {video_path}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file.")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
logger.info(f"Video: {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
tracker = BYTETracker(
track_thresh=CONFIG["TRACK_THRESH"],
track_buffer=CONFIG["TRACK_BUFFER"],
match_thresh=CONFIG["MATCH_THRESH"],
frame_rate=fps
)
violations = []
snapshots = []
worker_id_mapping = {}
unique_violations = {}
violation_frames = {}
worker_counter = 1
processed_frames = 0
start_time = asyncio.get_event_loop().time()
while processed_frames < total_frames:
batch_frames = []
batch_indices = []
for _ in range(CONFIG["BATCH_SIZE"]):
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if frame_idx >= total_frames:
break
ret, frame = cap.read()
if not ret:
logger.warning(f"Failed to read frame {frame_idx}")
break
frame = preprocess_frame(frame)
batch_frames.append(frame)
batch_indices.append(frame_idx)
processed_frames += CONFIG["FRAME_SKIP"]
for _ in range(CONFIG["FRAME_SKIP"] - 1):
cap.grab()
if not batch_frames:
break
batch_frames_np = np.array(batch_frames)
batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
batch_frames_tensor = batch_frames_tensor.to(device)
if device.type == "cuda":
batch_frames_tensor = batch_frames_tensor.half()
results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
progress(processed_frames / total_frames, desc=f"Processing {processed_frames}/{total_frames} frames")
for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
current_time = frame_idx / fps
boxes = result.boxes
track_inputs = [
{"bbox": box.xywh.cpu().numpy()[0], "conf": float(box.conf), "cls": int(box.cls)}
for box in boxes if CONFIG["VIOLATION_LABELS"].get(int(box.cls)) and float(box.conf) >= CONFIG["CONFIDENCE_THRESHOLDS"].get(CONFIG["VIOLATION_LABELS"][int(box.cls)], 0.25)
]
if track_inputs:
tracked_objects = tracker.update(
np.array([t["bbox"] for t in track_inputs]),
np.array([t["conf"] for t in track_inputs]),
np.array([t["cls"] for t in track_inputs]),
current_time
)
logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers")
for obj in tracked_objects:
label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']))
if not label:
continue
tracker_id = obj['id']
if tracker_id not in worker_id_mapping:
worker_id_mapping[tracker_id] = worker_counter
worker_counter += 1
worker_id = worker_id_mapping[tracker_id]
violation_key = (worker_id, label)
if violation_key not in unique_violations:
unique_violations[violation_key] = current_time
violation_frames[violation_key] = frame_idx
violations.append({
"worker_id": worker_id,
"violation": label,
"timestamp": current_time,
"confidence": round(obj['score'], 2),
"frame_idx": frame_idx
})
cap.release()
# Capture snapshots
cap = cv2.VideoCapture(video_path)
for violation in violations:
frame_idx = violation["frame_idx"]
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
continue
frame = preprocess_frame(frame)
snapshot_frame = draw_detections(frame, [{
"worker_id": violation["worker_id"],
"violation": violation["violation"],
"confidence": violation["confidence"],
"bounding_box": violation.get("bounding_box", [0, 0, 0, 0]),
"timestamp": violation["timestamp"]
}])
snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
snapshot_path = os.path.join(output_dir, snapshot_filename)
cv2.imwrite(snapshot_path, snapshot_frame, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
snapshots.append({
"violation": violation["violation"],
"worker_id": violation["worker_id"],
"timestamp": violation["timestamp"],
"snapshot_path": snapshot_path,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
"confidence": violation["confidence"]
})
cap.release()
score = calculate_safety_score(violations)
pdf_path = await generate_violation_pdf(violations, score, output_dir)
violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n|-----------|-----------|----------|------------|\n"
for v in sorted(violations, key=lambda x: (x["worker_id"], x["timestamp"])):
violation_table += f"| {CONFIG['DISPLAY_NAMES'][v['violation']]} | {v['worker_id']} | {v['timestamp']:.2f} | {v['confidence']:.2f} |\n"
snapshots_text = "".join(
f"### {CONFIG['DISPLAY_NAMES'][s['violation']]} - Worker {s['worker_id']} at {s['timestamp']:.2f}s\n\n\n\n"
for s in snapshots
) or "No snapshots captured."
yield (
violation_table,
f"Safety Score: {score}%",
snapshots_text,
pdf_path # Return the PDF path for download
)
except Exception as e:
logger.error(f"Error processing video: {str(e)}")
yield f"Error: {str(e)}", "", "", ""
finally:
if video_path and os.path.exists(video_path):
os.remove(video_path)
if device.type == "cuda":
torch.cuda.empty_cache()
async def gradio_interface(video_file=None, stream_url=None):
# Input validation: yield error message instead of returning
if not video_file and not stream_url:
yield "Please upload a video or provide a stream URL.", "", "", ""
return # Use bare return to exit the generator
temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
try:
if video_file:
with open(video_file, "rb") as f:
video_data = f.read()
async for result in process_video(video_data, temp_dir):
yield result
elif stream_url:
cap = cv2.VideoCapture(stream_url)
if not cap.isOpened():
yield "Failed to open stream.", "", "", ""
return
temp_file = os.path.join(temp_dir, "stream.mp4")
writer = None
start_time = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start_time < CONFIG["MAX_PROCESSING_TIME"]:
ret, frame = cap.read()
if not ret:
break
if writer is None:
writer = cv2.VideoWriter(temp_file, cv2.VideoWriter_fourcc(*'mp4v'), 30, (int(cap.get(3)), int(cap.get(4))))
writer.write(frame)
cap.release()
writer.release()
with open(temp_file, "rb") as f:
video_data = f.read()
async for result in process_video(video_data, temp_dir):
yield result
except Exception as e:
logger.error(f"Error in gradio_interface: {str(e)}")
yield f"Error: {str(e)}", "", "", ""
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
# Gradio Interface
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Video(label="Upload Site Video"),
gr.Textbox(label="Or Enter RTSP/HTTP Stream URL", placeholder="rtsp://example.com/stream")
],
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.File(label="Download Violation Report PDF")
],
title="Enhanced Worksite Safety Analyzer",
description="Upload a video or provide a stream URL to detect safety violations.",
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("Launching Safety Analyzer...")
interface.launch() |