AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
8a4e253 verified
raw
history blame
15.5 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
import uuid
# ==========================
# Configuration
# ==========================
CONFIG = {
"MODEL_PATH": "yolov8n.pt", # Lightweight model, must be trained for violations only
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone" # Ignored in processing
},
"DISPLAY_NAMES": { # Mapping for user-friendly violation names
"no_helmet": "Missing Helmet",
"no_harness": "Missing Harness",
"unsafe_posture": "Unsafe Posture"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
"FRAME_SKIP": 15, # Process every 15th frame
"MAX_PROCESSING_TIME": 25, # Cap video processing at 25s
"CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Ensure output directory exists
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
# ==========================
# Device Setup
# ==========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# ==========================
# Model Loading
# ==========================
def load_model():
try:
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if unexpected classes are detected.")
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Salesforce Integration
# ==========================
@retry(stop_max_attempt_number=2, wait_fixed=1000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
sf.describe()
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
if not violations:
c.drawString(1 * inch, y_position, "No violations detected.")
else:
for v in violations:
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
@retry(stop_max_attempt_number=2, wait_fixed=1000)
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
@retry(stop_max_attempt_number=2, wait_fixed=1000)
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
for v in violations
) or "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
logger.info(f"Creating Salesforce record with data: {record_data}")
try:
record = sf.Safety_Video_Report__c.create(record_data)
logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
except Exception as e:
logger.error(f"Failed to create Safety_Video_Report__c: {e}")
record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
logger.warning(f"Fell back to Account record: {record['id']}")
record_id = record["id"]
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
try:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
except Exception as e:
logger.error(f"Failed to update Safety_Video_Report__c: {e}")
sf.Account.update(record_id, {"Description": uploaded_url})
logger.info(f"Updated Account record {record_id} with PDF URL")
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}")
return None, ""
# ==========================
# Safety Score Calculation
# ==========================
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20
}
score = 100
for v in violations:
if v["violation"] in penalties:
score -= penalties[v["violation"]]
return max(score, 0)
# ==========================
# Video Processing
# ==========================
def process_video(video_data):
try:
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file")
violations, snapshots = [], []
frame_count = 0
start_time = time.time()
fps = video.get(cv2.CAP_PROP_FPS)
max_frames = int(60 * fps) # Process up to 1 minute
# Track one snapshot per violation type
snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
while True:
ret, frame = video.read()
if not ret or frame_count >= max_frames:
break
if frame_count % CONFIG["FRAME_SKIP"] != 0:
frame_count += 1
continue
# Stop if processing time exceeds 25 seconds
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
logger.info("Processing time limit reached")
break
results = model(frame, device=device)
seen_violations = set()
for result in results:
for box in result.boxes:
cls, conf = int(box.cls), float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
# Only process specified violations
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
logger.info(f"Ignoring detection: {label} (cls: {cls}, conf: {conf}) - not a target violation")
continue
# Apply confidence threshold
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
continue
if label in seen_violations:
continue
seen_violations.add(label)
violation = {
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]],
"timestamp": frame_count / fps
}
violations.append(violation)
# Save only one snapshot per violation type
if not snapshot_taken[label]:
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
cv2.imwrite(snapshot_path, frame)
with open(snapshot_path, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
snapshots.append({
"violation": label,
"frame": frame_count,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
"snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
})
snapshot_taken[label] = True
frame_count += 1
video.release()
os.remove(video_path)
if not violations:
logger.info("No violations detected")
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": ""
}
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
return {
"violations": violations,
"snapshots": snapshots,
"score": score,
"salesforce_record_id": report_id,
"violation_details_url": final_pdf_url
}
except Exception as e:
logger.error(f"Error processing video: {e}")
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": ""
}
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
with open(video_file, "rb") as f:
video_data = f.read()
result = process_video(video_data)
violation_table = "No violations detected."
if result["violations"]:
header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n"
rows = []
for v in result["violations"]:
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
rows.append(row)
violation_table = header + separator + "\n".join(rows)
snapshots_text = "No snapshots captured."
if result["snapshots"]:
snapshots_text = "\n".join(
f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
for s in result["snapshots"]
)
return (
violation_table,
f"Safety Score: {result['score']}%",
snapshots_text,
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
result["violation_details_url"] or "N/A"
)
except Exception as e:
logger.error(f"Error in Gradio interface: {e}")
return f"Error: {str(e)}", "", "Error in processing.", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored."
)
if __name__ == "__main__":
logger.info("Launching Safety Analyzer App...")
interface.launch()