Spaces:
Runtime error
Runtime error
Pranav Mishra
Fix performance issues: updated requirements with exact versions and added model caching
a26c11a
| """ | |
| Audio Digit Classification API for Hugging Face Spaces | |
| Backend API for spoken digit recognition (0-9) - HF Spaces deployment | |
| """ | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import os | |
| import time | |
| import logging | |
| from typing import Dict, Any, Optional | |
| from dotenv import load_dotenv | |
| import numpy as np | |
| # Import audio processors (only essential ones for deployment) | |
| from audio_processors.external_api import ExternalAPIProcessor | |
| from audio_processors.whisper_digit_processor import WhisperDigitProcessor | |
| from audio_processors.ml_mfcc_processor import MLMFCCProcessor | |
| from audio_processors.ml_mel_cnn_processor import MLMelCNNProcessor | |
| from audio_processors.ml_raw_cnn_processor import MLRawCNNProcessor | |
| # Import utilities | |
| from utils.audio_utils import validate_audio_format, convert_audio_format, get_audio_duration, convert_for_ml_models | |
| from utils.logging_utils import performance_logger, setup_flask_logging | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| app.secret_key = os.getenv('SECRET_KEY', 'hf_spaces_deployment_key') | |
| # Enable CORS for frontend requests from Vercel | |
| CORS(app, origins=['*']) # In production, specify your Vercel domain | |
| # Setup logging | |
| setup_flask_logging(app) | |
| # Configuration for HF Spaces | |
| MAX_AUDIO_DURATION = 10 # seconds | |
| MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
| ALLOWED_EXTENSIONS = {'wav', 'mp3', 'ogg', 'm4a', 'webm'} | |
| def allowed_file(filename: str) -> bool: | |
| """Check if file extension is allowed.""" | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| # Global processor cache for model persistence | |
| _processor_cache = {} | |
| def initialize_processors(): | |
| """Initialize audio processors optimized for HF Spaces deployment with caching.""" | |
| global _processor_cache | |
| # Return cached processors if already initialized | |
| if _processor_cache: | |
| app.logger.info(f"Using cached processors: {len(_processor_cache)} available") | |
| return _processor_cache | |
| procs = {} | |
| # ML-trained processors (high priority - use best models only) | |
| ml_processors = [ | |
| ('ml_mfcc', MLMFCCProcessor, 'ML MFCC + Dense NN (Best - 98.52%)'), | |
| ('ml_mel_cnn', MLMelCNNProcessor, 'ML Mel CNN (Good - 97.22%)'), | |
| ('ml_raw_cnn', MLRawCNNProcessor, 'ML Raw CNN (Fair - 91.30%)') | |
| ] | |
| ml_working_count = 0 | |
| for proc_key, proc_class, proc_name in ml_processors: | |
| try: | |
| # Initialize once and cache | |
| app.logger.info(f"Loading {proc_name}...") | |
| processor = proc_class() | |
| if processor.is_configured(): | |
| procs[proc_key] = processor | |
| ml_working_count += 1 | |
| app.logger.info(f"[OK] {proc_name} loaded successfully (cached)") | |
| else: | |
| app.logger.warning(f"[WARN] {proc_name} not configured (model files missing)") | |
| except Exception as e: | |
| app.logger.error(f"[FAIL] Failed to initialize {proc_name}: {str(e)}") | |
| # External API processor as fallback | |
| try: | |
| external_processor = ExternalAPIProcessor() | |
| if external_processor.is_configured(): | |
| procs['external_api'] = external_processor | |
| app.logger.info("[OK] External API processor initialized (cached)") | |
| else: | |
| app.logger.warning("[WARN] External API not configured") | |
| except Exception as e: | |
| app.logger.error(f"[FAIL] Failed to initialize External API: {str(e)}") | |
| # Whisper digit processor as another fallback | |
| try: | |
| whisper_processor = WhisperDigitProcessor() | |
| if whisper_processor.is_configured(): | |
| procs['whisper_digit'] = whisper_processor | |
| app.logger.info("[OK] Whisper digit processor initialized (cached)") | |
| except Exception as e: | |
| app.logger.error(f"[FAIL] Failed to initialize Whisper: {str(e)}") | |
| # Cache the processors globally | |
| _processor_cache = procs | |
| app.logger.info(f"Processor initialization complete:") | |
| app.logger.info(f" ML Models loaded: {ml_working_count}/3") | |
| app.logger.info(f" Total processors cached: {len(procs)}") | |
| return procs | |
| # Initialize processors on startup (cached globally) | |
| processors = initialize_processors() | |
| def index(): | |
| """API status endpoint.""" | |
| return jsonify({ | |
| 'message': 'Streaming Digit Classifier API', | |
| 'status': 'running', | |
| 'version': '1.0.0', | |
| 'available_processors': list(processors.keys()), | |
| 'documentation': 'Frontend at Vercel, Backend API at HF Spaces' | |
| }) | |
| def process_audio(): | |
| """ | |
| Process audio file with selected method and return digit prediction. | |
| Expects multipart form data with 'audio' file and 'method' selection. | |
| """ | |
| try: | |
| # Validate request | |
| if 'audio' not in request.files: | |
| return jsonify({'error': 'No audio file provided'}), 400 | |
| if 'method' not in request.form: | |
| return jsonify({'error': 'No processing method specified'}), 400 | |
| audio_file = request.files['audio'] | |
| method = request.form['method'] | |
| # Validate audio file | |
| if audio_file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| if not allowed_file(audio_file.filename): | |
| return jsonify({'error': 'Unsupported file format'}), 400 | |
| # Validate method | |
| if method not in processors: | |
| return jsonify({'error': f'Unknown processing method: {method}'}), 400 | |
| # Read audio data | |
| audio_data = audio_file.read() | |
| # Check file size | |
| if len(audio_data) > MAX_FILE_SIZE: | |
| return jsonify({'error': 'Audio file too large'}), 400 | |
| # Convert to standard format | |
| try: | |
| app.logger.debug(f"Converting audio format. Original size: {len(audio_data)} bytes") | |
| standardized_audio = convert_audio_format(audio_data) | |
| app.logger.debug(f"Converted audio size: {len(standardized_audio)} bytes") | |
| except Exception as e: | |
| app.logger.error(f"Audio conversion failed: {str(e)}") | |
| return jsonify({'error': 'Failed to process audio format - unsupported format or corrupted file'}), 400 | |
| # Check audio duration | |
| duration = get_audio_duration(standardized_audio) | |
| if duration > MAX_AUDIO_DURATION: | |
| return jsonify({ | |
| 'error': f'Audio too long: {duration:.1f}s (max: {MAX_AUDIO_DURATION}s)' | |
| }), 400 | |
| if duration < 0.1: | |
| return jsonify({'error': 'Audio too short (minimum: 0.1s)'}), 400 | |
| # Log audio input info | |
| performance_logger.log_audio_info(duration, { | |
| 'filename': audio_file.filename, | |
| 'size_bytes': len(audio_data), | |
| 'converted_size': len(standardized_audio), | |
| 'method': method | |
| }) | |
| # Process with selected method | |
| processor = processors[method] | |
| result = processor.predict_with_timing(standardized_audio) | |
| # Log performance | |
| performance_logger.log_prediction(method, result) | |
| # Add additional metadata | |
| result.update({ | |
| 'audio_duration': round(duration, 3), | |
| 'file_size': len(audio_data), | |
| 'api_version': '1.0.0' | |
| }) | |
| app.logger.info(f"Processed audio with {method}: '{result['predicted_digit']}' in {result['inference_time']}s") | |
| return jsonify(result) | |
| except Exception as e: | |
| app.logger.error(f"Audio processing error: {str(e)}") | |
| return jsonify({ | |
| 'error': 'Internal processing error', | |
| 'success': False, | |
| 'timestamp': time.time() | |
| }), 500 | |
| def process_audio_chunk(): | |
| """ | |
| Process streaming audio chunk for real-time digit recognition. | |
| """ | |
| try: | |
| # Validate request | |
| if 'audio' not in request.files: | |
| return jsonify({'error': 'No audio chunk provided'}), 400 | |
| audio_file = request.files['audio'] | |
| method = request.form.get('method', 'ml_mfcc') # Default to best ML model | |
| # Validate method | |
| if method not in processors: | |
| return jsonify({'error': f'Unknown processing method: {method}'}), 400 | |
| # Read audio data | |
| audio_data = audio_file.read() | |
| # Check chunk size | |
| if len(audio_data) > MAX_FILE_SIZE: | |
| return jsonify({'error': 'Audio chunk too large'}), 400 | |
| if len(audio_data) < 100: | |
| return jsonify({'error': 'Audio chunk too small'}), 400 | |
| # Convert to standardized format | |
| try: | |
| standardized_audio = convert_for_ml_models(audio_data, 'streaming') | |
| except Exception as e: | |
| app.logger.error(f"Audio conversion failed for chunk: {str(e)}") | |
| return jsonify({'error': 'Failed to process audio chunk format'}), 400 | |
| # Process audio chunk | |
| processor = processors[method] | |
| result = processor.predict_with_timing(standardized_audio) | |
| # Add streaming metadata | |
| result.update({ | |
| 'segment_index': 0, | |
| 'segment_size': len(standardized_audio), | |
| 'is_streaming': True, | |
| 'api_version': '1.0.0' | |
| }) | |
| app.logger.info(f"Streaming prediction: '{result['predicted_digit']}' " | |
| f"(Inference: {result['inference_time']}s)") | |
| return jsonify({ | |
| 'success': True, | |
| 'segments_detected': 1, | |
| 'total_results': 1, | |
| 'results': [result], | |
| 'timestamp': time.time(), | |
| 'has_fallback': False | |
| }) | |
| except Exception as e: | |
| app.logger.error(f"Streaming audio processing error: {str(e)}") | |
| return jsonify({ | |
| 'error': 'Internal streaming processing error', | |
| 'success': False, | |
| 'timestamp': time.time() | |
| }), 500 | |
| def get_processors(): | |
| """Get information about available processors.""" | |
| try: | |
| processor_info = {} | |
| for name, processor in processors.items(): | |
| info = { | |
| 'name': processor.name, | |
| 'method': name, | |
| 'configured': getattr(processor, 'is_configured', lambda: True)() | |
| } | |
| # Add model-specific info if available | |
| if hasattr(processor, 'get_model_info'): | |
| info.update(processor.get_model_info()) | |
| processor_info[name] = info | |
| return jsonify(processor_info) | |
| except Exception as e: | |
| app.logger.error(f"Error getting processors: {str(e)}") | |
| return jsonify({'error': 'Failed to retrieve processor information'}), 500 | |
| def health_check(): | |
| """Health check endpoint.""" | |
| try: | |
| # Check processor availability | |
| processor_health = {} | |
| for name, processor in processors.items(): | |
| processor_health[name] = { | |
| 'available': True, | |
| 'configured': getattr(processor, 'is_configured', lambda: True)() | |
| } | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'timestamp': time.time(), | |
| 'processors': processor_health, | |
| 'version': '1.0.0', | |
| 'deployment': 'huggingface-spaces' | |
| }) | |
| except Exception as e: | |
| app.logger.error(f"Health check failed: {str(e)}") | |
| return jsonify({ | |
| 'status': 'unhealthy', | |
| 'error': str(e), | |
| 'timestamp': time.time() | |
| }), 500 | |
| def not_found_error(error): | |
| """Handle 404 errors.""" | |
| return jsonify({'error': 'Endpoint not found', 'status': 404}), 404 | |
| def internal_error(error): | |
| """Handle 500 errors.""" | |
| app.logger.error(f"Internal error: {str(error)}") | |
| return jsonify({'error': 'Internal server error', 'status': 500}), 500 | |
| def too_large_error(error): | |
| """Handle file too large errors.""" | |
| return jsonify({'error': 'File too large', 'status': 413}), 413 | |
| if __name__ == '__main__': | |
| # Log startup information | |
| try: | |
| import importlib.metadata | |
| flask_version = importlib.metadata.version('flask') | |
| except: | |
| flask_version = 'unknown' | |
| performance_logger.log_system_info({ | |
| 'python_version': os.sys.version, | |
| 'flask_version': flask_version, | |
| 'processors_loaded': list(processors.keys()), | |
| 'max_audio_duration': MAX_AUDIO_DURATION, | |
| 'max_file_size': MAX_FILE_SIZE, | |
| 'deployment': 'huggingface-spaces' | |
| }) | |
| # Run server (HF Spaces requires port 7860) | |
| port = int(os.getenv('PORT', 7860)) | |
| app.logger.info(f"Starting Audio Digit Classifier API on port {port}") | |
| app.logger.info("Deployment: Hugging Face Spaces") | |
| app.run( | |
| host='0.0.0.0', | |
| port=port, | |
| debug=False, # Disable debug in production | |
| threaded=True | |
| ) |