VijayPulmamidi's picture
Update app.py
236ae88 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from typing import List, Dict
import cv2
import torch
import time
import requests
import json
from io import BytesIO
import os
import numpy as np
from pathlib import Path
# --- Configuration (config.py) ---
RTSP_URL = "rtsp://your_rtsp_stream_url"
SALESFORCE_URL = "https://your_salesforce_instance_url"
SALESFORCE_TOKEN = "your_salesforce_access_token"
HUGGINGFACE_API_URL = "https://huggingface.co/your_model_endpoint"
HUGGINGFACE_TOKEN = "your_huggingface_api_token"
# --- Initialize FastAPI app ---
app = FastAPI()
# --- YOLOv8 Model (yolo_model.py) ---
class YOLOv8Model:
def __init__(self, model_name='yolov8'):
self.model = torch.hub.load('ultralytics/yolov5', model_name) # YOLOv8 based on YOLOv5
self.model.eval()
def predict(self, image):
"""
Run inference on the image and return results.
Returns bounding boxes, class, and confidence score.
"""
results = self.model(image) # Inference
return results.pandas().xywh[0] # Returns bounding boxes, class names, confidence score
# --- Preprocessing RTSP Frame (preprocess.py) ---
def preprocess_frame(frame):
"""
Preprocesses a frame for YOLOv8 by converting it to RGB and resizing it.
"""
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert frame to RGB
img_resized = cv2.resize(img, (640, 640)) # Resize to 640x640 (YOLOv8 input size)
return img_resized
# --- RTSP Stream Handler (rtsp_stream.py) ---
def capture_rtsp_frames(rtsp_url: str):
"""
Captures frames from an RTSP stream at a 1-2 second interval.
"""
cap = cv2.VideoCapture(rtsp_url)
while cap.isOpened():
ret, frame = cap.read()
if ret:
timestamp = time.time()
yield frame, timestamp
else:
break
cap.release()
# --- Save Violations and Snapshots (violation_log.py) ---
def save_snapshot(frame):
"""
Save the frame as a snapshot and return the snapshot URL.
"""
filename = f"snapshot_{int(time.time())}.jpg"
snapshot_path = Path("./snapshots") / filename
os.makedirs("./snapshots", exist_ok=True)
cv2.imwrite(str(snapshot_path), frame)
return f"http://localhost/snapshots/{filename}" # URL for local testing
def log_violation(violation_data):
"""
Save the violation data to a JSON file for further processing.
"""
log_file = Path("./violation_logs.json")
if log_file.exists():
with open(log_file, "r") as f:
logs = json.load(f)
else:
logs = []
logs.append(violation_data)
with open(log_file, "w") as f:
json.dump(logs, f, indent=4)
# --- Notification System (notification.py) ---
def send_alert(violation):
"""
Sends an alert to site authorities about the detected violation.
"""
# Placeholder for notification logic (e.g., sending an email or SMS)
print(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}")
# Here, you can integrate with an actual notification service (e.g., SendGrid, Twilio)
# --- Salesforce Integration (create_violation.py) ---
def create_salesforce_violation_record(violation_data):
"""
Create a violation record in Salesforce using the REST API.
"""
salesforce_url = f"{SALESFORCE_URL}/services/data/vXX.0/sobjects/Safety_Violation_Log__c/"
headers = {
'Authorization': f'Bearer {SALESFORCE_TOKEN}',
'Content-Type': 'application/json'
}
violation_obj = {
'Site_ID__c': violation_data['site_id'],
'Violation_Type__c': violation_data['violation_type'],
'Timestamp__c': violation_data['timestamp'],
'Snapshot_URL__c': violation_data['snapshot_url'],
'Severity__c': violation_data['severity'],
'Alert_Sent__c': True,
'Resolved__c': False
}
response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj))
return response.json()
# --- API Routes (api.py) ---
@app.post("/detect_violation/")
async def detect_violation():
try:
# Capture RTSP stream frames
for frame, timestamp in capture_rtsp_frames(RTSP_URL):
# Preprocess frame and run YOLOv8 inference
frame_processed = preprocess_frame(frame)
model = YOLOv8Model() # Load YOLOv8 model
results = model.predict(frame_processed)
# Log violation and send alert if violation detected
for index, row in results.iterrows():
violation = {
'site_id': "Site1", # Placeholder, should be dynamic
'violation_type': row['name'],
'timestamp': timestamp,
'snapshot_url': save_snapshot(frame), # Save snapshot and get URL
'severity': row['confidence']
}
create_salesforce_violation_record(violation) # Log to Salesforce
send_alert(violation) # Send alert to site HSE
return {"status": "Violation detection complete."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing stream: {e}")
@app.post("/upload_image/")
async def upload_image(file: UploadFile = File(...)):
try:
# Read the uploaded image
image_data = await file.read()
image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
# Run inference
model = YOLOv8Model()
results = model.predict(image)
return {"results": results.to_dict()}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
@app.get("/health_check/")
async def health_check():
return {"status": "Running smoothly"}
# --- Dockerfile ---
# Dockerfile for deployment
# Create a Dockerfile to containerize the FastAPI app
# FROM python:3.9-slim
# WORKDIR /app
# COPY requirements.txt .
# RUN pip install -r requirements.txt
# COPY . .
# EXPOSE 8000
# CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
# --- requirements.txt ---
# fastapi
# uvicorn
# torch
# opencv-python
# requests
# numpy
# pandas
# simple_salesforce
# pydantic
# --- tests/test_inference.py ---
# Removed pytest import to avoid the error on app startup
def test_inference():
# Create a dummy frame for testing
frame = np.zeros((640, 640, 3), dtype=np.uint8)
result = process_and_predict(frame)
# Check that result is in the expected format
assert isinstance(result, pd.DataFrame), "Expected results in DataFrame format"
assert "name" in result.columns, "Violation class (name) missing in results"
# --- tests/test_violation_log.py ---
def test_log_violation():
violation_data = {
'site_id': "Site1",
'violation_type': "No Helmet",
'timestamp': 1234567890,
'snapshot_url': save_snapshot("dummy_frame"),
'severity': "Critical"
}
log_violation(violation_data)
# Check if the violation log file exists
assert os.path.exists("./violation_logs.json"), "Violation logs file not found"
# --- tests/test_notification.py ---
def test_send_alert():
violation = {'violation_type': 'No Helmet', 'severity': 'Critical'}
send_alert(violation)
# Here you can mock or check if send_alert() works as expected