Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import logging | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import asyncio | |
| import tempfile | |
| import shutil | |
| from ultralytics import YOLO | |
| from tracker import BYTETracker | |
| from utils import ( | |
| preprocess_frame, draw_detections, calculate_safety_score, | |
| generate_violation_pdf, clean_output_directory | |
| ) | |
| from config import CONFIG, check_ffmpeg | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='{"time": "%(asctime)s", "level": "%(levelname)s", "message": "%(message)s"}', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| FFMPEG_AVAILABLE = check_ffmpeg() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Load YOLO model | |
| def load_model(): | |
| static_dir = "static" | |
| os.makedirs(static_dir, exist_ok=True) | |
| model_path = CONFIG["MODEL_PATH"] | |
| fallback_model = CONFIG["FALLBACK_MODEL"] | |
| if not os.path.isfile(model_path): | |
| logger.warning(f"Custom model {model_path} not found. Using fallback.") | |
| if not os.path.isfile(fallback_model): | |
| logger.info(f"Downloading fallback model: {fallback_model}") | |
| torch.hub.download_url_to_file( | |
| 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', | |
| fallback_model | |
| ) | |
| else: | |
| logger.info(f"Using custom model: {model_path}") | |
| model_path = model_path if os.path.isfile(model_path) else fallback_model | |
| model = YOLO(model_path).to(device) | |
| if device.type == "cuda": | |
| model.model.half() | |
| logger.info(f"Model loaded: {model_path}, classes: {model.names}") | |
| return model | |
| model = load_model() | |
| async def process_video(video_data, temp_dir, progress=gr.Progress()): | |
| output_dir = os.path.join("static", "output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| clean_output_directory() # Clean old files | |
| video_path = None | |
| try: | |
| if not video_data: | |
| raise ValueError("Empty video data provided.") | |
| logger.info(f"Received video data size: {len(video_data)} bytes") | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file: | |
| temp_file.write(video_data) | |
| video_path = temp_file.name | |
| logger.info(f"Video saved to: {video_path}") | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError("Could not open video file.") | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| logger.info(f"Video: {total_frames} frames, {fps:.1f} FPS, {width}x{height}") | |
| tracker = BYTETracker( | |
| track_thresh=CONFIG["TRACK_THRESH"], | |
| track_buffer=CONFIG["TRACK_BUFFER"], | |
| match_thresh=CONFIG["MATCH_THRESH"], | |
| frame_rate=fps | |
| ) | |
| violations = [] | |
| snapshots = [] | |
| worker_id_mapping = {} | |
| unique_violations = {} | |
| violation_frames = {} | |
| worker_counter = 1 | |
| processed_frames = 0 | |
| start_time = asyncio.get_event_loop().time() | |
| while processed_frames < total_frames: | |
| batch_frames = [] | |
| batch_indices = [] | |
| for _ in range(CONFIG["BATCH_SIZE"]): | |
| frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) | |
| if frame_idx >= total_frames: | |
| break | |
| ret, frame = cap.read() | |
| if not ret: | |
| logger.warning(f"Failed to read frame {frame_idx}") | |
| break | |
| frame = preprocess_frame(frame) | |
| batch_frames.append(frame) | |
| batch_indices.append(frame_idx) | |
| processed_frames += CONFIG["FRAME_SKIP"] | |
| for _ in range(CONFIG["FRAME_SKIP"] - 1): | |
| cap.grab() | |
| if not batch_frames: | |
| break | |
| batch_frames_np = np.array(batch_frames) | |
| batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0 | |
| batch_frames_tensor = batch_frames_tensor.to(device) | |
| if device.type == "cuda": | |
| batch_frames_tensor = batch_frames_tensor.half() | |
| results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False) | |
| progress(processed_frames / total_frames, desc=f"Processing {processed_frames}/{total_frames} frames") | |
| for i, (result, frame_idx) in enumerate(zip(results, batch_indices)): | |
| current_time = frame_idx / fps | |
| boxes = result.boxes | |
| track_inputs = [ | |
| {"bbox": box.xywh.cpu().numpy()[0], "conf": float(box.conf), "cls": int(box.cls)} | |
| for box in boxes if CONFIG["VIOLATION_LABELS"].get(int(box.cls)) and float(box.conf) >= CONFIG["CONFIDENCE_THRESHOLDS"].get(CONFIG["VIOLATION_LABELS"][int(box.cls)], 0.25) | |
| ] | |
| if track_inputs: | |
| tracked_objects = tracker.update( | |
| np.array([t["bbox"] for t in track_inputs]), | |
| np.array([t["conf"] for t in track_inputs]), | |
| np.array([t["cls"] for t in track_inputs]), | |
| current_time | |
| ) | |
| logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers") | |
| for obj in tracked_objects: | |
| label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls'])) | |
| if not label: | |
| continue | |
| tracker_id = obj['id'] | |
| if tracker_id not in worker_id_mapping: | |
| worker_id_mapping[tracker_id] = worker_counter | |
| worker_counter += 1 | |
| worker_id = worker_id_mapping[tracker_id] | |
| violation_key = (worker_id, label) | |
| if violation_key not in unique_violations: | |
| unique_violations[violation_key] = current_time | |
| violation_frames[violation_key] = frame_idx | |
| violations.append({ | |
| "worker_id": worker_id, | |
| "violation": label, | |
| "timestamp": current_time, | |
| "confidence": round(obj['score'], 2), | |
| "frame_idx": frame_idx | |
| }) | |
| cap.release() | |
| # Capture snapshots | |
| cap = cv2.VideoCapture(video_path) | |
| for violation in violations: | |
| frame_idx = violation["frame_idx"] | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| ret, frame = cap.read() | |
| if not ret: | |
| continue | |
| frame = preprocess_frame(frame) | |
| snapshot_frame = draw_detections(frame, [{ | |
| "worker_id": violation["worker_id"], | |
| "violation": violation["violation"], | |
| "confidence": violation["confidence"], | |
| "bounding_box": violation.get("bounding_box", [0, 0, 0, 0]), | |
| "timestamp": violation["timestamp"] | |
| }]) | |
| snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg" | |
| snapshot_path = os.path.join(output_dir, snapshot_filename) | |
| cv2.imwrite(snapshot_path, snapshot_frame, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]) | |
| snapshots.append({ | |
| "violation": violation["violation"], | |
| "worker_id": violation["worker_id"], | |
| "timestamp": violation["timestamp"], | |
| "snapshot_path": snapshot_path, | |
| "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}", | |
| "confidence": violation["confidence"] | |
| }) | |
| cap.release() | |
| score = calculate_safety_score(violations) | |
| pdf_path = await generate_violation_pdf(violations, score, output_dir) | |
| violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n|-----------|-----------|----------|------------|\n" | |
| for v in sorted(violations, key=lambda x: (x["worker_id"], x["timestamp"])): | |
| violation_table += f"| {CONFIG['DISPLAY_NAMES'][v['violation']]} | {v['worker_id']} | {v['timestamp']:.2f} | {v['confidence']:.2f} |\n" | |
| snapshots_text = "".join( | |
| f"### {CONFIG['DISPLAY_NAMES'][s['violation']]} - Worker {s['worker_id']} at {s['timestamp']:.2f}s\n\n\n\n" | |
| for s in snapshots | |
| ) or "No snapshots captured." | |
| yield ( | |
| violation_table, | |
| f"Safety Score: {score}%", | |
| snapshots_text, | |
| pdf_path # Return the PDF path for download | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing video: {str(e)}") | |
| yield f"Error: {str(e)}", "", "", "" | |
| finally: | |
| if video_path and os.path.exists(video_path): | |
| os.remove(video_path) | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| async def gradio_interface(video_file=None, stream_url=None): | |
| # Input validation: yield error message instead of returning | |
| if not video_file and not stream_url: | |
| yield "Please upload a video or provide a stream URL.", "", "", "" | |
| return # Use bare return to exit the generator | |
| temp_dir = tempfile.mkdtemp(prefix="Ultralytics_") | |
| try: | |
| if video_file: | |
| with open(video_file, "rb") as f: | |
| video_data = f.read() | |
| async for result in process_video(video_data, temp_dir): | |
| yield result | |
| elif stream_url: | |
| cap = cv2.VideoCapture(stream_url) | |
| if not cap.isOpened(): | |
| yield "Failed to open stream.", "", "", "" | |
| return | |
| temp_file = os.path.join(temp_dir, "stream.mp4") | |
| writer = None | |
| start_time = asyncio.get_event_loop().time() | |
| while asyncio.get_event_loop().time() - start_time < CONFIG["MAX_PROCESSING_TIME"]: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if writer is None: | |
| writer = cv2.VideoWriter(temp_file, cv2.VideoWriter_fourcc(*'mp4v'), 30, (int(cap.get(3)), int(cap.get(4)))) | |
| writer.write(frame) | |
| cap.release() | |
| writer.release() | |
| with open(temp_file, "rb") as f: | |
| video_data = f.read() | |
| async for result in process_video(video_data, temp_dir): | |
| yield result | |
| except Exception as e: | |
| logger.error(f"Error in gradio_interface: {str(e)}") | |
| yield f"Error: {str(e)}", "", "", "" | |
| finally: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| # Gradio Interface | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Video(label="Upload Site Video"), | |
| gr.Textbox(label="Or Enter RTSP/HTTP Stream URL", placeholder="rtsp://example.com/stream") | |
| ], | |
| outputs=[ | |
| gr.Markdown(label="Detected Safety Violations"), | |
| gr.Textbox(label="Compliance Score"), | |
| gr.Markdown(label="Snapshots"), | |
| gr.File(label="Download Violation Report PDF") | |
| ], | |
| title="Enhanced Worksite Safety Analyzer", | |
| description="Upload a video or provide a stream URL to detect safety violations.", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Launching Safety Analyzer...") | |
| interface.launch() |