from fastapi import FastAPI, File, UploadFile, HTTPException from typing import List, Dict import cv2 import torch import time import requests import json from io import BytesIO import os import numpy as np from pathlib import Path # --- Configuration (config.py) --- RTSP_URL = "rtsp://your_rtsp_stream_url" SALESFORCE_URL = "https://your_salesforce_instance_url" SALESFORCE_TOKEN = "your_salesforce_access_token" HUGGINGFACE_API_URL = "https://huggingface.co/your_model_endpoint" HUGGINGFACE_TOKEN = "your_huggingface_api_token" # --- Initialize FastAPI app --- app = FastAPI() # --- YOLOv8 Model (yolo_model.py) --- class YOLOv8Model: def __init__(self, model_name='yolov8'): self.model = torch.hub.load('ultralytics/yolov5', model_name) # YOLOv8 based on YOLOv5 self.model.eval() def predict(self, image): """ Run inference on the image and return results. Returns bounding boxes, class, and confidence score. """ results = self.model(image) # Inference return results.pandas().xywh[0] # Returns bounding boxes, class names, confidence score # --- Preprocessing RTSP Frame (preprocess.py) --- def preprocess_frame(frame): """ Preprocesses a frame for YOLOv8 by converting it to RGB and resizing it. """ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert frame to RGB img_resized = cv2.resize(img, (640, 640)) # Resize to 640x640 (YOLOv8 input size) return img_resized # --- RTSP Stream Handler (rtsp_stream.py) --- def capture_rtsp_frames(rtsp_url: str): """ Captures frames from an RTSP stream at a 1-2 second interval. """ cap = cv2.VideoCapture(rtsp_url) while cap.isOpened(): ret, frame = cap.read() if ret: timestamp = time.time() yield frame, timestamp else: break cap.release() # --- Save Violations and Snapshots (violation_log.py) --- def save_snapshot(frame): """ Save the frame as a snapshot and return the snapshot URL. """ filename = f"snapshot_{int(time.time())}.jpg" snapshot_path = Path("./snapshots") / filename os.makedirs("./snapshots", exist_ok=True) cv2.imwrite(str(snapshot_path), frame) return f"http://localhost/snapshots/{filename}" # URL for local testing def log_violation(violation_data): """ Save the violation data to a JSON file for further processing. """ log_file = Path("./violation_logs.json") if log_file.exists(): with open(log_file, "r") as f: logs = json.load(f) else: logs = [] logs.append(violation_data) with open(log_file, "w") as f: json.dump(logs, f, indent=4) # --- Notification System (notification.py) --- def send_alert(violation): """ Sends an alert to site authorities about the detected violation. """ # Placeholder for notification logic (e.g., sending an email or SMS) print(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}") # Here, you can integrate with an actual notification service (e.g., SendGrid, Twilio) # --- Salesforce Integration (create_violation.py) --- def create_salesforce_violation_record(violation_data): """ Create a violation record in Salesforce using the REST API. """ salesforce_url = f"{SALESFORCE_URL}/services/data/vXX.0/sobjects/Safety_Violation_Log__c/" headers = { 'Authorization': f'Bearer {SALESFORCE_TOKEN}', 'Content-Type': 'application/json' } violation_obj = { 'Site_ID__c': violation_data['site_id'], 'Violation_Type__c': violation_data['violation_type'], 'Timestamp__c': violation_data['timestamp'], 'Snapshot_URL__c': violation_data['snapshot_url'], 'Severity__c': violation_data['severity'], 'Alert_Sent__c': True, 'Resolved__c': False } response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj)) return response.json() # --- API Routes (api.py) --- @app.post("/detect_violation/") async def detect_violation(): try: # Capture RTSP stream frames for frame, timestamp in capture_rtsp_frames(RTSP_URL): # Preprocess frame and run YOLOv8 inference frame_processed = preprocess_frame(frame) model = YOLOv8Model() # Load YOLOv8 model results = model.predict(frame_processed) # Log violation and send alert if violation detected for index, row in results.iterrows(): violation = { 'site_id': "Site1", # Placeholder, should be dynamic 'violation_type': row['name'], 'timestamp': timestamp, 'snapshot_url': save_snapshot(frame), # Save snapshot and get URL 'severity': row['confidence'] } create_salesforce_violation_record(violation) # Log to Salesforce send_alert(violation) # Send alert to site HSE return {"status": "Violation detection complete."} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing stream: {e}") @app.post("/upload_image/") async def upload_image(file: UploadFile = File(...)): try: # Read the uploaded image image_data = await file.read() image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) # Run inference model = YOLOv8Model() results = model.predict(image) return {"results": results.to_dict()} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {e}") @app.get("/health_check/") async def health_check(): return {"status": "Running smoothly"} # --- Dockerfile --- # Dockerfile for deployment # Create a Dockerfile to containerize the FastAPI app # FROM python:3.9-slim # WORKDIR /app # COPY requirements.txt . # RUN pip install -r requirements.txt # COPY . . # EXPOSE 8000 # CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] # --- requirements.txt --- # fastapi # uvicorn # torch # opencv-python # requests # numpy # pandas # simple_salesforce # pydantic # --- tests/test_inference.py --- # Removed pytest import to avoid the error on app startup def test_inference(): # Create a dummy frame for testing frame = np.zeros((640, 640, 3), dtype=np.uint8) result = process_and_predict(frame) # Check that result is in the expected format assert isinstance(result, pd.DataFrame), "Expected results in DataFrame format" assert "name" in result.columns, "Violation class (name) missing in results" # --- tests/test_violation_log.py --- def test_log_violation(): violation_data = { 'site_id': "Site1", 'violation_type': "No Helmet", 'timestamp': 1234567890, 'snapshot_url': save_snapshot("dummy_frame"), 'severity': "Critical" } log_violation(violation_data) # Check if the violation log file exists assert os.path.exists("./violation_logs.json"), "Violation logs file not found" # --- tests/test_notification.py --- def test_send_alert(): violation = {'violation_type': 'No Helmet', 'severity': 'Critical'} send_alert(violation) # Here you can mock or check if send_alert() works as expected