|
|
model: |
|
|
si: |
|
|
class_path: omg.si.stochastic_interpolants.StochasticInterpolants |
|
|
init_args: |
|
|
stochastic_interpolants: |
|
|
|
|
|
- class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity |
|
|
|
|
|
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant |
|
|
init_args: |
|
|
interpolant: omg.si.interpolants.PeriodicLinearInterpolant |
|
|
gamma: null |
|
|
epsilon: null |
|
|
differential_equation_type: "ODE" |
|
|
integrator_kwargs: |
|
|
method: "euler" |
|
|
velocity_annealing_factor: 10.182659004291072 |
|
|
correct_center_of_mass_motion: true |
|
|
|
|
|
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant |
|
|
init_args: |
|
|
interpolant: omg.si.interpolants.LinearInterpolant |
|
|
gamma: null |
|
|
epsilon: null |
|
|
differential_equation_type: "ODE" |
|
|
integrator_kwargs: |
|
|
method: "euler" |
|
|
velocity_annealing_factor: 1.824475401606087 |
|
|
correct_center_of_mass_motion: false |
|
|
data_fields: |
|
|
|
|
|
|
|
|
- "species" |
|
|
- "pos" |
|
|
- "cell" |
|
|
integration_time_steps: 210 |
|
|
relative_si_costs: |
|
|
species_loss: 0.0 |
|
|
pos_loss_b: 0.9994149341846618 |
|
|
cell_loss_b: 0.0005850658153382233 |
|
|
sampler: |
|
|
class_path: omg.sampler.IndependentSampler |
|
|
init_args: |
|
|
pos_distribution: |
|
|
class_path: omg.sampler.position_distributions.UniformPositionDistribution |
|
|
cell_distribution: |
|
|
class_path: omg.sampler.cell_distributions.InformedLatticeDistribution |
|
|
init_args: |
|
|
dataset_name: mp_20 |
|
|
species_distribution: |
|
|
class_path: omg.sampler.species_distributions.MirrorSpecies |
|
|
model: |
|
|
class_path: omg.model.model.Model |
|
|
init_args: |
|
|
encoder: |
|
|
class_path: omg.model.encoders.cspnet_full.CSPNetFull |
|
|
head: |
|
|
class_path: omg.model.heads.pass_through.PassThrough |
|
|
time_embedder: |
|
|
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings |
|
|
init_args: |
|
|
dim: 256 |
|
|
use_min_perm_dist: False |
|
|
float_32_matmul_precision: "high" |
|
|
validation_mode: "match_rate" |
|
|
dataset_name: "mp_20" |
|
|
data: |
|
|
train_dataset: |
|
|
class_path: omg.datamodule.StructureDataset |
|
|
init_args: |
|
|
file_path: "data/mp_20/train.lmdb" |
|
|
lazy_storage: True |
|
|
niggli_reduce: False |
|
|
val_dataset: |
|
|
class_path: omg.datamodule.StructureDataset |
|
|
init_args: |
|
|
file_path: "data/mp_20/val.lmdb" |
|
|
lazy_storage: True |
|
|
niggli_reduce: False |
|
|
pred_dataset: |
|
|
class_path: omg.datamodule.StructureDataset |
|
|
init_args: |
|
|
file_path: "data/mp_20/test.lmdb" |
|
|
lazy_storage: True |
|
|
niggli_reduce: False |
|
|
batch_size: 1024 |
|
|
num_workers: 4 |
|
|
pin_memory: True |
|
|
persistent_workers: True |
|
|
trainer: |
|
|
callbacks: |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_loss_total" |
|
|
save_top_k: 1 |
|
|
monitor: "val_loss_total" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_match_rate" |
|
|
save_top_k: 1 |
|
|
monitor: "match_rate" |
|
|
save_weights_only: true |
|
|
mode: 'max' |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_rmsd" |
|
|
save_top_k: 1 |
|
|
monitor: "mean_rmsd" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
save_top_k: -1 |
|
|
monitor: "val_loss_total" |
|
|
every_n_epochs: 100 |
|
|
save_weights_only: false |
|
|
gradient_clip_val: 0.5 |
|
|
num_sanity_val_steps: 0 |
|
|
precision: "32-true" |
|
|
max_epochs: 2000 |
|
|
enable_progress_bar: false |
|
|
check_val_every_n_epoch: 100 |
|
|
optimizer: |
|
|
class_path: torch.optim.Adam |
|
|
init_args: |
|
|
lr: 0.0006689636445843722 |
|
|
|