martirossyan commited on
Commit
1ef2652
·
verified ·
1 Parent(s): b82baed

Upload 4 files

Browse files
EncDec-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73ab9ce457f02553142b9fa08c6c93fa5d14d772d973c63a05a7f8af44beb09d
3
+ size 148070612
EncDec-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
13
+ init_args:
14
+ switch_time: 0.42184997325946555
15
+ power: 0.5
16
+ gamma:
17
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
18
+ init_args:
19
+ a: 0.03989185248799893
20
+ switch_time: 0.42184997325946555
21
+ power: 0.5
22
+ epsilon:
23
+ class_path: omg.si.epsilon.VanishingEpsilon
24
+ init_args:
25
+ c: 2.3996529332194574
26
+ mu: 0.25251095399328916
27
+ sigma: 0.03759134500470063
28
+ differential_equation_type: "SDE"
29
+ integrator_kwargs:
30
+ method: "euler"
31
+ dt: 0.0014076164225116372
32
+ velocity_annealing_factor: 3.7755089557808477
33
+ correct_center_of_mass_motion: true
34
+ # lattice vectors
35
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
36
+ init_args:
37
+ interpolant: omg.si.interpolants.LinearInterpolant
38
+ gamma:
39
+ class_path: omg.si.gamma.LatentGammaSqrt
40
+ init_args:
41
+ a: 4.961271013084809
42
+ epsilon: null
43
+ differential_equation_type: "ODE"
44
+ integrator_kwargs:
45
+ method: "euler"
46
+ velocity_annealing_factor: 1.1379701544400436
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 710
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.6143090042317803
58
+ pos_loss_z: 0.3794040725288834
59
+ cell_loss_b: 0.00628692323933625
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution: null
64
+ cell_distribution:
65
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
66
+ init_args:
67
+ dataset_name: mp_20
68
+ species_distribution:
69
+ class_path: omg.sampler.distributions.MirrorData
70
+ model:
71
+ class_path: omg.model.model.Model
72
+ init_args:
73
+ encoder:
74
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
75
+ head:
76
+ class_path: omg.model.heads.pass_through.PassThrough
77
+ time_embedder:
78
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
79
+ init_args:
80
+ dim: 256
81
+ use_min_perm_dist: False
82
+ float_32_matmul_precision: "high"
83
+ validation_mode: "match_rate"
84
+ dataset_name: "mp_20"
85
+ data:
86
+ train_dataset:
87
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
88
+ init_args:
89
+ dataset:
90
+ class_path: omg.datamodule.datamodule.DataModule
91
+ init_args:
92
+ lmdb_paths:
93
+ - "data/mp_20/train.lmdb"
94
+ niggli: True
95
+ val_dataset:
96
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
97
+ init_args:
98
+ dataset:
99
+ class_path: omg.datamodule.datamodule.DataModule
100
+ init_args:
101
+ lmdb_paths:
102
+ - "data/mp_20/val.lmdb"
103
+ niggli: True
104
+ predict_dataset:
105
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
106
+ init_args:
107
+ dataset:
108
+ class_path: omg.datamodule.datamodule.DataModule
109
+ init_args:
110
+ lmdb_paths:
111
+ - "data/mp_20/test.lmdb"
112
+ niggli: True
113
+ batch_size: 32
114
+ num_workers: 4
115
+ pin_memory: True
116
+ persistent_workers: True
117
+ trainer:
118
+ callbacks:
119
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
120
+ init_args:
121
+ filename: "best_val_loss_total"
122
+ save_top_k: 1
123
+ monitor: "val_loss_total"
124
+ save_weights_only: true
125
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
126
+ init_args:
127
+ filename: "best_val_match_rate"
128
+ save_top_k: 1
129
+ monitor: "match_rate"
130
+ save_weights_only: true
131
+ mode: 'max'
132
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
133
+ init_args:
134
+ filename: "best_val_rmsd"
135
+ save_top_k: 1
136
+ monitor: "mean_rmsd"
137
+ save_weights_only: true
138
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
139
+ init_args:
140
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
141
+ monitor: "val_loss_total"
142
+ every_n_epochs: 100
143
+ save_weights_only: false
144
+ gradient_clip_val: 0.5
145
+ num_sanity_val_steps: 0
146
+ precision: "32-true"
147
+ max_epochs: 2000
148
+ enable_progress_bar: false
149
+ check_val_every_n_epoch: 100
150
+ optimizer:
151
+ class_path: torch.optim.Adam
152
+ init_args:
153
+ lr: 0.00018567271191860665
Linear-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9755c37adfda19e98a32e14751649874e060d75f607f70f434f61ef45a2cb9e9
3
+ size 148108247
Linear-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant: omg.si.interpolants.PeriodicLinearInterpolant
12
+ gamma:
13
+ class_path: omg.si.gamma.LatentGammaSqrt
14
+ init_args:
15
+ a: 0.06285652866840548
16
+ epsilon:
17
+ class_path: omg.si.epsilon.VanishingEpsilon
18
+ init_args:
19
+ c: 6.097168392667226
20
+ mu: 0.21833859329765842
21
+ sigma: 0.04985718977712428
22
+ differential_equation_type: "SDE"
23
+ integrator_kwargs:
24
+ method: "euler"
25
+ dt: 0.0032297736033797264
26
+ velocity_annealing_factor: 11.58289329358004
27
+ correct_center_of_mass_motion: true
28
+ # lattice vectors
29
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
30
+ init_args:
31
+ interpolant: omg.si.interpolants.LinearInterpolant
32
+ gamma:
33
+ class_path: omg.si.gamma.LatentGammaSqrt
34
+ init_args:
35
+ a: 0.1317493001266121
36
+ epsilon:
37
+ class_path: omg.si.epsilon.VanishingEpsilon
38
+ init_args:
39
+ c: 9.612495617660462
40
+ mu: 0.08389382419092543
41
+ sigma: 0.033192886798663945
42
+ differential_equation_type: "SDE"
43
+ integrator_kwargs:
44
+ method: "euler"
45
+ dt: 0.0032297736033797264
46
+ velocity_annealing_factor: 5.081210983525862
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 310
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.007345481151868809
58
+ pos_loss_z: 0.9153543617007412
59
+ cell_loss_b: 0.06421063793348068
60
+ cell_loss_z: 0.013089519213909303
61
+ sampler:
62
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
63
+ init_args:
64
+ pos_distribution: null
65
+ cell_distribution:
66
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
67
+ init_args:
68
+ dataset_name: mp_20
69
+ species_distribution:
70
+ class_path: omg.sampler.distributions.MirrorData
71
+ model:
72
+ class_path: omg.model.model.Model
73
+ init_args:
74
+ encoder:
75
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
76
+ head:
77
+ class_path: omg.model.heads.pass_through.PassThrough
78
+ time_embedder:
79
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
80
+ init_args:
81
+ dim: 256
82
+ use_min_perm_dist: False
83
+ float_32_matmul_precision: "high"
84
+ validation_mode: "match_rate"
85
+ dataset_name: "mp_20"
86
+ data:
87
+ train_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/mp_20/train.lmdb"
95
+ niggli: False
96
+ val_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/mp_20/val.lmdb"
104
+ niggli: False
105
+ predict_dataset:
106
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
107
+ init_args:
108
+ dataset:
109
+ class_path: omg.datamodule.datamodule.DataModule
110
+ init_args:
111
+ lmdb_paths:
112
+ - "data/mp_20/test.lmdb"
113
+ niggli: False
114
+ batch_size: 256
115
+ num_workers: 4
116
+ pin_memory: True
117
+ persistent_workers: True
118
+ trainer:
119
+ callbacks:
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_loss_total"
123
+ save_top_k: 1
124
+ monitor: "val_loss_total"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_match_rate"
129
+ save_top_k: 1
130
+ monitor: "match_rate"
131
+ save_weights_only: true
132
+ mode: 'max'
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_rmsd"
136
+ save_top_k: 1
137
+ monitor: "mean_rmsd"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
142
+ monitor: "val_loss_total"
143
+ every_n_epochs: 100
144
+ save_weights_only: false
145
+ gradient_clip_val: 0.5
146
+ num_sanity_val_steps: 0
147
+ precision: "32-true"
148
+ max_epochs: 2000
149
+ enable_progress_bar: false
150
+ check_val_every_n_epoch: 100
151
+ optimizer:
152
+ class_path: torch.optim.Adam
153
+ init_args:
154
+ lr: 0.0002629870131361822