from typing import List import numpy as np import torch import torchaudio from pydub import AudioSegment from pydub.silence import detect_leading_silence, split_on_silence punctuation = {";", ":", ",", ".", "!", "?", ";", ":", ",", "。", "!", "?"} def chunk_tokens_punctuation(tokens_list: List[str], max_tokens: int = 100): """ Splits the input tokens list into chunks according to punctuations, each with a maximum number of tokens. Args: token_list (list of str): The list of tokens to be split. max_tokens (int): The maximum number of tokens per chunk. Returns: List[str]: A list of text chunks. """ # 1. Split the tokens according to punctuations. sentences = [] current_sentence = [] for token in tokens_list: # If the first token of current sentence is punctuation or blank, # append it to the end of the previous sentence. if ( len(current_sentence) == 0 and len(sentences) != 0 and (token in punctuation or token == " ") ): sentences[-1].append(token) # Otherwise, append the current token to the current sentence. else: current_sentence.append(token) # Split the sentence in positions of punctuations. if token in punctuation: sentences.append(current_sentence) current_sentence = [] # Assume the last few tokens are also a sentence if len(current_sentence) != 0: sentences.append(current_sentence) # 2. Merge short sentences. chunks = [] current_chunk = [] for sentence in sentences: if len(current_chunk) + len(sentence) <= max_tokens: current_chunk.extend(sentence) else: if len(current_chunk) > 0: chunks.append(current_chunk) current_chunk = sentence if len(current_chunk) > 0: chunks.append(current_chunk) return chunks def chunk_tokens_dialog(tokens_list: List[str], max_tokens: int = 100): """ Splits the input tokens list into chunks according to speaker-turn symbol [S1], each with a maximum number of tokens. Args: token_list (list of str): The list of tokens to be split. max_tokens (int): The maximum number of tokens per chunk. Returns: List[str]: A list of text chunks. """ # 1. Split the tokens according to speaker-turn symbol [S1]. dialogs = [] current_dialog = [] for token in tokens_list: if token == "[S1]": if len(current_dialog) != 0: dialogs.append(current_dialog) current_dialog = [] current_dialog.append(token) # Assume the last few tokens are also a dialog if len(current_dialog) != 0: dialogs.append(current_dialog) # 2. Merge short dialogs. chunks = [] current_chunk = [] for dialog in dialogs: if len(current_chunk) + len(dialog) <= max_tokens: current_chunk.extend(dialog) else: if len(current_chunk) > 0: chunks.append(current_chunk) current_chunk = dialog if len(current_chunk) > 0: chunks.append(current_chunk) return chunks def batchify_tokens( tokens_list: List[List[int]], max_duration: float, prompt_duration: float, token_duration: float, ): """ Sort and group the input list of token sequences into batches, where each batch's total duration does not exceed the maximum. Args: tokens_list (List[List[int]]): A list of token sequences, where each inner list represents a sequence of tokens. max_duration (float): The maximum allowed total duration for each batch. prompt_duration (float): The duration cost per prompt in the batch. token_duration (float): The duration cost per token. Returns: batches: List[List[List[int]]]: A list of batches, where each batch is a list of token sequences that fit within the max duration. index: List[int]: The original index of each sentence, used to recover the sequential order in the future. """ # Create index for each sentence indexed_tokens = list(enumerate(tokens_list)) # Sort according to sentence length (for less padding) indexed_sorted_tokens = sorted(indexed_tokens, key=lambda x: len(x[1])) index = [indexed_sorted_tokens[i][0] for i in range(len(indexed_sorted_tokens))] sorted_tokens = [ indexed_sorted_tokens[i][1] for i in range(len(indexed_sorted_tokens)) ] batches = [] batch = [] batch_size = 0 # Total number of tokens in current batch for tokens in sorted_tokens: # Calculate if adding current token sequence would exceed max duration # Formula considers: existing tokens' duration + existing # prompts' duration + new tokens' duration if ( batch_size * token_duration + len(batch) * prompt_duration + len(tokens) * token_duration <= max_duration ): # Add to current batch if within duration limit batch.append(tokens) batch_size += len(tokens) else: # If exceeding limit, finalize current batch (if not empty) if len(batch) > 0: batches.append(batch) # Start new batch with current token sequence batch = [tokens] batch_size = len(tokens) # Add the last batch if it's not empty if len(batch) > 0: batches.append(batch) return batches, index def cross_fade_concat( chunks: List[torch.Tensor], fade_duration: float = 0.1, sample_rate: int = 24000 ) -> torch.Tensor: """ Concatenates audio chunks with cross-fading between consecutive chunks. Args: chunks: List of audio tensors, each with shape (C, T) where C = number of channel, T = time dimension (samples) fade_duration: Duration of cross-fade in seconds sample_rate: Audio sample rate in Hz Returns: Concatenated audio tensor with shape (N, T_total) """ # Handle edge cases: empty input or single chunk if len(chunks) <= 1: return chunks[0] if chunks else torch.tensor([]) # Calculate total fade samples from duration and sample rate fade_samples = int(fade_duration * sample_rate) # Use simple concatenation if fade duration is non-positive if fade_samples <= 0: return torch.cat(chunks, dim=-1) # Initialize final tensor with the first chunk final = chunks[0] # Iterate through remaining chunks to apply cross-fading for next_chunk in chunks[1:]: # Calculate safe fade length (cannot exceed either chunk's duration) k = min(fade_samples, final.shape[-1], next_chunk.shape[-1]) # Fall back to simple concatenation if safe fade length is invalid if k <= 0: final = torch.cat([final, next_chunk], dim=-1) continue # Create fade curve (1 -> 0) with shape (1, k) for broadcasting fade = torch.linspace(1, 0, k, device=final.device)[None] # Concatenate three parts: # 1. Non-overlapping part of previous audio # 2. Cross-faded overlapping region # 3. Non-overlapping part of next audio final = torch.cat( [ final[..., :-k], # All samples except last k from previous final[..., -k:] * fade + next_chunk[..., :k] * (1 - fade), # Cross-fade region next_chunk[..., k:], # All samples except first k from next ], dim=-1, ) return final def add_punctuation(text: str): """Add punctuation if there is not in the end of text""" text = text.strip() if text[-1] not in punctuation: text += "." return text def load_prompt_wav(prompt_wav: str, sampling_rate: int): """ Load the waveform with torchaudio and resampling if needed. Parameters: prompt_wav: path of the prompt wav. sampling_rate: target sampling rate. Returns: Loaded prompt waveform with target sampling rate, PyTorch tensor of shape (C, T) """ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) if prompt_sampling_rate != sampling_rate: resampler = torchaudio.transforms.Resample( orig_freq=prompt_sampling_rate, new_freq=sampling_rate ) prompt_wav = resampler(prompt_wav) return prompt_wav def rms_norm(prompt_wav: torch.Tensor, target_rms: float): """ Normalize the rms of prompt_wav is it is smaller than target rms. Parameters: prompt_wav: PyTorch tensor with shape (C, T). target_rms: target rms value Returns: prompt_wav: normalized prompt wav with shape (C, T). promt_rms: rms of original prompt wav. Will be used to re-normalize the generated wav. """ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) if prompt_rms < target_rms: prompt_wav = prompt_wav * target_rms / prompt_rms return prompt_wav, prompt_rms def remove_silence( audio: torch.Tensor, sampling_rate: int, only_edge: bool = False, trail_sil: float = 0, ): """ Remove silences longer than 1 second, and edge silences longer than 0.1 seconds Parameters: audio: PyTorch tensor with shape (C, T). sampling_rate: sampling rate of the audio. only_edge: If true, only remove edge silences. trail_sil: the duration of added trailing silence in ms. Returns: PyTorch tensor with shape (C, T), where C is number of channels and T is number of audio samples """ # Load audio file wave = tensor_to_audiosegment(audio, sampling_rate) if not only_edge: # Split audio using silences longer than 1 second non_silent_segs = split_on_silence( wave, min_silence_len=1000, # Silences longer than 1 second (1000ms) silence_thresh=-50, keep_silence=1000, # Keep 1.0 second of silence around segments seek_step=10, ) # Concatenate all non-silent segments wave = AudioSegment.silent(duration=0) for seg in non_silent_segs: wave += seg # Remove silence longer than 0.1 seconds in the begining and ending of wave wave = remove_silence_edges(wave, 100, -50) # Add trailing silence to avoid leaking prompt to generated speech. wave = wave + AudioSegment.silent(duration=trail_sil) # Convert to PyTorch tensor return audiosegment_to_tensor(wave) def remove_silence_edges( audio: AudioSegment, keep_silence: int = 100, silence_threshold: float = -50 ): """ Remove edge silences longer than `keep_silence` ms. Parameters: audio: an AudioSegment object. keep_silence: kept silence in the edge. only_edge: If true, only remove edge silences. silence_threshold: the threshold of silence. Returns: An AudioSegment object """ # Remove leading silence start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold) start_idx = max(0, start_idx - keep_silence) audio = audio[start_idx:] # Remove trailing silence audio = audio.reverse() start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold) start_idx = max(0, start_idx - keep_silence) audio = audio[start_idx:] audio = audio.reverse() return audio def audiosegment_to_tensor(aseg): """ Convert a pydub.AudioSegment to PyTorch audio tensor """ audio_data = np.array(aseg.get_array_of_samples()) # Convert to float32 and normalize to [-1, 1] range audio_data = audio_data.astype(np.float32) / 32768.0 # Handle channels if aseg.channels == 1: # Mono channel: add channel dimension (T) -> (1, T) tensor_data = torch.from_numpy(audio_data).unsqueeze(0) else: # Multi-channel: reshape to (C, T) tensor_data = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T) return tensor_data def tensor_to_audiosegment(tensor, sample_rate): """ Convert a PyTorch audio tensor to pydub.AudioSegment Parameters: tensor: Tensor with shape (C, T), where C is the number of channels and T is the time steps sample_rate: Audio sample rate """ # Convert tensor to numpy array audio_np = tensor.cpu().numpy() # Add channel dimension if single channel if audio_np.ndim == 1: audio_np = audio_np[np.newaxis, :] # Convert to int16 type (common format for pydub) # Assumes tensor values are in [-1, 1] range as floating point audio_np = (audio_np * 32768.0).clip(-32768, 32767).astype(np.int16) # Convert to byte stream # For multi-channel audio, pydub requires interleaved format # (e.g., left-right-left-right) if audio_np.shape[0] > 1: # Convert to interleaved format audio_np = audio_np.transpose(1, 0).flatten() audio_bytes = audio_np.tobytes() # Create AudioSegment audio_segment = AudioSegment( data=audio_bytes, sample_width=2, frame_rate=sample_rate, channels=tensor.shape[0], ) return audio_segment