Tts-api / app.py
yukee1992's picture
Create app.py
330aa32 verified
raw
history blame
11.1 kB
import os
import tempfile
import uuid
from datetime import datetime
from typing import List, Optional
import requests
from fastapi import FastAPI, HTTPException, Form, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import numpy as np
from TTS.api import TTS
# Configure environment
os.makedirs("/tmp/voices", exist_ok=True)
os.makedirs("/tmp/output", exist_ok=True)
# Initialize FastAPI app
app = FastAPI(title="TTS API", description="API for text-to-speech with Coqui TTS and voice cloning")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration
OCI_UPLOAD_API_URL = os.getenv("OCI_UPLOAD_API_URL", "http://localhost:7860")
DEFAULT_MODEL = "tts_models/multilingual/multi-dataset/xtts_v2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"βœ… Using device: {DEVICE}")
# Initialize TTS model
try:
tts = TTS(DEFAULT_MODEL).to(DEVICE)
print("βœ… TTS model loaded successfully")
except Exception as e:
print(f"❌ Failed to load TTS model: {e}")
tts = None
class TTSRequest(BaseModel):
text: str
project_id: str
voice_name: Optional[str] = "default"
language: Optional[str] = "en"
class BatchTTSRequest(BaseModel):
texts: List[str]
project_id: str
voice_name: Optional[str] = "default"
language: Optional[str] = "en"
class VoiceCloneRequest(BaseModel):
project_id: str
voice_name: str
description: Optional[str] = ""
def upload_to_oci(file_path: str, filename: str, project_id: str, file_type="voiceover"):
"""Upload file to OCI using your existing API with subfolder support"""
try:
if not OCI_UPLOAD_API_URL:
return None, "OCI upload API URL not configured"
# Use voiceover subfolder
url = f"{OCI_UPLOAD_API_URL}/api/upload"
with open(file_path, "rb") as f:
files = {"file": (filename, f, "audio/wav")}
data = {
"project_id": project_id,
"subfolder": "voiceover" # This creates project_id/voiceover/ structure
}
response = requests.post(url, files=files, data=data, timeout=30)
if response.status_code == 200:
result = response.json()
if result.get("status") == "success":
return result, None
else:
return None, result.get("message", "Upload failed")
else:
return None, f"Upload failed with status {response.status_code}"
except Exception as e:
return None, f"Upload error: {str(e)}"
def upload_to_oci_with_retry(file_path: str, filename: str, project_id: str, file_type="voiceover", max_retries=3):
"""Upload file to OCI with retry logic"""
for attempt in range(max_retries):
try:
print(f"πŸ”„ Upload attempt {attempt + 1} of {max_retries} for {filename}")
result, error = upload_to_oci(file_path, filename, project_id, file_type)
if error:
if attempt < max_retries - 1:
wait_time = 2 ** attempt # Exponential backoff
print(f"⏳ Upload failed, retrying in {wait_time}s: {error}")
time.sleep(wait_time)
continue
else:
return None, error
else:
return result, None
except Exception as e:
if attempt < max_retries - 1:
wait_time = 2 ** attempt
print(f"⏳ Upload exception, retrying in {wait_time}s: {str(e)}")
time.sleep(wait_time)
continue
else:
return None, f"Upload failed after {max_retries} attempts: {str(e)}"
return None, "Upload failed: unexpected error"
@app.post("/api/tts")
async def generate_tts(request: TTSRequest):
"""Generate TTS for a single text"""
try:
if tts is None:
raise HTTPException(status_code=500, detail="TTS model not loaded")
print(f"πŸ“₯ TTS request for project: {request.project_id}")
print(f" Text length: {len(request.text)} characters")
print(f" Voice: {request.voice_name}")
print(f" Language: {request.language}")
# Generate unique filename with sequential naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"voiceover_{timestamp}.wav"
output_path = f"/tmp/output/{filename}"
# Generate TTS
tts.tts_to_file(
text=request.text,
speaker_wav=f"/tmp/voices/{request.voice_name}.wav" if request.voice_name != "default" else None,
language=request.language,
file_path=output_path
)
print(f"βœ… TTS generated: {output_path}")
# Upload to OCI
upload_result, error = upload_to_oci_with_retry(
output_path, filename, request.project_id, "voiceover"
)
if error:
print(f"❌ OCI upload failed: {error}")
# Still return the local file path if upload fails
return {
"status": "partial_success",
"message": f"TTS generated but upload failed: {error}",
"local_file": output_path,
"filename": filename
}
print(f"βœ… Upload successful: {filename}")
# Clean up local file
try:
os.remove(output_path)
except:
pass
return {
"status": "success",
"message": "TTS generated and uploaded successfully",
"filename": filename,
"oci_path": upload_result.get("path", f"{request.project_id}/voiceover/{filename}")
}
except Exception as e:
print(f"❌ TTS generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
@app.post("/api/batch-tts")
async def batch_generate_tts(request: BatchTTSRequest):
"""Generate TTS for multiple texts with sequential naming"""
try:
if tts is None:
raise HTTPException(status_code=500, detail="TTS model not loaded")
print(f"πŸ“₯ Batch TTS request for project: {request.project_id}")
print(f" Number of texts: {len(request.texts)}")
print(f" Voice: {request.voice_name}")
print(f" Language: {request.language}")
results = []
for i, text in enumerate(request.texts):
print(f" Processing text {i+1}/{len(request.texts)}")
# Generate sequential filename
filename = f"voiceover_{i+1:02d}.wav"
output_path = f"/tmp/output/{filename}"
# Generate TTS
tts.tts_to_file(
text=text,
speaker_wav=f"/tmp/voices/{request.voice_name}.wav" if request.voice_name != "default" else None,
language=request.language,
file_path=output_path
)
# Upload to OCI
upload_result, error = upload_to_oci_with_retry(
output_path, filename, request.project_id, "voiceover"
)
if error:
print(f"❌ OCI upload failed for {filename}: {error}")
results.append({
"text_index": i,
"status": "partial_success",
"message": f"TTS generated but upload failed: {error}",
"local_file": output_path,
"filename": filename
})
else:
print(f"βœ… Upload successful: {filename}")
results.append({
"text_index": i,
"status": "success",
"message": "TTS generated and uploaded successfully",
"filename": filename,
"oci_path": upload_result.get("path", f"{request.project_id}/voiceover/{filename}")
})
# Clean up local file
try:
os.remove(output_path)
except:
pass
return {
"status": "completed",
"project_id": request.project_id,
"results": results
}
except Exception as e:
print(f"❌ Batch TTS generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch TTS generation failed: {str(e)}")
@app.post("/api/upload-voice")
async def upload_voice_sample(
project_id: str = Form(...),
voice_name: str = Form(...),
file: UploadFile = File(...)
):
"""Upload a voice sample for cloning"""
try:
print(f"πŸ“₯ Voice upload request: {voice_name} for project {project_id}")
# Validate file type
if not file.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
raise HTTPException(status_code=400, detail="Only audio files are allowed")
# Save voice sample
voice_path = f"/tmp/voices/{voice_name}.wav"
with open(voice_path, "wb") as f:
content = await file.read()
f.write(content)
print(f"βœ… Voice sample saved: {voice_path}")
return {
"status": "success",
"message": "Voice sample uploaded successfully",
"voice_name": voice_name,
"local_path": voice_path
}
except Exception as e:
print(f"❌ Voice upload error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Voice upload failed: {str(e)}")
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"tts_loaded": tts is not None,
"device": DEVICE,
"timestamp": datetime.now().isoformat()
}
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "TTS API with Coqui TTS and Voice Cloning",
"endpoints": {
"POST /api/tts": "Generate TTS for a single text",
"POST /api/batch-tts": "Generate TTS for multiple texts",
"POST /api/upload-voice": "Upload a voice sample for cloning",
"GET /api/health": "Health check"
},
"model": DEFAULT_MODEL if tts else "Not loaded"
}
if __name__ == "__main__":
import uvicorn
print("πŸš€ Starting TTS API with Coqui TTS and Voice Cloning...")
print("πŸ“Š API endpoints available at: http://localhost:7860/")
print("πŸ“š Documentation available at: http://localhost:7860/docs")
uvicorn.run(app, host="0.0.0.0", port=7860)