sathyae's picture
Upload folder using huggingface_hub
f20d7fc verified
Raw
History Blame Contribute Delete
4.99 kB
data_module:
_target_: mattergen.common.data.datamodule.CrystDataModule
_recursive_: true
properties: []
transforms:
- _target_: mattergen.common.data.transform.symmetrize_lattice
_partial_: true
- _target_: mattergen.common.data.transform.set_chemical_system_string
_partial_: true
- _target_: mattergen.common.data.transform.set_composition_count
_partial_: true
dataset_transforms:
- _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
_partial_: true
average_density: 0.05771451654022283
root_dir: ${oc.env:PROJECT_ROOT}/../datasets/cache/broader_csp
train_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: ${data_module.root_dir}/train
properties: ${data_module.properties}
transforms: ${data_module.transforms}
dataset_transforms: ${data_module.dataset_transforms}
val_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: ${data_module.root_dir}/val
properties: ${data_module.properties}
transforms: ${data_module.transforms}
dataset_transforms: ${data_module.dataset_transforms}
test_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: ${data_module.root_dir}/test
properties: ${data_module.properties}
transforms: ${data_module.transforms}
dataset_transforms: ${data_module.dataset_transforms}
num_workers:
train: 0
val: 0
test: 0
batch_size:
train: ${eval:'(512 // ${trainer.accumulate_grad_batches}) // (${trainer.devices}
* ${trainer.num_nodes})'}
val: 8
test: 8
max_epochs: 900
trainer:
_target_: pytorch_lightning.Trainer
accelerator: gpu
devices: 8
num_nodes: 9
precision: 32
max_epochs: ${data_module.max_epochs}
accumulate_grad_batches: 1
gradient_clip_val: 0.5
gradient_clip_algorithm: value
check_val_every_n_epoch: 5
strategy:
_target_: pytorch_lightning.strategies.ddp.DDPStrategy
find_unused_parameters: true
callbacks:
- _target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: false
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: loss_val
mode: min
save_top_k: 1
save_last: true
verbose: false
every_n_epochs: 1
filename: '{epoch}-{loss_val:.2f}'
- _target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: 50
- _target_: mattergen.common.data.callback.SetPropertyScalers
max_steps: 200000
lightning_module:
_target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
optimizer_partial:
lr: 0.0001
_target_: torch.optim.Adam
_partial_: true
scheduler_partials:
- scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
factor: 0.6
patience: 100
min_lr: 1.0e-06
_partial_: true
interval: epoch
frequency: 1
monitor: loss_train
strict: true
diffusion_module:
_target_: mattergen.diffusion.diffusion_module.DiffusionModule
model:
_target_: mattergen.denoiser.GemNetTDenoiser
hidden_dim: 512
gemnet:
_target_: mattergen.common.gemnet.gemnet.GemNetT
num_targets: 1
latent_dim: ${eval:'${..hidden_dim} * (1 + len(${..property_embeddings}))'}
atom_embedding:
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
emb_size: ${...hidden_dim}
with_mask_type: ${eval:'${...denoise_atom_types} and "${...atom_type_diffusion}"
== "mask"'}
emb_size_atom: ${..hidden_dim}
emb_size_edge: ${..hidden_dim}
max_neighbors: 50
max_cell_images_per_dim: 5
cutoff: 7.0
num_blocks: 4
regress_stress: true
otf_graph: true
scale_file: ${oc.env:PROJECT_ROOT}/common/gemnet/gemnet-dT.json
denoise_atom_types: true
atom_type_diffusion: mask
property_embeddings_adapt: {}
property_embeddings: {}
corruption:
_target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
sdes:
pos:
_target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
wrapping_boundary: 1.0
sigma_max: 5.0
limit_info_key: num_atoms
cell:
_target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
vpsde_config:
beta_min: 0.1
beta_max: 20
limit_density: ${data_module.average_density}
limit_var_scaling_constant: 0.25
loss_fn:
_target_: mattergen.common.loss.MaterialsLoss
reduce: sum
include_pos: true
include_cell: true
include_atomic_numbers: false
weights:
cell: 1.0
pos: 0.1
pre_corruption_fn:
_target_: mattergen.property_embeddings.SetEmbeddingType
p_unconditional: 0.2
dropout_fields_iid: false
auto_resume: true