Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,18 +35,7 @@ def attention_block(inputs, time_steps):
|
|
| 35 |
|
| 36 |
@st.cache(allow_output_mutation=True)
|
| 37 |
def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
|
| 38 |
-
|
| 39 |
-
Function used to build the deep neural network model on startup
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
HIDDEN_UNITS (int, optional): Number of hidden units for each neural network hidden layer. Defaults to 256.
|
| 43 |
-
sequence_length (int, optional): Input sequence length (i.e., number of frames). Defaults to 30.
|
| 44 |
-
num_input_values (_type_, optional): Input size of the neural network model. Defaults to 33*4 (i.e., number of keypoints x number of metrics).
|
| 45 |
-
num_classes (int, optional): Number of classification categories (i.e., model output size). Defaults to 3.
|
| 46 |
-
|
| 47 |
-
Returns:
|
| 48 |
-
keras model: neural network with pre-trained weights
|
| 49 |
-
"""
|
| 50 |
# Input
|
| 51 |
inputs = Input(shape=(sequence_length, num_input_values))
|
| 52 |
# Bi-LSTM
|
|
@@ -70,24 +59,10 @@ def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num
|
|
| 70 |
|
| 71 |
HIDDEN_UNITS = 256
|
| 72 |
model = build_model(HIDDEN_UNITS)
|
| 73 |
-
|
| 74 |
-
## App
|
| 75 |
-
st.write("# AI Personal Fitness Trainer Web App")
|
| 76 |
-
|
| 77 |
-
st.markdown("ββ **Development Note** ββ")
|
| 78 |
-
st.markdown("Currently, the exercise recognition model uses the the x, y, and z coordinates of each anatomical landmark from the MediaPipe Pose model. These coordinates are normalized with respect to the image frame (e.g., the top left corner represents (x=0,y=0) and the bottom right corner represents(x=1,y=1)).")
|
| 79 |
-
st.markdown("I'm currently developing and testing two new feature engineering strategies:")
|
| 80 |
-
st.markdown("- Normalizing coordinates by the detected bounding box of the user")
|
| 81 |
-
st.markdown("- Using joint angles rather than keypoint coordaintes as features")
|
| 82 |
-
st.write("Stay Tuned!")
|
| 83 |
-
|
| 84 |
-
st.write("## Settings")
|
| 85 |
threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
|
| 86 |
threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
|
| 87 |
threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
|
| 88 |
|
| 89 |
-
st.write("## Activate the AI π€ποΈββοΈ")
|
| 90 |
-
|
| 91 |
## Mediapipe
|
| 92 |
mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
|
| 93 |
mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
|
|
@@ -182,8 +157,62 @@ class VideoProcessor:
|
|
| 182 |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
|
| 183 |
)
|
| 184 |
return
|
| 185 |
-
|
| 186 |
@st.cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
def count_reps(self, image, landmarks, mp_pose):
|
| 188 |
"""
|
| 189 |
Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
|
|
@@ -288,155 +317,82 @@ class VideoProcessor:
|
|
| 288 |
cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
|
| 289 |
|
| 290 |
return output_frame
|
| 291 |
-
|
| 292 |
-
# @st.cache()
|
| 293 |
-
# def process(self, image):
|
| 294 |
-
# """
|
| 295 |
-
# Function to process the video frame from the user's webcam and run the fitness trainer AI
|
| 296 |
-
|
| 297 |
-
# Args:
|
| 298 |
-
# image (numpy array): input image from the webcam
|
| 299 |
-
|
| 300 |
-
# Returns:
|
| 301 |
-
# numpy array: processed image with keypoint detection and fitness activity classification visualized
|
| 302 |
-
# """
|
| 303 |
-
# # Pose detection model
|
| 304 |
-
# image.flags.writeable = False
|
| 305 |
-
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 306 |
-
# results = pose.process(image)
|
| 307 |
|
| 308 |
-
# # Draw the hand annotations on the image.
|
| 309 |
-
# image.flags.writeable = True
|
| 310 |
-
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 311 |
-
# self.draw_landmarks(image, results)
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
# self.current_action = self.actions[np.argmax(res)]
|
| 325 |
-
# confidence = np.max(res)
|
| 326 |
-
|
| 327 |
-
# # Erase current action variable if no probability is above threshold
|
| 328 |
-
# if confidence < self.threshold:
|
| 329 |
-
# self.current_action = ''
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
# # Count reps
|
| 335 |
-
# try:
|
| 336 |
-
# landmarks = results.pose_landmarks.landmark
|
| 337 |
-
# self.count_reps(
|
| 338 |
-
# image, landmarks, mp_pose)
|
| 339 |
-
# except:
|
| 340 |
-
# pass
|
| 341 |
-
|
| 342 |
-
# # Display graphical information
|
| 343 |
-
# cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
|
| 344 |
-
# cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
|
| 345 |
-
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 346 |
-
# cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
|
| 347 |
-
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 348 |
-
# cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
|
| 349 |
-
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 350 |
-
|
| 351 |
-
# # return cv2.flip(image, 1)
|
| 352 |
-
# return image
|
| 353 |
-
|
| 354 |
-
# def recv(self, frame):
|
| 355 |
-
# """
|
| 356 |
-
# Receive and process video stream from webcam
|
| 357 |
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
# img = frame.to_ndarray(format="bgr24")
|
| 365 |
-
# img = self.process(img)
|
| 366 |
-
# return av.VideoFrame.from_ndarray(img, format="bgr24")
|
| 367 |
-
def process_uploaded_file(self, file):
|
| 368 |
-
"""
|
| 369 |
-
Function to process an uploaded image or video file and run the fitness trainer AI
|
| 370 |
-
Args:
|
| 371 |
-
file (BytesIO): uploaded image or video file
|
| 372 |
-
Returns:
|
| 373 |
-
numpy array: processed image with keypoint detection and fitness activity classification visualized
|
| 374 |
-
"""
|
| 375 |
-
# Initialize an empty list to store processed frames
|
| 376 |
-
processed_frames = []
|
| 377 |
-
|
| 378 |
-
# Check if the uploaded file is a video
|
| 379 |
-
is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
|
| 380 |
-
|
| 381 |
-
if is_video:
|
| 382 |
-
container = av.open(file)
|
| 383 |
-
for frame in container.decode(video=0):
|
| 384 |
-
# Convert the frame to OpenCV format
|
| 385 |
-
image = frame.to_image().convert("RGB")
|
| 386 |
-
image = np.array(image)
|
| 387 |
|
| 388 |
-
|
| 389 |
-
|
| 390 |
|
| 391 |
-
|
| 392 |
-
|
| 393 |
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
|
| 402 |
-
|
| 403 |
-
|
| 404 |
|
| 405 |
-
|
| 406 |
-
|
| 407 |
|
| 408 |
-
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
|
| 427 |
-
|
| 428 |
|
| 429 |
-
# Options
|
| 430 |
-
RTC_CONFIGURATION = RTCConfiguration(
|
| 431 |
-
|
| 432 |
-
)
|
| 433 |
|
| 434 |
-
# Streamer
|
| 435 |
-
webrtc_ctx = webrtc_streamer(
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
)
|
|
|
|
| 35 |
|
| 36 |
@st.cache(allow_output_mutation=True)
|
| 37 |
def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
|
| 38 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
# Input
|
| 40 |
inputs = Input(shape=(sequence_length, num_input_values))
|
| 41 |
# Bi-LSTM
|
|
|
|
| 59 |
|
| 60 |
HIDDEN_UNITS = 256
|
| 61 |
model = build_model(HIDDEN_UNITS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
|
| 63 |
threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
|
| 64 |
threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
|
| 65 |
|
|
|
|
|
|
|
| 66 |
## Mediapipe
|
| 67 |
mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
|
| 68 |
mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
|
|
|
|
| 157 |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
|
| 158 |
)
|
| 159 |
return
|
|
|
|
| 160 |
@st.cache()
|
| 161 |
+
def process_video(self, video_file):
|
| 162 |
+
"""
|
| 163 |
+
Processes each frame of the input video, performs pose estimation,
|
| 164 |
+
and counts repetitions of each exercise.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
video_file (BytesIO): Input video file.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
tuple: A tuple containing the processed video frames with annotations
|
| 171 |
+
and the final count of repetitions for each exercise.
|
| 172 |
+
"""
|
| 173 |
+
cap = cv2.VideoCapture(video_file)
|
| 174 |
+
out_frames = []
|
| 175 |
+
# Initialize repetition counters
|
| 176 |
+
self.curl_counter = 0
|
| 177 |
+
self.press_counter = 0
|
| 178 |
+
self.squat_counter = 0
|
| 179 |
+
|
| 180 |
+
while cap.isOpened():
|
| 181 |
+
ret, frame = cap.read()
|
| 182 |
+
if not ret:
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
# Convert frame to RGB (Mediapipe requires RGB input)
|
| 186 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 187 |
+
|
| 188 |
+
# Pose estimation
|
| 189 |
+
results = pose.process(frame_rgb)
|
| 190 |
+
|
| 191 |
+
# Draw landmarks
|
| 192 |
+
self.draw_landmarks(frame, results)
|
| 193 |
+
|
| 194 |
+
# Extract keypoints
|
| 195 |
+
keypoints = self.extract_keypoints(results)
|
| 196 |
+
|
| 197 |
+
# Count repetitions
|
| 198 |
+
self.count_reps(frame, results.pose_landmarks, mp_pose)
|
| 199 |
+
|
| 200 |
+
# Visualize probabilities
|
| 201 |
+
if len(self.sequence) == self.sequence_length:
|
| 202 |
+
sequence = np.array([self.sequence])
|
| 203 |
+
res = model.predict(sequence)
|
| 204 |
+
frame = self.prob_viz(res[0], frame)
|
| 205 |
+
|
| 206 |
+
# Append frame to output frames
|
| 207 |
+
out_frames.append(frame)
|
| 208 |
+
|
| 209 |
+
# Release video capture
|
| 210 |
+
cap.release()
|
| 211 |
+
|
| 212 |
+
# Return annotated frames and repetition counts
|
| 213 |
+
return out_frames, {'curl': self.curl_counter, 'press': self.press_counter, 'squat': self.squat_counter}
|
| 214 |
+
@st.cache()
|
| 215 |
+
|
| 216 |
def count_reps(self, image, landmarks, mp_pose):
|
| 217 |
"""
|
| 218 |
Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
|
|
|
|
| 317 |
cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
|
| 318 |
|
| 319 |
return output_frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
+
video_processor.process_video_input(threshold1, threshold2, threshold3)
|
| 323 |
+
# def process_uploaded_file(self, file):
|
| 324 |
+
# """
|
| 325 |
+
# Function to process an uploaded image or video file and run the fitness trainer AI
|
| 326 |
+
# Args:
|
| 327 |
+
# file (BytesIO): uploaded image or video file
|
| 328 |
+
# Returns:
|
| 329 |
+
# numpy array: processed image with keypoint detection and fitness activity classification visualized
|
| 330 |
+
# """
|
| 331 |
+
# # Initialize an empty list to store processed frames
|
| 332 |
+
# processed_frames = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
# # Check if the uploaded file is a video
|
| 335 |
+
# is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
+
# if is_video:
|
| 338 |
+
# container = av.open(file)
|
| 339 |
+
# for frame in container.decode(video=0):
|
| 340 |
+
# # Convert the frame to OpenCV format
|
| 341 |
+
# image = frame.to_image().convert("RGB")
|
| 342 |
+
# image = np.array(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
+
# # Process the frame
|
| 345 |
+
# processed_frame = self.process(image)
|
| 346 |
|
| 347 |
+
# # Append the processed frame to the list
|
| 348 |
+
# processed_frames.append(processed_frame)
|
| 349 |
|
| 350 |
+
# # Close the video file container
|
| 351 |
+
# container.close()
|
| 352 |
+
# else:
|
| 353 |
+
# # If the uploaded file is an image
|
| 354 |
+
# # Load the image from the BytesIO object
|
| 355 |
+
# image = Image.open(file)
|
| 356 |
+
# image = np.array(image)
|
| 357 |
|
| 358 |
+
# # Process the image
|
| 359 |
+
# processed_frame = self.process(image)
|
| 360 |
|
| 361 |
+
# # Append the processed frame to the list
|
| 362 |
+
# processed_frames.append(processed_frame)
|
| 363 |
|
| 364 |
+
# return processed_frames
|
| 365 |
|
| 366 |
+
# def recv_uploaded_file(self, file):
|
| 367 |
+
# """
|
| 368 |
+
# Receive and process an uploaded video file
|
| 369 |
+
# Args:
|
| 370 |
+
# file (BytesIO): uploaded video file
|
| 371 |
+
# Returns:
|
| 372 |
+
# List[av.VideoFrame]: list of processed video frames
|
| 373 |
+
# """
|
| 374 |
+
# # Process the uploaded file
|
| 375 |
+
# processed_frames = self.process_uploaded_file(file)
|
| 376 |
|
| 377 |
+
# # Convert processed frames to av.VideoFrame objects
|
| 378 |
+
# av_frames = []
|
| 379 |
+
# for frame in processed_frames:
|
| 380 |
+
# av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
|
| 381 |
+
# av_frames.append(av_frame)
|
| 382 |
|
| 383 |
+
# return av_frames
|
| 384 |
|
| 385 |
+
# # Options
|
| 386 |
+
# RTC_CONFIGURATION = RTCConfiguration(
|
| 387 |
+
# {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
| 388 |
+
# )
|
| 389 |
|
| 390 |
+
# # Streamer
|
| 391 |
+
# webrtc_ctx = webrtc_streamer(
|
| 392 |
+
# key="AI trainer",
|
| 393 |
+
# mode=WebRtcMode.SENDRECV,
|
| 394 |
+
# rtc_configuration=RTC_CONFIGURATION,
|
| 395 |
+
# media_stream_constraints={"video": True, "audio": False},
|
| 396 |
+
# video_processor_factory=VideoProcessor,
|
| 397 |
+
# async_processing=True,
|
| 398 |
+
# )
|