Phani-1 commited on
Commit
d9b157e
·
verified ·
1 Parent(s): 6b5ca8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -209
app.py CHANGED
@@ -1,209 +1,195 @@
1
- import cv2
2
- import tempfile
3
- import numpy as np
4
- import mediapipe as mp
5
- import streamlit as st
6
- import tensorflow as tf
7
-
8
- POSE_CONNECTIONS = [
9
- (0, 1), (1, 2), (2, 3), (3, 7),
10
- (0, 4), (4, 5), (5, 6), (6, 8),
11
- (9, 10), (11, 12), (11, 13), (13, 15),
12
- (15, 17), (15, 19), (15, 21), (17, 19),
13
- (12, 14), (14, 16), (16, 18), (16, 20),
14
- (16, 22), (18, 20), (11, 23), (12, 24),
15
- (23, 24), (23, 25), (24, 26), (25, 27),
16
- (26, 28), (27, 29), (28, 30), (29, 31),
17
- (30, 32)
18
- ]
19
-
20
-
21
- @st.cache_resource
22
- def load_model():
23
- return tf.saved_model.load("Models/ssd_mobilenet/saved_model")
24
-
25
-
26
- model = load_model()
27
- mp_pose = mp.solutions.pose
28
-
29
- labels = {1: 'person'}
30
-
31
-
32
- def detect_persons(image):
33
- tensor_img = tf.convert_to_tensor(image)
34
- tensor_img = tensor_img[tf.newaxis, ...]
35
-
36
- detections = model(tensor_img)
37
-
38
- boxes = detections['detection_boxes'][0].numpy()
39
- scores = detections['detection_scores'][0].numpy()
40
- classes = detections['detection_classes'][0].numpy().astype(np.int32)
41
-
42
- return boxes, scores, classes
43
-
44
-
45
- def draw_landmarks(img, landmarks):
46
- height, width, _ = img.shape
47
- for lm in landmarks.landmark:
48
- cx, cy = int(lm.x * width), int(lm.y * height)
49
- cv2.circle(img, (cx, cy), 8, (0, 0, 255), -1)
50
-
51
- for connection in POSE_CONNECTIONS:
52
- start_idx, end_idx = connection
53
- if landmarks.landmark[start_idx] and landmarks.landmark[end_idx]:
54
- start_point = landmarks.landmark[start_idx]
55
- end_point = landmarks.landmark[end_idx]
56
-
57
- start_coordinates = (int(start_point.x * width), int(start_point.y * height))
58
- end_coordinates = (int(end_point.x * width), int(end_point.y * height))
59
-
60
- cv2.line(img, start_coordinates, end_coordinates, (0, 255, 0), 3)
61
-
62
- return img
63
-
64
-
65
- def draw_bounding_box(img, box, width, height):
66
- y_min, x_min, y_max, x_max = box
67
- left, right, top, bottom = x_min * width, x_max * width, y_min * height, y_max * height
68
- cv2.rectangle(img, (int(left), int(top)), (int(right), int(bottom)), (255, 0, 0), 2)
69
-
70
-
71
- def process_frame(frame, pose, draw_box):
72
- image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
73
- boxes, scores, classes = detect_persons(image_rgb)
74
-
75
- height, width, _ = frame.shape
76
- for i in range(len(scores)):
77
- if scores[i] > 0.5 and classes[i] == 1:
78
- y_min, x_min, y_max, x_max = boxes[i]
79
- left, right, top, bottom = x_min * width, x_max * width, y_min * height, y_max * height
80
- person_roi = frame[int(top):int(bottom), int(left):int(right)]
81
-
82
- results = pose.process(cv2.cvtColor(person_roi, cv2.COLOR_BGR2RGB))
83
-
84
- if results.pose_landmarks:
85
- person_roi = draw_landmarks(person_roi, results.pose_landmarks)
86
-
87
- frame[int(top):int(bottom), int(left):int(right)] = person_roi
88
- if draw_box:
89
- draw_bounding_box(frame, boxes[i], width, height)
90
-
91
- return frame
92
-
93
-
94
- def main():
95
- st.markdown(
96
- """
97
- <style>
98
- .title {
99
- font-size: 36px;
100
- color: #000000;
101
- padding-bottom: 40px;
102
- border-bottom: 4px solid #000000;
103
- }
104
- .intro {
105
- font-size: 18px;
106
- margin-top: 20px;
107
- margin-bottom: 20px;
108
- }
109
- .upload-section {
110
- background-color: #f0f0f0;
111
- padding: 20px;
112
- border-radius: 10px;
113
- margin-bottom: 20px;
114
- }
115
- .button-primary {
116
- background-color: #008CBA;
117
- color: white;
118
- font-weight: bold;
119
- padding: 10px 20px;
120
- border-radius: 5px;
121
- transition: background-color 0.3s ease;
122
- text-align: center;
123
- display: inline-block;
124
- cursor: pointer;
125
- }
126
- .button-primary:hover {
127
- background-color: #005f7f;
128
- }
129
- </style>
130
- """,
131
- unsafe_allow_html=True
132
- )
133
-
134
- st.markdown("<p class='title'>Multi-Person Pose Estimation</p>", unsafe_allow_html=True)
135
- st.markdown("<p class='intro'>Choose an operation type:</p>", unsafe_allow_html=True)
136
-
137
- operation_type = st.radio("Choose operation type", ("Input", "Demo"))
138
-
139
- if operation_type == "Input":
140
- input_type = st.radio("Choose input type", ("Image", "Video"))
141
-
142
- if input_type == "Image":
143
- uploaded_file = st.file_uploader(
144
- "Upload an image file (.jpg, .jpeg, .png)",
145
- type=["jpg", "jpeg", "png"]
146
- )
147
- else:
148
- uploaded_file = st.file_uploader(
149
- "Upload a video file (.mp4, .mov, .avi, .mkv)",
150
- type=["mp4", "mov", "avi", "mkv"]
151
- )
152
-
153
- draw_box = st.checkbox("Draw bounding box", value=False)
154
-
155
- pose = mp_pose.Pose()
156
-
157
- if uploaded_file is not None:
158
- with tempfile.NamedTemporaryFile(delete=False) as temp_file:
159
- temp_file.write(uploaded_file.read())
160
- file_path = temp_file.name
161
-
162
- if input_type == "Video":
163
- cam = cv2.VideoCapture(file_path)
164
- st_frame = st.empty()
165
-
166
- while cam.isOpened():
167
- success, frame = cam.read()
168
- if not success:
169
- break
170
-
171
- frame = process_frame(frame, pose, draw_box)
172
-
173
- st_frame.image(frame, channels='BGR', use_column_width=True)
174
- st.empty()
175
-
176
- st.text("Completed")
177
- cam.release()
178
-
179
- elif input_type == "Image":
180
- image = cv2.imread(file_path)
181
- processed_image = process_frame(image, pose, draw_box)
182
-
183
- st.image(processed_image, channels='BGR', use_column_width=True)
184
-
185
- elif operation_type == "Demo":
186
- st.empty()
187
- st.markdown("<p class='intro'>Demo video will be shown below:</p>", unsafe_allow_html=True)
188
-
189
- demo_video_path = "Videos/video.mp4"
190
- cam = cv2.VideoCapture(demo_video_path)
191
- st_frame = st.empty()
192
- pose = mp_pose.Pose()
193
-
194
- while cam.isOpened():
195
- success, frame = cam.read()
196
- if not success:
197
- break
198
-
199
- frame = process_frame(frame, pose, draw_box=False)
200
-
201
- st_frame.image(frame, channels='BGR', use_column_width=True)
202
- st.empty()
203
-
204
- st.text("Completed")
205
- cam.release()
206
-
207
-
208
- if __name__ == "__main__":
209
- main()
 
1
+ import gc
2
+ import cv2
3
+ import tempfile
4
+ import numpy as np
5
+ import mediapipe as mp
6
+ import streamlit as st
7
+ import tensorflow as tf
8
+
9
+ POSE_CONNECTIONS = [
10
+ (0, 1), (1, 2), (2, 3), (3, 7),
11
+ (0, 4), (4, 5), (5, 6), (6, 8),
12
+ (9, 10), (11, 12), (11, 13), (13, 15),
13
+ (15, 17), (15, 19), (15, 21), (17, 19),
14
+ (12, 14), (14, 16), (16, 18), (16, 20),
15
+ (16, 22), (18, 20), (11, 23), (12, 24),
16
+ (23, 24), (23, 25), (24, 26), (25, 27),
17
+ (26, 28), (27, 29), (28, 30), (29, 31),
18
+ (30, 32)
19
+ ]
20
+
21
+ @st.cache_resource
22
+ def load_model():
23
+ return tf.saved_model.load("Models/ssd_mobilenet/saved_model")
24
+
25
+ model = load_model()
26
+ mp_pose = mp.solutions.pose
27
+
28
+ labels = {1: 'person'}
29
+
30
+ def detect_persons(image):
31
+ tensor_img = tf.convert_to_tensor(image)
32
+ tensor_img = tensor_img[tf.newaxis, ...]
33
+ detections = model(tensor_img)
34
+ boxes = detections['detection_boxes'][0].numpy()
35
+ scores = detections['detection_scores'][0].numpy()
36
+ classes = detections['detection_classes'][0].numpy().astype(np.int32)
37
+ return boxes, scores, classes
38
+
39
+ def draw_landmarks(img, landmarks):
40
+ height, width, _ = img.shape
41
+ for lm in landmarks.landmark:
42
+ cx, cy = int(lm.x * width), int(lm.y * height)
43
+ cv2.circle(img, (cx, cy), 8, (0, 0, 255), -1)
44
+ for connection in POSE_CONNECTIONS:
45
+ start_idx, end_idx = connection
46
+ if landmarks.landmark[start_idx] and landmarks.landmark[end_idx]:
47
+ start_point = landmarks.landmark[start_idx]
48
+ end_point = landmarks.landmark[end_idx]
49
+ start_coordinates = (int(start_point.x * width), int(start_point.y * height))
50
+ end_coordinates = (int(end_point.x * width), int(end_point.y * height))
51
+ cv2.line(img, start_coordinates, end_coordinates, (0, 255, 0), 3)
52
+ return img
53
+
54
+ def draw_bounding_box(img, box, width, height):
55
+ y_min, x_min, y_max, x_max = box
56
+ left, right, top, bottom = x_min * width, x_max * width, y_min * height, y_max * height
57
+ cv2.rectangle(img, (int(left), int(top)), (int(right), int(bottom)), (255, 0, 0), 2)
58
+
59
+ def process_frame(frame, pose, draw_box):
60
+ image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
61
+ boxes, scores, classes = detect_persons(image_rgb)
62
+ height, width, _ = frame.shape
63
+ for i in range(len(scores)):
64
+ if scores[i] > 0.5 and classes[i] == 1:
65
+ y_min, x_min, y_max, x_max = boxes[i]
66
+ left, right, top, bottom = x_min * width, x_max * width, y_min * height, y_max * height
67
+ person_roi = frame[int(top):int(bottom), int(left):int(right)]
68
+ results = pose.process(cv2.cvtColor(person_roi, cv2.COLOR_BGR2RGB))
69
+ if results.pose_landmarks:
70
+ person_roi = draw_landmarks(person_roi, results.pose_landmarks)
71
+ frame[int(top):int(bottom), int(left):int(right)] = person_roi
72
+ if draw_box:
73
+ draw_bounding_box(frame, boxes[i], width, height)
74
+ return frame
75
+
76
+ def main():
77
+ st.markdown(
78
+ """
79
+ <style>
80
+ .title {
81
+ font-size: 36px;
82
+ color: #000000;
83
+ padding-bottom: 40px;
84
+ border-bottom: 4px solid #000000;
85
+ }
86
+ .intro {
87
+ font-size: 18px;
88
+ margin-top: 20px;
89
+ margin-bottom: 20px;
90
+ }
91
+ .upload-section {
92
+ background-color: #f0f0f0;
93
+ padding: 20px;
94
+ border-radius: 10px;
95
+ margin-bottom: 20px;
96
+ }
97
+ .button-primary {
98
+ background-color: #008CBA;
99
+ color: white;
100
+ font-weight: bold;
101
+ padding: 10px 20px;
102
+ border-radius: 5px;
103
+ transition: background-color 0.3s ease;
104
+ text-align: center;
105
+ display: inline-block;
106
+ cursor: pointer;
107
+ }
108
+ .button-primary:hover {
109
+ background-color: #005f7f;
110
+ }
111
+ </style>
112
+ """,
113
+ unsafe_allow_html=True
114
+ )
115
+
116
+ st.markdown("<p class='title'>Multi-Person Pose Estimation</p>", unsafe_allow_html=True)
117
+ st.markdown("<p class='intro'>Choose an operation type:</p>", unsafe_allow_html=True)
118
+
119
+ operation_type = st.radio("Choose operation type", ("Input", "Demo"))
120
+
121
+ if operation_type == "Input":
122
+ input_type = st.radio("Choose input type", ("Image", "Video"))
123
+
124
+ if input_type == "Image":
125
+ uploaded_file = st.file_uploader(
126
+ "Upload an image file (.jpg, .jpeg, .png)",
127
+ type=["jpg", "jpeg", "png"]
128
+ )
129
+ else:
130
+ uploaded_file = st.file_uploader(
131
+ "Upload a video file (.mp4, .mov, .avi, .mkv)",
132
+ type=["mp4", "mov", "avi", "mkv"]
133
+ )
134
+
135
+ draw_box = st.checkbox("Draw bounding box", value=False)
136
+
137
+ pose = mp_pose.Pose()
138
+
139
+ if uploaded_file is not None:
140
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
141
+ temp_file.write(uploaded_file.read())
142
+ file_path = temp_file.name
143
+
144
+ if input_type == "Video":
145
+ cam = cv2.VideoCapture(file_path)
146
+ st_frame = st.empty()
147
+
148
+ while cam.isOpened():
149
+ success, frame = cam.read()
150
+ if not success:
151
+ break
152
+
153
+ frame = process_frame(frame, pose, draw_box)
154
+
155
+ st_frame.image(frame, channels='BGR', use_column_width=True)
156
+
157
+ # Ensure proper synchronization and frame display
158
+ st_frame.empty()
159
+
160
+ st.text("Completed")
161
+ cam.release()
162
+
163
+ elif input_type == "Image":
164
+ image = cv2.imread(file_path)
165
+ processed_image = process_frame(image, pose, draw_box)
166
+
167
+ st.image(processed_image, channels='BGR', use_column_width=True)
168
+
169
+ gc.collect()
170
+
171
+ elif operation_type == "Demo":
172
+ st.empty()
173
+ st.markdown("<p class='intro'>Demo video will be shown below:</p>", unsafe_allow_html=True)
174
+
175
+ demo_video_path = "Videos/video.mp4"
176
+ cam = cv2.VideoCapture(demo_video_path)
177
+ st_frame = st.empty()
178
+ pose = mp_pose.Pose()
179
+
180
+ while cam.isOpened():
181
+ success, frame = cam.read()
182
+ if not success:
183
+ break
184
+
185
+ frame = process_frame(frame, pose, draw_box=False)
186
+
187
+ st_frame.image(frame, channels='BGR', use_column_width=True)
188
+ st_frame.empty()
189
+
190
+ st.text("Completed")
191
+ cam.release()
192
+ gc.collect()
193
+
194
+ if __name__ == "__main__":
195
+ main()