busy-module-audio / handler.py
EurekaPotato's picture
Upload folder using huggingface_hub
9d8ae5e verified
"""
Audio Feature Extraction β€” Hugging Face Inference Endpoint Handler
Extracts all 17 voice features from uploaded audio:
v1_snr, v2_noise_* (5), v3_speech_rate, v4/v5_pitch, v6/v7_energy,
v8/v9/v10_pause, v11/v12/v13_emotion
Derived from: src/audio_features.py, src/emotion_features.py
"""
import io
import numpy as np
import librosa
from scipy import signal as scipy_signal
from typing import Dict
import torch
import torch.nn as nn
from torchvision import models
import warnings
warnings.filterwarnings("ignore")
# ──────────────────────────────────────────────────────────────────────── #
# Imports from standardized modules
# ──────────────────────────────────────────────────────────────────────── #
try:
from audio_features import AudioFeatureExtractor
except ImportError:
# Fallback if running from a different context
import sys
sys.path.append('.')
from audio_features import AudioFeatureExtractor
# Initialize global extractor
# We use a global instance to cache models (VAD, Emotion)
print("[INFO] Initializing Global AudioFeatureExtractor...")
extractor = AudioFeatureExtractor(
sample_rate=16000,
use_emotion=True,
emotion_models_dir="/app/models" # Absolute path in Docker container
)
# Ensure models are downloaded/ready
if extractor.use_emotion and extractor.emotion_extractor:
print("[INFO] Checking for emotion models...")
# Trigger download if needed/possible
try:
if len(extractor.emotion_extractor.models) == 0:
print("[INFO] Models not found, attempting download...")
extractor.emotion_extractor.download_models()
# Re-init manually to load them
extractor.emotion_extractor.__init__(models_dir=extractor.emotion_extractor.models_dir)
except Exception as e:
print(f"[WARN] Failed to download emotion models: {e}")
# ──────────────────────────────────────────────────────────────────────── #
# Helper to handle NaN/Inf for JSON
# ──────────────────────────────────────────────────────────────────────── #
def sanitize_features(features: Dict[str, float]) -> Dict[str, float]:
sanitized = {}
for key, val in features.items():
if isinstance(val, (float, np.floating)):
if np.isnan(val) or np.isinf(val):
sanitized[key] = 0.0
else:
sanitized[key] = float(val)
elif isinstance(val, (int, np.integer)):
sanitized[key] = int(val)
else:
sanitized[key] = val # keep string/other as is
return sanitized
# ──────────────────────────────────────────────────────────────────────── #
# FastAPI handler for deployment (HF Spaces / Cloud Run / Lambda)
# ──────────────────────────────────────────────────────────────────────── #
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional
import base64
import traceback
app = FastAPI(title="Audio Feature Extraction API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Catch any unhandled exceptions and return defaults instead of 500."""
print(f"[GLOBAL ERROR] {request.url}: {exc}")
traceback.print_exc()
return JSONResponse(
status_code=200,
content={**DEFAULT_AUDIO_FEATURES, "_error": str(exc), "_handler": "global"},
)
# Extractor is already initialized globally above
# ──────────────────────────────────────────────────────────────────────── #
# Constants & Defaults
# ──────────────────────────────────────────────────────────────────────── #
DEFAULT_AUDIO_FEATURES = {
"v1_snr": 0.0,
"v2_noise_traffic": 0.0,
"v2_noise_office": 0.0,
"v2_noise_crowd": 0.0,
"v2_noise_wind": 0.0,
"v2_noise_clean": 1.0,
"v3_speech_rate": 0.0,
"v4_pitch_mean": 0.0,
"v5_pitch_std": 0.0,
"v6_energy_mean": 0.0,
"v7_energy_std": 0.0,
"v8_pause_ratio": 0.0,
"v9_avg_pause_dur": 0.0,
"v10_mid_pause_cnt": 0.0,
"v11_emotion_stress": 0.0,
"v12_emotion_energy": 0.0,
"v13_emotion_valence": 0.0,
}
class AudioBase64Request(BaseModel):
audio_base64: str = ""
transcript: str = ""
@app.get("/")
async def root():
return {
"service": "Audio Feature Extraction API",
"version": "1.0.0",
"endpoints": ["/health", "/extract-audio-features", "/extract-audio-features-base64"],
}
@app.get("/health")
async def health():
vad_status = extractor.vad_model is not None
emotion_status = extractor.emotion_extractor is not None if extractor.use_emotion else False
return {
"status": "healthy",
"vad_loaded": vad_status,
"emotion_loaded": emotion_status
}
@app.post("/extract-audio-features")
async def extract_audio_features(audio: UploadFile = File(...), transcript: str = Form("")):
"""Extract all 17 voice features from uploaded audio file."""
try:
audio_bytes = await audio.read()
# librosa.load returns (audio, sr)
y, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
# AudioFeatureExtractor.extract_all expects numpy array and optional transcript
features = extractor.extract_all(y, transcript)
return sanitize_features(features)
except Exception as e:
print(f"[ERROR] extract_audio_features: {e}")
traceback.print_exc()
return {**DEFAULT_AUDIO_FEATURES, "_error": str(e)}
@app.post("/extract-audio-features-base64")
async def extract_audio_features_base64(data: AudioBase64Request):
"""Extract features from base64-encoded audio (for Vercel serverless calls)."""
import soundfile as sf
audio_b64 = data.audio_base64
transcript = data.transcript
# Handle empty / missing audio β€” return default features
if not audio_b64 or len(audio_b64) < 100:
print("[INFO] Empty or too-short audio_base64, returning defaults")
return {**DEFAULT_AUDIO_FEATURES}
try:
# Strip data URL prefix if present (e.g. "data:audio/wav;base64,...")
if "," in audio_b64[:80]:
audio_b64 = audio_b64.split(",", 1)[1]
audio_bytes = base64.b64decode(audio_b64)
print(f"[INFO] Decoded {len(audio_bytes)} bytes of audio")
# Try soundfile first, fall back to librosa
try:
y, sr = sf.read(io.BytesIO(audio_bytes))
except Exception as sf_err:
print(f"[WARN] soundfile failed ({sf_err}), trying librosa...")
y, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
if hasattr(y, 'shape') and len(y.shape) > 1:
y = np.mean(y, axis=1)
y = np.asarray(y, dtype=np.float32)
if sr != 16000:
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
y = y.astype(np.float32)
if len(y) < 100:
print("[WARN] Audio too short after decode, returning defaults")
return {**DEFAULT_AUDIO_FEATURES}
features = extractor.extract_all(y, transcript)
print(f"[OK] Extracted {len(features)} audio features")
return sanitize_features(features)
except Exception as e:
print(f"[ERROR] extract_audio_features_base64: {e}")
traceback.print_exc()
# Return defaults rather than 500
return {**DEFAULT_AUDIO_FEATURES, "_error": str(e)}
if __name__ == "__main__":
import uvicorn
import os
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)