A2D2 / a2d2_mol /config_mol.yaml
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
1.34 kB
trainer: "any-order-flow"
dataset: "safe-drugs"
# HuggingFace dataset configuration
hf_dataset:
name: "datamol-io/safe-gpt"
smiles_column: "smiles" # Adjust based on actual column name in the dataset
model:
hidden_size: 768
n_heads: 12
cond_dim: 128
dropout: 0.05
n_blocks: 12
torch_dtype: 'float32' # Options: 'float32', 'float16', 'bfloat16'
interpolant:
type: "any-order"
tokens: null # filled in automatically
pad_token: null # filled in automatically
mask_token: null # filled in automatically
max_length: 256
insert_schedule:
type: "linear"
unmask_schedule:
type: "linear"
training:
only_embed_insert: true
batch_size: 2048
per_gpu_batch_size: 64 # Gradient accumulation happens automatically
cpus: 4
learning_rate: 3e-4
nodes: 1
devices: 2
max_steps: 500000
weight_decay: 0.03
checkpoint_dir: "checkpoints/pretrain_mol"
save_top_k: 3
save_every_n_steps: 1000 # Save checkpoint every 1k steps (for streaming datasets)
# save_every_n_epochs: 1 # Not used with streaming datasets
loss_fn:
unmask: "elbo"
insert: "expectation"
reset_lr: false
warmup_steps: 2000
ema_decay: 0.9999
filter_max_length: false
wandb:
entity: null # set to your W&B entity, or leave null to use the default
project: "a2d2-mol"
name: "a2d2-mol"
path: "./wandb"