| model: | |
| optimizer: | |
| optim: | |
| _target_: utils.optimizers.Lamb | |
| lr: 0.002 | |
| betas: | |
| - 0.9 | |
| - 0.999 | |
| weight_decay: 0.01 | |
| exclude_ln_and_biases_from_weight_decay: true | |
| lr_scheduler: | |
| _partial_: true | |
| _target_: utils.lr_scheduler.WarmupCosineDecayLR | |
| warmup_steps: 10000 | |
| total_steps: ${trainer.max_steps} | |
| rate: 0.7 | |
| network: | |
| _target_: cad.models.networks.rin.RINClassCond | |
| data_size: ${data.data_resolution} | |
| data_dim: 512 | |
| num_input_channels: 3 | |
| num_latents: 128 | |
| latents_dim: 768 | |
| label_dim: ${data.label_dim} | |
| num_cond_tokens: ${data.num_cond_tokens} | |
| num_processing_layers: 4 | |
| num_blocks: 4 | |
| path_size: 8 | |
| read_write_heads: 16 | |
| compute_heads: 16 | |
| latent_mlp_multiplier: 4 | |
| data_mlp_multiplier: 4 | |
| rw_dropout: 0.0 | |
| compute_dropout: 0 | |
| rw_stochastic_depth: 0 | |
| compute_stochastic_depth: 0 | |
| time_scaling: 1000.0 | |
| noise_embedding_type: positional | |
| data_positional_embedding_type: learned | |
| weight_init: xavier_uniform | |
| bias_init: zeros | |
| use_cond_token: true | |
| use_biases: true | |
| concat_cond_token_to_latents: true | |
| use_cond_rin_block: false | |
| use_16_bits_layer_norm: false | |
| train_noise_scheduler: | |
| _target_: cad.models.schedulers.LinearScheduler | |
| start: 1 | |
| end: 0 | |
| clip_min: 1.0e-09 | |
| inference_noise_scheduler: | |
| _target_: cad.models.schedulers.CosineSchedulerSimple | |
| ns: 0.0002 | |
| ds: 0.00025 | |
| preconditioning: | |
| _target_: cad.models.preconditioning.DDPMPrecond | |
| num_latents: ${model.network.num_latents} | |
| latents_dim: ${model.network.latents_dim} | |
| data_preprocessing: | |
| _target_: cad.models.preprocessing.PrecomputedPreconditioning | |
| input_key: image | |
| output_key_root: x_0 | |
| cond_preprocessing: | |
| _target_: cad.models.preprocessing.PrecomputedPreconditioning | |
| input_key: label | |
| output_key_root: label | |
| drop_labels: false | |
| postprocessing: | |
| _partial_: true | |
| _target_: utils.image_processing.remap_image_torch | |
| loss: | |
| _partial_: true | |
| _target_: cad.models.losses.DDPMLoss | |
| self_cond_rate: 0.9 | |
| cond_drop_rate: 0.0 | |
| conditioning_key: ${model.cond_preprocessing.output_key_root} | |
| resample_by_coherence: false | |
| sample_random_when_drop: false | |
| val_sampler: | |
| _partial_: true | |
| _target_: cad.models.samplers.ddim.ddim_sampler | |
| num_steps: 250 | |
| cfg_rate: ${model.cfg_rate} | |
| test_sampler: | |
| _partial_: true | |
| _target_: cad.models.samplers.ddpm.ddpm_sampler | |
| num_steps: 1000 | |
| cfg_rate: ${model.cfg_rate} | |
| uncond_conditioning: | |
| _target_: cad.utils.misc.dummy_value_loader | |
| value: 0.0 | |
| vae_embedding_name_mean: null | |
| return_image: true | |
| name: RIN | |
| ema_decay: 0.9999 | |
| start_ema_step: 0 | |
| cfg_rate: 0.0 | |
| channel_wise_normalisation: false | |
| computer: | |
| devices: 8 | |
| num_workers: 64 | |
| progress_bar_refresh_rate: 2 | |
| sync_batchnorm: true | |
| accelerator: gpu | |
| precision: bf16-mixed | |
| strategy: ddp | |
| num_nodes: 1 | |
| eval_gpu_type: h200 | |
| data: | |
| train_aug: | |
| _target_: torchvision.transforms.Compose | |
| transforms: | |
| - _target_: torchvision.transforms.ToTensor | |
| - _target_: utils.image_processing.CenterCrop | |
| ratio: '1:1' | |
| - _target_: torchvision.transforms.Resize | |
| size: ${data.img_resolution} | |
| interpolation: 3 | |
| antialias: true | |
| - _target_: torchvision.transforms.RandomHorizontalFlip | |
| p: 0.5 | |
| - _target_: torchvision.transforms.Normalize | |
| mean: 0.5 | |
| std: 0.5 | |
| val_aug: | |
| _target_: torchvision.transforms.Compose | |
| transforms: | |
| - _target_: torchvision.transforms.ToTensor | |
| - _target_: utils.image_processing.CenterCrop | |
| ratio: '1:1' | |
| - _target_: torchvision.transforms.Resize | |
| size: ${data.img_resolution} | |
| interpolation: 3 | |
| antialias: true | |
| - _target_: torchvision.transforms.Normalize | |
| mean: 0.5 | |
| std: 0.5 | |
| name: ImageNet_64 | |
| type: class_conditional | |
| img_resolution: 64 | |
| data_resolution: 64 | |
| label_dim: 1000 | |
| num_cond_tokens: 1 | |
| full_batch_size: 1024 | |
| in_channels: 3 | |
| out_channels: 3 | |
| train_instance: | |
| _partial_: true | |
| _target_: cad.data.dataset.HFImageNet64 | |
| split: train | |
| transform: ${data.train_aug} | |
| target_transform: ${data.target_transform} | |
| val_instance: | |
| _partial_: true | |
| _target_: cad.data.dataset.HFImageNet64 | |
| split: validation | |
| transform: ${data.val_aug} | |
| target_transform: ${data.target_transform} | |
| target_transform: | |
| _target_: utils.one_hot_transform.OneHotTransform | |
| num_classes: ${data.label_dim} | |
| collate_fn: | |
| _target_: data.datamodule.collate_to_dict | |
| keys: | |
| - image | |
| - label | |
| train_dataset: ${data.train_instance} | |
| val_dataset: ${data.val_instance} | |
| datamodule: | |
| _target_: data.datamodule.ImageDataModule | |
| train_dataset: ${data.train_dataset} | |
| val_dataset: ${data.val_dataset} | |
| full_batch_size: ${data.full_batch_size} | |
| collate_fn: ${data.collate_fn} | |
| num_workers: ${computer.num_workers} | |
| num_nodes: ${computer.num_nodes} | |
| num_devices: ${computer.devices} | |
| trainer: | |
| _target_: pytorch_lightning.Trainer | |
| max_steps: 150000 | |
| val_check_interval: 5000 | |
| check_val_every_n_epoch: null | |
| devices: ${computer.devices} | |
| accelerator: ${computer.accelerator} | |
| strategy: ${computer.strategy} | |
| log_every_n_steps: 1 | |
| num_nodes: ${computer.num_nodes} | |
| precision: ${computer.precision} | |
| logger: | |
| _target_: pytorch_lightning.loggers.WandbLogger | |
| save_dir: ${root_dir}/cad/wandb | |
| name: ${experiment_name} | |
| project: RIN | |
| log_model: false | |
| offline: false | |
| checkpoints: | |
| _target_: callbacks.checkpoint_and_validate.ModelCheckpointValidate | |
| gpu_type: ${computer.eval_gpu_type} | |
| validate_when_not_on_cluster: false | |
| validate_when_on_cluster: false | |
| eval_set: val | |
| validate_conditional: true | |
| validate_unconditional: false | |
| validate_per_class_metrics: true | |
| shape: | |
| - ${model.network.num_input_channels} | |
| - ${data.data_resolution} | |
| - ${data.data_resolution} | |
| num_classes: ${data.label_dim} | |
| dataset_name: ${data.name} | |
| dirpath: ${root_dir}/cad/checkpoints/${experiment_name} | |
| filename: step_{step} | |
| monitor: val/loss_ema | |
| save_last: true | |
| save_top_k: -1 | |
| enable_version_counter: false | |
| every_n_train_steps: 10000 | |
| auto_insert_metric_name: false | |
| progress_bar: | |
| _target_: pytorch_lightning.callbacks.TQDMProgressBar | |
| refresh_rate: ${computer.progress_bar_refresh_rate} | |
| data_dir: ${root_dir}/cad/datasets | |
| root_dir: ${hydra:runtime.cwd} | |
| experiment_name_suffix: hf_h200 | |
| experiment_name: ${data.name}_${model.name}_${experiment_name_suffix} | |