ColabWan / models /TTS /kugelaudio_handler.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
11.5 kB
import os
import re
import torch
from shared.mps import mps_device_or
from shared.utils import files_locator as fl
from .prompt_enhancers import TTS_MONOLOGUE_PROMPT, TTS_QWEN3_DIALOGUE_PROMPT
KUGELAUDIO_REPO_ID = "DeepBeepMeep/TTS"
KUGELAUDIO_ASSET_DIR = "kugelaudio"
KUGELAUDIO_TOKENIZER_DIR = "kugelaudio_text_tokenizer"
KUGELAUDIO_CONFIG_NAME = "config.json"
KUGELAUDIO_GENERATION_CONFIG_NAME = "generation_config.json"
KUGELAUDIO_TOKENIZER_FILES = [
"merges.txt",
"tokenizer.json",
"tokenizer_config.json",
"vocab.json",
"preprocessor_config.json",
]
KUGELAUDIO_DURATION_SLIDER = {
"label": "Max duration (seconds)",
"min": 1,
"max": 600,
"increment": 1,
"default": 20,
}
KUGELAUDIO_AUTO_SPLIT_SETTING_ID = "auto_split_every_s"
KUGELAUDIO_AUTO_SPLIT_MIN_SECONDS = 5.0
KUGELAUDIO_AUTO_SPLIT_MAX_SECONDS = 90.0
KUGELAUDIO_CUSTOM_SETTINGS = [
{
"id": KUGELAUDIO_AUTO_SPLIT_SETTING_ID,
"label": "Auto Split Every s (5-90, optional), to avoid Acceleration Effect. Empty Lines will force anyway Manual Splits.",
"name": "Auto Split Every s",
"type": "float",
},
]
def _configure_diffusion_compile_targets(model):
for _, submodule in model.named_modules():
submodule._compile_me = False
prediction_head = getattr(model, "prediction_head", None)
if prediction_head is None:
prediction_head = getattr(getattr(model, "model", None), "prediction_head", None)
if prediction_head is None:
raise RuntimeError("KugelAudio diffusion head is missing; cannot configure compile targets.")
layers = getattr(prediction_head, "layers", None)
if layers is not None:
for layer in layers:
layer._compile_me = True
final_layer = getattr(prediction_head, "final_layer", None)
if final_layer is not None:
final_layer._compile_me = True
def _get_kugelaudio_model_def():
return {
"audio_only": True,
"image_outputs": False,
"sliding_window": False,
"guidance_max_phases": 1,
"no_negative_prompt": True,
"inference_steps": False,
"temperature": True,
"image_prompt_types_allowed": "",
"supports_early_stop": True,
"profiles_dir": ["kugelaudio_0_open"],
"lm_engines": ["cg"],
"duration_slider": dict(KUGELAUDIO_DURATION_SLIDER),
"custom_settings": [one.copy() for one in KUGELAUDIO_CUSTOM_SETTINGS],
"preserve_empty_prompt_lines": True,
"pause_between_sentences": True,
"any_audio_prompt": True,
"audio_guide_label": "Reference voice (optional)",
"audio_prompt_choices": True,
"audio_prompt_type_sources": {
"selection": ["", "A", "AB"],
"labels": {
"": "Text only",
"A": "Voice cloning (1 reference audio)",
"AB": "Voice cloning (2 reference audios: Speaker 1 and Speaker 2)",
},
"letters_filter": "AB",
"default": "",
},
"text_prompt_enhancer_instructions": TTS_MONOLOGUE_PROMPT,
"text_prompt_enhancer_instructions1": TTS_QWEN3_DIALOGUE_PROMPT,
"text_prompt_enhancer_max_tokens": 512,
"text_prompt_enhancer_max_tokens1": 512,
"prompt_enhancer_def": {
"selection": ["T", "T1"],
"labels": {
"T": "A Speech based on current Prompt",
"T1": "A Dialogue between two People based on current Prompt",
},
"default": "T",
},
"prompt_enhancer_button_label": "Write",
"compile": ["transformer"],
}
def _get_kugelaudio_download_def():
return [
{
"repoId": KUGELAUDIO_REPO_ID,
"sourceFolderList": [KUGELAUDIO_TOKENIZER_DIR],
"fileList": [KUGELAUDIO_TOKENIZER_FILES],
},
]
class family_handler:
@staticmethod
def query_supported_types():
return ["kugelaudio_0_open"]
@staticmethod
def query_family_maps():
return {}, {}
@staticmethod
def query_model_family():
return "tts"
@staticmethod
def query_family_infos():
return {"tts": (200, "TTS")}
@staticmethod
def register_lora_cli_args(parser, lora_root):
parser.add_argument(
"--lora-dir-kugelaudio",
type=str,
default=None,
help=f"Path to a directory that contains KugelAudio settings (default: {os.path.join(lora_root, 'kugelaudio')})",
)
@staticmethod
def get_lora_dir(base_model_type, args, lora_root):
return getattr(args, "lora_dir_kugelaudio", None) or os.path.join(lora_root, "kugelaudio")
@staticmethod
def query_model_def(base_model_type, model_def):
return _get_kugelaudio_model_def()
@staticmethod
def query_model_files(computeList, base_model_type, model_def=None):
return _get_kugelaudio_download_def()
@staticmethod
def load_model(
model_filename,
model_type,
base_model_type,
model_def,
quantizeTransformer=False,
text_encoder_quantization=None,
dtype=None,
VAE_dtype=None,
mixed_precision_transformer=False,
save_quantized=False,
submodel_no_list=None,
text_encoder_filename=None,
profile=0,
lm_decoder_engine="legacy",
**kwargs,
):
from .kugelaudio.pipeline import KugelAudioPipeline
weights_path = model_filename[0]
pipeline = KugelAudioPipeline(
model_weights_path=weights_path,
ckpt_root=fl.get_download_location(),
device=mps_device_or(torch.device("cpu")),
lm_decoder_engine=lm_decoder_engine,
)
if lm_decoder_engine == "cg":
pipeline.model._budget = 0
language_model = getattr(getattr(pipeline.model, "model", None), "language_model", None)
if language_model is not None:
language_model._budget = 0
_configure_diffusion_compile_targets(pipeline.model)
pipe = {
"transformer": pipeline.model,
}
if save_quantized and weights_path:
from wgp import save_quantized_model
local_config_path = os.path.join(
os.path.dirname(__file__), "kugelaudio", "configs", "kugelaudio", KUGELAUDIO_CONFIG_NAME
)
if os.path.isfile(local_config_path):
config_path = local_config_path
else:
config_path = fl.locate_file(
os.path.join(KUGELAUDIO_ASSET_DIR, KUGELAUDIO_CONFIG_NAME),
error_if_none=False,
)
if config_path is None:
config_path = os.path.join(KUGELAUDIO_ASSET_DIR, KUGELAUDIO_CONFIG_NAME)
save_quantized_model(
pipeline.model,
model_type,
weights_path,
dtype or torch.bfloat16,
config_path,
)
return pipeline, pipe
@staticmethod
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if "alt_prompt" not in ui_defaults:
ui_defaults["alt_prompt"] = ""
defaults = {
"audio_prompt_type": "",
}
for key, value in defaults.items():
ui_defaults.setdefault(key, value)
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
duration_def = model_def.get("duration_slider", {})
ui_defaults.update(
{
"audio_prompt_type": "",
"prompt": "Hello! This is KugelAudio speaking in a clear, friendly voice.",
"repeat_generation": 1,
"duration_seconds": duration_def.get("default", 60),
"pause_seconds": 0.5,
"video_length": 0,
"num_inference_steps": 0,
"negative_prompt": "",
"temperature": 1.0,
"guidance_scale": 3.0,
"multi_prompts_gen_type": "FG",
}
)
@staticmethod
def validate_generative_prompt(base_model_type, model_def, inputs, one_prompt):
audio_prompt_type = inputs.get("audio_prompt_type", "") or ""
if one_prompt is None or len(str(one_prompt).strip()) == 0:
return "Prompt text cannot be empty for KugelAudio."
text = str(one_prompt)
if "Speaker" in text or "speaker" in text:
if "A" not in audio_prompt_type or "B" not in audio_prompt_type:
return "Multi-speaker prompts require two reference voice audio samples. Provide a voice sample or remove Speaker tags."
if "B" in audio_prompt_type:
if inputs.get("audio_guide") is None or inputs.get("audio_guide2") is None:
return "Two-voice cloning requires two reference audio files."
speaker_matches = list(re.finditer(r"Speaker\s*(\d+)\s*:", text, flags=re.IGNORECASE))
if not speaker_matches:
return (
"Two-voice cloning requires prompt lines with Speaker 1: and Speaker 2: "
"(or any two numeric speaker IDs). For headless settings, keep "
"'multi_prompts_gen_type' = 'FG' so dialogue lines stay in one prompt."
)
speaker_ids = sorted({int(m.group(1)) for m in speaker_matches})
if len(speaker_ids) != 2:
return (
"Two-voice cloning requires exactly two speaker IDs. Use Speaker 1: and Speaker 2:. "
"For headless settings, keep 'multi_prompts_gen_type' = 'FG'."
)
return None
@staticmethod
def validate_generative_settings(base_model_type, model_def, inputs):
custom_settings = inputs.get("custom_settings", None)
if custom_settings is None:
return None
if not isinstance(custom_settings, dict):
return "Custom settings must be a dictionary."
raw_value = custom_settings.get(KUGELAUDIO_AUTO_SPLIT_SETTING_ID, None)
if raw_value is None:
return None
if isinstance(raw_value, str):
raw_value = raw_value.strip()
if len(raw_value) == 0:
custom_settings.pop(KUGELAUDIO_AUTO_SPLIT_SETTING_ID, None)
inputs["custom_settings"] = custom_settings if len(custom_settings) > 0 else None
return None
try:
if isinstance(raw_value, bool):
raise ValueError()
auto_split_seconds = float(raw_value)
except Exception:
return (
f"Auto Split Every s must be a number between "
f"{int(KUGELAUDIO_AUTO_SPLIT_MIN_SECONDS)} and {int(KUGELAUDIO_AUTO_SPLIT_MAX_SECONDS)} seconds."
)
if (
auto_split_seconds < KUGELAUDIO_AUTO_SPLIT_MIN_SECONDS
or auto_split_seconds > KUGELAUDIO_AUTO_SPLIT_MAX_SECONDS
):
return (
f"Auto Split Every s must be between "
f"{int(KUGELAUDIO_AUTO_SPLIT_MIN_SECONDS)} and {int(KUGELAUDIO_AUTO_SPLIT_MAX_SECONDS)} seconds."
)
custom_settings[KUGELAUDIO_AUTO_SPLIT_SETTING_ID] = auto_split_seconds
inputs["custom_settings"] = custom_settings
return None