mkfallah commited on
Commit
58dac37
·
verified ·
1 Parent(s): b2b1119

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -15
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # app.py
2
  # simple gradio space for Persian TTS using kamtera/persian-tts-female-vits (coqui tts)
3
- # all ui messages and comments are in English
 
4
 
5
  import os
6
  import tempfile
@@ -8,42 +9,50 @@ from hazm import Normalizer
8
  from TTS.api import TTS
9
  import gradio as gr
10
 
 
 
 
11
  # -------------------------
12
  # configuration
13
- MODEL_ID = "Kamtera/persian-tts-female-vits"
14
- HF_TOKEN = os.environ.get("HF_TOKEN", None) # optional token for private models
15
- MAX_INPUT_LENGTH = 1200 # safety limit for long text
16
  # -------------------------
17
 
18
  normalizer = Normalizer()
19
 
20
- # load Coqui TTS model
21
- print("loading tts model:", MODEL_ID)
22
- if HF_TOKEN:
23
- tts = TTS(model_name=MODEL_ID, progress_bar=False, gpu=False, use_auth_token=HF_TOKEN)
24
- else:
25
- tts = TTS(model_name=MODEL_ID, progress_bar=False, gpu=False)
 
 
 
 
 
 
 
 
 
26
 
27
  def synthesize(text: str):
28
  """
29
  text: Persian text input
30
- returns: path to the generated wav file
31
  """
32
  if not text or not text.strip():
33
  return None, "please enter some text."
34
 
35
- # limit input length to avoid high latency
36
  if len(text) > MAX_INPUT_LENGTH:
37
  text = text[:MAX_INPUT_LENGTH] + "."
38
 
39
- # normalize persian text
40
  text = normalizer.normalize(text)
41
 
42
- # create a temporary output file
43
  out_fd, out_path = tempfile.mkstemp(suffix=".wav")
44
  os.close(out_fd)
45
 
46
- # generate audio
47
  try:
48
  tts.tts_to_file(text=text, file_path=out_path)
49
  except Exception as e:
 
1
  # app.py
2
  # simple gradio space for Persian TTS using kamtera/persian-tts-female-vits (coqui tts)
3
+ # loads model by first downloading the HuggingFace repo to a local folder,
4
+ # then passes the local path to TTS to avoid Coqui's "model_name parsing" error.
5
 
6
  import os
7
  import tempfile
 
9
  from TTS.api import TTS
10
  import gradio as gr
11
 
12
+ # add huggingface_hub to requirements and import here
13
+ from huggingface_hub import snapshot_download
14
+
15
  # -------------------------
16
  # configuration
17
+ HF_REPO_ID = "Kamtera/persian-tts-female-vits" # huggingface repo id
18
+ HF_TOKEN = os.environ.get("HF_TOKEN", None) # optional token for private models
19
+ MAX_INPUT_LENGTH = 1200 # safety limit for long text
20
  # -------------------------
21
 
22
  normalizer = Normalizer()
23
 
24
+ # download the HuggingFace repo to a local folder (cached by HF Hub)
25
+ print("downloading model repo from huggingface:", HF_REPO_ID)
26
+ try:
27
+ local_model_dir = snapshot_download(repo_id=HF_REPO_ID, use_auth_token=HF_TOKEN)
28
+ print("model downloaded to:", local_model_dir)
29
+ except Exception as e:
30
+ print("error while downloading model repo:", e)
31
+ local_model_dir = None
32
+
33
+ if local_model_dir is None:
34
+ raise RuntimeError("failed to download model repo. set HF_TOKEN if repo is private or check repo id.")
35
+
36
+ # now load model from local dir (coqui expects either a coqui id or a local path)
37
+ print("loading tts model from local folder:", local_model_dir)
38
+ tts = TTS(model_name=local_model_dir, progress_bar=False, gpu=False)
39
 
40
  def synthesize(text: str):
41
  """
42
  text: Persian text input
43
+ returns: tuple(output_path_or_none, status_message)
44
  """
45
  if not text or not text.strip():
46
  return None, "please enter some text."
47
 
 
48
  if len(text) > MAX_INPUT_LENGTH:
49
  text = text[:MAX_INPUT_LENGTH] + "."
50
 
 
51
  text = normalizer.normalize(text)
52
 
 
53
  out_fd, out_path = tempfile.mkstemp(suffix=".wav")
54
  os.close(out_fd)
55
 
 
56
  try:
57
  tts.tts_to_file(text=text, file_path=out_path)
58
  except Exception as e: