model: optimizer: optim: _target_: utils.optimizers.Lamb lr: 0.002 betas: - 0.9 - 0.999 weight_decay: 0.01 exclude_ln_and_biases_from_weight_decay: true lr_scheduler: _partial_: true _target_: utils.lr_scheduler.WarmupCosineDecayLR warmup_steps: 10000 total_steps: ${trainer.max_steps} rate: 0.7 network: _target_: cad.models.networks.rin.RINClassCond data_size: ${data.data_resolution} data_dim: 512 num_input_channels: 3 num_latents: 128 latents_dim: 768 label_dim: ${data.label_dim} num_cond_tokens: ${data.num_cond_tokens} num_processing_layers: 4 num_blocks: 4 path_size: 8 read_write_heads: 16 compute_heads: 16 latent_mlp_multiplier: 4 data_mlp_multiplier: 4 rw_dropout: 0.0 compute_dropout: 0 rw_stochastic_depth: 0 compute_stochastic_depth: 0 time_scaling: 1000.0 noise_embedding_type: positional data_positional_embedding_type: learned weight_init: xavier_uniform bias_init: zeros use_cond_token: true use_biases: true concat_cond_token_to_latents: true use_cond_rin_block: false use_16_bits_layer_norm: false train_noise_scheduler: _target_: cad.models.schedulers.LinearScheduler start: 1 end: 0 clip_min: 1.0e-09 inference_noise_scheduler: _target_: cad.models.schedulers.CosineSchedulerSimple ns: 0.0002 ds: 0.00025 preconditioning: _target_: cad.models.preconditioning.DDPMPrecond num_latents: ${model.network.num_latents} latents_dim: ${model.network.latents_dim} data_preprocessing: _target_: cad.models.preprocessing.PrecomputedPreconditioning input_key: image output_key_root: x_0 cond_preprocessing: _target_: cad.models.preprocessing.PrecomputedPreconditioning input_key: label output_key_root: label drop_labels: false postprocessing: _partial_: true _target_: utils.image_processing.remap_image_torch loss: _partial_: true _target_: cad.models.losses.DDPMLoss self_cond_rate: 0.9 cond_drop_rate: 0.0 conditioning_key: ${model.cond_preprocessing.output_key_root} resample_by_coherence: false sample_random_when_drop: false val_sampler: _partial_: true _target_: cad.models.samplers.ddim.ddim_sampler num_steps: 250 cfg_rate: ${model.cfg_rate} test_sampler: _partial_: true _target_: cad.models.samplers.ddpm.ddpm_sampler num_steps: 1000 cfg_rate: ${model.cfg_rate} uncond_conditioning: _target_: cad.utils.misc.dummy_value_loader value: 0.0 vae_embedding_name_mean: null return_image: true name: RIN ema_decay: 0.9999 start_ema_step: 0 cfg_rate: 0.0 channel_wise_normalisation: false computer: devices: 8 num_workers: 64 progress_bar_refresh_rate: 2 sync_batchnorm: true accelerator: gpu precision: bf16-mixed strategy: ddp num_nodes: 1 eval_gpu_type: h200 data: train_aug: _target_: torchvision.transforms.Compose transforms: - _target_: torchvision.transforms.ToTensor - _target_: utils.image_processing.CenterCrop ratio: '1:1' - _target_: torchvision.transforms.Resize size: ${data.img_resolution} interpolation: 3 antialias: true - _target_: torchvision.transforms.RandomHorizontalFlip p: 0.5 - _target_: torchvision.transforms.Normalize mean: 0.5 std: 0.5 val_aug: _target_: torchvision.transforms.Compose transforms: - _target_: torchvision.transforms.ToTensor - _target_: utils.image_processing.CenterCrop ratio: '1:1' - _target_: torchvision.transforms.Resize size: ${data.img_resolution} interpolation: 3 antialias: true - _target_: torchvision.transforms.Normalize mean: 0.5 std: 0.5 name: ImageNet_64 type: class_conditional img_resolution: 64 data_resolution: 64 label_dim: 1000 num_cond_tokens: 1 full_batch_size: 1024 in_channels: 3 out_channels: 3 train_instance: _partial_: true _target_: cad.data.dataset.HFImageNet64 split: train transform: ${data.train_aug} target_transform: ${data.target_transform} val_instance: _partial_: true _target_: cad.data.dataset.HFImageNet64 split: validation transform: ${data.val_aug} target_transform: ${data.target_transform} target_transform: _target_: utils.one_hot_transform.OneHotTransform num_classes: ${data.label_dim} collate_fn: _target_: data.datamodule.collate_to_dict keys: - image - label train_dataset: ${data.train_instance} val_dataset: ${data.val_instance} datamodule: _target_: data.datamodule.ImageDataModule train_dataset: ${data.train_dataset} val_dataset: ${data.val_dataset} full_batch_size: ${data.full_batch_size} collate_fn: ${data.collate_fn} num_workers: ${computer.num_workers} num_nodes: ${computer.num_nodes} num_devices: ${computer.devices} trainer: _target_: pytorch_lightning.Trainer max_steps: 150000 val_check_interval: 5000 check_val_every_n_epoch: null devices: ${computer.devices} accelerator: ${computer.accelerator} strategy: ${computer.strategy} log_every_n_steps: 1 num_nodes: ${computer.num_nodes} precision: ${computer.precision} logger: _target_: pytorch_lightning.loggers.WandbLogger save_dir: ${root_dir}/cad/wandb name: ${experiment_name} project: RIN log_model: false offline: false checkpoints: _target_: callbacks.checkpoint_and_validate.ModelCheckpointValidate gpu_type: ${computer.eval_gpu_type} validate_when_not_on_cluster: false validate_when_on_cluster: false eval_set: val validate_conditional: true validate_unconditional: false validate_per_class_metrics: true shape: - ${model.network.num_input_channels} - ${data.data_resolution} - ${data.data_resolution} num_classes: ${data.label_dim} dataset_name: ${data.name} dirpath: ${root_dir}/cad/checkpoints/${experiment_name} filename: step_{step} monitor: val/loss_ema save_last: true save_top_k: -1 enable_version_counter: false every_n_train_steps: 10000 auto_insert_metric_name: false progress_bar: _target_: pytorch_lightning.callbacks.TQDMProgressBar refresh_rate: ${computer.progress_bar_refresh_rate} data_dir: ${root_dir}/cad/datasets root_dir: ${hydra:runtime.cwd} experiment_name_suffix: hf_h200 experiment_name: ${data.name}_${model.name}_${experiment_name_suffix}