Spaces:
Runtime error
Runtime error
Pranav Mishra
commited on
Commit
·
1772a46
1
Parent(s):
494577d
Initial backend deployment - Flask API with ML models
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.example +21 -0
- Dockerfile +43 -0
- README.md +25 -7
- app.py +361 -0
- audio_processors/__init__.py +0 -0
- audio_processors/__pycache__/__init__.cpython-312.pyc +0 -0
- audio_processors/__pycache__/base_processor.cpython-312.pyc +0 -0
- audio_processors/__pycache__/external_api.cpython-312.pyc +0 -0
- audio_processors/__pycache__/faster_whisper_processor.cpython-312.pyc +0 -0
- audio_processors/__pycache__/local_whisper.cpython-312.pyc +0 -0
- audio_processors/__pycache__/mel_spectrogram.cpython-312.pyc +0 -0
- audio_processors/__pycache__/mfcc_processor.cpython-312.pyc +0 -0
- audio_processors/__pycache__/ml_mel_cnn_processor.cpython-312.pyc +0 -0
- audio_processors/__pycache__/ml_mfcc_processor.cpython-312.pyc +0 -0
- audio_processors/__pycache__/ml_raw_cnn_processor.cpython-312.pyc +0 -0
- audio_processors/__pycache__/raw_spectrogram.cpython-312.pyc +0 -0
- audio_processors/__pycache__/wav2vec2_processor.cpython-312.pyc +0 -0
- audio_processors/__pycache__/whisper_digit_processor.cpython-312.pyc +0 -0
- audio_processors/base_processor.py +85 -0
- audio_processors/external_api.py +153 -0
- audio_processors/faster_whisper_processor.py +219 -0
- audio_processors/local_whisper.py +158 -0
- audio_processors/mel_spectrogram.py +74 -0
- audio_processors/mfcc_processor.py +79 -0
- audio_processors/ml_mel_cnn_processor.py +307 -0
- audio_processors/ml_mfcc_processor.py +370 -0
- audio_processors/ml_raw_cnn_processor.py +307 -0
- audio_processors/raw_spectrogram.py +69 -0
- audio_processors/wav2vec2_processor.py +170 -0
- audio_processors/whisper_digit_processor.py +429 -0
- models/mel_cnn_classifier/best_model.pt +3 -0
- models/mfcc_classifier/best_model.pt +3 -0
- models/mfcc_classifier/scaler.pkl +3 -0
- models/raw_cnn_classifier/best_model.pt +3 -0
- requirements_hf.txt +26 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-312.pyc +0 -0
- utils/__pycache__/audio_utils.cpython-312.pyc +0 -0
- utils/__pycache__/enhanced_vad.cpython-312.pyc +0 -0
- utils/__pycache__/logging_utils.cpython-312.pyc +0 -0
- utils/__pycache__/noise_utils.cpython-312.pyc +0 -0
- utils/__pycache__/session_manager.cpython-312.pyc +0 -0
- utils/__pycache__/vad_feature_integration.cpython-312.pyc +0 -0
- utils/__pycache__/webm_converter.cpython-312.pyc +0 -0
- utils/__pycache__/webrtc_vad.cpython-312.pyc +0 -0
- utils/audio_utils.py +427 -0
- utils/enhanced_vad.py +571 -0
- utils/logging_utils.py +201 -0
- utils/noise_utils.py +292 -0
- utils/session_manager.py +340 -0
.env.example
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment variables for HF Spaces deployment
|
| 2 |
+
# Copy this to .env and set your values
|
| 3 |
+
|
| 4 |
+
# Flask Configuration
|
| 5 |
+
SECRET_KEY=your_secret_key_here
|
| 6 |
+
FLASK_ENV=production
|
| 7 |
+
|
| 8 |
+
# External API Keys (optional - for external processors)
|
| 9 |
+
HF_TOKEN=your_huggingface_token_here
|
| 10 |
+
OPENAI_API_KEY=your_openai_key_here
|
| 11 |
+
|
| 12 |
+
# Model Configuration
|
| 13 |
+
DEFAULT_ML_MODEL=ml_mfcc
|
| 14 |
+
ENABLE_EXTERNAL_API=false
|
| 15 |
+
|
| 16 |
+
# Performance Settings
|
| 17 |
+
MAX_AUDIO_DURATION=10
|
| 18 |
+
MAX_FILE_SIZE=10485760
|
| 19 |
+
|
| 20 |
+
# Logging
|
| 21 |
+
LOG_LEVEL=INFO
|
Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.9 as recommended by HF Spaces
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Create user for HF Spaces (required)
|
| 5 |
+
RUN useradd -m -u 1000 user
|
| 6 |
+
USER user
|
| 7 |
+
|
| 8 |
+
# Set environment variables
|
| 9 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 10 |
+
ENV PYTHONPATH="/app:$PYTHONPATH"
|
| 11 |
+
ENV PYTHONUNBUFFERED=1
|
| 12 |
+
|
| 13 |
+
# Set work directory
|
| 14 |
+
WORKDIR /app
|
| 15 |
+
|
| 16 |
+
# Install system dependencies (as user, limited packages)
|
| 17 |
+
# Note: HF Spaces has restrictions on system packages
|
| 18 |
+
COPY --chown=user ./requirements_hf.txt requirements.txt
|
| 19 |
+
|
| 20 |
+
# Install Python dependencies
|
| 21 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Copy application files (essential files only)
|
| 24 |
+
COPY --chown=user ./app.py ./app.py
|
| 25 |
+
COPY --chown=user ./audio_processors ./audio_processors
|
| 26 |
+
COPY --chown=user ./utils ./utils
|
| 27 |
+
COPY --chown=user ./models ./models
|
| 28 |
+
|
| 29 |
+
# Copy environment template (users can set their own HF_TOKEN)
|
| 30 |
+
COPY --chown=user ./.env.example ./.env
|
| 31 |
+
|
| 32 |
+
# Create log directory
|
| 33 |
+
RUN mkdir -p /app/logs
|
| 34 |
+
|
| 35 |
+
# Expose port (HF Spaces requires 7860)
|
| 36 |
+
EXPOSE 7860
|
| 37 |
+
|
| 38 |
+
# Health check
|
| 39 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 40 |
+
CMD python -c "import requests; requests.get('http://localhost:7860/api/health').raise_for_status()" || exit 1
|
| 41 |
+
|
| 42 |
+
# Run the application
|
| 43 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -1,12 +1,30 @@
|
|
| 1 |
---
|
| 2 |
-
title: Streaming Digit Classifier
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
|
| 9 |
-
short_description: Real-time spoken digit recognition API with 4 ML approaches
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Streaming Digit Classifier API
|
| 3 |
+
emoji: 🎤
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 7860
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Streaming Digit Classifier API
|
| 12 |
+
|
| 13 |
+
Backend API for real-time spoken digit recognition (0-9).
|
| 14 |
+
|
| 15 |
+
## Features
|
| 16 |
+
|
| 17 |
+
- ML Models: MFCC + Dense NN, Mel CNN, Raw CNN
|
| 18 |
+
- External API integration (Whisper)
|
| 19 |
+
- Real-time audio processing
|
| 20 |
+
- RESTful API endpoints
|
| 21 |
+
|
| 22 |
+
## API Endpoints
|
| 23 |
+
|
| 24 |
+
- \`GET /\` - API status
|
| 25 |
+
- \`POST /api/process_audio\` - Process audio file
|
| 26 |
+
- \`POST /api/process_audio_chunk\` - Process streaming chunk
|
| 27 |
+
- \`GET /api/health\` - Health check
|
| 28 |
+
- \`GET /api/processors\` - Available processors
|
| 29 |
+
|
| 30 |
+
Frontend: [Deployed on Vercel](https://your-frontend-url.vercel.app)
|
app.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Digit Classification API for Hugging Face Spaces
|
| 3 |
+
Backend API for spoken digit recognition (0-9) - HF Spaces deployment
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from flask import Flask, request, jsonify
|
| 7 |
+
from flask_cors import CORS
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Dict, Any, Optional
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
# Import audio processors (only essential ones for deployment)
|
| 16 |
+
from audio_processors.external_api import ExternalAPIProcessor
|
| 17 |
+
from audio_processors.whisper_digit_processor import WhisperDigitProcessor
|
| 18 |
+
from audio_processors.ml_mfcc_processor import MLMFCCProcessor
|
| 19 |
+
from audio_processors.ml_mel_cnn_processor import MLMelCNNProcessor
|
| 20 |
+
from audio_processors.ml_raw_cnn_processor import MLRawCNNProcessor
|
| 21 |
+
|
| 22 |
+
# Import utilities
|
| 23 |
+
from utils.audio_utils import validate_audio_format, convert_audio_format, get_audio_duration, convert_for_ml_models
|
| 24 |
+
from utils.logging_utils import performance_logger, setup_flask_logging
|
| 25 |
+
|
| 26 |
+
# Load environment variables
|
| 27 |
+
load_dotenv()
|
| 28 |
+
|
| 29 |
+
# Initialize Flask app
|
| 30 |
+
app = Flask(__name__)
|
| 31 |
+
app.secret_key = os.getenv('SECRET_KEY', 'hf_spaces_deployment_key')
|
| 32 |
+
|
| 33 |
+
# Enable CORS for frontend requests from Vercel
|
| 34 |
+
CORS(app, origins=['*']) # In production, specify your Vercel domain
|
| 35 |
+
|
| 36 |
+
# Setup logging
|
| 37 |
+
setup_flask_logging(app)
|
| 38 |
+
|
| 39 |
+
# Configuration for HF Spaces
|
| 40 |
+
MAX_AUDIO_DURATION = 10 # seconds
|
| 41 |
+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
| 42 |
+
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'ogg', 'm4a', 'webm'}
|
| 43 |
+
|
| 44 |
+
def allowed_file(filename: str) -> bool:
|
| 45 |
+
"""Check if file extension is allowed."""
|
| 46 |
+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 47 |
+
|
| 48 |
+
def initialize_processors():
|
| 49 |
+
"""Initialize audio processors optimized for HF Spaces deployment."""
|
| 50 |
+
procs = {}
|
| 51 |
+
|
| 52 |
+
# ML-trained processors (high priority - use best models only)
|
| 53 |
+
ml_processors = [
|
| 54 |
+
('ml_mfcc', MLMFCCProcessor, 'ML MFCC + Dense NN (Best - 98.52%)'),
|
| 55 |
+
('ml_mel_cnn', MLMelCNNProcessor, 'ML Mel CNN (Good - 97.22%)'),
|
| 56 |
+
('ml_raw_cnn', MLRawCNNProcessor, 'ML Raw CNN (Fair - 91.30%)')
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
ml_working_count = 0
|
| 60 |
+
for proc_key, proc_class, proc_name in ml_processors:
|
| 61 |
+
try:
|
| 62 |
+
processor = proc_class()
|
| 63 |
+
if processor.is_configured():
|
| 64 |
+
procs[proc_key] = processor
|
| 65 |
+
ml_working_count += 1
|
| 66 |
+
app.logger.info(f"[OK] {proc_name} loaded successfully")
|
| 67 |
+
else:
|
| 68 |
+
app.logger.warning(f"[WARN] {proc_name} not configured (model files missing)")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
app.logger.error(f"[FAIL] Failed to initialize {proc_name}: {str(e)}")
|
| 71 |
+
|
| 72 |
+
# External API processor as fallback
|
| 73 |
+
try:
|
| 74 |
+
external_processor = ExternalAPIProcessor()
|
| 75 |
+
if external_processor.is_configured():
|
| 76 |
+
procs['external_api'] = external_processor
|
| 77 |
+
app.logger.info("[OK] External API processor initialized")
|
| 78 |
+
else:
|
| 79 |
+
app.logger.warning("[WARN] External API not configured")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
app.logger.error(f"[FAIL] Failed to initialize External API: {str(e)}")
|
| 82 |
+
|
| 83 |
+
# Whisper digit processor as another fallback
|
| 84 |
+
try:
|
| 85 |
+
whisper_processor = WhisperDigitProcessor()
|
| 86 |
+
if whisper_processor.is_configured():
|
| 87 |
+
procs['whisper_digit'] = whisper_processor
|
| 88 |
+
app.logger.info("[OK] Whisper digit processor initialized")
|
| 89 |
+
except Exception as e:
|
| 90 |
+
app.logger.error(f"[FAIL] Failed to initialize Whisper: {str(e)}")
|
| 91 |
+
|
| 92 |
+
app.logger.info(f"Processor initialization complete:")
|
| 93 |
+
app.logger.info(f" ML Models loaded: {ml_working_count}/3")
|
| 94 |
+
app.logger.info(f" Total processors: {len(procs)}")
|
| 95 |
+
|
| 96 |
+
return procs
|
| 97 |
+
|
| 98 |
+
processors = initialize_processors()
|
| 99 |
+
|
| 100 |
+
@app.route('/')
|
| 101 |
+
def index():
|
| 102 |
+
"""API status endpoint."""
|
| 103 |
+
return jsonify({
|
| 104 |
+
'message': 'Streaming Digit Classifier API',
|
| 105 |
+
'status': 'running',
|
| 106 |
+
'version': '1.0.0',
|
| 107 |
+
'available_processors': list(processors.keys()),
|
| 108 |
+
'documentation': 'Frontend at Vercel, Backend API at HF Spaces'
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
@app.route('/api/process_audio', methods=['POST'])
|
| 112 |
+
def process_audio():
|
| 113 |
+
"""
|
| 114 |
+
Process audio file with selected method and return digit prediction.
|
| 115 |
+
Expects multipart form data with 'audio' file and 'method' selection.
|
| 116 |
+
"""
|
| 117 |
+
try:
|
| 118 |
+
# Validate request
|
| 119 |
+
if 'audio' not in request.files:
|
| 120 |
+
return jsonify({'error': 'No audio file provided'}), 400
|
| 121 |
+
|
| 122 |
+
if 'method' not in request.form:
|
| 123 |
+
return jsonify({'error': 'No processing method specified'}), 400
|
| 124 |
+
|
| 125 |
+
audio_file = request.files['audio']
|
| 126 |
+
method = request.form['method']
|
| 127 |
+
|
| 128 |
+
# Validate audio file
|
| 129 |
+
if audio_file.filename == '':
|
| 130 |
+
return jsonify({'error': 'No file selected'}), 400
|
| 131 |
+
|
| 132 |
+
if not allowed_file(audio_file.filename):
|
| 133 |
+
return jsonify({'error': 'Unsupported file format'}), 400
|
| 134 |
+
|
| 135 |
+
# Validate method
|
| 136 |
+
if method not in processors:
|
| 137 |
+
return jsonify({'error': f'Unknown processing method: {method}'}), 400
|
| 138 |
+
|
| 139 |
+
# Read audio data
|
| 140 |
+
audio_data = audio_file.read()
|
| 141 |
+
|
| 142 |
+
# Check file size
|
| 143 |
+
if len(audio_data) > MAX_FILE_SIZE:
|
| 144 |
+
return jsonify({'error': 'Audio file too large'}), 400
|
| 145 |
+
|
| 146 |
+
# Convert to standard format
|
| 147 |
+
try:
|
| 148 |
+
app.logger.debug(f"Converting audio format. Original size: {len(audio_data)} bytes")
|
| 149 |
+
standardized_audio = convert_audio_format(audio_data)
|
| 150 |
+
app.logger.debug(f"Converted audio size: {len(standardized_audio)} bytes")
|
| 151 |
+
except Exception as e:
|
| 152 |
+
app.logger.error(f"Audio conversion failed: {str(e)}")
|
| 153 |
+
return jsonify({'error': 'Failed to process audio format - unsupported format or corrupted file'}), 400
|
| 154 |
+
|
| 155 |
+
# Check audio duration
|
| 156 |
+
duration = get_audio_duration(standardized_audio)
|
| 157 |
+
if duration > MAX_AUDIO_DURATION:
|
| 158 |
+
return jsonify({
|
| 159 |
+
'error': f'Audio too long: {duration:.1f}s (max: {MAX_AUDIO_DURATION}s)'
|
| 160 |
+
}), 400
|
| 161 |
+
|
| 162 |
+
if duration < 0.1:
|
| 163 |
+
return jsonify({'error': 'Audio too short (minimum: 0.1s)'}), 400
|
| 164 |
+
|
| 165 |
+
# Log audio input info
|
| 166 |
+
performance_logger.log_audio_info(duration, {
|
| 167 |
+
'filename': audio_file.filename,
|
| 168 |
+
'size_bytes': len(audio_data),
|
| 169 |
+
'converted_size': len(standardized_audio),
|
| 170 |
+
'method': method
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
# Process with selected method
|
| 174 |
+
processor = processors[method]
|
| 175 |
+
result = processor.predict_with_timing(standardized_audio)
|
| 176 |
+
|
| 177 |
+
# Log performance
|
| 178 |
+
performance_logger.log_prediction(method, result)
|
| 179 |
+
|
| 180 |
+
# Add additional metadata
|
| 181 |
+
result.update({
|
| 182 |
+
'audio_duration': round(duration, 3),
|
| 183 |
+
'file_size': len(audio_data),
|
| 184 |
+
'api_version': '1.0.0'
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
app.logger.info(f"Processed audio with {method}: '{result['predicted_digit']}' in {result['inference_time']}s")
|
| 188 |
+
|
| 189 |
+
return jsonify(result)
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
app.logger.error(f"Audio processing error: {str(e)}")
|
| 193 |
+
return jsonify({
|
| 194 |
+
'error': 'Internal processing error',
|
| 195 |
+
'success': False,
|
| 196 |
+
'timestamp': time.time()
|
| 197 |
+
}), 500
|
| 198 |
+
|
| 199 |
+
@app.route('/api/process_audio_chunk', methods=['POST'])
|
| 200 |
+
def process_audio_chunk():
|
| 201 |
+
"""
|
| 202 |
+
Process streaming audio chunk for real-time digit recognition.
|
| 203 |
+
"""
|
| 204 |
+
try:
|
| 205 |
+
# Validate request
|
| 206 |
+
if 'audio' not in request.files:
|
| 207 |
+
return jsonify({'error': 'No audio chunk provided'}), 400
|
| 208 |
+
|
| 209 |
+
audio_file = request.files['audio']
|
| 210 |
+
method = request.form.get('method', 'ml_mfcc') # Default to best ML model
|
| 211 |
+
|
| 212 |
+
# Validate method
|
| 213 |
+
if method not in processors:
|
| 214 |
+
return jsonify({'error': f'Unknown processing method: {method}'}), 400
|
| 215 |
+
|
| 216 |
+
# Read audio data
|
| 217 |
+
audio_data = audio_file.read()
|
| 218 |
+
|
| 219 |
+
# Check chunk size
|
| 220 |
+
if len(audio_data) > MAX_FILE_SIZE:
|
| 221 |
+
return jsonify({'error': 'Audio chunk too large'}), 400
|
| 222 |
+
|
| 223 |
+
if len(audio_data) < 100:
|
| 224 |
+
return jsonify({'error': 'Audio chunk too small'}), 400
|
| 225 |
+
|
| 226 |
+
# Convert to standardized format
|
| 227 |
+
try:
|
| 228 |
+
standardized_audio = convert_for_ml_models(audio_data, 'streaming')
|
| 229 |
+
except Exception as e:
|
| 230 |
+
app.logger.error(f"Audio conversion failed for chunk: {str(e)}")
|
| 231 |
+
return jsonify({'error': 'Failed to process audio chunk format'}), 400
|
| 232 |
+
|
| 233 |
+
# Process audio chunk
|
| 234 |
+
processor = processors[method]
|
| 235 |
+
result = processor.predict_with_timing(standardized_audio)
|
| 236 |
+
|
| 237 |
+
# Add streaming metadata
|
| 238 |
+
result.update({
|
| 239 |
+
'segment_index': 0,
|
| 240 |
+
'segment_size': len(standardized_audio),
|
| 241 |
+
'is_streaming': True,
|
| 242 |
+
'api_version': '1.0.0'
|
| 243 |
+
})
|
| 244 |
+
|
| 245 |
+
app.logger.info(f"Streaming prediction: '{result['predicted_digit']}' "
|
| 246 |
+
f"(Inference: {result['inference_time']}s)")
|
| 247 |
+
|
| 248 |
+
return jsonify({
|
| 249 |
+
'success': True,
|
| 250 |
+
'segments_detected': 1,
|
| 251 |
+
'total_results': 1,
|
| 252 |
+
'results': [result],
|
| 253 |
+
'timestamp': time.time(),
|
| 254 |
+
'has_fallback': False
|
| 255 |
+
})
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
app.logger.error(f"Streaming audio processing error: {str(e)}")
|
| 259 |
+
return jsonify({
|
| 260 |
+
'error': 'Internal streaming processing error',
|
| 261 |
+
'success': False,
|
| 262 |
+
'timestamp': time.time()
|
| 263 |
+
}), 500
|
| 264 |
+
|
| 265 |
+
@app.route('/api/processors')
|
| 266 |
+
def get_processors():
|
| 267 |
+
"""Get information about available processors."""
|
| 268 |
+
try:
|
| 269 |
+
processor_info = {}
|
| 270 |
+
for name, processor in processors.items():
|
| 271 |
+
info = {
|
| 272 |
+
'name': processor.name,
|
| 273 |
+
'method': name,
|
| 274 |
+
'configured': getattr(processor, 'is_configured', lambda: True)()
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
# Add model-specific info if available
|
| 278 |
+
if hasattr(processor, 'get_model_info'):
|
| 279 |
+
info.update(processor.get_model_info())
|
| 280 |
+
|
| 281 |
+
processor_info[name] = info
|
| 282 |
+
|
| 283 |
+
return jsonify(processor_info)
|
| 284 |
+
|
| 285 |
+
except Exception as e:
|
| 286 |
+
app.logger.error(f"Error getting processors: {str(e)}")
|
| 287 |
+
return jsonify({'error': 'Failed to retrieve processor information'}), 500
|
| 288 |
+
|
| 289 |
+
@app.route('/api/health')
|
| 290 |
+
def health_check():
|
| 291 |
+
"""Health check endpoint."""
|
| 292 |
+
try:
|
| 293 |
+
# Check processor availability
|
| 294 |
+
processor_health = {}
|
| 295 |
+
for name, processor in processors.items():
|
| 296 |
+
processor_health[name] = {
|
| 297 |
+
'available': True,
|
| 298 |
+
'configured': getattr(processor, 'is_configured', lambda: True)()
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
return jsonify({
|
| 302 |
+
'status': 'healthy',
|
| 303 |
+
'timestamp': time.time(),
|
| 304 |
+
'processors': processor_health,
|
| 305 |
+
'version': '1.0.0',
|
| 306 |
+
'deployment': 'huggingface-spaces'
|
| 307 |
+
})
|
| 308 |
+
|
| 309 |
+
except Exception as e:
|
| 310 |
+
app.logger.error(f"Health check failed: {str(e)}")
|
| 311 |
+
return jsonify({
|
| 312 |
+
'status': 'unhealthy',
|
| 313 |
+
'error': str(e),
|
| 314 |
+
'timestamp': time.time()
|
| 315 |
+
}), 500
|
| 316 |
+
|
| 317 |
+
@app.errorhandler(404)
|
| 318 |
+
def not_found_error(error):
|
| 319 |
+
"""Handle 404 errors."""
|
| 320 |
+
return jsonify({'error': 'Endpoint not found', 'status': 404}), 404
|
| 321 |
+
|
| 322 |
+
@app.errorhandler(500)
|
| 323 |
+
def internal_error(error):
|
| 324 |
+
"""Handle 500 errors."""
|
| 325 |
+
app.logger.error(f"Internal error: {str(error)}")
|
| 326 |
+
return jsonify({'error': 'Internal server error', 'status': 500}), 500
|
| 327 |
+
|
| 328 |
+
@app.errorhandler(413)
|
| 329 |
+
def too_large_error(error):
|
| 330 |
+
"""Handle file too large errors."""
|
| 331 |
+
return jsonify({'error': 'File too large', 'status': 413}), 413
|
| 332 |
+
|
| 333 |
+
if __name__ == '__main__':
|
| 334 |
+
# Log startup information
|
| 335 |
+
try:
|
| 336 |
+
import importlib.metadata
|
| 337 |
+
flask_version = importlib.metadata.version('flask')
|
| 338 |
+
except:
|
| 339 |
+
flask_version = 'unknown'
|
| 340 |
+
|
| 341 |
+
performance_logger.log_system_info({
|
| 342 |
+
'python_version': os.sys.version,
|
| 343 |
+
'flask_version': flask_version,
|
| 344 |
+
'processors_loaded': list(processors.keys()),
|
| 345 |
+
'max_audio_duration': MAX_AUDIO_DURATION,
|
| 346 |
+
'max_file_size': MAX_FILE_SIZE,
|
| 347 |
+
'deployment': 'huggingface-spaces'
|
| 348 |
+
})
|
| 349 |
+
|
| 350 |
+
# Run server (HF Spaces requires port 7860)
|
| 351 |
+
port = int(os.getenv('PORT', 7860))
|
| 352 |
+
|
| 353 |
+
app.logger.info(f"Starting Audio Digit Classifier API on port {port}")
|
| 354 |
+
app.logger.info("Deployment: Hugging Face Spaces")
|
| 355 |
+
|
| 356 |
+
app.run(
|
| 357 |
+
host='0.0.0.0',
|
| 358 |
+
port=port,
|
| 359 |
+
debug=False, # Disable debug in production
|
| 360 |
+
threaded=True
|
| 361 |
+
)
|
audio_processors/__init__.py
ADDED
|
File without changes
|
audio_processors/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (153 Bytes). View file
|
|
|
audio_processors/__pycache__/base_processor.cpython-312.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
audio_processors/__pycache__/external_api.cpython-312.pyc
ADDED
|
Binary file (7.1 kB). View file
|
|
|
audio_processors/__pycache__/faster_whisper_processor.cpython-312.pyc
ADDED
|
Binary file (9.58 kB). View file
|
|
|
audio_processors/__pycache__/local_whisper.cpython-312.pyc
ADDED
|
Binary file (6.67 kB). View file
|
|
|
audio_processors/__pycache__/mel_spectrogram.cpython-312.pyc
ADDED
|
Binary file (2.8 kB). View file
|
|
|
audio_processors/__pycache__/mfcc_processor.cpython-312.pyc
ADDED
|
Binary file (2.85 kB). View file
|
|
|
audio_processors/__pycache__/ml_mel_cnn_processor.cpython-312.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
audio_processors/__pycache__/ml_mfcc_processor.cpython-312.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
audio_processors/__pycache__/ml_raw_cnn_processor.cpython-312.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
audio_processors/__pycache__/raw_spectrogram.cpython-312.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
audio_processors/__pycache__/wav2vec2_processor.cpython-312.pyc
ADDED
|
Binary file (7.25 kB). View file
|
|
|
audio_processors/__pycache__/whisper_digit_processor.cpython-312.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
audio_processors/base_processor.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Union, Dict, Any
|
| 3 |
+
import time
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class AudioProcessor(ABC):
|
| 9 |
+
"""
|
| 10 |
+
Abstract base class for all audio digit classification processors.
|
| 11 |
+
Provides common interface and logging functionality.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, name: str):
|
| 15 |
+
self.name = name
|
| 16 |
+
self.total_predictions = 0
|
| 17 |
+
self.total_inference_time = 0.0
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 21 |
+
"""
|
| 22 |
+
Process audio data and return predicted digit as string.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
audio_data: Raw audio bytes
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Predicted digit as string ('0'-'9')
|
| 29 |
+
"""
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
def predict_with_timing(self, audio_data: bytes) -> Dict[str, Any]:
|
| 33 |
+
"""
|
| 34 |
+
Process audio and return prediction with timing information.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
audio_data: Raw audio bytes
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Dictionary with prediction, timing, and method info
|
| 41 |
+
"""
|
| 42 |
+
start_time = time.time()
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
predicted_digit = self.process_audio(audio_data)
|
| 46 |
+
inference_time = time.time() - start_time
|
| 47 |
+
|
| 48 |
+
self.total_predictions += 1
|
| 49 |
+
self.total_inference_time += inference_time
|
| 50 |
+
|
| 51 |
+
result = {
|
| 52 |
+
'predicted_digit': predicted_digit,
|
| 53 |
+
'inference_time': round(inference_time, 3),
|
| 54 |
+
'method': self.name,
|
| 55 |
+
'timestamp': time.time(),
|
| 56 |
+
'average_time': round(self.total_inference_time / self.total_predictions, 3),
|
| 57 |
+
'success': True
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
logger.info(f"{self.name}: Predicted '{predicted_digit}' in {inference_time:.3f}s")
|
| 61 |
+
return result
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
inference_time = time.time() - start_time
|
| 65 |
+
logger.error(f"{self.name}: Error processing audio: {str(e)}")
|
| 66 |
+
|
| 67 |
+
return {
|
| 68 |
+
'predicted_digit': 'ERROR',
|
| 69 |
+
'inference_time': round(inference_time, 3),
|
| 70 |
+
'method': self.name,
|
| 71 |
+
'timestamp': time.time(),
|
| 72 |
+
'success': False,
|
| 73 |
+
'error': str(e)
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
def get_stats(self) -> Dict[str, float]:
|
| 77 |
+
"""Get processor statistics."""
|
| 78 |
+
if self.total_predictions == 0:
|
| 79 |
+
return {'total_predictions': 0, 'average_time': 0.0}
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
'total_predictions': self.total_predictions,
|
| 83 |
+
'total_time': round(self.total_inference_time, 3),
|
| 84 |
+
'average_time': round(self.total_inference_time / self.total_predictions, 3)
|
| 85 |
+
}
|
audio_processors/external_api.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from .base_processor import AudioProcessor
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class ExternalAPIProcessor(AudioProcessor):
|
| 11 |
+
"""
|
| 12 |
+
Hugging Face Whisper API integration for digit classification.
|
| 13 |
+
Uses openai/whisper-base model for speech-to-text conversion.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__("External API (Whisper)")
|
| 18 |
+
# Try alternative Whisper model that should be available
|
| 19 |
+
self.api_url = "https://api-inference.huggingface.co/models/openai/whisper-small"
|
| 20 |
+
self.token = os.getenv('HUGGING_FACE_TOKEN')
|
| 21 |
+
self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
| 22 |
+
|
| 23 |
+
if not self.token:
|
| 24 |
+
logger.warning("HUGGING_FACE_TOKEN not found in environment variables")
|
| 25 |
+
|
| 26 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 27 |
+
"""
|
| 28 |
+
Process audio using Hugging Face Whisper API.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
audio_data: Raw audio bytes (WAV format preferred)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Predicted digit as string ('0'-'9')
|
| 35 |
+
|
| 36 |
+
Raises:
|
| 37 |
+
Exception: If API call fails or no digit found in response
|
| 38 |
+
"""
|
| 39 |
+
if not self.token:
|
| 40 |
+
raise Exception("Hugging Face API token not configured")
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
# Make API request
|
| 44 |
+
response = requests.post(
|
| 45 |
+
self.api_url,
|
| 46 |
+
headers=self.headers,
|
| 47 |
+
data=audio_data,
|
| 48 |
+
timeout=15 # Increased timeout
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if response.status_code == 401:
|
| 52 |
+
logger.error("Hugging Face API token is invalid or expired")
|
| 53 |
+
raise Exception("Invalid or expired API token - please update HUGGING_FACE_TOKEN")
|
| 54 |
+
elif response.status_code == 404:
|
| 55 |
+
logger.error(f"Model not found or unavailable: {self.api_url}")
|
| 56 |
+
raise Exception("API model unavailable - may be loading or deprecated")
|
| 57 |
+
elif response.status_code == 503:
|
| 58 |
+
logger.warning("Model is loading, this may take a few moments")
|
| 59 |
+
raise Exception("API model is loading - please try again in a moment")
|
| 60 |
+
elif response.status_code != 200:
|
| 61 |
+
logger.error(f"API request failed: {response.status_code} - {response.text}")
|
| 62 |
+
raise Exception(f"API error {response.status_code}: {response.text[:100]}")
|
| 63 |
+
|
| 64 |
+
# Parse response
|
| 65 |
+
result = response.json()
|
| 66 |
+
|
| 67 |
+
if 'text' not in result:
|
| 68 |
+
logger.error(f"Unexpected API response format: {result}")
|
| 69 |
+
raise Exception("Invalid API response format")
|
| 70 |
+
|
| 71 |
+
transcribed_text = result['text'].strip().lower()
|
| 72 |
+
logger.debug(f"Whisper transcription: '{transcribed_text}'")
|
| 73 |
+
|
| 74 |
+
# Extract digit from transcription
|
| 75 |
+
predicted_digit = self._extract_digit(transcribed_text)
|
| 76 |
+
|
| 77 |
+
if predicted_digit is None:
|
| 78 |
+
logger.warning(f"No digit found in transcription: '{transcribed_text}'")
|
| 79 |
+
return "?"
|
| 80 |
+
|
| 81 |
+
return predicted_digit
|
| 82 |
+
|
| 83 |
+
except requests.exceptions.Timeout:
|
| 84 |
+
raise Exception("API request timeout (15s) - service may be slow")
|
| 85 |
+
except requests.exceptions.RequestException as e:
|
| 86 |
+
raise Exception(f"API request failed: {str(e)}")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Unexpected error in external API processing: {str(e)}")
|
| 89 |
+
raise
|
| 90 |
+
|
| 91 |
+
def _extract_digit(self, text: str) -> Optional[str]:
|
| 92 |
+
"""
|
| 93 |
+
Extract digit from transcribed text.
|
| 94 |
+
Handles both numerical ('1', '2') and word forms ('one', 'two').
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
text: Transcribed text from Whisper
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Digit as string ('0'-'9') or None if not found
|
| 101 |
+
"""
|
| 102 |
+
# Word to digit mapping
|
| 103 |
+
word_to_digit = {
|
| 104 |
+
'zero': '0', 'oh': '0',
|
| 105 |
+
'one': '1', 'won': '1',
|
| 106 |
+
'two': '2', 'to': '2', 'too': '2',
|
| 107 |
+
'three': '3', 'tree': '3',
|
| 108 |
+
'four': '4', 'for': '4', 'fore': '4',
|
| 109 |
+
'five': '5',
|
| 110 |
+
'six': '6', 'sick': '6',
|
| 111 |
+
'seven': '7',
|
| 112 |
+
'eight': '8', 'ate': '8',
|
| 113 |
+
'nine': '9', 'niner': '9'
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
# First, try to find a direct digit
|
| 117 |
+
digit_match = re.search(r'\b([0-9])\b', text)
|
| 118 |
+
if digit_match:
|
| 119 |
+
return digit_match.group(1)
|
| 120 |
+
|
| 121 |
+
# Then try word forms
|
| 122 |
+
words = text.split()
|
| 123 |
+
for word in words:
|
| 124 |
+
clean_word = re.sub(r'[^\w]', '', word.lower())
|
| 125 |
+
if clean_word in word_to_digit:
|
| 126 |
+
return word_to_digit[clean_word]
|
| 127 |
+
|
| 128 |
+
# Try partial matches for robustness
|
| 129 |
+
for word, digit in word_to_digit.items():
|
| 130 |
+
if word in text:
|
| 131 |
+
return digit
|
| 132 |
+
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
def is_configured(self) -> bool:
|
| 136 |
+
"""Check if API is properly configured."""
|
| 137 |
+
return bool(self.token)
|
| 138 |
+
|
| 139 |
+
def test_connection(self) -> bool:
|
| 140 |
+
"""Test API connection with a simple request."""
|
| 141 |
+
if not self.is_configured():
|
| 142 |
+
return False
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
# Test with minimal audio data
|
| 146 |
+
test_response = requests.get(
|
| 147 |
+
self.api_url,
|
| 148 |
+
headers=self.headers,
|
| 149 |
+
timeout=5
|
| 150 |
+
)
|
| 151 |
+
return test_response.status_code == 200
|
| 152 |
+
except:
|
| 153 |
+
return False
|
audio_processors/faster_whisper_processor.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Faster-Whisper processor with built-in VAD (2025 approach)
|
| 3 |
+
More reliable than manual WebRTC VAD + Whisper coordination
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import io
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from faster_whisper import WhisperModel
|
| 14 |
+
FASTER_WHISPER_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
FASTER_WHISPER_AVAILABLE = False
|
| 17 |
+
WhisperModel = None
|
| 18 |
+
|
| 19 |
+
from .base_processor import AudioProcessor
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
class FasterWhisperDigitProcessor(AudioProcessor):
|
| 24 |
+
"""
|
| 25 |
+
Modern 2025 approach using faster-whisper with built-in VAD.
|
| 26 |
+
Much more reliable than manual WebRTC VAD coordination.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
"""Initialize faster-whisper processor with built-in VAD."""
|
| 31 |
+
super().__init__("Faster-Whisper with VAD")
|
| 32 |
+
|
| 33 |
+
if not FASTER_WHISPER_AVAILABLE:
|
| 34 |
+
logger.error("faster-whisper not available. Install with: pip install faster-whisper")
|
| 35 |
+
self.model = None
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
self.model = None
|
| 39 |
+
self.device = "cuda" if self._cuda_available() else "cpu"
|
| 40 |
+
|
| 41 |
+
# Digit mapping
|
| 42 |
+
self.digit_map = {
|
| 43 |
+
"zero": "0", "one": "1", "two": "2", "three": "3",
|
| 44 |
+
"four": "4", "five": "5", "six": "6", "seven": "7",
|
| 45 |
+
"eight": "8", "nine": "9",
|
| 46 |
+
"oh": "0", "o": "0", "for": "4", "fore": "4",
|
| 47 |
+
"to": "2", "too": "2", "tu": "2", "tree": "3",
|
| 48 |
+
"free": "3", "ate": "8", "ait": "8"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Statistics
|
| 52 |
+
self.total_predictions = 0
|
| 53 |
+
self.successful_predictions = 0
|
| 54 |
+
self.failed_predictions = 0
|
| 55 |
+
|
| 56 |
+
self._initialize_model()
|
| 57 |
+
|
| 58 |
+
def _cuda_available(self) -> bool:
|
| 59 |
+
"""Check if CUDA is available."""
|
| 60 |
+
try:
|
| 61 |
+
import torch
|
| 62 |
+
return torch.cuda.is_available()
|
| 63 |
+
except ImportError:
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
def _initialize_model(self):
|
| 67 |
+
"""Initialize faster-whisper model with VAD."""
|
| 68 |
+
if not FASTER_WHISPER_AVAILABLE:
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
logger.info("Initializing faster-whisper model with built-in VAD...")
|
| 73 |
+
|
| 74 |
+
# Initialize faster-whisper model
|
| 75 |
+
self.model = WhisperModel(
|
| 76 |
+
"tiny", # Use tiny model for speed
|
| 77 |
+
device=self.device,
|
| 78 |
+
compute_type="float16" if self.device == "cuda" else "int8"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
logger.info(f"Faster-Whisper model initialized on {self.device}")
|
| 82 |
+
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Failed to initialize faster-whisper: {e}")
|
| 85 |
+
self.model = None
|
| 86 |
+
|
| 87 |
+
def is_configured(self) -> bool:
|
| 88 |
+
"""Check if processor is configured."""
|
| 89 |
+
return self.model is not None and FASTER_WHISPER_AVAILABLE
|
| 90 |
+
|
| 91 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 92 |
+
"""
|
| 93 |
+
Process audio with built-in VAD and return predicted digit.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
audio_data: Raw audio bytes
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
str: Predicted digit (0-9) or error message
|
| 100 |
+
"""
|
| 101 |
+
if not self.is_configured():
|
| 102 |
+
return "error: Model not configured"
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
# Convert audio to numpy array
|
| 106 |
+
audio_array = self._convert_audio_bytes(audio_data)
|
| 107 |
+
if audio_array is None:
|
| 108 |
+
return "error: Audio conversion failed"
|
| 109 |
+
|
| 110 |
+
# Use faster-whisper with built-in VAD
|
| 111 |
+
segments, info = self.model.transcribe(
|
| 112 |
+
audio_array,
|
| 113 |
+
language="en",
|
| 114 |
+
# Built-in VAD parameters - much better than manual VAD
|
| 115 |
+
vad_filter=True,
|
| 116 |
+
vad_parameters=dict(
|
| 117 |
+
min_silence_duration_ms=100, # 100ms minimum silence
|
| 118 |
+
speech_pad_ms=30 # 30ms padding around speech
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Process transcription results
|
| 123 |
+
transcriptions = []
|
| 124 |
+
for segment in segments:
|
| 125 |
+
text = segment.text.strip().lower()
|
| 126 |
+
if text:
|
| 127 |
+
transcriptions.append(text)
|
| 128 |
+
|
| 129 |
+
if not transcriptions:
|
| 130 |
+
return "error: No speech detected"
|
| 131 |
+
|
| 132 |
+
# Combine all segments and extract digit
|
| 133 |
+
full_text = " ".join(transcriptions)
|
| 134 |
+
digit = self._text_to_digit(full_text)
|
| 135 |
+
|
| 136 |
+
logger.debug(f"Faster-Whisper: '{full_text}' -> '{digit}'")
|
| 137 |
+
|
| 138 |
+
if digit in "0123456789":
|
| 139 |
+
self.successful_predictions += 1
|
| 140 |
+
return digit
|
| 141 |
+
else:
|
| 142 |
+
self.failed_predictions += 1
|
| 143 |
+
return f"unclear: {full_text}"
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"Faster-Whisper processing failed: {e}")
|
| 147 |
+
self.failed_predictions += 1
|
| 148 |
+
return f"error: {str(e)}"
|
| 149 |
+
finally:
|
| 150 |
+
self.total_predictions += 1
|
| 151 |
+
|
| 152 |
+
def _convert_audio_bytes(self, audio_data: bytes) -> Optional[np.ndarray]:
|
| 153 |
+
"""Convert audio bytes to numpy array for faster-whisper."""
|
| 154 |
+
try:
|
| 155 |
+
# Check if it's a WAV file
|
| 156 |
+
if audio_data.startswith(b'RIFF'):
|
| 157 |
+
import soundfile as sf
|
| 158 |
+
audio_buffer = io.BytesIO(audio_data)
|
| 159 |
+
audio_array, sample_rate = sf.read(audio_buffer, dtype='float32')
|
| 160 |
+
|
| 161 |
+
# Convert stereo to mono if needed
|
| 162 |
+
if len(audio_array.shape) > 1:
|
| 163 |
+
audio_array = np.mean(audio_array, axis=1)
|
| 164 |
+
|
| 165 |
+
return audio_array
|
| 166 |
+
else:
|
| 167 |
+
# Raw PCM data
|
| 168 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32)
|
| 169 |
+
return audio_array / 32768.0
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"Audio conversion failed: {e}")
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
def _text_to_digit(self, text: str) -> str:
|
| 176 |
+
"""Convert transcribed text to digit."""
|
| 177 |
+
text = text.strip().lower()
|
| 178 |
+
|
| 179 |
+
# Remove common words
|
| 180 |
+
text = text.replace("the", "").replace("number", "").replace("digit", "")
|
| 181 |
+
text = text.strip()
|
| 182 |
+
|
| 183 |
+
# Direct mapping
|
| 184 |
+
if text in self.digit_map:
|
| 185 |
+
return self.digit_map[text]
|
| 186 |
+
|
| 187 |
+
# Word-by-word check
|
| 188 |
+
for word in text.split():
|
| 189 |
+
if word in self.digit_map:
|
| 190 |
+
return self.digit_map[word]
|
| 191 |
+
|
| 192 |
+
# Check for digits in text
|
| 193 |
+
digits = [char for char in text if char.isdigit()]
|
| 194 |
+
if digits:
|
| 195 |
+
return digits[0]
|
| 196 |
+
|
| 197 |
+
return text
|
| 198 |
+
|
| 199 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 200 |
+
"""Get model information."""
|
| 201 |
+
return {
|
| 202 |
+
'model_name': 'faster-whisper-tiny',
|
| 203 |
+
'model_type': 'Speech-to-Text with VAD',
|
| 204 |
+
'has_builtin_vad': True,
|
| 205 |
+
'device': self.device,
|
| 206 |
+
'available': FASTER_WHISPER_AVAILABLE
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 210 |
+
"""Get processing statistics."""
|
| 211 |
+
success_rate = self.successful_predictions / max(1, self.total_predictions)
|
| 212 |
+
|
| 213 |
+
return {
|
| 214 |
+
'total_predictions': self.total_predictions,
|
| 215 |
+
'successful_predictions': self.successful_predictions,
|
| 216 |
+
'failed_predictions': self.failed_predictions,
|
| 217 |
+
'success_rate': round(success_rate, 3),
|
| 218 |
+
'model_available': self.is_configured()
|
| 219 |
+
}
|
audio_processors/local_whisper.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from .base_processor import AudioProcessor
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class LocalWhisperProcessor(AudioProcessor):
|
| 9 |
+
"""
|
| 10 |
+
Local Whisper model using transformers pipeline.
|
| 11 |
+
Fallback when API is unavailable.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__("Local Whisper (Tiny)")
|
| 16 |
+
self.pipeline = None
|
| 17 |
+
self.model_name = "openai/whisper-tiny"
|
| 18 |
+
self.is_initialized = False
|
| 19 |
+
|
| 20 |
+
def _initialize_model(self):
|
| 21 |
+
"""Lazy initialization of the model"""
|
| 22 |
+
if self.is_initialized:
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
logger.info(f"Loading local Whisper model: {self.model_name}")
|
| 27 |
+
|
| 28 |
+
from transformers import pipeline
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
# Use CPU for compatibility, GPU if available
|
| 32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
|
| 34 |
+
self.pipeline = pipeline(
|
| 35 |
+
"automatic-speech-recognition",
|
| 36 |
+
model=self.model_name,
|
| 37 |
+
device=device,
|
| 38 |
+
torch_dtype=torch.float32, # Use float32 to avoid dtype issues
|
| 39 |
+
return_timestamps=False # We only need text
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
logger.info(f"Local Whisper model loaded on {device}")
|
| 43 |
+
self.is_initialized = True
|
| 44 |
+
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
logger.error("transformers library not installed. Run: pip install transformers torch")
|
| 47 |
+
raise Exception("transformers library required for local processing")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Failed to load local Whisper model: {str(e)}")
|
| 50 |
+
raise Exception(f"Local model initialization failed: {str(e)}")
|
| 51 |
+
|
| 52 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Process audio using local Whisper model.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
audio_data: Raw audio bytes (WAV format preferred)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Predicted digit as string ('0'-'9')
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
Exception: If processing fails
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
# Initialize model on first use
|
| 67 |
+
self._initialize_model()
|
| 68 |
+
|
| 69 |
+
# Convert audio bytes to numpy array
|
| 70 |
+
from utils.audio_utils import audio_to_numpy
|
| 71 |
+
audio_array, sample_rate = audio_to_numpy(audio_data)
|
| 72 |
+
|
| 73 |
+
# Resample to 16kHz if needed (Whisper expects 16kHz)
|
| 74 |
+
if sample_rate != 16000:
|
| 75 |
+
logger.debug(f"Resampling from {sample_rate}Hz to 16kHz")
|
| 76 |
+
import librosa
|
| 77 |
+
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000)
|
| 78 |
+
|
| 79 |
+
# Process with pipeline
|
| 80 |
+
logger.debug(f"Processing audio: {len(audio_array)} samples at 16kHz")
|
| 81 |
+
result = self.pipeline(audio_array)
|
| 82 |
+
|
| 83 |
+
if not result or 'text' not in result:
|
| 84 |
+
logger.error(f"Unexpected pipeline result: {result}")
|
| 85 |
+
raise Exception("Invalid pipeline output")
|
| 86 |
+
|
| 87 |
+
transcribed_text = result['text'].strip().lower()
|
| 88 |
+
logger.debug(f"Local Whisper transcription: '{transcribed_text}'")
|
| 89 |
+
|
| 90 |
+
# Extract digit from transcription
|
| 91 |
+
predicted_digit = self._extract_digit(transcribed_text)
|
| 92 |
+
|
| 93 |
+
if predicted_digit is None:
|
| 94 |
+
logger.warning(f"No digit found in transcription: '{transcribed_text}'")
|
| 95 |
+
return "?"
|
| 96 |
+
|
| 97 |
+
return predicted_digit
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"Local Whisper processing failed: {str(e)}")
|
| 101 |
+
raise Exception(f"Local processing error: {str(e)}")
|
| 102 |
+
|
| 103 |
+
def _extract_digit(self, text: str) -> Optional[str]:
|
| 104 |
+
"""
|
| 105 |
+
Extract digit from transcribed text.
|
| 106 |
+
Handles both numerical ('1', '2') and word forms ('one', 'two').
|
| 107 |
+
"""
|
| 108 |
+
import re
|
| 109 |
+
|
| 110 |
+
# Word to digit mapping
|
| 111 |
+
word_to_digit = {
|
| 112 |
+
'zero': '0', 'oh': '0',
|
| 113 |
+
'one': '1', 'won': '1',
|
| 114 |
+
'two': '2', 'to': '2', 'too': '2',
|
| 115 |
+
'three': '3', 'tree': '3',
|
| 116 |
+
'four': '4', 'for': '4', 'fore': '4',
|
| 117 |
+
'five': '5',
|
| 118 |
+
'six': '6', 'sick': '6',
|
| 119 |
+
'seven': '7',
|
| 120 |
+
'eight': '8', 'ate': '8',
|
| 121 |
+
'nine': '9', 'niner': '9'
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
# First, try to find a direct digit
|
| 125 |
+
digit_match = re.search(r'\b([0-9])\b', text)
|
| 126 |
+
if digit_match:
|
| 127 |
+
return digit_match.group(1)
|
| 128 |
+
|
| 129 |
+
# Then try word forms
|
| 130 |
+
words = text.split()
|
| 131 |
+
for word in words:
|
| 132 |
+
clean_word = re.sub(r'[^\w]', '', word.lower())
|
| 133 |
+
if clean_word in word_to_digit:
|
| 134 |
+
return word_to_digit[clean_word]
|
| 135 |
+
|
| 136 |
+
# Try partial matches for robustness
|
| 137 |
+
for word, digit in word_to_digit.items():
|
| 138 |
+
if word in text:
|
| 139 |
+
return digit
|
| 140 |
+
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
def is_configured(self) -> bool:
|
| 144 |
+
"""Check if local model can be initialized."""
|
| 145 |
+
try:
|
| 146 |
+
import transformers
|
| 147 |
+
import torch
|
| 148 |
+
return True
|
| 149 |
+
except ImportError:
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
def test_connection(self) -> bool:
|
| 153 |
+
"""Test local model functionality."""
|
| 154 |
+
try:
|
| 155 |
+
self._initialize_model()
|
| 156 |
+
return True
|
| 157 |
+
except:
|
| 158 |
+
return False
|
audio_processors/mel_spectrogram.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import logging
|
| 3 |
+
from .base_processor import AudioProcessor
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
class MelSpectrogramProcessor(AudioProcessor):
|
| 8 |
+
"""
|
| 9 |
+
Mel Spectrogram processor using mel-scale frequency analysis.
|
| 10 |
+
|
| 11 |
+
Future implementation will:
|
| 12 |
+
- Apply mel filterbank to frequency domain representation
|
| 13 |
+
- Use perceptually-motivated frequency scaling
|
| 14 |
+
- Feed mel spectrogram features to deep learning model
|
| 15 |
+
|
| 16 |
+
Currently returns placeholder '00' for testing UI functionality.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__("Mel Spectrogram")
|
| 21 |
+
logger.info("Mel Spectrogram processor initialized (PLACEHOLDER MODE)")
|
| 22 |
+
|
| 23 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Process audio using mel-scale spectrogram analysis.
|
| 26 |
+
|
| 27 |
+
PLACEHOLDER IMPLEMENTATION:
|
| 28 |
+
Currently returns '00' for UI testing purposes.
|
| 29 |
+
|
| 30 |
+
Future implementation will:
|
| 31 |
+
1. Convert audio bytes to numpy array
|
| 32 |
+
2. Compute STFT of the audio signal
|
| 33 |
+
3. Apply mel filterbank to convert to mel scale
|
| 34 |
+
4. Take logarithm for perceptual scaling
|
| 35 |
+
5. Feed to trained neural network (CNN/RNN)
|
| 36 |
+
6. Return predicted digit
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
audio_data: Raw audio bytes
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Predicted digit as string (currently '00')
|
| 43 |
+
"""
|
| 44 |
+
logger.debug("Processing audio with Mel Spectrogram (placeholder)")
|
| 45 |
+
|
| 46 |
+
# Simulate processing time
|
| 47 |
+
import time
|
| 48 |
+
time.sleep(0.15)
|
| 49 |
+
|
| 50 |
+
# TODO: Implement actual mel spectrogram processing:
|
| 51 |
+
# 1. audio_array = np.frombuffer(audio_data, dtype=np.float32)
|
| 52 |
+
# 2. mel_spec = librosa.feature.melspectrogram(
|
| 53 |
+
# y=audio_array,
|
| 54 |
+
# sr=sample_rate,
|
| 55 |
+
# n_mels=128,
|
| 56 |
+
# fmax=8000
|
| 57 |
+
# )
|
| 58 |
+
# 3. mel_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 59 |
+
# 4. prediction = self.neural_model.predict(mel_db)
|
| 60 |
+
# 5. return str(np.argmax(prediction))
|
| 61 |
+
|
| 62 |
+
return '00'
|
| 63 |
+
|
| 64 |
+
def get_model_info(self) -> dict:
|
| 65 |
+
"""Get information about the mel spectrogram model."""
|
| 66 |
+
return {
|
| 67 |
+
'method': 'Mel Spectrogram',
|
| 68 |
+
'status': 'PLACEHOLDER',
|
| 69 |
+
'features': 'Mel-scale frequency representation',
|
| 70 |
+
'classifier': 'CNN/RNN (not implemented)',
|
| 71 |
+
'n_mels': 128,
|
| 72 |
+
'fmax': 8000,
|
| 73 |
+
'expected_inference_time': '<500ms'
|
| 74 |
+
}
|
audio_processors/mfcc_processor.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import logging
|
| 3 |
+
from .base_processor import AudioProcessor
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
class MFCCProcessor(AudioProcessor):
|
| 8 |
+
"""
|
| 9 |
+
MFCC (Mel-Frequency Cepstral Coefficients) processor.
|
| 10 |
+
|
| 11 |
+
Future implementation will:
|
| 12 |
+
- Extract MFCC features (typically 12-13 coefficients)
|
| 13 |
+
- Apply DCT (Discrete Cosine Transform) to mel spectrogram
|
| 14 |
+
- Use traditional ML classifier (SVM, Random Forest, etc.)
|
| 15 |
+
|
| 16 |
+
Currently returns placeholder '00' for testing UI functionality.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__("MFCC")
|
| 21 |
+
logger.info("MFCC processor initialized (PLACEHOLDER MODE)")
|
| 22 |
+
|
| 23 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Process audio using MFCC feature extraction.
|
| 26 |
+
|
| 27 |
+
PLACEHOLDER IMPLEMENTATION:
|
| 28 |
+
Currently returns '00' for UI testing purposes.
|
| 29 |
+
|
| 30 |
+
Future implementation will:
|
| 31 |
+
1. Convert audio bytes to numpy array
|
| 32 |
+
2. Compute mel spectrogram of the audio
|
| 33 |
+
3. Apply DCT to get cepstral coefficients
|
| 34 |
+
4. Extract first 12-13 MFCC coefficients
|
| 35 |
+
5. Optionally add delta and delta-delta features
|
| 36 |
+
6. Feed to trained classifier (SVM/Random Forest)
|
| 37 |
+
7. Return predicted digit
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
audio_data: Raw audio bytes
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Predicted digit as string (currently '00')
|
| 44 |
+
"""
|
| 45 |
+
logger.debug("Processing audio with MFCC (placeholder)")
|
| 46 |
+
|
| 47 |
+
# Simulate processing time (MFCC should be fastest)
|
| 48 |
+
import time
|
| 49 |
+
time.sleep(0.05)
|
| 50 |
+
|
| 51 |
+
# TODO: Implement actual MFCC processing:
|
| 52 |
+
# 1. audio_array = np.frombuffer(audio_data, dtype=np.float32)
|
| 53 |
+
# 2. mfccs = librosa.feature.mfcc(
|
| 54 |
+
# y=audio_array,
|
| 55 |
+
# sr=sample_rate,
|
| 56 |
+
# n_mfcc=13,
|
| 57 |
+
# n_fft=2048,
|
| 58 |
+
# hop_length=512
|
| 59 |
+
# )
|
| 60 |
+
# 3. # Optionally add delta features
|
| 61 |
+
# 4. delta_mfccs = librosa.feature.delta(mfccs)
|
| 62 |
+
# 5. features = np.concatenate([mfccs, delta_mfccs], axis=0)
|
| 63 |
+
# 6. prediction = self.svm_model.predict(features.T.flatten().reshape(1, -1))
|
| 64 |
+
# 7. return str(prediction[0])
|
| 65 |
+
|
| 66 |
+
return '00'
|
| 67 |
+
|
| 68 |
+
def get_model_info(self) -> dict:
|
| 69 |
+
"""Get information about the MFCC model."""
|
| 70 |
+
return {
|
| 71 |
+
'method': 'MFCC (Mel-Frequency Cepstral Coefficients)',
|
| 72 |
+
'status': 'PLACEHOLDER',
|
| 73 |
+
'features': 'Cepstral coefficients with delta features',
|
| 74 |
+
'classifier': 'SVM/Random Forest (not implemented)',
|
| 75 |
+
'n_mfcc': 13,
|
| 76 |
+
'n_fft': 2048,
|
| 77 |
+
'hop_length': 512,
|
| 78 |
+
'expected_inference_time': '<100ms'
|
| 79 |
+
}
|
audio_processors/ml_mel_cnn_processor.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ML Mel CNN Digit Processor
|
| 3 |
+
Uses the trained Mel Spectrogram + 2D CNN model for digit classification
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Any, Optional, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from .base_processor import AudioProcessor
|
| 15 |
+
|
| 16 |
+
# Add project root to path for ML imports
|
| 17 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 18 |
+
sys.path.append(str(PROJECT_ROOT))
|
| 19 |
+
|
| 20 |
+
# Import ML inference
|
| 21 |
+
from ml_training.inference import load_classifier
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
class MLMelCNNProcessor(AudioProcessor):
|
| 26 |
+
"""
|
| 27 |
+
ML-based Mel CNN digit processor using trained 2D CNN model.
|
| 28 |
+
|
| 29 |
+
Performance characteristics (based on training results):
|
| 30 |
+
- Test accuracy: 97.22%
|
| 31 |
+
- Inference time: ~3-5ms
|
| 32 |
+
- Model size: ~2.6MB
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
name = "ML Mel CNN (2D Conv)"
|
| 36 |
+
|
| 37 |
+
def __init__(self, model_dir: str = "models", device: str = "auto"):
|
| 38 |
+
"""
|
| 39 |
+
Initialize ML Mel CNN processor.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model_dir: Directory containing trained models
|
| 43 |
+
device: Device to run inference on ('cpu', 'cuda', or 'auto')
|
| 44 |
+
"""
|
| 45 |
+
super().__init__(self.name)
|
| 46 |
+
|
| 47 |
+
self.model_dir = Path(model_dir)
|
| 48 |
+
self.device = device if device != "auto" else None
|
| 49 |
+
self.classifier = None
|
| 50 |
+
self._configured = False
|
| 51 |
+
|
| 52 |
+
# Performance tracking
|
| 53 |
+
self.prediction_count = 0
|
| 54 |
+
self.total_inference_time = 0.0
|
| 55 |
+
self.last_prediction_time = None
|
| 56 |
+
|
| 57 |
+
# Try to load the model
|
| 58 |
+
self._initialize_classifier()
|
| 59 |
+
|
| 60 |
+
logger.info(f"ML Mel CNN Processor initialized (configured: {self._configured})")
|
| 61 |
+
|
| 62 |
+
def _initialize_classifier(self):
|
| 63 |
+
"""Initialize the ML classifier."""
|
| 64 |
+
try:
|
| 65 |
+
# Check if model directory exists
|
| 66 |
+
if not self.model_dir.exists():
|
| 67 |
+
logger.warning(f"Model directory not found: {self.model_dir}")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# Load the Mel CNN classifier
|
| 71 |
+
self.classifier = load_classifier(
|
| 72 |
+
model_dir=str(self.model_dir),
|
| 73 |
+
pipeline_type='mel_cnn',
|
| 74 |
+
device=self.device
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self._configured = True
|
| 78 |
+
logger.info("ML Mel CNN classifier loaded successfully")
|
| 79 |
+
logger.info(f" Model device: {self.classifier.device}")
|
| 80 |
+
logger.info(f" Parameters: {sum(p.numel() for p in self.classifier.model.parameters()):,}")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Failed to load ML Mel CNN classifier: {str(e)}")
|
| 84 |
+
self.classifier = None
|
| 85 |
+
self._configured = False
|
| 86 |
+
|
| 87 |
+
def is_configured(self) -> bool:
|
| 88 |
+
"""Check if the processor is properly configured."""
|
| 89 |
+
return self._configured and self.classifier is not None
|
| 90 |
+
|
| 91 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 92 |
+
"""
|
| 93 |
+
Process audio and return predicted digit (required by base class).
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
audio_data: Raw audio data in bytes
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
predicted_digit: Predicted digit as string
|
| 100 |
+
"""
|
| 101 |
+
return self.predict(audio_data)
|
| 102 |
+
|
| 103 |
+
def predict(self, audio_data: bytes) -> str:
|
| 104 |
+
"""
|
| 105 |
+
Predict digit from audio data.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
audio_data: Raw audio data in bytes
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
predicted_digit: Predicted digit as string
|
| 112 |
+
"""
|
| 113 |
+
if not self.is_configured():
|
| 114 |
+
raise RuntimeError("ML Mel CNN processor not properly configured")
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Convert audio with optimized format for ML models
|
| 118 |
+
from utils.audio_utils import convert_for_ml_models
|
| 119 |
+
optimized_audio = convert_for_ml_models(audio_data, 'mel_cnn')
|
| 120 |
+
|
| 121 |
+
# Convert audio bytes to numpy array
|
| 122 |
+
audio_array = self._bytes_to_audio_array(optimized_audio)
|
| 123 |
+
|
| 124 |
+
# Make prediction using ML classifier
|
| 125 |
+
start_time = time.time()
|
| 126 |
+
result = self.classifier.predict(
|
| 127 |
+
audio_array,
|
| 128 |
+
return_probabilities=True,
|
| 129 |
+
return_features=False
|
| 130 |
+
)
|
| 131 |
+
inference_time = time.time() - start_time
|
| 132 |
+
|
| 133 |
+
# Update performance tracking
|
| 134 |
+
self.prediction_count += 1
|
| 135 |
+
self.total_inference_time += inference_time
|
| 136 |
+
self.last_prediction_time = inference_time
|
| 137 |
+
|
| 138 |
+
predicted_digit = str(result['predicted_digit'])
|
| 139 |
+
confidence = result['confidence']
|
| 140 |
+
|
| 141 |
+
logger.debug(f"ML Mel CNN prediction: '{predicted_digit}' "
|
| 142 |
+
f"(confidence: {confidence:.3f}, time: {inference_time*1000:.1f}ms)")
|
| 143 |
+
|
| 144 |
+
return predicted_digit
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"ML Mel CNN prediction failed: {str(e)}")
|
| 148 |
+
raise
|
| 149 |
+
|
| 150 |
+
def predict_with_timing(self, audio_data: bytes) -> Dict[str, Any]:
|
| 151 |
+
"""
|
| 152 |
+
Predict digit with detailed timing and confidence information.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
audio_data: Raw audio data in bytes
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
result: Detailed prediction results
|
| 159 |
+
"""
|
| 160 |
+
if not self.is_configured():
|
| 161 |
+
return {
|
| 162 |
+
'success': False,
|
| 163 |
+
'error': 'ML Mel CNN processor not properly configured',
|
| 164 |
+
'predicted_digit': None,
|
| 165 |
+
'inference_time': 0.0
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
# Convert audio with optimized format for ML models
|
| 170 |
+
from utils.audio_utils import convert_for_ml_models
|
| 171 |
+
optimized_audio = convert_for_ml_models(audio_data, 'mel_cnn')
|
| 172 |
+
|
| 173 |
+
# Convert audio bytes to numpy array
|
| 174 |
+
audio_array = self._bytes_to_audio_array(optimized_audio)
|
| 175 |
+
|
| 176 |
+
# Make prediction using ML classifier
|
| 177 |
+
start_time = time.time()
|
| 178 |
+
ml_result = self.classifier.predict(
|
| 179 |
+
audio_array,
|
| 180 |
+
return_probabilities=True,
|
| 181 |
+
return_features=False
|
| 182 |
+
)
|
| 183 |
+
inference_time = time.time() - start_time
|
| 184 |
+
|
| 185 |
+
# Update performance tracking
|
| 186 |
+
self.prediction_count += 1
|
| 187 |
+
self.total_inference_time += inference_time
|
| 188 |
+
self.last_prediction_time = inference_time
|
| 189 |
+
|
| 190 |
+
# Format result
|
| 191 |
+
result = {
|
| 192 |
+
'success': True,
|
| 193 |
+
'predicted_digit': str(ml_result['predicted_digit']),
|
| 194 |
+
'confidence': ml_result['confidence'],
|
| 195 |
+
'inference_time': inference_time,
|
| 196 |
+
'class_probabilities': {
|
| 197 |
+
str(k): float(v) for k, v in ml_result['class_probabilities'].items()
|
| 198 |
+
},
|
| 199 |
+
'top_3_predictions': [
|
| 200 |
+
{
|
| 201 |
+
'digit': str(pred['digit']),
|
| 202 |
+
'probability': pred['probability']
|
| 203 |
+
}
|
| 204 |
+
for pred in ml_result['top_3_predictions']
|
| 205 |
+
],
|
| 206 |
+
'method': self.name,
|
| 207 |
+
'model_type': 'ml_mel_cnn',
|
| 208 |
+
'timestamp': time.time()
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
logger.debug(f"ML Mel CNN detailed prediction: '{result['predicted_digit']}' "
|
| 212 |
+
f"(confidence: {result['confidence']:.3f}, "
|
| 213 |
+
f"time: {inference_time*1000:.1f}ms)")
|
| 214 |
+
|
| 215 |
+
return result
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"ML Mel CNN prediction with timing failed: {str(e)}")
|
| 219 |
+
return {
|
| 220 |
+
'success': False,
|
| 221 |
+
'error': str(e),
|
| 222 |
+
'predicted_digit': None,
|
| 223 |
+
'inference_time': 0.0,
|
| 224 |
+
'method': self.name,
|
| 225 |
+
'model_type': 'ml_mel_cnn',
|
| 226 |
+
'timestamp': time.time()
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
|
| 230 |
+
"""Convert audio bytes to numpy array."""
|
| 231 |
+
try:
|
| 232 |
+
# Try to interpret as int16 PCM first (most common)
|
| 233 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
| 234 |
+
|
| 235 |
+
# Convert to float32 and normalize
|
| 236 |
+
audio_array = audio_array.astype(np.float32) / 32768.0
|
| 237 |
+
|
| 238 |
+
# If the array is too short, pad it
|
| 239 |
+
if len(audio_array) < 1000: # Less than ~60ms at 16kHz
|
| 240 |
+
# Pad with zeros to minimum length
|
| 241 |
+
audio_array = np.pad(audio_array, (0, 1000 - len(audio_array)))
|
| 242 |
+
|
| 243 |
+
return audio_array
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
logger.error(f"Failed to convert audio bytes to array: {str(e)}")
|
| 247 |
+
# Return a small zero array as fallback
|
| 248 |
+
return np.zeros(1000, dtype=np.float32)
|
| 249 |
+
|
| 250 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 251 |
+
"""Get processor performance statistics."""
|
| 252 |
+
stats = super().get_stats()
|
| 253 |
+
|
| 254 |
+
if self.prediction_count > 0:
|
| 255 |
+
stats.update({
|
| 256 |
+
'ml_predictions': self.prediction_count,
|
| 257 |
+
'average_inference_time': self.total_inference_time / self.prediction_count,
|
| 258 |
+
'last_inference_time': self.last_prediction_time,
|
| 259 |
+
'throughput_per_second': self.prediction_count / self.total_inference_time if self.total_inference_time > 0 else 0,
|
| 260 |
+
'model_configured': self.is_configured()
|
| 261 |
+
})
|
| 262 |
+
|
| 263 |
+
if self.classifier:
|
| 264 |
+
# Get ML classifier performance stats
|
| 265 |
+
ml_stats = self.classifier.get_performance_stats()
|
| 266 |
+
stats['ml_classifier_stats'] = ml_stats
|
| 267 |
+
|
| 268 |
+
return stats
|
| 269 |
+
|
| 270 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 271 |
+
"""Get information about the loaded model."""
|
| 272 |
+
if not self.is_configured():
|
| 273 |
+
return {'error': 'Model not loaded'}
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
info = {
|
| 277 |
+
'pipeline_type': 'mel_cnn',
|
| 278 |
+
'model_class': self.classifier.model.__class__.__name__,
|
| 279 |
+
'device': str(self.classifier.device),
|
| 280 |
+
'parameters': sum(p.numel() for p in self.classifier.model.parameters()),
|
| 281 |
+
'feature_extractor': self.classifier.feature_extractor.__class__.__name__,
|
| 282 |
+
'has_scaler': self.classifier.scaler is not None,
|
| 283 |
+
'expected_sample_rate': 8000,
|
| 284 |
+
'expected_audio_length': 8000, # 1 second at 8kHz
|
| 285 |
+
'input_shape': '(1, 64, 51)', # Mel spectrogram shape
|
| 286 |
+
'model_architecture': '2D CNN'
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if hasattr(self.classifier, 'model_path'):
|
| 290 |
+
info['model_path'] = str(self.classifier.model_path)
|
| 291 |
+
|
| 292 |
+
return info
|
| 293 |
+
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.error(f"Failed to get model info: {str(e)}")
|
| 296 |
+
return {'error': str(e)}
|
| 297 |
+
|
| 298 |
+
def benchmark_speed(self, num_samples: int = 100) -> Dict[str, Any]:
|
| 299 |
+
"""Benchmark inference speed."""
|
| 300 |
+
if not self.is_configured():
|
| 301 |
+
return {'error': 'Model not configured'}
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
return self.classifier.benchmark_speed(num_samples)
|
| 305 |
+
except Exception as e:
|
| 306 |
+
logger.error(f"Benchmark failed: {str(e)}")
|
| 307 |
+
return {'error': str(e)}
|
audio_processors/ml_mfcc_processor.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ML MFCC Digit Processor
|
| 3 |
+
Uses the trained MFCC + Dense NN model for digit classification
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Any, Optional, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from .base_processor import AudioProcessor
|
| 15 |
+
|
| 16 |
+
# Add project root to path for ML imports
|
| 17 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 18 |
+
sys.path.append(str(PROJECT_ROOT))
|
| 19 |
+
|
| 20 |
+
# Import ML inference
|
| 21 |
+
from ml_training.inference import load_classifier
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
class MLMFCCProcessor(AudioProcessor):
|
| 26 |
+
"""
|
| 27 |
+
ML-based MFCC digit processor using trained Dense NN model.
|
| 28 |
+
|
| 29 |
+
Performance characteristics (based on training results):
|
| 30 |
+
- Test accuracy: 98.52%
|
| 31 |
+
- Inference time: ~1-2ms
|
| 32 |
+
- Model size: ~0.3MB
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
name = "ML MFCC + Dense NN (Best)"
|
| 36 |
+
|
| 37 |
+
def __init__(self, model_dir: str = "models", device: str = "auto"):
|
| 38 |
+
"""
|
| 39 |
+
Initialize ML MFCC processor.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model_dir: Directory containing trained models
|
| 43 |
+
device: Device to run inference on ('cpu', 'cuda', or 'auto')
|
| 44 |
+
"""
|
| 45 |
+
super().__init__(self.name)
|
| 46 |
+
|
| 47 |
+
self.model_dir = Path(model_dir)
|
| 48 |
+
self.device = device if device != "auto" else None
|
| 49 |
+
self.classifier = None
|
| 50 |
+
self._configured = False
|
| 51 |
+
|
| 52 |
+
# Performance tracking
|
| 53 |
+
self.prediction_count = 0
|
| 54 |
+
self.total_inference_time = 0.0
|
| 55 |
+
self.last_prediction_time = None
|
| 56 |
+
|
| 57 |
+
# Try to load the model
|
| 58 |
+
self._initialize_classifier()
|
| 59 |
+
|
| 60 |
+
logger.info(f"ML MFCC Processor initialized (configured: {self._configured})")
|
| 61 |
+
|
| 62 |
+
def _initialize_classifier(self):
|
| 63 |
+
"""Initialize the ML classifier."""
|
| 64 |
+
try:
|
| 65 |
+
# Check if model directory exists
|
| 66 |
+
if not self.model_dir.exists():
|
| 67 |
+
logger.warning(f"Model directory not found: {self.model_dir}")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# Load the MFCC classifier
|
| 71 |
+
self.classifier = load_classifier(
|
| 72 |
+
model_dir=str(self.model_dir),
|
| 73 |
+
pipeline_type='mfcc',
|
| 74 |
+
device=self.device
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self._configured = True
|
| 78 |
+
logger.info("ML MFCC classifier loaded successfully")
|
| 79 |
+
logger.info(f" Model device: {self.classifier.device}")
|
| 80 |
+
logger.info(f" Parameters: {sum(p.numel() for p in self.classifier.model.parameters()):,}")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Failed to load ML MFCC classifier: {str(e)}")
|
| 84 |
+
self.classifier = None
|
| 85 |
+
self._configured = False
|
| 86 |
+
|
| 87 |
+
def is_configured(self) -> bool:
|
| 88 |
+
"""Check if the processor is properly configured."""
|
| 89 |
+
return self._configured and self.classifier is not None
|
| 90 |
+
|
| 91 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 92 |
+
"""
|
| 93 |
+
Process audio and return predicted digit (required by base class).
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
audio_data: Raw audio data in bytes
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
predicted_digit: Predicted digit as string
|
| 100 |
+
"""
|
| 101 |
+
return self.predict(audio_data)
|
| 102 |
+
|
| 103 |
+
def predict(self, audio_data: bytes) -> str:
|
| 104 |
+
"""
|
| 105 |
+
Predict digit from audio data.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
audio_data: Raw audio data in bytes
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
predicted_digit: Predicted digit as string
|
| 112 |
+
"""
|
| 113 |
+
if not self.is_configured():
|
| 114 |
+
raise RuntimeError("ML MFCC processor not properly configured")
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Convert audio with optimized format for ML models
|
| 118 |
+
from utils.audio_utils import convert_for_ml_models
|
| 119 |
+
optimized_audio = convert_for_ml_models(audio_data, 'mfcc')
|
| 120 |
+
|
| 121 |
+
# Convert audio bytes to numpy array
|
| 122 |
+
audio_array = self._bytes_to_audio_array(optimized_audio)
|
| 123 |
+
|
| 124 |
+
# No audio preprocessing needed - normalization happens at feature level in ML pipeline
|
| 125 |
+
|
| 126 |
+
# Make prediction using ML classifier
|
| 127 |
+
start_time = time.time()
|
| 128 |
+
result = self.classifier.predict(
|
| 129 |
+
audio_array,
|
| 130 |
+
return_probabilities=True,
|
| 131 |
+
return_features=False
|
| 132 |
+
)
|
| 133 |
+
inference_time = time.time() - start_time
|
| 134 |
+
|
| 135 |
+
# Update performance tracking
|
| 136 |
+
self.prediction_count += 1
|
| 137 |
+
self.total_inference_time += inference_time
|
| 138 |
+
self.last_prediction_time = inference_time
|
| 139 |
+
|
| 140 |
+
predicted_digit = str(result['predicted_digit'])
|
| 141 |
+
confidence = result['confidence']
|
| 142 |
+
|
| 143 |
+
# Debug logging for predictions (temporary)
|
| 144 |
+
if hasattr(result, 'probabilities') or 'probabilities' in result:
|
| 145 |
+
probs = result.get('probabilities', [])
|
| 146 |
+
if len(probs) >= 10:
|
| 147 |
+
top_predictions = [(i, p) for i, p in enumerate(probs)]
|
| 148 |
+
top_predictions.sort(key=lambda x: x[1], reverse=True)
|
| 149 |
+
logger.debug(f"MFCC Top 3 predictions: {[(str(d), f'{p:.3f}') for d, p in top_predictions[:3]]}")
|
| 150 |
+
|
| 151 |
+
logger.debug(f"MFCC predicted '{predicted_digit}' with confidence {confidence:.3f} in {inference_time:.3f}s")
|
| 152 |
+
|
| 153 |
+
logger.debug(f"ML MFCC prediction: '{predicted_digit}' "
|
| 154 |
+
f"(confidence: {confidence:.3f}, time: {inference_time*1000:.1f}ms)")
|
| 155 |
+
|
| 156 |
+
return predicted_digit
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"ML MFCC prediction failed: {str(e)}")
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
def predict_with_timing(self, audio_data: bytes) -> Dict[str, Any]:
|
| 163 |
+
"""
|
| 164 |
+
Predict digit with detailed timing and confidence information.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
audio_data: Raw audio data in bytes
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
result: Detailed prediction results
|
| 171 |
+
"""
|
| 172 |
+
if not self.is_configured():
|
| 173 |
+
return {
|
| 174 |
+
'success': False,
|
| 175 |
+
'error': 'ML MFCC processor not properly configured',
|
| 176 |
+
'predicted_digit': None,
|
| 177 |
+
'inference_time': 0.0
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
# Convert audio with optimized format for ML models
|
| 182 |
+
from utils.audio_utils import convert_for_ml_models
|
| 183 |
+
optimized_audio = convert_for_ml_models(audio_data, 'mfcc')
|
| 184 |
+
|
| 185 |
+
# Convert audio bytes to numpy array
|
| 186 |
+
audio_array = self._bytes_to_audio_array(optimized_audio)
|
| 187 |
+
|
| 188 |
+
# No audio preprocessing needed - normalization happens at feature level in ML pipeline
|
| 189 |
+
|
| 190 |
+
# Make prediction using ML classifier
|
| 191 |
+
start_time = time.time()
|
| 192 |
+
ml_result = self.classifier.predict(
|
| 193 |
+
audio_array,
|
| 194 |
+
return_probabilities=True,
|
| 195 |
+
return_features=False
|
| 196 |
+
)
|
| 197 |
+
inference_time = time.time() - start_time
|
| 198 |
+
|
| 199 |
+
# Update performance tracking
|
| 200 |
+
self.prediction_count += 1
|
| 201 |
+
self.total_inference_time += inference_time
|
| 202 |
+
self.last_prediction_time = inference_time
|
| 203 |
+
|
| 204 |
+
# Format result
|
| 205 |
+
result = {
|
| 206 |
+
'success': True,
|
| 207 |
+
'predicted_digit': str(ml_result['predicted_digit']),
|
| 208 |
+
'confidence': ml_result['confidence'],
|
| 209 |
+
'inference_time': inference_time,
|
| 210 |
+
'class_probabilities': {
|
| 211 |
+
str(k): float(v) for k, v in ml_result['class_probabilities'].items()
|
| 212 |
+
},
|
| 213 |
+
'top_3_predictions': [
|
| 214 |
+
{
|
| 215 |
+
'digit': str(pred['digit']),
|
| 216 |
+
'probability': pred['probability']
|
| 217 |
+
}
|
| 218 |
+
for pred in ml_result['top_3_predictions']
|
| 219 |
+
],
|
| 220 |
+
'method': self.name,
|
| 221 |
+
'model_type': 'ml_mfcc',
|
| 222 |
+
'timestamp': time.time()
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
logger.debug(f"ML MFCC detailed prediction: '{result['predicted_digit']}' "
|
| 226 |
+
f"(confidence: {result['confidence']:.3f}, "
|
| 227 |
+
f"time: {inference_time*1000:.1f}ms)")
|
| 228 |
+
|
| 229 |
+
return result
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.error(f"ML MFCC prediction with timing failed: {str(e)}")
|
| 233 |
+
return {
|
| 234 |
+
'success': False,
|
| 235 |
+
'error': str(e),
|
| 236 |
+
'predicted_digit': None,
|
| 237 |
+
'inference_time': 0.0,
|
| 238 |
+
'method': self.name,
|
| 239 |
+
'model_type': 'ml_mfcc',
|
| 240 |
+
'timestamp': time.time()
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
|
| 244 |
+
"""Convert audio bytes to numpy array."""
|
| 245 |
+
try:
|
| 246 |
+
# Try to interpret as int16 PCM first (most common)
|
| 247 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
| 248 |
+
|
| 249 |
+
# Convert to float32 and normalize
|
| 250 |
+
audio_array = audio_array.astype(np.float32) / 32768.0
|
| 251 |
+
|
| 252 |
+
# If the array is too short, pad it
|
| 253 |
+
if len(audio_array) < 1000: # Less than ~60ms at 16kHz
|
| 254 |
+
# Pad with zeros to minimum length
|
| 255 |
+
audio_array = np.pad(audio_array, (0, 1000 - len(audio_array)))
|
| 256 |
+
|
| 257 |
+
return audio_array
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.error(f"Failed to convert audio bytes to array: {str(e)}")
|
| 261 |
+
# Return a small zero array as fallback
|
| 262 |
+
return np.zeros(1000, dtype=np.float32)
|
| 263 |
+
|
| 264 |
+
def _preprocess_audio_for_mfcc(self, audio_array: np.ndarray) -> np.ndarray:
|
| 265 |
+
"""
|
| 266 |
+
Apply MFCC-specific audio preprocessing to improve model performance.
|
| 267 |
+
This compensates for missing scaler normalization.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
audio_array: Raw audio array
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
preprocessed_audio: Audio array optimized for MFCC feature extraction
|
| 274 |
+
"""
|
| 275 |
+
try:
|
| 276 |
+
# Remove DC component
|
| 277 |
+
audio_array = audio_array - np.mean(audio_array)
|
| 278 |
+
|
| 279 |
+
# Apply gentle normalization to handle volume variations
|
| 280 |
+
# This helps compensate for the missing feature scaler
|
| 281 |
+
max_val = np.max(np.abs(audio_array))
|
| 282 |
+
if max_val > 0:
|
| 283 |
+
audio_array = audio_array / max_val * 0.7 # Scale to 70% of max to avoid clipping
|
| 284 |
+
|
| 285 |
+
# Apply a gentle high-pass filter to remove low-frequency noise
|
| 286 |
+
# This improves MFCC feature quality
|
| 287 |
+
from scipy import signal
|
| 288 |
+
if len(audio_array) > 100: # Only apply if we have enough samples
|
| 289 |
+
# Simple high-pass filter at ~300Hz for 8kHz sample rate
|
| 290 |
+
sos = signal.butter(2, 300, btype='high', fs=8000, output='sos')
|
| 291 |
+
audio_array = signal.sosfilt(sos, audio_array)
|
| 292 |
+
|
| 293 |
+
# Ensure we don't have any NaN or inf values
|
| 294 |
+
audio_array = np.nan_to_num(audio_array, nan=0.0, posinf=0.0, neginf=0.0)
|
| 295 |
+
|
| 296 |
+
logger.debug(f"MFCC preprocessing applied: range=[{np.min(audio_array):.3f}, {np.max(audio_array):.3f}], "
|
| 297 |
+
f"mean={np.mean(audio_array):.3f}, std={np.std(audio_array):.3f}")
|
| 298 |
+
|
| 299 |
+
return audio_array
|
| 300 |
+
|
| 301 |
+
except ImportError:
|
| 302 |
+
# Fallback if scipy is not available - just normalize
|
| 303 |
+
logger.warning("Scipy not available, using basic normalization")
|
| 304 |
+
audio_array = audio_array - np.mean(audio_array)
|
| 305 |
+
max_val = np.max(np.abs(audio_array))
|
| 306 |
+
if max_val > 0:
|
| 307 |
+
audio_array = audio_array / max_val * 0.7
|
| 308 |
+
return audio_array
|
| 309 |
+
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.error(f"MFCC preprocessing failed: {str(e)}")
|
| 312 |
+
# Return original array if preprocessing fails
|
| 313 |
+
return audio_array
|
| 314 |
+
|
| 315 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 316 |
+
"""Get processor performance statistics."""
|
| 317 |
+
stats = super().get_stats()
|
| 318 |
+
|
| 319 |
+
if self.prediction_count > 0:
|
| 320 |
+
stats.update({
|
| 321 |
+
'ml_predictions': self.prediction_count,
|
| 322 |
+
'average_inference_time': self.total_inference_time / self.prediction_count,
|
| 323 |
+
'last_inference_time': self.last_prediction_time,
|
| 324 |
+
'throughput_per_second': self.prediction_count / self.total_inference_time if self.total_inference_time > 0 else 0,
|
| 325 |
+
'model_configured': self.is_configured()
|
| 326 |
+
})
|
| 327 |
+
|
| 328 |
+
if self.classifier:
|
| 329 |
+
# Get ML classifier performance stats
|
| 330 |
+
ml_stats = self.classifier.get_performance_stats()
|
| 331 |
+
stats['ml_classifier_stats'] = ml_stats
|
| 332 |
+
|
| 333 |
+
return stats
|
| 334 |
+
|
| 335 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 336 |
+
"""Get information about the loaded model."""
|
| 337 |
+
if not self.is_configured():
|
| 338 |
+
return {'error': 'Model not loaded'}
|
| 339 |
+
|
| 340 |
+
try:
|
| 341 |
+
info = {
|
| 342 |
+
'pipeline_type': 'mfcc',
|
| 343 |
+
'model_class': self.classifier.model.__class__.__name__,
|
| 344 |
+
'device': str(self.classifier.device),
|
| 345 |
+
'parameters': sum(p.numel() for p in self.classifier.model.parameters()),
|
| 346 |
+
'feature_extractor': self.classifier.feature_extractor.__class__.__name__,
|
| 347 |
+
'has_scaler': self.classifier.scaler is not None,
|
| 348 |
+
'expected_sample_rate': 8000,
|
| 349 |
+
'expected_audio_length': 8000 # 1 second at 8kHz
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
if hasattr(self.classifier, 'model_path'):
|
| 353 |
+
info['model_path'] = str(self.classifier.model_path)
|
| 354 |
+
|
| 355 |
+
return info
|
| 356 |
+
|
| 357 |
+
except Exception as e:
|
| 358 |
+
logger.error(f"Failed to get model info: {str(e)}")
|
| 359 |
+
return {'error': str(e)}
|
| 360 |
+
|
| 361 |
+
def benchmark_speed(self, num_samples: int = 100) -> Dict[str, Any]:
|
| 362 |
+
"""Benchmark inference speed."""
|
| 363 |
+
if not self.is_configured():
|
| 364 |
+
return {'error': 'Model not configured'}
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
return self.classifier.benchmark_speed(num_samples)
|
| 368 |
+
except Exception as e:
|
| 369 |
+
logger.error(f"Benchmark failed: {str(e)}")
|
| 370 |
+
return {'error': str(e)}
|
audio_processors/ml_raw_cnn_processor.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ML Raw CNN Digit Processor
|
| 3 |
+
Uses the trained Raw Waveform + 1D CNN model for digit classification
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Any, Optional, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from .base_processor import AudioProcessor
|
| 15 |
+
|
| 16 |
+
# Add project root to path for ML imports
|
| 17 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 18 |
+
sys.path.append(str(PROJECT_ROOT))
|
| 19 |
+
|
| 20 |
+
# Import ML inference
|
| 21 |
+
from ml_training.inference import load_classifier
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
class MLRawCNNProcessor(AudioProcessor):
|
| 26 |
+
"""
|
| 27 |
+
ML-based Raw CNN digit processor using trained 1D CNN model.
|
| 28 |
+
|
| 29 |
+
Performance characteristics (based on training results):
|
| 30 |
+
- Test accuracy: 91.30%
|
| 31 |
+
- Inference time: ~5-8ms
|
| 32 |
+
- Model size: ~2.6MB
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
name = "ML Raw CNN (1D Conv)"
|
| 36 |
+
|
| 37 |
+
def __init__(self, model_dir: str = "models", device: str = "auto"):
|
| 38 |
+
"""
|
| 39 |
+
Initialize ML Raw CNN processor.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model_dir: Directory containing trained models
|
| 43 |
+
device: Device to run inference on ('cpu', 'cuda', or 'auto')
|
| 44 |
+
"""
|
| 45 |
+
super().__init__(self.name)
|
| 46 |
+
|
| 47 |
+
self.model_dir = Path(model_dir)
|
| 48 |
+
self.device = device if device != "auto" else None
|
| 49 |
+
self.classifier = None
|
| 50 |
+
self._configured = False
|
| 51 |
+
|
| 52 |
+
# Performance tracking
|
| 53 |
+
self.prediction_count = 0
|
| 54 |
+
self.total_inference_time = 0.0
|
| 55 |
+
self.last_prediction_time = None
|
| 56 |
+
|
| 57 |
+
# Try to load the model
|
| 58 |
+
self._initialize_classifier()
|
| 59 |
+
|
| 60 |
+
logger.info(f"ML Raw CNN Processor initialized (configured: {self._configured})")
|
| 61 |
+
|
| 62 |
+
def _initialize_classifier(self):
|
| 63 |
+
"""Initialize the ML classifier."""
|
| 64 |
+
try:
|
| 65 |
+
# Check if model directory exists
|
| 66 |
+
if not self.model_dir.exists():
|
| 67 |
+
logger.warning(f"Model directory not found: {self.model_dir}")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# Load the Raw CNN classifier
|
| 71 |
+
self.classifier = load_classifier(
|
| 72 |
+
model_dir=str(self.model_dir),
|
| 73 |
+
pipeline_type='raw_cnn',
|
| 74 |
+
device=self.device
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self._configured = True
|
| 78 |
+
logger.info("ML Raw CNN classifier loaded successfully")
|
| 79 |
+
logger.info(f" Model device: {self.classifier.device}")
|
| 80 |
+
logger.info(f" Parameters: {sum(p.numel() for p in self.classifier.model.parameters()):,}")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Failed to load ML Raw CNN classifier: {str(e)}")
|
| 84 |
+
self.classifier = None
|
| 85 |
+
self._configured = False
|
| 86 |
+
|
| 87 |
+
def is_configured(self) -> bool:
|
| 88 |
+
"""Check if the processor is properly configured."""
|
| 89 |
+
return self._configured and self.classifier is not None
|
| 90 |
+
|
| 91 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 92 |
+
"""
|
| 93 |
+
Process audio and return predicted digit (required by base class).
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
audio_data: Raw audio data in bytes
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
predicted_digit: Predicted digit as string
|
| 100 |
+
"""
|
| 101 |
+
return self.predict(audio_data)
|
| 102 |
+
|
| 103 |
+
def predict(self, audio_data: bytes) -> str:
|
| 104 |
+
"""
|
| 105 |
+
Predict digit from audio data.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
audio_data: Raw audio data in bytes
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
predicted_digit: Predicted digit as string
|
| 112 |
+
"""
|
| 113 |
+
if not self.is_configured():
|
| 114 |
+
raise RuntimeError("ML Raw CNN processor not properly configured")
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Convert audio with optimized format for ML models
|
| 118 |
+
from utils.audio_utils import convert_for_ml_models
|
| 119 |
+
optimized_audio = convert_for_ml_models(audio_data, 'raw_cnn')
|
| 120 |
+
|
| 121 |
+
# Convert audio bytes to numpy array
|
| 122 |
+
audio_array = self._bytes_to_audio_array(optimized_audio)
|
| 123 |
+
|
| 124 |
+
# Make prediction using ML classifier
|
| 125 |
+
start_time = time.time()
|
| 126 |
+
result = self.classifier.predict(
|
| 127 |
+
audio_array,
|
| 128 |
+
return_probabilities=True,
|
| 129 |
+
return_features=False
|
| 130 |
+
)
|
| 131 |
+
inference_time = time.time() - start_time
|
| 132 |
+
|
| 133 |
+
# Update performance tracking
|
| 134 |
+
self.prediction_count += 1
|
| 135 |
+
self.total_inference_time += inference_time
|
| 136 |
+
self.last_prediction_time = inference_time
|
| 137 |
+
|
| 138 |
+
predicted_digit = str(result['predicted_digit'])
|
| 139 |
+
confidence = result['confidence']
|
| 140 |
+
|
| 141 |
+
logger.debug(f"ML Raw CNN prediction: '{predicted_digit}' "
|
| 142 |
+
f"(confidence: {confidence:.3f}, time: {inference_time*1000:.1f}ms)")
|
| 143 |
+
|
| 144 |
+
return predicted_digit
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"ML Raw CNN prediction failed: {str(e)}")
|
| 148 |
+
raise
|
| 149 |
+
|
| 150 |
+
def predict_with_timing(self, audio_data: bytes) -> Dict[str, Any]:
|
| 151 |
+
"""
|
| 152 |
+
Predict digit with detailed timing and confidence information.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
audio_data: Raw audio data in bytes
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
result: Detailed prediction results
|
| 159 |
+
"""
|
| 160 |
+
if not self.is_configured():
|
| 161 |
+
return {
|
| 162 |
+
'success': False,
|
| 163 |
+
'error': 'ML Raw CNN processor not properly configured',
|
| 164 |
+
'predicted_digit': None,
|
| 165 |
+
'inference_time': 0.0
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
# Convert audio with optimized format for ML models
|
| 170 |
+
from utils.audio_utils import convert_for_ml_models
|
| 171 |
+
optimized_audio = convert_for_ml_models(audio_data, 'raw_cnn')
|
| 172 |
+
|
| 173 |
+
# Convert audio bytes to numpy array
|
| 174 |
+
audio_array = self._bytes_to_audio_array(optimized_audio)
|
| 175 |
+
|
| 176 |
+
# Make prediction using ML classifier
|
| 177 |
+
start_time = time.time()
|
| 178 |
+
ml_result = self.classifier.predict(
|
| 179 |
+
audio_array,
|
| 180 |
+
return_probabilities=True,
|
| 181 |
+
return_features=False
|
| 182 |
+
)
|
| 183 |
+
inference_time = time.time() - start_time
|
| 184 |
+
|
| 185 |
+
# Update performance tracking
|
| 186 |
+
self.prediction_count += 1
|
| 187 |
+
self.total_inference_time += inference_time
|
| 188 |
+
self.last_prediction_time = inference_time
|
| 189 |
+
|
| 190 |
+
# Format result
|
| 191 |
+
result = {
|
| 192 |
+
'success': True,
|
| 193 |
+
'predicted_digit': str(ml_result['predicted_digit']),
|
| 194 |
+
'confidence': ml_result['confidence'],
|
| 195 |
+
'inference_time': inference_time,
|
| 196 |
+
'class_probabilities': {
|
| 197 |
+
str(k): float(v) for k, v in ml_result['class_probabilities'].items()
|
| 198 |
+
},
|
| 199 |
+
'top_3_predictions': [
|
| 200 |
+
{
|
| 201 |
+
'digit': str(pred['digit']),
|
| 202 |
+
'probability': pred['probability']
|
| 203 |
+
}
|
| 204 |
+
for pred in ml_result['top_3_predictions']
|
| 205 |
+
],
|
| 206 |
+
'method': self.name,
|
| 207 |
+
'model_type': 'ml_raw_cnn',
|
| 208 |
+
'timestamp': time.time()
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
logger.debug(f"ML Raw CNN detailed prediction: '{result['predicted_digit']}' "
|
| 212 |
+
f"(confidence: {result['confidence']:.3f}, "
|
| 213 |
+
f"time: {inference_time*1000:.1f}ms)")
|
| 214 |
+
|
| 215 |
+
return result
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"ML Raw CNN prediction with timing failed: {str(e)}")
|
| 219 |
+
return {
|
| 220 |
+
'success': False,
|
| 221 |
+
'error': str(e),
|
| 222 |
+
'predicted_digit': None,
|
| 223 |
+
'inference_time': 0.0,
|
| 224 |
+
'method': self.name,
|
| 225 |
+
'model_type': 'ml_raw_cnn',
|
| 226 |
+
'timestamp': time.time()
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
|
| 230 |
+
"""Convert audio bytes to numpy array."""
|
| 231 |
+
try:
|
| 232 |
+
# Try to interpret as int16 PCM first (most common)
|
| 233 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
| 234 |
+
|
| 235 |
+
# Convert to float32 and normalize
|
| 236 |
+
audio_array = audio_array.astype(np.float32) / 32768.0
|
| 237 |
+
|
| 238 |
+
# If the array is too short, pad it
|
| 239 |
+
if len(audio_array) < 1000: # Less than ~60ms at 16kHz
|
| 240 |
+
# Pad with zeros to minimum length
|
| 241 |
+
audio_array = np.pad(audio_array, (0, 1000 - len(audio_array)))
|
| 242 |
+
|
| 243 |
+
return audio_array
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
logger.error(f"Failed to convert audio bytes to array: {str(e)}")
|
| 247 |
+
# Return a small zero array as fallback
|
| 248 |
+
return np.zeros(1000, dtype=np.float32)
|
| 249 |
+
|
| 250 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 251 |
+
"""Get processor performance statistics."""
|
| 252 |
+
stats = super().get_stats()
|
| 253 |
+
|
| 254 |
+
if self.prediction_count > 0:
|
| 255 |
+
stats.update({
|
| 256 |
+
'ml_predictions': self.prediction_count,
|
| 257 |
+
'average_inference_time': self.total_inference_time / self.prediction_count,
|
| 258 |
+
'last_inference_time': self.last_prediction_time,
|
| 259 |
+
'throughput_per_second': self.prediction_count / self.total_inference_time if self.total_inference_time > 0 else 0,
|
| 260 |
+
'model_configured': self.is_configured()
|
| 261 |
+
})
|
| 262 |
+
|
| 263 |
+
if self.classifier:
|
| 264 |
+
# Get ML classifier performance stats
|
| 265 |
+
ml_stats = self.classifier.get_performance_stats()
|
| 266 |
+
stats['ml_classifier_stats'] = ml_stats
|
| 267 |
+
|
| 268 |
+
return stats
|
| 269 |
+
|
| 270 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 271 |
+
"""Get information about the loaded model."""
|
| 272 |
+
if not self.is_configured():
|
| 273 |
+
return {'error': 'Model not loaded'}
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
info = {
|
| 277 |
+
'pipeline_type': 'raw_cnn',
|
| 278 |
+
'model_class': self.classifier.model.__class__.__name__,
|
| 279 |
+
'device': str(self.classifier.device),
|
| 280 |
+
'parameters': sum(p.numel() for p in self.classifier.model.parameters()),
|
| 281 |
+
'feature_extractor': None, # Raw waveforms don't need feature extraction
|
| 282 |
+
'has_scaler': False,
|
| 283 |
+
'expected_sample_rate': 8000,
|
| 284 |
+
'expected_audio_length': 8000, # 1 second at 8kHz
|
| 285 |
+
'input_shape': '(1, 1, 8000)', # Raw waveform shape
|
| 286 |
+
'model_architecture': '1D CNN'
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if hasattr(self.classifier, 'model_path'):
|
| 290 |
+
info['model_path'] = str(self.classifier.model_path)
|
| 291 |
+
|
| 292 |
+
return info
|
| 293 |
+
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.error(f"Failed to get model info: {str(e)}")
|
| 296 |
+
return {'error': str(e)}
|
| 297 |
+
|
| 298 |
+
def benchmark_speed(self, num_samples: int = 100) -> Dict[str, Any]:
|
| 299 |
+
"""Benchmark inference speed."""
|
| 300 |
+
if not self.is_configured():
|
| 301 |
+
return {'error': 'Model not configured'}
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
return self.classifier.benchmark_speed(num_samples)
|
| 305 |
+
except Exception as e:
|
| 306 |
+
logger.error(f"Benchmark failed: {str(e)}")
|
| 307 |
+
return {'error': str(e)}
|
audio_processors/raw_spectrogram.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import logging
|
| 3 |
+
from .base_processor import AudioProcessor
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
class RawSpectrogramProcessor(AudioProcessor):
|
| 8 |
+
"""
|
| 9 |
+
Raw Spectrogram processor using STFT (Short-Time Fourier Transform).
|
| 10 |
+
|
| 11 |
+
Future implementation will:
|
| 12 |
+
- Apply STFT to audio data for time-frequency representation
|
| 13 |
+
- Use CNN classifier trained on spectrogram images
|
| 14 |
+
- Process raw frequency domain features without mel scaling
|
| 15 |
+
|
| 16 |
+
Currently returns placeholder '00' for testing UI functionality.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__("Raw Spectrogram")
|
| 21 |
+
logger.info("Raw Spectrogram processor initialized (PLACEHOLDER MODE)")
|
| 22 |
+
|
| 23 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Process audio using raw spectrogram analysis.
|
| 26 |
+
|
| 27 |
+
PLACEHOLDER IMPLEMENTATION:
|
| 28 |
+
Currently returns '00' for UI testing purposes.
|
| 29 |
+
|
| 30 |
+
Future implementation will:
|
| 31 |
+
1. Convert audio bytes to numpy array
|
| 32 |
+
2. Apply STFT with appropriate window size and overlap
|
| 33 |
+
3. Create time-frequency representation
|
| 34 |
+
4. Normalize spectrogram values
|
| 35 |
+
5. Feed to trained CNN model
|
| 36 |
+
6. Return predicted digit
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
audio_data: Raw audio bytes
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Predicted digit as string (currently '00')
|
| 43 |
+
"""
|
| 44 |
+
logger.debug("Processing audio with Raw Spectrogram (placeholder)")
|
| 45 |
+
|
| 46 |
+
# Simulate processing time
|
| 47 |
+
import time
|
| 48 |
+
time.sleep(0.1)
|
| 49 |
+
|
| 50 |
+
# TODO: Implement actual STFT-based processing:
|
| 51 |
+
# 1. audio_array = np.frombuffer(audio_data, dtype=np.float32)
|
| 52 |
+
# 2. stft_result = np.abs(librosa.stft(audio_array, n_fft=2048, hop_length=512))
|
| 53 |
+
# 3. spectrogram = librosa.amplitude_to_db(stft_result, ref=np.max)
|
| 54 |
+
# 4. prediction = self.cnn_model.predict(spectrogram)
|
| 55 |
+
# 5. return str(np.argmax(prediction))
|
| 56 |
+
|
| 57 |
+
return '00'
|
| 58 |
+
|
| 59 |
+
def get_model_info(self) -> dict:
|
| 60 |
+
"""Get information about the raw spectrogram model."""
|
| 61 |
+
return {
|
| 62 |
+
'method': 'Raw Spectrogram (STFT)',
|
| 63 |
+
'status': 'PLACEHOLDER',
|
| 64 |
+
'features': 'Time-frequency representation',
|
| 65 |
+
'classifier': 'CNN (not implemented)',
|
| 66 |
+
'window_size': 2048,
|
| 67 |
+
'hop_length': 512,
|
| 68 |
+
'expected_inference_time': '<1s'
|
| 69 |
+
}
|
audio_processors/wav2vec2_processor.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from .base_processor import AudioProcessor
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class Wav2Vec2Processor(AudioProcessor):
|
| 9 |
+
"""
|
| 10 |
+
Wav2Vec2 model processor for speech recognition.
|
| 11 |
+
Lightweight alternative to Whisper.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__("Wav2Vec2 (Facebook)")
|
| 16 |
+
self.processor = None
|
| 17 |
+
self.model = None
|
| 18 |
+
self.model_name = "facebook/wav2vec2-base-960h"
|
| 19 |
+
self.is_initialized = False
|
| 20 |
+
|
| 21 |
+
def _initialize_model(self):
|
| 22 |
+
"""Lazy initialization of the model"""
|
| 23 |
+
if self.is_initialized:
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
logger.info(f"Loading Wav2Vec2 model: {self.model_name}")
|
| 28 |
+
|
| 29 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 30 |
+
import torch
|
| 31 |
+
|
| 32 |
+
# Load processor and model
|
| 33 |
+
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
|
| 34 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
|
| 35 |
+
|
| 36 |
+
# Move to GPU if available
|
| 37 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
+
self.model = self.model.to(device)
|
| 39 |
+
self.device = device
|
| 40 |
+
|
| 41 |
+
logger.info(f"Wav2Vec2 model loaded on {device}")
|
| 42 |
+
self.is_initialized = True
|
| 43 |
+
|
| 44 |
+
except ImportError as e:
|
| 45 |
+
logger.error("transformers library not installed. Run: pip install transformers torch")
|
| 46 |
+
raise Exception("transformers library required for Wav2Vec2 processing")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error(f"Failed to load Wav2Vec2 model: {str(e)}")
|
| 49 |
+
raise Exception(f"Wav2Vec2 model initialization failed: {str(e)}")
|
| 50 |
+
|
| 51 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 52 |
+
"""
|
| 53 |
+
Process audio using Wav2Vec2 model.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
audio_data: Raw audio bytes (WAV format preferred)
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Predicted digit as string ('0'-'9')
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
Exception: If processing fails
|
| 63 |
+
"""
|
| 64 |
+
try:
|
| 65 |
+
# Initialize model on first use
|
| 66 |
+
self._initialize_model()
|
| 67 |
+
|
| 68 |
+
# Convert audio bytes to numpy array
|
| 69 |
+
from utils.audio_utils import audio_to_numpy
|
| 70 |
+
audio_array, sample_rate = audio_to_numpy(audio_data)
|
| 71 |
+
|
| 72 |
+
# Resample to 16kHz if needed (Wav2Vec2 expects 16kHz)
|
| 73 |
+
if sample_rate != 16000:
|
| 74 |
+
logger.debug(f"Resampling from {sample_rate}Hz to 16kHz")
|
| 75 |
+
import librosa
|
| 76 |
+
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000)
|
| 77 |
+
|
| 78 |
+
logger.debug(f"Processing audio: {len(audio_array)} samples at 16kHz")
|
| 79 |
+
|
| 80 |
+
# Process with Wav2Vec2
|
| 81 |
+
import torch
|
| 82 |
+
|
| 83 |
+
# Tokenize audio
|
| 84 |
+
input_values = self.processor(
|
| 85 |
+
audio_array,
|
| 86 |
+
return_tensors="pt",
|
| 87 |
+
padding="longest",
|
| 88 |
+
sampling_rate=16000
|
| 89 |
+
).input_values.to(self.device)
|
| 90 |
+
|
| 91 |
+
# Get logits
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
logits = self.model(input_values).logits
|
| 94 |
+
|
| 95 |
+
# Get predicted tokens
|
| 96 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 97 |
+
|
| 98 |
+
# Decode transcription
|
| 99 |
+
transcription = self.processor.batch_decode(predicted_ids)[0].lower().strip()
|
| 100 |
+
logger.debug(f"Wav2Vec2 transcription: '{transcription}'")
|
| 101 |
+
|
| 102 |
+
# Extract digit from transcription
|
| 103 |
+
predicted_digit = self._extract_digit(transcription)
|
| 104 |
+
|
| 105 |
+
if predicted_digit is None:
|
| 106 |
+
logger.warning(f"No digit found in transcription: '{transcription}'")
|
| 107 |
+
return "?"
|
| 108 |
+
|
| 109 |
+
return predicted_digit
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"Wav2Vec2 processing failed: {str(e)}")
|
| 113 |
+
raise Exception(f"Wav2Vec2 processing error: {str(e)}")
|
| 114 |
+
|
| 115 |
+
def _extract_digit(self, text: str) -> Optional[str]:
|
| 116 |
+
"""
|
| 117 |
+
Extract digit from transcribed text.
|
| 118 |
+
Handles both numerical ('1', '2') and word forms ('one', 'two').
|
| 119 |
+
"""
|
| 120 |
+
import re
|
| 121 |
+
|
| 122 |
+
# Word to digit mapping
|
| 123 |
+
word_to_digit = {
|
| 124 |
+
'zero': '0', 'oh': '0',
|
| 125 |
+
'one': '1', 'won': '1',
|
| 126 |
+
'two': '2', 'to': '2', 'too': '2',
|
| 127 |
+
'three': '3', 'tree': '3',
|
| 128 |
+
'four': '4', 'for': '4', 'fore': '4', 'full': '4', # "full" often misheard as "four"
|
| 129 |
+
'five': '5',
|
| 130 |
+
'six': '6', 'sick': '6',
|
| 131 |
+
'seven': '7',
|
| 132 |
+
'eight': '8', 'ate': '8',
|
| 133 |
+
'nine': '9', 'niner': '9'
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
# First, try to find a direct digit
|
| 137 |
+
digit_match = re.search(r'\b([0-9])\b', text)
|
| 138 |
+
if digit_match:
|
| 139 |
+
return digit_match.group(1)
|
| 140 |
+
|
| 141 |
+
# Then try word forms
|
| 142 |
+
words = text.split()
|
| 143 |
+
for word in words:
|
| 144 |
+
clean_word = re.sub(r'[^\w]', '', word.lower())
|
| 145 |
+
if clean_word in word_to_digit:
|
| 146 |
+
return word_to_digit[clean_word]
|
| 147 |
+
|
| 148 |
+
# Try partial matches for robustness
|
| 149 |
+
for word, digit in word_to_digit.items():
|
| 150 |
+
if word in text:
|
| 151 |
+
return digit
|
| 152 |
+
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
def is_configured(self) -> bool:
|
| 156 |
+
"""Check if Wav2Vec2 model can be initialized."""
|
| 157 |
+
try:
|
| 158 |
+
import transformers
|
| 159 |
+
import torch
|
| 160 |
+
return True
|
| 161 |
+
except ImportError:
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
def test_connection(self) -> bool:
|
| 165 |
+
"""Test Wav2Vec2 model functionality."""
|
| 166 |
+
try:
|
| 167 |
+
self._initialize_model()
|
| 168 |
+
return True
|
| 169 |
+
except:
|
| 170 |
+
return False
|
audio_processors/whisper_digit_processor.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Whisper-based digit recognition processor
|
| 3 |
+
Specialized implementation for spoken digit recognition (0-9)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import io
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import pipeline
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
|
| 15 |
+
from .base_processor import AudioProcessor
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class WhisperDigitProcessor(AudioProcessor):
|
| 20 |
+
"""
|
| 21 |
+
Whisper-based digit recognition processor using Hugging Face transformers.
|
| 22 |
+
Optimized for single digit recognition with mapping from text to numbers.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
"""Initialize Whisper digit processor with optimized settings."""
|
| 27 |
+
super().__init__("Whisper Digit Recognition")
|
| 28 |
+
self.model = None
|
| 29 |
+
self.device = 0 if torch.cuda.is_available() else -1
|
| 30 |
+
|
| 31 |
+
# Digit mapping for text-to-number conversion
|
| 32 |
+
self.digit_map = {
|
| 33 |
+
"zero": "0", "one": "1", "two": "2", "three": "3",
|
| 34 |
+
"four": "4", "five": "5", "six": "6", "seven": "7",
|
| 35 |
+
"eight": "8", "nine": "9",
|
| 36 |
+
# Common variations and alternatives
|
| 37 |
+
"oh": "0", "o": "0",
|
| 38 |
+
"for": "4", "fore": "4", "to": "2", "too": "2", "tu": "2",
|
| 39 |
+
"tree": "3", "free": "3", "ate": "8", "ait": "8"
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Reverse mapping for validation
|
| 43 |
+
self.number_words = set(self.digit_map.keys())
|
| 44 |
+
|
| 45 |
+
# Statistics tracking
|
| 46 |
+
self.total_predictions = 0
|
| 47 |
+
self.successful_predictions = 0
|
| 48 |
+
self.failed_predictions = 0
|
| 49 |
+
self.average_inference_time = 0.0
|
| 50 |
+
|
| 51 |
+
self._initialize_model()
|
| 52 |
+
|
| 53 |
+
def _initialize_model(self):
|
| 54 |
+
"""Initialize the Whisper model with optimal settings for digit recognition."""
|
| 55 |
+
try:
|
| 56 |
+
logger.info("Initializing Whisper model for digit recognition...")
|
| 57 |
+
|
| 58 |
+
# Use Whisper tiny model for fast inference
|
| 59 |
+
self.model = pipeline(
|
| 60 |
+
"automatic-speech-recognition",
|
| 61 |
+
model="openai/whisper-tiny",
|
| 62 |
+
device=self.device,
|
| 63 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 64 |
+
return_timestamps=False # We don't need timestamps for single digits
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
logger.info(f"Whisper model initialized successfully on device: {self.device}")
|
| 68 |
+
|
| 69 |
+
# Test model with dummy input
|
| 70 |
+
test_audio = np.random.randn(16000).astype(np.float32) # 1 second of noise
|
| 71 |
+
try:
|
| 72 |
+
test_result = self.model(test_audio)
|
| 73 |
+
logger.info("Model test successful")
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logger.warning(f"Model test failed but model loaded: {e}")
|
| 76 |
+
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.error(f"Failed to initialize Whisper model: {e}")
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
def is_configured(self) -> bool:
|
| 84 |
+
"""Check if the processor is properly configured."""
|
| 85 |
+
return self.model is not None
|
| 86 |
+
|
| 87 |
+
def process_audio(self, audio_data: bytes) -> str:
|
| 88 |
+
"""
|
| 89 |
+
Predict digit from audio data.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
audio_data: Raw audio bytes (WAV format preferred)
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
str: Predicted digit (0-9) or error message
|
| 96 |
+
"""
|
| 97 |
+
if not self.is_configured():
|
| 98 |
+
return "error: Model not configured"
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
# Convert audio bytes to numpy array
|
| 102 |
+
audio_array = self._convert_audio_to_array(audio_data)
|
| 103 |
+
|
| 104 |
+
if audio_array is None:
|
| 105 |
+
return "error: Invalid audio format"
|
| 106 |
+
|
| 107 |
+
# Ensure proper sample rate and format
|
| 108 |
+
audio_array = self._preprocess_audio(audio_array)
|
| 109 |
+
|
| 110 |
+
# Run Whisper inference
|
| 111 |
+
result = self.model(audio_array)
|
| 112 |
+
text = result["text"].strip().lower()
|
| 113 |
+
|
| 114 |
+
# Convert text to digit
|
| 115 |
+
digit = self._text_to_digit(text)
|
| 116 |
+
|
| 117 |
+
# Enhanced logging to debug transcription issues
|
| 118 |
+
logger.info(f"🎤 Whisper transcription: '{text}' -> digit: '{digit}'")
|
| 119 |
+
logger.info(f"📊 Audio stats: duration={len(audio_array)/16000:.2f}s, samples={len(audio_array)}, max_val={np.max(np.abs(audio_array)):.3f}")
|
| 120 |
+
|
| 121 |
+
if digit in "0123456789":
|
| 122 |
+
self.successful_predictions += 1
|
| 123 |
+
return digit
|
| 124 |
+
else:
|
| 125 |
+
self.failed_predictions += 1
|
| 126 |
+
return f"unclear: {text}"
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.error(f"Whisper prediction failed: {e}")
|
| 130 |
+
self.failed_predictions += 1
|
| 131 |
+
return f"error: {str(e)}"
|
| 132 |
+
finally:
|
| 133 |
+
self.total_predictions += 1
|
| 134 |
+
|
| 135 |
+
def _convert_audio_to_array(self, audio_data: bytes) -> Optional[np.ndarray]:
|
| 136 |
+
"""
|
| 137 |
+
Convert audio bytes to numpy array.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
audio_data: Raw audio bytes (could be WAV file or raw PCM from VAD)
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
np.ndarray: Audio samples or None if conversion failed
|
| 144 |
+
"""
|
| 145 |
+
# First check if this looks like raw PCM data from VAD (no file headers)
|
| 146 |
+
if len(audio_data) < 100 or not audio_data.startswith(b'RIFF'):
|
| 147 |
+
# This is likely raw PCM data from WebRTC VAD
|
| 148 |
+
try:
|
| 149 |
+
logger.debug("Processing raw PCM data from VAD segment")
|
| 150 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32)
|
| 151 |
+
audio_array = audio_array / 32768.0 # Normalize to [-1, 1]
|
| 152 |
+
self._original_sample_rate = 16000 # WebRTC VAD uses 16kHz
|
| 153 |
+
return audio_array
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"Failed to process raw PCM data: {e}")
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
# This looks like a complete audio file (WAV, etc.)
|
| 159 |
+
try:
|
| 160 |
+
# Try to read as audio file using soundfile
|
| 161 |
+
audio_buffer = io.BytesIO(audio_data)
|
| 162 |
+
audio_array, sample_rate = sf.read(audio_buffer, dtype='float32')
|
| 163 |
+
|
| 164 |
+
# Handle stereo to mono conversion
|
| 165 |
+
if len(audio_array.shape) > 1:
|
| 166 |
+
audio_array = np.mean(audio_array, axis=1)
|
| 167 |
+
|
| 168 |
+
# Store original sample rate for resampling
|
| 169 |
+
self._original_sample_rate = sample_rate
|
| 170 |
+
|
| 171 |
+
logger.debug(f"Successfully loaded audio file: {len(audio_array)} samples at {sample_rate}Hz")
|
| 172 |
+
return audio_array
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.warning(f"Audio file conversion failed with soundfile: {e}")
|
| 176 |
+
|
| 177 |
+
# Final fallback: treat as raw PCM
|
| 178 |
+
try:
|
| 179 |
+
logger.debug("Fallback: treating as raw PCM data")
|
| 180 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32)
|
| 181 |
+
audio_array = audio_array / 32768.0 # Normalize to [-1, 1]
|
| 182 |
+
self._original_sample_rate = 16000 # Assume 16kHz
|
| 183 |
+
return audio_array
|
| 184 |
+
except Exception as e2:
|
| 185 |
+
logger.error(f"All audio conversion methods failed: {e2}")
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
def _preprocess_audio(self, audio_array: np.ndarray) -> np.ndarray:
|
| 189 |
+
"""
|
| 190 |
+
Preprocess audio for optimal Whisper performance.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
audio_array: Raw audio samples
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
np.ndarray: Preprocessed audio
|
| 197 |
+
"""
|
| 198 |
+
# Resample to 16kHz if needed (Whisper's expected input)
|
| 199 |
+
target_sample_rate = 16000
|
| 200 |
+
|
| 201 |
+
if hasattr(self, '_original_sample_rate') and self._original_sample_rate != target_sample_rate:
|
| 202 |
+
try:
|
| 203 |
+
import librosa
|
| 204 |
+
audio_array = librosa.resample(
|
| 205 |
+
audio_array,
|
| 206 |
+
orig_sr=self._original_sample_rate,
|
| 207 |
+
target_sr=target_sample_rate
|
| 208 |
+
)
|
| 209 |
+
logger.debug(f"Resampled audio from {self._original_sample_rate}Hz to {target_sample_rate}Hz")
|
| 210 |
+
except ImportError:
|
| 211 |
+
logger.warning("librosa not available for resampling, using original audio")
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.warning(f"Resampling failed: {e}, using original audio")
|
| 214 |
+
|
| 215 |
+
# Trim silence from edges
|
| 216 |
+
audio_array = self._trim_silence(audio_array)
|
| 217 |
+
|
| 218 |
+
# Ensure minimum length (Whisper works better with at least 0.1s)
|
| 219 |
+
min_samples = int(0.1 * target_sample_rate)
|
| 220 |
+
if len(audio_array) < min_samples:
|
| 221 |
+
# Pad with silence
|
| 222 |
+
padding = min_samples - len(audio_array)
|
| 223 |
+
audio_array = np.pad(audio_array, (0, padding), mode='constant', constant_values=0)
|
| 224 |
+
|
| 225 |
+
# Normalize audio
|
| 226 |
+
max_val = np.max(np.abs(audio_array))
|
| 227 |
+
if max_val > 0:
|
| 228 |
+
audio_array = audio_array / max_val * 0.9 # Prevent clipping
|
| 229 |
+
|
| 230 |
+
return audio_array
|
| 231 |
+
|
| 232 |
+
def _trim_silence(self, audio_array: np.ndarray, silence_threshold: float = 0.01) -> np.ndarray:
|
| 233 |
+
"""
|
| 234 |
+
Trim silence from beginning and end of audio.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
audio_array: Audio samples
|
| 238 |
+
silence_threshold: Threshold for silence detection
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
np.ndarray: Trimmed audio
|
| 242 |
+
"""
|
| 243 |
+
if len(audio_array) == 0:
|
| 244 |
+
return audio_array
|
| 245 |
+
|
| 246 |
+
# Find non-silent regions
|
| 247 |
+
energy = audio_array ** 2
|
| 248 |
+
non_silent = energy > silence_threshold
|
| 249 |
+
|
| 250 |
+
if not np.any(non_silent):
|
| 251 |
+
return audio_array # All silence, return as is
|
| 252 |
+
|
| 253 |
+
# Find first and last non-silent samples
|
| 254 |
+
first_sound = np.argmax(non_silent)
|
| 255 |
+
last_sound = len(non_silent) - np.argmax(non_silent[::-1]) - 1
|
| 256 |
+
|
| 257 |
+
# Add small padding
|
| 258 |
+
padding_samples = int(0.05 * 16000) # 50ms padding
|
| 259 |
+
first_sound = max(0, first_sound - padding_samples)
|
| 260 |
+
last_sound = min(len(audio_array) - 1, last_sound + padding_samples)
|
| 261 |
+
|
| 262 |
+
return audio_array[first_sound:last_sound + 1]
|
| 263 |
+
|
| 264 |
+
def _text_to_digit(self, text: str) -> str:
|
| 265 |
+
"""
|
| 266 |
+
Convert transcribed text to digit.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
text: Transcribed text from Whisper
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
str: Digit (0-9) or original text if no match
|
| 273 |
+
"""
|
| 274 |
+
# Clean the text
|
| 275 |
+
text = text.strip().lower()
|
| 276 |
+
|
| 277 |
+
# Remove common punctuation and extra words
|
| 278 |
+
text = text.replace(",", "").replace(".", "").replace("!", "").replace("?", "")
|
| 279 |
+
text = text.replace("the", "").replace("number", "").replace("digit", "")
|
| 280 |
+
text = text.strip()
|
| 281 |
+
|
| 282 |
+
# Try direct mapping
|
| 283 |
+
if text in self.digit_map:
|
| 284 |
+
return self.digit_map[text]
|
| 285 |
+
|
| 286 |
+
# Try word-by-word mapping for multi-word responses
|
| 287 |
+
words = text.split()
|
| 288 |
+
for word in words:
|
| 289 |
+
if word in self.digit_map:
|
| 290 |
+
return self.digit_map[word]
|
| 291 |
+
|
| 292 |
+
# Check if it's already a digit
|
| 293 |
+
if len(text) == 1 and text.isdigit():
|
| 294 |
+
return text
|
| 295 |
+
|
| 296 |
+
# Look for digits in the text
|
| 297 |
+
digits_found = [char for char in text if char.isdigit()]
|
| 298 |
+
if digits_found:
|
| 299 |
+
return digits_found[0] # Return first digit found
|
| 300 |
+
|
| 301 |
+
# No clear digit found
|
| 302 |
+
return text
|
| 303 |
+
|
| 304 |
+
def predict_with_timing(self, audio_data: bytes) -> Dict[str, Any]:
|
| 305 |
+
"""
|
| 306 |
+
Predict digit with detailed timing and confidence metrics.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
audio_data: Raw audio bytes
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
dict: Prediction results with timing and metadata
|
| 313 |
+
"""
|
| 314 |
+
start_time = time.time()
|
| 315 |
+
|
| 316 |
+
predicted_digit = self.process_audio(audio_data)
|
| 317 |
+
|
| 318 |
+
inference_time = time.time() - start_time
|
| 319 |
+
|
| 320 |
+
# Update average inference time
|
| 321 |
+
if self.total_predictions > 0:
|
| 322 |
+
self.average_inference_time = (
|
| 323 |
+
(self.average_inference_time * (self.total_predictions - 1) + inference_time)
|
| 324 |
+
/ self.total_predictions
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Determine success status
|
| 328 |
+
is_successful = predicted_digit in "0123456789"
|
| 329 |
+
confidence_score = 1.0 if is_successful else 0.0
|
| 330 |
+
|
| 331 |
+
# Extract any error information
|
| 332 |
+
error_info = None
|
| 333 |
+
if predicted_digit.startswith("error:"):
|
| 334 |
+
error_info = predicted_digit[6:].strip()
|
| 335 |
+
predicted_digit = "unknown"
|
| 336 |
+
elif predicted_digit.startswith("unclear:"):
|
| 337 |
+
error_info = f"Transcription unclear: {predicted_digit[8:].strip()}"
|
| 338 |
+
predicted_digit = "unknown"
|
| 339 |
+
|
| 340 |
+
result = {
|
| 341 |
+
'predicted_digit': predicted_digit,
|
| 342 |
+
'confidence_score': confidence_score,
|
| 343 |
+
'inference_time': round(inference_time, 4),
|
| 344 |
+
'success': is_successful,
|
| 345 |
+
'timestamp': time.time(),
|
| 346 |
+
'model': 'openai/whisper-tiny',
|
| 347 |
+
'method': 'whisper_digit'
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
if error_info:
|
| 351 |
+
result['error'] = error_info
|
| 352 |
+
|
| 353 |
+
return result
|
| 354 |
+
|
| 355 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 356 |
+
"""
|
| 357 |
+
Get information about the loaded model.
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
dict: Model information
|
| 361 |
+
"""
|
| 362 |
+
return {
|
| 363 |
+
'model_name': 'openai/whisper-tiny',
|
| 364 |
+
'model_type': 'Speech-to-Text (ASR)',
|
| 365 |
+
'specialized_for': 'Digit Recognition (0-9)',
|
| 366 |
+
'device': 'GPU' if self.device >= 0 else 'CPU',
|
| 367 |
+
'torch_device': self.device,
|
| 368 |
+
'supports_streaming': False,
|
| 369 |
+
'supported_languages': ['en'],
|
| 370 |
+
'digit_mappings': len(self.digit_map)
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 374 |
+
"""
|
| 375 |
+
Get processor statistics.
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
dict: Performance statistics
|
| 379 |
+
"""
|
| 380 |
+
success_rate = (
|
| 381 |
+
self.successful_predictions / max(1, self.total_predictions)
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
return {
|
| 385 |
+
'total_predictions': self.total_predictions,
|
| 386 |
+
'successful_predictions': self.successful_predictions,
|
| 387 |
+
'failed_predictions': self.failed_predictions,
|
| 388 |
+
'success_rate': round(success_rate, 3),
|
| 389 |
+
'average_inference_time': round(self.average_inference_time, 4),
|
| 390 |
+
'model_loaded': self.is_configured()
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
def test_with_sample_audio(self) -> Dict[str, Any]:
|
| 394 |
+
"""
|
| 395 |
+
Test the processor with generated sample audio.
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
dict: Test results
|
| 399 |
+
"""
|
| 400 |
+
if not self.is_configured():
|
| 401 |
+
return {'error': 'Model not configured'}
|
| 402 |
+
|
| 403 |
+
try:
|
| 404 |
+
# Generate simple test audio (1 second of tone)
|
| 405 |
+
sample_rate = 16000
|
| 406 |
+
duration = 1.0
|
| 407 |
+
frequency = 440 # A note
|
| 408 |
+
|
| 409 |
+
t = np.linspace(0, duration, int(sample_rate * duration))
|
| 410 |
+
test_audio = 0.3 * np.sin(2 * np.pi * frequency * t).astype(np.float32)
|
| 411 |
+
|
| 412 |
+
# Run prediction
|
| 413 |
+
start_time = time.time()
|
| 414 |
+
result = self.model(test_audio)
|
| 415 |
+
test_time = time.time() - start_time
|
| 416 |
+
|
| 417 |
+
return {
|
| 418 |
+
'test_successful': True,
|
| 419 |
+
'test_time': round(test_time, 4),
|
| 420 |
+
'transcription': result.get('text', 'No text'),
|
| 421 |
+
'model_responsive': True
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
except Exception as e:
|
| 425 |
+
return {
|
| 426 |
+
'test_successful': False,
|
| 427 |
+
'error': str(e),
|
| 428 |
+
'model_responsive': False
|
| 429 |
+
}
|
models/mel_cnn_classifier/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:784cd9615368040ec7f4fa393f4bbfa8effa8b66b5a526cb2d82f3c526537ae7
|
| 3 |
+
size 7876706
|
models/mfcc_classifier/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:35346777b57dd72acf2599359e153336859ae5af05e991e0419a3c0f8fff0248
|
| 3 |
+
size 1019362
|
models/mfcc_classifier/scaler.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5af2561a5be3934fb43d590605bbb9a6293e93935a975d904c7eac5bfe876c1
|
| 3 |
+
size 4202
|
models/raw_cnn_classifier/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fdcae9f8fed4d05a27149a6258ba44e9350924fc571f6fe87deaf9cd4f4a3a0e
|
| 3 |
+
size 7728930
|
requirements_hf.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HF Spaces Requirements - Essential packages only
|
| 2 |
+
# Core Flask API
|
| 3 |
+
Flask==2.3.3
|
| 4 |
+
Flask-CORS==4.0.0
|
| 5 |
+
requests==2.31.0
|
| 6 |
+
python-dotenv==1.0.0
|
| 7 |
+
|
| 8 |
+
# Audio Processing Core
|
| 9 |
+
numpy==1.24.3
|
| 10 |
+
librosa==0.10.1
|
| 11 |
+
scipy==1.11.4
|
| 12 |
+
soundfile==0.12.1
|
| 13 |
+
|
| 14 |
+
# ML Models - PyTorch (CPU optimized for HF Spaces)
|
| 15 |
+
torch==2.0.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu
|
| 16 |
+
torchaudio==2.0.2+cpu --extra-index-url https://download.pytorch.org/whl/cpu
|
| 17 |
+
|
| 18 |
+
# Essential ML utilities
|
| 19 |
+
scikit-learn==1.3.2
|
| 20 |
+
transformers==4.35.2
|
| 21 |
+
|
| 22 |
+
# Audio format handling
|
| 23 |
+
webrtcvad==2.0.10
|
| 24 |
+
|
| 25 |
+
# Logging and utilities
|
| 26 |
+
tqdm==4.66.1
|
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (142 Bytes). View file
|
|
|
utils/__pycache__/audio_utils.cpython-312.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
utils/__pycache__/enhanced_vad.cpython-312.pyc
ADDED
|
Binary file (25.7 kB). View file
|
|
|
utils/__pycache__/logging_utils.cpython-312.pyc
ADDED
|
Binary file (9.6 kB). View file
|
|
|
utils/__pycache__/noise_utils.cpython-312.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
utils/__pycache__/session_manager.cpython-312.pyc
ADDED
|
Binary file (16.8 kB). View file
|
|
|
utils/__pycache__/vad_feature_integration.cpython-312.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
utils/__pycache__/webm_converter.cpython-312.pyc
ADDED
|
Binary file (5.77 kB). View file
|
|
|
utils/__pycache__/webrtc_vad.cpython-312.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
utils/audio_utils.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import wave
|
| 3 |
+
import io
|
| 4 |
+
import logging
|
| 5 |
+
import subprocess
|
| 6 |
+
import tempfile
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Tuple, Optional
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
def check_ffmpeg_available() -> bool:
|
| 14 |
+
"""Check if ffmpeg is available on the system."""
|
| 15 |
+
try:
|
| 16 |
+
result = subprocess.run(['ffmpeg', '-version'],
|
| 17 |
+
capture_output=True,
|
| 18 |
+
text=True,
|
| 19 |
+
timeout=5)
|
| 20 |
+
return result.returncode == 0
|
| 21 |
+
except (subprocess.SubprocessError, FileNotFoundError, subprocess.TimeoutExpired):
|
| 22 |
+
return False
|
| 23 |
+
|
| 24 |
+
def convert_with_ffmpeg(audio_data: bytes, target_sr: int = 8000, target_format: str = 'wav') -> Optional[bytes]:
|
| 25 |
+
"""
|
| 26 |
+
Convert audio using ffmpeg for high-quality format conversion.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
audio_data: Input audio bytes in any format
|
| 30 |
+
target_sr: Target sampling rate (default: 8000 Hz for ML models)
|
| 31 |
+
target_format: Target audio format (default: wav)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Converted audio bytes or None if conversion fails
|
| 35 |
+
"""
|
| 36 |
+
if not check_ffmpeg_available():
|
| 37 |
+
logger.warning("ffmpeg not available for audio conversion")
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
temp_input = None
|
| 41 |
+
temp_output = None
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
# Create temporary files
|
| 45 |
+
with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as temp_input:
|
| 46 |
+
temp_input.write(audio_data)
|
| 47 |
+
temp_input.flush()
|
| 48 |
+
|
| 49 |
+
with tempfile.NamedTemporaryFile(suffix=f'.{target_format}', delete=False) as temp_output:
|
| 50 |
+
pass # Just need the filename
|
| 51 |
+
|
| 52 |
+
# Build ffmpeg command for high-quality conversion
|
| 53 |
+
ffmpeg_cmd = [
|
| 54 |
+
'ffmpeg',
|
| 55 |
+
'-i', temp_input.name,
|
| 56 |
+
'-ar', str(target_sr), # Resample to target sample rate
|
| 57 |
+
'-ac', '1', # Convert to mono
|
| 58 |
+
'-acodec', 'pcm_s16le', # 16-bit PCM (standard for ML)
|
| 59 |
+
'-f', target_format, # Output format
|
| 60 |
+
'-loglevel', 'error', # Reduce ffmpeg output
|
| 61 |
+
'-y', # Overwrite output
|
| 62 |
+
temp_output.name
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
logger.debug(f"Running ffmpeg conversion: {' '.join(ffmpeg_cmd)}")
|
| 66 |
+
|
| 67 |
+
# Run ffmpeg conversion
|
| 68 |
+
result = subprocess.run(ffmpeg_cmd,
|
| 69 |
+
capture_output=True,
|
| 70 |
+
text=True,
|
| 71 |
+
timeout=30)
|
| 72 |
+
|
| 73 |
+
if result.returncode == 0:
|
| 74 |
+
# Read converted audio
|
| 75 |
+
with open(temp_output.name, 'rb') as f:
|
| 76 |
+
converted_audio = f.read()
|
| 77 |
+
|
| 78 |
+
logger.debug(f"ffmpeg conversion successful: "
|
| 79 |
+
f"{len(audio_data)} -> {len(converted_audio)} bytes "
|
| 80 |
+
f"({target_sr}Hz, mono, {target_format})")
|
| 81 |
+
|
| 82 |
+
return converted_audio
|
| 83 |
+
else:
|
| 84 |
+
logger.error(f"ffmpeg conversion failed: {result.stderr}")
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"ffmpeg conversion error: {str(e)}")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
finally:
|
| 92 |
+
# Clean up temporary files
|
| 93 |
+
try:
|
| 94 |
+
if temp_input and os.path.exists(temp_input.name):
|
| 95 |
+
os.unlink(temp_input.name)
|
| 96 |
+
if temp_output and os.path.exists(temp_output.name):
|
| 97 |
+
os.unlink(temp_output.name)
|
| 98 |
+
except Exception as cleanup_error:
|
| 99 |
+
logger.warning(f"Failed to cleanup temp files: {cleanup_error}")
|
| 100 |
+
|
| 101 |
+
def convert_for_ml_models(audio_data: bytes, pipeline_type: str = 'mfcc') -> bytes:
|
| 102 |
+
"""
|
| 103 |
+
Convert audio specifically for ML model requirements.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
audio_data: Input audio bytes
|
| 107 |
+
pipeline_type: ML pipeline type ('mfcc', 'mel_cnn', 'raw_cnn')
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Audio bytes optimized for the specific ML model
|
| 111 |
+
"""
|
| 112 |
+
# All our ML models expect 8kHz, mono, 16-bit PCM
|
| 113 |
+
target_sr = 8000
|
| 114 |
+
|
| 115 |
+
# Try ffmpeg first for best quality
|
| 116 |
+
converted = convert_with_ffmpeg(audio_data, target_sr=target_sr)
|
| 117 |
+
if converted:
|
| 118 |
+
logger.debug(f"Used ffmpeg for {pipeline_type} model audio conversion")
|
| 119 |
+
return converted
|
| 120 |
+
|
| 121 |
+
# Fallback to existing conversion methods
|
| 122 |
+
logger.debug(f"Using fallback audio conversion for {pipeline_type} model")
|
| 123 |
+
return convert_audio_format(audio_data)
|
| 124 |
+
|
| 125 |
+
def validate_audio_format(audio_data: bytes) -> bool:
|
| 126 |
+
"""
|
| 127 |
+
Validate that audio data is in a supported format.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
audio_data: Raw audio bytes
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
True if format is supported, False otherwise
|
| 134 |
+
"""
|
| 135 |
+
# Check minimum size
|
| 136 |
+
if len(audio_data) < 44: # WAV header is 44 bytes
|
| 137 |
+
logger.debug(f"Audio data too small: {len(audio_data)} bytes (minimum 44 for WAV header)")
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
# Check for null/empty data
|
| 141 |
+
if audio_data[:20] == b'\x00' * 20:
|
| 142 |
+
logger.error("Audio data appears to be empty/null bytes")
|
| 143 |
+
return False
|
| 144 |
+
|
| 145 |
+
# Check if it starts with RIFF header
|
| 146 |
+
if not audio_data.startswith(b'RIFF'):
|
| 147 |
+
logger.error(f"Audio data does not start with RIFF header. First 8 bytes: {audio_data[:8]}")
|
| 148 |
+
# Try to provide more diagnostic info
|
| 149 |
+
if len(audio_data) > 20:
|
| 150 |
+
logger.error(f"First 20 bytes as hex: {audio_data[:20].hex()}")
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
with wave.open(io.BytesIO(audio_data), 'rb') as wav_file:
|
| 155 |
+
# Check basic WAV properties
|
| 156 |
+
channels = wav_file.getnchannels()
|
| 157 |
+
sample_width = wav_file.getsampwidth()
|
| 158 |
+
frame_rate = wav_file.getframerate()
|
| 159 |
+
frames = wav_file.getnframes()
|
| 160 |
+
|
| 161 |
+
logger.debug(f"Audio format: {channels} channels, {sample_width} bytes/sample, {frame_rate} Hz, {frames} frames")
|
| 162 |
+
|
| 163 |
+
# Be more lenient with streaming chunks
|
| 164 |
+
if channels not in [1, 2]:
|
| 165 |
+
logger.warning(f"Unusual channel count: {channels}")
|
| 166 |
+
return False
|
| 167 |
+
if sample_width not in [1, 2, 4]: # 8-bit, 16-bit, 32-bit
|
| 168 |
+
logger.warning(f"Unusual sample width: {sample_width}")
|
| 169 |
+
return False
|
| 170 |
+
if frame_rate < 8000 or frame_rate > 48000: # Wider range
|
| 171 |
+
logger.warning(f"Unusual frame rate: {frame_rate}")
|
| 172 |
+
return False
|
| 173 |
+
if frames == 0:
|
| 174 |
+
logger.warning("No audio frames found")
|
| 175 |
+
return False
|
| 176 |
+
|
| 177 |
+
return True
|
| 178 |
+
except wave.Error as e:
|
| 179 |
+
logger.error(f"WAV format error: {str(e)}")
|
| 180 |
+
logger.error(f"Audio data size: {len(audio_data)} bytes")
|
| 181 |
+
if len(audio_data) > 44:
|
| 182 |
+
logger.error(f"WAV header bytes: {audio_data[:44].hex()}")
|
| 183 |
+
return False
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logger.error(f"Audio validation failed: {str(e)}")
|
| 186 |
+
logger.error(f"Audio data size: {len(audio_data)} bytes")
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
def convert_audio_format(audio_data: bytes) -> bytes:
|
| 190 |
+
"""
|
| 191 |
+
Convert various audio formats (WebM, OGG, MP3, etc.) to WAV format.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
audio_data: Input audio bytes in any supported format
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
Converted audio bytes in WAV format
|
| 198 |
+
|
| 199 |
+
Raises:
|
| 200 |
+
Exception: If conversion fails
|
| 201 |
+
"""
|
| 202 |
+
try:
|
| 203 |
+
# First detect the audio format
|
| 204 |
+
from .webm_converter import detect_audio_format, convert_webm_to_wav
|
| 205 |
+
|
| 206 |
+
audio_format = detect_audio_format(audio_data)
|
| 207 |
+
logger.debug(f"Detected audio format: {audio_format}")
|
| 208 |
+
|
| 209 |
+
# Handle WebM specifically (common from MediaRecorder)
|
| 210 |
+
if audio_format == 'webm':
|
| 211 |
+
logger.info("Converting WebM audio to WAV (fallback method)")
|
| 212 |
+
converted = convert_webm_to_wav(audio_data)
|
| 213 |
+
if converted:
|
| 214 |
+
return converted
|
| 215 |
+
else:
|
| 216 |
+
raise Exception("WebM conversion failed")
|
| 217 |
+
|
| 218 |
+
# Try using pydub for format conversion (handles WebM, OGG, MP3, etc.)
|
| 219 |
+
try:
|
| 220 |
+
from pydub import AudioSegment
|
| 221 |
+
import io
|
| 222 |
+
|
| 223 |
+
# Load audio from bytes
|
| 224 |
+
audio = AudioSegment.from_file(io.BytesIO(audio_data))
|
| 225 |
+
|
| 226 |
+
# Convert to mono and 16kHz
|
| 227 |
+
audio = audio.set_channels(1) # Mono
|
| 228 |
+
audio = audio.set_frame_rate(16000) # 16kHz
|
| 229 |
+
audio = audio.set_sample_width(2) # 16-bit
|
| 230 |
+
|
| 231 |
+
# Export as WAV
|
| 232 |
+
output_buffer = io.BytesIO()
|
| 233 |
+
audio.export(output_buffer, format="wav")
|
| 234 |
+
return output_buffer.getvalue()
|
| 235 |
+
|
| 236 |
+
except ImportError:
|
| 237 |
+
logger.warning("pydub not installed, falling back to basic WAV conversion")
|
| 238 |
+
# Fall back to basic WAV processing
|
| 239 |
+
return convert_to_mono_16khz(audio_data)
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logger.warning(f"pydub conversion failed: {str(e)}, trying fallback methods")
|
| 242 |
+
|
| 243 |
+
# Try WebM converter as fallback
|
| 244 |
+
if audio_format in ['webm', 'unknown']:
|
| 245 |
+
logger.info("Trying WebM fallback converter")
|
| 246 |
+
converted = convert_webm_to_wav(audio_data)
|
| 247 |
+
if converted:
|
| 248 |
+
return converted
|
| 249 |
+
|
| 250 |
+
# Last resort: basic WAV processing
|
| 251 |
+
return convert_to_mono_16khz(audio_data)
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.error(f"All audio conversion methods failed: {str(e)}")
|
| 255 |
+
raise Exception(f"Failed to convert audio format: {str(e)}")
|
| 256 |
+
|
| 257 |
+
def convert_to_mono_16khz(audio_data: bytes) -> bytes:
|
| 258 |
+
"""
|
| 259 |
+
Convert audio to mono, 16kHz format suitable for speech recognition.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
audio_data: Input audio bytes (WAV format)
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Converted audio bytes in mono 16kHz WAV format
|
| 266 |
+
|
| 267 |
+
Raises:
|
| 268 |
+
Exception: If conversion fails
|
| 269 |
+
"""
|
| 270 |
+
try:
|
| 271 |
+
with wave.open(io.BytesIO(audio_data), 'rb') as input_wav:
|
| 272 |
+
frames = input_wav.readframes(input_wav.getnframes())
|
| 273 |
+
channels = input_wav.getnchannels()
|
| 274 |
+
sample_width = input_wav.getsampwidth()
|
| 275 |
+
frame_rate = input_wav.getframerate()
|
| 276 |
+
|
| 277 |
+
# Convert to numpy array
|
| 278 |
+
if sample_width == 2:
|
| 279 |
+
audio_array = np.frombuffer(frames, dtype=np.int16)
|
| 280 |
+
else:
|
| 281 |
+
raise Exception(f"Unsupported sample width: {sample_width}")
|
| 282 |
+
|
| 283 |
+
# Convert stereo to mono if needed
|
| 284 |
+
if channels == 2:
|
| 285 |
+
audio_array = audio_array.reshape(-1, 2)
|
| 286 |
+
audio_array = np.mean(audio_array, axis=1).astype(np.int16)
|
| 287 |
+
|
| 288 |
+
# Resample to 16kHz if needed
|
| 289 |
+
if frame_rate != 16000:
|
| 290 |
+
# Simple downsampling (for production, use proper resampling)
|
| 291 |
+
ratio = frame_rate / 16000
|
| 292 |
+
if ratio > 1:
|
| 293 |
+
# Downsample by taking every nth sample
|
| 294 |
+
indices = np.arange(0, len(audio_array), ratio).astype(int)
|
| 295 |
+
audio_array = audio_array[indices]
|
| 296 |
+
else:
|
| 297 |
+
# Upsample by repeating samples (basic interpolation)
|
| 298 |
+
audio_array = np.repeat(audio_array, int(1/ratio))
|
| 299 |
+
|
| 300 |
+
# Create output WAV
|
| 301 |
+
output = io.BytesIO()
|
| 302 |
+
with wave.open(output, 'wb') as output_wav:
|
| 303 |
+
output_wav.setnchannels(1) # Mono
|
| 304 |
+
output_wav.setsampwidth(2) # 16-bit
|
| 305 |
+
output_wav.setframerate(16000) # 16kHz
|
| 306 |
+
output_wav.writeframes(audio_array.tobytes())
|
| 307 |
+
|
| 308 |
+
return output.getvalue()
|
| 309 |
+
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.error(f"Audio conversion failed: {str(e)}")
|
| 312 |
+
raise Exception(f"Failed to convert audio: {str(e)}")
|
| 313 |
+
|
| 314 |
+
def get_audio_duration(audio_data: bytes) -> float:
|
| 315 |
+
"""
|
| 316 |
+
Get duration of audio in seconds.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
audio_data: WAV audio bytes
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Duration in seconds
|
| 323 |
+
"""
|
| 324 |
+
try:
|
| 325 |
+
with wave.open(io.BytesIO(audio_data), 'rb') as wav_file:
|
| 326 |
+
frames = wav_file.getnframes()
|
| 327 |
+
frame_rate = wav_file.getframerate()
|
| 328 |
+
duration = frames / frame_rate
|
| 329 |
+
return duration
|
| 330 |
+
except Exception as e:
|
| 331 |
+
logger.error(f"Failed to get audio duration: {str(e)}")
|
| 332 |
+
return 0.0
|
| 333 |
+
|
| 334 |
+
def audio_to_numpy(audio_data: bytes) -> Tuple[np.ndarray, int]:
|
| 335 |
+
"""
|
| 336 |
+
Convert WAV audio bytes to numpy array.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
audio_data: WAV audio bytes
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
Tuple of (audio_array, sample_rate)
|
| 343 |
+
|
| 344 |
+
Raises:
|
| 345 |
+
Exception: If conversion fails
|
| 346 |
+
"""
|
| 347 |
+
try:
|
| 348 |
+
with wave.open(io.BytesIO(audio_data), 'rb') as wav_file:
|
| 349 |
+
frames = wav_file.readframes(wav_file.getnframes())
|
| 350 |
+
sample_rate = wav_file.getframerate()
|
| 351 |
+
channels = wav_file.getnchannels()
|
| 352 |
+
sample_width = wav_file.getsampwidth()
|
| 353 |
+
|
| 354 |
+
if sample_width == 2:
|
| 355 |
+
audio_array = np.frombuffer(frames, dtype=np.int16)
|
| 356 |
+
else:
|
| 357 |
+
raise Exception(f"Unsupported sample width: {sample_width}")
|
| 358 |
+
|
| 359 |
+
# Convert to float32 and normalize
|
| 360 |
+
audio_array = audio_array.astype(np.float32) / 32767.0
|
| 361 |
+
|
| 362 |
+
# Handle stereo
|
| 363 |
+
if channels == 2:
|
| 364 |
+
audio_array = audio_array.reshape(-1, 2)
|
| 365 |
+
audio_array = np.mean(audio_array, axis=1)
|
| 366 |
+
|
| 367 |
+
return audio_array, sample_rate
|
| 368 |
+
|
| 369 |
+
except Exception as e:
|
| 370 |
+
logger.error(f"Failed to convert audio to numpy: {str(e)}")
|
| 371 |
+
raise Exception(f"Audio conversion failed: {str(e)}")
|
| 372 |
+
|
| 373 |
+
def create_test_audio(digit: str, duration: float = 1.0, sample_rate: int = 16000) -> bytes:
|
| 374 |
+
"""
|
| 375 |
+
Create test audio data for development purposes.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
digit: Digit to simulate ('0'-'9')
|
| 379 |
+
duration: Audio duration in seconds
|
| 380 |
+
sample_rate: Sample rate in Hz
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
WAV audio bytes
|
| 384 |
+
"""
|
| 385 |
+
try:
|
| 386 |
+
# Create simple tone pattern based on digit
|
| 387 |
+
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
| 388 |
+
|
| 389 |
+
# Different frequency patterns for each digit
|
| 390 |
+
freq_map = {
|
| 391 |
+
'0': [400, 600], # Low frequencies
|
| 392 |
+
'1': [800, 1000], # Higher frequencies
|
| 393 |
+
'2': [600, 800],
|
| 394 |
+
'3': [700, 900],
|
| 395 |
+
'4': [500, 700],
|
| 396 |
+
'5': [900, 1100],
|
| 397 |
+
'6': [450, 650],
|
| 398 |
+
'7': [750, 950],
|
| 399 |
+
'8': [550, 750],
|
| 400 |
+
'9': [850, 1050]
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
freqs = freq_map.get(digit, [440, 880])
|
| 404 |
+
|
| 405 |
+
# Generate tone
|
| 406 |
+
signal = np.sin(freqs[0] * 2.0 * np.pi * t) * 0.3 + np.sin(freqs[1] * 2.0 * np.pi * t) * 0.3
|
| 407 |
+
|
| 408 |
+
# Add some envelope
|
| 409 |
+
envelope = np.exp(-3 * t)
|
| 410 |
+
signal = signal * envelope
|
| 411 |
+
|
| 412 |
+
# Convert to int16
|
| 413 |
+
signal = (signal * 32767).astype(np.int16)
|
| 414 |
+
|
| 415 |
+
# Create WAV
|
| 416 |
+
output = io.BytesIO()
|
| 417 |
+
with wave.open(output, 'wb') as wav_file:
|
| 418 |
+
wav_file.setnchannels(1)
|
| 419 |
+
wav_file.setsampwidth(2)
|
| 420 |
+
wav_file.setframerate(sample_rate)
|
| 421 |
+
wav_file.writeframes(signal.tobytes())
|
| 422 |
+
|
| 423 |
+
return output.getvalue()
|
| 424 |
+
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.error(f"Failed to create test audio: {str(e)}")
|
| 427 |
+
raise Exception(f"Test audio creation failed: {str(e)}")
|
utils/enhanced_vad.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced VAD Implementation with ffmpeg support and comprehensive debugging
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import logging
|
| 7 |
+
import subprocess
|
| 8 |
+
import tempfile
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import wave
|
| 12 |
+
import io
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 15 |
+
from threading import Thread, Lock
|
| 16 |
+
import asyncio
|
| 17 |
+
import concurrent.futures
|
| 18 |
+
|
| 19 |
+
# Try to import WebRTC VAD
|
| 20 |
+
try:
|
| 21 |
+
import webrtcvad
|
| 22 |
+
WEBRTC_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
WEBRTC_AVAILABLE = False
|
| 25 |
+
logging.warning("webrtcvad not available - using fallback VAD implementation")
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
class EnhancedVAD:
|
| 30 |
+
"""
|
| 31 |
+
Enhanced Voice Activity Detection with ffmpeg integration and comprehensive debugging.
|
| 32 |
+
|
| 33 |
+
Features:
|
| 34 |
+
- ffmpeg-based audio preprocessing
|
| 35 |
+
- Multiple VAD implementations (WebRTC, simple energy-based)
|
| 36 |
+
- Comprehensive audio validation and debugging
|
| 37 |
+
- Async audio chunk saving
|
| 38 |
+
- Real-time performance monitoring
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self,
|
| 42 |
+
sample_rate: int = 16000,
|
| 43 |
+
frame_duration_ms: int = 30,
|
| 44 |
+
aggressiveness: int = 1,
|
| 45 |
+
min_speech_duration: float = 0.4,
|
| 46 |
+
max_speech_duration: float = 3.0,
|
| 47 |
+
silence_threshold: float = 0.01):
|
| 48 |
+
"""
|
| 49 |
+
Initialize Enhanced VAD.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
sample_rate: Target sample rate (Hz)
|
| 53 |
+
frame_duration_ms: Frame duration in milliseconds
|
| 54 |
+
aggressiveness: VAD aggressiveness (0-3)
|
| 55 |
+
min_speech_duration: Minimum speech segment duration (seconds)
|
| 56 |
+
max_speech_duration: Maximum speech segment duration (seconds)
|
| 57 |
+
silence_threshold: Energy threshold for silence detection
|
| 58 |
+
"""
|
| 59 |
+
self.sample_rate = sample_rate
|
| 60 |
+
self.frame_duration_ms = frame_duration_ms
|
| 61 |
+
self.frame_size = int(sample_rate * frame_duration_ms / 1000)
|
| 62 |
+
self.aggressiveness = aggressiveness
|
| 63 |
+
self.min_speech_duration = min_speech_duration
|
| 64 |
+
self.max_speech_duration = max_speech_duration
|
| 65 |
+
self.silence_threshold = silence_threshold
|
| 66 |
+
|
| 67 |
+
# Initialize WebRTC VAD if available
|
| 68 |
+
self.webrtc_vad = None
|
| 69 |
+
if WEBRTC_AVAILABLE:
|
| 70 |
+
try:
|
| 71 |
+
self.webrtc_vad = webrtcvad.Vad(aggressiveness)
|
| 72 |
+
logger.info(f"WebRTC VAD initialized (aggressiveness: {aggressiveness})")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"Failed to initialize WebRTC VAD: {e}")
|
| 75 |
+
self.webrtc_vad = None
|
| 76 |
+
|
| 77 |
+
# Check ffmpeg availability
|
| 78 |
+
self.ffmpeg_available = self._check_ffmpeg_available()
|
| 79 |
+
|
| 80 |
+
# Performance tracking
|
| 81 |
+
self.stats = {
|
| 82 |
+
'total_chunks_processed': 0,
|
| 83 |
+
'speech_segments_detected': 0,
|
| 84 |
+
'processing_time_total': 0.0,
|
| 85 |
+
'last_processing_time': 0.0,
|
| 86 |
+
'ffmpeg_conversions': 0,
|
| 87 |
+
'audio_validation_failures': 0,
|
| 88 |
+
'webrtc_available': WEBRTC_AVAILABLE and self.webrtc_vad is not None,
|
| 89 |
+
'ffmpeg_available': self.ffmpeg_available
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# Async processing
|
| 93 |
+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
| 94 |
+
self.save_lock = Lock()
|
| 95 |
+
|
| 96 |
+
logger.info(f"Enhanced VAD initialized:")
|
| 97 |
+
logger.info(f" Sample rate: {sample_rate} Hz")
|
| 98 |
+
logger.info(f" Frame duration: {frame_duration_ms} ms")
|
| 99 |
+
logger.info(f" WebRTC VAD: {'Available' if self.webrtc_vad else 'Not available'}")
|
| 100 |
+
logger.info(f" ffmpeg: {'Available' if self.ffmpeg_available else 'Not available'}")
|
| 101 |
+
|
| 102 |
+
def _check_ffmpeg_available(self) -> bool:
|
| 103 |
+
"""Check if ffmpeg is available."""
|
| 104 |
+
try:
|
| 105 |
+
result = subprocess.run(['ffmpeg', '-version'],
|
| 106 |
+
capture_output=True, text=True, timeout=5)
|
| 107 |
+
return result.returncode == 0
|
| 108 |
+
except Exception:
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def preprocess_audio_with_ffmpeg(self, audio_data: bytes) -> Optional[bytes]:
|
| 112 |
+
"""
|
| 113 |
+
Preprocess audio using ffmpeg for optimal VAD performance.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
audio_data: Raw audio bytes
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Preprocessed audio bytes or None if processing fails
|
| 120 |
+
"""
|
| 121 |
+
if not self.ffmpeg_available:
|
| 122 |
+
logger.debug("ffmpeg not available for audio preprocessing")
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
temp_input = None
|
| 126 |
+
temp_output = None
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
# Create temporary files
|
| 130 |
+
with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as temp_input:
|
| 131 |
+
temp_input.write(audio_data)
|
| 132 |
+
temp_input.flush()
|
| 133 |
+
|
| 134 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_output:
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
# ffmpeg command for VAD-optimized preprocessing
|
| 138 |
+
ffmpeg_cmd = [
|
| 139 |
+
'ffmpeg',
|
| 140 |
+
'-i', temp_input.name,
|
| 141 |
+
'-ar', str(self.sample_rate), # Resample to target rate
|
| 142 |
+
'-ac', '1', # Convert to mono
|
| 143 |
+
'-acodec', 'pcm_s16le', # 16-bit PCM
|
| 144 |
+
'-af', 'highpass=f=80,lowpass=f=8000,dynaudnorm=f=10:g=3', # Audio filters for speech
|
| 145 |
+
'-f', 'wav',
|
| 146 |
+
'-loglevel', 'error',
|
| 147 |
+
'-y',
|
| 148 |
+
temp_output.name
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True, timeout=10)
|
| 152 |
+
|
| 153 |
+
if result.returncode == 0:
|
| 154 |
+
with open(temp_output.name, 'rb') as f:
|
| 155 |
+
preprocessed_audio = f.read()
|
| 156 |
+
|
| 157 |
+
self.stats['ffmpeg_conversions'] += 1
|
| 158 |
+
logger.debug(f"ffmpeg preprocessing: {len(audio_data)} -> {len(preprocessed_audio)} bytes")
|
| 159 |
+
return preprocessed_audio
|
| 160 |
+
else:
|
| 161 |
+
logger.error(f"ffmpeg preprocessing failed: {result.stderr}")
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"ffmpeg preprocessing error: {e}")
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
finally:
|
| 169 |
+
# Cleanup
|
| 170 |
+
try:
|
| 171 |
+
if temp_input and os.path.exists(temp_input.name):
|
| 172 |
+
os.unlink(temp_input.name)
|
| 173 |
+
if temp_output and os.path.exists(temp_output.name):
|
| 174 |
+
os.unlink(temp_output.name)
|
| 175 |
+
except Exception:
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
def validate_and_debug_audio(self, audio_data: bytes) -> Dict[str, Any]:
|
| 179 |
+
"""
|
| 180 |
+
Comprehensive audio validation and debugging.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
audio_data: Audio data to validate
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Validation results and debugging information
|
| 187 |
+
"""
|
| 188 |
+
debug_info = {
|
| 189 |
+
'size_bytes': len(audio_data),
|
| 190 |
+
'valid_wav': False,
|
| 191 |
+
'sample_rate': None,
|
| 192 |
+
'channels': None,
|
| 193 |
+
'duration': 0.0,
|
| 194 |
+
'energy_level': 0.0,
|
| 195 |
+
'is_silent': True,
|
| 196 |
+
'format_detected': 'unknown',
|
| 197 |
+
'issues': []
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
# Check minimum size
|
| 202 |
+
if len(audio_data) < 44:
|
| 203 |
+
debug_info['issues'].append(f"Too small: {len(audio_data)} bytes (need ≥44 for WAV)")
|
| 204 |
+
return debug_info
|
| 205 |
+
|
| 206 |
+
# Detect format by header
|
| 207 |
+
if audio_data.startswith(b'RIFF') and b'WAVE' in audio_data[:20]:
|
| 208 |
+
debug_info['format_detected'] = 'wav'
|
| 209 |
+
elif audio_data.startswith(b'OggS'):
|
| 210 |
+
debug_info['format_detected'] = 'ogg'
|
| 211 |
+
elif audio_data.startswith(b'\x1a\x45\xdf\xa3'):
|
| 212 |
+
debug_info['format_detected'] = 'webm'
|
| 213 |
+
|
| 214 |
+
# Try to parse as WAV
|
| 215 |
+
try:
|
| 216 |
+
with wave.open(io.BytesIO(audio_data), 'rb') as wav:
|
| 217 |
+
debug_info['valid_wav'] = True
|
| 218 |
+
debug_info['sample_rate'] = wav.getframerate()
|
| 219 |
+
debug_info['channels'] = wav.getnchannels()
|
| 220 |
+
debug_info['duration'] = wav.getnframes() / wav.getframerate()
|
| 221 |
+
|
| 222 |
+
# Read audio samples for analysis
|
| 223 |
+
wav.rewind()
|
| 224 |
+
frames = wav.readframes(wav.getnframes())
|
| 225 |
+
|
| 226 |
+
if len(frames) > 0:
|
| 227 |
+
# Convert to numpy for analysis
|
| 228 |
+
audio_array = np.frombuffer(frames, dtype=np.int16)
|
| 229 |
+
|
| 230 |
+
# Calculate energy level
|
| 231 |
+
energy = np.sqrt(np.mean(audio_array.astype(np.float32) ** 2))
|
| 232 |
+
debug_info['energy_level'] = float(energy)
|
| 233 |
+
debug_info['is_silent'] = energy < (self.silence_threshold * 32768)
|
| 234 |
+
|
| 235 |
+
# Check for constant beep (common issue)
|
| 236 |
+
if len(audio_array) > 100:
|
| 237 |
+
# Check if audio is a constant tone (beep)
|
| 238 |
+
diff = np.diff(audio_array)
|
| 239 |
+
if np.std(diff) < 100: # Very low variation
|
| 240 |
+
debug_info['issues'].append("Constant tone/beep detected")
|
| 241 |
+
|
| 242 |
+
# Check dynamic range
|
| 243 |
+
if np.max(audio_array) - np.min(audio_array) < 1000:
|
| 244 |
+
debug_info['issues'].append("Very low dynamic range")
|
| 245 |
+
|
| 246 |
+
except Exception as wav_error:
|
| 247 |
+
debug_info['issues'].append(f"WAV parsing failed: {wav_error}")
|
| 248 |
+
|
| 249 |
+
# Additional format-specific checks
|
| 250 |
+
if debug_info['format_detected'] in ['ogg', 'webm'] and not debug_info['valid_wav']:
|
| 251 |
+
debug_info['issues'].append("Non-WAV format detected - requires conversion")
|
| 252 |
+
|
| 253 |
+
logger.debug(f"Audio validation: {debug_info}")
|
| 254 |
+
|
| 255 |
+
if debug_info['issues']:
|
| 256 |
+
self.stats['audio_validation_failures'] += 1
|
| 257 |
+
logger.warning(f"Audio validation issues: {debug_info['issues']}")
|
| 258 |
+
|
| 259 |
+
return debug_info
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
debug_info['issues'].append(f"Validation error: {str(e)}")
|
| 263 |
+
logger.error(f"Audio validation failed: {e}")
|
| 264 |
+
return debug_info
|
| 265 |
+
|
| 266 |
+
def detect_speech_segments(self, audio_data: bytes) -> List[Tuple[bytes, Dict[str, Any]]]:
|
| 267 |
+
"""
|
| 268 |
+
Detect speech segments using multiple methods.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
audio_data: Input audio data
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
List of (segment_audio, segment_info) tuples
|
| 275 |
+
"""
|
| 276 |
+
start_time = time.time()
|
| 277 |
+
|
| 278 |
+
# Validate and debug audio
|
| 279 |
+
debug_info = self.validate_and_debug_audio(audio_data)
|
| 280 |
+
|
| 281 |
+
segments = []
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
# Preprocess with ffmpeg if available
|
| 285 |
+
processed_audio = self.preprocess_audio_with_ffmpeg(audio_data)
|
| 286 |
+
if processed_audio:
|
| 287 |
+
working_audio = processed_audio
|
| 288 |
+
logger.debug("Using ffmpeg-preprocessed audio for VAD")
|
| 289 |
+
else:
|
| 290 |
+
working_audio = audio_data
|
| 291 |
+
logger.debug("Using original audio for VAD")
|
| 292 |
+
|
| 293 |
+
# Re-validate processed audio
|
| 294 |
+
if processed_audio:
|
| 295 |
+
processed_debug = self.validate_and_debug_audio(processed_audio)
|
| 296 |
+
logger.debug(f"Processed audio validation: {processed_debug}")
|
| 297 |
+
|
| 298 |
+
# Method 1: WebRTC VAD (if available)
|
| 299 |
+
if self.webrtc_vad and debug_info['valid_wav']:
|
| 300 |
+
webrtc_segments = self._webrtc_vad_detection(working_audio)
|
| 301 |
+
segments.extend(webrtc_segments)
|
| 302 |
+
logger.debug(f"WebRTC VAD found {len(webrtc_segments)} segments")
|
| 303 |
+
|
| 304 |
+
# Method 2: Energy-based VAD (fallback)
|
| 305 |
+
if not segments or debug_info['issues']:
|
| 306 |
+
energy_segments = self._energy_based_vad(working_audio)
|
| 307 |
+
segments.extend(energy_segments)
|
| 308 |
+
logger.debug(f"Energy VAD found {len(energy_segments)} segments")
|
| 309 |
+
|
| 310 |
+
# Method 3: Simple duration-based segmentation (last resort)
|
| 311 |
+
if not segments and len(audio_data) > 8000: # > 8KB
|
| 312 |
+
fallback_segment = self._create_fallback_segment(working_audio)
|
| 313 |
+
if fallback_segment:
|
| 314 |
+
segments.append(fallback_segment)
|
| 315 |
+
logger.debug("Used fallback segmentation")
|
| 316 |
+
|
| 317 |
+
processing_time = time.time() - start_time
|
| 318 |
+
self.stats['total_chunks_processed'] += 1
|
| 319 |
+
self.stats['speech_segments_detected'] += len(segments)
|
| 320 |
+
self.stats['processing_time_total'] += processing_time
|
| 321 |
+
self.stats['last_processing_time'] = processing_time
|
| 322 |
+
|
| 323 |
+
logger.debug(f"VAD processing complete: {len(segments)} segments in {processing_time:.3f}s")
|
| 324 |
+
|
| 325 |
+
return segments
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
logger.error(f"Speech segment detection failed: {e}")
|
| 329 |
+
return []
|
| 330 |
+
|
| 331 |
+
def _webrtc_vad_detection(self, audio_data: bytes) -> List[Tuple[bytes, Dict[str, Any]]]:
|
| 332 |
+
"""WebRTC-based speech detection."""
|
| 333 |
+
segments = []
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
frame_size_bytes = self.frame_size * 2 # 16-bit = 2 bytes per sample
|
| 337 |
+
frames = []
|
| 338 |
+
|
| 339 |
+
# Extract frames
|
| 340 |
+
for i in range(0, len(audio_data) - frame_size_bytes + 1, frame_size_bytes):
|
| 341 |
+
frame = audio_data[i:i + frame_size_bytes]
|
| 342 |
+
if len(frame) == frame_size_bytes:
|
| 343 |
+
frames.append(frame)
|
| 344 |
+
|
| 345 |
+
if len(frames) < 5: # Need minimum frames
|
| 346 |
+
return segments
|
| 347 |
+
|
| 348 |
+
# VAD processing
|
| 349 |
+
speech_frames = []
|
| 350 |
+
for frame in frames:
|
| 351 |
+
try:
|
| 352 |
+
is_speech = self.webrtc_vad.is_speech(frame, self.sample_rate)
|
| 353 |
+
speech_frames.append((frame, is_speech))
|
| 354 |
+
except Exception as e:
|
| 355 |
+
logger.debug(f"WebRTC VAD frame processing failed: {e}")
|
| 356 |
+
speech_frames.append((frame, False))
|
| 357 |
+
|
| 358 |
+
# Group consecutive speech frames
|
| 359 |
+
current_segment = []
|
| 360 |
+
for frame, is_speech in speech_frames:
|
| 361 |
+
if is_speech:
|
| 362 |
+
current_segment.append(frame)
|
| 363 |
+
else:
|
| 364 |
+
if len(current_segment) > 0:
|
| 365 |
+
# End of speech segment
|
| 366 |
+
segment_audio = b''.join(current_segment)
|
| 367 |
+
segment_duration = len(current_segment) * self.frame_duration_ms / 1000
|
| 368 |
+
|
| 369 |
+
if segment_duration >= self.min_speech_duration:
|
| 370 |
+
segments.append((segment_audio, {
|
| 371 |
+
'duration': segment_duration,
|
| 372 |
+
'method': 'webrtc_vad',
|
| 373 |
+
'frames': len(current_segment)
|
| 374 |
+
}))
|
| 375 |
+
|
| 376 |
+
current_segment = []
|
| 377 |
+
|
| 378 |
+
# Handle final segment
|
| 379 |
+
if current_segment:
|
| 380 |
+
segment_audio = b''.join(current_segment)
|
| 381 |
+
segment_duration = len(current_segment) * self.frame_duration_ms / 1000
|
| 382 |
+
|
| 383 |
+
if segment_duration >= self.min_speech_duration:
|
| 384 |
+
segments.append((segment_audio, {
|
| 385 |
+
'duration': segment_duration,
|
| 386 |
+
'method': 'webrtc_vad',
|
| 387 |
+
'frames': len(current_segment)
|
| 388 |
+
}))
|
| 389 |
+
|
| 390 |
+
return segments
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
logger.error(f"WebRTC VAD detection failed: {e}")
|
| 394 |
+
return []
|
| 395 |
+
|
| 396 |
+
def _energy_based_vad(self, audio_data: bytes) -> List[Tuple[bytes, Dict[str, Any]]]:
|
| 397 |
+
"""Energy-based speech detection."""
|
| 398 |
+
segments = []
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
# Try to parse as WAV or raw PCM
|
| 402 |
+
try:
|
| 403 |
+
with wave.open(io.BytesIO(audio_data), 'rb') as wav:
|
| 404 |
+
frames = wav.readframes(wav.getnframes())
|
| 405 |
+
sample_rate = wav.getframerate()
|
| 406 |
+
except:
|
| 407 |
+
# Assume raw 16-bit PCM
|
| 408 |
+
frames = audio_data
|
| 409 |
+
sample_rate = self.sample_rate
|
| 410 |
+
|
| 411 |
+
if len(frames) < 1000: # Too short
|
| 412 |
+
return segments
|
| 413 |
+
|
| 414 |
+
# Convert to numpy array
|
| 415 |
+
audio_samples = np.frombuffer(frames, dtype=np.int16)
|
| 416 |
+
audio_float = audio_samples.astype(np.float32) / 32768.0
|
| 417 |
+
|
| 418 |
+
# Calculate energy in overlapping windows
|
| 419 |
+
window_size = int(sample_rate * 0.1) # 100ms windows
|
| 420 |
+
hop_size = window_size // 2
|
| 421 |
+
|
| 422 |
+
energies = []
|
| 423 |
+
for i in range(0, len(audio_float) - window_size, hop_size):
|
| 424 |
+
window = audio_float[i:i + window_size]
|
| 425 |
+
energy = np.sqrt(np.mean(window ** 2))
|
| 426 |
+
energies.append(energy)
|
| 427 |
+
|
| 428 |
+
if len(energies) < 3:
|
| 429 |
+
return segments
|
| 430 |
+
|
| 431 |
+
# Adaptive threshold
|
| 432 |
+
mean_energy = np.mean(energies)
|
| 433 |
+
threshold = max(self.silence_threshold, mean_energy * 0.3)
|
| 434 |
+
|
| 435 |
+
# Find speech segments
|
| 436 |
+
if isinstance(energies, (list, np.ndarray)):
|
| 437 |
+
energies = np.array(energies) # Ensure it's a numpy array
|
| 438 |
+
speech_windows = energies > threshold
|
| 439 |
+
|
| 440 |
+
# Group consecutive speech windows
|
| 441 |
+
speech_start = None
|
| 442 |
+
for i, is_speech in enumerate(speech_windows):
|
| 443 |
+
if is_speech and speech_start is None:
|
| 444 |
+
speech_start = i
|
| 445 |
+
elif not is_speech and speech_start is not None:
|
| 446 |
+
# End of speech
|
| 447 |
+
start_sample = speech_start * hop_size
|
| 448 |
+
end_sample = min(i * hop_size + window_size, len(audio_samples))
|
| 449 |
+
|
| 450 |
+
segment_samples = audio_samples[start_sample:end_sample]
|
| 451 |
+
segment_duration = len(segment_samples) / sample_rate
|
| 452 |
+
|
| 453 |
+
if segment_duration >= self.min_speech_duration:
|
| 454 |
+
# Convert back to bytes
|
| 455 |
+
segment_audio = segment_samples.tobytes()
|
| 456 |
+
|
| 457 |
+
segments.append((segment_audio, {
|
| 458 |
+
'duration': segment_duration,
|
| 459 |
+
'method': 'energy_based',
|
| 460 |
+
'start_time': start_sample / sample_rate,
|
| 461 |
+
'energy_threshold': threshold,
|
| 462 |
+
'mean_energy': mean_energy
|
| 463 |
+
}))
|
| 464 |
+
|
| 465 |
+
speech_start = None
|
| 466 |
+
|
| 467 |
+
return segments
|
| 468 |
+
|
| 469 |
+
except Exception as e:
|
| 470 |
+
logger.error(f"Energy-based VAD failed: {e}")
|
| 471 |
+
return []
|
| 472 |
+
|
| 473 |
+
def _create_fallback_segment(self, audio_data: bytes) -> Optional[Tuple[bytes, Dict[str, Any]]]:
|
| 474 |
+
"""Create a fallback segment when VAD methods fail."""
|
| 475 |
+
try:
|
| 476 |
+
# Use the entire audio as a segment if it's reasonable length
|
| 477 |
+
debug_info = self.validate_and_debug_audio(audio_data)
|
| 478 |
+
|
| 479 |
+
if debug_info['duration'] > 0:
|
| 480 |
+
duration = debug_info['duration']
|
| 481 |
+
else:
|
| 482 |
+
# Estimate duration based on size (assume 16-bit, mono, 16kHz)
|
| 483 |
+
estimated_samples = len(audio_data) // 2
|
| 484 |
+
duration = estimated_samples / self.sample_rate
|
| 485 |
+
|
| 486 |
+
if self.min_speech_duration <= duration <= self.max_speech_duration:
|
| 487 |
+
return (audio_data, {
|
| 488 |
+
'duration': duration,
|
| 489 |
+
'method': 'fallback',
|
| 490 |
+
'estimated': True,
|
| 491 |
+
'issues': debug_info['issues']
|
| 492 |
+
})
|
| 493 |
+
|
| 494 |
+
return None
|
| 495 |
+
|
| 496 |
+
except Exception as e:
|
| 497 |
+
logger.error(f"Fallback segment creation failed: {e}")
|
| 498 |
+
return None
|
| 499 |
+
|
| 500 |
+
async def save_audio_chunk_async(self, audio_data: bytes, session_id: str,
|
| 501 |
+
chunk_type: str = "vad_chunk") -> Optional[str]:
|
| 502 |
+
"""
|
| 503 |
+
Asynchronously save audio chunk to file.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
audio_data: Audio data to save
|
| 507 |
+
session_id: Session identifier
|
| 508 |
+
chunk_type: Type of chunk (for filename)
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
Path to saved file or None if failed
|
| 512 |
+
"""
|
| 513 |
+
def _save_chunk():
|
| 514 |
+
try:
|
| 515 |
+
with self.save_lock:
|
| 516 |
+
timestamp = int(time.time() * 1000)
|
| 517 |
+
filename = f"{chunk_type}_{session_id}_{timestamp}.wav"
|
| 518 |
+
filepath = Path("output") / filename
|
| 519 |
+
|
| 520 |
+
# Ensure output directory exists
|
| 521 |
+
filepath.parent.mkdir(exist_ok=True)
|
| 522 |
+
|
| 523 |
+
# Save as WAV file
|
| 524 |
+
with open(filepath, 'wb') as f:
|
| 525 |
+
f.write(audio_data)
|
| 526 |
+
|
| 527 |
+
logger.debug(f"Saved audio chunk: {filepath}")
|
| 528 |
+
return str(filepath)
|
| 529 |
+
|
| 530 |
+
except Exception as e:
|
| 531 |
+
logger.error(f"Failed to save audio chunk: {e}")
|
| 532 |
+
return None
|
| 533 |
+
|
| 534 |
+
# Run in executor to avoid blocking
|
| 535 |
+
loop = asyncio.get_event_loop()
|
| 536 |
+
result = await loop.run_in_executor(self.executor, _save_chunk)
|
| 537 |
+
return result
|
| 538 |
+
|
| 539 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 540 |
+
"""Get comprehensive VAD statistics."""
|
| 541 |
+
stats = self.stats.copy()
|
| 542 |
+
|
| 543 |
+
if stats['total_chunks_processed'] > 0:
|
| 544 |
+
stats['average_processing_time'] = stats['processing_time_total'] / stats['total_chunks_processed']
|
| 545 |
+
stats['segments_per_chunk'] = stats['speech_segments_detected'] / stats['total_chunks_processed']
|
| 546 |
+
else:
|
| 547 |
+
stats['average_processing_time'] = 0.0
|
| 548 |
+
stats['segments_per_chunk'] = 0.0
|
| 549 |
+
|
| 550 |
+
return stats
|
| 551 |
+
|
| 552 |
+
def cleanup(self):
|
| 553 |
+
"""Clean up resources."""
|
| 554 |
+
if hasattr(self, 'executor'):
|
| 555 |
+
self.executor.shutdown(wait=True)
|
| 556 |
+
logger.info("Enhanced VAD cleaned up")
|
| 557 |
+
|
| 558 |
+
# Convenience function for creating enhanced VAD
|
| 559 |
+
def create_enhanced_vad(config: Optional[Dict[str, Any]] = None) -> EnhancedVAD:
|
| 560 |
+
"""Create enhanced VAD with optional configuration."""
|
| 561 |
+
if config is None:
|
| 562 |
+
config = {}
|
| 563 |
+
|
| 564 |
+
return EnhancedVAD(
|
| 565 |
+
sample_rate=config.get('sample_rate', 16000),
|
| 566 |
+
frame_duration_ms=config.get('frame_duration_ms', 30),
|
| 567 |
+
aggressiveness=config.get('aggressiveness', 1),
|
| 568 |
+
min_speech_duration=config.get('min_speech_duration', 0.4),
|
| 569 |
+
max_speech_duration=config.get('max_speech_duration', 3.0),
|
| 570 |
+
silence_threshold=config.get('silence_threshold', 0.01)
|
| 571 |
+
)
|
utils/logging_utils.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
from typing import Dict, List, Any
|
| 4 |
+
from collections import defaultdict, deque
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
class PerformanceLogger:
|
| 8 |
+
"""
|
| 9 |
+
Performance logger for tracking audio processing metrics.
|
| 10 |
+
Provides detailed logging and statistics for each processing method.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, max_history: int = 100):
|
| 14 |
+
self.max_history = max_history
|
| 15 |
+
self.method_stats = defaultdict(lambda: {
|
| 16 |
+
'predictions': deque(maxlen=max_history),
|
| 17 |
+
'inference_times': deque(maxlen=max_history),
|
| 18 |
+
'errors': deque(maxlen=max_history),
|
| 19 |
+
'total_calls': 0,
|
| 20 |
+
'total_errors': 0
|
| 21 |
+
})
|
| 22 |
+
|
| 23 |
+
# Setup structured logging
|
| 24 |
+
self.setup_logging()
|
| 25 |
+
|
| 26 |
+
def setup_logging(self):
|
| 27 |
+
"""Setup structured logging with proper formatting."""
|
| 28 |
+
# Create custom formatter
|
| 29 |
+
formatter = logging.Formatter(
|
| 30 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Setup console handler
|
| 34 |
+
console_handler = logging.StreamHandler()
|
| 35 |
+
console_handler.setFormatter(formatter)
|
| 36 |
+
|
| 37 |
+
# Setup file handler
|
| 38 |
+
file_handler = logging.FileHandler('audio_digit_classifier.log')
|
| 39 |
+
file_handler.setFormatter(formatter)
|
| 40 |
+
|
| 41 |
+
# Configure root logger
|
| 42 |
+
logging.basicConfig(
|
| 43 |
+
level=logging.DEBUG,
|
| 44 |
+
handlers=[console_handler, file_handler]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
self.logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
def log_prediction(self, method: str, result: Dict[str, Any]):
|
| 50 |
+
"""
|
| 51 |
+
Log a prediction result with performance metrics.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
method: Processing method name
|
| 55 |
+
result: Prediction result dictionary
|
| 56 |
+
"""
|
| 57 |
+
stats = self.method_stats[method]
|
| 58 |
+
stats['total_calls'] += 1
|
| 59 |
+
|
| 60 |
+
if result.get('success', True):
|
| 61 |
+
stats['predictions'].append({
|
| 62 |
+
'digit': result.get('predicted_digit'),
|
| 63 |
+
'timestamp': result.get('timestamp', time.time()),
|
| 64 |
+
'inference_time': result.get('inference_time', 0)
|
| 65 |
+
})
|
| 66 |
+
stats['inference_times'].append(result.get('inference_time', 0))
|
| 67 |
+
|
| 68 |
+
self.logger.info(json.dumps({
|
| 69 |
+
'event': 'prediction',
|
| 70 |
+
'method': method,
|
| 71 |
+
'digit': result.get('predicted_digit'),
|
| 72 |
+
'inference_time': result.get('inference_time'),
|
| 73 |
+
'timestamp': result.get('timestamp')
|
| 74 |
+
}))
|
| 75 |
+
else:
|
| 76 |
+
stats['total_errors'] += 1
|
| 77 |
+
stats['errors'].append({
|
| 78 |
+
'error': result.get('error'),
|
| 79 |
+
'timestamp': result.get('timestamp', time.time()),
|
| 80 |
+
'inference_time': result.get('inference_time', 0)
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
self.logger.error(json.dumps({
|
| 84 |
+
'event': 'error',
|
| 85 |
+
'method': method,
|
| 86 |
+
'error': result.get('error'),
|
| 87 |
+
'timestamp': result.get('timestamp')
|
| 88 |
+
}))
|
| 89 |
+
|
| 90 |
+
def get_method_stats(self, method: str) -> Dict[str, Any]:
|
| 91 |
+
"""
|
| 92 |
+
Get performance statistics for a specific method.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
method: Processing method name
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Dictionary with performance statistics
|
| 99 |
+
"""
|
| 100 |
+
stats = self.method_stats[method]
|
| 101 |
+
inference_times = list(stats['inference_times'])
|
| 102 |
+
|
| 103 |
+
if not inference_times:
|
| 104 |
+
return {
|
| 105 |
+
'method': method,
|
| 106 |
+
'total_calls': stats['total_calls'],
|
| 107 |
+
'successful_predictions': 0,
|
| 108 |
+
'error_rate': 0.0,
|
| 109 |
+
'avg_inference_time': 0.0,
|
| 110 |
+
'min_inference_time': 0.0,
|
| 111 |
+
'max_inference_time': 0.0
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
successful_predictions = len(inference_times)
|
| 115 |
+
error_rate = stats['total_errors'] / stats['total_calls'] if stats['total_calls'] > 0 else 0
|
| 116 |
+
|
| 117 |
+
return {
|
| 118 |
+
'method': method,
|
| 119 |
+
'total_calls': stats['total_calls'],
|
| 120 |
+
'successful_predictions': successful_predictions,
|
| 121 |
+
'error_rate': round(error_rate * 100, 2),
|
| 122 |
+
'avg_inference_time': round(sum(inference_times) / len(inference_times), 3),
|
| 123 |
+
'min_inference_time': round(min(inference_times), 3),
|
| 124 |
+
'max_inference_time': round(max(inference_times), 3),
|
| 125 |
+
'recent_predictions': list(stats['predictions'])[-10:] # Last 10 predictions
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
def get_all_stats(self) -> Dict[str, Any]:
|
| 129 |
+
"""Get statistics for all processing methods."""
|
| 130 |
+
all_stats = {}
|
| 131 |
+
for method in self.method_stats.keys():
|
| 132 |
+
all_stats[method] = self.get_method_stats(method)
|
| 133 |
+
|
| 134 |
+
return all_stats
|
| 135 |
+
|
| 136 |
+
def get_comparison_report(self) -> str:
|
| 137 |
+
"""
|
| 138 |
+
Generate a comparison report of all processing methods.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Formatted string with method comparison
|
| 142 |
+
"""
|
| 143 |
+
all_stats = self.get_all_stats()
|
| 144 |
+
|
| 145 |
+
if not all_stats:
|
| 146 |
+
return "No statistics available yet."
|
| 147 |
+
|
| 148 |
+
report = "\n=== Audio Processing Method Comparison ===\n\n"
|
| 149 |
+
|
| 150 |
+
for method, stats in all_stats.items():
|
| 151 |
+
report += f"Method: {method}\n"
|
| 152 |
+
report += f" Total Calls: {stats['total_calls']}\n"
|
| 153 |
+
report += f" Successful: {stats['successful_predictions']}\n"
|
| 154 |
+
report += f" Error Rate: {stats['error_rate']}%\n"
|
| 155 |
+
report += f" Avg Time: {stats['avg_inference_time']}s\n"
|
| 156 |
+
report += f" Min/Max: {stats['min_inference_time']}s / {stats['max_inference_time']}s\n"
|
| 157 |
+
report += "\n"
|
| 158 |
+
|
| 159 |
+
# Find best performing method
|
| 160 |
+
if len(all_stats) > 1:
|
| 161 |
+
best_speed = min(all_stats.items(), key=lambda x: x[1]['avg_inference_time'])
|
| 162 |
+
best_accuracy = min(all_stats.items(), key=lambda x: x[1]['error_rate'])
|
| 163 |
+
|
| 164 |
+
report += f"Fastest Method: {best_speed[0]} ({best_speed[1]['avg_inference_time']}s avg)\n"
|
| 165 |
+
report += f"Most Accurate: {best_accuracy[0]} ({best_accuracy[1]['error_rate']}% error rate)\n"
|
| 166 |
+
|
| 167 |
+
return report
|
| 168 |
+
|
| 169 |
+
def log_system_info(self, info: Dict[str, Any]):
|
| 170 |
+
"""Log system information for debugging."""
|
| 171 |
+
self.logger.info(json.dumps({
|
| 172 |
+
'event': 'system_info',
|
| 173 |
+
'timestamp': time.time(),
|
| 174 |
+
**info
|
| 175 |
+
}))
|
| 176 |
+
|
| 177 |
+
def log_audio_info(self, duration: float, format_info: Dict[str, Any]):
|
| 178 |
+
"""Log audio input information."""
|
| 179 |
+
self.logger.debug(json.dumps({
|
| 180 |
+
'event': 'audio_input',
|
| 181 |
+
'duration': duration,
|
| 182 |
+
'format': format_info,
|
| 183 |
+
'timestamp': time.time()
|
| 184 |
+
}))
|
| 185 |
+
|
| 186 |
+
# Global performance logger instance
|
| 187 |
+
performance_logger = PerformanceLogger()
|
| 188 |
+
|
| 189 |
+
def setup_flask_logging(app):
|
| 190 |
+
"""Setup logging configuration for Flask application."""
|
| 191 |
+
if not app.debug:
|
| 192 |
+
# Production logging
|
| 193 |
+
file_handler = logging.FileHandler('flask_app.log')
|
| 194 |
+
file_handler.setFormatter(logging.Formatter(
|
| 195 |
+
'%(asctime)s %(levelname)s %(name)s %(message)s'
|
| 196 |
+
))
|
| 197 |
+
file_handler.setLevel(logging.INFO)
|
| 198 |
+
app.logger.addHandler(file_handler)
|
| 199 |
+
app.logger.setLevel(logging.INFO)
|
| 200 |
+
|
| 201 |
+
app.logger.info('Audio Digit Classifier startup')
|
utils/noise_utils.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import wave
|
| 3 |
+
import io
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Literal, Optional
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
NoiseType = Literal['white', 'pink', 'brown', 'gaussian', 'background', 'speech']
|
| 10 |
+
|
| 11 |
+
class NoiseGenerator:
|
| 12 |
+
"""
|
| 13 |
+
Audio noise generator for robustness testing.
|
| 14 |
+
Supports various types of noise injection for testing digit recognition.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
def generate_white_noise(self, duration: float, sample_rate: int = 16000,
|
| 21 |
+
amplitude: float = 0.1) -> np.ndarray:
|
| 22 |
+
"""
|
| 23 |
+
Generate white noise signal.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
duration: Duration in seconds
|
| 27 |
+
sample_rate: Sample rate in Hz
|
| 28 |
+
amplitude: Noise amplitude (0.0 to 1.0)
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Numpy array of white noise
|
| 32 |
+
"""
|
| 33 |
+
samples = int(duration * sample_rate)
|
| 34 |
+
noise = np.random.normal(0, amplitude, samples)
|
| 35 |
+
return noise.astype(np.float32)
|
| 36 |
+
|
| 37 |
+
def generate_pink_noise(self, duration: float, sample_rate: int = 16000,
|
| 38 |
+
amplitude: float = 0.1) -> np.ndarray:
|
| 39 |
+
"""
|
| 40 |
+
Generate pink noise (1/f noise).
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
duration: Duration in seconds
|
| 44 |
+
sample_rate: Sample rate in Hz
|
| 45 |
+
amplitude: Noise amplitude
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Numpy array of pink noise
|
| 49 |
+
"""
|
| 50 |
+
samples = int(duration * sample_rate)
|
| 51 |
+
|
| 52 |
+
# Generate white noise
|
| 53 |
+
white = np.random.randn(samples)
|
| 54 |
+
|
| 55 |
+
# Apply 1/f filter in frequency domain
|
| 56 |
+
freqs = np.fft.fftfreq(samples, 1/sample_rate)
|
| 57 |
+
freqs[0] = 1 # Avoid division by zero
|
| 58 |
+
|
| 59 |
+
# 1/f filter
|
| 60 |
+
filter_response = 1.0 / np.sqrt(np.abs(freqs))
|
| 61 |
+
filter_response[0] = 0
|
| 62 |
+
|
| 63 |
+
# Apply filter
|
| 64 |
+
white_fft = np.fft.fft(white)
|
| 65 |
+
pink_fft = white_fft * filter_response
|
| 66 |
+
pink = np.real(np.fft.ifft(pink_fft))
|
| 67 |
+
|
| 68 |
+
# Normalize and scale
|
| 69 |
+
pink = pink / np.std(pink) * amplitude
|
| 70 |
+
return pink.astype(np.float32)
|
| 71 |
+
|
| 72 |
+
def generate_brown_noise(self, duration: float, sample_rate: int = 16000,
|
| 73 |
+
amplitude: float = 0.1) -> np.ndarray:
|
| 74 |
+
"""
|
| 75 |
+
Generate brown noise (1/f^2 noise).
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
duration: Duration in seconds
|
| 79 |
+
sample_rate: Sample rate in Hz
|
| 80 |
+
amplitude: Noise amplitude
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Numpy array of brown noise
|
| 84 |
+
"""
|
| 85 |
+
samples = int(duration * sample_rate)
|
| 86 |
+
|
| 87 |
+
# Generate white noise and integrate (cumulative sum)
|
| 88 |
+
white = np.random.randn(samples)
|
| 89 |
+
brown = np.cumsum(white)
|
| 90 |
+
|
| 91 |
+
# Normalize and scale
|
| 92 |
+
brown = brown / np.std(brown) * amplitude
|
| 93 |
+
return brown.astype(np.float32)
|
| 94 |
+
|
| 95 |
+
def generate_gaussian_noise(self, duration: float, sample_rate: int = 16000,
|
| 96 |
+
amplitude: float = 0.1) -> np.ndarray:
|
| 97 |
+
"""
|
| 98 |
+
Generate Gaussian (normal distribution) noise.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
duration: Duration in seconds
|
| 102 |
+
sample_rate: Sample rate in Hz
|
| 103 |
+
amplitude: Noise amplitude (standard deviation)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Numpy array of Gaussian noise
|
| 107 |
+
"""
|
| 108 |
+
samples = int(duration * sample_rate)
|
| 109 |
+
noise = np.random.normal(0, amplitude, samples)
|
| 110 |
+
return noise.astype(np.float32)
|
| 111 |
+
|
| 112 |
+
def generate_background_noise(self, duration: float, sample_rate: int = 16000,
|
| 113 |
+
amplitude: float = 0.05) -> np.ndarray:
|
| 114 |
+
"""
|
| 115 |
+
Generate realistic background noise (mixture of different noise types).
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
duration: Duration in seconds
|
| 119 |
+
sample_rate: Sample rate in Hz
|
| 120 |
+
amplitude: Noise amplitude
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Numpy array of background noise
|
| 124 |
+
"""
|
| 125 |
+
# Mix different types of noise
|
| 126 |
+
white = self.generate_white_noise(duration, sample_rate, amplitude * 0.3)
|
| 127 |
+
pink = self.generate_pink_noise(duration, sample_rate, amplitude * 0.5)
|
| 128 |
+
|
| 129 |
+
# Add some low-frequency rumble
|
| 130 |
+
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
| 131 |
+
rumble = amplitude * 0.2 * np.sin(2 * np.pi * 60 * t) # 60 Hz hum
|
| 132 |
+
|
| 133 |
+
background = white + pink + rumble
|
| 134 |
+
return background.astype(np.float32)
|
| 135 |
+
|
| 136 |
+
def inject_noise(self, audio_data: bytes, noise_type: NoiseType,
|
| 137 |
+
noise_level: float = 0.1) -> bytes:
|
| 138 |
+
"""
|
| 139 |
+
Inject noise into existing audio data.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
audio_data: Original audio bytes (WAV format)
|
| 143 |
+
noise_type: Type of noise to inject
|
| 144 |
+
noise_level: Noise level relative to signal (0.0 to 1.0)
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Audio bytes with noise injected
|
| 148 |
+
|
| 149 |
+
Raises:
|
| 150 |
+
Exception: If noise injection fails
|
| 151 |
+
"""
|
| 152 |
+
try:
|
| 153 |
+
# Convert input audio to numpy
|
| 154 |
+
with wave.open(io.BytesIO(audio_data), 'rb') as wav_file:
|
| 155 |
+
frames = wav_file.readframes(wav_file.getnframes())
|
| 156 |
+
sample_rate = wav_file.getframerate()
|
| 157 |
+
channels = wav_file.getnchannels()
|
| 158 |
+
sample_width = wav_file.getsampwidth()
|
| 159 |
+
|
| 160 |
+
if sample_width != 2:
|
| 161 |
+
raise Exception(f"Unsupported sample width: {sample_width}")
|
| 162 |
+
|
| 163 |
+
audio_array = np.frombuffer(frames, dtype=np.int16)
|
| 164 |
+
|
| 165 |
+
# Convert to float
|
| 166 |
+
audio_float = audio_array.astype(np.float32) / 32767.0
|
| 167 |
+
|
| 168 |
+
# Handle stereo
|
| 169 |
+
if channels == 2:
|
| 170 |
+
audio_float = audio_float.reshape(-1, 2)
|
| 171 |
+
# Process each channel separately
|
| 172 |
+
for ch in range(2):
|
| 173 |
+
channel_data = audio_float[:, ch]
|
| 174 |
+
duration = len(channel_data) / sample_rate
|
| 175 |
+
|
| 176 |
+
# Generate appropriate noise
|
| 177 |
+
if noise_type == 'white':
|
| 178 |
+
noise = self.generate_white_noise(duration, sample_rate, noise_level)
|
| 179 |
+
elif noise_type == 'pink':
|
| 180 |
+
noise = self.generate_pink_noise(duration, sample_rate, noise_level)
|
| 181 |
+
elif noise_type == 'brown':
|
| 182 |
+
noise = self.generate_brown_noise(duration, sample_rate, noise_level)
|
| 183 |
+
elif noise_type == 'gaussian':
|
| 184 |
+
noise = self.generate_gaussian_noise(duration, sample_rate, noise_level)
|
| 185 |
+
elif noise_type == 'background':
|
| 186 |
+
noise = self.generate_background_noise(duration, sample_rate, noise_level)
|
| 187 |
+
else:
|
| 188 |
+
raise Exception(f"Unsupported noise type: {noise_type}")
|
| 189 |
+
|
| 190 |
+
# Ensure same length
|
| 191 |
+
if len(noise) != len(channel_data):
|
| 192 |
+
noise = noise[:len(channel_data)]
|
| 193 |
+
|
| 194 |
+
# Add noise
|
| 195 |
+
audio_float[:, ch] = channel_data + noise
|
| 196 |
+
|
| 197 |
+
# Flatten back
|
| 198 |
+
audio_float = audio_float.flatten()
|
| 199 |
+
else:
|
| 200 |
+
# Mono processing
|
| 201 |
+
duration = len(audio_float) / sample_rate
|
| 202 |
+
|
| 203 |
+
# Generate noise
|
| 204 |
+
if noise_type == 'white':
|
| 205 |
+
noise = self.generate_white_noise(duration, sample_rate, noise_level)
|
| 206 |
+
elif noise_type == 'pink':
|
| 207 |
+
noise = self.generate_pink_noise(duration, sample_rate, noise_level)
|
| 208 |
+
elif noise_type == 'brown':
|
| 209 |
+
noise = self.generate_brown_noise(duration, sample_rate, noise_level)
|
| 210 |
+
elif noise_type == 'gaussian':
|
| 211 |
+
noise = self.generate_gaussian_noise(duration, sample_rate, noise_level)
|
| 212 |
+
elif noise_type == 'background':
|
| 213 |
+
noise = self.generate_background_noise(duration, sample_rate, noise_level)
|
| 214 |
+
else:
|
| 215 |
+
raise Exception(f"Unsupported noise type: {noise_type}")
|
| 216 |
+
|
| 217 |
+
# Ensure same length
|
| 218 |
+
if len(noise) != len(audio_float):
|
| 219 |
+
noise = noise[:len(audio_float)]
|
| 220 |
+
|
| 221 |
+
# Add noise
|
| 222 |
+
audio_float = audio_float + noise
|
| 223 |
+
|
| 224 |
+
# Clip to prevent overflow
|
| 225 |
+
audio_float = np.clip(audio_float, -1.0, 1.0)
|
| 226 |
+
|
| 227 |
+
# Convert back to int16
|
| 228 |
+
audio_int16 = (audio_float * 32767).astype(np.int16)
|
| 229 |
+
|
| 230 |
+
# Create output WAV
|
| 231 |
+
output = io.BytesIO()
|
| 232 |
+
with wave.open(output, 'wb') as output_wav:
|
| 233 |
+
output_wav.setnchannels(channels)
|
| 234 |
+
output_wav.setsampwidth(sample_width)
|
| 235 |
+
output_wav.setframerate(sample_rate)
|
| 236 |
+
output_wav.writeframes(audio_int16.tobytes())
|
| 237 |
+
|
| 238 |
+
self.logger.debug(f"Injected {noise_type} noise at level {noise_level}")
|
| 239 |
+
return output.getvalue()
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
self.logger.error(f"Noise injection failed: {str(e)}")
|
| 243 |
+
raise Exception(f"Failed to inject noise: {str(e)}")
|
| 244 |
+
|
| 245 |
+
def create_pure_noise(self, noise_type: NoiseType, duration: float = 1.0,
|
| 246 |
+
sample_rate: int = 16000, amplitude: float = 0.3) -> bytes:
|
| 247 |
+
"""
|
| 248 |
+
Create pure noise audio file for testing.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
noise_type: Type of noise to generate
|
| 252 |
+
duration: Duration in seconds
|
| 253 |
+
sample_rate: Sample rate in Hz
|
| 254 |
+
amplitude: Noise amplitude
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
WAV audio bytes containing pure noise
|
| 258 |
+
"""
|
| 259 |
+
try:
|
| 260 |
+
# Generate noise
|
| 261 |
+
if noise_type == 'white':
|
| 262 |
+
noise = self.generate_white_noise(duration, sample_rate, amplitude)
|
| 263 |
+
elif noise_type == 'pink':
|
| 264 |
+
noise = self.generate_pink_noise(duration, sample_rate, amplitude)
|
| 265 |
+
elif noise_type == 'brown':
|
| 266 |
+
noise = self.generate_brown_noise(duration, sample_rate, amplitude)
|
| 267 |
+
elif noise_type == 'gaussian':
|
| 268 |
+
noise = self.generate_gaussian_noise(duration, sample_rate, amplitude)
|
| 269 |
+
elif noise_type == 'background':
|
| 270 |
+
noise = self.generate_background_noise(duration, sample_rate, amplitude)
|
| 271 |
+
else:
|
| 272 |
+
raise Exception(f"Unsupported noise type: {noise_type}")
|
| 273 |
+
|
| 274 |
+
# Convert to int16
|
| 275 |
+
noise_int16 = (np.clip(noise, -1.0, 1.0) * 32767).astype(np.int16)
|
| 276 |
+
|
| 277 |
+
# Create WAV
|
| 278 |
+
output = io.BytesIO()
|
| 279 |
+
with wave.open(output, 'wb') as wav_file:
|
| 280 |
+
wav_file.setnchannels(1) # Mono
|
| 281 |
+
wav_file.setsampwidth(2) # 16-bit
|
| 282 |
+
wav_file.setframerate(sample_rate)
|
| 283 |
+
wav_file.writeframes(noise_int16.tobytes())
|
| 284 |
+
|
| 285 |
+
return output.getvalue()
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
self.logger.error(f"Pure noise generation failed: {str(e)}")
|
| 289 |
+
raise Exception(f"Failed to create pure noise: {str(e)}")
|
| 290 |
+
|
| 291 |
+
# Global noise generator instance
|
| 292 |
+
noise_generator = NoiseGenerator()
|
utils/session_manager.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Session Management for Audio Chunk Storage
|
| 3 |
+
Handles session creation, audio chunk saving, and folder organization
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
import logging
|
| 10 |
+
import wave
|
| 11 |
+
import numpy as np
|
| 12 |
+
from typing import Dict, Optional, List
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import json
|
| 15 |
+
import threading
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class SessionManager:
|
| 20 |
+
"""
|
| 21 |
+
Manages audio recording sessions with systematic file storage.
|
| 22 |
+
Each session gets a unique ID and folder for organized chunk storage.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, base_output_dir: str = "output"):
|
| 26 |
+
"""
|
| 27 |
+
Initialize session manager.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
base_output_dir: Base directory for all session outputs
|
| 31 |
+
"""
|
| 32 |
+
self.base_output_dir = Path(base_output_dir)
|
| 33 |
+
self.base_output_dir.mkdir(exist_ok=True)
|
| 34 |
+
|
| 35 |
+
# Active sessions tracking
|
| 36 |
+
self.active_sessions: Dict[str, 'AudioSession'] = {}
|
| 37 |
+
self.lock = threading.Lock()
|
| 38 |
+
|
| 39 |
+
logger.info(f"Session manager initialized with output directory: {self.base_output_dir}")
|
| 40 |
+
|
| 41 |
+
def create_session(self, session_id: Optional[str] = None) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Create a new audio recording session.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
session_id: Optional custom session ID, otherwise auto-generated
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
str: Session ID
|
| 50 |
+
"""
|
| 51 |
+
if not session_id:
|
| 52 |
+
# Generate session ID with timestamp and short UUID
|
| 53 |
+
timestamp = int(time.time())
|
| 54 |
+
short_uuid = str(uuid.uuid4())[:8]
|
| 55 |
+
session_id = f"session{timestamp}_{short_uuid}"
|
| 56 |
+
|
| 57 |
+
with self.lock:
|
| 58 |
+
if session_id in self.active_sessions:
|
| 59 |
+
logger.warning(f"Session {session_id} already exists, returning existing session")
|
| 60 |
+
return session_id
|
| 61 |
+
|
| 62 |
+
# Create session object
|
| 63 |
+
session = AudioSession(session_id, self.base_output_dir)
|
| 64 |
+
self.active_sessions[session_id] = session
|
| 65 |
+
|
| 66 |
+
logger.info(f"Created new session: {session_id}")
|
| 67 |
+
return session_id
|
| 68 |
+
|
| 69 |
+
def get_session(self, session_id: str) -> Optional['AudioSession']:
|
| 70 |
+
"""Get an existing session by ID."""
|
| 71 |
+
with self.lock:
|
| 72 |
+
return self.active_sessions.get(session_id)
|
| 73 |
+
|
| 74 |
+
def close_session(self, session_id: str) -> bool:
|
| 75 |
+
"""
|
| 76 |
+
Close and finalize a session.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
session_id: Session to close
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
bool: True if session was closed successfully
|
| 83 |
+
"""
|
| 84 |
+
with self.lock:
|
| 85 |
+
if session_id not in self.active_sessions:
|
| 86 |
+
logger.warning(f"Session {session_id} not found")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
session = self.active_sessions[session_id]
|
| 90 |
+
session.finalize()
|
| 91 |
+
del self.active_sessions[session_id]
|
| 92 |
+
|
| 93 |
+
logger.info(f"Closed session: {session_id} ({session.chunk_count} chunks saved)")
|
| 94 |
+
return True
|
| 95 |
+
|
| 96 |
+
def cleanup_old_sessions(self, max_age_hours: int = 24) -> int:
|
| 97 |
+
"""
|
| 98 |
+
Clean up sessions older than specified hours.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
max_age_hours: Maximum age in hours before cleanup
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
int: Number of sessions cleaned up
|
| 105 |
+
"""
|
| 106 |
+
cutoff_time = time.time() - (max_age_hours * 3600)
|
| 107 |
+
cleaned_count = 0
|
| 108 |
+
|
| 109 |
+
# Find old session folders
|
| 110 |
+
for session_dir in self.base_output_dir.iterdir():
|
| 111 |
+
if not session_dir.is_dir() or not session_dir.name.startswith('session'):
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
# Check if session has a metadata file with creation time
|
| 116 |
+
metadata_file = session_dir / "session_info.json"
|
| 117 |
+
if metadata_file.exists():
|
| 118 |
+
with open(metadata_file, 'r') as f:
|
| 119 |
+
metadata = json.load(f)
|
| 120 |
+
if metadata.get('created_at', 0) < cutoff_time:
|
| 121 |
+
import shutil
|
| 122 |
+
shutil.rmtree(session_dir)
|
| 123 |
+
cleaned_count += 1
|
| 124 |
+
logger.info(f"Cleaned up old session: {session_dir.name}")
|
| 125 |
+
else:
|
| 126 |
+
# Fallback to directory modification time
|
| 127 |
+
if session_dir.stat().st_mtime < cutoff_time:
|
| 128 |
+
import shutil
|
| 129 |
+
shutil.rmtree(session_dir)
|
| 130 |
+
cleaned_count += 1
|
| 131 |
+
logger.info(f"Cleaned up old session: {session_dir.name}")
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Error cleaning up session {session_dir.name}: {e}")
|
| 135 |
+
|
| 136 |
+
if cleaned_count > 0:
|
| 137 |
+
logger.info(f"Cleaned up {cleaned_count} old sessions")
|
| 138 |
+
|
| 139 |
+
return cleaned_count
|
| 140 |
+
|
| 141 |
+
def get_session_stats(self) -> Dict:
|
| 142 |
+
"""Get statistics about all sessions."""
|
| 143 |
+
with self.lock:
|
| 144 |
+
stats = {
|
| 145 |
+
'active_sessions': len(self.active_sessions),
|
| 146 |
+
'total_chunks_active': sum(s.chunk_count for s in self.active_sessions.values()),
|
| 147 |
+
'session_details': {
|
| 148 |
+
sid: {
|
| 149 |
+
'chunk_count': session.chunk_count,
|
| 150 |
+
'created_at': session.created_at,
|
| 151 |
+
'folder_path': str(session.session_dir)
|
| 152 |
+
}
|
| 153 |
+
for sid, session in self.active_sessions.items()
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Count total session folders
|
| 158 |
+
total_session_dirs = len([
|
| 159 |
+
d for d in self.base_output_dir.iterdir()
|
| 160 |
+
if d.is_dir() and d.name.startswith('session')
|
| 161 |
+
])
|
| 162 |
+
stats['total_session_folders'] = total_session_dirs
|
| 163 |
+
|
| 164 |
+
return stats
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class AudioSession:
|
| 168 |
+
"""
|
| 169 |
+
Represents a single audio recording session with systematic chunk storage.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, session_id: str, base_output_dir: Path):
|
| 173 |
+
"""
|
| 174 |
+
Initialize audio session.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
session_id: Unique session identifier
|
| 178 |
+
base_output_dir: Base directory for output
|
| 179 |
+
"""
|
| 180 |
+
self.session_id = session_id
|
| 181 |
+
self.created_at = time.time()
|
| 182 |
+
self.chunk_count = 0
|
| 183 |
+
|
| 184 |
+
# Create session directory
|
| 185 |
+
self.session_dir = base_output_dir / session_id
|
| 186 |
+
self.session_dir.mkdir(exist_ok=True)
|
| 187 |
+
|
| 188 |
+
# Create subdirectories
|
| 189 |
+
self.chunks_dir = self.session_dir / "chunks"
|
| 190 |
+
self.chunks_dir.mkdir(exist_ok=True)
|
| 191 |
+
|
| 192 |
+
# Session metadata
|
| 193 |
+
self.metadata = {
|
| 194 |
+
'session_id': session_id,
|
| 195 |
+
'created_at': self.created_at,
|
| 196 |
+
'created_at_human': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.created_at)),
|
| 197 |
+
'chunk_count': 0,
|
| 198 |
+
'chunks': []
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
self._save_metadata()
|
| 202 |
+
logger.info(f"Session folder created: {self.session_dir}")
|
| 203 |
+
|
| 204 |
+
def save_audio_chunk(self, audio_data: bytes, prediction_result: Optional[Dict] = None,
|
| 205 |
+
chunk_type: str = "speech") -> str:
|
| 206 |
+
"""
|
| 207 |
+
Save an audio chunk to the session folder.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
audio_data: Raw audio bytes (WAV format preferred)
|
| 211 |
+
prediction_result: Optional prediction results to save alongside
|
| 212 |
+
chunk_type: Type of chunk ("speech", "vad_segment", "raw", etc.)
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
str: Path to saved chunk file
|
| 216 |
+
"""
|
| 217 |
+
self.chunk_count += 1
|
| 218 |
+
|
| 219 |
+
# Generate chunk filename
|
| 220 |
+
chunk_filename = f"{self.chunk_count:03d}.wav"
|
| 221 |
+
chunk_path = self.chunks_dir / chunk_filename
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
# Save audio data
|
| 225 |
+
if self._is_wav_format(audio_data):
|
| 226 |
+
# Already WAV format, save directly
|
| 227 |
+
with open(chunk_path, 'wb') as f:
|
| 228 |
+
f.write(audio_data)
|
| 229 |
+
logger.debug(f"Saved WAV chunk: {chunk_path}")
|
| 230 |
+
else:
|
| 231 |
+
# Convert raw PCM to WAV
|
| 232 |
+
self._save_pcm_as_wav(audio_data, chunk_path)
|
| 233 |
+
logger.debug(f"Converted and saved PCM chunk: {chunk_path}")
|
| 234 |
+
|
| 235 |
+
# Update metadata
|
| 236 |
+
chunk_info = {
|
| 237 |
+
'chunk_id': self.chunk_count,
|
| 238 |
+
'filename': chunk_filename,
|
| 239 |
+
'chunk_type': chunk_type,
|
| 240 |
+
'size_bytes': len(audio_data),
|
| 241 |
+
'saved_at': time.time(),
|
| 242 |
+
'saved_at_human': time.strftime('%Y-%m-%d %H:%M:%S'),
|
| 243 |
+
'audio_format': 'wav' if self._is_wav_format(audio_data) else 'pcm_converted'
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
# Add prediction results if provided
|
| 247 |
+
if prediction_result:
|
| 248 |
+
chunk_info['prediction'] = prediction_result
|
| 249 |
+
|
| 250 |
+
self.metadata['chunks'].append(chunk_info)
|
| 251 |
+
self.metadata['chunk_count'] = self.chunk_count
|
| 252 |
+
self._save_metadata()
|
| 253 |
+
|
| 254 |
+
logger.info(f"Saved audio chunk {self.chunk_count}: {chunk_path}")
|
| 255 |
+
return str(chunk_path)
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Failed to save audio chunk {self.chunk_count}: {e}")
|
| 259 |
+
# Rollback chunk count on failure
|
| 260 |
+
self.chunk_count -= 1
|
| 261 |
+
raise
|
| 262 |
+
|
| 263 |
+
def _is_wav_format(self, audio_data: bytes) -> bool:
|
| 264 |
+
"""Check if audio data is in WAV format."""
|
| 265 |
+
return audio_data.startswith(b'RIFF') and b'WAVE' in audio_data[:12]
|
| 266 |
+
|
| 267 |
+
def _save_pcm_as_wav(self, pcm_data: bytes, output_path: Path,
|
| 268 |
+
sample_rate: int = 16000, channels: int = 1, sample_width: int = 2):
|
| 269 |
+
"""
|
| 270 |
+
Convert raw PCM data to WAV format and save.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
pcm_data: Raw PCM bytes
|
| 274 |
+
output_path: Output WAV file path
|
| 275 |
+
sample_rate: Sample rate (default 16kHz for speech)
|
| 276 |
+
channels: Number of channels (default mono)
|
| 277 |
+
sample_width: Sample width in bytes (default 16-bit)
|
| 278 |
+
"""
|
| 279 |
+
try:
|
| 280 |
+
with wave.open(str(output_path), 'wb') as wav_file:
|
| 281 |
+
wav_file.setnchannels(channels)
|
| 282 |
+
wav_file.setsampwidth(sample_width)
|
| 283 |
+
wav_file.setframerate(sample_rate)
|
| 284 |
+
wav_file.writeframes(pcm_data)
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(f"PCM to WAV conversion failed: {e}")
|
| 288 |
+
# Fallback: save as raw PCM with .pcm extension
|
| 289 |
+
raw_path = output_path.with_suffix('.pcm')
|
| 290 |
+
with open(raw_path, 'wb') as f:
|
| 291 |
+
f.write(pcm_data)
|
| 292 |
+
logger.warning(f"Saved as raw PCM instead: {raw_path}")
|
| 293 |
+
|
| 294 |
+
def _save_metadata(self):
|
| 295 |
+
"""Save session metadata to JSON file."""
|
| 296 |
+
try:
|
| 297 |
+
metadata_path = self.session_dir / "session_info.json"
|
| 298 |
+
with open(metadata_path, 'w') as f:
|
| 299 |
+
json.dump(self.metadata, f, indent=2, default=str)
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"Failed to save session metadata: {e}")
|
| 302 |
+
|
| 303 |
+
def finalize(self):
|
| 304 |
+
"""Finalize the session and save final metadata."""
|
| 305 |
+
self.metadata['finalized_at'] = time.time()
|
| 306 |
+
self.metadata['finalized_at_human'] = time.strftime('%Y-%m-%d %H:%M:%S')
|
| 307 |
+
self.metadata['final_chunk_count'] = self.chunk_count
|
| 308 |
+
self._save_metadata()
|
| 309 |
+
|
| 310 |
+
logger.info(f"📋 Finalized session {self.session_id}: {self.chunk_count} chunks saved")
|
| 311 |
+
|
| 312 |
+
def get_chunk_list(self) -> List[str]:
|
| 313 |
+
"""Get list of all chunk files in order."""
|
| 314 |
+
chunk_files = []
|
| 315 |
+
for i in range(1, self.chunk_count + 1):
|
| 316 |
+
chunk_file = self.chunks_dir / f"{i:03d}.wav"
|
| 317 |
+
if chunk_file.exists():
|
| 318 |
+
chunk_files.append(str(chunk_file))
|
| 319 |
+
else:
|
| 320 |
+
# Check for .pcm fallback
|
| 321 |
+
pcm_file = self.chunks_dir / f"{i:03d}.pcm"
|
| 322 |
+
if pcm_file.exists():
|
| 323 |
+
chunk_files.append(str(pcm_file))
|
| 324 |
+
return chunk_files
|
| 325 |
+
|
| 326 |
+
def get_session_summary(self) -> Dict:
|
| 327 |
+
"""Get comprehensive session summary."""
|
| 328 |
+
return {
|
| 329 |
+
'session_id': self.session_id,
|
| 330 |
+
'created_at': self.created_at,
|
| 331 |
+
'chunk_count': self.chunk_count,
|
| 332 |
+
'session_dir': str(self.session_dir),
|
| 333 |
+
'chunks_dir': str(self.chunks_dir),
|
| 334 |
+
'chunk_files': self.get_chunk_list(),
|
| 335 |
+
'metadata': self.metadata
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# Global session manager instance
|
| 340 |
+
session_manager = SessionManager()
|