PrashanthB461's picture
Update app.py
f8f8eae verified
import os
import shutil
import sys
import warnings
from fastapi import FastAPI, File, UploadFile, HTTPException
from contextlib import asynccontextmanager
import torch
from ultralytics import YOLO
import cv2
import requests
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import logging
import pandas as pd
# --- Initial Configuration ---
# Suppress all warnings
warnings.filterwarnings("ignore")
# Set up YOLO_CONFIG_DIR only once
yolo_config_dir = "/tmp/Ultralytics"
if not hasattr(sys, '_yolo_config_initialized'):
try:
if os.path.exists(yolo_config_dir):
shutil.rmtree(yolo_config_dir)
os.makedirs(yolo_config_dir, exist_ok=True)
os.chmod(yolo_config_dir, 0o777) # Ensure directory is writable
os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir
sys._yolo_config_initialized = True
print(f"YOLO_CONFIG_DIR initialized to: {yolo_config_dir}")
except Exception as e:
print(f"Failed to set up YOLO_CONFIG_DIR: {e}")
raise
# --- Logging Configuration ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
# Get logger instance
logger = logging.getLogger(__name__)
# Suppress third-party logging
logging.getLogger("ultralytics").setLevel(logging.ERROR)
logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
# Log environment variable
logger.info(f"YOLO_CONFIG_DIR set to: {os.getenv('YOLO_CONFIG_DIR')}")
# --- Environment Variables ---
RTSP_URL = os.getenv("RTSP_URL", "rtsp://localhost:8554/stream")
SALESFORCE_URL = os.getenv("SALESFORCE_URL", "https://your_salesforce_instance_url")
SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "your_salesforce_access_token")
HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "https://api-inference.huggingface.co/models/PrashanthB461/SafetyViolationAI1")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "your_huggingface_api_token")
# --- Global Model Instance ---
yolo_model = None
# --- Model Class ---
class YOLOv8Model:
def __init__(self, model_path='/app/yolov8n.pt'):
try:
logger.info("Checking for model weights at %s", model_path)
if not os.path.exists(model_path):
logger.error("Model weights file %s not found", model_path)
raise FileNotFoundError(f"Model weights file {model_path} not found")
logger.info("Initializing YOLOv8 model")
original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
try:
self.model = YOLO(model_path)
logger.info("YOLOv8 model loaded successfully")
finally:
sys.stdout.close()
sys.stdout = original_stdout
except Exception as e:
logger.error(f"Failed to load YOLOv8 model: {e}")
raise
def predict(self, image):
try:
results = self.model(image)
return results.pandas().xyxy[0]
except Exception as e:
logger.error(f"Prediction error: {e}")
raise
# --- Frame Processing Functions ---
def preprocess_frame(frame):
try:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img, (640, 640))
return img_resized
except Exception as e:
logger.error(f"Frame preprocessing error: {e}")
raise
def capture_rtsp_frames(rtsp_url: str):
try:
logger.info(f"Attempting to connect to RTSP stream: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream: {rtsp_url}")
raise ValueError("RTSP stream not accessible")
while cap.isOpened():
ret, frame = cap.read()
if ret:
timestamp = datetime.utcnow().isoformat()
yield frame, timestamp
else:
logger.warning("Failed to read frame from RTSP stream")
break
cap.release()
except Exception as e:
logger.error(f"RTSP capture error: {e}")
raise
# --- Violation Handling Functions ---
def save_snapshot(frame):
try:
filename = f"snapshot_{int(time.time())}.jpg"
snapshot_path = Path("/snapshots") / filename
os.makedirs("/snapshots", exist_ok=True)
os.chmod("/snapshots", 0o777) # Ensure directory is writable
cv2.imwrite(str(snapshot_path), frame)
return f"/snapshots/{filename}"
except Exception as e:
logger.error(f"Snapshot saving error: {e}")
raise
def log_violation(violation_data):
try:
log_file = Path("/snapshots/violation_logs.json")
logs = []
if log_file.exists():
with open(log_file, "r") as f:
logs = json.load(f)
logs.append(violation_data)
with open(log_file, "w") as f:
json.dump(logs, f, indent=4)
except Exception as e:
logger.error(f"Violation logging error: {e}")
raise
def send_alert(violation):
logger.info(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}")
def create_salesforce_violation_record(violation_data):
try:
salesforce_url = f"{SALESFORCE_URL}/services/data/v60.0/sobjects/Safety_Violation_Log__c/"
headers = {
'Authorization': f'Bearer {SALESFORCE_TOKEN}',
'Content-Type': 'application/json'
}
violation_obj = {
'Site_ID__c': violation_data['site_id'],
'Camera_ID__c': violation_data['camera_id'],
'Violation_Type__c': violation_data['violation_type'],
'Timestamp__c': violation_data['timestamp'],
'Snapshot_URL__c': violation_data['snapshot_url'],
'Severity__c': violation_data['severity'],
'Alert_Sent__c': True,
'Resolved__c': False
}
response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj))
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Salesforce integration error: {e}")
raise
# --- FastAPI Application ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global yolo_model
logger.info("FastAPI application starting up")
try:
yolo_model = YOLOv8Model()
logger.info("YOLOv8 model initialized successfully")
except Exception as e:
logger.error(f"Startup error: {e}")
raise
yield
logger.info("FastAPI application shutting down")
yolo_model = None
app = FastAPI(
lifespan=lifespan,
title="Safety Violation Detection API",
description="API for detecting safety violations using YOLOv8",
version="1.0.0"
)
# --- API Endpoints ---
@app.post("/detect_violation/")
async def detect_violation():
try:
global yolo_model
if yolo_model is None:
raise HTTPException(status_code=500, detail="YOLO model not initialized")
for frame, timestamp in capture_rtsp_frames(RTSP_URL):
frame_processed = preprocess_frame(frame)
results = yolo_model.predict(frame_processed)
for index, row in results.iterrows():
severity = "Critical" if row['conf'] > 0.8 else "Moderate" if row['conf'] > 0.5 else "Minor"
violation = {
'site_id': "Site1",
'camera_id': "Camera1",
'violation_type': row['name'],
'timestamp': timestamp,
'snapshot_url': save_snapshot(frame),
'severity': severity
}
log_violation(violation)
create_salesforce_violation_record(violation)
send_alert(violation)
return {"status": "Violation detection complete."}
except Exception as e:
logger.error(f"Error processing stream: {e}")
raise HTTPException(status_code=500, detail=f"Error processing stream: {e}")
@app.post("/upload_image/")
async def upload_image(file: UploadFile = File(...)):
try:
global yolo_model
if yolo_model is None:
raise HTTPException(status_code=500, detail="YOLO model not initialized")
image_data = await file.read()
image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
results = yolo_model.predict(image)
return {"results": results.to_dict()}
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
@app.get("/health_check/")
async def health_check():
return {"status": "Running smoothly"}
@app.get("/")
async def root():
return {
"message": "Safety Violation Detection API",
"version": "1.0.0",
"endpoints": {
"/detect_violation": "POST - Process RTSP stream for violations",
"/upload_image": "POST - Analyze uploaded image for violations",
"/health_check": "GET - Check API status"
}
}