aidoc-tts / app.py
JustinJoshi's picture
Add vi back via fairseq model_name + TTS_HOME mirror approach
a23c103
import io
import json
import os
import re
from typing import Any
import importlib
import torch
_pt_utils = importlib.import_module("transformers.pytorch_utils")
if not hasattr(_pt_utils, "isin_mps_friendly"):
def _isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor) -> torch.Tensor:
if test_elements.device.type == "mps":
test_elements = test_elements.cpu()
return torch.isin(elements, test_elements)
_pt_utils.isin_mps_friendly = _isin_mps_friendly
import numpy as np
import soundfile as sf
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field, field_validator
from TTS.api import TTS
HOST = "0.0.0.0"
PORT = 7860
DEFAULT_SPEAKER = os.environ.get("COQUI_DEFAULT_SPEAKER", "p228")
REPOS: dict[str, str] = {
"en": os.environ.get("HF_TTS_EN_REPO", "Resilient-Coders/coqui-vctk-en"),
"es": os.environ.get("HF_TTS_ES_REPO", "Resilient-Coders/coqui-css10-es"),
"vi": os.environ.get("HF_TTS_VI_REPO", "Resilient-Coders/mms-tts-vie"),
}
# Vietnamese uses Fairseq format. Coqui loads it via model_name (model_dir path),
# which calls _load_fairseq_from_dir and never reads config.json.
# We mirror the HF snapshot files into TTS_HOME so model_name lookup finds them.
TTS_HOME = os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
VI_MODEL_NAME = "tts_models/vie/fairseq/vits"
VI_TTS_HOME_DIR = os.path.join(TTS_HOME, "tts_models--vie--fairseq--vits")
WEIGHT_FILE_CANDIDATES = ["model.pth", "model_file.pth.tar", "model_file.pth"]
def resolve_weights(local_dir: str) -> str:
for name in WEIGHT_FILE_CANDIDATES:
p = os.path.join(local_dir, name)
if os.path.isfile(p):
return p
raise RuntimeError(f"No weight file found in {local_dir}")
app = FastAPI(title="aiDoc TTS Space", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
tts_instances: dict[str, TTS] = {}
@app.on_event("startup")
async def preload_models() -> None:
import asyncio
loop = asyncio.get_event_loop()
for lang in REPOS:
await loop.run_in_executor(None, get_tts, lang)
class SynthesizeRequest(BaseModel):
text: str = Field(min_length=1)
speaker_idx: str | None = None
language: str = "en"
@field_validator("language")
@classmethod
def normalize_language(cls, v: str) -> str:
key = (v or "en").strip().lower()
if key not in REPOS:
raise ValueError(f"Unsupported language: {v!r}. Use one of: {', '.join(sorted(REPOS))}.")
return key
PATH_KEYS = ("speakers_file", "speaker_ids_file", "d_vector_file")
def _patch_dict(obj: dict, local_dir: str) -> bool:
"""Recursively fix off-machine absolute paths in a config dict. Returns True if anything changed."""
changed = False
for key, val in obj.items():
if isinstance(val, dict):
if _patch_dict(val, local_dir):
changed = True
elif key in PATH_KEYS and isinstance(val, str) and val and not os.path.isfile(val):
candidate = os.path.join(local_dir, os.path.basename(val))
if os.path.isfile(candidate):
obj[key] = candidate
changed = True
print(f"[tts] patched config key {key!r} -> {candidate}", flush=True)
return changed
def patch_config(local_dir: str) -> str:
"""Patch any off-machine absolute paths in config.json, overwriting in place.
The config stores paths in both top-level and nested (model_args) dicts.
We resolve the symlink to the actual HF blob, chmod it writable, patch all
occurrences, and overwrite in place. Safe in a container that resets each run.
"""
config_path = os.path.join(local_dir, "config.json")
real_path = os.path.realpath(config_path)
with open(real_path) as f:
cfg = json.load(f)
if _patch_dict(cfg, local_dir):
try:
os.chmod(real_path, 0o644)
except OSError as e:
print(f"[tts] chmod warning: {e}", flush=True)
with open(real_path, "w") as f:
json.dump(cfg, f)
print(f"[tts] wrote patched config to {real_path}", flush=True)
return config_path
def setup_fairseq_vi(local_dir: str) -> None:
"""Mirror HF snapshot files for the Vietnamese fairseq model into TTS_HOME.
Coqui's fairseq loader uses model_name -> model_dir -> _load_fairseq_from_dir,
which creates a blank VitsConfig and never reads config.json. Setting up the
TTS_HOME directory lets us use model_name without re-downloading from Coqui's
(defunct) registry, and avoids the config format incompatibility.
"""
os.makedirs(VI_TTS_HOME_DIR, exist_ok=True)
for fname in os.listdir(local_dir):
if fname.startswith("."):
continue
src = os.path.realpath(os.path.join(local_dir, fname))
dst = os.path.join(VI_TTS_HOME_DIR, fname)
if not os.path.exists(dst) and os.path.isfile(src):
try:
os.symlink(src, dst)
except OSError:
import shutil
shutil.copy2(src, dst)
print(f"[tts] vi: linked {fname}", flush=True)
def get_tts(lang: str) -> TTS:
if lang not in REPOS:
raise HTTPException(status_code=400, detail=f"Unsupported language: {lang}")
if lang not in tts_instances:
repo_id = REPOS[lang]
print(f"[tts] downloading repo for {lang}: {repo_id}", flush=True)
local_dir = snapshot_download(repo_id=repo_id)
if lang == "vi":
# Fairseq format: use model_name so Coqui routes through
# _load_fairseq_from_dir (blank VitsConfig, bypasses config.json parse).
setup_fairseq_vi(local_dir)
print(f"[tts] loading vi via model_name={VI_MODEL_NAME}", flush=True)
tts_instances[lang] = TTS(model_name=VI_MODEL_NAME, progress_bar=False).to("cpu")
else:
weights = resolve_weights(local_dir)
config_path = patch_config(local_dir)
print(f"[tts] loading {weights}", flush=True)
tts_instances[lang] = TTS(model_path=weights, config_path=config_path, progress_bar=False).to("cpu")
return tts_instances[lang]
def get_speakers(model: TTS) -> list[str]:
manager = getattr(getattr(model, "synthesizer", None), "tts_model", None)
speaker_manager = getattr(manager, "speaker_manager", None)
if speaker_manager is None:
return []
speaker_names: Any = getattr(speaker_manager, "speaker_names", None)
if isinstance(speaker_names, list):
return [str(name) for name in speaker_names]
name_to_id: Any = getattr(speaker_manager, "name_to_id", None)
if isinstance(name_to_id, dict):
return [str(name) for name in name_to_id.keys()]
speakers: Any = getattr(speaker_manager, "speakers", None)
if isinstance(speakers, dict):
return [str(name) for name in speakers.keys()]
return []
def resolve_sample_rate(model: TTS) -> int:
synthesizer = getattr(model, "synthesizer", None)
rate = getattr(synthesizer, "output_sample_rate", None) if synthesizer else None
if isinstance(rate, int) and rate > 0:
return rate
return 22050
@app.get("/")
async def root() -> dict[str, Any]:
return {
"service": "aidoc-tts",
"endpoints": ["/health", "/speakers", "/synthesize"],
}
@app.get("/health")
async def health() -> dict[str, Any]:
return {
"status": "ok",
"device": "cpu",
"loaded_languages": sorted(tts_instances.keys()),
"supported_languages": sorted(REPOS.keys()),
}
@app.get("/speakers")
async def speakers() -> dict[str, list[str]]:
model = get_tts("en")
return {"speakers": get_speakers(model)}
def split_sentences(text: str) -> list[str]:
text = re.sub(r"[\r\n]+", " ", text)
text = re.sub(r"[\u2022\u00b7\u2023\u25aa\u25b8\u25ba]+", "", text)
text = re.sub(r"\s{2,}", " ", text).strip()
raw = re.split(r"(?<=[.!?])\s+", text)
sentences: list[str] = []
current = ""
for chunk in raw:
chunk = chunk.strip()
if not chunk:
continue
if len(current) + len(chunk) > 200 and current:
sentences.append(current.strip())
current = chunk
else:
current = (current + " " + chunk).strip()
if current:
sentences.append(current.strip())
return [s for s in sentences if s]
@app.post("/synthesize")
async def synthesize(payload: SynthesizeRequest) -> Response:
lang = payload.language
model = get_tts(lang)
sample_rate = resolve_sample_rate(model)
sentences = split_sentences(payload.text)
if not sentences:
raise HTTPException(status_code=400, detail="No speakable text provided")
audio_parts: list[Any] = []
for sentence in sentences:
try:
if lang == "en":
speaker = payload.speaker_idx or DEFAULT_SPEAKER
wav = model.tts(text=sentence, speaker=speaker)
else:
wav = model.tts(text=sentence)
audio_parts.append(np.array(wav, dtype=np.float32))
except Exception as error:
print(f"[tts] skipping sentence due to error: {error!r}", flush=True)
continue
if not audio_parts:
raise HTTPException(status_code=500, detail="All sentences failed to synthesize")
combined = np.concatenate(audio_parts)
buffer = io.BytesIO()
sf.write(buffer, combined, samplerate=sample_rate, format="WAV")
return Response(content=buffer.getvalue(), media_type="audio/wav")
if __name__ == "__main__":
uvicorn.run("app:app", host=HOST, port=PORT, reload=False)