AI_Safety_Demo5 / app.py
neerajkalyank's picture
Update app.py
ffe0631 verified
import os
import sys
import logging
import cv2
import gradio as gr
import torch
import numpy as np
import asyncio
import tempfile
import shutil
from ultralytics import YOLO
from tracker import BYTETracker
from utils import (
preprocess_frame, draw_detections, calculate_safety_score,
generate_violation_pdf, clean_output_directory
)
from config import CONFIG, check_ffmpeg
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='{"time": "%(asctime)s", "level": "%(levelname)s", "message": "%(message)s"}',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
FFMPEG_AVAILABLE = check_ffmpeg()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load YOLO model
def load_model():
static_dir = "static"
os.makedirs(static_dir, exist_ok=True)
model_path = CONFIG["MODEL_PATH"]
fallback_model = CONFIG["FALLBACK_MODEL"]
if not os.path.isfile(model_path):
logger.warning(f"Custom model {model_path} not found. Using fallback.")
if not os.path.isfile(fallback_model):
logger.info(f"Downloading fallback model: {fallback_model}")
torch.hub.download_url_to_file(
'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt',
fallback_model
)
else:
logger.info(f"Using custom model: {model_path}")
model_path = model_path if os.path.isfile(model_path) else fallback_model
model = YOLO(model_path).to(device)
if device.type == "cuda":
model.model.half()
logger.info(f"Model loaded: {model_path}, classes: {model.names}")
return model
model = load_model()
async def process_video(video_data, temp_dir, progress=gr.Progress()):
output_dir = os.path.join("static", "output")
os.makedirs(output_dir, exist_ok=True)
clean_output_directory() # Clean old files
video_path = None
try:
if not video_data:
raise ValueError("Empty video data provided.")
logger.info(f"Received video data size: {len(video_data)} bytes")
with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
temp_file.write(video_data)
video_path = temp_file.name
logger.info(f"Video saved to: {video_path}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file.")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
logger.info(f"Video: {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
tracker = BYTETracker(
track_thresh=CONFIG["TRACK_THRESH"],
track_buffer=CONFIG["TRACK_BUFFER"],
match_thresh=CONFIG["MATCH_THRESH"],
frame_rate=fps
)
violations = []
snapshots = []
worker_id_mapping = {}
unique_violations = {}
violation_frames = {}
worker_counter = 1
processed_frames = 0
start_time = asyncio.get_event_loop().time()
while processed_frames < total_frames:
batch_frames = []
batch_indices = []
for _ in range(CONFIG["BATCH_SIZE"]):
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if frame_idx >= total_frames:
break
ret, frame = cap.read()
if not ret:
logger.warning(f"Failed to read frame {frame_idx}")
break
frame = preprocess_frame(frame)
batch_frames.append(frame)
batch_indices.append(frame_idx)
processed_frames += CONFIG["FRAME_SKIP"]
for _ in range(CONFIG["FRAME_SKIP"] - 1):
cap.grab()
if not batch_frames:
break
batch_frames_np = np.array(batch_frames)
batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
batch_frames_tensor = batch_frames_tensor.to(device)
if device.type == "cuda":
batch_frames_tensor = batch_frames_tensor.half()
results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False)
progress(processed_frames / total_frames, desc=f"Processing {processed_frames}/{total_frames} frames")
for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
current_time = frame_idx / fps
boxes = result.boxes
track_inputs = [
{"bbox": box.xywh.cpu().numpy()[0], "conf": float(box.conf), "cls": int(box.cls)}
for box in boxes if CONFIG["VIOLATION_LABELS"].get(int(box.cls)) and float(box.conf) >= CONFIG["CONFIDENCE_THRESHOLDS"].get(CONFIG["VIOLATION_LABELS"][int(box.cls)], 0.25)
]
if track_inputs:
tracked_objects = tracker.update(
np.array([t["bbox"] for t in track_inputs]),
np.array([t["conf"] for t in track_inputs]),
np.array([t["cls"] for t in track_inputs]),
current_time
)
logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers")
for obj in tracked_objects:
label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']))
if not label:
continue
tracker_id = obj['id']
if tracker_id not in worker_id_mapping:
worker_id_mapping[tracker_id] = worker_counter
worker_counter += 1
worker_id = worker_id_mapping[tracker_id]
violation_key = (worker_id, label)
if violation_key not in unique_violations:
unique_violations[violation_key] = current_time
violation_frames[violation_key] = frame_idx
violations.append({
"worker_id": worker_id,
"violation": label,
"timestamp": current_time,
"confidence": round(obj['score'], 2),
"frame_idx": frame_idx
})
cap.release()
# Capture snapshots
cap = cv2.VideoCapture(video_path)
for violation in violations:
frame_idx = violation["frame_idx"]
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
continue
frame = preprocess_frame(frame)
snapshot_frame = draw_detections(frame, [{
"worker_id": violation["worker_id"],
"violation": violation["violation"],
"confidence": violation["confidence"],
"bounding_box": violation.get("bounding_box", [0, 0, 0, 0]),
"timestamp": violation["timestamp"]
}])
snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
snapshot_path = os.path.join(output_dir, snapshot_filename)
cv2.imwrite(snapshot_path, snapshot_frame, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
snapshots.append({
"violation": violation["violation"],
"worker_id": violation["worker_id"],
"timestamp": violation["timestamp"],
"snapshot_path": snapshot_path,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
"confidence": violation["confidence"]
})
cap.release()
score = calculate_safety_score(violations)
pdf_path = await generate_violation_pdf(violations, score, output_dir)
violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n|-----------|-----------|----------|------------|\n"
for v in sorted(violations, key=lambda x: (x["worker_id"], x["timestamp"])):
violation_table += f"| {CONFIG['DISPLAY_NAMES'][v['violation']]} | {v['worker_id']} | {v['timestamp']:.2f} | {v['confidence']:.2f} |\n"
snapshots_text = "".join(
f"### {CONFIG['DISPLAY_NAMES'][s['violation']]} - Worker {s['worker_id']} at {s['timestamp']:.2f}s\n\n![Violation]({s['snapshot_url']})\n\n"
for s in snapshots
) or "No snapshots captured."
yield (
violation_table,
f"Safety Score: {score}%",
snapshots_text,
pdf_path # Return the PDF path for download
)
except Exception as e:
logger.error(f"Error processing video: {str(e)}")
yield f"Error: {str(e)}", "", "", ""
finally:
if video_path and os.path.exists(video_path):
os.remove(video_path)
if device.type == "cuda":
torch.cuda.empty_cache()
async def gradio_interface(video_file=None, stream_url=None):
# Input validation: yield error message instead of returning
if not video_file and not stream_url:
yield "Please upload a video or provide a stream URL.", "", "", ""
return # Use bare return to exit the generator
temp_dir = tempfile.mkdtemp(prefix="Ultralytics_")
try:
if video_file:
with open(video_file, "rb") as f:
video_data = f.read()
async for result in process_video(video_data, temp_dir):
yield result
elif stream_url:
cap = cv2.VideoCapture(stream_url)
if not cap.isOpened():
yield "Failed to open stream.", "", "", ""
return
temp_file = os.path.join(temp_dir, "stream.mp4")
writer = None
start_time = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start_time < CONFIG["MAX_PROCESSING_TIME"]:
ret, frame = cap.read()
if not ret:
break
if writer is None:
writer = cv2.VideoWriter(temp_file, cv2.VideoWriter_fourcc(*'mp4v'), 30, (int(cap.get(3)), int(cap.get(4))))
writer.write(frame)
cap.release()
writer.release()
with open(temp_file, "rb") as f:
video_data = f.read()
async for result in process_video(video_data, temp_dir):
yield result
except Exception as e:
logger.error(f"Error in gradio_interface: {str(e)}")
yield f"Error: {str(e)}", "", "", ""
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
# Gradio Interface
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Video(label="Upload Site Video"),
gr.Textbox(label="Or Enter RTSP/HTTP Stream URL", placeholder="rtsp://example.com/stream")
],
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.File(label="Download Violation Report PDF")
],
title="Enhanced Worksite Safety Analyzer",
description="Upload a video or provide a stream URL to detect safety violations.",
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("Launching Safety Analyzer...")
interface.launch()