s1-mini / app.py
lengyue233's picture
Update app.py
7898463 verified
import os
import queue
from huggingface_hub import snapshot_download
import numpy as np
import wave
import io
import gc
from typing import Callable
# Download if not exists
os.makedirs("checkpoints", exist_ok=True)
snapshot_download(repo_id="fishaudio/s1-mini", local_dir="./checkpoints/openaudio-s1-mini")
print("All checkpoints downloaded")
import html
import os
from argparse import ArgumentParser
from pathlib import Path
import gradio as gr
import torch
import torchaudio
torchaudio.set_audio_backend("soundfile")
from loguru import logger
from fish_speech.i18n import i18n
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.dac.inference import load_model as load_decoder_model
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
from tools.webui.inference import get_inference_wrapper
from fish_speech.utils.schema import ServeTTSRequest
# Make einx happy
os.environ["EINX_FILTER_TRACEBACK"] = "false"
HEADER_MD = """# Fish Audio S1
## The demo in this space is Fish Audio S1, Please check [Fish Audio](https://fish.audio) for the best model.
## ่ฏฅ Demo ไธบ Fish Audio S1 ็‰ˆๆœฌ, ่ฏทๅœจ [Fish Audio](https://fish.audio) ไฝ“้ชŒๆœ€ๆ–ฐ DEMO.
A text-to-speech model based on DAC & Qwen3 developed by [Fish Audio](https://fish.audio).
็”ฑ [Fish Audio](https://fish.audio) ็ ”ๅ‘็š„ DAC & Qwen3 ๅคš่ฏญ็ง่ฏญ้Ÿณๅˆๆˆ.
You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/s1-mini).
ไฝ ๅฏไปฅๅœจ [่ฟ™้‡Œ](https://github.com/fishaudio/fish-speech) ๆ‰พๅˆฐๆบไปฃ็ ๅ’Œ [่ฟ™้‡Œ](https://huggingface.co/fishaudio/s1-mini) ๆ‰พๅˆฐๆจกๅž‹.
Related code and weights are released under CC BY-NC-SA 4.0 License.
็›ธๅ…ณไปฃ็ ๏ผŒๆƒ้‡ไฝฟ็”จ CC BY-NC-SA 4.0 ่ฎธๅฏ่ฏๅ‘ๅธƒ.
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
ๆˆ‘ไปฌไธๅฏนๆจกๅž‹็š„ไปปไฝ•ๆปฅ็”จ่ดŸ่ดฃ๏ผŒ่ฏทๅœจไฝฟ็”จไน‹ๅ‰่€ƒ่™‘ๆ‚จๅฝ“ๅœฐ็š„ๆณ•ๅพ‹ๆณ•่ง„.
The model running in this WebUI is Fish Audio S1 Mini.
ๅœจๆญค WebUI ไธญ่ฟ่กŒ็š„ๆจกๅž‹ๆ˜ฏ Fish Audio S1 Mini.
"""
TEXTBOX_PLACEHOLDER = """Put your text here. ๅœจๆญคๅค„่พ“ๅ…ฅๆ–‡ๆœฌ."""
try:
import spaces
GPU_DECORATOR = spaces.GPU
except ImportError:
def GPU_DECORATOR(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
def build_html_error_message(error):
return f"""
<div style="color: red;
font-weight: bold;">
{html.escape(str(error))}
</div>
"""
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bit_depth // 8)
wav_file.setframerate(sample_rate)
wav_header_bytes = buffer.getvalue()
buffer.close()
return wav_header_bytes
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
# Use light theme by default
app.load(
None,
None,
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
% theme,
)
# Inference
with gr.Row():
with gr.Column(scale=3):
text = gr.Textbox(
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
)
with gr.Row():
with gr.Column():
with gr.Tab(label=i18n("Advanced Config")):
with gr.Row():
chunk_length = gr.Slider(
label=i18n("Iterative Prompt Length, 0 means off"),
minimum=0,
maximum=500,
value=0,
step=8,
)
max_new_tokens = gr.Slider(
label=i18n(
"Maximum tokens per batch, 0 means no limit"
),
minimum=0,
maximum=2048,
value=0,
step=8,
)
with gr.Row():
top_p = gr.Slider(
label="Top-P",
minimum=0.7,
maximum=0.95,
value=0.9,
step=0.01,
)
repetition_penalty = gr.Slider(
label=i18n("Repetition Penalty"),
minimum=1,
maximum=1.2,
value=1.1,
step=0.01,
)
with gr.Row():
temperature = gr.Slider(
label="Temperature",
minimum=0.7,
maximum=1.0,
value=0.9,
step=0.01,
)
seed = gr.Number(
label="Seed",
info="0 means randomized inference, otherwise deterministic",
value=0,
)
with gr.Tab(label=i18n("Reference Audio")):
with gr.Row():
gr.Markdown(
i18n(
"5 to 10 seconds of reference audio, useful for specifying speaker."
)
)
with gr.Row():
reference_id = gr.Textbox(
label=i18n("Reference ID"),
placeholder="Leave empty to use uploaded references",
)
with gr.Row():
use_memory_cache = gr.Radio(
label=i18n("Use Memory Cache"),
choices=["on", "off"],
value="on",
)
with gr.Row():
reference_audio = gr.Audio(
label=i18n("Reference Audio"),
type="filepath",
)
with gr.Row():
reference_text = gr.Textbox(
label=i18n("Reference Text"),
lines=1,
placeholder="ๅœจไธ€ๆ— ๆ‰€็Ÿฅไธญ๏ผŒๆขฆ้‡Œ็š„ไธ€ๅคฉ็ป“ๆŸไบ†๏ผŒไธ€ไธชๆ–ฐ็š„ใ€Œ่ฝฎๅ›žใ€ไพฟไผšๅผ€ๅง‹ใ€‚",
value="",
)
with gr.Column(scale=3):
with gr.Row():
error = gr.HTML(
label=i18n("Error Message"),
visible=True,
)
with gr.Row():
audio = gr.Audio(
label=i18n("Generated Audio"),
type="numpy",
interactive=False,
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
generate = gr.Button(
value="\U0001f3a7 " + i18n("Generate"),
variant="primary",
)
# Submit
generate.click(
inference_fct,
[
text,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
],
[audio, error],
concurrency_limit=1,
)
return app
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--llama-checkpoint-path",
type=Path,
default="checkpoints/openaudio-s1-mini",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=Path,
default="checkpoints/openaudio-s1-mini/codec.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true",default=True)
parser.add_argument("--max-gradio-length", type=int, default=0)
parser.add_argument("--theme", type=str, default="dark")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
args.precision = torch.half if args.half else torch.bfloat16
logger.info("Loading Llama model...")
llama_queue = launch_thread_safe_queue(
checkpoint_path=args.llama_checkpoint_path,
device=args.device,
precision=args.precision,
compile=args.compile,
)
logger.info("Llama model loaded, loading VQ-GAN model...")
decoder_model = load_decoder_model(
config_name=args.decoder_config_name,
checkpoint_path=args.decoder_checkpoint_path,
device=args.device,
)
logger.info("Decoder model loaded, warming up...")
# Create the inference engine
inference_engine = TTSInferenceEngine(
llama_queue=llama_queue,
decoder_model=decoder_model,
compile=args.compile,
precision=args.precision,
)
# Dry run to check if the model is loaded correctly and avoid the first-time latency
list(
inference_engine.inference(
ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=1024,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
format="wav",
)
)
)
logger.info("Warming up done, launching the web UI...")
inference_fct = get_inference_wrapper(inference_engine)
app = build_app(inference_fct, args.theme)
app.queue(api_open=True).launch(show_error=True, show_api=True)