CherithCutestory commited on
Commit
d6d700f
·
1 Parent(s): 126d204

Fixed some of the style tts loading issues

Browse files
Files changed (1) hide show
  1. app.py +34 -5
app.py CHANGED
@@ -12,24 +12,53 @@ def ensure_styletts2():
12
  try:
13
  import styletts2 # noqa: F401
14
  return
15
- except Exception:
16
  pass
17
 
18
  subprocess.check_call([
19
- sys.executable, "-m", "pip", "install", "--no-cache-dir", "--no-deps", "styletts2==0.1.6"
20
  ])
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ensure_styletts2()
23
 
24
  import io
25
  import uuid
26
  import soundfile as sf
27
  import gradio as gr
28
-
29
  import torch
30
 
31
- from styletts2 import StyleTTS2
32
-
33
 
34
  # ---------------------------
35
  # Global config
 
12
  try:
13
  import styletts2 # noqa: F401
14
  return
15
+ except ModuleNotFoundError:
16
  pass
17
 
18
  subprocess.check_call([
19
+ sys.executable, "-m", "pip", "install", "--upgrade", "--no-cache-dir", "--no-deps", "styletts2==0.1.6"
20
  ])
21
 
22
+ def import_styletts2_class():
23
+ """
24
+ styletts2 PyPI package doesn't export StyleTTS2 at top-level.
25
+ Try a few known module locations and return the class/callable.
26
+ """
27
+ import importlib
28
+
29
+ # Try common locations seen in forks / packaged builds
30
+ candidates = [
31
+ ("styletts2", "StyleTTS2"),
32
+ ("styletts2.model", "StyleTTS2"),
33
+ ("styletts2.styletts2", "StyleTTS2"),
34
+ ("styletts2.api", "StyleTTS2"),
35
+ ]
36
+
37
+ for mod_name, attr in candidates:
38
+ try:
39
+ mod = importlib.import_module(mod_name)
40
+ if hasattr(mod, attr):
41
+ return getattr(mod, attr)
42
+ except Exception:
43
+ pass
44
+
45
+ # If none worked, print what's actually inside and fail loudly
46
+ import styletts2
47
+ raise ImportError(
48
+ "Could not locate StyleTTS2 class. "
49
+ f"styletts2 package loaded from: {getattr(styletts2, '__file__', 'unknown')}. "
50
+ f"Available attrs: {sorted([a for a in dir(styletts2) if not a.startswith('_')])}"
51
+ )
52
+
53
  ensure_styletts2()
54
 
55
  import io
56
  import uuid
57
  import soundfile as sf
58
  import gradio as gr
 
59
  import torch
60
 
61
+ StyleTTS2 = import_styletts2_class()
 
62
 
63
  # ---------------------------
64
  # Global config