File size: 4,504 Bytes
4fcc7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a7776a
4fcc7b9
5a7776a
 
 
 
4fcc7b9
 
 
5a7776a
4fcc7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a7776a
4fcc7b9
5a7776a
e58815d
5a7776a
4fcc7b9
5a7776a
4fcc7b9
5a7776a
e58815d
5a7776a
 
 
4fcc7b9
5a7776a
e58815d
5a7776a
4fcc7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
model:
  si:
    class_path: omg.si.stochastic_interpolants.StochasticInterpolants
    init_args:
      stochastic_interpolants:
        # chemical species
        - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
        # fractional coordinates
        - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
          init_args:
            interpolant: omg.si.interpolants.PeriodicLinearInterpolant
            gamma: null
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 0.004755207270677389
            correct_center_of_mass_motion: true
        # lattice vectors
        - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
          init_args:
            interpolant: 
              class_path: omg.si.interpolants.EncoderDecoderInterpolant
              init_args:
                switch_time: 0.46351945271978645
                power: 1.0
            gamma: 
              class_path: omg.si.gamma.LatentGammaEncoderDecoder
              init_args:
                a: 0.8167071952445664
                switch_time: 0.46351945271978645
                power: 1.0
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 13.921408921615031
            correct_center_of_mass_motion: false
      data_fields:
        # if the order of the data_fields changes,
        # the order of the above StochasticInterpolant inputs must also change
        - "species"
        - "pos"
        - "cell"
      integration_time_steps: 480
  relative_si_costs:
    species_loss: 0.0
    pos_loss_b: 0.9860929911452281
    cell_loss_b: 0.01390700885477196
  sampler:
    class_path: omg.sampler.IndependentSampler
    init_args:
      pos_distribution:
        class_path: omg.sampler.position_distributions.UniformPositionDistribution
      cell_distribution:
        class_path: omg.sampler.cell_distributions.InformedLatticeDistribution
        init_args:
          dataset_name: perov_5
      species_distribution:
        class_path: omg.sampler.species_distributions.MirrorSpecies
  model:
    class_path: omg.model.model.Model
    init_args:
      encoder:
        class_path: omg.model.encoders.cspnet_full.CSPNetFull
      head:
        class_path: omg.model.heads.pass_through.PassThrough
      time_embedder:
        class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
        init_args:
          dim: 256
  use_min_perm_dist: True
  float_32_matmul_precision: "high"
  validation_mode: "match_rate"
  dataset_name: "perov_5"
data:
  train_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/perov_5/train.lmdb"
      lazy_storage: True
      niggli_reduce: True
  val_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/perov_5/val.lmdb"
      lazy_storage: True
      niggli_reduce: True
  pred_dataset:
    class_path: omg.datamodule.StructureDataset
    init_args:
      file_path: "data/perov_5/test.lmdb"
      lazy_storage: True
      niggli_reduce: True
  batch_size: 1024
  num_workers: 4
  pin_memory: True
  persistent_workers: True
trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_loss_total"
        save_top_k: 1
        monitor: "val_loss_total"
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_match_rate"
        save_top_k: 1
        monitor: "match_rate"
        save_weights_only: true
        mode: 'max'
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_rmsd"
        save_top_k: 1
        monitor: "mean_rmsd"
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        save_top_k: -1  # Store every checkpoint after 100 epochs.
        monitor: "val_loss_total"
        every_n_epochs: 100
        save_weights_only: false
  gradient_clip_val: 0.5
  num_sanity_val_steps: 0
  precision: "32-true"
  max_epochs: 6000
  enable_progress_bar: false
  check_val_every_n_epoch: 100
optimizer:
  class_path: torch.optim.Adam
  init_args:
    lr: 0.001147361965964576