# Copyright (c) 2023 OpenAI. (authors: Whisper Team) # 2024 Tsinghua Univ. (authors: Xingchen Song) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py Add rename_weights() & onnx2torch() & make_non_pad_mask() & mask_to_bias() Copy merge_tokenized_segments() from https://github.com/Mddct/s3tokenizer-long/blob/main/example.py """ import os from functools import lru_cache from typing import List, Optional, Union import numpy as np import onnx import torch import torch.nn.functional as F import torchaudio from torch.nn.utils.rnn import pad_sequence def _rename_weights(weights_dict: dict): """ Rename onnx weights to pytorch format. Parameters ---------- weight_dict: dict The dict containing weights in onnx format Returns ------- A new weight dict containing the weights in pytorch format. """ new_weight_dict = {} for k in weights_dict.keys(): if "quantizer" in k: # vq or fsq if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1": new_weight_dict["quantizer._codebook.embed"] = weights_dict[k] elif 'project_down' in k: # v2 new_weight_dict[k] = weights_dict[k] elif "positional_embedding" in k: # positional emb new_weight_dict[k] = weights_dict[k] elif "conv" in k: # 1/2 or 1/4 subsample new_weight_dict[k] = weights_dict[k] else: # transformer blocks assert "blocks" in k new_k = (k[1:].replace('/', '.').replace( 'MatMul', 'weight').replace('Add_1', 'bias').replace( 'Mul', 'weight').replace('Add', 'bias').replace( 'mlp.mlp', 'mlp')).replace('fsmn_block.Conv', 'fsmn_block.weight') new_weight_dict[f"encoder.{new_k}"] = weights_dict[k] return new_weight_dict def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False): """ Open an onnx file and convert to pytorch format. Parameters ---------- onnx_path: str The onnx file to open, typically `speech_tokenizer_v1.onnx` torch_path: str The path to save the torch-formated checkpoint. verbose: bool Logging info or not. Returns ------- A checkpoint dict containing the weights and their names, if torch_path is None. Otherwise save checkpoint dict to the desired path. """ onnx_model = onnx.load(onnx_path) weights_dict = {} initializer_map = { initializer.name: initializer for initializer in onnx_model.graph.initializer } for node in onnx_model.graph.node: for input_name in node.input: if input_name in initializer_map: ln_bias_name, ln_weight_name = None, None # for v2 ln initializer = initializer_map[input_name] if input_name in [ "onnx::Conv_1519", "encoders.conv1.weight", "onnx::Conv_2216", ]: # v1_50hz, v1_25hz, v2_25hz weight_name = "encoder.conv1.weight" elif input_name in [ "onnx::Conv_1520", "encoders.conv1.bias", "onnx::Conv_2217", ]: # v1_50hz, v1_25hz, v2_25hz weight_name = "encoder.conv1.bias" elif input_name in [ "onnx::Conv_1521", "encoders.conv2.weight", "onnx::Conv_2218", ]: weight_name = "encoder.conv2.weight" elif input_name in [ "onnx::Conv_1522", "encoders.conv2.bias", "onnx::Conv_2219", ]: weight_name = "encoder.conv2.bias" elif input_name == "encoders.positional_embedding": weight_name = "encoder.positional_embedding" elif input_name == 'quantizer.project_in.bias': weight_name = "quantizer._codebook.project_down.bias" elif input_name == 'onnx::MatMul_2536': weight_name = "quantizer._codebook.project_down.weight" else: if node.op_type == 'LayerNormalization': # in input_name: ln_name = node.name.replace('/LayerNormalization', '') ln_weight_name = ln_name + '.weight' ln_bias_name = ln_name + '.bias' else: weight_name = node.name if ln_weight_name is not None and ln_bias_name is not None: ln_inputs = node.input scale_name = ln_inputs[1] bias_name = ln_inputs[2] scale = onnx.numpy_helper.to_array( initializer_map[scale_name]).copy( ) if scale_name in initializer_map else None bias = onnx.numpy_helper.to_array( initializer_map[bias_name]).copy( ) if bias_name in initializer_map else None scale.flags.writeable = True bias.flags.writeable = True weight_tensor = torch.from_numpy(scale) bias_tensor = torch.from_numpy(bias) weights_dict[ln_bias_name] = bias_tensor weights_dict[ln_weight_name] = weight_tensor else: weight_array = onnx.numpy_helper.to_array( initializer).copy() weight_array.flags.writeable = True weight_tensor = torch.from_numpy(weight_array) if len(weight_tensor.shape) > 2 or weight_name in [ "encoder.positional_embedding" ]: weights_dict[weight_name] = weight_tensor else: weights_dict[weight_name] = weight_tensor.t() new_weights_dict = _rename_weights(weights_dict) if verbose: for k, v in new_weights_dict.items(): print(f"{k} : {v.shape} {v.dtype}") print(f"PyTorch weights saved to {torch_path}") del weights_dict, onnx_model if torch_path: torch.save(new_weights_dict, torch_path) else: return new_weights_dict def load_audio(file: str, sr: int = 16000): """ Open an audio file and read as mono waveform, resampling as necessary Parameters ---------- file: str The audio file to open sr: int The sample rate to resample the audio if necessary Returns ------- A torch.Tensor containing the audio waveform, in float32 dtype. """ audio, sample_rate = torchaudio.load(file) if sample_rate != sr: audio = torchaudio.transforms.Resample(sample_rate, sr)(audio) audio = audio[0] # get the first channel return audio @lru_cache(maxsize=None) def _mel_filters(device, n_mels: int) -> torch.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), ) """ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") with np.load(filters_path, allow_pickle=False) as f: return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) def log_mel_spectrogram( audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = 128, padding: int = 0, device: Optional[Union[str, torch.device]] = None, ): """ Compute the log-Mel spectrogram of Parameters ---------- audio: Union[str, np.ndarray, torch.Tensor], shape = (*) The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz n_mels: int The number of Mel-frequency filters, only 80 is supported padding: int Number of zero samples to pad to the right device: Optional[Union[str, torch.device]] If given, the audio tensor is moved to this device before STFT Returns ------- torch.Tensor, shape = (128, n_frames) A Tensor that contains the Mel spectrogram """ if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio) if device is not None: audio = audio.to(device) if padding > 0: audio = F.pad(audio, (0, padding)) window = torch.hann_window(400).to(audio.device) stft = torch.stft(audio, 400, 160, window=window, return_complex=True) magnitudes = stft[..., :-1].abs()**2 filters = _mel_filters(audio.device, n_mels) mel_spec = filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """Make mask tensor containing indices of non-padded part. The sequences in a batch may have different lengths. To enable batch computing, padding is need to make all sequence in same size. To avoid the padding part pass value to context dependent block such as attention or convolution , this padding part is masked. 1 for non-padded part and 0 for padded part. Parameters ---------- lengths (torch.Tensor): Batch of lengths (B,). Returns: ------- torch.Tensor: Mask tensor containing indices of padded part (B, max_T). Examples: >>> import torch >>> import s3tokenizer >>> lengths = torch.tensor([5, 3, 2]) >>> masks = s3tokenizer.make_non_pad_mask(lengths) masks = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] """ batch_size = lengths.size(0) max_len = max_len if max_len > 0 else lengths.max().item() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return ~mask def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: """Convert bool-tensor to float-tensor for flash attention. Parameters ---------- lengths (torch.Tensor): Batch of lengths (B, ?). Returns: ------- torch.Tensor: Mask tensor containing indices of padded part (B, ?). Examples: >>> import torch >>> import s3tokenizer >>> lengths = torch.tensor([5, 3, 2]) >>> masks = s3tokenizer.make_non_pad_mask(lengths) masks = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] >>> new_masks = s3tokenizer.mask_to_bias(masks, torch.float32) new_masks = [[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00], [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10], [-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]] """ assert mask.dtype == torch.bool assert dtype in [torch.float32, torch.bfloat16, torch.float16] mask = mask.to(dtype) # attention mask bias # NOTE(Mddct): torch.finfo jit issues # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min mask = (1.0 - mask) * -1.0e+10 return mask def padding(data: List[torch.Tensor]): """ Padding the data into batch data Parameters ---------- data: List[Tensor], shape of Tensor (128, T) Returns: ------- feats [B, 128, T_max], feats lengths [B] """ sample = data assert isinstance(sample, list) feats_lengths = torch.tensor([s.size(1) for s in sample], dtype=torch.int32) feats = [s.t() for s in sample] padded_feats = pad_sequence(feats, batch_first=True, padding_value=0) return padded_feats.transpose(1, 2), feats_lengths def merge_tokenized_segments(tokenized_segments, overlap, token_rate): """ Merges tokenized outputs by keeping the middle and dropping half of the overlapped tokens. Args: - tokenized_segments (List[List[int]]): List of tokenized sequences. - overlap (int): Overlapping duration in seconds (default: 4s). - token_rate (int): Number of tokens per second. Returns: - List[int]: A single merged token sequence. """ merged_tokens = [] overlap_tokens = ( overlap // 2) * token_rate # Tokens corresponding to half of the overlap duration for i, tokens in enumerate(tokenized_segments): l = 0 if i == 0 else overlap_tokens r = -overlap_tokens if i != len(tokenized_segments) - 1 else len(tokens) # Keep only the middle part (drop overlap / 2 from both sides) merged_tokens.extend(tokens[l:r]) return merged_tokens