tomegg3 commited on
Commit
ba97b2e
·
verified ·
1 Parent(s): 23e92cb

Upload 22 files

Browse files
EncDec-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b523412c0a025405dbef142ff19d6ae264152a73bc7361a745bd4ed8fc3829e
3
+ size 49646459
EncDec-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 0.8465315128078521
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
14
+ gamma: omg.si.gamma.LatentGammaEncoderDecoder
15
+ epsilon: null
16
+ differential_equation_type: "ODE"
17
+ integrator_kwargs:
18
+ method: "euler"
19
+ velocity_annealing_factor: 10.274308845621986
20
+ correct_center_of_mass_motion: true
21
+ # lattice vectors
22
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
23
+ init_args:
24
+ interpolant: omg.si.interpolants.LinearInterpolant
25
+ gamma: null
26
+ epsilon: null
27
+ differential_equation_type: "ODE"
28
+ integrator_kwargs:
29
+ method: "euler"
30
+ velocity_annealing_factor: 0.08217016314129522
31
+ correct_center_of_mass_motion: false
32
+ data_fields:
33
+ # if the order of the data_fields changes,
34
+ # the order of the above StochasticInterpolant inputs must also change
35
+ - "species"
36
+ - "pos"
37
+ - "cell"
38
+ integration_time_steps: 840
39
+ relative_si_costs:
40
+ species_loss: 0.2648412544596816
41
+ pos_loss_b: 0.7267924862588087
42
+ cell_loss_b: 0.008366259281509825
43
+ sampler:
44
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
45
+ init_args:
46
+ pos_distribution: null
47
+ cell_distribution:
48
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
49
+ init_args:
50
+ dataset_name: mp_20
51
+ species_distribution:
52
+ class_path: omg.sampler.distributions.MaskDistribution
53
+ model:
54
+ class_path: omg.model.model.Model
55
+ init_args:
56
+ encoder:
57
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
58
+ head:
59
+ class_path: omg.model.heads.pass_through.PassThrough
60
+ time_embedder:
61
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
62
+ init_args:
63
+ dim: 256
64
+ use_min_perm_dist: False
65
+ float_32_matmul_precision: "high"
66
+ validation_mode: "dng_eval"
67
+ dataset_name: "mp_20"
68
+ data:
69
+ train_dataset:
70
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
71
+ init_args:
72
+ dataset:
73
+ class_path: omg.datamodule.datamodule.DataModule
74
+ init_args:
75
+ lmdb_paths:
76
+ - "data/mp_20/train.lmdb"
77
+ niggli: False
78
+ val_dataset:
79
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
80
+ init_args:
81
+ dataset:
82
+ class_path: omg.datamodule.datamodule.DataModule
83
+ init_args:
84
+ lmdb_paths:
85
+ - "data/mp_20/val.lmdb"
86
+ niggli: False
87
+ predict_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/mp_20/test.lmdb"
95
+ niggli: False
96
+ batch_size: 128
97
+ num_workers: 4
98
+ pin_memory: True
99
+ persistent_workers: True
100
+ trainer:
101
+ callbacks:
102
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
103
+ init_args:
104
+ filename: "best_val_loss_total"
105
+ save_top_k: 1
106
+ monitor: "val_loss_total"
107
+ save_weights_only: true
108
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
109
+ init_args:
110
+ filename: "best_val_dng_eval"
111
+ save_top_k: 1
112
+ monitor: "dng_eval"
113
+ save_weights_only: true
114
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
115
+ init_args:
116
+ filename: "best_val_wdist_density"
117
+ save_top_k: 1
118
+ monitor: "wdist_density"
119
+ save_weights_only: true
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_wdist_Nary"
123
+ save_top_k: 1
124
+ monitor: "wdist_Nary"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_wdist_CN"
129
+ save_top_k: 1
130
+ monitor: "wdist_CN"
131
+ save_weights_only: true
132
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
133
+ init_args:
134
+ filename: "best_val_cov_precision"
135
+ save_top_k: 1
136
+ monitor: "cov_precision"
137
+ mode: "max"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ filename: "best_val_cov_recall"
142
+ save_top_k: 1
143
+ monitor: "cov_recall"
144
+ mode: "max"
145
+ save_weights_only: true
146
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
147
+ init_args:
148
+ filename: "best_val_validity"
149
+ save_top_k: 1
150
+ monitor: "validity_rate"
151
+ mode: "max"
152
+ save_weights_only: true
153
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
154
+ init_args:
155
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
156
+ monitor: "val_loss_total"
157
+ every_n_epochs: 100
158
+ save_weights_only: false
159
+ gradient_clip_val: 0.5
160
+ num_sanity_val_steps: 0
161
+ precision: "32-true"
162
+ max_epochs: 10000
163
+ enable_progress_bar: false
164
+ check_val_every_n_epoch: 100
165
+ optimizer:
166
+ class_path: torch.optim.Adam
167
+ init_args:
168
+ lr: 0.00012021943412654004
EncDec-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5242724db2a21bd949624fe805e84986b0f0ccfb229913bf3f9cf79ce5540f3d
3
+ size 148494714
EncDec-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 19.77565076697948
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
12
+ init_args:
13
+ interpolant:
14
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
15
+ init_args:
16
+ switch_time: 0.7261434470495144
17
+ power: 0.5
18
+ gamma:
19
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
20
+ init_args:
21
+ a: 0.10379253121526635
22
+ switch_time: 0.7261434470495144
23
+ power: 0.5
24
+ epsilon:
25
+ class_path: omg.si.epsilon.VanishingEpsilon
26
+ init_args:
27
+ c: 7.095366578175936
28
+ mu: 0.18874541537289413
29
+ sigma: 0.020535877713041894
30
+ differential_equation_type: "SDE"
31
+ integrator_kwargs:
32
+ method: "euler"
33
+ dt: 0.0016387520590797067
34
+ velocity_annealing_factor: 7.868988752162741
35
+ correct_center_of_mass_motion: true
36
+ # lattice vectors
37
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
38
+ init_args:
39
+ interpolant: omg.si.interpolants.LinearInterpolant
40
+ gamma:
41
+ class_path: omg.si.gamma.LatentGammaSqrt
42
+ init_args:
43
+ a: 1.6505152121826552
44
+ epsilon: null
45
+ differential_equation_type: "ODE"
46
+ integrator_kwargs:
47
+ method: "euler"
48
+ velocity_annealing_factor: 3.919302532270132
49
+ correct_center_of_mass_motion: false
50
+ data_fields:
51
+ # if the order of the data_fields changes,
52
+ # the order of the above StochasticInterpolant inputs must also change
53
+ - "species"
54
+ - "pos"
55
+ - "cell"
56
+ integration_time_steps: 610
57
+ relative_si_costs:
58
+ species_loss: 0.4341084689667317
59
+ pos_loss_b: 0.21431903999859292
60
+ pos_loss_z: 0.1968226417175204
61
+ cell_loss_b: 0.15474984931715505
62
+ sampler:
63
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
64
+ init_args:
65
+ pos_distribution: null
66
+ cell_distribution:
67
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
68
+ init_args:
69
+ dataset_name: mp_20
70
+ species_distribution:
71
+ class_path: omg.sampler.distributions.MaskDistribution
72
+ model:
73
+ class_path: omg.model.model.Model
74
+ init_args:
75
+ encoder:
76
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
77
+ head:
78
+ class_path: omg.model.heads.pass_through.PassThrough
79
+ time_embedder:
80
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
81
+ init_args:
82
+ dim: 256
83
+ use_min_perm_dist: False
84
+ float_32_matmul_precision: "high"
85
+ validation_mode: "dng_eval"
86
+ dataset_name: "mp_20"
87
+ data:
88
+ train_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/mp_20/train.lmdb"
96
+ niggli: False
97
+ val_dataset:
98
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
99
+ init_args:
100
+ dataset:
101
+ class_path: omg.datamodule.datamodule.DataModule
102
+ init_args:
103
+ lmdb_paths:
104
+ - "data/mp_20/val.lmdb"
105
+ niggli: False
106
+ predict_dataset:
107
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
108
+ init_args:
109
+ dataset:
110
+ class_path: omg.datamodule.datamodule.DataModule
111
+ init_args:
112
+ lmdb_paths:
113
+ - "data/mp_20/test.lmdb"
114
+ niggli: False
115
+ batch_size: 32
116
+ num_workers: 4
117
+ pin_memory: True
118
+ persistent_workers: True
119
+ trainer:
120
+ callbacks:
121
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
122
+ init_args:
123
+ filename: "best_val_loss_total"
124
+ save_top_k: 1
125
+ monitor: "val_loss_total"
126
+ save_weights_only: true
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_dng_eval"
130
+ save_top_k: 1
131
+ monitor: "dng_eval"
132
+ save_weights_only: true
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_wdist_density"
136
+ save_top_k: 1
137
+ monitor: "wdist_density"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ filename: "best_val_wdist_Nary"
142
+ save_top_k: 1
143
+ monitor: "wdist_Nary"
144
+ save_weights_only: true
145
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
146
+ init_args:
147
+ filename: "best_val_wdist_CN"
148
+ save_top_k: 1
149
+ monitor: "wdist_CN"
150
+ save_weights_only: true
151
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
152
+ init_args:
153
+ filename: "best_val_cov_precision"
154
+ save_top_k: 1
155
+ monitor: "cov_precision"
156
+ mode: "max"
157
+ save_weights_only: true
158
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
159
+ init_args:
160
+ filename: "best_val_cov_recall"
161
+ save_top_k: 1
162
+ monitor: "cov_recall"
163
+ mode: "max"
164
+ save_weights_only: true
165
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
166
+ init_args:
167
+ filename: "best_val_validity"
168
+ save_top_k: 1
169
+ monitor: "validity_rate"
170
+ mode: "max"
171
+ save_weights_only: true
172
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
173
+ init_args:
174
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
175
+ monitor: "val_loss_total"
176
+ every_n_epochs: 100
177
+ save_weights_only: false
178
+ gradient_clip_val: 0.5
179
+ gradient_clip_algorithm: "value"
180
+ num_sanity_val_steps: 0
181
+ precision: "32-true"
182
+ max_epochs: 2000
183
+ enable_progress_bar: false
184
+ check_val_every_n_epoch: 100
185
+ optimizer:
186
+ class_path: torch.optim.AdamW
187
+ init_args:
188
+ lr: 3.139610174577985e-05
189
+ weight_decay: 3.560067412494533e-05
190
+ lr_scheduler:
191
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
192
+ init_args:
193
+ T_max: 2000
194
+ eta_min: 1e-07
195
+
Linear-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6280e04b110be03b5bf5b727e34d9c42c3ba3ce58c943a7bea6db8824ddfdc1d
3
+ size 148519290
Linear-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 23.870491382634235
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
14
+ gamma:
15
+ class_path: omg.si.gamma.LatentGammaSqrt
16
+ init_args:
17
+ a: 1.4501684803942854
18
+ epsilon: null
19
+ differential_equation_type: "ODE"
20
+ integrator_kwargs:
21
+ method: "euler"
22
+ velocity_annealing_factor: 14.825022083056373
23
+ correct_center_of_mass_motion: true
24
+ # lattice vectors
25
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
26
+ init_args:
27
+ interpolant:
28
+ class_path: omg.si.interpolants.EncoderDecoderInterpolant
29
+ init_args:
30
+ switch_time: 0.13766881993242777
31
+ power: 1.0
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
34
+ init_args:
35
+ a: 7.882046441638109
36
+ switch_time: 0.13766881993242777
37
+ power: 1.0
38
+ epsilon:
39
+ class_path: omg.si.epsilon.VanishingEpsilon
40
+ init_args:
41
+ c: 5.487699104615115
42
+ mu: 0.2899409657474152
43
+ sigma: 0.010062500495585096
44
+ differential_equation_type: "SDE"
45
+ integrator_kwargs:
46
+ method: "euler"
47
+ dt: 0.007736434228718281
48
+ velocity_annealing_factor: 5.9072140831863305
49
+ correct_center_of_mass_motion: false
50
+ data_fields:
51
+ # if the order of the data_fields changes,
52
+ # the order of the above StochasticInterpolant inputs must also change
53
+ - "species"
54
+ - "pos"
55
+ - "cell"
56
+ integration_time_steps: 130
57
+ relative_si_costs:
58
+ species_loss: 0.22163297905676768
59
+ pos_loss_b: 0.7682858351574293
60
+ cell_loss_b: 0.008860420349864338
61
+ cell_loss_z: 0.001220765435938645
62
+ sampler:
63
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
64
+ init_args:
65
+ pos_distribution: null
66
+ cell_distribution:
67
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
68
+ init_args:
69
+ dataset_name: mp_20
70
+ species_distribution:
71
+ class_path: omg.sampler.distributions.MaskDistribution
72
+ model:
73
+ class_path: omg.model.model.Model
74
+ init_args:
75
+ encoder:
76
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
77
+ head:
78
+ class_path: omg.model.heads.pass_through.PassThrough
79
+ time_embedder:
80
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
81
+ init_args:
82
+ dim: 256
83
+ use_min_perm_dist: True
84
+ float_32_matmul_precision: "high"
85
+ validation_mode: "dng_eval"
86
+ dataset_name: "mp_20"
87
+ data:
88
+ train_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/mp_20/train.lmdb"
96
+ niggli: False
97
+ val_dataset:
98
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
99
+ init_args:
100
+ dataset:
101
+ class_path: omg.datamodule.datamodule.DataModule
102
+ init_args:
103
+ lmdb_paths:
104
+ - "data/mp_20/val.lmdb"
105
+ niggli: False
106
+ predict_dataset:
107
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
108
+ init_args:
109
+ dataset:
110
+ class_path: omg.datamodule.datamodule.DataModule
111
+ init_args:
112
+ lmdb_paths:
113
+ - "data/mp_20/test.lmdb"
114
+ niggli: False
115
+ batch_size: 256
116
+ num_workers: 4
117
+ pin_memory: True
118
+ persistent_workers: True
119
+ trainer:
120
+ callbacks:
121
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
122
+ init_args:
123
+ filename: "best_val_loss_total"
124
+ save_top_k: 1
125
+ monitor: "val_loss_total"
126
+ save_weights_only: true
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_dng_eval"
130
+ save_top_k: 1
131
+ monitor: "dng_eval"
132
+ save_weights_only: true
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_wdist_density"
136
+ save_top_k: 1
137
+ monitor: "wdist_density"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ filename: "best_val_wdist_Nary"
142
+ save_top_k: 1
143
+ monitor: "wdist_Nary"
144
+ save_weights_only: true
145
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
146
+ init_args:
147
+ filename: "best_val_wdist_CN"
148
+ save_top_k: 1
149
+ monitor: "wdist_CN"
150
+ save_weights_only: true
151
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
152
+ init_args:
153
+ filename: "best_val_cov_precision"
154
+ save_top_k: 1
155
+ monitor: "cov_precision"
156
+ mode: "max"
157
+ save_weights_only: true
158
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
159
+ init_args:
160
+ filename: "best_val_cov_recall"
161
+ save_top_k: 1
162
+ monitor: "cov_recall"
163
+ mode: "max"
164
+ save_weights_only: true
165
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
166
+ init_args:
167
+ filename: "best_val_validity"
168
+ save_top_k: 1
169
+ monitor: "validity_rate"
170
+ mode: "max"
171
+ save_weights_only: true
172
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
173
+ init_args:
174
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
175
+ monitor: "val_loss_total"
176
+ every_n_epochs: 100
177
+ save_weights_only: false
178
+ gradient_clip_val: 0.5
179
+ gradient_clip_algorithm: "value"
180
+ num_sanity_val_steps: 0
181
+ precision: "32-true"
182
+ max_epochs: 2000
183
+ enable_progress_bar: false
184
+ check_val_every_n_epoch: 100
185
+ optimizer:
186
+ class_path: torch.optim.AdamW
187
+ init_args:
188
+ lr: 0.00019511523812262233
189
+ weight_decay: 0.0003813804936812436
190
+ lr_scheduler:
191
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
192
+ init_args:
193
+ T_max: 2000
194
+ eta_min: 1e-07
195
+
Linear-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d14cec97c25c273f97009965293702ec5eda1f56648cfa0f2fac0acdf6b3459
3
+ size 49646459
Linear-ODE/train.yaml ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 7.080372063368751
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
14
+ gamma: null
15
+ epsilon: null
16
+ differential_equation_type: "ODE"
17
+ integrator_kwargs:
18
+ method: "euler"
19
+ velocity_annealing_factor: 13.620695525269845
20
+ correct_center_of_mass_motion: true
21
+ # lattice vectors
22
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
23
+ init_args:
24
+ interpolant: omg.si.interpolants.LinearInterpolant
25
+ gamma: null
26
+ epsilon: null
27
+ differential_equation_type: "ODE"
28
+ integrator_kwargs:
29
+ method: "euler"
30
+ velocity_annealing_factor: 1.0679662602192312
31
+ correct_center_of_mass_motion: false
32
+ data_fields:
33
+ # if the order of the data_fields changes,
34
+ # the order of the above StochasticInterpolant inputs must also change
35
+ - "species"
36
+ - "pos"
37
+ - "cell"
38
+ integration_time_steps: 150
39
+ relative_si_costs:
40
+ species_loss: 0.021815399034596464
41
+ pos_loss_b: 0.9775483266595605
42
+ cell_loss_b: 0.0006362743058429793
43
+ sampler:
44
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
45
+ init_args:
46
+ pos_distribution: null
47
+ cell_distribution:
48
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
49
+ init_args:
50
+ dataset_name: mp_20
51
+ species_distribution:
52
+ class_path: omg.sampler.distributions.MaskDistribution
53
+ model:
54
+ class_path: omg.model.model.Model
55
+ init_args:
56
+ encoder:
57
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
58
+ head:
59
+ class_path: omg.model.heads.pass_through.PassThrough
60
+ time_embedder:
61
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
62
+ init_args:
63
+ dim: 256
64
+ use_min_perm_dist: True
65
+ float_32_matmul_precision: "high"
66
+ validation_mode: "dng_eval"
67
+ dataset_name: "mp_20"
68
+ data:
69
+ train_dataset:
70
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
71
+ init_args:
72
+ dataset:
73
+ class_path: omg.datamodule.datamodule.DataModule
74
+ init_args:
75
+ lmdb_paths:
76
+ - "data/mp_20/train.lmdb"
77
+ niggli: False
78
+ val_dataset:
79
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
80
+ init_args:
81
+ dataset:
82
+ class_path: omg.datamodule.datamodule.DataModule
83
+ init_args:
84
+ lmdb_paths:
85
+ - "data/mp_20/val.lmdb"
86
+ niggli: False
87
+ predict_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/mp_20/test.lmdb"
95
+ niggli: False
96
+ batch_size: 512
97
+ num_workers: 4
98
+ pin_memory: True
99
+ persistent_workers: True
100
+ trainer:
101
+ callbacks:
102
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
103
+ init_args:
104
+ filename: "best_val_loss_total"
105
+ save_top_k: 1
106
+ monitor: "val_loss_total"
107
+ save_weights_only: true
108
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
109
+ init_args:
110
+ filename: "best_val_dng_eval"
111
+ save_top_k: 1
112
+ monitor: "dng_eval"
113
+ save_weights_only: true
114
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
115
+ init_args:
116
+ filename: "best_val_wdist_density"
117
+ save_top_k: 1
118
+ monitor: "wdist_density"
119
+ save_weights_only: true
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_wdist_Nary"
123
+ save_top_k: 1
124
+ monitor: "wdist_Nary"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_wdist_CN"
129
+ save_top_k: 1
130
+ monitor: "wdist_CN"
131
+ save_weights_only: true
132
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
133
+ init_args:
134
+ filename: "best_val_cov_precision"
135
+ save_top_k: 1
136
+ monitor: "cov_precision"
137
+ mode: "max"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ filename: "best_val_cov_recall"
142
+ save_top_k: 1
143
+ monitor: "cov_recall"
144
+ mode: "max"
145
+ save_weights_only: true
146
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
147
+ init_args:
148
+ filename: "best_val_validity"
149
+ save_top_k: 1
150
+ monitor: "validity_rate"
151
+ mode: "max"
152
+ save_weights_only: true
153
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
154
+ init_args:
155
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
156
+ monitor: "val_loss_total"
157
+ every_n_epochs: 100
158
+ save_weights_only: false
159
+ gradient_clip_val: 0.5
160
+ num_sanity_val_steps: 0
161
+ precision: "32-true"
162
+ max_epochs: 10000
163
+ enable_progress_bar: false
164
+ check_val_every_n_epoch: 100
165
+ optimizer:
166
+ class_path: torch.optim.Adam
167
+ init_args:
168
+ lr: 0.001736512450391209
Linear-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33f3460360ee97b7f2f18fd4d9553b7f1c9a523f4470bb0587e236c9b4022a6a
3
+ size 148494394
Linear-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 0.18946955217679085
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
14
+ gamma:
15
+ class_path: omg.si.gamma.LatentGammaSqrt
16
+ init_args:
17
+ a: 0.018159684059653552
18
+ epsilon:
19
+ class_path: omg.si.epsilon.VanishingEpsilon
20
+ init_args:
21
+ c: 9.74900863316411
22
+ mu: 0.17191546490562354
23
+ sigma: 0.029425925880471573
24
+ differential_equation_type: "SDE"
25
+ integrator_kwargs:
26
+ method: "euler"
27
+ dt: 0.0014076164225116372
28
+ velocity_annealing_factor: 6.334345265874859
29
+ correct_center_of_mass_motion: true
30
+ # lattice vectors
31
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
32
+ init_args:
33
+ interpolant: omg.si.interpolants.LinearInterpolant
34
+ gamma: null
35
+ epsilon: null
36
+ differential_equation_type: "ODE"
37
+ integrator_kwargs:
38
+ method: "euler"
39
+ velocity_annealing_factor: 1.0674474901964888
40
+ correct_center_of_mass_motion: false
41
+ data_fields:
42
+ # if the order of the data_fields changes,
43
+ # the order of the above StochasticInterpolant inputs must also change
44
+ - "species"
45
+ - "pos"
46
+ - "cell"
47
+ integration_time_steps: 710
48
+ relative_si_costs:
49
+ species_loss: 0.5918064683979826
50
+ pos_loss_b: 0.13091891010303253
51
+ pos_loss_z: 0.27077286248215743
52
+ cell_loss_b: 0.006501759016827464
53
+ sampler:
54
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
55
+ init_args:
56
+ pos_distribution: null
57
+ cell_distribution:
58
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
59
+ init_args:
60
+ dataset_name: mp_20
61
+ species_distribution:
62
+ class_path: omg.sampler.distributions.MaskDistribution
63
+ model:
64
+ class_path: omg.model.model.Model
65
+ init_args:
66
+ encoder:
67
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
68
+ head:
69
+ class_path: omg.model.heads.pass_through.PassThrough
70
+ time_embedder:
71
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
72
+ init_args:
73
+ dim: 256
74
+ use_min_perm_dist: True
75
+ float_32_matmul_precision: "high"
76
+ validation_mode: "dng_eval"
77
+ dataset_name: "mp_20"
78
+ data:
79
+ train_dataset:
80
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
81
+ init_args:
82
+ dataset:
83
+ class_path: omg.datamodule.datamodule.DataModule
84
+ init_args:
85
+ lmdb_paths:
86
+ - "data/mp_20/train.lmdb"
87
+ niggli: False
88
+ val_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/mp_20/val.lmdb"
96
+ niggli: False
97
+ predict_dataset:
98
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
99
+ init_args:
100
+ dataset:
101
+ class_path: omg.datamodule.datamodule.DataModule
102
+ init_args:
103
+ lmdb_paths:
104
+ - "data/mp_20/test.lmdb"
105
+ niggli: False
106
+ batch_size: 32
107
+ num_workers: 4
108
+ pin_memory: True
109
+ persistent_workers: True
110
+ trainer:
111
+ callbacks:
112
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
113
+ init_args:
114
+ filename: "best_val_loss_total"
115
+ save_top_k: 1
116
+ monitor: "val_loss_total"
117
+ save_weights_only: true
118
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
119
+ init_args:
120
+ filename: "best_val_dng_eval"
121
+ save_top_k: 1
122
+ monitor: "dng_eval"
123
+ save_weights_only: true
124
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
125
+ init_args:
126
+ filename: "best_val_wdist_density"
127
+ save_top_k: 1
128
+ monitor: "wdist_density"
129
+ save_weights_only: true
130
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
131
+ init_args:
132
+ filename: "best_val_wdist_Nary"
133
+ save_top_k: 1
134
+ monitor: "wdist_Nary"
135
+ save_weights_only: true
136
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
137
+ init_args:
138
+ filename: "best_val_wdist_CN"
139
+ save_top_k: 1
140
+ monitor: "wdist_CN"
141
+ save_weights_only: true
142
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
143
+ init_args:
144
+ filename: "best_val_cov_precision"
145
+ save_top_k: 1
146
+ monitor: "cov_precision"
147
+ mode: "max"
148
+ save_weights_only: true
149
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
150
+ init_args:
151
+ filename: "best_val_cov_recall"
152
+ save_top_k: 1
153
+ monitor: "cov_recall"
154
+ mode: "max"
155
+ save_weights_only: true
156
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
157
+ init_args:
158
+ filename: "best_val_validity"
159
+ save_top_k: 1
160
+ monitor: "validity_rate"
161
+ mode: "max"
162
+ save_weights_only: true
163
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
164
+ init_args:
165
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
166
+ monitor: "val_loss_total"
167
+ every_n_epochs: 100
168
+ save_weights_only: false
169
+ gradient_clip_val: 0.5
170
+ gradient_clip_algorithm: "value"
171
+ num_sanity_val_steps: 0
172
+ precision: "32-true"
173
+ max_epochs: 2000
174
+ enable_progress_bar: false
175
+ check_val_every_n_epoch: 100
176
+ optimizer:
177
+ class_path: torch.optim.AdamW
178
+ init_args:
179
+ lr: 0.00019745455354877462
180
+ weight_decay: 0.0003111161289640361
181
+ lr_scheduler:
182
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
183
+ init_args:
184
+ T_max: 2000
185
+ eta_min: 1e-07
186
+
Trig-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3d3c53b60d074e40860ff25e65bedf14da18eeabef997f1e0a795666ab4559e
3
+ size 148519034
Trig-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 27.249112246908787
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
14
+ gamma:
15
+ class_path: omg.si.gamma.LatentGammaSqrt
16
+ init_args:
17
+ a: 0.02697960692675219
18
+ epsilon: null
19
+ differential_equation_type: "ODE"
20
+ integrator_kwargs:
21
+ method: "euler"
22
+ velocity_annealing_factor: 7.788016364580001
23
+ correct_center_of_mass_motion: true
24
+ # lattice vectors
25
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
26
+ init_args:
27
+ interpolant: omg.si.interpolants.LinearInterpolant
28
+ gamma:
29
+ class_path: omg.si.gamma.LatentGammaSqrt
30
+ init_args:
31
+ a: 0.8480953426626178
32
+ epsilon:
33
+ class_path: omg.si.epsilon.VanishingEpsilon
34
+ init_args:
35
+ c: 3.673041703474942
36
+ mu: 0.0706838941982046
37
+ sigma: 0.01812545867373373
38
+ differential_equation_type: "SDE"
39
+ integrator_kwargs:
40
+ method: "euler"
41
+ dt: 0.001469808747060597
42
+ velocity_annealing_factor: 0.2916081322471492
43
+ correct_center_of_mass_motion: false
44
+ data_fields:
45
+ # if the order of the data_fields changes,
46
+ # the order of the above StochasticInterpolant inputs must also change
47
+ - "species"
48
+ - "pos"
49
+ - "cell"
50
+ integration_time_steps: 680
51
+ relative_si_costs:
52
+ species_loss: 0.43055618267791895
53
+ pos_loss_b: 0.2322254093385872
54
+ cell_loss_b: 0.003464180396862092
55
+ cell_loss_z: 0.33375422758663176
56
+ sampler:
57
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
58
+ init_args:
59
+ pos_distribution: null
60
+ cell_distribution:
61
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
62
+ init_args:
63
+ dataset_name: mp_20
64
+ species_distribution:
65
+ class_path: omg.sampler.distributions.MaskDistribution
66
+ model:
67
+ class_path: omg.model.model.Model
68
+ init_args:
69
+ encoder:
70
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
71
+ head:
72
+ class_path: omg.model.heads.pass_through.PassThrough
73
+ time_embedder:
74
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
75
+ init_args:
76
+ dim: 256
77
+ use_min_perm_dist: True
78
+ float_32_matmul_precision: "high"
79
+ validation_mode: "dng_eval"
80
+ dataset_name: "mp_20"
81
+ data:
82
+ train_dataset:
83
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
84
+ init_args:
85
+ dataset:
86
+ class_path: omg.datamodule.datamodule.DataModule
87
+ init_args:
88
+ lmdb_paths:
89
+ - "data/mp_20/train.lmdb"
90
+ niggli: True
91
+ val_dataset:
92
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
93
+ init_args:
94
+ dataset:
95
+ class_path: omg.datamodule.datamodule.DataModule
96
+ init_args:
97
+ lmdb_paths:
98
+ - "data/mp_20/val.lmdb"
99
+ niggli: True
100
+ predict_dataset:
101
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
102
+ init_args:
103
+ dataset:
104
+ class_path: omg.datamodule.datamodule.DataModule
105
+ init_args:
106
+ lmdb_paths:
107
+ - "data/mp_20/test.lmdb"
108
+ niggli: True
109
+ batch_size: 32
110
+ num_workers: 4
111
+ pin_memory: True
112
+ persistent_workers: True
113
+ trainer:
114
+ callbacks:
115
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
116
+ init_args:
117
+ filename: "best_val_loss_total"
118
+ save_top_k: 1
119
+ monitor: "val_loss_total"
120
+ save_weights_only: true
121
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
122
+ init_args:
123
+ filename: "best_val_dng_eval"
124
+ save_top_k: 1
125
+ monitor: "dng_eval"
126
+ save_weights_only: true
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_wdist_density"
130
+ save_top_k: 1
131
+ monitor: "wdist_density"
132
+ save_weights_only: true
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_wdist_Nary"
136
+ save_top_k: 1
137
+ monitor: "wdist_Nary"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ filename: "best_val_wdist_CN"
142
+ save_top_k: 1
143
+ monitor: "wdist_CN"
144
+ save_weights_only: true
145
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
146
+ init_args:
147
+ filename: "best_val_cov_precision"
148
+ save_top_k: 1
149
+ monitor: "cov_precision"
150
+ mode: "max"
151
+ save_weights_only: true
152
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
153
+ init_args:
154
+ filename: "best_val_cov_recall"
155
+ save_top_k: 1
156
+ monitor: "cov_recall"
157
+ mode: "max"
158
+ save_weights_only: true
159
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
160
+ init_args:
161
+ filename: "best_val_validity"
162
+ save_top_k: 1
163
+ monitor: "validity_rate"
164
+ mode: "max"
165
+ save_weights_only: true
166
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
167
+ init_args:
168
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
169
+ monitor: "val_loss_total"
170
+ every_n_epochs: 100
171
+ save_weights_only: false
172
+ gradient_clip_val: 0.5
173
+ gradient_clip_algorithm: "value"
174
+ num_sanity_val_steps: 0
175
+ precision: "32-true"
176
+ max_epochs: 2000
177
+ enable_progress_bar: false
178
+ check_val_every_n_epoch: 100
179
+ optimizer:
180
+ class_path: torch.optim.AdamW
181
+ init_args:
182
+ lr: 0.00014843344531647814
183
+ weight_decay: 0.00032033785912707614
184
+ lr_scheduler:
185
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
186
+ init_args:
187
+ T_max: 2000
188
+ eta_min: 1e-07
189
+
Trig-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da530d4e3b0e3ae3537da7ab517a3b2cfa754637289d72f7ff91f761a2a66557
3
+ size 148480896
Trig-ODE/train.yaml ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 32.687148090341246
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
14
+ gamma: null
15
+ epsilon: null
16
+ differential_equation_type: "ODE"
17
+ integrator_kwargs:
18
+ method: "euler"
19
+ velocity_annealing_factor: 8.59355026270501
20
+ correct_center_of_mass_motion: true
21
+ # lattice vectors
22
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
23
+ init_args:
24
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
25
+ gamma:
26
+ class_path: omg.si.gamma.LatentGammaSqrt
27
+ init_args:
28
+ a: 1.182707731733833
29
+ epsilon: null
30
+ differential_equation_type: "ODE"
31
+ integrator_kwargs:
32
+ method: "euler"
33
+ velocity_annealing_factor: 0.308921595643179
34
+ correct_center_of_mass_motion: false
35
+ data_fields:
36
+ # if the order of the data_fields changes,
37
+ # the order of the above StochasticInterpolant inputs must also change
38
+ - "species"
39
+ - "pos"
40
+ - "cell"
41
+ integration_time_steps: 860
42
+ relative_si_costs:
43
+ species_loss: 0.6674861836833045
44
+ pos_loss_b: 0.33016840408608256
45
+ cell_loss_b: 0.002345412230612938
46
+ sampler:
47
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
48
+ init_args:
49
+ pos_distribution: null
50
+ cell_distribution:
51
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
52
+ init_args:
53
+ dataset_name: mp_20
54
+ species_distribution:
55
+ class_path: omg.sampler.distributions.MaskDistribution
56
+ model:
57
+ class_path: omg.model.model.Model
58
+ init_args:
59
+ encoder:
60
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
61
+ head:
62
+ class_path: omg.model.heads.pass_through.PassThrough
63
+ time_embedder:
64
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
65
+ init_args:
66
+ dim: 256
67
+ use_min_perm_dist: True
68
+ float_32_matmul_precision: "high"
69
+ validation_mode: "dng_eval"
70
+ dataset_name: "mp_20"
71
+ data:
72
+ train_dataset:
73
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
74
+ init_args:
75
+ dataset:
76
+ class_path: omg.datamodule.datamodule.DataModule
77
+ init_args:
78
+ lmdb_paths:
79
+ - "data/mp_20/train.lmdb"
80
+ niggli: False
81
+ val_dataset:
82
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
83
+ init_args:
84
+ dataset:
85
+ class_path: omg.datamodule.datamodule.DataModule
86
+ init_args:
87
+ lmdb_paths:
88
+ - "data/mp_20/val.lmdb"
89
+ niggli: False
90
+ predict_dataset:
91
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
92
+ init_args:
93
+ dataset:
94
+ class_path: omg.datamodule.datamodule.DataModule
95
+ init_args:
96
+ lmdb_paths:
97
+ - "data/mp_20/test.lmdb"
98
+ niggli: False
99
+ batch_size: 128
100
+ num_workers: 4
101
+ pin_memory: True
102
+ persistent_workers: True
103
+ trainer:
104
+ callbacks:
105
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
106
+ init_args:
107
+ filename: "best_val_loss_total"
108
+ save_top_k: 1
109
+ monitor: "val_loss_total"
110
+ save_weights_only: true
111
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
112
+ init_args:
113
+ filename: "best_val_dng_eval"
114
+ save_top_k: 1
115
+ monitor: "dng_eval"
116
+ save_weights_only: true
117
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
118
+ init_args:
119
+ filename: "best_val_wdist_density"
120
+ save_top_k: 1
121
+ monitor: "wdist_density"
122
+ save_weights_only: true
123
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
124
+ init_args:
125
+ filename: "best_val_wdist_Nary"
126
+ save_top_k: 1
127
+ monitor: "wdist_Nary"
128
+ save_weights_only: true
129
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
130
+ init_args:
131
+ filename: "best_val_wdist_CN"
132
+ save_top_k: 1
133
+ monitor: "wdist_CN"
134
+ save_weights_only: true
135
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
136
+ init_args:
137
+ filename: "best_val_cov_precision"
138
+ save_top_k: 1
139
+ monitor: "cov_precision"
140
+ mode: "max"
141
+ save_weights_only: true
142
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
143
+ init_args:
144
+ filename: "best_val_cov_recall"
145
+ save_top_k: 1
146
+ monitor: "cov_recall"
147
+ mode: "max"
148
+ save_weights_only: true
149
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
150
+ init_args:
151
+ filename: "best_val_validity"
152
+ save_top_k: 1
153
+ monitor: "validity_rate"
154
+ mode: "max"
155
+ save_weights_only: true
156
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
157
+ init_args:
158
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
159
+ monitor: "val_loss_total"
160
+ every_n_epochs: 100
161
+ save_weights_only: false
162
+ gradient_clip_val: 0.5
163
+ gradient_clip_algorithm: "value"
164
+ num_sanity_val_steps: 0
165
+ precision: "32-true"
166
+ max_epochs: 2000
167
+ enable_progress_bar: false
168
+ check_val_every_n_epoch: 100
169
+ optimizer:
170
+ class_path: torch.optim.AdamW
171
+ init_args:
172
+ lr: 0.002704094492670699
173
+ weight_decay: 0.0006070253248675564
174
+ lr_scheduler:
175
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
176
+ init_args:
177
+ T_max: 2000
178
+ eta_min: 1e-07
179
+
Trig-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af001ad47f40c2794e5115785c33bb3e26489cd4a4ff9a874af8ea41594d9f0c
3
+ size 148494458
Trig-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 13.14607468893319
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
14
+ gamma:
15
+ class_path: omg.si.gamma.LatentGammaSqrt
16
+ init_args:
17
+ a: 0.022769980672795356
18
+ epsilon:
19
+ class_path: omg.si.epsilon.VanishingEpsilon
20
+ init_args:
21
+ c: 2.621699079870832
22
+ mu: 0.15417087293483117
23
+ sigma: 0.017962649662214652
24
+ differential_equation_type: "SDE"
25
+ integrator_kwargs:
26
+ method: "euler"
27
+ dt: 0.0013148878933861852
28
+ velocity_annealing_factor: 12.80156329264574
29
+ correct_center_of_mass_motion: true
30
+ # lattice vectors
31
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
32
+ init_args:
33
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
34
+ gamma:
35
+ class_path: omg.si.gamma.LatentGammaSqrt
36
+ init_args:
37
+ a: 0.31566788838946225
38
+ epsilon: null
39
+ differential_equation_type: "ODE"
40
+ integrator_kwargs:
41
+ method: "euler"
42
+ velocity_annealing_factor: 4.36422932474701
43
+ correct_center_of_mass_motion: false
44
+ data_fields:
45
+ # if the order of the data_fields changes,
46
+ # the order of the above StochasticInterpolant inputs must also change
47
+ - "species"
48
+ - "pos"
49
+ - "cell"
50
+ integration_time_steps: 760
51
+ relative_si_costs:
52
+ species_loss: 0.13597807419582586
53
+ pos_loss_b: 0.6304347545056598
54
+ pos_loss_z: 0.0753230160674198
55
+ cell_loss_b: 0.15826415523109444
56
+ sampler:
57
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
58
+ init_args:
59
+ pos_distribution: null
60
+ cell_distribution:
61
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
62
+ init_args:
63
+ dataset_name: mp_20
64
+ species_distribution:
65
+ class_path: omg.sampler.distributions.MaskDistribution
66
+ model:
67
+ class_path: omg.model.model.Model
68
+ init_args:
69
+ encoder:
70
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
71
+ head:
72
+ class_path: omg.model.heads.pass_through.PassThrough
73
+ time_embedder:
74
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
75
+ init_args:
76
+ dim: 256
77
+ use_min_perm_dist: True
78
+ float_32_matmul_precision: "high"
79
+ validation_mode: "dng_eval"
80
+ dataset_name: "mp_20"
81
+ data:
82
+ train_dataset:
83
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
84
+ init_args:
85
+ dataset:
86
+ class_path: omg.datamodule.datamodule.DataModule
87
+ init_args:
88
+ lmdb_paths:
89
+ - "data/mp_20/train.lmdb"
90
+ niggli: False
91
+ val_dataset:
92
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
93
+ init_args:
94
+ dataset:
95
+ class_path: omg.datamodule.datamodule.DataModule
96
+ init_args:
97
+ lmdb_paths:
98
+ - "data/mp_20/val.lmdb"
99
+ niggli: False
100
+ predict_dataset:
101
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
102
+ init_args:
103
+ dataset:
104
+ class_path: omg.datamodule.datamodule.DataModule
105
+ init_args:
106
+ lmdb_paths:
107
+ - "data/mp_20/test.lmdb"
108
+ niggli: False
109
+ batch_size: 256
110
+ num_workers: 4
111
+ pin_memory: True
112
+ persistent_workers: True
113
+ trainer:
114
+ callbacks:
115
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
116
+ init_args:
117
+ filename: "best_val_loss_total"
118
+ save_top_k: 1
119
+ monitor: "val_loss_total"
120
+ save_weights_only: true
121
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
122
+ init_args:
123
+ filename: "best_val_dng_eval"
124
+ save_top_k: 1
125
+ monitor: "dng_eval"
126
+ save_weights_only: true
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_wdist_density"
130
+ save_top_k: 1
131
+ monitor: "wdist_density"
132
+ save_weights_only: true
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_wdist_Nary"
136
+ save_top_k: 1
137
+ monitor: "wdist_Nary"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ filename: "best_val_wdist_CN"
142
+ save_top_k: 1
143
+ monitor: "wdist_CN"
144
+ save_weights_only: true
145
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
146
+ init_args:
147
+ filename: "best_val_cov_precision"
148
+ save_top_k: 1
149
+ monitor: "cov_precision"
150
+ mode: "max"
151
+ save_weights_only: true
152
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
153
+ init_args:
154
+ filename: "best_val_cov_recall"
155
+ save_top_k: 1
156
+ monitor: "cov_recall"
157
+ mode: "max"
158
+ save_weights_only: true
159
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
160
+ init_args:
161
+ filename: "best_val_validity"
162
+ save_top_k: 1
163
+ monitor: "validity_rate"
164
+ mode: "max"
165
+ save_weights_only: true
166
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
167
+ init_args:
168
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
169
+ monitor: "val_loss_total"
170
+ every_n_epochs: 100
171
+ save_weights_only: false
172
+ gradient_clip_val: 0.5
173
+ gradient_clip_algorithm: "value"
174
+ num_sanity_val_steps: 0
175
+ precision: "32-true"
176
+ max_epochs: 2000
177
+ enable_progress_bar: false
178
+ check_val_every_n_epoch: 100
179
+ optimizer:
180
+ class_path: torch.optim.AdamW
181
+ init_args:
182
+ lr: 0.0007969633652411341
183
+ weight_decay: 1.803908894626558e-05
184
+ lr_scheduler:
185
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
186
+ init_args:
187
+ T_max: 2000
188
+ eta_min: 1e-07
189
+
VESBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa131d3a8b1c1168441d6164f3247d7f9b364b57f037b3c0c35dbb66297e2951
3
+ size 148519354
VESBD-ODE/train.yaml ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 5.870180115373019
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
12
+ init_args:
13
+ interpolant:
14
+ class_path: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolantVE
15
+ init_args:
16
+ sigma:
17
+ class_path: omg.si.sigma.GeometricSigma
18
+ init_args:
19
+ sigma_min: 0.0020828565391521787
20
+ sigma_max: 0.8318656965968637
21
+ epsilon: null
22
+ differential_equation_type: "ODE"
23
+ integrator_kwargs:
24
+ method: "euler"
25
+ velocity_annealing_factor: 12.718434028622262
26
+ correct_center_of_mass_motion: true
27
+ predict_velocity: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.LinearInterpolant
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaSqrt
34
+ init_args:
35
+ a: 0.9128957436969677
36
+ epsilon:
37
+ class_path: omg.si.epsilon.VanishingEpsilon
38
+ init_args:
39
+ c: 5.79866710636582
40
+ mu: 0.2500005563766677
41
+ sigma: 0.020369240775387647
42
+ differential_equation_type: "SDE"
43
+ integrator_kwargs:
44
+ method: "euler"
45
+ dt: 0.0030334345065057278
46
+ velocity_annealing_factor: 0.9750213755072251
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 330
55
+ relative_si_costs:
56
+ species_loss: 0.09898280162138558
57
+ pos_loss_b: 0.22090708494461975
58
+ cell_loss_b: 0.04296279641692716
59
+ cell_loss_z: 0.6371473170170674
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution:
64
+ class_path: omg.sampler.distributions.NormalDistribution
65
+ init_args:
66
+ scale: 0.45439306223842724
67
+ cell_distribution:
68
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
69
+ init_args:
70
+ dataset_name: mp_20
71
+ species_distribution:
72
+ class_path: omg.sampler.distributions.MaskDistribution
73
+ model:
74
+ class_path: omg.model.model.Model
75
+ init_args:
76
+ encoder:
77
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
78
+ head:
79
+ class_path: omg.model.heads.pass_through.PassThrough
80
+ time_embedder:
81
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
82
+ init_args:
83
+ dim: 256
84
+ use_min_perm_dist: False
85
+ float_32_matmul_precision: "high"
86
+ validation_mode: "dng_eval"
87
+ dataset_name: "mp_20"
88
+ data:
89
+ train_dataset:
90
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
91
+ init_args:
92
+ dataset:
93
+ class_path: omg.datamodule.datamodule.DataModule
94
+ init_args:
95
+ lmdb_paths:
96
+ - "data/mp_20/train.lmdb"
97
+ niggli: True
98
+ val_dataset:
99
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
100
+ init_args:
101
+ dataset:
102
+ class_path: omg.datamodule.datamodule.DataModule
103
+ init_args:
104
+ lmdb_paths:
105
+ - "data/mp_20/val.lmdb"
106
+ niggli: True
107
+ predict_dataset:
108
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
109
+ init_args:
110
+ dataset:
111
+ class_path: omg.datamodule.datamodule.DataModule
112
+ init_args:
113
+ lmdb_paths:
114
+ - "data/mp_20/test.lmdb"
115
+ niggli: True
116
+ batch_size: 256
117
+ num_workers: 4
118
+ pin_memory: True
119
+ persistent_workers: True
120
+ trainer:
121
+ callbacks:
122
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
123
+ init_args:
124
+ filename: "best_val_loss_total"
125
+ save_top_k: 1
126
+ monitor: "val_loss_total"
127
+ save_weights_only: true
128
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
129
+ init_args:
130
+ filename: "best_val_dng_eval"
131
+ save_top_k: 1
132
+ monitor: "dng_eval"
133
+ save_weights_only: true
134
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
135
+ init_args:
136
+ filename: "best_val_wdist_density"
137
+ save_top_k: 1
138
+ monitor: "wdist_density"
139
+ save_weights_only: true
140
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
141
+ init_args:
142
+ filename: "best_val_wdist_Nary"
143
+ save_top_k: 1
144
+ monitor: "wdist_Nary"
145
+ save_weights_only: true
146
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
147
+ init_args:
148
+ filename: "best_val_wdist_CN"
149
+ save_top_k: 1
150
+ monitor: "wdist_CN"
151
+ save_weights_only: true
152
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
153
+ init_args:
154
+ filename: "best_val_cov_precision"
155
+ save_top_k: 1
156
+ monitor: "cov_precision"
157
+ mode: "max"
158
+ save_weights_only: true
159
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
160
+ init_args:
161
+ filename: "best_val_cov_recall"
162
+ save_top_k: 1
163
+ monitor: "cov_recall"
164
+ mode: "max"
165
+ save_weights_only: true
166
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
167
+ init_args:
168
+ filename: "best_val_validity"
169
+ save_top_k: 1
170
+ monitor: "validity_rate"
171
+ mode: "max"
172
+ save_weights_only: true
173
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
174
+ init_args:
175
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
176
+ monitor: "val_loss_total"
177
+ every_n_epochs: 100
178
+ save_weights_only: false
179
+ gradient_clip_val: 0.5
180
+ gradient_clip_algorithm: "value"
181
+ num_sanity_val_steps: 0
182
+ precision: "32-true"
183
+ max_epochs: 2000
184
+ enable_progress_bar: false
185
+ check_val_every_n_epoch: 100
186
+ optimizer:
187
+ class_path: torch.optim.AdamW
188
+ init_args:
189
+ lr: 0.0018098696508625563
190
+ weight_decay: 0.00026498129464991104
191
+ lr_scheduler:
192
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
193
+ init_args:
194
+ T_max: 2000
195
+ eta_min: 1e-07
196
+
VPSBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae9cc89dea91443d4f903d7d2e9c18a2f2828e5f1318f4f75a4975cc708f5fd4
3
+ size 148481024
VPSBD-ODE/train.yaml ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 20.267191359937392
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
14
+ epsilon: null
15
+ differential_equation_type: "ODE"
16
+ integrator_kwargs:
17
+ method: "euler"
18
+ velocity_annealing_factor: 2.301841820941901
19
+ correct_center_of_mass_motion: true
20
+ predict_velocity: true
21
+ # lattice vectors
22
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
23
+ init_args:
24
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
25
+ gamma:
26
+ class_path: omg.si.gamma.LatentGammaSqrt
27
+ init_args:
28
+ a: 7.796692096273471
29
+ epsilon: null
30
+ differential_equation_type: "ODE"
31
+ integrator_kwargs:
32
+ method: "euler"
33
+ velocity_annealing_factor: 2.741014121366449
34
+ correct_center_of_mass_motion: false
35
+ data_fields:
36
+ # if the order of the data_fields changes,
37
+ # the order of the above StochasticInterpolant inputs must also change
38
+ - "species"
39
+ - "pos"
40
+ - "cell"
41
+ integration_time_steps: 710
42
+ relative_si_costs:
43
+ species_loss: 0.5499726353395065
44
+ pos_loss_b: 0.40529122146887925
45
+ cell_loss_b: 0.04473614319161423
46
+ sampler:
47
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
48
+ init_args:
49
+ pos_distribution:
50
+ class_path: omg.sampler.distributions.NormalDistribution
51
+ init_args:
52
+ scale: 0.23441087988918383
53
+ cell_distribution:
54
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
55
+ init_args:
56
+ dataset_name: mp_20
57
+ species_distribution:
58
+ class_path: omg.sampler.distributions.MaskDistribution
59
+ model:
60
+ class_path: omg.model.model.Model
61
+ init_args:
62
+ encoder:
63
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
64
+ head:
65
+ class_path: omg.model.heads.pass_through.PassThrough
66
+ time_embedder:
67
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
68
+ init_args:
69
+ dim: 256
70
+ use_min_perm_dist: False
71
+ float_32_matmul_precision: "high"
72
+ validation_mode: "dng_eval"
73
+ dataset_name: "mp_20"
74
+ data:
75
+ train_dataset:
76
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
77
+ init_args:
78
+ dataset:
79
+ class_path: omg.datamodule.datamodule.DataModule
80
+ init_args:
81
+ lmdb_paths:
82
+ - "data/mp_20/train.lmdb"
83
+ niggli: False
84
+ val_dataset:
85
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
86
+ init_args:
87
+ dataset:
88
+ class_path: omg.datamodule.datamodule.DataModule
89
+ init_args:
90
+ lmdb_paths:
91
+ - "data/mp_20/val.lmdb"
92
+ niggli: False
93
+ predict_dataset:
94
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
95
+ init_args:
96
+ dataset:
97
+ class_path: omg.datamodule.datamodule.DataModule
98
+ init_args:
99
+ lmdb_paths:
100
+ - "data/mp_20/test.lmdb"
101
+ niggli: False
102
+ batch_size: 256
103
+ num_workers: 4
104
+ pin_memory: True
105
+ persistent_workers: True
106
+ trainer:
107
+ callbacks:
108
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
109
+ init_args:
110
+ filename: "best_val_loss_total"
111
+ save_top_k: 1
112
+ monitor: "val_loss_total"
113
+ save_weights_only: true
114
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
115
+ init_args:
116
+ filename: "best_val_dng_eval"
117
+ save_top_k: 1
118
+ monitor: "dng_eval"
119
+ save_weights_only: true
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_wdist_density"
123
+ save_top_k: 1
124
+ monitor: "wdist_density"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_wdist_Nary"
129
+ save_top_k: 1
130
+ monitor: "wdist_Nary"
131
+ save_weights_only: true
132
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
133
+ init_args:
134
+ filename: "best_val_wdist_CN"
135
+ save_top_k: 1
136
+ monitor: "wdist_CN"
137
+ save_weights_only: true
138
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
139
+ init_args:
140
+ filename: "best_val_cov_precision"
141
+ save_top_k: 1
142
+ monitor: "cov_precision"
143
+ mode: "max"
144
+ save_weights_only: true
145
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
146
+ init_args:
147
+ filename: "best_val_cov_recall"
148
+ save_top_k: 1
149
+ monitor: "cov_recall"
150
+ mode: "max"
151
+ save_weights_only: true
152
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
153
+ init_args:
154
+ filename: "best_val_validity"
155
+ save_top_k: 1
156
+ monitor: "validity_rate"
157
+ mode: "max"
158
+ save_weights_only: true
159
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
160
+ init_args:
161
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
162
+ monitor: "val_loss_total"
163
+ every_n_epochs: 100
164
+ save_weights_only: false
165
+ gradient_clip_val: 0.5
166
+ gradient_clip_algorithm: "value"
167
+ num_sanity_val_steps: 0
168
+ precision: "32-true"
169
+ max_epochs: 2000
170
+ enable_progress_bar: false
171
+ check_val_every_n_epoch: 100
172
+ optimizer:
173
+ class_path: torch.optim.AdamW
174
+ init_args:
175
+ lr: 0.00797735754708741
176
+ weight_decay: 1.923837446196394e-05
177
+ lr_scheduler:
178
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
179
+ init_args:
180
+ T_max: 2000
181
+ eta_min: 1e-07
182
+
VPSBD-SDE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bc99defffe2308a601267ab7ea233948ef3bfc033ff1b6e879525a50a0cda55
3
+ size 148532340
VPSBD-SDE/train.yaml ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
8
+ init_args:
9
+ noise: 8.517811450005286
10
+ # fractional coordinates
11
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
12
+ init_args:
13
+ interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
14
+ epsilon:
15
+ class_path: omg.si.epsilon.VanishingEpsilon
16
+ init_args:
17
+ c: 9.268934476283913
18
+ mu: 0.25243331190144214
19
+ sigma: 0.04584669320169394
20
+ differential_equation_type: "SDE"
21
+ integrator_kwargs:
22
+ method: "euler"
23
+ dt: 0.00114844657946378
24
+ velocity_annealing_factor: 9.059507260865466
25
+ correct_center_of_mass_motion: true
26
+ predict_velocity: true
27
+ # lattice vectors
28
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
29
+ init_args:
30
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
31
+ gamma:
32
+ class_path: omg.si.gamma.LatentGammaSqrt
33
+ init_args:
34
+ a: 3.0998874989979193
35
+ epsilon:
36
+ class_path: omg.si.epsilon.VanishingEpsilon
37
+ init_args:
38
+ c: 4.885221687138149
39
+ mu: 0.08252713485380268
40
+ sigma: 0.010096412051586533
41
+ differential_equation_type: "SDE"
42
+ integrator_kwargs:
43
+ method: "euler"
44
+ dt: 0.00114844657946378
45
+ velocity_annealing_factor: 11.76695215049249
46
+ correct_center_of_mass_motion: false
47
+ data_fields:
48
+ # if the order of the data_fields changes,
49
+ # the order of the above StochasticInterpolant inputs must also change
50
+ - "species"
51
+ - "pos"
52
+ - "cell"
53
+ integration_time_steps: 870
54
+ relative_si_costs:
55
+ species_loss: 0.35838215846575894
56
+ pos_loss_b: 0.5183735506925028
57
+ pos_loss_z: 0.0007900924652522647
58
+ cell_loss_b: 0.0044136759736567365
59
+ cell_loss_z: 0.11804052240282935
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution:
64
+ class_path: omg.sampler.distributions.NormalDistribution
65
+ init_args:
66
+ scale: 7.139671140709246
67
+ cell_distribution:
68
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
69
+ init_args:
70
+ dataset_name: mp_20
71
+ species_distribution:
72
+ class_path: omg.sampler.distributions.MaskDistribution
73
+ model:
74
+ class_path: omg.model.model.Model
75
+ init_args:
76
+ encoder:
77
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
78
+ head:
79
+ class_path: omg.model.heads.pass_through.PassThrough
80
+ time_embedder:
81
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
82
+ init_args:
83
+ dim: 256
84
+ use_min_perm_dist: False
85
+ float_32_matmul_precision: "high"
86
+ validation_mode: "dng_eval"
87
+ dataset_name: "mp_20"
88
+ data:
89
+ train_dataset:
90
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
91
+ init_args:
92
+ dataset:
93
+ class_path: omg.datamodule.datamodule.DataModule
94
+ init_args:
95
+ lmdb_paths:
96
+ - "data/mp_20/train.lmdb"
97
+ niggli: False
98
+ val_dataset:
99
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
100
+ init_args:
101
+ dataset:
102
+ class_path: omg.datamodule.datamodule.DataModule
103
+ init_args:
104
+ lmdb_paths:
105
+ - "data/mp_20/val.lmdb"
106
+ niggli: False
107
+ predict_dataset:
108
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
109
+ init_args:
110
+ dataset:
111
+ class_path: omg.datamodule.datamodule.DataModule
112
+ init_args:
113
+ lmdb_paths:
114
+ - "data/mp_20/test.lmdb"
115
+ niggli: False
116
+ batch_size: 512
117
+ num_workers: 4
118
+ pin_memory: True
119
+ persistent_workers: True
120
+ trainer:
121
+ callbacks:
122
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
123
+ init_args:
124
+ filename: "best_val_loss_total"
125
+ save_top_k: 1
126
+ monitor: "val_loss_total"
127
+ save_weights_only: true
128
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
129
+ init_args:
130
+ filename: "best_val_dng_eval"
131
+ save_top_k: 1
132
+ monitor: "dng_eval"
133
+ save_weights_only: true
134
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
135
+ init_args:
136
+ filename: "best_val_wdist_density"
137
+ save_top_k: 1
138
+ monitor: "wdist_density"
139
+ save_weights_only: true
140
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
141
+ init_args:
142
+ filename: "best_val_wdist_Nary"
143
+ save_top_k: 1
144
+ monitor: "wdist_Nary"
145
+ save_weights_only: true
146
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
147
+ init_args:
148
+ filename: "best_val_wdist_CN"
149
+ save_top_k: 1
150
+ monitor: "wdist_CN"
151
+ save_weights_only: true
152
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
153
+ init_args:
154
+ filename: "best_val_cov_precision"
155
+ save_top_k: 1
156
+ monitor: "cov_precision"
157
+ mode: "max"
158
+ save_weights_only: true
159
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
160
+ init_args:
161
+ filename: "best_val_cov_recall"
162
+ save_top_k: 1
163
+ monitor: "cov_recall"
164
+ mode: "max"
165
+ save_weights_only: true
166
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
167
+ init_args:
168
+ filename: "best_val_validity"
169
+ save_top_k: 1
170
+ monitor: "validity_rate"
171
+ mode: "max"
172
+ save_weights_only: true
173
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
174
+ init_args:
175
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
176
+ monitor: "val_loss_total"
177
+ every_n_epochs: 100
178
+ save_weights_only: false
179
+ gradient_clip_val: 0.5
180
+ gradient_clip_algorithm: "value"
181
+ num_sanity_val_steps: 0
182
+ precision: "32-true"
183
+ max_epochs: 2000
184
+ enable_progress_bar: false
185
+ check_val_every_n_epoch: 100
186
+ optimizer:
187
+ class_path: torch.optim.AdamW
188
+ init_args:
189
+ lr: 0.0014461332672089323
190
+ weight_decay: 0.0007097046414614019
191
+ lr_scheduler:
192
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
193
+ init_args:
194
+ T_max: 2000
195
+ eta_min: 1e-07
196
+