Spaces:
Sleeping
Sleeping
| from collections import OrderedDict | |
| from typing import List, Literal, Optional, Dict, Any, Union | |
| import torch | |
| from torch import nn | |
| from flamo import dsp, system | |
| from flamo.auxiliary.reverb import ( | |
| parallelFDNAccurateGEQ, | |
| parallelFirstOrderShelving, | |
| ) | |
| from flamo.functional import signal_gallery | |
| from flareverb.config.config import ( | |
| BaseConfig, | |
| FDNAttenuation, | |
| FDNMixing, | |
| FDNConfig, | |
| ) | |
| from flareverb.utils import ms_to_samps, rt2slope | |
| from flareverb.reverb import MapGamma | |
| class BaseFDN(nn.Module): | |
| """Base Feedback Delay Network (FDN) class for reverberation modeling. | |
| """ | |
| def __init__( | |
| self, | |
| config: FDNConfig, | |
| nfft: int, | |
| alias_decay_db: float, | |
| delay_lengths: List[int], | |
| device: Literal["cpu", "cuda"] = "cuda", | |
| requires_grad: bool = True, | |
| output_layer: Literal["freq_complex", "freq_mag", "time"] = "time", | |
| ) -> None: | |
| """ | |
| """ | |
| super().__init__() | |
| self._validate_delays(config, delay_lengths) | |
| self._initialize_parameters( | |
| config, nfft, alias_decay_db, delay_lengths, device, requires_grad | |
| ) | |
| self._setup_fdn_system(config, output_layer) | |
| def forward( | |
| self, | |
| inputs: torch.Tensor, | |
| ext_params: List[Dict[str, Any]], | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass through the FDN. | |
| Processes input signals through the Feedback Delay Network to generate | |
| reverberated output. Each input can have its own set of external parameters | |
| for dynamic control of the FDN characteristics. | |
| Parameters | |
| ---------- | |
| inputs : torch.Tensor | |
| Input tensor of shape (batch_size, signal_length). | |
| ext_params : List[Dict[str, Any]] | |
| List of external parameters for each input signal. Each dictionary | |
| can contain parameters to modify the FDN behavior during processing. | |
| The length must match the batch size. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Processed output tensor. Contains the reverberated signals. | |
| """ | |
| outputs = [] | |
| for x, ext_param in zip(inputs, ext_params): | |
| # Apply the FDN with the external parameters | |
| output = self.shell(x[..., None], ext_param) | |
| outputs.append(output) | |
| return torch.stack(outputs).squeeze(-1) | |
| def get_params(self) -> OrderedDict[str, Any]: | |
| """ | |
| Get the current parameters of the FDN. | |
| Extracts all learnable and configurable parameters from the FDN system | |
| for analysis, storage, or parameter transfer. All parameters are converted | |
| to CPU NumPy arrays for compatibility. | |
| Returns | |
| ------- | |
| OrderedDict[str, Any] | |
| Dictionary containing all FDN parameters: | |
| - 'delays': List of delay lengths in samples | |
| - 'onset_time': List of onset times in milliseconds | |
| - 'early_reflections': Direct path gain values | |
| - 'input_gains': Input gain coefficients for each delay line | |
| - 'output_gains': Output gain coefficients for each delay line | |
| - 'feedback_matrix': Mixing (feedback) matrix coefficients | |
| - 'attenuation': Attenuation coefficients for each delay line | |
| Notes | |
| ----- | |
| - All parameters are detached from the computation graph and moved to CPU | |
| - The returned parameters can be used to recreate or modify the FDN | |
| """ | |
| core = self.shell.get_core() | |
| map_matrix = core.branchA.feedback_loop.feedback.mixing_matrix.map | |
| params = OrderedDict() | |
| params["delays"] = self.delay_lengths.cpu().numpy().tolist() | |
| params["onset_time"] = self.onset | |
| params["early_reflections"] = ( | |
| core.branchB.early_reflections.param.cpu().detach().numpy().tolist() | |
| ) | |
| params["input_gains"] = ( | |
| core.branchA.input_gain.param.cpu().squeeze().detach().numpy().tolist() | |
| ) | |
| params["output_gains"] = ( | |
| core.branchA.output_gain.param[0].cpu().squeeze().detach().numpy().tolist() | |
| ) | |
| params["feedback_matrix"] = ( | |
| map_matrix(core.branchA.feedback_loop.feedback.mixing_matrix.param).cpu() | |
| .detach() | |
| .squeeze() | |
| .numpy() | |
| .tolist() | |
| ) | |
| # params["attenuation"] = ( | |
| # core.branchA.feedback_loop.feedback.attenuation.param.cpu() | |
| # .detach() | |
| # .numpy() | |
| # .tolist() | |
| # ) | |
| return params | |
| def _validate_delays(self, config: BaseConfig, delay_lengths: List[int]) -> None: | |
| """Validate delay lengths.""" | |
| if config.N != len(delay_lengths): | |
| raise ValueError( | |
| f"N ({config.N}) must match the length of delay_lengths ({len(delay_lengths)})" | |
| ) | |
| def _initialize_parameters( | |
| self, | |
| config: FDNConfig, | |
| nfft: int, | |
| alias_decay_db: float, | |
| delay_lengths: List[int], | |
| device: str, | |
| requires_grad: bool, | |
| ) -> None: | |
| """Initialize FDN parameters.""" | |
| self.device = torch.device(device) | |
| # Core FDN parameters | |
| self.N = config.N | |
| self.fs = config.fs | |
| self.nfft = nfft | |
| self.alias_decay_db = alias_decay_db | |
| self.requires_grad = requires_grad | |
| # Onset configuration | |
| self.early_reflections_type = config.early_reflections_type | |
| self.onset = ms_to_samps(torch.tensor(config.onset_time), config.fs) | |
| # Channel configuration | |
| self.in_ch = config.in_ch | |
| self.out_ch = config.out_ch | |
| # Delay configuration | |
| self.delay_lengths = torch.tensor( | |
| delay_lengths, device=self.device, dtype=torch.int64 | |
| ) | |
| def _setup_fdn_system(self, config: BaseConfig, output_layer: str) -> None: | |
| """Setup the complete FDN system.""" | |
| # Create FDN branches | |
| branch_a = self._create_fdn_branch( | |
| config.attenuation_config, config.mixing_matrix_config | |
| ) | |
| branch_b = self._create_direct_path(config) | |
| # Combine branches | |
| fdn_core = system.Parallel(brA=branch_a, brB=branch_b, sum_output=True) | |
| # Setup I/O layers | |
| input_layer = dsp.FFT(self.nfft) | |
| output_layer = self._create_output_layer(output_layer) | |
| # Create shell | |
| self.shell = system.Shell( | |
| core=fdn_core, | |
| input_layer=input_layer, | |
| output_layer=output_layer, | |
| ) | |
| def _create_output_layer(self, output_type: str): | |
| """Create the appropriate output layer based on type.""" | |
| if output_type == "time": | |
| return dsp.iFFTAntiAlias(nfft=self.nfft, alias_decay_db=self.alias_decay_db) | |
| elif output_type == "freq_complex": | |
| return dsp.Transform(transform=lambda x: x) | |
| elif output_type == "freq_mag": | |
| return dsp.Transform(transform=lambda x: torch.abs(x)) | |
| else: | |
| raise ValueError(f"Unsupported output layer type: {output_type}") | |
| def _create_fdn_branch( | |
| self, attenuation_config: FDNAttenuation, mixing_matrix_config: FDNMixing | |
| ): | |
| """Create the main FDN branch (branch A).""" | |
| # Input and output gains | |
| input_gain = dsp.Gain( | |
| size=(self.N, self.in_ch), | |
| nfft=self.nfft, | |
| requires_grad=self.requires_grad, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| output_gain = dsp.Gain( | |
| size=(self.out_ch, self.N), | |
| nfft=self.nfft, | |
| requires_grad=self.requires_grad, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| # Feedback loop components | |
| delays = self._create_delay_lines() | |
| mixing_matrix = self._create_mixing_matrix(mixing_matrix_config) | |
| attenuation = self._create_attenuation(attenuation_config) | |
| # Feedback path | |
| feedback = system.Series( | |
| OrderedDict({"mixing_matrix": mixing_matrix, "attenuation": attenuation}) | |
| ) | |
| # Recursion | |
| feedback_loop = system.Recursion(fF=delays, fB=feedback) | |
| # Complete FDN branch | |
| return system.Series( | |
| OrderedDict( | |
| { | |
| "input_gain": input_gain, | |
| "feedback_loop": feedback_loop, | |
| "output_gain": output_gain, | |
| } | |
| ) | |
| ) | |
| def _create_delay_lines(self): | |
| """Create parallel delay lines.""" | |
| delays = dsp.parallelDelay( | |
| size=(self.N,), | |
| max_len=self.delay_lengths.max(), | |
| nfft=self.nfft, | |
| isint=True, | |
| requires_grad=False, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| delays.assign_value(delays.sample2s(self.delay_lengths)) | |
| return delays | |
| def _create_mixing_matrix(self, config: FDNMixing): | |
| """Create orthogonal mixing matrix.""" | |
| if config.is_scattering or config.is_velvet_noise: | |
| m_L = torch.randint( | |
| low=1, | |
| high=int(torch.floor(min(self.delay_lengths) / 10)), | |
| size=[self.N], | |
| ) | |
| m_R = torch.randint( | |
| low=1, | |
| high=int(torch.floor(min(self.delay_lengths) / 10)), | |
| size=[self.N], | |
| ) | |
| if config.is_scattering: | |
| mixing = dsp.ScatteringMatrix( | |
| size=(config.n_stages, self.N, self.N), | |
| nfft=self.nfft, | |
| sparsity=config.sparsity, | |
| gain_per_sample=1.0, | |
| m_L=m_L, | |
| m_R=m_R, | |
| requires_grad=self.requires_grad, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| else: | |
| mixing = dsp.VelvetNoiseMatrix( | |
| size=(config.n_stages, self.N, self.N), | |
| nfft=self.nfft, | |
| density=1 / config.sparsity, | |
| gain_per_sample=1.0, | |
| m_L=m_L, | |
| m_R=m_R, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| elif config.mixing_type == "householder": | |
| mixing = dsp.HouseholderMatrix( | |
| size=(self.N, self.N), | |
| nfft=self.nfft, | |
| requires_grad=self.requires_grad, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| else: | |
| try: | |
| mixing = dsp.Matrix( | |
| size=(self.N, self.N), | |
| nfft=self.nfft, | |
| matrix_type=config.mixing_type, | |
| requires_grad=self.requires_grad, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) # TODO add hadamard, tiny rotation | |
| except: | |
| raise ValueError(f"Unsupported mixing type: {config.mixing_type}") | |
| return mixing | |
| def _create_direct_path(self, config: BaseConfig): | |
| """Create the direct path branch (branch B).""" | |
| onset_delay = dsp.parallelDelay( | |
| size=(self.in_ch,), | |
| max_len=self.onset, | |
| nfft=self.nfft, | |
| isint=True, | |
| requires_grad=False, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| if config.early_reflections_type == "FIR": | |
| L = self.delay_lengths.min() | |
| early_reflections = dsp.parallelFilter( | |
| size=(L-self.onset, self.in_ch), | |
| nfft=self.nfft, | |
| requires_grad=False, | |
| map=lambda x: x, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| else: | |
| early_reflections = dsp.Gain( | |
| size=(self.in_ch, self.out_ch), | |
| nfft=self.nfft, | |
| requires_grad=False, | |
| map=lambda x: x, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| self._configure_onset(onset_delay, early_reflections) | |
| return system.Series( | |
| OrderedDict( | |
| { | |
| "onset_delay": onset_delay, | |
| "early_reflections": early_reflections, | |
| } | |
| ) | |
| ) | |
| def _configure_onset(self, onset_delay, early_reflections): | |
| """Configure onset behavior based on early_reflections_type.""" | |
| # Ensure onset has correct number of values | |
| if len(self.onset) != self.in_ch: | |
| self.onset = self.onset.repeat(self.in_ch) | |
| if self.early_reflections_type is None: | |
| onset_delay.assign_value( | |
| onset_delay.sample2s(torch.zeros((self.in_ch,), device=self.device)) | |
| ) | |
| early_reflections.assign_value(torch.zeros((self.in_ch, 1))) | |
| elif self.early_reflections_type == "gain": | |
| onset_delay.assign_value(onset_delay.sample2s(torch.tensor(self.onset))) | |
| early_reflections.assign_value(torch.randn((self.in_ch, 1))) | |
| elif self.early_reflections_type == "FIR": | |
| velvet_noise = signal_gallery( | |
| batch_size=1, | |
| n_samples=early_reflections.size[0], | |
| n=self.in_ch, | |
| signal_type="velvet", | |
| fs=self.fs, | |
| rate=max(int(torch.rand(1,) / 100 * self.fs), self.fs / early_reflections.size[0] + 1), | |
| ).squeeze(0) | |
| early_reflections.assign_value(velvet_noise) | |
| else: | |
| raise ValueError(f"Unsupported onset type: {self.early_reflections_type}") | |
| def _create_attenuation(self, config: FDNAttenuation): | |
| """Create attenuation based on configuration type.""" | |
| if config.attenuation_type == "homogeneous": | |
| return self._create_homogeneous_attenuation(config) | |
| elif config.attenuation_type == "geq": | |
| return self._create_geq_attenuation(config) | |
| elif config.attenuation_type == "first_order_lp": | |
| return self._create_first_order_attenuation(config) | |
| else: | |
| raise ValueError(f"Unsupported attenuation type: {config.attenuation_type}") | |
| def _create_homogeneous_attenuation(self, config: FDNAttenuation): | |
| """Create homogeneous attenuation.""" | |
| attenuation = dsp.parallelGain( | |
| size=(self.N,), | |
| nfft=self.nfft, | |
| requires_grad=False, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| attenuation.map = MapGamma(self.delay_lengths) | |
| if config.attenuation_param == None: | |
| # Random attenuation within range | |
| random_rt = ( | |
| torch.rand((1,), device=self.device) | |
| * (config.attenuation_range[1] - config.attenuation_range[0]) | |
| + config.attenuation_range[0] | |
| ) | |
| attenuation_value = self._calculate_attenuation_value(random_rt) | |
| else: | |
| # Use specific attenuation parameter | |
| attenuation_value = self._calculate_attenuation_value( | |
| torch.tensor(config.attenuation_param, device=self.device) | |
| ) | |
| attenuation.assign_value(attenuation_value) | |
| return attenuation | |
| def _calculate_attenuation_value(self, rt_value: torch.Tensor) -> torch.Tensor: | |
| """Calculate attenuation value from RT value.""" | |
| return 10 ** ( | |
| (rt2slope(rt_value, self.fs) * torch.ones((self.N,), device=self.device)) | |
| / 20 | |
| ) | |
| def _create_geq_attenuation(self, config: FDNAttenuation): | |
| """Create GEQ-based attenuation.""" | |
| attenuation = parallelFDNAccurateGEQ( | |
| octave_interval=config.t60_octave_interval, | |
| nfft=self.nfft, | |
| fs=self.fs, | |
| delays=self.delay_lengths, | |
| alias_decay_db=self.alias_decay_db, | |
| start_freq=config.t60_center_freq[0], | |
| end_freq=config.t60_center_freq[-1], | |
| device=None, | |
| ) | |
| attenuation.assign_value( | |
| torch.tensor(config.attenuation_param[0], device=self.device) | |
| ) | |
| return attenuation | |
| def _create_first_order_attenuation(self, config: FDNAttenuation): | |
| """Create first-order shelving attenuation.""" | |
| attenuation = parallelFirstOrderShelving( | |
| nfft=self.nfft, | |
| fs=self.fs, | |
| rt_nyquist=config.rt_nyquist, | |
| delays=self.delay_lengths, | |
| alias_decay_db=self.alias_decay_db, | |
| device=self.device, | |
| ) | |
| attenuation.assign_value( | |
| torch.tensor(config.attenuation_param[0], device=self.device) | |
| ) | |
| return attenuation |