|
|
import os |
|
|
import sys |
|
|
import warnings |
|
|
import gradio as gr |
|
|
import torch |
|
|
from ultralytics import YOLO |
|
|
import cv2 |
|
|
import requests |
|
|
import json |
|
|
import time |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import logging |
|
|
import pandas as pd |
|
|
from reportlab.lib.pagesizes import letter |
|
|
from reportlab.pdfgen import canvas |
|
|
from io import BytesIO |
|
|
import seaborn as sns |
|
|
import matplotlib.pyplot as plt |
|
|
import subprocess |
|
|
from datetime import timezone |
|
|
import pytz |
|
|
import shutil |
|
|
import tempfile |
|
|
from scipy.spatial import distance |
|
|
import asyncio |
|
|
from functools import partial |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from simple_salesforce import Salesforce, SalesforceAuthenticationFailed |
|
|
from retrying import retry |
|
|
import base64 |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
MODEL_PATH = "./yolov8n.pt" |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[logging.StreamHandler()] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
logger.info(f"PyTorch CUDA Available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") |
|
|
else: |
|
|
logger.info("Using CPU") |
|
|
|
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
|
logger.info(f"Downloading model weights to {MODEL_PATH}...") |
|
|
try: |
|
|
download_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt" |
|
|
subprocess.run(["wget", download_url, "-O", MODEL_PATH], check=True) |
|
|
os.chmod(MODEL_PATH, 0o644) |
|
|
logger.info("Model weights downloaded successfully") |
|
|
except subprocess.CalledProcessError as e: |
|
|
logger.error(f"Failed to download model: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
yolo_config_dir = "./Ultralytics" |
|
|
os.makedirs(yolo_config_dir, exist_ok=True) |
|
|
os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir |
|
|
|
|
|
|
|
|
RTSP_URL_DEFAULT = os.getenv("RTSP_URL", "") |
|
|
SALESFORCE_URL = os.getenv("SALESFORCE_URL", "") |
|
|
SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "") |
|
|
HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "") |
|
|
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "") |
|
|
SF_USERNAME = "SafetyViolationAI22@sathkrutha.com" |
|
|
SF_PASSWORD = "Vij@y12345" |
|
|
SF_SECURITY_TOKEN = "inrcIMUU7rkV7BnNZ2LvD5MVQ" |
|
|
SALESFORCE_ENABLED = all([SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN]) |
|
|
|
|
|
|
|
|
IST = pytz.timezone("Asia/Kolkata") |
|
|
|
|
|
|
|
|
yolo_model = None |
|
|
recent_violations = [] |
|
|
violation_history = [] |
|
|
processing_active = False |
|
|
sf_connection = None |
|
|
site_id_cache = {} |
|
|
|
|
|
|
|
|
def initialize_model(): |
|
|
global yolo_model |
|
|
try: |
|
|
logger.info("Initializing YOLOv8 model...") |
|
|
yolo_model = YOLO(MODEL_PATH) |
|
|
logger.info("YOLOv8 model loaded successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize model: {e}") |
|
|
return False |
|
|
|
|
|
if not initialize_model(): |
|
|
logger.error("Critical error: Model initialization failed") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
@retry(stop_max_attempt_number=3, wait_fixed=2000, |
|
|
retry_on_exception=lambda e: isinstance(e, Exception)) |
|
|
def get_salesforce_connection(): |
|
|
"""Establishes and caches a Salesforce connection with retry logic.""" |
|
|
global sf_connection |
|
|
if sf_connection: |
|
|
try: |
|
|
sf_connection.query("SELECT Id FROM User LIMIT 1") |
|
|
logger.info("Salesforce connection is active.") |
|
|
return sf_connection |
|
|
except Exception: |
|
|
logger.warning("Salesforce session expired. Reconnecting...") |
|
|
sf_connection = None |
|
|
|
|
|
if not SALESFORCE_ENABLED: |
|
|
raise ConnectionError("Salesforce credentials are not configured.") |
|
|
|
|
|
try: |
|
|
sf_connection = Salesforce(username=SF_USERNAME, password=SF_PASSWORD, security_token=SF_SECURITY_TOKEN) |
|
|
logger.info(f"Successfully connected to Salesforce instance: {sf_connection.sf_instance}") |
|
|
return sf_connection |
|
|
except SalesforceAuthenticationFailed as e: |
|
|
logger.error(f"Salesforce authentication failed: {e}. Check credentials and IP restrictions.") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to connect to Salesforce: {e}") |
|
|
raise |
|
|
|
|
|
def get_or_create_site_id(sf, site_name='SITE001'): |
|
|
""" |
|
|
Queries for a Site record by name, creates it if not found, |
|
|
and returns the Salesforce ID. Caches the result. |
|
|
""" |
|
|
if site_name in site_id_cache: |
|
|
return site_id_cache[site_name] |
|
|
|
|
|
try: |
|
|
query = f"SELECT Id FROM Site__c WHERE Name = '{site_name}' LIMIT 1" |
|
|
result = sf.query(query) |
|
|
|
|
|
if result['totalSize'] > 0: |
|
|
site_id = result['records'][0]['Id'] |
|
|
logger.info(f"Found existing Site '{site_name}' with ID: {site_id}") |
|
|
site_id_cache[site_name] = site_id |
|
|
return site_id |
|
|
else: |
|
|
logger.info(f"Site '{site_name}' not found. Creating new Site record...") |
|
|
create_result = sf.Site__c.create({'Name': site_name}) |
|
|
if 'id' in create_result: |
|
|
site_id = create_result['id'] |
|
|
logger.info(f"Successfully created new Site '{site_name}' with ID: {site_id}") |
|
|
site_id_cache[site_name] = site_id |
|
|
return site_id |
|
|
else: |
|
|
logger.error(f"Failed to create Site record: {create_result.get('errors')}") |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Error getting or creating Site ID for '{site_name}': {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
def create_salesforce_violation_record(sf, violation_data): |
|
|
""" |
|
|
Prepares a payload for a Safety_Violation_Log__c record in Salesforce. |
|
|
Returns the payload for batch creation or None if failed. |
|
|
""" |
|
|
try: |
|
|
site_name = violation_data.get('site_id', 'Default Site') |
|
|
site_id_from_sf = get_or_create_site_id(sf, site_name) |
|
|
|
|
|
if not site_id_from_sf: |
|
|
logger.error(f"Failed to get or create Site record '{site_name}' in Salesforce.") |
|
|
return None, "Site ID creation/retrieval failed." |
|
|
|
|
|
payload = { |
|
|
'Site_ID__c': site_id_from_sf, |
|
|
'Violation_Type__c': violation_data.get('violation_type'), |
|
|
'Severity__c': violation_data.get('severity', 'Medium'), |
|
|
'Timestamp__c': violation_data.get('timestamp'), |
|
|
'Snapshot_URL__c': violation_data.get('snapshot_url', 'N/A'), |
|
|
'Worker_ID__c': violation_data.get('worker_id', 'N/A'), |
|
|
'Camera_ID__c': violation_data.get('camera_id', 'CAM001'), |
|
|
'Alert_Sent__c': True, |
|
|
'PDF_Report_URL__c': violation_data.get('pdf_url', "Report will be available after processing completion.") |
|
|
} |
|
|
|
|
|
payload = {k: v for k, v in payload.items() if v is not None} |
|
|
return payload, None |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to prepare Salesforce record: {e}", exc_info=True) |
|
|
return None, str(e) |
|
|
|
|
|
def generate_and_upload_report_to_salesforce(sf, violations, record_ids): |
|
|
""" |
|
|
Generates a PDF report, uploads it to Salesforce, links it to records, |
|
|
and updates those records with the download URL. |
|
|
Returns a temporary local path for the PDF and the Salesforce URL. |
|
|
""" |
|
|
if not violations or not record_ids or not sf: |
|
|
logger.warning("No violations, record IDs, or Salesforce connection. Skipping report generation.") |
|
|
return None, None |
|
|
|
|
|
try: |
|
|
|
|
|
buffer = BytesIO() |
|
|
c = canvas.Canvas(buffer, pagesize=letter) |
|
|
c.setFont("Helvetica-Bold", 16) |
|
|
c.drawString(100, 750, "Safety Violation Report") |
|
|
c.setFont("Helvetica", 12) |
|
|
c.drawString(100, 730, f"Generated: {datetime.now(IST).strftime('%Y-%m-%d %H:%M:%S IST')}") |
|
|
c.setFont("Helvetica", 10) |
|
|
c.drawString(100, 710, "Note: Each violation type reported only once per person per session.") |
|
|
|
|
|
y = 680 |
|
|
for i, violation in enumerate(violations, 1): |
|
|
c.setFont("Helvetica-Bold", 12) |
|
|
c.drawString(100, y, f"Violation #{i}: {violation['violation_type']}") |
|
|
y -= 20 |
|
|
c.setFont("Helvetica", 10) |
|
|
c.drawString(120, y, f"Severity: {violation['severity']}") |
|
|
y -= 15 |
|
|
c.drawString(120, y, f"Time: {violation['timestamp']}") |
|
|
y -= 15 |
|
|
c.drawString(120, y, f"Worker: {violation.get('worker_id', 'UNKNOWN')}") |
|
|
y -= 15 |
|
|
if 'distance' in violation: |
|
|
c.drawString(120, y, f"Distance: {violation['distance']}") |
|
|
y -= 15 |
|
|
y -= 20 |
|
|
if y < 50: |
|
|
c.showPage() |
|
|
y = 750 |
|
|
c.save() |
|
|
pdf_bytes = buffer.getvalue() |
|
|
buffer.close() |
|
|
|
|
|
|
|
|
title = f"Safety_Report_{datetime.now(IST).strftime('%Y%m%d_%H%M%S')}" |
|
|
b64_pdf = base64.b64encode(pdf_bytes).decode('utf-8') |
|
|
|
|
|
logger.info(f"Uploading PDF '{title}.pdf' to Salesforce...") |
|
|
cv_result = sf.ContentVersion.create({ |
|
|
'Title': title, |
|
|
'PathOnClient': f'{title}.pdf', |
|
|
'VersionData': b64_pdf |
|
|
}) |
|
|
|
|
|
if not cv_result.get('success'): |
|
|
logger.error(f"Failed to create ContentVersion: {cv_result.get('errors')}") |
|
|
return None, None |
|
|
|
|
|
content_version_id = cv_result['id'] |
|
|
logger.info(f"Successfully created ContentVersion with ID: {content_version_id}") |
|
|
|
|
|
|
|
|
query = f"SELECT ContentDocumentId FROM ContentVersion WHERE Id = '{content_version_id}'" |
|
|
cv_query_result = sf.query(query) |
|
|
if not cv_query_result['records']: |
|
|
logger.error(f"Could not find ContentDocumentId for ContentVersion {content_version_id}") |
|
|
return None, None |
|
|
content_document_id = cv_query_result['records'][0]['ContentDocumentId'] |
|
|
|
|
|
|
|
|
logger.info(f"Linking ContentDocument {content_document_id} to {len(record_ids)} records...") |
|
|
link_payloads = [{ |
|
|
'ContentDocumentId': content_document_id, |
|
|
'LinkedEntityId': record_id, |
|
|
'ShareType': 'V' |
|
|
} for record_id in record_ids] |
|
|
|
|
|
link_success_count = 0 |
|
|
for payload in link_payloads: |
|
|
try: |
|
|
link_result = sf.ContentDocumentLink.create(payload) |
|
|
if link_result.get('success'): |
|
|
link_success_count += 1 |
|
|
else: |
|
|
logger.warning(f"Failed to link to {payload['LinkedEntityId']}: {link_result.get('errors')}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error creating ContentDocumentLink for {payload['LinkedEntityId']}: {e}") |
|
|
|
|
|
logger.info(f"Successfully created {link_success_count}/{len(record_ids)} links.") |
|
|
|
|
|
|
|
|
sf_instance_url = sf.sf_instance.replace('https://', '') |
|
|
pdf_url = f"https://{sf_instance_url}/sfc/servlet.shepherd/version/download/{content_version_id}" |
|
|
logger.info(f"Updating records with Salesforce PDF URL: {pdf_url}") |
|
|
|
|
|
update_payloads = [{'Id': record_id, 'PDF_Report_URL__c': pdf_url} for record_id in record_ids] |
|
|
update_results = sf.bulk.Safety_Violation_Log__c.update(update_payloads) |
|
|
|
|
|
successful_updates = sum(1 for res in update_results if res.get('success')) |
|
|
logger.info(f"Successfully updated {successful_updates}/{len(record_ids)} records with the PDF URL.") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf', prefix='report_') as temp_pdf: |
|
|
temp_pdf.write(pdf_bytes) |
|
|
temp_pdf_path = temp_pdf.name |
|
|
|
|
|
logger.info(f"Salesforce report URL: {pdf_url}") |
|
|
logger.info(f"Temporary local report for download: {temp_pdf_path}") |
|
|
|
|
|
return temp_pdf_path, pdf_url |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in Salesforce PDF report generation/upload: {e}", exc_info=True) |
|
|
return None, None |
|
|
|
|
|
|
|
|
class SafetyViolationDetector: |
|
|
def __init__(self): |
|
|
|
|
|
self.helmet_threshold = 0.75 |
|
|
self.person_threshold = 0.60 |
|
|
self.unsafe_distance = 50 |
|
|
self.violation_cooldown = 20 |
|
|
|
|
|
|
|
|
self.unauthorized_zones = [ |
|
|
[100, 100, 300, 300], |
|
|
[400, 200, 600, 400] |
|
|
] |
|
|
|
|
|
self.active_violations = {} |
|
|
self.violation_history = {} |
|
|
self.person_tracker = {} |
|
|
self.person_positions_history = {} |
|
|
self.next_person_id = 1 |
|
|
self.max_tracking_distance = 120 |
|
|
|
|
|
self.session_violations = {} |
|
|
|
|
|
def reset_session(self): |
|
|
self.session_violations = {} |
|
|
self.active_violations = {} |
|
|
self.person_tracker = {} |
|
|
self.person_positions_history = {} |
|
|
self.next_person_id = 1 |
|
|
logger.info("Session violation tracking reset for new video") |
|
|
|
|
|
def has_reported_violation(self, person_id, violation_type): |
|
|
if person_id not in self.session_violations: |
|
|
return False |
|
|
return violation_type in self.session_violations[person_id] |
|
|
|
|
|
def mark_violation_reported(self, person_id, violation_type, timestamp): |
|
|
if person_id not in self.session_violations: |
|
|
self.session_violations[person_id] = {} |
|
|
self.session_violations[person_id][violation_type] = { |
|
|
'first_detected': timestamp, |
|
|
'count': self.session_violations[person_id].get(violation_type, {}).get('count', 0) + 1 |
|
|
} |
|
|
|
|
|
def _get_stable_person_id(self, box, current_time): |
|
|
center = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) |
|
|
box_area = (box[2] - box[0]) * (box[3] - box[1]) |
|
|
|
|
|
best_match_id = None |
|
|
best_match_score = 0 |
|
|
min_distance = float('inf') |
|
|
|
|
|
for person_id, history in self.person_positions_history.items(): |
|
|
if not history['positions']: |
|
|
continue |
|
|
|
|
|
last_position = history['positions'][-1] |
|
|
last_box = history['boxes'][-1] |
|
|
|
|
|
dist = np.sqrt((center[0] - last_position[0])**2 + (center[1] - last_position[1])**2) |
|
|
iou = self._iou(box, last_box) |
|
|
|
|
|
if dist < self.max_tracking_distance: |
|
|
score = (1.0 / (1.0 + dist/50)) * 0.7 + iou * 0.3 |
|
|
|
|
|
if score > best_match_score and score > 0.3: |
|
|
best_match_score = score |
|
|
best_match_id = person_id |
|
|
min_distance = dist |
|
|
|
|
|
if best_match_id is not None: |
|
|
person_id = best_match_id |
|
|
else: |
|
|
person_id = self.next_person_id |
|
|
self.next_person_id += 1 |
|
|
self.person_positions_history[person_id] = { |
|
|
'positions': [], |
|
|
'boxes': [], |
|
|
'first_seen': current_time, |
|
|
'last_seen': current_time |
|
|
} |
|
|
|
|
|
self.person_positions_history[person_id]['positions'].append(center) |
|
|
self.person_positions_history[person_id]['boxes'].append(box) |
|
|
self.person_positions_history[person_id]['last_seen'] = current_time |
|
|
|
|
|
if len(self.person_positions_history[person_id]['positions']) > 10: |
|
|
self.person_positions_history[person_id]['positions'].pop(0) |
|
|
self.person_positions_history[person_id]['boxes'].pop(0) |
|
|
|
|
|
return person_id |
|
|
|
|
|
def detect_violations(self, results, frame): |
|
|
start_time = time.time() |
|
|
current_time = time.time() |
|
|
violations = [] |
|
|
boxes = results[0].boxes.xyxy.cpu().numpy() |
|
|
confidences = results[0].boxes.conf.cpu().numpy() |
|
|
class_ids = results[0].boxes.cls.cpu().numpy().astype(int) |
|
|
class_names = results[0].names |
|
|
|
|
|
persons = [] |
|
|
helmets = [] |
|
|
|
|
|
for box, conf, cls_id in zip(boxes, confidences, class_ids): |
|
|
class_name = class_names[cls_id] |
|
|
if class_name == "person" and conf >= self.person_threshold: |
|
|
person_id = self._get_stable_person_id(box, current_time) |
|
|
persons.append({ |
|
|
'box': box, |
|
|
'confidence': conf, |
|
|
'center': ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2), |
|
|
'id': person_id |
|
|
}) |
|
|
elif class_name == "hard hat" and conf >= self.helmet_threshold: |
|
|
helmets.append({ |
|
|
'box': box, |
|
|
'confidence': conf, |
|
|
'area': (box[2] - box[0]) * (box[3] - box[1]) |
|
|
}) |
|
|
|
|
|
current_person_ids = set() |
|
|
for person in persons: |
|
|
person_id = person['id'] |
|
|
current_person_ids.add(person_id) |
|
|
|
|
|
if person_id not in self.person_tracker: |
|
|
self.person_tracker[person_id] = { |
|
|
'first_seen': current_time, |
|
|
'last_seen': current_time, |
|
|
'positions': [person['center']], |
|
|
'helmet_status': False, |
|
|
'violations': {} |
|
|
} |
|
|
else: |
|
|
self.person_tracker[person_id]['last_seen'] = current_time |
|
|
self.person_tracker[person_id]['positions'].append(person['center']) |
|
|
if len(self.person_tracker[person_id]['positions']) > 10: |
|
|
self.person_tracker[person_id]['positions'].pop(0) |
|
|
|
|
|
for person in persons: |
|
|
person_id = person['id'] |
|
|
|
|
|
helmet_violation = self._check_helmet_violation(person, helmets, frame, current_time) |
|
|
if helmet_violation: |
|
|
violations.append(helmet_violation) |
|
|
|
|
|
unauthorized_violation = self._check_unauthorized_area(person, frame, current_time) |
|
|
if unauthorized_violation: |
|
|
violations.append(unauthorized_violation) |
|
|
|
|
|
distance_violations = self._check_distance_violations(persons, frame, current_time) |
|
|
violations.extend(distance_violations) |
|
|
|
|
|
self._cleanup_violations(current_time) |
|
|
self._cleanup_inactive_persons(current_person_ids, current_time) |
|
|
|
|
|
logger.info(f"Violation detection time: {time.time() - start_time:.2f}s") |
|
|
return violations |
|
|
|
|
|
def _check_helmet_violation(self, person, helmets, frame, current_time): |
|
|
person_id = person['id'] |
|
|
person_box = person['box'] |
|
|
violation_type = 'no_helmet' |
|
|
|
|
|
if self.has_reported_violation(person_id, violation_type): |
|
|
return None |
|
|
|
|
|
head_region = [ |
|
|
person_box[0], |
|
|
max(person_box[1], person_box[1] + (person_box[3] - person_box[1]) * 0.3), |
|
|
person_box[2], |
|
|
person_box[1] + (person_box[3] - person_box[1]) * 0.3 |
|
|
] |
|
|
|
|
|
has_helmet = False |
|
|
for helmet in helmets: |
|
|
if self._iou(helmet['box'], head_region) > 0.1: |
|
|
has_helmet = True |
|
|
break |
|
|
|
|
|
self.person_tracker[person_id]['helmet_status'] = has_helmet |
|
|
|
|
|
if not has_helmet: |
|
|
violation_key = f"no_helmet_{person_id}" |
|
|
|
|
|
if (violation_key not in self.active_violations or |
|
|
current_time - self.active_violations[violation_key]['last_detected'] > self.violation_cooldown): |
|
|
|
|
|
self.mark_violation_reported(person_id, violation_type, current_time) |
|
|
|
|
|
self.active_violations[violation_key] = { |
|
|
'type': 'no_helmet', |
|
|
'person_id': person_id, |
|
|
'first_detected': current_time, |
|
|
'last_detected': current_time, |
|
|
'count': 1 |
|
|
} |
|
|
|
|
|
if 'no_helmet' not in self.person_tracker[person_id]['violations']: |
|
|
self.person_tracker[person_id]['violations']['no_helmet'] = { |
|
|
'count': 0, |
|
|
'last_time': 0 |
|
|
} |
|
|
self.person_tracker[person_id]['violations']['no_helmet']['count'] += 1 |
|
|
self.person_tracker[person_id]['violations']['no_helmet']['last_time'] = current_time |
|
|
|
|
|
self._annotate_frame(frame, person_box, person_id, "NO HELMET", (0, 0, 255)) |
|
|
logger.info(f"NEW VIOLATION: No helmet detected for person {person_id}") |
|
|
|
|
|
return { |
|
|
'type': 'no_helmet', |
|
|
'severity': 'Critical', |
|
|
'person': person, |
|
|
'person_id': person_id, |
|
|
'timestamp': current_time |
|
|
} |
|
|
else: |
|
|
self.active_violations[violation_key]['last_detected'] = current_time |
|
|
self.active_violations[violation_key]['count'] += 1 |
|
|
return None |
|
|
|
|
|
def _check_unauthorized_area(self, person, frame, current_time): |
|
|
person_id = person['id'] |
|
|
violation_type = 'unauthorized_area' |
|
|
|
|
|
if self.has_reported_violation(person_id, violation_type): |
|
|
return None |
|
|
|
|
|
x1, y1, x2, y2 = person['box'] |
|
|
person_center = ((x1 + x2) / 2, (y1 + y2) / 2) |
|
|
|
|
|
for zone in self.unauthorized_zones: |
|
|
zx1, zy1, zx2, zy2 = zone |
|
|
if (zx1 <= person_center[0] <= zx2 and zy1 <= person_center[1] <= zy2): |
|
|
violation_key = f"unauthorized_area_{person_id}_{zx1}_{zy1}" |
|
|
|
|
|
if (violation_key not in self.active_violations or |
|
|
current_time - self.active_violations[violation_key]['last_detected'] > self.violation_cooldown): |
|
|
|
|
|
self.mark_violation_reported(person_id, violation_type, current_time) |
|
|
|
|
|
self.active_violations[violation_key] = { |
|
|
'type': 'unauthorized_area', |
|
|
'person_id': person_id, |
|
|
'zone': zone, |
|
|
'first_detected': current_time, |
|
|
'last_detected': current_time, |
|
|
'count': 1 |
|
|
} |
|
|
|
|
|
if 'unauthorized_area' not in self.person_tracker[person_id]['violations']: |
|
|
self.person_tracker[person_id]['violations']['unauthorized_area'] = { |
|
|
'count': 0, |
|
|
'last_time': 0 |
|
|
} |
|
|
self.person_tracker[person_id]['violations']['unauthorized_area']['count'] += 1 |
|
|
self.person_tracker[person_id]['violations']['unauthorized_area']['last_time'] = current_time |
|
|
|
|
|
cv2.rectangle(frame, (zx1, zy1), (zx2, zy2), (255, 0, 255), 2) |
|
|
self._annotate_frame(frame, person['box'], person_id, "UNAUTHORIZED", (255, 0, 255)) |
|
|
logger.info(f"NEW VIOLATION: Unauthorized area detected for person {person_id}") |
|
|
|
|
|
return { |
|
|
'type': 'unauthorized_area', |
|
|
'severity': 'High', |
|
|
'person': person, |
|
|
'person_id': person_id, |
|
|
'zone': zone, |
|
|
'timestamp': current_time |
|
|
} |
|
|
else: |
|
|
self.active_violations[violation_key]['last_detected'] = current_time |
|
|
self.active_violations[violation_key]['count'] += 1 |
|
|
return None |
|
|
|
|
|
def _check_distance_violations(self, persons, frame, current_time): |
|
|
violations = [] |
|
|
if len(persons) < 2: |
|
|
return violations |
|
|
|
|
|
for i in range(len(persons)): |
|
|
for j in range(i+1, len(persons)): |
|
|
dist = self._euclidean_distance(persons[i]['center'], persons[j]['center']) |
|
|
if dist < self.unsafe_distance: |
|
|
person1_id = persons[i]['id'] |
|
|
person2_id = persons[j]['id'] |
|
|
violation_type = 'unsafe_distance' |
|
|
|
|
|
pair_key = f"{min(person1_id, person2_id)}_{max(person1_id, person2_id)}" |
|
|
|
|
|
if (self.has_reported_violation(person1_id, violation_type) or |
|
|
self.has_reported_violation(person2_id, violation_type)): |
|
|
continue |
|
|
|
|
|
violation_key = f"unsafe_distance_{pair_key}" |
|
|
|
|
|
if (violation_key not in self.active_violations or |
|
|
current_time - self.active_violations[violation_key]['last_detected'] > self.violation_cooldown): |
|
|
|
|
|
self.mark_violation_reported(person1_id, violation_type, current_time) |
|
|
self.mark_violation_reported(person2_id, violation_type, current_time) |
|
|
|
|
|
self.active_violations[violation_key] = { |
|
|
'type': 'unsafe_distance', |
|
|
'person1_id': person1_id, |
|
|
'person2_id': person2_id, |
|
|
'first_detected': current_time, |
|
|
'last_detected': current_time, |
|
|
'count': 1 |
|
|
} |
|
|
|
|
|
for pid in [person1_id, person2_id]: |
|
|
if 'unsafe_distance' not in self.person_tracker[pid]['violations']: |
|
|
self.person_tracker[pid]['violations']['unsafe_distance'] = { |
|
|
'count': 0, |
|
|
'last_time': 0 |
|
|
} |
|
|
self.person_tracker[pid]['violations']['unsafe_distance']['count'] += 1 |
|
|
self.person_tracker[pid]['violations']['unsafe_distance']['last_time'] = current_time |
|
|
|
|
|
self._annotate_distance(frame, persons[i]['box'], persons[j]['box'], |
|
|
person1_id, person2_id, dist) |
|
|
logger.info(f"NEW VIOLATION: Unsafe distance detected between persons {person1_id} and {person2_id}") |
|
|
|
|
|
violations.append({ |
|
|
'type': 'unsafe_distance', |
|
|
'severity': 'Moderate', |
|
|
'person1': persons[i], |
|
|
'person2': persons[j], |
|
|
'distance': dist, |
|
|
'person1_id': person1_id, |
|
|
'person2_id': person2_id, |
|
|
'timestamp': current_time |
|
|
}) |
|
|
else: |
|
|
self.active_violations[violation_key]['last_detected'] = current_time |
|
|
self.active_violations[violation_key]['count'] += 1 |
|
|
return violations |
|
|
|
|
|
def _cleanup_violations(self, current_time): |
|
|
expired_violations = [ |
|
|
k for k, v in self.active_violations.items() |
|
|
if current_time - v['last_detected'] > self.violation_cooldown |
|
|
] |
|
|
for key in expired_violations: |
|
|
del self.active_violations[key] |
|
|
|
|
|
def _cleanup_inactive_persons(self, current_person_ids, current_time): |
|
|
inactive_timeout = 60 |
|
|
expired_persons = [ |
|
|
pid for pid, data in self.person_tracker.items() |
|
|
if pid not in current_person_ids and |
|
|
current_time - data['last_seen'] > inactive_timeout |
|
|
] |
|
|
for pid in expired_persons: |
|
|
del self.person_tracker[pid] |
|
|
if pid in self.person_positions_history: |
|
|
del self.person_positions_history[pid] |
|
|
|
|
|
def _annotate_frame(self, frame, box, person_id, violation_type, color): |
|
|
x1, y1, x2, y2 = map(int, box) |
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) |
|
|
label = f"ID:{person_id:03d} {violation_type}" |
|
|
cv2.putText(frame, label, (x1, y1 - 10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) |
|
|
|
|
|
def _annotate_distance(self, frame, box1, box2, id1, id2, dist): |
|
|
x1, y1, x2, y2 = map(int, box1) |
|
|
x3, y3, x4, y4 = map(int, box2) |
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), 2) |
|
|
cv2.rectangle(frame, (x3, y3), (x4, y4), (0, 165, 255), 2) |
|
|
center1 = ((x1 + x2) // 2, (y1 + y2) // 2) |
|
|
center2 = ((x3 + x4) // 2, (y3 + y4) // 2) |
|
|
cv2.line(frame, center1, center2, (0, 165, 255), 2) |
|
|
mid_point = ((center1[0] + center2[0]) // 2, (center1[1] + center2[1]) // 2) |
|
|
cv2.putText(frame, f"{dist:.1f}px", mid_point, |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 2) |
|
|
|
|
|
def _iou(self, box1, box2): |
|
|
x1 = max(box1[0], box2[0]) |
|
|
y1 = max(box1[1], box2[1]) |
|
|
x2 = min(box1[2], box2[2]) |
|
|
y2 = min(box1[3], box2[3]) |
|
|
intersection = max(0, x2 - x1) * max(0, y2 - y1) |
|
|
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) |
|
|
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) |
|
|
return intersection / (area1 + area2 - intersection + 1e-6) |
|
|
|
|
|
def _euclidean_distance(self, point1, point2): |
|
|
return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2) |
|
|
|
|
|
def get_session_summary(self): |
|
|
summary = { |
|
|
'total_persons': len(self.session_violations), |
|
|
'violations_by_type': {}, |
|
|
'persons_with_violations': [] |
|
|
} |
|
|
|
|
|
for person_id, violations in self.session_violations.items(): |
|
|
person_info = { |
|
|
'person_id': person_id, |
|
|
'violations': list(violations.keys()), |
|
|
'violation_count': len(violations) |
|
|
} |
|
|
summary['persons_with_violations'].append(person_info) |
|
|
|
|
|
for violation_type in violations.keys(): |
|
|
if violation_type not in summary['violations_by_type']: |
|
|
summary['violations_by_type'][violation_type] = 0 |
|
|
summary['violations_by_type'][violation_type] += 1 |
|
|
|
|
|
return summary |
|
|
|
|
|
|
|
|
def preprocess_frame(frame): |
|
|
try: |
|
|
|
|
|
frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) |
|
|
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
img_resized = cv2.resize(img, (320, 320)) |
|
|
return img_resized |
|
|
except Exception as e: |
|
|
logger.error(f"Frame preprocessing error: {e}") |
|
|
raise |
|
|
|
|
|
def capture_rtsp_frames(rtsp_url, max_frames=None): |
|
|
try: |
|
|
logger.info(f"Connecting to RTSP stream: {rtsp_url}") |
|
|
cap = cv2.VideoCapture(rtsp_url) |
|
|
if not cap.isOpened(): |
|
|
raise ValueError(f"RTSP stream not accessible: {rtsp_url}") |
|
|
|
|
|
frame_count = 0 |
|
|
while cap.isOpened() and (max_frames is None or frame_count < max_frames): |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
timestamp = datetime.now(IST).isoformat() |
|
|
frame_count += 1 |
|
|
yield frame, timestamp, frame_count, None |
|
|
else: |
|
|
logger.warning("Failed to read frame from RTSP stream") |
|
|
break |
|
|
cap.release() |
|
|
except Exception as e: |
|
|
logger.error(f"RTSP capture error: {e}") |
|
|
raise |
|
|
finally: |
|
|
cv2.destroyAllWindows() |
|
|
|
|
|
|
|
|
async def process_image(image_path, progress=gr.Progress()): |
|
|
"""Process a single image for safety violations""" |
|
|
try: |
|
|
logger.info(f"Starting image analysis: {image_path}") |
|
|
start_time = time.time() |
|
|
|
|
|
current_run_violations = [] |
|
|
new_sf_record_ids = [] |
|
|
violation_payloads = [] |
|
|
tracker = SafetyViolationDetector() |
|
|
|
|
|
tracker.reset_session() |
|
|
logger.info("Starting new image analysis session") |
|
|
|
|
|
|
|
|
sf = None |
|
|
if SALESFORCE_ENABLED: |
|
|
try: |
|
|
sf = get_salesforce_connection() |
|
|
except Exception as e: |
|
|
logger.error(f"Could not connect to Salesforce: {e}") |
|
|
|
|
|
progress(0.1, desc="Loading image...") |
|
|
|
|
|
|
|
|
frame = cv2.imread(image_path) |
|
|
if frame is None: |
|
|
error_msg = f"Failed to load image: {image_path}" |
|
|
logger.error(error_msg) |
|
|
return None, error_msg, None, format_violations_as_text([]) |
|
|
|
|
|
progress(0.3, desc="Preprocessing image...") |
|
|
|
|
|
|
|
|
processed_frame = preprocess_frame(frame) |
|
|
|
|
|
progress(0.5, desc="Running AI detection...") |
|
|
|
|
|
|
|
|
results = yolo_model.predict(processed_frame) |
|
|
|
|
|
progress(0.7, desc="Analyzing violations...") |
|
|
|
|
|
|
|
|
violations = tracker.detect_violations(results, frame) |
|
|
|
|
|
violation_count = 0 |
|
|
timestamp = datetime.now(IST).isoformat() |
|
|
|
|
|
|
|
|
for violation in violations: |
|
|
violation_count += 1 |
|
|
snapshot_url = save_snapshot(frame, save_to_disk=False) |
|
|
worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}" |
|
|
if violation['type'] == 'unsafe_distance': |
|
|
worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}" |
|
|
|
|
|
violation_data = { |
|
|
'violation_type': violation['type'].replace('_', ' ').title(), |
|
|
'severity': violation['severity'], |
|
|
'timestamp': timestamp, |
|
|
'snapshot_url': snapshot_url, |
|
|
'site_id': 'SITE001', |
|
|
'camera_id': 'CAM001', |
|
|
'worker_id': worker_id, |
|
|
'frame_number': 1 |
|
|
} |
|
|
|
|
|
if violation['type'] == 'unsafe_distance': |
|
|
violation_data['distance'] = f"{violation['distance']:.1f}px" |
|
|
|
|
|
current_run_violations.append(violation_data) |
|
|
log_violation(violation_data) |
|
|
send_alert(violation_data) |
|
|
|
|
|
|
|
|
if sf: |
|
|
payload, error = create_salesforce_violation_record(sf, violation_data) |
|
|
if payload: |
|
|
violation_payloads.append(payload) |
|
|
else: |
|
|
logger.error(f"Salesforce payload creation failed: {error}") |
|
|
|
|
|
progress(0.8, desc="Creating Salesforce records...") |
|
|
|
|
|
|
|
|
if sf and violation_payloads: |
|
|
try: |
|
|
results = sf.bulk.Safety_Violation_Log__c.insert(violation_payloads) |
|
|
new_sf_record_ids = [result['id'] for result in results if result.get('success')] |
|
|
logger.info(f"Created {len(new_sf_record_ids)} Salesforce records") |
|
|
for result in results: |
|
|
if not result.get('success'): |
|
|
logger.error(f"Failed to create record: {result.get('errors')}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create bulk Salesforce records: {e}") |
|
|
|
|
|
progress(0.9, desc="Generating report...") |
|
|
|
|
|
|
|
|
pdf_temp_path = None |
|
|
if sf and new_sf_record_ids and current_run_violations: |
|
|
logger.info("Generating and uploading PDF report to Salesforce...") |
|
|
pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce( |
|
|
sf, current_run_violations, new_sf_record_ids |
|
|
) |
|
|
if not pdf_temp_path: |
|
|
logger.error("Failed to generate Salesforce report") |
|
|
elif current_run_violations and not sf: |
|
|
|
|
|
pdf_temp_path = generate_local_pdf_report(current_run_violations) |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
session_summary = tracker.get_session_summary() |
|
|
|
|
|
progress(1.0, desc="Analysis complete!") |
|
|
|
|
|
|
|
|
if violation_count > 0: |
|
|
status_message = f"""β
IMAGE ANALYSIS COMPLETED |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
π RESULTS: |
|
|
β’ Processing Time: {processing_time:.2f}s |
|
|
β’ Image: {os.path.basename(image_path)} |
|
|
π₯ UNIQUE PERSONS TRACKED: {session_summary['total_persons']} |
|
|
π VIOLATION TYPES: {', '.join(session_summary['violations_by_type'].keys())} |
|
|
π¨ UNIQUE VIOLATIONS: {violation_count} |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
Each violation reported only once per person""" |
|
|
else: |
|
|
status_message = f"""β
IMAGE ANALYSIS COMPLETED |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
π RESULTS: |
|
|
β’ Processing Time: {processing_time:.2f}s |
|
|
β’ Image: {os.path.basename(image_path)} |
|
|
π₯ UNIQUE PERSONS TRACKED: {session_summary['total_persons']} |
|
|
β
NO VIOLATIONS DETECTED |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
All safety protocols followed""" |
|
|
|
|
|
logger.info(f"Image analysis complete. Processing time: {processing_time:.2f}s") |
|
|
|
|
|
|
|
|
output_frames = [frame] if violations else None |
|
|
|
|
|
return output_frames, status_message, pdf_temp_path, format_violations_as_text(current_run_violations) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Image processing error: {e}", exc_info=True) |
|
|
error_message = f"Image processing failed: {str(e)}" |
|
|
return None, error_message, None, format_violations_as_text([]) |
|
|
|
|
|
def generate_local_pdf_report(violations): |
|
|
"""Generate a local PDF report when Salesforce is not available""" |
|
|
try: |
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf', prefix='safety_report_') |
|
|
c = canvas.Canvas(temp_file.name, pagesize=letter) |
|
|
c.setFont("Helvetica-Bold", 16) |
|
|
c.drawString(100, 750, "Safety Violation Report") |
|
|
c.setFont("Helvetica", 12) |
|
|
c.drawString(100, 730, f"Generated: {datetime.now(IST).strftime('%Y-%m-%d %H:%M:%S IST')}") |
|
|
c.setFont("Helvetica", 10) |
|
|
c.drawString(100, 710, "Note: Each violation type reported only once per person") |
|
|
|
|
|
y = 680 |
|
|
for i, violation in enumerate(violations, 1): |
|
|
c.setFont("Helvetica-Bold", 12) |
|
|
c.drawString(100, y, f"Violation #{i}: {violation['violation_type']}") |
|
|
y -= 20 |
|
|
c.setFont("Helvetica", 10) |
|
|
c.drawString(120, y, f"Severity: {violation['severity']}") |
|
|
y -= 15 |
|
|
c.drawString(120, y, f"Time: {violation['timestamp']}") |
|
|
y -= 15 |
|
|
c.drawString(120, y, f"Worker: {violation.get('worker_id', 'UNKNOWN')}") |
|
|
y -= 15 |
|
|
if 'distance' in violation: |
|
|
c.drawString(120, y, f"Distance: {violation['distance']}") |
|
|
y -= 15 |
|
|
y -= 20 |
|
|
if y < 50: |
|
|
c.showPage() |
|
|
y = 750 |
|
|
|
|
|
c.save() |
|
|
temp_file.close() |
|
|
return temp_file.name |
|
|
except Exception as e: |
|
|
logger.error(f"Local PDF generation error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
async def process_media(media_file, frame_skip=5, progress=gr.Progress()): |
|
|
"""Handle both image and video processing""" |
|
|
if media_file is None: |
|
|
return None, "No file uploaded", None, format_violations_as_text([]) |
|
|
|
|
|
file_path = media_file.name |
|
|
file_extension = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
|
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'} |
|
|
|
|
|
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm', '.m4v'} |
|
|
|
|
|
if file_extension in image_extensions: |
|
|
logger.info(f"Processing image: {file_path}") |
|
|
return await process_image(file_path, progress) |
|
|
elif file_extension in video_extensions: |
|
|
logger.info(f"Processing video: {file_path}") |
|
|
return await process_video(file_path, frame_skip, progress) |
|
|
else: |
|
|
error_msg = f"Unsupported file format: {file_extension}. Please upload an image or video file." |
|
|
logger.error(error_msg) |
|
|
return None, error_msg, None, format_violations_as_text([]) |
|
|
|
|
|
|
|
|
async def process_video(video_path, frame_skip=5, progress=gr.Progress()): |
|
|
global processing_active |
|
|
processing_active = True |
|
|
start_total = time.time() |
|
|
|
|
|
try: |
|
|
current_run_violations = [] |
|
|
new_sf_record_ids = [] |
|
|
violation_payloads = [] |
|
|
tracker = SafetyViolationDetector() |
|
|
|
|
|
tracker.reset_session() |
|
|
logger.info("Starting new video analysis session") |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
error_message = f"Failed to open video file: {video_path}" |
|
|
logger.error(error_message) |
|
|
return None, error_message, None, format_violations_as_text([]) |
|
|
|
|
|
frames = [] |
|
|
max_display_frames = 10 |
|
|
frame_count = 0 |
|
|
processed_frames = 0 |
|
|
violation_count = 0 |
|
|
|
|
|
|
|
|
sf = None |
|
|
if SALESFORCE_ENABLED: |
|
|
try: |
|
|
sf = get_salesforce_connection() |
|
|
except Exception as e: |
|
|
logger.error(f"Could not connect to Salesforce at start: {e}") |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
duration = total_frames / fps if fps > 0 else 0 |
|
|
|
|
|
progress(0, desc="Analyzing video...") |
|
|
|
|
|
while cap.isOpened() and processing_active: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
frame_count += 1 |
|
|
if frame_count % frame_skip != 0: |
|
|
continue |
|
|
|
|
|
processed_frames += 1 |
|
|
timestamp = datetime.now(IST).isoformat() |
|
|
|
|
|
progress_percent = min(100, (frame_count / total_frames) * 100) |
|
|
progress(progress_percent / 100, desc=f"Processing frame {frame_count}/{total_frames}") |
|
|
|
|
|
processed_frame = preprocess_frame(frame) |
|
|
results = yolo_model.predict(processed_frame) |
|
|
|
|
|
violations = tracker.detect_violations(results, frame) |
|
|
|
|
|
for violation in violations: |
|
|
violation_count += 1 |
|
|
snapshot_url = save_snapshot(frame, save_to_disk=False) |
|
|
worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}" |
|
|
if violation['type'] == 'unsafe_distance': |
|
|
worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}" |
|
|
violation_data = { |
|
|
'violation_type': violation['type'].replace('_', ' ').title(), |
|
|
'severity': violation['severity'], |
|
|
'timestamp': timestamp, |
|
|
'snapshot_url': snapshot_url, |
|
|
'site_id': 'SITE001', |
|
|
'camera_id': 'CAM001', |
|
|
'worker_id': worker_id, |
|
|
'frame_number': frame_count |
|
|
} |
|
|
|
|
|
if violation['type'] == 'unsafe_distance': |
|
|
violation_data['distance'] = f"{violation['distance']:.1f}px" |
|
|
|
|
|
current_run_violations.append(violation_data) |
|
|
log_violation(violation_data) |
|
|
send_alert(violation_data) |
|
|
|
|
|
if sf: |
|
|
payload, error = create_salesforce_violation_record(sf, violation_data) |
|
|
if payload: |
|
|
violation_payloads.append(payload) |
|
|
else: |
|
|
logger.error(f"Salesforce push failed for violation: {error}") |
|
|
|
|
|
if violations and len(frames) < max_display_frames: |
|
|
frames.append(frame) |
|
|
elif violations: |
|
|
frames.pop(0) |
|
|
frames.append(frame) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
if sf and violation_payloads: |
|
|
try: |
|
|
results = sf.bulk.Safety_Violation_Log__c.insert(violation_payloads) |
|
|
new_sf_record_ids = [result['id'] for result in results if result.get('success')] |
|
|
logger.info(f"Created {len(new_sf_record_ids)} Salesforce records in bulk") |
|
|
for result in results: |
|
|
if not result.get('success'): |
|
|
logger.error(f"Failed to create record: {result.get('errors')}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create bulk Salesforce records: {e}", exc_info=True) |
|
|
|
|
|
processing_time = time.time() - start_total |
|
|
actual_fps = processed_frames / processing_time if processing_time > 0 else 0 |
|
|
|
|
|
if not processing_active: |
|
|
return None, "Processing cancelled", None, format_violations_as_text([]) |
|
|
|
|
|
|
|
|
pdf_temp_path = None |
|
|
if sf and new_sf_record_ids and current_run_violations: |
|
|
logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...") |
|
|
pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids) |
|
|
if not pdf_temp_path: |
|
|
logger.error("Failed to generate and upload Salesforce report.") |
|
|
elif current_run_violations and not sf: |
|
|
|
|
|
pdf_temp_path = generate_local_pdf_report(current_run_violations) |
|
|
elif not current_run_violations: |
|
|
logger.info("No violations detected, skipping report generation.") |
|
|
else: |
|
|
logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.") |
|
|
|
|
|
session_summary = tracker.get_session_summary() |
|
|
logger.info(f"Video analysis complete. Session summary: {session_summary}") |
|
|
logger.info(f"Total processing time: {processing_time:.2f}s") |
|
|
|
|
|
status_message = generate_status_message( |
|
|
violation_count > 0, |
|
|
frame_count, |
|
|
processed_frames, |
|
|
duration, |
|
|
violation_count, |
|
|
processing_time, |
|
|
actual_fps, |
|
|
session_summary |
|
|
) |
|
|
|
|
|
return frames, status_message, pdf_temp_path, format_violations_as_text(current_run_violations) |
|
|
except Exception as e: |
|
|
logger.error(f"Video processing error: {e}", exc_info=True) |
|
|
error_message = f"Video processing failed: {str(e)}" |
|
|
return None, error_message, None, format_violations_as_text([]) |
|
|
finally: |
|
|
processing_active = False |
|
|
cv2.destroyAllWindows() |
|
|
logger.info(f"Total processing time: {time.time() - start_total:.2f}s") |
|
|
|
|
|
|
|
|
async def process_rtsp_stream(rtsp_url, max_frames=None, frame_skip=5, progress=gr.Progress()): |
|
|
global processing_active |
|
|
processing_active = True |
|
|
start_total = time.time() |
|
|
|
|
|
try: |
|
|
if not rtsp_url: |
|
|
raise ValueError("RTSP URL not provided") |
|
|
|
|
|
current_run_violations = [] |
|
|
new_sf_record_ids = [] |
|
|
violation_payloads = [] |
|
|
tracker = SafetyViolationDetector() |
|
|
|
|
|
tracker.reset_session() |
|
|
logger.info("Starting new RTSP stream analysis session") |
|
|
|
|
|
|
|
|
sf = None |
|
|
if SALESFORCE_ENABLED: |
|
|
try: |
|
|
sf = get_salesforce_connection() |
|
|
except Exception as e: |
|
|
logger.error(f"Could not connect to Salesforce at start: {e}") |
|
|
|
|
|
frames = [] |
|
|
max_display_frames = 10 |
|
|
violation_count = 0 |
|
|
|
|
|
progress(0, desc="Connecting to RTSP stream...") |
|
|
|
|
|
for frame, timestamp, fc, _ in capture_rtsp_frames(rtsp_url, max_frames): |
|
|
if not processing_active: |
|
|
break |
|
|
|
|
|
if fc % frame_skip != 0: |
|
|
continue |
|
|
|
|
|
progress_percent = min(100, (fc / (max_frames if max_frames else 100)) * 100) |
|
|
progress(progress_percent / 100, desc=f"Processing frame {fc}") |
|
|
|
|
|
processed_frame = preprocess_frame(frame) |
|
|
results = yolo_model.predict(processed_frame) |
|
|
|
|
|
violations = tracker.detect_violations(results, frame) |
|
|
|
|
|
for violation in violations: |
|
|
violation_count += 1 |
|
|
snapshot_url = save_snapshot(frame, save_to_disk=False) |
|
|
worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}" |
|
|
if violation['type'] == 'unsafe_distance': |
|
|
worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}" |
|
|
violation_data = { |
|
|
'violation_type': violation['type'].replace('_', ' ').title(), |
|
|
'severity': violation['severity'], |
|
|
'timestamp': timestamp, |
|
|
'snapshot_url': snapshot_url, |
|
|
'site_id': 'SITE001', |
|
|
'camera_id': 'CAM001', |
|
|
'worker_id': worker_id, |
|
|
'frame_number': fc |
|
|
} |
|
|
|
|
|
if violation['type'] == 'unsafe_distance': |
|
|
violation_data['distance'] = f"{violation['distance']:.1f}px" |
|
|
|
|
|
current_run_violations.append(violation_data) |
|
|
log_violation(violation_data) |
|
|
send_alert(violation_data) |
|
|
|
|
|
if sf: |
|
|
payload, error = create_salesforce_violation_record(sf, violation_data) |
|
|
if payload: |
|
|
violation_payloads.append(payload) |
|
|
else: |
|
|
logger.error(f"Salesforce push failed for violation: {error}") |
|
|
|
|
|
if violations and len(frames) < max_display_frames: |
|
|
frames.append(frame) |
|
|
elif violations: |
|
|
frames.pop(0) |
|
|
frames.append(frame) |
|
|
|
|
|
if sf and violation_payloads: |
|
|
try: |
|
|
results = sf.bulk.Safety_Violation_Log__c.insert(violation_payloads) |
|
|
new_sf_record_ids = [result['id'] for result in results if result.get('success')] |
|
|
logger.info(f"Created {len(new_sf_record_ids)} Salesforce records in bulk") |
|
|
for result in results: |
|
|
if not result.get('success'): |
|
|
logger.error(f"Failed to create record: {result.get('errors')}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create bulk Salesforce records: {e}", exc_info=True) |
|
|
|
|
|
if not processing_active: |
|
|
logger.info("Processing cancelled.") |
|
|
|
|
|
|
|
|
pdf_temp_path = None |
|
|
if sf and new_sf_record_ids and current_run_violations: |
|
|
logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...") |
|
|
pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids) |
|
|
if not pdf_temp_path: |
|
|
logger.error("Failed to generate and upload Salesforce report.") |
|
|
elif current_run_violations and not sf: |
|
|
|
|
|
pdf_temp_path = generate_local_pdf_report(current_run_violations) |
|
|
elif not current_run_violations: |
|
|
logger.info("No violations detected, skipping report generation.") |
|
|
else: |
|
|
logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.") |
|
|
|
|
|
if not processing_active: |
|
|
return "Processing cancelled.", frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations, generate=False), pdf_temp_path |
|
|
|
|
|
session_summary = tracker.get_session_summary() |
|
|
logger.info(f"RTSP analysis complete. Session summary: {session_summary}") |
|
|
logger.info(f"Total processing time: {time.time() - start_total:.2f}s") |
|
|
|
|
|
status_message = f"Processed {len(frames)} frames with {violation_count} unique violations. Persons tracked: {session_summary['total_persons']}" |
|
|
|
|
|
return status_message, frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations, generate=False), pdf_temp_path |
|
|
except Exception as e: |
|
|
logger.error(f"RTSP processing error: {e}", exc_info=True) |
|
|
error_message = f"RTSP processing failed: {str(e)}" |
|
|
return error_message, None, format_violations_as_text([]), None, None |
|
|
finally: |
|
|
processing_active = False |
|
|
cv2.destroyAllWindows() |
|
|
logger.info(f"Total processing time: {time.time() - start_total:.2f}s") |
|
|
|
|
|
|
|
|
def generate_status_message(has_violations, total_frames, processed_frames, duration, |
|
|
violation_count, processing_time, actual_fps, session_summary=None): |
|
|
base_message = f"""β
ANALYSIS COMPLETED |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
π RESULTS: |
|
|
β’ Frames: {total_frames} (Processed: {processed_frames}) |
|
|
β’ Duration: {duration:.2f}s |
|
|
β’ Processing Time: {processing_time:.2f}s |
|
|
β’ FPS: {actual_fps:.1f}""" |
|
|
|
|
|
if session_summary: |
|
|
base_message += f""" |
|
|
π₯ UNIQUE PERSONS TRACKED: {session_summary['total_persons']} |
|
|
π VIOLATION TYPES: {', '.join(session_summary['violations_by_type'].keys()) if session_summary['violations_by_type'] else 'None'}""" |
|
|
|
|
|
if has_violations: |
|
|
return f"""{base_message} |
|
|
π¨ UNIQUE VIOLATIONS: {violation_count} |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
Each violation reported only once per person""" |
|
|
else: |
|
|
return f"""{base_message} |
|
|
β
NO VIOLATIONS DETECTED |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
All safety protocols followed""" |
|
|
|
|
|
def save_snapshot(frame, save_to_disk=True): |
|
|
try: |
|
|
if not save_to_disk: |
|
|
return "in_memory_snapshot.jpg" |
|
|
filename = f"snapshot_{int(time.time())}.jpg" |
|
|
snapshot_dir = "./snapshots" |
|
|
os.makedirs(snapshot_dir, exist_ok=True) |
|
|
snapshot_path = os.path.join(snapshot_dir, filename) |
|
|
cv2.imwrite(snapshot_path, frame) |
|
|
return snapshot_path |
|
|
except Exception as e: |
|
|
logger.error(f"Snapshot error: {e}") |
|
|
return "snapshot_failed.jpg" |
|
|
|
|
|
def log_violation(violation_data): |
|
|
try: |
|
|
log_file = Path("./snapshots/violation_logs.json") |
|
|
logs = [] |
|
|
if log_file.exists(): |
|
|
with open(log_file, "r") as f: |
|
|
logs = json.load(f) |
|
|
logs.append(violation_data) |
|
|
global recent_violations, violation_history |
|
|
recent_violations = logs[-10:] |
|
|
violation_history = logs |
|
|
with open(log_file, "w") as f: |
|
|
json.dump(logs, f, indent=4) |
|
|
except Exception as e: |
|
|
logger.error(f"Logging error: {e}") |
|
|
|
|
|
def send_alert(violation): |
|
|
logger.info(f"ALERT: {violation['violation_type']} detected (Severity: {violation['severity']})") |
|
|
|
|
|
def format_violations_as_text(violations): |
|
|
if not violations: |
|
|
return """π SAFETY MONITORING STATUS |
|
|
|
|
|
β
NO VIOLATIONS DETECTED |
|
|
ββββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
π Current Status: ALL CLEAR |
|
|
π Last Updated: """ + datetime.now(IST).strftime('%Y-%m-%d %H:%M:%S IST') + """ |
|
|
π― Detection Accuracy: >90% confidence |
|
|
β‘ Response Time: <5 seconds |
|
|
|
|
|
The system is actively monitoring for: |
|
|
β’ No Helmet violations |
|
|
β’ Unsafe Distance violations |
|
|
β’ Unauthorized Area violations |
|
|
|
|
|
All safety protocols are currently being followed.""" |
|
|
|
|
|
text = f"""π¨ SAFETY VIOLATION ALERTS |
|
|
ββββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
π UNIQUE VIOLATIONS DETECTED: {len(violations)} |
|
|
Note: Each violation type reported only once per person |
|
|
|
|
|
""" |
|
|
for i, violation in enumerate(violations, 1): |
|
|
severity_emoji = "π΄" if violation['severity'] == 'Critical' else "π‘" |
|
|
text += f""" |
|
|
ββ ALERT #{i:02d} β {severity_emoji} {violation['violation_type'].upper()} |
|
|
β |
|
|
ββ π Time: {violation['timestamp']} |
|
|
ββ β οΈ Severity: {violation['severity']} |
|
|
ββ π Location: Site {violation['site_id']} | Camera {violation['camera_id']} |
|
|
ββ π· Worker: {violation.get('worker_id', 'UNKNOWN')} |
|
|
ββ πΈ Evidence: {violation['snapshot_url']} |
|
|
β |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββ\n""" |
|
|
|
|
|
text += f""" |
|
|
|
|
|
π SUMMARY STATISTICS: |
|
|
ββββββββββββββββββββββββββββββββββββββββββ |
|
|
β’ Total Violations: {len(violations)} |
|
|
β’ Critical: {sum(1 for v in violations if v['severity'] == 'Critical')} |
|
|
β’ Moderate: {sum(1 for v in violations if v['severity'] == 'Moderate')} |
|
|
β’ Last Alert: {violations[-1]['timestamp'] if violations else 'N/A'} |
|
|
|
|
|
π System Status: ACTIVELY MONITORING |
|
|
β‘ Response Time: <5 seconds |
|
|
π― Detection Accuracy: >90% confidence""" |
|
|
return text |
|
|
|
|
|
def generate_heatmap(violations, generate=True): |
|
|
if not generate or not violations: |
|
|
return None |
|
|
try: |
|
|
df = pd.DataFrame(violations) |
|
|
df['hour'] = pd.to_datetime(df['timestamp']).dt.hour |
|
|
heatmap_data = df.pivot_table(index='hour', columns='violation_type', aggfunc='size', fill_value=0) |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
sns.heatmap(heatmap_data, cmap='YlOrRd', annot=True, fmt='d') |
|
|
plt.title("Unique Violations by Hour") |
|
|
plt.xlabel("Violation Type") |
|
|
plt.ylabel("Hour of Day") |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
plt.savefig(temp_file.name, bbox_inches='tight') |
|
|
plt.close() |
|
|
return temp_file.name |
|
|
except Exception as e: |
|
|
logger.error(f"Heatmap error: {e}") |
|
|
return None |
|
|
|
|
|
def cancel_processing(): |
|
|
global processing_active |
|
|
processing_active = False |
|
|
cv2.destroyAllWindows() |
|
|
return "Processing cancelled" |
|
|
|
|
|
|
|
|
enhanced_custom_css = """ |
|
|
/* Reset Default Styles */ |
|
|
* { |
|
|
margin: 0; |
|
|
padding: 0; |
|
|
box-sizing: border-box; |
|
|
} |
|
|
|
|
|
/* Global Theme and Layout */ |
|
|
.gradio-container { |
|
|
font-family: 'Poppins', 'Inter', 'Segoe UI', 'Roboto', sans-serif !important; |
|
|
background: linear-gradient(45deg, #0a0a1f, #1a0033, #2a0044, #0a0a1f) !important; |
|
|
background-size: 400% !important; |
|
|
animation: gradientShift 12s ease infinite !important; |
|
|
min-height: 100vh !important; |
|
|
display: flex !important; |
|
|
flex-direction: column !important; |
|
|
justify-content: center !important; |
|
|
align-items: center !important; |
|
|
position: relative !important; |
|
|
overflow: hidden !important; |
|
|
} |
|
|
|
|
|
@keyframes gradientShift { |
|
|
0% { background-position: 0% 50%; } |
|
|
50% { background-position: 100% 50%; } |
|
|
100% { background-position: 0% 50%; } |
|
|
} |
|
|
|
|
|
/* Particle Canvas for Star-Like Animation */ |
|
|
#particle-canvas { |
|
|
position: absolute !important; |
|
|
top: 0 !important; |
|
|
left: 0 !important; |
|
|
width: 100% !important; |
|
|
height: 100% !important; |
|
|
z-index: 1 !important; |
|
|
pointer-events: none !important; |
|
|
} |
|
|
|
|
|
/* Main Header Styling */ |
|
|
.main-header { |
|
|
background: rgba(255, 255, 255, 0.12) !important; |
|
|
backdrop-filter: blur(15px) !important; |
|
|
border: 1px solid rgba(255, 0, 204, 0.3) !important; |
|
|
color: #f0f0f5 !important; |
|
|
text-align: center !important; |
|
|
padding: 1rem !important; |
|
|
margin-bottom: 1rem !important; |
|
|
border-radius: 15px !important; |
|
|
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important; |
|
|
z-index: 2 !important; |
|
|
animation: fadeIn 1s ease-out !important; |
|
|
} |
|
|
|
|
|
.header-title { |
|
|
font-size: 2.5rem !important; |
|
|
font-weight: 700 !important; |
|
|
text-shadow: 0 0 12px rgba(255, 0, 204, 0.8) !important; |
|
|
animation: glow 2s ease-in-out infinite alternate !important; |
|
|
margin-bottom: 0.5rem !important; |
|
|
} |
|
|
|
|
|
.header-subtitle { |
|
|
font-size: 1rem !important; |
|
|
font-weight: 400 !important; |
|
|
color: #ccc !important; |
|
|
text-shadow: 0 0 5px rgba(255, 0, 204, 0.5) !important; |
|
|
} |
|
|
|
|
|
@keyframes glow { |
|
|
from { text-shadow: 0 0 5px rgba(255, 0, 204, 0.5), 0 0 10px rgba(255, 0, 204, 0.3); } |
|
|
to { text-shadow: 0 0 12px rgba(255, 0, 204, 0.9), 0 0 20px rgba(255, 0, 204, 0.6); } |
|
|
} |
|
|
|
|
|
@keyframes fadeIn { |
|
|
from { opacity: 0; transform: translateY(20px); } |
|
|
to { opacity: 1; transform: translateY(0); } |
|
|
} |
|
|
|
|
|
/* Professional Card System (Glassmorphism) */ |
|
|
.professional-card { |
|
|
background: rgba(255, 255, 255, 0.12) !important; |
|
|
backdrop-filter: blur(15px) !important; |
|
|
border: 1px solid rgba(255, 0, 204, 0.3) !important; |
|
|
border-radius: 15px !important; |
|
|
padding: 1rem !important; |
|
|
margin: 0.5rem 0 !important; |
|
|
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important; |
|
|
z-index: 2 !important; |
|
|
transition: all 0.3s ease !important; |
|
|
} |
|
|
|
|
|
.professional-card:hover { |
|
|
transform: translateY(-2px) !important; |
|
|
box-shadow: 0 0 30px rgba(255, 0, 204, 0.35) !important; |
|
|
} |
|
|
|
|
|
/* Section Headers */ |
|
|
.section-header { |
|
|
background: rgba(255, 255, 255, 0.12) !important; |
|
|
backdrop-filter: blur(15px) !important; |
|
|
border: 1px solid rgba(255, 0, 204, 0.3) !important; |
|
|
color: #f0f0f5 !important; |
|
|
padding: 0.8rem 1rem !important; |
|
|
border-radius: 10px !important; |
|
|
text-align: center !important; |
|
|
font-weight: 700 !important; |
|
|
font-size: 1.2rem !important; |
|
|
margin-bottom: 1rem !important; |
|
|
text-shadow: 0 0 12px rgba(255, 0, 204, 0.8) !important; |
|
|
animation: glow 2s ease-in-out infinite alternate !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
/* Button Styling */ |
|
|
.btn-primary, .gr-button { |
|
|
background: linear-gradient(90deg, #00C4B4, #ff00cc) !important; |
|
|
border: none !important; |
|
|
border-radius: 8px !important; |
|
|
padding: 10px 20px !important; |
|
|
color: white !important; |
|
|
font-weight: 600 !important; |
|
|
font-size: 1rem !important; |
|
|
transition: all 0.3s ease !important; |
|
|
box-shadow: 0 0 15px rgba(255, 0, 204, 0.4) !important; |
|
|
position: relative !important; |
|
|
overflow: hidden !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
.btn-primary:hover, .gr-button:hover { |
|
|
background: linear-gradient(90deg, #00C6B6, #ff33cc) !important; |
|
|
box-shadow: 0 0 20px rgba(255, 0, 204, 0.6) !important; |
|
|
transform: translateY(-2px) !important; |
|
|
} |
|
|
|
|
|
.btn-primary::before, .gr-button::before { |
|
|
content: '' !important; |
|
|
position: absolute !important; |
|
|
top: 50% !important; |
|
|
left: 50% !important; |
|
|
width: 300% !important; |
|
|
height: 300% !important; |
|
|
background: rgba(255, 255, 255, 0.1) !important; |
|
|
transition: all 0.5s ease !important; |
|
|
transform: translate(-50%, -50%) rotate(45deg) !important; |
|
|
opacity: 0 !important; |
|
|
} |
|
|
|
|
|
.btn-primary:hover::before, .gr-button:hover::before { |
|
|
opacity: 1 !important; |
|
|
width: 0 !important; |
|
|
height: 0 !important; |
|
|
} |
|
|
|
|
|
.btn-secondary { |
|
|
background: linear-gradient(90deg, #11998e, #38ef7d) !important; |
|
|
border: none !important; |
|
|
border-radius: 8px !important; |
|
|
padding: 10px 20px !important; |
|
|
color: white !important; |
|
|
font-weight: 600 !important; |
|
|
font-size: 1rem !important; |
|
|
transition: all 0.3s ease !important; |
|
|
box-shadow: 0 0 15px rgba(17, 153, 142, 0.4) !important; |
|
|
} |
|
|
|
|
|
.btn-secondary:hover { |
|
|
background: linear-gradient(90deg, #0e867c, #33d670) !important; |
|
|
box-shadow: 0 0 20px rgba(17, 153, 142, 0.6) !important; |
|
|
transform: translateY(-2px) !important; |
|
|
} |
|
|
|
|
|
/* Status Display */ |
|
|
.status-display { |
|
|
background: rgba(255, 255, 255, 0.05) !important; |
|
|
border: 1px solid rgba(255, 255, 255, 0.15) !important; |
|
|
border-radius: 8px !important; |
|
|
padding: 1rem !important; |
|
|
color: #f0f0f5 !important; |
|
|
font-size: 0.95rem !important; |
|
|
font-family: 'Fira Code', 'Consolas', monospace !important; |
|
|
white-space: pre-wrap !important; |
|
|
max-height: 300px !important; |
|
|
overflow-y: auto !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
.status-display:focus { |
|
|
border-color: #ff00cc !important; |
|
|
box-shadow: 0 0 8px rgba(255, 0, 204, 0.6) !important; |
|
|
} |
|
|
|
|
|
.status-display::-webkit-scrollbar { |
|
|
width: 5px !important; |
|
|
} |
|
|
|
|
|
.status-display::-webkit-scrollbar-track { |
|
|
background: rgba(255, 255, 255, 0.05) !important; |
|
|
border-radius: 3px !important; |
|
|
} |
|
|
|
|
|
.status-display::-webkit-scrollbar-thumb { |
|
|
background: #ff00cc !important; |
|
|
border-radius: 3px !important; |
|
|
} |
|
|
|
|
|
/* Alert Panel */ |
|
|
.alert-panel { |
|
|
background: rgba(255, 51, 51, 0.15) !important; |
|
|
backdrop-filter: blur(10px) !important; |
|
|
border: 1px solid rgba(255, 51, 51, 0.5) !important; |
|
|
color: #ff3333 !important; |
|
|
border-radius: 15px !important; |
|
|
padding: 1rem !important; |
|
|
margin: 0.5rem 0 !important; |
|
|
box-shadow: 0 0 20px rgba(255, 51, 51, 0.3) !important; |
|
|
animation: alertPulse 2s infinite !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
@keyframes alertPulse { |
|
|
0%, 100% { transform: scale(1); opacity: 1; } |
|
|
50% { transform: scale(1.01); opacity: 0.95; } |
|
|
} |
|
|
|
|
|
/* Success Panel */ |
|
|
.success-panel { |
|
|
background: rgba(0, 184, 148, 0.15) !important; |
|
|
backdrop-filter: blur(10px) !important; |
|
|
border: 1px solid rgba(0, 184, 148, 0.5) !important; |
|
|
color: #00b894 !important; |
|
|
border-radius: 15px !important; |
|
|
padding: 1rem !important; |
|
|
margin: 0.5rem 0 !important; |
|
|
box-shadow: 0 0 20px rgba(0, 184, 148, 0.3) !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
/* Image Components */ |
|
|
.image-component { |
|
|
border-radius: 15px !important; |
|
|
overflow: hidden !important; |
|
|
box-shadow: 0 0 20px rgba(255, 0, 204, 0.2) !important; |
|
|
transition: all 0.3s ease !important; |
|
|
border: 1px solid rgba(255, 0, 204, 0.2) !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
.image-component:hover { |
|
|
transform: scale(1.01) !important; |
|
|
box-shadow: 0 0 25px rgba(255, 0, 204, 0.3) !important; |
|
|
} |
|
|
|
|
|
/* Gallery Styling */ |
|
|
.gallery-component { |
|
|
border-radius: 15px !important; |
|
|
overflow: hidden !important; |
|
|
box-shadow: 0 0 20px rgba(255, 0, 204, 0.2) !important; |
|
|
background: rgba(255, 255, 255, 0.05) !important; |
|
|
padding: 0.5rem !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
/* File Download Component */ |
|
|
.file-component { |
|
|
background: rgba(255, 255, 255, 0.05) !important; |
|
|
border: 1px dashed rgba(255, 0, 204, 0.3) !important; |
|
|
border-radius: 10px !important; |
|
|
padding: 1rem !important; |
|
|
text-align: center !important; |
|
|
transition: all 0.3s ease !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
.file-component:hover { |
|
|
background: rgba(255, 255, 255, 0.1) !important; |
|
|
transform: translateY(-2px) !important; |
|
|
} |
|
|
|
|
|
/* Analytics Dashboard */ |
|
|
.analytics-panel { |
|
|
background: rgba(255, 255, 255, 0.12) !important; |
|
|
backdrop-filter: blur(15px) !important; |
|
|
border-radius: 15px !important; |
|
|
padding: 1rem !important; |
|
|
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
/* Tab Styling */ |
|
|
.gradio-tabs { |
|
|
border: none !important; |
|
|
background: transparent !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
.gradio-tab-item { |
|
|
background: rgba(255, 255, 255, 0.05) !important; |
|
|
border: 1px solid rgba(255, 0, 204, 0.2) !important; |
|
|
border-radius: 8px !important; |
|
|
color: #ccc !important; |
|
|
padding: 0.5rem 1rem !important; |
|
|
margin: 0 0.2rem !important; |
|
|
transition: all 0.3s ease !important; |
|
|
} |
|
|
|
|
|
.gradio-tab-item.selected { |
|
|
background: rgba(255, 255, 255, 0.15) !important; |
|
|
color: #ff00cc !important; |
|
|
font-weight: 600 !important; |
|
|
border-color: #ff00cc !important; |
|
|
box-shadow: 0 0 10px rgba(255, 0, 204, 0.5) !important; |
|
|
} |
|
|
|
|
|
.gradio-tab-item:hover { |
|
|
background: rgba(255, 255, 255, 0.1) !important; |
|
|
color: #ff00cc !important; |
|
|
} |
|
|
|
|
|
.gradio-tab-content { |
|
|
background: rgba(255, 255, 255, 0.12) !important; |
|
|
backdrop-filter: blur(15px) !important; |
|
|
border-radius: 15px !important; |
|
|
padding: 1rem !important; |
|
|
box-shadow: 0 0 25px rgba(255, 0, 204, 0.25) !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
/* Footer Styling */ |
|
|
.footer-info { |
|
|
background: rgba(255, 255, 255, 0.12) !important; |
|
|
backdrop-filter: blur(15px) !important; |
|
|
border: 1px solid rgba(255, 0, 204, 0.3) !important; |
|
|
border-radius: 15px !important; |
|
|
padding: 1rem !important; |
|
|
margin-top: 1rem !important; |
|
|
text-align: center !important; |
|
|
color: #f0f0f5 !important; |
|
|
z-index: 2 !important; |
|
|
} |
|
|
|
|
|
.feature-grid { |
|
|
display: grid !important; |
|
|
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)) !important; |
|
|
gap: 0.5rem !important; |
|
|
margin-top: 1rem !important; |
|
|
} |
|
|
|
|
|
.feature-item { |
|
|
background: rgba(255, 255, 255, 0.05) !important; |
|
|
padding: 0.5rem !important; |
|
|
border-radius: 8px !important; |
|
|
text-align: center !important; |
|
|
transition: all 0.3s ease !important; |
|
|
color: #ccc !important; |
|
|
} |
|
|
|
|
|
.feature-item:hover { |
|
|
background: rgba(255, 255, 255, 0.1) !important; |
|
|
transform: translateY(-2px) !important; |
|
|
color: #ff00cc !important; |
|
|
} |
|
|
|
|
|
/* Responsive Design */ |
|
|
@media (max-width: 768px) { |
|
|
.gradio-container { |
|
|
padding: 1rem !important; |
|
|
} |
|
|
.main-header { |
|
|
padding: 0.8rem !important; |
|
|
} |
|
|
.header-title { |
|
|
font-size: 2rem !important; |
|
|
} |
|
|
.professional-card { |
|
|
padding: 0.8rem !important; |
|
|
margin: 0.3rem 0 !important; |
|
|
} |
|
|
.section-header { |
|
|
font-size: 1rem !important; |
|
|
padding: 0.6rem !important; |
|
|
} |
|
|
.btn-primary, .btn-secondary, .gr-button { |
|
|
padding: 8px 16px !important; |
|
|
font-size: 0.9rem !important; |
|
|
} |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Dynamic Safety Violation Detection using CCTV + AI", |
|
|
css=enhanced_custom_css, |
|
|
theme=gr.themes.Soft( |
|
|
primary_hue="blue", |
|
|
secondary_hue="emerald", |
|
|
neutral_hue="slate", |
|
|
radius_size="lg", |
|
|
spacing_size="sm", |
|
|
font=[ |
|
|
gr.themes.GoogleFont("Poppins"), |
|
|
"ui-sans-serif", |
|
|
"system-ui", |
|
|
"sans-serif" |
|
|
] |
|
|
).set( |
|
|
body_background_fill="none", |
|
|
block_background_fill="none", |
|
|
block_border_width="0px", |
|
|
block_shadow="none", |
|
|
block_radius="15px", |
|
|
button_primary_background_fill="none", |
|
|
button_primary_background_fill_hover="none", |
|
|
button_secondary_background_fill="none" |
|
|
) |
|
|
) as demo: |
|
|
|
|
|
gr.HTML(""" |
|
|
<canvas id="particle-canvas"></canvas> |
|
|
<script> |
|
|
const canvas = document.getElementById('particle-canvas'); |
|
|
const ctx = canvas.getContext('2d'); |
|
|
|
|
|
canvas.width = window.innerWidth; |
|
|
canvas.height = window.innerHeight; |
|
|
|
|
|
window.addEventListener('resize', () => { |
|
|
canvas.width = window.innerWidth; |
|
|
canvas.height = window.innerHeight; |
|
|
}); |
|
|
|
|
|
const stars = []; |
|
|
const starCount = 200; |
|
|
|
|
|
class Star { |
|
|
constructor() { |
|
|
this.reset(); |
|
|
} |
|
|
|
|
|
reset() { |
|
|
this.x = Math.random() * canvas.width; |
|
|
this.y = Math.random() * canvas.height; |
|
|
this.size = Math.random() * 2 + 0.5; |
|
|
this.speedX = Math.random() * 5 + 2; |
|
|
this.speedY = (Math.random() - 0.5) * 0.5; |
|
|
this.opacity = Math.random() * 0.5 + 0.5; |
|
|
this.twinklePhase = Math.random() * Math.PI * 2; |
|
|
this.twinkleSpeed = Math.random() * 0.05 + 0.02; |
|
|
} |
|
|
|
|
|
update() { |
|
|
this.x -= this.speedX; |
|
|
this.y += this.speedY; |
|
|
this.twinklePhase += this.twinkleSpeed; |
|
|
this.currentOpacity = this.opacity * (0.5 + 0.5 * Math.sin(this.twinklePhase)); |
|
|
if (this.x < 0 || this.y < 0 || this.y > canvas.height) { |
|
|
this.reset(); |
|
|
} |
|
|
} |
|
|
|
|
|
draw() { |
|
|
ctx.fillStyle = `rgba(255, 0, 204, ${this.currentOpacity})`; |
|
|
ctx.beginPath(); |
|
|
ctx.arc(this.x, this.y, this.size, 0, Math.PI * 2); |
|
|
ctx.fill(); |
|
|
|
|
|
ctx.fillStyle = `rgba(255, 0, 204, ${this.currentOpacity * 0.3})`; |
|
|
ctx.beginPath(); |
|
|
ctx.arc(this.x, this.y, this.size * 2, 0, Math.PI * 2); |
|
|
ctx.fill(); |
|
|
} |
|
|
} |
|
|
|
|
|
function initStars() { |
|
|
for (let i = 0; i < starCount; i++) { |
|
|
const star = new Star(); |
|
|
star.x = Math.random() * canvas.width; |
|
|
stars.push(star); |
|
|
} |
|
|
} |
|
|
|
|
|
function animateStars() { |
|
|
ctx.clearRect(0, 0, canvas.width, canvas.height); |
|
|
stars.forEach(star => { |
|
|
star.update(); |
|
|
star.draw(); |
|
|
}); |
|
|
requestAnimationFrame(animateStars); |
|
|
} |
|
|
|
|
|
initStars(); |
|
|
animateStars(); |
|
|
</script> |
|
|
""") |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="main-header"> |
|
|
<h1 class="header-title">π Dynamic Safety Violation Detection using CCTV + AI</h1> |
|
|
<p class="header-subtitle">Enhanced Multi-Person Tracking with Image & Video Analysis - Each violation type detected only once per person</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
gr.HTML('<div class="section-header">π· Smart Media Analysis (Images & Videos)</div>') |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes=["professional-card"]): |
|
|
media_input = gr.File( |
|
|
label="π€ Upload Image or Video for Safety Analysis", |
|
|
file_types=["image", "video"], |
|
|
elem_classes=["image-component"], |
|
|
height=200 |
|
|
) |
|
|
frame_skip_input = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
step=1, |
|
|
value=5, |
|
|
label="Frame Skip (Higher = Faster Processing, Videos Only)" |
|
|
) |
|
|
with gr.Row(): |
|
|
media_button = gr.Button( |
|
|
"π Analyze Media", |
|
|
variant="primary", |
|
|
elem_classes=["btn-primary"], |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML('<div class="section-header">π Analysis Results & Violation Details</div>') |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes=["professional-card"]): |
|
|
media_output = gr.Gallery( |
|
|
label="πΌοΈ Processed Media with Detection Results", |
|
|
elem_classes=["gallery-component"], |
|
|
height=260 |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes=["professional-card"]): |
|
|
media_status = gr.Textbox( |
|
|
label="π Analysis Status", |
|
|
elem_classes=["status-display"], |
|
|
lines=7, |
|
|
max_lines=9, |
|
|
value="π Awaiting Media Analysis\n\nUpload an image or video and click 'Analyze Media' to begin safety violation detection.\n\nβ’ Images: Instant analysis\nβ’ Videos: Frame-by-frame processing", |
|
|
interactive=False |
|
|
) |
|
|
pdf_output = gr.File( |
|
|
label="π₯ Download Professional Report", |
|
|
elem_classes=["file-component"] |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML('<div class="section-header">π¨ Real-time Violation Monitoring</div>') |
|
|
with gr.Group(elem_classes=["professional-card", "alert-panel"]): |
|
|
violation_log = gr.Textbox( |
|
|
label="π¨ Real-time Violation Details", |
|
|
elem_classes=["status-display"], |
|
|
lines=10, |
|
|
max_lines=12, |
|
|
value=format_violations_as_text(recent_violations), |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML('<div class="section-header">πΉ Live Stream Monitoring</div>') |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
with gr.Group(elem_classes=["professional-card"]): |
|
|
rtsp_url_input = gr.Textbox( |
|
|
label="π‘ RTSP Stream URL", |
|
|
placeholder="rtsp://example.com/stream", |
|
|
value=RTSP_URL_DEFAULT, |
|
|
interactive=True |
|
|
) |
|
|
with gr.Row(): |
|
|
rtsp_button = gr.Button( |
|
|
"π‘ Start Live Monitoring", |
|
|
variant="primary", |
|
|
elem_classes=["btn-primary"], |
|
|
size="lg" |
|
|
) |
|
|
rtsp_cancel_btn = gr.Button( |
|
|
"βΉοΈ Stop Monitoring", |
|
|
variant="secondary", |
|
|
elem_classes=["btn-secondary"], |
|
|
size="lg" |
|
|
) |
|
|
rtsp_status = gr.Textbox( |
|
|
label="πΊ Live Stream Processing Status", |
|
|
elem_classes=["status-display"], |
|
|
lines=6, |
|
|
max_lines=8, |
|
|
value="πΊ RTSP Stream Processor Ready\n\nEnter an RTSP URL and click 'Start Live Monitoring' to begin real-time monitoring.", |
|
|
interactive=False |
|
|
) |
|
|
with gr.Column(scale=3): |
|
|
with gr.Group(elem_classes=["professional-card"]): |
|
|
rtsp_output = gr.Gallery( |
|
|
label="π¬ Live Stream Frames & Detection Results", |
|
|
elem_classes=["gallery-component"], |
|
|
height=360, |
|
|
columns=3, |
|
|
rows=2, |
|
|
object_fit="cover" |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML('<div class="section-header">π Live Violation Analytics</div>') |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes=["professional-card", "alert-panel"]): |
|
|
rtsp_violation_log = gr.Textbox( |
|
|
label="π¨ Live Violation Log", |
|
|
elem_classes=["status-display"], |
|
|
lines=8, |
|
|
max_lines=10, |
|
|
interactive=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes=["professional-card", "analytics-panel"]): |
|
|
heatmap_output = gr.Image( |
|
|
label="π₯ Violation Heatmap - Temporal Analysis", |
|
|
elem_classes=["image-component"], |
|
|
height=320 |
|
|
) |
|
|
rtsp_pdf_output = gr.File( |
|
|
label="π₯ Download RTSP Report", |
|
|
elem_classes=["file-component"] |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="footer-info"> |
|
|
<h3>π‘οΈ Dynamic Safety Violation Detection using CCTV + AI</h3> |
|
|
<div class="feature-grid"> |
|
|
<div class="feature-item"> |
|
|
<strong>π― Real-time Detection</strong><br> |
|
|
Advanced YOLOv8 AI with >90% accuracy |
|
|
</div> |
|
|
<div class="feature-item"> |
|
|
<strong>β‘ Ultra-fast Response</strong><br> |
|
|
Alert generation in <5 seconds |
|
|
</div> |
|
|
<div class="feature-item"> |
|
|
<strong>πΈ Image & Video Support</strong><br> |
|
|
Process both static images and video files |
|
|
</div> |
|
|
<div class="feature-item"> |
|
|
<strong>π± Responsive Design</strong><br> |
|
|
Optimized for desktop, tablet & mobile |
|
|
</div> |
|
|
</div> |
|
|
<div style="margin-top: 0.8rem; padding-top: 0.8rem; border-top: 0.5px solid rgba(255,255,255,0.2);"> |
|
|
<p style="margin: 0; font-size: 0.8rem; opacity: 0.7;"> |
|
|
Dynamic Safety Violation Detection using CCTV + AI Β© 2025 |
|
|
</p> |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
media_button.click( |
|
|
fn=process_media, |
|
|
inputs=[media_input, frame_skip_input], |
|
|
outputs=[media_output, media_status, pdf_output, violation_log] |
|
|
) |
|
|
|
|
|
rtsp_button.click( |
|
|
fn=process_rtsp_stream, |
|
|
inputs=[rtsp_url_input], |
|
|
outputs=[rtsp_status, rtsp_output, rtsp_violation_log, heatmap_output, rtsp_pdf_output] |
|
|
) |
|
|
rtsp_cancel_btn.click(cancel_processing, outputs=[rtsp_status]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=True, |
|
|
show_error=True, |
|
|
quiet=False, |
|
|
favicon_path=None, |
|
|
auth=None, |
|
|
inbrowser=True |
|
|
) |