Approximetal commited on
Commit
39c9309
·
verified ·
1 Parent(s): d56376a

Update gradio_mix.py

Browse files
Files changed (1) hide show
  1. gradio_mix.py +49 -18
gradio_mix.py CHANGED
@@ -15,6 +15,7 @@ import jieba, zhconv
15
  from pypinyin.core import Pinyin
16
  from pypinyin import Style
17
 
 
18
  from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
19
  from lemas_tts.infer.edit_multilingual import gen_wav_multilingual
20
  from lemas_tts.infer.text_norm.txt2pinyin import (
@@ -46,21 +47,18 @@ DEMO_PATH = os.getenv("DEMO_PATH", "./pretrained_models/demo")
46
  TMP_PATH = os.getenv("TMP_PATH", "./pretrained_models/demo/temp")
47
  MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
48
 
49
- # Pick device for the TTS editing model. By default we try CUDA, but fall
50
- # back to CPU if the CUDA stack is not actually usable (e.g. kernel image
51
- # mismatch on older GPUs). You can override via LEMAS_DEVICE env (e.g. "cpu"
52
- # or "cuda").
 
 
 
53
  def _pick_device():
54
  forced = os.getenv("LEMAS_DEVICE")
55
  if forced:
56
  return forced
57
- if torch.cuda.is_available():
58
- try:
59
- torch.zeros(1).to("cuda")
60
- return "cuda"
61
- except Exception as e:
62
- logging.warning("CUDA appears available but failed (%s); falling back to CPU.", e)
63
- return "cpu"
64
 
65
  device = _pick_device()
66
  ASR_DEVICE = "cpu" # force whisperx/pyannote to CPU to avoid cuDNN issues
@@ -355,10 +353,8 @@ class MMSAlignModel:
355
  def __init__(self):
356
  from torchaudio.pipelines import MMS_FA as bundle
357
  self.mms_model = bundle.get_model()
358
- # MMS forced alignment is relatively light; keep it on CPU to avoid
359
- # CUDA kernel / arch mismatches on environments where the main TTS
360
- # model still uses GPU.
361
- self.mms_model.to("cpu")
362
  self.mms_tokenizer = bundle.get_tokenizer()
363
  self.mms_aligner = bundle.get_aligner()
364
  self.text_normalizer = ur.Uroman()
@@ -380,7 +376,7 @@ class MMSAlignModel:
380
 
381
  def compute_alignments(self, waveform: torch.Tensor, tokens):
382
  with torch.inference_mode():
383
- emission, _ = self.mms_model(waveform.to("cpu"))
384
  token_spans = self.mms_aligner(emission[0], tokens)
385
  return emission, token_spans
386
 
@@ -399,7 +395,7 @@ class MMSAlignModel:
399
  assert len(text_normed) == len(raw_text), f"normalized text len != raw text len: {len(text_normed)} != {len(raw_text)}"
400
  tokens = self.mms_tokenizer(text_normed)
401
  with torch.inference_mode():
402
- emission, _ = self.mms_model(waveform.to("cpu"))
403
  token_spans = self.mms_aligner(emission[0], tokens)
404
  num_frames = emission.size(1)
405
  ratio = waveform.size(1) / num_frames
@@ -562,12 +558,41 @@ def load_models(lemas_model_name, whisper_model_name, alignment_model_name, deno
562
  # Load LEMAS-TTS editing model (selected multilingual variant)
563
  from pathlib import Path
564
 
 
565
  ckpt_dir = Path(CKPTS_ROOT) / lemas_model_name
566
  ckpt_candidates = sorted(
567
  list(ckpt_dir.glob("*.safetensors")) + list(ckpt_dir.glob("*.pt"))
568
  )
 
 
 
 
 
 
 
 
 
 
 
569
  if not ckpt_candidates:
570
- raise gr.Error(f"No LEMAS-TTS ckpt found under {ckpt_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  ckpt_file = str(ckpt_candidates[-1])
572
 
573
  vocab_file = Path(PRETRAINED_ROOT) / "data" / lemas_model_name / "vocab.txt"
@@ -1201,6 +1226,12 @@ if __name__ == "__main__":
1201
  parser.add_argument("--port", default=41020, type=int, help="App port")
1202
  parser.add_argument("--share", action="store_true", help="Launch with public url")
1203
  parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")
 
 
 
 
 
 
1204
 
1205
  os.environ["USER"] = os.getenv("USER", "user")
1206
  args = parser.parse_args()
 
15
  from pypinyin.core import Pinyin
16
  from pypinyin import Style
17
 
18
+ from cached_path import cached_path
19
  from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
20
  from lemas_tts.infer.edit_multilingual import gen_wav_multilingual
21
  from lemas_tts.infer.text_norm.txt2pinyin import (
 
47
  TMP_PATH = os.getenv("TMP_PATH", "./pretrained_models/demo/temp")
48
  MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
49
 
50
+ # HF location for large TTS checkpoints (too big for Space storage).
51
+ # Mirrors LEMAS-TTS `inference_gradio.py`.
52
+ HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
53
+
54
+ # Pick device for the TTS editing model.
55
+ # - Default: "cuda" if available, else "cpu"
56
+ # - You can override via LEMAS_DEVICE env (e.g. "cpu" or "cuda").
57
  def _pick_device():
58
  forced = os.getenv("LEMAS_DEVICE")
59
  if forced:
60
  return forced
61
+ return "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
62
 
63
  device = _pick_device()
64
  ASR_DEVICE = "cpu" # force whisperx/pyannote to CPU to avoid cuDNN issues
 
353
  def __init__(self):
354
  from torchaudio.pipelines import MMS_FA as bundle
355
  self.mms_model = bundle.get_model()
356
+ # Keep MMS on the same device as the main edit model unless overridden.
357
+ self.mms_model.to(device)
 
 
358
  self.mms_tokenizer = bundle.get_tokenizer()
359
  self.mms_aligner = bundle.get_aligner()
360
  self.text_normalizer = ur.Uroman()
 
376
 
377
  def compute_alignments(self, waveform: torch.Tensor, tokens):
378
  with torch.inference_mode():
379
+ emission, _ = self.mms_model(waveform.to(device))
380
  token_spans = self.mms_aligner(emission[0], tokens)
381
  return emission, token_spans
382
 
 
395
  assert len(text_normed) == len(raw_text), f"normalized text len != raw text len: {len(text_normed)} != {len(raw_text)}"
396
  tokens = self.mms_tokenizer(text_normed)
397
  with torch.inference_mode():
398
+ emission, _ = self.mms_model(waveform.to(device))
399
  token_spans = self.mms_aligner(emission[0], tokens)
400
  num_frames = emission.size(1)
401
  ratio = waveform.size(1) / num_frames
 
558
  # Load LEMAS-TTS editing model (selected multilingual variant)
559
  from pathlib import Path
560
 
561
+ # Local ckpt search under the standard CKPTS_ROOT layout
562
  ckpt_dir = Path(CKPTS_ROOT) / lemas_model_name
563
  ckpt_candidates = sorted(
564
  list(ckpt_dir.glob("*.safetensors")) + list(ckpt_dir.glob("*.pt"))
565
  )
566
+ # Fallbacks for simpler layouts: allow ckpts directly under CKPTS_ROOT,
567
+ # e.g. ./pretrained_models/ckpts/multilingual_grl.safetensors
568
+ if not ckpt_candidates:
569
+ root_candidates = sorted(
570
+ list(Path(CKPTS_ROOT).glob(f"{lemas_model_name}*.safetensors"))
571
+ + list(Path(CKPTS_ROOT).glob(f"{lemas_model_name}*.pt"))
572
+ )
573
+ ckpt_candidates = root_candidates
574
+
575
+ # If no local ckpt is found, fall back to remote HF checkpoints
576
+ # (using the same mapping as LEMAS-TTS `inference_gradio.py`).
577
  if not ckpt_candidates:
578
+ remote_ckpts = {
579
+ "multilingual_grl": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_grl/multilingual_grl.safetensors",
580
+ "multilingual_prosody": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_prosody/multilingual_prosody.safetensors",
581
+ }
582
+ remote_path = remote_ckpts.get(lemas_model_name)
583
+ if remote_path is not None:
584
+ try:
585
+ resolved = cached_path(remote_path)
586
+ ckpt_candidates = [Path(resolved)]
587
+ logging.info("Resolved remote ckpt %s -> %s", remote_path, resolved)
588
+ except Exception as e:
589
+ raise gr.Error(f"Failed to download remote ckpt {remote_path}: {e}")
590
+
591
+ if not ckpt_candidates:
592
+ raise gr.Error(
593
+ f"No LEMAS-TTS ckpt found for '{lemas_model_name}' under {ckpt_dir} "
594
+ f"or {CKPTS_ROOT}"
595
+ )
596
  ckpt_file = str(ckpt_candidates[-1])
597
 
598
  vocab_file = Path(PRETRAINED_ROOT) / "data" / lemas_model_name / "vocab.txt"
 
1226
  parser.add_argument("--port", default=41020, type=int, help="App port")
1227
  parser.add_argument("--share", action="store_true", help="Launch with public url")
1228
  parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")
1229
+ parser.add_argument(
1230
+ "--models-path",
1231
+ default="./pretrained_models",
1232
+ dest="models_path",
1233
+ help="Path to pretrained_models root (mirrors LEMAS-TTS layout).",
1234
+ )
1235
 
1236
  os.environ["USER"] = os.getenv("USER", "user")
1237
  args = parser.parse_args()