Spaces:
Sleeping
Sleeping
| """ | |
| VoiceForge - FastAPI Main Application | |
| Production-grade Speech-to-Text & Text-to-Speech API | |
| """ | |
| import logging | |
| # WARN: PyTorch 2.6+ security workaround for Pyannote | |
| # Must be before any other torch imports | |
| import os | |
| os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0" | |
| import torch.serialization | |
| try: | |
| torch.serialization.add_safe_globals([dict]) | |
| except: | |
| pass | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.openapi.utils import get_openapi | |
| from prometheus_fastapi_instrumentator import Instrumentator | |
| from .core.config import get_settings | |
| from .api.routes import ( | |
| stt_router, | |
| tts_router, | |
| health_router, | |
| transcripts_router, | |
| ws_router, | |
| translation_router, | |
| batch_router, | |
| analysis_router, | |
| audio_router, | |
| cloning_router, | |
| sign_router, | |
| auth_router, | |
| s2s_router, | |
| sign_bridge # Import the module | |
| ) | |
| from .models import Base, engine | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| settings = get_settings() | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Application lifespan handler | |
| Runs on startup and shutdown | |
| """ | |
| # Startup | |
| logger.info(f"Starting {settings.app_name} v{settings.app_version}") | |
| # Create database tables | |
| logger.info("Creating database tables...") | |
| Base.metadata.create_all(bind=engine) | |
| # Pre-warm Whisper models for faster first request | |
| logger.info("Pre-warming AI models...") | |
| try: | |
| from .services.whisper_stt_service import get_whisper_model | |
| # Pre-load English Distil model (most common) | |
| get_whisper_model("distil-small.en") | |
| logger.info("✅ Distil-Whisper model loaded") | |
| # Pre-load multilingual model | |
| get_whisper_model("small") | |
| logger.info("✅ Whisper-small model loaded") | |
| except Exception as e: | |
| logger.warning(f"Model pre-warming failed: {e}") | |
| # Pre-cache TTS voice list | |
| try: | |
| from .services.tts_service import get_tts_service | |
| tts_service = get_tts_service() | |
| await tts_service.get_voices() | |
| logger.info("✅ TTS voice list cached") | |
| except Exception as e: | |
| logger.warning(f"Voice list caching failed: {e}") | |
| logger.info("🚀 Startup complete - All models warmed up!") | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down...") | |
| # TODO: Close database connections | |
| # TODO: Close Redis connections | |
| logger.info("Shutdown complete") | |
| # Create FastAPI application | |
| app = FastAPI( | |
| title=settings.app_name, | |
| description=""" | |
| ## VoiceForge API | |
| Production-grade Speech-to-Text and Text-to-Speech API. | |
| ### Features | |
| - 🎤 **Speech-to-Text**: Transcribe audio files with word-level timestamps | |
| - 🔊 **Text-to-Speech**: Synthesize speech with 300+ neural voices | |
| - 🌍 **Multi-language**: Support for 10+ languages | |
| - 🧠 **AI Analysis**: Sentiment, keywords, and summarization | |
| - 🌐 **Translation**: Translate text/audio between 20+ languages | |
| - ⚡ **Free & Fast**: Local Whisper + Edge TTS - no API costs | |
| """, | |
| version=settings.app_version, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| lifespan=lifespan, | |
| ) | |
| from slowapi import _rate_limit_exceeded_handler | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.middleware import SlowAPIMiddleware | |
| from .core.limiter import limiter | |
| from .core.security_headers import SecurityHeadersMiddleware | |
| # Add Rate Limiting (default: 60 requests/min per IP) | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| app.add_middleware(SlowAPIMiddleware) | |
| # Security Headers (Must be before CORS to ensure headers are present even on errors/CORS blocks) | |
| app.add_middleware(SecurityHeadersMiddleware) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Prometheus Metrics | |
| Instrumentator().instrument(app).expose(app) | |
| # Include routers | |
| app.include_router(health_router) | |
| app.include_router(auth_router, prefix="/api/v1") | |
| app.include_router(stt_router, prefix="/api/v1") | |
| app.include_router(tts_router, prefix="/api/v1") | |
| app.include_router(transcripts_router, prefix="/api/v1") | |
| app.include_router(ws_router, prefix="/api/v1") | |
| app.include_router(translation_router, prefix="/api/v1") | |
| app.include_router(batch_router, prefix="/api/v1") | |
| app.include_router(analysis_router, prefix="/api/v1") | |
| app.include_router(audio_router, prefix="/api/v1") | |
| app.include_router(cloning_router, prefix="/api/v1") | |
| app.include_router(sign_router, prefix="/api/v1") | |
| app.include_router(s2s_router, prefix="/api/v1") # Added s2s_router | |
| app.include_router(sign_bridge.router, prefix="/api/v1") # Added sign_bridge_router | |
| # Exception handlers | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| """Global exception handler for unhandled errors""" | |
| logger.exception(f"Unhandled error: {exc}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "internal_server_error", | |
| "message": "An unexpected error occurred", | |
| "detail": str(exc) if settings.debug else None, | |
| }, | |
| ) | |
| async def value_error_handler(request: Request, exc: ValueError): | |
| """Handler for validation errors""" | |
| return JSONResponse( | |
| status_code=400, | |
| content={ | |
| "error": "validation_error", | |
| "message": str(exc), | |
| }, | |
| ) | |
| # Root endpoint | |
| async def root(): | |
| """API root - returns basic info""" | |
| return { | |
| "name": settings.app_name, | |
| "version": settings.app_version, | |
| "status": "running", | |
| "docs": "/docs", | |
| "health": "/health", | |
| } | |
| # Custom OpenAPI schema | |
| def custom_openapi(): | |
| """Generate custom OpenAPI schema with enhanced documentation""" | |
| if app.openapi_schema: | |
| return app.openapi_schema | |
| openapi_schema = get_openapi( | |
| title=settings.app_name, | |
| version=settings.app_version, | |
| description=app.description, | |
| routes=app.routes, | |
| ) | |
| # Add custom logo | |
| openapi_schema["info"]["x-logo"] = { | |
| "url": "https://example.com/logo.png" | |
| } | |
| # Add tags with descriptions | |
| openapi_schema["tags"] = [ | |
| { | |
| "name": "Health", | |
| "description": "Health check endpoints for monitoring", | |
| }, | |
| { | |
| "name": "Speech-to-Text", | |
| "description": "Convert audio to text with timestamps and speaker detection", | |
| }, | |
| { | |
| "name": "Text-to-Speech", | |
| "description": "Convert text to natural-sounding speech", | |
| }, | |
| ] | |
| app.openapi_schema = openapi_schema | |
| return app.openapi_schema | |
| app.openapi = custom_openapi | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "app.main:app", | |
| host=settings.api_host, | |
| port=settings.api_port, | |
| reload=settings.debug, | |
| ) | |