yoyolicoris commited on
Commit
df0ae2d
·
1 Parent(s): 40b18c2

add regression model checkpoints and necessary dependencies

Browse files
ltng/regression.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import lightning.pytorch as pl
6
+ from typing import Tuple, List, Optional
7
+
8
+
9
+ class ParamPrediction(pl.LightningModule):
10
+ def __init__(
11
+ self,
12
+ predictor: nn.Module,
13
+ condition: str = "wet",
14
+ **kwargs,
15
+ ) -> None:
16
+ super().__init__()
17
+
18
+ self.predictor = predictor
19
+ self.condition = condition
20
+
21
+ def forward(
22
+ self,
23
+ wet: Optional[torch.Tensor] = None,
24
+ dry: Optional[torch.Tensor] = None,
25
+ ):
26
+ match self.condition:
27
+ case "wet":
28
+ return self.predictor(wet)
29
+ case "dry":
30
+ return self.predictor(dry)
31
+ case "both":
32
+ return self.predictor(wet, dry)
33
+ case _:
34
+ raise ValueError(f"Invalid condition: {self.condition}")
35
+
36
+ def training_step(self, batch, batch_idx):
37
+ x, cond, dry, rel_path = batch
38
+ pred = self(cond, dry)
39
+
40
+ loss = F.mse_loss(pred, x)
41
+
42
+ self.log("loss", loss.item(), prog_bar=True, sync_dist=True)
43
+
44
+ return loss
45
+
46
+ def on_validation_epoch_start(self) -> None:
47
+ self.tmp_val_outputs = []
48
+
49
+ def validation_step(self, batch, batch_idx):
50
+ x, cond, dry, *_ = batch
51
+
52
+ pred = self(cond, dry)
53
+ loss = F.mse_loss(pred, x)
54
+
55
+ values = {
56
+ "loss": loss.item(),
57
+ "N": x.shape[0],
58
+ }
59
+ self.tmp_val_outputs.append(values)
60
+ return loss
61
+
62
+ def on_validation_epoch_end(self) -> None:
63
+ outputs = self.tmp_val_outputs
64
+ weights = [x["N"] for x in outputs]
65
+ avg_loss = np.average([x["loss"] for x in outputs], weights=weights)
66
+
67
+ self.log_dict(
68
+ {
69
+ "val_loss": avg_loss,
70
+ },
71
+ prog_bar=True,
72
+ sync_dist=True,
73
+ )
74
+
75
+ delattr(self, "tmp_val_outputs")
76
+
77
+ def on_test_epoch_start(self) -> None:
78
+ self.tmp_test_outputs = []
79
+
80
+ def test_step(self, batch, batch_idx):
81
+ x, cond, dry, *_ = batch
82
+
83
+ pred = self(cond, dry)
84
+ loss = F.mse_loss(pred, x)
85
+
86
+ values = {
87
+ "loss": loss.item(),
88
+ "N": x.shape[0],
89
+ }
90
+ self.tmp_test_outputs.append(values)
91
+ return loss
92
+
93
+ def on_test_epoch_end(self) -> None:
94
+ outputs = self.tmp_test_outputs
95
+ weights = [x["N"] for x in outputs]
96
+ avg_loss = np.average([x["loss"] for x in outputs], weights=weights)
97
+
98
+ self.log_dict(
99
+ {
100
+ "test_loss": avg_loss,
101
+ },
102
+ prog_bar=True,
103
+ sync_dist=True,
104
+ )
105
+
106
+ delattr(self, "tmp_test_outputs")
modules/encoder.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from functools import partial, reduce
5
+ from typing import Optional, List
6
+
7
+ from .utils import chain_functions
8
+
9
+
10
+ class LogSpectralCentroid(nn.Module):
11
+ def forward(self, spec):
12
+ # assume spec is of shape (..., freq, time)
13
+ freqs = torch.linspace(0, 1, spec.size(-2), device=spec.device)
14
+ spec_T = spec.transpose(-1, -2)
15
+ normalised_spec = spec_T / spec_T.sum(-1, keepdim=True).clamp_min(1e-8)
16
+ return torch.log(normalised_spec @ freqs + 1e-8).unsqueeze(-2)
17
+
18
+
19
+ class LogSpectralFlatness(nn.Module):
20
+ def forward(self, spec):
21
+ # assume spec is of shape (..., freq, time)
22
+ spec_pow = spec.clamp(1e-8).square()
23
+ log_gmean = spec_pow.log().mean(-2, keepdim=True)
24
+ log_amean = spec_pow.mean(-2, keepdim=True).log()
25
+ return log_gmean - log_amean
26
+
27
+
28
+ class LogSpectralBandwidth(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.centroid = LogSpectralCentroid()
32
+
33
+ def forward(self, spec):
34
+ # assume spec is of shape (..., freq, time)
35
+ freqs = torch.linspace(0, 1, spec.size(-2), device=spec.device)
36
+ centroid = self.centroid(spec).exp()
37
+ normalised_spec = spec / spec.sum(-2, keepdim=True).clamp_min(1e-8)
38
+ return (
39
+ torch.log(
40
+ (normalised_spec * (freqs[:, None] - centroid).square()).sum(
41
+ -2, keepdim=True
42
+ )
43
+ + 1e-8
44
+ )
45
+ * 0.5
46
+ )
47
+
48
+
49
+ class LogRMS(nn.Module):
50
+ def forward(self, frame):
51
+ return torch.log(frame.square().mean(-2, keepdim=True).sqrt() + 1e-8)
52
+
53
+
54
+ class LogCrest(nn.Module):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.rms = LogRMS()
58
+
59
+ def forward(self, frame):
60
+ log_rms = self.rms(frame)
61
+ return frame.abs().amax(-2, keepdim=True).add(1e-8).log() - log_rms
62
+
63
+
64
+ class LogSpread(nn.Module):
65
+ def __init__(self):
66
+ super().__init__()
67
+ self.rms = LogRMS()
68
+
69
+ def forward(self, frame):
70
+ log_rms = self.rms(frame)
71
+ return (frame.abs().add(1e-8).log() - log_rms).mean(-2, keepdim=True)
72
+
73
+
74
+ class MapAndMerge(nn.Module):
75
+ def __init__(self, funcs: List[nn.Module], dim=-1):
76
+ super().__init__()
77
+ self.funcs = nn.ModuleList(funcs)
78
+ self.dim = dim
79
+
80
+ def forward(self, frame):
81
+ return torch.cat([f(frame) for f in self.funcs], dim=self.dim)
82
+
83
+
84
+ class Frame(nn.Module):
85
+ def __init__(self, frame_length, hop_length, center=False):
86
+ super().__init__()
87
+ self.frame_length = frame_length
88
+ self.hop_length = hop_length
89
+ self.center = center
90
+
91
+ def forward(self, waveform):
92
+ if self.center:
93
+ waveform = F.pad(waveform, (self.frame_length // 2, self.frame_length // 2))
94
+ return waveform.unfold(-1, self.frame_length, self.hop_length).transpose(-1, -2)
95
+
96
+
97
+ class StatisticReduction(nn.Module):
98
+ def __init__(self, dim=-1):
99
+ super().__init__()
100
+ self.dim = dim
101
+
102
+ def forward(self, x):
103
+ mu = x.mean(self.dim, keepdim=True)
104
+ diffs = x - mu
105
+ std = diffs.square().mean(self.dim, keepdim=True).sqrt()
106
+ zscores = diffs / std.clamp_min(1e-8)
107
+ skews = zscores.pow(3).mean(self.dim, keepdim=True)
108
+ kurts = zscores.pow(4).mean(self.dim, keepdim=True) - 3
109
+ return torch.cat([mu, std, skews, kurts], dim=self.dim)
reg-ckpts/checkpoints/epoch=99-step=6500-val_loss=0.842.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3559ca46370e3d00c96498107e81e98119d65c9ee3ecc2bd45e5d92f8b51c9a5
3
+ size 111225779
reg-ckpts/config.yaml ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.4.0
2
+ seed_everything: false
3
+ trainer:
4
+ accelerator: gpu
5
+ strategy: auto
6
+ devices: 1
7
+ num_nodes: 1
8
+ precision: null
9
+ logger:
10
+ class_path: lightning.pytorch.loggers.WandbLogger
11
+ init_args:
12
+ name: null
13
+ save_dir: .
14
+ version: null
15
+ offline: false
16
+ dir: null
17
+ id: null
18
+ anonymous: null
19
+ project: vocal-fx-regression
20
+ log_model: false
21
+ experiment: null
22
+ prefix: ''
23
+ checkpoint_name: null
24
+ job_type: null
25
+ config: null
26
+ entity: null
27
+ reinit: null
28
+ tags: null
29
+ group: null
30
+ notes: null
31
+ magic: null
32
+ config_exclude_keys: null
33
+ config_include_keys: null
34
+ mode: null
35
+ allow_val_change: null
36
+ resume: null
37
+ force: null
38
+ tensorboard: null
39
+ sync_tensorboard: null
40
+ monitor_gym: null
41
+ save_code: null
42
+ fork_from: null
43
+ resume_from: null
44
+ settings: null
45
+ callbacks:
46
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
47
+ init_args:
48
+ dirpath: null
49
+ filename: '{epoch}-{step}-{val_loss:.3f}'
50
+ monitor: val_loss
51
+ verbose: false
52
+ save_last: true
53
+ save_top_k: 3
54
+ save_weights_only: false
55
+ mode: min
56
+ auto_insert_metric_name: true
57
+ every_n_train_steps: null
58
+ train_time_interval: null
59
+ every_n_epochs: 10
60
+ save_on_train_epoch_end: null
61
+ enable_version_counter: true
62
+ fast_dev_run: false
63
+ max_epochs: null
64
+ min_epochs: null
65
+ max_steps: 100000
66
+ min_steps: null
67
+ max_time: null
68
+ limit_train_batches: null
69
+ limit_val_batches: null
70
+ limit_test_batches: null
71
+ limit_predict_batches: null
72
+ overfit_batches: 0.0
73
+ val_check_interval: null
74
+ check_val_every_n_epoch: 10
75
+ num_sanity_val_steps: 2
76
+ log_every_n_steps: 1
77
+ enable_checkpointing: null
78
+ enable_progress_bar: null
79
+ enable_model_summary: null
80
+ accumulate_grad_batches: 1
81
+ gradient_clip_val: null
82
+ gradient_clip_algorithm: null
83
+ deterministic: null
84
+ benchmark: null
85
+ inference_mode: true
86
+ use_distributed_sampler: true
87
+ profiler: null
88
+ detect_anomaly: false
89
+ barebones: false
90
+ plugins: null
91
+ sync_batchnorm: false
92
+ reload_dataloaders_every_n_epochs: 0
93
+ default_root_dir: null
94
+ ckpt_path: null
95
+ data:
96
+ class_path: ltng.aug_data.GenDataModule
97
+ init_args:
98
+ train_root: /data2/chin-yun/sub_train
99
+ batch_size: 64
100
+ val_root: /data2/chin-yun/sub_val
101
+ test_root: null
102
+ optimizer:
103
+ class_path: torch.optim.AdamW
104
+ init_args:
105
+ lr: 0.001
106
+ betas:
107
+ - 0.9
108
+ - 0.999
109
+ eps: 1.0e-08
110
+ weight_decay: 0.01
111
+ amsgrad: false
112
+ maximize: false
113
+ foreach: null
114
+ capturable: false
115
+ differentiable: false
116
+ fused: null
117
+ model:
118
+ class_path: ltng.regression.ParamPrediction
119
+ init_args:
120
+ predictor:
121
+ class_path: modules.model.LightningSequential
122
+ init_args:
123
+ modules:
124
+ - class_path: modules.encoder.MapAndMerge
125
+ init_args:
126
+ funcs:
127
+ - class_path: torch.nn.Identity
128
+ - class_path: modules.fx.Hadamard
129
+ dim: 1
130
+ - class_path: modules.encoder.MapAndMerge
131
+ init_args:
132
+ funcs:
133
+ - class_path: modules.model.LightningSequential
134
+ init_args:
135
+ modules:
136
+ - class_path: modules.encoder.Frame
137
+ init_args:
138
+ frame_length: 1024
139
+ hop_length: 256
140
+ center: true
141
+ - class_path: modules.encoder.MapAndMerge
142
+ init_args:
143
+ funcs:
144
+ - class_path: modules.encoder.LogRMS
145
+ - class_path: modules.encoder.LogCrest
146
+ - class_path: modules.encoder.LogSpread
147
+ dim: -2
148
+ - class_path: modules.model.LogMelSpectrogram
149
+ init_args:
150
+ sample_rate: 44100
151
+ n_fft: 1024
152
+ win_length: null
153
+ hop_length: 256
154
+ f_min: 0.0
155
+ f_max: null
156
+ pad: 0
157
+ n_mels: 80
158
+ window_fn: torch.hann_window
159
+ power: 2.0
160
+ normalized: false
161
+ wkwargs: null
162
+ center: true
163
+ pad_mode: reflect
164
+ onesided: null
165
+ norm: null
166
+ mel_scale: htk
167
+ dim: -2
168
+ - class_path: torch.nn.Flatten
169
+ init_args:
170
+ start_dim: 1
171
+ end_dim: -2
172
+ - class_path: torch.nn.Conv1d
173
+ init_args:
174
+ in_channels: 332
175
+ out_channels: 512
176
+ kernel_size: 5
177
+ stride: 1
178
+ padding: 0
179
+ dilation: 1
180
+ groups: 1
181
+ bias: true
182
+ padding_mode: zeros
183
+ device: null
184
+ dtype: null
185
+ - class_path: torch.nn.AvgPool1d
186
+ init_args:
187
+ kernel_size: 3
188
+ stride: 3
189
+ padding: 0
190
+ ceil_mode: false
191
+ count_include_pad: true
192
+ - class_path: torch.nn.BatchNorm1d
193
+ init_args:
194
+ num_features: 512
195
+ eps: 1.0e-05
196
+ momentum: 0.1
197
+ affine: true
198
+ track_running_stats: true
199
+ device: null
200
+ dtype: null
201
+ - class_path: torch.nn.ReLU
202
+ init_args:
203
+ inplace: false
204
+ - class_path: torch.nn.Conv1d
205
+ init_args:
206
+ in_channels: 512
207
+ out_channels: 512
208
+ kernel_size: 5
209
+ stride: 1
210
+ padding: 0
211
+ dilation: 1
212
+ groups: 1
213
+ bias: true
214
+ padding_mode: zeros
215
+ device: null
216
+ dtype: null
217
+ - class_path: torch.nn.AvgPool1d
218
+ init_args:
219
+ kernel_size: 3
220
+ stride: 3
221
+ padding: 0
222
+ ceil_mode: false
223
+ count_include_pad: true
224
+ - class_path: torch.nn.BatchNorm1d
225
+ init_args:
226
+ num_features: 512
227
+ eps: 1.0e-05
228
+ momentum: 0.1
229
+ affine: true
230
+ track_running_stats: true
231
+ device: null
232
+ dtype: null
233
+ - class_path: torch.nn.ReLU
234
+ init_args:
235
+ inplace: false
236
+ - class_path: torch.nn.Conv1d
237
+ init_args:
238
+ in_channels: 512
239
+ out_channels: 768
240
+ kernel_size: 5
241
+ stride: 1
242
+ padding: 0
243
+ dilation: 1
244
+ groups: 1
245
+ bias: true
246
+ padding_mode: zeros
247
+ device: null
248
+ dtype: null
249
+ - class_path: torch.nn.AvgPool1d
250
+ init_args:
251
+ kernel_size: 3
252
+ stride: 3
253
+ padding: 0
254
+ ceil_mode: false
255
+ count_include_pad: true
256
+ - class_path: torch.nn.BatchNorm1d
257
+ init_args:
258
+ num_features: 768
259
+ eps: 1.0e-05
260
+ momentum: 0.1
261
+ affine: true
262
+ track_running_stats: true
263
+ device: null
264
+ dtype: null
265
+ - class_path: torch.nn.ReLU
266
+ init_args:
267
+ inplace: false
268
+ - class_path: torch.nn.Conv1d
269
+ init_args:
270
+ in_channels: 768
271
+ out_channels: 1024
272
+ kernel_size: 5
273
+ stride: 1
274
+ padding: 0
275
+ dilation: 1
276
+ groups: 1
277
+ bias: true
278
+ padding_mode: zeros
279
+ device: null
280
+ dtype: null
281
+ - class_path: torch.nn.AvgPool1d
282
+ init_args:
283
+ kernel_size: 3
284
+ stride: 3
285
+ padding: 0
286
+ ceil_mode: false
287
+ count_include_pad: true
288
+ - class_path: torch.nn.BatchNorm1d
289
+ init_args:
290
+ num_features: 1024
291
+ eps: 1.0e-05
292
+ momentum: 0.1
293
+ affine: true
294
+ track_running_stats: true
295
+ device: null
296
+ dtype: null
297
+ - class_path: torch.nn.ReLU
298
+ init_args:
299
+ inplace: false
300
+ - class_path: torch.nn.Conv1d
301
+ init_args:
302
+ in_channels: 1024
303
+ out_channels: 1024
304
+ kernel_size: 1
305
+ stride: 1
306
+ padding: 0
307
+ dilation: 1
308
+ groups: 1
309
+ bias: true
310
+ padding_mode: zeros
311
+ device: null
312
+ dtype: null
313
+ - class_path: torch.nn.AdaptiveMaxPool1d
314
+ init_args:
315
+ output_size: 1
316
+ return_indices: false
317
+ - class_path: torch.nn.Flatten
318
+ init_args:
319
+ start_dim: 1
320
+ end_dim: -1
321
+ - class_path: torch.nn.Linear
322
+ init_args:
323
+ in_features: 1024
324
+ out_features: 130
325
+ bias: true
326
+ device: null
327
+ dtype: null
328
+ condition: wet
reg-ckpts/param_stats.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddbef7000cb8d9ac735dfb3ccd6429df0668532c8779ac52774c032fb9058b4e
3
+ size 2480