AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
0b9ff0b verified
raw
history blame
23.2 kB
import os
import sys
import subprocess
import logging
import warnings
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
from retrying import retry
import uuid
from multiprocessing import Pool, cpu_count
from functools import partial
# ==========================
# Configuration and Setup
# ==========================
# Handle Ultralytics config directory
os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
os.makedirs('/tmp/Ultralytics', exist_ok=True)
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# ==========================
# ByteTrack Implementation
# ==========================
class BYTETracker:
"""Custom implementation of ByteTrack to avoid installation issues"""
def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
self.track_thresh = track_thresh
self.track_buffer = track_buffer
self.match_thresh = match_thresh
self.frame_rate = frame_rate
self.next_id = 1
def update(self, dets, scores, cls):
tracks = []
for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
if score < self.track_thresh:
continue
x, y, w, h = det
tracks.append({
'id': self.next_id,
'bbox': [x, y, w, h],
'score': score,
'cls': cl
})
self.next_id += 1
return tracks
# ==========================
# Optimized Configuration
# ==========================
CONFIG = {
"MODEL_PATH": "yolov8_safety.pt",
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": "static/output",
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone",
4: "improper_tool_use"
},
"CLASS_COLORS": {
"no_helmet": (0, 0, 255), # Red
"no_harness": (0, 165, 255), # Orange
"unsafe_posture": (0, 255, 0), # Green
"unsafe_zone": (255, 0, 0), # Blue
"improper_tool_use": (255, 255, 0) # Yellow
},
"DISPLAY_NAMES": {
"no_helmet": "No Helmet Violation",
"no_harness": "No Harness Violation",
"unsafe_posture": "Unsafe Posture",
"unsafe_zone": "Unsafe Zone Entry",
"improper_tool_use": "Improper Tool Use"
},
"SF_CREDENTIALS": {
"username": "prashanth1ai@safety.com",
"password": "SaiPrash461",
"security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
"CONFIDENCE_THRESHOLDS": {
"no_helmet": 0.75,
"no_harness": 0.4,
"unsafe_posture": 0.4,
"unsafe_zone": 0.4,
"improper_tool_use": 0.4
},
"MIN_VIOLATION_FRAMES": 3,
"WORKER_TRACKING_DURATION": 5.0,
"MAX_PROCESSING_TIME": 60,
"FRAME_SKIP": 1,
"BATCH_SIZE": 32,
"PARALLEL_WORKERS": max(1, cpu_count() - 1),
"TRACK_BUFFER": 30,
"TRACK_THRESH": 0.4,
"MATCH_THRESH": 0.8
}
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def load_model():
try:
if os.path.isfile(CONFIG["MODEL_PATH"]):
model_path = CONFIG["MODEL_PATH"]
logger.info(f"Model loaded: {model_path}")
else:
model_path = CONFIG["FALLBACK_MODEL"]
logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
if not os.path.isfile(model_path):
logger.info(f"Downloading fallback model: {model_path}")
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
model = YOLO(model_path).to(device)
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# [Rest of your existing functions remain exactly the same...]
# draw_detections(), calculate_safety_score(), generate_violation_pdf(),
# connect_to_salesforce(), upload_pdf_to_salesforce(), push_report_to_salesforce(),
# process_video(), and gradio_interface() functions should be kept exactly as they were
# ==========================
# Optimized Helper Functions
# ==========================
def draw_detections(frame, detections):
for det in detections:
label = det.get("violation", "Unknown")
confidence = det.get("confidence", 0.0)
x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
x1 = int(x - w/2)
y1 = int(y - h/2)
x2 = int(x + w/2)
y2 = int(y + h/2)
color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return frame
def calculate_safety_score(violations):
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 35,
"improper_tool_use": 25
}
total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
score = 100 - total_penalty
return max(score, 0)
def generate_violation_pdf(violations, score):
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 10)
y_position = 9.5 * inch
report_data = {
"Compliance Score": f"{score}%",
"Violations Found": len(violations),
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in report_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.3 * inch
y_position -= 0.3 * inch
c.drawString(1 * inch, y_position, "Violation Details:")
y_position -= 0.3 * inch
if not violations:
c.drawString(1 * inch, y_position, "No violations detected.")
else:
for v in violations:
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
text = f"{display_name} from {v.get('start_timestamp', 0.0):.2f}s to {v.get('end_timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f}, Worker ID: {v.get('worker_id', 'N/A')})"
c.drawString(1 * inch, y_position, text)
y_position -= 0.3 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.showPage()
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
# ==========================
# Salesforce Integration
# ==========================
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
sf.describe()
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
try:
sf = connect_to_salesforce()
violations_text = "\n".join(
f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} from {v.get('start_timestamp', 0.0):.2f}s to {v.get('end_timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f}, Worker ID: {v.get('worker_id', 'N/A')})"
for v in violations
) or "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
logger.info(f"Creating Salesforce record with data: {record_data}")
try:
record = sf.Safety_Video_Report__c.create(record_data)
logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
except Exception as e:
logger.error(f"Failed to create Safety_Video_Report__c: {e}")
record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
logger.warning(f"Fell back to Account record: {record['id']}")
record_id = record["id"]
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
try:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
except Exception as e:
logger.error(f"Failed to update Safety_Video_Report__c: {e}")
sf.Account.update(record_id, {"Description": uploaded_url})
logger.info(f"Updated Account record {record_id} with PDF URL")
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
return None, ""
# ==========================
# Fast Video Processing
# ==========================
def process_video(video_data):
try:
# Create temp video file
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
# Open video file
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file")
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
duration = total_frames / fps
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
# Initialize ByteTrack
tracker = BYTETracker(
track_thresh=CONFIG["TRACK_THRESH"],
track_buffer=CONFIG["TRACK_BUFFER"],
match_thresh=CONFIG["MATCH_THRESH"],
frame_rate=fps
)
# Track violations by worker ID and type
violation_tracker = {} # {worker_id: {violation_type: [detections]}}
snapshots = []
start_time = time.time()
frame_skip = CONFIG["FRAME_SKIP"]
# Process frames in batches
while True:
batch_frames = []
batch_indices = []
# Collect frames for this batch
for _ in range(CONFIG["BATCH_SIZE"]):
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if frame_idx >= total_frames:
break
ret, frame = cap.read()
if not ret:
break
# Skip frames if needed
for _ in range(frame_skip - 1):
if not cap.grab():
break
batch_frames.append(frame)
batch_indices.append(frame_idx)
# Break if no more frames
if not batch_frames:
break
# Run batch detection
results = model(batch_frames, device=device, conf=0.1, verbose=False)
# Process results for each frame in batch
for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
current_time = frame_idx / fps
# Update progress
if time.time() - start_time > 1.0:
progress = (frame_idx / total_frames) * 100
yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
start_time = time.time()
# Prepare detections for ByteTrack
boxes = result.boxes
track_inputs = []
for box in boxes:
cls = int(box.cls)
conf = float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
continue
bbox = box.xywh.cpu().numpy()[0]
track_inputs.append({
"bbox": bbox, # [x, y, w, h]
"conf": conf,
"cls": cls
})
# Update tracker
tracked_objects = tracker.update(
np.array([t["bbox"] for t in track_inputs]),
np.array([t["conf"] for t in track_inputs]),
np.array([t["cls"] for t in track_inputs])
)
# Process tracked objects
for obj, track_input in zip(tracked_objects, track_inputs):
worker_id = obj.id
label = CONFIG["VIOLATION_LABELS"].get(int(obj.cls), None)
bbox = track_input["bbox"]
conf = track_input["conf"]
detection = {
"frame": frame_idx,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in bbox],
"timestamp": current_time,
"worker_id": worker_id
}
# Track violations by worker_id and type
if worker_id not in violation_tracker:
violation_tracker[worker_id] = {}
if label not in violation_tracker[worker_id]:
violation_tracker[worker_id][label] = []
violation_tracker[worker_id][label].append(detection)
cap.release()
os.remove(video_path)
processing_time = time.time() - start_time
logger.info(f"Processing complete in {processing_time:.2f}s")
# Consolidate violations
violations = []
for worker_id, worker_violations in violation_tracker.items():
for label, detections in worker_violations.items():
if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
# Select highest-confidence detection
best_detection = max(detections, key=lambda x: x["confidence"])
best_detection["start_timestamp"] = min(d["timestamp"] for d in detections)
best_detection["end_timestamp"] = max(d["timestamp"] for d in detections)
violations.append(best_detection)
# Capture snapshot for confirmed violation
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
ret, snapshot_frame = cap.read()
if ret:
snapshot_frame = draw_detections(snapshot_frame, [best_detection])
snapshot_filename = f"{label}_{best_detection['frame']}.jpg"
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
cv2.imwrite(snapshot_path, snapshot_frame)
snapshots.append({
"violation": label,
"frame": best_detection["frame"],
"snapshot_path": snapshot_path,
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
})
cap.release()
# Generate results
if not violations:
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
return
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
violation_table = "| Violation | Time Range (s) | Confidence | Worker ID |\n"
violation_table += "|------------------------|----------------|------------|-----------|\n"
for v in sorted(violations, key=lambda x: x["start_timestamp"]):
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
row = f"| {display_name:<22} | {v.get('start_timestamp', 0.0):.2f}-{v.get('end_timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
violation_table += row
snapshots_text = "\n".join(
f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
for s in snapshots
) if snapshots else "No snapshots captured."
yield (
violation_table,
f"Safety Score: {score}%",
snapshots_text,
f"Salesforce Record ID: {report_id or 'N/A'}",
final_pdf_url or "N/A"
)
except Exception as e:
logger.error(f"Error processing video: {e}", exc_info=True)
yield f"Error processing video: {e}", "", "", "", ""
# Initialize device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def load_model():
try:
if os.path.isfile(CONFIG["MODEL_PATH"]):
model_path = CONFIG["MODEL_PATH"]
logger.info(f"Model loaded: {model_path}")
else:
model_path = CONFIG["FALLBACK_MODEL"]
logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
if not os.path.isfile(model_path):
logger.info(f"Downloading fallback model: {model_path}")
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
model = YOLO(model_path).to(device)
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ==========================
# Helper Functions
# ==========================
def draw_detections(frame, detections):
# ... [your existing implementation] ...
def calculate_safety_score(violations):
# ... [your existing implementation] ...
def generate_violation_pdf(violations, score):
# ... [your existing implementation] ...
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
# ... [your existing implementation] ...
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
# ... [your existing implementation] ...
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
# ... [your existing implementation] ...
def process_video(video_data):
# ... [your existing implementation] ...
def gradio_interface(video_file):
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
with open(video_file, "rb") as f:
video_data = f.read()
for status, score, snapshots_text, record_id, details_url in process_video(video_data):
yield status, score, snapshots_text, record_id, details_url
except Exception as e:
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
# ==========================
# Gradio Interface
# ==========================
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("Launching Enhanced Safety Analyzer App...")
interface.launch()