Philipp Hoellmer
Use lazy storage
6af2cc9
model:
si:
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
init_args:
stochastic_interpolants:
# chemical species
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask
init_args:
noise: 23.870491382634235
# fractional coordinates
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
interpolant: omg.si.interpolants.PeriodicLinearInterpolant
gamma:
class_path: omg.si.gamma.LatentGammaSqrt
init_args:
a: 1.4501684803942854
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 14.825022083056373
correct_center_of_mass_motion: true
# lattice vectors
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
interpolant:
class_path: omg.si.interpolants.EncoderDecoderInterpolant
init_args:
switch_time: 0.13766881993242777
power: 1.0
gamma:
class_path: omg.si.gamma.LatentGammaEncoderDecoder
init_args:
a: 7.882046441638109
switch_time: 0.13766881993242777
power: 1.0
epsilon:
class_path: omg.si.epsilon.VanishingEpsilon
init_args:
c: 5.487699104615115
mu: 0.2899409657474152
sigma: 0.010062500495585096
differential_equation_type: "SDE"
integrator_kwargs:
method: "euler"
dt: 0.007736434228718281
velocity_annealing_factor: 5.9072140831863305
correct_center_of_mass_motion: false
data_fields:
# if the order of the data_fields changes,
# the order of the above StochasticInterpolant inputs must also change
- "species"
- "pos"
- "cell"
integration_time_steps: 130
relative_si_costs:
species_loss: 0.22163297905676768
pos_loss_b: 0.7682858351574293
cell_loss_b: 0.008860420349864338
cell_loss_z: 0.001220765435938645
sampler:
class_path: omg.sampler.IndependentSampler
init_args:
pos_distribution:
class_path: omg.sampler.position_distributions.UniformPositionDistribution
cell_distribution:
class_path: omg.sampler.cell_distributions.InformedLatticeDistribution
init_args:
dataset_name: mp_20
species_distribution:
class_path: omg.sampler.species_distributions.MaskSpeciesDistribution
model:
class_path: omg.model.model.Model
init_args:
encoder:
class_path: omg.model.encoders.cspnet_full.CSPNetFull
head:
class_path: omg.model.heads.pass_through.PassThrough
time_embedder:
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
init_args:
dim: 256
use_min_perm_dist: True
float_32_matmul_precision: "high"
validation_mode: "dng_eval"
dataset_name: "mp_20"
data:
train_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/mp_20/train.lmdb"
lazy_storage: True
niggli_reduce: False
val_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/mp_20/val.lmdb"
lazy_storage: True
niggli_reduce: False
pred_dataset:
class_path: omg.datamodule.StructureDataset
init_args:
file_path: "data/mp_20/test.lmdb"
lazy_storage: True
niggli_reduce: False
batch_size: 256
num_workers: 4
pin_memory: True
persistent_workers: True
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_loss_total"
save_top_k: 1
monitor: "val_loss_total"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_dng_eval"
save_top_k: 1
monitor: "dng_eval"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_wdist_density"
save_top_k: 1
monitor: "wdist_density"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_wdist_narity"
save_top_k: 1
monitor: "wdist_narity"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_wdist_coordination_numbers"
save_top_k: 1
monitor: "wdist_coordination_numbers"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_cov_precision"
save_top_k: 1
monitor: "cov_precision"
mode: "max"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_cov_recall"
save_top_k: 1
monitor: "cov_recall"
mode: "max"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_valid_rate"
save_top_k: 1
monitor: "valid_rate"
mode: "max"
save_weights_only: true
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: -1 # Store every checkpoint after 100 epochs.
monitor: "val_loss_total"
every_n_epochs: 100
save_weights_only: false
gradient_clip_val: 0.5
gradient_clip_algorithm: "value"
num_sanity_val_steps: 0
precision: "32-true"
max_epochs: 2000
enable_progress_bar: false
check_val_every_n_epoch: 100
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00019511523812262233
weight_decay: 0.0003813804936812436
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
T_max: 2000
eta_min: 1e-07