| import os |
|
|
| import torch |
|
|
| from shared.mps import mps_device_or |
| from shared.utils import files_locator as fl |
|
|
| from .prompt_enhancers import HEARTMULA_LYRIC_PROMPT |
|
|
| HEARTMULA_VERSION = "3B" |
|
|
|
|
| def _get_heartmula_model_def(): |
| return { |
| "audio_only": True, |
| "image_outputs": False, |
| "sliding_window": False, |
| "guidance_max_phases": 1, |
| "no_negative_prompt": True, |
| "inference_steps": False, |
| "temperature": True, |
| "image_prompt_types_allowed": "", |
| "supports_early_stop": True, |
| "profiles_dir": ["heartmula_oss_3b"], |
| "alt_prompt": { |
| "label": "Keywords / Tags", |
| "placeholder": "piano,happy,wedding", |
| "lines": 2, |
| }, |
| "lm_engines": ["cg"], |
| "text_prompt_enhancer_instructions": HEARTMULA_LYRIC_PROMPT, |
| "prompt_enhancer_button_label": "Compose Lyrics", |
| "duration_slider": { |
| "label": "Duration of the Song (in seconds)", |
| "min": 30, |
| "max": 240, |
| "increment": 0.1, |
| "default": 120, |
| }, |
| "top_k_slider": True, |
| "heartmula_cfg_scale": 1.5, |
| "heartmula_topk": 50, |
| "heartmula_max_audio_length_ms": 120000, |
| "heartmula_codec_guidance_scale": 1.25, |
| "heartmula_codec_steps": 10, |
| "heartmula_codec_version": "", |
| "compile": False, |
| } |
|
|
|
|
| def _get_heartmula_download_def(model_def): |
| codec_version = (model_def or {}).get("heartmula_codec_version", "") |
| codec_suffix = f"_{codec_version}" if codec_version else "" |
| repo_id = "DeepBeepMeep/TTS" |
| gen_files = [ |
| "gen_config.json", |
| "tokenizer.json", |
| f"codec_config{codec_suffix}.json", |
| f"HeartMula_codec{codec_suffix}.safetensors", |
| ] |
| return [ |
| { |
| "repoId": repo_id, |
| "sourceFolderList": ["HeartMula"], |
| "fileList": [gen_files], |
| }, |
| ] |
|
|
|
|
| class family_handler: |
| @staticmethod |
| def query_supported_types(): |
| return ["heartmula_oss_3b"] |
|
|
| @staticmethod |
| def query_family_maps(): |
| return {}, {} |
|
|
| @staticmethod |
| def query_model_family(): |
| return "tts" |
|
|
| @staticmethod |
| def query_family_infos(): |
| return {"tts": (200, "TTS")} |
|
|
| @staticmethod |
| def register_lora_cli_args(parser, lora_root): |
| parser.add_argument( |
| "--lora-heart_mula", |
| type=str, |
| default=None, |
| help=f"Path to a directory that contains Heart Mula settings (default: {os.path.join(lora_root, 'heart_mula')})", |
| ) |
|
|
| @staticmethod |
| def get_lora_dir(base_model_type, args, lora_root): |
| return getattr(args, "lora_heart_mula", None) or os.path.join(lora_root, "heart_mula") |
|
|
| @staticmethod |
| def query_model_def(base_model_type, model_def): |
| return _get_heartmula_model_def() |
|
|
| @staticmethod |
| def query_model_files(computeList, base_model_type, model_def=None): |
| return _get_heartmula_download_def(model_def or {}) |
|
|
| @staticmethod |
| def load_model( |
| model_filename, |
| model_type, |
| base_model_type, |
| model_def, |
| quantizeTransformer=False, |
| text_encoder_quantization=None, |
| dtype=None, |
| VAE_dtype=None, |
| mixed_precision_transformer=False, |
| save_quantized=False, |
| submodel_no_list=None, |
| text_encoder_filename=None, |
| profile=0, |
| lm_decoder_engine="legacy", |
| **kwargs, |
| ): |
| from .HeartMula.pipeline import HeartMuLaPipeline |
|
|
| ckpt_root = fl.get_download_location() |
| weights_candidate = None |
| if isinstance(model_filename, (list, tuple)): |
| if len(model_filename) > 0: |
| weights_candidate = model_filename[0] |
| else: |
| weights_candidate = model_filename |
| heartmula_weights_path = None |
| if weights_candidate: |
| heartmula_weights_path = fl.locate_file( |
| weights_candidate, error_if_none=False |
| ) |
| if heartmula_weights_path is None: |
| heartmula_weights_path = weights_candidate |
| pipeline = HeartMuLaPipeline( |
| ckpt_root=ckpt_root, |
| device=mps_device_or(torch.device("cpu")), |
| version=HEARTMULA_VERSION, |
| VAE_dtype=VAE_dtype, |
| heartmula_weights_path=heartmula_weights_path, |
| cfg_scale=model_def.get("heartmula_cfg_scale", 1.5), |
| topk=model_def.get("heartmula_topk", 50), |
| max_audio_length_ms=model_def.get("heartmula_max_audio_length_ms", 120000), |
| codec_steps=model_def.get("heartmula_codec_steps", 10), |
| codec_guidance_scale=model_def.get("heartmula_codec_guidance_scale", 1.25), |
| codec_version=model_def.get("heartmula_codec_version", ""), |
| lm_decoder_engine=lm_decoder_engine, |
| ) |
| if lm_decoder_engine=="cg": |
| pipeline.mula._budget = 0 |
|
|
| pipeline.mula.decoder[0].layers._compile_me = False |
| pipeline.mula.backbone.layers._compile_me = False |
| pipe = { |
| "transformer": pipeline.mula, |
| "transformer2": pipeline.mula.decoder[0], |
| "codec": pipeline.codec, |
| } |
| pipe = { |
| "pipe": pipe, |
| "coTenantsMap": { |
| "transformer": ["transformer2"], |
| "transformer2": ["transformer"], |
| }, |
| } |
|
|
| if int(profile) in (2, 4, 5): |
| pipe["budgets"] = {"transformer2": 200} |
|
|
| return pipeline, pipe |
|
|
| @staticmethod |
| def fix_settings(base_model_type, settings_version, model_def, ui_defaults): |
| if "alt_prompt" not in ui_defaults: |
| ui_defaults["alt_prompt"] = "" |
|
|
| defaults = { |
| "audio_prompt_type": "", |
| } |
| for key, value in defaults.items(): |
| ui_defaults.setdefault(key, value) |
|
|
| if settings_version < 2.44: |
| ui_defaults["guidance_scale"] = model_def.get("heartmula_cfg_scale", 1.5) |
| ui_defaults["top_k"] = model_def.get("heartmula_topk", 50) |
|
|
| @staticmethod |
| def update_default_settings(base_model_type, model_def, ui_defaults): |
| duration_def = model_def.get("duration_slider", {}) |
| ui_defaults.update( |
| { |
| "audio_prompt_type": "", |
| "alt_prompt": "piano,happy,wedding", |
| "repeat_generation": 1, |
| "duration_seconds": duration_def.get("default", 120), |
| "video_length": 0, |
| "num_inference_steps": 0, |
| "negative_prompt": "", |
| "temperature": 1.0, |
| "guidance_scale": model_def.get("heartmula_cfg_scale", 1.5), |
| "top_k": model_def.get("heartmula_topk", 50), |
| "multi_prompts_gen_type": "FG", |
| } |
| ) |
|
|
| @staticmethod |
| def validate_generative_prompt(base_model_type, model_def, inputs, one_prompt): |
| alt_prompt = inputs.get("alt_prompt", "") |
| if alt_prompt is None or len(str(alt_prompt).strip()) == 0: |
| return "Keywords prompt cannot be empty for HeartMuLa." |
| if inputs.get("audio_guide") is not None or inputs.get("audio_guide2") is not None: |
| return "HeartMuLa does not support reference audio yet." |
| return None |
|
|