Spaces:
Sleeping
Sleeping
| import logging | |
| # Early log - this should ALWAYS appear if the script is run | |
| logging.basicConfig(level=logging.DEBUG, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler()]) | |
| logger = logging.getLogger(__name__) | |
| logger.info("APP.PY EXECUTION STARTED - VERY FIRST LINE") | |
| from flask import Flask, render_template, Response, request, jsonify, session, redirect, url_for, send_from_directory | |
| from flask_socketio import SocketIO, emit | |
| import os | |
| import cv2 | |
| import base64 | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| import threading | |
| import time | |
| import uuid | |
| import sys | |
| import traceback | |
| from flask_cors import CORS | |
| logger.info("All imports completed in app.py") | |
| # Import attempt with error handling (from original full code) | |
| try: | |
| from pose_estimation.estimation import PoseEstimator | |
| from exercises.squat import Squat | |
| from exercises.hammer_curl import HammerCurl | |
| from exercises.push_up import PushUp | |
| from feedback.information import get_exercise_info | |
| from feedback.layout import layout_indicators | |
| from utils.draw_text_with_background import draw_text_with_background | |
| logger.info("Successfully imported custom pose estimation and exercise modules") | |
| except ImportError as e: | |
| logger.error(f"Failed to import custom modules: {e}") | |
| traceback.print_exc() | |
| # Depending on the severity, you might want to sys.exit(1) here | |
| # For now, we'll assume they are needed and would cause issues later if not found. | |
| # If these are essential for startup, the app might still fail. | |
| # Try to import WorkoutLogger with fallback (from original full code) | |
| try: | |
| from db.workout_logger import WorkoutLogger | |
| workout_logger = WorkoutLogger() | |
| logger.info("Successfully initialized workout logger") | |
| except ImportError: | |
| logger.warning("WorkoutLogger import failed, creating dummy class") | |
| class DummyWorkoutLogger: | |
| def __init__(self): pass | |
| def log_workout(self, *args, **kwargs): return {} | |
| def get_recent_workouts(self, *args, **kwargs): return [] | |
| def get_weekly_stats(self, *args, **kwargs): return {} | |
| def get_exercise_distribution(self, *args, **kwargs): return {} | |
| def get_user_stats(self, *args, **kwargs): return {'total_workouts': 0, 'total_exercises': 0, 'streak_days': 0} | |
| workout_logger = DummyWorkoutLogger() | |
| app = Flask(__name__) | |
| logger.info("Flask app object created.") | |
| app.config['SECRET_KEY'] = os.environ.get('FLASK_SECRET_KEY', 'a_very_secret_key_for_production_final') | |
| CORS(app, resources={r"/*": {"origins": "*"}}) # Apply CORS to the Flask app | |
| logger.info("CORS configured for Flask app.") | |
| socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet') | |
| logger.info("SocketIO initialized with eventlet async_mode and CORS.") | |
| # --- Global Variables & Helper Functions --- | |
| pose_estimator_instance = None # Renamed for clarity | |
| active_exercise_sessions = {} # Stores both HTTP API and WebSocket sessions | |
| MAX_SESSIONS = 100 | |
| # Variables for the original /video_feed functionality (server-side camera, for index.html) | |
| camera_original_feed = None | |
| output_frame_original_feed = None | |
| lock_original_feed = threading.Lock() | |
| exercise_running_original_feed = False | |
| current_exercise_original_feed = None | |
| current_exercise_data_original_feed = None | |
| exercise_counter_original_feed = 0 | |
| exercise_goal_original_feed = 0 | |
| sets_completed_original_feed = 0 | |
| sets_goal_original_feed = 0 | |
| workout_start_time_original_feed = None | |
| def get_pose_estimator(): | |
| global pose_estimator_instance | |
| if pose_estimator_instance is None: | |
| logger.info("Initializing PoseEstimator instance.") | |
| pose_estimator_instance = PoseEstimator() | |
| return pose_estimator_instance | |
| # --- Routes for original server-side rendering functionality --- | |
| def initialize_camera_original(): | |
| global camera_original_feed | |
| if camera_original_feed is None: | |
| logger.info("Initializing camera for /video_feed") | |
| try: | |
| camera_original_feed = cv2.VideoCapture(0) # Or appropriate camera index/source | |
| if not camera_original_feed.isOpened(): | |
| logger.error("Could not open video capture device for original feed.") | |
| camera_original_feed = None # Ensure it's None if failed | |
| except Exception as e: | |
| logger.error(f"Exception opening camera for original feed: {e}") | |
| camera_original_feed = None | |
| return camera_original_feed | |
| def release_camera_original(): | |
| global camera_original_feed | |
| if camera_original_feed is not None: | |
| logger.info("Releasing camera for /video_feed") | |
| camera_original_feed.release() | |
| camera_original_feed = None | |
| def generate_frames_original(): | |
| global output_frame_original_feed, lock_original_feed, exercise_running_original_feed | |
| global current_exercise_original_feed, current_exercise_data_original_feed | |
| global exercise_counter_original_feed, exercise_goal_original_feed | |
| global sets_completed_original_feed, sets_goal_original_feed | |
| global camera_original_feed | |
| local_pose_estimator = get_pose_estimator_instance() | |
| if camera_original_feed is None: | |
| camera_original_feed = initialize_camera_original() | |
| while True: | |
| if camera_original_feed is None or not camera_original_feed.isOpened(): | |
| logger.warning("Original feed camera not available in generate_frames_original.") | |
| blank_frame = np.zeros((480, 640, 3), dtype=np.uint8) | |
| cv2.putText(blank_frame, "Camera not available", (50, 240), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
| ret, buffer = cv2.imencode('.jpg', blank_frame) | |
| frame_bytes = buffer.tobytes() | |
| yield (b'--frame\r\n' | |
| b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') | |
| time.sleep(1) # Prevent busy-looping if camera is not found | |
| continue | |
| success, frame = camera_original_feed.read() | |
| if not success: | |
| logger.warning("Failed to read frame from original feed camera in loop.") | |
| time.sleep(0.1) | |
| continue | |
| if exercise_running_original_feed and current_exercise_original_feed and current_exercise_data_original_feed: | |
| try: | |
| results = local_pose_estimator.estimate_pose(frame, current_exercise_data_original_feed['type']) | |
| if results.pose_landmarks: | |
| exercise_info = get_exercise_info(current_exercise_data_original_feed['type']) | |
| if current_exercise_data_original_feed['type'] == "squat": | |
| counter, angle, stage = current_exercise_original_feed.track_squat(results.pose_landmarks.landmark, frame) | |
| layout_indicators(frame, current_exercise_data_original_feed['type'], (counter, angle, stage)) | |
| exercise_counter_original_feed = counter | |
| elif current_exercise_data_original_feed['type'] == "push_up": | |
| counter, angle, stage = current_exercise_original_feed.track_push_up(results.pose_landmarks.landmark, frame) | |
| layout_indicators(frame, current_exercise_data_original_feed['type'], (counter, angle, stage)) | |
| exercise_counter_original_feed = counter | |
| elif current_exercise_data_original_feed['type'] == "hammer_curl": | |
| (counter_right, angle_right, counter_left, angle_left, | |
| warning_message_right, warning_message_left, progress_right, | |
| progress_left, stage_right, stage_left) = current_exercise_original_feed.track_hammer_curl( | |
| results.pose_landmarks.landmark, frame) | |
| layout_indicators(frame, current_exercise_data_original_feed['type'], | |
| (counter_right, angle_right, counter_left, angle_left, | |
| warning_message_right, warning_message_left, | |
| progress_right, progress_left, stage_right, stage_left)) | |
| exercise_counter_original_feed = max(counter_right, counter_left) | |
| draw_text_with_background(frame, f"Exercise: {exercise_info.get('name', 'N/A')}", (40, 50), | |
| cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1) | |
| draw_text_with_background(frame, f"Reps Goal: {exercise_goal_original_feed}", (40, 80), | |
| cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1) | |
| draw_text_with_background(frame, f"Sets Goal: {sets_goal_original_feed}", (40, 110), | |
| cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1) | |
| draw_text_with_background(frame, f"Current Set: {sets_completed_original_feed + 1}", (40, 140), | |
| cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1) | |
| if exercise_counter_original_feed >= exercise_goal_original_feed: | |
| sets_completed_original_feed += 1 | |
| exercise_counter_original_feed = 0 | |
| if current_exercise_data_original_feed['type'] in ["squat", "push_up"]: | |
| current_exercise_original_feed.counter = 0 | |
| elif current_exercise_data_original_feed['type'] == "hammer_curl": | |
| current_exercise_original_feed.counter_right = 0 | |
| current_exercise_original_feed.counter_left = 0 | |
| if sets_completed_original_feed >= sets_goal_original_feed: | |
| exercise_running_original_feed = False | |
| draw_text_with_background(frame, "WORKOUT COMPLETE!", (frame.shape[1]//2 - 150, frame.shape[0]//2), | |
| cv2.FONT_HERSHEY_DUPLEX, 1.2, (255, 255, 255), (0, 200, 0), 2) | |
| else: | |
| draw_text_with_background(frame, f"SET {sets_completed_original_feed} COMPLETE! Rest for 30 sec", | |
| (frame.shape[1]//2 - 200, frame.shape[0]//2), | |
| cv2.FONT_HERSHEY_DUPLEX, 1.0, (255, 255, 255), (0, 0, 200), 2) | |
| else: # No pose landmarks | |
| pass | |
| except Exception as e: | |
| logger.error(f"Error during pose estimation or drawing for original feed: {e}", exc_info=True) | |
| else: | |
| cv2.putText(frame, "Select an exercise on main page to begin", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2) | |
| with lock_original_feed: | |
| output_frame_original_feed = frame.copy() | |
| ret, buffer = cv2.imencode('.jpg', output_frame_original_feed) | |
| frame_bytes = buffer.tobytes() | |
| yield (b'--frame\r\n' | |
| b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') | |
| # --- Flask Routes --- | |
| def index(): | |
| logger.info("Rendering index.html (original main page)") | |
| return render_template('index.html') | |
| def live_test_page_route(): | |
| logger.info("Serving live_test.html page via /live route.") | |
| # Assumes app.py and live_test.html are in the root of the application directory | |
| # This path is relative to the app's root directory. | |
| return send_from_directory('.', 'live_test.html') | |
| def dashboard(): | |
| logger.info("Rendering dashboard.html") | |
| try: | |
| recent_workouts_data = workout_logger.get_recent_workouts(5) if hasattr(workout_logger, 'get_recent_workouts') else [] | |
| weekly_stats_data = workout_logger.get_weekly_stats() if hasattr(workout_logger, 'get_weekly_stats') else {} | |
| exercise_distribution_data = workout_logger.get_exercise_distribution() if hasattr(workout_logger, 'get_exercise_distribution') else {} | |
| user_stats_data = workout_logger.get_user_stats() if hasattr(workout_logger, 'get_user_stats') else {'total_workouts': 0, 'total_exercises': 0, 'streak_days': 0} | |
| formatted_workouts = [] | |
| if recent_workouts_data: | |
| for workout in recent_workouts_data: | |
| formatted_workouts.append({ | |
| 'date': workout.get('date', 'N/A'), | |
| 'exercise': workout.get('exercise_type', 'Unknown').replace('_', ' ').title(), | |
| 'sets': workout.get('sets', 0), | |
| 'reps': workout.get('reps', 0), | |
| 'duration': f"{workout.get('duration_seconds', 0) // 60}:{workout.get('duration_seconds', 0) % 60:02d}" | |
| }) | |
| weekly_workout_count = 0 | |
| if weekly_stats_data: | |
| weekly_workout_count = sum(day.get('workout_count', 0) for day in weekly_stats_data.values()) | |
| return render_template('dashboard.html', | |
| recent_workouts=formatted_workouts, | |
| weekly_workouts=weekly_workout_count, | |
| total_workouts=user_stats_data.get('total_workouts', 0), | |
| total_exercises=user_stats_data.get('total_exercises', 0), | |
| streak_days=user_stats_data.get('streak_days', 0)) | |
| except Exception as e: | |
| logger.error(f"Error in dashboard: {e}", exc_info=True) | |
| return f"Error loading dashboard: {str(e)}", 500 | |
| def video_feed(): | |
| logger.info("Access to /video_feed (original server-side camera streaming)") | |
| # This route is for the original functionality and will likely not work in typical serverless/container environments | |
| # as it tries to access a local camera (index 0). | |
| # It's kept for now but might need to be disabled or re-thought for cloud deployment. | |
| return Response(generate_frames_original(), | |
| mimetype='multipart/x-mixed-replace; boundary=frame') | |
| def start_exercise_route(): | |
| global exercise_running_original_feed, current_exercise_original_feed, current_exercise_data_original_feed | |
| global exercise_counter_original_feed, exercise_goal_original_feed, sets_completed_original_feed, sets_goal_original_feed | |
| global workout_start_time_original_feed | |
| data = request.json | |
| exercise_type = data.get('exercise_type') | |
| sets_goal_original_feed = int(data.get('sets', 3)) | |
| exercise_goal_original_feed = int(data.get('reps', 10)) | |
| if not initialize_camera_original(): # Try to initialize if not already | |
| logger.error("Failed to initialize camera for /start_exercise") | |
| return jsonify({'success': False, 'error': 'Camera not available'}) | |
| exercise_counter_original_feed = 0 | |
| sets_completed_original_feed = 0 | |
| workout_start_time_original_feed = time.time() | |
| if exercise_type == "squat": | |
| current_exercise_original_feed = Squat() | |
| elif exercise_type == "push_up": | |
| current_exercise_original_feed = PushUp() | |
| elif exercise_type == "hammer_curl": | |
| current_exercise_original_feed = HammerCurl() | |
| else: | |
| return jsonify({'success': False, 'error': 'Invalid exercise type'}) | |
| current_exercise_data_original_feed = { | |
| 'type': exercise_type, | |
| 'sets': sets_goal_original_feed, | |
| 'reps': exercise_goal_original_feed | |
| } | |
| exercise_running_original_feed = True | |
| logger.info(f"Original exercise started via HTTP: {exercise_type}") | |
| return jsonify({'success': True}) | |
| def stop_exercise_route(): | |
| global exercise_running_original_feed, current_exercise_data_original_feed, workout_start_time_original_feed | |
| global sets_completed_original_feed, exercise_counter_original_feed, exercise_goal_original_feed | |
| if exercise_running_original_feed and current_exercise_data_original_feed and workout_start_time_original_feed: | |
| duration = int(time.time() - workout_start_time_original_feed) | |
| workout_logger.log_workout( | |
| exercise_type=current_exercise_data_original_feed['type'], | |
| sets=sets_completed_original_feed + (1 if exercise_counter_original_feed > 0 else 0), | |
| reps=exercise_goal_original_feed, | |
| duration_seconds=duration | |
| ) | |
| release_camera_original() | |
| exercise_running_original_feed = False | |
| logger.info("Original exercise stopped via HTTP.") | |
| return jsonify({'success': True}) | |
| def get_status_route(): | |
| global exercise_counter_original_feed, sets_completed_original_feed, exercise_goal_original_feed, sets_goal_original_feed, exercise_running_original_feed | |
| return jsonify({ | |
| 'exercise_running': exercise_running_original_feed, | |
| 'current_reps': exercise_counter_original_feed, | |
| 'current_set': sets_completed_original_feed + 1 if exercise_running_original_feed else 0, | |
| 'total_sets': sets_goal_original_feed, | |
| 'rep_goal': exercise_goal_original_feed | |
| }) | |
| def profile_route(): | |
| return "Profile page - Coming soon!" | |
| # --- HTTP API Endpoints (from original app) --- | |
| def api_analyze_frame(): | |
| logger.info("API call to /api/analyze_frame") | |
| try: | |
| data = request.json | |
| if not data or 'image' not in data: | |
| return jsonify({'error': 'No image provided'}), 400 | |
| exercise_type_for_api = data.get('exercise_type', 'squat') | |
| image_data = data['image'] | |
| missing_padding = len(image_data) % 4 | |
| if missing_padding: image_data += '=' * (4 - missing_padding) | |
| image_bytes = base64.b64decode(image_data) | |
| pil_image = Image.open(io.BytesIO(image_bytes)) | |
| frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) | |
| estimator = get_pose_estimator_instance() | |
| results = estimator.estimate_pose(frame, exercise_type_for_api) | |
| landmarks_list = [] | |
| if results.pose_landmarks: | |
| for i, landmark in enumerate(results.pose_landmarks.landmark): | |
| landmarks_list.append({ | |
| 'index': i, 'x': landmark.x, 'y': landmark.y, 'z': landmark.z, | |
| 'visibility': landmark.visibility if hasattr(landmark, 'visibility') else None | |
| }) | |
| return jsonify({'success': True, 'landmarks': landmarks_list}) | |
| except Exception as e: | |
| logger.error(f"API /api/analyze_frame: Error: {e}", exc_info=True) | |
| return jsonify({'error': f'Internal server error: {str(e)}'}), 500 | |
| def api_track_exercise_stream(): | |
| logger.info("API call to /api/track_exercise_stream") | |
| data = {} | |
| try: | |
| data = request.json | |
| if not data: return jsonify({'error': 'No JSON data provided'}), 400 | |
| session_id = data.get('session_id') | |
| exercise_type = data.get('exercise_type') | |
| image_data_base64 = data.get('image') | |
| frame_width = data.get('frame_width') | |
| frame_height = data.get('frame_height') | |
| if not all([session_id, exercise_type, image_data_base64, frame_width is not None, frame_height is not None]): | |
| return jsonify({'error': 'Missing required fields'}), 400 | |
| if not isinstance(frame_width, int) or not isinstance(frame_height, int) or frame_width <= 0 or frame_height <= 0: | |
| return jsonify({'error': 'Invalid frame_width or frame_height'}), 400 | |
| if len(active_exercise_sessions) >= MAX_SESSIONS and session_id not in active_exercise_sessions: | |
| logger.warning(f"Max sessions ({MAX_SESSIONS}) reached for HTTP. Rejecting new session {session_id}.") | |
| return jsonify({'error': 'Server busy, max sessions reached.'}), 503 | |
| if session_id not in active_exercise_sessions or active_exercise_sessions[session_id].get('source') != 'http': | |
| logger.info(f"Creating new HTTP session {session_id} for exercise {exercise_type}") | |
| exercise_instance_http = None | |
| if exercise_type == 'squat': exercise_instance_http = Squat() | |
| elif exercise_type == 'push_up': exercise_instance_http = PushUp() | |
| elif exercise_type == 'hammer_curl': exercise_instance_http = HammerCurl() | |
| else: return jsonify({'error': 'Invalid exercise_type'}), 400 | |
| active_exercise_sessions[session_id] = {'exercise': exercise_instance_http, 'type': exercise_type, 'source': 'http'} | |
| session_data = active_exercise_sessions[session_id] | |
| exercise_session_obj = session_data['exercise'] | |
| missing_padding = len(image_data_base64) % 4 | |
| if missing_padding: image_data_base64 += '=' * (4 - missing_padding) | |
| image_bytes = base64.b64decode(image_data_base64) | |
| pil_image = Image.open(io.BytesIO(image_bytes)) | |
| frame_for_estimation = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) | |
| pose_estimator = get_pose_estimator_instance() | |
| results = pose_estimator.estimate_pose(frame_for_estimation, exercise_type) | |
| if not results.pose_landmarks: | |
| return jsonify({'success': True, 'landmarks_detected': False, 'message': 'No landmarks detected in this frame.'}) | |
| exercise_data_output = None | |
| if exercise_type == 'squat': | |
| exercise_data_output = exercise_session_obj.track_squat(results.pose_landmarks.landmark, frame_width, frame_height) | |
| elif exercise_type == 'push_up': | |
| exercise_data_output = exercise_session_obj.track_push_up(results.pose_landmarks.landmark, frame_width, frame_height) | |
| elif exercise_type == 'hammer_curl': | |
| exercise_data_output = exercise_session_obj.track_hammer_curl(results.pose_landmarks.landmark, frame_width, frame_height) | |
| if exercise_data_output: | |
| return jsonify({'success': True, 'landmarks_detected': True, 'data': exercise_data_output}) | |
| else: | |
| return jsonify({'error': 'Failed to process exercise frame.'}), 500 | |
| except Exception as e: | |
| session_id_log = data.get('session_id', 'unknown_session') if isinstance(data, dict) else 'unknown_session' | |
| logger.error(f"API /api/track_exercise_stream Error for session {session_id_log}: {e}", exc_info=True) | |
| return jsonify({'error': f'Internal server error: {str(e)}'}), 500 | |
| def api_end_exercise_session(): | |
| logger.info("API call to /api/end_exercise_session") | |
| try: | |
| data = request.json | |
| if not data: return jsonify({'error': 'No JSON data provided'}), 400 | |
| session_id = data.get('session_id') | |
| if not session_id: return jsonify({'error': 'Missing session_id'}), 400 | |
| if session_id in active_exercise_sessions and active_exercise_sessions[session_id].get('source') == 'http': | |
| del active_exercise_sessions[session_id] | |
| logger.info(f"Ended and removed HTTP API session: {session_id}") | |
| return jsonify({'success': True, 'message': f'Session {session_id} ended.'}) | |
| else: | |
| logger.warning(f"Attempted to end non-existent or non-HTTP session: {session_id}") | |
| return jsonify({'success': False, 'message': f'Session {session_id} not found or not an HTTP session.'}), 404 | |
| except Exception as e: | |
| logger.error(f"API /api/end_exercise_session Error: {e}", exc_info=True) | |
| return jsonify({'error': f'Internal server error: {str(e)}'}), 500 | |
| # --- WebSocket Event Handlers (for live_test.html) --- | |
| def ws_connect(): | |
| logger.info(f"WebSocket client connected: {request.sid}") | |
| # It's good practice to emit to the specific client (room=request.sid) | |
| socketio.emit('connection_ack', {'message': 'Successfully connected via WebSocket!', 'sid': request.sid}, room=request.sid) | |
| def ws_disconnect(): | |
| logger.info(f"WebSocket client disconnected: {request.sid}") | |
| if request.sid in active_exercise_sessions and active_exercise_sessions[request.sid].get('source') == 'websocket': | |
| del active_exercise_sessions[request.sid] | |
| logger.info(f"Cleaned up WebSocket exercise session for client {request.sid}") | |
| def ws_start_exercise_session(data): | |
| exercise_type = data.get('exercise_type') | |
| client_sid = request.sid # Use this for all operations related to this client | |
| logger.info(f"WebSocket: Attempting to start exercise session for {client_sid} with type: {exercise_type}") | |
| if not exercise_type: | |
| logger.warning(f"WebSocket session start for {client_sid} failed: no exercise_type provided.") | |
| socketio.emit('session_error', {'error': 'exercise_type is required.'}, room=client_sid) | |
| return | |
| # Check if we are about to exceed max sessions, specifically for new WebSocket sessions | |
| # This logic might need refinement if HTTP sessions also count towards the same limit strictly | |
| if len([s for s in active_exercise_sessions.values() if s.get('source') == 'websocket']) >= MAX_SESSIONS and client_sid not in active_exercise_sessions: | |
| logger.warning(f"Max WebSocket sessions ({MAX_SESSIONS}) reached. Rejecting new session for {client_sid}.") | |
| socketio.emit('session_error', {'error': 'Server busy, max WebSocket sessions reached.'}, room=client_sid) | |
| return | |
| if client_sid in active_exercise_sessions and active_exercise_sessions[client_sid].get('source') == 'websocket': | |
| logger.info(f"WebSocket session for {client_sid} already exists. Re-initializing for new exercise: {exercise_type}") | |
| exercise_instance = None | |
| if exercise_type == 'squat': | |
| exercise_instance = Squat() | |
| elif exercise_type == 'push_up': | |
| exercise_instance = PushUp() | |
| elif exercise_type == 'hammer_curl': | |
| exercise_instance = HammerCurl() | |
| else: | |
| logger.warning(f"Invalid exercise type: {exercise_type} for WebSocket session {client_sid}") | |
| socketio.emit('session_error', {'error': 'Invalid exercise_type'}, room=client_sid) | |
| return | |
| active_exercise_sessions[client_sid] = {'exercise': exercise_instance, 'type': exercise_type, 'source': 'websocket'} | |
| logger.info(f"Successfully created WebSocket exercise session for {client_sid}, type: {exercise_type}") | |
| socketio.emit('session_started', {'session_id': client_sid, 'exercise_type': exercise_type}, room=client_sid) | |
| def ws_process_frame(data): | |
| client_sid = request.sid | |
| # logger.debug(f"WebSocket: Received frame from {client_sid} for processing.") | |
| if client_sid not in active_exercise_sessions or active_exercise_sessions[client_sid].get('source') != 'websocket': | |
| logger.warning(f"WebSocket: Frame received from {client_sid} without an active WebSocket session.") | |
| socketio.emit('frame_error', {'error': 'No active WebSocket session. Please start an exercise session first.'}, room=client_sid) | |
| return | |
| session_details = active_exercise_sessions.get(client_sid) | |
| # This check is a bit redundant due to the one above, but safe | |
| if not session_details or session_details.get('source') != 'websocket': | |
| logger.error(f"Session info not found for {client_sid} or not a WebSocket session. This shouldn't happen.") | |
| socketio.emit('frame_error', {'error': 'Internal session error.'}, room=client_sid) | |
| return | |
| exercise_session_obj = session_details['exercise'] | |
| session_exercise_type = session_details['type'] | |
| image_data_base64 = data.get('image') | |
| frame_width = data.get('frame_width') # Ensure these are integers | |
| frame_height = data.get('frame_height') # Ensure these are integers | |
| if not all([image_data_base64, isinstance(frame_width, int), isinstance(frame_height, int)]): | |
| logger.warning(f"WebSocket: Missing or invalid data in process_frame for {client_sid}.") | |
| socketio.emit('frame_error', {'error': 'Missing or invalid image, frame_width, or frame_height.'}, room=client_sid) | |
| return | |
| try: | |
| # Inner try for image decoding | |
| try: | |
| missing_padding = len(image_data_base64) % 4 | |
| if missing_padding: | |
| image_data_base64 += '=' * (4 - missing_padding) | |
| image_bytes = base64.b64decode(image_data_base64) | |
| pil_image = Image.open(io.BytesIO(image_bytes)) | |
| frame_for_estimation = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) | |
| except Exception as e: | |
| logger.error(f"WebSocket: Error decoding base64 image for {client_sid}: {e}", exc_info=True) | |
| socketio.emit('frame_error', {'error': f'Invalid base64 image data: {str(e)}'}, room=client_sid) | |
| return | |
| # Main processing logic | |
| pose_estimator = get_pose_estimator_instance() # Use the renamed function | |
| results = pose_estimator.estimate_pose(frame_for_estimation, session_exercise_type) | |
| exercise_data_result = None | |
| if results.pose_landmarks: | |
| if session_exercise_type == 'squat': | |
| exercise_data_result = exercise_session_obj.track_squat(results.pose_landmarks.landmark, frame_width, frame_height) | |
| elif session_exercise_type == 'push_up': | |
| exercise_data_result = exercise_session_obj.track_push_up(results.pose_landmarks.landmark, frame_width, frame_height) | |
| elif session_exercise_type == 'hammer_curl': | |
| exercise_data_result = exercise_session_obj.track_hammer_curl(results.pose_landmarks.landmark, frame_width, frame_height) | |
| if exercise_data_result: | |
| socketio.emit('exercise_update', {'success': True, 'landmarks_detected': True, 'data': exercise_data_result}, room=client_sid) | |
| else: | |
| logger.warning(f"WebSocket: Exercise tracking for {session_exercise_type} returned no data for {client_sid}, despite landmarks detected.") | |
| socketio.emit('exercise_update', {'success': False, 'landmarks_detected': True, 'message': 'Could not process exercise data.'}, room=client_sid) | |
| else: | |
| socketio.emit('exercise_update', {'success': True, 'landmarks_detected': False, 'message': 'No landmarks detected in frame.'}, room=client_sid) | |
| except Exception as e: | |
| logger.error(f"WebSocket: Error processing frame for {client_sid}, exercise {session_exercise_type}: {e}", exc_info=True) | |
| # traceback.print_exc() # logger.error with exc_info=True already includes traceback | |
| socketio.emit('frame_error', {'error': f'Internal server error during frame processing: {str(e)}'}, room=client_sid) | |
| logger.info("Route and SocketIO handlers defined.") | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) # Default to 7860 for HF Spaces consistency | |
| logger.info(f"Starting Flask-SocketIO app directly (for local development) using eventlet on host 0.0.0.0, port {port}") | |
| # When running with Gunicorn, Gunicorn handles the async_mode (eventlet in this case). | |
| # The async_mode in SocketIO() constructor is key for Gunicorn. | |
| # debug=True can cause issues with some SocketIO setups and multiple workers, keep False for production-like testing. | |
| socketio.run(app, host='0.0.0.0', port=port, debug=False, use_reloader=False) | |
| logger.info("App.py script finished executing top-level statements (this line may not be reached if server is running).") | |