File size: 4,173 Bytes
c2e0a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24fd3df
c2e0a0d
24fd3df
 
c2e0a0d
24fd3df
c2e0a0d
 
 
24fd3df
c2e0a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24fd3df
c2e0a0d
24fd3df
f305d44
24fd3df
c2e0a0d
24fd3df
c2e0a0d
24fd3df
f305d44
24fd3df
 
 
c2e0a0d
24fd3df
f305d44
24fd3df
c2e0a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
model:
  si:
    class_path: omg.si.stochastic_interpolants.StochasticInterpolants
    init_args:
      stochastic_interpolants:
        # chemical species
        - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
        # fractional coordinates
        - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
          init_args:
            interpolant: omg.si.interpolants.PeriodicLinearInterpolant
            gamma: null
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 12.752963137656907
            correct_center_of_mass_motion: true
        # lattice vectors
        - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
          init_args:
            interpolant: omg.si.interpolants.LinearInterpolant
            gamma: null
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 0.9964121490291458
            correct_center_of_mass_motion: false
      data_fields:
        # if the order of the data_fields changes,
        # the order of the above StochasticInterpolant inputs must also change
        - "species"
        - "pos"
        - "cell"
      integration_time_steps: 100
  relative_si_costs:
    species_loss: 0.0
    pos_loss_b: 0.9983149306572928
    cell_loss_b: 0.0016850693427072152
  sampler:
    class_path: omg.sampler.IndependentSampler
    init_args:
      pos_distribution:
        class_path: omg.sampler.position_distributions.UniformPositionDistribution
      cell_distribution:
        class_path: omg.sampler.cell_distributions.InformedLatticeDistribution
        init_args:
          dataset_name: mpts_52
      species_distribution:
        class_path: omg.sampler.species_distributions.MirrorSpecies
  model:
    class_path: omg.model.model.Model
    init_args:
      encoder:
        class_path: omg.model.encoders.cspnet_full.CSPNetFull
      head:
        class_path: omg.model.heads.pass_through.PassThrough
      time_embedder:
        class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
        init_args:
          dim: 256
  use_min_perm_dist: False
  float_32_matmul_precision: "high"
  validation_mode: "match_rate"
  dataset_name: "mpts_52"
data:
  train_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/mpts_52/train.lmdb"
      lazy_storage: True
      niggli_reduce: False
  val_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/mpts_52/val.lmdb"
      lazy_storage: True
      niggli_reduce: False
  pred_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/mpts_52/test.lmdb"
      lazy_storage: True
      niggli_reduce: False
  batch_size: 512
  num_workers: 4
  pin_memory: True
  persistent_workers: True
trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_loss_total"
        save_top_k: 1
        monitor: "val_loss_total"
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_match_rate"
        save_top_k: 1
        monitor: "match_rate"
        save_weights_only: true
        mode: 'max'
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_rmsd"
        save_top_k: 1
        monitor: "mean_rmsd"
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        save_top_k: -1  # Store every checkpoint after 100 epochs.
        monitor: "val_loss_total"
        every_n_epochs: 100
        save_weights_only: false
  gradient_clip_val: 0.5
  num_sanity_val_steps: 0
  precision: "32-true"
  max_epochs: 10000
  enable_progress_bar: false
  check_val_every_n_epoch: 100
optimizer:
  class_path: torch.optim.Adam
  init_args:
    lr: 0.0005546288717347031