PrashanthB461 commited on
Commit
98276c5
·
verified ·
1 Parent(s): d195ce0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -66
app.py CHANGED
@@ -53,19 +53,19 @@ CONFIG = {
53
  },
54
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
55
  "CONFIDENCE_THRESHOLDS": {
56
- "no_helmet": 0.6,
57
- "no_harness": 0.15,
58
- "unsafe_posture": 0.15,
59
- "unsafe_zone": 0.15,
60
- "improper_tool_use": 0.15
61
  },
62
- "IOU_THRESHOLD": 0.4,
63
- "MIN_VIOLATION_FRAMES": 3,
64
- "HELMET_CONFIDENCE_THRESHOLD": 0.65,
65
- "WORKER_TRACKING_DURATION": 3.0,
66
- "MAX_PROCESSING_TIME": 30, # 30 second hard limit
67
- "PARALLEL_WORKERS": max(1, cpu_count() - 1), # Use all but one CPU core
68
- "CHUNK_SIZE": 10 # Frames per parallel batch
69
  }
70
 
71
  # Setup logging
@@ -122,7 +122,6 @@ def calculate_iou(box1, box2):
122
  x1, y1, w1, h1 = box1
123
  x2, y2, w2, h2 = box2
124
 
125
- # Calculate intersection coordinates
126
  x_left = max(x1 - w1/2, x2 - w2/2)
127
  y_top = max(y1 - h1/2, y2 - h2/2)
128
  x_right = min(x1 + w1/2, x2 + w2/2)
@@ -140,7 +139,7 @@ def calculate_iou(box1, box2):
140
 
141
  def process_frame_batch(frame_batch, frame_indices, fps):
142
  batch_results = []
143
- results = model(frame_batch, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
144
 
145
  for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
146
  current_time = frame_idx / fps
@@ -152,7 +151,7 @@ def process_frame_batch(frame_batch, frame_indices, fps):
152
  conf = float(box.conf)
153
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
154
 
155
- if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
156
  continue
157
 
158
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
@@ -234,6 +233,7 @@ def calculate_safety_score(violations):
234
  # ==========================
235
  def process_video(video_data):
236
  try:
 
237
  # Create temp video file
238
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
239
  with open(video_path, "wb") as f:
@@ -256,66 +256,74 @@ def process_video(video_data):
256
 
257
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
258
 
259
- # Prepare for parallel processing
 
 
 
 
 
 
 
 
260
  frame_batches = []
261
  frame_indices_batches = []
262
  current_batch = []
263
  current_indices = []
264
-
265
- # Read all frames upfront for parallel processing
266
- all_frames = []
267
- all_indices = []
268
- for frame_idx in range(total_frames):
269
  ret, frame = cap.read()
270
  if not ret:
271
  break
272
- all_frames.append(frame)
273
- all_indices.append(frame_idx)
274
-
275
- # Organize into batches
276
  if len(current_batch) >= CONFIG["CHUNK_SIZE"]:
277
  frame_batches.append(current_batch)
278
  frame_indices_batches.append(current_indices)
279
  current_batch = []
280
  current_indices = []
281
-
282
- # Add remaining frames
283
  if current_batch:
284
  frame_batches.append(current_batch)
285
  frame_indices_batches.append(current_indices)
286
 
287
  cap.release()
288
-
289
  # Process frames in parallel
290
- workers = []
291
  violations = []
292
  helmet_violations = {}
293
  snapshots = []
294
- start_time = time.time()
295
-
296
- # Use multiprocessing Pool
297
  with Pool(processes=CONFIG["PARALLEL_WORKERS"]) as pool:
298
  process_func = partial(process_frame_batch, fps=fps)
299
  results = pool.starmap(process_func, zip(frame_batches, frame_indices_batches))
300
 
301
- # Flatten results
302
  all_detections = []
303
  for batch_result in results:
304
  all_detections.extend(batch_result)
 
305
 
306
- # Process detections and track workers
307
  workers = []
308
- for frame_idx, detections in sorted(all_detections, key=lambda x: x[0]):
309
  current_time = frame_idx / fps
310
-
311
- # Update progress periodically
312
- if time.time() - start_time > 1.0: # Update every second
313
  progress = (frame_idx / total_frames) * 100
314
  yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
315
- start_time = time.time()
 
 
 
 
 
316
 
317
  for detection in detections:
318
- # Worker tracking
319
  worker_id = None
320
  max_iou = 0
321
  for idx, worker in enumerate(workers):
@@ -337,7 +345,6 @@ def process_video(video_data):
337
 
338
  detection["worker_id"] = worker_id
339
 
340
- # Special handling for helmet violations
341
  if detection["violation"] == "no_helmet":
342
  if worker_id not in helmet_violations:
343
  helmet_violations[worker_id] = []
@@ -345,38 +352,36 @@ def process_video(video_data):
345
  else:
346
  violations.append(detection)
347
 
348
- # Remove workers not seen recently
349
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
350
 
351
- # Process helmet violations (require consistent detections)
352
  for worker_id, detections in helmet_violations.items():
353
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
354
- # Find the detection with highest confidence
355
  best_detection = max(detections, key=lambda x: x["confidence"])
356
- violations.append(best_detection)
357
-
358
- # Capture snapshot for this violation
359
- cap = cv2.VideoCapture(video_path)
360
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
361
- ret, snapshot_frame = cap.read()
362
- if ret:
363
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
364
- snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
365
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
366
- cv2.imwrite(snapshot_path, snapshot_frame)
367
- snapshots.append({
368
- "violation": "no_helmet",
369
- "frame": best_detection["frame"],
370
- "snapshot_path": snapshot_path,
371
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
372
- })
373
- cap.release()
 
374
 
375
  os.remove(video_path)
376
  processing_time = time.time() - start_time
377
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
378
 
379
- # Generate results
380
  if not violations:
381
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
382
  return
@@ -384,7 +389,6 @@ def process_video(video_data):
384
  score = calculate_safety_score(violations)
385
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
386
 
387
- # Generate violation table
388
  violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
389
  violation_table += "|------------------------|---------------|------------|-----------|\n"
390
  for v in sorted(violations, key=lambda x: x["timestamp"]):
@@ -392,13 +396,11 @@ def process_video(video_data):
392
  row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
393
  violation_table += row
394
 
395
- # Generate snapshots text
396
  snapshots_text = "\n".join(
397
  f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
398
  for s in snapshots
399
  ) if snapshots else "No snapshots captured."
400
 
401
- # Push to Salesforce
402
  try:
403
  sf = connect_to_salesforce()
404
  record_data = {
 
53
  },
54
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
55
  "CONFIDENCE_THRESHOLDS": {
56
+ "no_helmet": 0.55,
57
+ "no_harness": 0.1,
58
+ "unsafe_posture": 0.1,
59
+ "unsafe_zone": 0.1,
60
+ "improper_tool_use": 0.1
61
  },
62
+ "IOU_THRESHOLD": 0.45,
63
+ "MIN_VIOLATION_FRAMES": 2,
64
+ "HELMET_CONFIDENCE_THRESHOLD": 0.6,
65
+ "WORKER_TRACKING_DURATION": 2.5,
66
+ "MAX_PROCESSING_TIME": 30,
67
+ "PARALLEL_WORKERS": max(1, cpu_count() - 1),
68
+ "CHUNK_SIZE": 15 # Increased chunk size for faster processing
69
  }
70
 
71
  # Setup logging
 
122
  x1, y1, w1, h1 = box1
123
  x2, y2, w2, h2 = box2
124
 
 
125
  x_left = max(x1 - w1/2, x2 - w2/2)
126
  y_top = max(y1 - h1/2, y2 - h2/2)
127
  x_right = min(x1 + w1/2, x2 + w2/2)
 
139
 
140
  def process_frame_batch(frame_batch, frame_indices, fps):
141
  batch_results = []
142
+ results = model(frame_batch, device=device, conf=0.05, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
143
 
144
  for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
145
  current_time = frame_idx / fps
 
151
  conf = float(box.conf)
152
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
153
 
154
+ if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.2):
155
  continue
156
 
157
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
 
233
  # ==========================
234
  def process_video(video_data):
235
  try:
236
+ start_time = time.time()
237
  # Create temp video file
238
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
239
  with open(video_path, "wb") as f:
 
256
 
257
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
258
 
259
+ # Check if processing will exceed time limit
260
+ if duration > CONFIG["MAX_PROCESSING_TIME"]:
261
+ logger.warning(f"Video duration {duration:.2f}s exceeds max processing time {CONFIG['MAX_PROCESSING_TIME']}s")
262
+ cap.release()
263
+ os.remove(video_path)
264
+ yield "Video duration too long. Please upload a shorter video.", "", "", "", ""
265
+ return
266
+
267
+ # Read all frames upfront
268
  frame_batches = []
269
  frame_indices_batches = []
270
  current_batch = []
271
  current_indices = []
272
+ frame_count = 0
273
+
274
+ while True:
 
 
275
  ret, frame = cap.read()
276
  if not ret:
277
  break
278
+ current_batch.append(frame)
279
+ current_indices.append(frame_count)
280
+ frame_count += 1
281
+
282
  if len(current_batch) >= CONFIG["CHUNK_SIZE"]:
283
  frame_batches.append(current_batch)
284
  frame_indices_batches.append(current_indices)
285
  current_batch = []
286
  current_indices = []
287
+
 
288
  if current_batch:
289
  frame_batches.append(current_batch)
290
  frame_indices_batches.append(current_indices)
291
 
292
  cap.release()
293
+
294
  # Process frames in parallel
 
295
  violations = []
296
  helmet_violations = {}
297
  snapshots = []
298
+ last_progress_time = start_time
299
+
 
300
  with Pool(processes=CONFIG["PARALLEL_WORKERS"]) as pool:
301
  process_func = partial(process_frame_batch, fps=fps)
302
  results = pool.starmap(process_func, zip(frame_batches, frame_indices_batches))
303
 
304
+ # Flatten and sort results
305
  all_detections = []
306
  for batch_result in results:
307
  all_detections.extend(batch_result)
308
+ all_detections.sort(key=lambda x: x[0])
309
 
310
+ # Worker tracking
311
  workers = []
312
+ for frame_idx, detections in all_detections:
313
  current_time = frame_idx / fps
314
+
315
+ # Update progress every second
316
+ if time.time() - last_progress_time > 1.0:
317
  progress = (frame_idx / total_frames) * 100
318
  yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
319
+ last_progress_time = time.time()
320
+
321
+ # Early termination if time limit approached
322
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"] - 2:
323
+ logger.warning("Approaching max processing time, terminating early")
324
+ break
325
 
326
  for detection in detections:
 
327
  worker_id = None
328
  max_iou = 0
329
  for idx, worker in enumerate(workers):
 
345
 
346
  detection["worker_id"] = worker_id
347
 
 
348
  if detection["violation"] == "no_helmet":
349
  if worker_id not in helmet_violations:
350
  helmet_violations[worker_id] = []
 
352
  else:
353
  violations.append(detection)
354
 
 
355
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
356
 
357
+ # Process helmet violations
358
  for worker_id, detections in helmet_violations.items():
359
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
 
360
  best_detection = max(detections, key=lambda x: x["confidence"])
361
+ if best_detection["confidence"] >= CONFIG["HELMET_CONFIDENCE_THRESHOLD"]:
362
+ violations.append(best_detection)
363
+
364
+ # Capture snapshot
365
+ cap = cv2.VideoCapture(video_path)
366
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
367
+ ret, snapshot_frame = cap.read()
368
+ if ret:
369
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
370
+ snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
371
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
372
+ cv2.imwrite(snapshot_path, snapshot_frame)
373
+ snapshots.append({
374
+ "violation": "no_helmet",
375
+ "frame": best_detection["frame"],
376
+ "snapshot_path": snapshot_path,
377
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
378
+ })
379
+ cap.release()
380
 
381
  os.remove(video_path)
382
  processing_time = time.time() - start_time
383
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
384
 
 
385
  if not violations:
386
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
387
  return
 
389
  score = calculate_safety_score(violations)
390
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
391
 
 
392
  violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
393
  violation_table += "|------------------------|---------------|------------|-----------|\n"
394
  for v in sorted(violations, key=lambda x: x["timestamp"]):
 
396
  row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
397
  violation_table += row
398
 
 
399
  snapshots_text = "\n".join(
400
  f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
401
  for s in snapshots
402
  ) if snapshots else "No snapshots captured."
403
 
 
404
  try:
405
  sf = connect_to_salesforce()
406
  record_data = {