This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. Geneformer-V2-316M/model.safetensors +0 -3
  2. MANIFEST.in +4 -9
  3. README.md +16 -16
  4. config.json +7 -7
  5. docs/source/geneformer.in_silico_perturber.rst +1 -1
  6. examples/cell_classification.ipynb +5 -8
  7. examples/distributed_multitask_cell_classification.ipynb +0 -149
  8. examples/extract_and_plot_cell_embeddings.ipynb +3 -5
  9. examples/gene_classification.ipynb +7 -11
  10. examples/in_silico_perturbation.ipynb +10 -17
  11. examples/multitask_cell_classification.ipynb +3 -3
  12. examples/tokenizing_scRNAseq_data.ipynb +8 -14
  13. {Geneformer-V2-104M → fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522}/config.json +6 -6
  14. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +3 -0
  15. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
  16. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
  17. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
  18. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
  19. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
  20. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
  21. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
  22. geneformer/__init__.py +4 -9
  23. geneformer/classifier.py +9 -61
  24. geneformer/classifier_utils.py +6 -36
  25. geneformer/collator_for_classification.py +2 -7
  26. geneformer/emb_extractor.py +51 -132
  27. geneformer/{ensembl_mapping_dict_gc104M.pkl → ensembl_mapping_dict_gc95M.pkl} +0 -0
  28. geneformer/evaluation_utils.py +14 -33
  29. geneformer/{gene_median_dictionary_gc104M.pkl → gene_median_dictionary_gc95M.pkl} +0 -0
  30. geneformer/{gene_name_id_dict_gc104M.pkl → gene_name_id_dict_gc95M.pkl} +2 -2
  31. geneformer/in_silico_perturber.py +11 -53
  32. geneformer/in_silico_perturber_stats.py +5 -25
  33. geneformer/mtl/__init__.py +1 -4
  34. geneformer/mtl/collators.py +2 -2
  35. geneformer/mtl/data.py +123 -210
  36. geneformer/mtl/eval_utils.py +8 -5
  37. geneformer/mtl/imports.py +43 -0
  38. geneformer/mtl/model.py +1 -1
  39. geneformer/mtl/optuna_utils.py +27 -0
  40. geneformer/mtl/train.py +329 -656
  41. geneformer/mtl/train_utils.py +161 -0
  42. geneformer/mtl/utils.py +91 -603
  43. geneformer/mtl_classifier.py +8 -23
  44. geneformer/perturber_utils.py +65 -62
  45. geneformer/pretrainer.py +176 -6
  46. geneformer/{token_dictionary_gc104M.pkl → token_dictionary_gc95M.pkl} +0 -0
  47. geneformer/tokenizer.py +66 -186
  48. generation_config.json +1 -1
  49. gf-12L-30M-i2048/config.json +23 -0
  50. gf-12L-30M-i2048/pytorch_model.bin +3 -0
Geneformer-V2-316M/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:965ceccea81953d362081ef3843560a0e4fef88d396c28017881f1e94b1246f3
3
- size 1265455076
 
 
 
 
MANIFEST.in CHANGED
@@ -1,9 +1,4 @@
1
- include geneformer/gene_median_dictionary_gc104M.pkl
2
- include geneformer/gene_name_id_dict_gc104M.pkl
3
- include geneformer/ensembl_mapping_dict_gc104M.pkl
4
- include geneformer/token_dictionary_gc104M.pkl
5
-
6
- include geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl
7
- include geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl
8
- include geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl
9
- include geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl
 
1
+ include geneformer/gene_median_dictionary_gc95M.pkl
2
+ include geneformer/gene_name_id_dict_gc95M.pkl
3
+ include geneformer/ensembl_mapping_dict_gc95M.pkl
4
+ include geneformer/token_dictionary_gc95M.pkl
 
 
 
 
 
README.md CHANGED
@@ -1,36 +1,40 @@
1
  ---
2
  datasets: ctheodoris/Genecorpus-30M
3
  license: apache-2.0
4
- tags:
5
- - single-cell
6
- - genomics
7
  ---
8
  # Geneformer
9
  Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
10
 
11
  - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
12
- - See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model, now trained on ~104 million transcriptomes, and our continual learning, multitask learning, and quantization strategies.
13
  - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
14
 
15
  # Model Description
16
- Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer V1 was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million human single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. The current updated Geneformer V2 is pretrained on ~104 million human single cell transcriptomes (non-cancer). The cancer continual learning V2 variant was continually pretrained on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
17
 
18
- Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus (~30M for V1, ~104M for V2). The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
19
 
20
  The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
21
 
22
  We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
23
 
24
- During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational AI model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
25
 
26
  The repository includes the following pretrained models:
27
 
28
- - Geneformer-V1-10M: original model trained June 2021 on ~30M human single cell transcriptomes, 10M parameters, input size 2048, vocabulary ~25K protein-coding or non-coding RNA genes
29
- - Geneformer-V2-104M and Geneformer-V2-316M: updated model trained Dec 2024 on ~104M human single cell transcriptomes, 104M or 316M parameters, input size 4096, vocabulary ~20K protein-coding genes
 
 
30
 
31
- The current default model in the main directory of the repository is Geneformer-V2-316M.
 
 
 
32
 
33
- The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, Geneformer-V2-104M_CLcancer.
 
 
34
 
35
  # Application
36
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
@@ -78,13 +82,9 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
78
  - extracting and plotting cell embeddings
79
  - in silico perturbation
80
 
81
- Please also see [here](https://tinyurl.com/geneformertutorial) for a quickstart tutorial for predicting candidate therapeutic targets with Geneformer.
82
-
83
- Complete documentation is available at https://geneformer.readthedocs.io/en/latest/.
84
-
85
  Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
86
 
87
- Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). Importantly, as usual for deep learning models, there are no uniformly applicable default hyperparameters for Geneformer.
88
 
89
  # Citations
90
  - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
 
1
  ---
2
  datasets: ctheodoris/Genecorpus-30M
3
  license: apache-2.0
 
 
 
4
  ---
5
  # Geneformer
6
  Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
7
 
8
  - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
9
+ - See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
10
  - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
11
 
12
  # Model Description
13
+ Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
14
 
15
+ Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
16
 
17
  The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
18
 
19
  We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
20
 
21
+ During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
22
 
23
  The repository includes the following pretrained models:
24
 
25
+ L=layers\
26
+ M=millions of cells used for pretraining\
27
+ i=input size\
28
+ (pretraining date)
29
 
30
+ - GF-6L-30M-i2048 (June 2021)
31
+ - GF-12L-30M-i2048 (June 2021)
32
+ - GF-12L-95M-i4096 (April 2024)
33
+ - GF-20L-95M-i4096 (April 2024)
34
 
35
+ The current default model in the main directory of the repository is GF-12L-95M-i4096.
36
+
37
+ The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
38
 
39
  # Application
40
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
 
82
  - extracting and plotting cell embeddings
83
  - in silico perturbation
84
 
 
 
 
 
85
  Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
86
 
87
+ Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
88
 
89
  # Citations
90
  - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
config.json CHANGED
@@ -2,22 +2,22 @@
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
- "attention_probs_dropout_prob": 0.1,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.1,
9
- "hidden_size": 1152,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 4608,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
- "num_attention_heads": 18,
16
- "num_hidden_layers": 18,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
- "transformers_version": "4.44.2",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
 
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
+ "attention_probs_dropout_prob": 0.02,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
+ "transformers_version": "4.37.1",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
docs/source/geneformer.in_silico_perturber.rst CHANGED
@@ -5,4 +5,4 @@ geneformer.in\_silico\_perturber
5
  :members:
6
  :undoc-members:
7
  :show-inheritance:
8
- :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, , isp_perturb_all_special, isp_perturb_set_special, update_perturbation_dictionary
 
5
  :members:
6
  :undoc-members:
7
  :show-inheritance:
8
+ :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary
examples/cell_classification.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
14
  "metadata": {},
15
  "source": [
16
- "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters."
17
  ]
18
  },
19
  {
@@ -68,8 +68,6 @@
68
  " \"per_device_train_batch_size\": 12,\n",
69
  " \"seed\": 73,\n",
70
  "}\n",
71
- "\n",
72
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
73
  "cc = Classifier(classifier=\"cell\",\n",
74
  " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
75
  " filter_data=filter_data_dict,\n",
@@ -78,7 +76,6 @@
78
  " freeze_layers = 2,\n",
79
  " num_crossval_splits = 1,\n",
80
  " forward_batch_size=200,\n",
81
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
82
  " nproc=16)"
83
  ]
84
  },
@@ -128,7 +125,7 @@
128
  " \"train\": train_ids+eval_ids,\n",
129
  " \"test\": test_ids}\n",
130
  "\n",
131
- "# Example input_data_file for 30M model: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
132
  "cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
133
  " output_directory=output_dir,\n",
134
  " output_prefix=output_prefix,\n",
@@ -263,8 +260,8 @@
263
  " \"train\": train_ids,\n",
264
  " \"eval\": eval_ids}\n",
265
  "\n",
266
- "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
267
- "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
268
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
269
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
270
  " output_directory=output_dir,\n",
@@ -449,7 +446,7 @@
449
  "name": "python",
450
  "nbconvert_exporter": "python",
451
  "pygments_lexer": "ipython3",
452
- "version": "3.10.13"
453
  }
454
  },
455
  "nbformat": 4,
 
13
  "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
14
  "metadata": {},
15
  "source": [
16
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
17
  ]
18
  },
19
  {
 
68
  " \"per_device_train_batch_size\": 12,\n",
69
  " \"seed\": 73,\n",
70
  "}\n",
 
 
71
  "cc = Classifier(classifier=\"cell\",\n",
72
  " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
73
  " filter_data=filter_data_dict,\n",
 
76
  " freeze_layers = 2,\n",
77
  " num_crossval_splits = 1,\n",
78
  " forward_batch_size=200,\n",
 
79
  " nproc=16)"
80
  ]
81
  },
 
125
  " \"train\": train_ids+eval_ids,\n",
126
  " \"test\": test_ids}\n",
127
  "\n",
128
+ "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
129
  "cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
130
  " output_directory=output_dir,\n",
131
  " output_prefix=output_prefix,\n",
 
260
  " \"train\": train_ids,\n",
261
  " \"eval\": eval_ids}\n",
262
  "\n",
263
+ "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
264
+ "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
265
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
266
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
267
  " output_directory=output_dir,\n",
 
446
  "name": "python",
447
  "nbconvert_exporter": "python",
448
  "pygments_lexer": "ipython3",
449
+ "version": "3.11.5"
450
  }
451
  },
452
  "nbformat": 4,
examples/distributed_multitask_cell_classification.ipynb DELETED
@@ -1,149 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "b3266a7b",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import os\n",
11
- "import torch\n",
12
- "from geneformer import MTLClassifier"
13
- ]
14
- },
15
- {
16
- "cell_type": "code",
17
- "execution_count": null,
18
- "id": "3e12ac9f",
19
- "metadata": {},
20
- "outputs": [],
21
- "source": [
22
- "# Define paths\n",
23
- "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
24
- "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
25
- "train_path = \"/path/to/train/data.dataset\"\n",
26
- "val_path = \"/path/to/val/data.dataset\"\n",
27
- "test_path = \"/path/to/test/data.dataset\"\n",
28
- "results_dir = \"/path/to/results/directory\"\n",
29
- "model_save_path = \"/path/to/model/save/path\"\n",
30
- "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
31
- "\n",
32
- "# Define tasks and hyperparameters\n",
33
- "# task_columns should be a list of column names from your dataset\n",
34
- "# Each column represents a specific classification task (e.g. cell type, disease state)\n",
35
- "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns"
36
- ]
37
- },
38
- {
39
- "cell_type": "code",
40
- "execution_count": null,
41
- "id": "c9bd7562",
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "# Check GPU environment\n",
46
- "num_gpus = torch.cuda.device_count()\n",
47
- "use_distributed = num_gpus > 1\n",
48
- "print(f\"Number of GPUs detected: {num_gpus}\")\n",
49
- "print(f\"Using distributed training: {use_distributed}\")\n",
50
- "\n",
51
- "# Set environment variables for distributed training when multiple GPUs are available\n",
52
- "if use_distributed:\n",
53
- " os.environ[\"MASTER_ADDR\"] = \"localhost\" # hostname\n",
54
- " os.environ[\"MASTER_PORT\"] = \"12355\" # Choose an available port\n",
55
- " print(\"Distributed environment variables set.\")"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": null,
61
- "id": "b6ff3618",
62
- "metadata": {},
63
- "outputs": [],
64
- "source": [
65
- "#Define Hyperparameters for Optimization\n",
66
- "hyperparameters = {\n",
67
- " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
68
- " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
69
- " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
70
- " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
71
- " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
72
- " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n",
73
- "}"
74
- ]
75
- },
76
- {
77
- "cell_type": "code",
78
- "execution_count": null,
79
- "id": "f665c5a7",
80
- "metadata": {},
81
- "outputs": [],
82
- "source": [
83
- "mc = MTLClassifier(\n",
84
- " task_columns=task_columns, # Our defined classification tasks\n",
85
- " study_name=\"MTLClassifier_distributed\",\n",
86
- " pretrained_path=pretrained_path,\n",
87
- " train_path=train_path,\n",
88
- " val_path=val_path,\n",
89
- " test_path=test_path,\n",
90
- " model_save_path=model_save_path,\n",
91
- " results_dir=results_dir,\n",
92
- " tensorboard_log_dir=tensorboard_log_dir,\n",
93
- " hyperparameters=hyperparameters,\n",
94
- " # Distributed training parameters\n",
95
- " distributed_training=use_distributed, # Enable distributed training if multiple GPUs available\n",
96
- " master_addr=\"localhost\" if use_distributed else None,\n",
97
- " master_port=\"12355\" if use_distributed else None,\n",
98
- " # Other training parameters\n",
99
- " n_trials=15, # Number of trials for hyperparameter optimization\n",
100
- " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
101
- " batch_size=8, # Adjust based on available GPU memory\n",
102
- " gradient_accumulation_steps=4, # Accumulate gradients over multiple steps\n",
103
- " gradient_clipping=True, # Enable gradient clipping for stability\n",
104
- " max_grad_norm=1.0, # Set maximum gradient norm\n",
105
- " seed=42\n",
106
- ")"
107
- ]
108
- },
109
- {
110
- "cell_type": "code",
111
- "execution_count": null,
112
- "id": "f69f7b6a",
113
- "metadata": {},
114
- "outputs": [],
115
- "source": [
116
- "# Run Hyperparameter Optimization with Distributed Training\n",
117
- "if __name__ == \"__main__\":\n",
118
- " # This guard is required for distributed training to prevent\n",
119
- " # infinite subprocess spawning when using torch.multiprocessing\n",
120
- " mc.run_optuna_study()"
121
- ]
122
- },
123
- {
124
- "cell_type": "code",
125
- "execution_count": null,
126
- "id": "3affd5dd",
127
- "metadata": {},
128
- "outputs": [],
129
- "source": [
130
- "# Evaluate the Model on Test Data\n",
131
- "if __name__ == \"__main__\":\n",
132
- " mc.load_and_evaluate_test_model()"
133
- ]
134
- }
135
- ],
136
- "metadata": {
137
- "kernelspec": {
138
- "display_name": "bio",
139
- "language": "python",
140
- "name": "python3"
141
- },
142
- "language_info": {
143
- "name": "python",
144
- "version": "3.12.8"
145
- }
146
- },
147
- "nbformat": 4,
148
- "nbformat_minor": 5
149
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/extract_and_plot_cell_embeddings.ipynb CHANGED
@@ -18,7 +18,6 @@
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
21
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
22
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
23
  " num_classes=3,\n",
24
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
@@ -27,13 +26,12 @@
27
  " emb_label=[\"disease\",\"cell_type\"],\n",
28
  " labels_to_plot=[\"disease\"],\n",
29
  " forward_batch_size=200,\n",
30
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
31
  " nproc=16)\n",
32
  "\n",
33
  "# extracts embedding from input data\n",
34
  "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
35
- "# example dataset for V1 model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
36
- "embs = embex.extract_embs(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
37
  " \"path/to/input_data/\",\n",
38
  " \"path/to/output_directory/\",\n",
39
  " \"output_prefix\")\n"
@@ -131,7 +129,7 @@
131
  "name": "python",
132
  "nbconvert_exporter": "python",
133
  "pygments_lexer": "ipython3",
134
- "version": "3.10.13"
135
  }
136
  },
137
  "nbformat": 4,
 
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
 
21
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
22
  " num_classes=3,\n",
23
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
 
26
  " emb_label=[\"disease\",\"cell_type\"],\n",
27
  " labels_to_plot=[\"disease\"],\n",
28
  " forward_batch_size=200,\n",
 
29
  " nproc=16)\n",
30
  "\n",
31
  "# extracts embedding from input data\n",
32
  "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
33
+ "# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
34
+ "embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
35
  " \"path/to/input_data/\",\n",
36
  " \"path/to/output_directory/\",\n",
37
  " \"output_prefix\")\n"
 
129
  "name": "python",
130
  "nbconvert_exporter": "python",
131
  "pygments_lexer": "ipython3",
132
+ "version": "3.11.5"
133
  }
134
  },
135
  "nbformat": 4,
examples/gene_classification.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  "id": "79539e95-2c9c-4162-835c-f0d158abb15d",
14
  "metadata": {},
15
  "source": [
16
- "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters."
17
  ]
18
  },
19
  {
@@ -71,14 +71,12 @@
71
  }
72
  ],
73
  "source": [
74
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
75
  "cc = Classifier(classifier=\"gene\",\n",
76
  " gene_class_dict = gene_class_dict,\n",
77
  " max_ncells = 10_000,\n",
78
  " freeze_layers = 4,\n",
79
  " num_crossval_splits = 5,\n",
80
  " forward_batch_size=200,\n",
81
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
82
  " nproc=16)"
83
  ]
84
  },
@@ -104,7 +102,7 @@
104
  }
105
  ],
106
  "source": [
107
- "# Example input_data_file for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n",
108
  "cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
109
  " output_directory=output_dir,\n",
110
  " output_prefix=output_prefix)"
@@ -842,8 +840,8 @@
842
  }
843
  ],
844
  "source": [
845
- "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
846
- "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
847
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
848
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
849
  " output_directory=output_dir,\n",
@@ -1065,14 +1063,12 @@
1065
  }
1066
  ],
1067
  "source": [
1068
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
1069
  "cc = Classifier(classifier=\"gene\",\n",
1070
  " gene_class_dict = gene_class_dict,\n",
1071
  " max_ncells = 10_000,\n",
1072
  " freeze_layers = 4,\n",
1073
  " num_crossval_splits = 0,\n",
1074
  " forward_batch_size=200,\n",
1075
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
1076
  " nproc=16)"
1077
  ]
1078
  },
@@ -1219,8 +1215,8 @@
1219
  }
1220
  ],
1221
  "source": [
1222
- "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
1223
- "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
1224
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
1225
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
1226
  " output_directory=output_dir,\n",
@@ -1244,7 +1240,7 @@
1244
  "name": "python",
1245
  "nbconvert_exporter": "python",
1246
  "pygments_lexer": "ipython3",
1247
- "version": "3.10.13"
1248
  }
1249
  },
1250
  "nbformat": 4,
 
13
  "id": "79539e95-2c9c-4162-835c-f0d158abb15d",
14
  "metadata": {},
15
  "source": [
16
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
17
  ]
18
  },
19
  {
 
71
  }
72
  ],
73
  "source": [
 
74
  "cc = Classifier(classifier=\"gene\",\n",
75
  " gene_class_dict = gene_class_dict,\n",
76
  " max_ncells = 10_000,\n",
77
  " freeze_layers = 4,\n",
78
  " num_crossval_splits = 5,\n",
79
  " forward_batch_size=200,\n",
 
80
  " nproc=16)"
81
  ]
82
  },
 
102
  }
103
  ],
104
  "source": [
105
+ "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n",
106
  "cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
107
  " output_directory=output_dir,\n",
108
  " output_prefix=output_prefix)"
 
840
  }
841
  ],
842
  "source": [
843
+ "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
844
+ "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
845
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
846
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
847
  " output_directory=output_dir,\n",
 
1063
  }
1064
  ],
1065
  "source": [
 
1066
  "cc = Classifier(classifier=\"gene\",\n",
1067
  " gene_class_dict = gene_class_dict,\n",
1068
  " max_ncells = 10_000,\n",
1069
  " freeze_layers = 4,\n",
1070
  " num_crossval_splits = 0,\n",
1071
  " forward_batch_size=200,\n",
 
1072
  " nproc=16)"
1073
  ]
1074
  },
 
1215
  }
1216
  ],
1217
  "source": [
1218
+ "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
1219
+ "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n",
1220
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
1221
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
1222
  " output_directory=output_dir,\n",
 
1240
  "name": "python",
1241
  "nbconvert_exporter": "python",
1242
  "pygments_lexer": "ipython3",
1243
+ "version": "3.11.5"
1244
  }
1245
  },
1246
  "nbformat": 4,
examples/in_silico_perturbation.ipynb CHANGED
@@ -39,19 +39,17 @@
39
  "\n",
40
  "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
  "\n",
42
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
43
- "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
44
  " num_classes=3,\n",
45
  " filter_data=filter_data_dict,\n",
46
  " max_ncells=1000,\n",
47
  " emb_layer=0,\n",
48
  " summary_stat=\"exact_mean\",\n",
49
  " forward_batch_size=256,\n",
50
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
51
  " nproc=16)\n",
52
  "\n",
53
  "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
54
- " \"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
55
  " \"path/to/input_data\",\n",
56
  " \"path/to/output_directory\",\n",
57
  " \"output_prefix\")"
@@ -66,15 +64,14 @@
66
  },
67
  "outputs": [],
68
  "source": [
69
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
70
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
71
  " perturb_rank_shift=None,\n",
72
  " genes_to_perturb=\"all\",\n",
73
  " combos=0,\n",
74
  " anchor_gene=None,\n",
75
- " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
76
  " num_classes=3,\n",
77
- " emb_mode=\"cell\", # OF NOTE: SET TO \"CELL\" FOR V1 MODEL. FOR V2, SHOULD BE \"CLS\" (current default).\n",
78
  " cell_emb_style=\"mean_pool\",\n",
79
  " filter_data=filter_data_dict,\n",
80
  " cell_states_to_model=cell_states_to_model,\n",
@@ -82,7 +79,6 @@
82
  " max_ncells=2000,\n",
83
  " emb_layer=0,\n",
84
  " forward_batch_size=400,\n",
85
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
86
  " nproc=16)"
87
  ]
88
  },
@@ -94,10 +90,9 @@
94
  "outputs": [],
95
  "source": [
96
  "# outputs intermediate files from in silico perturbation\n",
97
- "\n",
98
- "isp.perturb_data(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
99
  " \"path/to/input_data\",\n",
100
- " \"path/to/isp_output_directory\",\n",
101
  " \"output_prefix\")"
102
  ]
103
  },
@@ -108,13 +103,11 @@
108
  "metadata": {},
109
  "outputs": [],
110
  "source": [
111
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
112
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
113
  " genes_perturbed=\"all\",\n",
114
  " combos=0,\n",
115
  " anchor_gene=None,\n",
116
- " cell_states_to_model=cell_states_to_model,\n",
117
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL SINCE V1 WAS USED FOR IN SILICO PERTURBATION ABOVE)"
118
  ]
119
  },
120
  {
@@ -125,9 +118,9 @@
125
  "outputs": [],
126
  "source": [
127
  "# extracts data from intermediate files and processes stats to output in final .csv\n",
128
- "ispstats.get_stats(\"path/to/isp_output_directory\", # this should be the directory \n",
129
  " None,\n",
130
- " \"path/to/isp_stats_output_directory\",\n",
131
  " \"output_prefix\")"
132
  ]
133
  }
@@ -148,7 +141,7 @@
148
  "name": "python",
149
  "nbconvert_exporter": "python",
150
  "pygments_lexer": "ipython3",
151
- "version": "3.10.13"
152
  }
153
  },
154
  "nbformat": 4,
 
39
  "\n",
40
  "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
  "\n",
42
+ "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
 
43
  " num_classes=3,\n",
44
  " filter_data=filter_data_dict,\n",
45
  " max_ncells=1000,\n",
46
  " emb_layer=0,\n",
47
  " summary_stat=\"exact_mean\",\n",
48
  " forward_batch_size=256,\n",
 
49
  " nproc=16)\n",
50
  "\n",
51
  "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
52
+ " \"path/to/model\",\n",
53
  " \"path/to/input_data\",\n",
54
  " \"path/to/output_directory\",\n",
55
  " \"output_prefix\")"
 
64
  },
65
  "outputs": [],
66
  "source": [
 
67
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
68
  " perturb_rank_shift=None,\n",
69
  " genes_to_perturb=\"all\",\n",
70
  " combos=0,\n",
71
  " anchor_gene=None,\n",
72
+ " model_type=\"CellClassifier\",\n",
73
  " num_classes=3,\n",
74
+ " emb_mode=\"cell\",\n",
75
  " cell_emb_style=\"mean_pool\",\n",
76
  " filter_data=filter_data_dict,\n",
77
  " cell_states_to_model=cell_states_to_model,\n",
 
79
  " max_ncells=2000,\n",
80
  " emb_layer=0,\n",
81
  " forward_batch_size=400,\n",
 
82
  " nproc=16)"
83
  ]
84
  },
 
90
  "outputs": [],
91
  "source": [
92
  "# outputs intermediate files from in silico perturbation\n",
93
+ "isp.perturb_data(\"path/to/model\",\n",
 
94
  " \"path/to/input_data\",\n",
95
+ " \"path/to/output_directory\",\n",
96
  " \"output_prefix\")"
97
  ]
98
  },
 
103
  "metadata": {},
104
  "outputs": [],
105
  "source": [
 
106
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
107
  " genes_perturbed=\"all\",\n",
108
  " combos=0,\n",
109
  " anchor_gene=None,\n",
110
+ " cell_states_to_model=cell_states_to_model)"
 
111
  ]
112
  },
113
  {
 
118
  "outputs": [],
119
  "source": [
120
  "# extracts data from intermediate files and processes stats to output in final .csv\n",
121
+ "ispstats.get_stats(\"path/to/input_data\",\n",
122
  " None,\n",
123
+ " \"path/to/output_directory\",\n",
124
  " \"output_prefix\")"
125
  ]
126
  }
 
141
  "name": "python",
142
  "nbconvert_exporter": "python",
143
  "pygments_lexer": "ipython3",
144
+ "version": "3.10.11"
145
  }
146
  },
147
  "nbformat": 4,
examples/multitask_cell_classification.ipynb CHANGED
@@ -286,7 +286,7 @@
286
  " filter_data_dict=filter_data_dict,\n",
287
  " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
  " emb_layer=0, # Use the second to last layer\n",
289
- " emb_mode = \"cls\", # Use CLS token embedding for V2 model\n",
290
  " summary_stat=\"exact_mean\",\n",
291
  " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
  " nproc=4\n",
@@ -324,7 +324,7 @@
324
  " perturb_type=perturb_type,\n",
325
  " genes_to_perturb=\"all\", # Perturb all genes\n",
326
  " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
- " emb_mode=\"cls\", # Use CLS token embedding for V2 model\n",
328
  " cell_states_to_model=cell_states_to_model,\n",
329
  " state_embs_dict=state_embs_dict,\n",
330
  " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
@@ -412,7 +412,7 @@
412
  "name": "python",
413
  "nbconvert_exporter": "python",
414
  "pygments_lexer": "ipython3",
415
- "version": "3.10.13"
416
  }
417
  },
418
  "nbformat": 4,
 
286
  " filter_data_dict=filter_data_dict,\n",
287
  " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
  " emb_layer=0, # Use the second to last layer\n",
289
+ " emb_mode = \"cls\",\n",
290
  " summary_stat=\"exact_mean\",\n",
291
  " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
  " nproc=4\n",
 
324
  " perturb_type=perturb_type,\n",
325
  " genes_to_perturb=\"all\", # Perturb all genes\n",
326
  " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
+ " emb_mode=\"cls\", # Use CLS token embedding\n",
328
  " cell_states_to_model=cell_states_to_model,\n",
329
  " state_embs_dict=state_embs_dict,\n",
330
  " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
 
412
  "name": "python",
413
  "nbconvert_exporter": "python",
414
  "pygments_lexer": "ipython3",
415
+ "version": "3.11.5"
416
  }
417
  },
418
  "nbformat": 4,
examples/tokenizing_scRNAseq_data.ipynb CHANGED
@@ -12,7 +12,7 @@
12
  },
13
  {
14
  "cell_type": "markdown",
15
- "id": "1fe86f48-5578-47df-b373-58c21ec170ab",
16
  "metadata": {},
17
  "source": [
18
  "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
@@ -25,17 +25,11 @@
25
  "\n",
26
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
  "\n",
28
- "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
29
- ]
30
- },
31
- {
32
- "cell_type": "markdown",
33
- "id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
34
- "metadata": {},
35
- "source": [
36
- "**********************************************************************************************************\n",
37
- "#### OF NOTE: Please ensure the correct token dictionary, gene median file, special token setting, and model input size is used for the correct model version.\n",
38
- "#### Current defaults are for V2 model series. To auto-select the correct settings for V1, set model_version argument to \"V1\"."
39
  ]
40
  },
41
  {
@@ -55,7 +49,7 @@
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
58
- "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16) # for V1 model, set model_version=\"V1\"\n",
59
  "tk.tokenize_data(\"loom_data_directory\", \n",
60
  " \"output_directory\", \n",
61
  " \"output_prefix\", \n",
@@ -79,7 +73,7 @@
79
  "name": "python",
80
  "nbconvert_exporter": "python",
81
  "pygments_lexer": "ipython3",
82
- "version": "3.10.13"
83
  }
84
  },
85
  "nbformat": 4,
 
12
  },
13
  {
14
  "cell_type": "markdown",
15
+ "id": "350e6252-b783-494b-9767-f087eb868a15",
16
  "metadata": {},
17
  "source": [
18
  "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
 
25
  "\n",
26
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
  "\n",
28
+ "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.\n",
29
+ "\n",
30
+ "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
31
+ "\n",
32
+ "#### The 95M model series also require the special_token argument to be set to True and model_input_size to be 4096."
 
 
 
 
 
 
33
  ]
34
  },
35
  {
 
49
  "metadata": {},
50
  "outputs": [],
51
  "source": [
52
+ "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
53
  "tk.tokenize_data(\"loom_data_directory\", \n",
54
  " \"output_directory\", \n",
55
  " \"output_prefix\", \n",
 
73
  "name": "python",
74
  "nbconvert_exporter": "python",
75
  "pygments_lexer": "ipython3",
76
+ "version": "3.10.11"
77
  }
78
  },
79
  "nbformat": 4,
{Geneformer-V2-104M → fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522}/config.json RENAMED
@@ -2,22 +2,22 @@
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
- "attention_probs_dropout_prob": 0.1,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.1,
9
- "hidden_size": 768,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 3072,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
- "num_attention_heads": 12,
16
  "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
- "transformers_version": "4.44.2",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
 
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
+ "attention_probs_dropout_prob": 0.02,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
+ "num_attention_heads": 8,
16
  "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
+ "transformers_version": "4.37.2",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
3
+ size 152363342
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin RENAMED
File without changes
geneformer/__init__.py CHANGED
@@ -4,15 +4,10 @@ from pathlib import Path
4
 
5
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
 
7
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc104M.pkl"
8
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc104M.pkl"
9
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc104M.pkl"
10
- ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc104M.pkl"
11
-
12
- GENE_MEDIAN_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl"
13
- TOKEN_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/token_dictionary_gc30M.pkl"
14
- ENSEMBL_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl"
15
- ENSEMBL_MAPPING_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl"
16
 
17
  from . import (
18
  collator_for_classification,
 
4
 
5
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
 
7
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
8
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
9
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
10
+ ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
 
 
 
 
 
11
 
12
  from . import (
13
  collator_for_classification,
geneformer/classifier.py CHANGED
@@ -48,13 +48,11 @@ import logging
48
  import os
49
  import pickle
50
  import subprocess
51
- from packaging.version import parse
52
  from pathlib import Path
53
 
54
  import numpy as np
55
  import pandas as pd
56
  import seaborn as sns
57
- import transformers
58
  from tqdm.auto import tqdm, trange
59
  from transformers import Trainer
60
  from transformers.training_args import TrainingArguments
@@ -73,7 +71,6 @@ sns.set()
73
 
74
  logger = logging.getLogger(__name__)
75
 
76
- transformers_version = parse(transformers.__version__)
77
 
78
  class Classifier:
79
  valid_option_dict = {
@@ -92,7 +89,6 @@ class Classifier:
92
  "no_eval": {bool},
93
  "stratify_splits_col": {None, str},
94
  "forward_batch_size": {int},
95
- "model_version": {"V1", "V2"},
96
  "token_dictionary_file": {None, str},
97
  "nproc": {int},
98
  "ngpu": {int},
@@ -116,7 +112,6 @@ class Classifier:
116
  stratify_splits_col=None,
117
  no_eval=False,
118
  forward_batch_size=100,
119
- model_version="V2",
120
  token_dictionary_file=None,
121
  nproc=4,
122
  ngpu=1,
@@ -193,9 +188,6 @@ class Classifier:
193
  | Otherwise, will perform eval during training.
194
  forward_batch_size : int
195
  | Batch size for forward pass (for evaluation, not training).
196
- model_version : str
197
- | To auto-select settings for model version other than current default.
198
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
199
  token_dictionary_file : None, str
200
  | Default is to use token dictionary file from Geneformer
201
  | Otherwise, will load custom gene token dictionary.
@@ -230,16 +222,14 @@ class Classifier:
230
  self.stratify_splits_col = stratify_splits_col
231
  self.no_eval = no_eval
232
  self.forward_batch_size = forward_batch_size
233
- self.model_version = model_version
234
  self.token_dictionary_file = token_dictionary_file
235
  self.nproc = nproc
236
  self.ngpu = ngpu
237
-
238
  if self.training_args is None:
239
  logger.warning(
240
  "Hyperparameter tuning is highly recommended for optimal results. "
241
- "No training_args provided; using default hyperparameters. "
242
- "Please note: these defaults are not recommended to be used uniformly across tasks."
243
  )
244
 
245
  self.validate_options()
@@ -254,10 +244,7 @@ class Classifier:
254
  ] = self.cell_state_dict["states"]
255
 
256
  # load token dictionary (Ensembl IDs:token)
257
- if self.model_version == "V1":
258
- from . import TOKEN_DICTIONARY_FILE_30M
259
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
260
- elif self.token_dictionary_file is None:
261
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE
262
  with open(self.token_dictionary_file, "rb") as f:
263
  self.gene_token_dict = pickle.load(f)
@@ -367,7 +354,6 @@ class Classifier:
367
  attr_to_balance=None,
368
  max_trials=100,
369
  pval_threshold=0.1,
370
- id_class_dict_path=None,
371
  ):
372
  """
373
  Prepare data for cell state or gene classification.
@@ -410,10 +396,6 @@ class Classifier:
410
  pval_threshold : None, float
411
  | P-value threshold to use for attribute balancing across splits
412
  | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
413
- id_class_dict_path : Path
414
- | Path to *_id_class_dict.pkl from prior run of prepare_data to reuse for labeling new data
415
- | Dictionary with keys being numeric class labels and values being original dataset class labels
416
- | Note: only available for CellClassifiers
417
  """
418
 
419
  if test_size is None:
@@ -455,25 +437,14 @@ class Classifier:
455
  )
456
  # rename cell state column to "label"
457
  data = cu.rename_cols(data, self.cell_state_dict["state_key"])
458
-
459
- # convert classes to numerical labels and save as id_class_dict
460
- if id_class_dict_path is not None:
461
- with open(id_class_dict_path,"rb") as fp:
462
- id_class_dict = pickle.load(fp)
463
- else:
464
- id_class_dict = None
465
- data, id_class_dict = cu.label_classes(
466
- self.classifier, data, self.cell_state_dict, self.nproc, id_class_dict,
467
- )
468
 
469
- elif self.classifier == "gene":
470
  # convert classes to numerical labels and save as id_class_dict
471
  # of note, will label all genes in gene_class_dict
472
  # if (cross-)validating, genes will be relabeled in column "labels" for each split
473
  # at the time of training with Classifier.validate
474
- data, id_class_dict = cu.label_classes(
475
- self.classifier, data, self.gene_class_dict, self.nproc
476
- )
477
 
478
  # save id_class_dict for future reference
479
  id_class_output_path = (
@@ -801,7 +772,7 @@ class Classifier:
801
  # 5-fold cross-validate
802
  num_cells = len(data)
803
  fifth_cells = int(np.floor(num_cells * 0.2))
804
- num_eval = int(min((self.eval_size * num_cells), fifth_cells))
805
  start = i * fifth_cells
806
  end = start + num_eval
807
  eval_indices = [j for j in range(start, end)]
@@ -1085,8 +1056,6 @@ class Classifier:
1085
  if eval_data is None:
1086
  def_training_args["evaluation_strategy"] = "no"
1087
  def_training_args["load_best_model_at_end"] = False
1088
- if transformers_version >= parse("4.46"):
1089
- def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy")
1090
  def_training_args.update(
1091
  {"save_strategy": "epoch", "save_total_limit": 1}
1092
  ) # only save last model for each run
@@ -1258,8 +1227,6 @@ class Classifier:
1258
  if eval_data is None:
1259
  def_training_args["evaluation_strategy"] = "no"
1260
  def_training_args["load_best_model_at_end"] = False
1261
- if transformers_version >= parse("4.46"):
1262
- def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy")
1263
  training_args_init = TrainingArguments(**def_training_args)
1264
 
1265
  if self.freeze_layers is not None:
@@ -1313,7 +1280,6 @@ class Classifier:
1313
  predict=False,
1314
  output_directory=None,
1315
  output_prefix=None,
1316
- predict_metadata=None,
1317
  ):
1318
  """
1319
  Evaluate the fine-tuned model.
@@ -1339,11 +1305,9 @@ class Classifier:
1339
 
1340
  ##### Evaluate the model #####
1341
  labels = id_class_dict.keys()
1342
-
1343
- y_pred, y_true, logits_list, predict_metadata_all = eu.classifier_predict(
1344
- model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata
1345
  )
1346
-
1347
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1348
  y_pred, y_true, logits_list, num_classes, labels
1349
  )
@@ -1353,9 +1317,6 @@ class Classifier:
1353
  "label_ids": y_true,
1354
  "predictions": logits_list,
1355
  }
1356
- if predict_metadata is not None:
1357
- pred_dict["prediction_metadata"] = predict_metadata_all
1358
-
1359
  pred_dict_output_path = (
1360
  Path(output_directory) / f"{output_prefix}_pred_dict"
1361
  ).with_suffix(".pkl")
@@ -1376,7 +1337,6 @@ class Classifier:
1376
  output_directory,
1377
  output_prefix,
1378
  predict=True,
1379
- predict_metadata=None,
1380
  ):
1381
  """
1382
  Evaluate the fine-tuned model.
@@ -1396,8 +1356,6 @@ class Classifier:
1396
  | Prefix for output files
1397
  predict : bool
1398
  | Whether or not to save eval predictions
1399
- predict_metadata : None | list
1400
- | Metadata labels to output with predictions (columns in test_data_file)
1401
  """
1402
 
1403
  # load numerical id to class dictionary (id:class)
@@ -1410,15 +1368,6 @@ class Classifier:
1410
  # load previously filtered and prepared data
1411
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1412
 
1413
- if predict_metadata is not None:
1414
- absent_metadata = []
1415
- for predict_metadata_x in predict_metadata:
1416
- if predict_metadata_x not in test_data.features.keys():
1417
- absent_metadata += [predict_metadata_x]
1418
- if len(absent_metadata)>0:
1419
- logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}")
1420
- raise
1421
-
1422
  # load previously fine-tuned model
1423
  model = pu.load_model(
1424
  self.model_type,
@@ -1437,7 +1386,6 @@ class Classifier:
1437
  predict=predict,
1438
  output_directory=output_directory,
1439
  output_prefix=output_prefix,
1440
- predict_metadata=predict_metadata,
1441
  )
1442
 
1443
  all_conf_mat_df = pd.DataFrame(
 
48
  import os
49
  import pickle
50
  import subprocess
 
51
  from pathlib import Path
52
 
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
 
56
  from tqdm.auto import tqdm, trange
57
  from transformers import Trainer
58
  from transformers.training_args import TrainingArguments
 
71
 
72
  logger = logging.getLogger(__name__)
73
 
 
74
 
75
  class Classifier:
76
  valid_option_dict = {
 
89
  "no_eval": {bool},
90
  "stratify_splits_col": {None, str},
91
  "forward_batch_size": {int},
 
92
  "token_dictionary_file": {None, str},
93
  "nproc": {int},
94
  "ngpu": {int},
 
112
  stratify_splits_col=None,
113
  no_eval=False,
114
  forward_batch_size=100,
 
115
  token_dictionary_file=None,
116
  nproc=4,
117
  ngpu=1,
 
188
  | Otherwise, will perform eval during training.
189
  forward_batch_size : int
190
  | Batch size for forward pass (for evaluation, not training).
 
 
 
191
  token_dictionary_file : None, str
192
  | Default is to use token dictionary file from Geneformer
193
  | Otherwise, will load custom gene token dictionary.
 
222
  self.stratify_splits_col = stratify_splits_col
223
  self.no_eval = no_eval
224
  self.forward_batch_size = forward_batch_size
 
225
  self.token_dictionary_file = token_dictionary_file
226
  self.nproc = nproc
227
  self.ngpu = ngpu
228
+
229
  if self.training_args is None:
230
  logger.warning(
231
  "Hyperparameter tuning is highly recommended for optimal results. "
232
+ "No training_args provided; using default hyperparameters."
 
233
  )
234
 
235
  self.validate_options()
 
244
  ] = self.cell_state_dict["states"]
245
 
246
  # load token dictionary (Ensembl IDs:token)
247
+ if self.token_dictionary_file is None:
 
 
 
248
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE
249
  with open(self.token_dictionary_file, "rb") as f:
250
  self.gene_token_dict = pickle.load(f)
 
354
  attr_to_balance=None,
355
  max_trials=100,
356
  pval_threshold=0.1,
 
357
  ):
358
  """
359
  Prepare data for cell state or gene classification.
 
396
  pval_threshold : None, float
397
  | P-value threshold to use for attribute balancing across splits
398
  | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
 
 
 
 
399
  """
400
 
401
  if test_size is None:
 
437
  )
438
  # rename cell state column to "label"
439
  data = cu.rename_cols(data, self.cell_state_dict["state_key"])
 
 
 
 
 
 
 
 
 
 
440
 
 
441
  # convert classes to numerical labels and save as id_class_dict
442
  # of note, will label all genes in gene_class_dict
443
  # if (cross-)validating, genes will be relabeled in column "labels" for each split
444
  # at the time of training with Classifier.validate
445
+ data, id_class_dict = cu.label_classes(
446
+ self.classifier, data, self.gene_class_dict, self.nproc
447
+ )
448
 
449
  # save id_class_dict for future reference
450
  id_class_output_path = (
 
772
  # 5-fold cross-validate
773
  num_cells = len(data)
774
  fifth_cells = int(np.floor(num_cells * 0.2))
775
+ num_eval = min((self.eval_size * num_cells), fifth_cells)
776
  start = i * fifth_cells
777
  end = start + num_eval
778
  eval_indices = [j for j in range(start, end)]
 
1056
  if eval_data is None:
1057
  def_training_args["evaluation_strategy"] = "no"
1058
  def_training_args["load_best_model_at_end"] = False
 
 
1059
  def_training_args.update(
1060
  {"save_strategy": "epoch", "save_total_limit": 1}
1061
  ) # only save last model for each run
 
1227
  if eval_data is None:
1228
  def_training_args["evaluation_strategy"] = "no"
1229
  def_training_args["load_best_model_at_end"] = False
 
 
1230
  training_args_init = TrainingArguments(**def_training_args)
1231
 
1232
  if self.freeze_layers is not None:
 
1280
  predict=False,
1281
  output_directory=None,
1282
  output_prefix=None,
 
1283
  ):
1284
  """
1285
  Evaluate the fine-tuned model.
 
1305
 
1306
  ##### Evaluate the model #####
1307
  labels = id_class_dict.keys()
1308
+ y_pred, y_true, logits_list = eu.classifier_predict(
1309
+ model, self.classifier, eval_data, self.forward_batch_size
 
1310
  )
 
1311
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1312
  y_pred, y_true, logits_list, num_classes, labels
1313
  )
 
1317
  "label_ids": y_true,
1318
  "predictions": logits_list,
1319
  }
 
 
 
1320
  pred_dict_output_path = (
1321
  Path(output_directory) / f"{output_prefix}_pred_dict"
1322
  ).with_suffix(".pkl")
 
1337
  output_directory,
1338
  output_prefix,
1339
  predict=True,
 
1340
  ):
1341
  """
1342
  Evaluate the fine-tuned model.
 
1356
  | Prefix for output files
1357
  predict : bool
1358
  | Whether or not to save eval predictions
 
 
1359
  """
1360
 
1361
  # load numerical id to class dictionary (id:class)
 
1368
  # load previously filtered and prepared data
1369
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1370
 
 
 
 
 
 
 
 
 
 
1371
  # load previously fine-tuned model
1372
  model = pu.load_model(
1373
  self.model_type,
 
1386
  predict=predict,
1387
  output_directory=output_directory,
1388
  output_prefix=output_prefix,
 
1389
  )
1390
 
1391
  all_conf_mat_df = pd.DataFrame(
geneformer/classifier_utils.py CHANGED
@@ -94,7 +94,7 @@ def remove_rare(data, rare_threshold, label, nproc):
94
  return data
95
 
96
 
97
- def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None):
98
  if classifier == "cell":
99
  label_set = set(data["label"])
100
  elif classifier == "gene":
@@ -113,24 +113,15 @@ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None):
113
  )
114
  raise
115
 
116
- if id_class_dict is None:
117
- class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
118
- id_class_dict = {v: k for k, v in class_id_dict.items()}
119
- else:
120
- class_id_dict = {v: k for k, v in id_class_dict.items()}
121
-
122
- if classifier == "gene":
123
- inverse_gene_class_dict = {}
124
- for key, value_list in gene_class_dict.items():
125
- for value in value_list:
126
- inverse_gene_class_dict[value] = key
127
 
128
  def classes_to_ids(example):
129
  if classifier == "cell":
130
  example["label"] = class_id_dict[example["label"]]
131
  elif classifier == "gene":
132
  example["labels"] = label_gene_classes(
133
- example, class_id_dict, inverse_gene_class_dict
134
  )
135
  return example
136
 
@@ -138,9 +129,9 @@ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None):
138
  return data, id_class_dict
139
 
140
 
141
- def label_gene_classes(example, class_id_dict, inverse_gene_class_dict):
142
  return [
143
- class_id_dict.get(inverse_gene_class_dict.get(token_id, -100), -100)
144
  for token_id in example["input_ids"]
145
  ]
146
 
@@ -570,27 +561,6 @@ def compute_metrics(pred):
570
  return {"accuracy": acc, "macro_f1": macro_f1}
571
 
572
 
573
- def robust_compute_objective(metrics: dict):
574
- # tries both prefixed ("eval_") and raw metric names to support different transformers versions
575
- metric_name = "macro_f1"
576
-
577
- # check for the prefixed version
578
- prefixed_metric_name = f"eval_{metric_name}"
579
- if prefixed_metric_name in metrics:
580
- return metrics[prefixed_metric_name]
581
-
582
- # fall back to the raw name
583
- elif metric_name in metrics:
584
- return metrics[metric_name]
585
-
586
- # if neither is found, raise a clear error to help with debugging
587
- raise KeyError(
588
- f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. "
589
- f"Please check your `compute_metrics` function and `TrainingArguments`. "
590
- f"Available metrics: {list(metrics.keys())}"
591
- )
592
-
593
-
594
  def get_default_train_args(model, classifier, data, output_dir):
595
  num_layers = pu.quant_layers(model)
596
  freeze_layers = 0
 
94
  return data
95
 
96
 
97
+ def label_classes(classifier, data, gene_class_dict, nproc):
98
  if classifier == "cell":
99
  label_set = set(data["label"])
100
  elif classifier == "gene":
 
113
  )
114
  raise
115
 
116
+ class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
117
+ id_class_dict = {v: k for k, v in class_id_dict.items()}
 
 
 
 
 
 
 
 
 
118
 
119
  def classes_to_ids(example):
120
  if classifier == "cell":
121
  example["label"] = class_id_dict[example["label"]]
122
  elif classifier == "gene":
123
  example["labels"] = label_gene_classes(
124
+ example, class_id_dict, gene_class_dict
125
  )
126
  return example
127
 
 
129
  return data, id_class_dict
130
 
131
 
132
+ def label_gene_classes(example, class_id_dict, gene_class_dict):
133
  return [
134
+ class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
135
  for token_id in example["input_ids"]
136
  ]
137
 
 
561
  return {"accuracy": acc, "macro_f1": macro_f1}
562
 
563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  def get_default_train_args(model, classifier, data, output_dir):
565
  num_layers = pu.quant_layers(model)
566
  freeze_layers = 0
geneformer/collator_for_classification.py CHANGED
@@ -26,10 +26,9 @@ LARGE_INTEGER = int(
26
  1e20
27
  ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
28
 
29
- warnings.filterwarnings("ignore", message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach()", category=UserWarning, module="torch")
30
-
31
  # precollator functions
32
 
 
33
  class ExplicitEnum(Enum):
34
  """
35
  Enum with more explicit error message for missing values.
@@ -104,9 +103,6 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
104
  def pad_token_id(self):
105
  return self._pad_token_id
106
 
107
- def save_pretrained(self, save_directory):
108
- pass
109
-
110
  def _get_padding_truncation_strategies(
111
  self,
112
  padding=True,
@@ -645,8 +641,7 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
645
  def __call__(self, features):
646
  batch = self._prepare_batch(features)
647
 
648
- # batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
649
- batch = {k: torch.tensor(v.clone().detach(), dtype=torch.int64) if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
650
  return batch
651
 
652
 
 
26
  1e20
27
  ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
28
 
 
 
29
  # precollator functions
30
 
31
+
32
  class ExplicitEnum(Enum):
33
  """
34
  Enum with more explicit error message for missing values.
 
103
  def pad_token_id(self):
104
  return self._pad_token_id
105
 
 
 
 
106
  def _get_padding_truncation_strategies(
107
  self,
108
  padding=True,
 
641
  def __call__(self, features):
642
  batch = self._prepare_batch(features)
643
 
644
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
 
645
  return batch
646
 
647
 
geneformer/emb_extractor.py CHANGED
@@ -42,8 +42,6 @@ def get_embs(
42
  special_token=False,
43
  summary_stat=None,
44
  silent=False,
45
- save_tdigest=False,
46
- tdigest_path=None,
47
  ):
48
  model_input_size = pu.get_model_input_size(model)
49
  total_batch_length = len(filtered_input_data)
@@ -182,18 +180,12 @@ def get_embs(
182
  # calculate summary stat embs from approximated tdigests
183
  elif summary_stat is not None:
184
  if emb_mode == "cell":
185
- if save_tdigest:
186
- with open(f"{tdigest_path}","wb") as fp:
187
- pickle.dump(embs_tdigests, fp)
188
  if summary_stat == "mean":
189
  summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
190
  elif summary_stat == "median":
191
  summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
192
  embs_stack = torch.tensor(summary_emb_list)
193
  elif emb_mode == "gene":
194
- if save_tdigest:
195
- with open(f"{tdigest_path}","wb") as fp:
196
- pickle.dump(embs_tdigests_dict, fp)
197
  if summary_stat == "mean":
198
  [
199
  update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
@@ -260,7 +252,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels):
260
  return embs_df
261
 
262
 
263
- def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mean_pool"):
264
  gene_set = {
265
  element for sublist in downsampled_data["input_ids"] for element in sublist
266
  }
@@ -275,52 +267,25 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mea
275
  )
276
  for k in dict_i.keys():
277
  gene_emb_dict[k].append(dict_i[k])
278
- if gene_emb_style != "all":
279
- for k in gene_emb_dict.keys():
280
- gene_emb_dict[k] = (
281
- torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
282
- .cpu()
283
- .numpy()
284
- )
285
- embs_df = pd.DataFrame(gene_emb_dict).T
286
- else:
287
- embs_df = dict_lol_to_df(gene_emb_dict)
288
  embs_df.index = [token_gene_dict[token] for token in embs_df.index]
289
  return embs_df
290
 
291
- def dict_lol_to_df(data_dict):
292
- # save dictionary with values being list of equal-length lists as dataframe
293
- df_data = []
294
- for key, list_of_lists in data_dict.items():
295
- for i, sublist in enumerate(list_of_lists):
296
- row_data = [key, i] + sublist.tolist()
297
- df_data.append(row_data)
298
-
299
- # determine column names based on the length of sublists
300
- # assuming all sublists have the same length
301
- num_columns_from_sublist = len(list(data_dict.values())[0][0])
302
- column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)]
303
-
304
- # create the dataframe
305
- df = pd.DataFrame(df_data, columns=column_names)
306
-
307
- # set 'Gene' as the index
308
- df = df.set_index('Gene')
309
-
310
- return df
311
-
312
- def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
313
  only_embs_df = embs_df.iloc[:, :emb_dims]
314
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
315
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
316
  str
317
  )
318
  vars_dict = {"embs": only_embs_df.columns}
319
-
320
- obs_dict = {"cell_id": list(only_embs_df.index)}
321
- for label_i in labels_clean:
322
- obs_dict[label_i] = list(embs_df[label_i])
323
-
324
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
325
  sc.tl.pca(adata, svd_solver="arpack")
326
  sc.pp.neighbors(adata, random_state=seed)
@@ -331,26 +296,21 @@ def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory,
331
  if kwargs_dict is not None:
332
  default_kwargs_dict.update(kwargs_dict)
333
 
334
- for label_i in labels_clean:
335
- output_prefix_label = output_prefix + f"_umap_{label_i}"
336
- output_file = (
337
- Path(output_directory) / output_prefix_label
338
- ).with_suffix(".pdf")
339
-
340
- cats = set(embs_df[label_i])
341
-
342
- with plt.rc_context():
343
- ax = sc.pl.umap(adata, color=label_i, show=False, **default_kwargs_dict)
344
- ax.legend(
345
- markerscale=2,
346
- frameon=False,
347
- loc="center left",
348
- bbox_to_anchor=(1, 0.5),
349
- ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
350
- )
351
- plt.show()
352
- plt.savefig(output_file, bbox_inches="tight")
353
-
354
  def gen_heatmap_class_colors(labels, df):
355
  pal = sns.cubehelix_palette(
356
  len(Counter(labels).keys()),
@@ -435,14 +395,13 @@ class EmbExtractor:
435
  "num_classes": {int},
436
  "emb_mode": {"cls", "cell", "gene"},
437
  "cell_emb_style": {"mean_pool"},
438
- "gene_emb_style": {"mean_pool", "all"},
439
  "filter_data": {None, dict},
440
  "max_ncells": {None, int},
441
  "emb_layer": {-1, 0},
442
  "emb_label": {None, list},
443
  "labels_to_plot": {None, list},
444
  "forward_batch_size": {int},
445
- "model_version": {"V1", "V2"},
446
  "token_dictionary_file": {None, str},
447
  "nproc": {int},
448
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
@@ -452,7 +411,7 @@ class EmbExtractor:
452
  self,
453
  model_type="Pretrained",
454
  num_classes=0,
455
- emb_mode="cls",
456
  cell_emb_style="mean_pool",
457
  gene_emb_style="mean_pool",
458
  filter_data=None,
@@ -463,8 +422,6 @@ class EmbExtractor:
463
  forward_batch_size=100,
464
  nproc=4,
465
  summary_stat=None,
466
- save_tdigest=False,
467
- model_version="V2",
468
  token_dictionary_file=None,
469
  ):
470
  """
@@ -472,8 +429,8 @@ class EmbExtractor:
472
 
473
  **Parameters:**
474
 
475
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "Pretrained-Quantized"}
476
- | Whether model is the pretrained Geneformer (full or quantized) or a fine-tuned gene or cell classifier.
477
  num_classes : int
478
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
479
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
@@ -483,9 +440,9 @@ class EmbExtractor:
483
  cell_emb_style : {"mean_pool"}
484
  | Method for summarizing cell embeddings if not using CLS token.
485
  | Currently only option is mean pooling of gene embeddings for given cell.
486
- gene_emb_style : {"mean_pool", "all}
487
  | Method for summarizing gene embeddings.
488
- | Currently only option is returning all or mean pooling of contextual gene embeddings for given gene.
489
  filter_data : None, dict
490
  | Default is to extract embeddings from all input data.
491
  | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
@@ -515,12 +472,6 @@ class EmbExtractor:
515
  | If mean or median, outputs only approximated mean or median embedding of input data.
516
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
517
  | Non-exact is slower but more memory-efficient.
518
- save_tdigest : bool
519
- | Whether to save a dictionary of tdigests for each gene and embedding dimension
520
- | Only applies when summary_stat is not None
521
- model_version : str
522
- | To auto-select settings for model version other than current default.
523
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
524
  token_dictionary_file : Path
525
  | Default is the Geneformer token dictionary
526
  | Path to pickle file containing token dictionary (Ensembl ID:token).
@@ -551,7 +502,6 @@ class EmbExtractor:
551
  self.emb_layer = emb_layer
552
  self.emb_label = emb_label
553
  self.labels_to_plot = labels_to_plot
554
- self.model_version = model_version
555
  self.token_dictionary_file = token_dictionary_file
556
  self.forward_batch_size = forward_batch_size
557
  self.nproc = nproc
@@ -561,29 +511,13 @@ class EmbExtractor:
561
  else:
562
  self.summary_stat = summary_stat
563
  self.exact_summary_stat = None
564
- self.save_tdigest = save_tdigest
565
 
566
  self.validate_options()
567
 
568
- if (summary_stat is None) and (save_tdigest is True):
569
- logger.warning(
570
- "tdigests will not be saved since summary_stat is None."
571
- )
572
- save_tdigest = False
573
-
574
- if self.model_version == "V1":
575
- from . import TOKEN_DICTIONARY_FILE_30M
576
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
577
- if self.emb_mode == "cls":
578
- self.emb_mode = "cell"
579
- logger.warning(
580
- "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
581
- )
582
-
583
  # load token dictionary (Ensembl IDs:token)
584
  if self.token_dictionary_file is None:
585
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE
586
- with open(self.token_dictionary_file, "rb") as f:
587
  self.gene_token_dict = pickle.load(f)
588
 
589
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
@@ -662,12 +596,6 @@ class EmbExtractor:
662
  filtered_input_data = pu.load_and_filter(
663
  self.filter_data, self.nproc, input_data_file
664
  )
665
-
666
- # Check to make sure that all the labels exist in the tokenized data:
667
- if self.emb_label is not None:
668
- for label in self.emb_label:
669
- assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
670
-
671
  if cell_state is not None:
672
  filtered_input_data = pu.filter_by_dict(
673
  filtered_input_data, cell_state, self.nproc
@@ -677,10 +605,6 @@ class EmbExtractor:
677
  self.model_type, self.num_classes, model_directory, mode="eval"
678
  )
679
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
680
- if self.save_tdigest:
681
- tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl")
682
- else:
683
- tdigest_path = None
684
  embs = get_embs(
685
  model=model,
686
  filtered_input_data=downsampled_data,
@@ -690,8 +614,6 @@ class EmbExtractor:
690
  forward_batch_size=self.forward_batch_size,
691
  token_gene_dict=self.token_gene_dict,
692
  summary_stat=self.summary_stat,
693
- save_tdigest=self.save_tdigest,
694
- tdigest_path=tdigest_path,
695
  )
696
 
697
  if self.emb_mode == "cell":
@@ -701,7 +623,7 @@ class EmbExtractor:
701
  embs_df = pd.DataFrame(embs.cpu().numpy()).T
702
  elif self.emb_mode == "gene":
703
  if self.summary_stat is None:
704
- embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict, self.gene_emb_style)
705
  elif self.summary_stat is not None:
706
  embs_df = pd.DataFrame(embs).T
707
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
@@ -717,14 +639,14 @@ class EmbExtractor:
717
  embs = embs.mean(dim=0)
718
  emb_dims = pu.get_model_emb_dims(model)
719
  embs_df = pd.DataFrame(
720
- embs_df.iloc[:, 0 : emb_dims].mean(axis="rows"),
721
  columns=[self.exact_summary_stat],
722
  ).T
723
  elif self.exact_summary_stat == "exact_median":
724
  embs = torch.median(embs, dim=0)[0]
725
  emb_dims = pu.get_model_emb_dims(model)
726
  embs_df = pd.DataFrame(
727
- embs_df.iloc[:, 0 : emb_dims].median(axis="rows"),
728
  columns=[self.exact_summary_stat],
729
  ).T
730
 
@@ -797,12 +719,6 @@ class EmbExtractor:
797
  )
798
  raise
799
 
800
- if self.emb_label is not None:
801
- logger.error(
802
- "For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
803
- )
804
- raise
805
-
806
  state_embs_dict = dict()
807
  state_key = cell_states_to_model["state_key"]
808
  for k, v in cell_states_to_model.items():
@@ -885,14 +801,14 @@ class EmbExtractor:
885
  raise
886
 
887
  if max_ncells_to_plot is not None:
888
- if self.max_ncells is not None:
889
- if max_ncells_to_plot > self.max_ncells:
890
- max_ncells_to_plot = self.max_ncells
891
- logger.warning(
892
- "max_ncells_to_plot must be <= max_ncells. "
893
- f"Changing max_ncells_to_plot to {self.max_ncells}."
894
- )
895
- embs = embs.sample(max_ncells_to_plot, axis=0)
896
 
897
  if self.emb_label is None:
898
  label_len = 0
@@ -913,9 +829,12 @@ class EmbExtractor:
913
  f"Label {label} from labels_to_plot "
914
  f"not present in provided embeddings dataframe."
915
  )
916
-
917
- labels_clean = [label for label in self.labels_to_plot if label in emb_labels]
918
- plot_umap(embs, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict)
 
 
 
919
 
920
  if plot_style == "heatmap":
921
  for label in self.labels_to_plot:
 
42
  special_token=False,
43
  summary_stat=None,
44
  silent=False,
 
 
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
 
180
  # calculate summary stat embs from approximated tdigests
181
  elif summary_stat is not None:
182
  if emb_mode == "cell":
 
 
 
183
  if summary_stat == "mean":
184
  summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
185
  elif summary_stat == "median":
186
  summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
187
  embs_stack = torch.tensor(summary_emb_list)
188
  elif emb_mode == "gene":
 
 
 
189
  if summary_stat == "mean":
190
  [
191
  update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
 
252
  return embs_df
253
 
254
 
255
+ def label_gene_embs(embs, downsampled_data, token_gene_dict):
256
  gene_set = {
257
  element for sublist in downsampled_data["input_ids"] for element in sublist
258
  }
 
267
  )
268
  for k in dict_i.keys():
269
  gene_emb_dict[k].append(dict_i[k])
270
+ for k in gene_emb_dict.keys():
271
+ gene_emb_dict[k] = (
272
+ torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
273
+ .cpu()
274
+ .numpy()
275
+ )
276
+ embs_df = pd.DataFrame(gene_emb_dict).T
 
 
 
277
  embs_df.index = [token_gene_dict[token] for token in embs_df.index]
278
  return embs_df
279
 
280
+
281
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  only_embs_df = embs_df.iloc[:, :emb_dims]
283
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
284
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
285
  str
286
  )
287
  vars_dict = {"embs": only_embs_df.columns}
288
+ obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
 
 
 
 
289
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
290
  sc.tl.pca(adata, svd_solver="arpack")
291
  sc.pp.neighbors(adata, random_state=seed)
 
296
  if kwargs_dict is not None:
297
  default_kwargs_dict.update(kwargs_dict)
298
 
299
+ cats = set(embs_df[label])
300
+
301
+ with plt.rc_context():
302
+ ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
303
+ ax.legend(
304
+ markerscale=2,
305
+ frameon=False,
306
+ loc="center left",
307
+ bbox_to_anchor=(1, 0.5),
308
+ ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
309
+ )
310
+ plt.show()
311
+ plt.savefig(output_file, bbox_inches="tight")
312
+
313
+
 
 
 
 
 
314
  def gen_heatmap_class_colors(labels, df):
315
  pal = sns.cubehelix_palette(
316
  len(Counter(labels).keys()),
 
395
  "num_classes": {int},
396
  "emb_mode": {"cls", "cell", "gene"},
397
  "cell_emb_style": {"mean_pool"},
398
+ "gene_emb_style": {"mean_pool"},
399
  "filter_data": {None, dict},
400
  "max_ncells": {None, int},
401
  "emb_layer": {-1, 0},
402
  "emb_label": {None, list},
403
  "labels_to_plot": {None, list},
404
  "forward_batch_size": {int},
 
405
  "token_dictionary_file": {None, str},
406
  "nproc": {int},
407
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
 
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
+ emb_mode="cell",
415
  cell_emb_style="mean_pool",
416
  gene_emb_style="mean_pool",
417
  filter_data=None,
 
422
  forward_batch_size=100,
423
  nproc=4,
424
  summary_stat=None,
 
 
425
  token_dictionary_file=None,
426
  ):
427
  """
 
429
 
430
  **Parameters:**
431
 
432
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
433
+ | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
434
  num_classes : int
435
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
436
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
 
440
  cell_emb_style : {"mean_pool"}
441
  | Method for summarizing cell embeddings if not using CLS token.
442
  | Currently only option is mean pooling of gene embeddings for given cell.
443
+ gene_emb_style : "mean_pool"
444
  | Method for summarizing gene embeddings.
445
+ | Currently only option is mean pooling of contextual gene embeddings for given gene.
446
  filter_data : None, dict
447
  | Default is to extract embeddings from all input data.
448
  | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
 
472
  | If mean or median, outputs only approximated mean or median embedding of input data.
473
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
474
  | Non-exact is slower but more memory-efficient.
 
 
 
 
 
 
475
  token_dictionary_file : Path
476
  | Default is the Geneformer token dictionary
477
  | Path to pickle file containing token dictionary (Ensembl ID:token).
 
502
  self.emb_layer = emb_layer
503
  self.emb_label = emb_label
504
  self.labels_to_plot = labels_to_plot
 
505
  self.token_dictionary_file = token_dictionary_file
506
  self.forward_batch_size = forward_batch_size
507
  self.nproc = nproc
 
511
  else:
512
  self.summary_stat = summary_stat
513
  self.exact_summary_stat = None
 
514
 
515
  self.validate_options()
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  # load token dictionary (Ensembl IDs:token)
518
  if self.token_dictionary_file is None:
519
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
520
+ with open(token_dictionary_file, "rb") as f:
521
  self.gene_token_dict = pickle.load(f)
522
 
523
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
 
596
  filtered_input_data = pu.load_and_filter(
597
  self.filter_data, self.nproc, input_data_file
598
  )
 
 
 
 
 
 
599
  if cell_state is not None:
600
  filtered_input_data = pu.filter_by_dict(
601
  filtered_input_data, cell_state, self.nproc
 
605
  self.model_type, self.num_classes, model_directory, mode="eval"
606
  )
607
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
 
 
 
 
608
  embs = get_embs(
609
  model=model,
610
  filtered_input_data=downsampled_data,
 
614
  forward_batch_size=self.forward_batch_size,
615
  token_gene_dict=self.token_gene_dict,
616
  summary_stat=self.summary_stat,
 
 
617
  )
618
 
619
  if self.emb_mode == "cell":
 
623
  embs_df = pd.DataFrame(embs.cpu().numpy()).T
624
  elif self.emb_mode == "gene":
625
  if self.summary_stat is None:
626
+ embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
627
  elif self.summary_stat is not None:
628
  embs_df = pd.DataFrame(embs).T
629
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
639
  embs = embs.mean(dim=0)
640
  emb_dims = pu.get_model_emb_dims(model)
641
  embs_df = pd.DataFrame(
642
+ embs_df[0 : emb_dims - 1].mean(axis="rows"),
643
  columns=[self.exact_summary_stat],
644
  ).T
645
  elif self.exact_summary_stat == "exact_median":
646
  embs = torch.median(embs, dim=0)[0]
647
  emb_dims = pu.get_model_emb_dims(model)
648
  embs_df = pd.DataFrame(
649
+ embs_df[0 : emb_dims - 1].median(axis="rows"),
650
  columns=[self.exact_summary_stat],
651
  ).T
652
 
 
719
  )
720
  raise
721
 
 
 
 
 
 
 
722
  state_embs_dict = dict()
723
  state_key = cell_states_to_model["state_key"]
724
  for k, v in cell_states_to_model.items():
 
801
  raise
802
 
803
  if max_ncells_to_plot is not None:
804
+ if max_ncells_to_plot > self.max_ncells:
805
+ max_ncells_to_plot = self.max_ncells
806
+ logger.warning(
807
+ "max_ncells_to_plot must be <= max_ncells. "
808
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
809
+ )
810
+ elif max_ncells_to_plot < self.max_ncells:
811
+ embs = embs.sample(max_ncells_to_plot, axis=0)
812
 
813
  if self.emb_label is None:
814
  label_len = 0
 
829
  f"Label {label} from labels_to_plot "
830
  f"not present in provided embeddings dataframe."
831
  )
832
+ continue
833
+ output_prefix_label = output_prefix + f"_umap_{label}"
834
+ output_file = (
835
+ Path(output_directory) / output_prefix_label
836
+ ).with_suffix(".pdf")
837
+ plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
838
 
839
  if plot_style == "heatmap":
840
  for label in self.labels_to_plot:
geneformer/{ensembl_mapping_dict_gc104M.pkl → ensembl_mapping_dict_gc95M.pkl} RENAMED
File without changes
geneformer/evaluation_utils.py CHANGED
@@ -8,7 +8,6 @@ import numpy as np
8
  import pandas as pd
9
  import seaborn as sns
10
  import torch
11
- import datasets
12
  from datasets.utils.logging import disable_progress_bar, enable_progress_bar
13
  from sklearn import preprocessing
14
  from sklearn.metrics import (
@@ -21,15 +20,20 @@ from sklearn.metrics import (
21
  )
22
  from tqdm.auto import trange
23
 
 
24
  from .emb_extractor import make_colorbar
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
 
29
- def preprocess_classifier_batch(cell_batch, max_len, label_name, gene_token_dict):
30
  if max_len is None:
31
  max_len = max([len(i) for i in cell_batch["input_ids"]])
32
 
 
 
 
 
33
  def pad_label_example(example):
34
  example[label_name] = np.pad(
35
  example[label_name],
@@ -77,7 +81,7 @@ def py_softmax(vector):
77
  return e / e.sum()
78
 
79
 
80
- def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict, predict_metadata=None):
81
  if classifier_type == "gene":
82
  label_name = "labels"
83
  elif classifier_type == "cell":
@@ -85,14 +89,6 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
85
 
86
  predict_logits = []
87
  predict_labels = []
88
-
89
- predict_metadata_all = None
90
-
91
- if predict_metadata is not None:
92
- predict_metadata_all = dict()
93
- for metadata_name in predict_metadata:
94
- predict_metadata_all[metadata_name] = []
95
-
96
  model.eval()
97
 
98
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
@@ -107,21 +103,11 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
107
  for i in trange(0, evalset_len, forward_batch_size):
108
  max_range = min(i + forward_batch_size, evalset_len)
109
  batch_evalset = evalset.select([i for i in range(i, max_range)])
110
-
111
- if predict_metadata is not None:
112
- for metadata_name in predict_metadata:
113
- predict_metadata_all[metadata_name] += batch_evalset[metadata_name]
114
-
115
  padded_batch = preprocess_classifier_batch(
116
- batch_evalset, max_evalset_len, label_name, gene_token_dict
117
  )
118
-
119
  padded_batch.set_format(type="torch")
120
 
121
- # For datasets>=4.0.0, convert to dict to avoid format issues
122
- if int(datasets.__version__.split(".")[0]) >= 4:
123
- padded_batch = padded_batch[:]
124
-
125
  input_data_batch = padded_batch["input_ids"]
126
  attn_msk_batch = padded_batch["attention_mask"]
127
  label_batch = padded_batch[label_name]
@@ -148,8 +134,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
148
  y_pred = [vote(item[0]) for item in logit_label_paired]
149
  y_true = [item[1] for item in logit_label_paired]
150
  logits_list = [item[0] for item in logit_label_paired]
151
-
152
- return y_pred, y_true, logits_list, predict_metadata_all
153
 
154
 
155
  def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
@@ -197,18 +182,14 @@ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix
197
  for model_name in roc_metric_dict.keys():
198
  mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
199
  mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
 
 
200
  color = model_style_dict[model_name]["color"]
201
  linestyle = model_style_dict[model_name]["linestyle"]
202
- if "roc_auc" not in roc_metric_dict[model_name].keys():
203
- all_roc_auc = roc_metric_dict[model_name]["all_roc_auc"]
204
- label = f"{model_name} (AUC {all_roc_auc:0.2f})"
205
  else:
206
- roc_auc = roc_metric_dict[model_name]["roc_auc"]
207
- roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
208
- if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
209
- label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
210
- else:
211
- label = f"{model_name} (AUC {roc_auc:0.2f})"
212
  plt.plot(
213
  mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
214
  )
 
8
  import pandas as pd
9
  import seaborn as sns
10
  import torch
 
11
  from datasets.utils.logging import disable_progress_bar, enable_progress_bar
12
  from sklearn import preprocessing
13
  from sklearn.metrics import (
 
20
  )
21
  from tqdm.auto import trange
22
 
23
+ from . import TOKEN_DICTIONARY_FILE
24
  from .emb_extractor import make_colorbar
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
 
29
+ def preprocess_classifier_batch(cell_batch, max_len, label_name):
30
  if max_len is None:
31
  max_len = max([len(i) for i in cell_batch["input_ids"]])
32
 
33
+ # load token dictionary (Ensembl IDs:token)
34
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
35
+ gene_token_dict = pickle.load(f)
36
+
37
  def pad_label_example(example):
38
  example[label_name] = np.pad(
39
  example[label_name],
 
81
  return e / e.sum()
82
 
83
 
84
+ def classifier_predict(model, classifier_type, evalset, forward_batch_size):
85
  if classifier_type == "gene":
86
  label_name = "labels"
87
  elif classifier_type == "cell":
 
89
 
90
  predict_logits = []
91
  predict_labels = []
 
 
 
 
 
 
 
 
92
  model.eval()
93
 
94
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
 
103
  for i in trange(0, evalset_len, forward_batch_size):
104
  max_range = min(i + forward_batch_size, evalset_len)
105
  batch_evalset = evalset.select([i for i in range(i, max_range)])
 
 
 
 
 
106
  padded_batch = preprocess_classifier_batch(
107
+ batch_evalset, max_evalset_len, label_name
108
  )
 
109
  padded_batch.set_format(type="torch")
110
 
 
 
 
 
111
  input_data_batch = padded_batch["input_ids"]
112
  attn_msk_batch = padded_batch["attention_mask"]
113
  label_batch = padded_batch[label_name]
 
134
  y_pred = [vote(item[0]) for item in logit_label_paired]
135
  y_true = [item[1] for item in logit_label_paired]
136
  logits_list = [item[0] for item in logit_label_paired]
137
+ return y_pred, y_true, logits_list
 
138
 
139
 
140
  def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
 
182
  for model_name in roc_metric_dict.keys():
183
  mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
184
  mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
185
+ roc_auc = roc_metric_dict[model_name]["roc_auc"]
186
+ roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
187
  color = model_style_dict[model_name]["color"]
188
  linestyle = model_style_dict[model_name]["linestyle"]
189
+ if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
190
+ label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
 
191
  else:
192
+ label = f"{model_name} (AUC {roc_auc:0.2f})"
 
 
 
 
 
193
  plt.plot(
194
  mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
195
  )
geneformer/{gene_median_dictionary_gc104M.pkl → gene_median_dictionary_gc95M.pkl} RENAMED
File without changes
geneformer/{gene_name_id_dict_gc104M.pkl → gene_name_id_dict_gc95M.pkl} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fabfa0c2f49c598c59ae432a32c3499a5908c033756c663b5e0cddf58deea8e1
3
- size 1660882
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b0fd0521406ed18b2e341ef0acb5f53aa1a62457a07ca5840e1c142f46dd326
3
+ size 2038812
geneformer/in_silico_perturber.py CHANGED
@@ -40,7 +40,7 @@ import pickle
40
  from collections import defaultdict
41
 
42
  import torch
43
- from datasets import Dataset
44
  from multiprocess import set_start_method
45
  from tqdm.auto import trange
46
 
@@ -48,9 +48,7 @@ from . import TOKEN_DICTIONARY_FILE
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
 
51
- import datasets
52
- datasets.logging.disable_progress_bar()
53
-
54
 
55
  logger = logging.getLogger(__name__)
56
 
@@ -62,7 +60,7 @@ class InSilicoPerturber:
62
  "genes_to_perturb": {"all", list},
63
  "combos": {0, 1},
64
  "anchor_gene": {None, str},
65
- "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "Pretrained-Quantized", "MTLCellClassifier-Quantized"},
66
  "num_classes": {int},
67
  "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
68
  "cell_emb_style": {"mean_pool"},
@@ -72,7 +70,6 @@ class InSilicoPerturber:
72
  "max_ncells": {None, int},
73
  "cell_inds_to_perturb": {"all", dict},
74
  "emb_layer": {-1, 0},
75
- "model_version": {"V1", "V2"},
76
  "token_dictionary_file": {None, str},
77
  "forward_batch_size": {int},
78
  "nproc": {int},
@@ -87,7 +84,7 @@ class InSilicoPerturber:
87
  anchor_gene=None,
88
  model_type="Pretrained",
89
  num_classes=0,
90
- emb_mode="cls",
91
  cell_emb_style="mean_pool",
92
  filter_data=None,
93
  cell_states_to_model=None,
@@ -97,7 +94,6 @@ class InSilicoPerturber:
97
  emb_layer=-1,
98
  forward_batch_size=100,
99
  nproc=4,
100
- model_version="V2",
101
  token_dictionary_file=None,
102
  clear_mem_ncells=1000,
103
  ):
@@ -134,7 +130,7 @@ class InSilicoPerturber:
134
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
135
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
136
  | anchor gene will be perturbed in combination with each other gene.
137
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "Pretrained-Quantized", "MTLCellClassifier-Quantized"}
138
  | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
139
  num_classes : int
140
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
@@ -186,9 +182,6 @@ class InSilicoPerturber:
186
  | Batch size for forward pass.
187
  nproc : int
188
  | Number of CPU processes to use.
189
- model_version : str
190
- | To auto-select settings for model version other than current default.
191
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
192
  token_dictionary_file : Path
193
  | Path to pickle file containing token dictionary (Ensembl ID:token).
194
  clear_mem_ncells : int
@@ -229,30 +222,15 @@ class InSilicoPerturber:
229
  self.emb_layer = emb_layer
230
  self.forward_batch_size = forward_batch_size
231
  self.nproc = nproc
232
- self.model_version = model_version
233
  self.token_dictionary_file = token_dictionary_file
234
- self.clear_mem_ncells = clear_mem_ncells
235
 
236
  self.validate_options()
237
 
238
- if self.model_version == "V1":
239
- from . import TOKEN_DICTIONARY_FILE_30M
240
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
241
- if self.emb_mode == "cls":
242
- self.emb_mode = "cell"
243
- logger.warning(
244
- "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
245
- )
246
- if self.emb_mode == "cls_and_gene":
247
- self.emb_mode = "cell_and_gene"
248
- logger.warning(
249
- "model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
250
- )
251
-
252
  # load token dictionary (Ensembl IDs:token)
253
  if self.token_dictionary_file is None:
254
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE
255
- with open(self.token_dictionary_file, "rb") as f:
256
  self.gene_token_dict = pickle.load(f)
257
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
258
 
@@ -816,8 +794,6 @@ class InSilicoPerturber:
816
  return example
817
 
818
  total_batch_length = len(filtered_input_data)
819
-
820
-
821
  if self.cell_states_to_model is None:
822
  cos_sims_dict = defaultdict(list)
823
  else:
@@ -902,7 +878,7 @@ class InSilicoPerturber:
902
  )
903
 
904
  ##### CLS and Gene Embedding Mode #####
905
- elif self.emb_mode == "cls_and_gene":
906
  full_original_emb = get_embs(
907
  model,
908
  minibatch,
@@ -915,7 +891,6 @@ class InSilicoPerturber:
915
  silent=True,
916
  )
917
  indices_to_perturb = perturbation_batch["perturb_index"]
918
-
919
  # remove indices that were perturbed
920
  original_emb = pu.remove_perturbed_indices_set(
921
  full_original_emb,
@@ -924,7 +899,6 @@ class InSilicoPerturber:
924
  self.tokens_to_perturb,
925
  minibatch["length"],
926
  )
927
-
928
  full_perturbation_emb = get_embs(
929
  model,
930
  perturbation_batch,
@@ -936,7 +910,7 @@ class InSilicoPerturber:
936
  summary_stat=None,
937
  silent=True,
938
  )
939
-
940
  # remove special tokens and padding
941
  original_emb = original_emb[:, 1:-1, :]
942
  if self.perturb_type == "overexpress":
@@ -947,25 +921,9 @@ class InSilicoPerturber:
947
  perturbation_emb = full_perturbation_emb[
948
  :, 1 : max(perturbation_batch["length"]) - 1, :
949
  ]
950
-
951
- n_perturbation_genes = perturbation_emb.size()[1]
952
 
953
- # truncate the original embedding as necessary
954
- if self.perturb_type == "overexpress":
955
- def calc_perturbation_length(ids):
956
- if ids == [-100]:
957
- return 0
958
- else:
959
- return len(ids)
960
-
961
- max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
962
 
963
- max_n_overflow = max(minibatch["n_overflow"])
964
- if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
965
- original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
966
- elif perturbation_emb.size()[1] < original_emb.size()[1]:
967
- original_emb = original_emb[:, 0:max_tensor_size, :]
968
-
969
  gene_cos_sims = pu.quant_cos_sims(
970
  perturbation_emb,
971
  original_emb,
 
40
  from collections import defaultdict
41
 
42
  import torch
43
+ from datasets import Dataset, disable_progress_bars
44
  from multiprocess import set_start_method
45
  from tqdm.auto import trange
46
 
 
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
 
51
+ disable_progress_bars()
 
 
52
 
53
  logger = logging.getLogger(__name__)
54
 
 
60
  "genes_to_perturb": {"all", list},
61
  "combos": {0, 1},
62
  "anchor_gene": {None, str},
63
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"},
64
  "num_classes": {int},
65
  "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
66
  "cell_emb_style": {"mean_pool"},
 
70
  "max_ncells": {None, int},
71
  "cell_inds_to_perturb": {"all", dict},
72
  "emb_layer": {-1, 0},
 
73
  "token_dictionary_file": {None, str},
74
  "forward_batch_size": {int},
75
  "nproc": {int},
 
84
  anchor_gene=None,
85
  model_type="Pretrained",
86
  num_classes=0,
87
+ emb_mode="cell",
88
  cell_emb_style="mean_pool",
89
  filter_data=None,
90
  cell_states_to_model=None,
 
94
  emb_layer=-1,
95
  forward_batch_size=100,
96
  nproc=4,
 
97
  token_dictionary_file=None,
98
  clear_mem_ncells=1000,
99
  ):
 
130
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
131
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
132
  | anchor gene will be perturbed in combination with each other gene.
133
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
134
  | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
135
  num_classes : int
136
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
 
182
  | Batch size for forward pass.
183
  nproc : int
184
  | Number of CPU processes to use.
 
 
 
185
  token_dictionary_file : Path
186
  | Path to pickle file containing token dictionary (Ensembl ID:token).
187
  clear_mem_ncells : int
 
222
  self.emb_layer = emb_layer
223
  self.forward_batch_size = forward_batch_size
224
  self.nproc = nproc
 
225
  self.token_dictionary_file = token_dictionary_file
226
+ self.clear_mem_ncells = clear_mem_ncells
227
 
228
  self.validate_options()
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  # load token dictionary (Ensembl IDs:token)
231
  if self.token_dictionary_file is None:
232
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
233
+ with open(token_dictionary_file, "rb") as f:
234
  self.gene_token_dict = pickle.load(f)
235
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
236
 
 
794
  return example
795
 
796
  total_batch_length = len(filtered_input_data)
 
 
797
  if self.cell_states_to_model is None:
798
  cos_sims_dict = defaultdict(list)
799
  else:
 
878
  )
879
 
880
  ##### CLS and Gene Embedding Mode #####
881
+ elif self.emb_mode == "cls_and_gene":
882
  full_original_emb = get_embs(
883
  model,
884
  minibatch,
 
891
  silent=True,
892
  )
893
  indices_to_perturb = perturbation_batch["perturb_index"]
 
894
  # remove indices that were perturbed
895
  original_emb = pu.remove_perturbed_indices_set(
896
  full_original_emb,
 
899
  self.tokens_to_perturb,
900
  minibatch["length"],
901
  )
 
902
  full_perturbation_emb = get_embs(
903
  model,
904
  perturbation_batch,
 
910
  summary_stat=None,
911
  silent=True,
912
  )
913
+
914
  # remove special tokens and padding
915
  original_emb = original_emb[:, 1:-1, :]
916
  if self.perturb_type == "overexpress":
 
921
  perturbation_emb = full_perturbation_emb[
922
  :, 1 : max(perturbation_batch["length"]) - 1, :
923
  ]
 
 
924
 
925
+ n_perturbation_genes = perturbation_emb.size()[1]
 
 
 
 
 
 
 
 
926
 
 
 
 
 
 
 
927
  gene_cos_sims = pu.quant_cos_sims(
928
  perturbation_emb,
929
  original_emb,
geneformer/in_silico_perturber_stats.py CHANGED
@@ -640,16 +640,10 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
640
  cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
641
 
642
  # quantify number of detections of each gene
643
- if anchor_token is None:
644
- cos_sims_full_df["N_Detections"] = [
645
- n_detections(i, dict_list, "cell", anchor_token)
646
- for i in cos_sims_full_df["Gene"]
647
- ]
648
- else:
649
- cos_sims_full_df["N_Detections"] = [
650
- n_detections(i, dict_list, "gene", anchor_token)
651
- for i in cos_sims_full_df["Gene"]
652
- ]
653
 
654
  if combos == 0:
655
  cos_sims_full_df = cos_sims_full_df.sort_values(
@@ -676,7 +670,6 @@ class InSilicoPerturberStats:
676
  "anchor_gene": {None, str},
677
  "cell_states_to_model": {None, dict},
678
  "pickle_suffix": {None, str},
679
- "model_version": {"V1", "V2"},
680
  }
681
 
682
  def __init__(
@@ -687,7 +680,6 @@ class InSilicoPerturberStats:
687
  anchor_gene=None,
688
  cell_states_to_model=None,
689
  pickle_suffix="_raw.pickle",
690
- model_version="V2",
691
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
692
  gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
693
  ):
@@ -715,7 +707,7 @@ class InSilicoPerturberStats:
715
  | analyzes data for anchor gene perturbed in combination with each other gene.
716
  | However, if combos=0 and anchor_gene="ENSG00000136574":
717
  | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
718
- cell_states_to_model : None, dict
719
  | Cell states to model if testing perturbations that achieve goal state change.
720
  | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
721
  | state_key: key specifying name of column in .dataset that defines the start/goal states
@@ -726,12 +718,6 @@ class InSilicoPerturberStats:
726
  | "start_state": "dcm",
727
  | "goal_state": "nf",
728
  | "alt_states": ["hcm", "other1", "other2"]}
729
- pickle_suffix : None, str
730
- | Suffix to subselect intermediate raw files for analysis.
731
- | Default output of InSilicoPerturber uses suffix "_raw.pickle".
732
- model_version : str
733
- | To auto-select settings for model version other than current default.
734
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
735
  token_dictionary_file : Path
736
  | Path to pickle file containing token dictionary (Ensembl ID:token).
737
  gene_name_id_dictionary_file : Path
@@ -744,15 +730,9 @@ class InSilicoPerturberStats:
744
  self.anchor_gene = anchor_gene
745
  self.cell_states_to_model = cell_states_to_model
746
  self.pickle_suffix = pickle_suffix
747
- self.model_version = model_version
748
 
749
  self.validate_options()
750
 
751
- if self.model_version == "V1":
752
- from . import ENSEMBL_DICTIONARY_FILE_30M, TOKEN_DICTIONARY_FILE_30M
753
- token_dictionary_file=TOKEN_DICTIONARY_FILE_30M
754
- gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE_30M
755
-
756
  # load token dictionary (Ensembl IDs:token)
757
  with open(token_dictionary_file, "rb") as f:
758
  self.gene_token_dict = pickle.load(f)
 
640
  cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
641
 
642
  # quantify number of detections of each gene
643
+ cos_sims_full_df["N_Detections"] = [
644
+ n_detections(i, dict_list, "gene", anchor_token)
645
+ for i in cos_sims_full_df["Gene"]
646
+ ]
 
 
 
 
 
 
647
 
648
  if combos == 0:
649
  cos_sims_full_df = cos_sims_full_df.sort_values(
 
670
  "anchor_gene": {None, str},
671
  "cell_states_to_model": {None, dict},
672
  "pickle_suffix": {None, str},
 
673
  }
674
 
675
  def __init__(
 
680
  anchor_gene=None,
681
  cell_states_to_model=None,
682
  pickle_suffix="_raw.pickle",
 
683
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
684
  gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
685
  ):
 
707
  | analyzes data for anchor gene perturbed in combination with each other gene.
708
  | However, if combos=0 and anchor_gene="ENSG00000136574":
709
  | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
710
+ cell_states_to_model: None, dict
711
  | Cell states to model if testing perturbations that achieve goal state change.
712
  | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
713
  | state_key: key specifying name of column in .dataset that defines the start/goal states
 
718
  | "start_state": "dcm",
719
  | "goal_state": "nf",
720
  | "alt_states": ["hcm", "other1", "other2"]}
 
 
 
 
 
 
721
  token_dictionary_file : Path
722
  | Path to pickle file containing token dictionary (Ensembl ID:token).
723
  gene_name_id_dictionary_file : Path
 
730
  self.anchor_gene = anchor_gene
731
  self.cell_states_to_model = cell_states_to_model
732
  self.pickle_suffix = pickle_suffix
 
733
 
734
  self.validate_options()
735
 
 
 
 
 
 
736
  # load token dictionary (Ensembl IDs:token)
737
  with open(token_dictionary_file, "rb") as f:
738
  self.gene_token_dict = pickle.load(f)
geneformer/mtl/__init__.py CHANGED
@@ -1,4 +1 @@
1
- # ruff: noqa: F401
2
-
3
- from . import eval_utils
4
- from . import utils
 
1
+ # ruff: noqa: F401
 
 
 
geneformer/mtl/collators.py CHANGED
@@ -1,8 +1,8 @@
1
  # imports
2
  import torch
3
  import pickle
4
- from geneformer.collator_for_classification import DataCollatorForGeneClassification
5
- from geneformer import TOKEN_DICTIONARY_FILE
6
 
7
  """Geneformer collator for multi-task cell classification."""
8
 
 
1
  # imports
2
  import torch
3
  import pickle
4
+ from ..collator_for_classification import DataCollatorForGeneClassification
5
+ from .. import TOKEN_DICTIONARY_FILE
6
 
7
  """Geneformer collator for multi-task cell classification."""
8
 
geneformer/mtl/data.py CHANGED
@@ -1,237 +1,150 @@
1
  import os
2
- import pickle
3
- import torch
4
- from torch.utils.data import DataLoader, Dataset
5
- from datasets import load_from_disk
6
 
7
  from .collators import DataCollatorForMultitaskCellClassification
 
8
 
9
 
10
- class StreamingMultiTaskDataset(Dataset):
11
-
12
- def __init__(self, dataset_path, config, is_test=False, dataset_type=""):
13
- """Initialize the streaming dataset."""
14
- self.dataset = load_from_disk(dataset_path)
15
- self.config = config
16
- self.is_test = is_test
17
- self.dataset_type = dataset_type
18
- self.cell_id_mapping = {}
19
-
20
- # Setup task and column mappings
21
- self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
22
- self.task_to_column = dict(zip(self.task_names, config["task_columns"]))
23
- config["task_names"] = self.task_names
24
-
25
- # Check if unique_cell_id column exists in the dataset
26
- self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names
27
- print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset")
28
-
29
- # Setup label mappings
30
- self.label_mappings_path = os.path.join(
31
- config["results_dir"],
32
- f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
33
- )
34
-
35
  if not is_test:
36
- self._validate_columns()
37
- self.task_label_mappings, self.num_labels_list = self._create_label_mappings()
38
- self._save_label_mappings()
39
- else:
40
- # Load existing mappings for test data
41
- self.task_label_mappings = self._load_label_mappings()
42
- self.num_labels_list = [len(mapping) for mapping in self.task_label_mappings.values()]
43
-
44
- def _validate_columns(self):
45
- """Ensures required columns are present in the dataset."""
46
- missing_columns = [col for col in self.task_to_column.values()
47
- if col not in self.dataset.column_names]
48
- if missing_columns:
49
- raise KeyError(
50
- f"Missing columns in {self.dataset_type} dataset: {missing_columns}. "
51
- f"Available columns: {self.dataset.column_names}"
52
- )
53
-
54
- def _create_label_mappings(self):
55
- """Creates label mappings for the dataset."""
56
  task_label_mappings = {}
 
57
  num_labels_list = []
58
-
59
- for task, column in self.task_to_column.items():
60
- unique_values = sorted(set(self.dataset[column]))
61
- mapping = {label: idx for idx, label in enumerate(unique_values)}
62
- task_label_mappings[task] = mapping
63
- num_labels_list.append(len(unique_values))
64
-
65
- return task_label_mappings, num_labels_list
66
-
67
- def _save_label_mappings(self):
68
- """Saves label mappings to a pickle file."""
69
- with open(self.label_mappings_path, "wb") as f:
70
- pickle.dump(self.task_label_mappings, f)
71
-
72
- def _load_label_mappings(self):
73
- """Loads label mappings from a pickle file."""
74
- with open(self.label_mappings_path, "rb") as f:
75
- return pickle.load(f)
76
-
77
- def __len__(self):
78
- return len(self.dataset)
79
-
80
- def __getitem__(self, idx):
81
- record = self.dataset[idx]
82
-
83
- # Store cell ID mapping
84
- if self.has_unique_cell_ids:
85
- unique_cell_id = record["unique_cell_id"]
86
- self.cell_id_mapping[idx] = unique_cell_id
87
- else:
88
- self.cell_id_mapping[idx] = f"cell_{idx}"
89
-
90
- # Create transformed record
91
- transformed_record = {
92
- "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
93
- "cell_id": idx,
94
- }
95
-
96
- # Add labels
97
- if not self.is_test:
98
- label_dict = {
99
- task: self.task_label_mappings[task][record[column]]
100
- for task, column in self.task_to_column.items()
101
- }
102
  else:
103
- label_dict = {task: -1 for task in self.config["task_names"]}
104
-
105
- transformed_record["label"] = label_dict
106
-
107
- return transformed_record
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
 
 
109
 
110
- def get_data_loader(dataset, batch_size, sampler=None, shuffle=True):
111
- """Create a DataLoader with the given dataset and parameters."""
112
- return DataLoader(
113
- dataset,
114
- batch_size=batch_size,
115
- sampler=sampler,
116
- shuffle=shuffle if sampler is None else False,
117
- num_workers=0,
118
- pin_memory=True,
119
- collate_fn=DataCollatorForMultitaskCellClassification(),
120
- )
 
121
 
 
122
 
123
- def prepare_data_loaders(config, include_test=False):
124
- """Prepare data loaders for training, validation, and optionally test."""
125
- result = {}
126
-
127
- # Process train data
128
- train_dataset = StreamingMultiTaskDataset(
129
- config["train_path"],
130
- config,
131
- dataset_type="train"
132
- )
133
- result["train_loader"] = get_data_loader(train_dataset, config["batch_size"])
134
-
135
- # Store the cell ID mapping from the dataset
136
- result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()}
137
- print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset")
138
-
139
- result["num_labels_list"] = train_dataset.num_labels_list
140
-
141
- # Process validation data
142
- val_dataset = StreamingMultiTaskDataset(
143
- config["val_path"],
144
- config,
145
- dataset_type="validation"
146
- )
147
- result["val_loader"] = get_data_loader(val_dataset, config["batch_size"])
148
-
149
- # Store the complete cell ID mapping for validation
150
- for idx in range(len(val_dataset)):
151
- _ = val_dataset[idx]
152
-
153
- result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()}
154
- print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset")
155
-
156
- # Validate label mappings
157
- validate_label_mappings(config)
158
-
159
- # Process test data if requested
160
- if include_test and "test_path" in config:
161
- test_dataset = StreamingMultiTaskDataset(
162
- config["test_path"],
163
- config,
164
- is_test=True,
165
- dataset_type="test"
166
- )
167
- result["test_loader"] = get_data_loader(test_dataset, config["batch_size"])
168
-
169
- for idx in range(len(test_dataset)):
170
- _ = test_dataset[idx]
171
-
172
- result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()}
173
- print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset")
174
-
175
- return result
176
-
177
-
178
- def validate_label_mappings(config):
179
- """Ensures train and validation label mappings are consistent."""
180
- train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
181
- val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
182
-
183
- with open(train_mappings_path, "rb") as f:
184
- train_mappings = pickle.load(f)
185
-
186
- with open(val_mappings_path, "rb") as f:
187
- val_mappings = pickle.load(f)
188
-
189
- for task_name in config["task_names"]:
190
- if train_mappings[task_name] != val_mappings[task_name]:
191
- raise ValueError(
192
- f"Mismatch in label mappings for task '{task_name}'.\n"
193
- f"Train Mapping: {train_mappings[task_name]}\n"
194
- f"Validation Mapping: {val_mappings[task_name]}"
195
- )
196
 
197
 
198
- # Legacy functions for backward compatibility
199
  def preload_and_process_data(config):
200
- """Preloads and preprocesses train and validation datasets."""
201
- data = prepare_data_loaders(config)
202
-
 
 
 
 
203
  return (
204
- data["train_loader"].dataset,
205
- data["train_cell_mapping"],
206
- data["val_loader"].dataset,
207
- data["val_cell_mapping"],
208
- data["num_labels_list"]
209
  )
210
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  def preload_data(config):
213
- """Preprocesses train and validation data for trials."""
214
- data = prepare_data_loaders(config)
215
- return data["train_loader"], data["val_loader"]
 
216
 
217
 
218
  def load_and_preprocess_test_data(config):
219
- """Loads and preprocesses test data."""
220
- test_dataset = StreamingMultiTaskDataset(
221
- config["test_path"],
222
- config,
223
- is_test=True,
224
- dataset_type="test"
225
- )
226
-
227
- return (
228
- test_dataset,
229
- test_dataset.cell_id_mapping,
230
- test_dataset.num_labels_list
231
- )
232
 
233
 
234
  def prepare_test_loader(config):
235
- """Prepares DataLoader for test data."""
236
- data = prepare_data_loaders(config, include_test=True)
237
- return data["test_loader"], data["test_cell_mapping"], data["num_labels_list"]
 
 
 
 
 
 
1
  import os
 
 
 
 
2
 
3
  from .collators import DataCollatorForMultitaskCellClassification
4
+ from .imports import *
5
 
6
 
7
+ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
8
+ try:
9
+ dataset = load_from_disk(dataset_path)
10
+
11
+ task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
12
+ task_to_column = dict(zip(task_names, config["task_columns"]))
13
+ config["task_names"] = task_names
14
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  if not is_test:
16
+ available_columns = set(dataset.column_names)
17
+ for column in task_to_column.values():
18
+ if column not in available_columns:
19
+ raise KeyError(
20
+ f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
21
+ )
22
+
23
+ label_mappings = {}
 
 
 
 
 
 
 
 
 
 
 
 
24
  task_label_mappings = {}
25
+ cell_id_mapping = {}
26
  num_labels_list = []
27
+
28
+ # Load or create task label mappings
29
+ if not is_test:
30
+ for task, column in task_to_column.items():
31
+ unique_values = sorted(set(dataset[column])) # Ensure consistency
32
+ label_mappings[column] = {
33
+ label: idx for idx, label in enumerate(unique_values)
34
+ }
35
+ task_label_mappings[task] = label_mappings[column]
36
+ num_labels_list.append(len(unique_values))
37
+
38
+ # Print the mappings for each task with dataset type prefix
39
+ for task, mapping in task_label_mappings.items():
40
+ print(
41
+ f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
42
+ ) # sanity check, for train/validation splits
43
+
44
+ # Save the task label mappings as a pickle file
45
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
46
+ pickle.dump(task_label_mappings, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
+ # Load task label mappings from pickle file for test data
49
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
50
+ task_label_mappings = pickle.load(f)
51
+
52
+ # Infer num_labels_list from task_label_mappings
53
+ for task, mapping in task_label_mappings.items():
54
+ num_labels_list.append(len(mapping))
55
+
56
+ # Store unique cell IDs in a separate dictionary
57
+ for idx, record in enumerate(dataset):
58
+ cell_id = record.get("unique_cell_id", idx)
59
+ cell_id_mapping[idx] = cell_id
60
+
61
+ # Transform records to the desired format
62
+ transformed_dataset = []
63
+ for idx, record in enumerate(dataset):
64
+ transformed_record = {}
65
+ transformed_record["input_ids"] = torch.tensor(
66
+ record["input_ids"], dtype=torch.long
67
+ )
68
 
69
+ # Use index-based cell ID for internal tracking
70
+ transformed_record["cell_id"] = idx
71
 
72
+ if not is_test:
73
+ # Prepare labels
74
+ label_dict = {}
75
+ for task, column in task_to_column.items():
76
+ label_value = record[column]
77
+ label_index = task_label_mappings[task][label_value]
78
+ label_dict[task] = label_index
79
+ transformed_record["label"] = label_dict
80
+ else:
81
+ # Create dummy labels for test data
82
+ label_dict = {task: -1 for task in config["task_names"]}
83
+ transformed_record["label"] = label_dict
84
 
85
+ transformed_dataset.append(transformed_record)
86
 
87
+ return transformed_dataset, cell_id_mapping, num_labels_list
88
+ except KeyError as e:
89
+ print(f"Missing configuration or dataset key: {e}")
90
+ except Exception as e:
91
+ print(f"An error occurred while loading or preprocessing data: {e}")
92
+ return None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
 
 
95
  def preload_and_process_data(config):
96
+ # Load and preprocess data once
97
+ train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
98
+ config["train_path"], config, dataset_type="train"
99
+ )
100
+ val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
101
+ config["val_path"], config, dataset_type="validation"
102
+ )
103
  return (
104
+ train_dataset,
105
+ train_cell_id_mapping,
106
+ val_dataset,
107
+ val_cell_id_mapping,
108
+ num_labels_list,
109
  )
110
 
111
 
112
+ def get_data_loader(preprocessed_dataset, batch_size):
113
+ nproc = os.cpu_count() ### I/O operations
114
+
115
+ data_collator = DataCollatorForMultitaskCellClassification()
116
+
117
+ loader = DataLoader(
118
+ preprocessed_dataset,
119
+ batch_size=batch_size,
120
+ shuffle=True,
121
+ collate_fn=data_collator,
122
+ num_workers=nproc,
123
+ pin_memory=True,
124
+ )
125
+ return loader
126
+
127
+
128
  def preload_data(config):
129
+ # Preprocessing the data before the Optuna trials start
130
+ train_loader = get_data_loader("train", config)
131
+ val_loader = get_data_loader("val", config)
132
+ return train_loader, val_loader
133
 
134
 
135
  def load_and_preprocess_test_data(config):
136
+ """
137
+ Load and preprocess test data, treating it as unlabeled.
138
+ """
139
+ return load_and_preprocess_data(config["test_path"], config, is_test=True)
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  def prepare_test_loader(config):
143
+ """
144
+ Prepare DataLoader for the test dataset.
145
+ """
146
+ test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
147
+ config
148
+ )
149
+ test_loader = get_data_loader(test_dataset, config["batch_size"])
150
+ return test_loader, cell_id_mapping, num_labels_list
geneformer/mtl/eval_utils.py CHANGED
@@ -1,16 +1,19 @@
1
- import os
2
- import json
3
- import torch
4
  import pandas as pd
5
 
6
- from .data import prepare_test_loader
 
7
  from .model import GeneformerMultiTask
8
 
 
9
  def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
10
  task_pred_labels = {task_name: [] for task_name in config["task_names"]}
11
  task_pred_probs = {task_name: [] for task_name in config["task_names"]}
12
  cell_ids = []
13
 
 
 
 
 
14
  model.eval()
15
  with torch.no_grad():
16
  for batch in test_loader:
@@ -82,4 +85,4 @@ def load_and_evaluate_test_model(config):
82
  best_model.to(device)
83
 
84
  evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
85
- print("Evaluation completed.")
 
 
 
 
1
  import pandas as pd
2
 
3
+ from .imports import * # noqa # isort:skip
4
+ from .data import prepare_test_loader # noqa # isort:skip
5
  from .model import GeneformerMultiTask
6
 
7
+
8
  def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
9
  task_pred_labels = {task_name: [] for task_name in config["task_names"]}
10
  task_pred_probs = {task_name: [] for task_name in config["task_names"]}
11
  cell_ids = []
12
 
13
+ # # Load task label mappings from pickle file
14
+ # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
15
+ # task_label_mappings = pickle.load(f)
16
+
17
  model.eval()
18
  with torch.no_grad():
19
  for batch in test_loader:
 
85
  best_model.to(device)
86
 
87
  evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
88
+ print("Evaluation completed.")
geneformer/mtl/imports.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+ import json
4
+ import os
5
+ import pickle
6
+ import sys
7
+ import warnings
8
+ from enum import Enum
9
+ from itertools import chain
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ import numpy as np
13
+ import optuna
14
+ import pandas as pd
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+ from datasets import load_from_disk
20
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
21
+ from sklearn.model_selection import train_test_split
22
+ from sklearn.preprocessing import LabelEncoder
23
+ from torch.utils.data import DataLoader
24
+ from transformers import (
25
+ AdamW,
26
+ BatchEncoding,
27
+ BertConfig,
28
+ BertModel,
29
+ DataCollatorForTokenClassification,
30
+ SpecialTokensMixin,
31
+ get_cosine_schedule_with_warmup,
32
+ get_linear_schedule_with_warmup,
33
+ get_scheduler,
34
+ )
35
+ from transformers.utils import logging, to_py_obj
36
+
37
+ from .collators import DataCollatorForMultitaskCellClassification
38
+
39
+ # local modules
40
+ from .data import get_data_loader, preload_and_process_data
41
+ from .model import GeneformerMultiTask
42
+ from .optuna_utils import create_optuna_study
43
+ from .utils import save_model
geneformer/mtl/model.py CHANGED
@@ -118,4 +118,4 @@ class GeneformerMultiTask(nn.Module):
118
  f"Error during loss computation for task {task_id}: {e}"
119
  )
120
 
121
- return total_loss, logits, losses if labels is not None else logits
 
118
  f"Error during loss computation for task {task_id}: {e}"
119
  )
120
 
121
+ return total_loss, logits, losses if labels is not None else logits
geneformer/mtl/optuna_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ from optuna.integration import TensorBoardCallback
3
+
4
+
5
+ def save_trial_callback(study, trial, trials_result_path):
6
+ with open(trials_result_path, "a") as f:
7
+ f.write(
8
+ f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
9
+ )
10
+
11
+
12
+ def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
13
+ study = optuna.create_study(direction="maximize")
14
+
15
+ # init TensorBoard callback
16
+ tensorboard_callback = TensorBoardCallback(
17
+ dirname=tensorboard_log_dir, metric_name="F1 Macro"
18
+ )
19
+
20
+ # callback and TensorBoard callback
21
+ callbacks = [
22
+ lambda study, trial: save_trial_callback(study, trial, trials_result_path),
23
+ tensorboard_callback,
24
+ ]
25
+
26
+ study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
27
+ return study
geneformer/mtl/train.py CHANGED
@@ -1,707 +1,380 @@
1
  import os
 
 
 
2
  import pandas as pd
3
  import torch
4
- import torch.distributed as dist
5
- import torch.multiprocessing as mp
6
- from torch.nn.parallel import DistributedDataParallel as DDP
7
  from torch.utils.tensorboard import SummaryWriter
8
  from tqdm import tqdm
9
- import optuna
10
- import functools
11
- import time
12
 
 
13
  from .model import GeneformerMultiTask
14
- from .utils import (
15
- calculate_metrics,
16
- get_layer_freeze_range,
17
- set_seed,
18
- initialize_wandb,
19
- create_model,
20
- setup_optimizer_and_scheduler,
21
- save_model,
22
- save_hyperparameters,
23
- prepare_training_environment,
24
- log_training_step,
25
- log_validation_metrics,
26
- save_validation_predictions,
27
- setup_logging,
28
- setup_distributed_environment,
29
- train_distributed
30
- )
31
-
32
-
33
- class Trainer:
34
- """Trainer class for multi-task learning"""
35
-
36
- def __init__(self, config):
37
- self.config = config
38
- self.device = None
39
- self.model = None
40
- self.optimizer = None
41
- self.scheduler = None
42
- self.writer = None
43
- self.is_distributed = config.get("distributed_training", False)
44
- self.local_rank = config.get("local_rank", 0)
45
- self.is_main_process = not self.is_distributed or self.local_rank == 0
46
-
47
- def train_epoch(self, train_loader, epoch):
48
- """Train the model for one epoch."""
49
- epoch_start = time.time()
50
- self.model.train()
51
-
52
- # For distributed training, we need to be aware of the global batch count
53
- if self.is_distributed:
54
- # Get world size for reporting
55
- world_size = dist.get_world_size()
56
- # Calculate total batches across all GPUs
57
- total_batches_global = len(train_loader) * world_size if self.local_rank == 0 else len(train_loader)
58
- else:
59
- world_size = 1
60
- total_batches_global = len(train_loader)
61
-
62
- progress_bar = None
63
- if self.is_main_process:
64
- # Use the global batch count for progress reporting in distributed mode
65
- progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}",
66
- total=len(train_loader))
67
- iterator = progress_bar
68
-
69
- # Report distributed training information
70
- if self.is_distributed:
71
- print(f"Distributed training: {world_size} GPUs, {len(train_loader)} batches per GPU, "
72
- f"{total_batches_global} total batches globally")
73
- else:
74
- iterator = train_loader
75
-
76
- batch_times = []
77
- forward_times = []
78
- backward_times = []
79
- optimizer_times = []
80
-
81
- # Get gradient accumulation steps from config (default to 1 if not specified)
82
- accumulation_steps = self.config.get("gradient_accumulation_steps", 1)
83
-
84
- # Zero gradients at the beginning
85
- self.optimizer.zero_grad()
86
-
87
- # Track loss for the entire epoch
88
- total_loss = 0.0
89
- num_batches = 0
90
- accumulated_loss = 0.0
91
-
92
- for batch_idx, batch in enumerate(iterator):
93
- batch_start = time.time()
94
-
95
- input_ids = batch["input_ids"].to(self.device)
96
- attention_mask = batch["attention_mask"].to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  labels = [
98
- batch["labels"][task_name].to(self.device) for task_name in self.config["task_names"]
 
99
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- forward_start = time.time()
102
- loss, _, _ = self.model(input_ids, attention_mask, labels)
103
-
104
- # Scale loss by accumulation steps for gradient accumulation
105
- if accumulation_steps > 1:
106
- loss = loss / accumulation_steps
107
-
108
- forward_end = time.time()
109
- forward_times.append(forward_end - forward_start)
110
-
111
- # Track loss - store the unscaled loss for reporting
112
- unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps)
113
- total_loss += unscaled_loss
114
- num_batches += 1
115
- accumulated_loss += loss.item() # For gradient accumulation tracking
116
-
117
- backward_start = time.time()
118
-
119
- # Use no_sync() for all but the last accumulation step to avoid unnecessary communication
120
- if self.is_distributed and accumulation_steps > 1:
121
- # If this is not the last accumulation step or the last batch
122
- if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader):
123
- with self.model.no_sync():
124
- loss.backward()
125
- else:
126
- loss.backward()
127
- else:
128
- # Non-distributed training or accumulation_steps=1
129
- loss.backward()
130
-
131
- backward_end = time.time()
132
- backward_times.append(backward_end - backward_start)
133
-
134
- # Only update weights after accumulation_steps or at the end of the epoch
135
- if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
136
- if self.config["gradient_clipping"]:
137
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
138
-
139
- optimizer_start = time.time()
140
- self.optimizer.step()
141
- self.scheduler.step()
142
- self.optimizer.zero_grad()
143
- optimizer_end = time.time()
144
- optimizer_times.append(optimizer_end - optimizer_start)
145
-
146
- # Log after optimizer step
147
- if self.is_main_process:
148
- # Calculate running average loss
149
- avg_loss = total_loss / num_batches
150
-
151
- log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx)
152
-
153
- # Update progress bar with just the running average loss
154
- progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
155
-
156
- accumulated_loss = 0.0
157
- else:
158
- optimizer_times.append(0) # No optimizer step taken
159
-
160
- batch_end = time.time()
161
- batch_times.append(batch_end - batch_start)
162
-
163
- epoch_end = time.time()
164
-
165
- # Calculate average loss for the epoch
166
- epoch_avg_loss = total_loss / num_batches
167
-
168
- # If distributed, gather losses from all processes to compute global average
169
- if self.is_distributed:
170
- # Create a tensor to hold the loss
171
- loss_tensor = torch.tensor([epoch_avg_loss], device=self.device)
172
- # Gather losses from all processes
173
- all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
174
- dist.all_gather(all_losses, loss_tensor)
175
- # Compute the global average loss across all processes
176
- epoch_avg_loss = torch.mean(torch.stack(all_losses)).item()
177
-
178
- if self.is_main_process:
179
- # douhble check if batch_size has already been adjusted for world_size in the config
180
- # This avoids double-counting the effective batch size
181
- per_gpu_batch_size = self.config['batch_size']
182
- total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size
183
-
184
- print(f"Epoch {epoch+1} timing:")
185
- print(f" Total epoch time: {epoch_end - epoch_start:.2f}s")
186
- print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
187
- print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s")
188
- print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s")
189
- print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s")
190
- print(f" Gradient accumulation steps: {accumulation_steps}")
191
- print(f" Batch size per GPU: {per_gpu_batch_size}")
192
- print(f" Effective global batch size: {total_effective_batch}")
193
- print(f" Average training loss: {epoch_avg_loss:.4f}")
194
- if self.is_distributed:
195
- print(f" Total batches processed across all GPUs: {total_batches_global}")
196
- print(f" Communication optimization: Using no_sync() for gradient accumulation")
197
-
198
- return epoch_avg_loss # Return the average loss for the epoch
199
-
200
- def validate_model(self, val_loader):
201
- val_start = time.time()
202
- self.model.eval()
203
- val_loss = 0.0
204
- task_true_labels = {task_name: [] for task_name in self.config["task_names"]}
205
- task_pred_labels = {task_name: [] for task_name in self.config["task_names"]}
206
- task_pred_probs = {task_name: [] for task_name in self.config["task_names"]}
207
-
208
- val_cell_ids = {}
209
- sample_counter = 0
210
-
211
- batch_times = []
212
-
213
- # Print validation dataset size
214
- if self.is_main_process:
215
- print(f"Validation dataset size: {len(val_loader.dataset)} samples")
216
- print(f"Number of validation batches: {len(val_loader)}")
217
-
218
- if self.is_distributed:
219
- world_size = dist.get_world_size()
220
- print(f"Distributed validation: {world_size} GPUs")
221
- if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
222
- samples_per_gpu = val_loader.sampler.num_samples
223
- print(f"Each GPU processes {samples_per_gpu} validation samples")
224
- print(f"Total validation samples processed: {samples_per_gpu * world_size}")
225
-
226
- with torch.no_grad():
227
- for batch in val_loader:
228
- batch_start = time.time()
229
- input_ids = batch["input_ids"].to(self.device)
230
- attention_mask = batch["attention_mask"].to(self.device)
231
- labels = [
232
- batch["labels"][task_name].to(self.device)
233
- for task_name in self.config["task_names"]
234
- ]
235
- loss, logits, _ = self.model(input_ids, attention_mask, labels)
236
- val_loss += loss.item()
237
-
238
- if "cell_id" in batch:
239
- for i, cell_id in enumerate(batch["cell_id"]):
240
- # Store the actual index for later mapping to unique_cell_id
241
- val_cell_ids[sample_counter + i] = cell_id.item()
242
-
243
- for sample_idx in range(len(batch["input_ids"])):
244
- for i, task_name in enumerate(self.config["task_names"]):
245
- true_label = batch["labels"][task_name][sample_idx].item()
246
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
247
- # Store the full probability distribution
248
- pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist()
249
- task_true_labels[task_name].append(true_label)
250
- task_pred_labels[task_name].append(pred_label)
251
- task_pred_probs[task_name].append(pred_prob)
252
-
253
- # Update current index for cell ID tracking
254
- sample_counter += len(batch["input_ids"])
255
-
256
- batch_end = time.time()
257
- batch_times.append(batch_end - batch_start)
258
-
259
- # norm validation loss by the number of batches
260
- val_loss /= len(val_loader)
261
-
262
- # distributed, gather results from all processes
263
- if self.is_distributed:
264
- # Create tensors to hold the local results
265
- loss_tensor = torch.tensor([val_loss], device=self.device)
266
- gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
267
- dist.all_gather(gathered_losses, loss_tensor)
268
-
269
- # Compute average loss across all processes
270
- val_loss = torch.mean(torch.cat(gathered_losses)).item()
271
-
272
- world_size = dist.get_world_size()
273
-
274
- if self.is_main_process:
275
- print(f"Collected predictions from rank {self.local_rank}")
276
- print(f"Number of samples processed by this rank: {sample_counter}")
277
-
278
- val_end = time.time()
279
-
280
- if self.is_main_process:
281
- print(f"Validation timing:")
282
- print(f" Total validation time: {val_end - val_start:.2f}s")
283
- print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
284
- print(f" Collected {len(val_cell_ids)} cell indices from validation data")
285
- print(f" Processed {sample_counter} total samples during validation")
286
-
287
- # Print number of samples per task
288
- for task_name in self.config["task_names"]:
289
- print(f" Task {task_name}: {len(task_true_labels[task_name])} samples")
290
-
291
- return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids
292
-
293
- def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
294
- """Train the model and return validation loss and trained model."""
295
- if self.config.get("use_wandb", False) and self.is_main_process:
296
- initialize_wandb(self.config)
297
-
298
- # Create model
299
- self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank)
300
-
301
- # Setup optimizer and scheduler
302
- total_steps = len(train_loader) * self.config["epochs"]
303
- self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps)
304
-
305
- # Training loop
306
- if self.is_main_process:
307
- epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress")
308
- else:
309
- epoch_progress = range(self.config["epochs"])
310
-
311
- best_val_loss = float('inf')
312
- train_losses = []
313
-
314
- with setup_logging(self.config) as self.writer:
315
- for epoch in epoch_progress:
316
- if self.is_distributed:
317
- train_loader.sampler.set_epoch(epoch)
318
-
319
- train_loss = self.train_epoch(train_loader, epoch)
320
- train_losses.append(train_loss)
321
-
322
- # Run validation after each epoch if configured to do so
323
- if self.config.get("validate_each_epoch", False):
324
- val_loss, _, _, _, _ = self.validate_model(val_loader)
325
- if val_loss < best_val_loss:
326
- best_val_loss = val_loss
327
-
328
- if self.is_main_process:
329
- epoch_progress.set_postfix({
330
- "train_loss": f"{train_loss:.4f}",
331
- "val_loss": f"{val_loss:.4f}",
332
- "best_val_loss": f"{best_val_loss:.4f}"
333
- })
334
- else:
335
- if self.is_main_process:
336
- epoch_progress.set_postfix({
337
- "train_loss": f"{train_loss:.4f}"
338
- })
339
-
340
- val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader)
341
- task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
342
-
343
- if self.is_main_process:
344
- log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"])
345
-
346
- # Save validation predictions
347
- save_validation_predictions(
348
- val_cell_ids,
349
- task_true_labels,
350
- task_pred_labels,
351
- task_pred_probs,
352
- {**self.config, "val_cell_mapping": val_cell_id_mapping} # Include the mapping
353
- )
354
-
355
- if self.config.get("use_wandb", False):
356
- import wandb
357
- wandb.finish()
358
-
359
- print(f"\nTraining Summary:")
360
- print(f" Final Training Loss: {train_losses[-1]:.4f}")
361
- print(f" Final Validation Loss: {val_loss:.4f}")
362
- for task_name, metrics in task_metrics.items():
363
- print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
364
-
365
- return val_loss, self.model # Return both the validation loss and the trained model
366
-
367
- def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
368
- if self.is_distributed:
369
- self.device = torch.device(f"cuda:{self.local_rank}")
370
- else:
371
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
372
-
373
- self.model = create_model(self.config, num_labels_list, self.device)
374
-
375
- # war model w DDP
376
- if self.is_distributed:
377
- self.model = DDP(self.model, device_ids=[self.local_rank])
378
-
379
- # communication hook to optimize gradient synchronization
380
- from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
381
-
382
- # default hook which maintains full precision
383
- self.model.register_comm_hook(
384
- state=None,
385
- hook=comm_hooks.allreduce_hook
386
  )
387
-
388
- print(f"Rank {self.local_rank}: Registered communication hook for optimized gradient synchronization")
389
-
390
- print(f"Rank {self.local_rank}: Using samplers created in distributed worker")
391
- print(f"Rank {self.local_rank}: Training dataset has {len(train_loader.dataset)} samples")
392
- if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'num_samples'):
393
- print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch")
394
-
395
- if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
396
- print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples")
397
-
398
- # Set up optimizer and scheduler
399
- self.optimizer, self.scheduler = setup_optimizer_and_scheduler(
400
- self.model, self.config, len(train_loader)
401
  )
402
 
403
- if self.is_main_process and self.config.get("use_wandb", False):
404
- initialize_wandb(self.config)
405
-
406
- return train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
407
-
408
-
409
- def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
410
- """Train a model with the given configuration and data."""
411
- # Check if distributed training is enabled
412
- if config.get("distributed_training", False):
413
- # Check if we have multiple GPUs
414
- if torch.cuda.device_count() > 1:
415
- result = train_distributed(
416
- Trainer,
417
- config,
418
- train_loader,
419
- val_loader,
420
- train_cell_id_mapping,
421
- val_cell_id_mapping,
422
- num_labels_list
 
 
 
 
423
  )
424
- if result is not None:
425
- return result
426
- else:
427
- print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
428
- config["distributed_training"] = False
429
-
430
- # Non-distributed training
431
- trainer = Trainer(config)
432
- trainer.device = device
433
- return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
434
 
 
 
 
435
 
436
- def objective(
437
- trial,
 
 
438
  train_loader,
439
  val_loader,
440
  train_cell_id_mapping,
441
  val_cell_id_mapping,
442
  num_labels_list,
443
- config,
444
- device,
445
  ):
446
- """Objective function for Optuna hyperparameter optimization."""
447
  set_seed(config["seed"])
448
  initialize_wandb(config)
449
 
450
- trial_config = config.copy()
451
-
452
- # Suggest hyperparameters for this trial
453
- for param_name, param_config in config["hyperparameters"].items():
454
- if param_name == "lr_scheduler_type":
455
- trial_config[param_name] = trial.suggest_categorical(
456
- param_name, param_config["choices"]
457
- )
458
- elif param_name == "task_weights" and config["use_task_weights"]:
459
- weights = [
460
- trial.suggest_float(
461
- f"task_weight_{i}",
462
- param_config["low"],
463
- param_config["high"],
464
- )
465
- for i in range(len(num_labels_list))
466
- ]
467
- weight_sum = sum(weights)
468
- trial_config[param_name] = [w / weight_sum for w in weights]
469
- elif "log" in param_config and param_config["log"]:
470
- trial_config[param_name] = trial.suggest_float(
471
- param_name, param_config["low"], param_config["high"], log=True
472
- )
473
- else:
474
- trial_config[param_name] = trial.suggest_float(
475
- param_name, param_config["low"], param_config["high"]
476
- )
477
-
478
- # Set appropriate max layers to freeze based on pretrained model
479
- if "max_layers_to_freeze" in trial_config:
480
- freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
481
- trial_config["max_layers_to_freeze"] = int(trial.suggest_int(
482
- "max_layers_to_freeze",
483
- freeze_range["min"],
484
- freeze_range["max"]
485
- ))
486
-
487
- trial_config["run_name"] = f"trial_{trial.number}"
488
-
489
- # Handle distributed training for this trial
490
- if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1:
491
- manager = mp.Manager()
492
- shared_dict = manager.dict()
493
-
494
- train_distributed(
495
- Trainer,
496
- trial_config,
497
- train_loader,
498
- val_loader,
499
- train_cell_id_mapping,
500
- val_cell_id_mapping,
501
- num_labels_list,
502
- trial.number,
503
- shared_dict
504
- )
505
-
506
- val_loss = shared_dict.get('val_loss', float('inf'))
507
- task_metrics = shared_dict.get('task_metrics', {})
508
-
509
- trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {}))
510
- trial.set_user_attr("task_weights", trial_config["task_weights"])
511
-
512
- if config.get("use_wandb", False):
513
- import wandb
514
- wandb.log({
515
- "trial_number": trial.number,
516
- "val_loss": val_loss,
517
- **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
518
- **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
519
- })
520
- wandb.finish()
521
-
522
- return val_loss
523
-
524
- with setup_logging(trial_config) as writer:
525
- trainer = Trainer(trial_config)
526
- trainer.device = device
527
- trainer.writer = writer
528
-
529
- # Create model with trial hyperparameters
530
- trainer.model = create_model(trial_config, num_labels_list, device)
531
- total_steps = len(train_loader) * config["epochs"]
532
- trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps)
533
-
534
- # Training loop
535
- for epoch in range(config["epochs"]):
536
- trainer.train_epoch(train_loader, epoch)
537
-
538
- val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader)
539
- task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
540
-
541
- # Log metrics
542
- log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"])
543
-
544
- # Save validation predictions
545
- save_validation_predictions(
546
- val_cell_ids,
547
- task_true_labels,
548
- task_pred_labels,
549
- task_pred_probs,
550
- {**trial_config, "val_cell_mapping": val_cell_id_mapping},
551
- trial.number,
552
  )
 
553
 
554
- # Store model state dict and task weights in trial user attributes
555
- trial.set_user_attr("model_state_dict", trainer.model.state_dict())
556
- trial.set_user_attr("task_weights", trial_config["task_weights"])
 
557
 
558
- # Report intermediate value to Optuna
559
- trial.report(val_loss, config["epochs"])
560
- if trial.should_prune():
561
- raise optuna.TrialPruned()
562
 
563
- if config.get("use_wandb", False):
564
- import wandb
565
- wandb.log(
566
- {
567
- "trial_number": trial.number,
568
- "val_loss": val_loss,
569
- **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
570
- **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
571
- **{k: v for k, v in trial_config.items() if k in [
572
- "learning_rate", "warmup_ratio", "weight_decay", "dropout_rate",
573
- "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"
574
- ]},
575
- }
576
- )
577
- wandb.finish()
578
 
579
- return val_loss
 
580
 
 
581
 
582
- def run_manual_tuning(config):
583
- """Run training with manually specified hyperparameters."""
584
- (
585
- device,
586
- train_loader,
587
- val_loader,
588
- train_cell_id_mapping,
589
- val_cell_id_mapping,
590
- num_labels_list,
591
- ) = prepare_training_environment(config)
592
 
593
- print("\nManual hyperparameters being used:")
594
- for key, value in config["manual_hyperparameters"].items():
595
- print(f"{key}: {value}")
596
- print()
597
 
598
- # Update config with manual hyperparameters
599
- for key, value in config["manual_hyperparameters"].items():
600
- config[key] = value
 
 
 
 
 
 
 
 
 
601
 
602
- # Train the model
603
- val_loss, trained_model = train_model(
604
- config,
605
- device,
606
- train_loader,
607
- val_loader,
608
- train_cell_id_mapping,
609
- val_cell_id_mapping,
610
- num_labels_list,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  )
612
 
613
- print(f"\nValidation loss with manual hyperparameters: {val_loss}")
614
-
615
- # Save the trained model - only if not using distributed training
616
- # (distributed training saves the model in the worker)
617
- if not config.get("distributed_training", False):
618
- model_save_directory = os.path.join(
619
- config["model_save_path"], "GeneformerMultiTask"
620
- )
621
- save_model(trained_model, model_save_directory)
622
-
623
- # Save the hyperparameters
624
- hyperparams_to_save = {
625
- **config["manual_hyperparameters"],
626
- "dropout_rate": config["dropout_rate"],
627
- "use_task_weights": config["use_task_weights"],
628
- "task_weights": config["task_weights"],
629
- "max_layers_to_freeze": config["max_layers_to_freeze"],
630
- "use_attention_pooling": config["use_attention_pooling"],
631
- }
632
- save_hyperparameters(model_save_directory, hyperparams_to_save)
 
 
 
633
 
634
- return val_loss
 
 
635
 
 
 
636
 
637
- def run_optuna_study(config):
638
- """Run hyperparameter optimization using Optuna."""
639
- # Prepare training environment
640
- (
641
- device,
642
- train_loader,
643
- val_loader,
644
- train_cell_id_mapping,
645
- val_cell_id_mapping,
646
- num_labels_list,
647
- ) = prepare_training_environment(config)
648
-
649
- # If manual hyperparameters are specified, use them instead of running Optuna
650
- if config.get("use_manual_hyperparameters", False):
651
- return run_manual_tuning(config)
652
-
653
- # Create a partial function with fixed arguments for the objective
654
- objective_with_config_and_data = functools.partial(
655
- objective,
656
- train_loader=train_loader,
657
- val_loader=val_loader,
658
- train_cell_id_mapping=train_cell_id_mapping,
659
- val_cell_id_mapping=val_cell_id_mapping,
660
- num_labels_list=num_labels_list,
661
- config=config,
662
- device=device,
663
- )
664
 
665
- # Create and run the Optuna study
666
- study = optuna.create_study(
667
- direction="minimize", # Minimize validation loss
668
- study_name=config["study_name"],
669
- # storage=config["storage"],
670
- load_if_exists=True,
671
  )
 
672
 
673
- study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
 
674
 
675
- # After finding the best trial
676
- best_params = study.best_trial.params
677
- best_task_weights = study.best_trial.user_attrs["task_weights"]
678
- print("Saving the best model and its hyperparameters...")
679
-
680
- # Create a model with the best hyperparameters
681
- best_model = GeneformerMultiTask(
682
- config["pretrained_path"],
683
- num_labels_list,
684
- dropout_rate=best_params["dropout_rate"],
685
- use_task_weights=config["use_task_weights"],
686
- task_weights=best_task_weights,
687
- max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0),
688
- use_attention_pooling=best_params.get("use_attention_pooling", False),
689
  )
690
 
691
- # Get the best model state dictionary
692
- best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
693
 
694
- best_model_state_dict = {
695
- k.replace("module.", ""): v for k, v in best_model_state_dict.items()
696
- }
697
 
698
- best_model.load_state_dict(best_model_state_dict, strict=False)
 
699
 
700
- model_save_directory = os.path.join(
701
- config["model_save_path"], "GeneformerMultiTask"
702
- )
703
- save_model(best_model, model_save_directory)
704
 
705
- save_hyperparameters(model_save_directory, {**best_params, "task_weights": best_task_weights})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
 
707
- return study.best_trial.value # Return the best validation loss
 
1
  import os
2
+ import random
3
+
4
+ import numpy as np
5
  import pandas as pd
6
  import torch
 
 
 
7
  from torch.utils.tensorboard import SummaryWriter
8
  from tqdm import tqdm
 
 
 
9
 
10
+ from .imports import *
11
  from .model import GeneformerMultiTask
12
+ from .utils import calculate_task_specific_metrics, get_layer_freeze_range
13
+
14
+
15
+ def set_seed(seed):
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+
23
+
24
+ def initialize_wandb(config):
25
+ if config.get("use_wandb", False):
26
+ import wandb
27
+
28
+ wandb.init(project=config["wandb_project"], config=config)
29
+ print("Weights & Biases (wandb) initialized and will be used for logging.")
30
+ else:
31
+ print(
32
+ "Weights & Biases (wandb) is not enabled. Logging will use other methods."
33
+ )
34
+
35
+
36
+ def create_model(config, num_labels_list, device):
37
+ model = GeneformerMultiTask(
38
+ config["pretrained_path"],
39
+ num_labels_list,
40
+ dropout_rate=config["dropout_rate"],
41
+ use_task_weights=config["use_task_weights"],
42
+ task_weights=config["task_weights"],
43
+ max_layers_to_freeze=config["max_layers_to_freeze"],
44
+ use_attention_pooling=config["use_attention_pooling"],
45
+ )
46
+ if config["use_data_parallel"]:
47
+ model = nn.DataParallel(model)
48
+ return model.to(device)
49
+
50
+
51
+ def setup_optimizer_and_scheduler(model, config, total_steps):
52
+ optimizer = AdamW(
53
+ model.parameters(),
54
+ lr=config["learning_rate"],
55
+ weight_decay=config["weight_decay"],
56
+ )
57
+ warmup_steps = int(config["warmup_ratio"] * total_steps)
58
+
59
+ if config["lr_scheduler_type"] == "linear":
60
+ scheduler = get_linear_schedule_with_warmup(
61
+ optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
62
+ )
63
+ elif config["lr_scheduler_type"] == "cosine":
64
+ scheduler = get_cosine_schedule_with_warmup(
65
+ optimizer,
66
+ num_warmup_steps=warmup_steps,
67
+ num_training_steps=total_steps,
68
+ num_cycles=0.5,
69
+ )
70
+
71
+ return optimizer, scheduler
72
+
73
+
74
+ def train_epoch(
75
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
76
+ ):
77
+ model.train()
78
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
79
+ for batch_idx, batch in enumerate(progress_bar):
80
+ optimizer.zero_grad()
81
+ input_ids = batch["input_ids"].to(device)
82
+ attention_mask = batch["attention_mask"].to(device)
83
+ labels = [
84
+ batch["labels"][task_name].to(device) for task_name in config["task_names"]
85
+ ]
86
+
87
+ loss, _, _ = model(input_ids, attention_mask, labels)
88
+ loss.backward()
89
+
90
+ if config["gradient_clipping"]:
91
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
92
+
93
+ optimizer.step()
94
+ scheduler.step()
95
+
96
+ writer.add_scalar(
97
+ "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
98
+ )
99
+ if config.get("use_wandb", False):
100
+ import wandb
101
+
102
+ wandb.log({"Training Loss": loss.item()})
103
+
104
+ # Update progress bar
105
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
106
+
107
+ return loss.item() # Return the last batch loss
108
+
109
+
110
+ def validate_model(model, val_loader, device, config):
111
+ model.eval()
112
+ val_loss = 0.0
113
+ task_true_labels = {task_name: [] for task_name in config["task_names"]}
114
+ task_pred_labels = {task_name: [] for task_name in config["task_names"]}
115
+ task_pred_probs = {task_name: [] for task_name in config["task_names"]}
116
+
117
+ with torch.no_grad():
118
+ for batch in val_loader:
119
+ input_ids = batch["input_ids"].to(device)
120
+ attention_mask = batch["attention_mask"].to(device)
121
  labels = [
122
+ batch["labels"][task_name].to(device)
123
+ for task_name in config["task_names"]
124
  ]
125
+ loss, logits, _ = model(input_ids, attention_mask, labels)
126
+ val_loss += loss.item()
127
+
128
+ for sample_idx in range(len(batch["input_ids"])):
129
+ for i, task_name in enumerate(config["task_names"]):
130
+ true_label = batch["labels"][task_name][sample_idx].item()
131
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
132
+ pred_prob = (
133
+ torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
134
+ )
135
+ task_true_labels[task_name].append(true_label)
136
+ task_pred_labels[task_name].append(pred_label)
137
+ task_pred_probs[task_name].append(pred_prob)
138
+
139
+ val_loss /= len(val_loader)
140
+ return val_loss, task_true_labels, task_pred_labels, task_pred_probs
141
+
142
+
143
+ def log_metrics(task_metrics, val_loss, config, writer, epochs):
144
+ for task_name, metrics in task_metrics.items():
145
+ print(
146
+ f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
147
+ )
148
+ if config.get("use_wandb", False):
149
+ import wandb
150
 
151
+ wandb.log(
152
+ {
153
+ f"{task_name} Validation F1 Macro": metrics["f1"],
154
+ f"{task_name} Validation Accuracy": metrics["accuracy"],
155
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
+
158
+ writer.add_scalar("Validation Loss", val_loss, epochs)
159
+ for task_name, metrics in task_metrics.items():
160
+ writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
161
+ writer.add_scalar(
162
+ f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
 
 
 
 
 
 
 
 
163
  )
164
 
165
+
166
+ def save_validation_predictions(
167
+ val_cell_id_mapping,
168
+ task_true_labels,
169
+ task_pred_labels,
170
+ task_pred_probs,
171
+ config,
172
+ trial_number=None,
173
+ ):
174
+ if trial_number is not None:
175
+ trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
176
+ os.makedirs(trial_results_dir, exist_ok=True)
177
+ val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
178
+ else:
179
+ val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
180
+
181
+ rows = []
182
+ for sample_idx in range(len(val_cell_id_mapping)):
183
+ row = {"Cell ID": val_cell_id_mapping[sample_idx]}
184
+ for task_name in config["task_names"]:
185
+ row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
186
+ row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
187
+ row[f"{task_name} Probabilities"] = ",".join(
188
+ map(str, task_pred_probs[task_name][sample_idx])
189
  )
190
+ rows.append(row)
 
 
 
 
 
 
 
 
 
191
 
192
+ df = pd.DataFrame(rows)
193
+ df.to_csv(val_preds_file, index=False)
194
+ print(f"Validation predictions saved to {val_preds_file}")
195
 
196
+
197
+ def train_model(
198
+ config,
199
+ device,
200
  train_loader,
201
  val_loader,
202
  train_cell_id_mapping,
203
  val_cell_id_mapping,
204
  num_labels_list,
 
 
205
  ):
 
206
  set_seed(config["seed"])
207
  initialize_wandb(config)
208
 
209
+ model = create_model(config, num_labels_list, device)
210
+ total_steps = len(train_loader) * config["epochs"]
211
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
212
+
213
+ log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
214
+ writer = SummaryWriter(log_dir=log_dir)
215
+
216
+ epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
217
+ for epoch in epoch_progress:
218
+ last_loss = train_epoch(
219
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  )
221
+ epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
222
 
223
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
224
+ model, val_loader, device, config
225
+ )
226
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
227
 
228
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
229
+ writer.close()
 
 
230
 
231
+ save_validation_predictions(
232
+ val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
233
+ )
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ if config.get("use_wandb", False):
236
+ import wandb
237
 
238
+ wandb.finish()
239
 
240
+ print(f"\nFinal Validation Loss: {val_loss:.4f}")
241
+ return val_loss, model # Return both the validation loss and the trained model
 
 
 
 
 
 
 
 
242
 
 
 
 
 
243
 
244
+ def objective(
245
+ trial,
246
+ train_loader,
247
+ val_loader,
248
+ train_cell_id_mapping,
249
+ val_cell_id_mapping,
250
+ num_labels_list,
251
+ config,
252
+ device,
253
+ ):
254
+ set_seed(config["seed"]) # Set the seed before each trial
255
+ initialize_wandb(config)
256
 
257
+ # Hyperparameters
258
+ config["learning_rate"] = trial.suggest_float(
259
+ "learning_rate",
260
+ config["hyperparameters"]["learning_rate"]["low"],
261
+ config["hyperparameters"]["learning_rate"]["high"],
262
+ log=config["hyperparameters"]["learning_rate"]["log"],
263
+ )
264
+ config["warmup_ratio"] = trial.suggest_float(
265
+ "warmup_ratio",
266
+ config["hyperparameters"]["warmup_ratio"]["low"],
267
+ config["hyperparameters"]["warmup_ratio"]["high"],
268
+ )
269
+ config["weight_decay"] = trial.suggest_float(
270
+ "weight_decay",
271
+ config["hyperparameters"]["weight_decay"]["low"],
272
+ config["hyperparameters"]["weight_decay"]["high"],
273
+ )
274
+ config["dropout_rate"] = trial.suggest_float(
275
+ "dropout_rate",
276
+ config["hyperparameters"]["dropout_rate"]["low"],
277
+ config["hyperparameters"]["dropout_rate"]["high"],
278
+ )
279
+ config["lr_scheduler_type"] = trial.suggest_categorical(
280
+ "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
+ )
282
+ config["use_attention_pooling"] = trial.suggest_categorical(
283
+ "use_attention_pooling", [False]
284
  )
285
 
286
+ if config["use_task_weights"]:
287
+ config["task_weights"] = [
288
+ trial.suggest_float(
289
+ f"task_weight_{i}",
290
+ config["hyperparameters"]["task_weights"]["low"],
291
+ config["hyperparameters"]["task_weights"]["high"],
292
+ )
293
+ for i in range(len(num_labels_list))
294
+ ]
295
+ weight_sum = sum(config["task_weights"])
296
+ config["task_weights"] = [
297
+ weight / weight_sum for weight in config["task_weights"]
298
+ ]
299
+ else:
300
+ config["task_weights"] = None
301
+
302
+ # Dynamic range for max_layers_to_freeze
303
+ freeze_range = get_layer_freeze_range(config["pretrained_path"])
304
+ config["max_layers_to_freeze"] = trial.suggest_int(
305
+ "max_layers_to_freeze",
306
+ freeze_range["min"],
307
+ freeze_range["max"]
308
+ )
309
 
310
+ model = create_model(config, num_labels_list, device)
311
+ total_steps = len(train_loader) * config["epochs"]
312
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
313
 
314
+ log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
315
+ writer = SummaryWriter(log_dir=log_dir)
316
 
317
+ for epoch in range(config["epochs"]):
318
+ train_epoch(
319
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
320
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
323
+ model, val_loader, device, config
 
 
 
 
324
  )
325
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
326
 
327
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
328
+ writer.close()
329
 
330
+ save_validation_predictions(
331
+ val_cell_id_mapping,
332
+ task_true_labels,
333
+ task_pred_labels,
334
+ task_pred_probs,
335
+ config,
336
+ trial.number,
 
 
 
 
 
 
 
337
  )
338
 
339
+ trial.set_user_attr("model_state_dict", model.state_dict())
340
+ trial.set_user_attr("task_weights", config["task_weights"])
341
 
342
+ trial.report(val_loss, config["epochs"])
 
 
343
 
344
+ if trial.should_prune():
345
+ raise optuna.TrialPruned()
346
 
347
+ if config.get("use_wandb", False):
348
+ import wandb
 
 
349
 
350
+ wandb.log(
351
+ {
352
+ "trial_number": trial.number,
353
+ "val_loss": val_loss,
354
+ **{
355
+ f"{task_name}_f1": metrics["f1"]
356
+ for task_name, metrics in task_metrics.items()
357
+ },
358
+ **{
359
+ f"{task_name}_accuracy": metrics["accuracy"]
360
+ for task_name, metrics in task_metrics.items()
361
+ },
362
+ **{
363
+ k: v
364
+ for k, v in config.items()
365
+ if k
366
+ in [
367
+ "learning_rate",
368
+ "warmup_ratio",
369
+ "weight_decay",
370
+ "dropout_rate",
371
+ "lr_scheduler_type",
372
+ "use_attention_pooling",
373
+ "max_layers_to_freeze",
374
+ ]
375
+ },
376
+ }
377
+ )
378
+ wandb.finish()
379
 
380
+ return val_loss
geneformer/mtl/train_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from .data import get_data_loader, preload_and_process_data
4
+ from .imports import *
5
+ from .model import GeneformerMultiTask
6
+ from .train import objective, train_model
7
+ from .utils import save_model
8
+
9
+
10
+ def set_seed(seed):
11
+ random.seed(seed)
12
+ np.random.seed(seed)
13
+ torch.manual_seed(seed)
14
+ torch.cuda.manual_seed_all(seed)
15
+ torch.backends.cudnn.deterministic = True
16
+ torch.backends.cudnn.benchmark = False
17
+
18
+
19
+ def run_manual_tuning(config):
20
+ # Set seed for reproducibility
21
+ set_seed(config["seed"])
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ (
25
+ train_dataset,
26
+ train_cell_id_mapping,
27
+ val_dataset,
28
+ val_cell_id_mapping,
29
+ num_labels_list,
30
+ ) = preload_and_process_data(config)
31
+ train_loader = get_data_loader(train_dataset, config["batch_size"])
32
+ val_loader = get_data_loader(val_dataset, config["batch_size"])
33
+
34
+ # Print the manual hyperparameters being used
35
+ print("\nManual hyperparameters being used:")
36
+ for key, value in config["manual_hyperparameters"].items():
37
+ print(f"{key}: {value}")
38
+ print() # Add an empty line for better readability
39
+
40
+ # Use the manual hyperparameters
41
+ for key, value in config["manual_hyperparameters"].items():
42
+ config[key] = value
43
+
44
+ # Train the model
45
+ val_loss, trained_model = train_model(
46
+ config,
47
+ device,
48
+ train_loader,
49
+ val_loader,
50
+ train_cell_id_mapping,
51
+ val_cell_id_mapping,
52
+ num_labels_list,
53
+ )
54
+
55
+ print(f"\nValidation loss with manual hyperparameters: {val_loss}")
56
+
57
+ # Save the trained model
58
+ model_save_directory = os.path.join(
59
+ config["model_save_path"], "GeneformerMultiTask"
60
+ )
61
+ save_model(trained_model, model_save_directory)
62
+
63
+ # Save the hyperparameters
64
+ hyperparams_to_save = {
65
+ **config["manual_hyperparameters"],
66
+ "dropout_rate": config["dropout_rate"],
67
+ "use_task_weights": config["use_task_weights"],
68
+ "task_weights": config["task_weights"],
69
+ "max_layers_to_freeze": config["max_layers_to_freeze"],
70
+ "use_attention_pooling": config["use_attention_pooling"],
71
+ }
72
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
73
+ with open(hyperparams_path, "w") as f:
74
+ json.dump(hyperparams_to_save, f)
75
+ print(f"Manual hyperparameters saved to {hyperparams_path}")
76
+
77
+ return val_loss
78
+
79
+
80
+ def run_optuna_study(config):
81
+ # Set seed for reproducibility
82
+ set_seed(config["seed"])
83
+
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ (
86
+ train_dataset,
87
+ train_cell_id_mapping,
88
+ val_dataset,
89
+ val_cell_id_mapping,
90
+ num_labels_list,
91
+ ) = preload_and_process_data(config)
92
+ train_loader = get_data_loader(train_dataset, config["batch_size"])
93
+ val_loader = get_data_loader(val_dataset, config["batch_size"])
94
+
95
+ if config["use_manual_hyperparameters"]:
96
+ train_model(
97
+ config,
98
+ device,
99
+ train_loader,
100
+ val_loader,
101
+ train_cell_id_mapping,
102
+ val_cell_id_mapping,
103
+ num_labels_list,
104
+ )
105
+ else:
106
+ objective_with_config_and_data = functools.partial(
107
+ objective,
108
+ train_loader=train_loader,
109
+ val_loader=val_loader,
110
+ train_cell_id_mapping=train_cell_id_mapping,
111
+ val_cell_id_mapping=val_cell_id_mapping,
112
+ num_labels_list=num_labels_list,
113
+ config=config,
114
+ device=device,
115
+ )
116
+
117
+ study = optuna.create_study(
118
+ direction="minimize", # Minimize validation loss
119
+ study_name=config["study_name"],
120
+ # storage=config["storage"],
121
+ load_if_exists=True,
122
+ )
123
+
124
+ study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
125
+
126
+ # After finding the best trial
127
+ best_params = study.best_trial.params
128
+ best_task_weights = study.best_trial.user_attrs["task_weights"]
129
+ print("Saving the best model and its hyperparameters...")
130
+
131
+ # Saving model as before
132
+ best_model = GeneformerMultiTask(
133
+ config["pretrained_path"],
134
+ num_labels_list,
135
+ dropout_rate=best_params["dropout_rate"],
136
+ use_task_weights=config["use_task_weights"],
137
+ task_weights=best_task_weights,
138
+ )
139
+
140
+ # Get the best model state dictionary
141
+ best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
142
+
143
+ # Remove the "module." prefix from the state dictionary keys if present
144
+ best_model_state_dict = {
145
+ k.replace("module.", ""): v for k, v in best_model_state_dict.items()
146
+ }
147
+
148
+ # Load the modified state dictionary into the model, skipping unexpected keys
149
+ best_model.load_state_dict(best_model_state_dict, strict=False)
150
+
151
+ model_save_directory = os.path.join(
152
+ config["model_save_path"], "GeneformerMultiTask"
153
+ )
154
+ save_model(best_model, model_save_directory)
155
+
156
+ # Additionally, save the best hyperparameters and task weights
157
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
158
+
159
+ with open(hyperparams_path, "w") as f:
160
+ json.dump({**best_params, "task_weights": best_task_weights}, f)
161
+ print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
geneformer/mtl/utils.py CHANGED
@@ -1,641 +1,129 @@
1
- from typing import Dict, List, Optional, Union
2
- import json
3
  import os
4
- import pickle
5
- import random
6
- import torch
7
- import numpy as np
8
- import optuna
9
  from sklearn.metrics import accuracy_score, f1_score
10
  from sklearn.preprocessing import LabelEncoder
11
- from torch.utils.tensorboard import SummaryWriter
12
- from transformers import AutoConfig, BertConfig, BertModel, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
13
- from torch.optim import AdamW
14
- import pandas as pd
15
- import torch.distributed as dist
16
- from torch.nn.parallel import DistributedDataParallel as DDP
17
- import torch.multiprocessing as mp
18
- from contextlib import contextmanager
19
-
20
-
21
- def set_seed(seed):
22
- random.seed(seed)
23
- np.random.seed(seed)
24
- torch.manual_seed(seed)
25
- torch.cuda.manual_seed_all(seed)
26
- torch.backends.cudnn.deterministic = True
27
- torch.backends.cudnn.benchmark = False
28
-
29
-
30
- def initialize_wandb(config):
31
- if config.get("use_wandb", False):
32
- import wandb
33
- wandb.init(
34
- project=config.get("wandb_project", "geneformer_multitask"),
35
- name=config.get("run_name", "experiment"),
36
- config=config,
37
- reinit=True,
38
- )
39
 
40
-
41
- def create_model(config, num_labels_list, device, is_distributed=False, local_rank=0):
42
- """Create and initialize the model based on configuration."""
43
- from .model import GeneformerMultiTask
44
-
45
- model = GeneformerMultiTask(
46
- config["pretrained_path"],
47
- num_labels_list,
48
- dropout_rate=config.get("dropout_rate", 0.1),
49
- use_task_weights=config.get("use_task_weights", False),
50
- task_weights=config.get("task_weights", None),
51
- max_layers_to_freeze=config.get("max_layers_to_freeze", 0),
52
- use_attention_pooling=config.get("use_attention_pooling", False),
53
- )
54
-
55
- # Move model to device
56
- model.to(device)
57
-
58
- if is_distributed:
59
- model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
60
-
61
- return model
62
-
63
-
64
- def setup_optimizer_and_scheduler(model, config, total_steps):
65
- """Set up optimizer and learning rate scheduler."""
66
- no_decay = ["bias", "LayerNorm.weight"]
67
- optimizer_grouped_parameters = [
68
- {
69
- "params": [p for n, p in model.named_parameters()
70
- if not any(nd in n for nd in no_decay) and p.requires_grad],
71
- "weight_decay": config["weight_decay"],
72
- },
73
- {
74
- "params": [p for n, p in model.named_parameters()
75
- if any(nd in n for nd in no_decay) and p.requires_grad],
76
- "weight_decay": 0.0,
77
- },
78
- ]
79
-
80
- optimizer = AdamW(
81
- optimizer_grouped_parameters,
82
- lr=config["learning_rate"],
83
- eps=config.get("adam_epsilon", 1e-8)
84
- )
85
-
86
- # Prepare scheduler
87
- warmup_steps = int(total_steps * config["warmup_ratio"])
88
-
89
- scheduler_map = {
90
- "linear": get_linear_schedule_with_warmup,
91
- "cosine": get_cosine_schedule_with_warmup
92
- }
93
-
94
- scheduler_fn = scheduler_map.get(config["lr_scheduler_type"])
95
- if not scheduler_fn:
96
- raise ValueError(f"Unsupported scheduler type: {config['lr_scheduler_type']}")
97
-
98
- scheduler = scheduler_fn(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
99
-
100
- return optimizer, scheduler
101
 
102
 
103
  def save_model(model, model_save_directory):
104
- """Save model weights and configuration."""
105
- os.makedirs(model_save_directory, exist_ok=True)
106
-
107
- # Handle DDP model
108
- if isinstance(model, DDP):
109
- model_to_save = model.module
 
 
110
  else:
111
- model_to_save = model
112
-
113
- model_state_dict = model_to_save.state_dict()
 
 
 
114
 
115
  model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
116
  torch.save(model_state_dict, model_save_path)
117
 
118
  # Save the model configuration
119
- model_to_save.config.to_json_file(os.path.join(model_save_directory, "config.json"))
120
-
121
- print(f"Model and configuration saved to {model_save_directory}")
122
-
123
-
124
- def save_hyperparameters(model_save_directory, hyperparams):
125
- """Save hyperparameters to a JSON file."""
126
- hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
127
- with open(hyperparams_path, "w") as f:
128
- json.dump(hyperparams, f)
129
- print(f"Hyperparameters saved to {hyperparams_path}")
130
-
131
-
132
- def calculate_metrics(labels=None, preds=None, task_data=None, metric_type="task_specific", return_format="dict"):
133
- if metric_type == "single":
134
- # Calculate metrics for a single task
135
- if labels is None or preds is None:
136
- raise ValueError("Labels and predictions must be provided for single task metrics")
137
-
138
- task_name = None
139
- if isinstance(labels, dict) and len(labels) == 1:
140
- task_name = list(labels.keys())[0]
141
- labels = labels[task_name]
142
- preds = preds[task_name]
143
-
144
- f1 = f1_score(labels, preds, average="macro")
145
- accuracy = accuracy_score(labels, preds)
146
-
147
- if return_format == "tuple":
148
- return f1, accuracy
149
-
150
- result = {"f1": f1, "accuracy": accuracy}
151
- if task_name:
152
- return {task_name: result}
153
- return result
154
-
155
- elif metric_type == "task_specific":
156
- # Calculate metrics for multiple tasks
157
- if task_data:
158
- result = {}
159
- for task_name, (task_labels, task_preds) in task_data.items():
160
- f1 = f1_score(task_labels, task_preds, average="macro")
161
- accuracy = accuracy_score(task_labels, task_preds)
162
- result[task_name] = {"f1": f1, "accuracy": accuracy}
163
- return result
164
- elif isinstance(labels, dict) and isinstance(preds, dict):
165
- result = {}
166
- for task_name in labels:
167
- if task_name in preds:
168
- f1 = f1_score(labels[task_name], preds[task_name], average="macro")
169
- accuracy = accuracy_score(labels[task_name], preds[task_name])
170
- result[task_name] = {"f1": f1, "accuracy": accuracy}
171
- return result
172
- else:
173
- raise ValueError("For task_specific metrics, either task_data or labels and preds dictionaries must be provided")
174
-
175
- elif metric_type == "combined":
176
- # Calculate combined metrics across all tasks
177
- if labels is None or preds is None:
178
- raise ValueError("Labels and predictions must be provided for combined metrics")
179
-
180
- # Handle label encoding for non-numeric labels
181
- if not all(isinstance(x, (int, float)) for x in labels + preds):
182
- le = LabelEncoder()
183
- le.fit(labels + preds)
184
- labels = le.transform(labels)
185
- preds = le.transform(preds)
186
-
187
- f1 = f1_score(labels, preds, average="macro")
188
- accuracy = accuracy_score(labels, preds)
189
-
190
- if return_format == "tuple":
191
- return f1, accuracy
192
- return {"f1": f1, "accuracy": accuracy}
193
-
194
- else:
195
- raise ValueError(f"Unknown metric_type: {metric_type}")
196
-
197
-
198
- def get_layer_freeze_range(pretrained_path):
199
- if not pretrained_path:
200
- return {"min": 0, "max": 0}
201
-
202
- config = AutoConfig.from_pretrained(pretrained_path)
203
- total_layers = config.num_hidden_layers
204
- return {"min": 0, "max": total_layers - 1}
205
-
206
-
207
- def prepare_training_environment(config):
208
- """
209
- Prepare the training environment by setting seed and loading data.
210
-
211
- Returns:
212
- tuple: (device, train_loader, val_loader, train_cell_id_mapping,
213
- val_cell_id_mapping, num_labels_list)
214
- """
215
- from .data import prepare_data_loaders
216
-
217
- # Set seed for reproducibility
218
- set_seed(config["seed"])
219
-
220
- # Set up device - for non-distributed training
221
- if not config.get("distributed_training", False):
222
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
  else:
224
- # For distributed training, device will be set per process
225
- device = None
226
-
227
- # Load data using the streaming dataset
228
- data = prepare_data_loaders(config)
229
-
230
- # For distributed training, we'll set up samplers later in the distributed worker
231
- # Don't create DistributedSampler here as process group isn't initialized yet
232
-
233
- return (
234
- device,
235
- data["train_loader"],
236
- data["val_loader"],
237
- data["train_cell_mapping"],
238
- data["val_cell_mapping"],
239
- data["num_labels_list"],
240
- )
241
 
 
242
 
243
- # Optuna hyperparameter optimization utilities
244
- def save_trial_callback(study, trial, trials_result_path):
245
- """
246
- Callback to save Optuna trial results to a file.
247
-
248
- Args:
249
- study: Optuna study object
250
- trial: Current trial object
251
- trials_result_path: Path to save trial results
252
- """
253
- with open(trials_result_path, "a") as f:
254
- f.write(
255
- f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
256
- )
257
 
 
 
 
 
 
 
 
 
 
258
 
259
- def create_optuna_study(objective, n_trials: int, trials_result_path: str, tensorboard_log_dir: str) -> optuna.Study:
260
- """Create and run an Optuna study with TensorBoard logging."""
261
- from optuna.integration import TensorBoardCallback
262
-
263
- study = optuna.create_study(direction="maximize")
264
- study.optimize(
265
- objective,
266
- n_trials=n_trials,
267
- callbacks=[
268
- lambda study, trial: save_trial_callback(study, trial, trials_result_path),
269
- TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
270
- ]
271
- )
272
- return study
273
 
 
 
 
274
 
275
- @contextmanager
276
- def setup_logging(config):
277
- run_name = config.get("run_name", "manual_run")
278
- log_dir = os.path.join(config["tensorboard_log_dir"], run_name)
279
- writer = SummaryWriter(log_dir=log_dir)
280
- try:
281
- yield writer
282
- finally:
283
- writer.close()
284
 
 
 
 
 
285
 
286
- def log_training_step(loss, writer, config, epoch, steps_per_epoch, batch_idx):
287
- """Log training step metrics to TensorBoard and optionally W&B."""
288
- writer.add_scalar(
289
- "Training Loss", loss, epoch * steps_per_epoch + batch_idx
290
- )
291
- if config.get("use_wandb", False):
292
- import wandb
293
- wandb.log({"Training Loss": loss})
294
 
 
 
295
 
296
- def log_validation_metrics(task_metrics, val_loss, config, writer, epoch):
297
- """Log validation metrics to console, TensorBoard, and optionally W&B."""
298
- for task_name, metrics in task_metrics.items():
299
- print(
300
- f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
301
- )
302
- if config.get("use_wandb", False):
303
- import wandb
304
- wandb.log(
305
- {
306
- f"{task_name} Validation F1 Macro": metrics["f1"],
307
- f"{task_name} Validation Accuracy": metrics["accuracy"],
308
- }
309
- )
310
 
311
- writer.add_scalar("Validation Loss", val_loss, epoch)
312
- for task_name, metrics in task_metrics.items():
313
- writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epoch)
314
- writer.add_scalar(
315
- f"{task_name} - Validation Accuracy", metrics["accuracy"], epoch
316
- )
317
 
 
 
 
 
 
318
 
319
- def load_label_mappings(results_dir: str, task_names: List[str]) -> Dict[str, Dict]:
320
- """Load or initialize task label mappings."""
321
- label_mappings_path = os.path.join(results_dir, "task_label_mappings_val.pkl")
322
- if os.path.exists(label_mappings_path):
323
- with open(label_mappings_path, 'rb') as f:
324
- return pickle.load(f)
325
- return {task_name: {} for task_name in task_names}
326
 
 
 
 
 
 
327
 
328
- def create_prediction_row(sample_idx: int, val_cell_indices: Dict, task_true_labels: Dict,
329
- task_pred_labels: Dict, task_pred_probs: Dict, task_names: List[str],
330
- inverted_mappings: Dict, val_cell_mapping: Dict) -> Dict:
331
- """Create a row for validation predictions."""
332
- batch_cell_idx = val_cell_indices.get(sample_idx)
333
- cell_id = val_cell_mapping.get(batch_cell_idx, f"unknown_cell_{sample_idx}") if batch_cell_idx is not None else f"unknown_cell_{sample_idx}"
334
-
335
- row = {"Cell ID": cell_id}
336
- for task_name in task_names:
337
- if task_name in task_true_labels and sample_idx < len(task_true_labels[task_name]):
338
- true_idx = task_true_labels[task_name][sample_idx]
339
- pred_idx = task_pred_labels[task_name][sample_idx]
340
- true_label = inverted_mappings.get(task_name, {}).get(true_idx, f"Unknown-{true_idx}")
341
- pred_label = inverted_mappings.get(task_name, {}).get(pred_idx, f"Unknown-{pred_idx}")
342
-
343
- row.update({
344
- f"{task_name}_true_idx": true_idx,
345
- f"{task_name}_pred_idx": pred_idx,
346
- f"{task_name}_true_label": true_label,
347
- f"{task_name}_pred_label": pred_label
348
- })
349
-
350
- if task_name in task_pred_probs and sample_idx < len(task_pred_probs[task_name]):
351
- probs = task_pred_probs[task_name][sample_idx]
352
- if isinstance(probs, (list, np.ndarray)) or (hasattr(probs, '__iter__') and not isinstance(probs, str)):
353
- prob_list = list(probs) if not isinstance(probs, list) else probs
354
- row[f"{task_name}_all_probs"] = ",".join(map(str, prob_list))
355
- for class_idx, prob in enumerate(prob_list):
356
- class_label = inverted_mappings.get(task_name, {}).get(class_idx, f"Unknown-{class_idx}")
357
- row[f"{task_name}_prob_{class_label}"] = prob
358
- else:
359
- row[f"{task_name}_all_probs"] = str(probs)
360
-
361
- return row
362
 
 
 
363
 
364
- def save_validation_predictions(
365
- val_cell_indices,
366
- task_true_labels,
367
- task_pred_labels,
368
- task_pred_probs,
369
- config,
370
- trial_number=None,
371
- ):
372
- """Save validation predictions to a CSV file with class labels and probabilities."""
373
- os.makedirs(config["results_dir"], exist_ok=True)
374
-
375
- if trial_number is not None:
376
- os.makedirs(os.path.join(config["results_dir"], f"trial_{trial_number}"), exist_ok=True)
377
- val_preds_file = os.path.join(config["results_dir"], f"trial_{trial_number}/val_preds.csv")
378
- else:
379
- val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
380
-
381
- if not val_cell_indices or not task_true_labels:
382
- pd.DataFrame().to_csv(val_preds_file, index=False)
383
- return
384
-
385
- try:
386
- label_mappings = load_label_mappings(config["results_dir"], config["task_names"])
387
- inverted_mappings = {task: {idx: label for label, idx in mapping.items()} for task, mapping in label_mappings.items()}
388
- val_cell_mapping = config.get("val_cell_mapping", {})
389
-
390
- # Determine maximum number of samples
391
- max_samples = max(
392
- [len(val_cell_indices)] +
393
- [len(task_true_labels[task]) for task in task_true_labels]
394
- )
395
-
396
- rows = [
397
- create_prediction_row(
398
- sample_idx, val_cell_indices, task_true_labels, task_pred_labels,
399
- task_pred_probs, config["task_names"], inverted_mappings, val_cell_mapping
400
- )
401
- for sample_idx in range(max_samples)
402
- ]
403
-
404
- pd.DataFrame(rows).to_csv(val_preds_file, index=False)
405
- except Exception as e:
406
- pd.DataFrame([{"Error": str(e)}]).to_csv(val_preds_file, index=False)
407
 
 
 
 
 
 
408
 
409
- def setup_distributed_environment(rank, world_size, config):
410
- """
411
- Setup the distributed training environment.
412
-
413
- Args:
414
- rank (int): The rank of the current process
415
- world_size (int): Total number of processes
416
- config (dict): Configuration dictionary
417
- """
418
- os.environ['MASTER_ADDR'] = config.get('master_addr', 'localhost')
419
- os.environ['MASTER_PORT'] = config.get('master_port', '12355')
420
-
421
- # Initialize the process group
422
- dist.init_process_group(
423
- backend='nccl',
424
- init_method='env://',
425
- world_size=world_size,
426
- rank=rank
427
- )
428
-
429
- # Set the device for this process
430
- torch.cuda.set_device(rank)
431
 
432
 
433
- def train_distributed(trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
434
- """Run distributed training across multiple GPUs with fallback to single GPU."""
435
- world_size = torch.cuda.device_count()
436
-
437
- if world_size <= 1:
438
- print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
439
- config["distributed_training"] = False
440
- trainer = trainer_class(config)
441
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
442
- trainer.device = device
443
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
444
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
445
- )
446
- val_loss, model = trainer.train(
447
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
448
- )
449
- model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
450
- save_model(model, model_save_directory)
451
- save_hyperparameters(model_save_directory, {
452
- **get_config_value(config, "manual_hyperparameters", {}),
453
- "dropout_rate": config["dropout_rate"],
454
- "use_task_weights": config["use_task_weights"],
455
- "task_weights": config["task_weights"],
456
- "max_layers_to_freeze": config["max_layers_to_freeze"],
457
- "use_attention_pooling": config["use_attention_pooling"],
458
- })
459
-
460
- if shared_dict is not None:
461
- shared_dict['val_loss'] = val_loss
462
- task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(model, val_loader, device, config)
463
- shared_dict['task_metrics'] = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
464
- shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()}
465
-
466
- return val_loss, model
467
-
468
- print(f"Using distributed training with {world_size} GPUs")
469
- mp.spawn(
470
- _distributed_worker,
471
- args=(world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number, shared_dict),
472
- nprocs=world_size,
473
- join=True
474
- )
475
-
476
- if trial_number is None and shared_dict is None:
477
- model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
478
- model_path = os.path.join(model_save_directory, "pytorch_model.bin")
479
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
480
- model = create_model(config, num_labels_list, device)
481
- model.load_state_dict(torch.load(model_path))
482
- return 0.0, model
483
-
484
- return None
485
-
486
-
487
- def _distributed_worker(rank, world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
488
- """Worker function for distributed training."""
489
- setup_distributed_environment(rank, world_size, config)
490
- config["local_rank"] = rank
491
-
492
- # Set up distributed samplers
493
- from torch.utils.data import DistributedSampler
494
- from .data import get_data_loader
495
-
496
- train_sampler = DistributedSampler(train_loader.dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
497
- val_sampler = DistributedSampler(val_loader.dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
498
-
499
- train_loader = get_data_loader(train_loader.dataset, config["batch_size"], sampler=train_sampler, shuffle=False)
500
- val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=val_sampler, shuffle=False)
501
-
502
- if rank == 0:
503
- print(f"Rank {rank}: Training {len(train_sampler)} samples, Validation {len(val_sampler)} samples")
504
- print(f"Total samples across {world_size} GPUs: Training {len(train_sampler) * world_size}, Validation {len(val_sampler) * world_size}")
505
-
506
- # Create and setup trainer
507
- trainer = trainer_class(config)
508
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
509
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
510
- )
511
-
512
- # Train the model
513
- val_loss, model = trainer.train(
514
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
515
- )
516
-
517
- # Save model only from the main process
518
- if rank == 0:
519
- model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
520
- save_model(model, model_save_directory)
521
-
522
- save_hyperparameters(model_save_directory, {
523
- **get_config_value(config, "manual_hyperparameters", {}),
524
- "dropout_rate": config["dropout_rate"],
525
- "use_task_weights": config["use_task_weights"],
526
- "task_weights": config["task_weights"],
527
- "max_layers_to_freeze": config["max_layers_to_freeze"],
528
- "use_attention_pooling": config["use_attention_pooling"],
529
- })
530
-
531
- # For Optuna trials, store results in shared dictionary
532
- if shared_dict is not None:
533
- shared_dict['val_loss'] = val_loss
534
-
535
- # Run validation on full dataset from rank 0 for consistent metrics
536
- full_val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=None, shuffle=False)
537
-
538
- # Get validation predictions using our utility function
539
- task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(
540
- model, full_val_loader, trainer.device, config
541
- )
542
-
543
- # Calculate metrics
544
- task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
545
- shared_dict['task_metrics'] = task_metrics
546
-
547
- # Store model state dict
548
- if isinstance(model, DDP):
549
- model_state_dict = model.module.state_dict()
550
- else:
551
- model_state_dict = model.state_dict()
552
-
553
- shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model_state_dict.items()}
554
-
555
- # Clean up distributed environment
556
- dist.destroy_process_group()
557
-
558
-
559
- def save_model_without_heads(model_directory):
560
  """
561
- Save a version of the fine-tuned model without classification heads.
562
-
563
  Args:
564
- model_directory (str): Path to the directory containing the fine-tuned model
 
 
565
  """
566
- import torch
567
- from transformers import BertConfig, BertModel
568
-
569
- # Load the full model
570
- model_path = os.path.join(model_directory, "pytorch_model.bin")
571
- config_path = os.path.join(model_directory, "config.json")
572
-
573
- if not os.path.exists(model_path) or not os.path.exists(config_path):
574
- raise FileNotFoundError(f"Model files not found in {model_directory}")
575
-
576
- # Load the configuration
577
- config = BertConfig.from_json_file(config_path)
578
-
579
- # Load the model state dict
580
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
581
-
582
- # Create a new model without heads
583
- base_model = BertModel(config)
584
-
585
- # Filter out the classification head parameters
586
- base_model_state_dict = {}
587
- for key, value in state_dict.items():
588
- # Only keep parameters that belong to the base model (not classification heads)
589
- if not key.startswith('classification_heads') and not key.startswith('attention_pool'):
590
- base_model_state_dict[key] = value
591
-
592
- # Load the filtered state dict into the base model
593
- base_model.load_state_dict(base_model_state_dict, strict=False)
594
-
595
- # Save the model without heads
596
- output_dir = os.path.join(model_directory, "model_without_heads")
597
- os.makedirs(output_dir, exist_ok=True)
598
-
599
- # Save the model weights
600
- torch.save(base_model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
601
-
602
- # Save the configuration
603
- base_model.config.to_json_file(os.path.join(output_dir, "config.json"))
604
-
605
- print(f"Model without classification heads saved to {output_dir}")
606
- return output_dir
607
-
608
-
609
- def get_config_value(config: Dict, key: str, default=None):
610
-
611
- return config.get(key, default)
612
-
613
-
614
- def collect_validation_predictions(model, val_loader, device, config) -> tuple:
615
- task_true_labels = {}
616
- task_pred_labels = {}
617
- task_pred_probs = {}
618
-
619
- with torch.no_grad():
620
- for batch in val_loader:
621
- input_ids = batch["input_ids"].to(device)
622
- attention_mask = batch["attention_mask"].to(device)
623
- labels = [batch["labels"][task_name].to(device) for task_name in config["task_names"]]
624
- _, logits, _ = model(input_ids, attention_mask, labels)
625
-
626
- for sample_idx in range(len(batch["input_ids"])):
627
- for i, task_name in enumerate(config["task_names"]):
628
- if task_name not in task_true_labels:
629
- task_true_labels[task_name] = []
630
- task_pred_labels[task_name] = []
631
- task_pred_probs[task_name] = []
632
-
633
- true_label = batch["labels"][task_name][sample_idx].item()
634
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
635
- pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
636
-
637
- task_true_labels[task_name].append(true_label)
638
- task_pred_labels[task_name].append(pred_label)
639
- task_pred_probs[task_name].append(pred_prob)
640
-
641
- return task_true_labels, task_pred_labels, task_pred_probs
 
 
 
1
  import os
2
+ import shutil
3
+
 
 
 
4
  from sklearn.metrics import accuracy_score, f1_score
5
  from sklearn.preprocessing import LabelEncoder
6
+ from transformers import AutoConfig, BertConfig, BertModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from .imports import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def save_model(model, model_save_directory):
12
+ if not os.path.exists(model_save_directory):
13
+ os.makedirs(model_save_directory)
14
+
15
+ # Get the state dict
16
+ if isinstance(model, nn.DataParallel):
17
+ model_state_dict = (
18
+ model.module.state_dict()
19
+ ) # Use model.module to access the underlying model
20
  else:
21
+ model_state_dict = model.state_dict()
22
+
23
+ # Remove the "module." prefix from the keys if present
24
+ model_state_dict = {
25
+ k.replace("module.", ""): v for k, v in model_state_dict.items()
26
+ }
27
 
28
  model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
29
  torch.save(model_state_dict, model_save_path)
30
 
31
  # Save the model configuration
32
+ if isinstance(model, nn.DataParallel):
33
+ model.module.config.to_json_file(
34
+ os.path.join(model_save_directory, "config.json")
35
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  else:
37
+ model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ print(f"Model and configuration saved to {model_save_directory}")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
43
+ task_metrics = {}
44
+ for task_name in task_true_labels.keys():
45
+ true_labels = task_true_labels[task_name]
46
+ pred_labels = task_pred_labels[task_name]
47
+ f1 = f1_score(true_labels, pred_labels, average="macro")
48
+ accuracy = accuracy_score(true_labels, pred_labels)
49
+ task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
50
+ return task_metrics
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ def calculate_combined_f1(combined_labels, combined_preds):
54
+ # Initialize the LabelEncoder
55
+ le = LabelEncoder()
56
 
57
+ # Fit and transform combined labels and predictions to numerical values
58
+ le.fit(combined_labels + combined_preds)
59
+ encoded_true_labels = le.transform(combined_labels)
60
+ encoded_pred_labels = le.transform(combined_preds)
 
 
 
 
 
61
 
62
+ # Print out the mapping for sanity check
63
+ print("\nLabel Encoder Mapping:")
64
+ for index, class_label in enumerate(le.classes_):
65
+ print(f"'{class_label}': {index}")
66
 
67
+ # Calculate accuracy
68
+ accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
 
 
 
 
 
 
69
 
70
+ # Calculate F1 Macro score
71
+ f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
72
 
73
+ return f1, accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
 
 
 
75
 
76
+ # def save_model_without_heads(original_model_save_directory):
77
+ # # Create a new directory for the model without heads
78
+ # new_model_save_directory = original_model_save_directory + "_No_Heads"
79
+ # if not os.path.exists(new_model_save_directory):
80
+ # os.makedirs(new_model_save_directory)
81
 
82
+ # # Load the model state dictionary
83
+ # model_state_dict = torch.load(
84
+ # os.path.join(original_model_save_directory, "pytorch_model.bin")
85
+ # )
 
 
 
86
 
87
+ # # Initialize a new BERT model without the classification heads
88
+ # config = BertConfig.from_pretrained(
89
+ # os.path.join(original_model_save_directory, "config.json")
90
+ # )
91
+ # model_without_heads = BertModel(config)
92
 
93
+ # # Filter the state dict to exclude classification heads
94
+ # model_without_heads_state_dict = {
95
+ # k: v
96
+ # for k, v in model_state_dict.items()
97
+ # if not k.startswith("classification_heads")
98
+ # }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # # Load the filtered state dict into the model
101
+ # model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
102
 
103
+ # # Save the model without heads
104
+ # model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
105
+ # torch.save(model_without_heads.state_dict(), model_save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # # Copy the configuration file
108
+ # shutil.copy(
109
+ # os.path.join(original_model_save_directory, "config.json"),
110
+ # new_model_save_directory,
111
+ # )
112
 
113
+ # print(f"Model without classification heads saved to {new_model_save_directory}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
+ def get_layer_freeze_range(pretrained_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  """
118
+ Dynamically determines the number of layers to freeze based on the model depth from its configuration.
 
119
  Args:
120
+ pretrained_path (str): Path to the pretrained model directory or model identifier.
121
+ Returns:
122
+ dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
123
  """
124
+ if pretrained_path:
125
+ config = AutoConfig.from_pretrained(pretrained_path)
126
+ total_layers = config.num_hidden_layers
127
+ return {"min": 0, "max": total_layers - 1}
128
+ else:
129
+ return {"min": 0, "max": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl_classifier.py CHANGED
@@ -29,7 +29,7 @@ Geneformer multi-task cell classifier.
29
  import logging
30
  import os
31
 
32
- from .mtl import eval_utils, utils, train
33
 
34
  logger = logging.getLogger(__name__)
35
 
@@ -49,9 +49,7 @@ class MTLClassifier:
49
  "max_layers_to_freeze": {None, dict},
50
  "epochs": {None, int},
51
  "tensorboard_log_dir": {None, str},
52
- "distributed_training": {None, bool},
53
- "master_addr": {None, str},
54
- "master_port": {None, str},
55
  "use_attention_pooling": {None, bool},
56
  "use_task_weights": {None, bool},
57
  "hyperparameters": {None, dict},
@@ -63,7 +61,6 @@ class MTLClassifier:
63
  "max_grad_norm": {None, int, float},
64
  "seed": {None, int},
65
  "trials_result_path": {None, str},
66
- "gradient_accumulation_steps": {None, int},
67
  }
68
 
69
  def __init__(
@@ -82,9 +79,7 @@ class MTLClassifier:
82
  max_layers_to_freeze=None,
83
  epochs=1,
84
  tensorboard_log_dir="/results/tblogdir",
85
- distributed_training=False,
86
- master_addr="localhost",
87
- master_port="12355",
88
  use_attention_pooling=True,
89
  use_task_weights=True,
90
  hyperparameters=None, # Default is None
@@ -94,7 +89,6 @@ class MTLClassifier:
94
  wandb_project=None,
95
  gradient_clipping=False,
96
  max_grad_norm=None,
97
- gradient_accumulation_steps=1, # Add this line with default value 1
98
  seed=42, # Default seed value
99
  ):
100
  """
@@ -123,12 +117,8 @@ class MTLClassifier:
123
  | Path to directory to save results
124
  tensorboard_log_dir : None, str
125
  | Path to directory for Tensorboard logging results
126
- distributed_training : None, bool
127
- | Whether to use distributed data parallel training across multiple GPUs
128
- master_addr : None, str
129
- | Master address for distributed training (default: localhost)
130
- master_port : None, str
131
- | Master port for distributed training (default: 12355)
132
  use_attention_pooling : None, bool
133
  | Whether to use attention pooling
134
  use_task_weights : None, bool
@@ -160,8 +150,6 @@ class MTLClassifier:
160
  | Whether to use gradient clipping
161
  max_grad_norm : None, int, float
162
  | Maximum norm for gradient clipping
163
- gradient_accumulation_steps : None, int
164
- | Number of steps to accumulate gradients before performing a backward/update pass
165
  seed : None, int
166
  | Random seed
167
  """
@@ -177,7 +165,6 @@ class MTLClassifier:
177
  self.batch_size = batch_size
178
  self.n_trials = n_trials
179
  self.study_name = study_name
180
- self.gradient_accumulation_steps = gradient_accumulation_steps
181
 
182
  if max_layers_to_freeze is None:
183
  # Dynamically determine the range of layers to freeze
@@ -188,9 +175,7 @@ class MTLClassifier:
188
 
189
  self.epochs = epochs
190
  self.tensorboard_log_dir = tensorboard_log_dir
191
- self.distributed_training = distributed_training
192
- self.master_addr = master_addr
193
- self.master_port = master_port
194
  self.use_attention_pooling = use_attention_pooling
195
  self.use_task_weights = use_task_weights
196
  self.hyperparameters = (
@@ -308,7 +293,7 @@ class MTLClassifier:
308
  self.config["manual_hyperparameters"] = self.manual_hyperparameters
309
  self.config["use_manual_hyperparameters"] = True
310
 
311
- train.run_manual_tuning(self.config)
312
 
313
  def validate_additional_options(self, req_var_dict):
314
  missing_variable = False
@@ -345,7 +330,7 @@ class MTLClassifier:
345
  req_var_dict = dict(zip(required_variable_names, required_variables))
346
  self.validate_additional_options(req_var_dict)
347
 
348
- train.run_optuna_study(self.config)
349
 
350
  def load_and_evaluate_test_model(
351
  self,
 
29
  import logging
30
  import os
31
 
32
+ from .mtl import eval_utils, train_utils, utils
33
 
34
  logger = logging.getLogger(__name__)
35
 
 
49
  "max_layers_to_freeze": {None, dict},
50
  "epochs": {None, int},
51
  "tensorboard_log_dir": {None, str},
52
+ "use_data_parallel": {None, bool},
 
 
53
  "use_attention_pooling": {None, bool},
54
  "use_task_weights": {None, bool},
55
  "hyperparameters": {None, dict},
 
61
  "max_grad_norm": {None, int, float},
62
  "seed": {None, int},
63
  "trials_result_path": {None, str},
 
64
  }
65
 
66
  def __init__(
 
79
  max_layers_to_freeze=None,
80
  epochs=1,
81
  tensorboard_log_dir="/results/tblogdir",
82
+ use_data_parallel=False,
 
 
83
  use_attention_pooling=True,
84
  use_task_weights=True,
85
  hyperparameters=None, # Default is None
 
89
  wandb_project=None,
90
  gradient_clipping=False,
91
  max_grad_norm=None,
 
92
  seed=42, # Default seed value
93
  ):
94
  """
 
117
  | Path to directory to save results
118
  tensorboard_log_dir : None, str
119
  | Path to directory for Tensorboard logging results
120
+ use_data_parallel : None, bool
121
+ | Whether to use data parallelization
 
 
 
 
122
  use_attention_pooling : None, bool
123
  | Whether to use attention pooling
124
  use_task_weights : None, bool
 
150
  | Whether to use gradient clipping
151
  max_grad_norm : None, int, float
152
  | Maximum norm for gradient clipping
 
 
153
  seed : None, int
154
  | Random seed
155
  """
 
165
  self.batch_size = batch_size
166
  self.n_trials = n_trials
167
  self.study_name = study_name
 
168
 
169
  if max_layers_to_freeze is None:
170
  # Dynamically determine the range of layers to freeze
 
175
 
176
  self.epochs = epochs
177
  self.tensorboard_log_dir = tensorboard_log_dir
178
+ self.use_data_parallel = use_data_parallel
 
 
179
  self.use_attention_pooling = use_attention_pooling
180
  self.use_task_weights = use_task_weights
181
  self.hyperparameters = (
 
293
  self.config["manual_hyperparameters"] = self.manual_hyperparameters
294
  self.config["use_manual_hyperparameters"] = True
295
 
296
+ train_utils.run_manual_tuning(self.config)
297
 
298
  def validate_additional_options(self, req_var_dict):
299
  missing_variable = False
 
330
  req_var_dict = dict(zip(required_variable_names, required_variables))
331
  self.validate_additional_options(req_var_dict)
332
 
333
+ train_utils.run_optuna_study(self.config)
334
 
335
  def load_and_evaluate_test_model(
336
  self,
geneformer/perturber_utils.py CHANGED
@@ -1,6 +1,5 @@
1
  import itertools as it
2
  import logging
3
- import os
4
  import pickle
5
  from collections import defaultdict
6
  from pathlib import Path
@@ -9,7 +8,6 @@ from typing import List
9
  import numpy as np
10
  import pandas as pd
11
  import torch
12
- import datasets
13
  from datasets import Dataset, load_from_disk
14
  from peft import LoraConfig, get_peft_model
15
  from transformers import (
@@ -19,6 +17,11 @@ from transformers import (
19
  BitsAndBytesConfig,
20
  )
21
 
 
 
 
 
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
@@ -110,25 +113,15 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
110
 
111
  # load model to GPU
112
  def load_model(model_type, num_classes, model_directory, mode, quantize=False):
113
- if model_type == "Pretrained-Quantized":
114
- inference_only = True
115
- model_type = "Pretrained"
116
- quantize = True
117
- elif model_type == "MTLCellClassifier-Quantized":
118
- inference_only = True
119
  model_type = "MTLCellClassifier"
120
  quantize = True
121
- else:
122
- inference_only = False
123
 
124
  output_hidden_states = (mode == "eval")
125
 
126
  # Quantization logic
127
- if isinstance(quantize, dict):
128
- quantize_config = quantize.get("bnb_config", None)
129
- peft_config = quantize.get("peft_config", None)
130
- elif quantize:
131
- if inference_only:
132
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
133
  peft_config = None
134
  else:
@@ -138,22 +131,13 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
- try:
142
- peft_config = LoraConfig(
143
- lora_alpha=128,
144
- lora_dropout=0.1,
145
- r=64,
146
- bias="none",
147
- task_type="TokenClassification",
148
- )
149
- except ValueError as e:
150
- peft_config = LoraConfig(
151
- lora_alpha=128,
152
- lora_dropout=0.1,
153
- r=64,
154
- bias="none",
155
- task_type="TOKEN_CLS",
156
- )
157
  else:
158
  quantize_config = None
159
  peft_config = None
@@ -190,34 +174,17 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
190
  model.eval()
191
 
192
  # Handle device placement and PEFT
193
- adapter_config_path = os.path.join(model_directory, "adapter_config.json")
194
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
195
  if not quantize:
196
  # Only move non-quantized models
197
- move_to_cuda(model)
198
- elif os.path.exists(adapter_config_path):
199
- # If adapter files exist, load them into the model using PEFT's from_pretrained
200
- model = PeftModel.from_pretrained(model, model_directory)
201
- move_to_cuda(model)
202
- print("loading lora weights")
203
  elif peft_config:
204
- # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
205
  model.enable_input_require_grads()
206
  model = get_peft_model(model, peft_config)
207
- move_to_cuda(model)
208
 
209
  return model
210
 
211
-
212
- def move_to_cuda(model):
213
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
214
- # get what device model is currently on
215
- model_device = next(model.parameters()).device
216
- # Check if the model is on the CPU and move to cuda if necessary
217
- if (model_device.type == "cpu") and (device.type == "cuda"):
218
- model.to(device)
219
-
220
-
221
  def quant_layers(model):
222
  layer_nums = []
223
  for name, parameter in model.named_parameters():
@@ -431,11 +398,6 @@ def remove_perturbed_indices_set(
431
  def make_perturbation_batch(
432
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
433
  ) -> tuple[Dataset, List[int]]:
434
-
435
- # For datasets>=4.0.0, convert to dict to avoid format issues
436
- if int(datasets.__version__.split(".")[0]) >= 4:
437
- example_cell = example_cell[:]
438
-
439
  if combo_lvl == 0 and tokens_to_perturb == "all":
440
  if perturb_type in ["overexpress", "activate"]:
441
  range_start = 1
@@ -508,11 +470,6 @@ def make_perturbation_batch(
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
511
-
512
- # For datasets>=4.0.0, convert to dict to avoid format issues
513
- if int(datasets.__version__.split(".")[0]) >= 4:
514
- example_cell = example_cell[:]
515
-
516
  if combo_lvl == 0 and tokens_to_perturb == "all":
517
  if perturb_type in ["overexpress", "activate"]:
518
  range_start = 1
@@ -913,4 +870,50 @@ def validate_cell_states_to_model(cell_states_to_model):
913
  "'goal_state': 'nf', "
914
  "'alt_states': ['hcm', 'other1', 'other2']}"
915
  )
916
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import itertools as it
2
  import logging
 
3
  import pickle
4
  from collections import defaultdict
5
  from pathlib import Path
 
8
  import numpy as np
9
  import pandas as pd
10
  import torch
 
11
  from datasets import Dataset, load_from_disk
12
  from peft import LoraConfig, get_peft_model
13
  from transformers import (
 
17
  BitsAndBytesConfig,
18
  )
19
 
20
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
21
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
22
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
23
+
24
+
25
  logger = logging.getLogger(__name__)
26
 
27
 
 
113
 
114
  # load model to GPU
115
  def load_model(model_type, num_classes, model_directory, mode, quantize=False):
116
+ if model_type == "MTLCellClassifier-Quantized":
 
 
 
 
 
117
  model_type = "MTLCellClassifier"
118
  quantize = True
 
 
119
 
120
  output_hidden_states = (mode == "eval")
121
 
122
  # Quantization logic
123
+ if quantize:
124
+ if model_type == "MTLCellClassifier":
 
 
 
125
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
126
  peft_config = None
127
  else:
 
131
  bnb_4bit_quant_type="nf4",
132
  bnb_4bit_compute_dtype=torch.bfloat16,
133
  )
134
+ peft_config = LoraConfig(
135
+ lora_alpha=128,
136
+ lora_dropout=0.1,
137
+ r=64,
138
+ bias="none",
139
+ task_type="TokenClassification",
140
+ )
 
 
 
 
 
 
 
 
 
141
  else:
142
  quantize_config = None
143
  peft_config = None
 
174
  model.eval()
175
 
176
  # Handle device placement and PEFT
 
 
177
  if not quantize:
178
  # Only move non-quantized models
179
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
+ model = model.to(device)
 
 
 
 
181
  elif peft_config:
182
+ # Apply PEFT for quantized models (except MTLCellClassifier)
183
  model.enable_input_require_grads()
184
  model = get_peft_model(model, peft_config)
 
185
 
186
  return model
187
 
 
 
 
 
 
 
 
 
 
 
188
  def quant_layers(model):
189
  layer_nums = []
190
  for name, parameter in model.named_parameters():
 
398
  def make_perturbation_batch(
399
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
400
  ) -> tuple[Dataset, List[int]]:
 
 
 
 
 
401
  if combo_lvl == 0 and tokens_to_perturb == "all":
402
  if perturb_type in ["overexpress", "activate"]:
403
  range_start = 1
 
470
  def make_perturbation_batch_special(
471
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
472
  ) -> tuple[Dataset, List[int]]:
 
 
 
 
 
473
  if combo_lvl == 0 and tokens_to_perturb == "all":
474
  if perturb_type in ["overexpress", "activate"]:
475
  range_start = 1
 
870
  "'goal_state': 'nf', "
871
  "'alt_states': ['hcm', 'other1', 'other2']}"
872
  )
873
+ raise
874
+
875
+
876
+ class GeneIdHandler:
877
+ def __init__(self, raise_errors=False):
878
+ def invert_dict(dict_obj):
879
+ return {v: k for k, v in dict_obj.items()}
880
+
881
+ self.raise_errors = raise_errors
882
+
883
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
884
+ self.gene_token_dict = pickle.load(f)
885
+ self.token_gene_dict = invert_dict(self.gene_token_dict)
886
+
887
+ with open(ENSEMBL_DICTIONARY_FILE, "rb") as f:
888
+ self.id_gene_dict = pickle.load(f)
889
+ self.gene_id_dict = invert_dict(self.id_gene_dict)
890
+
891
+ def ens_to_token(self, ens_id):
892
+ if not self.raise_errors:
893
+ return self.gene_token_dict.get(ens_id, ens_id)
894
+ else:
895
+ return self.gene_token_dict[ens_id]
896
+
897
+ def token_to_ens(self, token):
898
+ if not self.raise_errors:
899
+ return self.token_gene_dict.get(token, token)
900
+ else:
901
+ return self.token_gene_dict[token]
902
+
903
+ def ens_to_symbol(self, ens_id):
904
+ if not self.raise_errors:
905
+ return self.gene_id_dict.get(ens_id, ens_id)
906
+ else:
907
+ return self.gene_id_dict[ens_id]
908
+
909
+ def symbol_to_ens(self, symbol):
910
+ if not self.raise_errors:
911
+ return self.id_gene_dict.get(symbol, symbol)
912
+ else:
913
+ return self.id_gene_dict[symbol]
914
+
915
+ def token_to_symbol(self, token):
916
+ return self.ens_to_symbol(self.token_to_ens(token))
917
+
918
+ def symbol_to_token(self, symbol):
919
+ return self.ens_to_token(self.symbol_to_ens(symbol))
geneformer/pretrainer.py CHANGED
@@ -8,12 +8,13 @@ import math
8
  import pickle
9
  import warnings
10
  from enum import Enum
11
- from typing import Dict, List, Optional, Union
12
 
13
  import numpy as np
14
  import torch
15
  from datasets import Dataset
16
  from packaging import version
 
17
  from torch.utils.data.sampler import RandomSampler
18
  from transformers import (
19
  BatchEncoding,
@@ -23,8 +24,11 @@ from transformers import (
23
  )
24
  from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
25
  from transformers.trainer_pt_utils import (
 
 
26
  LengthGroupedSampler,
27
  )
 
28
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
29
  from transformers.utils.generic import _is_tensorflow, _is_torch
30
 
@@ -603,7 +607,7 @@ class GeneformerPretrainer(Trainer):
603
  )
604
  super().__init__(*args, **kwargs)
605
 
606
- # updated to not use distributed sampler since Trainer now distributes with accelerate
607
  def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
608
  if not isinstance(self.train_dataset, collections.abc.Sized):
609
  return None
@@ -626,15 +630,181 @@ class GeneformerPretrainer(Trainer):
626
  if self.tokenizer is not None
627
  else None
628
  )
629
- return LengthGroupedSampler(
 
630
  dataset=self.train_dataset,
631
  batch_size=self.args.train_batch_size,
632
  lengths=lengths,
633
  model_input_name=model_input_name,
634
  generator=generator,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
 
 
 
637
  else:
638
- if _is_torch_generator_available:
639
- return RandomSampler(self.train_dataset, generator=generator)
640
- return RandomSampler(self.train_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import pickle
9
  import warnings
10
  from enum import Enum
11
+ from typing import Dict, Iterator, List, Optional, Union
12
 
13
  import numpy as np
14
  import torch
15
  from datasets import Dataset
16
  from packaging import version
17
+ from torch.utils.data.distributed import DistributedSampler
18
  from torch.utils.data.sampler import RandomSampler
19
  from transformers import (
20
  BatchEncoding,
 
24
  )
25
  from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
26
  from transformers.trainer_pt_utils import (
27
+ DistributedLengthGroupedSampler,
28
+ DistributedSamplerWithLoop,
29
  LengthGroupedSampler,
30
  )
31
+ from transformers.training_args import ParallelMode
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
 
607
  )
608
  super().__init__(*args, **kwargs)
609
 
610
+ # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
611
  def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
612
  if not isinstance(self.train_dataset, collections.abc.Sized):
613
  return None
 
630
  if self.tokenizer is not None
631
  else None
632
  )
633
+ if self.args.world_size <= 1:
634
+ return LengthGroupedSampler(
635
  dataset=self.train_dataset,
636
  batch_size=self.args.train_batch_size,
637
  lengths=lengths,
638
  model_input_name=model_input_name,
639
  generator=generator,
640
+ )
641
+ else:
642
+ return CustomDistributedLengthGroupedSampler(
643
+ dataset=self.train_dataset,
644
+ batch_size=self.args.train_batch_size,
645
+ num_replicas=self.args.world_size,
646
+ rank=self.args.process_index,
647
+ lengths=lengths,
648
+ model_input_name=model_input_name,
649
+ seed=self.args.seed,
650
+ )
651
+
652
+ else:
653
+ if self.args.world_size <= 1:
654
+ if _is_torch_generator_available:
655
+ return RandomSampler(self.train_dataset, generator=generator)
656
+ return RandomSampler(self.train_dataset)
657
+ elif (
658
+ self.args.parallel_mode
659
+ in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
660
+ and not self.args.dataloader_drop_last
661
+ ):
662
+ # Use a loop for TPUs when drop_last is False to have all batches have the same size.
663
+ return DistributedSamplerWithLoop(
664
+ self.train_dataset,
665
+ batch_size=self.args.per_device_train_batch_size,
666
+ num_replicas=self.args.world_size,
667
+ rank=self.args.process_index,
668
+ seed=self.args.seed,
669
+ )
670
+ else:
671
+ return DistributedSampler(
672
+ self.train_dataset,
673
+ num_replicas=self.args.world_size,
674
+ rank=self.args.process_index,
675
+ seed=self.args.seed,
676
+ )
677
+
678
+
679
+ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
680
+ r"""
681
+ Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
682
+ length while keeping a bit of randomness.
683
+ """
684
+
685
+ # Copied and adapted from PyTorch DistributedSampler.
686
+ def __init__(
687
+ self,
688
+ dataset: Dataset,
689
+ batch_size: int,
690
+ num_replicas: Optional[int] = None,
691
+ rank: Optional[int] = None,
692
+ seed: int = 0,
693
+ drop_last: bool = False,
694
+ lengths: Optional[List[int]] = None,
695
+ model_input_name: Optional[str] = None,
696
+ ):
697
+ if num_replicas is None:
698
+ if not dist.is_available():
699
+ raise RuntimeError("Requires distributed package to be available")
700
+ num_replicas = dist.get_world_size()
701
+ if rank is None:
702
+ if not dist.is_available():
703
+ raise RuntimeError("Requires distributed package to be available")
704
+ rank = dist.get_rank()
705
+ self.dataset = dataset
706
+ self.batch_size = batch_size
707
+ self.num_replicas = num_replicas
708
+ self.rank = rank
709
+ self.epoch = 0
710
+ self.drop_last = drop_last
711
+ # If the dataset length is evenly divisible by # of replicas, then there
712
+ # is no need to drop any data, since the dataset will be split equally.
713
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0:
714
+ # Split to nearest available length that is evenly divisible.
715
+ # This is to ensure each rank receives the same amount of data when
716
+ # using this Sampler.
717
+ self.num_samples = math.ceil(
718
+ (len(self.dataset) - self.num_replicas) / self.num_replicas
719
  )
720
+ else:
721
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
722
+ self.total_size = self.num_samples * self.num_replicas
723
+ self.seed = seed
724
+ self.model_input_name = (
725
+ model_input_name if model_input_name is not None else "input_ids"
726
+ )
727
+
728
+ if lengths is None:
729
+ print("Lengths is none - calculating lengths.")
730
+ if (
731
+ not (
732
+ isinstance(dataset[0], dict)
733
+ or isinstance(dataset[0], BatchEncoding)
734
+ )
735
+ or self.model_input_name not in dataset[0]
736
+ ):
737
+ raise ValueError(
738
+ "Can only automatically infer lengths for datasets whose items are dictionaries with an "
739
+ f"'{self.model_input_name}' key."
740
+ )
741
+ lengths = [len(feature[self.model_input_name]) for feature in dataset]
742
+ self.lengths = lengths
743
+
744
+ def __iter__(self) -> Iterator:
745
+ # Deterministically shuffle based on epoch and seed
746
+ g = torch.Generator()
747
+ g.manual_seed(self.seed + self.epoch)
748
+
749
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
750
 
751
+ if not self.drop_last:
752
+ # add extra samples to make it evenly divisible
753
+ indices += indices[: (self.total_size - len(indices))]
754
  else:
755
+ # remove tail of data to make it evenly divisible.
756
+ indices = indices[: self.total_size]
757
+ assert len(indices) == self.total_size
758
+
759
+ # subsample
760
+ indices = indices[self.rank : self.total_size : self.num_replicas]
761
+ assert len(indices) == self.num_samples
762
+
763
+ return iter(indices)
764
+
765
+
766
+ def get_length_grouped_indices(
767
+ lengths, batch_size, mega_batch_mult=None, generator=None
768
+ ):
769
+ """
770
+ Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
771
+ similar lengths. To do this, the indices are:
772
+
773
+ - randomly permuted
774
+ - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
775
+ - sorted by length in each mega-batch
776
+
777
+ The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
778
+ maximum length placed first, so that an OOM happens sooner rather than later.
779
+ """
780
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
781
+ if mega_batch_mult is None:
782
+ # mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
783
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
784
+ # Just in case, for tiny datasets
785
+ if mega_batch_mult == 0:
786
+ mega_batch_mult = 1
787
+
788
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
789
+ indices = torch.randperm(len(lengths), generator=generator)
790
+ megabatch_size = mega_batch_mult * batch_size
791
+ megabatches = [
792
+ indices[i : i + megabatch_size].tolist()
793
+ for i in range(0, len(lengths), megabatch_size)
794
+ ]
795
+ megabatches = [
796
+ list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
797
+ for megabatch in megabatches
798
+ ]
799
+
800
+ # The rest is to get the biggest batch first.
801
+ # Since each megabatch is sorted by descending length, the longest element is the first
802
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
803
+ max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
804
+ # Switch to put the longest element in first position
805
+ megabatches[0][0], megabatches[max_idx][0] = (
806
+ megabatches[max_idx][0],
807
+ megabatches[0][0],
808
+ )
809
+
810
+ return [item for sublist in megabatches for item in sublist]
geneformer/{token_dictionary_gc104M.pkl → token_dictionary_gc95M.pkl} RENAMED
File without changes
geneformer/tokenizer.py CHANGED
@@ -3,7 +3,7 @@ Geneformer tokenizer.
3
 
4
  **Input data:**
5
 
6
- | *Required format:* raw counts scRNAseq data without feature selection as .loom, .h5ad, or .zarr file.
7
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
 
@@ -20,9 +20,9 @@ Geneformer tokenizer.
20
 
21
  **Description:**
22
 
23
- | Input data is a directory with .loom, .h5ad, or .zarr files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
 
25
- | The discussion below references the .loom file format, but the analagous labels are required for .h5ad and .zarr files, just that they will be column instead of row attributes and vice versa due to the transposed format of the file types.
26
 
27
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
 
@@ -30,9 +30,11 @@ Geneformer tokenizer.
30
 
31
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
 
33
- | If one's data is in other formats besides .loom, .h5ad, or .zarr, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom, .h5ad, or .zarr format prior to running the transcriptome tokenizer.
34
 
35
- | OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.)
 
 
36
 
37
  """
38
 
@@ -46,7 +48,6 @@ from collections import Counter
46
  from pathlib import Path
47
  from typing import Literal
48
 
49
- import anndata as ad
50
  import loompy as lp
51
  import numpy as np
52
  import pandas as pd
@@ -87,8 +88,6 @@ def sum_ensembl_ids(
87
  collapse_gene_ids,
88
  gene_mapping_dict,
89
  gene_token_dict,
90
- custom_attr_name_dict,
91
- use_h5ad_index,
92
  file_format="loom",
93
  chunk_size=512,
94
  ):
@@ -104,45 +103,33 @@ def sum_ensembl_ids(
104
  assert (
105
  "ensembl_id_collapsed" not in data.ra.keys()
106
  ), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
107
-
108
- assert (
109
- "n_counts" in data.ca.keys()
110
- ), "'n_counts' column missing from data.ca.keys()"
111
-
112
- if custom_attr_name_dict is not None:
113
- for label in custom_attr_name_dict:
114
- assert label in data.ca.keys(), f"Attribute `{label}` not present in dataset features"
115
-
116
- # Get the ensembl ids that exist in data
117
- ensembl_ids = data.ra.ensembl_id
118
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
119
  # Comparing to gene_token_dict here, would not perform any mapping steps
120
- if not collapse_gene_ids:
121
- ensembl_id_check = [
122
- gene for gene in ensembl_ids if gene in gene_token_dict.keys()
123
- ]
124
- if len(ensembl_id_check) == len(set(ensembl_id_check)):
 
125
  return data_directory
126
  else:
127
  raise ValueError("Error: data Ensembl IDs non-unique.")
128
-
129
- # Get the genes that exist in the mapping dictionary and the value of those genes
130
- genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()]
131
- vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict]
132
-
133
- # if the genes in the mapping dict and the value of those genes are of the same length,
134
- # simply return the mapped values
135
- if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))):
136
- mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]]
137
- data.ra["ensembl_id_collapsed"] = mapped_vals
138
  return data_directory
139
- # Genes need to be collapsed
140
  else:
141
  dedup_filename = data_directory.with_name(
142
  data_directory.stem + "__dedup.loom"
143
  )
144
- mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]]
145
- data.ra["ensembl_id_collapsed"] = mapped_vals
146
  dup_genes = [
147
  idx
148
  for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
@@ -201,19 +188,13 @@ def sum_ensembl_ids(
201
  dsout.add_columns(processed_array, col_attrs=view.ca)
202
  return dedup_filename
203
 
204
- elif file_format in ["h5ad", "zarr"]:
205
  """
206
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
207
  Returns adata object with deduplicated Ensembl IDs.
208
  """
209
 
210
- if file_format == "h5ad":
211
- data = sc.read_h5ad(str(data_directory))
212
- else: # zarr
213
- data = ad.read_zarr(str(data_directory))
214
-
215
- if use_h5ad_index:
216
- data.var["ensembl_id"] = list(data.var.index)
217
 
218
  assert (
219
  "ensembl_id" in data.var.columns
@@ -222,41 +203,33 @@ def sum_ensembl_ids(
222
  assert (
223
  "ensembl_id_collapsed" not in data.var.columns
224
  ), "'ensembl_id_collapsed' column already exists in data.var"
225
- assert (
226
- "n_counts" in data.obs.columns
227
- ), "'n_counts' column missing from data.obs"
228
-
229
- if custom_attr_name_dict is not None:
230
- for label in custom_attr_name_dict:
231
- assert label in data.obs.columns, f"Attribute `{label}` not present in data.obs"
232
 
233
-
234
- # Get the ensembl ids that exist in data
235
- ensembl_ids = data.var.ensembl_id
236
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
237
  # Comparing to gene_token_dict here, would not perform any mapping steps
238
- if not collapse_gene_ids:
239
- ensembl_id_check = [
240
- gene for gene in ensembl_ids if gene in gene_token_dict.keys()
241
- ]
242
- if len(ensembl_id_check) == len(set(ensembl_id_check)):
 
243
  return data
244
  else:
245
  raise ValueError("Error: data Ensembl IDs non-unique.")
246
 
247
- # Get the genes that exist in the mapping dictionary and the value of those genes
248
- genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()]
249
- vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict]
250
-
251
- # if the genes in the mapping dict and the value of those genes are of the same length,
252
- # simply return the mapped values
253
- if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))):
254
- data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict)
 
255
  return data
256
- # Genes need to be collapsed
257
  else:
258
- data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict)
259
- data.var_names = data.var["ensembl_id_collapsed"]
260
  data = data[:, ~data.var.index.isna()]
261
  dup_genes = [
262
  idx for idx, count in Counter(data.var_names).items() if count > 1
@@ -305,9 +278,6 @@ class TranscriptomeTokenizer:
305
  model_input_size=4096,
306
  special_token=True,
307
  collapse_gene_ids=True,
308
- use_h5ad_index=False,
309
- keep_counts=False,
310
- model_version="V2",
311
  gene_median_file=GENE_MEDIAN_FILE,
312
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
313
  gene_mapping_file=ENSEMBL_MAPPING_FILE,
@@ -327,23 +297,15 @@ class TranscriptomeTokenizer:
327
  | Chunk size for anndata tokenizer.
328
  model_input_size : int = 4096
329
  | Max input size of model to truncate input to.
330
- | For the V1 model series, should be 2048. For the V2 model series, should be 4096.
331
  special_token : bool = True
332
  | Adds CLS token before and EOS token after rank value encoding.
333
- | For the V1 model series, should be False. For the V2 model series, should be True.
334
  collapse_gene_ids : bool = True
335
  | Whether to collapse gene IDs based on gene mapping dictionary.
336
- use_h5ad_index : bool = False
337
- | use index as Ensembl IDs (only available for h5ad, only if collapse_gene_ids is True)
338
- keep_counts : bool = False
339
- | Whether to keep a dataset column that represents gene counts normalized by total cell counts
340
- | Counts will be ordered by the gene rank order within the tokenized rank value encoding for each cell.
341
- model_version : str
342
- | To auto-select settings for model version other than current default.
343
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
344
  gene_median_file : Path
345
  | Path to pickle file containing dictionary of non-zero median
346
- | gene expression values across Genecorpus.
347
  token_dictionary_file : Path
348
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
349
  gene_mapping_file : None, Path
@@ -365,22 +327,8 @@ class TranscriptomeTokenizer:
365
  # add CLS and EOS tokens
366
  self.special_token = special_token
367
 
368
- # CHANGE DEFAULTS TO BE FOR MODEL OTHER THAN CURRENT
369
- self.model_version = model_version
370
- if self.model_version not in ["V1","V2"]:
371
- logger.error(
372
- "Unrecognized model version. Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells."
373
- )
374
- elif self.model_version == "V1":
375
- self.model_input_size = 2048
376
- self.special_token = False
377
- from . import ENSEMBL_MAPPING_FILE_30M, GENE_MEDIAN_FILE_30M, TOKEN_DICTIONARY_FILE_30M
378
- gene_median_file = GENE_MEDIAN_FILE_30M
379
- token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
380
- gene_mapping_file = ENSEMBL_MAPPING_FILE_30M
381
-
382
  # load dictionary of gene normalization factors
383
- # (non-zero median value of expression across Genecorpus)
384
  with open(gene_median_file, "rb") as f:
385
  self.gene_median_dict = pickle.load(f)
386
 
@@ -403,18 +351,12 @@ class TranscriptomeTokenizer:
403
  "<eos>" in self.gene_token_dict.keys()
404
  ):
405
  logger.warning(
406
- "<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for V2 model series, special_token should be True."
407
  )
408
 
409
  # if collapsing duplicate gene IDs
410
  self.collapse_gene_ids = collapse_gene_ids
411
 
412
- # if using h5ad index as ensembl_ids
413
- self.use_h5ad_index = use_h5ad_index
414
-
415
- # if keeping counts within dataset column
416
- self.keep_counts = keep_counts
417
-
418
  # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
419
  if gene_mapping_file is not None:
420
  with open(gene_mapping_file, "rb") as f:
@@ -439,8 +381,7 @@ class TranscriptomeTokenizer:
439
  data_directory: Path | str,
440
  output_directory: Path | str,
441
  output_prefix: str,
442
- file_format: Literal["loom", "h5ad", "zarr"] = "loom",
443
- input_identifier: str = "",
444
  use_generator: bool = False,
445
  ):
446
  """
@@ -455,21 +396,17 @@ class TranscriptomeTokenizer:
455
  output_prefix : str
456
  | Prefix for output .dataset
457
  file_format : str
458
- | Format of input files. Can be "loom", "h5ad", or "zarr".
459
- input_identifier : str
460
- | Substring identifier for input .loom, .h5ad, or .zarr, only matches are tokenized
461
- | Default is no identifier, tokenizes all files in provided directory.
462
  use_generator : bool
463
  | Whether to use generator or dict for tokenization.
464
 
465
  """
466
- tokenized_cells, cell_metadata, tokenized_counts = self.tokenize_files(
467
- Path(data_directory), file_format, input_identifier
468
  )
469
  tokenized_dataset = self.create_dataset(
470
  tokenized_cells,
471
  cell_metadata,
472
- tokenized_counts,
473
  use_generator=use_generator,
474
  )
475
 
@@ -477,10 +414,9 @@ class TranscriptomeTokenizer:
477
  tokenized_dataset.save_to_disk(str(output_path))
478
 
479
  def tokenize_files(
480
- self, data_directory, file_format: Literal["loom", "h5ad", "zarr"] = "loom", input_identifier: str = ""
481
  ):
482
  tokenized_cells = []
483
- tokenized_counts = []
484
  if self.custom_attr_name_dict is not None:
485
  cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
486
  cell_metadata = {
@@ -489,20 +425,15 @@ class TranscriptomeTokenizer:
489
 
490
  # loops through directories to tokenize .loom files
491
  file_found = 0
492
- # loops through directories to tokenize .loom, .h5ad, or .zarr files
493
  tokenize_file_fn = (
494
  self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
495
  )
496
- if input_identifier == "":
497
- file_match = f"*.{file_format}"
498
- else:
499
- file_match = f"*{input_identifier}*.{file_format}"
500
- for file_path in data_directory.glob(file_match):
501
  file_found = 1
502
  print(f"Tokenizing {file_path}")
503
- file_tokenized_cells, file_cell_metadata, file_tokenized_counts = tokenize_file_fn(file_path, file_format=file_format)
504
  tokenized_cells += file_tokenized_cells
505
- tokenized_counts += file_tokenized_counts
506
  if self.custom_attr_name_dict is not None:
507
  for k in cell_attr:
508
  cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
@@ -516,17 +447,15 @@ class TranscriptomeTokenizer:
516
  f"No .{file_format} files found in directory {data_directory}."
517
  )
518
  raise
519
- return tokenized_cells, cell_metadata, tokenized_counts
520
 
521
- def tokenize_anndata(self, adata_file_path, target_sum=10_000, file_format="h5ad"):
522
  adata = sum_ensembl_ids(
523
  adata_file_path,
524
  self.collapse_gene_ids,
525
  self.gene_mapping_dict,
526
  self.gene_token_dict,
527
- self.custom_attr_name_dict,
528
- self.use_h5ad_index,
529
- file_format=file_format,
530
  chunk_size=self.chunk_size,
531
  )
532
 
@@ -565,7 +494,6 @@ class TranscriptomeTokenizer:
565
  filter_pass_loc = np.array([i for i in range(adata.shape[0])])
566
 
567
  tokenized_cells = []
568
- tokenized_counts = []
569
 
570
  for i in range(0, len(filter_pass_loc), self.chunk_size):
571
  idx = filter_pass_loc[i : i + self.chunk_size]
@@ -573,23 +501,14 @@ class TranscriptomeTokenizer:
573
  n_counts = adata[idx].obs["n_counts"].values[:, None]
574
  X_view0 = adata[idx, :].X
575
  X_view = X_view0[:, coding_miRNA_loc]
576
- X_norm_unscaled = X_view / n_counts * target_sum
577
- X_norm = X_norm_unscaled / norm_factor_vector
578
  X_norm = sp.csr_matrix(X_norm)
579
- X_norm_unscaled = sp.csr_matrix(X_norm_unscaled)
580
 
581
  tokenized_cells += [
582
  rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
583
  for i in range(X_norm.shape[0])
584
  ]
585
 
586
- if self.keep_counts:
587
- X_norm_unscaled = sp.csr_matrix(X_norm_unscaled)
588
- tokenized_counts += [
589
- rank_genes(X_norm[i].data, X_norm_unscaled[i].data)
590
- for i in range(X_norm.shape[0])
591
- ]
592
-
593
  # add custom attributes for subview to dict
594
  if self.custom_attr_name_dict is not None:
595
  for k in file_cell_metadata.keys():
@@ -597,28 +516,9 @@ class TranscriptomeTokenizer:
597
  else:
598
  file_cell_metadata = None
599
 
600
- # ensure no tokenized_cells are empty
601
- empty_cell_indices = [i for i, cell in enumerate(tokenized_cells) if cell.size == 0]
602
- if len(empty_cell_indices) > 0:
603
- logger.warning(
604
- "Warning: cells without any genes in token dictionary detected. This is unusual and may indicate empty droplets or otherwise invalid cells within the input data. Consider further QC prior to tokenization. Proceeding with excluding empty cells."
605
- )
606
- empty_cell_indices.sort(reverse=True) # for safe deletion
607
- for index in empty_cell_indices:
608
- del tokenized_cells[index]
609
- if self.keep_counts:
610
- del tokenized_counts[index]
611
- # remove corresponding metadata
612
- for k,v in file_cell_metadata.items():
613
- for index in empty_cell_indices:
614
- del v[index]
615
- file_cell_metadata[k] = v
616
-
617
- return tokenized_cells, file_cell_metadata, tokenized_counts
618
 
619
- def tokenize_loom(self, loom_file_path, target_sum=10_000, file_format="loom"):
620
- tokenized_counts = [] # keep_counts not implemented for tokenize_loom
621
-
622
  if self.custom_attr_name_dict is not None:
623
  file_cell_metadata = {
624
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
@@ -631,9 +531,7 @@ class TranscriptomeTokenizer:
631
  self.collapse_gene_ids,
632
  self.gene_mapping_dict,
633
  self.gene_token_dict,
634
- self.custom_attr_name_dict,
635
- use_h5ad_index=False,
636
- file_format=file_format,
637
  chunk_size=self.chunk_size,
638
  )
639
 
@@ -706,33 +604,29 @@ class TranscriptomeTokenizer:
706
  del data.ra["ensembl_id_collapsed"]
707
 
708
 
709
- return tokenized_cells, file_cell_metadata, tokenized_counts
710
 
711
  def create_dataset(
712
  self,
713
  tokenized_cells,
714
  cell_metadata,
715
- tokenized_counts,
716
  use_generator=False,
717
  keep_uncropped_input_ids=False,
718
  ):
719
  print("Creating dataset.")
720
  # create dict for dataset creation
721
  dataset_dict = {"input_ids": tokenized_cells}
722
- if self.keep_counts:
723
- dataset_dict["counts"] = tokenized_counts
724
-
725
  if self.custom_attr_name_dict is not None:
726
  dataset_dict.update(cell_metadata)
727
 
728
  # create dataset
729
  if use_generator:
 
730
  def dict_generator():
731
  for i in range(len(tokenized_cells)):
732
  yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
733
 
734
  output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
735
-
736
  else:
737
  output_dataset = Dataset.from_dict(dataset_dict)
738
 
@@ -755,23 +649,9 @@ class TranscriptomeTokenizer:
755
  len(example["input_ids"]),
756
  self.gene_token_dict.get("<eos>"),
757
  )
758
- if self.keep_counts:
759
- example["counts"] = example["counts"][
760
- 0 : self.model_input_size - 2
761
- ] # truncate to leave space for CLS and EOS token
762
- example["counts"] = np.insert(
763
- example["counts"], 0, 0.0
764
- )
765
- example["counts"] = np.insert(
766
- example["counts"],
767
- len(example["counts"]),
768
- 0.0,
769
- )
770
  else:
771
  # Truncate/Crop input_ids to input size
772
  example["input_ids"] = example["input_ids"][0 : self.model_input_size]
773
- if self.keep_counts:
774
- example["counts"] = example["counts"][0 : self.model_input_size]
775
  example["length"] = len(example["input_ids"])
776
 
777
  return example
 
3
 
4
  **Input data:**
5
 
6
+ | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
7
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
 
 
20
 
21
  **Description:**
22
 
23
+ | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
 
25
+ | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
26
 
27
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
 
 
30
 
31
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
 
33
+ | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
 
35
+ | OF NOTE: Take care that the correct token dictionary and gene median file is used for the correct model.
36
+
37
+ | OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048.
38
 
39
  """
40
 
 
48
  from pathlib import Path
49
  from typing import Literal
50
 
 
51
  import loompy as lp
52
  import numpy as np
53
  import pandas as pd
 
88
  collapse_gene_ids,
89
  gene_mapping_dict,
90
  gene_token_dict,
 
 
91
  file_format="loom",
92
  chunk_size=512,
93
  ):
 
103
  assert (
104
  "ensembl_id_collapsed" not in data.ra.keys()
105
  ), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
 
 
 
 
 
 
 
 
 
 
 
106
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
107
  # Comparing to gene_token_dict here, would not perform any mapping steps
108
+ gene_ids_in_dict = [
109
+ gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
110
+ ]
111
+ if collapse_gene_ids is False:
112
+
113
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
114
  return data_directory
115
  else:
116
  raise ValueError("Error: data Ensembl IDs non-unique.")
117
+
118
+ gene_ids_collapsed = [
119
+ gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
120
+ ]
121
+ gene_ids_collapsed_in_dict = [
122
+ gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
123
+ ]
124
+
125
+ if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
126
+ data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
127
  return data_directory
 
128
  else:
129
  dedup_filename = data_directory.with_name(
130
  data_directory.stem + "__dedup.loom"
131
  )
132
+ data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
 
133
  dup_genes = [
134
  idx
135
  for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
 
188
  dsout.add_columns(processed_array, col_attrs=view.ca)
189
  return dedup_filename
190
 
191
+ elif file_format == "h5ad":
192
  """
193
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
194
  Returns adata object with deduplicated Ensembl IDs.
195
  """
196
 
197
+ data = sc.read_h5ad(str(data_directory))
 
 
 
 
 
 
198
 
199
  assert (
200
  "ensembl_id" in data.var.columns
 
203
  assert (
204
  "ensembl_id_collapsed" not in data.var.columns
205
  ), "'ensembl_id_collapsed' column already exists in data.var"
 
 
 
 
 
 
 
206
 
 
 
 
207
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
208
  # Comparing to gene_token_dict here, would not perform any mapping steps
209
+ gene_ids_in_dict = [
210
+ gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
211
+ ]
212
+ if collapse_gene_ids is False:
213
+
214
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
215
  return data
216
  else:
217
  raise ValueError("Error: data Ensembl IDs non-unique.")
218
 
219
+ # Check for when if collapse_gene_ids is True
220
+ gene_ids_collapsed = [
221
+ gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
222
+ ]
223
+ gene_ids_collapsed_in_dict = [
224
+ gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
225
+ ]
226
+ if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
227
+ data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
228
  return data
229
+
230
  else:
231
+ data.var["ensembl_id_collapsed"] = gene_ids_collapsed
232
+ data.var_names = gene_ids_collapsed
233
  data = data[:, ~data.var.index.isna()]
234
  dup_genes = [
235
  idx for idx, count in Counter(data.var_names).items() if count > 1
 
278
  model_input_size=4096,
279
  special_token=True,
280
  collapse_gene_ids=True,
 
 
 
281
  gene_median_file=GENE_MEDIAN_FILE,
282
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
283
  gene_mapping_file=ENSEMBL_MAPPING_FILE,
 
297
  | Chunk size for anndata tokenizer.
298
  model_input_size : int = 4096
299
  | Max input size of model to truncate input to.
300
+ | For the 30M model series, should be 2048. For the 95M model series, should be 4096.
301
  special_token : bool = True
302
  | Adds CLS token before and EOS token after rank value encoding.
303
+ | For the 30M model series, should be False. For the 95M model series, should be True.
304
  collapse_gene_ids : bool = True
305
  | Whether to collapse gene IDs based on gene mapping dictionary.
 
 
 
 
 
 
 
 
306
  gene_median_file : Path
307
  | Path to pickle file containing dictionary of non-zero median
308
+ | gene expression values across Genecorpus-30M.
309
  token_dictionary_file : Path
310
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
311
  gene_mapping_file : None, Path
 
327
  # add CLS and EOS tokens
328
  self.special_token = special_token
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  # load dictionary of gene normalization factors
331
+ # (non-zero median value of expression across Genecorpus-30M)
332
  with open(gene_median_file, "rb") as f:
333
  self.gene_median_dict = pickle.load(f)
334
 
 
351
  "<eos>" in self.gene_token_dict.keys()
352
  ):
353
  logger.warning(
354
+ "<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for 95M model series, special_token should be True."
355
  )
356
 
357
  # if collapsing duplicate gene IDs
358
  self.collapse_gene_ids = collapse_gene_ids
359
 
 
 
 
 
 
 
360
  # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
361
  if gene_mapping_file is not None:
362
  with open(gene_mapping_file, "rb") as f:
 
381
  data_directory: Path | str,
382
  output_directory: Path | str,
383
  output_prefix: str,
384
+ file_format: Literal["loom", "h5ad"] = "loom",
 
385
  use_generator: bool = False,
386
  ):
387
  """
 
396
  output_prefix : str
397
  | Prefix for output .dataset
398
  file_format : str
399
+ | Format of input files. Can be "loom" or "h5ad".
 
 
 
400
  use_generator : bool
401
  | Whether to use generator or dict for tokenization.
402
 
403
  """
404
+ tokenized_cells, cell_metadata = self.tokenize_files(
405
+ Path(data_directory), file_format
406
  )
407
  tokenized_dataset = self.create_dataset(
408
  tokenized_cells,
409
  cell_metadata,
 
410
  use_generator=use_generator,
411
  )
412
 
 
414
  tokenized_dataset.save_to_disk(str(output_path))
415
 
416
  def tokenize_files(
417
+ self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
418
  ):
419
  tokenized_cells = []
 
420
  if self.custom_attr_name_dict is not None:
421
  cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
422
  cell_metadata = {
 
425
 
426
  # loops through directories to tokenize .loom files
427
  file_found = 0
428
+ # loops through directories to tokenize .loom or .h5ad files
429
  tokenize_file_fn = (
430
  self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
431
  )
432
+ for file_path in data_directory.glob(f"*.{file_format}"):
 
 
 
 
433
  file_found = 1
434
  print(f"Tokenizing {file_path}")
435
+ file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
436
  tokenized_cells += file_tokenized_cells
 
437
  if self.custom_attr_name_dict is not None:
438
  for k in cell_attr:
439
  cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
 
447
  f"No .{file_format} files found in directory {data_directory}."
448
  )
449
  raise
450
+ return tokenized_cells, cell_metadata
451
 
452
+ def tokenize_anndata(self, adata_file_path, target_sum=10_000):
453
  adata = sum_ensembl_ids(
454
  adata_file_path,
455
  self.collapse_gene_ids,
456
  self.gene_mapping_dict,
457
  self.gene_token_dict,
458
+ file_format="h5ad",
 
 
459
  chunk_size=self.chunk_size,
460
  )
461
 
 
494
  filter_pass_loc = np.array([i for i in range(adata.shape[0])])
495
 
496
  tokenized_cells = []
 
497
 
498
  for i in range(0, len(filter_pass_loc), self.chunk_size):
499
  idx = filter_pass_loc[i : i + self.chunk_size]
 
501
  n_counts = adata[idx].obs["n_counts"].values[:, None]
502
  X_view0 = adata[idx, :].X
503
  X_view = X_view0[:, coding_miRNA_loc]
504
+ X_norm = X_view / n_counts * target_sum / norm_factor_vector
 
505
  X_norm = sp.csr_matrix(X_norm)
 
506
 
507
  tokenized_cells += [
508
  rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
509
  for i in range(X_norm.shape[0])
510
  ]
511
 
 
 
 
 
 
 
 
512
  # add custom attributes for subview to dict
513
  if self.custom_attr_name_dict is not None:
514
  for k in file_cell_metadata.keys():
 
516
  else:
517
  file_cell_metadata = None
518
 
519
+ return tokenized_cells, file_cell_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
+ def tokenize_loom(self, loom_file_path, target_sum=10_000):
 
 
522
  if self.custom_attr_name_dict is not None:
523
  file_cell_metadata = {
524
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
 
531
  self.collapse_gene_ids,
532
  self.gene_mapping_dict,
533
  self.gene_token_dict,
534
+ file_format="loom",
 
 
535
  chunk_size=self.chunk_size,
536
  )
537
 
 
604
  del data.ra["ensembl_id_collapsed"]
605
 
606
 
607
+ return tokenized_cells, file_cell_metadata
608
 
609
  def create_dataset(
610
  self,
611
  tokenized_cells,
612
  cell_metadata,
 
613
  use_generator=False,
614
  keep_uncropped_input_ids=False,
615
  ):
616
  print("Creating dataset.")
617
  # create dict for dataset creation
618
  dataset_dict = {"input_ids": tokenized_cells}
 
 
 
619
  if self.custom_attr_name_dict is not None:
620
  dataset_dict.update(cell_metadata)
621
 
622
  # create dataset
623
  if use_generator:
624
+
625
  def dict_generator():
626
  for i in range(len(tokenized_cells)):
627
  yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
628
 
629
  output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
 
630
  else:
631
  output_dataset = Dataset.from_dict(dataset_dict)
632
 
 
649
  len(example["input_ids"]),
650
  self.gene_token_dict.get("<eos>"),
651
  )
 
 
 
 
 
 
 
 
 
 
 
 
652
  else:
653
  # Truncate/Crop input_ids to input size
654
  example["input_ids"] = example["input_ids"][0 : self.model_input_size]
 
 
655
  example["length"] = len(example["input_ids"])
656
 
657
  return example
generation_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
- "transformers_version": "4.44.2"
5
  }
 
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
+ "transformers_version": "4.37.1"
5
  }
gf-12L-30M-i2048/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "gradient_checkpointing": false,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0",
20
+ "type_vocab_size": 2,
21
+ "use_cache": true,
22
+ "vocab_size": 25426
23
+ }
gf-12L-30M-i2048/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615
3
+ size 158467410