File size: 1,949 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch

################################################################################
# Allow project-wide access to persistent data properties
################################################################################


class DataProperties(object):
    """

    Allow shared access to data properties (e.g. sample rate) across all audio

    processing modules. Each dataset registers its properties with the

    DataProperties class upon initialization, eliminating the need to repeatedly

    pass properties as parameters

    """

    # Default data properties: 1-second 16kHz audio scaled to [-1, 1]
    properties = {
        "sample_rate": 16000,
        "scale": 1.0,
        "signal_length": 16000
    }

    @classmethod
    def register_properties(cls, **kwargs):
        """

        Register data properties by name

        """
        cls.properties = kwargs

    @classmethod
    def get(cls, *args):
        """

        Access one or more data properties by name

        """
        if len(args) > 1:
            return tuple(cls.properties[a] for a in args)
        else:
            return cls.properties[args[0]]

    @classmethod
    def format_input(cls, x: torch.Tensor):
        """

        Ensure input is correctly formatted (batch/channels/samples). If input

        cannot be reshaped to required dimensions, raise error

        """

        try:
            signal_length = cls.properties["signal_length"]
        except KeyError:
            raise ValueError(f"Data property `signal_length` must be defined to"
                             f" format inputs")

        if x.ndim <= 1:
            n_batch = 1

        else:
            n_batch = x.shape[0]

        try:
            x = x.reshape(n_batch, 1, signal_length)
        except RuntimeError:
            raise ValueError(f"Invalid input dimensions {list(x.shape)}")

        return x