martirossyan commited on
Commit
6a8ebba
·
verified ·
1 Parent(s): d0441b6

Upload 10 files

Browse files
EncDec-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7dded5b6fe149b90baa87c42b4ac9624f1337c6c5a3dba22834f0ea6b6f5b67
3
+ size 148101112
EncDec-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
13
+ init_args:
14
+ switch_time: 0.6487086666110259
15
+ power: 1.0
16
+ gamma:
17
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
18
+ init_args:
19
+ a: 1.9883383838119686
20
+ switch_time: 0.6487086666110259
21
+ power: 1.0
22
+ epsilon: null
23
+ differential_equation_type: "ODE"
24
+ integrator_kwargs:
25
+ method: "euler"
26
+ velocity_annealing_factor: 12.290317841755964
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaSqrt
34
+ init_args:
35
+ a: 0.21935645939922985
36
+ epsilon:
37
+ class_path: omg.si.epsilon.VanishingEpsilon
38
+ init_args:
39
+ c: 9.431054439782873
40
+ mu: 0.21809909486896933
41
+ sigma: 0.03292165737293197
42
+ differential_equation_type: "SDE"
43
+ integrator_kwargs:
44
+ method: "euler"
45
+ dt: 0.001218559336848557
46
+ velocity_annealing_factor: 4.302804708170181
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 820
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.689192251322191
58
+ cell_loss_b: 0.12351464867571432
59
+ cell_loss_z: 0.18729310000209468
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution: null
64
+ cell_distribution:
65
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
66
+ init_args:
67
+ dataset_name: alex_mp_20
68
+ species_distribution:
69
+ class_path: omg.sampler.distributions.MirrorData
70
+ model:
71
+ class_path: omg.model.model.Model
72
+ init_args:
73
+ encoder:
74
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
75
+ head:
76
+ class_path: omg.model.heads.pass_through.PassThrough
77
+ time_embedder:
78
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
79
+ init_args:
80
+ dim: 256
81
+ use_min_perm_dist: False
82
+ float_32_matmul_precision: "high"
83
+ validation_mode: "match_rate"
84
+ number_cpus: 7
85
+ dataset_name: "alex_mp_20"
86
+ data:
87
+ train_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/alex_mp_20/train.lmdb"
95
+ niggli: True
96
+ val_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/alex_mp_20/val.lmdb"
104
+ niggli: True
105
+ predict_dataset:
106
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
107
+ init_args:
108
+ dataset:
109
+ class_path: omg.datamodule.datamodule.DataModule
110
+ init_args:
111
+ lmdb_paths:
112
+ - "data/alex_mp_20/test.lmdb"
113
+ niggli: True
114
+ batch_size: 512
115
+ num_workers: 4
116
+ pin_memory: True
117
+ persistent_workers: True
118
+ trainer:
119
+ callbacks:
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_loss_total"
123
+ save_top_k: 1
124
+ monitor: "val_loss_total"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_match_rate"
129
+ save_top_k: 1
130
+ monitor: "match_rate"
131
+ save_weights_only: true
132
+ mode: 'max'
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_rmsd"
136
+ save_top_k: 1
137
+ monitor: "mean_rmsd"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
142
+ monitor: "val_loss_total"
143
+ every_n_epochs: 100
144
+ save_weights_only: false
145
+ gradient_clip_val: 0.5
146
+ num_sanity_val_steps: 0
147
+ precision: "32-true"
148
+ max_epochs: 2000
149
+ enable_progress_bar: true
150
+ limit_val_batches: 0.1
151
+ check_val_every_n_epoch: 100
152
+ optimizer:
153
+ class_path: torch.optim.Adam
154
+ init_args:
155
+ lr: 0.00047748599389170053
EncDec-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e60e7c48ca7a5dc422a470dada73871f90095e7299f2b09858a1f6bb02a8dc2
3
+ size 148075262
EncDec-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
13
+ init_args:
14
+ switch_time: 0.42184997325946555
15
+ power: 0.5
16
+ gamma:
17
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
18
+ init_args:
19
+ a: 0.03989185248799893
20
+ switch_time: 0.42184997325946555
21
+ power: 0.5
22
+ epsilon:
23
+ class_path: omg.si.epsilon.VanishingEpsilon
24
+ init_args:
25
+ c: 2.3996529332194574
26
+ mu: 0.25251095399328916
27
+ sigma: 0.03759134500470063
28
+ differential_equation_type: "SDE"
29
+ integrator_kwargs:
30
+ method: "euler"
31
+ dt: 0.0014076164225116372
32
+ velocity_annealing_factor: 3.7755089557808477
33
+ correct_center_of_mass_motion: true
34
+ # lattice vectors
35
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
36
+ init_args:
37
+ interpolant: omg.si.interpolants.LinearInterpolant
38
+ gamma:
39
+ class_path: omg.si.gamma.LatentGammaSqrt
40
+ init_args:
41
+ a: 4.961271013084809
42
+ epsilon: null
43
+ differential_equation_type: "ODE"
44
+ integrator_kwargs:
45
+ method: "euler"
46
+ velocity_annealing_factor: 1.1379701544400436
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 710
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.6143090042317803
58
+ pos_loss_z: 0.3794040725288834
59
+ cell_loss_b: 0.00628692323933625
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution: null
64
+ cell_distribution:
65
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
66
+ init_args:
67
+ dataset_name: alex_mp_20
68
+ species_distribution:
69
+ class_path: omg.sampler.distributions.MirrorData
70
+ model:
71
+ class_path: omg.model.model.Model
72
+ init_args:
73
+ encoder:
74
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
75
+ head:
76
+ class_path: omg.model.heads.pass_through.PassThrough
77
+ time_embedder:
78
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
79
+ init_args:
80
+ dim: 256
81
+ use_min_perm_dist: False
82
+ float_32_matmul_precision: "high"
83
+ validation_mode: "match_rate"
84
+ number_cpus: 7
85
+ dataset_name: "alex_mp_20"
86
+ data:
87
+ train_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/alex_mp_20/train.lmdb"
95
+ niggli: True
96
+ val_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/alex_mp_20/val.lmdb"
104
+ niggli: True
105
+ predict_dataset:
106
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
107
+ init_args:
108
+ dataset:
109
+ class_path: omg.datamodule.datamodule.DataModule
110
+ init_args:
111
+ lmdb_paths:
112
+ - "data/alex_mp_20/test.lmdb"
113
+ niggli: True
114
+ batch_size: 512 # Used to be 32
115
+ num_workers: 4
116
+ pin_memory: True
117
+ persistent_workers: True
118
+ trainer:
119
+ callbacks:
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_loss_total"
123
+ save_top_k: 1
124
+ monitor: "val_loss_total"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_match_rate"
129
+ save_top_k: 1
130
+ monitor: "match_rate"
131
+ save_weights_only: true
132
+ mode: 'max'
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_rmsd"
136
+ save_top_k: 1
137
+ monitor: "mean_rmsd"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
142
+ monitor: "val_loss_total"
143
+ every_n_epochs: 100
144
+ save_weights_only: false
145
+ gradient_clip_val: 0.5
146
+ num_sanity_val_steps: 0
147
+ precision: "32-true"
148
+ max_epochs: 2000
149
+ enable_progress_bar: true
150
+ limit_val_batches: 0.1
151
+ check_val_every_n_epoch: 200 # Used to be 100
152
+ optimizer:
153
+ class_path: torch.optim.Adam
154
+ init_args:
155
+ lr: 0.00018567271191860665
Linear-ODE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:552ee843d60a93c942d831eaff0c360c8c43046aec2c00c2a7711c8a31d67507
3
+ size 49644475
Linear-ODE-Gamma/train.yaml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.2575112227566439
16
+ epsilon: null
17
+ differential_equation_type: "ODE"
18
+ integrator_kwargs:
19
+ method: "euler"
20
+ velocity_annealing_factor: 7.7611189744870925
21
+ correct_center_of_mass_motion: true
22
+ # lattice vectors
23
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
24
+ init_args:
25
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
26
+ gamma:
27
+ class_path: omg.si.gamma.LatentGammaSqrt
28
+ init_args:
29
+ a: 2.9759856920732597
30
+ epsilon: null
31
+ differential_equation_type: "ODE"
32
+ integrator_kwargs:
33
+ method: "euler"
34
+ velocity_annealing_factor: 4.116061496782678
35
+ correct_center_of_mass_motion: false
36
+ data_fields:
37
+ # if the order of the data_fields changes,
38
+ # the order of the above StochasticInterpolant inputs must also change
39
+ - "species"
40
+ - "pos"
41
+ - "cell"
42
+ integration_time_steps: 690
43
+ relative_si_costs:
44
+ species_loss: 0.0
45
+ pos_loss_b: 0.9976417941296929
46
+ cell_loss_b: 0.002358205870307133
47
+ sampler:
48
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
49
+ init_args:
50
+ pos_distribution: null
51
+ cell_distribution:
52
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
53
+ init_args:
54
+ dataset_name: alex_mp_20
55
+ species_distribution:
56
+ class_path: omg.sampler.distributions.MirrorData
57
+ model:
58
+ class_path: omg.model.model.Model
59
+ init_args:
60
+ encoder:
61
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
62
+ head:
63
+ class_path: omg.model.heads.pass_through.PassThrough
64
+ time_embedder:
65
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
66
+ init_args:
67
+ dim: 256
68
+ use_min_perm_dist: False
69
+ float_32_matmul_precision: "high"
70
+ validation_mode: "match_rate"
71
+ number_cpus: 7
72
+ dataset_name: "alex_mp_20"
73
+ data:
74
+ train_dataset:
75
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
76
+ init_args:
77
+ dataset:
78
+ class_path: omg.datamodule.datamodule.DataModule
79
+ init_args:
80
+ lmdb_paths:
81
+ - "data/alex_mp_20/train.lmdb"
82
+ niggli: True
83
+ val_dataset:
84
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
85
+ init_args:
86
+ dataset:
87
+ class_path: omg.datamodule.datamodule.DataModule
88
+ init_args:
89
+ lmdb_paths:
90
+ - "data/alex_mp_20/val.lmdb"
91
+ niggli: True
92
+ predict_dataset:
93
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
94
+ init_args:
95
+ dataset:
96
+ class_path: omg.datamodule.datamodule.DataModule
97
+ init_args:
98
+ lmdb_paths:
99
+ - "data/alex_mp_20/test.lmdb"
100
+ niggli: True
101
+ batch_size: 128
102
+ num_workers: 4
103
+ pin_memory: True
104
+ persistent_workers: True
105
+ trainer:
106
+ callbacks:
107
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
108
+ init_args:
109
+ filename: "best_val_loss_total"
110
+ save_top_k: 1
111
+ monitor: "val_loss_total"
112
+ save_weights_only: true
113
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
114
+ init_args:
115
+ filename: "best_val_match_rate"
116
+ save_top_k: 1
117
+ monitor: "match_rate"
118
+ save_weights_only: true
119
+ mode: 'max'
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_rmsd"
123
+ save_top_k: 1
124
+ monitor: "mean_rmsd"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
129
+ monitor: "val_loss_total"
130
+ every_n_epochs: 100
131
+ save_weights_only: false
132
+ gradient_clip_val: 0.5
133
+ num_sanity_val_steps: 0
134
+ precision: "32-true"
135
+ max_epochs: 2000
136
+ enable_progress_bar: true
137
+ limit_val_batches: 0.1
138
+ check_val_every_n_epoch: 100
139
+ optimizer:
140
+ class_path: torch.optim.Adam
141
+ init_args:
142
+ lr: 4.006249666984122e-05
Linear-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f720adab2daa3cb9598484ea6ff37c39367bdf41b3c9dc49ef09400abaab87f3
3
+ size 148064122
Linear-ODE/train.yaml ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.PeriodicLinearInterpolant
12
+ gamma: null
13
+ epsilon: null
14
+ differential_equation_type: "ODE"
15
+ integrator_kwargs:
16
+ method: "euler"
17
+ velocity_annealing_factor: 10.182659004291072
18
+ correct_center_of_mass_motion: true
19
+ # lattice vectors
20
+ - class_path: omg.si.SingleStochasticInterpolant
21
+ init_args:
22
+ interpolant: omg.si.LinearInterpolant
23
+ gamma: null
24
+ epsilon: null
25
+ differential_equation_type: "ODE"
26
+ integrator_kwargs:
27
+ method: "euler"
28
+ velocity_annealing_factor: 1.824475401606087
29
+ correct_center_of_mass_motion: false
30
+ data_fields:
31
+ # if the order of the data_fields changes,
32
+ # the order of the above StochasticInterpolant inputs must also change
33
+ - "species"
34
+ - "pos"
35
+ - "cell"
36
+ integration_time_steps: 210
37
+ relative_si_costs:
38
+ species_loss: 0.0
39
+ pos_loss_b: 0.9994149341846618
40
+ cell_loss_b: 0.0005850658153382233
41
+ sampler:
42
+ class_path: omg.sampler.SampleFromRNG
43
+ init_args:
44
+ pos_distribution: null
45
+ cell_distribution:
46
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
47
+ init_args:
48
+ dataset_name: alex_mp_20
49
+ species_distribution:
50
+ class_path: omg.sampler.distributions.MirrorData
51
+ model:
52
+ class_path: omg.model.model.Model
53
+ init_args:
54
+ encoder:
55
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
56
+ head:
57
+ class_path: omg.model.heads.pass_through.PassThrough
58
+ time_embedder:
59
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
60
+ init_args:
61
+ dim: 256
62
+ use_min_perm_dist: False
63
+ float_32_matmul_precision: "high"
64
+ validation_mode: "match_rate"
65
+ number_cpus: 7
66
+ dataset_name: "alex_mp_20"
67
+ data:
68
+ train_dataset:
69
+ class_path: omg.datamodule.OMGTorchDataset
70
+ init_args:
71
+ dataset:
72
+ class_path: omg.datamodule.DataModule
73
+ init_args:
74
+ lmdb_paths:
75
+ - data/alex_mp_20/train.lmdb
76
+ niggli: False
77
+ val_dataset:
78
+ class_path: omg.datamodule.OMGTorchDataset
79
+ init_args:
80
+ dataset:
81
+ class_path: omg.datamodule.DataModule
82
+ init_args:
83
+ lmdb_paths:
84
+ - data/alex_mp_20/val.lmdb
85
+ niggli: False
86
+ predict_dataset:
87
+ class_path: omg.datamodule.OMGTorchDataset
88
+ init_args:
89
+ dataset:
90
+ class_path: omg.datamodule.DataModule
91
+ init_args:
92
+ lmdb_paths:
93
+ - data/alex_mp_20/test.lmdb
94
+ niggli: False
95
+ batch_size: 512
96
+ num_workers: 4
97
+ pin_memory: True
98
+ persistent_workers: True
99
+ trainer:
100
+ callbacks:
101
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
102
+ init_args:
103
+ filename: "best_val_loss_total"
104
+ save_top_k: 1
105
+ monitor: "val_loss_total"
106
+ save_weights_only: true
107
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
108
+ init_args:
109
+ filename: "best_val_match_rate"
110
+ save_top_k: 1
111
+ monitor: "match_rate"
112
+ save_weights_only: true
113
+ mode: 'max'
114
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
115
+ init_args:
116
+ filename: "best_val_rmsd"
117
+ save_top_k: 1
118
+ monitor: "mean_rmsd"
119
+ save_weights_only: true
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
123
+ monitor: "val_loss_total"
124
+ every_n_epochs: 100
125
+ save_weights_only: false
126
+ gradient_clip_val: 0.5
127
+ num_sanity_val_steps: 0
128
+ precision: "32-true"
129
+ max_epochs: 2000
130
+ enable_progress_bar: true
131
+ limit_val_batches: 0.1
132
+ check_val_every_n_epoch: 100
133
+ optimizer:
134
+ class_path: torch.optim.Adam
135
+ init_args:
136
+ lr: 0.0006689636445843722
Linear-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b48a00fe7e9fa442885ed8c65a06613f8d5f0f1739bc87ed1c9e7fa7d3ed3977
3
+ size 49644475
Linear-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.06285652866840548
16
+ epsilon:
17
+ class_path: omg.si.epsilon.VanishingEpsilon
18
+ init_args:
19
+ c: 6.097168392667226
20
+ mu: 0.21833859329765842
21
+ sigma: 0.04985718977712428
22
+ differential_equation_type: "SDE"
23
+ integrator_kwargs:
24
+ method: "euler"
25
+ dt: 0.0032297736033797264
26
+ velocity_annealing_factor: 11.58289329358004
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.LinearInterpolant
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaSqrt
34
+ init_args:
35
+ a: 0.1317493001266121
36
+ epsilon:
37
+ class_path: omg.si.epsilon.VanishingEpsilon
38
+ init_args:
39
+ c: 9.612495617660462
40
+ mu: 0.08389382419092543
41
+ sigma: 0.033192886798663945
42
+ differential_equation_type: "SDE"
43
+ integrator_kwargs:
44
+ method: "euler"
45
+ dt: 0.0032297736033797264
46
+ velocity_annealing_factor: 5.081210983525862
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 310
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.007345481151868809
58
+ pos_loss_z: 0.9153543617007412
59
+ cell_loss_b: 0.06421063793348068
60
+ cell_loss_z: 0.013089519213909303
61
+ sampler:
62
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
63
+ init_args:
64
+ pos_distribution: null
65
+ cell_distribution:
66
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
67
+ init_args:
68
+ dataset_name: alex_mp_20
69
+ species_distribution:
70
+ class_path: omg.sampler.distributions.MirrorData
71
+ model:
72
+ class_path: omg.model.model.Model
73
+ init_args:
74
+ encoder:
75
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
76
+ head:
77
+ class_path: omg.model.heads.pass_through.PassThrough
78
+ time_embedder:
79
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
80
+ init_args:
81
+ dim: 256
82
+ use_min_perm_dist: False
83
+ float_32_matmul_precision: "high"
84
+ validation_mode: "match_rate"
85
+ number_cpus: 7
86
+ dataset_name: "alex_mp_20"
87
+ data:
88
+ train_dataset:
89
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
90
+ init_args:
91
+ dataset:
92
+ class_path: omg.datamodule.datamodule.DataModule
93
+ init_args:
94
+ lmdb_paths:
95
+ - "data/alex_mp_20/train.lmdb"
96
+ niggli: False
97
+ val_dataset:
98
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
99
+ init_args:
100
+ dataset:
101
+ class_path: omg.datamodule.datamodule.DataModule
102
+ init_args:
103
+ lmdb_paths:
104
+ - "data/alex_mp_20/val.lmdb"
105
+ niggli: False
106
+ predict_dataset:
107
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
108
+ init_args:
109
+ dataset:
110
+ class_path: omg.datamodule.datamodule.DataModule
111
+ init_args:
112
+ lmdb_paths:
113
+ - "data/alex_mp_20/test.lmdb"
114
+ niggli: False
115
+ batch_size: 256
116
+ num_workers: 4
117
+ pin_memory: True
118
+ persistent_workers: True
119
+ trainer:
120
+ callbacks:
121
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
122
+ init_args:
123
+ filename: "best_val_loss_total"
124
+ save_top_k: 1
125
+ monitor: "val_loss_total"
126
+ save_weights_only: true
127
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
128
+ init_args:
129
+ filename: "best_val_match_rate"
130
+ save_top_k: 1
131
+ monitor: "match_rate"
132
+ save_weights_only: true
133
+ mode: 'max'
134
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
135
+ init_args:
136
+ filename: "best_val_rmsd"
137
+ save_top_k: 1
138
+ monitor: "mean_rmsd"
139
+ save_weights_only: true
140
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
141
+ init_args:
142
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
143
+ monitor: "val_loss_total"
144
+ every_n_epochs: 100
145
+ save_weights_only: false
146
+ gradient_clip_val: 0.5
147
+ num_sanity_val_steps: 0
148
+ precision: "32-true"
149
+ max_epochs: 2000
150
+ enable_progress_bar: true
151
+ limit_val_batches: 0.1
152
+ check_val_every_n_epoch: 100
153
+ optimizer:
154
+ class_path: torch.optim.Adam
155
+ init_args:
156
+ lr: 0.0002629870131361822