PrashanthB461 commited on
Commit
ceac049
·
verified ·
1 Parent(s): 481fa07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -274
app.py CHANGED
@@ -50,9 +50,9 @@ CONFIG = {
50
  "domain": "login"
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
- "FRAME_SKIP": 1,
54
  "CONFIDENCE_THRESHOLDS": {
55
- "no_helmet": 0.6, # Increased to reduce false positives
56
  "no_harness": 0.15,
57
  "unsafe_posture": 0.15,
58
  "unsafe_zone": 0.15,
@@ -60,8 +60,10 @@ CONFIG = {
60
  },
61
  "IOU_THRESHOLD": 0.4,
62
  "MIN_VIOLATION_FRAMES": 2,
63
- "HELMET_CONFIDENCE_THRESHOLD": 0.65, # Lowered to mark more workers as compliant
64
- "MAX_PROCESSING_TIME": 60
 
 
65
  }
66
 
67
  # Setup logging
@@ -93,7 +95,7 @@ def load_model():
93
  model = load_model()
94
 
95
  # ==========================
96
- # Enhanced Helper Functions
97
  # ==========================
98
  def draw_detections(frame, detections):
99
  for det in detections:
@@ -123,27 +125,14 @@ def calculate_iou(box1, box2):
123
  x2_min, y2_min = x2 - w2/2, y2 - h2/2
124
  x2_max, y2_max = x2 + w2/2, y2 + h2/2
125
 
126
- intersection = max(0, x1_max - x1_min) * max(0, y1_max - y1_min)
 
127
  area1 = w1 * h1
128
  area2 = w2 * h2
129
  union = area1 + area2 - intersection
130
 
131
  return intersection / union if union > 0 else 0
132
 
133
- # ==========================
134
- # Salesforce Integration
135
- # ==========================
136
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
137
- def connect_to_salesforce():
138
- try:
139
- sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
140
- logger.info("Connected to Salesforce")
141
- sf.describe()
142
- return sf
143
- except Exception as e:
144
- logger.error(f"Salesforce connection failed: {e}")
145
- raise
146
-
147
  def generate_violation_pdf(violations, score):
148
  try:
149
  pdf_filename = f"violations_{int(time.time())}.pdf"
@@ -193,299 +182,169 @@ def generate_violation_pdf(violations, score):
193
  logger.error(f"Error generating PDF: {e}")
194
  return "", "", None
195
 
196
- def upload_pdf_to_salesforce(sf, pdf_file, report_id):
197
- try:
198
- if not pdf_file:
199
- logger.error("No PDF file provided for upload")
200
- return ""
201
- encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
202
- content_version_data = {
203
- "Title": f"Safety_Violation_Report_{int(time.time())}",
204
- "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
205
- "VersionData": encoded_pdf,
206
- "FirstPublishLocationId": report_id
207
- }
208
- content_version = sf.ContentVersion.create(content_version_data)
209
- result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
210
- if not result['records']:
211
- logger.error("Failed to retrieve ContentVersion")
212
- return ""
213
- file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
214
- logger.info(f"PDF uploaded to Salesforce: {file_url}")
215
- return file_url
216
- except Exception as e:
217
- logger.error(f"Error uploading PDF to Salesforce: {e}")
218
- return ""
219
-
220
- def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
221
- try:
222
- sf = connect_to_salesforce()
223
- violations_text = "\n".join(
224
- f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
225
- for v in violations
226
- ) or "No violations detected."
227
- pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
228
-
229
- record_data = {
230
- "Compliance_Score__c": score,
231
- "Violations_Found__c": len(violations),
232
- "Violations_Details__c": violations_text,
233
- "Status__c": "Pending",
234
- "PDF_Report_URL__c": pdf_url
235
- }
236
- logger.info(f"Creating Salesforce record with data: {record_data}")
237
- try:
238
- record = sf.Safety_Video_Report__c.create(record_data)
239
- logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
240
- except Exception as e:
241
- logger.error(f"Failed to create Safety_Video_Report__c: {e}")
242
- record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
243
- logger.warning(f"Fell back to Account record: {record['id']}")
244
- record_id = record["id"]
245
-
246
- if pdf_file:
247
- uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
248
- if uploaded_url:
249
- try:
250
- sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
251
- logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
252
- except Exception as e:
253
- logger.error(f"Failed to update Safety_Video_Report__c: {e}")
254
- sf.Account.update(record_id, {"Description": uploaded_url})
255
- logger.info(f"Updated Account record {record_id} with PDF URL")
256
- pdf_url = uploaded_url
257
-
258
- return record_id, pdf_url
259
- except Exception as e:
260
- logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
261
- return None, ""
262
-
263
- def calculate_safety_score(violations):
264
- penalties = {
265
- "no_helmet": 25,
266
- "no_harness": 30,
267
- "unsafe_posture": 20,
268
- "unsafe_zone": 35,
269
- "improper_tool_use": 25
270
- }
271
- total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
272
- score = 100 - total_penalty
273
- return max(score, 0)
274
-
275
  # ==========================
276
- # Enhanced Video Processing
277
  # ==========================
278
  def process_video(video_data):
279
  try:
 
280
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
281
  with open(video_path, "wb") as f:
282
  f.write(video_data)
283
  logger.info(f"Video saved: {video_path}")
284
 
285
- video = cv2.VideoCapture(video_path)
286
- if not video.isOpened():
 
287
  raise ValueError("Could not open video file")
288
 
289
- violations = []
290
- snapshots = []
291
- frame_count = 0
292
- total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
293
- fps = video.get(cv2.CAP_PROP_FPS)
294
  if fps <= 0:
295
  fps = 30
296
- video_duration = total_frames / fps
297
- logger.info(f"Video duration: {video_duration:.2f} seconds, Total frames: {total_frames}, FPS: {fps}")
 
 
 
 
 
 
 
298
 
299
  workers = []
300
- violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
301
- confirmed_violations = {}
302
- snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
303
- helmet_compliance = {}
304
- detection_counts = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()}
305
  start_time = time.time()
306
-
307
- # Calculate frames to process within 30 seconds
308
- target_frames = int(total_frames / CONFIG["FRAME_SKIP"])
309
- frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int)
310
-
311
  processed_frames = 0
312
- for idx in frame_indices:
313
- elapsed_time = time.time() - start_time
314
- if elapsed_time > CONFIG["MAX_PROCESSING_TIME"]:
315
- logger.info(f"Processing time limit of {CONFIG['MAX_PROCESSING_TIME']} seconds reached. Processed {processed_frames}/{target_frames} frames.")
316
- break
317
 
318
- video.set(cv2.CAP_PROP_POS_FRAMES, idx)
319
- ret, frame = video.read()
320
- if not ret:
321
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
- processed_frames += 1
324
- current_time = idx / fps
325
- progress = (processed_frames / target_frames) * 100
326
- yield f"Processing video... {progress:.1f}% complete (Frame {idx}/{total_frames})", "", "", "", ""
327
 
328
- # Run detection on this frame
329
- results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"])
330
 
331
- current_detections = []
332
- for result in results:
 
 
 
 
 
333
  boxes = result.boxes
334
  for box in boxes:
335
  cls = int(box.cls)
336
  conf = float(box.conf)
337
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
 
339
- if label is None:
340
- logger.warning(f"Unknown class ID {cls} detected, skipping")
341
- continue
342
-
343
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
344
- logger.debug(f"Detection {label} with confidence {conf:.2f} below threshold, skipping")
345
  continue
346
 
347
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
348
-
349
- current_detections.append({
350
- "frame": idx,
351
  "violation": label,
352
  "confidence": round(conf, 2),
353
  "bounding_box": bbox,
354
  "timestamp": current_time
355
- })
356
- detection_counts[label] += 1
357
-
358
- logger.debug(f"Frame {idx}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")
359
-
360
- for detection in current_detections:
361
- violation_type = detection.get("violation", None)
362
- if violation_type is None:
363
- logger.error(f"Invalid detection, missing 'violation' key: {detection}")
364
- continue
365
 
366
- if violation_type == "no_helmet":
367
- matched_worker = None
368
  max_iou = 0
369
- for worker in workers:
370
- iou = calculate_iou(detection["bounding_box"], worker["bbox"])
371
  if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
372
  max_iou = iou
373
- matched_worker = worker
374
-
375
- if matched_worker:
376
- worker_id = matched_worker["id"]
377
- if worker_id not in helmet_compliance:
378
- helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
379
- helmet_compliance[worker_id]["no_helmet_frames"] += 1
380
- if detection["confidence"] < CONFIG["HELMET_CONFIDENCE_THRESHOLD"]:
381
- helmet_compliance[worker_id]["compliant"] = True
382
- logger.debug(f"Worker {worker_id} marked as helmet compliant due to low no_helmet confidence: {detection['confidence']}")
383
- else:
384
- logger.debug(f"Worker {worker_id} potential no_helmet violation with confidence: {detection['confidence']}")
385
- if helmet_compliance[worker_id]["compliant"]:
386
- logger.debug(f"Worker {worker_id} has helmet, skipping no_helmet violation")
387
- continue
388
-
389
- matched_worker = None
390
- max_iou = 0
391
-
392
- for worker in workers:
393
- iou = calculate_iou(detection["bounding_box"], worker["bbox"])
394
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
395
- max_iou = iou
396
- matched_worker = worker
397
-
398
- if matched_worker:
399
- matched_worker["bbox"] = detection["bounding_box"]
400
- matched_worker["last_seen"] = current_time
401
- worker_id = matched_worker["id"]
402
- else:
403
- worker_id = len(workers) + 1
404
- workers.append({
405
- "id": worker_id,
406
- "bbox": detection["bounding_box"],
407
- "first_seen": current_time,
408
- "last_seen": current_time
409
- })
410
- if worker_id not in helmet_compliance:
411
- helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
412
-
413
- if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
414
- logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping")
415
- continue
416
-
417
- detection["worker_id"] = worker_id
418
- violation_history[violation_type].append(detection)
419
 
420
- workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
 
 
 
 
 
 
421
 
422
- logger.info(f"Detection counts: {detection_counts}")
 
423
 
424
- for violation_type, detections in violation_history.items():
425
- if not detections:
426
- logger.info(f"No detections for {violation_type}")
427
- continue
 
 
 
 
 
 
428
 
429
- worker_violations = {}
430
- for det in detections:
431
- if det["worker_id"] not in worker_violations:
432
- worker_violations[det["worker_id"]] = []
433
- worker_violations[det["worker_id"]].append(det)
434
-
435
- for worker_id, worker_dets in worker_violations.items():
436
- if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
437
- if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
438
- continue
439
-
440
- if violation_type == "no_helmet":
441
- if worker_id in helmet_compliance and helmet_compliance[worker_id]["compliant"]:
442
- logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
443
- continue
444
- # Removed the stricter persistence check for no_helmet
445
- logger.info(f"Confirmed no_helmet for worker {worker_id} with {len(worker_dets)} detections")
446
-
447
- best_detection = max(worker_dets, key=lambda x: x["confidence"])
448
- violations.append(best_detection)
449
-
450
- if worker_id not in confirmed_violations:
451
- confirmed_violations[worker_id] = set()
452
- confirmed_violations[worker_id].add(violation_type)
453
-
454
- if not snapshot_taken[violation_type]:
455
- cap = cv2.VideoCapture(video_path)
456
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
457
- ret, snapshot_frame = cap.read()
458
- if not ret:
459
- logger.error(f"Failed to capture snapshot for {violation_type} at frame {best_detection['frame']}")
460
- cap.release()
461
- continue
462
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
463
-
464
- snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
465
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
466
- cv2.imwrite(snapshot_path, snapshot_frame)
467
- snapshots.append({
468
- "violation": violation_type,
469
- "frame": best_detection["frame"],
470
- "snapshot_path": snapshot_path,
471
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
472
- })
473
- snapshot_taken[violation_type] = True
474
- logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
475
- cap.release()
476
 
477
- video.release()
478
  os.remove(video_path)
479
- logger.info(f"Video file {video_path} removed")
480
 
 
481
  if not violations:
482
- logger.info("No persistent violations detected")
483
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
484
  return
485
 
486
- score = calculate_safety_score(violations)
 
 
 
 
 
487
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
488
- report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
489
 
490
  violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
491
  violation_table += "|------------------------|---------------|------------|-----------|\n"
@@ -494,22 +353,19 @@ def process_video(video_data):
494
  row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
495
  violation_table += row
496
 
497
- snapshots_text = "No snapshots captured."
498
- if snapshots:
499
- violation_name_map = CONFIG["DISPLAY_NAMES"]
500
- snapshots_text = "\n".join(
501
- f"- Snapshot for {violation_name_map.get(s.get('violation', 'Unknown'), 'Unknown')} at frame {s.get('frame', 0)}: ![]({s.get('snapshot_base64', '')})"
502
- for s in snapshots
503
- )
504
 
505
- logger.info(f"Processing complete: {len(violations)} violations detected, score: {score}%")
506
  yield (
507
  violation_table,
508
  f"Safety Score: {score}%",
509
  snapshots_text,
510
- f"Salesforce Record ID: {report_id or 'N/A'}",
511
- final_pdf_url or "N/A"
512
  )
 
513
  except Exception as e:
514
  logger.error(f"Error processing video: {e}", exc_info=True)
515
  yield f"Error processing video: {e}", "", "", "", ""
@@ -546,5 +402,5 @@ interface = gr.Interface(
546
  )
547
 
548
  if __name__ == "__main__":
549
- logger.info("Launching Enhanced Safety Analyzer App...")
550
  interface.launch()
 
50
  "domain": "login"
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
+ "FRAME_SKIP": 3, # Increased from 1 to 3 for faster processing
54
  "CONFIDENCE_THRESHOLDS": {
55
+ "no_helmet": 0.6,
56
  "no_harness": 0.15,
57
  "unsafe_posture": 0.15,
58
  "unsafe_zone": 0.15,
 
60
  },
61
  "IOU_THRESHOLD": 0.4,
62
  "MIN_VIOLATION_FRAMES": 2,
63
+ "HELMET_CONFIDENCE_THRESHOLD": 0.65,
64
+ "MAX_PROCESSING_TIME": 30, # Reduced from 60 to 30 seconds
65
+ "BATCH_SIZE": 10, # Process frames in batches for efficiency
66
+ "WORKER_TRACKING_DURATION": 3.0 # Seconds to track a worker without updates
67
  }
68
 
69
  # Setup logging
 
95
  model = load_model()
96
 
97
  # ==========================
98
+ # Optimized Helper Functions
99
  # ==========================
100
  def draw_detections(frame, detections):
101
  for det in detections:
 
125
  x2_min, y2_min = x2 - w2/2, y2 - h2/2
126
  x2_max, y2_max = x2 + w2/2, y2 + h2/2
127
 
128
+ intersection = max(0, min(x1_max, x2_max) - max(x1_min, x2_min)) * \
129
+ max(0, min(y1_max, y2_max) - max(y1_min, y2_min))
130
  area1 = w1 * h1
131
  area2 = w2 * h2
132
  union = area1 + area2 - intersection
133
 
134
  return intersection / union if union > 0 else 0
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def generate_violation_pdf(violations, score):
137
  try:
138
  pdf_filename = f"violations_{int(time.time())}.pdf"
 
182
  logger.error(f"Error generating PDF: {e}")
183
  return "", "", None
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # ==========================
186
+ # Optimized Video Processing
187
  # ==========================
188
  def process_video(video_data):
189
  try:
190
+ # Create temp video file
191
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
192
  with open(video_path, "wb") as f:
193
  f.write(video_data)
194
  logger.info(f"Video saved: {video_path}")
195
 
196
+ # Open video file
197
+ cap = cv2.VideoCapture(video_path)
198
+ if not cap.isOpened():
199
  raise ValueError("Could not open video file")
200
 
201
+ # Get video properties
202
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
203
+ fps = cap.get(cv2.CAP_PROP_FPS)
 
 
204
  if fps <= 0:
205
  fps = 30
206
+ duration = total_frames / fps
207
+ logger.info(f"Video duration: {duration:.2f}s, Frames: {total_frames}, FPS: {fps}")
208
+
209
+ # Calculate frames to process
210
+ frame_skip = CONFIG["FRAME_SKIP"]
211
+ frames_to_process = total_frames // frame_skip
212
+ if frames_to_process < 10: # Ensure we process at least 10 frames
213
+ frame_skip = max(1, total_frames // 10)
214
+ frames_to_process = total_frames // frame_skip
215
 
216
  workers = []
217
+ violations = []
218
+ helmet_violations = {}
219
+ snapshots = []
 
 
220
  start_time = time.time()
 
 
 
 
 
221
  processed_frames = 0
 
 
 
 
 
222
 
223
+ # Process frames in batches for efficiency
224
+ for batch_start in range(0, total_frames, CONFIG["BATCH_SIZE"] * frame_skip):
225
+ batch_frames = []
226
+ batch_indices = []
227
+
228
+ # Collect frames for this batch
229
+ for i in range(CONFIG["BATCH_SIZE"]):
230
+ frame_idx = batch_start + i * frame_skip
231
+ if frame_idx >= total_frames:
232
+ break
233
+
234
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
235
+ ret, frame = cap.read()
236
+ if not ret:
237
+ continue
238
+
239
+ batch_frames.append(frame)
240
+ batch_indices.append(frame_idx)
241
+ processed_frames += 1
242
 
243
+ # Skip empty batches
244
+ if not batch_frames:
245
+ continue
 
246
 
247
+ # Run batch detection
248
+ results = model(batch_frames, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
249
 
250
+ # Process results for each frame in batch
251
+ for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
252
+ current_time = frame_idx / fps
253
+ progress = (processed_frames / frames_to_process) * 100
254
+ yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
255
+
256
+ # Process detections in this frame
257
  boxes = result.boxes
258
  for box in boxes:
259
  cls = int(box.cls)
260
  conf = float(box.conf)
261
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
262
 
263
+ if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
 
 
 
 
 
264
  continue
265
 
266
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
267
+ detection = {
268
+ "frame": frame_idx,
 
269
  "violation": label,
270
  "confidence": round(conf, 2),
271
  "bounding_box": bbox,
272
  "timestamp": current_time
273
+ }
 
 
 
 
 
 
 
 
 
274
 
275
+ # Worker tracking and helmet detection optimization
276
+ worker_id = None
277
  max_iou = 0
278
+ for idx, worker in enumerate(workers):
279
+ iou = calculate_iou(bbox, worker["bbox"])
280
  if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
281
  max_iou = iou
282
+ worker_id = worker["id"]
283
+ workers[idx]["bbox"] = bbox # Update worker position
284
+ workers[idx]["last_seen"] = current_time
285
+
286
+ if worker_id is None:
287
+ worker_id = len(workers) + 1
288
+ workers.append({
289
+ "id": worker_id,
290
+ "bbox": bbox,
291
+ "first_seen": current_time,
292
+ "last_seen": current_time
293
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ # Special handling for helmet violations
296
+ if label == "no_helmet":
297
+ if worker_id not in helmet_violations:
298
+ helmet_violations[worker_id] = []
299
+ helmet_violations[worker_id].append(detection)
300
+ else:
301
+ violations.append(detection)
302
 
303
+ # Remove workers not seen recently
304
+ workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
305
 
306
+ # Check processing time limit
307
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
308
+ logger.info(f"Processing time limit reached at frame {frame_idx}")
309
+ break
310
+
311
+ # Process helmet violations (more strict criteria)
312
+ for worker_id, detections in helmet_violations.items():
313
+ if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
314
+ best_detection = max(detections, key=lambda x: x["confidence"])
315
+ violations.append(best_detection)
316
 
317
+ # Capture snapshot for this violation
318
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
319
+ ret, snapshot_frame = cap.read()
320
+ if ret:
321
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
322
+ snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
323
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
324
+ cv2.imwrite(snapshot_path, snapshot_frame)
325
+ snapshots.append({
326
+ "violation": "no_helmet",
327
+ "frame": best_detection["frame"],
328
+ "snapshot_path": snapshot_path,
329
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
330
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ cap.release()
333
  os.remove(video_path)
334
+ logger.info(f"Processing complete. {len(violations)} violations found.")
335
 
336
+ # Generate results
337
  if not violations:
 
338
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
339
  return
340
 
341
+ score = max(0, 100 - sum(25 if v["violation"] == "no_helmet" else
342
+ 30 if v["violation"] == "no_harness" else
343
+ 20 if v["violation"] == "unsafe_posture" else
344
+ 35 if v["violation"] == "unsafe_zone" else
345
+ 25 for v in violations))
346
+
347
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
 
348
 
349
  violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
350
  violation_table += "|------------------------|---------------|------------|-----------|\n"
 
353
  row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
354
  violation_table += row
355
 
356
+ snapshots_text = "\n".join(
357
+ f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
358
+ for s in snapshots
359
+ ) if snapshots else "No snapshots captured."
 
 
 
360
 
 
361
  yield (
362
  violation_table,
363
  f"Safety Score: {score}%",
364
  snapshots_text,
365
+ "Salesforce integration placeholder",
366
+ pdf_url or "N/A"
367
  )
368
+
369
  except Exception as e:
370
  logger.error(f"Error processing video: {e}", exc_info=True)
371
  yield f"Error processing video: {e}", "", "", "", ""
 
402
  )
403
 
404
  if __name__ == "__main__":
405
+ logger.info("Launching Optimized Safety Analyzer App...")
406
  interface.launch()