Spaces:
Sleeping
Sleeping
Sean Carnahan
Patch for Hugging Face Spaces: fix matplotlib config, check .gitignore, prep for model file inclusion
f20fe1f | from flask import Flask, render_template, request, jsonify, send_from_directory, url_for | |
| from flask_cors import CORS | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import os | |
| from werkzeug.utils import secure_filename | |
| import sys | |
| import traceback | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing import image | |
| # Add bodybuilding_pose_analyzer to path | |
| sys.path.append('.') # Assuming app.py is at the root of cv.github.io | |
| from bodybuilding_pose_analyzer.src.movenet_analyzer import MoveNetAnalyzer | |
| from bodybuilding_pose_analyzer.src.pose_analyzer import PoseAnalyzer | |
| app = Flask(__name__, static_url_path='/static', static_folder='static') | |
| CORS(app, resources={r"/*": {"origins": "*"}}) | |
| app.config['UPLOAD_FOLDER'] = 'static/uploads' | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size | |
| try: | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| except PermissionError: | |
| pass # Ignore if we can't create it (e.g., on HF Spaces) | |
| # Load CNN model for bodybuilding pose classification | |
| cnn_model_path = 'external/BodybuildingPoseClassifier/bodybuilding_pose_classifier.h5' | |
| cnn_model = load_model(cnn_model_path) | |
| cnn_class_labels = ['Side Chest', 'Front Double Biceps', 'Back Double Biceps', 'Front Lat Spread', 'Back Lat Spread'] | |
| def predict_pose_cnn(img_path): | |
| img = image.load_img(img_path, target_size=(150, 150)) | |
| img_array = image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) / 255.0 | |
| predictions = cnn_model.predict(img_array) | |
| predicted_class = np.argmax(predictions, axis=1) | |
| confidence = float(np.max(predictions)) | |
| return cnn_class_labels[predicted_class[0]], confidence | |
| def serve_video(filename): | |
| response = send_from_directory(app.config['UPLOAD_FOLDER'], filename, as_attachment=False) | |
| # Ensure correct content type, especially for Safari/iOS if issues arise | |
| if filename.lower().endswith('.mp4'): | |
| response.headers['Content-Type'] = 'video/mp4' | |
| return response | |
| def after_request(response): | |
| response.headers.add('Access-Control-Allow-Origin', '*') | |
| response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization,X-Requested-With,Accept') | |
| response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS') | |
| return response | |
| def process_video_movenet(video_path, model_variant='lightning', pose_type='front_double_biceps'): | |
| try: | |
| print(f"[PROCESS_VIDEO_MOVENET] Called with video_path: {video_path}, model_variant: {model_variant}, pose_type: {pose_type}") | |
| if not os.path.exists(video_path): | |
| raise FileNotFoundError(f"Video file not found: {video_path}") | |
| analyzer = MoveNetAnalyzer(model_name=model_variant) | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Failed to open video file: {video_path}") | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| print(f"Processing video with MoveNet ({model_variant}): {width}x{height} @ {fps}fps") | |
| output_filename = f'output_movenet_{model_variant}.mp4' | |
| output_path = os.path.join(app.config['UPLOAD_FOLDER'], output_filename) | |
| fourcc = cv2.VideoWriter_fourcc(*'avc1') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| frame_count = 0 | |
| current_pose = pose_type # Initialized (e.g., to 'front_double_biceps') | |
| segment_length = 4 * fps if fps > 0 else 120 # 4 seconds worth of frames | |
| cnn_pose = None | |
| last_valid_landmarks = None | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| # Detect pose and get landmarks, reusing last valid landmarks if needed | |
| frame_with_pose, landmarks_analysis, landmarks = analyzer.process_frame(frame, current_pose, last_valid_landmarks=last_valid_landmarks) | |
| if landmarks: | |
| last_valid_landmarks = landmarks | |
| # Every 4 seconds, classify the pose (rule-based and CNN) | |
| if (frame_count - 1) % segment_length == 0: | |
| if landmarks: | |
| detected_pose = analyzer.classify_pose(landmarks) | |
| print(f"[AUTO-POSE] Frame {frame_count}: Detected pose: {detected_pose}") | |
| current_pose = detected_pose | |
| else: | |
| print(f"[AUTO-POSE] Frame {frame_count}: No landmarks detected, keeping previous pose: {current_pose}") | |
| # CNN prediction (every 4 seconds) | |
| temp_img_path = f'temp_frame_for_cnn_{frame_count}.jpg' | |
| cv2.imwrite(temp_img_path, frame) | |
| try: | |
| cnn_pose_pred, cnn_conf = predict_pose_cnn(temp_img_path) | |
| print(f"[CNN] Frame {frame_count}: Pose: {cnn_pose_pred}, Conf: {cnn_conf:.2f}") | |
| if cnn_conf >= 0.3: | |
| current_pose = cnn_pose_pred # <--- HERE current_pose is updated | |
| except Exception as e: | |
| print(f"[CNN] Error predicting pose: {e}") | |
| cnn_pose_pred, cnn_conf = None, 0.0 | |
| if os.path.exists(temp_img_path): | |
| os.remove(temp_img_path) | |
| # Determine best pose | |
| if cnn_conf >= 0.3: | |
| best_pose = cnn_pose_pred | |
| elif landmarks: | |
| best_pose = analyzer.classify_pose(landmarks) | |
| else: | |
| best_pose = 'Uncertain' | |
| # Analyze using the current pose | |
| analysis = analyzer.analyze_pose(landmarks, current_pose) if landmarks else {'error': 'No pose detected'} | |
| # Overlay results | |
| y_offset = 90 | |
| if 'error' not in analysis: | |
| display_model_name = f"Gladiator {model_variant.capitalize()}" | |
| cv2.putText(frame_with_pose, f"Model: {display_model_name}", | |
| (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) | |
| cv2.putText(frame_with_pose, f"Gladiator Pose: {best_pose}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) | |
| for joint, angle in analysis.get('angles', {}).items(): | |
| text_to_display = f"{joint.capitalize()}: {angle:.1f} deg" | |
| cv2.putText(frame_with_pose, text_to_display, | |
| (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) | |
| y_offset += 25 | |
| for correction in analysis.get('corrections', []): | |
| cv2.putText(frame_with_pose, correction, | |
| (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2) | |
| y_offset += 25 | |
| else: | |
| cv2.putText(frame_with_pose, analysis['error'], | |
| (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) | |
| out.write(frame_with_pose) | |
| cap.release() | |
| out.release() | |
| if frame_count == 0: | |
| raise ValueError("No frames were processed from the video by MoveNet") | |
| print(f"MoveNet video processing completed. Processed {frame_count} frames. Output: {output_path}") | |
| return url_for('serve_video', filename=output_filename, _external=False) | |
| except Exception as e: | |
| print(f'Error in process_video_movenet: {e}') | |
| traceback.print_exc() | |
| raise | |
| def process_video_mediapipe(video_path): | |
| try: | |
| print(f"[PROCESS_VIDEO_MEDIAPIPE] Called with video_path: {video_path}") | |
| if not os.path.exists(video_path): | |
| raise FileNotFoundError(f"Video file not found: {video_path}") | |
| analyzer = PoseAnalyzer() | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Failed to open video file: {video_path}") | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| print(f"Processing video with MediaPipe: {width}x{height} @ {fps}fps") | |
| output_filename = f'output_mediapipe.mp4' | |
| output_path = os.path.join(app.config['UPLOAD_FOLDER'], output_filename) | |
| fourcc = cv2.VideoWriter_fourcc(*'avc1') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| frame_count = 0 | |
| cnn_pose = None | |
| segment_length = 4 * fps if fps > 0 else 120 # 4 seconds worth of frames | |
| last_valid_landmarks = None | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| # Detect pose and analyze, reusing last valid landmarks if needed | |
| frame_with_pose, analysis, landmarks = analyzer.process_frame(frame, last_valid_landmarks=last_valid_landmarks) | |
| if landmarks: | |
| last_valid_landmarks = landmarks | |
| # Every 4 seconds, classify the pose using CNN | |
| if (frame_count - 1) % segment_length == 0: | |
| temp_img_path = 'temp_frame_for_cnn.jpg' | |
| cv2.imwrite(temp_img_path, frame) | |
| try: | |
| cnn_pose, cnn_conf = predict_pose_cnn(temp_img_path) | |
| print(f"[CNN] Confidence: {cnn_conf:.3f} for pose: {cnn_pose}") | |
| except Exception as e: | |
| print(f"[CNN] Error predicting pose: {e}") | |
| cnn_pose, cnn_conf = None, 0.0 | |
| if os.path.exists(temp_img_path): | |
| os.remove(temp_img_path) | |
| # Determine best pose | |
| if cnn_conf >= 0.3: | |
| best_pose = cnn_pose | |
| else: | |
| best_pose = 'Uncertain' | |
| # Overlay results | |
| y_offset = 30 | |
| cv2.putText(frame_with_pose, f"Model: Gladiator SupaDot", (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) | |
| y_offset += 30 | |
| cv2.putText(frame_with_pose, f"Gladiator Pose: {best_pose}", (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) | |
| y_offset += 30 | |
| if 'error' not in analysis: | |
| for joint, angle in analysis.get('angles', {}).items(): | |
| text_to_display = f"{joint.capitalize()}: {angle:.1f} deg" | |
| cv2.putText(frame_with_pose, text_to_display, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) | |
| y_offset += 25 | |
| for correction in analysis.get('corrections', []): | |
| cv2.putText(frame_with_pose, correction, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2) | |
| y_offset += 25 | |
| else: | |
| cv2.putText(frame_with_pose, analysis['error'], (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) | |
| out.write(frame_with_pose) | |
| cap.release() | |
| out.release() | |
| if frame_count == 0: | |
| raise ValueError("No frames were processed from the video by MediaPipe") | |
| print(f"MediaPipe video processing completed. Processed {frame_count} frames. Output: {output_path}") | |
| return url_for('serve_video', filename=output_filename, _external=False) | |
| except Exception as e: | |
| print(f'Error in process_video_mediapipe: {e}') | |
| traceback.print_exc() | |
| raise | |
| def index(): | |
| return render_template('index.html') | |
| def upload_file(): | |
| try: | |
| if 'video' not in request.files: | |
| return jsonify({'error': 'No video file provided'}), 400 | |
| file = request.files['video'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No selected file'}), 400 | |
| if file: | |
| allowed_extensions = {'mp4', 'avi', 'mov', 'mkv'} | |
| if '.' not in file.filename or file.filename.rsplit('.', 1)[1].lower() not in allowed_extensions: | |
| return jsonify({'error': 'Invalid file format. Allowed formats: mp4, avi, mov, mkv'}), 400 | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| print(f"File saved to: {filepath}") | |
| try: | |
| model_choice = request.form.get('model_choice', 'Gladiator SupaDot') | |
| if model_choice == 'movenet': | |
| movenet_variant = request.form.get('movenet_variant', 'lightning') | |
| output_path_url = process_video_movenet(filepath, model_variant=movenet_variant) | |
| else: | |
| output_path_url = process_video_mediapipe(filepath) | |
| print(f"[DEBUG] Generated video URL for client: {output_path_url}") | |
| return jsonify({ | |
| 'message': f'Video processed successfully with {model_choice}', | |
| 'output_path': output_path_url | |
| }) | |
| except Exception as e: | |
| print(f"Error processing video: {e}") | |
| traceback.print_exc() | |
| return jsonify({'error': f'Error processing video: {str(e)}'}), 500 | |
| finally: | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| except Exception as e: | |
| print(f"Error in upload_file: {e}") | |
| traceback.print_exc() | |
| return jsonify({'error': 'Internal server error'}), 500 | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860, debug=True) |