Wan2GP / models /TTS /yue_handler.py
Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
import os
YUE_STAGE1_COT_REPO = "m-a-p/YuE-s1-7B-anneal-en-cot"
YUE_STAGE1_ICL_REPO = "m-a-p/YuE-s1-7B-anneal-en-icl"
YUE_STAGE2_REPO = "m-a-p/YuE-s2-1B-general"
YUE_STAGE1_FILES = [
"config.json",
]
YUE_STAGE2_FILES = [
"config.json",
]
def _get_yue_model_def(model_def):
use_audio_prompt = bool(model_def.get("yue_audio_prompt", False))
yue_def = {
"audio_only": True,
"image_outputs": False,
"sliding_window": False,
"guidance_max_phases": 0,
"no_negative_prompt": True,
"inference_steps": False,
"temperature": True,
"image_prompt_types_allowed": "",
"profiles_dir": ["yue"],
"alt_prompt": {
"label": "Genres / Tags",
"placeholder": "pop, dreamy, warm vocal, female, nostalgic",
"lines": 2,
},
"yue_max_new_tokens": 3000,
"yue_run_n_segments": 2,
"yue_stage2_batch_size": 4,
"yue_segment_duration": 6,
"yue_prompt_start_time": 0.0,
"yue_prompt_end_time": 30.0,
}
if use_audio_prompt:
yue_def.update(
{
"any_audio_prompt": True,
"audio_prompt_choices": True,
"audio_guide_label": "Vocal prompt",
"audio_guide2_label": "Instrumental prompt",
"audio_prompt_type_sources": {
"selection": ["", "A", "AB"],
"labels": {
"": "Lyrics only",
"A": "Mixed audio prompt",
"AB": "Vocal + Instrumental prompts",
},
"letters_filter": "AB",
"default": "",
},
}
)
return yue_def
def _get_yue_download_def(model_def):
use_audio_prompt = bool(model_def.get("yue_audio_prompt", False))
stage1_repo = YUE_STAGE1_ICL_REPO if use_audio_prompt else YUE_STAGE1_COT_REPO
stage1_folder = os.path.basename(stage1_repo)
stage2_folder = os.path.basename(YUE_STAGE2_REPO)
xcodec_root = "xcodec_mini_infer"
xcodec_source_folders = [
"final_ckpt",
"decoders",
"models",
"modules",
"quantization",
"RepCodec",
"descriptaudiocodec",
"vocos",
"semantic_ckpts/hf_1_325000",
]
xcodec_files = [
["config.yaml", "ckpt_00360000.pth"],
["config.yaml", "decoder_131000.pth", "decoder_151000.pth"],
[],
[],
[],
[],
[],
[],
[],
]
return [
{
"repoId": stage1_repo,
"sourceFolderList": [""],
"targetFolderList": [stage1_folder],
"fileList": [YUE_STAGE1_FILES],
},
{
"repoId": YUE_STAGE2_REPO,
"sourceFolderList": [""],
"targetFolderList": [stage2_folder],
"fileList": [YUE_STAGE2_FILES],
},
{
"repoId": stage1_repo,
"sourceFolderList": [""],
"targetFolderList": ["mm_tokenizer_v0.2_hf"],
"fileList": [["tokenizer.model"]],
},
{
"repoId": "m-a-p/xcodec_mini_infer",
"sourceFolderList": [""],
"targetFolderList": [xcodec_root],
"fileList": [["vocoder.py", "post_process_audio.py"]],
},
{
"repoId": "m-a-p/xcodec_mini_infer",
"sourceFolderList": xcodec_source_folders,
"targetFolderList": [xcodec_root] * len(xcodec_source_folders),
"fileList": xcodec_files,
},
]
class family_handler:
@staticmethod
def query_supported_types():
return ["yue"]
@staticmethod
def query_family_maps():
return {}, {}
@staticmethod
def query_model_family():
return "tts"
@staticmethod
def query_family_infos():
return {"tts": (200, "TTS")}
@staticmethod
def register_lora_cli_args(parser, lora_root):
parser.add_argument(
"--lora-dir-tts",
type=str,
default=None,
help=f"Path to a directory that contains TTS settings (default: {os.path.join(lora_root, 'tts')})",
)
@staticmethod
def get_lora_dir(base_model_type, args, lora_root):
return getattr(args, "lora_dir_tts", None) or os.path.join(lora_root, "tts")
@staticmethod
def query_model_def(base_model_type, model_def):
return _get_yue_model_def(model_def)
@staticmethod
def query_model_files(computeList, base_model_type, model_def=None):
return _get_yue_download_def(model_def or {})
@staticmethod
def load_model(
model_filename,
model_type,
base_model_type,
model_def,
quantizeTransformer=False,
text_encoder_quantization=None,
dtype=None,
VAE_dtype=None,
mixed_precision_transformer=False,
save_quantized=False,
submodel_no_list=None,
text_encoder_filename=None,
profile=0,
**kwargs,
):
from .yue.pipeline import YuePipeline
if isinstance(model_filename, list):
stage1_weights = model_filename[0] if len(model_filename) > 0 else ""
stage2_weights = model_filename[1] if len(model_filename) > 1 else ""
else:
stage1_weights = model_filename or ""
stage2_weights = ""
pipeline = YuePipeline(
stage1_weights_path=stage1_weights,
stage2_weights_path=stage2_weights,
use_audio_prompt=bool(model_def.get("yue_audio_prompt", False)),
max_new_tokens=model_def.get("yue_max_new_tokens", 200),
run_n_segments=model_def.get("yue_run_n_segments", 1),
stage2_batch_size=model_def.get("yue_stage2_batch_size", 10),
segment_duration=model_def.get("yue_segment_duration", 6),
prompt_start_time=model_def.get("yue_prompt_start_time", 0.0),
prompt_end_time=model_def.get("yue_prompt_end_time", 30.0),
)
pipe = {
"transformer": pipeline.model_stage1,
"transformer2": pipeline.model_stage2,
"codec_model": pipeline.codec_model,
"vocoder_vocal": pipeline.vocoder_vocal,
"vocoder_inst": pipeline.vocoder_inst,
}
return pipeline, pipe
@staticmethod
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if "alt_prompt" not in ui_defaults:
ui_defaults["alt_prompt"] = ""
defaults = {
"audio_prompt_type": "",
}
for key, value in defaults.items():
ui_defaults.setdefault(key, value)
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update(
{
"audio_prompt_type": "",
"alt_prompt": "pop, dreamy, warm vocal, female, nostalgic",
"repeat_generation": 1,
"video_length": 0,
"num_inference_steps": 0,
"negative_prompt": "",
"temperature": 1.0,
"multi_prompts_gen_type": 2,
}
)
@staticmethod
def validate_generative_prompt(base_model_type, model_def, inputs, one_prompt):
if one_prompt is None or len(str(one_prompt).strip()) == 0:
return "Lyrics prompt cannot be empty for Yue."
alt_prompt = inputs.get("alt_prompt", "")
if alt_prompt is None or len(str(alt_prompt).strip()) == 0:
return "Genres prompt cannot be empty for Yue."
audio_prompt_type = inputs.get("audio_prompt_type", "") or ""
if model_def.get("yue_audio_prompt", False):
if "A" in audio_prompt_type:
if inputs.get("audio_guide") is None:
return "You must provide a vocal or mixed audio prompt for Yue ICL."
if "B" in audio_prompt_type and inputs.get("audio_guide2") is None:
return "You must provide an instrumental prompt for Yue ICL."
start_time = float(
inputs.get(
"yue_prompt_start_time",
model_def.get("yue_prompt_start_time", 0.0),
)
)
end_time = float(
inputs.get(
"yue_prompt_end_time",
model_def.get("yue_prompt_end_time", 30.0),
)
)
if start_time >= end_time:
return "Audio prompt start time must be less than end time."
if end_time - start_time > 30:
return "Audio prompt duration should not exceed 30 seconds."
elif inputs.get("audio_guide") is not None or inputs.get("audio_guide2") is not None:
return "Select an audio prompt type for Yue ICL or clear audio prompts."
else:
if inputs.get("audio_guide") is not None or inputs.get("audio_guide2") is not None:
return "Yue base model does not support audio prompts. Please use Yue ICL."
return None