PrashanthB461 commited on
Commit
508af1e
·
verified ·
1 Parent(s): 3edce5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +360 -116
app.py CHANGED
@@ -20,18 +20,14 @@ import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
22
 
23
- # ==========================
24
- # Configuration and Setup
25
- # ==========================
26
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
27
  os.makedirs('/tmp/Ultralytics', exist_ok=True)
28
 
29
- logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
30
  logger = logging.getLogger(__name__)
31
 
32
- # ==========================
33
- # ByteTrack Implementation
34
- # ==========================
35
  class BYTETracker:
36
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
37
  self.track_thresh = track_thresh
@@ -39,27 +35,102 @@ class BYTETracker:
39
  self.match_thresh = match_thresh
40
  self.frame_rate = frame_rate
41
  self.next_id = 1
42
-
 
43
  def update(self, dets, scores, cls):
44
  tracks = []
 
 
45
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
46
  if score < self.track_thresh:
47
  logger.debug(f"Skipping detection with score {score} below threshold {self.track_thresh}")
48
  continue
49
 
50
  x, y, w, h = det
51
- tracks.append({
52
- 'id': self.next_id,
53
- 'bbox': [x, y, w, h],
54
- 'score': score,
55
- 'cls': cl
56
- })
57
- self.next_id += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  return tracks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # ==========================
61
- # Optimized Configuration
62
- # ==========================
63
  CONFIG = {
64
  "MODEL_PATH": "yolov8_safety.pt",
65
  "FALLBACK_MODEL": "yolov8n.pt",
@@ -72,11 +143,11 @@ CONFIG = {
72
  4: "improper_tool_use"
73
  },
74
  "CLASS_COLORS": {
75
- "no_helmet": (0, 0, 255),
76
- "no_harness": (0, 165, 255),
77
- "unsafe_posture": (0, 255, 0),
78
- "unsafe_zone": (255, 0, 0),
79
- "improper_tool_use": (255, 255, 0)
80
  },
81
  "DISPLAY_NAMES": {
82
  "no_helmet": "No Helmet Violation",
@@ -93,21 +164,23 @@ CONFIG = {
93
  },
94
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Sadio2/resolve/main/static/output/",
95
  "CONFIDENCE_THRESHOLDS": {
96
- "no_helmet": 0.5, # Lowered from 0.75
97
- "no_harness": 0.3, # Lowered from 0.4
98
- "unsafe_posture": 0.3, # Lowered from 0.4
99
- "unsafe_zone": 0.3, # Lowered from 0.4
100
- "improper_tool_use": 0.3 # Lowered from 0.4
101
  },
102
- "MIN_VIOLATION_FRAMES": 1, # Lowered from 3
 
103
  "WORKER_TRACKING_DURATION": 5.0,
104
  "MAX_PROCESSING_TIME": 60,
105
  "FRAME_SKIP": 1,
106
- "BATCH_SIZE": 16, # Reduced from 32 to prevent memory issues
107
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
108
  "TRACK_BUFFER": 30,
109
- "TRACK_THRESH": 0.3, # Lowered from 0.4
110
- "MATCH_THRESH": 0.7 # Lowered from 0.8
 
111
  }
112
 
113
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -124,6 +197,7 @@ def load_model():
124
  if not os.path.isfile(model_path):
125
  logger.info(f"Downloading fallback model: {model_path}")
126
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
 
127
  model = YOLO(model_path).to(device)
128
  logger.info(f"Model classes: {model.names}")
129
  return model
@@ -133,33 +207,47 @@ def load_model():
133
 
134
  model = load_model()
135
 
136
- # ==========================
137
- # Helper Functions
138
- # ==========================
139
  def preprocess_frame(frame):
140
  """Apply basic preprocessing to enhance detection"""
141
  frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) # Increase contrast
142
  return frame
143
 
144
  def draw_detections(frame, detections):
 
 
 
145
  for det in detections:
146
  label = det.get("violation", "Unknown")
147
  confidence = det.get("confidence", 0.0)
148
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
149
-
150
  x1 = int(x - w/2)
151
  y1 = int(y - h/2)
152
  x2 = int(x + w/2)
153
  y2 = int(y + h/2)
154
 
155
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
156
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
157
 
 
 
 
 
158
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
159
- cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
160
- return frame
 
 
 
 
 
 
 
 
 
161
 
162
  def calculate_safety_score(violations):
 
163
  penalties = {
164
  "no_helmet": 25,
165
  "no_harness": 30,
@@ -167,41 +255,78 @@ def calculate_safety_score(violations):
167
  "unsafe_zone": 35,
168
  "improper_tool_use": 25
169
  }
170
- total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
 
 
 
 
 
 
 
171
  score = 100 - total_penalty
172
  return max(score, 0)
173
 
174
  def generate_violation_pdf(violations, score):
 
175
  try:
176
  pdf_filename = f"violations_{int(time.time())}.pdf"
177
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
178
  pdf_file = BytesIO()
179
  c = canvas.Canvas(pdf_file, pagesize=letter)
180
- c.setFont("Helvetica", 12)
181
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
182
- c.setFont("Helvetica", 10)
 
 
 
 
 
 
183
 
184
- y_position = 9.5 * inch
185
- report_data = {
186
- "Compliance Score": f"{score}%",
187
- "Violations Found": len(violations),
188
- "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
 
 
 
 
 
189
  }
190
- for key, value in report_data.items():
 
191
  c.drawString(1 * inch, y_position, f"{key}: {value}")
192
- y_position -= 0.3 * inch
193
 
194
- y_position -= 0.3 * inch
195
- c.drawString(1 * inch, y_position, "Violation Details:")
196
- y_position -= 0.3 * inch
197
  if not violations:
 
198
  c.drawString(1 * inch, y_position, "No violations detected.")
199
  else:
200
- for v in violations:
 
 
 
 
 
 
 
 
 
 
201
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
202
- text = f"{display_name} from {v.get('start_timestamp', 0.0):.2f}s to {v.get('end_timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f}, Worker ID: {v.get('worker_id', 'N/A')})"
 
 
 
 
203
  c.drawString(1 * inch, y_position, text)
 
 
 
 
204
  y_position -= 0.3 * inch
 
205
  if y_position < 1 * inch:
206
  c.showPage()
207
  c.setFont("Helvetica", 10)
@@ -222,6 +347,7 @@ def generate_violation_pdf(violations, score):
222
 
223
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
224
  def connect_to_salesforce():
 
225
  try:
226
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
227
  logger.info("Connected to Salesforce")
@@ -232,10 +358,12 @@ def connect_to_salesforce():
232
  raise
233
 
234
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
 
235
  try:
236
  if not pdf_file:
237
  logger.error("No PDF file provided for upload")
238
  return ""
 
239
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
240
  content_version_data = {
241
  "Title": f"Safety_Violation_Report_{int(time.time())}",
@@ -245,9 +373,11 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
245
  }
246
  content_version = sf.ContentVersion.create(content_version_data)
247
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
 
248
  if not result['records']:
249
  logger.error("Failed to retrieve ContentVersion")
250
  return ""
 
251
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
252
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
253
  return file_url
@@ -256,12 +386,24 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
256
  return ""
257
 
258
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
 
259
  try:
260
  sf = connect_to_salesforce()
261
- violations_text = "\n".join(
262
- f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} from {v.get('start_timestamp', 0.0):.2f}s to {v.get('end_timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f}, Worker ID: {v.get('worker_id', 'N/A')})"
263
- for v in violations
264
- ) or "No violations detected."
 
 
 
 
 
 
 
 
 
 
 
265
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
266
 
267
  record_data = {
@@ -271,7 +413,9 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
271
  "Status__c": "Pending",
272
  "PDF_Report_URL__c": pdf_url
273
  }
 
274
  logger.info(f"Creating Salesforce record with data: {record_data}")
 
275
  try:
276
  record = sf.Safety_Video_Report__c.create(record_data)
277
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
@@ -279,6 +423,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
279
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
280
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
281
  logger.warning(f"Fell back to Account record: {record['id']}")
 
282
  record_id = record["id"]
283
 
284
  if pdf_file:
@@ -299,6 +444,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
299
  return None, ""
300
 
301
  def process_video(video_data):
 
302
  try:
303
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
304
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
@@ -327,12 +473,14 @@ def process_video(video_data):
327
  frame_rate=fps
328
  )
329
 
330
- violation_tracker = {}
 
331
  snapshots = []
332
  start_time = time.time()
333
  frame_skip = CONFIG["FRAME_SKIP"]
 
334
 
335
- while True:
336
  batch_frames = []
337
  batch_indices = []
338
 
@@ -347,6 +495,7 @@ def process_video(video_data):
347
 
348
  frame = preprocess_frame(frame)
349
 
 
350
  for _ in range(frame_skip - 1):
351
  if not cap.grab():
352
  break
@@ -357,11 +506,14 @@ def process_video(video_data):
357
  if not batch_frames:
358
  break
359
 
 
360
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
361
 
362
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
 
363
  current_time = frame_idx / fps
364
 
 
365
  if time.time() - start_time > 1.0:
366
  progress = (frame_idx / total_frames) * 100
367
  yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
@@ -369,6 +521,7 @@ def process_video(video_data):
369
 
370
  boxes = result.boxes
371
  track_inputs = []
 
372
  for box in boxes:
373
  cls = int(box.cls)
374
  conf = float(box.conf)
@@ -377,6 +530,7 @@ def process_video(video_data):
377
  if label is None:
378
  logger.debug(f"Unknown class ID {cls} detected, skipping")
379
  continue
 
380
  if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
381
  logger.debug(f"Detection for {label} with confidence {conf} below threshold {CONFIG['CONFIDENCE_THRESHOLDS'].get(label, 0.25)}")
382
  continue
@@ -388,89 +542,178 @@ def process_video(video_data):
388
  "cls": cls
389
  })
390
 
 
 
 
 
391
  tracked_objects = tracker.update(
392
  np.array([t["bbox"] for t in track_inputs]),
393
  np.array([t["conf"] for t in track_inputs]),
394
  np.array([t["cls"] for t in track_inputs])
395
  )
 
396
  logger.debug(f"Frame {frame_idx}: {len(tracked_objects)} objects tracked")
397
 
398
- for obj, track_input in zip(tracked_objects, track_inputs):
 
399
  worker_id = obj['id']
400
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
401
- bbox = track_input["bbox"]
402
- conf = track_input["conf"]
403
 
404
- detection = {
405
- "frame": frame_idx,
406
- "violation": label,
407
- "confidence": round(conf, 2),
408
- "bounding_box": [round(x, 2) for x in bbox],
409
- "timestamp": current_time,
410
- "worker_id": worker_id
411
- }
412
- logger.debug(f"Detection: {detection}")
413
-
414
- if worker_id not in violation_tracker:
415
- violation_tracker[worker_id] = {}
416
- if label not in violation_tracker[worker_id]:
417
- violation_tracker[worker_id][label] = []
418
- violation_tracker[worker_id][label].append(detection)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  cap.release()
421
  if os.path.exists(video_path):
422
  os.remove(video_path)
 
423
  processing_time = time.time() - start_time
424
  logger.info(f"Processing complete in {processing_time:.2f}s")
425
 
 
426
  violations = []
427
- for worker_id, worker_violations in violation_tracker.items():
428
- for label, detections in worker_violations.items():
429
- if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
430
- best_detection = max(detections, key=lambda x: x["confidence"])
431
- best_detection["start_timestamp"] = min(d["timestamp"] for d in detections)
432
- best_detection["end_timestamp"] = max(d["timestamp"] for d in detections)
433
- violations.append(best_detection)
434
-
435
- cap = cv2.VideoCapture(video_path)
436
- if not cap.isOpened():
437
- logger.warning(f"Could not reopen video for snapshot at frame {best_detection['frame']}")
438
- continue
439
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
440
- ret, snapshot_frame = cap.read()
441
- if ret:
442
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
443
- snapshot_filename = f"{label}_{best_detection['frame']}.jpg"
444
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
445
- cv2.imwrite(snapshot_path, snapshot_frame)
446
- snapshots.append({
447
- "violation": label,
448
- "frame": best_detection["frame"],
449
- "snapshot_path": snapshot_path,
450
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
451
- })
452
- cap.release()
453
 
454
  if not violations:
455
  logger.info("No violations detected after processing")
456
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
457
  return
458
 
 
459
  score = calculate_safety_score(violations)
 
 
460
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
 
 
461
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
462
 
463
- violation_table = "| Violation | Time Range (s) | Confidence | Worker ID |\n"
464
- violation_table += "|------------------------|----------------|------------|-----------|\n"
465
- for v in sorted(violations, key=lambda x: x["start_timestamp"]):
 
 
466
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
467
- row = f"| {display_name:<22} | {v.get('start_timestamp', 0.0):.2f}-{v.get('end_timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
 
 
 
 
 
468
  violation_table += row
469
 
470
- snapshots_text = "\n".join(
471
- f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
472
- for s in snapshots
473
- ) if snapshots else "No snapshots captured."
 
 
 
 
 
 
 
 
474
 
475
  yield (
476
  violation_table,
@@ -487,21 +730,22 @@ def process_video(video_data):
487
  yield f"Error processing video: {e}", "", "", "", ""
488
 
489
  def gradio_interface(video_file):
 
490
  if not video_file:
491
  return "No file uploaded.", "", "No file uploaded.", "", ""
 
492
  try:
493
  with open(video_file, "rb") as f:
494
  video_data = f.read()
495
 
496
  for status, score, snapshots_text, record_id, details_url in process_video(video_data):
497
  yield status, score, snapshots_text, record_id, details_url
 
498
  except Exception as e:
499
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
500
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
501
 
502
- # ==========================
503
- # Gradio Interface
504
- # ==========================
505
  interface = gr.Interface(
506
  fn=gradio_interface,
507
  inputs=gr.Video(label="Upload Site Video"),
@@ -513,7 +757,7 @@ interface = gr.Interface(
513
  gr.Textbox(label="Violation Details URL")
514
  ],
515
  title="Worksite Safety Violation Analyzer",
516
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
517
  allow_flagging="never"
518
  )
519
 
 
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
22
 
23
+ # ========================== # Configuration and Setup # ==========================
 
 
24
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
25
  os.makedirs('/tmp/Ultralytics', exist_ok=True)
26
 
27
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
  logger = logging.getLogger(__name__)
29
 
30
+ # ========================== # ByteTrack Implementation # ==========================
 
 
31
  class BYTETracker:
32
  def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
33
  self.track_thresh = track_thresh
 
35
  self.match_thresh = match_thresh
36
  self.frame_rate = frame_rate
37
  self.next_id = 1
38
+ self.tracks = {} # Store active tracks
39
+
40
  def update(self, dets, scores, cls):
41
  tracks = []
42
+
43
+ # Update existing tracks with new detections
44
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
45
  if score < self.track_thresh:
46
  logger.debug(f"Skipping detection with score {score} below threshold {self.track_thresh}")
47
  continue
48
 
49
  x, y, w, h = det
50
+
51
+ # Try to match with existing tracks
52
+ matched = False
53
+ for track_id, track_info in self.tracks.items():
54
+ # Simple IOU-based matching
55
+ tx, ty, tw, th = track_info['bbox']
56
+ iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
57
+
58
+ if iou > self.match_thresh and track_info['cls'] == cl:
59
+ # Update existing track
60
+ self.tracks[track_id] = {
61
+ 'bbox': [x, y, w, h],
62
+ 'score': score,
63
+ 'cls': cl,
64
+ 'last_seen': time.time()
65
+ }
66
+ tracks.append({
67
+ 'id': track_id,
68
+ 'bbox': [x, y, w, h],
69
+ 'score': score,
70
+ 'cls': cl
71
+ })
72
+ matched = True
73
+ break
74
+
75
+ if not matched:
76
+ # Create new track
77
+ self.tracks[self.next_id] = {
78
+ 'bbox': [x, y, w, h],
79
+ 'score': score,
80
+ 'cls': cl,
81
+ 'last_seen': time.time()
82
+ }
83
+ tracks.append({
84
+ 'id': self.next_id,
85
+ 'bbox': [x, y, w, h],
86
+ 'score': score,
87
+ 'cls': cl
88
+ })
89
+ self.next_id += 1
90
+
91
+ # Remove stale tracks
92
+ current_time = time.time()
93
+ stale_ids = []
94
+ for track_id, track_info in self.tracks.items():
95
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
96
+ stale_ids.append(track_id)
97
+
98
+ for track_id in stale_ids:
99
+ del self.tracks[track_id]
100
+
101
  return tracks
102
+
103
+ def _calculate_iou(self, box1, box2):
104
+ """Calculate IOU between two boxes in format [x, y, w, h]"""
105
+ x1, y1, w1, h1 = box1
106
+ x2, y2, w2, h2 = box2
107
+
108
+ # Convert to xmin, ymin, xmax, ymax
109
+ xmin1, ymin1 = x1 - w1/2, y1 - h1/2
110
+ xmax1, ymax1 = x1 + w1/2, y1 + h1/2
111
+ xmin2, ymin2 = x2 - w2/2, y2 - h2/2
112
+ xmax2, ymax2 = x2 + w2/2, y2 + h2/2
113
+
114
+ # Calculate area of intersection
115
+ x_left = max(xmin1, xmin2)
116
+ y_top = max(ymin1, ymin2)
117
+ x_right = min(xmax1, xmax2)
118
+ y_bottom = min(ymax1, ymax2)
119
+
120
+ if x_right < x_left or y_bottom < y_top:
121
+ return 0.0
122
+
123
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
124
+
125
+ # Calculate area of both boxes
126
+ box1_area = w1 * h1
127
+ box2_area = w2 * h2
128
+
129
+ # Calculate IOU
130
+ iou = intersection_area / (box1_area + box2_area - intersection_area)
131
+ return iou
132
 
133
+ # ========================== # Optimized Configuration # ==========================
 
 
134
  CONFIG = {
135
  "MODEL_PATH": "yolov8_safety.pt",
136
  "FALLBACK_MODEL": "yolov8n.pt",
 
143
  4: "improper_tool_use"
144
  },
145
  "CLASS_COLORS": {
146
+ "no_helmet": (0, 0, 255), # Red in BGR
147
+ "no_harness": (0, 165, 255), # Orange in BGR
148
+ "unsafe_posture": (0, 255, 0), # Green in BGR
149
+ "unsafe_zone": (255, 0, 0), # Blue in BGR
150
+ "improper_tool_use": (255, 255, 0) # Cyan in BGR
151
  },
152
  "DISPLAY_NAMES": {
153
  "no_helmet": "No Helmet Violation",
 
164
  },
165
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Sadio2/resolve/main/static/output/",
166
  "CONFIDENCE_THRESHOLDS": {
167
+ "no_helmet": 0.5,
168
+ "no_harness": 0.3,
169
+ "unsafe_posture": 0.3,
170
+ "unsafe_zone": 0.3,
171
+ "improper_tool_use": 0.3
172
  },
173
+ "MIN_VIOLATION_FRAMES": 1,
174
+ "VIOLATION_COOLDOWN": 5.0, # Time in seconds before same violation type can be detected again for the same worker
175
  "WORKER_TRACKING_DURATION": 5.0,
176
  "MAX_PROCESSING_TIME": 60,
177
  "FRAME_SKIP": 1,
178
+ "BATCH_SIZE": 16,
179
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
180
  "TRACK_BUFFER": 30,
181
+ "TRACK_THRESH": 0.3,
182
+ "MATCH_THRESH": 0.7,
183
+ "SNAPSHOT_QUALITY": 90 # JPEG quality for snapshots (0-100)
184
  }
185
 
186
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
197
  if not os.path.isfile(model_path):
198
  logger.info(f"Downloading fallback model: {model_path}")
199
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
200
+
201
  model = YOLO(model_path).to(device)
202
  logger.info(f"Model classes: {model.names}")
203
  return model
 
207
 
208
  model = load_model()
209
 
210
+ # ========================== # Helper Functions # ==========================
 
 
211
  def preprocess_frame(frame):
212
  """Apply basic preprocessing to enhance detection"""
213
  frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) # Increase contrast
214
  return frame
215
 
216
  def draw_detections(frame, detections):
217
+ """Draw bounding boxes and labels on detection frame with improved visibility"""
218
+ result_frame = frame.copy()
219
+
220
  for det in detections:
221
  label = det.get("violation", "Unknown")
222
  confidence = det.get("confidence", 0.0)
223
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
224
+
225
  x1 = int(x - w/2)
226
  y1 = int(y - h/2)
227
  x2 = int(x + w/2)
228
  y2 = int(y + h/2)
229
 
230
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
 
231
 
232
+ # Draw thicker rectangle with border for better visibility
233
+ cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
234
+
235
+ # Add a black background behind text for better readability
236
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
237
+ text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
238
+ cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
239
+ cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
240
+
241
+ # Draw worker ID
242
+ worker_id = det.get("worker_id", "Unknown")
243
+ worker_text = f"Worker: {worker_id}"
244
+ cv2.putText(result_frame, worker_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
245
+ cv2.putText(result_frame, worker_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
246
+
247
+ return result_frame
248
 
249
  def calculate_safety_score(violations):
250
+ """Calculate safety score based on detected violations"""
251
  penalties = {
252
  "no_helmet": 25,
253
  "no_harness": 30,
 
255
  "unsafe_zone": 35,
256
  "improper_tool_use": 25
257
  }
258
+
259
+ # Count unique violation types
260
+ unique_violations = set()
261
+ for v in violations:
262
+ unique_violations.add(v.get("violation", "Unknown"))
263
+
264
+ # Calculate penalty based on unique violation types
265
+ total_penalty = sum(penalties.get(v, 0) for v in unique_violations)
266
  score = 100 - total_penalty
267
  return max(score, 0)
268
 
269
  def generate_violation_pdf(violations, score):
270
+ """Generate a PDF report for the detected violations"""
271
  try:
272
  pdf_filename = f"violations_{int(time.time())}.pdf"
273
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
274
  pdf_file = BytesIO()
275
  c = canvas.Canvas(pdf_file, pagesize=letter)
276
+ c.setFont("Helvetica-Bold", 16)
277
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
278
+
279
+ c.setFont("Helvetica", 12)
280
+ c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
281
+ c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
282
+
283
+ c.setFont("Helvetica-Bold", 14)
284
+ c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
285
 
286
+ y_position = 8.2 * inch
287
+ c.setFont("Helvetica-Bold", 12)
288
+ c.drawString(1 * inch, y_position, "Summary:")
289
+ y_position -= 0.3 * inch
290
+
291
+ c.setFont("Helvetica", 10)
292
+ summary_data = {
293
+ "Total Violations Found": len(violations),
294
+ "Unique Workers with Violations": len(set(v.get("worker_id", "Unknown") for v in violations)),
295
+ "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
296
  }
297
+
298
+ for key, value in summary_data.items():
299
  c.drawString(1 * inch, y_position, f"{key}: {value}")
300
+ y_position -= 0.25 * inch
301
 
 
 
 
302
  if not violations:
303
+ y_position -= 0.3 * inch
304
  c.drawString(1 * inch, y_position, "No violations detected.")
305
  else:
306
+ y_position -= 0.5 * inch
307
+ c.setFont("Helvetica-Bold", 12)
308
+ c.drawString(1 * inch, y_position, "Violation Details:")
309
+ y_position -= 0.3 * inch
310
+
311
+ c.setFont("Helvetica", 10)
312
+ # Sort violations by worker ID and type for better organization
313
+ sorted_violations = sorted(violations, key=lambda v: (v.get("worker_id", "Unknown"), v.get("violation", "Unknown")))
314
+
315
+ for v in sorted_violations:
316
+ worker_id = v.get("worker_id", "Unknown")
317
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
318
+ start_time = v.get('start_timestamp', 0.0)
319
+ end_time = v.get('end_timestamp', 0.0)
320
+ confidence = v.get('confidence', 0.0)
321
+
322
+ text = f"Worker ID: {worker_id} - {display_name}"
323
  c.drawString(1 * inch, y_position, text)
324
+ y_position -= 0.2 * inch
325
+
326
+ details = f" Time: {start_time:.2f}s to {end_time:.2f}s (Confidence: {confidence:.2f})"
327
+ c.drawString(1.2 * inch, y_position, details)
328
  y_position -= 0.3 * inch
329
+
330
  if y_position < 1 * inch:
331
  c.showPage()
332
  c.setFont("Helvetica", 10)
 
347
 
348
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
349
  def connect_to_salesforce():
350
+ """Connect to Salesforce with retry logic"""
351
  try:
352
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
353
  logger.info("Connected to Salesforce")
 
358
  raise
359
 
360
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
361
+ """Upload PDF report to Salesforce"""
362
  try:
363
  if not pdf_file:
364
  logger.error("No PDF file provided for upload")
365
  return ""
366
+
367
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
368
  content_version_data = {
369
  "Title": f"Safety_Violation_Report_{int(time.time())}",
 
373
  }
374
  content_version = sf.ContentVersion.create(content_version_data)
375
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
376
+
377
  if not result['records']:
378
  logger.error("Failed to retrieve ContentVersion")
379
  return ""
380
+
381
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
382
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
383
  return file_url
 
386
  return ""
387
 
388
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
389
+ """Push violation report to Salesforce"""
390
  try:
391
  sf = connect_to_salesforce()
392
+
393
+ # Format violations for Salesforce
394
+ violations_text = ""
395
+ for v in violations:
396
+ display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
397
+ worker_id = v.get('worker_id', 'N/A')
398
+ start_time = v.get('start_timestamp', 0.0)
399
+ end_time = v.get('end_timestamp', 0.0)
400
+ confidence = v.get('confidence', 0.0)
401
+
402
+ violations_text += f"Worker {worker_id}: {display_name} ({start_time:.2f}s-{end_time:.2f}s, Conf: {confidence:.2f})\n"
403
+
404
+ if not violations_text:
405
+ violations_text = "No violations detected."
406
+
407
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
408
 
409
  record_data = {
 
413
  "Status__c": "Pending",
414
  "PDF_Report_URL__c": pdf_url
415
  }
416
+
417
  logger.info(f"Creating Salesforce record with data: {record_data}")
418
+
419
  try:
420
  record = sf.Safety_Video_Report__c.create(record_data)
421
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
 
423
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
424
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
425
  logger.warning(f"Fell back to Account record: {record['id']}")
426
+
427
  record_id = record["id"]
428
 
429
  if pdf_file:
 
444
  return None, ""
445
 
446
  def process_video(video_data):
447
+ """Process video to detect safety violations"""
448
  try:
449
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
450
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
 
473
  frame_rate=fps
474
  )
475
 
476
+ # Track unique violations by worker ID
477
+ unique_violations = {} # {worker_id: {violation_type: {first_detection, last_detection, best_confidence, best_frame, cooldown}}}
478
  snapshots = []
479
  start_time = time.time()
480
  frame_skip = CONFIG["FRAME_SKIP"]
481
+ processed_frames = 0
482
 
483
+ while processed_frames < total_frames:
484
  batch_frames = []
485
  batch_indices = []
486
 
 
495
 
496
  frame = preprocess_frame(frame)
497
 
498
+ # Skip frames if needed
499
  for _ in range(frame_skip - 1):
500
  if not cap.grab():
501
  break
 
506
  if not batch_frames:
507
  break
508
 
509
+ # Process batch with YOLO model
510
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
511
 
512
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
513
+ processed_frames += 1
514
  current_time = frame_idx / fps
515
 
516
+ # Update progress every second
517
  if time.time() - start_time > 1.0:
518
  progress = (frame_idx / total_frames) * 100
519
  yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
 
521
 
522
  boxes = result.boxes
523
  track_inputs = []
524
+
525
  for box in boxes:
526
  cls = int(box.cls)
527
  conf = float(box.conf)
 
530
  if label is None:
531
  logger.debug(f"Unknown class ID {cls} detected, skipping")
532
  continue
533
+
534
  if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
535
  logger.debug(f"Detection for {label} with confidence {conf} below threshold {CONFIG['CONFIDENCE_THRESHOLDS'].get(label, 0.25)}")
536
  continue
 
542
  "cls": cls
543
  })
544
 
545
+ # Skip tracking if no detections
546
+ if not track_inputs:
547
+ continue
548
+
549
  tracked_objects = tracker.update(
550
  np.array([t["bbox"] for t in track_inputs]),
551
  np.array([t["conf"] for t in track_inputs]),
552
  np.array([t["cls"] for t in track_inputs])
553
  )
554
+
555
  logger.debug(f"Frame {frame_idx}: {len(tracked_objects)} objects tracked")
556
 
557
+ # Process tracked objects for violations
558
+ for obj in tracked_objects:
559
  worker_id = obj['id']
560
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
561
+ conf = obj['score']
562
+ bbox = obj['bbox']
563
 
564
+ if label is None:
565
+ continue
566
+
567
+ # Initialize worker if not seen before
568
+ if worker_id not in unique_violations:
569
+ unique_violations[worker_id] = {}
570
+
571
+ # Check if this is a new violation type for this worker or if cooldown has passed
572
+ is_new_violation = False
573
+ if label not in unique_violations[worker_id]:
574
+ # New violation type for this worker
575
+ unique_violations[worker_id][label] = {
576
+ 'first_detection': current_time,
577
+ 'last_detection': current_time,
578
+ 'best_confidence': conf,
579
+ 'best_frame': frame_idx,
580
+ 'best_bbox': bbox,
581
+ 'cooldown': current_time + CONFIG["VIOLATION_COOLDOWN"]
582
+ }
583
+ is_new_violation = True
584
+ elif current_time > unique_violations[worker_id][label]['cooldown']:
585
+ # Cooldown period has passed, treat as a new violation
586
+ unique_violations[worker_id][label] = {
587
+ 'first_detection': current_time,
588
+ 'last_detection': current_time,
589
+ 'best_confidence': conf,
590
+ 'best_frame': frame_idx,
591
+ 'best_bbox': bbox,
592
+ 'cooldown': current_time + CONFIG["VIOLATION_COOLDOWN"]
593
+ }
594
+ is_new_violation = True
595
+ else:
596
+ # Update existing violation
597
+ violation_info = unique_violations[worker_id][label]
598
+ violation_info['last_detection'] = current_time
599
+
600
+ # Update if this is a better detection (higher confidence)
601
+ if conf > violation_info['best_confidence']:
602
+ violation_info['best_confidence'] = conf
603
+ violation_info['best_frame'] = frame_idx
604
+ violation_info['best_bbox'] = bbox
605
+
606
+ # If this is a new violation, capture a snapshot
607
+ if is_new_violation:
608
+ # Create a detection object for the snapshot
609
+ detection = {
610
+ "frame": frame_idx,
611
+ "violation": label,
612
+ "confidence": round(conf, 2),
613
+ "bounding_box": bbox,
614
+ "timestamp": current_time,
615
+ "worker_id": worker_id
616
+ }
617
+
618
+ # Take a snapshot for the new violation
619
+ snapshot_frame = batch_frames[i].copy()
620
+ snapshot_frame = draw_detections(snapshot_frame, [detection])
621
+
622
+ # Add timestamp to the image
623
+ cv2.putText(
624
+ snapshot_frame,
625
+ f"Time: {current_time:.2f}s",
626
+ (10, 30),
627
+ cv2.FONT_HERSHEY_SIMPLEX,
628
+ 0.7,
629
+ (255, 255, 255),
630
+ 2
631
+ )
632
+
633
+ # Save snapshot with high quality
634
+ snapshot_filename = f"{label}_worker{worker_id}_{int(current_time)}_{frame_idx}.jpg"
635
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
636
+
637
+ # Use higher quality for JPEG to ensure better visibility
638
+ cv2.imwrite(
639
+ snapshot_path,
640
+ snapshot_frame,
641
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
642
+ )
643
+
644
+ snapshots.append({
645
+ "violation": label,
646
+ "worker_id": worker_id,
647
+ "frame": frame_idx,
648
+ "timestamp": current_time,
649
+ "snapshot_path": snapshot_path,
650
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
651
+ })
652
+
653
+ logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at frame {frame_idx}")
654
 
655
  cap.release()
656
  if os.path.exists(video_path):
657
  os.remove(video_path)
658
+
659
  processing_time = time.time() - start_time
660
  logger.info(f"Processing complete in {processing_time:.2f}s")
661
 
662
+ # Convert tracked violations to final violation list
663
  violations = []
664
+ for worker_id, worker_violations in unique_violations.items():
665
+ for label, violation_info in worker_violations.items():
666
+ violation = {
667
+ "worker_id": worker_id,
668
+ "violation": label,
669
+ "confidence": violation_info['best_confidence'],
670
+ "start_timestamp": violation_info['first_detection'],
671
+ "end_timestamp": violation_info['last_detection'],
672
+ "frame": violation_info['best_frame'],
673
+ "bounding_box": violation_info['best_bbox']
674
+ }
675
+ violations.append(violation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  if not violations:
678
  logger.info("No violations detected after processing")
679
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
680
  return
681
 
682
+ # Calculate safety score
683
  score = calculate_safety_score(violations)
684
+
685
+ # Generate PDF report
686
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
687
+
688
+ # Push report to Salesforce
689
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
690
 
691
+ # Format violations table for display
692
+ violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
693
+ violation_table += "|-----------|-----------|----------|------------|\n"
694
+
695
+ for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("start_timestamp", 0.0))):
696
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
697
+ worker_id = v.get("worker_id", "Unknown")
698
+ start_time = v.get('start_timestamp', 0.0)
699
+ end_time = v.get('end_timestamp', 0.0)
700
+ confidence = v.get('confidence', 0.0)
701
+
702
+ row = f"| {display_name} | {worker_id} | {start_time:.2f}-{end_time:.2f} | {confidence:.2f} |\n"
703
  violation_table += row
704
 
705
+ # Format snapshots for display
706
+ snapshots_text = ""
707
+ for i, s in enumerate(snapshots):
708
+ display_name = CONFIG["DISPLAY_NAMES"].get(s['violation'], "Unknown")
709
+ worker_id = s.get("worker_id", "Unknown")
710
+ timestamp = s.get("timestamp", 0.0)
711
+
712
+ snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
713
+ snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
714
+
715
+ if not snapshots_text:
716
+ snapshots_text = "No snapshots captured."
717
 
718
  yield (
719
  violation_table,
 
730
  yield f"Error processing video: {e}", "", "", "", ""
731
 
732
  def gradio_interface(video_file):
733
+ """Gradio interface for the video processing"""
734
  if not video_file:
735
  return "No file uploaded.", "", "No file uploaded.", "", ""
736
+
737
  try:
738
  with open(video_file, "rb") as f:
739
  video_data = f.read()
740
 
741
  for status, score, snapshots_text, record_id, details_url in process_video(video_data):
742
  yield status, score, snapshots_text, record_id, details_url
743
+
744
  except Exception as e:
745
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
746
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
747
 
748
+ # ========================== # Gradio Interface # ==========================
 
 
749
  interface = gr.Interface(
750
  fn=gradio_interface,
751
  inputs=gr.Video(label="Upload Site Video"),
 
757
  gr.Textbox(label="Violation Details URL")
758
  ],
759
  title="Worksite Safety Violation Analyzer",
760
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
761
  allow_flagging="never"
762
  )
763