JustinJoshi commited on
Commit
a23c103
·
1 Parent(s): 8fedf61

Add vi back via fairseq model_name + TTS_HOME mirror approach

Browse files
Files changed (1) hide show
  1. app.py +43 -6
app.py CHANGED
@@ -32,10 +32,16 @@ DEFAULT_SPEAKER = os.environ.get("COQUI_DEFAULT_SPEAKER", "p228")
32
  REPOS: dict[str, str] = {
33
  "en": os.environ.get("HF_TTS_EN_REPO", "Resilient-Coders/coqui-vctk-en"),
34
  "es": os.environ.get("HF_TTS_ES_REPO", "Resilient-Coders/coqui-css10-es"),
35
- # "vi" (Vietnamese MMS fairseq) uses a config format incompatible with
36
- # coqui-tts 0.27.x. Vietnamese is handled by the local sidecar / cloud fallback.
37
  }
38
 
 
 
 
 
 
 
 
39
  WEIGHT_FILE_CANDIDATES = ["model.pth", "model_file.pth.tar", "model_file.pth"]
40
 
41
 
@@ -124,6 +130,29 @@ def patch_config(local_dir: str) -> str:
124
  return config_path
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def get_tts(lang: str) -> TTS:
128
  if lang not in REPOS:
129
  raise HTTPException(status_code=400, detail=f"Unsupported language: {lang}")
@@ -131,10 +160,18 @@ def get_tts(lang: str) -> TTS:
131
  repo_id = REPOS[lang]
132
  print(f"[tts] downloading repo for {lang}: {repo_id}", flush=True)
133
  local_dir = snapshot_download(repo_id=repo_id)
134
- weights = resolve_weights(local_dir)
135
- config_path = patch_config(local_dir)
136
- print(f"[tts] loading {weights}", flush=True)
137
- tts_instances[lang] = TTS(model_path=weights, config_path=config_path, progress_bar=False).to("cpu")
 
 
 
 
 
 
 
 
138
  return tts_instances[lang]
139
 
140
 
 
32
  REPOS: dict[str, str] = {
33
  "en": os.environ.get("HF_TTS_EN_REPO", "Resilient-Coders/coqui-vctk-en"),
34
  "es": os.environ.get("HF_TTS_ES_REPO", "Resilient-Coders/coqui-css10-es"),
35
+ "vi": os.environ.get("HF_TTS_VI_REPO", "Resilient-Coders/mms-tts-vie"),
 
36
  }
37
 
38
+ # Vietnamese uses Fairseq format. Coqui loads it via model_name (model_dir path),
39
+ # which calls _load_fairseq_from_dir and never reads config.json.
40
+ # We mirror the HF snapshot files into TTS_HOME so model_name lookup finds them.
41
+ TTS_HOME = os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
42
+ VI_MODEL_NAME = "tts_models/vie/fairseq/vits"
43
+ VI_TTS_HOME_DIR = os.path.join(TTS_HOME, "tts_models--vie--fairseq--vits")
44
+
45
  WEIGHT_FILE_CANDIDATES = ["model.pth", "model_file.pth.tar", "model_file.pth"]
46
 
47
 
 
130
  return config_path
131
 
132
 
133
+ def setup_fairseq_vi(local_dir: str) -> None:
134
+ """Mirror HF snapshot files for the Vietnamese fairseq model into TTS_HOME.
135
+
136
+ Coqui's fairseq loader uses model_name -> model_dir -> _load_fairseq_from_dir,
137
+ which creates a blank VitsConfig and never reads config.json. Setting up the
138
+ TTS_HOME directory lets us use model_name without re-downloading from Coqui's
139
+ (defunct) registry, and avoids the config format incompatibility.
140
+ """
141
+ os.makedirs(VI_TTS_HOME_DIR, exist_ok=True)
142
+ for fname in os.listdir(local_dir):
143
+ if fname.startswith("."):
144
+ continue
145
+ src = os.path.realpath(os.path.join(local_dir, fname))
146
+ dst = os.path.join(VI_TTS_HOME_DIR, fname)
147
+ if not os.path.exists(dst) and os.path.isfile(src):
148
+ try:
149
+ os.symlink(src, dst)
150
+ except OSError:
151
+ import shutil
152
+ shutil.copy2(src, dst)
153
+ print(f"[tts] vi: linked {fname}", flush=True)
154
+
155
+
156
  def get_tts(lang: str) -> TTS:
157
  if lang not in REPOS:
158
  raise HTTPException(status_code=400, detail=f"Unsupported language: {lang}")
 
160
  repo_id = REPOS[lang]
161
  print(f"[tts] downloading repo for {lang}: {repo_id}", flush=True)
162
  local_dir = snapshot_download(repo_id=repo_id)
163
+
164
+ if lang == "vi":
165
+ # Fairseq format: use model_name so Coqui routes through
166
+ # _load_fairseq_from_dir (blank VitsConfig, bypasses config.json parse).
167
+ setup_fairseq_vi(local_dir)
168
+ print(f"[tts] loading vi via model_name={VI_MODEL_NAME}", flush=True)
169
+ tts_instances[lang] = TTS(model_name=VI_MODEL_NAME, progress_bar=False).to("cpu")
170
+ else:
171
+ weights = resolve_weights(local_dir)
172
+ config_path = patch_config(local_dir)
173
+ print(f"[tts] loading {weights}", flush=True)
174
+ tts_instances[lang] = TTS(model_path=weights, config_path=config_path, progress_bar=False).to("cpu")
175
  return tts_instances[lang]
176
 
177