DrugFlow / configs /training /preference_alignment.yml
mority's picture
Upload 53 files
6e7d4ba verified
run_name: drugflow_preference_alignment
checkpoint: ./reference.ckpt # TODO: specify reference checkpoint
dpo_mode: single_dpo_comp_v3
pocket_representation: CA+
virtual_nodes: [0, 10]
flexible: False
flexible_bb: False
train_params:
logdir: ./runs # symlink to any location you like
datadir: ./processed_crossdocked # symlink to the dataset location
enable_progress_bar: True
num_sanity_val_steps: 0
batch_size: 64
accumulate_grad_batches: 2
lr: 5.0e-5
n_epochs: 500
num_workers: 0
gpus: 1
clip_grad: True
gnina: gnina # path to gnina binary
sample_from_clusters: False
sharded_dataset: False
wandb_params:
mode: online # disabled, offline, online
entity:
group: crossdocked
loss_params:
discrete_loss: VLB # VLB or CE
lambda_x: 1.0
lambda_h: 500
dpo_lambda_h: 2500
lambda_e: 500
dpo_lambda_e: 2500
lambda_chi: 0.5 # only effective if flexible=True
lambda_trans: 1.0 # only effective if flexible_bb=True
lambda_rot: 0.1 # only effective if flexible_bb=True
lambda_clash: null
timestep_weights: null # sigmoid_a=1_b=10 # null, sigmoid_a=?_b=?
dpo_beta: 100.0
dpo_beta_schedule: 't'
dpo_lambda_w: 1.0
dpo_lambda_l: 0.2
clamp_dpo: False
simulation_params:
n_steps: 5000
prior_h: marginal # uniform, marginal
prior_e: uniform # uniform, marginal
predict_final: False
predict_confidence: False
eval_params:
eval_epochs: 4
n_eval_samples: 1
n_sampling_steps: 500
eval_batch_size: 16
visualize_sample_epoch: 1
n_visualize_samples: 10
visualize_chain_epoch: 1
keep_frames: 100
sample_with_ground_truth_size: True
predictor_params:
heterogeneous_graph: True
backbone: gvp
num_rbf_time: 16
edge_cutoff_ligand: null
edge_cutoff_pocket: 10.0
edge_cutoff_interaction: 10.0
cycle_counts: True
spectral_feat: False
reflection_equivariant: False
num_rbf: 16
d_max: 15.0
self_conditioning: True
augment_residue_sc: False
augment_ligand_sc: False
normal_modes: False
add_chi_as_feature: False
angle_act_fn: null
add_all_atom_diff: False
gvp_params:
n_layers: 5
node_h_dim: [ 128, 32 ] # (s, V)
edge_h_dim: [ 128, 32 ]
dropout: 0.0
vector_gate: True