Spaces:
Runtime error
Runtime error
File size: 9,372 Bytes
a719f85 9955bc2 d157456 388bfb9 3bdfb30 388bfb9 3bffc30 0aa58ed c3ce28d 3bffc30 c3ce28d f8f8eae c3ce28d f8f8eae d157456 c3ce28d d157456 c3ce28d d157456 c3ce28d 0aa58ed 3bffc30 0660ef9 f8f8eae a719f85 0aa58ed c3ce28d 69350d2 388bfb9 c3ce28d 0aa58ed c3ce28d 388bfb9 f8f8eae 69350d2 f8f8eae 0660ef9 c3ce28d 5ef6e4d c3ce28d d157456 69350d2 388bfb9 69350d2 3bffc30 69350d2 388bfb9 c3ce28d 388bfb9 69350d2 388bfb9 69350d2 5ef6e4d 69350d2 c3ce28d 69350d2 388bfb9 c3ce28d 388bfb9 69350d2 5ef6e4d 69350d2 c3ce28d 69350d2 388bfb9 69350d2 388bfb9 69350d2 388bfb9 69350d2 388bfb9 69350d2 388bfb9 c3ce28d 3bffc30 c3ce28d 388bfb9 0aa58ed c3ce28d 388bfb9 0aa58ed 388bfb9 69350d2 388bfb9 69350d2 388bfb9 85f2cb1 69350d2 388bfb9 69350d2 388bfb9 69350d2 388bfb9 0aa58ed c3ce28d 388bfb9 0aa58ed 388bfb9 69350d2 388bfb9 d157456 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
import os
import shutil
import sys
import warnings
from fastapi import FastAPI, File, UploadFile, HTTPException
from contextlib import asynccontextmanager
import torch
from ultralytics import YOLO
import cv2
import requests
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import logging
import pandas as pd
# --- Initial Configuration ---
# Suppress all warnings
warnings.filterwarnings("ignore")
# Set up YOLO_CONFIG_DIR only once
yolo_config_dir = "/tmp/Ultralytics"
if not hasattr(sys, '_yolo_config_initialized'):
try:
if os.path.exists(yolo_config_dir):
shutil.rmtree(yolo_config_dir)
os.makedirs(yolo_config_dir, exist_ok=True)
os.chmod(yolo_config_dir, 0o777) # Ensure directory is writable
os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir
sys._yolo_config_initialized = True
print(f"YOLO_CONFIG_DIR initialized to: {yolo_config_dir}")
except Exception as e:
print(f"Failed to set up YOLO_CONFIG_DIR: {e}")
raise
# --- Logging Configuration ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
# Get logger instance
logger = logging.getLogger(__name__)
# Suppress third-party logging
logging.getLogger("ultralytics").setLevel(logging.ERROR)
logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
# Log environment variable
logger.info(f"YOLO_CONFIG_DIR set to: {os.getenv('YOLO_CONFIG_DIR')}")
# --- Environment Variables ---
RTSP_URL = os.getenv("RTSP_URL", "rtsp://localhost:8554/stream")
SALESFORCE_URL = os.getenv("SALESFORCE_URL", "https://your_salesforce_instance_url")
SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "your_salesforce_access_token")
HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "https://api-inference.huggingface.co/models/PrashanthB461/SafetyViolationAI1")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "your_huggingface_api_token")
# --- Global Model Instance ---
yolo_model = None
# --- Model Class ---
class YOLOv8Model:
def __init__(self, model_path='/app/yolov8n.pt'):
try:
logger.info("Checking for model weights at %s", model_path)
if not os.path.exists(model_path):
logger.error("Model weights file %s not found", model_path)
raise FileNotFoundError(f"Model weights file {model_path} not found")
logger.info("Initializing YOLOv8 model")
original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
try:
self.model = YOLO(model_path)
logger.info("YOLOv8 model loaded successfully")
finally:
sys.stdout.close()
sys.stdout = original_stdout
except Exception as e:
logger.error(f"Failed to load YOLOv8 model: {e}")
raise
def predict(self, image):
try:
results = self.model(image)
return results.pandas().xyxy[0]
except Exception as e:
logger.error(f"Prediction error: {e}")
raise
# --- Frame Processing Functions ---
def preprocess_frame(frame):
try:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img, (640, 640))
return img_resized
except Exception as e:
logger.error(f"Frame preprocessing error: {e}")
raise
def capture_rtsp_frames(rtsp_url: str):
try:
logger.info(f"Attempting to connect to RTSP stream: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream: {rtsp_url}")
raise ValueError("RTSP stream not accessible")
while cap.isOpened():
ret, frame = cap.read()
if ret:
timestamp = datetime.utcnow().isoformat()
yield frame, timestamp
else:
logger.warning("Failed to read frame from RTSP stream")
break
cap.release()
except Exception as e:
logger.error(f"RTSP capture error: {e}")
raise
# --- Violation Handling Functions ---
def save_snapshot(frame):
try:
filename = f"snapshot_{int(time.time())}.jpg"
snapshot_path = Path("/snapshots") / filename
os.makedirs("/snapshots", exist_ok=True)
os.chmod("/snapshots", 0o777) # Ensure directory is writable
cv2.imwrite(str(snapshot_path), frame)
return f"/snapshots/{filename}"
except Exception as e:
logger.error(f"Snapshot saving error: {e}")
raise
def log_violation(violation_data):
try:
log_file = Path("/snapshots/violation_logs.json")
logs = []
if log_file.exists():
with open(log_file, "r") as f:
logs = json.load(f)
logs.append(violation_data)
with open(log_file, "w") as f:
json.dump(logs, f, indent=4)
except Exception as e:
logger.error(f"Violation logging error: {e}")
raise
def send_alert(violation):
logger.info(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}")
def create_salesforce_violation_record(violation_data):
try:
salesforce_url = f"{SALESFORCE_URL}/services/data/v60.0/sobjects/Safety_Violation_Log__c/"
headers = {
'Authorization': f'Bearer {SALESFORCE_TOKEN}',
'Content-Type': 'application/json'
}
violation_obj = {
'Site_ID__c': violation_data['site_id'],
'Camera_ID__c': violation_data['camera_id'],
'Violation_Type__c': violation_data['violation_type'],
'Timestamp__c': violation_data['timestamp'],
'Snapshot_URL__c': violation_data['snapshot_url'],
'Severity__c': violation_data['severity'],
'Alert_Sent__c': True,
'Resolved__c': False
}
response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj))
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Salesforce integration error: {e}")
raise
# --- FastAPI Application ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global yolo_model
logger.info("FastAPI application starting up")
try:
yolo_model = YOLOv8Model()
logger.info("YOLOv8 model initialized successfully")
except Exception as e:
logger.error(f"Startup error: {e}")
raise
yield
logger.info("FastAPI application shutting down")
yolo_model = None
app = FastAPI(
lifespan=lifespan,
title="Safety Violation Detection API",
description="API for detecting safety violations using YOLOv8",
version="1.0.0"
)
# --- API Endpoints ---
@app.post("/detect_violation/")
async def detect_violation():
try:
global yolo_model
if yolo_model is None:
raise HTTPException(status_code=500, detail="YOLO model not initialized")
for frame, timestamp in capture_rtsp_frames(RTSP_URL):
frame_processed = preprocess_frame(frame)
results = yolo_model.predict(frame_processed)
for index, row in results.iterrows():
severity = "Critical" if row['conf'] > 0.8 else "Moderate" if row['conf'] > 0.5 else "Minor"
violation = {
'site_id': "Site1",
'camera_id': "Camera1",
'violation_type': row['name'],
'timestamp': timestamp,
'snapshot_url': save_snapshot(frame),
'severity': severity
}
log_violation(violation)
create_salesforce_violation_record(violation)
send_alert(violation)
return {"status": "Violation detection complete."}
except Exception as e:
logger.error(f"Error processing stream: {e}")
raise HTTPException(status_code=500, detail=f"Error processing stream: {e}")
@app.post("/upload_image/")
async def upload_image(file: UploadFile = File(...)):
try:
global yolo_model
if yolo_model is None:
raise HTTPException(status_code=500, detail="YOLO model not initialized")
image_data = await file.read()
image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
results = yolo_model.predict(image)
return {"results": results.to_dict()}
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
@app.get("/health_check/")
async def health_check():
return {"status": "Running smoothly"}
@app.get("/")
async def root():
return {
"message": "Safety Violation Detection API",
"version": "1.0.0",
"endpoints": {
"/detect_violation": "POST - Process RTSP stream for violations",
"/upload_image": "POST - Analyze uploaded image for violations",
"/health_check": "GET - Check API status"
}
} |