AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
937ffb2 verified
raw
history blame
12.7 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
# ==========================
# Configuration
# ==========================
CONFIG = {
"MODEL_PATH": os.getenv("SAFETY_MODEL_PATH", "models/yolov8_safety.pt"),
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
"FRAME_SKIP": 5,
"MAX_FRAMES": 100,
"MAX_PROCESSING_TIME": 30
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Ensure output directory exists
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
# ==========================
# Device Setup
# ==========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# ==========================
# Model Loading
# ==========================
def load_model():
model_path = CONFIG["MODEL_PATH"] if os.path.isfile(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
try:
model = YOLO(model_path).to(device)
logger.info(f"Model loaded: {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Salesforce Integration
# ==========================
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
for v in violations:
text = f"{v['violation']} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"{v['violation']} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
for v in violations
) or "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
record = sf.Safety_Video_Report__c.create(record_data)
record_id = record["id"]
logger.info(f"Salesforce record created: {record_id}")
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}")
return None, ""
# ==========================
# Safety Score Calculation
# ==========================
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 25
}
score = 100 - sum(penalties.get(v["violation"], 0) for v in violations)
return max(score, 0)
# ==========================
# Video Processing
# ==========================
def process_video(video_data):
try:
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file")
violations, snapshots = [], []
frame_count, processed_frames = 0, 0
start_time = time.time()
while processed_frames < CONFIG["MAX_FRAMES"]:
ret, frame = video.read()
if not ret:
break
if frame_count % CONFIG["FRAME_SKIP"] != 0:
frame_count += 1
continue
results = model(frame, device=device)
for result in results:
for box in result.boxes:
cls, conf = int(box.cls), float(box.conf)
xywh = box.xywh.cpu().numpy()[0]
label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
violation = {
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in xywh],
"timestamp": frame_count / video.get(cv2.CAP_PROP_FPS)
}
violations.append(violation)
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
cv2.imwrite(snapshot_path, frame)
snapshots.append({
"violation": label,
"frame": frame_count,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
})
frame_count += 1
processed_frames += 1
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
logger.warning("Processing time limit exceeded")
break
video.release()
os.remove(video_path)
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
return {
"violations": violations,
"snapshots": snapshots,
"score": score,
"salesforce_record_id": report_id,
"violation_details_url": final_pdf_url
}
except Exception as e:
logger.error(f"Error processing video: {e}")
return {
"violations": [],
"snapshots": [],
"score": 0,
"salesforce_record_id": None,
"violation_details_url": ""
}
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
with open(video_file, "rb") as f:
video_data = f.read()
result = process_video(video_data)
# Format violations as a Markdown table
violation_table = "No violations detected."
if result["violations"]:
header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
separator = "|---------------|-----------|------------|--------------------------|-------------------|\n"
rows = []
for v in result["violations"]:
violation_name = v["violation"]
# Replace specific violation names for display
if violation_name == "no_helmet":
violation_name = "no helmet"
elif violation_name == "no_harness":
violation_name = "no harness"
elif violation_name == "unsafe_posture":
violation_name = "unsafe posture"
elif violation_name == "unsafe_zone":
violation_name = "unsafe zone"
row = f"| {violation_name:<13} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | [Details]({result['violation_details_url']}) |"
rows.append(row)
violation_table = header + separator + "\n".join(rows)
# Format snapshots as a bullet list with clickable image links
snapshots_text = "No snapshots captured."
if result["snapshots"]:
snapshots_text = "\n".join(
f"- Snapshot for {s['violation'].replace('no_', '').replace('unsafe_', '')} at frame {s['frame']}: [![Snapshot]({s['snapshot_url']})]({s['snapshot_url']})"
for s in result["snapshots"]
)
return (
violation_table,
f"Safety Score: {result['score']}%",
snapshots_text,
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
result["violation_details_url"] or 'N/A'
)
except Exception as e:
logger.error(f"Error in Gradio interface: {e}")
return f"Error: {str(e)}", "", "Error in processing.", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
)
if __name__ == "__main__":
logger.info("Launching Safety Analyzer App...")
interface.launch()