Safetensors
jperera-czbio commited on
Commit
45486d7
·
1 Parent(s): a2218ad

Upload inference checkpoint and config

Browse files
Files changed (1) hide show
  1. config.yaml +441 -0
config.yaml ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ paths:
2
+ data_path: /mnt/data/
3
+ experiment_path: /mnt/payam/results
4
+ model:
5
+ module:
6
+ _target_: scg_vae.models.LatentDiffusion
7
+ vae_model:
8
+ _target_: scg_vae.vae.TransformerVAE
9
+ encoder:
10
+ _target_: scg_vae.nnets.Encoder
11
+ n_layer: 8
12
+ n_inducing_points: 256
13
+ n_embed: 256
14
+ n_embed_latent: 16
15
+ n_head: 4
16
+ n_head_cross: 4
17
+ dropout: 0.0
18
+ bias: false
19
+ multiple_of: 4
20
+ layernorm_eps: 1.0e-08
21
+ norm_layer: layernorm
22
+ positional_encoding: false
23
+ latent_projection_type: mlp
24
+ decoder:
25
+ _target_: scg_vae.nnets.Decoder
26
+ n_genes: ${datamodule.label_encoder.n_genes}
27
+ n_embed: ${model.module.vae_model.encoder.n_embed}
28
+ n_embed_latent: ${model.module.vae_model.encoder.n_embed_latent}
29
+ n_head: ${model.module.vae_model.encoder.n_head}
30
+ n_head_cross: ${model.module.vae_model.encoder.n_head_cross}
31
+ n_layer: ${model.module.vae_model.encoder.n_layer}
32
+ n_inducing_points: ${model.module.vae_model.encoder.n_inducing_points}
33
+ dropout: ${model.module.vae_model.encoder.dropout}
34
+ bias: ${model.module.vae_model.encoder.bias}
35
+ multiple_of: ${model.module.vae_model.encoder.multiple_of}
36
+ layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
37
+ norm_layer: ${model.module.vae_model.encoder.norm_layer}
38
+ shared_embedding: true
39
+ use_adaln: false
40
+ latent_projection_type: ${model.module.vae_model.encoder.latent_projection_type}
41
+ input_layer:
42
+ _target_: scg_vae.layers.InputTransformerVAE
43
+ _partial_: false
44
+ n_genes: ${datamodule.label_encoder.n_genes}
45
+ n_embed: ${model.module.vae_model.encoder.n_embed}
46
+ agg_func: softbin
47
+ decoder_head: ${model.decoder_head.${model.decoder_name}}
48
+ vae_scheduler:
49
+ _target_: scg_vae._utils.wsd_schedule
50
+ num_training_steps: 894
51
+ final_lr_factor: 0.1
52
+ num_warmup_steps: 89
53
+ init_div_factor: 100
54
+ fract_decay: 0.1
55
+ decay_type: sqrt
56
+ vae_optimizer:
57
+ _target_: scg_vae.optimizers.AdamWLegacy
58
+ _partial_: true
59
+ lr: 0.032
60
+ weight_decay: 0.0
61
+ betas:
62
+ - 0.9
63
+ - 0.95
64
+ caution: false
65
+ diffusion_model:
66
+ _target_: scg_vae.diffusion.FlowMatching
67
+ n_embed: 256
68
+ n_embed_input: 16
69
+ seq_len: 256
70
+ n_layer: 8
71
+ n_head: 8
72
+ dropout: 0.0
73
+ bias: true
74
+ norm_layer: layernorm
75
+ multiple_of: 4
76
+ layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
77
+ class_vocab_sizes: ${datamodule.label_encoder.class_vocab_sizes}
78
+ cfg_dropout_prob: 0.8
79
+ condition_strategy: mutually_exclusive
80
+ nnet:
81
+ _target_: scg_vae.nnets.DiT
82
+ n_embed_input: ${model.module.vae_model.encoder.n_embed_latent}
83
+ n_embed: 512
84
+ n_layer: 8
85
+ n_head: 8
86
+ dropout: 0.0
87
+ bias: false
88
+ norm_layer: layernorm
89
+ multiple_of: 4
90
+ layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
91
+ class_vocab_sizes: ${datamodule.label_encoder.class_vocab_sizes}
92
+ seq_len: ${model.module.vae_model.encoder.n_inducing_points}
93
+ use_gpt_for_gene_ko: false
94
+ gene_ko_class_name: guide_target_ensembl
95
+ control_perturbation_name: NTC
96
+ condition_strategy: joint
97
+ n_inducing_points: ${model.module.vae_model.encoder.n_inducing_points}
98
+ sigma: 0.0001
99
+ v: 0.0
100
+ timesteps: 50
101
+ cfm: false
102
+ transport:
103
+ _target_: scg_vae.transport.create_transport
104
+ path_type: Linear
105
+ prediction: velocity
106
+ loss_weight: velocity
107
+ train_eps: 1.0e-05
108
+ sample_eps: 1.0e-05
109
+ diffusion_optimizer:
110
+ _target_: torch.optim.AdamW
111
+ _partial_: true
112
+ lr: 0.00096
113
+ weight_decay: 0.01
114
+ betas:
115
+ - 0.9
116
+ - 0.999
117
+ eps: 1.0e-08
118
+ diffusion_scheduler:
119
+ _target_: scg_vae._utils.wsd_schedule
120
+ num_training_steps: 894
121
+ final_lr_factor: 0.1
122
+ num_warmup_steps: 89
123
+ init_div_factor: 100
124
+ fract_decay: 1.0
125
+ decay_type: cosine
126
+ ema_decay: 0.9999
127
+ ema_update_every: 10
128
+ update_after_step: 10000
129
+ allow_different_devices: true
130
+ use_foreach: true
131
+ calculate_grad_norms: false
132
+ compile: true
133
+ compile_mode: default
134
+ vae_as_tokenizer: null
135
+ generation_args:
136
+ conditioning_strategy: joint
137
+ guidance_weight: ${datamodule.dataset_params.${datamodule.dataset}.guidance_weight}
138
+ num_inference_steps: 50
139
+ seed: 56
140
+ eval_generation:
141
+ freq: 50
142
+ sample_size: 10000
143
+ enabled: true
144
+ warmup_epochs: ${eval:'${training.num_epochs} // 2'}
145
+ store_latents: false
146
+ decoder_name: negative_binomial_shared_theta
147
+ decoder_head:
148
+ negative_binomial_shared_theta:
149
+ _target_: scg_vae.stochastic_layers.NegativeBinomialTransformerLayer
150
+ n_genes: 3699
151
+ shared_theta: true
152
+ n_embed: 256
153
+ norm_layer: layernorm
154
+ layernorm_eps: 1.0e-08
155
+ negative_binomial_unshared_theta:
156
+ _target_: scg_vae.stochastic_layers.NegativeBinomialTransformerLayer
157
+ n_genes: ${datamodule.label_encoder.n_genes}
158
+ shared_theta: false
159
+ n_embed: ${model.module.vae_model.encoder.n_embed}
160
+ norm_layer: ${model.module.vae_model.encoder.norm_layer}
161
+ layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
162
+ discretized_gaussian:
163
+ _target_: scg_vae.stochastic_layers.TruncatedDiscretizedGaussianTransformerLayer
164
+ n_embed: ${model.module.vae_model.encoder.n_embed}
165
+ norm_layer: ${model.module.vae_model.encoder.norm_layer}
166
+ layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
167
+ discretized_logistic:
168
+ _target_: scg_vae.stochastic_layers.TruncatedDiscretizedLogisticTransformerLayer
169
+ n_embed: ${model.module.vae_model.encoder.n_embed}
170
+ norm_layer: ${model.module.vae_model.encoder.norm_layer}
171
+ layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
172
+ poisson:
173
+ _target_: scg_vae.stochastic_layers.PossionTransformerLayer
174
+ n_genes: ${datamodule.label_encoder.n_genes}
175
+ n_embed: ${model.module.vae_model.encoder.n_embed}
176
+ norm_layer: ${model.module.vae_model.encoder.norm_layer}
177
+ layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
178
+ negative_binomial_decoupled_shared_theta:
179
+ _target_: scg_vae.stochastic_layers.NegativeBinomialTransformerLayerDecoupled
180
+ n_genes: ${datamodule.label_encoder.n_genes}
181
+ shared_theta: true
182
+ n_embed: ${model.module.vae_model.encoder.n_embed}
183
+ use_gene_bias: true
184
+ min_theta: 1.0e-05
185
+ eps: 0.0
186
+ negative_binomial_decoupled_unshared_theta:
187
+ _target_: scg_vae.stochastic_layers.NegativeBinomialTransformerLayerDecoupled
188
+ n_genes: ${datamodule.label_encoder.n_genes}
189
+ shared_theta: false
190
+ n_embed: ${model.module.vae_model.encoder.n_embed}
191
+ use_gene_bias: true
192
+ min_theta: 1.0e-05
193
+ eps: 0.0
194
+ batch_size: 256
195
+ test_batch_size: 128
196
+ num_parameters: 59380351
197
+ flops: 15112077312
198
+ get_flops:
199
+ _target_: scg_vae.flops.get_flops
200
+ seq_len: ${model.module.diffusion_model.seq_len}
201
+ vocab_size: ${model.module.diffusion_model.seq_len}
202
+ num_heads: ${model.module.diffusion_model.n_head}
203
+ swiglu: false
204
+ n_layers: ${model.module.diffusion_model.n_layer}
205
+ d_model: ${model.module.diffusion_model.n_embed}
206
+ key_size: ${model.module.diffusion_model.n_embed}
207
+ ffw_size: ${eval:'${model.module.diffusion_model.n_embed} * ${model.module.diffusion_model.multiple_of}'}
208
+ training:
209
+ trainer:
210
+ _target_: pytorch_lightning.Trainer
211
+ _partial_: true
212
+ max_steps: 894
213
+ enable_progress_bar: false
214
+ precision: 32
215
+ log_every_n_steps: 30
216
+ val_check_interval: null
217
+ limit_val_batches: null
218
+ check_val_every_n_epoch: null
219
+ sync_batchnorm: true
220
+ accelerator: gpu
221
+ enable_checkpointing: true
222
+ deterministic: false
223
+ gradient_clip_val: 10.0
224
+ gradient_clip_algorithm: norm
225
+ accumulate_grad_batches: 1
226
+ logger:
227
+ csv:
228
+ _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
229
+ save_dir: ${paths.experiment_path}/csv_logger
230
+ name: ${experiment_name}
231
+ version: null
232
+ prefix: ''
233
+ wandb:
234
+ _target_: pytorch_lightning.loggers.wandb.WandbLogger
235
+ _partial_: true
236
+ save_dir: ${paths.experiment_path}/wandb_logger
237
+ project: ${experiment_name}
238
+ entity: null
239
+ name: ${experiment_name}
240
+ job_type: sweep
241
+ id: null
242
+ resume: allow
243
+ resume_from: null
244
+ settings:
245
+ init_timeout: 120
246
+ callbacks:
247
+ lr_monitor:
248
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
249
+ logging_interval: step
250
+ log_weight_decay: true
251
+ model_checkpoints:
252
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
253
+ dirpath: /mnt/payam/results/checkpoints/FM_marson_expr_31/foo=bar
254
+ filename: last.ckpt
255
+ save_weights_only: false
256
+ save_on_train_epoch_end: true
257
+ save_top_k: -1
258
+ monitor: val_loss
259
+ mode: min
260
+ enable_version_counter: false
261
+ save_last: true
262
+ num_epochs: 1
263
+ datamodule:
264
+ datamodule:
265
+ _target_: scg_vae.datamodule.SimplifiedDataModule
266
+ train_adata_path: ${datamodule.dataset_params.${datamodule.dataset}.adata_train}
267
+ test_adata_path: ${datamodule.dataset_params.${datamodule.dataset}.adata_test}
268
+ adata_attr: ${datamodule.dataset_params.${datamodule.dataset}.adata_attr}
269
+ adata_key: ${datamodule.dataset_params.${datamodule.dataset}.adata_key}
270
+ vocabulary_encoder: ${datamodule.label_encoder}
271
+ batch_size: ${model.batch_size}
272
+ test_batch_size: ${model.test_batch_size}
273
+ num_workers: 16
274
+ seed: ${seed}
275
+ prefetch_factor: 8
276
+ persistent_workers: true
277
+ drop_last_indices: true
278
+ drop_incomplete_batch: true
279
+ sample_genes: none
280
+ genes_seq_len: ${datamodule.dataset_params.${datamodule.dataset}.genes_seq_len}
281
+ val_as_test: false
282
+ hvg_genes_path: ${datamodule.dataset_params.${datamodule.dataset}.hvg_genes_path}
283
+ label_encoder:
284
+ _target_: scg_vae.encoder.VocabularyEncoderSimplified
285
+ adata_path: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded//train_hvg/adata_0.h5ad
286
+ mask_token: <MASK>
287
+ mask_token_idx: 0
288
+ n_genes: 3699
289
+ class_vocab_sizes:
290
+ donor_id: 4
291
+ guide_target_ensembl: 10571
292
+ experimental_perturbation_time_point: 3
293
+ guidance_weight:
294
+ donor_id: 1.0
295
+ guide_target_ensembl: 1.0
296
+ experimental_perturbation_time_point: 1.0
297
+ mu_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_mu.pkl
298
+ sd_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_sd.pkl
299
+ condition_strategy: joint
300
+ gpt_gene_embeddings_path: null
301
+ dataset: marson_expr_2_hvg
302
+ marson_base_path: /mnt/data/Marson/ML_splits
303
+ dataset_params:
304
+ marson_expr_2_hvg_gpt:
305
+ adata_train: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/train_hvg/adata_0.h5ad
306
+ adata_test: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/valid_hvg/adata_0.h5ad
307
+ hvg_genes_path: null
308
+ gpt_gene_embeddings_path: /mnt/data/Marson/embeddings/gpt/gpt_embeddings_from_json_postfiltered_ensembl.pkl
309
+ n_genes: 3699
310
+ mu_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_mu.pkl
311
+ sd_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_sd.pkl
312
+ genes_seq_len: 3699
313
+ adata_attr: X
314
+ adata_key: null
315
+ class_vocab_sizes:
316
+ donor_id: 4
317
+ guide_target_ensembl: 10571
318
+ experimental_perturbation_time_point: 3
319
+ guidance_weight:
320
+ donor_id: 1.0
321
+ guide_target_ensembl: 1.0
322
+ experimental_perturbation_time_point: 1.0
323
+ condition_strategy: joint
324
+ marson_expr_2_hvg:
325
+ adata_train: ${datamodule.base_data_path}/train_hvg/adata_0.h5ad
326
+ adata_test: ${datamodule.base_data_path}/test_hvg/adata_0.h5ad
327
+ hvg_genes_path: null
328
+ gpt_gene_embeddings_path: null
329
+ n_genes: 3699
330
+ mu_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_mu.pkl
331
+ sd_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_sd.pkl
332
+ genes_seq_len: 3699
333
+ adata_attr: X
334
+ adata_key: null
335
+ class_vocab_sizes:
336
+ donor_id: 4
337
+ guide_target_ensembl: 10571
338
+ experimental_perturbation_time_point: 3
339
+ guidance_weight:
340
+ donor_id: 1.0
341
+ guide_target_ensembl: 1.0
342
+ experimental_perturbation_time_point: 1.0
343
+ condition_strategy: joint
344
+ marson_expr_2:
345
+ adata_train: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/train/adata_0.h5ad
346
+ adata_test: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/valid/adata_0.h5ad
347
+ hvg_genes_path: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/hvgs_union_3699.pkl
348
+ gpt_gene_embeddings_path: null
349
+ n_genes: 18069
350
+ mu_size_factor: null
351
+ sd_size_factor: null
352
+ genes_seq_len: 18069
353
+ adata_attr: X
354
+ adata_key: null
355
+ class_vocab_sizes:
356
+ donor_id: 4
357
+ guide_target_ensembl: 10571
358
+ experimental_perturbation_time_point: 3
359
+ guidance_weight:
360
+ donor_id: 1.0
361
+ guide_target_ensembl: 1.0
362
+ experimental_perturbation_time_point: 1.0
363
+ condition_strategy: joint
364
+ marson_expr_1:
365
+ adata_train: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_sharded/train/adata_0.h5ad
366
+ adata_test: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_sharded/test/adata_0.h5ad
367
+ hvg_genes_path: null
368
+ gpt_gene_embeddings_path: null
369
+ n_genes: 3699
370
+ mu_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_sharded/marson_log_size_factor_mu.pkl
371
+ sd_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_sharded/marson_log_size_factor_sd.pkl
372
+ genes_seq_len: 3699
373
+ adata_attr: X
374
+ adata_key: null
375
+ class_vocab_sizes:
376
+ donor_id: 4
377
+ guide_target_ensembl: 10571
378
+ experimental_perturbation_time_point: 3
379
+ guidance_weight:
380
+ donor_id: 1.0
381
+ guide_target_ensembl: 1.0
382
+ experimental_perturbation_time_point: 1.0
383
+ condition_strategy: joint
384
+ marson_expr_0:
385
+ adata_train: ${datamodule.base_data_path}/train/adata.h5ad
386
+ adata_test: ${datamodule.base_data_path}/train/adata.h5ad
387
+ gpt_gene_embeddings_path: null
388
+ n_genes: 3699
389
+ mu_size_factor: null
390
+ sd_size_factor: null
391
+ genes_seq_len: ${datamodule.dataset_params.${datamodule.dataset}.n_genes}
392
+ adata_attr: X
393
+ adata_key: null
394
+ class_vocab_sizes:
395
+ donor_id: 4
396
+ guide_target_ensembl: 9674
397
+ experimental_perturbation_time_point: 3
398
+ guidance_weight:
399
+ donor_id: 1.0
400
+ guide_target_ensembl: 1.0
401
+ experimental_perturbation_time_point: 1.0
402
+ condition_strategy: joint
403
+ czb_cd4_naive_holdout:
404
+ adata_train: ${datamodule.base_data_path}/CD4_hold_out/train.h5ad
405
+ adata_test: ${datamodule.base_data_path}/CD4_hold_out/test.h5ad
406
+ gpt_gene_embeddings_path: null
407
+ n_genes: 2000
408
+ mu_size_factor: null
409
+ sd_size_factor: null
410
+ genes_seq_len: 2000
411
+ adata_attr: X
412
+ adata_key: null
413
+ class_vocab_sizes:
414
+ cell_type: 18
415
+ cytokine: 91
416
+ guidance_weight:
417
+ cell_type: 1.0
418
+ cytokine: 1.0
419
+ condition_strategy: joint
420
+ replogle:
421
+ adata_train: ${datamodule.base_data_path}/czb_data/Replogle_Nadig_from_STATE/processed/train_asSTATE_hvg.h5ad
422
+ adata_test: ${datamodule.base_data_path}/czb_data/Replogle_Nadig_from_STATE/processed/test_asSTATE_hvg.h5ad
423
+ gpt_gene_embeddings_path: null
424
+ n_genes: 2000
425
+ class_vocab_sizes:
426
+ cell_line: 4
427
+ gene: 2024
428
+ guidance_weight:
429
+ cell_line: 1.0
430
+ gene: 1.0
431
+ adata_attr: X
432
+ adata_key: null
433
+ condition_strategy: joint
434
+ mu_size_factor: null
435
+ sd_size_factor: null
436
+ base_data_path: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/
437
+ experiment_name: inference_fm_31
438
+ seed: 56
439
+ inference_path: /mnt/payam/results/generated_cells/inference_fm_test_31
440
+ dataset_generation_idx: 0
441
+ ckpt_file: last.ckpt