Spaces:
Sleeping
Sleeping
| # Standard library imports | |
| from pathlib import Path | |
| import warnings | |
| # Third-party imports | |
| from typing import Union, Optional, List | |
| import torch | |
| from pydantic import BaseModel, model_validator, Field | |
| class FDNAttenuation(BaseModel): | |
| """ | |
| Configuration for attenuation filters in FDN. | |
| """ | |
| attenuation_type: str = Field( | |
| default="homogeneous", | |
| description="Type of attenuation filter. Types can be 'homogeneous', 'geq', or 'first_order_lp'." | |
| ) | |
| attenuation_range: List[float] = Field( | |
| default_factory=lambda: [0.5, 3.5], | |
| description="Attenuation range in seconds (used only when attenuation_param is not given)." | |
| ) | |
| rt_nyquist: float = Field( | |
| default=0.2, | |
| description="RT at Nyquist (for first order filter)." | |
| ) | |
| attenuation_param: Optional[List[List[float]]] = Field( | |
| default=None, | |
| description="T60 parameter. The size depends on the attenuation_type: " \ | |
| "'homogeneous' -> [num, 1]; " \ | |
| "'geq' -> [num, num_bands]; " \ | |
| "'first_order_lp' -> [num, 2]." | |
| ) | |
| t60_octave_interval: int = Field( | |
| default=1, | |
| description="Octave interval for T60." | |
| ) | |
| t60_center_freq: List[float] = Field( | |
| default_factory=lambda: [63, 125, 250, 500, 1000, 2000, 4000, 8000], | |
| description="Center frequencies for T60." | |
| ) | |
| def check_geq_parameters(self) -> "FDNAttenuation": | |
| """ | |
| Validate that for 'geq' attenuation type, t60_center_freq length matches | |
| the second dimension of attenuation_param when provided. | |
| """ | |
| if (self.attenuation_type == "geq" and | |
| self.attenuation_param is not None and | |
| len(self.attenuation_param) > 0): | |
| # Get the number of frequency bands from attenuation_param | |
| num_bands = len(self.attenuation_param[0]) | |
| if len(self.t60_center_freq) != num_bands: | |
| raise ValueError( | |
| f"For 'geq' attenuation type, length of t60_center_freq " | |
| f"({len(self.t60_center_freq)}) must match the number of frequency bands " | |
| f"in attenuation_param ({num_bands})" | |
| ) | |
| return self | |
| class FDNMixing(BaseModel): | |
| """ | |
| Mixing matrix configuration for FDN. | |
| """ | |
| mixing_type: str = Field( | |
| default="orthogonal", | |
| description="Type of mixing matrix: 'orthogonal', 'householder', 'hadamard', or 'rotation'." | |
| ) | |
| is_scattering: bool = Field( | |
| default=False, | |
| description="If filter feedback matrix is used." | |
| ) | |
| is_velvet_noise: bool = Field( | |
| default=False, | |
| description="If velvet noise is used." | |
| ) | |
| sparsity: int = Field( | |
| default=1, | |
| description="Density for scattering mapping." | |
| ) | |
| n_stages: int = Field( | |
| default=3, | |
| description="Number of stages in the scattering mapping." | |
| ) | |
| def check_mixing_exclusivity(self) -> "FDNMixing": | |
| """ | |
| Validate that is_scattering and is_velvet_noise are not both True. | |
| """ | |
| if self.is_scattering and self.is_velvet_noise: | |
| raise ValueError("is_scattering and is_velvet_noise cannot both be True") | |
| return self | |
| class FDNConfig(BaseModel): | |
| """ | |
| FDN Configuration class. | |
| """ | |
| in_ch: int = Field( | |
| default=1, | |
| description="Input channels." | |
| ) | |
| out_ch: int = Field( | |
| default=1, | |
| description="Output channels." | |
| ) | |
| fs: int = Field( | |
| default=48000, | |
| description="Sampling frequency." | |
| ) | |
| N: int = Field( | |
| default=6, | |
| description="Number of delay lines." | |
| ) | |
| delay_lengths: Optional[List[int]] = Field( | |
| default=None, | |
| description="Delay lengths in samples." | |
| ) | |
| delay_range_ms: List[float] = Field( | |
| default_factory=lambda: [20.0, 50.0], | |
| description="Delay lengths range in ms." | |
| ) | |
| delay_log_spacing: bool = Field( | |
| default=False, | |
| description="If delay lengths should be logarithmically spaced." | |
| ) | |
| onset_time: List[float] = Field( | |
| default_factory=lambda: [10], | |
| description="Onset time in ms." | |
| ) | |
| early_reflections_type: Optional[str] = Field( | |
| default=None, | |
| description="Type of early reflections: 'gain', 'FIR', or None." | |
| ) | |
| drr: float = Field( | |
| default=0.25, | |
| description="Direct to reverberant ratio." | |
| ) | |
| energy: Optional[float] = Field( | |
| default=None, | |
| description="Energy of the FDN." | |
| ) | |
| gain_init: str = Field( | |
| default="randn", | |
| description="Gain initialization distribution." | |
| ) | |
| attenuation_config: FDNAttenuation = Field( | |
| default_factory=FDNAttenuation, | |
| description="Attenuation configuration." | |
| ) | |
| mixing_matrix_config: FDNMixing = Field( | |
| default_factory=FDNMixing, | |
| description="Mixing matrix configuration." | |
| ) | |
| alias_decay_db: float = Field( | |
| default=0.0, | |
| description="Alias decay in dB." | |
| ) | |
| def check_delay_lengths(self) -> "BaseConfig": | |
| """ | |
| Validate that delay_lengths length matches N when provided, and check onset_time vs delay_range_ms. | |
| """ | |
| if self.delay_lengths is not None: | |
| if len(self.delay_lengths) != self.N: | |
| raise ValueError( | |
| f"Length of delay_lengths ({len(self.delay_lengths)}) must match N ({self.N})" | |
| ) | |
| if max(self.onset_time) > self.delay_range_ms[0]: | |
| warnings.warn( | |
| f"Max onset_time ({self.onset_time} ms) is larger than first element of delay_range_ms ({self.delay_range_ms[0]} ms)" | |
| ) | |
| return self | |
| def check_early_reflections(self) -> "FDNConfig": | |
| """ | |
| Set drr to 0 when early_reflections_type is None. | |
| """ | |
| if self.early_reflections_type is None: | |
| self.drr = 0.0 | |
| print("Setting drr to 0.0 since early_reflections_type is None") | |
| return self | |
| class FDNOptimConfig(BaseModel): | |
| """ | |
| FDN Optimization Configuration class. | |
| """ | |
| max_epochs: int = Field( | |
| default=10, | |
| description="Number of optimization iterations." | |
| ) | |
| lr: float = Field( | |
| default=1e-3, | |
| description="Learning rate." | |
| ) | |
| batch_size: int = Field( | |
| default=1, | |
| description="Batch size." | |
| ) | |
| device: str = Field( | |
| default="cuda", | |
| description="Device to use for optimization." | |
| ) | |
| dataset_length: int = Field( | |
| default=100, | |
| description="Dataset length." | |
| ) | |
| train_dir: str = Field( | |
| default=None, | |
| description="Training directory." | |
| ) | |
| class BaseConfig(BaseModel): | |
| """ | |
| Base Configuration class for the overall system. | |
| """ | |
| fs: int = Field( | |
| default=48000, | |
| description="Sampling frequency." | |
| ) | |
| nfft: int = Field( | |
| default=96000, | |
| description="Number of FFT points." | |
| ) | |
| fdn_config: Union[FDNConfig] = Field( | |
| default_factory=FDNConfig, | |
| description="FDN configuration." | |
| ) | |
| optimize: bool = Field( | |
| default=False, | |
| description="Whether to optimize for colorlessness." | |
| ) | |
| fdn_optim_config: FDNOptimConfig = Field( | |
| default_factory=FDNOptimConfig, | |
| description="Optimization configuration." | |
| ) | |
| device: str = Field( | |
| default="cuda", | |
| description="Device to use." | |
| ) | |
| def create_with_fdn_params( | |
| cls, | |
| N: int, | |
| delay_lengths: List[int], | |
| **kwargs | |
| ) -> "BaseConfig": | |
| """ | |
| Convenience method to create BaseConfig with FDN parameters. | |
| Args: | |
| N: Number of delay lines | |
| delay_lengths: List of delay lengths in samples | |
| **kwargs: Additional parameters for BaseConfig or FDNConfig | |
| (prefix with 'fdn_' for FDNConfig parameters) | |
| Returns: | |
| BaseConfig instance with configured FDN parameters | |
| """ | |
| # Separate FDN-specific kwargs from BaseConfig kwargs | |
| fdn_kwargs = {} | |
| base_kwargs = {} | |
| for key, value in kwargs.items(): | |
| if key.startswith('fdn_'): | |
| # Remove 'fdn_' prefix for FDNConfig parameters | |
| fdn_kwargs[key[4:]] = value | |
| else: | |
| base_kwargs[key] = value | |
| # Create FDNConfig with N and delay_lengths | |
| fdn_config = FDNConfig( | |
| N=N, | |
| delay_lengths=delay_lengths, | |
| **fdn_kwargs | |
| ) | |
| # Create and return BaseConfig | |
| return cls(fdn_config=fdn_config, **base_kwargs) | |
| def validate_config(self) -> "BaseConfig": | |
| """ | |
| Validate FDN config, and check device availability. | |
| """ | |
| # Validate FDN configuration | |
| if self.fdn_config.fs != self.fs: | |
| raise ValueError("Sampling frequency in fdn_config must match fs") | |
| # Validate device availability | |
| original_device = self.device | |
| if self.device.startswith("cuda"): | |
| if not torch.cuda.is_available(): | |
| warnings.warn(f"CUDA not available, switching from '{original_device}' to 'cpu'") | |
| self.device = "cpu" | |
| elif self.device != "cuda": # specific cuda device like "cuda:0" | |
| try: | |
| device_id = int(self.device.split(":")[1]) | |
| if device_id >= torch.cuda.device_count(): | |
| warnings.warn(f"CUDA device {device_id} not available, switching to 'cuda:0'") | |
| self.device = "cuda:0" | |
| except (IndexError, ValueError): | |
| warnings.warn(f"Invalid device format '{original_device}', switching to 'cuda'") | |
| self.device = "cuda" | |
| elif self.device == "mps": | |
| if not torch.backends.mps.is_available(): | |
| warnings.warn(f"MPS not available, switching from '{original_device}' to 'cpu'") | |
| self.device = "cpu" | |
| # Sync device with optimization config | |
| self.fdn_optim_config.device = self.device | |
| return self |