Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import sys | |
| import warnings | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from contextlib import asynccontextmanager | |
| import torch | |
| from ultralytics import YOLO | |
| import cv2 | |
| import requests | |
| import json | |
| import time | |
| import numpy as np | |
| from pathlib import Path | |
| from datetime import datetime | |
| import logging | |
| import pandas as pd | |
| # --- Initial Configuration --- | |
| # Suppress all warnings | |
| warnings.filterwarnings("ignore") | |
| # Set up YOLO_CONFIG_DIR only once | |
| yolo_config_dir = "/tmp/Ultralytics" | |
| if not hasattr(sys, '_yolo_config_initialized'): | |
| try: | |
| if os.path.exists(yolo_config_dir): | |
| shutil.rmtree(yolo_config_dir) | |
| os.makedirs(yolo_config_dir, exist_ok=True) | |
| os.chmod(yolo_config_dir, 0o777) # Ensure directory is writable | |
| os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir | |
| sys._yolo_config_initialized = True | |
| print(f"YOLO_CONFIG_DIR initialized to: {yolo_config_dir}") | |
| except Exception as e: | |
| print(f"Failed to set up YOLO_CONFIG_DIR: {e}") | |
| raise | |
| # --- Logging Configuration --- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler()] | |
| ) | |
| # Get logger instance | |
| logger = logging.getLogger(__name__) | |
| # Suppress third-party logging | |
| logging.getLogger("ultralytics").setLevel(logging.ERROR) | |
| logging.getLogger("PIL").setLevel(logging.WARNING) | |
| logging.getLogger("matplotlib").setLevel(logging.WARNING) | |
| logging.getLogger("urllib3").setLevel(logging.WARNING) | |
| # Log environment variable | |
| logger.info(f"YOLO_CONFIG_DIR set to: {os.getenv('YOLO_CONFIG_DIR')}") | |
| # --- Environment Variables --- | |
| RTSP_URL = os.getenv("RTSP_URL", "rtsp://localhost:8554/stream") | |
| SALESFORCE_URL = os.getenv("SALESFORCE_URL", "https://your_salesforce_instance_url") | |
| SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "your_salesforce_access_token") | |
| HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "https://api-inference.huggingface.co/models/PrashanthB461/SafetyViolationAI1") | |
| HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "your_huggingface_api_token") | |
| # --- Global Model Instance --- | |
| yolo_model = None | |
| # --- Model Class --- | |
| class YOLOv8Model: | |
| def __init__(self, model_path='/app/yolov8n.pt'): | |
| try: | |
| logger.info("Checking for model weights at %s", model_path) | |
| if not os.path.exists(model_path): | |
| logger.error("Model weights file %s not found", model_path) | |
| raise FileNotFoundError(f"Model weights file {model_path} not found") | |
| logger.info("Initializing YOLOv8 model") | |
| original_stdout = sys.stdout | |
| sys.stdout = open(os.devnull, 'w') | |
| try: | |
| self.model = YOLO(model_path) | |
| logger.info("YOLOv8 model loaded successfully") | |
| finally: | |
| sys.stdout.close() | |
| sys.stdout = original_stdout | |
| except Exception as e: | |
| logger.error(f"Failed to load YOLOv8 model: {e}") | |
| raise | |
| def predict(self, image): | |
| try: | |
| results = self.model(image) | |
| return results.pandas().xyxy[0] | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| raise | |
| # --- Frame Processing Functions --- | |
| def preprocess_frame(frame): | |
| try: | |
| img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| img_resized = cv2.resize(img, (640, 640)) | |
| return img_resized | |
| except Exception as e: | |
| logger.error(f"Frame preprocessing error: {e}") | |
| raise | |
| def capture_rtsp_frames(rtsp_url: str): | |
| try: | |
| logger.info(f"Attempting to connect to RTSP stream: {rtsp_url}") | |
| cap = cv2.VideoCapture(rtsp_url) | |
| if not cap.isOpened(): | |
| logger.error(f"Failed to open RTSP stream: {rtsp_url}") | |
| raise ValueError("RTSP stream not accessible") | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if ret: | |
| timestamp = datetime.utcnow().isoformat() | |
| yield frame, timestamp | |
| else: | |
| logger.warning("Failed to read frame from RTSP stream") | |
| break | |
| cap.release() | |
| except Exception as e: | |
| logger.error(f"RTSP capture error: {e}") | |
| raise | |
| # --- Violation Handling Functions --- | |
| def save_snapshot(frame): | |
| try: | |
| filename = f"snapshot_{int(time.time())}.jpg" | |
| snapshot_path = Path("/snapshots") / filename | |
| os.makedirs("/snapshots", exist_ok=True) | |
| os.chmod("/snapshots", 0o777) # Ensure directory is writable | |
| cv2.imwrite(str(snapshot_path), frame) | |
| return f"/snapshots/{filename}" | |
| except Exception as e: | |
| logger.error(f"Snapshot saving error: {e}") | |
| raise | |
| def log_violation(violation_data): | |
| try: | |
| log_file = Path("/snapshots/violation_logs.json") | |
| logs = [] | |
| if log_file.exists(): | |
| with open(log_file, "r") as f: | |
| logs = json.load(f) | |
| logs.append(violation_data) | |
| with open(log_file, "w") as f: | |
| json.dump(logs, f, indent=4) | |
| except Exception as e: | |
| logger.error(f"Violation logging error: {e}") | |
| raise | |
| def send_alert(violation): | |
| logger.info(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}") | |
| def create_salesforce_violation_record(violation_data): | |
| try: | |
| salesforce_url = f"{SALESFORCE_URL}/services/data/v60.0/sobjects/Safety_Violation_Log__c/" | |
| headers = { | |
| 'Authorization': f'Bearer {SALESFORCE_TOKEN}', | |
| 'Content-Type': 'application/json' | |
| } | |
| violation_obj = { | |
| 'Site_ID__c': violation_data['site_id'], | |
| 'Camera_ID__c': violation_data['camera_id'], | |
| 'Violation_Type__c': violation_data['violation_type'], | |
| 'Timestamp__c': violation_data['timestamp'], | |
| 'Snapshot_URL__c': violation_data['snapshot_url'], | |
| 'Severity__c': violation_data['severity'], | |
| 'Alert_Sent__c': True, | |
| 'Resolved__c': False | |
| } | |
| response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj)) | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| logger.error(f"Salesforce integration error: {e}") | |
| raise | |
| # --- FastAPI Application --- | |
| async def lifespan(app: FastAPI): | |
| global yolo_model | |
| logger.info("FastAPI application starting up") | |
| try: | |
| yolo_model = YOLOv8Model() | |
| logger.info("YOLOv8 model initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Startup error: {e}") | |
| raise | |
| yield | |
| logger.info("FastAPI application shutting down") | |
| yolo_model = None | |
| app = FastAPI( | |
| lifespan=lifespan, | |
| title="Safety Violation Detection API", | |
| description="API for detecting safety violations using YOLOv8", | |
| version="1.0.0" | |
| ) | |
| # --- API Endpoints --- | |
| async def detect_violation(): | |
| try: | |
| global yolo_model | |
| if yolo_model is None: | |
| raise HTTPException(status_code=500, detail="YOLO model not initialized") | |
| for frame, timestamp in capture_rtsp_frames(RTSP_URL): | |
| frame_processed = preprocess_frame(frame) | |
| results = yolo_model.predict(frame_processed) | |
| for index, row in results.iterrows(): | |
| severity = "Critical" if row['conf'] > 0.8 else "Moderate" if row['conf'] > 0.5 else "Minor" | |
| violation = { | |
| 'site_id': "Site1", | |
| 'camera_id': "Camera1", | |
| 'violation_type': row['name'], | |
| 'timestamp': timestamp, | |
| 'snapshot_url': save_snapshot(frame), | |
| 'severity': severity | |
| } | |
| log_violation(violation) | |
| create_salesforce_violation_record(violation) | |
| send_alert(violation) | |
| return {"status": "Violation detection complete."} | |
| except Exception as e: | |
| logger.error(f"Error processing stream: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error processing stream: {e}") | |
| async def upload_image(file: UploadFile = File(...)): | |
| try: | |
| global yolo_model | |
| if yolo_model is None: | |
| raise HTTPException(status_code=500, detail="YOLO model not initialized") | |
| image_data = await file.read() | |
| image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) | |
| results = yolo_model.predict(image) | |
| return {"results": results.to_dict()} | |
| except Exception as e: | |
| logger.error(f"Error processing image: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {e}") | |
| async def health_check(): | |
| return {"status": "Running smoothly"} | |
| async def root(): | |
| return { | |
| "message": "Safety Violation Detection API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "/detect_violation": "POST - Process RTSP stream for violations", | |
| "/upload_image": "POST - Analyze uploaded image for violations", | |
| "/health_check": "GET - Check API status" | |
| } | |
| } |