Spaces:
Sleeping
Sleeping
| """ | |
| TTS Backend for Yorubs | |
| Uses facebook/mms-tts-yor model for Yoruba text-to-speech | |
| Security Features: | |
| - API key validation (X-API-Key header) | |
| - Rate limiting per IP (100 requests/day) | |
| """ | |
| import os | |
| import base64 | |
| import logging | |
| from datetime import datetime, timezone | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException, Header, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from tts_service import TTSService | |
| from cache import TTSCache | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| API_KEY = os.environ.get("TTS_API_KEY", "") | |
| # Rate limiting | |
| MAX_REQUESTS_PER_DAY = 100 | |
| # ============================================================================= | |
| # RATE LIMITING (In-Memory - resets on restart) | |
| # ============================================================================= | |
| rate_limit_cache: dict[str, dict] = {} | |
| def check_rate_limit(client_id: str) -> tuple[bool, int]: | |
| """ | |
| Check if client has exceeded rate limit. | |
| Returns (allowed, remaining_requests) | |
| """ | |
| today = datetime.now(timezone.utc).strftime("%Y-%m-%d") | |
| cache_key = f"{client_id}_{today}" | |
| if cache_key not in rate_limit_cache: | |
| rate_limit_cache[cache_key] = {"count": 0, "date": today} | |
| entry = rate_limit_cache[cache_key] | |
| # Reset if new day | |
| if entry["date"] != today: | |
| entry = {"count": 0, "date": today} | |
| rate_limit_cache[cache_key] = entry | |
| remaining = MAX_REQUESTS_PER_DAY - entry["count"] | |
| if entry["count"] >= MAX_REQUESTS_PER_DAY: | |
| return False, 0 | |
| entry["count"] += 1 | |
| return True, remaining - 1 | |
| def cleanup_old_rate_limits(): | |
| """Remove entries from previous days""" | |
| today = datetime.now(timezone.utc).strftime("%Y-%m-%d") | |
| keys_to_remove = [k for k, v in rate_limit_cache.items() if v.get("date") != today] | |
| for key in keys_to_remove: | |
| del rate_limit_cache[key] | |
| # ============================================================================= | |
| # FASTAPI APP | |
| # ============================================================================= | |
| app = FastAPI( | |
| title="Yorubs TTS API", | |
| description="Text-to-Speech API for Yoruba language using MMS-TTS-YOR", | |
| version="4.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize services | |
| tts = TTSService() | |
| cache = TTSCache() | |
| # ============================================================================= | |
| # MODELS | |
| # ============================================================================= | |
| class TTSRequest(BaseModel): | |
| text: str | |
| speed: Optional[float] = 1.0 | |
| class TTSResponse(BaseModel): | |
| audio: str # base64 encoded WAV | |
| cached: bool | |
| remaining_requests: Optional[int] = None | |
| # ============================================================================= | |
| # ENDPOINTS | |
| # ============================================================================= | |
| async def root(): | |
| return {"status": "ok", "service": "Yorubs TTS API", "version": "4.0.0"} | |
| async def health(): | |
| return { | |
| "status": "healthy", | |
| "model": "facebook/mms-tts-yor", | |
| } | |
| async def text_to_speech( | |
| request: TTSRequest, | |
| req: Request, | |
| x_api_key: Optional[str] = Header(None, alias="X-API-Key"), | |
| authorization: Optional[str] = Header(None), | |
| ): | |
| """ | |
| Generate speech from text. | |
| Authentication: | |
| - API Key via X-API-Key header | |
| - Bearer token forwarded from Supabase Edge Function | |
| Rate limiting: 100 requests per client per day | |
| """ | |
| # Authenticate via API key or Bearer token from Edge Function | |
| if API_KEY and x_api_key == API_KEY: | |
| client_id = "api_key_client" | |
| elif authorization and authorization.startswith("Bearer "): | |
| # Trust the Edge Function that already validated the JWT | |
| client_id = "edge_function_client" | |
| else: | |
| raise HTTPException(status_code=401, detail="Invalid or missing authentication") | |
| # Use client IP for more granular rate limiting when possible | |
| client_ip = req.client.host if req.client else client_id | |
| rate_key = f"{client_id}_{client_ip}" | |
| # Validate request | |
| text = request.text.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Text is required") | |
| if len(text) > 500: | |
| raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") | |
| # Check rate limit | |
| allowed, remaining = check_rate_limit(rate_key) | |
| if not allowed: | |
| raise HTTPException( | |
| status_code=429, | |
| detail="Daily rate limit exceeded. Please try again tomorrow." | |
| ) | |
| logger.info(f"TTS request for text: {text[:50]}... speed: {request.speed}") | |
| # Normalize speed | |
| speed = max(0.5, min(1.5, request.speed or 1.0)) | |
| # Check cache first | |
| # v6 prefix forces re-generation after expanded carrier threshold (short phrases) | |
| cache_key = f"v6:{text}|speed={speed}" if speed != 1.0 else f"v6:{text}" | |
| cached_audio = await cache.get(cache_key) | |
| if cached_audio: | |
| logger.info("Returning cached audio") | |
| return TTSResponse(audio=cached_audio, cached=True, remaining_requests=remaining) | |
| try: | |
| audio_bytes = await tts.synthesize(text, speed=speed) | |
| audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') | |
| await cache.set(cache_key, audio_b64) | |
| logger.info(f"Generated audio: {len(audio_bytes)} bytes") | |
| return TTSResponse(audio=audio_b64, cached=False, remaining_requests=remaining) | |
| except Exception as e: | |
| logger.error(f"TTS synthesis failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"TTS synthesis failed: {str(e)}") | |
| # ============================================================================= | |
| # CACHE MANAGEMENT | |
| # ============================================================================= | |
| class ClearCacheRequest(BaseModel): | |
| text: Optional[str] = None # Clear specific text, or all if omitted | |
| async def clear_cache( | |
| request: ClearCacheRequest, | |
| x_api_key: Optional[str] = Header(None, alias="X-API-Key"), | |
| ): | |
| """ | |
| Clear cached TTS audio. Requires API key. | |
| - With text: clears only that specific entry (both v5 and v6 prefixed) | |
| - Without text: clears ALL cached entries | |
| """ | |
| if not API_KEY or x_api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="API key required for cache management") | |
| if request.text: | |
| # Clear specific text entries (all known prefix versions) | |
| text = request.text.strip() | |
| cleared = 0 | |
| for prefix in ["v5:", "v6:", ""]: | |
| key = f"{prefix}{text}" | |
| existing = await cache.get(key) | |
| if existing: | |
| await cache.delete(key) | |
| cleared += 1 | |
| # Also clear with speed variants | |
| for speed in [0.5, 0.7, 1.5]: | |
| speed_key = f"{prefix}{text}|speed={speed}" | |
| existing = await cache.get(speed_key) | |
| if existing: | |
| await cache.delete(speed_key) | |
| cleared += 1 | |
| logger.info(f"Cleared {cleared} cache entries for: {text[:50]}") | |
| return {"status": "ok", "cleared": cleared, "text": text} | |
| else: | |
| await cache.clear() | |
| logger.info("Cleared ALL cache entries") | |
| return {"status": "ok", "cleared": "all"} | |
| # ============================================================================= | |
| # STARTUP | |
| # ============================================================================= | |
| async def startup_event(): | |
| cleanup_old_rate_limits() | |
| logger.info("TTS API v4.0.0 started — carrier threshold: 3 words / 15 chars") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |