PrashanthB461 commited on
Commit
0b9ff0b
·
verified ·
1 Parent(s): 6060e8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -30
app.py CHANGED
@@ -37,50 +37,28 @@ logger = logging.getLogger(__name__)
37
  class BYTETracker:
38
  """Custom implementation of ByteTrack to avoid installation issues"""
39
  def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
40
- try:
41
- from yolox.tracker.byte_tracker import BYTETracker as OriginalBYTETracker
42
- self.tracker = OriginalBYTETracker(
43
- track_thresh=track_thresh,
44
- track_buffer=track_buffer,
45
- match_thresh=match_thresh,
46
- frame_rate=frame_rate
47
- )
48
- self._original = True
49
- except ImportError:
50
- logger.warning("Using simplified ByteTrack implementation")
51
- self._original = False
52
- self.track_thresh = track_thresh
53
- self.track_buffer = track_buffer
54
- self.match_thresh = match_thresh
55
- self.frame_rate = frame_rate
56
- self.tracked_objects = {}
57
- self.next_id = 1
58
-
59
- def update(self, dets, scores, cls):
60
- if self._original:
61
- return self.tracker.update(dets, scores, cls)
62
-
63
- # Simplified tracking logic for fallback
64
- if len(dets) == 0:
65
- return []
66
 
 
67
  tracks = []
68
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
69
  if score < self.track_thresh:
70
  continue
71
 
72
  x, y, w, h = det
73
- track_id = self.next_id
74
- self.next_id += 1
75
  tracks.append({
76
- 'id': track_id,
77
  'bbox': [x, y, w, h],
78
  'score': score,
79
  'cls': cl
80
  })
 
81
  return tracks
82
 
83
-
84
  # ==========================
85
  # Optimized Configuration
86
  # ==========================
@@ -522,6 +500,67 @@ def process_video(video_data):
522
  yield f"Error processing video: {e}", "", "", "", ""
523
 
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  # ==========================
526
  # Gradio Interface
527
  # ==========================
 
37
  class BYTETracker:
38
  """Custom implementation of ByteTrack to avoid installation issues"""
39
  def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
40
+ self.track_thresh = track_thresh
41
+ self.track_buffer = track_buffer
42
+ self.match_thresh = match_thresh
43
+ self.frame_rate = frame_rate
44
+ self.next_id = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ def update(self, dets, scores, cls):
47
  tracks = []
48
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
49
  if score < self.track_thresh:
50
  continue
51
 
52
  x, y, w, h = det
 
 
53
  tracks.append({
54
+ 'id': self.next_id,
55
  'bbox': [x, y, w, h],
56
  'score': score,
57
  'cls': cl
58
  })
59
+ self.next_id += 1
60
  return tracks
61
 
 
62
  # ==========================
63
  # Optimized Configuration
64
  # ==========================
 
500
  yield f"Error processing video: {e}", "", "", "", ""
501
 
502
 
503
+ # Initialize device and model
504
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
505
+ logger.info(f"Using device: {device}")
506
+
507
+ def load_model():
508
+ try:
509
+ if os.path.isfile(CONFIG["MODEL_PATH"]):
510
+ model_path = CONFIG["MODEL_PATH"]
511
+ logger.info(f"Model loaded: {model_path}")
512
+ else:
513
+ model_path = CONFIG["FALLBACK_MODEL"]
514
+ logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
515
+ if not os.path.isfile(model_path):
516
+ logger.info(f"Downloading fallback model: {model_path}")
517
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
518
+ model = YOLO(model_path).to(device)
519
+ return model
520
+ except Exception as e:
521
+ logger.error(f"Failed to load model: {e}")
522
+ raise
523
+
524
+ model = load_model()
525
+
526
+ # ==========================
527
+ # Helper Functions
528
+ # ==========================
529
+ def draw_detections(frame, detections):
530
+ # ... [your existing implementation] ...
531
+
532
+ def calculate_safety_score(violations):
533
+ # ... [your existing implementation] ...
534
+
535
+ def generate_violation_pdf(violations, score):
536
+ # ... [your existing implementation] ...
537
+
538
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
539
+ def connect_to_salesforce():
540
+ # ... [your existing implementation] ...
541
+
542
+ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
543
+ # ... [your existing implementation] ...
544
+
545
+ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
546
+ # ... [your existing implementation] ...
547
+
548
+ def process_video(video_data):
549
+ # ... [your existing implementation] ...
550
+
551
+ def gradio_interface(video_file):
552
+ if not video_file:
553
+ return "No file uploaded.", "", "No file uploaded.", "", ""
554
+ try:
555
+ with open(video_file, "rb") as f:
556
+ video_data = f.read()
557
+
558
+ for status, score, snapshots_text, record_id, details_url in process_video(video_data):
559
+ yield status, score, snapshots_text, record_id, details_url
560
+ except Exception as e:
561
+ logger.error(f"Error in Gradio interface: {e}", exc_info=True)
562
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
563
+
564
  # ==========================
565
  # Gradio Interface
566
  # ==========================