fdn-optimization / src /config.py
Gloria Dal Santo
Create main source code
d9e647a
# Standard library imports
from pathlib import Path
import warnings
# Third-party imports
from typing import Union, Optional, List
import torch
from pydantic import BaseModel, model_validator, Field
class FDNAttenuation(BaseModel):
"""
Configuration for attenuation filters in FDN.
"""
attenuation_type: str = Field(
default="homogeneous",
description="Type of attenuation filter. Types can be 'homogeneous', 'geq', or 'first_order_lp'."
)
attenuation_range: List[float] = Field(
default_factory=lambda: [0.5, 3.5],
description="Attenuation range in seconds (used only when attenuation_param is not given)."
)
rt_nyquist: float = Field(
default=0.2,
description="RT at Nyquist (for first order filter)."
)
attenuation_param: Optional[List[List[float]]] = Field(
default=None,
description="T60 parameter. The size depends on the attenuation_type: " \
"'homogeneous' -> [num, 1]; " \
"'geq' -> [num, num_bands]; " \
"'first_order_lp' -> [num, 2]."
)
t60_octave_interval: int = Field(
default=1,
description="Octave interval for T60."
)
t60_center_freq: List[float] = Field(
default_factory=lambda: [63, 125, 250, 500, 1000, 2000, 4000, 8000],
description="Center frequencies for T60."
)
@model_validator(mode="after")
def check_geq_parameters(self) -> "FDNAttenuation":
"""
Validate that for 'geq' attenuation type, t60_center_freq length matches
the second dimension of attenuation_param when provided.
"""
if (self.attenuation_type == "geq" and
self.attenuation_param is not None and
len(self.attenuation_param) > 0):
# Get the number of frequency bands from attenuation_param
num_bands = len(self.attenuation_param[0])
if len(self.t60_center_freq) != num_bands:
raise ValueError(
f"For 'geq' attenuation type, length of t60_center_freq "
f"({len(self.t60_center_freq)}) must match the number of frequency bands "
f"in attenuation_param ({num_bands})"
)
return self
class FDNMixing(BaseModel):
"""
Mixing matrix configuration for FDN.
"""
mixing_type: str = Field(
default="orthogonal",
description="Type of mixing matrix: 'orthogonal', 'householder', 'hadamard', or 'rotation'."
)
is_scattering: bool = Field(
default=False,
description="If filter feedback matrix is used."
)
is_velvet_noise: bool = Field(
default=False,
description="If velvet noise is used."
)
sparsity: int = Field(
default=1,
description="Density for scattering mapping."
)
n_stages: int = Field(
default=3,
description="Number of stages in the scattering mapping."
)
@model_validator(mode="after")
def check_mixing_exclusivity(self) -> "FDNMixing":
"""
Validate that is_scattering and is_velvet_noise are not both True.
"""
if self.is_scattering and self.is_velvet_noise:
raise ValueError("is_scattering and is_velvet_noise cannot both be True")
return self
class FDNConfig(BaseModel):
"""
FDN Configuration class.
"""
in_ch: int = Field(
default=1,
description="Input channels."
)
out_ch: int = Field(
default=1,
description="Output channels."
)
fs: int = Field(
default=48000,
description="Sampling frequency."
)
N: int = Field(
default=6,
description="Number of delay lines."
)
delay_lengths: Optional[List[int]] = Field(
default=None,
description="Delay lengths in samples."
)
delay_range_ms: List[float] = Field(
default_factory=lambda: [20.0, 50.0],
description="Delay lengths range in ms."
)
delay_log_spacing: bool = Field(
default=False,
description="If delay lengths should be logarithmically spaced."
)
onset_time: List[float] = Field(
default_factory=lambda: [10],
description="Onset time in ms."
)
early_reflections_type: Optional[str] = Field(
default=None,
description="Type of early reflections: 'gain', 'FIR', or None."
)
drr: float = Field(
default=0.25,
description="Direct to reverberant ratio."
)
energy: Optional[float] = Field(
default=None,
description="Energy of the FDN."
)
gain_init: str = Field(
default="randn",
description="Gain initialization distribution."
)
attenuation_config: FDNAttenuation = Field(
default_factory=FDNAttenuation,
description="Attenuation configuration."
)
mixing_matrix_config: FDNMixing = Field(
default_factory=FDNMixing,
description="Mixing matrix configuration."
)
alias_decay_db: float = Field(
default=0.0,
description="Alias decay in dB."
)
@model_validator(mode="after")
def check_delay_lengths(self) -> "BaseConfig":
"""
Validate that delay_lengths length matches N when provided, and check onset_time vs delay_range_ms.
"""
if self.delay_lengths is not None:
if len(self.delay_lengths) != self.N:
raise ValueError(
f"Length of delay_lengths ({len(self.delay_lengths)}) must match N ({self.N})"
)
if max(self.onset_time) > self.delay_range_ms[0]:
warnings.warn(
f"Max onset_time ({self.onset_time} ms) is larger than first element of delay_range_ms ({self.delay_range_ms[0]} ms)"
)
return self
@model_validator(mode="after")
def check_early_reflections(self) -> "FDNConfig":
"""
Set drr to 0 when early_reflections_type is None.
"""
if self.early_reflections_type is None:
self.drr = 0.0
print("Setting drr to 0.0 since early_reflections_type is None")
return self
class FDNOptimConfig(BaseModel):
"""
FDN Optimization Configuration class.
"""
max_epochs: int = Field(
default=10,
description="Number of optimization iterations."
)
lr: float = Field(
default=1e-3,
description="Learning rate."
)
batch_size: int = Field(
default=1,
description="Batch size."
)
device: str = Field(
default="cuda",
description="Device to use for optimization."
)
dataset_length: int = Field(
default=100,
description="Dataset length."
)
train_dir: str = Field(
default=None,
description="Training directory."
)
class BaseConfig(BaseModel):
"""
Base Configuration class for the overall system.
"""
fs: int = Field(
default=48000,
description="Sampling frequency."
)
nfft: int = Field(
default=96000,
description="Number of FFT points."
)
fdn_config: Union[FDNConfig] = Field(
default_factory=FDNConfig,
description="FDN configuration."
)
optimize: bool = Field(
default=False,
description="Whether to optimize for colorlessness."
)
fdn_optim_config: FDNOptimConfig = Field(
default_factory=FDNOptimConfig,
description="Optimization configuration."
)
device: str = Field(
default="cuda",
description="Device to use."
)
@classmethod
def create_with_fdn_params(
cls,
N: int,
delay_lengths: List[int],
**kwargs
) -> "BaseConfig":
"""
Convenience method to create BaseConfig with FDN parameters.
Args:
N: Number of delay lines
delay_lengths: List of delay lengths in samples
**kwargs: Additional parameters for BaseConfig or FDNConfig
(prefix with 'fdn_' for FDNConfig parameters)
Returns:
BaseConfig instance with configured FDN parameters
"""
# Separate FDN-specific kwargs from BaseConfig kwargs
fdn_kwargs = {}
base_kwargs = {}
for key, value in kwargs.items():
if key.startswith('fdn_'):
# Remove 'fdn_' prefix for FDNConfig parameters
fdn_kwargs[key[4:]] = value
else:
base_kwargs[key] = value
# Create FDNConfig with N and delay_lengths
fdn_config = FDNConfig(
N=N,
delay_lengths=delay_lengths,
**fdn_kwargs
)
# Create and return BaseConfig
return cls(fdn_config=fdn_config, **base_kwargs)
@model_validator(mode="after")
def validate_config(self) -> "BaseConfig":
"""
Validate FDN config, and check device availability.
"""
# Validate FDN configuration
if self.fdn_config.fs != self.fs:
raise ValueError("Sampling frequency in fdn_config must match fs")
# Validate device availability
original_device = self.device
if self.device.startswith("cuda"):
if not torch.cuda.is_available():
warnings.warn(f"CUDA not available, switching from '{original_device}' to 'cpu'")
self.device = "cpu"
elif self.device != "cuda": # specific cuda device like "cuda:0"
try:
device_id = int(self.device.split(":")[1])
if device_id >= torch.cuda.device_count():
warnings.warn(f"CUDA device {device_id} not available, switching to 'cuda:0'")
self.device = "cuda:0"
except (IndexError, ValueError):
warnings.warn(f"Invalid device format '{original_device}', switching to 'cuda'")
self.device = "cuda"
elif self.device == "mps":
if not torch.backends.mps.is_available():
warnings.warn(f"MPS not available, switching from '{original_device}' to 'cpu'")
self.device = "cpu"
# Sync device with optimization config
self.fdn_optim_config.device = self.device
return self