""" FastAPI Backend for Hugging Face Spaces Provides REST API endpoints for audio processing + Text-to-Speech """ from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Request from fastapi.responses import JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import soundfile as sf import tempfile import os from pathlib import Path import logging from typing import Optional import time from collections import defaultdict from datetime import datetime, timedelta import asyncio from huggingface_hub import hf_hub_download # Direct import (no 'backend.' prefix for HF Spaces) from inference_pipeline import EnhancementPipeline # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) BASE_DIR = Path(__file__).parent.resolve() # Security: Allowed file types ALLOWED_EXTENSIONS = {'.wav', '.mp3', '.m4a', '.ogg', '.flac', '.webm'} ALLOWED_MIMETYPES = { 'audio/wav', 'audio/wave', 'audio/x-wav', 'audio/mpeg', 'audio/mp3', 'audio/mp4', 'audio/m4a', 'audio/x-m4a', 'audio/ogg', 'audio/flac', 'audio/webm' } # Initialize FastAPI app app = FastAPI( title="ClearSpeech API", description="Speech Enhancement, Transcription & Text-to-Speech", version="2.1.0", docs_url="/docs", redoc_url="/redoc" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global pipeline instance pipeline = None temp_files = {} # ============================================================================ # SECURITY: Rate Limiting & File Validation # ============================================================================ class SimpleRateLimiter: """Simple in-memory rate limiter for demo protection""" def __init__(self, max_requests: int = 100, window_minutes: int = 60): self.max_requests = max_requests self.window = timedelta(minutes=window_minutes) self.requests = defaultdict(list) self.lock = asyncio.Lock() async def check_rate_limit(self, client_ip: str) -> bool: async with self.lock: now = datetime.now() self.requests[client_ip] = [ ts for ts in self.requests[client_ip] if now - ts < self.window ] if len(self.requests[client_ip]) >= self.max_requests: return False self.requests[client_ip].append(now) return True async def cleanup(self): while True: await asyncio.sleep(3600) async with self.lock: now = datetime.now() for ip in list(self.requests.keys()): self.requests[ip] = [ts for ts in self.requests[ip] if now - ts < self.window] if not self.requests[ip]: del self.requests[ip] rate_limiter = SimpleRateLimiter(max_requests=20, window_minutes=60) def get_client_ip(request: Request) -> str: """Get client IP from request""" forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip return request.client.host if request.client else "unknown" def validate_audio_file(file: UploadFile) -> None: """Validate uploaded file is a safe audio file""" file_ext = Path(file.filename).suffix.lower() if file_ext not in ALLOWED_EXTENSIONS: raise HTTPException( status_code=400, detail=f"Invalid file type '{file_ext}'. Allowed: {', '.join(ALLOWED_EXTENSIONS)}" ) if file.content_type and file.content_type not in ALLOWED_MIMETYPES: raise HTTPException( status_code=400, detail=f"Invalid content type: {file.content_type}" ) if '..' in file.filename or '/' in file.filename or '\\' in file.filename: raise HTTPException(status_code=400, detail="Invalid filename") # Configuration class Config: # Hugging Face Hub Configuration HF_REPO_ID = os.getenv("HF_REPO_ID", "thecodeworm/clearspeech-unet") HF_CHECKPOINT_FILENAME = "best_model.pt" # Local paths CHECKPOINT_DIR = Path(tempfile.gettempdir()) / "clearspeech_models" CNN_CHECKPOINT = CHECKPOINT_DIR / HF_CHECKPOINT_FILENAME # Model configuration WHISPER_MODEL = os.getenv("WHISPER_MODEL", "small") # Can use 'base' with 16GB RAM! DEVICE = os.getenv("DEVICE", "cpu") USE_FP16 = False # Limits MAX_FILE_SIZE = int(os.getenv("MAX_FILE_SIZE", 50 * 1024 * 1024)) TEMP_DIR = Path(tempfile.gettempdir()) / "clearspeech" @classmethod def setup(cls): """Setup: Download checkpoint from Hugging Face Hub""" cls.TEMP_DIR.mkdir(parents=True, exist_ok=True) cls.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) # Download from HF Hub if not exists if not cls.CNN_CHECKPOINT.exists(): logger.info("="*70) logger.info("📥 Downloading model checkpoint from Hugging Face Hub") logger.info("="*70) logger.info(f"Repository: {cls.HF_REPO_ID}") logger.info(f"Filename: {cls.HF_CHECKPOINT_FILENAME}") try: downloaded_path = hf_hub_download( repo_id=cls.HF_REPO_ID, filename=cls.HF_CHECKPOINT_FILENAME, cache_dir=str(cls.CHECKPOINT_DIR.parent), local_dir=str(cls.CHECKPOINT_DIR), local_dir_use_symlinks=False ) cls.CNN_CHECKPOINT = Path(downloaded_path) logger.info(f"✅ Checkpoint downloaded successfully!") logger.info(f" Saved to: {cls.CNN_CHECKPOINT}") logger.info("="*70) except Exception as e: logger.error("="*70) logger.error("❌ Failed to download checkpoint") logger.error("="*70) logger.error(f"Error: {e}") logger.error(f"Please verify HF_REPO_ID: {cls.HF_REPO_ID}") raise else: logger.info(f"✅ Using cached checkpoint: {cls.CNN_CHECKPOINT}") # Response models class ProcessResponse(BaseModel): success: bool transcript: str duration: float language: str enhanced_audio_url: str tts_audio_url: Optional[str] = None segments: list = [] processing_time: float class EnhanceResponse(BaseModel): success: bool enhanced_audio_url: str duration: float processing_time: float class TranscribeResponse(BaseModel): success: bool transcript: str duration: float language: str segments: list = [] processing_time: float class TTSRequest(BaseModel): text: str language: str = "en" voice: str = "default" class HealthResponse(BaseModel): status: str models_loaded: bool cnn_checkpoint: str whisper_model: str device: str tts_available: bool @app.on_event("startup") async def startup_event(): """Load models on server startup""" global pipeline logger.info("🚀 Starting ClearSpeech API Server on Hugging Face Spaces...") try: Config.setup() if not Config.CNN_CHECKPOINT.exists(): raise FileNotFoundError(f"Checkpoint not found: {Config.CNN_CHECKPOINT}") pipeline = EnhancementPipeline( cnn_checkpoint_path=str(Config.CNN_CHECKPOINT), whisper_model_name=Config.WHISPER_MODEL, device=Config.DEVICE, use_fp16=Config.USE_FP16 ) logger.info("✅ Models loaded successfully!") logger.info(f"📍 CNN Checkpoint: {Config.CNN_CHECKPOINT}") logger.info(f"📍 Whisper Model: {Config.WHISPER_MODEL}") logger.info(f"📍 Device: {Config.DEVICE}") # Check TTS try: import gtts logger.info("✅ TTS (gtts) available") except ImportError: logger.warning("⚠️ TTS not available") logger.info("="*70) logger.info("Server ready! Visit /docs for API documentation") logger.info("="*70) # Start rate limiter cleanup asyncio.create_task(rate_limiter.cleanup()) except Exception as e: logger.error(f"❌ Failed to load models: {e}") raise @app.on_event("shutdown") async def shutdown_event(): """Cleanup on server shutdown""" logger.info("Shutting down server...") for filepath in temp_files.values(): try: if Path(filepath).exists(): os.remove(filepath) except Exception as e: logger.warning(f"Failed to cleanup {filepath}: {e}") temp_files.clear() # ============================================================================ # TTS FUNCTIONS # ============================================================================ def generate_tts_gtts(text: str, output_path: str, language: str = "en"): """Generate TTS using gTTS""" try: from gtts import gTTS tts = gTTS(text=text, lang=language, slow=False) tts.save(output_path) return True except Exception as e: logger.error(f"gTTS failed: {e}") return False def generate_tts(text: str, output_path: str, language: str = "en"): """Generate TTS""" return generate_tts_gtts(text, output_path, language) # ============================================================================ # API ENDPOINTS # ============================================================================ @app.get("/") async def root(): """Health check endpoint""" return { "status": "online", "message": "ClearSpeech API - Speech Enhancement, Transcription & TTS", "version": "2.1.0", "platform": "Hugging Face Spaces", "endpoints": { "docs": "/docs", "health": "/health", "process": "/process (POST)", "enhance": "/enhance (POST)", "transcribe": "/transcribe (POST)", "tts": "/tts (POST)", "download": "/download/{filename} (GET)" } } @app.get("/health", response_model=HealthResponse) async def health_check(): """Detailed health check""" tts_available = False try: import gtts tts_available = True except ImportError: pass return { "status": "healthy" if pipeline is not None else "unhealthy", "models_loaded": pipeline is not None, "cnn_checkpoint": str(Config.CNN_CHECKPOINT), "whisper_model": Config.WHISPER_MODEL, "device": Config.DEVICE, "tts_available": tts_available } @app.post("/process", response_model=ProcessResponse) async def process_audio( request: Request, file: UploadFile = File(...), language: Optional[str] = Form(default="en"), skip_enhancement: Optional[str] = Form(default="false"), generate_tts_param: Optional[str] = Form(default="false", alias="generate_tts") ): """Complete pipeline: enhance + transcribe + optional TTS""" # Rate limiting client_ip = get_client_ip(request) if not await rate_limiter.check_rate_limit(client_ip): raise HTTPException( status_code=429, detail="Rate limit exceeded. Max 20 requests per hour." ) if pipeline is None: raise HTTPException(status_code=503, detail="Models not loaded") # File validation validate_audio_file(file) # Convert string parameters to boolean skip_enhancement_bool = skip_enhancement.lower() in ['true', '1', 'yes'] generate_tts_bool = generate_tts_param.lower() in ['true', '1', 'yes'] start_time = time.time() try: contents = await file.read() if len(contents) > Config.MAX_FILE_SIZE: raise HTTPException( status_code=413, detail=f"File too large. Max: {Config.MAX_FILE_SIZE / 1024 / 1024}MB" ) logger.info(f"📥 Processing: {file.filename} ({len(contents)/1024:.1f} KB)") # Process audio result = pipeline.process( contents, language=language, skip_enhancement=skip_enhancement_bool ) # Save enhanced audio temp_filename = f"enhanced_{int(time.time())}_{file.filename}" if not temp_filename.endswith('.wav'): temp_filename = temp_filename.rsplit('.', 1)[0] + '.wav' temp_path = Config.TEMP_DIR / temp_filename sf.write(temp_path, result['enhanced_audio'], result['sample_rate']) temp_files[temp_filename] = str(temp_path) enhanced_audio_url = f"/download/{temp_filename}" # Generate TTS if requested tts_audio_url = None if generate_tts_bool and result['transcript']: tts_filename = f"tts_{int(time.time())}_{file.filename}" if not tts_filename.endswith('.wav'): tts_filename = tts_filename.rsplit('.', 1)[0] + '.wav' tts_path = Config.TEMP_DIR / tts_filename if generate_tts(result['transcript'], str(tts_path), language): temp_files[tts_filename] = str(tts_path) tts_audio_url = f"/download/{tts_filename}" logger.info(f"✅ Generated TTS") else: logger.warning(f"⚠️ TTS generation failed") processing_time = time.time() - start_time response = { "success": True, "transcript": result['transcript'], "duration": result['duration'], "language": result['language'], "enhanced_audio_url": enhanced_audio_url, "tts_audio_url": tts_audio_url, "segments": result.get('segments', []), "processing_time": round(processing_time, 2) } logger.info(f"✅ Processed in {processing_time:.2f}s") return JSONResponse(content=response) except HTTPException: raise except Exception as e: logger.error(f"❌ Error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") @app.post("/enhance", response_model=EnhanceResponse) async def enhance_only( request: Request, file: UploadFile = File(...) ): """Enhancement only (no transcription)""" # Rate limiting client_ip = get_client_ip(request) if not await rate_limiter.check_rate_limit(client_ip): raise HTTPException(status_code=429, detail="Rate limit exceeded") if pipeline is None: raise HTTPException(status_code=503, detail="Models not loaded") # File validation validate_audio_file(file) start_time = time.time() try: contents = await file.read() # Load and enhance audio = pipeline.audio_processor.load_audio(contents) enhanced_audio = pipeline.enhance_audio(audio) # Save temp_filename = f"enhanced_{int(time.time())}_{file.filename}" if not temp_filename.endswith('.wav'): temp_filename = temp_filename.rsplit('.', 1)[0] + '.wav' temp_path = Config.TEMP_DIR / temp_filename sf.write(temp_path, enhanced_audio, pipeline.audio_processor.sample_rate) temp_files[temp_filename] = str(temp_path) duration = len(enhanced_audio) / pipeline.audio_processor.sample_rate processing_time = time.time() - start_time return { "success": True, "enhanced_audio_url": f"/download/{temp_filename}", "duration": duration, "processing_time": round(processing_time, 2) } except Exception as e: logger.error(f"❌ Enhancement error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/transcribe", response_model=TranscribeResponse) async def transcribe_only( request: Request, file: UploadFile = File(...), language: Optional[str] = Form(default="en"), enhance_first: Optional[str] = Form(default="true") ): """Transcription with optional enhancement""" # Rate limiting client_ip = get_client_ip(request) if not await rate_limiter.check_rate_limit(client_ip): raise HTTPException(status_code=429, detail="Rate limit exceeded") if pipeline is None: raise HTTPException(status_code=503, detail="Models not loaded") # File validation validate_audio_file(file) enhance_bool = enhance_first.lower() in ['true', '1', 'yes'] start_time = time.time() try: contents = await file.read() # Load audio audio = pipeline.audio_processor.load_audio(contents) # Optionally enhance if enhance_bool: audio = pipeline.enhance_audio(audio) # Transcribe result = pipeline.transcribe_audio(audio, language) duration = len(audio) / pipeline.audio_processor.sample_rate processing_time = time.time() - start_time return { "success": True, "transcript": result['text'].strip(), "duration": duration, "language": result.get('language', language), "segments": result.get('segments', []), "processing_time": round(processing_time, 2) } except Exception as e: logger.error(f"❌ Transcription error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/tts") async def text_to_speech(request: TTSRequest): """Convert text to speech""" if not request.text: raise HTTPException(status_code=400, detail="No text provided") try: temp_filename = f"tts_{int(time.time())}.wav" temp_path = Config.TEMP_DIR / temp_filename if not generate_tts(request.text, str(temp_path), request.language): raise HTTPException( status_code=500, detail="TTS failed. Install gtts." ) return FileResponse( temp_path, media_type="audio/wav", filename=temp_filename ) except HTTPException: raise except Exception as e: logger.error(f"❌ TTS error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/download/{filename}") async def download_file(filename: str): """Download processed audio file""" if filename not in temp_files: raise HTTPException(status_code=404, detail="File not found or expired") file_path = Path(temp_files[filename]) if not file_path.exists(): raise HTTPException(status_code=404, detail="File not found") return FileResponse( file_path, media_type="audio/wav", filename=filename ) @app.delete("/cleanup/{filename}") async def cleanup_file(filename: str): """Manually cleanup a temporary file""" if filename not in temp_files: raise HTTPException(status_code=404, detail="File not found") try: file_path = Path(temp_files[filename]) if file_path.exists(): os.remove(file_path) del temp_files[filename] return {"success": True, "message": "File deleted"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn # HF Spaces uses port 7860 uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info" )