liveTrain / app.py
pjxcharya's picture
Update app.py
5aa9372 verified
from flask import Flask, render_template, Response, request, jsonify, session, redirect, url_for
import cv2
import base64
import io
import numpy as np
from PIL import Image
import threading
import time
import uuid
import sys
import traceback
import logging
from flask_cors import CORS
# Set up logging
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)
# Import attempt with error handling
try:
from pose_estimation.estimation import PoseEstimator
from exercises.squat import Squat
from exercises.hammer_curl import HammerCurl
from exercises.push_up import PushUp
from feedback.information import get_exercise_info
from feedback.layout import layout_indicators
from utils.draw_text_with_background import draw_text_with_background
logger.info("Successfully imported pose estimation modules")
except ImportError as e:
logger.error(f"Failed to import required modules: {e}")
traceback.print_exc()
sys.exit(1)
# Try to import WorkoutLogger with fallback
try:
from db.workout_logger import WorkoutLogger
workout_logger = WorkoutLogger()
logger.info("Successfully initialized workout logger")
except ImportError:
logger.warning("WorkoutLogger import failed, creating dummy class")
class DummyWorkoutLogger:
def __init__(self):
pass
def log_workout(self, *args, **kwargs):
return {}
def get_recent_workouts(self, *args, **kwargs):
return []
def get_weekly_stats(self, *args, **kwargs):
return {}
def get_exercise_distribution(self, *args, **kwargs):
return {}
def get_user_stats(self, *args, **kwargs):
return {'total_workouts': 0, 'total_exercises': 0, 'streak_days': 0}
workout_logger = DummyWorkoutLogger()
logger.info("Setting up Flask application")
app = Flask(__name__)
app.secret_key = 'fitness_trainer_secret_key' # Required for sessions
CORS(app, origins="*", methods=["GET", "POST", "OPTIONS"], allow_headers=["Content-Type", "Authorization"])
pose_estimator_api = None
active_exercise_sessions = {}
# Max number of concurrent sessions to avoid memory issues
MAX_SESSIONS = 100
def get_pose_estimator_api_instance():
global pose_estimator_api
if pose_estimator_api is None:
logger.info("Initializing PoseEstimator for API")
pose_estimator_api = PoseEstimator()
return pose_estimator_api
# Global variables
camera = None
output_frame = None
lock = threading.Lock()
exercise_running = False
current_exercise = None
current_exercise_data = None
exercise_counter = 0
exercise_goal = 0
sets_completed = 0
sets_goal = 0
workout_start_time = None
def initialize_camera():
global camera
if camera is None:
camera = cv2.VideoCapture(0)
return camera
def release_camera():
global camera
if camera is not None:
camera.release()
camera = None
def generate_frames():
global output_frame, lock, exercise_running, current_exercise, current_exercise_data
global exercise_counter, exercise_goal, sets_completed, sets_goal
pose_estimator = PoseEstimator()
while True:
if camera is None:
continue
success, frame = camera.read()
if not success:
continue
# Only process frames if an exercise is running
if exercise_running and current_exercise:
# Process with pose estimation
results = pose_estimator.estimate_pose(frame, current_exercise_data['type'])
if results.pose_landmarks:
# Track exercise based on type
if current_exercise_data['type'] == "squat":
counter, angle, stage = current_exercise.track_squat(results.pose_landmarks.landmark, frame)
layout_indicators(frame, current_exercise_data['type'], (counter, angle, stage))
exercise_counter = counter
elif current_exercise_data['type'] == "push_up":
counter, angle, stage = current_exercise.track_push_up(results.pose_landmarks.landmark, frame)
layout_indicators(frame, current_exercise_data['type'], (counter, angle, stage))
exercise_counter = counter
elif current_exercise_data['type'] == "hammer_curl":
(counter_right, angle_right, counter_left, angle_left,
warning_message_right, warning_message_left, progress_right,
progress_left, stage_right, stage_left) = current_exercise.track_hammer_curl(
results.pose_landmarks.landmark, frame)
layout_indicators(frame, current_exercise_data['type'],
(counter_right, angle_right, counter_left, angle_left,
warning_message_right, warning_message_left,
progress_right, progress_left, stage_right, stage_left))
exercise_counter = max(counter_right, counter_left)
# Display exercise information
exercise_info = get_exercise_info(current_exercise_data['type'])
draw_text_with_background(frame, f"Exercise: {exercise_info.get('name', 'N/A')}", (40, 50),
cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1)
draw_text_with_background(frame, f"Reps Goal: {exercise_goal}", (40, 80),
cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1)
draw_text_with_background(frame, f"Sets Goal: {sets_goal}", (40, 110),
cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1)
draw_text_with_background(frame, f"Current Set: {sets_completed + 1}", (40, 140),
cv2.FONT_HERSHEY_DUPLEX, 0.7, (255, 255, 255), (118, 29, 14), 1)
# Check if rep goal is reached for current set
if exercise_counter >= exercise_goal:
sets_completed += 1
exercise_counter = 0
# Reset exercise counter in the appropriate exercise object
if current_exercise_data['type'] == "squat" or current_exercise_data['type'] == "push_up":
current_exercise.counter = 0
elif current_exercise_data['type'] == "hammer_curl":
current_exercise.counter_right = 0
current_exercise.counter_left = 0
# Check if all sets are completed
if sets_completed >= sets_goal:
exercise_running = False
draw_text_with_background(frame, "WORKOUT COMPLETE!", (frame.shape[1]//2 - 150, frame.shape[0]//2),
cv2.FONT_HERSHEY_DUPLEX, 1.2, (255, 255, 255), (0, 200, 0), 2)
else:
draw_text_with_background(frame, f"SET {sets_completed} COMPLETE! Rest for 30 sec",
(frame.shape[1]//2 - 200, frame.shape[0]//2),
cv2.FONT_HERSHEY_DUPLEX, 1.0, (255, 255, 255), (0, 0, 200), 2)
# We could add rest timer functionality here
else:
# Display welcome message if no exercise is running
cv2.putText(frame, "Select an exercise to begin", (frame.shape[1]//2 - 150, frame.shape[0]//2),
cv2.FONT_HERSHEY_DUPLEX, 0.8, (255, 255, 255), 1)
# Encode the frame in JPEG format
with lock:
output_frame = frame.copy()
# Yield the frame in byte format
ret, buffer = cv2.imencode('.jpg', output_frame)
frame = buffer.tobytes()
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n')
@app.route('/')
def index():
"""Home page with exercise selection"""
logger.info("Rendering index page")
try:
return render_template('index.html')
except Exception as e:
logger.error(f"Error rendering index: {e}")
return f"Error rendering template: {str(e)}", 500
@app.route('/dashboard')
def dashboard():
"""Dashboard page with workout statistics"""
logger.info("Rendering dashboard page")
try:
# Get data for the dashboard
recent_workouts = workout_logger.get_recent_workouts(5)
weekly_stats = workout_logger.get_weekly_stats()
exercise_distribution = workout_logger.get_exercise_distribution()
user_stats = workout_logger.get_user_stats()
# Format workouts for display
formatted_workouts = []
for workout in recent_workouts:
formatted_workouts.append({
'date': workout['date'],
'exercise': workout['exercise_type'].replace('_', ' ').title(),
'sets': workout['sets'],
'reps': workout['reps'],
'duration': f"{workout['duration_seconds'] // 60}:{workout['duration_seconds'] % 60:02d}"
})
# Calculate total workouts this week
weekly_workout_count = sum(day['workout_count'] for day in weekly_stats.values())
return render_template('dashboard.html',
recent_workouts=formatted_workouts,
weekly_workouts=weekly_workout_count,
total_workouts=user_stats['total_workouts'],
total_exercises=user_stats['total_exercises'],
streak_days=user_stats['streak_days'])
except Exception as e:
logger.error(f"Error in dashboard: {e}")
traceback.print_exc()
return f"Error loading dashboard: {str(e)}", 500
@app.route('/video_feed')
def video_feed():
"""Video streaming route"""
return Response(generate_frames(),
mimetype='multipart/x-mixed-replace; boundary=frame')
@app.route('/start_exercise', methods=['POST'])
def start_exercise():
"""Start a new exercise based on user selection"""
global exercise_running, current_exercise, current_exercise_data
global exercise_counter, exercise_goal, sets_completed, sets_goal
global workout_start_time
data = request.json
exercise_type = data.get('exercise_type')
sets_goal = int(data.get('sets', 3))
exercise_goal = int(data.get('reps', 10))
# Initialize camera if not already done
initialize_camera()
# Reset counters
exercise_counter = 0
sets_completed = 0
workout_start_time = time.time()
# Initialize the appropriate exercise class
if exercise_type == "squat":
current_exercise = Squat()
elif exercise_type == "push_up":
current_exercise = PushUp()
elif exercise_type == "hammer_curl":
current_exercise = HammerCurl()
else:
return jsonify({'success': False, 'error': 'Invalid exercise type'})
# Store exercise data
current_exercise_data = {
'type': exercise_type,
'sets': sets_goal,
'reps': exercise_goal
}
# Start the exercise
exercise_running = True
return jsonify({'success': True})
@app.route('/stop_exercise', methods=['POST'])
def stop_exercise():
"""Stop the current exercise and log the workout"""
global exercise_running, current_exercise_data, workout_start_time
global exercise_counter, exercise_goal, sets_completed, sets_goal
if exercise_running and current_exercise_data:
# Calculate duration
duration = int(time.time() - workout_start_time) if workout_start_time else 0
# Log the workout
workout_logger.log_workout(
exercise_type=current_exercise_data['type'],
sets=sets_completed + (1 if exercise_counter > 0 else 0), # Include partial set
reps=exercise_goal,
duration_seconds=duration
)
release_camera()
exercise_running = False
return jsonify({'success': True})
@app.route('/get_status', methods=['GET'])
def get_status():
"""Return current exercise status"""
global exercise_counter, sets_completed, exercise_goal, sets_goal, exercise_running
return jsonify({
'exercise_running': exercise_running,
'current_reps': exercise_counter,
'current_set': sets_completed + 1 if exercise_running else 0,
'total_sets': sets_goal,
'rep_goal': exercise_goal
})
@app.route('/profile')
def profile():
"""User profile page - placeholder for future development"""
return "Profile page - Coming soon!"
@app.route('/healthz')
def health_check():
logger.info("Health check endpoint called successfully.")
return "OK", 200
@app.route('/api/analyze_frame', methods=['POST'])
def analyze_frame():
logger.info("API call to /api/analyze_frame")
try:
data = request.json
if not data or 'image' not in data:
logger.warning("API /api/analyze_frame: No image provided or invalid JSON.")
return jsonify({'error': 'No image provided in JSON payload (expected base64 string under "image" key)'}), 400
# Use 'squat' as a default if not provided. This primarily affects server-side drawing, not the landmarks themselves.
exercise_type_for_api = data.get('exercise_type', 'squat')
if exercise_type_for_api not in ["squat", "push_up", "hammer_curl"]: # Validate against known types if necessary
logger.warning(f"API /api/analyze_frame: Invalid exercise_type '{exercise_type_for_api}'. Defaulting to 'squat'.")
exercise_type_for_api = "squat"
image_data = data['image']
# Add padding if missing for base64 decoding
missing_padding = len(image_data) % 4
if missing_padding:
image_data += '=' * (4 - missing_padding)
try:
# Decode the base64 string
image_bytes = base64.b64decode(image_data)
pil_image = Image.open(io.BytesIO(image_bytes))
# Convert PIL image (RGB) to OpenCV frame (BGR)
frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
except Exception as e:
logger.error(f"API /api/analyze_frame: Error decoding base64 image: {e}")
return jsonify({'error': f'Invalid base64 image data: {str(e)}'}), 400
estimator = get_pose_estimator_api_instance()
# Process with pose estimation
results = estimator.estimate_pose(frame, exercise_type_for_api)
landmarks_list = []
if results.pose_landmarks:
for i, landmark in enumerate(results.pose_landmarks.landmark):
landmarks_list.append({
'index': i,
'x': landmark.x,
'y': landmark.y,
'z': landmark.z,
'visibility': landmark.visibility if hasattr(landmark, 'visibility') else None
})
logger.info(f"API /api/analyze_frame: Successfully processed image, found {len(landmarks_list)} landmarks for exercise type '{exercise_type_for_api}'.")
return jsonify({'success': True, 'landmarks': landmarks_list})
else:
logger.info(f"API /api/analyze_frame: No landmarks detected in the provided image for exercise type '{exercise_type_for_api}'.")
return jsonify({'success': True, 'landmarks': []})
except Exception as e:
logger.error(f"API /api/analyze_frame: Error: {e}")
traceback.print_exc()
return jsonify({'error': f'Internal server error: {str(e)}'}), 500
@app.route('/api/track_exercise_stream', methods=['POST'])
def track_exercise_stream():
logger.info("API call to /api/track_exercise_stream")
data = {} # Initialize data to ensure it's defined for logging in case of early errors
try:
data = request.json
if not data:
return jsonify({'error': 'No JSON data provided'}), 400
session_id = data.get('session_id')
exercise_type = data.get('exercise_type')
image_data_base64 = data.get('image')
frame_width = data.get('frame_width')
frame_height = data.get('frame_height')
if not all([session_id, exercise_type, image_data_base64, frame_width, frame_height]):
return jsonify({'error': 'Missing required fields: session_id, exercise_type, image, frame_width, frame_height'}), 400
if not isinstance(frame_width, int) or not isinstance(frame_height, int) or frame_width <= 0 or frame_height <= 0:
return jsonify({'error': 'Invalid frame_width or frame_height'}), 400
# Manage session limit
if len(active_exercise_sessions) >= MAX_SESSIONS and session_id not in active_exercise_sessions:
logger.warning(f"Max sessions ({MAX_SESSIONS}) reached. Rejecting new session {session_id}.")
# Optional: could try to evict oldest session
return jsonify({'error': 'Server busy, max sessions reached. Please try again later.'}), 503
# Get or create exercise session object
if session_id not in active_exercise_sessions:
logger.info(f"Creating new session {session_id} for exercise {exercise_type}")
if exercise_type == 'squat':
active_exercise_sessions[session_id] = Squat()
elif exercise_type == 'push_up':
active_exercise_sessions[session_id] = PushUp()
elif exercise_type == 'hammer_curl':
active_exercise_sessions[session_id] = HammerCurl()
else:
logger.warning(f"Invalid exercise type: {exercise_type} for session {session_id}")
return jsonify({'error': 'Invalid exercise_type'}), 400
exercise_session = active_exercise_sessions[session_id]
# Decode image (similar to /api/analyze_frame)
missing_padding = len(image_data_base64) % 4
if missing_padding:
image_data_base64 += '=' * (4 - missing_padding)
try:
image_bytes = base64.b64decode(image_data_base64)
pil_image = Image.open(io.BytesIO(image_bytes))
frame_for_estimation = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) # For PoseEstimator
except Exception as e:
logger.error(f"Session {session_id}: Error decoding base64 image: {e}")
return jsonify({'error': f'Invalid base64 image data: {str(e)}'}), 400
# Get pose landmarks
pose_estimator = get_pose_estimator_api_instance() # Reusing the existing estimator instance getter
# The 'squat' here is just a default for PoseEstimator's internal drawing logic, which we don't use for API response.
# The actual exercise type for tracking is handled by `exercise_session` object.
results = pose_estimator.estimate_pose(frame_for_estimation, 'squat')
if not results.pose_landmarks:
logger.info(f"Session {session_id}: No landmarks detected.")
# Return current state even if no new landmarks, or specific message
return jsonify({
'success': True,
'landmarks_detected': False,
'message': 'No landmarks detected in this frame.',
# Optionally, could return last known state from exercise_session if needed
})
exercise_data = None
if exercise_type == 'squat':
exercise_data = exercise_session.track_squat(results.pose_landmarks.landmark, frame_width, frame_height)
elif exercise_type == 'push_up':
exercise_data = exercise_session.track_push_up(results.pose_landmarks.landmark, frame_width, frame_height)
elif exercise_type == 'hammer_curl':
exercise_data = exercise_session.track_hammer_curl(results.pose_landmarks.landmark, frame_width, frame_height)
if exercise_data:
logger.debug(f"Session {session_id}: Exercise data: {exercise_data}")
return jsonify({'success': True, 'landmarks_detected': True, 'data': exercise_data})
else:
# Should not happen if exercise_type is valid and session initialized
logger.error(f"Session {session_id}: Could not get exercise_data for {exercise_type}")
return jsonify({'error': 'Failed to process exercise frame.'}), 500
except Exception as e:
session_id_log = data.get('session_id', 'unknown_session') if isinstance(data, dict) else 'unknown_session'
logger.error(f"API /api/track_exercise_stream Error for session {session_id_log}: {e}")
traceback.print_exc()
return jsonify({'error': f'Internal server error: {str(e)}'}), 500
@app.route('/api/end_exercise_session', methods=['POST'])
def end_exercise_session():
logger.info("API call to /api/end_exercise_session")
try:
data = request.json
if not data:
return jsonify({'error': 'No JSON data provided'}), 400
session_id = data.get('session_id')
if not session_id:
return jsonify({'error': 'Missing session_id'}), 400
if session_id in active_exercise_sessions:
del active_exercise_sessions[session_id]
logger.info(f"Ended and removed session: {session_id}")
return jsonify({'success': True, 'message': f'Session {session_id} ended.'})
else:
logger.warning(f"Attempted to end non-existent session: {session_id}")
return jsonify({'success': False, 'message': f'Session {session_id} not found.'}), 404
except Exception as e:
logger.error(f"API /api/end_exercise_session Error: {e}")
traceback.print_exc()
return jsonify({'error': f'Internal server error: {str(e)}'}), 500
if __name__ == '__main__':
try:
logger.info("Starting the Flask application on http://127.0.0.1:5000")
print("Starting Fitness Trainer app, please wait...")
print("Open http://127.0.0.1:5000 in your web browser when the server starts")
app.run(debug=True)
except Exception as e:
logger.error(f"Failed to start application: {e}")
traceback.print_exc()