Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| from torch import nn | |
| import torchaudio | |
| import torchaudio.compliance.kaldi as kaldi | |
| from .adapter import CNNAdapter, CNNSubsampling, LinearAdapter | |
| from .cmvn import GlobalCMVN, load_cmvn | |
| from .module.encoder.encoder import whaleEncoder | |
| class audioEncoderProcessor: | |
| def __init__( | |
| self, | |
| dataset_conf: dict = None, | |
| ): | |
| self.dataset_conf = dataset_conf | |
| def process(self, wav_path): | |
| try: | |
| waveform, sample_rate = torchaudio.load(wav_path) | |
| except Exception as e: | |
| print(f"cannot open {wav_path}!!!!!!!!!!!!!!!!") | |
| if sample_rate != self.dataset_conf["resample_conf"]["resample_rate"]: | |
| waveform = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, new_freq=self.dataset_conf["resample_conf"]["resample_rate"] | |
| )(waveform) | |
| sample_rate = self.dataset_conf['resample_conf']['resample_rate'] | |
| waveform = waveform * (1 << 15) | |
| # Only keep key, feat, label | |
| mat = kaldi.fbank( | |
| waveform, | |
| num_mel_bins=self.dataset_conf["fbank_conf"]["num_mel_bins"], | |
| frame_length=self.dataset_conf["fbank_conf"]["frame_length"], | |
| frame_shift=self.dataset_conf["fbank_conf"]["frame_shift"], | |
| dither=self.dataset_conf["fbank_conf"]["dither"], | |
| energy_floor=0.0, | |
| sample_frequency=sample_rate, | |
| ) | |
| attn_mask = torch.ones(mat.shape[0]) | |
| attn_mask = attn_mask[2::2][2::2][0::2] | |
| return mat, attn_mask.shape[0] | |
| class audioEncoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| encoder: torch.nn.Module, | |
| llm_path: str, | |
| freeze_llm: bool = True, | |
| enc_out_dim: int = 512, | |
| llm_embed_dim: int = 4096, | |
| kernel_size: int = 3, | |
| IGNORE_ID: int = -100, | |
| adpter_type: str = "cnn", | |
| add_audio_bos_eos: bool = False, | |
| task_num: int = 10, | |
| task_before_audio: bool = False, | |
| task_type: str = "prompt", | |
| freeze_encoder: bool = False, | |
| freeze_adpter: bool = False, | |
| audio_prompt_finetune: bool = False, | |
| audio_prompt_num: int = 25, | |
| activation_func: str = "relu", | |
| norm: str = "batch", | |
| chat_template=None, | |
| ): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.enc_out_dim = enc_out_dim | |
| self.llm_embed_dim = llm_embed_dim | |
| self.IGNORE_ID = IGNORE_ID | |
| self.add_audio_bos_eos = add_audio_bos_eos | |
| self.task_before_audio = task_before_audio | |
| self.task_type = task_type | |
| self.freeze_encoder = freeze_encoder | |
| self.freeze_adpter = freeze_adpter | |
| self.audio_prompt_finetune = audio_prompt_finetune | |
| self.audio_prompt_num = audio_prompt_num | |
| if adpter_type == "cnn": | |
| self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size) | |
| elif adpter_type == "linear": | |
| self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim) | |
| elif adpter_type == "subsampling": | |
| self.adpter = CNNSubsampling( | |
| enc_out_dim, llm_embed_dim, kernel_size, activation_func, norm | |
| ) | |
| if self.freeze_encoder: | |
| self.encoder.eval() | |
| for (name, param) in self.encoder.named_parameters(): | |
| param.requires_grad = False | |
| if self.freeze_adpter: | |
| self.adpter.eval() | |
| for (name, param) in self.adpter.named_parameters(): | |
| param.requires_grad = False | |
| if self.audio_prompt_finetune: | |
| self.prompt_embeddings = nn.Embedding(audio_prompt_num, llm_embed_dim) | |
| self.prompt_ids = torch.tensor([i for i in range(audio_prompt_num)]).long() | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| ) -> Dict[str, Optional[torch.Tensor]]: | |
| speech = speech.to(next(self.parameters()).dtype) | |
| # 1. Encoder | |
| encoder_out, encoder_mask = self.encoder(speech, speech_lengths) | |
| inputs_embeds, encoder_mask = self.adpter(encoder_out, encoder_mask) # B, T, D | |
| attention_mask = encoder_mask.squeeze(1) # B, T | |
| assert inputs_embeds.size(1) == attention_mask.size(1) | |
| # audio bos/eos | |
| if self.add_audio_bos_eos: | |
| inputs_embeds, attention_mask, target = self._add_bos_eos( | |
| "audio", "/audio", inputs_embeds, attention_mask, target | |
| ) | |
| B, _, _ = inputs_embeds.shape | |
| if self.audio_prompt_finetune: | |
| prompt_ids = self.prompt_ids.repeat(B, 1).to(inputs_embeds.device) | |
| prompt_embeds = self.prompt_embeddings( | |
| prompt_ids.to(inputs_embeds.device)) # B, 5, D | |
| inputs_embeds = torch.cat((prompt_embeds, inputs_embeds), 1) # B, (T+5), D | |
| outputs = { | |
| "inputs_embeds": inputs_embeds, | |
| "attention_mask": attention_mask, | |
| } | |
| return outputs | |
| def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None): | |
| B = len(inputs_embeds) | |
| bos_embed = self.task_embeddings( | |
| torch.full([B, 1], self.task_ids[bos]).to(inputs_embeds.device) | |
| ) # B, 1, D | |
| eos_embed = self.task_embeddings( | |
| torch.full([B, 1], self.task_ids[eos]).to(inputs_embeds.device) | |
| ) # B, 1, D | |
| bos_eos_target = torch.full([B, 2], self.IGNORE_ID).to(inputs_embeds.device) # B, 2 | |
| bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1 | |
| inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D | |
| inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D | |
| attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T) | |
| attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1) | |
| if target is not None: | |
| target = torch.cat((target, bos_eos_target), 1) # B, (T+2), D | |
| return inputs_embeds, attention_mask, target | |
| def init_model(configs): | |
| if configs["cmvn_file"] is not None: | |
| mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"]) | |
| global_cmvn = GlobalCMVN(torch.from_numpy(mean).float(), torch.from_numpy(istd).float()) | |
| else: | |
| global_cmvn = None | |
| input_dim = configs["input_dim"] | |
| encoder = whaleEncoder(input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]) | |
| model = audioEncoder(encoder=encoder, **configs["model_conf"]) | |
| processor = audioEncoderProcessor(dataset_conf=configs["dataset_conf"]) | |
| model.audio_processor = processor | |
| return model | |