File size: 4,993 Bytes
f20d7fc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | data_module:
_target_: mattergen.common.data.datamodule.CrystDataModule
_recursive_: true
properties: []
transforms:
- _target_: mattergen.common.data.transform.symmetrize_lattice
_partial_: true
- _target_: mattergen.common.data.transform.set_chemical_system_string
_partial_: true
- _target_: mattergen.common.data.transform.set_composition_count
_partial_: true
dataset_transforms:
- _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
_partial_: true
average_density: 0.05771451654022283
root_dir: ${oc.env:PROJECT_ROOT}/../datasets/cache/broader_csp
train_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: ${data_module.root_dir}/train
properties: ${data_module.properties}
transforms: ${data_module.transforms}
dataset_transforms: ${data_module.dataset_transforms}
val_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: ${data_module.root_dir}/val
properties: ${data_module.properties}
transforms: ${data_module.transforms}
dataset_transforms: ${data_module.dataset_transforms}
test_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: ${data_module.root_dir}/test
properties: ${data_module.properties}
transforms: ${data_module.transforms}
dataset_transforms: ${data_module.dataset_transforms}
num_workers:
train: 0
val: 0
test: 0
batch_size:
train: ${eval:'(512 // ${trainer.accumulate_grad_batches}) // (${trainer.devices}
* ${trainer.num_nodes})'}
val: 8
test: 8
max_epochs: 900
trainer:
_target_: pytorch_lightning.Trainer
accelerator: gpu
devices: 8
num_nodes: 9
precision: 32
max_epochs: ${data_module.max_epochs}
accumulate_grad_batches: 1
gradient_clip_val: 0.5
gradient_clip_algorithm: value
check_val_every_n_epoch: 5
strategy:
_target_: pytorch_lightning.strategies.ddp.DDPStrategy
find_unused_parameters: true
callbacks:
- _target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: false
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: loss_val
mode: min
save_top_k: 1
save_last: true
verbose: false
every_n_epochs: 1
filename: '{epoch}-{loss_val:.2f}'
- _target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: 50
- _target_: mattergen.common.data.callback.SetPropertyScalers
max_steps: 200000
lightning_module:
_target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
optimizer_partial:
lr: 0.0001
_target_: torch.optim.Adam
_partial_: true
scheduler_partials:
- scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
factor: 0.6
patience: 100
min_lr: 1.0e-06
_partial_: true
interval: epoch
frequency: 1
monitor: loss_train
strict: true
diffusion_module:
_target_: mattergen.diffusion.diffusion_module.DiffusionModule
model:
_target_: mattergen.denoiser.GemNetTDenoiser
hidden_dim: 512
gemnet:
_target_: mattergen.common.gemnet.gemnet.GemNetT
num_targets: 1
latent_dim: ${eval:'${..hidden_dim} * (1 + len(${..property_embeddings}))'}
atom_embedding:
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
emb_size: ${...hidden_dim}
with_mask_type: ${eval:'${...denoise_atom_types} and "${...atom_type_diffusion}"
== "mask"'}
emb_size_atom: ${..hidden_dim}
emb_size_edge: ${..hidden_dim}
max_neighbors: 50
max_cell_images_per_dim: 5
cutoff: 7.0
num_blocks: 4
regress_stress: true
otf_graph: true
scale_file: ${oc.env:PROJECT_ROOT}/common/gemnet/gemnet-dT.json
denoise_atom_types: true
atom_type_diffusion: mask
property_embeddings_adapt: {}
property_embeddings: {}
corruption:
_target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
sdes:
pos:
_target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
wrapping_boundary: 1.0
sigma_max: 5.0
limit_info_key: num_atoms
cell:
_target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
vpsde_config:
beta_min: 0.1
beta_max: 20
limit_density: ${data_module.average_density}
limit_var_scaling_constant: 0.25
loss_fn:
_target_: mattergen.common.loss.MaterialsLoss
reduce: sum
include_pos: true
include_cell: true
include_atomic_numbers: false
weights:
cell: 1.0
pos: 0.1
pre_corruption_fn:
_target_: mattergen.property_embeddings.SetEmbeddingType
p_unconditional: 0.2
dropout_fields_iid: false
auto_resume: true
|