AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
125b2ad verified
raw
history blame
22.8 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
import uuid
from multiprocessing import Pool, cpu_count
from functools import partial
# ==========================
# Optimized Configuration
# ==========================
CONFIG = {
"MODEL_PATH": "yolov8_safety.pt",
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone",
4: "improper_tool_use"
},
"CLASS_COLORS": {
"no_helmet": (0, 0, 255), # Red
"no_harness": (0, 165, 255), # Orange
"unsafe_posture": (0, 255, 0), # Green
"unsafe_zone": (255, 0, 0), # Blue
"improper_tool_use": (255, 255, 0) # Yellow
},
"DISPLAY_NAMES": {
"no_helmet": "No Helmet Violation",
"no_harness": "No Harness Violation",
"unsafe_posture": "Unsafe Posture",
"unsafe_zone": "Unsafe Zone Entry",
"improper_tool_use": "Improper Tool Use"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
"CONFIDENCE_THRESHOLDS": {
"no_helmet": 0.75, # Increased for stricter helmet detection
"no_harness": 0.4,
"unsafe_posture": 0.4,
"unsafe_zone": 0.4,
"improper_tool_use": 0.4
},
"MIN_VIOLATION_FRAMES": 3,
"WORKER_TRACKING_DURATION": 3.0,
"MAX_PROCESSING_TIME": 60, # 1 minute limit
"FRAME_SKIP": 2, # Process every 2nd frame for speed
"BATCH_SIZE": 16, # Frames per batch
"PARALLEL_WORKERS": max(1, cpu_count() - 1) # Use all CPU cores except one
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def load_model():
try:
if os.path.isfile(CONFIG["MODEL_PATH"]):
model_path = CONFIG["MODEL_PATH"]
logger.info(f"Model loaded: {model_path}")
else:
model_path = CONFIG["FALLBACK_MODEL"]
logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
if not os.path.isfile(model_path):
logger.info(f"Downloading fallback model: {model_path}")
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
model = YOLO(model_path).to(device)
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Optimized Helper Functions
# ==========================
def draw_detections(frame, detections):
for det in detections:
label = det.get("violation", "Unknown")
confidence = det.get("confidence", 0.0)
x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
x1 = int(x - w/2)
y1 = int(y - h/2)
x2 = int(x + w/2)
y2 = int(y + h/2)
color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return frame
def calculate_iou(box1, box2):
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
x_left = max(x1 - w1/2, x2 - w2/2)
y_top = max(y1 - h1/2, y2 - h2/2)
x_right = min(x1 + w1/2, x2 + w2/2)
y_bottom = min(y1 + h1/2, y2 + h2/2)
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
box1_area = w1 * h1
box2_area = w2 * h2
union_area = box1_area + box2_area - intersection_area
return intersection_area / union_area
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
if not violations:
c.drawString(1 * inch, y_position, "No violations detected.")
else:
for v in violations:
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 35,
"improper_tool_use": 25
}
total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
score = 100 - total_penalty
return max(score, 0)
# ==========================
# Salesforce Integration
# ==========================
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
sf.describe()
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
for v in violations
) or "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
logger.info(f"Creating Salesforce record with data: {record_data}")
try:
record = sf.Safety_Video_Report__c.create(record_data)
logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
except Exception as e:
logger.error(f"Failed to create Safety_Video_Report__c: {e}")
record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
logger.warning(f"Fell back to Account record: {record['id']}")
record_id = record["id"]
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
try:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
except Exception as e:
logger.error(f"Failed to update Safety_Video_Report__c: {e}")
sf.Account.update(record_id, {"Description": uploaded_url})
logger.info(f"Updated Account record {record_id} with PDF URL")
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
return None, ""
# ==========================
# Fast Video Processing
# ==========================
def process_video(video_data):
try:
# Create temp video file
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
# Open video file
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file")
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0:
fps = 30
duration = total_frames / fps
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
# Track workers only for helmet violations
helmet_workers = {} # {worker_id: {"first_detected": timestamp, "bbox": bbox}}
violations = []
snapshots = []
start_time = time.time()
frame_skip = CONFIG["FRAME_SKIP"]
next_worker_id = 1
# Process frames in batches
while True:
batch_frames = []
batch_indices = []
# Collect frames for this batch
for _ in range(CONFIG["BATCH_SIZE"]):
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if frame_idx >= total_frames:
break
ret, frame = cap.read()
if not ret:
break
# Skip frames if needed
for _ in range(frame_skip - 1):
if not cap.grab():
break
batch_frames.append(frame)
batch_indices.append(frame_idx)
# Break if no more frames
if not batch_frames:
break
# Run batch detection
results = model(batch_frames, device=device, conf=0.1, verbose=False)
# Process results for each frame in batch
for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
current_time = frame_idx / fps
# Update progress periodically
if time.time() - start_time > 1.0: # Update every second
progress = (frame_idx / total_frames) * 100
yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
start_time = time.time()
# Process detections in this frame
boxes = result.boxes
frame_violations = set() # Track violations in this frame to avoid duplicates
for box in boxes:
cls = int(box.cls)
conf = float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
continue
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
# For no_helmet violations, track workers and only record first occurrence
if label == "no_helmet":
# Check if this is a known worker
worker_id = None
for w_id, worker in helmet_workers.items():
iou = calculate_iou(bbox, worker["bbox"])
if iou > 0.4: # IOU threshold
worker_id = w_id
# Update worker's position
helmet_workers[w_id]["bbox"] = bbox
helmet_workers[w_id]["last_seen"] = current_time
break
# If new worker, assign ID and record first violation
if worker_id is None:
worker_id = next_worker_id
next_worker_id += 1
helmet_workers[worker_id] = {
"bbox": bbox,
"first_seen": current_time,
"last_seen": current_time
}
# Only record first violation for this worker
detection = {
"frame": frame_idx,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": bbox,
"timestamp": current_time,
"worker_id": worker_id
}
violations.append(detection)
# Capture snapshot
cap_snapshot = cv2.VideoCapture(video_path)
cap_snapshot.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, snapshot_frame = cap_snapshot.read()
if ret:
snapshot_frame = draw_detections(snapshot_frame, [detection])
snapshot_filename = f"no_helmet_{worker_id}_{frame_idx}.jpg"
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
cv2.imwrite(snapshot_path, snapshot_frame)
snapshots.append({
"violation": "no_helmet",
"frame": frame_idx,
"worker_id": worker_id,
"snapshot_path": snapshot_path,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
})
cap_snapshot.release()
else:
# For other violations, only record if not already detected in this frame
if label not in frame_violations:
detection = {
"frame": frame_idx,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": bbox,
"timestamp": current_time
}
violations.append(detection)
frame_violations.add(label)
# Capture snapshot for first occurrence of this violation type
cap_snapshot = cv2.VideoCapture(video_path)
cap_snapshot.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, snapshot_frame = cap_snapshot.read()
if ret:
snapshot_frame = draw_detections(snapshot_frame, [detection])
snapshot_filename = f"{label}_{frame_idx}.jpg"
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
cv2.imwrite(snapshot_path, snapshot_frame)
snapshots.append({
"violation": label,
"frame": frame_idx,
"snapshot_path": snapshot_path,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
})
cap_snapshot.release()
# Remove inactive workers
inactive_workers = [w_id for w_id, worker in helmet_workers.items()
if current_time - worker["last_seen"] > CONFIG["WORKER_TRACKING_DURATION"]]
for w_id in inactive_workers:
del helmet_workers[w_id]
cap.release()
os.remove(video_path)
processing_time = time.time() - start_time
logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
# Generate results
if not violations:
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
return
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID (Helmet Only) |\n"
violation_table += "|------------------------|---------------|------------|--------------------------|\n"
for v in sorted(violations, key=lambda x: x["timestamp"]):
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
worker_id = v.get("worker_id", "N/A") if v.get("violation") == "no_helmet" else "N/A"
row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {worker_id} |\n"
violation_table += row
# Create HTML for snapshots with clickable links
snapshots_html = "<div style='display: flex; flex-wrap: wrap; gap: 10px;'>"
for s in snapshots:
display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
worker_text = f"Worker {s['worker_id']}" if "worker_id" in s else ""
snapshots_html += f"""
<div style='text-align: center; margin: 10px;'>
<a href='{s['snapshot_url']}' target='_blank'>
<img src='{s['snapshot_url']}' style='max-width: 200px; max-height: 150px;'/>
</a>
<p>{display_name} at frame {s['frame']} {worker_text}</p>
</div>
"""
snapshots_html += "</div>"
yield (
violation_table,
f"Safety Score: {score}%",
snapshots_html,
f"Salesforce Record ID: {report_id or 'N/A'}",
final_pdf_url or "N/A"
)
except Exception as e:
logger.error(f"Error processing video: {e}", exc_info=True)
yield f"Error processing video: {e}", "", "", "", ""
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
with open(video_file, "rb") as f:
video_data = f.read()
for status, score, snapshots_html, record_id, details_url in process_video(video_data):
yield status, score, snapshots_html, record_id, details_url
except Exception as e:
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.HTML(label="Violation Snapshots (Click to enlarge)"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("Launching Enhanced Safety Analyzer App...")
interface.launch()