PrashanthB461's picture
Update app.py
1aa6cf2 verified
import os
import sys
import warnings
import gradio as gr
import torch
from ultralytics import YOLO
import cv2
import requests
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import logging
import pandas as pd
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from io import BytesIO
import seaborn as sns
import matplotlib.pyplot as plt
import subprocess
from datetime import timezone
import pytz
import shutil
import tempfile
from scipy.spatial import distance
import asyncio
from functools import partial
from concurrent.futures import ThreadPoolExecutor
from simple_salesforce import Salesforce, SalesforceAuthenticationFailed
from retrying import retry
import base64
# --- Initial Configuration ---
warnings.filterwarnings("ignore")
MODEL_PATH = "./yolov8n.pt"
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Download model weights if needed
if not os.path.exists(MODEL_PATH):
logger.info(f"Downloading model weights to {MODEL_PATH}...")
try:
download_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt"
subprocess.run(["wget", download_url, "-O", MODEL_PATH], check=True)
os.chmod(MODEL_PATH, 0o644)
logger.info("Model weights downloaded successfully")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to download model: {e}")
sys.exit(1)
# Set up YOLO config directory
yolo_config_dir = "./Ultralytics"
os.makedirs(yolo_config_dir, exist_ok=True)
os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir
# --- Environment Variables ---
RTSP_URL_DEFAULT = os.getenv("RTSP_URL", "")
SALESFORCE_URL = os.getenv("SALESFORCE_URL", "")
SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "")
HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "")
SF_USERNAME = "SafetyViolationAI22@sathkrutha.com"
SF_PASSWORD = "Vij@y12345"
SF_SECURITY_TOKEN = "inrcIMUU7rkV7BnNZ2LvD5MVQ"
SALESFORCE_ENABLED = all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN])
# --- Time Zone Configuration ---
IST = pytz.timezone("Asia/Kolkata")
# --- Global Variables ---
yolo_model = None
recent_violations = []
violation_history = []
processing_active = False
sf_connection = None
site_id_cache = {}
# --- Model Initialization ---
def initialize_model():
global yolo_model
try:
logger.info("Initializing YOLOv8 model...")
yolo_model = YOLO(MODEL_PATH)
logger.info("YOLOv8 model loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
return False
if not initialize_model():
logger.error("Critical error: Model initialization failed")
sys.exit(1)
# --- Salesforce Integration Functions ---
@retry(stop_max_attempt_number=3, wait_fixed=2000,
retry_on_exception=lambda e: isinstance(e, Exception))
def get_salesforce_connection():
"""Establishes and caches a Salesforce connection with retry logic."""
global sf_connection
if sf_connection:
try:
sf_connection.query("SELECT Id FROM User LIMIT 1")
logger.info("Salesforce connection is active.")
return sf_connection
except Exception:
logger.warning("Salesforce session expired. Reconnecting...")
sf_connection = None
if not SALESFORCE_ENABLED:
raise ConnectionError("Salesforce credentials are not configured.")
try:
sf_connection = Salesforce(username=SF_USERNAME, password=SF_PASSWORD, security_token=SF_SECURITY_TOKEN)
logger.info(f"Successfully connected to Salesforce instance: {sf_connection.sf_instance}")
return sf_connection
except SalesforceAuthenticationFailed as e:
logger.error(f"Salesforce authentication failed: {e}. Check credentials and IP restrictions.")
raise
except Exception as e:
logger.error(f"Failed to connect to Salesforce: {e}")
raise
def get_or_create_site_id(sf, site_name='SITE001'):
"""
Queries for a Site record by name, creates it if not found,
and returns the Salesforce ID. Caches the result.
"""
if site_name in site_id_cache:
return site_id_cache[site_name]
try:
query = f"SELECT Id FROM Site__c WHERE Name = '{site_name}' LIMIT 1"
result = sf.query(query)
if result['totalSize'] > 0:
site_id = result['records'][0]['Id']
logger.info(f"Found existing Site '{site_name}' with ID: {site_id}")
site_id_cache[site_name] = site_id
return site_id
else:
logger.info(f"Site '{site_name}' not found. Creating new Site record...")
create_result = sf.Site__c.create({'Name': site_name})
if 'id' in create_result:
site_id = create_result['id']
logger.info(f"Successfully created new Site '{site_name}' with ID: {site_id}")
site_id_cache[site_name] = site_id
return site_id
else:
logger.error(f"Failed to create Site record: {create_result.get('errors')}")
return None
except Exception as e:
logger.error(f"Error getting or creating Site ID for '{site_name}': {e}", exc_info=True)
return None
def create_salesforce_violation_record(sf, violation_data):
"""
Creates a new record in the Safety_Violation_Log__c object in Salesforce.
Returns the ID of the new record.
"""
try:
site_name = violation_data.get('site_id', 'Default Site')
site_id_from_sf = get_or_create_site_id(sf, site_name)
if not site_id_from_sf:
logger.error(f"Failed to get or create Site record '{site_name}' in Salesforce.")
return None, "Site ID creation/retrieval failed."
payload = {
'Site_ID__c': site_id_from_sf,
'Violation_Type__c': violation_data.get('violation_type'),
'Severity__c': violation_data.get('severity', 'Medium'),
'Timestamp__c': violation_data.get('timestamp'),
'Snapshot_URL__c': violation_data.get('snapshot_url', 'N/A'),
'Worker_ID__c': violation_data.get('worker_id', 'N/A'),
'Camera_ID__c': violation_data.get('camera_id', 'CAM001'),
'Alert_Sent__c': True,
'PDF_Report_URL__c': violation_data.get('pdf_url', "Report will be available after processing completion.")
}
payload = {k: v for k, v in payload.items() if v is not None}
logger.info(f"Creating Salesforce record with payload: {json.dumps(payload)}")
result = sf.Safety_Violation_Log__c.create(payload)
record_id = result.get('id')
if record_id:
logger.info(f"Successfully created Salesforce record: {record_id}")
return record_id, "Record created."
else:
errors = result.get('errors', ['Unknown error'])
logger.error(f"Salesforce record creation failed. Response: {errors}")
return None, errors[0]
except Exception as e:
logger.error(f"Failed to create Salesforce record: {e}", exc_info=True)
return None, str(e)
def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
"""
Generates a PDF report, uploads it to Salesforce, links it to records,
and updates those records with the download URL.
Returns a temporary local path for the PDF and the Salesforce URL.
"""
if not violations or not record_ids or not sf:
logger.warning("No violations, record IDs, or Salesforce connection. Skipping report generation.")
return None, None
try:
# 1. Generate PDF in memory
buffer = BytesIO()
c = canvas.Canvas(buffer, pagesize=letter)
c.setFont("Helvetica-Bold", 16)
c.drawString(100, 750, "Safety Violation Report")
c.setFont("Helvetica", 12)
c.drawString(100, 730, f"Generated: {datetime.now(IST).strftime('%Y-%m-%d %H:%M:%S IST')}")
c.setFont("Helvetica", 10)
c.drawString(100, 710, "Note: Each violation type reported only once per person per session.")
y = 680
for i, violation in enumerate(violations, 1):
c.setFont("Helvetica-Bold", 12)
c.drawString(100, y, f"Violation #{i}: {violation['violation_type']}")
y -= 20
c.setFont("Helvetica", 10)
c.drawString(120, y, f"Severity: {violation['severity']}")
y -= 15
c.drawString(120, y, f"Time: {violation['timestamp']}")
y -= 15
c.drawString(120, y, f"Worker: {violation.get('worker_id', 'UNKNOWN')}")
y -= 15
if 'distance' in violation:
c.drawString(120, y, f"Distance: {violation['distance']}")
y -= 15
y -= 20
if y < 50:
c.showPage()
y = 750
c.save()
pdf_bytes = buffer.getvalue()
buffer.close()
# 2. Upload ContentVersion to Salesforce
title = f"Safety_Report_{datetime.now(IST).strftime('%Y%m%d_%H%M%S')}"
b64_pdf = base64.b64encode(pdf_bytes).decode('utf-8')
logger.info(f"Uploading PDF '{title}.pdf' to Salesforce...")
cv_result = sf.ContentVersion.create({
'Title': title,
'PathOnClient': f'{title}.pdf',
'VersionData': b64_pdf
})
if not cv_result.get('success'):
logger.error(f"Failed to create ContentVersion: {cv_result.get('errors')}")
return None, None
content_version_id = cv_result['id']
logger.info(f"Successfully created ContentVersion with ID: {content_version_id}")
# 3. Get ContentDocumentId
query = f"SELECT ContentDocumentId FROM ContentVersion WHERE Id = '{content_version_id}'"
cv_query_result = sf.query(query)
if not cv_query_result['records']:
logger.error(f"Could not find ContentDocumentId for ContentVersion {content_version_id}")
return None, None
content_document_id = cv_query_result['records'][0]['ContentDocumentId']
# 4. Link ContentDocument to all violation records
logger.info(f"Linking ContentDocument {content_document_id} to {len(record_ids)} records...")
link_payloads = [{
'ContentDocumentId': content_document_id,
'LinkedEntityId': record_id,
'ShareType': 'V' # V = Viewer
} for record_id in record_ids]
link_success_count = 0
for payload in link_payloads:
try:
link_result = sf.ContentDocumentLink.create(payload)
if link_result.get('success'):
link_success_count += 1
else:
logger.warning(f"Failed to link to {payload['LinkedEntityId']}: {link_result.get('errors')}")
except Exception as e:
logger.error(f"Error creating ContentDocumentLink for {payload['LinkedEntityId']}: {e}")
logger.info(f"Successfully created {link_success_count}/{len(record_ids)} links.")
# 5. Construct URL and Update records
sf_instance_url = sf.sf_instance.replace('https://', '')
pdf_url = f"https://{sf_instance_url}/sfc/servlet.shepherd/version/download/{content_version_id}"
logger.info(f"Updating records with Salesforce PDF URL: {pdf_url}")
update_payloads = [{'Id': record_id, 'PDF_Report_URL__c': pdf_url} for record_id in record_ids]
update_results = sf.bulk.Safety_Violation_Log__c.update(update_payloads)
successful_updates = sum(1 for res in update_results if res.get('success'))
logger.info(f"Successfully updated {successful_updates}/{len(record_ids)} records with the PDF URL.")
# 6. Save PDF to a temporary file for Gradio output
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf', prefix='report_') as temp_pdf:
temp_pdf.write(pdf_bytes)
temp_pdf_path = temp_pdf.name
logger.info(f"Salesforce report URL: {pdf_url}")
logger.info(f"Temporary local report for download: {temp_pdf_path}")
return temp_pdf_path, pdf_url
except Exception as e:
logger.error(f"Error in Salesforce PDF report generation/upload: {e}", exc_info=True)
return None, None
# --- Safety Violation Detector Class ---
class SafetyViolationDetector:
def __init__(self):
# Detection thresholds (fine-tuned for better accuracy)
self.helmet_threshold = 0.75
self.person_threshold = 0.60
self.unsafe_distance = 50 # pixels
self.violation_cooldown = 20 # seconds
# Unauthorized zones (x1, y1, x2, y2)
self.unauthorized_zones = [
[100, 100, 300, 300], # Example zone 1
[400, 200, 600, 400] # Example zone 2
]
self.active_violations = {}
self.session_violations = {}
self.person_tracker = {}
self.person_positions_history = {}
self.next_person_id = 1
self.max_tracking_distance = 120
self.max_history_length = 20 # Increased for better tracking continuity
self.min_iou_for_match = 0.4 # Adjusted for stricter matching
self.min_score_for_match = 0.5 # Adjusted for stricter matching
def reset_session(self):
self.session_violations = {}
self.active_violations = {}
self.person_tracker = {}
self.person_positions_history = {}
self.next_person_id = 1
logger.info("Session violation tracking reset for new video")
def has_reported_violation(self, person_id, violation_type):
if person_id not in self.session_violations:
return False
return violation_type in self.session_violations[person_id]
def mark_violation_reported(self, person_id, violation_type, timestamp):
if person_id not in self.session_violations:
self.session_violations[person_id] = {}
self.session_violations[person_id][violation_type] = {
'first_detected': timestamp,
'count': self.session_violations[person_id].get(violation_type, {}).get('count', 0) + 1
}
def _get_stable_person_id(self, box, current_time):
center = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
box_area = (box[2] - box[0]) * (box[3] - box[1])
best_match_id = None
best_match_score = 0
min_distance = float('inf')
for person_id, history in self.person_positions_history.items():
if not history['positions']:
continue
last_position = history['positions'][-1]
last_box = history['boxes'][-1]
dist = np.sqrt((center[0] - last_position[0])**2 + (center[1] - last_position[1])**2)
iou = self._iou(box, last_box)
# Weighted scoring: more emphasis on IoU for stable tracking
score = (1.0 / (1.0 + dist/50)) * 0.3 + iou * 0.7
if score > best_match_score and score > self.min_score_for_match and iou > self.min_iou_for_match:
best_match_score = score
best_match_id = person_id
min_distance = dist
if best_match_id is not None:
person_id = best_match_id
else:
person_id = self.next_person_id
self.next_person_id += 1
self.person_positions_history[person_id] = {
'positions': [],
'boxes': [],
'first_seen': current_time,
'last_seen': current_time,
'features': [] # Store features for enhanced tracking
}
# Update tracking history
self.person_positions_history[person_id]['positions'].append(center)
self.person_positions_history[person_id]['boxes'].append(box)
self.person_positions_history[person_id]['last_seen'] = current_time
# Limit history length to prevent memory issues
if len(self.person_positions_history[person_id]['positions']) > self.max_history_length:
self.person_positions_history[person_id]['positions'].pop(0)
self.person_positions_history[person_id]['boxes'].pop(0)
return person_id
def detect_violations(self, results, frame):
current_time = time.time()
violations = []
boxes = results[0].boxes.xyxy.cpu().numpy()
confidences = results[0].boxes.conf.cpu().numpy()
class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
class_names = results[0].names
persons = []
helmets = []
for box, conf, cls_id in zip(boxes, confidences, class_ids):
class_name = class_names[cls_id]
if class_name == "person" and conf >= self.person_threshold:
person_id = self._get_stable_person_id(box, current_time)
persons.append({
'box': box,
'confidence': conf,
'center': ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2),
'id': person_id
})
elif class_name == "hard hat" and conf >= self.helmet_threshold:
helmets.append({
'box': box,
'confidence': conf,
'area': (box[2] - box[0]) * (box[3] - box[1])
})
current_person_ids = set()
for person in persons:
person_id = person['id']
current_person_ids.add(person_id)
if person_id not in self.person_tracker:
self.person_tracker[person_id] = {
'first_seen': current_time,
'last_seen': current_time,
'positions': [person['center']],
'helmet_status': False,
'violations': {},
'last_violation_frame': {}
}
else:
self.person_tracker[person_id]['last_seen'] = current_time
self.person_tracker[person_id]['positions'].append(person['center'])
if len(self.person_tracker[person_id]['positions']) > self.max_history_length:
self.person_tracker[person_id]['positions'].pop(0)
for person in persons:
person_id = person['id']
helmet_violation = self._check_helmet_violation(person, helmets, frame, current_time)
if helmet_violation:
violations.append(helmet_violation)
unauthorized_violation = self._check_unauthorized_area(person, frame, current_time)
if unauthorized_violation:
violations.append(unauthorized_violation)
distance_violations = self._check_distance_violations(persons, frame, current_time)
violations.extend(distance_violations)
self._cleanup_violations(current_time)
self._cleanup_inactive_persons(current_person_ids, current_time)
return violations
def _check_helmet_violation(self, person, helmets, frame, current_time):
person_id = person['id']
person_box = person['box']
violation_type = 'no_helmet'
if self.has_reported_violation(person_id, violation_type):
return None
head_region = [
person_box[0],
max(person_box[1], person_box[1] + (person_box[3] - person_box[1]) * 0.3),
person_box[2],
person_box[1] + (person_box[3] - person_box[1]) * 0.3
]
has_helmet = False
for helmet in helmets:
if self._iou(helmet['box'], head_region) > 0.1:
has_helmet = True
break
self.person_tracker[person_id]['helmet_status'] = has_helmet
if not has_helmet:
violation_key = f"no_helmet_{person_id}"
# Ensure persistent violation check across frames
if violation_key not in self.active_violations:
self.active_violations[violation_key] = {
'type': 'no_helmet',
'person_id': person_id,
'first_detected': current_time,
'last_detected': current_time,
'count': 0,
'confirmed': False
}
self.active_violations[violation_key]['count'] += 1
self.active_violations[violation_key]['last_detected'] = current_time
# Require consistent detection over multiple frames for confirmation
if self.active_violations[violation_key]['count'] >= 5 and not self.active_violations[violation_key]['confirmed']:
self.mark_violation_reported(person_id, violation_type, current_time)
self.active_violations[violation_key]['confirmed'] = True
if 'no_helmet' not in self.person_tracker[person_id]['violations']:
self.person_tracker[person_id]['violations']['no_helmet'] = {
'count': 0,
'last_time': 0
}
self.person_tracker[person_id]['violations']['no_helmet']['count'] += 1
self.person_tracker[person_id]['violations']['no_helmet']['last_time'] = current_time
self._annotate_frame(frame, person_box, person_id, "NO HELMET", (0, 0, 255))
logger.info(f"CONFIRMED VIOLATION: No helmet detected for person {person_id}")
return {
'type': 'no_helmet',
'severity': 'Critical',
'person': person,
'person_id': person_id,
'timestamp': current_time
}
return None
def _check_unauthorized_area(self, person, frame, current_time):
person_id = person['id']
violation_type = 'unauthorized_area'
if self.has_reported_violation(person_id, violation_type):
return None
x1, y1, x2, y2 = person['box']
person_center = ((x1 + x2) / 2, (y1 + y2) / 2)
for zone in self.unauthorized_zones:
zx1, zy1, zx2, zy2 = zone
if (zx1 <= person_center[0] <= zx2 and zy1 <= person_center[1] <= zy2):
violation_key = f"unauthorized_area_{person_id}_{zx1}_{zy1}"
if violation_key not in self.active_violations:
self.active_violations[violation_key] = {
'type': 'unauthorized_area',
'person_id': person_id,
'zone': zone,
'first_detected': current_time,
'last_detected': current_time,
'count': 0,
'confirmed': False
}
self.active_violations[violation_key]['count'] += 1
self.active_violations[violation_key]['last_detected'] = current_time
if self.active_violations[violation_key]['count'] >= 5 and not self.active_violations[violation_key]['confirmed']:
self.mark_violation_reported(person_id, violation_type, current_time)
self.active_violations[violation_key]['confirmed'] = True
if 'unauthorized_area' not in self.person_tracker[person_id]['violations']:
self.person_tracker[person_id]['violations']['unauthorized_area'] = {
'count': 0,
'last_time': 0
}
self.person_tracker[person_id]['violations']['unauthorized_area']['count'] += 1
self.person_tracker[person_id]['violations']['unauthorized_area']['last_time'] = current_time
cv2.rectangle(frame, (zx1, zy1), (zx2, zy2), (255, 0, 255), 2)
self._annotate_frame(frame, person['box'], person_id, "UNAUTHORIZED", (255, 0, 255))
logger.info(f"CONFIRMED VIOLATION: Unauthorized area detected for person {person_id}")
return {
'type': 'unauthorized_area',
'severity': 'High',
'person': person,
'person_id': person_id,
'zone': zone,
'timestamp': current_time
}
return None
def _check_distance_violations(self, persons, frame, current_time):
violations = []
if len(persons) < 2:
return violations
for i in range(len(persons)):
for j in range(i+1, len(persons)):
dist = self._euclidean_distance(persons[i]['center'], persons[j]['center'])
if dist < self.unsafe_distance:
person1_id = persons[i]['id']
person2_id = persons[j]['id']
violation_type = 'unsafe_distance'
if (self.has_reported_violation(person1_id, violation_type) or
self.has_reported_violation(person2_id, violation_type)):
continue
pair_key = f"{min(person1_id, person2_id)}_{max(person1_id, person2_id)}"
violation_key = f"unsafe_distance_{pair_key}"
if violation_key not in self.active_violations:
self.active_violations[violation_key] = {
'type': 'unsafe_distance',
'person1_id': person1_id,
'person2_id': person2_id,
'first_detected': current_time,
'last_detected': current_time,
'count': 0,
'confirmed': False
}
self.active_violations[violation_key]['count'] += 1
self.active_violations[violation_key]['last_detected'] = current_time
if self.active_violations[violation_key]['count'] >= 5 and not self.active_violations[violation_key]['confirmed']:
self.mark_violation_reported(person1_id, violation_type, current_time)
self.mark_violation_reported(person2_id, violation_type, current_time)
self.active_violations[violation_key]['confirmed'] = True
for pid in [person1_id, person2_id]:
if 'unsafe_distance' not in self.person_tracker[pid]['violations']:
self.person_tracker[pid]['violations']['unsafe_distance'] = {
'count': 0,
'last_time': 0
}
self.person_tracker[pid]['violations']['unsafe_distance']['count'] += 1
self.person_tracker[pid]['violations']['unsafe_distance']['last_time'] = current_time
self._annotate_distance(frame, persons[i]['box'], persons[j]['box'],
person1_id, person2_id, dist)
logger.info(f"CONFIRMED VIOLATION: Unsafe distance detected between persons {person1_id} and {person2_id}")
violations.append({
'type': 'unsafe_distance',
'severity': 'Moderate',
'person1': persons[i],
'person2': persons[j],
'distance': dist,
'person1_id': person1_id,
'person2_id': person2_id,
'timestamp': current_time
})
return violations
def _cleanup_violations(self, current_time):
expired_violations = [
k for k, v in self.active_violations.items()
if current_time - v['last_detected'] > self.violation_cooldown
]
for key in expired_violations:
del self.active_violations[key]
def _cleanup_inactive_persons(self, current_person_ids, current_time):
inactive_timeout = 60
expired_persons = [
pid for pid, data in self.person_tracker.items()
if pid not in current_person_ids and
current_time - data['last_seen'] > inactive_timeout
]
for pid in expired_persons:
del self.person_tracker[pid]
if pid in self.person_positions_history:
del self.person_positions_history[pid]
def _annotate_frame(self, frame, box, person_id, violation_type, color):
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label = f"ID:{person_id:03d} {violation_type}"
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
def _annotate_distance(self, frame, box1, box2, id1, id2, dist):
x1, y1, x2, y2 = map(int, box1)
x3, y3, x4, y4 = map(int, box2)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), 2)
cv2.rectangle(frame, (x3, y3), (x4, y4), (0, 165, 255), 2)
center1 = ((x1 + x2) // 2, (y1 + y2) // 2)
center2 = ((x3 + x4) // 2, (y3 + y4) // 2)
cv2.line(frame, center1, center2, (0, 165, 255), 2)
mid_point = ((center1[0] + center2[0]) // 2, (center1[1] + center2[1]) // 2)
cv2.putText(frame, f"{dist:.1f}px", mid_point,
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 2)
def _iou(self, box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
intersection = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
return intersection / (area1 + area2 - intersection + 1e-6)
def _euclidean_distance(self, point1, point2):
return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)
def get_session_summary(self):
summary = {
'total_persons': len(self.session_violations),
'violations_by_type': {},
'persons_with_violations': []
}
for person_id, violations in self.session_violations.items():
person_info = {
'person_id': person_id,
'violations': list(violations.keys()),
'violation_count': len(violations)
}
summary['persons_with_violations'].append(person_info)
for violation_type in violations.keys():
if violation_type not in summary['violations_by_type']:
summary['violations_by_type'][violation_type] = 0
summary['violations_by_type'][violation_type] += 1
return summary
# --- Frame Processing Functions ---
def preprocess_frame(frame):
try:
# Enhance image for better detection
frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) # Increase contrast
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img, (640, 640))
return img_resized
except Exception as e:
logger.error(f"Frame preprocessing error: {e}")
raise
def capture_rtsp_frames(rtsp_url, max_frames=None):
try:
logger.info(f"Connecting to RTSP stream: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
raise ValueError(f"RTSP stream not accessible: {rtsp_url}")
frame_count = 0
while cap.isOpened() and (max_frames is None or frame_count < max_frames):
ret, frame = cap.read()
if ret:
timestamp = datetime.now(IST).isoformat()
frame_count += 1
yield frame, timestamp, frame_count, None
else:
logger.warning("Failed to read frame from RTSP stream")
break
cap.release()
except Exception as e:
logger.error(f"RTSP capture error: {e}")
raise
finally:
cv2.destroyAllWindows()
# --- Video Processing Functions ---
async def process_video(video_path, frame_skip=1, progress=gr.Progress()):
global processing_active
processing_active = True
try:
current_run_violations = []
new_sf_record_ids = []
tracker = SafetyViolationDetector()
tracker.reset_session()
logger.info("Starting new video analysis session")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
error_message = f"Failed to open video file: {video_path}"
logger.error(error_message)
return None, error_message, None, format_violations_as_text([])
frames = []
start_time = time.time()
frame_count = 0
processed_frames = 0
violation_count = 0
# Get Salesforce connection once at the beginning
sf = None
if SALESFORCE_ENABLED:
try:
sf = get_salesforce_connection()
except Exception as e:
logger.error(f"Could not connect to Salesforce at start: {e}")
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
progress(0, desc="Analyzing video...")
while cap.isOpened() and processing_active:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % frame_skip != 0:
continue
processed_frames += 1
timestamp = datetime.now(IST).isoformat()
progress_percent = min(100, (frame_count / total_frames) * 100)
progress(progress_percent / 100, desc=f"Processing frame {frame_count}/{total_frames}")
processed_frame = preprocess_frame(frame)
results = yolo_model.predict(processed_frame)
violations = tracker.detect_violations(results, frame)
for violation in violations:
violation_count += 1
snapshot_url = save_snapshot(frame)
worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
if violation['type'] == 'unsafe_distance':
worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
violation_data = {
'violation_type': violation['type'].replace('_', ' ').title(),
'severity': violation['severity'],
'timestamp': timestamp,
'snapshot_url': snapshot_url,
'site_id': 'SITE001',
'camera_id': 'CAM001',
'worker_id': worker_id,
'frame_number': frame_count
}
if violation['type'] == 'unsafe_distance':
violation_data['distance'] = f"{violation['distance']:.1f}px"
current_run_violations.append(violation_data)
log_violation(violation_data)
send_alert(violation_data)
if sf:
record_id, message = create_salesforce_violation_record(sf, violation_data)
if record_id:
new_sf_record_ids.append(record_id)
else:
logger.error(f"Salesforce push failed for violation: {message}")
if violations:
frames.append(frame)
cap.release()
processing_time = time.time() - start_time
actual_fps = processed_frames / processing_time if processing_time > 0 else 0
if not processing_active:
return None, "Processing cancelled", None, format_violations_as_text([])
# Generate and upload report to Salesforce
pdf_temp_path = None
if sf and new_sf_record_ids and current_run_violations:
logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...")
pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids)
if not pdf_temp_path:
logger.error("Failed to generate and upload Salesforce report.")
elif not current_run_violations:
logger.info("No violations detected, skipping report generation.")
else:
logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.")
session_summary = tracker.get_session_summary()
logger.info(f"Video analysis complete. Session summary: {session_summary}")
status_message = generate_status_message(
violation_count > 0,
frame_count,
processed_frames,
duration,
violation_count,
processing_time,
actual_fps,
session_summary
)
return frames, status_message, pdf_temp_path, format_violations_as_text(current_run_violations)
except Exception as e:
logger.error(f"Video processing error: {e}", exc_info=True)
error_message = f"Video processing failed: {str(e)}"
return None, error_message, None, format_violations_as_text([])
finally:
processing_active = False
cv2.destroyAllWindows()
# --- RTSP Processing ---
async def process_rtsp_stream(rtsp_url, max_frames=None, frame_skip=1, progress=gr.Progress()):
global processing_active
processing_active = True
try:
if not rtsp_url:
raise ValueError("RTSP URL not provided")
current_run_violations = []
new_sf_record_ids = []
tracker = SafetyViolationDetector()
tracker.reset_session()
logger.info("Starting new RTSP stream analysis session")
# Get Salesforce connection once at the beginning
sf = None
if SALESFORCE_ENABLED:
try:
sf = get_salesforce_connection()
except Exception as e:
logger.error(f"Could not connect to Salesforce at start: {e}")
frames = []
violation_count = 0
progress(0, desc="Connecting to RTSP stream...")
for frame, timestamp, fc, _ in capture_rtsp_frames(rtsp_url, max_frames):
if not processing_active:
break
if fc % frame_skip != 0:
continue
progress_percent = min(100, (fc / (max_frames if max_frames else 100)) * 100)
progress(progress_percent / 100, desc=f"Processing frame {fc}")
processed_frame = preprocess_frame(frame)
results = yolo_model.predict(processed_frame)
violations = tracker.detect_violations(results, frame)
for violation in violations:
violation_count += 1
snapshot_url = save_snapshot(frame)
worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
if violation['type'] == 'unsafe_distance':
worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
violation_data = {
'violation_type': violation['type'].replace('_', ' ').title(),
'severity': violation['severity'],
'timestamp': timestamp,
'snapshot_url': snapshot_url,
'site_id': 'SITE001',
'camera_id': 'CAM001',
'worker_id': worker_id,
'frame_number': fc
}
if violation['type'] == 'unsafe_distance':
violation_data['distance'] = f"{violation['distance']:.1f}px"
current_run_violations.append(violation_data)
log_violation(violation_data)
send_alert(violation_data)
if sf:
record_id, message = create_salesforce_violation_record(sf, violation_data)
if record_id:
new_sf_record_ids.append(record_id)
else:
logger.error(f"Salesforce push failed for violation: {message}")
if violations:
frames.append(frame)
if not processing_active:
logger.info("Processing cancelled by user.")
# Generate and upload report to Salesforce
pdf_temp_path = None
if sf and new_sf_record_ids and current_run_violations:
logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...")
pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids)
if not pdf_temp_path:
logger.error("Failed to generate and upload Salesforce report.")
elif not current_run_violations:
logger.info("No violations detected, skipping report generation.")
else:
logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.")
if not processing_active:
return "Processing cancelled.", frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations), pdf_temp_path
session_summary = tracker.get_session_summary()
logger.info(f"RTSP analysis complete. Session summary: {session_summary}")
status_message = f"Processed {len(frames)} frames with {violation_count} unique violations. Persons tracked: {session_summary['total_persons']}"
return status_message, frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations), pdf_temp_path
except Exception as e:
logger.error(f"RTSP processing error: {e}", exc_info=True)
error_message = f"RTSP processing failed: {str(e)}"
return error_message, None, format_violations_as_text([]), None, None
finally:
processing_active = False
cv2.destroyAllWindows()
# --- Other Functions ---
def generate_status_message(has_violations, total_frames, processed_frames, duration,
violation_count, processing_time, actual_fps, session_summary=None):
base_message = f"""โœ… ANALYSIS COMPLETED
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
๐Ÿ“Š RESULTS:
โ€ข Frames: {total_frames} (Processed: {processed_frames})
โ€ข Duration: {duration:.2f}s
โ€ข Processing Time: {processing_time:.2f}s
โ€ข FPS: {actual_fps:.1f}"""
if session_summary:
base_message += f"""
๐Ÿ‘ฅ UNIQUE PERSONS TRACKED: {session_summary['total_persons']}
๐Ÿ” VIOLATION TYPES: {', '.join(session_summary['violations_by_type'].keys()) if session_summary['violations_by_type'] else 'None'}"""
if has_violations:
return f"""{base_message}
๐Ÿšจ UNIQUE VIOLATIONS: {violation_count}
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
Each violation reported only once per person"""
else:
return f"""{base_message}
โœ… NO VIOLATIONS DETECTED
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
All safety protocols followed"""
def save_snapshot(frame):
try:
filename = f"snapshot_{int(time.time())}.jpg"
snapshot_dir = "./snapshots"
os.makedirs(snapshot_dir, exist_ok=True)
snapshot_path = os.path.join(snapshot_dir, filename)
cv2.imwrite(snapshot_path, frame)
return snapshot_path
except Exception as e:
logger.error(f"Snapshot error: {e}")
return "snapshot_failed.jpg"
def log_violation(violation_data):
try:
log_file = Path("./snapshots/violation_logs.json")
logs = []
if log_file.exists():
with open(log_file, "r") as f:
logs = json.load(f)
logs.append(violation_data)
global recent_violations, violation_history
recent_violations = logs[-10:]
violation_history = logs
with open(log_file, "w") as f:
json.dump(logs, f, indent=4)
except Exception as e:
logger.error(f"Logging error: {e}")
def send_alert(violation):
logger.info(f"ALERT: {violation['violation_type']} detected (Severity: {violation['severity']})")
def format_violations_as_text(violations):
if not violations:
return """๐Ÿ” SAFETY MONITORING STATUS
โœ… NO VIOLATIONS DETECTED
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
๐Ÿ“Š Current Status: ALL CLEAR
๐Ÿ• Last Updated: """ + datetime.now(IST).strftime('%Y-%m-%d %H:%M:%S IST') + """
๐ŸŽฏ Detection Accuracy: >90% confidence
โšก Response Time: <5 seconds
The system is actively monitoring for:
โ€ข No Helmet violations
โ€ข Unsafe Distance violations
โ€ข Unauthorized Area violations
All safety protocols are currently being followed."""
text = f"""๐Ÿšจ SAFETY VIOLATION ALERTS
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
๐Ÿ“Š UNIQUE VIOLATIONS DETECTED: {len(violations)}
Note: Each violation type reported only once per person
"""
for i, violation in enumerate(violations, 1):
severity_emoji = "๐Ÿ”ด" if violation['severity'] == 'Critical' else "๐ŸŸก"
text += f"""
โ”Œโ”€ ALERT #{i:02d} โ”€ {severity_emoji} {violation['violation_type'].upper()}
โ”‚
โ”œโ”€ ๐Ÿ• Time: {violation['timestamp']}
โ”œโ”€ โš ๏ธ Severity: {violation['severity']}
โ”œโ”€ ๐Ÿ“ Location: Site {violation['site_id']} | Camera {violation['camera_id']}
โ”œโ”€ ๐Ÿ‘ท Worker: {violation.get('worker_id', 'UNKNOWN')}
โ”œโ”€ ๐Ÿ“ธ Evidence: {violation['snapshot_url']}
โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€\n"""
text += f"""
๐Ÿ“ˆ SUMMARY STATISTICS:
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
โ€ข Total Violations: {len(violations)}
โ€ข Critical: {sum(1 for v in violations if v['severity'] == 'Critical')}
โ€ข Moderate: {sum(1 for v in violations if v['severity'] == 'Moderate')}
โ€ข Last Alert: {violations[-1]['timestamp'] if violations else 'N/A'}
๐Ÿ”„ System Status: ACTIVELY MONITORING
โšก Response Time: <5 seconds
๐ŸŽฏ Detection Accuracy: >90% confidence"""
return text
def generate_heatmap(violations):
if not violations:
return None
try:
df = pd.DataFrame(violations)
df['hour'] = pd.to_datetime(df['timestamp']).dt.hour
heatmap_data = df.pivot_table(index='hour', columns='violation_type', aggfunc='size', fill_value=0)
plt.figure(figsize=(12, 8))
sns.heatmap(heatmap_data, cmap='YlOrRd', annot=True, fmt='d')
plt.title("Unique Violations by Hour")
plt.xlabel("Violation Type")
plt.ylabel("Hour of Day")
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
plt.savefig(temp_file.name, bbox_inches='tight')
plt.close()
return temp_file.name
except Exception as e:
logger.error(f"Heatmap error: {e}")
return None
def cancel_processing():
global processing_active
processing_active = False
cv2.destroyAllWindows()
return "Processing cancelled"
# --- Enhanced CSS ---
enhanced_custom_css = """
/* Reset Default Styles */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
/* Global Theme and Layout */
.gradio-container {
font-family: 'Poppins', 'Inter', 'Segoe UI', 'Roboto', sans-serif !important;
background: linear-gradient(45deg, #0a0a1f, #1a0033, #2a0044, #0a0a1f) !important;
background-size: 400% !important;
animation: gradientShift 12s ease infinite !important;
min-height: 100vh !important;
display: flex !important;
flex-direction: column !important;
justify-content: center !important;
align-items: center !important;
position: relative !important;
overflow: hidden !important;
}
@keyframes gradientShift {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
/* Particle Canvas for Star-Like Animation */
#particle-canvas {
position: absolute !important;
top: 0 !important;
left: 0 !important;
width: 100% !important;
height: 100% !important;
z-index: 1 !important;
pointer-events: none !important;
}
/* Main Header Styling */
.main-header {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border: 1px solid rgba(255, 0, 204, 0.3) !important;
color: #f0f0f5 !important;
text-align: center !important;
padding: 1rem !important;
margin-bottom: 1rem !important;
border-radius: 15px !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important;
z-index: 2 !important;
animation: fadeIn 1s ease-out !important;
}
.header-title {
font-size: 2.5rem !important;
font-weight: 700 !important;
text-shadow: 0 0 12px rgba(255, 0, 204, 0.8) !important;
animation: glow 2s ease-in-out infinite alternate !important;
margin-bottom: 0.5rem !important;
}
.header-subtitle {
font-size: 1rem !important;
font-weight: 400 !important;
color: #ccc !important;
text-shadow: 0 0 5px rgba(255, 0, 204, 0.5) !important;
}
@keyframes glow {
from { text-shadow: 0 0 5px rgba(255, 0, 204, 0.5), 0 0 10px rgba(255, 0, 204, 0.3); }
to { text-shadow: 0 0 12px rgba(255, 0, 204, 0.9), 0 0 20px rgba(255, 0, 204, 0.6); }
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
/* Professional Card System (Glassmorphism) */
.professional-card {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border: 1px solid rgba(255, 0, 204, 0.3) !important;
border-radius: 15px !important;
padding: 1rem !important;
margin: 0.5rem 0 !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important;
z-index: 2 !important;
transition: all 0.3s ease !important;
}
.professional-card:hover {
transform: translateY(-2px) !important;
box-shadow: 0 0 30px rgba(255, 0, 204, 0.35) !important;
}
/* Section Headers */
.section-header {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border: 1px solid rgba(255, 0, 204, 0.3) !important;
color: #f0f0f5 !important;
padding: 0.8rem 1rem !important;
border-radius: 10px !important;
text-align: center !important;
font-weight: 700 !important;
font-size: 1.2rem !important;
margin-bottom: 1rem !important;
text-shadow: 0 0 12px rgba(255, 0, 204, 0.8) !important;
animation: glow 2s ease-in-out infinite alternate !important;
z-index: 2 !important;
}
/* Button Styling */
.btn-primary, .gr-button {
background: linear-gradient(90deg, #00C4B4, #ff00cc) !important;
border: none !important;
border-radius: 8px !important;
padding: 10px 20px !important;
color: white !important;
font-weight: 600 !important;
font-size: 1rem !important;
transition: all 0.3s ease !important;
box-shadow: 0 0 15px rgba(255, 0, 204, 0.4) !important;
position: relative !important;
overflow: hidden !important;
z-index: 2 !important;
}
.btn-primary:hover, .gr-button:hover {
background: linear-gradient(90deg, #00C6B6, #ff33cc) !important;
box-shadow: 0 0 20px rgba(255, 0, 204, 0.6) !important;
transform: translateY(-2px) !important;
}
.btn-primary::before, .gr-button::before {
content: '' !important;
position: absolute !important;
top: 50% !important;
left: 50% !important;
width: 300% !important;
height: 300% !important;
background: rgba(255, 255, 255, 0.1) !important;
transition: all 0.5s ease !important;
transform: translate(-50%, -50%) rotate(45deg) !important;
opacity: 0 !important;
}
.btn-primary:hover::before, .gr-button:hover::before {
opacity: 1 !important;
width: 0 !important;
height: 0 !important;
}
.btn-secondary {
background: linear-gradient(90deg, #11998e, #38ef7d) !important;
border: none !important;
border-radius: 8px !important;
padding: 10px 20px !important;
color: white !important;
font-weight: 600 !important;
font-size: 1rem !important;
transition: all 0.3s ease !important;
box-shadow: 0 0 15px rgba(17, 153, 142, 0.4) !important;
}
.btn-secondary:hover {
background: linear-gradient(90deg, #0e867c, #33d670) !important;
box-shadow: 0 0 20px rgba(17, 153, 142, 0.6) !important;
transform: translateY(-2px) !important;
}
/* Status Display */
.status-display {
background: rgba(255, 255, 255, 0.05) !important;
border: 1px solid rgba(255, 255, 255, 0.15) !important;
border-radius: 8px !important;
padding: 1rem !important;
color: #f0f0f5 !important;
font-size: 0.95rem !important;
font-family: 'Fira Code', 'Consolas', monospace !important;
white-space: pre-wrap !important;
max-height: 300px !important;
overflow-y: auto !important;
z-index: 2 !important;
}
.status-display:focus {
border-color: #ff00cc !important;
box-shadow: 0 0 8px rgba(255, 0, 204, 0.6) !important;
}
.status-display::-webkit-scrollbar {
width: 5px !important;
}
.status-display::-webkit-scrollbar-track {
background: rgba(255, 255, 255, 0.05) !important;
border-radius: 3px !important;
}
.status-display::-webkit-scrollbar-thumb {
background: #ff00cc !important;
border-radius: 3px !important;
}
/* Alert Panel */
.alert-panel {
background: rgba(255, 51, 51, 0.15) !important;
backdrop-filter: blur(10px) !important;
border: 1px solid rgba(255, 51, 51, 0.5) !important;
color: #ff3333 !important;
border-radius: 15px !important;
padding: 1rem !important;
margin: 0.5rem 0 !important;
box-shadow: 0 0 20px rgba(255, 51, 51, 0.3) !important;
animation: alertPulse 2s infinite !important;
z-index: 2 !important;
}
@keyframes alertPulse {
0%, 100% { transform: scale(1); opacity: 1; }
50% { transform: scale(1.01); opacity: 0.95; }
}
/* Success Panel */
.success-panel {
background: rgba(0, 184, 148, 0.15) !important;
backdrop-filter: blur(10px) !important;
border: 1px solid rgba(0, 184, 148, 0.5) !important;
color: #00b894 !important;
border-radius: 15px !important;
padding: 1rem !important;
margin: 0.5rem 0 !important;
box-shadow: 0 0 20px rgba(0, 184, 148, 0.3) !important;
z-index: 2 !important;
}
/* Image Components */
.image-component {
border-radius: 15px !important;
overflow: hidden !important;
box-shadow: 0 0 20px rgba(255, 0, 204, 0.2) !important;
transition: all 0.3s ease !important;
border: 1px solid rgba(255, 0, 204, 0.2) !important;
z-index: 2 !important;
}
.image-component:hover {
transform: scale(1.01) !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.3) !important;
}
/* Gallery Styling */
.gallery-component {
border-radius: 15px !important;
overflow: hidden !important;
box-shadow: 0 0 20px rgba(255, 0, 204, 0.2) !important;
background: rgba(255, 255, 255, 0.05) !important;
padding: 0.5rem !important;
z-index: 2 !important;
}
/* File Download Component */
.file-component {
background: rgba(255, 255, 255, 0.05) !important;
border: 1px dashed rgba(255, 0, 204, 0.3) !important;
border-radius: 10px !important;
padding: 1rem !important;
text-align: center !important;
transition: all 0.3s ease !important;
z-index: 2 !important;
}
.file-component:hover {
background: rgba(255, 255, 255, 0.1) !important;
transform: translateY(-2px) !important;
}
/* Analytics Dashboard */
.analytics-panel {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border-radius: 15px !important;
padding: 1rem !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important;
z-index: 2 !important;
}
/* Tab Styling */
.gradio-tabs {
border: none !important;
background: transparent !important;
z-index: 2 !important;
}
.gradio-tab-item {
background: rgba(255, 255, 255, 0.05) !important;
border: 1px solid rgba(255, 0, 204, 0.2) !important;
border-radius: 8px !important;
color: #ccc !important;
padding: 0.5rem 1rem !important;
margin: 0 0.2rem !important;
transition: all 0.3s ease !important;
}
.gradio-tab-item.selected {
background: rgba(255, 255, 255, 0.15) !important;
color: #ff00cc !important;
font-weight: 600 !important;
border-color: #ff00cc !important;
box-shadow: 0 0 10px rgba(255, 0, 204, 0.5) !important;
}
.gradio-tab-item:hover {
background: rgba(255, 255, 255, 0.1) !important;
color: #ff00cc !important;
}
.gradio-tab-content {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border-radius: 15px !important;
padding: 1rem !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important;
z-index: 2 !important;
}
/* Footer Styling */
.footer-info {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border: 1px solid rgba(255, 0, 204, 0.3) !important;
border-radius: 15px !important;
padding: 1rem !important;
margin-top: 1rem !important;
text-align: center !important;
color: #f0f0f5 !important;
z-index: 2 !important;
}
.feature-grid {
display: grid !important;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)) !important;
gap: 0.5rem !important;
margin-top: 1rem !important;
}
.feature-item {
background: rgba(255, 255, 255, 0.05) !important;
padding: 0.5rem !important;
border-radius: 8px !important;
text-align: center !important;
transition: all 0.3s ease !important;
color: #ccc !important;
}
.feature-item:hover {
background: rgba(255, 255, 255, 0.1) !important;
transform: translateY(-2px) !important;
color: #ff00cc !important;
}
/* Responsive Design */
@media (max-width: 768px) {
.gradio-container {
padding: 1rem !important;
}
.main-header {
padding: 0.8rem !important;
}
.header-title {
font-size: 2rem !important;
}
.professional-card {
padding: 0.8rem !important;
margin: 0.3rem 0 !important;
}
.section-header {
font-size: 1rem !important;
padding: 0.6rem !important;
}
.btn-primary, .btn-secondary, .gr-button {
padding: 8px 16px !important;
font-size: 0.9rem !important;
}
}
"""
# --- Gradio Interface ---
with gr.Blocks(
title="Dynamic Safety Violation Detection using CCTV + AI",
css=enhanced_custom_css,
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="emerald",
neutral_hue="slate",
radius_size="lg",
spacing_size="sm",
font=[
gr.themes.GoogleFont("Poppins"),
"ui-sans-serif",
"system-ui",
"sans-serif"
]
).set(
body_background_fill="none",
block_background_fill="none",
block_border_width="0px",
block_shadow="none",
block_radius="15px",
button_primary_background_fill="none",
button_primary_background_fill_hover="none",
button_secondary_background_fill="none"
)
) as demo:
# Particle Canvas for Enhanced Star Animation
gr.HTML("""
<canvas id="particle-canvas"></canvas>
<script>
const canvas = document.getElementById('particle-canvas');
const ctx = canvas.getContext('2d');
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
window.addEventListener('resize', () => {
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
});
const stars = [];
const starCount = 200;
class Star {
constructor() {
this.reset();
}
reset() {
this.x = Math.random() * canvas.width;
this.y = Math.random() * canvas.height;
this.size = Math.random() * 2 + 0.5;
this.speedX = Math.random() * 5 + 2;
this.speedY = (Math.random() - 0.5) * 0.5;
this.opacity = Math.random() * 0.5 + 0.5;
this.twinklePhase = Math.random() * Math.PI * 2;
this.twinkleSpeed = Math.random() * 0.05 + 0.02;
}
update() {
this.x -= this.speedX;
this.y += this.speedY;
this.twinklePhase += this.twinkleSpeed;
this.currentOpacity = this.opacity * (0.5 + 0.5 * Math.sin(this.twinklePhase));
if (this.x < 0 || this.y < 0 || this.y > canvas.height) {
this.reset();
}
}
draw() {
ctx.fillStyle = `rgba(255, 0, 204, ${this.currentOpacity})`;
ctx.beginPath();
ctx.arc(this.x, this.y, this.size, 0, Math.PI * 2);
ctx.fill();
ctx.fillStyle = `rgba(255, 0, 204, ${this.currentOpacity * 0.3})`;
ctx.beginPath();
ctx.arc(this.x, this.y, this.size * 2, 0, Math.PI * 2);
ctx.fill();
}
}
function initStars() {
for (let i = 0; i < starCount; i++) {
const star = new Star();
star.x = Math.random() * canvas.width;
stars.push(star);
}
}
function animateStars() {
ctx.clearRect(0, 0, canvas.width, canvas.height);
stars.forEach(star => {
star.update();
star.draw();
});
requestAnimationFrame(animateStars);
}
initStars();
animateStars();
</script>
""")
# Professional Header
gr.HTML("""
<div class="main-header">
<h1 class="header-title">๐Ÿ” Dynamic Safety Violation Detection using CCTV + AI</h1>
<p class="header-subtitle">Enhanced Multi-Person Tracking - Each violation type detected only once per person per video</p>
</div>
""")
# Smart Media Analysis Section
gr.HTML('<div class="section-header">๐Ÿ“ท Smart Media Analysis</div>')
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card"]):
media_input = gr.File(
label="๐Ÿ“ค Upload Image or Video for Safety Analysis",
file_types=["image", "video"],
elem_classes=["image-component"],
height=200
)
with gr.Row():
media_button = gr.Button(
"๐Ÿ” Analyze Media",
variant="primary",
elem_classes=["btn-primary"],
size="lg"
)
# Analysis Results Section
gr.HTML('<div class="section-header">๐Ÿ“Š Analysis Results & Violation Details</div>')
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card"]):
media_output = gr.Gallery(
label="๐Ÿ–ผ๏ธ Processed Media with Detection Results",
elem_classes=["gallery-component"],
height=260
)
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card"]):
media_status = gr.Textbox(
label="๐Ÿ“‹ Analysis Status",
elem_classes=["status-display"],
lines=7,
max_lines=9,
value="๐Ÿ“Š Awaiting Media Analysis\n\nUpload an image/video and click 'Analyze Media' to begin safety violation detection.",
interactive=False
)
pdf_output = gr.File(
label="๐Ÿ“ฅ Download Professional Report",
elem_classes=["file-component"]
)
# Violation Details Section
gr.HTML('<div class="section-header">๐Ÿšจ Real-time Violation Monitoring</div>')
with gr.Group(elem_classes=["professional-card", "alert-panel"]):
violation_log = gr.Textbox(
label="๐Ÿšจ Real-time Violation Details",
elem_classes=["status-display"],
lines=10,
max_lines=12,
value=format_violations_as_text(recent_violations),
interactive=False
)
# Live Stream Processing Section
gr.HTML('<div class="section-header">๐Ÿ“น Live Stream Monitoring</div>')
with gr.Row():
with gr.Column(scale=2):
with gr.Group(elem_classes=["professional-card"]):
rtsp_url_input = gr.Textbox(
label="๐Ÿ“ก RTSP Stream URL",
placeholder="rtsp://example.com/stream",
value=RTSP_URL_DEFAULT,
interactive=True
)
with gr.Row():
rtsp_button = gr.Button(
"๐Ÿ“ก Start Live Monitoring",
variant="primary",
elem_classes=["btn-primary"],
size="lg"
)
rtsp_cancel_btn = gr.Button(
"โน๏ธ Stop Monitoring",
variant="secondary",
elem_classes=["btn-secondary"],
size="lg"
)
rtsp_status = gr.Textbox(
label="๐Ÿ“บ Live Stream Processing Status",
elem_classes=["status-display"],
lines=6,
max_lines=8,
value="๐Ÿ“บ RTSP Stream Processor Ready\n\nEnter an RTSP URL and click 'Start Live Monitoring' to begin real-time monitoring.",
interactive=False
)
with gr.Column(scale=3):
with gr.Group(elem_classes=["professional-card"]):
rtsp_output = gr.Gallery(
label="๐ŸŽฌ Live Stream Frames & Detection Results",
elem_classes=["gallery-component"],
height=360,
columns=3,
rows=2,
object_fit="cover"
)
# Live Violation Log Section
gr.HTML('<div class="section-header">๐Ÿ“Š Live Violation Analytics</div>')
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card", "alert-panel"]):
rtsp_violation_log = gr.Textbox(
label="๐Ÿšจ Live Violation Log",
elem_classes=["status-display"],
lines=8,
max_lines=10,
interactive=False
)
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card", "analytics-panel"]):
heatmap_output = gr.Image(
label="๐Ÿ”ฅ Violation Heatmap - Temporal Analysis",
elem_classes=["image-component"],
height=320
)
rtsp_pdf_output = gr.File(
label="๐Ÿ“ฅ Download RTSP Professional Report",
elem_classes=["file-component"]
)
# Professional Footer
gr.HTML(f"""
<div class="footer-info">
<h3>๐Ÿ›ก๏ธ Dynamic Safety Violation Detection using CCTV + AI</h3>
<div class="feature-grid">
<div class="feature-item">
<strong>๐ŸŽฏ Real-time Detection</strong><br>
Advanced YOLOv8 AI with >90% accuracy
</div>
<div class="feature-item">
<strong>โšก Ultra-fast Response</strong><br>
Alert generation in <5 seconds
</div>
<div class="feature-item">
<strong>๐Ÿ“ฑ Responsive Design</strong><br>
Optimized for desktop, tablet & mobile
</div>
<div class="feature-item">
<strong>๐Ÿ”— Salesforce Integration</strong><br>
Reports are uploaded directly to Salesforce
</div>
</div>
<div style="margin-top: 0.8rem; padding-top: 0.8rem; border-top: 0.5px solid rgba(255,255,255,0.2);">
<p style="margin: 0; font-size: 0.8rem; opacity: 0.7;">
Dynamic Safety Violation Detection using CCTV + AI ยฉ 2025 | Cloud-integrated safety monitoring.
</p>
</div>
</div>
""")
# Event Handlers
media_process_event = media_button.click(
fn=process_video,
inputs=media_input,
outputs=[media_output, media_status, pdf_output, violation_log]
)
rtsp_process_event = rtsp_button.click(
fn=process_rtsp_stream,
inputs=[rtsp_url_input],
outputs=[rtsp_status, rtsp_output, rtsp_violation_log, heatmap_output, rtsp_pdf_output]
)
rtsp_cancel_btn.click(
fn=cancel_processing,
outputs=[rtsp_status],
cancels=[rtsp_process_event]
)
if __name__ == "__main__":
# Ensure the required package is installed
try:
import scipy
except ImportError:
print("Scipy not found. Installing...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True,
quiet=False,
favicon_path=None,
auth=None,
inbrowser=True
)