PrashanthB461 commited on
Commit
40b5e84
·
verified ·
1 Parent(s): 6827f40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -40
app.py CHANGED
@@ -13,6 +13,7 @@ from io import BytesIO
13
  import base64
14
  import logging
15
  from retrying import retry
 
16
 
17
  # ==========================
18
  # Enhanced Configuration
@@ -49,9 +50,21 @@ CONFIG = {
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 5, # Reduced for better detection
 
 
 
 
 
 
53
  "MAX_PROCESSING_TIME": 60,
54
- "CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
 
 
 
 
 
 
55
  "IOU_THRESHOLD": 0.4,
56
  "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
57
  }
@@ -94,7 +107,6 @@ def draw_detections(frame, detections):
94
  confidence = det["confidence"]
95
  x, y, w, h = det["bounding_box"]
96
 
97
- # Convert from center coordinates to corner coordinates
98
  x1 = int(x - w/2)
99
  y1 = int(y - h/2)
100
  x2 = int(x + w/2)
@@ -113,19 +125,12 @@ def calculate_iou(box1, box2):
113
  x1, y1, w1, h1 = box1
114
  x2, y2, w2, h2 = box2
115
 
116
- # Convert to top-left and bottom-right coordinates
117
  x1_min, y1_min = x1 - w1/2, y1 - h1/2
118
  x1_max, y1_max = x1 + w1/2, y1 + h1/2
119
  x2_min, y2_min = x2 - w2/2, y2 - h2/2
120
  x2_max, y2_max = x2 + w2/2, y2 + h2/2
121
 
122
- # Calculate intersection
123
- x_min = max(x1_min, x2_min)
124
- y_min = max(y1_min, y2_min)
125
- x_max = min(x1_max, x2_max)
126
- y_max = min(y1_max, y2_max)
127
-
128
- intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
129
  area1 = w1 * h1
130
  area2 = w2 * h2
131
  union = area1 + area2 - intersection
@@ -133,7 +138,7 @@ def calculate_iou(box1, box2):
133
  return intersection / union if union > 0 else 0
134
 
135
  # ==========================
136
- # Salesforce Integration (unchanged)
137
  # ==========================
138
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
139
  def connect_to_salesforce():
@@ -270,13 +275,7 @@ def calculate_safety_score(violations):
270
  "unsafe_zone": 35,
271
  "improper_tool_use": 25
272
  }
273
- # Count unique violations per worker
274
- unique_violations = set()
275
- for v in violations:
276
- key = (v["worker_id"], v["violation"])
277
- unique_violations.add(key)
278
-
279
- total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
280
  score = 100 - total_penalty
281
  return max(score, 0)
282
 
@@ -302,10 +301,11 @@ def process_video(video_data):
302
  if fps <= 0:
303
  fps = 30 # Default assumption if FPS cannot be determined
304
 
305
- # Structure to track workers and their violations
306
  workers = []
307
  violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
 
308
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
 
309
 
310
  logger.info(f"Processing video with FPS: {fps}")
311
  logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
@@ -315,18 +315,18 @@ def process_video(video_data):
315
  if not ret:
316
  break
317
 
318
- if frame_count % CONFIG["FRAME_SKIP"] != 0:
319
- frame_count += 1
320
- continue
321
-
322
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
323
  logger.info("Processing time limit reached")
324
  break
325
 
326
  current_time = frame_count / fps
327
-
 
 
 
 
328
  # Run detection on this frame
329
- results = model(frame, device=device)
330
 
331
  current_detections = []
332
  for result in results:
@@ -339,14 +339,14 @@ def process_video(video_data):
339
  if label is None:
340
  continue
341
 
342
- if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
343
  continue
344
 
345
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
 
347
  current_detections.append({
348
  "frame": frame_count,
349
- "violation": label,
350
  "confidence": round(conf, 2),
351
  "bounding_box": bbox,
352
  "timestamp": current_time
@@ -354,7 +354,21 @@ def process_video(video_data):
354
 
355
  # Process detections and associate with workers
356
  for detection in current_detections:
357
- # Find matching worker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  matched_worker = None
359
  max_iou = 0
360
 
@@ -365,12 +379,10 @@ def process_video(video_data):
365
  matched_worker = worker
366
 
367
  if matched_worker:
368
- # Update worker's position
369
  matched_worker["bbox"] = detection["bounding_box"]
370
  matched_worker["last_seen"] = current_time
371
  worker_id = matched_worker["id"]
372
  else:
373
- # New worker
374
  worker_id = len(workers) + 1
375
  workers.append({
376
  "id": worker_id,
@@ -379,9 +391,19 @@ def process_video(video_data):
379
  "last_seen": current_time
380
  })
381
 
382
- # Add to violation history
 
 
 
383
  detection["worker_id"] = worker_id
384
- violation_history[detection["violation"]].append(detection)
 
 
 
 
 
 
 
385
 
386
  frame_count += 1
387
 
@@ -393,30 +415,36 @@ def process_video(video_data):
393
  if not detections:
394
  continue
395
 
396
- # Group by worker
397
  worker_violations = {}
398
  for det in detections:
399
  if det["worker_id"] not in worker_violations:
400
  worker_violations[det["worker_id"]] = []
401
  worker_violations[det["worker_id"]].append(det)
402
 
403
- # Check each worker's violations for persistence
404
  for worker_id, worker_dets in worker_violations.items():
405
  if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
406
- # Take the highest confidence detection
 
 
 
 
 
 
 
407
  best_detection = max(worker_dets, key=lambda x: x["confidence"])
408
  violations.append(best_detection)
409
 
410
- # Capture snapshot if not already taken
 
 
 
411
  if not snapshot_taken[violation_type]:
412
- # Get the frame for this violation
413
  cap = cv2.VideoCapture(video_path)
414
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
415
  ret, snapshot_frame = cap.read()
416
  cap.release()
417
 
418
  if ret:
419
- # Draw detections on snapshot
420
  snapshot_frame = draw_detections(snapshot_frame, [best_detection])
421
 
422
  snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
@@ -430,7 +458,6 @@ def process_video(video_data):
430
  })
431
  snapshot_taken[violation_type] = True
432
 
433
- # Final processing
434
  if not violations:
435
  logger.info("No persistent violations detected")
436
  return {
 
13
  import base64
14
  import logging
15
  from retrying import retry
16
+ import uuid
17
 
18
  # ==========================
19
  # Enhanced Configuration
 
50
  "domain": "login"
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
+ "FRAME_SKIP": {
54
+ "no_helmet": 3, # Lower skip for frequent violations
55
+ "no_harness": 2,
56
+ "unsafe_posture": 2,
57
+ "unsafe_zone": 2,
58
+ "improper_tool_use": 2
59
+ },
60
  "MAX_PROCESSING_TIME": 60,
61
+ "CONFIDENCE_THRESHOLDS": {
62
+ "no_helmet": 0.3, # Slightly higher to reduce false positives
63
+ "no_harness": 0.2,
64
+ "unsafe_posture": 0.2,
65
+ "unsafe_zone": 0.2,
66
+ "improper_tool_use": 0.2
67
+ },
68
  "IOU_THRESHOLD": 0.4,
69
  "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
70
  }
 
107
  confidence = det["confidence"]
108
  x, y, w, h = det["bounding_box"]
109
 
 
110
  x1 = int(x - w/2)
111
  y1 = int(y - h/2)
112
  x2 = int(x + w/2)
 
125
  x1, y1, w1, h1 = box1
126
  x2, y2, w2, h2 = box2
127
 
 
128
  x1_min, y1_min = x1 - w1/2, y1 - h1/2
129
  x1_max, y1_max = x1 + w1/2, y1 + h1/2
130
  x2_min, y2_min = x2 - w2/2, y2 - h2/2
131
  x2_max, y2_max = x2 + w2/2, y2 + h2/2
132
 
133
+ intersection = max(0, x1_max - x1_min) * max(0, y1_max - y1_min)
 
 
 
 
 
 
134
  area1 = w1 * h1
135
  area2 = w2 * h2
136
  union = area1 + area2 - intersection
 
138
  return intersection / union if union > 0 else 0
139
 
140
  # ==========================
141
+ # Salesforce Integration
142
  # ==========================
143
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
144
  def connect_to_salesforce():
 
275
  "unsafe_zone": 35,
276
  "improper_tool_use": 25
277
  }
278
+ total_penalty = sum(penalties.get(v["violation"], 0) for v in violations)
 
 
 
 
 
 
279
  score = 100 - total_penalty
280
  return max(score, 0)
281
 
 
301
  if fps <= 0:
302
  fps = 30 # Default assumption if FPS cannot be determined
303
 
 
304
  workers = []
305
  violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
306
+ confirmed_violations = {} # Track confirmed violations per worker
307
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
308
+ helmet_compliance = {} # Track workers with helmets
309
 
310
  logger.info(f"Processing video with FPS: {fps}")
311
  logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
 
315
  if not ret:
316
  break
317
 
 
 
 
 
318
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
319
  logger.info("Processing time limit reached")
320
  break
321
 
322
  current_time = frame_count / fps
323
+ min_frame_skip = min(CONFIG["FRAME_SKIP"].values())
324
+ if frame_count % min_frame_skip != 0:
325
+ frame_count += 1
326
+ continue
327
+
328
  # Run detection on this frame
329
+ results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], agnostic_nms=True)
330
 
331
  current_detections = []
332
  for result in results:
 
339
  if label is None:
340
  continue
341
 
342
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
343
  continue
344
 
345
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
 
347
  current_detections.append({
348
  "frame": frame_count,
349
+ "vi臉olation": label,
350
  "confidence": round(conf, 2),
351
  "bounding_box": bbox,
352
  "timestamp": current_time
 
354
 
355
  # Process detections and associate with workers
356
  for detection in current_detections:
357
+ violation_type = detection["violation"]
358
+ # Skip No Helmet detection if worker is compliant
359
+ if violation_type == "no_helmet":
360
+ matched_worker = None
361
+ max_iou = 0
362
+ for worker in workers:
363
+ iou = calculate_iou(detection["bounding_box"], worker["bbox"])
364
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
365
+ max_iou = iou
366
+ matched_worker = worker
367
+
368
+ if matched_worker and matched_worker["id"] in helmet_compliance:
369
+ continue # Skip if worker is known to wear a helmet
370
+
371
+ # Find or create worker
372
  matched_worker = None
373
  max_iou = 0
374
 
 
379
  matched_worker = worker
380
 
381
  if matched_worker:
 
382
  matched_worker["bbox"] = detection["bounding_box"]
383
  matched_worker["last_seen"] = current_time
384
  worker_id = matched_worker["id"]
385
  else:
 
386
  worker_id = len(workers) + 1
387
  workers.append({
388
  "id": worker_id,
 
391
  "last_seen": current_time
392
  })
393
 
394
+ # Skip if this violation type is already confirmed for this worker
395
+ if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
396
+ continue
397
+
398
  detection["worker_id"] = worker_id
399
+ violation_history[violation_type].append(detection)
400
+
401
+ # Update helmet compliance (simulate by checking if No Helmet confidence is low)
402
+ if violation_type == "no_helmet" and conf < 0.5:
403
+ helmet_compliance[worker_id] = True
404
+
405
+ # Clean up old workers
406
+ workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
407
 
408
  frame_count += 1
409
 
 
415
  if not detections:
416
  continue
417
 
 
418
  worker_violations = {}
419
  for det in detections:
420
  if det["worker_id"] not in worker_violations:
421
  worker_violations[det["worker_id"]] = []
422
  worker_violations[det["worker_id"]].append(det)
423
 
 
424
  for worker_id, worker_dets in worker_violations.items():
425
  if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
426
+ # Skip if already confirmed
427
+ if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
428
+ continue
429
+
430
+ # Skip No Helmet if worker is compliant
431
+ if violation_type == "no_helmet" and worker_id in helmet_compliance:
432
+ continue
433
+
434
  best_detection = max(worker_dets, key=lambda x: x["confidence"])
435
  violations.append(best_detection)
436
 
437
+ if worker_id not in confirmed_violations:
438
+ confirmed_violations[worker_id] = set()
439
+ confirmed_violations[worker_id].add(violation_type)
440
+
441
  if not snapshot_taken[violation_type]:
 
442
  cap = cv2.VideoCapture(video_path)
443
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
444
  ret, snapshot_frame = cap.read()
445
  cap.release()
446
 
447
  if ret:
 
448
  snapshot_frame = draw_detections(snapshot_frame, [best_detection])
449
 
450
  snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
 
458
  })
459
  snapshot_taken[violation_type] = True
460
 
 
461
  if not violations:
462
  logger.info("No persistent violations detected")
463
  return {