Safetensors
scldm_cd4 / config.yaml
jperera-czbio's picture
Upload inference checkpoint and config
45486d7
paths:
data_path: /mnt/data/
experiment_path: /mnt/payam/results
model:
module:
_target_: scg_vae.models.LatentDiffusion
vae_model:
_target_: scg_vae.vae.TransformerVAE
encoder:
_target_: scg_vae.nnets.Encoder
n_layer: 8
n_inducing_points: 256
n_embed: 256
n_embed_latent: 16
n_head: 4
n_head_cross: 4
dropout: 0.0
bias: false
multiple_of: 4
layernorm_eps: 1.0e-08
norm_layer: layernorm
positional_encoding: false
latent_projection_type: mlp
decoder:
_target_: scg_vae.nnets.Decoder
n_genes: ${datamodule.label_encoder.n_genes}
n_embed: ${model.module.vae_model.encoder.n_embed}
n_embed_latent: ${model.module.vae_model.encoder.n_embed_latent}
n_head: ${model.module.vae_model.encoder.n_head}
n_head_cross: ${model.module.vae_model.encoder.n_head_cross}
n_layer: ${model.module.vae_model.encoder.n_layer}
n_inducing_points: ${model.module.vae_model.encoder.n_inducing_points}
dropout: ${model.module.vae_model.encoder.dropout}
bias: ${model.module.vae_model.encoder.bias}
multiple_of: ${model.module.vae_model.encoder.multiple_of}
layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
norm_layer: ${model.module.vae_model.encoder.norm_layer}
shared_embedding: true
use_adaln: false
latent_projection_type: ${model.module.vae_model.encoder.latent_projection_type}
input_layer:
_target_: scg_vae.layers.InputTransformerVAE
_partial_: false
n_genes: ${datamodule.label_encoder.n_genes}
n_embed: ${model.module.vae_model.encoder.n_embed}
agg_func: softbin
decoder_head: ${model.decoder_head.${model.decoder_name}}
vae_scheduler:
_target_: scg_vae._utils.wsd_schedule
num_training_steps: 894
final_lr_factor: 0.1
num_warmup_steps: 89
init_div_factor: 100
fract_decay: 0.1
decay_type: sqrt
vae_optimizer:
_target_: scg_vae.optimizers.AdamWLegacy
_partial_: true
lr: 0.032
weight_decay: 0.0
betas:
- 0.9
- 0.95
caution: false
diffusion_model:
_target_: scg_vae.diffusion.FlowMatching
n_embed: 256
n_embed_input: 16
seq_len: 256
n_layer: 8
n_head: 8
dropout: 0.0
bias: true
norm_layer: layernorm
multiple_of: 4
layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
class_vocab_sizes: ${datamodule.label_encoder.class_vocab_sizes}
cfg_dropout_prob: 0.8
condition_strategy: mutually_exclusive
nnet:
_target_: scg_vae.nnets.DiT
n_embed_input: ${model.module.vae_model.encoder.n_embed_latent}
n_embed: 512
n_layer: 8
n_head: 8
dropout: 0.0
bias: false
norm_layer: layernorm
multiple_of: 4
layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
class_vocab_sizes: ${datamodule.label_encoder.class_vocab_sizes}
seq_len: ${model.module.vae_model.encoder.n_inducing_points}
use_gpt_for_gene_ko: false
gene_ko_class_name: guide_target_ensembl
control_perturbation_name: NTC
condition_strategy: joint
n_inducing_points: ${model.module.vae_model.encoder.n_inducing_points}
sigma: 0.0001
v: 0.0
timesteps: 50
cfm: false
transport:
_target_: scg_vae.transport.create_transport
path_type: Linear
prediction: velocity
loss_weight: velocity
train_eps: 1.0e-05
sample_eps: 1.0e-05
diffusion_optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 0.00096
weight_decay: 0.01
betas:
- 0.9
- 0.999
eps: 1.0e-08
diffusion_scheduler:
_target_: scg_vae._utils.wsd_schedule
num_training_steps: 894
final_lr_factor: 0.1
num_warmup_steps: 89
init_div_factor: 100
fract_decay: 1.0
decay_type: cosine
ema_decay: 0.9999
ema_update_every: 10
update_after_step: 10000
allow_different_devices: true
use_foreach: true
calculate_grad_norms: false
compile: true
compile_mode: default
vae_as_tokenizer: null
generation_args:
conditioning_strategy: joint
guidance_weight: ${datamodule.dataset_params.${datamodule.dataset}.guidance_weight}
num_inference_steps: 50
seed: 56
eval_generation:
freq: 50
sample_size: 10000
enabled: true
warmup_epochs: ${eval:'${training.num_epochs} // 2'}
store_latents: false
decoder_name: negative_binomial_shared_theta
decoder_head:
negative_binomial_shared_theta:
_target_: scg_vae.stochastic_layers.NegativeBinomialTransformerLayer
n_genes: 3699
shared_theta: true
n_embed: 256
norm_layer: layernorm
layernorm_eps: 1.0e-08
negative_binomial_unshared_theta:
_target_: scg_vae.stochastic_layers.NegativeBinomialTransformerLayer
n_genes: ${datamodule.label_encoder.n_genes}
shared_theta: false
n_embed: ${model.module.vae_model.encoder.n_embed}
norm_layer: ${model.module.vae_model.encoder.norm_layer}
layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
discretized_gaussian:
_target_: scg_vae.stochastic_layers.TruncatedDiscretizedGaussianTransformerLayer
n_embed: ${model.module.vae_model.encoder.n_embed}
norm_layer: ${model.module.vae_model.encoder.norm_layer}
layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
discretized_logistic:
_target_: scg_vae.stochastic_layers.TruncatedDiscretizedLogisticTransformerLayer
n_embed: ${model.module.vae_model.encoder.n_embed}
norm_layer: ${model.module.vae_model.encoder.norm_layer}
layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
poisson:
_target_: scg_vae.stochastic_layers.PossionTransformerLayer
n_genes: ${datamodule.label_encoder.n_genes}
n_embed: ${model.module.vae_model.encoder.n_embed}
norm_layer: ${model.module.vae_model.encoder.norm_layer}
layernorm_eps: ${model.module.vae_model.encoder.layernorm_eps}
negative_binomial_decoupled_shared_theta:
_target_: scg_vae.stochastic_layers.NegativeBinomialTransformerLayerDecoupled
n_genes: ${datamodule.label_encoder.n_genes}
shared_theta: true
n_embed: ${model.module.vae_model.encoder.n_embed}
use_gene_bias: true
min_theta: 1.0e-05
eps: 0.0
negative_binomial_decoupled_unshared_theta:
_target_: scg_vae.stochastic_layers.NegativeBinomialTransformerLayerDecoupled
n_genes: ${datamodule.label_encoder.n_genes}
shared_theta: false
n_embed: ${model.module.vae_model.encoder.n_embed}
use_gene_bias: true
min_theta: 1.0e-05
eps: 0.0
batch_size: 256
test_batch_size: 128
num_parameters: 59380351
flops: 15112077312
get_flops:
_target_: scg_vae.flops.get_flops
seq_len: ${model.module.diffusion_model.seq_len}
vocab_size: ${model.module.diffusion_model.seq_len}
num_heads: ${model.module.diffusion_model.n_head}
swiglu: false
n_layers: ${model.module.diffusion_model.n_layer}
d_model: ${model.module.diffusion_model.n_embed}
key_size: ${model.module.diffusion_model.n_embed}
ffw_size: ${eval:'${model.module.diffusion_model.n_embed} * ${model.module.diffusion_model.multiple_of}'}
training:
trainer:
_target_: pytorch_lightning.Trainer
_partial_: true
max_steps: 894
enable_progress_bar: false
precision: 32
log_every_n_steps: 30
val_check_interval: null
limit_val_batches: null
check_val_every_n_epoch: null
sync_batchnorm: true
accelerator: gpu
enable_checkpointing: true
deterministic: false
gradient_clip_val: 10.0
gradient_clip_algorithm: norm
accumulate_grad_batches: 1
logger:
csv:
_target_: pytorch_lightning.loggers.csv_logs.CSVLogger
save_dir: ${paths.experiment_path}/csv_logger
name: ${experiment_name}
version: null
prefix: ''
wandb:
_target_: pytorch_lightning.loggers.wandb.WandbLogger
_partial_: true
save_dir: ${paths.experiment_path}/wandb_logger
project: ${experiment_name}
entity: null
name: ${experiment_name}
job_type: sweep
id: null
resume: allow
resume_from: null
settings:
init_timeout: 120
callbacks:
lr_monitor:
_target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
log_weight_decay: true
model_checkpoints:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: /mnt/payam/results/checkpoints/FM_marson_expr_31/foo=bar
filename: last.ckpt
save_weights_only: false
save_on_train_epoch_end: true
save_top_k: -1
monitor: val_loss
mode: min
enable_version_counter: false
save_last: true
num_epochs: 1
datamodule:
datamodule:
_target_: scg_vae.datamodule.SimplifiedDataModule
train_adata_path: ${datamodule.dataset_params.${datamodule.dataset}.adata_train}
test_adata_path: ${datamodule.dataset_params.${datamodule.dataset}.adata_test}
adata_attr: ${datamodule.dataset_params.${datamodule.dataset}.adata_attr}
adata_key: ${datamodule.dataset_params.${datamodule.dataset}.adata_key}
vocabulary_encoder: ${datamodule.label_encoder}
batch_size: ${model.batch_size}
test_batch_size: ${model.test_batch_size}
num_workers: 16
seed: ${seed}
prefetch_factor: 8
persistent_workers: true
drop_last_indices: true
drop_incomplete_batch: true
sample_genes: none
genes_seq_len: ${datamodule.dataset_params.${datamodule.dataset}.genes_seq_len}
val_as_test: false
hvg_genes_path: ${datamodule.dataset_params.${datamodule.dataset}.hvg_genes_path}
label_encoder:
_target_: scg_vae.encoder.VocabularyEncoderSimplified
adata_path: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded//train_hvg/adata_0.h5ad
mask_token: <MASK>
mask_token_idx: 0
n_genes: 3699
class_vocab_sizes:
donor_id: 4
guide_target_ensembl: 10571
experimental_perturbation_time_point: 3
guidance_weight:
donor_id: 1.0
guide_target_ensembl: 1.0
experimental_perturbation_time_point: 1.0
mu_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_mu.pkl
sd_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_sd.pkl
condition_strategy: joint
gpt_gene_embeddings_path: null
dataset: marson_expr_2_hvg
marson_base_path: /mnt/data/Marson/ML_splits
dataset_params:
marson_expr_2_hvg_gpt:
adata_train: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/train_hvg/adata_0.h5ad
adata_test: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/valid_hvg/adata_0.h5ad
hvg_genes_path: null
gpt_gene_embeddings_path: /mnt/data/Marson/embeddings/gpt/gpt_embeddings_from_json_postfiltered_ensembl.pkl
n_genes: 3699
mu_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_mu.pkl
sd_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_sd.pkl
genes_seq_len: 3699
adata_attr: X
adata_key: null
class_vocab_sizes:
donor_id: 4
guide_target_ensembl: 10571
experimental_perturbation_time_point: 3
guidance_weight:
donor_id: 1.0
guide_target_ensembl: 1.0
experimental_perturbation_time_point: 1.0
condition_strategy: joint
marson_expr_2_hvg:
adata_train: ${datamodule.base_data_path}/train_hvg/adata_0.h5ad
adata_test: ${datamodule.base_data_path}/test_hvg/adata_0.h5ad
hvg_genes_path: null
gpt_gene_embeddings_path: null
n_genes: 3699
mu_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_mu.pkl
sd_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/size_factors_hvg/log_size_factor_sd.pkl
genes_seq_len: 3699
adata_attr: X
adata_key: null
class_vocab_sizes:
donor_id: 4
guide_target_ensembl: 10571
experimental_perturbation_time_point: 3
guidance_weight:
donor_id: 1.0
guide_target_ensembl: 1.0
experimental_perturbation_time_point: 1.0
condition_strategy: joint
marson_expr_2:
adata_train: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/train/adata_0.h5ad
adata_test: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/valid/adata_0.h5ad
hvg_genes_path: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_allG_sharded/hvgs_union_3699.pkl
gpt_gene_embeddings_path: null
n_genes: 18069
mu_size_factor: null
sd_size_factor: null
genes_seq_len: 18069
adata_attr: X
adata_key: null
class_vocab_sizes:
donor_id: 4
guide_target_ensembl: 10571
experimental_perturbation_time_point: 3
guidance_weight:
donor_id: 1.0
guide_target_ensembl: 1.0
experimental_perturbation_time_point: 1.0
condition_strategy: joint
marson_expr_1:
adata_train: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_sharded/train/adata_0.h5ad
adata_test: ${datamodule.marson_base_path}/donor_timepoint_single_guide_fewshot_sharded/test/adata_0.h5ad
hvg_genes_path: null
gpt_gene_embeddings_path: null
n_genes: 3699
mu_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_sharded/marson_log_size_factor_mu.pkl
sd_size_factor: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_sharded/marson_log_size_factor_sd.pkl
genes_seq_len: 3699
adata_attr: X
adata_key: null
class_vocab_sizes:
donor_id: 4
guide_target_ensembl: 10571
experimental_perturbation_time_point: 3
guidance_weight:
donor_id: 1.0
guide_target_ensembl: 1.0
experimental_perturbation_time_point: 1.0
condition_strategy: joint
marson_expr_0:
adata_train: ${datamodule.base_data_path}/train/adata.h5ad
adata_test: ${datamodule.base_data_path}/train/adata.h5ad
gpt_gene_embeddings_path: null
n_genes: 3699
mu_size_factor: null
sd_size_factor: null
genes_seq_len: ${datamodule.dataset_params.${datamodule.dataset}.n_genes}
adata_attr: X
adata_key: null
class_vocab_sizes:
donor_id: 4
guide_target_ensembl: 9674
experimental_perturbation_time_point: 3
guidance_weight:
donor_id: 1.0
guide_target_ensembl: 1.0
experimental_perturbation_time_point: 1.0
condition_strategy: joint
czb_cd4_naive_holdout:
adata_train: ${datamodule.base_data_path}/CD4_hold_out/train.h5ad
adata_test: ${datamodule.base_data_path}/CD4_hold_out/test.h5ad
gpt_gene_embeddings_path: null
n_genes: 2000
mu_size_factor: null
sd_size_factor: null
genes_seq_len: 2000
adata_attr: X
adata_key: null
class_vocab_sizes:
cell_type: 18
cytokine: 91
guidance_weight:
cell_type: 1.0
cytokine: 1.0
condition_strategy: joint
replogle:
adata_train: ${datamodule.base_data_path}/czb_data/Replogle_Nadig_from_STATE/processed/train_asSTATE_hvg.h5ad
adata_test: ${datamodule.base_data_path}/czb_data/Replogle_Nadig_from_STATE/processed/test_asSTATE_hvg.h5ad
gpt_gene_embeddings_path: null
n_genes: 2000
class_vocab_sizes:
cell_line: 4
gene: 2024
guidance_weight:
cell_line: 1.0
gene: 1.0
adata_attr: X
adata_key: null
condition_strategy: joint
mu_size_factor: null
sd_size_factor: null
base_data_path: /mnt/data/Marson/ML_splits/donor_timepoint_single_guide_fewshot_allG_sharded/
experiment_name: inference_fm_31
seed: 56
inference_path: /mnt/payam/results/generated_cells/inference_fm_test_31
dataset_generation_idx: 0
ckpt_file: last.ckpt