Approximetal commited on
Commit
dff3185
·
verified ·
1 Parent(s): 2a1e401

Update lemas_tts/api.py

Browse files
Files changed (1) hide show
  1. lemas_tts/api.py +41 -5
lemas_tts/api.py CHANGED
@@ -1,10 +1,10 @@
 
1
  import random
2
  import sys
3
  from pathlib import Path
4
  import re, regex
5
  import soundfile as sf
6
  import tqdm
7
- from cached_path import cached_path
8
  from hydra.utils import get_class
9
  from omegaconf import OmegaConf
10
 
@@ -36,11 +36,47 @@ def _find_repo_root(start: Path) -> Path:
36
  return start
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  REPO_ROOT = _find_repo_root(THIS_FILE)
40
- # Local pretrained root (used when running from a repo / Space that bundles weights)
41
- PRETRAINED_ROOT = REPO_ROOT / "pretrained_models"
42
- # Remote pretrained root on Hugging Face Hub (fallback when local files are absent)
43
- HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
44
  CKPTS_ROOT = PRETRAINED_ROOT / "ckpts"
45
 
46
  class TTS:
 
1
+ import os
2
  import random
3
  import sys
4
  from pathlib import Path
5
  import re, regex
6
  import soundfile as sf
7
  import tqdm
 
8
  from hydra.utils import get_class
9
  from omegaconf import OmegaConf
10
 
 
36
  return start
37
 
38
 
39
+ def _find_pretrained_root(start: Path) -> Path:
40
+ """
41
+ Locate the `pretrained_models` root, with support for:
42
+ 1) Explicit env override (LEMAS_PRETRAINED_ROOT)
43
+ 2) Hugging Face Spaces model mount under /models
44
+ 3) Local source tree (searching upwards from this file)
45
+ """
46
+ # 1) Explicit override
47
+ env_root = os.environ.get("LEMAS_PRETRAINED_ROOT")
48
+ if env_root:
49
+ p = Path(env_root)
50
+ if p.is_dir():
51
+ return p
52
+
53
+ # 2) HF Spaces model mount: /models/<model_id>/pretrained_models
54
+ models_dir = Path("/models")
55
+ if models_dir.is_dir():
56
+ # Try the expected model name first
57
+ specific = models_dir / "LEMAS-Project__LEMAS-TTS"
58
+ if (specific / "pretrained_models").is_dir():
59
+ return specific / "pretrained_models"
60
+ # Otherwise, pick the first model that has a pretrained_models subdir
61
+ for child in models_dir.iterdir():
62
+ if child.is_dir() and (child / "pretrained_models").is_dir():
63
+ return child / "pretrained_models"
64
+
65
+ # 3) Local repo layout
66
+ repo_root = _find_repo_root(start)
67
+ if (repo_root / "pretrained_models").is_dir():
68
+ return repo_root / "pretrained_models"
69
+
70
+ cwd = Path.cwd()
71
+ if (cwd / "pretrained_models").is_dir():
72
+ return cwd / "pretrained_models"
73
+
74
+ # Fallback: assume under repo root even if directory is missing
75
+ return repo_root / "pretrained_models"
76
+
77
+
78
  REPO_ROOT = _find_repo_root(THIS_FILE)
79
+ PRETRAINED_ROOT = _find_pretrained_root(THIS_FILE)
 
 
 
80
  CKPTS_ROOT = PRETRAINED_ROOT / "ckpts"
81
 
82
  class TTS: