Shen Feiyu
add 1s
faadabf
from fairseq import checkpoint_utils
from torch.nn.utils.rnn import pad_sequence
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_mask_from_lengths(lengths, max_len=None):
max_len = torch.max(lengths).item() if max_len is None else max_len
ids = torch.arange(0, max_len).to(lengths.device)
mask = ~(ids < lengths.unsqueeze(1)).bool()
return mask
class HuBERT(nn.Module):
def __init__(self, model_path, sampling_rate=16000):
super().__init__()
models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task(
[model_path],
suffix="",
)
model = models[0]
model = model.half()
model.eval()
self.model = model
for param in self.parameters():
param.requires_grad = False
self.sampling_rate = sampling_rate
self.normalize = saved_cfg.task.normalize
@torch.no_grad()
@torch.cuda.amp.autocast(enabled=False, dtype=torch.float16)
def forward(self, inp, length=None, split=True, split_size=4):
self.model.eval()
if self.training and split:
split_size = int(math.ceil(inp.shape[0] / 4))
outs, out_lens = [], []
for i in range(0, inp.shape[0], split_size):
inp_, length_ = inp[i : i + split_size], length[i : i + split_size]
out_, out_len_ = self._extract(inp_, length_)
outs.append(out_)
out_lens.append(out_len_)
max_length = max([max(ols) for ols in out_lens])
return torch.cat(
[F.pad(o, (0, 0, 0, max_length - o.shape[1]), value=0) for o in outs],
dim=0,
), torch.cat(out_lens, dim=0)
else:
return self._extract(inp, length)
@torch.no_grad()
def _extract(self, inp, length):
frame_samples = int(self.sampling_rate * 0.02)
device = inp.device
if len(inp.shape) == 3:
inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2)
assert len(inp.shape) == 2
assert self.sampling_rate == 16000
feats = inp
# Padding with 0
padding_size = 3200 # Longer to cover receptive field
feats = F.pad(feats, (0, padding_size), mode="constant", value=0)
# Norm volume using LN
feats = self._postprocess(
feats, length + padding_size, normalize=self.normalize
)
if length is None:
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
else:
length = torch.ceil(length / 320).int()
padding_mask = get_mask_from_lengths(length).bool()
padding_mask = F.pad(padding_mask, (0, 9), value=True)
inputs = {
"source": feats.half().to(device),
"padding_mask": padding_mask.to(device),
"mask": False,
}
logits, _ = self.model.extract_features(**inputs)
logits = logits[:, : length.max()].float()
return logits, length
def _postprocess(self, feats, lengths, normalize=False):
assert feats.dim() == 2, feats.dim()
if normalize:
with torch.no_grad():
feats = [
F.layer_norm(feat[:length], feat[:length].shape)
for feat, length in zip(feats, lengths)
]
feats = pad_sequence(feats, batch_first=True, padding_value=0)
return feats