AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
4b68be9 verified
raw
history blame
19.1 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
import uuid
# ==========================
# Enhanced Configuration
# ==========================
CONFIG = {
"MODEL_PATH": "yolov8_safety.pt",
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone",
4: "improper_tool_use"
},
"CLASS_COLORS": {
"no_helmet": (0, 0, 255),
"no_harness": (0, 165, 255),
"unsafe_posture": (0, 255, 0),
"unsafe_zone": (255, 0, 0),
"improper_tool_use": (255, 255, 0)
},
"DISPLAY_NAMES": {
"no_helmet": "No Helmet Violation",
"no_harness": "No Harness Violation",
"unsafe_posture": "Unsafe Posture Violation",
"unsafe_zone": "Unsafe Zone Entry",
"improper_tool_use": "Improper Tool Use"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
"FRAME_SKIP": 2, # Process every 2nd frame for balance of speed/accuracy
"CONFIDENCE_THRESHOLDS": {
"no_helmet": 0.6,
"no_harness": 0.15,
"unsafe_posture": 0.15,
"unsafe_zone": 0.15,
"improper_tool_use": 0.15
},
"IOU_THRESHOLD": 0.4,
"MIN_VIOLATION_FRAMES": 3, # Require more consistent detections
"HELMET_CONFIDENCE_THRESHOLD": 0.65,
"WORKER_TRACKING_DURATION": 3.0, # Seconds to track a worker
"MIN_FRAME_RATE": 5, # Minimum frames per second to process
"MAX_FRAME_RATE": 15, # Maximum frames per second to process
"BATCH_SIZE": 8 # Number of frames to process at once
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def load_model():
try:
if os.path.isfile(CONFIG["MODEL_PATH"]):
model_path = CONFIG["MODEL_PATH"]
logger.info(f"Model loaded: {model_path}")
else:
model_path = CONFIG["FALLBACK_MODEL"]
logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
if not os.path.isfile(model_path):
logger.info(f"Downloading fallback model: {model_path}")
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
model = YOLO(model_path).to(device)
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Optimized Helper Functions
# ==========================
def draw_detections(frame, detections):
for det in detections:
label = det.get("violation", "Unknown")
confidence = det.get("confidence", 0.0)
x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
x1 = int(x - w/2)
y1 = int(y - h/2)
x2 = int(x + w/2)
y2 = int(y + h/2)
color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
cv2.putText(frame, display_text, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return frame
def calculate_iou(box1, box2):
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
# Calculate coordinates of the intersection rectangle
x_left = max(x1 - w1/2, x2 - w2/2)
y_top = max(y1 - h1/2, y2 - h2/2)
x_right = min(x1 + w1/2, x2 + w2/2)
y_bottom = min(y1 + h1/2, y2 + h2/2)
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
box1_area = w1 * h1
box2_area = w2 * h2
union_area = box1_area + box2_area - intersection_area
return intersection_area / union_area
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
if not violations:
c.drawString(1 * inch, y_position, "No violations detected.")
else:
for v in violations:
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 35,
"improper_tool_use": 25
}
total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
score = 100 - total_penalty
return max(score, 0)
# ==========================
# Optimized Video Processing
# ==========================
def process_video(video_data):
try:
# Create temp video file
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
# Open video file
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file")
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0:
fps = 30 # Default assumption if FPS not available
duration = total_frames / fps
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
# Calculate optimal frame skipping
original_frame_skip = CONFIG["FRAME_SKIP"]
target_fps = min(max(fps / original_frame_skip, CONFIG["MIN_FRAME_RATE"]), CONFIG["MAX_FRAME_RATE"])
actual_frame_skip = max(1, int(fps / target_fps))
frames_to_process = total_frames // actual_frame_skip
logger.info(f"Processing strategy: Frame skip={actual_frame_skip}, Target FPS={target_fps:.1f}, Frames to process={frames_to_process}")
workers = []
violations = []
helmet_violations = {}
snapshots = []
start_time = time.time()
processed_frames = 0
last_progress_update = 0
# Process frames in batches
while True:
batch_frames = []
batch_indices = []
# Collect frames for this batch
for _ in range(CONFIG["BATCH_SIZE"]):
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if frame_idx >= total_frames:
break
ret, frame = cap.read()
if not ret:
break
batch_frames.append(frame)
batch_indices.append(frame_idx)
processed_frames += 1
# Skip frames according to our strategy
for _ in range(actual_frame_skip - 1):
if not cap.grab():
break
# Break if no more frames
if not batch_frames:
break
# Run batch detection
results = model(batch_frames, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
# Process results for each frame in batch
for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
current_time = frame_idx / fps
# Update progress periodically
if time.time() - last_progress_update > 1.0: # Update every second
progress = (frame_idx / total_frames) * 100
yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
last_progress_update = time.time()
# Process detections in this frame
boxes = result.boxes
for box in boxes:
cls = int(box.cls)
conf = float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
continue
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
detection = {
"frame": frame_idx,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": bbox,
"timestamp": current_time
}
# Worker tracking
worker_id = None
max_iou = 0
for idx, worker in enumerate(workers):
iou = calculate_iou(bbox, worker["bbox"])
if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
max_iou = iou
worker_id = worker["id"]
workers[idx]["bbox"] = bbox # Update worker position
workers[idx]["last_seen"] = current_time
if worker_id is None:
worker_id = len(workers) + 1
workers.append({
"id": worker_id,
"bbox": bbox,
"first_seen": current_time,
"last_seen": current_time
})
detection["worker_id"] = worker_id
# Special handling for helmet violations
if label == "no_helmet":
if worker_id not in helmet_violations:
helmet_violations[worker_id] = []
helmet_violations[worker_id].append(detection)
else:
violations.append(detection)
# Remove workers not seen recently
workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
# Process helmet violations (require consistent detections)
for worker_id, detections in helmet_violations.items():
if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
# Find the detection with highest confidence
best_detection = max(detections, key=lambda x: x["confidence"])
violations.append(best_detection)
# Capture snapshot for this violation
cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
ret, snapshot_frame = cap.read()
if ret:
snapshot_frame = draw_detections(snapshot_frame, [best_detection])
snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
cv2.imwrite(snapshot_path, snapshot_frame)
snapshots.append({
"violation": "no_helmet",
"frame": best_detection["frame"],
"snapshot_path": snapshot_path,
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
})
cap.release()
os.remove(video_path)
processing_time = time.time() - start_time
logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
# Generate results
if not violations:
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
return
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
# Generate violation table
violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
violation_table += "|------------------------|---------------|------------|-----------|\n"
for v in sorted(violations, key=lambda x: x["timestamp"]):
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
violation_table += row
# Generate snapshots text
snapshots_text = "\n".join(
f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
for s in snapshots
) if snapshots else "No snapshots captured."
# Push to Salesforce
try:
sf = connect_to_salesforce()
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Status__c": "Completed",
"Processing_Time__c": f"{processing_time:.2f}s"
}
record = sf.Safety_Video_Report__c.create(record_data)
record_id = record["id"]
if pdf_file:
pdf_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
except Exception as e:
logger.error(f"Salesforce integration failed: {e}")
record_id = "N/A (Salesforce error)"
yield (
violation_table,
f"Safety Score: {score}%",
snapshots_text,
f"Salesforce Record ID: {record_id}",
pdf_url or "N/A"
)
except Exception as e:
logger.error(f"Error processing video: {e}", exc_info=True)
yield f"Error processing video: {e}", "", "", "", ""
# ==========================
# Salesforce Integration
# ==========================
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
with open(video_file, "rb") as f:
video_data = f.read()
for status, score, snapshots_text, record_id, details_url in process_video(video_data):
yield status, score, snapshots_text, record_id, details_url
except Exception as e:
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("Launching Enhanced Safety Analyzer App...")
interface.launch()