PrashanthB461 commited on
Commit
dc9b74e
·
verified ·
1 Parent(s): 27626a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -359
app.py CHANGED
@@ -2,24 +2,16 @@ import os
2
  import cv2
3
  import gradio as gr
4
  import torch
5
- import numpy as np
6
  from ultralytics import YOLO
7
  import time
8
- from simple_salesforce import Salesforce
9
- from reportlab.lib.pagesizes import letter
10
- from reportlab.pdfgen import canvas
11
- from reportlab.lib.units import inch
12
- from io import BytesIO
13
- import base64
14
  import logging
15
- from retrying import retry
16
 
17
  # ==========================
18
- # Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt", # Your custom-trained model
22
- "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
@@ -28,38 +20,11 @@ CONFIG = {
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
- "CLASS_COLORS": { # Bounding box colors
32
- "no_helmet": (0, 0, 255), # Red
33
- "no_harness": (0, 165, 255), # Orange
34
- "unsafe_posture": (0, 255, 0), # Green
35
- "unsafe_zone": (255, 0, 0), # Blue
36
- "improper_tool_use": (255, 255, 0) # Yellow
37
- },
38
- "DISPLAY_NAMES": {
39
- "no_helmet": "No Helmet",
40
- "no_harness": "No Safety Harness",
41
- "unsafe_posture": "Unsafe Posture",
42
- "unsafe_zone": "Unsafe Zone Entry",
43
- "improper_tool_use": "Improper Tool Use"
44
- },
45
- "SF_CREDENTIALS": { # Salesforce credentials
46
- "username": "prashanth1ai@safety.com",
47
- "password": "SaiPrash461",
48
- "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
- "domain": "login"
50
- },
51
- "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 5, # Process every 5th frame (balance speed vs. accuracy)
53
- "MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
54
- "CONFIDENCE_THRESHOLD": { # Per-class thresholds
55
- "no_helmet": 0.4,
56
- "no_harness": 0.3,
57
- "unsafe_posture": 0.25,
58
- "unsafe_zone": 0.3,
59
- "improper_tool_use": 0.35
60
- },
61
- "IOU_THRESHOLD": 0.4, # For worker tracking
62
- "MIN_VIOLATION_FRAMES": 3 # Min frames to confirm a violation
63
  }
64
 
65
  # Setup logging
@@ -67,355 +32,113 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
67
  logger = logging.getLogger(__name__)
68
 
69
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
70
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
  logger.info(f"Using device: {device}")
72
 
73
- # ==========================
74
- # Load YOLOv8 Model
75
- # ==========================
76
- def load_model():
77
- try:
78
- if os.path.isfile(CONFIG["MODEL_PATH"]):
79
- model = YOLO(CONFIG["MODEL_PATH"]).to(device)
80
- logger.info("Loaded custom safety model")
81
- else:
82
- model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
83
- logger.warning("Using fallback model (lower accuracy)")
84
- return model
85
- except Exception as e:
86
- logger.error(f"Model load failed: {e}")
87
- raise
88
-
89
- model = load_model()
90
-
91
- # ==========================
92
- # Core Detection Functions
93
- # ==========================
94
- def draw_detections(frame, detections):
95
- """Draw bounding boxes with labels on frame."""
96
- for det in detections:
97
- label = det["violation"]
98
- conf = det["confidence"]
99
- x, y, w, h = det["bounding_box"]
100
- x1, y1 = int(x - w/2), int(y - h/2)
101
- x2, y2 = int(x + w/2), int(y + h/2)
102
-
103
- color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
- cv2.putText(frame, f"{label}: {conf:.2f}", (x1, y1-10),
106
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
107
- return frame
108
-
109
- def calculate_iou(box1, box2):
110
- """Compute Intersection-over-Union for tracking."""
111
- x1, y1, w1, h1 = box1
112
- x2, y2, w2, h2 = box2
113
- x_min = max(x1 - w1/2, x2 - w2/2)
114
- y_min = max(y1 - h1/2, y2 - h2/2)
115
- x_max = min(x1 + w1/2, x2 + w2/2)
116
- y_max = min(y1 + h1/2, y2 + h2/2)
117
- intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
118
- union = w1 * h1 + w2 * h2 - intersection
119
- return intersection / union if union > 0 else 0
120
-
121
- # ==========================
122
- # Salesforce Integration
123
- # ==========================
124
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
125
- def connect_to_salesforce():
126
- try:
127
- sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
128
- logger.info("Salesforce connection successful")
129
- return sf
130
- except Exception as e:
131
- logger.error(f"Salesforce login failed: {e}")
132
- raise
133
 
134
- def generate_violation_pdf(violations, score):
135
- """Generate PDF report with violations."""
136
- try:
137
- pdf_file = BytesIO()
138
- c = canvas.Canvas(pdf_file, pagesize=letter)
139
- c.setFont("Helvetica-Bold", 14)
140
- c.drawString(1 * inch, 10.5 * inch, "Worksite Safety Violation Report")
141
- c.setFont("Helvetica", 12)
142
-
143
- # Report metadata
144
- y_pos = 9.8 * inch
145
- report_data = [
146
- ("Compliance Score", f"{score}%"),
147
- ("Total Violations", len(violations)),
148
- ("Date", time.strftime("%Y-%m-%d")),
149
- ("Time", time.strftime("%H:%M:%S"))
150
- ]
151
- for label, value in report_data:
152
- c.drawString(1 * inch, y_pos, f"{label}: {value}")
153
- y_pos -= 0.4 * inch
154
-
155
- # Violation details
156
- y_pos -= 0.3 * inch
157
- c.setFont("Helvetica-Bold", 12)
158
- c.drawString(1 * inch, y_pos, "Violation Details:")
159
- c.setFont("Helvetica", 10)
160
- y_pos -= 0.3 * inch
161
-
162
- if not violations:
163
- c.drawString(1 * inch, y_pos, "No violations detected.")
164
- else:
165
- for v in violations:
166
- text = (
167
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
168
- f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
169
- )
170
- c.drawString(1 * inch, y_pos, text)
171
- y_pos -= 0.25 * inch
172
- if y_pos < 1 * inch:
173
- c.showPage()
174
- y_pos = 10 * inch
175
-
176
- c.save()
177
- pdf_file.seek(0)
178
-
179
- # Save PDF
180
- pdf_filename = f"violation_report_{int(time.time())}.pdf"
181
- pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
182
- with open(pdf_path, "wb") as f:
183
- f.write(pdf_file.getvalue())
184
-
185
- return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}", pdf_file
186
- except Exception as e:
187
- logger.error(f"PDF generation failed: {e}")
188
- return None, None, None
189
-
190
- def push_report_to_salesforce(violations, score, pdf_file):
191
- """Upload report to Salesforce."""
192
- try:
193
- sf = connect_to_salesforce()
194
-
195
- # Create violation details text
196
- violations_text = "\n".join(
197
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
198
- f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
199
- for v in violations
200
- ) or "No violations detected."
201
-
202
- # Create Salesforce record
203
- record_data = {
204
- "Compliance_Score__c": score,
205
- "Violations_Found__c": len(violations),
206
- "Violations_Details__c": violations_text,
207
- "Status__c": "Pending Review"
208
- }
209
- record = sf.Safety_Video_Report__c.create(record_data)
210
- record_id = record["id"]
211
-
212
- # Upload PDF if available
213
- pdf_url = ""
214
- if pdf_file:
215
- encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode("utf-8")
216
- content_version = sf.ContentVersion.create({
217
- "Title": f"Safety_Report_{record_id}",
218
- "PathOnClient": f"report_{record_id}.pdf",
219
- "VersionData": encoded_pdf,
220
- "FirstPublishLocationId": record_id
221
- })
222
- pdf_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
223
-
224
- return record_id, pdf_url
225
- except Exception as e:
226
- logger.error(f"Salesforce upload failed: {e}")
227
- return None, ""
228
 
229
- # ==========================
230
- # Video Processing
231
- # ==========================
232
  def process_video(video_path):
233
- """Analyze video for safety violations."""
234
- try:
235
- cap = cv2.VideoCapture(video_path)
236
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
237
- frame_count = 0
238
- violations = []
239
- snapshots = []
240
- workers = []
241
- snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
242
-
243
  while cap.isOpened():
244
  ret, frame = cap.read()
245
  if not ret:
246
  break
247
-
248
- if frame_count % CONFIG["FRAME_SKIP"] != 0:
249
- frame_count += 1
250
- continue
251
-
252
- # Run detection
253
- results = model(frame, device=device)
254
- current_time = frame_count / fps
255
-
256
- for result in results:
257
- for box in result.boxes:
258
- cls = int(box.cls)
259
- conf = float(box.conf)
260
- label = CONFIG["VIOLATION_LABELS"].get(cls)
261
-
262
- if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3):
263
- continue
264
-
265
- bbox = box.xywh.cpu().numpy()[0]
266
- detection = {
267
- "frame": frame_count,
268
- "violation": label,
269
- "confidence": conf,
270
- "bounding_box": bbox,
271
- "timestamp": current_time
272
- }
273
-
274
- # Track worker
275
- matched_worker = None
276
- max_iou = 0
277
- for worker in workers:
278
- iou = calculate_iou(worker["bbox"], bbox)
279
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
280
- max_iou = iou
281
- matched_worker = worker
282
-
283
- if matched_worker:
284
- worker_id = matched_worker["id"]
285
- matched_worker["bbox"] = bbox
286
- else:
287
- worker_id = len(workers) + 1
288
- workers.append({"id": worker_id, "bbox": bbox})
289
-
290
- detection["worker_id"] = worker_id
291
- violations.append(detection)
292
-
293
- # Capture snapshot if first detection of this type
294
- if not snapshot_taken[label]:
295
- snapshot_path = os.path.join(
296
- CONFIG["OUTPUT_DIR"],
297
- f"{label}_{frame_count}.jpg"
298
- )
299
- cv2.imwrite(snapshot_path, draw_detections(frame.copy(), [detection]))
300
- snapshots.append({
301
- "violation": label,
302
- "frame": frame_count,
303
- "path": snapshot_path,
304
- "url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
305
- })
306
- snapshot_taken[label] = True
307
 
308
  frame_count += 1
 
 
 
309
 
310
- cap.release()
311
-
312
- # Filter violations (require min frames)
313
- filtered_violations = []
314
- violation_counts = {}
315
- for v in violations:
316
- key = (v["worker_id"], v["violation"])
317
- violation_counts[key] = violation_counts.get(key, 0) + 1
318
-
319
- for v in violations:
320
- if violation_counts[(v["worker_id"], v["violation"])] >= CONFIG["MIN_VIOLATION_FRAMES"]:
321
- filtered_violations.append(v)
322
-
323
- # Calculate safety score
324
- penalty_weights = {
325
- "no_helmet": 25,
326
- "no_harness": 30,
327
- "unsafe_posture": 20,
328
- "unsafe_zone": 35,
329
- "improper_tool_use": 25
330
- }
331
- unique_violations = set((v["worker_id"], v["violation"]) for v in filtered_violations)
332
- total_penalty = sum(penalty_weights.get(v, 0) for _, v in unique_violations)
333
- safety_score = max(100 - total_penalty, 0)
334
-
335
- return {
336
- "violations": filtered_violations,
337
- "snapshots": snapshots,
338
- "score": safety_score,
339
- "message": ""
340
- }
341
- except Exception as e:
342
- logger.error(f"Video processing failed: {e}")
343
- return {
344
- "violations": [],
345
- "snapshots": [],
346
- "score": 100,
347
- "message": f"Error: {str(e)}"
348
- }
349
 
350
- # ==========================
351
- # Gradio Interface
352
- # ==========================
353
  def analyze_video(video_file):
354
- """Gradio interface function."""
355
  if not video_file:
356
- return "No video uploaded", "", "", "", ""
357
 
358
  try:
359
- # Process video
360
- result = process_video(video_file)
361
- if result["message"]:
362
- return result["message"], "", "", "", ""
363
 
364
- # Generate report
365
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(
366
- result["violations"],
367
- result["score"]
368
- )
369
- record_id, sf_url = push_report_to_salesforce(
370
- result["violations"],
371
- result["score"],
372
- pdf_file
373
- )
374
-
375
- # Format outputs
376
  violation_table = (
377
- "| Violation Type | Timestamp (s) | Confidence | Worker ID |\n"
378
- "|------------------------|---------------|------------|-----------|\n" +
379
  "\n".join(
380
- f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<22} | "
381
- f"{v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
382
- for v in result["violations"]
383
- ) if result["violations"] else "No violations detected."
384
  )
385
 
386
- snapshots_md = "\n".join(
387
- f"**{CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])}** "
388
- f"(Frame {s['frame']}): ![]({s['url']})"
389
- for s in result["snapshots"]
390
- ) if result["snapshots"] else "No snapshots available."
391
-
392
  return (
393
  violation_table,
394
- f"Safety Score: {result['score']}%",
395
- snapshots_md,
396
- f"Salesforce Record: {record_id or 'N/A'}",
397
- sf_url or pdf_url or "N/A"
398
  )
399
  except Exception as e:
400
- return f"Error: {str(e)}", "", "", "", ""
401
 
402
- # Launch Gradio App
403
  interface = gr.Interface(
404
  fn=analyze_video,
405
  inputs=gr.Video(label="Upload Site Video"),
406
  outputs=[
407
  gr.Markdown("## Detected Violations"),
408
- gr.Textbox(label="Safety Score"),
409
- gr.Markdown("## Violation Snapshots"),
410
- gr.Textbox(label="Salesforce Record ID"),
411
- gr.Textbox(label="Report URL")
412
  ],
413
- title="AI Safety Compliance Analyzer",
414
- description=(
415
- "Upload worksite video to detect safety violations. "
416
- "Supported violations: Missing Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use."
417
- )
418
  )
419
 
420
  if __name__ == "__main__":
421
- interface.launch(share=True)
 
2
  import cv2
3
  import gradio as gr
4
  import torch
 
5
  from ultralytics import YOLO
6
  import time
 
 
 
 
 
 
7
  import logging
8
+ from concurrent.futures import ThreadPoolExecutor
9
 
10
  # ==========================
11
+ # Optimized Configuration
12
  # ==========================
13
  CONFIG = {
14
+ "MODEL_PATH": "yolov8_safety.pt",
 
15
  "OUTPUT_DIR": "static/output",
16
  "VIOLATION_LABELS": {
17
  0: "no_helmet",
 
20
  3: "unsafe_zone",
21
  4: "improper_tool_use"
22
  },
23
+ "FRAME_SKIP": 10, # Increased for faster processing
24
+ "MAX_PROCESSING_TIME": 60,
25
+ "CONFIDENCE_THRESHOLD": 0.35, # Balanced threshold
26
+ "MIN_VIOLATION_FRAMES": 2, # Reduced from 3 to 2
27
+ "GPU_ACCELERATION": True # Enable GPU if available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  }
29
 
30
  # Setup logging
 
32
  logger = logging.getLogger(__name__)
33
 
34
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
35
+ device = torch.device("cuda" if torch.cuda.is_available() and CONFIG["GPU_ACCELERATION"] else "cpu")
36
  logger.info(f"Using device: {device}")
37
 
38
+ # Load model
39
+ model = YOLO(CONFIG["MODEL_PATH"]).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def process_frame(frame, frame_count, fps):
42
+ """Process a single frame for violations"""
43
+ results = model(frame, device=device, verbose=False) # Disable verbose logging
44
+ detections = []
45
+
46
+ for result in results:
47
+ for box in result.boxes:
48
+ cls = int(box.cls)
49
+ conf = float(box.conf)
50
+ if conf > CONFIG["CONFIDENCE_THRESHOLD"] and cls in CONFIG["VIOLATION_LABELS"]:
51
+ detections.append({
52
+ "frame": frame_count,
53
+ "violation": CONFIG["VIOLATION_LABELS"][cls],
54
+ "confidence": conf,
55
+ "bounding_box": box.xywh.cpu().numpy()[0],
56
+ "timestamp": frame_count / fps
57
+ })
58
+ return detections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
 
 
60
  def process_video(video_path):
61
+ """Optimized video processing with parallel frame processing"""
62
+ start_time = time.time()
63
+ cap = cv2.VideoCapture(video_path)
64
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
65
+ violations = []
66
+ frame_count = 0
67
+
68
+ with ThreadPoolExecutor() as executor:
69
+ futures = []
 
70
  while cap.isOpened():
71
  ret, frame = cap.read()
72
  if not ret:
73
  break
74
+
75
+ if frame_count % CONFIG["FRAME_SKIP"] == 0:
76
+ futures.append(executor.submit(process_frame, frame.copy(), frame_count, fps))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  frame_count += 1
79
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
80
+ logger.info("Processing time limit reached")
81
+ break
82
 
83
+ for future in futures:
84
+ violations.extend(future.result())
85
+
86
+ cap.release()
87
+
88
+ # Filter violations by frequency
89
+ violation_counts = {}
90
+ for v in violations:
91
+ key = (v["violation"], int(v["timestamp"]))
92
+ violation_counts[key] = violation_counts.get(key, 0) + 1
93
+
94
+ filtered_violations = [
95
+ v for v in violations
96
+ if violation_counts.get((v["violation"], int(v["timestamp"])), 0) >= CONFIG["MIN_VIOLATION_FRAMES"]
97
+ ]
98
+
99
+ return filtered_violations, time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
101
  def analyze_video(video_file):
102
+ """Optimized Gradio interface function"""
103
  if not video_file:
104
+ return "No video uploaded", "", "", ""
105
 
106
  try:
107
+ start_time = time.time()
108
+ violations, processing_time = process_video(video_file)
 
 
109
 
110
+ # Generate simple output (removed PDF generation for speed)
 
 
 
 
 
 
 
 
 
 
 
111
  violation_table = (
112
+ "| Violation Type | Timestamp (s) | Confidence |\n"
113
+ "|----------------|---------------|------------|\n" +
114
  "\n".join(
115
+ f"| {v['violation']:<14} | {v['timestamp']:.1f} | {v['confidence']:.2f} |"
116
+ for v in violations
117
+ ) if violations else "No violations detected."
 
118
  )
119
 
 
 
 
 
 
 
120
  return (
121
  violation_table,
122
+ f"Processing Time: {processing_time:.1f}s",
123
+ f"Violations Found: {len(violations)}",
124
+ f"Analysis Completed in {time.time()-start_time:.1f}s"
 
125
  )
126
  except Exception as e:
127
+ return f"Error: {str(e)}", "", "", ""
128
 
129
+ # Simplified Gradio Interface
130
  interface = gr.Interface(
131
  fn=analyze_video,
132
  inputs=gr.Video(label="Upload Site Video"),
133
  outputs=[
134
  gr.Markdown("## Detected Violations"),
135
+ gr.Textbox(label="Processing Info"),
136
+ gr.Textbox(label="Summary"),
137
+ gr.Textbox(label="Status")
 
138
  ],
139
+ title="Fast Safety Compliance Analyzer",
140
+ description="Optimized version for quick safety violation detection"
 
 
 
141
  )
142
 
143
  if __name__ == "__main__":
144
+ interface.launch()