PrashanthB461 commited on
Commit
a87de62
·
verified ·
1 Parent(s): b0446d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +671 -413
app.py CHANGED
@@ -1,481 +1,739 @@
1
  import os
2
  import sys
 
3
  import logging
4
- import time
5
  import cv2
6
- import numpy as np
7
- import torch
8
  import gradio as gr
 
 
9
  from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
 
 
10
  from collections import defaultdict
11
- from typing import List, Dict, Tuple, Optional
12
-
13
- # ========================== # Configuration # ==========================
14
- class Config:
15
- # Model settings
16
- MODEL_PATH = "yolov8_safety.pt"
17
- FALLBACK_MODEL = "yolov8n.pt"
18
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
-
20
- # Violation settings
21
- VIOLATION_LABELS = {
22
- 0: "no_helmet",
23
- 1: "no_harness",
24
- 2: "unsafe_posture",
25
- 3: "unsafe_zone",
26
- 4: "improper_tool_use"
27
- }
28
-
29
- # Tracking settings
30
- FACE_TRACKING = True # Enable face tracking for helmet violations
31
- TRACK_BUFFER = 30 # Number of frames to keep track of a person
32
- MIN_FACE_CONFIDENCE = 0.7 # Minimum confidence for face detection
33
- MIN_VIOLATION_CONFIDENCE = 0.5 # Minimum confidence for violation detection
34
-
35
- # Output settings
36
- OUTPUT_DIR = "static/output"
37
- SNAPSHOT_QUALITY = 90
38
- FRAME_SKIP = 2 # Process every nth frame for efficiency
39
-
40
- # Violation suppression
41
- VIOLATION_COOLDOWN = 5.0 # Seconds before same violation can be reported again
42
- POSITION_THRESHOLD = 50 # Pixel distance to consider same location
43
-
44
- # Display colors
45
- COLOR_MAP = {
46
- "no_helmet": (0, 0, 255),
47
- "no_harness": (0, 165, 255),
48
- "unsafe_posture": (0, 255, 0),
49
- "unsafe_zone": (255, 0, 0),
50
- "improper_tool_use": (255, 255, 0)
51
- }
52
-
53
- # Penalty scores for safety calculation
54
- PENALTIES = {
55
- "no_helmet": 25,
56
- "no_harness": 30,
57
- "unsafe_posture": 20,
58
- "unsafe_zone": 35,
59
- "improper_tool_use": 25
60
- }
61
 
62
- # Initialize logging
 
 
 
63
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
64
  logger = logging.getLogger(__name__)
65
 
66
- # ========================== # Face Recognition # ==========================
67
- class FaceTracker:
68
- def __init__(self):
69
- try:
70
- # Try to load face detection model
71
- self.face_cascade = cv2.CascadeClassifier(
72
- cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
73
- self.enabled = True
74
- except:
75
- logger.warning("Face detection model not available - falling back to position tracking")
76
- self.enabled = False
77
-
78
- def detect_faces(self, frame: np.ndarray) -> List[Tuple[int, int, int, int]]:
79
- """Detect faces in a frame and return bounding boxes"""
80
- if not self.enabled:
81
- return []
82
-
83
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
84
- faces = self.face_cascade.detectMultiScale(
85
- gray,
86
- scaleFactor=1.1,
87
- minNeighbors=5,
88
- minSize=(30, 30)
89
- return faces
90
-
91
- # ========================== # Violation Tracker # ==========================
92
- class ViolationTracker:
93
- def __init__(self):
94
- self.tracked_workers = {} # {track_id: {last_seen, violations, face_bbox, position}}
95
- self.violation_history = defaultdict(list) # {violation_type: [detection_times]}
96
  self.next_id = 1
97
- self.face_tracker = FaceTracker() if Config.FACE_TRACKING else None
98
-
99
- def _calculate_iou(self, box1: Tuple, box2: Tuple) -> float:
100
- """Calculate Intersection over Union for two bounding boxes"""
101
- x1, y1, w1, h1 = box1
102
- x2, y2, w2, h2 = box2
103
-
104
- # Calculate coordinates
105
- x_left = max(x1, x2)
106
- y_top = max(y1, y2)
107
- x_right = min(x1 + w1, x2 + w2)
108
- y_bottom = min(y1 + h1, y2 + h2)
109
-
110
- if x_right < x_left or y_bottom < y_top:
111
- return 0.0
112
-
113
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
114
- box1_area = w1 * h1
115
- box2_area = w2 * h2
116
 
117
- return intersection_area / (box1_area + box2_area - intersection_area)
118
-
119
- def _is_same_position(self, pos1: Tuple, pos2: Tuple) -> bool:
120
- """Check if two positions are close enough to be considered the same"""
121
- x1, y1 = pos1
122
- x2, y2 = pos2
123
- distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
124
- return distance < Config.POSITION_THRESHOLD
125
-
126
- def _get_face_id(self, frame: np.ndarray, bbox: Tuple) -> Optional[int]:
127
- """Try to identify a person by their face"""
128
- if not self.face_tracker or not self.face_tracker.enabled:
129
- return None
130
-
131
- x, y, w, h = bbox
132
- face_region = frame[max(0, y):y+h, max(0, x):x+w]
133
-
134
- faces = self.face_tracker.detect_faces(face_region)
135
- if len(faces) == 0:
136
- return None
137
-
138
- # Get the largest face in the region
139
- largest_face = max(faces, key=lambda f: f[2]*f[3])
140
- fx, fy, fw, fh = largest_face
141
-
142
- # Check if we've seen this face before
143
- for track_id, info in self.tracked_workers.items():
144
- if 'face_bbox' not in info:
145
- continue
146
-
147
- # Compare with stored face
148
- iou = self._calculate_iou((fx, fy, fw, fh), info['face_bbox'])
149
- if iou > 0.5: # Threshold for face matching
150
- return track_id
151
-
152
- return None
153
-
154
- def update(self, frame: np.ndarray, detections: List[Dict], timestamp: float) -> List[Dict]:
155
- """Update tracker with new detections and return unique violations"""
156
  current_time = time.time()
157
- unique_violations = []
 
158
 
159
- # First pass - try to match with existing tracks
160
  for det in detections:
161
  bbox = det['bbox']
162
- violation_type = det['violation']
163
  confidence = det['confidence']
164
- x, y, w, h = bbox
165
- center = (x + w/2, y + h/2)
166
-
167
- # Try face recognition for helmet violations
168
- track_id = None
169
- if violation_type == "no_helmet" and self.face_tracker:
170
- track_id = self._get_face_id(frame, bbox)
171
 
172
- # If no face match, try position matching
173
- if track_id is None:
174
- for tid, info in self.tracked_workers.items():
175
- if current_time - info['last_seen'] > Config.TRACK_BUFFER / 30.0:
176
- continue
177
-
178
- if 'position' in info and self._is_same_position(center, info['position']):
179
- track_id = tid
180
- break
181
 
182
- # If still no match, create new track
183
- if track_id is None:
184
- track_id = self.next_id
185
  self.next_id += 1
186
- self.tracked_workers[track_id] = {
187
- 'last_seen': current_time,
188
- 'violations': set(),
189
- 'position': center
 
 
 
 
 
 
190
  }
 
191
 
192
- # Store face if available
193
- if violation_type == "no_helmet" and self.face_tracker:
194
- faces = self.face_tracker.detect_faces(frame[y:y+h, x:x+w])
195
- if len(faces) > 0:
196
- self.tracked_workers[track_id]['face_bbox'] = max(
197
- faces, key=lambda f: f[2]*f[3])
198
 
199
- # Update track info
200
- self.tracked_workers[track_id].update({
 
201
  'last_seen': current_time,
202
- 'position': center,
203
- 'bbox': bbox
204
- })
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- # Check if this is a new violation for this track
207
- if violation_type not in self.tracked_workers[track_id]['violations']:
208
- # Check if this violation type has been seen recently
209
- recent_violations = [t for t in self.violation_history[violation_type]
210
- if current_time - t < Config.VIOLATION_COOLDOWN]
211
 
212
- if not recent_violations:
213
- self.tracked_workers[track_id]['violations'].add(violation_type)
214
- self.violation_history[violation_type].append(current_time)
 
 
 
 
215
 
216
- unique_violations.append({
217
- 'track_id': track_id,
218
- 'violation': violation_type,
219
- 'bbox': bbox,
220
- 'confidence': confidence,
221
- 'timestamp': timestamp
222
- })
 
 
223
 
224
- # Clean up old tracks
225
- stale_ids = [tid for tid, info in self.tracked_workers.items()
226
- if current_time - info['last_seen'] > Config.TRACK_BUFFER / 30.0]
227
- for tid in stale_ids:
228
- del self.tracked_workers[tid]
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- return unique_violations
 
 
 
231
 
232
- # ========================== # Video Processor # ==========================
233
- class VideoProcessor:
234
- def __init__(self):
235
- self.model = self._load_model()
236
- self.tracker = ViolationTracker()
237
 
238
- def _load_model(self):
239
- """Load YOLOv8 model with fallback"""
 
240
  try:
241
- if os.path.exists(Config.MODEL_PATH):
242
- model = YOLO(Config.MODEL_PATH).to(Config.DEVICE)
243
- logger.info(f"Loaded custom model from {Config.MODEL_PATH}")
244
- else:
245
- model = YOLO(Config.FALLBACK_MODEL).to(Config.DEVICE)
246
- logger.warning("Using fallback YOLOv8n model")
247
- return model
248
  except Exception as e:
249
- logger.error(f"Failed to load model: {e}")
250
- raise
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- def _draw_detection(self, frame: np.ndarray, detection: Dict) -> np.ndarray:
253
- """Draw a single detection on the frame"""
254
- x, y, w, h = detection['bbox']
255
- label = detection['violation']
256
- confidence = detection['confidence']
257
- track_id = detection.get('track_id', 0)
258
-
259
- color = Config.COLOR_MAP.get(label, (0, 0, 255))
260
-
261
- # Draw bounding box
262
- cv2.rectangle(frame, (x, y), (x+w, y+h), color, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- # Draw label background
265
- label_text = f"{label} (ID: {track_id})"
266
- (text_width, text_height), _ = cv2.getTextSize(
267
- label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
268
 
269
- cv2.rectangle(frame, (x, y - text_height - 10),
270
- (x + text_width + 10, y), color, -1)
271
 
272
- # Draw label text
273
- cv2.putText(frame, label_text, (x + 5, y - 5),
274
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
 
 
275
 
276
- # Draw confidence
277
- conf_text = f"{confidence:.2f}"
278
- cv2.putText(frame, conf_text, (x + 5, y + h + 15),
279
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
280
 
281
- return frame
 
 
 
 
 
 
 
 
 
 
282
 
283
- def _process_frame(self, frame: np.ndarray, frame_idx: int, fps: float) -> Tuple[List[Dict], np.ndarray]:
284
- """Process a single video frame"""
285
- timestamp = frame_idx / fps
286
-
287
- # Run YOLO detection
288
- results = self.model(frame, verbose=False)
289
- detections = []
290
-
291
- for result in results:
292
- for box in result.boxes:
293
- cls_id = int(box.cls)
294
- conf = float(box.conf)
295
- label = Config.VIOLATION_LABELS.get(cls_id)
296
-
297
- if label is None or conf < Config.MIN_VIOLATION_CONFIDENCE:
298
- continue
299
-
300
- # Convert bounding box to (x, y, w, h) format
301
- x1, y1, x2, y2 = map(int, box.xyxy[0])
302
- detections.append({
303
- 'bbox': (x1, y1, x2 - x1, y2 - y1),
304
- 'violation': label,
305
- 'confidence': conf
306
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- # Update tracker and get unique violations
309
- unique_violations = self.tracker.update(frame, detections, timestamp)
 
 
 
 
 
 
 
 
 
 
310
 
311
- # Draw detections on frame
312
- output_frame = frame.copy()
313
- for violation in unique_violations:
314
- output_frame = self._draw_detection(output_frame, violation)
315
 
316
- return unique_violations, output_frame
317
-
318
- def process_video(self, video_path: str) -> Tuple[List[Dict], List[Tuple[float, np.ndarray]]]:
319
- """Process entire video file"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  cap = cv2.VideoCapture(video_path)
321
  if not cap.isOpened():
322
- raise ValueError(f"Could not open video: {video_path}")
323
-
324
- fps = cap.get(cv2.CAP_PROP_FPS)
325
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
326
- violations = []
 
 
 
 
 
 
327
  snapshots = []
328
-
329
- logger.info(f"Processing video: {video_path} ({total_frames} frames, {fps:.1f} FPS)")
330
-
331
- frame_idx = 0
332
- while cap.isOpened():
333
- ret, frame = cap.read()
334
- if not ret:
335
- break
 
 
 
 
 
336
 
337
- # Skip frames for efficiency
338
- if frame_idx % Config.FRAME_SKIP != 0:
339
- frame_idx += 1
340
- continue
 
341
 
342
- # Process frame
343
- frame_violations, output_frame = self._process_frame(frame, frame_idx, fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
- # Save snapshots for new violations
346
- for violation in frame_violations:
347
- violations.append(violation)
348
 
349
- # Save snapshot
350
- timestamp = frame_idx / fps
351
- snapshot_path = os.path.join(
352
- Config.OUTPUT_DIR,
353
- f"violation_{violation['violation']}_id{violation['track_id']}_{int(timestamp*100)}.jpg")
 
 
 
354
 
355
- cv2.imwrite(snapshot_path, output_frame,
356
- [cv2.IMWRITE_JPEG_QUALITY, Config.SNAPSHOT_QUALITY])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
- snapshots.append((timestamp, snapshot_path))
359
- logger.info(f"Detected {violation['violation']} at {timestamp:.2f}s (ID: {violation['track_id']})")
360
-
361
- frame_idx += 1
362
-
363
- # Yield progress
364
- if frame_idx % 10 == 0:
365
- progress = frame_idx / total_frames
366
- yield progress, violations, snapshots
367
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  cap.release()
369
- yield 1.0, violations, snapshots # Final yield
370
-
371
- def calculate_safety_score(self, violations: List[Dict]) -> int:
372
- """Calculate safety compliance score (0-100)"""
373
- penalty = 0
374
- violation_types = set()
375
-
376
- for violation in violations:
377
- violation_types.add(violation['violation'])
378
-
379
- for v_type in violation_types:
380
- penalty += Config.PENALTIES.get(v_type, 0)
381
-
382
- return max(0, 100 - penalty)
383
 
384
- # ========================== # Gradio Interface # ==========================
385
- def setup_output_dir():
386
- """Ensure output directory exists"""
387
- os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
388
- logger.info(f"Output directory: {Config.OUTPUT_DIR}")
389
-
390
- def format_violation_table(violations: List[Dict]) -> str:
391
- """Format violations as markdown table"""
392
- if not violations:
393
- return "No violations detected"
394
-
395
- table = "| Violation | Worker ID | Time (s) | Confidence |\n"
396
- table += "|-----------|-----------|----------|------------|\n"
397
-
398
- for v in sorted(violations, key=lambda x: x['timestamp']):
399
- table += f"| {v['violation']} | {v['track_id']} | {v['timestamp']:.2f} | {v['confidence']:.2f} |\n"
400
-
401
- return table
402
 
403
- def format_snapshots(snapshots: List[Tuple[float, str]]) -> str:
404
- """Format snapshots as markdown"""
405
- if not snapshots:
406
- return "No violations captured"
407
-
408
- markdown = ""
409
- for timestamp, path in snapshots:
410
- filename = os.path.basename(path)
411
- markdown += f"### Violation at {timestamp:.2f}s\n\n"
412
- markdown += f"![Violation](file/{path})\n\n"
413
-
414
- return markdown
415
 
416
- def process_video_wrapper(video_file: str):
417
- """Wrapper for Gradio interface"""
418
- setup_output_dir()
419
-
420
- # Create temporary video path
421
- temp_video_path = os.path.join(Config.OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
422
- with open(temp_video_path, "wb") as f:
423
- f.write(open(video_file, "rb").read())
424
-
425
- processor = VideoProcessor()
426
-
427
- try:
428
- for progress, violations, snapshots in processor.process_video(temp_video_path):
429
- yield (
430
- f"Processing... {progress*100:.1f}% complete",
431
- "",
432
- "",
433
- ""
434
- )
435
 
436
- # Calculate final results
437
- score = processor.calculate_safety_score(violations)
438
- violation_table = format_violation_table(violations)
439
- snapshots_md = format_snapshots(snapshots)
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  yield (
442
- "Processing complete",
443
- f"Safety Score: {score}%",
444
  violation_table,
445
- snapshots_md
 
 
 
446
  )
 
 
 
 
 
 
 
 
 
 
 
447
 
448
- finally:
449
- if os.path.exists(temp_video_path):
450
- os.remove(temp_video_path)
451
-
452
- # ========================== # Main Application # ==========================
453
- def create_interface():
454
- """Create Gradio interface"""
455
- with gr.Blocks(title="Safety Compliance Analyzer") as interface:
456
- gr.Markdown("# 🚧 Safety Compliance Video Analyzer")
457
- gr.Markdown("Upload site videos to detect safety violations (No Helmet, No Harness, etc.)")
458
-
459
- with gr.Row():
460
- with gr.Column():
461
- video_input = gr.Video(label="Upload Site Video")
462
- submit_btn = gr.Button("Analyze Video", variant="primary")
463
 
464
- with gr.Column():
465
- progress_out = gr.Textbox(label="Status")
466
- score_out = gr.Textbox(label="Safety Score")
467
- violations_out = gr.Markdown(label="Detected Violations")
468
- snapshots_out = gr.Markdown(label="Violation Snapshots")
469
-
470
- submit_btn.click(
471
- fn=process_video_wrapper,
472
- inputs=video_input,
473
- outputs=[progress_out, score_out, violations_out, snapshots_out]
474
- )
475
-
476
- return interface
 
 
 
 
 
 
477
 
478
  if __name__ == "__main__":
479
- setup_output_dir()
480
- interface = create_interface()
481
  interface.launch()
 
1
  import os
2
  import sys
3
+ import subprocess
4
  import logging
5
+ import warnings
6
  import cv2
 
 
7
  import gradio as gr
8
+ import torch
9
+ import numpy as np
10
  from ultralytics import YOLO
11
+ import time
12
+ from simple_salesforce import Salesforce
13
+ from reportlab.lib.pagesizes import letter
14
+ from reportlab.pdfgen import canvas
15
+ from reportlab.lib.units import inch
16
+ from io import BytesIO
17
+ import base64
18
+ from retrying import retry
19
+ import uuid
20
+ from multiprocessing import Pool, cpu_count
21
+ from functools import partial
22
+ import face_recognition
23
  from collections import defaultdict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # ========================== # Configuration and Setup # ==========================
26
+ os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
27
+ os.makedirs('/tmp/Ultralytics', exist_ok=True)
28
+
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
30
  logger = logging.getLogger(__name__)
31
 
32
+ # Suppress warnings
33
+ warnings.filterwarnings("ignore")
34
+
35
+ # ========================== # Enhanced Tracker Implementation # ==========================
36
+ class SafetyTracker:
37
+ def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
38
+ self.track_thresh = track_thresh
39
+ self.track_buffer = track_buffer
40
+ self.match_thresh = match_thresh
41
+ self.frame_rate = frame_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  self.next_id = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Trackers for different purposes
45
+ self.worker_tracks = {} # Active worker tracks
46
+ self.violation_history = defaultdict(dict) # Track violations per worker
47
+ self.face_encodings = {} # Store face encodings for helmet violations
48
+ self.position_history = defaultdict(list) # Track positions for non-helmet violations
49
+
50
+ # Cooldown periods (in seconds)
51
+ self.VIOLATION_COOLDOWNS = {
52
+ "no_helmet": 30.0,
53
+ "no_harness": 20.0,
54
+ "unsafe_posture": 15.0,
55
+ "unsafe_zone": 10.0,
56
+ "improper_tool_use": 15.0
57
+ }
58
+
59
+ def update(self, detections, frame):
60
+ """Update tracks with new detections and check for violations"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  current_time = time.time()
62
+ active_violations = []
63
+ new_violations = []
64
 
 
65
  for det in detections:
66
  bbox = det['bbox']
67
+ label = det['violation']
68
  confidence = det['confidence']
 
 
 
 
 
 
 
69
 
70
+ # For helmet violations, use face recognition
71
+ if label == "no_helmet":
72
+ worker_id = self._match_by_face(bbox, frame)
73
+ else:
74
+ # For other violations, use position tracking
75
+ worker_id = self._match_by_position(bbox, label)
 
 
 
76
 
77
+ if worker_id is None:
78
+ worker_id = self.next_id
 
79
  self.next_id += 1
80
+
81
+ # Check if this is a new violation for this worker
82
+ if self._is_new_violation(worker_id, label, current_time):
83
+ # Record the violation
84
+ violation = {
85
+ 'worker_id': worker_id,
86
+ 'violation': label,
87
+ 'confidence': confidence,
88
+ 'bbox': bbox,
89
+ 'timestamp': current_time
90
  }
91
+ new_violations.append(violation)
92
 
93
+ # Update violation history
94
+ self.violation_history[worker_id][label] = current_time
95
+
96
+ # For helmet violations, store face encoding
97
+ if label == "no_helmet":
98
+ self._store_face_encoding(worker_id, bbox, frame)
99
 
100
+ # Keep track of active workers
101
+ self.worker_tracks[worker_id] = {
102
+ 'bbox': bbox,
103
  'last_seen': current_time,
104
+ 'label': label
105
+ }
106
+
107
+ # Clean up old tracks
108
+ self._cleanup_tracks(current_time)
109
+
110
+ return new_violations
111
+
112
+ def _match_by_face(self, bbox, frame):
113
+ """Match detection by face recognition (for helmet violations)"""
114
+ x, y, w, h = bbox
115
+ face_region = frame[max(0, int(y-h/2)):int(y+h/2), max(0, int(x-w/2)):int(x+w/2)]
116
+
117
+ if face_region.size == 0:
118
+ return None
119
 
120
+ try:
121
+ # Get face encodings from current detection
122
+ face_locations = face_recognition.face_locations(face_region)
123
+ if not face_locations:
124
+ return None
125
 
126
+ current_encoding = face_recognition.face_encodings(face_region, face_locations)[0]
127
+
128
+ # Compare with known faces
129
+ for worker_id, encodings in self.face_encodings.items():
130
+ matches = face_recognition.compare_faces(encodings, current_encoding, tolerance=0.6)
131
+ if any(matches):
132
+ return worker_id
133
 
134
+ except Exception as e:
135
+ logger.warning(f"Face recognition error: {e}")
136
+
137
+ return None
138
+
139
+ def _match_by_position(self, bbox, label):
140
+ """Match detection by position (for non-helmet violations)"""
141
+ x, y, w, h = bbox
142
+ current_pos = (x, y)
143
 
144
+ for worker_id, positions in self.position_history.items():
145
+ if label not in self.violation_history[worker_id]:
146
+ continue
147
+
148
+ # Check if current position is near any previous positions for this worker
149
+ for pos in positions:
150
+ distance = np.sqrt((current_pos[0]-pos[0])**2 + (current_pos[1]-pos[1])**2)
151
+ if distance < 100: # Within 100 pixels
152
+ return worker_id
153
+
154
+ return None
155
+
156
+ def _is_new_violation(self, worker_id, label, current_time):
157
+ """Check if this is a new violation for this worker"""
158
+ if label not in self.violation_history[worker_id]:
159
+ return True
160
 
161
+ last_detection = self.violation_history[worker_id][label]
162
+ cooldown = self.VIOLATION_COOLDOWNS.get(label, 10.0)
163
+
164
+ return (current_time - last_detection) > cooldown
165
 
166
+ def _store_face_encoding(self, worker_id, bbox, frame):
167
+ """Store face encoding for a worker"""
168
+ x, y, w, h = bbox
169
+ face_region = frame[max(0, int(y-h/2)):int(y+h/2), max(0, int(x-w/2)):int(x+w/2)]
 
170
 
171
+ if face_region.size == 0:
172
+ return
173
+
174
  try:
175
+ face_locations = face_recognition.face_locations(face_region)
176
+ if face_locations:
177
+ encoding = face_recognition.face_encodings(face_region, face_locations)[0]
178
+ if worker_id not in self.face_encodings:
179
+ self.face_encodings[worker_id] = []
180
+ self.face_encodings[worker_id].append(encoding)
 
181
  except Exception as e:
182
+ logger.warning(f"Error storing face encoding: {e}")
183
+
184
+ def _cleanup_tracks(self, current_time):
185
+ """Clean up old tracks and face encodings"""
186
+ # Remove inactive workers
187
+ inactive_ids = [
188
+ worker_id for worker_id, track in self.worker_tracks.items()
189
+ if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
190
+ ]
191
+
192
+ for worker_id in inactive_ids:
193
+ self.worker_tracks.pop(worker_id, None)
194
+ self.position_history.pop(worker_id, None)
195
 
196
+ # Keep face encodings for a longer period (for helmet violations)
197
+ if (current_time - max(self.violation_history[worker_id].values(), default=0)) > 300: # 5 minutes
198
+ self.face_encodings.pop(worker_id, None)
199
+ self.violation_history.pop(worker_id, None)
200
+
201
+ # ========================== # Optimized Configuration # ==========================
202
+ CONFIG = {
203
+ "MODEL_PATH": "yolov8_safety.pt",
204
+ "FALLBACK_MODEL": "yolov8n.pt",
205
+ "OUTPUT_DIR": "static/output",
206
+ "VIOLATION_LABELS": {
207
+ 0: "no_helmet",
208
+ 1: "no_harness",
209
+ 2: "unsafe_posture",
210
+ 3: "unsafe_zone",
211
+ 4: "improper_tool_use"
212
+ },
213
+ "CLASS_COLORS": {
214
+ "no_helmet": (0, 0, 255), # Red
215
+ "no_harness": (0, 165, 255), # Orange
216
+ "unsafe_posture": (0, 255, 0), # Green
217
+ "unsafe_zone": (255, 0, 0), # Blue
218
+ "improper_tool_use": (255, 255, 0) # Cyan
219
+ },
220
+ "DISPLAY_NAMES": {
221
+ "no_helmet": "No Helmet Violation",
222
+ "no_harness": "No Harness Violation",
223
+ "unsafe_posture": "Unsafe Posture",
224
+ "unsafe_zone": "Unsafe Zone Entry",
225
+ "improper_tool_use": "Improper Tool Use"
226
+ },
227
+ "SF_CREDENTIALS": {
228
+ "username": "prashanth1ai@safety.com",
229
+ "password": "SaiPrash461",
230
+ "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
231
+ "domain": "login"
232
+ },
233
+ "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
234
+ "CONFIDENCE_THRESHOLDS": {
235
+ "no_helmet": 0.5,
236
+ "no_harness": 0.3,
237
+ "unsafe_posture": 0.3,
238
+ "unsafe_zone": 0.3,
239
+ "improper_tool_use": 0.3
240
+ },
241
+ "MIN_VIOLATION_FRAMES": 1,
242
+ "FRAME_SKIP": 2,
243
+ "BATCH_SIZE": 16,
244
+ "PARALLEL_WORKERS": max(1, cpu_count() - 1),
245
+ "SNAPSHOT_QUALITY": 95,
246
+ "FACE_RECOGNITION_INTERVAL": 5 # Process face recognition every 5 frames
247
+ }
248
+
249
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
250
+ logger.info(f"Using device: {device}")
251
+
252
+ def load_model():
253
+ try:
254
+ if os.path.isfile(CONFIG["MODEL_PATH"]):
255
+ model_path = CONFIG["MODEL_PATH"]
256
+ logger.info(f"Model loaded: {model_path}")
257
+ else:
258
+ model_path = CONFIG["FALLBACK_MODEL"]
259
+ logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
260
+ if not os.path.isfile(model_path):
261
+ logger.info(f"Downloading fallback model: {model_path}")
262
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
263
+
264
+ model = YOLO(model_path).to(device)
265
+ logger.info(f"Model classes: {model.names}")
266
+ return model
267
+ except Exception as e:
268
+ logger.error(f"Failed to load model: {e}")
269
+ raise
270
+
271
+ model = load_model()
272
+
273
+ # ========================== # Helper Functions # ==========================
274
+ def preprocess_frame(frame):
275
+ """Apply basic preprocessing to enhance detection"""
276
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
277
+ return frame
278
+
279
+ def draw_detections(frame, detections):
280
+ """Draw bounding boxes and labels on detection frame with improved visibility"""
281
+ result_frame = frame.copy()
282
+
283
+ for det in detections:
284
+ label = det.get("violation", "Unknown")
285
+ confidence = det.get("confidence", 0.0)
286
+ x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
287
+ worker_id = det.get("worker_id", "Unknown")
288
+
289
+ x1 = int(x - w/2)
290
+ y1 = int(y - h/2)
291
+ x2 = int(x + w/2)
292
+ y2 = int(y + h/2)
293
 
294
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
 
 
 
295
 
296
+ # Draw thicker rectangle with border
297
+ cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
298
 
299
+ # Add black background behind text
300
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
301
+ text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
302
+ cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
303
+ cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
304
 
305
+ # Add confidence score
306
+ conf_text = f"Conf: {confidence:.2f}"
307
+ cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
 
308
 
309
+ return result_frame
310
+
311
+ def calculate_safety_score(violations):
312
+ """Calculate safety score based on detected violations"""
313
+ penalties = {
314
+ "no_helmet": 25,
315
+ "no_harness": 30,
316
+ "unsafe_posture": 20,
317
+ "unsafe_zone": 35,
318
+ "improper_tool_use": 25
319
+ }
320
 
321
+ # Count unique violation types
322
+ unique_violations = set()
323
+ for v in violations:
324
+ violation_type = v.get("violation", "Unknown")
325
+ unique_violations.add(violation_type)
326
+
327
+ total_penalty = sum(penalties.get(v, 0) for v in unique_violations)
328
+ score = max(0, 100 - total_penalty)
329
+ return score
330
+
331
+ def generate_violation_pdf(violations, score):
332
+ """Generate a PDF report for the detected violations"""
333
+ try:
334
+ pdf_filename = f"violations_{int(time.time())}.pdf"
335
+ pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
336
+ pdf_file = BytesIO()
337
+ c = canvas.Canvas(pdf_file, pagesize=letter)
338
+
339
+ # Title
340
+ c.setFont("Helvetica-Bold", 16)
341
+ c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
342
+
343
+ # Basic Information
344
+ c.setFont("Helvetica", 12)
345
+ c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
346
+ c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
347
+
348
+ # Safety Score
349
+ c.setFont("Helvetica-Bold", 14)
350
+ c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
351
+
352
+ # Violation Summary
353
+ y_position = 8.2 * inch
354
+ c.setFont("Helvetica-Bold", 12)
355
+ c.drawString(1 * inch, y_position, "Summary:")
356
+ y_position -= 0.3 * inch
357
+
358
+ c.setFont("Helvetica", 10)
359
+ summary_data = {
360
+ "Total Violations Found": len(violations),
361
+ "Unique Violation Types": len(set(v['violation'] for v in violations)),
362
+ "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
363
+ }
364
+
365
+ for key, value in summary_data.items():
366
+ c.drawString(1 * inch, y_position, f"{key}: {value}")
367
+ y_position -= 0.25 * inch
368
+
369
+ # Detailed Violations
370
+ y_position -= 0.5 * inch
371
+ c.setFont("Helvetica-Bold", 12)
372
+ c.drawString(1 * inch, y_position, "Violation Details:")
373
+ y_position -= 0.3 * inch
374
+
375
+ c.setFont("Helvetica", 10)
376
+ for v in violations:
377
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
378
+ worker_id = v.get("worker_id", "Unknown")
379
+ time_str = f"{v.get('timestamp', 0.0):.2f}s"
380
+ conf_str = f"{v.get('confidence', 0.0):.2f}"
381
+
382
+ violation_text = f"- {display_name} by Worker {worker_id} at {time_str} (Confidence: {conf_str})"
383
+ c.drawString(1.2 * inch, y_position, violation_text)
384
+ y_position -= 0.2 * inch
385
+
386
+ if y_position < 1 * inch:
387
+ c.showPage()
388
+ c.setFont("Helvetica", 10)
389
+ y_position = 10 * inch
390
+
391
+ c.save()
392
+ pdf_file.seek(0)
393
+
394
+ # Save PDF file
395
+ with open(pdf_path, "wb") as f:
396
+ f.write(pdf_file.getvalue())
397
+
398
+ public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
399
+ logger.info(f"PDF generated: {public_url}")
400
+ return pdf_path, public_url, pdf_file
401
+ except Exception as e:
402
+ logger.error(f"Error generating PDF: {e}")
403
+ return "", "", None
404
+
405
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
406
+ def connect_to_salesforce():
407
+ """Connect to Salesforce with retry logic"""
408
+ try:
409
+ sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
410
+ logger.info("Connected to Salesforce")
411
+ sf.describe()
412
+ return sf
413
+ except Exception as e:
414
+ logger.error(f"Salesforce connection failed: {e}")
415
+ raise
416
+
417
+ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
418
+ """Upload PDF report to Salesforce"""
419
+ try:
420
+ if not pdf_file:
421
+ logger.error("No PDF file provided for upload")
422
+ return ""
423
+
424
+ encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
425
+ content_version_data = {
426
+ "Title": f"Safety_Violation_Report_{int(time.time())}",
427
+ "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
428
+ "VersionData": encoded_pdf,
429
+ "FirstPublishLocationId": report_id
430
+ }
431
+ content_version = sf.ContentVersion.create(content_version_data)
432
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
433
+
434
+ if not result['records']:
435
+ logger.error("Failed to retrieve ContentVersion")
436
+ return ""
437
+
438
+ file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
439
+ logger.info(f"PDF uploaded to Salesforce: {file_url}")
440
+ return file_url
441
+ except Exception as e:
442
+ logger.error(f"Error uploading PDF to Salesforce: {e}")
443
+ return ""
444
+
445
+ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
446
+ """Push violation report to Salesforce"""
447
+ try:
448
+ sf = connect_to_salesforce()
449
+
450
+ # Format violations for Salesforce
451
+ violations_text = ""
452
+ for v in violations:
453
+ display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
454
+ worker_id = v.get('worker_id', 'Unknown')
455
+ timestamp = v.get('timestamp', 0.0)
456
+ confidence = v.get('confidence', 0.0)
457
+
458
+ violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
459
 
460
+ if not violations_text:
461
+ violations_text = "No violations detected."
462
+
463
+ pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
464
+
465
+ record_data = {
466
+ "Compliance_Score__c": score,
467
+ "Violations_Found__c": len(violations),
468
+ "Violations_Details__c": violations_text,
469
+ "Status__c": "Pending",
470
+ "PDF_Report_URL__c": pdf_url
471
+ }
472
 
473
+ logger.info(f"Creating Salesforce record with data: {record_data}")
 
 
 
474
 
475
+ try:
476
+ record = sf.Safety_Video_Report__c.create(record_data)
477
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
478
+ except Exception as e:
479
+ logger.error(f"Failed to create Safety_Video_Report__c: {e}")
480
+ record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
481
+ logger.warning(f"Fell back to Account record: {record['id']}")
482
+
483
+ record_id = record["id"]
484
+
485
+ if pdf_file:
486
+ uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
487
+ if uploaded_url:
488
+ try:
489
+ sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
490
+ logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
491
+ except Exception as e:
492
+ logger.error(f"Failed to update Safety_Video_Report__c: {e}")
493
+ sf.Account.update(record_id, {"Description": uploaded_url})
494
+ logger.info(f"Updated Account record {record_id} with PDF URL")
495
+ pdf_url = uploaded_url
496
+
497
+ return record_id, pdf_url
498
+ except Exception as e:
499
+ logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
500
+ return None, ""
501
+
502
+ def process_video(video_data):
503
+ """Process video to detect safety violations with enhanced tracking"""
504
+ try:
505
+ os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
506
+ logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
507
+
508
+ video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
509
+ with open(video_path, "wb") as f:
510
+ f.write(video_data)
511
+ logger.info(f"Video saved: {video_path}")
512
+
513
  cap = cv2.VideoCapture(video_path)
514
  if not cap.isOpened():
515
+ os.remove(video_path)
516
+ raise ValueError("Could not open video file")
517
+
518
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
519
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
520
+ duration = total_frames / fps
521
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
522
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
523
+ logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
524
+
525
+ tracker = SafetyTracker(frame_rate=fps)
526
  snapshots = []
527
+ start_time = time.time()
528
+ frame_skip = CONFIG["FRAME_SKIP"]
529
+ processed_frames = 0
530
+ frame_counter = 0
531
+
532
+ while processed_frames < total_frames:
533
+ batch_frames = []
534
+ batch_indices = []
535
+
536
+ for _ in range(CONFIG["BATCH_SIZE"]):
537
+ frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
538
+ if frame_idx >= total_frames:
539
+ break
540
 
541
+ ret, frame = cap.read()
542
+ if not ret:
543
+ break
544
+
545
+ frame = preprocess_frame(frame)
546
 
547
+ # Skip frames if needed
548
+ for _ in range(frame_skip - 1):
549
+ if not cap.grab():
550
+ break
551
+
552
+ batch_frames.append(frame)
553
+ batch_indices.append(frame_idx)
554
+ processed_frames += 1
555
+ frame_counter += 1
556
+
557
+ if not batch_frames:
558
+ break
559
+
560
+ # Process batch with YOLO model
561
+ results = model(batch_frames, device=device, conf=0.1, verbose=False)
562
 
563
+ for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
564
+ current_time = frame_idx / fps
 
565
 
566
+ # Update progress every second
567
+ if time.time() - start_time > 1.0:
568
+ progress = (processed_frames / total_frames) * 100
569
+ yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", ""
570
+ start_time = time.time()
571
+
572
+ boxes = result.boxes
573
+ detections = []
574
 
575
+ for box in boxes:
576
+ cls = int(box.cls)
577
+ conf = float(box.conf)
578
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
579
+
580
+ if label is None:
581
+ continue
582
+
583
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
584
+ continue
585
+
586
+ bbox = box.xywh.cpu().numpy()[0]
587
+ detections.append({
588
+ "bbox": bbox,
589
+ "violation": label,
590
+ "confidence": conf
591
+ })
592
+
593
+ if not detections:
594
+ continue
595
+
596
+ # Update tracker with new detections
597
+ new_violations = tracker.update(detections, batch_frames[i])
598
 
599
+ # Process new violations
600
+ for violation in new_violations:
601
+ # Take snapshot for the new violation
602
+ snapshot_frame = batch_frames[i].copy()
603
+ snapshot_frame = draw_detections(snapshot_frame, [violation])
604
+
605
+ # Add timestamp to snapshot
606
+ cv2.putText(
607
+ snapshot_frame,
608
+ f"Time: {violation['timestamp']:.2f}s",
609
+ (10, 30),
610
+ cv2.FONT_HERSHEY_SIMPLEX,
611
+ 0.7,
612
+ (255, 255, 255),
613
+ 2
614
+ )
615
+
616
+ # Save snapshot with high quality
617
+ snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
618
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
619
+
620
+ cv2.imwrite(
621
+ snapshot_path,
622
+ snapshot_frame,
623
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
624
+ )
625
+
626
+ snapshots.append({
627
+ "violation": violation['violation'],
628
+ "worker_id": violation['worker_id'],
629
+ "timestamp": violation['timestamp'],
630
+ "snapshot_path": snapshot_path,
631
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
632
+ })
633
+
634
+ logger.info(f"Captured snapshot for {violation['violation']} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
635
+
636
  cap.release()
637
+ if os.path.exists(video_path):
638
+ os.remove(video_path)
639
+
640
+ processing_time = time.time() - start_time
641
+ logger.info(f"Processing complete in {processing_time:.2f}s")
 
 
 
 
 
 
 
 
 
642
 
643
+ # Get all unique violations from tracker
644
+ violations = []
645
+ for worker_id, worker_violations in tracker.violation_history.items():
646
+ for label, detection_time in worker_violations.items():
647
+ violations.append({
648
+ "worker_id": worker_id,
649
+ "violation": label,
650
+ "timestamp": detection_time
651
+ })
 
 
 
 
 
 
 
 
 
652
 
653
+ if not violations:
654
+ logger.info("No violations detected after processing")
655
+ yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
656
+ return
 
 
 
 
 
 
 
 
657
 
658
+ # Calculate safety score
659
+ score = calculate_safety_score(violations)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
 
661
+ # Generate PDF report
662
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
 
 
663
 
664
+ # Push report to Salesforce
665
+ report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
666
+
667
+ # Format violations table for display
668
+ violation_table = "| Violation | Worker ID | Time (s) |\n"
669
+ violation_table += "|-----------|-----------|----------|\n"
670
+
671
+ for v in sorted(violations, key=lambda x: x.get("timestamp", 0.0)):
672
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
673
+ worker_id = v.get("worker_id", "Unknown")
674
+ timestamp = v.get("timestamp", 0.0)
675
+
676
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} |\n"
677
+
678
+ # Format snapshots for display
679
+ snapshots_text = ""
680
+ for s in snapshots:
681
+ display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
682
+ worker_id = s.get("worker_id", "Unknown")
683
+ timestamp = s.get("timestamp", 0.0)
684
+
685
+ snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
686
+ snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
687
+
688
+ if not snapshots_text:
689
+ snapshots_text = "No snapshots captured."
690
+
691
  yield (
 
 
692
  violation_table,
693
+ f"Safety Score: {score}%",
694
+ snapshots_text,
695
+ f"Salesforce Record ID: {report_id or 'N/A'}",
696
+ final_pdf_url or "N/A"
697
  )
698
+
699
+ except Exception as e:
700
+ logger.error(f"Error processing video: {e}", exc_info=True)
701
+ if 'video_path' in locals() and os.path.exists(video_path):
702
+ os.remove(video_path)
703
+ yield f"Error processing video: {e}", "", "", "", ""
704
+
705
+ def gradio_interface(video_file):
706
+ """Gradio interface for the video processing"""
707
+ if not video_file:
708
+ return "No file uploaded.", "", "No file uploaded.", "", ""
709
 
710
+ try:
711
+ with open(video_file, "rb") as f:
712
+ video_data = f.read()
713
+
714
+ for status, score, snapshots_text, record_id, details_url in process_video(video_data):
715
+ yield status, score, snapshots_text, record_id, details_url
 
 
 
 
 
 
 
 
 
716
 
717
+ except Exception as e:
718
+ logger.error(f"Error in Gradio interface: {e}", exc_info=True)
719
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
720
+
721
+ # ========================== # Gradio Interface # ==========================
722
+ interface = gr.Interface(
723
+ fn=gradio_interface,
724
+ inputs=gr.Video(label="Upload Site Video"),
725
+ outputs=[
726
+ gr.Markdown(label="Detected Safety Violations"),
727
+ gr.Textbox(label="Compliance Score"),
728
+ gr.Markdown(label="Snapshots"),
729
+ gr.Textbox(label="Salesforce Record ID"),
730
+ gr.Textbox(label="Violation Details URL")
731
+ ],
732
+ title="Worksite Safety Violation Analyzer",
733
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
734
+ allow_flagging="never"
735
+ )
736
 
737
  if __name__ == "__main__":
738
+ logger.info("Launching Enhanced Safety Analyzer App...")
 
739
  interface.launch()