experiment: project: "sftok_b" name: "sftok_b_stage3_run1" output_dir: "sftok_b_stage3_run1" max_train_examples: 1_281_167 save_every: 25_000 eval_every: 50_000 generate_every: 5_000 log_every: 50 log_grad_norm_every: 1_000 resume: True init_weight: "" model: model_type: "sftok" # vq_model: codebook_size: 8192 token_size: 32 use_l2_norm: True commitment_cost: 0.25 # vit arch vit_enc_model_size: "base" vit_enc_patch_size: 16 num_latent_tokens: 64 num_group: 1 finetune_decoder: True semantic_guide: null pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin" decoder: # sftok vit_dec_model_size: "large" vit_dec_patch_size: 16 num_latent_tokens: 64 token_size: 32 num_proxy_codes: 256 codebook_size: 1024 semantic_guide: null # sampling hyper-params on the flight randomize_temperature: 1.0 # 1.0 guidance_scale: 0.0 # 4.5 guidance_decay: "constant" embedding_width: 1024 embedding_init: False losses: discriminator_type: dinodisc # patchgan discriminator_start: 20_000 # 20_000 quantizer_weight: 1.0 discriminator_factor: 1.0 discriminator_weight: 0.5 perceptual_loss: lpips perceptual_weight: 1.0 reconstruction_loss: "l2" reconstruction_weight: 1.0 lecam_regularization_weight: 0.001 dataset: params: train_shards_path_or_url: /path/to/data/train eval_shards_path_or_url: /path/to/data/val num_workers_per_gpu: 12 dataset_type: simple_image_dataset preprocessing: resize_shorter_edge: 256 crop_size: 256 random_crop: True random_flip: True optimizer: name: adamw params: learning_rate: 1e-4 discriminator_learning_rate: 1e-4 beta1: 0.9 beta2: 0.999 weight_decay: 1e-4 lr_scheduler: scheduler: "cosine" params: learning_rate: ${optimizer.params.learning_rate} warmup_steps: 5_000 end_lr: 1e-5 training: gradient_accumulation_steps: 2 per_gpu_batch_size: 64 mixed_precision: "fp16" enable_tf32: True enable_wandb: False use_ema: True seed: 42 max_train_steps: 1_000_000 num_generated_images: 2 max_grad_norm: 1.0 use_mlmloss: True single_step_generation: False mask_schedule: "arccos" # "arccos", "linear", "exponential", "cosine" mask_power: 1.0 min_mask_ratio: 0.0 max_mask_ratio: 1.0 use_soft_proxycodes: False soft_proxycode_sigma: 4.0 guided_mask: True