PrashanthB461 commited on
Commit
baeebd1
·
verified ·
1 Parent(s): a3d6280

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -188
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import sys
3
- import subprocess
4
  import logging
5
  import warnings
6
  import cv2
@@ -16,10 +15,8 @@ from reportlab.lib.units import inch
16
  from io import BytesIO
17
  import base64
18
  from retrying import retry
19
- import uuid
20
- from multiprocessing import Pool, cpu_count
21
- from functools import partial
22
  from collections import defaultdict
 
23
 
24
  # ========================== # Configuration and Setup # ==========================
25
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
@@ -29,7 +26,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
29
  logger = logging.getLogger(__name__)
30
  warnings.filterwarnings("ignore")
31
 
32
- # ========================== # Position-Based Tracker Implementation # ==========================
33
  class SafetyTracker:
34
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
35
  self.track_thresh = track_thresh
@@ -37,13 +34,10 @@ class SafetyTracker:
37
  self.match_thresh = match_thresh
38
  self.frame_rate = frame_rate
39
  self.next_id = 1
 
 
 
40
 
41
- # Tracking stores
42
- self.worker_tracks = {} # Active tracks
43
- self.violation_history = defaultdict(dict) # {worker_id: {violation_type: last_detection_time}}
44
- self.position_history = defaultdict(list) # {worker_id: [positions]}
45
-
46
- # Violation cooldowns (seconds)
47
  self.VIOLATION_COOLDOWNS = {
48
  "no_helmet": 30.0,
49
  "no_harness": 20.0,
@@ -52,8 +46,7 @@ class SafetyTracker:
52
  "improper_tool_use": 15.0
53
  }
54
 
55
- def update(self, detections):
56
- """Update tracks with new detections using position-based matching"""
57
  current_time = time.time()
58
  new_violations = []
59
 
@@ -62,14 +55,11 @@ class SafetyTracker:
62
  label = det['violation']
63
  confidence = det['confidence']
64
 
65
- # Match by position
66
  worker_id = self._match_by_position(bbox, label)
67
-
68
  if worker_id is None:
69
  worker_id = self.next_id
70
  self.next_id += 1
71
 
72
- # Check if new violation
73
  if self._is_new_violation(worker_id, label, current_time):
74
  violation = {
75
  'worker_id': worker_id,
@@ -81,62 +71,41 @@ class SafetyTracker:
81
  new_violations.append(violation)
82
  self.violation_history[worker_id][label] = current_time
83
 
84
- # Update position history
85
- x, y, w, h = bbox
86
- self.position_history[worker_id].append((x, y))
87
-
88
- # Update active tracks
89
  self.worker_tracks[worker_id] = {
90
  'bbox': bbox,
91
  'last_seen': current_time,
92
  'label': label
93
  }
 
94
 
95
- # Cleanup old tracks
96
  self._cleanup_tracks(current_time)
97
-
98
  return new_violations
99
 
100
  def _match_by_position(self, bbox, label):
101
- """Match detection to existing tracks using position"""
102
- x, y, w, h = bbox
103
- current_pos = (x, y)
104
-
105
  for worker_id, positions in self.position_history.items():
106
- # Only match if worker has had this violation type before
107
- if label not in self.violation_history[worker_id]:
108
- continue
109
-
110
- # Check distance to historical positions
111
  for pos in positions[-5:]: # Check last 5 positions
112
- distance = np.sqrt((current_pos[0]-pos[0])**2 + (current_pos[1]-pos[1])**2)
113
- if distance < 100: # Within 100 pixels
114
  return worker_id
115
  return None
116
 
117
  def _is_new_violation(self, worker_id, label, current_time):
118
- """Check if violation is new based on cooldown"""
119
  if label not in self.violation_history[worker_id]:
120
  return True
121
-
122
- last_time = self.violation_history[worker_id][label]
123
- cooldown = self.VIOLATION_COOLDOWNS.get(label, 10.0)
124
- return (current_time - last_time) > cooldown
125
 
126
  def _cleanup_tracks(self, current_time):
127
- """Remove inactive tracks"""
128
  inactive_ids = [
129
- id for id, track in self.worker_tracks.items()
130
  if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
131
  ]
132
- for id in inactive_ids:
133
- self.worker_tracks.pop(id, None)
134
- self.position_history.pop(id, None)
135
- # Keep violation history for longer
136
- if (current_time - max(self.violation_history[id].values(), default=0)) > 300:
137
- self.violation_history.pop(id, None)
138
 
139
- # ========================== # App Configuration # ==========================
140
  CONFIG = {
141
  "MODEL_PATH": "yolov8_safety.pt",
142
  "FALLBACK_MODEL": "yolov8n.pt",
@@ -156,10 +125,10 @@ CONFIG = {
156
  "improper_tool_use": (255, 255, 0)
157
  },
158
  "DISPLAY_NAMES": {
159
- "no_helmet": "No Helmet Violation",
160
- "no_harness": "No Harness Violation",
161
  "unsafe_posture": "Unsafe Posture",
162
- "unsafe_zone": "Unsafe Zone Entry",
163
  "improper_tool_use": "Improper Tool Use"
164
  },
165
  "SF_CREDENTIALS": {
@@ -177,7 +146,7 @@ CONFIG = {
177
  "improper_tool_use": 0.3
178
  },
179
  "FRAME_SKIP": 2,
180
- "BATCH_SIZE": 8, # Reduced for stability
181
  "SNAPSHOT_QUALITY": 90
182
  }
183
 
@@ -188,9 +157,7 @@ logger.info(f"Using device: {device}")
188
  def load_model():
189
  try:
190
  model_path = CONFIG["MODEL_PATH"] if os.path.exists(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
191
- logger.info(f"Loading model: {model_path}")
192
  if not os.path.exists(model_path):
193
- logger.info("Downloading fallback model...")
194
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
195
  return YOLO(model_path).to(device)
196
  except Exception as e:
@@ -199,35 +166,23 @@ def load_model():
199
 
200
  model = load_model()
201
 
202
- def preprocess_frame(frame):
203
- """Basic image enhancement"""
204
- return cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
205
-
206
  def draw_detections(frame, detections):
207
- """Draw bounding boxes with labels"""
208
- result = frame.copy()
209
  for det in detections:
210
  x, y, w, h = det['bbox']
211
  x1, y1 = int(x-w/2), int(y-h/2)
212
  x2, y2 = int(x+w/2), int(y+h/2)
213
- label = CONFIG["DISPLAY_NAMES"].get(det['violation'], det['violation'])
214
- color = CONFIG["CLASS_COLORS"].get(det['violation'], (0,0,255))
215
-
216
- cv2.rectangle(result, (x1, y1), (x2, y2), color, 3)
217
- cv2.putText(result, f"{label} (Worker {det['worker_id']})",
218
- (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
219
- return result
220
 
221
  def calculate_safety_score(violations):
222
- penalty_map = {
223
- "no_helmet": 25, "no_harness": 30,
224
- "unsafe_posture": 20, "unsafe_zone": 35,
225
- "improper_tool_use": 25
226
- }
227
  unique_violations = {v['violation'] for v in violations}
228
- return max(0, 100 - sum(penalty_map.get(v,0) for v in unique_violations))
229
 
230
- # ========================== # Reporting Functions # ==========================
231
  def generate_violation_pdf(violations, score):
232
  try:
233
  pdf_buffer = BytesIO()
@@ -237,16 +192,16 @@ def generate_violation_pdf(violations, score):
237
  c.setFont("Helvetica-Bold", 16)
238
  c.drawString(1*inch, 10*inch, "Safety Violation Report")
239
  c.setFont("Helvetica", 12)
240
- c.drawString(1*inch, 9.5*inch, f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}")
241
  c.drawString(1*inch, 9*inch, f"Safety Score: {score}%")
242
 
243
- # Violations list
244
  y = 8.5*inch
245
  c.setFont("Helvetica-Bold", 14)
246
  c.drawString(1*inch, y, "Violations Detected:")
247
  y -= 0.3*inch
248
-
249
  c.setFont("Helvetica", 10)
 
250
  for v in violations:
251
  text = f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]} at {v['timestamp']:.1f}s"
252
  c.drawString(1.2*inch, y, text)
@@ -254,6 +209,7 @@ def generate_violation_pdf(violations, score):
254
  if y < 1*inch:
255
  c.showPage()
256
  y = 10*inch
 
257
 
258
  c.save()
259
  pdf_buffer.seek(0)
@@ -272,85 +228,63 @@ def generate_violation_pdf(violations, score):
272
  def connect_to_salesforce():
273
  try:
274
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
275
- logger.info("Salesforce connection established")
276
  return sf
277
  except Exception as e:
278
  logger.error(f"Salesforce connection failed: {e}")
279
  raise
280
 
281
- def push_report_to_salesforce(violations, score, pdf_path, pdf_buffer):
282
  try:
283
- sf = connect_to_salesforce()
284
-
285
  # Create record
286
- record_data = {
 
 
 
 
 
287
  "Compliance_Score__c": score,
288
  "Violations_Found__c": len(violations),
289
- "Violations_Details__c": "\n".join(
290
- f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]}"
291
- for v in violations
292
- ),
293
- "Status__c": "New"
294
- }
295
 
296
- record = sf.Safety_Video_Report__c.create(record_data)
297
- record_id = record['id']
298
- logger.info(f"Created Salesforce record: {record_id}")
 
 
 
 
 
299
 
300
- # Upload PDF if available
301
- pdf_url = ""
302
- if pdf_buffer:
303
- encoded = base64.b64encode(pdf_buffer.getvalue()).decode()
304
- content_version = sf.ContentVersion.create({
305
- "Title": f"Safety_Report_{record_id}",
306
- "PathOnClient": "report.pdf",
307
- "VersionData": encoded,
308
- "FirstPublishLocationId": record_id
309
- })
310
- pdf_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
311
- logger.info(f"PDF uploaded: {pdf_url}")
312
-
313
- return record_id, pdf_url
314
  except Exception as e:
315
  logger.error(f"Salesforce upload failed: {e}")
316
- return None, ""
317
 
318
  # ========================== # Video Processing # ==========================
319
  def process_video(video_data):
320
  try:
321
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
322
-
323
- # Save video
324
- video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"input_{int(time.time())}.mp4")
325
  with open(video_path, "wb") as f:
326
  f.write(video_data)
327
 
328
  cap = cv2.VideoCapture(video_path)
329
- if not cap.isOpened():
330
- raise ValueError("Failed to open video")
331
-
332
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
333
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
334
  tracker = SafetyTracker(frame_rate=fps)
335
  snapshots = []
336
 
337
- frame_count = 0
338
- while True:
339
  ret, frame = cap.read()
340
  if not ret:
341
  break
342
 
343
- if frame_count % CONFIG["FRAME_SKIP"] != 0:
344
- frame_count += 1
345
- continue
346
-
347
- # Process frame
348
- frame = preprocess_frame(frame)
349
- results = model(frame, verbose=False)[0]
350
 
351
- # Get detections
352
  detections = []
353
- for box in results.boxes:
354
  cls = int(box.cls)
355
  label = CONFIG["VIOLATION_LABELS"].get(cls)
356
  if label and box.conf > CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
@@ -360,72 +294,49 @@ def process_video(video_data):
360
  'confidence': float(box.conf)
361
  })
362
 
363
- # Update tracker
364
- new_violations = tracker.update(detections)
365
 
366
- # Capture snapshots for new violations
367
  for violation in new_violations:
368
  snapshot = draw_detections(frame.copy(), [violation])
369
  timestamp = time.strftime("%Y%m%d_%H%M%S")
370
- img_path = os.path.join(
371
- CONFIG["OUTPUT_DIR"],
372
- f"violation_{violation['worker_id']}_{violation['violation']}_{timestamp}.jpg"
373
- )
374
- cv2.imwrite(img_path, snapshot, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
375
  snapshots.append({
376
  'path': img_path,
377
  'url': f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(img_path)}",
378
  'violation': violation
379
  })
380
-
381
- # Update progress
382
- if frame_count % 10 == 0:
383
- progress = min(100, frame_count / total_frames * 100)
384
- yield f"Processing... {progress:.1f}%", "", "", "", ""
385
 
386
- frame_count += 1
387
-
388
  cap.release()
389
  os.remove(video_path)
390
 
391
- # Generate report
392
- violations = [
393
- {
394
- 'worker_id': worker_id,
395
- 'violation': violation_type,
396
- 'timestamp': detection_time
397
- }
398
- for worker_id, violations in tracker.violation_history.items()
399
- for violation_type, detection_time in violations.items()
400
- ]
401
-
402
- if not violations:
403
- yield "No violations found", "Safety Score: 100%", "No snapshots", "N/A", "N/A"
404
  return
405
 
406
- score = calculate_safety_score(violations)
407
- pdf_path, pdf_url, pdf_buffer = generate_violation_pdf(violations, score)
408
- record_id, salesforce_url = push_report_to_salesforce(violations, score, pdf_path, pdf_buffer)
409
-
410
- # Format output
411
- violations_table = "| Violation | Worker ID | Time |\n|-----------|-----------|------|\n"
412
- violations_table += "\n".join(
413
- f"| {CONFIG['DISPLAY_NAMES'][v['violation']]} | {v['worker_id']} | {v['timestamp']:.1f}s |"
414
- for v in violations
415
- )
416
 
417
- snapshots_md = "\n\n".join(
418
- f"### {CONFIG['DISPLAY_NAMES'][s['violation']['violation']]} (Worker {s['violation']['worker_id']})\n"
419
- f"![Snapshot]({s['url']})"
420
- for s in snapshots
 
 
 
 
 
421
  )
422
 
423
  yield (
424
- violations_table,
425
  f"Safety Score: {score}%",
426
- snapshots_md or "No snapshots",
427
  f"Salesforce ID: {record_id or 'N/A'}",
428
- salesforce_url or pdf_url or "N/A"
429
  )
430
 
431
  except Exception as e:
@@ -435,30 +346,22 @@ def process_video(video_data):
435
  yield f"Error: {str(e)}", "", "", "", ""
436
 
437
  # ========================== # Gradio Interface # ==========================
438
- def gradio_interface(video_file):
439
- if not video_file:
440
  return "Upload a video file", "", "", "", ""
441
-
442
- try:
443
- with open(video_file, "rb") as f:
444
- video_data = f.read()
445
-
446
- for output in process_video(video_data):
447
- yield output
448
-
449
- except Exception as e:
450
- logger.error(f"Interface error: {e}")
451
- yield f"Error: {str(e)}", "", "", "", ""
452
 
453
  interface = gr.Interface(
454
  fn=gradio_interface,
455
- inputs=gr.Video(label="Upload Safety Video"),
456
  outputs=[
457
- gr.Markdown("## Detected Violations"),
458
- gr.Textbox(label="Safety Score"),
459
- gr.Markdown("## Evidence Snapshots"),
460
- gr.Textbox(label="Salesforce Record"),
461
- gr.Textbox(label="Report URL")
462
  ],
463
  title="AI Safety Compliance Analyzer",
464
  description="Detects PPE and safety violations in worksite videos"
 
1
  import os
2
  import sys
 
3
  import logging
4
  import warnings
5
  import cv2
 
15
  from io import BytesIO
16
  import base64
17
  from retrying import retry
 
 
 
18
  from collections import defaultdict
19
+ from multiprocessing import cpu_count
20
 
21
  # ========================== # Configuration and Setup # ==========================
22
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
 
26
  logger = logging.getLogger(__name__)
27
  warnings.filterwarnings("ignore")
28
 
29
+ # ========================== # Optimized Tracker Implementation (No Face Recognition) # ==========================
30
  class SafetyTracker:
31
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
32
  self.track_thresh = track_thresh
 
34
  self.match_thresh = match_thresh
35
  self.frame_rate = frame_rate
36
  self.next_id = 1
37
+ self.worker_tracks = {}
38
+ self.violation_history = defaultdict(dict)
39
+ self.position_history = defaultdict(list)
40
 
 
 
 
 
 
 
41
  self.VIOLATION_COOLDOWNS = {
42
  "no_helmet": 30.0,
43
  "no_harness": 20.0,
 
46
  "improper_tool_use": 15.0
47
  }
48
 
49
+ def update(self, detections, frame):
 
50
  current_time = time.time()
51
  new_violations = []
52
 
 
55
  label = det['violation']
56
  confidence = det['confidence']
57
 
 
58
  worker_id = self._match_by_position(bbox, label)
 
59
  if worker_id is None:
60
  worker_id = self.next_id
61
  self.next_id += 1
62
 
 
63
  if self._is_new_violation(worker_id, label, current_time):
64
  violation = {
65
  'worker_id': worker_id,
 
71
  new_violations.append(violation)
72
  self.violation_history[worker_id][label] = current_time
73
 
 
 
 
 
 
74
  self.worker_tracks[worker_id] = {
75
  'bbox': bbox,
76
  'last_seen': current_time,
77
  'label': label
78
  }
79
+ self.position_history[worker_id].append((bbox[0], bbox[1]))
80
 
 
81
  self._cleanup_tracks(current_time)
 
82
  return new_violations
83
 
84
  def _match_by_position(self, bbox, label):
85
+ x, y = bbox[0], bbox[1]
 
 
 
86
  for worker_id, positions in self.position_history.items():
 
 
 
 
 
87
  for pos in positions[-5:]: # Check last 5 positions
88
+ if np.sqrt((x-pos[0])**2 + (y-pos[1])**2) < 100:
 
89
  return worker_id
90
  return None
91
 
92
  def _is_new_violation(self, worker_id, label, current_time):
 
93
  if label not in self.violation_history[worker_id]:
94
  return True
95
+ return (current_time - self.violation_history[worker_id][label]) > self.VIOLATION_COOLDOWNS.get(label, 10.0)
 
 
 
96
 
97
  def _cleanup_tracks(self, current_time):
 
98
  inactive_ids = [
99
+ wid for wid, track in self.worker_tracks.items()
100
  if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
101
  ]
102
+ for wid in inactive_ids:
103
+ self.worker_tracks.pop(wid, None)
104
+ self.position_history.pop(wid, None)
105
+ if (current_time - max(self.violation_history[wid].values(), default=0)) > 300:
106
+ self.violation_history.pop(wid, None)
 
107
 
108
+ # ========================== # Configuration # ==========================
109
  CONFIG = {
110
  "MODEL_PATH": "yolov8_safety.pt",
111
  "FALLBACK_MODEL": "yolov8n.pt",
 
125
  "improper_tool_use": (255, 255, 0)
126
  },
127
  "DISPLAY_NAMES": {
128
+ "no_helmet": "No Helmet",
129
+ "no_harness": "No Harness",
130
  "unsafe_posture": "Unsafe Posture",
131
+ "unsafe_zone": "Unsafe Zone",
132
  "improper_tool_use": "Improper Tool Use"
133
  },
134
  "SF_CREDENTIALS": {
 
146
  "improper_tool_use": 0.3
147
  },
148
  "FRAME_SKIP": 2,
149
+ "BATCH_SIZE": 8,
150
  "SNAPSHOT_QUALITY": 90
151
  }
152
 
 
157
  def load_model():
158
  try:
159
  model_path = CONFIG["MODEL_PATH"] if os.path.exists(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
 
160
  if not os.path.exists(model_path):
 
161
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
162
  return YOLO(model_path).to(device)
163
  except Exception as e:
 
166
 
167
  model = load_model()
168
 
 
 
 
 
169
  def draw_detections(frame, detections):
170
+ annotated = frame.copy()
 
171
  for det in detections:
172
  x, y, w, h = det['bbox']
173
  x1, y1 = int(x-w/2), int(y-h/2)
174
  x2, y2 = int(x+w/2), int(y+h/2)
175
+ color = CONFIG["CLASS_COLORS"][det['violation']]
176
+ cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
177
+ label = f"{CONFIG['DISPLAY_NAMES'][det['violation']]} (Worker {det['worker_id']})"
178
+ cv2.putText(annotated, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 2)
179
+ return annotated
 
 
180
 
181
  def calculate_safety_score(violations):
182
+ penalties = {"no_helmet":25, "no_harness":30, "unsafe_posture":20, "unsafe_zone":35, "improper_tool_use":25}
 
 
 
 
183
  unique_violations = {v['violation'] for v in violations}
184
+ return max(0, 100 - sum(penalties.get(v,0) for v in unique_violations))
185
 
 
186
  def generate_violation_pdf(violations, score):
187
  try:
188
  pdf_buffer = BytesIO()
 
192
  c.setFont("Helvetica-Bold", 16)
193
  c.drawString(1*inch, 10*inch, "Safety Violation Report")
194
  c.setFont("Helvetica", 12)
195
+ c.drawString(1*inch, 9.5*inch, f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}")
196
  c.drawString(1*inch, 9*inch, f"Safety Score: {score}%")
197
 
198
+ # Violations List
199
  y = 8.5*inch
200
  c.setFont("Helvetica-Bold", 14)
201
  c.drawString(1*inch, y, "Violations Detected:")
202
  y -= 0.3*inch
 
203
  c.setFont("Helvetica", 10)
204
+
205
  for v in violations:
206
  text = f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]} at {v['timestamp']:.1f}s"
207
  c.drawString(1.2*inch, y, text)
 
209
  if y < 1*inch:
210
  c.showPage()
211
  y = 10*inch
212
+ c.setFont("Helvetica", 10)
213
 
214
  c.save()
215
  pdf_buffer.seek(0)
 
228
  def connect_to_salesforce():
229
  try:
230
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
231
+ sf.describe()
232
  return sf
233
  except Exception as e:
234
  logger.error(f"Salesforce connection failed: {e}")
235
  raise
236
 
237
+ def upload_to_salesforce(sf, pdf_file, violations, score):
238
  try:
 
 
239
  # Create record
240
+ violations_text = "\n".join(
241
+ f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]} at {v['timestamp']:.1f}s"
242
+ for v in violations
243
+ )
244
+
245
+ record = sf.Safety_Video_Report__c.create({
246
  "Compliance_Score__c": score,
247
  "Violations_Found__c": len(violations),
248
+ "Violations_Details__c": violations_text
249
+ })
 
 
 
 
250
 
251
+ # Upload PDF
252
+ encoded = base64.b64encode(pdf_file.getvalue()).decode()
253
+ content = sf.ContentVersion.create({
254
+ "Title": f"Safety_Report_{int(time.time())}",
255
+ "PathOnClient": "report.pdf",
256
+ "VersionData": encoded,
257
+ "FirstPublishLocationId": record['id']
258
+ })
259
 
260
+ return record['id'], f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content['id']}"
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  except Exception as e:
262
  logger.error(f"Salesforce upload failed: {e}")
263
+ return None, None
264
 
265
  # ========================== # Video Processing # ==========================
266
  def process_video(video_data):
267
  try:
268
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
269
+ video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
 
 
270
  with open(video_path, "wb") as f:
271
  f.write(video_data)
272
 
273
  cap = cv2.VideoCapture(video_path)
 
 
 
274
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
275
  tracker = SafetyTracker(frame_rate=fps)
276
  snapshots = []
277
 
278
+ while cap.isOpened():
 
279
  ret, frame = cap.read()
280
  if not ret:
281
  break
282
 
283
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
284
+ results = model(frame, verbose=False)
 
 
 
 
 
285
 
 
286
  detections = []
287
+ for box in results[0].boxes:
288
  cls = int(box.cls)
289
  label = CONFIG["VIOLATION_LABELS"].get(cls)
290
  if label and box.conf > CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
 
294
  'confidence': float(box.conf)
295
  })
296
 
297
+ new_violations = tracker.update(detections, frame)
 
298
 
 
299
  for violation in new_violations:
300
  snapshot = draw_detections(frame.copy(), [violation])
301
  timestamp = time.strftime("%Y%m%d_%H%M%S")
302
+ img_path = os.path.join(CONFIG["OUTPUT_DIR"], f"violation_{violation['worker_id']}_{timestamp}.jpg")
303
+ cv2.imwrite(img_path, cv2.cvtColor(snapshot, cv2.COLOR_RGB2BGR),
304
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
 
 
305
  snapshots.append({
306
  'path': img_path,
307
  'url': f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(img_path)}",
308
  'violation': violation
309
  })
 
 
 
 
 
310
 
311
+ yield f"Processing frame {int(cap.get(cv2.CAP_PROP_POS_FRAMES))}...", "", "", "", ""
312
+
313
  cap.release()
314
  os.remove(video_path)
315
 
316
+ if not snapshots:
317
+ yield "No violations detected", "Safety Score: 100%", "No snapshots", "N/A", "N/A"
 
 
 
 
 
 
 
 
 
 
 
318
  return
319
 
320
+ score = calculate_safety_score([v['violation'] for v in snapshots])
321
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf([v['violation'] for v in snapshots], score)
 
 
 
 
 
 
 
 
322
 
323
+ if pdf_file:
324
+ record_id, sf_url = upload_to_salesforce(connect_to_salesforce(), pdf_file,
325
+ [v['violation'] for v in snapshots], score)
326
+ else:
327
+ record_id, sf_url = None, None
328
+
329
+ snapshots_md = "\n".join(
330
+ f"![{v['violation']['violation']}]({v['url']})"
331
+ for v in snapshots
332
  )
333
 
334
  yield (
335
+ "\n".join(f"- {v['violation']['violation']} (Worker {v['violation']['worker_id']})" for v in snapshots),
336
  f"Safety Score: {score}%",
337
+ snapshots_md,
338
  f"Salesforce ID: {record_id or 'N/A'}",
339
+ sf_url or pdf_url or "N/A"
340
  )
341
 
342
  except Exception as e:
 
346
  yield f"Error: {str(e)}", "", "", "", ""
347
 
348
  # ========================== # Gradio Interface # ==========================
349
+ def gradio_interface(video):
350
+ if not video:
351
  return "Upload a video file", "", "", "", ""
352
+
353
+ for update in process_video(open(video, "rb").read()):
354
+ yield update
 
 
 
 
 
 
 
 
355
 
356
  interface = gr.Interface(
357
  fn=gradio_interface,
358
+ inputs=gr.Video(),
359
  outputs=[
360
+ gr.Markdown("Detected Violations"),
361
+ gr.Textbox("Safety Score"),
362
+ gr.Markdown("Evidence Snapshots"),
363
+ gr.Textbox("Salesforce Record"),
364
+ gr.Textbox("Report URL")
365
  ],
366
  title="AI Safety Compliance Analyzer",
367
  description="Detects PPE and safety violations in worksite videos"