Upload 22 files
Browse files- EncDec-ODE-Gamma/checkpoint.ckpt +3 -0
- EncDec-ODE-Gamma/train.yaml +168 -0
- EncDec-SDE-Gamma/checkpoint.ckpt +3 -0
- EncDec-SDE-Gamma/train.yaml +195 -0
- Linear-ODE-Gamma/checkpoint.ckpt +3 -0
- Linear-ODE-Gamma/train.yaml +195 -0
- Linear-ODE/checkpoint.ckpt +3 -0
- Linear-ODE/train.yaml +168 -0
- Linear-SDE-Gamma/checkpoint.ckpt +3 -0
- Linear-SDE-Gamma/train.yaml +186 -0
- Trig-ODE-Gamma/checkpoint.ckpt +3 -0
- Trig-ODE-Gamma/train.yaml +189 -0
- Trig-ODE/checkpoint.ckpt +3 -0
- Trig-ODE/train.yaml +179 -0
- Trig-SDE-Gamma/checkpoint.ckpt +3 -0
- Trig-SDE-Gamma/train.yaml +189 -0
- VESBD-ODE/checkpoint.ckpt +3 -0
- VESBD-ODE/train.yaml +196 -0
- VPSBD-ODE/checkpoint.ckpt +3 -0
- VPSBD-ODE/train.yaml +182 -0
- VPSBD-SDE/checkpoint.ckpt +3 -0
- VPSBD-SDE/train.yaml +196 -0
EncDec-ODE-Gamma/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b523412c0a025405dbef142ff19d6ae264152a73bc7361a745bd4ed8fc3829e
|
| 3 |
+
size 49646459
|
EncDec-ODE-Gamma/train.yaml
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 0.8465315128078521
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
|
| 14 |
+
gamma: omg.si.gamma.LatentGammaEncoderDecoder
|
| 15 |
+
epsilon: null
|
| 16 |
+
differential_equation_type: "ODE"
|
| 17 |
+
integrator_kwargs:
|
| 18 |
+
method: "euler"
|
| 19 |
+
velocity_annealing_factor: 10.274308845621986
|
| 20 |
+
correct_center_of_mass_motion: true
|
| 21 |
+
# lattice vectors
|
| 22 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 23 |
+
init_args:
|
| 24 |
+
interpolant: omg.si.interpolants.LinearInterpolant
|
| 25 |
+
gamma: null
|
| 26 |
+
epsilon: null
|
| 27 |
+
differential_equation_type: "ODE"
|
| 28 |
+
integrator_kwargs:
|
| 29 |
+
method: "euler"
|
| 30 |
+
velocity_annealing_factor: 0.08217016314129522
|
| 31 |
+
correct_center_of_mass_motion: false
|
| 32 |
+
data_fields:
|
| 33 |
+
# if the order of the data_fields changes,
|
| 34 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 35 |
+
- "species"
|
| 36 |
+
- "pos"
|
| 37 |
+
- "cell"
|
| 38 |
+
integration_time_steps: 840
|
| 39 |
+
relative_si_costs:
|
| 40 |
+
species_loss: 0.2648412544596816
|
| 41 |
+
pos_loss_b: 0.7267924862588087
|
| 42 |
+
cell_loss_b: 0.008366259281509825
|
| 43 |
+
sampler:
|
| 44 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 45 |
+
init_args:
|
| 46 |
+
pos_distribution: null
|
| 47 |
+
cell_distribution:
|
| 48 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 49 |
+
init_args:
|
| 50 |
+
dataset_name: mp_20
|
| 51 |
+
species_distribution:
|
| 52 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 53 |
+
model:
|
| 54 |
+
class_path: omg.model.model.Model
|
| 55 |
+
init_args:
|
| 56 |
+
encoder:
|
| 57 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 58 |
+
head:
|
| 59 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 60 |
+
time_embedder:
|
| 61 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 62 |
+
init_args:
|
| 63 |
+
dim: 256
|
| 64 |
+
use_min_perm_dist: False
|
| 65 |
+
float_32_matmul_precision: "high"
|
| 66 |
+
validation_mode: "dng_eval"
|
| 67 |
+
dataset_name: "mp_20"
|
| 68 |
+
data:
|
| 69 |
+
train_dataset:
|
| 70 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 71 |
+
init_args:
|
| 72 |
+
dataset:
|
| 73 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 74 |
+
init_args:
|
| 75 |
+
lmdb_paths:
|
| 76 |
+
- "data/mp_20/train.lmdb"
|
| 77 |
+
niggli: False
|
| 78 |
+
val_dataset:
|
| 79 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 80 |
+
init_args:
|
| 81 |
+
dataset:
|
| 82 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 83 |
+
init_args:
|
| 84 |
+
lmdb_paths:
|
| 85 |
+
- "data/mp_20/val.lmdb"
|
| 86 |
+
niggli: False
|
| 87 |
+
predict_dataset:
|
| 88 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 89 |
+
init_args:
|
| 90 |
+
dataset:
|
| 91 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 92 |
+
init_args:
|
| 93 |
+
lmdb_paths:
|
| 94 |
+
- "data/mp_20/test.lmdb"
|
| 95 |
+
niggli: False
|
| 96 |
+
batch_size: 128
|
| 97 |
+
num_workers: 4
|
| 98 |
+
pin_memory: True
|
| 99 |
+
persistent_workers: True
|
| 100 |
+
trainer:
|
| 101 |
+
callbacks:
|
| 102 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 103 |
+
init_args:
|
| 104 |
+
filename: "best_val_loss_total"
|
| 105 |
+
save_top_k: 1
|
| 106 |
+
monitor: "val_loss_total"
|
| 107 |
+
save_weights_only: true
|
| 108 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 109 |
+
init_args:
|
| 110 |
+
filename: "best_val_dng_eval"
|
| 111 |
+
save_top_k: 1
|
| 112 |
+
monitor: "dng_eval"
|
| 113 |
+
save_weights_only: true
|
| 114 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 115 |
+
init_args:
|
| 116 |
+
filename: "best_val_wdist_density"
|
| 117 |
+
save_top_k: 1
|
| 118 |
+
monitor: "wdist_density"
|
| 119 |
+
save_weights_only: true
|
| 120 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 121 |
+
init_args:
|
| 122 |
+
filename: "best_val_wdist_Nary"
|
| 123 |
+
save_top_k: 1
|
| 124 |
+
monitor: "wdist_Nary"
|
| 125 |
+
save_weights_only: true
|
| 126 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 127 |
+
init_args:
|
| 128 |
+
filename: "best_val_wdist_CN"
|
| 129 |
+
save_top_k: 1
|
| 130 |
+
monitor: "wdist_CN"
|
| 131 |
+
save_weights_only: true
|
| 132 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 133 |
+
init_args:
|
| 134 |
+
filename: "best_val_cov_precision"
|
| 135 |
+
save_top_k: 1
|
| 136 |
+
monitor: "cov_precision"
|
| 137 |
+
mode: "max"
|
| 138 |
+
save_weights_only: true
|
| 139 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 140 |
+
init_args:
|
| 141 |
+
filename: "best_val_cov_recall"
|
| 142 |
+
save_top_k: 1
|
| 143 |
+
monitor: "cov_recall"
|
| 144 |
+
mode: "max"
|
| 145 |
+
save_weights_only: true
|
| 146 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 147 |
+
init_args:
|
| 148 |
+
filename: "best_val_validity"
|
| 149 |
+
save_top_k: 1
|
| 150 |
+
monitor: "validity_rate"
|
| 151 |
+
mode: "max"
|
| 152 |
+
save_weights_only: true
|
| 153 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 154 |
+
init_args:
|
| 155 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 156 |
+
monitor: "val_loss_total"
|
| 157 |
+
every_n_epochs: 100
|
| 158 |
+
save_weights_only: false
|
| 159 |
+
gradient_clip_val: 0.5
|
| 160 |
+
num_sanity_val_steps: 0
|
| 161 |
+
precision: "32-true"
|
| 162 |
+
max_epochs: 10000
|
| 163 |
+
enable_progress_bar: false
|
| 164 |
+
check_val_every_n_epoch: 100
|
| 165 |
+
optimizer:
|
| 166 |
+
class_path: torch.optim.Adam
|
| 167 |
+
init_args:
|
| 168 |
+
lr: 0.00012021943412654004
|
EncDec-SDE-Gamma/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5242724db2a21bd949624fe805e84986b0f0ccfb229913bf3f9cf79ce5540f3d
|
| 3 |
+
size 148494714
|
EncDec-SDE-Gamma/train.yaml
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 19.77565076697948
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant:
|
| 14 |
+
class_path: omg.si.interpolants.PeriodicEncoderDecoderInterpolant
|
| 15 |
+
init_args:
|
| 16 |
+
switch_time: 0.7261434470495144
|
| 17 |
+
power: 0.5
|
| 18 |
+
gamma:
|
| 19 |
+
class_path: omg.si.gamma.LatentGammaEncoderDecoder
|
| 20 |
+
init_args:
|
| 21 |
+
a: 0.10379253121526635
|
| 22 |
+
switch_time: 0.7261434470495144
|
| 23 |
+
power: 0.5
|
| 24 |
+
epsilon:
|
| 25 |
+
class_path: omg.si.epsilon.VanishingEpsilon
|
| 26 |
+
init_args:
|
| 27 |
+
c: 7.095366578175936
|
| 28 |
+
mu: 0.18874541537289413
|
| 29 |
+
sigma: 0.020535877713041894
|
| 30 |
+
differential_equation_type: "SDE"
|
| 31 |
+
integrator_kwargs:
|
| 32 |
+
method: "euler"
|
| 33 |
+
dt: 0.0016387520590797067
|
| 34 |
+
velocity_annealing_factor: 7.868988752162741
|
| 35 |
+
correct_center_of_mass_motion: true
|
| 36 |
+
# lattice vectors
|
| 37 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 38 |
+
init_args:
|
| 39 |
+
interpolant: omg.si.interpolants.LinearInterpolant
|
| 40 |
+
gamma:
|
| 41 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 42 |
+
init_args:
|
| 43 |
+
a: 1.6505152121826552
|
| 44 |
+
epsilon: null
|
| 45 |
+
differential_equation_type: "ODE"
|
| 46 |
+
integrator_kwargs:
|
| 47 |
+
method: "euler"
|
| 48 |
+
velocity_annealing_factor: 3.919302532270132
|
| 49 |
+
correct_center_of_mass_motion: false
|
| 50 |
+
data_fields:
|
| 51 |
+
# if the order of the data_fields changes,
|
| 52 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 53 |
+
- "species"
|
| 54 |
+
- "pos"
|
| 55 |
+
- "cell"
|
| 56 |
+
integration_time_steps: 610
|
| 57 |
+
relative_si_costs:
|
| 58 |
+
species_loss: 0.4341084689667317
|
| 59 |
+
pos_loss_b: 0.21431903999859292
|
| 60 |
+
pos_loss_z: 0.1968226417175204
|
| 61 |
+
cell_loss_b: 0.15474984931715505
|
| 62 |
+
sampler:
|
| 63 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 64 |
+
init_args:
|
| 65 |
+
pos_distribution: null
|
| 66 |
+
cell_distribution:
|
| 67 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 68 |
+
init_args:
|
| 69 |
+
dataset_name: mp_20
|
| 70 |
+
species_distribution:
|
| 71 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 72 |
+
model:
|
| 73 |
+
class_path: omg.model.model.Model
|
| 74 |
+
init_args:
|
| 75 |
+
encoder:
|
| 76 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 77 |
+
head:
|
| 78 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 79 |
+
time_embedder:
|
| 80 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 81 |
+
init_args:
|
| 82 |
+
dim: 256
|
| 83 |
+
use_min_perm_dist: False
|
| 84 |
+
float_32_matmul_precision: "high"
|
| 85 |
+
validation_mode: "dng_eval"
|
| 86 |
+
dataset_name: "mp_20"
|
| 87 |
+
data:
|
| 88 |
+
train_dataset:
|
| 89 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 90 |
+
init_args:
|
| 91 |
+
dataset:
|
| 92 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 93 |
+
init_args:
|
| 94 |
+
lmdb_paths:
|
| 95 |
+
- "data/mp_20/train.lmdb"
|
| 96 |
+
niggli: False
|
| 97 |
+
val_dataset:
|
| 98 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 99 |
+
init_args:
|
| 100 |
+
dataset:
|
| 101 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 102 |
+
init_args:
|
| 103 |
+
lmdb_paths:
|
| 104 |
+
- "data/mp_20/val.lmdb"
|
| 105 |
+
niggli: False
|
| 106 |
+
predict_dataset:
|
| 107 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 108 |
+
init_args:
|
| 109 |
+
dataset:
|
| 110 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 111 |
+
init_args:
|
| 112 |
+
lmdb_paths:
|
| 113 |
+
- "data/mp_20/test.lmdb"
|
| 114 |
+
niggli: False
|
| 115 |
+
batch_size: 32
|
| 116 |
+
num_workers: 4
|
| 117 |
+
pin_memory: True
|
| 118 |
+
persistent_workers: True
|
| 119 |
+
trainer:
|
| 120 |
+
callbacks:
|
| 121 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 122 |
+
init_args:
|
| 123 |
+
filename: "best_val_loss_total"
|
| 124 |
+
save_top_k: 1
|
| 125 |
+
monitor: "val_loss_total"
|
| 126 |
+
save_weights_only: true
|
| 127 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 128 |
+
init_args:
|
| 129 |
+
filename: "best_val_dng_eval"
|
| 130 |
+
save_top_k: 1
|
| 131 |
+
monitor: "dng_eval"
|
| 132 |
+
save_weights_only: true
|
| 133 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 134 |
+
init_args:
|
| 135 |
+
filename: "best_val_wdist_density"
|
| 136 |
+
save_top_k: 1
|
| 137 |
+
monitor: "wdist_density"
|
| 138 |
+
save_weights_only: true
|
| 139 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 140 |
+
init_args:
|
| 141 |
+
filename: "best_val_wdist_Nary"
|
| 142 |
+
save_top_k: 1
|
| 143 |
+
monitor: "wdist_Nary"
|
| 144 |
+
save_weights_only: true
|
| 145 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 146 |
+
init_args:
|
| 147 |
+
filename: "best_val_wdist_CN"
|
| 148 |
+
save_top_k: 1
|
| 149 |
+
monitor: "wdist_CN"
|
| 150 |
+
save_weights_only: true
|
| 151 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 152 |
+
init_args:
|
| 153 |
+
filename: "best_val_cov_precision"
|
| 154 |
+
save_top_k: 1
|
| 155 |
+
monitor: "cov_precision"
|
| 156 |
+
mode: "max"
|
| 157 |
+
save_weights_only: true
|
| 158 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 159 |
+
init_args:
|
| 160 |
+
filename: "best_val_cov_recall"
|
| 161 |
+
save_top_k: 1
|
| 162 |
+
monitor: "cov_recall"
|
| 163 |
+
mode: "max"
|
| 164 |
+
save_weights_only: true
|
| 165 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 166 |
+
init_args:
|
| 167 |
+
filename: "best_val_validity"
|
| 168 |
+
save_top_k: 1
|
| 169 |
+
monitor: "validity_rate"
|
| 170 |
+
mode: "max"
|
| 171 |
+
save_weights_only: true
|
| 172 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 173 |
+
init_args:
|
| 174 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 175 |
+
monitor: "val_loss_total"
|
| 176 |
+
every_n_epochs: 100
|
| 177 |
+
save_weights_only: false
|
| 178 |
+
gradient_clip_val: 0.5
|
| 179 |
+
gradient_clip_algorithm: "value"
|
| 180 |
+
num_sanity_val_steps: 0
|
| 181 |
+
precision: "32-true"
|
| 182 |
+
max_epochs: 2000
|
| 183 |
+
enable_progress_bar: false
|
| 184 |
+
check_val_every_n_epoch: 100
|
| 185 |
+
optimizer:
|
| 186 |
+
class_path: torch.optim.AdamW
|
| 187 |
+
init_args:
|
| 188 |
+
lr: 3.139610174577985e-05
|
| 189 |
+
weight_decay: 3.560067412494533e-05
|
| 190 |
+
lr_scheduler:
|
| 191 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 192 |
+
init_args:
|
| 193 |
+
T_max: 2000
|
| 194 |
+
eta_min: 1e-07
|
| 195 |
+
|
Linear-ODE-Gamma/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6280e04b110be03b5bf5b727e34d9c42c3ba3ce58c943a7bea6db8824ddfdc1d
|
| 3 |
+
size 148519290
|
Linear-ODE-Gamma/train.yaml
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 23.870491382634235
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicLinearInterpolant
|
| 14 |
+
gamma:
|
| 15 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 16 |
+
init_args:
|
| 17 |
+
a: 1.4501684803942854
|
| 18 |
+
epsilon: null
|
| 19 |
+
differential_equation_type: "ODE"
|
| 20 |
+
integrator_kwargs:
|
| 21 |
+
method: "euler"
|
| 22 |
+
velocity_annealing_factor: 14.825022083056373
|
| 23 |
+
correct_center_of_mass_motion: true
|
| 24 |
+
# lattice vectors
|
| 25 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 26 |
+
init_args:
|
| 27 |
+
interpolant:
|
| 28 |
+
class_path: omg.si.interpolants.EncoderDecoderInterpolant
|
| 29 |
+
init_args:
|
| 30 |
+
switch_time: 0.13766881993242777
|
| 31 |
+
power: 1.0
|
| 32 |
+
gamma:
|
| 33 |
+
class_path: omg.si.gamma.LatentGammaEncoderDecoder
|
| 34 |
+
init_args:
|
| 35 |
+
a: 7.882046441638109
|
| 36 |
+
switch_time: 0.13766881993242777
|
| 37 |
+
power: 1.0
|
| 38 |
+
epsilon:
|
| 39 |
+
class_path: omg.si.epsilon.VanishingEpsilon
|
| 40 |
+
init_args:
|
| 41 |
+
c: 5.487699104615115
|
| 42 |
+
mu: 0.2899409657474152
|
| 43 |
+
sigma: 0.010062500495585096
|
| 44 |
+
differential_equation_type: "SDE"
|
| 45 |
+
integrator_kwargs:
|
| 46 |
+
method: "euler"
|
| 47 |
+
dt: 0.007736434228718281
|
| 48 |
+
velocity_annealing_factor: 5.9072140831863305
|
| 49 |
+
correct_center_of_mass_motion: false
|
| 50 |
+
data_fields:
|
| 51 |
+
# if the order of the data_fields changes,
|
| 52 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 53 |
+
- "species"
|
| 54 |
+
- "pos"
|
| 55 |
+
- "cell"
|
| 56 |
+
integration_time_steps: 130
|
| 57 |
+
relative_si_costs:
|
| 58 |
+
species_loss: 0.22163297905676768
|
| 59 |
+
pos_loss_b: 0.7682858351574293
|
| 60 |
+
cell_loss_b: 0.008860420349864338
|
| 61 |
+
cell_loss_z: 0.001220765435938645
|
| 62 |
+
sampler:
|
| 63 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 64 |
+
init_args:
|
| 65 |
+
pos_distribution: null
|
| 66 |
+
cell_distribution:
|
| 67 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 68 |
+
init_args:
|
| 69 |
+
dataset_name: mp_20
|
| 70 |
+
species_distribution:
|
| 71 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 72 |
+
model:
|
| 73 |
+
class_path: omg.model.model.Model
|
| 74 |
+
init_args:
|
| 75 |
+
encoder:
|
| 76 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 77 |
+
head:
|
| 78 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 79 |
+
time_embedder:
|
| 80 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 81 |
+
init_args:
|
| 82 |
+
dim: 256
|
| 83 |
+
use_min_perm_dist: True
|
| 84 |
+
float_32_matmul_precision: "high"
|
| 85 |
+
validation_mode: "dng_eval"
|
| 86 |
+
dataset_name: "mp_20"
|
| 87 |
+
data:
|
| 88 |
+
train_dataset:
|
| 89 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 90 |
+
init_args:
|
| 91 |
+
dataset:
|
| 92 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 93 |
+
init_args:
|
| 94 |
+
lmdb_paths:
|
| 95 |
+
- "data/mp_20/train.lmdb"
|
| 96 |
+
niggli: False
|
| 97 |
+
val_dataset:
|
| 98 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 99 |
+
init_args:
|
| 100 |
+
dataset:
|
| 101 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 102 |
+
init_args:
|
| 103 |
+
lmdb_paths:
|
| 104 |
+
- "data/mp_20/val.lmdb"
|
| 105 |
+
niggli: False
|
| 106 |
+
predict_dataset:
|
| 107 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 108 |
+
init_args:
|
| 109 |
+
dataset:
|
| 110 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 111 |
+
init_args:
|
| 112 |
+
lmdb_paths:
|
| 113 |
+
- "data/mp_20/test.lmdb"
|
| 114 |
+
niggli: False
|
| 115 |
+
batch_size: 256
|
| 116 |
+
num_workers: 4
|
| 117 |
+
pin_memory: True
|
| 118 |
+
persistent_workers: True
|
| 119 |
+
trainer:
|
| 120 |
+
callbacks:
|
| 121 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 122 |
+
init_args:
|
| 123 |
+
filename: "best_val_loss_total"
|
| 124 |
+
save_top_k: 1
|
| 125 |
+
monitor: "val_loss_total"
|
| 126 |
+
save_weights_only: true
|
| 127 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 128 |
+
init_args:
|
| 129 |
+
filename: "best_val_dng_eval"
|
| 130 |
+
save_top_k: 1
|
| 131 |
+
monitor: "dng_eval"
|
| 132 |
+
save_weights_only: true
|
| 133 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 134 |
+
init_args:
|
| 135 |
+
filename: "best_val_wdist_density"
|
| 136 |
+
save_top_k: 1
|
| 137 |
+
monitor: "wdist_density"
|
| 138 |
+
save_weights_only: true
|
| 139 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 140 |
+
init_args:
|
| 141 |
+
filename: "best_val_wdist_Nary"
|
| 142 |
+
save_top_k: 1
|
| 143 |
+
monitor: "wdist_Nary"
|
| 144 |
+
save_weights_only: true
|
| 145 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 146 |
+
init_args:
|
| 147 |
+
filename: "best_val_wdist_CN"
|
| 148 |
+
save_top_k: 1
|
| 149 |
+
monitor: "wdist_CN"
|
| 150 |
+
save_weights_only: true
|
| 151 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 152 |
+
init_args:
|
| 153 |
+
filename: "best_val_cov_precision"
|
| 154 |
+
save_top_k: 1
|
| 155 |
+
monitor: "cov_precision"
|
| 156 |
+
mode: "max"
|
| 157 |
+
save_weights_only: true
|
| 158 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 159 |
+
init_args:
|
| 160 |
+
filename: "best_val_cov_recall"
|
| 161 |
+
save_top_k: 1
|
| 162 |
+
monitor: "cov_recall"
|
| 163 |
+
mode: "max"
|
| 164 |
+
save_weights_only: true
|
| 165 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 166 |
+
init_args:
|
| 167 |
+
filename: "best_val_validity"
|
| 168 |
+
save_top_k: 1
|
| 169 |
+
monitor: "validity_rate"
|
| 170 |
+
mode: "max"
|
| 171 |
+
save_weights_only: true
|
| 172 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 173 |
+
init_args:
|
| 174 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 175 |
+
monitor: "val_loss_total"
|
| 176 |
+
every_n_epochs: 100
|
| 177 |
+
save_weights_only: false
|
| 178 |
+
gradient_clip_val: 0.5
|
| 179 |
+
gradient_clip_algorithm: "value"
|
| 180 |
+
num_sanity_val_steps: 0
|
| 181 |
+
precision: "32-true"
|
| 182 |
+
max_epochs: 2000
|
| 183 |
+
enable_progress_bar: false
|
| 184 |
+
check_val_every_n_epoch: 100
|
| 185 |
+
optimizer:
|
| 186 |
+
class_path: torch.optim.AdamW
|
| 187 |
+
init_args:
|
| 188 |
+
lr: 0.00019511523812262233
|
| 189 |
+
weight_decay: 0.0003813804936812436
|
| 190 |
+
lr_scheduler:
|
| 191 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 192 |
+
init_args:
|
| 193 |
+
T_max: 2000
|
| 194 |
+
eta_min: 1e-07
|
| 195 |
+
|
Linear-ODE/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d14cec97c25c273f97009965293702ec5eda1f56648cfa0f2fac0acdf6b3459
|
| 3 |
+
size 49646459
|
Linear-ODE/train.yaml
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 7.080372063368751
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicLinearInterpolant
|
| 14 |
+
gamma: null
|
| 15 |
+
epsilon: null
|
| 16 |
+
differential_equation_type: "ODE"
|
| 17 |
+
integrator_kwargs:
|
| 18 |
+
method: "euler"
|
| 19 |
+
velocity_annealing_factor: 13.620695525269845
|
| 20 |
+
correct_center_of_mass_motion: true
|
| 21 |
+
# lattice vectors
|
| 22 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 23 |
+
init_args:
|
| 24 |
+
interpolant: omg.si.interpolants.LinearInterpolant
|
| 25 |
+
gamma: null
|
| 26 |
+
epsilon: null
|
| 27 |
+
differential_equation_type: "ODE"
|
| 28 |
+
integrator_kwargs:
|
| 29 |
+
method: "euler"
|
| 30 |
+
velocity_annealing_factor: 1.0679662602192312
|
| 31 |
+
correct_center_of_mass_motion: false
|
| 32 |
+
data_fields:
|
| 33 |
+
# if the order of the data_fields changes,
|
| 34 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 35 |
+
- "species"
|
| 36 |
+
- "pos"
|
| 37 |
+
- "cell"
|
| 38 |
+
integration_time_steps: 150
|
| 39 |
+
relative_si_costs:
|
| 40 |
+
species_loss: 0.021815399034596464
|
| 41 |
+
pos_loss_b: 0.9775483266595605
|
| 42 |
+
cell_loss_b: 0.0006362743058429793
|
| 43 |
+
sampler:
|
| 44 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 45 |
+
init_args:
|
| 46 |
+
pos_distribution: null
|
| 47 |
+
cell_distribution:
|
| 48 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 49 |
+
init_args:
|
| 50 |
+
dataset_name: mp_20
|
| 51 |
+
species_distribution:
|
| 52 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 53 |
+
model:
|
| 54 |
+
class_path: omg.model.model.Model
|
| 55 |
+
init_args:
|
| 56 |
+
encoder:
|
| 57 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 58 |
+
head:
|
| 59 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 60 |
+
time_embedder:
|
| 61 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 62 |
+
init_args:
|
| 63 |
+
dim: 256
|
| 64 |
+
use_min_perm_dist: True
|
| 65 |
+
float_32_matmul_precision: "high"
|
| 66 |
+
validation_mode: "dng_eval"
|
| 67 |
+
dataset_name: "mp_20"
|
| 68 |
+
data:
|
| 69 |
+
train_dataset:
|
| 70 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 71 |
+
init_args:
|
| 72 |
+
dataset:
|
| 73 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 74 |
+
init_args:
|
| 75 |
+
lmdb_paths:
|
| 76 |
+
- "data/mp_20/train.lmdb"
|
| 77 |
+
niggli: False
|
| 78 |
+
val_dataset:
|
| 79 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 80 |
+
init_args:
|
| 81 |
+
dataset:
|
| 82 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 83 |
+
init_args:
|
| 84 |
+
lmdb_paths:
|
| 85 |
+
- "data/mp_20/val.lmdb"
|
| 86 |
+
niggli: False
|
| 87 |
+
predict_dataset:
|
| 88 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 89 |
+
init_args:
|
| 90 |
+
dataset:
|
| 91 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 92 |
+
init_args:
|
| 93 |
+
lmdb_paths:
|
| 94 |
+
- "data/mp_20/test.lmdb"
|
| 95 |
+
niggli: False
|
| 96 |
+
batch_size: 512
|
| 97 |
+
num_workers: 4
|
| 98 |
+
pin_memory: True
|
| 99 |
+
persistent_workers: True
|
| 100 |
+
trainer:
|
| 101 |
+
callbacks:
|
| 102 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 103 |
+
init_args:
|
| 104 |
+
filename: "best_val_loss_total"
|
| 105 |
+
save_top_k: 1
|
| 106 |
+
monitor: "val_loss_total"
|
| 107 |
+
save_weights_only: true
|
| 108 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 109 |
+
init_args:
|
| 110 |
+
filename: "best_val_dng_eval"
|
| 111 |
+
save_top_k: 1
|
| 112 |
+
monitor: "dng_eval"
|
| 113 |
+
save_weights_only: true
|
| 114 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 115 |
+
init_args:
|
| 116 |
+
filename: "best_val_wdist_density"
|
| 117 |
+
save_top_k: 1
|
| 118 |
+
monitor: "wdist_density"
|
| 119 |
+
save_weights_only: true
|
| 120 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 121 |
+
init_args:
|
| 122 |
+
filename: "best_val_wdist_Nary"
|
| 123 |
+
save_top_k: 1
|
| 124 |
+
monitor: "wdist_Nary"
|
| 125 |
+
save_weights_only: true
|
| 126 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 127 |
+
init_args:
|
| 128 |
+
filename: "best_val_wdist_CN"
|
| 129 |
+
save_top_k: 1
|
| 130 |
+
monitor: "wdist_CN"
|
| 131 |
+
save_weights_only: true
|
| 132 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 133 |
+
init_args:
|
| 134 |
+
filename: "best_val_cov_precision"
|
| 135 |
+
save_top_k: 1
|
| 136 |
+
monitor: "cov_precision"
|
| 137 |
+
mode: "max"
|
| 138 |
+
save_weights_only: true
|
| 139 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 140 |
+
init_args:
|
| 141 |
+
filename: "best_val_cov_recall"
|
| 142 |
+
save_top_k: 1
|
| 143 |
+
monitor: "cov_recall"
|
| 144 |
+
mode: "max"
|
| 145 |
+
save_weights_only: true
|
| 146 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 147 |
+
init_args:
|
| 148 |
+
filename: "best_val_validity"
|
| 149 |
+
save_top_k: 1
|
| 150 |
+
monitor: "validity_rate"
|
| 151 |
+
mode: "max"
|
| 152 |
+
save_weights_only: true
|
| 153 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 154 |
+
init_args:
|
| 155 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 156 |
+
monitor: "val_loss_total"
|
| 157 |
+
every_n_epochs: 100
|
| 158 |
+
save_weights_only: false
|
| 159 |
+
gradient_clip_val: 0.5
|
| 160 |
+
num_sanity_val_steps: 0
|
| 161 |
+
precision: "32-true"
|
| 162 |
+
max_epochs: 10000
|
| 163 |
+
enable_progress_bar: false
|
| 164 |
+
check_val_every_n_epoch: 100
|
| 165 |
+
optimizer:
|
| 166 |
+
class_path: torch.optim.Adam
|
| 167 |
+
init_args:
|
| 168 |
+
lr: 0.001736512450391209
|
Linear-SDE-Gamma/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33f3460360ee97b7f2f18fd4d9553b7f1c9a523f4470bb0587e236c9b4022a6a
|
| 3 |
+
size 148494394
|
Linear-SDE-Gamma/train.yaml
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 0.18946955217679085
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicLinearInterpolant
|
| 14 |
+
gamma:
|
| 15 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 16 |
+
init_args:
|
| 17 |
+
a: 0.018159684059653552
|
| 18 |
+
epsilon:
|
| 19 |
+
class_path: omg.si.epsilon.VanishingEpsilon
|
| 20 |
+
init_args:
|
| 21 |
+
c: 9.74900863316411
|
| 22 |
+
mu: 0.17191546490562354
|
| 23 |
+
sigma: 0.029425925880471573
|
| 24 |
+
differential_equation_type: "SDE"
|
| 25 |
+
integrator_kwargs:
|
| 26 |
+
method: "euler"
|
| 27 |
+
dt: 0.0014076164225116372
|
| 28 |
+
velocity_annealing_factor: 6.334345265874859
|
| 29 |
+
correct_center_of_mass_motion: true
|
| 30 |
+
# lattice vectors
|
| 31 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 32 |
+
init_args:
|
| 33 |
+
interpolant: omg.si.interpolants.LinearInterpolant
|
| 34 |
+
gamma: null
|
| 35 |
+
epsilon: null
|
| 36 |
+
differential_equation_type: "ODE"
|
| 37 |
+
integrator_kwargs:
|
| 38 |
+
method: "euler"
|
| 39 |
+
velocity_annealing_factor: 1.0674474901964888
|
| 40 |
+
correct_center_of_mass_motion: false
|
| 41 |
+
data_fields:
|
| 42 |
+
# if the order of the data_fields changes,
|
| 43 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 44 |
+
- "species"
|
| 45 |
+
- "pos"
|
| 46 |
+
- "cell"
|
| 47 |
+
integration_time_steps: 710
|
| 48 |
+
relative_si_costs:
|
| 49 |
+
species_loss: 0.5918064683979826
|
| 50 |
+
pos_loss_b: 0.13091891010303253
|
| 51 |
+
pos_loss_z: 0.27077286248215743
|
| 52 |
+
cell_loss_b: 0.006501759016827464
|
| 53 |
+
sampler:
|
| 54 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 55 |
+
init_args:
|
| 56 |
+
pos_distribution: null
|
| 57 |
+
cell_distribution:
|
| 58 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 59 |
+
init_args:
|
| 60 |
+
dataset_name: mp_20
|
| 61 |
+
species_distribution:
|
| 62 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 63 |
+
model:
|
| 64 |
+
class_path: omg.model.model.Model
|
| 65 |
+
init_args:
|
| 66 |
+
encoder:
|
| 67 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 68 |
+
head:
|
| 69 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 70 |
+
time_embedder:
|
| 71 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 72 |
+
init_args:
|
| 73 |
+
dim: 256
|
| 74 |
+
use_min_perm_dist: True
|
| 75 |
+
float_32_matmul_precision: "high"
|
| 76 |
+
validation_mode: "dng_eval"
|
| 77 |
+
dataset_name: "mp_20"
|
| 78 |
+
data:
|
| 79 |
+
train_dataset:
|
| 80 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 81 |
+
init_args:
|
| 82 |
+
dataset:
|
| 83 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 84 |
+
init_args:
|
| 85 |
+
lmdb_paths:
|
| 86 |
+
- "data/mp_20/train.lmdb"
|
| 87 |
+
niggli: False
|
| 88 |
+
val_dataset:
|
| 89 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 90 |
+
init_args:
|
| 91 |
+
dataset:
|
| 92 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 93 |
+
init_args:
|
| 94 |
+
lmdb_paths:
|
| 95 |
+
- "data/mp_20/val.lmdb"
|
| 96 |
+
niggli: False
|
| 97 |
+
predict_dataset:
|
| 98 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 99 |
+
init_args:
|
| 100 |
+
dataset:
|
| 101 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 102 |
+
init_args:
|
| 103 |
+
lmdb_paths:
|
| 104 |
+
- "data/mp_20/test.lmdb"
|
| 105 |
+
niggli: False
|
| 106 |
+
batch_size: 32
|
| 107 |
+
num_workers: 4
|
| 108 |
+
pin_memory: True
|
| 109 |
+
persistent_workers: True
|
| 110 |
+
trainer:
|
| 111 |
+
callbacks:
|
| 112 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 113 |
+
init_args:
|
| 114 |
+
filename: "best_val_loss_total"
|
| 115 |
+
save_top_k: 1
|
| 116 |
+
monitor: "val_loss_total"
|
| 117 |
+
save_weights_only: true
|
| 118 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 119 |
+
init_args:
|
| 120 |
+
filename: "best_val_dng_eval"
|
| 121 |
+
save_top_k: 1
|
| 122 |
+
monitor: "dng_eval"
|
| 123 |
+
save_weights_only: true
|
| 124 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 125 |
+
init_args:
|
| 126 |
+
filename: "best_val_wdist_density"
|
| 127 |
+
save_top_k: 1
|
| 128 |
+
monitor: "wdist_density"
|
| 129 |
+
save_weights_only: true
|
| 130 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 131 |
+
init_args:
|
| 132 |
+
filename: "best_val_wdist_Nary"
|
| 133 |
+
save_top_k: 1
|
| 134 |
+
monitor: "wdist_Nary"
|
| 135 |
+
save_weights_only: true
|
| 136 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 137 |
+
init_args:
|
| 138 |
+
filename: "best_val_wdist_CN"
|
| 139 |
+
save_top_k: 1
|
| 140 |
+
monitor: "wdist_CN"
|
| 141 |
+
save_weights_only: true
|
| 142 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 143 |
+
init_args:
|
| 144 |
+
filename: "best_val_cov_precision"
|
| 145 |
+
save_top_k: 1
|
| 146 |
+
monitor: "cov_precision"
|
| 147 |
+
mode: "max"
|
| 148 |
+
save_weights_only: true
|
| 149 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 150 |
+
init_args:
|
| 151 |
+
filename: "best_val_cov_recall"
|
| 152 |
+
save_top_k: 1
|
| 153 |
+
monitor: "cov_recall"
|
| 154 |
+
mode: "max"
|
| 155 |
+
save_weights_only: true
|
| 156 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 157 |
+
init_args:
|
| 158 |
+
filename: "best_val_validity"
|
| 159 |
+
save_top_k: 1
|
| 160 |
+
monitor: "validity_rate"
|
| 161 |
+
mode: "max"
|
| 162 |
+
save_weights_only: true
|
| 163 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 164 |
+
init_args:
|
| 165 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 166 |
+
monitor: "val_loss_total"
|
| 167 |
+
every_n_epochs: 100
|
| 168 |
+
save_weights_only: false
|
| 169 |
+
gradient_clip_val: 0.5
|
| 170 |
+
gradient_clip_algorithm: "value"
|
| 171 |
+
num_sanity_val_steps: 0
|
| 172 |
+
precision: "32-true"
|
| 173 |
+
max_epochs: 2000
|
| 174 |
+
enable_progress_bar: false
|
| 175 |
+
check_val_every_n_epoch: 100
|
| 176 |
+
optimizer:
|
| 177 |
+
class_path: torch.optim.AdamW
|
| 178 |
+
init_args:
|
| 179 |
+
lr: 0.00019745455354877462
|
| 180 |
+
weight_decay: 0.0003111161289640361
|
| 181 |
+
lr_scheduler:
|
| 182 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 183 |
+
init_args:
|
| 184 |
+
T_max: 2000
|
| 185 |
+
eta_min: 1e-07
|
| 186 |
+
|
Trig-ODE-Gamma/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e3d3c53b60d074e40860ff25e65bedf14da18eeabef997f1e0a795666ab4559e
|
| 3 |
+
size 148519034
|
Trig-ODE-Gamma/train.yaml
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 27.249112246908787
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
|
| 14 |
+
gamma:
|
| 15 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 16 |
+
init_args:
|
| 17 |
+
a: 0.02697960692675219
|
| 18 |
+
epsilon: null
|
| 19 |
+
differential_equation_type: "ODE"
|
| 20 |
+
integrator_kwargs:
|
| 21 |
+
method: "euler"
|
| 22 |
+
velocity_annealing_factor: 7.788016364580001
|
| 23 |
+
correct_center_of_mass_motion: true
|
| 24 |
+
# lattice vectors
|
| 25 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 26 |
+
init_args:
|
| 27 |
+
interpolant: omg.si.interpolants.LinearInterpolant
|
| 28 |
+
gamma:
|
| 29 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 30 |
+
init_args:
|
| 31 |
+
a: 0.8480953426626178
|
| 32 |
+
epsilon:
|
| 33 |
+
class_path: omg.si.epsilon.VanishingEpsilon
|
| 34 |
+
init_args:
|
| 35 |
+
c: 3.673041703474942
|
| 36 |
+
mu: 0.0706838941982046
|
| 37 |
+
sigma: 0.01812545867373373
|
| 38 |
+
differential_equation_type: "SDE"
|
| 39 |
+
integrator_kwargs:
|
| 40 |
+
method: "euler"
|
| 41 |
+
dt: 0.001469808747060597
|
| 42 |
+
velocity_annealing_factor: 0.2916081322471492
|
| 43 |
+
correct_center_of_mass_motion: false
|
| 44 |
+
data_fields:
|
| 45 |
+
# if the order of the data_fields changes,
|
| 46 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 47 |
+
- "species"
|
| 48 |
+
- "pos"
|
| 49 |
+
- "cell"
|
| 50 |
+
integration_time_steps: 680
|
| 51 |
+
relative_si_costs:
|
| 52 |
+
species_loss: 0.43055618267791895
|
| 53 |
+
pos_loss_b: 0.2322254093385872
|
| 54 |
+
cell_loss_b: 0.003464180396862092
|
| 55 |
+
cell_loss_z: 0.33375422758663176
|
| 56 |
+
sampler:
|
| 57 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 58 |
+
init_args:
|
| 59 |
+
pos_distribution: null
|
| 60 |
+
cell_distribution:
|
| 61 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 62 |
+
init_args:
|
| 63 |
+
dataset_name: mp_20
|
| 64 |
+
species_distribution:
|
| 65 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 66 |
+
model:
|
| 67 |
+
class_path: omg.model.model.Model
|
| 68 |
+
init_args:
|
| 69 |
+
encoder:
|
| 70 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 71 |
+
head:
|
| 72 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 73 |
+
time_embedder:
|
| 74 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 75 |
+
init_args:
|
| 76 |
+
dim: 256
|
| 77 |
+
use_min_perm_dist: True
|
| 78 |
+
float_32_matmul_precision: "high"
|
| 79 |
+
validation_mode: "dng_eval"
|
| 80 |
+
dataset_name: "mp_20"
|
| 81 |
+
data:
|
| 82 |
+
train_dataset:
|
| 83 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 84 |
+
init_args:
|
| 85 |
+
dataset:
|
| 86 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 87 |
+
init_args:
|
| 88 |
+
lmdb_paths:
|
| 89 |
+
- "data/mp_20/train.lmdb"
|
| 90 |
+
niggli: True
|
| 91 |
+
val_dataset:
|
| 92 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 93 |
+
init_args:
|
| 94 |
+
dataset:
|
| 95 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 96 |
+
init_args:
|
| 97 |
+
lmdb_paths:
|
| 98 |
+
- "data/mp_20/val.lmdb"
|
| 99 |
+
niggli: True
|
| 100 |
+
predict_dataset:
|
| 101 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 102 |
+
init_args:
|
| 103 |
+
dataset:
|
| 104 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 105 |
+
init_args:
|
| 106 |
+
lmdb_paths:
|
| 107 |
+
- "data/mp_20/test.lmdb"
|
| 108 |
+
niggli: True
|
| 109 |
+
batch_size: 32
|
| 110 |
+
num_workers: 4
|
| 111 |
+
pin_memory: True
|
| 112 |
+
persistent_workers: True
|
| 113 |
+
trainer:
|
| 114 |
+
callbacks:
|
| 115 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 116 |
+
init_args:
|
| 117 |
+
filename: "best_val_loss_total"
|
| 118 |
+
save_top_k: 1
|
| 119 |
+
monitor: "val_loss_total"
|
| 120 |
+
save_weights_only: true
|
| 121 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 122 |
+
init_args:
|
| 123 |
+
filename: "best_val_dng_eval"
|
| 124 |
+
save_top_k: 1
|
| 125 |
+
monitor: "dng_eval"
|
| 126 |
+
save_weights_only: true
|
| 127 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 128 |
+
init_args:
|
| 129 |
+
filename: "best_val_wdist_density"
|
| 130 |
+
save_top_k: 1
|
| 131 |
+
monitor: "wdist_density"
|
| 132 |
+
save_weights_only: true
|
| 133 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 134 |
+
init_args:
|
| 135 |
+
filename: "best_val_wdist_Nary"
|
| 136 |
+
save_top_k: 1
|
| 137 |
+
monitor: "wdist_Nary"
|
| 138 |
+
save_weights_only: true
|
| 139 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 140 |
+
init_args:
|
| 141 |
+
filename: "best_val_wdist_CN"
|
| 142 |
+
save_top_k: 1
|
| 143 |
+
monitor: "wdist_CN"
|
| 144 |
+
save_weights_only: true
|
| 145 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 146 |
+
init_args:
|
| 147 |
+
filename: "best_val_cov_precision"
|
| 148 |
+
save_top_k: 1
|
| 149 |
+
monitor: "cov_precision"
|
| 150 |
+
mode: "max"
|
| 151 |
+
save_weights_only: true
|
| 152 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 153 |
+
init_args:
|
| 154 |
+
filename: "best_val_cov_recall"
|
| 155 |
+
save_top_k: 1
|
| 156 |
+
monitor: "cov_recall"
|
| 157 |
+
mode: "max"
|
| 158 |
+
save_weights_only: true
|
| 159 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 160 |
+
init_args:
|
| 161 |
+
filename: "best_val_validity"
|
| 162 |
+
save_top_k: 1
|
| 163 |
+
monitor: "validity_rate"
|
| 164 |
+
mode: "max"
|
| 165 |
+
save_weights_only: true
|
| 166 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 167 |
+
init_args:
|
| 168 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 169 |
+
monitor: "val_loss_total"
|
| 170 |
+
every_n_epochs: 100
|
| 171 |
+
save_weights_only: false
|
| 172 |
+
gradient_clip_val: 0.5
|
| 173 |
+
gradient_clip_algorithm: "value"
|
| 174 |
+
num_sanity_val_steps: 0
|
| 175 |
+
precision: "32-true"
|
| 176 |
+
max_epochs: 2000
|
| 177 |
+
enable_progress_bar: false
|
| 178 |
+
check_val_every_n_epoch: 100
|
| 179 |
+
optimizer:
|
| 180 |
+
class_path: torch.optim.AdamW
|
| 181 |
+
init_args:
|
| 182 |
+
lr: 0.00014843344531647814
|
| 183 |
+
weight_decay: 0.00032033785912707614
|
| 184 |
+
lr_scheduler:
|
| 185 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 186 |
+
init_args:
|
| 187 |
+
T_max: 2000
|
| 188 |
+
eta_min: 1e-07
|
| 189 |
+
|
Trig-ODE/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da530d4e3b0e3ae3537da7ab517a3b2cfa754637289d72f7ff91f761a2a66557
|
| 3 |
+
size 148480896
|
Trig-ODE/train.yaml
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 32.687148090341246
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
|
| 14 |
+
gamma: null
|
| 15 |
+
epsilon: null
|
| 16 |
+
differential_equation_type: "ODE"
|
| 17 |
+
integrator_kwargs:
|
| 18 |
+
method: "euler"
|
| 19 |
+
velocity_annealing_factor: 8.59355026270501
|
| 20 |
+
correct_center_of_mass_motion: true
|
| 21 |
+
# lattice vectors
|
| 22 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 23 |
+
init_args:
|
| 24 |
+
interpolant: omg.si.interpolants.TrigonometricInterpolant
|
| 25 |
+
gamma:
|
| 26 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 27 |
+
init_args:
|
| 28 |
+
a: 1.182707731733833
|
| 29 |
+
epsilon: null
|
| 30 |
+
differential_equation_type: "ODE"
|
| 31 |
+
integrator_kwargs:
|
| 32 |
+
method: "euler"
|
| 33 |
+
velocity_annealing_factor: 0.308921595643179
|
| 34 |
+
correct_center_of_mass_motion: false
|
| 35 |
+
data_fields:
|
| 36 |
+
# if the order of the data_fields changes,
|
| 37 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 38 |
+
- "species"
|
| 39 |
+
- "pos"
|
| 40 |
+
- "cell"
|
| 41 |
+
integration_time_steps: 860
|
| 42 |
+
relative_si_costs:
|
| 43 |
+
species_loss: 0.6674861836833045
|
| 44 |
+
pos_loss_b: 0.33016840408608256
|
| 45 |
+
cell_loss_b: 0.002345412230612938
|
| 46 |
+
sampler:
|
| 47 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 48 |
+
init_args:
|
| 49 |
+
pos_distribution: null
|
| 50 |
+
cell_distribution:
|
| 51 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 52 |
+
init_args:
|
| 53 |
+
dataset_name: mp_20
|
| 54 |
+
species_distribution:
|
| 55 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 56 |
+
model:
|
| 57 |
+
class_path: omg.model.model.Model
|
| 58 |
+
init_args:
|
| 59 |
+
encoder:
|
| 60 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 61 |
+
head:
|
| 62 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 63 |
+
time_embedder:
|
| 64 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 65 |
+
init_args:
|
| 66 |
+
dim: 256
|
| 67 |
+
use_min_perm_dist: True
|
| 68 |
+
float_32_matmul_precision: "high"
|
| 69 |
+
validation_mode: "dng_eval"
|
| 70 |
+
dataset_name: "mp_20"
|
| 71 |
+
data:
|
| 72 |
+
train_dataset:
|
| 73 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 74 |
+
init_args:
|
| 75 |
+
dataset:
|
| 76 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 77 |
+
init_args:
|
| 78 |
+
lmdb_paths:
|
| 79 |
+
- "data/mp_20/train.lmdb"
|
| 80 |
+
niggli: False
|
| 81 |
+
val_dataset:
|
| 82 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 83 |
+
init_args:
|
| 84 |
+
dataset:
|
| 85 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 86 |
+
init_args:
|
| 87 |
+
lmdb_paths:
|
| 88 |
+
- "data/mp_20/val.lmdb"
|
| 89 |
+
niggli: False
|
| 90 |
+
predict_dataset:
|
| 91 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 92 |
+
init_args:
|
| 93 |
+
dataset:
|
| 94 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 95 |
+
init_args:
|
| 96 |
+
lmdb_paths:
|
| 97 |
+
- "data/mp_20/test.lmdb"
|
| 98 |
+
niggli: False
|
| 99 |
+
batch_size: 128
|
| 100 |
+
num_workers: 4
|
| 101 |
+
pin_memory: True
|
| 102 |
+
persistent_workers: True
|
| 103 |
+
trainer:
|
| 104 |
+
callbacks:
|
| 105 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 106 |
+
init_args:
|
| 107 |
+
filename: "best_val_loss_total"
|
| 108 |
+
save_top_k: 1
|
| 109 |
+
monitor: "val_loss_total"
|
| 110 |
+
save_weights_only: true
|
| 111 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 112 |
+
init_args:
|
| 113 |
+
filename: "best_val_dng_eval"
|
| 114 |
+
save_top_k: 1
|
| 115 |
+
monitor: "dng_eval"
|
| 116 |
+
save_weights_only: true
|
| 117 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 118 |
+
init_args:
|
| 119 |
+
filename: "best_val_wdist_density"
|
| 120 |
+
save_top_k: 1
|
| 121 |
+
monitor: "wdist_density"
|
| 122 |
+
save_weights_only: true
|
| 123 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 124 |
+
init_args:
|
| 125 |
+
filename: "best_val_wdist_Nary"
|
| 126 |
+
save_top_k: 1
|
| 127 |
+
monitor: "wdist_Nary"
|
| 128 |
+
save_weights_only: true
|
| 129 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 130 |
+
init_args:
|
| 131 |
+
filename: "best_val_wdist_CN"
|
| 132 |
+
save_top_k: 1
|
| 133 |
+
monitor: "wdist_CN"
|
| 134 |
+
save_weights_only: true
|
| 135 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 136 |
+
init_args:
|
| 137 |
+
filename: "best_val_cov_precision"
|
| 138 |
+
save_top_k: 1
|
| 139 |
+
monitor: "cov_precision"
|
| 140 |
+
mode: "max"
|
| 141 |
+
save_weights_only: true
|
| 142 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 143 |
+
init_args:
|
| 144 |
+
filename: "best_val_cov_recall"
|
| 145 |
+
save_top_k: 1
|
| 146 |
+
monitor: "cov_recall"
|
| 147 |
+
mode: "max"
|
| 148 |
+
save_weights_only: true
|
| 149 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 150 |
+
init_args:
|
| 151 |
+
filename: "best_val_validity"
|
| 152 |
+
save_top_k: 1
|
| 153 |
+
monitor: "validity_rate"
|
| 154 |
+
mode: "max"
|
| 155 |
+
save_weights_only: true
|
| 156 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 157 |
+
init_args:
|
| 158 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 159 |
+
monitor: "val_loss_total"
|
| 160 |
+
every_n_epochs: 100
|
| 161 |
+
save_weights_only: false
|
| 162 |
+
gradient_clip_val: 0.5
|
| 163 |
+
gradient_clip_algorithm: "value"
|
| 164 |
+
num_sanity_val_steps: 0
|
| 165 |
+
precision: "32-true"
|
| 166 |
+
max_epochs: 2000
|
| 167 |
+
enable_progress_bar: false
|
| 168 |
+
check_val_every_n_epoch: 100
|
| 169 |
+
optimizer:
|
| 170 |
+
class_path: torch.optim.AdamW
|
| 171 |
+
init_args:
|
| 172 |
+
lr: 0.002704094492670699
|
| 173 |
+
weight_decay: 0.0006070253248675564
|
| 174 |
+
lr_scheduler:
|
| 175 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 176 |
+
init_args:
|
| 177 |
+
T_max: 2000
|
| 178 |
+
eta_min: 1e-07
|
| 179 |
+
|
Trig-SDE-Gamma/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af001ad47f40c2794e5115785c33bb3e26489cd4a4ff9a874af8ea41594d9f0c
|
| 3 |
+
size 148494458
|
Trig-SDE-Gamma/train.yaml
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 13.14607468893319
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicTrigonometricInterpolant
|
| 14 |
+
gamma:
|
| 15 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 16 |
+
init_args:
|
| 17 |
+
a: 0.022769980672795356
|
| 18 |
+
epsilon:
|
| 19 |
+
class_path: omg.si.epsilon.VanishingEpsilon
|
| 20 |
+
init_args:
|
| 21 |
+
c: 2.621699079870832
|
| 22 |
+
mu: 0.15417087293483117
|
| 23 |
+
sigma: 0.017962649662214652
|
| 24 |
+
differential_equation_type: "SDE"
|
| 25 |
+
integrator_kwargs:
|
| 26 |
+
method: "euler"
|
| 27 |
+
dt: 0.0013148878933861852
|
| 28 |
+
velocity_annealing_factor: 12.80156329264574
|
| 29 |
+
correct_center_of_mass_motion: true
|
| 30 |
+
# lattice vectors
|
| 31 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 32 |
+
init_args:
|
| 33 |
+
interpolant: omg.si.interpolants.TrigonometricInterpolant
|
| 34 |
+
gamma:
|
| 35 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 36 |
+
init_args:
|
| 37 |
+
a: 0.31566788838946225
|
| 38 |
+
epsilon: null
|
| 39 |
+
differential_equation_type: "ODE"
|
| 40 |
+
integrator_kwargs:
|
| 41 |
+
method: "euler"
|
| 42 |
+
velocity_annealing_factor: 4.36422932474701
|
| 43 |
+
correct_center_of_mass_motion: false
|
| 44 |
+
data_fields:
|
| 45 |
+
# if the order of the data_fields changes,
|
| 46 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 47 |
+
- "species"
|
| 48 |
+
- "pos"
|
| 49 |
+
- "cell"
|
| 50 |
+
integration_time_steps: 760
|
| 51 |
+
relative_si_costs:
|
| 52 |
+
species_loss: 0.13597807419582586
|
| 53 |
+
pos_loss_b: 0.6304347545056598
|
| 54 |
+
pos_loss_z: 0.0753230160674198
|
| 55 |
+
cell_loss_b: 0.15826415523109444
|
| 56 |
+
sampler:
|
| 57 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 58 |
+
init_args:
|
| 59 |
+
pos_distribution: null
|
| 60 |
+
cell_distribution:
|
| 61 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 62 |
+
init_args:
|
| 63 |
+
dataset_name: mp_20
|
| 64 |
+
species_distribution:
|
| 65 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 66 |
+
model:
|
| 67 |
+
class_path: omg.model.model.Model
|
| 68 |
+
init_args:
|
| 69 |
+
encoder:
|
| 70 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 71 |
+
head:
|
| 72 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 73 |
+
time_embedder:
|
| 74 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 75 |
+
init_args:
|
| 76 |
+
dim: 256
|
| 77 |
+
use_min_perm_dist: True
|
| 78 |
+
float_32_matmul_precision: "high"
|
| 79 |
+
validation_mode: "dng_eval"
|
| 80 |
+
dataset_name: "mp_20"
|
| 81 |
+
data:
|
| 82 |
+
train_dataset:
|
| 83 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 84 |
+
init_args:
|
| 85 |
+
dataset:
|
| 86 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 87 |
+
init_args:
|
| 88 |
+
lmdb_paths:
|
| 89 |
+
- "data/mp_20/train.lmdb"
|
| 90 |
+
niggli: False
|
| 91 |
+
val_dataset:
|
| 92 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 93 |
+
init_args:
|
| 94 |
+
dataset:
|
| 95 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 96 |
+
init_args:
|
| 97 |
+
lmdb_paths:
|
| 98 |
+
- "data/mp_20/val.lmdb"
|
| 99 |
+
niggli: False
|
| 100 |
+
predict_dataset:
|
| 101 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 102 |
+
init_args:
|
| 103 |
+
dataset:
|
| 104 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 105 |
+
init_args:
|
| 106 |
+
lmdb_paths:
|
| 107 |
+
- "data/mp_20/test.lmdb"
|
| 108 |
+
niggli: False
|
| 109 |
+
batch_size: 256
|
| 110 |
+
num_workers: 4
|
| 111 |
+
pin_memory: True
|
| 112 |
+
persistent_workers: True
|
| 113 |
+
trainer:
|
| 114 |
+
callbacks:
|
| 115 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 116 |
+
init_args:
|
| 117 |
+
filename: "best_val_loss_total"
|
| 118 |
+
save_top_k: 1
|
| 119 |
+
monitor: "val_loss_total"
|
| 120 |
+
save_weights_only: true
|
| 121 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 122 |
+
init_args:
|
| 123 |
+
filename: "best_val_dng_eval"
|
| 124 |
+
save_top_k: 1
|
| 125 |
+
monitor: "dng_eval"
|
| 126 |
+
save_weights_only: true
|
| 127 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 128 |
+
init_args:
|
| 129 |
+
filename: "best_val_wdist_density"
|
| 130 |
+
save_top_k: 1
|
| 131 |
+
monitor: "wdist_density"
|
| 132 |
+
save_weights_only: true
|
| 133 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 134 |
+
init_args:
|
| 135 |
+
filename: "best_val_wdist_Nary"
|
| 136 |
+
save_top_k: 1
|
| 137 |
+
monitor: "wdist_Nary"
|
| 138 |
+
save_weights_only: true
|
| 139 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 140 |
+
init_args:
|
| 141 |
+
filename: "best_val_wdist_CN"
|
| 142 |
+
save_top_k: 1
|
| 143 |
+
monitor: "wdist_CN"
|
| 144 |
+
save_weights_only: true
|
| 145 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 146 |
+
init_args:
|
| 147 |
+
filename: "best_val_cov_precision"
|
| 148 |
+
save_top_k: 1
|
| 149 |
+
monitor: "cov_precision"
|
| 150 |
+
mode: "max"
|
| 151 |
+
save_weights_only: true
|
| 152 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 153 |
+
init_args:
|
| 154 |
+
filename: "best_val_cov_recall"
|
| 155 |
+
save_top_k: 1
|
| 156 |
+
monitor: "cov_recall"
|
| 157 |
+
mode: "max"
|
| 158 |
+
save_weights_only: true
|
| 159 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 160 |
+
init_args:
|
| 161 |
+
filename: "best_val_validity"
|
| 162 |
+
save_top_k: 1
|
| 163 |
+
monitor: "validity_rate"
|
| 164 |
+
mode: "max"
|
| 165 |
+
save_weights_only: true
|
| 166 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 167 |
+
init_args:
|
| 168 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 169 |
+
monitor: "val_loss_total"
|
| 170 |
+
every_n_epochs: 100
|
| 171 |
+
save_weights_only: false
|
| 172 |
+
gradient_clip_val: 0.5
|
| 173 |
+
gradient_clip_algorithm: "value"
|
| 174 |
+
num_sanity_val_steps: 0
|
| 175 |
+
precision: "32-true"
|
| 176 |
+
max_epochs: 2000
|
| 177 |
+
enable_progress_bar: false
|
| 178 |
+
check_val_every_n_epoch: 100
|
| 179 |
+
optimizer:
|
| 180 |
+
class_path: torch.optim.AdamW
|
| 181 |
+
init_args:
|
| 182 |
+
lr: 0.0007969633652411341
|
| 183 |
+
weight_decay: 1.803908894626558e-05
|
| 184 |
+
lr_scheduler:
|
| 185 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 186 |
+
init_args:
|
| 187 |
+
T_max: 2000
|
| 188 |
+
eta_min: 1e-07
|
| 189 |
+
|
VESBD-ODE/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa131d3a8b1c1168441d6164f3247d7f9b364b57f037b3c0c35dbb66297e2951
|
| 3 |
+
size 148519354
|
VESBD-ODE/train.yaml
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 5.870180115373019
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant:
|
| 14 |
+
class_path: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolantVE
|
| 15 |
+
init_args:
|
| 16 |
+
sigma:
|
| 17 |
+
class_path: omg.si.sigma.GeometricSigma
|
| 18 |
+
init_args:
|
| 19 |
+
sigma_min: 0.0020828565391521787
|
| 20 |
+
sigma_max: 0.8318656965968637
|
| 21 |
+
epsilon: null
|
| 22 |
+
differential_equation_type: "ODE"
|
| 23 |
+
integrator_kwargs:
|
| 24 |
+
method: "euler"
|
| 25 |
+
velocity_annealing_factor: 12.718434028622262
|
| 26 |
+
correct_center_of_mass_motion: true
|
| 27 |
+
predict_velocity: true
|
| 28 |
+
# lattice vectors
|
| 29 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 30 |
+
init_args:
|
| 31 |
+
interpolant: omg.si.interpolants.LinearInterpolant
|
| 32 |
+
gamma:
|
| 33 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 34 |
+
init_args:
|
| 35 |
+
a: 0.9128957436969677
|
| 36 |
+
epsilon:
|
| 37 |
+
class_path: omg.si.epsilon.VanishingEpsilon
|
| 38 |
+
init_args:
|
| 39 |
+
c: 5.79866710636582
|
| 40 |
+
mu: 0.2500005563766677
|
| 41 |
+
sigma: 0.020369240775387647
|
| 42 |
+
differential_equation_type: "SDE"
|
| 43 |
+
integrator_kwargs:
|
| 44 |
+
method: "euler"
|
| 45 |
+
dt: 0.0030334345065057278
|
| 46 |
+
velocity_annealing_factor: 0.9750213755072251
|
| 47 |
+
correct_center_of_mass_motion: false
|
| 48 |
+
data_fields:
|
| 49 |
+
# if the order of the data_fields changes,
|
| 50 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 51 |
+
- "species"
|
| 52 |
+
- "pos"
|
| 53 |
+
- "cell"
|
| 54 |
+
integration_time_steps: 330
|
| 55 |
+
relative_si_costs:
|
| 56 |
+
species_loss: 0.09898280162138558
|
| 57 |
+
pos_loss_b: 0.22090708494461975
|
| 58 |
+
cell_loss_b: 0.04296279641692716
|
| 59 |
+
cell_loss_z: 0.6371473170170674
|
| 60 |
+
sampler:
|
| 61 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 62 |
+
init_args:
|
| 63 |
+
pos_distribution:
|
| 64 |
+
class_path: omg.sampler.distributions.NormalDistribution
|
| 65 |
+
init_args:
|
| 66 |
+
scale: 0.45439306223842724
|
| 67 |
+
cell_distribution:
|
| 68 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 69 |
+
init_args:
|
| 70 |
+
dataset_name: mp_20
|
| 71 |
+
species_distribution:
|
| 72 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 73 |
+
model:
|
| 74 |
+
class_path: omg.model.model.Model
|
| 75 |
+
init_args:
|
| 76 |
+
encoder:
|
| 77 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 78 |
+
head:
|
| 79 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 80 |
+
time_embedder:
|
| 81 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 82 |
+
init_args:
|
| 83 |
+
dim: 256
|
| 84 |
+
use_min_perm_dist: False
|
| 85 |
+
float_32_matmul_precision: "high"
|
| 86 |
+
validation_mode: "dng_eval"
|
| 87 |
+
dataset_name: "mp_20"
|
| 88 |
+
data:
|
| 89 |
+
train_dataset:
|
| 90 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 91 |
+
init_args:
|
| 92 |
+
dataset:
|
| 93 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 94 |
+
init_args:
|
| 95 |
+
lmdb_paths:
|
| 96 |
+
- "data/mp_20/train.lmdb"
|
| 97 |
+
niggli: True
|
| 98 |
+
val_dataset:
|
| 99 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 100 |
+
init_args:
|
| 101 |
+
dataset:
|
| 102 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 103 |
+
init_args:
|
| 104 |
+
lmdb_paths:
|
| 105 |
+
- "data/mp_20/val.lmdb"
|
| 106 |
+
niggli: True
|
| 107 |
+
predict_dataset:
|
| 108 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 109 |
+
init_args:
|
| 110 |
+
dataset:
|
| 111 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 112 |
+
init_args:
|
| 113 |
+
lmdb_paths:
|
| 114 |
+
- "data/mp_20/test.lmdb"
|
| 115 |
+
niggli: True
|
| 116 |
+
batch_size: 256
|
| 117 |
+
num_workers: 4
|
| 118 |
+
pin_memory: True
|
| 119 |
+
persistent_workers: True
|
| 120 |
+
trainer:
|
| 121 |
+
callbacks:
|
| 122 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 123 |
+
init_args:
|
| 124 |
+
filename: "best_val_loss_total"
|
| 125 |
+
save_top_k: 1
|
| 126 |
+
monitor: "val_loss_total"
|
| 127 |
+
save_weights_only: true
|
| 128 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 129 |
+
init_args:
|
| 130 |
+
filename: "best_val_dng_eval"
|
| 131 |
+
save_top_k: 1
|
| 132 |
+
monitor: "dng_eval"
|
| 133 |
+
save_weights_only: true
|
| 134 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 135 |
+
init_args:
|
| 136 |
+
filename: "best_val_wdist_density"
|
| 137 |
+
save_top_k: 1
|
| 138 |
+
monitor: "wdist_density"
|
| 139 |
+
save_weights_only: true
|
| 140 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 141 |
+
init_args:
|
| 142 |
+
filename: "best_val_wdist_Nary"
|
| 143 |
+
save_top_k: 1
|
| 144 |
+
monitor: "wdist_Nary"
|
| 145 |
+
save_weights_only: true
|
| 146 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 147 |
+
init_args:
|
| 148 |
+
filename: "best_val_wdist_CN"
|
| 149 |
+
save_top_k: 1
|
| 150 |
+
monitor: "wdist_CN"
|
| 151 |
+
save_weights_only: true
|
| 152 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 153 |
+
init_args:
|
| 154 |
+
filename: "best_val_cov_precision"
|
| 155 |
+
save_top_k: 1
|
| 156 |
+
monitor: "cov_precision"
|
| 157 |
+
mode: "max"
|
| 158 |
+
save_weights_only: true
|
| 159 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 160 |
+
init_args:
|
| 161 |
+
filename: "best_val_cov_recall"
|
| 162 |
+
save_top_k: 1
|
| 163 |
+
monitor: "cov_recall"
|
| 164 |
+
mode: "max"
|
| 165 |
+
save_weights_only: true
|
| 166 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 167 |
+
init_args:
|
| 168 |
+
filename: "best_val_validity"
|
| 169 |
+
save_top_k: 1
|
| 170 |
+
monitor: "validity_rate"
|
| 171 |
+
mode: "max"
|
| 172 |
+
save_weights_only: true
|
| 173 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 174 |
+
init_args:
|
| 175 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 176 |
+
monitor: "val_loss_total"
|
| 177 |
+
every_n_epochs: 100
|
| 178 |
+
save_weights_only: false
|
| 179 |
+
gradient_clip_val: 0.5
|
| 180 |
+
gradient_clip_algorithm: "value"
|
| 181 |
+
num_sanity_val_steps: 0
|
| 182 |
+
precision: "32-true"
|
| 183 |
+
max_epochs: 2000
|
| 184 |
+
enable_progress_bar: false
|
| 185 |
+
check_val_every_n_epoch: 100
|
| 186 |
+
optimizer:
|
| 187 |
+
class_path: torch.optim.AdamW
|
| 188 |
+
init_args:
|
| 189 |
+
lr: 0.0018098696508625563
|
| 190 |
+
weight_decay: 0.00026498129464991104
|
| 191 |
+
lr_scheduler:
|
| 192 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 193 |
+
init_args:
|
| 194 |
+
T_max: 2000
|
| 195 |
+
eta_min: 1e-07
|
| 196 |
+
|
VPSBD-ODE/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae9cc89dea91443d4f903d7d2e9c18a2f2828e5f1318f4f75a4975cc708f5fd4
|
| 3 |
+
size 148481024
|
VPSBD-ODE/train.yaml
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 20.267191359937392
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
|
| 14 |
+
epsilon: null
|
| 15 |
+
differential_equation_type: "ODE"
|
| 16 |
+
integrator_kwargs:
|
| 17 |
+
method: "euler"
|
| 18 |
+
velocity_annealing_factor: 2.301841820941901
|
| 19 |
+
correct_center_of_mass_motion: true
|
| 20 |
+
predict_velocity: true
|
| 21 |
+
# lattice vectors
|
| 22 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 23 |
+
init_args:
|
| 24 |
+
interpolant: omg.si.interpolants.TrigonometricInterpolant
|
| 25 |
+
gamma:
|
| 26 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 27 |
+
init_args:
|
| 28 |
+
a: 7.796692096273471
|
| 29 |
+
epsilon: null
|
| 30 |
+
differential_equation_type: "ODE"
|
| 31 |
+
integrator_kwargs:
|
| 32 |
+
method: "euler"
|
| 33 |
+
velocity_annealing_factor: 2.741014121366449
|
| 34 |
+
correct_center_of_mass_motion: false
|
| 35 |
+
data_fields:
|
| 36 |
+
# if the order of the data_fields changes,
|
| 37 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 38 |
+
- "species"
|
| 39 |
+
- "pos"
|
| 40 |
+
- "cell"
|
| 41 |
+
integration_time_steps: 710
|
| 42 |
+
relative_si_costs:
|
| 43 |
+
species_loss: 0.5499726353395065
|
| 44 |
+
pos_loss_b: 0.40529122146887925
|
| 45 |
+
cell_loss_b: 0.04473614319161423
|
| 46 |
+
sampler:
|
| 47 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 48 |
+
init_args:
|
| 49 |
+
pos_distribution:
|
| 50 |
+
class_path: omg.sampler.distributions.NormalDistribution
|
| 51 |
+
init_args:
|
| 52 |
+
scale: 0.23441087988918383
|
| 53 |
+
cell_distribution:
|
| 54 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 55 |
+
init_args:
|
| 56 |
+
dataset_name: mp_20
|
| 57 |
+
species_distribution:
|
| 58 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 59 |
+
model:
|
| 60 |
+
class_path: omg.model.model.Model
|
| 61 |
+
init_args:
|
| 62 |
+
encoder:
|
| 63 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 64 |
+
head:
|
| 65 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 66 |
+
time_embedder:
|
| 67 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 68 |
+
init_args:
|
| 69 |
+
dim: 256
|
| 70 |
+
use_min_perm_dist: False
|
| 71 |
+
float_32_matmul_precision: "high"
|
| 72 |
+
validation_mode: "dng_eval"
|
| 73 |
+
dataset_name: "mp_20"
|
| 74 |
+
data:
|
| 75 |
+
train_dataset:
|
| 76 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 77 |
+
init_args:
|
| 78 |
+
dataset:
|
| 79 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 80 |
+
init_args:
|
| 81 |
+
lmdb_paths:
|
| 82 |
+
- "data/mp_20/train.lmdb"
|
| 83 |
+
niggli: False
|
| 84 |
+
val_dataset:
|
| 85 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 86 |
+
init_args:
|
| 87 |
+
dataset:
|
| 88 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 89 |
+
init_args:
|
| 90 |
+
lmdb_paths:
|
| 91 |
+
- "data/mp_20/val.lmdb"
|
| 92 |
+
niggli: False
|
| 93 |
+
predict_dataset:
|
| 94 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 95 |
+
init_args:
|
| 96 |
+
dataset:
|
| 97 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 98 |
+
init_args:
|
| 99 |
+
lmdb_paths:
|
| 100 |
+
- "data/mp_20/test.lmdb"
|
| 101 |
+
niggli: False
|
| 102 |
+
batch_size: 256
|
| 103 |
+
num_workers: 4
|
| 104 |
+
pin_memory: True
|
| 105 |
+
persistent_workers: True
|
| 106 |
+
trainer:
|
| 107 |
+
callbacks:
|
| 108 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 109 |
+
init_args:
|
| 110 |
+
filename: "best_val_loss_total"
|
| 111 |
+
save_top_k: 1
|
| 112 |
+
monitor: "val_loss_total"
|
| 113 |
+
save_weights_only: true
|
| 114 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 115 |
+
init_args:
|
| 116 |
+
filename: "best_val_dng_eval"
|
| 117 |
+
save_top_k: 1
|
| 118 |
+
monitor: "dng_eval"
|
| 119 |
+
save_weights_only: true
|
| 120 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 121 |
+
init_args:
|
| 122 |
+
filename: "best_val_wdist_density"
|
| 123 |
+
save_top_k: 1
|
| 124 |
+
monitor: "wdist_density"
|
| 125 |
+
save_weights_only: true
|
| 126 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 127 |
+
init_args:
|
| 128 |
+
filename: "best_val_wdist_Nary"
|
| 129 |
+
save_top_k: 1
|
| 130 |
+
monitor: "wdist_Nary"
|
| 131 |
+
save_weights_only: true
|
| 132 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 133 |
+
init_args:
|
| 134 |
+
filename: "best_val_wdist_CN"
|
| 135 |
+
save_top_k: 1
|
| 136 |
+
monitor: "wdist_CN"
|
| 137 |
+
save_weights_only: true
|
| 138 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 139 |
+
init_args:
|
| 140 |
+
filename: "best_val_cov_precision"
|
| 141 |
+
save_top_k: 1
|
| 142 |
+
monitor: "cov_precision"
|
| 143 |
+
mode: "max"
|
| 144 |
+
save_weights_only: true
|
| 145 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 146 |
+
init_args:
|
| 147 |
+
filename: "best_val_cov_recall"
|
| 148 |
+
save_top_k: 1
|
| 149 |
+
monitor: "cov_recall"
|
| 150 |
+
mode: "max"
|
| 151 |
+
save_weights_only: true
|
| 152 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 153 |
+
init_args:
|
| 154 |
+
filename: "best_val_validity"
|
| 155 |
+
save_top_k: 1
|
| 156 |
+
monitor: "validity_rate"
|
| 157 |
+
mode: "max"
|
| 158 |
+
save_weights_only: true
|
| 159 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 160 |
+
init_args:
|
| 161 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 162 |
+
monitor: "val_loss_total"
|
| 163 |
+
every_n_epochs: 100
|
| 164 |
+
save_weights_only: false
|
| 165 |
+
gradient_clip_val: 0.5
|
| 166 |
+
gradient_clip_algorithm: "value"
|
| 167 |
+
num_sanity_val_steps: 0
|
| 168 |
+
precision: "32-true"
|
| 169 |
+
max_epochs: 2000
|
| 170 |
+
enable_progress_bar: false
|
| 171 |
+
check_val_every_n_epoch: 100
|
| 172 |
+
optimizer:
|
| 173 |
+
class_path: torch.optim.AdamW
|
| 174 |
+
init_args:
|
| 175 |
+
lr: 0.00797735754708741
|
| 176 |
+
weight_decay: 1.923837446196394e-05
|
| 177 |
+
lr_scheduler:
|
| 178 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 179 |
+
init_args:
|
| 180 |
+
T_max: 2000
|
| 181 |
+
eta_min: 1e-07
|
| 182 |
+
|
VPSBD-SDE/checkpoint.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bc99defffe2308a601267ab7ea233948ef3bfc033ff1b6e879525a50a0cda55
|
| 3 |
+
size 148532340
|
VPSBD-SDE/train.yaml
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
si:
|
| 3 |
+
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
|
| 4 |
+
init_args:
|
| 5 |
+
stochastic_interpolants:
|
| 6 |
+
# chemical species
|
| 7 |
+
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
|
| 8 |
+
init_args:
|
| 9 |
+
noise: 8.517811450005286
|
| 10 |
+
# fractional coordinates
|
| 11 |
+
- class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
|
| 12 |
+
init_args:
|
| 13 |
+
interpolant: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolant
|
| 14 |
+
epsilon:
|
| 15 |
+
class_path: omg.si.epsilon.VanishingEpsilon
|
| 16 |
+
init_args:
|
| 17 |
+
c: 9.268934476283913
|
| 18 |
+
mu: 0.25243331190144214
|
| 19 |
+
sigma: 0.04584669320169394
|
| 20 |
+
differential_equation_type: "SDE"
|
| 21 |
+
integrator_kwargs:
|
| 22 |
+
method: "euler"
|
| 23 |
+
dt: 0.00114844657946378
|
| 24 |
+
velocity_annealing_factor: 9.059507260865466
|
| 25 |
+
correct_center_of_mass_motion: true
|
| 26 |
+
predict_velocity: true
|
| 27 |
+
# lattice vectors
|
| 28 |
+
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
|
| 29 |
+
init_args:
|
| 30 |
+
interpolant: omg.si.interpolants.TrigonometricInterpolant
|
| 31 |
+
gamma:
|
| 32 |
+
class_path: omg.si.gamma.LatentGammaSqrt
|
| 33 |
+
init_args:
|
| 34 |
+
a: 3.0998874989979193
|
| 35 |
+
epsilon:
|
| 36 |
+
class_path: omg.si.epsilon.VanishingEpsilon
|
| 37 |
+
init_args:
|
| 38 |
+
c: 4.885221687138149
|
| 39 |
+
mu: 0.08252713485380268
|
| 40 |
+
sigma: 0.010096412051586533
|
| 41 |
+
differential_equation_type: "SDE"
|
| 42 |
+
integrator_kwargs:
|
| 43 |
+
method: "euler"
|
| 44 |
+
dt: 0.00114844657946378
|
| 45 |
+
velocity_annealing_factor: 11.76695215049249
|
| 46 |
+
correct_center_of_mass_motion: false
|
| 47 |
+
data_fields:
|
| 48 |
+
# if the order of the data_fields changes,
|
| 49 |
+
# the order of the above StochasticInterpolant inputs must also change
|
| 50 |
+
- "species"
|
| 51 |
+
- "pos"
|
| 52 |
+
- "cell"
|
| 53 |
+
integration_time_steps: 870
|
| 54 |
+
relative_si_costs:
|
| 55 |
+
species_loss: 0.35838215846575894
|
| 56 |
+
pos_loss_b: 0.5183735506925028
|
| 57 |
+
pos_loss_z: 0.0007900924652522647
|
| 58 |
+
cell_loss_b: 0.0044136759736567365
|
| 59 |
+
cell_loss_z: 0.11804052240282935
|
| 60 |
+
sampler:
|
| 61 |
+
class_path: omg.sampler.sample_from_rng.SampleFromRNG
|
| 62 |
+
init_args:
|
| 63 |
+
pos_distribution:
|
| 64 |
+
class_path: omg.sampler.distributions.NormalDistribution
|
| 65 |
+
init_args:
|
| 66 |
+
scale: 7.139671140709246
|
| 67 |
+
cell_distribution:
|
| 68 |
+
class_path: omg.sampler.distributions.InformedLatticeDistribution
|
| 69 |
+
init_args:
|
| 70 |
+
dataset_name: mp_20
|
| 71 |
+
species_distribution:
|
| 72 |
+
class_path: omg.sampler.distributions.MaskDistribution
|
| 73 |
+
model:
|
| 74 |
+
class_path: omg.model.model.Model
|
| 75 |
+
init_args:
|
| 76 |
+
encoder:
|
| 77 |
+
class_path: omg.model.encoders.cspnet_full.CSPNetFull
|
| 78 |
+
head:
|
| 79 |
+
class_path: omg.model.heads.pass_through.PassThrough
|
| 80 |
+
time_embedder:
|
| 81 |
+
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
|
| 82 |
+
init_args:
|
| 83 |
+
dim: 256
|
| 84 |
+
use_min_perm_dist: False
|
| 85 |
+
float_32_matmul_precision: "high"
|
| 86 |
+
validation_mode: "dng_eval"
|
| 87 |
+
dataset_name: "mp_20"
|
| 88 |
+
data:
|
| 89 |
+
train_dataset:
|
| 90 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 91 |
+
init_args:
|
| 92 |
+
dataset:
|
| 93 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 94 |
+
init_args:
|
| 95 |
+
lmdb_paths:
|
| 96 |
+
- "data/mp_20/train.lmdb"
|
| 97 |
+
niggli: False
|
| 98 |
+
val_dataset:
|
| 99 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 100 |
+
init_args:
|
| 101 |
+
dataset:
|
| 102 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 103 |
+
init_args:
|
| 104 |
+
lmdb_paths:
|
| 105 |
+
- "data/mp_20/val.lmdb"
|
| 106 |
+
niggli: False
|
| 107 |
+
predict_dataset:
|
| 108 |
+
class_path: omg.datamodule.dataloader.OMGTorchDataset
|
| 109 |
+
init_args:
|
| 110 |
+
dataset:
|
| 111 |
+
class_path: omg.datamodule.datamodule.DataModule
|
| 112 |
+
init_args:
|
| 113 |
+
lmdb_paths:
|
| 114 |
+
- "data/mp_20/test.lmdb"
|
| 115 |
+
niggli: False
|
| 116 |
+
batch_size: 512
|
| 117 |
+
num_workers: 4
|
| 118 |
+
pin_memory: True
|
| 119 |
+
persistent_workers: True
|
| 120 |
+
trainer:
|
| 121 |
+
callbacks:
|
| 122 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 123 |
+
init_args:
|
| 124 |
+
filename: "best_val_loss_total"
|
| 125 |
+
save_top_k: 1
|
| 126 |
+
monitor: "val_loss_total"
|
| 127 |
+
save_weights_only: true
|
| 128 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 129 |
+
init_args:
|
| 130 |
+
filename: "best_val_dng_eval"
|
| 131 |
+
save_top_k: 1
|
| 132 |
+
monitor: "dng_eval"
|
| 133 |
+
save_weights_only: true
|
| 134 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 135 |
+
init_args:
|
| 136 |
+
filename: "best_val_wdist_density"
|
| 137 |
+
save_top_k: 1
|
| 138 |
+
monitor: "wdist_density"
|
| 139 |
+
save_weights_only: true
|
| 140 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 141 |
+
init_args:
|
| 142 |
+
filename: "best_val_wdist_Nary"
|
| 143 |
+
save_top_k: 1
|
| 144 |
+
monitor: "wdist_Nary"
|
| 145 |
+
save_weights_only: true
|
| 146 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 147 |
+
init_args:
|
| 148 |
+
filename: "best_val_wdist_CN"
|
| 149 |
+
save_top_k: 1
|
| 150 |
+
monitor: "wdist_CN"
|
| 151 |
+
save_weights_only: true
|
| 152 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 153 |
+
init_args:
|
| 154 |
+
filename: "best_val_cov_precision"
|
| 155 |
+
save_top_k: 1
|
| 156 |
+
monitor: "cov_precision"
|
| 157 |
+
mode: "max"
|
| 158 |
+
save_weights_only: true
|
| 159 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 160 |
+
init_args:
|
| 161 |
+
filename: "best_val_cov_recall"
|
| 162 |
+
save_top_k: 1
|
| 163 |
+
monitor: "cov_recall"
|
| 164 |
+
mode: "max"
|
| 165 |
+
save_weights_only: true
|
| 166 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 167 |
+
init_args:
|
| 168 |
+
filename: "best_val_validity"
|
| 169 |
+
save_top_k: 1
|
| 170 |
+
monitor: "validity_rate"
|
| 171 |
+
mode: "max"
|
| 172 |
+
save_weights_only: true
|
| 173 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 174 |
+
init_args:
|
| 175 |
+
save_top_k: -1 # Store every checkpoint after 100 epochs.
|
| 176 |
+
monitor: "val_loss_total"
|
| 177 |
+
every_n_epochs: 100
|
| 178 |
+
save_weights_only: false
|
| 179 |
+
gradient_clip_val: 0.5
|
| 180 |
+
gradient_clip_algorithm: "value"
|
| 181 |
+
num_sanity_val_steps: 0
|
| 182 |
+
precision: "32-true"
|
| 183 |
+
max_epochs: 2000
|
| 184 |
+
enable_progress_bar: false
|
| 185 |
+
check_val_every_n_epoch: 100
|
| 186 |
+
optimizer:
|
| 187 |
+
class_path: torch.optim.AdamW
|
| 188 |
+
init_args:
|
| 189 |
+
lr: 0.0014461332672089323
|
| 190 |
+
weight_decay: 0.0007097046414614019
|
| 191 |
+
lr_scheduler:
|
| 192 |
+
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
|
| 193 |
+
init_args:
|
| 194 |
+
T_max: 2000
|
| 195 |
+
eta_min: 1e-07
|
| 196 |
+
|