PrashanthB461 commited on
Commit
67959fd
·
verified ·
1 Parent(s): 1df6374

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -72
app.py CHANGED
@@ -14,8 +14,6 @@ import base64
14
  import logging
15
  from retrying import retry
16
  import uuid
17
- from multiprocessing import Pool, cpu_count
18
- from functools import partial
19
 
20
  # ==========================
21
  # Optimized Configuration
@@ -53,17 +51,17 @@ CONFIG = {
53
  },
54
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
55
  "CONFIDENCE_THRESHOLDS": {
56
- "no_helmet": 0.6, # Higher threshold to reduce false positives
57
  "no_harness": 0.4,
58
  "unsafe_posture": 0.4,
59
  "unsafe_zone": 0.4,
60
  "improper_tool_use": 0.4
61
  },
62
- "MIN_VIOLATION_FRAMES": 3, # Require 3+ detections to confirm violation
63
- "WORKER_TRACKING_DURATION": 3.0, # Track workers for 3 seconds
64
- "MAX_PROCESSING_TIME": 30, # 30-second limit
65
- "PARALLEL_WORKERS": max(1, cpu_count() - 1), # Use all CPU cores
66
- "CHUNK_SIZE": 8 # Frames per batch
67
  }
68
 
69
  # Setup logging
@@ -119,7 +117,6 @@ def calculate_iou(box1, box2):
119
  x1, y1, w1, h1 = box1
120
  x2, y2, w2, h2 = box2
121
 
122
- # Calculate intersection area
123
  x_left = max(x1 - w1/2, x2 - w2/2)
124
  y_top = max(y1 - h1/2, y2 - h2/2)
125
  x_right = min(x1 + w1/2, x2 + w2/2)
@@ -135,36 +132,6 @@ def calculate_iou(box1, box2):
135
 
136
  return intersection_area / union_area
137
 
138
- def process_frame_batch(frame_batch, frame_indices, fps):
139
- batch_results = []
140
- results = model(frame_batch, device=device, conf=0.1, verbose=False)
141
-
142
- for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
143
- current_time = frame_idx / fps
144
- detections = []
145
-
146
- boxes = result.boxes
147
- for box in boxes:
148
- cls = int(box.cls)
149
- conf = float(box.conf)
150
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
151
-
152
- if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
153
- continue
154
-
155
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
156
- detections.append({
157
- "frame": frame_idx,
158
- "violation": label,
159
- "confidence": round(conf, 2),
160
- "bounding_box": bbox,
161
- "timestamp": current_time
162
- })
163
-
164
- batch_results.append((frame_idx, detections))
165
-
166
- return batch_results
167
-
168
  def generate_violation_pdf(violations, score):
169
  try:
170
  pdf_filename = f"violations_{int(time.time())}.pdf"
@@ -227,7 +194,7 @@ def calculate_safety_score(violations):
227
  return max(score, 0)
228
 
229
  # ==========================
230
- # Optimized Video Processing
231
  # ==========================
232
  def process_video(video_data):
233
  try:
@@ -246,62 +213,97 @@ def process_video(video_data):
246
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
247
  fps = cap.get(cv2.CAP_PROP_FPS)
248
  if fps <= 0:
249
- fps = 30 # Default assumption if FPS not available
250
  duration = total_frames / fps
251
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
252
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
253
 
254
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
255
 
256
- # Read all frames upfront
257
- all_frames = []
258
- all_indices = []
259
- for frame_idx in range(total_frames):
260
- ret, frame = cap.read()
261
- if not ret:
262
- break
263
- all_frames.append(frame)
264
- all_indices.append(frame_idx)
265
- cap.release()
266
-
267
- # Process frames in parallel batches
268
  workers = []
269
  violations = []
270
  helmet_violations = {}
271
  snapshots = []
272
  start_time = time.time()
 
 
273
 
274
- # Split frames into batches
275
- frame_batches = [all_frames[i:i + CONFIG["CHUNK_SIZE"]] for i in range(0, len(all_frames), CONFIG["CHUNK_SIZE"])]
276
- frame_indices_batches = [all_indices[i:i + CONFIG["CHUNK_SIZE"]] for i in range(0, len(all_indices), CONFIG["CHUNK_SIZE"])]
277
-
278
- # Process batches in parallel
279
- with Pool(processes=CONFIG["PARALLEL_WORKERS"]) as pool:
280
- process_func = partial(process_frame_batch, fps=fps)
281
- results = pool.starmap(process_func, zip(frame_batches, frame_indices_batches))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
- # Flatten results and track workers
284
- for batch_result in results:
285
- for frame_idx, detections in batch_result:
 
 
286
  current_time = frame_idx / fps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- for detection in detections:
289
  # Worker tracking
290
  worker_id = None
291
  max_iou = 0
292
  for idx, worker in enumerate(workers):
293
- iou = calculate_iou(detection["bounding_box"], worker["bbox"])
294
  if iou > max_iou and iou > 0.4: # IOU threshold
295
  max_iou = iou
296
  worker_id = worker["id"]
297
- workers[idx]["bbox"] = detection["bounding_box"]
298
  workers[idx]["last_seen"] = current_time
299
 
300
  if worker_id is None:
301
  worker_id = len(workers) + 1
302
  workers.append({
303
  "id": worker_id,
304
- "bbox": detection["bounding_box"],
305
  "first_seen": current_time,
306
  "last_seen": current_time
307
  })
@@ -319,6 +321,11 @@ def process_video(video_data):
319
  # Remove inactive workers
320
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
321
 
 
 
 
 
 
322
  # Confirm helmet violations (require multiple detections)
323
  for worker_id, detections in helmet_violations.items():
324
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
@@ -342,10 +349,6 @@ def process_video(video_data):
342
  })
343
  cap.release()
344
 
345
- os.remove(video_path)
346
- processing_time = time.time() - start_time
347
- logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
348
-
349
  # Generate results
350
  if not violations:
351
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
 
14
  import logging
15
  from retrying import retry
16
  import uuid
 
 
17
 
18
  # ==========================
19
  # Optimized Configuration
 
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
  "CONFIDENCE_THRESHOLDS": {
54
+ "no_helmet": 0.6,
55
  "no_harness": 0.4,
56
  "unsafe_posture": 0.4,
57
  "unsafe_zone": 0.4,
58
  "improper_tool_use": 0.4
59
  },
60
+ "MIN_VIOLATION_FRAMES": 3,
61
+ "WORKER_TRACKING_DURATION": 3.0,
62
+ "MAX_PROCESSING_TIME": 60, # 1 minute limit
63
+ "FRAME_SKIP": 2, # Process every 2nd frame for speed
64
+ "BATCH_SIZE": 16 # Frames per batch
65
  }
66
 
67
  # Setup logging
 
117
  x1, y1, w1, h1 = box1
118
  x2, y2, w2, h2 = box2
119
 
 
120
  x_left = max(x1 - w1/2, x2 - w2/2)
121
  y_top = max(y1 - h1/2, y2 - h2/2)
122
  x_right = min(x1 + w1/2, x2 + w2/2)
 
132
 
133
  return intersection_area / union_area
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def generate_violation_pdf(violations, score):
136
  try:
137
  pdf_filename = f"violations_{int(time.time())}.pdf"
 
194
  return max(score, 0)
195
 
196
  # ==========================
197
+ # Fast Video Processing
198
  # ==========================
199
  def process_video(video_data):
200
  try:
 
213
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
214
  fps = cap.get(cv2.CAP_PROP_FPS)
215
  if fps <= 0:
216
+ fps = 30
217
  duration = total_frames / fps
218
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
219
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
220
 
221
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
222
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  workers = []
224
  violations = []
225
  helmet_violations = {}
226
  snapshots = []
227
  start_time = time.time()
228
+ processed_frames = 0
229
+ frame_skip = CONFIG["FRAME_SKIP"]
230
 
231
+ # Process frames in batches
232
+ while True:
233
+ batch_frames = []
234
+ batch_indices = []
235
+
236
+ # Collect frames for this batch
237
+ for _ in range(CONFIG["BATCH_SIZE"]):
238
+ frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
239
+ if frame_idx >= total_frames:
240
+ break
241
+
242
+ ret, frame = cap.read()
243
+ if not ret:
244
+ break
245
+
246
+ # Skip frames if needed
247
+ for _ in range(frame_skip - 1):
248
+ if not cap.grab():
249
+ break
250
+
251
+ batch_frames.append(frame)
252
+ batch_indices.append(frame_idx)
253
+ processed_frames += 1
254
+
255
+ # Break if no more frames
256
+ if not batch_frames:
257
+ break
258
 
259
+ # Run batch detection
260
+ results = model(batch_frames, device=device, conf=0.1, verbose=False)
261
+
262
+ # Process results for each frame in batch
263
+ for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
264
  current_time = frame_idx / fps
265
+
266
+ # Update progress periodically
267
+ if time.time() - start_time > 1.0: # Update every second
268
+ progress = (frame_idx / total_frames) * 100
269
+ yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
270
+ start_time = time.time()
271
+
272
+ # Process detections in this frame
273
+ boxes = result.boxes
274
+ for box in boxes:
275
+ cls = int(box.cls)
276
+ conf = float(box.conf)
277
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
278
+
279
+ if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
280
+ continue
281
+
282
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
283
+ detection = {
284
+ "frame": frame_idx,
285
+ "violation": label,
286
+ "confidence": round(conf, 2),
287
+ "bounding_box": bbox,
288
+ "timestamp": current_time
289
+ }
290
 
 
291
  # Worker tracking
292
  worker_id = None
293
  max_iou = 0
294
  for idx, worker in enumerate(workers):
295
+ iou = calculate_iou(bbox, worker["bbox"])
296
  if iou > max_iou and iou > 0.4: # IOU threshold
297
  max_iou = iou
298
  worker_id = worker["id"]
299
+ workers[idx]["bbox"] = bbox
300
  workers[idx]["last_seen"] = current_time
301
 
302
  if worker_id is None:
303
  worker_id = len(workers) + 1
304
  workers.append({
305
  "id": worker_id,
306
+ "bbox": bbox,
307
  "first_seen": current_time,
308
  "last_seen": current_time
309
  })
 
321
  # Remove inactive workers
322
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
323
 
324
+ cap.release()
325
+ os.remove(video_path)
326
+ processing_time = time.time() - start_time
327
+ logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
328
+
329
  # Confirm helmet violations (require multiple detections)
330
  for worker_id, detections in helmet_violations.items():
331
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
 
349
  })
350
  cap.release()
351
 
 
 
 
 
352
  # Generate results
353
  if not violations:
354
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"