ColabWan / models /TTS /omnivoice_handler.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
16.6 kB
import json
import os
import re
import torch
import whisper
from accelerate import init_empty_weights
from mmgp import offload
from shared.deepy.transcription import (
WHISPER_MEDIUM_CONFIG_FILENAME,
WHISPER_MEDIUM_FOLDER,
WHISPER_MEDIUM_REPO,
WHISPER_MEDIUM_WEIGHTS_FILENAME,
)
from shared.mps import mps_device_or
from shared.utils import files_locator as fl
from .omnivoice.pipeline import (
OMNIVOICE_ASSET_DIR,
OMNIVOICE_AUDIO_TOKENIZER_DIR,
OMNIVOICE_AUDIO_TOKENIZER_WEIGHTS,
OMNIVOICE_AUTO_END_TRIM_FLAG,
OMNIVOICE_AUTO_SPLIT_MAX_SECONDS,
OMNIVOICE_AUTO_SPLIT_MIN_SECONDS,
OMNIVOICE_AUTO_SPLIT_SETTING_ID,
OMNIVOICE_CONFIG_NAME,
OMNIVOICE_DEFAULT_VOICE_INSTRUCTION,
is_omnivoice_voice_instruction,
normalize_omnivoice_voice_instruction,
)
from .omnivoice.modeling_omnivoice import _resolve_instruct
from .omnivoice.utils.voice_design import _INSTRUCT_VALID_EN, _INSTRUCT_VALID_ZH, _ZH_RE
from .prompt_enhancers import TTS_MONOLOGUE_PROMPT, TTS_QWEN3_DIALOGUE_PROMPT
OMNIVOICE_REPO_ID = "DeepBeepMeep/TTS"
OMNIVOICE_MAIN_FILENAME = "omnivoice_bf16.safetensors"
OMNIVOICE_QUANT_FILENAME = "omnivoice_quanto_bf16_int8.safetensors"
OMNIVOICE_TOKENIZER_FILES = [
OMNIVOICE_CONFIG_NAME,
"tokenizer.json",
"tokenizer_config.json",
"chat_template.jinja",
]
OMNIVOICE_AUDIO_TOKENIZER_FILES = [
"config.json",
"preprocessor_config.json",
OMNIVOICE_AUDIO_TOKENIZER_WEIGHTS,
]
OMNIVOICE_WHISPER_FILES = [
WHISPER_MEDIUM_CONFIG_FILENAME,
WHISPER_MEDIUM_WEIGHTS_FILENAME,
]
OMNIVOICE_LANGUAGE_CHOICES = [
("Auto", "auto"),
("English", "english"),
("Chinese", "chinese"),
("French", "french"),
("German", "german"),
("Italian", "italian"),
("Japanese", "japanese"),
("Korean", "korean"),
("Portuguese", "portuguese"),
("Spanish", "spanish"),
("Arabic", "arabic"),
("Hindi", "hindi"),
("Russian", "russian"),
]
OMNIVOICE_DURATION_SLIDER = {
"label": "Max duration (seconds, 0 = auto)",
"min": 0,
"max": 600,
"increment": 1,
"default": 0,
}
OMNIVOICE_AUDIO_PROMPT_TYPE_SOURCES = {
"selection": ["", "A", "AB"],
"labels": {
"": "Voice design",
"A": "Voice cloning (1 reference audio)",
"AB": "Voice cloning dialogue (Speaker 1 and Speaker 2)",
},
"letters_filter": "AB",
"default": "",
}
OMNIVOICE_AUDIO_PROMPT_TYPE_CUSTOM_OPTION = {
"label": "Auto Detect Segment End",
"flag": OMNIVOICE_AUTO_END_TRIM_FLAG,
}
OMNIVOICE_CUSTOM_SETTINGS = [
{
"id": OMNIVOICE_AUTO_SPLIT_SETTING_ID,
"label": "Auto Split Every s (5-90, optional), may reduce VRAM requirements for very long speeches.",
"name": "Auto Split Every s",
"type": "float",
},
]
OMNIVOICE_PROMPT_SPECIAL_TAGS = [
"[laughter]",
"[sigh]",
"[confirmation-en]",
"[question-en]",
"[question-ah]",
"[question-oh]",
"[question-ei]",
"[question-yi]",
"[surprise-ah]",
"[surprise-oh]",
"[surprise-wa]",
"[surprise-yo]",
"[dissatisfaction-hnn]",
]
def _format_markdown_items(items):
return ", ".join(f"`{item}`" for item in sorted(items))
def _read_omnivoice_text_input(value):
if value is None:
return ""
if isinstance(value, str) and os.path.isfile(value):
with open(value, "r", encoding="utf-8") as reader:
return reader.read()
return str(value)
def _validate_omnivoice_instruction(instruction, target_text):
try:
_resolve_instruct(instruction, use_zh=bool(target_text and _ZH_RE.search(target_text)))
except ValueError as error:
return f"Invalid OmniVoice voice instruction:\n{error}"
return None
OMNIVOICE_INFOS = f"""
## Prompt special tags
These tags can be inserted directly in the main prompt text:
{", ".join(f"`{tag}`" for tag in OMNIVOICE_PROMPT_SPECIAL_TAGS)}
## Voice instruction / reference transcript(s)
Use this field differently depending on the selected voice mode.
### Voice Design
Leave it blank for Auto Voice. To design a voice, enter comma-separated voice tags.
Valid English tags:
{_format_markdown_items(_INSTRUCT_VALID_EN)}
Valid Chinese tags:
{_format_markdown_items(_INSTRUCT_VALID_ZH)}
Examples:
```text
female, young adult, low pitch, british accent
male, middle-aged, whisper
```
Use only valid tags here. Square-bracket non-verbal tags such as `[laughter]` belong in the main prompt, not in this field.
### Voice cloning
Upload a reference audio file and either leave this field blank or enter the exact transcript of the reference audio. If left blank, WanGP transcribes the reference with Whisper.
The transcript must describe the reference audio, not the target prompt. For best results, use a clean 3-10 second reference clip, preferably in the same language as the text you want to generate.
If this field contains only valid voice tags such as `female` or `male, british accent`, WanGP treats it as a voice instruction rather than a reference transcript.
### Two-speaker cloning
Upload both reference voices and provide transcripts like this, or leave blank for Whisper transcription:
```text
Speaker 1: Exact words spoken in the first reference audio.
Speaker 2: Exact words spoken in the second reference audio.
```
"""
def _detach_whisper_alignment_heads(whisper_model):
alignment_heads = getattr(whisper_model, "alignment_heads", None)
if alignment_heads is not None and getattr(alignment_heads, "layout", None) == torch.sparse_coo:
whisper_model._buffers.pop("alignment_heads", None)
object.__setattr__(whisper_model, "alignment_heads", alignment_heads)
def _load_omnivoice_whisper_medium():
model_dir = fl.locate_folder(WHISPER_MEDIUM_FOLDER)
config_path = os.path.join(model_dir, WHISPER_MEDIUM_CONFIG_FILENAME)
weights_path = fl.locate_file(os.path.join(WHISPER_MEDIUM_FOLDER, WHISPER_MEDIUM_WEIGHTS_FILENAME))
with open(config_path, "r", encoding="utf-8") as reader:
config = json.load(reader)
dims = whisper.model.ModelDimensions(**dict(config.get("dims", {}) or {}))
with init_empty_weights(include_buffers=False):
whisper_model = whisper.model.Whisper(dims)
whisper_model._buffers.pop("alignment_heads", None)
offload.load_model_data(whisper_model, weights_path, default_dtype=torch.float32, writable_tensors=False)
whisper_model.to(dtype=torch.float32)
alignment_heads = str(config.get("alignment_heads", "") or "").strip()
if len(alignment_heads) > 0:
whisper_model.set_alignment_heads(alignment_heads.encode("ascii"))
_detach_whisper_alignment_heads(whisper_model)
whisper_model.eval().requires_grad_(False)
whisper_model._model_dtype = torch.float32
return whisper_model
def _get_omnivoice_model_def():
return {
"audio_only": True,
"image_outputs": False,
"sliding_window": False,
"guidance_max_phases": 1,
"no_negative_prompt": True,
"inference_steps": True,
"temperature": False,
"image_prompt_types_allowed": "",
"supports_early_stop": True,
"profiles_dir": ["omnivoice"],
"duration_slider": dict(OMNIVOICE_DURATION_SLIDER),
"infos": OMNIVOICE_INFOS,
"model_modes": {
"choices": list(OMNIVOICE_LANGUAGE_CHOICES),
"default": "auto",
"label": "Language",
},
"alt_prompt": {
"label": "Voice instruction / reference transcript(s)",
"placeholder": "Voice Design: optional voice tags such as female\nVoice clone: optional transcript; blank uses Whisper to autotranscribe",
"lines": 4,
},
"preserve_empty_prompt_lines": True,
"pause_between_sentences": True,
"any_audio_prompt": True,
"audio_prompt_choices": True,
"audio_prompt_type_sources": dict(OMNIVOICE_AUDIO_PROMPT_TYPE_SOURCES),
"audio_prompt_type_custom_option": dict(OMNIVOICE_AUDIO_PROMPT_TYPE_CUSTOM_OPTION),
"custom_settings": [one.copy() for one in OMNIVOICE_CUSTOM_SETTINGS],
"audio_guide_label": "Speaker 1 reference voice",
"audio_guide2_label": "Speaker 2 reference voice",
"text_prompt_enhancer_instructions": TTS_MONOLOGUE_PROMPT,
"text_prompt_enhancer_instructions1": TTS_QWEN3_DIALOGUE_PROMPT,
"text_prompt_enhancer_max_tokens": 512,
"text_prompt_enhancer_max_tokens1": 512,
"prompt_enhancer_def": {
"selection": ["T", "T1"],
"labels": {
"T": "A Speech based on current Prompt",
"T1": "A Dialogue between two People based on current Prompt",
},
"default": "T",
},
"prompt_enhancer_button_label": "Write",
"compile": False,
}
def _get_omnivoice_download_def():
return [
{
"repoId": OMNIVOICE_REPO_ID,
"sourceFolderList": [OMNIVOICE_ASSET_DIR, OMNIVOICE_AUDIO_TOKENIZER_DIR],
"fileList": [OMNIVOICE_TOKENIZER_FILES, OMNIVOICE_AUDIO_TOKENIZER_FILES],
},
{
"repoId": WHISPER_MEDIUM_REPO,
"sourceFolderList": [WHISPER_MEDIUM_FOLDER],
"fileList": [OMNIVOICE_WHISPER_FILES],
}
]
class family_handler:
@staticmethod
def query_supported_types():
return ["omnivoice"]
@staticmethod
def query_family_maps():
return {}, {}
@staticmethod
def query_model_family():
return "tts"
@staticmethod
def query_family_infos():
return {"tts": (200, "TTS")}
@staticmethod
def register_lora_cli_args(parser, lora_root):
parser.add_argument(
"--lora-dir-omnivoice",
type=str,
default=None,
help=f"Path to a directory that contains OmniVoice settings (default: {os.path.join(lora_root, 'omnivoice')})",
)
@staticmethod
def get_lora_dir(base_model_type, args, lora_root):
return getattr(args, "lora_dir_omnivoice", None) or os.path.join(lora_root, "omnivoice")
@staticmethod
def query_model_def(base_model_type, model_def):
return _get_omnivoice_model_def()
@staticmethod
def query_model_files(computeList, base_model_type, model_def=None):
return _get_omnivoice_download_def()
@staticmethod
def load_model(
model_filename,
model_type,
base_model_type,
model_def,
quantizeTransformer=False,
text_encoder_quantization=None,
dtype=None,
VAE_dtype=None,
mixed_precision_transformer=False,
save_quantized=False,
submodel_no_list=None,
text_encoder_filename=None,
profile=0,
lm_decoder_engine="legacy",
**kwargs,
):
from .omnivoice.pipeline import OmniVoicePipeline
weights_path = model_filename[0] if isinstance(model_filename, (list, tuple)) else model_filename
pipeline = OmniVoicePipeline(
model_weights_path=weights_path,
ckpt_root=fl.get_download_location(),
device=mps_device_or(torch.device("cpu")),
dtype=dtype or torch.bfloat16,
)
whisper_model = _load_omnivoice_whisper_medium()
pipeline.set_whisper_model(whisper_model)
pipe = {
"transformer": pipeline.model,
"audio_tokenizer": pipeline.audio_tokenizer,
"whisper": whisper_model,
}
if save_quantized and weights_path:
from wgp import save_quantized_model
config_path = fl.locate_file(os.path.join(OMNIVOICE_ASSET_DIR, OMNIVOICE_CONFIG_NAME))
save_quantized_model(pipeline.model, model_type, weights_path, dtype or torch.bfloat16, config_path)
return pipeline, pipe
@staticmethod
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
ui_defaults.setdefault("audio_prompt_type", "")
ui_defaults.setdefault("model_mode", "auto")
ui_defaults.setdefault("alt_prompt", "")
ui_defaults["alt_prompt"] = normalize_omnivoice_voice_instruction(str(ui_defaults.get("alt_prompt") or ""))
ui_defaults.setdefault("pause_seconds", 0.2)
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
duration_def = model_def.get("duration_slider", {})
ui_defaults.update(
{
"audio_prompt_type": "",
"model_mode": "auto",
"prompt": "The lights are already on, so we can start whenever you are ready.",
"alt_prompt": OMNIVOICE_DEFAULT_VOICE_INSTRUCTION,
"repeat_generation": 1,
"duration_seconds": duration_def.get("default", 0),
"pause_seconds": 0.2,
"video_length": 0,
"num_inference_steps": 32,
"negative_prompt": "",
"temperature": 0.1,
"guidance_scale": 2.0,
"multi_prompts_gen_type": "FG",
}
)
@staticmethod
def validate_generative_prompt(base_model_type, model_def, inputs, one_prompt):
if one_prompt is None or len(str(one_prompt).strip()) == 0:
return "Prompt text cannot be empty for OmniVoice."
audio_prompt_type = str(inputs.get("audio_prompt_type", "") or "").upper()
text = str(one_prompt)
instruction_or_ref = normalize_omnivoice_voice_instruction(_read_omnivoice_text_input(inputs.get("alt_prompt", ""))).strip()
if instruction_or_ref and ("A" not in audio_prompt_type or is_omnivoice_voice_instruction(instruction_or_ref)):
instruction_error = _validate_omnivoice_instruction(instruction_or_ref, text)
if instruction_error is not None:
return instruction_error
has_speaker_syntax = re.search(r"Speaker\s*\d+\s*:", text, flags=re.IGNORECASE) is not None
if "A" in audio_prompt_type and "B" not in audio_prompt_type and inputs.get("audio_guide") is None:
return "OmniVoice voice cloning requires a reference audio file."
if "B" in audio_prompt_type:
if inputs.get("audio_guide") is None or inputs.get("audio_guide2") is None:
return "OmniVoice dialogue mode requires two reference audio files."
speaker_matches = list(re.finditer(r"Speaker\s*(\d+)\s*:", text, flags=re.IGNORECASE))
if not speaker_matches:
return "OmniVoice dialogue mode requires prompt lines using Speaker 1: and Speaker 2:."
speaker_ids = sorted({int(m.group(1)) for m in speaker_matches})
if len(speaker_ids) != 2:
return "OmniVoice dialogue mode requires exactly two speaker IDs. Use Speaker 1: and Speaker 2:."
elif has_speaker_syntax:
return "Speaker-tag dialogue requires OmniVoice two-speaker mode."
return None
@staticmethod
def validate_generative_settings(base_model_type, model_def, inputs):
custom_settings = inputs.get("custom_settings", None)
if custom_settings is None:
return None
if not isinstance(custom_settings, dict):
return "Custom settings must be a dictionary."
raw_value = custom_settings.get(OMNIVOICE_AUTO_SPLIT_SETTING_ID, None)
if raw_value is None:
return None
if isinstance(raw_value, str):
raw_value = raw_value.strip()
if len(raw_value) == 0:
custom_settings.pop(OMNIVOICE_AUTO_SPLIT_SETTING_ID, None)
inputs["custom_settings"] = custom_settings if len(custom_settings) > 0 else None
return None
try:
if isinstance(raw_value, bool):
raise ValueError()
auto_split_seconds = float(raw_value)
except Exception:
return (
f"Auto Split Every s must be a number between "
f"{int(OMNIVOICE_AUTO_SPLIT_MIN_SECONDS)} and {int(OMNIVOICE_AUTO_SPLIT_MAX_SECONDS)} seconds."
)
if auto_split_seconds < OMNIVOICE_AUTO_SPLIT_MIN_SECONDS or auto_split_seconds > OMNIVOICE_AUTO_SPLIT_MAX_SECONDS:
return (
f"Auto Split Every s must be between "
f"{int(OMNIVOICE_AUTO_SPLIT_MIN_SECONDS)} and {int(OMNIVOICE_AUTO_SPLIT_MAX_SECONDS)} seconds."
)
custom_settings[OMNIVOICE_AUTO_SPLIT_SETTING_ID] = auto_split_seconds
inputs["custom_settings"] = custom_settings
return None