| import torch |
| import numpy as np |
| from transformers import PreTrainedModel |
|
|
| from .modeling_distiller import DistillerModel |
| from .configuration_distiller import DistillerConfig |
|
|
|
|
| class RDDistillerModel(PreTrainedModel): |
| config_class = DistillerConfig |
|
|
| def __init__(self, config: DistillerConfig): |
| super().__init__(config) |
| self.model = DistillerModel(config) |
| self.post_init() |
| |
| def prepare_input_data( |
| self, |
| wavs: torch.Tensor, |
| wav_lens: torch.Tensor = None |
| ): |
| if type(wavs) == list: |
| wav_lens = [len(wave) for wave in wavs] |
| wavs = pad_sequence(wavs, batch_first=True) |
|
|
| elif type(wavs) == torch.Tensor and wav_lens is None: |
| wav_lens = [wav.shape[0] for wav in wavs] |
|
|
| |
| if wavs.dim() == 1: |
| wavs = wavs.unsqueeze(0) |
| elif wavs.dim() > 2: |
| raise ValueError |
|
|
| batch_size = wavs.shape[0] |
| seq_len = wavs.shape[1] |
|
|
| pad_mask = np.ones((batch_size, seq_len)) |
|
|
| |
| for idx in range(batch_size): |
| pad_mask[idx, wav_lens[idx] :] = 0 |
|
|
| wavs = wavs.to(dtype=torch.float32) |
| pad_mask = torch.FloatTensor(pad_mask).to( |
| device=wavs.device, dtype=torch.float32 |
| ) |
| return wavs, pad_mask |
|
|
| def forward( |
| self, |
| wavs: torch.Tensor, |
| wav_lens: torch.Tensor = None, |
| ): |
| wavs, pad_mask = self.prepare_input_data(wavs, wav_lens) |
| _, feat_final, pred, _, layer_hidden = self.model( |
| wavs, pad_mask, get_hidden=True, no_pred=False |
| ) |
|
|
| hidden_feats = pred.transpose(0, 1).split(1, 0) |
| hidden_feats = [hid.squeeze(0) for hid in hidden_feats] |
| hidden_feats = [feat_final] + layer_hidden + hidden_feats |
|
|
| return { |
| "last_hidden_state": hidden_feats[-1], |
| "hidden_states": hidden_feats, |
| } |
|
|