lucascamillomd commited on
Commit
af6512a
·
verified ·
1 Parent(s): 11413ed

Upload config/age.yaml with huggingface_hub

Browse files
Files changed (1) hide show
  1. config/age.yaml +142 -0
config/age.yaml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ tags:
3
+ - finetune_altumage_small_age
4
+ - finetuning
5
+ - small
6
+ - altumage
7
+ - age
8
+ train: true
9
+ test: true
10
+ ckpt_path: null
11
+ model_ckpt_path: dependencies/models/pretrained/small.ckpt
12
+ strict_load: false
13
+ seed: 42
14
+ data:
15
+ batch_size: 16
16
+ dna_llm: nucleotide-transformer-v2-500m-multi-species
17
+ max_length: 10000
18
+ sorting_strategy: sorted_chromosome
19
+ dna_context_len: 2001
20
+ num_workers: 8
21
+ pin_memory: false
22
+ _target_: cpgpt.data.cpgpt_datamodule.CpGPTDataModule
23
+ train_dir: ${paths.data_dir}/altumage/processed/train
24
+ val_dir: ${paths.data_dir}/altumage/processed/val
25
+ test_dir: ${paths.data_dir}/altumage/processed/test
26
+ dependencies_dir: ${paths.dependencies_dir}/human
27
+ model:
28
+ optimizer:
29
+ _target_: schedulefree.AdamWScheduleFree
30
+ _partial_: true
31
+ lr: 0.0001
32
+ weight_decay: 0.01
33
+ betas:
34
+ - 0.9
35
+ - 0.95
36
+ net:
37
+ _target_: cpgpt.model.components.model.CpGPT
38
+ d_embedding: 128
39
+ d_hidden: 128
40
+ d_dna_embedding: 1024
41
+ n_attention_heads: 8
42
+ n_layers: 8
43
+ n_mlp_blocks: 3
44
+ dropout: 0.01
45
+ architecture: transformer
46
+ activation: swiglu
47
+ positional_encoding: rotary
48
+ sample_embedding_method: cls
49
+ use_power_norm: false
50
+ fft: false
51
+ use_condition_decoder: true
52
+ condition_size: 1
53
+ use_noise_decoder: false
54
+ mlp_block_bias: false
55
+ mlp_block_norm_type: rmsnorm
56
+ mlp_block_pre_norm: false
57
+ mlp_block_post_norm: false
58
+ transformer_block_bias: false
59
+ transformer_block_norm_type: rmsnorm
60
+ transformer_block_norm_first: true
61
+ transformer_block_dropout: 0.0
62
+ training:
63
+ generative_splits: 2
64
+ binarize_input: false
65
+ contrastive_threshold: 0.5
66
+ diffusion: false
67
+ reconstruct_mode: all
68
+ diffusion_params:
69
+ num_timesteps: 1000
70
+ loss_weights:
71
+ m_mae: 10.0
72
+ m_mae_unc: 1.0
73
+ betas_mae: 0.0
74
+ betas_kld: 0.0
75
+ betas_beta: 0.0
76
+ betas_wd: 1.0
77
+ contrastive: 1.0
78
+ sample_kld: 1.0
79
+ diffusion_mse: 0.0
80
+ condition_loss: 0.1
81
+ condition_decoder_loss: mae
82
+ _target_: cpgpt.model.cpgpt_module.CpGPTLitModule
83
+ compile: true
84
+ callbacks:
85
+ model_checkpoint:
86
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
87
+ dirpath: ${paths.output_dir}/checkpoints
88
+ filename: ${tags[0]}
89
+ monitor: val/condition_loss
90
+ verbose: false
91
+ save_last: true
92
+ save_top_k: 1
93
+ mode: min
94
+ auto_insert_metric_name: false
95
+ save_weights_only: false
96
+ every_n_train_steps: null
97
+ train_time_interval: null
98
+ every_n_epochs: null
99
+ save_on_train_epoch_end: null
100
+ model_summary:
101
+ _target_: lightning.pytorch.callbacks.RichModelSummary
102
+ max_depth: -1
103
+ rich_progress_bar:
104
+ _target_: lightning.pytorch.callbacks.RichProgressBar
105
+ logger:
106
+ wandb:
107
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
108
+ save_dir: ${paths.output_dir}
109
+ offline: false
110
+ id: null
111
+ anonymous: null
112
+ project: CpGPT
113
+ log_model: true
114
+ prefix: ''
115
+ entity: lucascamillo
116
+ group: ''
117
+ tags: ${tags}
118
+ job_type: ''
119
+ trainer:
120
+ _target_: lightning.pytorch.trainer.Trainer
121
+ default_root_dir: ${paths.output_dir}
122
+ min_steps: 1000
123
+ max_steps: 100000
124
+ accelerator: auto
125
+ devices: 1
126
+ precision: 16-mixed
127
+ val_check_interval: 1000
128
+ log_every_n_steps: 1
129
+ detect_anomaly: false
130
+ deterministic: false
131
+ accumulate_grad_batches: 1
132
+ paths:
133
+ root_dir: ${oc.env:PROJECT_ROOT}
134
+ data_dir: ${paths.root_dir}/data/
135
+ dependencies_dir: ${paths.root_dir}/dependencies/
136
+ log_dir: ${paths.root_dir}/logs/
137
+ output_dir: ${hydra:runtime.output_dir}
138
+ work_dir: ${hydra:runtime.cwd}
139
+ extras:
140
+ ignore_warnings: true
141
+ enforce_tags: true
142
+ print_config: true