File size: 4,401 Bytes
c2e0a0d
 
 
 
 
 
 
 
 
 
31243d0
 
 
 
c2e0a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24fd3df
c2e0a0d
24fd3df
 
c2e0a0d
 
24fd3df
 
c2e0a0d
 
 
24fd3df
c2e0a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24fd3df
c2e0a0d
24fd3df
f305d44
24fd3df
c2e0a0d
24fd3df
c2e0a0d
24fd3df
f305d44
24fd3df
 
 
c2e0a0d
24fd3df
f305d44
24fd3df
c2e0a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
model:
  si:
    class_path: omg.si.stochastic_interpolants.StochasticInterpolants
    init_args:
      stochastic_interpolants:
        # chemical species
        - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
        # fractional coordinates
        - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
          init_args:
            interpolant:
              class_path: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolantVP
              init_args:
                tau: omg.si.tau.TauConstantSchedule
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 6.613808424917352
            correct_center_of_mass_motion: true
            predict_velocity: true
        # lattice vectors
        - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
          init_args:
            interpolant: omg.si.interpolants.LinearInterpolant
            gamma: null
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 2.447993013544224
            correct_center_of_mass_motion: false
      data_fields:
        # if the order of the data_fields changes,
        # the order of the above StochasticInterpolant inputs must also change
        - "species"
        - "pos"
        - "cell"
      integration_time_steps: 890
  relative_si_costs:
    species_loss: 0.0
    pos_loss_b: 0.9597565150933746
    cell_loss_b: 0.04024348490662539
  sampler:
    class_path: omg.sampler.IndependentSampler
    init_args:
      pos_distribution:
        class_path: omg.sampler.position_distributions.NormalPositionDistribution
        init_args:
          scale: 0.22006712732536396
      cell_distribution:
        class_path: omg.sampler.cell_distributions.InformedLatticeDistribution
        init_args:
          dataset_name: mpts_52
      species_distribution:
        class_path: omg.sampler.species_distributions.MirrorSpecies
  model:
    class_path: omg.model.model.Model
    init_args:
      encoder:
        class_path: omg.model.encoders.cspnet_full.CSPNetFull
      head:
        class_path: omg.model.heads.pass_through.PassThrough
      time_embedder:
        class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
        init_args:
          dim: 256
  use_min_perm_dist: True
  float_32_matmul_precision: "high"
  validation_mode: "match_rate"
  number_cpus: 7
  dataset_name: "mpts_52"
data:
  train_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/mpts_52/train.lmdb"
      lazy_storage: True
      niggli_reduce: False
  val_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/mpts_52/val.lmdb"
      lazy_storage: True
      niggli_reduce: False
  pred_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/mpts_52/test.lmdb"
      lazy_storage: True
      niggli_reduce: False
  batch_size: 64
  num_workers: 4
  pin_memory: True
  persistent_workers: True
trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_loss_total"
        save_top_k: 1
        monitor: "val_loss_total"
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_match_rate"
        save_top_k: 1
        monitor: "match_rate"
        save_weights_only: true
        mode: 'max'
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_rmsd"
        save_top_k: 1
        monitor: "mean_rmsd"
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        save_top_k: -1  # Store every checkpoint after 100 epochs.
        monitor: "val_loss_total"
        every_n_epochs: 100
        save_weights_only: false
  gradient_clip_val: 0.5
  num_sanity_val_steps: 0
  precision: "32-true"
  max_epochs: 2000
  enable_progress_bar: true
  limit_val_batches: 0.5
  check_val_every_n_epoch: 100
optimizer:
  class_path: torch.optim.Adam
  init_args:
    lr: 2.519765029616902e-05