File size: 5,018 Bytes
e1c2051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import functools
from typing import Sequence

import torch
import torchaudio
from torch import nn


class RFFTimeDomainImplicitFilter(nn.Module):
    def __init__(
        self,
        n_filters: int,
        init_kernel_size: int,
        init_sample_rate: int,
        ch_list: Sequence[int] = [32, 32],
        n_RFFs: int = 32,
        nonlinearity: str = "relu",
        train_RFF: bool = False,
        use_layer_norm: bool = False,
    ) -> None:
        """n_filters: Number of filters.
        init_kernel_size: Initial kernel size.
        init_sample_rate: Initial sample rate.
        ch_list: Channel list of MLP.
        n_RFFs: Number of RFFs. If n_RFFs <= 0, do not use random Fourier feature inputs (i.e., directly input normalized time).
        nonlinearity (str): Nonlinearity
        train_RFF (bool): If True, train RFFs.
        use_layer_norm (bool): If True, use layer norm.
        """
        super().__init__()
        self.n_filters = n_filters
        self.register_buffer("init_kernel_size", torch.tensor(init_kernel_size).float())
        self.register_buffer("init_sample_rate", torch.tensor(init_sample_rate).float())

        # nonlinearity
        if nonlinearity == "relu":
            NonlinearityClass = functools.partial(nn.ReLU, inplace=True)
        else:
            raise NotImplementedError

        # MLP
        layers = []
        in_ch_list = [n_RFFs * 2 if n_RFFs > 0 else 1, *list(ch_list)]
        out_ch_list = [*list(ch_list), n_filters]
        for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list):
            layers.append(nn.Conv1d(in_ch, out_ch, 1))
            if i < len(in_ch_list) - 1:
                if use_layer_norm:
                    layers.append(nn.GroupNorm(1, out_ch))
                layers.append(NonlinearityClass())
        self.implicit_filter = nn.Sequential(*layers)

        if n_RFFs > 0:
            self.RFF_param = nn.Parameter(
                torch.zeros((n_RFFs,), dtype=torch.float).normal_(
                    0.0,
                    2.0 * torch.pi * 10.0,
                ),
                requires_grad=train_RFF,
            )
        else:
            self.RFF_param = None

        def set_zero_bias(m) -> None:
            if isinstance(m, nn.Conv1d):
                if m.bias is None:
                    msg = "bias cannot be none"
                    raise ValueError(msg)
                m.bias.data.fill_(0.0)

        self.implicit_filter.apply(set_zero_bias)

    @property
    def device(self):
        return self.implicit_filter[0].weight.device

    def _get_ir(self, normalized_time):
        """Return discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP.

        Args:
            normalized_time (torch.Tensor): Normalized time (time).

        Return:
            torch.Tensor: Discrete-time impulse responses (n_filters x time)

        """
        if self.RFF_param is not None:
            RFF = self.RFF_param[:, None] @ normalized_time[None, :]  # n_RFFs x time
            RFF = torch.cat((RFF.sin(), RFF.cos()), dim=0)  # n_RFFs*2 x time
            ir = self.implicit_filter(RFF[None, :, :])  # 1 x n_filters x time
        else:
            ir = self.implicit_filter(
                normalized_time[None, None, :],
            )  # 1 x n_filters x time
        return ir.view(*(ir.shape[1:]))

    def get_impulse_responses(self, sample_rate: int, kernel_size):
        """Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP."""
        use_oversampling = False
        if not self.training and hasattr(self, "use_oversampling"):
            use_oversampling = self.use_oversampling

        if use_oversampling:
            ir = self.get_impulse_responses_oversampling(sample_rate)
        else:
            normalized_time = torch.linspace(
                -1.0,
                1.0,
                kernel_size,
                device=self.device,
                requires_grad=False,
            )  # time
            ir = self._get_ir(normalized_time)
        return ir

    def get_impulse_responses_oversampling(self, sample_rate: int):
        """Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP with oversampling for anti-aliasing.
        First, calculate the discrete-time impulse responses with the trained sample rate.
        Then, resample the calculated discrete-time impulse responses at the input sample rate.
        """
        normalized_time = torch.linspace(
            -1.0,
            1.0,
            self.init_kernel_size.item(),
            device=self.device,
            requires_grad=False,
        )  # time
        ir = self._get_ir(normalized_time)
        resampled_ir = torchaudio.functional.resample(
            ir,
            int(self.init_sample_rate.item()),
            int(sample_rate),
        )  # resampling
        return resampled_ir.float().to(self.device)