File size: 17,353 Bytes
86e8346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import warnings
import random

try:
    import librosa  # type: ignore
except Exception:  # pragma: no cover
    librosa = None  # Fallback: user must install librosa when using local audio paths

try:
    import resampy  # type: ignore
except Exception:  # pragma: no cover
    resampy = None


def _resample_if_needed(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
    if orig_sr == target_sr:
        return wav.astype(np.float32, copy=False)
    if resampy is not None:
        return resampy.resample(wav.astype(np.float32), orig_sr, target_sr)
    if librosa is not None:
        return librosa.resample(y=wav.astype(np.float32), orig_sr=orig_sr, target_sr=target_sr)
    warnings.warn(
        "No resampler available; treating audio as target_sr without resampling. Install resampy or librosa.",
        RuntimeWarning,
    )
    return wav.astype(np.float32, copy=False)


# Lightweight HF-style dataset wrapper (optional). Trainer can also pass raw HF datasets directly.
class VibeVoiceDataset:
    def __init__(
        self,
        dataset: Any,
        text_column: str = "text",
        audio_column: str = "audio",
        voice_prompts_column: Optional[str] = "voice_prompts",
    ) -> None:
        self.dataset = dataset
        self.text_column = text_column
        self.audio_column = audio_column
        self.voice_prompts_column = voice_prompts_column

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item = self.dataset[idx]
        data: Dict[str, Any] = {}
        data["text"] = item[self.text_column]
        data["audio"] = item[self.audio_column]

        user_provided_prompt = None
        if self.voice_prompts_column and self.voice_prompts_column in item:
            user_provided_prompt = item[self.voice_prompts_column]

        if user_provided_prompt:
            # A prompt was provided in the dataset, so we use it.
            if not isinstance(user_provided_prompt, list):
                data["voice_prompts"] = [user_provided_prompt]
            else:
                data["voice_prompts"] = user_provided_prompt
        else:
            # FALLBACK: No prompt provided, so we auto-generate one from the target audio.
            try:
                target_sr = 24000
                wav_array = _load_audio_to_24k(item[self.audio_column], target_sr=target_sr)
                audio_len_seconds = len(wav_array) / target_sr

                min_len_sec = min(5.0, audio_len_seconds / 4.0)
                max_len_sec = min(15.0, audio_len_seconds / 2.0)

                if min_len_sec > max_len_sec:
                    min_len_sec = max_len_sec
                max_len_sec = min(max_len_sec, audio_len_seconds)

                if max_len_sec > 0.1:
                    prompt_len_sec = random.uniform(min_len_sec, max_len_sec)
                    prompt_len_samples = int(prompt_len_sec * target_sr)

                    max_start_sample = len(wav_array) - prompt_len_samples
                    start_sample = random.randint(0, max_start_sample)

                    prompt_crop = wav_array[start_sample : start_sample + prompt_len_samples]

                    data["voice_prompts"] = [prompt_crop]
                else:
                    data["voice_prompts"] = None

            except Exception as e:
                warnings.warn(f"Could not create voice prompt for item {idx}: {e}")
                data["voice_prompts"] = None
        return data


def _load_audio_to_24k(audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]], *, target_sr: int = 24000) -> np.ndarray:
    if isinstance(audio, np.ndarray):
        return audio.astype(np.float32)
    if isinstance(audio, torch.Tensor):
        return audio.detach().cpu().float().numpy()
    if isinstance(audio, str):
        if librosa is None:
            raise RuntimeError("librosa is required to load audio file paths. Please pip install librosa.")
        wav, sr = librosa.load(audio, sr=None, mono=True)
        wav = _resample_if_needed(wav, int(sr), target_sr)
        return wav
    if isinstance(audio, dict) and "array" in audio and "sampling_rate" in audio:
        arr = np.asarray(audio["array"], dtype=np.float32)
        sr = int(audio["sampling_rate"])
        arr = _resample_if_needed(arr, sr, target_sr)
        return arr
    raise ValueError(f"Unsupported audio type: {type(audio)}")


@dataclass
class VibeVoiceCollator:
    processor: Any  # VibeVoiceProcessor
    max_length: Optional[int] = None
    speech_compress_ratio: int = 3200
    semantic_vae_dim: int = 128
    compute_semantics: bool = False
    debug_checks: bool = False

    text_field: str = "text"
    audio_field: str = "audio"
    voice_prompts_field: str = "voice_prompts"
    voice_prompt_drop_rate: float = 0.0

    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
        batch_size = len(features)

        sample_input_ids: List[List[int]] = []
        sample_attention_masks: List[List[int]] = []
        sample_acoustic_input_masks: List[List[bool]] = []
        sample_acoustic_loss_masks: List[List[bool]] = []

        all_speech_waveforms: List[np.ndarray] = []
        all_speech_latent_lengths: List[int] = []
        per_segment_is_target: List[bool] = []

        for ex in features:
            text: str = ex.get(self.text_field, "")
            voice_prompts: Optional[List[Union[str, np.ndarray, torch.Tensor]]] = ex.get(self.voice_prompts_field)
            target_audio: Union[str, np.ndarray, torch.Tensor, Dict[str, Any]] = ex.get(self.audio_field)

            # Clamp drop rate for safety
            _drop_rate = self.voice_prompt_drop_rate
            if _drop_rate < 0.0:
                _drop_rate = 0.0
            elif _drop_rate > 1.0:
                _drop_rate = 1.0

            proc = self.processor(
                text=[text],
                voice_samples=[voice_prompts] if voice_prompts is not None and random.random() >= _drop_rate else None,
                padding=False,
                truncation=False,
                max_length=self.max_length,
                return_tensors="pt",
            )

            ids = proc["input_ids"][0].tolist()
            attn = proc.get("attention_mask", torch.ones_like(proc["input_ids"]))[0].tolist()
            speech_input_mask = proc.get("speech_input_mask")
            if speech_input_mask is None:
                speech_input_mask = torch.zeros_like(proc["input_ids"], dtype=torch.bool)
            speech_input_mask_list = speech_input_mask[0].tolist()

            wav_target = _load_audio_to_24k(target_audio, target_sr=24000)
            # Prefer exact frame count from acoustic tokenizer if available; fallback to compress ratio
            target_latent_len = None
            try:
                acoustic_tok = getattr(self.processor, "acoustic_tokenizer", None)
                if acoustic_tok is not None and hasattr(acoustic_tok, "encode"):
                    enc_out = acoustic_tok.encode(wav_target)
                    # Normalize various possible return formats to get time dimension
                    T = None
                    try:
                        # Direct array-like with shape (T, D) or (T,)
                        if hasattr(enc_out, "shape") and len(getattr(enc_out, "shape", [])) >= 1:
                            T = int(enc_out.shape[0])
                        else:
                            # Nested lists/tuples or ModelOutput-like
                            cand = enc_out
                            # Drill down a couple of levels safely
                            for _ in range(2):
                                if isinstance(cand, (list, tuple)) and len(cand) > 0:
                                    cand = cand[0]
                            if hasattr(cand, "shape") and len(getattr(cand, "shape", [])) >= 1:
                                T = int(cand.shape[0])
                    except Exception:
                        T = None
                    if T is not None and T > 0:
                        target_latent_len = T
            except Exception:
                target_latent_len = None
            if target_latent_len is None:
                target_latent_len = max(1, int(math.ceil(len(wav_target) / float(self.speech_compress_ratio))))

            speech_diff_id = self.processor.tokenizer.speech_diffusion_id
            target_placeholders = [speech_diff_id] * target_latent_len

            ids_extended = ids + target_placeholders
            attn_extended = attn + [1] * target_latent_len

            acoustic_input_mask = speech_input_mask_list + [True] * target_latent_len
            acoustic_loss_mask = ([False] * len(speech_input_mask_list)) + [True] * target_latent_len

            # Add speech_end_id token
            speech_end_id = self.processor.tokenizer.speech_end_id
            ids_extended.append(speech_end_id)
            attn_extended.append(1)
            acoustic_input_mask.append(False)
            acoustic_loss_mask.append(False)

            # FIXED: Add actual EOS token after speech_end_id to properly terminate generation
            eos_token_id = self.processor.tokenizer.eos_token_id
            ids_extended.append(eos_token_id)
            attn_extended.append(1)
            acoustic_input_mask.append(False)
            acoustic_loss_mask.append(False)

            if self.max_length is not None and len(ids_extended) > self.max_length:
                cut = len(ids_extended) - int(self.max_length)
                leading_non_acoustic = 0
                for v in acoustic_input_mask:
                    if v:
                        break
                    leading_non_acoustic += 1
                if cut > leading_non_acoustic:
                    raise ValueError(
                        f"--max_length={self.max_length} would truncate into acoustic tokens. "
                        f"Needed cut={cut}, but only {leading_non_acoustic} leading non-acoustic tokens available. "
                        "Increase max_length or shorten text/voice-prompt preamble."
                    )
                ids_extended = ids_extended[cut:]
                attn_extended = attn_extended[cut:]
                acoustic_input_mask = acoustic_input_mask[cut:]
                acoustic_loss_mask = acoustic_loss_mask[cut:]

            sample_input_ids.append(ids_extended)
            sample_attention_masks.append(attn_extended)
            sample_acoustic_input_masks.append(acoustic_input_mask)
            sample_acoustic_loss_masks.append(acoustic_loss_mask)

            voice_speeches = []
            voice_latent_lengths = []
            if proc.get("speech_tensors") is not None:
                voice_np = proc["speech_tensors"].cpu().numpy()
                voice_masks = proc["speech_masks"].cpu().numpy().astype(bool)
                for seg_idx in range(voice_np.shape[0]):
                    voice_speeches.append(voice_np[seg_idx])
                    voice_latent_lengths.append(int(voice_masks[seg_idx].sum()))

            all_speech_waveforms.extend(voice_speeches)
            all_speech_latent_lengths.extend(voice_latent_lengths)
            per_segment_is_target.extend([False] * len(voice_speeches))

            all_speech_waveforms.append(wav_target)
            all_speech_latent_lengths.append(target_latent_len)
            per_segment_is_target.append(True)

        max_seq_len = max(len(x) for x in sample_input_ids)
        padded_input_ids = []
        padded_attention_masks = []
        padded_acoustic_input_masks = []
        padded_acoustic_loss_masks = []
        tok = self.processor.tokenizer
        pad_token_id = getattr(tok, "pad_token_id", None)
        if pad_token_id is None or pad_token_id < 0:
            pad_token_id = getattr(tok, "eos_token_id", None)
            if pad_token_id is None or pad_token_id < 0:
                raise ValueError(
                    "Tokenizer has no pad_token_id or eos_token_id; please set one or pass a valid pad id."
                )
        for ids, attn, ain_mask, aloss_mask in zip(
            sample_input_ids, sample_attention_masks, sample_acoustic_input_masks, sample_acoustic_loss_masks
        ):
            pad_len = max_seq_len - len(ids)
            padded_input_ids.append(ids + [pad_token_id] * pad_len)
            padded_attention_masks.append(attn + [0] * pad_len)
            padded_acoustic_input_masks.append(ain_mask + [False] * pad_len)
            padded_acoustic_loss_masks.append(aloss_mask + [False] * pad_len)

        input_ids_tensor = torch.tensor(padded_input_ids, dtype=torch.long)
        attention_mask_tensor = torch.tensor(padded_attention_masks, dtype=torch.long)
        acoustic_input_mask_tensor = torch.tensor(padded_acoustic_input_masks, dtype=torch.bool)
        acoustic_loss_mask_tensor = torch.tensor(padded_acoustic_loss_masks, dtype=torch.bool)

        if all_speech_waveforms:
            max_wave_len = max(w.shape[0] for w in all_speech_waveforms)
            padded_speeches = np.zeros((len(all_speech_waveforms), max_wave_len), dtype=np.float32)
            for i, w in enumerate(all_speech_waveforms):
                L = w.shape[0]
                padded_speeches[i, :L] = w

            max_latent_len = max(all_speech_latent_lengths) if all_speech_latent_lengths else 1
            speech_masks_np = np.zeros((len(all_speech_waveforms), max_latent_len), dtype=np.bool_)
            for i, L_lat in enumerate(all_speech_latent_lengths):
                speech_masks_np[i, :L_lat] = True

            speech_tensors_tensor = torch.tensor(padded_speeches, dtype=torch.float32)
            speech_masks_tensor = torch.tensor(speech_masks_np, dtype=torch.bool)

            speeches_loss_input_np = np.zeros_like(speech_masks_np, dtype=np.bool_)
            for i, is_target in enumerate(per_segment_is_target):
                if is_target:
                    speeches_loss_input_np[i] = speech_masks_np[i]
            speeches_loss_input_tensor = torch.tensor(speeches_loss_input_np, dtype=torch.bool)

            # Semantic features
            if self.compute_semantics and hasattr(self.processor, "semantic_tokenizer") and self.processor.semantic_tokenizer is not None:
                sem_feats: List[np.ndarray] = []
                for w in all_speech_waveforms:
                    try:
                        # Expect [T, D]  where T ≈ ceil(len(w)/compress_ratio)
                        sem = self.processor.semantic_tokenizer.encode(w)
                        sem = np.asarray(sem, dtype=np.float32)
                    except Exception:
                        sem = np.zeros((0, self.semantic_vae_dim), dtype=np.float32)
                    if sem.ndim != 2:
                        raise RuntimeError(f"Semantic tokenizer returned unexpected shape {sem.shape}. Expect [T, D].")
                    L = sem.shape[0]
                    D = sem.shape[1]
                    if D != self.semantic_vae_dim:
                        if D < self.semantic_vae_dim:
                            pad_d = np.zeros((L, self.semantic_vae_dim - D), dtype=np.float32)
                            sem = np.concatenate([sem, pad_d], axis=1)
                        else:
                            sem = sem[:, : self.semantic_vae_dim]
                    if L < max_latent_len:
                        pad = np.zeros((max_latent_len - L, self.semantic_vae_dim), dtype=np.float32)
                        sem = np.concatenate([sem, pad], axis=0)
                    elif L > max_latent_len:
                        sem = sem[:max_latent_len]
                    sem_feats.append(sem.astype(np.float32))
                speech_semantic_tensors = torch.tensor(np.stack(sem_feats, axis=0), dtype=torch.float32)
            else:
                # Semantic tokenizer unavailable while semantics are required for training.
                # Raise to avoid silently degrading alignment with zeroed features.
                raise RuntimeError(
                    "Semantic features are required but could not be computed. "
                    "Ensure processor.semantic_tokenizer is available or precompute and provide features."
                )
        else:
            speech_tensors_tensor = None
            speech_masks_tensor = None
            speeches_loss_input_tensor = None
            speech_semantic_tensors = None  # No segments in batch

        if self.debug_checks:
            assert (input_ids_tensor >= 0).all(), "input_ids contains negative indices"
            if speech_tensors_tensor is not None:
                assert speech_tensors_tensor.dim() == 2, "Expected speech_tensors 2D [segments, samples]"

        return {
            "input_ids": input_ids_tensor,
            "attention_mask": attention_mask_tensor,
            "speech_tensors": speech_tensors_tensor,
            "speech_masks": speech_masks_tensor,
            "speech_semantic_tensors": speech_semantic_tensors,
            "acoustic_input_mask": acoustic_input_mask_tensor,
            "acoustic_loss_mask": acoustic_loss_mask_tensor,
            "speeches_loss_input": speeches_loss_input_tensor,
        }