SFTok-B / stage3.yaml
AndyRaoTHU's picture
Rename sftok-b_stage3.yaml to stage3.yaml
490d3d4 verified
experiment:
project: "sftok_b"
name: "sftok_b_stage3_run1"
output_dir: "sftok_b_stage3_run1"
max_train_examples: 1_281_167
save_every: 25_000
eval_every: 50_000
generate_every: 5_000
log_every: 50
log_grad_norm_every: 1_000
resume: True
init_weight: ""
model:
model_type: "sftok" #
vq_model:
codebook_size: 8192
token_size: 32
use_l2_norm: True
commitment_cost: 0.25
# vit arch
vit_enc_model_size: "base"
vit_enc_patch_size: 16
num_latent_tokens: 64
num_group: 1
finetune_decoder: True
semantic_guide: null
pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
decoder:
# sftok
vit_dec_model_size: "base"
vit_dec_patch_size: 16
num_latent_tokens: 64
token_size: 32
num_proxy_codes: 256
codebook_size: 1024
semantic_guide: null
# sampling hyper-params on the flight
randomize_temperature: 1.0 # 1.0
guidance_scale: 0.0 # 4.5
guidance_decay: "constant"
embedding_width: 1024
embedding_init: False
losses:
discriminator_type: dinodisc # patchgan
discriminator_start: 20_000 # 20_000
quantizer_weight: 1.0
discriminator_factor: 1.0
discriminator_weight: 0.5
perceptual_loss: lpips
perceptual_weight: 1.0
reconstruction_loss: "l2"
reconstruction_weight: 1.0
lecam_regularization_weight: 0.001
dataset:
params:
train_shards_path_or_url: /path/to/data/train
eval_shards_path_or_url: /path/to/data/val
num_workers_per_gpu: 12
dataset_type: simple_image_dataset
preprocessing:
resize_shorter_edge: 256
crop_size: 256
random_crop: True
random_flip: True
optimizer:
name: adamw
params:
learning_rate: 1e-4
discriminator_learning_rate: 1e-4
beta1: 0.9
beta2: 0.999
weight_decay: 1e-4
lr_scheduler:
scheduler: "cosine"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 5_000
end_lr: 1e-5
training:
gradient_accumulation_steps: 2
per_gpu_batch_size: 64
mixed_precision: "fp16"
enable_tf32: True
enable_wandb: False
use_ema: True
seed: 42
max_train_steps: 1_000_000
num_generated_images: 2
max_grad_norm: 1.0
use_mlmloss: True
single_step_generation: False
mask_schedule: "arccos" # "arccos", "linear", "exponential", "cosine"
mask_power: 1.0
min_mask_ratio: 0.0
max_mask_ratio: 1.0
use_soft_proxycodes: False
soft_proxycode_sigma: 4.0
guided_mask: True