File size: 4,842 Bytes
53ecc0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
core:
  version: ${get_flowmm_version:}
  tags:
  - ${now:%Y-%m-%d}
logging:
  val_check_interval: 5
  wandb:
    project: rfmcsp-${model.target_distribution}-${hydra:runtime.choices.data}
    entity: null
    log_model: true
    mode: online
    group: ${hydra:runtime.choices.model}-${hydra:runtime.choices.vectorfield}-${generate_id:}
  wandb_watch:
    log: all
    log_freq: 500
  lr_monitor:
    logging_interval: step
    log_momentum: false
optim:
  optimizer:
    _target_: torch.optim.AdamW
    lr: 0.0003
    weight_decay: 0.0
  lr_scheduler:
    _target_: torch.optim.lr_scheduler.CosineAnnealingLR
    T_max: ${data.train_max_epochs}
    eta_min: 1.0e-05
  interval: epoch
  ema_decay: 0.999
train:
  deterministic: warn
  random_seed: 42
  pl_trainer:
    fast_dev_run: false
    devices: 1
    accelerator: gpu
    precision: 32
    max_epochs: ${data.train_max_epochs}
    accumulate_grad_batches: 1
    num_sanity_val_steps: 1
    gradient_clip_val: 0.5
    gradient_clip_algorithm: value
    profiler: simple
  monitor_metric: val/loss
  monitor_metric_mode: min
  model_checkpoints:
    save_top_k: 1
    verbose: false
    save_last: false
  every_n_epochs_checkpoint:
    every_n_epochs: 100
    save_top_k: -1
    verbose: false
    save_last: false
val:
  compute_nll: false
test:
  compute_nll: false
  compute_loss: true
integrate:
  div_mode: rademacher
  method: euler
  num_steps: 1000
  normalize_loglik: true
  inference_anneal_slope: 0.0
  inference_anneal_offset: 0.0
base_distribution_from_data: false
data:
  dataset_name: mp_20
  dim_coords: 3
  root_path: ${oc.env:PROJECT_ROOT}/data/mp_20
  prop: formation_energy_per_atom
  num_targets: 1
  niggli: true
  primitive: false
  graph_method: crystalnn
  lattice_scale_method: scale_length
  preprocess_workers: 30
  readout: mean
  max_atoms: 20
  otf_graph: false
  eval_model_name: mp20
  tolerance: 0.1
  use_space_group: false
  use_pos_index: false
  train_max_epochs: 2000
  early_stopping_patience: 100000
  teacher_forcing_max_epoch: 500
  datamodule:
    _target_: diffcsp.pl_data.datamodule.CrystDataModule
    datasets:
      train:
        _target_: diffcsp.pl_data.dataset.CrystDataset
        name: Formation energy train
        path: ${data.root_path}/train.csv
        save_path: ${data.root_path}/train_ori.pt
        prop: ${data.prop}
        niggli: ${data.niggli}
        primitive: ${data.primitive}
        graph_method: ${data.graph_method}
        tolerance: ${data.tolerance}
        use_space_group: ${data.use_space_group}
        use_pos_index: ${data.use_pos_index}
        lattice_scale_method: ${data.lattice_scale_method}
        preprocess_workers: ${data.preprocess_workers}
      val:
      - _target_: diffcsp.pl_data.dataset.CrystDataset
        name: Formation energy val
        path: ${data.root_path}/val.csv
        save_path: ${data.root_path}/val_ori.pt
        prop: ${data.prop}
        niggli: ${data.niggli}
        primitive: ${data.primitive}
        graph_method: ${data.graph_method}
        tolerance: ${data.tolerance}
        use_space_group: ${data.use_space_group}
        use_pos_index: ${data.use_pos_index}
        lattice_scale_method: ${data.lattice_scale_method}
        preprocess_workers: ${data.preprocess_workers}
      test:
      - _target_: diffcsp.pl_data.dataset.CrystDataset
        name: Formation energy test
        path: ${data.root_path}/test.csv
        save_path: ${data.root_path}/test_ori.pt
        prop: ${data.prop}
        niggli: ${data.niggli}
        primitive: ${data.primitive}
        graph_method: ${data.graph_method}
        tolerance: ${data.tolerance}
        use_space_group: ${data.use_space_group}
        use_pos_index: ${data.use_pos_index}
        lattice_scale_method: ${data.lattice_scale_method}
        preprocess_workers: ${data.preprocess_workers}
    num_workers:
      train: 40
      val: 40
      test: 40
    batch_size:
      train: 256
      val: 1024
      test: 512
model:
  cost_coord: 400.0
  cost_lattice: 1.0
  cost_type: 40.0
  cost_cross_ent: 0.0
  affine_combine_costs: true
  target_distribution: unconditional
  self_cond: false
  manifold_getter:
    atom_type_manifold: analog_bits
    coord_manifold: flat_torus_01
    lattice_manifold: lattice_params
    analog_bits_scale: 1.0
    length_inner_coef: 1.0
vectorfield:
  _target_: flowmm.model.arch.CSPNet
  hidden_dim: 512
  time_dim: 256
  num_layers: 6
  act_fn: silu
  dis_emb: sin
  num_freqs: 128
  edge_style: fc
  max_neighbors: 20
  cutoff: 7.0
  ln: true
  use_log_map: true
  dim_atomic_rep: ${get_dim_atomic_rep:${model.manifold_getter.atom_type_manifold}}
  lattice_manifold: ${model.manifold_getter.lattice_manifold}
  concat_sum_pool: true
  represent_num_atoms: true
  represent_angle_edge_to_lattice: true
  self_edges: false
  self_cond: ${model.self_cond}