PrashanthB461 commited on
Commit
220ca2f
·
verified ·
1 Parent(s): a277b4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -137
app.py CHANGED
@@ -216,7 +216,7 @@ CONFIG = {
216
  "MATCH_THRESH": 0.5,
217
  "SNAPSHOT_QUALITY": 95,
218
  "MAX_WORKER_DISTANCE": 300,
219
- "MODEL_INPUT_SIZE": (640, 640) # Updated to match YOLO input requirements
220
  }
221
 
222
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -246,56 +246,11 @@ def load_model():
246
  model = load_model()
247
 
248
  # ========================== # Helper Functions # ==========================
249
- def preprocess_frame(frame, original_shape):
250
- # Resize while preserving aspect ratio, then pad to MODEL_INPUT_SIZE (640x640)
251
- target_size = CONFIG["MODEL_INPUT_SIZE"] # (640, 640)
252
- h, w = frame.shape[:2]
253
- scale = min(target_size[0] / w, target_size[1] / h)
254
- new_w, new_h = int(w * scale), int(h * scale)
255
-
256
- # Resize the frame
257
- frame_resized = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
258
-
259
- # Create a new 640x640 image with padding
260
- padded_frame = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
261
- top = (target_size[1] - new_h) // 2
262
- left = (target_size[0] - new_w) // 2
263
- padded_frame[top:top+new_h, left:left+new_w] = frame_resized
264
-
265
- # Apply contrast adjustment
266
- padded_frame = cv2.convertScaleAbs(padded_frame, alpha=1.2, beta=20)
267
-
268
- # Store padding info to adjust bounding boxes later
269
- padding_info = {
270
- "scale": scale,
271
- "top": top,
272
- "left": left,
273
- "original_shape": original_shape
274
- }
275
-
276
- return padded_frame, padding_info
277
-
278
- def adjust_bbox(bbox, padding_info):
279
- # Adjust bounding box coordinates from padded 640x640 space back to original frame space
280
- scale = padding_info["scale"]
281
- top = padding_info["top"]
282
- left = padding_info["left"]
283
-
284
- x, y, w, h = bbox
285
- # Remove padding offset and scale back
286
- x = (x - left) / scale
287
- y = (y - top) / scale
288
- w = w / scale
289
- h = h / scale
290
-
291
- # Ensure coordinates are within original frame bounds
292
- orig_h, orig_w = padding_info["original_shape"][:2]
293
- x = max(0, min(x, orig_w))
294
- y = max(0, min(y, orig_h))
295
- w = max(0, min(w, orig_w - x))
296
- h = max(0, min(h, orig_h - y))
297
-
298
- return [x, y, w, h]
299
 
300
  def draw_detections(frame, detections):
301
  result_frame = frame.copy()
@@ -496,7 +451,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
496
 
497
  try:
498
  record = sf.Safety_Video_Report__c.create(record_data)
499
- logger.info(f"Created Safety_Violation_Report__c record: {record['id']}")
500
  except Exception as e:
501
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
502
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
@@ -593,16 +548,15 @@ def process_video(video_data, temp_dir):
593
 
594
  worker_id_mapping = {}
595
  unique_violations = {}
596
- snapshots = []
597
  start_time = time.time()
598
  frame_skip = CONFIG["FRAME_SKIP"]
599
  processed_frames = 0
 
600
 
601
  while processed_frames < total_frames:
602
  batch_frames = []
603
  batch_indices = []
604
- batch_padding_info = []
605
- batch_original_frames = []
606
 
607
  for _ in range(CONFIG["BATCH_SIZE"]):
608
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
@@ -614,9 +568,7 @@ def process_video(video_data, temp_dir):
614
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
615
  break
616
 
617
- # Keep a copy of the original frame for drawing detections
618
- original_frame = frame.copy()
619
- frame, padding_info = preprocess_frame(frame, original_shape=frame.shape)
620
 
621
  for _ in range(frame_skip - 1):
622
  if not cap.grab():
@@ -624,8 +576,6 @@ def process_video(video_data, temp_dir):
624
 
625
  batch_frames.append(frame)
626
  batch_indices.append(frame_idx)
627
- batch_padding_info.append(padding_info)
628
- batch_original_frames.append(original_frame)
629
  processed_frames += 1
630
 
631
  if not batch_frames:
@@ -633,9 +583,9 @@ def process_video(video_data, temp_dir):
633
  break
634
 
635
  try:
636
- # Convert frames to tensor and move to device
637
- batch_frames_tensor = [torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0 for frame in batch_frames]
638
- batch_frames_tensor = torch.stack(batch_frames_tensor).to(device)
639
  if device.type == "cuda":
640
  batch_frames_tensor = batch_frames_tensor.half()
641
 
@@ -648,16 +598,17 @@ def process_video(video_data, temp_dir):
648
  if device.type == "cuda":
649
  torch.cuda.empty_cache()
650
 
651
- for i, (result, frame_idx, padding_info, original_frame) in enumerate(zip(results, batch_indices, batch_padding_info, batch_original_frames)):
 
 
 
 
 
 
 
 
652
  current_time = frame_idx / fps
653
 
654
- if time.time() - start_time > 0.5:
655
- progress = (processed_frames / total_frames) * 100
656
- elapsed_time = time.time() - start_time
657
- fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
658
- yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
659
- start_time = time.time()
660
-
661
  boxes = result.boxes
662
  track_inputs = []
663
 
@@ -673,8 +624,6 @@ def process_video(video_data, temp_dir):
673
  continue
674
 
675
  bbox = box.xywh.cpu().numpy()[0]
676
- # Adjust bounding box coordinates to original frame space
677
- bbox = adjust_bbox(bbox, padding_info)
678
  track_inputs.append({
679
  "bbox": bbox,
680
  "conf": conf,
@@ -710,79 +659,92 @@ def process_video(video_data, temp_dir):
710
 
711
  if violation_key not in unique_violations:
712
  unique_violations[violation_key] = current_time
713
-
714
- detection = {
715
- "worker_id": worker_id,
716
- "violation": label,
717
- "confidence": round(float(conf), 2),
718
- "bounding_box": bbox,
719
- "timestamp": current_time
720
- }
721
-
722
- # Use the original frame for drawing detections
723
- snapshot_frame = original_frame.copy()
724
- snapshot_frame = draw_detections(snapshot_frame, [detection])
725
-
726
- cv2.putText(
727
- snapshot_frame,
728
- f"Time: {current_time:.2f}s",
729
- (10, 30),
730
- cv2.FONT_HERSHEY_SIMPLEX,
731
- 0.7,
732
- (255, 255, 255),
733
- 2
734
- )
735
-
736
- snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
737
- snapshot_path = os.path.join(output_dir, snapshot_filename)
738
-
739
- cv2.imwrite(
740
- snapshot_path,
741
- snapshot_frame,
742
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
743
- )
744
-
745
- snapshots.append({
746
- "violation": label,
747
- "worker_id": worker_id,
748
- "timestamp": current_time,
749
- "snapshot_path": snapshot_path,
750
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
751
- "confidence": round(float(conf), 2)
752
- })
753
-
754
- logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
755
-
756
- if len(snapshots) > 100:
757
- snapshots = snapshots[-10:]
758
 
759
  cap.release()
760
  processing_time = time.time() - start_time
761
  logger.info(f"Processing complete in {processing_time:.2f}s")
762
 
763
- logger.info(f"Snapshots: {snapshots}")
764
-
765
  violations = []
766
  for (worker_id, label), detection_time in unique_violations.items():
767
- confidence = next(
768
- (float(s["confidence"]) for s in snapshots if s["worker_id"] == worker_id and s["violation"] == label),
769
- 0.0
770
- )
771
- violation = {
772
  "worker_id": worker_id,
773
  "violation": label,
774
  "timestamp": detection_time,
775
- "confidence": confidence
776
- }
777
- violations.append(violation)
778
-
779
- logger.info(f"Violations: {violations}")
780
 
781
  if not violations:
782
  logger.info("No violations detected after processing")
783
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
784
  return
785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  score = calculate_safety_score(violations)
787
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
788
 
@@ -795,12 +757,7 @@ def process_video(video_data, temp_dir):
795
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
796
  worker_id = v.get("worker_id", "Unknown")
797
  timestamp = v.get("timestamp", 0.0)
798
- try:
799
- confidence = float(v.get("confidence", 0.0))
800
- except (ValueError, TypeError) as e:
801
- logger.error(f"Invalid confidence value in violation {v}: {e}")
802
- confidence = 0.0
803
-
804
  violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
805
 
806
  snapshots_text = ""
@@ -808,7 +765,6 @@ def process_video(video_data, temp_dir):
808
  display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
809
  worker_id = s.get("worker_id", "Unknown")
810
  timestamp = s.get("timestamp", 0.0)
811
-
812
  snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
813
  snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
814
 
 
216
  "MATCH_THRESH": 0.5,
217
  "SNAPSHOT_QUALITY": 95,
218
  "MAX_WORKER_DISTANCE": 300,
219
+ "TARGET_RESOLUTION": (384, 384) # Changed to 384x384 (divisible by 32)
220
  }
221
 
222
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
246
  model = load_model()
247
 
248
  # ========================== # Helper Functions # ==========================
249
+ def preprocess_frame(frame):
250
+ target_res = CONFIG["TARGET_RESOLUTION"]
251
+ frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
252
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
253
+ return frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  def draw_detections(frame, detections):
256
  result_frame = frame.copy()
 
451
 
452
  try:
453
  record = sf.Safety_Video_Report__c.create(record_data)
454
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
455
  except Exception as e:
456
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
457
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
 
548
 
549
  worker_id_mapping = {}
550
  unique_violations = {}
551
+ violation_frames = {} # Store frame indices for violations
552
  start_time = time.time()
553
  frame_skip = CONFIG["FRAME_SKIP"]
554
  processed_frames = 0
555
+ last_yield_time = start_time
556
 
557
  while processed_frames < total_frames:
558
  batch_frames = []
559
  batch_indices = []
 
 
560
 
561
  for _ in range(CONFIG["BATCH_SIZE"]):
562
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
 
568
  logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
569
  break
570
 
571
+ frame = preprocess_frame(frame)
 
 
572
 
573
  for _ in range(frame_skip - 1):
574
  if not cap.grab():
 
576
 
577
  batch_frames.append(frame)
578
  batch_indices.append(frame_idx)
 
 
579
  processed_frames += 1
580
 
581
  if not batch_frames:
 
583
  break
584
 
585
  try:
586
+ batch_frames_np = np.array(batch_frames) # Shape: (batch, height, width, channels)
587
+ batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
588
+ batch_frames_tensor = batch_frames_tensor.to(device)
589
  if device.type == "cuda":
590
  batch_frames_tensor = batch_frames_tensor.half()
591
 
 
598
  if device.type == "cuda":
599
  torch.cuda.empty_cache()
600
 
601
+ current_time = time.time()
602
+ if current_time - last_yield_time > 0.1:
603
+ progress = (processed_frames / total_frames) * 100
604
+ elapsed_time = current_time - start_time
605
+ fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
606
+ yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", ""
607
+ last_yield_time = current_time
608
+
609
+ for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
610
  current_time = frame_idx / fps
611
 
 
 
 
 
 
 
 
612
  boxes = result.boxes
613
  track_inputs = []
614
 
 
624
  continue
625
 
626
  bbox = box.xywh.cpu().numpy()[0]
 
 
627
  track_inputs.append({
628
  "bbox": bbox,
629
  "conf": conf,
 
659
 
660
  if violation_key not in unique_violations:
661
  unique_violations[violation_key] = current_time
662
+ violation_frames[violation_key] = frame_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
  cap.release()
665
  processing_time = time.time() - start_time
666
  logger.info(f"Processing complete in {processing_time:.2f}s")
667
 
 
 
668
  violations = []
669
  for (worker_id, label), detection_time in unique_violations.items():
670
+ violations.append({
 
 
 
 
671
  "worker_id": worker_id,
672
  "violation": label,
673
  "timestamp": detection_time,
674
+ "confidence": 0.0, # Will be updated after reprocessing frames
675
+ "frame_idx": violation_frames[(worker_id, label)]
676
+ })
 
 
677
 
678
  if not violations:
679
  logger.info("No violations detected after processing")
680
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
681
  return
682
 
683
+ # Reopen video to capture snapshots for violations
684
+ snapshots = []
685
+ cap = cv2.VideoCapture(video_path)
686
+ for violation in violations:
687
+ frame_idx = violation["frame_idx"]
688
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
689
+ ret, frame = cap.read()
690
+ if not ret:
691
+ logger.warning(f"Failed to read frame {frame_idx} for snapshot.")
692
+ continue
693
+
694
+ frame = preprocess_frame(frame)
695
+ frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
696
+ frame_tensor = frame_tensor.unsqueeze(0).to(device)
697
+ if device.type == "cuda":
698
+ frame_tensor = frame_tensor.half()
699
+
700
+ result = model(frame_tensor, device=device, conf=0.1, verbose=False)[0]
701
+ boxes = result.boxes
702
+
703
+ for box in boxes:
704
+ cls = int(box.cls)
705
+ conf = float(box.conf)
706
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
707
+ if label == violation["violation"]:
708
+ violation["confidence"] = round(conf, 2)
709
+ bbox = box.xywh.cpu().numpy()[0]
710
+ detection = {
711
+ "worker_id": violation["worker_id"],
712
+ "violation": label,
713
+ "confidence": violation["confidence"],
714
+ "bounding_box": bbox,
715
+ "timestamp": violation["timestamp"]
716
+ }
717
+ snapshot_frame = frame.copy()
718
+ snapshot_frame = draw_detections(snapshot_frame, [detection])
719
+ cv2.putText(
720
+ snapshot_frame,
721
+ f"Time: {violation['timestamp']:.2f}s",
722
+ (10, 30),
723
+ cv2.FONT_HERSHEY_SIMPLEX,
724
+ 0.7,
725
+ (255, 255, 255),
726
+ 2
727
+ )
728
+ snapshot_filename = f"violation_{label}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
729
+ snapshot_path = os.path.join(output_dir, snapshot_filename)
730
+ cv2.imwrite(
731
+ snapshot_path,
732
+ snapshot_frame,
733
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
734
+ )
735
+ snapshots.append({
736
+ "violation": label,
737
+ "worker_id": violation["worker_id"],
738
+ "timestamp": violation["timestamp"],
739
+ "snapshot_path": snapshot_path,
740
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
741
+ "confidence": violation["confidence"]
742
+ })
743
+ logger.info(f"Captured snapshot for {label} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
744
+ break
745
+
746
+ cap.release()
747
+
748
  score = calculate_safety_score(violations)
749
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
750
 
 
757
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
758
  worker_id = v.get("worker_id", "Unknown")
759
  timestamp = v.get("timestamp", 0.0)
760
+ confidence = v.get("confidence", 0.0)
 
 
 
 
 
761
  violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
762
 
763
  snapshots_text = ""
 
765
  display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
766
  worker_id = s.get("worker_id", "Unknown")
767
  timestamp = s.get("timestamp", 0.0)
 
768
  snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
769
  snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
770