PrashanthB461 commited on
Commit
a0709a7
·
verified ·
1 Parent(s): 6104e09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -311
app.py CHANGED
@@ -15,7 +15,7 @@ import logging
15
  from retrying import retry
16
 
17
  # ==========================
18
- # Enhanced Configuration
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
@@ -36,10 +36,10 @@ CONFIG = {
36
  "improper_tool_use": (255, 255, 0) # Yellow
37
  },
38
  "DISPLAY_NAMES": {
39
- "no_helmet": "No Helmet Violation",
40
- "no_harness": "No Harness Violation",
41
- "unsafe_posture": "Unsafe Posture Violation",
42
- "unsafe_zone": "Unsafe Zone Entry",
43
  "improper_tool_use": "Improper Tool Use"
44
  },
45
  "SF_CREDENTIALS": {
@@ -49,217 +49,206 @@ CONFIG = {
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 5, # Reduced for better detection
53
  "MAX_PROCESSING_TIME": 60,
54
- "CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
 
 
 
 
 
 
55
  "IOU_THRESHOLD": 0.4,
56
- "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
 
57
  }
58
 
59
- # Setup logging
60
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
61
  logger = logging.getLogger(__name__)
62
-
63
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
64
 
 
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  logger.info(f"Using device: {device}")
67
 
68
  def load_model():
69
  try:
70
- if os.path.isfile(CONFIG["MODEL_PATH"]):
71
- model_path = CONFIG["MODEL_PATH"]
72
- logger.info(f"Model loaded: {model_path}")
73
  else:
74
- model_path = CONFIG["FALLBACK_MODEL"]
75
- logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
76
- if not os.path.isfile(model_path):
77
- logger.info(f"Downloading fallback model: {model_path}")
78
- torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
79
- model = YOLO(model_path).to(device)
80
  return model
81
  except Exception as e:
82
- logger.error(f"Failed to load model: {e}")
83
  raise
84
 
85
  model = load_model()
86
 
87
- # ==========================
88
- # Enhanced Helper Functions
89
- # ==========================
90
  def draw_detections(frame, detections):
91
- """Draw bounding boxes and labels on frame"""
92
  for det in detections:
93
  label = det["violation"]
94
- confidence = det["confidence"]
95
- x, y, w, h = det["bounding_box"]
96
 
97
- # Convert from center coordinates to corner coordinates
98
- x1 = int(x - w/2)
99
- y1 = int(y - h/2)
100
- x2 = int(x + w/2)
101
- y2 = int(y + h/2)
102
 
103
- color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
 
106
- display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
107
- cv2.putText(frame, display_text, (x1, y1-10),
108
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
109
  return frame
110
 
111
  def calculate_iou(box1, box2):
112
- """Calculate Intersection over Union (IoU) for two bounding boxes."""
113
- x1, y1, w1, h1 = box1
114
- x2, y2, w2, h2 = box2
115
 
116
- # Convert to top-left and bottom-right coordinates
117
- x1_min, y1_min = x1 - w1/2, y1 - h1/2
118
- x1_max, y1_max = x1 + w1/2, y1 + h1/2
119
- x2_min, y2_min = x2 - w2/2, y2 - h2/2
120
- x2_max, y2_max = x2 + w2/2, y2 + h2/2
121
 
122
- # Calculate intersection
123
- x_min = max(x1_min, x2_min)
124
- y_min = max(y1_min, y2_min)
125
- x_max = min(x1_max, x2_max)
126
- y_max = min(y1_max, y2_max)
127
 
128
- intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
129
- area1 = w1 * h1
130
- area2 = w2 * h2
131
- union = area1 + area2 - intersection
132
 
133
- return intersection / union if union > 0 else 0
134
-
135
- # ==========================
136
- # Salesforce Integration (unchanged)
137
- # ==========================
138
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
139
- def connect_to_salesforce():
140
- try:
141
- sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
142
- logger.info("Connected to Salesforce")
143
- sf.describe()
144
- return sf
145
- except Exception as e:
146
- logger.error(f"Salesforce connection failed: {e}")
147
- raise
148
 
149
  def generate_violation_pdf(violations, score):
150
  try:
151
  pdf_filename = f"violations_{int(time.time())}.pdf"
152
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
153
  pdf_file = BytesIO()
 
154
  c = canvas.Canvas(pdf_file, pagesize=letter)
 
 
155
  c.setFont("Helvetica", 12)
156
- c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
157
- c.setFont("Helvetica", 10)
158
-
159
- y_position = 9.5 * inch
160
- report_data = {
161
- "Compliance Score": f"{score}%",
162
- "Violations Found": len(violations),
163
- "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
164
- }
165
- for key, value in report_data.items():
166
  c.drawString(1 * inch, y_position, f"{key}: {value}")
167
- y_position -= 0.3 * inch
168
-
 
 
169
  y_position -= 0.3 * inch
 
 
170
  c.drawString(1 * inch, y_position, "Violation Details:")
171
  y_position -= 0.3 * inch
 
 
172
  if not violations:
173
  c.drawString(1 * inch, y_position, "No violations detected.")
174
  else:
175
  for v in violations:
176
- display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
177
- text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
178
- c.drawString(1 * inch, y_position, text)
179
- y_position -= 0.3 * inch
 
 
 
180
  if y_position < 1 * inch:
181
  c.showPage()
182
- c.setFont("Helvetica", 10)
183
  y_position = 10 * inch
184
-
185
- c.showPage()
186
  c.save()
187
  pdf_file.seek(0)
188
-
189
  with open(pdf_path, "wb") as f:
190
  f.write(pdf_file.getvalue())
 
191
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
192
- logger.info(f"PDF generated: {public_url}")
193
  return pdf_path, public_url, pdf_file
 
194
  except Exception as e:
195
- logger.error(f"Error generating PDF: {e}")
196
  return "", "", None
197
 
 
 
 
 
 
 
 
 
 
 
198
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
199
  try:
200
- if not pdf_file:
201
- logger.error("No PDF file provided for upload")
202
- return ""
203
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
204
- content_version_data = {
205
- "Title": f"Safety_Violation_Report_{int(time.time())}",
206
- "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
207
  "VersionData": encoded_pdf,
208
  "FirstPublishLocationId": report_id
209
- }
210
- content_version = sf.ContentVersion.create(content_version_data)
211
- result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
212
- if not result['records']:
213
- logger.error("Failed to retrieve ContentVersion")
214
- return ""
215
- file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
216
- logger.info(f"PDF uploaded to Salesforce: {file_url}")
217
- return file_url
218
  except Exception as e:
219
- logger.error(f"Error uploading PDF to Salesforce: {e}")
220
  return ""
221
 
222
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
223
  try:
224
  sf = connect_to_salesforce()
 
225
  violations_text = "\n".join(
226
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
 
227
  for v in violations
228
- ) or "No violations detected."
229
- pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
230
-
231
  record_data = {
232
  "Compliance_Score__c": score,
233
  "Violations_Found__c": len(violations),
234
  "Violations_Details__c": violations_text,
235
- "Status__c": "Pending",
236
- "PDF_Report_URL__c": pdf_url
237
  }
238
- logger.info(f"Creating Salesforce record with data: {record_data}")
239
  try:
240
  record = sf.Safety_Video_Report__c.create(record_data)
241
- logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
 
242
  except Exception as e:
243
- logger.error(f"Failed to create Safety_Video_Report__c: {e}")
244
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
245
- logger.warning(f"Fell back to Account record: {record['id']}")
246
- record_id = record["id"]
247
-
 
248
  if pdf_file:
249
- uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
250
- if uploaded_url:
251
  try:
252
- sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
253
- logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
254
- except Exception as e:
255
- logger.error(f"Failed to update Safety_Video_Report__c: {e}")
256
- sf.Account.update(record_id, {"Description": uploaded_url})
257
- logger.info(f"Updated Account record {record_id} with PDF URL")
258
- pdf_url = uploaded_url
259
-
260
- return record_id, pdf_url
261
  except Exception as e:
262
- logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
263
  return None, ""
264
 
265
  def calculate_safety_score(violations):
@@ -270,265 +259,199 @@ def calculate_safety_score(violations):
270
  "unsafe_zone": 35,
271
  "improper_tool_use": 25
272
  }
273
- # Count unique violations per worker
274
- unique_violations = set()
275
- for v in violations:
276
- key = (v["worker_id"], v["violation"])
277
- unique_violations.add(key)
278
-
279
- total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
280
- score = 100 - total_penalty
281
- return max(score, 0)
282
 
283
- # ==========================
284
- # Enhanced Video Processing
285
- # ==========================
286
  def process_video(video_data):
287
  try:
288
- video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
289
- with open(video_path, "wb") as f:
290
  f.write(video_data)
291
- logger.info(f"Video saved: {video_path}")
292
-
293
- video = cv2.VideoCapture(video_path)
294
- if not video.isOpened():
295
- raise ValueError("Could not open video file")
296
-
 
297
  violations = []
298
  snapshots = []
 
 
 
299
  frame_count = 0
300
  start_time = time.time()
301
- fps = video.get(cv2.CAP_PROP_FPS)
302
- if fps <= 0:
303
- fps = 30 # Default assumption if FPS cannot be determined
304
-
305
- # Structure to track workers and their violations
306
- workers = []
307
- violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
- snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
309
-
310
- logger.info(f"Processing video with FPS: {fps}")
311
- logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
312
-
313
- while True:
314
- ret, frame = video.read()
315
  if not ret:
316
  break
317
-
318
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
319
  frame_count += 1
320
  continue
321
-
322
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
323
- logger.info("Processing time limit reached")
324
  break
325
-
326
  current_time = frame_count / fps
 
327
 
328
- # Run detection on this frame
329
- results = model(frame, device=device)
330
-
331
- current_detections = []
332
  for result in results:
333
- boxes = result.boxes
334
- for box in boxes:
335
  cls = int(box.cls)
336
  conf = float(box.conf)
337
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
 
339
- if label is None:
340
  continue
341
 
342
- if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
343
- continue
344
-
345
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
- current_detections.append({
348
  "frame": frame_count,
349
  "violation": label,
350
  "confidence": round(conf, 2),
351
  "bounding_box": bbox,
352
- "timestamp": current_time
353
- })
354
-
355
- # Process detections and associate with workers
356
- for detection in current_detections:
357
- # Find matching worker
358
- matched_worker = None
359
- max_iou = 0
360
-
361
- for worker in workers:
362
- iou = calculate_iou(detection["bounding_box"], worker["bbox"])
363
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
364
- max_iou = iou
365
- matched_worker = worker
366
-
367
- if matched_worker:
368
- # Update worker's position
369
- matched_worker["bbox"] = detection["bounding_box"]
370
- matched_worker["last_seen"] = current_time
371
- worker_id = matched_worker["id"]
372
- else:
373
- # New worker
374
- worker_id = len(workers) + 1
375
- workers.append({
376
- "id": worker_id,
377
- "bbox": detection["bounding_box"],
378
- "first_seen": current_time,
379
- "last_seen": current_time
380
  })
381
-
382
- # Add to violation history
383
- detection["worker_id"] = worker_id
384
- violation_history[detection["violation"]].append(detection)
385
-
386
  frame_count += 1
387
-
388
- video.release()
389
- os.remove(video_path)
390
 
391
- # Process violation history to confirm persistent violations
392
  for violation_type, detections in violation_history.items():
393
  if not detections:
394
  continue
395
 
396
- # Group by worker
397
- worker_violations = {}
398
  for det in detections:
399
- if det["worker_id"] not in worker_violations:
400
- worker_violations[det["worker_id"]] = []
401
- worker_violations[det["worker_id"]].append(det)
402
 
403
- # Check each worker's violations for persistence
404
- for worker_id, worker_dets in worker_violations.items():
405
- if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
406
- # Take the highest confidence detection
407
- best_detection = max(worker_dets, key=lambda x: x["confidence"])
408
- violations.append(best_detection)
 
 
409
 
410
- # Capture snapshot if not already taken
411
  if not snapshot_taken[violation_type]:
412
- # Get the frame for this violation
413
- cap = cv2.VideoCapture(video_path)
414
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
415
  ret, snapshot_frame = cap.read()
416
- cap.release()
417
-
418
  if ret:
419
- # Draw detections on snapshot
420
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
421
-
422
- snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
423
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
424
- cv2.imwrite(snapshot_path, snapshot_frame)
425
  snapshots.append({
426
  "violation": violation_type,
427
- "frame": best_detection["frame"],
428
- "snapshot_path": snapshot_path,
429
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
430
  })
431
  snapshot_taken[violation_type] = True
432
-
433
- # Final processing
434
- if not violations:
435
- logger.info("No persistent violations detected")
436
- return {
437
- "violations": [],
438
- "snapshots": [],
439
- "score": 100,
440
- "salesforce_record_id": None,
441
- "violation_details_url": "",
442
- "message": "No violations detected in the video."
443
- }
444
-
445
  score = calculate_safety_score(violations)
446
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
447
- report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
448
-
449
  return {
450
  "violations": violations,
451
  "snapshots": snapshots,
452
  "score": score,
453
- "salesforce_record_id": report_id,
454
- "violation_details_url": final_pdf_url,
455
  "message": ""
456
  }
 
457
  except Exception as e:
458
- logger.error(f"Error processing video: {e}", exc_info=True)
459
  return {
460
  "violations": [],
461
  "snapshots": [],
462
  "score": 100,
463
  "salesforce_record_id": None,
464
  "violation_details_url": "",
465
- "message": f"Error processing video: {e}"
466
  }
467
 
468
- # ==========================
469
- # Gradio Interface
470
- # ==========================
471
  def gradio_interface(video_file):
472
- if not video_file:
473
- return "No file uploaded.", "", "No file uploaded.", "", ""
474
  try:
475
- yield "Processing video... please wait.", "", "", "", ""
476
-
477
  with open(video_file, "rb") as f:
478
- video_data = f.read()
479
-
480
- result = process_video(video_data)
481
-
482
- if result.get("message"):
483
- yield result["message"], "", "", "", ""
484
- return
485
-
486
- violation_table = "No violations detected."
487
- if result["violations"]:
488
- header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
489
- separator = "|------------------------|---------------|------------|-----------|\n"
490
- rows = []
491
- violation_name_map = CONFIG["DISPLAY_NAMES"]
492
- for v in result["violations"]:
493
- display_name = violation_name_map.get(v["violation"], v["violation"])
494
- row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
495
- rows.append(row)
496
- violation_table = header + separator + "\n".join(rows)
497
-
498
- snapshots_text = "No snapshots captured."
499
- if result["snapshots"]:
500
- violation_name_map = CONFIG["DISPLAY_NAMES"]
501
- snapshots_text = "\n".join(
502
- f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
503
- for s in result["snapshots"]
504
  )
505
-
 
 
 
 
 
 
506
  yield (
507
  violation_table,
508
  f"Safety Score: {result['score']}%",
509
- snapshots_text,
510
- f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
511
- result["violation_details_url"] or "N/A"
512
  )
513
  except Exception as e:
514
- logger.error(f"Error in Gradio interface: {e}", exc_info=True)
515
- yield f"Error: {str(e)}", "", "Error in processing.", "", ""
516
 
517
  interface = gr.Interface(
518
  fn=gradio_interface,
519
  inputs=gr.Video(label="Upload Site Video"),
520
- outputs=[
521
- gr.Markdown(label="Detected Safety Violations"),
522
  gr.Textbox(label="Compliance Score"),
523
- gr.Markdown(label="Snapshots"),
524
- gr.Textbox(label="Salesforce Record ID"),
525
- gr.Textbox(label="Violation Details URL")
526
  ],
527
- title="Worksite Safety Violation Analyzer",
528
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
529
- allow_flagging="never"
530
  )
531
 
532
  if __name__ == "__main__":
533
- logger.info("Launching Enhanced Safety Analyzer App...")
534
- interface.launch()
 
15
  from retrying import retry
16
 
17
  # ==========================
18
+ # OPTIMIZED CONFIGURATION
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
 
36
  "improper_tool_use": (255, 255, 0) # Yellow
37
  },
38
  "DISPLAY_NAMES": {
39
+ "no_helmet": "No Helmet",
40
+ "no_harness": "No Harness",
41
+ "unsafe_posture": "Unsafe Posture",
42
+ "unsafe_zone": "Unsafe Zone",
43
  "improper_tool_use": "Improper Tool Use"
44
  },
45
  "SF_CREDENTIALS": {
 
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 3,
53
  "MAX_PROCESSING_TIME": 60,
54
+ "CONFIDENCE_THRESHOLD": {
55
+ "no_helmet": 0.4,
56
+ "no_harness": 0.3,
57
+ "unsafe_posture": 0.25,
58
+ "unsafe_zone": 0.3,
59
+ "improper_tool_use": 0.35
60
+ },
61
  "IOU_THRESHOLD": 0.4,
62
+ "MIN_VIOLATION_FRAMES": 3,
63
+ "MIN_VIOLATION_DURATION": 1.5
64
  }
65
 
66
+ # Initialize logging
67
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
68
  logger = logging.getLogger(__name__)
 
69
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
70
 
71
+ # Device configuration
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  logger.info(f"Using device: {device}")
74
 
75
  def load_model():
76
  try:
77
+ if os.path.exists(CONFIG["MODEL_PATH"]):
78
+ model = YOLO(CONFIG["MODEL_PATH"]).to(device)
79
+ logger.info("Loaded custom safety model")
80
  else:
81
+ model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
82
+ logger.warning("Using fallback model - recommend training yolov8_safety.pt")
 
 
 
 
83
  return model
84
  except Exception as e:
85
+ logger.error(f"Model loading failed: {str(e)}")
86
  raise
87
 
88
  model = load_model()
89
 
 
 
 
90
  def draw_detections(frame, detections):
91
+ """Draw bounding boxes with labels and confidence scores"""
92
  for det in detections:
93
  label = det["violation"]
94
+ x, y, w, h = [int(v) for v in det["bounding_box"]]
95
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
96
 
97
+ x1, y1 = int(x - w/2), int(y - h/2)
98
+ x2, y2 = int(x + w/2), int(y + h/2)
 
 
 
99
 
 
100
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
101
 
102
+ label_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {det['confidence']:.2f}"
103
+ (text_width, text_height), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
104
+ cv2.rectangle(frame, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1)
105
+ cv2.putText(frame, label_text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
106
  return frame
107
 
108
  def calculate_iou(box1, box2):
109
+ """Calculate Intersection over Union for two bounding boxes"""
110
+ box1 = [box1[0] - box1[2]/2, box1[1] - box1[3]/2, box1[0] + box1[2]/2, box1[1] + box1[3]/2]
111
+ box2 = [box2[0] - box2[2]/2, box2[1] - box2[3]/2, box2[0] + box2[2]/2, box2[1] + box2[3]/2]
112
 
113
+ x_left = max(box1[0], box2[0])
114
+ y_top = max(box1[1], box2[1])
115
+ x_right = min(box1[2], box2[2])
116
+ y_bottom = min(box1[3], box2[3])
 
117
 
118
+ if x_right < x_left or y_bottom < y_top:
119
+ return 0.0
 
 
 
120
 
121
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
122
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
123
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
 
124
 
125
+ return intersection_area / float(box1_area + box2_area - intersection_area)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def generate_violation_pdf(violations, score):
128
  try:
129
  pdf_filename = f"violations_{int(time.time())}.pdf"
130
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
131
  pdf_file = BytesIO()
132
+
133
  c = canvas.Canvas(pdf_file, pagesize=letter)
134
+ c.setFont("Helvetica-Bold", 14)
135
+ c.drawString(1 * inch, 10.5 * inch, "Worksite Safety Violation Report")
136
  c.setFont("Helvetica", 12)
137
+
138
+ y_position = 10 * inch
139
+ report_data = [
140
+ ("Compliance Score", f"{score}%"),
141
+ ("Total Violations", len(violations)),
142
+ ("Report Date", time.strftime("%Y-%m-%d %H:%M:%S"))
143
+ ]
144
+
145
+ for key, value in report_data:
 
146
  c.drawString(1 * inch, y_position, f"{key}: {value}")
147
+ y_position -= 0.4 * inch
148
+
149
+ y_position -= 0.2 * inch
150
+ c.line(1 * inch, y_position, 7.5 * inch, y_position)
151
  y_position -= 0.3 * inch
152
+
153
+ c.setFont("Helvetica-Bold", 12)
154
  c.drawString(1 * inch, y_position, "Violation Details:")
155
  y_position -= 0.3 * inch
156
+ c.setFont("Helvetica", 10)
157
+
158
  if not violations:
159
  c.drawString(1 * inch, y_position, "No violations detected.")
160
  else:
161
  for v in violations:
162
+ violation_text = (
163
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
164
+ f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f}, "
165
+ f"Worker: {v['worker_id']})"
166
+ )
167
+ c.drawString(1 * inch, y_position, violation_text)
168
+ y_position -= 0.25 * inch
169
  if y_position < 1 * inch:
170
  c.showPage()
 
171
  y_position = 10 * inch
172
+ c.setFont("Helvetica", 10)
173
+
174
  c.save()
175
  pdf_file.seek(0)
176
+
177
  with open(pdf_path, "wb") as f:
178
  f.write(pdf_file.getvalue())
179
+
180
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
181
+ logger.info(f"Generated PDF report: {public_url}")
182
  return pdf_path, public_url, pdf_file
183
+
184
  except Exception as e:
185
+ logger.error(f"PDF generation failed: {str(e)}")
186
  return "", "", None
187
 
188
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
189
+ def connect_to_salesforce():
190
+ try:
191
+ sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
192
+ logger.info("Connected to Salesforce")
193
+ return sf
194
+ except Exception as e:
195
+ logger.error(f"Salesforce connection failed: {str(e)}")
196
+ raise
197
+
198
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
199
  try:
 
 
 
200
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
201
+ content_version = sf.ContentVersion.create({
202
+ "Title": f"Safety_Report_{int(time.time())}",
203
+ "PathOnClient": "safety_report.pdf",
204
  "VersionData": encoded_pdf,
205
  "FirstPublishLocationId": report_id
206
+ })
207
+ return f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
 
 
 
 
 
 
 
208
  except Exception as e:
209
+ logger.error(f"PDF upload failed: {str(e)}")
210
  return ""
211
 
212
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
213
  try:
214
  sf = connect_to_salesforce()
215
+
216
  violations_text = "\n".join(
217
+ f"- {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
218
+ f"at {v['timestamp']:.2f}s (Worker {v['worker_id']}, Confidence: {v['confidence']:.2f})"
219
  for v in violations
220
+ ) or "No violations detected"
221
+
 
222
  record_data = {
223
  "Compliance_Score__c": score,
224
  "Violations_Found__c": len(violations),
225
  "Violations_Details__c": violations_text,
226
+ "Status__c": "New"
 
227
  }
228
+
229
  try:
230
  record = sf.Safety_Video_Report__c.create(record_data)
231
+ record_id = record["id"]
232
+ logger.info(f"Created Salesforce record: {record_id}")
233
  except Exception as e:
234
+ logger.error(f"Failed to create Safety Report: {str(e)}")
235
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
236
+ record_id = record["id"]
237
+ logger.warning(f"Created fallback Account record: {record_id}")
238
+
239
+ pdf_url = ""
240
  if pdf_file:
241
+ pdf_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
242
+ if pdf_url:
243
  try:
244
+ sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": pdf_url})
245
+ except:
246
+ sf.Account.update(record_id, {"Description": pdf_url})
247
+
248
+ return record_id, pdf_url if pdf_url else f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}"
249
+
 
 
 
250
  except Exception as e:
251
+ logger.error(f"Salesforce integration failed: {str(e)}")
252
  return None, ""
253
 
254
  def calculate_safety_score(violations):
 
259
  "unsafe_zone": 35,
260
  "improper_tool_use": 25
261
  }
262
+ unique_violations = {(v["worker_id"], v["violation"]) for v in violations}
263
+ total_penalty = sum(penalties.get(v[1], 0) for v in unique_violations)
264
+ return max(100 - total_penalty, 0)
 
 
 
 
 
 
265
 
 
 
 
266
  def process_video(video_data):
267
  try:
268
+ temp_video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
269
+ with open(temp_video_path, "wb") as f:
270
  f.write(video_data)
271
+
272
+ cap = cv2.VideoCapture(temp_video_path)
273
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
274
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
275
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
276
+
277
+ workers = []
278
  violations = []
279
  snapshots = []
280
+ violation_history = {k: [] for k in CONFIG["VIOLATION_LABELS"].values()}
281
+ snapshot_taken = {k: False for k in CONFIG["VIOLATION_LABELS"].values()}
282
+
283
  frame_count = 0
284
  start_time = time.time()
285
+
286
+ while cap.isOpened():
287
+ ret, frame = cap.read()
 
 
 
 
 
 
 
 
 
 
 
288
  if not ret:
289
  break
290
+
291
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
292
  frame_count += 1
293
  continue
294
+
295
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
296
+ logger.warning("Processing timeout reached")
297
  break
298
+
299
  current_time = frame_count / fps
300
+ results = model(frame, device=device, verbose=False)
301
 
 
 
 
 
302
  for result in results:
303
+ for box in result.boxes:
 
304
  cls = int(box.cls)
305
  conf = float(box.conf)
306
+ label = CONFIG["VIOLATION_LABELS"].get(cls)
307
 
308
+ if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3):
309
  continue
310
 
311
+ bbox = box.xywh.cpu().numpy()[0].tolist()
312
+
313
+ matched_worker = None
314
+ max_iou = 0
315
+ for worker in workers:
316
+ iou = calculate_iou(bbox, worker["bbox"])
317
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
318
+ max_iou = iou
319
+ matched_worker = worker
320
+
321
+ if matched_worker:
322
+ worker_id = matched_worker["id"]
323
+ matched_worker["bbox"] = bbox
324
+ matched_worker["last_seen"] = current_time
325
+ else:
326
+ worker_id = len(workers) + 1
327
+ workers.append({
328
+ "id": worker_id,
329
+ "bbox": bbox,
330
+ "first_seen": current_time,
331
+ "last_seen": current_time
332
+ })
333
 
334
+ violation_history[label].append({
335
  "frame": frame_count,
336
  "violation": label,
337
  "confidence": round(conf, 2),
338
  "bounding_box": bbox,
339
+ "timestamp": current_time,
340
+ "worker_id": worker_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  })
342
+
 
 
 
 
343
  frame_count += 1
 
 
 
344
 
 
345
  for violation_type, detections in violation_history.items():
346
  if not detections:
347
  continue
348
 
349
+ worker_groups = {}
 
350
  for det in detections:
351
+ if det["worker_id"] not in worker_groups:
352
+ worker_groups[det["worker_id"]] = []
353
+ worker_groups[det["worker_id"]].append(det)
354
 
355
+ for worker_id, worker_dets in worker_groups.items():
356
+ if len(worker_dets) < 2:
357
+ continue
358
+
359
+ duration = worker_dets[-1]["timestamp"] - worker_dets[0]["timestamp"]
360
+ if duration >= CONFIG["MIN_VIOLATION_DURATION"]:
361
+ best_det = max(worker_dets, key=lambda x: x["confidence"])
362
+ violations.append(best_det)
363
 
 
364
  if not snapshot_taken[violation_type]:
365
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_det["frame"])
 
 
366
  ret, snapshot_frame = cap.read()
 
 
367
  if ret:
368
+ snapshot_frame = draw_detections(snapshot_frame, [best_det])
369
+ filename = f"{violation_type}_{best_det['frame']}.jpg"
370
+ path = os.path.join(CONFIG["OUTPUT_DIR"], filename)
371
+ cv2.imwrite(path, snapshot_frame)
 
 
372
  snapshots.append({
373
  "violation": violation_type,
374
+ "frame": best_det["frame"],
375
+ "path": path,
376
+ "url": f"{CONFIG['PUBLIC_URL_BASE']}{filename}"
377
  })
378
  snapshot_taken[violation_type] = True
379
+
380
+ cap.release()
381
+ os.remove(temp_video_path)
382
+
 
 
 
 
 
 
 
 
 
383
  score = calculate_safety_score(violations)
384
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
385
+ record_id, sf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
386
+
387
  return {
388
  "violations": violations,
389
  "snapshots": snapshots,
390
  "score": score,
391
+ "salesforce_record_id": record_id,
392
+ "violation_details_url": sf_url or pdf_url,
393
  "message": ""
394
  }
395
+
396
  except Exception as e:
397
+ logger.error(f"Video processing failed: {str(e)}")
398
  return {
399
  "violations": [],
400
  "snapshots": [],
401
  "score": 100,
402
  "salesforce_record_id": None,
403
  "violation_details_url": "",
404
+ "message": f"Error: {str(e)}"
405
  }
406
 
 
 
 
407
  def gradio_interface(video_file):
 
 
408
  try:
409
+ yield "Analyzing video...", "", "", "", ""
410
+
411
  with open(video_file, "rb") as f:
412
+ result = process_video(f.read())
413
+
414
+ violation_table = (
415
+ "| Violation Type | Timestamp | Confidence | Worker ID |\n"
416
+ "|---------------------|-----------|------------|-----------|\n" +
417
+ "\n".join(
418
+ f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<19} | "
419
+ f"{v['timestamp']:.2f} | "
420
+ f"{v['confidence']:.2f} | "
421
+ f"{v['worker_id']} |"
422
+ for v in result["violations"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  )
424
+ ) if result["violations"] else "No violations detected"
425
+
426
+ snapshots_md = "\n".join(
427
+ f"![{s['violation']} at frame {s['frame']}]({s['url']})"
428
+ for s in result["snapshots"]
429
+ ) if result["snapshots"] else "No snapshots"
430
+
431
  yield (
432
  violation_table,
433
  f"Safety Score: {result['score']}%",
434
+ snapshots_md,
435
+ f"Salesforce ID: {result['salesforce_record_id'] or 'None'}",
436
+ result["violation_details_url"] or "None"
437
  )
438
  except Exception as e:
439
+ logger.error(f"Interface error: {str(e)}")
440
+ yield f"Error: {str(e)}", "", "", "", ""
441
 
442
  interface = gr.Interface(
443
  fn=gradio_interface,
444
  inputs=gr.Video(label="Upload Site Video"),
445
+ outputs=[
446
+ gr.Markdown(label="Violations Detected"),
447
  gr.Textbox(label="Compliance Score"),
448
+ gr.Markdown(label="Evidence Snapshots"),
449
+ gr.Textbox(label="Salesforce Record"),
450
+ gr.Textbox(label="Report URL")
451
  ],
452
+ title="AI Safety Compliance Monitor",
453
+ description="Detects 5 violation types: No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use"
 
454
  )
455
 
456
  if __name__ == "__main__":
457
+ interface.launch(share=True)