AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
6104e09 verified
raw
history blame
20.9 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
# ==========================
# Enhanced Configuration
# ==========================
CONFIG = {
"MODEL_PATH": "yolov8_safety.pt",
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone",
4: "improper_tool_use"
},
"CLASS_COLORS": {
"no_helmet": (0, 0, 255), # Red
"no_harness": (0, 165, 255), # Orange
"unsafe_posture": (0, 255, 0), # Green
"unsafe_zone": (255, 0, 0), # Blue
"improper_tool_use": (255, 255, 0) # Yellow
},
"DISPLAY_NAMES": {
"no_helmet": "No Helmet Violation",
"no_harness": "No Harness Violation",
"unsafe_posture": "Unsafe Posture Violation",
"unsafe_zone": "Unsafe Zone Entry",
"improper_tool_use": "Improper Tool Use"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
"FRAME_SKIP": 5, # Reduced for better detection
"MAX_PROCESSING_TIME": 60,
"CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
"IOU_THRESHOLD": 0.4,
"MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def load_model():
try:
if os.path.isfile(CONFIG["MODEL_PATH"]):
model_path = CONFIG["MODEL_PATH"]
logger.info(f"Model loaded: {model_path}")
else:
model_path = CONFIG["FALLBACK_MODEL"]
logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
if not os.path.isfile(model_path):
logger.info(f"Downloading fallback model: {model_path}")
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
model = YOLO(model_path).to(device)
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Enhanced Helper Functions
# ==========================
def draw_detections(frame, detections):
"""Draw bounding boxes and labels on frame"""
for det in detections:
label = det["violation"]
confidence = det["confidence"]
x, y, w, h = det["bounding_box"]
# Convert from center coordinates to corner coordinates
x1 = int(x - w/2)
y1 = int(y - h/2)
x2 = int(x + w/2)
y2 = int(y + h/2)
color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
cv2.putText(frame, display_text, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return frame
def calculate_iou(box1, box2):
"""Calculate Intersection over Union (IoU) for two bounding boxes."""
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
# Convert to top-left and bottom-right coordinates
x1_min, y1_min = x1 - w1/2, y1 - h1/2
x1_max, y1_max = x1 + w1/2, y1 + h1/2
x2_min, y2_min = x2 - w2/2, y2 - h2/2
x2_max, y2_max = x2 + w2/2, y2 + h2/2
# Calculate intersection
x_min = max(x1_min, x2_min)
y_min = max(y1_min, y2_min)
x_max = min(x1_max, x2_max)
y_max = min(y1_max, y2_max)
intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
area1 = w1 * h1
area2 = w2 * h2
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0
# ==========================
# Salesforce Integration (unchanged)
# ==========================
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
sf.describe()
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
if not violations:
c.drawString(1 * inch, y_position, "No violations detected.")
else:
for v in violations:
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
for v in violations
) or "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
logger.info(f"Creating Salesforce record with data: {record_data}")
try:
record = sf.Safety_Video_Report__c.create(record_data)
logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
except Exception as e:
logger.error(f"Failed to create Safety_Video_Report__c: {e}")
record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
logger.warning(f"Fell back to Account record: {record['id']}")
record_id = record["id"]
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
try:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
except Exception as e:
logger.error(f"Failed to update Safety_Video_Report__c: {e}")
sf.Account.update(record_id, {"Description": uploaded_url})
logger.info(f"Updated Account record {record_id} with PDF URL")
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
return None, ""
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 35,
"improper_tool_use": 25
}
# Count unique violations per worker
unique_violations = set()
for v in violations:
key = (v["worker_id"], v["violation"])
unique_violations.add(key)
total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
score = 100 - total_penalty
return max(score, 0)
# ==========================
# Enhanced Video Processing
# ==========================
def process_video(video_data):
try:
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file")
violations = []
snapshots = []
frame_count = 0
start_time = time.time()
fps = video.get(cv2.CAP_PROP_FPS)
if fps <= 0:
fps = 30 # Default assumption if FPS cannot be determined
# Structure to track workers and their violations
workers = []
violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
logger.info(f"Processing video with FPS: {fps}")
logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
while True:
ret, frame = video.read()
if not ret:
break
if frame_count % CONFIG["FRAME_SKIP"] != 0:
frame_count += 1
continue
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
logger.info("Processing time limit reached")
break
current_time = frame_count / fps
# Run detection on this frame
results = model(frame, device=device)
current_detections = []
for result in results:
boxes = result.boxes
for box in boxes:
cls = int(box.cls)
conf = float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
if label is None:
continue
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
continue
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
current_detections.append({
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": bbox,
"timestamp": current_time
})
# Process detections and associate with workers
for detection in current_detections:
# Find matching worker
matched_worker = None
max_iou = 0
for worker in workers:
iou = calculate_iou(detection["bounding_box"], worker["bbox"])
if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
max_iou = iou
matched_worker = worker
if matched_worker:
# Update worker's position
matched_worker["bbox"] = detection["bounding_box"]
matched_worker["last_seen"] = current_time
worker_id = matched_worker["id"]
else:
# New worker
worker_id = len(workers) + 1
workers.append({
"id": worker_id,
"bbox": detection["bounding_box"],
"first_seen": current_time,
"last_seen": current_time
})
# Add to violation history
detection["worker_id"] = worker_id
violation_history[detection["violation"]].append(detection)
frame_count += 1
video.release()
os.remove(video_path)
# Process violation history to confirm persistent violations
for violation_type, detections in violation_history.items():
if not detections:
continue
# Group by worker
worker_violations = {}
for det in detections:
if det["worker_id"] not in worker_violations:
worker_violations[det["worker_id"]] = []
worker_violations[det["worker_id"]].append(det)
# Check each worker's violations for persistence
for worker_id, worker_dets in worker_violations.items():
if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
# Take the highest confidence detection
best_detection = max(worker_dets, key=lambda x: x["confidence"])
violations.append(best_detection)
# Capture snapshot if not already taken
if not snapshot_taken[violation_type]:
# Get the frame for this violation
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
ret, snapshot_frame = cap.read()
cap.release()
if ret:
# Draw detections on snapshot
snapshot_frame = draw_detections(snapshot_frame, [best_detection])
snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
cv2.imwrite(snapshot_path, snapshot_frame)
snapshots.append({
"violation": violation_type,
"frame": best_detection["frame"],
"snapshot_path": snapshot_path,
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
})
snapshot_taken[violation_type] = True
# Final processing
if not violations:
logger.info("No persistent violations detected")
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": "",
"message": "No violations detected in the video."
}
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
return {
"violations": violations,
"snapshots": snapshots,
"score": score,
"salesforce_record_id": report_id,
"violation_details_url": final_pdf_url,
"message": ""
}
except Exception as e:
logger.error(f"Error processing video: {e}", exc_info=True)
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": "",
"message": f"Error processing video: {e}"
}
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
yield "Processing video... please wait.", "", "", "", ""
with open(video_file, "rb") as f:
video_data = f.read()
result = process_video(video_data)
if result.get("message"):
yield result["message"], "", "", "", ""
return
violation_table = "No violations detected."
if result["violations"]:
header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
separator = "|------------------------|---------------|------------|-----------|\n"
rows = []
violation_name_map = CONFIG["DISPLAY_NAMES"]
for v in result["violations"]:
display_name = violation_name_map.get(v["violation"], v["violation"])
row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
rows.append(row)
violation_table = header + separator + "\n".join(rows)
snapshots_text = "No snapshots captured."
if result["snapshots"]:
violation_name_map = CONFIG["DISPLAY_NAMES"]
snapshots_text = "\n".join(
f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
for s in result["snapshots"]
)
yield (
violation_table,
f"Safety Score: {result['score']}%",
snapshots_text,
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
result["violation_details_url"] or "N/A"
)
except Exception as e:
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("Launching Enhanced Safety Analyzer App...")
interface.launch()