PrashanthB461 commited on
Commit
11ea390
·
verified ·
1 Parent(s): 8f347c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -242
app.py CHANGED
@@ -15,68 +15,60 @@ import logging
15
  from retrying import retry
16
 
17
  # ==========================
18
- # Enhanced Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt",
22
- "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
- 2: "unsafe_posture",
28
- 3: "unsafe_zone",
29
- 4: "improper_tool_use"
30
  },
31
- "CLASS_COLORS": {
32
- "no_helmet": (0, 0, 255), # Red
33
- "no_harness": (0, 165, 255), # Orange
34
- "unsafe_posture": (0, 255, 0), # Green
35
- "unsafe_zone": (255, 0, 0), # Blue
36
- "improper_tool_use": (255, 255, 0) # Yellow
37
- },
38
- "DISPLAY_NAMES": {
39
- "no_helmet": "No Helmet Violation",
40
- "no_harness": "No Harness Violation",
41
- "unsafe_posture": "Unsafe Posture Violation",
42
- "unsafe_zone": "Unsafe Zone Entry",
43
- "improper_tool_use": "Improper Tool Use"
44
  },
45
  "SF_CREDENTIALS": {
46
- "username": "prashanth1ai@safety.com",
47
- "password": "SaiPrash461",
48
- "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
- "domain": "login"
50
  },
51
- "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 5, # Reduced for better detection
53
- "MAX_PROCESSING_TIME": 60,
54
- "CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
55
- "IOU_THRESHOLD": 0.4,
56
- "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
57
  }
58
 
59
  # Setup logging
60
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
61
  logger = logging.getLogger(__name__)
62
 
 
63
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
64
 
 
 
 
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  logger.info(f"Using device: {device}")
67
 
 
 
 
68
  def load_model():
69
  try:
70
- if os.path.isfile(CONFIG["MODEL_PATH"]):
71
- model_path = CONFIG["MODEL_PATH"]
72
- logger.info(f"Model loaded: {model_path}")
73
- else:
74
- model_path = CONFIG["FALLBACK_MODEL"]
75
- logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
76
- if not os.path.isfile(model_path):
77
- logger.info(f"Downloading fallback model: {model_path}")
78
- torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
79
  model = YOLO(model_path).to(device)
 
 
 
80
  return model
81
  except Exception as e:
82
  logger.error(f"Failed to load model: {e}")
@@ -85,57 +77,9 @@ def load_model():
85
  model = load_model()
86
 
87
  # ==========================
88
- # Enhanced Helper Functions
89
- # ==========================
90
- def draw_detections(frame, detections):
91
- """Draw bounding boxes and labels on frame"""
92
- for det in detections:
93
- label = det["violation"]
94
- confidence = det["confidence"]
95
- x, y, w, h = det["bounding_box"]
96
-
97
- # Convert from center coordinates to corner coordinates
98
- x1 = int(x - w/2)
99
- y1 = int(y - h/2)
100
- x2 = int(x + w/2)
101
- y2 = int(y + h/2)
102
-
103
- color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
-
106
- display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
107
- cv2.putText(frame, display_text, (x1, y1-10),
108
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
109
- return frame
110
-
111
- def calculate_iou(box1, box2):
112
- """Calculate Intersection over Union (IoU) for two bounding boxes."""
113
- x1, y1, w1, h1 = box1
114
- x2, y2, w2, h2 = box2
115
-
116
- # Convert to top-left and bottom-right coordinates
117
- x1_min, y1_min = x1 - w1/2, y1 - h1/2
118
- x1_max, y1_max = x1 + w1/2, y1 + h1/2
119
- x2_min, y2_min = x2 - w2/2, y2 - h2/2
120
- x2_max, y2_max = x2 + w2/2, y2 + h2/2
121
-
122
- # Calculate intersection
123
- x_min = max(x1_min, x2_min)
124
- y_min = max(y1_min, y2_min)
125
- x_max = min(x1_max, x2_max)
126
- y_max = min(y1_max, y2_max)
127
-
128
- intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
129
- area1 = w1 * h1
130
- area2 = w2 * h2
131
- union = area1 + area2 - intersection
132
-
133
- return intersection / union if union > 0 else 0
134
-
135
- # ==========================
136
- # Salesforce Integration (unchanged)
137
  # ==========================
138
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
139
  def connect_to_salesforce():
140
  try:
141
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
@@ -195,6 +139,7 @@ def generate_violation_pdf(violations, score):
195
  logger.error(f"Error generating PDF: {e}")
196
  return "", "", None
197
 
 
198
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
199
  try:
200
  if not pdf_file:
@@ -208,7 +153,7 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
208
  "FirstPublishLocationId": report_id
209
  }
210
  content_version = sf.ContentVersion.create(content_version_data)
211
- result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
212
  if not result['records']:
213
  logger.error("Failed to retrieve ContentVersion")
214
  return ""
@@ -219,6 +164,7 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
219
  logger.error(f"Error uploading PDF to Salesforce: {e}")
220
  return ""
221
 
 
222
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
223
  try:
224
  sf = connect_to_salesforce()
@@ -237,10 +183,10 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
237
  }
238
  logger.info(f"Creating Salesforce record with data: {record_data}")
239
  try:
240
- record = sf.Safety_Video_Report__c.create(record_data)
241
- logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
242
  except Exception as e:
243
- logger.error(f"Failed to create Safety_Video_Report__c: {e}")
244
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
245
  logger.warning(f"Fell back to Account record: {record['id']}")
246
  record_id = record["id"]
@@ -249,39 +195,36 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
249
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
250
  if uploaded_url:
251
  try:
252
- sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
253
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
254
  except Exception as e:
255
- logger.error(f"Failed to update Safety_Video_Report__c: {e}")
256
  sf.Account.update(record_id, {"Description": uploaded_url})
257
  logger.info(f"Updated Account record {record_id} with PDF URL")
258
  pdf_url = uploaded_url
259
 
260
  return record_id, pdf_url
261
  except Exception as e:
262
- logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
263
  return None, ""
264
 
 
 
 
265
  def calculate_safety_score(violations):
266
  penalties = {
267
  "no_helmet": 25,
268
  "no_harness": 30,
269
- "unsafe_posture": 20,
270
- "unsafe_zone": 35,
271
- "improper_tool_use": 25
272
  }
273
- # Count unique violations per worker
274
- unique_violations = set()
275
  for v in violations:
276
- key = (v["worker_id"], v["violation"])
277
- unique_violations.add(key)
278
-
279
- total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
280
- score = 100 - total_penalty
281
  return max(score, 0)
282
 
283
  # ==========================
284
- # Enhanced Video Processing
285
  # ==========================
286
  def process_video(video_data):
287
  try:
@@ -294,152 +237,83 @@ def process_video(video_data):
294
  if not video.isOpened():
295
  raise ValueError("Could not open video file")
296
 
297
- violations = []
298
- snapshots = []
299
  frame_count = 0
300
  start_time = time.time()
301
  fps = video.get(cv2.CAP_PROP_FPS)
302
- if fps <= 0:
303
- fps = 30 # Default assumption if FPS cannot be determined
304
-
305
- # Structure to track workers and their violations
306
- workers = []
307
- violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
- snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
309
 
310
- logger.info(f"Processing video with FPS: {fps}")
311
- logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
312
 
313
  while True:
314
  ret, frame = video.read()
315
- if not ret:
316
  break
317
 
318
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
319
  frame_count += 1
320
  continue
321
 
 
322
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
323
  logger.info("Processing time limit reached")
324
  break
325
 
326
- current_time = frame_count / fps
327
-
328
- # Run detection on this frame
329
  results = model(frame, device=device)
330
-
331
- current_detections = []
332
  for result in results:
333
- boxes = result.boxes
334
- for box in boxes:
335
- cls = int(box.cls)
336
- conf = float(box.conf)
337
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
-
339
- if label is None:
340
  continue
341
-
342
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
 
343
  continue
 
 
 
344
 
345
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
-
347
- current_detections.append({
348
  "frame": frame_count,
349
  "violation": label,
350
  "confidence": round(conf, 2),
351
- "bounding_box": bbox,
352
- "timestamp": current_time
353
- })
354
-
355
- # Process detections and associate with workers
356
- for detection in current_detections:
357
- # Find matching worker
358
- matched_worker = None
359
- max_iou = 0
360
-
361
- for worker in workers:
362
- iou = calculate_iou(detection["bounding_box"], worker["bbox"])
363
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
364
- max_iou = iou
365
- matched_worker = worker
366
-
367
- if matched_worker:
368
- # Update worker's position
369
- matched_worker["bbox"] = detection["bounding_box"]
370
- matched_worker["last_seen"] = current_time
371
- worker_id = matched_worker["id"]
372
- else:
373
- # New worker
374
- worker_id = len(workers) + 1
375
- workers.append({
376
- "id": worker_id,
377
- "bbox": detection["bounding_box"],
378
- "first_seen": current_time,
379
- "last_seen": current_time
380
- })
381
-
382
- # Add to violation history
383
- detection["worker_id"] = worker_id
384
- violation_history[detection["violation"]].append(detection)
385
 
386
  frame_count += 1
387
 
388
  video.release()
389
  os.remove(video_path)
390
-
391
- # Process violation history to confirm persistent violations
392
- for violation_type, detections in violation_history.items():
393
- if not detections:
394
- continue
395
-
396
- # Group by worker
397
- worker_violations = {}
398
- for det in detections:
399
- if det["worker_id"] not in worker_violations:
400
- worker_violations[det["worker_id"]] = []
401
- worker_violations[det["worker_id"]].append(det)
402
-
403
- # Check each worker's violations for persistence
404
- for worker_id, worker_dets in worker_violations.items():
405
- if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
406
- # Take the highest confidence detection
407
- best_detection = max(worker_dets, key=lambda x: x["confidence"])
408
- violations.append(best_detection)
409
-
410
- # Capture snapshot if not already taken
411
- if not snapshot_taken[violation_type]:
412
- # Get the frame for this violation
413
- cap = cv2.VideoCapture(video_path)
414
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
415
- ret, snapshot_frame = cap.read()
416
- cap.release()
417
-
418
- if ret:
419
- # Draw detections on snapshot
420
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
421
-
422
- snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
423
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
424
- cv2.imwrite(snapshot_path, snapshot_frame)
425
- snapshots.append({
426
- "violation": violation_type,
427
- "frame": best_detection["frame"],
428
- "snapshot_path": snapshot_path,
429
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
430
- })
431
- snapshot_taken[violation_type] = True
432
-
433
- # Final processing
434
  if not violations:
435
- logger.info("No persistent violations detected")
436
  return {
437
  "violations": [],
438
  "snapshots": [],
439
  "score": 100,
440
  "salesforce_record_id": None,
441
- "violation_details_url": "",
442
- "message": "No violations detected in the video."
443
  }
444
 
445
  score = calculate_safety_score(violations)
@@ -451,18 +325,16 @@ def process_video(video_data):
451
  "snapshots": snapshots,
452
  "score": score,
453
  "salesforce_record_id": report_id,
454
- "violation_details_url": final_pdf_url,
455
- "message": ""
456
  }
457
  except Exception as e:
458
- logger.error(f"Error processing video: {e}", exc_info=True)
459
  return {
460
  "violations": [],
461
  "snapshots": [],
462
  "score": 100,
463
  "salesforce_record_id": None,
464
- "violation_details_url": "",
465
- "message": f"Error processing video: {e}"
466
  }
467
 
468
  # ==========================
@@ -472,38 +344,29 @@ def gradio_interface(video_file):
472
  if not video_file:
473
  return "No file uploaded.", "", "No file uploaded.", "", ""
474
  try:
475
- yield "Processing video... please wait.", "", "", "", ""
476
-
477
  with open(video_file, "rb") as f:
478
  video_data = f.read()
479
-
480
  result = process_video(video_data)
481
 
482
- if result.get("message"):
483
- yield result["message"], "", "", "", ""
484
- return
485
-
486
  violation_table = "No violations detected."
487
  if result["violations"]:
488
- header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
489
- separator = "|------------------------|---------------|------------|-----------|\n"
490
  rows = []
491
- violation_name_map = CONFIG["DISPLAY_NAMES"]
492
  for v in result["violations"]:
493
- display_name = violation_name_map.get(v["violation"], v["violation"])
494
- row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
495
  rows.append(row)
496
  violation_table = header + separator + "\n".join(rows)
497
 
498
  snapshots_text = "No snapshots captured."
499
  if result["snapshots"]:
500
- violation_name_map = CONFIG["DISPLAY_NAMES"]
501
  snapshots_text = "\n".join(
502
- f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
503
  for s in result["snapshots"]
504
  )
505
 
506
- yield (
507
  violation_table,
508
  f"Safety Score: {result['score']}%",
509
  snapshots_text,
@@ -511,13 +374,13 @@ def gradio_interface(video_file):
511
  result["violation_details_url"] or "N/A"
512
  )
513
  except Exception as e:
514
- logger.error(f"Error in Gradio interface: {e}", exc_info=True)
515
- yield f"Error: {str(e)}", "", "Error in processing.", "", ""
516
 
517
  interface = gr.Interface(
518
  fn=gradio_interface,
519
  inputs=gr.Video(label="Upload Site Video"),
520
- outputs=[
521
  gr.Markdown(label="Detected Safety Violations"),
522
  gr.Textbox(label="Compliance Score"),
523
  gr.Markdown(label="Snapshots"),
@@ -525,10 +388,9 @@ interface = gr.Interface(
525
  gr.Textbox(label="Violation Details URL")
526
  ],
527
  title="Worksite Safety Violation Analyzer",
528
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
529
- allow_flagging="never"
530
  )
531
 
532
  if __name__ == "__main__":
533
- logger.info("Launching Enhanced Safety Analyzer App...")
534
- interface.launch()
 
15
  from retrying import retry
16
 
17
  # ==========================
18
+ # Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt", # Custom-trained model for specific violations
22
+ "FALLBACK_MODEL_PATH": "yolov8n.pt", # Fallback if custom model is missing
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
+ 2: "unsafe_posture"
 
 
28
  },
29
+ "DISPLAY_NAMES": { # Mapping for user-friendly violation names
30
+ "no_helmet": "Missing Helmet",
31
+ "no_harness": "Missing Harness",
32
+ "unsafe_posture": "Unsafe Posture"
 
 
 
 
 
 
 
 
 
33
  },
34
  "SF_CREDENTIALS": {
35
+ "username": "your_username@safety.com",
36
+ "password": "your_password",
37
+ "security_token": "your_security_token",
38
+ "domain": "login" # Use "test" for sandbox
39
  },
40
+ "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
41
+ "FRAME_SKIP": 15, # Process every 15th frame
42
+ "MAX_PROCESSING_TIME": 25, # Cap video processing at 25s
43
+ "CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
 
 
44
  }
45
 
46
  # Setup logging
47
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
48
  logger = logging.getLogger(__name__)
49
 
50
+ # Ensure output directory exists
51
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
52
 
53
+ # ==========================
54
+ # Device Setup
55
+ # ==========================
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  logger.info(f"Using device: {device}")
58
 
59
+ # ==========================
60
+ # Model Loading
61
+ # ==========================
62
  def load_model():
63
  try:
64
+ model_path = CONFIG["MODEL_PATH"]
65
+ if not os.path.exists(model_path):
66
+ logger.warning(f"Custom model {model_path} not found. Falling back to {CONFIG['FALLBACK_MODEL_PATH']}")
67
+ model_path = CONFIG["FALLBACK_MODEL_PATH"]
 
 
 
 
 
68
  model = YOLO(model_path).to(device)
69
+ logger.info(f"Model loaded: {model_path}")
70
+ if model_path == CONFIG["FALLBACK_MODEL_PATH"]:
71
+ logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
72
  return model
73
  except Exception as e:
74
  logger.error(f"Failed to load model: {e}")
 
77
  model = load_model()
78
 
79
  # ==========================
80
+ # Salesforce Integration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # ==========================
82
+ @retry(stop_max_attempt_number=2, wait_fixed=1000)
83
  def connect_to_salesforce():
84
  try:
85
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
 
139
  logger.error(f"Error generating PDF: {e}")
140
  return "", "", None
141
 
142
+ @retry(stop_max_attempt_number=2, wait_fixed=1000)
143
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
144
  try:
145
  if not pdf_file:
 
153
  "FirstPublishLocationId": report_id
154
  }
155
  content_version = sf.ContentVersion.create(content_version_data)
156
+ result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'")
157
  if not result['records']:
158
  logger.error("Failed to retrieve ContentVersion")
159
  return ""
 
164
  logger.error(f"Error uploading PDF to Salesforce: {e}")
165
  return ""
166
 
167
+ @retry(stop_max_attempt_number=2, wait_fixed=1000)
168
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
169
  try:
170
  sf = connect_to_salesforce()
 
183
  }
184
  logger.info(f"Creating Salesforce record with data: {record_data}")
185
  try:
186
+ record = sf.Safety_Violation_Report__c.create(record_data)
187
+ logger.info(f"Created Safety_Violation_Report__c record: {record['id']}")
188
  except Exception as e:
189
+ logger.error(f"Failed to create Safety_Violation_Report__c: {e}")
190
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
191
  logger.warning(f"Fell back to Account record: {record['id']}")
192
  record_id = record["id"]
 
195
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
196
  if uploaded_url:
197
  try:
198
+ sf.Safety_Violation_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
199
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
200
  except Exception as e:
201
+ logger.error(f"Failed to update Safety_Violation_Report__c: {e}")
202
  sf.Account.update(record_id, {"Description": uploaded_url})
203
  logger.info(f"Updated Account record {record_id} with PDF URL")
204
  pdf_url = uploaded_url
205
 
206
  return record_id, pdf_url
207
  except Exception as e:
208
+ logger.error(f"Salesforce record creation failed: {e}")
209
  return None, ""
210
 
211
+ # ==========================
212
+ # Safety Score Calculation
213
+ # ==========================
214
  def calculate_safety_score(violations):
215
  penalties = {
216
  "no_helmet": 25,
217
  "no_harness": 30,
218
+ "unsafe_posture": 20
 
 
219
  }
220
+ score = 100
 
221
  for v in violations:
222
+ if v["violation"] in penalties:
223
+ score -= penalties[v["violation"]]
 
 
 
224
  return max(score, 0)
225
 
226
  # ==========================
227
+ # Video Processing
228
  # ==========================
229
  def process_video(video_data):
230
  try:
 
237
  if not video.isOpened():
238
  raise ValueError("Could not open video file")
239
 
240
+ violations, snapshots = [], []
 
241
  frame_count = 0
242
  start_time = time.time()
243
  fps = video.get(cv2.CAP_PROP_FPS)
244
+ max_frames = int(60 * fps) # Process up to 1 minute
 
 
 
 
 
 
245
 
246
+ # Track one snapshot per violation type
247
+ snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
248
 
249
  while True:
250
  ret, frame = video.read()
251
+ if not ret or frame_count >= max_frames:
252
  break
253
 
254
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
255
  frame_count += 1
256
  continue
257
 
258
+ # Stop if processing time exceeds 25 seconds
259
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
260
  logger.info("Processing time limit reached")
261
  break
262
 
 
 
 
263
  results = model(frame, device=device)
264
+ seen_violations = set()
 
265
  for result in results:
266
+ for box in result.boxes:
267
+ cls, conf = int(box.cls), float(box.conf)
268
+ label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
269
+ # Only process specified violations
270
+ if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
271
+ logger.warning(f"Unexpected detection: {label} (cls: {cls}, conf: {conf}) - ignored")
 
272
  continue
273
+ # Apply confidence threshold
274
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
275
+ logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
276
  continue
277
+ if label in seen_violations:
278
+ continue
279
+ seen_violations.add(label)
280
 
281
+ violation = {
 
 
282
  "frame": frame_count,
283
  "violation": label,
284
  "confidence": round(conf, 2),
285
+ "bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]],
286
+ "timestamp": frame_count / fps
287
+ }
288
+ violations.append(violation)
289
+
290
+ # Save only one snapshot per violation type
291
+ if not snapshot_taken[label]:
292
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
293
+ cv2.imwrite(snapshot_path, frame)
294
+ with open(snapshot_path, "rb") as img_file:
295
+ img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
296
+ snapshots.append({
297
+ "violation": label,
298
+ "frame": frame_count,
299
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
300
+ "snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
301
+ })
302
+ snapshot_taken[label] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  frame_count += 1
305
 
306
  video.release()
307
  os.remove(video_path)
308
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  if not violations:
310
+ logger.info("No violations detected")
311
  return {
312
  "violations": [],
313
  "snapshots": [],
314
  "score": 100,
315
  "salesforce_record_id": None,
316
+ "violation_details_url": ""
 
317
  }
318
 
319
  score = calculate_safety_score(violations)
 
325
  "snapshots": snapshots,
326
  "score": score,
327
  "salesforce_record_id": report_id,
328
+ "violation_details_url": final_pdf_url
 
329
  }
330
  except Exception as e:
331
+ logger.error(f"Error processing video: {e}")
332
  return {
333
  "violations": [],
334
  "snapshots": [],
335
  "score": 100,
336
  "salesforce_record_id": None,
337
+ "violation_details_url": ""
 
338
  }
339
 
340
  # ==========================
 
344
  if not video_file:
345
  return "No file uploaded.", "", "No file uploaded.", "", ""
346
  try:
 
 
347
  with open(video_file, "rb") as f:
348
  video_data = f.read()
 
349
  result = process_video(video_data)
350
 
 
 
 
 
351
  violation_table = "No violations detected."
352
  if result["violations"]:
353
+ header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
354
+ separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n"
355
  rows = []
 
356
  for v in result["violations"]:
357
+ display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
358
+ row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
359
  rows.append(row)
360
  violation_table = header + separator + "\n".join(rows)
361
 
362
  snapshots_text = "No snapshots captured."
363
  if result["snapshots"]:
 
364
  snapshots_text = "\n".join(
365
+ f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
366
  for s in result["snapshots"]
367
  )
368
 
369
+ return (
370
  violation_table,
371
  f"Safety Score: {result['score']}%",
372
  snapshots_text,
 
374
  result["violation_details_url"] or "N/A"
375
  )
376
  except Exception as e:
377
+ logger.error(f"Error in Gradio interface: {e}")
378
+ return f"Error: {str(e)}", "", "Error in processing.", "", ""
379
 
380
  interface = gr.Interface(
381
  fn=gradio_interface,
382
  inputs=gr.Video(label="Upload Site Video"),
383
+ outputs=[
384
  gr.Markdown(label="Detected Safety Violations"),
385
  gr.Textbox(label="Compliance Score"),
386
  gr.Markdown(label="Snapshots"),
 
388
  gr.Textbox(label="Violation Details URL")
389
  ],
390
  title="Worksite Safety Violation Analyzer",
391
+ description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored."
 
392
  )
393
 
394
  if __name__ == "__main__":
395
+ logger.info("Launching Safety Analyzer App...")
396
+ interface.launch()