notmax123 commited on
Commit
615a636
·
1 Parent(s): 35eb04b

Use v2 model repositories for Space runtime

Browse files

Fetch ONNX assets from the v2 bundle and voice-export safetensors from blue-v2 so uploaded reference voices are not mixed with old v1 checkpoints or voice JSONs.

Made-with: Cursor

Files changed (3) hide show
  1. app.py +37 -12
  2. download_models.py +14 -20
  3. export_new_voice.py +16 -16
app.py CHANGED
@@ -20,7 +20,7 @@ from num2words import num2words
20
  import gradio as gr
21
  import onnxruntime as ort
22
 
23
- from download_models import download_blue_models, download_default_voices, download_renikud
24
 
25
  # ------------------------------------------------------------------
26
  # Paths
@@ -42,6 +42,12 @@ VOCAB_PATH = next(
42
  def _needs_download() -> bool:
43
  required = ["text_encoder.onnx", "vector_estimator.onnx", "vocoder.onnx",
44
  "duration_predictor.onnx"]
 
 
 
 
 
 
45
  for fn in required:
46
  p = os.path.join(ONNX_DIR, fn)
47
  if not os.path.exists(p) or os.path.getsize(p) < 1000:
@@ -572,13 +578,25 @@ TTS = BlueTTS(ONNX_DIR, CONFIG_PATH, VOCAB_PATH, RENIKUD_PATH)
572
  def discover_voices() -> Dict[str, str]:
573
  out: Dict[str, str] = {}
574
  for p in sorted(glob.glob(os.path.join(VOICES_DIR, "*.json"))):
 
 
 
 
 
 
 
 
 
 
 
 
575
  label = os.path.splitext(os.path.basename(p))[0]
576
  pretty = label.replace("_", " ").replace("spk ", "Speaker ").title()
577
  out[pretty] = p
578
  return out
579
 
580
 
581
- VOICES: Dict[str, str] = discover_voices() or {"Default": next(iter(discover_voices().values()), "")}
582
  VOICE_STYLES: Dict[str, Style] = {name: load_voice_style([path]) for name, path in VOICES.items()}
583
 
584
 
@@ -603,19 +621,20 @@ def _hash_file(path: str) -> str:
603
 
604
 
605
  def _ensure_pt_weights() -> dict[str, str]:
606
- """Make sure pt checkpoints are on disk; download from notmax123/blue if missing."""
607
  needed: dict[str, Optional[str]] = {k: _find_pt_weight(v) for k, v in PT_WEIGHT_ALIASES.items()}
608
  if any(v is None for v in needed.values()):
609
  from huggingface_hub import hf_hub_download
610
  import shutil
611
  os.makedirs("pt_weights", exist_ok=True)
612
- for fn in ("blue_codec.safetensors", "duration_predictor_final.pt",
613
- "vf_estimetor.pt", "stats_multilingual.pt"):
 
614
  dest = os.path.join("pt_weights", fn)
615
  if not os.path.exists(dest):
616
- print(f"[INFO] Fetching notmax123/blue/{fn} …")
617
  cached = hf_hub_download(
618
- repo_id="notmax123/blue", filename=fn, repo_type="model",
619
  token=os.environ.get("HF_TOKEN") or None,
620
  )
621
  shutil.copy2(cached, dest)
@@ -695,6 +714,12 @@ def synthesize_text(text: str, voice_source: str, voice: str, lang: str, steps:
695
  err = f'<div class="stats-bar"><span class="stat-pill">❌ voice clone failed: {e}</span></div>'
696
  return None, err
697
  else:
 
 
 
 
 
 
698
  style = VOICE_STYLES[voice]
699
  wav, sr = TTS.synthesize(
700
  expand_numbers(text, lang=lang), lang=lang, style=style,
@@ -726,10 +751,10 @@ PT_WEIGHTS_SEARCH = [
726
  "pt_weights",
727
  ]
728
  PT_WEIGHT_ALIASES: dict[str, list[str]] = {
729
- "ae_ckpt": ["blue_codec.safetensors", "blue_codec.pt"],
730
- "ttl_ckpt": ["vf_estimetor.pt", "vf_estimator.pt"],
731
- "dp_ckpt": ["duration_predictor_final.pt", "duration_predictor.pt"],
732
- "stats": ["stats_multilingual.pt", "stats.pt"],
733
  }
734
 
735
 
@@ -870,7 +895,7 @@ with gr.Blocks(title="BlueTTS — Multilingual TTS") as demo:
870
  with gr.Column(elem_classes="ref-panel"):
871
  voice_source_input = gr.Radio(
872
  choices=[("Saved voice", "saved"), ("Uploaded reference", "upload")],
873
- value="saved",
874
  label="Voice source",
875
  )
876
  ref_wav_input = gr.Audio(
 
20
  import gradio as gr
21
  import onnxruntime as ort
22
 
23
+ from download_models import BLUE_REPO, download_blue_models, download_default_voices, download_renikud
24
 
25
  # ------------------------------------------------------------------
26
  # Paths
 
42
  def _needs_download() -> bool:
43
  required = ["text_encoder.onnx", "vector_estimator.onnx", "vocoder.onnx",
44
  "duration_predictor.onnx"]
45
+ repo_marker = os.path.join(ONNX_DIR, ".repo_id")
46
+ if not os.path.exists(repo_marker):
47
+ return True
48
+ with open(repo_marker) as f:
49
+ if f.read().strip() != BLUE_REPO:
50
+ return True
51
  for fn in required:
52
  p = os.path.join(ONNX_DIR, fn)
53
  if not os.path.exists(p) or os.path.getsize(p) < 1000:
 
578
  def discover_voices() -> Dict[str, str]:
579
  out: Dict[str, str] = {}
580
  for p in sorted(glob.glob(os.path.join(VOICES_DIR, "*.json"))):
581
+ try:
582
+ with open(p) as f:
583
+ payload = json.load(f)
584
+ ttl = payload.get("style_ttl")
585
+ if ttl:
586
+ arr = np.array(ttl["data"], dtype=np.float32)
587
+ if float(arr.std()) > 0.3:
588
+ print(f"[INFO] Skipping incompatible voice JSON {p} (style_ttl std={arr.std():.3f})")
589
+ continue
590
+ except Exception as e:
591
+ print(f"[WARN] Skipping unreadable voice JSON {p}: {e}")
592
+ continue
593
  label = os.path.splitext(os.path.basename(p))[0]
594
  pretty = label.replace("_", " ").replace("spk ", "Speaker ").title()
595
  out[pretty] = p
596
  return out
597
 
598
 
599
+ VOICES: Dict[str, str] = discover_voices()
600
  VOICE_STYLES: Dict[str, Style] = {name: load_voice_style([path]) for name, path in VOICES.items()}
601
 
602
 
 
621
 
622
 
623
  def _ensure_pt_weights() -> dict[str, str]:
624
+ """Make sure v2 PyTorch/safetensors checkpoints are on disk."""
625
  needed: dict[str, Optional[str]] = {k: _find_pt_weight(v) for k, v in PT_WEIGHT_ALIASES.items()}
626
  if any(v is None for v in needed.values()):
627
  from huggingface_hub import hf_hub_download
628
  import shutil
629
  os.makedirs("pt_weights", exist_ok=True)
630
+ repo_id = os.environ.get("BLUE_PT_REPO", "notmax123/blue-v2")
631
+ for fn in ("blue_codec.safetensors", "duration_predictor_final.safetensors",
632
+ "vf_estimetor.safetensors", "stats_multilingual.safetensors"):
633
  dest = os.path.join("pt_weights", fn)
634
  if not os.path.exists(dest):
635
+ print(f"[INFO] Fetching {repo_id}/{fn} …")
636
  cached = hf_hub_download(
637
+ repo_id=repo_id, filename=fn, repo_type="model",
638
  token=os.environ.get("HF_TOKEN") or None,
639
  )
640
  shutil.copy2(cached, dest)
 
714
  err = f'<div class="stats-bar"><span class="stat-pill">❌ voice clone failed: {e}</span></div>'
715
  return None, err
716
  else:
717
+ if not VOICE_STYLES:
718
+ err = (
719
+ '<div class="stats-bar"><span class="stat-pill">'
720
+ 'No saved v2 voices are installed. Choose "Uploaded reference" and upload audio.</span></div>'
721
+ )
722
+ return None, err
723
  style = VOICE_STYLES[voice]
724
  wav, sr = TTS.synthesize(
725
  expand_numbers(text, lang=lang), lang=lang, style=style,
 
751
  "pt_weights",
752
  ]
753
  PT_WEIGHT_ALIASES: dict[str, list[str]] = {
754
+ "ae_ckpt": ["blue_codec.safetensors"],
755
+ "ttl_ckpt": ["vf_estimetor.safetensors"],
756
+ "dp_ckpt": ["duration_predictor_final.safetensors"],
757
+ "stats": ["stats_multilingual.safetensors"],
758
  }
759
 
760
 
 
895
  with gr.Column(elem_classes="ref-panel"):
896
  voice_source_input = gr.Radio(
897
  choices=[("Saved voice", "saved"), ("Uploaded reference", "upload")],
898
+ value="saved" if VOICE_STYLES else "upload",
899
  label="Voice source",
900
  )
901
  ref_wav_input = gr.Audio(
download_models.py CHANGED
@@ -1,35 +1,30 @@
1
- """Download the slim BlueTTS ONNX bundle + a couple of sample voices."""
2
  import os
3
  import shutil
4
- from huggingface_hub import hf_hub_download, list_repo_files
5
 
6
- BLUE_REPO = "notmax123/blue-onnx"
7
  RENIKUD_REPO = "thewh1teagle/renikud"
8
 
9
- # Core slim bundle: 4 ONNX files + tts config.
10
  BLUE_FILES = [
11
  "text_encoder.onnx",
12
  "vector_estimator.onnx",
13
  "vocoder.onnx",
14
  "duration_predictor.onnx",
15
- "tts.json",
16
  ]
17
 
18
- # Default voices fetched for the UI. Users can drop additional voice JSONs into
19
- # the same directory (e.g. by exporting with ``export_new_voice.py``) and they
20
- # will be picked up automatically.
21
- DEFAULT_VOICES: dict[str, str] = {
22
- "Female": "voices/all_voices/female/spk_00014.json",
23
- "Male": "voices/all_voices/male/spk_00017.json",
24
- }
25
 
26
 
27
  def _is_valid(path: str, min_bytes: int = 100) -> bool:
28
  return os.path.exists(path) and os.path.getsize(path) >= min_bytes
29
 
30
 
31
- def _fetch(repo_id: str, filename: str, dest: str, min_bytes: int = 100) -> None:
32
- if _is_valid(dest, min_bytes):
33
  print(f"Already present: {dest} ({os.path.getsize(dest):,} bytes)")
34
  return
35
  os.makedirs(os.path.dirname(dest) or ".", exist_ok=True)
@@ -44,14 +39,13 @@ def _fetch(repo_id: str, filename: str, dest: str, min_bytes: int = 100) -> None
44
 
45
  def download_blue_models(dest_dir: str = "onnx_slim") -> None:
46
  os.makedirs(dest_dir, exist_ok=True)
 
 
47
  for filename in BLUE_FILES:
48
  dest = os.path.join(dest_dir, filename)
49
- try:
50
- _fetch(BLUE_REPO, filename, dest, min_bytes=100)
51
- except Exception as e:
52
- print(f" FAILED {filename}: {e}")
53
- if filename.endswith(".onnx"):
54
- raise
55
 
56
 
57
  def download_default_voices(dest_dir: str = "voices") -> dict[str, str]:
 
1
+ """Download the slim BlueTTS ONNX bundle + matching sample voices."""
2
  import os
3
  import shutil
4
+ from huggingface_hub import hf_hub_download
5
 
6
+ BLUE_REPO = os.environ.get("BLUE_ONNX_REPO", "notmax123/blue-onnx-v2")
7
  RENIKUD_REPO = "thewh1teagle/renikud"
8
 
9
+ # Core slim bundle. Config is kept in the Space repo as root tts.json.
10
  BLUE_FILES = [
11
  "text_encoder.onnx",
12
  "vector_estimator.onnx",
13
  "vocoder.onnx",
14
  "duration_predictor.onnx",
 
15
  ]
16
 
17
+ # Users can drop matching v2 voice JSONs into ./voices. The v2 ONNX repo does
18
+ # not ship default voices, and old v1 voices are not compatible.
19
+ DEFAULT_VOICES: dict[str, str] = {}
 
 
 
 
20
 
21
 
22
  def _is_valid(path: str, min_bytes: int = 100) -> bool:
23
  return os.path.exists(path) and os.path.getsize(path) >= min_bytes
24
 
25
 
26
+ def _fetch(repo_id: str, filename: str, dest: str, min_bytes: int = 100, *, force: bool = False) -> None:
27
+ if not force and _is_valid(dest, min_bytes):
28
  print(f"Already present: {dest} ({os.path.getsize(dest):,} bytes)")
29
  return
30
  os.makedirs(os.path.dirname(dest) or ".", exist_ok=True)
 
39
 
40
  def download_blue_models(dest_dir: str = "onnx_slim") -> None:
41
  os.makedirs(dest_dir, exist_ok=True)
42
+ marker = os.path.join(dest_dir, ".repo_id")
43
+ force = not os.path.exists(marker) or open(marker).read().strip() != BLUE_REPO
44
  for filename in BLUE_FILES:
45
  dest = os.path.join(dest_dir, filename)
46
+ _fetch(BLUE_REPO, filename, dest, min_bytes=100, force=force)
47
+ with open(marker, "w") as f:
48
+ f.write(BLUE_REPO + "\n")
 
 
 
49
 
50
 
51
  def download_default_voices(dest_dir: str = "voices") -> dict[str, str]:
export_new_voice.py CHANGED
@@ -5,17 +5,17 @@ Build a *voice style* JSON for Blue (BlueTTS) from one reference WAV.
5
 
6
  See repo README for usage. Requires the BlueTTS training codebase on
7
  ``PYTHONPATH`` and the PyTorch checkpoints (``blue_codec.safetensors``,
8
- ``vf_estimator.safetensors``, ``duration_predictor.safetensors``,
9
- ``stats_multilingual.pt``).
10
 
11
  PYTHONPATH=training uv run python export_new_voice.py \
12
  --ref_wav /path/to/ref.wav \
13
  --out voices/mine.json \
14
  --config tts.json \
15
  --ae_ckpt pt_weights/blue_codec.safetensors \
16
- --ttl_ckpt pt_weights/vf_estimator.safetensors \
17
- --dp_ckpt pt_weights/duration_predictor.safetensors \
18
- --stats pt_weights/stats_multilingual.pt
19
  """
20
  from __future__ import annotations
21
 
@@ -40,12 +40,12 @@ if _TRAINING not in sys.path:
40
  from bluecodec.autoencoder.latent_encoder import LatentEncoder # noqa: E402
41
  from models.utils import LinearMelSpectrogram, compress_latents, load_ttl_config # noqa: E402
42
 
43
- HF_REPO_ID = "notmax123/blue"
44
  HF_WEIGHT_SIZES: dict[str, int] = {
45
  "blue_codec.safetensors": 245_114_104,
46
- "duration_predictor.safetensors": 2_040_512,
47
- "stats_multilingual.pt": 3_133,
48
- "vf_estimator.safetensors": 179_313_224,
49
  }
50
 
51
 
@@ -93,7 +93,7 @@ def load_stats(device: str, preferred: str, fallback: str = "stats.pt"):
93
  stats_path = preferred if os.path.exists(preferred) else fallback
94
  if not os.path.exists(stats_path):
95
  raise FileNotFoundError(f"Missing stats file: tried {preferred} and {fallback}")
96
- stats = torch.load(stats_path, map_location=device, weights_only=False)
97
  mean = stats["mean"].to(device).view(1, -1, 1)
98
  std = stats["std"].to(device).view(1, -1, 1)
99
  return mean, std, stats_path
@@ -144,9 +144,9 @@ def export_voice_style(
144
  *,
145
  config: str = "tts.json",
146
  ae_ckpt: str = "blue_codec.safetensors",
147
- ttl_ckpt: str = "vf_estimator.safetensors",
148
- dp_ckpt: str = "duration_predictor.safetensors",
149
- stats: str = "stats_multilingual.pt",
150
  device: str = "cpu",
151
  out_pt: str | None = None,
152
  verify_hf_sizes_flag: bool = False,
@@ -320,9 +320,9 @@ def main() -> None:
320
  ap.add_argument("--out", type=str, default="voice.json")
321
  ap.add_argument("--out_pt", type=str, default=None)
322
  ap.add_argument("--ae_ckpt", type=str, default="blue_codec.safetensors")
323
- ap.add_argument("--stats", type=str, default="stats_multilingual.pt")
324
- ap.add_argument("--ttl_ckpt", type=str, default="vf_estimator.safetensors")
325
- ap.add_argument("--dp_ckpt", type=str, default="duration_predictor.safetensors")
326
  ap.add_argument("--verify_hf_sizes", action="store_true")
327
  ap.add_argument("--device", type=str, default="cpu")
328
  ap.add_argument("--config", type=str, default="tts.json")
 
5
 
6
  See repo README for usage. Requires the BlueTTS training codebase on
7
  ``PYTHONPATH`` and the PyTorch checkpoints (``blue_codec.safetensors``,
8
+ ``vf_estimetor.safetensors``, ``duration_predictor_final.safetensors``,
9
+ ``stats_multilingual.safetensors``).
10
 
11
  PYTHONPATH=training uv run python export_new_voice.py \
12
  --ref_wav /path/to/ref.wav \
13
  --out voices/mine.json \
14
  --config tts.json \
15
  --ae_ckpt pt_weights/blue_codec.safetensors \
16
+ --ttl_ckpt pt_weights/vf_estimetor.safetensors \
17
+ --dp_ckpt pt_weights/duration_predictor_final.safetensors \
18
+ --stats pt_weights/stats_multilingual.safetensors
19
  """
20
  from __future__ import annotations
21
 
 
40
  from bluecodec.autoencoder.latent_encoder import LatentEncoder # noqa: E402
41
  from models.utils import LinearMelSpectrogram, compress_latents, load_ttl_config # noqa: E402
42
 
43
+ HF_REPO_ID = "notmax123/blue-v2"
44
  HF_WEIGHT_SIZES: dict[str, int] = {
45
  "blue_codec.safetensors": 245_114_104,
46
+ "duration_predictor_final.safetensors": 2_040_744,
47
+ "stats_multilingual.safetensors": 1_416,
48
+ "vf_estimetor.safetensors": 174_487_392,
49
  }
50
 
51
 
 
93
  stats_path = preferred if os.path.exists(preferred) else fallback
94
  if not os.path.exists(stats_path):
95
  raise FileNotFoundError(f"Missing stats file: tried {preferred} and {fallback}")
96
+ stats = load_torch_or_safetensors(stats_path, map_location=device)
97
  mean = stats["mean"].to(device).view(1, -1, 1)
98
  std = stats["std"].to(device).view(1, -1, 1)
99
  return mean, std, stats_path
 
144
  *,
145
  config: str = "tts.json",
146
  ae_ckpt: str = "blue_codec.safetensors",
147
+ ttl_ckpt: str = "vf_estimetor.safetensors",
148
+ dp_ckpt: str = "duration_predictor_final.safetensors",
149
+ stats: str = "stats_multilingual.safetensors",
150
  device: str = "cpu",
151
  out_pt: str | None = None,
152
  verify_hf_sizes_flag: bool = False,
 
320
  ap.add_argument("--out", type=str, default="voice.json")
321
  ap.add_argument("--out_pt", type=str, default=None)
322
  ap.add_argument("--ae_ckpt", type=str, default="blue_codec.safetensors")
323
+ ap.add_argument("--stats", type=str, default="stats_multilingual.safetensors")
324
+ ap.add_argument("--ttl_ckpt", type=str, default="vf_estimetor.safetensors")
325
+ ap.add_argument("--dp_ckpt", type=str, default="duration_predictor_final.safetensors")
326
  ap.add_argument("--verify_hf_sizes", action="store_true")
327
  ap.add_argument("--device", type=str, default="cpu")
328
  ap.add_argument("--config", type=str, default="tts.json")