t-reents commited on
Commit
d4ccbcc
·
verified ·
1 Parent(s): 9e7cfc9

Add model checkpoints for XtalPaint presented in https://arxiv.org/abs/2601.01959

Browse files

Model checkpoints for the `pos-only` and `TD-pos-only` models presented in Score-based diffusion models for accurate crystal-structure inpainting and reconstruction of hydrogen positions (https://arxiv.org/abs/2601.01959). These are retrained versions of the MatterGen (https://github.com/microsoft/mattergen) diffusion model for crystal structures.

TD-pos-only/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64b5bd5124d44117b0ceafbf06f108703296f3148f3e7ca1c630d4240c571bfe
3
+ size 439587436
TD-pos-only/config.yaml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto_resume: false
2
+ checkpoint_path: null
3
+ data_module:
4
+ _recursive_: true
5
+ _target_: mattergen.common.data.datamodule.CrystDataModule
6
+ average_density: 0.05771451654022283
7
+ batch_size:
8
+ train: 128
9
+ val: 128
10
+ dataset_transforms:
11
+ - _partial_: true
12
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
13
+ max_epochs: 2200
14
+ num_workers:
15
+ train: 128
16
+ val: 128
17
+ properties: []
18
+ root_dir: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H
19
+ train_dataset:
20
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
21
+ cache_path: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H/train
22
+ dataset_transforms:
23
+ - _partial_: true
24
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
25
+ properties: []
26
+ transforms:
27
+ - _partial_: true
28
+ _target_: mattergen.common.data.transform.symmetrize_lattice
29
+ - _partial_: true
30
+ _target_: mattergen.common.data.transform.set_chemical_system_string
31
+ transforms:
32
+ - _partial_: true
33
+ _target_: mattergen.common.data.transform.symmetrize_lattice
34
+ - _partial_: true
35
+ _target_: mattergen.common.data.transform.set_chemical_system_string
36
+ val_dataset:
37
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
38
+ cache_path: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H/val
39
+ dataset_transforms:
40
+ - _partial_: true
41
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
42
+ properties: []
43
+ transforms:
44
+ - _partial_: true
45
+ _target_: mattergen.common.data.transform.symmetrize_lattice
46
+ - _partial_: true
47
+ _target_: mattergen.common.data.transform.set_chemical_system_string
48
+ lightning_module:
49
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
50
+ diffusion_module:
51
+ _target_: dbcsi_inpainting.time_dependent.diffusion_module.TDDiffusionModule
52
+ corruption:
53
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
54
+ sdes:
55
+ pos:
56
+ _target_: dbcsi_inpainting.time_dependent.corruption.TDNumAtomsVarianceAdjustedWrappedVESDE
57
+ limit_info_key: num_atoms
58
+ sigma_max: 5.0
59
+ wrapping_boundary: 1.0
60
+ loss_fn:
61
+ _target_: dbcsi_inpainting.time_dependent.loss.TDMaterialsLoss
62
+ d3pm_hybrid_lambda: 0.01
63
+ include_atomic_numbers: false
64
+ include_cell: false
65
+ include_pos: true
66
+ reduce: sum
67
+ weights:
68
+ pos: 1
69
+ model:
70
+ _target_: mattergen.denoiser.GemNetTDenoiser
71
+ atom_type_diffusion: mask
72
+ denoise_atom_types: false
73
+ gemnet:
74
+ _target_: mattergen.common.gemnet.gemnet.GemNetT
75
+ atom_embedding:
76
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
77
+ emb_size: 512
78
+ with_mask_type: false
79
+ cutoff: 7.0
80
+ emb_size_atom: 512
81
+ emb_size_edge: 512
82
+ latent_dim: 512
83
+ max_cell_images_per_dim: 5
84
+ max_neighbors: 50
85
+ num_blocks: 4
86
+ num_targets: 1
87
+ otf_graph: true
88
+ regress_stress: true
89
+ scale_file: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/common/gemnet/gemnet-dT.json
90
+ hidden_dim: 512
91
+ property_embeddings: {}
92
+ property_embeddings_adapt: {}
93
+ p_replace: 0.2
94
+ pre_corruption_fn:
95
+ _target_: mattergen.property_embeddings.SetEmbeddingType
96
+ dropout_fields_iid: false
97
+ p_unconditional: 0.2
98
+ t_replace: 0.001
99
+ optimizer_partial:
100
+ _partial_: true
101
+ _target_: torch.optim.Adam
102
+ lr: 0.0001
103
+ scheduler_partials:
104
+ - frequency: 1
105
+ interval: epoch
106
+ monitor: loss_train
107
+ scheduler:
108
+ _partial_: true
109
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
110
+ factor: 0.6
111
+ min_lr: 1.0e-06
112
+ patience: 100
113
+ verbose: true
114
+ strict: true
115
+ load_original: false
116
+ params: {}
117
+ trainer:
118
+ _target_: pytorch_lightning.Trainer
119
+ accelerator: gpu
120
+ accumulate_grad_batches: 1
121
+ callbacks:
122
+ - _target_: pytorch_lightning.callbacks.EarlyStopping
123
+ min_delta: 0.01
124
+ mode: min
125
+ monitor: loss_val
126
+ patience: 20
127
+ strict: true
128
+ verbose: true
129
+ check_val_every_n_epoch: 5
130
+ devices: 4
131
+ gradient_clip_algorithm: value
132
+ gradient_clip_val: 0.5
133
+ max_epochs: 2200
134
+ num_nodes: 1
135
+ precision: 32
136
+ strategy:
137
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
138
+ find_unused_parameters: true
pos-only/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd6335286fd98e633a27109331261e5b7cb84afc8328849745855305aab2797a
3
+ size 439588140
pos-only/config.yaml ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto_resume: false
2
+ checkpoint_path: null
3
+ data_module:
4
+ _recursive_: true
5
+ _target_: mattergen.common.data.datamodule.CrystDataModule
6
+ average_density: 0.05771451654022283
7
+ batch_size:
8
+ train: 128
9
+ val: 128
10
+ dataset_transforms:
11
+ - _partial_: true
12
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
13
+ max_epochs: 2200
14
+ num_workers:
15
+ train: 128
16
+ val: 128
17
+ properties: []
18
+ root_dir: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H
19
+ train_dataset:
20
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
21
+ cache_path: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H/train
22
+ dataset_transforms:
23
+ - _partial_: true
24
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
25
+ properties: []
26
+ transforms:
27
+ - _partial_: true
28
+ _target_: mattergen.common.data.transform.symmetrize_lattice
29
+ - _partial_: true
30
+ _target_: mattergen.common.data.transform.set_chemical_system_string
31
+ transforms:
32
+ - _partial_: true
33
+ _target_: mattergen.common.data.transform.symmetrize_lattice
34
+ - _partial_: true
35
+ _target_: mattergen.common.data.transform.set_chemical_system_string
36
+ val_dataset:
37
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
38
+ cache_path: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H/val
39
+ dataset_transforms:
40
+ - _partial_: true
41
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
42
+ properties: []
43
+ transforms:
44
+ - _partial_: true
45
+ _target_: mattergen.common.data.transform.symmetrize_lattice
46
+ - _partial_: true
47
+ _target_: mattergen.common.data.transform.set_chemical_system_string
48
+ lightning_module:
49
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
50
+ diffusion_module:
51
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
52
+ corruption:
53
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
54
+ sdes:
55
+ pos:
56
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
57
+ limit_info_key: num_atoms
58
+ sigma_max: 5.0
59
+ wrapping_boundary: 1.0
60
+ loss_fn:
61
+ _target_: mattergen.common.loss.MaterialsLoss
62
+ d3pm_hybrid_lambda: 0.01
63
+ include_atomic_numbers: false
64
+ include_cell: false
65
+ include_pos: true
66
+ reduce: sum
67
+ weights:
68
+ pos: 1.0
69
+ model:
70
+ _target_: mattergen.denoiser.GemNetTDenoiser
71
+ atom_type_diffusion: mask
72
+ denoise_atom_types: false
73
+ gemnet:
74
+ _target_: mattergen.common.gemnet.gemnet.GemNetT
75
+ atom_embedding:
76
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
77
+ emb_size: 512
78
+ with_mask_type: false
79
+ cutoff: 7.0
80
+ emb_size_atom: 512
81
+ emb_size_edge: 512
82
+ latent_dim: 512
83
+ max_cell_images_per_dim: 5
84
+ max_neighbors: 50
85
+ num_blocks: 4
86
+ num_targets: 1
87
+ otf_graph: true
88
+ regress_stress: true
89
+ scale_file: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/common/gemnet/gemnet-dT.json
90
+ hidden_dim: 512
91
+ property_embeddings: {}
92
+ property_embeddings_adapt: {}
93
+ pre_corruption_fn:
94
+ _target_: mattergen.property_embeddings.SetEmbeddingType
95
+ dropout_fields_iid: false
96
+ p_unconditional: 0.2
97
+ optimizer_partial:
98
+ _partial_: true
99
+ _target_: torch.optim.Adam
100
+ lr: 0.0001
101
+ scheduler_partials:
102
+ - frequency: 1
103
+ interval: epoch
104
+ monitor: loss_train
105
+ scheduler:
106
+ _partial_: true
107
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
108
+ factor: 0.6
109
+ min_lr: 1.0e-06
110
+ patience: 100
111
+ verbose: true
112
+ strict: true
113
+ load_original: false
114
+ params: {}
115
+ trainer:
116
+ _target_: pytorch_lightning.Trainer
117
+ accelerator: gpu
118
+ accumulate_grad_batches: 1
119
+ callbacks:
120
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
121
+ log_momentum: false
122
+ logging_interval: step
123
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
124
+ every_n_epochs: 1
125
+ filename: '{epoch}-{loss_val:.2f}'
126
+ mode: min
127
+ monitor: loss_val
128
+ save_last: true
129
+ save_top_k: 1
130
+ verbose: false
131
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
132
+ refresh_rate: 50
133
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
134
+ - _target_: pytorch_lightning.callbacks.EarlyStopping
135
+ min_delta: 0.01
136
+ mode: min
137
+ monitor: loss_val
138
+ patience: 30
139
+ strict: true
140
+ verbose: true
141
+ check_val_every_n_epoch: 5
142
+ devices: 4
143
+ gradient_clip_algorithm: value
144
+ gradient_clip_val: 0.5
145
+ max_epochs: 2200
146
+ num_nodes: 1
147
+ precision: 32
148
+ strategy:
149
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
150
+ find_unused_parameters: true