File size: 9,927 Bytes
8e7db59
c98bffb
8e7db59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a23c103
8e7db59
 
a23c103
 
 
 
 
 
 
8e7db59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f520c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c98bffb
f2c9b97
 
f520c8c
 
 
f2c9b97
c98bffb
f520c8c
f2c9b97
 
c98bffb
 
f520c8c
f2c9b97
 
 
 
 
 
 
c98bffb
 
 
 
a23c103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e7db59
 
 
 
 
 
 
a23c103
 
 
 
 
 
 
 
 
 
 
 
8e7db59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import io
import json
import os
import re
from typing import Any

import importlib
import torch

_pt_utils = importlib.import_module("transformers.pytorch_utils")
if not hasattr(_pt_utils, "isin_mps_friendly"):
    def _isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor) -> torch.Tensor:
        if test_elements.device.type == "mps":
            test_elements = test_elements.cpu()
        return torch.isin(elements, test_elements)
    _pt_utils.isin_mps_friendly = _isin_mps_friendly

import numpy as np
import soundfile as sf
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field, field_validator
from TTS.api import TTS

HOST = "0.0.0.0"
PORT = 7860
DEFAULT_SPEAKER = os.environ.get("COQUI_DEFAULT_SPEAKER", "p228")

REPOS: dict[str, str] = {
    "en": os.environ.get("HF_TTS_EN_REPO", "Resilient-Coders/coqui-vctk-en"),
    "es": os.environ.get("HF_TTS_ES_REPO", "Resilient-Coders/coqui-css10-es"),
    "vi": os.environ.get("HF_TTS_VI_REPO", "Resilient-Coders/mms-tts-vie"),
}

# Vietnamese uses Fairseq format. Coqui loads it via model_name (model_dir path),
# which calls _load_fairseq_from_dir and never reads config.json.
# We mirror the HF snapshot files into TTS_HOME so model_name lookup finds them.
TTS_HOME = os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
VI_MODEL_NAME = "tts_models/vie/fairseq/vits"
VI_TTS_HOME_DIR = os.path.join(TTS_HOME, "tts_models--vie--fairseq--vits")

WEIGHT_FILE_CANDIDATES = ["model.pth", "model_file.pth.tar", "model_file.pth"]


def resolve_weights(local_dir: str) -> str:
    for name in WEIGHT_FILE_CANDIDATES:
        p = os.path.join(local_dir, name)
        if os.path.isfile(p):
            return p
    raise RuntimeError(f"No weight file found in {local_dir}")


app = FastAPI(title="aiDoc TTS Space", version="1.0.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

tts_instances: dict[str, TTS] = {}


@app.on_event("startup")
async def preload_models() -> None:
    import asyncio
    loop = asyncio.get_event_loop()
    for lang in REPOS:
        await loop.run_in_executor(None, get_tts, lang)


class SynthesizeRequest(BaseModel):
    text: str = Field(min_length=1)
    speaker_idx: str | None = None
    language: str = "en"

    @field_validator("language")
    @classmethod
    def normalize_language(cls, v: str) -> str:
        key = (v or "en").strip().lower()
        if key not in REPOS:
            raise ValueError(f"Unsupported language: {v!r}. Use one of: {', '.join(sorted(REPOS))}.")
        return key


PATH_KEYS = ("speakers_file", "speaker_ids_file", "d_vector_file")


def _patch_dict(obj: dict, local_dir: str) -> bool:
    """Recursively fix off-machine absolute paths in a config dict. Returns True if anything changed."""
    changed = False
    for key, val in obj.items():
        if isinstance(val, dict):
            if _patch_dict(val, local_dir):
                changed = True
        elif key in PATH_KEYS and isinstance(val, str) and val and not os.path.isfile(val):
            candidate = os.path.join(local_dir, os.path.basename(val))
            if os.path.isfile(candidate):
                obj[key] = candidate
                changed = True
                print(f"[tts] patched config key {key!r} -> {candidate}", flush=True)
    return changed


def patch_config(local_dir: str) -> str:
    """Patch any off-machine absolute paths in config.json, overwriting in place.

    The config stores paths in both top-level and nested (model_args) dicts.
    We resolve the symlink to the actual HF blob, chmod it writable, patch all
    occurrences, and overwrite in place. Safe in a container that resets each run.
    """
    config_path = os.path.join(local_dir, "config.json")
    real_path = os.path.realpath(config_path)

    with open(real_path) as f:
        cfg = json.load(f)

    if _patch_dict(cfg, local_dir):
        try:
            os.chmod(real_path, 0o644)
        except OSError as e:
            print(f"[tts] chmod warning: {e}", flush=True)
        with open(real_path, "w") as f:
            json.dump(cfg, f)
        print(f"[tts] wrote patched config to {real_path}", flush=True)

    return config_path


def setup_fairseq_vi(local_dir: str) -> None:
    """Mirror HF snapshot files for the Vietnamese fairseq model into TTS_HOME.

    Coqui's fairseq loader uses model_name -> model_dir -> _load_fairseq_from_dir,
    which creates a blank VitsConfig and never reads config.json. Setting up the
    TTS_HOME directory lets us use model_name without re-downloading from Coqui's
    (defunct) registry, and avoids the config format incompatibility.
    """
    os.makedirs(VI_TTS_HOME_DIR, exist_ok=True)
    for fname in os.listdir(local_dir):
        if fname.startswith("."):
            continue
        src = os.path.realpath(os.path.join(local_dir, fname))
        dst = os.path.join(VI_TTS_HOME_DIR, fname)
        if not os.path.exists(dst) and os.path.isfile(src):
            try:
                os.symlink(src, dst)
            except OSError:
                import shutil
                shutil.copy2(src, dst)
            print(f"[tts] vi: linked {fname}", flush=True)


def get_tts(lang: str) -> TTS:
    if lang not in REPOS:
        raise HTTPException(status_code=400, detail=f"Unsupported language: {lang}")
    if lang not in tts_instances:
        repo_id = REPOS[lang]
        print(f"[tts] downloading repo for {lang}: {repo_id}", flush=True)
        local_dir = snapshot_download(repo_id=repo_id)

        if lang == "vi":
            # Fairseq format: use model_name so Coqui routes through
            # _load_fairseq_from_dir (blank VitsConfig, bypasses config.json parse).
            setup_fairseq_vi(local_dir)
            print(f"[tts] loading vi via model_name={VI_MODEL_NAME}", flush=True)
            tts_instances[lang] = TTS(model_name=VI_MODEL_NAME, progress_bar=False).to("cpu")
        else:
            weights = resolve_weights(local_dir)
            config_path = patch_config(local_dir)
            print(f"[tts] loading {weights}", flush=True)
            tts_instances[lang] = TTS(model_path=weights, config_path=config_path, progress_bar=False).to("cpu")
    return tts_instances[lang]


def get_speakers(model: TTS) -> list[str]:
    manager = getattr(getattr(model, "synthesizer", None), "tts_model", None)
    speaker_manager = getattr(manager, "speaker_manager", None)
    if speaker_manager is None:
        return []
    speaker_names: Any = getattr(speaker_manager, "speaker_names", None)
    if isinstance(speaker_names, list):
        return [str(name) for name in speaker_names]
    name_to_id: Any = getattr(speaker_manager, "name_to_id", None)
    if isinstance(name_to_id, dict):
        return [str(name) for name in name_to_id.keys()]
    speakers: Any = getattr(speaker_manager, "speakers", None)
    if isinstance(speakers, dict):
        return [str(name) for name in speakers.keys()]
    return []


def resolve_sample_rate(model: TTS) -> int:
    synthesizer = getattr(model, "synthesizer", None)
    rate = getattr(synthesizer, "output_sample_rate", None) if synthesizer else None
    if isinstance(rate, int) and rate > 0:
        return rate
    return 22050


@app.get("/")
async def root() -> dict[str, Any]:
    return {
        "service": "aidoc-tts",
        "endpoints": ["/health", "/speakers", "/synthesize"],
    }


@app.get("/health")
async def health() -> dict[str, Any]:
    return {
        "status": "ok",
        "device": "cpu",
        "loaded_languages": sorted(tts_instances.keys()),
        "supported_languages": sorted(REPOS.keys()),
    }


@app.get("/speakers")
async def speakers() -> dict[str, list[str]]:
    model = get_tts("en")
    return {"speakers": get_speakers(model)}


def split_sentences(text: str) -> list[str]:
    text = re.sub(r"[\r\n]+", " ", text)
    text = re.sub(r"[\u2022\u00b7\u2023\u25aa\u25b8\u25ba]+", "", text)
    text = re.sub(r"\s{2,}", " ", text).strip()

    raw = re.split(r"(?<=[.!?])\s+", text)

    sentences: list[str] = []
    current = ""
    for chunk in raw:
        chunk = chunk.strip()
        if not chunk:
            continue
        if len(current) + len(chunk) > 200 and current:
            sentences.append(current.strip())
            current = chunk
        else:
            current = (current + " " + chunk).strip()
    if current:
        sentences.append(current.strip())

    return [s for s in sentences if s]


@app.post("/synthesize")
async def synthesize(payload: SynthesizeRequest) -> Response:
    lang = payload.language
    model = get_tts(lang)
    sample_rate = resolve_sample_rate(model)
    sentences = split_sentences(payload.text)

    if not sentences:
        raise HTTPException(status_code=400, detail="No speakable text provided")

    audio_parts: list[Any] = []
    for sentence in sentences:
        try:
            if lang == "en":
                speaker = payload.speaker_idx or DEFAULT_SPEAKER
                wav = model.tts(text=sentence, speaker=speaker)
            else:
                wav = model.tts(text=sentence)
            audio_parts.append(np.array(wav, dtype=np.float32))
        except Exception as error:
            print(f"[tts] skipping sentence due to error: {error!r}", flush=True)
            continue

    if not audio_parts:
        raise HTTPException(status_code=500, detail="All sentences failed to synthesize")

    combined = np.concatenate(audio_parts)
    buffer = io.BytesIO()
    sf.write(buffer, combined, samplerate=sample_rate, format="WAV")
    return Response(content=buffer.getvalue(), media_type="audio/wav")


if __name__ == "__main__":
    uvicorn.run("app:app", host=HOST, port=PORT, reload=False)