| from flask import Flask, render_template, Response, request, jsonify, session, redirect, url_for |
| import cv2 |
| import base64 |
| import io |
| import numpy as np |
| from PIL import Image |
| import threading |
| import time |
| import uuid |
| import sys |
| import traceback |
| import logging |
| from flask_cors import CORS |
|
|
| |
| logging.basicConfig(level=logging.DEBUG, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| handlers=[logging.StreamHandler()]) |
| logger = logging.getLogger(__name__) |
|
|
| |
| try: |
| from pose_estimation.estimation import PoseEstimator |
| from exercises.squat import Squat |
| from exercises.hammer_curl import HammerCurl |
| from exercises.push_up import PushUp |
| from feedback.information import get_exercise_info |
| from feedback.layout import layout_indicators |
| from utils.draw_text_with_background import draw_text_with_background |
| logger.info("Successfully imported pose estimation modules") |
| except ImportError as e: |
| logger.error(f"Failed to import required modules: {e}") |
| traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| try: |
| from db.workout_logger import WorkoutLogger |
| workout_logger = WorkoutLogger() |
| logger.info("Successfully initialized workout logger") |
| except ImportError: |
| logger.warning("WorkoutLogger import failed, creating dummy class") |
| |
| class DummyWorkoutLogger: |
| def __init__(self): |
| pass |
| def log_workout(self, *args, **kwargs): |
| return {} |
| def get_recent_workouts(self, *args, **kwargs): |
| return [] |
| def get_weekly_stats(self, *args, **kwargs): |
| return {} |
| def get_exercise_distribution(self, *args, **kwargs): |
| return {} |
| def get_user_stats(self, *args, **kwargs): |
| return {'total_workouts': 0, 'total_exercises': 0, 'streak_days': 0} |
| |
| workout_logger = DummyWorkoutLogger() |
|
|
| logger.info("Setting up Flask application") |
| app = Flask(__name__) |
| app.secret_key = 'fitness_trainer_secret_key' |
| CORS(app, origins="*", methods=["GET", "POST", "OPTIONS"], allow_headers=["Content-Type", "Authorization"]) |
|
|
| pose_estimator_api = None |
| active_exercise_sessions = {} |
| |
| MAX_SESSIONS = 100 |
|
|
| def get_pose_estimator_api_instance(): |
| global pose_estimator_api |
| if pose_estimator_api is None: |
| logger.info("Initializing PoseEstimator for API") |
| pose_estimator_api = PoseEstimator() |
| return pose_estimator_api |
|
|
| |
| camera = None |
| output_frame = None |
| lock = threading.Lock() |
| exercise_running = False |
| current_exercise = None |
| current_exercise_data = None |
| exercise_counter = 0 |
| exercise_goal = 0 |
| sets_completed = 0 |
| sets_goal = 0 |
| workout_start_time = None |
|
|
| def initialize_camera(): |
| global camera |
| if camera is None: |
| camera = cv2.VideoCapture(0) |
| return camera |
|
|
| def release_camera(): |
| global camera |
| if camera is not None: |
| camera.release() |
| camera = None |
|
|
| def generate_frames(): |
| global output_frame, lock, exercise_running, current_exercise, current_exercise_data |
| global exercise_counter, exercise_goal, sets_completed, sets_goal |
| |
| pose_estimator = PoseEstimator() |
| |
| while True: |
| if camera is None: |
| continue |
| |
| success, frame = camera.read() |
| if not success: |
| continue |
| |
| |
| if exercise_running and current_exercise: |
| |
| results = pose_estimator.estimate_pose(frame, current_exercise_data['type']) |
| |
| if results.pose_landmarks: |
| |
| if current_exercise_data['type'] == "squat": |
| counter, angle, stage = current_exercise.track_squat(results.pose_landmarks.landmark, frame) |
| layout_indicators(frame, current_exercise_data['type'], (counter, angle, stage)) |
| exercise_counter = counter |
| |
| elif current_exercise_data['type'] == "push_up": |
| counter, angle, stage = current_exercise.track_push_up(results.pose_landmarks.landmark, frame) |
| layout_indicators(frame, current_exercise_data['type'], (counter, angle, stage)) |
| exercise_counter = counter |
| |
| elif current_exercise_data['type'] == "hammer_curl": |
| (counter_right, angle_right, counter_left, angle_left, |
| warning_message_right, warning_message_left, progress_right, |
| progress_left, stage_right, stage_left) = current_exercise.track_hammer_curl( |
| results.pose_landmarks.landmark, frame) |
| layout_indicators(frame, current_exercise_data['type'], |
| (counter_right, angle_right, counter_left, angle_left, |
| warning_message_right, warning_message_left, |
| progress_right, progress_left, stage_right, stage_left)) |
| exercise_counter = max(counter_right, counter_left) |
| |
| |
| exercise_info = get_exercise_info(current_exercise_data['type']) |
| draw_text_with_background(frame, f"Exercise: {exercise_info.get('name', 'N/A')}", (40, 50), |
| cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1) |
| draw_text_with_background(frame, f"Reps Goal: {exercise_goal}", (40, 80), |
| cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1) |
| draw_text_with_background(frame, f"Sets Goal: {sets_goal}", (40, 110), |
| cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1) |
| draw_text_with_background(frame, f"Current Set: {sets_completed + 1}", (40, 140), |
| cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1) |
| |
| |
| if exercise_counter >= exercise_goal: |
| sets_completed += 1 |
| exercise_counter = 0 |
| |
| if current_exercise_data['type'] == "squat" or current_exercise_data['type'] == "push_up": |
| current_exercise.counter = 0 |
| elif current_exercise_data['type'] == "hammer_curl": |
| current_exercise.counter_right = 0 |
| current_exercise.counter_left = 0 |
| |
| |
| if sets_completed >= sets_goal: |
| exercise_running = False |
| draw_text_with_background(frame, "WORKOUT COMPLETE!", (frame.shape[1]//2 - 150, frame.shape[0]//2), |
| cv2.FONT_HERSHEY_DUPLEX, 1.2, (255, 255, 255), (0, 200, 0), 2) |
| else: |
| draw_text_with_background(frame, f"SET {sets_completed} COMPLETE! Rest for 30 sec", |
| (frame.shape[1]//2 - 200, frame.shape[0]//2), |
| cv2.FONT_HERSHEY_DUPLEX, 1.0, (255, 255, 255), (0, 0, 200), 2) |
| |
| else: |
| |
| cv2.putText(frame, "Select an exercise to begin", (frame.shape[1]//2 - 150, frame.shape[0]//2), |
| cv2.FONT_HERSHEY_DUPLEX, 0.8, (255, 255, 255), 1) |
| |
| |
| with lock: |
| output_frame = frame.copy() |
| |
| |
| ret, buffer = cv2.imencode('.jpg', output_frame) |
| frame = buffer.tobytes() |
| yield (b'--frame\r\n' |
| b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') |
|
|
| @app.route('/') |
| def index(): |
| """Home page with exercise selection""" |
| logger.info("Rendering index page") |
| try: |
| return render_template('index.html') |
| except Exception as e: |
| logger.error(f"Error rendering index: {e}") |
| return f"Error rendering template: {str(e)}", 500 |
|
|
| @app.route('/dashboard') |
| def dashboard(): |
| """Dashboard page with workout statistics""" |
| logger.info("Rendering dashboard page") |
| try: |
| |
| recent_workouts = workout_logger.get_recent_workouts(5) |
| weekly_stats = workout_logger.get_weekly_stats() |
| exercise_distribution = workout_logger.get_exercise_distribution() |
| user_stats = workout_logger.get_user_stats() |
| |
| |
| formatted_workouts = [] |
| for workout in recent_workouts: |
| formatted_workouts.append({ |
| 'date': workout['date'], |
| 'exercise': workout['exercise_type'].replace('_', ' ').title(), |
| 'sets': workout['sets'], |
| 'reps': workout['reps'], |
| 'duration': f"{workout['duration_seconds'] // 60}:{workout['duration_seconds'] % 60:02d}" |
| }) |
| |
| |
| weekly_workout_count = sum(day['workout_count'] for day in weekly_stats.values()) |
| |
| return render_template('dashboard.html', |
| recent_workouts=formatted_workouts, |
| weekly_workouts=weekly_workout_count, |
| total_workouts=user_stats['total_workouts'], |
| total_exercises=user_stats['total_exercises'], |
| streak_days=user_stats['streak_days']) |
| except Exception as e: |
| logger.error(f"Error in dashboard: {e}") |
| traceback.print_exc() |
| return f"Error loading dashboard: {str(e)}", 500 |
|
|
| @app.route('/video_feed') |
| def video_feed(): |
| """Video streaming route""" |
| return Response(generate_frames(), |
| mimetype='multipart/x-mixed-replace; boundary=frame') |
|
|
| @app.route('/start_exercise', methods=['POST']) |
| def start_exercise(): |
| """Start a new exercise based on user selection""" |
| global exercise_running, current_exercise, current_exercise_data |
| global exercise_counter, exercise_goal, sets_completed, sets_goal |
| global workout_start_time |
| |
| data = request.json |
| exercise_type = data.get('exercise_type') |
| sets_goal = int(data.get('sets', 3)) |
| exercise_goal = int(data.get('reps', 10)) |
| |
| |
| initialize_camera() |
| |
| |
| exercise_counter = 0 |
| sets_completed = 0 |
| workout_start_time = time.time() |
| |
| |
| if exercise_type == "squat": |
| current_exercise = Squat() |
| elif exercise_type == "push_up": |
| current_exercise = PushUp() |
| elif exercise_type == "hammer_curl": |
| current_exercise = HammerCurl() |
| else: |
| return jsonify({'success': False, 'error': 'Invalid exercise type'}) |
| |
| |
| current_exercise_data = { |
| 'type': exercise_type, |
| 'sets': sets_goal, |
| 'reps': exercise_goal |
| } |
| |
| |
| exercise_running = True |
| |
| return jsonify({'success': True}) |
|
|
| @app.route('/stop_exercise', methods=['POST']) |
| def stop_exercise(): |
| """Stop the current exercise and log the workout""" |
| global exercise_running, current_exercise_data, workout_start_time |
| global exercise_counter, exercise_goal, sets_completed, sets_goal |
| |
| if exercise_running and current_exercise_data: |
| |
| duration = int(time.time() - workout_start_time) if workout_start_time else 0 |
| |
| |
| workout_logger.log_workout( |
| exercise_type=current_exercise_data['type'], |
| sets=sets_completed + (1 if exercise_counter > 0 else 0), |
| reps=exercise_goal, |
| duration_seconds=duration |
| ) |
| release_camera() |
| exercise_running = False |
| return jsonify({'success': True}) |
|
|
| @app.route('/get_status', methods=['GET']) |
| def get_status(): |
| """Return current exercise status""" |
| global exercise_counter, sets_completed, exercise_goal, sets_goal, exercise_running |
| |
| return jsonify({ |
| 'exercise_running': exercise_running, |
| 'current_reps': exercise_counter, |
| 'current_set': sets_completed + 1 if exercise_running else 0, |
| 'total_sets': sets_goal, |
| 'rep_goal': exercise_goal |
| }) |
|
|
| @app.route('/profile') |
| def profile(): |
| """User profile page - placeholder for future development""" |
| return "Profile page - Coming soon!" |
|
|
| @app.route('/healthz') |
| def health_check(): |
| logger.info("Health check endpoint called successfully.") |
| return "OK", 200 |
|
|
| @app.route('/api/analyze_frame', methods=['POST']) |
| def analyze_frame(): |
| logger.info("API call to /api/analyze_frame") |
| try: |
| data = request.json |
| if not data or 'image' not in data: |
| logger.warning("API /api/analyze_frame: No image provided or invalid JSON.") |
| return jsonify({'error': 'No image provided in JSON payload (expected base64 string under "image" key)'}), 400 |
| |
| |
| exercise_type_for_api = data.get('exercise_type', 'squat') |
| if exercise_type_for_api not in ["squat", "push_up", "hammer_curl"]: |
| logger.warning(f"API /api/analyze_frame: Invalid exercise_type '{exercise_type_for_api}'. Defaulting to 'squat'.") |
| exercise_type_for_api = "squat" |
|
|
| image_data = data['image'] |
| |
| |
| missing_padding = len(image_data) % 4 |
| if missing_padding: |
| image_data += '=' * (4 - missing_padding) |
|
|
| try: |
| |
| image_bytes = base64.b64decode(image_data) |
| pil_image = Image.open(io.BytesIO(image_bytes)) |
| |
| frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) |
| except Exception as e: |
| logger.error(f"API /api/analyze_frame: Error decoding base64 image: {e}") |
| return jsonify({'error': f'Invalid base64 image data: {str(e)}'}), 400 |
|
|
| estimator = get_pose_estimator_api_instance() |
| |
| |
| results = estimator.estimate_pose(frame, exercise_type_for_api) |
| |
| landmarks_list = [] |
| if results.pose_landmarks: |
| for i, landmark in enumerate(results.pose_landmarks.landmark): |
| landmarks_list.append({ |
| 'index': i, |
| 'x': landmark.x, |
| 'y': landmark.y, |
| 'z': landmark.z, |
| 'visibility': landmark.visibility if hasattr(landmark, 'visibility') else None |
| }) |
| logger.info(f"API /api/analyze_frame: Successfully processed image, found {len(landmarks_list)} landmarks for exercise type '{exercise_type_for_api}'.") |
| return jsonify({'success': True, 'landmarks': landmarks_list}) |
| else: |
| logger.info(f"API /api/analyze_frame: No landmarks detected in the provided image for exercise type '{exercise_type_for_api}'.") |
| return jsonify({'success': True, 'landmarks': []}) |
|
|
| except Exception as e: |
| logger.error(f"API /api/analyze_frame: Error: {e}") |
| traceback.print_exc() |
| return jsonify({'error': f'Internal server error: {str(e)}'}), 500 |
|
|
| @app.route('/api/track_exercise_stream', methods=['POST']) |
| def track_exercise_stream(): |
| logger.info("API call to /api/track_exercise_stream") |
| data = {} |
| try: |
| data = request.json |
| if not data: |
| return jsonify({'error': 'No JSON data provided'}), 400 |
|
|
| session_id = data.get('session_id') |
| exercise_type = data.get('exercise_type') |
| image_data_base64 = data.get('image') |
| frame_width = data.get('frame_width') |
| frame_height = data.get('frame_height') |
|
|
| if not all([session_id, exercise_type, image_data_base64, frame_width, frame_height]): |
| return jsonify({'error': 'Missing required fields: session_id, exercise_type, image, frame_width, frame_height'}), 400 |
|
|
| if not isinstance(frame_width, int) or not isinstance(frame_height, int) or frame_width <= 0 or frame_height <= 0: |
| return jsonify({'error': 'Invalid frame_width or frame_height'}), 400 |
| |
| |
| if len(active_exercise_sessions) >= MAX_SESSIONS and session_id not in active_exercise_sessions: |
| logger.warning(f"Max sessions ({MAX_SESSIONS}) reached. Rejecting new session {session_id}.") |
| |
| return jsonify({'error': 'Server busy, max sessions reached. Please try again later.'}), 503 |
|
|
|
|
| |
| if session_id not in active_exercise_sessions: |
| logger.info(f"Creating new session {session_id} for exercise {exercise_type}") |
| if exercise_type == 'squat': |
| active_exercise_sessions[session_id] = Squat() |
| elif exercise_type == 'push_up': |
| active_exercise_sessions[session_id] = PushUp() |
| elif exercise_type == 'hammer_curl': |
| active_exercise_sessions[session_id] = HammerCurl() |
| else: |
| logger.warning(f"Invalid exercise type: {exercise_type} for session {session_id}") |
| return jsonify({'error': 'Invalid exercise_type'}), 400 |
| |
| exercise_session = active_exercise_sessions[session_id] |
|
|
| |
| missing_padding = len(image_data_base64) % 4 |
| if missing_padding: |
| image_data_base64 += '=' * (4 - missing_padding) |
| try: |
| image_bytes = base64.b64decode(image_data_base64) |
| pil_image = Image.open(io.BytesIO(image_bytes)) |
| frame_for_estimation = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) |
| except Exception as e: |
| logger.error(f"Session {session_id}: Error decoding base64 image: {e}") |
| return jsonify({'error': f'Invalid base64 image data: {str(e)}'}), 400 |
|
|
| |
| pose_estimator = get_pose_estimator_api_instance() |
| |
| |
| results = pose_estimator.estimate_pose(frame_for_estimation, 'squat') |
|
|
| if not results.pose_landmarks: |
| logger.info(f"Session {session_id}: No landmarks detected.") |
| |
| return jsonify({ |
| 'success': True, |
| 'landmarks_detected': False, |
| 'message': 'No landmarks detected in this frame.', |
| |
| }) |
| |
| exercise_data = None |
| if exercise_type == 'squat': |
| exercise_data = exercise_session.track_squat(results.pose_landmarks.landmark, frame_width, frame_height) |
| elif exercise_type == 'push_up': |
| exercise_data = exercise_session.track_push_up(results.pose_landmarks.landmark, frame_width, frame_height) |
| elif exercise_type == 'hammer_curl': |
| exercise_data = exercise_session.track_hammer_curl(results.pose_landmarks.landmark, frame_width, frame_height) |
| |
| if exercise_data: |
| logger.debug(f"Session {session_id}: Exercise data: {exercise_data}") |
| return jsonify({'success': True, 'landmarks_detected': True, 'data': exercise_data}) |
| else: |
| |
| logger.error(f"Session {session_id}: Could not get exercise_data for {exercise_type}") |
| return jsonify({'error': 'Failed to process exercise frame.'}), 500 |
|
|
| except Exception as e: |
| session_id_log = data.get('session_id', 'unknown_session') if isinstance(data, dict) else 'unknown_session' |
| logger.error(f"API /api/track_exercise_stream Error for session {session_id_log}: {e}") |
| traceback.print_exc() |
| return jsonify({'error': f'Internal server error: {str(e)}'}), 500 |
|
|
| @app.route('/api/end_exercise_session', methods=['POST']) |
| def end_exercise_session(): |
| logger.info("API call to /api/end_exercise_session") |
| try: |
| data = request.json |
| if not data: |
| return jsonify({'error': 'No JSON data provided'}), 400 |
| |
| session_id = data.get('session_id') |
| if not session_id: |
| return jsonify({'error': 'Missing session_id'}), 400 |
|
|
| if session_id in active_exercise_sessions: |
| del active_exercise_sessions[session_id] |
| logger.info(f"Ended and removed session: {session_id}") |
| return jsonify({'success': True, 'message': f'Session {session_id} ended.'}) |
| else: |
| logger.warning(f"Attempted to end non-existent session: {session_id}") |
| return jsonify({'success': False, 'message': f'Session {session_id} not found.'}), 404 |
| except Exception as e: |
| logger.error(f"API /api/end_exercise_session Error: {e}") |
| traceback.print_exc() |
| return jsonify({'error': f'Internal server error: {str(e)}'}), 500 |
|
|
| if __name__ == '__main__': |
| try: |
| logger.info("Starting the Flask application on http://127.0.0.1:5000") |
| print("Starting Fitness Trainer app, please wait...") |
| print("Open http://127.0.0.1:5000 in your web browser when the server starts") |
| app.run(debug=True) |
| except Exception as e: |
| logger.error(f"Failed to start application: {e}") |
| traceback.print_exc() |