toshuu commited on
Commit
de64ba8
·
verified ·
1 Parent(s): 9bed061

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -118
app.py CHANGED
@@ -1,138 +1,184 @@
1
- # app.py (replace your current file with this)
2
  import os
3
- import threading
4
  import tempfile
5
- import inspect
6
- import traceback
7
- import numpy as np
8
- import soundfile as sf
9
- import gradio as gr
10
-
11
- # torch import is required; HF Spaces requirements will install CPU wheels
12
  import torch
 
 
13
 
 
14
  MODEL_PATH = "v4_indic.pt"
15
- SAMPLE_RATE = 48000
16
-
17
- lock = threading.Lock()
18
- _model = None
19
- _apply_tts_callable = None
20
- _apply_tts_sig = None
21
-
22
- def load_model():
23
- global _model, _apply_tts_callable, _apply_tts_sig
24
- if _model is not None:
25
- return _model
26
 
27
- if not os.path.exists(MODEL_PATH):
28
- raise FileNotFoundError(f"Model file not found in repo root: {MODEL_PATH}")
29
 
30
- print("Loading model from", MODEL_PATH)
31
- pkg = torch.package.PackageImporter(MODEL_PATH)
32
- _model = pkg.load_pickle("tts_models", "model")
33
- print("Model object loaded:", type(_model).__name__)
34
 
35
- # discover apply_tts
36
- if hasattr(_model, "apply_tts"):
37
- _apply_tts_callable = getattr(_model, "apply_tts")
38
- try:
39
- _apply_tts_sig = inspect.signature(_apply_tts_callable)
40
- print("apply_tts signature:", _apply_tts_sig)
41
- except Exception as e:
42
- print("Could not introspect apply_tts signature:", e)
43
- _apply_tts_sig = None
44
- else:
45
- raise RuntimeError("Loaded model does not expose 'apply_tts'")
46
-
47
- return _model
48
-
49
- def _call_apply_tts(text):
 
 
 
 
50
  """
51
- Try a sequence of possible call signatures for apply_tts.
52
- Return numpy array (float32) audio and sample rate.
53
  """
54
- # ensure model loaded
55
- m = load_model()
56
-
57
- # Build candidate calls (ordered by likelihood)
58
- # Each entry is (kwargs dict, args tuple)
59
- candidates = [
60
- ({"text": text}, ()), # apply_tts(text=text)
61
- ({}, (text,)), # apply_tts(text)
62
- ({"text": text, "sample_rate": SAMPLE_RATE}, ()), # apply_tts(text=..., sample_rate=...)
63
- ({}, (text, SAMPLE_RATE)), # apply_tts(text, sample_rate)
64
- ({"text": text, "speaker": 0, "sample_rate": SAMPLE_RATE}, ()), # apply_tts(text=..., speaker=0,...)
65
- ({"text": text, "lang": "hi", "speaker": 0, "sample_rate": SAMPLE_RATE}, ()), # apply_tts(text=..., lang=..., speaker=...)
66
- ({"text": text, "lang_id": 0, "speaker_id": 0, "sample_rate": SAMPLE_RATE}, ()), # older variants
67
- ]
68
-
69
  last_exc = None
70
- for kw, args in candidates:
 
 
71
  try:
72
- # attempt call
73
- if args:
74
- res = m.apply_tts(*args, **kw)
 
 
75
  else:
76
- res = m.apply_tts(**kw)
77
- # success: convert to numpy if torch tensor
78
- if isinstance(res, torch.Tensor):
79
- res = res.detach().cpu().numpy()
80
- res = np.asarray(res, dtype=np.float32)
81
- return res
82
- except TypeError as te:
83
- last_exc = te
84
- # signature mismatch, try next
85
  continue
86
  except Exception as e:
87
- # If a runtime error occurred within model (e.g. tokenizer / input length), raise it
88
- print("Runtime error while calling apply_tts with", kw, args)
89
- traceback.print_exc()
90
  last_exc = e
91
- break
92
-
93
- # If we exit loop without returning, raise a helpful error
94
  raise RuntimeError(f"apply_tts call failed for all known signatures. last error: {last_exc}")
95
 
96
- def synthesize_text_to_wavfile(text):
97
- if not text or not isinstance(text, str) or len(text.strip()) == 0:
98
- raise ValueError("Empty input text")
99
-
100
- audio = _call_apply_tts(text)
101
-
102
- # normalize audio to [-1,1] float32
103
- if audio.dtype != np.float32:
104
- audio = audio.astype(np.float32)
105
- max_abs = np.max(np.abs(audio)) if audio.size > 0 else 1.0
106
- if max_abs > 1.0:
107
- audio = audio / max_abs
108
-
109
- # write to temp WAV
110
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
111
- sf.write(tmp.name, audio, SAMPLE_RATE)
112
- tmp.close()
113
- return tmp.name
114
-
115
- # Gradio function
116
- def tts_gradio_fn(text: str):
117
- with lock:
118
- path = synthesize_text_to_wavfile(text)
119
- return path
120
-
121
- def build_demo():
122
- with gr.Blocks() as demo:
123
- gr.Markdown("# 🔊 Silero v4 Indic — Robust HF Space")
124
- txt = gr.Textbox(label="Text to speak", lines=4, value="नमस्ते, यह टेस्‍ट है।")
125
- btn = gr.Button("Generate")
126
- out = gr.Audio(label="Output audio")
127
- btn.click(fn=tts_gradio_fn, inputs=[txt], outputs=[out])
128
- return demo
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if __name__ == "__main__":
131
- # preload model on startup to avoid cold-call overhead
132
- try:
133
- load_model()
134
- except Exception as e:
135
- print("Model load failed at startup:", e)
136
- traceback.print_exc()
137
- demo = build_demo()
138
- demo.launch()
 
 
1
  import os
2
+ import sys
3
  import tempfile
 
 
 
 
 
 
 
4
  import torch
5
+ import gradio as gr
6
+ from datetime import datetime
7
 
8
+ # Configuration
9
  MODEL_PATH = "v4_indic.pt"
10
+ DEFAULT_SPEAKER = "hindi_female" # Changed from 'xenia' to valid speaker
11
+ DEFAULT_SAMPLE_RATE = 48000
 
 
 
 
 
 
 
 
 
12
 
13
+ print(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
 
14
 
15
+ # Load the model
16
+ print(f"Loading model from {MODEL_PATH}")
17
+ m = torch.package.PackageImporter(MODEL_PATH).load_pickle("tts_models", "model")
18
+ print(f"Model object loaded: {type(m).__name__}")
19
 
20
+ # Inspect apply_tts signature
21
+ import inspect
22
+ sig = inspect.signature(m.apply_tts)
23
+ print(f"apply_tts signature: {sig}")
24
+
25
+ # Available speakers
26
+ AVAILABLE_SPEAKERS = [
27
+ "bengali_female", "bengali_male",
28
+ "gujarati_female", "gujarati_male",
29
+ "hindi_female", "hindi_male",
30
+ "kannada_female", "kannada_male",
31
+ "malayalam_female", "malayalam_male",
32
+ "manipuri_female",
33
+ "rajasthani_female", "rajasthani_male",
34
+ "tamil_female", "tamil_male",
35
+ "telugu_female", "telugu_male"
36
+ ]
37
+
38
+ def _call_apply_tts(text, speaker=DEFAULT_SPEAKER, sample_rate=DEFAULT_SAMPLE_RATE):
39
  """
40
+ Wrapper to call apply_tts with proper error handling.
 
41
  """
42
+ # Validate speaker
43
+ if speaker not in AVAILABLE_SPEAKERS:
44
+ print(f"Warning: Invalid speaker '{speaker}', using default '{DEFAULT_SPEAKER}'")
45
+ speaker = DEFAULT_SPEAKER
46
+
47
+ kw = {
48
+ 'text': text,
49
+ 'speaker': speaker,
50
+ 'sample_rate': sample_rate
51
+ }
52
+
53
+ print(f"Runtime error while calling apply_tts with {kw}")
 
 
 
54
  last_exc = None
55
+
56
+ # Try different parameter combinations
57
+ for attempt_kw in [kw, {'text': text, 'speaker': speaker}]:
58
  try:
59
+ res = m.apply_tts(**attempt_kw)
60
+
61
+ # Handle different return types
62
+ if isinstance(res, tuple):
63
+ audio = res[0]
64
  else:
65
+ audio = res
66
+
67
+ return audio
68
+
69
+ except TypeError as e:
70
+ last_exc = e
71
+ print(f"Attempt failed with {attempt_kw}: {e}")
 
 
72
  continue
73
  except Exception as e:
 
 
 
74
  last_exc = e
75
+ print(f"Error with {attempt_kw}: {e}")
76
+ raise
77
+
78
  raise RuntimeError(f"apply_tts call failed for all known signatures. last error: {last_exc}")
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ def synthesize_text_to_wavfile(text, speaker=DEFAULT_SPEAKER, sample_rate=DEFAULT_SAMPLE_RATE):
82
+ """
83
+ Synthesize text to audio and save to temporary WAV file.
84
+
85
+ Args:
86
+ text: Text to synthesize
87
+ speaker: Speaker voice to use
88
+ sample_rate: Audio sample rate
89
+
90
+ Returns:
91
+ Path to generated WAV file
92
+ """
93
+ audio = _call_apply_tts(text, speaker, sample_rate)
94
+
95
+ # Create temporary file
96
+ fd, path = tempfile.mkstemp(suffix=".wav")
97
+ os.close(fd)
98
+
99
+ # Save audio
100
+ import scipy.io.wavfile as wavfile
101
+ wavfile.write(path, sample_rate, audio)
102
+
103
+ return path
104
+
105
+
106
+ def tts_gradio_fn(text, speaker, sample_rate):
107
+ """
108
+ Gradio interface function.
109
+
110
+ Args:
111
+ text: Input text
112
+ speaker: Selected speaker voice
113
+ sample_rate: Audio sample rate
114
+
115
+ Returns:
116
+ Path to generated audio file
117
+ """
118
+ if not text or not text.strip():
119
+ raise ValueError("Please enter some text to synthesize")
120
+
121
+ path = synthesize_text_to_wavfile(text, speaker, sample_rate)
122
+ return path
123
+
124
+
125
+ # Create Gradio interface
126
+ with gr.Blocks(title="Silero v4 Indic TTS") as demo:
127
+ gr.Markdown("# Silero v4 Indic Text-to-Speech")
128
+ gr.Markdown("Convert text to speech in multiple Indian languages")
129
+
130
+ with gr.Row():
131
+ with gr.Column():
132
+ text_input = gr.Textbox(
133
+ label="Enter Text",
134
+ placeholder="नमस्ते, यह टेस्ट है। (Enter text in Hindi, Bengali, Tamil, Telugu, etc.)",
135
+ lines=5
136
+ )
137
+
138
+ speaker_dropdown = gr.Dropdown(
139
+ choices=AVAILABLE_SPEAKERS,
140
+ value=DEFAULT_SPEAKER,
141
+ label="Select Speaker Voice"
142
+ )
143
+
144
+ sample_rate_dropdown = gr.Dropdown(
145
+ choices=[8000, 16000, 24000, 48000],
146
+ value=DEFAULT_SAMPLE_RATE,
147
+ label="Sample Rate (Hz)"
148
+ )
149
+
150
+ submit_btn = gr.Button("Generate Speech", variant="primary")
151
+
152
+ with gr.Column():
153
+ audio_output = gr.Audio(
154
+ label="Generated Audio",
155
+ type="filepath"
156
+ )
157
+
158
+ # Examples
159
+ gr.Examples(
160
+ examples=[
161
+ ["नमस्ते, यह टेस्ट है।", "hindi_female", 48000],
162
+ ["হ্যালো, এটি একটি পরীক্ষা।", "bengali_female", 48000],
163
+ ["வணக்கம், இது ஒரு சோதனை.", "tamil_female", 48000],
164
+ ["హలో, ఇది ఒక పరీక్ష.", "telugu_female", 48000],
165
+ ],
166
+ inputs=[text_input, speaker_dropdown, sample_rate_dropdown],
167
+ outputs=audio_output,
168
+ fn=tts_gradio_fn,
169
+ cache_examples=False
170
+ )
171
+
172
+ submit_btn.click(
173
+ fn=tts_gradio_fn,
174
+ inputs=[text_input, speaker_dropdown, sample_rate_dropdown],
175
+ outputs=audio_output
176
+ )
177
+
178
+ # Launch the app
179
  if __name__ == "__main__":
180
+ demo.launch(
181
+ server_name="0.0.0.0",
182
+ server_port=7860,
183
+ ssr_mode=True
184
+ )