File size: 4,078 Bytes
6a8ebba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfd651
6a8ebba
0bfd651
 
6a8ebba
0bfd651
6a8ebba
 
 
0bfd651
6a8ebba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bfd651
6a8ebba
0bfd651
9556f6e
0bfd651
6a8ebba
0bfd651
6a8ebba
0bfd651
9556f6e
0bfd651
 
 
6a8ebba
0bfd651
9556f6e
0bfd651
6a8ebba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
model:
  si:
    class_path: omg.si.StochasticInterpolants
    init_args:
      stochastic_interpolants:
        # chemical species
        - class_path: omg.si.SingleStochasticInterpolantIdentity
        # fractional coordinates
        - class_path: omg.si.SingleStochasticInterpolant
          init_args:
            interpolant: omg.si.PeriodicLinearInterpolant
            gamma: null
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 10.182659004291072
            correct_center_of_mass_motion: true
        # lattice vectors
        - class_path: omg.si.SingleStochasticInterpolant
          init_args:
            interpolant: omg.si.LinearInterpolant
            gamma: null
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 1.824475401606087
            correct_center_of_mass_motion: false
      data_fields:
        # if the order of the data_fields changes,
        # the order of the above StochasticInterpolant inputs must also change
        - "species"
        - "pos"
        - "cell"
      integration_time_steps: 210
  relative_si_costs:
    species_loss: 0.0
    pos_loss_b: 0.9994149341846618
    cell_loss_b: 0.0005850658153382233
  sampler:
    class_path: omg.sampler.IndependentSampler
    init_args:
      pos_distribution:
        class_path: omg.sampler.position_distributions.UniformPositionDistribution
      cell_distribution:
        class_path: omg.sampler.cell_distributions.InformedLatticeDistribution
        init_args:
          dataset_name: alex_mp_20
      species_distribution:
        class_path: omg.sampler.species_distributions.MirrorSpecies
  model:
    class_path: omg.model.model.Model
    init_args:
      encoder:
        class_path: omg.model.encoders.cspnet_full.CSPNetFull
      head:
        class_path: omg.model.heads.pass_through.PassThrough
      time_embedder:
        class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
        init_args:
          dim: 256
  use_min_perm_dist: False
  float_32_matmul_precision: "high"
  validation_mode: "match_rate"
  number_cpus: 7
  dataset_name: "alex_mp_20"
data:
  train_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/alex_mp_20/train.lmdb"
      lazy_storage: True
      niggli_reduce: False
  val_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/alex_mp_20/val.lmdb"
      lazy_storage: True
      niggli_reduce: False
  pred_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/alex_mp_20/test.lmdb"
      lazy_storage: True
      niggli_reduce: False
  batch_size: 512
  num_workers: 4
  pin_memory: True
  persistent_workers: True
trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_loss_total"
        save_top_k: 1
        monitor: "val_loss_total"
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_match_rate"
        save_top_k: 1
        monitor: "match_rate"
        save_weights_only: true
        mode: 'max'
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_rmsd"
        save_top_k: 1
        monitor: "mean_rmsd"
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        save_top_k: -1  # Store every checkpoint after 100 epochs.
        monitor: "val_loss_total"
        every_n_epochs: 100
        save_weights_only: false
  gradient_clip_val: 0.5
  num_sanity_val_steps: 0
  precision: "32-true"
  max_epochs: 2000
  enable_progress_bar: true
  limit_val_batches: 0.1
  check_val_every_n_epoch: 100
optimizer:
  class_path: torch.optim.Adam
  init_args:
    lr: 0.0006689636445843722