plaid / PLAID-2B /config.yaml
amyxlu's picture
Upload PLAID-2B
d95bcd9 verified
resume_from_model_id: null
compression_model_id: j1v1wv6w
use_old_ema_module: false
paths:
project_dir: /homefs/home/lux70/code/plaid
bucket_dir: /data/lux70/plaid
data_dir: /data/lux70/data
home_dir: /homefs/home/lux70
log_dir: ${paths.bucket_dir}/logs
checkpoint_dir: ${paths.bucket_dir}/checkpoints/plaid-compositional
artifacts_dir: ${paths.bucket_dir}/artifacts
entity: lu-amy-al1
trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
strategy: ddp_find_unused_parameters_true
devices: -1
num_nodes: 8
precision: '32'
gradient_clip_val: 0.5
log_every_n_steps: 50
num_sanity_val_steps: 0
gradient_clip_algorithm: norm
max_epochs: 20000
default_root_dir: ${paths.log_dir}
datamodule:
_target_: plaid.datasets.FunctionOrganismDataModule
train_shards: ${paths.data_dir}/pfam/compressed/j1v1wv6w/train/shard{0000..4423}.tar
val_shards: ${paths.data_dir}/pfam/compressed/j1v1wv6w/val/shard{0000..0863}.tar
config_file: ${paths.data_dir}/pfam/compressed/j1v1wv6w/config.json
go_metadata_fpath: ${paths.data_dir}/pfam/pfam2go.csv
organism_metadata_fpath: ${paths.data_dir}/pfam/organism_counts.csv
cache_dir: ${paths.home_dir}/.cache/plaid_data/j1v1wv6w
train_epoch_num_batches: 1000000
val_epoch_num_batches: 1000
shuffle_buffer: 20000
shuffle_initial: 20000
max_length: 256
batch_size: 32
num_workers: 8
prefetch_factor: 4
denoiser:
_target_: plaid.denoisers.FunctionOrganismUDiT
hidden_size: 2048
max_seq_len: 256
depth: 23
num_heads: 16
mlp_ratio: 4.0
use_self_conditioning: true
timestep_embedding_strategy: sinusoidal
use_skip_connect: false
attention_mode: xformers_memory_efficient
diffusion:
_target_: plaid.diffusion.FunctionOrganismDiffusion
beta_scheduler_name: sigmoid
beta_scheduler_start: -3
beta_scheduler_end: 3
beta_scheduler_tau: 1
x_downscale_factor: 1.0
timesteps: 1000
objective: pred_v
min_snr_loss_weight: true
min_snr_gamma: 5
x_clip_val: 1.0
function_y_cond_drop_prob: 0.1
organism_y_cond_drop_prob: 0.1
ema_decay: 0.9999
lr: 0.0001
lr_adam_betas:
- 0.9
- 0.999
lr_sched_type: cosine_with_restarts
lr_num_warmup_steps: 10000
lr_num_training_steps: 1000000
lr_num_cycles: 1
callbacks:
checkpoint:
_target_: plaid.callbacks.EMAModelCheckpoint
save_last: link
filename: epoch{epoch}-step{step}
verbose: true
every_n_train_steps: 10000
monitor: step
save_top_k: 1
mode: max
auto_insert_metric_name: false
dirpath: ${paths.checkpoint_dir}
ema:
_target_: plaid.callbacks.EMA
decay: 0.9999
apply_ema_every_n_steps: 1
start_step: 0
save_ema_weights_in_callback_state: false
evaluate_ema_weights_instead: true
logger:
_target_: lightning.pytorch.loggers.WandbLogger
project: plaid-compositional-conditioning
entity: prescient-design
name: UDiT_XXL
tags: null
group: null