martirossyan commited on
Commit
d93de2f
·
verified ·
1 Parent(s): 6a8ebba

Upload 12 files

Browse files
Trig-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:055a50c0491baed117149cf4d183da0100ce21081b9f3de668ed0d7a89f2a88a
3
+ size 148099774
Trig-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.03337798944475465
16
+ epsilon: null
17
+ differential_equation_type: "ODE"
18
+ integrator_kwargs:
19
+ method: "euler"
20
+ velocity_annealing_factor: 13.545929738762764
21
+ correct_center_of_mass_motion: true
22
+ # lattice vectors
23
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
24
+ init_args:
25
+ interpolant: omg.si.interpolants.LinearInterpolant
26
+ gamma:
27
+ class_path: omg.si.gamma.LatentGammaSqrt
28
+ init_args:
29
+ a: 0.017261010545698854
30
+ epsilon:
31
+ class_path: omg.si.epsilon.VanishingEpsilon
32
+ init_args:
33
+ c: 0.8758328635983847
34
+ mu: 0.29744423858325936
35
+ sigma: 0.0052236060273636595
36
+ differential_equation_type: "SDE"
37
+ integrator_kwargs:
38
+ method: "euler"
39
+ dt: 0.0012811297783628106
40
+ velocity_annealing_factor: 2.380421528846764
41
+ correct_center_of_mass_motion: false
42
+ data_fields:
43
+ # if the order of the data_fields changes,
44
+ # the order of the above StochasticInterpolant inputs must also change
45
+ - "species"
46
+ - "pos"
47
+ - "cell"
48
+ integration_time_steps: 780
49
+ relative_si_costs:
50
+ species_loss: 0.0
51
+ pos_loss_b: 0.983015308902659
52
+ cell_loss_b: 0.01673796318800159
53
+ cell_loss_z: 0.0002467279093394523
54
+ sampler:
55
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
56
+ init_args:
57
+ pos_distribution: null
58
+ cell_distribution:
59
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
60
+ init_args:
61
+ dataset_name: alex_mp_20
62
+ species_distribution:
63
+ class_path: omg.sampler.distributions.MirrorData
64
+ model:
65
+ class_path: omg.model.model.Model
66
+ init_args:
67
+ encoder:
68
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
69
+ head:
70
+ class_path: omg.model.heads.pass_through.PassThrough
71
+ time_embedder:
72
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
73
+ init_args:
74
+ dim: 256
75
+ use_min_perm_dist: False
76
+ float_32_matmul_precision: "high"
77
+ validation_mode: "match_rate"
78
+ number_cpus: 7
79
+ dataset_name: "alex_mp_20"
80
+ data:
81
+ train_dataset:
82
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
83
+ init_args:
84
+ dataset:
85
+ class_path: omg.datamodule.datamodule.DataModule
86
+ init_args:
87
+ lmdb_paths:
88
+ - "data/alex_mp_20/train.lmdb"
89
+ niggli: True
90
+ val_dataset:
91
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
92
+ init_args:
93
+ dataset:
94
+ class_path: omg.datamodule.datamodule.DataModule
95
+ init_args:
96
+ lmdb_paths:
97
+ - "data/alex_mp_20/val.lmdb"
98
+ niggli: True
99
+ predict_dataset:
100
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
101
+ init_args:
102
+ dataset:
103
+ class_path: omg.datamodule.datamodule.DataModule
104
+ init_args:
105
+ lmdb_paths:
106
+ - "data/alex_mp_20/test.lmdb"
107
+ niggli: True
108
+ batch_size: 32
109
+ num_workers: 4
110
+ pin_memory: True
111
+ persistent_workers: True
112
+ trainer:
113
+ callbacks:
114
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
115
+ init_args:
116
+ filename: "best_val_loss_total"
117
+ save_top_k: 1
118
+ monitor: "val_loss_total"
119
+ save_weights_only: true
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_match_rate"
123
+ save_top_k: 1
124
+ monitor: "match_rate"
125
+ save_weights_only: true
126
+ mode: 'max'
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_rmsd"
130
+ save_top_k: 1
131
+ monitor: "mean_rmsd"
132
+ save_weights_only: true
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
136
+ monitor: "val_loss_total"
137
+ every_n_epochs: 100
138
+ save_weights_only: false
139
+ gradient_clip_val: 0.5
140
+ num_sanity_val_steps: 0
141
+ precision: "32-true"
142
+ max_epochs: 2000
143
+ enable_progress_bar: true
144
+ limit_val_batches: 0.1
145
+ check_val_every_n_epoch: 100
146
+ optimizer:
147
+ class_path: torch.optim.Adam
148
+ init_args:
149
+ lr: 8.341737878937152e-05
Trig-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:047eabb29f750a4f9d49d088d3f1e7cb1231ba58f3c5456d63859453fb347ff4
3
+ size 148100284
Trig-ODE/train.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma: null
13
+ epsilon: null
14
+ differential_equation_type: "ODE"
15
+ integrator_kwargs:
16
+ method: "euler"
17
+ velocity_annealing_factor: 12.34532470785473
18
+ correct_center_of_mass_motion: true
19
+ # lattice vectors
20
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
21
+ init_args:
22
+ interpolant:
23
+ class_path: omg.si.interpolants.EncoderDecoderInterpolant
24
+ init_args:
25
+ switch_time: 0.4080329374611481
26
+ power: 0.5
27
+ gamma:
28
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
29
+ init_args:
30
+ a: 5.270616141661882
31
+ switch_time: 0.4080329374611481
32
+ power: 0.5
33
+ epsilon:
34
+ class_path: omg.si.epsilon.VanishingEpsilon
35
+ init_args:
36
+ c: 4.354817546796119
37
+ mu: 0.2923928859901851
38
+ sigma: 0.04742031136770322
39
+ differential_equation_type: "SDE"
40
+ integrator_kwargs:
41
+ method: "euler"
42
+ dt: 0.005905325524508953
43
+ velocity_annealing_factor: 3.6141717997883447
44
+ correct_center_of_mass_motion: false
45
+ data_fields:
46
+ # if the order of the data_fields changes,
47
+ # the order of the above StochasticInterpolant inputs must also change
48
+ - "species"
49
+ - "pos"
50
+ - "cell"
51
+ integration_time_steps: 170
52
+ relative_si_costs:
53
+ species_loss: 0.0
54
+ pos_loss_b: 0.9967455480681945
55
+ cell_loss_b: 0.002271914623580616
56
+ cell_loss_z: 0.0009825373082248405
57
+ sampler:
58
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
59
+ init_args:
60
+ pos_distribution: null
61
+ cell_distribution:
62
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
63
+ init_args:
64
+ dataset_name: alex_mp_20
65
+ species_distribution:
66
+ class_path: omg.sampler.distributions.MirrorData
67
+ model:
68
+ class_path: omg.model.model.Model
69
+ init_args:
70
+ encoder:
71
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
72
+ head:
73
+ class_path: omg.model.heads.pass_through.PassThrough
74
+ time_embedder:
75
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
76
+ init_args:
77
+ dim: 256
78
+ use_min_perm_dist: False
79
+ float_32_matmul_precision: "high"
80
+ validation_mode: "match_rate"
81
+ number_cpus: 7
82
+ dataset_name: "alex_mp_20"
83
+ data:
84
+ train_dataset:
85
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
86
+ init_args:
87
+ dataset:
88
+ class_path: omg.datamodule.datamodule.DataModule
89
+ init_args:
90
+ lmdb_paths:
91
+ - "data/alex_mp_20/train.lmdb"
92
+ niggli: False
93
+ val_dataset:
94
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
95
+ init_args:
96
+ dataset:
97
+ class_path: omg.datamodule.datamodule.DataModule
98
+ init_args:
99
+ lmdb_paths:
100
+ - "data/alex_mp_20/val.lmdb"
101
+ niggli: False
102
+ predict_dataset:
103
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
104
+ init_args:
105
+ dataset:
106
+ class_path: omg.datamodule.datamodule.DataModule
107
+ init_args:
108
+ lmdb_paths:
109
+ - "data/alex_mp_20/test.lmdb"
110
+ niggli: False
111
+ batch_size: 32
112
+ num_workers: 4
113
+ pin_memory: True
114
+ persistent_workers: True
115
+ trainer:
116
+ callbacks:
117
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
118
+ init_args:
119
+ filename: "best_val_loss_total"
120
+ save_top_k: 1
121
+ monitor: "val_loss_total"
122
+ save_weights_only: true
123
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
124
+ init_args:
125
+ filename: "best_val_match_rate"
126
+ save_top_k: 1
127
+ monitor: "match_rate"
128
+ save_weights_only: true
129
+ mode: 'max'
130
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
131
+ init_args:
132
+ filename: "best_val_rmsd"
133
+ save_top_k: 1
134
+ monitor: "mean_rmsd"
135
+ save_weights_only: true
136
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
137
+ init_args:
138
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
139
+ monitor: "val_loss_total"
140
+ every_n_epochs: 100
141
+ save_weights_only: false
142
+ gradient_clip_val: 0.5
143
+ num_sanity_val_steps: 0
144
+ precision: "32-true"
145
+ max_epochs: 2000
146
+ enable_progress_bar: true
147
+ limit_val_batches: 0.1
148
+ check_val_every_n_epoch: 100
149
+ optimizer:
150
+ class_path: torch.optim.Adam
151
+ init_args:
152
+ lr: 3.629490873183724e-05
Trig-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf07f1ccfea92f1382ce5d1a9e3802869bbe9a01f134a145c0eca45493d5b82e
3
+ size 148075198
Trig-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.049242906264339095
16
+ epsilon:
17
+ class_path: omg.si.epsilon.VanishingEpsilon
18
+ init_args:
19
+ c: 9.418703639528207
20
+ mu: 0.1967838464371502
21
+ sigma: 0.040028404066547216
22
+ differential_equation_type: "SDE"
23
+ integrator_kwargs:
24
+ method: "euler"
25
+ dt: 0.0013504737289622426
26
+ velocity_annealing_factor: 11.483173553510193
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
32
+ gamma: null
33
+ epsilon: null
34
+ differential_equation_type: "ODE"
35
+ integrator_kwargs:
36
+ method: "euler"
37
+ velocity_annealing_factor: 0.4337356395028541
38
+ correct_center_of_mass_motion: false
39
+ data_fields:
40
+ # if the order of the data_fields changes,
41
+ # the order of the above StochasticInterpolant inputs must also change
42
+ - "species"
43
+ - "pos"
44
+ - "cell"
45
+ integration_time_steps: 740
46
+ relative_si_costs:
47
+ species_loss: 0.0
48
+ pos_loss_b: 0.24677273761024368
49
+ pos_loss_z: 0.7231540118244248
50
+ cell_loss_b: 0.030073250565331323
51
+ sampler:
52
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
53
+ init_args:
54
+ pos_distribution: null
55
+ cell_distribution:
56
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
57
+ init_args:
58
+ dataset_name: alex_mp_20
59
+ species_distribution:
60
+ class_path: omg.sampler.distributions.MirrorData
61
+ model:
62
+ class_path: omg.model.model.Model
63
+ init_args:
64
+ encoder:
65
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
66
+ head:
67
+ class_path: omg.model.heads.pass_through.PassThrough
68
+ time_embedder:
69
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
70
+ init_args:
71
+ dim: 256
72
+ use_min_perm_dist: True
73
+ float_32_matmul_precision: "high"
74
+ validation_mode: "match_rate"
75
+ number_cpus: 7
76
+ dataset_name: "alex_mp_20"
77
+ data:
78
+ train_dataset:
79
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
80
+ init_args:
81
+ dataset:
82
+ class_path: omg.datamodule.datamodule.DataModule
83
+ init_args:
84
+ lmdb_paths:
85
+ - "data/alex_mp_20/train.lmdb"
86
+ niggli: True
87
+ val_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/alex_mp_20/val.lmdb"
95
+ niggli: True
96
+ predict_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/alex_mp_20/test.lmdb"
104
+ niggli: True
105
+ batch_size: 32
106
+ num_workers: 4
107
+ pin_memory: True
108
+ persistent_workers: True
109
+ trainer:
110
+ callbacks:
111
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
112
+ init_args:
113
+ filename: "best_val_loss_total"
114
+ save_top_k: 1
115
+ monitor: "val_loss_total"
116
+ save_weights_only: true
117
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
118
+ init_args:
119
+ filename: "best_val_match_rate"
120
+ save_top_k: 1
121
+ monitor: "match_rate"
122
+ save_weights_only: true
123
+ mode: 'max'
124
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
125
+ init_args:
126
+ filename: "best_val_rmsd"
127
+ save_top_k: 1
128
+ monitor: "mean_rmsd"
129
+ save_weights_only: true
130
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
131
+ init_args:
132
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
133
+ monitor: "val_loss_total"
134
+ every_n_epochs: 100
135
+ save_weights_only: false
136
+ gradient_clip_val: 0.5
137
+ num_sanity_val_steps: 0
138
+ precision: "32-true"
139
+ max_epochs: 2000
140
+ enable_progress_bar: true
141
+ limit_val_batches: 0.1
142
+ check_val_every_n_epoch: 100
143
+ optimizer:
144
+ class_path: torch.optim.Adam
145
+ init_args:
146
+ lr: 9.320780466656964e-05
VESBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4442b50d380f7df53088cbc48ebb683ad5f768bf89dda93fa695ea1f339e289
3
+ size 148100602
VESBD-ODE/train.yaml ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolantVE
13
+ init_args:
14
+ sigma:
15
+ class_path: omg.si.sigma.GeometricSigma
16
+ init_args:
17
+ sigma_min: 0.004705415831077799
18
+ sigma_max: 0.9967130801483843
19
+ epsilon: null
20
+ differential_equation_type: "ODE"
21
+ integrator_kwargs:
22
+ method: "euler"
23
+ velocity_annealing_factor: 8.284579088906593
24
+ correct_center_of_mass_motion: true
25
+ predict_velocity: true
26
+ # lattice vectors
27
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
28
+ init_args:
29
+ interpolant: omg.si.interpolants.LinearInterpolant
30
+ gamma:
31
+ class_path: omg.si.gamma.LatentGammaSqrt
32
+ init_args:
33
+ a: 0.016616684357970132
34
+ epsilon:
35
+ class_path: omg.si.epsilon.VanishingEpsilon
36
+ init_args:
37
+ c: 3.9372558236242052
38
+ mu: 0.2649556265396099
39
+ sigma: 0.03578203230805775
40
+ differential_equation_type: "SDE"
41
+ integrator_kwargs:
42
+ method: "euler"
43
+ dt: 0.0015144158387556672
44
+ velocity_annealing_factor: 0.42775377056075214
45
+ correct_center_of_mass_motion: false
46
+ data_fields:
47
+ # if the order of the data_fields changes,
48
+ # the order of the above StochasticInterpolant inputs must also change
49
+ - "species"
50
+ - "pos"
51
+ - "cell"
52
+ integration_time_steps: 660
53
+ relative_si_costs:
54
+ species_loss: 0.0
55
+ pos_loss_b: 0.9813067351598369
56
+ cell_loss_b: 0.0005256953168558359
57
+ cell_loss_z: 0.018167569523307267
58
+ sampler:
59
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
60
+ init_args:
61
+ pos_distribution:
62
+ class_path: omg.sampler.distributions.NormalDistribution
63
+ init_args:
64
+ scale: 9.77149759679434
65
+ cell_distribution:
66
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
67
+ init_args:
68
+ dataset_name: alex_mp_20
69
+ species_distribution:
70
+ class_path: omg.sampler.distributions.MirrorData
71
+ model:
72
+ class_path: omg.model.model.Model
73
+ init_args:
74
+ encoder:
75
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
76
+ head:
77
+ class_path: omg.model.heads.pass_through.PassThrough
78
+ time_embedder:
79
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
80
+ init_args:
81
+ dim: 256
82
+ use_min_perm_dist: False
83
+ float_32_matmul_precision: "high"
84
+ validation_mode: "match_rate"
85
+ number_cpus: 7
86
+ dataset_name: "alex_mp_20"
87
+ data:
88
+ train_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/alex_mp_20/train.lmdb"
96
+ niggli: False
97
+ val_dataset:
98
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
99
+ init_args:
100
+ dataset:
101
+ class_path: omg.datamodule.datamodule.DataModule
102
+ init_args:
103
+ lmdb_paths:
104
+ - "data/alex_mp_20/val.lmdb"
105
+ niggli: False
106
+ predict_dataset:
107
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
108
+ init_args:
109
+ dataset:
110
+ class_path: omg.datamodule.datamodule.DataModule
111
+ init_args:
112
+ lmdb_paths:
113
+ - "data/alex_mp_20/test.lmdb"
114
+ niggli: False
115
+ batch_size: 512
116
+ num_workers: 4
117
+ pin_memory: True
118
+ persistent_workers: True
119
+ trainer:
120
+ callbacks:
121
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
122
+ init_args:
123
+ filename: "best_val_loss_total"
124
+ save_top_k: 1
125
+ monitor: "val_loss_total"
126
+ save_weights_only: true
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_match_rate"
130
+ save_top_k: 1
131
+ monitor: "match_rate"
132
+ save_weights_only: true
133
+ mode: 'max'
134
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
135
+ init_args:
136
+ filename: "best_val_rmsd"
137
+ save_top_k: 1
138
+ monitor: "mean_rmsd"
139
+ save_weights_only: true
140
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
141
+ init_args:
142
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
143
+ monitor: "val_loss_total"
144
+ every_n_epochs: 100
145
+ save_weights_only: false
146
+ gradient_clip_val: 0.5
147
+ num_sanity_val_steps: 0
148
+ precision: "32-true"
149
+ max_epochs: 2000
150
+ enable_progress_bar: true
151
+ limit_val_batches: 0.1
152
+ check_val_every_n_epoch: 100
153
+ optimizer:
154
+ class_path: torch.optim.Adam
155
+ init_args:
156
+ lr: 0.000296636127734534
VPSBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f041997c8526a5e3162958f0196584f6ae086f2a09f32774b5bea7465dcc2a76
3
+ size 148062466
VPSBD-ODE/train.yaml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
12
+ epsilon: null
13
+ differential_equation_type: "ODE"
14
+ integrator_kwargs:
15
+ method: "euler"
16
+ velocity_annealing_factor: 6.613808424917352
17
+ correct_center_of_mass_motion: true
18
+ predict_velocity: true
19
+ # lattice vectors
20
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
21
+ init_args:
22
+ interpolant: omg.si.interpolants.LinearInterpolant
23
+ gamma: null
24
+ epsilon: null
25
+ differential_equation_type: "ODE"
26
+ integrator_kwargs:
27
+ method: "euler"
28
+ velocity_annealing_factor: 2.447993013544224
29
+ correct_center_of_mass_motion: false
30
+ data_fields:
31
+ # if the order of the data_fields changes,
32
+ # the order of the above StochasticInterpolant inputs must also change
33
+ - "species"
34
+ - "pos"
35
+ - "cell"
36
+ integration_time_steps: 890
37
+ relative_si_costs:
38
+ species_loss: 0.0
39
+ pos_loss_b: 0.9597565150933746
40
+ cell_loss_b: 0.04024348490662539
41
+ sampler:
42
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
43
+ init_args:
44
+ pos_distribution:
45
+ class_path: omg.sampler.distributions.NormalDistribution
46
+ init_args:
47
+ scale: 0.22006712732536396
48
+ cell_distribution:
49
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
50
+ init_args:
51
+ dataset_name: alex_mp_20
52
+ species_distribution:
53
+ class_path: omg.sampler.distributions.MirrorData
54
+ model:
55
+ class_path: omg.model.model.Model
56
+ init_args:
57
+ encoder:
58
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
59
+ head:
60
+ class_path: omg.model.heads.pass_through.PassThrough
61
+ time_embedder:
62
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
63
+ init_args:
64
+ dim: 256
65
+ use_min_perm_dist: True
66
+ float_32_matmul_precision: "high"
67
+ validation_mode: "match_rate"
68
+ number_cpus: 7
69
+ dataset_name: "alex_mp_20"
70
+ data:
71
+ train_dataset:
72
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
73
+ init_args:
74
+ dataset:
75
+ class_path: omg.datamodule.datamodule.DataModule
76
+ init_args:
77
+ lmdb_paths:
78
+ - "data/alex_mp_20/train.lmdb"
79
+ niggli: True
80
+ val_dataset:
81
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
82
+ init_args:
83
+ dataset:
84
+ class_path: omg.datamodule.datamodule.DataModule
85
+ init_args:
86
+ lmdb_paths:
87
+ - "data/alex_mp_20/val.lmdb"
88
+ niggli: True
89
+ predict_dataset:
90
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
91
+ init_args:
92
+ dataset:
93
+ class_path: omg.datamodule.datamodule.DataModule
94
+ init_args:
95
+ lmdb_paths:
96
+ - "data/alex_mp_20/test.lmdb"
97
+ niggli: True
98
+ batch_size: 64
99
+ num_workers: 4
100
+ pin_memory: True
101
+ persistent_workers: True
102
+ trainer:
103
+ callbacks:
104
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
105
+ init_args:
106
+ filename: "best_val_loss_total"
107
+ save_top_k: 1
108
+ monitor: "val_loss_total"
109
+ save_weights_only: true
110
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
111
+ init_args:
112
+ filename: "best_val_match_rate"
113
+ save_top_k: 1
114
+ monitor: "match_rate"
115
+ save_weights_only: true
116
+ mode: 'max'
117
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
118
+ init_args:
119
+ filename: "best_val_rmsd"
120
+ save_top_k: 1
121
+ monitor: "mean_rmsd"
122
+ save_weights_only: true
123
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
124
+ init_args:
125
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
126
+ monitor: "val_loss_total"
127
+ every_n_epochs: 100
128
+ save_weights_only: false
129
+ gradient_clip_val: 0.5
130
+ num_sanity_val_steps: 0
131
+ precision: "32-true"
132
+ max_epochs: 2000
133
+ enable_progress_bar: true
134
+ limit_val_batches: 0.1
135
+ check_val_every_n_epoch: 100
136
+ optimizer:
137
+ class_path: torch.optim.Adam
138
+ init_args:
139
+ lr: 2.519765029616902e-05
VPSBD-SDE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07cee771817165d46f06d84f3593e3b9f01ef3a8f782de027a70c411ca88bfc1
3
+ size 49644411
VPSBD-SDE/train.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
12
+ epsilon:
13
+ class_path: omg.si.epsilon.VanishingEpsilon
14
+ init_args:
15
+ c: 2.4729222108905815
16
+ mu: 0.17656358406313838
17
+ sigma: 0.02379822283154629
18
+ differential_equation_type: "SDE"
19
+ integrator_kwargs:
20
+ method: "euler"
21
+ dt: 0.0016661101253703237
22
+ velocity_annealing_factor: 6.459028320375323
23
+ correct_center_of_mass_motion: true
24
+ predict_velocity: true
25
+ # lattice vectors
26
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
27
+ init_args:
28
+ interpolant: omg.si.interpolants.LinearInterpolant
29
+ gamma:
30
+ class_path: omg.si.gamma.LatentGammaSqrt
31
+ init_args:
32
+ a: 3.683542379054881
33
+ epsilon: null
34
+ differential_equation_type: "ODE"
35
+ integrator_kwargs:
36
+ method: "euler"
37
+ velocity_annealing_factor: 0.6692350794589719
38
+ correct_center_of_mass_motion: false
39
+ data_fields:
40
+ # if the order of the data_fields changes,
41
+ # the order of the above StochasticInterpolant inputs must also change
42
+ - "species"
43
+ - "pos"
44
+ - "cell"
45
+ integration_time_steps: 600
46
+ relative_si_costs:
47
+ species_loss: 0.0
48
+ pos_loss_b: 0.6060249654155797
49
+ pos_loss_z: 0.3828230559814603
50
+ cell_loss_b: 0.011151978602959979
51
+ sampler:
52
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
53
+ init_args:
54
+ pos_distribution:
55
+ class_path: omg.sampler.distributions.NormalDistribution
56
+ init_args:
57
+ scale: 2.2937003279036148
58
+ cell_distribution:
59
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
60
+ init_args:
61
+ dataset_name: alex_mp_20
62
+ species_distribution:
63
+ class_path: omg.sampler.distributions.MirrorData
64
+ model:
65
+ class_path: omg.model.model.Model
66
+ init_args:
67
+ encoder:
68
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
69
+ head:
70
+ class_path: omg.model.heads.pass_through.PassThrough
71
+ time_embedder:
72
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
73
+ init_args:
74
+ dim: 256
75
+ use_min_perm_dist: True
76
+ float_32_matmul_precision: "high"
77
+ validation_mode: "match_rate"
78
+ number_cpus: 7
79
+ dataset_name: "alex_mp_20"
80
+ data:
81
+ train_dataset:
82
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
83
+ init_args:
84
+ dataset:
85
+ class_path: omg.datamodule.datamodule.DataModule
86
+ init_args:
87
+ lmdb_paths:
88
+ - "data/alex_mp_20/train.lmdb"
89
+ niggli: True
90
+ val_dataset:
91
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
92
+ init_args:
93
+ dataset:
94
+ class_path: omg.datamodule.datamodule.DataModule
95
+ init_args:
96
+ lmdb_paths:
97
+ - "data/alex_mp_20/val.lmdb"
98
+ niggli: True
99
+ predict_dataset:
100
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
101
+ init_args:
102
+ dataset:
103
+ class_path: omg.datamodule.datamodule.DataModule
104
+ init_args:
105
+ lmdb_paths:
106
+ - "data/alex_mp_20/test.lmdb"
107
+ niggli: True
108
+ batch_size: 64
109
+ num_workers: 4
110
+ pin_memory: True
111
+ persistent_workers: True
112
+ trainer:
113
+ callbacks:
114
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
115
+ init_args:
116
+ filename: "best_val_loss_total"
117
+ save_top_k: 1
118
+ monitor: "val_loss_total"
119
+ save_weights_only: true
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_match_rate"
123
+ save_top_k: 1
124
+ monitor: "match_rate"
125
+ save_weights_only: true
126
+ mode: 'max'
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_rmsd"
130
+ save_top_k: 1
131
+ monitor: "mean_rmsd"
132
+ save_weights_only: true
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
136
+ monitor: "val_loss_total"
137
+ every_n_epochs: 100
138
+ save_weights_only: false
139
+ gradient_clip_val: 0.5
140
+ num_sanity_val_steps: 0
141
+ precision: "32-true"
142
+ max_epochs: 2000
143
+ enable_progress_bar: true
144
+ limit_val_batches: 0.1
145
+ check_val_every_n_epoch: 100
146
+ optimizer:
147
+ class_path: torch.optim.Adam
148
+ init_args:
149
+ lr: 0.0003030820420973639