| | import torch |
| | import numpy as np |
| | from transformers import PreTrainedModel |
| |
|
| | from .modeling_distiller import DistillerModel |
| | from .configuration_distiller import DistillerConfig |
| |
|
| |
|
| | class RDDistillerModel(PreTrainedModel): |
| | config_class = DistillerConfig |
| |
|
| | def __init__(self, config: DistillerConfig): |
| | super().__init__(config) |
| | self.model = DistillerModel(config) |
| | self.post_init() |
| | |
| | def prepare_input_data( |
| | self, |
| | wavs: torch.Tensor, |
| | wav_lens: torch.Tensor = None |
| | ): |
| | if type(wavs) == list: |
| | wav_lens = [len(wave) for wave in wavs] |
| | wavs = pad_sequence(wavs, batch_first=True) |
| |
|
| | elif type(wavs) == torch.Tensor and wav_lens is None: |
| | wav_lens = [wav.shape[0] for wav in wavs] |
| |
|
| | |
| | if wavs.dim() == 1: |
| | wavs = wavs.unsqueeze(0) |
| | elif wavs.dim() > 2: |
| | raise ValueError |
| |
|
| | batch_size = wavs.shape[0] |
| | seq_len = wavs.shape[1] |
| |
|
| | pad_mask = np.ones((batch_size, seq_len)) |
| |
|
| | |
| | for idx in range(batch_size): |
| | pad_mask[idx, wav_lens[idx] :] = 0 |
| |
|
| | wavs = wavs.to(dtype=torch.float32) |
| | pad_mask = torch.FloatTensor(pad_mask).to( |
| | device=wavs.device, dtype=torch.float32 |
| | ) |
| | return wavs, pad_mask |
| |
|
| | def forward( |
| | self, |
| | wavs: torch.Tensor, |
| | wav_lens: torch.Tensor = None, |
| | ): |
| | wavs, pad_mask = self.prepare_input_data(wavs, wav_lens) |
| | _, feat_final, pred, _, layer_hidden = self.model( |
| | wavs, pad_mask, get_hidden=True, no_pred=False |
| | ) |
| |
|
| | hidden_feats = pred.transpose(0, 1).split(1, 0) |
| | hidden_feats = [hid.squeeze(0) for hid in hidden_feats] |
| | hidden_feats = [feat_final] + layer_hidden + hidden_feats |
| |
|
| | return { |
| | "last_hidden_state": hidden_feats[-1], |
| | "hidden_states": hidden_feats, |
| | } |
| |
|