Spaces:
Running
Running
| import os | |
| import uuid | |
| import tempfile | |
| import base64 | |
| from pathlib import Path | |
| from threading import Lock | |
| from typing import Optional, Dict | |
| import requests | |
| import torch | |
| import torchaudio | |
| from torchaudio.transforms import Resample | |
| from fastapi import FastAPI, Body, Header, HTTPException, BackgroundTasks | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel, Field, HttpUrl | |
| # ========== Configuration ========== | |
| SPACE_API_KEY = os.getenv("SPACE_API_KEY") | |
| HF_TOKEN = ( | |
| os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| or os.getenv("HF_TOKEN") | |
| ) | |
| MODEL_REPO = "IndexTeam/IndexTTS-2" | |
| MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2") | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| # Max length for input text | |
| MAX_TEXT_LENGTH = 1000 | |
| # Use 16 kHz sample rate for faster/audio-size tradeoff | |
| TARGET_SR = 16000 | |
| # Limit PyTorch threads on CPU | |
| torch.set_num_threads(1) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ========== Download / Load Model ========== | |
| try: | |
| from huggingface_hub import snapshot_download | |
| from indextts.infer_v2 import IndexTTS2 | |
| except Exception as e: | |
| raise RuntimeError("Required library missing: ensure `huggingface_hub` and `indextts` are installed.") from e | |
| # Only download if not already present | |
| config_file = Path(MODEL_DIR) / "config.yaml" | |
| if not config_file.exists(): | |
| print(f"Downloading model {MODEL_REPO} to {MODEL_DIR} …") | |
| snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, token=HF_TOKEN) | |
| print("Download complete.") | |
| tts_model = IndexTTS2(cfg_path=str(config_file), model_dir=MODEL_DIR, use_fp16=False, use_cuda_kernel=False, use_deepspeed=False) | |
| print("IndexTTS-2 loaded, device:", DEVICE) | |
| # ========== FastAPI app ========== | |
| app = FastAPI(title="IndexTTS2 API") | |
| JOBS: Dict[str, Dict[str, str]] = {} | |
| JOB_LOCK = Lock() | |
| class GenerateRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH) | |
| speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio") | |
| language: Optional[str] = Field("en", description="Language code") | |
| def _require_api_key(x_api_key: Optional[str]): | |
| if not SPACE_API_KEY: | |
| return | |
| if x_api_key != SPACE_API_KEY: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| def _write_temp_audio_from_url(url: HttpUrl) -> str: | |
| response = requests.get(url, stream=True, timeout=30) | |
| if response.status_code >= 400: | |
| raise HTTPException(status_code=400, detail=f"Could not fetch speaker audio: {response.status_code}") | |
| suffix = Path(url.path).suffix or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| tmp.write(chunk) | |
| return tmp.name | |
| def _write_temp_audio_from_base64(payload: str) -> str: | |
| try: | |
| raw = base64.b64decode(payload) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| tmp.write(raw) | |
| return tmp.name | |
| def _temp_speaker_file(speaker_wav: str) -> str: | |
| if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"): | |
| return _write_temp_audio_from_url(speaker_wav) | |
| return _write_temp_audio_from_base64(speaker_wav) | |
| def _preprocess_audio_wav(path: str, target_sr: int = TARGET_SR, target_peak: float = 0.98) -> str: | |
| wav, sr = torchaudio.load(path) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| if sr != target_sr: | |
| resampler = Resample(orig_freq=sr, new_freq=target_sr) | |
| wav = resampler(wav) | |
| sr = target_sr | |
| peak = wav.abs().max().item() if wav.numel() else 0.0 | |
| if peak > 0: | |
| wav = wav * (target_peak / peak) | |
| torchaudio.save(path, wav, sr, bits_per_sample=16) | |
| return path | |
| def _set_job(job_id: str, **kwargs): | |
| with JOB_LOCK: | |
| JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs} | |
| def _get_job(job_id: str) -> Optional[Dict[str, str]]: | |
| with JOB_LOCK: | |
| data = JOBS.get(job_id) | |
| return dict(data) if data else None | |
| def _pop_job(job_id: str) -> Optional[Dict[str, str]]: | |
| with JOB_LOCK: | |
| return JOBS.pop(job_id, None) | |
| def _cleanup_files(*paths: str): | |
| for p in paths: | |
| try: | |
| os.remove(p) | |
| except OSError: | |
| pass | |
| def _run_generate_job(job_id: str, payload: Dict[str, str]): | |
| speaker_file = None | |
| output_file = None | |
| _set_job(job_id, status="processing") | |
| try: | |
| speaker_file = _temp_speaker_file(payload["speaker_wav"]) | |
| speaker_file = _preprocess_audio_wav(speaker_file, target_sr=TARGET_SR) | |
| output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav") | |
| # Use spk_audio_prompt — this model requires audio prompt | |
| tts_model.infer( | |
| text=payload["text"], | |
| spk_audio_prompt=speaker_file, | |
| output_path=output_file, | |
| use_random=False, | |
| verbose=False, | |
| ) | |
| if not Path(output_file).exists(): | |
| raise RuntimeError(f"TTS generation failed — output file not created.") | |
| _set_job(job_id, status="completed", output_file=output_file) | |
| except Exception as exc: | |
| _cleanup_files(speaker_file or "", output_file or "") | |
| _set_job(job_id, status="error", error=str(exc)) | |
| def generate( | |
| payload: GenerateRequest = Body(...), | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| x_api_key: Optional[str] = Header(default=None), | |
| ): | |
| _require_api_key(x_api_key) | |
| job_id = str(uuid.uuid4()) | |
| _set_job(job_id, status="queued") | |
| background_tasks.add_task(_run_generate_job, job_id, payload.dict()) | |
| return JSONResponse( | |
| status_code=202, | |
| content={ | |
| "job_id": job_id, | |
| "status": "queued", | |
| "status_url": f"/status/{job_id}", | |
| "result_url": f"/result/{job_id}", | |
| }, | |
| ) | |
| def status(job_id: str, x_api_key: Optional[str] = Header(default=None)): | |
| _require_api_key(x_api_key) | |
| job = _get_job(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| resp = {"job_id": job_id, "status": job.get("status", "unknown")} | |
| if "error" in job: | |
| resp["error"] = job["error"] | |
| return resp | |
| def result(job_id: str, x_api_key: Optional[str] = Header(default=None)): | |
| _require_api_key(x_api_key) | |
| job = _get_job(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| if job.get("status") != "completed": | |
| raise HTTPException(status_code=409, detail=f"Job not ready (status={job.get('status')})") | |
| output_file = job.get("output_file") | |
| if not output_file or not Path(output_file).exists(): | |
| _pop_job(job_id) | |
| raise HTTPException(status_code=410, detail="Result missing or expired") | |
| # cleanup after sending | |
| background = BackgroundTasks() | |
| background.add_task(_cleanup_files, output_file) | |
| _pop_job(job_id) | |
| return FileResponse(output_file, media_type="audio/wav", filename="output.wav", background=background) | |