import io import json import os import re from typing import Any import importlib import torch _pt_utils = importlib.import_module("transformers.pytorch_utils") if not hasattr(_pt_utils, "isin_mps_friendly"): def _isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor) -> torch.Tensor: if test_elements.device.type == "mps": test_elements = test_elements.cpu() return torch.isin(elements, test_elements) _pt_utils.isin_mps_friendly = _isin_mps_friendly import numpy as np import soundfile as sf import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from huggingface_hub import snapshot_download from pydantic import BaseModel, Field, field_validator from TTS.api import TTS HOST = "0.0.0.0" PORT = 7860 DEFAULT_SPEAKER = os.environ.get("COQUI_DEFAULT_SPEAKER", "p228") REPOS: dict[str, str] = { "en": os.environ.get("HF_TTS_EN_REPO", "Resilient-Coders/coqui-vctk-en"), "es": os.environ.get("HF_TTS_ES_REPO", "Resilient-Coders/coqui-css10-es"), "vi": os.environ.get("HF_TTS_VI_REPO", "Resilient-Coders/mms-tts-vie"), } # Vietnamese uses Fairseq format. Coqui loads it via model_name (model_dir path), # which calls _load_fairseq_from_dir and never reads config.json. # We mirror the HF snapshot files into TTS_HOME so model_name lookup finds them. TTS_HOME = os.path.join(os.path.expanduser("~"), ".local", "share", "tts") VI_MODEL_NAME = "tts_models/vie/fairseq/vits" VI_TTS_HOME_DIR = os.path.join(TTS_HOME, "tts_models--vie--fairseq--vits") WEIGHT_FILE_CANDIDATES = ["model.pth", "model_file.pth.tar", "model_file.pth"] def resolve_weights(local_dir: str) -> str: for name in WEIGHT_FILE_CANDIDATES: p = os.path.join(local_dir, name) if os.path.isfile(p): return p raise RuntimeError(f"No weight file found in {local_dir}") app = FastAPI(title="aiDoc TTS Space", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) tts_instances: dict[str, TTS] = {} @app.on_event("startup") async def preload_models() -> None: import asyncio loop = asyncio.get_event_loop() for lang in REPOS: await loop.run_in_executor(None, get_tts, lang) class SynthesizeRequest(BaseModel): text: str = Field(min_length=1) speaker_idx: str | None = None language: str = "en" @field_validator("language") @classmethod def normalize_language(cls, v: str) -> str: key = (v or "en").strip().lower() if key not in REPOS: raise ValueError(f"Unsupported language: {v!r}. Use one of: {', '.join(sorted(REPOS))}.") return key PATH_KEYS = ("speakers_file", "speaker_ids_file", "d_vector_file") def _patch_dict(obj: dict, local_dir: str) -> bool: """Recursively fix off-machine absolute paths in a config dict. Returns True if anything changed.""" changed = False for key, val in obj.items(): if isinstance(val, dict): if _patch_dict(val, local_dir): changed = True elif key in PATH_KEYS and isinstance(val, str) and val and not os.path.isfile(val): candidate = os.path.join(local_dir, os.path.basename(val)) if os.path.isfile(candidate): obj[key] = candidate changed = True print(f"[tts] patched config key {key!r} -> {candidate}", flush=True) return changed def patch_config(local_dir: str) -> str: """Patch any off-machine absolute paths in config.json, overwriting in place. The config stores paths in both top-level and nested (model_args) dicts. We resolve the symlink to the actual HF blob, chmod it writable, patch all occurrences, and overwrite in place. Safe in a container that resets each run. """ config_path = os.path.join(local_dir, "config.json") real_path = os.path.realpath(config_path) with open(real_path) as f: cfg = json.load(f) if _patch_dict(cfg, local_dir): try: os.chmod(real_path, 0o644) except OSError as e: print(f"[tts] chmod warning: {e}", flush=True) with open(real_path, "w") as f: json.dump(cfg, f) print(f"[tts] wrote patched config to {real_path}", flush=True) return config_path def setup_fairseq_vi(local_dir: str) -> None: """Mirror HF snapshot files for the Vietnamese fairseq model into TTS_HOME. Coqui's fairseq loader uses model_name -> model_dir -> _load_fairseq_from_dir, which creates a blank VitsConfig and never reads config.json. Setting up the TTS_HOME directory lets us use model_name without re-downloading from Coqui's (defunct) registry, and avoids the config format incompatibility. """ os.makedirs(VI_TTS_HOME_DIR, exist_ok=True) for fname in os.listdir(local_dir): if fname.startswith("."): continue src = os.path.realpath(os.path.join(local_dir, fname)) dst = os.path.join(VI_TTS_HOME_DIR, fname) if not os.path.exists(dst) and os.path.isfile(src): try: os.symlink(src, dst) except OSError: import shutil shutil.copy2(src, dst) print(f"[tts] vi: linked {fname}", flush=True) def get_tts(lang: str) -> TTS: if lang not in REPOS: raise HTTPException(status_code=400, detail=f"Unsupported language: {lang}") if lang not in tts_instances: repo_id = REPOS[lang] print(f"[tts] downloading repo for {lang}: {repo_id}", flush=True) local_dir = snapshot_download(repo_id=repo_id) if lang == "vi": # Fairseq format: use model_name so Coqui routes through # _load_fairseq_from_dir (blank VitsConfig, bypasses config.json parse). setup_fairseq_vi(local_dir) print(f"[tts] loading vi via model_name={VI_MODEL_NAME}", flush=True) tts_instances[lang] = TTS(model_name=VI_MODEL_NAME, progress_bar=False).to("cpu") else: weights = resolve_weights(local_dir) config_path = patch_config(local_dir) print(f"[tts] loading {weights}", flush=True) tts_instances[lang] = TTS(model_path=weights, config_path=config_path, progress_bar=False).to("cpu") return tts_instances[lang] def get_speakers(model: TTS) -> list[str]: manager = getattr(getattr(model, "synthesizer", None), "tts_model", None) speaker_manager = getattr(manager, "speaker_manager", None) if speaker_manager is None: return [] speaker_names: Any = getattr(speaker_manager, "speaker_names", None) if isinstance(speaker_names, list): return [str(name) for name in speaker_names] name_to_id: Any = getattr(speaker_manager, "name_to_id", None) if isinstance(name_to_id, dict): return [str(name) for name in name_to_id.keys()] speakers: Any = getattr(speaker_manager, "speakers", None) if isinstance(speakers, dict): return [str(name) for name in speakers.keys()] return [] def resolve_sample_rate(model: TTS) -> int: synthesizer = getattr(model, "synthesizer", None) rate = getattr(synthesizer, "output_sample_rate", None) if synthesizer else None if isinstance(rate, int) and rate > 0: return rate return 22050 @app.get("/") async def root() -> dict[str, Any]: return { "service": "aidoc-tts", "endpoints": ["/health", "/speakers", "/synthesize"], } @app.get("/health") async def health() -> dict[str, Any]: return { "status": "ok", "device": "cpu", "loaded_languages": sorted(tts_instances.keys()), "supported_languages": sorted(REPOS.keys()), } @app.get("/speakers") async def speakers() -> dict[str, list[str]]: model = get_tts("en") return {"speakers": get_speakers(model)} def split_sentences(text: str) -> list[str]: text = re.sub(r"[\r\n]+", " ", text) text = re.sub(r"[\u2022\u00b7\u2023\u25aa\u25b8\u25ba]+", "", text) text = re.sub(r"\s{2,}", " ", text).strip() raw = re.split(r"(?<=[.!?])\s+", text) sentences: list[str] = [] current = "" for chunk in raw: chunk = chunk.strip() if not chunk: continue if len(current) + len(chunk) > 200 and current: sentences.append(current.strip()) current = chunk else: current = (current + " " + chunk).strip() if current: sentences.append(current.strip()) return [s for s in sentences if s] @app.post("/synthesize") async def synthesize(payload: SynthesizeRequest) -> Response: lang = payload.language model = get_tts(lang) sample_rate = resolve_sample_rate(model) sentences = split_sentences(payload.text) if not sentences: raise HTTPException(status_code=400, detail="No speakable text provided") audio_parts: list[Any] = [] for sentence in sentences: try: if lang == "en": speaker = payload.speaker_idx or DEFAULT_SPEAKER wav = model.tts(text=sentence, speaker=speaker) else: wav = model.tts(text=sentence) audio_parts.append(np.array(wav, dtype=np.float32)) except Exception as error: print(f"[tts] skipping sentence due to error: {error!r}", flush=True) continue if not audio_parts: raise HTTPException(status_code=500, detail="All sentences failed to synthesize") combined = np.concatenate(audio_parts) buffer = io.BytesIO() sf.write(buffer, combined, samplerate=sample_rate, format="WAV") return Response(content=buffer.getvalue(), media_type="audio/wav") if __name__ == "__main__": uvicorn.run("app:app", host=HOST, port=PORT, reload=False)