GenePooler
#429
by
jamieb-nvs
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- Geneformer-V2-316M/model.safetensors +0 -3
- MANIFEST.in +4 -9
- README.md +16 -16
- config.json +7 -7
- docs/source/geneformer.in_silico_perturber.rst +1 -1
- examples/cell_classification.ipynb +5 -8
- examples/distributed_multitask_cell_classification.ipynb +0 -149
- examples/extract_and_plot_cell_embeddings.ipynb +3 -5
- examples/gene_classification.ipynb +7 -11
- examples/in_silico_perturbation.ipynb +10 -17
- examples/multitask_cell_classification.ipynb +3 -3
- examples/tokenizing_scRNAseq_data.ipynb +8 -14
- {Geneformer-V2-104M → fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522}/config.json +6 -6
- fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +3 -0
- fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
- fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
- fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
- fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
- fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
- fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
- fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
- geneformer/__init__.py +4 -9
- geneformer/classifier.py +9 -61
- geneformer/classifier_utils.py +6 -36
- geneformer/collator_for_classification.py +2 -7
- geneformer/emb_extractor.py +51 -132
- geneformer/{ensembl_mapping_dict_gc104M.pkl → ensembl_mapping_dict_gc95M.pkl} +0 -0
- geneformer/evaluation_utils.py +14 -33
- geneformer/{gene_median_dictionary_gc104M.pkl → gene_median_dictionary_gc95M.pkl} +0 -0
- geneformer/{gene_name_id_dict_gc104M.pkl → gene_name_id_dict_gc95M.pkl} +2 -2
- geneformer/in_silico_perturber.py +11 -53
- geneformer/in_silico_perturber_stats.py +5 -25
- geneformer/mtl/__init__.py +1 -4
- geneformer/mtl/collators.py +2 -2
- geneformer/mtl/data.py +123 -210
- geneformer/mtl/eval_utils.py +8 -5
- geneformer/mtl/imports.py +43 -0
- geneformer/mtl/model.py +1 -1
- geneformer/mtl/optuna_utils.py +27 -0
- geneformer/mtl/train.py +329 -656
- geneformer/mtl/train_utils.py +161 -0
- geneformer/mtl/utils.py +91 -603
- geneformer/mtl_classifier.py +8 -23
- geneformer/perturber_utils.py +65 -62
- geneformer/pretrainer.py +176 -6
- geneformer/{token_dictionary_gc104M.pkl → token_dictionary_gc95M.pkl} +0 -0
- geneformer/tokenizer.py +66 -186
- generation_config.json +1 -1
- gf-12L-30M-i2048/config.json +23 -0
- gf-12L-30M-i2048/pytorch_model.bin +3 -0
Geneformer-V2-316M/model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:965ceccea81953d362081ef3843560a0e4fef88d396c28017881f1e94b1246f3
|
| 3 |
-
size 1265455076
|
|
|
|
|
|
|
|
|
|
|
|
MANIFEST.in
CHANGED
|
@@ -1,9 +1,4 @@
|
|
| 1 |
-
include geneformer/
|
| 2 |
-
include geneformer/
|
| 3 |
-
include geneformer/
|
| 4 |
-
include geneformer/
|
| 5 |
-
|
| 6 |
-
include geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl
|
| 7 |
-
include geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl
|
| 8 |
-
include geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl
|
| 9 |
-
include geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl
|
|
|
|
| 1 |
+
include geneformer/gene_median_dictionary_gc95M.pkl
|
| 2 |
+
include geneformer/gene_name_id_dict_gc95M.pkl
|
| 3 |
+
include geneformer/ensembl_mapping_dict_gc95M.pkl
|
| 4 |
+
include geneformer/token_dictionary_gc95M.pkl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,36 +1,40 @@
|
|
| 1 |
---
|
| 2 |
datasets: ctheodoris/Genecorpus-30M
|
| 3 |
license: apache-2.0
|
| 4 |
-
tags:
|
| 5 |
-
- single-cell
|
| 6 |
-
- genomics
|
| 7 |
---
|
| 8 |
# Geneformer
|
| 9 |
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
| 10 |
|
| 11 |
- See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
|
| 12 |
-
- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model
|
| 13 |
- See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
|
| 14 |
|
| 15 |
# Model Description
|
| 16 |
-
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer
|
| 17 |
|
| 18 |
-
Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus
|
| 19 |
|
| 20 |
The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
| 21 |
|
| 22 |
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
| 23 |
|
| 24 |
-
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational
|
| 25 |
|
| 26 |
The repository includes the following pretrained models:
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
The
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Application
|
| 36 |
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
|
@@ -78,13 +82,9 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
|
|
| 78 |
- extracting and plotting cell embeddings
|
| 79 |
- in silico perturbation
|
| 80 |
|
| 81 |
-
Please also see [here](https://tinyurl.com/geneformertutorial) for a quickstart tutorial for predicting candidate therapeutic targets with Geneformer.
|
| 82 |
-
|
| 83 |
-
Complete documentation is available at https://geneformer.readthedocs.io/en/latest/.
|
| 84 |
-
|
| 85 |
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
| 86 |
|
| 87 |
-
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
| 88 |
|
| 89 |
# Citations
|
| 90 |
- C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
|
|
|
|
| 1 |
---
|
| 2 |
datasets: ctheodoris/Genecorpus-30M
|
| 3 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
| 4 |
---
|
| 5 |
# Geneformer
|
| 6 |
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
| 7 |
|
| 8 |
- See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
|
| 9 |
+
- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
|
| 10 |
- See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
|
| 11 |
|
| 12 |
# Model Description
|
| 13 |
+
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
|
| 14 |
|
| 15 |
+
Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
| 16 |
|
| 17 |
The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
| 18 |
|
| 19 |
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
| 20 |
|
| 21 |
+
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
|
| 22 |
|
| 23 |
The repository includes the following pretrained models:
|
| 24 |
|
| 25 |
+
L=layers\
|
| 26 |
+
M=millions of cells used for pretraining\
|
| 27 |
+
i=input size\
|
| 28 |
+
(pretraining date)
|
| 29 |
|
| 30 |
+
- GF-6L-30M-i2048 (June 2021)
|
| 31 |
+
- GF-12L-30M-i2048 (June 2021)
|
| 32 |
+
- GF-12L-95M-i4096 (April 2024)
|
| 33 |
+
- GF-20L-95M-i4096 (April 2024)
|
| 34 |
|
| 35 |
+
The current default model in the main directory of the repository is GF-12L-95M-i4096.
|
| 36 |
+
|
| 37 |
+
The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
|
| 38 |
|
| 39 |
# Application
|
| 40 |
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
|
|
|
| 82 |
- extracting and plotting cell embeddings
|
| 83 |
- in silico perturbation
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
| 86 |
|
| 87 |
+
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
| 88 |
|
| 89 |
# Citations
|
| 90 |
- C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
|
config.json
CHANGED
|
@@ -2,22 +2,22 @@
|
|
| 2 |
"architectures": [
|
| 3 |
"BertForMaskedLM"
|
| 4 |
],
|
| 5 |
-
"attention_probs_dropout_prob": 0.
|
| 6 |
"classifier_dropout": null,
|
| 7 |
"hidden_act": "relu",
|
| 8 |
-
"hidden_dropout_prob": 0.
|
| 9 |
-
"hidden_size":
|
| 10 |
"initializer_range": 0.02,
|
| 11 |
-
"intermediate_size":
|
| 12 |
"layer_norm_eps": 1e-12,
|
| 13 |
"max_position_embeddings": 4096,
|
| 14 |
"model_type": "bert",
|
| 15 |
-
"num_attention_heads":
|
| 16 |
-
"num_hidden_layers":
|
| 17 |
"pad_token_id": 0,
|
| 18 |
"position_embedding_type": "absolute",
|
| 19 |
"torch_dtype": "float32",
|
| 20 |
-
"transformers_version": "4.
|
| 21 |
"type_vocab_size": 2,
|
| 22 |
"use_cache": true,
|
| 23 |
"vocab_size": 20275
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"BertForMaskedLM"
|
| 4 |
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.02,
|
| 6 |
"classifier_dropout": null,
|
| 7 |
"hidden_act": "relu",
|
| 8 |
+
"hidden_dropout_prob": 0.02,
|
| 9 |
+
"hidden_size": 512,
|
| 10 |
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 1024,
|
| 12 |
"layer_norm_eps": 1e-12,
|
| 13 |
"max_position_embeddings": 4096,
|
| 14 |
"model_type": "bert",
|
| 15 |
+
"num_attention_heads": 8,
|
| 16 |
+
"num_hidden_layers": 12,
|
| 17 |
"pad_token_id": 0,
|
| 18 |
"position_embedding_type": "absolute",
|
| 19 |
"torch_dtype": "float32",
|
| 20 |
+
"transformers_version": "4.37.1",
|
| 21 |
"type_vocab_size": 2,
|
| 22 |
"use_cache": true,
|
| 23 |
"vocab_size": 20275
|
docs/source/geneformer.in_silico_perturber.rst
CHANGED
|
@@ -5,4 +5,4 @@ geneformer.in\_silico\_perturber
|
|
| 5 |
:members:
|
| 6 |
:undoc-members:
|
| 7 |
:show-inheritance:
|
| 8 |
-
:exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set,
|
|
|
|
| 5 |
:members:
|
| 6 |
:undoc-members:
|
| 7 |
:show-inheritance:
|
| 8 |
+
:exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary
|
examples/cell_classification.ipynb
CHANGED
|
@@ -13,7 +13,7 @@
|
|
| 13 |
"id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
|
| 14 |
"metadata": {},
|
| 15 |
"source": [
|
| 16 |
-
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization.
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
|
@@ -68,8 +68,6 @@
|
|
| 68 |
" \"per_device_train_batch_size\": 12,\n",
|
| 69 |
" \"seed\": 73,\n",
|
| 70 |
"}\n",
|
| 71 |
-
"\n",
|
| 72 |
-
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
| 73 |
"cc = Classifier(classifier=\"cell\",\n",
|
| 74 |
" cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
|
| 75 |
" filter_data=filter_data_dict,\n",
|
|
@@ -78,7 +76,6 @@
|
|
| 78 |
" freeze_layers = 2,\n",
|
| 79 |
" num_crossval_splits = 1,\n",
|
| 80 |
" forward_batch_size=200,\n",
|
| 81 |
-
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
| 82 |
" nproc=16)"
|
| 83 |
]
|
| 84 |
},
|
|
@@ -128,7 +125,7 @@
|
|
| 128 |
" \"train\": train_ids+eval_ids,\n",
|
| 129 |
" \"test\": test_ids}\n",
|
| 130 |
"\n",
|
| 131 |
-
"# Example input_data_file
|
| 132 |
"cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
|
| 133 |
" output_directory=output_dir,\n",
|
| 134 |
" output_prefix=output_prefix,\n",
|
|
@@ -263,8 +260,8 @@
|
|
| 263 |
" \"train\": train_ids,\n",
|
| 264 |
" \"eval\": eval_ids}\n",
|
| 265 |
"\n",
|
| 266 |
-
"#
|
| 267 |
-
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\"
|
| 268 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
|
| 269 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 270 |
" output_directory=output_dir,\n",
|
|
@@ -449,7 +446,7 @@
|
|
| 449 |
"name": "python",
|
| 450 |
"nbconvert_exporter": "python",
|
| 451 |
"pygments_lexer": "ipython3",
|
| 452 |
-
"version": "3.
|
| 453 |
}
|
| 454 |
},
|
| 455 |
"nbformat": 4,
|
|
|
|
| 13 |
"id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
|
| 14 |
"metadata": {},
|
| 15 |
"source": [
|
| 16 |
+
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
|
|
|
| 68 |
" \"per_device_train_batch_size\": 12,\n",
|
| 69 |
" \"seed\": 73,\n",
|
| 70 |
"}\n",
|
|
|
|
|
|
|
| 71 |
"cc = Classifier(classifier=\"cell\",\n",
|
| 72 |
" cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
|
| 73 |
" filter_data=filter_data_dict,\n",
|
|
|
|
| 76 |
" freeze_layers = 2,\n",
|
| 77 |
" num_crossval_splits = 1,\n",
|
| 78 |
" forward_batch_size=200,\n",
|
|
|
|
| 79 |
" nproc=16)"
|
| 80 |
]
|
| 81 |
},
|
|
|
|
| 125 |
" \"train\": train_ids+eval_ids,\n",
|
| 126 |
" \"test\": test_ids}\n",
|
| 127 |
"\n",
|
| 128 |
+
"# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
|
| 129 |
"cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
|
| 130 |
" output_directory=output_dir,\n",
|
| 131 |
" output_prefix=output_prefix,\n",
|
|
|
|
| 260 |
" \"train\": train_ids,\n",
|
| 261 |
" \"eval\": eval_ids}\n",
|
| 262 |
"\n",
|
| 263 |
+
"# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
|
| 264 |
+
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
|
| 265 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
|
| 266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 267 |
" output_directory=output_dir,\n",
|
|
|
|
| 446 |
"name": "python",
|
| 447 |
"nbconvert_exporter": "python",
|
| 448 |
"pygments_lexer": "ipython3",
|
| 449 |
+
"version": "3.11.5"
|
| 450 |
}
|
| 451 |
},
|
| 452 |
"nbformat": 4,
|
examples/distributed_multitask_cell_classification.ipynb
DELETED
|
@@ -1,149 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": null,
|
| 6 |
-
"id": "b3266a7b",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [],
|
| 9 |
-
"source": [
|
| 10 |
-
"import os\n",
|
| 11 |
-
"import torch\n",
|
| 12 |
-
"from geneformer import MTLClassifier"
|
| 13 |
-
]
|
| 14 |
-
},
|
| 15 |
-
{
|
| 16 |
-
"cell_type": "code",
|
| 17 |
-
"execution_count": null,
|
| 18 |
-
"id": "3e12ac9f",
|
| 19 |
-
"metadata": {},
|
| 20 |
-
"outputs": [],
|
| 21 |
-
"source": [
|
| 22 |
-
"# Define paths\n",
|
| 23 |
-
"pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
|
| 24 |
-
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
| 25 |
-
"train_path = \"/path/to/train/data.dataset\"\n",
|
| 26 |
-
"val_path = \"/path/to/val/data.dataset\"\n",
|
| 27 |
-
"test_path = \"/path/to/test/data.dataset\"\n",
|
| 28 |
-
"results_dir = \"/path/to/results/directory\"\n",
|
| 29 |
-
"model_save_path = \"/path/to/model/save/path\"\n",
|
| 30 |
-
"tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
|
| 31 |
-
"\n",
|
| 32 |
-
"# Define tasks and hyperparameters\n",
|
| 33 |
-
"# task_columns should be a list of column names from your dataset\n",
|
| 34 |
-
"# Each column represents a specific classification task (e.g. cell type, disease state)\n",
|
| 35 |
-
"task_columns = [\"cell_type\", \"disease_state\"] # Example task columns"
|
| 36 |
-
]
|
| 37 |
-
},
|
| 38 |
-
{
|
| 39 |
-
"cell_type": "code",
|
| 40 |
-
"execution_count": null,
|
| 41 |
-
"id": "c9bd7562",
|
| 42 |
-
"metadata": {},
|
| 43 |
-
"outputs": [],
|
| 44 |
-
"source": [
|
| 45 |
-
"# Check GPU environment\n",
|
| 46 |
-
"num_gpus = torch.cuda.device_count()\n",
|
| 47 |
-
"use_distributed = num_gpus > 1\n",
|
| 48 |
-
"print(f\"Number of GPUs detected: {num_gpus}\")\n",
|
| 49 |
-
"print(f\"Using distributed training: {use_distributed}\")\n",
|
| 50 |
-
"\n",
|
| 51 |
-
"# Set environment variables for distributed training when multiple GPUs are available\n",
|
| 52 |
-
"if use_distributed:\n",
|
| 53 |
-
" os.environ[\"MASTER_ADDR\"] = \"localhost\" # hostname\n",
|
| 54 |
-
" os.environ[\"MASTER_PORT\"] = \"12355\" # Choose an available port\n",
|
| 55 |
-
" print(\"Distributed environment variables set.\")"
|
| 56 |
-
]
|
| 57 |
-
},
|
| 58 |
-
{
|
| 59 |
-
"cell_type": "code",
|
| 60 |
-
"execution_count": null,
|
| 61 |
-
"id": "b6ff3618",
|
| 62 |
-
"metadata": {},
|
| 63 |
-
"outputs": [],
|
| 64 |
-
"source": [
|
| 65 |
-
"#Define Hyperparameters for Optimization\n",
|
| 66 |
-
"hyperparameters = {\n",
|
| 67 |
-
" \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
|
| 68 |
-
" \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
|
| 69 |
-
" \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
|
| 70 |
-
" \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
|
| 71 |
-
" \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
|
| 72 |
-
" \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n",
|
| 73 |
-
"}"
|
| 74 |
-
]
|
| 75 |
-
},
|
| 76 |
-
{
|
| 77 |
-
"cell_type": "code",
|
| 78 |
-
"execution_count": null,
|
| 79 |
-
"id": "f665c5a7",
|
| 80 |
-
"metadata": {},
|
| 81 |
-
"outputs": [],
|
| 82 |
-
"source": [
|
| 83 |
-
"mc = MTLClassifier(\n",
|
| 84 |
-
" task_columns=task_columns, # Our defined classification tasks\n",
|
| 85 |
-
" study_name=\"MTLClassifier_distributed\",\n",
|
| 86 |
-
" pretrained_path=pretrained_path,\n",
|
| 87 |
-
" train_path=train_path,\n",
|
| 88 |
-
" val_path=val_path,\n",
|
| 89 |
-
" test_path=test_path,\n",
|
| 90 |
-
" model_save_path=model_save_path,\n",
|
| 91 |
-
" results_dir=results_dir,\n",
|
| 92 |
-
" tensorboard_log_dir=tensorboard_log_dir,\n",
|
| 93 |
-
" hyperparameters=hyperparameters,\n",
|
| 94 |
-
" # Distributed training parameters\n",
|
| 95 |
-
" distributed_training=use_distributed, # Enable distributed training if multiple GPUs available\n",
|
| 96 |
-
" master_addr=\"localhost\" if use_distributed else None,\n",
|
| 97 |
-
" master_port=\"12355\" if use_distributed else None,\n",
|
| 98 |
-
" # Other training parameters\n",
|
| 99 |
-
" n_trials=15, # Number of trials for hyperparameter optimization\n",
|
| 100 |
-
" epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
|
| 101 |
-
" batch_size=8, # Adjust based on available GPU memory\n",
|
| 102 |
-
" gradient_accumulation_steps=4, # Accumulate gradients over multiple steps\n",
|
| 103 |
-
" gradient_clipping=True, # Enable gradient clipping for stability\n",
|
| 104 |
-
" max_grad_norm=1.0, # Set maximum gradient norm\n",
|
| 105 |
-
" seed=42\n",
|
| 106 |
-
")"
|
| 107 |
-
]
|
| 108 |
-
},
|
| 109 |
-
{
|
| 110 |
-
"cell_type": "code",
|
| 111 |
-
"execution_count": null,
|
| 112 |
-
"id": "f69f7b6a",
|
| 113 |
-
"metadata": {},
|
| 114 |
-
"outputs": [],
|
| 115 |
-
"source": [
|
| 116 |
-
"# Run Hyperparameter Optimization with Distributed Training\n",
|
| 117 |
-
"if __name__ == \"__main__\":\n",
|
| 118 |
-
" # This guard is required for distributed training to prevent\n",
|
| 119 |
-
" # infinite subprocess spawning when using torch.multiprocessing\n",
|
| 120 |
-
" mc.run_optuna_study()"
|
| 121 |
-
]
|
| 122 |
-
},
|
| 123 |
-
{
|
| 124 |
-
"cell_type": "code",
|
| 125 |
-
"execution_count": null,
|
| 126 |
-
"id": "3affd5dd",
|
| 127 |
-
"metadata": {},
|
| 128 |
-
"outputs": [],
|
| 129 |
-
"source": [
|
| 130 |
-
"# Evaluate the Model on Test Data\n",
|
| 131 |
-
"if __name__ == \"__main__\":\n",
|
| 132 |
-
" mc.load_and_evaluate_test_model()"
|
| 133 |
-
]
|
| 134 |
-
}
|
| 135 |
-
],
|
| 136 |
-
"metadata": {
|
| 137 |
-
"kernelspec": {
|
| 138 |
-
"display_name": "bio",
|
| 139 |
-
"language": "python",
|
| 140 |
-
"name": "python3"
|
| 141 |
-
},
|
| 142 |
-
"language_info": {
|
| 143 |
-
"name": "python",
|
| 144 |
-
"version": "3.12.8"
|
| 145 |
-
}
|
| 146 |
-
},
|
| 147 |
-
"nbformat": 4,
|
| 148 |
-
"nbformat_minor": 5
|
| 149 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/extract_and_plot_cell_embeddings.ipynb
CHANGED
|
@@ -18,7 +18,6 @@
|
|
| 18 |
"outputs": [],
|
| 19 |
"source": [
|
| 20 |
"# initiate EmbExtractor\n",
|
| 21 |
-
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
| 22 |
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
| 23 |
" num_classes=3,\n",
|
| 24 |
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
|
@@ -27,13 +26,12 @@
|
|
| 27 |
" emb_label=[\"disease\",\"cell_type\"],\n",
|
| 28 |
" labels_to_plot=[\"disease\"],\n",
|
| 29 |
" forward_batch_size=200,\n",
|
| 30 |
-
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
| 31 |
" nproc=16)\n",
|
| 32 |
"\n",
|
| 33 |
"# extracts embedding from input data\n",
|
| 34 |
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
| 35 |
-
"# example dataset
|
| 36 |
-
"embs = embex.extract_embs(\"../fine_tuned_models/
|
| 37 |
" \"path/to/input_data/\",\n",
|
| 38 |
" \"path/to/output_directory/\",\n",
|
| 39 |
" \"output_prefix\")\n"
|
|
@@ -131,7 +129,7 @@
|
|
| 131 |
"name": "python",
|
| 132 |
"nbconvert_exporter": "python",
|
| 133 |
"pygments_lexer": "ipython3",
|
| 134 |
-
"version": "3.
|
| 135 |
}
|
| 136 |
},
|
| 137 |
"nbformat": 4,
|
|
|
|
| 18 |
"outputs": [],
|
| 19 |
"source": [
|
| 20 |
"# initiate EmbExtractor\n",
|
|
|
|
| 21 |
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
| 22 |
" num_classes=3,\n",
|
| 23 |
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
|
|
|
| 26 |
" emb_label=[\"disease\",\"cell_type\"],\n",
|
| 27 |
" labels_to_plot=[\"disease\"],\n",
|
| 28 |
" forward_batch_size=200,\n",
|
|
|
|
| 29 |
" nproc=16)\n",
|
| 30 |
"\n",
|
| 31 |
"# extracts embedding from input data\n",
|
| 32 |
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
| 33 |
+
"# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
|
| 34 |
+
"embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
|
| 35 |
" \"path/to/input_data/\",\n",
|
| 36 |
" \"path/to/output_directory/\",\n",
|
| 37 |
" \"output_prefix\")\n"
|
|
|
|
| 129 |
"name": "python",
|
| 130 |
"nbconvert_exporter": "python",
|
| 131 |
"pygments_lexer": "ipython3",
|
| 132 |
+
"version": "3.11.5"
|
| 133 |
}
|
| 134 |
},
|
| 135 |
"nbformat": 4,
|
examples/gene_classification.ipynb
CHANGED
|
@@ -13,7 +13,7 @@
|
|
| 13 |
"id": "79539e95-2c9c-4162-835c-f0d158abb15d",
|
| 14 |
"metadata": {},
|
| 15 |
"source": [
|
| 16 |
-
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications.
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
|
@@ -71,14 +71,12 @@
|
|
| 71 |
}
|
| 72 |
],
|
| 73 |
"source": [
|
| 74 |
-
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
| 75 |
"cc = Classifier(classifier=\"gene\",\n",
|
| 76 |
" gene_class_dict = gene_class_dict,\n",
|
| 77 |
" max_ncells = 10_000,\n",
|
| 78 |
" freeze_layers = 4,\n",
|
| 79 |
" num_crossval_splits = 5,\n",
|
| 80 |
" forward_batch_size=200,\n",
|
| 81 |
-
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
| 82 |
" nproc=16)"
|
| 83 |
]
|
| 84 |
},
|
|
@@ -104,7 +102,7 @@
|
|
| 104 |
}
|
| 105 |
],
|
| 106 |
"source": [
|
| 107 |
-
"# Example input_data_file
|
| 108 |
"cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
|
| 109 |
" output_directory=output_dir,\n",
|
| 110 |
" output_prefix=output_prefix)"
|
|
@@ -842,8 +840,8 @@
|
|
| 842 |
}
|
| 843 |
],
|
| 844 |
"source": [
|
| 845 |
-
"#
|
| 846 |
-
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\"
|
| 847 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
| 848 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 849 |
" output_directory=output_dir,\n",
|
|
@@ -1065,14 +1063,12 @@
|
|
| 1065 |
}
|
| 1066 |
],
|
| 1067 |
"source": [
|
| 1068 |
-
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
| 1069 |
"cc = Classifier(classifier=\"gene\",\n",
|
| 1070 |
" gene_class_dict = gene_class_dict,\n",
|
| 1071 |
" max_ncells = 10_000,\n",
|
| 1072 |
" freeze_layers = 4,\n",
|
| 1073 |
" num_crossval_splits = 0,\n",
|
| 1074 |
" forward_batch_size=200,\n",
|
| 1075 |
-
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
| 1076 |
" nproc=16)"
|
| 1077 |
]
|
| 1078 |
},
|
|
@@ -1219,8 +1215,8 @@
|
|
| 1219 |
}
|
| 1220 |
],
|
| 1221 |
"source": [
|
| 1222 |
-
"#
|
| 1223 |
-
"trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\"
|
| 1224 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
| 1225 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 1226 |
" output_directory=output_dir,\n",
|
|
@@ -1244,7 +1240,7 @@
|
|
| 1244 |
"name": "python",
|
| 1245 |
"nbconvert_exporter": "python",
|
| 1246 |
"pygments_lexer": "ipython3",
|
| 1247 |
-
"version": "3.
|
| 1248 |
}
|
| 1249 |
},
|
| 1250 |
"nbformat": 4,
|
|
|
|
| 13 |
"id": "79539e95-2c9c-4162-835c-f0d158abb15d",
|
| 14 |
"metadata": {},
|
| 15 |
"source": [
|
| 16 |
+
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
|
|
|
| 71 |
}
|
| 72 |
],
|
| 73 |
"source": [
|
|
|
|
| 74 |
"cc = Classifier(classifier=\"gene\",\n",
|
| 75 |
" gene_class_dict = gene_class_dict,\n",
|
| 76 |
" max_ncells = 10_000,\n",
|
| 77 |
" freeze_layers = 4,\n",
|
| 78 |
" num_crossval_splits = 5,\n",
|
| 79 |
" forward_batch_size=200,\n",
|
|
|
|
| 80 |
" nproc=16)"
|
| 81 |
]
|
| 82 |
},
|
|
|
|
| 102 |
}
|
| 103 |
],
|
| 104 |
"source": [
|
| 105 |
+
"# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n",
|
| 106 |
"cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
|
| 107 |
" output_directory=output_dir,\n",
|
| 108 |
" output_prefix=output_prefix)"
|
|
|
|
| 840 |
}
|
| 841 |
],
|
| 842 |
"source": [
|
| 843 |
+
"# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
|
| 844 |
+
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
|
| 845 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
| 846 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 847 |
" output_directory=output_dir,\n",
|
|
|
|
| 1063 |
}
|
| 1064 |
],
|
| 1065 |
"source": [
|
|
|
|
| 1066 |
"cc = Classifier(classifier=\"gene\",\n",
|
| 1067 |
" gene_class_dict = gene_class_dict,\n",
|
| 1068 |
" max_ncells = 10_000,\n",
|
| 1069 |
" freeze_layers = 4,\n",
|
| 1070 |
" num_crossval_splits = 0,\n",
|
| 1071 |
" forward_batch_size=200,\n",
|
|
|
|
| 1072 |
" nproc=16)"
|
| 1073 |
]
|
| 1074 |
},
|
|
|
|
| 1215 |
}
|
| 1216 |
],
|
| 1217 |
"source": [
|
| 1218 |
+
"# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
|
| 1219 |
+
"trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n",
|
| 1220 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
| 1221 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 1222 |
" output_directory=output_dir,\n",
|
|
|
|
| 1240 |
"name": "python",
|
| 1241 |
"nbconvert_exporter": "python",
|
| 1242 |
"pygments_lexer": "ipython3",
|
| 1243 |
+
"version": "3.11.5"
|
| 1244 |
}
|
| 1245 |
},
|
| 1246 |
"nbformat": 4,
|
examples/in_silico_perturbation.ipynb
CHANGED
|
@@ -39,19 +39,17 @@
|
|
| 39 |
"\n",
|
| 40 |
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
| 41 |
"\n",
|
| 42 |
-
"
|
| 43 |
-
"embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
| 44 |
" num_classes=3,\n",
|
| 45 |
" filter_data=filter_data_dict,\n",
|
| 46 |
" max_ncells=1000,\n",
|
| 47 |
" emb_layer=0,\n",
|
| 48 |
" summary_stat=\"exact_mean\",\n",
|
| 49 |
" forward_batch_size=256,\n",
|
| 50 |
-
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
| 51 |
" nproc=16)\n",
|
| 52 |
"\n",
|
| 53 |
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
| 54 |
-
" \"
|
| 55 |
" \"path/to/input_data\",\n",
|
| 56 |
" \"path/to/output_directory\",\n",
|
| 57 |
" \"output_prefix\")"
|
|
@@ -66,15 +64,14 @@
|
|
| 66 |
},
|
| 67 |
"outputs": [],
|
| 68 |
"source": [
|
| 69 |
-
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
| 70 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
| 71 |
" perturb_rank_shift=None,\n",
|
| 72 |
" genes_to_perturb=\"all\",\n",
|
| 73 |
" combos=0,\n",
|
| 74 |
" anchor_gene=None,\n",
|
| 75 |
-
" model_type=\"CellClassifier\"
|
| 76 |
" num_classes=3,\n",
|
| 77 |
-
" emb_mode=\"cell\"
|
| 78 |
" cell_emb_style=\"mean_pool\",\n",
|
| 79 |
" filter_data=filter_data_dict,\n",
|
| 80 |
" cell_states_to_model=cell_states_to_model,\n",
|
|
@@ -82,7 +79,6 @@
|
|
| 82 |
" max_ncells=2000,\n",
|
| 83 |
" emb_layer=0,\n",
|
| 84 |
" forward_batch_size=400,\n",
|
| 85 |
-
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
| 86 |
" nproc=16)"
|
| 87 |
]
|
| 88 |
},
|
|
@@ -94,10 +90,9 @@
|
|
| 94 |
"outputs": [],
|
| 95 |
"source": [
|
| 96 |
"# outputs intermediate files from in silico perturbation\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"isp.perturb_data(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
|
| 99 |
" \"path/to/input_data\",\n",
|
| 100 |
-
" \"path/to/
|
| 101 |
" \"output_prefix\")"
|
| 102 |
]
|
| 103 |
},
|
|
@@ -108,13 +103,11 @@
|
|
| 108 |
"metadata": {},
|
| 109 |
"outputs": [],
|
| 110 |
"source": [
|
| 111 |
-
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
| 112 |
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
| 113 |
" genes_perturbed=\"all\",\n",
|
| 114 |
" combos=0,\n",
|
| 115 |
" anchor_gene=None,\n",
|
| 116 |
-
" cell_states_to_model=cell_states_to_model
|
| 117 |
-
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL SINCE V1 WAS USED FOR IN SILICO PERTURBATION ABOVE)"
|
| 118 |
]
|
| 119 |
},
|
| 120 |
{
|
|
@@ -125,9 +118,9 @@
|
|
| 125 |
"outputs": [],
|
| 126 |
"source": [
|
| 127 |
"# extracts data from intermediate files and processes stats to output in final .csv\n",
|
| 128 |
-
"ispstats.get_stats(\"path/to/
|
| 129 |
" None,\n",
|
| 130 |
-
" \"path/to/
|
| 131 |
" \"output_prefix\")"
|
| 132 |
]
|
| 133 |
}
|
|
@@ -148,7 +141,7 @@
|
|
| 148 |
"name": "python",
|
| 149 |
"nbconvert_exporter": "python",
|
| 150 |
"pygments_lexer": "ipython3",
|
| 151 |
-
"version": "3.10.
|
| 152 |
}
|
| 153 |
},
|
| 154 |
"nbformat": 4,
|
|
|
|
| 39 |
"\n",
|
| 40 |
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
| 41 |
"\n",
|
| 42 |
+
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
|
|
|
| 43 |
" num_classes=3,\n",
|
| 44 |
" filter_data=filter_data_dict,\n",
|
| 45 |
" max_ncells=1000,\n",
|
| 46 |
" emb_layer=0,\n",
|
| 47 |
" summary_stat=\"exact_mean\",\n",
|
| 48 |
" forward_batch_size=256,\n",
|
|
|
|
| 49 |
" nproc=16)\n",
|
| 50 |
"\n",
|
| 51 |
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
| 52 |
+
" \"path/to/model\",\n",
|
| 53 |
" \"path/to/input_data\",\n",
|
| 54 |
" \"path/to/output_directory\",\n",
|
| 55 |
" \"output_prefix\")"
|
|
|
|
| 64 |
},
|
| 65 |
"outputs": [],
|
| 66 |
"source": [
|
|
|
|
| 67 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
| 68 |
" perturb_rank_shift=None,\n",
|
| 69 |
" genes_to_perturb=\"all\",\n",
|
| 70 |
" combos=0,\n",
|
| 71 |
" anchor_gene=None,\n",
|
| 72 |
+
" model_type=\"CellClassifier\",\n",
|
| 73 |
" num_classes=3,\n",
|
| 74 |
+
" emb_mode=\"cell\",\n",
|
| 75 |
" cell_emb_style=\"mean_pool\",\n",
|
| 76 |
" filter_data=filter_data_dict,\n",
|
| 77 |
" cell_states_to_model=cell_states_to_model,\n",
|
|
|
|
| 79 |
" max_ncells=2000,\n",
|
| 80 |
" emb_layer=0,\n",
|
| 81 |
" forward_batch_size=400,\n",
|
|
|
|
| 82 |
" nproc=16)"
|
| 83 |
]
|
| 84 |
},
|
|
|
|
| 90 |
"outputs": [],
|
| 91 |
"source": [
|
| 92 |
"# outputs intermediate files from in silico perturbation\n",
|
| 93 |
+
"isp.perturb_data(\"path/to/model\",\n",
|
|
|
|
| 94 |
" \"path/to/input_data\",\n",
|
| 95 |
+
" \"path/to/output_directory\",\n",
|
| 96 |
" \"output_prefix\")"
|
| 97 |
]
|
| 98 |
},
|
|
|
|
| 103 |
"metadata": {},
|
| 104 |
"outputs": [],
|
| 105 |
"source": [
|
|
|
|
| 106 |
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
| 107 |
" genes_perturbed=\"all\",\n",
|
| 108 |
" combos=0,\n",
|
| 109 |
" anchor_gene=None,\n",
|
| 110 |
+
" cell_states_to_model=cell_states_to_model)"
|
|
|
|
| 111 |
]
|
| 112 |
},
|
| 113 |
{
|
|
|
|
| 118 |
"outputs": [],
|
| 119 |
"source": [
|
| 120 |
"# extracts data from intermediate files and processes stats to output in final .csv\n",
|
| 121 |
+
"ispstats.get_stats(\"path/to/input_data\",\n",
|
| 122 |
" None,\n",
|
| 123 |
+
" \"path/to/output_directory\",\n",
|
| 124 |
" \"output_prefix\")"
|
| 125 |
]
|
| 126 |
}
|
|
|
|
| 141 |
"name": "python",
|
| 142 |
"nbconvert_exporter": "python",
|
| 143 |
"pygments_lexer": "ipython3",
|
| 144 |
+
"version": "3.10.11"
|
| 145 |
}
|
| 146 |
},
|
| 147 |
"nbformat": 4,
|
examples/multitask_cell_classification.ipynb
CHANGED
|
@@ -286,7 +286,7 @@
|
|
| 286 |
" filter_data_dict=filter_data_dict,\n",
|
| 287 |
" max_ncells=1000, # Number of cells to extract embeddings for\n",
|
| 288 |
" emb_layer=0, # Use the second to last layer\n",
|
| 289 |
-
" emb_mode = \"cls\"
|
| 290 |
" summary_stat=\"exact_mean\",\n",
|
| 291 |
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
| 292 |
" nproc=4\n",
|
|
@@ -324,7 +324,7 @@
|
|
| 324 |
" perturb_type=perturb_type,\n",
|
| 325 |
" genes_to_perturb=\"all\", # Perturb all genes\n",
|
| 326 |
" model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
|
| 327 |
-
" emb_mode=\"cls\", # Use CLS token embedding
|
| 328 |
" cell_states_to_model=cell_states_to_model,\n",
|
| 329 |
" state_embs_dict=state_embs_dict,\n",
|
| 330 |
" max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
|
|
@@ -412,7 +412,7 @@
|
|
| 412 |
"name": "python",
|
| 413 |
"nbconvert_exporter": "python",
|
| 414 |
"pygments_lexer": "ipython3",
|
| 415 |
-
"version": "3.
|
| 416 |
}
|
| 417 |
},
|
| 418 |
"nbformat": 4,
|
|
|
|
| 286 |
" filter_data_dict=filter_data_dict,\n",
|
| 287 |
" max_ncells=1000, # Number of cells to extract embeddings for\n",
|
| 288 |
" emb_layer=0, # Use the second to last layer\n",
|
| 289 |
+
" emb_mode = \"cls\",\n",
|
| 290 |
" summary_stat=\"exact_mean\",\n",
|
| 291 |
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
| 292 |
" nproc=4\n",
|
|
|
|
| 324 |
" perturb_type=perturb_type,\n",
|
| 325 |
" genes_to_perturb=\"all\", # Perturb all genes\n",
|
| 326 |
" model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
|
| 327 |
+
" emb_mode=\"cls\", # Use CLS token embedding\n",
|
| 328 |
" cell_states_to_model=cell_states_to_model,\n",
|
| 329 |
" state_embs_dict=state_embs_dict,\n",
|
| 330 |
" max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
|
|
|
|
| 412 |
"name": "python",
|
| 413 |
"nbconvert_exporter": "python",
|
| 414 |
"pygments_lexer": "ipython3",
|
| 415 |
+
"version": "3.11.5"
|
| 416 |
}
|
| 417 |
},
|
| 418 |
"nbformat": 4,
|
examples/tokenizing_scRNAseq_data.ipynb
CHANGED
|
@@ -12,7 +12,7 @@
|
|
| 12 |
},
|
| 13 |
{
|
| 14 |
"cell_type": "markdown",
|
| 15 |
-
"id": "
|
| 16 |
"metadata": {},
|
| 17 |
"source": [
|
| 18 |
"#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
|
|
@@ -25,17 +25,11 @@
|
|
| 25 |
"\n",
|
| 26 |
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
| 27 |
"\n",
|
| 28 |
-
"#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
"id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
|
| 34 |
-
"metadata": {},
|
| 35 |
-
"source": [
|
| 36 |
-
"**********************************************************************************************************\n",
|
| 37 |
-
"#### OF NOTE: Please ensure the correct token dictionary, gene median file, special token setting, and model input size is used for the correct model version.\n",
|
| 38 |
-
"#### Current defaults are for V2 model series. To auto-select the correct settings for V1, set model_version argument to \"V1\"."
|
| 39 |
]
|
| 40 |
},
|
| 41 |
{
|
|
@@ -55,7 +49,7 @@
|
|
| 55 |
"metadata": {},
|
| 56 |
"outputs": [],
|
| 57 |
"source": [
|
| 58 |
-
"tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)
|
| 59 |
"tk.tokenize_data(\"loom_data_directory\", \n",
|
| 60 |
" \"output_directory\", \n",
|
| 61 |
" \"output_prefix\", \n",
|
|
@@ -79,7 +73,7 @@
|
|
| 79 |
"name": "python",
|
| 80 |
"nbconvert_exporter": "python",
|
| 81 |
"pygments_lexer": "ipython3",
|
| 82 |
-
"version": "3.10.
|
| 83 |
}
|
| 84 |
},
|
| 85 |
"nbformat": 4,
|
|
|
|
| 12 |
},
|
| 13 |
{
|
| 14 |
"cell_type": "markdown",
|
| 15 |
+
"id": "350e6252-b783-494b-9767-f087eb868a15",
|
| 16 |
"metadata": {},
|
| 17 |
"source": [
|
| 18 |
"#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
|
|
|
|
| 25 |
"\n",
|
| 26 |
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
| 27 |
"\n",
|
| 28 |
+
"#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"#### The 95M model series also require the special_token argument to be set to True and model_input_size to be 4096."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
]
|
| 34 |
},
|
| 35 |
{
|
|
|
|
| 49 |
"metadata": {},
|
| 50 |
"outputs": [],
|
| 51 |
"source": [
|
| 52 |
+
"tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
|
| 53 |
"tk.tokenize_data(\"loom_data_directory\", \n",
|
| 54 |
" \"output_directory\", \n",
|
| 55 |
" \"output_prefix\", \n",
|
|
|
|
| 73 |
"name": "python",
|
| 74 |
"nbconvert_exporter": "python",
|
| 75 |
"pygments_lexer": "ipython3",
|
| 76 |
+
"version": "3.10.11"
|
| 77 |
}
|
| 78 |
},
|
| 79 |
"nbformat": 4,
|
{Geneformer-V2-104M → fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522}/config.json
RENAMED
|
@@ -2,22 +2,22 @@
|
|
| 2 |
"architectures": [
|
| 3 |
"BertForMaskedLM"
|
| 4 |
],
|
| 5 |
-
"attention_probs_dropout_prob": 0.
|
| 6 |
"classifier_dropout": null,
|
| 7 |
"hidden_act": "relu",
|
| 8 |
-
"hidden_dropout_prob": 0.
|
| 9 |
-
"hidden_size":
|
| 10 |
"initializer_range": 0.02,
|
| 11 |
-
"intermediate_size":
|
| 12 |
"layer_norm_eps": 1e-12,
|
| 13 |
"max_position_embeddings": 4096,
|
| 14 |
"model_type": "bert",
|
| 15 |
-
"num_attention_heads":
|
| 16 |
"num_hidden_layers": 12,
|
| 17 |
"pad_token_id": 0,
|
| 18 |
"position_embedding_type": "absolute",
|
| 19 |
"torch_dtype": "float32",
|
| 20 |
-
"transformers_version": "4.
|
| 21 |
"type_vocab_size": 2,
|
| 22 |
"use_cache": true,
|
| 23 |
"vocab_size": 20275
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"BertForMaskedLM"
|
| 4 |
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.02,
|
| 6 |
"classifier_dropout": null,
|
| 7 |
"hidden_act": "relu",
|
| 8 |
+
"hidden_dropout_prob": 0.02,
|
| 9 |
+
"hidden_size": 512,
|
| 10 |
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 1024,
|
| 12 |
"layer_norm_eps": 1e-12,
|
| 13 |
"max_position_embeddings": 4096,
|
| 14 |
"model_type": "bert",
|
| 15 |
+
"num_attention_heads": 8,
|
| 16 |
"num_hidden_layers": 12,
|
| 17 |
"pad_token_id": 0,
|
| 18 |
"position_embedding_type": "absolute",
|
| 19 |
"torch_dtype": "float32",
|
| 20 |
+
"transformers_version": "4.37.2",
|
| 21 |
"type_vocab_size": 2,
|
| 22 |
"use_cache": true,
|
| 23 |
"vocab_size": 20275
|
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
|
| 3 |
+
size 152363342
|
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json
RENAMED
|
File without changes
|
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt
RENAMED
|
File without changes
|
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin
RENAMED
|
File without changes
|
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth
RENAMED
|
File without changes
|
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt
RENAMED
|
File without changes
|
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json
RENAMED
|
File without changes
|
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin
RENAMED
|
File without changes
|
geneformer/__init__.py
CHANGED
|
@@ -4,15 +4,10 @@ from pathlib import Path
|
|
| 4 |
|
| 5 |
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
|
| 6 |
|
| 7 |
-
GENE_MEDIAN_FILE = Path(__file__).parent / "
|
| 8 |
-
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "
|
| 9 |
-
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "
|
| 10 |
-
ENSEMBL_MAPPING_FILE = Path(__file__).parent / "
|
| 11 |
-
|
| 12 |
-
GENE_MEDIAN_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl"
|
| 13 |
-
TOKEN_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/token_dictionary_gc30M.pkl"
|
| 14 |
-
ENSEMBL_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl"
|
| 15 |
-
ENSEMBL_MAPPING_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl"
|
| 16 |
|
| 17 |
from . import (
|
| 18 |
collator_for_classification,
|
|
|
|
| 4 |
|
| 5 |
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
|
| 6 |
|
| 7 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
|
| 8 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
|
| 9 |
+
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
|
| 10 |
+
ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from . import (
|
| 13 |
collator_for_classification,
|
geneformer/classifier.py
CHANGED
|
@@ -48,13 +48,11 @@ import logging
|
|
| 48 |
import os
|
| 49 |
import pickle
|
| 50 |
import subprocess
|
| 51 |
-
from packaging.version import parse
|
| 52 |
from pathlib import Path
|
| 53 |
|
| 54 |
import numpy as np
|
| 55 |
import pandas as pd
|
| 56 |
import seaborn as sns
|
| 57 |
-
import transformers
|
| 58 |
from tqdm.auto import tqdm, trange
|
| 59 |
from transformers import Trainer
|
| 60 |
from transformers.training_args import TrainingArguments
|
|
@@ -73,7 +71,6 @@ sns.set()
|
|
| 73 |
|
| 74 |
logger = logging.getLogger(__name__)
|
| 75 |
|
| 76 |
-
transformers_version = parse(transformers.__version__)
|
| 77 |
|
| 78 |
class Classifier:
|
| 79 |
valid_option_dict = {
|
|
@@ -92,7 +89,6 @@ class Classifier:
|
|
| 92 |
"no_eval": {bool},
|
| 93 |
"stratify_splits_col": {None, str},
|
| 94 |
"forward_batch_size": {int},
|
| 95 |
-
"model_version": {"V1", "V2"},
|
| 96 |
"token_dictionary_file": {None, str},
|
| 97 |
"nproc": {int},
|
| 98 |
"ngpu": {int},
|
|
@@ -116,7 +112,6 @@ class Classifier:
|
|
| 116 |
stratify_splits_col=None,
|
| 117 |
no_eval=False,
|
| 118 |
forward_batch_size=100,
|
| 119 |
-
model_version="V2",
|
| 120 |
token_dictionary_file=None,
|
| 121 |
nproc=4,
|
| 122 |
ngpu=1,
|
|
@@ -193,9 +188,6 @@ class Classifier:
|
|
| 193 |
| Otherwise, will perform eval during training.
|
| 194 |
forward_batch_size : int
|
| 195 |
| Batch size for forward pass (for evaluation, not training).
|
| 196 |
-
model_version : str
|
| 197 |
-
| To auto-select settings for model version other than current default.
|
| 198 |
-
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
| 199 |
token_dictionary_file : None, str
|
| 200 |
| Default is to use token dictionary file from Geneformer
|
| 201 |
| Otherwise, will load custom gene token dictionary.
|
|
@@ -230,16 +222,14 @@ class Classifier:
|
|
| 230 |
self.stratify_splits_col = stratify_splits_col
|
| 231 |
self.no_eval = no_eval
|
| 232 |
self.forward_batch_size = forward_batch_size
|
| 233 |
-
self.model_version = model_version
|
| 234 |
self.token_dictionary_file = token_dictionary_file
|
| 235 |
self.nproc = nproc
|
| 236 |
self.ngpu = ngpu
|
| 237 |
-
|
| 238 |
if self.training_args is None:
|
| 239 |
logger.warning(
|
| 240 |
"Hyperparameter tuning is highly recommended for optimal results. "
|
| 241 |
-
"No training_args provided; using default hyperparameters.
|
| 242 |
-
"Please note: these defaults are not recommended to be used uniformly across tasks."
|
| 243 |
)
|
| 244 |
|
| 245 |
self.validate_options()
|
|
@@ -254,10 +244,7 @@ class Classifier:
|
|
| 254 |
] = self.cell_state_dict["states"]
|
| 255 |
|
| 256 |
# load token dictionary (Ensembl IDs:token)
|
| 257 |
-
if self.
|
| 258 |
-
from . import TOKEN_DICTIONARY_FILE_30M
|
| 259 |
-
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
| 260 |
-
elif self.token_dictionary_file is None:
|
| 261 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
| 262 |
with open(self.token_dictionary_file, "rb") as f:
|
| 263 |
self.gene_token_dict = pickle.load(f)
|
|
@@ -367,7 +354,6 @@ class Classifier:
|
|
| 367 |
attr_to_balance=None,
|
| 368 |
max_trials=100,
|
| 369 |
pval_threshold=0.1,
|
| 370 |
-
id_class_dict_path=None,
|
| 371 |
):
|
| 372 |
"""
|
| 373 |
Prepare data for cell state or gene classification.
|
|
@@ -410,10 +396,6 @@ class Classifier:
|
|
| 410 |
pval_threshold : None, float
|
| 411 |
| P-value threshold to use for attribute balancing across splits
|
| 412 |
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
| 413 |
-
id_class_dict_path : Path
|
| 414 |
-
| Path to *_id_class_dict.pkl from prior run of prepare_data to reuse for labeling new data
|
| 415 |
-
| Dictionary with keys being numeric class labels and values being original dataset class labels
|
| 416 |
-
| Note: only available for CellClassifiers
|
| 417 |
"""
|
| 418 |
|
| 419 |
if test_size is None:
|
|
@@ -455,25 +437,14 @@ class Classifier:
|
|
| 455 |
)
|
| 456 |
# rename cell state column to "label"
|
| 457 |
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
| 458 |
-
|
| 459 |
-
# convert classes to numerical labels and save as id_class_dict
|
| 460 |
-
if id_class_dict_path is not None:
|
| 461 |
-
with open(id_class_dict_path,"rb") as fp:
|
| 462 |
-
id_class_dict = pickle.load(fp)
|
| 463 |
-
else:
|
| 464 |
-
id_class_dict = None
|
| 465 |
-
data, id_class_dict = cu.label_classes(
|
| 466 |
-
self.classifier, data, self.cell_state_dict, self.nproc, id_class_dict,
|
| 467 |
-
)
|
| 468 |
|
| 469 |
-
elif self.classifier == "gene":
|
| 470 |
# convert classes to numerical labels and save as id_class_dict
|
| 471 |
# of note, will label all genes in gene_class_dict
|
| 472 |
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
| 473 |
# at the time of training with Classifier.validate
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
|
| 478 |
# save id_class_dict for future reference
|
| 479 |
id_class_output_path = (
|
|
@@ -801,7 +772,7 @@ class Classifier:
|
|
| 801 |
# 5-fold cross-validate
|
| 802 |
num_cells = len(data)
|
| 803 |
fifth_cells = int(np.floor(num_cells * 0.2))
|
| 804 |
-
num_eval =
|
| 805 |
start = i * fifth_cells
|
| 806 |
end = start + num_eval
|
| 807 |
eval_indices = [j for j in range(start, end)]
|
|
@@ -1085,8 +1056,6 @@ class Classifier:
|
|
| 1085 |
if eval_data is None:
|
| 1086 |
def_training_args["evaluation_strategy"] = "no"
|
| 1087 |
def_training_args["load_best_model_at_end"] = False
|
| 1088 |
-
if transformers_version >= parse("4.46"):
|
| 1089 |
-
def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy")
|
| 1090 |
def_training_args.update(
|
| 1091 |
{"save_strategy": "epoch", "save_total_limit": 1}
|
| 1092 |
) # only save last model for each run
|
|
@@ -1258,8 +1227,6 @@ class Classifier:
|
|
| 1258 |
if eval_data is None:
|
| 1259 |
def_training_args["evaluation_strategy"] = "no"
|
| 1260 |
def_training_args["load_best_model_at_end"] = False
|
| 1261 |
-
if transformers_version >= parse("4.46"):
|
| 1262 |
-
def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy")
|
| 1263 |
training_args_init = TrainingArguments(**def_training_args)
|
| 1264 |
|
| 1265 |
if self.freeze_layers is not None:
|
|
@@ -1313,7 +1280,6 @@ class Classifier:
|
|
| 1313 |
predict=False,
|
| 1314 |
output_directory=None,
|
| 1315 |
output_prefix=None,
|
| 1316 |
-
predict_metadata=None,
|
| 1317 |
):
|
| 1318 |
"""
|
| 1319 |
Evaluate the fine-tuned model.
|
|
@@ -1339,11 +1305,9 @@ class Classifier:
|
|
| 1339 |
|
| 1340 |
##### Evaluate the model #####
|
| 1341 |
labels = id_class_dict.keys()
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata
|
| 1345 |
)
|
| 1346 |
-
|
| 1347 |
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
| 1348 |
y_pred, y_true, logits_list, num_classes, labels
|
| 1349 |
)
|
|
@@ -1353,9 +1317,6 @@ class Classifier:
|
|
| 1353 |
"label_ids": y_true,
|
| 1354 |
"predictions": logits_list,
|
| 1355 |
}
|
| 1356 |
-
if predict_metadata is not None:
|
| 1357 |
-
pred_dict["prediction_metadata"] = predict_metadata_all
|
| 1358 |
-
|
| 1359 |
pred_dict_output_path = (
|
| 1360 |
Path(output_directory) / f"{output_prefix}_pred_dict"
|
| 1361 |
).with_suffix(".pkl")
|
|
@@ -1376,7 +1337,6 @@ class Classifier:
|
|
| 1376 |
output_directory,
|
| 1377 |
output_prefix,
|
| 1378 |
predict=True,
|
| 1379 |
-
predict_metadata=None,
|
| 1380 |
):
|
| 1381 |
"""
|
| 1382 |
Evaluate the fine-tuned model.
|
|
@@ -1396,8 +1356,6 @@ class Classifier:
|
|
| 1396 |
| Prefix for output files
|
| 1397 |
predict : bool
|
| 1398 |
| Whether or not to save eval predictions
|
| 1399 |
-
predict_metadata : None | list
|
| 1400 |
-
| Metadata labels to output with predictions (columns in test_data_file)
|
| 1401 |
"""
|
| 1402 |
|
| 1403 |
# load numerical id to class dictionary (id:class)
|
|
@@ -1410,15 +1368,6 @@ class Classifier:
|
|
| 1410 |
# load previously filtered and prepared data
|
| 1411 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
| 1412 |
|
| 1413 |
-
if predict_metadata is not None:
|
| 1414 |
-
absent_metadata = []
|
| 1415 |
-
for predict_metadata_x in predict_metadata:
|
| 1416 |
-
if predict_metadata_x not in test_data.features.keys():
|
| 1417 |
-
absent_metadata += [predict_metadata_x]
|
| 1418 |
-
if len(absent_metadata)>0:
|
| 1419 |
-
logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}")
|
| 1420 |
-
raise
|
| 1421 |
-
|
| 1422 |
# load previously fine-tuned model
|
| 1423 |
model = pu.load_model(
|
| 1424 |
self.model_type,
|
|
@@ -1437,7 +1386,6 @@ class Classifier:
|
|
| 1437 |
predict=predict,
|
| 1438 |
output_directory=output_directory,
|
| 1439 |
output_prefix=output_prefix,
|
| 1440 |
-
predict_metadata=predict_metadata,
|
| 1441 |
)
|
| 1442 |
|
| 1443 |
all_conf_mat_df = pd.DataFrame(
|
|
|
|
| 48 |
import os
|
| 49 |
import pickle
|
| 50 |
import subprocess
|
|
|
|
| 51 |
from pathlib import Path
|
| 52 |
|
| 53 |
import numpy as np
|
| 54 |
import pandas as pd
|
| 55 |
import seaborn as sns
|
|
|
|
| 56 |
from tqdm.auto import tqdm, trange
|
| 57 |
from transformers import Trainer
|
| 58 |
from transformers.training_args import TrainingArguments
|
|
|
|
| 71 |
|
| 72 |
logger = logging.getLogger(__name__)
|
| 73 |
|
|
|
|
| 74 |
|
| 75 |
class Classifier:
|
| 76 |
valid_option_dict = {
|
|
|
|
| 89 |
"no_eval": {bool},
|
| 90 |
"stratify_splits_col": {None, str},
|
| 91 |
"forward_batch_size": {int},
|
|
|
|
| 92 |
"token_dictionary_file": {None, str},
|
| 93 |
"nproc": {int},
|
| 94 |
"ngpu": {int},
|
|
|
|
| 112 |
stratify_splits_col=None,
|
| 113 |
no_eval=False,
|
| 114 |
forward_batch_size=100,
|
|
|
|
| 115 |
token_dictionary_file=None,
|
| 116 |
nproc=4,
|
| 117 |
ngpu=1,
|
|
|
|
| 188 |
| Otherwise, will perform eval during training.
|
| 189 |
forward_batch_size : int
|
| 190 |
| Batch size for forward pass (for evaluation, not training).
|
|
|
|
|
|
|
|
|
|
| 191 |
token_dictionary_file : None, str
|
| 192 |
| Default is to use token dictionary file from Geneformer
|
| 193 |
| Otherwise, will load custom gene token dictionary.
|
|
|
|
| 222 |
self.stratify_splits_col = stratify_splits_col
|
| 223 |
self.no_eval = no_eval
|
| 224 |
self.forward_batch_size = forward_batch_size
|
|
|
|
| 225 |
self.token_dictionary_file = token_dictionary_file
|
| 226 |
self.nproc = nproc
|
| 227 |
self.ngpu = ngpu
|
| 228 |
+
|
| 229 |
if self.training_args is None:
|
| 230 |
logger.warning(
|
| 231 |
"Hyperparameter tuning is highly recommended for optimal results. "
|
| 232 |
+
"No training_args provided; using default hyperparameters."
|
|
|
|
| 233 |
)
|
| 234 |
|
| 235 |
self.validate_options()
|
|
|
|
| 244 |
] = self.cell_state_dict["states"]
|
| 245 |
|
| 246 |
# load token dictionary (Ensembl IDs:token)
|
| 247 |
+
if self.token_dictionary_file is None:
|
|
|
|
|
|
|
|
|
|
| 248 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
| 249 |
with open(self.token_dictionary_file, "rb") as f:
|
| 250 |
self.gene_token_dict = pickle.load(f)
|
|
|
|
| 354 |
attr_to_balance=None,
|
| 355 |
max_trials=100,
|
| 356 |
pval_threshold=0.1,
|
|
|
|
| 357 |
):
|
| 358 |
"""
|
| 359 |
Prepare data for cell state or gene classification.
|
|
|
|
| 396 |
pval_threshold : None, float
|
| 397 |
| P-value threshold to use for attribute balancing across splits
|
| 398 |
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
"""
|
| 400 |
|
| 401 |
if test_size is None:
|
|
|
|
| 437 |
)
|
| 438 |
# rename cell state column to "label"
|
| 439 |
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
|
|
|
|
| 441 |
# convert classes to numerical labels and save as id_class_dict
|
| 442 |
# of note, will label all genes in gene_class_dict
|
| 443 |
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
| 444 |
# at the time of training with Classifier.validate
|
| 445 |
+
data, id_class_dict = cu.label_classes(
|
| 446 |
+
self.classifier, data, self.gene_class_dict, self.nproc
|
| 447 |
+
)
|
| 448 |
|
| 449 |
# save id_class_dict for future reference
|
| 450 |
id_class_output_path = (
|
|
|
|
| 772 |
# 5-fold cross-validate
|
| 773 |
num_cells = len(data)
|
| 774 |
fifth_cells = int(np.floor(num_cells * 0.2))
|
| 775 |
+
num_eval = min((self.eval_size * num_cells), fifth_cells)
|
| 776 |
start = i * fifth_cells
|
| 777 |
end = start + num_eval
|
| 778 |
eval_indices = [j for j in range(start, end)]
|
|
|
|
| 1056 |
if eval_data is None:
|
| 1057 |
def_training_args["evaluation_strategy"] = "no"
|
| 1058 |
def_training_args["load_best_model_at_end"] = False
|
|
|
|
|
|
|
| 1059 |
def_training_args.update(
|
| 1060 |
{"save_strategy": "epoch", "save_total_limit": 1}
|
| 1061 |
) # only save last model for each run
|
|
|
|
| 1227 |
if eval_data is None:
|
| 1228 |
def_training_args["evaluation_strategy"] = "no"
|
| 1229 |
def_training_args["load_best_model_at_end"] = False
|
|
|
|
|
|
|
| 1230 |
training_args_init = TrainingArguments(**def_training_args)
|
| 1231 |
|
| 1232 |
if self.freeze_layers is not None:
|
|
|
|
| 1280 |
predict=False,
|
| 1281 |
output_directory=None,
|
| 1282 |
output_prefix=None,
|
|
|
|
| 1283 |
):
|
| 1284 |
"""
|
| 1285 |
Evaluate the fine-tuned model.
|
|
|
|
| 1305 |
|
| 1306 |
##### Evaluate the model #####
|
| 1307 |
labels = id_class_dict.keys()
|
| 1308 |
+
y_pred, y_true, logits_list = eu.classifier_predict(
|
| 1309 |
+
model, self.classifier, eval_data, self.forward_batch_size
|
|
|
|
| 1310 |
)
|
|
|
|
| 1311 |
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
| 1312 |
y_pred, y_true, logits_list, num_classes, labels
|
| 1313 |
)
|
|
|
|
| 1317 |
"label_ids": y_true,
|
| 1318 |
"predictions": logits_list,
|
| 1319 |
}
|
|
|
|
|
|
|
|
|
|
| 1320 |
pred_dict_output_path = (
|
| 1321 |
Path(output_directory) / f"{output_prefix}_pred_dict"
|
| 1322 |
).with_suffix(".pkl")
|
|
|
|
| 1337 |
output_directory,
|
| 1338 |
output_prefix,
|
| 1339 |
predict=True,
|
|
|
|
| 1340 |
):
|
| 1341 |
"""
|
| 1342 |
Evaluate the fine-tuned model.
|
|
|
|
| 1356 |
| Prefix for output files
|
| 1357 |
predict : bool
|
| 1358 |
| Whether or not to save eval predictions
|
|
|
|
|
|
|
| 1359 |
"""
|
| 1360 |
|
| 1361 |
# load numerical id to class dictionary (id:class)
|
|
|
|
| 1368 |
# load previously filtered and prepared data
|
| 1369 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
| 1370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1371 |
# load previously fine-tuned model
|
| 1372 |
model = pu.load_model(
|
| 1373 |
self.model_type,
|
|
|
|
| 1386 |
predict=predict,
|
| 1387 |
output_directory=output_directory,
|
| 1388 |
output_prefix=output_prefix,
|
|
|
|
| 1389 |
)
|
| 1390 |
|
| 1391 |
all_conf_mat_df = pd.DataFrame(
|
geneformer/classifier_utils.py
CHANGED
|
@@ -94,7 +94,7 @@ def remove_rare(data, rare_threshold, label, nproc):
|
|
| 94 |
return data
|
| 95 |
|
| 96 |
|
| 97 |
-
def label_classes(classifier, data, gene_class_dict, nproc
|
| 98 |
if classifier == "cell":
|
| 99 |
label_set = set(data["label"])
|
| 100 |
elif classifier == "gene":
|
|
@@ -113,24 +113,15 @@ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None):
|
|
| 113 |
)
|
| 114 |
raise
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
| 119 |
-
else:
|
| 120 |
-
class_id_dict = {v: k for k, v in id_class_dict.items()}
|
| 121 |
-
|
| 122 |
-
if classifier == "gene":
|
| 123 |
-
inverse_gene_class_dict = {}
|
| 124 |
-
for key, value_list in gene_class_dict.items():
|
| 125 |
-
for value in value_list:
|
| 126 |
-
inverse_gene_class_dict[value] = key
|
| 127 |
|
| 128 |
def classes_to_ids(example):
|
| 129 |
if classifier == "cell":
|
| 130 |
example["label"] = class_id_dict[example["label"]]
|
| 131 |
elif classifier == "gene":
|
| 132 |
example["labels"] = label_gene_classes(
|
| 133 |
-
example, class_id_dict,
|
| 134 |
)
|
| 135 |
return example
|
| 136 |
|
|
@@ -138,9 +129,9 @@ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None):
|
|
| 138 |
return data, id_class_dict
|
| 139 |
|
| 140 |
|
| 141 |
-
def label_gene_classes(example, class_id_dict,
|
| 142 |
return [
|
| 143 |
-
class_id_dict.get(
|
| 144 |
for token_id in example["input_ids"]
|
| 145 |
]
|
| 146 |
|
|
@@ -570,27 +561,6 @@ def compute_metrics(pred):
|
|
| 570 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
| 571 |
|
| 572 |
|
| 573 |
-
def robust_compute_objective(metrics: dict):
|
| 574 |
-
# tries both prefixed ("eval_") and raw metric names to support different transformers versions
|
| 575 |
-
metric_name = "macro_f1"
|
| 576 |
-
|
| 577 |
-
# check for the prefixed version
|
| 578 |
-
prefixed_metric_name = f"eval_{metric_name}"
|
| 579 |
-
if prefixed_metric_name in metrics:
|
| 580 |
-
return metrics[prefixed_metric_name]
|
| 581 |
-
|
| 582 |
-
# fall back to the raw name
|
| 583 |
-
elif metric_name in metrics:
|
| 584 |
-
return metrics[metric_name]
|
| 585 |
-
|
| 586 |
-
# if neither is found, raise a clear error to help with debugging
|
| 587 |
-
raise KeyError(
|
| 588 |
-
f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. "
|
| 589 |
-
f"Please check your `compute_metrics` function and `TrainingArguments`. "
|
| 590 |
-
f"Available metrics: {list(metrics.keys())}"
|
| 591 |
-
)
|
| 592 |
-
|
| 593 |
-
|
| 594 |
def get_default_train_args(model, classifier, data, output_dir):
|
| 595 |
num_layers = pu.quant_layers(model)
|
| 596 |
freeze_layers = 0
|
|
|
|
| 94 |
return data
|
| 95 |
|
| 96 |
|
| 97 |
+
def label_classes(classifier, data, gene_class_dict, nproc):
|
| 98 |
if classifier == "cell":
|
| 99 |
label_set = set(data["label"])
|
| 100 |
elif classifier == "gene":
|
|
|
|
| 113 |
)
|
| 114 |
raise
|
| 115 |
|
| 116 |
+
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
| 117 |
+
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
def classes_to_ids(example):
|
| 120 |
if classifier == "cell":
|
| 121 |
example["label"] = class_id_dict[example["label"]]
|
| 122 |
elif classifier == "gene":
|
| 123 |
example["labels"] = label_gene_classes(
|
| 124 |
+
example, class_id_dict, gene_class_dict
|
| 125 |
)
|
| 126 |
return example
|
| 127 |
|
|
|
|
| 129 |
return data, id_class_dict
|
| 130 |
|
| 131 |
|
| 132 |
+
def label_gene_classes(example, class_id_dict, gene_class_dict):
|
| 133 |
return [
|
| 134 |
+
class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
|
| 135 |
for token_id in example["input_ids"]
|
| 136 |
]
|
| 137 |
|
|
|
|
| 561 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
| 562 |
|
| 563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
def get_default_train_args(model, classifier, data, output_dir):
|
| 565 |
num_layers = pu.quant_layers(model)
|
| 566 |
freeze_layers = 0
|
geneformer/collator_for_classification.py
CHANGED
|
@@ -26,10 +26,9 @@ LARGE_INTEGER = int(
|
|
| 26 |
1e20
|
| 27 |
) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
|
| 28 |
|
| 29 |
-
warnings.filterwarnings("ignore", message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach()", category=UserWarning, module="torch")
|
| 30 |
-
|
| 31 |
# precollator functions
|
| 32 |
|
|
|
|
| 33 |
class ExplicitEnum(Enum):
|
| 34 |
"""
|
| 35 |
Enum with more explicit error message for missing values.
|
|
@@ -104,9 +103,6 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
|
| 104 |
def pad_token_id(self):
|
| 105 |
return self._pad_token_id
|
| 106 |
|
| 107 |
-
def save_pretrained(self, save_directory):
|
| 108 |
-
pass
|
| 109 |
-
|
| 110 |
def _get_padding_truncation_strategies(
|
| 111 |
self,
|
| 112 |
padding=True,
|
|
@@ -645,8 +641,7 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
|
| 645 |
def __call__(self, features):
|
| 646 |
batch = self._prepare_batch(features)
|
| 647 |
|
| 648 |
-
|
| 649 |
-
batch = {k: torch.tensor(v.clone().detach(), dtype=torch.int64) if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
| 650 |
return batch
|
| 651 |
|
| 652 |
|
|
|
|
| 26 |
1e20
|
| 27 |
) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
|
| 28 |
|
|
|
|
|
|
|
| 29 |
# precollator functions
|
| 30 |
|
| 31 |
+
|
| 32 |
class ExplicitEnum(Enum):
|
| 33 |
"""
|
| 34 |
Enum with more explicit error message for missing values.
|
|
|
|
| 103 |
def pad_token_id(self):
|
| 104 |
return self._pad_token_id
|
| 105 |
|
|
|
|
|
|
|
|
|
|
| 106 |
def _get_padding_truncation_strategies(
|
| 107 |
self,
|
| 108 |
padding=True,
|
|
|
|
| 641 |
def __call__(self, features):
|
| 642 |
batch = self._prepare_batch(features)
|
| 643 |
|
| 644 |
+
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
|
|
|
| 645 |
return batch
|
| 646 |
|
| 647 |
|
geneformer/emb_extractor.py
CHANGED
|
@@ -42,8 +42,6 @@ def get_embs(
|
|
| 42 |
special_token=False,
|
| 43 |
summary_stat=None,
|
| 44 |
silent=False,
|
| 45 |
-
save_tdigest=False,
|
| 46 |
-
tdigest_path=None,
|
| 47 |
):
|
| 48 |
model_input_size = pu.get_model_input_size(model)
|
| 49 |
total_batch_length = len(filtered_input_data)
|
|
@@ -182,18 +180,12 @@ def get_embs(
|
|
| 182 |
# calculate summary stat embs from approximated tdigests
|
| 183 |
elif summary_stat is not None:
|
| 184 |
if emb_mode == "cell":
|
| 185 |
-
if save_tdigest:
|
| 186 |
-
with open(f"{tdigest_path}","wb") as fp:
|
| 187 |
-
pickle.dump(embs_tdigests, fp)
|
| 188 |
if summary_stat == "mean":
|
| 189 |
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
| 190 |
elif summary_stat == "median":
|
| 191 |
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
| 192 |
embs_stack = torch.tensor(summary_emb_list)
|
| 193 |
elif emb_mode == "gene":
|
| 194 |
-
if save_tdigest:
|
| 195 |
-
with open(f"{tdigest_path}","wb") as fp:
|
| 196 |
-
pickle.dump(embs_tdigests_dict, fp)
|
| 197 |
if summary_stat == "mean":
|
| 198 |
[
|
| 199 |
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
|
@@ -260,7 +252,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels):
|
|
| 260 |
return embs_df
|
| 261 |
|
| 262 |
|
| 263 |
-
def label_gene_embs(embs, downsampled_data, token_gene_dict
|
| 264 |
gene_set = {
|
| 265 |
element for sublist in downsampled_data["input_ids"] for element in sublist
|
| 266 |
}
|
|
@@ -275,52 +267,25 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mea
|
|
| 275 |
)
|
| 276 |
for k in dict_i.keys():
|
| 277 |
gene_emb_dict[k].append(dict_i[k])
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
gene_emb_dict[k] =
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
embs_df = pd.DataFrame(gene_emb_dict).T
|
| 286 |
-
else:
|
| 287 |
-
embs_df = dict_lol_to_df(gene_emb_dict)
|
| 288 |
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
| 289 |
return embs_df
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
df_data = []
|
| 294 |
-
for key, list_of_lists in data_dict.items():
|
| 295 |
-
for i, sublist in enumerate(list_of_lists):
|
| 296 |
-
row_data = [key, i] + sublist.tolist()
|
| 297 |
-
df_data.append(row_data)
|
| 298 |
-
|
| 299 |
-
# determine column names based on the length of sublists
|
| 300 |
-
# assuming all sublists have the same length
|
| 301 |
-
num_columns_from_sublist = len(list(data_dict.values())[0][0])
|
| 302 |
-
column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)]
|
| 303 |
-
|
| 304 |
-
# create the dataframe
|
| 305 |
-
df = pd.DataFrame(df_data, columns=column_names)
|
| 306 |
-
|
| 307 |
-
# set 'Gene' as the index
|
| 308 |
-
df = df.set_index('Gene')
|
| 309 |
-
|
| 310 |
-
return df
|
| 311 |
-
|
| 312 |
-
def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
|
| 313 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
| 314 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
| 315 |
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
| 316 |
str
|
| 317 |
)
|
| 318 |
vars_dict = {"embs": only_embs_df.columns}
|
| 319 |
-
|
| 320 |
-
obs_dict = {"cell_id": list(only_embs_df.index)}
|
| 321 |
-
for label_i in labels_clean:
|
| 322 |
-
obs_dict[label_i] = list(embs_df[label_i])
|
| 323 |
-
|
| 324 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
| 325 |
sc.tl.pca(adata, svd_solver="arpack")
|
| 326 |
sc.pp.neighbors(adata, random_state=seed)
|
|
@@ -331,26 +296,21 @@ def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory,
|
|
| 331 |
if kwargs_dict is not None:
|
| 332 |
default_kwargs_dict.update(kwargs_dict)
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
|
| 350 |
-
)
|
| 351 |
-
plt.show()
|
| 352 |
-
plt.savefig(output_file, bbox_inches="tight")
|
| 353 |
-
|
| 354 |
def gen_heatmap_class_colors(labels, df):
|
| 355 |
pal = sns.cubehelix_palette(
|
| 356 |
len(Counter(labels).keys()),
|
|
@@ -435,14 +395,13 @@ class EmbExtractor:
|
|
| 435 |
"num_classes": {int},
|
| 436 |
"emb_mode": {"cls", "cell", "gene"},
|
| 437 |
"cell_emb_style": {"mean_pool"},
|
| 438 |
-
"gene_emb_style": {"mean_pool"
|
| 439 |
"filter_data": {None, dict},
|
| 440 |
"max_ncells": {None, int},
|
| 441 |
"emb_layer": {-1, 0},
|
| 442 |
"emb_label": {None, list},
|
| 443 |
"labels_to_plot": {None, list},
|
| 444 |
"forward_batch_size": {int},
|
| 445 |
-
"model_version": {"V1", "V2"},
|
| 446 |
"token_dictionary_file": {None, str},
|
| 447 |
"nproc": {int},
|
| 448 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
|
@@ -452,7 +411,7 @@ class EmbExtractor:
|
|
| 452 |
self,
|
| 453 |
model_type="Pretrained",
|
| 454 |
num_classes=0,
|
| 455 |
-
emb_mode="
|
| 456 |
cell_emb_style="mean_pool",
|
| 457 |
gene_emb_style="mean_pool",
|
| 458 |
filter_data=None,
|
|
@@ -463,8 +422,6 @@ class EmbExtractor:
|
|
| 463 |
forward_batch_size=100,
|
| 464 |
nproc=4,
|
| 465 |
summary_stat=None,
|
| 466 |
-
save_tdigest=False,
|
| 467 |
-
model_version="V2",
|
| 468 |
token_dictionary_file=None,
|
| 469 |
):
|
| 470 |
"""
|
|
@@ -472,8 +429,8 @@ class EmbExtractor:
|
|
| 472 |
|
| 473 |
**Parameters:**
|
| 474 |
|
| 475 |
-
model_type : {"Pretrained", "GeneClassifier", "CellClassifier"
|
| 476 |
-
| Whether model is the pretrained Geneformer
|
| 477 |
num_classes : int
|
| 478 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
| 479 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
|
@@ -483,9 +440,9 @@ class EmbExtractor:
|
|
| 483 |
cell_emb_style : {"mean_pool"}
|
| 484 |
| Method for summarizing cell embeddings if not using CLS token.
|
| 485 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
| 486 |
-
gene_emb_style :
|
| 487 |
| Method for summarizing gene embeddings.
|
| 488 |
-
| Currently only option is
|
| 489 |
filter_data : None, dict
|
| 490 |
| Default is to extract embeddings from all input data.
|
| 491 |
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
|
@@ -515,12 +472,6 @@ class EmbExtractor:
|
|
| 515 |
| If mean or median, outputs only approximated mean or median embedding of input data.
|
| 516 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
| 517 |
| Non-exact is slower but more memory-efficient.
|
| 518 |
-
save_tdigest : bool
|
| 519 |
-
| Whether to save a dictionary of tdigests for each gene and embedding dimension
|
| 520 |
-
| Only applies when summary_stat is not None
|
| 521 |
-
model_version : str
|
| 522 |
-
| To auto-select settings for model version other than current default.
|
| 523 |
-
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
| 524 |
token_dictionary_file : Path
|
| 525 |
| Default is the Geneformer token dictionary
|
| 526 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
|
@@ -551,7 +502,6 @@ class EmbExtractor:
|
|
| 551 |
self.emb_layer = emb_layer
|
| 552 |
self.emb_label = emb_label
|
| 553 |
self.labels_to_plot = labels_to_plot
|
| 554 |
-
self.model_version = model_version
|
| 555 |
self.token_dictionary_file = token_dictionary_file
|
| 556 |
self.forward_batch_size = forward_batch_size
|
| 557 |
self.nproc = nproc
|
|
@@ -561,29 +511,13 @@ class EmbExtractor:
|
|
| 561 |
else:
|
| 562 |
self.summary_stat = summary_stat
|
| 563 |
self.exact_summary_stat = None
|
| 564 |
-
self.save_tdigest = save_tdigest
|
| 565 |
|
| 566 |
self.validate_options()
|
| 567 |
|
| 568 |
-
if (summary_stat is None) and (save_tdigest is True):
|
| 569 |
-
logger.warning(
|
| 570 |
-
"tdigests will not be saved since summary_stat is None."
|
| 571 |
-
)
|
| 572 |
-
save_tdigest = False
|
| 573 |
-
|
| 574 |
-
if self.model_version == "V1":
|
| 575 |
-
from . import TOKEN_DICTIONARY_FILE_30M
|
| 576 |
-
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
| 577 |
-
if self.emb_mode == "cls":
|
| 578 |
-
self.emb_mode = "cell"
|
| 579 |
-
logger.warning(
|
| 580 |
-
"model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
# load token dictionary (Ensembl IDs:token)
|
| 584 |
if self.token_dictionary_file is None:
|
| 585 |
-
|
| 586 |
-
with open(
|
| 587 |
self.gene_token_dict = pickle.load(f)
|
| 588 |
|
| 589 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
|
@@ -662,12 +596,6 @@ class EmbExtractor:
|
|
| 662 |
filtered_input_data = pu.load_and_filter(
|
| 663 |
self.filter_data, self.nproc, input_data_file
|
| 664 |
)
|
| 665 |
-
|
| 666 |
-
# Check to make sure that all the labels exist in the tokenized data:
|
| 667 |
-
if self.emb_label is not None:
|
| 668 |
-
for label in self.emb_label:
|
| 669 |
-
assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
|
| 670 |
-
|
| 671 |
if cell_state is not None:
|
| 672 |
filtered_input_data = pu.filter_by_dict(
|
| 673 |
filtered_input_data, cell_state, self.nproc
|
|
@@ -677,10 +605,6 @@ class EmbExtractor:
|
|
| 677 |
self.model_type, self.num_classes, model_directory, mode="eval"
|
| 678 |
)
|
| 679 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
| 680 |
-
if self.save_tdigest:
|
| 681 |
-
tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl")
|
| 682 |
-
else:
|
| 683 |
-
tdigest_path = None
|
| 684 |
embs = get_embs(
|
| 685 |
model=model,
|
| 686 |
filtered_input_data=downsampled_data,
|
|
@@ -690,8 +614,6 @@ class EmbExtractor:
|
|
| 690 |
forward_batch_size=self.forward_batch_size,
|
| 691 |
token_gene_dict=self.token_gene_dict,
|
| 692 |
summary_stat=self.summary_stat,
|
| 693 |
-
save_tdigest=self.save_tdigest,
|
| 694 |
-
tdigest_path=tdigest_path,
|
| 695 |
)
|
| 696 |
|
| 697 |
if self.emb_mode == "cell":
|
|
@@ -701,7 +623,7 @@ class EmbExtractor:
|
|
| 701 |
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
| 702 |
elif self.emb_mode == "gene":
|
| 703 |
if self.summary_stat is None:
|
| 704 |
-
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict
|
| 705 |
elif self.summary_stat is not None:
|
| 706 |
embs_df = pd.DataFrame(embs).T
|
| 707 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
@@ -717,14 +639,14 @@ class EmbExtractor:
|
|
| 717 |
embs = embs.mean(dim=0)
|
| 718 |
emb_dims = pu.get_model_emb_dims(model)
|
| 719 |
embs_df = pd.DataFrame(
|
| 720 |
-
embs_df
|
| 721 |
columns=[self.exact_summary_stat],
|
| 722 |
).T
|
| 723 |
elif self.exact_summary_stat == "exact_median":
|
| 724 |
embs = torch.median(embs, dim=0)[0]
|
| 725 |
emb_dims = pu.get_model_emb_dims(model)
|
| 726 |
embs_df = pd.DataFrame(
|
| 727 |
-
embs_df
|
| 728 |
columns=[self.exact_summary_stat],
|
| 729 |
).T
|
| 730 |
|
|
@@ -797,12 +719,6 @@ class EmbExtractor:
|
|
| 797 |
)
|
| 798 |
raise
|
| 799 |
|
| 800 |
-
if self.emb_label is not None:
|
| 801 |
-
logger.error(
|
| 802 |
-
"For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
|
| 803 |
-
)
|
| 804 |
-
raise
|
| 805 |
-
|
| 806 |
state_embs_dict = dict()
|
| 807 |
state_key = cell_states_to_model["state_key"]
|
| 808 |
for k, v in cell_states_to_model.items():
|
|
@@ -885,14 +801,14 @@ class EmbExtractor:
|
|
| 885 |
raise
|
| 886 |
|
| 887 |
if max_ncells_to_plot is not None:
|
| 888 |
-
if self.max_ncells
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
|
| 897 |
if self.emb_label is None:
|
| 898 |
label_len = 0
|
|
@@ -913,9 +829,12 @@ class EmbExtractor:
|
|
| 913 |
f"Label {label} from labels_to_plot "
|
| 914 |
f"not present in provided embeddings dataframe."
|
| 915 |
)
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
|
|
|
|
|
|
|
|
|
| 919 |
|
| 920 |
if plot_style == "heatmap":
|
| 921 |
for label in self.labels_to_plot:
|
|
|
|
| 42 |
special_token=False,
|
| 43 |
summary_stat=None,
|
| 44 |
silent=False,
|
|
|
|
|
|
|
| 45 |
):
|
| 46 |
model_input_size = pu.get_model_input_size(model)
|
| 47 |
total_batch_length = len(filtered_input_data)
|
|
|
|
| 180 |
# calculate summary stat embs from approximated tdigests
|
| 181 |
elif summary_stat is not None:
|
| 182 |
if emb_mode == "cell":
|
|
|
|
|
|
|
|
|
|
| 183 |
if summary_stat == "mean":
|
| 184 |
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
| 185 |
elif summary_stat == "median":
|
| 186 |
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
| 187 |
embs_stack = torch.tensor(summary_emb_list)
|
| 188 |
elif emb_mode == "gene":
|
|
|
|
|
|
|
|
|
|
| 189 |
if summary_stat == "mean":
|
| 190 |
[
|
| 191 |
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
|
|
|
| 252 |
return embs_df
|
| 253 |
|
| 254 |
|
| 255 |
+
def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
| 256 |
gene_set = {
|
| 257 |
element for sublist in downsampled_data["input_ids"] for element in sublist
|
| 258 |
}
|
|
|
|
| 267 |
)
|
| 268 |
for k in dict_i.keys():
|
| 269 |
gene_emb_dict[k].append(dict_i[k])
|
| 270 |
+
for k in gene_emb_dict.keys():
|
| 271 |
+
gene_emb_dict[k] = (
|
| 272 |
+
torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
|
| 273 |
+
.cpu()
|
| 274 |
+
.numpy()
|
| 275 |
+
)
|
| 276 |
+
embs_df = pd.DataFrame(gene_emb_dict).T
|
|
|
|
|
|
|
|
|
|
| 277 |
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
| 278 |
return embs_df
|
| 279 |
|
| 280 |
+
|
| 281 |
+
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
| 283 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
| 284 |
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
| 285 |
str
|
| 286 |
)
|
| 287 |
vars_dict = {"embs": only_embs_df.columns}
|
| 288 |
+
obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
| 290 |
sc.tl.pca(adata, svd_solver="arpack")
|
| 291 |
sc.pp.neighbors(adata, random_state=seed)
|
|
|
|
| 296 |
if kwargs_dict is not None:
|
| 297 |
default_kwargs_dict.update(kwargs_dict)
|
| 298 |
|
| 299 |
+
cats = set(embs_df[label])
|
| 300 |
+
|
| 301 |
+
with plt.rc_context():
|
| 302 |
+
ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
|
| 303 |
+
ax.legend(
|
| 304 |
+
markerscale=2,
|
| 305 |
+
frameon=False,
|
| 306 |
+
loc="center left",
|
| 307 |
+
bbox_to_anchor=(1, 0.5),
|
| 308 |
+
ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
|
| 309 |
+
)
|
| 310 |
+
plt.show()
|
| 311 |
+
plt.savefig(output_file, bbox_inches="tight")
|
| 312 |
+
|
| 313 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
def gen_heatmap_class_colors(labels, df):
|
| 315 |
pal = sns.cubehelix_palette(
|
| 316 |
len(Counter(labels).keys()),
|
|
|
|
| 395 |
"num_classes": {int},
|
| 396 |
"emb_mode": {"cls", "cell", "gene"},
|
| 397 |
"cell_emb_style": {"mean_pool"},
|
| 398 |
+
"gene_emb_style": {"mean_pool"},
|
| 399 |
"filter_data": {None, dict},
|
| 400 |
"max_ncells": {None, int},
|
| 401 |
"emb_layer": {-1, 0},
|
| 402 |
"emb_label": {None, list},
|
| 403 |
"labels_to_plot": {None, list},
|
| 404 |
"forward_batch_size": {int},
|
|
|
|
| 405 |
"token_dictionary_file": {None, str},
|
| 406 |
"nproc": {int},
|
| 407 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
|
|
|
| 411 |
self,
|
| 412 |
model_type="Pretrained",
|
| 413 |
num_classes=0,
|
| 414 |
+
emb_mode="cell",
|
| 415 |
cell_emb_style="mean_pool",
|
| 416 |
gene_emb_style="mean_pool",
|
| 417 |
filter_data=None,
|
|
|
|
| 422 |
forward_batch_size=100,
|
| 423 |
nproc=4,
|
| 424 |
summary_stat=None,
|
|
|
|
|
|
|
| 425 |
token_dictionary_file=None,
|
| 426 |
):
|
| 427 |
"""
|
|
|
|
| 429 |
|
| 430 |
**Parameters:**
|
| 431 |
|
| 432 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
|
| 433 |
+
| Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
| 434 |
num_classes : int
|
| 435 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
| 436 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
|
|
|
| 440 |
cell_emb_style : {"mean_pool"}
|
| 441 |
| Method for summarizing cell embeddings if not using CLS token.
|
| 442 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
| 443 |
+
gene_emb_style : "mean_pool"
|
| 444 |
| Method for summarizing gene embeddings.
|
| 445 |
+
| Currently only option is mean pooling of contextual gene embeddings for given gene.
|
| 446 |
filter_data : None, dict
|
| 447 |
| Default is to extract embeddings from all input data.
|
| 448 |
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
|
|
|
| 472 |
| If mean or median, outputs only approximated mean or median embedding of input data.
|
| 473 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
| 474 |
| Non-exact is slower but more memory-efficient.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
token_dictionary_file : Path
|
| 476 |
| Default is the Geneformer token dictionary
|
| 477 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
|
|
|
| 502 |
self.emb_layer = emb_layer
|
| 503 |
self.emb_label = emb_label
|
| 504 |
self.labels_to_plot = labels_to_plot
|
|
|
|
| 505 |
self.token_dictionary_file = token_dictionary_file
|
| 506 |
self.forward_batch_size = forward_batch_size
|
| 507 |
self.nproc = nproc
|
|
|
|
| 511 |
else:
|
| 512 |
self.summary_stat = summary_stat
|
| 513 |
self.exact_summary_stat = None
|
|
|
|
| 514 |
|
| 515 |
self.validate_options()
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
# load token dictionary (Ensembl IDs:token)
|
| 518 |
if self.token_dictionary_file is None:
|
| 519 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
| 520 |
+
with open(token_dictionary_file, "rb") as f:
|
| 521 |
self.gene_token_dict = pickle.load(f)
|
| 522 |
|
| 523 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
|
|
|
| 596 |
filtered_input_data = pu.load_and_filter(
|
| 597 |
self.filter_data, self.nproc, input_data_file
|
| 598 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
if cell_state is not None:
|
| 600 |
filtered_input_data = pu.filter_by_dict(
|
| 601 |
filtered_input_data, cell_state, self.nproc
|
|
|
|
| 605 |
self.model_type, self.num_classes, model_directory, mode="eval"
|
| 606 |
)
|
| 607 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
embs = get_embs(
|
| 609 |
model=model,
|
| 610 |
filtered_input_data=downsampled_data,
|
|
|
|
| 614 |
forward_batch_size=self.forward_batch_size,
|
| 615 |
token_gene_dict=self.token_gene_dict,
|
| 616 |
summary_stat=self.summary_stat,
|
|
|
|
|
|
|
| 617 |
)
|
| 618 |
|
| 619 |
if self.emb_mode == "cell":
|
|
|
|
| 623 |
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
| 624 |
elif self.emb_mode == "gene":
|
| 625 |
if self.summary_stat is None:
|
| 626 |
+
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
|
| 627 |
elif self.summary_stat is not None:
|
| 628 |
embs_df = pd.DataFrame(embs).T
|
| 629 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
|
|
| 639 |
embs = embs.mean(dim=0)
|
| 640 |
emb_dims = pu.get_model_emb_dims(model)
|
| 641 |
embs_df = pd.DataFrame(
|
| 642 |
+
embs_df[0 : emb_dims - 1].mean(axis="rows"),
|
| 643 |
columns=[self.exact_summary_stat],
|
| 644 |
).T
|
| 645 |
elif self.exact_summary_stat == "exact_median":
|
| 646 |
embs = torch.median(embs, dim=0)[0]
|
| 647 |
emb_dims = pu.get_model_emb_dims(model)
|
| 648 |
embs_df = pd.DataFrame(
|
| 649 |
+
embs_df[0 : emb_dims - 1].median(axis="rows"),
|
| 650 |
columns=[self.exact_summary_stat],
|
| 651 |
).T
|
| 652 |
|
|
|
|
| 719 |
)
|
| 720 |
raise
|
| 721 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
state_embs_dict = dict()
|
| 723 |
state_key = cell_states_to_model["state_key"]
|
| 724 |
for k, v in cell_states_to_model.items():
|
|
|
|
| 801 |
raise
|
| 802 |
|
| 803 |
if max_ncells_to_plot is not None:
|
| 804 |
+
if max_ncells_to_plot > self.max_ncells:
|
| 805 |
+
max_ncells_to_plot = self.max_ncells
|
| 806 |
+
logger.warning(
|
| 807 |
+
"max_ncells_to_plot must be <= max_ncells. "
|
| 808 |
+
f"Changing max_ncells_to_plot to {self.max_ncells}."
|
| 809 |
+
)
|
| 810 |
+
elif max_ncells_to_plot < self.max_ncells:
|
| 811 |
+
embs = embs.sample(max_ncells_to_plot, axis=0)
|
| 812 |
|
| 813 |
if self.emb_label is None:
|
| 814 |
label_len = 0
|
|
|
|
| 829 |
f"Label {label} from labels_to_plot "
|
| 830 |
f"not present in provided embeddings dataframe."
|
| 831 |
)
|
| 832 |
+
continue
|
| 833 |
+
output_prefix_label = output_prefix + f"_umap_{label}"
|
| 834 |
+
output_file = (
|
| 835 |
+
Path(output_directory) / output_prefix_label
|
| 836 |
+
).with_suffix(".pdf")
|
| 837 |
+
plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
|
| 838 |
|
| 839 |
if plot_style == "heatmap":
|
| 840 |
for label in self.labels_to_plot:
|
geneformer/{ensembl_mapping_dict_gc104M.pkl → ensembl_mapping_dict_gc95M.pkl}
RENAMED
|
File without changes
|
geneformer/evaluation_utils.py
CHANGED
|
@@ -8,7 +8,6 @@ import numpy as np
|
|
| 8 |
import pandas as pd
|
| 9 |
import seaborn as sns
|
| 10 |
import torch
|
| 11 |
-
import datasets
|
| 12 |
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
| 13 |
from sklearn import preprocessing
|
| 14 |
from sklearn.metrics import (
|
|
@@ -21,15 +20,20 @@ from sklearn.metrics import (
|
|
| 21 |
)
|
| 22 |
from tqdm.auto import trange
|
| 23 |
|
|
|
|
| 24 |
from .emb_extractor import make_colorbar
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
|
| 29 |
-
def preprocess_classifier_batch(cell_batch, max_len, label_name
|
| 30 |
if max_len is None:
|
| 31 |
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def pad_label_example(example):
|
| 34 |
example[label_name] = np.pad(
|
| 35 |
example[label_name],
|
|
@@ -77,7 +81,7 @@ def py_softmax(vector):
|
|
| 77 |
return e / e.sum()
|
| 78 |
|
| 79 |
|
| 80 |
-
def classifier_predict(model, classifier_type, evalset, forward_batch_size
|
| 81 |
if classifier_type == "gene":
|
| 82 |
label_name = "labels"
|
| 83 |
elif classifier_type == "cell":
|
|
@@ -85,14 +89,6 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 85 |
|
| 86 |
predict_logits = []
|
| 87 |
predict_labels = []
|
| 88 |
-
|
| 89 |
-
predict_metadata_all = None
|
| 90 |
-
|
| 91 |
-
if predict_metadata is not None:
|
| 92 |
-
predict_metadata_all = dict()
|
| 93 |
-
for metadata_name in predict_metadata:
|
| 94 |
-
predict_metadata_all[metadata_name] = []
|
| 95 |
-
|
| 96 |
model.eval()
|
| 97 |
|
| 98 |
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
|
@@ -107,21 +103,11 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 107 |
for i in trange(0, evalset_len, forward_batch_size):
|
| 108 |
max_range = min(i + forward_batch_size, evalset_len)
|
| 109 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
| 110 |
-
|
| 111 |
-
if predict_metadata is not None:
|
| 112 |
-
for metadata_name in predict_metadata:
|
| 113 |
-
predict_metadata_all[metadata_name] += batch_evalset[metadata_name]
|
| 114 |
-
|
| 115 |
padded_batch = preprocess_classifier_batch(
|
| 116 |
-
batch_evalset, max_evalset_len, label_name
|
| 117 |
)
|
| 118 |
-
|
| 119 |
padded_batch.set_format(type="torch")
|
| 120 |
|
| 121 |
-
# For datasets>=4.0.0, convert to dict to avoid format issues
|
| 122 |
-
if int(datasets.__version__.split(".")[0]) >= 4:
|
| 123 |
-
padded_batch = padded_batch[:]
|
| 124 |
-
|
| 125 |
input_data_batch = padded_batch["input_ids"]
|
| 126 |
attn_msk_batch = padded_batch["attention_mask"]
|
| 127 |
label_batch = padded_batch[label_name]
|
|
@@ -148,8 +134,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 148 |
y_pred = [vote(item[0]) for item in logit_label_paired]
|
| 149 |
y_true = [item[1] for item in logit_label_paired]
|
| 150 |
logits_list = [item[0] for item in logit_label_paired]
|
| 151 |
-
|
| 152 |
-
return y_pred, y_true, logits_list, predict_metadata_all
|
| 153 |
|
| 154 |
|
| 155 |
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
|
@@ -197,18 +182,14 @@ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix
|
|
| 197 |
for model_name in roc_metric_dict.keys():
|
| 198 |
mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
|
| 199 |
mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
|
|
|
|
|
|
|
| 200 |
color = model_style_dict[model_name]["color"]
|
| 201 |
linestyle = model_style_dict[model_name]["linestyle"]
|
| 202 |
-
if
|
| 203 |
-
|
| 204 |
-
label = f"{model_name} (AUC {all_roc_auc:0.2f})"
|
| 205 |
else:
|
| 206 |
-
|
| 207 |
-
roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
|
| 208 |
-
if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
|
| 209 |
-
label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
|
| 210 |
-
else:
|
| 211 |
-
label = f"{model_name} (AUC {roc_auc:0.2f})"
|
| 212 |
plt.plot(
|
| 213 |
mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
|
| 214 |
)
|
|
|
|
| 8 |
import pandas as pd
|
| 9 |
import seaborn as sns
|
| 10 |
import torch
|
|
|
|
| 11 |
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
| 12 |
from sklearn import preprocessing
|
| 13 |
from sklearn.metrics import (
|
|
|
|
| 20 |
)
|
| 21 |
from tqdm.auto import trange
|
| 22 |
|
| 23 |
+
from . import TOKEN_DICTIONARY_FILE
|
| 24 |
from .emb_extractor import make_colorbar
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
|
| 29 |
+
def preprocess_classifier_batch(cell_batch, max_len, label_name):
|
| 30 |
if max_len is None:
|
| 31 |
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
| 32 |
|
| 33 |
+
# load token dictionary (Ensembl IDs:token)
|
| 34 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
| 35 |
+
gene_token_dict = pickle.load(f)
|
| 36 |
+
|
| 37 |
def pad_label_example(example):
|
| 38 |
example[label_name] = np.pad(
|
| 39 |
example[label_name],
|
|
|
|
| 81 |
return e / e.sum()
|
| 82 |
|
| 83 |
|
| 84 |
+
def classifier_predict(model, classifier_type, evalset, forward_batch_size):
|
| 85 |
if classifier_type == "gene":
|
| 86 |
label_name = "labels"
|
| 87 |
elif classifier_type == "cell":
|
|
|
|
| 89 |
|
| 90 |
predict_logits = []
|
| 91 |
predict_labels = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
model.eval()
|
| 93 |
|
| 94 |
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
|
|
|
| 103 |
for i in trange(0, evalset_len, forward_batch_size):
|
| 104 |
max_range = min(i + forward_batch_size, evalset_len)
|
| 105 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
padded_batch = preprocess_classifier_batch(
|
| 107 |
+
batch_evalset, max_evalset_len, label_name
|
| 108 |
)
|
|
|
|
| 109 |
padded_batch.set_format(type="torch")
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
input_data_batch = padded_batch["input_ids"]
|
| 112 |
attn_msk_batch = padded_batch["attention_mask"]
|
| 113 |
label_batch = padded_batch[label_name]
|
|
|
|
| 134 |
y_pred = [vote(item[0]) for item in logit_label_paired]
|
| 135 |
y_true = [item[1] for item in logit_label_paired]
|
| 136 |
logits_list = [item[0] for item in logit_label_paired]
|
| 137 |
+
return y_pred, y_true, logits_list
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
|
|
|
| 182 |
for model_name in roc_metric_dict.keys():
|
| 183 |
mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
|
| 184 |
mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
|
| 185 |
+
roc_auc = roc_metric_dict[model_name]["roc_auc"]
|
| 186 |
+
roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
|
| 187 |
color = model_style_dict[model_name]["color"]
|
| 188 |
linestyle = model_style_dict[model_name]["linestyle"]
|
| 189 |
+
if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
|
| 190 |
+
label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
|
|
|
|
| 191 |
else:
|
| 192 |
+
label = f"{model_name} (AUC {roc_auc:0.2f})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
plt.plot(
|
| 194 |
mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
|
| 195 |
)
|
geneformer/{gene_median_dictionary_gc104M.pkl → gene_median_dictionary_gc95M.pkl}
RENAMED
|
File without changes
|
geneformer/{gene_name_id_dict_gc104M.pkl → gene_name_id_dict_gc95M.pkl}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b0fd0521406ed18b2e341ef0acb5f53aa1a62457a07ca5840e1c142f46dd326
|
| 3 |
+
size 2038812
|
geneformer/in_silico_perturber.py
CHANGED
|
@@ -40,7 +40,7 @@ import pickle
|
|
| 40 |
from collections import defaultdict
|
| 41 |
|
| 42 |
import torch
|
| 43 |
-
from datasets import Dataset
|
| 44 |
from multiprocess import set_start_method
|
| 45 |
from tqdm.auto import trange
|
| 46 |
|
|
@@ -48,9 +48,7 @@ from . import TOKEN_DICTIONARY_FILE
|
|
| 48 |
from . import perturber_utils as pu
|
| 49 |
from .emb_extractor import get_embs
|
| 50 |
|
| 51 |
-
|
| 52 |
-
datasets.logging.disable_progress_bar()
|
| 53 |
-
|
| 54 |
|
| 55 |
logger = logging.getLogger(__name__)
|
| 56 |
|
|
@@ -62,7 +60,7 @@ class InSilicoPerturber:
|
|
| 62 |
"genes_to_perturb": {"all", list},
|
| 63 |
"combos": {0, 1},
|
| 64 |
"anchor_gene": {None, str},
|
| 65 |
-
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "
|
| 66 |
"num_classes": {int},
|
| 67 |
"emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
|
| 68 |
"cell_emb_style": {"mean_pool"},
|
|
@@ -72,7 +70,6 @@ class InSilicoPerturber:
|
|
| 72 |
"max_ncells": {None, int},
|
| 73 |
"cell_inds_to_perturb": {"all", dict},
|
| 74 |
"emb_layer": {-1, 0},
|
| 75 |
-
"model_version": {"V1", "V2"},
|
| 76 |
"token_dictionary_file": {None, str},
|
| 77 |
"forward_batch_size": {int},
|
| 78 |
"nproc": {int},
|
|
@@ -87,7 +84,7 @@ class InSilicoPerturber:
|
|
| 87 |
anchor_gene=None,
|
| 88 |
model_type="Pretrained",
|
| 89 |
num_classes=0,
|
| 90 |
-
emb_mode="
|
| 91 |
cell_emb_style="mean_pool",
|
| 92 |
filter_data=None,
|
| 93 |
cell_states_to_model=None,
|
|
@@ -97,7 +94,6 @@ class InSilicoPerturber:
|
|
| 97 |
emb_layer=-1,
|
| 98 |
forward_batch_size=100,
|
| 99 |
nproc=4,
|
| 100 |
-
model_version="V2",
|
| 101 |
token_dictionary_file=None,
|
| 102 |
clear_mem_ncells=1000,
|
| 103 |
):
|
|
@@ -134,7 +130,7 @@ class InSilicoPerturber:
|
|
| 134 |
| ENSEMBL ID of gene to use as anchor in combination perturbations.
|
| 135 |
| For example, if combos=1 and anchor_gene="ENSG00000148400":
|
| 136 |
| anchor gene will be perturbed in combination with each other gene.
|
| 137 |
-
model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "
|
| 138 |
| Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
|
| 139 |
num_classes : int
|
| 140 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
|
@@ -186,9 +182,6 @@ class InSilicoPerturber:
|
|
| 186 |
| Batch size for forward pass.
|
| 187 |
nproc : int
|
| 188 |
| Number of CPU processes to use.
|
| 189 |
-
model_version : str
|
| 190 |
-
| To auto-select settings for model version other than current default.
|
| 191 |
-
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
| 192 |
token_dictionary_file : Path
|
| 193 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 194 |
clear_mem_ncells : int
|
|
@@ -229,30 +222,15 @@ class InSilicoPerturber:
|
|
| 229 |
self.emb_layer = emb_layer
|
| 230 |
self.forward_batch_size = forward_batch_size
|
| 231 |
self.nproc = nproc
|
| 232 |
-
self.model_version = model_version
|
| 233 |
self.token_dictionary_file = token_dictionary_file
|
| 234 |
-
self.clear_mem_ncells = clear_mem_ncells
|
| 235 |
|
| 236 |
self.validate_options()
|
| 237 |
|
| 238 |
-
if self.model_version == "V1":
|
| 239 |
-
from . import TOKEN_DICTIONARY_FILE_30M
|
| 240 |
-
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
| 241 |
-
if self.emb_mode == "cls":
|
| 242 |
-
self.emb_mode = "cell"
|
| 243 |
-
logger.warning(
|
| 244 |
-
"model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
|
| 245 |
-
)
|
| 246 |
-
if self.emb_mode == "cls_and_gene":
|
| 247 |
-
self.emb_mode = "cell_and_gene"
|
| 248 |
-
logger.warning(
|
| 249 |
-
"model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
# load token dictionary (Ensembl IDs:token)
|
| 253 |
if self.token_dictionary_file is None:
|
| 254 |
-
|
| 255 |
-
with open(
|
| 256 |
self.gene_token_dict = pickle.load(f)
|
| 257 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
| 258 |
|
|
@@ -816,8 +794,6 @@ class InSilicoPerturber:
|
|
| 816 |
return example
|
| 817 |
|
| 818 |
total_batch_length = len(filtered_input_data)
|
| 819 |
-
|
| 820 |
-
|
| 821 |
if self.cell_states_to_model is None:
|
| 822 |
cos_sims_dict = defaultdict(list)
|
| 823 |
else:
|
|
@@ -902,7 +878,7 @@ class InSilicoPerturber:
|
|
| 902 |
)
|
| 903 |
|
| 904 |
##### CLS and Gene Embedding Mode #####
|
| 905 |
-
elif self.emb_mode == "cls_and_gene":
|
| 906 |
full_original_emb = get_embs(
|
| 907 |
model,
|
| 908 |
minibatch,
|
|
@@ -915,7 +891,6 @@ class InSilicoPerturber:
|
|
| 915 |
silent=True,
|
| 916 |
)
|
| 917 |
indices_to_perturb = perturbation_batch["perturb_index"]
|
| 918 |
-
|
| 919 |
# remove indices that were perturbed
|
| 920 |
original_emb = pu.remove_perturbed_indices_set(
|
| 921 |
full_original_emb,
|
|
@@ -924,7 +899,6 @@ class InSilicoPerturber:
|
|
| 924 |
self.tokens_to_perturb,
|
| 925 |
minibatch["length"],
|
| 926 |
)
|
| 927 |
-
|
| 928 |
full_perturbation_emb = get_embs(
|
| 929 |
model,
|
| 930 |
perturbation_batch,
|
|
@@ -936,7 +910,7 @@ class InSilicoPerturber:
|
|
| 936 |
summary_stat=None,
|
| 937 |
silent=True,
|
| 938 |
)
|
| 939 |
-
|
| 940 |
# remove special tokens and padding
|
| 941 |
original_emb = original_emb[:, 1:-1, :]
|
| 942 |
if self.perturb_type == "overexpress":
|
|
@@ -947,25 +921,9 @@ class InSilicoPerturber:
|
|
| 947 |
perturbation_emb = full_perturbation_emb[
|
| 948 |
:, 1 : max(perturbation_batch["length"]) - 1, :
|
| 949 |
]
|
| 950 |
-
|
| 951 |
-
n_perturbation_genes = perturbation_emb.size()[1]
|
| 952 |
|
| 953 |
-
|
| 954 |
-
if self.perturb_type == "overexpress":
|
| 955 |
-
def calc_perturbation_length(ids):
|
| 956 |
-
if ids == [-100]:
|
| 957 |
-
return 0
|
| 958 |
-
else:
|
| 959 |
-
return len(ids)
|
| 960 |
-
|
| 961 |
-
max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
|
| 962 |
|
| 963 |
-
max_n_overflow = max(minibatch["n_overflow"])
|
| 964 |
-
if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
|
| 965 |
-
original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
|
| 966 |
-
elif perturbation_emb.size()[1] < original_emb.size()[1]:
|
| 967 |
-
original_emb = original_emb[:, 0:max_tensor_size, :]
|
| 968 |
-
|
| 969 |
gene_cos_sims = pu.quant_cos_sims(
|
| 970 |
perturbation_emb,
|
| 971 |
original_emb,
|
|
|
|
| 40 |
from collections import defaultdict
|
| 41 |
|
| 42 |
import torch
|
| 43 |
+
from datasets import Dataset, disable_progress_bars
|
| 44 |
from multiprocess import set_start_method
|
| 45 |
from tqdm.auto import trange
|
| 46 |
|
|
|
|
| 48 |
from . import perturber_utils as pu
|
| 49 |
from .emb_extractor import get_embs
|
| 50 |
|
| 51 |
+
disable_progress_bars()
|
|
|
|
|
|
|
| 52 |
|
| 53 |
logger = logging.getLogger(__name__)
|
| 54 |
|
|
|
|
| 60 |
"genes_to_perturb": {"all", list},
|
| 61 |
"combos": {0, 1},
|
| 62 |
"anchor_gene": {None, str},
|
| 63 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"},
|
| 64 |
"num_classes": {int},
|
| 65 |
"emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
|
| 66 |
"cell_emb_style": {"mean_pool"},
|
|
|
|
| 70 |
"max_ncells": {None, int},
|
| 71 |
"cell_inds_to_perturb": {"all", dict},
|
| 72 |
"emb_layer": {-1, 0},
|
|
|
|
| 73 |
"token_dictionary_file": {None, str},
|
| 74 |
"forward_batch_size": {int},
|
| 75 |
"nproc": {int},
|
|
|
|
| 84 |
anchor_gene=None,
|
| 85 |
model_type="Pretrained",
|
| 86 |
num_classes=0,
|
| 87 |
+
emb_mode="cell",
|
| 88 |
cell_emb_style="mean_pool",
|
| 89 |
filter_data=None,
|
| 90 |
cell_states_to_model=None,
|
|
|
|
| 94 |
emb_layer=-1,
|
| 95 |
forward_batch_size=100,
|
| 96 |
nproc=4,
|
|
|
|
| 97 |
token_dictionary_file=None,
|
| 98 |
clear_mem_ncells=1000,
|
| 99 |
):
|
|
|
|
| 130 |
| ENSEMBL ID of gene to use as anchor in combination perturbations.
|
| 131 |
| For example, if combos=1 and anchor_gene="ENSG00000148400":
|
| 132 |
| anchor gene will be perturbed in combination with each other gene.
|
| 133 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
|
| 134 |
| Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
|
| 135 |
num_classes : int
|
| 136 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
|
|
|
| 182 |
| Batch size for forward pass.
|
| 183 |
nproc : int
|
| 184 |
| Number of CPU processes to use.
|
|
|
|
|
|
|
|
|
|
| 185 |
token_dictionary_file : Path
|
| 186 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 187 |
clear_mem_ncells : int
|
|
|
|
| 222 |
self.emb_layer = emb_layer
|
| 223 |
self.forward_batch_size = forward_batch_size
|
| 224 |
self.nproc = nproc
|
|
|
|
| 225 |
self.token_dictionary_file = token_dictionary_file
|
| 226 |
+
self.clear_mem_ncells = clear_mem_ncells
|
| 227 |
|
| 228 |
self.validate_options()
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
# load token dictionary (Ensembl IDs:token)
|
| 231 |
if self.token_dictionary_file is None:
|
| 232 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
| 233 |
+
with open(token_dictionary_file, "rb") as f:
|
| 234 |
self.gene_token_dict = pickle.load(f)
|
| 235 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
| 236 |
|
|
|
|
| 794 |
return example
|
| 795 |
|
| 796 |
total_batch_length = len(filtered_input_data)
|
|
|
|
|
|
|
| 797 |
if self.cell_states_to_model is None:
|
| 798 |
cos_sims_dict = defaultdict(list)
|
| 799 |
else:
|
|
|
|
| 878 |
)
|
| 879 |
|
| 880 |
##### CLS and Gene Embedding Mode #####
|
| 881 |
+
elif self.emb_mode == "cls_and_gene":
|
| 882 |
full_original_emb = get_embs(
|
| 883 |
model,
|
| 884 |
minibatch,
|
|
|
|
| 891 |
silent=True,
|
| 892 |
)
|
| 893 |
indices_to_perturb = perturbation_batch["perturb_index"]
|
|
|
|
| 894 |
# remove indices that were perturbed
|
| 895 |
original_emb = pu.remove_perturbed_indices_set(
|
| 896 |
full_original_emb,
|
|
|
|
| 899 |
self.tokens_to_perturb,
|
| 900 |
minibatch["length"],
|
| 901 |
)
|
|
|
|
| 902 |
full_perturbation_emb = get_embs(
|
| 903 |
model,
|
| 904 |
perturbation_batch,
|
|
|
|
| 910 |
summary_stat=None,
|
| 911 |
silent=True,
|
| 912 |
)
|
| 913 |
+
|
| 914 |
# remove special tokens and padding
|
| 915 |
original_emb = original_emb[:, 1:-1, :]
|
| 916 |
if self.perturb_type == "overexpress":
|
|
|
|
| 921 |
perturbation_emb = full_perturbation_emb[
|
| 922 |
:, 1 : max(perturbation_batch["length"]) - 1, :
|
| 923 |
]
|
|
|
|
|
|
|
| 924 |
|
| 925 |
+
n_perturbation_genes = perturbation_emb.size()[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 926 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
gene_cos_sims = pu.quant_cos_sims(
|
| 928 |
perturbation_emb,
|
| 929 |
original_emb,
|
geneformer/in_silico_perturber_stats.py
CHANGED
|
@@ -640,16 +640,10 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
|
|
| 640 |
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
| 641 |
|
| 642 |
# quantify number of detections of each gene
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
]
|
| 648 |
-
else:
|
| 649 |
-
cos_sims_full_df["N_Detections"] = [
|
| 650 |
-
n_detections(i, dict_list, "gene", anchor_token)
|
| 651 |
-
for i in cos_sims_full_df["Gene"]
|
| 652 |
-
]
|
| 653 |
|
| 654 |
if combos == 0:
|
| 655 |
cos_sims_full_df = cos_sims_full_df.sort_values(
|
|
@@ -676,7 +670,6 @@ class InSilicoPerturberStats:
|
|
| 676 |
"anchor_gene": {None, str},
|
| 677 |
"cell_states_to_model": {None, dict},
|
| 678 |
"pickle_suffix": {None, str},
|
| 679 |
-
"model_version": {"V1", "V2"},
|
| 680 |
}
|
| 681 |
|
| 682 |
def __init__(
|
|
@@ -687,7 +680,6 @@ class InSilicoPerturberStats:
|
|
| 687 |
anchor_gene=None,
|
| 688 |
cell_states_to_model=None,
|
| 689 |
pickle_suffix="_raw.pickle",
|
| 690 |
-
model_version="V2",
|
| 691 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 692 |
gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
|
| 693 |
):
|
|
@@ -715,7 +707,7 @@ class InSilicoPerturberStats:
|
|
| 715 |
| analyzes data for anchor gene perturbed in combination with each other gene.
|
| 716 |
| However, if combos=0 and anchor_gene="ENSG00000136574":
|
| 717 |
| analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
|
| 718 |
-
cell_states_to_model
|
| 719 |
| Cell states to model if testing perturbations that achieve goal state change.
|
| 720 |
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
| 721 |
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
|
@@ -726,12 +718,6 @@ class InSilicoPerturberStats:
|
|
| 726 |
| "start_state": "dcm",
|
| 727 |
| "goal_state": "nf",
|
| 728 |
| "alt_states": ["hcm", "other1", "other2"]}
|
| 729 |
-
pickle_suffix : None, str
|
| 730 |
-
| Suffix to subselect intermediate raw files for analysis.
|
| 731 |
-
| Default output of InSilicoPerturber uses suffix "_raw.pickle".
|
| 732 |
-
model_version : str
|
| 733 |
-
| To auto-select settings for model version other than current default.
|
| 734 |
-
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
| 735 |
token_dictionary_file : Path
|
| 736 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 737 |
gene_name_id_dictionary_file : Path
|
|
@@ -744,15 +730,9 @@ class InSilicoPerturberStats:
|
|
| 744 |
self.anchor_gene = anchor_gene
|
| 745 |
self.cell_states_to_model = cell_states_to_model
|
| 746 |
self.pickle_suffix = pickle_suffix
|
| 747 |
-
self.model_version = model_version
|
| 748 |
|
| 749 |
self.validate_options()
|
| 750 |
|
| 751 |
-
if self.model_version == "V1":
|
| 752 |
-
from . import ENSEMBL_DICTIONARY_FILE_30M, TOKEN_DICTIONARY_FILE_30M
|
| 753 |
-
token_dictionary_file=TOKEN_DICTIONARY_FILE_30M
|
| 754 |
-
gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE_30M
|
| 755 |
-
|
| 756 |
# load token dictionary (Ensembl IDs:token)
|
| 757 |
with open(token_dictionary_file, "rb") as f:
|
| 758 |
self.gene_token_dict = pickle.load(f)
|
|
|
|
| 640 |
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
| 641 |
|
| 642 |
# quantify number of detections of each gene
|
| 643 |
+
cos_sims_full_df["N_Detections"] = [
|
| 644 |
+
n_detections(i, dict_list, "gene", anchor_token)
|
| 645 |
+
for i in cos_sims_full_df["Gene"]
|
| 646 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
|
| 648 |
if combos == 0:
|
| 649 |
cos_sims_full_df = cos_sims_full_df.sort_values(
|
|
|
|
| 670 |
"anchor_gene": {None, str},
|
| 671 |
"cell_states_to_model": {None, dict},
|
| 672 |
"pickle_suffix": {None, str},
|
|
|
|
| 673 |
}
|
| 674 |
|
| 675 |
def __init__(
|
|
|
|
| 680 |
anchor_gene=None,
|
| 681 |
cell_states_to_model=None,
|
| 682 |
pickle_suffix="_raw.pickle",
|
|
|
|
| 683 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 684 |
gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
|
| 685 |
):
|
|
|
|
| 707 |
| analyzes data for anchor gene perturbed in combination with each other gene.
|
| 708 |
| However, if combos=0 and anchor_gene="ENSG00000136574":
|
| 709 |
| analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
|
| 710 |
+
cell_states_to_model: None, dict
|
| 711 |
| Cell states to model if testing perturbations that achieve goal state change.
|
| 712 |
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
| 713 |
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
|
|
|
| 718 |
| "start_state": "dcm",
|
| 719 |
| "goal_state": "nf",
|
| 720 |
| "alt_states": ["hcm", "other1", "other2"]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
token_dictionary_file : Path
|
| 722 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 723 |
gene_name_id_dictionary_file : Path
|
|
|
|
| 730 |
self.anchor_gene = anchor_gene
|
| 731 |
self.cell_states_to_model = cell_states_to_model
|
| 732 |
self.pickle_suffix = pickle_suffix
|
|
|
|
| 733 |
|
| 734 |
self.validate_options()
|
| 735 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
# load token dictionary (Ensembl IDs:token)
|
| 737 |
with open(token_dictionary_file, "rb") as f:
|
| 738 |
self.gene_token_dict = pickle.load(f)
|
geneformer/mtl/__init__.py
CHANGED
|
@@ -1,4 +1 @@
|
|
| 1 |
-
# ruff: noqa: F401
|
| 2 |
-
|
| 3 |
-
from . import eval_utils
|
| 4 |
-
from . import utils
|
|
|
|
| 1 |
+
# ruff: noqa: F401
|
|
|
|
|
|
|
|
|
geneformer/mtl/collators.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# imports
|
| 2 |
import torch
|
| 3 |
import pickle
|
| 4 |
-
from
|
| 5 |
-
from
|
| 6 |
|
| 7 |
"""Geneformer collator for multi-task cell classification."""
|
| 8 |
|
|
|
|
| 1 |
# imports
|
| 2 |
import torch
|
| 3 |
import pickle
|
| 4 |
+
from ..collator_for_classification import DataCollatorForGeneClassification
|
| 5 |
+
from .. import TOKEN_DICTIONARY_FILE
|
| 6 |
|
| 7 |
"""Geneformer collator for multi-task cell classification."""
|
| 8 |
|
geneformer/mtl/data.py
CHANGED
|
@@ -1,237 +1,150 @@
|
|
| 1 |
import os
|
| 2 |
-
import pickle
|
| 3 |
-
import torch
|
| 4 |
-
from torch.utils.data import DataLoader, Dataset
|
| 5 |
-
from datasets import load_from_disk
|
| 6 |
|
| 7 |
from .collators import DataCollatorForMultitaskCellClassification
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
self.cell_id_mapping = {}
|
| 19 |
-
|
| 20 |
-
# Setup task and column mappings
|
| 21 |
-
self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
|
| 22 |
-
self.task_to_column = dict(zip(self.task_names, config["task_columns"]))
|
| 23 |
-
config["task_names"] = self.task_names
|
| 24 |
-
|
| 25 |
-
# Check if unique_cell_id column exists in the dataset
|
| 26 |
-
self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names
|
| 27 |
-
print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset")
|
| 28 |
-
|
| 29 |
-
# Setup label mappings
|
| 30 |
-
self.label_mappings_path = os.path.join(
|
| 31 |
-
config["results_dir"],
|
| 32 |
-
f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
if not is_test:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def _validate_columns(self):
|
| 45 |
-
"""Ensures required columns are present in the dataset."""
|
| 46 |
-
missing_columns = [col for col in self.task_to_column.values()
|
| 47 |
-
if col not in self.dataset.column_names]
|
| 48 |
-
if missing_columns:
|
| 49 |
-
raise KeyError(
|
| 50 |
-
f"Missing columns in {self.dataset_type} dataset: {missing_columns}. "
|
| 51 |
-
f"Available columns: {self.dataset.column_names}"
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
def _create_label_mappings(self):
|
| 55 |
-
"""Creates label mappings for the dataset."""
|
| 56 |
task_label_mappings = {}
|
|
|
|
| 57 |
num_labels_list = []
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
return len(self.dataset)
|
| 79 |
-
|
| 80 |
-
def __getitem__(self, idx):
|
| 81 |
-
record = self.dataset[idx]
|
| 82 |
-
|
| 83 |
-
# Store cell ID mapping
|
| 84 |
-
if self.has_unique_cell_ids:
|
| 85 |
-
unique_cell_id = record["unique_cell_id"]
|
| 86 |
-
self.cell_id_mapping[idx] = unique_cell_id
|
| 87 |
-
else:
|
| 88 |
-
self.cell_id_mapping[idx] = f"cell_{idx}"
|
| 89 |
-
|
| 90 |
-
# Create transformed record
|
| 91 |
-
transformed_record = {
|
| 92 |
-
"input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
|
| 93 |
-
"cell_id": idx,
|
| 94 |
-
}
|
| 95 |
-
|
| 96 |
-
# Add labels
|
| 97 |
-
if not self.is_test:
|
| 98 |
-
label_dict = {
|
| 99 |
-
task: self.task_label_mappings[task][record[column]]
|
| 100 |
-
for task, column in self.task_to_column.items()
|
| 101 |
-
}
|
| 102 |
else:
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
| 121 |
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
config["train_path"],
|
| 130 |
-
config,
|
| 131 |
-
dataset_type="train"
|
| 132 |
-
)
|
| 133 |
-
result["train_loader"] = get_data_loader(train_dataset, config["batch_size"])
|
| 134 |
-
|
| 135 |
-
# Store the cell ID mapping from the dataset
|
| 136 |
-
result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()}
|
| 137 |
-
print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset")
|
| 138 |
-
|
| 139 |
-
result["num_labels_list"] = train_dataset.num_labels_list
|
| 140 |
-
|
| 141 |
-
# Process validation data
|
| 142 |
-
val_dataset = StreamingMultiTaskDataset(
|
| 143 |
-
config["val_path"],
|
| 144 |
-
config,
|
| 145 |
-
dataset_type="validation"
|
| 146 |
-
)
|
| 147 |
-
result["val_loader"] = get_data_loader(val_dataset, config["batch_size"])
|
| 148 |
-
|
| 149 |
-
# Store the complete cell ID mapping for validation
|
| 150 |
-
for idx in range(len(val_dataset)):
|
| 151 |
-
_ = val_dataset[idx]
|
| 152 |
-
|
| 153 |
-
result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()}
|
| 154 |
-
print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset")
|
| 155 |
-
|
| 156 |
-
# Validate label mappings
|
| 157 |
-
validate_label_mappings(config)
|
| 158 |
-
|
| 159 |
-
# Process test data if requested
|
| 160 |
-
if include_test and "test_path" in config:
|
| 161 |
-
test_dataset = StreamingMultiTaskDataset(
|
| 162 |
-
config["test_path"],
|
| 163 |
-
config,
|
| 164 |
-
is_test=True,
|
| 165 |
-
dataset_type="test"
|
| 166 |
-
)
|
| 167 |
-
result["test_loader"] = get_data_loader(test_dataset, config["batch_size"])
|
| 168 |
-
|
| 169 |
-
for idx in range(len(test_dataset)):
|
| 170 |
-
_ = test_dataset[idx]
|
| 171 |
-
|
| 172 |
-
result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()}
|
| 173 |
-
print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset")
|
| 174 |
-
|
| 175 |
-
return result
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def validate_label_mappings(config):
|
| 179 |
-
"""Ensures train and validation label mappings are consistent."""
|
| 180 |
-
train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
|
| 181 |
-
val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
|
| 182 |
-
|
| 183 |
-
with open(train_mappings_path, "rb") as f:
|
| 184 |
-
train_mappings = pickle.load(f)
|
| 185 |
-
|
| 186 |
-
with open(val_mappings_path, "rb") as f:
|
| 187 |
-
val_mappings = pickle.load(f)
|
| 188 |
-
|
| 189 |
-
for task_name in config["task_names"]:
|
| 190 |
-
if train_mappings[task_name] != val_mappings[task_name]:
|
| 191 |
-
raise ValueError(
|
| 192 |
-
f"Mismatch in label mappings for task '{task_name}'.\n"
|
| 193 |
-
f"Train Mapping: {train_mappings[task_name]}\n"
|
| 194 |
-
f"Validation Mapping: {val_mappings[task_name]}"
|
| 195 |
-
)
|
| 196 |
|
| 197 |
|
| 198 |
-
# Legacy functions for backward compatibility
|
| 199 |
def preload_and_process_data(config):
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
return (
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
)
|
| 210 |
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
def preload_data(config):
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
|
|
|
| 216 |
|
| 217 |
|
| 218 |
def load_and_preprocess_test_data(config):
|
| 219 |
-
"""
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
is_test=True,
|
| 224 |
-
dataset_type="test"
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
return (
|
| 228 |
-
test_dataset,
|
| 229 |
-
test_dataset.cell_id_mapping,
|
| 230 |
-
test_dataset.num_labels_list
|
| 231 |
-
)
|
| 232 |
|
| 233 |
|
| 234 |
def prepare_test_loader(config):
|
| 235 |
-
"""
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from .collators import DataCollatorForMultitaskCellClassification
|
| 4 |
+
from .imports import *
|
| 5 |
|
| 6 |
|
| 7 |
+
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
|
| 8 |
+
try:
|
| 9 |
+
dataset = load_from_disk(dataset_path)
|
| 10 |
+
|
| 11 |
+
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
|
| 12 |
+
task_to_column = dict(zip(task_names, config["task_columns"]))
|
| 13 |
+
config["task_names"] = task_names
|
| 14 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
if not is_test:
|
| 16 |
+
available_columns = set(dataset.column_names)
|
| 17 |
+
for column in task_to_column.values():
|
| 18 |
+
if column not in available_columns:
|
| 19 |
+
raise KeyError(
|
| 20 |
+
f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
label_mappings = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
task_label_mappings = {}
|
| 25 |
+
cell_id_mapping = {}
|
| 26 |
num_labels_list = []
|
| 27 |
+
|
| 28 |
+
# Load or create task label mappings
|
| 29 |
+
if not is_test:
|
| 30 |
+
for task, column in task_to_column.items():
|
| 31 |
+
unique_values = sorted(set(dataset[column])) # Ensure consistency
|
| 32 |
+
label_mappings[column] = {
|
| 33 |
+
label: idx for idx, label in enumerate(unique_values)
|
| 34 |
+
}
|
| 35 |
+
task_label_mappings[task] = label_mappings[column]
|
| 36 |
+
num_labels_list.append(len(unique_values))
|
| 37 |
+
|
| 38 |
+
# Print the mappings for each task with dataset type prefix
|
| 39 |
+
for task, mapping in task_label_mappings.items():
|
| 40 |
+
print(
|
| 41 |
+
f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
|
| 42 |
+
) # sanity check, for train/validation splits
|
| 43 |
+
|
| 44 |
+
# Save the task label mappings as a pickle file
|
| 45 |
+
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
|
| 46 |
+
pickle.dump(task_label_mappings, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
else:
|
| 48 |
+
# Load task label mappings from pickle file for test data
|
| 49 |
+
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
| 50 |
+
task_label_mappings = pickle.load(f)
|
| 51 |
+
|
| 52 |
+
# Infer num_labels_list from task_label_mappings
|
| 53 |
+
for task, mapping in task_label_mappings.items():
|
| 54 |
+
num_labels_list.append(len(mapping))
|
| 55 |
+
|
| 56 |
+
# Store unique cell IDs in a separate dictionary
|
| 57 |
+
for idx, record in enumerate(dataset):
|
| 58 |
+
cell_id = record.get("unique_cell_id", idx)
|
| 59 |
+
cell_id_mapping[idx] = cell_id
|
| 60 |
+
|
| 61 |
+
# Transform records to the desired format
|
| 62 |
+
transformed_dataset = []
|
| 63 |
+
for idx, record in enumerate(dataset):
|
| 64 |
+
transformed_record = {}
|
| 65 |
+
transformed_record["input_ids"] = torch.tensor(
|
| 66 |
+
record["input_ids"], dtype=torch.long
|
| 67 |
+
)
|
| 68 |
|
| 69 |
+
# Use index-based cell ID for internal tracking
|
| 70 |
+
transformed_record["cell_id"] = idx
|
| 71 |
|
| 72 |
+
if not is_test:
|
| 73 |
+
# Prepare labels
|
| 74 |
+
label_dict = {}
|
| 75 |
+
for task, column in task_to_column.items():
|
| 76 |
+
label_value = record[column]
|
| 77 |
+
label_index = task_label_mappings[task][label_value]
|
| 78 |
+
label_dict[task] = label_index
|
| 79 |
+
transformed_record["label"] = label_dict
|
| 80 |
+
else:
|
| 81 |
+
# Create dummy labels for test data
|
| 82 |
+
label_dict = {task: -1 for task in config["task_names"]}
|
| 83 |
+
transformed_record["label"] = label_dict
|
| 84 |
|
| 85 |
+
transformed_dataset.append(transformed_record)
|
| 86 |
|
| 87 |
+
return transformed_dataset, cell_id_mapping, num_labels_list
|
| 88 |
+
except KeyError as e:
|
| 89 |
+
print(f"Missing configuration or dataset key: {e}")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"An error occurred while loading or preprocessing data: {e}")
|
| 92 |
+
return None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
|
|
|
|
| 95 |
def preload_and_process_data(config):
|
| 96 |
+
# Load and preprocess data once
|
| 97 |
+
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
|
| 98 |
+
config["train_path"], config, dataset_type="train"
|
| 99 |
+
)
|
| 100 |
+
val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
|
| 101 |
+
config["val_path"], config, dataset_type="validation"
|
| 102 |
+
)
|
| 103 |
return (
|
| 104 |
+
train_dataset,
|
| 105 |
+
train_cell_id_mapping,
|
| 106 |
+
val_dataset,
|
| 107 |
+
val_cell_id_mapping,
|
| 108 |
+
num_labels_list,
|
| 109 |
)
|
| 110 |
|
| 111 |
|
| 112 |
+
def get_data_loader(preprocessed_dataset, batch_size):
|
| 113 |
+
nproc = os.cpu_count() ### I/O operations
|
| 114 |
+
|
| 115 |
+
data_collator = DataCollatorForMultitaskCellClassification()
|
| 116 |
+
|
| 117 |
+
loader = DataLoader(
|
| 118 |
+
preprocessed_dataset,
|
| 119 |
+
batch_size=batch_size,
|
| 120 |
+
shuffle=True,
|
| 121 |
+
collate_fn=data_collator,
|
| 122 |
+
num_workers=nproc,
|
| 123 |
+
pin_memory=True,
|
| 124 |
+
)
|
| 125 |
+
return loader
|
| 126 |
+
|
| 127 |
+
|
| 128 |
def preload_data(config):
|
| 129 |
+
# Preprocessing the data before the Optuna trials start
|
| 130 |
+
train_loader = get_data_loader("train", config)
|
| 131 |
+
val_loader = get_data_loader("val", config)
|
| 132 |
+
return train_loader, val_loader
|
| 133 |
|
| 134 |
|
| 135 |
def load_and_preprocess_test_data(config):
|
| 136 |
+
"""
|
| 137 |
+
Load and preprocess test data, treating it as unlabeled.
|
| 138 |
+
"""
|
| 139 |
+
return load_and_preprocess_data(config["test_path"], config, is_test=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
def prepare_test_loader(config):
|
| 143 |
+
"""
|
| 144 |
+
Prepare DataLoader for the test dataset.
|
| 145 |
+
"""
|
| 146 |
+
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
|
| 147 |
+
config
|
| 148 |
+
)
|
| 149 |
+
test_loader = get_data_loader(test_dataset, config["batch_size"])
|
| 150 |
+
return test_loader, cell_id_mapping, num_labels_list
|
geneformer/mtl/eval_utils.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import json
|
| 3 |
-
import torch
|
| 4 |
import pandas as pd
|
| 5 |
|
| 6 |
-
from .
|
|
|
|
| 7 |
from .model import GeneformerMultiTask
|
| 8 |
|
|
|
|
| 9 |
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
| 10 |
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
| 11 |
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
| 12 |
cell_ids = []
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
model.eval()
|
| 15 |
with torch.no_grad():
|
| 16 |
for batch in test_loader:
|
|
@@ -82,4 +85,4 @@ def load_and_evaluate_test_model(config):
|
|
| 82 |
best_model.to(device)
|
| 83 |
|
| 84 |
evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
|
| 85 |
-
print("Evaluation completed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
|
| 3 |
+
from .imports import * # noqa # isort:skip
|
| 4 |
+
from .data import prepare_test_loader # noqa # isort:skip
|
| 5 |
from .model import GeneformerMultiTask
|
| 6 |
|
| 7 |
+
|
| 8 |
def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
|
| 9 |
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
| 10 |
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
| 11 |
cell_ids = []
|
| 12 |
|
| 13 |
+
# # Load task label mappings from pickle file
|
| 14 |
+
# with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
| 15 |
+
# task_label_mappings = pickle.load(f)
|
| 16 |
+
|
| 17 |
model.eval()
|
| 18 |
with torch.no_grad():
|
| 19 |
for batch in test_loader:
|
|
|
|
| 85 |
best_model.to(device)
|
| 86 |
|
| 87 |
evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
|
| 88 |
+
print("Evaluation completed.")
|
geneformer/mtl/imports.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import gc
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import sys
|
| 7 |
+
import warnings
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from itertools import chain
|
| 10 |
+
from typing import Dict, List, Optional, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import optuna
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
from datasets import load_from_disk
|
| 20 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
|
| 21 |
+
from sklearn.model_selection import train_test_split
|
| 22 |
+
from sklearn.preprocessing import LabelEncoder
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
from transformers import (
|
| 25 |
+
AdamW,
|
| 26 |
+
BatchEncoding,
|
| 27 |
+
BertConfig,
|
| 28 |
+
BertModel,
|
| 29 |
+
DataCollatorForTokenClassification,
|
| 30 |
+
SpecialTokensMixin,
|
| 31 |
+
get_cosine_schedule_with_warmup,
|
| 32 |
+
get_linear_schedule_with_warmup,
|
| 33 |
+
get_scheduler,
|
| 34 |
+
)
|
| 35 |
+
from transformers.utils import logging, to_py_obj
|
| 36 |
+
|
| 37 |
+
from .collators import DataCollatorForMultitaskCellClassification
|
| 38 |
+
|
| 39 |
+
# local modules
|
| 40 |
+
from .data import get_data_loader, preload_and_process_data
|
| 41 |
+
from .model import GeneformerMultiTask
|
| 42 |
+
from .optuna_utils import create_optuna_study
|
| 43 |
+
from .utils import save_model
|
geneformer/mtl/model.py
CHANGED
|
@@ -118,4 +118,4 @@ class GeneformerMultiTask(nn.Module):
|
|
| 118 |
f"Error during loss computation for task {task_id}: {e}"
|
| 119 |
)
|
| 120 |
|
| 121 |
-
return total_loss, logits, losses if labels is not None else logits
|
|
|
|
| 118 |
f"Error during loss computation for task {task_id}: {e}"
|
| 119 |
)
|
| 120 |
|
| 121 |
+
return total_loss, logits, losses if labels is not None else logits
|
geneformer/mtl/optuna_utils.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
from optuna.integration import TensorBoardCallback
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def save_trial_callback(study, trial, trials_result_path):
|
| 6 |
+
with open(trials_result_path, "a") as f:
|
| 7 |
+
f.write(
|
| 8 |
+
f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
|
| 13 |
+
study = optuna.create_study(direction="maximize")
|
| 14 |
+
|
| 15 |
+
# init TensorBoard callback
|
| 16 |
+
tensorboard_callback = TensorBoardCallback(
|
| 17 |
+
dirname=tensorboard_log_dir, metric_name="F1 Macro"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# callback and TensorBoard callback
|
| 21 |
+
callbacks = [
|
| 22 |
+
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
|
| 23 |
+
tensorboard_callback,
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
|
| 27 |
+
return study
|
geneformer/mtl/train.py
CHANGED
|
@@ -1,707 +1,380 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import torch
|
| 4 |
-
import torch.distributed as dist
|
| 5 |
-
import torch.multiprocessing as mp
|
| 6 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 7 |
from torch.utils.tensorboard import SummaryWriter
|
| 8 |
from tqdm import tqdm
|
| 9 |
-
import optuna
|
| 10 |
-
import functools
|
| 11 |
-
import time
|
| 12 |
|
|
|
|
| 13 |
from .model import GeneformerMultiTask
|
| 14 |
-
from .utils import
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
labels = [
|
| 98 |
-
batch["labels"][task_name].to(
|
|
|
|
| 99 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
loss = loss / accumulation_steps
|
| 107 |
-
|
| 108 |
-
forward_end = time.time()
|
| 109 |
-
forward_times.append(forward_end - forward_start)
|
| 110 |
-
|
| 111 |
-
# Track loss - store the unscaled loss for reporting
|
| 112 |
-
unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps)
|
| 113 |
-
total_loss += unscaled_loss
|
| 114 |
-
num_batches += 1
|
| 115 |
-
accumulated_loss += loss.item() # For gradient accumulation tracking
|
| 116 |
-
|
| 117 |
-
backward_start = time.time()
|
| 118 |
-
|
| 119 |
-
# Use no_sync() for all but the last accumulation step to avoid unnecessary communication
|
| 120 |
-
if self.is_distributed and accumulation_steps > 1:
|
| 121 |
-
# If this is not the last accumulation step or the last batch
|
| 122 |
-
if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader):
|
| 123 |
-
with self.model.no_sync():
|
| 124 |
-
loss.backward()
|
| 125 |
-
else:
|
| 126 |
-
loss.backward()
|
| 127 |
-
else:
|
| 128 |
-
# Non-distributed training or accumulation_steps=1
|
| 129 |
-
loss.backward()
|
| 130 |
-
|
| 131 |
-
backward_end = time.time()
|
| 132 |
-
backward_times.append(backward_end - backward_start)
|
| 133 |
-
|
| 134 |
-
# Only update weights after accumulation_steps or at the end of the epoch
|
| 135 |
-
if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
|
| 136 |
-
if self.config["gradient_clipping"]:
|
| 137 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
|
| 138 |
-
|
| 139 |
-
optimizer_start = time.time()
|
| 140 |
-
self.optimizer.step()
|
| 141 |
-
self.scheduler.step()
|
| 142 |
-
self.optimizer.zero_grad()
|
| 143 |
-
optimizer_end = time.time()
|
| 144 |
-
optimizer_times.append(optimizer_end - optimizer_start)
|
| 145 |
-
|
| 146 |
-
# Log after optimizer step
|
| 147 |
-
if self.is_main_process:
|
| 148 |
-
# Calculate running average loss
|
| 149 |
-
avg_loss = total_loss / num_batches
|
| 150 |
-
|
| 151 |
-
log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx)
|
| 152 |
-
|
| 153 |
-
# Update progress bar with just the running average loss
|
| 154 |
-
progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
|
| 155 |
-
|
| 156 |
-
accumulated_loss = 0.0
|
| 157 |
-
else:
|
| 158 |
-
optimizer_times.append(0) # No optimizer step taken
|
| 159 |
-
|
| 160 |
-
batch_end = time.time()
|
| 161 |
-
batch_times.append(batch_end - batch_start)
|
| 162 |
-
|
| 163 |
-
epoch_end = time.time()
|
| 164 |
-
|
| 165 |
-
# Calculate average loss for the epoch
|
| 166 |
-
epoch_avg_loss = total_loss / num_batches
|
| 167 |
-
|
| 168 |
-
# If distributed, gather losses from all processes to compute global average
|
| 169 |
-
if self.is_distributed:
|
| 170 |
-
# Create a tensor to hold the loss
|
| 171 |
-
loss_tensor = torch.tensor([epoch_avg_loss], device=self.device)
|
| 172 |
-
# Gather losses from all processes
|
| 173 |
-
all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
|
| 174 |
-
dist.all_gather(all_losses, loss_tensor)
|
| 175 |
-
# Compute the global average loss across all processes
|
| 176 |
-
epoch_avg_loss = torch.mean(torch.stack(all_losses)).item()
|
| 177 |
-
|
| 178 |
-
if self.is_main_process:
|
| 179 |
-
# douhble check if batch_size has already been adjusted for world_size in the config
|
| 180 |
-
# This avoids double-counting the effective batch size
|
| 181 |
-
per_gpu_batch_size = self.config['batch_size']
|
| 182 |
-
total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size
|
| 183 |
-
|
| 184 |
-
print(f"Epoch {epoch+1} timing:")
|
| 185 |
-
print(f" Total epoch time: {epoch_end - epoch_start:.2f}s")
|
| 186 |
-
print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
|
| 187 |
-
print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s")
|
| 188 |
-
print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s")
|
| 189 |
-
print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s")
|
| 190 |
-
print(f" Gradient accumulation steps: {accumulation_steps}")
|
| 191 |
-
print(f" Batch size per GPU: {per_gpu_batch_size}")
|
| 192 |
-
print(f" Effective global batch size: {total_effective_batch}")
|
| 193 |
-
print(f" Average training loss: {epoch_avg_loss:.4f}")
|
| 194 |
-
if self.is_distributed:
|
| 195 |
-
print(f" Total batches processed across all GPUs: {total_batches_global}")
|
| 196 |
-
print(f" Communication optimization: Using no_sync() for gradient accumulation")
|
| 197 |
-
|
| 198 |
-
return epoch_avg_loss # Return the average loss for the epoch
|
| 199 |
-
|
| 200 |
-
def validate_model(self, val_loader):
|
| 201 |
-
val_start = time.time()
|
| 202 |
-
self.model.eval()
|
| 203 |
-
val_loss = 0.0
|
| 204 |
-
task_true_labels = {task_name: [] for task_name in self.config["task_names"]}
|
| 205 |
-
task_pred_labels = {task_name: [] for task_name in self.config["task_names"]}
|
| 206 |
-
task_pred_probs = {task_name: [] for task_name in self.config["task_names"]}
|
| 207 |
-
|
| 208 |
-
val_cell_ids = {}
|
| 209 |
-
sample_counter = 0
|
| 210 |
-
|
| 211 |
-
batch_times = []
|
| 212 |
-
|
| 213 |
-
# Print validation dataset size
|
| 214 |
-
if self.is_main_process:
|
| 215 |
-
print(f"Validation dataset size: {len(val_loader.dataset)} samples")
|
| 216 |
-
print(f"Number of validation batches: {len(val_loader)}")
|
| 217 |
-
|
| 218 |
-
if self.is_distributed:
|
| 219 |
-
world_size = dist.get_world_size()
|
| 220 |
-
print(f"Distributed validation: {world_size} GPUs")
|
| 221 |
-
if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
|
| 222 |
-
samples_per_gpu = val_loader.sampler.num_samples
|
| 223 |
-
print(f"Each GPU processes {samples_per_gpu} validation samples")
|
| 224 |
-
print(f"Total validation samples processed: {samples_per_gpu * world_size}")
|
| 225 |
-
|
| 226 |
-
with torch.no_grad():
|
| 227 |
-
for batch in val_loader:
|
| 228 |
-
batch_start = time.time()
|
| 229 |
-
input_ids = batch["input_ids"].to(self.device)
|
| 230 |
-
attention_mask = batch["attention_mask"].to(self.device)
|
| 231 |
-
labels = [
|
| 232 |
-
batch["labels"][task_name].to(self.device)
|
| 233 |
-
for task_name in self.config["task_names"]
|
| 234 |
-
]
|
| 235 |
-
loss, logits, _ = self.model(input_ids, attention_mask, labels)
|
| 236 |
-
val_loss += loss.item()
|
| 237 |
-
|
| 238 |
-
if "cell_id" in batch:
|
| 239 |
-
for i, cell_id in enumerate(batch["cell_id"]):
|
| 240 |
-
# Store the actual index for later mapping to unique_cell_id
|
| 241 |
-
val_cell_ids[sample_counter + i] = cell_id.item()
|
| 242 |
-
|
| 243 |
-
for sample_idx in range(len(batch["input_ids"])):
|
| 244 |
-
for i, task_name in enumerate(self.config["task_names"]):
|
| 245 |
-
true_label = batch["labels"][task_name][sample_idx].item()
|
| 246 |
-
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
| 247 |
-
# Store the full probability distribution
|
| 248 |
-
pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist()
|
| 249 |
-
task_true_labels[task_name].append(true_label)
|
| 250 |
-
task_pred_labels[task_name].append(pred_label)
|
| 251 |
-
task_pred_probs[task_name].append(pred_prob)
|
| 252 |
-
|
| 253 |
-
# Update current index for cell ID tracking
|
| 254 |
-
sample_counter += len(batch["input_ids"])
|
| 255 |
-
|
| 256 |
-
batch_end = time.time()
|
| 257 |
-
batch_times.append(batch_end - batch_start)
|
| 258 |
-
|
| 259 |
-
# norm validation loss by the number of batches
|
| 260 |
-
val_loss /= len(val_loader)
|
| 261 |
-
|
| 262 |
-
# distributed, gather results from all processes
|
| 263 |
-
if self.is_distributed:
|
| 264 |
-
# Create tensors to hold the local results
|
| 265 |
-
loss_tensor = torch.tensor([val_loss], device=self.device)
|
| 266 |
-
gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
|
| 267 |
-
dist.all_gather(gathered_losses, loss_tensor)
|
| 268 |
-
|
| 269 |
-
# Compute average loss across all processes
|
| 270 |
-
val_loss = torch.mean(torch.cat(gathered_losses)).item()
|
| 271 |
-
|
| 272 |
-
world_size = dist.get_world_size()
|
| 273 |
-
|
| 274 |
-
if self.is_main_process:
|
| 275 |
-
print(f"Collected predictions from rank {self.local_rank}")
|
| 276 |
-
print(f"Number of samples processed by this rank: {sample_counter}")
|
| 277 |
-
|
| 278 |
-
val_end = time.time()
|
| 279 |
-
|
| 280 |
-
if self.is_main_process:
|
| 281 |
-
print(f"Validation timing:")
|
| 282 |
-
print(f" Total validation time: {val_end - val_start:.2f}s")
|
| 283 |
-
print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
|
| 284 |
-
print(f" Collected {len(val_cell_ids)} cell indices from validation data")
|
| 285 |
-
print(f" Processed {sample_counter} total samples during validation")
|
| 286 |
-
|
| 287 |
-
# Print number of samples per task
|
| 288 |
-
for task_name in self.config["task_names"]:
|
| 289 |
-
print(f" Task {task_name}: {len(task_true_labels[task_name])} samples")
|
| 290 |
-
|
| 291 |
-
return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids
|
| 292 |
-
|
| 293 |
-
def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
|
| 294 |
-
"""Train the model and return validation loss and trained model."""
|
| 295 |
-
if self.config.get("use_wandb", False) and self.is_main_process:
|
| 296 |
-
initialize_wandb(self.config)
|
| 297 |
-
|
| 298 |
-
# Create model
|
| 299 |
-
self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank)
|
| 300 |
-
|
| 301 |
-
# Setup optimizer and scheduler
|
| 302 |
-
total_steps = len(train_loader) * self.config["epochs"]
|
| 303 |
-
self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps)
|
| 304 |
-
|
| 305 |
-
# Training loop
|
| 306 |
-
if self.is_main_process:
|
| 307 |
-
epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress")
|
| 308 |
-
else:
|
| 309 |
-
epoch_progress = range(self.config["epochs"])
|
| 310 |
-
|
| 311 |
-
best_val_loss = float('inf')
|
| 312 |
-
train_losses = []
|
| 313 |
-
|
| 314 |
-
with setup_logging(self.config) as self.writer:
|
| 315 |
-
for epoch in epoch_progress:
|
| 316 |
-
if self.is_distributed:
|
| 317 |
-
train_loader.sampler.set_epoch(epoch)
|
| 318 |
-
|
| 319 |
-
train_loss = self.train_epoch(train_loader, epoch)
|
| 320 |
-
train_losses.append(train_loss)
|
| 321 |
-
|
| 322 |
-
# Run validation after each epoch if configured to do so
|
| 323 |
-
if self.config.get("validate_each_epoch", False):
|
| 324 |
-
val_loss, _, _, _, _ = self.validate_model(val_loader)
|
| 325 |
-
if val_loss < best_val_loss:
|
| 326 |
-
best_val_loss = val_loss
|
| 327 |
-
|
| 328 |
-
if self.is_main_process:
|
| 329 |
-
epoch_progress.set_postfix({
|
| 330 |
-
"train_loss": f"{train_loss:.4f}",
|
| 331 |
-
"val_loss": f"{val_loss:.4f}",
|
| 332 |
-
"best_val_loss": f"{best_val_loss:.4f}"
|
| 333 |
-
})
|
| 334 |
-
else:
|
| 335 |
-
if self.is_main_process:
|
| 336 |
-
epoch_progress.set_postfix({
|
| 337 |
-
"train_loss": f"{train_loss:.4f}"
|
| 338 |
-
})
|
| 339 |
-
|
| 340 |
-
val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader)
|
| 341 |
-
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
|
| 342 |
-
|
| 343 |
-
if self.is_main_process:
|
| 344 |
-
log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"])
|
| 345 |
-
|
| 346 |
-
# Save validation predictions
|
| 347 |
-
save_validation_predictions(
|
| 348 |
-
val_cell_ids,
|
| 349 |
-
task_true_labels,
|
| 350 |
-
task_pred_labels,
|
| 351 |
-
task_pred_probs,
|
| 352 |
-
{**self.config, "val_cell_mapping": val_cell_id_mapping} # Include the mapping
|
| 353 |
-
)
|
| 354 |
-
|
| 355 |
-
if self.config.get("use_wandb", False):
|
| 356 |
-
import wandb
|
| 357 |
-
wandb.finish()
|
| 358 |
-
|
| 359 |
-
print(f"\nTraining Summary:")
|
| 360 |
-
print(f" Final Training Loss: {train_losses[-1]:.4f}")
|
| 361 |
-
print(f" Final Validation Loss: {val_loss:.4f}")
|
| 362 |
-
for task_name, metrics in task_metrics.items():
|
| 363 |
-
print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
|
| 364 |
-
|
| 365 |
-
return val_loss, self.model # Return both the validation loss and the trained model
|
| 366 |
-
|
| 367 |
-
def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
|
| 368 |
-
if self.is_distributed:
|
| 369 |
-
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 370 |
-
else:
|
| 371 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 372 |
-
|
| 373 |
-
self.model = create_model(self.config, num_labels_list, self.device)
|
| 374 |
-
|
| 375 |
-
# war model w DDP
|
| 376 |
-
if self.is_distributed:
|
| 377 |
-
self.model = DDP(self.model, device_ids=[self.local_rank])
|
| 378 |
-
|
| 379 |
-
# communication hook to optimize gradient synchronization
|
| 380 |
-
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
|
| 381 |
-
|
| 382 |
-
# default hook which maintains full precision
|
| 383 |
-
self.model.register_comm_hook(
|
| 384 |
-
state=None,
|
| 385 |
-
hook=comm_hooks.allreduce_hook
|
| 386 |
)
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch")
|
| 394 |
-
|
| 395 |
-
if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
|
| 396 |
-
print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples")
|
| 397 |
-
|
| 398 |
-
# Set up optimizer and scheduler
|
| 399 |
-
self.optimizer, self.scheduler = setup_optimizer_and_scheduler(
|
| 400 |
-
self.model, self.config, len(train_loader)
|
| 401 |
)
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
if
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
)
|
| 424 |
-
|
| 425 |
-
return result
|
| 426 |
-
else:
|
| 427 |
-
print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
|
| 428 |
-
config["distributed_training"] = False
|
| 429 |
-
|
| 430 |
-
# Non-distributed training
|
| 431 |
-
trainer = Trainer(config)
|
| 432 |
-
trainer.device = device
|
| 433 |
-
return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
|
| 434 |
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
| 438 |
train_loader,
|
| 439 |
val_loader,
|
| 440 |
train_cell_id_mapping,
|
| 441 |
val_cell_id_mapping,
|
| 442 |
num_labels_list,
|
| 443 |
-
config,
|
| 444 |
-
device,
|
| 445 |
):
|
| 446 |
-
"""Objective function for Optuna hyperparameter optimization."""
|
| 447 |
set_seed(config["seed"])
|
| 448 |
initialize_wandb(config)
|
| 449 |
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
f"task_weight_{i}",
|
| 462 |
-
param_config["low"],
|
| 463 |
-
param_config["high"],
|
| 464 |
-
)
|
| 465 |
-
for i in range(len(num_labels_list))
|
| 466 |
-
]
|
| 467 |
-
weight_sum = sum(weights)
|
| 468 |
-
trial_config[param_name] = [w / weight_sum for w in weights]
|
| 469 |
-
elif "log" in param_config and param_config["log"]:
|
| 470 |
-
trial_config[param_name] = trial.suggest_float(
|
| 471 |
-
param_name, param_config["low"], param_config["high"], log=True
|
| 472 |
-
)
|
| 473 |
-
else:
|
| 474 |
-
trial_config[param_name] = trial.suggest_float(
|
| 475 |
-
param_name, param_config["low"], param_config["high"]
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
# Set appropriate max layers to freeze based on pretrained model
|
| 479 |
-
if "max_layers_to_freeze" in trial_config:
|
| 480 |
-
freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
|
| 481 |
-
trial_config["max_layers_to_freeze"] = int(trial.suggest_int(
|
| 482 |
-
"max_layers_to_freeze",
|
| 483 |
-
freeze_range["min"],
|
| 484 |
-
freeze_range["max"]
|
| 485 |
-
))
|
| 486 |
-
|
| 487 |
-
trial_config["run_name"] = f"trial_{trial.number}"
|
| 488 |
-
|
| 489 |
-
# Handle distributed training for this trial
|
| 490 |
-
if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1:
|
| 491 |
-
manager = mp.Manager()
|
| 492 |
-
shared_dict = manager.dict()
|
| 493 |
-
|
| 494 |
-
train_distributed(
|
| 495 |
-
Trainer,
|
| 496 |
-
trial_config,
|
| 497 |
-
train_loader,
|
| 498 |
-
val_loader,
|
| 499 |
-
train_cell_id_mapping,
|
| 500 |
-
val_cell_id_mapping,
|
| 501 |
-
num_labels_list,
|
| 502 |
-
trial.number,
|
| 503 |
-
shared_dict
|
| 504 |
-
)
|
| 505 |
-
|
| 506 |
-
val_loss = shared_dict.get('val_loss', float('inf'))
|
| 507 |
-
task_metrics = shared_dict.get('task_metrics', {})
|
| 508 |
-
|
| 509 |
-
trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {}))
|
| 510 |
-
trial.set_user_attr("task_weights", trial_config["task_weights"])
|
| 511 |
-
|
| 512 |
-
if config.get("use_wandb", False):
|
| 513 |
-
import wandb
|
| 514 |
-
wandb.log({
|
| 515 |
-
"trial_number": trial.number,
|
| 516 |
-
"val_loss": val_loss,
|
| 517 |
-
**{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
|
| 518 |
-
**{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
|
| 519 |
-
})
|
| 520 |
-
wandb.finish()
|
| 521 |
-
|
| 522 |
-
return val_loss
|
| 523 |
-
|
| 524 |
-
with setup_logging(trial_config) as writer:
|
| 525 |
-
trainer = Trainer(trial_config)
|
| 526 |
-
trainer.device = device
|
| 527 |
-
trainer.writer = writer
|
| 528 |
-
|
| 529 |
-
# Create model with trial hyperparameters
|
| 530 |
-
trainer.model = create_model(trial_config, num_labels_list, device)
|
| 531 |
-
total_steps = len(train_loader) * config["epochs"]
|
| 532 |
-
trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps)
|
| 533 |
-
|
| 534 |
-
# Training loop
|
| 535 |
-
for epoch in range(config["epochs"]):
|
| 536 |
-
trainer.train_epoch(train_loader, epoch)
|
| 537 |
-
|
| 538 |
-
val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader)
|
| 539 |
-
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
|
| 540 |
-
|
| 541 |
-
# Log metrics
|
| 542 |
-
log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"])
|
| 543 |
-
|
| 544 |
-
# Save validation predictions
|
| 545 |
-
save_validation_predictions(
|
| 546 |
-
val_cell_ids,
|
| 547 |
-
task_true_labels,
|
| 548 |
-
task_pred_labels,
|
| 549 |
-
task_pred_probs,
|
| 550 |
-
{**trial_config, "val_cell_mapping": val_cell_id_mapping},
|
| 551 |
-
trial.number,
|
| 552 |
)
|
|
|
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
|
|
|
| 557 |
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
if trial.should_prune():
|
| 561 |
-
raise optuna.TrialPruned()
|
| 562 |
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
{
|
| 567 |
-
"trial_number": trial.number,
|
| 568 |
-
"val_loss": val_loss,
|
| 569 |
-
**{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
|
| 570 |
-
**{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
|
| 571 |
-
**{k: v for k, v in trial_config.items() if k in [
|
| 572 |
-
"learning_rate", "warmup_ratio", "weight_decay", "dropout_rate",
|
| 573 |
-
"lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"
|
| 574 |
-
]},
|
| 575 |
-
}
|
| 576 |
-
)
|
| 577 |
-
wandb.finish()
|
| 578 |
|
| 579 |
-
|
|
|
|
| 580 |
|
|
|
|
| 581 |
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
(
|
| 585 |
-
device,
|
| 586 |
-
train_loader,
|
| 587 |
-
val_loader,
|
| 588 |
-
train_cell_id_mapping,
|
| 589 |
-
val_cell_id_mapping,
|
| 590 |
-
num_labels_list,
|
| 591 |
-
) = prepare_training_environment(config)
|
| 592 |
|
| 593 |
-
print("\nManual hyperparameters being used:")
|
| 594 |
-
for key, value in config["manual_hyperparameters"].items():
|
| 595 |
-
print(f"{key}: {value}")
|
| 596 |
-
print()
|
| 597 |
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
|
| 602 |
-
#
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
)
|
| 612 |
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
-
|
|
|
|
|
|
|
| 635 |
|
|
|
|
|
|
|
| 636 |
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
device,
|
| 642 |
-
train_loader,
|
| 643 |
-
val_loader,
|
| 644 |
-
train_cell_id_mapping,
|
| 645 |
-
val_cell_id_mapping,
|
| 646 |
-
num_labels_list,
|
| 647 |
-
) = prepare_training_environment(config)
|
| 648 |
-
|
| 649 |
-
# If manual hyperparameters are specified, use them instead of running Optuna
|
| 650 |
-
if config.get("use_manual_hyperparameters", False):
|
| 651 |
-
return run_manual_tuning(config)
|
| 652 |
-
|
| 653 |
-
# Create a partial function with fixed arguments for the objective
|
| 654 |
-
objective_with_config_and_data = functools.partial(
|
| 655 |
-
objective,
|
| 656 |
-
train_loader=train_loader,
|
| 657 |
-
val_loader=val_loader,
|
| 658 |
-
train_cell_id_mapping=train_cell_id_mapping,
|
| 659 |
-
val_cell_id_mapping=val_cell_id_mapping,
|
| 660 |
-
num_labels_list=num_labels_list,
|
| 661 |
-
config=config,
|
| 662 |
-
device=device,
|
| 663 |
-
)
|
| 664 |
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
direction="minimize", # Minimize validation loss
|
| 668 |
-
study_name=config["study_name"],
|
| 669 |
-
# storage=config["storage"],
|
| 670 |
-
load_if_exists=True,
|
| 671 |
)
|
|
|
|
| 672 |
|
| 673 |
-
|
|
|
|
| 674 |
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
config["pretrained_path"],
|
| 683 |
-
num_labels_list,
|
| 684 |
-
dropout_rate=best_params["dropout_rate"],
|
| 685 |
-
use_task_weights=config["use_task_weights"],
|
| 686 |
-
task_weights=best_task_weights,
|
| 687 |
-
max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0),
|
| 688 |
-
use_attention_pooling=best_params.get("use_attention_pooling", False),
|
| 689 |
)
|
| 690 |
|
| 691 |
-
|
| 692 |
-
|
| 693 |
|
| 694 |
-
|
| 695 |
-
k.replace("module.", ""): v for k, v in best_model_state_dict.items()
|
| 696 |
-
}
|
| 697 |
|
| 698 |
-
|
|
|
|
| 699 |
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
)
|
| 703 |
-
save_model(best_model, model_save_directory)
|
| 704 |
|
| 705 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 706 |
|
| 707 |
-
return
|
|
|
|
| 1 |
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
import pandas as pd
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
| 7 |
from torch.utils.tensorboard import SummaryWriter
|
| 8 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
from .imports import *
|
| 11 |
from .model import GeneformerMultiTask
|
| 12 |
+
from .utils import calculate_task_specific_metrics, get_layer_freeze_range
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def set_seed(seed):
|
| 16 |
+
random.seed(seed)
|
| 17 |
+
np.random.seed(seed)
|
| 18 |
+
torch.manual_seed(seed)
|
| 19 |
+
torch.cuda.manual_seed_all(seed)
|
| 20 |
+
torch.backends.cudnn.deterministic = True
|
| 21 |
+
torch.backends.cudnn.benchmark = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def initialize_wandb(config):
|
| 25 |
+
if config.get("use_wandb", False):
|
| 26 |
+
import wandb
|
| 27 |
+
|
| 28 |
+
wandb.init(project=config["wandb_project"], config=config)
|
| 29 |
+
print("Weights & Biases (wandb) initialized and will be used for logging.")
|
| 30 |
+
else:
|
| 31 |
+
print(
|
| 32 |
+
"Weights & Biases (wandb) is not enabled. Logging will use other methods."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_model(config, num_labels_list, device):
|
| 37 |
+
model = GeneformerMultiTask(
|
| 38 |
+
config["pretrained_path"],
|
| 39 |
+
num_labels_list,
|
| 40 |
+
dropout_rate=config["dropout_rate"],
|
| 41 |
+
use_task_weights=config["use_task_weights"],
|
| 42 |
+
task_weights=config["task_weights"],
|
| 43 |
+
max_layers_to_freeze=config["max_layers_to_freeze"],
|
| 44 |
+
use_attention_pooling=config["use_attention_pooling"],
|
| 45 |
+
)
|
| 46 |
+
if config["use_data_parallel"]:
|
| 47 |
+
model = nn.DataParallel(model)
|
| 48 |
+
return model.to(device)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def setup_optimizer_and_scheduler(model, config, total_steps):
|
| 52 |
+
optimizer = AdamW(
|
| 53 |
+
model.parameters(),
|
| 54 |
+
lr=config["learning_rate"],
|
| 55 |
+
weight_decay=config["weight_decay"],
|
| 56 |
+
)
|
| 57 |
+
warmup_steps = int(config["warmup_ratio"] * total_steps)
|
| 58 |
+
|
| 59 |
+
if config["lr_scheduler_type"] == "linear":
|
| 60 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 61 |
+
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
|
| 62 |
+
)
|
| 63 |
+
elif config["lr_scheduler_type"] == "cosine":
|
| 64 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 65 |
+
optimizer,
|
| 66 |
+
num_warmup_steps=warmup_steps,
|
| 67 |
+
num_training_steps=total_steps,
|
| 68 |
+
num_cycles=0.5,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return optimizer, scheduler
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def train_epoch(
|
| 75 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
| 76 |
+
):
|
| 77 |
+
model.train()
|
| 78 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
|
| 79 |
+
for batch_idx, batch in enumerate(progress_bar):
|
| 80 |
+
optimizer.zero_grad()
|
| 81 |
+
input_ids = batch["input_ids"].to(device)
|
| 82 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 83 |
+
labels = [
|
| 84 |
+
batch["labels"][task_name].to(device) for task_name in config["task_names"]
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
loss, _, _ = model(input_ids, attention_mask, labels)
|
| 88 |
+
loss.backward()
|
| 89 |
+
|
| 90 |
+
if config["gradient_clipping"]:
|
| 91 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
|
| 92 |
+
|
| 93 |
+
optimizer.step()
|
| 94 |
+
scheduler.step()
|
| 95 |
+
|
| 96 |
+
writer.add_scalar(
|
| 97 |
+
"Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
|
| 98 |
+
)
|
| 99 |
+
if config.get("use_wandb", False):
|
| 100 |
+
import wandb
|
| 101 |
+
|
| 102 |
+
wandb.log({"Training Loss": loss.item()})
|
| 103 |
+
|
| 104 |
+
# Update progress bar
|
| 105 |
+
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 106 |
+
|
| 107 |
+
return loss.item() # Return the last batch loss
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def validate_model(model, val_loader, device, config):
|
| 111 |
+
model.eval()
|
| 112 |
+
val_loss = 0.0
|
| 113 |
+
task_true_labels = {task_name: [] for task_name in config["task_names"]}
|
| 114 |
+
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
|
| 115 |
+
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
|
| 116 |
+
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
for batch in val_loader:
|
| 119 |
+
input_ids = batch["input_ids"].to(device)
|
| 120 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 121 |
labels = [
|
| 122 |
+
batch["labels"][task_name].to(device)
|
| 123 |
+
for task_name in config["task_names"]
|
| 124 |
]
|
| 125 |
+
loss, logits, _ = model(input_ids, attention_mask, labels)
|
| 126 |
+
val_loss += loss.item()
|
| 127 |
+
|
| 128 |
+
for sample_idx in range(len(batch["input_ids"])):
|
| 129 |
+
for i, task_name in enumerate(config["task_names"]):
|
| 130 |
+
true_label = batch["labels"][task_name][sample_idx].item()
|
| 131 |
+
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
| 132 |
+
pred_prob = (
|
| 133 |
+
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
| 134 |
+
)
|
| 135 |
+
task_true_labels[task_name].append(true_label)
|
| 136 |
+
task_pred_labels[task_name].append(pred_label)
|
| 137 |
+
task_pred_probs[task_name].append(pred_prob)
|
| 138 |
+
|
| 139 |
+
val_loss /= len(val_loader)
|
| 140 |
+
return val_loss, task_true_labels, task_pred_labels, task_pred_probs
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def log_metrics(task_metrics, val_loss, config, writer, epochs):
|
| 144 |
+
for task_name, metrics in task_metrics.items():
|
| 145 |
+
print(
|
| 146 |
+
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
|
| 147 |
+
)
|
| 148 |
+
if config.get("use_wandb", False):
|
| 149 |
+
import wandb
|
| 150 |
|
| 151 |
+
wandb.log(
|
| 152 |
+
{
|
| 153 |
+
f"{task_name} Validation F1 Macro": metrics["f1"],
|
| 154 |
+
f"{task_name} Validation Accuracy": metrics["accuracy"],
|
| 155 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
)
|
| 157 |
+
|
| 158 |
+
writer.add_scalar("Validation Loss", val_loss, epochs)
|
| 159 |
+
for task_name, metrics in task_metrics.items():
|
| 160 |
+
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
|
| 161 |
+
writer.add_scalar(
|
| 162 |
+
f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
+
|
| 166 |
+
def save_validation_predictions(
|
| 167 |
+
val_cell_id_mapping,
|
| 168 |
+
task_true_labels,
|
| 169 |
+
task_pred_labels,
|
| 170 |
+
task_pred_probs,
|
| 171 |
+
config,
|
| 172 |
+
trial_number=None,
|
| 173 |
+
):
|
| 174 |
+
if trial_number is not None:
|
| 175 |
+
trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
|
| 176 |
+
os.makedirs(trial_results_dir, exist_ok=True)
|
| 177 |
+
val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
|
| 178 |
+
else:
|
| 179 |
+
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
|
| 180 |
+
|
| 181 |
+
rows = []
|
| 182 |
+
for sample_idx in range(len(val_cell_id_mapping)):
|
| 183 |
+
row = {"Cell ID": val_cell_id_mapping[sample_idx]}
|
| 184 |
+
for task_name in config["task_names"]:
|
| 185 |
+
row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
|
| 186 |
+
row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
|
| 187 |
+
row[f"{task_name} Probabilities"] = ",".join(
|
| 188 |
+
map(str, task_pred_probs[task_name][sample_idx])
|
| 189 |
)
|
| 190 |
+
rows.append(row)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
+
df = pd.DataFrame(rows)
|
| 193 |
+
df.to_csv(val_preds_file, index=False)
|
| 194 |
+
print(f"Validation predictions saved to {val_preds_file}")
|
| 195 |
|
| 196 |
+
|
| 197 |
+
def train_model(
|
| 198 |
+
config,
|
| 199 |
+
device,
|
| 200 |
train_loader,
|
| 201 |
val_loader,
|
| 202 |
train_cell_id_mapping,
|
| 203 |
val_cell_id_mapping,
|
| 204 |
num_labels_list,
|
|
|
|
|
|
|
| 205 |
):
|
|
|
|
| 206 |
set_seed(config["seed"])
|
| 207 |
initialize_wandb(config)
|
| 208 |
|
| 209 |
+
model = create_model(config, num_labels_list, device)
|
| 210 |
+
total_steps = len(train_loader) * config["epochs"]
|
| 211 |
+
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
|
| 212 |
+
|
| 213 |
+
log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
|
| 214 |
+
writer = SummaryWriter(log_dir=log_dir)
|
| 215 |
+
|
| 216 |
+
epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
|
| 217 |
+
for epoch in epoch_progress:
|
| 218 |
+
last_loss = train_epoch(
|
| 219 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
)
|
| 221 |
+
epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
|
| 222 |
|
| 223 |
+
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
| 224 |
+
model, val_loader, device, config
|
| 225 |
+
)
|
| 226 |
+
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
| 227 |
|
| 228 |
+
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
| 229 |
+
writer.close()
|
|
|
|
|
|
|
| 230 |
|
| 231 |
+
save_validation_predictions(
|
| 232 |
+
val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
|
| 233 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
+
if config.get("use_wandb", False):
|
| 236 |
+
import wandb
|
| 237 |
|
| 238 |
+
wandb.finish()
|
| 239 |
|
| 240 |
+
print(f"\nFinal Validation Loss: {val_loss:.4f}")
|
| 241 |
+
return val_loss, model # Return both the validation loss and the trained model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
+
def objective(
|
| 245 |
+
trial,
|
| 246 |
+
train_loader,
|
| 247 |
+
val_loader,
|
| 248 |
+
train_cell_id_mapping,
|
| 249 |
+
val_cell_id_mapping,
|
| 250 |
+
num_labels_list,
|
| 251 |
+
config,
|
| 252 |
+
device,
|
| 253 |
+
):
|
| 254 |
+
set_seed(config["seed"]) # Set the seed before each trial
|
| 255 |
+
initialize_wandb(config)
|
| 256 |
|
| 257 |
+
# Hyperparameters
|
| 258 |
+
config["learning_rate"] = trial.suggest_float(
|
| 259 |
+
"learning_rate",
|
| 260 |
+
config["hyperparameters"]["learning_rate"]["low"],
|
| 261 |
+
config["hyperparameters"]["learning_rate"]["high"],
|
| 262 |
+
log=config["hyperparameters"]["learning_rate"]["log"],
|
| 263 |
+
)
|
| 264 |
+
config["warmup_ratio"] = trial.suggest_float(
|
| 265 |
+
"warmup_ratio",
|
| 266 |
+
config["hyperparameters"]["warmup_ratio"]["low"],
|
| 267 |
+
config["hyperparameters"]["warmup_ratio"]["high"],
|
| 268 |
+
)
|
| 269 |
+
config["weight_decay"] = trial.suggest_float(
|
| 270 |
+
"weight_decay",
|
| 271 |
+
config["hyperparameters"]["weight_decay"]["low"],
|
| 272 |
+
config["hyperparameters"]["weight_decay"]["high"],
|
| 273 |
+
)
|
| 274 |
+
config["dropout_rate"] = trial.suggest_float(
|
| 275 |
+
"dropout_rate",
|
| 276 |
+
config["hyperparameters"]["dropout_rate"]["low"],
|
| 277 |
+
config["hyperparameters"]["dropout_rate"]["high"],
|
| 278 |
+
)
|
| 279 |
+
config["lr_scheduler_type"] = trial.suggest_categorical(
|
| 280 |
+
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
|
| 281 |
+
)
|
| 282 |
+
config["use_attention_pooling"] = trial.suggest_categorical(
|
| 283 |
+
"use_attention_pooling", [False]
|
| 284 |
)
|
| 285 |
|
| 286 |
+
if config["use_task_weights"]:
|
| 287 |
+
config["task_weights"] = [
|
| 288 |
+
trial.suggest_float(
|
| 289 |
+
f"task_weight_{i}",
|
| 290 |
+
config["hyperparameters"]["task_weights"]["low"],
|
| 291 |
+
config["hyperparameters"]["task_weights"]["high"],
|
| 292 |
+
)
|
| 293 |
+
for i in range(len(num_labels_list))
|
| 294 |
+
]
|
| 295 |
+
weight_sum = sum(config["task_weights"])
|
| 296 |
+
config["task_weights"] = [
|
| 297 |
+
weight / weight_sum for weight in config["task_weights"]
|
| 298 |
+
]
|
| 299 |
+
else:
|
| 300 |
+
config["task_weights"] = None
|
| 301 |
+
|
| 302 |
+
# Dynamic range for max_layers_to_freeze
|
| 303 |
+
freeze_range = get_layer_freeze_range(config["pretrained_path"])
|
| 304 |
+
config["max_layers_to_freeze"] = trial.suggest_int(
|
| 305 |
+
"max_layers_to_freeze",
|
| 306 |
+
freeze_range["min"],
|
| 307 |
+
freeze_range["max"]
|
| 308 |
+
)
|
| 309 |
|
| 310 |
+
model = create_model(config, num_labels_list, device)
|
| 311 |
+
total_steps = len(train_loader) * config["epochs"]
|
| 312 |
+
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
|
| 313 |
|
| 314 |
+
log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
|
| 315 |
+
writer = SummaryWriter(log_dir=log_dir)
|
| 316 |
|
| 317 |
+
for epoch in range(config["epochs"]):
|
| 318 |
+
train_epoch(
|
| 319 |
+
model, train_loader, optimizer, scheduler, device, config, writer, epoch
|
| 320 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
+
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
|
| 323 |
+
model, val_loader, device, config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
)
|
| 325 |
+
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
|
| 326 |
|
| 327 |
+
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
|
| 328 |
+
writer.close()
|
| 329 |
|
| 330 |
+
save_validation_predictions(
|
| 331 |
+
val_cell_id_mapping,
|
| 332 |
+
task_true_labels,
|
| 333 |
+
task_pred_labels,
|
| 334 |
+
task_pred_probs,
|
| 335 |
+
config,
|
| 336 |
+
trial.number,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
)
|
| 338 |
|
| 339 |
+
trial.set_user_attr("model_state_dict", model.state_dict())
|
| 340 |
+
trial.set_user_attr("task_weights", config["task_weights"])
|
| 341 |
|
| 342 |
+
trial.report(val_loss, config["epochs"])
|
|
|
|
|
|
|
| 343 |
|
| 344 |
+
if trial.should_prune():
|
| 345 |
+
raise optuna.TrialPruned()
|
| 346 |
|
| 347 |
+
if config.get("use_wandb", False):
|
| 348 |
+
import wandb
|
|
|
|
|
|
|
| 349 |
|
| 350 |
+
wandb.log(
|
| 351 |
+
{
|
| 352 |
+
"trial_number": trial.number,
|
| 353 |
+
"val_loss": val_loss,
|
| 354 |
+
**{
|
| 355 |
+
f"{task_name}_f1": metrics["f1"]
|
| 356 |
+
for task_name, metrics in task_metrics.items()
|
| 357 |
+
},
|
| 358 |
+
**{
|
| 359 |
+
f"{task_name}_accuracy": metrics["accuracy"]
|
| 360 |
+
for task_name, metrics in task_metrics.items()
|
| 361 |
+
},
|
| 362 |
+
**{
|
| 363 |
+
k: v
|
| 364 |
+
for k, v in config.items()
|
| 365 |
+
if k
|
| 366 |
+
in [
|
| 367 |
+
"learning_rate",
|
| 368 |
+
"warmup_ratio",
|
| 369 |
+
"weight_decay",
|
| 370 |
+
"dropout_rate",
|
| 371 |
+
"lr_scheduler_type",
|
| 372 |
+
"use_attention_pooling",
|
| 373 |
+
"max_layers_to_freeze",
|
| 374 |
+
]
|
| 375 |
+
},
|
| 376 |
+
}
|
| 377 |
+
)
|
| 378 |
+
wandb.finish()
|
| 379 |
|
| 380 |
+
return val_loss
|
geneformer/mtl/train_utils.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
from .data import get_data_loader, preload_and_process_data
|
| 4 |
+
from .imports import *
|
| 5 |
+
from .model import GeneformerMultiTask
|
| 6 |
+
from .train import objective, train_model
|
| 7 |
+
from .utils import save_model
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def set_seed(seed):
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
torch.manual_seed(seed)
|
| 14 |
+
torch.cuda.manual_seed_all(seed)
|
| 15 |
+
torch.backends.cudnn.deterministic = True
|
| 16 |
+
torch.backends.cudnn.benchmark = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def run_manual_tuning(config):
|
| 20 |
+
# Set seed for reproducibility
|
| 21 |
+
set_seed(config["seed"])
|
| 22 |
+
|
| 23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
(
|
| 25 |
+
train_dataset,
|
| 26 |
+
train_cell_id_mapping,
|
| 27 |
+
val_dataset,
|
| 28 |
+
val_cell_id_mapping,
|
| 29 |
+
num_labels_list,
|
| 30 |
+
) = preload_and_process_data(config)
|
| 31 |
+
train_loader = get_data_loader(train_dataset, config["batch_size"])
|
| 32 |
+
val_loader = get_data_loader(val_dataset, config["batch_size"])
|
| 33 |
+
|
| 34 |
+
# Print the manual hyperparameters being used
|
| 35 |
+
print("\nManual hyperparameters being used:")
|
| 36 |
+
for key, value in config["manual_hyperparameters"].items():
|
| 37 |
+
print(f"{key}: {value}")
|
| 38 |
+
print() # Add an empty line for better readability
|
| 39 |
+
|
| 40 |
+
# Use the manual hyperparameters
|
| 41 |
+
for key, value in config["manual_hyperparameters"].items():
|
| 42 |
+
config[key] = value
|
| 43 |
+
|
| 44 |
+
# Train the model
|
| 45 |
+
val_loss, trained_model = train_model(
|
| 46 |
+
config,
|
| 47 |
+
device,
|
| 48 |
+
train_loader,
|
| 49 |
+
val_loader,
|
| 50 |
+
train_cell_id_mapping,
|
| 51 |
+
val_cell_id_mapping,
|
| 52 |
+
num_labels_list,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
print(f"\nValidation loss with manual hyperparameters: {val_loss}")
|
| 56 |
+
|
| 57 |
+
# Save the trained model
|
| 58 |
+
model_save_directory = os.path.join(
|
| 59 |
+
config["model_save_path"], "GeneformerMultiTask"
|
| 60 |
+
)
|
| 61 |
+
save_model(trained_model, model_save_directory)
|
| 62 |
+
|
| 63 |
+
# Save the hyperparameters
|
| 64 |
+
hyperparams_to_save = {
|
| 65 |
+
**config["manual_hyperparameters"],
|
| 66 |
+
"dropout_rate": config["dropout_rate"],
|
| 67 |
+
"use_task_weights": config["use_task_weights"],
|
| 68 |
+
"task_weights": config["task_weights"],
|
| 69 |
+
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
| 70 |
+
"use_attention_pooling": config["use_attention_pooling"],
|
| 71 |
+
}
|
| 72 |
+
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
| 73 |
+
with open(hyperparams_path, "w") as f:
|
| 74 |
+
json.dump(hyperparams_to_save, f)
|
| 75 |
+
print(f"Manual hyperparameters saved to {hyperparams_path}")
|
| 76 |
+
|
| 77 |
+
return val_loss
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def run_optuna_study(config):
|
| 81 |
+
# Set seed for reproducibility
|
| 82 |
+
set_seed(config["seed"])
|
| 83 |
+
|
| 84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
+
(
|
| 86 |
+
train_dataset,
|
| 87 |
+
train_cell_id_mapping,
|
| 88 |
+
val_dataset,
|
| 89 |
+
val_cell_id_mapping,
|
| 90 |
+
num_labels_list,
|
| 91 |
+
) = preload_and_process_data(config)
|
| 92 |
+
train_loader = get_data_loader(train_dataset, config["batch_size"])
|
| 93 |
+
val_loader = get_data_loader(val_dataset, config["batch_size"])
|
| 94 |
+
|
| 95 |
+
if config["use_manual_hyperparameters"]:
|
| 96 |
+
train_model(
|
| 97 |
+
config,
|
| 98 |
+
device,
|
| 99 |
+
train_loader,
|
| 100 |
+
val_loader,
|
| 101 |
+
train_cell_id_mapping,
|
| 102 |
+
val_cell_id_mapping,
|
| 103 |
+
num_labels_list,
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
objective_with_config_and_data = functools.partial(
|
| 107 |
+
objective,
|
| 108 |
+
train_loader=train_loader,
|
| 109 |
+
val_loader=val_loader,
|
| 110 |
+
train_cell_id_mapping=train_cell_id_mapping,
|
| 111 |
+
val_cell_id_mapping=val_cell_id_mapping,
|
| 112 |
+
num_labels_list=num_labels_list,
|
| 113 |
+
config=config,
|
| 114 |
+
device=device,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
study = optuna.create_study(
|
| 118 |
+
direction="minimize", # Minimize validation loss
|
| 119 |
+
study_name=config["study_name"],
|
| 120 |
+
# storage=config["storage"],
|
| 121 |
+
load_if_exists=True,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
|
| 125 |
+
|
| 126 |
+
# After finding the best trial
|
| 127 |
+
best_params = study.best_trial.params
|
| 128 |
+
best_task_weights = study.best_trial.user_attrs["task_weights"]
|
| 129 |
+
print("Saving the best model and its hyperparameters...")
|
| 130 |
+
|
| 131 |
+
# Saving model as before
|
| 132 |
+
best_model = GeneformerMultiTask(
|
| 133 |
+
config["pretrained_path"],
|
| 134 |
+
num_labels_list,
|
| 135 |
+
dropout_rate=best_params["dropout_rate"],
|
| 136 |
+
use_task_weights=config["use_task_weights"],
|
| 137 |
+
task_weights=best_task_weights,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Get the best model state dictionary
|
| 141 |
+
best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
|
| 142 |
+
|
| 143 |
+
# Remove the "module." prefix from the state dictionary keys if present
|
| 144 |
+
best_model_state_dict = {
|
| 145 |
+
k.replace("module.", ""): v for k, v in best_model_state_dict.items()
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# Load the modified state dictionary into the model, skipping unexpected keys
|
| 149 |
+
best_model.load_state_dict(best_model_state_dict, strict=False)
|
| 150 |
+
|
| 151 |
+
model_save_directory = os.path.join(
|
| 152 |
+
config["model_save_path"], "GeneformerMultiTask"
|
| 153 |
+
)
|
| 154 |
+
save_model(best_model, model_save_directory)
|
| 155 |
+
|
| 156 |
+
# Additionally, save the best hyperparameters and task weights
|
| 157 |
+
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
| 158 |
+
|
| 159 |
+
with open(hyperparams_path, "w") as f:
|
| 160 |
+
json.dump({**best_params, "task_weights": best_task_weights}, f)
|
| 161 |
+
print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
|
geneformer/mtl/utils.py
CHANGED
|
@@ -1,641 +1,129 @@
|
|
| 1 |
-
from typing import Dict, List, Optional, Union
|
| 2 |
-
import json
|
| 3 |
import os
|
| 4 |
-
import
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import numpy as np
|
| 8 |
-
import optuna
|
| 9 |
from sklearn.metrics import accuracy_score, f1_score
|
| 10 |
from sklearn.preprocessing import LabelEncoder
|
| 11 |
-
from
|
| 12 |
-
from transformers import AutoConfig, BertConfig, BertModel, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
|
| 13 |
-
from torch.optim import AdamW
|
| 14 |
-
import pandas as pd
|
| 15 |
-
import torch.distributed as dist
|
| 16 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 17 |
-
import torch.multiprocessing as mp
|
| 18 |
-
from contextlib import contextmanager
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def set_seed(seed):
|
| 22 |
-
random.seed(seed)
|
| 23 |
-
np.random.seed(seed)
|
| 24 |
-
torch.manual_seed(seed)
|
| 25 |
-
torch.cuda.manual_seed_all(seed)
|
| 26 |
-
torch.backends.cudnn.deterministic = True
|
| 27 |
-
torch.backends.cudnn.benchmark = False
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def initialize_wandb(config):
|
| 31 |
-
if config.get("use_wandb", False):
|
| 32 |
-
import wandb
|
| 33 |
-
wandb.init(
|
| 34 |
-
project=config.get("wandb_project", "geneformer_multitask"),
|
| 35 |
-
name=config.get("run_name", "experiment"),
|
| 36 |
-
config=config,
|
| 37 |
-
reinit=True,
|
| 38 |
-
)
|
| 39 |
|
| 40 |
-
|
| 41 |
-
def create_model(config, num_labels_list, device, is_distributed=False, local_rank=0):
|
| 42 |
-
"""Create and initialize the model based on configuration."""
|
| 43 |
-
from .model import GeneformerMultiTask
|
| 44 |
-
|
| 45 |
-
model = GeneformerMultiTask(
|
| 46 |
-
config["pretrained_path"],
|
| 47 |
-
num_labels_list,
|
| 48 |
-
dropout_rate=config.get("dropout_rate", 0.1),
|
| 49 |
-
use_task_weights=config.get("use_task_weights", False),
|
| 50 |
-
task_weights=config.get("task_weights", None),
|
| 51 |
-
max_layers_to_freeze=config.get("max_layers_to_freeze", 0),
|
| 52 |
-
use_attention_pooling=config.get("use_attention_pooling", False),
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
# Move model to device
|
| 56 |
-
model.to(device)
|
| 57 |
-
|
| 58 |
-
if is_distributed:
|
| 59 |
-
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
| 60 |
-
|
| 61 |
-
return model
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def setup_optimizer_and_scheduler(model, config, total_steps):
|
| 65 |
-
"""Set up optimizer and learning rate scheduler."""
|
| 66 |
-
no_decay = ["bias", "LayerNorm.weight"]
|
| 67 |
-
optimizer_grouped_parameters = [
|
| 68 |
-
{
|
| 69 |
-
"params": [p for n, p in model.named_parameters()
|
| 70 |
-
if not any(nd in n for nd in no_decay) and p.requires_grad],
|
| 71 |
-
"weight_decay": config["weight_decay"],
|
| 72 |
-
},
|
| 73 |
-
{
|
| 74 |
-
"params": [p for n, p in model.named_parameters()
|
| 75 |
-
if any(nd in n for nd in no_decay) and p.requires_grad],
|
| 76 |
-
"weight_decay": 0.0,
|
| 77 |
-
},
|
| 78 |
-
]
|
| 79 |
-
|
| 80 |
-
optimizer = AdamW(
|
| 81 |
-
optimizer_grouped_parameters,
|
| 82 |
-
lr=config["learning_rate"],
|
| 83 |
-
eps=config.get("adam_epsilon", 1e-8)
|
| 84 |
-
)
|
| 85 |
-
|
| 86 |
-
# Prepare scheduler
|
| 87 |
-
warmup_steps = int(total_steps * config["warmup_ratio"])
|
| 88 |
-
|
| 89 |
-
scheduler_map = {
|
| 90 |
-
"linear": get_linear_schedule_with_warmup,
|
| 91 |
-
"cosine": get_cosine_schedule_with_warmup
|
| 92 |
-
}
|
| 93 |
-
|
| 94 |
-
scheduler_fn = scheduler_map.get(config["lr_scheduler_type"])
|
| 95 |
-
if not scheduler_fn:
|
| 96 |
-
raise ValueError(f"Unsupported scheduler type: {config['lr_scheduler_type']}")
|
| 97 |
-
|
| 98 |
-
scheduler = scheduler_fn(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
|
| 99 |
-
|
| 100 |
-
return optimizer, scheduler
|
| 101 |
|
| 102 |
|
| 103 |
def save_model(model, model_save_directory):
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
#
|
| 108 |
-
if isinstance(model,
|
| 109 |
-
|
|
|
|
|
|
|
| 110 |
else:
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
| 116 |
torch.save(model_state_dict, model_save_path)
|
| 117 |
|
| 118 |
# Save the model configuration
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def save_hyperparameters(model_save_directory, hyperparams):
|
| 125 |
-
"""Save hyperparameters to a JSON file."""
|
| 126 |
-
hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
|
| 127 |
-
with open(hyperparams_path, "w") as f:
|
| 128 |
-
json.dump(hyperparams, f)
|
| 129 |
-
print(f"Hyperparameters saved to {hyperparams_path}")
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def calculate_metrics(labels=None, preds=None, task_data=None, metric_type="task_specific", return_format="dict"):
|
| 133 |
-
if metric_type == "single":
|
| 134 |
-
# Calculate metrics for a single task
|
| 135 |
-
if labels is None or preds is None:
|
| 136 |
-
raise ValueError("Labels and predictions must be provided for single task metrics")
|
| 137 |
-
|
| 138 |
-
task_name = None
|
| 139 |
-
if isinstance(labels, dict) and len(labels) == 1:
|
| 140 |
-
task_name = list(labels.keys())[0]
|
| 141 |
-
labels = labels[task_name]
|
| 142 |
-
preds = preds[task_name]
|
| 143 |
-
|
| 144 |
-
f1 = f1_score(labels, preds, average="macro")
|
| 145 |
-
accuracy = accuracy_score(labels, preds)
|
| 146 |
-
|
| 147 |
-
if return_format == "tuple":
|
| 148 |
-
return f1, accuracy
|
| 149 |
-
|
| 150 |
-
result = {"f1": f1, "accuracy": accuracy}
|
| 151 |
-
if task_name:
|
| 152 |
-
return {task_name: result}
|
| 153 |
-
return result
|
| 154 |
-
|
| 155 |
-
elif metric_type == "task_specific":
|
| 156 |
-
# Calculate metrics for multiple tasks
|
| 157 |
-
if task_data:
|
| 158 |
-
result = {}
|
| 159 |
-
for task_name, (task_labels, task_preds) in task_data.items():
|
| 160 |
-
f1 = f1_score(task_labels, task_preds, average="macro")
|
| 161 |
-
accuracy = accuracy_score(task_labels, task_preds)
|
| 162 |
-
result[task_name] = {"f1": f1, "accuracy": accuracy}
|
| 163 |
-
return result
|
| 164 |
-
elif isinstance(labels, dict) and isinstance(preds, dict):
|
| 165 |
-
result = {}
|
| 166 |
-
for task_name in labels:
|
| 167 |
-
if task_name in preds:
|
| 168 |
-
f1 = f1_score(labels[task_name], preds[task_name], average="macro")
|
| 169 |
-
accuracy = accuracy_score(labels[task_name], preds[task_name])
|
| 170 |
-
result[task_name] = {"f1": f1, "accuracy": accuracy}
|
| 171 |
-
return result
|
| 172 |
-
else:
|
| 173 |
-
raise ValueError("For task_specific metrics, either task_data or labels and preds dictionaries must be provided")
|
| 174 |
-
|
| 175 |
-
elif metric_type == "combined":
|
| 176 |
-
# Calculate combined metrics across all tasks
|
| 177 |
-
if labels is None or preds is None:
|
| 178 |
-
raise ValueError("Labels and predictions must be provided for combined metrics")
|
| 179 |
-
|
| 180 |
-
# Handle label encoding for non-numeric labels
|
| 181 |
-
if not all(isinstance(x, (int, float)) for x in labels + preds):
|
| 182 |
-
le = LabelEncoder()
|
| 183 |
-
le.fit(labels + preds)
|
| 184 |
-
labels = le.transform(labels)
|
| 185 |
-
preds = le.transform(preds)
|
| 186 |
-
|
| 187 |
-
f1 = f1_score(labels, preds, average="macro")
|
| 188 |
-
accuracy = accuracy_score(labels, preds)
|
| 189 |
-
|
| 190 |
-
if return_format == "tuple":
|
| 191 |
-
return f1, accuracy
|
| 192 |
-
return {"f1": f1, "accuracy": accuracy}
|
| 193 |
-
|
| 194 |
-
else:
|
| 195 |
-
raise ValueError(f"Unknown metric_type: {metric_type}")
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
def get_layer_freeze_range(pretrained_path):
|
| 199 |
-
if not pretrained_path:
|
| 200 |
-
return {"min": 0, "max": 0}
|
| 201 |
-
|
| 202 |
-
config = AutoConfig.from_pretrained(pretrained_path)
|
| 203 |
-
total_layers = config.num_hidden_layers
|
| 204 |
-
return {"min": 0, "max": total_layers - 1}
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def prepare_training_environment(config):
|
| 208 |
-
"""
|
| 209 |
-
Prepare the training environment by setting seed and loading data.
|
| 210 |
-
|
| 211 |
-
Returns:
|
| 212 |
-
tuple: (device, train_loader, val_loader, train_cell_id_mapping,
|
| 213 |
-
val_cell_id_mapping, num_labels_list)
|
| 214 |
-
"""
|
| 215 |
-
from .data import prepare_data_loaders
|
| 216 |
-
|
| 217 |
-
# Set seed for reproducibility
|
| 218 |
-
set_seed(config["seed"])
|
| 219 |
-
|
| 220 |
-
# Set up device - for non-distributed training
|
| 221 |
-
if not config.get("distributed_training", False):
|
| 222 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 223 |
else:
|
| 224 |
-
|
| 225 |
-
device = None
|
| 226 |
-
|
| 227 |
-
# Load data using the streaming dataset
|
| 228 |
-
data = prepare_data_loaders(config)
|
| 229 |
-
|
| 230 |
-
# For distributed training, we'll set up samplers later in the distributed worker
|
| 231 |
-
# Don't create DistributedSampler here as process group isn't initialized yet
|
| 232 |
-
|
| 233 |
-
return (
|
| 234 |
-
device,
|
| 235 |
-
data["train_loader"],
|
| 236 |
-
data["val_loader"],
|
| 237 |
-
data["train_cell_mapping"],
|
| 238 |
-
data["val_cell_mapping"],
|
| 239 |
-
data["num_labels_list"],
|
| 240 |
-
)
|
| 241 |
|
|
|
|
| 242 |
|
| 243 |
-
# Optuna hyperparameter optimization utilities
|
| 244 |
-
def save_trial_callback(study, trial, trials_result_path):
|
| 245 |
-
"""
|
| 246 |
-
Callback to save Optuna trial results to a file.
|
| 247 |
-
|
| 248 |
-
Args:
|
| 249 |
-
study: Optuna study object
|
| 250 |
-
trial: Current trial object
|
| 251 |
-
trials_result_path: Path to save trial results
|
| 252 |
-
"""
|
| 253 |
-
with open(trials_result_path, "a") as f:
|
| 254 |
-
f.write(
|
| 255 |
-
f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
|
| 256 |
-
)
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
-
def create_optuna_study(objective, n_trials: int, trials_result_path: str, tensorboard_log_dir: str) -> optuna.Study:
|
| 260 |
-
"""Create and run an Optuna study with TensorBoard logging."""
|
| 261 |
-
from optuna.integration import TensorBoardCallback
|
| 262 |
-
|
| 263 |
-
study = optuna.create_study(direction="maximize")
|
| 264 |
-
study.optimize(
|
| 265 |
-
objective,
|
| 266 |
-
n_trials=n_trials,
|
| 267 |
-
callbacks=[
|
| 268 |
-
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
|
| 269 |
-
TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
|
| 270 |
-
]
|
| 271 |
-
)
|
| 272 |
-
return study
|
| 273 |
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
writer = SummaryWriter(log_dir=log_dir)
|
| 280 |
-
try:
|
| 281 |
-
yield writer
|
| 282 |
-
finally:
|
| 283 |
-
writer.close()
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
writer.add_scalar(
|
| 289 |
-
"Training Loss", loss, epoch * steps_per_epoch + batch_idx
|
| 290 |
-
)
|
| 291 |
-
if config.get("use_wandb", False):
|
| 292 |
-
import wandb
|
| 293 |
-
wandb.log({"Training Loss": loss})
|
| 294 |
|
|
|
|
|
|
|
| 295 |
|
| 296 |
-
|
| 297 |
-
"""Log validation metrics to console, TensorBoard, and optionally W&B."""
|
| 298 |
-
for task_name, metrics in task_metrics.items():
|
| 299 |
-
print(
|
| 300 |
-
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
|
| 301 |
-
)
|
| 302 |
-
if config.get("use_wandb", False):
|
| 303 |
-
import wandb
|
| 304 |
-
wandb.log(
|
| 305 |
-
{
|
| 306 |
-
f"{task_name} Validation F1 Macro": metrics["f1"],
|
| 307 |
-
f"{task_name} Validation Accuracy": metrics["accuracy"],
|
| 308 |
-
}
|
| 309 |
-
)
|
| 310 |
|
| 311 |
-
writer.add_scalar("Validation Loss", val_loss, epoch)
|
| 312 |
-
for task_name, metrics in task_metrics.items():
|
| 313 |
-
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epoch)
|
| 314 |
-
writer.add_scalar(
|
| 315 |
-
f"{task_name} - Validation Accuracy", metrics["accuracy"], epoch
|
| 316 |
-
)
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
with open(label_mappings_path, 'rb') as f:
|
| 324 |
-
return pickle.load(f)
|
| 325 |
-
return {task_name: {} for task_name in task_names}
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
row = {"Cell ID": cell_id}
|
| 336 |
-
for task_name in task_names:
|
| 337 |
-
if task_name in task_true_labels and sample_idx < len(task_true_labels[task_name]):
|
| 338 |
-
true_idx = task_true_labels[task_name][sample_idx]
|
| 339 |
-
pred_idx = task_pred_labels[task_name][sample_idx]
|
| 340 |
-
true_label = inverted_mappings.get(task_name, {}).get(true_idx, f"Unknown-{true_idx}")
|
| 341 |
-
pred_label = inverted_mappings.get(task_name, {}).get(pred_idx, f"Unknown-{pred_idx}")
|
| 342 |
-
|
| 343 |
-
row.update({
|
| 344 |
-
f"{task_name}_true_idx": true_idx,
|
| 345 |
-
f"{task_name}_pred_idx": pred_idx,
|
| 346 |
-
f"{task_name}_true_label": true_label,
|
| 347 |
-
f"{task_name}_pred_label": pred_label
|
| 348 |
-
})
|
| 349 |
-
|
| 350 |
-
if task_name in task_pred_probs and sample_idx < len(task_pred_probs[task_name]):
|
| 351 |
-
probs = task_pred_probs[task_name][sample_idx]
|
| 352 |
-
if isinstance(probs, (list, np.ndarray)) or (hasattr(probs, '__iter__') and not isinstance(probs, str)):
|
| 353 |
-
prob_list = list(probs) if not isinstance(probs, list) else probs
|
| 354 |
-
row[f"{task_name}_all_probs"] = ",".join(map(str, prob_list))
|
| 355 |
-
for class_idx, prob in enumerate(prob_list):
|
| 356 |
-
class_label = inverted_mappings.get(task_name, {}).get(class_idx, f"Unknown-{class_idx}")
|
| 357 |
-
row[f"{task_name}_prob_{class_label}"] = prob
|
| 358 |
-
else:
|
| 359 |
-
row[f"{task_name}_all_probs"] = str(probs)
|
| 360 |
-
|
| 361 |
-
return row
|
| 362 |
|
|
|
|
|
|
|
| 363 |
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
task_pred_labels,
|
| 368 |
-
task_pred_probs,
|
| 369 |
-
config,
|
| 370 |
-
trial_number=None,
|
| 371 |
-
):
|
| 372 |
-
"""Save validation predictions to a CSV file with class labels and probabilities."""
|
| 373 |
-
os.makedirs(config["results_dir"], exist_ok=True)
|
| 374 |
-
|
| 375 |
-
if trial_number is not None:
|
| 376 |
-
os.makedirs(os.path.join(config["results_dir"], f"trial_{trial_number}"), exist_ok=True)
|
| 377 |
-
val_preds_file = os.path.join(config["results_dir"], f"trial_{trial_number}/val_preds.csv")
|
| 378 |
-
else:
|
| 379 |
-
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
|
| 380 |
-
|
| 381 |
-
if not val_cell_indices or not task_true_labels:
|
| 382 |
-
pd.DataFrame().to_csv(val_preds_file, index=False)
|
| 383 |
-
return
|
| 384 |
-
|
| 385 |
-
try:
|
| 386 |
-
label_mappings = load_label_mappings(config["results_dir"], config["task_names"])
|
| 387 |
-
inverted_mappings = {task: {idx: label for label, idx in mapping.items()} for task, mapping in label_mappings.items()}
|
| 388 |
-
val_cell_mapping = config.get("val_cell_mapping", {})
|
| 389 |
-
|
| 390 |
-
# Determine maximum number of samples
|
| 391 |
-
max_samples = max(
|
| 392 |
-
[len(val_cell_indices)] +
|
| 393 |
-
[len(task_true_labels[task]) for task in task_true_labels]
|
| 394 |
-
)
|
| 395 |
-
|
| 396 |
-
rows = [
|
| 397 |
-
create_prediction_row(
|
| 398 |
-
sample_idx, val_cell_indices, task_true_labels, task_pred_labels,
|
| 399 |
-
task_pred_probs, config["task_names"], inverted_mappings, val_cell_mapping
|
| 400 |
-
)
|
| 401 |
-
for sample_idx in range(max_samples)
|
| 402 |
-
]
|
| 403 |
-
|
| 404 |
-
pd.DataFrame(rows).to_csv(val_preds_file, index=False)
|
| 405 |
-
except Exception as e:
|
| 406 |
-
pd.DataFrame([{"Error": str(e)}]).to_csv(val_preds_file, index=False)
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
-
|
| 410 |
-
"""
|
| 411 |
-
Setup the distributed training environment.
|
| 412 |
-
|
| 413 |
-
Args:
|
| 414 |
-
rank (int): The rank of the current process
|
| 415 |
-
world_size (int): Total number of processes
|
| 416 |
-
config (dict): Configuration dictionary
|
| 417 |
-
"""
|
| 418 |
-
os.environ['MASTER_ADDR'] = config.get('master_addr', 'localhost')
|
| 419 |
-
os.environ['MASTER_PORT'] = config.get('master_port', '12355')
|
| 420 |
-
|
| 421 |
-
# Initialize the process group
|
| 422 |
-
dist.init_process_group(
|
| 423 |
-
backend='nccl',
|
| 424 |
-
init_method='env://',
|
| 425 |
-
world_size=world_size,
|
| 426 |
-
rank=rank
|
| 427 |
-
)
|
| 428 |
-
|
| 429 |
-
# Set the device for this process
|
| 430 |
-
torch.cuda.set_device(rank)
|
| 431 |
|
| 432 |
|
| 433 |
-
def
|
| 434 |
-
"""Run distributed training across multiple GPUs with fallback to single GPU."""
|
| 435 |
-
world_size = torch.cuda.device_count()
|
| 436 |
-
|
| 437 |
-
if world_size <= 1:
|
| 438 |
-
print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
|
| 439 |
-
config["distributed_training"] = False
|
| 440 |
-
trainer = trainer_class(config)
|
| 441 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 442 |
-
trainer.device = device
|
| 443 |
-
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
|
| 444 |
-
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
| 445 |
-
)
|
| 446 |
-
val_loss, model = trainer.train(
|
| 447 |
-
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
| 448 |
-
)
|
| 449 |
-
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
|
| 450 |
-
save_model(model, model_save_directory)
|
| 451 |
-
save_hyperparameters(model_save_directory, {
|
| 452 |
-
**get_config_value(config, "manual_hyperparameters", {}),
|
| 453 |
-
"dropout_rate": config["dropout_rate"],
|
| 454 |
-
"use_task_weights": config["use_task_weights"],
|
| 455 |
-
"task_weights": config["task_weights"],
|
| 456 |
-
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
| 457 |
-
"use_attention_pooling": config["use_attention_pooling"],
|
| 458 |
-
})
|
| 459 |
-
|
| 460 |
-
if shared_dict is not None:
|
| 461 |
-
shared_dict['val_loss'] = val_loss
|
| 462 |
-
task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(model, val_loader, device, config)
|
| 463 |
-
shared_dict['task_metrics'] = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
|
| 464 |
-
shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()}
|
| 465 |
-
|
| 466 |
-
return val_loss, model
|
| 467 |
-
|
| 468 |
-
print(f"Using distributed training with {world_size} GPUs")
|
| 469 |
-
mp.spawn(
|
| 470 |
-
_distributed_worker,
|
| 471 |
-
args=(world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number, shared_dict),
|
| 472 |
-
nprocs=world_size,
|
| 473 |
-
join=True
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
if trial_number is None and shared_dict is None:
|
| 477 |
-
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
|
| 478 |
-
model_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
| 479 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 480 |
-
model = create_model(config, num_labels_list, device)
|
| 481 |
-
model.load_state_dict(torch.load(model_path))
|
| 482 |
-
return 0.0, model
|
| 483 |
-
|
| 484 |
-
return None
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
def _distributed_worker(rank, world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
|
| 488 |
-
"""Worker function for distributed training."""
|
| 489 |
-
setup_distributed_environment(rank, world_size, config)
|
| 490 |
-
config["local_rank"] = rank
|
| 491 |
-
|
| 492 |
-
# Set up distributed samplers
|
| 493 |
-
from torch.utils.data import DistributedSampler
|
| 494 |
-
from .data import get_data_loader
|
| 495 |
-
|
| 496 |
-
train_sampler = DistributedSampler(train_loader.dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
| 497 |
-
val_sampler = DistributedSampler(val_loader.dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
|
| 498 |
-
|
| 499 |
-
train_loader = get_data_loader(train_loader.dataset, config["batch_size"], sampler=train_sampler, shuffle=False)
|
| 500 |
-
val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=val_sampler, shuffle=False)
|
| 501 |
-
|
| 502 |
-
if rank == 0:
|
| 503 |
-
print(f"Rank {rank}: Training {len(train_sampler)} samples, Validation {len(val_sampler)} samples")
|
| 504 |
-
print(f"Total samples across {world_size} GPUs: Training {len(train_sampler) * world_size}, Validation {len(val_sampler) * world_size}")
|
| 505 |
-
|
| 506 |
-
# Create and setup trainer
|
| 507 |
-
trainer = trainer_class(config)
|
| 508 |
-
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
|
| 509 |
-
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
| 510 |
-
)
|
| 511 |
-
|
| 512 |
-
# Train the model
|
| 513 |
-
val_loss, model = trainer.train(
|
| 514 |
-
train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
# Save model only from the main process
|
| 518 |
-
if rank == 0:
|
| 519 |
-
model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
|
| 520 |
-
save_model(model, model_save_directory)
|
| 521 |
-
|
| 522 |
-
save_hyperparameters(model_save_directory, {
|
| 523 |
-
**get_config_value(config, "manual_hyperparameters", {}),
|
| 524 |
-
"dropout_rate": config["dropout_rate"],
|
| 525 |
-
"use_task_weights": config["use_task_weights"],
|
| 526 |
-
"task_weights": config["task_weights"],
|
| 527 |
-
"max_layers_to_freeze": config["max_layers_to_freeze"],
|
| 528 |
-
"use_attention_pooling": config["use_attention_pooling"],
|
| 529 |
-
})
|
| 530 |
-
|
| 531 |
-
# For Optuna trials, store results in shared dictionary
|
| 532 |
-
if shared_dict is not None:
|
| 533 |
-
shared_dict['val_loss'] = val_loss
|
| 534 |
-
|
| 535 |
-
# Run validation on full dataset from rank 0 for consistent metrics
|
| 536 |
-
full_val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=None, shuffle=False)
|
| 537 |
-
|
| 538 |
-
# Get validation predictions using our utility function
|
| 539 |
-
task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(
|
| 540 |
-
model, full_val_loader, trainer.device, config
|
| 541 |
-
)
|
| 542 |
-
|
| 543 |
-
# Calculate metrics
|
| 544 |
-
task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
|
| 545 |
-
shared_dict['task_metrics'] = task_metrics
|
| 546 |
-
|
| 547 |
-
# Store model state dict
|
| 548 |
-
if isinstance(model, DDP):
|
| 549 |
-
model_state_dict = model.module.state_dict()
|
| 550 |
-
else:
|
| 551 |
-
model_state_dict = model.state_dict()
|
| 552 |
-
|
| 553 |
-
shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model_state_dict.items()}
|
| 554 |
-
|
| 555 |
-
# Clean up distributed environment
|
| 556 |
-
dist.destroy_process_group()
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
def save_model_without_heads(model_directory):
|
| 560 |
"""
|
| 561 |
-
|
| 562 |
-
|
| 563 |
Args:
|
| 564 |
-
|
|
|
|
|
|
|
| 565 |
"""
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
if not os.path.exists(model_path) or not os.path.exists(config_path):
|
| 574 |
-
raise FileNotFoundError(f"Model files not found in {model_directory}")
|
| 575 |
-
|
| 576 |
-
# Load the configuration
|
| 577 |
-
config = BertConfig.from_json_file(config_path)
|
| 578 |
-
|
| 579 |
-
# Load the model state dict
|
| 580 |
-
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
| 581 |
-
|
| 582 |
-
# Create a new model without heads
|
| 583 |
-
base_model = BertModel(config)
|
| 584 |
-
|
| 585 |
-
# Filter out the classification head parameters
|
| 586 |
-
base_model_state_dict = {}
|
| 587 |
-
for key, value in state_dict.items():
|
| 588 |
-
# Only keep parameters that belong to the base model (not classification heads)
|
| 589 |
-
if not key.startswith('classification_heads') and not key.startswith('attention_pool'):
|
| 590 |
-
base_model_state_dict[key] = value
|
| 591 |
-
|
| 592 |
-
# Load the filtered state dict into the base model
|
| 593 |
-
base_model.load_state_dict(base_model_state_dict, strict=False)
|
| 594 |
-
|
| 595 |
-
# Save the model without heads
|
| 596 |
-
output_dir = os.path.join(model_directory, "model_without_heads")
|
| 597 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 598 |
-
|
| 599 |
-
# Save the model weights
|
| 600 |
-
torch.save(base_model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
|
| 601 |
-
|
| 602 |
-
# Save the configuration
|
| 603 |
-
base_model.config.to_json_file(os.path.join(output_dir, "config.json"))
|
| 604 |
-
|
| 605 |
-
print(f"Model without classification heads saved to {output_dir}")
|
| 606 |
-
return output_dir
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
def get_config_value(config: Dict, key: str, default=None):
|
| 610 |
-
|
| 611 |
-
return config.get(key, default)
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
def collect_validation_predictions(model, val_loader, device, config) -> tuple:
|
| 615 |
-
task_true_labels = {}
|
| 616 |
-
task_pred_labels = {}
|
| 617 |
-
task_pred_probs = {}
|
| 618 |
-
|
| 619 |
-
with torch.no_grad():
|
| 620 |
-
for batch in val_loader:
|
| 621 |
-
input_ids = batch["input_ids"].to(device)
|
| 622 |
-
attention_mask = batch["attention_mask"].to(device)
|
| 623 |
-
labels = [batch["labels"][task_name].to(device) for task_name in config["task_names"]]
|
| 624 |
-
_, logits, _ = model(input_ids, attention_mask, labels)
|
| 625 |
-
|
| 626 |
-
for sample_idx in range(len(batch["input_ids"])):
|
| 627 |
-
for i, task_name in enumerate(config["task_names"]):
|
| 628 |
-
if task_name not in task_true_labels:
|
| 629 |
-
task_true_labels[task_name] = []
|
| 630 |
-
task_pred_labels[task_name] = []
|
| 631 |
-
task_pred_probs[task_name] = []
|
| 632 |
-
|
| 633 |
-
true_label = batch["labels"][task_name][sample_idx].item()
|
| 634 |
-
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
|
| 635 |
-
pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
|
| 636 |
-
|
| 637 |
-
task_true_labels[task_name].append(true_label)
|
| 638 |
-
task_pred_labels[task_name].append(pred_label)
|
| 639 |
-
task_pred_probs[task_name].append(pred_prob)
|
| 640 |
-
|
| 641 |
-
return task_true_labels, task_pred_labels, task_pred_probs
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import shutil
|
| 3 |
+
|
|
|
|
|
|
|
|
|
|
| 4 |
from sklearn.metrics import accuracy_score, f1_score
|
| 5 |
from sklearn.preprocessing import LabelEncoder
|
| 6 |
+
from transformers import AutoConfig, BertConfig, BertModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
from .imports import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def save_model(model, model_save_directory):
|
| 12 |
+
if not os.path.exists(model_save_directory):
|
| 13 |
+
os.makedirs(model_save_directory)
|
| 14 |
+
|
| 15 |
+
# Get the state dict
|
| 16 |
+
if isinstance(model, nn.DataParallel):
|
| 17 |
+
model_state_dict = (
|
| 18 |
+
model.module.state_dict()
|
| 19 |
+
) # Use model.module to access the underlying model
|
| 20 |
else:
|
| 21 |
+
model_state_dict = model.state_dict()
|
| 22 |
+
|
| 23 |
+
# Remove the "module." prefix from the keys if present
|
| 24 |
+
model_state_dict = {
|
| 25 |
+
k.replace("module.", ""): v for k, v in model_state_dict.items()
|
| 26 |
+
}
|
| 27 |
|
| 28 |
model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
|
| 29 |
torch.save(model_state_dict, model_save_path)
|
| 30 |
|
| 31 |
# Save the model configuration
|
| 32 |
+
if isinstance(model, nn.DataParallel):
|
| 33 |
+
model.module.config.to_json_file(
|
| 34 |
+
os.path.join(model_save_directory, "config.json")
|
| 35 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
else:
|
| 37 |
+
model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
print(f"Model and configuration saved to {model_save_directory}")
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
|
| 43 |
+
task_metrics = {}
|
| 44 |
+
for task_name in task_true_labels.keys():
|
| 45 |
+
true_labels = task_true_labels[task_name]
|
| 46 |
+
pred_labels = task_pred_labels[task_name]
|
| 47 |
+
f1 = f1_score(true_labels, pred_labels, average="macro")
|
| 48 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
| 49 |
+
task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
|
| 50 |
+
return task_metrics
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
def calculate_combined_f1(combined_labels, combined_preds):
|
| 54 |
+
# Initialize the LabelEncoder
|
| 55 |
+
le = LabelEncoder()
|
| 56 |
|
| 57 |
+
# Fit and transform combined labels and predictions to numerical values
|
| 58 |
+
le.fit(combined_labels + combined_preds)
|
| 59 |
+
encoded_true_labels = le.transform(combined_labels)
|
| 60 |
+
encoded_pred_labels = le.transform(combined_preds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
# Print out the mapping for sanity check
|
| 63 |
+
print("\nLabel Encoder Mapping:")
|
| 64 |
+
for index, class_label in enumerate(le.classes_):
|
| 65 |
+
print(f"'{class_label}': {index}")
|
| 66 |
|
| 67 |
+
# Calculate accuracy
|
| 68 |
+
accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
# Calculate F1 Macro score
|
| 71 |
+
f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
|
| 72 |
|
| 73 |
+
return f1, accuracy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
# def save_model_without_heads(original_model_save_directory):
|
| 77 |
+
# # Create a new directory for the model without heads
|
| 78 |
+
# new_model_save_directory = original_model_save_directory + "_No_Heads"
|
| 79 |
+
# if not os.path.exists(new_model_save_directory):
|
| 80 |
+
# os.makedirs(new_model_save_directory)
|
| 81 |
|
| 82 |
+
# # Load the model state dictionary
|
| 83 |
+
# model_state_dict = torch.load(
|
| 84 |
+
# os.path.join(original_model_save_directory, "pytorch_model.bin")
|
| 85 |
+
# )
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
# # Initialize a new BERT model without the classification heads
|
| 88 |
+
# config = BertConfig.from_pretrained(
|
| 89 |
+
# os.path.join(original_model_save_directory, "config.json")
|
| 90 |
+
# )
|
| 91 |
+
# model_without_heads = BertModel(config)
|
| 92 |
|
| 93 |
+
# # Filter the state dict to exclude classification heads
|
| 94 |
+
# model_without_heads_state_dict = {
|
| 95 |
+
# k: v
|
| 96 |
+
# for k, v in model_state_dict.items()
|
| 97 |
+
# if not k.startswith("classification_heads")
|
| 98 |
+
# }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
# # Load the filtered state dict into the model
|
| 101 |
+
# model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
|
| 102 |
|
| 103 |
+
# # Save the model without heads
|
| 104 |
+
# model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
|
| 105 |
+
# torch.save(model_without_heads.state_dict(), model_save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
# # Copy the configuration file
|
| 108 |
+
# shutil.copy(
|
| 109 |
+
# os.path.join(original_model_save_directory, "config.json"),
|
| 110 |
+
# new_model_save_directory,
|
| 111 |
+
# )
|
| 112 |
|
| 113 |
+
# print(f"Model without classification heads saved to {new_model_save_directory}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
+
def get_layer_freeze_range(pretrained_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
"""
|
| 118 |
+
Dynamically determines the number of layers to freeze based on the model depth from its configuration.
|
|
|
|
| 119 |
Args:
|
| 120 |
+
pretrained_path (str): Path to the pretrained model directory or model identifier.
|
| 121 |
+
Returns:
|
| 122 |
+
dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
|
| 123 |
"""
|
| 124 |
+
if pretrained_path:
|
| 125 |
+
config = AutoConfig.from_pretrained(pretrained_path)
|
| 126 |
+
total_layers = config.num_hidden_layers
|
| 127 |
+
return {"min": 0, "max": total_layers - 1}
|
| 128 |
+
else:
|
| 129 |
+
return {"min": 0, "max": 0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/mtl_classifier.py
CHANGED
|
@@ -29,7 +29,7 @@ Geneformer multi-task cell classifier.
|
|
| 29 |
import logging
|
| 30 |
import os
|
| 31 |
|
| 32 |
-
from .mtl import eval_utils,
|
| 33 |
|
| 34 |
logger = logging.getLogger(__name__)
|
| 35 |
|
|
@@ -49,9 +49,7 @@ class MTLClassifier:
|
|
| 49 |
"max_layers_to_freeze": {None, dict},
|
| 50 |
"epochs": {None, int},
|
| 51 |
"tensorboard_log_dir": {None, str},
|
| 52 |
-
"
|
| 53 |
-
"master_addr": {None, str},
|
| 54 |
-
"master_port": {None, str},
|
| 55 |
"use_attention_pooling": {None, bool},
|
| 56 |
"use_task_weights": {None, bool},
|
| 57 |
"hyperparameters": {None, dict},
|
|
@@ -63,7 +61,6 @@ class MTLClassifier:
|
|
| 63 |
"max_grad_norm": {None, int, float},
|
| 64 |
"seed": {None, int},
|
| 65 |
"trials_result_path": {None, str},
|
| 66 |
-
"gradient_accumulation_steps": {None, int},
|
| 67 |
}
|
| 68 |
|
| 69 |
def __init__(
|
|
@@ -82,9 +79,7 @@ class MTLClassifier:
|
|
| 82 |
max_layers_to_freeze=None,
|
| 83 |
epochs=1,
|
| 84 |
tensorboard_log_dir="/results/tblogdir",
|
| 85 |
-
|
| 86 |
-
master_addr="localhost",
|
| 87 |
-
master_port="12355",
|
| 88 |
use_attention_pooling=True,
|
| 89 |
use_task_weights=True,
|
| 90 |
hyperparameters=None, # Default is None
|
|
@@ -94,7 +89,6 @@ class MTLClassifier:
|
|
| 94 |
wandb_project=None,
|
| 95 |
gradient_clipping=False,
|
| 96 |
max_grad_norm=None,
|
| 97 |
-
gradient_accumulation_steps=1, # Add this line with default value 1
|
| 98 |
seed=42, # Default seed value
|
| 99 |
):
|
| 100 |
"""
|
|
@@ -123,12 +117,8 @@ class MTLClassifier:
|
|
| 123 |
| Path to directory to save results
|
| 124 |
tensorboard_log_dir : None, str
|
| 125 |
| Path to directory for Tensorboard logging results
|
| 126 |
-
|
| 127 |
-
| Whether to use
|
| 128 |
-
master_addr : None, str
|
| 129 |
-
| Master address for distributed training (default: localhost)
|
| 130 |
-
master_port : None, str
|
| 131 |
-
| Master port for distributed training (default: 12355)
|
| 132 |
use_attention_pooling : None, bool
|
| 133 |
| Whether to use attention pooling
|
| 134 |
use_task_weights : None, bool
|
|
@@ -160,8 +150,6 @@ class MTLClassifier:
|
|
| 160 |
| Whether to use gradient clipping
|
| 161 |
max_grad_norm : None, int, float
|
| 162 |
| Maximum norm for gradient clipping
|
| 163 |
-
gradient_accumulation_steps : None, int
|
| 164 |
-
| Number of steps to accumulate gradients before performing a backward/update pass
|
| 165 |
seed : None, int
|
| 166 |
| Random seed
|
| 167 |
"""
|
|
@@ -177,7 +165,6 @@ class MTLClassifier:
|
|
| 177 |
self.batch_size = batch_size
|
| 178 |
self.n_trials = n_trials
|
| 179 |
self.study_name = study_name
|
| 180 |
-
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 181 |
|
| 182 |
if max_layers_to_freeze is None:
|
| 183 |
# Dynamically determine the range of layers to freeze
|
|
@@ -188,9 +175,7 @@ class MTLClassifier:
|
|
| 188 |
|
| 189 |
self.epochs = epochs
|
| 190 |
self.tensorboard_log_dir = tensorboard_log_dir
|
| 191 |
-
self.
|
| 192 |
-
self.master_addr = master_addr
|
| 193 |
-
self.master_port = master_port
|
| 194 |
self.use_attention_pooling = use_attention_pooling
|
| 195 |
self.use_task_weights = use_task_weights
|
| 196 |
self.hyperparameters = (
|
|
@@ -308,7 +293,7 @@ class MTLClassifier:
|
|
| 308 |
self.config["manual_hyperparameters"] = self.manual_hyperparameters
|
| 309 |
self.config["use_manual_hyperparameters"] = True
|
| 310 |
|
| 311 |
-
|
| 312 |
|
| 313 |
def validate_additional_options(self, req_var_dict):
|
| 314 |
missing_variable = False
|
|
@@ -345,7 +330,7 @@ class MTLClassifier:
|
|
| 345 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
| 346 |
self.validate_additional_options(req_var_dict)
|
| 347 |
|
| 348 |
-
|
| 349 |
|
| 350 |
def load_and_evaluate_test_model(
|
| 351 |
self,
|
|
|
|
| 29 |
import logging
|
| 30 |
import os
|
| 31 |
|
| 32 |
+
from .mtl import eval_utils, train_utils, utils
|
| 33 |
|
| 34 |
logger = logging.getLogger(__name__)
|
| 35 |
|
|
|
|
| 49 |
"max_layers_to_freeze": {None, dict},
|
| 50 |
"epochs": {None, int},
|
| 51 |
"tensorboard_log_dir": {None, str},
|
| 52 |
+
"use_data_parallel": {None, bool},
|
|
|
|
|
|
|
| 53 |
"use_attention_pooling": {None, bool},
|
| 54 |
"use_task_weights": {None, bool},
|
| 55 |
"hyperparameters": {None, dict},
|
|
|
|
| 61 |
"max_grad_norm": {None, int, float},
|
| 62 |
"seed": {None, int},
|
| 63 |
"trials_result_path": {None, str},
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
def __init__(
|
|
|
|
| 79 |
max_layers_to_freeze=None,
|
| 80 |
epochs=1,
|
| 81 |
tensorboard_log_dir="/results/tblogdir",
|
| 82 |
+
use_data_parallel=False,
|
|
|
|
|
|
|
| 83 |
use_attention_pooling=True,
|
| 84 |
use_task_weights=True,
|
| 85 |
hyperparameters=None, # Default is None
|
|
|
|
| 89 |
wandb_project=None,
|
| 90 |
gradient_clipping=False,
|
| 91 |
max_grad_norm=None,
|
|
|
|
| 92 |
seed=42, # Default seed value
|
| 93 |
):
|
| 94 |
"""
|
|
|
|
| 117 |
| Path to directory to save results
|
| 118 |
tensorboard_log_dir : None, str
|
| 119 |
| Path to directory for Tensorboard logging results
|
| 120 |
+
use_data_parallel : None, bool
|
| 121 |
+
| Whether to use data parallelization
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
use_attention_pooling : None, bool
|
| 123 |
| Whether to use attention pooling
|
| 124 |
use_task_weights : None, bool
|
|
|
|
| 150 |
| Whether to use gradient clipping
|
| 151 |
max_grad_norm : None, int, float
|
| 152 |
| Maximum norm for gradient clipping
|
|
|
|
|
|
|
| 153 |
seed : None, int
|
| 154 |
| Random seed
|
| 155 |
"""
|
|
|
|
| 165 |
self.batch_size = batch_size
|
| 166 |
self.n_trials = n_trials
|
| 167 |
self.study_name = study_name
|
|
|
|
| 168 |
|
| 169 |
if max_layers_to_freeze is None:
|
| 170 |
# Dynamically determine the range of layers to freeze
|
|
|
|
| 175 |
|
| 176 |
self.epochs = epochs
|
| 177 |
self.tensorboard_log_dir = tensorboard_log_dir
|
| 178 |
+
self.use_data_parallel = use_data_parallel
|
|
|
|
|
|
|
| 179 |
self.use_attention_pooling = use_attention_pooling
|
| 180 |
self.use_task_weights = use_task_weights
|
| 181 |
self.hyperparameters = (
|
|
|
|
| 293 |
self.config["manual_hyperparameters"] = self.manual_hyperparameters
|
| 294 |
self.config["use_manual_hyperparameters"] = True
|
| 295 |
|
| 296 |
+
train_utils.run_manual_tuning(self.config)
|
| 297 |
|
| 298 |
def validate_additional_options(self, req_var_dict):
|
| 299 |
missing_variable = False
|
|
|
|
| 330 |
req_var_dict = dict(zip(required_variable_names, required_variables))
|
| 331 |
self.validate_additional_options(req_var_dict)
|
| 332 |
|
| 333 |
+
train_utils.run_optuna_study(self.config)
|
| 334 |
|
| 335 |
def load_and_evaluate_test_model(
|
| 336 |
self,
|
geneformer/perturber_utils.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import itertools as it
|
| 2 |
import logging
|
| 3 |
-
import os
|
| 4 |
import pickle
|
| 5 |
from collections import defaultdict
|
| 6 |
from pathlib import Path
|
|
@@ -9,7 +8,6 @@ from typing import List
|
|
| 9 |
import numpy as np
|
| 10 |
import pandas as pd
|
| 11 |
import torch
|
| 12 |
-
import datasets
|
| 13 |
from datasets import Dataset, load_from_disk
|
| 14 |
from peft import LoraConfig, get_peft_model
|
| 15 |
from transformers import (
|
|
@@ -19,6 +17,11 @@ from transformers import (
|
|
| 19 |
BitsAndBytesConfig,
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
|
|
@@ -110,25 +113,15 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
|
|
| 110 |
|
| 111 |
# load model to GPU
|
| 112 |
def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
| 113 |
-
if model_type == "
|
| 114 |
-
inference_only = True
|
| 115 |
-
model_type = "Pretrained"
|
| 116 |
-
quantize = True
|
| 117 |
-
elif model_type == "MTLCellClassifier-Quantized":
|
| 118 |
-
inference_only = True
|
| 119 |
model_type = "MTLCellClassifier"
|
| 120 |
quantize = True
|
| 121 |
-
else:
|
| 122 |
-
inference_only = False
|
| 123 |
|
| 124 |
output_hidden_states = (mode == "eval")
|
| 125 |
|
| 126 |
# Quantization logic
|
| 127 |
-
if
|
| 128 |
-
|
| 129 |
-
peft_config = quantize.get("peft_config", None)
|
| 130 |
-
elif quantize:
|
| 131 |
-
if inference_only:
|
| 132 |
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 133 |
peft_config = None
|
| 134 |
else:
|
|
@@ -138,22 +131,13 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
| 138 |
bnb_4bit_quant_type="nf4",
|
| 139 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 140 |
)
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
)
|
| 149 |
-
except ValueError as e:
|
| 150 |
-
peft_config = LoraConfig(
|
| 151 |
-
lora_alpha=128,
|
| 152 |
-
lora_dropout=0.1,
|
| 153 |
-
r=64,
|
| 154 |
-
bias="none",
|
| 155 |
-
task_type="TOKEN_CLS",
|
| 156 |
-
)
|
| 157 |
else:
|
| 158 |
quantize_config = None
|
| 159 |
peft_config = None
|
|
@@ -190,34 +174,17 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
| 190 |
model.eval()
|
| 191 |
|
| 192 |
# Handle device placement and PEFT
|
| 193 |
-
adapter_config_path = os.path.join(model_directory, "adapter_config.json")
|
| 194 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 195 |
if not quantize:
|
| 196 |
# Only move non-quantized models
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
# If adapter files exist, load them into the model using PEFT's from_pretrained
|
| 200 |
-
model = PeftModel.from_pretrained(model, model_directory)
|
| 201 |
-
move_to_cuda(model)
|
| 202 |
-
print("loading lora weights")
|
| 203 |
elif peft_config:
|
| 204 |
-
# Apply PEFT for quantized models (except MTLCellClassifier
|
| 205 |
model.enable_input_require_grads()
|
| 206 |
model = get_peft_model(model, peft_config)
|
| 207 |
-
move_to_cuda(model)
|
| 208 |
|
| 209 |
return model
|
| 210 |
|
| 211 |
-
|
| 212 |
-
def move_to_cuda(model):
|
| 213 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 214 |
-
# get what device model is currently on
|
| 215 |
-
model_device = next(model.parameters()).device
|
| 216 |
-
# Check if the model is on the CPU and move to cuda if necessary
|
| 217 |
-
if (model_device.type == "cpu") and (device.type == "cuda"):
|
| 218 |
-
model.to(device)
|
| 219 |
-
|
| 220 |
-
|
| 221 |
def quant_layers(model):
|
| 222 |
layer_nums = []
|
| 223 |
for name, parameter in model.named_parameters():
|
|
@@ -431,11 +398,6 @@ def remove_perturbed_indices_set(
|
|
| 431 |
def make_perturbation_batch(
|
| 432 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 433 |
) -> tuple[Dataset, List[int]]:
|
| 434 |
-
|
| 435 |
-
# For datasets>=4.0.0, convert to dict to avoid format issues
|
| 436 |
-
if int(datasets.__version__.split(".")[0]) >= 4:
|
| 437 |
-
example_cell = example_cell[:]
|
| 438 |
-
|
| 439 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 440 |
if perturb_type in ["overexpress", "activate"]:
|
| 441 |
range_start = 1
|
|
@@ -508,11 +470,6 @@ def make_perturbation_batch(
|
|
| 508 |
def make_perturbation_batch_special(
|
| 509 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 510 |
) -> tuple[Dataset, List[int]]:
|
| 511 |
-
|
| 512 |
-
# For datasets>=4.0.0, convert to dict to avoid format issues
|
| 513 |
-
if int(datasets.__version__.split(".")[0]) >= 4:
|
| 514 |
-
example_cell = example_cell[:]
|
| 515 |
-
|
| 516 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 517 |
if perturb_type in ["overexpress", "activate"]:
|
| 518 |
range_start = 1
|
|
@@ -913,4 +870,50 @@ def validate_cell_states_to_model(cell_states_to_model):
|
|
| 913 |
"'goal_state': 'nf', "
|
| 914 |
"'alt_states': ['hcm', 'other1', 'other2']}"
|
| 915 |
)
|
| 916 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import itertools as it
|
| 2 |
import logging
|
|
|
|
| 3 |
import pickle
|
| 4 |
from collections import defaultdict
|
| 5 |
from pathlib import Path
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
import torch
|
|
|
|
| 11 |
from datasets import Dataset, load_from_disk
|
| 12 |
from peft import LoraConfig, get_peft_model
|
| 13 |
from transformers import (
|
|
|
|
| 17 |
BitsAndBytesConfig,
|
| 18 |
)
|
| 19 |
|
| 20 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
| 21 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
| 22 |
+
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
| 27 |
|
|
|
|
| 113 |
|
| 114 |
# load model to GPU
|
| 115 |
def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
| 116 |
+
if model_type == "MTLCellClassifier-Quantized":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
model_type = "MTLCellClassifier"
|
| 118 |
quantize = True
|
|
|
|
|
|
|
| 119 |
|
| 120 |
output_hidden_states = (mode == "eval")
|
| 121 |
|
| 122 |
# Quantization logic
|
| 123 |
+
if quantize:
|
| 124 |
+
if model_type == "MTLCellClassifier":
|
|
|
|
|
|
|
|
|
|
| 125 |
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 126 |
peft_config = None
|
| 127 |
else:
|
|
|
|
| 131 |
bnb_4bit_quant_type="nf4",
|
| 132 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 133 |
)
|
| 134 |
+
peft_config = LoraConfig(
|
| 135 |
+
lora_alpha=128,
|
| 136 |
+
lora_dropout=0.1,
|
| 137 |
+
r=64,
|
| 138 |
+
bias="none",
|
| 139 |
+
task_type="TokenClassification",
|
| 140 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
else:
|
| 142 |
quantize_config = None
|
| 143 |
peft_config = None
|
|
|
|
| 174 |
model.eval()
|
| 175 |
|
| 176 |
# Handle device placement and PEFT
|
|
|
|
|
|
|
| 177 |
if not quantize:
|
| 178 |
# Only move non-quantized models
|
| 179 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 180 |
+
model = model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
elif peft_config:
|
| 182 |
+
# Apply PEFT for quantized models (except MTLCellClassifier)
|
| 183 |
model.enable_input_require_grads()
|
| 184 |
model = get_peft_model(model, peft_config)
|
|
|
|
| 185 |
|
| 186 |
return model
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def quant_layers(model):
|
| 189 |
layer_nums = []
|
| 190 |
for name, parameter in model.named_parameters():
|
|
|
|
| 398 |
def make_perturbation_batch(
|
| 399 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 400 |
) -> tuple[Dataset, List[int]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 402 |
if perturb_type in ["overexpress", "activate"]:
|
| 403 |
range_start = 1
|
|
|
|
| 470 |
def make_perturbation_batch_special(
|
| 471 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 472 |
) -> tuple[Dataset, List[int]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 474 |
if perturb_type in ["overexpress", "activate"]:
|
| 475 |
range_start = 1
|
|
|
|
| 870 |
"'goal_state': 'nf', "
|
| 871 |
"'alt_states': ['hcm', 'other1', 'other2']}"
|
| 872 |
)
|
| 873 |
+
raise
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
class GeneIdHandler:
|
| 877 |
+
def __init__(self, raise_errors=False):
|
| 878 |
+
def invert_dict(dict_obj):
|
| 879 |
+
return {v: k for k, v in dict_obj.items()}
|
| 880 |
+
|
| 881 |
+
self.raise_errors = raise_errors
|
| 882 |
+
|
| 883 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
| 884 |
+
self.gene_token_dict = pickle.load(f)
|
| 885 |
+
self.token_gene_dict = invert_dict(self.gene_token_dict)
|
| 886 |
+
|
| 887 |
+
with open(ENSEMBL_DICTIONARY_FILE, "rb") as f:
|
| 888 |
+
self.id_gene_dict = pickle.load(f)
|
| 889 |
+
self.gene_id_dict = invert_dict(self.id_gene_dict)
|
| 890 |
+
|
| 891 |
+
def ens_to_token(self, ens_id):
|
| 892 |
+
if not self.raise_errors:
|
| 893 |
+
return self.gene_token_dict.get(ens_id, ens_id)
|
| 894 |
+
else:
|
| 895 |
+
return self.gene_token_dict[ens_id]
|
| 896 |
+
|
| 897 |
+
def token_to_ens(self, token):
|
| 898 |
+
if not self.raise_errors:
|
| 899 |
+
return self.token_gene_dict.get(token, token)
|
| 900 |
+
else:
|
| 901 |
+
return self.token_gene_dict[token]
|
| 902 |
+
|
| 903 |
+
def ens_to_symbol(self, ens_id):
|
| 904 |
+
if not self.raise_errors:
|
| 905 |
+
return self.gene_id_dict.get(ens_id, ens_id)
|
| 906 |
+
else:
|
| 907 |
+
return self.gene_id_dict[ens_id]
|
| 908 |
+
|
| 909 |
+
def symbol_to_ens(self, symbol):
|
| 910 |
+
if not self.raise_errors:
|
| 911 |
+
return self.id_gene_dict.get(symbol, symbol)
|
| 912 |
+
else:
|
| 913 |
+
return self.id_gene_dict[symbol]
|
| 914 |
+
|
| 915 |
+
def token_to_symbol(self, token):
|
| 916 |
+
return self.ens_to_symbol(self.token_to_ens(token))
|
| 917 |
+
|
| 918 |
+
def symbol_to_token(self, symbol):
|
| 919 |
+
return self.ens_to_token(self.symbol_to_ens(symbol))
|
geneformer/pretrainer.py
CHANGED
|
@@ -8,12 +8,13 @@ import math
|
|
| 8 |
import pickle
|
| 9 |
import warnings
|
| 10 |
from enum import Enum
|
| 11 |
-
from typing import Dict, List, Optional, Union
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
| 15 |
from datasets import Dataset
|
| 16 |
from packaging import version
|
|
|
|
| 17 |
from torch.utils.data.sampler import RandomSampler
|
| 18 |
from transformers import (
|
| 19 |
BatchEncoding,
|
|
@@ -23,8 +24,11 @@ from transformers import (
|
|
| 23 |
)
|
| 24 |
from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
|
| 25 |
from transformers.trainer_pt_utils import (
|
|
|
|
|
|
|
| 26 |
LengthGroupedSampler,
|
| 27 |
)
|
|
|
|
| 28 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
| 29 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
| 30 |
|
|
@@ -603,7 +607,7 @@ class GeneformerPretrainer(Trainer):
|
|
| 603 |
)
|
| 604 |
super().__init__(*args, **kwargs)
|
| 605 |
|
| 606 |
-
#
|
| 607 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
| 608 |
if not isinstance(self.train_dataset, collections.abc.Sized):
|
| 609 |
return None
|
|
@@ -626,15 +630,181 @@ class GeneformerPretrainer(Trainer):
|
|
| 626 |
if self.tokenizer is not None
|
| 627 |
else None
|
| 628 |
)
|
| 629 |
-
|
|
|
|
| 630 |
dataset=self.train_dataset,
|
| 631 |
batch_size=self.args.train_batch_size,
|
| 632 |
lengths=lengths,
|
| 633 |
model_input_name=model_input_name,
|
| 634 |
generator=generator,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
|
|
|
|
|
|
|
|
|
| 637 |
else:
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import pickle
|
| 9 |
import warnings
|
| 10 |
from enum import Enum
|
| 11 |
+
from typing import Dict, Iterator, List, Optional, Union
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
| 15 |
from datasets import Dataset
|
| 16 |
from packaging import version
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
from torch.utils.data.sampler import RandomSampler
|
| 19 |
from transformers import (
|
| 20 |
BatchEncoding,
|
|
|
|
| 24 |
)
|
| 25 |
from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
|
| 26 |
from transformers.trainer_pt_utils import (
|
| 27 |
+
DistributedLengthGroupedSampler,
|
| 28 |
+
DistributedSamplerWithLoop,
|
| 29 |
LengthGroupedSampler,
|
| 30 |
)
|
| 31 |
+
from transformers.training_args import ParallelMode
|
| 32 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
| 33 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
| 34 |
|
|
|
|
| 607 |
)
|
| 608 |
super().__init__(*args, **kwargs)
|
| 609 |
|
| 610 |
+
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
| 611 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
| 612 |
if not isinstance(self.train_dataset, collections.abc.Sized):
|
| 613 |
return None
|
|
|
|
| 630 |
if self.tokenizer is not None
|
| 631 |
else None
|
| 632 |
)
|
| 633 |
+
if self.args.world_size <= 1:
|
| 634 |
+
return LengthGroupedSampler(
|
| 635 |
dataset=self.train_dataset,
|
| 636 |
batch_size=self.args.train_batch_size,
|
| 637 |
lengths=lengths,
|
| 638 |
model_input_name=model_input_name,
|
| 639 |
generator=generator,
|
| 640 |
+
)
|
| 641 |
+
else:
|
| 642 |
+
return CustomDistributedLengthGroupedSampler(
|
| 643 |
+
dataset=self.train_dataset,
|
| 644 |
+
batch_size=self.args.train_batch_size,
|
| 645 |
+
num_replicas=self.args.world_size,
|
| 646 |
+
rank=self.args.process_index,
|
| 647 |
+
lengths=lengths,
|
| 648 |
+
model_input_name=model_input_name,
|
| 649 |
+
seed=self.args.seed,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
else:
|
| 653 |
+
if self.args.world_size <= 1:
|
| 654 |
+
if _is_torch_generator_available:
|
| 655 |
+
return RandomSampler(self.train_dataset, generator=generator)
|
| 656 |
+
return RandomSampler(self.train_dataset)
|
| 657 |
+
elif (
|
| 658 |
+
self.args.parallel_mode
|
| 659 |
+
in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
| 660 |
+
and not self.args.dataloader_drop_last
|
| 661 |
+
):
|
| 662 |
+
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
| 663 |
+
return DistributedSamplerWithLoop(
|
| 664 |
+
self.train_dataset,
|
| 665 |
+
batch_size=self.args.per_device_train_batch_size,
|
| 666 |
+
num_replicas=self.args.world_size,
|
| 667 |
+
rank=self.args.process_index,
|
| 668 |
+
seed=self.args.seed,
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
return DistributedSampler(
|
| 672 |
+
self.train_dataset,
|
| 673 |
+
num_replicas=self.args.world_size,
|
| 674 |
+
rank=self.args.process_index,
|
| 675 |
+
seed=self.args.seed,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
| 680 |
+
r"""
|
| 681 |
+
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
| 682 |
+
length while keeping a bit of randomness.
|
| 683 |
+
"""
|
| 684 |
+
|
| 685 |
+
# Copied and adapted from PyTorch DistributedSampler.
|
| 686 |
+
def __init__(
|
| 687 |
+
self,
|
| 688 |
+
dataset: Dataset,
|
| 689 |
+
batch_size: int,
|
| 690 |
+
num_replicas: Optional[int] = None,
|
| 691 |
+
rank: Optional[int] = None,
|
| 692 |
+
seed: int = 0,
|
| 693 |
+
drop_last: bool = False,
|
| 694 |
+
lengths: Optional[List[int]] = None,
|
| 695 |
+
model_input_name: Optional[str] = None,
|
| 696 |
+
):
|
| 697 |
+
if num_replicas is None:
|
| 698 |
+
if not dist.is_available():
|
| 699 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 700 |
+
num_replicas = dist.get_world_size()
|
| 701 |
+
if rank is None:
|
| 702 |
+
if not dist.is_available():
|
| 703 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 704 |
+
rank = dist.get_rank()
|
| 705 |
+
self.dataset = dataset
|
| 706 |
+
self.batch_size = batch_size
|
| 707 |
+
self.num_replicas = num_replicas
|
| 708 |
+
self.rank = rank
|
| 709 |
+
self.epoch = 0
|
| 710 |
+
self.drop_last = drop_last
|
| 711 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
| 712 |
+
# is no need to drop any data, since the dataset will be split equally.
|
| 713 |
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
| 714 |
+
# Split to nearest available length that is evenly divisible.
|
| 715 |
+
# This is to ensure each rank receives the same amount of data when
|
| 716 |
+
# using this Sampler.
|
| 717 |
+
self.num_samples = math.ceil(
|
| 718 |
+
(len(self.dataset) - self.num_replicas) / self.num_replicas
|
| 719 |
)
|
| 720 |
+
else:
|
| 721 |
+
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
| 722 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 723 |
+
self.seed = seed
|
| 724 |
+
self.model_input_name = (
|
| 725 |
+
model_input_name if model_input_name is not None else "input_ids"
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
if lengths is None:
|
| 729 |
+
print("Lengths is none - calculating lengths.")
|
| 730 |
+
if (
|
| 731 |
+
not (
|
| 732 |
+
isinstance(dataset[0], dict)
|
| 733 |
+
or isinstance(dataset[0], BatchEncoding)
|
| 734 |
+
)
|
| 735 |
+
or self.model_input_name not in dataset[0]
|
| 736 |
+
):
|
| 737 |
+
raise ValueError(
|
| 738 |
+
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
| 739 |
+
f"'{self.model_input_name}' key."
|
| 740 |
+
)
|
| 741 |
+
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
| 742 |
+
self.lengths = lengths
|
| 743 |
+
|
| 744 |
+
def __iter__(self) -> Iterator:
|
| 745 |
+
# Deterministically shuffle based on epoch and seed
|
| 746 |
+
g = torch.Generator()
|
| 747 |
+
g.manual_seed(self.seed + self.epoch)
|
| 748 |
+
|
| 749 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
| 750 |
|
| 751 |
+
if not self.drop_last:
|
| 752 |
+
# add extra samples to make it evenly divisible
|
| 753 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 754 |
else:
|
| 755 |
+
# remove tail of data to make it evenly divisible.
|
| 756 |
+
indices = indices[: self.total_size]
|
| 757 |
+
assert len(indices) == self.total_size
|
| 758 |
+
|
| 759 |
+
# subsample
|
| 760 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
| 761 |
+
assert len(indices) == self.num_samples
|
| 762 |
+
|
| 763 |
+
return iter(indices)
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
def get_length_grouped_indices(
|
| 767 |
+
lengths, batch_size, mega_batch_mult=None, generator=None
|
| 768 |
+
):
|
| 769 |
+
"""
|
| 770 |
+
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
| 771 |
+
similar lengths. To do this, the indices are:
|
| 772 |
+
|
| 773 |
+
- randomly permuted
|
| 774 |
+
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
| 775 |
+
- sorted by length in each mega-batch
|
| 776 |
+
|
| 777 |
+
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
| 778 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
| 779 |
+
"""
|
| 780 |
+
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
| 781 |
+
if mega_batch_mult is None:
|
| 782 |
+
# mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
| 783 |
+
mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
|
| 784 |
+
# Just in case, for tiny datasets
|
| 785 |
+
if mega_batch_mult == 0:
|
| 786 |
+
mega_batch_mult = 1
|
| 787 |
+
|
| 788 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
| 789 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
| 790 |
+
megabatch_size = mega_batch_mult * batch_size
|
| 791 |
+
megabatches = [
|
| 792 |
+
indices[i : i + megabatch_size].tolist()
|
| 793 |
+
for i in range(0, len(lengths), megabatch_size)
|
| 794 |
+
]
|
| 795 |
+
megabatches = [
|
| 796 |
+
list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
|
| 797 |
+
for megabatch in megabatches
|
| 798 |
+
]
|
| 799 |
+
|
| 800 |
+
# The rest is to get the biggest batch first.
|
| 801 |
+
# Since each megabatch is sorted by descending length, the longest element is the first
|
| 802 |
+
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
| 803 |
+
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
| 804 |
+
# Switch to put the longest element in first position
|
| 805 |
+
megabatches[0][0], megabatches[max_idx][0] = (
|
| 806 |
+
megabatches[max_idx][0],
|
| 807 |
+
megabatches[0][0],
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
return [item for sublist in megabatches for item in sublist]
|
geneformer/{token_dictionary_gc104M.pkl → token_dictionary_gc95M.pkl}
RENAMED
|
File without changes
|
geneformer/tokenizer.py
CHANGED
|
@@ -3,7 +3,7 @@ Geneformer tokenizer.
|
|
| 3 |
|
| 4 |
**Input data:**
|
| 5 |
|
| 6 |
-
| *Required format:* raw counts scRNAseq data without feature selection as .loom
|
| 7 |
| *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
|
| 8 |
| *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
|
| 9 |
|
|
@@ -20,9 +20,9 @@ Geneformer tokenizer.
|
|
| 20 |
|
| 21 |
**Description:**
|
| 22 |
|
| 23 |
-
| Input data is a directory with .loom
|
| 24 |
|
| 25 |
-
| The discussion below references the .loom file format, but the analagous labels are required for .h5ad
|
| 26 |
|
| 27 |
| Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
|
| 28 |
|
|
@@ -30,9 +30,11 @@ Geneformer tokenizer.
|
|
| 30 |
|
| 31 |
| Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
|
| 32 |
|
| 33 |
-
| If one's data is in other formats besides .loom
|
| 34 |
|
| 35 |
-
| OF NOTE:
|
|
|
|
|
|
|
| 36 |
|
| 37 |
"""
|
| 38 |
|
|
@@ -46,7 +48,6 @@ from collections import Counter
|
|
| 46 |
from pathlib import Path
|
| 47 |
from typing import Literal
|
| 48 |
|
| 49 |
-
import anndata as ad
|
| 50 |
import loompy as lp
|
| 51 |
import numpy as np
|
| 52 |
import pandas as pd
|
|
@@ -87,8 +88,6 @@ def sum_ensembl_ids(
|
|
| 87 |
collapse_gene_ids,
|
| 88 |
gene_mapping_dict,
|
| 89 |
gene_token_dict,
|
| 90 |
-
custom_attr_name_dict,
|
| 91 |
-
use_h5ad_index,
|
| 92 |
file_format="loom",
|
| 93 |
chunk_size=512,
|
| 94 |
):
|
|
@@ -104,45 +103,33 @@ def sum_ensembl_ids(
|
|
| 104 |
assert (
|
| 105 |
"ensembl_id_collapsed" not in data.ra.keys()
|
| 106 |
), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
|
| 107 |
-
|
| 108 |
-
assert (
|
| 109 |
-
"n_counts" in data.ca.keys()
|
| 110 |
-
), "'n_counts' column missing from data.ca.keys()"
|
| 111 |
-
|
| 112 |
-
if custom_attr_name_dict is not None:
|
| 113 |
-
for label in custom_attr_name_dict:
|
| 114 |
-
assert label in data.ca.keys(), f"Attribute `{label}` not present in dataset features"
|
| 115 |
-
|
| 116 |
-
# Get the ensembl ids that exist in data
|
| 117 |
-
ensembl_ids = data.ra.ensembl_id
|
| 118 |
# Check for duplicate Ensembl IDs if collapse_gene_ids is False.
|
| 119 |
# Comparing to gene_token_dict here, would not perform any mapping steps
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
return data_directory
|
| 126 |
else:
|
| 127 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
data.ra["ensembl_id_collapsed"] =
|
| 138 |
return data_directory
|
| 139 |
-
# Genes need to be collapsed
|
| 140 |
else:
|
| 141 |
dedup_filename = data_directory.with_name(
|
| 142 |
data_directory.stem + "__dedup.loom"
|
| 143 |
)
|
| 144 |
-
|
| 145 |
-
data.ra["ensembl_id_collapsed"] = mapped_vals
|
| 146 |
dup_genes = [
|
| 147 |
idx
|
| 148 |
for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
|
|
@@ -201,19 +188,13 @@ def sum_ensembl_ids(
|
|
| 201 |
dsout.add_columns(processed_array, col_attrs=view.ca)
|
| 202 |
return dedup_filename
|
| 203 |
|
| 204 |
-
elif file_format
|
| 205 |
"""
|
| 206 |
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
| 207 |
Returns adata object with deduplicated Ensembl IDs.
|
| 208 |
"""
|
| 209 |
|
| 210 |
-
|
| 211 |
-
data = sc.read_h5ad(str(data_directory))
|
| 212 |
-
else: # zarr
|
| 213 |
-
data = ad.read_zarr(str(data_directory))
|
| 214 |
-
|
| 215 |
-
if use_h5ad_index:
|
| 216 |
-
data.var["ensembl_id"] = list(data.var.index)
|
| 217 |
|
| 218 |
assert (
|
| 219 |
"ensembl_id" in data.var.columns
|
|
@@ -222,41 +203,33 @@ def sum_ensembl_ids(
|
|
| 222 |
assert (
|
| 223 |
"ensembl_id_collapsed" not in data.var.columns
|
| 224 |
), "'ensembl_id_collapsed' column already exists in data.var"
|
| 225 |
-
assert (
|
| 226 |
-
"n_counts" in data.obs.columns
|
| 227 |
-
), "'n_counts' column missing from data.obs"
|
| 228 |
-
|
| 229 |
-
if custom_attr_name_dict is not None:
|
| 230 |
-
for label in custom_attr_name_dict:
|
| 231 |
-
assert label in data.obs.columns, f"Attribute `{label}` not present in data.obs"
|
| 232 |
|
| 233 |
-
|
| 234 |
-
# Get the ensembl ids that exist in data
|
| 235 |
-
ensembl_ids = data.var.ensembl_id
|
| 236 |
# Check for duplicate Ensembl IDs if collapse_gene_ids is False.
|
| 237 |
# Comparing to gene_token_dict here, would not perform any mapping steps
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
| 243 |
return data
|
| 244 |
else:
|
| 245 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
| 246 |
|
| 247 |
-
#
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
| 255 |
return data
|
| 256 |
-
|
| 257 |
else:
|
| 258 |
-
data.var["ensembl_id_collapsed"] =
|
| 259 |
-
data.var_names =
|
| 260 |
data = data[:, ~data.var.index.isna()]
|
| 261 |
dup_genes = [
|
| 262 |
idx for idx, count in Counter(data.var_names).items() if count > 1
|
|
@@ -305,9 +278,6 @@ class TranscriptomeTokenizer:
|
|
| 305 |
model_input_size=4096,
|
| 306 |
special_token=True,
|
| 307 |
collapse_gene_ids=True,
|
| 308 |
-
use_h5ad_index=False,
|
| 309 |
-
keep_counts=False,
|
| 310 |
-
model_version="V2",
|
| 311 |
gene_median_file=GENE_MEDIAN_FILE,
|
| 312 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 313 |
gene_mapping_file=ENSEMBL_MAPPING_FILE,
|
|
@@ -327,23 +297,15 @@ class TranscriptomeTokenizer:
|
|
| 327 |
| Chunk size for anndata tokenizer.
|
| 328 |
model_input_size : int = 4096
|
| 329 |
| Max input size of model to truncate input to.
|
| 330 |
-
| For the
|
| 331 |
special_token : bool = True
|
| 332 |
| Adds CLS token before and EOS token after rank value encoding.
|
| 333 |
-
| For the
|
| 334 |
collapse_gene_ids : bool = True
|
| 335 |
| Whether to collapse gene IDs based on gene mapping dictionary.
|
| 336 |
-
use_h5ad_index : bool = False
|
| 337 |
-
| use index as Ensembl IDs (only available for h5ad, only if collapse_gene_ids is True)
|
| 338 |
-
keep_counts : bool = False
|
| 339 |
-
| Whether to keep a dataset column that represents gene counts normalized by total cell counts
|
| 340 |
-
| Counts will be ordered by the gene rank order within the tokenized rank value encoding for each cell.
|
| 341 |
-
model_version : str
|
| 342 |
-
| To auto-select settings for model version other than current default.
|
| 343 |
-
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
| 344 |
gene_median_file : Path
|
| 345 |
| Path to pickle file containing dictionary of non-zero median
|
| 346 |
-
| gene expression values across Genecorpus.
|
| 347 |
token_dictionary_file : Path
|
| 348 |
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
| 349 |
gene_mapping_file : None, Path
|
|
@@ -365,22 +327,8 @@ class TranscriptomeTokenizer:
|
|
| 365 |
# add CLS and EOS tokens
|
| 366 |
self.special_token = special_token
|
| 367 |
|
| 368 |
-
# CHANGE DEFAULTS TO BE FOR MODEL OTHER THAN CURRENT
|
| 369 |
-
self.model_version = model_version
|
| 370 |
-
if self.model_version not in ["V1","V2"]:
|
| 371 |
-
logger.error(
|
| 372 |
-
"Unrecognized model version. Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells."
|
| 373 |
-
)
|
| 374 |
-
elif self.model_version == "V1":
|
| 375 |
-
self.model_input_size = 2048
|
| 376 |
-
self.special_token = False
|
| 377 |
-
from . import ENSEMBL_MAPPING_FILE_30M, GENE_MEDIAN_FILE_30M, TOKEN_DICTIONARY_FILE_30M
|
| 378 |
-
gene_median_file = GENE_MEDIAN_FILE_30M
|
| 379 |
-
token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
| 380 |
-
gene_mapping_file = ENSEMBL_MAPPING_FILE_30M
|
| 381 |
-
|
| 382 |
# load dictionary of gene normalization factors
|
| 383 |
-
# (non-zero median value of expression across Genecorpus)
|
| 384 |
with open(gene_median_file, "rb") as f:
|
| 385 |
self.gene_median_dict = pickle.load(f)
|
| 386 |
|
|
@@ -403,18 +351,12 @@ class TranscriptomeTokenizer:
|
|
| 403 |
"<eos>" in self.gene_token_dict.keys()
|
| 404 |
):
|
| 405 |
logger.warning(
|
| 406 |
-
"<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for
|
| 407 |
)
|
| 408 |
|
| 409 |
# if collapsing duplicate gene IDs
|
| 410 |
self.collapse_gene_ids = collapse_gene_ids
|
| 411 |
|
| 412 |
-
# if using h5ad index as ensembl_ids
|
| 413 |
-
self.use_h5ad_index = use_h5ad_index
|
| 414 |
-
|
| 415 |
-
# if keeping counts within dataset column
|
| 416 |
-
self.keep_counts = keep_counts
|
| 417 |
-
|
| 418 |
# load gene mappings dictionary (Ensembl IDs:Ensembl ID)
|
| 419 |
if gene_mapping_file is not None:
|
| 420 |
with open(gene_mapping_file, "rb") as f:
|
|
@@ -439,8 +381,7 @@ class TranscriptomeTokenizer:
|
|
| 439 |
data_directory: Path | str,
|
| 440 |
output_directory: Path | str,
|
| 441 |
output_prefix: str,
|
| 442 |
-
file_format: Literal["loom", "h5ad"
|
| 443 |
-
input_identifier: str = "",
|
| 444 |
use_generator: bool = False,
|
| 445 |
):
|
| 446 |
"""
|
|
@@ -455,21 +396,17 @@ class TranscriptomeTokenizer:
|
|
| 455 |
output_prefix : str
|
| 456 |
| Prefix for output .dataset
|
| 457 |
file_format : str
|
| 458 |
-
| Format of input files. Can be "loom"
|
| 459 |
-
input_identifier : str
|
| 460 |
-
| Substring identifier for input .loom, .h5ad, or .zarr, only matches are tokenized
|
| 461 |
-
| Default is no identifier, tokenizes all files in provided directory.
|
| 462 |
use_generator : bool
|
| 463 |
| Whether to use generator or dict for tokenization.
|
| 464 |
|
| 465 |
"""
|
| 466 |
-
tokenized_cells, cell_metadata
|
| 467 |
-
Path(data_directory), file_format
|
| 468 |
)
|
| 469 |
tokenized_dataset = self.create_dataset(
|
| 470 |
tokenized_cells,
|
| 471 |
cell_metadata,
|
| 472 |
-
tokenized_counts,
|
| 473 |
use_generator=use_generator,
|
| 474 |
)
|
| 475 |
|
|
@@ -477,10 +414,9 @@ class TranscriptomeTokenizer:
|
|
| 477 |
tokenized_dataset.save_to_disk(str(output_path))
|
| 478 |
|
| 479 |
def tokenize_files(
|
| 480 |
-
self, data_directory, file_format: Literal["loom", "h5ad"
|
| 481 |
):
|
| 482 |
tokenized_cells = []
|
| 483 |
-
tokenized_counts = []
|
| 484 |
if self.custom_attr_name_dict is not None:
|
| 485 |
cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
|
| 486 |
cell_metadata = {
|
|
@@ -489,20 +425,15 @@ class TranscriptomeTokenizer:
|
|
| 489 |
|
| 490 |
# loops through directories to tokenize .loom files
|
| 491 |
file_found = 0
|
| 492 |
-
# loops through directories to tokenize .loom
|
| 493 |
tokenize_file_fn = (
|
| 494 |
self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
|
| 495 |
)
|
| 496 |
-
|
| 497 |
-
file_match = f"*.{file_format}"
|
| 498 |
-
else:
|
| 499 |
-
file_match = f"*{input_identifier}*.{file_format}"
|
| 500 |
-
for file_path in data_directory.glob(file_match):
|
| 501 |
file_found = 1
|
| 502 |
print(f"Tokenizing {file_path}")
|
| 503 |
-
file_tokenized_cells, file_cell_metadata
|
| 504 |
tokenized_cells += file_tokenized_cells
|
| 505 |
-
tokenized_counts += file_tokenized_counts
|
| 506 |
if self.custom_attr_name_dict is not None:
|
| 507 |
for k in cell_attr:
|
| 508 |
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
|
|
@@ -516,17 +447,15 @@ class TranscriptomeTokenizer:
|
|
| 516 |
f"No .{file_format} files found in directory {data_directory}."
|
| 517 |
)
|
| 518 |
raise
|
| 519 |
-
return tokenized_cells, cell_metadata
|
| 520 |
|
| 521 |
-
def tokenize_anndata(self, adata_file_path, target_sum=10_000
|
| 522 |
adata = sum_ensembl_ids(
|
| 523 |
adata_file_path,
|
| 524 |
self.collapse_gene_ids,
|
| 525 |
self.gene_mapping_dict,
|
| 526 |
self.gene_token_dict,
|
| 527 |
-
|
| 528 |
-
self.use_h5ad_index,
|
| 529 |
-
file_format=file_format,
|
| 530 |
chunk_size=self.chunk_size,
|
| 531 |
)
|
| 532 |
|
|
@@ -565,7 +494,6 @@ class TranscriptomeTokenizer:
|
|
| 565 |
filter_pass_loc = np.array([i for i in range(adata.shape[0])])
|
| 566 |
|
| 567 |
tokenized_cells = []
|
| 568 |
-
tokenized_counts = []
|
| 569 |
|
| 570 |
for i in range(0, len(filter_pass_loc), self.chunk_size):
|
| 571 |
idx = filter_pass_loc[i : i + self.chunk_size]
|
|
@@ -573,23 +501,14 @@ class TranscriptomeTokenizer:
|
|
| 573 |
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
| 574 |
X_view0 = adata[idx, :].X
|
| 575 |
X_view = X_view0[:, coding_miRNA_loc]
|
| 576 |
-
|
| 577 |
-
X_norm = X_norm_unscaled / norm_factor_vector
|
| 578 |
X_norm = sp.csr_matrix(X_norm)
|
| 579 |
-
X_norm_unscaled = sp.csr_matrix(X_norm_unscaled)
|
| 580 |
|
| 581 |
tokenized_cells += [
|
| 582 |
rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
|
| 583 |
for i in range(X_norm.shape[0])
|
| 584 |
]
|
| 585 |
|
| 586 |
-
if self.keep_counts:
|
| 587 |
-
X_norm_unscaled = sp.csr_matrix(X_norm_unscaled)
|
| 588 |
-
tokenized_counts += [
|
| 589 |
-
rank_genes(X_norm[i].data, X_norm_unscaled[i].data)
|
| 590 |
-
for i in range(X_norm.shape[0])
|
| 591 |
-
]
|
| 592 |
-
|
| 593 |
# add custom attributes for subview to dict
|
| 594 |
if self.custom_attr_name_dict is not None:
|
| 595 |
for k in file_cell_metadata.keys():
|
|
@@ -597,28 +516,9 @@ class TranscriptomeTokenizer:
|
|
| 597 |
else:
|
| 598 |
file_cell_metadata = None
|
| 599 |
|
| 600 |
-
|
| 601 |
-
empty_cell_indices = [i for i, cell in enumerate(tokenized_cells) if cell.size == 0]
|
| 602 |
-
if len(empty_cell_indices) > 0:
|
| 603 |
-
logger.warning(
|
| 604 |
-
"Warning: cells without any genes in token dictionary detected. This is unusual and may indicate empty droplets or otherwise invalid cells within the input data. Consider further QC prior to tokenization. Proceeding with excluding empty cells."
|
| 605 |
-
)
|
| 606 |
-
empty_cell_indices.sort(reverse=True) # for safe deletion
|
| 607 |
-
for index in empty_cell_indices:
|
| 608 |
-
del tokenized_cells[index]
|
| 609 |
-
if self.keep_counts:
|
| 610 |
-
del tokenized_counts[index]
|
| 611 |
-
# remove corresponding metadata
|
| 612 |
-
for k,v in file_cell_metadata.items():
|
| 613 |
-
for index in empty_cell_indices:
|
| 614 |
-
del v[index]
|
| 615 |
-
file_cell_metadata[k] = v
|
| 616 |
-
|
| 617 |
-
return tokenized_cells, file_cell_metadata, tokenized_counts
|
| 618 |
|
| 619 |
-
def tokenize_loom(self, loom_file_path, target_sum=10_000
|
| 620 |
-
tokenized_counts = [] # keep_counts not implemented for tokenize_loom
|
| 621 |
-
|
| 622 |
if self.custom_attr_name_dict is not None:
|
| 623 |
file_cell_metadata = {
|
| 624 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
|
@@ -631,9 +531,7 @@ class TranscriptomeTokenizer:
|
|
| 631 |
self.collapse_gene_ids,
|
| 632 |
self.gene_mapping_dict,
|
| 633 |
self.gene_token_dict,
|
| 634 |
-
|
| 635 |
-
use_h5ad_index=False,
|
| 636 |
-
file_format=file_format,
|
| 637 |
chunk_size=self.chunk_size,
|
| 638 |
)
|
| 639 |
|
|
@@ -706,33 +604,29 @@ class TranscriptomeTokenizer:
|
|
| 706 |
del data.ra["ensembl_id_collapsed"]
|
| 707 |
|
| 708 |
|
| 709 |
-
return tokenized_cells, file_cell_metadata
|
| 710 |
|
| 711 |
def create_dataset(
|
| 712 |
self,
|
| 713 |
tokenized_cells,
|
| 714 |
cell_metadata,
|
| 715 |
-
tokenized_counts,
|
| 716 |
use_generator=False,
|
| 717 |
keep_uncropped_input_ids=False,
|
| 718 |
):
|
| 719 |
print("Creating dataset.")
|
| 720 |
# create dict for dataset creation
|
| 721 |
dataset_dict = {"input_ids": tokenized_cells}
|
| 722 |
-
if self.keep_counts:
|
| 723 |
-
dataset_dict["counts"] = tokenized_counts
|
| 724 |
-
|
| 725 |
if self.custom_attr_name_dict is not None:
|
| 726 |
dataset_dict.update(cell_metadata)
|
| 727 |
|
| 728 |
# create dataset
|
| 729 |
if use_generator:
|
|
|
|
| 730 |
def dict_generator():
|
| 731 |
for i in range(len(tokenized_cells)):
|
| 732 |
yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
|
| 733 |
|
| 734 |
output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
|
| 735 |
-
|
| 736 |
else:
|
| 737 |
output_dataset = Dataset.from_dict(dataset_dict)
|
| 738 |
|
|
@@ -755,23 +649,9 @@ class TranscriptomeTokenizer:
|
|
| 755 |
len(example["input_ids"]),
|
| 756 |
self.gene_token_dict.get("<eos>"),
|
| 757 |
)
|
| 758 |
-
if self.keep_counts:
|
| 759 |
-
example["counts"] = example["counts"][
|
| 760 |
-
0 : self.model_input_size - 2
|
| 761 |
-
] # truncate to leave space for CLS and EOS token
|
| 762 |
-
example["counts"] = np.insert(
|
| 763 |
-
example["counts"], 0, 0.0
|
| 764 |
-
)
|
| 765 |
-
example["counts"] = np.insert(
|
| 766 |
-
example["counts"],
|
| 767 |
-
len(example["counts"]),
|
| 768 |
-
0.0,
|
| 769 |
-
)
|
| 770 |
else:
|
| 771 |
# Truncate/Crop input_ids to input size
|
| 772 |
example["input_ids"] = example["input_ids"][0 : self.model_input_size]
|
| 773 |
-
if self.keep_counts:
|
| 774 |
-
example["counts"] = example["counts"][0 : self.model_input_size]
|
| 775 |
example["length"] = len(example["input_ids"])
|
| 776 |
|
| 777 |
return example
|
|
|
|
| 3 |
|
| 4 |
**Input data:**
|
| 5 |
|
| 6 |
+
| *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
|
| 7 |
| *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
|
| 8 |
| *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
|
| 9 |
|
|
|
|
| 20 |
|
| 21 |
**Description:**
|
| 22 |
|
| 23 |
+
| Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
|
| 24 |
|
| 25 |
+
| The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
|
| 26 |
|
| 27 |
| Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
|
| 28 |
|
|
|
|
| 30 |
|
| 31 |
| Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
|
| 32 |
|
| 33 |
+
| If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
|
| 34 |
|
| 35 |
+
| OF NOTE: Take care that the correct token dictionary and gene median file is used for the correct model.
|
| 36 |
+
|
| 37 |
+
| OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048.
|
| 38 |
|
| 39 |
"""
|
| 40 |
|
|
|
|
| 48 |
from pathlib import Path
|
| 49 |
from typing import Literal
|
| 50 |
|
|
|
|
| 51 |
import loompy as lp
|
| 52 |
import numpy as np
|
| 53 |
import pandas as pd
|
|
|
|
| 88 |
collapse_gene_ids,
|
| 89 |
gene_mapping_dict,
|
| 90 |
gene_token_dict,
|
|
|
|
|
|
|
| 91 |
file_format="loom",
|
| 92 |
chunk_size=512,
|
| 93 |
):
|
|
|
|
| 103 |
assert (
|
| 104 |
"ensembl_id_collapsed" not in data.ra.keys()
|
| 105 |
), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
# Check for duplicate Ensembl IDs if collapse_gene_ids is False.
|
| 107 |
# Comparing to gene_token_dict here, would not perform any mapping steps
|
| 108 |
+
gene_ids_in_dict = [
|
| 109 |
+
gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
|
| 110 |
+
]
|
| 111 |
+
if collapse_gene_ids is False:
|
| 112 |
+
|
| 113 |
+
if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
|
| 114 |
return data_directory
|
| 115 |
else:
|
| 116 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
| 117 |
+
|
| 118 |
+
gene_ids_collapsed = [
|
| 119 |
+
gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
|
| 120 |
+
]
|
| 121 |
+
gene_ids_collapsed_in_dict = [
|
| 122 |
+
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
| 126 |
+
data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
|
| 127 |
return data_directory
|
|
|
|
| 128 |
else:
|
| 129 |
dedup_filename = data_directory.with_name(
|
| 130 |
data_directory.stem + "__dedup.loom"
|
| 131 |
)
|
| 132 |
+
data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
|
|
|
|
| 133 |
dup_genes = [
|
| 134 |
idx
|
| 135 |
for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
|
|
|
|
| 188 |
dsout.add_columns(processed_array, col_attrs=view.ca)
|
| 189 |
return dedup_filename
|
| 190 |
|
| 191 |
+
elif file_format == "h5ad":
|
| 192 |
"""
|
| 193 |
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
| 194 |
Returns adata object with deduplicated Ensembl IDs.
|
| 195 |
"""
|
| 196 |
|
| 197 |
+
data = sc.read_h5ad(str(data_directory))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
assert (
|
| 200 |
"ensembl_id" in data.var.columns
|
|
|
|
| 203 |
assert (
|
| 204 |
"ensembl_id_collapsed" not in data.var.columns
|
| 205 |
), "'ensembl_id_collapsed' column already exists in data.var"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
|
|
|
|
|
|
|
|
|
| 207 |
# Check for duplicate Ensembl IDs if collapse_gene_ids is False.
|
| 208 |
# Comparing to gene_token_dict here, would not perform any mapping steps
|
| 209 |
+
gene_ids_in_dict = [
|
| 210 |
+
gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
|
| 211 |
+
]
|
| 212 |
+
if collapse_gene_ids is False:
|
| 213 |
+
|
| 214 |
+
if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
|
| 215 |
return data
|
| 216 |
else:
|
| 217 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
| 218 |
|
| 219 |
+
# Check for when if collapse_gene_ids is True
|
| 220 |
+
gene_ids_collapsed = [
|
| 221 |
+
gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
|
| 222 |
+
]
|
| 223 |
+
gene_ids_collapsed_in_dict = [
|
| 224 |
+
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
| 225 |
+
]
|
| 226 |
+
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
| 227 |
+
data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
|
| 228 |
return data
|
| 229 |
+
|
| 230 |
else:
|
| 231 |
+
data.var["ensembl_id_collapsed"] = gene_ids_collapsed
|
| 232 |
+
data.var_names = gene_ids_collapsed
|
| 233 |
data = data[:, ~data.var.index.isna()]
|
| 234 |
dup_genes = [
|
| 235 |
idx for idx, count in Counter(data.var_names).items() if count > 1
|
|
|
|
| 278 |
model_input_size=4096,
|
| 279 |
special_token=True,
|
| 280 |
collapse_gene_ids=True,
|
|
|
|
|
|
|
|
|
|
| 281 |
gene_median_file=GENE_MEDIAN_FILE,
|
| 282 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 283 |
gene_mapping_file=ENSEMBL_MAPPING_FILE,
|
|
|
|
| 297 |
| Chunk size for anndata tokenizer.
|
| 298 |
model_input_size : int = 4096
|
| 299 |
| Max input size of model to truncate input to.
|
| 300 |
+
| For the 30M model series, should be 2048. For the 95M model series, should be 4096.
|
| 301 |
special_token : bool = True
|
| 302 |
| Adds CLS token before and EOS token after rank value encoding.
|
| 303 |
+
| For the 30M model series, should be False. For the 95M model series, should be True.
|
| 304 |
collapse_gene_ids : bool = True
|
| 305 |
| Whether to collapse gene IDs based on gene mapping dictionary.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
gene_median_file : Path
|
| 307 |
| Path to pickle file containing dictionary of non-zero median
|
| 308 |
+
| gene expression values across Genecorpus-30M.
|
| 309 |
token_dictionary_file : Path
|
| 310 |
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
| 311 |
gene_mapping_file : None, Path
|
|
|
|
| 327 |
# add CLS and EOS tokens
|
| 328 |
self.special_token = special_token
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
# load dictionary of gene normalization factors
|
| 331 |
+
# (non-zero median value of expression across Genecorpus-30M)
|
| 332 |
with open(gene_median_file, "rb") as f:
|
| 333 |
self.gene_median_dict = pickle.load(f)
|
| 334 |
|
|
|
|
| 351 |
"<eos>" in self.gene_token_dict.keys()
|
| 352 |
):
|
| 353 |
logger.warning(
|
| 354 |
+
"<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for 95M model series, special_token should be True."
|
| 355 |
)
|
| 356 |
|
| 357 |
# if collapsing duplicate gene IDs
|
| 358 |
self.collapse_gene_ids = collapse_gene_ids
|
| 359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
# load gene mappings dictionary (Ensembl IDs:Ensembl ID)
|
| 361 |
if gene_mapping_file is not None:
|
| 362 |
with open(gene_mapping_file, "rb") as f:
|
|
|
|
| 381 |
data_directory: Path | str,
|
| 382 |
output_directory: Path | str,
|
| 383 |
output_prefix: str,
|
| 384 |
+
file_format: Literal["loom", "h5ad"] = "loom",
|
|
|
|
| 385 |
use_generator: bool = False,
|
| 386 |
):
|
| 387 |
"""
|
|
|
|
| 396 |
output_prefix : str
|
| 397 |
| Prefix for output .dataset
|
| 398 |
file_format : str
|
| 399 |
+
| Format of input files. Can be "loom" or "h5ad".
|
|
|
|
|
|
|
|
|
|
| 400 |
use_generator : bool
|
| 401 |
| Whether to use generator or dict for tokenization.
|
| 402 |
|
| 403 |
"""
|
| 404 |
+
tokenized_cells, cell_metadata = self.tokenize_files(
|
| 405 |
+
Path(data_directory), file_format
|
| 406 |
)
|
| 407 |
tokenized_dataset = self.create_dataset(
|
| 408 |
tokenized_cells,
|
| 409 |
cell_metadata,
|
|
|
|
| 410 |
use_generator=use_generator,
|
| 411 |
)
|
| 412 |
|
|
|
|
| 414 |
tokenized_dataset.save_to_disk(str(output_path))
|
| 415 |
|
| 416 |
def tokenize_files(
|
| 417 |
+
self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
|
| 418 |
):
|
| 419 |
tokenized_cells = []
|
|
|
|
| 420 |
if self.custom_attr_name_dict is not None:
|
| 421 |
cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
|
| 422 |
cell_metadata = {
|
|
|
|
| 425 |
|
| 426 |
# loops through directories to tokenize .loom files
|
| 427 |
file_found = 0
|
| 428 |
+
# loops through directories to tokenize .loom or .h5ad files
|
| 429 |
tokenize_file_fn = (
|
| 430 |
self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
|
| 431 |
)
|
| 432 |
+
for file_path in data_directory.glob(f"*.{file_format}"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
file_found = 1
|
| 434 |
print(f"Tokenizing {file_path}")
|
| 435 |
+
file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
|
| 436 |
tokenized_cells += file_tokenized_cells
|
|
|
|
| 437 |
if self.custom_attr_name_dict is not None:
|
| 438 |
for k in cell_attr:
|
| 439 |
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
|
|
|
|
| 447 |
f"No .{file_format} files found in directory {data_directory}."
|
| 448 |
)
|
| 449 |
raise
|
| 450 |
+
return tokenized_cells, cell_metadata
|
| 451 |
|
| 452 |
+
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
| 453 |
adata = sum_ensembl_ids(
|
| 454 |
adata_file_path,
|
| 455 |
self.collapse_gene_ids,
|
| 456 |
self.gene_mapping_dict,
|
| 457 |
self.gene_token_dict,
|
| 458 |
+
file_format="h5ad",
|
|
|
|
|
|
|
| 459 |
chunk_size=self.chunk_size,
|
| 460 |
)
|
| 461 |
|
|
|
|
| 494 |
filter_pass_loc = np.array([i for i in range(adata.shape[0])])
|
| 495 |
|
| 496 |
tokenized_cells = []
|
|
|
|
| 497 |
|
| 498 |
for i in range(0, len(filter_pass_loc), self.chunk_size):
|
| 499 |
idx = filter_pass_loc[i : i + self.chunk_size]
|
|
|
|
| 501 |
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
| 502 |
X_view0 = adata[idx, :].X
|
| 503 |
X_view = X_view0[:, coding_miRNA_loc]
|
| 504 |
+
X_norm = X_view / n_counts * target_sum / norm_factor_vector
|
|
|
|
| 505 |
X_norm = sp.csr_matrix(X_norm)
|
|
|
|
| 506 |
|
| 507 |
tokenized_cells += [
|
| 508 |
rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
|
| 509 |
for i in range(X_norm.shape[0])
|
| 510 |
]
|
| 511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
# add custom attributes for subview to dict
|
| 513 |
if self.custom_attr_name_dict is not None:
|
| 514 |
for k in file_cell_metadata.keys():
|
|
|
|
| 516 |
else:
|
| 517 |
file_cell_metadata = None
|
| 518 |
|
| 519 |
+
return tokenized_cells, file_cell_metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
+
def tokenize_loom(self, loom_file_path, target_sum=10_000):
|
|
|
|
|
|
|
| 522 |
if self.custom_attr_name_dict is not None:
|
| 523 |
file_cell_metadata = {
|
| 524 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
|
|
|
| 531 |
self.collapse_gene_ids,
|
| 532 |
self.gene_mapping_dict,
|
| 533 |
self.gene_token_dict,
|
| 534 |
+
file_format="loom",
|
|
|
|
|
|
|
| 535 |
chunk_size=self.chunk_size,
|
| 536 |
)
|
| 537 |
|
|
|
|
| 604 |
del data.ra["ensembl_id_collapsed"]
|
| 605 |
|
| 606 |
|
| 607 |
+
return tokenized_cells, file_cell_metadata
|
| 608 |
|
| 609 |
def create_dataset(
|
| 610 |
self,
|
| 611 |
tokenized_cells,
|
| 612 |
cell_metadata,
|
|
|
|
| 613 |
use_generator=False,
|
| 614 |
keep_uncropped_input_ids=False,
|
| 615 |
):
|
| 616 |
print("Creating dataset.")
|
| 617 |
# create dict for dataset creation
|
| 618 |
dataset_dict = {"input_ids": tokenized_cells}
|
|
|
|
|
|
|
|
|
|
| 619 |
if self.custom_attr_name_dict is not None:
|
| 620 |
dataset_dict.update(cell_metadata)
|
| 621 |
|
| 622 |
# create dataset
|
| 623 |
if use_generator:
|
| 624 |
+
|
| 625 |
def dict_generator():
|
| 626 |
for i in range(len(tokenized_cells)):
|
| 627 |
yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
|
| 628 |
|
| 629 |
output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
|
|
|
|
| 630 |
else:
|
| 631 |
output_dataset = Dataset.from_dict(dataset_dict)
|
| 632 |
|
|
|
|
| 649 |
len(example["input_ids"]),
|
| 650 |
self.gene_token_dict.get("<eos>"),
|
| 651 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
else:
|
| 653 |
# Truncate/Crop input_ids to input size
|
| 654 |
example["input_ids"] = example["input_ids"][0 : self.model_input_size]
|
|
|
|
|
|
|
| 655 |
example["length"] = len(example["input_ids"])
|
| 656 |
|
| 657 |
return example
|
generation_config.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{
|
| 2 |
"_from_model_config": true,
|
| 3 |
"pad_token_id": 0,
|
| 4 |
-
"transformers_version": "4.
|
| 5 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"_from_model_config": true,
|
| 3 |
"pad_token_id": 0,
|
| 4 |
+
"transformers_version": "4.37.1"
|
| 5 |
}
|
gf-12L-30M-i2048/config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForMaskedLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.02,
|
| 6 |
+
"gradient_checkpointing": false,
|
| 7 |
+
"hidden_act": "relu",
|
| 8 |
+
"hidden_dropout_prob": 0.02,
|
| 9 |
+
"hidden_size": 512,
|
| 10 |
+
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 1024,
|
| 12 |
+
"layer_norm_eps": 1e-12,
|
| 13 |
+
"max_position_embeddings": 2048,
|
| 14 |
+
"model_type": "bert",
|
| 15 |
+
"num_attention_heads": 8,
|
| 16 |
+
"num_hidden_layers": 12,
|
| 17 |
+
"pad_token_id": 0,
|
| 18 |
+
"position_embedding_type": "absolute",
|
| 19 |
+
"transformers_version": "4.6.0",
|
| 20 |
+
"type_vocab_size": 2,
|
| 21 |
+
"use_cache": true,
|
| 22 |
+
"vocab_size": 25426
|
| 23 |
+
}
|
gf-12L-30M-i2048/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615
|
| 3 |
+
size 158467410
|