| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import torch |
| | import soundfile as sf |
| | import logging |
| | import argparse |
| | import gradio as gr |
| | import platform |
| |
|
| | from datetime import datetime |
| | from cli.SparkTTS import SparkTTS |
| | from sparktts.utils.token_parser import LEVELS_MAP_UI |
| |
|
| |
|
| | def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0): |
| | """Load the model once at the beginning.""" |
| | logging.info(f"Loading model from: {model_dir}") |
| |
|
| | |
| | if platform.system() == "Darwin": |
| | |
| | device = torch.device(f"mps:{device}") |
| | logging.info(f"Using MPS device: {device}") |
| | elif torch.cuda.is_available(): |
| | |
| | device = torch.device(f"cuda:{device}") |
| | logging.info(f"Using CUDA device: {device}") |
| | else: |
| | |
| | device = torch.device("cpu") |
| | logging.info("GPU acceleration not available, using CPU") |
| |
|
| | model = SparkTTS(model_dir, device) |
| | return model |
| |
|
| |
|
| | def run_tts( |
| | text, |
| | model, |
| | prompt_text=None, |
| | prompt_speech=None, |
| | gender=None, |
| | pitch=None, |
| | speed=None, |
| | save_dir="example/results", |
| | ): |
| | """Perform TTS inference and save the generated audio.""" |
| | logging.info(f"Saving audio to: {save_dir}") |
| |
|
| | if prompt_text is not None: |
| | prompt_text = None if len(prompt_text) <= 1 else prompt_text |
| |
|
| | |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | |
| | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
| | save_path = os.path.join(save_dir, f"{timestamp}.wav") |
| |
|
| | logging.info("Starting inference...") |
| |
|
| | |
| | with torch.no_grad(): |
| | wav = model.inference( |
| | text, |
| | prompt_speech, |
| | prompt_text, |
| | gender, |
| | pitch, |
| | speed, |
| | ) |
| |
|
| | sf.write(save_path, wav, samplerate=16000) |
| |
|
| | logging.info(f"Audio saved at: {save_path}") |
| |
|
| | return save_path |
| |
|
| |
|
| | def build_ui(model_dir, device=0): |
| |
|
| | |
| | model = initialize_model(model_dir, device=device) |
| |
|
| | |
| | def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record): |
| | """ |
| | Gradio callback to clone voice using text and optional prompt speech. |
| | - text: The input text to be synthesised. |
| | - prompt_text: Additional textual info for the prompt (optional). |
| | - prompt_wav_upload/prompt_wav_record: Audio files used as reference. |
| | """ |
| | prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record |
| | prompt_text_clean = None if len(prompt_text) < 2 else prompt_text |
| |
|
| | audio_output_path = run_tts( |
| | text, |
| | model, |
| | prompt_text=prompt_text_clean, |
| | prompt_speech=prompt_speech |
| | ) |
| | return audio_output_path |
| |
|
| | |
| | def voice_creation(text, gender, pitch, speed): |
| | """ |
| | Gradio callback to create a synthetic voice with adjustable parameters. |
| | - text: The input text for synthesis. |
| | - gender: 'male' or 'female'. |
| | - pitch/speed: Ranges mapped by LEVELS_MAP_UI. |
| | """ |
| | pitch_val = LEVELS_MAP_UI[int(pitch)] |
| | speed_val = LEVELS_MAP_UI[int(speed)] |
| | audio_output_path = run_tts( |
| | text, |
| | model, |
| | gender=gender, |
| | pitch=pitch_val, |
| | speed=speed_val |
| | ) |
| | return audio_output_path |
| |
|
| | with gr.Blocks() as demo: |
| | |
| | gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>') |
| | with gr.Tabs(): |
| | |
| | with gr.TabItem("Voice Clone"): |
| | gr.Markdown( |
| | "### Upload reference audio or recording (上传参考音频或者录音)" |
| | ) |
| |
|
| | with gr.Row(): |
| | prompt_wav_upload = gr.Audio( |
| | sources="upload", |
| | type="filepath", |
| | label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", |
| | ) |
| | prompt_wav_record = gr.Audio( |
| | sources="microphone", |
| | type="filepath", |
| | label="Record the prompt audio file.", |
| | ) |
| |
|
| | with gr.Row(): |
| | text_input = gr.Textbox( |
| | label="Text", lines=3, placeholder="Enter text here" |
| | ) |
| | prompt_text_input = gr.Textbox( |
| | label="Text of prompt speech (Optional; recommended for cloning in the same language.)", |
| | lines=3, |
| | placeholder="Enter text of the prompt speech.", |
| | ) |
| |
|
| | audio_output = gr.Audio( |
| | label="Generated Audio", autoplay=True, streaming=True |
| | ) |
| |
|
| | generate_buttom_clone = gr.Button("Generate") |
| |
|
| | generate_buttom_clone.click( |
| | voice_clone, |
| | inputs=[ |
| | text_input, |
| | prompt_text_input, |
| | prompt_wav_upload, |
| | prompt_wav_record, |
| | ], |
| | outputs=[audio_output], |
| | ) |
| |
|
| | |
| | with gr.TabItem("Voice Creation"): |
| | gr.Markdown( |
| | "### Create your own voice based on the following parameters" |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | gender = gr.Radio( |
| | choices=["male", "female"], value="male", label="Gender" |
| | ) |
| | pitch = gr.Slider( |
| | minimum=1, maximum=5, step=1, value=3, label="Pitch" |
| | ) |
| | speed = gr.Slider( |
| | minimum=1, maximum=5, step=1, value=3, label="Speed" |
| | ) |
| | with gr.Column(): |
| | text_input_creation = gr.Textbox( |
| | label="Input Text", |
| | lines=3, |
| | placeholder="Enter text here", |
| | value="You can generate a customized voice by adjusting parameters such as pitch and speed.", |
| | ) |
| | create_button = gr.Button("Create Voice") |
| |
|
| | audio_output = gr.Audio( |
| | label="Generated Audio", autoplay=True, streaming=True |
| | ) |
| | create_button.click( |
| | voice_creation, |
| | inputs=[text_input_creation, gender, pitch, speed], |
| | outputs=[audio_output], |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | def parse_arguments(): |
| | """ |
| | Parse command-line arguments such as model directory and device ID. |
| | """ |
| | parser = argparse.ArgumentParser(description="Spark TTS Gradio server.") |
| | parser.add_argument( |
| | "--model_dir", |
| | type=str, |
| | default="pretrained_models/Spark-TTS-0.5B", |
| | help="Path to the model directory." |
| | ) |
| | parser.add_argument( |
| | "--device", |
| | type=int, |
| | default=0, |
| | help="ID of the GPU device to use (e.g., 0 for cuda:0)." |
| | ) |
| | parser.add_argument( |
| | "--server_name", |
| | type=str, |
| | default="0.0.0.0", |
| | help="Server host/IP for Gradio app." |
| | ) |
| | parser.add_argument( |
| | "--server_port", |
| | type=int, |
| | default=7860, |
| | help="Server port for Gradio app." |
| | ) |
| | return parser.parse_args() |
| |
|
| | if __name__ == "__main__": |
| | |
| | args = parse_arguments() |
| |
|
| | |
| | demo = build_ui( |
| | model_dir=args.model_dir, |
| | device=args.device |
| | ) |
| |
|
| | |
| | demo.launch( |
| | server_name=args.server_name, |
| | server_port=args.server_port |
| | ) |