File size: 4,504 Bytes
4fcc7b9 5a7776a 4fcc7b9 5a7776a 4fcc7b9 5a7776a 4fcc7b9 5a7776a 4fcc7b9 5a7776a e58815d 5a7776a 4fcc7b9 5a7776a 4fcc7b9 5a7776a e58815d 5a7776a 4fcc7b9 5a7776a e58815d 5a7776a 4fcc7b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | model:
si:
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
init_args:
stochastic_interpolants:
# chemical species
- class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
# fractional coordinates
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
interpolant: omg.si.interpolants.PeriodicLinearInterpolant
gamma: null
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 0.004755207270677389
correct_center_of_mass_motion: true
# lattice vectors
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
interpolant:
class_path: omg.si.interpolants.EncoderDecoderInterpolant
init_args:
switch_time: 0.46351945271978645
power: 1.0
gamma:
class_path: omg.si.gamma.LatentGammaEncoderDecoder
init_args:
a: 0.8167071952445664
switch_time: 0.46351945271978645
power: 1.0
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 13.921408921615031
correct_center_of_mass_motion: false
data_fields:
# if the order of the data_fields changes,
# the order of the above StochasticInterpolant inputs must also change
- "species"
- "pos"
- "cell"
integration_time_steps: 480
relative_si_costs:
species_loss: 0.0
pos_loss_b: 0.9860929911452281
cell_loss_b: 0.01390700885477196
sampler:
class_path: omg.sampler.IndependentSampler
init_args:
pos_distribution:
class_path: omg.sampler.position_distributions.UniformPositionDistribution
cell_distribution:
class_path: omg.sampler.cell_distributions.InformedLatticeDistribution
init_args:
dataset_name: perov_5
species_distribution:
class_path: omg.sampler.species_distributions.MirrorSpecies
model:
class_path: omg.model.model.Model
init_args:
encoder:
class_path: omg.model.encoders.cspnet_full.CSPNetFull
head:
class_path: omg.model.heads.pass_through.PassThrough
time_embedder:
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
init_args:
dim: 256
use_min_perm_dist: True
float_32_matmul_precision: "high"
validation_mode: "match_rate"
dataset_name: "perov_5"
data:
train_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/perov_5/train.lmdb"
lazy_storage: True
niggli_reduce: True
val_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/perov_5/val.lmdb"
lazy_storage: True
niggli_reduce: True
pred_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/perov_5/test.lmdb"
lazy_storage: True
niggli_reduce: True
batch_size: 1024
num_workers: 4
pin_memory: True
persistent_workers: True
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_loss_total"
save_top_k: 1
monitor: "val_loss_total"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_match_rate"
save_top_k: 1
monitor: "match_rate"
save_weights_only: true
mode: 'max'
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_rmsd"
save_top_k: 1
monitor: "mean_rmsd"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: -1 # Store every checkpoint after 100 epochs.
monitor: "val_loss_total"
every_n_epochs: 100
save_weights_only: false
gradient_clip_val: 0.5
num_sanity_val_steps: 0
precision: "32-true"
max_epochs: 6000
enable_progress_bar: false
check_val_every_n_epoch: 100
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001147361965964576
|