AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
25e32db verified
raw
history blame
13.1 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
import retrying
# ==========================
# Configuration
# ==========================
CONFIG = {
"MODEL_PATH": os.getenv("SAFETY_MODEL_PATH", "models/yolov8_safety.pt"),
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
"FRAME_SKIP": 5,
"MAX_PROCESSING_TIME": 30
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Ensure output directory exists
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
# ==========================
# Device Setup
# ==========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# ==========================
# Model Loading
# ==========================
def load_model():
model_path = CONFIG["MODEL_PATH"] if os.path.isfile(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
try:
model = YOLO(model_path).to(device)
logger.info(f"Model loaded: {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Salesforce Integration
# ==========================
@retrying.retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
for v in violations:
text = f"{v['violation']} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
logger.info(f"Creating ContentVersion for report ID: {report_id}")
content_version = sf.ContentVersion.create(content_version_data)
if not content_version.get('id'):
logger.error("ContentVersion creation failed: No ID returned")
return ""
result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"{v['violation']} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
for v in violations
) or "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": float(score), # Ensure correct data type
"Violations_Found__c": int(len(violations)), # Ensure integer
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
logger.info(f"Creating Safety_Video_Report__c with data: {record_data}")
record = sf.Safety_Video_Report__c.create(record_data)
if not record.get('id'):
logger.error("Safety_Video_Report__c creation failed: No ID returned")
return None, ""
record_id = record["id"]
logger.info(f"Salesforce record created: {record_id}")
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
logger.info(f"Updating record {record_id} with PDF URL: {uploaded_url}")
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}")
return None, ""
# ==========================
# Safety Score Calculation
# ==========================
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 25
}
score = 100 - sum(penalties.get(v["violation"], 0) for v in violations)
return max(score, 0)
# ==========================
# Video Processing
# ==========================
def process_video(video_data):
try:
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file")
violations, snapshots = [], []
frame_count = 0
start_time = time.time()
while True:
ret, frame = video.read()
if not ret:
break
if frame_count % CONFIG["FRAME_SKIP"] != 0:
frame_count += 1
continue
results = model(frame, device=device)
for result in results:
for box in result.boxes:
cls, conf = int(box.cls), float(box.conf)
xywh = box.xywh.cpu().numpy()[0]
label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
violation = {
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in xywh],
"timestamp": frame_count / video.get(cv2.CAP_PROP_FPS)
}
violations.append(violation)
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
cv2.imwrite(snapshot_path, frame)
snapshots.append({
"violation": label,
"frame": frame_count,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
})
frame_count += 1
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
logger.warning("Processing time limit exceeded")
break
video.release()
os.remove(video_path)
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
return {
"violations": violations,
"snapshots": snapshots,
"score": score,
"salesforce_record_id": report_id,
"violation_details_url": final_pdf_url
}
except Exception as e:
logger.error(f"Error processing video: {e}")
return {
"violations": [],
"snapshots": [],
"score": 0,
"salesforce_record_id": None,
"violation_details_url": ""
}
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
with open(video_file, "rb") as f:
video_data = f.read()
result = process_video(video_data)
# Format violations as a Markdown table
violation_table = "No violations detected."
if result["violations"]:
header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
separator = "|---------------|-----------|------------|--------------------------|-------------------------|\n"
rows = []
for v in result["violations"]:
# Simplify violation names but preserve "no_helmet" for table
violation_name = v["violation"]
if violation_name == "no_helmet":
violation_name = "no_helmet"
else:
violation_name = violation_name.replace("no_", "").replace("unsafe_", "")
row = f"| {violation_name:<13} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
rows.append(row)
violation_table = header + separator + "\n".join(rows)
# Format snapshots as a bullet list with direct image links
snapshots_text = "No snapshots captured."
if result["snapshots"]:
snapshots_text = "\n".join(
f"- Snapshot for {s['violation'].replace('no_', '').replace('unsafe_', '')} at frame {s['frame']}: [{s['snapshot_url']}]({s['snapshot_url']})"
for s in result["snapshots"]
)
return (
violation_table,
f"Safety Score: {result['score']}%",
snapshots_text,
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
result["violation_details_url"] or "N/A"
)
except Exception as e:
logger.error(f"Error in Gradio interface: {e}")
return f"Error: {str(e)}", "", "Error in processing.", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
)
if __name__ == "__main__":
logger.info("Launching Safety Analyzer App...")
interface.launch()