AI_Safety_Demo / app.py
PrashanthB461's picture
Update app.py
5a78d5c verified
import os
import cv2
import gradio as gr
import torch
import numpy as np
from simple_salesforce import Salesforce
import time
# Salesforce connection details (replace these with your actual credentials)
sf = Salesforce(username='prashanth@safetyanaluzer.com', password='SaiPrash461', security_token='MOA3BXBfGyqnjBneog8a9IcGw')
# Test Salesforce connection
try:
# Simple query to verify the connection
result = sf.query("SELECT Id FROM Safety_Video_Report__c LIMIT 1")
print("βœ… Salesforce connection successful.")
except Exception as e:
print(f"❌ Error connecting to Salesforce: {e}")
try:
from ultralytics import YOLO
except ImportError as e:
print("❌ Ultralytics not installed. Run: pip install ultralytics")
raise
# ==========================
# Configuration
# ==========================
DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
FALLBACK_MODEL = "yolov8n.pt" # Use nano model if custom one is missing
MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
VIOLATION_LABELS = {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone"
}
# ==========================
# Device Setup
# ==========================
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"βœ… Using device: {device}")
except Exception as e:
print(f"⚠️ Error setting device: {e}")
device = torch.device("cpu")
# ==========================
# Load Model (Use YOLOv8n for Faster Inference)
# ==========================
if os.path.isfile(MODEL_PATH):
selected_model = MODEL_PATH
print(f"βœ… Found model at: {selected_model}")
else:
print(f"⚠️ Model file '{MODEL_PATH}' not found. Falling back to: {FALLBACK_MODEL}")
selected_model = FALLBACK_MODEL
try:
model = YOLO(selected_model)
print(f"βœ… Model loaded: {selected_model}")
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise
# ==========================
# Video Processing with Optimizations
# ==========================
def process_video(video_path, frame_skip=5, max_frames=100):
try:
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file.")
frame_count = 0
violations = []
processed_frame_count = 0
start_time = time.time()
while True:
ret, frame = video.read()
if not ret:
break
# Skip frames to reduce processing time (process only every 'frame_skip' frame)
if frame_count % frame_skip != 0:
frame_count += 1
continue
# Model inference for detecting violations
results = model(frame, device=device)
for result in results:
for box in result.boxes:
cls = int(box.cls)
conf = float(box.conf)
xywh = box.xywh.cpu().numpy()[0]
label = VIOLATION_LABELS.get(cls, f"class_{cls}")
violations.append({
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in xywh]
})
frame_count += 1
processed_frame_count += 1
# Stop processing after a fixed number of frames to save time
if processed_frame_count >= max_frames:
break
# Check elapsed time, stop if we exceed 30 seconds
elapsed_time = time.time() - start_time
if elapsed_time > 30:
print("⏰ Exceeded 30 seconds of processing time.")
break
video.release()
score = calculate_safety_score(violations)
# Send data to Salesforce
send_to_salesforce(violations, score, video_path)
return violations, score
except Exception as e:
print(f"❌ Error processing video: {e}")
return [], f"Error: {e}"
# ==========================
# Safety Score Calculation
# ==========================
def calculate_safety_score(violations):
base_score = 100
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 25
}
for v in violations:
base_score -= penalties.get(v["violation"], 0)
return max(base_score, 0)
def send_to_salesforce(violations, score, video_path):
# Dynamic values (e.g., from user input)
site_name = "Construction Site 1" # Replace with dynamic data from your app
uploaded_by = "JohnDoe" # Replace with actual user info (e.g., from session or UI)
upload_date = "2025-05-08T12:00:00Z" # Replace with actual upload date
status = "Reviewed" # Update based on video processing result
print("Starting file upload to Salesforce...")
# Video file upload
file_id = upload_video_to_salesforce(video_path) # Upload the video and get the file ID
print(f"File uploaded to Salesforce with ID: {file_id}")
# Violations details: Frame number, violation type, confidence level
violations_details = "\n".join([f"Frame {v['frame']}: {v['violation']} (Confidence: {v['confidence']})" for v in violations])
# Data to be inserted into Salesforce
data = {
'Site__c': site_name,
'Uploaded_By__c': uploaded_by,
'Upload_Date__c': upload_date,
'Status__c': status,
'Compliance_Score__c': score,
'Violations_Found__c': len(violations),
'Violations_Details__c': violations_details,
'Video_File__c': file_id, # This is the file ID from ContentVersion
'PDF_Report_URL__c': "http://path_to_pdf_report" # Replace with actual PDF report URL if available
}
try:
print(f"Creating Salesforce record with data: {data}") # Log data being sent to Salesforce
result = sf.Safety_Video_Report__c.create(data) # Changed the object name to Safety_Video_Report__c
print(f"βœ… Successfully created new report in Salesforce with ID: {result['id']}")
except Exception as e:
print(f"❌ Error creating/updating record in Salesforce: {e}")
# ==========================
# Function to Upload Video to Salesforce
# ==========================
def upload_video_to_salesforce(video_path):
with open(video_path, 'rb') as file:
video_data = file.read()
content_version_data = {
'Title': 'Safety Video',
'PathOnClient': 'safety_video.mp4', # Path of the file on the client's system
'VersionData': video_data, # Video file content as binary
'FirstPublishLocationId': 'your_salesforce_library_id', # Library ID (if required, or leave blank)
}
try:
# Create a new ContentVersion record to upload the video
content_version = sf.ContentVersion.create(content_version_data)
file_id = content_version['Id'] # This is the ContentDocument ID
print(f"βœ… Video uploaded to Salesforce with File ID: {file_id}")
return file_id
except Exception as e:
print(f"❌ Error uploading video to Salesforce: {e}")
return None
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "Please upload a video file.", ""
violations, score = process_video(video_file)
return violations, f"Safety Score: {score}%"
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.JSON(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score")
],
title="Worksite Safety Violation Analyzer",
description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
)
if __name__ == "__main__":
print("πŸš€ Launching Safety Analyzer App...")
interface.launch()