martirossyan commited on
Commit
4fcc7b9
·
verified ·
1 Parent(s): 8279050

Upload 22 files

Browse files
EncDec-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:633dd3695208efed95dce13097427939e74423c993c74cba200ca9e405183108
3
+ size 49644411
EncDec-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
13
+ init_args:
14
+ switch_time: 0.796130965510696
15
+ power: 1.0
16
+ gamma:
17
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
18
+ init_args:
19
+ a: 0.6557615904788995
20
+ switch_time: 0.796130965510696
21
+ power: 1.0
22
+ epsilon: null
23
+ differential_equation_type: "ODE"
24
+ integrator_kwargs:
25
+ method: "euler"
26
+ velocity_annealing_factor: 14.941666601494628
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.LinearInterpolant
32
+ gamma: null
33
+ epsilon: null
34
+ differential_equation_type: "ODE"
35
+ integrator_kwargs:
36
+ method: "euler"
37
+ velocity_annealing_factor: 0.3178550359129071
38
+ correct_center_of_mass_motion: false
39
+ data_fields:
40
+ # if the order of the data_fields changes,
41
+ # the order of the above StochasticInterpolant inputs must also change
42
+ - "species"
43
+ - "pos"
44
+ - "cell"
45
+ integration_time_steps: 460
46
+ relative_si_costs:
47
+ species_loss: 0.0
48
+ pos_loss_b: 0.8563010628686587
49
+ cell_loss_b: 0.14369893713134133
50
+ sampler:
51
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
52
+ init_args:
53
+ pos_distribution: null
54
+ cell_distribution:
55
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
56
+ init_args:
57
+ dataset_name: perov_5
58
+ species_distribution:
59
+ class_path: omg.sampler.distributions.MirrorData
60
+ model:
61
+ class_path: omg.model.model.Model
62
+ init_args:
63
+ encoder:
64
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
65
+ head:
66
+ class_path: omg.model.heads.pass_through.PassThrough
67
+ time_embedder:
68
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
69
+ init_args:
70
+ dim: 256
71
+ use_min_perm_dist: True
72
+ float_32_matmul_precision: "high"
73
+ validation_mode: "match_rate"
74
+ dataset_name: "perov_5"
75
+ data:
76
+ train_dataset:
77
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
78
+ init_args:
79
+ dataset:
80
+ class_path: omg.datamodule.datamodule.DataModule
81
+ init_args:
82
+ lmdb_paths:
83
+ - "data/perov_5/train.lmdb"
84
+ niggli: True
85
+ val_dataset:
86
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
87
+ init_args:
88
+ dataset:
89
+ class_path: omg.datamodule.datamodule.DataModule
90
+ init_args:
91
+ lmdb_paths:
92
+ - "data/perov_5/val.lmdb"
93
+ niggli: True
94
+ predict_dataset:
95
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
96
+ init_args:
97
+ dataset:
98
+ class_path: omg.datamodule.datamodule.DataModule
99
+ init_args:
100
+ lmdb_paths:
101
+ - "data/perov_5/test.lmdb"
102
+ niggli: True
103
+ batch_size: 128
104
+ num_workers: 4
105
+ pin_memory: True
106
+ persistent_workers: True
107
+ trainer:
108
+ callbacks:
109
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
110
+ init_args:
111
+ filename: "best_val_loss_total"
112
+ save_top_k: 1
113
+ monitor: "val_loss_total"
114
+ save_weights_only: true
115
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
116
+ init_args:
117
+ filename: "best_val_match_rate"
118
+ save_top_k: 1
119
+ monitor: "match_rate"
120
+ save_weights_only: true
121
+ mode: 'max'
122
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
123
+ init_args:
124
+ filename: "best_val_rmsd"
125
+ save_top_k: 1
126
+ monitor: "mean_rmsd"
127
+ save_weights_only: true
128
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
129
+ init_args:
130
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
131
+ monitor: "val_loss_total"
132
+ every_n_epochs: 100
133
+ save_weights_only: false
134
+ gradient_clip_val: 0.5
135
+ num_sanity_val_steps: 0
136
+ precision: "32-true"
137
+ max_epochs: 6000
138
+ enable_progress_bar: false
139
+ check_val_every_n_epoch: 100
140
+ optimizer:
141
+ class_path: torch.optim.Adam
142
+ init_args:
143
+ lr: 7.808103295004345e-05
EncDec-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d59a9c011ebebaaa0d6e3180f5d433a3d4038bd100365577632cd9880ac8da99
3
+ size 49644411
EncDec-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
13
+ init_args:
14
+ switch_time: 0.6055018323069807
15
+ power: 1.0
16
+ gamma:
17
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
18
+ init_args:
19
+ a: 8.454472851641802
20
+ switch_time: 0.6055018323069807
21
+ power: 1.0
22
+ epsilon:
23
+ class_path: omg.si.epsilon.VanishingEpsilon
24
+ init_args:
25
+ c: 4.609299406421399
26
+ mu: 0.2674947568710694
27
+ sigma: 0.04906444616252471
28
+ differential_equation_type: "SDE"
29
+ integrator_kwargs:
30
+ method: "euler"
31
+ dt: 0.001074273488484323
32
+ velocity_annealing_factor: 14.554387706860773
33
+ correct_center_of_mass_motion: true
34
+ # lattice vectors
35
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
36
+ init_args:
37
+ interpolant: omg.si.interpolants.LinearInterpolant
38
+ gamma:
39
+ class_path: omg.si.gamma.LatentGammaSqrt
40
+ init_args:
41
+ a: 0.1539379702797485
42
+ epsilon: null
43
+ differential_equation_type: "ODE"
44
+ integrator_kwargs:
45
+ method: "euler"
46
+ velocity_annealing_factor: 0.07461560076268103
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 930
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.28276509270307465
58
+ pos_loss_z: 0.7168554318065845
59
+ cell_loss_b: 0.0003794754903409129
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution: null
64
+ cell_distribution:
65
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
66
+ init_args:
67
+ dataset_name: perov_5
68
+ species_distribution:
69
+ class_path: omg.sampler.distributions.MirrorData
70
+ model:
71
+ class_path: omg.model.model.Model
72
+ init_args:
73
+ encoder:
74
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
75
+ head:
76
+ class_path: omg.model.heads.pass_through.PassThrough
77
+ time_embedder:
78
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
79
+ init_args:
80
+ dim: 256
81
+ use_min_perm_dist: True
82
+ float_32_matmul_precision: "high"
83
+ validation_mode: "match_rate"
84
+ dataset_name: "perov_5"
85
+ data:
86
+ train_dataset:
87
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
88
+ init_args:
89
+ dataset:
90
+ class_path: omg.datamodule.datamodule.DataModule
91
+ init_args:
92
+ lmdb_paths:
93
+ - "data/perov_5/train.lmdb"
94
+ niggli: False
95
+ val_dataset:
96
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
97
+ init_args:
98
+ dataset:
99
+ class_path: omg.datamodule.datamodule.DataModule
100
+ init_args:
101
+ lmdb_paths:
102
+ - "data/perov_5/val.lmdb"
103
+ niggli: False
104
+ predict_dataset:
105
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
106
+ init_args:
107
+ dataset:
108
+ class_path: omg.datamodule.datamodule.DataModule
109
+ init_args:
110
+ lmdb_paths:
111
+ - "data/perov_5/test.lmdb"
112
+ niggli: False
113
+ batch_size: 128
114
+ num_workers: 4
115
+ pin_memory: True
116
+ persistent_workers: True
117
+ trainer:
118
+ callbacks:
119
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
120
+ init_args:
121
+ filename: "best_val_loss_total"
122
+ save_top_k: 1
123
+ monitor: "val_loss_total"
124
+ save_weights_only: true
125
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
126
+ init_args:
127
+ filename: "best_val_match_rate"
128
+ save_top_k: 1
129
+ monitor: "match_rate"
130
+ save_weights_only: true
131
+ mode: 'max'
132
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
133
+ init_args:
134
+ filename: "best_val_rmsd"
135
+ save_top_k: 1
136
+ monitor: "mean_rmsd"
137
+ save_weights_only: true
138
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
139
+ init_args:
140
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
141
+ monitor: "val_loss_total"
142
+ every_n_epochs: 100
143
+ save_weights_only: false
144
+ gradient_clip_val: 0.5
145
+ num_sanity_val_steps: 0
146
+ precision: "32-true"
147
+ max_epochs: 6000
148
+ enable_progress_bar: false
149
+ check_val_every_n_epoch: 100
150
+ optimizer:
151
+ class_path: torch.optim.Adam
152
+ init_args:
153
+ lr: 0.0002837109869864481
Linear-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e94ea8faa2a598e29fec87b82cf2f498adce196f020b1217a1078988ad39235
3
+ size 49644411
Linear-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.03386737488191369
16
+ epsilon: null
17
+ differential_equation_type: "ODE"
18
+ integrator_kwargs:
19
+ method: "euler"
20
+ velocity_annealing_factor: 0.007950108070075533
21
+ correct_center_of_mass_motion: true
22
+ # lattice vectors
23
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
24
+ init_args:
25
+ interpolant: omg.si.interpolants.LinearInterpolant
26
+ gamma: null
27
+ epsilon: null
28
+ differential_equation_type: "ODE"
29
+ integrator_kwargs:
30
+ method: "euler"
31
+ velocity_annealing_factor: 12.194837909618993
32
+ correct_center_of_mass_motion: false
33
+ data_fields:
34
+ # if the order of the data_fields changes,
35
+ # the order of the above StochasticInterpolant inputs must also change
36
+ - "species"
37
+ - "pos"
38
+ - "cell"
39
+ integration_time_steps: 820
40
+ relative_si_costs:
41
+ species_loss: 0.0
42
+ pos_loss_b: 0.9724021294519893
43
+ cell_loss_b: 0.0275978705480107
44
+ sampler:
45
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
46
+ init_args:
47
+ pos_distribution: null
48
+ cell_distribution:
49
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
50
+ init_args:
51
+ dataset_name: perov_5
52
+ species_distribution:
53
+ class_path: omg.sampler.distributions.MirrorData
54
+ model:
55
+ class_path: omg.model.model.Model
56
+ init_args:
57
+ encoder:
58
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
59
+ head:
60
+ class_path: omg.model.heads.pass_through.PassThrough
61
+ time_embedder:
62
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
63
+ init_args:
64
+ dim: 256
65
+ use_min_perm_dist: True
66
+ float_32_matmul_precision: "high"
67
+ validation_mode: "match_rate"
68
+ dataset_name: "perov_5"
69
+ data:
70
+ train_dataset:
71
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
72
+ init_args:
73
+ dataset:
74
+ class_path: omg.datamodule.datamodule.DataModule
75
+ init_args:
76
+ lmdb_paths:
77
+ - "data/perov_5/train.lmdb"
78
+ niggli: True
79
+ val_dataset:
80
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
81
+ init_args:
82
+ dataset:
83
+ class_path: omg.datamodule.datamodule.DataModule
84
+ init_args:
85
+ lmdb_paths:
86
+ - "data/perov_5/val.lmdb"
87
+ niggli: True
88
+ predict_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/perov_5/test.lmdb"
96
+ niggli: True
97
+ batch_size: 512
98
+ num_workers: 4
99
+ pin_memory: True
100
+ persistent_workers: True
101
+ trainer:
102
+ callbacks:
103
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
104
+ init_args:
105
+ filename: "best_val_loss_total"
106
+ save_top_k: 1
107
+ monitor: "val_loss_total"
108
+ save_weights_only: true
109
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
110
+ init_args:
111
+ filename: "best_val_match_rate"
112
+ save_top_k: 1
113
+ monitor: "match_rate"
114
+ save_weights_only: true
115
+ mode: 'max'
116
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
117
+ init_args:
118
+ filename: "best_val_rmsd"
119
+ save_top_k: 1
120
+ monitor: "mean_rmsd"
121
+ save_weights_only: true
122
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
123
+ init_args:
124
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
125
+ monitor: "val_loss_total"
126
+ every_n_epochs: 100
127
+ save_weights_only: false
128
+ gradient_clip_val: 0.5
129
+ num_sanity_val_steps: 0
130
+ precision: "32-true"
131
+ max_epochs: 6000
132
+ enable_progress_bar: false
133
+ check_val_every_n_epoch: 100
134
+ optimizer:
135
+ class_path: torch.optim.Adam
136
+ init_args:
137
+ lr: 3.6259796277646535e-05
Linear-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53c7377693f3c24c663e019f2beacbb82e31c77c24111ea420e55fa860a005fb
3
+ size 49644411
Linear-ODE/train.yaml ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma: null
13
+ epsilon: null
14
+ differential_equation_type: "ODE"
15
+ integrator_kwargs:
16
+ method: "euler"
17
+ velocity_annealing_factor: 0.004755207270677389
18
+ correct_center_of_mass_motion: true
19
+ # lattice vectors
20
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
21
+ init_args:
22
+ interpolant:
23
+ class_path: omg.si.interpolants.EncoderDecoderInterpolant
24
+ init_args:
25
+ switch_time: 0.46351945271978645
26
+ power: 1.0
27
+ gamma:
28
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
29
+ init_args:
30
+ a: 0.8167071952445664
31
+ switch_time: 0.46351945271978645
32
+ power: 1.0
33
+ epsilon: null
34
+ differential_equation_type: "ODE"
35
+ integrator_kwargs:
36
+ method: "euler"
37
+ velocity_annealing_factor: 13.921408921615031
38
+ correct_center_of_mass_motion: false
39
+ data_fields:
40
+ # if the order of the data_fields changes,
41
+ # the order of the above StochasticInterpolant inputs must also change
42
+ - "species"
43
+ - "pos"
44
+ - "cell"
45
+ integration_time_steps: 480
46
+ relative_si_costs:
47
+ species_loss: 0.0
48
+ pos_loss_b: 0.9860929911452281
49
+ cell_loss_b: 0.01390700885477196
50
+ sampler:
51
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
52
+ init_args:
53
+ pos_distribution: null
54
+ cell_distribution:
55
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
56
+ init_args:
57
+ dataset_name: perov_5
58
+ species_distribution:
59
+ class_path: omg.sampler.distributions.MirrorData
60
+ model:
61
+ class_path: omg.model.model.Model
62
+ init_args:
63
+ encoder:
64
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
65
+ head:
66
+ class_path: omg.model.heads.pass_through.PassThrough
67
+ time_embedder:
68
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
69
+ init_args:
70
+ dim: 256
71
+ use_min_perm_dist: True
72
+ float_32_matmul_precision: "high"
73
+ validation_mode: "match_rate"
74
+ dataset_name: "perov_5"
75
+ data:
76
+ train_dataset:
77
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
78
+ init_args:
79
+ dataset:
80
+ class_path: omg.datamodule.datamodule.DataModule
81
+ init_args:
82
+ lmdb_paths:
83
+ - "data/perov_5/train.lmdb"
84
+ niggli: True
85
+ val_dataset:
86
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
87
+ init_args:
88
+ dataset:
89
+ class_path: omg.datamodule.datamodule.DataModule
90
+ init_args:
91
+ lmdb_paths:
92
+ - "data/perov_5/val.lmdb"
93
+ niggli: True
94
+ predict_dataset:
95
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
96
+ init_args:
97
+ dataset:
98
+ class_path: omg.datamodule.datamodule.DataModule
99
+ init_args:
100
+ lmdb_paths:
101
+ - "data/perov_5/test.lmdb"
102
+ niggli: True
103
+ batch_size: 1024
104
+ num_workers: 4
105
+ pin_memory: True
106
+ persistent_workers: True
107
+ trainer:
108
+ callbacks:
109
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
110
+ init_args:
111
+ filename: "best_val_loss_total"
112
+ save_top_k: 1
113
+ monitor: "val_loss_total"
114
+ save_weights_only: true
115
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
116
+ init_args:
117
+ filename: "best_val_match_rate"
118
+ save_top_k: 1
119
+ monitor: "match_rate"
120
+ save_weights_only: true
121
+ mode: 'max'
122
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
123
+ init_args:
124
+ filename: "best_val_rmsd"
125
+ save_top_k: 1
126
+ monitor: "mean_rmsd"
127
+ save_weights_only: true
128
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
129
+ init_args:
130
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
131
+ monitor: "val_loss_total"
132
+ every_n_epochs: 100
133
+ save_weights_only: false
134
+ gradient_clip_val: 0.5
135
+ num_sanity_val_steps: 0
136
+ precision: "32-true"
137
+ max_epochs: 6000
138
+ enable_progress_bar: false
139
+ check_val_every_n_epoch: 100
140
+ optimizer:
141
+ class_path: torch.optim.Adam
142
+ init_args:
143
+ lr: 0.001147361965964576
Linear-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:665ec47274b0e0035cca0d1e675f6edefda932249b49a143fcaa0b7f858412e6
3
+ size 49644411
Linear-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.027547642683482473
16
+ epsilon:
17
+ class_path: omg.si.epsilon.VanishingEpsilon
18
+ init_args:
19
+ c: 8.26092465709134
20
+ mu: 0.1083235196756059
21
+ sigma: 0.03686939437589988
22
+ differential_equation_type: "SDE"
23
+ integrator_kwargs:
24
+ method: "euler"
25
+ dt: 0.001097909756936133
26
+ velocity_annealing_factor: 8.19603285406944
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.LinearInterpolant
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaSqrt
34
+ init_args:
35
+ a: 0.012871444447488238
36
+ epsilon: null
37
+ differential_equation_type: "ODE"
38
+ integrator_kwargs:
39
+ method: "euler"
40
+ velocity_annealing_factor: 1.4603041330880495
41
+ correct_center_of_mass_motion: false
42
+ data_fields:
43
+ # if the order of the data_fields changes,
44
+ # the order of the above StochasticInterpolant inputs must also change
45
+ - "species"
46
+ - "pos"
47
+ - "cell"
48
+ integration_time_steps: 910
49
+ relative_si_costs:
50
+ species_loss: 0.0
51
+ pos_loss_b: 0.0023778886849979398
52
+ pos_loss_z: 0.9924707469401747
53
+ cell_loss_b: 0.0051513643748272876
54
+ sampler:
55
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
56
+ init_args:
57
+ pos_distribution: null
58
+ cell_distribution:
59
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
60
+ init_args:
61
+ dataset_name: perov_5
62
+ species_distribution:
63
+ class_path: omg.sampler.distributions.MirrorData
64
+ model:
65
+ class_path: omg.model.model.Model
66
+ init_args:
67
+ encoder:
68
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
69
+ head:
70
+ class_path: omg.model.heads.pass_through.PassThrough
71
+ time_embedder:
72
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
73
+ init_args:
74
+ dim: 256
75
+ use_min_perm_dist: True
76
+ float_32_matmul_precision: "high"
77
+ validation_mode: "match_rate"
78
+ dataset_name: "perov_5"
79
+ data:
80
+ train_dataset:
81
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
82
+ init_args:
83
+ dataset:
84
+ class_path: omg.datamodule.datamodule.DataModule
85
+ init_args:
86
+ lmdb_paths:
87
+ - "data/perov_5/train.lmdb"
88
+ niggli: True
89
+ val_dataset:
90
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
91
+ init_args:
92
+ dataset:
93
+ class_path: omg.datamodule.datamodule.DataModule
94
+ init_args:
95
+ lmdb_paths:
96
+ - "data/perov_5/val.lmdb"
97
+ niggli: True
98
+ predict_dataset:
99
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
100
+ init_args:
101
+ dataset:
102
+ class_path: omg.datamodule.datamodule.DataModule
103
+ init_args:
104
+ lmdb_paths:
105
+ - "data/perov_5/test.lmdb"
106
+ niggli: True
107
+ batch_size: 128
108
+ num_workers: 4
109
+ pin_memory: True
110
+ persistent_workers: True
111
+ trainer:
112
+ callbacks:
113
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
114
+ init_args:
115
+ filename: "best_val_loss_total"
116
+ save_top_k: 1
117
+ monitor: "val_loss_total"
118
+ save_weights_only: true
119
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
120
+ init_args:
121
+ filename: "best_val_match_rate"
122
+ save_top_k: 1
123
+ monitor: "match_rate"
124
+ save_weights_only: true
125
+ mode: 'max'
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_rmsd"
129
+ save_top_k: 1
130
+ monitor: "mean_rmsd"
131
+ save_weights_only: true
132
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
133
+ init_args:
134
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
135
+ monitor: "val_loss_total"
136
+ every_n_epochs: 100
137
+ save_weights_only: false
138
+ gradient_clip_val: 0.5
139
+ num_sanity_val_steps: 0
140
+ precision: "32-true"
141
+ max_epochs: 6000
142
+ enable_progress_bar: false
143
+ check_val_every_n_epoch: 100
144
+ optimizer:
145
+ class_path: torch.optim.Adam
146
+ init_args:
147
+ lr: 1.3254493006246477e-05
Trig-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36062e9b35956dd19d3c7c087513c5fe75875a9700775290ae7cd0ed8818ac81
3
+ size 49644411
Trig-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.02648500626802044
16
+ epsilon: null
17
+ differential_equation_type: "ODE"
18
+ integrator_kwargs:
19
+ method: "euler"
20
+ velocity_annealing_factor: 3.8566932544902413
21
+ correct_center_of_mass_motion: true
22
+ # lattice vectors
23
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
24
+ init_args:
25
+ interpolant: omg.si.interpolants.LinearInterpolant
26
+ gamma: null
27
+ epsilon: null
28
+ differential_equation_type: "ODE"
29
+ integrator_kwargs:
30
+ method: "euler"
31
+ velocity_annealing_factor: 14.219455036917472
32
+ correct_center_of_mass_motion: false
33
+ data_fields:
34
+ # if the order of the data_fields changes,
35
+ # the order of the above StochasticInterpolant inputs must also change
36
+ - "species"
37
+ - "pos"
38
+ - "cell"
39
+ integration_time_steps: 970
40
+ relative_si_costs:
41
+ species_loss: 0.0
42
+ pos_loss_b: 0.8133671709485343
43
+ cell_loss_b: 0.1866328290514657
44
+ sampler:
45
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
46
+ init_args:
47
+ pos_distribution: null
48
+ cell_distribution:
49
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
50
+ init_args:
51
+ dataset_name: perov_5
52
+ species_distribution:
53
+ class_path: omg.sampler.distributions.MirrorData
54
+ model:
55
+ class_path: omg.model.model.Model
56
+ init_args:
57
+ encoder:
58
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
59
+ head:
60
+ class_path: omg.model.heads.pass_through.PassThrough
61
+ time_embedder:
62
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
63
+ init_args:
64
+ dim: 256
65
+ use_min_perm_dist: False
66
+ float_32_matmul_precision: "high"
67
+ validation_mode: "match_rate"
68
+ dataset_name: "perov_5"
69
+ data:
70
+ train_dataset:
71
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
72
+ init_args:
73
+ dataset:
74
+ class_path: omg.datamodule.datamodule.DataModule
75
+ init_args:
76
+ lmdb_paths:
77
+ - "data/perov_5/train.lmdb"
78
+ niggli: True
79
+ val_dataset:
80
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
81
+ init_args:
82
+ dataset:
83
+ class_path: omg.datamodule.datamodule.DataModule
84
+ init_args:
85
+ lmdb_paths:
86
+ - "data/perov_5/val.lmdb"
87
+ niggli: True
88
+ predict_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/perov_5/test.lmdb"
96
+ niggli: True
97
+ batch_size: 128
98
+ num_workers: 4
99
+ pin_memory: True
100
+ persistent_workers: True
101
+ trainer:
102
+ callbacks:
103
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
104
+ init_args:
105
+ filename: "best_val_loss_total"
106
+ save_top_k: 1
107
+ monitor: "val_loss_total"
108
+ save_weights_only: true
109
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
110
+ init_args:
111
+ filename: "best_val_match_rate"
112
+ save_top_k: 1
113
+ monitor: "match_rate"
114
+ save_weights_only: true
115
+ mode: 'max'
116
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
117
+ init_args:
118
+ filename: "best_val_rmsd"
119
+ save_top_k: 1
120
+ monitor: "mean_rmsd"
121
+ save_weights_only: true
122
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
123
+ init_args:
124
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
125
+ monitor: "val_loss_total"
126
+ every_n_epochs: 100
127
+ save_weights_only: false
128
+ gradient_clip_val: 0.5
129
+ num_sanity_val_steps: 0
130
+ precision: "32-true"
131
+ max_epochs: 6000
132
+ enable_progress_bar: false
133
+ check_val_every_n_epoch: 100
134
+ optimizer:
135
+ class_path: torch.optim.Adam
136
+ init_args:
137
+ lr: 1.620271269284964e-05
Trig-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c922e442102fe8257aca134fecf0fe5a151289f7def154e0cf09ee0684ba4ca
3
+ size 49644411
Trig-ODE/train.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma: null
13
+ epsilon: null
14
+ differential_equation_type: "ODE"
15
+ integrator_kwargs:
16
+ method: "euler"
17
+ velocity_annealing_factor: 14.9938835509918
18
+ correct_center_of_mass_motion: true
19
+ # lattice vectors
20
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
21
+ init_args:
22
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
23
+ gamma:
24
+ class_path: omg.si.gamma.LatentGammaSqrt
25
+ init_args:
26
+ a: 0.021443243513445315
27
+ epsilon: null
28
+ differential_equation_type: "ODE"
29
+ integrator_kwargs:
30
+ method: "euler"
31
+ velocity_annealing_factor: 14.973558717968908
32
+ correct_center_of_mass_motion: false
33
+ data_fields:
34
+ # if the order of the data_fields changes,
35
+ # the order of the above StochasticInterpolant inputs must also change
36
+ - "species"
37
+ - "pos"
38
+ - "cell"
39
+ integration_time_steps: 880
40
+ relative_si_costs:
41
+ species_loss: 0.0
42
+ pos_loss_b: 0.9983263145571981
43
+ cell_loss_b: 0.00167368544280187
44
+ sampler:
45
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
46
+ init_args:
47
+ pos_distribution: null
48
+ cell_distribution:
49
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
50
+ init_args:
51
+ dataset_name: perov_5
52
+ species_distribution:
53
+ class_path: omg.sampler.distributions.MirrorData
54
+ model:
55
+ class_path: omg.model.model.Model
56
+ init_args:
57
+ encoder:
58
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
59
+ head:
60
+ class_path: omg.model.heads.pass_through.PassThrough
61
+ time_embedder:
62
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
63
+ init_args:
64
+ dim: 256
65
+ use_min_perm_dist: True
66
+ float_32_matmul_precision: "high"
67
+ validation_mode: "match_rate"
68
+ dataset_name: "perov_5"
69
+ data:
70
+ train_dataset:
71
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
72
+ init_args:
73
+ dataset:
74
+ class_path: omg.datamodule.datamodule.DataModule
75
+ init_args:
76
+ lmdb_paths:
77
+ - "data/perov_5/train.lmdb"
78
+ niggli: False
79
+ val_dataset:
80
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
81
+ init_args:
82
+ dataset:
83
+ class_path: omg.datamodule.datamodule.DataModule
84
+ init_args:
85
+ lmdb_paths:
86
+ - "data/perov_5/val.lmdb"
87
+ niggli: False
88
+ predict_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/perov_5/test.lmdb"
96
+ niggli: False
97
+ batch_size: 256
98
+ num_workers: 4
99
+ pin_memory: True
100
+ persistent_workers: True
101
+ trainer:
102
+ callbacks:
103
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
104
+ init_args:
105
+ filename: "best_val_loss_total"
106
+ save_top_k: 1
107
+ monitor: "val_loss_total"
108
+ save_weights_only: true
109
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
110
+ init_args:
111
+ filename: "best_val_match_rate"
112
+ save_top_k: 1
113
+ monitor: "match_rate"
114
+ save_weights_only: true
115
+ mode: 'max'
116
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
117
+ init_args:
118
+ filename: "best_val_rmsd"
119
+ save_top_k: 1
120
+ monitor: "mean_rmsd"
121
+ save_weights_only: true
122
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
123
+ init_args:
124
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
125
+ monitor: "val_loss_total"
126
+ every_n_epochs: 100
127
+ save_weights_only: false
128
+ gradient_clip_val: 0.5
129
+ num_sanity_val_steps: 0
130
+ precision: "32-true"
131
+ max_epochs: 6000
132
+ enable_progress_bar: false
133
+ check_val_every_n_epoch: 100
134
+ optimizer:
135
+ class_path: torch.optim.Adam
136
+ init_args:
137
+ lr: 4.4871577022001995e-05
Trig-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d79dfab91a8fb1dc38329091496843b097f2f81bc3d879770cd8b2775dc4808b
3
+ size 49644411
Trig-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.06271372569234963
16
+ epsilon:
17
+ class_path: omg.si.epsilon.VanishingEpsilon
18
+ init_args:
19
+ c: 7.478617683255472
20
+ mu: 0.06295065489868475
21
+ sigma: 0.03384344419315302
22
+ differential_equation_type: "SDE"
23
+ integrator_kwargs:
24
+ method: "euler"
25
+ dt: 0.0011101224226877093
26
+ velocity_annealing_factor: 3.4362588970266796
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.LinearInterpolant
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaSqrt
34
+ init_args:
35
+ a: 0.0508912424117229
36
+ epsilon: null
37
+ differential_equation_type: "ODE"
38
+ integrator_kwargs:
39
+ method: "euler"
40
+ velocity_annealing_factor: 0.03360577590810462
41
+ correct_center_of_mass_motion: false
42
+ data_fields:
43
+ # if the order of the data_fields changes,
44
+ # the order of the above StochasticInterpolant inputs must also change
45
+ - "species"
46
+ - "pos"
47
+ - "cell"
48
+ integration_time_steps: 900
49
+ relative_si_costs:
50
+ species_loss: 0.0
51
+ pos_loss_b: 0.6868298746007011
52
+ pos_loss_z: 0.24887105602907683
53
+ cell_loss_b: 0.064299069370222
54
+ sampler:
55
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
56
+ init_args:
57
+ pos_distribution: null
58
+ cell_distribution:
59
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
60
+ init_args:
61
+ dataset_name: perov_5
62
+ species_distribution:
63
+ class_path: omg.sampler.distributions.MirrorData
64
+ model:
65
+ class_path: omg.model.model.Model
66
+ init_args:
67
+ encoder:
68
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
69
+ head:
70
+ class_path: omg.model.heads.pass_through.PassThrough
71
+ time_embedder:
72
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
73
+ init_args:
74
+ dim: 256
75
+ use_min_perm_dist: True
76
+ float_32_matmul_precision: "high"
77
+ validation_mode: "match_rate"
78
+ dataset_name: "perov_5"
79
+ data:
80
+ train_dataset:
81
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
82
+ init_args:
83
+ dataset:
84
+ class_path: omg.datamodule.datamodule.DataModule
85
+ init_args:
86
+ lmdb_paths:
87
+ - "data/perov_5/train.lmdb"
88
+ niggli: True
89
+ val_dataset:
90
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
91
+ init_args:
92
+ dataset:
93
+ class_path: omg.datamodule.datamodule.DataModule
94
+ init_args:
95
+ lmdb_paths:
96
+ - "data/perov_5/val.lmdb"
97
+ niggli: True
98
+ predict_dataset:
99
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
100
+ init_args:
101
+ dataset:
102
+ class_path: omg.datamodule.datamodule.DataModule
103
+ init_args:
104
+ lmdb_paths:
105
+ - "data/perov_5/test.lmdb"
106
+ niggli: True
107
+ batch_size: 512
108
+ num_workers: 4
109
+ pin_memory: True
110
+ persistent_workers: True
111
+ trainer:
112
+ callbacks:
113
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
114
+ init_args:
115
+ filename: "best_val_loss_total"
116
+ save_top_k: 1
117
+ monitor: "val_loss_total"
118
+ save_weights_only: true
119
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
120
+ init_args:
121
+ filename: "best_val_match_rate"
122
+ save_top_k: 1
123
+ monitor: "match_rate"
124
+ save_weights_only: true
125
+ mode: 'max'
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_rmsd"
129
+ save_top_k: 1
130
+ monitor: "mean_rmsd"
131
+ save_weights_only: true
132
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
133
+ init_args:
134
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
135
+ monitor: "val_loss_total"
136
+ every_n_epochs: 100
137
+ save_weights_only: false
138
+ gradient_clip_val: 0.5
139
+ num_sanity_val_steps: 0
140
+ precision: "32-true"
141
+ max_epochs: 6000
142
+ enable_progress_bar: false
143
+ check_val_every_n_epoch: 100
144
+ optimizer:
145
+ class_path: torch.optim.Adam
146
+ init_args:
147
+ lr: 7.173658889538975e-05
VESBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a5ec3b1762902e9fb6fd6f0f5a24e34f66f9bed2e69f94bcb03413b476c4080
3
+ size 49642338
VESBD-ODE/train.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolantVE
13
+ init_args:
14
+ sigma:
15
+ class_path: omg.si.sigma.GeometricSigma
16
+ init_args:
17
+ sigma_min: 0.007753186833706728
18
+ sigma_max: 0.5165059747015202
19
+ epsilon: null
20
+ differential_equation_type: "ODE"
21
+ integrator_kwargs:
22
+ method: "euler"
23
+ velocity_annealing_factor: 0.0030999124784898413
24
+ correct_center_of_mass_motion: true
25
+ predict_velocity: true
26
+ # lattice vectors
27
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
28
+ init_args:
29
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
30
+ gamma:
31
+ class_path: omg.si.gamma.LatentGammaSqrt
32
+ init_args:
33
+ a: 0.024482789522429726
34
+ epsilon:
35
+ class_path: omg.si.epsilon.VanishingEpsilon
36
+ init_args:
37
+ c: 9.940425570212101
38
+ mu: 0.24041599621265147
39
+ sigma: 0.021132860336543085
40
+ differential_equation_type: "SDE"
41
+ integrator_kwargs:
42
+ method: "euler"
43
+ dt: 0.0026332451961934566
44
+ velocity_annealing_factor: 14.933642154361792
45
+ correct_center_of_mass_motion: false
46
+ data_fields:
47
+ # if the order of the data_fields changes,
48
+ # the order of the above StochasticInterpolant inputs must also change
49
+ - "species"
50
+ - "pos"
51
+ - "cell"
52
+ integration_time_steps: 380
53
+ relative_si_costs:
54
+ species_loss: 0.0
55
+ pos_loss_b: 0.979954187812053
56
+ cell_loss_b: 0.01866918394074503
57
+ cell_loss_z: 0.0013766282472020075
58
+ sampler:
59
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
60
+ init_args:
61
+ pos_distribution:
62
+ class_path: omg.sampler.distributions.NormalDistribution
63
+ init_args:
64
+ scale: 8.955438982782663
65
+ cell_distribution:
66
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
67
+ init_args:
68
+ dataset_name: perov_5
69
+ species_distribution:
70
+ class_path: omg.sampler.distributions.MirrorData
71
+ model:
72
+ class_path: omg.model.model.Model
73
+ init_args:
74
+ encoder:
75
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
76
+ head:
77
+ class_path: omg.model.heads.pass_through.PassThrough
78
+ time_embedder:
79
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
80
+ init_args:
81
+ dim: 256
82
+ use_min_perm_dist: False
83
+ float_32_matmul_precision: "high"
84
+ validation_mode: "match_rate"
85
+ dataset_name: "perov_5"
86
+ data:
87
+ train_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/perov_5/train.lmdb"
95
+ niggli: False
96
+ val_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/perov_5/val.lmdb"
104
+ niggli: False
105
+ predict_dataset:
106
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
107
+ init_args:
108
+ dataset:
109
+ class_path: omg.datamodule.datamodule.DataModule
110
+ init_args:
111
+ lmdb_paths:
112
+ - "data/perov_5/test.lmdb"
113
+ niggli: False
114
+ batch_size: 256
115
+ num_workers: 4
116
+ pin_memory: True
117
+ persistent_workers: True
118
+ trainer:
119
+ callbacks:
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_loss_total"
123
+ save_top_k: 1
124
+ monitor: "val_loss_total"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_match_rate"
129
+ save_top_k: 1
130
+ monitor: "match_rate"
131
+ save_weights_only: true
132
+ mode: 'max'
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_rmsd"
136
+ save_top_k: 1
137
+ monitor: "mean_rmsd"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
142
+ monitor: "val_loss_total"
143
+ every_n_epochs: 100
144
+ save_weights_only: false
145
+ gradient_clip_val: 0.5
146
+ num_sanity_val_steps: 0
147
+ precision: "32-true"
148
+ max_epochs: 6000
149
+ enable_progress_bar: false
150
+ check_val_every_n_epoch: 100
151
+ optimizer:
152
+ class_path: torch.optim.Adam
153
+ init_args:
154
+ lr: 0.0077762908469486665
VPSBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4352b9d2fa58fd50f7de927ba905c0cc7685a3e2780da3f3b4a00c550f91ebf2
3
+ size 49644411
VPSBD-ODE/train.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
12
+ epsilon: null
13
+ differential_equation_type: "ODE"
14
+ integrator_kwargs:
15
+ method: "euler"
16
+ velocity_annealing_factor: 12.792912174596323
17
+ correct_center_of_mass_motion: true
18
+ predict_velocity: true
19
+ # lattice vectors
20
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
21
+ init_args:
22
+ interpolant: omg.si.interpolants.ScoreBasedDiffusionModelInterpolant
23
+ epsilon:
24
+ class_path: omg.si.epsilon.VanishingEpsilon
25
+ init_args:
26
+ c: 8.480198053500128
27
+ mu: 0.12906782653832816
28
+ sigma: 0.0485371724887369
29
+ differential_equation_type: "SDE"
30
+ integrator_kwargs:
31
+ method: "euler"
32
+ dt: 0.007736434228718281
33
+ velocity_annealing_factor: 2.690266084902449
34
+ correct_center_of_mass_motion: false
35
+ predict_velocity: true
36
+ data_fields:
37
+ # if the order of the data_fields changes,
38
+ # the order of the above StochasticInterpolant inputs must also change
39
+ - "species"
40
+ - "pos"
41
+ - "cell"
42
+ integration_time_steps: 130
43
+ relative_si_costs:
44
+ species_loss: 0.0
45
+ pos_loss_b: 0.003496793110817246
46
+ cell_loss_b: 0.01211289827585164
47
+ cell_loss_z: 0.9843903086133311
48
+ sampler:
49
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
50
+ init_args:
51
+ pos_distribution:
52
+ class_path: omg.sampler.distributions.NormalDistribution
53
+ init_args:
54
+ scale: 0.2775300948889965
55
+ cell_distribution:
56
+ class_path: omg.sampler.distributions.NormalDistribution
57
+ init_args:
58
+ scale: 0.6055540534879594
59
+ species_distribution:
60
+ class_path: omg.sampler.distributions.MirrorData
61
+ model:
62
+ class_path: omg.model.model.Model
63
+ init_args:
64
+ encoder:
65
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
66
+ head:
67
+ class_path: omg.model.heads.pass_through.PassThrough
68
+ time_embedder:
69
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
70
+ init_args:
71
+ dim: 256
72
+ use_min_perm_dist: True
73
+ float_32_matmul_precision: "high"
74
+ validation_mode: "match_rate"
75
+ dataset_name: "perov_5"
76
+ data:
77
+ train_dataset:
78
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
79
+ init_args:
80
+ dataset:
81
+ class_path: omg.datamodule.datamodule.DataModule
82
+ init_args:
83
+ lmdb_paths:
84
+ - "data/perov_5/train.lmdb"
85
+ niggli: True
86
+ val_dataset:
87
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
88
+ init_args:
89
+ dataset:
90
+ class_path: omg.datamodule.datamodule.DataModule
91
+ init_args:
92
+ lmdb_paths:
93
+ - "data/perov_5/val.lmdb"
94
+ niggli: True
95
+ predict_dataset:
96
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
97
+ init_args:
98
+ dataset:
99
+ class_path: omg.datamodule.datamodule.DataModule
100
+ init_args:
101
+ lmdb_paths:
102
+ - "data/perov_5/test.lmdb"
103
+ niggli: True
104
+ batch_size: 512
105
+ num_workers: 4
106
+ pin_memory: True
107
+ persistent_workers: True
108
+ trainer:
109
+ callbacks:
110
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
111
+ init_args:
112
+ filename: "best_val_loss_total"
113
+ save_top_k: 1
114
+ monitor: "val_loss_total"
115
+ save_weights_only: true
116
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
117
+ init_args:
118
+ filename: "best_val_match_rate"
119
+ save_top_k: 1
120
+ monitor: "match_rate"
121
+ save_weights_only: true
122
+ mode: 'max'
123
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
124
+ init_args:
125
+ filename: "best_val_rmsd"
126
+ save_top_k: 1
127
+ monitor: "mean_rmsd"
128
+ save_weights_only: true
129
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
130
+ init_args:
131
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
132
+ monitor: "val_loss_total"
133
+ every_n_epochs: 100
134
+ save_weights_only: false
135
+ gradient_clip_val: 0.5
136
+ num_sanity_val_steps: 0
137
+ precision: "32-true"
138
+ max_epochs: 6000
139
+ enable_progress_bar: false
140
+ check_val_every_n_epoch: 100
141
+ optimizer:
142
+ class_path: torch.optim.Adam
143
+ init_args:
144
+ lr: 0.0008719662356797908
VPSBD-SDE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:281a567887da1041008b32cb86478969531ea2b18d146be3e5fbe2df66316678
3
+ size 49644411
VPSBD-SDE/train.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
12
+ epsilon:
13
+ class_path: omg.si.epsilon.VanishingEpsilon
14
+ init_args:
15
+ c: 6.705334122560177
16
+ mu: 0.1759894495124853
17
+ sigma: 0.02684743624891644
18
+ differential_equation_type: "SDE"
19
+ integrator_kwargs:
20
+ method: "euler"
21
+ dt: 0.002859598957002163
22
+ velocity_annealing_factor: 11.540982308844075
23
+ correct_center_of_mass_motion: true
24
+ predict_velocity: true
25
+ # lattice vectors
26
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
27
+ init_args:
28
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
29
+ gamma:
30
+ class_path: omg.si.gamma.LatentGammaSqrt
31
+ init_args:
32
+ a: 0.028688550336857962
33
+ epsilon:
34
+ class_path: omg.si.epsilon.VanishingEpsilon
35
+ init_args:
36
+ c: 8.616058241205366
37
+ mu: 0.20683060178158524
38
+ sigma: 0.010467959402930785
39
+ differential_equation_type: "SDE"
40
+ integrator_kwargs:
41
+ method: "euler"
42
+ dt: 0.002859598957002163
43
+ velocity_annealing_factor: 11.528499292207702
44
+ correct_center_of_mass_motion: false
45
+ data_fields:
46
+ # if the order of the data_fields changes,
47
+ # the order of the above StochasticInterpolant inputs must also change
48
+ - "species"
49
+ - "pos"
50
+ - "cell"
51
+ integration_time_steps: 350
52
+ relative_si_costs:
53
+ species_loss: 0.0
54
+ pos_loss_b: 0.2897890800401683
55
+ pos_loss_z: 0.3259349777392057
56
+ cell_loss_b: 0.19601072982998402
57
+ cell_loss_z: 0.18826521239064184
58
+ sampler:
59
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
60
+ init_args:
61
+ pos_distribution:
62
+ class_path: omg.sampler.distributions.NormalDistribution
63
+ init_args:
64
+ scale: 0.12777034312154512
65
+ cell_distribution:
66
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
67
+ init_args:
68
+ dataset_name: perov_5
69
+ species_distribution:
70
+ class_path: omg.sampler.distributions.MirrorData
71
+ model:
72
+ class_path: omg.model.model.Model
73
+ init_args:
74
+ encoder:
75
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
76
+ head:
77
+ class_path: omg.model.heads.pass_through.PassThrough
78
+ time_embedder:
79
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
80
+ init_args:
81
+ dim: 256
82
+ use_min_perm_dist: True
83
+ float_32_matmul_precision: "high"
84
+ validation_mode: "match_rate"
85
+ dataset_name: "perov_5"
86
+ data:
87
+ train_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/perov_5/train.lmdb"
95
+ niggli: False
96
+ val_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/perov_5/val.lmdb"
104
+ niggli: False
105
+ predict_dataset:
106
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
107
+ init_args:
108
+ dataset:
109
+ class_path: omg.datamodule.datamodule.DataModule
110
+ init_args:
111
+ lmdb_paths:
112
+ - "data/perov_5/test.lmdb"
113
+ niggli: False
114
+ batch_size: 128
115
+ num_workers: 4
116
+ pin_memory: True
117
+ persistent_workers: True
118
+ trainer:
119
+ callbacks:
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_loss_total"
123
+ save_top_k: 1
124
+ monitor: "val_loss_total"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_match_rate"
129
+ save_top_k: 1
130
+ monitor: "match_rate"
131
+ save_weights_only: true
132
+ mode: 'max'
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_rmsd"
136
+ save_top_k: 1
137
+ monitor: "mean_rmsd"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
142
+ monitor: "val_loss_total"
143
+ every_n_epochs: 100
144
+ save_weights_only: false
145
+ gradient_clip_val: 0.5
146
+ num_sanity_val_steps: 0
147
+ precision: "32-true"
148
+ max_epochs: 6000
149
+ enable_progress_bar: false
150
+ check_val_every_n_epoch: 100
151
+ optimizer:
152
+ class_path: torch.optim.Adam
153
+ init_args:
154
+ lr: 3.829398871139748e-05