| | from typing import List |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| | from transformers import Wav2Vec2Processor, Wav2Vec2Model |
| |
|
| | SAMPLE_RATE = 16000 |
| |
|
| |
|
| | class UpstreamExpert(nn.Module): |
| | def __init__(self, ckpt: str = None, model_config: str = None, **kwargs): |
| | super().__init__() |
| |
|
| | self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
| | self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") |
| |
|
| | def get_downsample_rates(self, key: str) -> int: |
| | return 320 |
| |
|
| | def forward(self, wavs: List[Tensor]): |
| |
|
| |
|
| | wavs_silence = [] |
| |
|
| |
|
| | |
| |
|
| | |
| | wavs_silence = wavs |
| |
|
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//5).to(wav.device) |
| | wavs_silence.append(torch.cat((temp_wav, wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//10).to(wav.device) |
| | wavs_silence.append(torch.cat((temp_wav, wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//20).to(wav.device) |
| | wavs_silence.append(torch.cat((temp_wav, wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//5).to(wav.device) |
| | wavs_silence.append(torch.cat((wav, temp_wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//10).to(wav.device) |
| | wavs_silence.append(torch.cat((wav, temp_wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//20).to(wav.device) |
| | wavs_silence.append(torch.cat((wav, temp_wav))) |
| |
|
| |
|
| | wavs = wavs_silence |
| |
|
| |
|
| | device = wavs[0].device |
| |
|
| | processor_outputs = self.processor( |
| | [wav.cpu().numpy() for wav in wavs], |
| | return_tensors="pt", |
| | sampling_rate=SAMPLE_RATE, |
| | padding="longest", |
| | ) |
| | attention_mask = processor_outputs.get("attention_mask", None) |
| | if isinstance(attention_mask, torch.Tensor): |
| | attention_mask = attention_mask.to(device) |
| | model_outputs = self.model( |
| | processor_outputs.input_values.to(device), |
| | attention_mask=attention_mask, |
| | output_hidden_states=True, |
| | ) |
| | return { |
| | "last_hidden_state": model_outputs.last_hidden_state, |
| | "hidden_states": model_outputs.hidden_states, |
| | } |
| |
|