| import os
|
|
|
| import gradio as gr
|
|
|
| from shared.utils import files_locator as fl
|
|
|
| from .prompt_enhancers import TTS_MONOLOGUE_PROMPT
|
|
|
|
|
| _FALLBACK_SUPPORTED_LANGUAGES = {
|
| "ar": "Arabic",
|
| "da": "Danish",
|
| "de": "German",
|
| "el": "Greek",
|
| "en": "English",
|
| "es": "Spanish",
|
| "fi": "Finnish",
|
| "fr": "French",
|
| "he": "Hebrew",
|
| "hi": "Hindi",
|
| "it": "Italian",
|
| "ja": "Japanese",
|
| "ko": "Korean",
|
| "ms": "Malay",
|
| "nl": "Dutch",
|
| "no": "Norwegian",
|
| "pl": "Polish",
|
| "pt": "Portuguese",
|
| "ru": "Russian",
|
| "sv": "Swedish",
|
| "sw": "Swahili",
|
| "tr": "Turkish",
|
| "zh": "Chinese",
|
| }
|
|
|
| def _get_supported_languages() -> dict:
|
| try:
|
| from .chatterbox.mtl_tts import SUPPORTED_LANGUAGES
|
| except Exception:
|
| return _FALLBACK_SUPPORTED_LANGUAGES
|
| return SUPPORTED_LANGUAGES
|
|
|
|
|
| def _get_language_choices() -> list[tuple[str, str]]:
|
| languages = _get_supported_languages()
|
| return [
|
| (f"{name} ({code})", code)
|
| for code, name in sorted(languages.items(), key=lambda item: item[1])
|
| ]
|
|
|
| CHATTERBOX_CUSTOM_SETTINGS_MIGRATION_VERSION = 2.50
|
| CHATTERBOX_DEFAULT_CUSTOM_SETTINGS = {
|
| "exaggeration": 0.5,
|
| "pace": 0.5,
|
| }
|
| CHATTERBOX_CUSTOM_SETTINGS = [
|
| {
|
| "id": "exaggeration",
|
| "label": "Emotion Exaggeration (0.25-2.0, 0.5 = neutral)",
|
| "name": "Exaggeration",
|
| "type": "float",
|
| "default": CHATTERBOX_DEFAULT_CUSTOM_SETTINGS["exaggeration"],
|
| },
|
| {
|
| "id": "pace",
|
| "label": "Pace (0.2-1.0)",
|
| "name": "Pace",
|
| "type": "float",
|
| "default": CHATTERBOX_DEFAULT_CUSTOM_SETTINGS["pace"],
|
| },
|
| ]
|
|
|
|
|
| def _get_chatterbox_model_def():
|
| return {
|
| "audio_only": True,
|
| "image_outputs": False,
|
| "sliding_window": False,
|
| "guidance_max_phases": 0,
|
| "no_negative_prompt": True,
|
| "inference_steps": False,
|
| "temperature": True,
|
| "image_prompt_types_allowed": "",
|
| "profiles_dir": ["chatterbox"],
|
| "audio_guide_label": "Voice to Replicate",
|
| "model_modes": {
|
| "choices": _get_language_choices(),
|
| "default": "en",
|
| "label": "Language",
|
| },
|
| "any_audio_prompt": True,
|
| "custom_settings": [one.copy() for one in CHATTERBOX_CUSTOM_SETTINGS],
|
| "text_prompt_enhancer_instructions": TTS_MONOLOGUE_PROMPT,
|
| "prompt_enhancer_button_label": "Write Speech",
|
| }
|
|
|
|
|
| def _get_chatterbox_download_def():
|
| mandatory_files = [
|
| "ve.safetensors",
|
| "t3_mtl23ls_v2.safetensors",
|
| "s3gen.pt",
|
| "grapheme_mtl_merged_expanded_v1.json",
|
| "conds.pt",
|
| "Cangjie5_TC.json",
|
| ]
|
| return {
|
| "repoId": "ResembleAI/chatterbox",
|
| "sourceFolderList": [""],
|
| "targetFolderList": ["chatterbox"],
|
| "fileList": [mandatory_files],
|
| }
|
|
|
|
|
| class family_handler:
|
| @staticmethod
|
| def query_supported_types():
|
| return ["chatterbox"]
|
|
|
| @staticmethod
|
| def query_family_maps():
|
| return {}, {}
|
|
|
| @staticmethod
|
| def query_model_family():
|
| return "tts"
|
|
|
| @staticmethod
|
| def query_family_infos():
|
| return {"tts": (200, "TTS")}
|
|
|
| @staticmethod
|
| def register_lora_cli_args(parser, lora_root):
|
| parser.add_argument(
|
| "--lora-dir-chatterbox",
|
| type=str,
|
| default=None,
|
| help=f"Path to a directory that contains chatterbox settings (default: {os.path.join(lora_root, 'chatterbox')})",
|
| )
|
|
|
| @staticmethod
|
| def get_lora_dir(base_model_type, args, lora_root):
|
| return getattr(args, "lora_dir_chatterbox", None) or os.path.join(lora_root, "chatterbox")
|
|
|
| @staticmethod
|
| def query_model_def(base_model_type, model_def):
|
| return _get_chatterbox_model_def()
|
|
|
| @staticmethod
|
| def query_model_files(computeList, base_model_type, model_def=None):
|
| return _get_chatterbox_download_def()
|
|
|
| @staticmethod
|
| def load_model(
|
| model_filename,
|
| model_type,
|
| base_model_type,
|
| model_def,
|
| quantizeTransformer=False,
|
| text_encoder_quantization=None,
|
| dtype=None,
|
| VAE_dtype=None,
|
| mixed_precision_transformer=False,
|
| save_quantized=False,
|
| submodel_no_list=None,
|
| text_encoder_filename=None,
|
| profile=0,
|
| **kwargs,
|
| ):
|
| from .chatterbox.pipeline import ChatterboxPipeline
|
|
|
| ckpt_root = fl.get_download_location()
|
| pipeline = ChatterboxPipeline(ckpt_root=ckpt_root, device="cpu")
|
| pipe = {
|
| "ve": pipeline.model.ve,
|
| "s3gen": pipeline.model.s3gen,
|
| "t3": pipeline.model.t3,
|
| "conds": pipeline.model.conds,
|
| }
|
| return pipeline, pipe
|
|
|
| @staticmethod
|
| def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
|
| if "alt_prompt" not in ui_defaults:
|
| ui_defaults["alt_prompt"] = ""
|
|
|
| defaults = {
|
| "audio_prompt_type": "A",
|
| "model_mode": "en",
|
| }
|
| for key, value in defaults.items():
|
| ui_defaults.setdefault(key, value)
|
|
|
| if settings_version < 2.44:
|
| ui_defaults["guidance_scale"] = 1.0
|
|
|
| legacy_exaggeration = ui_defaults.pop("exaggeration", None)
|
| legacy_pace = ui_defaults.pop("pace", None)
|
| custom_settings = ui_defaults.get("custom_settings", None)
|
| if not isinstance(custom_settings, dict):
|
| custom_settings = {}
|
| else:
|
| custom_settings = custom_settings.copy()
|
|
|
| if settings_version < CHATTERBOX_CUSTOM_SETTINGS_MIGRATION_VERSION:
|
| if legacy_exaggeration is not None:
|
| custom_settings.setdefault("exaggeration", legacy_exaggeration)
|
| if legacy_pace is not None:
|
| custom_settings.setdefault("pace", legacy_pace)
|
|
|
| if legacy_exaggeration is not None and "exaggeration" not in custom_settings:
|
| custom_settings["exaggeration"] = legacy_exaggeration
|
| if legacy_pace is not None and "pace" not in custom_settings:
|
| custom_settings["pace"] = legacy_pace
|
|
|
| for key, value in CHATTERBOX_DEFAULT_CUSTOM_SETTINGS.items():
|
| custom_settings.setdefault(key, value)
|
| ui_defaults["custom_settings"] = custom_settings
|
|
|
| @staticmethod
|
| def update_default_settings(base_model_type, model_def, ui_defaults):
|
| ui_defaults.update(
|
| {
|
| "audio_prompt_type": "A",
|
| "model_mode": "en",
|
| "repeat_generation": 1,
|
| "video_length": 0,
|
| "num_inference_steps": 0,
|
| "negative_prompt": "",
|
| "custom_settings": dict(CHATTERBOX_DEFAULT_CUSTOM_SETTINGS),
|
| "temperature": 0.8,
|
| "guidance_scale": 1.0,
|
| "multi_prompts_gen_type": 2,
|
| }
|
| )
|
|
|
| @staticmethod
|
| def validate_generative_prompt(base_model_type, model_def, inputs, one_prompt):
|
| if len(one_prompt) > 300:
|
| gr.Info(
|
| "It is recommended to use a prompt that has less than 300 characters,"
|
| " otherwise you may get unexpected results."
|
| )
|
|
|