PrashanthB461 commited on
Commit
188385c
·
verified ·
1 Parent(s): 98276c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -8
app.py CHANGED
@@ -64,8 +64,10 @@ CONFIG = {
64
  "HELMET_CONFIDENCE_THRESHOLD": 0.6,
65
  "WORKER_TRACKING_DURATION": 2.5,
66
  "MAX_PROCESSING_TIME": 30,
67
- "PARALLEL_WORKERS": max(1, cpu_count() - 1),
68
- "CHUNK_SIZE": 15 # Increased chunk size for faster processing
 
 
69
  }
70
 
71
  # Setup logging
@@ -99,6 +101,15 @@ model = load_model()
99
  # ==========================
100
  # Optimized Helper Functions
101
  # ==========================
 
 
 
 
 
 
 
 
 
102
  def draw_detections(frame, detections):
103
  for det in detections:
104
  label = det.get("violation", "Unknown")
@@ -139,7 +150,9 @@ def calculate_iou(box1, box2):
139
 
140
  def process_frame_batch(frame_batch, frame_indices, fps):
141
  batch_results = []
142
- results = model(frame_batch, device=device, conf=0.05, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
 
 
143
 
144
  for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
145
  current_time = frame_idx / fps
@@ -169,6 +182,7 @@ def process_frame_batch(frame_batch, frame_indices, fps):
169
 
170
  def generate_violation_pdf(violations, score):
171
  try:
 
172
  pdf_filename = f"violations_{int(time.time())}.pdf"
173
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
174
  pdf_file = BytesIO()
@@ -210,6 +224,7 @@ def generate_violation_pdf(violations, score):
210
  with open(pdf_path, "wb") as f:
211
  f.write(pdf_file.getvalue())
212
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
 
213
  logger.info(f"PDF generated: {public_url}")
214
  return pdf_path, public_url, pdf_file
215
  except Exception as e:
@@ -241,6 +256,7 @@ def process_video(video_data):
241
  logger.info(f"Video saved: {video_path}")
242
 
243
  # Open video file
 
244
  cap = cv2.VideoCapture(video_path)
245
  if not cap.isOpened():
246
  raise ValueError("Could not open video file")
@@ -264,19 +280,32 @@ def process_video(video_data):
264
  yield "Video duration too long. Please upload a shorter video.", "", "", "", ""
265
  return
266
 
267
- # Read all frames upfront
 
 
 
 
 
 
 
 
 
268
  frame_batches = []
269
  frame_indices_batches = []
270
  current_batch = []
271
  current_indices = []
272
  frame_count = 0
 
273
 
274
  while True:
275
  ret, frame = cap.read()
276
  if not ret:
277
  break
278
- current_batch.append(frame)
279
- current_indices.append(frame_count)
 
 
 
280
  frame_count += 1
281
 
282
  if len(current_batch) >= CONFIG["CHUNK_SIZE"]:
@@ -290,6 +319,8 @@ def process_video(video_data):
290
  frame_indices_batches.append(current_indices)
291
 
292
  cap.release()
 
 
293
 
294
  # Process frames in parallel
295
  violations = []
@@ -297,6 +328,7 @@ def process_video(video_data):
297
  snapshots = []
298
  last_progress_time = start_time
299
 
 
300
  with Pool(processes=CONFIG["PARALLEL_WORKERS"]) as pool:
301
  process_func = partial(process_frame_batch, fps=fps)
302
  results = pool.starmap(process_func, zip(frame_batches, frame_indices_batches))
@@ -307,15 +339,18 @@ def process_video(video_data):
307
  all_detections.extend(batch_result)
308
  all_detections.sort(key=lambda x: x[0])
309
 
 
 
310
  # Worker tracking
 
311
  workers = []
312
  for frame_idx, detections in all_detections:
313
  current_time = frame_idx / fps
314
 
315
  # Update progress every second
316
  if time.time() - last_progress_time > 1.0:
317
- progress = (frame_idx / total_frames) * 100
318
- yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
319
  last_progress_time = time.time()
320
 
321
  # Early termination if time limit approached
@@ -354,7 +389,10 @@ def process_video(video_data):
354
 
355
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
356
 
 
 
357
  # Process helmet violations
 
358
  for worker_id, detections in helmet_violations.items():
359
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
360
  best_detection = max(detections, key=lambda x: x["confidence"])
@@ -378,6 +416,8 @@ def process_video(video_data):
378
  })
379
  cap.release()
380
 
 
 
381
  os.remove(video_path)
382
  processing_time = time.time() - start_time
383
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
@@ -401,6 +441,7 @@ def process_video(video_data):
401
  for s in snapshots
402
  ) if snapshots else "No snapshots captured."
403
 
 
404
  try:
405
  sf = connect_to_salesforce()
406
  record_data = {
@@ -417,6 +458,8 @@ def process_video(video_data):
417
  except Exception as e:
418
  logger.error(f"Salesforce integration failed: {e}")
419
  record_id = "N/A (Salesforce error)"
 
 
420
 
421
  yield (
422
  violation_table,
 
64
  "HELMET_CONFIDENCE_THRESHOLD": 0.6,
65
  "WORKER_TRACKING_DURATION": 2.5,
66
  "MAX_PROCESSING_TIME": 30,
67
+ "PARALLEL_WORKERS": 2, # Reduced for Hugging Face Spaces
68
+ "CHUNK_SIZE": 20, # Increased for faster batch processing
69
+ "FRAME_SAMPLE_RATE": 2, # Process every 2nd frame
70
+ "MAX_FRAME_WIDTH": 640 # Resize frames to this width
71
  }
72
 
73
  # Setup logging
 
101
  # ==========================
102
  # Optimized Helper Functions
103
  # ==========================
104
+ def resize_frame(frame, max_width):
105
+ height, width = frame.shape[:2]
106
+ if width > max_width:
107
+ scale = max_width / width
108
+ new_width = int(width * scale)
109
+ new_height = int(height * scale)
110
+ frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
111
+ return frame
112
+
113
  def draw_detections(frame, detections):
114
  for det in detections:
115
  label = det.get("violation", "Unknown")
 
150
 
151
  def process_frame_batch(frame_batch, frame_indices, fps):
152
  batch_results = []
153
+ start_inference = time.time()
154
+ results = model(frame_batch, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
155
+ logger.info(f"Inference time for batch of {len(frame_batch)} frames: {time.time() - start_inference:.2f}s")
156
 
157
  for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
158
  current_time = frame_idx / fps
 
182
 
183
  def generate_violation_pdf(violations, score):
184
  try:
185
+ start_pdf = time.time()
186
  pdf_filename = f"violations_{int(time.time())}.pdf"
187
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
188
  pdf_file = BytesIO()
 
224
  with open(pdf_path, "wb") as f:
225
  f.write(pdf_file.getvalue())
226
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
227
+ logger.info(f"PDF generation time: {time.time() - start_pdf:.2f}s")
228
  logger.info(f"PDF generated: {public_url}")
229
  return pdf_path, public_url, pdf_file
230
  except Exception as e:
 
256
  logger.info(f"Video saved: {video_path}")
257
 
258
  # Open video file
259
+ start_read = time.time()
260
  cap = cv2.VideoCapture(video_path)
261
  if not cap.isOpened():
262
  raise ValueError("Could not open video file")
 
280
  yield "Video duration too long. Please upload a shorter video.", "", "", "", ""
281
  return
282
 
283
+ # Estimate processing feasibility
284
+ estimated_frames = total_frames // CONFIG["FRAME_SAMPLE_RATE"]
285
+ if estimated_frames * 0.1 > CONFIG["MAX_PROCESSING_TIME"]: # Rough estimate: 0.1s per frame
286
+ logger.warning(f"Too many frames ({estimated_frames}) to process within {CONFIG['MAX_PROCESSING_TIME']}s")
287
+ cap.release()
288
+ os.remove(video_path)
289
+ yield "Video has too many frames to process within 30 seconds.", "", "", "", ""
290
+ return
291
+
292
+ # Read frames with sampling
293
  frame_batches = []
294
  frame_indices_batches = []
295
  current_batch = []
296
  current_indices = []
297
  frame_count = 0
298
+ sampled_frame_count = 0
299
 
300
  while True:
301
  ret, frame = cap.read()
302
  if not ret:
303
  break
304
+ if frame_count % CONFIG["FRAME_SAMPLE_RATE"] == 0:
305
+ frame = resize_frame(frame, CONFIG["MAX_FRAME_WIDTH"])
306
+ current_batch.append(frame)
307
+ current_indices.append(frame_count)
308
+ sampled_frame_count += 1
309
  frame_count += 1
310
 
311
  if len(current_batch) >= CONFIG["CHUNK_SIZE"]:
 
319
  frame_indices_batches.append(current_indices)
320
 
321
  cap.release()
322
+ logger.info(f"Frame reading time: {time.time() - start_read:.2f}s")
323
+ logger.info(f"Total frames: {frame_count}, Sampled frames: {sampled_frame_count}")
324
 
325
  # Process frames in parallel
326
  violations = []
 
328
  snapshots = []
329
  last_progress_time = start_time
330
 
331
+ start_parallel = time.time()
332
  with Pool(processes=CONFIG["PARALLEL_WORKERS"]) as pool:
333
  process_func = partial(process_frame_batch, fps=fps)
334
  results = pool.starmap(process_func, zip(frame_batches, frame_indices_batches))
 
339
  all_detections.extend(batch_result)
340
  all_detections.sort(key=lambda x: x[0])
341
 
342
+ logger.info(f"Parallel processing time: {time.time() - start_parallel:.2f}s")
343
+
344
  # Worker tracking
345
+ start_tracking = time.time()
346
  workers = []
347
  for frame_idx, detections in all_detections:
348
  current_time = frame_idx / fps
349
 
350
  # Update progress every second
351
  if time.time() - last_progress_time > 1.0:
352
+ progress = (frame_idx / frame_count) * 100
353
+ yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{frame_count})", "", "", "", ""
354
  last_progress_time = time.time()
355
 
356
  # Early termination if time limit approached
 
389
 
390
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
391
 
392
+ logger.info(f"Worker tracking time: {time.time() - start_tracking:.2f}s")
393
+
394
  # Process helmet violations
395
+ start_snapshot = time.time()
396
  for worker_id, detections in helmet_violations.items():
397
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
398
  best_detection = max(detections, key=lambda x: x["confidence"])
 
416
  })
417
  cap.release()
418
 
419
+ logger.info(f"Snapshot generation time: {time.time() - start_snapshot:.2f}s")
420
+
421
  os.remove(video_path)
422
  processing_time = time.time() - start_time
423
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
 
441
  for s in snapshots
442
  ) if snapshots else "No snapshots captured."
443
 
444
+ start_salesforce = time.time()
445
  try:
446
  sf = connect_to_salesforce()
447
  record_data = {
 
458
  except Exception as e:
459
  logger.error(f"Salesforce integration failed: {e}")
460
  record_id = "N/A (Salesforce error)"
461
+
462
+ logger.info(f"Salesforce integration time: {time.time() - start_salesforce:.2f}s")
463
 
464
  yield (
465
  violation_table,