Gloria Dal Santo commited on
Commit
d9e647a
·
1 Parent(s): 08eeac9

Create main source code

Browse files
src/__pycache__/config.cpython-310.pyc ADDED
Binary file (9.23 kB). View file
 
src/__pycache__/reverb.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
src/config.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard library imports
2
+ from pathlib import Path
3
+ import warnings
4
+
5
+ # Third-party imports
6
+ from typing import Union, Optional, List
7
+ import torch
8
+ from pydantic import BaseModel, model_validator, Field
9
+
10
+ class FDNAttenuation(BaseModel):
11
+ """
12
+ Configuration for attenuation filters in FDN.
13
+ """
14
+ attenuation_type: str = Field(
15
+ default="homogeneous",
16
+ description="Type of attenuation filter. Types can be 'homogeneous', 'geq', or 'first_order_lp'."
17
+ )
18
+ attenuation_range: List[float] = Field(
19
+ default_factory=lambda: [0.5, 3.5],
20
+ description="Attenuation range in seconds (used only when attenuation_param is not given)."
21
+ )
22
+ rt_nyquist: float = Field(
23
+ default=0.2,
24
+ description="RT at Nyquist (for first order filter)."
25
+ )
26
+ attenuation_param: Optional[List[List[float]]] = Field(
27
+ default=None,
28
+ description="T60 parameter. The size depends on the attenuation_type: " \
29
+ "'homogeneous' -> [num, 1]; " \
30
+ "'geq' -> [num, num_bands]; " \
31
+ "'first_order_lp' -> [num, 2]."
32
+ )
33
+ t60_octave_interval: int = Field(
34
+ default=1,
35
+ description="Octave interval for T60."
36
+ )
37
+ t60_center_freq: List[float] = Field(
38
+ default_factory=lambda: [63, 125, 250, 500, 1000, 2000, 4000, 8000],
39
+ description="Center frequencies for T60."
40
+ )
41
+
42
+ @model_validator(mode="after")
43
+ def check_geq_parameters(self) -> "FDNAttenuation":
44
+ """
45
+ Validate that for 'geq' attenuation type, t60_center_freq length matches
46
+ the second dimension of attenuation_param when provided.
47
+ """
48
+ if (self.attenuation_type == "geq" and
49
+ self.attenuation_param is not None and
50
+ len(self.attenuation_param) > 0):
51
+
52
+ # Get the number of frequency bands from attenuation_param
53
+ num_bands = len(self.attenuation_param[0])
54
+
55
+ if len(self.t60_center_freq) != num_bands:
56
+ raise ValueError(
57
+ f"For 'geq' attenuation type, length of t60_center_freq "
58
+ f"({len(self.t60_center_freq)}) must match the number of frequency bands "
59
+ f"in attenuation_param ({num_bands})"
60
+ )
61
+
62
+ return self
63
+
64
+ class FDNMixing(BaseModel):
65
+ """
66
+ Mixing matrix configuration for FDN.
67
+ """
68
+ mixing_type: str = Field(
69
+ default="orthogonal",
70
+ description="Type of mixing matrix: 'orthogonal', 'householder', 'hadamard', or 'rotation'."
71
+ )
72
+ is_scattering: bool = Field(
73
+ default=False,
74
+ description="If filter feedback matrix is used."
75
+ )
76
+ is_velvet_noise: bool = Field(
77
+ default=False,
78
+ description="If velvet noise is used."
79
+ )
80
+ sparsity: int = Field(
81
+ default=1,
82
+ description="Density for scattering mapping."
83
+ )
84
+ n_stages: int = Field(
85
+ default=3,
86
+ description="Number of stages in the scattering mapping."
87
+ )
88
+
89
+ @model_validator(mode="after")
90
+ def check_mixing_exclusivity(self) -> "FDNMixing":
91
+ """
92
+ Validate that is_scattering and is_velvet_noise are not both True.
93
+ """
94
+ if self.is_scattering and self.is_velvet_noise:
95
+ raise ValueError("is_scattering and is_velvet_noise cannot both be True")
96
+ return self
97
+
98
+ class FDNConfig(BaseModel):
99
+ """
100
+ FDN Configuration class.
101
+ """
102
+ in_ch: int = Field(
103
+ default=1,
104
+ description="Input channels."
105
+ )
106
+ out_ch: int = Field(
107
+ default=1,
108
+ description="Output channels."
109
+ )
110
+ fs: int = Field(
111
+ default=48000,
112
+ description="Sampling frequency."
113
+ )
114
+ N: int = Field(
115
+ default=6,
116
+ description="Number of delay lines."
117
+ )
118
+ delay_lengths: Optional[List[int]] = Field(
119
+ default=None,
120
+ description="Delay lengths in samples."
121
+ )
122
+ delay_range_ms: List[float] = Field(
123
+ default_factory=lambda: [20.0, 50.0],
124
+ description="Delay lengths range in ms."
125
+ )
126
+ delay_log_spacing: bool = Field(
127
+ default=False,
128
+ description="If delay lengths should be logarithmically spaced."
129
+ )
130
+ onset_time: List[float] = Field(
131
+ default_factory=lambda: [10],
132
+ description="Onset time in ms."
133
+ )
134
+ early_reflections_type: Optional[str] = Field(
135
+ default=None,
136
+ description="Type of early reflections: 'gain', 'FIR', or None."
137
+ )
138
+ drr: float = Field(
139
+ default=0.25,
140
+ description="Direct to reverberant ratio."
141
+ )
142
+ energy: Optional[float] = Field(
143
+ default=None,
144
+ description="Energy of the FDN."
145
+ )
146
+ gain_init: str = Field(
147
+ default="randn",
148
+ description="Gain initialization distribution."
149
+ )
150
+ attenuation_config: FDNAttenuation = Field(
151
+ default_factory=FDNAttenuation,
152
+ description="Attenuation configuration."
153
+ )
154
+ mixing_matrix_config: FDNMixing = Field(
155
+ default_factory=FDNMixing,
156
+ description="Mixing matrix configuration."
157
+ )
158
+ alias_decay_db: float = Field(
159
+ default=0.0,
160
+ description="Alias decay in dB."
161
+ )
162
+
163
+ @model_validator(mode="after")
164
+ def check_delay_lengths(self) -> "BaseConfig":
165
+ """
166
+ Validate that delay_lengths length matches N when provided, and check onset_time vs delay_range_ms.
167
+ """
168
+ if self.delay_lengths is not None:
169
+ if len(self.delay_lengths) != self.N:
170
+ raise ValueError(
171
+ f"Length of delay_lengths ({len(self.delay_lengths)}) must match N ({self.N})"
172
+ )
173
+ if max(self.onset_time) > self.delay_range_ms[0]:
174
+ warnings.warn(
175
+ f"Max onset_time ({self.onset_time} ms) is larger than first element of delay_range_ms ({self.delay_range_ms[0]} ms)"
176
+ )
177
+ return self
178
+
179
+ @model_validator(mode="after")
180
+ def check_early_reflections(self) -> "FDNConfig":
181
+ """
182
+ Set drr to 0 when early_reflections_type is None.
183
+ """
184
+ if self.early_reflections_type is None:
185
+ self.drr = 0.0
186
+ print("Setting drr to 0.0 since early_reflections_type is None")
187
+ return self
188
+
189
+ class FDNOptimConfig(BaseModel):
190
+ """
191
+ FDN Optimization Configuration class.
192
+ """
193
+ max_epochs: int = Field(
194
+ default=10,
195
+ description="Number of optimization iterations."
196
+ )
197
+ lr: float = Field(
198
+ default=1e-3,
199
+ description="Learning rate."
200
+ )
201
+ batch_size: int = Field(
202
+ default=1,
203
+ description="Batch size."
204
+ )
205
+ device: str = Field(
206
+ default="cuda",
207
+ description="Device to use for optimization."
208
+ )
209
+ dataset_length: int = Field(
210
+ default=100,
211
+ description="Dataset length."
212
+ )
213
+ train_dir: str = Field(
214
+ default=None,
215
+ description="Training directory."
216
+ )
217
+
218
+ class BaseConfig(BaseModel):
219
+ """
220
+ Base Configuration class for the overall system.
221
+ """
222
+ fs: int = Field(
223
+ default=48000,
224
+ description="Sampling frequency."
225
+ )
226
+ nfft: int = Field(
227
+ default=96000,
228
+ description="Number of FFT points."
229
+ )
230
+ fdn_config: Union[FDNConfig] = Field(
231
+ default_factory=FDNConfig,
232
+ description="FDN configuration."
233
+ )
234
+ optimize: bool = Field(
235
+ default=False,
236
+ description="Whether to optimize for colorlessness."
237
+ )
238
+ fdn_optim_config: FDNOptimConfig = Field(
239
+ default_factory=FDNOptimConfig,
240
+ description="Optimization configuration."
241
+ )
242
+ device: str = Field(
243
+ default="cuda",
244
+ description="Device to use."
245
+ )
246
+
247
+ @classmethod
248
+ def create_with_fdn_params(
249
+ cls,
250
+ N: int,
251
+ delay_lengths: List[int],
252
+ **kwargs
253
+ ) -> "BaseConfig":
254
+ """
255
+ Convenience method to create BaseConfig with FDN parameters.
256
+
257
+ Args:
258
+ N: Number of delay lines
259
+ delay_lengths: List of delay lengths in samples
260
+ **kwargs: Additional parameters for BaseConfig or FDNConfig
261
+ (prefix with 'fdn_' for FDNConfig parameters)
262
+
263
+ Returns:
264
+ BaseConfig instance with configured FDN parameters
265
+ """
266
+ # Separate FDN-specific kwargs from BaseConfig kwargs
267
+ fdn_kwargs = {}
268
+ base_kwargs = {}
269
+
270
+ for key, value in kwargs.items():
271
+ if key.startswith('fdn_'):
272
+ # Remove 'fdn_' prefix for FDNConfig parameters
273
+ fdn_kwargs[key[4:]] = value
274
+ else:
275
+ base_kwargs[key] = value
276
+
277
+ # Create FDNConfig with N and delay_lengths
278
+ fdn_config = FDNConfig(
279
+ N=N,
280
+ delay_lengths=delay_lengths,
281
+ **fdn_kwargs
282
+ )
283
+
284
+ # Create and return BaseConfig
285
+ return cls(fdn_config=fdn_config, **base_kwargs)
286
+
287
+ @model_validator(mode="after")
288
+ def validate_config(self) -> "BaseConfig":
289
+ """
290
+ Validate FDN config, and check device availability.
291
+ """
292
+
293
+ # Validate FDN configuration
294
+ if self.fdn_config.fs != self.fs:
295
+ raise ValueError("Sampling frequency in fdn_config must match fs")
296
+
297
+ # Validate device availability
298
+ original_device = self.device
299
+ if self.device.startswith("cuda"):
300
+ if not torch.cuda.is_available():
301
+ warnings.warn(f"CUDA not available, switching from '{original_device}' to 'cpu'")
302
+ self.device = "cpu"
303
+ elif self.device != "cuda": # specific cuda device like "cuda:0"
304
+ try:
305
+ device_id = int(self.device.split(":")[1])
306
+ if device_id >= torch.cuda.device_count():
307
+ warnings.warn(f"CUDA device {device_id} not available, switching to 'cuda:0'")
308
+ self.device = "cuda:0"
309
+ except (IndexError, ValueError):
310
+ warnings.warn(f"Invalid device format '{original_device}', switching to 'cuda'")
311
+ self.device = "cuda"
312
+ elif self.device == "mps":
313
+ if not torch.backends.mps.is_available():
314
+ warnings.warn(f"MPS not available, switching from '{original_device}' to 'cpu'")
315
+ self.device = "cpu"
316
+
317
+ # Sync device with optimization config
318
+ self.fdn_optim_config.device = self.device
319
+
320
+ return self
src/reverb.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import List, Literal, Optional, Dict, Any, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from flamo import dsp, system
7
+ from flamo.auxiliary.reverb import (
8
+ parallelFDNAccurateGEQ,
9
+ parallelFirstOrderShelving,
10
+ )
11
+ from flamo.functional import signal_gallery
12
+
13
+ from flareverb.config.config import (
14
+ BaseConfig,
15
+ FDNAttenuation,
16
+ FDNMixing,
17
+ FDNConfig,
18
+ )
19
+
20
+ from flareverb.utils import ms_to_samps, rt2slope
21
+ from flareverb.reverb import MapGamma
22
+
23
+ class BaseFDN(nn.Module):
24
+ """Base Feedback Delay Network (FDN) class for reverberation modeling.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ config: FDNConfig,
30
+ nfft: int,
31
+ alias_decay_db: float,
32
+ delay_lengths: List[int],
33
+ device: Literal["cpu", "cuda"] = "cuda",
34
+ requires_grad: bool = True,
35
+ output_layer: Literal["freq_complex", "freq_mag", "time"] = "time",
36
+ ) -> None:
37
+ """
38
+ """
39
+ super().__init__()
40
+
41
+ self._validate_delays(config, delay_lengths)
42
+ self._initialize_parameters(
43
+ config, nfft, alias_decay_db, delay_lengths, device, requires_grad
44
+ )
45
+ self._setup_fdn_system(config, output_layer)
46
+
47
+ def forward(
48
+ self,
49
+ inputs: torch.Tensor,
50
+ ext_params: List[Dict[str, Any]],
51
+ ) -> torch.Tensor:
52
+ """
53
+ Forward pass through the FDN.
54
+
55
+ Processes input signals through the Feedback Delay Network to generate
56
+ reverberated output. Each input can have its own set of external parameters
57
+ for dynamic control of the FDN characteristics.
58
+
59
+ Parameters
60
+ ----------
61
+ inputs : torch.Tensor
62
+ Input tensor of shape (batch_size, signal_length).
63
+ ext_params : List[Dict[str, Any]]
64
+ List of external parameters for each input signal. Each dictionary
65
+ can contain parameters to modify the FDN behavior during processing.
66
+ The length must match the batch size.
67
+
68
+ Returns
69
+ -------
70
+ torch.Tensor
71
+ Processed output tensor. Contains the reverberated signals.
72
+ """
73
+ outputs = []
74
+ for x, ext_param in zip(inputs, ext_params):
75
+ # Apply the FDN with the external parameters
76
+ output = self.shell(x[..., None], ext_param)
77
+ outputs.append(output)
78
+
79
+ return torch.stack(outputs).squeeze(-1)
80
+
81
+ def get_params(self) -> OrderedDict[str, Any]:
82
+ """
83
+ Get the current parameters of the FDN.
84
+
85
+ Extracts all learnable and configurable parameters from the FDN system
86
+ for analysis, storage, or parameter transfer. All parameters are converted
87
+ to CPU NumPy arrays for compatibility.
88
+
89
+ Returns
90
+ -------
91
+ OrderedDict[str, Any]
92
+ Dictionary containing all FDN parameters:
93
+ - 'delays': List of delay lengths in samples
94
+ - 'onset_time': List of onset times in milliseconds
95
+ - 'early_reflections': Direct path gain values
96
+ - 'input_gains': Input gain coefficients for each delay line
97
+ - 'output_gains': Output gain coefficients for each delay line
98
+ - 'feedback_matrix': Mixing (feedback) matrix coefficients
99
+ - 'attenuation': Attenuation coefficients for each delay line
100
+
101
+ Notes
102
+ -----
103
+ - All parameters are detached from the computation graph and moved to CPU
104
+ - The returned parameters can be used to recreate or modify the FDN
105
+ """
106
+ core = self.shell.get_core()
107
+
108
+ params = OrderedDict()
109
+ params["delays"] = self.delay_lengths.cpu().numpy().tolist()
110
+ params["onset_time"] = self.onset
111
+ params["early_reflections"] = (
112
+ core.branchB.early_reflections.param.cpu().detach().numpy().tolist()
113
+ )
114
+ params["input_gains"] = (
115
+ core.branchA.input_gain.param.cpu().detach().numpy().tolist()
116
+ )
117
+ params["output_gains"] = (
118
+ core.branchA.output_gain.param[0].cpu().detach().numpy().tolist()
119
+ )
120
+ params["feedback_matrix"] = (
121
+ core.branchA.feedback_loop.feedback.mixing_matrix.param.cpu()
122
+ .detach()
123
+ .numpy()
124
+ .tolist()
125
+ )
126
+ params["attenuation"] = (
127
+ core.branchA.feedback_loop.feedback.attenuation.param.cpu()
128
+ .detach()
129
+ .numpy()
130
+ .tolist()
131
+ )
132
+ return params
133
+
134
+ def _validate_delays(self, config: BaseConfig, delay_lengths: List[int]) -> None:
135
+ """Validate delay lengths."""
136
+ if config.N != len(delay_lengths):
137
+ raise ValueError(
138
+ f"N ({config.N}) must match the length of delay_lengths ({len(delay_lengths)})"
139
+ )
140
+
141
+ def _initialize_parameters(
142
+ self,
143
+ config: FDNConfig,
144
+ nfft: int,
145
+ alias_decay_db: float,
146
+ delay_lengths: List[int],
147
+ device: str,
148
+ requires_grad: bool,
149
+ ) -> None:
150
+ """Initialize FDN parameters."""
151
+ self.device = torch.device(device)
152
+
153
+ # Core FDN parameters
154
+ self.N = config.N
155
+ self.fs = config.fs
156
+ self.nfft = nfft
157
+ self.alias_decay_db = alias_decay_db
158
+ self.requires_grad = requires_grad
159
+
160
+ # Onset configuration
161
+ self.early_reflections_type = config.early_reflections_type
162
+ self.onset = ms_to_samps(torch.tensor(config.onset_time), config.fs)
163
+
164
+ # Channel configuration
165
+ self.in_ch = config.in_ch
166
+ self.out_ch = config.out_ch
167
+
168
+ # Delay configuration
169
+ self.delay_lengths = torch.tensor(
170
+ delay_lengths, device=self.device, dtype=torch.int64
171
+ )
172
+
173
+ def _setup_fdn_system(self, config: BaseConfig, output_layer: str) -> None:
174
+ """Setup the complete FDN system."""
175
+ # Create FDN branches
176
+ branch_a = self._create_fdn_branch(
177
+ config.attenuation_config, config.mixing_matrix_config
178
+ )
179
+ branch_b = self._create_direct_path(config)
180
+
181
+ # Combine branches
182
+ fdn_core = system.Parallel(brA=branch_a, brB=branch_b, sum_output=True)
183
+
184
+ # Setup I/O layers
185
+ input_layer = dsp.FFT(self.nfft)
186
+ output_layer = self._create_output_layer(output_layer)
187
+
188
+ # Create shell
189
+ self.shell = system.Shell(
190
+ core=fdn_core,
191
+ input_layer=input_layer,
192
+ output_layer=output_layer,
193
+ )
194
+
195
+ def _create_output_layer(self, output_type: str):
196
+ """Create the appropriate output layer based on type."""
197
+ if output_type == "time":
198
+ return dsp.iFFTAntiAlias(nfft=self.nfft, alias_decay_db=self.alias_decay_db)
199
+ elif output_type == "freq_complex":
200
+ return dsp.Transform(transform=lambda x: x)
201
+ elif output_type == "freq_mag":
202
+ return dsp.Transform(transform=lambda x: torch.abs(x))
203
+ else:
204
+ raise ValueError(f"Unsupported output layer type: {output_type}")
205
+
206
+ def _create_fdn_branch(
207
+ self, attenuation_config: FDNAttenuation, mixing_matrix_config: FDNMixing
208
+ ):
209
+ """Create the main FDN branch (branch A)."""
210
+ # Input and output gains
211
+ input_gain = dsp.Gain(
212
+ size=(self.N, self.in_ch),
213
+ nfft=self.nfft,
214
+ requires_grad=self.requires_grad,
215
+ alias_decay_db=self.alias_decay_db,
216
+ device=self.device,
217
+ )
218
+
219
+ output_gain = dsp.Gain(
220
+ size=(self.out_ch, self.N),
221
+ nfft=self.nfft,
222
+ requires_grad=self.requires_grad,
223
+ alias_decay_db=self.alias_decay_db,
224
+ device=self.device,
225
+ )
226
+
227
+ # Feedback loop components
228
+ delays = self._create_delay_lines()
229
+ mixing_matrix = self._create_mixing_matrix(mixing_matrix_config)
230
+ attenuation = self._create_attenuation(attenuation_config)
231
+
232
+ # Feedback path
233
+ feedback = system.Series(
234
+ OrderedDict({"mixing_matrix": mixing_matrix, "attenuation": attenuation})
235
+ )
236
+
237
+ # Recursion
238
+ feedback_loop = system.Recursion(fF=delays, fB=feedback)
239
+
240
+ # Complete FDN branch
241
+ return system.Series(
242
+ OrderedDict(
243
+ {
244
+ "input_gain": input_gain,
245
+ "feedback_loop": feedback_loop,
246
+ "output_gain": output_gain,
247
+ }
248
+ )
249
+ )
250
+
251
+ def _create_delay_lines(self):
252
+ """Create parallel delay lines."""
253
+ delays = dsp.parallelDelay(
254
+ size=(self.N,),
255
+ max_len=self.delay_lengths.max(),
256
+ nfft=self.nfft,
257
+ isint=True,
258
+ requires_grad=False,
259
+ alias_decay_db=self.alias_decay_db,
260
+ device=self.device,
261
+ )
262
+ delays.assign_value(delays.sample2s(self.delay_lengths))
263
+ return delays
264
+
265
+ def _create_mixing_matrix(self, config: FDNMixing):
266
+ """Create orthogonal mixing matrix."""
267
+ if config.is_scattering or config.is_velvet_noise:
268
+ m_L = torch.randint(
269
+ low=1,
270
+ high=int(torch.floor(min(self.delay_lengths) / 10)),
271
+ size=[self.N],
272
+ )
273
+ m_R = torch.randint(
274
+ low=1,
275
+ high=int(torch.floor(min(self.delay_lengths) / 10)),
276
+ size=[self.N],
277
+ )
278
+ if config.is_scattering:
279
+ mixing = dsp.ScatteringMatrix(
280
+ size=(config.n_stages, self.N, self.N),
281
+ nfft=self.nfft,
282
+ sparsity=config.sparsity,
283
+ gain_per_sample=1.0,
284
+ m_L=m_L,
285
+ m_R=m_R,
286
+ requires_grad=self.requires_grad,
287
+ alias_decay_db=self.alias_decay_db,
288
+ device=self.device,
289
+ )
290
+ else:
291
+ mixing = dsp.VelvetNoiseMatrix(
292
+ size=(config.n_stages, self.N, self.N),
293
+ nfft=self.nfft,
294
+ density=1 / config.sparsity,
295
+ gain_per_sample=1.0,
296
+ m_L=m_L,
297
+ m_R=m_R,
298
+ alias_decay_db=self.alias_decay_db,
299
+ device=self.device,
300
+ )
301
+ elif config.mixing_type == "householder":
302
+ mixing = dsp.HouseholderMatrix(
303
+ size=(self.N, self.N),
304
+ nfft=self.nfft,
305
+ requires_grad=self.requires_grad,
306
+ alias_decay_db=self.alias_decay_db,
307
+ device=self.device,
308
+ )
309
+ else:
310
+ try:
311
+ mixing = dsp.Matrix(
312
+ size=(self.N, self.N),
313
+ nfft=self.nfft,
314
+ matrix_type=config.mixing_type,
315
+ requires_grad=self.requires_grad,
316
+ alias_decay_db=self.alias_decay_db,
317
+ device=self.device,
318
+ ) # TODO add hadamard, tiny rotation
319
+ except:
320
+ raise ValueError(f"Unsupported mixing type: {config.mixing_type}")
321
+ return mixing
322
+
323
+ def _create_direct_path(self, config: BaseConfig):
324
+ """Create the direct path branch (branch B)."""
325
+ onset_delay = dsp.parallelDelay(
326
+ size=(self.in_ch,),
327
+ max_len=self.onset,
328
+ nfft=self.nfft,
329
+ isint=True,
330
+ requires_grad=False,
331
+ alias_decay_db=self.alias_decay_db,
332
+ device=self.device,
333
+ )
334
+
335
+ if config.early_reflections_type == "FIR":
336
+ L = self.delay_lengths.min()
337
+ early_reflections = dsp.parallelFilter(
338
+ size=(L-self.onset, self.in_ch),
339
+ nfft=self.nfft,
340
+ requires_grad=False,
341
+ map=lambda x: x,
342
+ alias_decay_db=self.alias_decay_db,
343
+ device=self.device,
344
+ )
345
+ else:
346
+ early_reflections = dsp.Gain(
347
+ size=(self.in_ch, self.out_ch),
348
+ nfft=self.nfft,
349
+ requires_grad=False,
350
+ map=lambda x: x,
351
+ alias_decay_db=self.alias_decay_db,
352
+ device=self.device,
353
+ )
354
+
355
+ self._configure_onset(onset_delay, early_reflections)
356
+
357
+ return system.Series(
358
+ OrderedDict(
359
+ {
360
+ "onset_delay": onset_delay,
361
+ "early_reflections": early_reflections,
362
+ }
363
+ )
364
+ )
365
+
366
+ def _configure_onset(self, onset_delay, early_reflections):
367
+ """Configure onset behavior based on early_reflections_type."""
368
+ # Ensure onset has correct number of values
369
+ if len(self.onset) != self.in_ch:
370
+ self.onset = self.onset.repeat(self.in_ch)
371
+ if self.early_reflections_type is None:
372
+ onset_delay.assign_value(
373
+ onset_delay.sample2s(torch.zeros((self.in_ch,), device=self.device))
374
+ )
375
+ early_reflections.assign_value(torch.zeros((self.in_ch, 1)))
376
+
377
+ elif self.early_reflections_type == "gain":
378
+ onset_delay.assign_value(onset_delay.sample2s(torch.tensor(self.onset)))
379
+ early_reflections.assign_value(torch.randn((self.in_ch, 1)))
380
+
381
+ elif self.early_reflections_type == "FIR":
382
+ velvet_noise = signal_gallery(
383
+ batch_size=1,
384
+ n_samples=early_reflections.size[0],
385
+ n=self.in_ch,
386
+ signal_type="velvet",
387
+ fs=self.fs,
388
+ rate=max(int(torch.rand(1,) / 100 * self.fs), self.fs / early_reflections.size[0] + 1),
389
+ ).squeeze(0)
390
+ early_reflections.assign_value(velvet_noise)
391
+ else:
392
+ raise ValueError(f"Unsupported onset type: {self.early_reflections_type}")
393
+
394
+ def _create_attenuation(self, config: FDNAttenuation):
395
+ """Create attenuation based on configuration type."""
396
+ if config.attenuation_type == "homogeneous":
397
+ return self._create_homogeneous_attenuation(config)
398
+ elif config.attenuation_type == "geq":
399
+ return self._create_geq_attenuation(config)
400
+ elif config.attenuation_type == "first_order_lp":
401
+ return self._create_first_order_attenuation(config)
402
+ else:
403
+ raise ValueError(f"Unsupported attenuation type: {config.attenuation_type}")
404
+
405
+ def _create_homogeneous_attenuation(self, config: FDNAttenuation):
406
+ """Create homogeneous attenuation."""
407
+ attenuation = dsp.parallelGain(
408
+ size=(self.N,),
409
+ nfft=self.nfft,
410
+ requires_grad=False,
411
+ alias_decay_db=self.alias_decay_db,
412
+ device=self.device,
413
+ )
414
+ attenuation.map = MapGamma(self.delay_lengths)
415
+
416
+ if config.attenuation_param == None:
417
+ # Random attenuation within range
418
+ random_rt = (
419
+ torch.rand((1,), device=self.device)
420
+ * (config.attenuation_range[1] - config.attenuation_range[0])
421
+ + config.attenuation_range[0]
422
+ )
423
+ attenuation_value = self._calculate_attenuation_value(random_rt)
424
+ else:
425
+ # Use specific attenuation parameter
426
+ attenuation_value = self._calculate_attenuation_value(
427
+ torch.tensor(config.attenuation_param, device=self.device)
428
+ )
429
+
430
+ attenuation.assign_value(attenuation_value)
431
+ return attenuation
432
+
433
+ def _calculate_attenuation_value(self, rt_value: torch.Tensor) -> torch.Tensor:
434
+ """Calculate attenuation value from RT value."""
435
+ return 10 ** (
436
+ (rt2slope(rt_value, self.fs) * torch.ones((self.N,), device=self.device))
437
+ / 20
438
+ )
439
+
440
+ def _create_geq_attenuation(self, config: FDNAttenuation):
441
+ """Create GEQ-based attenuation."""
442
+
443
+ attenuation = parallelFDNAccurateGEQ(
444
+ octave_interval=config.t60_octave_interval,
445
+ nfft=self.nfft,
446
+ fs=self.fs,
447
+ delays=self.delay_lengths,
448
+ alias_decay_db=self.alias_decay_db,
449
+ start_freq=config.t60_center_freq[0],
450
+ end_freq=config.t60_center_freq[-1],
451
+ device=None,
452
+ )
453
+ attenuation.assign_value(
454
+ torch.tensor(config.attenuation_param[0], device=self.device)
455
+ )
456
+ return attenuation
457
+
458
+ def _create_first_order_attenuation(self, config: FDNAttenuation):
459
+ """Create first-order shelving attenuation."""
460
+
461
+ attenuation = parallelFirstOrderShelving(
462
+ nfft=self.nfft,
463
+ fs=self.fs,
464
+ rt_nyquist=config.rt_nyquist,
465
+ delays=self.delay_lengths,
466
+ alias_decay_db=self.alias_decay_db,
467
+ device=self.device,
468
+ )
469
+ attenuation.assign_value(
470
+ torch.tensor(config.attenuation_param[0], device=self.device)
471
+ )
472
+ return attenuation