File size: 34,562 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Tuple

import numpy as np
import torch

from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like
from nemo.collections.asr.parts.utils.audio_utils import db2mag, wrap_to_pi
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging
from nemo.utils.decorators import experimental

try:
    import torchaudio

    HAVE_TORCHAUDIO = True
except ModuleNotFoundError:
    HAVE_TORCHAUDIO = False


__all__ = [
    'MaskEstimatorRNN',
    'MaskReferenceChannel',
    'MaskBasedBeamformer',
    'MaskBasedDereverbWPE',
]


@experimental
class SpectrogramToMultichannelFeatures(NeuralModule):
    """Convert a complex-valued multi-channel spectrogram to
    multichannel features.

    Args:
        num_subbands: Expected number of subbands in the input signal
        num_input_channels: Optional, provides the number of channels
                            of the input signal. Used to infer the number
                            of output channels.
        magnitude_reduction: Reduction across channels. Default `None`, will calculate
                             magnitude of each channel.
        use_ipd: Use inter-channel phase difference (IPD).
        mag_normalization: Normalization for magnitude features
        ipd_normalization: Normalization for IPD features
    """

    def __init__(
        self,
        num_subbands: int,
        num_input_channels: Optional[int] = None,
        mag_reduction: Optional[str] = 'rms',
        use_ipd: bool = False,
        mag_normalization: Optional[str] = None,
        ipd_normalization: Optional[str] = None,
    ):
        super().__init__()
        self.mag_reduction = mag_reduction
        self.use_ipd = use_ipd

        # TODO: normalization
        if mag_normalization is not None:
            raise NotImplementedError(f'Unknown magnitude normalization {mag_normalization}')
        self.mag_normalization = mag_normalization

        if ipd_normalization is not None:
            raise NotImplementedError(f'Unknown ipd normalization {ipd_normalization}')
        self.ipd_normalization = ipd_normalization

        if self.use_ipd:
            self._num_features = 2 * num_subbands
            self._num_channels = num_input_channels
        else:
            self._num_features = num_subbands
            self._num_channels = num_input_channels if self.mag_reduction is None else 1

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @property
    def num_features(self) -> int:
        """Configured number of features
        """
        return self._num_features

    @property
    def num_channels(self) -> int:
        """Configured number of channels
        """
        if self._num_channels is not None:
            return self._num_channels
        else:
            raise ValueError(
                'Num channels is not configured. To configure this, `num_input_channels` '
                'must be provided when constructing the object.'
            )

    @typecheck()
    def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor:
        """Convert input batch of C-channel spectrograms into
        a batch of time-frequency features with dimension num_feat.
        The output number of channels may be the same as input, or
        reduced to 1, e.g., if averaging over magnitude and not appending individual IPDs.

        Args:
            input: Spectrogram for C channels with F subbands and N time frames, (B, C, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)

        Returns:
            num_feat_channels channels with num_feat features, shape (B, num_feat_channels, num_feat, N)
        """
        # Magnitude spectrum
        if self.mag_reduction is None:
            mag = torch.abs(input)
        elif self.mag_reduction == 'abs_mean':
            mag = torch.abs(torch.mean(input, axis=1, keepdim=True))
        elif self.mag_reduction == 'mean_abs':
            mag = torch.mean(torch.abs(input), axis=1, keepdim=True)
        elif self.mag_reduction == 'rms':
            mag = torch.sqrt(torch.mean(torch.abs(input) ** 2, axis=1, keepdim=True))
        else:
            raise ValueError(f'Unexpected magnitude reduction {self.mag_reduction}')

        if self.mag_normalization is not None:
            mag = self.mag_normalization(mag)

        features = mag

        if self.use_ipd:
            # Calculate IPD relative to average spec
            spec_mean = torch.mean(input, axis=1, keepdim=True)
            ipd = torch.angle(input) - torch.angle(spec_mean)
            # Modulo to [-pi, pi]
            ipd = wrap_to_pi(ipd)

            if self.ipd_normalization is not None:
                ipd = self.ipd_normalization(ipd)

            # Concatenate to existing features
            features = torch.cat([features.expand(ipd.shape), ipd], axis=2)

        if self._num_channels is not None and features.size(1) != self._num_channels:
            raise RuntimeError(
                f'Number of channels in features {features.size(1)} is different than the configured number of channels {self._num_channels}'
            )

        return features, input_length


class MaskEstimatorRNN(NeuralModule):
    """Estimate `num_outputs` masks from the input spectrogram
    using stacked RNNs and projections.

    The module is structured as follows:
        input --> spatial features --> input projection -->
            --> stacked RNNs --> output projection for each output --> sigmoid

    Reference:
        Multi-microphone neural speech separation for far-field multi-talker
        speech recognition (https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8462081)

    Args:
        num_outputs: Number of output masks to estimate
        num_subbands: Number of subbands of the input spectrogram
        num_features: Number of features after the input projections
        num_layers: Number of RNN layers
        num_hidden_features: Number of hidden features in RNN layers
        num_input_channels: Number of input channels
        dropout: If non-zero, introduces dropout on the outputs of each RNN layer except the last layer, with dropout
                 probability equal to `dropout`. Default: 0
        bidirectional: If `True`, use bidirectional RNN.
        rnn_type: Type of RNN, either `lstm` or `gru`. Default: `lstm`
        mag_reduction: Channel-wise reduction for magnitude features
        use_ipd: Use inter-channel phase difference (IPD) features
    """

    def __init__(
        self,
        num_outputs: int,
        num_subbands: int,
        num_features: int = 1024,
        num_layers: int = 3,
        num_hidden_features: Optional[int] = None,
        num_input_channels: Optional[int] = None,
        dropout: float = 0,
        bidirectional=True,
        rnn_type: str = 'lstm',
        mag_reduction: str = 'rms',
        use_ipd: bool = None,
    ):
        super().__init__()
        if num_hidden_features is None:
            num_hidden_features = num_features

        self.features = SpectrogramToMultichannelFeatures(
            num_subbands=num_subbands,
            num_input_channels=num_input_channels,
            mag_reduction=mag_reduction,
            use_ipd=use_ipd,
        )

        self.input_projection = torch.nn.Linear(
            in_features=self.features.num_features * self.features.num_channels, out_features=num_features
        )

        if rnn_type == 'lstm':
            self.rnn = torch.nn.LSTM(
                input_size=num_features,
                hidden_size=num_hidden_features,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout,
                bidirectional=bidirectional,
            )
        elif rnn_type == 'gru':
            self.rnn = torch.nn.GRU(
                input_size=num_features,
                hidden_size=num_hidden_features,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout,
                bidirectional=bidirectional,
            )
        else:
            raise ValueError(f'Unknown rnn_type: {rnn_type}')

        # Each output shares the RNN and has a separate projection
        self.output_projections = torch.nn.ModuleList(
            [
                torch.nn.Linear(
                    in_features=2 * num_features if bidirectional else num_features, out_features=num_subbands
                )
                for _ in range(num_outputs)
            ]
        )
        self.output_nonlinearity = torch.nn.Sigmoid()

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Estimate `num_outputs` masks from the input spectrogram.

        Args:
            input: C-channel input, shape (B, C, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)

        Returns:
            Returns `num_outputs` masks in a tensor, shape (B, num_outputs, F, N),
            and output length with shape (B,)
        """
        input, _ = self.features(input=input, input_length=input_length)
        B, num_feature_channels, num_features, N = input.shape

        # (B, num_feat_channels, num_feat, N) -> (B, N, num_feat_channels, num_feat)
        input = input.permute(0, 3, 1, 2)

        # (B, N, num_feat_channels, num_feat) -> (B, N, num_feat_channels * num_features)
        input = input.view(B, N, -1)

        # Apply projection on num_feat
        input = self.input_projection(input)

        # Apply RNN on the input sequence
        input_packed = torch.nn.utils.rnn.pack_padded_sequence(
            input, input_length.cpu(), batch_first=True, enforce_sorted=False
        ).to(input.device)
        self.rnn.flatten_parameters()
        input_packed, _ = self.rnn(input_packed)
        input, input_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True)
        input_length = input_length.to(input.device)

        # Create `num_outputs` masks
        output = []
        for output_projection in self.output_projections:
            # Output projection
            mask = output_projection(input)
            mask = self.output_nonlinearity(mask)

            # Back to the original format
            # (B, N, F) -> (B, F, N)
            mask = mask.transpose(2, 1)

            # Append to the output
            output.append(mask)

        # Stack along channel dimension to get (B, M, F, N)
        output = torch.stack(output, axis=1)

        # Mask frames beyond input length
        length_mask: torch.Tensor = make_seq_mask_like(
            lengths=input_length, like=output, time_dim=-1, valid_ones=False
        )
        output = output.masked_fill(length_mask, 0.0)

        return output, input_length


class MaskReferenceChannel(NeuralModule):
    """A simple mask processor which applies mask
    on ref_channel of the input signal.

    Args:
        ref_channel: Index of the reference channel.
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB
    """

    def __init__(self, ref_channel: int = 0, mask_min_db: float = -200, mask_max_db: float = 0):
        super().__init__()
        self.ref_channel = ref_channel
        # Mask thresholding
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(
        self, input: torch.Tensor, input_length: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply mask on `ref_channel` of the input signal.
        This can be used to generate multi-channel output.
        If `mask` has `M` channels, the output will have `M` channels as well.

        Args:
            input: Input signal complex-valued spectrogram, shape (B, C, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)
            mask: Mask for M outputs, shape (B, M, F, N)

        Returns:
            M-channel output complex-valed spectrogram with shape (B, M, F, N)
        """
        # Apply thresholds
        mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)

        # Apply each output mask on the ref channel
        output = mask * input[:, self.ref_channel : self.ref_channel + 1, ...]
        return output, input_length


class MaskBasedBeamformer(NeuralModule):
    """Multi-channel processor using masks to estimate signal statistics.

    Args:
        filter_type: string denoting the type of the filter. Defaults to `mvdr`
        ref_channel: reference channel for processing
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB
    """

    def __init__(
        self,
        filter_type: str = 'mvdr_souden',
        ref_channel: int = 0,
        mask_min_db: float = -200,
        mask_max_db: float = 0,
    ):
        if not HAVE_TORCHAUDIO:
            logging.error('Could not import torchaudio. Some features might not work.')

            raise ModuleNotFoundError(
                "torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}"
            )

        super().__init__()
        self.ref_channel = ref_channel
        self.filter_type = filter_type
        if self.filter_type == 'mvdr_souden':
            self.psd = torchaudio.transforms.PSD()
            self.filter = torchaudio.transforms.SoudenMVDR()
        else:
            raise ValueError(f'Unknown filter type {filter_type}')
        # Mask thresholding
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(self, input: torch.Tensor, input_length: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Apply a mask-based beamformer to the input spectrogram.
        This can be used to generate multi-channel output.
        If `mask` has `M` channels, the output will have `M` channels as well.

        Args:
            input: Input signal complex-valued spectrogram, shape (B, C, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)
            mask: Mask for M output signals, shape (B, M, F, N)
        
        Returns:
            M-channel output signal complex-valued spectrogram, shape (B, M, F, N)
        """
        # Apply threshold on the mask
        mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)
        # Length mask
        length_mask: torch.Tensor = make_seq_mask_like(
            lengths=input_length, like=mask[:, 0, ...], time_dim=-1, valid_ones=False
        )
        # Use each mask to generate an output at ref_channel
        output = []
        for m in range(mask.size(1)):
            # Prepare mask for the desired and the undesired signal
            mask_desired = mask[:, m, ...].masked_fill(length_mask, 0.0)
            mask_undesired = (1 - mask_desired).masked_fill(length_mask, 0.0)
            # Calculate PSDs
            psd_desired = self.psd(input, mask_desired)
            psd_undesired = self.psd(input, mask_undesired)
            # Apply filter
            output_m = self.filter(input, psd_desired, psd_undesired, reference_channel=self.ref_channel)
            output_m = output_m.masked_fill(length_mask, 0.0)
            # Save the current output (B, F, N)
            output.append(output_m)

        output = torch.stack(output, axis=1)

        return output, input_length


class WPEFilter(NeuralModule):
    """A weighted prediction error filter.
    Given input signal, and expected power of the desired signal, this
    class estimates a multiple-input multiple-output prediction filter
    and returns the filtered signal. Currently, estimation of statistics
    and processing is performed in batch mode.

    Args:
        filter_length: Length of the prediction filter in frames, per channel
        prediction_delay: Prediction delay in frames
        diag_reg: Diagonal regularization for the correlation matrix Q, applied as diag_reg * trace(Q) + eps
        eps: Small positive constant for regularization

    References:
        - Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction
            Methods for Blind MIMO Impulse Response Shortening, 2012
        - Jukić et al, Group sparsity for MIMO speech dereverberation, 2015
    """

    def __init__(
        self, filter_length: int, prediction_delay: int, diag_reg: Optional[float] = 1e-8, eps: float = 1e-10
    ):
        super().__init__()
        self.filter_length = filter_length
        self.prediction_delay = prediction_delay
        self.diag_reg = diag_reg
        self.eps = eps

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "power": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @typecheck()
    def forward(
        self, input: torch.Tensor, power: torch.Tensor, input_length: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Given input and the predicted power for the desired signal, estimate
        the WPE filter and return the processed signal.

        Args:
            input: Input signal, shape (B, C, F, N)
            power: Predicted power of the desired signal, shape (B, C, F, N)
            input_length: Optional, length of valid frames in `input`. Defaults to `None`

        Returns:
            Tuple of (processed_signal, output_length). Processed signal has the same
            shape as the input signal (B, C, F, N), and the output length is the same
            as the input length.
        """
        # Temporal weighting: average power over channels, shape (B, F, N)
        weight = torch.mean(power, dim=1)
        # Use inverse power as the weight
        weight = 1 / (weight + self.eps)

        # Multi-channel convolution matrix for each subband
        tilde_input = self.convtensor(input, filter_length=self.filter_length, delay=self.prediction_delay)

        # Estimate correlation matrices
        Q, R = self.estimate_correlations(
            input=input, weight=weight, tilde_input=tilde_input, input_length=input_length
        )

        # Estimate prediction filter
        G = self.estimate_filter(Q=Q, R=R)

        # Apply prediction filter
        undesired_signal = self.apply_filter(filter=G, tilde_input=tilde_input)

        # Dereverberation
        desired_signal = input - undesired_signal

        if input_length is not None:
            # Mask padded frames
            length_mask: torch.Tensor = make_seq_mask_like(
                lengths=input_length, like=desired_signal, time_dim=-1, valid_ones=False
            )
            desired_signal = desired_signal.masked_fill(length_mask, 0.0)

        return desired_signal, input_length

    @classmethod
    def convtensor(
        cls, x: torch.Tensor, filter_length: int, delay: int = 0, n_steps: Optional[int] = None
    ) -> torch.Tensor:
        """Create a tensor equivalent of convmtx_mc for each example in the batch.
        The input signal tensor `x` has shape (B, C, F, N).
        Convtensor returns a view of the input signal `x`.

        Note: We avoid reshaping the output to collapse channels and filter taps into
        a single dimension, e.g., (B, F, N, -1). In this way, the output is a view of the input,
        while an additional reshape would result in a contiguous array and more memory use.

        Args:
            x: input tensor, shape (B, C, F, N)
            filter_length: length of the filter, determines the shape of the convolution tensor
            delay: delay to add to the input signal `x` before constructing the convolution tensor
            n_steps: Optional, number of time steps to keep in the out. Defaults to the number of
                    time steps in the input tensor.

        Returns:
            Return a convolutional tensor with shape (B, C, F, n_steps, filter_length)
        """
        if x.ndim != 4:
            raise RuntimeError(f'Expecting a 4-D input. Received input with shape {x.shape}')

        B, C, F, N = x.shape

        if n_steps is None:
            # Keep the same length as the input signal
            n_steps = N

        # Pad temporal dimension
        x = torch.nn.functional.pad(x, (filter_length - 1 + delay, 0))

        # Build Toeplitz-like matrix view by unfolding across time
        tilde_X = x.unfold(-1, filter_length, 1)

        # Trim to the set number of time steps
        tilde_X = tilde_X[:, :, :, :n_steps, :]

        return tilde_X

    @classmethod
    def permute_convtensor(cls, x: torch.Tensor) -> torch.Tensor:
        """Reshape and permute columns to convert the result of
        convtensor to be equal to convmtx_mc. This is used for verification
        purposes and it is not required to use the filter.

        Args:
            x: output of self.convtensor, shape (B, C, F, N, filter_length)

        Returns:
            Output has shape (B, F, N, C*filter_length) that corresponds to
            the layout of convmtx_mc.
        """
        B, C, F, N, filter_length = x.shape

        # .view will not work, so a copy will have to be created with .reshape
        # That will result in more memory use, since we don't use a view of the original
        # multi-channel signal
        x = x.permute(0, 2, 3, 1, 4)
        x = x.reshape(B, F, N, C * filter_length)

        permute = []
        for m in range(C):
            permute[m * filter_length : (m + 1) * filter_length] = m * filter_length + np.flip(
                np.arange(filter_length)
            )
        return x[..., permute]

    def estimate_correlations(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        tilde_input: torch.Tensor,
        input_length: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor]:
        """
        Args:
            input: Input signal, shape (B, C, F, N)
            weight: Time-frequency weight, shape (B, F, N)
            tilde_input: Multi-channel convolution tensor, shape (B, C, F, N, filter_length)
            input_length: Length of each input example, shape (B)

        Returns:
            Returns a tuple of correlation matrices for each batch.

            Let `X` denote the input signal in a single subband,
            `tilde{X}` the corresponding multi-channel correlation matrix,
            and `w` the vector of weights.

            The first output is
                Q = tilde{X}^H * diag(w) * tilde{X}     (1)
            for each (b, f).
            The matrix calculated in (1) has shape (C * filter_length, C * filter_length)
            The output is returned in a tensor with shape (B, F, C, filter_length, C, filter_length).

            The second output is
                R = tilde{X}^H * diag(w) * X            (2)
            for each (b, f).
            The matrix calculated in (2) has shape (C * filter_length, C)
            The output is returned in a tensor with shape (B, F, C, filter_length, C). The last
            dimension corresponds to output channels.
        """
        if input_length is not None:
            # Take only valid samples into account
            length_mask: torch.Tensor = make_seq_mask_like(
                lengths=input_length, like=weight, time_dim=-1, valid_ones=False
            )
            weight = weight.masked_fill(length_mask, 0.0)

        # Calculate (1)
        # result: (B, F, C, filter_length, C, filter_length)
        Q = torch.einsum('bjfik,bmfin->bfjkmn', tilde_input.conj(), weight[:, None, :, :, None] * tilde_input)

        # Calculate (2)
        # result: (B, F, C, filter_length, C)
        R = torch.einsum('bjfik,bmfi->bfjkm', tilde_input.conj(), weight[:, None, :, :] * input)

        return Q, R

    def estimate_filter(self, Q: torch.Tensor, R: torch.Tensor) -> torch.Tensor:
        """Estimate the MIMO prediction filter as
            G(b,f) = Q(b,f) \ R(b,f)
        for each subband in each example in the batch (b, f).

        Args:
            Q: shape (B, F, C, filter_length, C, filter_length)
            R: shape (B, F, C, filter_length, C)

        Returns:
            Complex-valued prediction filter, shape (B, C, F, C, filter_length)
        """
        B, F, C, filter_length, _, _ = Q.shape
        assert (
            filter_length == self.filter_length
        ), f'Shape of Q {Q.shape} is not matching filter length {self.filter_length}'

        # Reshape to analytical dimensions for each (b, f)
        Q = Q.reshape(B, F, C * self.filter_length, C * filter_length)
        R = R.reshape(B, F, C * self.filter_length, C)

        # Diagonal regularization
        if self.diag_reg:
            # Regularization: diag_reg * trace(Q) + eps
            diag_reg = self.diag_reg * torch.diagonal(Q, dim1=-2, dim2=-1).sum(-1).real + self.eps
            # Apply regularization on Q
            Q = Q + torch.diag_embed(diag_reg.unsqueeze(-1) * torch.ones(Q.shape[-1], device=Q.device))

        # Solve for the filter
        G = torch.linalg.solve(Q, R)

        # Reshape to desired representation: (B, F, input channels, filter_length, output channels)
        G = G.reshape(B, F, C, filter_length, C)
        # Move output channels to front: (B, output channels, F, input channels, filter_length)
        G = G.permute(0, 4, 1, 2, 3)

        return G

    def apply_filter(
        self, filter: torch.Tensor, input: Optional[torch.Tensor] = None, tilde_input: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Apply a prediction filter `filter` on the input `input` as

            output(b,f) = tilde{input(b,f)} * filter(b,f)

        If available, directly use the convolution matrix `tilde_input`.

        Args:
            input: Input signal, shape (B, C, F, N)
            tilde_input: Convolution matrix for the input signal, shape (B, C, F, N, filter_length)
            filter: Prediction filter, shape (B, C, F, C, filter_length)

        Returns:
            Multi-channel signal obtained by applying the prediction filter on
            the input signal, same shape as input (B, C, F, N)
        """
        if input is None and tilde_input is None:
            raise RuntimeError(f'Both inputs cannot be None simultaneously.')
        if input is not None and tilde_input is not None:
            raise RuntimeError(f'Both inputs cannot be provided simultaneously.')

        if tilde_input is None:
            tilde_input = self.convtensor(input, filter_length=self.filter_length, delay=self.prediction_delay)

        # For each (batch, output channel, f, time step), sum across (input channel, filter tap)
        output = torch.einsum('bjfik,bmfjk->bmfi', tilde_input, filter)

        return output


class MaskBasedDereverbWPE(NeuralModule):
    """Multi-channel linear prediction-based dereverberation using
    weighted prediction error for filter estimation.

    An optional mask to estimate the signal power can be provided.
    If a time-frequency mask is not provided, the algorithm corresponds
    to the conventional WPE algorithm.

    Args:
        filter_length: Length of the convolutional filter for each channel in frames.
        prediction_delay: Delay of the input signal for multi-channel linear prediction in frames.
        num_iterations: Number of iterations for reweighting
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a minimal value before applying it, defaults to 0dB
        diag_reg: Diagonal regularization for WPE
        eps: Small regularization constant

    References:
        - Kinoshita et al, Neural network-based spectrum estimation for online WPE dereverberation, 2017
        - Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction Methods for Blind MIMO Impulse Response Shortening, 2012
    """

    def __init__(
        self,
        filter_length: int,
        prediction_delay: int,
        num_iterations: int = 1,
        mask_min_db: float = -200,
        mask_max_db: float = 0,
        diag_reg: Optional[float] = 1e-8,
        eps: float = 1e-10,
    ):
        super().__init__()
        # Filter setup
        self.filter = WPEFilter(
            filter_length=filter_length, prediction_delay=prediction_delay, diag_reg=diag_reg, eps=eps
        )
        self.num_iterations = num_iterations
        # Mask thresholding
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType(), optional=True),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports.
        """
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @typecheck()
    def forward(
        self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Given an input signal `input`, apply the WPE dereverberation algoritm.

        Args:
            input: C-channel complex-valued spectrogram, shape (B, C, F, N)
            input_length: Optional length for each signal in the batch, shape (B,)
            mask: Optional mask, shape (B, 1, F, N) or (B, C, F, N)

        Returns:
            Processed tensor with the same number of channels as the input,
            shape (B, C, F, N).
        """
        io_dtype = input.dtype

        with torch.cuda.amp.autocast(enabled=False):

            output = input.cdouble()

            for i in range(self.num_iterations):
                magnitude = torch.abs(output)
                if i == 0 and mask is not None:
                    # Apply thresholds
                    mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)
                    # Mask magnitude
                    magnitude = mask * magnitude
                # Calculate power
                power = magnitude ** 2
                # Apply filter
                output, output_length = self.filter(input=output, input_length=input_length, power=power)

        return output.to(io_dtype), output_length