lucascamillomd commited on
Commit
2009665
·
verified ·
1 Parent(s): 1f18ad9

Upload config/small.yaml with huggingface_hub

Browse files
Files changed (1) hide show
  1. config/small.yaml +140 -0
config/small.yaml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ tags:
3
+ - pretrain_small2
4
+ - pretraining
5
+ - small
6
+ - cpgcorpus
7
+ train: true
8
+ test: true
9
+ ckpt_path: null
10
+ model_ckpt_path: dependencies/models/pretrained/small.ckpt
11
+ strict_load: true
12
+ seed: 42
13
+ data:
14
+ batch_size: 8
15
+ dna_llm: nucleotide-transformer-v2-500m-multi-species
16
+ max_length: 20000
17
+ sorting_strategy: sorted_chromosome
18
+ dna_context_len: 2001
19
+ num_workers: 8
20
+ pin_memory: false
21
+ _target_: cpgpt.data.cpgpt_datamodule.CpGPTDataModule
22
+ train_dir: ${paths.data_dir}/cpgcorpus/processed/train
23
+ val_dir: ${paths.data_dir}/cpgcorpus/processed/val
24
+ test_dir: ${paths.data_dir}/cpgcorpus/processed/test
25
+ dependencies_dir: ${paths.dependencies_dir}/human
26
+ model:
27
+ optimizer:
28
+ _target_: schedulefree.AdamWScheduleFree
29
+ _partial_: true
30
+ lr: 0.0005
31
+ weight_decay: 0.01
32
+ betas:
33
+ - 0.9
34
+ - 0.95
35
+ net:
36
+ _target_: cpgpt.model.components.model.CpGPT
37
+ d_embedding: 128
38
+ d_hidden: 128
39
+ d_dna_embedding: 1024
40
+ n_attention_heads: 8
41
+ n_layers: 8
42
+ n_mlp_blocks: 3
43
+ dropout: 0.01
44
+ architecture: transformer
45
+ activation: swiglu
46
+ positional_encoding: rotary
47
+ sample_embedding_method: cls
48
+ use_power_norm: false
49
+ fft: false
50
+ use_condition_decoder: false
51
+ condition_size: 0
52
+ use_noise_decoder: false
53
+ mlp_block_bias: false
54
+ mlp_block_norm_type: rmsnorm
55
+ mlp_block_pre_norm: false
56
+ mlp_block_post_norm: false
57
+ transformer_block_bias: false
58
+ transformer_block_norm_type: rmsnorm
59
+ transformer_block_norm_first: true
60
+ transformer_block_dropout: 0.0
61
+ training:
62
+ generative_splits: 4
63
+ binarize_input: false
64
+ contrastive_threshold: 0.4
65
+ diffusion: false
66
+ reconstruct_mode: all
67
+ diffusion_params:
68
+ num_timesteps: 1000
69
+ loss_weights:
70
+ m_mae: 10.0
71
+ m_mae_unc: 1.0
72
+ betas_mae: 0.0
73
+ betas_kld: 0.0
74
+ betas_beta: 0.0
75
+ betas_wd: 1.0
76
+ contrastive: 1.0
77
+ sample_kld: 1.0
78
+ diffusion_mse: 0.0
79
+ condition_loss: 0.0
80
+ _target_: cpgpt.model.cpgpt_module.CpGPTLitModule
81
+ compile: true
82
+ callbacks:
83
+ model_checkpoint:
84
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
85
+ dirpath: ${paths.output_dir}/checkpoints
86
+ filename: ${tags[0]}
87
+ monitor: val/loss
88
+ verbose: false
89
+ save_last: true
90
+ save_top_k: 1
91
+ mode: min
92
+ auto_insert_metric_name: false
93
+ save_weights_only: false
94
+ every_n_train_steps: null
95
+ train_time_interval: null
96
+ every_n_epochs: null
97
+ save_on_train_epoch_end: null
98
+ model_summary:
99
+ _target_: lightning.pytorch.callbacks.RichModelSummary
100
+ max_depth: -1
101
+ rich_progress_bar:
102
+ _target_: lightning.pytorch.callbacks.RichProgressBar
103
+ logger:
104
+ wandb:
105
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
106
+ save_dir: ${paths.output_dir}
107
+ offline: false
108
+ id: null
109
+ anonymous: null
110
+ project: CpGPT
111
+ log_model: true
112
+ prefix: ''
113
+ entity: lucascamillo
114
+ group: ''
115
+ tags: ${tags}
116
+ job_type: ''
117
+ trainer:
118
+ _target_: lightning.pytorch.trainer.Trainer
119
+ default_root_dir: ${paths.output_dir}
120
+ min_steps: 1000
121
+ max_steps: 100000
122
+ accelerator: auto
123
+ devices: 1
124
+ precision: 16-mixed
125
+ val_check_interval: 1000
126
+ log_every_n_steps: 1
127
+ detect_anomaly: false
128
+ deterministic: false
129
+ accumulate_grad_batches: 1
130
+ paths:
131
+ root_dir: ${oc.env:PROJECT_ROOT}
132
+ data_dir: ${paths.root_dir}/data/
133
+ dependencies_dir: ${paths.root_dir}/dependencies/
134
+ log_dir: ${paths.root_dir}/logs/
135
+ output_dir: ${hydra:runtime.output_dir}
136
+ work_dir: ${hydra:runtime.cwd}
137
+ extras:
138
+ ignore_warnings: true
139
+ enforce_tags: true
140
+ print_config: true