Voice_backend / app /main.py
Mohansai2004's picture
Update app/main.py
259e786 verified
"""
Main application entry point.
"""
import os
import asyncio
import uvicorn
from fastapi import FastAPI, WebSocket, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
from typing import List, Optional
# Configure environment variables before any imports that might use numba
os.environ['NUMBA_CACHE_DIR'] = '/tmp/numba_cache'
os.environ.setdefault('NUMBA_DISABLE_JIT', '0')
# Configure matplotlib and fontconfig to use writable directories
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['FONTCONFIG_PATH'] = '/tmp/fontconfig'
# Configure SpaCy cache directories (MUST be writable)
os.environ['HOME'] = '/tmp'
os.environ['SPACY_HOME'] = '/tmp'
os.environ['XDG_CACHE_HOME'] = '/tmp'
# Configure argostranslate to use persisted directory in /app
os.environ.setdefault('ARGOS_PACKAGES_DIR', '/app/.argos_packages')
from app.config import setup_logging, get_settings, logger
from app.server import get_websocket_server
settings = get_settings()
# Pydantic models for request validation
class TranslateRequest(BaseModel):
"""Request model for text translation."""
text: str = Field(..., description="Text to translate")
source_language: str = Field(..., description="Source language code (e.g., 'en')")
target_language: str = Field(..., description="Target language code (e.g., 'es')")
class BatchTranslateRequest(BaseModel):
"""Request model for batch text translation."""
texts: List[str] = Field(..., description="List of texts to translate", max_items=100)
source_language: str = Field(..., description="Source language code")
target_language: str = Field(..., description="Target language code")
class TTSRequest(BaseModel):
"""Request model for TTS audio generation."""
text: str = Field(..., description="Text to synthesize")
language: str = Field(..., description="Language code (e.g., 'en', 'hi', 'es')")
format: str = Field(default="wav", description="Audio format: 'wav' or 'mp3'")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
# Startup
logger.info(
"application_starting",
environment=settings.environment,
host=settings.host,
port=settings.port
)
# Initialize models and resources here
# await initialize_models()
yield
# Shutdown
logger.info("application_shutting_down")
# Cleanup resources here
# Create FastAPI application
app = FastAPI(
title="Voice-to-Voice Translator",
description="Real-time voice translation system",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Root endpoint."""
return {
"service": "Voice-to-Voice Translator",
"version": "1.0.0",
"status": "running"
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
from app.server import get_connection_manager
from app.rooms import get_room_manager
conn_mgr = get_connection_manager()
room_mgr = get_room_manager()
return {
"status": "healthy",
"connections": conn_mgr.get_connection_count(),
"rooms": room_mgr.get_room_count(),
"total_users": room_mgr.get_total_users()
}
@app.get("/stats")
async def get_server_statistics():
"""Get detailed server statistics and metrics."""
from app.server import get_connection_manager
from app.rooms import get_room_manager
conn_mgr = get_connection_manager()
room_mgr = get_room_manager()
# Get worker pool statistics - avoid importing if it causes errors
translation_stats = {"total": settings.translation_workers, "busy": 0, "queue_size": 0}
tts_stats = {"total": settings.tts_workers, "busy": 0, "queue_size": 0}
try:
from app.workers import get_translation_pool, get_tts_pool
translation_pool = get_translation_pool()
tts_pool = get_tts_pool()
if hasattr(translation_pool, 'get_busy_count'):
translation_stats["busy"] = translation_pool.get_busy_count()
if hasattr(translation_pool, 'get_queue_size'):
translation_stats["queue_size"] = translation_pool.get_queue_size()
if hasattr(tts_pool, 'get_busy_count'):
tts_stats["busy"] = tts_pool.get_busy_count()
if hasattr(tts_pool, 'get_queue_size'):
tts_stats["queue_size"] = tts_pool.get_queue_size()
except Exception as e:
logger.warning("worker_stats_unavailable", error=str(e))
return {
"server": {
"version": "1.0.0",
"environment": settings.environment
},
"connections": {
"total": conn_mgr.get_connection_count(),
"active": conn_mgr.get_connection_count()
},
"rooms": {
"total": room_mgr.get_room_count(),
"active": room_mgr.get_room_count()
},
"workers": {
"translation": translation_stats,
"tts": tts_stats
}
}
@app.get("/config")
async def get_system_configuration():
"""Get current system configuration."""
from app.utils.model_scanner import ModelScanner
# Get dynamic language information from model scanner
scanner = ModelScanner()
stt_languages = scanner.get_available_stt_languages()
translation_languages = scanner.get_available_translation_languages()
tts_languages = scanner.get_available_tts_languages()
# Convert to simple dict for supported_languages field
supported_langs = {}
# Add STT languages
for code, info in stt_languages.items():
supported_langs[code] = info.get("name", code.upper())
# Add translation languages (merge with STT)
for code, info in translation_languages.items():
if code not in supported_langs:
supported_langs[code] = info.get("name", code.upper())
# Add TTS languages (merge with previous)
for code, info in tts_languages.items():
if code not in supported_langs and code != "multilingual":
supported_langs[code] = info.get("name", code.upper())
return {
"audio": {
"sample_rate": settings.audio_sample_rate,
"channels": settings.audio_channels,
"chunk_size": settings.audio_chunk_size,
"format": settings.audio_format
},
"limits": {
"max_connections": settings.max_connections,
"max_connections_per_ip": settings.max_connections_per_ip,
"max_users_per_room": settings.max_users_per_room,
"max_message_size": settings.max_message_size,
"idle_timeout": settings.idle_timeout,
"room_timeout": settings.room_timeout
},
"rate_limits": {
"messages_per_second": settings.max_messages_per_second,
"requests_per_minute": settings.max_requests_per_minute,
"rate_limit_per_minute": settings.rate_limit_per_minute
},
"workers": {
"translation_workers": settings.translation_workers,
"tts_workers": settings.tts_workers,
"worker_threads": settings.worker_threads,
"queue_size": settings.queue_size
},
"models": {
"stt_engine": settings.stt_engine,
"translation_engine": settings.translation_engine,
"tts_engine": settings.tts_engine,
"vosk_model_base_path": settings.vosk_model_base_path,
"argos_model_path": settings.argos_model_path,
"coqui_model_path": settings.coqui_model_path,
"available_models": {
"stt": stt_languages,
"translation": translation_languages,
"tts": tts_languages
},
"supported_languages": supported_langs
},
"features": {
"authentication_enabled": settings.enable_auth,
"rate_limiting_enabled": True,
"metrics_enabled": settings.enable_metrics,
"gpu_enabled": settings.enable_gpu
},
"websocket": {
"ping_interval": settings.ws_ping_interval,
"ping_timeout": settings.ws_ping_timeout
},
"environment": {
"env": settings.environment,
"debug": settings.debug,
"host": settings.host,
"port": settings.port,
"cors_origins": settings.cors_origins
}
}
@app.get("/languages/supported")
async def get_supported_languages():
"""Get list of supported languages for STT, Translation, and TTS based on downloaded models."""
from app.utils.model_scanner import ModelScanner
scanner = ModelScanner()
return {
"stt": scanner.get_available_stt_languages(),
"translation": scanner.get_available_translation_languages(),
"tts": scanner.get_available_tts_languages()
}
@app.get("/languages/pairs")
async def get_translation_pairs():
"""Get available translation language pairs based on installed packages."""
from app.utils.model_scanner import ModelScanner
scanner = ModelScanner()
pairs = scanner.get_translation_pairs()
return {
"pairs": pairs,
"count": len(pairs)
}
@app.post("/translate")
async def translate_text(request: TranslateRequest):
"""Translate text without audio processing.
Request body:
{
"text": "Hello, how are you?",
"source_language": "en",
"target_language": "es"
}
"""
from app.pipeline.translate import get_translator
from app.utils.model_scanner import ModelScanner
import time
try:
# Get available languages dynamically
scanner = ModelScanner()
available_languages = scanner.get_available_translation_languages()
# Check if source language is in supported languages
if request.source_language not in available_languages:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"error": f"Source language '{request.source_language}' is not supported",
"available_languages": list(available_languages.keys()),
"message": "Check /languages/supported for available languages"
}
)
# Check if target language is in supported languages
if request.target_language not in available_languages:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"error": f"Target language '{request.target_language}' is not supported",
"available_languages": list(available_languages.keys()),
"message": "Check /languages/supported for available languages"
}
)
# Don't check for specific pair availability - let translator auto-install
# The translator will handle package installation if needed
# Get translator and translate (with auto-install)
translator = get_translator()
start_time = time.time()
translated_text = await translator.translate(
request.text,
request.source_language,
request.target_language
)
processing_time = (time.time() - start_time) * 1000
return {
"original_text": request.text,
"translated_text": translated_text,
"source_language": request.source_language,
"target_language": request.target_language,
"processing_time_ms": round(processing_time, 2)
}
except HTTPException:
raise
except Exception as e:
logger.error("translation_error", error=str(e), exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Translation failed: {str(e)}"
)
@app.post("/translate/batch")
async def translate_batch(request: BatchTranslateRequest):
"""Translate multiple texts in one request.
Request body:
{
"texts": ["Hello", "How are you?", "Goodbye"],
"source_language": "en",
"target_language": "es"
}
"""
from app.pipeline.translate import get_translator
from app.utils.model_scanner import ModelScanner
import time
try:
# Get available languages dynamically
scanner = ModelScanner()
available_languages = scanner.get_available_translation_languages()
translation_pairs = scanner.get_translation_pairs()
# Check if source language is available
if request.source_language not in available_languages:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"error": f"Source language '{request.source_language}' is not available",
"available_languages": list(available_languages.keys()),
"message": "Check /languages/supported for available languages"
}
)
# Check if target language is in supported languages
if request.target_language not in available_languages:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"error": f"Target language '{request.target_language}' is not supported",
"available_languages": list(available_languages.keys()),
"message": "Check /languages/supported for available languages"
}
)
# Don't check for specific pair availability - let translator auto-install
# The translator will handle package installation if needed
# Get translator (with auto-install support)
translator = get_translator()
# Translate all texts
start_time = time.time()
translations = []
for text in request.texts:
translated_text = await translator.translate(
text,
request.source_language,
request.target_language
)
translations.append({
"original": text,
"translated": translated_text
})
processing_time = (time.time() - start_time) * 1000
return {
"translations": translations,
"source_language": request.source_language,
"target_language": request.target_language,
"total": len(translations),
"processing_time_ms": round(processing_time, 2)
}
except HTTPException:
raise
except Exception as e:
logger.error("batch_translation_error", error=str(e), exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Batch translation failed: {str(e)}"
)
@app.post("/tts/generate")
async def generate_tts_audio(request: TTSRequest):
"""Generate TTS audio for the provided text.
Request body:
{
"text": "Hello, this is a test message",
"language": "en",
"format": "wav"
}
Returns:
Binary audio data (WAV or MP3 format)
"""
from fastapi.responses import Response
from app.pipeline.tts import get_tts_factory
from app.utils.model_scanner import ModelScanner
import time
import io
try:
# Validate format
if request.format not in ["wav", "mp3"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid format '{request.format}'. Supported formats: 'wav', 'mp3'"
)
# Get available TTS languages
scanner = ModelScanner()
available_tts_languages = scanner.get_available_tts_languages()
# Check if language is supported
if request.language not in available_tts_languages:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"error": f"Language '{request.language}' is not supported for TTS",
"available_languages": list(available_tts_languages.keys()),
"message": "Check /languages/supported for available TTS languages"
}
)
# Get TTS engine
tts_factory = get_tts_factory()
tts_engine = tts_factory.get_engine(request.language)
# Generate audio
start_time = time.time()
audio_data = await tts_engine.synthesize_async(request.text)
processing_time = (time.time() - start_time) * 1000
logger.info(
"tts_generated",
language=request.language,
text_length=len(request.text),
audio_size=len(audio_data),
processing_time_ms=round(processing_time, 2)
)
# Return audio as response
if request.format == "wav":
media_type = "audio/wav"
else:
media_type = "audio/mpeg"
return Response(
content=audio_data,
media_type=media_type,
headers={
"Content-Disposition": f'attachment; filename="tts_output.{request.format}"',
"X-Processing-Time-Ms": str(round(processing_time, 2))
}
)
except HTTPException:
raise
except Exception as e:
logger.error("tts_generation_error", error=str(e), exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"TTS generation failed: {str(e)}"
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for client connections."""
ws_server = get_websocket_server()
await ws_server.handle_connection(websocket)
def main():
"""Main entry point."""
# Setup logging
setup_logging()
logger.info(
"starting_server",
host=settings.host,
port=settings.port,
debug=settings.debug
)
# Run server
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.is_development,
log_level=settings.log_level.lower(),
ws_ping_interval=settings.ws_ping_interval,
ws_ping_timeout=settings.ws_ping_timeout,
)
if __name__ == "__main__":
main()