Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from typing import List, Dict | |
| import cv2 | |
| import torch | |
| import time | |
| import requests | |
| import json | |
| from io import BytesIO | |
| import os | |
| import numpy as np | |
| from pathlib import Path | |
| # --- Configuration (config.py) --- | |
| RTSP_URL = "rtsp://your_rtsp_stream_url" | |
| SALESFORCE_URL = "https://your_salesforce_instance_url" | |
| SALESFORCE_TOKEN = "your_salesforce_access_token" | |
| HUGGINGFACE_API_URL = "https://huggingface.co/your_model_endpoint" | |
| HUGGINGFACE_TOKEN = "your_huggingface_api_token" | |
| # --- Initialize FastAPI app --- | |
| app = FastAPI() | |
| # --- YOLOv8 Model (yolo_model.py) --- | |
| class YOLOv8Model: | |
| def __init__(self, model_name='yolov8'): | |
| self.model = torch.hub.load('ultralytics/yolov5', model_name) # YOLOv8 based on YOLOv5 | |
| self.model.eval() | |
| def predict(self, image): | |
| """ | |
| Run inference on the image and return results. | |
| Returns bounding boxes, class, and confidence score. | |
| """ | |
| results = self.model(image) # Inference | |
| return results.pandas().xywh[0] # Returns bounding boxes, class names, confidence score | |
| # --- Preprocessing RTSP Frame (preprocess.py) --- | |
| def preprocess_frame(frame): | |
| """ | |
| Preprocesses a frame for YOLOv8 by converting it to RGB and resizing it. | |
| """ | |
| img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert frame to RGB | |
| img_resized = cv2.resize(img, (640, 640)) # Resize to 640x640 (YOLOv8 input size) | |
| return img_resized | |
| # --- RTSP Stream Handler (rtsp_stream.py) --- | |
| def capture_rtsp_frames(rtsp_url: str): | |
| """ | |
| Captures frames from an RTSP stream at a 1-2 second interval. | |
| """ | |
| cap = cv2.VideoCapture(rtsp_url) | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if ret: | |
| timestamp = time.time() | |
| yield frame, timestamp | |
| else: | |
| break | |
| cap.release() | |
| # --- Save Violations and Snapshots (violation_log.py) --- | |
| def save_snapshot(frame): | |
| """ | |
| Save the frame as a snapshot and return the snapshot URL. | |
| """ | |
| filename = f"snapshot_{int(time.time())}.jpg" | |
| snapshot_path = Path("./snapshots") / filename | |
| os.makedirs("./snapshots", exist_ok=True) | |
| cv2.imwrite(str(snapshot_path), frame) | |
| return f"http://localhost/snapshots/{filename}" # URL for local testing | |
| def log_violation(violation_data): | |
| """ | |
| Save the violation data to a JSON file for further processing. | |
| """ | |
| log_file = Path("./violation_logs.json") | |
| if log_file.exists(): | |
| with open(log_file, "r") as f: | |
| logs = json.load(f) | |
| else: | |
| logs = [] | |
| logs.append(violation_data) | |
| with open(log_file, "w") as f: | |
| json.dump(logs, f, indent=4) | |
| # --- Notification System (notification.py) --- | |
| def send_alert(violation): | |
| """ | |
| Sends an alert to site authorities about the detected violation. | |
| """ | |
| # Placeholder for notification logic (e.g., sending an email or SMS) | |
| print(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}") | |
| # Here, you can integrate with an actual notification service (e.g., SendGrid, Twilio) | |
| # --- Salesforce Integration (create_violation.py) --- | |
| def create_salesforce_violation_record(violation_data): | |
| """ | |
| Create a violation record in Salesforce using the REST API. | |
| """ | |
| salesforce_url = f"{SALESFORCE_URL}/services/data/vXX.0/sobjects/Safety_Violation_Log__c/" | |
| headers = { | |
| 'Authorization': f'Bearer {SALESFORCE_TOKEN}', | |
| 'Content-Type': 'application/json' | |
| } | |
| violation_obj = { | |
| 'Site_ID__c': violation_data['site_id'], | |
| 'Violation_Type__c': violation_data['violation_type'], | |
| 'Timestamp__c': violation_data['timestamp'], | |
| 'Snapshot_URL__c': violation_data['snapshot_url'], | |
| 'Severity__c': violation_data['severity'], | |
| 'Alert_Sent__c': True, | |
| 'Resolved__c': False | |
| } | |
| response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj)) | |
| return response.json() | |
| # --- API Routes (api.py) --- | |
| async def detect_violation(): | |
| try: | |
| # Capture RTSP stream frames | |
| for frame, timestamp in capture_rtsp_frames(RTSP_URL): | |
| # Preprocess frame and run YOLOv8 inference | |
| frame_processed = preprocess_frame(frame) | |
| model = YOLOv8Model() # Load YOLOv8 model | |
| results = model.predict(frame_processed) | |
| # Log violation and send alert if violation detected | |
| for index, row in results.iterrows(): | |
| violation = { | |
| 'site_id': "Site1", # Placeholder, should be dynamic | |
| 'violation_type': row['name'], | |
| 'timestamp': timestamp, | |
| 'snapshot_url': save_snapshot(frame), # Save snapshot and get URL | |
| 'severity': row['confidence'] | |
| } | |
| create_salesforce_violation_record(violation) # Log to Salesforce | |
| send_alert(violation) # Send alert to site HSE | |
| return {"status": "Violation detection complete."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing stream: {e}") | |
| async def upload_image(file: UploadFile = File(...)): | |
| try: | |
| # Read the uploaded image | |
| image_data = await file.read() | |
| image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) | |
| # Run inference | |
| model = YOLOv8Model() | |
| results = model.predict(image) | |
| return {"results": results.to_dict()} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {e}") | |
| async def health_check(): | |
| return {"status": "Running smoothly"} | |
| # --- Dockerfile --- | |
| # Dockerfile for deployment | |
| # Create a Dockerfile to containerize the FastAPI app | |
| # FROM python:3.9-slim | |
| # WORKDIR /app | |
| # COPY requirements.txt . | |
| # RUN pip install -r requirements.txt | |
| # COPY . . | |
| # EXPOSE 8000 | |
| # CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] | |
| # --- requirements.txt --- | |
| # fastapi | |
| # uvicorn | |
| # torch | |
| # opencv-python | |
| # requests | |
| # numpy | |
| # pandas | |
| # simple_salesforce | |
| # pydantic | |
| # --- tests/test_inference.py --- | |
| # Removed pytest import to avoid the error on app startup | |
| def test_inference(): | |
| # Create a dummy frame for testing | |
| frame = np.zeros((640, 640, 3), dtype=np.uint8) | |
| result = process_and_predict(frame) | |
| # Check that result is in the expected format | |
| assert isinstance(result, pd.DataFrame), "Expected results in DataFrame format" | |
| assert "name" in result.columns, "Violation class (name) missing in results" | |
| # --- tests/test_violation_log.py --- | |
| def test_log_violation(): | |
| violation_data = { | |
| 'site_id': "Site1", | |
| 'violation_type': "No Helmet", | |
| 'timestamp': 1234567890, | |
| 'snapshot_url': save_snapshot("dummy_frame"), | |
| 'severity': "Critical" | |
| } | |
| log_violation(violation_data) | |
| # Check if the violation log file exists | |
| assert os.path.exists("./violation_logs.json"), "Violation logs file not found" | |
| # --- tests/test_notification.py --- | |
| def test_send_alert(): | |
| violation = {'violation_type': 'No Helmet', 'severity': 'Critical'} | |
| send_alert(violation) | |
| # Here you can mock or check if send_alert() works as expected | |