Vibevoice_1_5_lora / data_vibevoice.py
DevParker's picture
Upload 8 files
86e8346 verified
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import warnings
import random
try:
import librosa # type: ignore
except Exception: # pragma: no cover
librosa = None # Fallback: user must install librosa when using local audio paths
try:
import resampy # type: ignore
except Exception: # pragma: no cover
resampy = None
def _resample_if_needed(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return wav.astype(np.float32, copy=False)
if resampy is not None:
return resampy.resample(wav.astype(np.float32), orig_sr, target_sr)
if librosa is not None:
return librosa.resample(y=wav.astype(np.float32), orig_sr=orig_sr, target_sr=target_sr)
warnings.warn(
"No resampler available; treating audio as target_sr without resampling. Install resampy or librosa.",
RuntimeWarning,
)
return wav.astype(np.float32, copy=False)
# Lightweight HF-style dataset wrapper (optional). Trainer can also pass raw HF datasets directly.
class VibeVoiceDataset:
def __init__(
self,
dataset: Any,
text_column: str = "text",
audio_column: str = "audio",
voice_prompts_column: Optional[str] = "voice_prompts",
) -> None:
self.dataset = dataset
self.text_column = text_column
self.audio_column = audio_column
self.voice_prompts_column = voice_prompts_column
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.dataset[idx]
data: Dict[str, Any] = {}
data["text"] = item[self.text_column]
data["audio"] = item[self.audio_column]
user_provided_prompt = None
if self.voice_prompts_column and self.voice_prompts_column in item:
user_provided_prompt = item[self.voice_prompts_column]
if user_provided_prompt:
# A prompt was provided in the dataset, so we use it.
if not isinstance(user_provided_prompt, list):
data["voice_prompts"] = [user_provided_prompt]
else:
data["voice_prompts"] = user_provided_prompt
else:
# FALLBACK: No prompt provided, so we auto-generate one from the target audio.
try:
target_sr = 24000
wav_array = _load_audio_to_24k(item[self.audio_column], target_sr=target_sr)
audio_len_seconds = len(wav_array) / target_sr
min_len_sec = min(5.0, audio_len_seconds / 4.0)
max_len_sec = min(15.0, audio_len_seconds / 2.0)
if min_len_sec > max_len_sec:
min_len_sec = max_len_sec
max_len_sec = min(max_len_sec, audio_len_seconds)
if max_len_sec > 0.1:
prompt_len_sec = random.uniform(min_len_sec, max_len_sec)
prompt_len_samples = int(prompt_len_sec * target_sr)
max_start_sample = len(wav_array) - prompt_len_samples
start_sample = random.randint(0, max_start_sample)
prompt_crop = wav_array[start_sample : start_sample + prompt_len_samples]
data["voice_prompts"] = [prompt_crop]
else:
data["voice_prompts"] = None
except Exception as e:
warnings.warn(f"Could not create voice prompt for item {idx}: {e}")
data["voice_prompts"] = None
return data
def _load_audio_to_24k(audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]], *, target_sr: int = 24000) -> np.ndarray:
if isinstance(audio, np.ndarray):
return audio.astype(np.float32)
if isinstance(audio, torch.Tensor):
return audio.detach().cpu().float().numpy()
if isinstance(audio, str):
if librosa is None:
raise RuntimeError("librosa is required to load audio file paths. Please pip install librosa.")
wav, sr = librosa.load(audio, sr=None, mono=True)
wav = _resample_if_needed(wav, int(sr), target_sr)
return wav
if isinstance(audio, dict) and "array" in audio and "sampling_rate" in audio:
arr = np.asarray(audio["array"], dtype=np.float32)
sr = int(audio["sampling_rate"])
arr = _resample_if_needed(arr, sr, target_sr)
return arr
raise ValueError(f"Unsupported audio type: {type(audio)}")
@dataclass
class VibeVoiceCollator:
processor: Any # VibeVoiceProcessor
max_length: Optional[int] = None
speech_compress_ratio: int = 3200
semantic_vae_dim: int = 128
compute_semantics: bool = False
debug_checks: bool = False
text_field: str = "text"
audio_field: str = "audio"
voice_prompts_field: str = "voice_prompts"
voice_prompt_drop_rate: float = 0.0
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
batch_size = len(features)
sample_input_ids: List[List[int]] = []
sample_attention_masks: List[List[int]] = []
sample_acoustic_input_masks: List[List[bool]] = []
sample_acoustic_loss_masks: List[List[bool]] = []
all_speech_waveforms: List[np.ndarray] = []
all_speech_latent_lengths: List[int] = []
per_segment_is_target: List[bool] = []
for ex in features:
text: str = ex.get(self.text_field, "")
voice_prompts: Optional[List[Union[str, np.ndarray, torch.Tensor]]] = ex.get(self.voice_prompts_field)
target_audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]] = ex.get(self.audio_field)
# Clamp drop rate for safety
_drop_rate = self.voice_prompt_drop_rate
if _drop_rate < 0.0:
_drop_rate = 0.0
elif _drop_rate > 1.0:
_drop_rate = 1.0
proc = self.processor(
text=[text],
voice_samples=[voice_prompts] if voice_prompts is not None and random.random() >= _drop_rate else None,
padding=False,
truncation=False,
max_length=self.max_length,
return_tensors="pt",
)
ids = proc["input_ids"][0].tolist()
attn = proc.get("attention_mask", torch.ones_like(proc["input_ids"]))[0].tolist()
speech_input_mask = proc.get("speech_input_mask")
if speech_input_mask is None:
speech_input_mask = torch.zeros_like(proc["input_ids"], dtype=torch.bool)
speech_input_mask_list = speech_input_mask[0].tolist()
wav_target = _load_audio_to_24k(target_audio, target_sr=24000)
# Prefer exact frame count from acoustic tokenizer if available; fallback to compress ratio
target_latent_len = None
try:
acoustic_tok = getattr(self.processor, "acoustic_tokenizer", None)
if acoustic_tok is not None and hasattr(acoustic_tok, "encode"):
enc_out = acoustic_tok.encode(wav_target)
# Normalize various possible return formats to get time dimension
T = None
try:
# Direct array-like with shape (T, D) or (T,)
if hasattr(enc_out, "shape") and len(getattr(enc_out, "shape", [])) >= 1:
T = int(enc_out.shape[0])
else:
# Nested lists/tuples or ModelOutput-like
cand = enc_out
# Drill down a couple of levels safely
for _ in range(2):
if isinstance(cand, (list, tuple)) and len(cand) > 0:
cand = cand[0]
if hasattr(cand, "shape") and len(getattr(cand, "shape", [])) >= 1:
T = int(cand.shape[0])
except Exception:
T = None
if T is not None and T > 0:
target_latent_len = T
except Exception:
target_latent_len = None
if target_latent_len is None:
target_latent_len = max(1, int(math.ceil(len(wav_target) / float(self.speech_compress_ratio))))
speech_diff_id = self.processor.tokenizer.speech_diffusion_id
target_placeholders = [speech_diff_id] * target_latent_len
ids_extended = ids + target_placeholders
attn_extended = attn + [1] * target_latent_len
acoustic_input_mask = speech_input_mask_list + [True] * target_latent_len
acoustic_loss_mask = ([False] * len(speech_input_mask_list)) + [True] * target_latent_len
# Add speech_end_id token
speech_end_id = self.processor.tokenizer.speech_end_id
ids_extended.append(speech_end_id)
attn_extended.append(1)
acoustic_input_mask.append(False)
acoustic_loss_mask.append(False)
# FIXED: Add actual EOS token after speech_end_id to properly terminate generation
eos_token_id = self.processor.tokenizer.eos_token_id
ids_extended.append(eos_token_id)
attn_extended.append(1)
acoustic_input_mask.append(False)
acoustic_loss_mask.append(False)
if self.max_length is not None and len(ids_extended) > self.max_length:
cut = len(ids_extended) - int(self.max_length)
leading_non_acoustic = 0
for v in acoustic_input_mask:
if v:
break
leading_non_acoustic += 1
if cut > leading_non_acoustic:
raise ValueError(
f"--max_length={self.max_length} would truncate into acoustic tokens. "
f"Needed cut={cut}, but only {leading_non_acoustic} leading non-acoustic tokens available. "
"Increase max_length or shorten text/voice-prompt preamble."
)
ids_extended = ids_extended[cut:]
attn_extended = attn_extended[cut:]
acoustic_input_mask = acoustic_input_mask[cut:]
acoustic_loss_mask = acoustic_loss_mask[cut:]
sample_input_ids.append(ids_extended)
sample_attention_masks.append(attn_extended)
sample_acoustic_input_masks.append(acoustic_input_mask)
sample_acoustic_loss_masks.append(acoustic_loss_mask)
voice_speeches = []
voice_latent_lengths = []
if proc.get("speech_tensors") is not None:
voice_np = proc["speech_tensors"].cpu().numpy()
voice_masks = proc["speech_masks"].cpu().numpy().astype(bool)
for seg_idx in range(voice_np.shape[0]):
voice_speeches.append(voice_np[seg_idx])
voice_latent_lengths.append(int(voice_masks[seg_idx].sum()))
all_speech_waveforms.extend(voice_speeches)
all_speech_latent_lengths.extend(voice_latent_lengths)
per_segment_is_target.extend([False] * len(voice_speeches))
all_speech_waveforms.append(wav_target)
all_speech_latent_lengths.append(target_latent_len)
per_segment_is_target.append(True)
max_seq_len = max(len(x) for x in sample_input_ids)
padded_input_ids = []
padded_attention_masks = []
padded_acoustic_input_masks = []
padded_acoustic_loss_masks = []
tok = self.processor.tokenizer
pad_token_id = getattr(tok, "pad_token_id", None)
if pad_token_id is None or pad_token_id < 0:
pad_token_id = getattr(tok, "eos_token_id", None)
if pad_token_id is None or pad_token_id < 0:
raise ValueError(
"Tokenizer has no pad_token_id or eos_token_id; please set one or pass a valid pad id."
)
for ids, attn, ain_mask, aloss_mask in zip(
sample_input_ids, sample_attention_masks, sample_acoustic_input_masks, sample_acoustic_loss_masks
):
pad_len = max_seq_len - len(ids)
padded_input_ids.append(ids + [pad_token_id] * pad_len)
padded_attention_masks.append(attn + [0] * pad_len)
padded_acoustic_input_masks.append(ain_mask + [False] * pad_len)
padded_acoustic_loss_masks.append(aloss_mask + [False] * pad_len)
input_ids_tensor = torch.tensor(padded_input_ids, dtype=torch.long)
attention_mask_tensor = torch.tensor(padded_attention_masks, dtype=torch.long)
acoustic_input_mask_tensor = torch.tensor(padded_acoustic_input_masks, dtype=torch.bool)
acoustic_loss_mask_tensor = torch.tensor(padded_acoustic_loss_masks, dtype=torch.bool)
if all_speech_waveforms:
max_wave_len = max(w.shape[0] for w in all_speech_waveforms)
padded_speeches = np.zeros((len(all_speech_waveforms), max_wave_len), dtype=np.float32)
for i, w in enumerate(all_speech_waveforms):
L = w.shape[0]
padded_speeches[i, :L] = w
max_latent_len = max(all_speech_latent_lengths) if all_speech_latent_lengths else 1
speech_masks_np = np.zeros((len(all_speech_waveforms), max_latent_len), dtype=np.bool_)
for i, L_lat in enumerate(all_speech_latent_lengths):
speech_masks_np[i, :L_lat] = True
speech_tensors_tensor = torch.tensor(padded_speeches, dtype=torch.float32)
speech_masks_tensor = torch.tensor(speech_masks_np, dtype=torch.bool)
speeches_loss_input_np = np.zeros_like(speech_masks_np, dtype=np.bool_)
for i, is_target in enumerate(per_segment_is_target):
if is_target:
speeches_loss_input_np[i] = speech_masks_np[i]
speeches_loss_input_tensor = torch.tensor(speeches_loss_input_np, dtype=torch.bool)
# Semantic features
if self.compute_semantics and hasattr(self.processor, "semantic_tokenizer") and self.processor.semantic_tokenizer is not None:
sem_feats: List[np.ndarray] = []
for w in all_speech_waveforms:
try:
# Expect [T, D] where T ≈ ceil(len(w)/compress_ratio)
sem = self.processor.semantic_tokenizer.encode(w)
sem = np.asarray(sem, dtype=np.float32)
except Exception:
sem = np.zeros((0, self.semantic_vae_dim), dtype=np.float32)
if sem.ndim != 2:
raise RuntimeError(f"Semantic tokenizer returned unexpected shape {sem.shape}. Expect [T, D].")
L = sem.shape[0]
D = sem.shape[1]
if D != self.semantic_vae_dim:
if D < self.semantic_vae_dim:
pad_d = np.zeros((L, self.semantic_vae_dim - D), dtype=np.float32)
sem = np.concatenate([sem, pad_d], axis=1)
else:
sem = sem[:, : self.semantic_vae_dim]
if L < max_latent_len:
pad = np.zeros((max_latent_len - L, self.semantic_vae_dim), dtype=np.float32)
sem = np.concatenate([sem, pad], axis=0)
elif L > max_latent_len:
sem = sem[:max_latent_len]
sem_feats.append(sem.astype(np.float32))
speech_semantic_tensors = torch.tensor(np.stack(sem_feats, axis=0), dtype=torch.float32)
else:
# Semantic tokenizer unavailable while semantics are required for training.
# Raise to avoid silently degrading alignment with zeroed features.
raise RuntimeError(
"Semantic features are required but could not be computed. "
"Ensure processor.semantic_tokenizer is available or precompute and provide features."
)
else:
speech_tensors_tensor = None
speech_masks_tensor = None
speeches_loss_input_tensor = None
speech_semantic_tensors = None # No segments in batch
if self.debug_checks:
assert (input_ids_tensor >= 0).all(), "input_ids contains negative indices"
if speech_tensors_tensor is not None:
assert speech_tensors_tensor.dim() == 2, "Expected speech_tensors 2D [segments, samples]"
return {
"input_ids": input_ids_tensor,
"attention_mask": attention_mask_tensor,
"speech_tensors": speech_tensors_tensor,
"speech_masks": speech_masks_tensor,
"speech_semantic_tensors": speech_semantic_tensors,
"acoustic_input_mask": acoustic_input_mask_tensor,
"acoustic_loss_mask": acoustic_loss_mask_tensor,
"speeches_loss_input": speeches_loss_input_tensor,
}