| | import os |
| | import queue |
| | from huggingface_hub import snapshot_download |
| | import numpy as np |
| | import wave |
| | import io |
| | import gc |
| | from typing import Callable |
| |
|
| | |
| | os.makedirs("checkpoints", exist_ok=True) |
| | snapshot_download(repo_id="fishaudio/s1-mini", local_dir="./checkpoints/openaudio-s1-mini") |
| |
|
| | print("All checkpoints downloaded") |
| |
|
| | import html |
| | import os |
| | from argparse import ArgumentParser |
| | from pathlib import Path |
| |
|
| | import gradio as gr |
| | import torch |
| | import torchaudio |
| |
|
| | torchaudio.set_audio_backend("soundfile") |
| |
|
| | from loguru import logger |
| | from fish_speech.i18n import i18n |
| | from fish_speech.inference_engine import TTSInferenceEngine |
| | from fish_speech.models.dac.inference import load_model as load_decoder_model |
| | from fish_speech.models.text2semantic.inference import launch_thread_safe_queue |
| | from tools.webui.inference import get_inference_wrapper |
| | from fish_speech.utils.schema import ServeTTSRequest |
| |
|
| | |
| | os.environ["EINX_FILTER_TRACEBACK"] = "false" |
| |
|
| |
|
| | HEADER_MD = """# Fish Audio S1 |
| | |
| | ## The demo in this space is Fish Audio S1, Please check [Fish Audio](https://fish.audio) for the best model. |
| | ## 该 Demo 为 Fish Audio S1 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO. |
| | |
| | A text-to-speech model based on DAC & Qwen3 developed by [Fish Audio](https://fish.audio). |
| | 由 [Fish Audio](https://fish.audio) 研发的 DAC & Qwen3 多语种语音合成. |
| | |
| | You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/s1-mini). |
| | 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/s1-mini) 找到模型. |
| | |
| | Related code and weights are released under CC BY-NC-SA 4.0 License. |
| | 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布. |
| | |
| | We are not responsible for any misuse of the model, please consider your local laws and regulations before using it. |
| | 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规. |
| | |
| | The model running in this WebUI is Fish Audio S1 Mini. |
| | 在此 WebUI 中运行的模型是 Fish Audio S1 Mini. |
| | """ |
| |
|
| | TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本.""" |
| |
|
| | try: |
| | import spaces |
| |
|
| | GPU_DECORATOR = spaces.GPU |
| | except ImportError: |
| |
|
| | def GPU_DECORATOR(func): |
| | def wrapper(*args, **kwargs): |
| | return func(*args, **kwargs) |
| |
|
| | return wrapper |
| |
|
| | def build_html_error_message(error): |
| | return f""" |
| | <div style="color: red; |
| | font-weight: bold;"> |
| | {html.escape(str(error))} |
| | </div> |
| | """ |
| |
|
| | def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): |
| | buffer = io.BytesIO() |
| |
|
| | with wave.open(buffer, "wb") as wav_file: |
| | wav_file.setnchannels(channels) |
| | wav_file.setsampwidth(bit_depth // 8) |
| | wav_file.setframerate(sample_rate) |
| |
|
| | wav_header_bytes = buffer.getvalue() |
| | buffer.close() |
| | return wav_header_bytes |
| |
|
| |
|
| | def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks: |
| | with gr.Blocks(theme=gr.themes.Base()) as app: |
| | gr.Markdown(HEADER_MD) |
| |
|
| | |
| | app.load( |
| | None, |
| | None, |
| | js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" |
| | % theme, |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | text = gr.Textbox( |
| | label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | with gr.Tab(label=i18n("Advanced Config")): |
| | with gr.Row(): |
| | chunk_length = gr.Slider( |
| | label=i18n("Iterative Prompt Length, 0 means off"), |
| | minimum=0, |
| | maximum=500, |
| | value=0, |
| | step=8, |
| | ) |
| |
|
| | max_new_tokens = gr.Slider( |
| | label=i18n( |
| | "Maximum tokens per batch, 0 means no limit" |
| | ), |
| | minimum=0, |
| | maximum=2048, |
| | value=0, |
| | step=8, |
| | ) |
| |
|
| | with gr.Row(): |
| | top_p = gr.Slider( |
| | label="Top-P", |
| | minimum=0.7, |
| | maximum=0.95, |
| | value=0.9, |
| | step=0.01, |
| | ) |
| |
|
| | repetition_penalty = gr.Slider( |
| | label=i18n("Repetition Penalty"), |
| | minimum=1, |
| | maximum=1.2, |
| | value=1.1, |
| | step=0.01, |
| | ) |
| |
|
| | with gr.Row(): |
| | temperature = gr.Slider( |
| | label="Temperature", |
| | minimum=0.7, |
| | maximum=1.0, |
| | value=0.9, |
| | step=0.01, |
| | ) |
| | seed = gr.Number( |
| | label="Seed", |
| | info="0 means randomized inference, otherwise deterministic", |
| | value=0, |
| | ) |
| |
|
| | with gr.Tab(label=i18n("Reference Audio")): |
| | with gr.Row(): |
| | gr.Markdown( |
| | i18n( |
| | "5 to 10 seconds of reference audio, useful for specifying speaker." |
| | ) |
| | ) |
| | with gr.Row(): |
| | reference_id = gr.Textbox( |
| | label=i18n("Reference ID"), |
| | placeholder="Leave empty to use uploaded references", |
| | ) |
| |
|
| | with gr.Row(): |
| | use_memory_cache = gr.Radio( |
| | label=i18n("Use Memory Cache"), |
| | choices=["on", "off"], |
| | value="on", |
| | ) |
| |
|
| | with gr.Row(): |
| | reference_audio = gr.Audio( |
| | label=i18n("Reference Audio"), |
| | type="filepath", |
| | ) |
| | with gr.Row(): |
| | reference_text = gr.Textbox( |
| | label=i18n("Reference Text"), |
| | lines=1, |
| | placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", |
| | value="", |
| | ) |
| |
|
| | with gr.Column(scale=3): |
| | with gr.Row(): |
| | error = gr.HTML( |
| | label=i18n("Error Message"), |
| | visible=True, |
| | ) |
| | with gr.Row(): |
| | audio = gr.Audio( |
| | label=i18n("Generated Audio"), |
| | type="numpy", |
| | interactive=False, |
| | visible=True, |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | generate = gr.Button( |
| | value="\U0001f3a7 " + i18n("Generate"), |
| | variant="primary", |
| | ) |
| |
|
| | |
| | generate.click( |
| | inference_fct, |
| | [ |
| | text, |
| | reference_id, |
| | reference_audio, |
| | reference_text, |
| | max_new_tokens, |
| | chunk_length, |
| | top_p, |
| | repetition_penalty, |
| | temperature, |
| | seed, |
| | use_memory_cache, |
| | ], |
| | [audio, error], |
| | concurrency_limit=1, |
| | ) |
| |
|
| | return app |
| |
|
| | def parse_args(): |
| | parser = ArgumentParser() |
| | parser.add_argument( |
| | "--llama-checkpoint-path", |
| | type=Path, |
| | default="checkpoints/openaudio-s1-mini", |
| | ) |
| | parser.add_argument( |
| | "--decoder-checkpoint-path", |
| | type=Path, |
| | default="checkpoints/openaudio-s1-mini/codec.pth", |
| | ) |
| | parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq") |
| | parser.add_argument("--device", type=str, default="cuda") |
| | parser.add_argument("--half", action="store_true") |
| | parser.add_argument("--compile", action="store_true",default=True) |
| | parser.add_argument("--max-gradio-length", type=int, default=0) |
| | parser.add_argument("--theme", type=str, default="dark") |
| |
|
| | return parser.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | args.precision = torch.half if args.half else torch.bfloat16 |
| |
|
| | logger.info("Loading Llama model...") |
| | llama_queue = launch_thread_safe_queue( |
| | checkpoint_path=args.llama_checkpoint_path, |
| | device=args.device, |
| | precision=args.precision, |
| | compile=args.compile, |
| | ) |
| | logger.info("Llama model loaded, loading VQ-GAN model...") |
| |
|
| | decoder_model = load_decoder_model( |
| | config_name=args.decoder_config_name, |
| | checkpoint_path=args.decoder_checkpoint_path, |
| | device=args.device, |
| | ) |
| |
|
| | logger.info("Decoder model loaded, warming up...") |
| |
|
| | |
| | inference_engine = TTSInferenceEngine( |
| | llama_queue=llama_queue, |
| | decoder_model=decoder_model, |
| | compile=args.compile, |
| | precision=args.precision, |
| | ) |
| |
|
| | |
| | list( |
| | inference_engine.inference( |
| | ServeTTSRequest( |
| | text="Hello world.", |
| | references=[], |
| | reference_id=None, |
| | max_new_tokens=1024, |
| | chunk_length=200, |
| | top_p=0.7, |
| | repetition_penalty=1.5, |
| | temperature=0.7, |
| | format="wav", |
| | ) |
| | ) |
| | ) |
| |
|
| | logger.info("Warming up done, launching the web UI...") |
| |
|
| | inference_fct = get_inference_wrapper(inference_engine) |
| |
|
| | app = build_app(inference_fct, args.theme) |
| | app.queue(api_open=True).launch(show_error=True, show_api=True) |
| |
|