AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
a0709a7 verified
raw
history blame
17.4 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
# ==========================
# OPTIMIZED CONFIGURATION
# ==========================
CONFIG = {
"MODEL_PATH": "yolov8_safety.pt",
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone",
4: "improper_tool_use"
},
"CLASS_COLORS": {
"no_helmet": (0, 0, 255), # Red
"no_harness": (0, 165, 255), # Orange
"unsafe_posture": (0, 255, 0), # Green
"unsafe_zone": (255, 0, 0), # Blue
"improper_tool_use": (255, 255, 0) # Yellow
},
"DISPLAY_NAMES": {
"no_helmet": "No Helmet",
"no_harness": "No Harness",
"unsafe_posture": "Unsafe Posture",
"unsafe_zone": "Unsafe Zone",
"improper_tool_use": "Improper Tool Use"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
"FRAME_SKIP": 3,
"MAX_PROCESSING_TIME": 60,
"CONFIDENCE_THRESHOLD": {
"no_helmet": 0.4,
"no_harness": 0.3,
"unsafe_posture": 0.25,
"unsafe_zone": 0.3,
"improper_tool_use": 0.35
},
"IOU_THRESHOLD": 0.4,
"MIN_VIOLATION_FRAMES": 3,
"MIN_VIOLATION_DURATION": 1.5
}
# Initialize logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def load_model():
try:
if os.path.exists(CONFIG["MODEL_PATH"]):
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
logger.info("Loaded custom safety model")
else:
model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
logger.warning("Using fallback model - recommend training yolov8_safety.pt")
return model
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise
model = load_model()
def draw_detections(frame, detections):
"""Draw bounding boxes with labels and confidence scores"""
for det in detections:
label = det["violation"]
x, y, w, h = [int(v) for v in det["bounding_box"]]
color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
x1, y1 = int(x - w/2), int(y - h/2)
x2, y2 = int(x + w/2), int(y + h/2)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {det['confidence']:.2f}"
(text_width, text_height), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(frame, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1)
cv2.putText(frame, label_text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return frame
def calculate_iou(box1, box2):
"""Calculate Intersection over Union for two bounding boxes"""
box1 = [box1[0] - box1[2]/2, box1[1] - box1[3]/2, box1[0] + box1[2]/2, box1[1] + box1[3]/2]
box2 = [box2[0] - box2[2]/2, box2[1] - box2[3]/2, box2[0] + box2[2]/2, box2[1] + box2[3]/2]
x_left = max(box1[0], box2[0])
y_top = max(box1[1], box2[1])
x_right = min(box1[2], box2[2])
y_bottom = min(box1[3], box2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
return intersection_area / float(box1_area + box2_area - intersection_area)
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica-Bold", 14)
c.drawString(1 * inch, 10.5 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 12)
y_position = 10 * inch
report_data = [
("Compliance Score", f"{score}%"),
("Total Violations", len(violations)),
("Report Date", time.strftime("%Y-%m-%d %H:%M:%S"))
]
for key, value in report_data:
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.4 * inch
y_position -= 0.2 * inch
c.line(1 * inch, y_position, 7.5 * inch, y_position)
y_position -= 0.3 * inch
c.setFont("Helvetica-Bold", 12)
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
c.setFont("Helvetica", 10)
if not violations:
c.drawString(1 * inch, y_position, "No violations detected.")
else:
for v in violations:
violation_text = (
f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f}, "
f"Worker: {v['worker_id']})"
)
c.drawString(1 * inch, y_position, violation_text)
y_position -= 0.25 * inch
if y_position < 1 * inch:
c.showPage()
y_position = 10 * inch
c.setFont("Helvetica", 10)
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"Generated PDF report: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"PDF generation failed: {str(e)}")
return "", "", None
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {str(e)}")
raise
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version = sf.ContentVersion.create({
"Title": f"Safety_Report_{int(time.time())}",
"PathOnClient": "safety_report.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
})
return f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
except Exception as e:
logger.error(f"PDF upload failed: {str(e)}")
return ""
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"- {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
f"at {v['timestamp']:.2f}s (Worker {v['worker_id']}, Confidence: {v['confidence']:.2f})"
for v in violations
) or "No violations detected"
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "New"
}
try:
record = sf.Safety_Video_Report__c.create(record_data)
record_id = record["id"]
logger.info(f"Created Salesforce record: {record_id}")
except Exception as e:
logger.error(f"Failed to create Safety Report: {str(e)}")
record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
record_id = record["id"]
logger.warning(f"Created fallback Account record: {record_id}")
pdf_url = ""
if pdf_file:
pdf_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if pdf_url:
try:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": pdf_url})
except:
sf.Account.update(record_id, {"Description": pdf_url})
return record_id, pdf_url if pdf_url else f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}"
except Exception as e:
logger.error(f"Salesforce integration failed: {str(e)}")
return None, ""
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 35,
"improper_tool_use": 25
}
unique_violations = {(v["worker_id"], v["violation"]) for v in violations}
total_penalty = sum(penalties.get(v[1], 0) for v in unique_violations)
return max(100 - total_penalty, 0)
def process_video(video_data):
try:
temp_video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(temp_video_path, "wb") as f:
f.write(video_data)
cap = cv2.VideoCapture(temp_video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
workers = []
violations = []
snapshots = []
violation_history = {k: [] for k in CONFIG["VIOLATION_LABELS"].values()}
snapshot_taken = {k: False for k in CONFIG["VIOLATION_LABELS"].values()}
frame_count = 0
start_time = time.time()
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % CONFIG["FRAME_SKIP"] != 0:
frame_count += 1
continue
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
logger.warning("Processing timeout reached")
break
current_time = frame_count / fps
results = model(frame, device=device, verbose=False)
for result in results:
for box in result.boxes:
cls = int(box.cls)
conf = float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls)
if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3):
continue
bbox = box.xywh.cpu().numpy()[0].tolist()
matched_worker = None
max_iou = 0
for worker in workers:
iou = calculate_iou(bbox, worker["bbox"])
if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
max_iou = iou
matched_worker = worker
if matched_worker:
worker_id = matched_worker["id"]
matched_worker["bbox"] = bbox
matched_worker["last_seen"] = current_time
else:
worker_id = len(workers) + 1
workers.append({
"id": worker_id,
"bbox": bbox,
"first_seen": current_time,
"last_seen": current_time
})
violation_history[label].append({
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": bbox,
"timestamp": current_time,
"worker_id": worker_id
})
frame_count += 1
for violation_type, detections in violation_history.items():
if not detections:
continue
worker_groups = {}
for det in detections:
if det["worker_id"] not in worker_groups:
worker_groups[det["worker_id"]] = []
worker_groups[det["worker_id"]].append(det)
for worker_id, worker_dets in worker_groups.items():
if len(worker_dets) < 2:
continue
duration = worker_dets[-1]["timestamp"] - worker_dets[0]["timestamp"]
if duration >= CONFIG["MIN_VIOLATION_DURATION"]:
best_det = max(worker_dets, key=lambda x: x["confidence"])
violations.append(best_det)
if not snapshot_taken[violation_type]:
cap.set(cv2.CAP_PROP_POS_FRAMES, best_det["frame"])
ret, snapshot_frame = cap.read()
if ret:
snapshot_frame = draw_detections(snapshot_frame, [best_det])
filename = f"{violation_type}_{best_det['frame']}.jpg"
path = os.path.join(CONFIG["OUTPUT_DIR"], filename)
cv2.imwrite(path, snapshot_frame)
snapshots.append({
"violation": violation_type,
"frame": best_det["frame"],
"path": path,
"url": f"{CONFIG['PUBLIC_URL_BASE']}{filename}"
})
snapshot_taken[violation_type] = True
cap.release()
os.remove(temp_video_path)
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
record_id, sf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
return {
"violations": violations,
"snapshots": snapshots,
"score": score,
"salesforce_record_id": record_id,
"violation_details_url": sf_url or pdf_url,
"message": ""
}
except Exception as e:
logger.error(f"Video processing failed: {str(e)}")
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": "",
"message": f"Error: {str(e)}"
}
def gradio_interface(video_file):
try:
yield "Analyzing video...", "", "", "", ""
with open(video_file, "rb") as f:
result = process_video(f.read())
violation_table = (
"| Violation Type | Timestamp | Confidence | Worker ID |\n"
"|---------------------|-----------|------------|-----------|\n" +
"\n".join(
f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<19} | "
f"{v['timestamp']:.2f} | "
f"{v['confidence']:.2f} | "
f"{v['worker_id']} |"
for v in result["violations"]
)
) if result["violations"] else "No violations detected"
snapshots_md = "\n".join(
f"![{s['violation']} at frame {s['frame']}]({s['url']})"
for s in result["snapshots"]
) if result["snapshots"] else "No snapshots"
yield (
violation_table,
f"Safety Score: {result['score']}%",
snapshots_md,
f"Salesforce ID: {result['salesforce_record_id'] or 'None'}",
result["violation_details_url"] or "None"
)
except Exception as e:
logger.error(f"Interface error: {str(e)}")
yield f"Error: {str(e)}", "", "", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Violations Detected"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Evidence Snapshots"),
gr.Textbox(label="Salesforce Record"),
gr.Textbox(label="Report URL")
],
title="AI Safety Compliance Monitor",
description="Detects 5 violation types: No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use"
)
if __name__ == "__main__":
interface.launch(share=True)