arittrabag commited on
Commit
95ad330
·
verified ·
1 Parent(s): 7cfad87

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import tempfile
4
+ import torch
5
+ import numpy as np
6
+ from ultralytics import YOLO
7
+ from deep_sort_realtime.deepsort_tracker import DeepSort
8
+ import warnings
9
+ import os
10
+ import platform
11
+
12
+ # Suppress ScriptRunContext warnings from threads
13
+ warnings.filterwarnings("ignore", message=".*missing ScriptRunContext.*")
14
+
15
+ # Check if running in headless environment
16
+ IS_HEADLESS = platform.system() == 'Linux'
17
+
18
+ # Initialize YOLO model
19
+ @st.cache_resource
20
+ def load_yolo_model():
21
+ try:
22
+ return YOLO("best.pt")
23
+ except Exception as e:
24
+ st.error(f"Error loading YOLO model: {str(e)}")
25
+ return None
26
+
27
+ # Main app
28
+ st.title("📦 Inventory Management")
29
+
30
+ # Settings
31
+ st.sidebar.header("Settings")
32
+ CONF_THRESHOLD = st.sidebar.slider(
33
+ "Confidence Threshold",
34
+ min_value=0.0,
35
+ max_value=1.0,
36
+ value=0.4,
37
+ help="Higher values mean more confident detections but might miss objects"
38
+ )
39
+
40
+ FRAME_SKIP = st.sidebar.slider(
41
+ "Frame Skip",
42
+ min_value=0,
43
+ max_value=10,
44
+ value=2,
45
+ help="Process every Nth frame (higher values = faster processing but may miss objects)"
46
+ )
47
+
48
+ # Load YOLO model
49
+ yolo_model = load_yolo_model()
50
+
51
+ if yolo_model is None:
52
+ st.error("Failed to load YOLO model. Please check if the model file exists.")
53
+ st.stop()
54
+
55
+ # File uploader
56
+ uploaded_file = st.sidebar.file_uploader(
57
+ "Upload Video",
58
+ type=["mp4", "avi", "mov"],
59
+ help="Supported formats: MP4, AVI, MOV"
60
+ )
61
+
62
+ if uploaded_file is not None:
63
+ try:
64
+ tfile = tempfile.NamedTemporaryFile(delete=False)
65
+ tfile.write(uploaded_file.read())
66
+ video_path = tfile.name
67
+
68
+ if st.sidebar.button("Start Processing"):
69
+ tracker = DeepSort(
70
+ embedder="mobilenet",
71
+ embedder_gpu=torch.cuda.is_available(),
72
+ max_age=30 # Increase max_age for longer tracking retention
73
+ )
74
+
75
+ cap = cv2.VideoCapture(video_path)
76
+ if not cap.isOpened():
77
+ st.error("Error opening video file")
78
+ st.stop()
79
+
80
+ # Get video properties for processing
81
+ fps = cap.get(cv2.CAP_PROP_FPS)
82
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
83
+
84
+ counted_objects = set()
85
+ frame_placeholder = st.empty()
86
+ status_text = st.sidebar.empty()
87
+ progress_bar = st.progress(0)
88
+
89
+ # Counter for frame skipping
90
+ frame_counter = 0
91
+
92
+ while True:
93
+ ret, frame = cap.read()
94
+ if not ret:
95
+ break
96
+
97
+ frame_counter += 1
98
+ current_position = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
99
+ progress = current_position / frame_count
100
+ progress_bar.progress(progress)
101
+
102
+ # Skip frames based on user setting
103
+ if FRAME_SKIP > 0 and frame_counter % (FRAME_SKIP + 1) != 0:
104
+ continue
105
+
106
+ try:
107
+ # Resize frame for faster processing (if needed)
108
+ # h, w = frame.shape[:2]
109
+ # if w > 1280: # Only resize if the frame is large
110
+ # frame = cv2.resize(frame, (1280, int(h * 1280 / w)))
111
+
112
+ results = yolo_model(frame, verbose=False) # Turn off verbose output for speed
113
+ detections = []
114
+ for result in results:
115
+ for box in result.boxes.data.tolist():
116
+ x1, y1, x2, y2, score, class_id = box
117
+ if score > CONF_THRESHOLD:
118
+ detections.append([[x1, y1, x2 - x1, y2 - y1], score, int(class_id)])
119
+
120
+ tracks = tracker.update_tracks(detections, frame=frame)
121
+ for track in tracks:
122
+ if not track.is_confirmed():
123
+ continue
124
+ track_id = track.track_id
125
+ ltrb = track.to_ltrb()
126
+ x1, y1, x2, y2 = map(int, ltrb)
127
+ counted_objects.add(track_id)
128
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
129
+ cv2.putText(frame, f"ID: {track_id}", (x1, y1-10),
130
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
131
+
132
+ cv2.putText(frame, f"Total Objects: {len(counted_objects)}",
133
+ (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
134
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
135
+ frame_placeholder.image(frame_rgb, channels="RGB", use_column_width=True)
136
+ status_text.info(f"Processing... Current count: {len(counted_objects)}")
137
+
138
+ # Remove sleep to maximize performance
139
+ # time.sleep(0.01)
140
+
141
+ except Exception as e:
142
+ st.error(f"Error processing frame: {str(e)}")
143
+ continue
144
+
145
+ cap.release()
146
+ progress_bar.progress(1.0)
147
+ st.sidebar.success(f"Final count: {len(counted_objects)} objects")
148
+ st.balloons()
149
+
150
+ except Exception as e:
151
+ st.error(f"Error processing video: {str(e)}")
152
+ finally:
153
+ if 'tfile' in locals():
154
+ tfile.close()