PrashanthB461 commited on
Commit
b0446d7
·
verified ·
1 Parent(s): 6c75df2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +415 -708
app.py CHANGED
@@ -1,774 +1,481 @@
1
  import os
2
  import sys
3
- import subprocess
4
  import logging
5
- import warnings
6
  import cv2
7
- import gradio as gr
8
- import torch
9
  import numpy as np
 
 
10
  from ultralytics import YOLO
11
- import time
12
- from simple_salesforce import Salesforce
13
- from reportlab.lib.pagesizes import letter
14
- from reportlab.pdfgen import canvas
15
- from reportlab.lib.units import inch
16
- from io import BytesIO
17
- import base64
18
- from retrying import retry
19
- import uuid
20
- from multiprocessing import Pool, cpu_count
21
- from functools import partial
22
-
23
- # ========================== # Configuration and Setup # ==========================
24
- os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
25
- os.makedirs('/tmp/Ultralytics', exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
  logger = logging.getLogger(__name__)
29
 
30
- # ========================== # ByteTrack Implementation # ==========================
31
- class BYTETracker:
32
- def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
33
- self.track_thresh = track_thresh
34
- self.track_buffer = track_buffer
35
- self.match_thresh = match_thresh
36
- self.frame_rate = frame_rate
37
- self.next_id = 1
38
- self.tracks = {} # Store active tracks
39
- self.worker_history = {} # Track worker positions over time
40
- self.last_positions = {} # Last known positions of workers
41
-
42
- def update(self, dets, scores, cls):
43
- tracks = []
44
- current_time = time.time()
45
-
46
- # Update existing tracks with new detections
47
- for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
48
- if score < self.track_thresh:
49
- continue
50
-
51
- x, y, w, h = det
52
- matched = False
53
- best_iou = 0
54
- best_track_id = None
55
-
56
- # Try to match with existing tracks
57
- for track_id, track_info in self.tracks.items():
58
- if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
59
- continue
60
-
61
- tx, ty, tw, th = track_info['bbox']
62
- iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
63
-
64
- if iou > self.match_thresh and iou > best_iou:
65
- best_iou = iou
66
- best_track_id = track_id
67
- matched = True
68
-
69
- if matched:
70
- # Update existing track
71
- self.tracks[best_track_id].update({
72
- 'bbox': [x, y, w, h],
73
- 'score': score,
74
- 'cls': cl,
75
- 'last_seen': current_time
76
- })
77
-
78
- # Update position history
79
- if best_track_id not in self.worker_history:
80
- self.worker_history[best_track_id] = []
81
- self.worker_history[best_track_id].append([x, y])
82
- self.last_positions[best_track_id] = [x, y]
83
-
84
- tracks.append({
85
- 'id': best_track_id,
86
- 'bbox': [x, y, w, h],
87
- 'score': score,
88
- 'cls': cl
89
- })
90
- else:
91
- # Create new track
92
- # Check if this detection might be the same worker from a different angle
93
- same_worker = False
94
- for worker_id, last_pos in self.last_positions.items():
95
- if self._is_same_worker([x, y], last_pos):
96
- self.tracks[worker_id] = {
97
- 'bbox': [x, y, w, h],
98
- 'score': score,
99
- 'cls': cl,
100
- 'last_seen': current_time
101
- }
102
- tracks.append({
103
- 'id': worker_id,
104
- 'bbox': [x, y, w, h],
105
- 'score': score,
106
- 'cls': cl
107
- })
108
- same_worker = True
109
- break
110
-
111
- if not same_worker:
112
- self.tracks[self.next_id] = {
113
- 'bbox': [x, y, w, h],
114
- 'score': score,
115
- 'cls': cl,
116
- 'last_seen': current_time
117
- }
118
- self.worker_history[self.next_id] = [[x, y]]
119
- self.last_positions[self.next_id] = [x, y]
120
- tracks.append({
121
- 'id': self.next_id,
122
- 'bbox': [x, y, w, h],
123
- 'score': score,
124
- 'cls': cl
125
- })
126
- self.next_id += 1
127
-
128
- # Clean up old tracks
129
- current_time = time.time()
130
- stale_ids = []
131
- for track_id, track_info in self.tracks.items():
132
- if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
133
- stale_ids.append(track_id)
134
-
135
- for track_id in stale_ids:
136
- del self.tracks[track_id]
137
- if track_id in self.worker_history:
138
- del self.worker_history[track_id]
139
- if track_id in self.last_positions:
140
- del self.last_positions[track_id]
141
 
142
- return tracks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- def _calculate_iou(self, box1, box2):
145
- """Calculate IOU between two boxes"""
146
  x1, y1, w1, h1 = box1
147
  x2, y2, w2, h2 = box2
148
 
149
- # Calculate intersection coordinates
150
- x_left = max(x1 - w1/2, x2 - w2/2)
151
- y_top = max(y1 - h1/2, y2 - h2/2)
152
- x_right = min(x1 + w1/2, x2 + w2/2)
153
- y_bottom = min(y1 + h1/2, y2 + h2/2)
154
 
155
  if x_right < x_left or y_bottom < y_top:
156
  return 0.0
157
 
158
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
159
-
160
  box1_area = w1 * h1
161
  box2_area = w2 * h2
162
 
163
- iou = intersection_area / (box1_area + box2_area - intersection_area)
164
- return iou
165
 
166
- def _is_same_worker(self, pos1, pos2, threshold=100):
167
- """Check if two positions likely belong to the same worker"""
168
  x1, y1 = pos1
169
  x2, y2 = pos2
170
  distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
171
- return distance < threshold
172
-
173
- # ========================== # Optimized Configuration # ==========================
174
- CONFIG = {
175
- "MODEL_PATH": "yolov8_safety.pt",
176
- "FALLBACK_MODEL": "yolov8n.pt",
177
- "OUTPUT_DIR": "static/output",
178
- "VIOLATION_LABELS": {
179
- 0: "no_helmet",
180
- 1: "no_harness",
181
- 2: "unsafe_posture",
182
- 3: "unsafe_zone",
183
- 4: "improper_tool_use"
184
- },
185
- "CLASS_COLORS": {
186
- "no_helmet": (0, 0, 255), # Red
187
- "no_harness": (0, 165, 255), # Orange
188
- "unsafe_posture": (0, 255, 0), # Green
189
- "unsafe_zone": (255, 0, 0), # Blue
190
- "improper_tool_use": (255, 255, 0) # Cyan
191
- },
192
- "DISPLAY_NAMES": {
193
- "no_helmet": "No Helmet Violation",
194
- "no_harness": "No Harness Violation",
195
- "unsafe_posture": "Unsafe Posture",
196
- "unsafe_zone": "Unsafe Zone Entry",
197
- "improper_tool_use": "Improper Tool Use"
198
- },
199
- "SF_CREDENTIALS": {
200
- "username": "prashanth1ai@safety.com",
201
- "password": "SaiPrash461",
202
- "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
203
- "domain": "login"
204
- },
205
- "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
206
- "CONFIDENCE_THRESHOLDS": {
207
- "no_helmet": 0.5,
208
- "no_harness": 0.3,
209
- "unsafe_posture": 0.3,
210
- "unsafe_zone": 0.3,
211
- "improper_tool_use": 0.3
212
- },
213
- "MIN_VIOLATION_FRAMES": 1,
214
- "VIOLATION_COOLDOWN": 30.0, # Increased cooldown period
215
- "WORKER_TRACKING_DURATION": 5.0,
216
- "MAX_PROCESSING_TIME": 60,
217
- "FRAME_SKIP": 2, # Skip more frames for faster processing
218
- "BATCH_SIZE": 16,
219
- "PARALLEL_WORKERS": max(1, cpu_count() - 1),
220
- "TRACK_BUFFER": 30,
221
- "TRACK_THRESH": 0.3,
222
- "MATCH_THRESH": 0.7,
223
- "SNAPSHOT_QUALITY": 95, # Higher quality for better visibility
224
- "MAX_WORKER_DISTANCE": 100 # Maximum pixel distance to consider same worker
225
- }
226
-
227
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
228
- logger.info(f"Using device: {device}")
229
-
230
- def load_model():
231
- try:
232
- if os.path.isfile(CONFIG["MODEL_PATH"]):
233
- model_path = CONFIG["MODEL_PATH"]
234
- logger.info(f"Model loaded: {model_path}")
235
- else:
236
- model_path = CONFIG["FALLBACK_MODEL"]
237
- logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
238
- if not os.path.isfile(model_path):
239
- logger.info(f"Downloading fallback model: {model_path}")
240
- torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
241
-
242
- model = YOLO(model_path).to(device)
243
- logger.info(f"Model classes: {model.names}")
244
- return model
245
- except Exception as e:
246
- logger.error(f"Failed to load model: {e}")
247
- raise
248
-
249
- model = load_model()
250
-
251
- # ========================== # Helper Functions # ==========================
252
- def preprocess_frame(frame):
253
- """Apply basic preprocessing to enhance detection"""
254
- frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
255
- return frame
256
-
257
- def draw_detections(frame, detections):
258
- """Draw bounding boxes and labels on detection frame with improved visibility"""
259
- result_frame = frame.copy()
260
 
261
- for det in detections:
262
- label = det.get("violation", "Unknown")
263
- confidence = det.get("confidence", 0.0)
264
- x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
265
- worker_id = det.get("worker_id", "Unknown")
266
-
267
- x1 = int(x - w/2)
268
- y1 = int(y - h/2)
269
- x2 = int(x + w/2)
270
- y2 = int(y + h/2)
271
-
272
- color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
273
-
274
- # Draw thicker rectangle with border
275
- cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
276
-
277
- # Add black background behind text
278
- display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
279
- text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
280
- cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
281
- cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
282
-
283
- # Add confidence score
284
- conf_text = f"Conf: {confidence:.2f}"
285
- cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
286
 
287
- return result_frame
288
-
289
- def calculate_safety_score(violations):
290
- """Calculate safety score based on detected violations"""
291
- penalties = {
292
- "no_helmet": 25,
293
- "no_harness": 30,
294
- "unsafe_posture": 20,
295
- "unsafe_zone": 35,
296
- "improper_tool_use": 25
297
- }
298
-
299
- # Count unique violation types per worker
300
- worker_violations = {}
301
- for v in violations:
302
- worker_id = v.get("worker_id", "Unknown")
303
- violation_type = v.get("violation", "Unknown")
304
-
305
- if worker_id not in worker_violations:
306
- worker_violations[worker_id] = set()
307
- worker_violations[worker_id].add(violation_type)
308
-
309
- # Calculate total penalty
310
- total_penalty = 0
311
- for worker_violations_set in worker_violations.values():
312
- worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
313
- total_penalty += worker_penalty
314
-
315
- score = max(0, 100 - total_penalty)
316
- return score
317
-
318
- def generate_violation_pdf(violations, score):
319
- """Generate a PDF report for the detected violations"""
320
- try:
321
- pdf_filename = f"violations_{int(time.time())}.pdf"
322
- pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
323
- pdf_file = BytesIO()
324
- c = canvas.Canvas(pdf_file, pagesize=letter)
325
-
326
- # Title
327
- c.setFont("Helvetica-Bold", 16)
328
- c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
329
-
330
- # Basic Information
331
- c.setFont("Helvetica", 12)
332
- c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
333
- c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
334
-
335
- # Safety Score
336
- c.setFont("Helvetica-Bold", 14)
337
- c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
338
-
339
- # Violation Summary
340
- y_position = 8.2 * inch
341
- c.setFont("Helvetica-Bold", 12)
342
- c.drawString(1 * inch, y_position, "Summary:")
343
- y_position -= 0.3 * inch
344
-
345
- # Group violations by worker
346
- worker_violations = {}
347
- for v in violations:
348
- worker_id = v.get("worker_id", "Unknown")
349
- if worker_id not in worker_violations:
350
- worker_violations[worker_id] = []
351
- worker_violations[worker_id].append(v)
352
-
353
- c.setFont("Helvetica", 10)
354
- summary_data = {
355
- "Total Workers with Violations": len(worker_violations),
356
- "Total Violations Found": len(violations),
357
- "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
358
- }
359
-
360
- for key, value in summary_data.items():
361
- c.drawString(1 * inch, y_position, f"{key}: {value}")
362
- y_position -= 0.25 * inch
363
-
364
- # Detailed Violations by Worker
365
- y_position -= 0.5 * inch
366
- c.setFont("Helvetica-Bold", 12)
367
- c.drawString(1 * inch, y_position, "Violations by Worker:")
368
- y_position -= 0.3 * inch
369
-
370
- c.setFont("Helvetica", 10)
371
- for worker_id, worker_vios in worker_violations.items():
372
- c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
373
- y_position -= 0.2 * inch
374
 
375
- for v in worker_vios:
376
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
377
- time_str = f"{v.get('timestamp', 0.0):.2f}s"
378
- conf_str = f"{v.get('confidence', 0.0):.2f}"
 
 
 
 
379
 
380
- violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
381
- c.drawString(1.2 * inch, y_position, violation_text)
382
- y_position -= 0.2 * inch
 
383
 
384
- if y_position < 1 * inch:
385
- c.showPage()
386
- c.setFont("Helvetica", 10)
387
- y_position = 10 * inch
388
-
389
- c.save()
390
- pdf_file.seek(0)
391
-
392
- # Save PDF file
393
- with open(pdf_path, "wb") as f:
394
- f.write(pdf_file.getvalue())
 
 
 
395
 
396
- public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
397
- logger.info(f"PDF generated: {public_url}")
398
- return pdf_path, public_url, pdf_file
399
- except Exception as e:
400
- logger.error(f"Error generating PDF: {e}")
401
- return "", "", None
402
-
403
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
404
- def connect_to_salesforce():
405
- """Connect to Salesforce with retry logic"""
406
- try:
407
- sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
408
- logger.info("Connected to Salesforce")
409
- sf.describe()
410
- return sf
411
- except Exception as e:
412
- logger.error(f"Salesforce connection failed: {e}")
413
- raise
414
-
415
- def upload_pdf_to_salesforce(sf, pdf_file, report_id):
416
- """Upload PDF report to Salesforce"""
417
- try:
418
- if not pdf_file:
419
- logger.error("No PDF file provided for upload")
420
- return ""
421
 
422
- encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
423
- content_version_data = {
424
- "Title": f"Safety_Violation_Report_{int(time.time())}",
425
- "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
426
- "VersionData": encoded_pdf,
427
- "FirstPublishLocationId": report_id
428
- }
429
- content_version = sf.ContentVersion.create(content_version_data)
430
- result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
431
-
432
- if not result['records']:
433
- logger.error("Failed to retrieve ContentVersion")
434
- return ""
435
 
436
- file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
437
- logger.info(f"PDF uploaded to Salesforce: {file_url}")
438
- return file_url
439
- except Exception as e:
440
- logger.error(f"Error uploading PDF to Salesforce: {e}")
441
- return ""
442
-
443
- def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
444
- """Push violation report to Salesforce"""
445
- try:
446
- sf = connect_to_salesforce()
447
-
448
- # Format violations for Salesforce
449
- violations_text = ""
450
- for v in violations:
451
- display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
452
- worker_id = v.get('worker_id', 'Unknown')
453
- timestamp = v.get('timestamp', 0.0)
454
- confidence = v.get('confidence', 0.0)
 
 
 
 
455
 
456
- violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
- if not violations_text:
459
- violations_text = "No violations detected."
 
 
 
460
 
461
- pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
462
 
463
- record_data = {
464
- "Compliance_Score__c": score,
465
- "Violations_Found__c": len(violations),
466
- "Violations_Details__c": violations_text,
467
- "Status__c": "Pending",
468
- "PDF_Report_URL__c": pdf_url
469
- }
470
-
471
- logger.info(f"Creating Salesforce record with data: {record_data}")
472
 
 
 
473
  try:
474
- record = sf.Safety_Video_Report__c.create(record_data)
475
- logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
 
 
 
 
 
476
  except Exception as e:
477
- logger.error(f"Failed to create Safety_Video_Report__c: {e}")
478
- record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
479
- logger.warning(f"Fell back to Account record: {record['id']}")
480
 
481
- record_id = record["id"]
482
-
483
- if pdf_file:
484
- uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
485
- if uploaded_url:
486
- try:
487
- sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
488
- logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
489
- except Exception as e:
490
- logger.error(f"Failed to update Safety_Video_Report__c: {e}")
491
- sf.Account.update(record_id, {"Description": uploaded_url})
492
- logger.info(f"Updated Account record {record_id} with PDF URL")
493
- pdf_url = uploaded_url
494
-
495
- return record_id, pdf_url
496
- except Exception as e:
497
- logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
498
- return None, ""
499
-
500
- def process_video(video_data):
501
- """Process video to detect safety violations"""
502
- try:
503
- os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
504
- logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
505
-
506
- video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
507
- with open(video_path, "wb") as f:
508
- f.write(video_data)
509
- logger.info(f"Video saved: {video_path}")
510
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  cap = cv2.VideoCapture(video_path)
512
  if not cap.isOpened():
513
- os.remove(video_path)
514
- raise ValueError("Could not open video file")
515
-
516
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
517
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
518
- duration = total_frames / fps
519
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
520
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
521
- logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
522
-
523
- tracker = BYTETracker(
524
- track_thresh=CONFIG["TRACK_THRESH"],
525
- track_buffer=CONFIG["TRACK_BUFFER"],
526
- match_thresh=CONFIG["MATCH_THRESH"],
527
- frame_rate=fps
528
- )
529
-
530
- # Track unique violations by worker ID
531
- unique_violations = {} # {worker_id: {violation_type: first_detection_time}}
532
  snapshots = []
533
- start_time = time.time()
534
- frame_skip = CONFIG["FRAME_SKIP"]
535
- processed_frames = 0
536
-
537
- while processed_frames < total_frames:
538
- batch_frames = []
539
- batch_indices = []
540
-
541
- for _ in range(CONFIG["BATCH_SIZE"]):
542
- frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
543
- if frame_idx >= total_frames:
544
- break
545
-
546
- ret, frame = cap.read()
547
- if not ret:
548
- break
549
-
550
- frame = preprocess_frame(frame)
551
 
552
- # Skip frames if needed
553
- for _ in range(frame_skip - 1):
554
- if not cap.grab():
555
- break
556
 
557
- batch_frames.append(frame)
558
- batch_indices.append(frame_idx)
559
- processed_frames += 1
560
-
561
- if not batch_frames:
562
- break
563
-
564
- # Process batch with YOLO model
565
- results = model(batch_frames, device=device, conf=0.1, verbose=False)
566
 
567
- for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
568
- current_time = frame_idx / fps
 
569
 
570
- # Update progress every second
571
- if time.time() - start_time > 1.0:
572
- progress = (processed_frames / total_frames) * 100
573
- yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", ""
574
- start_time = time.time()
575
-
576
- boxes = result.boxes
577
- track_inputs = []
578
 
579
- for box in boxes:
580
- cls = int(box.cls)
581
- conf = float(box.conf)
582
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
583
-
584
- if label is None:
585
- continue
586
-
587
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
588
- continue
589
-
590
- bbox = box.xywh.cpu().numpy()[0]
591
- track_inputs.append({
592
- "bbox": bbox,
593
- "conf": conf,
594
- "cls": cls
595
- })
596
-
597
- if not track_inputs:
598
- continue
599
-
600
- tracked_objects = tracker.update(
601
- np.array([t["bbox"] for t in track_inputs]),
602
- np.array([t["conf"] for t in track_inputs]),
603
- np.array([t["cls"] for t in track_inputs])
604
- )
605
-
606
- # Process tracked objects for violations
607
- for obj in tracked_objects:
608
- worker_id = obj['id']
609
- label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
610
- conf = obj['score']
611
- bbox = obj['bbox']
612
-
613
- if label is None:
614
- continue
615
-
616
- # Initialize worker if not seen before
617
- if worker_id not in unique_violations:
618
- unique_violations[worker_id] = {}
619
-
620
- # Check if this violation type has been recorded for this worker
621
- if label not in unique_violations[worker_id]:
622
- # This is a new violation type for this worker
623
- unique_violations[worker_id][label] = current_time
624
-
625
- # Create detection object
626
- detection = {
627
- "worker_id": worker_id,
628
- "violation": label,
629
- "confidence": round(conf, 2),
630
- "bounding_box": bbox,
631
- "timestamp": current_time
632
- }
633
-
634
- # Take snapshot for the new violation
635
- snapshot_frame = batch_frames[i].copy()
636
- snapshot_frame = draw_detections(snapshot_frame, [detection])
637
-
638
- # Add timestamp to snapshot
639
- cv2.putText(
640
- snapshot_frame,
641
- f"Time: {current_time:.2f}s",
642
- (10, 30),
643
- cv2.FONT_HERSHEY_SIMPLEX,
644
- 0.7,
645
- (255, 255, 255),
646
- 2
647
- )
648
-
649
- # Save snapshot with high quality
650
- snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
651
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
652
-
653
- cv2.imwrite(
654
- snapshot_path,
655
- snapshot_frame,
656
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
657
- )
658
-
659
- snapshots.append({
660
- "violation": label,
661
- "worker_id": worker_id,
662
- "timestamp": current_time,
663
- "snapshot_path": snapshot_path,
664
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
665
- })
666
-
667
- logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
668
-
669
- cap.release()
670
- if os.path.exists(video_path):
671
- os.remove(video_path)
672
 
673
- processing_time = time.time() - start_time
674
- logger.info(f"Processing complete in {processing_time:.2f}s")
675
-
676
- # Convert tracked violations to final violation list
677
- violations = []
678
- for worker_id, worker_violations in unique_violations.items():
679
- for label, detection_time in worker_violations.items():
680
- violation = {
681
- "worker_id": worker_id,
682
- "violation": label,
683
- "timestamp": detection_time
684
- }
685
- violations.append(violation)
686
-
687
- if not violations:
688
- logger.info("No violations detected after processing")
689
- yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
690
- return
691
-
692
- # Calculate safety score
693
- score = calculate_safety_score(violations)
694
 
695
- # Generate PDF report
696
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
 
 
 
 
 
697
 
698
- # Push report to Salesforce
699
- report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
700
-
701
- # Format violations table for display
702
- violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
703
- violation_table += "|-----------|-----------|----------|------------|\n"
704
-
705
- for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
706
- display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
707
- worker_id = v.get("worker_id", "Unknown")
708
- timestamp = v.get("timestamp", 0.0)
709
- confidence = v.get("confidence", 0.0)
710
-
711
- violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
712
 
713
- # Format snapshots for display
714
- snapshots_text = ""
715
- for s in snapshots:
716
- display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
717
- worker_id = s.get("worker_id", "Unknown")
718
- timestamp = s.get("timestamp", 0.0)
719
-
720
- snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
721
- snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
 
 
 
 
 
 
 
 
 
722
 
723
- if not snapshots_text:
724
- snapshots_text = "No snapshots captured."
 
 
 
 
 
 
 
 
 
 
725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
  yield (
727
- violation_table,
728
  f"Safety Score: {score}%",
729
- snapshots_text,
730
- f"Salesforce Record ID: {report_id or 'N/A'}",
731
- final_pdf_url or "N/A"
732
  )
733
-
734
- except Exception as e:
735
- logger.error(f"Error processing video: {e}", exc_info=True)
736
- if 'video_path' in locals() and os.path.exists(video_path):
737
- os.remove(video_path)
738
- yield f"Error processing video: {e}", "", "", "", ""
739
-
740
- def gradio_interface(video_file):
741
- """Gradio interface for the video processing"""
742
- if not video_file:
743
- return "No file uploaded.", "", "No file uploaded.", "", ""
744
 
745
- try:
746
- with open(video_file, "rb") as f:
747
- video_data = f.read()
748
-
749
- for status, score, snapshots_text, record_id, details_url in process_video(video_data):
750
- yield status, score, snapshots_text, record_id, details_url
 
 
 
 
 
 
 
 
 
751
 
752
- except Exception as e:
753
- logger.error(f"Error in Gradio interface: {e}", exc_info=True)
754
- yield f"Error: {str(e)}", "", "Error in processing.", "", ""
755
-
756
- # ========================== # Gradio Interface # ==========================
757
- interface = gr.Interface(
758
- fn=gradio_interface,
759
- inputs=gr.Video(label="Upload Site Video"),
760
- outputs=[
761
- gr.Markdown(label="Detected Safety Violations"),
762
- gr.Textbox(label="Compliance Score"),
763
- gr.Markdown(label="Snapshots"),
764
- gr.Textbox(label="Salesforce Record ID"),
765
- gr.Textbox(label="Violation Details URL")
766
- ],
767
- title="Worksite Safety Violation Analyzer",
768
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
769
- allow_flagging="never"
770
- )
771
 
772
  if __name__ == "__main__":
773
- logger.info("Launching Enhanced Safety Analyzer App...")
 
774
  interface.launch()
 
1
  import os
2
  import sys
 
3
  import logging
4
+ import time
5
  import cv2
 
 
6
  import numpy as np
7
+ import torch
8
+ import gradio as gr
9
  from ultralytics import YOLO
10
+ from collections import defaultdict
11
+ from typing import List, Dict, Tuple, Optional
12
+
13
+ # ========================== # Configuration # ==========================
14
+ class Config:
15
+ # Model settings
16
+ MODEL_PATH = "yolov8_safety.pt"
17
+ FALLBACK_MODEL = "yolov8n.pt"
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # Violation settings
21
+ VIOLATION_LABELS = {
22
+ 0: "no_helmet",
23
+ 1: "no_harness",
24
+ 2: "unsafe_posture",
25
+ 3: "unsafe_zone",
26
+ 4: "improper_tool_use"
27
+ }
28
+
29
+ # Tracking settings
30
+ FACE_TRACKING = True # Enable face tracking for helmet violations
31
+ TRACK_BUFFER = 30 # Number of frames to keep track of a person
32
+ MIN_FACE_CONFIDENCE = 0.7 # Minimum confidence for face detection
33
+ MIN_VIOLATION_CONFIDENCE = 0.5 # Minimum confidence for violation detection
34
+
35
+ # Output settings
36
+ OUTPUT_DIR = "static/output"
37
+ SNAPSHOT_QUALITY = 90
38
+ FRAME_SKIP = 2 # Process every nth frame for efficiency
39
+
40
+ # Violation suppression
41
+ VIOLATION_COOLDOWN = 5.0 # Seconds before same violation can be reported again
42
+ POSITION_THRESHOLD = 50 # Pixel distance to consider same location
43
+
44
+ # Display colors
45
+ COLOR_MAP = {
46
+ "no_helmet": (0, 0, 255),
47
+ "no_harness": (0, 165, 255),
48
+ "unsafe_posture": (0, 255, 0),
49
+ "unsafe_zone": (255, 0, 0),
50
+ "improper_tool_use": (255, 255, 0)
51
+ }
52
+
53
+ # Penalty scores for safety calculation
54
+ PENALTIES = {
55
+ "no_helmet": 25,
56
+ "no_harness": 30,
57
+ "unsafe_posture": 20,
58
+ "unsafe_zone": 35,
59
+ "improper_tool_use": 25
60
+ }
61
 
62
+ # Initialize logging
63
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
64
  logger = logging.getLogger(__name__)
65
 
66
+ # ========================== # Face Recognition # ==========================
67
+ class FaceTracker:
68
+ def __init__(self):
69
+ try:
70
+ # Try to load face detection model
71
+ self.face_cascade = cv2.CascadeClassifier(
72
+ cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
73
+ self.enabled = True
74
+ except:
75
+ logger.warning("Face detection model not available - falling back to position tracking")
76
+ self.enabled = False
77
+
78
+ def detect_faces(self, frame: np.ndarray) -> List[Tuple[int, int, int, int]]:
79
+ """Detect faces in a frame and return bounding boxes"""
80
+ if not self.enabled:
81
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
84
+ faces = self.face_cascade.detectMultiScale(
85
+ gray,
86
+ scaleFactor=1.1,
87
+ minNeighbors=5,
88
+ minSize=(30, 30)
89
+ return faces
90
+
91
+ # ========================== # Violation Tracker # ==========================
92
+ class ViolationTracker:
93
+ def __init__(self):
94
+ self.tracked_workers = {} # {track_id: {last_seen, violations, face_bbox, position}}
95
+ self.violation_history = defaultdict(list) # {violation_type: [detection_times]}
96
+ self.next_id = 1
97
+ self.face_tracker = FaceTracker() if Config.FACE_TRACKING else None
98
 
99
+ def _calculate_iou(self, box1: Tuple, box2: Tuple) -> float:
100
+ """Calculate Intersection over Union for two bounding boxes"""
101
  x1, y1, w1, h1 = box1
102
  x2, y2, w2, h2 = box2
103
 
104
+ # Calculate coordinates
105
+ x_left = max(x1, x2)
106
+ y_top = max(y1, y2)
107
+ x_right = min(x1 + w1, x2 + w2)
108
+ y_bottom = min(y1 + h1, y2 + h2)
109
 
110
  if x_right < x_left or y_bottom < y_top:
111
  return 0.0
112
 
113
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
 
114
  box1_area = w1 * h1
115
  box2_area = w2 * h2
116
 
117
+ return intersection_area / (box1_area + box2_area - intersection_area)
 
118
 
119
+ def _is_same_position(self, pos1: Tuple, pos2: Tuple) -> bool:
120
+ """Check if two positions are close enough to be considered the same"""
121
  x1, y1 = pos1
122
  x2, y2 = pos2
123
  distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
124
+ return distance < Config.POSITION_THRESHOLD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ def _get_face_id(self, frame: np.ndarray, bbox: Tuple) -> Optional[int]:
127
+ """Try to identify a person by their face"""
128
+ if not self.face_tracker or not self.face_tracker.enabled:
129
+ return None
130
+
131
+ x, y, w, h = bbox
132
+ face_region = frame[max(0, y):y+h, max(0, x):x+w]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ faces = self.face_tracker.detect_faces(face_region)
135
+ if len(faces) == 0:
136
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # Get the largest face in the region
139
+ largest_face = max(faces, key=lambda f: f[2]*f[3])
140
+ fx, fy, fw, fh = largest_face
141
+
142
+ # Check if we've seen this face before
143
+ for track_id, info in self.tracked_workers.items():
144
+ if 'face_bbox' not in info:
145
+ continue
146
 
147
+ # Compare with stored face
148
+ iou = self._calculate_iou((fx, fy, fw, fh), info['face_bbox'])
149
+ if iou > 0.5: # Threshold for face matching
150
+ return track_id
151
 
152
+ return None
153
+
154
+ def update(self, frame: np.ndarray, detections: List[Dict], timestamp: float) -> List[Dict]:
155
+ """Update tracker with new detections and return unique violations"""
156
+ current_time = time.time()
157
+ unique_violations = []
158
+
159
+ # First pass - try to match with existing tracks
160
+ for det in detections:
161
+ bbox = det['bbox']
162
+ violation_type = det['violation']
163
+ confidence = det['confidence']
164
+ x, y, w, h = bbox
165
+ center = (x + w/2, y + h/2)
166
 
167
+ # Try face recognition for helmet violations
168
+ track_id = None
169
+ if violation_type == "no_helmet" and self.face_tracker:
170
+ track_id = self._get_face_id(frame, bbox)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ # If no face match, try position matching
173
+ if track_id is None:
174
+ for tid, info in self.tracked_workers.items():
175
+ if current_time - info['last_seen'] > Config.TRACK_BUFFER / 30.0:
176
+ continue
177
+
178
+ if 'position' in info and self._is_same_position(center, info['position']):
179
+ track_id = tid
180
+ break
 
 
 
 
181
 
182
+ # If still no match, create new track
183
+ if track_id is None:
184
+ track_id = self.next_id
185
+ self.next_id += 1
186
+ self.tracked_workers[track_id] = {
187
+ 'last_seen': current_time,
188
+ 'violations': set(),
189
+ 'position': center
190
+ }
191
+
192
+ # Store face if available
193
+ if violation_type == "no_helmet" and self.face_tracker:
194
+ faces = self.face_tracker.detect_faces(frame[y:y+h, x:x+w])
195
+ if len(faces) > 0:
196
+ self.tracked_workers[track_id]['face_bbox'] = max(
197
+ faces, key=lambda f: f[2]*f[3])
198
+
199
+ # Update track info
200
+ self.tracked_workers[track_id].update({
201
+ 'last_seen': current_time,
202
+ 'position': center,
203
+ 'bbox': bbox
204
+ })
205
 
206
+ # Check if this is a new violation for this track
207
+ if violation_type not in self.tracked_workers[track_id]['violations']:
208
+ # Check if this violation type has been seen recently
209
+ recent_violations = [t for t in self.violation_history[violation_type]
210
+ if current_time - t < Config.VIOLATION_COOLDOWN]
211
+
212
+ if not recent_violations:
213
+ self.tracked_workers[track_id]['violations'].add(violation_type)
214
+ self.violation_history[violation_type].append(current_time)
215
+
216
+ unique_violations.append({
217
+ 'track_id': track_id,
218
+ 'violation': violation_type,
219
+ 'bbox': bbox,
220
+ 'confidence': confidence,
221
+ 'timestamp': timestamp
222
+ })
223
 
224
+ # Clean up old tracks
225
+ stale_ids = [tid for tid, info in self.tracked_workers.items()
226
+ if current_time - info['last_seen'] > Config.TRACK_BUFFER / 30.0]
227
+ for tid in stale_ids:
228
+ del self.tracked_workers[tid]
229
 
230
+ return unique_violations
231
 
232
+ # ========================== # Video Processor # ==========================
233
+ class VideoProcessor:
234
+ def __init__(self):
235
+ self.model = self._load_model()
236
+ self.tracker = ViolationTracker()
 
 
 
 
237
 
238
+ def _load_model(self):
239
+ """Load YOLOv8 model with fallback"""
240
  try:
241
+ if os.path.exists(Config.MODEL_PATH):
242
+ model = YOLO(Config.MODEL_PATH).to(Config.DEVICE)
243
+ logger.info(f"Loaded custom model from {Config.MODEL_PATH}")
244
+ else:
245
+ model = YOLO(Config.FALLBACK_MODEL).to(Config.DEVICE)
246
+ logger.warning("Using fallback YOLOv8n model")
247
+ return model
248
  except Exception as e:
249
+ logger.error(f"Failed to load model: {e}")
250
+ raise
 
251
 
252
+ def _draw_detection(self, frame: np.ndarray, detection: Dict) -> np.ndarray:
253
+ """Draw a single detection on the frame"""
254
+ x, y, w, h = detection['bbox']
255
+ label = detection['violation']
256
+ confidence = detection['confidence']
257
+ track_id = detection.get('track_id', 0)
258
+
259
+ color = Config.COLOR_MAP.get(label, (0, 0, 255))
260
+
261
+ # Draw bounding box
262
+ cv2.rectangle(frame, (x, y), (x+w, y+h), color, 2)
263
+
264
+ # Draw label background
265
+ label_text = f"{label} (ID: {track_id})"
266
+ (text_width, text_height), _ = cv2.getTextSize(
267
+ label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
268
+
269
+ cv2.rectangle(frame, (x, y - text_height - 10),
270
+ (x + text_width + 10, y), color, -1)
271
+
272
+ # Draw label text
273
+ cv2.putText(frame, label_text, (x + 5, y - 5),
274
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
275
+
276
+ # Draw confidence
277
+ conf_text = f"{confidence:.2f}"
278
+ cv2.putText(frame, conf_text, (x + 5, y + h + 15),
279
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
280
+
281
+ return frame
282
+
283
+ def _process_frame(self, frame: np.ndarray, frame_idx: int, fps: float) -> Tuple[List[Dict], np.ndarray]:
284
+ """Process a single video frame"""
285
+ timestamp = frame_idx / fps
286
+
287
+ # Run YOLO detection
288
+ results = self.model(frame, verbose=False)
289
+ detections = []
290
+
291
+ for result in results:
292
+ for box in result.boxes:
293
+ cls_id = int(box.cls)
294
+ conf = float(box.conf)
295
+ label = Config.VIOLATION_LABELS.get(cls_id)
296
+
297
+ if label is None or conf < Config.MIN_VIOLATION_CONFIDENCE:
298
+ continue
299
+
300
+ # Convert bounding box to (x, y, w, h) format
301
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
302
+ detections.append({
303
+ 'bbox': (x1, y1, x2 - x1, y2 - y1),
304
+ 'violation': label,
305
+ 'confidence': conf
306
+ })
307
+
308
+ # Update tracker and get unique violations
309
+ unique_violations = self.tracker.update(frame, detections, timestamp)
310
+
311
+ # Draw detections on frame
312
+ output_frame = frame.copy()
313
+ for violation in unique_violations:
314
+ output_frame = self._draw_detection(output_frame, violation)
315
+
316
+ return unique_violations, output_frame
317
+
318
+ def process_video(self, video_path: str) -> Tuple[List[Dict], List[Tuple[float, np.ndarray]]]:
319
+ """Process entire video file"""
320
  cap = cv2.VideoCapture(video_path)
321
  if not cap.isOpened():
322
+ raise ValueError(f"Could not open video: {video_path}")
323
+
324
+ fps = cap.get(cv2.CAP_PROP_FPS)
325
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
326
+ violations = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  snapshots = []
328
+
329
+ logger.info(f"Processing video: {video_path} ({total_frames} frames, {fps:.1f} FPS)")
330
+
331
+ frame_idx = 0
332
+ while cap.isOpened():
333
+ ret, frame = cap.read()
334
+ if not ret:
335
+ break
 
 
 
 
 
 
 
 
 
 
336
 
337
+ # Skip frames for efficiency
338
+ if frame_idx % Config.FRAME_SKIP != 0:
339
+ frame_idx += 1
340
+ continue
341
 
342
+ # Process frame
343
+ frame_violations, output_frame = self._process_frame(frame, frame_idx, fps)
 
 
 
 
 
 
 
344
 
345
+ # Save snapshots for new violations
346
+ for violation in frame_violations:
347
+ violations.append(violation)
348
 
349
+ # Save snapshot
350
+ timestamp = frame_idx / fps
351
+ snapshot_path = os.path.join(
352
+ Config.OUTPUT_DIR,
353
+ f"violation_{violation['violation']}_id{violation['track_id']}_{int(timestamp*100)}.jpg")
 
 
 
354
 
355
+ cv2.imwrite(snapshot_path, output_frame,
356
+ [cv2.IMWRITE_JPEG_QUALITY, Config.SNAPSHOT_QUALITY])
357
+
358
+ snapshots.append((timestamp, snapshot_path))
359
+ logger.info(f"Detected {violation['violation']} at {timestamp:.2f}s (ID: {violation['track_id']})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
+ frame_idx += 1
362
+
363
+ # Yield progress
364
+ if frame_idx % 10 == 0:
365
+ progress = frame_idx / total_frames
366
+ yield progress, violations, snapshots
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
+ cap.release()
369
+ yield 1.0, violations, snapshots # Final yield
370
+
371
+ def calculate_safety_score(self, violations: List[Dict]) -> int:
372
+ """Calculate safety compliance score (0-100)"""
373
+ penalty = 0
374
+ violation_types = set()
375
 
376
+ for violation in violations:
377
+ violation_types.add(violation['violation'])
378
+
379
+ for v_type in violation_types:
380
+ penalty += Config.PENALTIES.get(v_type, 0)
381
+
382
+ return max(0, 100 - penalty)
 
 
 
 
 
 
 
383
 
384
+ # ========================== # Gradio Interface # ==========================
385
+ def setup_output_dir():
386
+ """Ensure output directory exists"""
387
+ os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
388
+ logger.info(f"Output directory: {Config.OUTPUT_DIR}")
389
+
390
+ def format_violation_table(violations: List[Dict]) -> str:
391
+ """Format violations as markdown table"""
392
+ if not violations:
393
+ return "No violations detected"
394
+
395
+ table = "| Violation | Worker ID | Time (s) | Confidence |\n"
396
+ table += "|-----------|-----------|----------|------------|\n"
397
+
398
+ for v in sorted(violations, key=lambda x: x['timestamp']):
399
+ table += f"| {v['violation']} | {v['track_id']} | {v['timestamp']:.2f} | {v['confidence']:.2f} |\n"
400
+
401
+ return table
402
 
403
+ def format_snapshots(snapshots: List[Tuple[float, str]]) -> str:
404
+ """Format snapshots as markdown"""
405
+ if not snapshots:
406
+ return "No violations captured"
407
+
408
+ markdown = ""
409
+ for timestamp, path in snapshots:
410
+ filename = os.path.basename(path)
411
+ markdown += f"### Violation at {timestamp:.2f}s\n\n"
412
+ markdown += f"![Violation](file/{path})\n\n"
413
+
414
+ return markdown
415
 
416
+ def process_video_wrapper(video_file: str):
417
+ """Wrapper for Gradio interface"""
418
+ setup_output_dir()
419
+
420
+ # Create temporary video path
421
+ temp_video_path = os.path.join(Config.OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
422
+ with open(temp_video_path, "wb") as f:
423
+ f.write(open(video_file, "rb").read())
424
+
425
+ processor = VideoProcessor()
426
+
427
+ try:
428
+ for progress, violations, snapshots in processor.process_video(temp_video_path):
429
+ yield (
430
+ f"Processing... {progress*100:.1f}% complete",
431
+ "",
432
+ "",
433
+ ""
434
+ )
435
+
436
+ # Calculate final results
437
+ score = processor.calculate_safety_score(violations)
438
+ violation_table = format_violation_table(violations)
439
+ snapshots_md = format_snapshots(snapshots)
440
+
441
  yield (
442
+ "Processing complete",
443
  f"Safety Score: {score}%",
444
+ violation_table,
445
+ snapshots_md
 
446
  )
 
 
 
 
 
 
 
 
 
 
 
447
 
448
+ finally:
449
+ if os.path.exists(temp_video_path):
450
+ os.remove(temp_video_path)
451
+
452
+ # ========================== # Main Application # ==========================
453
+ def create_interface():
454
+ """Create Gradio interface"""
455
+ with gr.Blocks(title="Safety Compliance Analyzer") as interface:
456
+ gr.Markdown("# 🚧 Safety Compliance Video Analyzer")
457
+ gr.Markdown("Upload site videos to detect safety violations (No Helmet, No Harness, etc.)")
458
+
459
+ with gr.Row():
460
+ with gr.Column():
461
+ video_input = gr.Video(label="Upload Site Video")
462
+ submit_btn = gr.Button("Analyze Video", variant="primary")
463
 
464
+ with gr.Column():
465
+ progress_out = gr.Textbox(label="Status")
466
+ score_out = gr.Textbox(label="Safety Score")
467
+ violations_out = gr.Markdown(label="Detected Violations")
468
+ snapshots_out = gr.Markdown(label="Violation Snapshots")
469
+
470
+ submit_btn.click(
471
+ fn=process_video_wrapper,
472
+ inputs=video_input,
473
+ outputs=[progress_out, score_out, violations_out, snapshots_out]
474
+ )
475
+
476
+ return interface
 
 
 
 
 
 
477
 
478
  if __name__ == "__main__":
479
+ setup_output_dir()
480
+ interface = create_interface()
481
  interface.launch()