Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import os | |
| import re | |
| from typing import Any | |
| import importlib | |
| import torch | |
| _pt_utils = importlib.import_module("transformers.pytorch_utils") | |
| if not hasattr(_pt_utils, "isin_mps_friendly"): | |
| def _isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor) -> torch.Tensor: | |
| if test_elements.device.type == "mps": | |
| test_elements = test_elements.cpu() | |
| return torch.isin(elements, test_elements) | |
| _pt_utils.isin_mps_friendly = _isin_mps_friendly | |
| import numpy as np | |
| import soundfile as sf | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import Response | |
| from huggingface_hub import snapshot_download | |
| from pydantic import BaseModel, Field, field_validator | |
| from TTS.api import TTS | |
| HOST = "0.0.0.0" | |
| PORT = 7860 | |
| DEFAULT_SPEAKER = os.environ.get("COQUI_DEFAULT_SPEAKER", "p228") | |
| REPOS: dict[str, str] = { | |
| "en": os.environ.get("HF_TTS_EN_REPO", "Resilient-Coders/coqui-vctk-en"), | |
| "es": os.environ.get("HF_TTS_ES_REPO", "Resilient-Coders/coqui-css10-es"), | |
| "vi": os.environ.get("HF_TTS_VI_REPO", "Resilient-Coders/mms-tts-vie"), | |
| } | |
| # Vietnamese uses Fairseq format. Coqui loads it via model_name (model_dir path), | |
| # which calls _load_fairseq_from_dir and never reads config.json. | |
| # We mirror the HF snapshot files into TTS_HOME so model_name lookup finds them. | |
| TTS_HOME = os.path.join(os.path.expanduser("~"), ".local", "share", "tts") | |
| VI_MODEL_NAME = "tts_models/vie/fairseq/vits" | |
| VI_TTS_HOME_DIR = os.path.join(TTS_HOME, "tts_models--vie--fairseq--vits") | |
| WEIGHT_FILE_CANDIDATES = ["model.pth", "model_file.pth.tar", "model_file.pth"] | |
| def resolve_weights(local_dir: str) -> str: | |
| for name in WEIGHT_FILE_CANDIDATES: | |
| p = os.path.join(local_dir, name) | |
| if os.path.isfile(p): | |
| return p | |
| raise RuntimeError(f"No weight file found in {local_dir}") | |
| app = FastAPI(title="aiDoc TTS Space", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| tts_instances: dict[str, TTS] = {} | |
| async def preload_models() -> None: | |
| import asyncio | |
| loop = asyncio.get_event_loop() | |
| for lang in REPOS: | |
| await loop.run_in_executor(None, get_tts, lang) | |
| class SynthesizeRequest(BaseModel): | |
| text: str = Field(min_length=1) | |
| speaker_idx: str | None = None | |
| language: str = "en" | |
| def normalize_language(cls, v: str) -> str: | |
| key = (v or "en").strip().lower() | |
| if key not in REPOS: | |
| raise ValueError(f"Unsupported language: {v!r}. Use one of: {', '.join(sorted(REPOS))}.") | |
| return key | |
| PATH_KEYS = ("speakers_file", "speaker_ids_file", "d_vector_file") | |
| def _patch_dict(obj: dict, local_dir: str) -> bool: | |
| """Recursively fix off-machine absolute paths in a config dict. Returns True if anything changed.""" | |
| changed = False | |
| for key, val in obj.items(): | |
| if isinstance(val, dict): | |
| if _patch_dict(val, local_dir): | |
| changed = True | |
| elif key in PATH_KEYS and isinstance(val, str) and val and not os.path.isfile(val): | |
| candidate = os.path.join(local_dir, os.path.basename(val)) | |
| if os.path.isfile(candidate): | |
| obj[key] = candidate | |
| changed = True | |
| print(f"[tts] patched config key {key!r} -> {candidate}", flush=True) | |
| return changed | |
| def patch_config(local_dir: str) -> str: | |
| """Patch any off-machine absolute paths in config.json, overwriting in place. | |
| The config stores paths in both top-level and nested (model_args) dicts. | |
| We resolve the symlink to the actual HF blob, chmod it writable, patch all | |
| occurrences, and overwrite in place. Safe in a container that resets each run. | |
| """ | |
| config_path = os.path.join(local_dir, "config.json") | |
| real_path = os.path.realpath(config_path) | |
| with open(real_path) as f: | |
| cfg = json.load(f) | |
| if _patch_dict(cfg, local_dir): | |
| try: | |
| os.chmod(real_path, 0o644) | |
| except OSError as e: | |
| print(f"[tts] chmod warning: {e}", flush=True) | |
| with open(real_path, "w") as f: | |
| json.dump(cfg, f) | |
| print(f"[tts] wrote patched config to {real_path}", flush=True) | |
| return config_path | |
| def setup_fairseq_vi(local_dir: str) -> None: | |
| """Mirror HF snapshot files for the Vietnamese fairseq model into TTS_HOME. | |
| Coqui's fairseq loader uses model_name -> model_dir -> _load_fairseq_from_dir, | |
| which creates a blank VitsConfig and never reads config.json. Setting up the | |
| TTS_HOME directory lets us use model_name without re-downloading from Coqui's | |
| (defunct) registry, and avoids the config format incompatibility. | |
| """ | |
| os.makedirs(VI_TTS_HOME_DIR, exist_ok=True) | |
| for fname in os.listdir(local_dir): | |
| if fname.startswith("."): | |
| continue | |
| src = os.path.realpath(os.path.join(local_dir, fname)) | |
| dst = os.path.join(VI_TTS_HOME_DIR, fname) | |
| if not os.path.exists(dst) and os.path.isfile(src): | |
| try: | |
| os.symlink(src, dst) | |
| except OSError: | |
| import shutil | |
| shutil.copy2(src, dst) | |
| print(f"[tts] vi: linked {fname}", flush=True) | |
| def get_tts(lang: str) -> TTS: | |
| if lang not in REPOS: | |
| raise HTTPException(status_code=400, detail=f"Unsupported language: {lang}") | |
| if lang not in tts_instances: | |
| repo_id = REPOS[lang] | |
| print(f"[tts] downloading repo for {lang}: {repo_id}", flush=True) | |
| local_dir = snapshot_download(repo_id=repo_id) | |
| if lang == "vi": | |
| # Fairseq format: use model_name so Coqui routes through | |
| # _load_fairseq_from_dir (blank VitsConfig, bypasses config.json parse). | |
| setup_fairseq_vi(local_dir) | |
| print(f"[tts] loading vi via model_name={VI_MODEL_NAME}", flush=True) | |
| tts_instances[lang] = TTS(model_name=VI_MODEL_NAME, progress_bar=False).to("cpu") | |
| else: | |
| weights = resolve_weights(local_dir) | |
| config_path = patch_config(local_dir) | |
| print(f"[tts] loading {weights}", flush=True) | |
| tts_instances[lang] = TTS(model_path=weights, config_path=config_path, progress_bar=False).to("cpu") | |
| return tts_instances[lang] | |
| def get_speakers(model: TTS) -> list[str]: | |
| manager = getattr(getattr(model, "synthesizer", None), "tts_model", None) | |
| speaker_manager = getattr(manager, "speaker_manager", None) | |
| if speaker_manager is None: | |
| return [] | |
| speaker_names: Any = getattr(speaker_manager, "speaker_names", None) | |
| if isinstance(speaker_names, list): | |
| return [str(name) for name in speaker_names] | |
| name_to_id: Any = getattr(speaker_manager, "name_to_id", None) | |
| if isinstance(name_to_id, dict): | |
| return [str(name) for name in name_to_id.keys()] | |
| speakers: Any = getattr(speaker_manager, "speakers", None) | |
| if isinstance(speakers, dict): | |
| return [str(name) for name in speakers.keys()] | |
| return [] | |
| def resolve_sample_rate(model: TTS) -> int: | |
| synthesizer = getattr(model, "synthesizer", None) | |
| rate = getattr(synthesizer, "output_sample_rate", None) if synthesizer else None | |
| if isinstance(rate, int) and rate > 0: | |
| return rate | |
| return 22050 | |
| async def root() -> dict[str, Any]: | |
| return { | |
| "service": "aidoc-tts", | |
| "endpoints": ["/health", "/speakers", "/synthesize"], | |
| } | |
| async def health() -> dict[str, Any]: | |
| return { | |
| "status": "ok", | |
| "device": "cpu", | |
| "loaded_languages": sorted(tts_instances.keys()), | |
| "supported_languages": sorted(REPOS.keys()), | |
| } | |
| async def speakers() -> dict[str, list[str]]: | |
| model = get_tts("en") | |
| return {"speakers": get_speakers(model)} | |
| def split_sentences(text: str) -> list[str]: | |
| text = re.sub(r"[\r\n]+", " ", text) | |
| text = re.sub(r"[\u2022\u00b7\u2023\u25aa\u25b8\u25ba]+", "", text) | |
| text = re.sub(r"\s{2,}", " ", text).strip() | |
| raw = re.split(r"(?<=[.!?])\s+", text) | |
| sentences: list[str] = [] | |
| current = "" | |
| for chunk in raw: | |
| chunk = chunk.strip() | |
| if not chunk: | |
| continue | |
| if len(current) + len(chunk) > 200 and current: | |
| sentences.append(current.strip()) | |
| current = chunk | |
| else: | |
| current = (current + " " + chunk).strip() | |
| if current: | |
| sentences.append(current.strip()) | |
| return [s for s in sentences if s] | |
| async def synthesize(payload: SynthesizeRequest) -> Response: | |
| lang = payload.language | |
| model = get_tts(lang) | |
| sample_rate = resolve_sample_rate(model) | |
| sentences = split_sentences(payload.text) | |
| if not sentences: | |
| raise HTTPException(status_code=400, detail="No speakable text provided") | |
| audio_parts: list[Any] = [] | |
| for sentence in sentences: | |
| try: | |
| if lang == "en": | |
| speaker = payload.speaker_idx or DEFAULT_SPEAKER | |
| wav = model.tts(text=sentence, speaker=speaker) | |
| else: | |
| wav = model.tts(text=sentence) | |
| audio_parts.append(np.array(wav, dtype=np.float32)) | |
| except Exception as error: | |
| print(f"[tts] skipping sentence due to error: {error!r}", flush=True) | |
| continue | |
| if not audio_parts: | |
| raise HTTPException(status_code=500, detail="All sentences failed to synthesize") | |
| combined = np.concatenate(audio_parts) | |
| buffer = io.BytesIO() | |
| sf.write(buffer, combined, samplerate=sample_rate, format="WAV") | |
| return Response(content=buffer.getvalue(), media_type="audio/wav") | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host=HOST, port=PORT, reload=False) | |