Spaces:
Paused
Paused
| from typing import List | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from pydub import AudioSegment | |
| from pydub.silence import detect_leading_silence, split_on_silence | |
| punctuation = {";", ":", ",", ".", "!", "?", ";", ":", ",", "。", "!", "?"} | |
| def chunk_tokens_punctuation(tokens_list: List[str], max_tokens: int = 100): | |
| """ | |
| Splits the input tokens list into chunks according to punctuations, | |
| each with a maximum number of tokens. | |
| Args: | |
| token_list (list of str): The list of tokens to be split. | |
| max_tokens (int): The maximum number of tokens per chunk. | |
| Returns: | |
| List[str]: A list of text chunks. | |
| """ | |
| # 1. Split the tokens according to punctuations. | |
| sentences = [] | |
| current_sentence = [] | |
| for token in tokens_list: | |
| # If the first token of current sentence is punctuation or blank, | |
| # append it to the end of the previous sentence. | |
| if ( | |
| len(current_sentence) == 0 | |
| and len(sentences) != 0 | |
| and (token in punctuation or token == " ") | |
| ): | |
| sentences[-1].append(token) | |
| # Otherwise, append the current token to the current sentence. | |
| else: | |
| current_sentence.append(token) | |
| # Split the sentence in positions of punctuations. | |
| if token in punctuation: | |
| sentences.append(current_sentence) | |
| current_sentence = [] | |
| # Assume the last few tokens are also a sentence | |
| if len(current_sentence) != 0: | |
| sentences.append(current_sentence) | |
| # 2. Merge short sentences. | |
| chunks = [] | |
| current_chunk = [] | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) <= max_tokens: | |
| current_chunk.extend(sentence) | |
| else: | |
| if len(current_chunk) > 0: | |
| chunks.append(current_chunk) | |
| current_chunk = sentence | |
| if len(current_chunk) > 0: | |
| chunks.append(current_chunk) | |
| return chunks | |
| def chunk_tokens_dialog(tokens_list: List[str], max_tokens: int = 100): | |
| """ | |
| Splits the input tokens list into chunks according to speaker-turn | |
| symbol [S1], each with a maximum number of tokens. | |
| Args: | |
| token_list (list of str): The list of tokens to be split. | |
| max_tokens (int): The maximum number of tokens per chunk. | |
| Returns: | |
| List[str]: A list of text chunks. | |
| """ | |
| # 1. Split the tokens according to speaker-turn symbol [S1]. | |
| dialogs = [] | |
| current_dialog = [] | |
| for token in tokens_list: | |
| if token == "[S1]": | |
| if len(current_dialog) != 0: | |
| dialogs.append(current_dialog) | |
| current_dialog = [] | |
| current_dialog.append(token) | |
| # Assume the last few tokens are also a dialog | |
| if len(current_dialog) != 0: | |
| dialogs.append(current_dialog) | |
| # 2. Merge short dialogs. | |
| chunks = [] | |
| current_chunk = [] | |
| for dialog in dialogs: | |
| if len(current_chunk) + len(dialog) <= max_tokens: | |
| current_chunk.extend(dialog) | |
| else: | |
| if len(current_chunk) > 0: | |
| chunks.append(current_chunk) | |
| current_chunk = dialog | |
| if len(current_chunk) > 0: | |
| chunks.append(current_chunk) | |
| return chunks | |
| def batchify_tokens( | |
| tokens_list: List[List[int]], | |
| max_duration: float, | |
| prompt_duration: float, | |
| token_duration: float, | |
| ): | |
| """ | |
| Sort and group the input list of token sequences into batches, where each batch's | |
| total duration does not exceed the maximum. | |
| Args: | |
| tokens_list (List[List[int]]): A list of token sequences, where each inner | |
| list represents a sequence of tokens. | |
| max_duration (float): The maximum allowed total duration for each batch. | |
| prompt_duration (float): The duration cost per prompt in the batch. | |
| token_duration (float): The duration cost per token. | |
| Returns: | |
| batches: List[List[List[int]]]: A list of batches, where each batch is a list of | |
| token sequences that fit within the max duration. | |
| index: List[int]: The original index of each sentence, used to recover the | |
| sequential order in the future. | |
| """ | |
| # Create index for each sentence | |
| indexed_tokens = list(enumerate(tokens_list)) | |
| # Sort according to sentence length (for less padding) | |
| indexed_sorted_tokens = sorted(indexed_tokens, key=lambda x: len(x[1])) | |
| index = [indexed_sorted_tokens[i][0] for i in range(len(indexed_sorted_tokens))] | |
| sorted_tokens = [ | |
| indexed_sorted_tokens[i][1] for i in range(len(indexed_sorted_tokens)) | |
| ] | |
| batches = [] | |
| batch = [] | |
| batch_size = 0 # Total number of tokens in current batch | |
| for tokens in sorted_tokens: | |
| # Calculate if adding current token sequence would exceed max duration | |
| # Formula considers: existing tokens' duration + existing | |
| # prompts' duration + new tokens' duration | |
| if ( | |
| batch_size * token_duration | |
| + len(batch) * prompt_duration | |
| + len(tokens) * token_duration | |
| <= max_duration | |
| ): | |
| # Add to current batch if within duration limit | |
| batch.append(tokens) | |
| batch_size += len(tokens) | |
| else: | |
| # If exceeding limit, finalize current batch (if not empty) | |
| if len(batch) > 0: | |
| batches.append(batch) | |
| # Start new batch with current token sequence | |
| batch = [tokens] | |
| batch_size = len(tokens) | |
| # Add the last batch if it's not empty | |
| if len(batch) > 0: | |
| batches.append(batch) | |
| return batches, index | |
| def cross_fade_concat( | |
| chunks: List[torch.Tensor], fade_duration: float = 0.1, sample_rate: int = 24000 | |
| ) -> torch.Tensor: | |
| """ | |
| Concatenates audio chunks with cross-fading between consecutive chunks. | |
| Args: | |
| chunks: List of audio tensors, each with shape (C, T) where | |
| C = number of channel, T = time dimension (samples) | |
| fade_duration: Duration of cross-fade in seconds | |
| sample_rate: Audio sample rate in Hz | |
| Returns: | |
| Concatenated audio tensor with shape (N, T_total) | |
| """ | |
| # Handle edge cases: empty input or single chunk | |
| if len(chunks) <= 1: | |
| return chunks[0] if chunks else torch.tensor([]) | |
| # Calculate total fade samples from duration and sample rate | |
| fade_samples = int(fade_duration * sample_rate) | |
| # Use simple concatenation if fade duration is non-positive | |
| if fade_samples <= 0: | |
| return torch.cat(chunks, dim=-1) | |
| # Initialize final tensor with the first chunk | |
| final = chunks[0] | |
| # Iterate through remaining chunks to apply cross-fading | |
| for next_chunk in chunks[1:]: | |
| # Calculate safe fade length (cannot exceed either chunk's duration) | |
| k = min(fade_samples, final.shape[-1], next_chunk.shape[-1]) | |
| # Fall back to simple concatenation if safe fade length is invalid | |
| if k <= 0: | |
| final = torch.cat([final, next_chunk], dim=-1) | |
| continue | |
| # Create fade curve (1 -> 0) with shape (1, k) for broadcasting | |
| fade = torch.linspace(1, 0, k, device=final.device)[None] | |
| # Concatenate three parts: | |
| # 1. Non-overlapping part of previous audio | |
| # 2. Cross-faded overlapping region | |
| # 3. Non-overlapping part of next audio | |
| final = torch.cat( | |
| [ | |
| final[..., :-k], # All samples except last k from previous | |
| final[..., -k:] * fade | |
| + next_chunk[..., :k] * (1 - fade), # Cross-fade region | |
| next_chunk[..., k:], # All samples except first k from next | |
| ], | |
| dim=-1, | |
| ) | |
| return final | |
| def add_punctuation(text: str): | |
| """Add punctuation if there is not in the end of text""" | |
| text = text.strip() | |
| if text[-1] not in punctuation: | |
| text += "." | |
| return text | |
| def load_prompt_wav(prompt_wav: str, sampling_rate: int): | |
| """ | |
| Load the waveform with torchaudio and resampling if needed. | |
| Parameters: | |
| prompt_wav: path of the prompt wav. | |
| sampling_rate: target sampling rate. | |
| Returns: | |
| Loaded prompt waveform with target sampling rate, | |
| PyTorch tensor of shape (C, T) | |
| """ | |
| prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) | |
| if prompt_sampling_rate != sampling_rate: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=prompt_sampling_rate, new_freq=sampling_rate | |
| ) | |
| prompt_wav = resampler(prompt_wav) | |
| return prompt_wav | |
| def rms_norm(prompt_wav: torch.Tensor, target_rms: float): | |
| """ | |
| Normalize the rms of prompt_wav is it is smaller than target rms. | |
| Parameters: | |
| prompt_wav: PyTorch tensor with shape (C, T). | |
| target_rms: target rms value | |
| Returns: | |
| prompt_wav: normalized prompt wav with shape (C, T). | |
| promt_rms: rms of original prompt wav. Will be used to | |
| re-normalize the generated wav. | |
| """ | |
| prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) | |
| if prompt_rms < target_rms: | |
| prompt_wav = prompt_wav * target_rms / prompt_rms | |
| return prompt_wav, prompt_rms | |
| def remove_silence( | |
| audio: torch.Tensor, | |
| sampling_rate: int, | |
| only_edge: bool = False, | |
| trail_sil: float = 0, | |
| ): | |
| """ | |
| Remove silences longer than 1 second, and edge silences longer than 0.1 seconds | |
| Parameters: | |
| audio: PyTorch tensor with shape (C, T). | |
| sampling_rate: sampling rate of the audio. | |
| only_edge: If true, only remove edge silences. | |
| trail_sil: the duration of added trailing silence in ms. | |
| Returns: | |
| PyTorch tensor with shape (C, T), where C is number of channels | |
| and T is number of audio samples | |
| """ | |
| # Load audio file | |
| wave = tensor_to_audiosegment(audio, sampling_rate) | |
| if not only_edge: | |
| # Split audio using silences longer than 1 second | |
| non_silent_segs = split_on_silence( | |
| wave, | |
| min_silence_len=1000, # Silences longer than 1 second (1000ms) | |
| silence_thresh=-50, | |
| keep_silence=1000, # Keep 1.0 second of silence around segments | |
| seek_step=10, | |
| ) | |
| # Concatenate all non-silent segments | |
| wave = AudioSegment.silent(duration=0) | |
| for seg in non_silent_segs: | |
| wave += seg | |
| # Remove silence longer than 0.1 seconds in the begining and ending of wave | |
| wave = remove_silence_edges(wave, 100, -50) | |
| # Add trailing silence to avoid leaking prompt to generated speech. | |
| wave = wave + AudioSegment.silent(duration=trail_sil) | |
| # Convert to PyTorch tensor | |
| return audiosegment_to_tensor(wave) | |
| def remove_silence_edges( | |
| audio: AudioSegment, keep_silence: int = 100, silence_threshold: float = -50 | |
| ): | |
| """ | |
| Remove edge silences longer than `keep_silence` ms. | |
| Parameters: | |
| audio: an AudioSegment object. | |
| keep_silence: kept silence in the edge. | |
| only_edge: If true, only remove edge silences. | |
| silence_threshold: the threshold of silence. | |
| Returns: | |
| An AudioSegment object | |
| """ | |
| # Remove leading silence | |
| start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold) | |
| start_idx = max(0, start_idx - keep_silence) | |
| audio = audio[start_idx:] | |
| # Remove trailing silence | |
| audio = audio.reverse() | |
| start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold) | |
| start_idx = max(0, start_idx - keep_silence) | |
| audio = audio[start_idx:] | |
| audio = audio.reverse() | |
| return audio | |
| def audiosegment_to_tensor(aseg): | |
| """ | |
| Convert a pydub.AudioSegment to PyTorch audio tensor | |
| """ | |
| audio_data = np.array(aseg.get_array_of_samples()) | |
| # Convert to float32 and normalize to [-1, 1] range | |
| audio_data = audio_data.astype(np.float32) / 32768.0 | |
| # Handle channels | |
| if aseg.channels == 1: | |
| # Mono channel: add channel dimension (T) -> (1, T) | |
| tensor_data = torch.from_numpy(audio_data).unsqueeze(0) | |
| else: | |
| # Multi-channel: reshape to (C, T) | |
| tensor_data = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T) | |
| return tensor_data | |
| def tensor_to_audiosegment(tensor, sample_rate): | |
| """ | |
| Convert a PyTorch audio tensor to pydub.AudioSegment | |
| Parameters: | |
| tensor: Tensor with shape (C, T), where C is the number of channels | |
| and T is the time steps | |
| sample_rate: Audio sample rate | |
| """ | |
| # Convert tensor to numpy array | |
| audio_np = tensor.cpu().numpy() | |
| # Add channel dimension if single channel | |
| if audio_np.ndim == 1: | |
| audio_np = audio_np[np.newaxis, :] | |
| # Convert to int16 type (common format for pydub) | |
| # Assumes tensor values are in [-1, 1] range as floating point | |
| audio_np = (audio_np * 32768.0).clip(-32768, 32767).astype(np.int16) | |
| # Convert to byte stream | |
| # For multi-channel audio, pydub requires interleaved format | |
| # (e.g., left-right-left-right) | |
| if audio_np.shape[0] > 1: | |
| # Convert to interleaved format | |
| audio_np = audio_np.transpose(1, 0).flatten() | |
| audio_bytes = audio_np.tobytes() | |
| # Create AudioSegment | |
| audio_segment = AudioSegment( | |
| data=audio_bytes, | |
| sample_width=2, | |
| frame_rate=sample_rate, | |
| channels=tensor.shape[0], | |
| ) | |
| return audio_segment | |