Spaces:
Sleeping
Sleeping
| import torch | |
| ################################################################################ | |
| # Allow project-wide access to persistent data properties | |
| ################################################################################ | |
| class DataProperties(object): | |
| """ | |
| Allow shared access to data properties (e.g. sample rate) across all audio | |
| processing modules. Each dataset registers its properties with the | |
| DataProperties class upon initialization, eliminating the need to repeatedly | |
| pass properties as parameters | |
| """ | |
| # Default data properties: 1-second 16kHz audio scaled to [-1, 1] | |
| properties = { | |
| "sample_rate": 16000, | |
| "scale": 1.0, | |
| "signal_length": 16000 | |
| } | |
| def register_properties(cls, **kwargs): | |
| """ | |
| Register data properties by name | |
| """ | |
| cls.properties = kwargs | |
| def get(cls, *args): | |
| """ | |
| Access one or more data properties by name | |
| """ | |
| if len(args) > 1: | |
| return tuple(cls.properties[a] for a in args) | |
| else: | |
| return cls.properties[args[0]] | |
| def format_input(cls, x: torch.Tensor): | |
| """ | |
| Ensure input is correctly formatted (batch/channels/samples). If input | |
| cannot be reshaped to required dimensions, raise error | |
| """ | |
| try: | |
| signal_length = cls.properties["signal_length"] | |
| except KeyError: | |
| raise ValueError(f"Data property `signal_length` must be defined to" | |
| f" format inputs") | |
| if x.ndim <= 1: | |
| n_batch = 1 | |
| else: | |
| n_batch = x.shape[0] | |
| try: | |
| x = x.reshape(n_batch, 1, signal_length) | |
| except RuntimeError: | |
| raise ValueError(f"Invalid input dimensions {list(x.shape)}") | |
| return x | |