| | import gc |
| | import html |
| | import io |
| | import os |
| | import queue |
| | import wave |
| | from argparse import ArgumentParser |
| | from functools import partial |
| | from pathlib import Path |
| |
|
| | import gradio as gr |
| | import librosa |
| | import numpy as np |
| | import pyrootutils |
| | import torch |
| | from loguru import logger |
| | from transformers import AutoTokenizer |
| |
|
| | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
| |
|
| |
|
| | from fish_speech.i18n import i18n |
| | from fish_speech.text.chn_text_norm.text import Text as ChnNormedText |
| | from fish_speech.utils import autocast_exclude_mps |
| | from tools.api import decode_vq_tokens, encode_reference |
| | from tools.llama.generate import ( |
| | GenerateRequest, |
| | GenerateResponse, |
| | WrappedGenerateResponse, |
| | launch_thread_safe_queue, |
| | ) |
| | from tools.vqgan.inference import load_model as load_decoder_model |
| |
|
| | |
| | os.environ["EINX_FILTER_TRACEBACK"] = "false" |
| |
|
| |
|
| | HEADER_MD = f"""# Fish Speech |
| | |
| | {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")} |
| | |
| | {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")} |
| | |
| | {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")} |
| | |
| | {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")} |
| | """ |
| |
|
| | TEXTBOX_PLACEHOLDER = i18n("Put your text here.") |
| | SPACE_IMPORTED = False |
| |
|
| |
|
| | def build_html_error_message(error): |
| | return f""" |
| | <div style="color: red; |
| | font-weight: bold;"> |
| | {html.escape(str(error))} |
| | </div> |
| | """ |
| |
|
| |
|
| | @torch.inference_mode() |
| | def inference( |
| | text, |
| | enable_reference_audio, |
| | reference_audio, |
| | reference_text, |
| | max_new_tokens, |
| | chunk_length, |
| | top_p, |
| | repetition_penalty, |
| | temperature, |
| | streaming=False, |
| | ): |
| | if args.max_gradio_length > 0 and len(text) > args.max_gradio_length: |
| | return ( |
| | None, |
| | None, |
| | i18n("Text is too long, please keep it under {} characters.").format( |
| | args.max_gradio_length |
| | ), |
| | ) |
| |
|
| | |
| | prompt_tokens = encode_reference( |
| | decoder_model=decoder_model, |
| | reference_audio=reference_audio, |
| | enable_reference_audio=enable_reference_audio, |
| | ) |
| |
|
| | |
| | request = dict( |
| | device=decoder_model.device, |
| | max_new_tokens=max_new_tokens, |
| | text=text, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | temperature=temperature, |
| | compile=args.compile, |
| | iterative_prompt=chunk_length > 0, |
| | chunk_length=chunk_length, |
| | max_length=2048, |
| | prompt_tokens=prompt_tokens if enable_reference_audio else None, |
| | prompt_text=reference_text if enable_reference_audio else None, |
| | ) |
| |
|
| | response_queue = queue.Queue() |
| | llama_queue.put( |
| | GenerateRequest( |
| | request=request, |
| | response_queue=response_queue, |
| | ) |
| | ) |
| |
|
| | if streaming: |
| | yield wav_chunk_header(), None, None |
| |
|
| | segments = [] |
| |
|
| | while True: |
| | result: WrappedGenerateResponse = response_queue.get() |
| | if result.status == "error": |
| | yield None, None, build_html_error_message(result.response) |
| | break |
| |
|
| | result: GenerateResponse = result.response |
| | if result.action == "next": |
| | break |
| |
|
| | with autocast_exclude_mps( |
| | device_type=decoder_model.device.type, dtype=args.precision |
| | ): |
| | fake_audios = decode_vq_tokens( |
| | decoder_model=decoder_model, |
| | codes=result.codes, |
| | ) |
| |
|
| | fake_audios = fake_audios.float().cpu().numpy() |
| | segments.append(fake_audios) |
| |
|
| | if streaming: |
| | yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None |
| |
|
| | if len(segments) == 0: |
| | return ( |
| | None, |
| | None, |
| | build_html_error_message( |
| | i18n("No audio generated, please check the input text.") |
| | ), |
| | ) |
| |
|
| | |
| | audio = np.concatenate(segments, axis=0) |
| | yield None, (decoder_model.spec_transform.sample_rate, audio), None |
| |
|
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| |
|
| | inference_stream = partial(inference, streaming=True) |
| |
|
| | n_audios = 4 |
| |
|
| | global_audio_list = [] |
| | global_error_list = [] |
| |
|
| |
|
| | def inference_wrapper( |
| | text, |
| | enable_reference_audio, |
| | reference_audio, |
| | reference_text, |
| | max_new_tokens, |
| | chunk_length, |
| | top_p, |
| | repetition_penalty, |
| | temperature, |
| | batch_infer_num, |
| | ): |
| | audios = [] |
| | errors = [] |
| |
|
| | for _ in range(batch_infer_num): |
| | result = inference( |
| | text, |
| | enable_reference_audio, |
| | reference_audio, |
| | reference_text, |
| | max_new_tokens, |
| | chunk_length, |
| | top_p, |
| | repetition_penalty, |
| | temperature, |
| | ) |
| |
|
| | _, audio_data, error_message = next(result) |
| |
|
| | audios.append( |
| | gr.Audio(value=audio_data if audio_data else None, visible=True), |
| | ) |
| | errors.append( |
| | gr.HTML(value=error_message if error_message else None, visible=True), |
| | ) |
| |
|
| | for _ in range(batch_infer_num, n_audios): |
| | audios.append( |
| | gr.Audio(value=None, visible=False), |
| | ) |
| | errors.append( |
| | gr.HTML(value=None, visible=False), |
| | ) |
| |
|
| | return None, *audios, *errors |
| |
|
| |
|
| | def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): |
| | buffer = io.BytesIO() |
| |
|
| | with wave.open(buffer, "wb") as wav_file: |
| | wav_file.setnchannels(channels) |
| | wav_file.setsampwidth(bit_depth // 8) |
| | wav_file.setframerate(sample_rate) |
| |
|
| | wav_header_bytes = buffer.getvalue() |
| | buffer.close() |
| | return wav_header_bytes |
| |
|
| |
|
| | def normalize_text(user_input, use_normalization): |
| | if use_normalization: |
| | return ChnNormedText(raw_text=user_input).normalize() |
| | else: |
| | return user_input |
| |
|
| |
|
| | asr_model = None |
| |
|
| |
|
| | def build_app(): |
| | with gr.Blocks(theme=gr.themes.Base()) as app: |
| | gr.Markdown(HEADER_MD) |
| |
|
| | |
| | app.load( |
| | None, |
| | None, |
| | js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" |
| | % args.theme, |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | text = gr.Textbox( |
| | label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 |
| | ) |
| | refined_text = gr.Textbox( |
| | label=i18n("Realtime Transform Text"), |
| | placeholder=i18n( |
| | "Normalization Result Preview (Currently Only Chinese)" |
| | ), |
| | lines=5, |
| | interactive=False, |
| | ) |
| |
|
| | with gr.Row(): |
| | if_refine_text = gr.Checkbox( |
| | label=i18n("Text Normalization"), |
| | value=False, |
| | scale=1, |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Tab(label=i18n("Advanced Config")): |
| | chunk_length = gr.Slider( |
| | label=i18n("Iterative Prompt Length, 0 means off"), |
| | minimum=50, |
| | maximum=300, |
| | value=200, |
| | step=8, |
| | ) |
| |
|
| | max_new_tokens = gr.Slider( |
| | label=i18n("Maximum tokens per batch, 0 means no limit"), |
| | minimum=0, |
| | maximum=2048, |
| | value=1024, |
| | step=8, |
| | ) |
| |
|
| | top_p = gr.Slider( |
| | label="Top-P", |
| | minimum=0.6, |
| | maximum=0.9, |
| | value=0.7, |
| | step=0.01, |
| | ) |
| |
|
| | repetition_penalty = gr.Slider( |
| | label=i18n("Repetition Penalty"), |
| | minimum=1, |
| | maximum=1.5, |
| | value=1.2, |
| | step=0.01, |
| | ) |
| |
|
| | temperature = gr.Slider( |
| | label="Temperature", |
| | minimum=0.6, |
| | maximum=0.9, |
| | value=0.7, |
| | step=0.01, |
| | ) |
| |
|
| | with gr.Tab(label=i18n("Reference Audio")): |
| | gr.Markdown( |
| | i18n( |
| | "5 to 10 seconds of reference audio, useful for specifying speaker." |
| | ) |
| | ) |
| |
|
| | enable_reference_audio = gr.Checkbox( |
| | label=i18n("Enable Reference Audio"), |
| | ) |
| | reference_audio = gr.Audio( |
| | label=i18n("Reference Audio"), |
| | type="filepath", |
| | ) |
| | with gr.Row(): |
| | reference_text = gr.Textbox( |
| | label=i18n("Reference Text"), |
| | lines=1, |
| | placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", |
| | value="", |
| | ) |
| | with gr.Tab(label=i18n("Batch Inference")): |
| | batch_infer_num = gr.Slider( |
| | label="Batch infer nums", |
| | minimum=1, |
| | maximum=n_audios, |
| | step=1, |
| | value=1, |
| | ) |
| |
|
| | with gr.Column(scale=3): |
| | for _ in range(n_audios): |
| | with gr.Row(): |
| | error = gr.HTML( |
| | label=i18n("Error Message"), |
| | visible=True if _ == 0 else False, |
| | ) |
| | global_error_list.append(error) |
| | with gr.Row(): |
| | audio = gr.Audio( |
| | label=i18n("Generated Audio"), |
| | type="numpy", |
| | interactive=False, |
| | visible=True if _ == 0 else False, |
| | ) |
| | global_audio_list.append(audio) |
| |
|
| | with gr.Row(): |
| | stream_audio = gr.Audio( |
| | label=i18n("Streaming Audio"), |
| | streaming=True, |
| | autoplay=True, |
| | interactive=False, |
| | show_download_button=True, |
| | ) |
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | generate = gr.Button( |
| | value="\U0001F3A7 " + i18n("Generate"), variant="primary" |
| | ) |
| | generate_stream = gr.Button( |
| | value="\U0001F3A7 " + i18n("Streaming Generate"), |
| | variant="primary", |
| | ) |
| |
|
| | text.input( |
| | fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text] |
| | ) |
| |
|
| | |
| | generate.click( |
| | inference_wrapper, |
| | [ |
| | refined_text, |
| | enable_reference_audio, |
| | reference_audio, |
| | reference_text, |
| | max_new_tokens, |
| | chunk_length, |
| | top_p, |
| | repetition_penalty, |
| | temperature, |
| | batch_infer_num, |
| | ], |
| | [stream_audio, *global_audio_list, *global_error_list], |
| | concurrency_limit=1, |
| | ) |
| |
|
| | generate_stream.click( |
| | inference_stream, |
| | [ |
| | refined_text, |
| | enable_reference_audio, |
| | reference_audio, |
| | reference_text, |
| | max_new_tokens, |
| | chunk_length, |
| | top_p, |
| | repetition_penalty, |
| | temperature, |
| | ], |
| | [stream_audio, global_audio_list[0], global_error_list[0]], |
| | concurrency_limit=10, |
| | ) |
| | return app |
| |
|
| |
|
| | def parse_args(): |
| | parser = ArgumentParser() |
| | parser.add_argument( |
| | "--llama-checkpoint-path", |
| | type=Path, |
| | default="checkpoints/fish-speech-1.4", |
| | ) |
| | parser.add_argument( |
| | "--decoder-checkpoint-path", |
| | type=Path, |
| | default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", |
| | ) |
| | parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") |
| | parser.add_argument("--device", type=str, default="cuda") |
| | parser.add_argument("--half", action="store_true") |
| | parser.add_argument("--compile", action="store_true") |
| | parser.add_argument("--max-gradio-length", type=int, default=0) |
| | parser.add_argument("--theme", type=str, default="light") |
| |
|
| | return parser.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | args.precision = torch.half if args.half else torch.bfloat16 |
| |
|
| | logger.info("Loading Llama model...") |
| | llama_queue = launch_thread_safe_queue( |
| | checkpoint_path=args.llama_checkpoint_path, |
| | device=args.device, |
| | precision=args.precision, |
| | compile=args.compile, |
| | ) |
| | logger.info("Llama model loaded, loading VQ-GAN model...") |
| |
|
| | decoder_model = load_decoder_model( |
| | config_name=args.decoder_config_name, |
| | checkpoint_path=args.decoder_checkpoint_path, |
| | device=args.device, |
| | ) |
| |
|
| | logger.info("Decoder model loaded, warming up...") |
| |
|
| | |
| | list( |
| | inference( |
| | text="Hello, world!", |
| | enable_reference_audio=False, |
| | reference_audio=None, |
| | reference_text="", |
| | max_new_tokens=1024, |
| | chunk_length=200, |
| | top_p=0.7, |
| | repetition_penalty=1.2, |
| | temperature=0.7, |
| | ) |
| | ) |
| |
|
| | logger.info("Warming up done, launching the web UI...") |
| |
|
| | app = build_app() |
| | app.launch(show_api=True) |
| |
|