File size: 13,487 Bytes
2967cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
from typing import List

import numpy as np
import torch
import torchaudio
from pydub import AudioSegment
from pydub.silence import detect_leading_silence, split_on_silence

punctuation = {";", ":", ",", ".", "!", "?", ";", ":", ",", "。", "!", "?"}


def chunk_tokens_punctuation(tokens_list: List[str], max_tokens: int = 100):
    """
    Splits the input tokens list into chunks according to punctuations,
        each with a maximum number of tokens.

    Args:
        token_list (list of str): The list of tokens to be split.
        max_tokens (int): The maximum number of tokens per chunk.

    Returns:
        List[str]: A list of text chunks.
    """

    # 1. Split the tokens according to punctuations.
    sentences = []
    current_sentence = []
    for token in tokens_list:
        # If the first token of current sentence is punctuation or blank,
        # append it to the end of the previous sentence.
        if (
            len(current_sentence) == 0
            and len(sentences) != 0
            and (token in punctuation or token == " ")
        ):
            sentences[-1].append(token)
        # Otherwise, append the current token to the current sentence.
        else:
            current_sentence.append(token)
            # Split the sentence in positions of punctuations.
            if token in punctuation:
                sentences.append(current_sentence)
                current_sentence = []
    # Assume the last few tokens are also a sentence
    if len(current_sentence) != 0:
        sentences.append(current_sentence)

    # 2. Merge short sentences.
    chunks = []
    current_chunk = []
    for sentence in sentences:
        if len(current_chunk) + len(sentence) <= max_tokens:
            current_chunk.extend(sentence)
        else:
            if len(current_chunk) > 0:
                chunks.append(current_chunk)
            current_chunk = sentence

    if len(current_chunk) > 0:
        chunks.append(current_chunk)

    return chunks


def chunk_tokens_dialog(tokens_list: List[str], max_tokens: int = 100):
    """
    Splits the input tokens list into chunks according to speaker-turn
        symbol [S1], each with a maximum number of tokens.

    Args:
        token_list (list of str): The list of tokens to be split.
        max_tokens (int): The maximum number of tokens per chunk.

    Returns:
        List[str]: A list of text chunks.
    """

    # 1. Split the tokens according to speaker-turn symbol [S1].
    dialogs = []
    current_dialog = []
    for token in tokens_list:
        if token == "[S1]":
            if len(current_dialog) != 0:
                dialogs.append(current_dialog)
            current_dialog = []
        current_dialog.append(token)
    # Assume the last few tokens are also a dialog
    if len(current_dialog) != 0:
        dialogs.append(current_dialog)

    # 2. Merge short dialogs.
    chunks = []
    current_chunk = []
    for dialog in dialogs:
        if len(current_chunk) + len(dialog) <= max_tokens:
            current_chunk.extend(dialog)
        else:
            if len(current_chunk) > 0:
                chunks.append(current_chunk)
            current_chunk = dialog

    if len(current_chunk) > 0:
        chunks.append(current_chunk)

    return chunks


def batchify_tokens(
    tokens_list: List[List[int]],
    max_duration: float,
    prompt_duration: float,
    token_duration: float,
):
    """
    Sort and group the input list of token sequences into batches, where each batch's
        total duration does not exceed the maximum.

    Args:
        tokens_list (List[List[int]]): A list of token sequences, where each inner
            list represents a sequence of tokens.
        max_duration (float): The maximum allowed total duration for each batch.
        prompt_duration (float): The duration cost per prompt in the batch.
        token_duration (float): The duration cost per token.

    Returns:
        batches: List[List[List[int]]]: A list of batches, where each batch is a list of
            token sequences that fit within the max duration.
        index: List[int]: The original index of each sentence, used to recover the
            sequential order in the future.
    """
    # Create index for each sentence
    indexed_tokens = list(enumerate(tokens_list))

    # Sort according to sentence length (for less padding)
    indexed_sorted_tokens = sorted(indexed_tokens, key=lambda x: len(x[1]))
    index = [indexed_sorted_tokens[i][0] for i in range(len(indexed_sorted_tokens))]
    sorted_tokens = [
        indexed_sorted_tokens[i][1] for i in range(len(indexed_sorted_tokens))
    ]

    batches = []
    batch = []
    batch_size = 0  # Total number of tokens in current batch

    for tokens in sorted_tokens:
        # Calculate if adding current token sequence would exceed max duration
        # Formula considers: existing tokens' duration + existing
        # prompts' duration + new tokens' duration
        if (
            batch_size * token_duration
            + len(batch) * prompt_duration
            + len(tokens) * token_duration
            <= max_duration
        ):
            # Add to current batch if within duration limit
            batch.append(tokens)
            batch_size += len(tokens)
        else:
            # If exceeding limit, finalize current batch (if not empty)
            if len(batch) > 0:
                batches.append(batch)
            # Start new batch with current token sequence
            batch = [tokens]
            batch_size = len(tokens)

    # Add the last batch if it's not empty
    if len(batch) > 0:
        batches.append(batch)

    return batches, index


def cross_fade_concat(
    chunks: List[torch.Tensor], fade_duration: float = 0.1, sample_rate: int = 24000
) -> torch.Tensor:
    """
    Concatenates audio chunks with cross-fading between consecutive chunks.

    Args:
        chunks: List of audio tensors, each with shape (C, T) where
                C = number of channel, T = time dimension (samples)
        fade_duration: Duration of cross-fade in seconds
        sample_rate: Audio sample rate in Hz

    Returns:
        Concatenated audio tensor with shape (N, T_total)
    """
    # Handle edge cases: empty input or single chunk
    if len(chunks) <= 1:
        return chunks[0] if chunks else torch.tensor([])

    # Calculate total fade samples from duration and sample rate
    fade_samples = int(fade_duration * sample_rate)

    # Use simple concatenation if fade duration is non-positive
    if fade_samples <= 0:
        return torch.cat(chunks, dim=-1)

    # Initialize final tensor with the first chunk
    final = chunks[0]

    # Iterate through remaining chunks to apply cross-fading
    for next_chunk in chunks[1:]:
        # Calculate safe fade length (cannot exceed either chunk's duration)
        k = min(fade_samples, final.shape[-1], next_chunk.shape[-1])

        # Fall back to simple concatenation if safe fade length is invalid
        if k <= 0:
            final = torch.cat([final, next_chunk], dim=-1)
            continue

        # Create fade curve (1 -> 0) with shape (1, k) for broadcasting
        fade = torch.linspace(1, 0, k, device=final.device)[None]

        # Concatenate three parts:
        # 1. Non-overlapping part of previous audio
        # 2. Cross-faded overlapping region
        # 3. Non-overlapping part of next audio
        final = torch.cat(
            [
                final[..., :-k],  # All samples except last k from previous
                final[..., -k:] * fade
                + next_chunk[..., :k] * (1 - fade),  # Cross-fade region
                next_chunk[..., k:],  # All samples except first k from next
            ],
            dim=-1,
        )

    return final


def add_punctuation(text: str):
    """Add punctuation if there is not in the end of text"""
    text = text.strip()
    if text[-1] not in punctuation:
        text += "."
    return text


def load_prompt_wav(prompt_wav: str, sampling_rate: int):
    """
    Load the waveform with torchaudio and resampling if needed.

    Parameters:
        prompt_wav: path of the prompt wav.
        sampling_rate: target sampling rate.

    Returns:
        Loaded prompt waveform with target sampling rate,
        PyTorch tensor of shape (C, T)
    """
    prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)

    if prompt_sampling_rate != sampling_rate:
        resampler = torchaudio.transforms.Resample(
            orig_freq=prompt_sampling_rate, new_freq=sampling_rate
        )
        prompt_wav = resampler(prompt_wav)
    return prompt_wav


def rms_norm(prompt_wav: torch.Tensor, target_rms: float):
    """
    Normalize the rms of prompt_wav is it is smaller than target rms.

    Parameters:
        prompt_wav: PyTorch tensor with shape (C, T).
        target_rms: target rms value

    Returns:
        prompt_wav: normalized prompt wav with shape (C, T).
        promt_rms: rms of original prompt wav. Will be used to
            re-normalize the generated wav.
    """
    prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
    if prompt_rms < target_rms:
        prompt_wav = prompt_wav * target_rms / prompt_rms
    return prompt_wav, prompt_rms


def remove_silence(
    audio: torch.Tensor,
    sampling_rate: int,
    only_edge: bool = False,
    trail_sil: float = 0,
):
    """
    Remove silences longer than 1 second, and edge silences longer than 0.1 seconds

    Parameters:
        audio: PyTorch tensor with shape (C, T).
        sampling_rate: sampling rate of the audio.
        only_edge: If true, only remove edge silences.
        trail_sil: the duration of added trailing silence in ms.

    Returns:
        PyTorch tensor with shape (C, T), where C is number of channels
            and T is number of audio samples
    """
    # Load audio file
    wave = tensor_to_audiosegment(audio, sampling_rate)

    if not only_edge:
        # Split audio using silences longer than 1 second
        non_silent_segs = split_on_silence(
            wave,
            min_silence_len=1000,  # Silences longer than 1 second (1000ms)
            silence_thresh=-50,
            keep_silence=1000,  # Keep 1.0 second of silence around segments
            seek_step=10,
        )

        # Concatenate all non-silent segments
        wave = AudioSegment.silent(duration=0)
        for seg in non_silent_segs:
            wave += seg

    # Remove silence longer than 0.1 seconds in the begining and ending of wave
    wave = remove_silence_edges(wave, 100, -50)

    # Add trailing silence to avoid leaking prompt to generated speech.
    wave = wave + AudioSegment.silent(duration=trail_sil)

    # Convert to PyTorch tensor
    return audiosegment_to_tensor(wave)


def remove_silence_edges(
    audio: AudioSegment, keep_silence: int = 100, silence_threshold: float = -50
):
    """
    Remove edge silences longer than `keep_silence` ms.

    Parameters:
        audio: an AudioSegment object.
        keep_silence: kept silence in the edge.
        only_edge: If true, only remove edge silences.
        silence_threshold: the threshold of silence.

    Returns:
        An AudioSegment object
    """
    # Remove leading silence
    start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
    start_idx = max(0, start_idx - keep_silence)
    audio = audio[start_idx:]

    # Remove trailing silence
    audio = audio.reverse()
    start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
    start_idx = max(0, start_idx - keep_silence)
    audio = audio[start_idx:]
    audio = audio.reverse()

    return audio


def audiosegment_to_tensor(aseg):
    """
    Convert a pydub.AudioSegment to PyTorch audio tensor
    """
    audio_data = np.array(aseg.get_array_of_samples())

    # Convert to float32 and normalize to [-1, 1] range
    audio_data = audio_data.astype(np.float32) / 32768.0

    # Handle channels
    if aseg.channels == 1:
        # Mono channel: add channel dimension (T) -> (1, T)
        tensor_data = torch.from_numpy(audio_data).unsqueeze(0)
    else:
        # Multi-channel: reshape to (C, T)
        tensor_data = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)

    return tensor_data


def tensor_to_audiosegment(tensor, sample_rate):
    """
    Convert a PyTorch audio tensor to pydub.AudioSegment

    Parameters:
        tensor: Tensor with shape (C, T), where C is the number of channels
            and T is the time steps
        sample_rate: Audio sample rate
    """
    # Convert tensor to numpy array
    audio_np = tensor.cpu().numpy()

    # Add channel dimension if single channel
    if audio_np.ndim == 1:
        audio_np = audio_np[np.newaxis, :]

    # Convert to int16 type (common format for pydub)
    # Assumes tensor values are in [-1, 1] range as floating point
    audio_np = (audio_np * 32768.0).clip(-32768, 32767).astype(np.int16)

    # Convert to byte stream
    # For multi-channel audio, pydub requires interleaved format
    # (e.g., left-right-left-right)
    if audio_np.shape[0] > 1:
        # Convert to interleaved format
        audio_np = audio_np.transpose(1, 0).flatten()
    audio_bytes = audio_np.tobytes()

    # Create AudioSegment
    audio_segment = AudioSegment(
        data=audio_bytes,
        sample_width=2,
        frame_rate=sample_rate,
        channels=tensor.shape[0],
    )

    return audio_segment