PIEthonista commited on
Commit
dff2967
·
1 Parent(s): 4deaa5d
Files changed (21) hide show
  1. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml +386 -0
  2. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/args_135_iter_224808.pickle +3 -0
  3. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/generative_model_135_iter_224808.npy +3 -0
  4. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml +386 -0
  5. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/args_75_iter_125628.pickle +3 -0
  6. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/generative_model_75_iter_125628.npy +3 -0
  7. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml +378 -0
  8. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume/args_75_iter_125628.pickle +3 -0
  9. 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume/generative_model_75_iter_125628.npy +3 -0
  10. 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml +386 -0
  11. 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/args_75_iter_125628.pickle +3 -0
  12. 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/generative_model_75_iter_125628.npy +3 -0
  13. 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml +386 -0
  14. 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/args_75_iter_125628.pickle +3 -0
  15. 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/generative_model_75_iter_125628.npy +3 -0
  16. 03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_2x_resume/03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A.yaml +381 -0
  17. 03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_2x_resume/args_5_iter_56988.pickle +3 -0
  18. 03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_2x_resume/generative_model_5_iter_56988.npy +3 -0
  19. 03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A/03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml +377 -0
  20. 03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A/args_135_iter_224808.pickle +3 -0
  21. 03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A/generative_model_135_iter_224808.npy +3 -0
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ proj_name: Control-GeoLDM
2
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
3
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x
4
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_2x
5
+ exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x
6
+
7
+
8
+ # ========================================================================================================== Training Mode (ldm/vae/both)
9
+ # Train second stage LatentDiffusionModel model
10
+ train_diffusion: true
11
+
12
+ # training mode: VAE | LDM | ControlNet
13
+ training_mode: ControlNet
14
+ loss_analysis: false
15
+
16
+ # Specify ligand & pocket VAE weights path, set to null for random initialisation
17
+ # set checkpoint (ckpt) to null to automatically select best
18
+ ligand_ae_path: outputs_selected/vae_ligands/AMP__01_VAE_vaenorm_True10__bfloat16__latent2_nf256_epoch100_bs36_lr1e-4_InvClassFreq_Smooth0.25_x10_h5_NoEMA__DecOnly_KL-0__20240623__10A__LG_Only
19
+ ligand_ae_ckpt: generative_model_2_iter_6336.npy
20
+ pocket_ae_path: outputs_selected/vae_pockets/AMP__01_VAE_vaenorm_True10__float32__latent2_nf256_epoch100_bs12_lr1e-4_InvClassFreq_Smooth0.25_XH_x30_h15_NoEMA__20240623__10A__PKT_CA_Only
21
+ pocket_ae_ckpt: generative_model_3_iter_33308.npy
22
+
23
+ # Specify LDM weights path, set to null for random initialisation
24
+ ldm_path: outputs_selected/ldm/AMP__02_LDM_vaenorm_True10__float32__latent2_nf256_epoch200_bs36_lr1e-4_NoEMA__VAE_DecOnly_KL-0__20240623__10A_9x_resume
25
+ ldm_ckpt: generative_model_108_iter_230208.npy
26
+
27
+ # Zero out all weights of fusion blocks instead of randomly instantiated
28
+ zero_fusion_block_weights: false
29
+
30
+
31
+ # Train 1st stage AutoEncoder model (no effect if train_diffusion=False)
32
+ trainable_ligand_ae_encoder: false
33
+ trainable_ligand_ae_decoder: false
34
+ trainable_pocket_ae_encoder: false
35
+
36
+ # Train 2nd stage LDM model
37
+ trainable_ldm: false
38
+
39
+ # Train 3rd stage ControlNet
40
+ trainable_controlnet: true
41
+ trainable_fusion_blocks: true
42
+
43
+
44
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
45
+ conditioning: []
46
+
47
+ # include atom charge, according to periodic table
48
+ include_charges: false # true for qm9
49
+
50
+ # only works for ldm, not for VAE
51
+ condition_time: true
52
+
53
+ # Time Noisy, t/2, adopted from [https://arxiv.org/abs/2405.06659]
54
+ time_noisy: false
55
+
56
+ vis_activations: false
57
+ vis_activations_batch_samples: 5
58
+ vis_activations_batch_size: 1
59
+ vis_activations_specific_ylim: [0, 40]
60
+
61
+ # random_seed: 0
62
+ random_seed: 42
63
+
64
+
65
+ # ========================================================================================================== Dataset
66
+
67
+ # pre-computed dataset stats
68
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__LIGAND
69
+
70
+ # pre-computed training dataset
71
+ data_file: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/d_20241203_CrossDocked_LG_PKT_MMseq2_split__10.0A__CA_Only.npz
72
+ data_splitted: true
73
+
74
+ # Quick Vina 2.1
75
+ compute_qvina: true
76
+ qvina_search_size: 20 # search size (all 3 axes) in Angstroms around ligand center
77
+ qvina_exhaustiveness: 16
78
+ qvina_seed: 42
79
+ qvina_cleanup_files: true # cleanup tmp pdb, pdbqt files
80
+ qvina_save_csv: true # save results in csv
81
+ pocket_pdb_dir: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/test_val_paired_files/val_pocket
82
+ match_raw_file_by_id: true
83
+ mgltools_env_name: mgltools-python2 # for pdb -> pdbqt conversion
84
+
85
+ ligand_add_H: false # add hydrogens via: [mgltools] prepare_ligand4.py -l .. -o .. -A hydrogens
86
+ pocket_add_H: false # add hydrogens via: [mgltools] prepare_receptor4.py -r .. -o .. -A checkhydrogens
87
+ pocket_remove_nonstd_resi: false # remove any pocket residues not in this list:
88
+ # ['CYS','ILE','SER','VAL','GLN','LYS','ASN',
89
+ # 'PRO','THR','PHE','ALA','HIS','GLY','ASP',
90
+ # 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET',
91
+ # 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']
92
+
93
+
94
+ # set to null if you're running this dataset for the first time.
95
+ # Script will generate a random permutation to shuffle the dataset.
96
+ # Please set the path to the DATASET_permutation.npy file after it is generated.
97
+ # permutation_file_path: ./data/d_20240623_CrossDocked_LG_PKT/d_20240623_CrossDocked_LG_PKT__10.0A_LG100_PKT600_permutation.npy
98
+ permutation_file_path: null
99
+
100
+ # what data to load for VAE training: ligand | pocket | all
101
+ vae_data_mode: ligand
102
+
103
+ # When set to an integer value, QM9 will only contain molecules of that amount of atoms, default null
104
+ filter_n_atoms: null
105
+
106
+ # Only use molecules below this size. Int, default null ~!geom
107
+ filter_molecule_size: 100
108
+ filter_pocket_size: 80
109
+
110
+ # Organize data by size to reduce average memory usage. ~!geom
111
+ sequential: false
112
+
113
+ # Number of worker for the dataloader
114
+ num_workers: 32 # match cpu count
115
+
116
+ # use data augmentation (i.e. random rotation of x atom coordinates)
117
+ data_augmentation: false
118
+
119
+ # remove hydrogen atoms
120
+ remove_h: false
121
+
122
+
123
+
124
+
125
+ # ========================================================================================================== Training Params
126
+ start_epoch: 0
127
+ test_epochs: 5 # 4
128
+
129
+
130
+ n_epochs: 1000 # 3000 takes 20 epoches on paper (bs:32), hence 80 epochs for bs:8
131
+ batch_size: 60 # 14
132
+ lr: 1.0e-4
133
+
134
+ # weight of KL term in ELBO, default 0.01
135
+ kl_weight: 0.0
136
+
137
+ # ode_regularization weightage, default 1e-3
138
+ ode_regularization: 0.001
139
+ # brute_force: false
140
+ # actnorm: true
141
+ break_train_epoch: false
142
+
143
+ # Data Parallel for multi GPU support
144
+ dp: true
145
+ clip_grad: true
146
+
147
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
148
+ ema_decay: 0 # 0.99
149
+
150
+ # add noise to x before encoding, data augmenting
151
+ augment_noise: 0
152
+
153
+ # Number of samples to compute the stability, default 500
154
+ n_stability_samples: 90 # 98, 50
155
+ n_stability_samples_batch_size: 10 # 7, 14
156
+
157
+ # Dataset partition where pocket samples will be drawn from for analyzing
158
+ # generated ligands' stability: train | test | val
159
+ n_stability_eval_split: val
160
+
161
+
162
+ # disables CUDA training
163
+ no_cuda: false
164
+
165
+ # hutch | exact
166
+ trace: hutch
167
+
168
+ # verbose logging
169
+ verbose: false
170
+
171
+ dtype: torch.float32
172
+
173
+ # enable mixed precision training (fp32, fp16)
174
+ mixed_precision_training: true
175
+ mixed_precision_autocast_dtype: torch.bfloat16
176
+
177
+ # use model checkpointing during training to reduce GPU memory usage
178
+ use_checkpointing: true
179
+
180
+ # sqrt: checkpointing is done on the sqrt(block_num)'th Equivariant block of each EGNN for most optimal perf
181
+ # all: checkpointing is done on all Equivariant blocks. Not optimal but helps if input size is too large
182
+ checkpointing_mode: sqrt
183
+
184
+ # splits tensors into managable chunks and performs forward propagation without breaking GPU memory limit
185
+ forward_tensor_chunk_size: 50000
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+ # ========================================================================================================== LDM
195
+ # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
196
+ model: egnn_dynamics
197
+
198
+ probabilistic_model: diffusion
199
+
200
+ # Training complexity is O(1) (unaffected), but sampling complexity is O(steps), default 500
201
+ diffusion_steps: 1000
202
+
203
+ # learned, cosine, polynomial_<power>
204
+ diffusion_noise_schedule: polynomial_2
205
+
206
+ # default 1e-5
207
+ diffusion_noise_precision: 1.0e-05 # ~!fp16
208
+
209
+ # vlb | l2
210
+ diffusion_loss_type: l2
211
+
212
+ # number of latent features, default 4
213
+ latent_nf: 2
214
+
215
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
216
+ normalize_factors: [1, 4, 10]
217
+
218
+ vae_normalize_x: true
219
+ vae_normalize_method: scale # scale | linear
220
+ vae_normalize_factors: [10, 1, 1]
221
+
222
+ reweight_class_loss: "inv_class_freq"
223
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
224
+
225
+ error_x_weight: 10 # error_x custom weighting
226
+ error_h_weight: 5
227
+
228
+
229
+ # ========================================================================================================== Network Architecture
230
+
231
+ # number of layers of EquivariantBlock to use in VAE's Encoder
232
+ encoder_n_layers: 1
233
+
234
+ # number of layers of EquivariantBlock to use in LDM and VAE's Decoder
235
+ n_layers: 4
236
+
237
+ # number of GCL Blocks to use in each EquivariantBlock
238
+ inv_sublayers: 1
239
+
240
+ # model's internal operating number of features
241
+ nf: 256
242
+
243
+ # use tanh in the coord_mlp
244
+ tanh: true
245
+
246
+ # use attention in the EGNN
247
+ attention: true
248
+
249
+ # diff/(|diff| + norm_constant)
250
+ norm_constant: 1
251
+
252
+ # whether using or not the sin embedding
253
+ sin_embedding: false
254
+
255
+ # uniform | variational | argmax_variational | deterministic
256
+ dequantization: argmax_variational
257
+
258
+ # Normalize the sum aggregation of EGNN
259
+ normalization_factor: 1
260
+
261
+ # EGNN aggregation method: sum | mean
262
+ aggregation_method: sum
263
+
264
+
265
+ # Fusion Block specific settings
266
+ fusion_weights: [0.1, 0.1, 0.1, 0.1] # [0.25, 0.5, 0.75, 1]
267
+ # Condition fusion method:
268
+ # - scaled_sum : (h1_i,x1_i) = (h1_i,x1_i) + w_i * (f_h1_i,f_x1_i)
269
+ # - balanced_sum : (h1_i,x1_i) = [(1 - w_i) * (h1_i,x1_i)] + [w_i * (f_h1_i,f_x1_i)]
270
+ # - replace : (h1_i,x1_i) = (f_h1_i,f_x1_i)
271
+ fusion_mode: balanced_sum
272
+
273
+ # Initial Noise Injection / Feedback Mechanism
274
+ noise_injection_weights: [0.5, 0.5] # pkt = w[0]*lg + w[1]*pkt
275
+ noise_injection_aggregation_method: mean # mean | sum
276
+ noise_injection_normalization_factor: 1 # aggregation normalization factor
277
+
278
+
279
+
280
+
281
+ # ========================================================================================================== Logging
282
+ # Can be used to visualize multiple times per epoch, default 1e8
283
+ visualize_sample_chain: true
284
+ visualize_every_batch: 20000
285
+ visualize_sample_chain_epochs: 2 # for 1% testing dataset, others set to 1
286
+ n_report_steps: 50
287
+
288
+
289
+
290
+
291
+ # ========================================================================================================== Saving & Resuming
292
+ # resume: null
293
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
294
+ # resume_model_ckpt: generative_model_8_iter_14049.npy
295
+ # resume_optim_ckpt: optim_8_iter_14049.npy
296
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume
297
+ # resume_model_ckpt: generative_model_84_iter_140505.npy
298
+ # resume_optim_ckpt: optim_84_iter_140505.npy
299
+ resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_2x_resume
300
+ resume_model_ckpt: generative_model_76_iter_127281.npy
301
+ resume_optim_ckpt: optim_76_iter_127281.npy
302
+
303
+ save_model: true
304
+
305
+
306
+
307
+ # ========================================================================================================== Wandb
308
+ # disable wandb
309
+ no_wandb: false
310
+ wandb_usr: gohyixian456
311
+ # True = wandb online -- False = wandb offline
312
+ online: true
313
+
314
+
315
+
316
+
317
+ pocket_vae:
318
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__POCKET
319
+ vae_data_mode: pocket
320
+ remove_h: false
321
+ ca_only: true
322
+
323
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
324
+ conditioning: []
325
+
326
+ # egnn_dynamics
327
+ model: egnn_dynamics
328
+
329
+ # include atom charge, according to periodic table
330
+ include_charges: false
331
+
332
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
333
+ ema_decay: 0
334
+
335
+ # weight of KL term in ELBO, default 0.01
336
+ kl_weight: 0.01
337
+
338
+ # number of latent features, default 4 (have to match ligand VAE & LDM's latent_nf)
339
+ latent_nf: 2
340
+
341
+ # number of layers of EquivariantBlock to use in VAE's Encoder
342
+ encoder_n_layers: 1
343
+
344
+ # number of layers of EquivariantBlock to use in VAE's Decoder
345
+ n_layers: 4
346
+
347
+ # number of GCL Blocks to use in each EquivariantBlock
348
+ inv_sublayers: 1
349
+
350
+ # model's internal operating number of features
351
+ nf: 256
352
+
353
+ # use tanh in the coord_mlp
354
+ tanh: true
355
+
356
+ # use attention in the EGNN
357
+ attention: true
358
+
359
+ # diff/(|diff| + norm_constant)
360
+ norm_constant: 1
361
+
362
+ # whether using or not the sin embedding
363
+ sin_embedding: false
364
+
365
+ # uniform | variational | argmax_variational | deterministic
366
+ dequantization: argmax_variational
367
+
368
+ # Normalize the sum aggregation of EGNN
369
+ normalization_factor: 1
370
+
371
+ # EGNN aggregation method: sum | mean
372
+ aggregation_method: sum
373
+
374
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
375
+ normalize_factors: [1, 4, 10]
376
+
377
+ vae_normalize_x: true
378
+ vae_normalize_method: scale # scale | linear
379
+ vae_normalize_factors: [10, 1, 1]
380
+
381
+ reweight_class_loss: "inv_class_freq"
382
+ reweight_coords_loss: "inv_class_freq"
383
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
384
+
385
+ error_x_weight: 30
386
+ error_h_weight: 15
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/args_135_iter_224808.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81c69beda6fc5c42a0523fcedb22afeed8eb170c216b15215bdd705627b6e317
3
+ size 5710
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/generative_model_135_iter_224808.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:041b33cb54d1628bf5516aa1ae250dd8c7f631835463dc86bc2ce00becebbe0d
3
+ size 53576312
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ proj_name: Control-GeoLDM
2
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
3
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x
4
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_2x
5
+ exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x
6
+
7
+
8
+ # ========================================================================================================== Training Mode (ldm/vae/both)
9
+ # Train second stage LatentDiffusionModel model
10
+ train_diffusion: true
11
+
12
+ # training mode: VAE | LDM | ControlNet
13
+ training_mode: ControlNet
14
+ loss_analysis: false
15
+
16
+ # Specify ligand & pocket VAE weights path, set to null for random initialisation
17
+ # set checkpoint (ckpt) to null to automatically select best
18
+ ligand_ae_path: outputs_selected/vae_ligands/AMP__01_VAE_vaenorm_True10__bfloat16__latent2_nf256_epoch100_bs36_lr1e-4_InvClassFreq_Smooth0.25_x10_h5_NoEMA__DecOnly_KL-0__20240623__10A__LG_Only
19
+ ligand_ae_ckpt: generative_model_2_iter_6336.npy
20
+ pocket_ae_path: outputs_selected/vae_pockets/AMP__01_VAE_vaenorm_True10__float32__latent2_nf256_epoch100_bs12_lr1e-4_InvClassFreq_Smooth0.25_XH_x30_h15_NoEMA__20240623__10A__PKT_CA_Only
21
+ pocket_ae_ckpt: generative_model_3_iter_33308.npy
22
+
23
+ # Specify LDM weights path, set to null for random initialisation
24
+ ldm_path: outputs_selected/ldm/AMP__02_LDM_vaenorm_True10__float32__latent2_nf256_epoch200_bs36_lr1e-4_NoEMA__VAE_DecOnly_KL-0__20240623__10A_9x_resume
25
+ ldm_ckpt: generative_model_108_iter_230208.npy
26
+
27
+ # Zero out all weights of fusion blocks instead of randomly instantiated
28
+ zero_fusion_block_weights: false
29
+
30
+
31
+ # Train 1st stage AutoEncoder model (no effect if train_diffusion=False)
32
+ trainable_ligand_ae_encoder: false
33
+ trainable_ligand_ae_decoder: false
34
+ trainable_pocket_ae_encoder: false
35
+
36
+ # Train 2nd stage LDM model
37
+ trainable_ldm: false
38
+
39
+ # Train 3rd stage ControlNet
40
+ trainable_controlnet: true
41
+ trainable_fusion_blocks: true
42
+
43
+
44
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
45
+ conditioning: []
46
+
47
+ # include atom charge, according to periodic table
48
+ include_charges: false # true for qm9
49
+
50
+ # only works for ldm, not for VAE
51
+ condition_time: true
52
+
53
+ # Time Noisy, t/2, adopted from [https://arxiv.org/abs/2405.06659]
54
+ time_noisy: false
55
+
56
+ vis_activations: false
57
+ vis_activations_batch_samples: 5
58
+ vis_activations_batch_size: 1
59
+ vis_activations_specific_ylim: [0, 40]
60
+
61
+ # random_seed: 0
62
+ random_seed: 42
63
+
64
+
65
+ # ========================================================================================================== Dataset
66
+
67
+ # pre-computed dataset stats
68
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__LIGAND
69
+
70
+ # pre-computed training dataset
71
+ data_file: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/d_20241203_CrossDocked_LG_PKT_MMseq2_split__10.0A__CA_Only.npz
72
+ data_splitted: true
73
+
74
+ # Quick Vina 2.1
75
+ compute_qvina: true
76
+ qvina_search_size: 20 # search size (all 3 axes) in Angstroms around ligand center
77
+ qvina_exhaustiveness: 16
78
+ qvina_seed: 42
79
+ qvina_cleanup_files: true # cleanup tmp pdb, pdbqt files
80
+ qvina_save_csv: true # save results in csv
81
+ pocket_pdb_dir: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/test_val_paired_files/val_pocket
82
+ match_raw_file_by_id: true
83
+ mgltools_env_name: mgltools-python2 # for pdb -> pdbqt conversion
84
+
85
+ ligand_add_H: false # add hydrogens via: [mgltools] prepare_ligand4.py -l .. -o .. -A hydrogens
86
+ pocket_add_H: false # add hydrogens via: [mgltools] prepare_receptor4.py -r .. -o .. -A checkhydrogens
87
+ pocket_remove_nonstd_resi: false # remove any pocket residues not in this list:
88
+ # ['CYS','ILE','SER','VAL','GLN','LYS','ASN',
89
+ # 'PRO','THR','PHE','ALA','HIS','GLY','ASP',
90
+ # 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET',
91
+ # 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']
92
+
93
+
94
+ # set to null if you're running this dataset for the first time.
95
+ # Script will generate a random permutation to shuffle the dataset.
96
+ # Please set the path to the DATASET_permutation.npy file after it is generated.
97
+ # permutation_file_path: ./data/d_20240623_CrossDocked_LG_PKT/d_20240623_CrossDocked_LG_PKT__10.0A_LG100_PKT600_permutation.npy
98
+ permutation_file_path: null
99
+
100
+ # what data to load for VAE training: ligand | pocket | all
101
+ vae_data_mode: ligand
102
+
103
+ # When set to an integer value, QM9 will only contain molecules of that amount of atoms, default null
104
+ filter_n_atoms: null
105
+
106
+ # Only use molecules below this size. Int, default null ~!geom
107
+ filter_molecule_size: 100
108
+ filter_pocket_size: 80
109
+
110
+ # Organize data by size to reduce average memory usage. ~!geom
111
+ sequential: false
112
+
113
+ # Number of worker for the dataloader
114
+ num_workers: 32 # match cpu count
115
+
116
+ # use data augmentation (i.e. random rotation of x atom coordinates)
117
+ data_augmentation: false
118
+
119
+ # remove hydrogen atoms
120
+ remove_h: false
121
+
122
+
123
+
124
+
125
+ # ========================================================================================================== Training Params
126
+ start_epoch: 0
127
+ test_epochs: 5 # 4
128
+
129
+
130
+ n_epochs: 1000 # 3000 takes 20 epoches on paper (bs:32), hence 80 epochs for bs:8
131
+ batch_size: 60 # 14
132
+ lr: 1.0e-4
133
+
134
+ # weight of KL term in ELBO, default 0.01
135
+ kl_weight: 0.0
136
+
137
+ # ode_regularization weightage, default 1e-3
138
+ ode_regularization: 0.001
139
+ # brute_force: false
140
+ # actnorm: true
141
+ break_train_epoch: false
142
+
143
+ # Data Parallel for multi GPU support
144
+ dp: true
145
+ clip_grad: true
146
+
147
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
148
+ ema_decay: 0 # 0.99
149
+
150
+ # add noise to x before encoding, data augmenting
151
+ augment_noise: 0
152
+
153
+ # Number of samples to compute the stability, default 500
154
+ n_stability_samples: 90 # 98, 50
155
+ n_stability_samples_batch_size: 10 # 7, 14
156
+
157
+ # Dataset partition where pocket samples will be drawn from for analyzing
158
+ # generated ligands' stability: train | test | val
159
+ n_stability_eval_split: val
160
+
161
+
162
+ # disables CUDA training
163
+ no_cuda: false
164
+
165
+ # hutch | exact
166
+ trace: hutch
167
+
168
+ # verbose logging
169
+ verbose: false
170
+
171
+ dtype: torch.float32
172
+
173
+ # enable mixed precision training (fp32, fp16)
174
+ mixed_precision_training: true
175
+ mixed_precision_autocast_dtype: torch.bfloat16
176
+
177
+ # use model checkpointing during training to reduce GPU memory usage
178
+ use_checkpointing: true
179
+
180
+ # sqrt: checkpointing is done on the sqrt(block_num)'th Equivariant block of each EGNN for most optimal perf
181
+ # all: checkpointing is done on all Equivariant blocks. Not optimal but helps if input size is too large
182
+ checkpointing_mode: sqrt
183
+
184
+ # splits tensors into managable chunks and performs forward propagation without breaking GPU memory limit
185
+ forward_tensor_chunk_size: 50000
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+ # ========================================================================================================== LDM
195
+ # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
196
+ model: egnn_dynamics
197
+
198
+ probabilistic_model: diffusion
199
+
200
+ # Training complexity is O(1) (unaffected), but sampling complexity is O(steps), default 500
201
+ diffusion_steps: 1000
202
+
203
+ # learned, cosine, polynomial_<power>
204
+ diffusion_noise_schedule: polynomial_2
205
+
206
+ # default 1e-5
207
+ diffusion_noise_precision: 1.0e-05 # ~!fp16
208
+
209
+ # vlb | l2
210
+ diffusion_loss_type: l2
211
+
212
+ # number of latent features, default 4
213
+ latent_nf: 2
214
+
215
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
216
+ normalize_factors: [1, 4, 10]
217
+
218
+ vae_normalize_x: true
219
+ vae_normalize_method: scale # scale | linear
220
+ vae_normalize_factors: [10, 1, 1]
221
+
222
+ reweight_class_loss: "inv_class_freq"
223
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
224
+
225
+ error_x_weight: 10 # error_x custom weighting
226
+ error_h_weight: 5
227
+
228
+
229
+ # ========================================================================================================== Network Architecture
230
+
231
+ # number of layers of EquivariantBlock to use in VAE's Encoder
232
+ encoder_n_layers: 1
233
+
234
+ # number of layers of EquivariantBlock to use in LDM and VAE's Decoder
235
+ n_layers: 4
236
+
237
+ # number of GCL Blocks to use in each EquivariantBlock
238
+ inv_sublayers: 1
239
+
240
+ # model's internal operating number of features
241
+ nf: 256
242
+
243
+ # use tanh in the coord_mlp
244
+ tanh: true
245
+
246
+ # use attention in the EGNN
247
+ attention: true
248
+
249
+ # diff/(|diff| + norm_constant)
250
+ norm_constant: 1
251
+
252
+ # whether using or not the sin embedding
253
+ sin_embedding: false
254
+
255
+ # uniform | variational | argmax_variational | deterministic
256
+ dequantization: argmax_variational
257
+
258
+ # Normalize the sum aggregation of EGNN
259
+ normalization_factor: 1
260
+
261
+ # EGNN aggregation method: sum | mean
262
+ aggregation_method: sum
263
+
264
+
265
+ # Fusion Block specific settings
266
+ fusion_weights: [0.5, 0.5, 0.5, 0.5] # [0.25, 0.5, 0.75, 1]
267
+ # Condition fusion method:
268
+ # - scaled_sum : (h1_i,x1_i) = (h1_i,x1_i) + w_i * (f_h1_i,f_x1_i)
269
+ # - balanced_sum : (h1_i,x1_i) = [(1 - w_i) * (h1_i,x1_i)] + [w_i * (f_h1_i,f_x1_i)]
270
+ # - replace : (h1_i,x1_i) = (f_h1_i,f_x1_i)
271
+ fusion_mode: balanced_sum
272
+
273
+ # Initial Noise Injection / Feedback Mechanism
274
+ noise_injection_weights: [0.5, 0.5] # pkt = w[0]*lg + w[1]*pkt
275
+ noise_injection_aggregation_method: mean # mean | sum
276
+ noise_injection_normalization_factor: 1 # aggregation normalization factor
277
+
278
+
279
+
280
+
281
+ # ========================================================================================================== Logging
282
+ # Can be used to visualize multiple times per epoch, default 1e8
283
+ visualize_sample_chain: true
284
+ visualize_every_batch: 20000
285
+ visualize_sample_chain_epochs: 2 # for 1% testing dataset, others set to 1
286
+ n_report_steps: 50
287
+
288
+
289
+
290
+
291
+ # ========================================================================================================== Saving & Resuming
292
+ # resume: null
293
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
294
+ # resume_model_ckpt: generative_model_8_iter_14049.npy
295
+ # resume_optim_ckpt: optim_8_iter_14049.npy
296
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume
297
+ # resume_model_ckpt: generative_model_84_iter_140505.npy
298
+ # resume_optim_ckpt: optim_84_iter_140505.npy
299
+ resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_2x_resume
300
+ resume_model_ckpt: generative_model_75_iter_125628.npy
301
+ resume_optim_ckpt: optim_75_iter_125628.npy
302
+
303
+ save_model: true
304
+
305
+
306
+
307
+ # ========================================================================================================== Wandb
308
+ # disable wandb
309
+ no_wandb: false
310
+ wandb_usr: gohyixian456
311
+ # True = wandb online -- False = wandb offline
312
+ online: true
313
+
314
+
315
+
316
+
317
+ pocket_vae:
318
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__POCKET
319
+ vae_data_mode: pocket
320
+ remove_h: false
321
+ ca_only: true
322
+
323
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
324
+ conditioning: []
325
+
326
+ # egnn_dynamics
327
+ model: egnn_dynamics
328
+
329
+ # include atom charge, according to periodic table
330
+ include_charges: false
331
+
332
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
333
+ ema_decay: 0
334
+
335
+ # weight of KL term in ELBO, default 0.01
336
+ kl_weight: 0.01
337
+
338
+ # number of latent features, default 4 (have to match ligand VAE & LDM's latent_nf)
339
+ latent_nf: 2
340
+
341
+ # number of layers of EquivariantBlock to use in VAE's Encoder
342
+ encoder_n_layers: 1
343
+
344
+ # number of layers of EquivariantBlock to use in VAE's Decoder
345
+ n_layers: 4
346
+
347
+ # number of GCL Blocks to use in each EquivariantBlock
348
+ inv_sublayers: 1
349
+
350
+ # model's internal operating number of features
351
+ nf: 256
352
+
353
+ # use tanh in the coord_mlp
354
+ tanh: true
355
+
356
+ # use attention in the EGNN
357
+ attention: true
358
+
359
+ # diff/(|diff| + norm_constant)
360
+ norm_constant: 1
361
+
362
+ # whether using or not the sin embedding
363
+ sin_embedding: false
364
+
365
+ # uniform | variational | argmax_variational | deterministic
366
+ dequantization: argmax_variational
367
+
368
+ # Normalize the sum aggregation of EGNN
369
+ normalization_factor: 1
370
+
371
+ # EGNN aggregation method: sum | mean
372
+ aggregation_method: sum
373
+
374
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
375
+ normalize_factors: [1, 4, 10]
376
+
377
+ vae_normalize_x: true
378
+ vae_normalize_method: scale # scale | linear
379
+ vae_normalize_factors: [10, 1, 1]
380
+
381
+ reweight_class_loss: "inv_class_freq"
382
+ reweight_coords_loss: "inv_class_freq"
383
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
384
+
385
+ error_x_weight: 30
386
+ error_h_weight: 15
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/args_75_iter_125628.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8916a0dbab4f3c448ffd3c169d7f4b6cb5fc1164a4c853d4b061027f4bed00c
3
+ size 5710
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/generative_model_75_iter_125628.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce0e20a07221263b534873efc0116dc2b946a992668f37f25cdadb261aca2040
3
+ size 53575942
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ proj_name: Control-GeoLDM
2
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
3
+ exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x
4
+
5
+
6
+ # ========================================================================================================== Training Mode (ldm/vae/both)
7
+ # Train second stage LatentDiffusionModel model
8
+ train_diffusion: true
9
+
10
+ # training mode: VAE | LDM | ControlNet
11
+ training_mode: ControlNet
12
+ loss_analysis: false
13
+
14
+ # Specify ligand & pocket VAE weights path, set to null for random initialisation
15
+ # set checkpoint (ckpt) to null to automatically select best
16
+ ligand_ae_path: outputs_selected/vae_ligands/AMP__01_VAE_vaenorm_True10__bfloat16__latent2_nf256_epoch100_bs36_lr1e-4_InvClassFreq_Smooth0.25_x10_h5_NoEMA__DecOnly_KL-0__20240623__10A__LG_Only
17
+ ligand_ae_ckpt: generative_model_2_iter_6336.npy
18
+ pocket_ae_path: outputs_selected/vae_pockets/AMP__01_VAE_vaenorm_True10__float32__latent2_nf256_epoch100_bs12_lr1e-4_InvClassFreq_Smooth0.25_XH_x30_h15_NoEMA__20240623__10A__PKT_CA_Only
19
+ pocket_ae_ckpt: generative_model_3_iter_33308.npy
20
+
21
+ # Specify LDM weights path, set to null for random initialisation
22
+ ldm_path: outputs_selected/ldm/AMP__02_LDM_vaenorm_True10__float32__latent2_nf256_epoch200_bs36_lr1e-4_NoEMA__VAE_DecOnly_KL-0__20240623__10A_9x_resume
23
+ ldm_ckpt: generative_model_108_iter_230208.npy
24
+
25
+ # Zero out all weights of fusion blocks instead of randomly instantiated
26
+ zero_fusion_block_weights: false
27
+
28
+
29
+ # Train 1st stage AutoEncoder model (no effect if train_diffusion=False)
30
+ trainable_ligand_ae_encoder: false
31
+ trainable_ligand_ae_decoder: false
32
+ trainable_pocket_ae_encoder: false
33
+
34
+ # Train 2nd stage LDM model
35
+ trainable_ldm: false
36
+
37
+ # Train 3rd stage ControlNet
38
+ trainable_controlnet: true
39
+ trainable_fusion_blocks: true
40
+
41
+
42
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
43
+ conditioning: []
44
+
45
+ # include atom charge, according to periodic table
46
+ include_charges: false # true for qm9
47
+
48
+ # only works for ldm, not for VAE
49
+ condition_time: true
50
+
51
+ # Time Noisy, t/2, adopted from [https://arxiv.org/abs/2405.06659]
52
+ time_noisy: false
53
+
54
+ vis_activations: false
55
+ vis_activations_batch_samples: 5
56
+ vis_activations_batch_size: 1
57
+ vis_activations_specific_ylim: [0, 40]
58
+
59
+ # random_seed: 0
60
+ random_seed: 42
61
+
62
+
63
+ # ========================================================================================================== Dataset
64
+
65
+ # pre-computed dataset stats
66
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__LIGAND
67
+
68
+ # pre-computed training dataset
69
+ data_file: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/d_20241203_CrossDocked_LG_PKT_MMseq2_split__10.0A__CA_Only.npz
70
+ data_splitted: true
71
+
72
+ # Quick Vina 2.1
73
+ compute_qvina: true
74
+ qvina_search_size: 20 # search size (all 3 axes) in Angstroms around ligand center
75
+ qvina_exhaustiveness: 16
76
+ qvina_seed: 42
77
+ qvina_cleanup_files: true # cleanup tmp pdb, pdbqt files
78
+ qvina_save_csv: true # save results in csv
79
+ pocket_pdb_dir: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/test_val_paired_files/val_pocket
80
+ match_raw_file_by_id: true
81
+ mgltools_env_name: mgltools-python2 # for pdb -> pdbqt conversion
82
+
83
+ ligand_add_H: false # add hydrogens via: [mgltools] prepare_ligand4.py -l .. -o .. -A hydrogens
84
+ pocket_add_H: false # add hydrogens via: [mgltools] prepare_receptor4.py -r .. -o .. -A checkhydrogens
85
+ pocket_remove_nonstd_resi: false # remove any pocket residues not in this list:
86
+ # ['CYS','ILE','SER','VAL','GLN','LYS','ASN',
87
+ # 'PRO','THR','PHE','ALA','HIS','GLY','ASP',
88
+ # 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET',
89
+ # 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']
90
+
91
+
92
+ # set to null if you're running this dataset for the first time.
93
+ # Script will generate a random permutation to shuffle the dataset.
94
+ # Please set the path to the DATASET_permutation.npy file after it is generated.
95
+ # permutation_file_path: ./data/d_20240623_CrossDocked_LG_PKT/d_20240623_CrossDocked_LG_PKT__10.0A_LG100_PKT600_permutation.npy
96
+ permutation_file_path: null
97
+
98
+ # what data to load for VAE training: ligand | pocket | all
99
+ vae_data_mode: ligand
100
+
101
+ # When set to an integer value, QM9 will only contain molecules of that amount of atoms, default null
102
+ filter_n_atoms: null
103
+
104
+ # Only use molecules below this size. Int, default null ~!geom
105
+ filter_molecule_size: 100
106
+ filter_pocket_size: 80
107
+
108
+ # Organize data by size to reduce average memory usage. ~!geom
109
+ sequential: false
110
+
111
+ # Number of worker for the dataloader
112
+ num_workers: 32 # match cpu count
113
+
114
+ # use data augmentation (i.e. random rotation of x atom coordinates)
115
+ data_augmentation: false
116
+
117
+ # remove hydrogen atoms
118
+ remove_h: false
119
+
120
+
121
+
122
+
123
+ # ========================================================================================================== Training Params
124
+ start_epoch: 0
125
+ test_epochs: 5 # 4
126
+
127
+
128
+ n_epochs: 1000 # 3000 takes 20 epoches on paper (bs:32), hence 80 epochs for bs:8
129
+ batch_size: 60 # 14
130
+ lr: 1.0e-4
131
+
132
+ # weight of KL term in ELBO, default 0.01
133
+ kl_weight: 0.0
134
+
135
+ # ode_regularization weightage, default 1e-3
136
+ ode_regularization: 0.001
137
+ # brute_force: false
138
+ # actnorm: true
139
+ break_train_epoch: false
140
+
141
+ # Data Parallel for multi GPU support
142
+ dp: true
143
+ clip_grad: true
144
+
145
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
146
+ ema_decay: 0 # 0.99
147
+
148
+ # add noise to x before encoding, data augmenting
149
+ augment_noise: 0
150
+
151
+ # Number of samples to compute the stability, default 500
152
+ n_stability_samples: 90 # 98, 50
153
+ n_stability_samples_batch_size: 10 # 7, 14
154
+
155
+ # Dataset partition where pocket samples will be drawn from for analyzing
156
+ # generated ligands' stability: train | test | val
157
+ n_stability_eval_split: val
158
+
159
+
160
+ # disables CUDA training
161
+ no_cuda: false
162
+
163
+ # hutch | exact
164
+ trace: hutch
165
+
166
+ # verbose logging
167
+ verbose: false
168
+
169
+ dtype: torch.float32
170
+
171
+ # enable mixed precision training (fp32, fp16)
172
+ mixed_precision_training: true
173
+ mixed_precision_autocast_dtype: torch.bfloat16
174
+
175
+ # use model checkpointing during training to reduce GPU memory usage
176
+ use_checkpointing: true
177
+
178
+ # sqrt: checkpointing is done on the sqrt(block_num)'th Equivariant block of each EGNN for most optimal perf
179
+ # all: checkpointing is done on all Equivariant blocks. Not optimal but helps if input size is too large
180
+ checkpointing_mode: sqrt
181
+
182
+ # splits tensors into managable chunks and performs forward propagation without breaking GPU memory limit
183
+ forward_tensor_chunk_size: 50000
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+ # ========================================================================================================== LDM
193
+ # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
194
+ model: egnn_dynamics
195
+
196
+ probabilistic_model: diffusion
197
+
198
+ # Training complexity is O(1) (unaffected), but sampling complexity is O(steps), default 500
199
+ diffusion_steps: 1000
200
+
201
+ # learned, cosine, polynomial_<power>
202
+ diffusion_noise_schedule: polynomial_2
203
+
204
+ # default 1e-5
205
+ diffusion_noise_precision: 1.0e-05 # ~!fp16
206
+
207
+ # vlb | l2
208
+ diffusion_loss_type: l2
209
+
210
+ # number of latent features, default 4
211
+ latent_nf: 2
212
+
213
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
214
+ normalize_factors: [1, 4, 10]
215
+
216
+ vae_normalize_x: true
217
+ vae_normalize_method: scale # scale | linear
218
+ vae_normalize_factors: [10, 1, 1]
219
+
220
+ reweight_class_loss: "inv_class_freq"
221
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
222
+
223
+ error_x_weight: 10 # error_x custom weighting
224
+ error_h_weight: 5
225
+
226
+
227
+ # ========================================================================================================== Network Architecture
228
+
229
+ # number of layers of EquivariantBlock to use in VAE's Encoder
230
+ encoder_n_layers: 1
231
+
232
+ # number of layers of EquivariantBlock to use in LDM and VAE's Decoder
233
+ n_layers: 4
234
+
235
+ # number of GCL Blocks to use in each EquivariantBlock
236
+ inv_sublayers: 1
237
+
238
+ # model's internal operating number of features
239
+ nf: 256
240
+
241
+ # use tanh in the coord_mlp
242
+ tanh: true
243
+
244
+ # use attention in the EGNN
245
+ attention: true
246
+
247
+ # diff/(|diff| + norm_constant)
248
+ norm_constant: 1
249
+
250
+ # whether using or not the sin embedding
251
+ sin_embedding: false
252
+
253
+ # uniform | variational | argmax_variational | deterministic
254
+ dequantization: argmax_variational
255
+
256
+ # Normalize the sum aggregation of EGNN
257
+ normalization_factor: 1
258
+
259
+ # EGNN aggregation method: sum | mean
260
+ aggregation_method: sum
261
+
262
+
263
+ # Fusion Block specific settings
264
+ fusion_weights: [0.9, 0.9, 0.9, 0.9] # [0.25, 0.5, 0.75, 1]
265
+ # Condition fusion method:
266
+ # - scaled_sum : (h1_i,x1_i) = (h1_i,x1_i) + w_i * (f_h1_i,f_x1_i)
267
+ # - balanced_sum : (h1_i,x1_i) = [(1 - w_i) * (h1_i,x1_i)] + [w_i * (f_h1_i,f_x1_i)]
268
+ # - replace : (h1_i,x1_i) = (f_h1_i,f_x1_i)
269
+ fusion_mode: balanced_sum
270
+
271
+ # Initial Noise Injection / Feedback Mechanism
272
+ noise_injection_weights: [0.5, 0.5] # pkt = w[0]*lg + w[1]*pkt
273
+ noise_injection_aggregation_method: mean # mean | sum
274
+ noise_injection_normalization_factor: 1 # aggregation normalization factor
275
+
276
+
277
+
278
+
279
+ # ========================================================================================================== Logging
280
+ # Can be used to visualize multiple times per epoch, default 1e8
281
+ visualize_sample_chain: true
282
+ visualize_every_batch: 20000
283
+ visualize_sample_chain_epochs: 2 # for 1% testing dataset, others set to 1
284
+ n_report_steps: 50
285
+
286
+
287
+
288
+
289
+ # ========================================================================================================== Saving & Resuming
290
+ # resume: null
291
+ resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
292
+ resume_model_ckpt: generative_model_80_iter_133893.npy
293
+ resume_optim_ckpt: optim_80_iter_133893.npy
294
+
295
+ save_model: true
296
+
297
+
298
+
299
+ # ========================================================================================================== Wandb
300
+ # disable wandb
301
+ no_wandb: false
302
+ wandb_usr: gohyixian456
303
+ # True = wandb online -- False = wandb offline
304
+ online: true
305
+
306
+
307
+
308
+
309
+ pocket_vae:
310
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__POCKET
311
+ vae_data_mode: pocket
312
+ remove_h: false
313
+ ca_only: true
314
+
315
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
316
+ conditioning: []
317
+
318
+ # egnn_dynamics
319
+ model: egnn_dynamics
320
+
321
+ # include atom charge, according to periodic table
322
+ include_charges: false
323
+
324
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
325
+ ema_decay: 0
326
+
327
+ # weight of KL term in ELBO, default 0.01
328
+ kl_weight: 0.01
329
+
330
+ # number of latent features, default 4 (have to match ligand VAE & LDM's latent_nf)
331
+ latent_nf: 2
332
+
333
+ # number of layers of EquivariantBlock to use in VAE's Encoder
334
+ encoder_n_layers: 1
335
+
336
+ # number of layers of EquivariantBlock to use in VAE's Decoder
337
+ n_layers: 4
338
+
339
+ # number of GCL Blocks to use in each EquivariantBlock
340
+ inv_sublayers: 1
341
+
342
+ # model's internal operating number of features
343
+ nf: 256
344
+
345
+ # use tanh in the coord_mlp
346
+ tanh: true
347
+
348
+ # use attention in the EGNN
349
+ attention: true
350
+
351
+ # diff/(|diff| + norm_constant)
352
+ norm_constant: 1
353
+
354
+ # whether using or not the sin embedding
355
+ sin_embedding: false
356
+
357
+ # uniform | variational | argmax_variational | deterministic
358
+ dequantization: argmax_variational
359
+
360
+ # Normalize the sum aggregation of EGNN
361
+ normalization_factor: 1
362
+
363
+ # EGNN aggregation method: sum | mean
364
+ aggregation_method: sum
365
+
366
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
367
+ normalize_factors: [1, 4, 10]
368
+
369
+ vae_normalize_x: true
370
+ vae_normalize_method: scale # scale | linear
371
+ vae_normalize_factors: [10, 1, 1]
372
+
373
+ reweight_class_loss: "inv_class_freq"
374
+ reweight_coords_loss: "inv_class_freq"
375
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
376
+
377
+ error_x_weight: 30
378
+ error_h_weight: 15
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume/args_75_iter_125628.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87d55dd84aac61dbcb37450676b8999cab749933af7ee14565ccc0349aeae494
3
+ size 5700
03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.9__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume/generative_model_75_iter_125628.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c234296d787f7121cf40b9f958a3e4714a899c122a37de7d3a61fbb9555f8653
3
+ size 53575942
03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ proj_name: Control-GeoLDM
2
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
3
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x
4
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_2x
5
+ exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x
6
+
7
+
8
+ # ========================================================================================================== Training Mode (ldm/vae/both)
9
+ # Train second stage LatentDiffusionModel model
10
+ train_diffusion: true
11
+
12
+ # training mode: VAE | LDM | ControlNet
13
+ training_mode: ControlNet
14
+ loss_analysis: false
15
+
16
+ # Specify ligand & pocket VAE weights path, set to null for random initialisation
17
+ # set checkpoint (ckpt) to null to automatically select best
18
+ ligand_ae_path: outputs_selected/vae_ligands/AMP__01_VAE_vaenorm_True10__bfloat16__latent2_nf256_epoch100_bs36_lr1e-4_InvClassFreq_Smooth0.25_x10_h5_NoEMA__DecOnly_KL-0__20240623__10A__LG_Only
19
+ ligand_ae_ckpt: generative_model_2_iter_6336.npy
20
+ pocket_ae_path: outputs_selected/vae_pockets/AMP__01_VAE_vaenorm_True10__float32__latent2_nf256_epoch100_bs12_lr1e-4_InvClassFreq_Smooth0.25_XH_x30_h15_NoEMA__20240623__10A__PKT_CA_Only
21
+ pocket_ae_ckpt: generative_model_3_iter_33308.npy
22
+
23
+ # Specify LDM weights path, set to null for random initialisation
24
+ ldm_path: outputs_selected/ldm/AMP__02_LDM_vaenorm_True10__float32__latent2_nf256_epoch200_bs36_lr1e-4_NoEMA__VAE_DecOnly_KL-0__20240623__10A_9x_resume
25
+ ldm_ckpt: generative_model_108_iter_230208.npy
26
+
27
+ # Zero out all weights of fusion blocks instead of randomly instantiated
28
+ zero_fusion_block_weights: false
29
+
30
+
31
+ # Train 1st stage AutoEncoder model (no effect if train_diffusion=False)
32
+ trainable_ligand_ae_encoder: false
33
+ trainable_ligand_ae_decoder: false
34
+ trainable_pocket_ae_encoder: false
35
+
36
+ # Train 2nd stage LDM model
37
+ trainable_ldm: false
38
+
39
+ # Train 3rd stage ControlNet
40
+ trainable_controlnet: true
41
+ trainable_fusion_blocks: true
42
+
43
+
44
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
45
+ conditioning: []
46
+
47
+ # include atom charge, according to periodic table
48
+ include_charges: false # true for qm9
49
+
50
+ # only works for ldm, not for VAE
51
+ condition_time: true
52
+
53
+ # Time Noisy, t/2, adopted from [https://arxiv.org/abs/2405.06659]
54
+ time_noisy: false
55
+
56
+ vis_activations: false
57
+ vis_activations_batch_samples: 5
58
+ vis_activations_batch_size: 1
59
+ vis_activations_specific_ylim: [0, 40]
60
+
61
+ # random_seed: 0
62
+ random_seed: 42
63
+
64
+
65
+ # ========================================================================================================== Dataset
66
+
67
+ # pre-computed dataset stats
68
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__LIGAND
69
+
70
+ # pre-computed training dataset
71
+ data_file: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/d_20241203_CrossDocked_LG_PKT_MMseq2_split__10.0A__CA_Only.npz
72
+ data_splitted: true
73
+
74
+ # Quick Vina 2.1
75
+ compute_qvina: true
76
+ qvina_search_size: 20 # search size (all 3 axes) in Angstroms around ligand center
77
+ qvina_exhaustiveness: 16
78
+ qvina_seed: 42
79
+ qvina_cleanup_files: true # cleanup tmp pdb, pdbqt files
80
+ qvina_save_csv: true # save results in csv
81
+ pocket_pdb_dir: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/test_val_paired_files/val_pocket
82
+ match_raw_file_by_id: true
83
+ mgltools_env_name: mgltools-python2 # for pdb -> pdbqt conversion
84
+
85
+ ligand_add_H: false # add hydrogens via: [mgltools] prepare_ligand4.py -l .. -o .. -A hydrogens
86
+ pocket_add_H: false # add hydrogens via: [mgltools] prepare_receptor4.py -r .. -o .. -A checkhydrogens
87
+ pocket_remove_nonstd_resi: false # remove any pocket residues not in this list:
88
+ # ['CYS','ILE','SER','VAL','GLN','LYS','ASN',
89
+ # 'PRO','THR','PHE','ALA','HIS','GLY','ASP',
90
+ # 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET',
91
+ # 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']
92
+
93
+
94
+ # set to null if you're running this dataset for the first time.
95
+ # Script will generate a random permutation to shuffle the dataset.
96
+ # Please set the path to the DATASET_permutation.npy file after it is generated.
97
+ # permutation_file_path: ./data/d_20240623_CrossDocked_LG_PKT/d_20240623_CrossDocked_LG_PKT__10.0A_LG100_PKT600_permutation.npy
98
+ permutation_file_path: null
99
+
100
+ # what data to load for VAE training: ligand | pocket | all
101
+ vae_data_mode: ligand
102
+
103
+ # When set to an integer value, QM9 will only contain molecules of that amount of atoms, default null
104
+ filter_n_atoms: null
105
+
106
+ # Only use molecules below this size. Int, default null ~!geom
107
+ filter_molecule_size: 100
108
+ filter_pocket_size: 80
109
+
110
+ # Organize data by size to reduce average memory usage. ~!geom
111
+ sequential: false
112
+
113
+ # Number of worker for the dataloader
114
+ num_workers: 32 # match cpu count
115
+
116
+ # use data augmentation (i.e. random rotation of x atom coordinates)
117
+ data_augmentation: false
118
+
119
+ # remove hydrogen atoms
120
+ remove_h: false
121
+
122
+
123
+
124
+
125
+ # ========================================================================================================== Training Params
126
+ start_epoch: 0
127
+ test_epochs: 5 # 4
128
+
129
+
130
+ n_epochs: 1000 # 3000 takes 20 epoches on paper (bs:32), hence 80 epochs for bs:8
131
+ batch_size: 60 # 14
132
+ lr: 1.0e-4
133
+
134
+ # weight of KL term in ELBO, default 0.01
135
+ kl_weight: 0.0
136
+
137
+ # ode_regularization weightage, default 1e-3
138
+ ode_regularization: 0.001
139
+ # brute_force: false
140
+ # actnorm: true
141
+ break_train_epoch: false
142
+
143
+ # Data Parallel for multi GPU support
144
+ dp: true
145
+ clip_grad: true
146
+
147
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
148
+ ema_decay: 0 # 0.99
149
+
150
+ # add noise to x before encoding, data augmenting
151
+ augment_noise: 0
152
+
153
+ # Number of samples to compute the stability, default 500
154
+ n_stability_samples: 90 # 98, 50
155
+ n_stability_samples_batch_size: 10 # 7, 14
156
+
157
+ # Dataset partition where pocket samples will be drawn from for analyzing
158
+ # generated ligands' stability: train | test | val
159
+ n_stability_eval_split: val
160
+
161
+
162
+ # disables CUDA training
163
+ no_cuda: false
164
+
165
+ # hutch | exact
166
+ trace: hutch
167
+
168
+ # verbose logging
169
+ verbose: false
170
+
171
+ dtype: torch.float32
172
+
173
+ # enable mixed precision training (fp32, fp16)
174
+ mixed_precision_training: true
175
+ mixed_precision_autocast_dtype: torch.bfloat16
176
+
177
+ # use model checkpointing during training to reduce GPU memory usage
178
+ use_checkpointing: true
179
+
180
+ # sqrt: checkpointing is done on the sqrt(block_num)'th Equivariant block of each EGNN for most optimal perf
181
+ # all: checkpointing is done on all Equivariant blocks. Not optimal but helps if input size is too large
182
+ checkpointing_mode: sqrt
183
+
184
+ # splits tensors into managable chunks and performs forward propagation without breaking GPU memory limit
185
+ forward_tensor_chunk_size: 50000
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+ # ========================================================================================================== LDM
195
+ # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
196
+ model: egnn_dynamics
197
+
198
+ probabilistic_model: diffusion
199
+
200
+ # Training complexity is O(1) (unaffected), but sampling complexity is O(steps), default 500
201
+ diffusion_steps: 1000
202
+
203
+ # learned, cosine, polynomial_<power>
204
+ diffusion_noise_schedule: polynomial_2
205
+
206
+ # default 1e-5
207
+ diffusion_noise_precision: 1.0e-05 # ~!fp16
208
+
209
+ # vlb | l2
210
+ diffusion_loss_type: l2
211
+
212
+ # number of latent features, default 4
213
+ latent_nf: 2
214
+
215
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
216
+ normalize_factors: [1, 4, 10]
217
+
218
+ vae_normalize_x: true
219
+ vae_normalize_method: scale # scale | linear
220
+ vae_normalize_factors: [10, 1, 1]
221
+
222
+ reweight_class_loss: "inv_class_freq"
223
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
224
+
225
+ error_x_weight: 10 # error_x custom weighting
226
+ error_h_weight: 5
227
+
228
+
229
+ # ========================================================================================================== Network Architecture
230
+
231
+ # number of layers of EquivariantBlock to use in VAE's Encoder
232
+ encoder_n_layers: 1
233
+
234
+ # number of layers of EquivariantBlock to use in LDM and VAE's Decoder
235
+ n_layers: 4
236
+
237
+ # number of GCL Blocks to use in each EquivariantBlock
238
+ inv_sublayers: 1
239
+
240
+ # model's internal operating number of features
241
+ nf: 256
242
+
243
+ # use tanh in the coord_mlp
244
+ tanh: true
245
+
246
+ # use attention in the EGNN
247
+ attention: true
248
+
249
+ # diff/(|diff| + norm_constant)
250
+ norm_constant: 1
251
+
252
+ # whether using or not the sin embedding
253
+ sin_embedding: false
254
+
255
+ # uniform | variational | argmax_variational | deterministic
256
+ dequantization: argmax_variational
257
+
258
+ # Normalize the sum aggregation of EGNN
259
+ normalization_factor: 1
260
+
261
+ # EGNN aggregation method: sum | mean
262
+ aggregation_method: sum
263
+
264
+
265
+ # Fusion Block specific settings
266
+ fusion_weights: [0, 0, 0.1, 0.1] # [0.25, 0.5, 0.75, 1]
267
+ # Condition fusion method:
268
+ # - scaled_sum : (h1_i,x1_i) = (h1_i,x1_i) + w_i * (f_h1_i,f_x1_i)
269
+ # - balanced_sum : (h1_i,x1_i) = [(1 - w_i) * (h1_i,x1_i)] + [w_i * (f_h1_i,f_x1_i)]
270
+ # - replace : (h1_i,x1_i) = (f_h1_i,f_x1_i)
271
+ fusion_mode: balanced_sum
272
+
273
+ # Initial Noise Injection / Feedback Mechanism
274
+ noise_injection_weights: [0.5, 0.5] # pkt = w[0]*lg + w[1]*pkt
275
+ noise_injection_aggregation_method: mean # mean | sum
276
+ noise_injection_normalization_factor: 1 # aggregation normalization factor
277
+
278
+
279
+
280
+
281
+ # ========================================================================================================== Logging
282
+ # Can be used to visualize multiple times per epoch, default 1e8
283
+ visualize_sample_chain: true
284
+ visualize_every_batch: 20000
285
+ visualize_sample_chain_epochs: 2 # for 1% testing dataset, others set to 1
286
+ n_report_steps: 50
287
+
288
+
289
+
290
+
291
+ # ========================================================================================================== Saving & Resuming
292
+ # resume: null
293
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
294
+ # resume_model_ckpt: generative_model_8_iter_14049.npy
295
+ # resume_optim_ckpt: optim_8_iter_14049.npy
296
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume
297
+ # resume_model_ckpt: generative_model_82_iter_137199.npy
298
+ # resume_optim_ckpt: optim_82_iter_137199.npy
299
+ resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_2x_resume
300
+ resume_model_ckpt: generative_model_75_iter_125628.npy
301
+ resume_optim_ckpt: optim_75_iter_125628.npy
302
+
303
+ save_model: true
304
+
305
+
306
+
307
+ # ========================================================================================================== Wandb
308
+ # disable wandb
309
+ no_wandb: false
310
+ wandb_usr: gohyixian456
311
+ # True = wandb online -- False = wandb offline
312
+ online: true
313
+
314
+
315
+
316
+
317
+ pocket_vae:
318
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__POCKET
319
+ vae_data_mode: pocket
320
+ remove_h: false
321
+ ca_only: true
322
+
323
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
324
+ conditioning: []
325
+
326
+ # egnn_dynamics
327
+ model: egnn_dynamics
328
+
329
+ # include atom charge, according to periodic table
330
+ include_charges: false
331
+
332
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
333
+ ema_decay: 0
334
+
335
+ # weight of KL term in ELBO, default 0.01
336
+ kl_weight: 0.01
337
+
338
+ # number of latent features, default 4 (have to match ligand VAE & LDM's latent_nf)
339
+ latent_nf: 2
340
+
341
+ # number of layers of EquivariantBlock to use in VAE's Encoder
342
+ encoder_n_layers: 1
343
+
344
+ # number of layers of EquivariantBlock to use in VAE's Decoder
345
+ n_layers: 4
346
+
347
+ # number of GCL Blocks to use in each EquivariantBlock
348
+ inv_sublayers: 1
349
+
350
+ # model's internal operating number of features
351
+ nf: 256
352
+
353
+ # use tanh in the coord_mlp
354
+ tanh: true
355
+
356
+ # use attention in the EGNN
357
+ attention: true
358
+
359
+ # diff/(|diff| + norm_constant)
360
+ norm_constant: 1
361
+
362
+ # whether using or not the sin embedding
363
+ sin_embedding: false
364
+
365
+ # uniform | variational | argmax_variational | deterministic
366
+ dequantization: argmax_variational
367
+
368
+ # Normalize the sum aggregation of EGNN
369
+ normalization_factor: 1
370
+
371
+ # EGNN aggregation method: sum | mean
372
+ aggregation_method: sum
373
+
374
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
375
+ normalize_factors: [1, 4, 10]
376
+
377
+ vae_normalize_x: true
378
+ vae_normalize_method: scale # scale | linear
379
+ vae_normalize_factors: [10, 1, 1]
380
+
381
+ reweight_class_loss: "inv_class_freq"
382
+ reweight_coords_loss: "inv_class_freq"
383
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
384
+
385
+ error_x_weight: 30
386
+ error_h_weight: 15
03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/args_75_iter_125628.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a37e79ffff388edaaf86eed139cf88521b7087099e25135cda14c0a17747d81
3
+ size 5706
03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.1__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/generative_model_75_iter_125628.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:710e0ed2b42e05e78713e9918ed196a14e1ef86b6aa5f7be7e47268c779eb7af
3
+ size 53575942
03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ proj_name: Control-GeoLDM
2
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
3
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x
4
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_2x
5
+ exp_name: 03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x
6
+
7
+
8
+ # ========================================================================================================== Training Mode (ldm/vae/both)
9
+ # Train second stage LatentDiffusionModel model
10
+ train_diffusion: true
11
+
12
+ # training mode: VAE | LDM | ControlNet
13
+ training_mode: ControlNet
14
+ loss_analysis: false
15
+
16
+ # Specify ligand & pocket VAE weights path, set to null for random initialisation
17
+ # set checkpoint (ckpt) to null to automatically select best
18
+ ligand_ae_path: outputs_selected/vae_ligands/AMP__01_VAE_vaenorm_True10__bfloat16__latent2_nf256_epoch100_bs36_lr1e-4_InvClassFreq_Smooth0.25_x10_h5_NoEMA__DecOnly_KL-0__20240623__10A__LG_Only
19
+ ligand_ae_ckpt: generative_model_2_iter_6336.npy
20
+ pocket_ae_path: outputs_selected/vae_pockets/AMP__01_VAE_vaenorm_True10__float32__latent2_nf256_epoch100_bs12_lr1e-4_InvClassFreq_Smooth0.25_XH_x30_h15_NoEMA__20240623__10A__PKT_CA_Only
21
+ pocket_ae_ckpt: generative_model_3_iter_33308.npy
22
+
23
+ # Specify LDM weights path, set to null for random initialisation
24
+ ldm_path: outputs_selected/ldm/AMP__02_LDM_vaenorm_True10__float32__latent2_nf256_epoch200_bs36_lr1e-4_NoEMA__VAE_DecOnly_KL-0__20240623__10A_9x_resume
25
+ ldm_ckpt: generative_model_108_iter_230208.npy
26
+
27
+ # Zero out all weights of fusion blocks instead of randomly instantiated
28
+ zero_fusion_block_weights: false
29
+
30
+
31
+ # Train 1st stage AutoEncoder model (no effect if train_diffusion=False)
32
+ trainable_ligand_ae_encoder: false
33
+ trainable_ligand_ae_decoder: false
34
+ trainable_pocket_ae_encoder: false
35
+
36
+ # Train 2nd stage LDM model
37
+ trainable_ldm: false
38
+
39
+ # Train 3rd stage ControlNet
40
+ trainable_controlnet: true
41
+ trainable_fusion_blocks: true
42
+
43
+
44
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
45
+ conditioning: []
46
+
47
+ # include atom charge, according to periodic table
48
+ include_charges: false # true for qm9
49
+
50
+ # only works for ldm, not for VAE
51
+ condition_time: true
52
+
53
+ # Time Noisy, t/2, adopted from [https://arxiv.org/abs/2405.06659]
54
+ time_noisy: false
55
+
56
+ vis_activations: false
57
+ vis_activations_batch_samples: 5
58
+ vis_activations_batch_size: 1
59
+ vis_activations_specific_ylim: [0, 40]
60
+
61
+ # random_seed: 0
62
+ random_seed: 42
63
+
64
+
65
+ # ========================================================================================================== Dataset
66
+
67
+ # pre-computed dataset stats
68
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__LIGAND
69
+
70
+ # pre-computed training dataset
71
+ data_file: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/d_20241203_CrossDocked_LG_PKT_MMseq2_split__10.0A__CA_Only.npz
72
+ data_splitted: true
73
+
74
+ # Quick Vina 2.1
75
+ compute_qvina: true
76
+ qvina_search_size: 20 # search size (all 3 axes) in Angstroms around ligand center
77
+ qvina_exhaustiveness: 16
78
+ qvina_seed: 42
79
+ qvina_cleanup_files: true # cleanup tmp pdb, pdbqt files
80
+ qvina_save_csv: true # save results in csv
81
+ pocket_pdb_dir: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/test_val_paired_files/val_pocket
82
+ match_raw_file_by_id: true
83
+ mgltools_env_name: mgltools-python2 # for pdb -> pdbqt conversion
84
+
85
+ ligand_add_H: false # add hydrogens via: [mgltools] prepare_ligand4.py -l .. -o .. -A hydrogens
86
+ pocket_add_H: false # add hydrogens via: [mgltools] prepare_receptor4.py -r .. -o .. -A checkhydrogens
87
+ pocket_remove_nonstd_resi: false # remove any pocket residues not in this list:
88
+ # ['CYS','ILE','SER','VAL','GLN','LYS','ASN',
89
+ # 'PRO','THR','PHE','ALA','HIS','GLY','ASP',
90
+ # 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET',
91
+ # 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']
92
+
93
+
94
+ # set to null if you're running this dataset for the first time.
95
+ # Script will generate a random permutation to shuffle the dataset.
96
+ # Please set the path to the DATASET_permutation.npy file after it is generated.
97
+ # permutation_file_path: ./data/d_20240623_CrossDocked_LG_PKT/d_20240623_CrossDocked_LG_PKT__10.0A_LG100_PKT600_permutation.npy
98
+ permutation_file_path: null
99
+
100
+ # what data to load for VAE training: ligand | pocket | all
101
+ vae_data_mode: ligand
102
+
103
+ # When set to an integer value, QM9 will only contain molecules of that amount of atoms, default null
104
+ filter_n_atoms: null
105
+
106
+ # Only use molecules below this size. Int, default null ~!geom
107
+ filter_molecule_size: 100
108
+ filter_pocket_size: 80
109
+
110
+ # Organize data by size to reduce average memory usage. ~!geom
111
+ sequential: false
112
+
113
+ # Number of worker for the dataloader
114
+ num_workers: 32 # match cpu count
115
+
116
+ # use data augmentation (i.e. random rotation of x atom coordinates)
117
+ data_augmentation: false
118
+
119
+ # remove hydrogen atoms
120
+ remove_h: false
121
+
122
+
123
+
124
+
125
+ # ========================================================================================================== Training Params
126
+ start_epoch: 0
127
+ test_epochs: 5 # 4
128
+
129
+
130
+ n_epochs: 1000 # 3000 takes 20 epoches on paper (bs:32), hence 80 epochs for bs:8
131
+ batch_size: 60 # 14
132
+ lr: 1.0e-4
133
+
134
+ # weight of KL term in ELBO, default 0.01
135
+ kl_weight: 0.0
136
+
137
+ # ode_regularization weightage, default 1e-3
138
+ ode_regularization: 0.001
139
+ # brute_force: false
140
+ # actnorm: true
141
+ break_train_epoch: false
142
+
143
+ # Data Parallel for multi GPU support
144
+ dp: true
145
+ clip_grad: true
146
+
147
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
148
+ ema_decay: 0 # 0.99
149
+
150
+ # add noise to x before encoding, data augmenting
151
+ augment_noise: 0
152
+
153
+ # Number of samples to compute the stability, default 500
154
+ n_stability_samples: 90 # 98, 50
155
+ n_stability_samples_batch_size: 10 # 7, 14
156
+
157
+ # Dataset partition where pocket samples will be drawn from for analyzing
158
+ # generated ligands' stability: train | test | val
159
+ n_stability_eval_split: val
160
+
161
+
162
+ # disables CUDA training
163
+ no_cuda: false
164
+
165
+ # hutch | exact
166
+ trace: hutch
167
+
168
+ # verbose logging
169
+ verbose: false
170
+
171
+ dtype: torch.float32
172
+
173
+ # enable mixed precision training (fp32, fp16)
174
+ mixed_precision_training: true
175
+ mixed_precision_autocast_dtype: torch.bfloat16
176
+
177
+ # use model checkpointing during training to reduce GPU memory usage
178
+ use_checkpointing: true
179
+
180
+ # sqrt: checkpointing is done on the sqrt(block_num)'th Equivariant block of each EGNN for most optimal perf
181
+ # all: checkpointing is done on all Equivariant blocks. Not optimal but helps if input size is too large
182
+ checkpointing_mode: sqrt
183
+
184
+ # splits tensors into managable chunks and performs forward propagation without breaking GPU memory limit
185
+ forward_tensor_chunk_size: 50000
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+ # ========================================================================================================== LDM
195
+ # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
196
+ model: egnn_dynamics
197
+
198
+ probabilistic_model: diffusion
199
+
200
+ # Training complexity is O(1) (unaffected), but sampling complexity is O(steps), default 500
201
+ diffusion_steps: 1000
202
+
203
+ # learned, cosine, polynomial_<power>
204
+ diffusion_noise_schedule: polynomial_2
205
+
206
+ # default 1e-5
207
+ diffusion_noise_precision: 1.0e-05 # ~!fp16
208
+
209
+ # vlb | l2
210
+ diffusion_loss_type: l2
211
+
212
+ # number of latent features, default 4
213
+ latent_nf: 2
214
+
215
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
216
+ normalize_factors: [1, 4, 10]
217
+
218
+ vae_normalize_x: true
219
+ vae_normalize_method: scale # scale | linear
220
+ vae_normalize_factors: [10, 1, 1]
221
+
222
+ reweight_class_loss: "inv_class_freq"
223
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
224
+
225
+ error_x_weight: 10 # error_x custom weighting
226
+ error_h_weight: 5
227
+
228
+
229
+ # ========================================================================================================== Network Architecture
230
+
231
+ # number of layers of EquivariantBlock to use in VAE's Encoder
232
+ encoder_n_layers: 1
233
+
234
+ # number of layers of EquivariantBlock to use in LDM and VAE's Decoder
235
+ n_layers: 4
236
+
237
+ # number of GCL Blocks to use in each EquivariantBlock
238
+ inv_sublayers: 1
239
+
240
+ # model's internal operating number of features
241
+ nf: 256
242
+
243
+ # use tanh in the coord_mlp
244
+ tanh: true
245
+
246
+ # use attention in the EGNN
247
+ attention: true
248
+
249
+ # diff/(|diff| + norm_constant)
250
+ norm_constant: 1
251
+
252
+ # whether using or not the sin embedding
253
+ sin_embedding: false
254
+
255
+ # uniform | variational | argmax_variational | deterministic
256
+ dequantization: argmax_variational
257
+
258
+ # Normalize the sum aggregation of EGNN
259
+ normalization_factor: 1
260
+
261
+ # EGNN aggregation method: sum | mean
262
+ aggregation_method: sum
263
+
264
+
265
+ # Fusion Block specific settings
266
+ fusion_weights: [0, 0, 0.5, 0.5] # [0.25, 0.5, 0.75, 1]
267
+ # Condition fusion method:
268
+ # - scaled_sum : (h1_i,x1_i) = (h1_i,x1_i) + w_i * (f_h1_i,f_x1_i)
269
+ # - balanced_sum : (h1_i,x1_i) = [(1 - w_i) * (h1_i,x1_i)] + [w_i * (f_h1_i,f_x1_i)]
270
+ # - replace : (h1_i,x1_i) = (f_h1_i,f_x1_i)
271
+ fusion_mode: balanced_sum
272
+
273
+ # Initial Noise Injection / Feedback Mechanism
274
+ noise_injection_weights: [0.5, 0.5] # pkt = w[0]*lg + w[1]*pkt
275
+ noise_injection_aggregation_method: mean # mean | sum
276
+ noise_injection_normalization_factor: 1 # aggregation normalization factor
277
+
278
+
279
+
280
+
281
+ # ========================================================================================================== Logging
282
+ # Can be used to visualize multiple times per epoch, default 1e8
283
+ visualize_sample_chain: true
284
+ visualize_every_batch: 20000
285
+ visualize_sample_chain_epochs: 2 # for 1% testing dataset, others set to 1
286
+ n_report_steps: 50
287
+
288
+
289
+
290
+
291
+ # ========================================================================================================== Saving & Resuming
292
+ # resume: null
293
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
294
+ # resume_model_ckpt: generative_model_8_iter_14049.npy
295
+ # resume_optim_ckpt: optim_8_iter_14049.npy
296
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_1x_resume
297
+ # resume_model_ckpt: generative_model_89_iter_148770.npy
298
+ # resume_optim_ckpt: optim_89_iter_148770.npy
299
+ resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_2x_resume
300
+ resume_model_ckpt: generative_model_75_iter_125628.npy
301
+ resume_optim_ckpt: optim_75_iter_125628.npy
302
+
303
+ save_model: true
304
+
305
+
306
+
307
+ # ========================================================================================================== Wandb
308
+ # disable wandb
309
+ no_wandb: false
310
+ wandb_usr: gohyixian456
311
+ # True = wandb online -- False = wandb offline
312
+ online: true
313
+
314
+
315
+
316
+
317
+ pocket_vae:
318
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__POCKET
319
+ vae_data_mode: pocket
320
+ remove_h: false
321
+ ca_only: true
322
+
323
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
324
+ conditioning: []
325
+
326
+ # egnn_dynamics
327
+ model: egnn_dynamics
328
+
329
+ # include atom charge, according to periodic table
330
+ include_charges: false
331
+
332
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
333
+ ema_decay: 0
334
+
335
+ # weight of KL term in ELBO, default 0.01
336
+ kl_weight: 0.01
337
+
338
+ # number of latent features, default 4 (have to match ligand VAE & LDM's latent_nf)
339
+ latent_nf: 2
340
+
341
+ # number of layers of EquivariantBlock to use in VAE's Encoder
342
+ encoder_n_layers: 1
343
+
344
+ # number of layers of EquivariantBlock to use in VAE's Decoder
345
+ n_layers: 4
346
+
347
+ # number of GCL Blocks to use in each EquivariantBlock
348
+ inv_sublayers: 1
349
+
350
+ # model's internal operating number of features
351
+ nf: 256
352
+
353
+ # use tanh in the coord_mlp
354
+ tanh: true
355
+
356
+ # use attention in the EGNN
357
+ attention: true
358
+
359
+ # diff/(|diff| + norm_constant)
360
+ norm_constant: 1
361
+
362
+ # whether using or not the sin embedding
363
+ sin_embedding: false
364
+
365
+ # uniform | variational | argmax_variational | deterministic
366
+ dequantization: argmax_variational
367
+
368
+ # Normalize the sum aggregation of EGNN
369
+ normalization_factor: 1
370
+
371
+ # EGNN aggregation method: sum | mean
372
+ aggregation_method: sum
373
+
374
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
375
+ normalize_factors: [1, 4, 10]
376
+
377
+ vae_normalize_x: true
378
+ vae_normalize_method: scale # scale | linear
379
+ vae_normalize_factors: [10, 1, 1]
380
+
381
+ reweight_class_loss: "inv_class_freq"
382
+ reweight_coords_loss: "inv_class_freq"
383
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
384
+
385
+ error_x_weight: 30
386
+ error_h_weight: 15
03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/args_75_iter_125628.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a9ea77f2e066fb9ef3f6f492394c18fd1687822cc66172907da6e2f67d15eec
3
+ size 5706
03_latent2_nf256_ds1k_fusBSum_CA_conditionBlocks34_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A_3x_resume/generative_model_75_iter_125628.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b78e4f16f584b7bb3d619c2ab445bfa0fd55162787588c4214324553e605905b
3
+ size 53575942
03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_2x_resume/03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A.yaml ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ proj_name: Control-GeoLDM
2
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A
3
+ # exp_name: 03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_1x
4
+ exp_name: 03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_2x
5
+
6
+
7
+ # ========================================================================================================== Training Mode (ldm/vae/both)
8
+ # Train second stage LatentDiffusionModel model
9
+ train_diffusion: true
10
+
11
+ # training mode: VAE | LDM | ControlNet
12
+ training_mode: ControlNet
13
+ loss_analysis: false
14
+
15
+ # Specify ligand & pocket VAE weights path, set to null for random initialisation
16
+ # set checkpoint (ckpt) to null to automatically select best
17
+ ligand_ae_path: outputs_selected/vae_ligands/AMP__01_VAE_vaenorm_True10__bfloat16__latent2_nf256_epoch100_bs36_lr1e-4_InvClassFreq_Smooth0.25_x10_h5_NoEMA__DecOnly_KL-0__20240623__10A__LG_Only
18
+ ligand_ae_ckpt: generative_model_2_iter_6336.npy
19
+ pocket_ae_path: outputs_selected/vae_pockets/AMP__01_VAE_vaenorm_True10__float32__latent2_nf256_epoch100_bs4_lr1e-5_InvClassFreq_Smooth0.25_XH_x30_h15_NoEMA__20240623__10A__PKT_Only_2x_resume
20
+ pocket_ae_ckpt: generative_model_3_iter_76064.npy
21
+
22
+ # Specify LDM weights path, set to null for random initialisation
23
+ ldm_path: outputs_selected/ldm/AMP__02_LDM_vaenorm_True10__float32__latent2_nf256_epoch200_bs36_lr1e-4_NoEMA__VAE_DecOnly_KL-0__20240623__10A_9x_resume
24
+ ldm_ckpt: generative_model_108_iter_230208.npy
25
+
26
+ # Zero out all weights of fusion blocks instead of randomly instantiated
27
+ zero_fusion_block_weights: false
28
+
29
+
30
+ # Train 1st stage AutoEncoder model (no effect if train_diffusion=False)
31
+ trainable_ligand_ae_encoder: false
32
+ trainable_ligand_ae_decoder: false
33
+ trainable_pocket_ae_encoder: false
34
+
35
+ # Train 2nd stage LDM model
36
+ trainable_ldm: false
37
+
38
+ # Train 3rd stage ControlNet
39
+ trainable_controlnet: true
40
+ trainable_fusion_blocks: true
41
+
42
+
43
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
44
+ conditioning: []
45
+
46
+ # include atom charge, according to periodic table
47
+ include_charges: false # true for qm9
48
+
49
+ # only works for ldm, not for VAE
50
+ condition_time: true
51
+
52
+ # Time Noisy, t/2, adopted from [https://arxiv.org/abs/2405.06659]
53
+ time_noisy: false
54
+
55
+ vis_activations: false
56
+ vis_activations_batch_samples: 5
57
+ vis_activations_batch_size: 1
58
+ vis_activations_specific_ylim: [0, 40]
59
+
60
+ random_seed: 42
61
+
62
+
63
+ # ========================================================================================================== Dataset
64
+
65
+ # pre-computed dataset stats
66
+ dataset: d_20240623_CrossDocked_LG_PKT__10A__LIGAND
67
+
68
+ # pre-computed training dataset
69
+ data_file: ./data/d_20241115_CrossDocked_LG_PKT_MMseq2_split/d_20241115_CrossDocked_LG_PKT_MMseq2_split__10.0A.npz
70
+ data_splitted: true
71
+
72
+ # Quick Vina 2.1
73
+ compute_qvina: true
74
+ qvina_search_size: 20 # search size (all 3 axes) in Angstroms around ligand center
75
+ qvina_exhaustiveness: 16
76
+ qvina_seed: 42
77
+ qvina_cleanup_files: false # cleanup tmp pdb, pdbqt files
78
+ qvina_save_csv: true # save results in csv
79
+ pocket_pdb_dir: ./data/d_20241115_CrossDocked_LG_PKT_MMseq2_split/test_val_paired_files/val_pocket
80
+ match_raw_file_by_id: true
81
+ mgltools_env_name: mgltools-python2 # for pdb -> pdbqt conversion
82
+
83
+ ligand_add_H: false # add hydrogens via: [mgltools] prepare_ligand4.py -l .. -o .. -A hydrogens
84
+ pocket_add_H: false # add hydrogens via: [mgltools] prepare_receptor4.py -r .. -o .. -A checkhydrogens
85
+ pocket_remove_nonstd_resi: false # remove any pocket residues not in this list:
86
+ # ['CYS','ILE','SER','VAL','GLN','LYS','ASN',
87
+ # 'PRO','THR','PHE','ALA','HIS','GLY','ASP',
88
+ # 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET',
89
+ # 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']
90
+
91
+
92
+ # set to null if you're running this dataset for the first time.
93
+ # Script will generate a random permutation to shuffle the dataset.
94
+ # Please set the path to the DATASET_permutation.npy file after it is generated.
95
+ # permutation_file_path: ./data/d_20240623_CrossDocked_LG_PKT/d_20240623_CrossDocked_LG_PKT__10.0A_LG100_PKT600_permutation.npy
96
+ permutation_file_path: null
97
+
98
+ # what data to load for VAE training: ligand | pocket | all
99
+ vae_data_mode: ligand
100
+
101
+ # When set to an integer value, QM9 will only contain molecules of that amount of atoms, default null
102
+ filter_n_atoms: null
103
+
104
+ # Only use molecules below this size. Int, default null ~!geom
105
+ filter_molecule_size: 100
106
+ filter_pocket_size: 600 # refer EDA
107
+
108
+ # Organize data by size to reduce average memory usage. ~!geom
109
+ sequential: false
110
+
111
+ # Number of worker for the dataloader
112
+ num_workers: 60 # match cpu count
113
+
114
+ # use data augmentation (i.e. random rotation of x atom coordinates)
115
+ data_augmentation: false
116
+
117
+ # remove hydrogen atoms
118
+ remove_h: false
119
+
120
+
121
+
122
+
123
+ # ========================================================================================================== Training Params
124
+ start_epoch: 0
125
+ test_epochs: 1 # 4
126
+
127
+
128
+ n_epochs: 1000 # 3000 takes 20 epoches on paper (bs:32), hence 80 epochs for bs:8
129
+ batch_size: 10 # 14
130
+ lr: 1.0e-4
131
+
132
+ # weight of KL term in ELBO, default 0.01
133
+ kl_weight: 0.0
134
+
135
+ # ode_regularization weightage, default 1e-3
136
+ ode_regularization: 0.001
137
+ # brute_force: false
138
+ # actnorm: true
139
+ break_train_epoch: false
140
+
141
+ # Data Parallel for multi GPU support
142
+ dp: true
143
+ clip_grad: true
144
+
145
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
146
+ ema_decay: 0 # 0.99
147
+
148
+ # add noise to x before encoding, data augmenting
149
+ augment_noise: 0
150
+
151
+ # Number of samples to compute the stability, default 500
152
+ n_stability_samples: 90 # 98, 50
153
+ n_stability_samples_batch_size: 10 # 7, 14
154
+
155
+ # Dataset partition where pocket samples will be drawn from for analyzing
156
+ # generated ligands' stability: train | test | val
157
+ n_stability_eval_split: val
158
+
159
+
160
+ # disables CUDA training
161
+ no_cuda: false
162
+
163
+ # hutch | exact
164
+ trace: hutch
165
+
166
+ # verbose logging
167
+ verbose: false
168
+
169
+ dtype: torch.float32
170
+
171
+ # enable mixed precision training (fp32, fp16)
172
+ mixed_precision_training: true
173
+ mixed_precision_autocast_dtype: torch.bfloat16
174
+
175
+ # use model checkpointing during training to reduce GPU memory usage
176
+ use_checkpointing: true
177
+
178
+ # sqrt: checkpointing is done on the sqrt(block_num)'th Equivariant block of each EGNN for most optimal perf
179
+ # all: checkpointing is done on all Equivariant blocks. Not optimal but helps if input size is too large
180
+ checkpointing_mode: sqrt
181
+
182
+ # splits tensors into managable chunks and performs forward propagation without breaking GPU memory limit
183
+ forward_tensor_chunk_size: 50000
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+ # ========================================================================================================== LDM
193
+ # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
194
+ model: egnn_dynamics
195
+
196
+ probabilistic_model: diffusion
197
+
198
+ # Training complexity is O(1) (unaffected), but sampling complexity is O(steps), default 500
199
+ diffusion_steps: 1000
200
+
201
+ # learned, cosine, polynomial_<power>
202
+ diffusion_noise_schedule: polynomial_2
203
+
204
+ # default 1e-5
205
+ diffusion_noise_precision: 1.0e-05 # ~!fp16
206
+
207
+ # vlb | l2
208
+ diffusion_loss_type: l2
209
+
210
+ # number of latent features, default 4
211
+ latent_nf: 2
212
+
213
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
214
+ normalize_factors: [1, 4, 10]
215
+
216
+ vae_normalize_x: true
217
+ vae_normalize_method: scale # scale | linear
218
+ vae_normalize_factors: [10, 1, 1]
219
+
220
+ reweight_class_loss: "inv_class_freq"
221
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
222
+
223
+ error_x_weight: 10 # error_x custom weighting
224
+ error_h_weight: 5
225
+
226
+
227
+ # ========================================================================================================== Network Architecture
228
+
229
+ # number of layers of EquivariantBlock to use in VAE's Encoder
230
+ encoder_n_layers: 1
231
+
232
+ # number of layers of EquivariantBlock to use in LDM and VAE's Decoder
233
+ n_layers: 4
234
+
235
+ # number of GCL Blocks to use in each EquivariantBlock
236
+ inv_sublayers: 1
237
+
238
+ # model's internal operating number of features
239
+ nf: 256
240
+
241
+ # use tanh in the coord_mlp
242
+ tanh: true
243
+
244
+ # use attention in the EGNN
245
+ attention: true
246
+
247
+ # diff/(|diff| + norm_constant)
248
+ norm_constant: 1
249
+
250
+ # whether using or not the sin embedding
251
+ sin_embedding: false
252
+
253
+ # uniform | variational | argmax_variational | deterministic
254
+ dequantization: argmax_variational
255
+
256
+ # Normalize the sum aggregation of EGNN
257
+ normalization_factor: 1
258
+
259
+ # EGNN aggregation method: sum | mean
260
+ aggregation_method: sum
261
+
262
+
263
+ # Fusion Block specific settings
264
+ fusion_weights: [0, 0, 0.1, 0.1] # [0.25, 0.5, 0.75, 1]
265
+ # Condition fusion method:
266
+ # - scaled_sum : (h1_i,x1_i) = (h1_i,x1_i) + w_i * (f_h1_i,f_x1_i)
267
+ # - balanced_sum : (h1_i,x1_i) = [(1 - w_i) * (h1_i,x1_i)] + [w_i * (f_h1_i,f_x1_i)]
268
+ # - replace : (h1_i,x1_i) = (f_h1_i,f_x1_i)
269
+ fusion_mode: balanced_sum
270
+
271
+ # Initial Noise Injection / Feedback Mechanism
272
+ noise_injection_weights: [0.5, 0.5] # pkt = w[0]*lg + w[1]*pkt
273
+ noise_injection_aggregation_method: mean # mean | sum
274
+ noise_injection_normalization_factor: 1 # aggregation normalization factor
275
+
276
+
277
+
278
+
279
+ # ========================================================================================================== Logging
280
+ # Can be used to visualize multiple times per epoch, default 1e8
281
+ visualize_sample_chain: true
282
+ visualize_every_batch: 20000
283
+ visualize_sample_chain_epochs: 2 # for 1% testing dataset, others set to 1
284
+ n_report_steps: 50
285
+
286
+
287
+
288
+
289
+ # ========================================================================================================== Saving & Resuming
290
+ # resume: null
291
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A
292
+ # resume_model_ckpt: generative_model_5_iter_56988.npy
293
+ # resume_optim_ckpt: optim_5_iter_56988.npy
294
+ resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_1x_resume
295
+ resume_model_ckpt: generative_model_6_iter_66486.npy
296
+ resume_optim_ckpt: optim_6_iter_66486.npy
297
+
298
+ save_model: true
299
+
300
+
301
+
302
+ # ========================================================================================================== Wandb
303
+ # disable wandb
304
+ no_wandb: false
305
+ wandb_usr: gohyixian456
306
+ # True = wandb online -- False = wandb offline
307
+ online: true
308
+
309
+
310
+
311
+
312
+ pocket_vae:
313
+ dataset: d_20240623_CrossDocked_LG_PKT__10A__LIGAND+POCKET
314
+ vae_data_mode: pocket
315
+ remove_h: false
316
+ ca_only: false
317
+
318
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
319
+ conditioning: []
320
+
321
+ # egnn_dynamics
322
+ model: egnn_dynamics
323
+
324
+ # include atom charge, according to periodic table
325
+ include_charges: false
326
+
327
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
328
+ ema_decay: 0
329
+
330
+ # weight of KL term in ELBO, default 0.01
331
+ kl_weight: 0.01
332
+
333
+ # number of latent features, default 4 (have to match ligand VAE & LDM's latent_nf)
334
+ latent_nf: 2
335
+
336
+ # number of layers of EquivariantBlock to use in VAE's Encoder
337
+ encoder_n_layers: 1
338
+
339
+ # number of layers of EquivariantBlock to use in VAE's Decoder
340
+ n_layers: 4
341
+
342
+ # number of GCL Blocks to use in each EquivariantBlock
343
+ inv_sublayers: 1
344
+
345
+ # model's internal operating number of features
346
+ nf: 256
347
+
348
+ # use tanh in the coord_mlp
349
+ tanh: true
350
+
351
+ # use attention in the EGNN
352
+ attention: true
353
+
354
+ # diff/(|diff| + norm_constant)
355
+ norm_constant: 1
356
+
357
+ # whether using or not the sin embedding
358
+ sin_embedding: false
359
+
360
+ # uniform | variational | argmax_variational | deterministic
361
+ dequantization: argmax_variational
362
+
363
+ # Normalize the sum aggregation of EGNN
364
+ normalization_factor: 1
365
+
366
+ # EGNN aggregation method: sum | mean
367
+ aggregation_method: sum
368
+
369
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
370
+ normalize_factors: [1, 4, 10]
371
+
372
+ vae_normalize_x: true
373
+ vae_normalize_method: scale # scale | linear
374
+ vae_normalize_factors: [10, 1, 1]
375
+
376
+ reweight_class_loss: "inv_class_freq"
377
+ reweight_coords_loss: "inv_class_freq"
378
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
379
+
380
+ error_x_weight: 30
381
+ error_h_weight: 15
03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_2x_resume/args_5_iter_56988.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e265c6c80c375510446ea18e8c7c918bcbbbf7c4d7cae1e379f9aa8b27c7979d
3
+ size 5422
03_latent2_nf256_ds1k_fusBSum_conditionBlocks34_0.1__epoch1k_bs10_lr1e-4_NoEMA__20241115__10A_2x_resume/generative_model_5_iter_56988.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bfe6dc165698fecbafb4948e34f69a4e0c1d62c79be4a22dda8b69baeeb69b3
3
+ size 53550562
03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A/03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A.yaml ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ proj_name: Control-GeoLDM
2
+ exp_name: 03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
3
+
4
+
5
+ # ========================================================================================================== Training Mode (ldm/vae/both)
6
+ # Train second stage LatentDiffusionModel model
7
+ train_diffusion: true
8
+
9
+ # training mode: VAE | LDM | ControlNet
10
+ training_mode: ControlNet
11
+ loss_analysis: false
12
+
13
+ # Specify ligand & pocket VAE weights path, set to null for random initialisation
14
+ # set checkpoint (ckpt) to null to automatically select best
15
+ ligand_ae_path: outputs_selected/vae_ligands/AMP__01_VAE_vaenorm_True10__bfloat16__latent2_nf256_epoch100_bs36_lr1e-4_InvClassFreq_Smooth0.25_x10_h5_NoEMA__DecOnly_KL-0__20240623__10A__LG_Only
16
+ ligand_ae_ckpt: generative_model_2_iter_6336.npy
17
+ pocket_ae_path: outputs_selected/vae_pockets/AMP__01_VAE_vaenorm_True10__float32__latent2_nf256_epoch100_bs12_lr1e-4_InvClassFreq_Smooth0.25_XH_x30_h15_NoEMA__20240623__10A__PKT_CA_Only
18
+ pocket_ae_ckpt: generative_model_3_iter_33308.npy
19
+
20
+ # Specify LDM weights path, set to null for random initialisation
21
+ ldm_path: outputs_selected/ldm/AMP__02_LDM_vaenorm_True10__float32__latent2_nf256_epoch200_bs36_lr1e-4_NoEMA__VAE_DecOnly_KL-0__20240623__10A_9x_resume
22
+ ldm_ckpt: generative_model_108_iter_230208.npy
23
+
24
+ # Zero out all weights of fusion blocks instead of randomly instantiated
25
+ zero_fusion_block_weights: false
26
+
27
+
28
+ # Train 1st stage AutoEncoder model (no effect if train_diffusion=False)
29
+ trainable_ligand_ae_encoder: false
30
+ trainable_ligand_ae_decoder: false
31
+ trainable_pocket_ae_encoder: false
32
+
33
+ # Train 2nd stage LDM model
34
+ trainable_ldm: false
35
+
36
+ # Train 3rd stage ControlNet
37
+ trainable_controlnet: true
38
+ trainable_fusion_blocks: true
39
+
40
+
41
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
42
+ conditioning: []
43
+
44
+ # include atom charge, according to periodic table
45
+ include_charges: false # true for qm9
46
+
47
+ # only works for ldm, not for VAE
48
+ condition_time: true
49
+
50
+ # Time Noisy, t/2, adopted from [https://arxiv.org/abs/2405.06659]
51
+ time_noisy: false
52
+
53
+ vis_activations: false
54
+ vis_activations_batch_samples: 5
55
+ vis_activations_batch_size: 1
56
+ vis_activations_specific_ylim: [0, 40]
57
+
58
+ # random_seed: 0
59
+ random_seed: 42
60
+
61
+
62
+ # ========================================================================================================== Dataset
63
+
64
+ # pre-computed dataset stats
65
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__LIGAND
66
+
67
+ # pre-computed training dataset
68
+ data_file: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/d_20241203_CrossDocked_LG_PKT_MMseq2_split__10.0A__CA_Only.npz
69
+ data_splitted: true
70
+
71
+ # Quick Vina 2.1
72
+ compute_qvina: true
73
+ qvina_search_size: 20 # search size (all 3 axes) in Angstroms around ligand center
74
+ qvina_exhaustiveness: 16
75
+ qvina_seed: 42
76
+ qvina_cleanup_files: true # cleanup tmp pdb, pdbqt files
77
+ qvina_save_csv: true # save results in csv
78
+ pocket_pdb_dir: ./data/d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only/test_val_paired_files/val_pocket
79
+ match_raw_file_by_id: true
80
+ mgltools_env_name: mgltools-python2 # for pdb -> pdbqt conversion
81
+
82
+ ligand_add_H: false # add hydrogens via: [mgltools] prepare_ligand4.py -l .. -o .. -A hydrogens
83
+ pocket_add_H: false # add hydrogens via: [mgltools] prepare_receptor4.py -r .. -o .. -A checkhydrogens
84
+ pocket_remove_nonstd_resi: false # remove any pocket residues not in this list:
85
+ # ['CYS','ILE','SER','VAL','GLN','LYS','ASN',
86
+ # 'PRO','THR','PHE','ALA','HIS','GLY','ASP',
87
+ # 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET',
88
+ # 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']
89
+
90
+
91
+ # set to null if you're running this dataset for the first time.
92
+ # Script will generate a random permutation to shuffle the dataset.
93
+ # Please set the path to the DATASET_permutation.npy file after it is generated.
94
+ # permutation_file_path: ./data/d_20240623_CrossDocked_LG_PKT/d_20240623_CrossDocked_LG_PKT__10.0A_LG100_PKT600_permutation.npy
95
+ permutation_file_path: null
96
+
97
+ # what data to load for VAE training: ligand | pocket | all
98
+ vae_data_mode: ligand
99
+
100
+ # When set to an integer value, QM9 will only contain molecules of that amount of atoms, default null
101
+ filter_n_atoms: null
102
+
103
+ # Only use molecules below this size. Int, default null ~!geom
104
+ filter_molecule_size: 100
105
+ filter_pocket_size: 80
106
+
107
+ # Organize data by size to reduce average memory usage. ~!geom
108
+ sequential: false
109
+
110
+ # Number of worker for the dataloader
111
+ num_workers: 32 # match cpu count
112
+
113
+ # use data augmentation (i.e. random rotation of x atom coordinates)
114
+ data_augmentation: false
115
+
116
+ # remove hydrogen atoms
117
+ remove_h: false
118
+
119
+
120
+
121
+
122
+ # ========================================================================================================== Training Params
123
+ start_epoch: 0
124
+ test_epochs: 5 # 4
125
+
126
+
127
+ n_epochs: 1000 # 3000 takes 20 epoches on paper (bs:32), hence 80 epochs for bs:8
128
+ batch_size: 60 # 14
129
+ lr: 1.0e-4
130
+
131
+ # weight of KL term in ELBO, default 0.01
132
+ kl_weight: 0.0
133
+
134
+ # ode_regularization weightage, default 1e-3
135
+ ode_regularization: 0.001
136
+ # brute_force: false
137
+ # actnorm: true
138
+ break_train_epoch: false
139
+
140
+ # Data Parallel for multi GPU support
141
+ dp: true
142
+ clip_grad: true
143
+
144
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
145
+ ema_decay: 0 # 0.99
146
+
147
+ # add noise to x before encoding, data augmenting
148
+ augment_noise: 0
149
+
150
+ # Number of samples to compute the stability, default 500
151
+ n_stability_samples: 90 # 98, 50
152
+ n_stability_samples_batch_size: 10 # 7, 14
153
+
154
+ # Dataset partition where pocket samples will be drawn from for analyzing
155
+ # generated ligands' stability: train | test | val
156
+ n_stability_eval_split: val
157
+
158
+
159
+ # disables CUDA training
160
+ no_cuda: false
161
+
162
+ # hutch | exact
163
+ trace: hutch
164
+
165
+ # verbose logging
166
+ verbose: false
167
+
168
+ dtype: torch.float32
169
+
170
+ # enable mixed precision training (fp32, fp16)
171
+ mixed_precision_training: true
172
+ mixed_precision_autocast_dtype: torch.bfloat16
173
+
174
+ # use model checkpointing during training to reduce GPU memory usage
175
+ use_checkpointing: true
176
+
177
+ # sqrt: checkpointing is done on the sqrt(block_num)'th Equivariant block of each EGNN for most optimal perf
178
+ # all: checkpointing is done on all Equivariant blocks. Not optimal but helps if input size is too large
179
+ checkpointing_mode: sqrt
180
+
181
+ # splits tensors into managable chunks and performs forward propagation without breaking GPU memory limit
182
+ forward_tensor_chunk_size: 50000
183
+
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+ # ========================================================================================================== LDM
192
+ # our_dynamics | schnet | simple_dynamics | kernel_dynamics | egnn_dynamics | gnn_dynamics
193
+ model: egnn_dynamics
194
+
195
+ probabilistic_model: diffusion
196
+
197
+ # Training complexity is O(1) (unaffected), but sampling complexity is O(steps), default 500
198
+ diffusion_steps: 1000
199
+
200
+ # learned, cosine, polynomial_<power>
201
+ diffusion_noise_schedule: polynomial_2
202
+
203
+ # default 1e-5
204
+ diffusion_noise_precision: 1.0e-05 # ~!fp16
205
+
206
+ # vlb | l2
207
+ diffusion_loss_type: l2
208
+
209
+ # number of latent features, default 4
210
+ latent_nf: 2
211
+
212
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
213
+ normalize_factors: [1, 4, 10]
214
+
215
+ vae_normalize_x: true
216
+ vae_normalize_method: scale # scale | linear
217
+ vae_normalize_factors: [10, 1, 1]
218
+
219
+ reweight_class_loss: "inv_class_freq"
220
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
221
+
222
+ error_x_weight: 10 # error_x custom weighting
223
+ error_h_weight: 5
224
+
225
+
226
+ # ========================================================================================================== Network Architecture
227
+
228
+ # number of layers of EquivariantBlock to use in VAE's Encoder
229
+ encoder_n_layers: 1
230
+
231
+ # number of layers of EquivariantBlock to use in LDM and VAE's Decoder
232
+ n_layers: 4
233
+
234
+ # number of GCL Blocks to use in each EquivariantBlock
235
+ inv_sublayers: 1
236
+
237
+ # model's internal operating number of features
238
+ nf: 256
239
+
240
+ # use tanh in the coord_mlp
241
+ tanh: true
242
+
243
+ # use attention in the EGNN
244
+ attention: true
245
+
246
+ # diff/(|diff| + norm_constant)
247
+ norm_constant: 1
248
+
249
+ # whether using or not the sin embedding
250
+ sin_embedding: false
251
+
252
+ # uniform | variational | argmax_variational | deterministic
253
+ dequantization: argmax_variational
254
+
255
+ # Normalize the sum aggregation of EGNN
256
+ normalization_factor: 1
257
+
258
+ # EGNN aggregation method: sum | mean
259
+ aggregation_method: sum
260
+
261
+
262
+ # Fusion Block specific settings
263
+ fusion_weights: [1.0, 1.0, 1.0, 1.0] # [0.25, 0.5, 0.75, 1]
264
+ # Condition fusion method:
265
+ # - scaled_sum : (h1_i,x1_i) = (h1_i,x1_i) + w_i * (f_h1_i,f_x1_i)
266
+ # - balanced_sum : (h1_i,x1_i) = [(1 - w_i) * (h1_i,x1_i)] + [w_i * (f_h1_i,f_x1_i)]
267
+ # - replace : (h1_i,x1_i) = (f_h1_i,f_x1_i)
268
+ fusion_mode: replace
269
+
270
+ # Initial Noise Injection / Feedback Mechanism
271
+ noise_injection_weights: [0.5, 0.5] # pkt = w[0]*lg + w[1]*pkt
272
+ noise_injection_aggregation_method: mean # mean | sum
273
+ noise_injection_normalization_factor: 1 # aggregation normalization factor
274
+
275
+
276
+
277
+
278
+ # ========================================================================================================== Logging
279
+ # Can be used to visualize multiple times per epoch, default 1e8
280
+ visualize_sample_chain: true
281
+ visualize_every_batch: 20000
282
+ visualize_sample_chain_epochs: 2 # for 1% testing dataset, others set to 1
283
+ n_report_steps: 50
284
+
285
+
286
+
287
+
288
+ # ========================================================================================================== Saving & Resuming
289
+ # resume: outputs_selected/controlnet/03_latent2_nf256_ds1k_fusBSum_CA_conditionAll_0.5__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A
290
+ # resume_model_ckpt: generative_model_8_iter_14049.npy
291
+ # resume_optim_ckpt: optim_8_iter_14049.npy
292
+ resume: null
293
+
294
+ save_model: true
295
+
296
+
297
+
298
+ # ========================================================================================================== Wandb
299
+ # disable wandb
300
+ no_wandb: false
301
+ wandb_usr: gohyixian456
302
+ # True = wandb online -- False = wandb offline
303
+ online: true
304
+
305
+
306
+
307
+
308
+ pocket_vae:
309
+ dataset: d_20241203_CrossDocked_LG_PKT_MMseq2_split_CA_only__10A__POCKET
310
+ vae_data_mode: pocket
311
+ remove_h: false
312
+ ca_only: true
313
+
314
+ # can contain multiple: homo | onehot | lumo | num_atoms | etc
315
+ conditioning: []
316
+
317
+ # egnn_dynamics
318
+ model: egnn_dynamics
319
+
320
+ # include atom charge, according to periodic table
321
+ include_charges: false
322
+
323
+ # Amount of EMA decay, 0 means off. A reasonable value is 0.999.
324
+ ema_decay: 0
325
+
326
+ # weight of KL term in ELBO, default 0.01
327
+ kl_weight: 0.01
328
+
329
+ # number of latent features, default 4 (have to match ligand VAE & LDM's latent_nf)
330
+ latent_nf: 2
331
+
332
+ # number of layers of EquivariantBlock to use in VAE's Encoder
333
+ encoder_n_layers: 1
334
+
335
+ # number of layers of EquivariantBlock to use in VAE's Decoder
336
+ n_layers: 4
337
+
338
+ # number of GCL Blocks to use in each EquivariantBlock
339
+ inv_sublayers: 1
340
+
341
+ # model's internal operating number of features
342
+ nf: 256
343
+
344
+ # use tanh in the coord_mlp
345
+ tanh: true
346
+
347
+ # use attention in the EGNN
348
+ attention: true
349
+
350
+ # diff/(|diff| + norm_constant)
351
+ norm_constant: 1
352
+
353
+ # whether using or not the sin embedding
354
+ sin_embedding: false
355
+
356
+ # uniform | variational | argmax_variational | deterministic
357
+ dequantization: argmax_variational
358
+
359
+ # Normalize the sum aggregation of EGNN
360
+ normalization_factor: 1
361
+
362
+ # EGNN aggregation method: sum | mean
363
+ aggregation_method: sum
364
+
365
+ # normalize factors for [x, h_cat/categorical/one-hot, h_int/integer/charges]
366
+ normalize_factors: [1, 4, 10]
367
+
368
+ vae_normalize_x: true
369
+ vae_normalize_method: scale # scale | linear
370
+ vae_normalize_factors: [10, 1, 1]
371
+
372
+ reweight_class_loss: "inv_class_freq"
373
+ reweight_coords_loss: "inv_class_freq"
374
+ smoothing_factor: 0.25 # [0.1 - 1.0) 1.0 is essentially disabling
375
+
376
+ error_x_weight: 30
377
+ error_h_weight: 15
03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A/args_135_iter_224808.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a53ec1e5f606b364dd6579772c91d961b0576c88991fffe029868968a346e71
3
+ size 5430
03_latent2_nf256_ds1k_fusReplace_CA__epoch1k_bs60_lr1e-4_NoEMA__20241203__10A/generative_model_135_iter_224808.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74803916b1fd755f3b6fa346885792b031824d4dc6ed5fded39d04580f358397
3
+ size 53576312