amyxlu commited on
Commit
54209a4
·
verified ·
1 Parent(s): 5be476e

Upload 2 files

Browse files
Files changed (2) hide show
  1. PLAID-100M/config.yaml +101 -0
  2. PLAID-100M/last.ckpt +3 -0
PLAID-100M/config.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resume_from_model_id: null
2
+ compression_model_id: j1v1wv6w
3
+ use_old_ema_module: false
4
+ paths:
5
+ project_dir: /homefs/home/lux70/code/plaid
6
+ bucket_dir: /data/lux70/plaid
7
+ data_dir: /data/lux70/data
8
+ home_dir: /homefs/home/lux70
9
+ log_dir: ${paths.bucket_dir}/logs
10
+ checkpoint_dir: ${paths.bucket_dir}/checkpoints/plaid-compositional
11
+ artifacts_dir: ${paths.bucket_dir}/artifacts
12
+ entity: lu-amy-al1
13
+ trainer:
14
+ _target_: lightning.pytorch.Trainer
15
+ accelerator: gpu
16
+ strategy: ddp_find_unused_parameters_true
17
+ devices: -1
18
+ num_nodes: 1
19
+ precision: '32'
20
+ gradient_clip_val: 1.0
21
+ log_every_n_steps: 50
22
+ num_sanity_val_steps: 0
23
+ gradient_clip_algorithm: norm
24
+ max_epochs: 20000
25
+ default_root_dir: ${paths.log_dir}
26
+ datamodule:
27
+ _target_: plaid.datasets.FunctionOrganismDataModule
28
+ train_shards: ${paths.data_dir}/pfam/compressed/j1v1wv6w/train/shard{0000..4423}.tar
29
+ val_shards: ${paths.data_dir}/pfam/compressed/j1v1wv6w/val/shard{0000..0863}.tar
30
+ config_file: ${paths.data_dir}/pfam/compressed/j1v1wv6w/config.json
31
+ go_metadata_fpath: ${paths.data_dir}/pfam/pfam2go.csv
32
+ organism_metadata_fpath: ${paths.data_dir}/pfam/organism_counts.csv
33
+ cache_dir: ${paths.home_dir}/.cache/plaid_data/j1v1wv6w
34
+ train_epoch_num_batches: 1000000
35
+ val_epoch_num_batches: 1000
36
+ shuffle_buffer: 20000
37
+ shuffle_initial: 20000
38
+ max_length: 256
39
+ batch_size: 256
40
+ num_workers: 8
41
+ prefetch_factor: 4
42
+ denoiser:
43
+ _target_: plaid.denoisers.FunctionOrganismUDiT
44
+ hidden_size: 768
45
+ max_seq_len: 512
46
+ depth: 12
47
+ num_heads: 12
48
+ mlp_ratio: 4.0
49
+ use_self_conditioning: true
50
+ timestep_embedding_strategy: fourier
51
+ use_skip_connect: false
52
+ attention_mode: xformers_memory_efficient
53
+ diffusion:
54
+ _target_: plaid.diffusion.FunctionOrganismDiffusion
55
+ beta_scheduler_name: sigmoid
56
+ beta_scheduler_start: -3
57
+ beta_scheduler_end: 3
58
+ beta_scheduler_tau: 1
59
+ x_downscale_factor: 1.0
60
+ timesteps: 1000
61
+ objective: pred_v
62
+ min_snr_loss_weight: true
63
+ min_snr_gamma: 5
64
+ x_clip_val: 1.0
65
+ function_y_cond_drop_prob: 0.1
66
+ organism_y_cond_drop_prob: 0.1
67
+ ema_decay: 0.9999
68
+ lr: 0.0001
69
+ lr_adam_betas:
70
+ - 0.9
71
+ - 0.999
72
+ lr_sched_type: cosine_with_restarts
73
+ lr_num_warmup_steps: 10000
74
+ lr_num_training_steps: 1000000
75
+ lr_num_cycles: 1
76
+ callbacks:
77
+ checkpoint:
78
+ _target_: plaid.callbacks.EMAModelCheckpoint
79
+ save_last: link
80
+ filename: epoch{epoch}-step{step}
81
+ verbose: true
82
+ every_n_train_steps: 10000
83
+ monitor: step
84
+ save_top_k: 1
85
+ mode: max
86
+ auto_insert_metric_name: false
87
+ dirpath: ${paths.checkpoint_dir}
88
+ ema:
89
+ _target_: plaid.callbacks.EMA
90
+ decay: 0.9999
91
+ apply_ema_every_n_steps: 1
92
+ start_step: 0
93
+ save_ema_weights_in_callback_state: false
94
+ evaluate_ema_weights_instead: true
95
+ logger:
96
+ _target_: lightning.pytorch.loggers.WandbLogger
97
+ project: plaid-compositional-conditioning
98
+ entity: prescient-design
99
+ name: UDiT_B
100
+ tags: null
101
+ group: null
PLAID-100M/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c0eac0cbf1e653c913578d45f878ca8bbf28e417db19d53077de3233b5a4ab7
3
+ size 1608486946