PrashanthB461 commited on
Commit
c6381a2
·
verified ·
1 Parent(s): a605375

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +490 -262
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
  import logging
4
  import warnings
5
  import cv2
@@ -15,6 +16,10 @@ from reportlab.lib.units import inch
15
  from io import BytesIO
16
  import base64
17
  from retrying import retry
 
 
 
 
18
  from collections import defaultdict
19
 
20
  # ========================== # Configuration and Setup # ==========================
@@ -23,98 +28,99 @@ os.makedirs('/tmp/Ultralytics', exist_ok=True)
23
 
24
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
25
  logger = logging.getLogger(__name__)
26
- warnings.filterwarnings("ignore")
27
-
28
- # ========================== # Position-Based Tracker (No Face Recognition) # ==========================
29
- class SafetyTracker:
30
- def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
31
- self.track_thresh = track_thresh
32
- self.track_buffer = track_buffer
33
- self.match_thresh = match_thresh
34
- self.frame_rate = frame_rate
35
- self.next_id = 1
36
 
37
- self.worker_tracks = {} # Active worker tracks
38
- self.violation_history = defaultdict(dict) # Track violations per worker
39
- self.position_history = defaultdict(list) # Track positions for all violations
 
 
40
 
41
- self.VIOLATION_COOLDOWNS = {
42
- "no_helmet": 30.0,
43
- "no_harness": 20.0,
44
- "unsafe_posture": 15.0,
45
- "unsafe_zone": 10.0,
46
- "improper_tool_use": 15.0
47
- }
48
-
49
- def update(self, detections):
50
- current_time = time.time()
51
- new_violations = []
52
 
53
- for det in detections:
54
- bbox = det['bbox']
55
- label = det['violation']
56
- confidence = det['confidence']
57
-
58
- worker_id = self._match_by_position(bbox, label)
59
-
60
- if worker_id is None:
61
- worker_id = self.next_id
62
- self.next_id += 1
63
-
64
- if self._is_new_violation(worker_id, label, current_time):
65
- violation = {
66
- 'worker_id': worker_id,
67
- 'violation': label,
68
- 'confidence': confidence,
69
- 'bbox': bbox,
70
- 'timestamp': current_time
71
- }
72
- new_violations.append(violation)
73
- self.violation_history[worker_id][label] = current_time
74
 
75
- self.worker_tracks[worker_id] = {
76
- 'bbox': bbox,
77
- 'last_seen': current_time,
78
- 'label': label
79
- }
80
- self.position_history[worker_id].append((bbox[0], bbox[1]))
81
 
82
- self._cleanup_tracks(current_time)
83
- return new_violations
84
-
85
- def _match_by_position(self, bbox, label):
86
- x, y, w, h = bbox
87
- current_pos = (x, y)
88
 
89
- for worker_id, positions in self.position_history.items():
90
- if not positions:
91
- continue
92
-
93
- last_pos = positions[-1]
94
- distance = np.sqrt((current_pos[0]-last_pos[0])**2 + (current_pos[1]-last_pos[1])**2)
95
- if distance < 100: # Within 100 pixels
96
- return worker_id
97
- return None
98
-
99
- def _is_new_violation(self, worker_id, label, current_time):
100
- if label not in self.violation_history[worker_id]:
101
- return True
102
 
103
- last_detection = self.violation_history[worker_id][label]
104
- cooldown = self.VIOLATION_COOLDOWNS.get(label, 10.0)
105
- return (current_time - last_detection) > cooldown
106
-
107
- def _cleanup_tracks(self, current_time):
108
- inactive_ids = [
109
- worker_id for worker_id, track in self.worker_tracks.items()
110
- if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
111
- ]
 
 
 
 
 
 
 
 
 
 
112
 
113
- for worker_id in inactive_ids:
114
- self.worker_tracks.pop(worker_id, None)
115
- self.position_history.pop(worker_id, None)
116
- if (current_time - max(self.violation_history[worker_id].values(), default=0)) > 300:
117
- self.violation_history.pop(worker_id, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # ========================== # Optimized Configuration # ==========================
120
  CONFIG = {
@@ -129,11 +135,11 @@ CONFIG = {
129
  4: "improper_tool_use"
130
  },
131
  "CLASS_COLORS": {
132
- "no_helmet": (0, 0, 255),
133
- "no_harness": (0, 165, 255),
134
- "unsafe_posture": (0, 255, 0),
135
- "unsafe_zone": (255, 0, 0),
136
- "improper_tool_use": (255, 255, 0)
137
  },
138
  "DISPLAY_NAMES": {
139
  "no_helmet": "No Helmet Violation",
@@ -156,9 +162,17 @@ CONFIG = {
156
  "unsafe_zone": 0.3,
157
  "improper_tool_use": 0.3
158
  },
 
 
 
 
159
  "FRAME_SKIP": 2,
160
- "BATCH_SIZE": 8,
161
- "SNAPSHOT_QUALITY": 90
 
 
 
 
162
  }
163
 
164
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -167,301 +181,515 @@ logger.info(f"Using device: {device}")
167
  def load_model():
168
  try:
169
  if os.path.isfile(CONFIG["MODEL_PATH"]):
170
- model = YOLO(CONFIG["MODEL_PATH"]).to(device)
171
- logger.info(f"Loaded custom model: {CONFIG['MODEL_PATH']}")
172
  else:
173
- model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
174
- logger.warning("Using fallback YOLOv8n model")
 
 
 
 
 
 
175
  return model
176
  except Exception as e:
177
- logger.error(f"Model loading failed: {e}")
178
  raise
179
 
180
  model = load_model()
181
 
182
- # ========================== # Core Functions # ==========================
183
  def preprocess_frame(frame):
 
184
  frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
185
  return frame
186
 
187
  def draw_detections(frame, detections):
 
188
  result_frame = frame.copy()
 
189
  for det in detections:
190
- label = det["violation"]
191
- confidence = det["confidence"]
192
- x, y, w, h = det["bbox"]
193
- worker_id = det["worker_id"]
194
-
195
- x1, y1 = int(x - w/2), int(y - h/2)
196
- x2, y2 = int(x + w/2), int(y + h/2)
197
- color = CONFIG["CLASS_COLORS"][label]
 
 
 
198
 
 
199
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
200
- text = f"{CONFIG['DISPLAY_NAMES'][label]} (Worker {worker_id})"
201
- (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
202
- cv2.rectangle(result_frame, (x1, y1-th-10), (x1+tw+10, y1), (0,0,0), -1)
203
- cv2.putText(result_frame, text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
204
- cv2.putText(result_frame, f"Conf: {confidence:.2f}", (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 2)
 
 
 
 
 
 
205
  return result_frame
206
 
207
  def calculate_safety_score(violations):
 
208
  penalties = {
209
- "no_helmet": 25, "no_harness": 30, "unsafe_posture": 20,
210
- "unsafe_zone": 35, "improper_tool_use": 25
 
 
 
211
  }
212
- unique_violations = {v["violation"] for v in violations}
213
- return max(0, 100 - sum(penalties.get(v, 0) for v in unique_violations))
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def generate_violation_pdf(violations, score):
 
216
  try:
217
- pdf_buffer = BytesIO()
218
- c = canvas.Canvas(pdf_buffer, pagesize=letter)
 
 
219
 
220
- # Header
221
  c.setFont("Helvetica-Bold", 16)
222
- c.drawString(1*inch, 10*inch, "Worksite Safety Violation Report")
 
 
223
  c.setFont("Helvetica", 12)
224
- c.drawString(1*inch, 9.5*inch, f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}")
225
- c.drawString(1*inch, 9*inch, f"Safety Score: {score}%")
226
 
227
- # Violations List
228
- y = 8.5*inch
229
  c.setFont("Helvetica-Bold", 14)
230
- c.drawString(1*inch, y, "Detected Violations:")
231
- y -= 0.3*inch
232
- c.setFont("Helvetica", 10)
 
 
 
 
233
 
 
 
234
  for v in violations:
235
- text = (f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']} "
236
- f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})")
237
- if y < 1*inch:
238
- c.showPage()
239
- y = 10*inch
240
- c.drawString(1.2*inch, y, text)
241
- y -= 0.2*inch
242
 
243
- c.save()
244
- pdf_buffer.seek(0)
 
 
 
 
245
 
246
- # Save to file
247
- pdf_filename = f"violation_report_{int(time.time())}.pdf"
248
- pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  with open(pdf_path, "wb") as f:
250
- f.write(pdf_buffer.getvalue())
251
 
252
- return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}", pdf_buffer
 
 
253
  except Exception as e:
254
- logger.error(f"PDF generation failed: {e}")
255
- return None, None, None
256
 
257
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
258
  def connect_to_salesforce():
 
259
  try:
260
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
261
- logger.info("Salesforce connection established")
 
262
  return sf
263
  except Exception as e:
264
  logger.error(f"Salesforce connection failed: {e}")
265
  raise
266
 
267
- def upload_to_salesforce(sf, pdf_file, record_id):
 
268
  try:
269
- encoded = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
270
- file_data = {
271
- "Title": f"Safety_Report_{int(time.time())}",
272
- "PathOnClient": "safety_report.pdf",
273
- "VersionData": encoded,
274
- "FirstPublishLocationId": record_id
 
 
 
 
275
  }
276
- result = sf.ContentVersion.create(file_data)
277
- return f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{result['id']}"
 
 
 
 
 
 
 
 
278
  except Exception as e:
279
- logger.error(f"Salesforce upload failed: {e}")
280
- return None
281
 
282
- def create_salesforce_record(violations, score, pdf_url=None):
 
283
  try:
284
  sf = connect_to_salesforce()
285
- violations_text = "\n".join(
286
- f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]} "
287
- f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
288
- for v in violations
289
- )
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  record_data = {
292
  "Compliance_Score__c": score,
293
  "Violations_Found__c": len(violations),
294
  "Violations_Details__c": violations_text,
295
  "Status__c": "Pending",
296
- "PDF_Report_URL__c": pdf_url or ""
297
  }
298
 
 
 
299
  try:
300
  record = sf.Safety_Video_Report__c.create(record_data)
301
- return record["id"], None
302
- except:
 
303
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
304
- return record["id"], "Used Account as fallback"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  except Exception as e:
306
- logger.error(f"Salesforce record creation failed: {e}")
307
- return None, str(e)
308
 
309
  def process_video(video_data):
 
310
  try:
311
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
 
 
312
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
313
  with open(video_path, "wb") as f:
314
  f.write(video_data)
 
315
 
316
  cap = cv2.VideoCapture(video_path)
317
  if not cap.isOpened():
318
- raise ValueError("Failed to open video")
 
319
 
320
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
321
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
322
- tracker = SafetyTracker(frame_rate=fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  snapshots = []
 
 
324
  processed_frames = 0
325
- last_update = time.time()
326
 
327
  while processed_frames < total_frames:
328
  batch_frames = []
 
 
329
  for _ in range(CONFIG["BATCH_SIZE"]):
 
 
 
 
330
  ret, frame = cap.read()
331
  if not ret:
332
  break
333
- batch_frames.append(preprocess_frame(frame))
 
 
 
 
 
 
 
 
 
334
  processed_frames += 1
335
- if CONFIG["FRAME_SKIP"] > 1:
336
- for _ in range(CONFIG["FRAME_SKIP"]-1):
337
- cap.grab()
338
- processed_frames += 1
339
 
340
  if not batch_frames:
341
  break
342
 
 
343
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
344
 
345
- for i, result in enumerate(results):
 
 
 
 
 
 
 
 
 
346
  detections = []
347
- for box in result.boxes:
 
348
  cls = int(box.cls)
349
  conf = float(box.conf)
350
- label = CONFIG["VIOLATION_LABELS"].get(cls)
351
- if label and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
352
- detections.append({
353
- "bbox": box.xywh.cpu().numpy()[0],
354
- "violation": label,
355
- "confidence": conf
356
- })
357
 
358
- new_violations = tracker.update(detections)
359
-
360
- for violation in new_violations:
361
- frame_with_det = draw_detections(batch_frames[i].copy(), [violation])
362
- timestamp = f"Time: {violation['timestamp']:.2f}s"
363
- cv2.putText(frame_with_det, timestamp, (10, 30),
364
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
 
 
365
 
366
- snap_name = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
367
- snap_path = os.path.join(CONFIG["OUTPUT_DIR"], snap_name)
368
- cv2.imwrite(snap_path, frame_with_det,
369
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
- snapshots.append({
372
- "violation": violation['violation'],
373
- "worker_id": violation['worker_id'],
374
- "timestamp": violation['timestamp'],
375
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snap_name}"
376
- })
377
-
378
- if time.time() - last_update > 1:
379
- progress = (processed_frames / total_frames) * 100
380
- yield f"Processing... {progress:.1f}%", "", "", "", ""
381
- last_update = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  cap.release()
384
  if os.path.exists(video_path):
385
  os.remove(video_path)
386
-
387
- violations = [
388
- {"worker_id": wid, "violation": v, "timestamp": t, "confidence": 0} # Confidence placeholder
389
- for wid, violations in tracker.violation_history.items()
390
- for v, t in violations.items()
391
- ]
392
 
393
  if not violations:
394
- yield "No violations found", "Safety Score: 100%", "No snapshots", "N/A", "N/A"
 
395
  return
396
 
 
397
  score = calculate_safety_score(violations)
 
 
398
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
399
- record_id, sf_error = create_salesforce_record(violations, score, pdf_url)
400
 
401
- if pdf_file and record_id:
402
- uploaded_url = upload_to_salesforce(connect_to_salesforce(), pdf_file, record_id)
403
- if uploaded_url:
404
- pdf_url = uploaded_url
405
 
406
- violation_table = "| Violation | Worker ID | Time (s) |\n|-----------|-----------|----------|\n"
407
- violation_table += "\n".join(
408
- f"| {CONFIG['DISPLAY_NAMES'][v['violation']]} | {v['worker_id']} | {v['timestamp']:.2f} |"
409
- for v in sorted(violations, key=lambda x: x['timestamp'])
410
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
- snapshots_md = "\n\n".join(
413
- f"### {CONFIG['DISPLAY_NAMES'][s['violation']]} - Worker {s['worker_id']} at {s['timestamp']:.2f}s\n\n"
414
- f"![Snapshot]({s['snapshot_url']})"
415
- for s in snapshots
416
- ) if snapshots else "No snapshots captured"
417
 
418
  yield (
419
  violation_table,
420
  f"Safety Score: {score}%",
421
- snapshots_md,
422
- f"Salesforce ID: {record_id or 'N/A'} {sf_error or ''}",
423
- pdf_url or "N/A"
424
  )
425
 
426
  except Exception as e:
427
- logger.error(f"Video processing failed: {e}")
428
  if 'video_path' in locals() and os.path.exists(video_path):
429
  os.remove(video_path)
430
- yield f"Error: {str(e)}", "", "", "", ""
431
 
432
- # ========================== # Gradio Interface # ==========================
433
- def gradio_interface(video):
434
- if not video:
435
- return "Upload a video file", "", "", "", ""
436
-
437
  try:
438
- with open(video, "rb") as f:
439
  video_data = f.read()
440
-
441
- for output in process_video(video_data):
442
- yield output
 
443
  except Exception as e:
444
- logger.error(f"Interface error: {e}")
445
- yield f"Error: {str(e)}", "", "", "", ""
446
 
447
- with gr.Blocks(title="Safety Compliance Analyzer") as app:
448
- gr.Markdown("# Worksite Safety Violation Analyzer")
449
- gr.Markdown("Upload site videos to detect safety violations (No Helmet, No Harness, etc.)")
450
-
451
- with gr.Row():
452
- video_input = gr.Video(label="Site Video", sources=["upload"])
453
- with gr.Column():
454
- violations_out = gr.Markdown(label="Detected Violations")
455
- score_out = gr.Textbox(label="Safety Score")
456
- snapshots_out = gr.Markdown(label="Violation Snapshots")
457
- salesforce_out = gr.Textbox(label="Salesforce Record")
458
- pdf_out = gr.Textbox(label="Report PDF URL")
459
-
460
- video_input.change(
461
- gradio_interface,
462
- inputs=video_input,
463
- outputs=[violations_out, score_out, snapshots_out, salesforce_out, pdf_out]
464
- )
465
 
466
  if __name__ == "__main__":
467
- app.launch(server_port=7860, server_name="0.0.0.0")
 
 
1
  import os
2
  import sys
3
+ import subprocess
4
  import logging
5
  import warnings
6
  import cv2
 
16
  from io import BytesIO
17
  import base64
18
  from retrying import retry
19
+ import uuid
20
+ from multiprocessing import Pool, cpu_count
21
+ from functools import partial
22
+ import face_recognition
23
  from collections import defaultdict
24
 
25
  # ========================== # Configuration and Setup # ==========================
 
28
 
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
30
  logger = logging.getLogger(__name__)
31
+
32
+ # ========================== # Face Recognition Setup # ==========================
33
+ class FaceTracker:
34
+ def __init__(self):
35
+ self.known_faces = {}
36
+ self.next_face_id = 1
37
+ self.tolerance = 0.6
38
+ self.frame_skip = 5 # Process face recognition every N frames
 
 
39
 
40
+ def get_face_encoding(self, frame, box):
41
+ """Extract face encoding from bounding box"""
42
+ x, y, w, h = box
43
+ x1, y1 = int(x - w/2), int(y - h/2)
44
+ x2, y2 = int(x + w/2), int(y + h/2)
45
 
46
+ # Expand the face area slightly
47
+ expand = 0.2
48
+ h_expand = int((y2 - y1) * expand)
49
+ w_expand = int((x2 - x1) * expand)
 
 
 
 
 
 
 
50
 
51
+ y1 = max(0, y1 - h_expand)
52
+ y2 = min(frame.shape[0], y2 + h_expand)
53
+ x1 = max(0, x1 - w_expand)
54
+ x2 = min(frame.shape[1], x2 + w_expand)
55
+
56
+ face_frame = frame[y1:y2, x1:x2]
57
+
58
+ if face_frame.size == 0:
59
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Convert to RGB (face_recognition uses RGB)
62
+ rgb_frame = cv2.cvtColor(face_frame, cv2.COLOR_BGR2RGB)
 
 
 
 
63
 
64
+ # Get face encodings
65
+ encodings = face_recognition.face_encodings(rgb_frame)
66
+ return encodings[0] if encodings else None
 
 
 
67
 
68
+ def identify_face(self, frame, box):
69
+ """Identify or register a new face"""
70
+ encoding = self.get_face_encoding(frame, box)
71
+ if encoding is None:
72
+ return None
 
 
 
 
 
 
 
 
73
 
74
+ # Compare with known faces
75
+ for face_id, known_encoding in self.known_faces.items():
76
+ matches = face_recognition.compare_faces([known_encoding], encoding, tolerance=self.tolerance)
77
+ if matches[0]:
78
+ return face_id
79
+
80
+ # If no match, register new face
81
+ face_id = f"face_{self.next_face_id}"
82
+ self.known_faces[face_id] = encoding
83
+ self.next_face_id += 1
84
+ return face_id
85
+
86
+ # ========================== # Position-Based Tracker # ==========================
87
+ class PositionTracker:
88
+ def __init__(self, distance_threshold=100, cooldown=30):
89
+ self.workers = {}
90
+ self.distance_threshold = distance_threshold
91
+ self.cooldown = cooldown
92
+ self.next_id = 1
93
 
94
+ def track(self, position, violation_type, current_time):
95
+ """Track worker position and return worker ID"""
96
+ # Check if this is a known worker
97
+ for worker_id, worker_data in self.workers.items():
98
+ last_pos = worker_data['position']
99
+ last_time = worker_data['last_seen']
100
+
101
+ # Calculate distance and time difference
102
+ distance = np.sqrt((position[0] - last_pos[0])**2 + (position[1] - last_pos[1])**2)
103
+ time_diff = current_time - last_time
104
+
105
+ # If close enough and not too much time has passed
106
+ if distance < self.distance_threshold and time_diff < self.cooldown:
107
+ # Check if this violation type was already recorded
108
+ if violation_type not in worker_data['violations']:
109
+ worker_data['position'] = position
110
+ worker_data['last_seen'] = current_time
111
+ worker_data['violations'].add(violation_type)
112
+ return worker_id
113
+ return None # Violation already recorded
114
+
115
+ # If no match, create new worker
116
+ worker_id = f"worker_{self.next_id}"
117
+ self.workers[worker_id] = {
118
+ 'position': position,
119
+ 'last_seen': current_time,
120
+ 'violations': {violation_type}
121
+ }
122
+ self.next_id += 1
123
+ return worker_id
124
 
125
  # ========================== # Optimized Configuration # ==========================
126
  CONFIG = {
 
135
  4: "improper_tool_use"
136
  },
137
  "CLASS_COLORS": {
138
+ "no_helmet": (0, 0, 255), # Red
139
+ "no_harness": (0, 165, 255), # Orange
140
+ "unsafe_posture": (0, 255, 0), # Green
141
+ "unsafe_zone": (255, 0, 0), # Blue
142
+ "improper_tool_use": (255, 255, 0) # Cyan
143
  },
144
  "DISPLAY_NAMES": {
145
  "no_helmet": "No Helmet Violation",
 
162
  "unsafe_zone": 0.3,
163
  "improper_tool_use": 0.3
164
  },
165
+ "MIN_VIOLATION_FRAMES": 1,
166
+ "VIOLATION_COOLDOWN": 30.0,
167
+ "WORKER_TRACKING_DURATION": 5.0,
168
+ "MAX_PROCESSING_TIME": 60,
169
  "FRAME_SKIP": 2,
170
+ "BATCH_SIZE": 16,
171
+ "PARALLEL_WORKERS": max(1, cpu_count() - 1),
172
+ "FACE_RECOGNITION_INTERVAL": 5, # Process face recognition every N frames
173
+ "POSITION_TRACKING_THRESHOLD": 100, # pixels
174
+ "SNAPSHOT_QUALITY": 95,
175
+ "MAX_WORKER_DISTANCE": 100
176
  }
177
 
178
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
181
  def load_model():
182
  try:
183
  if os.path.isfile(CONFIG["MODEL_PATH"]):
184
+ model_path = CONFIG["MODEL_PATH"]
185
+ logger.info(f"Model loaded: {model_path}")
186
  else:
187
+ model_path = CONFIG["FALLBACK_MODEL"]
188
+ logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
189
+ if not os.path.isfile(model_path):
190
+ logger.info(f"Downloading fallback model: {model_path}")
191
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
192
+
193
+ model = YOLO(model_path).to(device)
194
+ logger.info(f"Model classes: {model.names}")
195
  return model
196
  except Exception as e:
197
+ logger.error(f"Failed to load model: {e}")
198
  raise
199
 
200
  model = load_model()
201
 
202
+ # ========================== # Helper Functions # ==========================
203
  def preprocess_frame(frame):
204
+ """Apply basic preprocessing to enhance detection"""
205
  frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
206
  return frame
207
 
208
  def draw_detections(frame, detections):
209
+ """Draw bounding boxes and labels on detection frame with improved visibility"""
210
  result_frame = frame.copy()
211
+
212
  for det in detections:
213
+ label = det.get("violation", "Unknown")
214
+ confidence = det.get("confidence", 0.0)
215
+ x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
216
+ worker_id = det.get("worker_id", "Unknown")
217
+
218
+ x1 = int(x - w/2)
219
+ y1 = int(y - h/2)
220
+ x2 = int(x + w/2)
221
+ y2 = int(y + h/2)
222
+
223
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
224
 
225
+ # Draw thicker rectangle with border
226
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
227
+
228
+ # Add black background behind text
229
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
230
+ text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
231
+ cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
232
+ cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
233
+
234
+ # Add confidence score
235
+ conf_text = f"Conf: {confidence:.2f}"
236
+ cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
237
+
238
  return result_frame
239
 
240
  def calculate_safety_score(violations):
241
+ """Calculate safety score based on detected violations"""
242
  penalties = {
243
+ "no_helmet": 25,
244
+ "no_harness": 30,
245
+ "unsafe_posture": 20,
246
+ "unsafe_zone": 35,
247
+ "improper_tool_use": 25
248
  }
249
+
250
+ # Count unique violation types per worker
251
+ worker_violations = defaultdict(set)
252
+ for v in violations:
253
+ worker_id = v.get("worker_id", "Unknown")
254
+ violation_type = v.get("violation", "Unknown")
255
+ worker_violations[worker_id].add(violation_type)
256
+
257
+ # Calculate total penalty
258
+ total_penalty = sum(penalties.get(v, 0) for violations_set in worker_violations.values() for v in violations_set)
259
+
260
+ score = max(0, 100 - total_penalty)
261
+ return score
262
 
263
  def generate_violation_pdf(violations, score):
264
+ """Generate a PDF report for the detected violations"""
265
  try:
266
+ pdf_filename = f"violations_{int(time.time())}.pdf"
267
+ pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
268
+ pdf_file = BytesIO()
269
+ c = canvas.Canvas(pdf_file, pagesize=letter)
270
 
271
+ # Title
272
  c.setFont("Helvetica-Bold", 16)
273
+ c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
274
+
275
+ # Basic Information
276
  c.setFont("Helvetica", 12)
277
+ c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
278
+ c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
279
 
280
+ # Safety Score
 
281
  c.setFont("Helvetica-Bold", 14)
282
+ c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
283
+
284
+ # Violation Summary
285
+ y_position = 8.2 * inch
286
+ c.setFont("Helvetica-Bold", 12)
287
+ c.drawString(1 * inch, y_position, "Summary:")
288
+ y_position -= 0.3 * inch
289
 
290
+ # Group violations by worker
291
+ worker_violations = defaultdict(list)
292
  for v in violations:
293
+ worker_id = v.get("worker_id", "Unknown")
294
+ worker_violations[worker_id].append(v)
 
 
 
 
 
295
 
296
+ c.setFont("Helvetica", 10)
297
+ summary_data = {
298
+ "Total Workers with Violations": len(worker_violations),
299
+ "Total Violations Found": len(violations),
300
+ "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
301
+ }
302
 
303
+ for key, value in summary_data.items():
304
+ c.drawString(1 * inch, y_position, f"{key}: {value}")
305
+ y_position -= 0.25 * inch
306
+
307
+ # Detailed Violations by Worker
308
+ y_position -= 0.5 * inch
309
+ c.setFont("Helvetica-Bold", 12)
310
+ c.drawString(1 * inch, y_position, "Violations by Worker:")
311
+ y_position -= 0.3 * inch
312
+
313
+ c.setFont("Helvetica", 10)
314
+ for worker_id, worker_vios in worker_violations.items():
315
+ c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
316
+ y_position -= 0.2 * inch
317
+
318
+ for v in worker_vios:
319
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
320
+ time_str = f"{v.get('timestamp', 0.0):.2f}s"
321
+ conf_str = f"{v.get('confidence', 0.0):.2f}"
322
+
323
+ violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
324
+ c.drawString(1.2 * inch, y_position, violation_text)
325
+ y_position -= 0.2 * inch
326
+
327
+ if y_position < 1 * inch:
328
+ c.showPage()
329
+ c.setFont("Helvetica", 10)
330
+ y_position = 10 * inch
331
+
332
+ c.save()
333
+ pdf_file.seek(0)
334
+
335
+ # Save PDF file
336
  with open(pdf_path, "wb") as f:
337
+ f.write(pdf_file.getvalue())
338
 
339
+ public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
340
+ logger.info(f"PDF generated: {public_url}")
341
+ return pdf_path, public_url, pdf_file
342
  except Exception as e:
343
+ logger.error(f"Error generating PDF: {e}")
344
+ return "", "", None
345
 
346
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
347
  def connect_to_salesforce():
348
+ """Connect to Salesforce with retry logic"""
349
  try:
350
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
351
+ logger.info("Connected to Salesforce")
352
+ sf.describe()
353
  return sf
354
  except Exception as e:
355
  logger.error(f"Salesforce connection failed: {e}")
356
  raise
357
 
358
+ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
359
+ """Upload PDF report to Salesforce"""
360
  try:
361
+ if not pdf_file:
362
+ logger.error("No PDF file provided for upload")
363
+ return ""
364
+
365
+ encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
366
+ content_version_data = {
367
+ "Title": f"Safety_Violation_Report_{int(time.time())}",
368
+ "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
369
+ "VersionData": encoded_pdf,
370
+ "FirstPublishLocationId": report_id
371
  }
372
+ content_version = sf.ContentVersion.create(content_version_data)
373
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
374
+
375
+ if not result['records']:
376
+ logger.error("Failed to retrieve ContentVersion")
377
+ return ""
378
+
379
+ file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
380
+ logger.info(f"PDF uploaded to Salesforce: {file_url}")
381
+ return file_url
382
  except Exception as e:
383
+ logger.error(f"Error uploading PDF to Salesforce: {e}")
384
+ return ""
385
 
386
+ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
387
+ """Push violation report to Salesforce"""
388
  try:
389
  sf = connect_to_salesforce()
 
 
 
 
 
390
 
391
+ # Format violations for Salesforce
392
+ violations_text = ""
393
+ for v in violations:
394
+ display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
395
+ worker_id = v.get('worker_id', 'Unknown')
396
+ timestamp = v.get('timestamp', 0.0)
397
+ confidence = v.get('confidence', 0.0)
398
+
399
+ violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
400
+
401
+ if not violations_text:
402
+ violations_text = "No violations detected."
403
+
404
+ pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
405
+
406
  record_data = {
407
  "Compliance_Score__c": score,
408
  "Violations_Found__c": len(violations),
409
  "Violations_Details__c": violations_text,
410
  "Status__c": "Pending",
411
+ "PDF_Report_URL__c": pdf_url
412
  }
413
 
414
+ logger.info(f"Creating Salesforce record with data: {record_data}")
415
+
416
  try:
417
  record = sf.Safety_Video_Report__c.create(record_data)
418
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
419
+ except Exception as e:
420
+ logger.error(f"Failed to create Safety_Video_Report__c: {e}")
421
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
422
+ logger.warning(f"Fell back to Account record: {record['id']}")
423
+
424
+ record_id = record["id"]
425
+
426
+ if pdf_file:
427
+ uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
428
+ if uploaded_url:
429
+ try:
430
+ sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
431
+ logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
432
+ except Exception as e:
433
+ logger.error(f"Failed to update Safety_Video_Report__c: {e}")
434
+ sf.Account.update(record_id, {"Description": uploaded_url})
435
+ logger.info(f"Updated Account record {record_id} with PDF URL")
436
+ pdf_url = uploaded_url
437
+
438
+ return record_id, pdf_url
439
  except Exception as e:
440
+ logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
441
+ return None, ""
442
 
443
  def process_video(video_data):
444
+ """Process video to detect safety violations"""
445
  try:
446
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
447
+ logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
448
+
449
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
450
  with open(video_path, "wb") as f:
451
  f.write(video_data)
452
+ logger.info(f"Video saved: {video_path}")
453
 
454
  cap = cv2.VideoCapture(video_path)
455
  if not cap.isOpened():
456
+ os.remove(video_path)
457
+ raise ValueError("Could not open video file")
458
 
 
459
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
460
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
461
+ duration = total_frames / fps
462
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
463
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
464
+ logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
465
+
466
+ # Initialize trackers
467
+ face_tracker = FaceTracker()
468
+ position_tracker = PositionTracker(
469
+ distance_threshold=CONFIG["POSITION_TRACKING_THRESHOLD"],
470
+ cooldown=CONFIG["VIOLATION_COOLDOWN"]
471
+ )
472
+
473
+ violations = []
474
  snapshots = []
475
+ start_time = time.time()
476
+ frame_skip = CONFIG["FRAME_SKIP"]
477
  processed_frames = 0
478
+ frame_count = 0
479
 
480
  while processed_frames < total_frames:
481
  batch_frames = []
482
+ batch_indices = []
483
+
484
  for _ in range(CONFIG["BATCH_SIZE"]):
485
+ frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
486
+ if frame_idx >= total_frames:
487
+ break
488
+
489
  ret, frame = cap.read()
490
  if not ret:
491
  break
492
+
493
+ frame = preprocess_frame(frame)
494
+
495
+ # Skip frames if needed
496
+ for _ in range(frame_skip - 1):
497
+ if not cap.grab():
498
+ break
499
+
500
+ batch_frames.append(frame)
501
+ batch_indices.append(frame_idx)
502
  processed_frames += 1
503
+ frame_count += 1
 
 
 
504
 
505
  if not batch_frames:
506
  break
507
 
508
+ # Process batch with YOLO model
509
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
510
 
511
+ for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
512
+ current_time = frame_idx / fps
513
+
514
+ # Update progress every second
515
+ if time.time() - start_time > 1.0:
516
+ progress = (processed_frames / total_frames) * 100
517
+ yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", ""
518
+ start_time = time.time()
519
+
520
+ boxes = result.boxes
521
  detections = []
522
+
523
+ for box in boxes:
524
  cls = int(box.cls)
525
  conf = float(box.conf)
526
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
527
+
528
+ if label is None:
529
+ continue
530
+
531
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
532
+ continue
533
 
534
+ bbox = box.xywh.cpu().numpy()[0]
535
+
536
+ # For helmet violations, use face recognition
537
+ if label == "no_helmet" and frame_count % CONFIG["FACE_RECOGNITION_INTERVAL"] == 0:
538
+ worker_id = face_tracker.identify_face(batch_frames[i], bbox)
539
+ else:
540
+ # For other violations, use position tracking
541
+ position = (bbox[0], bbox[1])
542
+ worker_id = position_tracker.track(position, label, current_time)
543
 
544
+ if worker_id is None:
545
+ continue # Skip if this is a duplicate violation
546
+
547
+ detection = {
548
+ "worker_id": worker_id,
549
+ "violation": label,
550
+ "confidence": round(conf, 2),
551
+ "bounding_box": bbox,
552
+ "timestamp": current_time
553
+ }
554
+ detections.append(detection)
555
+
556
+ # Process new violations
557
+ for detection in detections:
558
+ # Check if we already have this violation for this worker
559
+ existing = next((v for v in violations
560
+ if v["worker_id"] == detection["worker_id"]
561
+ and v["violation"] == detection["violation"]), None)
562
 
563
+ if not existing:
564
+ violations.append(detection)
565
+
566
+ # Take snapshot for the new violation
567
+ snapshot_frame = batch_frames[i].copy()
568
+ snapshot_frame = draw_detections(snapshot_frame, [detection])
569
+
570
+ # Add timestamp to snapshot
571
+ cv2.putText(
572
+ snapshot_frame,
573
+ f"Time: {current_time:.2f}s",
574
+ (10, 30),
575
+ cv2.FONT_HERSHEY_SIMPLEX,
576
+ 0.7,
577
+ (255, 255, 255),
578
+ 2
579
+ )
580
+
581
+ # Save snapshot with high quality
582
+ snapshot_filename = f"violation_{detection['violation']}_worker{detection['worker_id']}_{int(current_time*100)}.jpg"
583
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
584
+
585
+ cv2.imwrite(
586
+ snapshot_path,
587
+ snapshot_frame,
588
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
589
+ )
590
+
591
+ snapshots.append({
592
+ "violation": detection["violation"],
593
+ "worker_id": detection["worker_id"],
594
+ "timestamp": current_time,
595
+ "snapshot_path": snapshot_path,
596
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
597
+ })
598
+
599
+ logger.info(f"Captured snapshot for {detection['violation']} violation by worker {detection['worker_id']} at {current_time:.2f}s")
600
 
601
  cap.release()
602
  if os.path.exists(video_path):
603
  os.remove(video_path)
604
+
605
+ processing_time = time.time() - start_time
606
+ logger.info(f"Processing complete in {processing_time:.2f}s")
 
 
 
607
 
608
  if not violations:
609
+ logger.info("No violations detected after processing")
610
+ yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
611
  return
612
 
613
+ # Calculate safety score
614
  score = calculate_safety_score(violations)
615
+
616
+ # Generate PDF report
617
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
 
618
 
619
+ # Push report to Salesforce
620
+ report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
 
 
621
 
622
+ # Format violations table for display
623
+ violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
624
+ violation_table += "|-----------|-----------|----------|------------|\n"
625
+
626
+ for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
627
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
628
+ worker_id = v.get("worker_id", "Unknown")
629
+ timestamp = v.get("timestamp", 0.0)
630
+ confidence = v.get("confidence", 0.0)
631
+
632
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
633
+
634
+ # Format snapshots for display
635
+ snapshots_text = ""
636
+ for s in snapshots:
637
+ display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
638
+ worker_id = s.get("worker_id", "Unknown")
639
+ timestamp = s.get("timestamp", 0.0)
640
+
641
+ snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
642
+ snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
643
 
644
+ if not snapshots_text:
645
+ snapshots_text = "No snapshots captured."
 
 
 
646
 
647
  yield (
648
  violation_table,
649
  f"Safety Score: {score}%",
650
+ snapshots_text,
651
+ f"Salesforce Record ID: {report_id or 'N/A'}",
652
+ final_pdf_url or "N/A"
653
  )
654
 
655
  except Exception as e:
656
+ logger.error(f"Error processing video: {e}", exc_info=True)
657
  if 'video_path' in locals() and os.path.exists(video_path):
658
  os.remove(video_path)
659
+ yield f"Error processing video: {e}", "", "", "", ""
660
 
661
+ def gradio_interface(video_file):
662
+ """Gradio interface for the video processing"""
663
+ if not video_file:
664
+ return "No file uploaded.", "", "No file uploaded.", "", ""
665
+
666
  try:
667
+ with open(video_file, "rb") as f:
668
  video_data = f.read()
669
+
670
+ for status, score, snapshots_text, record_id, details_url in process_video(video_data):
671
+ yield status, score, snapshots_text, record_id, details_url
672
+
673
  except Exception as e:
674
+ logger.error(f"Error in Gradio interface: {e}", exc_info=True)
675
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
676
 
677
+ # ========================== # Gradio Interface # ==========================
678
+ interface = gr.Interface(
679
+ fn=gradio_interface,
680
+ inputs=gr.Video(label="Upload Site Video"),
681
+ outputs=[
682
+ gr.Markdown(label="Detected Safety Violations"),
683
+ gr.Textbox(label="Compliance Score"),
684
+ gr.Markdown(label="Snapshots"),
685
+ gr.Textbox(label="Salesforce Record ID"),
686
+ gr.Textbox(label="Violation Details URL")
687
+ ],
688
+ title="Worksite Safety Violation Analyzer",
689
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
690
+ allow_flagging="never"
691
+ )
 
 
 
692
 
693
  if __name__ == "__main__":
694
+ logger.info("Launching Enhanced Safety Analyzer App...")
695
+ interface.launch()