martirossyan commited on
Commit
f780140
·
verified ·
1 Parent(s): 55c80ef

Upload 2 files

Browse files
EncDec-SDE-Gamma/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ea7b9cdd147c4f463b65fd3250146c509dfc17fd678acde7c00eabcecc9576c
3
+ size 49642338
EncDec-SDE-Gamma/train.yaml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
13
+ init_args:
14
+ switch_time: 0.42184997325946555
15
+ power: 0.5
16
+ gamma:
17
+ class_path: omg.si.gamma.LatentGammaEncoderDecoder
18
+ init_args:
19
+ a: 0.03989185248799893
20
+ switch_time: 0.42184997325946555
21
+ power: 0.5
22
+ epsilon:
23
+ class_path: omg.si.epsilon.VanishingEpsilon
24
+ init_args:
25
+ c: 2.3996529332194574
26
+ mu: 0.25251095399328916
27
+ sigma: 0.03759134500470063
28
+ differential_equation_type: "SDE"
29
+ integrator_kwargs:
30
+ method: "euler"
31
+ dt: 0.0014076164225116372
32
+ velocity_annealing_factor: 3.7755089557808477
33
+ correct_center_of_mass_motion: true
34
+ # lattice vectors
35
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
36
+ init_args:
37
+ interpolant: omg.si.interpolants.LinearInterpolant
38
+ gamma:
39
+ class_path: omg.si.gamma.LatentGammaSqrt
40
+ init_args:
41
+ a: 4.961271013084809
42
+ epsilon: null
43
+ differential_equation_type: "ODE"
44
+ integrator_kwargs:
45
+ method: "euler"
46
+ velocity_annealing_factor: 1.1379701544400436
47
+ correct_center_of_mass_motion: false
48
+ data_fields:
49
+ # if the order of the data_fields changes,
50
+ # the order of the above StochasticInterpolant inputs must also change
51
+ - "species"
52
+ - "pos"
53
+ - "cell"
54
+ integration_time_steps: 710
55
+ relative_si_costs:
56
+ species_loss: 0.0
57
+ pos_loss_b: 0.6143090042317803
58
+ pos_loss_z: 0.3794040725288834
59
+ cell_loss_b: 0.00628692323933625
60
+ sampler:
61
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
62
+ init_args:
63
+ pos_distribution: null
64
+ cell_distribution:
65
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
66
+ init_args:
67
+ dataset_name: mp_20
68
+ species_distribution:
69
+ class_path: omg.sampler.distributions.MirrorData
70
+ model:
71
+ class_path: omg.model.model.Model
72
+ init_args:
73
+ encoder:
74
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
75
+ head:
76
+ class_path: omg.model.heads.pass_through.PassThrough
77
+ time_embedder:
78
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
79
+ init_args:
80
+ dim: 256
81
+ use_min_perm_dist: False
82
+ float_32_matmul_precision: "high"
83
+ validation_mode: "match_rate"
84
+ dataset_name: "mp_20"
85
+ data:
86
+ train_dataset:
87
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
88
+ init_args:
89
+ dataset:
90
+ class_path: omg.datamodule.datamodule.DataModule
91
+ init_args:
92
+ lmdb_paths:
93
+ - "data/mp_20/train.lmdb"
94
+ niggli: True
95
+ val_dataset:
96
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
97
+ init_args:
98
+ dataset:
99
+ class_path: omg.datamodule.datamodule.DataModule
100
+ init_args:
101
+ lmdb_paths:
102
+ - "data/mp_20/val.lmdb"
103
+ niggli: True
104
+ predict_dataset:
105
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
106
+ init_args:
107
+ dataset:
108
+ class_path: omg.datamodule.datamodule.DataModule
109
+ init_args:
110
+ lmdb_paths:
111
+ - "data/mp_20/test.lmdb"
112
+ niggli: True
113
+ batch_size: 32
114
+ num_workers: 4
115
+ pin_memory: True
116
+ persistent_workers: True
117
+ trainer:
118
+ callbacks:
119
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
120
+ init_args:
121
+ filename: "best_val_loss_total"
122
+ save_top_k: 1
123
+ monitor: "val_loss_total"
124
+ save_weights_only: true
125
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
126
+ init_args:
127
+ filename: "best_val_match_rate"
128
+ save_top_k: 1
129
+ monitor: "match_rate"
130
+ save_weights_only: true
131
+ mode: 'max'
132
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
133
+ init_args:
134
+ filename: "best_val_rmsd"
135
+ save_top_k: 1
136
+ monitor: "mean_rmsd"
137
+ save_weights_only: true
138
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
139
+ init_args:
140
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
141
+ monitor: "val_loss_total"
142
+ every_n_epochs: 100
143
+ save_weights_only: false
144
+ gradient_clip_val: 0.5
145
+ num_sanity_val_steps: 0
146
+ precision: "32-true"
147
+ max_epochs: 2000
148
+ enable_progress_bar: false
149
+ check_val_every_n_epoch: 100
150
+ optimizer:
151
+ class_path: torch.optim.Adam
152
+ init_args:
153
+ lr: 0.00018567271191860665