PrashanthB461 commited on
Commit
f4592c4
·
verified ·
1 Parent(s): 64f085e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -70
app.py CHANGED
@@ -15,10 +15,11 @@ import logging
15
  from retrying import retry
16
 
17
  # ==========================
18
- # Optimized Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt",
 
22
  "OUTPUT_DIR": "static/output",
23
  "VIOLATION_LABELS": {
24
  0: "no_helmet",
@@ -27,18 +28,38 @@ CONFIG = {
27
  3: "unsafe_zone",
28
  4: "improper_tool_use"
29
  },
30
- "CLASS_COLORS": {
31
- "no_helmet": (0, 0, 255),
32
- "no_harness": (0, 165, 255),
33
- "unsafe_posture": (0, 255, 0),
34
- "unsafe_zone": (255, 0, 0),
35
- "improper_tool_use": (255, 255, 0)
36
  },
37
- "FRAME_SKIP": 10, # Increased from 5 to 10 for faster processing
38
- "MAX_PROCESSING_TIME": 30, # Reduced from 60 to 30 seconds
39
- "CONFIDENCE_THRESHOLD": 0.35, # Slightly increased for faster filtering
40
- "IOU_THRESHOLD": 0.5,
41
- "MIN_VIOLATION_FRAMES": 2 # Reduced from 3 to 2 frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  }
43
 
44
  # Setup logging
@@ -49,45 +70,188 @@ os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  logger.info(f"Using device: {device}")
51
 
52
- # Load model with caching
53
- model = None
 
54
  def load_model():
55
- global model
56
- if model is None:
57
- model = YOLO(CONFIG["MODEL_PATH"]).to(device)
58
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # ==========================
61
- # Optimized Processing Functions
62
  # ==========================
63
  def process_video(video_path):
64
- """Optimized video processing with early termination"""
65
  try:
66
  cap = cv2.VideoCapture(video_path)
67
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
68
  frame_count = 0
69
  violations = []
 
70
  workers = []
71
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
72
- start_time = time.time()
73
 
74
  while cap.isOpened():
75
  ret, frame = cap.read()
76
  if not ret:
77
  break
78
 
79
- # Skip frames for faster processing
80
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
81
  frame_count += 1
82
  continue
83
 
84
- # Early termination if processing takes too long
85
- if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
86
- logger.info("Processing time limit reached")
87
- break
88
-
89
- # Run detection on current frame
90
- results = load_model()(frame, device=device, verbose=False) # Disable verbose logging
91
 
92
  for result in results:
93
  for box in result.boxes:
@@ -95,7 +259,7 @@ def process_video(video_path):
95
  conf = float(box.conf)
96
  label = CONFIG["VIOLATION_LABELS"].get(cls)
97
 
98
- if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
99
  continue
100
 
101
  bbox = box.xywh.cpu().numpy()[0]
@@ -104,23 +268,26 @@ def process_video(video_path):
104
  "violation": label,
105
  "confidence": conf,
106
  "bounding_box": bbox,
107
- "timestamp": frame_count / fps
108
  }
109
 
110
- # Simplified worker tracking
111
- matched = False
 
112
  for worker in workers:
113
- if calculate_iou(worker["bbox"], bbox) > CONFIG["IOU_THRESHOLD"]:
114
- worker["bbox"] = bbox
115
- detection["worker_id"] = worker["id"]
116
- matched = True
117
- break
118
 
119
- if not matched:
 
 
 
120
  worker_id = len(workers) + 1
121
  workers.append({"id": worker_id, "bbox": bbox})
122
- detection["worker_id"] = worker_id
123
 
 
124
  violations.append(detection)
125
 
126
  # Capture snapshot if first detection of this type
@@ -129,7 +296,13 @@ def process_video(video_path):
129
  CONFIG["OUTPUT_DIR"],
130
  f"{label}_{frame_count}.jpg"
131
  )
132
- cv2.imwrite(snapshot_path, frame)
 
 
 
 
 
 
133
  snapshot_taken[label] = True
134
 
135
  frame_count += 1
@@ -147,9 +320,22 @@ def process_video(video_path):
147
  if violation_counts[(v["worker_id"], v["violation"])] >= CONFIG["MIN_VIOLATION_FRAMES"]:
148
  filtered_violations.append(v)
149
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  return {
151
  "violations": filtered_violations,
152
- "snapshots": [k for k, v in snapshot_taken.items() if v],
 
153
  "message": ""
154
  }
155
  except Exception as e:
@@ -157,68 +343,79 @@ def process_video(video_path):
157
  return {
158
  "violations": [],
159
  "snapshots": [],
 
160
  "message": f"Error: {str(e)}"
161
  }
162
 
163
  # ==========================
164
- # Fast Gradio Interface
165
  # ==========================
166
  def analyze_video(video_file):
167
- """Optimized interface with progress updates"""
168
  if not video_file:
169
  return "No video uploaded", "", "", "", ""
170
 
171
- # Immediate feedback
172
- yield "Processing started...", "", "", "", ""
173
-
174
  try:
175
- # Process video in background
176
  result = process_video(video_file)
177
-
178
  if result["message"]:
179
- yield result["message"], "", "", "", ""
180
- return
 
 
 
 
 
 
 
 
 
 
181
 
182
- # Generate simplified output
183
  violation_table = (
184
- "| Violation Type | Timestamp | Confidence | Worker ID |\n"
185
- "|----------------|-----------|------------|-----------|\n" +
186
  "\n".join(
187
- f"| {v['violation']} | {v['timestamp']:.1f}s | {v['confidence']:.2f} | {v['worker_id']} |"
 
188
  for v in result["violations"]
189
  ) if result["violations"] else "No violations detected."
190
  )
191
 
192
  snapshots_md = "\n".join(
193
- f"**{violation}** detected"
194
- for violation in result["snapshots"]
 
195
  ) if result["snapshots"] else "No snapshots available."
196
 
197
- yield (
198
  violation_table,
199
- f"Analysis complete in {time.time() - start_time:.1f}s",
200
  snapshots_md,
201
- "Salesforce integration placeholder",
202
- "Report URL placeholder"
203
  )
204
  except Exception as e:
205
- yield f"Error: {str(e)}", "", "", "", ""
206
 
207
- # Launch optimized interface
208
  interface = gr.Interface(
209
  fn=analyze_video,
210
  inputs=gr.Video(label="Upload Site Video"),
211
  outputs=[
212
  gr.Markdown("## Detected Violations"),
213
- gr.Textbox(label="Processing Time"),
214
  gr.Markdown("## Violation Snapshots"),
215
- gr.Textbox(label="Salesforce Record"),
216
  gr.Textbox(label="Report URL")
217
  ],
218
- title="Fast Safety Compliance Analyzer",
219
- description="Optimized for quick safety violation detection in worksite videos.",
220
- allow_flagging="never"
 
 
221
  )
222
 
223
  if __name__ == "__main__":
224
- interface.launch()
 
15
  from retrying import retry
16
 
17
  # ==========================
18
+ # Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt", # Your custom-trained model
22
+ "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
 
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
+ "CLASS_COLORS": { # Bounding box colors
32
+ "no_helmet": (0, 0, 255), # Red
33
+ "no_harness": (0, 165, 255), # Orange
34
+ "unsafe_posture": (0, 255, 0), # Green
35
+ "unsafe_zone": (255, 0, 0), # Blue
36
+ "improper_tool_use": (255, 255, 0) # Yellow
37
  },
38
+ "DISPLAY_NAMES": {
39
+ "no_helmet": "No Helmet",
40
+ "no_harness": "No Safety Harness",
41
+ "unsafe_posture": "Unsafe Posture",
42
+ "unsafe_zone": "Unsafe Zone Entry",
43
+ "improper_tool_use": "Improper Tool Use"
44
+ },
45
+ "SF_CREDENTIALS": { # Salesforce credentials
46
+ "username": "prashanth1ai@safety.com",
47
+ "password": "SaiPrash461",
48
+ "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
+ "domain": "login"
50
+ },
51
+ "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 5, # Process every 5th frame (balance speed vs. accuracy)
53
+ "MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
54
+ "CONFIDENCE_THRESHOLD": { # Per-class thresholds
55
+ "no_helmet": 0.4,
56
+ "no_harness": 0.3,
57
+ "unsafe_posture": 0.25,
58
+ "unsafe_zone": 0.3,
59
+ "improper_tool_use": 0.35
60
+ },
61
+ "IOU_THRESHOLD": 0.4, # For worker tracking
62
+ "MIN_VIOLATION_FRAMES": 3 # Min frames to confirm a violation
63
  }
64
 
65
  # Setup logging
 
70
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
  logger.info(f"Using device: {device}")
72
 
73
+ # ==========================
74
+ # Load YOLOv8 Model
75
+ # ==========================
76
  def load_model():
77
+ try:
78
+ if os.path.isfile(CONFIG["MODEL_PATH"]):
79
+ model = YOLO(CONFIG["MODEL_PATH"]).to(device)
80
+ logger.info("Loaded custom safety model")
81
+ else:
82
+ model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
83
+ logger.warning("Using fallback model (lower accuracy)")
84
+ return model
85
+ except Exception as e:
86
+ logger.error(f"Model load failed: {e}")
87
+ raise
88
+
89
+ model = load_model()
90
+
91
+ # ==========================
92
+ # Core Detection Functions
93
+ # ==========================
94
+ def draw_detections(frame, detections):
95
+ """Draw bounding boxes with labels on frame."""
96
+ for det in detections:
97
+ label = det["violation"]
98
+ conf = det["confidence"]
99
+ x, y, w, h = det["bounding_box"]
100
+ x1, y1 = int(x - w/2), int(y - h/2)
101
+ x2, y2 = int(x + w/2), int(y + h/2)
102
+
103
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
+ cv2.putText(frame, f"{label}: {conf:.2f}", (x1, y1-10),
106
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
107
+ return frame
108
+
109
+ def calculate_iou(box1, box2):
110
+ """Compute Intersection-over-Union for tracking."""
111
+ x1, y1, w1, h1 = box1
112
+ x2, y2, w2, h2 = box2
113
+ x_min = max(x1 - w1/2, x2 - w2/2)
114
+ y_min = max(y1 - h1/2, y2 - h2/2)
115
+ x_max = min(x1 + w1/2, x2 + w2/2)
116
+ y_max = min(y1 + h1/2, y2 + h2/2)
117
+ intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
118
+ union = w1 * h1 + w2 * h2 - intersection
119
+ return intersection / union if union > 0 else 0
120
+
121
+ # ==========================
122
+ # Salesforce Integration
123
+ # ==========================
124
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
125
+ def connect_to_salesforce():
126
+ try:
127
+ sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
128
+ logger.info("Salesforce connection successful")
129
+ return sf
130
+ except Exception as e:
131
+ logger.error(f"Salesforce login failed: {e}")
132
+ raise
133
+
134
+ def generate_violation_pdf(violations, score):
135
+ """Generate PDF report with violations."""
136
+ try:
137
+ pdf_file = BytesIO()
138
+ c = canvas.Canvas(pdf_file, pagesize=letter)
139
+ c.setFont("Helvetica-Bold", 14)
140
+ c.drawString(1 * inch, 10.5 * inch, "Worksite Safety Violation Report")
141
+ c.setFont("Helvetica", 12)
142
+
143
+ # Report metadata
144
+ y_pos = 9.8 * inch
145
+ report_data = [
146
+ ("Compliance Score", f"{score}%"),
147
+ ("Total Violations", len(violations)),
148
+ ("Date", time.strftime("%Y-%m-%d")),
149
+ ("Time", time.strftime("%H:%M:%S"))
150
+ ]
151
+ for label, value in report_data:
152
+ c.drawString(1 * inch, y_pos, f"{label}: {value}")
153
+ y_pos -= 0.4 * inch
154
+
155
+ # Violation details
156
+ y_pos -= 0.3 * inch
157
+ c.setFont("Helvetica-Bold", 12)
158
+ c.drawString(1 * inch, y_pos, "Violation Details:")
159
+ c.setFont("Helvetica", 10)
160
+ y_pos -= 0.3 * inch
161
+
162
+ if not violations:
163
+ c.drawString(1 * inch, y_pos, "No violations detected.")
164
+ else:
165
+ for v in violations:
166
+ text = (
167
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
168
+ f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
169
+ )
170
+ c.drawString(1 * inch, y_pos, text)
171
+ y_pos -= 0.25 * inch
172
+ if y_pos < 1 * inch:
173
+ c.showPage()
174
+ y_pos = 10 * inch
175
+
176
+ c.save()
177
+ pdf_file.seek(0)
178
+
179
+ # Save PDF
180
+ pdf_filename = f"violation_report_{int(time.time())}.pdf"
181
+ pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
182
+ with open(pdf_path, "wb") as f:
183
+ f.write(pdf_file.getvalue())
184
+
185
+ return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}", pdf_file
186
+ except Exception as e:
187
+ logger.error(f"PDF generation failed: {e}")
188
+ return None, None, None
189
+
190
+ def push_report_to_salesforce(violations, score, pdf_file):
191
+ """Upload report to Salesforce."""
192
+ try:
193
+ sf = connect_to_salesforce()
194
+
195
+ # Create violation details text
196
+ violations_text = "\n".join(
197
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
198
+ f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
199
+ for v in violations
200
+ ) or "No violations detected."
201
+
202
+ # Create Salesforce record
203
+ record_data = {
204
+ "Compliance_Score__c": score,
205
+ "Violations_Found__c": len(violations),
206
+ "Violations_Details__c": violations_text,
207
+ "Status__c": "Pending Review"
208
+ }
209
+ record = sf.Safety_Video_Report__c.create(record_data)
210
+ record_id = record["id"]
211
+
212
+ # Upload PDF if available
213
+ pdf_url = ""
214
+ if pdf_file:
215
+ encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode("utf-8")
216
+ content_version = sf.ContentVersion.create({
217
+ "Title": f"Safety_Report_{record_id}",
218
+ "PathOnClient": f"report_{record_id}.pdf",
219
+ "VersionData": encoded_pdf,
220
+ "FirstPublishLocationId": record_id
221
+ })
222
+ pdf_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
223
+
224
+ return record_id, pdf_url
225
+ except Exception as e:
226
+ logger.error(f"Salesforce upload failed: {e}")
227
+ return None, ""
228
 
229
  # ==========================
230
+ # Video Processing
231
  # ==========================
232
  def process_video(video_path):
233
+ """Analyze video for safety violations."""
234
  try:
235
  cap = cv2.VideoCapture(video_path)
236
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
237
  frame_count = 0
238
  violations = []
239
+ snapshots = []
240
  workers = []
241
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
 
242
 
243
  while cap.isOpened():
244
  ret, frame = cap.read()
245
  if not ret:
246
  break
247
 
 
248
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
249
  frame_count += 1
250
  continue
251
 
252
+ # Run detection
253
+ results = model(frame, device=device)
254
+ current_time = frame_count / fps
 
 
 
 
255
 
256
  for result in results:
257
  for box in result.boxes:
 
259
  conf = float(box.conf)
260
  label = CONFIG["VIOLATION_LABELS"].get(cls)
261
 
262
+ if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3):
263
  continue
264
 
265
  bbox = box.xywh.cpu().numpy()[0]
 
268
  "violation": label,
269
  "confidence": conf,
270
  "bounding_box": bbox,
271
+ "timestamp": current_time
272
  }
273
 
274
+ # Track worker
275
+ matched_worker = None
276
+ max_iou = 0
277
  for worker in workers:
278
+ iou = calculate_iou(worker["bbox"], bbox)
279
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
280
+ max_iou = iou
281
+ matched_worker = worker
 
282
 
283
+ if matched_worker:
284
+ worker_id = matched_worker["id"]
285
+ matched_worker["bbox"] = bbox
286
+ else:
287
  worker_id = len(workers) + 1
288
  workers.append({"id": worker_id, "bbox": bbox})
 
289
 
290
+ detection["worker_id"] = worker_id
291
  violations.append(detection)
292
 
293
  # Capture snapshot if first detection of this type
 
296
  CONFIG["OUTPUT_DIR"],
297
  f"{label}_{frame_count}.jpg"
298
  )
299
+ cv2.imwrite(snapshot_path, draw_detections(frame.copy(), [detection]))
300
+ snapshots.append({
301
+ "violation": label,
302
+ "frame": frame_count,
303
+ "path": snapshot_path,
304
+ "url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
305
+ })
306
  snapshot_taken[label] = True
307
 
308
  frame_count += 1
 
320
  if violation_counts[(v["worker_id"], v["violation"])] >= CONFIG["MIN_VIOLATION_FRAMES"]:
321
  filtered_violations.append(v)
322
 
323
+ # Calculate safety score
324
+ penalty_weights = {
325
+ "no_helmet": 25,
326
+ "no_harness": 30,
327
+ "unsafe_posture": 20,
328
+ "unsafe_zone": 35,
329
+ "improper_tool_use": 25
330
+ }
331
+ unique_violations = set((v["worker_id"], v["violation"]) for v in filtered_violations)
332
+ total_penalty = sum(penalty_weights.get(v, 0) for _, v in unique_violations)
333
+ safety_score = max(100 - total_penalty, 0)
334
+
335
  return {
336
  "violations": filtered_violations,
337
+ "snapshots": snapshots,
338
+ "score": safety_score,
339
  "message": ""
340
  }
341
  except Exception as e:
 
343
  return {
344
  "violations": [],
345
  "snapshots": [],
346
+ "score": 100,
347
  "message": f"Error: {str(e)}"
348
  }
349
 
350
  # ==========================
351
+ # Gradio Interface
352
  # ==========================
353
  def analyze_video(video_file):
354
+ """Gradio interface function."""
355
  if not video_file:
356
  return "No video uploaded", "", "", "", ""
357
 
 
 
 
358
  try:
359
+ # Process video
360
  result = process_video(video_file)
 
361
  if result["message"]:
362
+ return result["message"], "", "", "", ""
363
+
364
+ # Generate report
365
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(
366
+ result["violations"],
367
+ result["score"]
368
+ )
369
+ record_id, sf_url = push_report_to_salesforce(
370
+ result["violations"],
371
+ result["score"],
372
+ pdf_file
373
+ )
374
 
375
+ # Format outputs
376
  violation_table = (
377
+ "| Violation Type | Timestamp (s) | Confidence | Worker ID |\n"
378
+ "|------------------------|---------------|------------|-----------|\n" +
379
  "\n".join(
380
+ f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<22} | "
381
+ f"{v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
382
  for v in result["violations"]
383
  ) if result["violations"] else "No violations detected."
384
  )
385
 
386
  snapshots_md = "\n".join(
387
+ f"**{CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])}** "
388
+ f"(Frame {s['frame']}): ![]({s['url']})"
389
+ for s in result["snapshots"]
390
  ) if result["snapshots"] else "No snapshots available."
391
 
392
+ return (
393
  violation_table,
394
+ f"Safety Score: {result['score']}%",
395
  snapshots_md,
396
+ f"Salesforce Record: {record_id or 'N/A'}",
397
+ sf_url or pdf_url or "N/A"
398
  )
399
  except Exception as e:
400
+ return f"Error: {str(e)}", "", "", "", ""
401
 
402
+ # Launch Gradio App
403
  interface = gr.Interface(
404
  fn=analyze_video,
405
  inputs=gr.Video(label="Upload Site Video"),
406
  outputs=[
407
  gr.Markdown("## Detected Violations"),
408
+ gr.Textbox(label="Safety Score"),
409
  gr.Markdown("## Violation Snapshots"),
410
+ gr.Textbox(label="Salesforce Record ID"),
411
  gr.Textbox(label="Report URL")
412
  ],
413
+ title="AI Safety Compliance Analyzer",
414
+ description=(
415
+ "Upload worksite video to detect safety violations. "
416
+ "Supported violations: Missing Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use."
417
+ )
418
  )
419
 
420
  if __name__ == "__main__":
421
+ interface.launch(share=True)