PrashanthB461 commited on
Commit
ac968c9
·
verified ·
1 Parent(s): 60c81b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -446
app.py CHANGED
@@ -12,523 +12,295 @@ from reportlab.lib.units import inch
12
  from io import BytesIO
13
  import base64
14
  import logging
15
- from retrying import retry
16
 
17
  # ==========================
18
- # Enhanced Configuration
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
22
- "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
- 1: "no_harness",
27
  2: "unsafe_posture",
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
  "CLASS_COLORS": {
32
- "no_helmet": (0, 0, 255), # Red
33
- "no_harness": (0, 165, 255), # Orange
34
- "unsafe_posture": (0, 255, 0), # Green
35
- "unsafe_zone": (255, 0, 0), # Blue
36
- "improper_tool_use": (255, 255, 0) # Yellow
37
- },
38
- "DISPLAY_NAMES": {
39
- "no_helmet": "No Helmet Violation",
40
- "no_harness": "No Harness Violation",
41
- "unsafe_posture": "Unsafe Posture Violation",
42
- "unsafe_zone": "Unsafe Zone Entry",
43
- "improper_tool_use": "Improper Tool Use"
44
  },
 
 
 
 
 
45
  "SF_CREDENTIALS": {
46
  "username": "prashanth1ai@safety.com",
47
  "password": "SaiPrash461",
48
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
  "domain": "login"
50
  },
51
- "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 5, # Reduced for better detection
53
- "MAX_PROCESSING_TIME": 60,
54
- "CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
55
- "IOU_THRESHOLD": 0.4,
56
- "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
57
  }
58
 
59
- # Setup logging
60
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
61
- logger = logging.getLogger(__name__)
62
-
63
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
64
-
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
- logger.info(f"Using device: {device}")
67
-
68
- def load_model():
69
- try:
70
- if os.path.isfile(CONFIG["MODEL_PATH"]):
71
- model_path = CONFIG["MODEL_PATH"]
72
- logger.info(f"Model loaded: {model_path}")
73
- else:
74
- model_path = CONFIG["FALLBACK_MODEL"]
75
- logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
76
- if not os.path.isfile(model_path):
77
- logger.info(f"Downloading fallback model: {model_path}")
78
- torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
79
- model = YOLO(model_path).to(device)
80
- return model
81
- except Exception as e:
82
- logger.error(f"Failed to load model: {e}")
83
- raise
84
 
85
- model = load_model()
 
 
86
 
87
  # ==========================
88
- # Enhanced Helper Functions
89
  # ==========================
90
  def draw_detections(frame, detections):
91
- """Draw bounding boxes and labels on frame"""
92
  for det in detections:
93
- label = det["violation"]
94
- confidence = det["confidence"]
95
- x, y, w, h = det["bounding_box"]
96
-
97
- # Convert from center coordinates to corner coordinates
98
- x1 = int(x - w/2)
99
- y1 = int(y - h/2)
100
- x2 = int(x + w/2)
101
- y2 = int(y + h/2)
102
 
103
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
-
106
- display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
107
- cv2.putText(frame, display_text, (x1, y1-10),
108
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
109
  return frame
110
 
111
- def calculate_iou(box1, box2):
112
- """Calculate Intersection over Union (IoU) for two bounding boxes."""
113
- x1, y1, w1, h1 = box1
114
- x2, y2, w2, h2 = box2
115
 
116
- # Convert to top-left and bottom-right coordinates
117
- x1_min, y1_min = x1 - w1/2, y1 - h1/2
118
- x1_max, y1_max = x1 + w1/2, y1 + h1/2
119
- x2_min, y2_min = x2 - w2/2, y2 - h2/2
120
- x2_max, y2_max = x2 + w2/2, y2 + h2/2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- # Calculate intersection
123
- x_min = max(x1_min, x2_min)
124
- y_min = max(y1_min, y2_min)
125
- x_max = min(x1_max, x2_max)
126
- y_max = min(y1_max, y2_max)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
129
- area1 = w1 * h1
130
- area2 = w2 * h2
131
- union = area1 + area2 - intersection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- return intersection / union if union > 0 else 0
134
 
135
  # ==========================
136
- # Salesforce Integration (unchanged)
137
  # ==========================
138
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
139
- def connect_to_salesforce():
140
- try:
141
- sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
142
- logger.info("Connected to Salesforce")
143
- sf.describe()
144
- return sf
145
- except Exception as e:
146
- logger.error(f"Salesforce connection failed: {e}")
147
- raise
148
-
149
- def generate_violation_pdf(violations, score):
150
  try:
151
- pdf_filename = f"violations_{int(time.time())}.pdf"
152
- pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
153
  pdf_file = BytesIO()
154
  c = canvas.Canvas(pdf_file, pagesize=letter)
 
 
 
 
155
  c.setFont("Helvetica", 12)
156
- c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
 
 
 
 
 
 
 
 
 
 
 
157
  c.setFont("Helvetica", 10)
158
-
159
- y_position = 9.5 * inch
160
- report_data = {
161
- "Compliance Score": f"{score}%",
162
- "Violations Found": len(violations),
163
- "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
164
- }
165
- for key, value in report_data.items():
166
- c.drawString(1 * inch, y_position, f"{key}: {value}")
167
- y_position -= 0.3 * inch
168
-
169
- y_position -= 0.3 * inch
170
- c.drawString(1 * inch, y_position, "Violation Details:")
171
- y_position -= 0.3 * inch
172
- if not violations:
173
- c.drawString(1 * inch, y_position, "No violations detected.")
174
- else:
175
- for v in violations:
176
- display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
177
- text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
178
- c.drawString(1 * inch, y_position, text)
179
- y_position -= 0.3 * inch
180
- if y_position < 1 * inch:
181
- c.showPage()
182
- c.setFont("Helvetica", 10)
183
- y_position = 10 * inch
184
-
185
- c.showPage()
186
  c.save()
187
  pdf_file.seek(0)
188
-
189
- with open(pdf_path, "wb") as f:
190
- f.write(pdf_file.getvalue())
191
- public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
192
- logger.info(f"PDF generated: {public_url}")
193
- return pdf_path, public_url, pdf_file
194
- except Exception as e:
195
- logger.error(f"Error generating PDF: {e}")
196
- return "", "", None
197
-
198
- def upload_pdf_to_salesforce(sf, pdf_file, report_id):
199
- try:
200
- if not pdf_file:
201
- logger.error("No PDF file provided for upload")
202
- return ""
203
- encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
204
- content_version_data = {
205
- "Title": f"Safety_Violation_Report_{int(time.time())}",
206
- "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
207
- "VersionData": encoded_pdf,
208
- "FirstPublishLocationId": report_id
209
- }
210
- content_version = sf.ContentVersion.create(content_version_data)
211
- result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
212
- if not result['records']:
213
- logger.error("Failed to retrieve ContentVersion")
214
- return ""
215
- file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
216
- logger.info(f"PDF uploaded to Salesforce: {file_url}")
217
- return file_url
218
- except Exception as e:
219
- logger.error(f"Error uploading PDF to Salesforce: {e}")
220
- return ""
221
-
222
- def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
223
- try:
224
- sf = connect_to_salesforce()
225
- violations_text = "\n".join(
226
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
227
- for v in violations
228
- ) or "No violations detected."
229
- pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
230
-
231
- record_data = {
232
- "Compliance_Score__c": score,
233
- "Violations_Found__c": len(violations),
234
- "Violations_Details__c": violations_text,
235
- "Status__c": "Pending",
236
- "PDF_Report_URL__c": pdf_url
237
- }
238
- logger.info(f"Creating Salesforce record with data: {record_data}")
239
- try:
240
- record = sf.Safety_Video_Report__c.create(record_data)
241
- logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
242
- except Exception as e:
243
- logger.error(f"Failed to create Safety_Video_Report__c: {e}")
244
- record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
245
- logger.warning(f"Fell back to Account record: {record['id']}")
246
- record_id = record["id"]
247
-
248
- if pdf_file:
249
- uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
250
- if uploaded_url:
251
- try:
252
- sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
253
- logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
254
- except Exception as e:
255
- logger.error(f"Failed to update Safety_Video_Report__c: {e}")
256
- sf.Account.update(record_id, {"Description": uploaded_url})
257
- logger.info(f"Updated Account record {record_id} with PDF URL")
258
- pdf_url = uploaded_url
259
-
260
- return record_id, pdf_url
261
  except Exception as e:
262
- logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
263
- return None, ""
264
-
265
- def calculate_safety_score(violations):
266
- penalties = {
267
- "no_helmet": 25,
268
- "no_harness": 30,
269
- "unsafe_posture": 20,
270
- "unsafe_zone": 35,
271
- "improper_tool_use": 25
272
- }
273
- # Count unique violations per worker
274
- unique_violations = set()
275
- for v in violations:
276
- key = (v["worker_id"], v["violation"])
277
- unique_violations.add(key)
278
-
279
- total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
280
- score = 100 - total_penalty
281
- return max(score, 0)
282
 
283
  # ==========================
284
- # Enhanced Video Processing
285
  # ==========================
286
- def process_video(video_data):
 
287
  try:
288
- video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
289
- with open(video_path, "wb") as f:
290
- f.write(video_data)
291
- logger.info(f"Video saved: {video_path}")
292
-
293
- video = cv2.VideoCapture(video_path)
294
- if not video.isOpened():
295
- raise ValueError("Could not open video file")
296
-
297
- violations = []
298
- snapshots = []
299
- frame_count = 0
300
- start_time = time.time()
301
- fps = video.get(cv2.CAP_PROP_FPS)
302
- if fps <= 0:
303
- fps = 30 # Default assumption if FPS cannot be determined
304
-
305
- # Structure to track workers and their violations
306
- workers = []
307
- violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
- snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
309
-
310
- logger.info(f"Processing video with FPS: {fps}")
311
- logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
312
-
313
- while True:
314
- ret, frame = video.read()
315
- if not ret:
316
- break
317
-
318
- if frame_count % CONFIG["FRAME_SKIP"] != 0:
319
- frame_count += 1
320
- continue
321
-
322
- if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
323
- logger.info("Processing time limit reached")
324
- break
325
-
326
- current_time = frame_count / fps
327
-
328
- # Run detection on this frame
329
- results = model(frame, device=device)
330
-
331
- current_detections = []
332
- for result in results:
333
- boxes = result.boxes
334
- for box in boxes:
335
- cls = int(box.cls)
336
- conf = float(box.conf)
337
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
-
339
- if label is None:
340
- continue
341
-
342
- if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
343
- continue
344
-
345
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
-
347
- current_detections.append({
348
- "frame": frame_count,
349
- "violation": label,
350
- "confidence": round(conf, 2),
351
- "bounding_box": bbox,
352
- "timestamp": current_time
353
- })
354
-
355
- # Process detections and associate with workers
356
- for detection in current_detections:
357
- # Find matching worker
358
- matched_worker = None
359
- max_iou = 0
360
-
361
- for worker in workers:
362
- iou = calculate_iou(detection["bounding_box"], worker["bbox"])
363
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
364
- max_iou = iou
365
- matched_worker = worker
366
-
367
- if matched_worker:
368
- # Update worker's position
369
- matched_worker["bbox"] = detection["bounding_box"]
370
- matched_worker["last_seen"] = current_time
371
- worker_id = matched_worker["id"]
372
- else:
373
- # New worker
374
- worker_id = len(workers) + 1
375
- workers.append({
376
- "id": worker_id,
377
- "bbox": detection["bounding_box"],
378
- "first_seen": current_time,
379
- "last_seen": current_time
380
- })
381
-
382
- # Add to violation history
383
- detection["worker_id"] = worker_id
384
- violation_history[detection["violation"]].append(detection)
385
-
386
- frame_count += 1
387
-
388
- video.release()
389
- os.remove(video_path)
390
 
391
- # Process violation history to confirm persistent violations
392
- for violation_type, detections in violation_history.items():
393
- if not detections:
394
- continue
395
-
396
- # Group by worker
397
- worker_violations = {}
398
- for det in detections:
399
- if det["worker_id"] not in worker_violations:
400
- worker_violations[det["worker_id"]] = []
401
- worker_violations[det["worker_id"]].append(det)
402
-
403
- # Check each worker's violations for persistence
404
- for worker_id, worker_dets in worker_violations.items():
405
- if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
406
- # Take the highest confidence detection
407
- best_detection = max(worker_dets, key=lambda x: x["confidence"])
408
- violations.append(best_detection)
409
-
410
- # Capture snapshot if not already taken
411
- if not snapshot_taken[violation_type]:
412
- # Get the frame for this violation
413
- cap = cv2.VideoCapture(video_path)
414
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
415
- ret, snapshot_frame = cap.read()
416
- cap.release()
417
-
418
- if ret:
419
- # Draw detections on snapshot
420
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
421
-
422
- snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
423
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
424
- cv2.imwrite(snapshot_path, snapshot_frame)
425
- snapshots.append({
426
- "violation": violation_type,
427
- "frame": best_detection["frame"],
428
- "snapshot_path": snapshot_path,
429
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
430
- })
431
- snapshot_taken[violation_type] = True
432
-
433
- # Final processing
434
- if not violations:
435
- logger.info("No persistent violations detected")
436
- return {
437
- "violations": [],
438
- "snapshots": [],
439
- "score": 100,
440
- "salesforce_record_id": None,
441
- "violation_details_url": "",
442
- "message": "No violations detected in the video."
443
- }
444
-
445
- score = calculate_safety_score(violations)
446
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
447
- report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
448
-
449
- return {
450
- "violations": violations,
451
- "snapshots": snapshots,
452
- "score": score,
453
- "salesforce_record_id": report_id,
454
- "violation_details_url": final_pdf_url,
455
- "message": ""
456
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  except Exception as e:
458
- logger.error(f"Error processing video: {e}", exc_info=True)
459
- return {
460
- "violations": [],
461
- "snapshots": [],
462
- "score": 100,
463
- "salesforce_record_id": None,
464
- "violation_details_url": "",
465
- "message": f"Error processing video: {e}"
466
- }
467
 
468
  # ==========================
469
  # Gradio Interface
470
  # ==========================
471
- def gradio_interface(video_file):
 
472
  if not video_file:
473
- return "No file uploaded.", "", "No file uploaded.", "", ""
 
474
  try:
475
- yield "Processing video... please wait.", "", "", "", ""
476
-
477
- with open(video_file, "rb") as f:
478
- video_data = f.read()
479
-
480
- result = process_video(video_data)
481
-
482
- if result.get("message"):
483
- yield result["message"], "", "", "", ""
484
- return
485
-
486
- violation_table = "No violations detected."
487
- if result["violations"]:
488
- header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
489
- separator = "|------------------------|---------------|------------|-----------|\n"
490
- rows = []
491
- violation_name_map = CONFIG["DISPLAY_NAMES"]
492
- for v in result["violations"]:
493
- display_name = violation_name_map.get(v["violation"], v["violation"])
494
- row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
495
- rows.append(row)
496
- violation_table = header + separator + "\n".join(rows)
497
-
498
- snapshots_text = "No snapshots captured."
499
- if result["snapshots"]:
500
- violation_name_map = CONFIG["DISPLAY_NAMES"]
501
- snapshots_text = "\n".join(
502
- f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
503
- for s in result["snapshots"]
504
  )
505
-
506
- yield (
507
- violation_table,
508
- f"Safety Score: {result['score']}%",
509
- snapshots_text,
510
- f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
511
- result["violation_details_url"] or "N/A"
512
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  except Exception as e:
514
- logger.error(f"Error in Gradio interface: {e}", exc_info=True)
515
- yield f"Error: {str(e)}", "", "Error in processing.", "", ""
516
 
 
517
  interface = gr.Interface(
518
- fn=gradio_interface,
519
  inputs=gr.Video(label="Upload Site Video"),
520
- outputs=[
521
- gr.Markdown(label="Detected Safety Violations"),
522
- gr.Textbox(label="Compliance Score"),
523
- gr.Markdown(label="Snapshots"),
524
- gr.Textbox(label="Salesforce Record ID"),
525
- gr.Textbox(label="Violation Details URL")
526
  ],
527
- title="Worksite Safety Violation Analyzer",
528
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
 
 
 
529
  allow_flagging="never"
530
  )
531
 
532
  if __name__ == "__main__":
533
- logger.info("Launching Enhanced Safety Analyzer App...")
534
- interface.launch()
 
12
  from io import BytesIO
13
  import base64
14
  import logging
15
+ from concurrent.futures import ThreadPoolExecutor
16
 
17
  # ==========================
18
+ # Optimized Configuration
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
 
22
  "OUTPUT_DIR": "static/output",
23
  "VIOLATION_LABELS": {
24
  0: "no_helmet",
25
+ 1: "no_harness",
26
  2: "unsafe_posture",
27
  3: "unsafe_zone",
28
  4: "improper_tool_use"
29
  },
30
  "CLASS_COLORS": {
31
+ "no_helmet": (0, 0, 255),
32
+ "no_harness": (0, 165, 255),
33
+ "unsafe_posture": (0, 255, 0),
34
+ "unsafe_zone": (255, 0, 0),
35
+ "improper_tool_use": (255, 255, 0)
 
 
 
 
 
 
 
36
  },
37
+ "FRAME_SKIP": 8, # Balanced speed/accuracy
38
+ "CONFIDENCE_THRESHOLD": 0.35,
39
+ "MIN_DETECTIONS": 2,
40
+ "MAX_WORKERS": 4,
41
+ "PROCESSING_RESOLUTION": (640, 640),
42
  "SF_CREDENTIALS": {
43
  "username": "prashanth1ai@safety.com",
44
  "password": "SaiPrash461",
45
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
46
  "domain": "login"
47
  },
48
+ "MAX_PROCESSING_SECONDS": 30 # Timeout
 
 
 
 
 
49
  }
50
 
51
+ # Setup
 
 
 
52
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
 
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ logging.basicConfig(level=logging.INFO)
55
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Load and warm up model
58
+ model = YOLO(CONFIG["MODEL_PATH"]).to(device)
59
+ model.warmup(imgsz=[1, 3, *CONFIG["PROCESSING_RESOLUTION"]])
60
 
61
  # ==========================
62
+ # Core Processing Functions
63
  # ==========================
64
  def draw_detections(frame, detections):
65
+ """Draw bounding boxes on frame"""
66
  for det in detections:
67
+ label = det["label"]
68
+ x, y, w, h = det["bbox"]
69
+ x1, y1 = int(x - w/2), int(y - h/2)
70
+ x2, y2 = int(x + w/2), int(y + h/2)
 
 
 
 
 
71
 
72
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
73
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
74
+ cv2.putText(frame, f"{label}: {det['confidence']:.2f}",
75
+ (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
 
 
76
  return frame
77
 
78
+ def process_frame(frame, frame_count, fps):
79
+ """Process single frame for violations"""
80
+ resized = cv2.resize(frame, CONFIG["PROCESSING_RESOLUTION"])
81
+ results = model(resized, verbose=False)
82
 
83
+ frame_violations = []
84
+ for result in results:
85
+ for box in result.boxes:
86
+ conf = box.conf.item()
87
+ if conf > CONFIG["CONFIDENCE_THRESHOLD"]:
88
+ label = CONFIG["VIOLATION_LABELS"].get(int(box.cls.item()))
89
+ if label:
90
+ frame_violations.append({
91
+ "label": label,
92
+ "confidence": conf,
93
+ "timestamp": frame_count / fps,
94
+ "frame": frame_count,
95
+ "bbox": box.xywh.cpu().numpy()[0]
96
+ })
97
+ return frame_violations
98
+
99
+ def process_video(video_path):
100
+ """Optimized video processing pipeline"""
101
+ start_time = time.time()
102
+ cap = cv2.VideoCapture(video_path)
103
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
104
+ frame_count = 0
105
+ all_violations = []
106
+ snapshots = {}
107
+ last_update = time.time()
108
 
109
+ with ThreadPoolExecutor(max_workers=CONFIG["MAX_WORKERS"]) as executor:
110
+ futures = []
111
+
112
+ while cap.isOpened():
113
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_SECONDS"]:
114
+ logger.warning("Processing timeout reached")
115
+ break
116
+
117
+ ret, frame = cap.read()
118
+ if not ret:
119
+ break
120
+
121
+ if frame_count % CONFIG["FRAME_SKIP"] == 0:
122
+ futures.append(executor.submit(
123
+ process_frame,
124
+ frame.copy(),
125
+ frame_count,
126
+ fps
127
+ ))
128
+
129
+ frame_count += 1
130
+
131
+ # Process completed frames
132
+ while futures and futures[0].done():
133
+ for violation in futures.pop(0).result():
134
+ all_violations.append(violation)
135
+ if violation["label"] not in snapshots:
136
+ snapshots[violation["label"]] = {
137
+ "frame": draw_detections(frame.copy(), [violation]),
138
+ "timestamp": violation["timestamp"],
139
+ "confidence": violation["confidence"]
140
+ }
141
+
142
+ cap.release()
143
 
144
+ # Filter violations
145
+ confirmed_violations = []
146
+ violation_counts = {}
147
+ for v in all_violations:
148
+ key = v["label"]
149
+ violation_counts[key] = violation_counts.get(key, 0) + 1
150
+
151
+ for v in all_violations:
152
+ if violation_counts[v["label"]] >= CONFIG["MIN_DETECTIONS"]:
153
+ confirmed_violations.append(v)
154
+
155
+ # Save snapshots
156
+ snapshot_paths = []
157
+ for label, data in snapshots.items():
158
+ if violation_counts.get(label, 0) >= CONFIG["MIN_DETECTIONS"]:
159
+ path = os.path.join(CONFIG["OUTPUT_DIR"], f"{label}_{int(time.time())}.jpg")
160
+ cv2.imwrite(path, data["frame"])
161
+ snapshot_paths.append({
162
+ "label": label,
163
+ "path": path,
164
+ "timestamp": data["timestamp"],
165
+ "confidence": data["confidence"]
166
+ })
167
 
168
+ return confirmed_violations, snapshot_paths, time.time() - start_time
169
 
170
  # ==========================
171
+ # Reporting Functions
172
  # ==========================
173
+ def generate_report(violations, snapshots, processing_time):
174
+ """Generate PDF report"""
 
 
 
 
 
 
 
 
 
 
175
  try:
 
 
176
  pdf_file = BytesIO()
177
  c = canvas.Canvas(pdf_file, pagesize=letter)
178
+
179
+ # Header
180
+ c.setFont("Helvetica-Bold", 14)
181
+ c.drawString(1*inch, 10.5*inch, "Safety Violation Report")
182
  c.setFont("Helvetica", 12)
183
+
184
+ # Summary
185
+ y_pos = 9.8*inch
186
+ c.drawString(1*inch, y_pos, f"Processing Time: {processing_time:.1f}s")
187
+ y_pos -= 0.4*inch
188
+ c.drawString(1*inch, y_pos, f"Total Violations: {len(violations)}")
189
+ y_pos -= 0.6*inch
190
+
191
+ # Violations
192
+ c.setFont("Helvetica-Bold", 12)
193
+ c.drawString(1*inch, y_pos, "Violation Details:")
194
+ y_pos -= 0.3*inch
195
  c.setFont("Helvetica", 10)
196
+
197
+ for v in violations[:50]: # Limit to first 50
198
+ text = f"{v['label']} at {v['timestamp']:.1f}s (Confidence: {v['confidence']:.2f})"
199
+ c.drawString(1*inch, y_pos, text)
200
+ y_pos -= 0.2*inch
201
+ if y_pos < 1*inch:
202
+ c.showPage()
203
+ y_pos = 10*inch
204
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  c.save()
206
  pdf_file.seek(0)
207
+ return pdf_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  except Exception as e:
209
+ logger.error(f"Report generation failed: {e}")
210
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  # ==========================
213
+ # Salesforce Integration
214
  # ==========================
215
+ def upload_to_salesforce(violations, snapshots, processing_time):
216
+ """Upload results to Salesforce"""
217
  try:
218
+ sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ # Create record
221
+ record_data = {
222
+ "Processing_Time__c": processing_time,
223
+ "Violation_Count__c": len(violations),
224
+ "Status__c": "Completed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  }
226
+ record = sf.Safety_Video_Report__c.create(record_data)
227
+
228
+ # Upload report
229
+ pdf = generate_report(violations, snapshots, processing_time)
230
+ if pdf:
231
+ encoded = base64.b64encode(pdf.getvalue()).decode("utf-8")
232
+ sf.ContentVersion.create({
233
+ "Title": f"Safety_Report_{record['id']}",
234
+ "PathOnClient": "report.pdf",
235
+ "VersionData": encoded,
236
+ "FirstPublishLocationId": record['id']
237
+ })
238
+
239
+ return record['id']
240
  except Exception as e:
241
+ logger.error(f"Salesforce upload failed: {e}")
242
+ return None
 
 
 
 
 
 
 
243
 
244
  # ==========================
245
  # Gradio Interface
246
  # ==========================
247
+ def analyze_video(video_file):
248
+ """Main processing function"""
249
  if not video_file:
250
+ return "No video uploaded", "", "", ""
251
+
252
  try:
253
+ # Process video
254
+ violations, snapshots, proc_time = process_video(video_file)
255
+
256
+ # Generate outputs
257
+ violation_table = (
258
+ "| Violation | Time (s) | Confidence |\n"
259
+ "|-----------|----------|------------|\n" +
260
+ "\n".join(
261
+ f"| {v['label']} | {v['timestamp']:.1f} | {v['confidence']:.2f} |"
262
+ for v in violations[:20] # Show first 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  )
 
 
 
 
 
 
 
264
  )
265
+
266
+ snapshot_markdown = "\n".join(
267
+ f"**{s['label']}** ({s['timestamp']:.1f}s, {s['confidence']:.2f}):\n"
268
+ f"![]({s['path']})"
269
+ for s in snapshots
270
+ )
271
+
272
+ # Salesforce upload (async)
273
+ sf_id = "Not uploaded"
274
+ if violations:
275
+ sf_id = upload_to_salesforce(violations, snapshots, proc_time) or "Upload failed"
276
+
277
+ return (
278
+ f"Processed in {proc_time:.1f}s\n{violation_table}",
279
+ snapshot_markdown,
280
+ f"Found {len(violations)} violations",
281
+ f"Salesforce ID: {sf_id}"
282
+ )
283
+
284
  except Exception as e:
285
+ return f"Error: {str(e)}", "", "", ""
 
286
 
287
+ # Launch interface
288
  interface = gr.Interface(
289
+ fn=analyze_video,
290
  inputs=gr.Video(label="Upload Site Video"),
291
+ outputs=[
292
+ gr.Markdown("## Violation Results"),
293
+ gr.Markdown("## Evidence Snapshots"),
294
+ gr.Textbox(label="Summary"),
295
+ gr.Textbox(label="Salesforce Record")
 
296
  ],
297
+ title="AI Safety Compliance Inspector",
298
+ description=(
299
+ "Upload worksite video for automated safety violation detection. "
300
+ "Detects: Missing helmets, improper harnessing, unsafe zones, and more."
301
+ ),
302
  allow_flagging="never"
303
  )
304
 
305
  if __name__ == "__main__":
306
+ interface.launch(server_name="0.0.0.0", server_port=7860)