primepake commited on
Commit
067b9b6
·
1 Parent(s): 0672778

add training code and model

Browse files
flowae/configs/datasets/dae.yaml CHANGED
@@ -62,7 +62,7 @@ datasets:
62
 
63
  # Visualization
64
  visualize_ae_dir: /mnt/nvme/dito_audio
65
- visualize_ae_random_n_samples: 32
66
  eval_ae_max_samples: 100
67
  val_idx: [0, 1, 2, 3, 4, 5, 6, 7]
68
 
 
62
 
63
  # Visualization
64
  visualize_ae_dir: /mnt/nvme/dito_audio
65
+ visualize_ae_random_n_samples: 8
66
  eval_ae_max_samples: 100
67
  val_idx: [0, 1, 2, 3, 4, 5, 6, 7]
68
 
flowae/configs/experiments/dito-B-audio.yaml CHANGED
@@ -43,6 +43,6 @@ model:
43
  name: fm
44
  args: {timescale: 1000.0}
45
 
46
- render_sampler: {name: fm_euler_sampler_audio}
47
  render_n_steps: 50
48
 
 
43
  name: fm
44
  args: {timescale: 1000.0}
45
 
46
+ render_sampler: {name: fm_euler_sampler}
47
  render_n_steps: 50
48
 
flowae/configs/trainers/dito.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer: audio_ldm_trainer
2
+
3
+ autocast_bfloat16: true
4
+
5
+ max_iter: 300000
6
+ epoch_iter: 10000
7
+ eval_iter: 50000
8
+ save_iter: 50000
9
+ vis_iter: 50000
10
+
11
+ optimizers:
12
+ encoder:
13
+ name: adamw
14
+ args: {lr: 1.e-4}
15
+ renderer:
16
+ name: adamw
17
+ args: {lr: 1.e-4}
18
+
19
+ evaluate_ae: true
flowae/configs/trainers/glpto.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer: ldm_trainer
2
+
3
+ autocast_bfloat16: true
4
+
5
+ max_iter: 300000
6
+ epoch_iter: 10000
7
+ eval_iter: 50000
8
+ save_iter: 50000
9
+ vis_iter: 50000
10
+
11
+ optimizers:
12
+ encoder:
13
+ name: adam
14
+ args: {lr: 1.e-4, betas: [0.5, 0.9]}
15
+ renderer:
16
+ name: adam
17
+ args: {lr: 1.e-4, betas: [0.5, 0.9]}
18
+ disc:
19
+ name: adam
20
+ args: {lr: 1.e-4, betas: [0.5, 0.9]}
21
+ gan_start_after_iters: 50000
22
+ find_unused_parameters: true
23
+
24
+ evaluate_ae: true
flowae/configs/trainers/zdm.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer: ldm_trainer
2
+
3
+ autocast_bfloat16: true
4
+
5
+ max_iter: 400000
6
+ epoch_iter: 10000
7
+ eval_iter: 100000
8
+ save_iter: 100000
9
+ vis_iter: 100000
10
+ ckpt_select_metric:
11
+ name: zdm_ema_loss
12
+ type: min
13
+
14
+ optimizers:
15
+ zdm:
16
+ name: adamw
17
+ args: {lr: 1.e-4, weight_decay: 0.0}
18
+ find_unused_parameters: true
19
+
20
+ evaluate_zdm: true
flowae/models/networks/consistency_audio_decoder_unet.py CHANGED
@@ -135,13 +135,13 @@ class AudioUpsample(nn.Module):
135
 
136
  gn_1 = F.silu(self.gn_1(x))
137
  # 1D interpolation upsampling
138
- upsample = F.interpolate(gn_1, scale_factor=self.upsample_factor, mode='nearest')
139
  f_1 = self.f_1(upsample)
140
  gn_2 = self.gn_2(f_1)
141
 
142
  f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
143
 
144
- return f_2 + F.interpolate(x_skip, scale_factor=self.upsample_factor, mode='nearest')
145
 
146
 
147
  @register('audio_diffusion_unet')
@@ -272,7 +272,7 @@ class AudioDiffusionUNet(nn.Module):
272
  z_proj = F.interpolate(
273
  z_proj,
274
  size=x.shape[-1],
275
- mode='nearest' # or 'linear' for smoother interpolation
276
  )
277
 
278
  # Add latent conditioning to audio features
 
135
 
136
  gn_1 = F.silu(self.gn_1(x))
137
  # 1D interpolation upsampling
138
+ upsample = F.interpolate(gn_1, scale_factor=self.upsample_factor, mode='linear')
139
  f_1 = self.f_1(upsample)
140
  gn_2 = self.gn_2(f_1)
141
 
142
  f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
143
 
144
+ return f_2 + F.interpolate(x_skip, scale_factor=self.upsample_factor, mode='linear')
145
 
146
 
147
  @register('audio_diffusion_unet')
 
272
  z_proj = F.interpolate(
273
  z_proj,
274
  size=x.shape[-1],
275
+ mode='linear' # or 'linear' for smoother interpolation
276
  )
277
 
278
  # Add latent conditioning to audio features
flowae/run.py CHANGED
@@ -13,7 +13,7 @@ def make_args():
13
  parser.add_argument('--tag', '-t', default=None)
14
  parser.add_argument('--resume', '-r', action='store_true')
15
  parser.add_argument('--force-replace', '-f', action='store_true')
16
- parser.add_argument('--wandb', '-w', action='store_true')
17
  parser.add_argument('--save-root', default='save')
18
  parser.add_argument('--eval-only', action='store_true')
19
  args = parser.parse_args()
@@ -45,7 +45,7 @@ def make_env(args):
45
  env['exp_name'] = exp_name
46
 
47
  env['save_dir'] = os.path.join(args.save_root, exp_name)
48
- env['wandb'] = args.wandb
49
  env['resume'] = args.resume
50
  env['force_replace'] = args.force_replace
51
  return env
 
13
  parser.add_argument('--tag', '-t', default=None)
14
  parser.add_argument('--resume', '-r', action='store_true')
15
  parser.add_argument('--force-replace', '-f', action='store_true')
16
+ parser.add_argument('--comet', '-c', action='store_true', help='Enable Comet ML logging')
17
  parser.add_argument('--save-root', default='save')
18
  parser.add_argument('--eval-only', action='store_true')
19
  args = parser.parse_args()
 
45
  env['exp_name'] = exp_name
46
 
47
  env['save_dir'] = os.path.join(args.save_root, exp_name)
48
+ env['comet'] = args.comet
49
  env['resume'] = args.resume
50
  env['force_replace'] = args.force_replace
51
  return env
flowae/run.sh CHANGED
@@ -1,2 +1,2 @@
1
  torchrun --nnodes=1 --nproc-per-node=1 run.py --config configs/experiments/dito-B-f8c4-noise-sync.yaml --save-root /mnt/nvme/dito
2
- torchrun --nnodes=1 --nproc-per-node=1 run.py --config configs/experiments/dito-B-audio.yaml --save-root /mnt/nvme/ditogit ad
 
1
  torchrun --nnodes=1 --nproc-per-node=1 run.py --config configs/experiments/dito-B-f8c4-noise-sync.yaml --save-root /mnt/nvme/dito
2
+ torchrun --nnodes=1 --nproc-per-node=1 run.py --config configs/experiments/dito-B-audio.yaml --save-root /mnt/nvme/dit2 --comet
flowae/trainers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .trainers import register, trainers_dict
2
+ from . import base_trainer
3
+ from . import ldm_trainer
4
+ from . import audio_ldm_trainer
flowae/trainers/audio_ldm_trainer.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from PIL import Image
7
+
8
+ import utils
9
+ from .trainers import register
10
+ from trainers.base_trainer import BaseTrainer
11
+ from models.ldm.dac.audiotools import AudioSignal
12
+ import soundfile as sf
13
+ import numpy as np
14
+ import torchaudio
15
+ import time
16
+
17
+ from datetime import datetime
18
+ import matplotlib.pyplot as plt
19
+ from tqdm import tqdm
20
+
21
+ @register('audio_ldm_trainer')
22
+ class AudioLDMTrainer(BaseTrainer):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+
26
+ def make_model(self):
27
+ super().make_model()
28
+ self.has_optimizer = dict()
29
+ total_params = 0
30
+ for name, m in self.model.named_children():
31
+ params = utils.compute_num_params(m, text=False)
32
+ self.log(f' .{name} {params}')
33
+ total_params = total_params + params
34
+ # Log to Comet
35
+ if self.experiment:
36
+ self.experiment.log_metric(f"model/{name}_params", params)
37
+
38
+ if self.experiment:
39
+ self.experiment.log_metric("model/total_params", total_params)
40
+
41
+ def make_optimizers(self):
42
+ self.optimizers = dict()
43
+ self.has_optimizer = dict()
44
+ for name, spec in self.config.optimizers.items():
45
+ self.optimizers[name] = utils.make_optimizer(self.model.get_parameters(name), spec)
46
+ self.has_optimizer[name] = True
47
+
48
+ # Log optimizer config to Comet
49
+ if self.experiment:
50
+ self.experiment.log_parameters({
51
+ f"optimizer/{name}/type": spec.get("type", "adam"),
52
+ f"optimizer/{name}/lr": spec.get("lr", 1e-4),
53
+ f"optimizer/{name}/weight_decay": spec.get("weight_decay", 0),
54
+ })
55
+
56
+ def train_step(self, data, bp=True):
57
+ kwargs = {'has_optimizer': self.has_optimizer}
58
+
59
+ # Start timing
60
+ step_start_time = time.time()
61
+ # Audio-specific data preparation
62
+ if 'signal' in data:
63
+ # Convert AudioSignal to tensor format expected by model
64
+ audio_data = data['signal'].audio_data # [batch, channels, samples]
65
+ sample_rate = data['signal'].sample_rate
66
+
67
+ # Prepare data dict for model
68
+ model_data = {
69
+ 'inp': audio_data,
70
+ 'gt': audio_data, # For autoencoder training
71
+ 'sample_rate': sample_rate
72
+ }
73
+ else:
74
+ model_data = data
75
+
76
+ # self.log(f'Audio data shape: {model_data["inp"].shape}')
77
+
78
+ # Log batch info to Comet
79
+ if self.experiment and self.iter % 500 == 0:
80
+ self.experiment.log_metric("train/batch_size", model_data["inp"].shape[0], step=self.iter)
81
+ self.experiment.log_metric("train/audio_length_samples", model_data["inp"].shape[-1], step=self.iter)
82
+ self.experiment.log_metric("train/audio_duration_sec",
83
+ model_data["inp"].shape[-1] / model_data.get("sample_rate", 24000),
84
+ step=self.iter)
85
+
86
+
87
+ if self.config.get('autocast_bfloat16', False):
88
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
89
+ ret = self.model_ddp(model_data, mode='loss', **kwargs)
90
+ else:
91
+ ret = self.model_ddp(model_data, mode='loss', **kwargs)
92
+
93
+ loss = ret.pop('loss')
94
+ ret['loss'] = loss.item()
95
+
96
+ if bp:
97
+ self.model_ddp.zero_grad(set_to_none=True)
98
+ loss.backward()
99
+
100
+ # Log gradients to Comet
101
+ if self.experiment and self.iter % 5 == 0:
102
+ self._log_gradients()
103
+
104
+ for name, o in self.optimizers.items():
105
+ if name != 'disc':
106
+ o.step()
107
+
108
+
109
+ if hasattr(self.model, 'update_ema'):
110
+ self.model.update_ema()
111
+
112
+ # Log training metrics to Comet
113
+ if self.experiment:
114
+ # Log all losses
115
+ for k, v in ret.items():
116
+ if 'loss' in k.lower():
117
+ self.experiment.log_metric(f"train/{k}", v, step=self.iter)
118
+
119
+ # Log learning rates
120
+ for name, opt in self.optimizers.items():
121
+ lr = opt.param_groups[0]['lr']
122
+ self.experiment.log_metric(f"train/lr_{name}", lr, step=self.iter)
123
+
124
+ # Log timing
125
+ step_time = time.time() - step_start_time
126
+ self.experiment.log_metric("train/step_time", step_time, step=self.iter)
127
+
128
+ # Log GPU memory usage
129
+ if torch.cuda.is_available():
130
+ self.experiment.log_metric("train/gpu_memory_allocated",
131
+ torch.cuda.memory_allocated() / 1e9,
132
+ step=self.iter)
133
+ self.experiment.log_metric("train/gpu_memory_reserved",
134
+ torch.cuda.memory_reserved() / 1e9,
135
+ step=self.iter)
136
+ return ret
137
+
138
+ def _log_gradients(self):
139
+ """Log gradient statistics to Comet ML"""
140
+ if not self.experiment:
141
+ return
142
+
143
+ grad_stats = {}
144
+ for name, param in self.model.named_parameters():
145
+ if param.grad is not None:
146
+ grad_norm = param.grad.norm().item()
147
+ grad_mean = param.grad.mean().item()
148
+ grad_std = param.grad.std().item()
149
+
150
+ # Log aggregate stats by module
151
+ module_name = name.split('.')[0]
152
+ if module_name not in grad_stats:
153
+ grad_stats[module_name] = {
154
+ 'norm': [],
155
+ 'mean': [],
156
+ 'std': []
157
+ }
158
+ grad_stats[module_name]['norm'].append(grad_norm)
159
+ grad_stats[module_name]['mean'].append(grad_mean)
160
+ grad_stats[module_name]['std'].append(grad_std)
161
+
162
+ # Log aggregated stats
163
+ for module, stats in grad_stats.items():
164
+ self.experiment.log_metric(f"gradients/{module}/norm_mean", np.mean(stats['norm']), step=self.iter)
165
+ self.experiment.log_metric(f"gradients/{module}/norm_max", np.max(stats['norm']), step=self.iter)
166
+
167
+
168
+ def run_training(self):
169
+ config = self.config
170
+ max_iter = config['max_iter']
171
+ epoch_iter = config['epoch_iter']
172
+ assert max_iter % epoch_iter == 0
173
+ max_epoch = max_iter // epoch_iter
174
+
175
+ save_iter = config.get('save_iter')
176
+ if save_iter is not None:
177
+ assert save_iter % epoch_iter == 0
178
+ save_epoch = save_iter // epoch_iter
179
+ print('save_epoch', save_epoch)
180
+ else:
181
+ save_epoch = max_epoch + 1
182
+
183
+ eval_iter = config.get('eval_iter')
184
+ if eval_iter is not None:
185
+ assert eval_iter % epoch_iter == 0
186
+ eval_epoch = eval_iter // epoch_iter
187
+ else:
188
+ eval_epoch = max_epoch + 1
189
+
190
+ vis_iter = config.get('vis_iter')
191
+ if vis_iter is not None:
192
+ assert vis_iter % epoch_iter == 0
193
+ vis_epoch = vis_iter // epoch_iter
194
+ else:
195
+ vis_epoch = max_epoch + 1
196
+
197
+ if config.get('ckpt_select_metric') is not None:
198
+ m = config.ckpt_select_metric
199
+ self.ckpt_select_metric = m.name
200
+ self.ckpt_select_type = m.type
201
+ if m.type == 'min':
202
+ self.ckpt_select_v = 1e18
203
+ elif m.type == 'max':
204
+ self.ckpt_select_v = -1e18
205
+ else:
206
+ self.ckpt_select_metric = None
207
+ self.ckpt_select_v = 0
208
+
209
+ self.train_loader = self.loaders['train']
210
+ self.train_loader_sampler = self.loader_samplers['train']
211
+ self.train_loader_epoch = 0
212
+ self.train_loader_iter = None
213
+
214
+ self.iter = 0
215
+
216
+ if self.resume_ckpt is not None:
217
+ for _ in range(self.resume_ckpt['iter']):
218
+ self.iter += 1
219
+ self.at_train_iter_start()
220
+ self.ckpt_select_v = self.resume_ckpt['ckpt_select_v']
221
+ self.train_loader_epoch = self.resume_ckpt['train_loader_epoch']
222
+ self.train_loader_iter = None
223
+ self.resume_ckpt = None
224
+ self.log(f'Resumed iter status.')
225
+
226
+ self.visualize()
227
+
228
+ start_epoch = self.iter // epoch_iter + 1
229
+
230
+ for epoch in range(start_epoch, max_epoch + 1):
231
+ self.log_buffer = [f'Epoch {epoch}']
232
+
233
+ for sampler in self.loader_samplers.values():
234
+ if sampler is not self.train_loader_sampler:
235
+ sampler.set_epoch(epoch)
236
+
237
+ self.model_ddp.train()
238
+
239
+ pbar = range(1, epoch_iter + 1)
240
+ if self.is_master and epoch == start_epoch:
241
+ pbar = tqdm(pbar, desc='train', leave=False)
242
+
243
+ t_data = 0
244
+ t_nondata = 0
245
+ t_before_data = time.time()
246
+
247
+ for _ in pbar:
248
+ self.iter += 1
249
+ self.at_train_iter_start()
250
+
251
+ try:
252
+ if self.train_loader_iter is None:
253
+ raise StopIteration
254
+ data = next(self.train_loader_iter)
255
+ except StopIteration:
256
+ self.train_loader_epoch += 1
257
+ self.train_loader_sampler.set_epoch(self.train_loader_epoch)
258
+ self.train_loader_iter = iter(self.train_loader)
259
+ data = next(self.train_loader_iter)
260
+
261
+ t_after_data = time.time()
262
+ t_data += t_after_data - t_before_data
263
+
264
+ for k, v in data.items():
265
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
266
+
267
+ ret = self.train_step(data)
268
+
269
+ t_before_data = time.time()
270
+ t_nondata += t_before_data - t_after_data
271
+
272
+ if self.is_master and epoch == start_epoch:
273
+ pbar.set_description(desc=f'train: loss={ret["loss"]:.4f}')
274
+
275
+ # save the model every 1000 iterations
276
+ if self.iter % 2000 == 0:
277
+ self.save_ckpt(f'ckpt-{self.iter}.pth')
278
+
279
+ self.save_ckpt('ckpt-last.pth')
280
+
281
+ if epoch % save_epoch == 0 and epoch != max_epoch:
282
+ self.save_ckpt(f'ckpt-{self.iter}.pth')
283
+
284
+ if epoch % eval_epoch == 0:
285
+ with torch.no_grad():
286
+ eval_ave_scalars = self.evaluate()
287
+ if self.ckpt_select_metric is not None:
288
+ v = eval_ave_scalars[self.ckpt_select_metric].item()
289
+ if ((self.ckpt_select_type == 'min' and v < self.ckpt_select_v) or
290
+ (self.ckpt_select_type == 'max' and v > self.ckpt_select_v)):
291
+ self.ckpt_select_v = v
292
+ self.save_ckpt('ckpt-best.pth')
293
+
294
+ if epoch % vis_epoch == 0:
295
+ with torch.no_grad():
296
+ self.visualize()
297
+
298
+ def evaluate(self):
299
+ self.model_ddp.eval()
300
+
301
+ ave_scalars = dict()
302
+ pbar = self.loaders['val']
303
+
304
+ for data in pbar:
305
+ # Prepare audio data for GPU
306
+ if 'signal' in data:
307
+ data['signal'] = data['signal'].to(self.device)
308
+ else:
309
+ for k, v in data.items():
310
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
311
+
312
+ ret = self.train_step(data, bp=False)
313
+
314
+ bs = data['signal'].batch_size if 'signal' in data else len(next(iter(data.values())))
315
+ for k, v in ret.items():
316
+ if ave_scalars.get(k) is None:
317
+ ave_scalars[k] = utils.Averager()
318
+ ave_scalars[k].add(v, n=bs)
319
+
320
+ self.sync_ave_scalars(ave_scalars)
321
+
322
+ # Audio-specific evaluation
323
+ if self.config.get('evaluate_ae', False):
324
+ ave_scalars.update(self.evaluate_audio_ae())
325
+
326
+ if self.config.get('evaluate_zdm', False):
327
+ ema = self.config.get('evaluate_zdm_ema', True)
328
+ ave_scalars.update(self.evaluate_audio_zdm(ema=ema))
329
+
330
+ logtext = 'val:'
331
+ for k, v in ave_scalars.items():
332
+ logtext += f' {k}={v.item():.4f}'
333
+ self.log_scalar('val/' + k, v.item())
334
+
335
+ # Log to Comet
336
+ if self.experiment:
337
+ self.experiment.log_metric(f"val/{k}", v.item(), step=self.iter)
338
+
339
+ self.log_buffer.append(logtext)
340
+
341
+ return ave_scalars
342
+
343
+ def visualize(self):
344
+ self.model_ddp.eval()
345
+
346
+ if self.config.get('evaluate_ae', False):
347
+ self.visualize_audio_ae_random()
348
+
349
+ if self.config.get('evaluate_zdm', False):
350
+ ema = self.config.get('evaluate_zdm_ema', True)
351
+ self.visualize_audio_zdm_random(ema=ema)
352
+
353
+ def evaluate_audio_ae(self):
354
+ """Audio autoencoder evaluation with spectral metrics"""
355
+ max_samples = self.config.get('eval_ae_max_samples', 1000)
356
+ self.loader_samplers['eval_ae'].set_epoch(0)
357
+
358
+ l1_loss_avg = utils.Averager()
359
+ snr_avg = utils.Averager()
360
+ spectral_convergence_avg = utils.Averager()
361
+ cnt = 0
362
+
363
+ # Create cache directories for audio samples
364
+ cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gen')
365
+ cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gt')
366
+ if self.is_master:
367
+ utils.ensure_path(cache_gen_dir, force_replace=True)
368
+ utils.ensure_path(cache_gt_dir, force_replace=True)
369
+ dist.barrier()
370
+
371
+ for data in self.loaders['eval_ae']:
372
+ if 'signal' in data:
373
+ data['signal'] = data['signal'].to(self.device)
374
+ signal = data['signal']
375
+ else:
376
+ for k, v in data.items():
377
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
378
+ signal = AudioSignal(data['inp'], data.get('sample_rate', 22050))
379
+
380
+ # Get reconstruction
381
+ pred_audio = self.model(data, mode='pred')
382
+ if isinstance(pred_audio, dict):
383
+ pred_audio = pred_audio.get('audio', pred_audio.get('recons', pred_audio))
384
+
385
+ recons = AudioSignal(pred_audio, signal.sample_rate)
386
+
387
+ # SNR calculation
388
+ signal_power = (signal.audio_data ** 2).mean()
389
+ noise_power = ((recons.audio_data - signal.audio_data) ** 2).mean()
390
+ snr = 10 * torch.log10(signal_power / (noise_power + 1e-8))
391
+ snr_avg.add(snr.item())
392
+
393
+ # Spectral convergence
394
+ stft_transform = torchaudio.transforms.Spectrogram(
395
+ n_fft=1024,
396
+ hop_length=256,
397
+ power=2
398
+ ).to(self.device)
399
+
400
+ orig_spec = stft_transform(signal.audio_data)
401
+ recon_spec = stft_transform(recons.audio_data)
402
+
403
+ spec_diff = torch.norm(orig_spec - recon_spec, p='fro')
404
+ spec_norm = torch.norm(orig_spec, p='fro')
405
+ spectral_convergence = spec_diff / (spec_norm + 1e-8)
406
+ spectral_convergence_avg.add(spectral_convergence.item())
407
+
408
+ l1_loss = torch.nn.functional.l1_loss(recons.audio_data, signal.audio_data).item()
409
+ l1_loss_avg.add(l1_loss)
410
+
411
+ # Save audio samples for potential subjective evaluation
412
+ for i in range(min(signal.batch_size, 5)): # Save up to 5 per batch
413
+ idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
414
+ if max_samples is None or idx < max_samples:
415
+ # Save as wav files
416
+ sf.write(
417
+ os.path.join(cache_gen_dir, f'{idx}.wav'),
418
+ recons[i].audio_data.cpu().numpy().T,
419
+ int(recons[i].sample_rate)
420
+ )
421
+ sf.write(
422
+ os.path.join(cache_gt_dir, f'{idx}.wav'),
423
+ signal[i].audio_data.cpu().numpy().T,
424
+ int(signal[i].sample_rate)
425
+ )
426
+ cnt += 1
427
+
428
+ dist.barrier()
429
+
430
+ # Sync metrics across processes
431
+ for avg_metric in [l1_loss_avg, snr_avg, spectral_convergence_avg]:
432
+ vt = torch.tensor(avg_metric.item(), device=self.device)
433
+ dist.all_reduce(vt, op=dist.ReduceOp.SUM)
434
+ torch.cuda.synchronize()
435
+ avg_metric.v = vt.item() / int(os.environ['WORLD_SIZE'])
436
+
437
+ if self.is_master:
438
+ prefix = 'eval_ae'
439
+ ret = {
440
+ f'{prefix}/L1_Loss': l1_loss_avg.item(),
441
+ f'{prefix}/SNR': snr_avg.item(),
442
+ f'{prefix}/Spectral_Convergence': spectral_convergence_avg.item(),
443
+ }
444
+ else:
445
+ ret = {}
446
+ dist.barrier()
447
+
448
+ ret = {k: utils.Averager(v) for k, v in ret.items()}
449
+ return ret
450
+
451
+ def evaluate_audio_zdm(self, ema):
452
+ """Audio latent diffusion model evaluation"""
453
+ max_samples = self.config.get('eval_zdm_max_samples', 1000)
454
+ self.loader_samplers['eval_zdm'].set_epoch(0)
455
+
456
+ cnt = 0
457
+ l1_loss_avg = utils.Averager()
458
+ cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gen')
459
+ cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'audio_gt')
460
+ if self.is_master:
461
+ utils.ensure_path(cache_gen_dir, force_replace=True)
462
+ utils.ensure_path(cache_gt_dir, force_replace=True)
463
+ dist.barrier()
464
+
465
+ for data in self.loaders['eval_zdm']:
466
+ if 'signal' in data:
467
+ data['signal'] = data['signal'].to(self.device)
468
+ gt_signal = data['signal']
469
+ else:
470
+ for k, v in data.items():
471
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
472
+ gt_signal = AudioSignal(data['inp'], data.get('sample_rate', 22050))
473
+
474
+ # Generate samples from latent diffusion model
475
+ net_kwargs = dict()
476
+ uncond_net_kwargs = dict()
477
+ # Add conditioning if available (e.g., for conditional generation)
478
+
479
+ pred_audio = self.model.generate_samples(
480
+ batch_size=gt_signal.batch_size,
481
+ n_steps=self.model.zdm_n_steps,
482
+ net_kwargs=net_kwargs,
483
+ uncond_net_kwargs=uncond_net_kwargs,
484
+ ema=ema
485
+ )
486
+
487
+ pred_signal = AudioSignal(pred_audio, gt_signal.sample_rate)
488
+
489
+ l1_loss = torch.nn.functional.l1_loss(pred_signal.audio_data, gt_signal.audio_data).item()
490
+ l1_loss_avg.add(l1_loss)
491
+
492
+ # Save samples
493
+ for i in range(min(gt_signal.batch_size, 5)):
494
+ idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
495
+ if max_samples is None or idx < max_samples:
496
+ sf.write(
497
+ os.path.join(cache_gen_dir, f'{idx}.wav'),
498
+ pred_signal[i].audio_data.cpu().numpy().T,
499
+ int(pred_signal[i].sample_rate)
500
+ )
501
+ sf.write(
502
+ os.path.join(cache_gt_dir, f'{idx}.wav'),
503
+ gt_signal[i].audio_data.cpu().numpy().T,
504
+ int(gt_signal[i].sample_rate)
505
+ )
506
+ cnt += 1
507
+
508
+ dist.barrier()
509
+
510
+ # Sync metrics
511
+ for avg_metric in [l1_loss_avg]:
512
+ vt = torch.tensor(avg_metric.item(), device=self.device)
513
+ dist.all_reduce(vt, op=dist.ReduceOp.SUM)
514
+ torch.cuda.synchronize()
515
+ avg_metric.v = vt.item() / int(os.environ['WORLD_SIZE'])
516
+
517
+ if self.is_master:
518
+ prefix = 'eval_zdm' + ('_ema' if ema else '')
519
+ ret = {
520
+ f'{prefix}/l1_loss_avg': l1_loss_avg.item(),
521
+ }
522
+ else:
523
+ ret = {}
524
+ dist.barrier()
525
+
526
+ ret = {k: utils.Averager(v) for k, v in ret.items()}
527
+ return ret
528
+
529
+ def visualize_audio_ae_random(self):
530
+ """Save random audio reconstructions for listening"""
531
+ if self.is_master:
532
+ idx_list = list(range(len(self.datasets['eval_ae'])))
533
+ random.shuffle(idx_list)
534
+ n_samples = self.config.get('visualize_ae_random_n_samples', 8)
535
+
536
+ audio_samples = []
537
+
538
+ for idx in idx_list[:n_samples]:
539
+ data = self.datasets['eval_ae'][idx]
540
+
541
+ # Prepare data
542
+ if 'signal' in data:
543
+ signal = data['signal'].unsqueeze(0).to(self.device)
544
+ model_data = {
545
+ 'inp': signal.audio_data,
546
+ 'gt': signal.audio_data,
547
+ 'sample_rate': signal.sample_rate
548
+ }
549
+ else:
550
+ for k, v in data.items():
551
+ data[k] = v.unsqueeze(0).to(self.device) if torch.is_tensor(v) else v
552
+ signal = AudioSignal(data['inp'], data.get('sample_rate', 24000))
553
+ model_data = data
554
+
555
+ # Get reconstruction
556
+ pred_audio = self.model(model_data, mode='pred')
557
+ if isinstance(pred_audio, dict):
558
+ pred_audio = pred_audio.get('audio', pred_audio.get('recons', pred_audio))
559
+
560
+ recons = AudioSignal(pred_audio, signal.sample_rate)
561
+
562
+ # Save to file and log to Comet
563
+ self.save_audio_sample(signal, f'audio_ae_original_{idx}')
564
+ self.save_audio_sample(recons, f'audio_ae_recons_{idx}')
565
+
566
+ dist.barrier()
567
+
568
+ def visualize_audio_zdm_random(self, ema):
569
+ """Save random audio generations from latent diffusion model"""
570
+ if self.is_master:
571
+ n_samples = self.config.get('visualize_zdm_random_n_samples', 8)
572
+
573
+ for i in range(n_samples):
574
+ # Generate random sample
575
+ net_kwargs = dict()
576
+ uncond_net_kwargs = dict()
577
+
578
+ # Get a reference from dataset for parameters like sample_rate
579
+ ref_data = self.datasets['eval_ae'][0]
580
+ if 'signal' in ref_data:
581
+ ref_signal = ref_data['signal']
582
+ sample_rate = ref_signal.sample_rate
583
+ batch_size = 1
584
+ else:
585
+ sample_rate = ref_data.get('sample_rate', 24000)
586
+ batch_size = 1
587
+
588
+ pred_audio = self.model.generate_samples(
589
+ batch_size=batch_size,
590
+ n_steps=self.model.zdm_n_steps,
591
+ net_kwargs=net_kwargs,
592
+ uncond_net_kwargs=uncond_net_kwargs,
593
+ ema=ema
594
+ )
595
+
596
+ pred_signal = AudioSignal(pred_audio, sample_rate)
597
+
598
+ # Save generated audio
599
+ self.save_audio_sample(pred_signal, f'audio_zdm_generated_{i}')
600
+
601
+ dist.barrier()
602
+
603
+ def save_audio_sample(self, audio_signal, name):
604
+ """Save audio sample and log to Comet ML"""
605
+ try:
606
+ # Ensure audio is in correct format
607
+ audio_data = audio_signal.audio_data.cpu()
608
+
609
+ # Handle different dimensions
610
+ if audio_data.dim() == 3: # [batch, channels, samples]
611
+ audio_data = audio_data[0] # Take first sample
612
+ if audio_data.dim() == 2: # [channels, samples]
613
+ audio_data = audio_data.transpose(0, 1) # [samples, channels]
614
+ elif audio_data.dim() == 1: # [samples]
615
+ audio_data = audio_data.unsqueeze(1) # [samples, 1]
616
+
617
+ audio_data = audio_data.numpy()
618
+
619
+ # Normalize if needed
620
+ if np.abs(audio_data).max() > 1.0:
621
+ audio_data = audio_data / np.abs(audio_data).max()
622
+
623
+ # Save to file
624
+ save_path = os.path.join(self.env['save_dir'], 'audio_samples')
625
+ os.makedirs(save_path, exist_ok=True)
626
+
627
+ file_path = os.path.join(save_path, f'{name}_step_{self.iter}.wav')
628
+ sf.write(file_path, audio_data, int(audio_signal.sample_rate))
629
+
630
+ # Log to Comet ML
631
+ if self.experiment:
632
+ self.experiment.log_audio(
633
+ file_path,
634
+ metadata={
635
+ 'name': name,
636
+ 'step': self.iter,
637
+ 'sample_rate': int(audio_signal.sample_rate),
638
+ 'duration': len(audio_data) / audio_signal.sample_rate,
639
+ 'channels': audio_data.shape[1] if audio_data.ndim > 1 else 1
640
+ },
641
+ step=self.iter
642
+ )
643
+
644
+ # Also log spectrograms for visualization
645
+ if self.iter % self.config.get('spectrogram_log_freq', 1000) == 0:
646
+ self._log_spectrogram(audio_signal, name)
647
+
648
+ self.log(f"Saved audio sample: {file_path}")
649
+
650
+ except Exception as e:
651
+ self.log(f"Error saving audio sample {name}: {e}")
652
+ if self.experiment:
653
+ self.experiment.log_text(f"Error saving audio {name}: {str(e)}", step=self.iter)
654
+
655
+ def _log_spectrogram(self, audio_signal, name):
656
+ """Log spectrogram visualization to Comet ML"""
657
+ if not self.experiment:
658
+ return
659
+
660
+ try:
661
+
662
+ # Compute spectrogram
663
+ stft_transform = torchaudio.transforms.Spectrogram(
664
+ n_fft=2048,
665
+ hop_length=512,
666
+ power=2
667
+ )
668
+
669
+ audio_data = audio_signal.audio_data
670
+ if audio_data.dim() == 3:
671
+ audio_data = audio_data[0]
672
+ if audio_data.dim() == 2:
673
+ audio_data = audio_data[0] # Take first channel
674
+
675
+ spec = stft_transform(audio_data.cpu())
676
+ spec_db = 10 * torch.log10(spec + 1e-8)
677
+
678
+ # Create figure
679
+ fig, ax = plt.subplots(figsize=(10, 4))
680
+ im = ax.imshow(
681
+ spec_db.numpy(),
682
+ aspect='auto',
683
+ origin='lower',
684
+ cmap='viridis',
685
+ extent=[0, len(audio_data) / audio_signal.sample_rate, 0, audio_signal.sample_rate / 2]
686
+ )
687
+ ax.set_xlabel('Time (s)')
688
+ ax.set_ylabel('Frequency (Hz)')
689
+ ax.set_title(f'{name} - Spectrogram')
690
+ plt.colorbar(im, ax=ax, label='dB')
691
+
692
+ # Log to Comet
693
+ self.experiment.log_figure(f"spectrogram/{name}", fig, step=self.iter)
694
+ plt.close(fig)
695
+
696
+ except Exception as e:
697
+ self.log(f"Error logging spectrogram for {name}: {e}")
698
+
699
+
700
+
701
+ def save_checkpoint(self, tag="latest"):
702
+ """Save checkpoint and log to Comet ML"""
703
+ checkpoint_path = super().save_checkpoint(tag)
704
+
705
+ if self.experiment and checkpoint_path:
706
+ # Log checkpoint to Comet
707
+ self.experiment.log_model(
708
+ f"checkpoint_{tag}",
709
+ checkpoint_path,
710
+ metadata={
711
+ "step": self.iter,
712
+ "tag": tag,
713
+ "timestamp": datetime.now().isoformat()
714
+ }
715
+ )
flowae/trainers/base_trainer.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import copy
4
+ from datetime import timedelta
5
+
6
+ import yaml
7
+ import torch
8
+ import torch.distributed as dist
9
+ from omegaconf import OmegaConf
10
+ from tqdm import tqdm
11
+ from torch.utils.data import IterableDataset, DataLoader
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from torch.nn.parallel import DistributedDataParallel
14
+
15
+ import datasets
16
+ import models
17
+ import utils
18
+ from .trainers import register
19
+ from comet_ml import Experiment
20
+
21
+ from datetime import datetime
22
+
23
+ @register('base_trainer')
24
+ class BaseTrainer():
25
+
26
+ def __init__(self, env, config):
27
+ self.env = env
28
+ self.config = config
29
+ self.config_dict = OmegaConf.to_container(config, resolve=True)
30
+
31
+ if config.get('allow_tf32', False):
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+ torch.backends.cudnn.allow_tf32 = True
34
+
35
+ dist.init_process_group(backend='nccl', timeout=timedelta(minutes=240))
36
+ self.rank = int(os.environ['RANK'])
37
+ self.local_rank = int(os.environ['LOCAL_RANK'])
38
+ self.world_size = int(os.environ['WORLD_SIZE'])
39
+ self.node_id = int(os.environ['GROUP_RANK'])
40
+ self.node_tot = self.world_size // int(os.environ['LOCAL_WORLD_SIZE'])
41
+ self.is_master = (self.rank == 0)
42
+
43
+ torch.cuda.set_device(self.local_rank)
44
+ self.device = torch.device('cuda', torch.cuda.current_device())
45
+
46
+ if self.is_master:
47
+ # Setup path
48
+ if env['resume']:
49
+ replace = False
50
+ force_replace = False
51
+ else:
52
+ replace = True
53
+ force_replace = env['force_replace']
54
+ utils.ensure_path(env['save_dir'], replace=replace, force_replace=force_replace)
55
+
56
+ # Save config
57
+ with open(os.path.join(env['save_dir'], 'config.yaml'), 'w') as f:
58
+ yaml.dump(self.config_dict, f, sort_keys=False)
59
+
60
+ # Setup logging
61
+ logger = utils.set_logger(os.path.join(env['save_dir'], 'log.txt'))
62
+ self.log = logger.info
63
+
64
+ # Initialize Comet ML experiment
65
+ self.experiment = None
66
+ if self.is_master: # Only log from master process
67
+ self.experiment = Experiment(
68
+ project_name=self.config.get("comet_project", "audio-ldm"),
69
+ workspace=os.environ.get("COMET_WORKSPACE"),
70
+ experiment_name=self.config.get("exp_name", f"audio_ldm_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
71
+ )
72
+
73
+ # Log hyperparameters
74
+ self.experiment.log_parameters(self.config)
75
+
76
+ # Add tags
77
+ tags = self.config.get("tags", ["audio", "ldm", "diffusion"])
78
+ for tag in tags:
79
+ self.experiment.add_tag(tag)
80
+ else:
81
+ self.log = lambda *args, **kwargs: None
82
+ self.experiment = None
83
+ dist.barrier()
84
+
85
+ self.log(f'Environment setup done. World size: {self.world_size}.')
86
+
87
+ def run(self, eval_only=False):
88
+ self.make_datasets()
89
+
90
+ resume_ckpt = os.path.join(self.env['save_dir'], 'ckpt-last.pth')
91
+ resume = (self.env['resume'] and os.path.isfile(resume_ckpt))
92
+ if resume:
93
+ self.resume_ckpt = torch.load(resume_ckpt, map_location='cpu')
94
+ else:
95
+ self.resume_ckpt = None
96
+
97
+ self.make_model()
98
+ if resume:
99
+ self.model.load_state_dict(self.resume_ckpt['model']['sd'])
100
+ self.resume_ckpt['model'] = None
101
+ self.log(f'Resumed model from checkpoint {resume_ckpt}.')
102
+
103
+ if eval_only:
104
+ self.model_ddp = self.model
105
+ with torch.no_grad():
106
+ self.log_buffer = [f'Eval']
107
+ self.iter = 0
108
+ self.evaluate()
109
+ self.visualize()
110
+ self.log(', '.join(self.log_buffer))
111
+
112
+ else:
113
+ self.model_ddp = DistributedDataParallel(
114
+ self.model,
115
+ device_ids=[self.local_rank],
116
+ find_unused_parameters=self.config.get('find_unused_parameters', False)
117
+ )
118
+
119
+ self.make_optimizers()
120
+ if resume:
121
+ for name, optimizer in self.resume_ckpt['optimizers'].items():
122
+ self.optimizers[name].load_state_dict(optimizer['sd'])
123
+ self.resume_ckpt['optimizers'] = None
124
+ self.log(f'Resumed optimizers.')
125
+
126
+ self.run_training()
127
+
128
+ self.on_train_end()
129
+
130
+ def on_train_end(self):
131
+ """Called at the end of training"""
132
+ if self.experiment:
133
+ # Log final model
134
+ model_path = os.path.join(self.env['save_dir'], 'final_model.pt')
135
+ torch.save(self.model.state_dict(), model_path)
136
+ self.experiment.log_model("final_model", model_path)
137
+
138
+ # End the experiment
139
+ self.experiment.end()
140
+
141
+ def make_distributed_loader(self, dataset, batch_size, shuffle, drop_last, num_workers, pin_memory):
142
+ assert batch_size % self.world_size == 0
143
+ assert num_workers % self.world_size == 0
144
+ if isinstance(dataset, IterableDataset):
145
+ sampler = None
146
+ else:
147
+ sampler = DistributedSampler(dataset, shuffle=shuffle)
148
+ loader = DataLoader(
149
+ dataset,
150
+ batch_size=batch_size // self.world_size,
151
+ drop_last=drop_last,
152
+ sampler=sampler,
153
+ num_workers=num_workers // self.world_size,
154
+ pin_memory=pin_memory
155
+ )
156
+ return loader, sampler
157
+
158
+ def make_datasets(self):
159
+ self.datasets = dict()
160
+ self.loaders = dict()
161
+ self.loader_samplers = dict()
162
+
163
+ for split, spec in self.config.datasets.items():
164
+ loader_spec = spec.pop('loader')
165
+
166
+ dataset = datasets.make(spec)
167
+ self.datasets[split] = dataset
168
+ if isinstance(dataset, IterableDataset):
169
+ self.log(f'Dataset {split}: IterableDataset')
170
+ else:
171
+ self.log(f'Dataset {split}: len={len(dataset)}')
172
+
173
+ drop_last = loader_spec.get('drop_last', True)
174
+ shuffle = loader_spec.get('shuffle', True)
175
+ self.loaders[split], self.loader_samplers[split] = self.make_distributed_loader(
176
+ dataset,
177
+ loader_spec.batch_size,
178
+ shuffle,
179
+ drop_last,
180
+ loader_spec.num_workers,
181
+ loader_spec.get('pin_memory', True)
182
+ )
183
+
184
+ def make_model(self):
185
+ model = models.make(self.config.model)
186
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
187
+ self.model = model.to(self.device)
188
+ self.log(f'Model: #params={utils.compute_num_params(model)}')
189
+
190
+ def make_optimizers(self):
191
+ self.optimizers = {'model': utils.make_optimizer(self.model.parameters(), self.config.optimizers['model'])}
192
+
193
+ def run_training(self):
194
+ config = self.config
195
+ max_iter = config['max_iter']
196
+ epoch_iter = config['epoch_iter']
197
+ assert max_iter % epoch_iter == 0
198
+ max_epoch = max_iter // epoch_iter
199
+
200
+ save_iter = config.get('save_iter')
201
+ if save_iter is not None:
202
+ assert save_iter % epoch_iter == 0
203
+ save_epoch = save_iter // epoch_iter
204
+ print('save_epoch', save_epoch)
205
+ else:
206
+ save_epoch = max_epoch + 1
207
+
208
+ eval_iter = config.get('eval_iter')
209
+ if eval_iter is not None:
210
+ assert eval_iter % epoch_iter == 0
211
+ eval_epoch = eval_iter // epoch_iter
212
+ else:
213
+ eval_epoch = max_epoch + 1
214
+
215
+ vis_iter = config.get('vis_iter')
216
+ if vis_iter is not None:
217
+ assert vis_iter % epoch_iter == 0
218
+ vis_epoch = vis_iter // epoch_iter
219
+ else:
220
+ vis_epoch = max_epoch + 1
221
+
222
+ if config.get('ckpt_select_metric') is not None:
223
+ m = config.ckpt_select_metric
224
+ self.ckpt_select_metric = m.name
225
+ self.ckpt_select_type = m.type
226
+ if m.type == 'min':
227
+ self.ckpt_select_v = 1e18
228
+ elif m.type == 'max':
229
+ self.ckpt_select_v = -1e18
230
+ else:
231
+ self.ckpt_select_metric = None
232
+ self.ckpt_select_v = 0
233
+
234
+ self.train_loader = self.loaders['train']
235
+ self.train_loader_sampler = self.loader_samplers['train']
236
+ self.train_loader_epoch = 0
237
+ self.train_loader_iter = None
238
+
239
+ self.iter = 0
240
+
241
+ if self.resume_ckpt is not None:
242
+ for _ in range(self.resume_ckpt['iter']):
243
+ self.iter += 1
244
+ self.at_train_iter_start()
245
+ self.ckpt_select_v = self.resume_ckpt['ckpt_select_v']
246
+ self.train_loader_epoch = self.resume_ckpt['train_loader_epoch']
247
+ self.train_loader_iter = None
248
+ self.resume_ckpt = None
249
+ self.log(f'Resumed iter status.')
250
+
251
+ if config.get('vis_before_training', False):
252
+ self.visualize()
253
+
254
+ start_epoch = self.iter // epoch_iter + 1
255
+ epoch_timer = utils.EpochTimer(max_epoch - start_epoch + 1)
256
+
257
+ for epoch in range(start_epoch, max_epoch + 1):
258
+ self.log_buffer = [f'Epoch {epoch}']
259
+
260
+ for sampler in self.loader_samplers.values():
261
+ if sampler is not self.train_loader_sampler:
262
+ sampler.set_epoch(epoch)
263
+
264
+ self.model_ddp.train()
265
+
266
+ ave_scalars = dict()
267
+ pbar = range(1, epoch_iter + 1)
268
+ if self.is_master and epoch == start_epoch:
269
+ pbar = tqdm(pbar, desc='train', leave=False)
270
+
271
+ t_data = 0
272
+ t_nondata = 0
273
+ t_before_data = time.time()
274
+
275
+ for _ in pbar:
276
+ self.iter += 1
277
+ self.at_train_iter_start()
278
+
279
+ try:
280
+ if self.train_loader_iter is None:
281
+ raise StopIteration
282
+ data = next(self.train_loader_iter)
283
+ except StopIteration:
284
+ self.train_loader_epoch += 1
285
+ self.train_loader_sampler.set_epoch(self.train_loader_epoch)
286
+ self.train_loader_iter = iter(self.train_loader)
287
+ data = next(self.train_loader_iter)
288
+
289
+ t_after_data = time.time()
290
+ t_data += t_after_data - t_before_data
291
+
292
+ for k, v in data.items():
293
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
294
+
295
+ ret = self.train_step(data)
296
+
297
+
298
+
299
+ t_before_data = time.time()
300
+ t_nondata += t_before_data - t_after_data
301
+
302
+ if self.is_master and epoch == start_epoch:
303
+ pbar.set_description(desc=f'train: loss={ret["loss"]:.4f}')
304
+
305
+ # save the model every 1000 iterations
306
+ if self.iter % 100 == 0:
307
+ self.save_ckpt(f'ckpt-{self.iter}.pth')
308
+
309
+ self.save_ckpt('ckpt-last.pth')
310
+
311
+ if epoch % save_epoch == 0 and epoch != max_epoch:
312
+ self.save_ckpt(f'ckpt-{self.iter}.pth')
313
+
314
+ if epoch % eval_epoch == 0:
315
+ with torch.no_grad():
316
+ eval_ave_scalars = self.evaluate()
317
+ if self.ckpt_select_metric is not None:
318
+ v = eval_ave_scalars[self.ckpt_select_metric].item()
319
+ if ((self.ckpt_select_type == 'min' and v < self.ckpt_select_v) or
320
+ (self.ckpt_select_type == 'max' and v > self.ckpt_select_v)):
321
+ self.ckpt_select_v = v
322
+ self.save_ckpt('ckpt-best.pth')
323
+
324
+ if epoch % vis_epoch == 0:
325
+ with torch.no_grad():
326
+ self.visualize()
327
+
328
+ def at_train_iter_start(self):
329
+ pass
330
+
331
+ def train_step(self, data, bp=True):
332
+
333
+ print('data', data)
334
+ if self.config.get('autocast_bfloat16', False):
335
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
336
+ ret = self.model_ddp(data)
337
+ else:
338
+ ret = self.model_ddp(data)
339
+
340
+ loss = ret.pop('loss')
341
+ ret['loss'] = loss.item()
342
+ if bp:
343
+ self.model_ddp.zero_grad()
344
+ loss.backward()
345
+ for o in self.optimizers.values():
346
+ o.step()
347
+ return ret
348
+
349
+ def evaluate(self):
350
+ self.model_ddp.eval()
351
+
352
+ ave_scalars = dict()
353
+ pbar = self.loaders['val']
354
+
355
+ for data in pbar:
356
+ for k, v in data.items():
357
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
358
+
359
+ ret = self.train_step(data, bp=False)
360
+
361
+ bs = len(next(iter(data.values())))
362
+ for k, v in ret.items():
363
+ if ave_scalars.get(k) is None:
364
+ ave_scalars[k] = utils.Averager()
365
+ ave_scalars[k].add(v, n=bs)
366
+
367
+ self.sync_ave_scalars(ave_scalars)
368
+
369
+ logtext = 'val:'
370
+ for k, v in ave_scalars.items():
371
+ logtext += f' {k}={v.item():.4f}'
372
+ self.log_scalar('val/' + k, v.item())
373
+ self.log_buffer.append(logtext)
374
+
375
+ return ave_scalars
376
+
377
+ def visualize(self):
378
+ pass
379
+
380
+ def save_ckpt(self, filename):
381
+ if self.is_master:
382
+ model_spec = copy.copy(self.config_dict['model'])
383
+ model_spec['sd'] = self.model.state_dict()
384
+ optimizers_spec = dict()
385
+ for name, spec in self.config_dict['optimizers'].items():
386
+ spec = copy.copy(spec)
387
+ spec['sd'] = self.optimizers[name].state_dict()
388
+ optimizers_spec[name] = spec
389
+ ckpt = {
390
+ 'config': self.config_dict,
391
+ 'model': model_spec,
392
+ 'optimizers': optimizers_spec,
393
+ 'iter': self.iter,
394
+ 'train_loader_epoch': self.train_loader_epoch,
395
+ 'ckpt_select_v': self.ckpt_select_v,
396
+ }
397
+ torch.save(ckpt, os.path.join(self.env['save_dir'], filename))
398
+ dist.barrier()
399
+
400
+ def sync_ave_scalars(self, ave_scalars):
401
+ keys = sorted(list(ave_scalars.keys()))
402
+ for k in keys:
403
+ if not k.startswith('_'):
404
+ v = ave_scalars[k]
405
+ vt = torch.tensor(v.item(), device=self.device)
406
+ dist.all_reduce(vt, op=dist.ReduceOp.SUM)
407
+ torch.cuda.synchronize()
408
+ ave_scalars[k].v = vt.item() / self.world_size
409
+ ave_scalars[k].n *= self.world_size
410
+
411
+
412
+ def log_scalar(self, k, v):
413
+ if self.experiment:
414
+ self.experiment.log_metric(k, v, step=self.iter)
415
+
416
+ def log_image(self, k, v):
417
+ if self.experiment:
418
+ self.experiment.log_image(k, v, step=self.iter)
flowae/trainers/ldm_trainer.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ import torch_fidelity
7
+ import torchvision
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+
11
+ import utils
12
+ from utils.geometry import make_coord_scale_grid
13
+ from .trainers import register
14
+ from trainers.base_trainer import BaseTrainer
15
+ from models.ldm.dac.audiotools import AudioSignal
16
+ import soundfile as sf
17
+ import numpy as np
18
+
19
+ from models.ldm.dac.loss import (GANLoss, L1Loss, MelSpectrogramLoss,
20
+ MultiScaleSTFTLoss, kl_loss)
21
+
22
+ @register('ldm_trainer')
23
+ class LDMTrainer(BaseTrainer):
24
+
25
+ def make_model(self):
26
+ super().make_model()
27
+ self.has_optimizer = dict()
28
+ for name, m in self.model.named_children():
29
+ self.log(f' .{name} {utils.compute_num_params(m)}')
30
+
31
+ def make_optimizers(self):
32
+ self.optimizers = dict()
33
+ self.has_optimizer = dict()
34
+ for name, spec in self.config.optimizers.items():
35
+ self.optimizers[name] = utils.make_optimizer(self.model.get_parameters(name), spec)
36
+ self.has_optimizer[name] = True
37
+
38
+ def train_step(self, data, bp=True):
39
+ kwargs = {'has_optimizer': self.has_optimizer}
40
+ print('data', data.keys())
41
+ print('inp', data['inp'].shape)
42
+ print('gt', data['gt'].shape)
43
+
44
+ if self.config.get('autocast_bfloat16', False):
45
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
46
+ ret = self.model_ddp(data, mode='loss', **kwargs)
47
+ else:
48
+ ret = self.model_ddp(data, mode='loss', **kwargs)
49
+
50
+ loss = ret.pop('loss')
51
+ ret['loss'] = loss.item()
52
+ if bp:
53
+ self.model_ddp.zero_grad()
54
+ loss.backward()
55
+ for name, o in self.optimizers.items():
56
+ if name != 'disc':
57
+ o.step()
58
+
59
+ self.model.update_ema()
60
+
61
+ return ret
62
+
63
+ def evaluate(self):
64
+ self.model_ddp.eval()
65
+
66
+ ave_scalars = dict()
67
+ pbar = self.loaders['val']
68
+
69
+ for data in pbar:
70
+ for k, v in data.items():
71
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
72
+
73
+ ret = self.train_step(data, bp=False)
74
+
75
+ bs = len(next(iter(data.values())))
76
+ for k, v in ret.items():
77
+ if ave_scalars.get(k) is None:
78
+ ave_scalars[k] = utils.Averager()
79
+ ave_scalars[k].add(v, n=bs)
80
+
81
+ self.sync_ave_scalars(ave_scalars)
82
+
83
+ # Extra evaluation #
84
+ if self.config.get('evaluate_ae', False):
85
+ ave_scalars.update(self.evaluate_ae())
86
+
87
+ if self.config.get('evaluate_zdm', False):
88
+ ema = self.config.get('evaluate_zdm_ema', True)
89
+ ave_scalars.update(self.evaluate_zdm(ema=ema))
90
+ # - #
91
+
92
+ logtext = 'val:'
93
+ for k, v in ave_scalars.items():
94
+ logtext += f' {k}={v.item():.4f}'
95
+ self.log_scalar('val/' + k, v.item())
96
+ self.log_buffer.append(logtext)
97
+
98
+ return ave_scalars
99
+
100
+ def visualize(self):
101
+ self.model_ddp.eval()
102
+
103
+ if self.config.get('evaluate_ae', False):
104
+ # self.visualize_ae_fixset()
105
+ self.visualize_ae_random()
106
+
107
+ if self.config.get('evaluate_zdm', False):
108
+ ema = self.config.get('evaluate_zdm_ema', True)
109
+ # self.visualize_zdm_fixset(ema=ema)
110
+ self.visualize_zdm_random(ema=ema)
111
+ # self.visualize_zdm_denoising(ema=ema)
112
+
113
+ def evaluate_ae(self):
114
+ max_samples = self.config.get('eval_ae_max_samples')
115
+ self.loader_samplers['eval_ae'].set_epoch(0)
116
+
117
+ to_pil = transforms.ToPILImage()
118
+ psnr_value = utils.Averager()
119
+ cnt = 0
120
+
121
+ cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gen')
122
+ cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gt')
123
+ if self.is_master:
124
+ utils.ensure_path(cache_gen_dir, force_replace=True)
125
+ utils.ensure_path(cache_gt_dir, force_replace=True)
126
+ dist.barrier()
127
+
128
+ for data in self.loaders['eval_ae']:
129
+ for k, v in data.items():
130
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
131
+
132
+ pred = self.model(data, mode='pred')
133
+ gt_patch = data['gt'][:, :3, ...]
134
+
135
+ pred = (pred * 0.5 + 0.5).clamp(0, 1)
136
+ gt_patch = (gt_patch * 0.5 + 0.5).clamp(0, 1)
137
+
138
+ # PSNR
139
+ mse = (pred - gt_patch).pow(2).mean(dim=[1, 2, 3])
140
+ psnr_value.add((-10 * torch.log10(mse)).mean().item())
141
+
142
+ # FID
143
+ for i in range(len(pred)):
144
+ idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
145
+ if max_samples is None or idx < max_samples:
146
+ to_pil(pred[i]).save(os.path.join(cache_gen_dir, f'{idx}.png'))
147
+ to_pil(gt_patch[i]).save(os.path.join(cache_gt_dir, f'{idx}.png'))
148
+ cnt += 1
149
+ dist.barrier()
150
+
151
+ vt = torch.tensor(psnr_value.item(), device=self.device)
152
+ dist.all_reduce(vt, op=dist.ReduceOp.SUM)
153
+ torch.cuda.synchronize()
154
+ psnr_value = vt.item() / int(os.environ['WORLD_SIZE'])
155
+
156
+ if self.is_master:
157
+ metrics = torch_fidelity.calculate_metrics(
158
+ input1=cache_gen_dir,
159
+ input2=cache_gt_dir,
160
+ cuda=True,
161
+ fid=True,
162
+ verbose=False,
163
+ )
164
+ prefix = 'eval_ae'
165
+ ret = {
166
+ f'{prefix}/PSNR': psnr_value,
167
+ f'{prefix}/FID': metrics['frechet_inception_distance'],
168
+ }
169
+ else:
170
+ ret = {}
171
+ dist.barrier()
172
+
173
+ ret = {k: utils.Averager(v) for k, v in ret.items()}
174
+ return ret
175
+
176
+ def evaluate_zdm(self, ema):
177
+ max_samples = self.config.get('eval_zdm_max_samples')
178
+ self.loader_samplers['eval_zdm'].set_epoch(0)
179
+
180
+ to_pil = transforms.ToPILImage()
181
+ cnt = 0
182
+
183
+ cache_gen_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gen')
184
+ cache_gt_dir = os.path.join(self.env['save_dir'], 'cache', 'fid_gt')
185
+ if self.is_master:
186
+ utils.ensure_path(cache_gen_dir, force_replace=True)
187
+ utils.ensure_path(cache_gt_dir, force_replace=True)
188
+ dist.barrier()
189
+
190
+ for data in self.loaders['eval_zdm']:
191
+ for k, v in data.items():
192
+ data[k] = v.to(self.device) if torch.is_tensor(v) else v
193
+
194
+ gt_patch = data['inp']
195
+
196
+ net_kwargs = dict()
197
+ uncond_net_kwargs = dict()
198
+ if self.model.zdm_class_cond is not None:
199
+ net_kwargs['class_labels'] = data['class_labels']
200
+
201
+ setting = self.config['visualize_zdm_setting']
202
+ uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones(
203
+ len(data['class_labels']), dtype=torch.long, device=self.device)
204
+
205
+ pred = self.model.generate_samples(
206
+ batch_size=gt_patch.shape[0],
207
+ n_steps=self.model.zdm_n_steps,
208
+ net_kwargs=net_kwargs,
209
+ uncond_net_kwargs=uncond_net_kwargs,
210
+ ema=ema
211
+ )
212
+
213
+ pred = (pred * 0.5 + 0.5).clamp(0, 1)
214
+ gt_patch = (gt_patch * 0.5 + 0.5).clamp(0, 1)
215
+
216
+ # FID
217
+ for i in range(len(pred)):
218
+ idx = int(os.environ['RANK']) + cnt * int(os.environ['WORLD_SIZE'])
219
+ if max_samples is None or idx < max_samples:
220
+ to_pil(pred[i]).save(os.path.join(cache_gen_dir, f'{idx}.png'))
221
+ to_pil(gt_patch[i]).save(os.path.join(cache_gt_dir, f'{idx}.png'))
222
+ cnt += 1
223
+ dist.barrier()
224
+
225
+ if self.is_master:
226
+ metrics = torch_fidelity.calculate_metrics(
227
+ input1=cache_gen_dir,
228
+ input2=cache_gt_dir,
229
+ cuda=True,
230
+ fid=True,
231
+ verbose=False,
232
+ )
233
+ prefix = 'eval_zdm' + ('_ema' if ema else '')
234
+ ret = {
235
+ f'{prefix}/FID': metrics['frechet_inception_distance'],
236
+ }
237
+ else:
238
+ ret = {}
239
+ dist.barrier()
240
+
241
+ ret = {k: utils.Averager(v) for k, v in ret.items()}
242
+ return ret
243
+
244
+ def visualize_ae_fixset(self):
245
+ if self.config.get('visualize_ae_dir') is None:
246
+ return
247
+ to_tensor = transforms.ToTensor()
248
+ if self.is_master:
249
+ files = sorted(os.listdir(self.config['visualize_ae_dir']))
250
+ vis_images = []
251
+
252
+ for f in files:
253
+ image = Image.open(os.path.join(self.config['visualize_ae_dir'], f)).convert('RGB')
254
+ x = to_tensor(image).unsqueeze(0).to(self.device)
255
+ x = (x - 0.5) / 0.5
256
+ gt_dummy = torch.zeros(x.shape[0], 7, x.shape[2], x.shape[3], device=self.device)
257
+
258
+ pred1 = self.model({'inp': x, 'gt': gt_dummy}, mode='pred')
259
+ pred2 = self.model({'inp': x, 'gt': gt_dummy}, mode='pred')
260
+ vis_images.extend([x, pred1, pred2])
261
+
262
+ vis_images = torch.cat(vis_images, dim=0)
263
+ vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=6)
264
+ self.log_image('vis_ae_fixset', vis_images)
265
+ dist.barrier()
266
+
267
+ def visualize_ae_random(self):
268
+ if self.is_master:
269
+ idx_list = list(range(len(self.datasets['eval_ae'])))
270
+ random.shuffle(idx_list)
271
+ n_samples = self.config['visualize_ae_random_n_samples']
272
+ vis_images = []
273
+
274
+ for idx in idx_list[:n_samples]:
275
+ data = self.datasets['eval_ae'][idx]
276
+ for k, v in data.items():
277
+ data[k] = v.unsqueeze(0).to(self.device) if torch.is_tensor(v) else v
278
+
279
+ pred1 = self.model(data, mode='pred')
280
+ pred2 = self.model(data, mode='pred')
281
+ gt_patch = data['gt'][:, :3, ...]
282
+ vis_images.extend([gt_patch, pred1, pred2])
283
+
284
+ vis_images = torch.cat(vis_images, dim=0)
285
+ vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=6)
286
+ self.log_image('vis_ae_random', vis_images)
287
+ dist.barrier()
288
+
289
+ def visualize_zdm_fixset(self, ema):
290
+ if self.is_master:
291
+ vis_file = torch.load(self.config['visualize_zdm_file'], map_location='cpu')
292
+ for k, v in vis_file.items():
293
+ vis_file[k] = v.to(self.device) if torch.is_tensor(v) else v
294
+ n_samples = len(vis_file['noise'])
295
+
296
+ batch_size = self.config.get('visualize_zdm_batch_size', 1)
297
+ guidance_list = [1.0] + self.config.get('visualize_zdm_guidance_list', [])
298
+
299
+ vis_images = []
300
+
301
+ for i in range(0, n_samples, batch_size):
302
+ cur_batch_size = min(batch_size, n_samples - i)
303
+
304
+ net_kwargs = dict()
305
+ uncond_net_kwargs = dict()
306
+ if self.config.get('visualize_zdm_setting') is not None:
307
+ setting = self.config['visualize_zdm_setting']
308
+ if setting['name'] == 'class':
309
+ net_kwargs['class_labels'] = vis_file['class_labels'][i:i + cur_batch_size]
310
+ uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones(
311
+ cur_batch_size, dtype=torch.long, device=self.device)
312
+ else:
313
+ raise NotImplementedError
314
+
315
+ for guidance in guidance_list:
316
+ pred = self.model.generate_samples(
317
+ batch_size=cur_batch_size,
318
+ n_steps=self.model.zdm_n_steps,
319
+ net_kwargs=net_kwargs,
320
+ uncond_net_kwargs=uncond_net_kwargs,
321
+ ema=ema,
322
+ guidance=guidance,
323
+ noise=vis_file['noise'][i:i + cur_batch_size],
324
+ )
325
+ vis_images.append(pred)
326
+
327
+ vis_images = torch.cat(vis_images, dim=0)
328
+ vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=batch_size)
329
+ name = 'vis_zdm_fixset'
330
+ name += '_ema' if ema else ''
331
+ name += '_cfg' + str(guidance_list[1:])[1:-1] if len(guidance_list) > 1 else ''
332
+ self.log_image(name, vis_images)
333
+ dist.barrier()
334
+
335
+ def visualize_zdm_random(self, ema):
336
+ n_samples = self.config['visualize_zdm_random_n_samples']
337
+ batch_size = self.config.get('visualize_zdm_batch_size', 1)
338
+ guidance_list = [1.0] + self.config.get('visualize_zdm_guidance_list', [])
339
+
340
+ vis_images = []
341
+
342
+ if self.is_master:
343
+ for i in range(0, n_samples, batch_size):
344
+ cur_batch_size = min(batch_size, n_samples - i)
345
+
346
+ net_kwargs = dict()
347
+ uncond_net_kwargs = dict()
348
+ if self.config.get('visualize_zdm_setting') is not None:
349
+ setting = self.config['visualize_zdm_setting']
350
+ if setting['name'] == 'class':
351
+ net_kwargs['class_labels'] = torch.randint(
352
+ setting['n_classes'], size=(cur_batch_size,), device=self.device)
353
+ uncond_net_kwargs['class_labels'] = setting['n_classes'] * torch.ones(
354
+ cur_batch_size, dtype=torch.long, device=self.device)
355
+ else:
356
+ raise NotImplementedError
357
+
358
+ for guidance in guidance_list:
359
+ pred = self.model.generate_samples(
360
+ batch_size=cur_batch_size,
361
+ n_steps=self.model.zdm_n_steps,
362
+ net_kwargs=net_kwargs,
363
+ uncond_net_kwargs=uncond_net_kwargs,
364
+ ema=ema,
365
+ guidance=guidance,
366
+ )
367
+ vis_images.append(pred)
368
+
369
+ vis_images = torch.cat(vis_images, dim=0)
370
+ vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=batch_size)
371
+ name = 'vis_zdm_random'
372
+ name += '_ema' if ema else ''
373
+ name += '_cfg' + str(guidance_list[1:])[1:-1] if len(guidance_list) > 1 else ''
374
+ self.log_image(name, vis_images)
375
+ dist.barrier()
376
+
377
+ def visualize_zdm_denoising(self, ema, n_selected_timesteps=5):
378
+ if self.is_master:
379
+ vis_file = torch.load(self.config['visualize_zdm_denoising_file'], map_location='cpu')
380
+
381
+ vis_images = []
382
+
383
+ for i in range(len(vis_file['inp'])):
384
+ x = (
385
+ vis_file['inp'][i]
386
+ .to(self.device)
387
+ .unsqueeze(0)
388
+ .expand(n_selected_timesteps, -1, -1, -1)
389
+ )
390
+
391
+ z = self.model.encode(x)
392
+ z = self.model.normalize_for_zdm(z)
393
+ t = torch.linspace(0, 1, n_selected_timesteps + 1, device=self.device)[1:]
394
+ noise = (
395
+ vis_file['noise'][i]
396
+ .to(self.device)
397
+ .unsqueeze(0)
398
+ .expand(n_selected_timesteps, -1, -1, -1)
399
+ )
400
+ z_t, _ = self.model.zdm_diffusion.add_noise(z, t, noise=noise)
401
+
402
+ # Visualize noisy latents
403
+ zp = self.model.denormalize_for_zdm(z_t)
404
+ z_dec = self.model.decode(zp)
405
+ coord, scale = make_coord_scale_grid(x.shape[-2:], device=self.device, batch_size=n_selected_timesteps)
406
+ coord = coord.permute(0, 3, 1, 2)
407
+ scale = scale.permute(0, 3, 1, 2)
408
+ x_out = self.model.render(z_dec, coord, scale)
409
+ vis_images.append(x_out)
410
+
411
+ # Generate denoised latents
412
+ net = self.model.zdm_net_ema if ema else self.model.zdm_net
413
+ net_kwargs = dict()
414
+ if self.config.get('visualize_zdm_setting') is not None:
415
+ setting = self.config['visualize_zdm_setting']
416
+ if setting['name'] == 'class':
417
+ net_kwargs['class_labels'] = (
418
+ vis_file['class_labels'][i]
419
+ .to(self.device)
420
+ .unsqueeze(0)
421
+ .expand(n_selected_timesteps)
422
+ )
423
+ else:
424
+ raise NotImplementedError
425
+ pred = self.model.zdm_diffusion.get_prediction(net, z_t, t, net_kwargs=net_kwargs)
426
+ zp = []
427
+ for j in range(len(pred)):
428
+ zp.append(self.model.zdm_diffusion.convert_sample_prediction(z_t[j], float(t[j]), pred[j]))
429
+ zp = torch.stack(zp, dim=0)
430
+
431
+ # Visualize denoised latents
432
+ zp = self.model.denormalize_for_zdm(zp)
433
+ z_dec = self.model.decode(zp)
434
+ coord, scale = make_coord_scale_grid(x.shape[-2:], device=self.device, batch_size=n_selected_timesteps)
435
+ coord = coord.permute(0, 3, 1, 2)
436
+ scale = scale.permute(0, 3, 1, 2)
437
+ x_out = self.model.render(z_dec, coord, scale)
438
+ vis_images.append(x_out)
439
+
440
+ vis_images = torch.cat(vis_images, dim=0)
441
+ vis_images = torchvision.utils.make_grid(vis_images, normalize=True, value_range=(-1, 1), nrow=n_selected_timesteps)
442
+ self.log_image('vis_zdm' + ('_ema' if ema else '') + '_denoising', vis_images)
443
+ dist.barrier()
flowae/trainers/trainers.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ trainers_dict = dict()
2
+
3
+
4
+ def register(name):
5
+ def decorator(cls):
6
+ trainers_dict[name] = cls
7
+ return cls
8
+ return decorator