SEDD-500M-Coder / config.yaml
asdf213's picture
Upload folder using huggingface_hub
d2136fb verified
defaults:
- _self_
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
- /data: openwebtext
- /model: small
- /strategy: ddp
- /noise: loglinear
- /lr_scheduler: constant_warmup
mode: train # train / ppl_eval / sample_eval
diffusion: absorbing_state
backbone: dit # dit / dimamba / ar : backbone for Diffusion
ebm_backbone: null # dit / dimamba / ar : backbone for EBM
parameterization: subs # subs / d3pm / sedd
time_conditioning: True
T: 0 # 0 (continuous time) / 1000
subs_masking: False
seed: 1
loader:
global_batch_size: 512
eval_global_batch_size: ${.global_batch_size}
# Note: batch_size and eval_batch_size are **per machine**
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
pin_memory: True
sampling:
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
steps: 128
noise_removal: True
# TODO(yair): @subham, why aren't these params under `eval`?
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
num_sample_log: 2
semi_ar: False
stride_length: 1
num_strides: 1
# importance sampling
is_size: 2
is_start: 0.6
is_end: 0.4
is_temp: 1
# ar ebm
ar_carry_over: True
training:
ema: 0.9999
antithetic_sampling: True
importance_sampling: False
sampling_eps: 1e-3
change_of_variables: False
eval:
checkpoint_path: '' # Used to evaluate a checkpoint after training.
disable_ema: False
compute_generative_perplexity: False
perplexity_batch_size: 8
compute_perplexity_on_sanity: False
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf, meta-llama/Meta-Llama-3-8B
generate_samples: True
optim:
weight_decay: 0
lr: 3e-4
beta1: 0.9
beta2: 0.999
eps: 1e-8
trainer:
_target_: lightning.Trainer
accelerator: cuda
num_nodes: 1
devices: ${device_count:}
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
gradient_clip_val: 1.0
precision: 'bf16'
num_sanity_val_steps: 2
max_steps: 1_000_000
log_every_n_steps: 10
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
val_check_interval: 10000
wandb:
project: text-diffusion
notes: Mulan for text
group: null
job_type: null
name: null
id: ${.name}_${seed}
tags:
- ${noise.type}
- ${data.train}
- ${data.valid}
hydra:
run:
dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
job:
chdir: true
checkpointing:
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
save_dir: ${cwd:}
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
resume_from_ckpt: true
resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt