PrashanthB461 commited on
Commit
7ecb9d9
·
verified ·
1 Parent(s): baeebd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -139
app.py CHANGED
@@ -16,7 +16,6 @@ from io import BytesIO
16
  import base64
17
  from retrying import retry
18
  from collections import defaultdict
19
- from multiprocessing import cpu_count
20
 
21
  # ========================== # Configuration and Setup # ==========================
22
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
@@ -26,7 +25,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
26
  logger = logging.getLogger(__name__)
27
  warnings.filterwarnings("ignore")
28
 
29
- # ========================== # Optimized Tracker Implementation (No Face Recognition) # ==========================
30
  class SafetyTracker:
31
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
32
  self.track_thresh = track_thresh
@@ -34,9 +33,10 @@ class SafetyTracker:
34
  self.match_thresh = match_thresh
35
  self.frame_rate = frame_rate
36
  self.next_id = 1
37
- self.worker_tracks = {}
38
- self.violation_history = defaultdict(dict)
39
- self.position_history = defaultdict(list)
 
40
 
41
  self.VIOLATION_COOLDOWNS = {
42
  "no_helmet": 30.0,
@@ -46,7 +46,7 @@ class SafetyTracker:
46
  "improper_tool_use": 15.0
47
  }
48
 
49
- def update(self, detections, frame):
50
  current_time = time.time()
51
  new_violations = []
52
 
@@ -56,6 +56,7 @@ class SafetyTracker:
56
  confidence = det['confidence']
57
 
58
  worker_id = self._match_by_position(bbox, label)
 
59
  if worker_id is None:
60
  worker_id = self.next_id
61
  self.next_id += 1
@@ -82,30 +83,40 @@ class SafetyTracker:
82
  return new_violations
83
 
84
  def _match_by_position(self, bbox, label):
85
- x, y = bbox[0], bbox[1]
 
 
86
  for worker_id, positions in self.position_history.items():
87
- for pos in positions[-5:]: # Check last 5 positions
88
- if np.sqrt((x-pos[0])**2 + (y-pos[1])**2) < 100:
89
- return worker_id
 
 
 
 
90
  return None
91
 
92
  def _is_new_violation(self, worker_id, label, current_time):
93
  if label not in self.violation_history[worker_id]:
94
  return True
95
- return (current_time - self.violation_history[worker_id][label]) > self.VIOLATION_COOLDOWNS.get(label, 10.0)
 
 
 
96
 
97
  def _cleanup_tracks(self, current_time):
98
  inactive_ids = [
99
- wid for wid, track in self.worker_tracks.items()
100
  if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
101
  ]
102
- for wid in inactive_ids:
103
- self.worker_tracks.pop(wid, None)
104
- self.position_history.pop(wid, None)
105
- if (current_time - max(self.violation_history[wid].values(), default=0)) > 300:
106
- self.violation_history.pop(wid, None)
 
107
 
108
- # ========================== # Configuration # ==========================
109
  CONFIG = {
110
  "MODEL_PATH": "yolov8_safety.pt",
111
  "FALLBACK_MODEL": "yolov8n.pt",
@@ -125,10 +136,10 @@ CONFIG = {
125
  "improper_tool_use": (255, 255, 0)
126
  },
127
  "DISPLAY_NAMES": {
128
- "no_helmet": "No Helmet",
129
- "no_harness": "No Harness",
130
  "unsafe_posture": "Unsafe Posture",
131
- "unsafe_zone": "Unsafe Zone",
132
  "improper_tool_use": "Improper Tool Use"
133
  },
134
  "SF_CREDENTIALS": {
@@ -153,35 +164,53 @@ CONFIG = {
153
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
  logger.info(f"Using device: {device}")
155
 
156
- # ========================== # Core Functions # ==========================
157
  def load_model():
158
  try:
159
- model_path = CONFIG["MODEL_PATH"] if os.path.exists(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
160
- if not os.path.exists(model_path):
161
- torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
162
- return YOLO(model_path).to(device)
 
 
 
163
  except Exception as e:
164
  logger.error(f"Model loading failed: {e}")
165
  raise
166
 
167
  model = load_model()
168
 
 
 
 
 
 
169
  def draw_detections(frame, detections):
170
- annotated = frame.copy()
171
  for det in detections:
172
- x, y, w, h = det['bbox']
173
- x1, y1 = int(x-w/2), int(y-h/2)
174
- x2, y2 = int(x+w/2), int(y+h/2)
175
- color = CONFIG["CLASS_COLORS"][det['violation']]
176
- cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
177
- label = f"{CONFIG['DISPLAY_NAMES'][det['violation']]} (Worker {det['worker_id']})"
178
- cv2.putText(annotated, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 2)
179
- return annotated
 
 
 
 
 
 
 
 
180
 
181
  def calculate_safety_score(violations):
182
- penalties = {"no_helmet":25, "no_harness":30, "unsafe_posture":20, "unsafe_zone":35, "improper_tool_use":25}
183
- unique_violations = {v['violation'] for v in violations}
184
- return max(0, 100 - sum(penalties.get(v,0) for v in unique_violations))
 
 
 
185
 
186
  def generate_violation_pdf(violations, score):
187
  try:
@@ -190,7 +219,7 @@ def generate_violation_pdf(violations, score):
190
 
191
  # Header
192
  c.setFont("Helvetica-Bold", 16)
193
- c.drawString(1*inch, 10*inch, "Safety Violation Report")
194
  c.setFont("Helvetica", 12)
195
  c.drawString(1*inch, 9.5*inch, f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}")
196
  c.drawString(1*inch, 9*inch, f"Safety Score: {score}%")
@@ -198,28 +227,29 @@ def generate_violation_pdf(violations, score):
198
  # Violations List
199
  y = 8.5*inch
200
  c.setFont("Helvetica-Bold", 14)
201
- c.drawString(1*inch, y, "Violations Detected:")
202
  y -= 0.3*inch
203
  c.setFont("Helvetica", 10)
204
 
205
  for v in violations:
206
- text = f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]} at {v['timestamp']:.1f}s"
207
- c.drawString(1.2*inch, y, text)
208
- y -= 0.2*inch
209
  if y < 1*inch:
210
  c.showPage()
211
  y = 10*inch
212
- c.setFont("Helvetica", 10)
 
213
 
214
  c.save()
215
  pdf_buffer.seek(0)
216
 
217
  # Save to file
218
- pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], f"report_{int(time.time())}.pdf")
 
219
  with open(pdf_path, "wb") as f:
220
  f.write(pdf_buffer.getvalue())
221
 
222
- return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}", pdf_buffer
223
  except Exception as e:
224
  logger.error(f"PDF generation failed: {e}")
225
  return None, None, None
@@ -228,119 +258,173 @@ def generate_violation_pdf(violations, score):
228
  def connect_to_salesforce():
229
  try:
230
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
231
- sf.describe()
232
  return sf
233
  except Exception as e:
234
  logger.error(f"Salesforce connection failed: {e}")
235
  raise
236
 
237
- def upload_to_salesforce(sf, pdf_file, violations, score):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  try:
239
- # Create record
240
  violations_text = "\n".join(
241
- f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]} at {v['timestamp']:.1f}s"
 
242
  for v in violations
243
  )
244
 
245
- record = sf.Safety_Video_Report__c.create({
246
  "Compliance_Score__c": score,
247
  "Violations_Found__c": len(violations),
248
- "Violations_Details__c": violations_text
249
- })
250
-
251
- # Upload PDF
252
- encoded = base64.b64encode(pdf_file.getvalue()).decode()
253
- content = sf.ContentVersion.create({
254
- "Title": f"Safety_Report_{int(time.time())}",
255
- "PathOnClient": "report.pdf",
256
- "VersionData": encoded,
257
- "FirstPublishLocationId": record['id']
258
- })
259
 
260
- return record['id'], f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content['id']}"
 
 
 
 
 
261
  except Exception as e:
262
- logger.error(f"Salesforce upload failed: {e}")
263
- return None, None
264
 
265
- # ========================== # Video Processing # ==========================
266
  def process_video(video_data):
267
  try:
268
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
269
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
270
  with open(video_path, "wb") as f:
271
  f.write(video_data)
272
-
273
  cap = cv2.VideoCapture(video_path)
 
 
 
274
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
275
  tracker = SafetyTracker(frame_rate=fps)
276
  snapshots = []
277
-
278
- while cap.isOpened():
279
- ret, frame = cap.read()
280
- if not ret:
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  break
282
-
283
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
284
- results = model(frame, verbose=False)
285
-
286
- detections = []
287
- for box in results[0].boxes:
288
- cls = int(box.cls)
289
- label = CONFIG["VIOLATION_LABELS"].get(cls)
290
- if label and box.conf > CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
291
- detections.append({
292
- 'bbox': box.xywh[0].cpu().numpy(),
293
- 'violation': label,
294
- 'confidence': float(box.conf)
295
- })
296
-
297
- new_violations = tracker.update(detections, frame)
298
 
299
- for violation in new_violations:
300
- snapshot = draw_detections(frame.copy(), [violation])
301
- timestamp = time.strftime("%Y%m%d_%H%M%S")
302
- img_path = os.path.join(CONFIG["OUTPUT_DIR"], f"violation_{violation['worker_id']}_{timestamp}.jpg")
303
- cv2.imwrite(img_path, cv2.cvtColor(snapshot, cv2.COLOR_RGB2BGR),
304
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
305
- snapshots.append({
306
- 'path': img_path,
307
- 'url': f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(img_path)}",
308
- 'violation': violation
309
- })
 
 
 
310
 
311
- yield f"Processing frame {int(cap.get(cv2.CAP_PROP_POS_FRAMES))}...", "", "", "", ""
312
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  cap.release()
314
- os.remove(video_path)
315
-
316
- if not snapshots:
317
- yield "No violations detected", "Safety Score: 100%", "No snapshots", "N/A", "N/A"
 
 
 
 
 
 
 
318
  return
319
-
320
- score = calculate_safety_score([v['violation'] for v in snapshots])
321
- pdf_path, pdf_url, pdf_file = generate_violation_pdf([v['violation'] for v in snapshots], score)
 
322
 
323
- if pdf_file:
324
- record_id, sf_url = upload_to_salesforce(connect_to_salesforce(), pdf_file,
325
- [v['violation'] for v in snapshots], score)
326
- else:
327
- record_id, sf_url = None, None
328
-
329
- snapshots_md = "\n".join(
330
- f"![{v['violation']['violation']}]({v['url']})"
331
- for v in snapshots
332
  )
333
-
 
 
 
 
 
 
334
  yield (
335
- "\n".join(f"- {v['violation']['violation']} (Worker {v['violation']['worker_id']})" for v in snapshots),
336
  f"Safety Score: {score}%",
337
  snapshots_md,
338
- f"Salesforce ID: {record_id or 'N/A'}",
339
- sf_url or pdf_url or "N/A"
340
  )
341
-
342
  except Exception as e:
343
- logger.error(f"Processing failed: {e}")
344
  if 'video_path' in locals() and os.path.exists(video_path):
345
  os.remove(video_path)
346
  yield f"Error: {str(e)}", "", "", "", ""
@@ -350,22 +434,34 @@ def gradio_interface(video):
350
  if not video:
351
  return "Upload a video file", "", "", "", ""
352
 
353
- for update in process_video(open(video, "rb").read()):
354
- yield update
355
-
356
- interface = gr.Interface(
357
- fn=gradio_interface,
358
- inputs=gr.Video(),
359
- outputs=[
360
- gr.Markdown("Detected Violations"),
361
- gr.Textbox("Safety Score"),
362
- gr.Markdown("Evidence Snapshots"),
363
- gr.Textbox("Salesforce Record"),
364
- gr.Textbox("Report URL")
365
- ],
366
- title="AI Safety Compliance Analyzer",
367
- description="Detects PPE and safety violations in worksite videos"
368
- )
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  if __name__ == "__main__":
371
- interface.launch()
 
16
  import base64
17
  from retrying import retry
18
  from collections import defaultdict
 
19
 
20
  # ========================== # Configuration and Setup # ==========================
21
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
 
25
  logger = logging.getLogger(__name__)
26
  warnings.filterwarnings("ignore")
27
 
28
+ # ========================== # Position-Based Tracker (No Face Recognition) # ==========================
29
  class SafetyTracker:
30
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
31
  self.track_thresh = track_thresh
 
33
  self.match_thresh = match_thresh
34
  self.frame_rate = frame_rate
35
  self.next_id = 1
36
+
37
+ self.worker_tracks = {} # Active worker tracks
38
+ self.violation_history = defaultdict(dict) # Track violations per worker
39
+ self.position_history = defaultdict(list) # Track positions for all violations
40
 
41
  self.VIOLATION_COOLDOWNS = {
42
  "no_helmet": 30.0,
 
46
  "improper_tool_use": 15.0
47
  }
48
 
49
+ def update(self, detections):
50
  current_time = time.time()
51
  new_violations = []
52
 
 
56
  confidence = det['confidence']
57
 
58
  worker_id = self._match_by_position(bbox, label)
59
+
60
  if worker_id is None:
61
  worker_id = self.next_id
62
  self.next_id += 1
 
83
  return new_violations
84
 
85
  def _match_by_position(self, bbox, label):
86
+ x, y, w, h = bbox
87
+ current_pos = (x, y)
88
+
89
  for worker_id, positions in self.position_history.items():
90
+ if not positions:
91
+ continue
92
+
93
+ last_pos = positions[-1]
94
+ distance = np.sqrt((current_pos[0]-last_pos[0])**2 + (current_pos[1]-last_pos[1])**2)
95
+ if distance < 100: # Within 100 pixels
96
+ return worker_id
97
  return None
98
 
99
  def _is_new_violation(self, worker_id, label, current_time):
100
  if label not in self.violation_history[worker_id]:
101
  return True
102
+
103
+ last_detection = self.violation_history[worker_id][label]
104
+ cooldown = self.VIOLATION_COOLDOWNS.get(label, 10.0)
105
+ return (current_time - last_detection) > cooldown
106
 
107
  def _cleanup_tracks(self, current_time):
108
  inactive_ids = [
109
+ worker_id for worker_id, track in self.worker_tracks.items()
110
  if (current_time - track['last_seen']) > (self.track_buffer / self.frame_rate)
111
  ]
112
+
113
+ for worker_id in inactive_ids:
114
+ self.worker_tracks.pop(worker_id, None)
115
+ self.position_history.pop(worker_id, None)
116
+ if (current_time - max(self.violation_history[worker_id].values(), default=0)) > 300:
117
+ self.violation_history.pop(worker_id, None)
118
 
119
+ # ========================== # Optimized Configuration # ==========================
120
  CONFIG = {
121
  "MODEL_PATH": "yolov8_safety.pt",
122
  "FALLBACK_MODEL": "yolov8n.pt",
 
136
  "improper_tool_use": (255, 255, 0)
137
  },
138
  "DISPLAY_NAMES": {
139
+ "no_helmet": "No Helmet Violation",
140
+ "no_harness": "No Harness Violation",
141
  "unsafe_posture": "Unsafe Posture",
142
+ "unsafe_zone": "Unsafe Zone Entry",
143
  "improper_tool_use": "Improper Tool Use"
144
  },
145
  "SF_CREDENTIALS": {
 
164
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
  logger.info(f"Using device: {device}")
166
 
 
167
  def load_model():
168
  try:
169
+ if os.path.isfile(CONFIG["MODEL_PATH"]):
170
+ model = YOLO(CONFIG["MODEL_PATH"]).to(device)
171
+ logger.info(f"Loaded custom model: {CONFIG['MODEL_PATH']}")
172
+ else:
173
+ model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device)
174
+ logger.warning("Using fallback YOLOv8n model")
175
+ return model
176
  except Exception as e:
177
  logger.error(f"Model loading failed: {e}")
178
  raise
179
 
180
  model = load_model()
181
 
182
+ # ========================== # Core Functions # ==========================
183
+ def preprocess_frame(frame):
184
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
185
+ return frame
186
+
187
  def draw_detections(frame, detections):
188
+ result_frame = frame.copy()
189
  for det in detections:
190
+ label = det["violation"]
191
+ confidence = det["confidence"]
192
+ x, y, w, h = det["bbox"]
193
+ worker_id = det["worker_id"]
194
+
195
+ x1, y1 = int(x - w/2), int(y - h/2)
196
+ x2, y2 = int(x + w/2), int(y + h/2)
197
+ color = CONFIG["CLASS_COLORS"][label]
198
+
199
+ cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
200
+ text = f"{CONFIG['DISPLAY_NAMES'][label]} (Worker {worker_id})"
201
+ (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
202
+ cv2.rectangle(result_frame, (x1, y1-th-10), (x1+tw+10, y1), (0,0,0), -1)
203
+ cv2.putText(result_frame, text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
204
+ cv2.putText(result_frame, f"Conf: {confidence:.2f}", (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 2)
205
+ return result_frame
206
 
207
  def calculate_safety_score(violations):
208
+ penalties = {
209
+ "no_helmet": 25, "no_harness": 30, "unsafe_posture": 20,
210
+ "unsafe_zone": 35, "improper_tool_use": 25
211
+ }
212
+ unique_violations = {v["violation"] for v in violations}
213
+ return max(0, 100 - sum(penalties.get(v, 0) for v in unique_violations))
214
 
215
  def generate_violation_pdf(violations, score):
216
  try:
 
219
 
220
  # Header
221
  c.setFont("Helvetica-Bold", 16)
222
+ c.drawString(1*inch, 10*inch, "Worksite Safety Violation Report")
223
  c.setFont("Helvetica", 12)
224
  c.drawString(1*inch, 9.5*inch, f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}")
225
  c.drawString(1*inch, 9*inch, f"Safety Score: {score}%")
 
227
  # Violations List
228
  y = 8.5*inch
229
  c.setFont("Helvetica-Bold", 14)
230
+ c.drawString(1*inch, y, "Detected Violations:")
231
  y -= 0.3*inch
232
  c.setFont("Helvetica", 10)
233
 
234
  for v in violations:
235
+ text = (f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']} "
236
+ f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})")
 
237
  if y < 1*inch:
238
  c.showPage()
239
  y = 10*inch
240
+ c.drawString(1.2*inch, y, text)
241
+ y -= 0.2*inch
242
 
243
  c.save()
244
  pdf_buffer.seek(0)
245
 
246
  # Save to file
247
+ pdf_filename = f"violation_report_{int(time.time())}.pdf"
248
+ pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
249
  with open(pdf_path, "wb") as f:
250
  f.write(pdf_buffer.getvalue())
251
 
252
+ return pdf_path, f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}", pdf_buffer
253
  except Exception as e:
254
  logger.error(f"PDF generation failed: {e}")
255
  return None, None, None
 
258
  def connect_to_salesforce():
259
  try:
260
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
261
+ logger.info("Salesforce connection established")
262
  return sf
263
  except Exception as e:
264
  logger.error(f"Salesforce connection failed: {e}")
265
  raise
266
 
267
+ def upload_to_salesforce(sf, pdf_file, record_id):
268
+ try:
269
+ encoded = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
270
+ file_data = {
271
+ "Title": f"Safety_Report_{int(time.time())}",
272
+ "PathOnClient": "safety_report.pdf",
273
+ "VersionData": encoded,
274
+ "FirstPublishLocationId": record_id
275
+ }
276
+ result = sf.ContentVersion.create(file_data)
277
+ return f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{result['id']}"
278
+ except Exception as e:
279
+ logger.error(f"Salesforce upload failed: {e}")
280
+ return None
281
+
282
+ def create_salesforce_record(violations, score, pdf_url=None):
283
  try:
284
+ sf = connect_to_salesforce()
285
  violations_text = "\n".join(
286
+ f"Worker {v['worker_id']}: {CONFIG['DISPLAY_NAMES'][v['violation']]} "
287
+ f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f})"
288
  for v in violations
289
  )
290
 
291
+ record_data = {
292
  "Compliance_Score__c": score,
293
  "Violations_Found__c": len(violations),
294
+ "Violations_Details__c": violations_text,
295
+ "Status__c": "Pending",
296
+ "PDF_Report_URL__c": pdf_url or ""
297
+ }
 
 
 
 
 
 
 
298
 
299
+ try:
300
+ record = sf.Safety_Video_Report__c.create(record_data)
301
+ return record["id"], None
302
+ except:
303
+ record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
304
+ return record["id"], "Used Account as fallback"
305
  except Exception as e:
306
+ logger.error(f"Salesforce record creation failed: {e}")
307
+ return None, str(e)
308
 
 
309
  def process_video(video_data):
310
  try:
311
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
312
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
313
  with open(video_path, "wb") as f:
314
  f.write(video_data)
315
+
316
  cap = cv2.VideoCapture(video_path)
317
+ if not cap.isOpened():
318
+ raise ValueError("Failed to open video")
319
+
320
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
321
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
322
  tracker = SafetyTracker(frame_rate=fps)
323
  snapshots = []
324
+ processed_frames = 0
325
+ last_update = time.time()
326
+
327
+ while processed_frames < total_frames:
328
+ batch_frames = []
329
+ for _ in range(CONFIG["BATCH_SIZE"]):
330
+ ret, frame = cap.read()
331
+ if not ret:
332
+ break
333
+ batch_frames.append(preprocess_frame(frame))
334
+ processed_frames += 1
335
+ if CONFIG["FRAME_SKIP"] > 1:
336
+ for _ in range(CONFIG["FRAME_SKIP"]-1):
337
+ cap.grab()
338
+ processed_frames += 1
339
+
340
+ if not batch_frames:
341
  break
342
+
343
+ results = model(batch_frames, device=device, conf=0.1, verbose=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
+ for i, result in enumerate(results):
346
+ detections = []
347
+ for box in result.boxes:
348
+ cls = int(box.cls)
349
+ conf = float(box.conf)
350
+ label = CONFIG["VIOLATION_LABELS"].get(cls)
351
+ if label and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
352
+ detections.append({
353
+ "bbox": box.xywh.cpu().numpy()[0],
354
+ "violation": label,
355
+ "confidence": conf
356
+ })
357
+
358
+ new_violations = tracker.update(detections)
359
 
360
+ for violation in new_violations:
361
+ frame_with_det = draw_detections(batch_frames[i].copy(), [violation])
362
+ timestamp = f"Time: {violation['timestamp']:.2f}s"
363
+ cv2.putText(frame_with_det, timestamp, (10, 30),
364
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
365
+
366
+ snap_name = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
367
+ snap_path = os.path.join(CONFIG["OUTPUT_DIR"], snap_name)
368
+ cv2.imwrite(snap_path, frame_with_det,
369
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
370
+
371
+ snapshots.append({
372
+ "violation": violation['violation'],
373
+ "worker_id": violation['worker_id'],
374
+ "timestamp": violation['timestamp'],
375
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snap_name}"
376
+ })
377
+
378
+ if time.time() - last_update > 1:
379
+ progress = (processed_frames / total_frames) * 100
380
+ yield f"Processing... {progress:.1f}%", "", "", "", ""
381
+ last_update = time.time()
382
+
383
  cap.release()
384
+ if os.path.exists(video_path):
385
+ os.remove(video_path)
386
+
387
+ violations = [
388
+ {"worker_id": wid, "violation": v, "timestamp": t, "confidence": 0} # Confidence placeholder
389
+ for wid, violations in tracker.violation_history.items()
390
+ for v, t in violations.items()
391
+ ]
392
+
393
+ if not violations:
394
+ yield "No violations found", "Safety Score: 100%", "No snapshots", "N/A", "N/A"
395
  return
396
+
397
+ score = calculate_safety_score(violations)
398
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
399
+ record_id, sf_error = create_salesforce_record(violations, score, pdf_url)
400
 
401
+ if pdf_file and record_id:
402
+ uploaded_url = upload_to_salesforce(connect_to_salesforce(), pdf_file, record_id)
403
+ if uploaded_url:
404
+ pdf_url = uploaded_url
405
+
406
+ violation_table = "| Violation | Worker ID | Time (s) |\n|-----------|-----------|----------|\n"
407
+ violation_table += "\n".join(
408
+ f"| {CONFIG['DISPLAY_NAMES'][v['violation']]} | {v['worker_id']} | {v['timestamp']:.2f} |"
409
+ for v in sorted(violations, key=lambda x: x['timestamp'])
410
  )
411
+
412
+ snapshots_md = "\n\n".join(
413
+ f"### {CONFIG['DISPLAY_NAMES'][s['violation']]} - Worker {s['worker_id']} at {s['timestamp']:.2f}s\n\n"
414
+ f"![Snapshot]({s['snapshot_url']})"
415
+ for s in snapshots
416
+ ) if snapshots else "No snapshots captured"
417
+
418
  yield (
419
+ violation_table,
420
  f"Safety Score: {score}%",
421
  snapshots_md,
422
+ f"Salesforce ID: {record_id or 'N/A'} {sf_error or ''}",
423
+ pdf_url or "N/A"
424
  )
425
+
426
  except Exception as e:
427
+ logger.error(f"Video processing failed: {e}")
428
  if 'video_path' in locals() and os.path.exists(video_path):
429
  os.remove(video_path)
430
  yield f"Error: {str(e)}", "", "", "", ""
 
434
  if not video:
435
  return "Upload a video file", "", "", "", ""
436
 
437
+ try:
438
+ with open(video, "rb") as f:
439
+ video_data = f.read()
440
+
441
+ for output in process_video(video_data):
442
+ yield output
443
+ except Exception as e:
444
+ logger.error(f"Interface error: {e}")
445
+ yield f"Error: {str(e)}", "", "", "", ""
446
+
447
+ with gr.Blocks(title="Safety Compliance Analyzer") as app:
448
+ gr.Markdown("# Worksite Safety Violation Analyzer")
449
+ gr.Markdown("Upload site videos to detect safety violations (No Helmet, No Harness, etc.)")
450
+
451
+ with gr.Row():
452
+ video_input = gr.Video(label="Site Video", sources=["upload"])
453
+ with gr.Column():
454
+ violations_out = gr.Markdown(label="Detected Violations")
455
+ score_out = gr.Textbox(label="Safety Score")
456
+ snapshots_out = gr.Markdown(label="Violation Snapshots")
457
+ salesforce_out = gr.Textbox(label="Salesforce Record")
458
+ pdf_out = gr.Textbox(label="Report PDF URL")
459
+
460
+ video_input.change(
461
+ gradio_interface,
462
+ inputs=video_input,
463
+ outputs=[violations_out, score_out, snapshots_out, salesforce_out, pdf_out]
464
+ )
465
 
466
  if __name__ == "__main__":
467
+ app.launch(server_port=7860, server_name="0.0.0.0")