PrashanthB461 commited on
Commit
4b68be9
·
verified ·
1 Parent(s): ceac049

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -50
app.py CHANGED
@@ -50,7 +50,7 @@ CONFIG = {
50
  "domain": "login"
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
- "FRAME_SKIP": 3, # Increased from 1 to 3 for faster processing
54
  "CONFIDENCE_THRESHOLDS": {
55
  "no_helmet": 0.6,
56
  "no_harness": 0.15,
@@ -59,11 +59,12 @@ CONFIG = {
59
  "improper_tool_use": 0.15
60
  },
61
  "IOU_THRESHOLD": 0.4,
62
- "MIN_VIOLATION_FRAMES": 2,
63
  "HELMET_CONFIDENCE_THRESHOLD": 0.65,
64
- "MAX_PROCESSING_TIME": 30, # Reduced from 60 to 30 seconds
65
- "BATCH_SIZE": 10, # Process frames in batches for efficiency
66
- "WORKER_TRACKING_DURATION": 3.0 # Seconds to track a worker without updates
 
67
  }
68
 
69
  # Setup logging
@@ -120,18 +121,21 @@ def calculate_iou(box1, box2):
120
  x1, y1, w1, h1 = box1
121
  x2, y2, w2, h2 = box2
122
 
123
- x1_min, y1_min = x1 - w1/2, y1 - h1/2
124
- x1_max, y1_max = x1 + w1/2, y1 + h1/2
125
- x2_min, y2_min = x2 - w2/2, y2 - h2/2
126
- x2_max, y2_max = x2 + w2/2, y2 + h2/2
 
127
 
128
- intersection = max(0, min(x1_max, x2_max) - max(x1_min, x2_min)) * \
129
- max(0, min(y1_max, y2_max) - max(y1_min, y2_min))
130
- area1 = w1 * h1
131
- area2 = w2 * h2
132
- union = area1 + area2 - intersection
133
 
134
- return intersection / union if union > 0 else 0
 
 
 
 
 
135
 
136
  def generate_violation_pdf(violations, score):
137
  try:
@@ -182,6 +186,18 @@ def generate_violation_pdf(violations, score):
182
  logger.error(f"Error generating PDF: {e}")
183
  return "", "", None
184
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # ==========================
186
  # Optimized Video Processing
187
  # ==========================
@@ -202,16 +218,19 @@ def process_video(video_data):
202
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
203
  fps = cap.get(cv2.CAP_PROP_FPS)
204
  if fps <= 0:
205
- fps = 30
206
  duration = total_frames / fps
207
- logger.info(f"Video duration: {duration:.2f}s, Frames: {total_frames}, FPS: {fps}")
 
 
 
208
 
209
- # Calculate frames to process
210
- frame_skip = CONFIG["FRAME_SKIP"]
211
- frames_to_process = total_frames // frame_skip
212
- if frames_to_process < 10: # Ensure we process at least 10 frames
213
- frame_skip = max(1, total_frames // 10)
214
- frames_to_process = total_frames // frame_skip
215
 
216
  workers = []
217
  violations = []
@@ -219,30 +238,35 @@ def process_video(video_data):
219
  snapshots = []
220
  start_time = time.time()
221
  processed_frames = 0
 
222
 
223
- # Process frames in batches for efficiency
224
- for batch_start in range(0, total_frames, CONFIG["BATCH_SIZE"] * frame_skip):
225
  batch_frames = []
226
  batch_indices = []
227
 
228
  # Collect frames for this batch
229
- for i in range(CONFIG["BATCH_SIZE"]):
230
- frame_idx = batch_start + i * frame_skip
231
  if frame_idx >= total_frames:
232
  break
233
 
234
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
235
  ret, frame = cap.read()
236
  if not ret:
237
- continue
238
 
239
  batch_frames.append(frame)
240
  batch_indices.append(frame_idx)
241
  processed_frames += 1
 
 
 
 
 
242
 
243
- # Skip empty batches
244
  if not batch_frames:
245
- continue
246
 
247
  # Run batch detection
248
  results = model(batch_frames, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
@@ -250,8 +274,12 @@ def process_video(video_data):
250
  # Process results for each frame in batch
251
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
252
  current_time = frame_idx / fps
253
- progress = (processed_frames / frames_to_process) * 100
254
- yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
 
 
 
 
255
 
256
  # Process detections in this frame
257
  boxes = result.boxes
@@ -272,7 +300,7 @@ def process_video(video_data):
272
  "timestamp": current_time
273
  }
274
 
275
- # Worker tracking and helmet detection optimization
276
  worker_id = None
277
  max_iou = 0
278
  for idx, worker in enumerate(workers):
@@ -292,6 +320,8 @@ def process_video(video_data):
292
  "last_seen": current_time
293
  })
294
 
 
 
295
  # Special handling for helmet violations
296
  if label == "no_helmet":
297
  if worker_id not in helmet_violations:
@@ -303,14 +333,10 @@ def process_video(video_data):
303
  # Remove workers not seen recently
304
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
305
 
306
- # Check processing time limit
307
- if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
308
- logger.info(f"Processing time limit reached at frame {frame_idx}")
309
- break
310
-
311
- # Process helmet violations (more strict criteria)
312
  for worker_id, detections in helmet_violations.items():
313
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
 
314
  best_detection = max(detections, key=lambda x: x["confidence"])
315
  violations.append(best_detection)
316
 
@@ -331,38 +357,54 @@ def process_video(video_data):
331
 
332
  cap.release()
333
  os.remove(video_path)
334
- logger.info(f"Processing complete. {len(violations)} violations found.")
 
335
 
336
  # Generate results
337
  if not violations:
338
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
339
  return
340
 
341
- score = max(0, 100 - sum(25 if v["violation"] == "no_helmet" else
342
- 30 if v["violation"] == "no_harness" else
343
- 20 if v["violation"] == "unsafe_posture" else
344
- 35 if v["violation"] == "unsafe_zone" else
345
- 25 for v in violations))
346
-
347
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
348
 
 
349
  violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
350
  violation_table += "|------------------------|---------------|------------|-----------|\n"
351
- for v in violations:
352
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
353
  row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
354
  violation_table += row
355
 
 
356
  snapshots_text = "\n".join(
357
  f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
358
  for s in snapshots
359
  ) if snapshots else "No snapshots captured."
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  yield (
362
  violation_table,
363
  f"Safety Score: {score}%",
364
  snapshots_text,
365
- "Salesforce integration placeholder",
366
  pdf_url or "N/A"
367
  )
368
 
@@ -370,6 +412,40 @@ def process_video(video_data):
370
  logger.error(f"Error processing video: {e}", exc_info=True)
371
  yield f"Error processing video: {e}", "", "", "", ""
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  # ==========================
374
  # Gradio Interface
375
  # ==========================
@@ -402,5 +478,5 @@ interface = gr.Interface(
402
  )
403
 
404
  if __name__ == "__main__":
405
- logger.info("Launching Optimized Safety Analyzer App...")
406
  interface.launch()
 
50
  "domain": "login"
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
+ "FRAME_SKIP": 2, # Process every 2nd frame for balance of speed/accuracy
54
  "CONFIDENCE_THRESHOLDS": {
55
  "no_helmet": 0.6,
56
  "no_harness": 0.15,
 
59
  "improper_tool_use": 0.15
60
  },
61
  "IOU_THRESHOLD": 0.4,
62
+ "MIN_VIOLATION_FRAMES": 3, # Require more consistent detections
63
  "HELMET_CONFIDENCE_THRESHOLD": 0.65,
64
+ "WORKER_TRACKING_DURATION": 3.0, # Seconds to track a worker
65
+ "MIN_FRAME_RATE": 5, # Minimum frames per second to process
66
+ "MAX_FRAME_RATE": 15, # Maximum frames per second to process
67
+ "BATCH_SIZE": 8 # Number of frames to process at once
68
  }
69
 
70
  # Setup logging
 
121
  x1, y1, w1, h1 = box1
122
  x2, y2, w2, h2 = box2
123
 
124
+ # Calculate coordinates of the intersection rectangle
125
+ x_left = max(x1 - w1/2, x2 - w2/2)
126
+ y_top = max(y1 - h1/2, y2 - h2/2)
127
+ x_right = min(x1 + w1/2, x2 + w2/2)
128
+ y_bottom = min(y1 + h1/2, y2 + h2/2)
129
 
130
+ if x_right < x_left or y_bottom < y_top:
131
+ return 0.0
 
 
 
132
 
133
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
134
+ box1_area = w1 * h1
135
+ box2_area = w2 * h2
136
+ union_area = box1_area + box2_area - intersection_area
137
+
138
+ return intersection_area / union_area
139
 
140
  def generate_violation_pdf(violations, score):
141
  try:
 
186
  logger.error(f"Error generating PDF: {e}")
187
  return "", "", None
188
 
189
+ def calculate_safety_score(violations):
190
+ penalties = {
191
+ "no_helmet": 25,
192
+ "no_harness": 30,
193
+ "unsafe_posture": 20,
194
+ "unsafe_zone": 35,
195
+ "improper_tool_use": 25
196
+ }
197
+ total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
198
+ score = 100 - total_penalty
199
+ return max(score, 0)
200
+
201
  # ==========================
202
  # Optimized Video Processing
203
  # ==========================
 
218
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
219
  fps = cap.get(cv2.CAP_PROP_FPS)
220
  if fps <= 0:
221
+ fps = 30 # Default assumption if FPS not available
222
  duration = total_frames / fps
223
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
224
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
225
+
226
+ logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
227
 
228
+ # Calculate optimal frame skipping
229
+ original_frame_skip = CONFIG["FRAME_SKIP"]
230
+ target_fps = min(max(fps / original_frame_skip, CONFIG["MIN_FRAME_RATE"]), CONFIG["MAX_FRAME_RATE"])
231
+ actual_frame_skip = max(1, int(fps / target_fps))
232
+ frames_to_process = total_frames // actual_frame_skip
233
+ logger.info(f"Processing strategy: Frame skip={actual_frame_skip}, Target FPS={target_fps:.1f}, Frames to process={frames_to_process}")
234
 
235
  workers = []
236
  violations = []
 
238
  snapshots = []
239
  start_time = time.time()
240
  processed_frames = 0
241
+ last_progress_update = 0
242
 
243
+ # Process frames in batches
244
+ while True:
245
  batch_frames = []
246
  batch_indices = []
247
 
248
  # Collect frames for this batch
249
+ for _ in range(CONFIG["BATCH_SIZE"]):
250
+ frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
251
  if frame_idx >= total_frames:
252
  break
253
 
 
254
  ret, frame = cap.read()
255
  if not ret:
256
+ break
257
 
258
  batch_frames.append(frame)
259
  batch_indices.append(frame_idx)
260
  processed_frames += 1
261
+
262
+ # Skip frames according to our strategy
263
+ for _ in range(actual_frame_skip - 1):
264
+ if not cap.grab():
265
+ break
266
 
267
+ # Break if no more frames
268
  if not batch_frames:
269
+ break
270
 
271
  # Run batch detection
272
  results = model(batch_frames, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
 
274
  # Process results for each frame in batch
275
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
276
  current_time = frame_idx / fps
277
+
278
+ # Update progress periodically
279
+ if time.time() - last_progress_update > 1.0: # Update every second
280
+ progress = (frame_idx / total_frames) * 100
281
+ yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
282
+ last_progress_update = time.time()
283
 
284
  # Process detections in this frame
285
  boxes = result.boxes
 
300
  "timestamp": current_time
301
  }
302
 
303
+ # Worker tracking
304
  worker_id = None
305
  max_iou = 0
306
  for idx, worker in enumerate(workers):
 
320
  "last_seen": current_time
321
  })
322
 
323
+ detection["worker_id"] = worker_id
324
+
325
  # Special handling for helmet violations
326
  if label == "no_helmet":
327
  if worker_id not in helmet_violations:
 
333
  # Remove workers not seen recently
334
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
335
 
336
+ # Process helmet violations (require consistent detections)
 
 
 
 
 
337
  for worker_id, detections in helmet_violations.items():
338
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
339
+ # Find the detection with highest confidence
340
  best_detection = max(detections, key=lambda x: x["confidence"])
341
  violations.append(best_detection)
342
 
 
357
 
358
  cap.release()
359
  os.remove(video_path)
360
+ processing_time = time.time() - start_time
361
+ logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
362
 
363
  # Generate results
364
  if not violations:
365
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
366
  return
367
 
368
+ score = calculate_safety_score(violations)
 
 
 
 
 
369
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
370
 
371
+ # Generate violation table
372
  violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
373
  violation_table += "|------------------------|---------------|------------|-----------|\n"
374
+ for v in sorted(violations, key=lambda x: x["timestamp"]):
375
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
376
  row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
377
  violation_table += row
378
 
379
+ # Generate snapshots text
380
  snapshots_text = "\n".join(
381
  f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
382
  for s in snapshots
383
  ) if snapshots else "No snapshots captured."
384
 
385
+ # Push to Salesforce
386
+ try:
387
+ sf = connect_to_salesforce()
388
+ record_data = {
389
+ "Compliance_Score__c": score,
390
+ "Violations_Found__c": len(violations),
391
+ "Status__c": "Completed",
392
+ "Processing_Time__c": f"{processing_time:.2f}s"
393
+ }
394
+ record = sf.Safety_Video_Report__c.create(record_data)
395
+ record_id = record["id"]
396
+
397
+ if pdf_file:
398
+ pdf_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
399
+ except Exception as e:
400
+ logger.error(f"Salesforce integration failed: {e}")
401
+ record_id = "N/A (Salesforce error)"
402
+
403
  yield (
404
  violation_table,
405
  f"Safety Score: {score}%",
406
  snapshots_text,
407
+ f"Salesforce Record ID: {record_id}",
408
  pdf_url or "N/A"
409
  )
410
 
 
412
  logger.error(f"Error processing video: {e}", exc_info=True)
413
  yield f"Error processing video: {e}", "", "", "", ""
414
 
415
+ # ==========================
416
+ # Salesforce Integration
417
+ # ==========================
418
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
419
+ def connect_to_salesforce():
420
+ try:
421
+ sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
422
+ logger.info("Connected to Salesforce")
423
+ return sf
424
+ except Exception as e:
425
+ logger.error(f"Salesforce connection failed: {e}")
426
+ raise
427
+
428
+ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
429
+ try:
430
+ encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
431
+ content_version_data = {
432
+ "Title": f"Safety_Violation_Report_{int(time.time())}",
433
+ "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
434
+ "VersionData": encoded_pdf,
435
+ "FirstPublishLocationId": report_id
436
+ }
437
+ content_version = sf.ContentVersion.create(content_version_data)
438
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
439
+ if not result['records']:
440
+ logger.error("Failed to retrieve ContentVersion")
441
+ return ""
442
+ file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
443
+ logger.info(f"PDF uploaded to Salesforce: {file_url}")
444
+ return file_url
445
+ except Exception as e:
446
+ logger.error(f"Error uploading PDF to Salesforce: {e}")
447
+ return ""
448
+
449
  # ==========================
450
  # Gradio Interface
451
  # ==========================
 
478
  )
479
 
480
  if __name__ == "__main__":
481
+ logger.info("Launching Enhanced Safety Analyzer App...")
482
  interface.launch()