File size: 4,401 Bytes
c2e0a0d 31243d0 c2e0a0d 24fd3df c2e0a0d 24fd3df c2e0a0d 24fd3df c2e0a0d 24fd3df c2e0a0d 24fd3df c2e0a0d 24fd3df f305d44 24fd3df c2e0a0d 24fd3df c2e0a0d 24fd3df f305d44 24fd3df c2e0a0d 24fd3df f305d44 24fd3df c2e0a0d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | model:
si:
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
init_args:
stochastic_interpolants:
# chemical species
- class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
# fractional coordinates
- class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
init_args:
interpolant:
class_path: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolantVP
init_args:
tau: omg.si.tau.TauConstantSchedule
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 6.613808424917352
correct_center_of_mass_motion: true
predict_velocity: true
# lattice vectors
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
interpolant: omg.si.interpolants.LinearInterpolant
gamma: null
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 2.447993013544224
correct_center_of_mass_motion: false
data_fields:
# if the order of the data_fields changes,
# the order of the above StochasticInterpolant inputs must also change
- "species"
- "pos"
- "cell"
integration_time_steps: 890
relative_si_costs:
species_loss: 0.0
pos_loss_b: 0.9597565150933746
cell_loss_b: 0.04024348490662539
sampler:
class_path: omg.sampler.IndependentSampler
init_args:
pos_distribution:
class_path: omg.sampler.position_distributions.NormalPositionDistribution
init_args:
scale: 0.22006712732536396
cell_distribution:
class_path: omg.sampler.cell_distributions.InformedLatticeDistribution
init_args:
dataset_name: mpts_52
species_distribution:
class_path: omg.sampler.species_distributions.MirrorSpecies
model:
class_path: omg.model.model.Model
init_args:
encoder:
class_path: omg.model.encoders.cspnet_full.CSPNetFull
head:
class_path: omg.model.heads.pass_through.PassThrough
time_embedder:
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
init_args:
dim: 256
use_min_perm_dist: True
float_32_matmul_precision: "high"
validation_mode: "match_rate"
number_cpus: 7
dataset_name: "mpts_52"
data:
train_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/mpts_52/train.lmdb"
lazy_storage: True
niggli_reduce: False
val_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/mpts_52/val.lmdb"
lazy_storage: True
niggli_reduce: False
pred_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/mpts_52/test.lmdb"
lazy_storage: True
niggli_reduce: False
batch_size: 64
num_workers: 4
pin_memory: True
persistent_workers: True
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_loss_total"
save_top_k: 1
monitor: "val_loss_total"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_match_rate"
save_top_k: 1
monitor: "match_rate"
save_weights_only: true
mode: 'max'
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_rmsd"
save_top_k: 1
monitor: "mean_rmsd"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: -1 # Store every checkpoint after 100 epochs.
monitor: "val_loss_total"
every_n_epochs: 100
save_weights_only: false
gradient_clip_val: 0.5
num_sanity_val_steps: 0
precision: "32-true"
max_epochs: 2000
enable_progress_bar: true
limit_val_batches: 0.5
check_val_every_n_epoch: 100
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 2.519765029616902e-05
|