RIN_imagenet64 / config.yaml
AlienKevin's picture
Upload checkpoints
2b35aa7 verified
model:
optimizer:
optim:
_target_: utils.optimizers.Lamb
lr: 0.002
betas:
- 0.9
- 0.999
weight_decay: 0.01
exclude_ln_and_biases_from_weight_decay: true
lr_scheduler:
_partial_: true
_target_: utils.lr_scheduler.WarmupCosineDecayLR
warmup_steps: 10000
total_steps: ${trainer.max_steps}
rate: 0.7
network:
_target_: cad.models.networks.rin.RINClassCond
data_size: ${data.data_resolution}
data_dim: 512
num_input_channels: 3
num_latents: 128
latents_dim: 768
label_dim: ${data.label_dim}
num_cond_tokens: ${data.num_cond_tokens}
num_processing_layers: 4
num_blocks: 4
path_size: 8
read_write_heads: 16
compute_heads: 16
latent_mlp_multiplier: 4
data_mlp_multiplier: 4
rw_dropout: 0.0
compute_dropout: 0
rw_stochastic_depth: 0
compute_stochastic_depth: 0
time_scaling: 1000.0
noise_embedding_type: positional
data_positional_embedding_type: learned
weight_init: xavier_uniform
bias_init: zeros
use_cond_token: true
use_biases: true
concat_cond_token_to_latents: true
use_cond_rin_block: false
use_16_bits_layer_norm: false
train_noise_scheduler:
_target_: cad.models.schedulers.LinearScheduler
start: 1
end: 0
clip_min: 1.0e-09
inference_noise_scheduler:
_target_: cad.models.schedulers.CosineSchedulerSimple
ns: 0.0002
ds: 0.00025
preconditioning:
_target_: cad.models.preconditioning.DDPMPrecond
num_latents: ${model.network.num_latents}
latents_dim: ${model.network.latents_dim}
data_preprocessing:
_target_: cad.models.preprocessing.PrecomputedPreconditioning
input_key: image
output_key_root: x_0
cond_preprocessing:
_target_: cad.models.preprocessing.PrecomputedPreconditioning
input_key: label
output_key_root: label
drop_labels: false
postprocessing:
_partial_: true
_target_: utils.image_processing.remap_image_torch
loss:
_partial_: true
_target_: cad.models.losses.DDPMLoss
self_cond_rate: 0.9
cond_drop_rate: 0.0
conditioning_key: ${model.cond_preprocessing.output_key_root}
resample_by_coherence: false
sample_random_when_drop: false
val_sampler:
_partial_: true
_target_: cad.models.samplers.ddim.ddim_sampler
num_steps: 250
cfg_rate: ${model.cfg_rate}
test_sampler:
_partial_: true
_target_: cad.models.samplers.ddpm.ddpm_sampler
num_steps: 1000
cfg_rate: ${model.cfg_rate}
uncond_conditioning:
_target_: cad.utils.misc.dummy_value_loader
value: 0.0
vae_embedding_name_mean: null
return_image: true
name: RIN
ema_decay: 0.9999
start_ema_step: 0
cfg_rate: 0.0
channel_wise_normalisation: false
computer:
devices: 8
num_workers: 64
progress_bar_refresh_rate: 2
sync_batchnorm: true
accelerator: gpu
precision: bf16-mixed
strategy: ddp
num_nodes: 1
eval_gpu_type: h200
data:
train_aug:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: utils.image_processing.CenterCrop
ratio: '1:1'
- _target_: torchvision.transforms.Resize
size: ${data.img_resolution}
interpolation: 3
antialias: true
- _target_: torchvision.transforms.RandomHorizontalFlip
p: 0.5
- _target_: torchvision.transforms.Normalize
mean: 0.5
std: 0.5
val_aug:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: utils.image_processing.CenterCrop
ratio: '1:1'
- _target_: torchvision.transforms.Resize
size: ${data.img_resolution}
interpolation: 3
antialias: true
- _target_: torchvision.transforms.Normalize
mean: 0.5
std: 0.5
name: ImageNet_64
type: class_conditional
img_resolution: 64
data_resolution: 64
label_dim: 1000
num_cond_tokens: 1
full_batch_size: 1024
in_channels: 3
out_channels: 3
train_instance:
_partial_: true
_target_: cad.data.dataset.HFImageNet64
split: train
transform: ${data.train_aug}
target_transform: ${data.target_transform}
val_instance:
_partial_: true
_target_: cad.data.dataset.HFImageNet64
split: validation
transform: ${data.val_aug}
target_transform: ${data.target_transform}
target_transform:
_target_: utils.one_hot_transform.OneHotTransform
num_classes: ${data.label_dim}
collate_fn:
_target_: data.datamodule.collate_to_dict
keys:
- image
- label
train_dataset: ${data.train_instance}
val_dataset: ${data.val_instance}
datamodule:
_target_: data.datamodule.ImageDataModule
train_dataset: ${data.train_dataset}
val_dataset: ${data.val_dataset}
full_batch_size: ${data.full_batch_size}
collate_fn: ${data.collate_fn}
num_workers: ${computer.num_workers}
num_nodes: ${computer.num_nodes}
num_devices: ${computer.devices}
trainer:
_target_: pytorch_lightning.Trainer
max_steps: 150000
val_check_interval: 5000
check_val_every_n_epoch: null
devices: ${computer.devices}
accelerator: ${computer.accelerator}
strategy: ${computer.strategy}
log_every_n_steps: 1
num_nodes: ${computer.num_nodes}
precision: ${computer.precision}
logger:
_target_: pytorch_lightning.loggers.WandbLogger
save_dir: ${root_dir}/cad/wandb
name: ${experiment_name}
project: RIN
log_model: false
offline: false
checkpoints:
_target_: callbacks.checkpoint_and_validate.ModelCheckpointValidate
gpu_type: ${computer.eval_gpu_type}
validate_when_not_on_cluster: false
validate_when_on_cluster: false
eval_set: val
validate_conditional: true
validate_unconditional: false
validate_per_class_metrics: true
shape:
- ${model.network.num_input_channels}
- ${data.data_resolution}
- ${data.data_resolution}
num_classes: ${data.label_dim}
dataset_name: ${data.name}
dirpath: ${root_dir}/cad/checkpoints/${experiment_name}
filename: step_{step}
monitor: val/loss_ema
save_last: true
save_top_k: -1
enable_version_counter: false
every_n_train_steps: 10000
auto_insert_metric_name: false
progress_bar:
_target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: ${computer.progress_bar_refresh_rate}
data_dir: ${root_dir}/cad/datasets
root_dir: ${hydra:runtime.cwd}
experiment_name_suffix: hf_h200
experiment_name: ${data.name}_${model.name}_${experiment_name_suffix}