mkfallah commited on
Commit
7791cdf
·
verified ·
1 Parent(s): 580b4a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -38
app.py CHANGED
@@ -1,47 +1,188 @@
1
  # app.py
2
- # simple gradio space for Persian TTS using kamtera/persian-tts-female-vits (coqui tts)
3
- # loads model by first downloading the HuggingFace repo to a local folder,
4
- # then passes the local path to TTS to avoid Coqui's "model_name parsing" error.
5
 
6
  import os
 
7
  import tempfile
 
 
 
 
 
8
  from hazm import Normalizer
 
9
  from TTS.api import TTS
10
  import gradio as gr
11
 
12
- # add huggingface_hub to requirements and import here
13
- from huggingface_hub import snapshot_download
14
 
15
- # -------------------------
16
- # configuration
17
- HF_REPO_ID = "Kamtera/persian-tts-female-vits" # huggingface repo id
18
- HF_TOKEN = os.environ.get("HF_TOKEN", None) # optional token for private models
19
- MAX_INPUT_LENGTH = 1200 # safety limit for long text
20
- # -------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  normalizer = Normalizer()
23
 
24
- # download the HuggingFace repo to a local folder (cached by HF Hub)
25
- print("downloading model repo from huggingface:", HF_REPO_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
 
27
  local_model_dir = snapshot_download(repo_id=HF_REPO_ID, use_auth_token=HF_TOKEN)
28
- print("model downloaded to:", local_model_dir)
29
  except Exception as e:
30
- print("error while downloading model repo:", e)
 
31
  local_model_dir = None
32
 
33
  if local_model_dir is None:
34
- raise RuntimeError("failed to download model repo. set HF_TOKEN if repo is private or check repo id.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # now load model from local dir (coqui expects either a coqui id or a local path)
37
- print("loading tts model from local folder:", local_model_dir)
38
- tts = TTS(model_name=local_model_dir, progress_bar=False, gpu=False)
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def synthesize(text: str):
41
- """
42
- text: Persian text input
43
- returns: tuple(output_path_or_none, status_message)
44
- """
45
  if not text or not text.strip():
46
  return None, "please enter some text."
47
 
@@ -49,35 +190,29 @@ def synthesize(text: str):
49
  text = text[:MAX_INPUT_LENGTH] + "."
50
 
51
  text = normalizer.normalize(text)
52
-
53
  out_fd, out_path = tempfile.mkstemp(suffix=".wav")
54
  os.close(out_fd)
55
 
56
  try:
57
  tts.tts_to_file(text=text, file_path=out_path)
58
  except Exception as e:
59
- print("tts generation error:", e)
60
- return None, f"error: {e}"
 
61
 
62
  return out_path, "speech generated successfully."
63
 
64
- # gradio ui
65
- with gr.Blocks(css=".gradio-container {background-color: #fafafa}") as demo:
66
- gr.Markdown("## persian tts — kamtera / persian-tts-female-vits")
67
- text_input = gr.Textbox(
68
- label="persian text (max ~1200 chars)",
69
- lines=6,
70
- placeholder="enter your Persian text here..."
71
- )
72
  generate_btn = gr.Button("generate speech")
73
  audio_output = gr.Audio(label="output audio", type="filepath")
74
  status = gr.Markdown("")
75
-
76
  def run_tts(text):
77
  audio_path, msg = synthesize(text)
78
  return audio_path, msg
79
-
80
  generate_btn.click(fn=run_tts, inputs=text_input, outputs=[audio_output, status])
81
 
82
- if __name__ == "__main__":
83
- demo.launch()
 
1
  # app.py
2
+ # debug-friendly gradio space entrypoint for persian tts
3
+ # this script prints environment info, lists repo files, logs to /tmp/startup.log
4
+ # comments are english and start with lowercase
5
 
6
  import os
7
+ import sys
8
  import tempfile
9
+ import glob
10
+ import traceback
11
+ from typing import Optional, Tuple
12
+
13
+ # external libs
14
  from hazm import Normalizer
15
+ from huggingface_hub import snapshot_download
16
  from TTS.api import TTS
17
  import gradio as gr
18
 
19
+ LOG_PATH = "/tmp/startup.log"
 
20
 
21
+ def log(msg: str, flush: bool = True):
22
+ """write message to stdout and append to startup log file"""
23
+ ts = f"[startup] {msg}"
24
+ print(ts)
25
+ try:
26
+ with open(LOG_PATH, "a", encoding="utf-8") as f:
27
+ f.write(ts + "\n")
28
+ except Exception:
29
+ pass
30
+ if flush:
31
+ try:
32
+ sys.stdout.flush()
33
+ except Exception:
34
+ pass
35
+
36
+ # clear previous log
37
+ try:
38
+ open(LOG_PATH, "w").close()
39
+ except Exception:
40
+ pass
41
+
42
+ log("starting app - debug mode enabled")
43
+ log(f"python executable: {sys.executable}")
44
+ log(f"python version: {sys.version.replace(chr(10), ' ')}")
45
+ log(f"cwd: {os.getcwd()}")
46
+ log("environment variables (selected):")
47
+ for k in ["HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "CUDA_VISIBLE_DEVICES", "PYTHONPATH"]:
48
+ log(f" {k}={os.environ.get(k)}")
49
+
50
+ # list repo root files (first-level) to help debugging missing files
51
+ try:
52
+ root_files = os.listdir(".")
53
+ log("files in repo root (first 100 entries):")
54
+ for i, name in enumerate(root_files[:100]):
55
+ log(f" - {name}")
56
+ except Exception as e:
57
+ log(f"error listing repo root: {e}")
58
+
59
+ # basic config (edit as needed)
60
+ HF_REPO_ID = "Kamtera/persian-tts-female-vits"
61
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
62
+ MAX_INPUT_LENGTH = 1200
63
 
64
  normalizer = Normalizer()
65
 
66
+ def find_model_files(model_dir: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
67
+ """try to discover model and config files under model_dir"""
68
+ model_patterns = ["**/model.pth", "**/model.pt", "**/*.pth", "**/*.pt"]
69
+ config_patterns = ["**/config.json", "**/model_config.json", "**/config*.json"]
70
+ vocoder_patterns = ["**/vocoder.pth", "**/vocoder.pt", "**/hifi-gan*.pth", "**/*.pth"]
71
+ vocoder_config_patterns = ["**/vocoder_config.json", "**/vocoder-config.json", "**/*vocoder*.json"]
72
+
73
+ def glob_first(root, patterns):
74
+ for pat in patterns:
75
+ matches = glob.glob(os.path.join(root, pat), recursive=True)
76
+ if matches:
77
+ matches.sort(key=lambda p: (len(p.split(os.sep)), p))
78
+ return matches[0]
79
+ return None
80
+
81
+ model_path = glob_first(model_dir, model_patterns)
82
+ config_path = glob_first(model_dir, config_patterns)
83
+ vocoder_path = glob_first(model_dir, vocoder_patterns)
84
+ vocoder_config_path = glob_first(model_dir, vocoder_config_patterns)
85
+
86
+ log("discovered model files:")
87
+ log(f" model_path: {model_path}")
88
+ log(f" config_path: {config_path}")
89
+ log(f" vocoder_path: {vocoder_path}")
90
+ log(f" vocoder_config_path: {vocoder_config_path}")
91
+
92
+ return model_path, config_path, vocoder_path, vocoder_config_path
93
+
94
+ # main: attempt to download and initialize model, but catch and log everything
95
+ local_model_dir = None
96
  try:
97
+ log(f"attempting to snapshot_download repo: {HF_REPO_ID}")
98
  local_model_dir = snapshot_download(repo_id=HF_REPO_ID, use_auth_token=HF_TOKEN)
99
+ log(f"snapshot_download returned: {local_model_dir}")
100
  except Exception as e:
101
+ log("snapshot_download raised an exception:")
102
+ log(traceback.format_exc())
103
  local_model_dir = None
104
 
105
  if local_model_dir is None:
106
+ log("failed to download model repo. please ensure HF_TOKEN secret is set if repo is private.")
107
+ # continue to start gradio with a minimal interface that returns the error message
108
+ def synthesize_error(text: str):
109
+ return None, "model repo not available - check space logs and HF_TOKEN"
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown("## persian tts (debug) - model not loaded")
112
+ txt = gr.Textbox(label="persian text", lines=4, placeholder="enter text...")
113
+ btn = gr.Button("generate (disabled)")
114
+ audio = gr.Audio(label="output audio", type="filepath")
115
+ status = gr.Markdown("model repo not downloaded. check logs.")
116
+ btn.click(lambda t: (None, "model not available"), inputs=txt, outputs=[audio, status])
117
+ demo.launch()
118
+ sys.exit(0)
119
+
120
+ # locate model files
121
+ try:
122
+ model_path, config_path, vocoder_path, vocoder_config_path = find_model_files(local_model_dir)
123
+ except Exception:
124
+ log("error during find_model_files:")
125
+ log(traceback.format_exc())
126
+ model_path = config_path = vocoder_path = vocoder_config_path = None
127
+
128
+ # if not found, print a short tree to aid debugging
129
+ if not model_path or not config_path:
130
+ log("model checkpoint or config.json not found automatically - printing repo tree (top levels):")
131
+ try:
132
+ for root, dirs, files in os.walk(local_model_dir):
133
+ rel = os.path.relpath(root, local_model_dir)
134
+ log(f"dir: {rel} - files: {files[:20]}")
135
+ # limit depth printed
136
+ if len(rel.split(os.sep)) > 3:
137
+ break
138
+ except Exception:
139
+ log("error while printing tree:")
140
+ log(traceback.format_exc())
141
+
142
+ log("cannot proceed to load tts. please inspect the repo structure and share the printed tree.")
143
+ # start a minimal ui showing the problem
144
+ with gr.Blocks() as demo:
145
+ gr.Markdown("## persian tts (debug) - missing model files")
146
+ gr.Markdown("model checkpoint or config.json not found in the downloaded repo. see /tmp/startup.log for details.")
147
+ txt = gr.Textbox(label="persian text", lines=4)
148
+ btn = gr.Button("generate (disabled)")
149
+ audio = gr.Audio(label="output audio", type="filepath")
150
+ status = gr.Markdown("model files missing. check logs.")
151
+ btn.click(lambda t: (None, "model not available"), inputs=txt, outputs=[audio, status])
152
+ demo.launch()
153
+ sys.exit(0)
154
+
155
+ # prepare tts kwargs and attempt load
156
+ tts_kwargs = {"model_path": model_path, "config_path": config_path, "gpu": False}
157
+ if vocoder_path:
158
+ tts_kwargs["vocoder_path"] = vocoder_path
159
+ if vocoder_config_path:
160
+ tts_kwargs["vocoder_config_path"] = vocoder_config_path
161
 
162
+ log("initializing TTS with kwargs:")
163
+ for k, v in tts_kwargs.items():
164
+ log(f" {k}: {v}")
165
 
166
+ try:
167
+ tts = TTS(**tts_kwargs)
168
+ log("tts initialized successfully")
169
+ except Exception as e:
170
+ log("tts initialization failed:")
171
+ log(traceback.format_exc())
172
+ # start a minimal ui showing the init error
173
+ with gr.Blocks() as demo:
174
+ gr.Markdown("## persian tts (debug) - tts init failed")
175
+ gr.Markdown("see /tmp/startup.log for stacktrace")
176
+ txt = gr.Textbox(label="persian text", lines=4)
177
+ btn = gr.Button("generate (disabled)")
178
+ audio = gr.Audio(label="output audio", type="filepath")
179
+ status = gr.Markdown("tts init failed. check logs.")
180
+ btn.click(lambda t: (None, "tts not available"), inputs=txt, outputs=[audio, status])
181
+ demo.launch()
182
+ sys.exit(0)
183
+
184
+ # normal synth function
185
  def synthesize(text: str):
 
 
 
 
186
  if not text or not text.strip():
187
  return None, "please enter some text."
188
 
 
190
  text = text[:MAX_INPUT_LENGTH] + "."
191
 
192
  text = normalizer.normalize(text)
 
193
  out_fd, out_path = tempfile.mkstemp(suffix=".wav")
194
  os.close(out_fd)
195
 
196
  try:
197
  tts.tts_to_file(text=text, file_path=out_path)
198
  except Exception as e:
199
+ log("tts generation error:")
200
+ log(traceback.format_exc())
201
+ return None, f"error during synthesis: {e}"
202
 
203
  return out_path, "speech generated successfully."
204
 
205
+ # gradio ui (normal)
206
+ with gr.Blocks() as demo:
207
+ gr.Markdown("## persian tts — debug-enabled")
208
+ text_input = gr.Textbox(label="persian text (max ~1200 chars)", lines=6, placeholder="enter your persian text here...")
 
 
 
 
209
  generate_btn = gr.Button("generate speech")
210
  audio_output = gr.Audio(label="output audio", type="filepath")
211
  status = gr.Markdown("")
 
212
  def run_tts(text):
213
  audio_path, msg = synthesize(text)
214
  return audio_path, msg
 
215
  generate_btn.click(fn=run_tts, inputs=text_input, outputs=[audio_output, status])
216
 
217
+ log("launching gradio app now")
218
+ demo.launch()