splines-ai commited on
Commit
87e8ac6
·
verified ·
1 Parent(s): d357bfb

Upload Structures25 models

Browse files

See also https://datadryad.org/dataset/doi:10.5061/dryad.0cfxpnwcs#readme

trained-on-qm9/hparams.yaml ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ name: ''
3
+ tags:
4
+ - qm9_perturbed_fock
5
+ - kin_plus_xc
6
+ - graphformer
7
+ train: true
8
+ validate: true
9
+ test: false
10
+ ckpt_path: /export/scratch/ialgroup/dft_str25/models/train/runs/009__str25\qm9_tf__seed-100/checkpoints/last.ckpt
11
+ use_original_settings: null
12
+ weight_ckpt_path: null
13
+ seed: 2274360845
14
+ data:
15
+ datamodule:
16
+ _target_: mldft.ml.data.datamodule.OFDataModule
17
+ transforms: ${data.transforms}
18
+ split_file: ${oc.env:DFT_DATA}/${data.dataset_name}/split.pkl
19
+ data_dir: ${oc.env:DFT_DATA}
20
+ basis_info: ${data.basis_info}
21
+ batch_size: 128
22
+ num_workers: 8
23
+ pin_memory: false
24
+ shuffle_train: true
25
+ shuffle_val: false
26
+ shuffle_test: false
27
+ dataset_kwargs:
28
+ add_irreps: true
29
+ cache_in_memory: false
30
+ energy_key: e_${data.target_key}
31
+ gradient_key: grad_${data.target_key}
32
+ limit_scf_iterations:
33
+ - 6
34
+ - 7
35
+ - 8
36
+ - 9
37
+ - 10
38
+ - 11
39
+ - 12
40
+ - 13
41
+ - 14
42
+ - 15
43
+ - 16
44
+ - 17
45
+ - 18
46
+ - 19
47
+ - 20
48
+ - 21
49
+ - 22
50
+ - 23
51
+ - 24
52
+ - 25
53
+ - 26
54
+ - -1
55
+ keep_initial_guess: false
56
+ dataloader_kwargs:
57
+ follow_batch:
58
+ - coeffs
59
+ - atomic_numbers
60
+ list_keys: null
61
+ transforms:
62
+ cached_transforms:
63
+ name: local_frames_global_${data.natural_reparametrization.orthogonalization}_natrep
64
+ additional_pre_transforms:
65
+ - _target_: mldft.ml.data.components.convert_transforms.AddOverlapMatrix
66
+ basis_info: ${data.basis_info}
67
+ transforms:
68
+ - _target_: mldft.ml.data.components.basis_transforms.ToLocalFrames
69
+ sparse: false
70
+ - _target_: mldft.ml.data.components.basis_transforms.ToGlobalNatRep
71
+ orthogonalization: ${data.natural_reparametrization.orthogonalization}
72
+ _target_: mldft.ml.data.components.basis_transforms.MasterTransformation
73
+ name: local_frames_global_${data.natural_reparametrization.orthogonalization}_natrep
74
+ use_cached_data: true
75
+ pre_transforms:
76
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
77
+ float_dtype: torch.float64
78
+ - _target_: mldft.ml.data.components.convert_transforms.ProjectGradient
79
+ - _target_: mldft.ml.data.components.convert_transforms.AddFullEdgeIndex
80
+ - _target_: mldft.ml.data.components.basis_transforms.AddLocalFrames
81
+ basis_transforms: []
82
+ post_transforms:
83
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
84
+ add_transformation_matrix: false
85
+ target_key: kin_plus_xc
86
+ dataset_statistics:
87
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
88
+ path: ${oc.env:DFT_DATA}/${data.dataset_name}/dataset_statistics/dataset_statistics_labels_${data.transforms.name}_${data.datamodule.dataset_kwargs.energy_key}.zarr
89
+ natural_reparametrization:
90
+ orthogonalization: symmetric
91
+ basis_info:
92
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
93
+ path_to_data_info: ${oc.env:DFT_DATA}/${data.dataset_name}/dataset_info.yaml
94
+ atomic_numbers:
95
+ - 1
96
+ - 6
97
+ - 7
98
+ - 8
99
+ - 9
100
+ cutoff: 6.0
101
+ dataset_name: QM9_perturbed_fock
102
+ model:
103
+ optimizer:
104
+ _target_: torch.optim.AdamW
105
+ _partial_: true
106
+ lr: 7.0e-05
107
+ betas:
108
+ - 0.95
109
+ - 0.99
110
+ weight_decay: 1.0e-10
111
+ loss_function:
112
+ _target_: mldft.ml.models.components.loss_function.WeightedLoss
113
+ energy_loss:
114
+ weight: 0.1
115
+ loss:
116
+ _target_: mldft.ml.models.components.loss_function.EnergyLoss
117
+ loss_function:
118
+ _target_: torch.nn.L1Loss
119
+ reduction: none
120
+ sample_weigher:
121
+ _target_: mldft.ml.models.components.sample_weighers.HasEnergyLabelSampleWeigher
122
+ gradient_loss:
123
+ weight: 0.9
124
+ loss:
125
+ _target_: mldft.ml.models.components.loss_function.EnergyGradientLoss
126
+ loss_function:
127
+ _target_: torch.nn.L1Loss
128
+ reduction: none
129
+ sample_weigher:
130
+ _target_: mldft.ml.models.components.sample_weighers.HasEnergyLabelSampleWeigher
131
+ coefficient_loss:
132
+ weight: 0
133
+ loss:
134
+ _target_: mldft.ml.models.components.loss_function.CoefficientLoss
135
+ loss_function:
136
+ _target_: torch.nn.L1Loss
137
+ reduction: none
138
+ sample_weigher: null
139
+ scheduler:
140
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
141
+ _partial_: true
142
+ T_max: ${trainer.max_epochs}
143
+ eta_min: 0
144
+ last_epoch: -1
145
+ _target_: mldft.ml.models.mldft_module.MLDFTLitModule
146
+ variational: true
147
+ target_key: ${data.target_key}
148
+ compile: false
149
+ basis_info: ${data.basis_info}
150
+ metric_interval: 1
151
+ logging_mixin_interval: 1000
152
+ show_logging_mixins_in_progress_bar: false
153
+ net:
154
+ _target_: mldft.ml.models.components.graphformer.Graphformer
155
+ edge_mlp:
156
+ _target_: mldft.ml.models.components.mlp.MLP
157
+ in_channels: 128
158
+ hidden_channels:
159
+ - 768
160
+ - 32
161
+ activation_layer:
162
+ _target_: hydra.utils.get_class
163
+ path: torch.nn.SiLU
164
+ dropout: 0.0
165
+ energy_mlp:
166
+ _target_: mldft.ml.models.components.mlp.MLP
167
+ in_channels: 768
168
+ hidden_channels:
169
+ - 768
170
+ - 1
171
+ activation_layer:
172
+ _target_: hydra.utils.get_class
173
+ path: torch.nn.SiLU
174
+ dropout: 0.0
175
+ disable_dropout_last_layer: true
176
+ disable_activation_last_layer: true
177
+ disable_norm_last_layer: true
178
+ gbf_module:
179
+ _target_: mldft.ml.models.components.gbf_module.GaussianLayer
180
+ basis_info: ${data.basis_info}
181
+ num_gaussians: 128
182
+ init_radius_range:
183
+ - 0
184
+ - 3
185
+ directed: true
186
+ normalized: true
187
+ node_embedding_module:
188
+ _target_: mldft.ml.models.components.node_embedding.NodeEmbedding.from_basis_info
189
+ basis_info: ${data.basis_info}
190
+ out_channels: 768
191
+ dst_in_channels: 128
192
+ p_hidden_channels: 768
193
+ p_num_layers: 3
194
+ p_activation:
195
+ _target_: hydra.utils.get_class
196
+ path: torch.nn.GELU
197
+ p_dropout: 0.0
198
+ dst_hidden_channels: 768
199
+ dst_num_layers: 3
200
+ dst_activation:
201
+ _target_: hydra.utils.get_class
202
+ path: torch.nn.GELU
203
+ dst_dropout: 0.0
204
+ lambda_co: 10.0
205
+ lambda_mul: 0.02
206
+ use_per_basis_func_shrink_gate: true
207
+ cutoff: null
208
+ gnn_module:
209
+ _target_: mldft.ml.models.components.g3d_stack.G3DStack
210
+ g3d_class:
211
+ _partial_: true
212
+ _target_: mldft.ml.models.components.g3d_layer_tf.G3DLayerTF
213
+ in_reps:
214
+ _target_: tensorframes.reps.Irreps
215
+ irreps: 513x0+85x1
216
+ n_layers: 4
217
+ heads: 32
218
+ edge_dim: 1
219
+ dropout: 0.0
220
+ attention_weight_dropout: 0.0
221
+ mlp_hidden_dim: null
222
+ mlp_activation:
223
+ _target_: hydra.utils.get_class
224
+ path: torch.nn.GELU
225
+ norm_layer_class:
226
+ _target_: torch_geometric.nn.norm.layer_norm.LayerNorm
227
+ _partial_: true
228
+ mode: node
229
+ activation_dropout: 0.0
230
+ cutoff: null
231
+ atom_ref_module:
232
+ _target_: mldft.ml.models.components.atom_ref.AtomRef.from_dataset_statistics
233
+ dataset_statistics: ${data.dataset_statistics}
234
+ weigher_key: has_energy_label
235
+ initial_guess_module:
236
+ _target_: mldft.ml.models.components.initial_guess_delta_module.InitialGuessDeltaModule
237
+ input_size: 768
238
+ basis_info: ${data.basis_info}
239
+ dataset_statistics: ${data.dataset_statistics}
240
+ weigher_key: initial_guess_only
241
+ activation_function:
242
+ _target_: hydra.utils.get_class
243
+ path: torch.nn.GELU
244
+ hidden_layers:
245
+ - 768
246
+ dropout: 0.0
247
+ dimension_wise_rescaling_module:
248
+ _target_: mldft.ml.models.components.dimension_wise_rescaling.DimensionWiseRescaling.from_dataset_statistics
249
+ dataset_statistics: ${data.dataset_statistics}
250
+ weigher_key: has_energy_label
251
+ s_coeff: 50
252
+ s_grad: 0.05
253
+ epsilon: 1.0e-08
254
+ callbacks:
255
+ learning_rate_monitor:
256
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
257
+ model_checkpoint:
258
+ _target_: mldft.ml.callbacks.checkpoint.ModelCheckpointWithPermissions
259
+ dirpath: ${paths.output_dir}/checkpoints
260
+ filename: epoch_{epoch:03d}
261
+ monitor: val_loss/total
262
+ verbose: false
263
+ save_last: true
264
+ save_top_k: 1
265
+ mode: min
266
+ auto_insert_metric_name: false
267
+ save_weights_only: false
268
+ every_n_train_steps: null
269
+ train_time_interval: null
270
+ every_n_epochs: null
271
+ save_on_train_epoch_end: null
272
+ model_summary:
273
+ _target_: mldft.ml.callbacks.SubModelSummary
274
+ max_depth: -1
275
+ path_in_model: net
276
+ rich_progress_bar:
277
+ _target_: lightning.pytorch.callbacks.RichProgressBar
278
+ print_overrides:
279
+ _target_: mldft.ml.callbacks.PrintOverrides
280
+ compact: false
281
+ target_pred_scatters:
282
+ _target_: mldft.ml.callbacks.image_logging.LogTargetPredScatters
283
+ with_atom_ref: auto
284
+ train_timing:
285
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
286
+ initial_interval: ${callbacks.interval}
287
+ val_timing:
288
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
289
+ initial_interval: ${callbacks.interval}
290
+ gradient_scatter:
291
+ _target_: mldft.ml.callbacks.image_logging.LogGradientScatter
292
+ train_timing:
293
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
294
+ initial_interval: ${callbacks.interval}
295
+ val_timing:
296
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
297
+ initial_interval: ${callbacks.interval}
298
+ distance_embeddings:
299
+ _target_: mldft.ml.callbacks.image_logging.LogDistanceEmbeddings
300
+ max_distance: 8.0
301
+ n_distances: 1000
302
+ train_timing:
303
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
304
+ initial_interval: ${callbacks.interval}
305
+ val_timing:
306
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
307
+ initial_interval: ${callbacks.interval}
308
+ molecule_mesh_logging:
309
+ log_initial_guess: true
310
+ log_gradient: true
311
+ log_random_basis_functions: false
312
+ _target_: mldft.ml.callbacks.mesh_logging.LogMolecule
313
+ train_timing:
314
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
315
+ initial_interval: ${callbacks.interval}
316
+ val_timing:
317
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
318
+ initial_interval: ${callbacks.interval}
319
+ custom_scalars:
320
+ _target_: mldft.ml.callbacks.custom_scalars.AddMetricAndLossCustomScalars
321
+ interval: 1000
322
+ logger:
323
+ tensorboard:
324
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
325
+ save_dir: ${paths.output_dir}
326
+ max_queue: 10000
327
+ name: null
328
+ log_graph: false
329
+ default_hp_metric: false
330
+ prefix: ''
331
+ version: ''
332
+ trainer:
333
+ _target_: lightning.pytorch.trainer.Trainer
334
+ default_root_dir: ${paths.output_dir}
335
+ min_epochs: 1
336
+ max_epochs: 90
337
+ log_every_n_steps: 200
338
+ inference_mode: false
339
+ accelerator: auto
340
+ devices: 1
341
+ precision: 32
342
+ check_val_every_n_epoch: 1
343
+ deterministic: false
344
+ paths:
345
+ root_dir: ${oc.env:PROJECT_ROOT}
346
+ data_dir: ${oc.env:DFT_DATA}
347
+ log_dir: ${oc.env:DFT_MODELS}
348
+ output_dir: ${hydra:runtime.output_dir}
349
+ work_dir: ${hydra:runtime.cwd}
350
+ extras:
351
+ ignore_warnings: false
352
+ enforce_tags: true
353
+ print_config: true
354
+ hostname: compgpu7
355
+ local: {}
356
+ git:
357
+ sha: 0dbf4dcea9857269d00de317042686c330a76403
358
+ branch: change_tensorframes_version
359
+ is_dirty: false
trained-on-qm9/hparams_resolved.yaml ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ name: ''
3
+ tags:
4
+ - qm9_perturbed_fock
5
+ - kin_plus_xc
6
+ - graphformer
7
+ train: true
8
+ validate: true
9
+ test: false
10
+ ckpt_path: /export/scratch/ialgroup/dft_str25/models/train/runs/009__str25\qm9_tf__seed-100/checkpoints/last.ckpt
11
+ use_original_settings: null
12
+ weight_ckpt_path: null
13
+ seed: 2274360845
14
+ data:
15
+ datamodule:
16
+ _target_: mldft.ml.data.datamodule.OFDataModule
17
+ transforms:
18
+ cached_transforms:
19
+ name: local_frames_global_symmetric_natrep
20
+ additional_pre_transforms:
21
+ - _target_: mldft.ml.data.components.convert_transforms.AddOverlapMatrix
22
+ basis_info:
23
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
24
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_info.yaml
25
+ atomic_numbers:
26
+ - 1
27
+ - 6
28
+ - 7
29
+ - 8
30
+ - 9
31
+ transforms:
32
+ - _target_: mldft.ml.data.components.basis_transforms.ToLocalFrames
33
+ sparse: false
34
+ - _target_: mldft.ml.data.components.basis_transforms.ToGlobalNatRep
35
+ orthogonalization: symmetric
36
+ _target_: mldft.ml.data.components.basis_transforms.MasterTransformation
37
+ name: local_frames_global_symmetric_natrep
38
+ use_cached_data: true
39
+ pre_transforms:
40
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
41
+ float_dtype: torch.float64
42
+ - _target_: mldft.ml.data.components.convert_transforms.ProjectGradient
43
+ - _target_: mldft.ml.data.components.convert_transforms.AddFullEdgeIndex
44
+ - _target_: mldft.ml.data.components.basis_transforms.AddLocalFrames
45
+ basis_transforms: []
46
+ post_transforms:
47
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
48
+ add_transformation_matrix: false
49
+ split_file: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/split.pkl
50
+ data_dir: /export/scratch/ialgroup/dft_data
51
+ basis_info:
52
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
53
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_info.yaml
54
+ atomic_numbers:
55
+ - 1
56
+ - 6
57
+ - 7
58
+ - 8
59
+ - 9
60
+ batch_size: 128
61
+ num_workers: 8
62
+ pin_memory: false
63
+ shuffle_train: true
64
+ shuffle_val: false
65
+ shuffle_test: false
66
+ dataset_kwargs:
67
+ add_irreps: true
68
+ cache_in_memory: false
69
+ energy_key: e_kin_plus_xc
70
+ gradient_key: grad_kin_plus_xc
71
+ limit_scf_iterations:
72
+ - 6
73
+ - 7
74
+ - 8
75
+ - 9
76
+ - 10
77
+ - 11
78
+ - 12
79
+ - 13
80
+ - 14
81
+ - 15
82
+ - 16
83
+ - 17
84
+ - 18
85
+ - 19
86
+ - 20
87
+ - 21
88
+ - 22
89
+ - 23
90
+ - 24
91
+ - 25
92
+ - 26
93
+ - -1
94
+ keep_initial_guess: false
95
+ dataloader_kwargs:
96
+ follow_batch:
97
+ - coeffs
98
+ - atomic_numbers
99
+ list_keys: null
100
+ transforms:
101
+ cached_transforms:
102
+ name: local_frames_global_symmetric_natrep
103
+ additional_pre_transforms:
104
+ - _target_: mldft.ml.data.components.convert_transforms.AddOverlapMatrix
105
+ basis_info:
106
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
107
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_info.yaml
108
+ atomic_numbers:
109
+ - 1
110
+ - 6
111
+ - 7
112
+ - 8
113
+ - 9
114
+ transforms:
115
+ - _target_: mldft.ml.data.components.basis_transforms.ToLocalFrames
116
+ sparse: false
117
+ - _target_: mldft.ml.data.components.basis_transforms.ToGlobalNatRep
118
+ orthogonalization: symmetric
119
+ _target_: mldft.ml.data.components.basis_transforms.MasterTransformation
120
+ name: local_frames_global_symmetric_natrep
121
+ use_cached_data: true
122
+ pre_transforms:
123
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
124
+ float_dtype: torch.float64
125
+ - _target_: mldft.ml.data.components.convert_transforms.ProjectGradient
126
+ - _target_: mldft.ml.data.components.convert_transforms.AddFullEdgeIndex
127
+ - _target_: mldft.ml.data.components.basis_transforms.AddLocalFrames
128
+ basis_transforms: []
129
+ post_transforms:
130
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
131
+ add_transformation_matrix: false
132
+ target_key: kin_plus_xc
133
+ dataset_statistics:
134
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
135
+ path: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_statistics/dataset_statistics_labels_local_frames_global_symmetric_natrep_e_kin_plus_xc.zarr
136
+ natural_reparametrization:
137
+ orthogonalization: symmetric
138
+ basis_info:
139
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
140
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_info.yaml
141
+ atomic_numbers:
142
+ - 1
143
+ - 6
144
+ - 7
145
+ - 8
146
+ - 9
147
+ cutoff: 6.0
148
+ dataset_name: QM9_perturbed_fock
149
+ model:
150
+ optimizer:
151
+ _target_: torch.optim.AdamW
152
+ _partial_: true
153
+ lr: 7.0e-05
154
+ betas:
155
+ - 0.95
156
+ - 0.99
157
+ weight_decay: 1.0e-10
158
+ loss_function:
159
+ _target_: mldft.ml.models.components.loss_function.WeightedLoss
160
+ energy_loss:
161
+ weight: 0.1
162
+ loss:
163
+ _target_: mldft.ml.models.components.loss_function.EnergyLoss
164
+ loss_function:
165
+ _target_: torch.nn.L1Loss
166
+ reduction: none
167
+ sample_weigher:
168
+ _target_: mldft.ml.models.components.sample_weighers.HasEnergyLabelSampleWeigher
169
+ gradient_loss:
170
+ weight: 0.9
171
+ loss:
172
+ _target_: mldft.ml.models.components.loss_function.EnergyGradientLoss
173
+ loss_function:
174
+ _target_: torch.nn.L1Loss
175
+ reduction: none
176
+ sample_weigher:
177
+ _target_: mldft.ml.models.components.sample_weighers.HasEnergyLabelSampleWeigher
178
+ coefficient_loss:
179
+ weight: 0
180
+ loss:
181
+ _target_: mldft.ml.models.components.loss_function.CoefficientLoss
182
+ loss_function:
183
+ _target_: torch.nn.L1Loss
184
+ reduction: none
185
+ sample_weigher: null
186
+ scheduler:
187
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
188
+ _partial_: true
189
+ T_max: 90
190
+ eta_min: 0
191
+ last_epoch: -1
192
+ _target_: mldft.ml.models.mldft_module.MLDFTLitModule
193
+ variational: true
194
+ target_key: kin_plus_xc
195
+ compile: false
196
+ basis_info:
197
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
198
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_info.yaml
199
+ atomic_numbers:
200
+ - 1
201
+ - 6
202
+ - 7
203
+ - 8
204
+ - 9
205
+ metric_interval: 1
206
+ logging_mixin_interval: 1000
207
+ show_logging_mixins_in_progress_bar: false
208
+ net:
209
+ _target_: mldft.ml.models.components.graphformer.Graphformer
210
+ edge_mlp:
211
+ _target_: mldft.ml.models.components.mlp.MLP
212
+ in_channels: 128
213
+ hidden_channels:
214
+ - 768
215
+ - 32
216
+ activation_layer:
217
+ _target_: hydra.utils.get_class
218
+ path: torch.nn.SiLU
219
+ dropout: 0.0
220
+ energy_mlp:
221
+ _target_: mldft.ml.models.components.mlp.MLP
222
+ in_channels: 768
223
+ hidden_channels:
224
+ - 768
225
+ - 1
226
+ activation_layer:
227
+ _target_: hydra.utils.get_class
228
+ path: torch.nn.SiLU
229
+ dropout: 0.0
230
+ disable_dropout_last_layer: true
231
+ disable_activation_last_layer: true
232
+ disable_norm_last_layer: true
233
+ gbf_module:
234
+ _target_: mldft.ml.models.components.gbf_module.GaussianLayer
235
+ basis_info:
236
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
237
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_info.yaml
238
+ atomic_numbers:
239
+ - 1
240
+ - 6
241
+ - 7
242
+ - 8
243
+ - 9
244
+ num_gaussians: 128
245
+ init_radius_range:
246
+ - 0
247
+ - 3
248
+ directed: true
249
+ normalized: true
250
+ node_embedding_module:
251
+ _target_: mldft.ml.models.components.node_embedding.NodeEmbedding.from_basis_info
252
+ basis_info:
253
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
254
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_info.yaml
255
+ atomic_numbers:
256
+ - 1
257
+ - 6
258
+ - 7
259
+ - 8
260
+ - 9
261
+ out_channels: 768
262
+ dst_in_channels: 128
263
+ p_hidden_channels: 768
264
+ p_num_layers: 3
265
+ p_activation:
266
+ _target_: hydra.utils.get_class
267
+ path: torch.nn.GELU
268
+ p_dropout: 0.0
269
+ dst_hidden_channels: 768
270
+ dst_num_layers: 3
271
+ dst_activation:
272
+ _target_: hydra.utils.get_class
273
+ path: torch.nn.GELU
274
+ dst_dropout: 0.0
275
+ lambda_co: 10.0
276
+ lambda_mul: 0.02
277
+ use_per_basis_func_shrink_gate: true
278
+ cutoff: null
279
+ gnn_module:
280
+ _target_: mldft.ml.models.components.g3d_stack.G3DStack
281
+ g3d_class:
282
+ _partial_: true
283
+ _target_: mldft.ml.models.components.g3d_layer_tf.G3DLayerTF
284
+ in_reps:
285
+ _target_: tensorframes.reps.Irreps
286
+ irreps: 513x0+85x1
287
+ n_layers: 4
288
+ heads: 32
289
+ edge_dim: 1
290
+ dropout: 0.0
291
+ attention_weight_dropout: 0.0
292
+ mlp_hidden_dim: null
293
+ mlp_activation:
294
+ _target_: hydra.utils.get_class
295
+ path: torch.nn.GELU
296
+ norm_layer_class:
297
+ _target_: torch_geometric.nn.norm.layer_norm.LayerNorm
298
+ _partial_: true
299
+ mode: node
300
+ activation_dropout: 0.0
301
+ cutoff: null
302
+ atom_ref_module:
303
+ _target_: mldft.ml.models.components.atom_ref.AtomRef.from_dataset_statistics
304
+ dataset_statistics:
305
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
306
+ path: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_statistics/dataset_statistics_labels_local_frames_global_symmetric_natrep_e_kin_plus_xc.zarr
307
+ weigher_key: has_energy_label
308
+ initial_guess_module:
309
+ _target_: mldft.ml.models.components.initial_guess_delta_module.InitialGuessDeltaModule
310
+ input_size: 768
311
+ basis_info:
312
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
313
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_info.yaml
314
+ atomic_numbers:
315
+ - 1
316
+ - 6
317
+ - 7
318
+ - 8
319
+ - 9
320
+ dataset_statistics:
321
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
322
+ path: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_statistics/dataset_statistics_labels_local_frames_global_symmetric_natrep_e_kin_plus_xc.zarr
323
+ weigher_key: initial_guess_only
324
+ activation_function:
325
+ _target_: hydra.utils.get_class
326
+ path: torch.nn.GELU
327
+ hidden_layers:
328
+ - 768
329
+ dropout: 0.0
330
+ dimension_wise_rescaling_module:
331
+ _target_: mldft.ml.models.components.dimension_wise_rescaling.DimensionWiseRescaling.from_dataset_statistics
332
+ dataset_statistics:
333
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
334
+ path: /export/scratch/ialgroup/dft_data/QM9_perturbed_fock/dataset_statistics/dataset_statistics_labels_local_frames_global_symmetric_natrep_e_kin_plus_xc.zarr
335
+ weigher_key: has_energy_label
336
+ s_coeff: 50
337
+ s_grad: 0.05
338
+ epsilon: 1.0e-08
339
+ callbacks:
340
+ learning_rate_monitor:
341
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
342
+ model_checkpoint:
343
+ _target_: mldft.ml.callbacks.checkpoint.ModelCheckpointWithPermissions
344
+ dirpath: /export/scratch/ialgroup/dft_str25/models/train/runs/088__from_checkpoint_009__str25\qm9_tf/checkpoints
345
+ filename: epoch_{epoch:03d}
346
+ monitor: val_loss/total
347
+ verbose: false
348
+ save_last: true
349
+ save_top_k: 1
350
+ mode: min
351
+ auto_insert_metric_name: false
352
+ save_weights_only: false
353
+ every_n_train_steps: null
354
+ train_time_interval: null
355
+ every_n_epochs: null
356
+ save_on_train_epoch_end: null
357
+ model_summary:
358
+ _target_: mldft.ml.callbacks.SubModelSummary
359
+ max_depth: -1
360
+ path_in_model: net
361
+ rich_progress_bar:
362
+ _target_: lightning.pytorch.callbacks.RichProgressBar
363
+ print_overrides:
364
+ _target_: mldft.ml.callbacks.PrintOverrides
365
+ compact: false
366
+ target_pred_scatters:
367
+ _target_: mldft.ml.callbacks.image_logging.LogTargetPredScatters
368
+ with_atom_ref: auto
369
+ train_timing:
370
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
371
+ initial_interval: 1000
372
+ val_timing:
373
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
374
+ initial_interval: 1000
375
+ gradient_scatter:
376
+ _target_: mldft.ml.callbacks.image_logging.LogGradientScatter
377
+ train_timing:
378
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
379
+ initial_interval: 1000
380
+ val_timing:
381
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
382
+ initial_interval: 1000
383
+ distance_embeddings:
384
+ _target_: mldft.ml.callbacks.image_logging.LogDistanceEmbeddings
385
+ max_distance: 8.0
386
+ n_distances: 1000
387
+ train_timing:
388
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
389
+ initial_interval: 1000
390
+ val_timing:
391
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
392
+ initial_interval: 1000
393
+ molecule_mesh_logging:
394
+ log_initial_guess: true
395
+ log_gradient: true
396
+ log_random_basis_functions: false
397
+ _target_: mldft.ml.callbacks.mesh_logging.LogMolecule
398
+ train_timing:
399
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
400
+ initial_interval: 1000
401
+ val_timing:
402
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
403
+ initial_interval: 1000
404
+ custom_scalars:
405
+ _target_: mldft.ml.callbacks.custom_scalars.AddMetricAndLossCustomScalars
406
+ interval: 1000
407
+ logger:
408
+ tensorboard:
409
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
410
+ save_dir: /export/scratch/ialgroup/dft_str25/models/train/runs/088__from_checkpoint_009__str25\qm9_tf
411
+ max_queue: 10000
412
+ name: null
413
+ log_graph: false
414
+ default_hp_metric: false
415
+ prefix: ''
416
+ version: ''
417
+ trainer:
418
+ _target_: lightning.pytorch.trainer.Trainer
419
+ default_root_dir: /export/scratch/ialgroup/dft_str25/models/train/runs/088__from_checkpoint_009__str25\qm9_tf
420
+ min_epochs: 1
421
+ max_epochs: 90
422
+ log_every_n_steps: 200
423
+ inference_mode: false
424
+ accelerator: auto
425
+ devices: 1
426
+ precision: 32
427
+ check_val_every_n_epoch: 1
428
+ deterministic: false
429
+ paths:
430
+ root_dir: /export/home/mickler/sciai-dft
431
+ data_dir: /export/scratch/ialgroup/dft_data
432
+ log_dir: /export/scratch/ialgroup/dft_str25/models
433
+ output_dir: /export/scratch/ialgroup/dft_str25/models/train/runs/088__from_checkpoint_009__str25\qm9_tf
434
+ work_dir: /export/home/mickler/sciai-dft
435
+ extras:
436
+ ignore_warnings: false
437
+ enforce_tags: true
438
+ print_config: true
439
+ hostname: compgpu7
440
+ local: {}
441
+ git:
442
+ sha: 0dbf4dcea9857269d00de317042686c330a76403
443
+ branch: change_tensorframes_version
444
+ is_dirty: false
trained-on-qm9/trained-on-qm9.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9759da26660c619de9c3bbf4c2dc164343ee90e08e22b3fcdcc9682dacb9bd09
3
+ size 224663166
trained-on-qmugs/hparams.yaml ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ name: ''
3
+ tags:
4
+ - qmugs_bin0_qm9_perturbed_fock
5
+ - kin_plus_xc
6
+ - graphformer
7
+ train: true
8
+ validate: true
9
+ test: false
10
+ ckpt_path: null
11
+ use_original_settings: null
12
+ weight_ckpt_path: /export/scratch/ialgroup/dft_str25/models/train/runs/110__from_checkpoint_063__str25\qmugs_hierarc_tf/checkpoints/last.ckpt
13
+ seed: 292311302
14
+ data:
15
+ datamodule:
16
+ _target_: mldft.ml.data.datamodule.OFDataModule
17
+ transforms: ${data.transforms}
18
+ split_file: ${oc.env:DFT_DATA}/${data.dataset_name}/split.pkl
19
+ data_dir: ${oc.env:DFT_DATA}
20
+ basis_info: ${data.basis_info}
21
+ batch_size: 128
22
+ num_workers: 32
23
+ pin_memory: false
24
+ shuffle_train: true
25
+ shuffle_val: false
26
+ shuffle_test: false
27
+ dataset_kwargs:
28
+ add_irreps: true
29
+ cache_in_memory: false
30
+ energy_key: e_${data.target_key}
31
+ gradient_key: grad_${data.target_key}
32
+ limit_scf_iterations:
33
+ - 6
34
+ - 7
35
+ - 8
36
+ - 9
37
+ - 10
38
+ - 11
39
+ - 12
40
+ - 13
41
+ - 14
42
+ - 15
43
+ - 16
44
+ - 17
45
+ - 18
46
+ - 19
47
+ - 20
48
+ - 21
49
+ - 22
50
+ - 23
51
+ - 24
52
+ - 25
53
+ - 26
54
+ - -1
55
+ keep_initial_guess: false
56
+ dataloader_kwargs:
57
+ follow_batch:
58
+ - coeffs
59
+ - atomic_numbers
60
+ list_keys: null
61
+ transforms:
62
+ cached_transforms:
63
+ name: local_frames_global_${data.natural_reparametrization.orthogonalization}_natrep
64
+ additional_pre_transforms:
65
+ - _target_: mldft.ml.data.components.convert_transforms.AddOverlapMatrix
66
+ basis_info: ${data.basis_info}
67
+ transforms:
68
+ - _target_: mldft.ml.data.components.basis_transforms.ToLocalFrames
69
+ sparse: false
70
+ - _target_: mldft.ml.data.components.basis_transforms.ToGlobalNatRep
71
+ orthogonalization: ${data.natural_reparametrization.orthogonalization}
72
+ _target_: mldft.ml.data.components.basis_transforms.MasterTransformation
73
+ name: local_frames_global_${data.natural_reparametrization.orthogonalization}_natrep
74
+ use_cached_data: true
75
+ pre_transforms:
76
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
77
+ float_dtype: torch.float64
78
+ - _target_: mldft.ml.data.components.convert_transforms.ProjectGradient
79
+ - _target_: mldft.ml.data.components.convert_transforms.AddRadiusEdgeIndex
80
+ radius: ${data.cutoff}
81
+ - _target_: mldft.ml.data.components.basis_transforms.AddLocalFrames
82
+ basis_transforms: []
83
+ post_transforms:
84
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
85
+ add_transformation_matrix: false
86
+ target_key: kin_plus_xc
87
+ dataset_statistics:
88
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
89
+ path: ${oc.env:DFT_DATA}/${data.dataset_name}/dataset_statistics/dataset_statistics_labels_${data.transforms.name}_${data.datamodule.dataset_kwargs.energy_key}.zarr
90
+ natural_reparametrization:
91
+ orthogonalization: symmetric
92
+ basis_info:
93
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
94
+ path_to_data_info: ${oc.env:DFT_DATA}/${data.dataset_name}/dataset_info.yaml
95
+ atomic_numbers:
96
+ - 1
97
+ - 6
98
+ - 7
99
+ - 8
100
+ - 9
101
+ cutoff: 6.0
102
+ cutoff_start: 0.0
103
+ dataset_name: QMUGSBin0_perturbed_fock
104
+ model:
105
+ optimizer:
106
+ _target_: torch.optim.AdamW
107
+ _partial_: true
108
+ lr: 1.0e-05
109
+ betas:
110
+ - 0.95
111
+ - 0.99
112
+ weight_decay: 1.0e-10
113
+ loss_function:
114
+ _target_: mldft.ml.models.components.loss_function.WeightedLoss
115
+ energy_loss:
116
+ weight: 0.1
117
+ loss:
118
+ _target_: mldft.ml.models.components.loss_function.EnergyLoss
119
+ loss_function:
120
+ _target_: torch.nn.L1Loss
121
+ reduction: none
122
+ sample_weigher:
123
+ _target_: mldft.ml.models.components.sample_weighers.HasEnergyLabelSampleWeigher
124
+ gradient_loss:
125
+ weight: 0.9
126
+ loss:
127
+ _target_: mldft.ml.models.components.loss_function.EnergyGradientLoss
128
+ loss_function:
129
+ _target_: torch.nn.L1Loss
130
+ reduction: none
131
+ sample_weigher:
132
+ _target_: mldft.ml.models.components.sample_weighers.HasEnergyLabelSampleWeigher
133
+ coefficient_loss:
134
+ weight: 0
135
+ loss:
136
+ _target_: mldft.ml.models.components.loss_function.CoefficientLoss
137
+ loss_function:
138
+ _target_: torch.nn.L1Loss
139
+ reduction: none
140
+ sample_weigher: null
141
+ scheduler:
142
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
143
+ _partial_: true
144
+ T_max: ${trainer.max_epochs}
145
+ eta_min: 0
146
+ last_epoch: -1
147
+ _target_: mldft.ml.models.mldft_module.MLDFTLitModule
148
+ variational: true
149
+ target_key: ${data.target_key}
150
+ compile: false
151
+ basis_info: ${data.basis_info}
152
+ metric_interval: 1
153
+ logging_mixin_interval: 1000
154
+ show_logging_mixins_in_progress_bar: false
155
+ net:
156
+ _target_: mldft.ml.models.components.graphformer.Graphformer
157
+ edge_mlp:
158
+ _target_: mldft.ml.models.components.mlp.MLP
159
+ in_channels: 128
160
+ hidden_channels:
161
+ - 768
162
+ - 32
163
+ activation_layer:
164
+ _target_: hydra.utils.get_class
165
+ path: torch.nn.SiLU
166
+ dropout: 0.0
167
+ energy_mlp:
168
+ _target_: mldft.ml.models.components.graphformer.MLPStack
169
+ in_channels: 768
170
+ hidden_channels:
171
+ - 768
172
+ - 1
173
+ activation_layer:
174
+ _target_: hydra.utils.get_class
175
+ path: torch.nn.SiLU
176
+ dropout: 0.0
177
+ disable_dropout_last_layer: true
178
+ disable_activation_last_layer: true
179
+ disable_norm_last_layer: true
180
+ mlp_class:
181
+ _partial_: true
182
+ _target_: mldft.ml.models.components.mlp.MLP
183
+ n_mlps: 4
184
+ gbf_module:
185
+ _target_: mldft.ml.models.components.gbf_module.GaussianLayer
186
+ basis_info: ${data.basis_info}
187
+ num_gaussians: 128
188
+ init_radius_range:
189
+ - 0
190
+ - 3
191
+ directed: true
192
+ normalized: true
193
+ node_embedding_module:
194
+ _target_: mldft.ml.models.components.node_embedding.NodeEmbedding.from_basis_info
195
+ basis_info: ${data.basis_info}
196
+ out_channels: 768
197
+ dst_in_channels: 128
198
+ p_hidden_channels: 768
199
+ p_num_layers: 3
200
+ p_activation:
201
+ _target_: hydra.utils.get_class
202
+ path: torch.nn.GELU
203
+ p_dropout: 0.0
204
+ dst_hidden_channels: 768
205
+ dst_num_layers: 3
206
+ dst_activation:
207
+ _target_: hydra.utils.get_class
208
+ path: torch.nn.GELU
209
+ dst_dropout: 0.0
210
+ lambda_co: 10.0
211
+ lambda_mul: 0.02
212
+ use_per_basis_func_shrink_gate: true
213
+ cutoff: null
214
+ gnn_module:
215
+ _target_: mldft.ml.models.components.g3d_stack.G3DStack
216
+ g3d_class:
217
+ _partial_: true
218
+ _target_: mldft.ml.models.components.g3d_layer_tf.G3DLayerTF
219
+ in_reps:
220
+ _target_: tensorframes.reps.Irreps
221
+ irreps: 513x0+85x1
222
+ n_layers: 8
223
+ heads: 32
224
+ edge_dim: 1
225
+ dropout: 0.0
226
+ attention_weight_dropout: 0.0
227
+ mlp_hidden_dim: null
228
+ mlp_activation:
229
+ _target_: hydra.utils.get_class
230
+ path: torch.nn.GELU
231
+ norm_layer_class:
232
+ _target_: torch_geometric.nn.norm.layer_norm.LayerNorm
233
+ _partial_: true
234
+ mode: node
235
+ activation_dropout: 0.0
236
+ cutoff: null
237
+ energy_readout_every: 2
238
+ atom_ref_module:
239
+ _target_: mldft.ml.models.components.atom_ref.AtomRef.from_dataset_statistics
240
+ dataset_statistics: ${data.dataset_statistics}
241
+ weigher_key: has_energy_label
242
+ initial_guess_module:
243
+ _target_: mldft.ml.models.components.initial_guess_delta_module.InitialGuessDeltaModule
244
+ input_size: 768
245
+ basis_info: ${data.basis_info}
246
+ dataset_statistics: ${data.dataset_statistics}
247
+ weigher_key: initial_guess_only
248
+ activation_function:
249
+ _target_: hydra.utils.get_class
250
+ path: torch.nn.GELU
251
+ hidden_layers:
252
+ - 768
253
+ dropout: 0.0
254
+ dimension_wise_rescaling_module:
255
+ _target_: mldft.ml.models.components.dimension_wise_rescaling.DimensionWiseRescaling.from_dataset_statistics
256
+ dataset_statistics: ${data.dataset_statistics}
257
+ weigher_key: has_energy_label
258
+ s_coeff: 50
259
+ s_grad: 0.05
260
+ epsilon: 1.0e-08
261
+ callbacks:
262
+ learning_rate_monitor:
263
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
264
+ model_checkpoint:
265
+ _target_: mldft.ml.callbacks.checkpoint.ModelCheckpointWithPermissions
266
+ dirpath: ${paths.output_dir}/checkpoints
267
+ filename: epoch_{epoch:03d}
268
+ monitor: val_loss/total
269
+ verbose: false
270
+ save_last: true
271
+ save_top_k: 1
272
+ mode: min
273
+ auto_insert_metric_name: false
274
+ save_weights_only: false
275
+ every_n_train_steps: null
276
+ train_time_interval: null
277
+ every_n_epochs: null
278
+ save_on_train_epoch_end: null
279
+ model_summary:
280
+ _target_: mldft.ml.callbacks.SubModelSummary
281
+ max_depth: -1
282
+ path_in_model: net
283
+ rich_progress_bar:
284
+ _target_: lightning.pytorch.callbacks.RichProgressBar
285
+ print_overrides:
286
+ _target_: mldft.ml.callbacks.PrintOverrides
287
+ compact: false
288
+ target_pred_scatters:
289
+ _target_: mldft.ml.callbacks.image_logging.LogTargetPredScatters
290
+ with_atom_ref: auto
291
+ train_timing:
292
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
293
+ initial_interval: ${callbacks.interval}
294
+ val_timing:
295
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
296
+ initial_interval: ${callbacks.interval}
297
+ gradient_scatter:
298
+ _target_: mldft.ml.callbacks.image_logging.LogGradientScatter
299
+ train_timing:
300
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
301
+ initial_interval: ${callbacks.interval}
302
+ val_timing:
303
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
304
+ initial_interval: ${callbacks.interval}
305
+ distance_embeddings:
306
+ _target_: mldft.ml.callbacks.image_logging.LogDistanceEmbeddings
307
+ max_distance: 8.0
308
+ n_distances: 1000
309
+ train_timing:
310
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
311
+ initial_interval: ${callbacks.interval}
312
+ val_timing:
313
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
314
+ initial_interval: ${callbacks.interval}
315
+ molecule_mesh_logging:
316
+ log_initial_guess: true
317
+ log_gradient: true
318
+ log_random_basis_functions: false
319
+ _target_: mldft.ml.callbacks.mesh_logging.LogMolecule
320
+ train_timing:
321
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
322
+ initial_interval: ${callbacks.interval}
323
+ val_timing:
324
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
325
+ initial_interval: ${callbacks.interval}
326
+ custom_scalars:
327
+ _target_: mldft.ml.callbacks.custom_scalars.AddMetricAndLossCustomScalars
328
+ interval: 1000
329
+ logger:
330
+ tensorboard:
331
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
332
+ save_dir: ${paths.output_dir}
333
+ max_queue: 10000
334
+ name: null
335
+ log_graph: false
336
+ default_hp_metric: false
337
+ prefix: ''
338
+ version: ''
339
+ trainer:
340
+ _target_: lightning.pytorch.trainer.Trainer
341
+ default_root_dir: ${paths.output_dir}
342
+ min_epochs: 1
343
+ max_epochs: 30
344
+ log_every_n_steps: 200
345
+ inference_mode: false
346
+ accelerator: auto
347
+ devices: 1
348
+ precision: 32
349
+ check_val_every_n_epoch: 1
350
+ deterministic: false
351
+ paths:
352
+ root_dir: ${oc.env:PROJECT_ROOT}
353
+ data_dir: ${oc.env:DFT_DATA}
354
+ log_dir: ${oc.env:DFT_MODELS}
355
+ output_dir: ${hydra:runtime.output_dir}
356
+ work_dir: ${hydra:runtime.cwd}
357
+ extras:
358
+ ignore_warnings: false
359
+ enforce_tags: true
360
+ print_config: true
361
+ hostname: compgpu11
362
+ local: {}
363
+ git:
364
+ sha: 58990f9002b6e36eb94f874ccb8dc9a3609236ab
365
+ branch: main
366
+ is_dirty: true
trained-on-qmugs/hparams_resolved.yaml ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ name: ''
3
+ tags:
4
+ - qmugs_bin0_qm9_perturbed_fock
5
+ - kin_plus_xc
6
+ - graphformer
7
+ train: true
8
+ validate: true
9
+ test: false
10
+ ckpt_path: null
11
+ use_original_settings: null
12
+ weight_ckpt_path: /export/scratch/ialgroup/dft_str25/models/train/runs/110__from_checkpoint_063__str25\qmugs_hierarc_tf/checkpoints/last.ckpt
13
+ seed: 292311302
14
+ data:
15
+ datamodule:
16
+ _target_: mldft.ml.data.datamodule.OFDataModule
17
+ transforms:
18
+ cached_transforms:
19
+ name: local_frames_global_symmetric_natrep
20
+ additional_pre_transforms:
21
+ - _target_: mldft.ml.data.components.convert_transforms.AddOverlapMatrix
22
+ basis_info:
23
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
24
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_info.yaml
25
+ atomic_numbers:
26
+ - 1
27
+ - 6
28
+ - 7
29
+ - 8
30
+ - 9
31
+ transforms:
32
+ - _target_: mldft.ml.data.components.basis_transforms.ToLocalFrames
33
+ sparse: false
34
+ - _target_: mldft.ml.data.components.basis_transforms.ToGlobalNatRep
35
+ orthogonalization: symmetric
36
+ _target_: mldft.ml.data.components.basis_transforms.MasterTransformation
37
+ name: local_frames_global_symmetric_natrep
38
+ use_cached_data: true
39
+ pre_transforms:
40
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
41
+ float_dtype: torch.float64
42
+ - _target_: mldft.ml.data.components.convert_transforms.ProjectGradient
43
+ - _target_: mldft.ml.data.components.convert_transforms.AddRadiusEdgeIndex
44
+ radius: 6.0
45
+ - _target_: mldft.ml.data.components.basis_transforms.AddLocalFrames
46
+ basis_transforms: []
47
+ post_transforms:
48
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
49
+ add_transformation_matrix: false
50
+ split_file: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/split.pkl
51
+ data_dir: /export/scratch/ialgroup/dft_data
52
+ basis_info:
53
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
54
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_info.yaml
55
+ atomic_numbers:
56
+ - 1
57
+ - 6
58
+ - 7
59
+ - 8
60
+ - 9
61
+ batch_size: 128
62
+ num_workers: 32
63
+ pin_memory: false
64
+ shuffle_train: true
65
+ shuffle_val: false
66
+ shuffle_test: false
67
+ dataset_kwargs:
68
+ add_irreps: true
69
+ cache_in_memory: false
70
+ energy_key: e_kin_plus_xc
71
+ gradient_key: grad_kin_plus_xc
72
+ limit_scf_iterations:
73
+ - 6
74
+ - 7
75
+ - 8
76
+ - 9
77
+ - 10
78
+ - 11
79
+ - 12
80
+ - 13
81
+ - 14
82
+ - 15
83
+ - 16
84
+ - 17
85
+ - 18
86
+ - 19
87
+ - 20
88
+ - 21
89
+ - 22
90
+ - 23
91
+ - 24
92
+ - 25
93
+ - 26
94
+ - -1
95
+ keep_initial_guess: false
96
+ dataloader_kwargs:
97
+ follow_batch:
98
+ - coeffs
99
+ - atomic_numbers
100
+ list_keys: null
101
+ transforms:
102
+ cached_transforms:
103
+ name: local_frames_global_symmetric_natrep
104
+ additional_pre_transforms:
105
+ - _target_: mldft.ml.data.components.convert_transforms.AddOverlapMatrix
106
+ basis_info:
107
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
108
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_info.yaml
109
+ atomic_numbers:
110
+ - 1
111
+ - 6
112
+ - 7
113
+ - 8
114
+ - 9
115
+ transforms:
116
+ - _target_: mldft.ml.data.components.basis_transforms.ToLocalFrames
117
+ sparse: false
118
+ - _target_: mldft.ml.data.components.basis_transforms.ToGlobalNatRep
119
+ orthogonalization: symmetric
120
+ _target_: mldft.ml.data.components.basis_transforms.MasterTransformation
121
+ name: local_frames_global_symmetric_natrep
122
+ use_cached_data: true
123
+ pre_transforms:
124
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
125
+ float_dtype: torch.float64
126
+ - _target_: mldft.ml.data.components.convert_transforms.ProjectGradient
127
+ - _target_: mldft.ml.data.components.convert_transforms.AddRadiusEdgeIndex
128
+ radius: 6.0
129
+ - _target_: mldft.ml.data.components.basis_transforms.AddLocalFrames
130
+ basis_transforms: []
131
+ post_transforms:
132
+ - _target_: mldft.ml.data.components.convert_transforms.ToTorch
133
+ add_transformation_matrix: false
134
+ target_key: kin_plus_xc
135
+ dataset_statistics:
136
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
137
+ path: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_statistics/dataset_statistics_labels_local_frames_global_symmetric_natrep_e_kin_plus_xc.zarr
138
+ natural_reparametrization:
139
+ orthogonalization: symmetric
140
+ basis_info:
141
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
142
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_info.yaml
143
+ atomic_numbers:
144
+ - 1
145
+ - 6
146
+ - 7
147
+ - 8
148
+ - 9
149
+ cutoff: 6.0
150
+ cutoff_start: 0.0
151
+ dataset_name: QMUGSBin0_perturbed_fock
152
+ model:
153
+ optimizer:
154
+ _target_: torch.optim.AdamW
155
+ _partial_: true
156
+ lr: 1.0e-05
157
+ betas:
158
+ - 0.95
159
+ - 0.99
160
+ weight_decay: 1.0e-10
161
+ loss_function:
162
+ _target_: mldft.ml.models.components.loss_function.WeightedLoss
163
+ energy_loss:
164
+ weight: 0.1
165
+ loss:
166
+ _target_: mldft.ml.models.components.loss_function.EnergyLoss
167
+ loss_function:
168
+ _target_: torch.nn.L1Loss
169
+ reduction: none
170
+ sample_weigher:
171
+ _target_: mldft.ml.models.components.sample_weighers.HasEnergyLabelSampleWeigher
172
+ gradient_loss:
173
+ weight: 0.9
174
+ loss:
175
+ _target_: mldft.ml.models.components.loss_function.EnergyGradientLoss
176
+ loss_function:
177
+ _target_: torch.nn.L1Loss
178
+ reduction: none
179
+ sample_weigher:
180
+ _target_: mldft.ml.models.components.sample_weighers.HasEnergyLabelSampleWeigher
181
+ coefficient_loss:
182
+ weight: 0
183
+ loss:
184
+ _target_: mldft.ml.models.components.loss_function.CoefficientLoss
185
+ loss_function:
186
+ _target_: torch.nn.L1Loss
187
+ reduction: none
188
+ sample_weigher: null
189
+ scheduler:
190
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
191
+ _partial_: true
192
+ T_max: 30
193
+ eta_min: 0
194
+ last_epoch: -1
195
+ _target_: mldft.ml.models.mldft_module.MLDFTLitModule
196
+ variational: true
197
+ target_key: kin_plus_xc
198
+ compile: false
199
+ basis_info:
200
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
201
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_info.yaml
202
+ atomic_numbers:
203
+ - 1
204
+ - 6
205
+ - 7
206
+ - 8
207
+ - 9
208
+ metric_interval: 1
209
+ logging_mixin_interval: 1000
210
+ show_logging_mixins_in_progress_bar: false
211
+ net:
212
+ _target_: mldft.ml.models.components.graphformer.Graphformer
213
+ edge_mlp:
214
+ _target_: mldft.ml.models.components.mlp.MLP
215
+ in_channels: 128
216
+ hidden_channels:
217
+ - 768
218
+ - 32
219
+ activation_layer:
220
+ _target_: hydra.utils.get_class
221
+ path: torch.nn.SiLU
222
+ dropout: 0.0
223
+ energy_mlp:
224
+ _target_: mldft.ml.models.components.graphformer.MLPStack
225
+ in_channels: 768
226
+ hidden_channels:
227
+ - 768
228
+ - 1
229
+ activation_layer:
230
+ _target_: hydra.utils.get_class
231
+ path: torch.nn.SiLU
232
+ dropout: 0.0
233
+ disable_dropout_last_layer: true
234
+ disable_activation_last_layer: true
235
+ disable_norm_last_layer: true
236
+ mlp_class:
237
+ _partial_: true
238
+ _target_: mldft.ml.models.components.mlp.MLP
239
+ n_mlps: 4
240
+ gbf_module:
241
+ _target_: mldft.ml.models.components.gbf_module.GaussianLayer
242
+ basis_info:
243
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
244
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_info.yaml
245
+ atomic_numbers:
246
+ - 1
247
+ - 6
248
+ - 7
249
+ - 8
250
+ - 9
251
+ num_gaussians: 128
252
+ init_radius_range:
253
+ - 0
254
+ - 3
255
+ directed: true
256
+ normalized: true
257
+ node_embedding_module:
258
+ _target_: mldft.ml.models.components.node_embedding.NodeEmbedding.from_basis_info
259
+ basis_info:
260
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
261
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_info.yaml
262
+ atomic_numbers:
263
+ - 1
264
+ - 6
265
+ - 7
266
+ - 8
267
+ - 9
268
+ out_channels: 768
269
+ dst_in_channels: 128
270
+ p_hidden_channels: 768
271
+ p_num_layers: 3
272
+ p_activation:
273
+ _target_: hydra.utils.get_class
274
+ path: torch.nn.GELU
275
+ p_dropout: 0.0
276
+ dst_hidden_channels: 768
277
+ dst_num_layers: 3
278
+ dst_activation:
279
+ _target_: hydra.utils.get_class
280
+ path: torch.nn.GELU
281
+ dst_dropout: 0.0
282
+ lambda_co: 10.0
283
+ lambda_mul: 0.02
284
+ use_per_basis_func_shrink_gate: true
285
+ cutoff: null
286
+ gnn_module:
287
+ _target_: mldft.ml.models.components.g3d_stack.G3DStack
288
+ g3d_class:
289
+ _partial_: true
290
+ _target_: mldft.ml.models.components.g3d_layer_tf.G3DLayerTF
291
+ in_reps:
292
+ _target_: tensorframes.reps.Irreps
293
+ irreps: 513x0+85x1
294
+ n_layers: 8
295
+ heads: 32
296
+ edge_dim: 1
297
+ dropout: 0.0
298
+ attention_weight_dropout: 0.0
299
+ mlp_hidden_dim: null
300
+ mlp_activation:
301
+ _target_: hydra.utils.get_class
302
+ path: torch.nn.GELU
303
+ norm_layer_class:
304
+ _target_: torch_geometric.nn.norm.layer_norm.LayerNorm
305
+ _partial_: true
306
+ mode: node
307
+ activation_dropout: 0.0
308
+ cutoff: null
309
+ energy_readout_every: 2
310
+ atom_ref_module:
311
+ _target_: mldft.ml.models.components.atom_ref.AtomRef.from_dataset_statistics
312
+ dataset_statistics:
313
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
314
+ path: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_statistics/dataset_statistics_labels_local_frames_global_symmetric_natrep_e_kin_plus_xc.zarr
315
+ weigher_key: has_energy_label
316
+ initial_guess_module:
317
+ _target_: mldft.ml.models.components.initial_guess_delta_module.InitialGuessDeltaModule
318
+ input_size: 768
319
+ basis_info:
320
+ _target_: mldft.ml.data.components.basis_info.BasisInfo.from_dataset_info_yaml
321
+ path_to_data_info: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_info.yaml
322
+ atomic_numbers:
323
+ - 1
324
+ - 6
325
+ - 7
326
+ - 8
327
+ - 9
328
+ dataset_statistics:
329
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
330
+ path: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_statistics/dataset_statistics_labels_local_frames_global_symmetric_natrep_e_kin_plus_xc.zarr
331
+ weigher_key: initial_guess_only
332
+ activation_function:
333
+ _target_: hydra.utils.get_class
334
+ path: torch.nn.GELU
335
+ hidden_layers:
336
+ - 768
337
+ dropout: 0.0
338
+ dimension_wise_rescaling_module:
339
+ _target_: mldft.ml.models.components.dimension_wise_rescaling.DimensionWiseRescaling.from_dataset_statistics
340
+ dataset_statistics:
341
+ _target_: mldft.ml.preprocess.dataset_statistics.DatasetStatistics
342
+ path: /export/scratch/ialgroup/dft_data/QMUGSBin0_perturbed_fock/dataset_statistics/dataset_statistics_labels_local_frames_global_symmetric_natrep_e_kin_plus_xc.zarr
343
+ weigher_key: has_energy_label
344
+ s_coeff: 50
345
+ s_grad: 0.05
346
+ epsilon: 1.0e-08
347
+ callbacks:
348
+ learning_rate_monitor:
349
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
350
+ model_checkpoint:
351
+ _target_: mldft.ml.callbacks.checkpoint.ModelCheckpointWithPermissions
352
+ dirpath: /export/scratch/ialgroup/dft_str25/models/train/runs/214__num_workers-32__qmugs_bin0_perturbed_fock__str25\qmugs_hard_cutoff_hierarc_tf__lr-1e-5__max_epochs-30__from_weight_checkpoint_110/checkpoints
353
+ filename: epoch_{epoch:03d}
354
+ monitor: val_loss/total
355
+ verbose: false
356
+ save_last: true
357
+ save_top_k: 1
358
+ mode: min
359
+ auto_insert_metric_name: false
360
+ save_weights_only: false
361
+ every_n_train_steps: null
362
+ train_time_interval: null
363
+ every_n_epochs: null
364
+ save_on_train_epoch_end: null
365
+ model_summary:
366
+ _target_: mldft.ml.callbacks.SubModelSummary
367
+ max_depth: -1
368
+ path_in_model: net
369
+ rich_progress_bar:
370
+ _target_: lightning.pytorch.callbacks.RichProgressBar
371
+ print_overrides:
372
+ _target_: mldft.ml.callbacks.PrintOverrides
373
+ compact: false
374
+ target_pred_scatters:
375
+ _target_: mldft.ml.callbacks.image_logging.LogTargetPredScatters
376
+ with_atom_ref: auto
377
+ train_timing:
378
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
379
+ initial_interval: 1000
380
+ val_timing:
381
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
382
+ initial_interval: 1000
383
+ gradient_scatter:
384
+ _target_: mldft.ml.callbacks.image_logging.LogGradientScatter
385
+ train_timing:
386
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
387
+ initial_interval: 1000
388
+ val_timing:
389
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
390
+ initial_interval: 1000
391
+ distance_embeddings:
392
+ _target_: mldft.ml.callbacks.image_logging.LogDistanceEmbeddings
393
+ max_distance: 8.0
394
+ n_distances: 1000
395
+ train_timing:
396
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
397
+ initial_interval: 1000
398
+ val_timing:
399
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
400
+ initial_interval: 1000
401
+ molecule_mesh_logging:
402
+ log_initial_guess: true
403
+ log_gradient: true
404
+ log_random_basis_functions: false
405
+ _target_: mldft.ml.callbacks.mesh_logging.LogMolecule
406
+ train_timing:
407
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
408
+ initial_interval: 1000
409
+ val_timing:
410
+ _target_: mldft.ml.callbacks.timing.EveryIncreasingInterval
411
+ initial_interval: 1000
412
+ custom_scalars:
413
+ _target_: mldft.ml.callbacks.custom_scalars.AddMetricAndLossCustomScalars
414
+ interval: 1000
415
+ logger:
416
+ tensorboard:
417
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
418
+ save_dir: /export/scratch/ialgroup/dft_str25/models/train/runs/214__num_workers-32__qmugs_bin0_perturbed_fock__str25\qmugs_hard_cutoff_hierarc_tf__lr-1e-5__max_epochs-30__from_weight_checkpoint_110
419
+ max_queue: 10000
420
+ name: null
421
+ log_graph: false
422
+ default_hp_metric: false
423
+ prefix: ''
424
+ version: ''
425
+ trainer:
426
+ _target_: lightning.pytorch.trainer.Trainer
427
+ default_root_dir: /export/scratch/ialgroup/dft_str25/models/train/runs/214__num_workers-32__qmugs_bin0_perturbed_fock__str25\qmugs_hard_cutoff_hierarc_tf__lr-1e-5__max_epochs-30__from_weight_checkpoint_110
428
+ min_epochs: 1
429
+ max_epochs: 30
430
+ log_every_n_steps: 200
431
+ inference_mode: false
432
+ accelerator: auto
433
+ devices: 1
434
+ precision: 32
435
+ check_val_every_n_epoch: 1
436
+ deterministic: false
437
+ paths:
438
+ root_dir: /export/home/mickler/sciai-dft
439
+ data_dir: /export/scratch/ialgroup/dft_data
440
+ log_dir: /export/scratch/ialgroup/dft_str25/models
441
+ output_dir: /export/scratch/ialgroup/dft_str25/models/train/runs/214__num_workers-32__qmugs_bin0_perturbed_fock__str25\qmugs_hard_cutoff_hierarc_tf__lr-1e-5__max_epochs-30__from_weight_checkpoint_110
442
+ work_dir: /export/home/mickler/sciai-dft
443
+ extras:
444
+ ignore_warnings: false
445
+ enforce_tags: true
446
+ print_config: true
447
+ hostname: compgpu11
448
+ local: {}
449
+ git:
450
+ sha: 58990f9002b6e36eb94f874ccb8dc9a3609236ab
451
+ branch: main
452
+ is_dirty: true
trained-on-qmugs/trained-on-qmugs.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dde9e2e940ebbfcf4c74681b3264c1add71bf3539634e1b81bacffd5bd08be32
3
+ size 417147510