import os import shutil import sys import warnings from fastapi import FastAPI, File, UploadFile, HTTPException from contextlib import asynccontextmanager import torch from ultralytics import YOLO import cv2 import requests import json import time import numpy as np from pathlib import Path from datetime import datetime import logging import pandas as pd # --- Initial Configuration --- # Suppress all warnings warnings.filterwarnings("ignore") # Set up YOLO_CONFIG_DIR only once yolo_config_dir = "/tmp/Ultralytics" if not hasattr(sys, '_yolo_config_initialized'): try: if os.path.exists(yolo_config_dir): shutil.rmtree(yolo_config_dir) os.makedirs(yolo_config_dir, exist_ok=True) os.chmod(yolo_config_dir, 0o777) # Ensure directory is writable os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir sys._yolo_config_initialized = True print(f"YOLO_CONFIG_DIR initialized to: {yolo_config_dir}") except Exception as e: print(f"Failed to set up YOLO_CONFIG_DIR: {e}") raise # --- Logging Configuration --- logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] ) # Get logger instance logger = logging.getLogger(__name__) # Suppress third-party logging logging.getLogger("ultralytics").setLevel(logging.ERROR) logging.getLogger("PIL").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) # Log environment variable logger.info(f"YOLO_CONFIG_DIR set to: {os.getenv('YOLO_CONFIG_DIR')}") # --- Environment Variables --- RTSP_URL = os.getenv("RTSP_URL", "rtsp://localhost:8554/stream") SALESFORCE_URL = os.getenv("SALESFORCE_URL", "https://your_salesforce_instance_url") SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "your_salesforce_access_token") HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "https://api-inference.huggingface.co/models/PrashanthB461/SafetyViolationAI1") HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "your_huggingface_api_token") # --- Global Model Instance --- yolo_model = None # --- Model Class --- class YOLOv8Model: def __init__(self, model_path='/app/yolov8n.pt'): try: logger.info("Checking for model weights at %s", model_path) if not os.path.exists(model_path): logger.error("Model weights file %s not found", model_path) raise FileNotFoundError(f"Model weights file {model_path} not found") logger.info("Initializing YOLOv8 model") original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') try: self.model = YOLO(model_path) logger.info("YOLOv8 model loaded successfully") finally: sys.stdout.close() sys.stdout = original_stdout except Exception as e: logger.error(f"Failed to load YOLOv8 model: {e}") raise def predict(self, image): try: results = self.model(image) return results.pandas().xyxy[0] except Exception as e: logger.error(f"Prediction error: {e}") raise # --- Frame Processing Functions --- def preprocess_frame(frame): try: img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) img_resized = cv2.resize(img, (640, 640)) return img_resized except Exception as e: logger.error(f"Frame preprocessing error: {e}") raise def capture_rtsp_frames(rtsp_url: str): try: logger.info(f"Attempting to connect to RTSP stream: {rtsp_url}") cap = cv2.VideoCapture(rtsp_url) if not cap.isOpened(): logger.error(f"Failed to open RTSP stream: {rtsp_url}") raise ValueError("RTSP stream not accessible") while cap.isOpened(): ret, frame = cap.read() if ret: timestamp = datetime.utcnow().isoformat() yield frame, timestamp else: logger.warning("Failed to read frame from RTSP stream") break cap.release() except Exception as e: logger.error(f"RTSP capture error: {e}") raise # --- Violation Handling Functions --- def save_snapshot(frame): try: filename = f"snapshot_{int(time.time())}.jpg" snapshot_path = Path("/snapshots") / filename os.makedirs("/snapshots", exist_ok=True) os.chmod("/snapshots", 0o777) # Ensure directory is writable cv2.imwrite(str(snapshot_path), frame) return f"/snapshots/{filename}" except Exception as e: logger.error(f"Snapshot saving error: {e}") raise def log_violation(violation_data): try: log_file = Path("/snapshots/violation_logs.json") logs = [] if log_file.exists(): with open(log_file, "r") as f: logs = json.load(f) logs.append(violation_data) with open(log_file, "w") as f: json.dump(logs, f, indent=4) except Exception as e: logger.error(f"Violation logging error: {e}") raise def send_alert(violation): logger.info(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}") def create_salesforce_violation_record(violation_data): try: salesforce_url = f"{SALESFORCE_URL}/services/data/v60.0/sobjects/Safety_Violation_Log__c/" headers = { 'Authorization': f'Bearer {SALESFORCE_TOKEN}', 'Content-Type': 'application/json' } violation_obj = { 'Site_ID__c': violation_data['site_id'], 'Camera_ID__c': violation_data['camera_id'], 'Violation_Type__c': violation_data['violation_type'], 'Timestamp__c': violation_data['timestamp'], 'Snapshot_URL__c': violation_data['snapshot_url'], 'Severity__c': violation_data['severity'], 'Alert_Sent__c': True, 'Resolved__c': False } response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj)) response.raise_for_status() return response.json() except Exception as e: logger.error(f"Salesforce integration error: {e}") raise # --- FastAPI Application --- @asynccontextmanager async def lifespan(app: FastAPI): global yolo_model logger.info("FastAPI application starting up") try: yolo_model = YOLOv8Model() logger.info("YOLOv8 model initialized successfully") except Exception as e: logger.error(f"Startup error: {e}") raise yield logger.info("FastAPI application shutting down") yolo_model = None app = FastAPI( lifespan=lifespan, title="Safety Violation Detection API", description="API for detecting safety violations using YOLOv8", version="1.0.0" ) # --- API Endpoints --- @app.post("/detect_violation/") async def detect_violation(): try: global yolo_model if yolo_model is None: raise HTTPException(status_code=500, detail="YOLO model not initialized") for frame, timestamp in capture_rtsp_frames(RTSP_URL): frame_processed = preprocess_frame(frame) results = yolo_model.predict(frame_processed) for index, row in results.iterrows(): severity = "Critical" if row['conf'] > 0.8 else "Moderate" if row['conf'] > 0.5 else "Minor" violation = { 'site_id': "Site1", 'camera_id': "Camera1", 'violation_type': row['name'], 'timestamp': timestamp, 'snapshot_url': save_snapshot(frame), 'severity': severity } log_violation(violation) create_salesforce_violation_record(violation) send_alert(violation) return {"status": "Violation detection complete."} except Exception as e: logger.error(f"Error processing stream: {e}") raise HTTPException(status_code=500, detail=f"Error processing stream: {e}") @app.post("/upload_image/") async def upload_image(file: UploadFile = File(...)): try: global yolo_model if yolo_model is None: raise HTTPException(status_code=500, detail="YOLO model not initialized") image_data = await file.read() image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) results = yolo_model.predict(image) return {"results": results.to_dict()} except Exception as e: logger.error(f"Error processing image: {e}") raise HTTPException(status_code=500, detail=f"Error processing image: {e}") @app.get("/health_check/") async def health_check(): return {"status": "Running smoothly"} @app.get("/") async def root(): return { "message": "Safety Violation Detection API", "version": "1.0.0", "endpoints": { "/detect_violation": "POST - Process RTSP stream for violations", "/upload_image": "POST - Analyze uploaded image for violations", "/health_check": "GET - Check API status" } }