lucascamillomd commited on
Commit
90e2d14
·
verified ·
1 Parent(s): b57c330

Upload config/hannum_cot.yaml with huggingface_hub

Browse files
Files changed (1) hide show
  1. config/hannum_cot.yaml +147 -0
config/hannum_cot.yaml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ tags:
3
+ - finetune_hannum_small
4
+ - finetuning
5
+ - small
6
+ - hannum
7
+ train: true
8
+ test: true
9
+ trainer_ckpt_path: null
10
+ model_ckpt_path: dependencies/model/weights/small.ckpt
11
+ strict_load: true
12
+ seed: 42
13
+ data:
14
+ batch_size: 6
15
+ dna_llm: nucleotide-transformer-v2-500m-multi-species
16
+ max_length: 10000
17
+ sorting_strategy: sorted_chromosome
18
+ dna_context_len: 2001
19
+ num_workers: 8
20
+ pin_memory: false
21
+ _target_: cpgpt.data.cpgpt_datamodule.CpGPTDataModule
22
+ train_dir: ${paths.data_dir}/hannum/processed/train
23
+ val_dir: ${paths.data_dir}/hannum/processed/val
24
+ test_dir: ${paths.data_dir}/hannum/processed/test
25
+ dependencies_dir: ${paths.dependencies_dir}/human
26
+ model:
27
+ optimizer:
28
+ _target_: schedulefree.AdamWScheduleFree
29
+ _partial_: true
30
+ lr: 0.0001
31
+ weight_decay: 0.01
32
+ betas:
33
+ - 0.9
34
+ - 0.95
35
+ warmup_steps: ${trainer.min_steps}
36
+ scheduler:
37
+ _target_: torch.optim.lr_scheduler.ConstantLR
38
+ _partial_: true
39
+ factor: 1.0
40
+ total_iters: 1
41
+ net:
42
+ _target_: cpgpt.model.components.model.CpGPT
43
+ d_embedding: 128
44
+ d_hidden: 128
45
+ d_dna_embedding: 1024
46
+ n_attention_heads: 8
47
+ n_layers: 8
48
+ n_mlp_blocks: 3
49
+ dropout: 0.01
50
+ architecture: transformer
51
+ activation: swiglu
52
+ positional_encoding: rotary
53
+ sample_embedding_method: cls
54
+ use_power_norm: false
55
+ fft: false
56
+ use_condition_decoder: false
57
+ condition_size: 0
58
+ use_noise_decoder: false
59
+ mlp_block_bias: false
60
+ mlp_block_norm_type: rmsnorm
61
+ mlp_block_pre_norm: false
62
+ mlp_block_post_norm: false
63
+ transformer_block_bias: false
64
+ transformer_block_norm_type: rmsnorm
65
+ transformer_block_norm_first: true
66
+ transformer_block_dropout: 0.0
67
+ training:
68
+ generative_splits: 10
69
+ binarize_input: true
70
+ contrastive_threshold: 0.5
71
+ diffusion: false
72
+ reconstruct_mode: all
73
+ diffusion_params:
74
+ num_timesteps: 1000
75
+ loss_weights:
76
+ m_mae: 10.0
77
+ m_mae_unc: 1.0
78
+ betas_mae: 0.0
79
+ betas_kld: 0.0
80
+ betas_beta: 0.0
81
+ betas_wd: 1.0
82
+ contrastive: 1.0
83
+ sample_kld: 1.0
84
+ diffusion_mse: 0.0
85
+ condition_loss: 0.0
86
+ _target_: cpgpt.model.cpgpt_module.CpGPTLitModule
87
+ compile: true
88
+ callbacks:
89
+ model_checkpoint:
90
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
91
+ dirpath: ${paths.output_dir}/checkpoints
92
+ filename: ${tags[0]}
93
+ monitor: val/loss
94
+ verbose: false
95
+ save_last: true
96
+ save_top_k: 1
97
+ mode: min
98
+ auto_insert_metric_name: false
99
+ save_weights_only: false
100
+ every_n_train_steps: null
101
+ train_time_interval: null
102
+ every_n_epochs: null
103
+ save_on_train_epoch_end: null
104
+ model_summary:
105
+ _target_: lightning.pytorch.callbacks.RichModelSummary
106
+ max_depth: -1
107
+ rich_progress_bar:
108
+ _target_: lightning.pytorch.callbacks.RichProgressBar
109
+ logger:
110
+ wandb:
111
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
112
+ save_dir: ${paths.output_dir}
113
+ offline: false
114
+ id: null
115
+ anonymous: null
116
+ project: CpGPT
117
+ log_model: true
118
+ prefix: ''
119
+ entity: lucascamillo
120
+ group: ''
121
+ tags: ${tags}
122
+ job_type: ''
123
+ trainer:
124
+ _target_: lightning.pytorch.trainer.Trainer
125
+ default_root_dir: ${paths.output_dir}
126
+ min_steps: 1000
127
+ max_steps: 50000
128
+ accelerator: auto
129
+ devices: 1
130
+ precision: 16-mixed
131
+ val_check_interval: 1000
132
+ check_val_every_n_epoch: null
133
+ log_every_n_steps: 1
134
+ detect_anomaly: false
135
+ deterministic: false
136
+ accumulate_grad_batches: 1
137
+ paths:
138
+ root_dir: ${oc.env:PROJECT_ROOT}
139
+ data_dir: ${paths.root_dir}/data/
140
+ dependencies_dir: ${paths.root_dir}/dependencies/
141
+ log_dir: ${paths.root_dir}/logs/
142
+ output_dir: ${hydra:runtime.output_dir}
143
+ work_dir: ${hydra:runtime.cwd}
144
+ extras:
145
+ ignore_warnings: true
146
+ enforce_tags: true
147
+ print_config: true