| | from __future__ import annotations
|
| |
|
| | import os
|
| |
|
| | os.environ["USE_LIBUV"] = "0"
|
| | import datetime
|
| | import html
|
| | import json
|
| | import platform
|
| | import shutil
|
| | import signal
|
| | import subprocess
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| | import gradio as gr
|
| | import psutil
|
| | import yaml
|
| | from loguru import logger
|
| | from tqdm import tqdm
|
| |
|
| | PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
|
| | sys.path.insert(0, "")
|
| | print(sys.path)
|
| | cur_work_dir = Path(os.getcwd()).resolve()
|
| | print("You are in ", str(cur_work_dir))
|
| |
|
| | from fish_speech.i18n import i18n
|
| | from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
|
| |
|
| | config_path = cur_work_dir / "fish_speech" / "configs"
|
| | vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
|
| | llama_yml_path = config_path / "text2semantic_finetune.yaml"
|
| |
|
| | env = os.environ.copy()
|
| | env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
|
| |
|
| | seafoam = Seafoam()
|
| |
|
| |
|
| | def build_html_error_message(error):
|
| | return f"""
|
| | <div style="color: red; font-weight: bold;">
|
| | {html.escape(error)}
|
| | </div>
|
| | """
|
| |
|
| |
|
| | def build_html_ok_message(msg):
|
| | return f"""
|
| | <div style="color: green; font-weight: bold;">
|
| | {html.escape(msg)}
|
| | </div>
|
| | """
|
| |
|
| |
|
| | def build_html_href(link, desc, msg):
|
| | return f"""
|
| | <span style="color: green; font-weight: bold; display: inline-block">
|
| | {html.escape(msg)}
|
| | <a href="{link}">{desc}</a>
|
| | </span>
|
| | """
|
| |
|
| |
|
| | def load_data_in_raw(path):
|
| | with open(path, "r", encoding="utf-8") as file:
|
| | data = file.read()
|
| | return str(data)
|
| |
|
| |
|
| | def kill_proc_tree(pid, including_parent=True):
|
| | try:
|
| | parent = psutil.Process(pid)
|
| | except psutil.NoSuchProcess:
|
| |
|
| | return
|
| |
|
| | children = parent.children(recursive=True)
|
| | for child in children:
|
| | try:
|
| | os.kill(child.pid, signal.SIGTERM)
|
| | except OSError:
|
| | pass
|
| | if including_parent:
|
| | try:
|
| | os.kill(parent.pid, signal.SIGTERM)
|
| | except OSError:
|
| | pass
|
| |
|
| |
|
| | system = platform.system()
|
| | p_label = None
|
| | p_infer = None
|
| | p_tensorboard = None
|
| |
|
| |
|
| | def kill_process(pid):
|
| | if system == "Windows":
|
| | cmd = "taskkill /t /f /pid %s" % pid
|
| |
|
| | subprocess.run(cmd)
|
| | else:
|
| | kill_proc_tree(pid)
|
| |
|
| |
|
| | def change_label(if_label):
|
| | global p_label
|
| | if if_label == True and p_label is None:
|
| | url = "http://localhost:3000"
|
| | remote_url = "https://text-labeler.pages.dev/"
|
| | try:
|
| | p_label = subprocess.Popen(
|
| | [
|
| | (
|
| | "asr-label-linux-x64"
|
| | if sys.platform == "linux"
|
| | else "asr-label-win-x64.exe"
|
| | )
|
| | ]
|
| | )
|
| | except FileNotFoundError:
|
| | logger.warning("asr-label execution not found!")
|
| |
|
| | yield build_html_href(
|
| | link=remote_url,
|
| | desc=i18n("Optional online ver"),
|
| | msg=i18n("Opened labeler in browser"),
|
| | )
|
| |
|
| | elif if_label == False and p_label is not None:
|
| | kill_process(p_label.pid)
|
| | p_label = None
|
| | yield build_html_ok_message("Nothing")
|
| |
|
| |
|
| | def clean_infer_cache():
|
| | import tempfile
|
| |
|
| | temp_dir = Path(tempfile.gettempdir())
|
| | gradio_dir = str(temp_dir / "gradio")
|
| | try:
|
| | shutil.rmtree(gradio_dir)
|
| | logger.info(f"Deleted cached audios: {gradio_dir}")
|
| | except PermissionError:
|
| | logger.info(f"Permission denied: Unable to delete {gradio_dir}")
|
| | except FileNotFoundError:
|
| | logger.info(f"{gradio_dir} was not found")
|
| | except Exception as e:
|
| | logger.info(f"An error occurred: {e}")
|
| |
|
| |
|
| | def change_infer(
|
| | if_infer,
|
| | host,
|
| | port,
|
| | infer_decoder_model,
|
| | infer_decoder_config,
|
| | infer_llama_model,
|
| | infer_compile,
|
| | ):
|
| | global p_infer
|
| | if if_infer == True and p_infer == None:
|
| | env = os.environ.copy()
|
| |
|
| | env["GRADIO_SERVER_NAME"] = host
|
| | env["GRADIO_SERVER_PORT"] = port
|
| |
|
| | url = f"http://{host}:{port}"
|
| | yield build_html_ok_message(
|
| | i18n("Inferring interface is launched at {}").format(url)
|
| | )
|
| |
|
| | clean_infer_cache()
|
| |
|
| | p_infer = subprocess.Popen(
|
| | [
|
| | PYTHON,
|
| | "tools/webui.py",
|
| | "--decoder-checkpoint-path",
|
| | infer_decoder_model,
|
| | "--decoder-config-name",
|
| | infer_decoder_config,
|
| | "--llama-checkpoint-path",
|
| | infer_llama_model,
|
| | ]
|
| | + (["--compile"] if infer_compile == "Yes" else []),
|
| | env=env,
|
| | )
|
| |
|
| | elif if_infer == False and p_infer is not None:
|
| | kill_process(p_infer.pid)
|
| | p_infer = None
|
| | yield build_html_error_message(i18n("Infer interface is closed"))
|
| |
|
| |
|
| | js = load_data_in_raw("fish_speech/webui/js/animate.js")
|
| | css = load_data_in_raw("fish_speech/webui/css/style.css")
|
| |
|
| | data_pre_output = (cur_work_dir / "data").resolve()
|
| | default_model_output = (cur_work_dir / "results").resolve()
|
| | default_filelist = data_pre_output / "detect.list"
|
| | data_pre_output.mkdir(parents=True, exist_ok=True)
|
| |
|
| | items = []
|
| | dict_items = {}
|
| |
|
| |
|
| | def load_yaml_data_in_fact(yml_path):
|
| | with open(yml_path, "r", encoding="utf-8") as file:
|
| | yml = yaml.safe_load(file)
|
| | return yml
|
| |
|
| |
|
| | def write_yaml_data_in_fact(yml, yml_path):
|
| | with open(yml_path, "w", encoding="utf-8") as file:
|
| | yaml.safe_dump(yml, file, allow_unicode=True)
|
| | return yml
|
| |
|
| |
|
| | def generate_tree(directory, depth=0, max_depth=None, prefix=""):
|
| | if max_depth is not None and depth > max_depth:
|
| | return ""
|
| |
|
| | tree_str = ""
|
| | files = []
|
| | directories = []
|
| | for item in os.listdir(directory):
|
| | if os.path.isdir(os.path.join(directory, item)):
|
| | directories.append(item)
|
| | else:
|
| | files.append(item)
|
| |
|
| | entries = directories + files
|
| | for i, entry in enumerate(entries):
|
| | connector = "├── " if i < len(entries) - 1 else "└── "
|
| | tree_str += f"{prefix}{connector}{entry}<br />"
|
| | if i < len(directories):
|
| | extension = "│ " if i < len(entries) - 1 else " "
|
| | tree_str += generate_tree(
|
| | os.path.join(directory, entry),
|
| | depth + 1,
|
| | max_depth,
|
| | prefix=prefix + extension,
|
| | )
|
| | return tree_str
|
| |
|
| |
|
| | def new_explorer(data_path, max_depth):
|
| | return gr.Markdown(
|
| | elem_classes=["scrollable-component"],
|
| | value=generate_tree(data_path, max_depth=max_depth),
|
| | )
|
| |
|
| |
|
| | def add_item(
|
| | folder: str,
|
| | method: str,
|
| | label_lang: str,
|
| | if_initial_prompt: bool,
|
| | initial_prompt: str | None,
|
| | ):
|
| | folder = folder.strip(" ").strip('"')
|
| |
|
| | folder_path = Path(folder)
|
| |
|
| | if folder and folder not in items and data_pre_output not in folder_path.parents:
|
| | if folder_path.is_dir():
|
| | items.append(folder)
|
| | dict_items[folder] = dict(
|
| | type="folder",
|
| | method=method,
|
| | label_lang=label_lang,
|
| | initial_prompt=initial_prompt if if_initial_prompt else None,
|
| | )
|
| | elif folder:
|
| | err = folder
|
| | return gr.Checkboxgroup(choices=items), build_html_error_message(
|
| | i18n("Invalid path: {}").format(err)
|
| | )
|
| |
|
| | formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
|
| | logger.info("After Adding: " + formatted_data)
|
| | gr.Info(formatted_data)
|
| | return gr.Checkboxgroup(choices=items), build_html_ok_message(
|
| | i18n("Added path successfully!")
|
| | )
|
| |
|
| |
|
| | def remove_items(selected_items):
|
| | global items, dict_items
|
| | to_remove = [item for item in items if item in selected_items]
|
| | for item in to_remove:
|
| | del dict_items[item]
|
| | items = [item for item in items if item in dict_items.keys()]
|
| | formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
|
| | logger.info(formatted_data)
|
| | gr.Warning("After Removing: " + formatted_data)
|
| | return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
|
| | i18n("Removed path successfully!")
|
| | )
|
| |
|
| |
|
| | def show_selected(options):
|
| | selected_options = ", ".join(options)
|
| |
|
| | if options:
|
| | return i18n("Selected: {}").format(selected_options)
|
| | else:
|
| | return i18n("No selected options")
|
| |
|
| |
|
| | from pydub import AudioSegment
|
| |
|
| |
|
| | def convert_to_mono_in_place(audio_path: Path):
|
| | audio = AudioSegment.from_file(audio_path)
|
| | if audio.channels > 1:
|
| | mono_audio = audio.set_channels(1)
|
| | mono_audio.export(audio_path, format=audio_path.suffix[1:])
|
| | logger.info(f"Convert {audio_path} successfully")
|
| |
|
| |
|
| | def list_copy(list_file_path, method):
|
| | wav_root = data_pre_output
|
| | lst = []
|
| | with list_file_path.open("r", encoding="utf-8") as file:
|
| | for line in tqdm(file, desc="Processing audio/transcript"):
|
| | wav_path, speaker_name, language, text = line.strip().split("|")
|
| | original_wav_path = Path(wav_path)
|
| | target_wav_path = (
|
| | wav_root / original_wav_path.parent.name / original_wav_path.name
|
| | )
|
| | lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
|
| | if target_wav_path.is_file():
|
| | continue
|
| | target_wav_path.parent.mkdir(parents=True, exist_ok=True)
|
| | if method == i18n("Copy"):
|
| | shutil.copy(original_wav_path, target_wav_path)
|
| | else:
|
| | shutil.move(original_wav_path, target_wav_path.parent)
|
| | convert_to_mono_in_place(target_wav_path)
|
| | original_lab_path = original_wav_path.with_suffix(".lab")
|
| | target_lab_path = (
|
| | wav_root
|
| | / original_wav_path.parent.name
|
| | / original_wav_path.with_suffix(".lab").name
|
| | )
|
| | if target_lab_path.is_file():
|
| | continue
|
| | if method == i18n("Copy"):
|
| | shutil.copy(original_lab_path, target_lab_path)
|
| | else:
|
| | shutil.move(original_lab_path, target_lab_path.parent)
|
| |
|
| | if method == i18n("Move"):
|
| | with list_file_path.open("w", encoding="utf-8") as file:
|
| | file.writelines("\n".join(lst))
|
| |
|
| | del lst
|
| | return build_html_ok_message(i18n("Use filelist"))
|
| |
|
| |
|
| | def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
|
| | global dict_items
|
| | data_path = Path(data_path)
|
| | gr.Warning("Pre-processing begins...")
|
| | for item, content in dict_items.items():
|
| | item_path = Path(item)
|
| | tar_path = data_path / item_path.name
|
| |
|
| | if content["type"] == "folder" and item_path.is_dir():
|
| | if content["method"] == i18n("Copy"):
|
| | os.makedirs(tar_path, exist_ok=True)
|
| | shutil.copytree(
|
| | src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
|
| | )
|
| | elif not tar_path.is_dir():
|
| | shutil.move(src=str(item_path), dst=str(tar_path))
|
| |
|
| | for suf in ["wav", "flac", "mp3"]:
|
| | for audio_path in tar_path.glob(f"**/*.{suf}"):
|
| | convert_to_mono_in_place(audio_path)
|
| |
|
| | cur_lang = content["label_lang"]
|
| | initial_prompt = content["initial_prompt"]
|
| |
|
| | transcribe_cmd = [
|
| | PYTHON,
|
| | "tools/whisper_asr.py",
|
| | "--model-size",
|
| | label_model,
|
| | "--device",
|
| | label_device,
|
| | "--audio-dir",
|
| | tar_path,
|
| | "--save-dir",
|
| | tar_path,
|
| | "--language",
|
| | cur_lang,
|
| | ]
|
| |
|
| | if initial_prompt is not None:
|
| | transcribe_cmd += ["--initial-prompt", initial_prompt]
|
| |
|
| | if cur_lang != "IGNORE":
|
| | try:
|
| | gr.Warning("Begin To Transcribe")
|
| | subprocess.run(
|
| | transcribe_cmd,
|
| | env=env,
|
| | )
|
| | except Exception:
|
| | print("Transcription error occurred")
|
| |
|
| | elif content["type"] == "file" and item_path.is_file():
|
| | list_copy(item_path, content["method"])
|
| |
|
| | return build_html_ok_message(i18n("Move files successfully")), new_explorer(
|
| | data_path, max_depth=max_depth
|
| | )
|
| |
|
| |
|
| | def generate_folder_name():
|
| | now = datetime.datetime.now()
|
| | folder_name = now.strftime("%Y%m%d_%H%M%S")
|
| | return folder_name
|
| |
|
| |
|
| | def train_process(
|
| | data_path: str,
|
| | option: str,
|
| |
|
| | llama_ckpt,
|
| | llama_base_config,
|
| | llama_lr,
|
| | llama_maxsteps,
|
| | llama_data_num_workers,
|
| | llama_data_batch_size,
|
| | llama_data_max_length,
|
| | llama_precision,
|
| | llama_check_interval,
|
| | llama_grad_batches,
|
| | llama_use_speaker,
|
| | llama_use_lora,
|
| | ):
|
| |
|
| | backend = "nccl" if sys.platform == "linux" else "gloo"
|
| |
|
| | new_project = generate_folder_name()
|
| | print("New Project Name: ", new_project)
|
| |
|
| | if option == "VQGAN":
|
| | msg = "Skipped VQGAN Training."
|
| | gr.Warning(msg)
|
| | logger.info(msg)
|
| |
|
| | if option == "LLAMA":
|
| | msg = "LLAMA Training begins..."
|
| | gr.Warning(msg)
|
| | logger.info(msg)
|
| | subprocess.run(
|
| | [
|
| | PYTHON,
|
| | "tools/vqgan/extract_vq.py",
|
| | str(data_pre_output),
|
| | "--num-workers",
|
| | "1",
|
| | "--batch-size",
|
| | "16",
|
| | "--config-name",
|
| | "firefly_gan_vq",
|
| | "--checkpoint-path",
|
| | "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
| | ]
|
| | )
|
| |
|
| | subprocess.run(
|
| | [
|
| | PYTHON,
|
| | "tools/llama/build_dataset.py",
|
| | "--input",
|
| | str(data_pre_output),
|
| | "--text-extension",
|
| | ".lab",
|
| | "--num-workers",
|
| | "16",
|
| | ]
|
| | )
|
| | ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
|
| | lora_prefix = "lora_" if llama_use_lora else ""
|
| | llama_name = lora_prefix + "text2semantic_" + new_project
|
| | latest = next(
|
| | iter(
|
| | sorted(
|
| | [
|
| | str(p.relative_to("results"))
|
| | for p in Path("results").glob(lora_prefix + "text2sem*/")
|
| | ],
|
| | reverse=True,
|
| | )
|
| | ),
|
| | llama_name,
|
| | )
|
| | project = (
|
| | llama_name
|
| | if llama_ckpt == i18n("new")
|
| | else (
|
| | latest
|
| | if llama_ckpt == i18n("latest")
|
| | else Path(llama_ckpt).relative_to("results")
|
| | )
|
| | )
|
| | logger.info(project)
|
| |
|
| | if llama_check_interval > llama_maxsteps:
|
| | llama_check_interval = llama_maxsteps
|
| |
|
| | train_cmd = [
|
| | PYTHON,
|
| | "fish_speech/train.py",
|
| | "--config-name",
|
| | "text2semantic_finetune",
|
| | f"project={project}",
|
| | f"trainer.strategy.process_group_backend={backend}",
|
| | f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
|
| | f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
|
| | f"model.optimizer.lr={llama_lr}",
|
| | f"trainer.max_steps={llama_maxsteps}",
|
| | f"data.num_workers={llama_data_num_workers}",
|
| | f"data.batch_size={llama_data_batch_size}",
|
| | f"max_length={llama_data_max_length}",
|
| | f"trainer.precision={llama_precision}",
|
| | f"trainer.val_check_interval={llama_check_interval}",
|
| | f"trainer.accumulate_grad_batches={llama_grad_batches}",
|
| | f"train_dataset.interactive_prob={llama_use_speaker}",
|
| | ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
|
| | logger.info(train_cmd)
|
| | subprocess.run(train_cmd)
|
| |
|
| | return build_html_ok_message(i18n("Training stopped"))
|
| |
|
| |
|
| | def tensorboard_process(
|
| | if_tensorboard: bool,
|
| | tensorboard_dir: str,
|
| | host: str,
|
| | port: str,
|
| | ):
|
| | global p_tensorboard
|
| | if if_tensorboard == True and p_tensorboard == None:
|
| | url = f"http://{host}:{port}"
|
| | yield build_html_ok_message(
|
| | i18n("Tensorboard interface is launched at {}").format(url)
|
| | )
|
| | prefix = ["tensorboard"]
|
| | if Path("fishenv").exists():
|
| | prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
|
| |
|
| | p_tensorboard = subprocess.Popen(
|
| | prefix
|
| | + [
|
| | "--logdir",
|
| | tensorboard_dir,
|
| | "--host",
|
| | host,
|
| | "--port",
|
| | port,
|
| | "--reload_interval",
|
| | "120",
|
| | ]
|
| | )
|
| | elif if_tensorboard == False and p_tensorboard != None:
|
| | kill_process(p_tensorboard.pid)
|
| | p_tensorboard = None
|
| | yield build_html_error_message(i18n("Tensorboard interface is closed"))
|
| |
|
| |
|
| | def fresh_tb_dir():
|
| | return gr.Dropdown(
|
| | choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
|
| | )
|
| |
|
| |
|
| | def list_decoder_models():
|
| | paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
|
| | if not paths:
|
| | logger.warning("No decoder model found")
|
| | return paths
|
| |
|
| |
|
| | def list_llama_models():
|
| | choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
|
| | choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
|
| | choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
|
| | choices = sorted(choices, reverse=True)
|
| | if not choices:
|
| | logger.warning("No LLaMA model found")
|
| | return choices
|
| |
|
| |
|
| | def list_lora_llama_models():
|
| | choices = sorted(
|
| | [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
|
| | )
|
| | if not choices:
|
| | logger.warning("No LoRA LLaMA model found")
|
| | return choices
|
| |
|
| |
|
| | def fresh_decoder_model():
|
| | return gr.Dropdown(choices=list_decoder_models())
|
| |
|
| |
|
| | def fresh_llama_ckpt(llama_use_lora):
|
| | return gr.Dropdown(
|
| | choices=[i18n("latest"), i18n("new")]
|
| | + (
|
| | [str(p) for p in Path("results").glob("text2sem*/")]
|
| | if not llama_use_lora
|
| | else [str(p) for p in Path("results").glob("lora_*/")]
|
| | )
|
| | )
|
| |
|
| |
|
| | def fresh_llama_model():
|
| | return gr.Dropdown(choices=list_llama_models())
|
| |
|
| |
|
| | def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
|
| | if (
|
| | lora_weight is None
|
| | or not Path(lora_weight).exists()
|
| | or not Path(llama_weight).exists()
|
| | ):
|
| | return build_html_error_message(
|
| | i18n(
|
| | "Path error, please check the model file exists in the corresponding path"
|
| | )
|
| | )
|
| | gr.Warning("Merging begins...")
|
| | merge_cmd = [
|
| | PYTHON,
|
| | "tools/llama/merge_lora.py",
|
| | "--lora-config",
|
| | "r_8_alpha_16",
|
| | "--lora-weight",
|
| | lora_weight,
|
| | "--output",
|
| | llama_lora_output + "_" + generate_folder_name(),
|
| | ]
|
| | logger.info(merge_cmd)
|
| | subprocess.run(merge_cmd)
|
| | return build_html_ok_message(i18n("Merge successfully"))
|
| |
|
| |
|
| | def llama_quantify(llama_weight, quantify_mode):
|
| | if llama_weight is None or not Path(llama_weight).exists():
|
| | return build_html_error_message(
|
| | i18n(
|
| | "Path error, please check the model file exists in the corresponding path"
|
| | )
|
| | )
|
| |
|
| | gr.Warning("Quantifying begins...")
|
| |
|
| | now = generate_folder_name()
|
| | quantify_cmd = [
|
| | PYTHON,
|
| | "tools/llama/quantize.py",
|
| | "--checkpoint-path",
|
| | llama_weight,
|
| | "--mode",
|
| | quantify_mode,
|
| | "--timestamp",
|
| | now,
|
| | ]
|
| | logger.info(quantify_cmd)
|
| | subprocess.run(quantify_cmd)
|
| | if quantify_mode == "int8":
|
| | quantize_path = str(
|
| | Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
|
| | )
|
| | else:
|
| | quantize_path = str(
|
| | Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
|
| | )
|
| | return build_html_ok_message(
|
| | i18n("Quantify successfully") + f"Path: {quantize_path}"
|
| | )
|
| |
|
| |
|
| | init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
|
| | init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
|
| |
|
| | with gr.Blocks(
|
| | head="<style>\n" + css + "\n</style>",
|
| | js=js,
|
| | theme=seafoam,
|
| | analytics_enabled=False,
|
| | title="Fish Speech",
|
| | ) as demo:
|
| | with gr.Row():
|
| | with gr.Column():
|
| | with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
|
| | with gr.Row():
|
| | textbox = gr.Textbox(
|
| | label="\U0000270F "
|
| | + i18n("Input Audio & Source Path for Transcription"),
|
| | info=i18n("Speaker is identified by the folder name"),
|
| | interactive=True,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | with gr.Column():
|
| | output_radio = gr.Radio(
|
| | label="\U0001F4C1 "
|
| | + i18n("Select source file processing method"),
|
| | choices=[i18n("Copy"), i18n("Move")],
|
| | value=i18n("Copy"),
|
| | interactive=True,
|
| | )
|
| | with gr.Column():
|
| | error = gr.HTML(label=i18n("Error Message"))
|
| | if_label = gr.Checkbox(
|
| | label=i18n("Open Labeler WebUI"), scale=0, show_label=True
|
| | )
|
| |
|
| | with gr.Row():
|
| | label_device = gr.Dropdown(
|
| | label=i18n("Labeling Device"),
|
| | info=i18n(
|
| | "It is recommended to use CUDA, if you have low configuration, use CPU"
|
| | ),
|
| | choices=["cpu", "cuda"],
|
| | value="cuda",
|
| | interactive=True,
|
| | )
|
| | label_model = gr.Dropdown(
|
| | label=i18n("Whisper Model"),
|
| | info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
|
| | choices=["large-v3", "medium"],
|
| | value="large-v3",
|
| | interactive=True,
|
| | )
|
| | label_radio = gr.Dropdown(
|
| | label=i18n("Optional Label Language"),
|
| | info=i18n(
|
| | "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
|
| | ),
|
| | choices=[
|
| | (i18n("Chinese"), "zh"),
|
| | (i18n("English"), "en"),
|
| | (i18n("Japanese"), "ja"),
|
| | (i18n("Disabled"), "IGNORE"),
|
| | (i18n("auto"), "auto"),
|
| | ],
|
| | value="IGNORE",
|
| | interactive=True,
|
| | )
|
| |
|
| | with gr.Row():
|
| | if_initial_prompt = gr.Checkbox(
|
| | value=False,
|
| | label=i18n("Enable Initial Prompt"),
|
| | min_width=120,
|
| | scale=0,
|
| | )
|
| | initial_prompt = gr.Textbox(
|
| | label=i18n("Initial Prompt"),
|
| | info=i18n(
|
| | "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
|
| | ),
|
| | placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
|
| | interactive=False,
|
| | )
|
| |
|
| | with gr.Row():
|
| | add_button = gr.Button(
|
| | "\U000027A1 " + i18n("Add to Processing Area"),
|
| | variant="primary",
|
| | )
|
| | remove_button = gr.Button(
|
| | "\U000026D4 " + i18n("Remove Selected Data")
|
| | )
|
| |
|
| | with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
|
| | with gr.Row():
|
| | model_type_radio = gr.Radio(
|
| | label=i18n(
|
| | "Select the model to be trained (Depending on the Tab page you are on)"
|
| | ),
|
| | interactive=False,
|
| | choices=["VQGAN", "LLAMA"],
|
| | value="VQGAN",
|
| | )
|
| | with gr.Row():
|
| | with gr.Tabs():
|
| | with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
|
| | gr.HTML("You don't need to train this model!")
|
| |
|
| | with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
|
| | with gr.Row(equal_height=False):
|
| | llama_use_lora = gr.Checkbox(
|
| | label=i18n("Use LoRA"),
|
| | info=i18n(
|
| | "Use LoRA can save GPU memory, but may reduce the quality of the model"
|
| | ),
|
| | value=True,
|
| | interactive=True,
|
| | )
|
| | llama_ckpt = gr.Dropdown(
|
| | label=i18n("Select LLAMA ckpt"),
|
| | choices=[i18n("latest"), i18n("new")]
|
| | + [
|
| | str(p)
|
| | for p in Path("results").glob("text2sem*/")
|
| | ]
|
| | + [str(p) for p in Path("results").glob("lora*/")],
|
| | value=i18n("latest"),
|
| | interactive=True,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | llama_lr_slider = gr.Slider(
|
| | label=i18n("Initial Learning Rate"),
|
| | info=i18n(
|
| | "lr smaller -> usually train slower but more stable"
|
| | ),
|
| | interactive=True,
|
| | minimum=1e-5,
|
| | maximum=1e-4,
|
| | step=1e-5,
|
| | value=5e-5,
|
| | )
|
| | llama_maxsteps_slider = gr.Slider(
|
| | label=i18n("Maximum Training Steps"),
|
| | info=i18n(
|
| | "recommend: max_steps = num_audios // batch_size * (2 to 5)"
|
| | ),
|
| | interactive=True,
|
| | minimum=1,
|
| | maximum=10000,
|
| | step=1,
|
| | value=50,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | llama_base_config = gr.Dropdown(
|
| | label=i18n("Model Size"),
|
| | choices=[
|
| | "text2semantic_finetune",
|
| | ],
|
| | value="text2semantic_finetune",
|
| | )
|
| | llama_data_num_workers_slider = gr.Slider(
|
| | label=i18n("Number of Workers"),
|
| | minimum=1,
|
| | maximum=32,
|
| | step=1,
|
| | value=4,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | llama_data_batch_size_slider = gr.Slider(
|
| | label=i18n("Batch Size"),
|
| | interactive=True,
|
| | minimum=1,
|
| | maximum=32,
|
| | step=1,
|
| | value=2,
|
| | )
|
| | llama_data_max_length_slider = gr.Slider(
|
| | label=i18n("Maximum Length per Sample"),
|
| | interactive=True,
|
| | minimum=1024,
|
| | maximum=4096,
|
| | step=128,
|
| | value=2048,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | llama_precision_dropdown = gr.Dropdown(
|
| | label=i18n("Precision"),
|
| | info=i18n(
|
| | "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
|
| | ),
|
| | interactive=True,
|
| | choices=["32", "bf16-true", "16-mixed"],
|
| | value="bf16-true",
|
| | )
|
| | llama_check_interval_slider = gr.Slider(
|
| | label=i18n("Save model every n steps"),
|
| | info=i18n(
|
| | "make sure that it's not greater than max_steps"
|
| | ),
|
| | interactive=True,
|
| | minimum=1,
|
| | maximum=1000,
|
| | step=1,
|
| | value=50,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | llama_grad_batches = gr.Slider(
|
| | label=i18n("Accumulate Gradient Batches"),
|
| | interactive=True,
|
| | minimum=1,
|
| | maximum=20,
|
| | step=1,
|
| | value=init_llama_yml["trainer"][
|
| | "accumulate_grad_batches"
|
| | ],
|
| | )
|
| | llama_use_speaker = gr.Slider(
|
| | label=i18n(
|
| | "Probability of applying Speaker Condition"
|
| | ),
|
| | interactive=True,
|
| | minimum=0.1,
|
| | maximum=1.0,
|
| | step=0.05,
|
| | value=init_llama_yml["train_dataset"][
|
| | "interactive_prob"
|
| | ],
|
| | )
|
| |
|
| | with gr.Tab(label=i18n("Merge LoRA"), id=4):
|
| | with gr.Row(equal_height=False):
|
| | llama_weight = gr.Dropdown(
|
| | label=i18n("Base LLAMA Model"),
|
| | info=i18n(
|
| | "Type the path or select from the dropdown"
|
| | ),
|
| | choices=[
|
| | "checkpoints/fish-speech-1.4/model.pth",
|
| | ],
|
| | value="checkpoints/fish-speech-1.4/model.pth",
|
| | allow_custom_value=True,
|
| | interactive=True,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | lora_weight = gr.Dropdown(
|
| | label=i18n("LoRA Model to be merged"),
|
| | info=i18n(
|
| | "Type the path or select from the dropdown"
|
| | ),
|
| | choices=[
|
| | str(p)
|
| | for p in Path("results").glob("lora*/**/*.ckpt")
|
| | ],
|
| | allow_custom_value=True,
|
| | interactive=True,
|
| | )
|
| | lora_llama_config = gr.Dropdown(
|
| | label=i18n("LLAMA Model Config"),
|
| | info=i18n(
|
| | "Type the path or select from the dropdown"
|
| | ),
|
| | choices=[
|
| | "text2semantic_finetune",
|
| | ],
|
| | value="text2semantic_finetune",
|
| | allow_custom_value=True,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | llama_lora_output = gr.Dropdown(
|
| | label=i18n("Output Path"),
|
| | info=i18n(
|
| | "Type the path or select from the dropdown"
|
| | ),
|
| | value="checkpoints/merged",
|
| | choices=["checkpoints/merged"],
|
| | allow_custom_value=True,
|
| | interactive=True,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | llama_lora_merge_btn = gr.Button(
|
| | value=i18n("Merge"), variant="primary"
|
| | )
|
| |
|
| | with gr.Tab(label=i18n("Model Quantization"), id=5):
|
| | with gr.Row(equal_height=False):
|
| | llama_weight_to_quantify = gr.Dropdown(
|
| | label=i18n("Base LLAMA Model"),
|
| | info=i18n(
|
| | "Type the path or select from the dropdown"
|
| | ),
|
| | choices=list_llama_models(),
|
| | value="checkpoints/fish-speech-1.4",
|
| | allow_custom_value=True,
|
| | interactive=True,
|
| | )
|
| | quantify_mode = gr.Dropdown(
|
| | label=i18n("Post-quantification Precision"),
|
| | info=i18n(
|
| | "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
|
| | ),
|
| | choices=["int8", "int4"],
|
| | value="int8",
|
| | allow_custom_value=False,
|
| | interactive=True,
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | llama_quantify_btn = gr.Button(
|
| | value=i18n("Quantify"), variant="primary"
|
| | )
|
| |
|
| | with gr.Tab(label="Tensorboard", id=6):
|
| | with gr.Row(equal_height=False):
|
| | tb_host = gr.Textbox(
|
| | label=i18n("Tensorboard Host"), value="127.0.0.1"
|
| | )
|
| | tb_port = gr.Textbox(
|
| | label=i18n("Tensorboard Port"), value="11451"
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | tb_dir = gr.Dropdown(
|
| | label=i18n("Tensorboard Log Path"),
|
| | allow_custom_value=True,
|
| | choices=[
|
| | str(p)
|
| | for p in Path("results").glob("**/tensorboard/")
|
| | ],
|
| | )
|
| | with gr.Row(equal_height=False):
|
| | if_tb = gr.Checkbox(
|
| | label=i18n("Open Tensorboard"),
|
| | )
|
| |
|
| | with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
|
| | with gr.Column():
|
| | with gr.Row():
|
| | with gr.Accordion(
|
| | label="\U0001F5A5 "
|
| | + i18n("Inference Server Configuration"),
|
| | open=False,
|
| | ):
|
| | with gr.Row():
|
| | infer_host_textbox = gr.Textbox(
|
| | label=i18n("WebUI Host"), value="127.0.0.1"
|
| | )
|
| | infer_port_textbox = gr.Textbox(
|
| | label=i18n("WebUI Port"), value="7862"
|
| | )
|
| | with gr.Row():
|
| | infer_decoder_model = gr.Dropdown(
|
| | label=i18n("Decoder Model Path"),
|
| | info=i18n(
|
| | "Type the path or select from the dropdown"
|
| | ),
|
| | choices=list_decoder_models(),
|
| | value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
| | allow_custom_value=True,
|
| | )
|
| | infer_decoder_config = gr.Dropdown(
|
| | label=i18n("Decoder Model Config"),
|
| | info=i18n("Changing with the Model Path"),
|
| | value="firefly_gan_vq",
|
| | choices=[
|
| | "firefly_gan_vq",
|
| | ],
|
| | allow_custom_value=True,
|
| | )
|
| | with gr.Row():
|
| | infer_llama_model = gr.Dropdown(
|
| | label=i18n("LLAMA Model Path"),
|
| | info=i18n(
|
| | "Type the path or select from the dropdown"
|
| | ),
|
| | value="checkpoints/fish-speech-1.4",
|
| | choices=list_llama_models(),
|
| | allow_custom_value=True,
|
| | )
|
| |
|
| | with gr.Row():
|
| | infer_compile = gr.Radio(
|
| | label=i18n("Compile Model"),
|
| | info=i18n(
|
| | "Compile the model can significantly reduce the inference time, but will increase cold start time"
|
| | ),
|
| | choices=["Yes", "No"],
|
| | value=(
|
| | "Yes" if (sys.platform == "linux") else "No"
|
| | ),
|
| | interactive=is_module_installed("triton"),
|
| | )
|
| |
|
| | with gr.Row():
|
| | infer_checkbox = gr.Checkbox(
|
| | label=i18n("Open Inference Server")
|
| | )
|
| | infer_error = gr.HTML(label=i18n("Inference Server Error"))
|
| |
|
| | with gr.Column():
|
| | train_error = gr.HTML(label=i18n("Training Error"))
|
| | checkbox_group = gr.CheckboxGroup(
|
| | label="\U0001F4CA " + i18n("Data Source"),
|
| | info=i18n(
|
| | "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
|
| | ),
|
| | elem_classes=["data_src"],
|
| | )
|
| | train_box = gr.Textbox(
|
| | label=i18n("Data Preprocessing Path"),
|
| | value=str(data_pre_output),
|
| | interactive=False,
|
| | )
|
| | model_box = gr.Textbox(
|
| | label="\U0001F4BE " + i18n("Model Output Path"),
|
| | value=str(default_model_output),
|
| | interactive=False,
|
| | )
|
| |
|
| | with gr.Accordion(
|
| | i18n(
|
| | "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
|
| | ),
|
| | elem_classes=["scrollable-component"],
|
| | elem_id="file_accordion",
|
| | ):
|
| | tree_slider = gr.Slider(
|
| | minimum=0,
|
| | maximum=3,
|
| | value=0,
|
| | step=1,
|
| | show_label=False,
|
| | container=False,
|
| | )
|
| | file_markdown = new_explorer(str(data_pre_output), 0)
|
| | with gr.Row(equal_height=False):
|
| | admit_btn = gr.Button(
|
| | "\U00002705 " + i18n("File Preprocessing"),
|
| | variant="primary",
|
| | )
|
| | fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
|
| | help_button = gr.Button("\U00002753", scale=0, min_width=80)
|
| | train_btn = gr.Button(i18n("Start Training"), variant="primary")
|
| |
|
| | footer = load_data_in_raw("fish_speech/webui/html/footer.html")
|
| | footer = footer.format(
|
| | versions=versions_html(),
|
| | api_docs="https://speech.fish.audio/inference/#http-api",
|
| | )
|
| | gr.HTML(footer, elem_id="footer")
|
| | vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
|
| | llama_page.select(lambda: "LLAMA", None, model_type_radio)
|
| | add_button.click(
|
| | fn=add_item,
|
| | inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
|
| | outputs=[checkbox_group, error],
|
| | )
|
| | remove_button.click(
|
| | fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
|
| | )
|
| | checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
|
| | help_button.click(
|
| | fn=None,
|
| | js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
|
| | 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
|
| | )
|
| | if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
|
| | if_initial_prompt.change(
|
| | fn=lambda x: gr.Textbox(value="", interactive=x),
|
| | inputs=[if_initial_prompt],
|
| | outputs=[initial_prompt],
|
| | )
|
| | train_btn.click(
|
| | fn=train_process,
|
| | inputs=[
|
| | train_box,
|
| | model_type_radio,
|
| |
|
| | llama_ckpt,
|
| | llama_base_config,
|
| | llama_lr_slider,
|
| | llama_maxsteps_slider,
|
| | llama_data_num_workers_slider,
|
| | llama_data_batch_size_slider,
|
| | llama_data_max_length_slider,
|
| | llama_precision_dropdown,
|
| | llama_check_interval_slider,
|
| | llama_grad_batches,
|
| | llama_use_speaker,
|
| | llama_use_lora,
|
| | ],
|
| | outputs=[train_error],
|
| | )
|
| | if_tb.change(
|
| | fn=tensorboard_process,
|
| | inputs=[if_tb, tb_dir, tb_host, tb_port],
|
| | outputs=[train_error],
|
| | )
|
| | tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
|
| | infer_decoder_model.change(
|
| | fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
|
| | )
|
| | infer_llama_model.change(
|
| | fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
|
| | )
|
| | llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
|
| | admit_btn.click(
|
| | fn=check_files,
|
| | inputs=[train_box, tree_slider, label_model, label_device],
|
| | outputs=[error, file_markdown],
|
| | )
|
| | fresh_btn.click(
|
| | fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
|
| | )
|
| | llama_use_lora.change(
|
| | fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
|
| | )
|
| | llama_ckpt.change(
|
| | fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
|
| | )
|
| | lora_weight.change(
|
| | fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
|
| | inputs=[],
|
| | outputs=[lora_weight],
|
| | )
|
| | llama_lora_merge_btn.click(
|
| | fn=llama_lora_merge,
|
| | inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
|
| | outputs=[train_error],
|
| | )
|
| | llama_quantify_btn.click(
|
| | fn=llama_quantify,
|
| | inputs=[llama_weight_to_quantify, quantify_mode],
|
| | outputs=[train_error],
|
| | )
|
| | infer_checkbox.change(
|
| | fn=change_infer,
|
| | inputs=[
|
| | infer_checkbox,
|
| | infer_host_textbox,
|
| | infer_port_textbox,
|
| | infer_decoder_model,
|
| | infer_decoder_config,
|
| | infer_llama_model,
|
| | infer_compile,
|
| | ],
|
| | outputs=[infer_error],
|
| | )
|
| |
|
| | demo.launch(inbrowser=True)
|
| |
|