voiceblock / voicebox /src /data /dataproperties.py
ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import torch
################################################################################
# Allow project-wide access to persistent data properties
################################################################################
class DataProperties(object):
"""
Allow shared access to data properties (e.g. sample rate) across all audio
processing modules. Each dataset registers its properties with the
DataProperties class upon initialization, eliminating the need to repeatedly
pass properties as parameters
"""
# Default data properties: 1-second 16kHz audio scaled to [-1, 1]
properties = {
"sample_rate": 16000,
"scale": 1.0,
"signal_length": 16000
}
@classmethod
def register_properties(cls, **kwargs):
"""
Register data properties by name
"""
cls.properties = kwargs
@classmethod
def get(cls, *args):
"""
Access one or more data properties by name
"""
if len(args) > 1:
return tuple(cls.properties[a] for a in args)
else:
return cls.properties[args[0]]
@classmethod
def format_input(cls, x: torch.Tensor):
"""
Ensure input is correctly formatted (batch/channels/samples). If input
cannot be reshaped to required dimensions, raise error
"""
try:
signal_length = cls.properties["signal_length"]
except KeyError:
raise ValueError(f"Data property `signal_length` must be defined to"
f" format inputs")
if x.ndim <= 1:
n_batch = 1
else:
n_batch = x.shape[0]
try:
x = x.reshape(n_batch, 1, signal_length)
except RuntimeError:
raise ValueError(f"Invalid input dimensions {list(x.shape)}")
return x