PrashanthB461 commited on
Commit
8f347c4
·
verified ·
1 Parent(s): 2d42edf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +361 -230
app.py CHANGED
@@ -15,10 +15,10 @@ import logging
15
  from retrying import retry
16
 
17
  # ==========================
18
- # Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt", # Your custom-trained model
22
  "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
@@ -28,7 +28,7 @@ CONFIG = {
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
- "CLASS_COLORS": { # Bounding box colors
32
  "no_helmet": (0, 0, 255), # Red
33
  "no_harness": (0, 165, 255), # Orange
34
  "unsafe_posture": (0, 255, 0), # Green
@@ -36,28 +36,24 @@ CONFIG = {
36
  "improper_tool_use": (255, 255, 0) # Yellow
37
  },
38
  "DISPLAY_NAMES": {
39
- "no_helmet": "No Helmet",
40
- "no_harness": "No Safety Harness",
41
- "unsafe_posture": "Unsafe Posture",
42
  "unsafe_zone": "Unsafe Zone Entry",
43
  "improper_tool_use": "Improper Tool Use"
44
  },
45
- "SF_CREDENTIALS": { # Salesforce credentials
46
  "username": "prashanth1ai@safety.com",
47
  "password": "SaiPrash461",
48
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 20, # Increased to process every 20th frame (faster)
53
- "MAX_PROCESSING_TIME": 25, # Max processing time (seconds), leaving 5s for post-processing
54
- "CONFIDENCE_THRESHOLD": { # Per-class thresholds
55
- "no_helmet": 0.4,
56
- "no_harness": 0.3,
57
- "unsafe_posture": 0.25,
58
- "unsafe_zone": 0.3,
59
- "improper_tool_use": 0.35
60
- }
61
  }
62
 
63
  # Setup logging
@@ -65,339 +61,474 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
65
  logger = logging.getLogger(__name__)
66
 
67
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
 
68
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
  logger.info(f"Using device: {device}")
70
 
71
- # ==========================
72
- # Load YOLOv8 Model
73
- # ==========================
74
  def load_model():
75
  try:
76
  if os.path.isfile(CONFIG["MODEL_PATH"]):
77
- model = YOLO(CONFIG["MODEL_PATH"]).to(device)
78
- logger.info("Loaded custom safety model")
79
  else:
80
- model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
81
- logger.warning("Using fallback model (lower accuracy)")
 
 
 
 
82
  return model
83
  except Exception as e:
84
- logger.error(f"Model load failed: {e}")
85
  raise
86
 
87
  model = load_model()
88
 
89
  # ==========================
90
- # Core Detection Functions
91
  # ==========================
92
  def draw_detections(frame, detections):
93
- """Draw bounding boxes with labels on frame."""
94
  for det in detections:
95
  label = det["violation"]
96
- conf = det["confidence"]
97
  x, y, w, h = det["bounding_box"]
98
- x1, y1 = int(x - w/2), int(y - h/2)
99
- x2, y2 = int(x + w/2), int(y + h/2)
 
 
 
 
100
 
101
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
102
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
103
- cv2.putText(frame, f"{label}: {conf:.2f}", (x1, y1-10),
 
 
104
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
105
  return frame
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # ==========================
108
- # Salesforce Integration
109
  # ==========================
110
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
111
  def connect_to_salesforce():
112
  try:
113
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
114
- logger.info("Salesforce connection successful")
 
115
  return sf
116
  except Exception as e:
117
- logger.error(f"Salesforce login failed: {e}")
118
  raise
119
 
120
  def generate_violation_pdf(violations, score):
121
- """Generate PDF report with violations."""
122
  try:
 
 
123
  pdf_file = BytesIO()
124
  c = canvas.Canvas(pdf_file, pagesize=letter)
125
- c.setFont("Helvetica-Bold", 14)
126
- c.drawString(1 * inch, 10.5 * inch, "Worksite Safety Violation Report")
127
  c.setFont("Helvetica", 12)
128
-
129
- # Report metadata
130
- y_pos = 9.8 * inch
131
- report_data = [
132
- ("Compliance Score", f"{score}%"),
133
- ("Total Violations", len(violations)),
134
- ("Date", time.strftime("%Y-%m-%d")),
135
- ("Time", time.strftime("%H:%M:%S"))
136
- ]
137
- for label, value in report_data:
138
- c.drawString(1 * inch, y_pos, f"{label}: {value}")
139
- y_pos -= 0.4 * inch
140
-
141
- # Violation details
142
- y_pos -= 0.3 * inch
143
- c.setFont("Helvetica-Bold", 12)
144
- c.drawString(1 * inch, y_pos, "Violation Details:")
145
  c.setFont("Helvetica", 10)
146
- y_pos -= 0.3 * inch
147
-
 
 
 
 
 
 
 
 
 
 
 
 
148
  if not violations:
149
- c.drawString(1 * inch, y_pos, "No violations detected.")
150
  else:
151
  for v in violations:
152
- text = (
153
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
154
- f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
155
- )
156
- c.drawString(1 * inch, y_pos, text)
157
- y_pos -= 0.25 * inch
158
- if y_pos < 1 * inch:
159
  c.showPage()
160
- y_pos = 10 * inch
161
-
 
 
162
  c.save()
163
  pdf_file.seek(0)
164
-
165
- # Save PDF
166
- pdf_filename = f"violation_report_{int(time.time())}.pdf"
167
- pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
168
  with open(pdf_path, "wb") as f:
169
  f.write(pdf_file.getvalue())
170
-
171
- return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}", pdf_file
 
172
  except Exception as e:
173
- logger.error(f"PDF generation failed: {e}")
174
- return None, None, None
175
 
176
- def push_report_to_salesforce(violations, score, pdf_file):
177
- """Upload report to Salesforce."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  try:
179
  sf = connect_to_salesforce()
180
-
181
- # Create violation details text
182
  violations_text = "\n".join(
183
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} "
184
- f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
185
  for v in violations
186
  ) or "No violations detected."
187
-
188
- # Create Salesforce record
189
  record_data = {
190
  "Compliance_Score__c": score,
191
  "Violations_Found__c": len(violations),
192
  "Violations_Details__c": violations_text,
193
- "Status__c": "Pending Review"
 
194
  }
195
- record = sf.Safety_Video_Report__c.create(record_data)
 
 
 
 
 
 
 
196
  record_id = record["id"]
197
-
198
- # Upload PDF if available
199
- pdf_url = ""
200
  if pdf_file:
201
- encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode("utf-8")
202
- content_version = sf.ContentVersion.create({
203
- "Title": f"Safety_Report_{record_id}",
204
- "PathOnClient": f"report_{record_id}.pdf",
205
- "VersionData": encoded_pdf,
206
- "FirstPublishLocationId": record_id
207
- })
208
- pdf_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
209
-
 
 
210
  return record_id, pdf_url
211
  except Exception as e:
212
- logger.error(f"Salesforce upload failed: {e}")
213
  return None, ""
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  # ==========================
216
- # Video Processing
217
  # ==========================
218
- def process_video(video_path):
219
- """Analyze video for safety violations within 30 seconds."""
220
  try:
221
- start_time = time.time()
222
- cap = cv2.VideoCapture(video_path)
223
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
224
- frame_count = 0
 
 
 
 
 
225
  violations = []
226
  snapshots = []
 
 
 
 
 
 
 
 
 
227
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
228
- frames_to_save = [] # Store frames for snapshot saving later
229
-
230
- while cap.isOpened():
231
- ret, frame = cap.read()
 
 
232
  if not ret:
233
  break
234
-
235
- # Stop if processing time exceeds limit
236
- if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
237
- logger.info("Reached max processing time, stopping frame analysis")
238
- break
239
-
240
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
241
  frame_count += 1
242
  continue
 
 
 
 
 
 
243
 
244
- # Run detection
245
  results = model(frame, device=device)
246
- current_time = frame_count / fps
247
 
 
248
  for result in results:
249
- for box in result.boxes:
 
250
  cls = int(box.cls)
251
  conf = float(box.conf)
252
- label = CONFIG["VIOLATION_LABELS"].get(cls)
253
 
254
- if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3):
255
  continue
 
 
 
 
 
256
 
257
- bbox = box.xywh.cpu().numpy()[0]
258
- detection = {
259
  "frame": frame_count,
260
  "violation": label,
261
- "confidence": conf,
262
  "bounding_box": bbox,
263
  "timestamp": current_time
264
- }
265
-
266
- violations.append(detection)
267
-
268
- # Store frame for snapshot if first detection of this type
269
- if not snapshot_taken[label]:
270
- frames_to_save.append({
271
- "frame": frame.copy(),
272
- "detection": detection,
273
- "label": label
274
- })
275
- snapshot_taken[label] = True
276
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  frame_count += 1
 
 
 
278
 
279
- cap.release()
280
-
281
- # Save snapshots after processing all frames
282
- for item in frames_to_save:
283
- frame = item["frame"]
284
- detection = item["detection"]
285
- label = item["label"]
286
- snapshot_path = os.path.join(
287
- CONFIG["OUTPUT_DIR"],
288
- f"{label}_{detection['frame']}.jpg"
289
- )
290
- cv2.imwrite(snapshot_path, draw_detections(frame, [detection]))
291
- snapshots.append({
292
- "violation": label,
293
- "frame": detection["frame"],
294
- "path": snapshot_path,
295
- "url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
296
- })
297
-
298
- # Calculate safety score
299
- penalty_weights = {
300
- "no_helmet": 25,
301
- "no_harness": 30,
302
- "unsafe_posture": 20,
303
- "unsafe_zone": 35,
304
- "improper_tool_use": 25
305
- }
306
- unique_violations = set(v["violation"] for v in violations)
307
- total_penalty = sum(penalty_weights.get(v, 0) for v in unique_violations)
308
- safety_score = max(100 - total_penalty, 0)
309
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  return {
311
  "violations": violations,
312
  "snapshots": snapshots,
313
- "score": safety_score,
 
 
314
  "message": ""
315
  }
316
  except Exception as e:
317
- logger.error(f"Video processing failed: {e}")
318
  return {
319
  "violations": [],
320
  "snapshots": [],
321
  "score": 100,
322
- "message": f"Error: {str(e)}"
 
 
323
  }
324
 
325
  # ==========================
326
  # Gradio Interface
327
  # ==========================
328
- def analyze_video(video_file):
329
- """Gradio interface function."""
330
  if not video_file:
331
- return "No video uploaded", "", "", "", ""
332
-
333
  try:
334
- start_time = time.time()
335
-
336
- # Process video
337
- result = process_video(video_file)
338
- if result["message"]:
339
- return result["message"], "", "", "", ""
340
-
341
- # Generate report
342
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(
343
- result["violations"],
344
- result["score"]
345
- )
346
- record_id, sf_url = push_report_to_salesforce(
347
- result["violations"],
348
- result["score"],
349
- pdf_file
350
- )
351
-
352
- # Check total time
353
- total_time = time.time() - start_time
354
- if total_time > 30:
355
- logger.warning(f"Processing took {total_time:.2f}s, exceeded 30s target")
356
-
357
- # Format outputs
358
- violation_table = (
359
- "| Violation Type | Timestamp (s) | Confidence |\n"
360
- "|------------------------|---------------|------------|\n" +
361
- "\n".join(
362
- f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<22} | "
363
- f"{v['timestamp']:.2f} | {v['confidence']:.2f} |"
364
- for v in result["violations"]
365
- ) if result["violations"] else "No violations detected."
366
- )
367
-
368
- snapshots_md = "\n".join(
369
- f"**{CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])}** "
370
- f"(Frame {s['frame']}): ![]({s['url']})"
371
- for s in result["snapshots"]
372
- ) if result["snapshots"] else "No snapshots available."
373
-
374
- return (
375
  violation_table,
376
  f"Safety Score: {result['score']}%",
377
- snapshots_md,
378
- f"Salesforce Record: {record_id or 'N/A'}",
379
- sf_url or pdf_url or "N/A"
380
  )
381
  except Exception as e:
382
- return f"Error: {str(e)}", "", "", "", ""
 
383
 
384
- # Launch Gradio App
385
  interface = gr.Interface(
386
- fn=analyze_video,
387
  inputs=gr.Video(label="Upload Site Video"),
388
- outputs=[
389
- gr.Markdown("## Detected Violations"),
390
- gr.Textbox(label="Safety Score"),
391
- gr.Markdown("## Violation Snapshots"),
392
  gr.Textbox(label="Salesforce Record ID"),
393
- gr.Textbox(label="Report URL")
394
  ],
395
- title="AI Safety Compliance Analyzer",
396
- description=(
397
- "Upload worksite video to detect safety violations. "
398
- "Supported violations: Missing Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use."
399
- )
400
  )
401
 
402
  if __name__ == "__main__":
403
- interface.launch(share=True)
 
 
15
  from retrying import retry
16
 
17
  # ==========================
18
+ # Enhanced Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt",
22
  "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
 
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
+ "CLASS_COLORS": {
32
  "no_helmet": (0, 0, 255), # Red
33
  "no_harness": (0, 165, 255), # Orange
34
  "unsafe_posture": (0, 255, 0), # Green
 
36
  "improper_tool_use": (255, 255, 0) # Yellow
37
  },
38
  "DISPLAY_NAMES": {
39
+ "no_helmet": "No Helmet Violation",
40
+ "no_harness": "No Harness Violation",
41
+ "unsafe_posture": "Unsafe Posture Violation",
42
  "unsafe_zone": "Unsafe Zone Entry",
43
  "improper_tool_use": "Improper Tool Use"
44
  },
45
+ "SF_CREDENTIALS": {
46
  "username": "prashanth1ai@safety.com",
47
  "password": "SaiPrash461",
48
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 5, # Reduced for better detection
53
+ "MAX_PROCESSING_TIME": 60,
54
+ "CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
55
+ "IOU_THRESHOLD": 0.4,
56
+ "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
 
 
 
 
57
  }
58
 
59
  # Setup logging
 
61
  logger = logging.getLogger(__name__)
62
 
63
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
64
+
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  logger.info(f"Using device: {device}")
67
 
 
 
 
68
  def load_model():
69
  try:
70
  if os.path.isfile(CONFIG["MODEL_PATH"]):
71
+ model_path = CONFIG["MODEL_PATH"]
72
+ logger.info(f"Model loaded: {model_path}")
73
  else:
74
+ model_path = CONFIG["FALLBACK_MODEL"]
75
+ logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
76
+ if not os.path.isfile(model_path):
77
+ logger.info(f"Downloading fallback model: {model_path}")
78
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
79
+ model = YOLO(model_path).to(device)
80
  return model
81
  except Exception as e:
82
+ logger.error(f"Failed to load model: {e}")
83
  raise
84
 
85
  model = load_model()
86
 
87
  # ==========================
88
+ # Enhanced Helper Functions
89
  # ==========================
90
  def draw_detections(frame, detections):
91
+ """Draw bounding boxes and labels on frame"""
92
  for det in detections:
93
  label = det["violation"]
94
+ confidence = det["confidence"]
95
  x, y, w, h = det["bounding_box"]
96
+
97
+ # Convert from center coordinates to corner coordinates
98
+ x1 = int(x - w/2)
99
+ y1 = int(y - h/2)
100
+ x2 = int(x + w/2)
101
+ y2 = int(y + h/2)
102
 
103
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
+
106
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
107
+ cv2.putText(frame, display_text, (x1, y1-10),
108
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
109
  return frame
110
 
111
+ def calculate_iou(box1, box2):
112
+ """Calculate Intersection over Union (IoU) for two bounding boxes."""
113
+ x1, y1, w1, h1 = box1
114
+ x2, y2, w2, h2 = box2
115
+
116
+ # Convert to top-left and bottom-right coordinates
117
+ x1_min, y1_min = x1 - w1/2, y1 - h1/2
118
+ x1_max, y1_max = x1 + w1/2, y1 + h1/2
119
+ x2_min, y2_min = x2 - w2/2, y2 - h2/2
120
+ x2_max, y2_max = x2 + w2/2, y2 + h2/2
121
+
122
+ # Calculate intersection
123
+ x_min = max(x1_min, x2_min)
124
+ y_min = max(y1_min, y2_min)
125
+ x_max = min(x1_max, x2_max)
126
+ y_max = min(y1_max, y2_max)
127
+
128
+ intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
129
+ area1 = w1 * h1
130
+ area2 = w2 * h2
131
+ union = area1 + area2 - intersection
132
+
133
+ return intersection / union if union > 0 else 0
134
+
135
  # ==========================
136
+ # Salesforce Integration (unchanged)
137
  # ==========================
138
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
139
  def connect_to_salesforce():
140
  try:
141
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
142
+ logger.info("Connected to Salesforce")
143
+ sf.describe()
144
  return sf
145
  except Exception as e:
146
+ logger.error(f"Salesforce connection failed: {e}")
147
  raise
148
 
149
  def generate_violation_pdf(violations, score):
 
150
  try:
151
+ pdf_filename = f"violations_{int(time.time())}.pdf"
152
+ pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
153
  pdf_file = BytesIO()
154
  c = canvas.Canvas(pdf_file, pagesize=letter)
 
 
155
  c.setFont("Helvetica", 12)
156
+ c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  c.setFont("Helvetica", 10)
158
+
159
+ y_position = 9.5 * inch
160
+ report_data = {
161
+ "Compliance Score": f"{score}%",
162
+ "Violations Found": len(violations),
163
+ "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
164
+ }
165
+ for key, value in report_data.items():
166
+ c.drawString(1 * inch, y_position, f"{key}: {value}")
167
+ y_position -= 0.3 * inch
168
+
169
+ y_position -= 0.3 * inch
170
+ c.drawString(1 * inch, y_position, "Violation Details:")
171
+ y_position -= 0.3 * inch
172
  if not violations:
173
+ c.drawString(1 * inch, y_position, "No violations detected.")
174
  else:
175
  for v in violations:
176
+ display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
177
+ text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
178
+ c.drawString(1 * inch, y_position, text)
179
+ y_position -= 0.3 * inch
180
+ if y_position < 1 * inch:
 
 
181
  c.showPage()
182
+ c.setFont("Helvetica", 10)
183
+ y_position = 10 * inch
184
+
185
+ c.showPage()
186
  c.save()
187
  pdf_file.seek(0)
188
+
 
 
 
189
  with open(pdf_path, "wb") as f:
190
  f.write(pdf_file.getvalue())
191
+ public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
192
+ logger.info(f"PDF generated: {public_url}")
193
+ return pdf_path, public_url, pdf_file
194
  except Exception as e:
195
+ logger.error(f"Error generating PDF: {e}")
196
+ return "", "", None
197
 
198
+ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
199
+ try:
200
+ if not pdf_file:
201
+ logger.error("No PDF file provided for upload")
202
+ return ""
203
+ encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
204
+ content_version_data = {
205
+ "Title": f"Safety_Violation_Report_{int(time.time())}",
206
+ "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
207
+ "VersionData": encoded_pdf,
208
+ "FirstPublishLocationId": report_id
209
+ }
210
+ content_version = sf.ContentVersion.create(content_version_data)
211
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
212
+ if not result['records']:
213
+ logger.error("Failed to retrieve ContentVersion")
214
+ return ""
215
+ file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
216
+ logger.info(f"PDF uploaded to Salesforce: {file_url}")
217
+ return file_url
218
+ except Exception as e:
219
+ logger.error(f"Error uploading PDF to Salesforce: {e}")
220
+ return ""
221
+
222
+ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
223
  try:
224
  sf = connect_to_salesforce()
 
 
225
  violations_text = "\n".join(
226
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
 
227
  for v in violations
228
  ) or "No violations detected."
229
+ pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
230
+
231
  record_data = {
232
  "Compliance_Score__c": score,
233
  "Violations_Found__c": len(violations),
234
  "Violations_Details__c": violations_text,
235
+ "Status__c": "Pending",
236
+ "PDF_Report_URL__c": pdf_url
237
  }
238
+ logger.info(f"Creating Salesforce record with data: {record_data}")
239
+ try:
240
+ record = sf.Safety_Video_Report__c.create(record_data)
241
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
242
+ except Exception as e:
243
+ logger.error(f"Failed to create Safety_Video_Report__c: {e}")
244
+ record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
245
+ logger.warning(f"Fell back to Account record: {record['id']}")
246
  record_id = record["id"]
247
+
 
 
248
  if pdf_file:
249
+ uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
250
+ if uploaded_url:
251
+ try:
252
+ sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
253
+ logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
254
+ except Exception as e:
255
+ logger.error(f"Failed to update Safety_Video_Report__c: {e}")
256
+ sf.Account.update(record_id, {"Description": uploaded_url})
257
+ logger.info(f"Updated Account record {record_id} with PDF URL")
258
+ pdf_url = uploaded_url
259
+
260
  return record_id, pdf_url
261
  except Exception as e:
262
+ logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
263
  return None, ""
264
 
265
+ def calculate_safety_score(violations):
266
+ penalties = {
267
+ "no_helmet": 25,
268
+ "no_harness": 30,
269
+ "unsafe_posture": 20,
270
+ "unsafe_zone": 35,
271
+ "improper_tool_use": 25
272
+ }
273
+ # Count unique violations per worker
274
+ unique_violations = set()
275
+ for v in violations:
276
+ key = (v["worker_id"], v["violation"])
277
+ unique_violations.add(key)
278
+
279
+ total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
280
+ score = 100 - total_penalty
281
+ return max(score, 0)
282
+
283
  # ==========================
284
+ # Enhanced Video Processing
285
  # ==========================
286
+ def process_video(video_data):
 
287
  try:
288
+ video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
289
+ with open(video_path, "wb") as f:
290
+ f.write(video_data)
291
+ logger.info(f"Video saved: {video_path}")
292
+
293
+ video = cv2.VideoCapture(video_path)
294
+ if not video.isOpened():
295
+ raise ValueError("Could not open video file")
296
+
297
  violations = []
298
  snapshots = []
299
+ frame_count = 0
300
+ start_time = time.time()
301
+ fps = video.get(cv2.CAP_PROP_FPS)
302
+ if fps <= 0:
303
+ fps = 30 # Default assumption if FPS cannot be determined
304
+
305
+ # Structure to track workers and their violations
306
+ workers = []
307
+ violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
309
+
310
+ logger.info(f"Processing video with FPS: {fps}")
311
+ logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
312
+
313
+ while True:
314
+ ret, frame = video.read()
315
  if not ret:
316
  break
317
+
 
 
 
 
 
318
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
319
  frame_count += 1
320
  continue
321
+
322
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
323
+ logger.info("Processing time limit reached")
324
+ break
325
+
326
+ current_time = frame_count / fps
327
 
328
+ # Run detection on this frame
329
  results = model(frame, device=device)
 
330
 
331
+ current_detections = []
332
  for result in results:
333
+ boxes = result.boxes
334
+ for box in boxes:
335
  cls = int(box.cls)
336
  conf = float(box.conf)
337
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
 
339
+ if label is None:
340
  continue
341
+
342
+ if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
343
+ continue
344
+
345
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
 
347
+ current_detections.append({
 
348
  "frame": frame_count,
349
  "violation": label,
350
+ "confidence": round(conf, 2),
351
  "bounding_box": bbox,
352
  "timestamp": current_time
353
+ })
354
+
355
+ # Process detections and associate with workers
356
+ for detection in current_detections:
357
+ # Find matching worker
358
+ matched_worker = None
359
+ max_iou = 0
360
+
361
+ for worker in workers:
362
+ iou = calculate_iou(detection["bounding_box"], worker["bbox"])
363
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
364
+ max_iou = iou
365
+ matched_worker = worker
366
+
367
+ if matched_worker:
368
+ # Update worker's position
369
+ matched_worker["bbox"] = detection["bounding_box"]
370
+ matched_worker["last_seen"] = current_time
371
+ worker_id = matched_worker["id"]
372
+ else:
373
+ # New worker
374
+ worker_id = len(workers) + 1
375
+ workers.append({
376
+ "id": worker_id,
377
+ "bbox": detection["bounding_box"],
378
+ "first_seen": current_time,
379
+ "last_seen": current_time
380
+ })
381
+
382
+ # Add to violation history
383
+ detection["worker_id"] = worker_id
384
+ violation_history[detection["violation"]].append(detection)
385
+
386
  frame_count += 1
387
+
388
+ video.release()
389
+ os.remove(video_path)
390
 
391
+ # Process violation history to confirm persistent violations
392
+ for violation_type, detections in violation_history.items():
393
+ if not detections:
394
+ continue
395
+
396
+ # Group by worker
397
+ worker_violations = {}
398
+ for det in detections:
399
+ if det["worker_id"] not in worker_violations:
400
+ worker_violations[det["worker_id"]] = []
401
+ worker_violations[det["worker_id"]].append(det)
402
+
403
+ # Check each worker's violations for persistence
404
+ for worker_id, worker_dets in worker_violations.items():
405
+ if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
406
+ # Take the highest confidence detection
407
+ best_detection = max(worker_dets, key=lambda x: x["confidence"])
408
+ violations.append(best_detection)
409
+
410
+ # Capture snapshot if not already taken
411
+ if not snapshot_taken[violation_type]:
412
+ # Get the frame for this violation
413
+ cap = cv2.VideoCapture(video_path)
414
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
415
+ ret, snapshot_frame = cap.read()
416
+ cap.release()
417
+
418
+ if ret:
419
+ # Draw detections on snapshot
420
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
421
+
422
+ snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
423
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
424
+ cv2.imwrite(snapshot_path, snapshot_frame)
425
+ snapshots.append({
426
+ "violation": violation_type,
427
+ "frame": best_detection["frame"],
428
+ "snapshot_path": snapshot_path,
429
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
430
+ })
431
+ snapshot_taken[violation_type] = True
432
+
433
+ # Final processing
434
+ if not violations:
435
+ logger.info("No persistent violations detected")
436
+ return {
437
+ "violations": [],
438
+ "snapshots": [],
439
+ "score": 100,
440
+ "salesforce_record_id": None,
441
+ "violation_details_url": "",
442
+ "message": "No violations detected in the video."
443
+ }
444
+
445
+ score = calculate_safety_score(violations)
446
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
447
+ report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
448
+
449
  return {
450
  "violations": violations,
451
  "snapshots": snapshots,
452
+ "score": score,
453
+ "salesforce_record_id": report_id,
454
+ "violation_details_url": final_pdf_url,
455
  "message": ""
456
  }
457
  except Exception as e:
458
+ logger.error(f"Error processing video: {e}", exc_info=True)
459
  return {
460
  "violations": [],
461
  "snapshots": [],
462
  "score": 100,
463
+ "salesforce_record_id": None,
464
+ "violation_details_url": "",
465
+ "message": f"Error processing video: {e}"
466
  }
467
 
468
  # ==========================
469
  # Gradio Interface
470
  # ==========================
471
+ def gradio_interface(video_file):
 
472
  if not video_file:
473
+ return "No file uploaded.", "", "No file uploaded.", "", ""
 
474
  try:
475
+ yield "Processing video... please wait.", "", "", "", ""
476
+
477
+ with open(video_file, "rb") as f:
478
+ video_data = f.read()
479
+
480
+ result = process_video(video_data)
481
+
482
+ if result.get("message"):
483
+ yield result["message"], "", "", "", ""
484
+ return
485
+
486
+ violation_table = "No violations detected."
487
+ if result["violations"]:
488
+ header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
489
+ separator = "|------------------------|---------------|------------|-----------|\n"
490
+ rows = []
491
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
492
+ for v in result["violations"]:
493
+ display_name = violation_name_map.get(v["violation"], v["violation"])
494
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
495
+ rows.append(row)
496
+ violation_table = header + separator + "\n".join(rows)
497
+
498
+ snapshots_text = "No snapshots captured."
499
+ if result["snapshots"]:
500
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
501
+ snapshots_text = "\n".join(
502
+ f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
503
+ for s in result["snapshots"]
504
+ )
505
+
506
+ yield (
 
 
 
 
 
 
 
 
 
507
  violation_table,
508
  f"Safety Score: {result['score']}%",
509
+ snapshots_text,
510
+ f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
511
+ result["violation_details_url"] or "N/A"
512
  )
513
  except Exception as e:
514
+ logger.error(f"Error in Gradio interface: {e}", exc_info=True)
515
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
516
 
 
517
  interface = gr.Interface(
518
+ fn=gradio_interface,
519
  inputs=gr.Video(label="Upload Site Video"),
520
+ outputs=[
521
+ gr.Markdown(label="Detected Safety Violations"),
522
+ gr.Textbox(label="Compliance Score"),
523
+ gr.Markdown(label="Snapshots"),
524
  gr.Textbox(label="Salesforce Record ID"),
525
+ gr.Textbox(label="Violation Details URL")
526
  ],
527
+ title="Worksite Safety Violation Analyzer",
528
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
529
+ allow_flagging="never"
 
 
530
  )
531
 
532
  if __name__ == "__main__":
533
+ logger.info("Launching Enhanced Safety Analyzer App...")
534
+ interface.launch()