VoxCPM2 / app.py
Kremon96's picture
Upload 9 files
c73d0e1 verified
# ComfyUI-VoxCPMTTS
# A clean, efficient ComfyUI custom node for VoxCPM TTS (Text-to-Speech) functionality. This implementation provides high-quality speech generation and voice cloning capabilities using the VoxCPM model.
#
# Models License Notice:
# - VoxCPM1.5: Apache-2.0 License (https://huggingface.co/openbmb/VoxCPM1.5)
#
# This integration script follows GPL-3.0 License.
# When using or modifying this code, please respect both the original model licenses
# and this integration's license terms.
#
# Source: https://github.com/1038lab/ComfyUI-VoxCPMTTS
import torch
import numpy as np
import os
import folder_paths
import comfy.model_management as model_management
import soundfile as sf
from huggingface_hub import snapshot_download
from voxcpm import VoxCPM
DEFAULT_MODEL_ID = os.getenv("VOXCPM_MODEL_ID", "openbmb/VoxCPM1.5")
DEFAULT_MODEL_DIRNAME = os.getenv("VOXCPM_MODEL_DIR", DEFAULT_MODEL_ID.split("/")[-1])
MODEL_CHOICES = ["VoxCPM 1.5", "VoxCPM 0.5B"]
MODEL_REPO_MAP = {
"VoxCPM 1.5": "openbmb/VoxCPM1.5",
"VoxCPM 0.5B": "openbmb/VoxCPM-0.5B",
}
DEFAULT_MODEL_CHOICE = next((name for name, repo in MODEL_REPO_MAP.items() if repo == DEFAULT_MODEL_ID and name in MODEL_CHOICES), MODEL_CHOICES[0])
_voxcpm_models = {}
_asr_model = None
_asr_backend = None
def get_available_devices():
devices = ["auto"]
if torch.cuda.is_available():
devices.append("cuda")
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
devices.append("mps")
devices.append("cpu")
return devices
def get_best_device():
if torch.cuda.is_available():
return model_management.get_torch_device()
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def unload_model_instance(repo_id: str):
global _voxcpm_models
to_delete = []
for (rid, _), mdl in _voxcpm_models.items():
if rid == repo_id:
try:
if hasattr(mdl, "tts_model"):
mdl.tts_model.to(torch.device("cpu"))
except Exception:
pass
to_delete.append((rid, _))
for key in to_delete:
_voxcpm_models.pop(key, None)
model_management.soft_empty_cache()
def load_model(model_repo: str = DEFAULT_MODEL_ID, model_dirname: str = None):
global _voxcpm_models
repo_id = model_repo or DEFAULT_MODEL_ID
dirname = model_dirname or os.getenv("VOXCPM_MODEL_DIR") or repo_id.split("/")[-1]
models_dir = folder_paths.models_dir
tts_dir = os.path.join(models_dir, "TTS")
target_dir = os.path.join(tts_dir, dirname)
cache_key = (repo_id, target_dir)
if cache_key not in _voxcpm_models:
if not os.path.exists(target_dir) or not os.listdir(target_dir):
try:
os.makedirs(tts_dir, exist_ok=True)
os.makedirs(target_dir, exist_ok=True)
snapshot_download(
repo_id=repo_id,
local_dir=target_dir,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"[ERROR] Failed to download VoxCPM model: {str(e)}")
raise e
try:
_voxcpm_models[cache_key] = VoxCPM(
voxcpm_model_path=target_dir,
enable_denoiser=False,
optimize=False,
)
except Exception as e:
print(f"[ERROR] VoxCPM initialization failed: {str(e)}")
raise e
return _voxcpm_models[cache_key]
class AILab_VoxCPMTTS:
@classmethod
def INPUT_TYPES(s):
available_devices = get_available_devices()
return {
"required": {
"model": (MODEL_CHOICES, {"default": DEFAULT_MODEL_CHOICE, "tooltip": "Select VoxCPM model version"}),
"text": ("STRING", {"multiline": True, "default": "Hello, this is VoxCPM TTS.", "tooltip": "Text to synthesize into speech"}),
},
"optional": {
"reference_audio": ("AUDIO", {"tooltip": "Reference audio for voice cloning"}),
"reference_text": ("STRING", {"multiline": True, "default": "", "tooltip": "Text transcript of reference audio (leave empty for auto-transcription)"}),
"auto_transcribe_reference": ("BOOLEAN", {"default": False, "tooltip": "If reference text is empty, auto-run ASR to fill it"}),
"unload_model": ("BOOLEAN", {"default": False, "tooltip": "Unload model after generation to free VRAM"}),
"device": (available_devices, {"default": available_devices[0], "tooltip": "Device to run the model on"}),
"seed": ("INT", {"default": -1, "min": -1, "max": 0xffffffffffffffff, "tooltip": "Random seed (-1 for random)"}),
}
}
RETURN_TYPES = ("STRING", "AUDIO")
RETURN_NAMES = ("REFERENCE_TEXT", "AUDIO")
FUNCTION = "generate"
CATEGORY = "🧪AILab/🔊TTS/VoxCPMTTS"
def generate(self, model, text, device, seed, reference_audio=None, reference_text="", auto_transcribe_reference=False, unload_model=False):
# Hidden presets for simplified node
cfg_value = 2.0
inference_steps = 10
max_length = 4096
fade_in_ms = 20
retry_attempts = 2
retry_threshold = 8.0
normalize = True
if seed != -1:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
model_management.soft_empty_cache()
repo_id = MODEL_REPO_MAP.get(model, MODEL_REPO_MAP.get(model.strip(), model))
model = load_model(model_repo=repo_id)
target_sr = 44100 if "1.5" in repo_id else 16000
if device == "auto":
target_device = get_best_device()
elif device == "cuda" and torch.cuda.is_available():
target_device = model_management.get_torch_device()
else:
target_device = torch.device("cpu")
try:
if hasattr(model, 'tts_model'):
model.tts_model.to(target_device)
if hasattr(model.tts_model, 'audio_vae'):
model.tts_model.audio_vae.to(target_device)
except Exception as e:
pass
text = text.strip()
if len(text) == 0:
raise ValueError("Please input text to synthesize.")
prompt_wav_path = None
prompt_waveform = None
prompt_sample_rate = None
prompt_text = None
if reference_audio:
import torchaudio, tempfile
waveform = reference_audio["waveform"]
sample_rate = reference_audio["sample_rate"]
# Normalize to shape [1,1,T]
while waveform.dim() > 3:
waveform = waveform.squeeze(0)
if waveform.dim() == 3:
if waveform.shape[0] > 1:
waveform = waveform[0:1, ...]
elif waveform.dim() == 2:
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.unsqueeze(0)
elif waveform.dim() == 1:
waveform = waveform.unsqueeze(0).unsqueeze(0)
else:
waveform = waveform.reshape(1, 1, -1)
prompt_waveform = waveform.contiguous().float().cpu()
prompt_sample_rate = int(sample_rate)
if isinstance(reference_text, str) and len(reference_text.strip()) >= 1:
prompt_text = reference_text.strip()
elif auto_transcribe_reference:
tmp_wav = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
wav_np = prompt_waveform.squeeze().cpu().numpy()
if wav_np.ndim > 1:
wav_np = wav_np[0]
sf.write(tmp_file.name, wav_np, prompt_sample_rate)
tmp_wav = tmp_file.name
global _asr_model
global _asr_backend
if _asr_model is None:
# Prefer faster-whisper; fallback to openai-whisper
try:
from faster_whisper import WhisperModel # type: ignore
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if torch.cuda.is_available() else "int8"
model_name = os.environ.get("VOXCPM_ASR_MODEL", "small")
print(f"[ASR] Loading faster-whisper: {model_name} on {device}/{compute_type}")
_asr_model = WhisperModel(model_name, device=device, compute_type=compute_type)
_asr_backend = "faster_whisper"
except Exception:
try:
import whisper # type: ignore
model_name = os.environ.get("VOXCPM_ASR_MODEL", "small")
print(f"[ASR] Loading openai-whisper: {model_name}")
_asr_model = whisper.load_model(model_name)
_asr_backend = "openai_whisper"
except Exception as e_load:
print(f"[ASR] Failed to load ASR backends: {e_load}")
_asr_backend = None
text_out = ""
if _asr_backend == "faster_whisper":
segments, _ = _asr_model.transcribe(tmp_wav, language=None, beam_size=1)
text_out = " ".join([seg.text for seg in segments]).strip()
elif _asr_backend == "openai_whisper":
res = _asr_model.transcribe(tmp_wav)
text_out = (res.get("text", "") or "").strip()
prompt_text = text_out if len(text_out) > 0 else None
print(f"[ASR] Transcription: {text_out[:120]}" + ("..." if len(text_out) > 120 else ""))
except Exception as e:
prompt_text = None
print(f"[ASR] Failed: {e}")
finally:
if tmp_wav and os.path.exists(tmp_wav):
try:
os.unlink(tmp_wav)
except Exception:
pass
if prompt_text is None:
raise ValueError("reference_text is required when providing reference_audio (or enable auto_transcribe_reference and ensure ASR succeeds)")
tmp_prompt_file = None
if prompt_wav_path is None and prompt_waveform is not None and prompt_sample_rate is not None:
try:
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
wav_np = prompt_waveform.squeeze().cpu().numpy()
if wav_np.ndim > 1:
wav_np = wav_np[0]
sf.write(tmp_file.name, wav_np, prompt_sample_rate)
tmp_prompt_file = tmp_file.name
prompt_wav_path = tmp_prompt_file
except Exception as e:
print(f"[WARN] Failed to write prompt waveform: {e}")
prompt_wav_path = None
try:
wav = model.generate(
text=text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=float(cfg_value),
inference_timesteps=int(inference_steps),
max_len=max_length,
normalize=normalize,
retry_badcase=retry_attempts > 0,
retry_badcase_max_times=retry_attempts,
retry_badcase_ratio_threshold=retry_threshold,
)
finally:
if tmp_prompt_file and os.path.exists(tmp_prompt_file):
try:
os.unlink(tmp_prompt_file)
except Exception:
pass
if isinstance(wav, np.ndarray):
audio_tensor = torch.from_numpy(wav).float()
else:
audio_tensor = wav.float()
if audio_tensor.is_cuda:
audio_tensor = audio_tensor.cpu()
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
elif audio_tensor.dim() == 2:
audio_tensor = audio_tensor.unsqueeze(0)
try:
n = int(max(0, fade_in_ms) * target_sr / 1000)
if n > 0 and audio_tensor.shape[-1] > n:
ramp = torch.linspace(0.0, 1.0, n, dtype=audio_tensor.dtype, device=audio_tensor.device)
audio_tensor[..., :n] *= ramp
except Exception:
pass
used_reference_text = prompt_text if isinstance(prompt_text, str) else ""
result = (used_reference_text, {"waveform": audio_tensor, "sample_rate": target_sr})
if unload_model:
unload_model_instance(repo_id)
return result
class AILab_VoxCPMTTS_Advanced:
@classmethod
def INPUT_TYPES(s):
available_devices = get_available_devices()
return {
"required": {
"model": (MODEL_CHOICES, {"default": DEFAULT_MODEL_CHOICE, "tooltip": "Select VoxCPM model version"}),
"text": ("STRING", {"multiline": True, "default": "Hello, this is VoxCPM TTS.", "tooltip": "Text to synthesize into speech"}),
"cfg_value": ("FLOAT", {"default": 2.0, "min": 1.0, "max": 10.0, "step": 0.1, "tooltip": "Guidance scale: higher = more adherence to prompt, lower = more natural"}),
"inference_steps": ("INT", {"default": 10, "min": 1, "max": 100, "step": 1, "tooltip": "Number of diffusion steps: higher = better quality, lower = faster"}),
"max_length": ("INT", {"default": 4096, "min": 256, "max": 8192, "step": 256, "tooltip": "Maximum token length during generation"}),
"normalize": ("BOOLEAN", {"default": True, "tooltip": "Enable text normalization for numbers, punctuation, etc."}),
},
"optional": {
"reference_audio": ("AUDIO", {"tooltip": "Reference audio for voice cloning"}),
"reference_text": ("STRING", {"multiline": True, "default": "", "tooltip": "Text transcript of reference audio (leave empty for auto-transcription)"}),
"fade_in_ms": ("INT", {"default": 20, "min": 0, "max": 1000, "step": 5, "tooltip": "Fade-in duration to reduce initial artifacts (0-1000ms)"}),
"auto_transcribe_reference": ("BOOLEAN", {"default": False, "tooltip": "Automatically transcribe reference audio when no text is provided"}),
"show_transcription_log": ("BOOLEAN", {"default": True, "tooltip": "Show ASR transcription logs in console"}),
"unload_model": ("BOOLEAN", {"default": False, "tooltip": "Unload model after generation to free VRAM"}),
"retry_attempts": ("INT", {"default": 2, "min": 0, "max": 10, "step": 1, "tooltip": "Number of retry attempts for bad cases"}),
"retry_threshold": ("FLOAT", {"default": 8.0, "min": 2.0, "max": 20.0, "step": 0.1, "tooltip": "Audio-to-text ratio threshold for retry detection"}),
"device": (available_devices, {"default": available_devices[0], "tooltip": "Device to run the model on"}),
"seed": ("INT", {"default": -1, "min": -1, "max": 0xffffffffffffffff, "tooltip": "Random seed (-1 for random)"}),
}
}
RETURN_TYPES = ("AUDIO", "STRING")
FUNCTION = "generate"
CATEGORY = "🧪AILab/🔊TTS/VoxCPMTTS"
def generate(
self,
model,
text,
cfg_value,
inference_steps,
max_length,
normalize,
seed,
device,
retry_attempts,
retry_threshold,
reference_audio=None,
reference_text="",
fade_in_ms=20,
auto_transcribe_reference=False,
show_transcription_log=True,
unload_model=False,
):
if seed != -1:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
model_management.soft_empty_cache()
repo_id = MODEL_REPO_MAP.get(model, MODEL_REPO_MAP.get(model.strip(), model))
model = load_model(model_repo=repo_id)
target_sr = 44100 if "1.5" in repo_id else 16000
if device == "auto":
target_device = get_best_device()
elif device == "cuda" and torch.cuda.is_available():
target_device = model_management.get_torch_device()
else:
target_device = torch.device("cpu")
try:
if hasattr(model, 'tts_model'):
model.tts_model.to(target_device)
if hasattr(model.tts_model, 'audio_vae'):
model.tts_model.audio_vae.to(target_device)
except Exception:
pass
text = text.strip()
if len(text) == 0:
raise ValueError("Please input text to synthesize.")
prompt_wav_path = None
prompt_waveform = None
prompt_sample_rate = None
prompt_text = None
if reference_audio:
import torchaudio, tempfile
waveform = reference_audio["waveform"]
sample_rate = reference_audio["sample_rate"]
while waveform.dim() > 3:
waveform = waveform.squeeze(0)
if waveform.dim() == 3:
if waveform.shape[0] > 1:
waveform = waveform[0:1, ...]
elif waveform.dim() == 2:
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.unsqueeze(0)
elif waveform.dim() == 1:
waveform = waveform.unsqueeze(0).unsqueeze(0)
else:
waveform = waveform.reshape(1, 1, -1)
prompt_waveform = waveform.contiguous().float().cpu()
prompt_sample_rate = int(sample_rate)
if isinstance(reference_text, str) and len(reference_text.strip()) >= 1:
prompt_text = reference_text.strip()
elif auto_transcribe_reference:
tmp_wav = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
wav_np = prompt_waveform.squeeze().cpu().numpy()
if wav_np.ndim > 1:
wav_np = wav_np[0]
sf.write(tmp_file.name, wav_np, prompt_sample_rate)
tmp_wav = tmp_file.name
global _asr_model, _asr_backend
if _asr_model is None:
try:
from faster_whisper import WhisperModel # type: ignore
asr_device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if torch.cuda.is_available() else "int8"
model_name = os.environ.get("VOXCPM_ASR_MODEL", "small")
if show_transcription_log:
print(f"[ASR] Loading faster-whisper: {model_name} on {asr_device}/{compute_type}")
_asr_model = WhisperModel(model_name, device=asr_device, compute_type=compute_type)
_asr_backend = "faster_whisper"
except Exception:
try:
import whisper # type: ignore
model_name = os.environ.get("VOXCPM_ASR_MODEL", "small")
if show_transcription_log:
print(f"[ASR] Loading openai-whisper: {model_name}")
_asr_model = whisper.load_model(model_name)
_asr_backend = "openai_whisper"
except Exception as e_load:
if show_transcription_log:
print(f"[ASR] Failed to load ASR backends: {e_load}")
_asr_backend = None
text_out = ""
if _asr_backend == "faster_whisper":
segments, _ = _asr_model.transcribe(tmp_wav, language=None, beam_size=1)
text_out = " ".join([seg.text for seg in segments]).strip()
elif _asr_backend == "openai_whisper":
res = _asr_model.transcribe(tmp_wav)
text_out = (res.get("text", "") or "").strip()
prompt_text = text_out if len(text_out) > 0 else None
if show_transcription_log:
print(f"[ASR] Transcription: {text_out[:120]}" + ("..." if len(text_out) > 120 else ""))
except Exception as e:
prompt_text = None
if show_transcription_log:
print(f"[ASR] Failed: {e}")
finally:
if tmp_wav and os.path.exists(tmp_wav):
try:
os.unlink(tmp_wav)
except Exception:
pass
if prompt_text is None:
if show_transcription_log:
print("[ERROR] reference_text is required when providing reference_audio (or enable auto_transcribe_reference and ensure ASR succeeds)")
raise ValueError("reference_text is required when providing reference_audio (or enable auto_transcribe_reference and ensure ASR succeeds)")
tmp_prompt_file = None
if prompt_wav_path is None and prompt_waveform is not None and prompt_sample_rate is not None:
try:
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
wav_np = prompt_waveform.squeeze().cpu().numpy()
if wav_np.ndim > 1:
wav_np = wav_np[0]
sf.write(tmp_file.name, wav_np, prompt_sample_rate)
tmp_prompt_file = tmp_file.name
prompt_wav_path = tmp_prompt_file
except Exception as e:
if show_transcription_log:
print(f"[WARN] Failed to write prompt waveform: {e}")
prompt_wav_path = None
try:
wav = model.generate(
text=text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=float(cfg_value),
inference_timesteps=int(inference_steps),
max_len=max_length,
normalize=normalize,
retry_badcase=retry_attempts > 0,
retry_badcase_max_times=retry_attempts,
retry_badcase_ratio_threshold=retry_threshold,
)
finally:
if tmp_prompt_file and os.path.exists(tmp_prompt_file):
try:
os.unlink(tmp_prompt_file)
except Exception:
pass
if isinstance(wav, np.ndarray):
audio_tensor = torch.from_numpy(wav).float()
else:
audio_tensor = wav.float()
if audio_tensor.is_cuda:
audio_tensor = audio_tensor.cpu()
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
elif audio_tensor.dim() == 2:
audio_tensor = audio_tensor.unsqueeze(0)
try:
n = int(max(0, fade_in_ms) * target_sr / 1000)
if n > 0 and audio_tensor.shape[-1] > n:
ramp = torch.linspace(0.0, 1.0, n, dtype=audio_tensor.dtype, device=audio_tensor.device)
audio_tensor[..., :n] *= ramp
except Exception:
pass
used_reference_text = prompt_text if isinstance(prompt_text, str) else ""
result = ({"waveform": audio_tensor, "sample_rate": target_sr}, used_reference_text)
if unload_model:
unload_model_instance(repo_id)
return result
NODE_CLASS_MAPPINGS = {
"AILab_VoxCPMTTS": AILab_VoxCPMTTS,
"AILab_VoxCPMTTS_Advanced": AILab_VoxCPMTTS_Advanced,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AILab_VoxCPMTTS": "VoxCPM TTS",
"AILab_VoxCPMTTS_Advanced": "VoxCPM TTS (Advanced)",
}