|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
import tempfile |
|
|
import torch |
|
|
import numpy as np |
|
|
from typing import List, Optional |
|
|
|
|
|
from .base_vibevoice import BaseVibeVoiceNode |
|
|
|
|
|
|
|
|
logger = logging.getLogger("VibeVoice") |
|
|
|
|
|
class VibeVoiceMultipleSpeakersNode(BaseVibeVoiceNode): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
try: |
|
|
from .free_memory_node import VibeVoiceFreeMemoryNode |
|
|
VibeVoiceFreeMemoryNode.register_multi_speaker(self) |
|
|
except: |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return { |
|
|
"required": { |
|
|
"text": ("STRING", { |
|
|
"multiline": True, |
|
|
"default": "[1]: Hello, this is the first speaker.\n[2]: Hi there, I'm the second speaker.\n[1]: Nice to meet you!\n[2]: Nice to meet you too!", |
|
|
"tooltip": "Text with speaker labels. Use '[N]:' format where N is 1-4. Gets disabled when connected to another node.", |
|
|
"forceInput": False, |
|
|
"dynamicPrompts": True |
|
|
}), |
|
|
"model": (["VibeVoice-1.5B", "VibeVoice-Large", "VibeVoice-Large-Quant-4Bit","VibeVoice-Large-Q8"], { |
|
|
"default": "VibeVoice-Large-Q8", |
|
|
"tooltip": "Model to use. Large is recommended for multi-speaker generation, Quant-4Bit uses less VRAM (CUDA only)" |
|
|
}), |
|
|
"attention_type": (["auto", "eager", "sdpa", "flash_attention_2", "sage"], { |
|
|
"default": "auto", |
|
|
"tooltip": "Attention implementation. Auto selects the best available, eager is standard, sdpa is optimized PyTorch, flash_attention_2 requires compatible GPU, sage uses quantized attention for speedup (CUDA only)" |
|
|
}), |
|
|
"free_memory_after_generate": ("BOOLEAN", {"default": True, "tooltip": "Free model from memory after generation to save VRAM/RAM. Disable to keep model loaded for faster subsequent generations"}), |
|
|
"diffusion_steps": ("INT", {"default": 20, "min": 5, "max": 100, "step": 1, "tooltip": "Number of denoising steps. More steps = better quality but slower. Default: 20"}), |
|
|
"seed": ("INT", {"default": 42, "min": 0, "max": 2**32-1, "tooltip": "Random seed for generation. Default 42 is used in official examples"}), |
|
|
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.5, "step": 0.05, "tooltip": "Classifier-free guidance scale (official default: 1.3)"}), |
|
|
"use_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable sampling mode. When False (default), uses deterministic generation like official examples"}), |
|
|
}, |
|
|
"optional": { |
|
|
"speaker1_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 1. If not provided, synthetic voice will be used."}), |
|
|
"speaker2_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 2. If not provided, synthetic voice will be used."}), |
|
|
"speaker3_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 3. If not provided, synthetic voice will be used."}), |
|
|
"speaker4_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 4. If not provided, synthetic voice will be used."}), |
|
|
"temperature": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}), |
|
|
"top_p": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 1.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}), |
|
|
} |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("AUDIO",) |
|
|
RETURN_NAMES = ("audio",) |
|
|
FUNCTION = "generate_speech" |
|
|
CATEGORY = "VibeVoiceWrapper" |
|
|
DESCRIPTION = "Generate multi-speaker conversations with up to 4 distinct voices using Microsoft VibeVoice" |
|
|
|
|
|
def _prepare_voice_sample(self, voice_audio, speaker_idx: int) -> Optional[np.ndarray]: |
|
|
"""Prepare a single voice sample from input audio""" |
|
|
return self._prepare_audio_from_comfyui(voice_audio) |
|
|
|
|
|
def generate_speech(self, text: str = "", model: str = "VibeVoice-7B-Preview", |
|
|
attention_type: str = "auto", free_memory_after_generate: bool = True, |
|
|
diffusion_steps: int = 20, seed: int = 42, cfg_scale: float = 1.3, |
|
|
use_sampling: bool = False, speaker1_voice=None, speaker2_voice=None, |
|
|
speaker3_voice=None, speaker4_voice=None, |
|
|
temperature: float = 0.95, top_p: float = 0.95): |
|
|
"""Generate multi-speaker speech from text using VibeVoice""" |
|
|
|
|
|
try: |
|
|
|
|
|
if not text or not text.strip(): |
|
|
raise Exception("No text provided. Please enter text with speaker labels (e.g., '[1]: Hello' or '[2]: Hi')") |
|
|
|
|
|
|
|
|
bracket_pattern = r'\[(\d+)\]\s*:' |
|
|
speakers_numbers = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)]))) |
|
|
|
|
|
|
|
|
if not speakers_numbers: |
|
|
num_speakers = 1 |
|
|
else: |
|
|
num_speakers = min(max(speakers_numbers), 4) |
|
|
if max(speakers_numbers) > 4: |
|
|
print(f"[VibeVoice] Warning: Found {max(speakers_numbers)} speakers, limiting to 4") |
|
|
|
|
|
|
|
|
|
|
|
converted_text = text |
|
|
|
|
|
|
|
|
speakers_in_text = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)]))) |
|
|
|
|
|
if not speakers_in_text: |
|
|
|
|
|
speaker_pattern = r'Speaker\s+(\d+)\s*:' |
|
|
speakers_in_text = sorted(list(set([int(m) for m in re.findall(speaker_pattern, text)]))) |
|
|
|
|
|
if speakers_in_text: |
|
|
|
|
|
for speaker_num in sorted(speakers_in_text, reverse=True): |
|
|
pattern = f'Speaker\\s+{speaker_num}\\s*:' |
|
|
replacement = f'Speaker {speaker_num - 1}:' |
|
|
converted_text = re.sub(pattern, replacement, converted_text) |
|
|
else: |
|
|
|
|
|
speakers_in_text = [1] |
|
|
|
|
|
|
|
|
pause_segments = self._parse_pause_keywords(text) |
|
|
|
|
|
|
|
|
speaker_segments_with_pauses = [] |
|
|
segments = [] |
|
|
|
|
|
for seg_type, seg_content in pause_segments: |
|
|
if seg_type == 'pause': |
|
|
speaker_segments_with_pauses.append(('pause', seg_content, None)) |
|
|
else: |
|
|
|
|
|
text_clean = seg_content.replace('\n', ' ').replace('\r', ' ') |
|
|
text_clean = ' '.join(text_clean.split()) |
|
|
|
|
|
if text_clean: |
|
|
speaker_segments_with_pauses.append(('text', text_clean, 1)) |
|
|
segments.append(f"Speaker 0: {text_clean}") |
|
|
|
|
|
|
|
|
converted_text = '\n'.join(segments) if segments else f"Speaker 0: {text}" |
|
|
else: |
|
|
|
|
|
|
|
|
segments = [] |
|
|
|
|
|
|
|
|
speaker_matches = list(re.finditer(f'\\[({"|".join(map(str, speakers_in_text))})\\]\\s*:', converted_text)) |
|
|
|
|
|
|
|
|
speaker_segments_with_pauses = [] |
|
|
|
|
|
for i, match in enumerate(speaker_matches): |
|
|
speaker_num = int(match.group(1)) |
|
|
start = match.end() |
|
|
|
|
|
|
|
|
if i + 1 < len(speaker_matches): |
|
|
end = speaker_matches[i + 1].start() |
|
|
else: |
|
|
end = len(converted_text) |
|
|
|
|
|
|
|
|
speaker_text = converted_text[start:end].strip() |
|
|
|
|
|
|
|
|
pause_segments = self._parse_pause_keywords(speaker_text) |
|
|
|
|
|
|
|
|
for seg_type, seg_content in pause_segments: |
|
|
if seg_type == 'pause': |
|
|
|
|
|
speaker_segments_with_pauses.append(('pause', seg_content, None)) |
|
|
else: |
|
|
|
|
|
text_clean = seg_content.replace('\n', ' ').replace('\r', ' ') |
|
|
text_clean = ' '.join(text_clean.split()) |
|
|
|
|
|
if text_clean: |
|
|
|
|
|
speaker_segments_with_pauses.append(('text', text_clean, speaker_num)) |
|
|
|
|
|
segments.append(f'Speaker {speaker_num - 1}: {text_clean}') |
|
|
|
|
|
|
|
|
converted_text = '\n'.join(segments) if segments else "" |
|
|
|
|
|
|
|
|
|
|
|
speakers = [f"Speaker {i}" for i in range(len(speakers_in_text))] |
|
|
|
|
|
|
|
|
model_mapping = self._get_model_mapping() |
|
|
model_path = model_mapping.get(model, model) |
|
|
self.load_model(model, model_path, attention_type) |
|
|
|
|
|
voice_inputs = [speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice] |
|
|
|
|
|
|
|
|
voice_samples = [] |
|
|
for i, speaker_num in enumerate(speakers_in_text): |
|
|
idx = speaker_num - 1 |
|
|
|
|
|
|
|
|
if idx < len(voice_inputs) and voice_inputs[idx] is not None: |
|
|
voice_sample = self._prepare_voice_sample(voice_inputs[idx], idx) |
|
|
if voice_sample is None: |
|
|
|
|
|
voice_sample = self._create_synthetic_voice_sample(idx) |
|
|
else: |
|
|
|
|
|
voice_sample = self._create_synthetic_voice_sample(idx) |
|
|
|
|
|
voice_samples.append(voice_sample) |
|
|
|
|
|
|
|
|
if len(voice_samples) != len(speakers_in_text): |
|
|
logger.error(f"Mismatch: {len(speakers_in_text)} speakers but {len(voice_samples)} voice samples!") |
|
|
raise Exception(f"Voice sample count mismatch: expected {len(speakers_in_text)}, got {len(voice_samples)}") |
|
|
|
|
|
|
|
|
if 'speaker_segments_with_pauses' in locals() and speaker_segments_with_pauses: |
|
|
|
|
|
all_audio_segments = [] |
|
|
sample_rate = 24000 |
|
|
|
|
|
|
|
|
grouped_segments = [] |
|
|
current_group = [] |
|
|
current_speaker = None |
|
|
|
|
|
for seg_type, seg_content, speaker_num in speaker_segments_with_pauses: |
|
|
if seg_type == 'pause': |
|
|
|
|
|
if current_group: |
|
|
grouped_segments.append(('text_group', current_group, current_speaker)) |
|
|
current_group = [] |
|
|
current_speaker = None |
|
|
|
|
|
grouped_segments.append(('pause', seg_content, None)) |
|
|
else: |
|
|
|
|
|
if speaker_num == current_speaker: |
|
|
|
|
|
current_group.append(seg_content) |
|
|
else: |
|
|
|
|
|
if current_group: |
|
|
grouped_segments.append(('text_group', current_group, current_speaker)) |
|
|
current_group = [seg_content] |
|
|
current_speaker = speaker_num |
|
|
|
|
|
|
|
|
if current_group: |
|
|
grouped_segments.append(('text_group', current_group, current_speaker)) |
|
|
|
|
|
|
|
|
for seg_type, seg_content, speaker_num in grouped_segments: |
|
|
if seg_type == 'pause': |
|
|
|
|
|
duration_ms = seg_content |
|
|
logger.info(f"Adding {duration_ms}ms pause") |
|
|
silence_audio = self._generate_silence(duration_ms, sample_rate) |
|
|
all_audio_segments.append(silence_audio) |
|
|
else: |
|
|
|
|
|
combined_text = ' '.join(seg_content) |
|
|
formatted_text = f"Speaker {speaker_num - 1}: {combined_text}" |
|
|
|
|
|
|
|
|
speaker_idx = speakers_in_text.index(speaker_num) |
|
|
speaker_voice_samples = [voice_samples[speaker_idx]] |
|
|
|
|
|
logger.info(f"Generating audio for Speaker {speaker_num}: {len(combined_text.split())} words") |
|
|
|
|
|
|
|
|
segment_audio = self._generate_with_vibevoice( |
|
|
formatted_text, speaker_voice_samples, cfg_scale, seed, |
|
|
diffusion_steps, use_sampling, temperature, top_p |
|
|
) |
|
|
|
|
|
all_audio_segments.append(segment_audio) |
|
|
|
|
|
|
|
|
if all_audio_segments: |
|
|
logger.info(f"Concatenating {len(all_audio_segments)} audio segments (including pauses)...") |
|
|
|
|
|
|
|
|
waveforms = [] |
|
|
for audio_segment in all_audio_segments: |
|
|
if isinstance(audio_segment, dict) and "waveform" in audio_segment: |
|
|
waveforms.append(audio_segment["waveform"]) |
|
|
|
|
|
if waveforms: |
|
|
|
|
|
valid_waveforms = [w for w in waveforms if w is not None] |
|
|
|
|
|
if valid_waveforms: |
|
|
|
|
|
combined_waveform = torch.cat(valid_waveforms, dim=-1) |
|
|
|
|
|
audio_dict = { |
|
|
"waveform": combined_waveform, |
|
|
"sample_rate": sample_rate |
|
|
} |
|
|
logger.info(f"Successfully generated multi-speaker audio with pauses") |
|
|
else: |
|
|
raise Exception("No valid audio waveforms generated") |
|
|
else: |
|
|
raise Exception("Failed to extract waveforms from audio segments") |
|
|
else: |
|
|
raise Exception("No audio segments generated") |
|
|
else: |
|
|
|
|
|
logger.info("Processing without pause support (no pause keywords found)") |
|
|
audio_dict = self._generate_with_vibevoice( |
|
|
converted_text, voice_samples, cfg_scale, seed, diffusion_steps, |
|
|
use_sampling, temperature, top_p |
|
|
) |
|
|
|
|
|
|
|
|
if free_memory_after_generate: |
|
|
self.free_memory() |
|
|
|
|
|
return (audio_dict,) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
import comfy.model_management as mm |
|
|
if isinstance(e, mm.InterruptProcessingException): |
|
|
|
|
|
logger.info("Generation interrupted by user") |
|
|
raise |
|
|
else: |
|
|
|
|
|
logger.error(f"Multi-speaker speech generation failed: {str(e)}") |
|
|
raise Exception(f"Error generating multi-speaker speech: {str(e)}") |
|
|
|
|
|
@classmethod |
|
|
def IS_CHANGED(cls, text="", model="VibeVoice-7B-Preview", |
|
|
speaker1_voice=None, speaker2_voice=None, |
|
|
speaker3_voice=None, speaker4_voice=None, **kwargs): |
|
|
"""Cache key for ComfyUI""" |
|
|
voices_hash = hash(str([speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice])) |
|
|
return f"{hash(text)}_{model}_{voices_hash}_{kwargs.get('cfg_scale', 1.3)}_{kwargs.get('seed', 0)}" |