|
|
model: |
|
|
si: |
|
|
class_path: omg.si.stochastic_interpolants.StochasticInterpolants |
|
|
init_args: |
|
|
stochastic_interpolants: |
|
|
|
|
|
- class_path: omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask |
|
|
init_args: |
|
|
noise: 7.080372063368751 |
|
|
|
|
|
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant |
|
|
init_args: |
|
|
interpolant: omg.si.interpolants.PeriodicLinearInterpolant |
|
|
gamma: null |
|
|
epsilon: null |
|
|
differential_equation_type: "ODE" |
|
|
integrator_kwargs: |
|
|
method: "euler" |
|
|
velocity_annealing_factor: 13.620695525269845 |
|
|
correct_center_of_mass_motion: true |
|
|
|
|
|
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant |
|
|
init_args: |
|
|
interpolant: omg.si.interpolants.LinearInterpolant |
|
|
gamma: null |
|
|
epsilon: null |
|
|
differential_equation_type: "ODE" |
|
|
integrator_kwargs: |
|
|
method: "euler" |
|
|
velocity_annealing_factor: 1.0679662602192312 |
|
|
correct_center_of_mass_motion: false |
|
|
data_fields: |
|
|
|
|
|
|
|
|
- "species" |
|
|
- "pos" |
|
|
- "cell" |
|
|
integration_time_steps: 150 |
|
|
relative_si_costs: |
|
|
species_loss: 0.021815399034596464 |
|
|
pos_loss_b: 0.9775483266595605 |
|
|
cell_loss_b: 0.0006362743058429793 |
|
|
sampler: |
|
|
class_path: omg.sampler.IndependentSampler |
|
|
init_args: |
|
|
pos_distribution: |
|
|
class_path: omg.sampler.position_distributions.UniformPositionDistribution |
|
|
cell_distribution: |
|
|
class_path: omg.sampler.cell_distributions.InformedLatticeDistribution |
|
|
init_args: |
|
|
dataset_name: mp_20 |
|
|
species_distribution: |
|
|
class_path: omg.sampler.species_distributions.MaskSpeciesDistribution |
|
|
model: |
|
|
class_path: omg.model.model.Model |
|
|
init_args: |
|
|
encoder: |
|
|
class_path: omg.model.encoders.cspnet_full.CSPNetFull |
|
|
head: |
|
|
class_path: omg.model.heads.pass_through.PassThrough |
|
|
time_embedder: |
|
|
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings |
|
|
init_args: |
|
|
dim: 256 |
|
|
use_min_perm_dist: True |
|
|
float_32_matmul_precision: "high" |
|
|
validation_mode: "dng_eval" |
|
|
dataset_name: "mp_20" |
|
|
data: |
|
|
train_dataset: |
|
|
class_path: omg.datamodule.StructureDataset |
|
|
init_args: |
|
|
file_path: "data/mp_20/train.lmdb" |
|
|
lazy_storage: True |
|
|
niggli_reduce: False |
|
|
val_dataset: |
|
|
class_path: omg.datamodule.StructureDataset |
|
|
init_args: |
|
|
file_path: "data/mp_20/val.lmdb" |
|
|
lazy_storage: True |
|
|
niggli_reduce: False |
|
|
pred_dataset: |
|
|
class_path: omg.datamodule.StructureDataset |
|
|
init_args: |
|
|
file_path: "data/mp_20/test.lmdb" |
|
|
lazy_storage: True |
|
|
niggli_reduce: False |
|
|
batch_size: 512 |
|
|
num_workers: 4 |
|
|
pin_memory: True |
|
|
persistent_workers: True |
|
|
trainer: |
|
|
callbacks: |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_loss_total" |
|
|
save_top_k: 1 |
|
|
monitor: "val_loss_total" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_dng_eval" |
|
|
save_top_k: 1 |
|
|
monitor: "dng_eval" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_wdist_density" |
|
|
save_top_k: 1 |
|
|
monitor: "wdist_density" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_wdist_narity" |
|
|
save_top_k: 1 |
|
|
monitor: "wdist_narity" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_wdist_coordination_numbers" |
|
|
save_top_k: 1 |
|
|
monitor: "wdist_coordination_numbers" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_cov_precision" |
|
|
save_top_k: 1 |
|
|
monitor: "cov_precision" |
|
|
mode: "max" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_cov_recall" |
|
|
save_top_k: 1 |
|
|
monitor: "cov_recall" |
|
|
mode: "max" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
filename: "best_val_valid_rate" |
|
|
save_top_k: 1 |
|
|
monitor: "valid_rate" |
|
|
mode: "max" |
|
|
save_weights_only: true |
|
|
- class_path: lightning.pytorch.callbacks.ModelCheckpoint |
|
|
init_args: |
|
|
save_top_k: -1 |
|
|
monitor: "val_loss_total" |
|
|
every_n_epochs: 100 |
|
|
save_weights_only: false |
|
|
gradient_clip_val: 0.5 |
|
|
num_sanity_val_steps: 0 |
|
|
precision: "32-true" |
|
|
max_epochs: 10000 |
|
|
enable_progress_bar: false |
|
|
check_val_every_n_epoch: 100 |
|
|
optimizer: |
|
|
class_path: torch.optim.Adam |
|
|
init_args: |
|
|
lr: 0.001736512450391209 |
|
|
|