File size: 4,993 Bytes
f20d7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
data_module:
  _target_: mattergen.common.data.datamodule.CrystDataModule
  _recursive_: true
  properties: []
  transforms:
  - _target_: mattergen.common.data.transform.symmetrize_lattice
    _partial_: true
  - _target_: mattergen.common.data.transform.set_chemical_system_string
    _partial_: true
  - _target_: mattergen.common.data.transform.set_composition_count
    _partial_: true
  dataset_transforms:
  - _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
    _partial_: true
  average_density: 0.05771451654022283
  root_dir: ${oc.env:PROJECT_ROOT}/../datasets/cache/broader_csp
  train_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: ${data_module.root_dir}/train
    properties: ${data_module.properties}
    transforms: ${data_module.transforms}
    dataset_transforms: ${data_module.dataset_transforms}
  val_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: ${data_module.root_dir}/val
    properties: ${data_module.properties}
    transforms: ${data_module.transforms}
    dataset_transforms: ${data_module.dataset_transforms}
  test_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: ${data_module.root_dir}/test
    properties: ${data_module.properties}
    transforms: ${data_module.transforms}
    dataset_transforms: ${data_module.dataset_transforms}
  num_workers:
    train: 0
    val: 0
    test: 0
  batch_size:
    train: ${eval:'(512 // ${trainer.accumulate_grad_batches}) // (${trainer.devices}
      * ${trainer.num_nodes})'}
    val: 8
    test: 8
  max_epochs: 900
trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: 8
  num_nodes: 9
  precision: 32
  max_epochs: ${data_module.max_epochs}
  accumulate_grad_batches: 1
  gradient_clip_val: 0.5
  gradient_clip_algorithm: value
  check_val_every_n_epoch: 5
  strategy:
    _target_: pytorch_lightning.strategies.ddp.DDPStrategy
    find_unused_parameters: true
  callbacks:
  - _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: step
    log_momentum: false
  - _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: loss_val
    mode: min
    save_top_k: 1
    save_last: true
    verbose: false
    every_n_epochs: 1
    filename: '{epoch}-{loss_val:.2f}'
  - _target_: pytorch_lightning.callbacks.TQDMProgressBar
    refresh_rate: 50
  - _target_: mattergen.common.data.callback.SetPropertyScalers
  max_steps: 200000
lightning_module:
  _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
  optimizer_partial:
    lr: 0.0001
    _target_: torch.optim.Adam
    _partial_: true
  scheduler_partials:
  - scheduler:
      _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
      factor: 0.6
      patience: 100
      min_lr: 1.0e-06
      _partial_: true
    interval: epoch
    frequency: 1
    monitor: loss_train
    strict: true
  diffusion_module:
    _target_: mattergen.diffusion.diffusion_module.DiffusionModule
    model:
      _target_: mattergen.denoiser.GemNetTDenoiser
      hidden_dim: 512
      gemnet:
        _target_: mattergen.common.gemnet.gemnet.GemNetT
        num_targets: 1
        latent_dim: ${eval:'${..hidden_dim} * (1 + len(${..property_embeddings}))'}
        atom_embedding:
          _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
          emb_size: ${...hidden_dim}
          with_mask_type: ${eval:'${...denoise_atom_types} and "${...atom_type_diffusion}"
            == "mask"'}
        emb_size_atom: ${..hidden_dim}
        emb_size_edge: ${..hidden_dim}
        max_neighbors: 50
        max_cell_images_per_dim: 5
        cutoff: 7.0
        num_blocks: 4
        regress_stress: true
        otf_graph: true
        scale_file: ${oc.env:PROJECT_ROOT}/common/gemnet/gemnet-dT.json
      denoise_atom_types: true
      atom_type_diffusion: mask
      property_embeddings_adapt: {}
      property_embeddings: {}
    corruption:
      _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
      sdes:
        pos:
          _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
          wrapping_boundary: 1.0
          sigma_max: 5.0
          limit_info_key: num_atoms
        cell:
          _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
          vpsde_config:
            beta_min: 0.1
            beta_max: 20
            limit_density: ${data_module.average_density}
            limit_var_scaling_constant: 0.25
    loss_fn:
      _target_: mattergen.common.loss.MaterialsLoss
      reduce: sum
      include_pos: true
      include_cell: true
      include_atomic_numbers: false
      weights:
        cell: 1.0
        pos: 0.1
    pre_corruption_fn:
      _target_: mattergen.property_embeddings.SetEmbeddingType
      p_unconditional: 0.2
      dropout_fields_iid: false
auto_resume: true