AlienKevin commited on
Commit
2b35aa7
·
verified ·
1 Parent(s): 63b50f6

Upload checkpoints

Browse files
config.yaml ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ optimizer:
3
+ optim:
4
+ _target_: utils.optimizers.Lamb
5
+ lr: 0.002
6
+ betas:
7
+ - 0.9
8
+ - 0.999
9
+ weight_decay: 0.01
10
+ exclude_ln_and_biases_from_weight_decay: true
11
+ lr_scheduler:
12
+ _partial_: true
13
+ _target_: utils.lr_scheduler.WarmupCosineDecayLR
14
+ warmup_steps: 10000
15
+ total_steps: ${trainer.max_steps}
16
+ rate: 0.7
17
+ network:
18
+ _target_: cad.models.networks.rin.RINClassCond
19
+ data_size: ${data.data_resolution}
20
+ data_dim: 512
21
+ num_input_channels: 3
22
+ num_latents: 128
23
+ latents_dim: 768
24
+ label_dim: ${data.label_dim}
25
+ num_cond_tokens: ${data.num_cond_tokens}
26
+ num_processing_layers: 4
27
+ num_blocks: 4
28
+ path_size: 8
29
+ read_write_heads: 16
30
+ compute_heads: 16
31
+ latent_mlp_multiplier: 4
32
+ data_mlp_multiplier: 4
33
+ rw_dropout: 0.0
34
+ compute_dropout: 0
35
+ rw_stochastic_depth: 0
36
+ compute_stochastic_depth: 0
37
+ time_scaling: 1000.0
38
+ noise_embedding_type: positional
39
+ data_positional_embedding_type: learned
40
+ weight_init: xavier_uniform
41
+ bias_init: zeros
42
+ use_cond_token: true
43
+ use_biases: true
44
+ concat_cond_token_to_latents: true
45
+ use_cond_rin_block: false
46
+ use_16_bits_layer_norm: false
47
+ train_noise_scheduler:
48
+ _target_: cad.models.schedulers.LinearScheduler
49
+ start: 1
50
+ end: 0
51
+ clip_min: 1.0e-09
52
+ inference_noise_scheduler:
53
+ _target_: cad.models.schedulers.CosineSchedulerSimple
54
+ ns: 0.0002
55
+ ds: 0.00025
56
+ preconditioning:
57
+ _target_: cad.models.preconditioning.DDPMPrecond
58
+ num_latents: ${model.network.num_latents}
59
+ latents_dim: ${model.network.latents_dim}
60
+ data_preprocessing:
61
+ _target_: cad.models.preprocessing.PrecomputedPreconditioning
62
+ input_key: image
63
+ output_key_root: x_0
64
+ cond_preprocessing:
65
+ _target_: cad.models.preprocessing.PrecomputedPreconditioning
66
+ input_key: label
67
+ output_key_root: label
68
+ drop_labels: false
69
+ postprocessing:
70
+ _partial_: true
71
+ _target_: utils.image_processing.remap_image_torch
72
+ loss:
73
+ _partial_: true
74
+ _target_: cad.models.losses.DDPMLoss
75
+ self_cond_rate: 0.9
76
+ cond_drop_rate: 0.0
77
+ conditioning_key: ${model.cond_preprocessing.output_key_root}
78
+ resample_by_coherence: false
79
+ sample_random_when_drop: false
80
+ val_sampler:
81
+ _partial_: true
82
+ _target_: cad.models.samplers.ddim.ddim_sampler
83
+ num_steps: 250
84
+ cfg_rate: ${model.cfg_rate}
85
+ test_sampler:
86
+ _partial_: true
87
+ _target_: cad.models.samplers.ddpm.ddpm_sampler
88
+ num_steps: 1000
89
+ cfg_rate: ${model.cfg_rate}
90
+ uncond_conditioning:
91
+ _target_: cad.utils.misc.dummy_value_loader
92
+ value: 0.0
93
+ vae_embedding_name_mean: null
94
+ return_image: true
95
+ name: RIN
96
+ ema_decay: 0.9999
97
+ start_ema_step: 0
98
+ cfg_rate: 0.0
99
+ channel_wise_normalisation: false
100
+ computer:
101
+ devices: 8
102
+ num_workers: 64
103
+ progress_bar_refresh_rate: 2
104
+ sync_batchnorm: true
105
+ accelerator: gpu
106
+ precision: bf16-mixed
107
+ strategy: ddp
108
+ num_nodes: 1
109
+ eval_gpu_type: h200
110
+ data:
111
+ train_aug:
112
+ _target_: torchvision.transforms.Compose
113
+ transforms:
114
+ - _target_: torchvision.transforms.ToTensor
115
+ - _target_: utils.image_processing.CenterCrop
116
+ ratio: '1:1'
117
+ - _target_: torchvision.transforms.Resize
118
+ size: ${data.img_resolution}
119
+ interpolation: 3
120
+ antialias: true
121
+ - _target_: torchvision.transforms.RandomHorizontalFlip
122
+ p: 0.5
123
+ - _target_: torchvision.transforms.Normalize
124
+ mean: 0.5
125
+ std: 0.5
126
+ val_aug:
127
+ _target_: torchvision.transforms.Compose
128
+ transforms:
129
+ - _target_: torchvision.transforms.ToTensor
130
+ - _target_: utils.image_processing.CenterCrop
131
+ ratio: '1:1'
132
+ - _target_: torchvision.transforms.Resize
133
+ size: ${data.img_resolution}
134
+ interpolation: 3
135
+ antialias: true
136
+ - _target_: torchvision.transforms.Normalize
137
+ mean: 0.5
138
+ std: 0.5
139
+ name: ImageNet_64
140
+ type: class_conditional
141
+ img_resolution: 64
142
+ data_resolution: 64
143
+ label_dim: 1000
144
+ num_cond_tokens: 1
145
+ full_batch_size: 1024
146
+ in_channels: 3
147
+ out_channels: 3
148
+ train_instance:
149
+ _partial_: true
150
+ _target_: cad.data.dataset.HFImageNet64
151
+ split: train
152
+ transform: ${data.train_aug}
153
+ target_transform: ${data.target_transform}
154
+ val_instance:
155
+ _partial_: true
156
+ _target_: cad.data.dataset.HFImageNet64
157
+ split: validation
158
+ transform: ${data.val_aug}
159
+ target_transform: ${data.target_transform}
160
+ target_transform:
161
+ _target_: utils.one_hot_transform.OneHotTransform
162
+ num_classes: ${data.label_dim}
163
+ collate_fn:
164
+ _target_: data.datamodule.collate_to_dict
165
+ keys:
166
+ - image
167
+ - label
168
+ train_dataset: ${data.train_instance}
169
+ val_dataset: ${data.val_instance}
170
+ datamodule:
171
+ _target_: data.datamodule.ImageDataModule
172
+ train_dataset: ${data.train_dataset}
173
+ val_dataset: ${data.val_dataset}
174
+ full_batch_size: ${data.full_batch_size}
175
+ collate_fn: ${data.collate_fn}
176
+ num_workers: ${computer.num_workers}
177
+ num_nodes: ${computer.num_nodes}
178
+ num_devices: ${computer.devices}
179
+ trainer:
180
+ _target_: pytorch_lightning.Trainer
181
+ max_steps: 150000
182
+ val_check_interval: 5000
183
+ check_val_every_n_epoch: null
184
+ devices: ${computer.devices}
185
+ accelerator: ${computer.accelerator}
186
+ strategy: ${computer.strategy}
187
+ log_every_n_steps: 1
188
+ num_nodes: ${computer.num_nodes}
189
+ precision: ${computer.precision}
190
+ logger:
191
+ _target_: pytorch_lightning.loggers.WandbLogger
192
+ save_dir: ${root_dir}/cad/wandb
193
+ name: ${experiment_name}
194
+ project: RIN
195
+ log_model: false
196
+ offline: false
197
+ checkpoints:
198
+ _target_: callbacks.checkpoint_and_validate.ModelCheckpointValidate
199
+ gpu_type: ${computer.eval_gpu_type}
200
+ validate_when_not_on_cluster: false
201
+ validate_when_on_cluster: false
202
+ eval_set: val
203
+ validate_conditional: true
204
+ validate_unconditional: false
205
+ validate_per_class_metrics: true
206
+ shape:
207
+ - ${model.network.num_input_channels}
208
+ - ${data.data_resolution}
209
+ - ${data.data_resolution}
210
+ num_classes: ${data.label_dim}
211
+ dataset_name: ${data.name}
212
+ dirpath: ${root_dir}/cad/checkpoints/${experiment_name}
213
+ filename: step_{step}
214
+ monitor: val/loss_ema
215
+ save_last: true
216
+ save_top_k: -1
217
+ enable_version_counter: false
218
+ every_n_train_steps: 10000
219
+ auto_insert_metric_name: false
220
+ progress_bar:
221
+ _target_: pytorch_lightning.callbacks.TQDMProgressBar
222
+ refresh_rate: ${computer.progress_bar_refresh_rate}
223
+ data_dir: ${root_dir}/cad/datasets
224
+ root_dir: ${hydra:runtime.cwd}
225
+ experiment_name_suffix: hf_h200
226
+ experiment_name: ${data.name}_${model.name}_${experiment_name_suffix}
last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df4939cc08ede413910bcf869efacc85d245ea50d5fa144a622d4b764534ec89
3
+ size 2524974006
step_10000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff198235baac1677879c485e696d6cb7e807cd68fc23dbb2b8dc34be36cba278
3
+ size 2524968580
step_100000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b760e61e1fbacf3492bedf3b4393c6d0e32dacc7c071f3463b1bc592c0d3cd67
3
+ size 2524972347
step_110000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c7a2147c0026b36f69a0e2897f0897300cfbdead12a3b1839e58ea160590a9d
3
+ size 2524972666
step_120000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81092bc67a2b288e40adb97434b8372c3d7e1dc85af87365c6713f81bc64b0f0
3
+ size 2524972985
step_130000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c2641880e1c123b7d3f46903ee61daf6c57d3220510533ddc7b772b28fc1860
3
+ size 2524973368
step_140000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:047a14b5a54ea861fa7fb02cd0e68c92316f64d52091e61231663e32ff9df0fd
3
+ size 2524973687
step_150000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df4939cc08ede413910bcf869efacc85d245ea50d5fa144a622d4b764534ec89
3
+ size 2524974006
step_20000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0f92bfeb6c2a3c4c64b0234565437afdf07267a24acec87f6b578df2552c974
3
+ size 2524968899
step_30000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d046919e2642306400fcd2a341ae0534d555c5ef0c785750f8d9ac74a2feec22
3
+ size 2524969282
step_40000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:663d4d8426ee3737c2567dd3b8813cd00e909579745738fc0e859051ac900cb7
3
+ size 2524969601
step_50000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d8746e816eb6252ca2ad4644df7984f18f21f0602ee8e6c724627e0106ad842
3
+ size 2524969920
step_60000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:610ca9c76bd769c3d911d94fdd66e67852e7267ddfb4b9b6a19712108449078f
3
+ size 2524970303
step_70000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36053f6fcbe99a7601e7f1025b314316169f398ed5d042ea8ddd910efdb41c7e
3
+ size 2524971326
step_80000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9310b5249f3416d790db5ed5062ae888828caabd52987110734a8fab868b0fcd
3
+ size 2524971645
step_90000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44b17040577d2118a4ecd7cd1c2958de9c1adcd8398540f662ca7070441f6bfe
3
+ size 2524971964