lucascamillomd commited on
Commit
10fc2ea
·
verified ·
1 Parent(s): ba0aa00

Upload config/proteins.yaml with huggingface_hub

Browse files
Files changed (1) hide show
  1. config/proteins.yaml +149 -0
config/proteins.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ tags:
3
+ - finetune_boa_small_proteins
4
+ - finetuning
5
+ - small
6
+ - boa
7
+ - proteins
8
+ train: true
9
+ test: true
10
+ trainer_ckpt_path: null
11
+ model_ckpt_path: ${paths.dependencies_dir}/model/weights/small.ckpt
12
+ strict_load: false
13
+ seed: 42
14
+ data:
15
+ batch_size: 4
16
+ dna_llm: nucleotide-transformer-v2-500m-multi-species
17
+ max_length: 10000
18
+ sorting_strategy: sorted_chromosome
19
+ dna_context_len: 2001
20
+ num_workers: 8
21
+ pin_memory: false
22
+ _target_: cpgpt.data.cpgpt_datamodule.CpGPTDataModule
23
+ train_dir: ${paths.data_dir}/boa/processed/train
24
+ val_dir: ${paths.data_dir}/boa/processed/test
25
+ test_dir: ${paths.data_dir}/boa/processed/test
26
+ dependencies_dir: ${paths.dependencies_dir}/human
27
+ model:
28
+ optimizer:
29
+ _target_: schedulefree.AdamWScheduleFree
30
+ _partial_: true
31
+ lr: 0.0001
32
+ weight_decay: 0.1
33
+ betas:
34
+ - 0.9
35
+ - 0.95
36
+ warmup_steps: ${trainer.min_steps}
37
+ scheduler:
38
+ _target_: torch.optim.lr_scheduler.ConstantLR
39
+ _partial_: true
40
+ factor: 1.0
41
+ total_iters: 1
42
+ net:
43
+ _target_: cpgpt.model.components.model.CpGPT
44
+ d_embedding: 128
45
+ d_hidden: 128
46
+ d_dna_embedding: 1024
47
+ n_attention_heads: 8
48
+ n_layers: 8
49
+ n_mlp_blocks: 3
50
+ dropout: 0.01
51
+ architecture: transformer
52
+ activation: swiglu
53
+ positional_encoding: rotary
54
+ sample_embedding_method: cls
55
+ use_power_norm: false
56
+ fft: false
57
+ use_condition_decoder: true
58
+ condition_size: 322
59
+ use_noise_decoder: false
60
+ mlp_block_bias: false
61
+ mlp_block_norm_type: rmsnorm
62
+ mlp_block_pre_norm: false
63
+ mlp_block_post_norm: false
64
+ transformer_block_bias: false
65
+ transformer_block_norm_type: rmsnorm
66
+ transformer_block_norm_first: true
67
+ transformer_block_dropout: 0.0
68
+ training:
69
+ generative_splits: 5
70
+ binarize_input: false
71
+ contrastive_threshold: 0.5
72
+ diffusion: false
73
+ reconstruct_mode: all
74
+ diffusion_params:
75
+ num_timesteps: 1000
76
+ loss_weights:
77
+ m_mae: 10.0
78
+ m_mae_unc: 1.0
79
+ betas_mae: 0.0
80
+ betas_kld: 0.0
81
+ betas_beta: 0.0
82
+ betas_wd: 1.0
83
+ contrastive: 1.0
84
+ sample_kld: 1.0
85
+ diffusion_mse: 0.0
86
+ condition_loss: 5.0
87
+ condition_decoder_loss: mae
88
+ _target_: cpgpt.model.cpgpt_module.CpGPTLitModule
89
+ compile: true
90
+ callbacks:
91
+ model_checkpoint:
92
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
93
+ dirpath: ${paths.output_dir}/checkpoints
94
+ filename: step_{step:06d}
95
+ monitor: val/condition_loss
96
+ verbose: false
97
+ save_last: true
98
+ save_top_k: 1
99
+ mode: min
100
+ auto_insert_metric_name: false
101
+ save_weights_only: false
102
+ every_n_train_steps: null
103
+ train_time_interval: null
104
+ every_n_epochs: null
105
+ save_on_train_epoch_end: null
106
+ model_summary:
107
+ _target_: lightning.pytorch.callbacks.RichModelSummary
108
+ max_depth: 3
109
+ rich_progress_bar:
110
+ _target_: lightning.pytorch.callbacks.RichProgressBar
111
+ logger:
112
+ wandb:
113
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
114
+ save_dir: ${paths.output_dir}
115
+ offline: false
116
+ id: null
117
+ anonymous: null
118
+ project: CpGPT
119
+ log_model: true
120
+ prefix: ''
121
+ entity: lucascamillo
122
+ group: ''
123
+ tags: ${tags}
124
+ job_type: ''
125
+ trainer:
126
+ _target_: lightning.pytorch.trainer.Trainer
127
+ default_root_dir: ${paths.output_dir}
128
+ min_steps: 600
129
+ max_steps: 6000
130
+ accelerator: auto
131
+ devices: 1
132
+ precision: 16-mixed
133
+ val_check_interval: 100
134
+ check_val_every_n_epoch: null
135
+ log_every_n_steps: 1
136
+ detect_anomaly: false
137
+ deterministic: false
138
+ accumulate_grad_batches: 1
139
+ paths:
140
+ root_dir: ${oc.env:PROJECT_ROOT}
141
+ data_dir: ${paths.root_dir}/data/
142
+ dependencies_dir: ${paths.root_dir}/dependencies/
143
+ log_dir: ${paths.root_dir}/logs/
144
+ output_dir: ${hydra:runtime.output_dir}
145
+ work_dir: ${hydra:runtime.cwd}
146
+ extras:
147
+ ignore_warnings: true
148
+ enforce_tags: true
149
+ print_config: true