AI_Safety_Demo1 / app.py
PrashanthB461's picture
Update app.py
8a4e253 verified
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
import uuid
# ==========================
# Configuration
# ==========================
CONFIG = {
"MODEL_PATH": "yolov8n.pt", # Lightweight model, must be trained for violations only
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone" # Ignored in processing
},
"DISPLAY_NAMES": { # Mapping for user-friendly violation names
"no_helmet": "Missing Helmet",
"no_harness": "Missing Harness",
"unsafe_posture": "Unsafe Posture"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
"FRAME_SKIP": 15, # Process every 15th frame
"MAX_PROCESSING_TIME": 25, # Cap video processing at 25s
"CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Ensure output directory exists
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
# ==========================
# Device Setup
# ==========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# ==========================
# Model Loading
# ==========================
def load_model():
try:
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if unexpected classes are detected.")
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Salesforce Integration
# ==========================
@retry(stop_max_attempt_number=2, wait_fixed=1000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
sf.describe()
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
if not violations:
c.drawString(1 * inch, y_position, "No violations detected.")
else:
for v in violations:
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
@retry(stop_max_attempt_number=2, wait_fixed=1000)
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
@retry(stop_max_attempt_number=2, wait_fixed=1000)
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
for v in violations
) or "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
logger.info(f"Creating Salesforce record with data: {record_data}")
try:
record = sf.Safety_Video_Report__c.create(record_data)
logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
except Exception as e:
logger.error(f"Failed to create Safety_Video_Report__c: {e}")
record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
logger.warning(f"Fell back to Account record: {record['id']}")
record_id = record["id"]
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
try:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
except Exception as e:
logger.error(f"Failed to update Safety_Video_Report__c: {e}")
sf.Account.update(record_id, {"Description": uploaded_url})
logger.info(f"Updated Account record {record_id} with PDF URL")
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}")
return None, ""
# ==========================
# Safety Score Calculation
# ==========================
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20
}
score = 100
for v in violations:
if v["violation"] in penalties:
score -= penalties[v["violation"]]
return max(score, 0)
# ==========================
# Video Processing
# ==========================
def process_video(video_data):
try:
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file")
violations, snapshots = [], []
frame_count = 0
start_time = time.time()
fps = video.get(cv2.CAP_PROP_FPS)
max_frames = int(60 * fps) # Process up to 1 minute
# Track one snapshot per violation type
snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
while True:
ret, frame = video.read()
if not ret or frame_count >= max_frames:
break
if frame_count % CONFIG["FRAME_SKIP"] != 0:
frame_count += 1
continue
# Stop if processing time exceeds 25 seconds
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
logger.info("Processing time limit reached")
break
results = model(frame, device=device)
seen_violations = set()
for result in results:
for box in result.boxes:
cls, conf = int(box.cls), float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
# Only process specified violations
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
logger.info(f"Ignoring detection: {label} (cls: {cls}, conf: {conf}) - not a target violation")
continue
# Apply confidence threshold
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
continue
if label in seen_violations:
continue
seen_violations.add(label)
violation = {
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]],
"timestamp": frame_count / fps
}
violations.append(violation)
# Save only one snapshot per violation type
if not snapshot_taken[label]:
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
cv2.imwrite(snapshot_path, frame)
with open(snapshot_path, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
snapshots.append({
"violation": label,
"frame": frame_count,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
"snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
})
snapshot_taken[label] = True
frame_count += 1
video.release()
os.remove(video_path)
if not violations:
logger.info("No violations detected")
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": ""
}
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
return {
"violations": violations,
"snapshots": snapshots,
"score": score,
"salesforce_record_id": report_id,
"violation_details_url": final_pdf_url
}
except Exception as e:
logger.error(f"Error processing video: {e}")
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": ""
}
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
with open(video_file, "rb") as f:
video_data = f.read()
result = process_video(video_data)
violation_table = "No violations detected."
if result["violations"]:
header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n"
rows = []
for v in result["violations"]:
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
rows.append(row)
violation_table = header + separator + "\n".join(rows)
snapshots_text = "No snapshots captured."
if result["snapshots"]:
snapshots_text = "\n".join(
f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
for s in result["snapshots"]
)
return (
violation_table,
f"Safety Score: {result['score']}%",
snapshots_text,
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
result["violation_details_url"] or "N/A"
)
except Exception as e:
logger.error(f"Error in Gradio interface: {e}")
return f"Error: {str(e)}", "", "Error in processing.", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored."
)
if __name__ == "__main__":
logger.info("Launching Safety Analyzer App...")
interface.launch()