AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
9cc7878 verified
raw
history blame
21.1 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
# ==========================
# Configuration
# ==========================
CONFIG = {
"MODEL_PATH": "yolov8_safety.pt",
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture"
},
"DISPLAY_NAMES": {
"no_helmet": "No Helmet Violation",
"no_harness": "No Harness Violation",
"unsafe_posture": "Unsafe Posture Violation"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
"FRAME_SKIP": 15,
"MAX_PROCESSING_TIME": 30,
"CONFIDENCE_THRESHOLD": 0.3, # Lowered threshold for detecting all violations
"IOU_THRESHOLD": 0.5 # Added for worker tracking
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def load_model():
try:
if os.path.isfile(CONFIG["MODEL_PATH"]):
model_path = CONFIG["MODEL_PATH"]
logger.info(f"Model loaded: {model_path}")
else:
model_path = CONFIG["FALLBACK_MODEL"]
logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
if not os.path.isfile(model_path):
logger.info(f"Downloading fallback model: {model_path}")
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
model = YOLO(model_path).to(device)
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Helper Functions
# ==========================
def calculate_iou(box1, box2):
"""Calculate Intersection over Union (IoU) for two bounding boxes."""
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
# Convert to top-left and bottom-right coordinates
x1_min, y1_min = x1 - w1/2, y1 - h1/2
x1_max, y1_max = x1 + w1/2, y1 + h1/2
x2_min, y2_min = x2 - w2/2, y2 - h2/2
x2_max, y2_max = x2 + w2/2, y2 + h2/2
# Calculate intersection
x_min = max(x1_min, x2_min)
y_min = max(y1_min, y2_min)
x_max = min(x1_max, x2_max)
y_max = min(y1_max, y2_max)
intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
area1 = w1 * h1
area2 = w2 * h2
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0
# ==========================
# Salesforce Integration
# ==========================
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
sf.describe()
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
if not violations:
c.drawString(1 * inch, y_position, "No violations detected.")
else:
for v in violations:
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
for v in violations
) or "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
logger.info(f"Creating Salesforce record with data: {record_data}")
try:
record = sf.Safety_Video_Report__c.create(record_data)
logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
except Exception as e:
logger.error(f"Failed to create Safety_Video_Report__c: {e}")
record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
logger.warning(f"Fell back to Account record: {record['id']}")
record_id = record["id"]
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
try:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
except Exception as e:
logger.error(f"Failed to update Safety_Video_Report__c: {e}")
sf.Account.update(record_id, {"Description": uploaded_url})
logger.info(f"Updated Account record {record_id} with PDF URL")
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
return None, ""
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20
}
total_penalty = sum(penalties.get(v["violation"], 0) for v in violations)
logger.info(f"Total Penalty: {total_penalty}")
score = 100 - total_penalty
logger.info(f"Calculated Score: {score}")
return max(score, 0)
def process_video(video_data):
try:
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file")
violations, snapshots = [], []
frame_count = 0
start_time = time.time()
fps = video.get(cv2.CAP_PROP_FPS)
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
workers = [] # List to track workers
# Adding debug logging for violation labels
logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
logger.info(f"Using confidence threshold: {CONFIG['CONFIDENCE_THRESHOLD']}")
while True:
ret, frame = video.read()
if not ret:
break
if frame_count % CONFIG["FRAME_SKIP"] != 0:
frame_count += 1
continue
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
logger.info("Processing time limit reached")
break
# Run detection on this frame
results = model(frame, device=device)
current_detections = []
# Process detections from the model
for result in results:
boxes = result.boxes
logger.info(f"Frame {frame_count}: Found {len(boxes)} potential detections")
for box in boxes:
cls, conf = int(box.cls), float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
# Enhanced logging
logger.info(f"Detection: class={cls}, conf={conf:.2f}, label={label}")
# Skip if not a known violation or below confidence threshold
if label not in CONFIG["VIOLATION_LABELS"].values():
logger.info(f"Skipping unknown class: {cls}")
continue
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
logger.info(f"Skipping low confidence: {conf:.2f} < {CONFIG['CONFIDENCE_THRESHOLD']}")
continue
# Process valid detection
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
logger.info(f"Valid detection: {label} with confidence: {conf:.2f}")
current_detections.append({
"violation": label,
"confidence": round(conf, 2),
"bounding_box": bbox,
"timestamp": frame_count / fps,
"frame": frame_count
})
# Process detections and associate with workers
# FIXED: Improved worker tracking logic
for detection in current_detections:
matched_worker = None
max_iou = 0
# Try to match with existing workers
for worker in workers:
iou = calculate_iou(detection["bounding_box"], worker["bbox"])
if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
max_iou = iou
matched_worker = worker
if matched_worker:
# Update existing worker
if detection["violation"] not in matched_worker["violations"]:
# New violation for this worker
logger.info(f"New violation for worker {matched_worker['id']}: {detection['violation']}")
matched_worker["violations"].add(detection["violation"])
violations.append({
"frame": frame_count,
"violation": detection["violation"],
"confidence": detection["confidence"],
"bounding_box": detection["bounding_box"],
"timestamp": detection["timestamp"],
"worker_id": matched_worker["id"]
})
# Save snapshot for this violation type if not already taken
if not snapshot_taken[detection["violation"]]:
snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
cv2.imwrite(snapshot_path, frame)
snapshot_taken[detection["violation"]] = True
snapshots.append({
"violation": detection["violation"],
"frame": frame_count,
"snapshot_path": snapshot_path,
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
})
# Update worker position
matched_worker["bbox"] = detection["bounding_box"]
matched_worker["last_frame"] = frame_count
else:
# New worker detected
worker_id = len(workers) + 1
logger.info(f"New worker {worker_id} with violation: {detection['violation']}")
workers.append({
"id": worker_id,
"violations": {detection["violation"]},
"bbox": detection["bounding_box"],
"last_frame": frame_count
})
violations.append({
"frame": frame_count,
"violation": detection["violation"],
"confidence": detection["confidence"],
"bounding_box": detection["bounding_box"],
"timestamp": detection["timestamp"],
"worker_id": worker_id
})
# Save snapshot for this violation type if not already taken
if not snapshot_taken[detection["violation"]]:
snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
cv2.imwrite(snapshot_path, frame)
snapshot_taken[detection["violation"]] = True
snapshots.append({
"violation": detection["violation"],
"frame": frame_count,
"snapshot_path": snapshot_path,
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
})
# Clean up workers that haven't been seen for a while
active_workers = [w for w in workers if frame_count - w["last_frame"] < CONFIG["FRAME_SKIP"] * 5]
if len(active_workers) != len(workers):
logger.info(f"Cleaned up {len(workers) - len(active_workers)} inactive workers")
workers = active_workers
frame_count += 1
video.release()
os.remove(video_path)
# Final log of violations detected
violation_types = {}
for v in violations:
violation_types[v["violation"]] = violation_types.get(v["violation"], 0) + 1
logger.info(f"Detection complete. Found violations: {violation_types}")
if not violations:
logger.info("No violations detected")
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": "",
"message": "No violations detected in the video."
}
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
return {
"violations": violations,
"snapshots": snapshots,
"score": score,
"salesforce_record_id": report_id,
"violation_details_url": final_pdf_url,
"message": ""
}
except Exception as e:
logger.error(f"Error processing video: {e}", exc_info=True)
return {
"violations": [],
"snapshots": [],
"score": 100,
"salesforce_record_id": None,
"violation_details_url": "",
"message": f"Error processing video: {e}"
}
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
yield "Processing video... please wait.", "", "", "", ""
with open(video_file, "rb") as f:
video_data = f.read()
result = process_video(video_data)
if result.get("message"):
yield result["message"], "", "", "", ""
return
violation_table = "No violations detected."
if result["violations"]:
header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
separator = "|------------------------|---------------|------------|-----------|\n"
rows = []
violation_name_map = CONFIG["DISPLAY_NAMES"]
for v in result["violations"]:
display_name = violation_name_map.get(v["violation"], v["violation"])
row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
rows.append(row)
violation_table = header + separator + "\n".join(rows)
snapshots_text = "No snapshots captured."
if result["snapshots"]:
violation_name_map = CONFIG["DISPLAY_NAMES"]
snapshots_text = "\n".join(
f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
for s in result["snapshots"]
)
yield (
violation_table,
f"Safety Score: {result['score']}%",
snapshots_text,
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
result["violation_details_url"] or "N/A"
)
except Exception as e:
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload site videos to detect safety violations (No Helmet Violation, No Harness Violation, Unsafe Posture Violation). Non-violations are ignored.",
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("Launching Safety Analyzer App...")
interface.launch()