defaults: - _self_ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor] - /data: openwebtext - /model: small - /strategy: ddp - /noise: loglinear - /lr_scheduler: constant_warmup mode: train # train / ppl_eval / sample_eval diffusion: absorbing_state backbone: dit # dit / dimamba / ar : backbone for Diffusion ebm_backbone: null # dit / dimamba / ar : backbone for EBM parameterization: subs # subs / d3pm / sedd time_conditioning: True T: 0 # 0 (continuous time) / 1000 subs_masking: False seed: 1 loader: global_batch_size: 512 eval_global_batch_size: ${.global_batch_size} # Note: batch_size and eval_batch_size are **per machine** batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}} eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}} num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"} pin_memory: True sampling: predictor: ddpm_cache # analytic, ddpm, ddpm_cache steps: 128 noise_removal: True # TODO(yair): @subham, why aren't these params under `eval`? num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches num_sample_log: 2 semi_ar: False stride_length: 1 num_strides: 1 # importance sampling is_size: 2 is_start: 0.6 is_end: 0.4 is_temp: 1 # ar ebm ar_carry_over: True training: ema: 0.9999 antithetic_sampling: True importance_sampling: False sampling_eps: 1e-3 change_of_variables: False eval: checkpoint_path: '' # Used to evaluate a checkpoint after training. disable_ema: False compute_generative_perplexity: False perplexity_batch_size: 8 compute_perplexity_on_sanity: False gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf, meta-llama/Meta-Llama-3-8B generate_samples: True optim: weight_decay: 0 lr: 3e-4 beta1: 0.9 beta2: 0.999 eps: 1e-8 trainer: _target_: lightning.Trainer accelerator: cuda num_nodes: 1 devices: ${device_count:} accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}} gradient_clip_val: 1.0 precision: 'bf16' num_sanity_val_steps: 2 max_steps: 1_000_000 log_every_n_steps: 10 limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run val_check_interval: 10000 wandb: project: text-diffusion notes: Mulan for text group: null job_type: null name: null id: ${.name}_${seed} tags: - ${noise.type} - ${data.train} - ${data.valid} hydra: run: dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S} job: chdir: true checkpointing: # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is save_dir: ${cwd:} # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath` resume_from_ckpt: true resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt