richfrain commited on
Commit
04ef997
·
verified ·
1 Parent(s): 36ada67

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.yaml +232 -0
  2. last.ckpt +3 -0
config.yaml ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adapter:
2
+ adapter:
3
+ _target_: mattergen.adapter.GemNetTAdapter
4
+ atom_type_diffusion: mask
5
+ denoise_atom_types: true
6
+ gemnet:
7
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
8
+ atom_embedding:
9
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
10
+ emb_size: 512
11
+ with_mask_type: true
12
+ condition_on_adapt:
13
+ - over_potential
14
+ cutoff: 7.0
15
+ emb_size_atom: 512
16
+ emb_size_edge: 512
17
+ latent_dim: 512
18
+ max_cell_images_per_dim: 5
19
+ max_neighbors: 50
20
+ num_blocks: 4
21
+ num_targets: 1
22
+ otf_graph: true
23
+ regress_stress: true
24
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
25
+ hidden_dim: 512
26
+ property_embeddings: {}
27
+ property_embeddings_adapt:
28
+ over_potential:
29
+ _target_: mattergen.property_embeddings.PropertyEmbedding
30
+ conditional_embedding_module:
31
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
32
+ d_model: 512
33
+ name: over_potential
34
+ scaler:
35
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
36
+ unconditional_embedding_module:
37
+ _target_: mattergen.property_embeddings.EmbeddingVector
38
+ hidden_dim: 512
39
+ full_finetuning: true
40
+ load_epoch: last
41
+ model_path: null
42
+ pretrained_name: mattergen_base
43
+ data_module:
44
+ _recursive_: true
45
+ _target_: mattergen.common.data.datamodule.CrystDataModule
46
+ average_density: 0.05771451654022283
47
+ batch_size:
48
+ test: 8
49
+ train: 32
50
+ val: 8
51
+ dataset_transforms:
52
+ - _partial_: true
53
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
54
+ max_epochs: 50
55
+ num_workers:
56
+ test: 0
57
+ train: 0
58
+ val: 0
59
+ properties:
60
+ - over_potential
61
+ root_dir: /root/mattergen/mattergen/../datasets/cache/mp_20o
62
+ test_dataset:
63
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
64
+ cache_path: /root/mattergen/mattergen/../datasets/cache/mp_20o/test
65
+ dataset_transforms:
66
+ - _partial_: true
67
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
68
+ properties:
69
+ - over_potential
70
+ transforms:
71
+ - _partial_: true
72
+ _target_: mattergen.common.data.transform.symmetrize_lattice
73
+ - _partial_: true
74
+ _target_: mattergen.common.data.transform.set_chemical_system_string
75
+ train_dataset:
76
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
77
+ cache_path: /root/mattergen/mattergen/../datasets/cache/mp_20o/train
78
+ dataset_transforms:
79
+ - _partial_: true
80
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
81
+ properties:
82
+ - over_potential
83
+ transforms:
84
+ - _partial_: true
85
+ _target_: mattergen.common.data.transform.symmetrize_lattice
86
+ - _partial_: true
87
+ _target_: mattergen.common.data.transform.set_chemical_system_string
88
+ transforms:
89
+ - _partial_: true
90
+ _target_: mattergen.common.data.transform.symmetrize_lattice
91
+ - _partial_: true
92
+ _target_: mattergen.common.data.transform.set_chemical_system_string
93
+ val_dataset:
94
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
95
+ cache_path: /root/mattergen/mattergen/../datasets/cache/mp_20o/val
96
+ dataset_transforms:
97
+ - _partial_: true
98
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
99
+ properties:
100
+ - over_potential
101
+ transforms:
102
+ - _partial_: true
103
+ _target_: mattergen.common.data.transform.symmetrize_lattice
104
+ - _partial_: true
105
+ _target_: mattergen.common.data.transform.set_chemical_system_string
106
+ lightning_module:
107
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
108
+ diffusion_module:
109
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
110
+ corruption:
111
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
112
+ discrete_corruptions:
113
+ atomic_numbers:
114
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
115
+ d3pm:
116
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
117
+ dim: 101
118
+ schedule:
119
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
120
+ kind: standard
121
+ num_steps: 1000
122
+ offset: 1
123
+ sdes:
124
+ cell:
125
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
126
+ vpsde_config:
127
+ beta_max: 20
128
+ beta_min: 0.1
129
+ limit_density: 0.05771451654022283
130
+ limit_var_scaling_constant: 0.25
131
+ pos:
132
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
133
+ limit_info_key: num_atoms
134
+ sigma_max: 5.0
135
+ wrapping_boundary: 1.0
136
+ loss_fn:
137
+ _target_: mattergen.common.loss.MaterialsLoss
138
+ d3pm_hybrid_lambda: 0.01
139
+ include_atomic_numbers: true
140
+ include_cell: true
141
+ include_pos: true
142
+ reduce: sum
143
+ weights:
144
+ atomic_numbers: 1.0
145
+ cell: 1.0
146
+ pos: 0.1
147
+ model:
148
+ _target_: mattergen.adapter.GemNetTAdapter
149
+ atom_type_diffusion: mask
150
+ denoise_atom_types: true
151
+ gemnet:
152
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
153
+ atom_embedding:
154
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
155
+ emb_size: 512
156
+ with_mask_type: true
157
+ condition_on_adapt:
158
+ - over_potential
159
+ cutoff: 7.0
160
+ emb_size_atom: 512
161
+ emb_size_edge: 512
162
+ latent_dim: 512
163
+ max_cell_images_per_dim: 5
164
+ max_neighbors: 50
165
+ num_blocks: 4
166
+ num_targets: 1
167
+ otf_graph: true
168
+ regress_stress: true
169
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
170
+ hidden_dim: 512
171
+ property_embeddings: {}
172
+ property_embeddings_adapt:
173
+ over_potential:
174
+ _target_: mattergen.property_embeddings.PropertyEmbedding
175
+ conditional_embedding_module:
176
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
177
+ d_model: 512
178
+ name: over_potential
179
+ scaler:
180
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
181
+ unconditional_embedding_module:
182
+ _target_: mattergen.property_embeddings.EmbeddingVector
183
+ hidden_dim: 512
184
+ pre_corruption_fn:
185
+ _target_: mattergen.property_embeddings.SetEmbeddingType
186
+ dropout_fields_iid: false
187
+ p_unconditional: 0.2
188
+ optimizer_partial:
189
+ _partial_: true
190
+ _target_: torch.optim.Adam
191
+ lr: 5.0e-06
192
+ scheduler_partials:
193
+ - frequency: 1
194
+ interval: epoch
195
+ monitor: loss_train
196
+ scheduler:
197
+ _partial_: true
198
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
199
+ factor: 0.6
200
+ min_lr: 1.0e-06
201
+ patience: 100
202
+ verbose: true
203
+ strict: true
204
+ trainer:
205
+ _target_: pytorch_lightning.Trainer
206
+ accelerator: gpu
207
+ accumulate_grad_batches: 1
208
+ callbacks:
209
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
210
+ log_momentum: false
211
+ logging_interval: step
212
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
213
+ every_n_epochs: 1
214
+ filename: '{epoch}-{loss_val:.2f}'
215
+ mode: min
216
+ monitor: loss_val
217
+ save_last: true
218
+ save_top_k: 1
219
+ verbose: false
220
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
221
+ refresh_rate: 50
222
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
223
+ check_val_every_n_epoch: 5
224
+ devices: 1
225
+ gradient_clip_algorithm: value
226
+ gradient_clip_val: 0.5
227
+ max_epochs: 200
228
+ num_nodes: 1
229
+ precision: 32
230
+ strategy:
231
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
232
+ find_unused_parameters: true
last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a62928972e71289c632ec0ed130d6e806c316c2cdd69a69ace8dd7a0de5c3ea
3
+ size 511777342