Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import mediapipe as mp | |
| import numpy as np | |
| from streamlit_webrtc import webrtc_streamer, VideoTransformerBase | |
| import av | |
| import threading | |
| from dataclasses import dataclass | |
| from typing import List | |
| # Mediapipe setup | |
| mp_drawing = mp.solutions.drawing_utils | |
| mp_pose = mp.solutions.pose | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| background: linear-gradient(135deg, #001f3f 0%, #00b4d8 100%); | |
| } | |
| .stButton > button { | |
| background-color: #00b4d8; | |
| color: white; | |
| border: none; | |
| padding: 0.5rem 2rem; | |
| border-radius: 5px; | |
| margin: 0.5rem; | |
| transition: all 0.3s; | |
| } | |
| .stButton > button:hover { | |
| background-color: #0077b6; | |
| } | |
| h1, h2, h3 { | |
| color: #001f3f; | |
| } | |
| .workout-container { | |
| background: rgba(0, 180, 216, 0.1); | |
| padding: 2rem; | |
| border-radius: 10px; | |
| margin: 1rem 0; | |
| } | |
| .feedback-text { | |
| background: rgba(0, 31, 63, 0.1); | |
| padding: 1rem; | |
| border-radius: 5px; | |
| margin: 1rem 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| class ExerciseState: | |
| counter: int = 0 | |
| stage: str = None | |
| feedback: str = "" | |
| # Global state | |
| state = ExerciseState() | |
| lock = threading.Lock() | |
| def calculate_angle(a, b, c): | |
| """Calculate angle between three points.""" | |
| a = np.array(a) | |
| b = np.array(b) | |
| c = np.array(c) | |
| radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0]) | |
| angle = np.abs(np.degrees(radians)) | |
| if angle > 180.0: | |
| angle = 360 - angle | |
| return angle | |
| def calculate_lateral_raise_angle(shoulder, wrist): | |
| """Calculate angle for lateral raise.""" | |
| horizontal_reference = np.array([1, 0]) | |
| arm_vector = np.array([wrist[0] - shoulder[0], wrist[1] - shoulder[1]]) | |
| dot_product = np.dot(horizontal_reference, arm_vector) | |
| magnitude_reference = np.linalg.norm(horizontal_reference) | |
| magnitude_arm = np.linalg.norm(arm_vector) | |
| if magnitude_arm == 0 or magnitude_reference == 0: | |
| return 0 | |
| cos_angle = dot_product / (magnitude_reference * magnitude_arm) | |
| angle = np.arccos(np.clip(cos_angle, -1.0, 1.0)) | |
| return np.degrees(angle) | |
| class VideoTransformer(VideoTransformerBase): | |
| def __init__(self): | |
| self.pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) | |
| self.workout_type = "bicep_curl" # Default workout | |
| def process_bicep_curl(self, landmarks): | |
| """Process frame for bicep curl exercise.""" | |
| shoulder = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, | |
| landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y] | |
| elbow = [landmarks[mp_pose.PoseLandmark.LEFT_ELBOW.value].x, | |
| landmarks[mp_pose.PoseLandmark.LEFT_ELBOW.value].y] | |
| wrist = [landmarks[mp_pose.PoseLandmark.LEFT_WRIST.value].x, | |
| landmarks[mp_pose.PoseLandmark.LEFT_WRIST.value].y] | |
| angle = calculate_angle(shoulder, elbow, wrist) | |
| with lock: | |
| if angle > 160 and state.stage != "down": | |
| state.stage = "down" | |
| state.feedback = "Lower the weight" | |
| elif angle < 40 and state.stage == "down": | |
| state.stage = "up" | |
| state.counter += 1 | |
| state.feedback = f"Good rep! Count: {state.counter}" | |
| return angle | |
| def process_lateral_raise(self, landmarks): | |
| """Process frame for lateral raise exercise.""" | |
| shoulder = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, | |
| landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y] | |
| wrist = [landmarks[mp_pose.PoseLandmark.LEFT_WRIST.value].x, | |
| landmarks[mp_pose.PoseLandmark.LEFT_WRIST.value].y] | |
| angle = calculate_lateral_raise_angle(shoulder, wrist) | |
| with lock: | |
| if angle < 20 and state.stage != "down": | |
| state.stage = "down" | |
| state.feedback = "Raise your arms" | |
| elif 70 <= angle <= 110 and state.stage == "down": | |
| state.stage = "up" | |
| state.counter += 1 | |
| state.feedback = f"Good rep! Count: {state.counter}" | |
| return angle | |
| def process_shoulder_press(self, landmarks): | |
| """Process frame for shoulder press exercise.""" | |
| shoulder = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, | |
| landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y] | |
| elbow = [landmarks[mp_pose.PoseLandmark.LEFT_ELBOW.value].x, | |
| landmarks[mp_pose.PoseLandmark.LEFT_ELBOW.value].y] | |
| wrist = [landmarks[mp_pose.PoseLandmark.LEFT_WRIST.value].x, | |
| landmarks[mp_pose.PoseLandmark.LEFT_WRIST.value].y] | |
| angle = calculate_angle(shoulder, elbow, wrist) | |
| with lock: | |
| if 80 <= angle <= 100 and state.stage != "down": | |
| state.stage = "down" | |
| state.feedback = "Press up!" | |
| elif angle > 160 and state.stage == "down": | |
| state.stage = "up" | |
| state.counter += 1 | |
| state.feedback = f"Good rep! Count: {state.counter}" | |
| return angle | |
| def recv(self, frame): | |
| img = frame.to_ndarray(format="bgr24") | |
| # Process the image | |
| image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| results = self.pose.process(image) | |
| if results.pose_landmarks: | |
| # Draw pose landmarks | |
| mp_drawing.draw_landmarks( | |
| img, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, | |
| mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2), | |
| mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2) | |
| ) | |
| # Process based on workout type | |
| if self.workout_type == "bicep_curl": | |
| angle = self.process_bicep_curl(results.pose_landmarks.landmark) | |
| elif self.workout_type == "lateral_raise": | |
| angle = self.process_lateral_raise(results.pose_landmarks.landmark) | |
| else: # shoulder_press | |
| angle = self.process_shoulder_press(results.pose_landmarks.landmark) | |
| # Draw angle and counter | |
| cv2.putText(img, f"Angle: {angle:.2f}", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| cv2.putText(img, f"Counter: {state.counter}", (10, 70), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| cv2.putText(img, f"Feedback: {state.feedback}", (10, 110), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| return av.VideoFrame.from_ndarray(img, format="bgr24") | |
| def main(): | |
| st.title("🏋️♂️ AI Workout Trainer") | |
| st.markdown(""" | |
| <div class='workout-container'> | |
| Welcome to your AI Workout Trainer! This app will help you perfect your form | |
| and track your exercises in real-time. Choose a workout and follow the feedback | |
| to improve your technique. | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Workout selection | |
| workout_options = { | |
| "Bicep Curl": "bicep_curl", | |
| "Lateral Raise": "lateral_raise", | |
| "Shoulder Press": "shoulder_press" | |
| } | |
| selected_workout = st.selectbox( | |
| "Choose your workout:", | |
| list(workout_options.keys()) | |
| ) | |
| # Reset state when workout changes | |
| if 'last_workout' not in st.session_state or st.session_state.last_workout != selected_workout: | |
| with lock: | |
| state.counter = 0 | |
| state.stage = None | |
| state.feedback = "" | |
| st.session_state.last_workout = selected_workout | |
| # Exercise descriptions | |
| descriptions = { | |
| "Bicep Curl": "Focus on keeping your upper arm still and curl the weight up smoothly.", | |
| "Lateral Raise": "Raise your arms to shoulder height, keeping them slightly bent.", | |
| "Shoulder Press": "Press the weight overhead, fully extending your arms." | |
| } | |
| st.markdown(f""" | |
| <div class='workout-container'> | |
| <h3>{selected_workout}</h3> | |
| <p>{descriptions[selected_workout]}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Initialize WebRTC streamer | |
| webrtc_ctx = webrtc_streamer( | |
| key="workout", | |
| video_transformer_factory=VideoTransformer, | |
| rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} | |
| ) | |
| if webrtc_ctx.video_transformer: | |
| webrtc_ctx.video_transformer.workout_type = workout_options[selected_workout] | |
| # Display feedback | |
| feedback_placeholder = st.empty() | |
| if webrtc_ctx.state.playing: | |
| feedback_placeholder.markdown(f""" | |
| <div class='feedback-text'> | |
| <h4>Current Exercise: {selected_workout}</h4> | |
| <p>Reps Completed: {state.counter}</p> | |
| <p>Feedback: {state.feedback}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() | |