PrashanthB461 commited on
Commit
8aef0a6
·
verified ·
1 Parent(s): 11ea390

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -104
app.py CHANGED
@@ -15,60 +15,68 @@ import logging
15
  from retrying import retry
16
 
17
  # ==========================
18
- # Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt", # Custom-trained model for specific violations
22
- "FALLBACK_MODEL_PATH": "yolov8n.pt", # Fallback if custom model is missing
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
- 2: "unsafe_posture"
 
 
28
  },
29
- "DISPLAY_NAMES": { # Mapping for user-friendly violation names
30
- "no_helmet": "Missing Helmet",
31
- "no_harness": "Missing Harness",
32
- "unsafe_posture": "Unsafe Posture"
 
 
 
 
 
 
 
 
 
33
  },
34
  "SF_CREDENTIALS": {
35
- "username": "your_username@safety.com",
36
- "password": "your_password",
37
- "security_token": "your_security_token",
38
- "domain": "login" # Use "test" for sandbox
39
  },
40
- "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
41
- "FRAME_SKIP": 15, # Process every 15th frame
42
- "MAX_PROCESSING_TIME": 25, # Cap video processing at 25s
43
- "CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
 
 
44
  }
45
 
46
  # Setup logging
47
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
48
  logger = logging.getLogger(__name__)
49
 
50
- # Ensure output directory exists
51
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
52
 
53
- # ==========================
54
- # Device Setup
55
- # ==========================
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  logger.info(f"Using device: {device}")
58
 
59
- # ==========================
60
- # Model Loading
61
- # ==========================
62
  def load_model():
63
  try:
64
- model_path = CONFIG["MODEL_PATH"]
65
- if not os.path.exists(model_path):
66
- logger.warning(f"Custom model {model_path} not found. Falling back to {CONFIG['FALLBACK_MODEL_PATH']}")
67
- model_path = CONFIG["FALLBACK_MODEL_PATH"]
68
- model = YOLO(model_path).to(device)
69
- logger.info(f"Model loaded: {model_path}")
70
- if model_path == CONFIG["FALLBACK_MODEL_PATH"]:
71
  logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
 
 
 
 
72
  return model
73
  except Exception as e:
74
  logger.error(f"Failed to load model: {e}")
@@ -77,9 +85,57 @@ def load_model():
77
  model = load_model()
78
 
79
  # ==========================
80
- # Salesforce Integration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # ==========================
82
- @retry(stop_max_attempt_number=2, wait_fixed=1000)
83
  def connect_to_salesforce():
84
  try:
85
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
@@ -139,7 +195,6 @@ def generate_violation_pdf(violations, score):
139
  logger.error(f"Error generating PDF: {e}")
140
  return "", "", None
141
 
142
- @retry(stop_max_attempt_number=2, wait_fixed=1000)
143
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
144
  try:
145
  if not pdf_file:
@@ -153,7 +208,7 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
153
  "FirstPublishLocationId": report_id
154
  }
155
  content_version = sf.ContentVersion.create(content_version_data)
156
- result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'")
157
  if not result['records']:
158
  logger.error("Failed to retrieve ContentVersion")
159
  return ""
@@ -164,7 +219,6 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
164
  logger.error(f"Error uploading PDF to Salesforce: {e}")
165
  return ""
166
 
167
- @retry(stop_max_attempt_number=2, wait_fixed=1000)
168
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
169
  try:
170
  sf = connect_to_salesforce()
@@ -183,10 +237,10 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
183
  }
184
  logger.info(f"Creating Salesforce record with data: {record_data}")
185
  try:
186
- record = sf.Safety_Violation_Report__c.create(record_data)
187
- logger.info(f"Created Safety_Violation_Report__c record: {record['id']}")
188
  except Exception as e:
189
- logger.error(f"Failed to create Safety_Violation_Report__c: {e}")
190
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
191
  logger.warning(f"Fell back to Account record: {record['id']}")
192
  record_id = record["id"]
@@ -195,36 +249,39 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
195
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
196
  if uploaded_url:
197
  try:
198
- sf.Safety_Violation_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
199
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
200
  except Exception as e:
201
- logger.error(f"Failed to update Safety_Violation_Report__c: {e}")
202
  sf.Account.update(record_id, {"Description": uploaded_url})
203
  logger.info(f"Updated Account record {record_id} with PDF URL")
204
  pdf_url = uploaded_url
205
 
206
  return record_id, pdf_url
207
  except Exception as e:
208
- logger.error(f"Salesforce record creation failed: {e}")
209
  return None, ""
210
 
211
- # ==========================
212
- # Safety Score Calculation
213
- # ==========================
214
  def calculate_safety_score(violations):
215
  penalties = {
216
  "no_helmet": 25,
217
  "no_harness": 30,
218
- "unsafe_posture": 20
 
 
219
  }
220
- score = 100
 
221
  for v in violations:
222
- if v["violation"] in penalties:
223
- score -= penalties[v["violation"]]
 
 
 
224
  return max(score, 0)
225
 
226
  # ==========================
227
- # Video Processing
228
  # ==========================
229
  def process_video(video_data):
230
  try:
@@ -237,83 +294,152 @@ def process_video(video_data):
237
  if not video.isOpened():
238
  raise ValueError("Could not open video file")
239
 
240
- violations, snapshots = [], []
 
241
  frame_count = 0
242
  start_time = time.time()
243
  fps = video.get(cv2.CAP_PROP_FPS)
244
- max_frames = int(60 * fps) # Process up to 1 minute
 
 
 
 
 
 
245
 
246
- # Track one snapshot per violation type
247
- snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
248
 
249
  while True:
250
  ret, frame = video.read()
251
- if not ret or frame_count >= max_frames:
252
  break
253
 
254
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
255
  frame_count += 1
256
  continue
257
 
258
- # Stop if processing time exceeds 25 seconds
259
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
260
  logger.info("Processing time limit reached")
261
  break
262
 
 
 
 
263
  results = model(frame, device=device)
264
- seen_violations = set()
 
265
  for result in results:
266
- for box in result.boxes:
267
- cls, conf = int(box.cls), float(box.conf)
268
- label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
269
- # Only process specified violations
270
- if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
271
- logger.warning(f"Unexpected detection: {label} (cls: {cls}, conf: {conf}) - ignored")
 
272
  continue
273
- # Apply confidence threshold
274
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
275
- logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
276
  continue
277
- if label in seen_violations:
278
- continue
279
- seen_violations.add(label)
280
 
281
- violation = {
 
 
282
  "frame": frame_count,
283
  "violation": label,
284
  "confidence": round(conf, 2),
285
- "bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]],
286
- "timestamp": frame_count / fps
287
- }
288
- violations.append(violation)
289
-
290
- # Save only one snapshot per violation type
291
- if not snapshot_taken[label]:
292
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
293
- cv2.imwrite(snapshot_path, frame)
294
- with open(snapshot_path, "rb") as img_file:
295
- img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
296
- snapshots.append({
297
- "violation": label,
298
- "frame": frame_count,
299
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
300
- "snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
301
- })
302
- snapshot_taken[label] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  frame_count += 1
305
 
306
  video.release()
307
  os.remove(video_path)
308
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  if not violations:
310
- logger.info("No violations detected")
311
  return {
312
  "violations": [],
313
  "snapshots": [],
314
  "score": 100,
315
  "salesforce_record_id": None,
316
- "violation_details_url": ""
 
317
  }
318
 
319
  score = calculate_safety_score(violations)
@@ -325,16 +451,18 @@ def process_video(video_data):
325
  "snapshots": snapshots,
326
  "score": score,
327
  "salesforce_record_id": report_id,
328
- "violation_details_url": final_pdf_url
 
329
  }
330
  except Exception as e:
331
- logger.error(f"Error processing video: {e}")
332
  return {
333
  "violations": [],
334
  "snapshots": [],
335
  "score": 100,
336
  "salesforce_record_id": None,
337
- "violation_details_url": ""
 
338
  }
339
 
340
  # ==========================
@@ -344,29 +472,38 @@ def gradio_interface(video_file):
344
  if not video_file:
345
  return "No file uploaded.", "", "No file uploaded.", "", ""
346
  try:
 
 
347
  with open(video_file, "rb") as f:
348
  video_data = f.read()
 
349
  result = process_video(video_data)
350
 
 
 
 
 
351
  violation_table = "No violations detected."
352
  if result["violations"]:
353
- header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
354
- separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n"
355
  rows = []
 
356
  for v in result["violations"]:
357
- display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
358
- row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
359
  rows.append(row)
360
  violation_table = header + separator + "\n".join(rows)
361
 
362
  snapshots_text = "No snapshots captured."
363
  if result["snapshots"]:
 
364
  snapshots_text = "\n".join(
365
- f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
366
  for s in result["snapshots"]
367
  )
368
 
369
- return (
370
  violation_table,
371
  f"Safety Score: {result['score']}%",
372
  snapshots_text,
@@ -374,13 +511,13 @@ def gradio_interface(video_file):
374
  result["violation_details_url"] or "N/A"
375
  )
376
  except Exception as e:
377
- logger.error(f"Error in Gradio interface: {e}")
378
- return f"Error: {str(e)}", "", "Error in processing.", "", ""
379
 
380
  interface = gr.Interface(
381
  fn=gradio_interface,
382
  inputs=gr.Video(label="Upload Site Video"),
383
- outputs=[
384
  gr.Markdown(label="Detected Safety Violations"),
385
  gr.Textbox(label="Compliance Score"),
386
  gr.Markdown(label="Snapshots"),
@@ -388,9 +525,10 @@ interface = gr.Interface(
388
  gr.Textbox(label="Violation Details URL")
389
  ],
390
  title="Worksite Safety Violation Analyzer",
391
- description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored."
 
392
  )
393
 
394
  if __name__ == "__main__":
395
- logger.info("Launching Safety Analyzer App...")
396
- interface.launch()
 
15
  from retrying import retry
16
 
17
  # ==========================
18
+ # Enhanced Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt",
22
+ "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
+ 2: "unsafe_posture",
28
+ 3: "unsafe_zone",
29
+ 4: "improper_tool_use"
30
  },
31
+ "CLASS_COLORS": {
32
+ "no_helmet": (0, 0, 255), # Red
33
+ "no_harness": (0, 165, 255), # Orange
34
+ "unsafe_posture": (0, 255, 0), # Green
35
+ "unsafe_zone": (255, 0, 0), # Blue
36
+ "improper_tool_use": (255, 255, 0) # Yellow
37
+ },
38
+ "DISPLAY_NAMES": {
39
+ "no_helmet": "No Helmet Violation",
40
+ "no_harness": "No Harness Violation",
41
+ "unsafe_posture": "Unsafe Posture Violation",
42
+ "unsafe_zone": "Unsafe Zone Entry",
43
+ "improper_tool_use": "Improper Tool Use"
44
  },
45
  "SF_CREDENTIALS": {
46
+ "username": "prashanth1ai@safety.com",
47
+ "password": "SaiPrash461",
48
+ "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
+ "domain": "login"
50
  },
51
+ "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 5, # Reduced for better detection
53
+ "MAX_PROCESSING_TIME": 60,
54
+ "CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
55
+ "IOU_THRESHOLD": 0.4,
56
+ "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
57
  }
58
 
59
  # Setup logging
60
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
61
  logger = logging.getLogger(__name__)
62
 
 
63
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
64
 
 
 
 
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  logger.info(f"Using device: {device}")
67
 
 
 
 
68
  def load_model():
69
  try:
70
+ if os.path.isfile(CONFIG["MODEL_PATH"]):
71
+ model_path = CONFIG["MODEL_PATH"]
72
+ logger.info(f"Model loaded: {model_path}")
73
+ else:
74
+ model_path = CONFIG["FALLBACK_MODEL"]
 
 
75
  logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
76
+ if not os.path.isfile(model_path):
77
+ logger.info(f"Downloading fallback model: {model_path}")
78
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
79
+ model = YOLO(model_path).to(device)
80
  return model
81
  except Exception as e:
82
  logger.error(f"Failed to load model: {e}")
 
85
  model = load_model()
86
 
87
  # ==========================
88
+ # Enhanced Helper Functions
89
+ # ==========================
90
+ def draw_detections(frame, detections):
91
+ """Draw bounding boxes and labels on frame"""
92
+ for det in detections:
93
+ label = det["violation"]
94
+ confidence = det["confidence"]
95
+ x, y, w, h = det["bounding_box"]
96
+
97
+ # Convert from center coordinates to corner coordinates
98
+ x1 = int(x - w/2)
99
+ y1 = int(y - h/2)
100
+ x2 = int(x + w/2)
101
+ y2 = int(y + h/2)
102
+
103
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
+
106
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
107
+ cv2.putText(frame, display_text, (x1, y1-10),
108
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
109
+ return frame
110
+
111
+ def calculate_iou(box1, box2):
112
+ """Calculate Intersection over Union (IoU) for two bounding boxes."""
113
+ x1, y1, w1, h1 = box1
114
+ x2, y2, w2, h2 = box2
115
+
116
+ # Convert to top-left and bottom-right coordinates
117
+ x1_min, y1_min = x1 - w1/2, y1 - h1/2
118
+ x1_max, y1_max = x1 + w1/2, y1 + h1/2
119
+ x2_min, y2_min = x2 - w2/2, y2 - h2/2
120
+ x2_max, y2_max = x2 + w2/2, y2 + h2/2
121
+
122
+ # Calculate intersection
123
+ x_min = max(x1_min, x2_min)
124
+ y_min = max(y1_min, y2_min)
125
+ x_max = min(x1_max, x2_max)
126
+ y_max = min(y1_max, y2_max)
127
+
128
+ intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
129
+ area1 = w1 * h1
130
+ area2 = w2 * h2
131
+ union = area1 + area2 - intersection
132
+
133
+ return intersection / union if union > 0 else 0
134
+
135
+ # ==========================
136
+ # Salesforce Integration (unchanged)
137
  # ==========================
138
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
139
  def connect_to_salesforce():
140
  try:
141
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
 
195
  logger.error(f"Error generating PDF: {e}")
196
  return "", "", None
197
 
 
198
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
199
  try:
200
  if not pdf_file:
 
208
  "FirstPublishLocationId": report_id
209
  }
210
  content_version = sf.ContentVersion.create(content_version_data)
211
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
212
  if not result['records']:
213
  logger.error("Failed to retrieve ContentVersion")
214
  return ""
 
219
  logger.error(f"Error uploading PDF to Salesforce: {e}")
220
  return ""
221
 
 
222
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
223
  try:
224
  sf = connect_to_salesforce()
 
237
  }
238
  logger.info(f"Creating Salesforce record with data: {record_data}")
239
  try:
240
+ record = sf.Safety_Video_Report__c.create(record_data)
241
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
242
  except Exception as e:
243
+ logger.error(f"Failed to create Safety_Video_Report__c: {e}")
244
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
245
  logger.warning(f"Fell back to Account record: {record['id']}")
246
  record_id = record["id"]
 
249
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
250
  if uploaded_url:
251
  try:
252
+ sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
253
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
254
  except Exception as e:
255
+ logger.error(f"Failed to update Safety_Video_Report__c: {e}")
256
  sf.Account.update(record_id, {"Description": uploaded_url})
257
  logger.info(f"Updated Account record {record_id} with PDF URL")
258
  pdf_url = uploaded_url
259
 
260
  return record_id, pdf_url
261
  except Exception as e:
262
+ logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
263
  return None, ""
264
 
 
 
 
265
  def calculate_safety_score(violations):
266
  penalties = {
267
  "no_helmet": 25,
268
  "no_harness": 30,
269
+ "unsafe_posture": 20,
270
+ "unsafe_zone": 35,
271
+ "improper_tool_use": 25
272
  }
273
+ # Count unique violations per worker
274
+ unique_violations = set()
275
  for v in violations:
276
+ key = (v["worker_id"], v["violation"])
277
+ unique_violations.add(key)
278
+
279
+ total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
280
+ score = 100 - total_penalty
281
  return max(score, 0)
282
 
283
  # ==========================
284
+ # Enhanced Video Processing
285
  # ==========================
286
  def process_video(video_data):
287
  try:
 
294
  if not video.isOpened():
295
  raise ValueError("Could not open video file")
296
 
297
+ violations = []
298
+ snapshots = []
299
  frame_count = 0
300
  start_time = time.time()
301
  fps = video.get(cv2.CAP_PROP_FPS)
302
+ if fps <= 0:
303
+ fps = 30 # Default assumption if FPS cannot be determined
304
+
305
+ # Structure to track workers and their violations
306
+ workers = []
307
+ violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
+ snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
309
 
310
+ logger.info(f"Processing video with FPS: {fps}")
311
+ logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
312
 
313
  while True:
314
  ret, frame = video.read()
315
+ if not ret:
316
  break
317
 
318
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
319
  frame_count += 1
320
  continue
321
 
 
322
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
323
  logger.info("Processing time limit reached")
324
  break
325
 
326
+ current_time = frame_count / fps
327
+
328
+ # Run detection on this frame
329
  results = model(frame, device=device)
330
+
331
+ current_detections = []
332
  for result in results:
333
+ boxes = result.boxes
334
+ for box in boxes:
335
+ cls = int(box.cls)
336
+ conf = float(box.conf)
337
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
+
339
+ if label is None:
340
  continue
341
+
342
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
 
343
  continue
 
 
 
344
 
345
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
+
347
+ current_detections.append({
348
  "frame": frame_count,
349
  "violation": label,
350
  "confidence": round(conf, 2),
351
+ "bounding_box": bbox,
352
+ "timestamp": current_time
353
+ })
354
+
355
+ # Process detections and associate with workers
356
+ for detection in current_detections:
357
+ # Find matching worker
358
+ matched_worker = None
359
+ max_iou = 0
360
+
361
+ for worker in workers:
362
+ iou = calculate_iou(detection["bounding_box"], worker["bbox"])
363
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
364
+ max_iou = iou
365
+ matched_worker = worker
366
+
367
+ if matched_worker:
368
+ # Update worker's position
369
+ matched_worker["bbox"] = detection["bounding_box"]
370
+ matched_worker["last_seen"] = current_time
371
+ worker_id = matched_worker["id"]
372
+ else:
373
+ # New worker
374
+ worker_id = len(workers) + 1
375
+ workers.append({
376
+ "id": worker_id,
377
+ "bbox": detection["bounding_box"],
378
+ "first_seen": current_time,
379
+ "last_seen": current_time
380
+ })
381
+
382
+ # Add to violation history
383
+ detection["worker_id"] = worker_id
384
+ violation_history[detection["violation"]].append(detection)
385
 
386
  frame_count += 1
387
 
388
  video.release()
389
  os.remove(video_path)
390
+
391
+ # Process violation history to confirm persistent violations
392
+ for violation_type, detections in violation_history.items():
393
+ if not detections:
394
+ continue
395
+
396
+ # Group by worker
397
+ worker_violations = {}
398
+ for det in detections:
399
+ if det["worker_id"] not in worker_violations:
400
+ worker_violations[det["worker_id"]] = []
401
+ worker_violations[det["worker_id"]].append(det)
402
+
403
+ # Check each worker's violations for persistence
404
+ for worker_id, worker_dets in worker_violations.items():
405
+ if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
406
+ # Take the highest confidence detection
407
+ best_detection = max(worker_dets, key=lambda x: x["confidence"])
408
+ violations.append(best_detection)
409
+
410
+ # Capture snapshot if not already taken
411
+ if not snapshot_taken[violation_type]:
412
+ # Get the frame for this violation
413
+ cap = cv2.VideoCapture(video_path)
414
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
415
+ ret, snapshot_frame = cap.read()
416
+ cap.release()
417
+
418
+ if ret:
419
+ # Draw detections on snapshot
420
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
421
+
422
+ snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
423
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
424
+ cv2.imwrite(snapshot_path, snapshot_frame)
425
+ snapshots.append({
426
+ "violation": violation_type,
427
+ "frame": best_detection["frame"],
428
+ "snapshot_path": snapshot_path,
429
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
430
+ })
431
+ snapshot_taken[violation_type] = True
432
+
433
+ # Final processing
434
  if not violations:
435
+ logger.info("No persistent violations detected")
436
  return {
437
  "violations": [],
438
  "snapshots": [],
439
  "score": 100,
440
  "salesforce_record_id": None,
441
+ "violation_details_url": "",
442
+ "message": "No violations detected in the video."
443
  }
444
 
445
  score = calculate_safety_score(violations)
 
451
  "snapshots": snapshots,
452
  "score": score,
453
  "salesforce_record_id": report_id,
454
+ "violation_details_url": final_pdf_url,
455
+ "message": ""
456
  }
457
  except Exception as e:
458
+ logger.error(f"Error processing video: {e}", exc_info=True)
459
  return {
460
  "violations": [],
461
  "snapshots": [],
462
  "score": 100,
463
  "salesforce_record_id": None,
464
+ "violation_details_url": "",
465
+ "message": f"Error processing video: {e}"
466
  }
467
 
468
  # ==========================
 
472
  if not video_file:
473
  return "No file uploaded.", "", "No file uploaded.", "", ""
474
  try:
475
+ yield "Processing video... please wait.", "", "", "", ""
476
+
477
  with open(video_file, "rb") as f:
478
  video_data = f.read()
479
+
480
  result = process_video(video_data)
481
 
482
+ if result.get("message"):
483
+ yield result["message"], "", "", "", ""
484
+ return
485
+
486
  violation_table = "No violations detected."
487
  if result["violations"]:
488
+ header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
489
+ separator = "|------------------------|---------------|------------|-----------|\n"
490
  rows = []
491
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
492
  for v in result["violations"]:
493
+ display_name = violation_name_map.get(v["violation"], v["violation"])
494
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
495
  rows.append(row)
496
  violation_table = header + separator + "\n".join(rows)
497
 
498
  snapshots_text = "No snapshots captured."
499
  if result["snapshots"]:
500
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
501
  snapshots_text = "\n".join(
502
+ f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
503
  for s in result["snapshots"]
504
  )
505
 
506
+ yield (
507
  violation_table,
508
  f"Safety Score: {result['score']}%",
509
  snapshots_text,
 
511
  result["violation_details_url"] or "N/A"
512
  )
513
  except Exception as e:
514
+ logger.error(f"Error in Gradio interface: {e}", exc_info=True)
515
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
516
 
517
  interface = gr.Interface(
518
  fn=gradio_interface,
519
  inputs=gr.Video(label="Upload Site Video"),
520
+ outputs=[
521
  gr.Markdown(label="Detected Safety Violations"),
522
  gr.Textbox(label="Compliance Score"),
523
  gr.Markdown(label="Snapshots"),
 
525
  gr.Textbox(label="Violation Details URL")
526
  ],
527
  title="Worksite Safety Violation Analyzer",
528
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
529
+ allow_flagging="never"
530
  )
531
 
532
  if __name__ == "__main__":
533
+ logger.info("Launching Enhanced Safety Analyzer App...")
534
+ interface.launch()