File size: 3,510 Bytes
faadabf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from fairseq import checkpoint_utils
from torch.nn.utils.rnn import pad_sequence

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_mask_from_lengths(lengths, max_len=None):
    max_len = torch.max(lengths).item() if max_len is None else max_len
    ids = torch.arange(0, max_len).to(lengths.device)
    mask = ~(ids < lengths.unsqueeze(1)).bool()
    return mask


class HuBERT(nn.Module):

    def __init__(self, model_path, sampling_rate=16000):

        super().__init__()

        models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task(
            [model_path],
            suffix="",
        )

        model = models[0]
        model = model.half()
        model.eval()
        self.model = model

        for param in self.parameters():
            param.requires_grad = False

        self.sampling_rate = sampling_rate
        self.normalize = saved_cfg.task.normalize

    @torch.no_grad()
    @torch.cuda.amp.autocast(enabled=False, dtype=torch.float16)
    def forward(self, inp, length=None, split=True, split_size=4):
        self.model.eval()
        if self.training and split:
            split_size = int(math.ceil(inp.shape[0] / 4))
            outs, out_lens = [], []
            for i in range(0, inp.shape[0], split_size):
                inp_, length_ = inp[i : i + split_size], length[i : i + split_size]
                out_, out_len_ = self._extract(inp_, length_)
                outs.append(out_)
                out_lens.append(out_len_)
            max_length = max([max(ols) for ols in out_lens])

            return torch.cat(
                [F.pad(o, (0, 0, 0, max_length - o.shape[1]), value=0) for o in outs],
                dim=0,
            ), torch.cat(out_lens, dim=0)
        else:
            return self._extract(inp, length)

    @torch.no_grad()
    def _extract(self, inp, length):
        frame_samples = int(self.sampling_rate * 0.02)
        device = inp.device

        if len(inp.shape) == 3:
            inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2)
        assert len(inp.shape) == 2
        assert self.sampling_rate == 16000

        feats = inp

        # Padding with 0
        padding_size = 3200  # Longer to cover receptive field
        feats = F.pad(feats, (0, padding_size), mode="constant", value=0)

        # Norm volume using LN
        feats = self._postprocess(
            feats, length + padding_size, normalize=self.normalize
        )

        if length is None:
            padding_mask = torch.BoolTensor(feats.shape).fill_(False)
        else:
            length = torch.ceil(length / 320).int()
            padding_mask = get_mask_from_lengths(length).bool()
            padding_mask = F.pad(padding_mask, (0, 9), value=True)

        inputs = {
            "source": feats.half().to(device),
            "padding_mask": padding_mask.to(device),
            "mask": False,
        }
        logits, _ = self.model.extract_features(**inputs)
        logits = logits[:, : length.max()].float()

        return logits, length

    def _postprocess(self, feats, lengths, normalize=False):
        assert feats.dim() == 2, feats.dim()

        if normalize:
            with torch.no_grad():
                feats = [
                    F.layer_norm(feat[:length], feat[:length].shape)
                    for feat, length in zip(feats, lengths)
                ]
                feats = pad_sequence(feats, batch_first=True, padding_value=0)
        return feats