PrashanthB461's picture
Update app.py
5a3efbb verified
import os
import sys
import warnings
import gradio as gr
import torch
from ultralytics import YOLO
import cv2
import requests
import json
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import logging
import pandas as pd
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from io import BytesIO
import seaborn as sns
import matplotlib.pyplot as plt
import subprocess
from datetime import timezone
import pytz
import shutil
import tempfile
from scipy.spatial import distance
import asyncio
from functools import partial
from concurrent.futures import ThreadPoolExecutor
from simple_salesforce import Salesforce, SalesforceAuthenticationFailed
from retrying import retry
import base64
# --- Initial Configuration ---
warnings.filterwarnings("ignore")
MODEL_PATH = "./yolov8n.pt"
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Check for GPU availability
logger.info(f"PyTorch CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
logger.info("Using CPU")
# Download model weights if needed
if not os.path.exists(MODEL_PATH):
logger.info(f"Downloading model weights to {MODEL_PATH}...")
try:
download_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt"
subprocess.run(["wget", download_url, "-O", MODEL_PATH], check=True)
os.chmod(MODEL_PATH, 0o644)
logger.info("Model weights downloaded successfully")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to download model: {e}")
sys.exit(1)
# Set up YOLO config directory
yolo_config_dir = "./Ultralytics"
os.makedirs(yolo_config_dir, exist_ok=True)
os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir
# --- Environment Variables ---
RTSP_URL_DEFAULT = os.getenv("RTSP_URL", "")
SALESFORCE_URL = os.getenv("SALESFORCE_URL", "")
SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "")
HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "")
SF_USERNAME = "SafetyViolationAI22@sathkrutha.com"
SF_PASSWORD = "Vij@y12345"
SF_SECURITY_TOKEN = "inrcIMUU7rkV7BnNZ2LvD5MVQ"
SALESFORCE_ENABLED = all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN])
# --- Time Zone Configuration ---
IST = pytz.timezone("Asia/Kolkata")
# --- Global Variables ---
yolo_model = None
recent_violations = []
violation_history = []
processing_active = False
sf_connection = None
site_id_cache = {}
# --- Model Initialization ---
def initialize_model():
global yolo_model
try:
logger.info("Initializing YOLOv8 model...")
yolo_model = YOLO(MODEL_PATH)
logger.info("YOLOv8 model loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
return False
if not initialize_model():
logger.error("Critical error: Model initialization failed")
sys.exit(1)
# --- Salesforce Integration Functions ---
@retry(stop_max_attempt_number=3, wait_fixed=2000,
retry_on_exception=lambda e: isinstance(e, Exception))
def get_salesforce_connection():
"""Establishes and caches a Salesforce connection with retry logic."""
global sf_connection
if sf_connection:
try:
sf_connection.query("SELECT Id FROM User LIMIT 1")
logger.info("Salesforce connection is active.")
return sf_connection
except Exception:
logger.warning("Salesforce session expired. Reconnecting...")
sf_connection = None
if not SALESFORCE_ENABLED:
raise ConnectionError("Salesforce credentials are not configured.")
try:
sf_connection = Salesforce(username=SF_USERNAME, password=SF_PASSWORD, security_token=SF_SECURITY_TOKEN)
logger.info(f"Successfully connected to Salesforce instance: {sf_connection.sf_instance}")
return sf_connection
except SalesforceAuthenticationFailed as e:
logger.error(f"Salesforce authentication failed: {e}. Check credentials and IP restrictions.")
raise
except Exception as e:
logger.error(f"Failed to connect to Salesforce: {e}")
raise
def get_or_create_site_id(sf, site_name='SITE001'):
"""
Queries for a Site record by name, creates it if not found,
and returns the Salesforce ID. Caches the result.
"""
if site_name in site_id_cache:
return site_id_cache[site_name]
try:
query = f"SELECT Id FROM Site__c WHERE Name = '{site_name}' LIMIT 1"
result = sf.query(query)
if result['totalSize'] > 0:
site_id = result['records'][0]['Id']
logger.info(f"Found existing Site '{site_name}' with ID: {site_id}")
site_id_cache[site_name] = site_id
return site_id
else:
logger.info(f"Site '{site_name}' not found. Creating new Site record...")
create_result = sf.Site__c.create({'Name': site_name})
if 'id' in create_result:
site_id = create_result['id']
logger.info(f"Successfully created new Site '{site_name}' with ID: {site_id}")
site_id_cache[site_name] = site_id
return site_id
else:
logger.error(f"Failed to create Site record: {create_result.get('errors')}")
return None
except Exception as e:
logger.error(f"Error getting or creating Site ID for '{site_name}': {e}", exc_info=True)
return None
def create_salesforce_violation_record(sf, violation_data):
"""
Prepares a payload for a Safety_Violation_Log__c record in Salesforce.
Returns the payload for batch creation or None if failed.
"""
try:
site_name = violation_data.get('site_id', 'Default Site')
site_id_from_sf = get_or_create_site_id(sf, site_name)
if not site_id_from_sf:
logger.error(f"Failed to get or create Site record '{site_name}' in Salesforce.")
return None, "Site ID creation/retrieval failed."
payload = {
'Site_ID__c': site_id_from_sf,
'Violation_Type__c': violation_data.get('violation_type'),
'Severity__c': violation_data.get('severity', 'Medium'),
'Timestamp__c': violation_data.get('timestamp'),
'Snapshot_URL__c': violation_data.get('snapshot_url', 'N/A'),
'Worker_ID__c': violation_data.get('worker_id', 'N/A'),
'Camera_ID__c': violation_data.get('camera_id', 'CAM001'),
'Alert_Sent__c': True,
'PDF_Report_URL__c': violation_data.get('pdf_url', "Report will be available after processing completion.")
}
payload = {k: v for k, v in payload.items() if v is not None}
return payload, None
except Exception as e:
logger.error(f"Failed to prepare Salesforce record: {e}", exc_info=True)
return None, str(e)
def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
"""
Generates a PDF report, uploads it to Salesforce, links it to records,
and updates those records with the download URL.
Returns a temporary local path for the PDF and the Salesforce URL.
"""
if not violations or not record_ids or not sf:
logger.warning("No violations, record IDs, or Salesforce connection. Skipping report generation.")
return None, None
try:
# 1. Generate PDF in memory
buffer = BytesIO()
c = canvas.Canvas(buffer, pagesize=letter)
c.setFont("Helvetica-Bold", 16)
c.drawString(100, 750, "Safety Violation Report")
c.setFont("Helvetica", 12)
c.drawString(100, 730, f"Generated: {datetime.now(IST).strftime('%Y-%m-%d %H:%M:%S IST')}")
c.setFont("Helvetica", 10)
c.drawString(100, 710, "Note: Each violation type reported only once per person per session.")
y = 680
for i, violation in enumerate(violations, 1):
c.setFont("Helvetica-Bold", 12)
c.drawString(100, y, f"Violation #{i}: {violation['violation_type']}")
y -= 20
c.setFont("Helvetica", 10)
c.drawString(120, y, f"Severity: {violation['severity']}")
y -= 15
c.drawString(120, y, f"Time: {violation['timestamp']}")
y -= 15
c.drawString(120, y, f"Worker: {violation.get('worker_id', 'UNKNOWN')}")
y -= 15
if 'distance' in violation:
c.drawString(120, y, f"Distance: {violation['distance']}")
y -= 15
y -= 20
if y < 50:
c.showPage()
y = 750
c.save()
pdf_bytes = buffer.getvalue()
buffer.close()
# 2. Upload ContentVersion to Salesforce
title = f"Safety_Report_{datetime.now(IST).strftime('%Y%m%d_%H%M%S')}"
b64_pdf = base64.b64encode(pdf_bytes).decode('utf-8')
logger.info(f"Uploading PDF '{title}.pdf' to Salesforce...")
cv_result = sf.ContentVersion.create({
'Title': title,
'PathOnClient': f'{title}.pdf',
'VersionData': b64_pdf
})
if not cv_result.get('success'):
logger.error(f"Failed to create ContentVersion: {cv_result.get('errors')}")
return None, None
content_version_id = cv_result['id']
logger.info(f"Successfully created ContentVersion with ID: {content_version_id}")
# 3. Get ContentDocumentId
query = f"SELECT ContentDocumentId FROM ContentVersion WHERE Id = '{content_version_id}'"
cv_query_result = sf.query(query)
if not cv_query_result['records']:
logger.error(f"Could not find ContentDocumentId for ContentVersion {content_version_id}")
return None, None
content_document_id = cv_query_result['records'][0]['ContentDocumentId']
# 4. Link ContentDocument to all violation records
logger.info(f"Linking ContentDocument {content_document_id} to {len(record_ids)} records...")
link_payloads = [{
'ContentDocumentId': content_document_id,
'LinkedEntityId': record_id,
'ShareType': 'V' # V = Viewer
} for record_id in record_ids]
link_success_count = 0
for payload in link_payloads:
try:
link_result = sf.ContentDocumentLink.create(payload)
if link_result.get('success'):
link_success_count += 1
else:
logger.warning(f"Failed to link to {payload['LinkedEntityId']}: {link_result.get('errors')}")
except Exception as e:
logger.error(f"Error creating ContentDocumentLink for {payload['LinkedEntityId']}: {e}")
logger.info(f"Successfully created {link_success_count}/{len(record_ids)} links.")
# 5. Construct URL and Update records
sf_instance_url = sf.sf_instance.replace('https://', '')
pdf_url = f"https://{sf_instance_url}/sfc/servlet.shepherd/version/download/{content_version_id}"
logger.info(f"Updating records with Salesforce PDF URL: {pdf_url}")
update_payloads = [{'Id': record_id, 'PDF_Report_URL__c': pdf_url} for record_id in record_ids]
update_results = sf.bulk.Safety_Violation_Log__c.update(update_payloads)
successful_updates = sum(1 for res in update_results if res.get('success'))
logger.info(f"Successfully updated {successful_updates}/{len(record_ids)} records with the PDF URL.")
# 6. Save PDF to a temporary file for Gradio output
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf', prefix='report_') as temp_pdf:
temp_pdf.write(pdf_bytes)
temp_pdf_path = temp_pdf.name
logger.info(f"Salesforce report URL: {pdf_url}")
logger.info(f"Temporary local report for download: {temp_pdf_path}")
return temp_pdf_path, pdf_url
except Exception as e:
logger.error(f"Error in Salesforce PDF report generation/upload: {e}", exc_info=True)
return None, None
# --- Enhanced Safety Violation Detector Class with Group Detection ---
class SafetyViolationDetector:
def __init__(self):
# Detection thresholds (fine-tuned for better accuracy)
self.helmet_threshold = 0.75
self.person_threshold = 0.60
self.unsafe_distance = 50 # pixels
self.violation_cooldown = 20 # seconds
# Unauthorized zones (x1, y1, x2, y2)
self.unauthorized_zones = [
[100, 100, 300, 300], # Example zone 1
[400, 200, 600, 400] # Example zone 2
]
self.active_violations = {}
self.violation_history = {}
self.person_tracker = {}
self.person_positions_history = {}
self.next_person_id = 1
self.max_tracking_distance = 120
self.session_violations = {}
def reset_session(self):
self.session_violations = {}
self.active_violations = {}
self.person_tracker = {}
self.person_positions_history = {}
self.next_person_id = 1
logger.info("Session violation tracking reset for new video")
def has_reported_violation(self, person_id, violation_type):
if person_id not in self.session_violations:
return False
return violation_type in self.session_violations[person_id]
def mark_violation_reported(self, person_id, violation_type, timestamp):
if person_id not in self.session_violations:
self.session_violations[person_id] = {}
self.session_violations[person_id][violation_type] = {
'first_detected': timestamp,
'count': self.session_violations[person_id].get(violation_type, {}).get('count', 0) + 1
}
def _get_stable_person_id(self, box, current_time):
center = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
box_area = (box[2] - box[0]) * (box[3] - box[1])
best_match_id = None
best_match_score = 0
min_distance = float('inf')
for person_id, history in self.person_positions_history.items():
if not history['positions']:
continue
last_position = history['positions'][-1]
last_box = history['boxes'][-1]
dist = np.sqrt((center[0] - last_position[0])**2 + (center[1] - last_position[1])**2)
iou = self._iou(box, last_box)
if dist < self.max_tracking_distance:
score = (1.0 / (1.0 + dist/50)) * 0.7 + iou * 0.3
if score > best_match_score and score > 0.3:
best_match_score = score
best_match_id = person_id
min_distance = dist
if best_match_id is not None:
person_id = best_match_id
else:
person_id = self.next_person_id
self.next_person_id += 1
self.person_positions_history[person_id] = {
'positions': [],
'boxes': [],
'first_seen': current_time,
'last_seen': current_time
}
self.person_positions_history[person_id]['positions'].append(center)
self.person_positions_history[person_id]['boxes'].append(box)
self.person_positions_history[person_id]['last_seen'] = current_time
if len(self.person_positions_history[person_id]['positions']) > 10:
self.person_positions_history[person_id]['positions'].pop(0)
self.person_positions_history[person_id]['boxes'].pop(0)
return person_id
def detect_violations(self, results, frame):
start_time = time.time()
current_time = time.time()
violations = []
boxes = results[0].boxes.xyxy.cpu().numpy()
confidences = results[0].boxes.conf.cpu().numpy()
class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
class_names = results[0].names
persons = []
helmets = []
for box, conf, cls_id in zip(boxes, confidences, class_ids):
class_name = class_names[cls_id]
if class_name == "person" and conf >= self.person_threshold:
person_id = self._get_stable_person_id(box, current_time)
persons.append({
'box': box,
'confidence': conf,
'center': ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2),
'id': person_id
})
elif class_name == "hard hat" and conf >= self.helmet_threshold:
helmets.append({
'box': box,
'confidence': conf,
'area': (box[2] - box[0]) * (box[3] - box[1])
})
current_person_ids = set()
for person in persons:
person_id = person['id']
current_person_ids.add(person_id)
if person_id not in self.person_tracker:
self.person_tracker[person_id] = {
'first_seen': current_time,
'last_seen': current_time,
'positions': [person['center']],
'helmet_status': False,
'violations': {}
}
else:
self.person_tracker[person_id]['last_seen'] = current_time
self.person_tracker[person_id]['positions'].append(person['center'])
if len(self.person_tracker[person_id]['positions']) > 10:
self.person_tracker[person_id]['positions'].pop(0)
for person in persons:
person_id = person['id']
helmet_violation = self._check_helmet_violation(person, helmets, frame, current_time)
if helmet_violation:
violations.append(helmet_violation)
unauthorized_violation = self._check_unauthorized_area(person, frame, current_time)
if unauthorized_violation:
violations.append(unauthorized_violation)
distance_violations = self._check_distance_violations(persons, frame, current_time)
violations.extend(distance_violations)
self._cleanup_violations(current_time)
self._cleanup_inactive_persons(current_person_ids, current_time)
logger.info(f"Violation detection time: {time.time() - start_time:.2f}s")
return violations
def _check_helmet_violation(self, person, helmets, frame, current_time):
person_id = person['id']
person_box = person['box']
violation_type = 'no_helmet'
if self.has_reported_violation(person_id, violation_type):
return None
head_region = [
person_box[0],
max(person_box[1], person_box[1] + (person_box[3] - person_box[1]) * 0.3),
person_box[2],
person_box[1] + (person_box[3] - person_box[1]) * 0.3
]
has_helmet = False
for helmet in helmets:
if self._iou(helmet['box'], head_region) > 0.1:
has_helmet = True
break
self.person_tracker[person_id]['helmet_status'] = has_helmet
if not has_helmet:
violation_key = f"no_helmet_{person_id}"
if (violation_key not in self.active_violations or
current_time - self.active_violations[violation_key]['last_detected'] > self.violation_cooldown):
self.mark_violation_reported(person_id, violation_type, current_time)
self.active_violations[violation_key] = {
'type': 'no_helmet',
'person_id': person_id,
'first_detected': current_time,
'last_detected': current_time,
'count': 1
}
if 'no_helmet' not in self.person_tracker[person_id]['violations']:
self.person_tracker[person_id]['violations']['no_helmet'] = {
'count': 0,
'last_time': 0
}
self.person_tracker[person_id]['violations']['no_helmet']['count'] += 1
self.person_tracker[person_id]['violations']['no_helmet']['last_time'] = current_time
self._annotate_frame(frame, person_box, person_id, "NO HELMET", (0, 0, 255))
logger.info(f"NEW VIOLATION: No helmet detected for person {person_id}")
return {
'type': 'no_helmet',
'severity': 'Critical',
'person': person,
'person_id': person_id,
'timestamp': current_time
}
else:
self.active_violations[violation_key]['last_detected'] = current_time
self.active_violations[violation_key]['count'] += 1
return None
def _check_unauthorized_area(self, person, frame, current_time):
person_id = person['id']
violation_type = 'unauthorized_area'
if self.has_reported_violation(person_id, violation_type):
return None
x1, y1, x2, y2 = person['box']
person_center = ((x1 + x2) / 2, (y1 + y2) / 2)
for zone in self.unauthorized_zones:
zx1, zy1, zx2, zy2 = zone
if (zx1 <= person_center[0] <= zx2 and zy1 <= person_center[1] <= zy2):
violation_key = f"unauthorized_area_{person_id}_{zx1}_{zy1}"
if (violation_key not in self.active_violations or
current_time - self.active_violations[violation_key]['last_detected'] > self.violation_cooldown):
self.mark_violation_reported(person_id, violation_type, current_time)
self.active_violations[violation_key] = {
'type': 'unauthorized_area',
'person_id': person_id,
'zone': zone,
'first_detected': current_time,
'last_detected': current_time,
'count': 1
}
if 'unauthorized_area' not in self.person_tracker[person_id]['violations']:
self.person_tracker[person_id]['violations']['unauthorized_area'] = {
'count': 0,
'last_time': 0
}
self.person_tracker[person_id]['violations']['unauthorized_area']['count'] += 1
self.person_tracker[person_id]['violations']['unauthorized_area']['last_time'] = current_time
cv2.rectangle(frame, (zx1, zy1), (zx2, zy2), (255, 0, 255), 2)
self._annotate_frame(frame, person['box'], person_id, "UNAUTHORIZED", (255, 0, 255))
logger.info(f"NEW VIOLATION: Unauthorized area detected for person {person_id}")
return {
'type': 'unauthorized_area',
'severity': 'High',
'person': person,
'person_id': person_id,
'zone': zone,
'timestamp': current_time
}
else:
self.active_violations[violation_key]['last_detected'] = current_time
self.active_violations[violation_key]['count'] += 1
return None
def _check_distance_violations(self, persons, frame, current_time):
violations = []
if len(persons) < 2:
return violations
for i in range(len(persons)):
for j in range(i+1, len(persons)):
dist = self._euclidean_distance(persons[i]['center'], persons[j]['center'])
if dist < self.unsafe_distance:
person1_id = persons[i]['id']
person2_id = persons[j]['id']
violation_type = 'unsafe_distance'
pair_key = f"{min(person1_id, person2_id)}_{max(person1_id, person2_id)}"
if (self.has_reported_violation(person1_id, violation_type) or
self.has_reported_violation(person2_id, violation_type)):
continue
violation_key = f"unsafe_distance_{pair_key}"
if (violation_key not in self.active_violations or
current_time - self.active_violations[violation_key]['last_detected'] > self.violation_cooldown):
self.mark_violation_reported(person1_id, violation_type, current_time)
self.mark_violation_reported(person2_id, violation_type, current_time)
self.active_violations[violation_key] = {
'type': 'unsafe_distance',
'person1_id': person1_id,
'person2_id': person2_id,
'first_detected': current_time,
'last_detected': current_time,
'count': 1
}
for pid in [person1_id, person2_id]:
if 'unsafe_distance' not in self.person_tracker[pid]['violations']:
self.person_tracker[pid]['violations']['unsafe_distance'] = {
'count': 0,
'last_time': 0
}
self.person_tracker[pid]['violations']['unsafe_distance']['count'] += 1
self.person_tracker[pid]['violations']['unsafe_distance']['last_time'] = current_time
self._annotate_distance(frame, persons[i]['box'], persons[j]['box'],
person1_id, person2_id, dist)
logger.info(f"NEW VIOLATION: Unsafe distance detected between persons {person1_id} and {person2_id}")
violations.append({
'type': 'unsafe_distance',
'severity': 'Moderate',
'person1': persons[i],
'person2': persons[j],
'distance': dist,
'person1_id': person1_id,
'person2_id': person2_id,
'timestamp': current_time
})
else:
self.active_violations[violation_key]['last_detected'] = current_time
self.active_violations[violation_key]['count'] += 1
return violations
def _cleanup_violations(self, current_time):
expired_violations = [
k for k, v in self.active_violations.items()
if current_time - v['last_detected'] > self.violation_cooldown
]
for key in expired_violations:
del self.active_violations[key]
def _cleanup_inactive_persons(self, current_person_ids, current_time):
inactive_timeout = 60
expired_persons = [
pid for pid, data in self.person_tracker.items()
if pid not in current_person_ids and
current_time - data['last_seen'] > inactive_timeout
]
for pid in expired_persons:
del self.person_tracker[pid]
if pid in self.person_positions_history:
del self.person_positions_history[pid]
def _annotate_frame(self, frame, box, person_id, violation_type, color):
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label = f"ID:{person_id:03d} {violation_type}"
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
def _annotate_distance(self, frame, box1, box2, id1, id2, dist):
x1, y1, x2, y2 = map(int, box1)
x3, y3, x4, y4 = map(int, box2)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), 2)
cv2.rectangle(frame, (x3, y3), (x4, y4), (0, 165, 255), 2)
center1 = ((x1 + x2) // 2, (y1 + y2) // 2)
center2 = ((x3 + x4) // 2, (y3 + y4) // 2)
cv2.line(frame, center1, center2, (0, 165, 255), 2)
mid_point = ((center1[0] + center2[0]) // 2, (center1[1] + center2[1]) // 2)
cv2.putText(frame, f"{dist:.1f}px", mid_point,
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 2)
def _iou(self, box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
intersection = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
return intersection / (area1 + area2 - intersection + 1e-6)
def _euclidean_distance(self, point1, point2):
return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)
def get_session_summary(self):
summary = {
'total_persons': len(self.session_violations),
'violations_by_type': {},
'persons_with_violations': []
}
for person_id, violations in self.session_violations.items():
person_info = {
'person_id': person_id,
'violations': list(violations.keys()),
'violation_count': len(violations)
}
summary['persons_with_violations'].append(person_info)
for violation_type in violations.keys():
if violation_type not in summary['violations_by_type']:
summary['violations_by_type'][violation_type] = 0
summary['violations_by_type'][violation_type] += 1
return summary
# --- Frame Processing Functions ---
def preprocess_frame(frame):
try:
# Enhance image for better detection
frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) # Increase contrast
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img, (320, 320)) # Reduced resolution
return img_resized
except Exception as e:
logger.error(f"Frame preprocessing error: {e}")
raise
def capture_rtsp_frames(rtsp_url, max_frames=None):
try:
logger.info(f"Connecting to RTSP stream: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
raise ValueError(f"RTSP stream not accessible: {rtsp_url}")
frame_count = 0
while cap.isOpened() and (max_frames is None or frame_count < max_frames):
ret, frame = cap.read()
if ret:
timestamp = datetime.now(IST).isoformat()
frame_count += 1
yield frame, timestamp, frame_count, None
else:
logger.warning("Failed to read frame from RTSP stream")
break
cap.release()
except Exception as e:
logger.error(f"RTSP capture error: {e}")
raise
finally:
cv2.destroyAllWindows()
# --- Image Processing Function ---
async def process_image(image_path, progress=gr.Progress()):
"""Process a single image for safety violations"""
try:
logger.info(f"Starting image analysis: {image_path}")
start_time = time.time()
current_run_violations = []
new_sf_record_ids = []
violation_payloads = []
tracker = SafetyViolationDetector()
tracker.reset_session()
logger.info("Starting new image analysis session")
# Get Salesforce connection
sf = None
if SALESFORCE_ENABLED:
try:
sf = get_salesforce_connection()
except Exception as e:
logger.error(f"Could not connect to Salesforce: {e}")
progress(0.1, desc="Loading image...")
# Load image
frame = cv2.imread(image_path)
if frame is None:
error_msg = f"Failed to load image: {image_path}"
logger.error(error_msg)
return None, error_msg, None, format_violations_as_text([])
progress(0.3, desc="Preprocessing image...")
# Preprocess image
processed_frame = preprocess_frame(frame)
progress(0.5, desc="Running AI detection...")
# Run YOLO detection
results = yolo_model.predict(processed_frame)
progress(0.7, desc="Analyzing violations...")
# Detect violations
violations = tracker.detect_violations(results, frame)
violation_count = 0
timestamp = datetime.now(IST).isoformat()
# Process each violation
for violation in violations:
violation_count += 1
snapshot_url = save_snapshot(frame, save_to_disk=False)
worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
if violation['type'] == 'unsafe_distance':
worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
violation_data = {
'violation_type': violation['type'].replace('_', ' ').title(),
'severity': violation['severity'],
'timestamp': timestamp,
'snapshot_url': snapshot_url,
'site_id': 'SITE001',
'camera_id': 'CAM001',
'worker_id': worker_id,
'frame_number': 1 # Single image
}
if violation['type'] == 'unsafe_distance':
violation_data['distance'] = f"{violation['distance']:.1f}px"
current_run_violations.append(violation_data)
log_violation(violation_data)
send_alert(violation_data)
# Prepare Salesforce record
if sf:
payload, error = create_salesforce_violation_record(sf, violation_data)
if payload:
violation_payloads.append(payload)
else:
logger.error(f"Salesforce payload creation failed: {error}")
progress(0.8, desc="Creating Salesforce records...")
# Create Salesforce records in bulk
if sf and violation_payloads:
try:
results = sf.bulk.Safety_Violation_Log__c.insert(violation_payloads)
new_sf_record_ids = [result['id'] for result in results if result.get('success')]
logger.info(f"Created {len(new_sf_record_ids)} Salesforce records")
for result in results:
if not result.get('success'):
logger.error(f"Failed to create record: {result.get('errors')}")
except Exception as e:
logger.error(f"Failed to create bulk Salesforce records: {e}")
progress(0.9, desc="Generating report...")
# Generate PDF report if violations found
pdf_temp_path = None
if sf and new_sf_record_ids and current_run_violations:
logger.info("Generating and uploading PDF report to Salesforce...")
pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(
sf, current_run_violations, new_sf_record_ids
)
if not pdf_temp_path:
logger.error("Failed to generate Salesforce report")
elif current_run_violations and not sf:
# Generate local PDF if no Salesforce
pdf_temp_path = generate_local_pdf_report(current_run_violations)
processing_time = time.time() - start_time
session_summary = tracker.get_session_summary()
progress(1.0, desc="Analysis complete!")
# Generate status message
if violation_count > 0:
status_message = f"""βœ… IMAGE ANALYSIS COMPLETED
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
πŸ“Š RESULTS:
β€’ Processing Time: {processing_time:.2f}s
β€’ Image: {os.path.basename(image_path)}
πŸ‘₯ UNIQUE PERSONS TRACKED: {session_summary['total_persons']}
πŸ” VIOLATION TYPES: {', '.join(session_summary['violations_by_type'].keys())}
🚨 UNIQUE VIOLATIONS: {violation_count}
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Each violation reported only once per person"""
else:
status_message = f"""βœ… IMAGE ANALYSIS COMPLETED
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
πŸ“Š RESULTS:
β€’ Processing Time: {processing_time:.2f}s
β€’ Image: {os.path.basename(image_path)}
πŸ‘₯ UNIQUE PERSONS TRACKED: {session_summary['total_persons']}
βœ… NO VIOLATIONS DETECTED
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
All safety protocols followed"""
logger.info(f"Image analysis complete. Processing time: {processing_time:.2f}s")
# Return annotated frame if violations found
output_frames = [frame] if violations else None
return output_frames, status_message, pdf_temp_path, format_violations_as_text(current_run_violations)
except Exception as e:
logger.error(f"Image processing error: {e}", exc_info=True)
error_message = f"Image processing failed: {str(e)}"
return None, error_message, None, format_violations_as_text([])
def generate_local_pdf_report(violations):
"""Generate a local PDF report when Salesforce is not available"""
try:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf', prefix='safety_report_')
c = canvas.Canvas(temp_file.name, pagesize=letter)
c.setFont("Helvetica-Bold", 16)
c.drawString(100, 750, "Safety Violation Report")
c.setFont("Helvetica", 12)
c.drawString(100, 730, f"Generated: {datetime.now(IST).strftime('%Y-%m-%d %H:%M:%S IST')}")
c.setFont("Helvetica", 10)
c.drawString(100, 710, "Note: Each violation type reported only once per person")
y = 680
for i, violation in enumerate(violations, 1):
c.setFont("Helvetica-Bold", 12)
c.drawString(100, y, f"Violation #{i}: {violation['violation_type']}")
y -= 20
c.setFont("Helvetica", 10)
c.drawString(120, y, f"Severity: {violation['severity']}")
y -= 15
c.drawString(120, y, f"Time: {violation['timestamp']}")
y -= 15
c.drawString(120, y, f"Worker: {violation.get('worker_id', 'UNKNOWN')}")
y -= 15
if 'distance' in violation:
c.drawString(120, y, f"Distance: {violation['distance']}")
y -= 15
y -= 20
if y < 50:
c.showPage()
y = 750
c.save()
temp_file.close()
return temp_file.name
except Exception as e:
logger.error(f"Local PDF generation error: {e}")
return None
# --- Media Processing Handler ---
async def process_media(media_file, frame_skip=5, progress=gr.Progress()):
"""Handle both image and video processing"""
if media_file is None:
return None, "No file uploaded", None, format_violations_as_text([])
file_path = media_file.name
file_extension = os.path.splitext(file_path)[1].lower()
# Image extensions
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
# Video extensions
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm', '.m4v'}
if file_extension in image_extensions:
logger.info(f"Processing image: {file_path}")
return await process_image(file_path, progress)
elif file_extension in video_extensions:
logger.info(f"Processing video: {file_path}")
return await process_video(file_path, frame_skip, progress)
else:
error_msg = f"Unsupported file format: {file_extension}. Please upload an image or video file."
logger.error(error_msg)
return None, error_msg, None, format_violations_as_text([])
# --- Video Processing Functions ---
async def process_video(video_path, frame_skip=5, progress=gr.Progress()):
global processing_active
processing_active = True
start_total = time.time()
try:
current_run_violations = []
new_sf_record_ids = []
violation_payloads = []
tracker = SafetyViolationDetector()
tracker.reset_session()
logger.info("Starting new video analysis session")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
error_message = f"Failed to open video file: {video_path}"
logger.error(error_message)
return None, error_message, None, format_violations_as_text([])
frames = []
max_display_frames = 10
frame_count = 0
processed_frames = 0
violation_count = 0
# Get Salesforce connection once at the beginning
sf = None
if SALESFORCE_ENABLED:
try:
sf = get_salesforce_connection()
except Exception as e:
logger.error(f"Could not connect to Salesforce at start: {e}")
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
progress(0, desc="Analyzing video...")
while cap.isOpened() and processing_active:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % frame_skip != 0:
continue
processed_frames += 1
timestamp = datetime.now(IST).isoformat()
progress_percent = min(100, (frame_count / total_frames) * 100)
progress(progress_percent / 100, desc=f"Processing frame {frame_count}/{total_frames}")
processed_frame = preprocess_frame(frame)
results = yolo_model.predict(processed_frame)
violations = tracker.detect_violations(results, frame)
for violation in violations:
violation_count += 1
snapshot_url = save_snapshot(frame, save_to_disk=False)
worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
if violation['type'] == 'unsafe_distance':
worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
violation_data = {
'violation_type': violation['type'].replace('_', ' ').title(),
'severity': violation['severity'],
'timestamp': timestamp,
'snapshot_url': snapshot_url,
'site_id': 'SITE001',
'camera_id': 'CAM001',
'worker_id': worker_id,
'frame_number': frame_count
}
if violation['type'] == 'unsafe_distance':
violation_data['distance'] = f"{violation['distance']:.1f}px"
current_run_violations.append(violation_data)
log_violation(violation_data)
send_alert(violation_data)
if sf:
payload, error = create_salesforce_violation_record(sf, violation_data)
if payload:
violation_payloads.append(payload)
else:
logger.error(f"Salesforce push failed for violation: {error}")
if violations and len(frames) < max_display_frames:
frames.append(frame)
elif violations:
frames.pop(0)
frames.append(frame)
cap.release()
if sf and violation_payloads:
try:
results = sf.bulk.Safety_Violation_Log__c.insert(violation_payloads)
new_sf_record_ids = [result['id'] for result in results if result.get('success')]
logger.info(f"Created {len(new_sf_record_ids)} Salesforce records in bulk")
for result in results:
if not result.get('success'):
logger.error(f"Failed to create record: {result.get('errors')}")
except Exception as e:
logger.error(f"Failed to create bulk Salesforce records: {e}", exc_info=True)
processing_time = time.time() - start_total
actual_fps = processed_frames / processing_time if processing_time > 0 else 0
if not processing_active:
return None, "Processing cancelled", None, format_violations_as_text([])
# Generate and upload report to Salesforce
pdf_temp_path = None
if sf and new_sf_record_ids and current_run_violations:
logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...")
pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids)
if not pdf_temp_path:
logger.error("Failed to generate and upload Salesforce report.")
elif current_run_violations and not sf:
# Generate local PDF if no Salesforce
pdf_temp_path = generate_local_pdf_report(current_run_violations)
elif not current_run_violations:
logger.info("No violations detected, skipping report generation.")
else:
logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.")
session_summary = tracker.get_session_summary()
logger.info(f"Video analysis complete. Session summary: {session_summary}")
logger.info(f"Total processing time: {processing_time:.2f}s")
status_message = generate_status_message(
violation_count > 0,
frame_count,
processed_frames,
duration,
violation_count,
processing_time,
actual_fps,
session_summary
)
return frames, status_message, pdf_temp_path, format_violations_as_text(current_run_violations)
except Exception as e:
logger.error(f"Video processing error: {e}", exc_info=True)
error_message = f"Video processing failed: {str(e)}"
return None, error_message, None, format_violations_as_text([])
finally:
processing_active = False
cv2.destroyAllWindows()
logger.info(f"Total processing time: {time.time() - start_total:.2f}s")
# --- RTSP Processing ---
async def process_rtsp_stream(rtsp_url, max_frames=None, frame_skip=5, progress=gr.Progress()):
global processing_active
processing_active = True
start_total = time.time()
try:
if not rtsp_url:
raise ValueError("RTSP URL not provided")
current_run_violations = []
new_sf_record_ids = []
violation_payloads = []
tracker = SafetyViolationDetector()
tracker.reset_session()
logger.info("Starting new RTSP stream analysis session")
# Get Salesforce connection once at the beginning
sf = None
if SALESFORCE_ENABLED:
try:
sf = get_salesforce_connection()
except Exception as e:
logger.error(f"Could not connect to Salesforce at start: {e}")
frames = []
max_display_frames = 10
violation_count = 0
progress(0, desc="Connecting to RTSP stream...")
for frame, timestamp, fc, _ in capture_rtsp_frames(rtsp_url, max_frames):
if not processing_active:
break
if fc % frame_skip != 0:
continue
progress_percent = min(100, (fc / (max_frames if max_frames else 100)) * 100)
progress(progress_percent / 100, desc=f"Processing frame {fc}")
processed_frame = preprocess_frame(frame)
results = yolo_model.predict(processed_frame)
violations = tracker.detect_violations(results, frame)
for violation in violations:
violation_count += 1
snapshot_url = save_snapshot(frame, save_to_disk=False)
worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
if violation['type'] == 'unsafe_distance':
worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
violation_data = {
'violation_type': violation['type'].replace('_', ' ').title(),
'severity': violation['severity'],
'timestamp': timestamp,
'snapshot_url': snapshot_url,
'site_id': 'SITE001',
'camera_id': 'CAM001',
'worker_id': worker_id,
'frame_number': fc
}
if violation['type'] == 'unsafe_distance':
violation_data['distance'] = f"{violation['distance']:.1f}px"
current_run_violations.append(violation_data)
log_violation(violation_data)
send_alert(violation_data)
if sf:
payload, error = create_salesforce_violation_record(sf, violation_data)
if payload:
violation_payloads.append(payload)
else:
logger.error(f"Salesforce push failed for violation: {error}")
if violations and len(frames) < max_display_frames:
frames.append(frame)
elif violations:
frames.pop(0)
frames.append(frame)
if sf and violation_payloads:
try:
results = sf.bulk.Safety_Violation_Log__c.insert(violation_payloads)
new_sf_record_ids = [result['id'] for result in results if result.get('success')]
logger.info(f"Created {len(new_sf_record_ids)} Salesforce records in bulk")
for result in results:
if not result.get('success'):
logger.error(f"Failed to create record: {result.get('errors')}")
except Exception as e:
logger.error(f"Failed to create bulk Salesforce records: {e}", exc_info=True)
if not processing_active:
logger.info("Processing cancelled.")
# Generate and upload report to Salesforce
pdf_temp_path = None
if sf and new_sf_record_ids and current_run_violations:
logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...")
pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids)
if not pdf_temp_path:
logger.error("Failed to generate and upload Salesforce report.")
elif current_run_violations and not sf:
# Generate local PDF if no Salesforce
pdf_temp_path = generate_local_pdf_report(current_run_violations)
elif not current_run_violations:
logger.info("No violations detected, skipping report generation.")
else:
logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.")
if not processing_active:
return "Processing cancelled.", frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations, generate=False), pdf_temp_path
session_summary = tracker.get_session_summary()
logger.info(f"RTSP analysis complete. Session summary: {session_summary}")
logger.info(f"Total processing time: {time.time() - start_total:.2f}s")
status_message = f"Processed {len(frames)} frames with {violation_count} unique violations. Persons tracked: {session_summary['total_persons']}"
return status_message, frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations, generate=False), pdf_temp_path
except Exception as e:
logger.error(f"RTSP processing error: {e}", exc_info=True)
error_message = f"RTSP processing failed: {str(e)}"
return error_message, None, format_violations_as_text([]), None, None
finally:
processing_active = False
cv2.destroyAllWindows()
logger.info(f"Total processing time: {time.time() - start_total:.2f}s")
# --- Other Functions ---
def generate_status_message(has_violations, total_frames, processed_frames, duration,
violation_count, processing_time, actual_fps, session_summary=None):
base_message = f"""βœ… ANALYSIS COMPLETED
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
πŸ“Š RESULTS:
β€’ Frames: {total_frames} (Processed: {processed_frames})
β€’ Duration: {duration:.2f}s
β€’ Processing Time: {processing_time:.2f}s
β€’ FPS: {actual_fps:.1f}"""
if session_summary:
base_message += f"""
πŸ‘₯ UNIQUE PERSONS TRACKED: {session_summary['total_persons']}
πŸ” VIOLATION TYPES: {', '.join(session_summary['violations_by_type'].keys()) if session_summary['violations_by_type'] else 'None'}"""
if has_violations:
return f"""{base_message}
🚨 UNIQUE VIOLATIONS: {violation_count}
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Each violation reported only once per person"""
else:
return f"""{base_message}
βœ… NO VIOLATIONS DETECTED
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
All safety protocols followed"""
def save_snapshot(frame, save_to_disk=True):
try:
if not save_to_disk:
return "in_memory_snapshot.jpg"
filename = f"snapshot_{int(time.time())}.jpg"
snapshot_dir = "./snapshots"
os.makedirs(snapshot_dir, exist_ok=True)
snapshot_path = os.path.join(snapshot_dir, filename)
cv2.imwrite(snapshot_path, frame)
return snapshot_path
except Exception as e:
logger.error(f"Snapshot error: {e}")
return "snapshot_failed.jpg"
def log_violation(violation_data):
try:
log_file = Path("./snapshots/violation_logs.json")
logs = []
if log_file.exists():
with open(log_file, "r") as f:
logs = json.load(f)
logs.append(violation_data)
global recent_violations, violation_history
recent_violations = logs[-10:]
violation_history = logs
with open(log_file, "w") as f:
json.dump(logs, f, indent=4)
except Exception as e:
logger.error(f"Logging error: {e}")
def send_alert(violation):
logger.info(f"ALERT: {violation['violation_type']} detected (Severity: {violation['severity']})")
def format_violations_as_text(violations):
if not violations:
return """πŸ” SAFETY MONITORING STATUS
βœ… NO VIOLATIONS DETECTED
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
πŸ“Š Current Status: ALL CLEAR
πŸ• Last Updated: """ + datetime.now(IST).strftime('%Y-%m-%d %H:%M:%S IST') + """
🎯 Detection Accuracy: >90% confidence
⚑ Response Time: <5 seconds
The system is actively monitoring for:
β€’ No Helmet violations
β€’ Unsafe Distance violations
β€’ Unauthorized Area violations
All safety protocols are currently being followed."""
text = f"""🚨 SAFETY VIOLATION ALERTS
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
πŸ“Š UNIQUE VIOLATIONS DETECTED: {len(violations)}
Note: Each violation type reported only once per person
"""
for i, violation in enumerate(violations, 1):
severity_emoji = "πŸ”΄" if violation['severity'] == 'Critical' else "🟑"
text += f"""
β”Œβ”€ ALERT #{i:02d} ─ {severity_emoji} {violation['violation_type'].upper()}
β”‚
β”œβ”€ πŸ• Time: {violation['timestamp']}
β”œβ”€ ⚠️ Severity: {violation['severity']}
β”œβ”€ πŸ“ Location: Site {violation['site_id']} | Camera {violation['camera_id']}
β”œβ”€ πŸ‘· Worker: {violation.get('worker_id', 'UNKNOWN')}
β”œβ”€ πŸ“Έ Evidence: {violation['snapshot_url']}
β”‚
└─────────────────────────────────────────────────\n"""
text += f"""
πŸ“ˆ SUMMARY STATISTICS:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
β€’ Total Violations: {len(violations)}
β€’ Critical: {sum(1 for v in violations if v['severity'] == 'Critical')}
β€’ Moderate: {sum(1 for v in violations if v['severity'] == 'Moderate')}
β€’ Last Alert: {violations[-1]['timestamp'] if violations else 'N/A'}
πŸ”„ System Status: ACTIVELY MONITORING
⚑ Response Time: <5 seconds
🎯 Detection Accuracy: >90% confidence"""
return text
def generate_heatmap(violations, generate=True):
if not generate or not violations:
return None
try:
df = pd.DataFrame(violations)
df['hour'] = pd.to_datetime(df['timestamp']).dt.hour
heatmap_data = df.pivot_table(index='hour', columns='violation_type', aggfunc='size', fill_value=0)
plt.figure(figsize=(12, 8))
sns.heatmap(heatmap_data, cmap='YlOrRd', annot=True, fmt='d')
plt.title("Unique Violations by Hour")
plt.xlabel("Violation Type")
plt.ylabel("Hour of Day")
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
plt.savefig(temp_file.name, bbox_inches='tight')
plt.close()
return temp_file.name
except Exception as e:
logger.error(f"Heatmap error: {e}")
return None
def cancel_processing():
global processing_active
processing_active = False
cv2.destroyAllWindows()
return "Processing cancelled"
# --- Enhanced CSS ---
enhanced_custom_css = """
/* Reset Default Styles */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
/* Global Theme and Layout */
.gradio-container {
font-family: 'Poppins', 'Inter', 'Segoe UI', 'Roboto', sans-serif !important;
background: linear-gradient(45deg, #0a0a1f, #1a0033, #2a0044, #0a0a1f) !important;
background-size: 400% !important;
animation: gradientShift 12s ease infinite !important;
min-height: 100vh !important;
display: flex !important;
flex-direction: column !important;
justify-content: center !important;
align-items: center !important;
position: relative !important;
overflow: hidden !important;
}
@keyframes gradientShift {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
/* Particle Canvas for Star-Like Animation */
#particle-canvas {
position: absolute !important;
top: 0 !important;
left: 0 !important;
width: 100% !important;
height: 100% !important;
z-index: 1 !important;
pointer-events: none !important;
}
/* Main Header Styling */
.main-header {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border: 1px solid rgba(255, 0, 204, 0.3) !important;
color: #f0f0f5 !important;
text-align: center !important;
padding: 1rem !important;
margin-bottom: 1rem !important;
border-radius: 15px !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important;
z-index: 2 !important;
animation: fadeIn 1s ease-out !important;
}
.header-title {
font-size: 2.5rem !important;
font-weight: 700 !important;
text-shadow: 0 0 12px rgba(255, 0, 204, 0.8) !important;
animation: glow 2s ease-in-out infinite alternate !important;
margin-bottom: 0.5rem !important;
}
.header-subtitle {
font-size: 1rem !important;
font-weight: 400 !important;
color: #ccc !important;
text-shadow: 0 0 5px rgba(255, 0, 204, 0.5) !important;
}
@keyframes glow {
from { text-shadow: 0 0 5px rgba(255, 0, 204, 0.5), 0 0 10px rgba(255, 0, 204, 0.3); }
to { text-shadow: 0 0 12px rgba(255, 0, 204, 0.9), 0 0 20px rgba(255, 0, 204, 0.6); }
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
/* Professional Card System (Glassmorphism) */
.professional-card {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border: 1px solid rgba(255, 0, 204, 0.3) !important;
border-radius: 15px !important;
padding: 1rem !important;
margin: 0.5rem 0 !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important;
z-index: 2 !important;
transition: all 0.3s ease !important;
}
.professional-card:hover {
transform: translateY(-2px) !important;
box-shadow: 0 0 30px rgba(255, 0, 204, 0.35) !important;
}
/* Section Headers */
.section-header {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border: 1px solid rgba(255, 0, 204, 0.3) !important;
color: #f0f0f5 !important;
padding: 0.8rem 1rem !important;
border-radius: 10px !important;
text-align: center !important;
font-weight: 700 !important;
font-size: 1.2rem !important;
margin-bottom: 1rem !important;
text-shadow: 0 0 12px rgba(255, 0, 204, 0.8) !important;
animation: glow 2s ease-in-out infinite alternate !important;
z-index: 2 !important;
}
/* Button Styling */
.btn-primary, .gr-button {
background: linear-gradient(90deg, #00C4B4, #ff00cc) !important;
border: none !important;
border-radius: 8px !important;
padding: 10px 20px !important;
color: white !important;
font-weight: 600 !important;
font-size: 1rem !important;
transition: all 0.3s ease !important;
box-shadow: 0 0 15px rgba(255, 0, 204, 0.4) !important;
position: relative !important;
overflow: hidden !important;
z-index: 2 !important;
}
.btn-primary:hover, .gr-button:hover {
background: linear-gradient(90deg, #00C6B6, #ff33cc) !important;
box-shadow: 0 0 20px rgba(255, 0, 204, 0.6) !important;
transform: translateY(-2px) !important;
}
.btn-primary::before, .gr-button::before {
content: '' !important;
position: absolute !important;
top: 50% !important;
left: 50% !important;
width: 300% !important;
height: 300% !important;
background: rgba(255, 255, 255, 0.1) !important;
transition: all 0.5s ease !important;
transform: translate(-50%, -50%) rotate(45deg) !important;
opacity: 0 !important;
}
.btn-primary:hover::before, .gr-button:hover::before {
opacity: 1 !important;
width: 0 !important;
height: 0 !important;
}
.btn-secondary {
background: linear-gradient(90deg, #11998e, #38ef7d) !important;
border: none !important;
border-radius: 8px !important;
padding: 10px 20px !important;
color: white !important;
font-weight: 600 !important;
font-size: 1rem !important;
transition: all 0.3s ease !important;
box-shadow: 0 0 15px rgba(17, 153, 142, 0.4) !important;
}
.btn-secondary:hover {
background: linear-gradient(90deg, #0e867c, #33d670) !important;
box-shadow: 0 0 20px rgba(17, 153, 142, 0.6) !important;
transform: translateY(-2px) !important;
}
/* Status Display */
.status-display {
background: rgba(255, 255, 255, 0.05) !important;
border: 1px solid rgba(255, 255, 255, 0.15) !important;
border-radius: 8px !important;
padding: 1rem !important;
color: #f0f0f5 !important;
font-size: 0.95rem !important;
font-family: 'Fira Code', 'Consolas', monospace !important;
white-space: pre-wrap !important;
max-height: 300px !important;
overflow-y: auto !important;
z-index: 2 !important;
}
.status-display:focus {
border-color: #ff00cc !important;
box-shadow: 0 0 8px rgba(255, 0, 204, 0.6) !important;
}
.status-display::-webkit-scrollbar {
width: 5px !important;
}
.status-display::-webkit-scrollbar-track {
background: rgba(255, 255, 255, 0.05) !important;
border-radius: 3px !important;
}
.status-display::-webkit-scrollbar-thumb {
background: #ff00cc !important;
border-radius: 3px !important;
}
/* Alert Panel */
.alert-panel {
background: rgba(255, 51, 51, 0.15) !important;
backdrop-filter: blur(10px) !important;
border: 1px solid rgba(255, 51, 51, 0.5) !important;
color: #ff3333 !important;
border-radius: 15px !important;
padding: 1rem !important;
margin: 0.5rem 0 !important;
box-shadow: 0 0 20px rgba(255, 51, 51, 0.3) !important;
animation: alertPulse 2s infinite !important;
z-index: 2 !important;
}
@keyframes alertPulse {
0%, 100% { transform: scale(1); opacity: 1; }
50% { transform: scale(1.01); opacity: 0.95; }
}
/* Success Panel */
.success-panel {
background: rgba(0, 184, 148, 0.15) !important;
backdrop-filter: blur(10px) !important;
border: 1px solid rgba(0, 184, 148, 0.5) !important;
color: #00b894 !important;
border-radius: 15px !important;
padding: 1rem !important;
margin: 0.5rem 0 !important;
box-shadow: 0 0 20px rgba(0, 184, 148, 0.3) !important;
z-index: 2 !important;
}
/* Image Components */
.image-component {
border-radius: 15px !important;
overflow: hidden !important;
box-shadow: 0 0 20px rgba(255, 0, 204, 0.2) !important;
transition: all 0.3s ease !important;
border: 1px solid rgba(255, 0, 204, 0.2) !important;
z-index: 2 !important;
}
.image-component:hover {
transform: scale(1.01) !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.3) !important;
}
/* Gallery Styling */
.gallery-component {
border-radius: 15px !important;
overflow: hidden !important;
box-shadow: 0 0 20px rgba(255, 0, 204, 0.2) !important;
background: rgba(255, 255, 255, 0.05) !important;
padding: 0.5rem !important;
z-index: 2 !important;
}
/* File Download Component */
.file-component {
background: rgba(255, 255, 255, 0.05) !important;
border: 1px dashed rgba(255, 0, 204, 0.3) !important;
border-radius: 10px !important;
padding: 1rem !important;
text-align: center !important;
transition: all 0.3s ease !important;
z-index: 2 !important;
}
.file-component:hover {
background: rgba(255, 255, 255, 0.1) !important;
transform: translateY(-2px) !important;
}
/* Analytics Dashboard */
.analytics-panel {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border-radius: 15px !important;
padding: 1rem !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important;
z-index: 2 !important;
}
/* Tab Styling */
.gradio-tabs {
border: none !important;
background: transparent !important;
z-index: 2 !important;
}
.gradio-tab-item {
background: rgba(255, 255, 255, 0.05) !important;
border: 1px solid rgba(255, 0, 204, 0.2) !important;
border-radius: 8px !important;
color: #ccc !important;
padding: 0.5rem 1rem !important;
margin: 0 0.2rem !important;
transition: all 0.3s ease !important;
}
.gradio-tab-item.selected {
background: rgba(255, 255, 255, 0.15) !important;
color: #ff00cc !important;
font-weight: 600 !important;
border-color: #ff00cc !important;
box-shadow: 0 0 10px rgba(255, 0, 204, 0.5) !important;
}
.gradio-tab-item:hover {
background: rgba(255, 255, 255, 0.1) !important;
color: #ff00cc !important;
}
.gradio-tab-content {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border-radius: 15px !important;
padding: 1rem !important;
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important;
z-index: 2 !important;
}
/* Footer Styling */
.footer-info {
background: rgba(255, 255, 255, 0.12) !important;
backdrop-filter: blur(15px) !important;
border: 1px solid rgba(255, 0, 204, 0.3) !important;
border-radius: 15px !important;
padding: 1rem !important;
margin-top: 1rem !important;
text-align: center !important;
color: #f0f0f5 !important;
z-index: 2 !important;
}
.feature-grid {
display: grid !important;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)) !important;
gap: 0.5rem !important;
margin-top: 1rem !important;
}
.feature-item {
background: rgba(255, 255, 255, 0.05) !important;
padding: 0.5rem !important;
border-radius: 8px !important;
text-align: center !important;
transition: all 0.3s ease !important;
color: #ccc !important;
}
.feature-item:hover {
background: rgba(255, 255, 255, 0.1) !important;
transform: translateY(-2px) !important;
color: #ff00cc !important;
}
/* Responsive Design */
@media (max-width: 768px) {
.gradio-container {
padding: 1rem !important;
}
.main-header {
padding: 0.8rem !important;
}
.header-title {
font-size: 2rem !important;
}
.professional-card {
padding: 0.8rem !important;
margin: 0.3rem 0 !important;
}
.section-header {
font-size: 1rem !important;
padding: 0.6rem !important;
}
.btn-primary, .btn-secondary, .gr-button {
padding: 8px 16px !important;
font-size: 0.9rem !important;
}
}
"""
# --- Gradio Interface ---
with gr.Blocks(
title="Dynamic Safety Violation Detection using CCTV + AI",
css=enhanced_custom_css,
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="emerald",
neutral_hue="slate",
radius_size="lg",
spacing_size="sm",
font=[
gr.themes.GoogleFont("Poppins"),
"ui-sans-serif",
"system-ui",
"sans-serif"
]
).set(
body_background_fill="none",
block_background_fill="none",
block_border_width="0px",
block_shadow="none",
block_radius="15px",
button_primary_background_fill="none",
button_primary_background_fill_hover="none",
button_secondary_background_fill="none"
)
) as demo:
# Particle Canvas for Enhanced Star Animation
gr.HTML("""
<canvas id="particle-canvas"></canvas>
<script>
const canvas = document.getElementById('particle-canvas');
const ctx = canvas.getContext('2d');
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
window.addEventListener('resize', () => {
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
});
const stars = [];
const starCount = 200;
class Star {
constructor() {
this.reset();
}
reset() {
this.x = Math.random() * canvas.width;
this.y = Math.random() * canvas.height;
this.size = Math.random() * 2 + 0.5;
this.speedX = Math.random() * 5 + 2;
this.speedY = (Math.random() - 0.5) * 0.5;
this.opacity = Math.random() * 0.5 + 0.5;
this.twinklePhase = Math.random() * Math.PI * 2;
this.twinkleSpeed = Math.random() * 0.05 + 0.02;
}
update() {
this.x -= this.speedX;
this.y += this.speedY;
this.twinklePhase += this.twinkleSpeed;
this.currentOpacity = this.opacity * (0.5 + 0.5 * Math.sin(this.twinklePhase));
if (this.x < 0 || this.y < 0 || this.y > canvas.height) {
this.reset();
}
}
draw() {
ctx.fillStyle = `rgba(255, 0, 204, ${this.currentOpacity})`;
ctx.beginPath();
ctx.arc(this.x, this.y, this.size, 0, Math.PI * 2);
ctx.fill();
ctx.fillStyle = `rgba(255, 0, 204, ${this.currentOpacity * 0.3})`;
ctx.beginPath();
ctx.arc(this.x, this.y, this.size * 2, 0, Math.PI * 2);
ctx.fill();
}
}
function initStars() {
for (let i = 0; i < starCount; i++) {
const star = new Star();
star.x = Math.random() * canvas.width;
stars.push(star);
}
}
function animateStars() {
ctx.clearRect(0, 0, canvas.width, canvas.height);
stars.forEach(star => {
star.update();
star.draw();
});
requestAnimationFrame(animateStars);
}
initStars();
animateStars();
</script>
""")
# Professional Header
gr.HTML("""
<div class="main-header">
<h1 class="header-title">πŸ” Dynamic Safety Violation Detection using CCTV + AI</h1>
<p class="header-subtitle">Enhanced Multi-Person Tracking with Image & Video Analysis - Each violation type detected only once per person</p>
</div>
""")
# Smart Media Analysis Section
gr.HTML('<div class="section-header">πŸ“· Smart Media Analysis (Images & Videos)</div>')
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card"]):
media_input = gr.File(
label="πŸ“€ Upload Image or Video for Safety Analysis",
file_types=["image", "video"],
elem_classes=["image-component"],
height=200
)
frame_skip_input = gr.Slider(
minimum=1,
maximum=10,
step=1,
value=5,
label="Frame Skip (Higher = Faster Processing, Videos Only)"
)
with gr.Row():
media_button = gr.Button(
"πŸ” Analyze Media",
variant="primary",
elem_classes=["btn-primary"],
size="lg"
)
# Analysis Results Section
gr.HTML('<div class="section-header">πŸ“Š Analysis Results & Violation Details</div>')
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card"]):
media_output = gr.Gallery(
label="πŸ–ΌοΈ Processed Media with Detection Results",
elem_classes=["gallery-component"],
height=260
)
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card"]):
media_status = gr.Textbox(
label="πŸ“‹ Analysis Status",
elem_classes=["status-display"],
lines=7,
max_lines=9,
value="πŸ“Š Awaiting Media Analysis\n\nUpload an image or video and click 'Analyze Media' to begin safety violation detection.\n\nβ€’ Images: Instant analysis\nβ€’ Videos: Frame-by-frame processing",
interactive=False
)
pdf_output = gr.File(
label="πŸ“₯ Download Professional Report",
elem_classes=["file-component"]
)
# Violation Details Section
gr.HTML('<div class="section-header">🚨 Real-time Violation Monitoring</div>')
with gr.Group(elem_classes=["professional-card", "alert-panel"]):
violation_log = gr.Textbox(
label="🚨 Real-time Violation Details",
elem_classes=["status-display"],
lines=10,
max_lines=12,
value=format_violations_as_text(recent_violations),
interactive=False
)
# Live Stream Processing Section
gr.HTML('<div class="section-header">πŸ“Ή Live Stream Monitoring</div>')
with gr.Row():
with gr.Column(scale=2):
with gr.Group(elem_classes=["professional-card"]):
rtsp_url_input = gr.Textbox(
label="πŸ“‘ RTSP Stream URL",
placeholder="rtsp://example.com/stream",
value=RTSP_URL_DEFAULT,
interactive=True
)
with gr.Row():
rtsp_button = gr.Button(
"πŸ“‘ Start Live Monitoring",
variant="primary",
elem_classes=["btn-primary"],
size="lg"
)
rtsp_cancel_btn = gr.Button(
"⏹️ Stop Monitoring",
variant="secondary",
elem_classes=["btn-secondary"],
size="lg"
)
rtsp_status = gr.Textbox(
label="πŸ“Ί Live Stream Processing Status",
elem_classes=["status-display"],
lines=6,
max_lines=8,
value="πŸ“Ί RTSP Stream Processor Ready\n\nEnter an RTSP URL and click 'Start Live Monitoring' to begin real-time monitoring.",
interactive=False
)
with gr.Column(scale=3):
with gr.Group(elem_classes=["professional-card"]):
rtsp_output = gr.Gallery(
label="🎬 Live Stream Frames & Detection Results",
elem_classes=["gallery-component"],
height=360,
columns=3,
rows=2,
object_fit="cover"
)
# Live Violation Log Section
gr.HTML('<div class="section-header">πŸ“Š Live Violation Analytics</div>')
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card", "alert-panel"]):
rtsp_violation_log = gr.Textbox(
label="🚨 Live Violation Log",
elem_classes=["status-display"],
lines=8,
max_lines=10,
interactive=False
)
with gr.Column(scale=1):
with gr.Group(elem_classes=["professional-card", "analytics-panel"]):
heatmap_output = gr.Image(
label="πŸ”₯ Violation Heatmap - Temporal Analysis",
elem_classes=["image-component"],
height=320
)
rtsp_pdf_output = gr.File(
label="πŸ“₯ Download RTSP Report",
elem_classes=["file-component"]
)
# Professional Footer
gr.HTML("""
<div class="footer-info">
<h3>πŸ›‘οΈ Dynamic Safety Violation Detection using CCTV + AI</h3>
<div class="feature-grid">
<div class="feature-item">
<strong>🎯 Real-time Detection</strong><br>
Advanced YOLOv8 AI with >90% accuracy
</div>
<div class="feature-item">
<strong>⚑ Ultra-fast Response</strong><br>
Alert generation in <5 seconds
</div>
<div class="feature-item">
<strong>πŸ“Έ Image & Video Support</strong><br>
Process both static images and video files
</div>
<div class="feature-item">
<strong>πŸ“± Responsive Design</strong><br>
Optimized for desktop, tablet & mobile
</div>
</div>
<div style="margin-top: 0.8rem; padding-top: 0.8rem; border-top: 0.5px solid rgba(255,255,255,0.2);">
<p style="margin: 0; font-size: 0.8rem; opacity: 0.7;">
Dynamic Safety Violation Detection using CCTV + AI Β© 2025
</p>
</div>
</div>
""")
# Event Handlers
media_button.click(
fn=process_media,
inputs=[media_input, frame_skip_input],
outputs=[media_output, media_status, pdf_output, violation_log]
)
rtsp_button.click(
fn=process_rtsp_stream,
inputs=[rtsp_url_input],
outputs=[rtsp_status, rtsp_output, rtsp_violation_log, heatmap_output, rtsp_pdf_output]
)
rtsp_cancel_btn.click(cancel_processing, outputs=[rtsp_status])
if __name__ == "__main__":
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True,
quiet=False,
favicon_path=None,
auth=None,
inbrowser=True
)