Tts-api / app.py
yukee1992's picture
Update app.py
faa93a9 verified
import os
import tempfile
import uuid
import time
import shutil
from datetime import datetime
from typing import List, Optional, Dict
from pathlib import Path
import requests
from fastapi import FastAPI, HTTPException, Form, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import numpy as np
# Configure environment with storage limits
os.makedirs("/tmp/voices", exist_ok=True)
os.makedirs("/tmp/output", exist_ok=True)
# Initialize FastAPI app
app = FastAPI(title="Multi-Language TTS API", description="API for text-to-speech with English and Chinese support")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration
OCI_UPLOAD_API_URL = os.getenv("OCI_UPLOAD_API_URL", "https://yukee1992-oci-video-storage.hf.space")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"✅ Using device: {DEVICE}")
# SIMPLIFIED: Use compatible models that work with current PyTorch
AVAILABLE_MODELS = {
"tacotron2-ddc": {
"name": "Tacotron2-DDC",
"model_name": "tts_models/en/ljspeech/tacotron2-DDC",
"description": "High-quality English TTS",
"languages": ["en"],
"voice_cloning": False,
"size_mb": 150,
"quality": "excellent",
"multi_speaker": False
},
"fastspeech2": {
"name": "FastSpeech2-Mandarin",
"model_name": "tts_models/zh-CN/baker/fastspeech2",
"description": "High-quality Chinese TTS",
"languages": ["zh"],
"voice_cloning": False,
"size_mb": 120,
"quality": "excellent",
"multi_speaker": False
}
}
# Voice styles for compatible models
VOICE_STYLES = {
# English Voice Styles
"default": {
"name": "Default English Voice",
"description": "Clear and natural English voice",
"gender": "neutral",
"language": "en",
"recommended_model": "tacotron2-ddc"
},
"clear": {
"name": "Clear English Voice",
"description": "Very clear and articulate English voice",
"gender": "neutral",
"language": "en",
"recommended_model": "tacotron2-ddc"
},
"professional": {
"name": "Professional English Voice",
"description": "Professional and authoritative English voice",
"gender": "neutral",
"language": "en",
"recommended_model": "tacotron2-ddc"
},
# Chinese Voice Styles
"chinese_default": {
"name": "默认中文语音",
"description": "清晰自然的中文语音",
"gender": "neutral",
"language": "zh",
"recommended_model": "fastspeech2"
},
"chinese_clear": {
"name": "清晰中文语音",
"description": "非常清晰和标准的中文语音",
"gender": "neutral",
"language": "zh",
"recommended_model": "fastspeech2"
},
"chinese_professional": {
"name": "专业中文语音",
"description": "专业和正式的中文语音",
"gender": "neutral",
"language": "zh",
"recommended_model": "fastspeech2"
}
}
# Global state
tts = None
model_loaded = False
current_model = ""
model_loading = False
# Pydantic models
class TTSRequest(BaseModel):
text: str
project_id: str
voice_style: Optional[str] = "default"
speed: Optional[float] = 1.0
language: Optional[str] = "auto"
class BatchTTSRequest(BaseModel):
texts: List[str]
project_id: str
voice_style: Optional[str] = "default"
speed: Optional[float] = 1.0
language: Optional[str] = "auto"
# Language detection function
def detect_language(text: str) -> str:
"""Detect if text is Chinese or English"""
import re
# Count Chinese characters
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
total_chars = len(text.strip())
if total_chars == 0:
return "en" # Default to English
chinese_ratio = chinese_chars / total_chars
# If more than 30% of characters are Chinese, treat as Chinese
if chinese_ratio > 0.3:
return "zh"
else:
return "en"
# Get appropriate model based on voice style and language
def get_model_for_voice_style(voice_style: str, language: str = "auto") -> str:
"""Determine which model to use based on voice style and language"""
if voice_style in VOICE_STYLES:
return VOICE_STYLES[voice_style].get("recommended_model", "tacotron2-ddc")
# Fallback logic based on language
if language == "zh":
return "fastspeech2"
else:
return "tacotron2-ddc"
# Storage management functions
def cleanup_old_files():
"""Clean up old files to free up space"""
try:
# Clean output files older than 1 hour
output_dir = Path("/tmp/output")
if output_dir.exists():
for file in output_dir.glob("*.wav"):
if file.stat().st_mtime < time.time() - 3600: # 1 hour
file.unlink()
print(f"🧹 Cleaned up old file: {file}")
# Clean voice files older than 24 hours
voices_dir = Path("/tmp/voices")
if voices_dir.exists():
for file in voices_dir.rglob("*.wav"):
if file.stat().st_mtime < time.time() - 86400: # 24 hours
file.unlink()
print(f"🧹 Cleaned up old voice file: {file}")
# Check storage usage
check_storage_usage()
except Exception as e:
print(f"⚠️ Cleanup error: {e}")
def check_storage_usage():
"""Check and log storage usage"""
try:
import shutil
# Check available space in /tmp
total, used, free = shutil.disk_usage("/tmp")
print(f"💾 Storage: {free // (2**30)}GB free of {total // (2**30)}GB total")
# Warn if running low
if free < 2 * (2**30): # Less than 2GB free
print("🚨 WARNING: Low storage space!")
return False
return True
except Exception as e:
print(f"⚠️ Storage check error: {e}")
return True
# Text cleaning with language support
def clean_text(text, language="auto"):
"""Clean text for TTS generation with language support"""
import re
if not text or not isinstance(text, str):
return "Hello" if language != "zh" else "你好"
# Auto-detect language if not specified
if language == "auto":
language = detect_language(text)
# Remove any problematic characters but keep basic punctuation
if language == "zh":
# Keep Chinese punctuation and characters
text = re.sub(r'[^\w\s\.\,\!\?\-\'\"\:\;\u4e00-\u9fff\u3000-\u303f\uff00-\uffef]', '', text)
else:
# Keep English punctuation
text = re.sub(r'[^\w\s\.\,\!\?\-\'\"\:\;]', '', text)
text = re.sub(r'\s+', ' ', text)
# Add appropriate ending punctuation if missing
if len(text) > 10 and not re.search(r'[\.\!\?。!?]$', text):
if language == "zh":
text = text + '。'
else:
text = text + '.'
text = text.strip()
if not text:
text = "Hello world" if language != "zh" else "你好世界"
return text
def upload_to_oci(file_path: str, filename: str, project_id: str, file_type="voiceover"):
"""Upload file to OCI"""
try:
if not OCI_UPLOAD_API_URL:
return None, "OCI upload API URL not configured"
url = f"{OCI_UPLOAD_API_URL}/api/upload"
with open(file_path, "rb") as f:
files = {"file": (filename, f, "audio/wav")}
data = {
"project_id": project_id,
"subfolder": "voiceover"
}
response = requests.post(url, files=files, data=data, timeout=30)
if response.status_code == 200:
result = response.json()
if result.get("status") == "success":
return result, None
else:
return None, result.get("message", "Upload failed")
else:
return None, f"Upload failed with status {response.status_code}"
except Exception as e:
return None, f"Upload error: {str(e)}"
# COMPATIBLE: Model loading with error handling
def load_tts_model(model_type="tacotron2-ddc"):
"""Load TTS model with storage optimization"""
global tts, model_loaded, current_model, model_loading
if model_loading:
print("⏳ Model is already being loaded...")
return False
if model_type not in AVAILABLE_MODELS:
print(f"❌ Model type '{model_type}' not found.")
return False
# If we're already using the correct model, no need to reload
if model_loaded and current_model == model_type:
print(f"✅ Model {model_type} is already loaded")
return True
model_loading = True
try:
# Clean up before loading new model
cleanup_old_files()
# Import TTS with error handling
try:
from TTS.api import TTS
except ImportError as e:
print(f"❌ TTS import failed: {e}")
return False
# Handle TOS acceptance automatically
import sys
from io import StringIO
old_stdin = sys.stdin
sys.stdin = StringIO('y\n')
try:
model_config = AVAILABLE_MODELS[model_type]
print(f"🚀 Loading {model_config['name']}...")
print(f" Languages: {', '.join(model_config['languages'])}")
# Clear current model from memory first if exists
if tts is not None:
print("🧹 Clearing previous model from memory...")
del tts
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load the selected model with error handling
try:
tts = TTS(model_config["model_name"]).to(DEVICE)
except Exception as e:
print(f"❌ TTS initialization failed: {e}")
# Try alternative initialization
try:
tts = TTS(model_config["model_name"])
print("✅ Model loaded without device specification")
except Exception as e2:
print(f"❌ Alternative loading also failed: {e2}")
return False
# Test the model with appropriate text
test_path = "/tmp/test_output.wav"
if "zh" in model_config["languages"]:
test_text = "你好" # Chinese test
else:
test_text = "Hello" # English test
try:
tts.tts_to_file(text=test_text, file_path=test_path)
if os.path.exists(test_path):
os.remove(test_path)
print("✅ Model tested successfully!")
except Exception as e:
print(f"⚠️ Model test failed but continuing: {e}")
# Continue even if test fails
model_loaded = True
current_model = model_type
print(f"✅ {model_config['name']} loaded successfully!")
print(f" Size: ~{model_config['size_mb']}MB")
print(f" Quality: {model_config['quality']}")
print(f" Languages: {model_config['languages']}")
return True
except Exception as e:
print(f"❌ Model failed to load: {e}")
return False
finally:
sys.stdin = old_stdin
except Exception as e:
print(f"❌ Failed to initialize TTS: {e}")
return False
finally:
model_loading = False
# Model switching logic
def ensure_correct_model(voice_style: str, text: str, language: str = "auto"):
"""Ensure the correct model is loaded for the requested voice style and language"""
global tts, model_loaded, current_model
# Determine target model
target_model = get_model_for_voice_style(voice_style, language)
print(f"🔍 Model selection: voice_style={voice_style}, language={language}, target_model={target_model}")
# If no model loaded or wrong model loaded, load the correct one
if not model_loaded or current_model != target_model:
print(f"🔄 Switching to model: {target_model} for voice style: {voice_style}, language: {language}")
return load_tts_model(target_model)
return True
# TTS generation with language-specific models
@app.post("/api/tts")
async def generate_tts(request: TTSRequest):
"""Generate TTS with multi-language support"""
try:
# Clean up before processing
cleanup_old_files()
# Auto-detect language if not specified
if request.language == "auto":
detected_language = detect_language(request.text)
print(f"🌐 Auto-detected language: {detected_language}")
else:
detected_language = request.language
# Ensure correct model is loaded
if not ensure_correct_model(request.voice_style, request.text, detected_language):
return {
"status": "error",
"message": f"Failed to load appropriate TTS model for {detected_language}",
"requires_tos_acceptance": True,
"tos_url": "https://coqui.ai/cpml.txt"
}
print(f"📥 TTS request for project: {request.project_id}")
print(f" Voice Style: {request.voice_style}")
print(f" Language: {detected_language}")
print(f" Text length: {len(request.text)} characters")
print(f" Current Model: {current_model}")
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"voiceover_{timestamp}.wav"
output_path = f"/tmp/output/{filename}"
# Ensure output directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Clean the text with language support
cleaned_text = clean_text(request.text, detected_language)
print(f"📝 Text: '{cleaned_text}'")
# Generate TTS
try:
# Use the appropriate model based on language
if current_model == "fastspeech2" and detected_language == "zh":
print("🎯 Using FastSpeech2 for Chinese text")
tts.tts_to_file(text=cleaned_text, file_path=output_path)
elif current_model == "tacotron2-ddc" and detected_language == "en":
print("🎯 Using Tacotron2-DDC for English text")
tts.tts_to_file(text=cleaned_text, file_path=output_path)
else:
# Language-model mismatch, try to switch
print(f"🔄 Language-model mismatch detected, attempting correction...")
correct_model = get_model_for_voice_style(request.voice_style, detected_language)
if load_tts_model(correct_model):
tts.tts_to_file(text=cleaned_text, file_path=output_path)
else:
raise Exception(f"Cannot process {detected_language} text with current model")
except Exception as tts_error:
print(f"❌ TTS generation failed: {tts_error}")
raise tts_error
# Verify the file was created
if not os.path.exists(output_path):
raise Exception(f"TTS failed to create output file")
file_size = os.path.getsize(output_path)
print(f"✅ TTS generated: {output_path} ({file_size} bytes)")
# Upload to OCI
upload_result, error = upload_to_oci(output_path, filename, request.project_id)
if error:
print(f"❌ OCI upload failed: {error}")
return {
"status": "partial_success",
"message": f"TTS generated but upload failed: {error}",
"local_file": output_path,
"filename": filename,
"file_size": file_size
}
print(f"✅ Upload successful: {filename}")
# Clean up local file immediately after upload
try:
os.remove(output_path)
print(f"🧹 Cleaned up local file: {output_path}")
except Exception as cleanup_error:
print(f"⚠️ Could not clean up file: {cleanup_error}")
return {
"status": "success",
"message": "TTS generated and uploaded successfully",
"filename": filename,
"oci_path": upload_result.get("path", f"{request.project_id}/voiceover/{filename}"),
"model_used": current_model,
"voice_style": request.voice_style,
"language": detected_language,
"text_preview": cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
}
except Exception as e:
print(f"❌ TTS generation error: {str(e)}")
return {
"status": "error",
"message": f"TTS generation failed: {str(e)}"
}
# Batch TTS processing
@app.post("/api/batch-tts")
async def batch_generate_tts(request: BatchTTSRequest):
"""Batch TTS with multi-language support"""
try:
cleanup_old_files()
print(f"📥 Batch TTS request for {len(request.texts)} texts")
print(f" Project: {request.project_id}")
print(f" Voice Style: {request.voice_style}")
print(f" Language: {request.language}")
results = []
for i, text in enumerate(request.texts):
try:
# Auto-detect language for each text
if request.language == "auto":
text_language = detect_language(text)
else:
text_language = request.language
print(f" Processing text {i+1}/{len(request.texts)}: {text_language} - {text[:50]}...")
single_request = TTSRequest(
text=text,
project_id=request.project_id,
voice_style=request.voice_style,
speed=request.speed,
language=text_language
)
result = await generate_tts(single_request)
results.append({
"text_index": i,
"text_preview": text[:30] + "..." if len(text) > 30 else text,
"status": result.get("status", "error"),
"message": result.get("message", ""),
"filename": result.get("filename", ""),
"oci_path": result.get("oci_path", ""),
"language": result.get("language", "unknown")
})
except Exception as e:
print(f"❌ Failed to process text {i}: {str(e)}")
results.append({
"text_index": i,
"text_preview": text[:30] + "..." if len(text) > 30 else text,
"status": "error",
"message": f"Failed to generate TTS: {str(e)}"
})
# Summary
success_count = sum(1 for r in results if r.get("status") == "success")
error_count = sum(1 for r in results if r.get("status") == "error")
print(f"📊 Batch completed: {success_count} successful, {error_count} failed")
return {
"status": "completed",
"project_id": request.project_id,
"summary": {
"total": len(results),
"successful": success_count,
"failed": error_count
},
"results": results,
"model_used": current_model
}
except Exception as e:
print(f"❌ Batch TTS generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch TTS generation failed: {str(e)}")
@app.get("/api/voice-styles")
async def get_voice_styles():
"""Get available voice styles"""
# Group voice styles by language
english_styles = {k: v for k, v in VOICE_STYLES.items() if v.get("language") == "en"}
chinese_styles = {k: v for k, v in VOICE_STYLES.items() if v.get("language") == "zh"}
return {
"status": "success",
"voice_styles": VOICE_STYLES,
"english_styles": english_styles,
"chinese_styles": chinese_styles,
"current_model": current_model if model_loaded else None,
"supported_languages": ["en", "zh", "auto"]
}
# Language detection endpoint
@app.post("/api/detect-language")
async def detect_text_language(text: str = Form(...)):
"""Detect the language of input text"""
try:
language = detect_language(text)
confidence = "high" if len(text) > 10 else "medium"
return {
"status": "success",
"language": language,
"confidence": confidence,
"text_preview": text[:100] + "..." if len(text) > 100 else text
}
except Exception as e:
return {
"status": "error",
"message": f"Language detection failed: {str(e)}"
}
@app.get("/api/health")
async def health_check():
"""Health check with storage info"""
storage_ok = check_storage_usage()
return {
"status": "healthy" if model_loaded and storage_ok else "warning",
"tts_loaded": model_loaded,
"model": current_model,
"storage_ok": storage_ok,
"device": DEVICE,
"supported_languages": AVAILABLE_MODELS.get(current_model, {}).get("languages", []) if model_loaded else []
}
@app.post("/api/cleanup")
async def manual_cleanup():
"""Manual cleanup endpoint"""
try:
cleanup_old_files()
return {
"status": "success",
"message": "Cleanup completed successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Multi-Language TTS API",
"model_loaded": model_loaded,
"model": current_model if model_loaded else "None",
"languages_supported": ["English", "Chinese"],
"storage_optimized": True
}
if __name__ == "__main__":
import uvicorn
print("🚀 Starting Multi-Language TTS API...")
print("💾 Storage management enabled")
print("🌐 Supporting English and Chinese")
print("🔊 Using Tacotron2-DDC (English) and FastSpeech2 (Chinese)")
check_storage_usage()
uvicorn.run(app, host="0.0.0.0", port=7860)