Spaces:
Running
Running
| import base64 | |
| import hashlib | |
| import os | |
| import tempfile | |
| import uuid | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| from pathlib import Path | |
| from threading import Lock | |
| from typing import Dict, Optional | |
| import requests | |
| import torch | |
| import torchaudio | |
| from torchaudio.transforms import Resample | |
| from fastapi import BackgroundTasks, Body, FastAPI, Header, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel, Field, HttpUrl | |
| # ---------------------------- | |
| # Config / Tunables | |
| # ---------------------------- | |
| SPACE_API_KEY = os.getenv("SPACE_API_KEY") | |
| HF_TOKEN = ( | |
| os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| or os.getenv("HF_TOKEN") | |
| ) | |
| MODEL_REPO = os.getenv("MODEL_REPO", "IndexTeam/IndexTTS-2") | |
| MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2") | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "1000")) | |
| DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "en") | |
| TARGET_SR = int(os.getenv("TARGET_SR", "16000")) # lowered to 16 kHz for speed | |
| TORCH_NUM_THREADS = int(os.getenv("TORCH_NUM_THREADS", "2")) | |
| # Embedding cache settings | |
| EMBED_CACHE_MAX = int(os.getenv("EMBED_CACHE_MAX", "128")) # max entries | |
| EMBED_CACHE_TTL = int(os.getenv("EMBED_CACHE_TTL", str(60 * 60 * 24))) # 24h by default | |
| # Threadpool for bounded parallel jobs (keeps worker threads limited) | |
| WORKER_COUNT = int(os.getenv("WORKER_COUNT", "1")) # keep low on CPU | |
| # ---------------------------- | |
| # Torch settings | |
| # ---------------------------- | |
| torch.set_num_threads(TORCH_NUM_THREADS) | |
| try: | |
| # optional: limit interop threads | |
| torch.set_num_interop_threads(max(1, TORCH_NUM_THREADS // 2)) | |
| except Exception: | |
| pass | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------------------------- | |
| # Hugging Face login (if token) | |
| # ---------------------------- | |
| if HF_TOKEN: | |
| os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN | |
| os.environ["HF_TOKEN"] = HF_TOKEN | |
| try: | |
| from huggingface_hub import login | |
| login(token=HF_TOKEN, add_to_git_credential=False) | |
| except Exception: | |
| pass | |
| # ---------------------------- | |
| # Optionally download model snapshot (only if missing) | |
| # ---------------------------- | |
| try: | |
| from huggingface_hub import snapshot_download | |
| cfg_path = Path(MODEL_DIR) / "config.yaml" | |
| if not cfg_path.exists(): | |
| print(f"Config missing; downloading model snapshot {MODEL_REPO} to {MODEL_DIR} ...") | |
| snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, token=HF_TOKEN) | |
| print("Download complete.") | |
| except Exception as exc: | |
| print(f"Warning: snapshot_download skipped or failed: {exc}") | |
| # ---------------------------- | |
| # Load IndexTTS2 model (CPU mode safe defaults) | |
| # ---------------------------- | |
| try: | |
| from indextts.infer_v2 import IndexTTS2 | |
| except Exception as exc: | |
| raise RuntimeError("indextts.infer_v2 import failed. Make sure IndexTTS2 is installed.") from exc | |
| cfg_path = os.path.join(MODEL_DIR, "config.yaml") | |
| if not Path(cfg_path).exists(): | |
| raise FileNotFoundError(f"Config file not found at {cfg_path}. Place model files in {MODEL_DIR}.") | |
| # Use CPU-safe options. If GPU becomes available, you can toggle use_fp16/use_cuda_kernel. | |
| tts_model = IndexTTS2( | |
| cfg_path=cfg_path, | |
| model_dir=MODEL_DIR, | |
| use_fp16=False, # CPU doesn't support FP16 reliably | |
| use_cuda_kernel=False, | |
| use_deepspeed=False, | |
| ) | |
| print("IndexTTS2 loaded.") | |
| # ---------------------------- | |
| # App + job state | |
| # ---------------------------- | |
| app = FastAPI(title="indextts2-api-optimized", version="1.0.0") | |
| JOBS: Dict[str, Dict[str, str]] = {} | |
| JOB_LOCK = Lock() | |
| # Threadpool for running TTS jobs; limits concurrency to WORKER_COUNT | |
| EXECUTOR = ThreadPoolExecutor(max_workers=WORKER_COUNT) | |
| # ---------------------------- | |
| # Simple LRU-like embedding cache (in-memory) | |
| # ---------------------------- | |
| class _EmbedCacheEntry: | |
| def __init__(self, emb_tensor: torch.Tensor): | |
| self.emb = emb_tensor.detach().cpu() # keep on CPU, detached | |
| self.ts = time.time() | |
| EMBED_CACHE: Dict[str, _EmbedCacheEntry] = {} | |
| EMBED_CACHE_LOCK = Lock() | |
| def _evict_cache_if_needed(): | |
| with EMBED_CACHE_LOCK: | |
| if len(EMBED_CACHE) <= EMBED_CACHE_MAX: | |
| return | |
| # Simple eviction: remove oldest entries | |
| items = sorted(EMBED_CACHE.items(), key=lambda kv: kv[1].ts) | |
| for key, _ in items[: max(1, len(items) - EMBED_CACHE_MAX)]: | |
| EMBED_CACHE.pop(key, None) | |
| def _get_cache_key_for_file(path: str) -> str: | |
| # Hash the file contents (fast enough for short audio) | |
| h = hashlib.sha256() | |
| with open(path, "rb") as f: | |
| while True: | |
| chunk = f.read(8192) | |
| if not chunk: | |
| break | |
| h.update(chunk) | |
| return h.hexdigest() | |
| def _cache_get(key: str) -> Optional[torch.Tensor]: | |
| with EMBED_CACHE_LOCK: | |
| entry = EMBED_CACHE.get(key) | |
| if not entry: | |
| return None | |
| if (time.time() - entry.ts) > EMBED_CACHE_TTL: | |
| EMBED_CACHE.pop(key, None) | |
| return None | |
| # update timestamp for LRU-ish behavior | |
| entry.ts = time.time() | |
| return entry.emb.clone() | |
| def _cache_set(key: str, emb: torch.Tensor): | |
| with EMBED_CACHE_LOCK: | |
| EMBED_CACHE[key] = _EmbedCacheEntry(emb) | |
| _evict_cache_if_needed() | |
| # ---------------------------- | |
| # Utilities for audio input handling | |
| # ---------------------------- | |
| def _write_temp_audio_from_url(url: HttpUrl) -> str: | |
| response = requests.get(url, stream=True, timeout=30) | |
| if response.status_code >= 400: | |
| raise HTTPException(status_code=400, detail=f"Could not fetch speaker audio: {response.status_code}") | |
| suffix = Path(url.path).suffix or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| tmp.write(chunk) | |
| return tmp.name | |
| def _write_temp_audio_from_base64(payload: str) -> str: | |
| try: | |
| raw = base64.b64decode(payload) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| tmp.write(raw) | |
| return tmp.name | |
| def _temp_speaker_file(speaker_wav: str) -> str: | |
| if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"): | |
| return _write_temp_audio_from_url(HttpUrl(speaker_wav)) | |
| return _write_temp_audio_from_base64(speaker_wav) | |
| def _preprocess_audio_wav(path: str, target_sr: int = TARGET_SR, target_peak: float = 0.98) -> str: | |
| """ | |
| Convert to mono, resample to target_sr, and peak-normalize. | |
| Overwrites the input file. | |
| """ | |
| wav, sr = torchaudio.load(path) | |
| # Mono | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| # Resample if needed | |
| if sr != target_sr: | |
| resampler = Resample(orig_freq=sr, new_freq=target_sr) | |
| wav = resampler(wav) | |
| sr = target_sr | |
| # Peak normalize | |
| peak = wav.abs().max().item() if wav.numel() else 0.0 | |
| if peak > 0: | |
| scale = min(target_peak / peak, 1.0) | |
| wav = wav * scale | |
| # Overwrite file in 16-bit PCM | |
| torchaudio.save(path, wav, sr, bits_per_sample=16) | |
| return path | |
| # ---------------------------- | |
| # Embedding extraction helper (tries multiple API variants) | |
| # ---------------------------- | |
| def _compute_spk_embedding(speaker_path: str) -> torch.Tensor: | |
| """ | |
| Returns a CPU tensor containing the speaker embedding. | |
| Tries multiple methods to extract embedding (get_spk_emb, extract_spk_emb, etc.) | |
| """ | |
| # Key: use hash of file contents | |
| key = _get_cache_key_for_file(speaker_path) | |
| cached = _cache_get(key) | |
| if cached is not None: | |
| return cached | |
| # Ensure audio preprocessed (mono/resample/normalize) | |
| _preprocess_audio_wav(speaker_path, target_sr=TARGET_SR) | |
| # Try known wrapper method names (depending on IndexTTS2 version) | |
| emb = None | |
| try: | |
| if hasattr(tts_model, "get_spk_emb"): | |
| emb = tts_model.get_spk_emb(speaker_path) | |
| elif hasattr(tts_model, "extract_spk_emb"): | |
| emb = tts_model.extract_spk_emb(speaker_path) | |
| elif hasattr(tts_model, "spk_encoder") and hasattr(tts_model.spk_encoder, "embed_utterance"): | |
| # some wrappers expose internal encoders | |
| wav, sr = torchaudio.load(speaker_path) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| wav = wav.squeeze(0).numpy() # expected shape for some encoders | |
| emb = tts_model.spk_encoder.embed_utterance(wav) | |
| emb = torch.from_numpy(emb) | |
| else: | |
| raise RuntimeError("No known speaker embedding method available on tts_model.") | |
| except Exception as exc: | |
| # If the model doesn't provide a direct API or something fails, fallback to infer path | |
| # where infer() might internally compute embedding. In that case we return None to indicate | |
| # that caller should call infer with spk_audio_prompt. | |
| raise RuntimeError(f"Failed to compute speaker embedding: {exc}") from exc | |
| # Normalize & store on CPU as float32 | |
| if isinstance(emb, torch.Tensor): | |
| emb_cpu = emb.detach().cpu().float() | |
| else: | |
| emb_cpu = torch.tensor(emb, dtype=torch.float32, device="cpu") | |
| _cache_set(key, emb_cpu) | |
| return emb_cpu | |
| # ---------------------------- | |
| # Job helpers | |
| # ---------------------------- | |
| def _set_job(job_id: str, **kwargs): | |
| with JOB_LOCK: | |
| JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs} | |
| def _get_job(job_id: str) -> Optional[Dict[str, str]]: | |
| with JOB_LOCK: | |
| data = JOBS.get(job_id) | |
| return dict(data) if data else None | |
| def _pop_job(job_id: str) -> Optional[Dict[str, str]]: | |
| with JOB_LOCK: | |
| return JOBS.pop(job_id, None) | |
| def _cleanup_files(*files: str): | |
| for file_path in files: | |
| if file_path and Path(file_path).exists(): | |
| try: | |
| Path(file_path).unlink(missing_ok=True) | |
| except Exception: | |
| pass | |
| def _run_generate_job(job_id: str, payload: Dict[str, str]): | |
| """ | |
| Worker function that computes (or reuses) embedding and performs TTS. | |
| """ | |
| speaker_file = None | |
| output_file = None | |
| _set_job(job_id, status="processing") | |
| try: | |
| # prepare speaker audio | |
| speaker_file = _temp_speaker_file(payload["speaker_wav"]) | |
| # preprocess (mono + resample + normalize) | |
| speaker_file = _preprocess_audio_wav(speaker_file, target_sr=TARGET_SR) | |
| # compute or fetch embedding (cached) | |
| try: | |
| spk_emb = _compute_spk_embedding(speaker_file) | |
| use_spk_emb = True | |
| except Exception as exc_emb: | |
| # If embedding extraction fails, fall back to passing audio path to infer | |
| spk_emb = None | |
| use_spk_emb = False | |
| print(f"Warning: embedding extraction failed, falling back to audio prompt: {exc_emb}") | |
| output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav") | |
| # Call inference: prefer spk_emb if available. | |
| infer_kwargs = { | |
| "text": payload["text"], | |
| "output_path": output_file, | |
| "use_random": False, | |
| "verbose": False, | |
| } | |
| # include sample_rate if supported by this wrapper | |
| try: | |
| infer_kwargs["sample_rate"] = TARGET_SR | |
| except Exception: | |
| pass | |
| if use_spk_emb and spk_emb is not None: | |
| # Use embedding path - many wrappers accept spk_emb or spk_embedding | |
| try: | |
| tts_model.infer(spk_emb=spk_emb, **infer_kwargs) | |
| except TypeError: | |
| # fallback argument name | |
| tts_model.infer(speaker_emb=spk_emb, **infer_kwargs) | |
| else: | |
| # pass the audio file as prompt (slower, model will compute embedding internally) | |
| tts_model.infer(spk_audio_prompt=speaker_file, **infer_kwargs) | |
| # Minimal validation: ensure file created | |
| if not Path(output_file).exists(): | |
| raise RuntimeError(f"TTS generation failed: output file not created at {output_file}") | |
| # Do NOT re-run heavy preprocess; only resample if the model returned a different sr (rare) | |
| try: | |
| out_wav, out_sr = torchaudio.load(output_file) | |
| if out_sr != TARGET_SR: | |
| resampler = Resample(orig_freq=out_sr, new_freq=TARGET_SR) | |
| out_wav = resampler(out_wav) | |
| torchaudio.save(output_file, out_wav, TARGET_SR, bits_per_sample=16) | |
| except Exception: | |
| # If this fails, still return the original output file | |
| pass | |
| # cleanup speaker temp (we keep output until client downloads) | |
| if speaker_file: | |
| try: | |
| Path(speaker_file).unlink(missing_ok=True) | |
| except Exception: | |
| pass | |
| _set_job(job_id, status="completed", output_file=output_file) | |
| except Exception as exc: | |
| _cleanup_files(speaker_file, output_file) | |
| _set_job(job_id, status="error", error=str(exc)) | |
| # ---------------------------- | |
| # FastAPI endpoints | |
| # ---------------------------- | |
| class GenerateRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH) | |
| speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio") | |
| language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code, default en") | |
| def _require_api_key(x_api_key: Optional[str]): | |
| if not SPACE_API_KEY: | |
| return | |
| if x_api_key != SPACE_API_KEY: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| def health(x_api_key: Optional[str] = Header(default=None)): | |
| _require_api_key(x_api_key) | |
| return {"status": "ok", "model": "indextts2", "device": DEVICE, "torch_threads": torch.get_num_threads()} | |
| def generate( | |
| payload: GenerateRequest = Body(...), | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| x_api_key: Optional[str] = Header(default=None), | |
| ): | |
| _require_api_key(x_api_key) | |
| job_id = str(uuid.uuid4()) | |
| _set_job(job_id, status="queued") | |
| # Submit to bounded threadpool to avoid uncontrolled concurrency on CPU | |
| EXECUTOR.submit(_run_generate_job, job_id, payload.dict()) | |
| return JSONResponse( | |
| status_code=202, | |
| content={ | |
| "job_id": job_id, | |
| "status": "queued", | |
| "status_url": f"/status/{job_id}", | |
| "result_url": f"/result/{job_id}", | |
| }, | |
| ) | |
| def job_status(job_id: str, x_api_key: Optional[str] = Header(default=None)): | |
| _require_api_key(x_api_key) | |
| job = _get_job(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| payload: Dict[str, str] = {"job_id": job_id, "status": job.get("status", "unknown")} | |
| if "error" in job: | |
| payload["error"] = job["error"] | |
| return payload | |
| def job_result( | |
| job_id: str, | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| x_api_key: Optional[str] = Header(default=None), | |
| ): | |
| _require_api_key(x_api_key) | |
| job = _get_job(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| status = job.get("status") | |
| if status != "completed": | |
| raise HTTPException(status_code=409, detail=f"Job not ready (status={status})") | |
| output_file = job.get("output_file") | |
| if not output_file or not Path(output_file).exists(): | |
| _pop_job(job_id) | |
| raise HTTPException(status_code=410, detail="Result expired or missing") | |
| # Remove job from memory and cleanup output after sending | |
| _pop_job(job_id) | |
| background_tasks.add_task(_cleanup_files, output_file) | |
| return FileResponse(output_file, media_type="audio/wav", filename="output.wav") | |
| def root(): | |
| return {"name": "indextts2-api-optimized", "endpoints": ["/health", "/generate", "/status/{job_id}", "/result/{job_id}"]} | |