PrashanthB461 commited on
Commit
27626a3
·
verified ·
1 Parent(s): 39355e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -248
app.py CHANGED
@@ -15,10 +15,10 @@ import logging
15
  from retrying import retry
16
 
17
  # ==========================
18
- # OPTIMIZED CONFIGURATION
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt",
22
  "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
@@ -28,7 +28,7 @@ CONFIG = {
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
- "CLASS_COLORS": {
32
  "no_helmet": (0, 0, 255), # Red
33
  "no_harness": (0, 165, 255), # Orange
34
  "unsafe_posture": (0, 255, 0), # Green
@@ -37,267 +37,221 @@ CONFIG = {
37
  },
38
  "DISPLAY_NAMES": {
39
  "no_helmet": "No Helmet",
40
- "no_harness": "No Harness",
41
  "unsafe_posture": "Unsafe Posture",
42
- "unsafe_zone": "Unsafe Zone",
43
  "improper_tool_use": "Improper Tool Use"
44
  },
45
- "SF_CREDENTIALS": {
46
  "username": "prashanth1ai@safety.com",
47
  "password": "SaiPrash461",
48
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 3,
53
- "MAX_PROCESSING_TIME": 60,
54
- "CONFIDENCE_THRESHOLD": {
55
  "no_helmet": 0.4,
56
  "no_harness": 0.3,
57
  "unsafe_posture": 0.25,
58
  "unsafe_zone": 0.3,
59
  "improper_tool_use": 0.35
60
  },
61
- "IOU_THRESHOLD": 0.4,
62
- "MIN_VIOLATION_FRAMES": 3,
63
- "MIN_VIOLATION_DURATION": 1.5
64
  }
65
 
66
- # Initialize logging
67
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
68
  logger = logging.getLogger(__name__)
69
- os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
70
 
71
- # Device configuration
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  logger.info(f"Using device: {device}")
74
 
 
 
 
75
  def load_model():
76
  try:
77
- if os.path.exists(CONFIG["MODEL_PATH"]):
78
  model = YOLO(CONFIG["MODEL_PATH"]).to(device)
79
  logger.info("Loaded custom safety model")
80
  else:
81
  model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
82
- logger.warning("Using fallback model - recommend training yolov8_safety.pt")
83
  return model
84
  except Exception as e:
85
- logger.error(f"Model loading failed: {str(e)}")
86
  raise
87
 
88
  model = load_model()
89
 
 
 
 
90
  def draw_detections(frame, detections):
91
- """Draw bounding boxes with labels and confidence scores"""
92
  for det in detections:
93
  label = det["violation"]
94
- x, y, w, h = [int(v) for v in det["bounding_box"]]
95
- color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
96
-
97
  x1, y1 = int(x - w/2), int(y - h/2)
98
  x2, y2 = int(x + w/2), int(y + h/2)
99
 
 
100
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
101
-
102
- label_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {det['confidence']:.2f}"
103
- (text_width, text_height), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
104
- cv2.rectangle(frame, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1)
105
- cv2.putText(frame, label_text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
106
  return frame
107
 
108
  def calculate_iou(box1, box2):
109
- """Calculate Intersection over Union for two bounding boxes"""
110
- box1 = [box1[0] - box1[2]/2, box1[1] - box1[3]/2, box1[0] + box1[2]/2, box1[1] + box1[3]/2]
111
- box2 = [box2[0] - box2[2]/2, box2[1] - box2[3]/2, box2[0] + box2[2]/2, box2[1] + box2[3]/2]
112
-
113
- x_left = max(box1[0], box2[0])
114
- y_top = max(box1[1], box2[1])
115
- x_right = min(box1[2], box2[2])
116
- y_bottom = min(box1[3], box2[3])
117
-
118
- if x_right < x_left or y_bottom < y_top:
119
- return 0.0
120
-
121
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
122
- box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
123
- box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
124
-
125
- return intersection_area / float(box1_area + box2_area - intersection_area)
 
 
 
 
 
 
126
 
127
  def generate_violation_pdf(violations, score):
 
128
  try:
129
- pdf_filename = f"violations_{int(time.time())}.pdf"
130
- pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
131
  pdf_file = BytesIO()
132
-
133
  c = canvas.Canvas(pdf_file, pagesize=letter)
134
  c.setFont("Helvetica-Bold", 14)
135
  c.drawString(1 * inch, 10.5 * inch, "Worksite Safety Violation Report")
136
  c.setFont("Helvetica", 12)
137
 
138
- y_position = 10 * inch
 
139
  report_data = [
140
  ("Compliance Score", f"{score}%"),
141
  ("Total Violations", len(violations)),
142
- ("Report Date", time.strftime("%Y-%m-%d %H:%M:%S"))
 
143
  ]
 
 
 
144
 
145
- for key, value in report_data:
146
- c.drawString(1 * inch, y_position, f"{key}: {value}")
147
- y_position -= 0.4 * inch
148
-
149
- y_position -= 0.2 * inch
150
- c.line(1 * inch, y_position, 7.5 * inch, y_position)
151
- y_position -= 0.3 * inch
152
-
153
  c.setFont("Helvetica-Bold", 12)
154
- c.drawString(1 * inch, y_position, "Violation Details:")
155
- y_position -= 0.3 * inch
156
  c.setFont("Helvetica", 10)
 
157
 
158
  if not violations:
159
- c.drawString(1 * inch, y_position, "No violations detected.")
160
  else:
161
  for v in violations:
162
- violation_text = (
163
  f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
164
- f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f}, "
165
- f"Worker: {v['worker_id']})"
166
  )
167
- c.drawString(1 * inch, y_position, violation_text)
168
- y_position -= 0.25 * inch
169
- if y_position < 1 * inch:
170
  c.showPage()
171
- y_position = 10 * inch
172
- c.setFont("Helvetica", 10)
173
 
174
  c.save()
175
  pdf_file.seek(0)
176
 
 
 
 
177
  with open(pdf_path, "wb") as f:
178
  f.write(pdf_file.getvalue())
179
-
180
- public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
181
- logger.info(f"Generated PDF report: {public_url}")
182
- return pdf_path, public_url, pdf_file
183
 
 
184
  except Exception as e:
185
- logger.error(f"PDF generation failed: {str(e)}")
186
- return "", "", None
187
-
188
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
189
- def connect_to_salesforce():
190
- try:
191
- sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
192
- logger.info("Connected to Salesforce")
193
- return sf
194
- except Exception as e:
195
- logger.error(f"Salesforce connection failed: {str(e)}")
196
- raise
197
-
198
- def upload_pdf_to_salesforce(sf, pdf_file, report_id):
199
- try:
200
- encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
201
- content_version = sf.ContentVersion.create({
202
- "Title": f"Safety_Report_{int(time.time())}",
203
- "PathOnClient": "safety_report.pdf",
204
- "VersionData": encoded_pdf,
205
- "FirstPublishLocationId": report_id
206
- })
207
- return f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
208
- except Exception as e:
209
- logger.error(f"PDF upload failed: {str(e)}")
210
- return ""
211
 
212
- def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
 
213
  try:
214
  sf = connect_to_salesforce()
215
 
 
216
  violations_text = "\n".join(
217
- f"- {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
218
- f"at {v['timestamp']:.2f}s (Worker {v['worker_id']}, Confidence: {v['confidence']:.2f})"
219
  for v in violations
220
- ) or "No violations detected"
221
 
 
222
  record_data = {
223
  "Compliance_Score__c": score,
224
  "Violations_Found__c": len(violations),
225
  "Violations_Details__c": violations_text,
226
- "Status__c": "New"
227
  }
 
 
228
 
229
- try:
230
- record = sf.Safety_Video_Report__c.create(record_data)
231
- record_id = record["id"]
232
- logger.info(f"Created Salesforce record: {record_id}")
233
- except Exception as e:
234
- logger.error(f"Failed to create Safety Report: {str(e)}")
235
- record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
236
- record_id = record["id"]
237
- logger.warning(f"Created fallback Account record: {record_id}")
238
-
239
  pdf_url = ""
240
  if pdf_file:
241
- pdf_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
242
- if pdf_url:
243
- try:
244
- sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": pdf_url})
245
- except:
246
- sf.Account.update(record_id, {"Description": pdf_url})
 
 
247
 
248
- return record_id, pdf_url if pdf_url else f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}"
249
-
250
  except Exception as e:
251
- logger.error(f"Salesforce integration failed: {str(e)}")
252
  return None, ""
253
 
254
- def calculate_safety_score(violations):
255
- penalties = {
256
- "no_helmet": 25,
257
- "no_harness": 30,
258
- "unsafe_posture": 20,
259
- "unsafe_zone": 35,
260
- "improper_tool_use": 25
261
- }
262
- unique_violations = {(v["worker_id"], v["violation"]) for v in violations}
263
- total_penalty = sum(penalties.get(v[1], 0) for v in unique_violations)
264
- return max(100 - total_penalty, 0)
265
-
266
- def process_video(video_data):
267
  try:
268
- temp_video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
269
- with open(temp_video_path, "wb") as f:
270
- f.write(video_data)
271
-
272
- cap = cv2.VideoCapture(temp_video_path)
273
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
274
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
275
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
276
-
277
- workers = []
278
  violations = []
279
  snapshots = []
280
- violation_history = {k: [] for k in CONFIG["VIOLATION_LABELS"].values()}
281
- snapshot_taken = {k: False for k in CONFIG["VIOLATION_LABELS"].values()}
282
-
283
- frame_count = 0
284
- start_time = time.time()
285
 
286
  while cap.isOpened():
287
  ret, frame = cap.read()
288
  if not ret:
289
  break
290
-
291
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
292
  frame_count += 1
293
  continue
294
-
295
- if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
296
- logger.warning("Processing timeout reached")
297
- break
298
-
299
  current_time = frame_count / fps
300
- results = model(frame, device=device, verbose=False)
301
 
302
  for result in results:
303
  for box in result.boxes:
@@ -307,150 +261,160 @@ def process_video(video_data):
307
 
308
  if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3):
309
  continue
310
-
311
- bbox = box.xywh.cpu().numpy()[0].tolist()
312
 
 
 
 
 
 
 
 
 
 
 
313
  matched_worker = None
314
  max_iou = 0
315
  for worker in workers:
316
- iou = calculate_iou(bbox, worker["bbox"])
317
  if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
318
  max_iou = iou
319
  matched_worker = worker
320
-
321
  if matched_worker:
322
  worker_id = matched_worker["id"]
323
  matched_worker["bbox"] = bbox
324
- matched_worker["last_seen"] = current_time
325
  else:
326
  worker_id = len(workers) + 1
327
- workers.append({
328
- "id": worker_id,
329
- "bbox": bbox,
330
- "first_seen": current_time,
331
- "last_seen": current_time
332
- })
333
 
334
- violation_history[label].append({
335
- "frame": frame_count,
336
- "violation": label,
337
- "confidence": round(conf, 2),
338
- "bounding_box": bbox,
339
- "timestamp": current_time,
340
- "worker_id": worker_id
341
- })
 
 
 
 
 
 
 
 
 
342
 
343
  frame_count += 1
344
 
345
- for violation_type, detections in violation_history.items():
346
- if not detections:
347
- continue
348
-
349
- worker_groups = {}
350
- for det in detections:
351
- if det["worker_id"] not in worker_groups:
352
- worker_groups[det["worker_id"]] = []
353
- worker_groups[det["worker_id"]].append(det)
354
-
355
- for worker_id, worker_dets in worker_groups.items():
356
- if len(worker_dets) < 2:
357
- continue
358
-
359
- duration = worker_dets[-1]["timestamp"] - worker_dets[0]["timestamp"]
360
- if duration >= CONFIG["MIN_VIOLATION_DURATION"]:
361
- best_det = max(worker_dets, key=lambda x: x["confidence"])
362
- violations.append(best_det)
363
-
364
- if not snapshot_taken[violation_type]:
365
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_det["frame"])
366
- ret, snapshot_frame = cap.read()
367
- if ret:
368
- snapshot_frame = draw_detections(snapshot_frame, [best_det])
369
- filename = f"{violation_type}_{best_det['frame']}.jpg"
370
- path = os.path.join(CONFIG["OUTPUT_DIR"], filename)
371
- cv2.imwrite(path, snapshot_frame)
372
- snapshots.append({
373
- "violation": violation_type,
374
- "frame": best_det["frame"],
375
- "path": path,
376
- "url": f"{CONFIG['PUBLIC_URL_BASE']}{filename}"
377
- })
378
- snapshot_taken[violation_type] = True
379
-
380
  cap.release()
381
- os.remove(temp_video_path)
382
 
383
- score = calculate_safety_score(violations)
384
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
385
- record_id, sf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  return {
388
- "violations": violations,
389
  "snapshots": snapshots,
390
- "score": score,
391
- "salesforce_record_id": record_id,
392
- "violation_details_url": sf_url or pdf_url,
393
  "message": ""
394
  }
395
-
396
  except Exception as e:
397
- logger.error(f"Video processing failed: {str(e)}")
398
  return {
399
  "violations": [],
400
  "snapshots": [],
401
  "score": 100,
402
- "salesforce_record_id": None,
403
- "violation_details_url": "",
404
  "message": f"Error: {str(e)}"
405
  }
406
 
407
- def gradio_interface(video_file):
 
 
 
 
 
 
 
408
  try:
409
- yield "Analyzing video...", "", "", "", ""
 
 
 
410
 
411
- with open(video_file, "rb") as f:
412
- result = process_video(f.read())
 
 
 
 
 
 
 
 
413
 
 
414
  violation_table = (
415
- "| Violation Type | Timestamp | Confidence | Worker ID |\n"
416
- "|---------------------|-----------|------------|-----------|\n" +
417
  "\n".join(
418
- f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<19} | "
419
- f"{v['timestamp']:.2f} | "
420
- f"{v['confidence']:.2f} | "
421
- f"{v['worker_id']} |"
422
  for v in result["violations"]
423
- )
424
- ) if result["violations"] else "No violations detected"
425
 
426
  snapshots_md = "\n".join(
427
- f"![{s['violation']} at frame {s['frame']}]({s['url']})"
 
428
  for s in result["snapshots"]
429
- ) if result["snapshots"] else "No snapshots"
430
 
431
- yield (
432
  violation_table,
433
  f"Safety Score: {result['score']}%",
434
  snapshots_md,
435
- f"Salesforce ID: {result['salesforce_record_id'] or 'None'}",
436
- result["violation_details_url"] or "None"
437
  )
438
  except Exception as e:
439
- logger.error(f"Interface error: {str(e)}")
440
- yield f"Error: {str(e)}", "", "", "", ""
441
 
 
442
  interface = gr.Interface(
443
- fn=gradio_interface,
444
  inputs=gr.Video(label="Upload Site Video"),
445
  outputs=[
446
- gr.Markdown(label="Violations Detected"),
447
- gr.Textbox(label="Compliance Score"),
448
- gr.Markdown(label="Evidence Snapshots"),
449
- gr.Textbox(label="Salesforce Record"),
450
  gr.Textbox(label="Report URL")
451
  ],
452
- title="AI Safety Compliance Monitor",
453
- description="Detects 5 violation types: No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use"
 
 
 
454
  )
455
 
456
  if __name__ == "__main__":
 
15
  from retrying import retry
16
 
17
  # ==========================
18
+ # Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt", # Your custom-trained model
22
  "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
 
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
+ "CLASS_COLORS": { # Bounding box colors
32
  "no_helmet": (0, 0, 255), # Red
33
  "no_harness": (0, 165, 255), # Orange
34
  "unsafe_posture": (0, 255, 0), # Green
 
37
  },
38
  "DISPLAY_NAMES": {
39
  "no_helmet": "No Helmet",
40
+ "no_harness": "No Safety Harness",
41
  "unsafe_posture": "Unsafe Posture",
42
+ "unsafe_zone": "Unsafe Zone Entry",
43
  "improper_tool_use": "Improper Tool Use"
44
  },
45
+ "SF_CREDENTIALS": { # Salesforce credentials
46
  "username": "prashanth1ai@safety.com",
47
  "password": "SaiPrash461",
48
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 5, # Process every 5th frame (balance speed vs. accuracy)
53
+ "MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
54
+ "CONFIDENCE_THRESHOLD": { # Per-class thresholds
55
  "no_helmet": 0.4,
56
  "no_harness": 0.3,
57
  "unsafe_posture": 0.25,
58
  "unsafe_zone": 0.3,
59
  "improper_tool_use": 0.35
60
  },
61
+ "IOU_THRESHOLD": 0.4, # For worker tracking
62
+ "MIN_VIOLATION_FRAMES": 3 # Min frames to confirm a violation
 
63
  }
64
 
65
+ # Setup logging
66
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
67
  logger = logging.getLogger(__name__)
 
68
 
69
+ os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
70
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
  logger.info(f"Using device: {device}")
72
 
73
+ # ==========================
74
+ # Load YOLOv8 Model
75
+ # ==========================
76
  def load_model():
77
  try:
78
+ if os.path.isfile(CONFIG["MODEL_PATH"]):
79
  model = YOLO(CONFIG["MODEL_PATH"]).to(device)
80
  logger.info("Loaded custom safety model")
81
  else:
82
  model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
83
+ logger.warning("Using fallback model (lower accuracy)")
84
  return model
85
  except Exception as e:
86
+ logger.error(f"Model load failed: {e}")
87
  raise
88
 
89
  model = load_model()
90
 
91
+ # ==========================
92
+ # Core Detection Functions
93
+ # ==========================
94
  def draw_detections(frame, detections):
95
+ """Draw bounding boxes with labels on frame."""
96
  for det in detections:
97
  label = det["violation"]
98
+ conf = det["confidence"]
99
+ x, y, w, h = det["bounding_box"]
 
100
  x1, y1 = int(x - w/2), int(y - h/2)
101
  x2, y2 = int(x + w/2), int(y + h/2)
102
 
103
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
+ cv2.putText(frame, f"{label}: {conf:.2f}", (x1, y1-10),
106
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
 
 
107
  return frame
108
 
109
  def calculate_iou(box1, box2):
110
+ """Compute Intersection-over-Union for tracking."""
111
+ x1, y1, w1, h1 = box1
112
+ x2, y2, w2, h2 = box2
113
+ x_min = max(x1 - w1/2, x2 - w2/2)
114
+ y_min = max(y1 - h1/2, y2 - h2/2)
115
+ x_max = min(x1 + w1/2, x2 + w2/2)
116
+ y_max = min(y1 + h1/2, y2 + h2/2)
117
+ intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
118
+ union = w1 * h1 + w2 * h2 - intersection
119
+ return intersection / union if union > 0 else 0
120
+
121
+ # ==========================
122
+ # Salesforce Integration
123
+ # ==========================
124
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
125
+ def connect_to_salesforce():
126
+ try:
127
+ sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
128
+ logger.info("Salesforce connection successful")
129
+ return sf
130
+ except Exception as e:
131
+ logger.error(f"Salesforce login failed: {e}")
132
+ raise
133
 
134
  def generate_violation_pdf(violations, score):
135
+ """Generate PDF report with violations."""
136
  try:
 
 
137
  pdf_file = BytesIO()
 
138
  c = canvas.Canvas(pdf_file, pagesize=letter)
139
  c.setFont("Helvetica-Bold", 14)
140
  c.drawString(1 * inch, 10.5 * inch, "Worksite Safety Violation Report")
141
  c.setFont("Helvetica", 12)
142
 
143
+ # Report metadata
144
+ y_pos = 9.8 * inch
145
  report_data = [
146
  ("Compliance Score", f"{score}%"),
147
  ("Total Violations", len(violations)),
148
+ ("Date", time.strftime("%Y-%m-%d")),
149
+ ("Time", time.strftime("%H:%M:%S"))
150
  ]
151
+ for label, value in report_data:
152
+ c.drawString(1 * inch, y_pos, f"{label}: {value}")
153
+ y_pos -= 0.4 * inch
154
 
155
+ # Violation details
156
+ y_pos -= 0.3 * inch
 
 
 
 
 
 
157
  c.setFont("Helvetica-Bold", 12)
158
+ c.drawString(1 * inch, y_pos, "Violation Details:")
 
159
  c.setFont("Helvetica", 10)
160
+ y_pos -= 0.3 * inch
161
 
162
  if not violations:
163
+ c.drawString(1 * inch, y_pos, "No violations detected.")
164
  else:
165
  for v in violations:
166
+ text = (
167
  f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
168
+ f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
 
169
  )
170
+ c.drawString(1 * inch, y_pos, text)
171
+ y_pos -= 0.25 * inch
172
+ if y_pos < 1 * inch:
173
  c.showPage()
174
+ y_pos = 10 * inch
 
175
 
176
  c.save()
177
  pdf_file.seek(0)
178
 
179
+ # Save PDF
180
+ pdf_filename = f"violation_report_{int(time.time())}.pdf"
181
+ pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
182
  with open(pdf_path, "wb") as f:
183
  f.write(pdf_file.getvalue())
 
 
 
 
184
 
185
+ return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}", pdf_file
186
  except Exception as e:
187
+ logger.error(f"PDF generation failed: {e}")
188
+ return None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ def push_report_to_salesforce(violations, score, pdf_file):
191
+ """Upload report to Salesforce."""
192
  try:
193
  sf = connect_to_salesforce()
194
 
195
+ # Create violation details text
196
  violations_text = "\n".join(
197
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
198
+ f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
199
  for v in violations
200
+ ) or "No violations detected."
201
 
202
+ # Create Salesforce record
203
  record_data = {
204
  "Compliance_Score__c": score,
205
  "Violations_Found__c": len(violations),
206
  "Violations_Details__c": violations_text,
207
+ "Status__c": "Pending Review"
208
  }
209
+ record = sf.Safety_Video_Report__c.create(record_data)
210
+ record_id = record["id"]
211
 
212
+ # Upload PDF if available
 
 
 
 
 
 
 
 
 
213
  pdf_url = ""
214
  if pdf_file:
215
+ encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode("utf-8")
216
+ content_version = sf.ContentVersion.create({
217
+ "Title": f"Safety_Report_{record_id}",
218
+ "PathOnClient": f"report_{record_id}.pdf",
219
+ "VersionData": encoded_pdf,
220
+ "FirstPublishLocationId": record_id
221
+ })
222
+ pdf_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
223
 
224
+ return record_id, pdf_url
 
225
  except Exception as e:
226
+ logger.error(f"Salesforce upload failed: {e}")
227
  return None, ""
228
 
229
+ # ==========================
230
+ # Video Processing
231
+ # ==========================
232
+ def process_video(video_path):
233
+ """Analyze video for safety violations."""
 
 
 
 
 
 
 
 
234
  try:
235
+ cap = cv2.VideoCapture(video_path)
 
 
 
 
236
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
237
+ frame_count = 0
 
 
 
238
  violations = []
239
  snapshots = []
240
+ workers = []
241
+ snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
 
 
 
242
 
243
  while cap.isOpened():
244
  ret, frame = cap.read()
245
  if not ret:
246
  break
247
+
248
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
249
  frame_count += 1
250
  continue
251
+
252
+ # Run detection
253
+ results = model(frame, device=device)
 
 
254
  current_time = frame_count / fps
 
255
 
256
  for result in results:
257
  for box in result.boxes:
 
261
 
262
  if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3):
263
  continue
 
 
264
 
265
+ bbox = box.xywh.cpu().numpy()[0]
266
+ detection = {
267
+ "frame": frame_count,
268
+ "violation": label,
269
+ "confidence": conf,
270
+ "bounding_box": bbox,
271
+ "timestamp": current_time
272
+ }
273
+
274
+ # Track worker
275
  matched_worker = None
276
  max_iou = 0
277
  for worker in workers:
278
+ iou = calculate_iou(worker["bbox"], bbox)
279
  if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
280
  max_iou = iou
281
  matched_worker = worker
282
+
283
  if matched_worker:
284
  worker_id = matched_worker["id"]
285
  matched_worker["bbox"] = bbox
 
286
  else:
287
  worker_id = len(workers) + 1
288
+ workers.append({"id": worker_id, "bbox": bbox})
 
 
 
 
 
289
 
290
+ detection["worker_id"] = worker_id
291
+ violations.append(detection)
292
+
293
+ # Capture snapshot if first detection of this type
294
+ if not snapshot_taken[label]:
295
+ snapshot_path = os.path.join(
296
+ CONFIG["OUTPUT_DIR"],
297
+ f"{label}_{frame_count}.jpg"
298
+ )
299
+ cv2.imwrite(snapshot_path, draw_detections(frame.copy(), [detection]))
300
+ snapshots.append({
301
+ "violation": label,
302
+ "frame": frame_count,
303
+ "path": snapshot_path,
304
+ "url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
305
+ })
306
+ snapshot_taken[label] = True
307
 
308
  frame_count += 1
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  cap.release()
 
311
 
312
+ # Filter violations (require min frames)
313
+ filtered_violations = []
314
+ violation_counts = {}
315
+ for v in violations:
316
+ key = (v["worker_id"], v["violation"])
317
+ violation_counts[key] = violation_counts.get(key, 0) + 1
318
+
319
+ for v in violations:
320
+ if violation_counts[(v["worker_id"], v["violation"])] >= CONFIG["MIN_VIOLATION_FRAMES"]:
321
+ filtered_violations.append(v)
322
+
323
+ # Calculate safety score
324
+ penalty_weights = {
325
+ "no_helmet": 25,
326
+ "no_harness": 30,
327
+ "unsafe_posture": 20,
328
+ "unsafe_zone": 35,
329
+ "improper_tool_use": 25
330
+ }
331
+ unique_violations = set((v["worker_id"], v["violation"]) for v in filtered_violations)
332
+ total_penalty = sum(penalty_weights.get(v, 0) for _, v in unique_violations)
333
+ safety_score = max(100 - total_penalty, 0)
334
 
335
  return {
336
+ "violations": filtered_violations,
337
  "snapshots": snapshots,
338
+ "score": safety_score,
 
 
339
  "message": ""
340
  }
 
341
  except Exception as e:
342
+ logger.error(f"Video processing failed: {e}")
343
  return {
344
  "violations": [],
345
  "snapshots": [],
346
  "score": 100,
 
 
347
  "message": f"Error: {str(e)}"
348
  }
349
 
350
+ # ==========================
351
+ # Gradio Interface
352
+ # ==========================
353
+ def analyze_video(video_file):
354
+ """Gradio interface function."""
355
+ if not video_file:
356
+ return "No video uploaded", "", "", "", ""
357
+
358
  try:
359
+ # Process video
360
+ result = process_video(video_file)
361
+ if result["message"]:
362
+ return result["message"], "", "", "", ""
363
 
364
+ # Generate report
365
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(
366
+ result["violations"],
367
+ result["score"]
368
+ )
369
+ record_id, sf_url = push_report_to_salesforce(
370
+ result["violations"],
371
+ result["score"],
372
+ pdf_file
373
+ )
374
 
375
+ # Format outputs
376
  violation_table = (
377
+ "| Violation Type | Timestamp (s) | Confidence | Worker ID |\n"
378
+ "|------------------------|---------------|------------|-----------|\n" +
379
  "\n".join(
380
+ f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<22} | "
381
+ f"{v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
 
 
382
  for v in result["violations"]
383
+ ) if result["violations"] else "No violations detected."
384
+ )
385
 
386
  snapshots_md = "\n".join(
387
+ f"**{CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])}** "
388
+ f"(Frame {s['frame']}): ![]({s['url']})"
389
  for s in result["snapshots"]
390
+ ) if result["snapshots"] else "No snapshots available."
391
 
392
+ return (
393
  violation_table,
394
  f"Safety Score: {result['score']}%",
395
  snapshots_md,
396
+ f"Salesforce Record: {record_id or 'N/A'}",
397
+ sf_url or pdf_url or "N/A"
398
  )
399
  except Exception as e:
400
+ return f"Error: {str(e)}", "", "", "", ""
 
401
 
402
+ # Launch Gradio App
403
  interface = gr.Interface(
404
+ fn=analyze_video,
405
  inputs=gr.Video(label="Upload Site Video"),
406
  outputs=[
407
+ gr.Markdown("## Detected Violations"),
408
+ gr.Textbox(label="Safety Score"),
409
+ gr.Markdown("## Violation Snapshots"),
410
+ gr.Textbox(label="Salesforce Record ID"),
411
  gr.Textbox(label="Report URL")
412
  ],
413
+ title="AI Safety Compliance Analyzer",
414
+ description=(
415
+ "Upload worksite video to detect safety violations. "
416
+ "Supported violations: Missing Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use."
417
+ )
418
  )
419
 
420
  if __name__ == "__main__":