AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
4c4b661 verified
raw
history blame
32 kB
import os
import sys
import subprocess
import logging
import warnings
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
from retrying import retry
import uuid
from multiprocessing import Pool, cpu_count
from functools import partial
import tempfile
import shutil
# ========================== # Configuration and Setup # ==========================
# Use a temporary directory for storage to avoid file system issues on Hugging Face Spaces
TEMP_DIR = tempfile.mkdtemp(prefix="Ultralytics_")
os.environ['YOLO_CONFIG_DIR'] = TEMP_DIR
# Ensure output directory exists within temp directory
OUTPUT_DIR = os.path.join(TEMP_DIR, "output")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Configure logging for better debugging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Check for FFmpeg availability to ensure video processing works
def check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
logger.info("FFmpeg is available.")
return True
except (subprocess.CalledProcessError, FileNotFoundError):
logger.error("FFmpeg is not installed or not found in PATH. Video processing may fail.")
return False
FFMPEG_AVAILABLE = check_ffmpeg()
# ========================== # ByteTrack Implementation # ==========================
class BYTETracker:
def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
self.track_thresh = track_thresh
self.track_buffer = track_buffer
self.match_thresh = match_thresh
self.frame_rate = frame_rate
self.next_id = 1
self.tracks = {} # Store active tracks
self.worker_history = {} # Track worker positions over time
self.last_positions = {} # Last known positions of workers
def update(self, dets, scores, cls):
tracks = []
current_time = time.time()
# Update existing tracks with new detections
for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
if score < self.track_thresh:
continue
x, y, w, h = det
matched = False
best_iou = 0
best_track_id = None
# Try to match with existing tracks
for track_id, track_info in self.tracks.items():
if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
continue
tx, ty, tw, th = track_info['bbox']
iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
if iou > self.match_thresh and iou > best_iou:
best_iou = iou
best_track_id = track_id
matched = True
if matched:
# Update existing track
self.tracks[best_track_id].update({
'bbox': [x, y, w, h],
'score': score,
'cls': cl,
'last_seen': current_time
})
# Update position history
if best_track_id not in self.worker_history:
self.worker_history[best_track_id] = []
self.worker_history[best_track_id].append([x, y])
self.last_positions[best_track_id] = [x, y]
tracks.append({
'id': best_track_id,
'bbox': [x, y, w, h],
'score': score,
'cls': cl
})
else:
# Create new track
same_worker = False
for worker_id, last_pos in self.last_positions.items():
if self._is_same_worker([x, y], last_pos):
self.tracks[worker_id] = {
'bbox': [x, y, w, h],
'score': score,
'cls': cl,
'last_seen': current_time
}
tracks.append({
'id': worker_id,
'bbox': [x, y, w, h],
'score': score,
'cls': cl
})
same_worker = True
break
if not same_worker:
self.tracks[self.next_id] = {
'bbox': [x, y, w, h],
'score': score,
'cls': cl,
'last_seen': current_time
}
self.worker_history[self.next_id] = [[x, y]]
self.last_positions[self.next_id] = [x, y]
tracks.append({
'id': self.next_id,
'bbox': [x, y, w, h],
'score': score,
'cls': cl
})
self.next_id += 1
# Clean up old tracks
current_time = time.time()
stale_ids = []
for track_id, track_info in self.tracks.items():
if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
stale_ids.append(track_id)
for track_id in stale_ids:
del self.tracks[track_id]
if track_id in self.worker_history:
del self.worker_history[track_id]
if track_id in self.last_positions:
del self.last_positions[track_id]
return tracks
def _calculate_iou(self, box1, box2):
"""Calculate IOU between two boxes"""
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
# Calculate intersection coordinates
x_left = max(x1 - w1/2, x2 - w2/2)
y_top = max(y1 - h1/2, y2 - h2/2)
x_right = min(x1 + w1/2, x2 + w2/2)
y_bottom = min(y1 + h1/2, y2 + h2/2)
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
box1_area = w1 * h1
box2_area = w2 * h2
iou = intersection_area / (box1_area + box2_area - intersection_area)
return iou
def _is_same_worker(self, pos1, pos2, threshold=100):
"""Check if two positions likely belong to the same worker"""
x1, y1 = pos1
x2, y2 = pos2
distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
return distance < threshold
# ========================== # Optimized Configuration # ==========================
CONFIG = {
"MODEL_PATH": "yolov8_safety.pt",
"FALLBACK_MODEL": "yolov8n.pt",
"OUTPUT_DIR": OUTPUT_DIR,
"VIOLATION_LABELS": {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone",
4: "improper_tool_use"
},
"CLASS_COLORS": {
"no_helmet": (0, 0, 255), # Red
"no_harness": (0, 165, 255), # Orange
"unsafe_posture": (0, 255, 0), # Green
"unsafe_zone": (255, 0, 0), # Blue
"improper_tool_use": (255, 255, 0) # Cyan
},
"DISPLAY_NAMES": {
"no_helmet": "No Helmet Violation",
"no_harness": "No Harness Violation",
"unsafe_posture": "Unsafe Posture",
"unsafe_zone": "Unsafe Zone Entry",
"improper_tool_use": "Improper Tool Use"
},
"SF_CREDENTIALS": {
"username": os.getenv("SF_USERNAME", "prashanth1ai@safety.com"),
"password": os.getenv("SF_PASSWORD", "SaiPrash461"),
"security_token": os.getenv("SF_SECURITY_TOKEN", "AP4AQnPoidIKPvSvNEfAHyoK"),
"domain": "login"
},
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
"CONFIDENCE_THRESHOLDS": {
"no_helmet": 0.5,
"no_harness": 0.3,
"unsafe_posture": 0.3,
"unsafe_zone": 0.3,
"improper_tool_use": 0.3
},
"MIN_VIOLATION_FRAMES": 1,
"VIOLATION_COOLDOWN": 30.0,
"WORKER_TRACKING_DURATION": 5.0,
"MAX_PROCESSING_TIME": 60,
"FRAME_SKIP": 2,
"BATCH_SIZE": 8, # Reduced batch size to lower memory usage
"PARALLEL_WORKERS": max(1, cpu_count() - 1),
"TRACK_BUFFER": 30,
"TRACK_THRESH": 0.3,
"MATCH_THRESH": 0.7,
"SNAPSHOT_QUALITY": 95,
"MAX_WORKER_DISTANCE": 100
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
def load_model():
try:
if os.path.isfile(CONFIG["MODEL_PATH"]):
model_path = CONFIG["MODEL_PATH"]
logger.info(f"Model loaded: {model_path}")
else:
model_path = CONFIG["FALLBACK_MODEL"]
logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
if not os.path.isfile(model_path):
logger.info(f"Downloading fallback model: {model_path}")
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
model = YOLO(model_path).to(device)
logger.info(f"Model classes: {model.names}")
return model
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
model = load_model()
# ========================== # Helper Functions # ==========================
def preprocess_frame(frame):
"""Apply basic preprocessing to enhance detection"""
frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
return frame
def draw_detections(frame, detections):
"""Draw bounding boxes and labels on detection frame"""
result_frame = frame.copy()
for det in detections:
label = det.get("violation", "Unknown")
confidence = det.get("confidence", 0.0)
x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
worker_id = det.get("worker_id", "Unknown")
x1 = int(x - w/2)
y1 = int(y - h/2)
x2 = int(x + w/2)
y2 = int(y + h/2)
color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
conf_text = f"Conf: {confidence:.2f}"
cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
return result_frame
def calculate_safety_score(violations):
"""Calculate safety score based on detected violations"""
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 35,
"improper_tool_use": 25
}
worker_violations = {}
for v in violations:
worker_id = v.get("worker_id", "Unknown")
violation_type = v.get("violation", "Unknown")
if worker_id not in worker_violations:
worker_violations[worker_id] = set()
worker_violations[worker_id].add(violation_type)
total_penalty = 0
for worker_violations_set in worker_violations.values():
worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
total_penalty += worker_penalty
score = max(0, 100 - total_penalty)
return score
def generate_violation_pdf(violations, score):
"""Generate a PDF report for the detected violations"""
try:
pdf_filename = f"violations_{int(time.time())}.pdf"
pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
pdf_file = BytesIO()
c = canvas.Canvas(pdf_file, pagesize=letter)
c.setFont("Helvetica-Bold", 16)
c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
c.setFont("Helvetica", 12)
c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
c.setFont("Helvetica-Bold", 14)
c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
y_position = 8.2 * inch
c.setFont("Helvetica-Bold", 12)
c.drawString(1 * inch, y_position, "Summary:")
y_position -= 0.3 * inch
worker_violations = {}
for v in violations:
worker_id = v.get("worker_id", "Unknown")
if worker_id not in worker_violations:
worker_violations[worker_id] = []
worker_violations[worker_id].append(v)
c.setFont("Helvetica", 10)
summary_data = {
"Total Workers with Violations": len(worker_violations),
"Total Violations Found": len(violations),
"Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
for key, value in summary_data.items():
c.drawString(1 * inch, y_position, f"{key}: {value}")
y_position -= 0.25 * inch
y_position -= 0.5 * inch
c.setFont("Helvetica-Bold", 12)
c.drawString(1 * inch, y_position, "Violations by Worker:")
y_position -= 0.3 * inch
c.setFont("Helvetica", 10)
for worker_id, worker_vios in worker_violations.items():
c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
y_position -= 0.2 * inch
for v in worker_vios:
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
time_str = f"{v.get('timestamp', 0.0):.2f}s"
conf_str = f"{v.get('confidence', 0.0):.2f}"
violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
c.drawString(1.2 * inch, y_position, violation_text)
y_position -= 0.2 * inch
if y_position < 1 * inch:
c.showPage()
c.setFont("Helvetica", 10)
y_position = 10 * inch
c.save()
pdf_file.seek(0)
with open(pdf_path, "wb") as f:
f.write(pdf_file.getvalue())
public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
logger.info(f"PDF generated: {public_url}")
return pdf_path, public_url, pdf_file
except Exception as e:
logger.error(f"Error generating PDF: {e}")
return "", "", None
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
"""Connect to Salesforce with retry logic"""
try:
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
logger.info("Connected to Salesforce")
sf.describe()
return sf
except Exception as e:
logger.error(f"Salesforce connection failed: {e}")
raise
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
"""Upload PDF report to Salesforce"""
try:
if not pdf_file:
logger.error("No PDF file provided for upload")
return ""
encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
content_version_data = {
"Title": f"Safety_Violation_Report_{int(time.time())}",
"PathOnClient": f"safety_violation_{int(time.time())}.pdf",
"VersionData": encoded_pdf,
"FirstPublishLocationId": report_id
}
content_version = sf.ContentVersion.create(content_version_data)
result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
if not result['records']:
logger.error("Failed to retrieve ContentVersion")
return ""
file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
logger.info(f"PDF uploaded to Salesforce: {file_url}")
return file_url
except Exception as e:
logger.error(f"Error uploading PDF to Salesforce: {e}")
return ""
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
"""Push violation report to Salesforce"""
try:
sf = connect_to_salesforce()
violations_text = ""
for v in violations:
display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
worker_id = v.get('worker_id', 'Unknown')
timestamp = v.get('timestamp', 0.0)
confidence = v.get('confidence', 0.0)
violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
if not violations_text:
violations_text = "No violations detected."
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
record_data = {
"Compliance_Score__c": score,
"Violations_Found__c": len(violations),
"Violations_Details__c": violations_text,
"Status__c": "Pending",
"PDF_Report_URL__c": pdf_url
}
logger.info(f"Creating Salesforce record with data: {record_data}")
try:
record = sf.Safety_Video_Report__c.create(record_data)
logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
except Exception as e:
logger.error(f"Failed to create Safety_Video_Report__c: {e}")
record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
logger.warning(f"Fell back to Account record: {record['id']}")
record_id = record["id"]
if pdf_file:
uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
if uploaded_url:
try:
sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
except Exception as e:
logger.error(f"Failed to update Safety_Video_Report__c: {e}")
sf.Account.update(record_id, {"Description": uploaded_url})
logger.info(f"Updated Account record {record_id} with PDF URL")
pdf_url = uploaded_url
return record_id, pdf_url
except Exception as e:
logger.error(f"Salesforce record creation failed: {e}")
return "N/A", "Salesforce integration failed."
def process_video(video_data):
"""Process video to detect safety violations"""
try:
# Validate video data
if not video_data:
raise ValueError("Empty video data provided.")
# Save video to a temporary file
video_fd, video_path = tempfile.mkstemp(suffix=".mp4", dir=TEMP_DIR)
with os.fdopen(video_fd, "wb") as f:
f.write(video_data)
logger.info(f"Video saved: {video_path}")
# Open video with OpenCV
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
duration = total_frames / fps
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
# Check if video is empty
if total_frames <= 0:
raise ValueError("Video has no frames.")
tracker = BYTETracker(
track_thresh=CONFIG["TRACK_THRESH"],
track_buffer=CONFIG["TRACK_BUFFER"],
match_thresh=CONFIG["MATCH_THRESH"],
frame_rate=fps
)
unique_violations = {}
snapshots = []
start_time = time.time()
frame_skip = CONFIG["FRAME_SKIP"]
processed_frames = 0
while processed_frames < total_frames:
batch_frames = []
batch_indices = []
for _ in range(CONFIG["BATCH_SIZE"]):
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if frame_idx >= total_frames:
break
ret, frame = cap.read()
if not ret:
logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
break
frame = preprocess_frame(frame)
for _ in range(frame_skip - 1):
if not cap.grab():
break
batch_frames.append(frame)
batch_indices.append(frame_idx)
processed_frames += 1
if not batch_frames:
logger.info("No more frames to process.")
break
# Process batch with YOLO model
try:
results = model(batch_frames, device=device, conf=0.1, verbose=False)
except Exception as e:
logger.error(f"Model inference failed: {e}")
raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}")
for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
current_time = frame_idx / fps
if time.time() - start_time > 1.0:
progress = (processed_frames / total_frames) * 100
yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", ""
start_time = time.time()
boxes = result.boxes
track_inputs = []
for box in boxes:
cls = int(box.cls)
conf = float(box.conf)
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
if label is None:
continue
if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
continue
bbox = box.xywh.cpu().numpy()[0]
track_inputs.append({
"bbox": bbox,
"conf": conf,
"cls": cls
})
if not track_inputs:
continue
tracked_objects = tracker.update(
np.array([t["bbox"] for t in track_inputs]),
np.array([t["conf"] for t in track_inputs]),
np.array([t["cls"] for t in track_inputs])
)
for obj in tracked_objects:
worker_id = obj['id']
label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
conf = obj['score']
bbox = obj['bbox']
if label is None:
continue
if worker_id not in unique_violations:
unique_violations[worker_id] = {}
if label not in unique_violations[worker_id]:
unique_violations[worker_id][label] = current_time
detection = {
"worker_id": worker_id,
"violation": label,
"confidence": round(float(conf), 2), # Ensure confidence is a float
"bounding_box": bbox,
"timestamp": current_time
}
snapshot_frame = batch_frames[i].copy()
snapshot_frame = draw_detections(snapshot_frame, [detection])
cv2.putText(
snapshot_frame,
f"Time: {current_time:.2f}s",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255),
2
)
snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
cv2.imwrite(
snapshot_path,
snapshot_frame,
[cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
)
snapshots.append({
"violation": label,
"worker_id": worker_id,
"timestamp": current_time,
"snapshot_path": snapshot_path,
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
"confidence": round(float(conf), 2) # Ensure confidence is stored as float
})
logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
# Ensure resources are released
cap.release()
if os.path.exists(video_path):
os.remove(video_path)
processing_time = time.time() - start_time
logger.info(f"Processing complete in {processing_time:.2f}s")
# Log the snapshots for debugging
logger.info(f"Snapshots: {snapshots}")
violations = []
for worker_id, worker_violations in unique_violations.items():
for label, detection_time in worker_violations.items():
# Find the confidence from snapshots, ensuring it's a float
confidence = next(
(float(s["confidence"]) for s in snapshots if s["worker_id"] == worker_id and s["violation"] == label),
0.0
)
violation = {
"worker_id": worker_id,
"violation": label,
"timestamp": detection_time,
"confidence": confidence
}
violations.append(violation)
# Log the violations for debugging
logger.info(f"Violations: {violations}")
if not violations:
logger.info("No violations detected after processing")
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
return
score = calculate_safety_score(violations)
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
# Push to Salesforce with fallback
record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
# Generate violation table with robust error handling
violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
violation_table += "|-----------|-----------|----------|------------|\n"
for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
worker_id = v.get("worker_id", "Unknown")
timestamp = v.get("timestamp", 0.0)
# Ensure confidence is a valid float
try:
confidence = float(v.get("confidence", 0.0))
except (ValueError, TypeError) as e:
logger.error(f"Invalid confidence value in violation {v}: {e}")
confidence = 0.0
violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
snapshots_text = ""
for s in snapshots:
display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
worker_id = s.get("worker_id", "Unknown")
timestamp = s.get("timestamp", 0.0)
snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
if not snapshots_text:
snapshots_text = "No snapshots captured."
yield (
violation_table,
f"Safety Score: {score}%",
snapshots_text,
f"Salesforce Record ID: {record_id}",
final_pdf_url
)
except Exception as e:
logger.error(f"Error processing video: {str(e)}", exc_info=True)
if 'video_path' in locals() and os.path.exists(video_path):
os.remove(video_path)
yield f"Error processing video: {str(e)}", "", "", "", ""
finally:
# Clean up temporary directory
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR, ignore_errors=True)
def gradio_interface(video_file):
"""Gradio interface for the video processing"""
if not video_file:
return "No file uploaded.", "", "No file uploaded.", "", ""
try:
with open(video_file, "rb") as f:
video_data = f.read()
# Validate FFmpeg availability
if not FFMPEG_AVAILABLE:
return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", ""
for status, score, snapshots_text, record_id, details_url in process_video(video_data):
yield status, score, snapshots_text, record_id, details_url
except Exception as e:
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
# ========================== # Gradio Interface # ==========================
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.Markdown(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Markdown(label="Snapshots"),
gr.Textbox(label="Salesforce Record ID"),
gr.Textbox(label="Violation Details URL")
],
title="Worksite Safety Violation Analyzer",
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("Launching Enhanced Safety Analyzer App...")
interface.launch()