# Standard library imports from pathlib import Path import warnings # Third-party imports from typing import Union, Optional, List import torch from pydantic import BaseModel, model_validator, Field class FDNAttenuation(BaseModel): """ Configuration for attenuation filters in FDN. """ attenuation_type: str = Field( default="homogeneous", description="Type of attenuation filter. Types can be 'homogeneous', 'geq', or 'first_order_lp'." ) attenuation_range: List[float] = Field( default_factory=lambda: [0.5, 3.5], description="Attenuation range in seconds (used only when attenuation_param is not given)." ) rt_nyquist: float = Field( default=0.2, description="RT at Nyquist (for first order filter)." ) attenuation_param: Optional[List[List[float]]] = Field( default=None, description="T60 parameter. The size depends on the attenuation_type: " \ "'homogeneous' -> [num, 1]; " \ "'geq' -> [num, num_bands]; " \ "'first_order_lp' -> [num, 2]." ) t60_octave_interval: int = Field( default=1, description="Octave interval for T60." ) t60_center_freq: List[float] = Field( default_factory=lambda: [63, 125, 250, 500, 1000, 2000, 4000, 8000], description="Center frequencies for T60." ) @model_validator(mode="after") def check_geq_parameters(self) -> "FDNAttenuation": """ Validate that for 'geq' attenuation type, t60_center_freq length matches the second dimension of attenuation_param when provided. """ if (self.attenuation_type == "geq" and self.attenuation_param is not None and len(self.attenuation_param) > 0): # Get the number of frequency bands from attenuation_param num_bands = len(self.attenuation_param[0]) if len(self.t60_center_freq) != num_bands: raise ValueError( f"For 'geq' attenuation type, length of t60_center_freq " f"({len(self.t60_center_freq)}) must match the number of frequency bands " f"in attenuation_param ({num_bands})" ) return self class FDNMixing(BaseModel): """ Mixing matrix configuration for FDN. """ mixing_type: str = Field( default="orthogonal", description="Type of mixing matrix: 'orthogonal', 'householder', 'hadamard', or 'rotation'." ) is_scattering: bool = Field( default=False, description="If filter feedback matrix is used." ) is_velvet_noise: bool = Field( default=False, description="If velvet noise is used." ) sparsity: int = Field( default=1, description="Density for scattering mapping." ) n_stages: int = Field( default=3, description="Number of stages in the scattering mapping." ) @model_validator(mode="after") def check_mixing_exclusivity(self) -> "FDNMixing": """ Validate that is_scattering and is_velvet_noise are not both True. """ if self.is_scattering and self.is_velvet_noise: raise ValueError("is_scattering and is_velvet_noise cannot both be True") return self class FDNConfig(BaseModel): """ FDN Configuration class. """ in_ch: int = Field( default=1, description="Input channels." ) out_ch: int = Field( default=1, description="Output channels." ) fs: int = Field( default=48000, description="Sampling frequency." ) N: int = Field( default=6, description="Number of delay lines." ) delay_lengths: Optional[List[int]] = Field( default=None, description="Delay lengths in samples." ) delay_range_ms: List[float] = Field( default_factory=lambda: [20.0, 50.0], description="Delay lengths range in ms." ) delay_log_spacing: bool = Field( default=False, description="If delay lengths should be logarithmically spaced." ) onset_time: List[float] = Field( default_factory=lambda: [10], description="Onset time in ms." ) early_reflections_type: Optional[str] = Field( default=None, description="Type of early reflections: 'gain', 'FIR', or None." ) drr: float = Field( default=0.25, description="Direct to reverberant ratio." ) energy: Optional[float] = Field( default=None, description="Energy of the FDN." ) gain_init: str = Field( default="randn", description="Gain initialization distribution." ) attenuation_config: FDNAttenuation = Field( default_factory=FDNAttenuation, description="Attenuation configuration." ) mixing_matrix_config: FDNMixing = Field( default_factory=FDNMixing, description="Mixing matrix configuration." ) alias_decay_db: float = Field( default=0.0, description="Alias decay in dB." ) @model_validator(mode="after") def check_delay_lengths(self) -> "BaseConfig": """ Validate that delay_lengths length matches N when provided, and check onset_time vs delay_range_ms. """ if self.delay_lengths is not None: if len(self.delay_lengths) != self.N: raise ValueError( f"Length of delay_lengths ({len(self.delay_lengths)}) must match N ({self.N})" ) if max(self.onset_time) > self.delay_range_ms[0]: warnings.warn( f"Max onset_time ({self.onset_time} ms) is larger than first element of delay_range_ms ({self.delay_range_ms[0]} ms)" ) return self @model_validator(mode="after") def check_early_reflections(self) -> "FDNConfig": """ Set drr to 0 when early_reflections_type is None. """ if self.early_reflections_type is None: self.drr = 0.0 print("Setting drr to 0.0 since early_reflections_type is None") return self class FDNOptimConfig(BaseModel): """ FDN Optimization Configuration class. """ max_epochs: int = Field( default=10, description="Number of optimization iterations." ) lr: float = Field( default=1e-3, description="Learning rate." ) batch_size: int = Field( default=1, description="Batch size." ) device: str = Field( default="cuda", description="Device to use for optimization." ) dataset_length: int = Field( default=100, description="Dataset length." ) train_dir: str = Field( default=None, description="Training directory." ) class BaseConfig(BaseModel): """ Base Configuration class for the overall system. """ fs: int = Field( default=48000, description="Sampling frequency." ) nfft: int = Field( default=96000, description="Number of FFT points." ) fdn_config: Union[FDNConfig] = Field( default_factory=FDNConfig, description="FDN configuration." ) optimize: bool = Field( default=False, description="Whether to optimize for colorlessness." ) fdn_optim_config: FDNOptimConfig = Field( default_factory=FDNOptimConfig, description="Optimization configuration." ) device: str = Field( default="cuda", description="Device to use." ) @classmethod def create_with_fdn_params( cls, N: int, delay_lengths: List[int], **kwargs ) -> "BaseConfig": """ Convenience method to create BaseConfig with FDN parameters. Args: N: Number of delay lines delay_lengths: List of delay lengths in samples **kwargs: Additional parameters for BaseConfig or FDNConfig (prefix with 'fdn_' for FDNConfig parameters) Returns: BaseConfig instance with configured FDN parameters """ # Separate FDN-specific kwargs from BaseConfig kwargs fdn_kwargs = {} base_kwargs = {} for key, value in kwargs.items(): if key.startswith('fdn_'): # Remove 'fdn_' prefix for FDNConfig parameters fdn_kwargs[key[4:]] = value else: base_kwargs[key] = value # Create FDNConfig with N and delay_lengths fdn_config = FDNConfig( N=N, delay_lengths=delay_lengths, **fdn_kwargs ) # Create and return BaseConfig return cls(fdn_config=fdn_config, **base_kwargs) @model_validator(mode="after") def validate_config(self) -> "BaseConfig": """ Validate FDN config, and check device availability. """ # Validate FDN configuration if self.fdn_config.fs != self.fs: raise ValueError("Sampling frequency in fdn_config must match fs") # Validate device availability original_device = self.device if self.device.startswith("cuda"): if not torch.cuda.is_available(): warnings.warn(f"CUDA not available, switching from '{original_device}' to 'cpu'") self.device = "cpu" elif self.device != "cuda": # specific cuda device like "cuda:0" try: device_id = int(self.device.split(":")[1]) if device_id >= torch.cuda.device_count(): warnings.warn(f"CUDA device {device_id} not available, switching to 'cuda:0'") self.device = "cuda:0" except (IndexError, ValueError): warnings.warn(f"Invalid device format '{original_device}', switching to 'cuda'") self.device = "cuda" elif self.device == "mps": if not torch.backends.mps.is_available(): warnings.warn(f"MPS not available, switching from '{original_device}' to 'cpu'") self.device = "cpu" # Sync device with optimization config self.fdn_optim_config.device = self.device return self