Gapeleon commited on
Commit
c387bd5
·
verified ·
1 Parent(s): 25bbd00

Create web_ui.py

Browse files
Files changed (1) hide show
  1. web_ui.py +336 -0
web_ui.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import soundfile as sf
4
+ import logging
5
+ import argparse
6
+ import gradio as gr
7
+ from datetime import datetime
8
+ from mira.model import MiraTTS
9
+
10
+ MODEL = None
11
+
12
+ def initialize_model(model_dir="YatharthS/MiraTTS"):
13
+ """Load the MiraTTS model once at the beginning."""
14
+ logging.info(f"Loading MiraTTS model from: {model_dir}")
15
+ model = MiraTTS(model_dir)
16
+ return model
17
+
18
+ def generate_audio(text, prompt_audio_path):
19
+ """Generate audio from text using MiraTTS with voice cloning."""
20
+ global MODEL
21
+
22
+ if MODEL is None:
23
+ MODEL = initialize_model()
24
+
25
+ try:
26
+ # Encode the prompt audio
27
+ context_tokens = MODEL.encode_audio(prompt_audio_path)
28
+
29
+ # Generate audio
30
+ audio = MODEL.generate(text, context_tokens)
31
+
32
+ # Convert to numpy array if it's a tensor and handle dtype
33
+ if torch.is_tensor(audio):
34
+ audio = audio.cpu().numpy()
35
+
36
+ # Ensure correct dtype for soundfile (convert from float16 to float32)
37
+ if audio.dtype == 'float16':
38
+ audio = audio.astype('float32')
39
+ elif audio.dtype not in ['float32', 'float64', 'int16', 'int32']:
40
+ audio = audio.astype('float32')
41
+
42
+ return audio, 48000 # Return audio and sample rate
43
+ except Exception as e:
44
+ logging.error(f"Error during generation: {e}")
45
+ raise e
46
+
47
+ def run_tts(text, prompt_audio_path, save_dir="results"):
48
+ """Perform TTS inference and save the generated audio."""
49
+ logging.info(f"Saving audio to: {save_dir}")
50
+
51
+ # Ensure the save directory exists
52
+ os.makedirs(save_dir, exist_ok=True)
53
+
54
+ # Generate unique filename using timestamp
55
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
56
+ save_path = os.path.join(save_dir, f"mira_tts_{timestamp}.wav")
57
+
58
+ logging.info("Starting MiraTTS inference...")
59
+
60
+ # Generate audio
61
+ audio, sample_rate = generate_audio(text, prompt_audio_path)
62
+
63
+ # Save audio file
64
+ sf.write(save_path, audio, samplerate=sample_rate)
65
+
66
+ logging.info(f"Audio saved at: {save_path}")
67
+ return save_path
68
+
69
+ def voice_clone_callback(text, prompt_audio_upload, prompt_audio_record):
70
+ """Gradio callback for voice cloning using MiraTTS."""
71
+ if not text.strip():
72
+ return None
73
+
74
+ # Use uploaded audio or recorded audio
75
+ prompt_audio = prompt_audio_upload if prompt_audio_upload else prompt_audio_record
76
+
77
+ if not prompt_audio:
78
+ return None
79
+
80
+ try:
81
+ audio_output_path = run_tts(text, prompt_audio)
82
+ return audio_output_path
83
+ except Exception as e:
84
+ logging.error(f"Error in voice cloning: {e}")
85
+ return None
86
+
87
+ def voice_creation_callback(text, temperature, top_p, top_k):
88
+ """Gradio callback for creating synthetic voice with custom parameters."""
89
+ if not text.strip():
90
+ return None
91
+
92
+ global MODEL
93
+
94
+ if MODEL is None:
95
+ MODEL = initialize_model()
96
+
97
+ try:
98
+ # Set custom generation parameters
99
+ MODEL.set_params(
100
+ temperature=temperature,
101
+ top_p=top_p,
102
+ top_k=top_k,
103
+ max_new_tokens=1024,
104
+ repetition_penalty=1.2
105
+ )
106
+
107
+ # Use a default voice context (you may want to provide default audio files)
108
+ # Check multiple possible paths for example audio
109
+ possible_paths = [
110
+ "/models3/src/MiraTTS/models/MiraTTS/example1.wav",
111
+ "models/MiraTTS/example1.wav",
112
+ "./models/MiraTTS/example1.wav"
113
+ ]
114
+
115
+ default_audio = None
116
+ for path in possible_paths:
117
+ if os.path.exists(path):
118
+ default_audio = path
119
+ break
120
+
121
+ if default_audio:
122
+ # Generate audio with dtype conversion
123
+ context_tokens = MODEL.encode_audio(default_audio)
124
+ audio = MODEL.generate(text, context_tokens)
125
+
126
+ # Handle tensor conversion and dtype
127
+ if torch.is_tensor(audio):
128
+ audio = audio.cpu().numpy()
129
+
130
+ # Ensure correct dtype for soundfile
131
+ if audio.dtype == 'float16':
132
+ audio = audio.astype('float32')
133
+ elif audio.dtype not in ['float32', 'float64', 'int16', 'int32']:
134
+ audio = audio.astype('float32')
135
+
136
+ # Save the audio
137
+ os.makedirs("results", exist_ok=True)
138
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
139
+ save_path = os.path.join("results", f"mira_tts_creation_{timestamp}.wav")
140
+ sf.write(save_path, audio, samplerate=48000)
141
+
142
+ return save_path
143
+ else:
144
+ logging.warning("No default audio found for voice creation")
145
+ return None
146
+
147
+ except Exception as e:
148
+ logging.error(f"Error in voice creation: {e}")
149
+ return None
150
+
151
+ def build_ui():
152
+ """Build the Gradio interface similar to SparkTTS."""
153
+
154
+ with gr.Blocks(title="MiraTTS Web Interface") as demo:
155
+ # Title
156
+ gr.HTML('<h1 style="text-align: center;">MiraTTS - High Quality Voice Synthesis</h1>')
157
+
158
+ # Description
159
+ gr.Markdown("""
160
+ MiraTTS is a highly optimized Text-to-Speech model based on Spark-TTS with LMDeploy acceleration.
161
+ It provides over 100x realtime generation speed with high-quality 48kHz audio output.
162
+ """)
163
+
164
+ with gr.Tabs():
165
+ # Voice Clone Tab
166
+ with gr.TabItem("Voice Clone"):
167
+ gr.Markdown("### Clone any voice using a reference audio sample")
168
+
169
+ with gr.Row():
170
+ prompt_audio_upload = gr.Audio(
171
+ sources="upload",
172
+ type="filepath",
173
+ label="Upload Reference Audio (recommended: 3-30 seconds, 16kHz+)",
174
+ )
175
+ prompt_audio_record = gr.Audio(
176
+ sources="microphone",
177
+ type="filepath",
178
+ label="Record Reference Audio",
179
+ )
180
+
181
+ text_input = gr.Textbox(
182
+ label="Text to Synthesize",
183
+ lines=3,
184
+ placeholder="Enter the text you want to convert to speech...",
185
+ value="Hello! This is a demonstration of MiraTTS voice cloning capabilities."
186
+ )
187
+
188
+ with gr.Row():
189
+ clone_button = gr.Button("Generate Audio", variant="primary")
190
+ clear_button = gr.Button("Clear")
191
+
192
+ audio_output_clone = gr.Audio(
193
+ label="Generated Audio",
194
+ autoplay=True
195
+ )
196
+
197
+ clone_button.click(
198
+ voice_clone_callback,
199
+ inputs=[text_input, prompt_audio_upload, prompt_audio_record],
200
+ outputs=[audio_output_clone],
201
+ )
202
+
203
+ clear_button.click(
204
+ lambda: (None, None, "", None),
205
+ outputs=[prompt_audio_upload, prompt_audio_record, text_input, audio_output_clone]
206
+ )
207
+
208
+ # Voice Creation Tab
209
+ with gr.TabItem("Voice Creation"):
210
+ gr.Markdown("### Create synthetic voices with custom parameters")
211
+
212
+ with gr.Row():
213
+ with gr.Column():
214
+ text_input_creation = gr.Textbox(
215
+ label="Text to Synthesize",
216
+ lines=3,
217
+ placeholder="Enter text here...",
218
+ value="You can create customized voices by adjusting the generation parameters below."
219
+ )
220
+
221
+ with gr.Row():
222
+ temperature = gr.Slider(
223
+ minimum=0.1,
224
+ maximum=1.5,
225
+ step=0.1,
226
+ value=0.8,
227
+ label="Temperature (creativity)"
228
+ )
229
+ top_p = gr.Slider(
230
+ minimum=0.1,
231
+ maximum=1.0,
232
+ step=0.05,
233
+ value=0.95,
234
+ label="Top-p (nucleus sampling)"
235
+ )
236
+ top_k = gr.Slider(
237
+ minimum=1,
238
+ maximum=100,
239
+ step=1,
240
+ value=50,
241
+ label="Top-k (vocabulary size)"
242
+ )
243
+
244
+ with gr.Column():
245
+ create_button = gr.Button("Create Voice", variant="primary")
246
+ audio_output_creation = gr.Audio(
247
+ label="Generated Audio",
248
+ autoplay=True
249
+ )
250
+
251
+ create_button.click(
252
+ voice_creation_callback,
253
+ inputs=[text_input_creation, temperature, top_p, top_k],
254
+ outputs=[audio_output_creation],
255
+ )
256
+
257
+ # About Tab
258
+ with gr.TabItem("About"):
259
+ gr.Markdown("""
260
+ ## About MiraTTS
261
+
262
+ MiraTTS is an optimized version of Spark-TTS with the following features:
263
+
264
+ - **Ultra-fast generation**: Over 100x realtime speed using LMDeploy optimization
265
+ - **High quality**: Generates crisp 48kHz audio outputs
266
+ - **Memory efficient**: Works within 6GB VRAM
267
+ - **Low latency**: As low as 100ms generation time
268
+ - **Voice cloning**: Clone any voice from a short audio sample
269
+
270
+ ### Model Information
271
+ - Base model: Spark-TTS-0.5B
272
+ - Optimization: LMDeploy + FlashSR
273
+ - Sample rate: 48kHz
274
+ - Model size: ~500M parameters
275
+
276
+ ### Usage Tips
277
+ - For voice cloning, use clear audio samples between 3-30 seconds
278
+ - Ensure reference audio is at least 16kHz quality
279
+ - Longer text inputs may require more memory
280
+ - Adjust generation parameters for different voice styles
281
+ """)
282
+
283
+ return demo
284
+
285
+ def parse_arguments():
286
+ """Parse command-line arguments."""
287
+ parser = argparse.ArgumentParser(description="MiraTTS Gradio Web Interface")
288
+ parser.add_argument(
289
+ "--model_dir",
290
+ type=str,
291
+ default="YatharthS/MiraTTS",
292
+ help="Path to the MiraTTS model directory or HuggingFace model ID"
293
+ )
294
+ parser.add_argument(
295
+ "--server_name",
296
+ type=str,
297
+ default="127.0.0.1",
298
+ help="Server host/IP for Gradio app"
299
+ )
300
+ parser.add_argument(
301
+ "--server_port",
302
+ type=int,
303
+ default=7860,
304
+ help="Server port for Gradio app"
305
+ )
306
+ parser.add_argument(
307
+ "--share",
308
+ action="store_true",
309
+ help="Create a public shareable link"
310
+ )
311
+ return parser.parse_args()
312
+
313
+ if __name__ == "__main__":
314
+ # Configure logging
315
+ logging.basicConfig(
316
+ level=logging.INFO,
317
+ format='%(asctime)s - %(levelname)s - %(message)s'
318
+ )
319
+
320
+ # Parse arguments
321
+ args = parse_arguments()
322
+
323
+ # Initialize model
324
+ logging.info("Initializing MiraTTS model...")
325
+ MODEL = initialize_model(args.model_dir)
326
+
327
+ # Build and launch interface
328
+ logging.info("Building Gradio interface...")
329
+ demo = build_ui()
330
+
331
+ logging.info(f"Launching web interface on {args.server_name}:{args.server_port}")
332
+ demo.launch(
333
+ server_name=args.server_name,
334
+ server_port=args.server_port,
335
+ share=args.share
336
+ )