indextts2-api / app.py
ataberkkilavuzcu's picture
Update app.py
e522f34 verified
raw
history blame
7.34 kB
import os
import uuid
import tempfile
import base64
from pathlib import Path
from threading import Lock
from typing import Optional, Dict
import requests
import torch
import torchaudio
from torchaudio.transforms import Resample
from fastapi import FastAPI, Body, Header, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel, Field, HttpUrl
# ========== Configuration ==========
SPACE_API_KEY = os.getenv("SPACE_API_KEY")
HF_TOKEN = (
os.getenv("HUGGING_FACE_HUB_TOKEN")
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
or os.getenv("HF_TOKEN")
)
MODEL_REPO = "IndexTeam/IndexTTS-2"
MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
os.makedirs(MODEL_DIR, exist_ok=True)
# Max length for input text
MAX_TEXT_LENGTH = 1000
# Use 16 kHz sample rate for faster/audio-size tradeoff
TARGET_SR = 16000
# Limit PyTorch threads on CPU
torch.set_num_threads(1)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ========== Download / Load Model ==========
try:
from huggingface_hub import snapshot_download
from indextts.infer_v2 import IndexTTS2
except Exception as e:
raise RuntimeError("Required library missing: ensure `huggingface_hub` and `indextts` are installed.") from e
# Only download if not already present
config_file = Path(MODEL_DIR) / "config.yaml"
if not config_file.exists():
print(f"Downloading model {MODEL_REPO} to {MODEL_DIR} …")
snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, token=HF_TOKEN)
print("Download complete.")
tts_model = IndexTTS2(cfg_path=str(config_file), model_dir=MODEL_DIR, use_fp16=False, use_cuda_kernel=False, use_deepspeed=False)
print("IndexTTS-2 loaded, device:", DEVICE)
# ========== FastAPI app ==========
app = FastAPI(title="IndexTTS2 API")
JOBS: Dict[str, Dict[str, str]] = {}
JOB_LOCK = Lock()
class GenerateRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
language: Optional[str] = Field("en", description="Language code")
def _require_api_key(x_api_key: Optional[str]):
if not SPACE_API_KEY:
return
if x_api_key != SPACE_API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized")
def _write_temp_audio_from_url(url: HttpUrl) -> str:
response = requests.get(url, stream=True, timeout=30)
if response.status_code >= 400:
raise HTTPException(status_code=400, detail=f"Could not fetch speaker audio: {response.status_code}")
suffix = Path(url.path).suffix or ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
tmp.write(chunk)
return tmp.name
def _write_temp_audio_from_base64(payload: str) -> str:
try:
raw = base64.b64decode(payload)
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(raw)
return tmp.name
def _temp_speaker_file(speaker_wav: str) -> str:
if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"):
return _write_temp_audio_from_url(speaker_wav)
return _write_temp_audio_from_base64(speaker_wav)
def _preprocess_audio_wav(path: str, target_sr: int = TARGET_SR, target_peak: float = 0.98) -> str:
wav, sr = torchaudio.load(path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
resampler = Resample(orig_freq=sr, new_freq=target_sr)
wav = resampler(wav)
sr = target_sr
peak = wav.abs().max().item() if wav.numel() else 0.0
if peak > 0:
wav = wav * (target_peak / peak)
torchaudio.save(path, wav, sr, bits_per_sample=16)
return path
def _set_job(job_id: str, **kwargs):
with JOB_LOCK:
JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs}
def _get_job(job_id: str) -> Optional[Dict[str, str]]:
with JOB_LOCK:
data = JOBS.get(job_id)
return dict(data) if data else None
def _pop_job(job_id: str) -> Optional[Dict[str, str]]:
with JOB_LOCK:
return JOBS.pop(job_id, None)
def _cleanup_files(*paths: str):
for p in paths:
try:
os.remove(p)
except OSError:
pass
def _run_generate_job(job_id: str, payload: Dict[str, str]):
speaker_file = None
output_file = None
_set_job(job_id, status="processing")
try:
speaker_file = _temp_speaker_file(payload["speaker_wav"])
speaker_file = _preprocess_audio_wav(speaker_file, target_sr=TARGET_SR)
output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
# Use spk_audio_prompt — this model requires audio prompt
tts_model.infer(
text=payload["text"],
spk_audio_prompt=speaker_file,
output_path=output_file,
use_random=False,
verbose=False,
)
if not Path(output_file).exists():
raise RuntimeError(f"TTS generation failed — output file not created.")
_set_job(job_id, status="completed", output_file=output_file)
except Exception as exc:
_cleanup_files(speaker_file or "", output_file or "")
_set_job(job_id, status="error", error=str(exc))
@app.post("/generate")
def generate(
payload: GenerateRequest = Body(...),
background_tasks: BackgroundTasks = BackgroundTasks(),
x_api_key: Optional[str] = Header(default=None),
):
_require_api_key(x_api_key)
job_id = str(uuid.uuid4())
_set_job(job_id, status="queued")
background_tasks.add_task(_run_generate_job, job_id, payload.dict())
return JSONResponse(
status_code=202,
content={
"job_id": job_id,
"status": "queued",
"status_url": f"/status/{job_id}",
"result_url": f"/result/{job_id}",
},
)
@app.get("/status/{job_id}")
def status(job_id: str, x_api_key: Optional[str] = Header(default=None)):
_require_api_key(x_api_key)
job = _get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
resp = {"job_id": job_id, "status": job.get("status", "unknown")}
if "error" in job:
resp["error"] = job["error"]
return resp
@app.get("/result/{job_id}")
def result(job_id: str, x_api_key: Optional[str] = Header(default=None)):
_require_api_key(x_api_key)
job = _get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.get("status") != "completed":
raise HTTPException(status_code=409, detail=f"Job not ready (status={job.get('status')})")
output_file = job.get("output_file")
if not output_file or not Path(output_file).exists():
_pop_job(job_id)
raise HTTPException(status_code=410, detail="Result missing or expired")
# cleanup after sending
background = BackgroundTasks()
background.add_task(_cleanup_files, output_file)
_pop_job(job_id)
return FileResponse(output_file, media_type="audio/wav", filename="output.wav", background=background)