parler-tts-api / app.py
yukee1992's picture
Create app.py
2dee5ee verified
raw
history blame
5.07 kB
from fastapi import FastAPI, HTTPException, Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import transformers
import torch
import json
import logging
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="TTS API", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
model = None
processor = None
model_loaded = False
model_type = "none"
@app.on_event("startup")
async def startup_event():
"""Initialize the application"""
global model, processor, model_loaded, model_type
logger.info("=== TTS API Starting ===")
logger.info(f"Transformers version: {transformers.__version__}")
logger.info(f"Torch version: {torch.__version__}")
await load_model()
async def load_model():
"""Load the TTS model with fallbacks"""
global model, processor, model_loaded, model_type
try:
logger.info("Step 1: Importing transformers...")
from transformers import AutoProcessor, AutoModel
logger.info("Step 2: Loading Parler-TTS processor...")
processor = AutoProcessor.from_pretrained(
"parler-tts/parler-tts-mini-v1",
trust_remote_code=True
)
logger.info("Step 3: Loading Parler-TTS model...")
model = AutoModel.from_pretrained(
"parler-tts/parler-tts-mini-v1",
trust_remote_code=True
)
model_loaded = True
model_type = "parler-tts"
logger.info("βœ… SUCCESS: Parler-TTS model loaded successfully!")
except Exception as e:
logger.error(f"❌ FAILED: Parler-TTS loading error: {e}")
logger.info("Trying fallback to Bark model...")
try:
processor = AutoProcessor.from_pretrained("suno/bark-small")
model = AutoModel.from_pretrained("suno/bark-small")
model_loaded = True
model_type = "bark"
logger.info("βœ… SUCCESS: Bark model loaded as fallback!")
except Exception as fallback_error:
logger.error(f"❌ FAILED: All models failed: {fallback_error}")
model_loaded = False
model_type = "none"
@app.get("/")
async def root():
return {
"message": "TTS API Service",
"status": "operational" if model_loaded else "degraded",
"model_loaded": model_loaded,
"model_type": model_type,
"transformers_version": transformers.__version__,
"torch_version": torch.__version__
}
@app.get("/health")
async def health():
return {
"status": "healthy" if model_loaded else "degraded",
"model_loaded": model_loaded,
"model_type": model_type
}
@app.get("/debug")
async def debug():
"""Debug endpoint to check environment"""
return {
"python_version": "3.9", # Hugging Face uses Python 3.9
"transformers_version": transformers.__version__,
"torch_version": torch.__version__,
"model_loaded": model_loaded,
"model_type": model_type,
"cuda_available": torch.cuda.is_available(),
"space_ready": True
}
@app.post("/api/generate-voiceovers")
async def generate_voiceovers(
project_id: str = Form(...),
voiceover_scenes: str = Form(...),
upload_to_oci: bool = Form(False)
):
"""Generate voiceovers from text scenes"""
if not model_loaded:
raise HTTPException(
status_code=503,
detail="TTS model not loaded. Service unavailable."
)
try:
# Parse input scenes
scenes = json.loads(voiceover_scenes)
if not isinstance(scenes, list):
raise HTTPException(
status_code=400,
detail="voiceover_scenes must be a JSON array"
)
logger.info(f"Processing {len(scenes)} scenes for project {project_id}")
# Return success response
return {
"status": "success",
"project_id": project_id,
"scenes_processed": len(scenes),
"model_type": model_type,
"message": f"Ready to process {len(scenes)} voiceover scenes using {model_type}",
"expected_files": [f"voiceover_{i:02d}.wav" for i in range(1, len(scenes) + 1)]
}
except json.JSONDecodeError:
raise HTTPException(
status_code=400,
detail="Invalid JSON format for voiceover_scenes"
)
except Exception as e:
logger.error(f"Error processing request: {e}")
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
logger.info("Starting TTS API server on port 7860...")
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")