| | import torch |
| | from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor |
| |
|
| | class AstralQuantizer(torch.nn.Module): |
| | def __init__( |
| | self, |
| | tokenizer_name: str, |
| | ssl_model_name: str, |
| | ssl_output_layer: int, |
| | encoder: torch.nn.Module, |
| | quantizer: torch.nn.Module, |
| | skip_ssl: bool = False, |
| | ): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.quantizer = quantizer |
| | self.tokenizer_name = tokenizer_name |
| | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| |
|
| | |
| | self.ssl_model_name = ssl_model_name |
| | self.ssl_output_layer = ssl_output_layer |
| | self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name) |
| |
|
| | if skip_ssl: |
| | self.ssl_model = None |
| | else: |
| | self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval() |
| | self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer] |
| | self.ssl_model.encoder.layer_norm = torch.nn.Identity() |
| |
|
| | def load_separate_checkpoint(self, checkpoint_path): |
| | params = torch.load(checkpoint_path, map_location='cpu')['net'] |
| | for key in params.keys(): |
| | for k in list(params[key].keys()): |
| | if k.startswith("module."): |
| | params[key][k[len("module."):]] = params[key][k] |
| | del params[key][k] |
| | self.encoder.load_state_dict(params['encoder']) |
| | self.quantizer.load_state_dict(params['vq']) |
| | if self.decoder is not None: |
| | self.decoder.load_state_dict(params['decoder']) |
| | if self.asr_decoder is not None: |
| | self.asr_decoder.load_state_dict(params['predictor'], strict=False) |
| |
|
| | def forward(self, waves_16k, wave_16k_lens, ssl_model=None): |
| | ssl_fn = self.ssl_model if self.ssl_model else ssl_model |
| | assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided" |
| | waves_16k_input_list = [ |
| | waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy() |
| | for bib in range(len(waves_16k)) |
| | ] |
| | alt_inputs = self.ssl_feature_extractor( |
| | waves_16k_input_list, |
| | return_tensors='pt', |
| | return_attention_mask=True, |
| | padding=True, |
| | sampling_rate=16000 |
| | ).to(waves_16k.device) |
| | feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320 |
| |
|
| | outputs = ssl_fn( |
| | alt_inputs.input_values, |
| | attention_mask=alt_inputs.attention_mask, |
| | ) |
| | last_hidden_states = outputs.last_hidden_state |
| | last_hidden_states = last_hidden_states[:, :feature_lens.max(), :] |
| | feature_lens = feature_lens.clamp(max=last_hidden_states.size(1)) |
| | last_hidden_states = last_hidden_states.transpose(1, 2) |
| | x_hidden = self.encoder(last_hidden_states, feature_lens) |
| | x_hidden = x_hidden.transpose(1, 2) |
| | x_quantized, indices = self.quantizer(x_hidden)[:2] |
| | return x_quantized, indices, feature_lens |