PrashanthB461 commited on
Commit
3edce5e
·
verified ·
1 Parent(s): 6cc29db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -45
app.py CHANGED
@@ -23,20 +23,17 @@ from functools import partial
23
  # ==========================
24
  # Configuration and Setup
25
  # ==========================
26
- # Handle Ultralytics config directory
27
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
28
  os.makedirs('/tmp/Ultralytics', exist_ok=True)
29
 
30
- # Setup logging
31
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
32
  logger = logging.getLogger(__name__)
33
 
34
  # ==========================
35
  # ByteTrack Implementation
36
  # ==========================
37
  class BYTETracker:
38
- """Robust ByteTrack implementation with fallback"""
39
- def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
40
  self.track_thresh = track_thresh
41
  self.track_buffer = track_buffer
42
  self.match_thresh = match_thresh
@@ -47,6 +44,7 @@ class BYTETracker:
47
  tracks = []
48
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
49
  if score < self.track_thresh:
 
50
  continue
51
 
52
  x, y, w, h = det
@@ -74,11 +72,11 @@ CONFIG = {
74
  4: "improper_tool_use"
75
  },
76
  "CLASS_COLORS": {
77
- "no_helmet": (0, 0, 255), # Red
78
- "no_harness": (0, 165, 255), # Orange
79
- "unsafe_posture": (0, 255, 0), # Green
80
- "unsafe_zone": (255, 0, 0), # Blue
81
- "improper_tool_use": (255, 255, 0) # Yellow
82
  },
83
  "DISPLAY_NAMES": {
84
  "no_helmet": "No Helmet Violation",
@@ -93,26 +91,25 @@ CONFIG = {
93
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
94
  "domain": "login"
95
  },
96
- "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
97
  "CONFIDENCE_THRESHOLDS": {
98
- "no_helmet": 0.75,
99
- "no_harness": 0.4,
100
- "unsafe_posture": 0.4,
101
- "unsafe_zone": 0.4,
102
- "improper_tool_use": 0.4
103
  },
104
- "MIN_VIOLATION_FRAMES": 3,
105
  "WORKER_TRACKING_DURATION": 5.0,
106
  "MAX_PROCESSING_TIME": 60,
107
  "FRAME_SKIP": 1,
108
- "BATCH_SIZE": 32,
109
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
110
  "TRACK_BUFFER": 30,
111
- "TRACK_THRESH": 0.4,
112
- "MATCH_THRESH": 0.8
113
  }
114
 
115
- # Initialize device and model
116
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
  logger.info(f"Using device: {device}")
118
 
@@ -128,6 +125,7 @@ def load_model():
128
  logger.info(f"Downloading fallback model: {model_path}")
129
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
130
  model = YOLO(model_path).to(device)
 
131
  return model
132
  except Exception as e:
133
  logger.error(f"Failed to load model: {e}")
@@ -138,6 +136,11 @@ model = load_model()
138
  # ==========================
139
  # Helper Functions
140
  # ==========================
 
 
 
 
 
141
  def draw_detections(frame, detections):
142
  for det in detections:
143
  label = det.get("violation", "Unknown")
@@ -297,23 +300,19 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
297
 
298
  def process_video(video_data):
299
  try:
300
- # Ensure output directory exists
301
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
302
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
303
 
304
- # Create temp video file
305
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
306
  with open(video_path, "wb") as f:
307
  f.write(video_data)
308
  logger.info(f"Video saved: {video_path}")
309
 
310
- # Open video file
311
  cap = cv2.VideoCapture(video_path)
312
  if not cap.isOpened():
313
  os.remove(video_path)
314
  raise ValueError("Could not open video file")
315
 
316
- # Get video properties
317
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
318
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
319
  duration = total_frames / fps
@@ -321,7 +320,6 @@ def process_video(video_data):
321
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
322
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
323
 
324
- # Initialize ByteTrack
325
  tracker = BYTETracker(
326
  track_thresh=CONFIG["TRACK_THRESH"],
327
  track_buffer=CONFIG["TRACK_BUFFER"],
@@ -329,18 +327,15 @@ def process_video(video_data):
329
  frame_rate=fps
330
  )
331
 
332
- # Track violations by worker ID and type
333
- violation_tracker = {} # {worker_id: {violation_type: [detections]}}
334
  snapshots = []
335
  start_time = time.time()
336
  frame_skip = CONFIG["FRAME_SKIP"]
337
 
338
- # Process frames in batches
339
  while True:
340
  batch_frames = []
341
  batch_indices = []
342
 
343
- # Collect frames for this batch
344
  for _ in range(CONFIG["BATCH_SIZE"]):
345
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
346
  if frame_idx >= total_frames:
@@ -350,7 +345,8 @@ def process_video(video_data):
350
  if not ret:
351
  break
352
 
353
- # Skip frames if needed
 
354
  for _ in range(frame_skip - 1):
355
  if not cap.grab():
356
  break
@@ -358,24 +354,19 @@ def process_video(video_data):
358
  batch_frames.append(frame)
359
  batch_indices.append(frame_idx)
360
 
361
- # Break if no more frames
362
  if not batch_frames:
363
  break
364
 
365
- # Run batch detection
366
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
367
 
368
- # Process results for each frame in batch
369
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
370
  current_time = frame_idx / fps
371
 
372
- # Update progress
373
  if time.time() - start_time > 1.0:
374
  progress = (frame_idx / total_frames) * 100
375
  yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
376
  start_time = time.time()
377
 
378
- # Prepare detections for ByteTrack
379
  boxes = result.boxes
380
  track_inputs = []
381
  for box in boxes:
@@ -383,24 +374,27 @@ def process_video(video_data):
383
  conf = float(box.conf)
384
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
385
 
386
- if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
 
 
 
 
387
  continue
388
 
389
  bbox = box.xywh.cpu().numpy()[0]
390
  track_inputs.append({
391
- "bbox": bbox, # [x, y, w, h]
392
  "conf": conf,
393
  "cls": cls
394
  })
395
 
396
- # Update tracker
397
  tracked_objects = tracker.update(
398
  np.array([t["bbox"] for t in track_inputs]),
399
  np.array([t["conf"] for t in track_inputs]),
400
  np.array([t["cls"] for t in track_inputs])
401
  )
 
402
 
403
- # Process tracked objects
404
  for obj, track_input in zip(tracked_objects, track_inputs):
405
  worker_id = obj['id']
406
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
@@ -415,8 +409,8 @@ def process_video(video_data):
415
  "timestamp": current_time,
416
  "worker_id": worker_id
417
  }
 
418
 
419
- # Track violations by worker_id and type
420
  if worker_id not in violation_tracker:
421
  violation_tracker[worker_id] = {}
422
  if label not in violation_tracker[worker_id]:
@@ -429,19 +423,19 @@ def process_video(video_data):
429
  processing_time = time.time() - start_time
430
  logger.info(f"Processing complete in {processing_time:.2f}s")
431
 
432
- # Consolidate violations
433
  violations = []
434
  for worker_id, worker_violations in violation_tracker.items():
435
  for label, detections in worker_violations.items():
436
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
437
- # Select highest-confidence detection
438
  best_detection = max(detections, key=lambda x: x["confidence"])
439
  best_detection["start_timestamp"] = min(d["timestamp"] for d in detections)
440
  best_detection["end_timestamp"] = max(d["timestamp"] for d in detections)
441
  violations.append(best_detection)
442
 
443
- # Capture snapshot for confirmed violation
444
  cap = cv2.VideoCapture(video_path)
 
 
 
445
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
446
  ret, snapshot_frame = cap.read()
447
  if ret:
@@ -457,8 +451,8 @@ def process_video(video_data):
457
  })
458
  cap.release()
459
 
460
- # Generate results
461
  if not violations:
 
462
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
463
  return
464
 
 
23
  # ==========================
24
  # Configuration and Setup
25
  # ==========================
 
26
  os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
27
  os.makedirs('/tmp/Ultralytics', exist_ok=True)
28
 
29
+ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
 
30
  logger = logging.getLogger(__name__)
31
 
32
  # ==========================
33
  # ByteTrack Implementation
34
  # ==========================
35
  class BYTETracker:
36
+ def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
 
37
  self.track_thresh = track_thresh
38
  self.track_buffer = track_buffer
39
  self.match_thresh = match_thresh
 
44
  tracks = []
45
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
46
  if score < self.track_thresh:
47
+ logger.debug(f"Skipping detection with score {score} below threshold {self.track_thresh}")
48
  continue
49
 
50
  x, y, w, h = det
 
72
  4: "improper_tool_use"
73
  },
74
  "CLASS_COLORS": {
75
+ "no_helmet": (0, 0, 255),
76
+ "no_harness": (0, 165, 255),
77
+ "unsafe_posture": (0, 255, 0),
78
+ "unsafe_zone": (255, 0, 0),
79
+ "improper_tool_use": (255, 255, 0)
80
  },
81
  "DISPLAY_NAMES": {
82
  "no_helmet": "No Helmet Violation",
 
91
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
92
  "domain": "login"
93
  },
94
+ "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Sadio2/resolve/main/static/output/",
95
  "CONFIDENCE_THRESHOLDS": {
96
+ "no_helmet": 0.5, # Lowered from 0.75
97
+ "no_harness": 0.3, # Lowered from 0.4
98
+ "unsafe_posture": 0.3, # Lowered from 0.4
99
+ "unsafe_zone": 0.3, # Lowered from 0.4
100
+ "improper_tool_use": 0.3 # Lowered from 0.4
101
  },
102
+ "MIN_VIOLATION_FRAMES": 1, # Lowered from 3
103
  "WORKER_TRACKING_DURATION": 5.0,
104
  "MAX_PROCESSING_TIME": 60,
105
  "FRAME_SKIP": 1,
106
+ "BATCH_SIZE": 16, # Reduced from 32 to prevent memory issues
107
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
108
  "TRACK_BUFFER": 30,
109
+ "TRACK_THRESH": 0.3, # Lowered from 0.4
110
+ "MATCH_THRESH": 0.7 # Lowered from 0.8
111
  }
112
 
 
113
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
  logger.info(f"Using device: {device}")
115
 
 
125
  logger.info(f"Downloading fallback model: {model_path}")
126
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
127
  model = YOLO(model_path).to(device)
128
+ logger.info(f"Model classes: {model.names}")
129
  return model
130
  except Exception as e:
131
  logger.error(f"Failed to load model: {e}")
 
136
  # ==========================
137
  # Helper Functions
138
  # ==========================
139
+ def preprocess_frame(frame):
140
+ """Apply basic preprocessing to enhance detection"""
141
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) # Increase contrast
142
+ return frame
143
+
144
  def draw_detections(frame, detections):
145
  for det in detections:
146
  label = det.get("violation", "Unknown")
 
300
 
301
  def process_video(video_data):
302
  try:
 
303
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
304
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
305
 
 
306
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
307
  with open(video_path, "wb") as f:
308
  f.write(video_data)
309
  logger.info(f"Video saved: {video_path}")
310
 
 
311
  cap = cv2.VideoCapture(video_path)
312
  if not cap.isOpened():
313
  os.remove(video_path)
314
  raise ValueError("Could not open video file")
315
 
 
316
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
317
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
318
  duration = total_frames / fps
 
320
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
321
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
322
 
 
323
  tracker = BYTETracker(
324
  track_thresh=CONFIG["TRACK_THRESH"],
325
  track_buffer=CONFIG["TRACK_BUFFER"],
 
327
  frame_rate=fps
328
  )
329
 
330
+ violation_tracker = {}
 
331
  snapshots = []
332
  start_time = time.time()
333
  frame_skip = CONFIG["FRAME_SKIP"]
334
 
 
335
  while True:
336
  batch_frames = []
337
  batch_indices = []
338
 
 
339
  for _ in range(CONFIG["BATCH_SIZE"]):
340
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
341
  if frame_idx >= total_frames:
 
345
  if not ret:
346
  break
347
 
348
+ frame = preprocess_frame(frame)
349
+
350
  for _ in range(frame_skip - 1):
351
  if not cap.grab():
352
  break
 
354
  batch_frames.append(frame)
355
  batch_indices.append(frame_idx)
356
 
 
357
  if not batch_frames:
358
  break
359
 
 
360
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
361
 
 
362
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
363
  current_time = frame_idx / fps
364
 
 
365
  if time.time() - start_time > 1.0:
366
  progress = (frame_idx / total_frames) * 100
367
  yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
368
  start_time = time.time()
369
 
 
370
  boxes = result.boxes
371
  track_inputs = []
372
  for box in boxes:
 
374
  conf = float(box.conf)
375
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
376
 
377
+ if label is None:
378
+ logger.debug(f"Unknown class ID {cls} detected, skipping")
379
+ continue
380
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
381
+ logger.debug(f"Detection for {label} with confidence {conf} below threshold {CONFIG['CONFIDENCE_THRESHOLDS'].get(label, 0.25)}")
382
  continue
383
 
384
  bbox = box.xywh.cpu().numpy()[0]
385
  track_inputs.append({
386
+ "bbox": bbox,
387
  "conf": conf,
388
  "cls": cls
389
  })
390
 
 
391
  tracked_objects = tracker.update(
392
  np.array([t["bbox"] for t in track_inputs]),
393
  np.array([t["conf"] for t in track_inputs]),
394
  np.array([t["cls"] for t in track_inputs])
395
  )
396
+ logger.debug(f"Frame {frame_idx}: {len(tracked_objects)} objects tracked")
397
 
 
398
  for obj, track_input in zip(tracked_objects, track_inputs):
399
  worker_id = obj['id']
400
  label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
 
409
  "timestamp": current_time,
410
  "worker_id": worker_id
411
  }
412
+ logger.debug(f"Detection: {detection}")
413
 
 
414
  if worker_id not in violation_tracker:
415
  violation_tracker[worker_id] = {}
416
  if label not in violation_tracker[worker_id]:
 
423
  processing_time = time.time() - start_time
424
  logger.info(f"Processing complete in {processing_time:.2f}s")
425
 
 
426
  violations = []
427
  for worker_id, worker_violations in violation_tracker.items():
428
  for label, detections in worker_violations.items():
429
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
 
430
  best_detection = max(detections, key=lambda x: x["confidence"])
431
  best_detection["start_timestamp"] = min(d["timestamp"] for d in detections)
432
  best_detection["end_timestamp"] = max(d["timestamp"] for d in detections)
433
  violations.append(best_detection)
434
 
 
435
  cap = cv2.VideoCapture(video_path)
436
+ if not cap.isOpened():
437
+ logger.warning(f"Could not reopen video for snapshot at frame {best_detection['frame']}")
438
+ continue
439
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
440
  ret, snapshot_frame = cap.read()
441
  if ret:
 
451
  })
452
  cap.release()
453
 
 
454
  if not violations:
455
+ logger.info("No violations detected after processing")
456
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
457
  return
458