yoruba-tts / main.py
Yurikks's picture
Upload 7 files
c7d5077 verified
"""
TTS Backend for Yorubs
Uses facebook/mms-tts-yor model for Yoruba text-to-speech
Security Features:
- API key validation (X-API-Key header)
- Rate limiting per IP (100 requests/day)
"""
import os
import base64
import logging
from datetime import datetime, timezone
from typing import Optional
from fastapi import FastAPI, HTTPException, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from tts_service import TTSService
from cache import TTSCache
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
API_KEY = os.environ.get("TTS_API_KEY", "")
# Rate limiting
MAX_REQUESTS_PER_DAY = 100
# =============================================================================
# RATE LIMITING (In-Memory - resets on restart)
# =============================================================================
rate_limit_cache: dict[str, dict] = {}
def check_rate_limit(client_id: str) -> tuple[bool, int]:
"""
Check if client has exceeded rate limit.
Returns (allowed, remaining_requests)
"""
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
cache_key = f"{client_id}_{today}"
if cache_key not in rate_limit_cache:
rate_limit_cache[cache_key] = {"count": 0, "date": today}
entry = rate_limit_cache[cache_key]
# Reset if new day
if entry["date"] != today:
entry = {"count": 0, "date": today}
rate_limit_cache[cache_key] = entry
remaining = MAX_REQUESTS_PER_DAY - entry["count"]
if entry["count"] >= MAX_REQUESTS_PER_DAY:
return False, 0
entry["count"] += 1
return True, remaining - 1
def cleanup_old_rate_limits():
"""Remove entries from previous days"""
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
keys_to_remove = [k for k, v in rate_limit_cache.items() if v.get("date") != today]
for key in keys_to_remove:
del rate_limit_cache[key]
# =============================================================================
# FASTAPI APP
# =============================================================================
app = FastAPI(
title="Yorubs TTS API",
description="Text-to-Speech API for Yoruba language using MMS-TTS-YOR",
version="4.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize services
tts = TTSService()
cache = TTSCache()
# =============================================================================
# MODELS
# =============================================================================
class TTSRequest(BaseModel):
text: str
speed: Optional[float] = 1.0
class TTSResponse(BaseModel):
audio: str # base64 encoded WAV
cached: bool
remaining_requests: Optional[int] = None
# =============================================================================
# ENDPOINTS
# =============================================================================
@app.get("/")
async def root():
return {"status": "ok", "service": "Yorubs TTS API", "version": "4.0.0"}
@app.get("/health")
async def health():
return {
"status": "healthy",
"model": "facebook/mms-tts-yor",
}
@app.post("/tts", response_model=TTSResponse)
async def text_to_speech(
request: TTSRequest,
req: Request,
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
authorization: Optional[str] = Header(None),
):
"""
Generate speech from text.
Authentication:
- API Key via X-API-Key header
- Bearer token forwarded from Supabase Edge Function
Rate limiting: 100 requests per client per day
"""
# Authenticate via API key or Bearer token from Edge Function
if API_KEY and x_api_key == API_KEY:
client_id = "api_key_client"
elif authorization and authorization.startswith("Bearer "):
# Trust the Edge Function that already validated the JWT
client_id = "edge_function_client"
else:
raise HTTPException(status_code=401, detail="Invalid or missing authentication")
# Use client IP for more granular rate limiting when possible
client_ip = req.client.host if req.client else client_id
rate_key = f"{client_id}_{client_ip}"
# Validate request
text = request.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Text is required")
if len(text) > 500:
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
# Check rate limit
allowed, remaining = check_rate_limit(rate_key)
if not allowed:
raise HTTPException(
status_code=429,
detail="Daily rate limit exceeded. Please try again tomorrow."
)
logger.info(f"TTS request for text: {text[:50]}... speed: {request.speed}")
# Normalize speed
speed = max(0.5, min(1.5, request.speed or 1.0))
# Check cache first
# v6 prefix forces re-generation after expanded carrier threshold (short phrases)
cache_key = f"v6:{text}|speed={speed}" if speed != 1.0 else f"v6:{text}"
cached_audio = await cache.get(cache_key)
if cached_audio:
logger.info("Returning cached audio")
return TTSResponse(audio=cached_audio, cached=True, remaining_requests=remaining)
try:
audio_bytes = await tts.synthesize(text, speed=speed)
audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
await cache.set(cache_key, audio_b64)
logger.info(f"Generated audio: {len(audio_bytes)} bytes")
return TTSResponse(audio=audio_b64, cached=False, remaining_requests=remaining)
except Exception as e:
logger.error(f"TTS synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"TTS synthesis failed: {str(e)}")
# =============================================================================
# CACHE MANAGEMENT
# =============================================================================
class ClearCacheRequest(BaseModel):
text: Optional[str] = None # Clear specific text, or all if omitted
@app.post("/clear-cache")
async def clear_cache(
request: ClearCacheRequest,
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
):
"""
Clear cached TTS audio. Requires API key.
- With text: clears only that specific entry (both v5 and v6 prefixed)
- Without text: clears ALL cached entries
"""
if not API_KEY or x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="API key required for cache management")
if request.text:
# Clear specific text entries (all known prefix versions)
text = request.text.strip()
cleared = 0
for prefix in ["v5:", "v6:", ""]:
key = f"{prefix}{text}"
existing = await cache.get(key)
if existing:
await cache.delete(key)
cleared += 1
# Also clear with speed variants
for speed in [0.5, 0.7, 1.5]:
speed_key = f"{prefix}{text}|speed={speed}"
existing = await cache.get(speed_key)
if existing:
await cache.delete(speed_key)
cleared += 1
logger.info(f"Cleared {cleared} cache entries for: {text[:50]}")
return {"status": "ok", "cleared": cleared, "text": text}
else:
await cache.clear()
logger.info("Cleared ALL cache entries")
return {"status": "ok", "cleared": "all"}
# =============================================================================
# STARTUP
# =============================================================================
@app.on_event("startup")
async def startup_event():
cleanup_old_rate_limits()
logger.info("TTS API v4.0.0 started — carrier threshold: 3 words / 15 chars")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)