rd_hubert / distiller_model.py
Hguimaraes's picture
Upload model
8ff90c3 verified
import torch
import numpy as np
from transformers import PreTrainedModel
from .modeling_distiller import DistillerModel
from .configuration_distiller import DistillerConfig
class RDDistillerModel(PreTrainedModel):
config_class = DistillerConfig
def __init__(self, config: DistillerConfig):
super().__init__(config)
self.model = DistillerModel(config)
self.post_init()
def prepare_input_data(
self,
wavs: torch.Tensor,
wav_lens: torch.Tensor = None
):
if type(wavs) == list:
wav_lens = [len(wave) for wave in wavs]
wavs = pad_sequence(wavs, batch_first=True)
elif type(wavs) == torch.Tensor and wav_lens is None:
wav_lens = [wav.shape[0] for wav in wavs]
# add arbitary batch axis B if input `wavs` has shape of T
if wavs.dim() == 1:
wavs = wavs.unsqueeze(0)
elif wavs.dim() > 2:
raise ValueError
batch_size = wavs.shape[0]
seq_len = wavs.shape[1]
pad_mask = np.ones((batch_size, seq_len)) # (batch_size, seq_len)
# zero vectors for padding dimension
for idx in range(batch_size):
pad_mask[idx, wav_lens[idx] :] = 0
wavs = wavs.to(dtype=torch.float32) # (batch_size, seq_len, 1)
pad_mask = torch.FloatTensor(pad_mask).to(
device=wavs.device, dtype=torch.float32
) # (batch_size, seq_len)
return wavs, pad_mask # (x, pad_mask)
def forward(
self,
wavs: torch.Tensor,
wav_lens: torch.Tensor = None,
):
wavs, pad_mask = self.prepare_input_data(wavs, wav_lens)
_, feat_final, pred, _, layer_hidden = self.model(
wavs, pad_mask, get_hidden=True, no_pred=False
)
hidden_feats = pred.transpose(0, 1).split(1, 0)
hidden_feats = [hid.squeeze(0) for hid in hidden_feats]
hidden_feats = [feat_final] + layer_hidden + hidden_feats
return {
"last_hidden_state": hidden_feats[-1],
"hidden_states": hidden_feats,
}