Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,9 @@ import math
|
|
| 11 |
|
| 12 |
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
| 13 |
import av
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
## Build and Load Model
|
| 16 |
def attention_block(inputs, time_steps):
|
|
@@ -286,87 +289,148 @@ class VideoProcessor:
|
|
| 286 |
|
| 287 |
return output_frame
|
| 288 |
|
| 289 |
-
@st.cache()
|
| 290 |
-
def process(self, image):
|
| 291 |
-
|
| 292 |
-
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
| 323 |
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
|
| 348 |
-
|
| 349 |
-
|
| 350 |
|
| 351 |
-
def recv(self, frame):
|
| 352 |
-
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
Args:
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
Returns:
|
| 359 |
-
av.VideoFrame: processed video
|
| 360 |
"""
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
## Stream Webcam Video and Run Model
|
| 366 |
# Options
|
| 367 |
RTC_CONFIGURATION = RTCConfiguration(
|
| 368 |
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
| 369 |
)
|
|
|
|
| 370 |
# Streamer
|
| 371 |
webrtc_ctx = webrtc_streamer(
|
| 372 |
key="AI trainer",
|
|
|
|
| 11 |
|
| 12 |
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
| 13 |
import av
|
| 14 |
+
from io import BytesIO
|
| 15 |
+
import av
|
| 16 |
+
from PIL import Image
|
| 17 |
|
| 18 |
## Build and Load Model
|
| 19 |
def attention_block(inputs, time_steps):
|
|
|
|
| 289 |
|
| 290 |
return output_frame
|
| 291 |
|
| 292 |
+
# @st.cache()
|
| 293 |
+
# def process(self, image):
|
| 294 |
+
# """
|
| 295 |
+
# Function to process the video frame from the user's webcam and run the fitness trainer AI
|
| 296 |
|
| 297 |
+
# Args:
|
| 298 |
+
# image (numpy array): input image from the webcam
|
| 299 |
|
| 300 |
+
# Returns:
|
| 301 |
+
# numpy array: processed image with keypoint detection and fitness activity classification visualized
|
| 302 |
+
# """
|
| 303 |
+
# # Pose detection model
|
| 304 |
+
# image.flags.writeable = False
|
| 305 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 306 |
+
# results = pose.process(image)
|
| 307 |
|
| 308 |
+
# # Draw the hand annotations on the image.
|
| 309 |
+
# image.flags.writeable = True
|
| 310 |
+
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 311 |
+
# self.draw_landmarks(image, results)
|
| 312 |
|
| 313 |
+
# # Prediction logic
|
| 314 |
+
# keypoints = self.extract_keypoints(results)
|
| 315 |
+
# self.sequence.append(keypoints.astype('float32',casting='same_kind'))
|
| 316 |
+
# self.sequence = self.sequence[-self.sequence_length:]
|
| 317 |
|
| 318 |
+
# if len(self.sequence) == self.sequence_length:
|
| 319 |
+
# res = model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
|
| 320 |
+
# # interpreter.set_tensor(self.input_details[0]['index'], np.expand_dims(self.sequence, axis=0))
|
| 321 |
+
# # interpreter.invoke()
|
| 322 |
+
# # res = interpreter.get_tensor(self.output_details[0]['index'])
|
| 323 |
|
| 324 |
+
# self.current_action = self.actions[np.argmax(res)]
|
| 325 |
+
# confidence = np.max(res)
|
| 326 |
|
| 327 |
+
# # Erase current action variable if no probability is above threshold
|
| 328 |
+
# if confidence < self.threshold:
|
| 329 |
+
# self.current_action = ''
|
| 330 |
|
| 331 |
+
# # Viz probabilities
|
| 332 |
+
# image = self.prob_viz(res, image)
|
| 333 |
|
| 334 |
+
# # Count reps
|
| 335 |
+
# try:
|
| 336 |
+
# landmarks = results.pose_landmarks.landmark
|
| 337 |
+
# self.count_reps(
|
| 338 |
+
# image, landmarks, mp_pose)
|
| 339 |
+
# except:
|
| 340 |
+
# pass
|
| 341 |
|
| 342 |
+
# # Display graphical information
|
| 343 |
+
# cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
|
| 344 |
+
# cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
|
| 345 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 346 |
+
# cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
|
| 347 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 348 |
+
# cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
|
| 349 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 350 |
|
| 351 |
+
# # return cv2.flip(image, 1)
|
| 352 |
+
# return image
|
| 353 |
|
| 354 |
+
# def recv(self, frame):
|
| 355 |
+
# """
|
| 356 |
+
# Receive and process video stream from webcam
|
| 357 |
+
|
| 358 |
+
# Args:
|
| 359 |
+
# frame: current video frame
|
| 360 |
|
| 361 |
+
# Returns:
|
| 362 |
+
# av.VideoFrame: processed video frame
|
| 363 |
+
# """
|
| 364 |
+
# img = frame.to_ndarray(format="bgr24")
|
| 365 |
+
# img = self.process(img)
|
| 366 |
+
# return av.VideoFrame.from_ndarray(img, format="bgr24")
|
| 367 |
+
def process_uploaded_file(self, file):
|
| 368 |
+
"""
|
| 369 |
+
Function to process an uploaded image or video file and run the fitness trainer AI
|
| 370 |
Args:
|
| 371 |
+
file (BytesIO): uploaded image or video file
|
| 372 |
+
Returns:
|
| 373 |
+
numpy array: processed image with keypoint detection and fitness activity classification visualized
|
| 374 |
+
"""
|
| 375 |
+
# Initialize an empty list to store processed frames
|
| 376 |
+
processed_frames = []
|
| 377 |
+
|
| 378 |
+
# Check if the uploaded file is a video
|
| 379 |
+
is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
|
| 380 |
|
| 381 |
+
if is_video:
|
| 382 |
+
container = av.open(file)
|
| 383 |
+
for frame in container.decode(video=0):
|
| 384 |
+
# Convert the frame to OpenCV format
|
| 385 |
+
image = frame.to_image().convert("RGB")
|
| 386 |
+
image = np.array(image)
|
| 387 |
+
|
| 388 |
+
# Process the frame
|
| 389 |
+
processed_frame = self.process(image)
|
| 390 |
+
|
| 391 |
+
# Append the processed frame to the list
|
| 392 |
+
processed_frames.append(processed_frame)
|
| 393 |
+
|
| 394 |
+
# Close the video file container
|
| 395 |
+
container.close()
|
| 396 |
+
else:
|
| 397 |
+
# If the uploaded file is an image
|
| 398 |
+
# Load the image from the BytesIO object
|
| 399 |
+
image = Image.open(file)
|
| 400 |
+
image = np.array(image)
|
| 401 |
+
|
| 402 |
+
# Process the image
|
| 403 |
+
processed_frame = self.process(image)
|
| 404 |
+
|
| 405 |
+
# Append the processed frame to the list
|
| 406 |
+
processed_frames.append(processed_frame)
|
| 407 |
+
|
| 408 |
+
return processed_frames
|
| 409 |
+
|
| 410 |
+
def recv_uploaded_file(self, file):
|
| 411 |
+
"""
|
| 412 |
+
Receive and process an uploaded video file
|
| 413 |
+
Args:
|
| 414 |
+
file (BytesIO): uploaded video file
|
| 415 |
Returns:
|
| 416 |
+
List[av.VideoFrame]: list of processed video frames
|
| 417 |
"""
|
| 418 |
+
# Process the uploaded file
|
| 419 |
+
processed_frames = self.process_uploaded_file(file)
|
| 420 |
+
|
| 421 |
+
# Convert processed frames to av.VideoFrame objects
|
| 422 |
+
av_frames = []
|
| 423 |
+
for frame in processed_frames:
|
| 424 |
+
av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
|
| 425 |
+
av_frames.append(av_frame)
|
| 426 |
+
|
| 427 |
+
return av_frames
|
| 428 |
|
|
|
|
| 429 |
# Options
|
| 430 |
RTC_CONFIGURATION = RTCConfiguration(
|
| 431 |
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
| 432 |
)
|
| 433 |
+
|
| 434 |
# Streamer
|
| 435 |
webrtc_ctx = webrtc_streamer(
|
| 436 |
key="AI trainer",
|