This view is limited to 50 files because it contains too many changes. See the raw diff here.
Files changed (50) hide show
  1. Geneformer-V2-316M/model.safetensors +0 -3
  2. MANIFEST.in +4 -9
  3. README.md +18 -16
  4. config.json +7 -7
  5. docs/source/geneformer.in_silico_perturber.rst +1 -1
  6. examples/cell_classification.ipynb +7 -6
  7. examples/distributed_multitask_cell_classification.ipynb +0 -149
  8. examples/extract_and_plot_cell_embeddings.ipynb +7 -6
  9. examples/gene_classification.ipynb +9 -10
  10. examples/in_silico_perturbation.ipynb +13 -10
  11. examples/multitask_cell_classification.ipynb +3 -3
  12. examples/tokenizing_scRNAseq_data.ipynb +8 -4
  13. {Geneformer-V2-104M → fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522}/config.json +6 -6
  14. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +3 -0
  15. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
  16. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
  17. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
  18. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
  19. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
  20. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
  21. fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
  22. geneformer/__init__.py +4 -9
  23. geneformer/classifier.py +7 -53
  24. geneformer/classifier_utils.py +3 -27
  25. geneformer/collator_for_classification.py +2 -7
  26. geneformer/emb_extractor.py +46 -115
  27. geneformer/{ensembl_mapping_dict_gc104M.pkl → ensembl_mapping_dict_gc95M.pkl} +0 -0
  28. geneformer/evaluation_utils.py +9 -24
  29. geneformer/{gene_median_dictionary_gc104M.pkl → gene_median_dictionary_gc95M.pkl} +0 -0
  30. geneformer/{gene_name_id_dict_gc104M.pkl → gene_name_id_dict_gc95M.pkl} +0 -0
  31. geneformer/in_silico_perturber.py +3 -23
  32. geneformer/in_silico_perturber_stats.py +1 -15
  33. geneformer/mtl/__init__.py +1 -4
  34. geneformer/mtl/collators.py +2 -2
  35. geneformer/mtl/data.py +117 -192
  36. geneformer/mtl/eval_utils.py +8 -5
  37. geneformer/mtl/imports.py +43 -0
  38. geneformer/mtl/model.py +1 -1
  39. geneformer/mtl/optuna_utils.py +27 -0
  40. geneformer/mtl/train.py +327 -666
  41. geneformer/mtl/train_utils.py +161 -0
  42. geneformer/mtl/utils.py +91 -603
  43. geneformer/mtl_classifier.py +8 -23
  44. geneformer/perturber_utils.py +62 -52
  45. geneformer/{token_dictionary_gc104M.pkl → token_dictionary_gc95M.pkl} +0 -0
  46. geneformer/tokenizer.py +33 -130
  47. generation_config.json +1 -1
  48. gf-12L-30M-i2048/config.json +23 -0
  49. gf-12L-30M-i2048/pytorch_model.bin +3 -0
  50. {Geneformer-V2-316M → gf-12L-30M-i2048}/training_args.bin +2 -2
Geneformer-V2-316M/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:965ceccea81953d362081ef3843560a0e4fef88d396c28017881f1e94b1246f3
3
- size 1265455076
 
 
 
 
MANIFEST.in CHANGED
@@ -1,9 +1,4 @@
1
- include geneformer/gene_median_dictionary_gc104M.pkl
2
- include geneformer/gene_name_id_dict_gc104M.pkl
3
- include geneformer/ensembl_mapping_dict_gc104M.pkl
4
- include geneformer/token_dictionary_gc104M.pkl
5
-
6
- include geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl
7
- include geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl
8
- include geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl
9
- include geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl
 
1
+ include geneformer/gene_median_dictionary_gc95M.pkl
2
+ include geneformer/gene_name_id_dict_gc95M.pkl
3
+ include geneformer/ensembl_mapping_dict_gc95M.pkl
4
+ include geneformer/token_dictionary_gc95M.pkl
 
 
 
 
 
README.md CHANGED
@@ -6,32 +6,38 @@ tags:
6
  - genomics
7
  ---
8
  # Geneformer
9
- Geneformer is a foundational transformer model pretrained on a large-scale corpus of human single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
10
 
11
  - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
12
- - See [our manuscript](https://rdcu.be/famFk) for details of the expanded model, now trained on ~104 million transcriptomes, and our quantization implementation for resource-efficient predictions.
13
- - See [our preprint](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of our continual and multitask learning strategies.
14
  - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
15
 
16
  # Model Description
17
- Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer V1 was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million human single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. The current updated Geneformer V2 is pretrained on ~104 million human single cell transcriptomes (non-cancer) from [Genecorpus-104M](https://huggingface.co/datasets/theodoris-lab/Genecorpus-104M). The cancer continual learning V2 variant was continually pretrained on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
18
 
19
- Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus (~30M for V1, ~104M for V2). The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
20
 
21
  The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
22
 
23
  We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
24
 
25
- During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational AI model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
26
 
27
  The repository includes the following pretrained models:
28
 
29
- - Geneformer-V1-10M: original model trained June 2021 on ~30M human single cell transcriptomes, 10M parameters, input size 2048, vocabulary ~25K protein-coding or non-coding RNA genes
30
- - Geneformer-V2-104M and Geneformer-V2-316M: updated model trained Dec 2024 on ~104M human single cell transcriptomes, 104M or 316M parameters, input size 4096, vocabulary ~20K protein-coding genes
 
 
31
 
32
- The current default model in the main directory of the repository is Geneformer-V2-316M.
 
 
 
33
 
34
- The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, Geneformer-V2-104M_CLcancer.
 
 
35
 
36
  # Application
37
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
@@ -79,14 +85,10 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
79
  - extracting and plotting cell embeddings
80
  - in silico perturbation
81
 
82
- Please also see [here](https://tinyurl.com/geneformertutorial) for a quickstart tutorial for predicting candidate therapeutic targets with Geneformer.
83
-
84
- Complete documentation is available at https://geneformer.readthedocs.io/en/latest/.
85
-
86
  Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
87
 
88
- Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). Importantly, as usual for deep learning models, there are no uniformly applicable default hyperparameters for Geneformer.
89
 
90
  # Citations
91
  - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
92
- - H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Scaling and quantization of large-scale foundation model enables resource-efficient predictions in network biology. _**Nature Computational Science**_, 27 Mar 2026. (*co-first authors, †co-senior authors, #corresponding author)
 
6
  - genomics
7
  ---
8
  # Geneformer
9
+ Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
10
 
11
  - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
12
+ - See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
 
13
  - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
14
 
15
  # Model Description
16
+ Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
17
 
18
+ Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
19
 
20
  The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
21
 
22
  We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
23
 
24
+ During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
25
 
26
  The repository includes the following pretrained models:
27
 
28
+ L=layers\
29
+ M=millions of cells used for pretraining\
30
+ i=input size\
31
+ (pretraining date)
32
 
33
+ - GF-6L-30M-i2048 (June 2021)
34
+ - GF-12L-30M-i2048 (June 2021)
35
+ - GF-12L-95M-i4096 (April 2024)
36
+ - GF-20L-95M-i4096 (April 2024)
37
 
38
+ The current default model in the main directory of the repository is GF-12L-95M-i4096.
39
+
40
+ The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
41
 
42
  # Application
43
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
 
85
  - extracting and plotting cell embeddings
86
  - in silico perturbation
87
 
 
 
 
 
88
  Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
89
 
90
+ Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
91
 
92
  # Citations
93
  - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
94
+ - H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. _**bioRxiv**_, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
config.json CHANGED
@@ -2,22 +2,22 @@
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
- "attention_probs_dropout_prob": 0.1,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.1,
9
- "hidden_size": 1152,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 4608,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
- "num_attention_heads": 18,
16
- "num_hidden_layers": 18,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
- "transformers_version": "4.44.2",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
 
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
+ "attention_probs_dropout_prob": 0.02,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
+ "transformers_version": "4.37.1",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
docs/source/geneformer.in_silico_perturber.rst CHANGED
@@ -5,4 +5,4 @@ geneformer.in\_silico\_perturber
5
  :members:
6
  :undoc-members:
7
  :show-inheritance:
8
- :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, , isp_perturb_all_special, isp_perturb_set_special, update_perturbation_dictionary
 
5
  :members:
6
  :undoc-members:
7
  :show-inheritance:
8
+ :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary
examples/cell_classification.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
14
  "metadata": {},
15
  "source": [
16
- "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters."
17
  ]
18
  },
19
  {
@@ -69,7 +69,9 @@
69
  " \"seed\": 73,\n",
70
  "}\n",
71
  "\n",
72
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
73
  "cc = Classifier(classifier=\"cell\",\n",
74
  " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
75
  " filter_data=filter_data_dict,\n",
@@ -78,7 +80,6 @@
78
  " freeze_layers = 2,\n",
79
  " num_crossval_splits = 1,\n",
80
  " forward_batch_size=200,\n",
81
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
82
  " nproc=16)"
83
  ]
84
  },
@@ -263,8 +264,8 @@
263
  " \"train\": train_ids,\n",
264
  " \"eval\": eval_ids}\n",
265
  "\n",
266
- "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
267
- "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
268
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
269
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
270
  " output_directory=output_dir,\n",
@@ -449,7 +450,7 @@
449
  "name": "python",
450
  "nbconvert_exporter": "python",
451
  "pygments_lexer": "ipython3",
452
- "version": "3.10.13"
453
  }
454
  },
455
  "nbformat": 4,
 
13
  "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
14
  "metadata": {},
15
  "source": [
16
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
17
  ]
18
  },
19
  {
 
69
  " \"seed\": 73,\n",
70
  "}\n",
71
  "\n",
72
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
73
+ "# (otherwise the Classifier will use the current default model dictionary)\n",
74
+ "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
75
  "cc = Classifier(classifier=\"cell\",\n",
76
  " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
77
  " filter_data=filter_data_dict,\n",
 
80
  " freeze_layers = 2,\n",
81
  " num_crossval_splits = 1,\n",
82
  " forward_batch_size=200,\n",
 
83
  " nproc=16)"
84
  ]
85
  },
 
264
  " \"train\": train_ids,\n",
265
  " \"eval\": eval_ids}\n",
266
  "\n",
267
+ "# Example 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n",
268
+ "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
269
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
270
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
271
  " output_directory=output_dir,\n",
 
450
  "name": "python",
451
  "nbconvert_exporter": "python",
452
  "pygments_lexer": "ipython3",
453
+ "version": "3.10.15"
454
  }
455
  },
456
  "nbformat": 4,
examples/distributed_multitask_cell_classification.ipynb DELETED
@@ -1,149 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "b3266a7b",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import os\n",
11
- "import torch\n",
12
- "from geneformer import MTLClassifier"
13
- ]
14
- },
15
- {
16
- "cell_type": "code",
17
- "execution_count": null,
18
- "id": "3e12ac9f",
19
- "metadata": {},
20
- "outputs": [],
21
- "source": [
22
- "# Define paths\n",
23
- "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
24
- "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
25
- "train_path = \"/path/to/train/data.dataset\"\n",
26
- "val_path = \"/path/to/val/data.dataset\"\n",
27
- "test_path = \"/path/to/test/data.dataset\"\n",
28
- "results_dir = \"/path/to/results/directory\"\n",
29
- "model_save_path = \"/path/to/model/save/path\"\n",
30
- "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
31
- "\n",
32
- "# Define tasks and hyperparameters\n",
33
- "# task_columns should be a list of column names from your dataset\n",
34
- "# Each column represents a specific classification task (e.g. cell type, disease state)\n",
35
- "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns"
36
- ]
37
- },
38
- {
39
- "cell_type": "code",
40
- "execution_count": null,
41
- "id": "c9bd7562",
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "# Check GPU environment\n",
46
- "num_gpus = torch.cuda.device_count()\n",
47
- "use_distributed = num_gpus > 1\n",
48
- "print(f\"Number of GPUs detected: {num_gpus}\")\n",
49
- "print(f\"Using distributed training: {use_distributed}\")\n",
50
- "\n",
51
- "# Set environment variables for distributed training when multiple GPUs are available\n",
52
- "if use_distributed:\n",
53
- " os.environ[\"MASTER_ADDR\"] = \"localhost\" # hostname\n",
54
- " os.environ[\"MASTER_PORT\"] = \"12355\" # Choose an available port\n",
55
- " print(\"Distributed environment variables set.\")"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": null,
61
- "id": "b6ff3618",
62
- "metadata": {},
63
- "outputs": [],
64
- "source": [
65
- "#Define Hyperparameters for Optimization\n",
66
- "hyperparameters = {\n",
67
- " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
68
- " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
69
- " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
70
- " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
71
- " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
72
- " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n",
73
- "}"
74
- ]
75
- },
76
- {
77
- "cell_type": "code",
78
- "execution_count": null,
79
- "id": "f665c5a7",
80
- "metadata": {},
81
- "outputs": [],
82
- "source": [
83
- "mc = MTLClassifier(\n",
84
- " task_columns=task_columns, # Our defined classification tasks\n",
85
- " study_name=\"MTLClassifier_distributed\",\n",
86
- " pretrained_path=pretrained_path,\n",
87
- " train_path=train_path,\n",
88
- " val_path=val_path,\n",
89
- " test_path=test_path,\n",
90
- " model_save_path=model_save_path,\n",
91
- " results_dir=results_dir,\n",
92
- " tensorboard_log_dir=tensorboard_log_dir,\n",
93
- " hyperparameters=hyperparameters,\n",
94
- " # Distributed training parameters\n",
95
- " distributed_training=use_distributed, # Enable distributed training if multiple GPUs available\n",
96
- " master_addr=\"localhost\" if use_distributed else None,\n",
97
- " master_port=\"12355\" if use_distributed else None,\n",
98
- " # Other training parameters\n",
99
- " n_trials=15, # Number of trials for hyperparameter optimization\n",
100
- " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
101
- " batch_size=8, # Adjust based on available GPU memory\n",
102
- " gradient_accumulation_steps=4, # Accumulate gradients over multiple steps\n",
103
- " gradient_clipping=True, # Enable gradient clipping for stability\n",
104
- " max_grad_norm=1.0, # Set maximum gradient norm\n",
105
- " seed=42\n",
106
- ")"
107
- ]
108
- },
109
- {
110
- "cell_type": "code",
111
- "execution_count": null,
112
- "id": "f69f7b6a",
113
- "metadata": {},
114
- "outputs": [],
115
- "source": [
116
- "# Run Hyperparameter Optimization with Distributed Training\n",
117
- "if __name__ == \"__main__\":\n",
118
- " # This guard is required for distributed training to prevent\n",
119
- " # infinite subprocess spawning when using torch.multiprocessing\n",
120
- " mc.run_optuna_study()"
121
- ]
122
- },
123
- {
124
- "cell_type": "code",
125
- "execution_count": null,
126
- "id": "3affd5dd",
127
- "metadata": {},
128
- "outputs": [],
129
- "source": [
130
- "# Evaluate the Model on Test Data\n",
131
- "if __name__ == \"__main__\":\n",
132
- " mc.load_and_evaluate_test_model()"
133
- ]
134
- }
135
- ],
136
- "metadata": {
137
- "kernelspec": {
138
- "display_name": "bio",
139
- "language": "python",
140
- "name": "python3"
141
- },
142
- "language_info": {
143
- "name": "python",
144
- "version": "3.12.8"
145
- }
146
- },
147
- "nbformat": 4,
148
- "nbformat_minor": 5
149
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/extract_and_plot_cell_embeddings.ipynb CHANGED
@@ -18,7 +18,8 @@
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
21
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
22
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
23
  " num_classes=3,\n",
24
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
@@ -27,13 +28,13 @@
27
  " emb_label=[\"disease\",\"cell_type\"],\n",
28
  " labels_to_plot=[\"disease\"],\n",
29
  " forward_batch_size=200,\n",
30
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
31
- " nproc=16)\n",
32
  "\n",
33
  "# extracts embedding from input data\n",
34
  "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
35
- "# example dataset for V1 model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
36
- "embs = embex.extract_embs(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
37
  " \"path/to/input_data/\",\n",
38
  " \"path/to/output_directory/\",\n",
39
  " \"output_prefix\")\n"
@@ -131,7 +132,7 @@
131
  "name": "python",
132
  "nbconvert_exporter": "python",
133
  "pygments_lexer": "ipython3",
134
- "version": "3.10.13"
135
  }
136
  },
137
  "nbformat": 4,
 
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
21
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
22
+ "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
23
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
24
  " num_classes=3,\n",
25
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
 
28
  " emb_label=[\"disease\",\"cell_type\"],\n",
29
  " labels_to_plot=[\"disease\"],\n",
30
  " forward_batch_size=200,\n",
31
+ " nproc=16,\n",
32
+ " token_dictionary_file=\"./gene_dictionaries_30m/token_dictionary_gc30M.pkl\") # change from current default dictionary for 30M model series\n",
33
  "\n",
34
  "# extracts embedding from input data\n",
35
  "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
36
+ "# example dataset for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
37
+ "embs = embex.extract_embs(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
38
  " \"path/to/input_data/\",\n",
39
  " \"path/to/output_directory/\",\n",
40
  " \"output_prefix\")\n"
 
132
  "name": "python",
133
  "nbconvert_exporter": "python",
134
  "pygments_lexer": "ipython3",
135
+ "version": "3.10.15"
136
  }
137
  },
138
  "nbformat": 4,
examples/gene_classification.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  "id": "79539e95-2c9c-4162-835c-f0d158abb15d",
14
  "metadata": {},
15
  "source": [
16
- "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters."
17
  ]
18
  },
19
  {
@@ -71,14 +71,15 @@
71
  }
72
  ],
73
  "source": [
74
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
75
  "cc = Classifier(classifier=\"gene\",\n",
76
  " gene_class_dict = gene_class_dict,\n",
77
  " max_ncells = 10_000,\n",
78
  " freeze_layers = 4,\n",
79
  " num_crossval_splits = 5,\n",
80
  " forward_batch_size=200,\n",
81
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
82
  " nproc=16)"
83
  ]
84
  },
@@ -842,8 +843,8 @@
842
  }
843
  ],
844
  "source": [
845
- "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
846
- "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
847
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
848
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
849
  " output_directory=output_dir,\n",
@@ -1065,14 +1066,12 @@
1065
  }
1066
  ],
1067
  "source": [
1068
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
1069
  "cc = Classifier(classifier=\"gene\",\n",
1070
  " gene_class_dict = gene_class_dict,\n",
1071
  " max_ncells = 10_000,\n",
1072
  " freeze_layers = 4,\n",
1073
  " num_crossval_splits = 0,\n",
1074
  " forward_batch_size=200,\n",
1075
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
1076
  " nproc=16)"
1077
  ]
1078
  },
@@ -1219,8 +1218,8 @@
1219
  }
1220
  ],
1221
  "source": [
1222
- "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
1223
- "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
1224
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
1225
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
1226
  " output_directory=output_dir,\n",
@@ -1244,7 +1243,7 @@
1244
  "name": "python",
1245
  "nbconvert_exporter": "python",
1246
  "pygments_lexer": "ipython3",
1247
- "version": "3.10.13"
1248
  }
1249
  },
1250
  "nbformat": 4,
 
13
  "id": "79539e95-2c9c-4162-835c-f0d158abb15d",
14
  "metadata": {},
15
  "source": [
16
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
17
  ]
18
  },
19
  {
 
71
  }
72
  ],
73
  "source": [
74
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
75
+ "# (otherwise the Classifier will use the current default model dictionary)\n",
76
+ "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
77
  "cc = Classifier(classifier=\"gene\",\n",
78
  " gene_class_dict = gene_class_dict,\n",
79
  " max_ncells = 10_000,\n",
80
  " freeze_layers = 4,\n",
81
  " num_crossval_splits = 5,\n",
82
  " forward_batch_size=200,\n",
 
83
  " nproc=16)"
84
  ]
85
  },
 
843
  }
844
  ],
845
  "source": [
846
+ "# 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n",
847
+ "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
848
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
849
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
850
  " output_directory=output_dir,\n",
 
1066
  }
1067
  ],
1068
  "source": [
 
1069
  "cc = Classifier(classifier=\"gene\",\n",
1070
  " gene_class_dict = gene_class_dict,\n",
1071
  " max_ncells = 10_000,\n",
1072
  " freeze_layers = 4,\n",
1073
  " num_crossval_splits = 0,\n",
1074
  " forward_batch_size=200,\n",
 
1075
  " nproc=16)"
1076
  ]
1077
  },
 
1218
  }
1219
  ],
1220
  "source": [
1221
+ "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
1222
+ "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n",
1223
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
1224
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
1225
  " output_directory=output_dir,\n",
 
1243
  "name": "python",
1244
  "nbconvert_exporter": "python",
1245
  "pygments_lexer": "ipython3",
1246
+ "version": "3.10.15"
1247
  }
1248
  },
1249
  "nbformat": 4,
examples/in_silico_perturbation.ipynb CHANGED
@@ -39,7 +39,9 @@
39
  "\n",
40
  "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
  "\n",
42
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
43
  "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
44
  " num_classes=3,\n",
45
  " filter_data=filter_data_dict,\n",
@@ -47,7 +49,6 @@
47
  " emb_layer=0,\n",
48
  " summary_stat=\"exact_mean\",\n",
49
  " forward_batch_size=256,\n",
50
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
51
  " nproc=16)\n",
52
  "\n",
53
  "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
@@ -66,7 +67,9 @@
66
  },
67
  "outputs": [],
68
  "source": [
69
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
70
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
71
  " perturb_rank_shift=None,\n",
72
  " genes_to_perturb=\"all\",\n",
@@ -74,7 +77,7 @@
74
  " anchor_gene=None,\n",
75
  " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
76
  " num_classes=3,\n",
77
- " emb_mode=\"cell\", # OF NOTE: SET TO \"CELL\" FOR V1 MODEL. FOR V2, SHOULD BE \"CLS\" (current default).\n",
78
  " cell_emb_style=\"mean_pool\",\n",
79
  " filter_data=filter_data_dict,\n",
80
  " cell_states_to_model=cell_states_to_model,\n",
@@ -82,7 +85,6 @@
82
  " max_ncells=2000,\n",
83
  " emb_layer=0,\n",
84
  " forward_batch_size=400,\n",
85
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
86
  " nproc=16)"
87
  ]
88
  },
@@ -95,7 +97,7 @@
95
  "source": [
96
  "# outputs intermediate files from in silico perturbation\n",
97
  "\n",
98
- "isp.perturb_data(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
99
  " \"path/to/input_data\",\n",
100
  " \"path/to/isp_output_directory\",\n",
101
  " \"output_prefix\")"
@@ -108,13 +110,14 @@
108
  "metadata": {},
109
  "outputs": [],
110
  "source": [
111
- "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
112
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
113
  " genes_perturbed=\"all\",\n",
114
  " combos=0,\n",
115
  " anchor_gene=None,\n",
116
- " cell_states_to_model=cell_states_to_model,\n",
117
- " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL SINCE V1 WAS USED FOR IN SILICO PERTURBATION ABOVE)"
118
  ]
119
  },
120
  {
@@ -148,7 +151,7 @@
148
  "name": "python",
149
  "nbconvert_exporter": "python",
150
  "pygments_lexer": "ipython3",
151
- "version": "3.10.13"
152
  }
153
  },
154
  "nbformat": 4,
 
39
  "\n",
40
  "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
  "\n",
42
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
43
+ "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
44
+ "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
45
  "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
46
  " num_classes=3,\n",
47
  " filter_data=filter_data_dict,\n",
 
49
  " emb_layer=0,\n",
50
  " summary_stat=\"exact_mean\",\n",
51
  " forward_batch_size=256,\n",
 
52
  " nproc=16)\n",
53
  "\n",
54
  "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
 
67
  },
68
  "outputs": [],
69
  "source": [
70
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
71
+ "# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
72
+ "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
73
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
74
  " perturb_rank_shift=None,\n",
75
  " genes_to_perturb=\"all\",\n",
 
77
  " anchor_gene=None,\n",
78
  " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
79
  " num_classes=3,\n",
80
+ " emb_mode=\"cell\",\n",
81
  " cell_emb_style=\"mean_pool\",\n",
82
  " filter_data=filter_data_dict,\n",
83
  " cell_states_to_model=cell_states_to_model,\n",
 
85
  " max_ncells=2000,\n",
86
  " emb_layer=0,\n",
87
  " forward_batch_size=400,\n",
 
88
  " nproc=16)"
89
  ]
90
  },
 
97
  "source": [
98
  "# outputs intermediate files from in silico perturbation\n",
99
  "\n",
100
+ "isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
101
  " \"path/to/input_data\",\n",
102
  " \"path/to/isp_output_directory\",\n",
103
  " \"output_prefix\")"
 
110
  "metadata": {},
111
  "outputs": [],
112
  "source": [
113
+ "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
114
+ "# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
115
+ "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
116
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
117
  " genes_perturbed=\"all\",\n",
118
  " combos=0,\n",
119
  " anchor_gene=None,\n",
120
+ " cell_states_to_model=cell_states_to_model)"
 
121
  ]
122
  },
123
  {
 
151
  "name": "python",
152
  "nbconvert_exporter": "python",
153
  "pygments_lexer": "ipython3",
154
+ "version": "3.10.15"
155
  }
156
  },
157
  "nbformat": 4,
examples/multitask_cell_classification.ipynb CHANGED
@@ -286,7 +286,7 @@
286
  " filter_data_dict=filter_data_dict,\n",
287
  " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
  " emb_layer=0, # Use the second to last layer\n",
289
- " emb_mode = \"cls\", # Use CLS token embedding for V2 model\n",
290
  " summary_stat=\"exact_mean\",\n",
291
  " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
  " nproc=4\n",
@@ -324,7 +324,7 @@
324
  " perturb_type=perturb_type,\n",
325
  " genes_to_perturb=\"all\", # Perturb all genes\n",
326
  " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
- " emb_mode=\"cls\", # Use CLS token embedding for V2 model\n",
328
  " cell_states_to_model=cell_states_to_model,\n",
329
  " state_embs_dict=state_embs_dict,\n",
330
  " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
@@ -412,7 +412,7 @@
412
  "name": "python",
413
  "nbconvert_exporter": "python",
414
  "pygments_lexer": "ipython3",
415
- "version": "3.10.13"
416
  }
417
  },
418
  "nbformat": 4,
 
286
  " filter_data_dict=filter_data_dict,\n",
287
  " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
  " emb_layer=0, # Use the second to last layer\n",
289
+ " emb_mode = \"cls\",\n",
290
  " summary_stat=\"exact_mean\",\n",
291
  " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
  " nproc=4\n",
 
324
  " perturb_type=perturb_type,\n",
325
  " genes_to_perturb=\"all\", # Perturb all genes\n",
326
  " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
+ " emb_mode=\"cls\", # Use CLS token embedding\n",
328
  " cell_states_to_model=cell_states_to_model,\n",
329
  " state_embs_dict=state_embs_dict,\n",
330
  " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
 
412
  "name": "python",
413
  "nbconvert_exporter": "python",
414
  "pygments_lexer": "ipython3",
415
+ "version": "3.11.5"
416
  }
417
  },
418
  "nbformat": 4,
examples/tokenizing_scRNAseq_data.ipynb CHANGED
@@ -34,8 +34,12 @@
34
  "metadata": {},
35
  "source": [
36
  "**********************************************************************************************************\n",
37
- "#### OF NOTE: Please ensure the correct token dictionary, gene median file, special token setting, and model input size is used for the correct model version.\n",
38
- "#### Current defaults are for V2 model series. To auto-select the correct settings for V1, set model_version argument to \"V1\"."
 
 
 
 
39
  ]
40
  },
41
  {
@@ -55,7 +59,7 @@
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
58
- "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16) # for V1 model, set model_version=\"V1\"\n",
59
  "tk.tokenize_data(\"loom_data_directory\", \n",
60
  " \"output_directory\", \n",
61
  " \"output_prefix\", \n",
@@ -79,7 +83,7 @@
79
  "name": "python",
80
  "nbconvert_exporter": "python",
81
  "pygments_lexer": "ipython3",
82
- "version": "3.10.13"
83
  }
84
  },
85
  "nbformat": 4,
 
34
  "metadata": {},
35
  "source": [
36
  "**********************************************************************************************************\n",
37
+ "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
38
+ "#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
39
+ "\n",
40
+ "#### ADDITIONALLY:\n",
41
+ "#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
42
+ "#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
43
  ]
44
  },
45
  {
 
59
  "metadata": {},
60
  "outputs": [],
61
  "source": [
62
+ "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
63
  "tk.tokenize_data(\"loom_data_directory\", \n",
64
  " \"output_directory\", \n",
65
  " \"output_prefix\", \n",
 
83
  "name": "python",
84
  "nbconvert_exporter": "python",
85
  "pygments_lexer": "ipython3",
86
+ "version": "3.10.15"
87
  }
88
  },
89
  "nbformat": 4,
{Geneformer-V2-104M → fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522}/config.json RENAMED
@@ -2,22 +2,22 @@
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
- "attention_probs_dropout_prob": 0.1,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.1,
9
- "hidden_size": 768,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 3072,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
- "num_attention_heads": 12,
16
  "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
- "transformers_version": "4.44.2",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
 
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
+ "attention_probs_dropout_prob": 0.02,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
+ "num_attention_heads": 8,
16
  "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
+ "transformers_version": "4.37.2",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
3
+ size 152363342
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json RENAMED
File without changes
fine_tuned_models/{Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin RENAMED
File without changes
geneformer/__init__.py CHANGED
@@ -4,15 +4,10 @@ from pathlib import Path
4
 
5
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
 
7
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc104M.pkl"
8
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc104M.pkl"
9
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc104M.pkl"
10
- ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc104M.pkl"
11
-
12
- GENE_MEDIAN_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl"
13
- TOKEN_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/token_dictionary_gc30M.pkl"
14
- ENSEMBL_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl"
15
- ENSEMBL_MAPPING_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl"
16
 
17
  from . import (
18
  collator_for_classification,
 
4
 
5
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
 
7
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
8
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
9
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
10
+ ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
 
 
 
 
 
11
 
12
  from . import (
13
  collator_for_classification,
geneformer/classifier.py CHANGED
@@ -48,13 +48,11 @@ import logging
48
  import os
49
  import pickle
50
  import subprocess
51
- from packaging.version import parse
52
  from pathlib import Path
53
 
54
  import numpy as np
55
  import pandas as pd
56
  import seaborn as sns
57
- import transformers
58
  from tqdm.auto import tqdm, trange
59
  from transformers import Trainer
60
  from transformers.training_args import TrainingArguments
@@ -73,7 +71,6 @@ sns.set()
73
 
74
  logger = logging.getLogger(__name__)
75
 
76
- transformers_version = parse(transformers.__version__)
77
 
78
  class Classifier:
79
  valid_option_dict = {
@@ -92,7 +89,6 @@ class Classifier:
92
  "no_eval": {bool},
93
  "stratify_splits_col": {None, str},
94
  "forward_batch_size": {int},
95
- "model_version": {"V1", "V2"},
96
  "token_dictionary_file": {None, str},
97
  "nproc": {int},
98
  "ngpu": {int},
@@ -116,7 +112,6 @@ class Classifier:
116
  stratify_splits_col=None,
117
  no_eval=False,
118
  forward_batch_size=100,
119
- model_version="V2",
120
  token_dictionary_file=None,
121
  nproc=4,
122
  ngpu=1,
@@ -193,9 +188,6 @@ class Classifier:
193
  | Otherwise, will perform eval during training.
194
  forward_batch_size : int
195
  | Batch size for forward pass (for evaluation, not training).
196
- model_version : str
197
- | To auto-select settings for model version other than current default.
198
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
199
  token_dictionary_file : None, str
200
  | Default is to use token dictionary file from Geneformer
201
  | Otherwise, will load custom gene token dictionary.
@@ -230,16 +222,14 @@ class Classifier:
230
  self.stratify_splits_col = stratify_splits_col
231
  self.no_eval = no_eval
232
  self.forward_batch_size = forward_batch_size
233
- self.model_version = model_version
234
  self.token_dictionary_file = token_dictionary_file
235
  self.nproc = nproc
236
  self.ngpu = ngpu
237
-
238
  if self.training_args is None:
239
  logger.warning(
240
  "Hyperparameter tuning is highly recommended for optimal results. "
241
- "No training_args provided; using default hyperparameters. "
242
- "Please note: these defaults are not recommended to be used uniformly across tasks."
243
  )
244
 
245
  self.validate_options()
@@ -254,10 +244,7 @@ class Classifier:
254
  ] = self.cell_state_dict["states"]
255
 
256
  # load token dictionary (Ensembl IDs:token)
257
- if self.model_version == "V1":
258
- from . import TOKEN_DICTIONARY_FILE_30M
259
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
260
- elif self.token_dictionary_file is None:
261
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE
262
  with open(self.token_dictionary_file, "rb") as f:
263
  self.gene_token_dict = pickle.load(f)
@@ -367,7 +354,6 @@ class Classifier:
367
  attr_to_balance=None,
368
  max_trials=100,
369
  pval_threshold=0.1,
370
- id_class_dict_path=None,
371
  ):
372
  """
373
  Prepare data for cell state or gene classification.
@@ -410,10 +396,6 @@ class Classifier:
410
  pval_threshold : None, float
411
  | P-value threshold to use for attribute balancing across splits
412
  | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
413
- id_class_dict_path : Path
414
- | Path to *_id_class_dict.pkl from prior run of prepare_data to reuse for labeling new data
415
- | Dictionary with keys being numeric class labels and values being original dataset class labels
416
- | Note: only available for CellClassifiers
417
  """
418
 
419
  if test_size is None:
@@ -457,13 +439,8 @@ class Classifier:
457
  data = cu.rename_cols(data, self.cell_state_dict["state_key"])
458
 
459
  # convert classes to numerical labels and save as id_class_dict
460
- if id_class_dict_path is not None:
461
- with open(id_class_dict_path,"rb") as fp:
462
- id_class_dict = pickle.load(fp)
463
- else:
464
- id_class_dict = None
465
  data, id_class_dict = cu.label_classes(
466
- self.classifier, data, self.cell_state_dict, self.nproc, id_class_dict,
467
  )
468
 
469
  elif self.classifier == "gene":
@@ -801,7 +778,7 @@ class Classifier:
801
  # 5-fold cross-validate
802
  num_cells = len(data)
803
  fifth_cells = int(np.floor(num_cells * 0.2))
804
- num_eval = int(min((self.eval_size * num_cells), fifth_cells))
805
  start = i * fifth_cells
806
  end = start + num_eval
807
  eval_indices = [j for j in range(start, end)]
@@ -1085,8 +1062,6 @@ class Classifier:
1085
  if eval_data is None:
1086
  def_training_args["evaluation_strategy"] = "no"
1087
  def_training_args["load_best_model_at_end"] = False
1088
- if transformers_version >= parse("4.46"):
1089
- def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy")
1090
  def_training_args.update(
1091
  {"save_strategy": "epoch", "save_total_limit": 1}
1092
  ) # only save last model for each run
@@ -1258,8 +1233,6 @@ class Classifier:
1258
  if eval_data is None:
1259
  def_training_args["evaluation_strategy"] = "no"
1260
  def_training_args["load_best_model_at_end"] = False
1261
- if transformers_version >= parse("4.46"):
1262
- def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy")
1263
  training_args_init = TrainingArguments(**def_training_args)
1264
 
1265
  if self.freeze_layers is not None:
@@ -1313,7 +1286,6 @@ class Classifier:
1313
  predict=False,
1314
  output_directory=None,
1315
  output_prefix=None,
1316
- predict_metadata=None,
1317
  ):
1318
  """
1319
  Evaluate the fine-tuned model.
@@ -1339,11 +1311,9 @@ class Classifier:
1339
 
1340
  ##### Evaluate the model #####
1341
  labels = id_class_dict.keys()
1342
-
1343
- y_pred, y_true, logits_list, predict_metadata_all = eu.classifier_predict(
1344
- model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata
1345
  )
1346
-
1347
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1348
  y_pred, y_true, logits_list, num_classes, labels
1349
  )
@@ -1353,9 +1323,6 @@ class Classifier:
1353
  "label_ids": y_true,
1354
  "predictions": logits_list,
1355
  }
1356
- if predict_metadata is not None:
1357
- pred_dict["prediction_metadata"] = predict_metadata_all
1358
-
1359
  pred_dict_output_path = (
1360
  Path(output_directory) / f"{output_prefix}_pred_dict"
1361
  ).with_suffix(".pkl")
@@ -1376,7 +1343,6 @@ class Classifier:
1376
  output_directory,
1377
  output_prefix,
1378
  predict=True,
1379
- predict_metadata=None,
1380
  ):
1381
  """
1382
  Evaluate the fine-tuned model.
@@ -1396,8 +1362,6 @@ class Classifier:
1396
  | Prefix for output files
1397
  predict : bool
1398
  | Whether or not to save eval predictions
1399
- predict_metadata : None | list
1400
- | Metadata labels to output with predictions (columns in test_data_file)
1401
  """
1402
 
1403
  # load numerical id to class dictionary (id:class)
@@ -1410,15 +1374,6 @@ class Classifier:
1410
  # load previously filtered and prepared data
1411
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1412
 
1413
- if predict_metadata is not None:
1414
- absent_metadata = []
1415
- for predict_metadata_x in predict_metadata:
1416
- if predict_metadata_x not in test_data.features.keys():
1417
- absent_metadata += [predict_metadata_x]
1418
- if len(absent_metadata)>0:
1419
- logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}")
1420
- raise
1421
-
1422
  # load previously fine-tuned model
1423
  model = pu.load_model(
1424
  self.model_type,
@@ -1437,7 +1392,6 @@ class Classifier:
1437
  predict=predict,
1438
  output_directory=output_directory,
1439
  output_prefix=output_prefix,
1440
- predict_metadata=predict_metadata,
1441
  )
1442
 
1443
  all_conf_mat_df = pd.DataFrame(
 
48
  import os
49
  import pickle
50
  import subprocess
 
51
  from pathlib import Path
52
 
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
 
56
  from tqdm.auto import tqdm, trange
57
  from transformers import Trainer
58
  from transformers.training_args import TrainingArguments
 
71
 
72
  logger = logging.getLogger(__name__)
73
 
 
74
 
75
  class Classifier:
76
  valid_option_dict = {
 
89
  "no_eval": {bool},
90
  "stratify_splits_col": {None, str},
91
  "forward_batch_size": {int},
 
92
  "token_dictionary_file": {None, str},
93
  "nproc": {int},
94
  "ngpu": {int},
 
112
  stratify_splits_col=None,
113
  no_eval=False,
114
  forward_batch_size=100,
 
115
  token_dictionary_file=None,
116
  nproc=4,
117
  ngpu=1,
 
188
  | Otherwise, will perform eval during training.
189
  forward_batch_size : int
190
  | Batch size for forward pass (for evaluation, not training).
 
 
 
191
  token_dictionary_file : None, str
192
  | Default is to use token dictionary file from Geneformer
193
  | Otherwise, will load custom gene token dictionary.
 
222
  self.stratify_splits_col = stratify_splits_col
223
  self.no_eval = no_eval
224
  self.forward_batch_size = forward_batch_size
 
225
  self.token_dictionary_file = token_dictionary_file
226
  self.nproc = nproc
227
  self.ngpu = ngpu
228
+
229
  if self.training_args is None:
230
  logger.warning(
231
  "Hyperparameter tuning is highly recommended for optimal results. "
232
+ "No training_args provided; using default hyperparameters."
 
233
  )
234
 
235
  self.validate_options()
 
244
  ] = self.cell_state_dict["states"]
245
 
246
  # load token dictionary (Ensembl IDs:token)
247
+ if self.token_dictionary_file is None:
 
 
 
248
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE
249
  with open(self.token_dictionary_file, "rb") as f:
250
  self.gene_token_dict = pickle.load(f)
 
354
  attr_to_balance=None,
355
  max_trials=100,
356
  pval_threshold=0.1,
 
357
  ):
358
  """
359
  Prepare data for cell state or gene classification.
 
396
  pval_threshold : None, float
397
  | P-value threshold to use for attribute balancing across splits
398
  | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
 
 
 
 
399
  """
400
 
401
  if test_size is None:
 
439
  data = cu.rename_cols(data, self.cell_state_dict["state_key"])
440
 
441
  # convert classes to numerical labels and save as id_class_dict
 
 
 
 
 
442
  data, id_class_dict = cu.label_classes(
443
+ self.classifier, data, self.cell_state_dict, self.nproc
444
  )
445
 
446
  elif self.classifier == "gene":
 
778
  # 5-fold cross-validate
779
  num_cells = len(data)
780
  fifth_cells = int(np.floor(num_cells * 0.2))
781
+ num_eval = min((self.eval_size * num_cells), fifth_cells)
782
  start = i * fifth_cells
783
  end = start + num_eval
784
  eval_indices = [j for j in range(start, end)]
 
1062
  if eval_data is None:
1063
  def_training_args["evaluation_strategy"] = "no"
1064
  def_training_args["load_best_model_at_end"] = False
 
 
1065
  def_training_args.update(
1066
  {"save_strategy": "epoch", "save_total_limit": 1}
1067
  ) # only save last model for each run
 
1233
  if eval_data is None:
1234
  def_training_args["evaluation_strategy"] = "no"
1235
  def_training_args["load_best_model_at_end"] = False
 
 
1236
  training_args_init = TrainingArguments(**def_training_args)
1237
 
1238
  if self.freeze_layers is not None:
 
1286
  predict=False,
1287
  output_directory=None,
1288
  output_prefix=None,
 
1289
  ):
1290
  """
1291
  Evaluate the fine-tuned model.
 
1311
 
1312
  ##### Evaluate the model #####
1313
  labels = id_class_dict.keys()
1314
+ y_pred, y_true, logits_list = eu.classifier_predict(
1315
+ model, self.classifier, eval_data, self.forward_batch_size
 
1316
  )
 
1317
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1318
  y_pred, y_true, logits_list, num_classes, labels
1319
  )
 
1323
  "label_ids": y_true,
1324
  "predictions": logits_list,
1325
  }
 
 
 
1326
  pred_dict_output_path = (
1327
  Path(output_directory) / f"{output_prefix}_pred_dict"
1328
  ).with_suffix(".pkl")
 
1343
  output_directory,
1344
  output_prefix,
1345
  predict=True,
 
1346
  ):
1347
  """
1348
  Evaluate the fine-tuned model.
 
1362
  | Prefix for output files
1363
  predict : bool
1364
  | Whether or not to save eval predictions
 
 
1365
  """
1366
 
1367
  # load numerical id to class dictionary (id:class)
 
1374
  # load previously filtered and prepared data
1375
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1376
 
 
 
 
 
 
 
 
 
 
1377
  # load previously fine-tuned model
1378
  model = pu.load_model(
1379
  self.model_type,
 
1392
  predict=predict,
1393
  output_directory=output_directory,
1394
  output_prefix=output_prefix,
 
1395
  )
1396
 
1397
  all_conf_mat_df = pd.DataFrame(
geneformer/classifier_utils.py CHANGED
@@ -94,7 +94,7 @@ def remove_rare(data, rare_threshold, label, nproc):
94
  return data
95
 
96
 
97
- def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None):
98
  if classifier == "cell":
99
  label_set = set(data["label"])
100
  elif classifier == "gene":
@@ -113,11 +113,8 @@ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None):
113
  )
114
  raise
115
 
116
- if id_class_dict is None:
117
- class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
118
- id_class_dict = {v: k for k, v in class_id_dict.items()}
119
- else:
120
- class_id_dict = {v: k for k, v in id_class_dict.items()}
121
 
122
  if classifier == "gene":
123
  inverse_gene_class_dict = {}
@@ -570,27 +567,6 @@ def compute_metrics(pred):
570
  return {"accuracy": acc, "macro_f1": macro_f1}
571
 
572
 
573
- def robust_compute_objective(metrics: dict):
574
- # tries both prefixed ("eval_") and raw metric names to support different transformers versions
575
- metric_name = "macro_f1"
576
-
577
- # check for the prefixed version
578
- prefixed_metric_name = f"eval_{metric_name}"
579
- if prefixed_metric_name in metrics:
580
- return metrics[prefixed_metric_name]
581
-
582
- # fall back to the raw name
583
- elif metric_name in metrics:
584
- return metrics[metric_name]
585
-
586
- # if neither is found, raise a clear error to help with debugging
587
- raise KeyError(
588
- f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. "
589
- f"Please check your `compute_metrics` function and `TrainingArguments`. "
590
- f"Available metrics: {list(metrics.keys())}"
591
- )
592
-
593
-
594
  def get_default_train_args(model, classifier, data, output_dir):
595
  num_layers = pu.quant_layers(model)
596
  freeze_layers = 0
 
94
  return data
95
 
96
 
97
+ def label_classes(classifier, data, gene_class_dict, nproc):
98
  if classifier == "cell":
99
  label_set = set(data["label"])
100
  elif classifier == "gene":
 
113
  )
114
  raise
115
 
116
+ class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
117
+ id_class_dict = {v: k for k, v in class_id_dict.items()}
 
 
 
118
 
119
  if classifier == "gene":
120
  inverse_gene_class_dict = {}
 
567
  return {"accuracy": acc, "macro_f1": macro_f1}
568
 
569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  def get_default_train_args(model, classifier, data, output_dir):
571
  num_layers = pu.quant_layers(model)
572
  freeze_layers = 0
geneformer/collator_for_classification.py CHANGED
@@ -26,10 +26,9 @@ LARGE_INTEGER = int(
26
  1e20
27
  ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
28
 
29
- warnings.filterwarnings("ignore", message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach()", category=UserWarning, module="torch")
30
-
31
  # precollator functions
32
 
 
33
  class ExplicitEnum(Enum):
34
  """
35
  Enum with more explicit error message for missing values.
@@ -104,9 +103,6 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
104
  def pad_token_id(self):
105
  return self._pad_token_id
106
 
107
- def save_pretrained(self, save_directory):
108
- pass
109
-
110
  def _get_padding_truncation_strategies(
111
  self,
112
  padding=True,
@@ -645,8 +641,7 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
645
  def __call__(self, features):
646
  batch = self._prepare_batch(features)
647
 
648
- # batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
649
- batch = {k: torch.tensor(v.clone().detach(), dtype=torch.int64) if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
650
  return batch
651
 
652
 
 
26
  1e20
27
  ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
28
 
 
 
29
  # precollator functions
30
 
31
+
32
  class ExplicitEnum(Enum):
33
  """
34
  Enum with more explicit error message for missing values.
 
103
  def pad_token_id(self):
104
  return self._pad_token_id
105
 
 
 
 
106
  def _get_padding_truncation_strategies(
107
  self,
108
  padding=True,
 
641
  def __call__(self, features):
642
  batch = self._prepare_batch(features)
643
 
644
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
 
645
  return batch
646
 
647
 
geneformer/emb_extractor.py CHANGED
@@ -42,8 +42,6 @@ def get_embs(
42
  special_token=False,
43
  summary_stat=None,
44
  silent=False,
45
- save_tdigest=False,
46
- tdigest_path=None,
47
  ):
48
  model_input_size = pu.get_model_input_size(model)
49
  total_batch_length = len(filtered_input_data)
@@ -182,18 +180,12 @@ def get_embs(
182
  # calculate summary stat embs from approximated tdigests
183
  elif summary_stat is not None:
184
  if emb_mode == "cell":
185
- if save_tdigest:
186
- with open(f"{tdigest_path}","wb") as fp:
187
- pickle.dump(embs_tdigests, fp)
188
  if summary_stat == "mean":
189
  summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
190
  elif summary_stat == "median":
191
  summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
192
  embs_stack = torch.tensor(summary_emb_list)
193
  elif emb_mode == "gene":
194
- if save_tdigest:
195
- with open(f"{tdigest_path}","wb") as fp:
196
- pickle.dump(embs_tdigests_dict, fp)
197
  if summary_stat == "mean":
198
  [
199
  update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
@@ -260,7 +252,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels):
260
  return embs_df
261
 
262
 
263
- def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mean_pool"):
264
  gene_set = {
265
  element for sublist in downsampled_data["input_ids"] for element in sublist
266
  }
@@ -275,52 +267,25 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mea
275
  )
276
  for k in dict_i.keys():
277
  gene_emb_dict[k].append(dict_i[k])
278
- if gene_emb_style != "all":
279
- for k in gene_emb_dict.keys():
280
- gene_emb_dict[k] = (
281
- torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
282
- .cpu()
283
- .numpy()
284
- )
285
- embs_df = pd.DataFrame(gene_emb_dict).T
286
- else:
287
- embs_df = dict_lol_to_df(gene_emb_dict)
288
  embs_df.index = [token_gene_dict[token] for token in embs_df.index]
289
  return embs_df
290
 
291
- def dict_lol_to_df(data_dict):
292
- # save dictionary with values being list of equal-length lists as dataframe
293
- df_data = []
294
- for key, list_of_lists in data_dict.items():
295
- for i, sublist in enumerate(list_of_lists):
296
- row_data = [key, i] + sublist.tolist()
297
- df_data.append(row_data)
298
-
299
- # determine column names based on the length of sublists
300
- # assuming all sublists have the same length
301
- num_columns_from_sublist = len(list(data_dict.values())[0][0])
302
- column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)]
303
-
304
- # create the dataframe
305
- df = pd.DataFrame(df_data, columns=column_names)
306
-
307
- # set 'Gene' as the index
308
- df = df.set_index('Gene')
309
-
310
- return df
311
-
312
- def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
313
  only_embs_df = embs_df.iloc[:, :emb_dims]
314
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
315
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
316
  str
317
  )
318
  vars_dict = {"embs": only_embs_df.columns}
319
-
320
- obs_dict = {"cell_id": list(only_embs_df.index)}
321
- for label_i in labels_clean:
322
- obs_dict[label_i] = list(embs_df[label_i])
323
-
324
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
325
  sc.tl.pca(adata, svd_solver="arpack")
326
  sc.pp.neighbors(adata, random_state=seed)
@@ -331,26 +296,21 @@ def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory,
331
  if kwargs_dict is not None:
332
  default_kwargs_dict.update(kwargs_dict)
333
 
334
- for label_i in labels_clean:
335
- output_prefix_label = output_prefix + f"_umap_{label_i}"
336
- output_file = (
337
- Path(output_directory) / output_prefix_label
338
- ).with_suffix(".pdf")
339
-
340
- cats = set(embs_df[label_i])
341
-
342
- with plt.rc_context():
343
- ax = sc.pl.umap(adata, color=label_i, show=False, **default_kwargs_dict)
344
- ax.legend(
345
- markerscale=2,
346
- frameon=False,
347
- loc="center left",
348
- bbox_to_anchor=(1, 0.5),
349
- ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
350
- )
351
- plt.show()
352
- plt.savefig(output_file, bbox_inches="tight")
353
-
354
  def gen_heatmap_class_colors(labels, df):
355
  pal = sns.cubehelix_palette(
356
  len(Counter(labels).keys()),
@@ -435,14 +395,13 @@ class EmbExtractor:
435
  "num_classes": {int},
436
  "emb_mode": {"cls", "cell", "gene"},
437
  "cell_emb_style": {"mean_pool"},
438
- "gene_emb_style": {"mean_pool", "all"},
439
  "filter_data": {None, dict},
440
  "max_ncells": {None, int},
441
  "emb_layer": {-1, 0},
442
  "emb_label": {None, list},
443
  "labels_to_plot": {None, list},
444
  "forward_batch_size": {int},
445
- "model_version": {"V1", "V2"},
446
  "token_dictionary_file": {None, str},
447
  "nproc": {int},
448
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
@@ -463,8 +422,6 @@ class EmbExtractor:
463
  forward_batch_size=100,
464
  nproc=4,
465
  summary_stat=None,
466
- save_tdigest=False,
467
- model_version="V2",
468
  token_dictionary_file=None,
469
  ):
470
  """
@@ -483,9 +440,9 @@ class EmbExtractor:
483
  cell_emb_style : {"mean_pool"}
484
  | Method for summarizing cell embeddings if not using CLS token.
485
  | Currently only option is mean pooling of gene embeddings for given cell.
486
- gene_emb_style : {"mean_pool", "all}
487
  | Method for summarizing gene embeddings.
488
- | Currently only option is returning all or mean pooling of contextual gene embeddings for given gene.
489
  filter_data : None, dict
490
  | Default is to extract embeddings from all input data.
491
  | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
@@ -515,12 +472,6 @@ class EmbExtractor:
515
  | If mean or median, outputs only approximated mean or median embedding of input data.
516
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
517
  | Non-exact is slower but more memory-efficient.
518
- save_tdigest : bool
519
- | Whether to save a dictionary of tdigests for each gene and embedding dimension
520
- | Only applies when summary_stat is not None
521
- model_version : str
522
- | To auto-select settings for model version other than current default.
523
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
524
  token_dictionary_file : Path
525
  | Default is the Geneformer token dictionary
526
  | Path to pickle file containing token dictionary (Ensembl ID:token).
@@ -551,7 +502,6 @@ class EmbExtractor:
551
  self.emb_layer = emb_layer
552
  self.emb_label = emb_label
553
  self.labels_to_plot = labels_to_plot
554
- self.model_version = model_version
555
  self.token_dictionary_file = token_dictionary_file
556
  self.forward_batch_size = forward_batch_size
557
  self.nproc = nproc
@@ -561,29 +511,13 @@ class EmbExtractor:
561
  else:
562
  self.summary_stat = summary_stat
563
  self.exact_summary_stat = None
564
- self.save_tdigest = save_tdigest
565
 
566
  self.validate_options()
567
 
568
- if (summary_stat is None) and (save_tdigest is True):
569
- logger.warning(
570
- "tdigests will not be saved since summary_stat is None."
571
- )
572
- save_tdigest = False
573
-
574
- if self.model_version == "V1":
575
- from . import TOKEN_DICTIONARY_FILE_30M
576
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
577
- if self.emb_mode == "cls":
578
- self.emb_mode = "cell"
579
- logger.warning(
580
- "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
581
- )
582
-
583
  # load token dictionary (Ensembl IDs:token)
584
  if self.token_dictionary_file is None:
585
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE
586
- with open(self.token_dictionary_file, "rb") as f:
587
  self.gene_token_dict = pickle.load(f)
588
 
589
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
@@ -677,10 +611,6 @@ class EmbExtractor:
677
  self.model_type, self.num_classes, model_directory, mode="eval"
678
  )
679
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
680
- if self.save_tdigest:
681
- tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl")
682
- else:
683
- tdigest_path = None
684
  embs = get_embs(
685
  model=model,
686
  filtered_input_data=downsampled_data,
@@ -690,8 +620,6 @@ class EmbExtractor:
690
  forward_batch_size=self.forward_batch_size,
691
  token_gene_dict=self.token_gene_dict,
692
  summary_stat=self.summary_stat,
693
- save_tdigest=self.save_tdigest,
694
- tdigest_path=tdigest_path,
695
  )
696
 
697
  if self.emb_mode == "cell":
@@ -701,7 +629,7 @@ class EmbExtractor:
701
  embs_df = pd.DataFrame(embs.cpu().numpy()).T
702
  elif self.emb_mode == "gene":
703
  if self.summary_stat is None:
704
- embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict, self.gene_emb_style)
705
  elif self.summary_stat is not None:
706
  embs_df = pd.DataFrame(embs).T
707
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
@@ -885,14 +813,14 @@ class EmbExtractor:
885
  raise
886
 
887
  if max_ncells_to_plot is not None:
888
- if self.max_ncells is not None:
889
- if max_ncells_to_plot > self.max_ncells:
890
- max_ncells_to_plot = self.max_ncells
891
- logger.warning(
892
- "max_ncells_to_plot must be <= max_ncells. "
893
- f"Changing max_ncells_to_plot to {self.max_ncells}."
894
- )
895
- embs = embs.sample(max_ncells_to_plot, axis=0)
896
 
897
  if self.emb_label is None:
898
  label_len = 0
@@ -913,9 +841,12 @@ class EmbExtractor:
913
  f"Label {label} from labels_to_plot "
914
  f"not present in provided embeddings dataframe."
915
  )
916
-
917
- labels_clean = [label for label in self.labels_to_plot if label in emb_labels]
918
- plot_umap(embs, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict)
 
 
 
919
 
920
  if plot_style == "heatmap":
921
  for label in self.labels_to_plot:
 
42
  special_token=False,
43
  summary_stat=None,
44
  silent=False,
 
 
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
 
180
  # calculate summary stat embs from approximated tdigests
181
  elif summary_stat is not None:
182
  if emb_mode == "cell":
 
 
 
183
  if summary_stat == "mean":
184
  summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
185
  elif summary_stat == "median":
186
  summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
187
  embs_stack = torch.tensor(summary_emb_list)
188
  elif emb_mode == "gene":
 
 
 
189
  if summary_stat == "mean":
190
  [
191
  update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
 
252
  return embs_df
253
 
254
 
255
+ def label_gene_embs(embs, downsampled_data, token_gene_dict):
256
  gene_set = {
257
  element for sublist in downsampled_data["input_ids"] for element in sublist
258
  }
 
267
  )
268
  for k in dict_i.keys():
269
  gene_emb_dict[k].append(dict_i[k])
270
+ for k in gene_emb_dict.keys():
271
+ gene_emb_dict[k] = (
272
+ torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
273
+ .cpu()
274
+ .numpy()
275
+ )
276
+ embs_df = pd.DataFrame(gene_emb_dict).T
 
 
 
277
  embs_df.index = [token_gene_dict[token] for token in embs_df.index]
278
  return embs_df
279
 
280
+
281
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  only_embs_df = embs_df.iloc[:, :emb_dims]
283
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
284
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
285
  str
286
  )
287
  vars_dict = {"embs": only_embs_df.columns}
288
+ obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
 
 
 
 
289
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
290
  sc.tl.pca(adata, svd_solver="arpack")
291
  sc.pp.neighbors(adata, random_state=seed)
 
296
  if kwargs_dict is not None:
297
  default_kwargs_dict.update(kwargs_dict)
298
 
299
+ cats = set(embs_df[label])
300
+
301
+ with plt.rc_context():
302
+ ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
303
+ ax.legend(
304
+ markerscale=2,
305
+ frameon=False,
306
+ loc="center left",
307
+ bbox_to_anchor=(1, 0.5),
308
+ ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
309
+ )
310
+ plt.show()
311
+ plt.savefig(output_file, bbox_inches="tight")
312
+
313
+
 
 
 
 
 
314
  def gen_heatmap_class_colors(labels, df):
315
  pal = sns.cubehelix_palette(
316
  len(Counter(labels).keys()),
 
395
  "num_classes": {int},
396
  "emb_mode": {"cls", "cell", "gene"},
397
  "cell_emb_style": {"mean_pool"},
398
+ "gene_emb_style": {"mean_pool"},
399
  "filter_data": {None, dict},
400
  "max_ncells": {None, int},
401
  "emb_layer": {-1, 0},
402
  "emb_label": {None, list},
403
  "labels_to_plot": {None, list},
404
  "forward_batch_size": {int},
 
405
  "token_dictionary_file": {None, str},
406
  "nproc": {int},
407
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
 
422
  forward_batch_size=100,
423
  nproc=4,
424
  summary_stat=None,
 
 
425
  token_dictionary_file=None,
426
  ):
427
  """
 
440
  cell_emb_style : {"mean_pool"}
441
  | Method for summarizing cell embeddings if not using CLS token.
442
  | Currently only option is mean pooling of gene embeddings for given cell.
443
+ gene_emb_style : "mean_pool"
444
  | Method for summarizing gene embeddings.
445
+ | Currently only option is mean pooling of contextual gene embeddings for given gene.
446
  filter_data : None, dict
447
  | Default is to extract embeddings from all input data.
448
  | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
 
472
  | If mean or median, outputs only approximated mean or median embedding of input data.
473
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
474
  | Non-exact is slower but more memory-efficient.
 
 
 
 
 
 
475
  token_dictionary_file : Path
476
  | Default is the Geneformer token dictionary
477
  | Path to pickle file containing token dictionary (Ensembl ID:token).
 
502
  self.emb_layer = emb_layer
503
  self.emb_label = emb_label
504
  self.labels_to_plot = labels_to_plot
 
505
  self.token_dictionary_file = token_dictionary_file
506
  self.forward_batch_size = forward_batch_size
507
  self.nproc = nproc
 
511
  else:
512
  self.summary_stat = summary_stat
513
  self.exact_summary_stat = None
 
514
 
515
  self.validate_options()
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  # load token dictionary (Ensembl IDs:token)
518
  if self.token_dictionary_file is None:
519
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
520
+ with open(token_dictionary_file, "rb") as f:
521
  self.gene_token_dict = pickle.load(f)
522
 
523
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
 
611
  self.model_type, self.num_classes, model_directory, mode="eval"
612
  )
613
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
 
 
 
 
614
  embs = get_embs(
615
  model=model,
616
  filtered_input_data=downsampled_data,
 
620
  forward_batch_size=self.forward_batch_size,
621
  token_gene_dict=self.token_gene_dict,
622
  summary_stat=self.summary_stat,
 
 
623
  )
624
 
625
  if self.emb_mode == "cell":
 
629
  embs_df = pd.DataFrame(embs.cpu().numpy()).T
630
  elif self.emb_mode == "gene":
631
  if self.summary_stat is None:
632
+ embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
633
  elif self.summary_stat is not None:
634
  embs_df = pd.DataFrame(embs).T
635
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
813
  raise
814
 
815
  if max_ncells_to_plot is not None:
816
+ if max_ncells_to_plot > self.max_ncells:
817
+ max_ncells_to_plot = self.max_ncells
818
+ logger.warning(
819
+ "max_ncells_to_plot must be <= max_ncells. "
820
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
821
+ )
822
+ elif max_ncells_to_plot < self.max_ncells:
823
+ embs = embs.sample(max_ncells_to_plot, axis=0)
824
 
825
  if self.emb_label is None:
826
  label_len = 0
 
841
  f"Label {label} from labels_to_plot "
842
  f"not present in provided embeddings dataframe."
843
  )
844
+ continue
845
+ output_prefix_label = output_prefix + f"_umap_{label}"
846
+ output_file = (
847
+ Path(output_directory) / output_prefix_label
848
+ ).with_suffix(".pdf")
849
+ plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
850
 
851
  if plot_style == "heatmap":
852
  for label in self.labels_to_plot:
geneformer/{ensembl_mapping_dict_gc104M.pkl → ensembl_mapping_dict_gc95M.pkl} RENAMED
File without changes
geneformer/evaluation_utils.py CHANGED
@@ -8,7 +8,6 @@ import numpy as np
8
  import pandas as pd
9
  import seaborn as sns
10
  import torch
11
- import datasets
12
  from datasets.utils.logging import disable_progress_bar, enable_progress_bar
13
  from sklearn import preprocessing
14
  from sklearn.metrics import (
@@ -21,15 +20,20 @@ from sklearn.metrics import (
21
  )
22
  from tqdm.auto import trange
23
 
 
24
  from .emb_extractor import make_colorbar
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
 
29
- def preprocess_classifier_batch(cell_batch, max_len, label_name, gene_token_dict):
30
  if max_len is None:
31
  max_len = max([len(i) for i in cell_batch["input_ids"]])
32
 
 
 
 
 
33
  def pad_label_example(example):
34
  example[label_name] = np.pad(
35
  example[label_name],
@@ -77,7 +81,7 @@ def py_softmax(vector):
77
  return e / e.sum()
78
 
79
 
80
- def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict, predict_metadata=None):
81
  if classifier_type == "gene":
82
  label_name = "labels"
83
  elif classifier_type == "cell":
@@ -85,14 +89,6 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
85
 
86
  predict_logits = []
87
  predict_labels = []
88
-
89
- predict_metadata_all = None
90
-
91
- if predict_metadata is not None:
92
- predict_metadata_all = dict()
93
- for metadata_name in predict_metadata:
94
- predict_metadata_all[metadata_name] = []
95
-
96
  model.eval()
97
 
98
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
@@ -107,21 +103,11 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
107
  for i in trange(0, evalset_len, forward_batch_size):
108
  max_range = min(i + forward_batch_size, evalset_len)
109
  batch_evalset = evalset.select([i for i in range(i, max_range)])
110
-
111
- if predict_metadata is not None:
112
- for metadata_name in predict_metadata:
113
- predict_metadata_all[metadata_name] += batch_evalset[metadata_name]
114
-
115
  padded_batch = preprocess_classifier_batch(
116
- batch_evalset, max_evalset_len, label_name, gene_token_dict
117
  )
118
-
119
  padded_batch.set_format(type="torch")
120
 
121
- # For datasets>=4.0.0, convert to dict to avoid format issues
122
- if int(datasets.__version__.split(".")[0]) >= 4:
123
- padded_batch = padded_batch[:]
124
-
125
  input_data_batch = padded_batch["input_ids"]
126
  attn_msk_batch = padded_batch["attention_mask"]
127
  label_batch = padded_batch[label_name]
@@ -148,8 +134,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
148
  y_pred = [vote(item[0]) for item in logit_label_paired]
149
  y_true = [item[1] for item in logit_label_paired]
150
  logits_list = [item[0] for item in logit_label_paired]
151
-
152
- return y_pred, y_true, logits_list, predict_metadata_all
153
 
154
 
155
  def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
 
8
  import pandas as pd
9
  import seaborn as sns
10
  import torch
 
11
  from datasets.utils.logging import disable_progress_bar, enable_progress_bar
12
  from sklearn import preprocessing
13
  from sklearn.metrics import (
 
20
  )
21
  from tqdm.auto import trange
22
 
23
+ from . import TOKEN_DICTIONARY_FILE
24
  from .emb_extractor import make_colorbar
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
 
29
+ def preprocess_classifier_batch(cell_batch, max_len, label_name):
30
  if max_len is None:
31
  max_len = max([len(i) for i in cell_batch["input_ids"]])
32
 
33
+ # load token dictionary (Ensembl IDs:token)
34
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
35
+ gene_token_dict = pickle.load(f)
36
+
37
  def pad_label_example(example):
38
  example[label_name] = np.pad(
39
  example[label_name],
 
81
  return e / e.sum()
82
 
83
 
84
+ def classifier_predict(model, classifier_type, evalset, forward_batch_size):
85
  if classifier_type == "gene":
86
  label_name = "labels"
87
  elif classifier_type == "cell":
 
89
 
90
  predict_logits = []
91
  predict_labels = []
 
 
 
 
 
 
 
 
92
  model.eval()
93
 
94
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
 
103
  for i in trange(0, evalset_len, forward_batch_size):
104
  max_range = min(i + forward_batch_size, evalset_len)
105
  batch_evalset = evalset.select([i for i in range(i, max_range)])
 
 
 
 
 
106
  padded_batch = preprocess_classifier_batch(
107
+ batch_evalset, max_evalset_len, label_name
108
  )
 
109
  padded_batch.set_format(type="torch")
110
 
 
 
 
 
111
  input_data_batch = padded_batch["input_ids"]
112
  attn_msk_batch = padded_batch["attention_mask"]
113
  label_batch = padded_batch[label_name]
 
134
  y_pred = [vote(item[0]) for item in logit_label_paired]
135
  y_true = [item[1] for item in logit_label_paired]
136
  logits_list = [item[0] for item in logit_label_paired]
137
+ return y_pred, y_true, logits_list
 
138
 
139
 
140
  def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
geneformer/{gene_median_dictionary_gc104M.pkl → gene_median_dictionary_gc95M.pkl} RENAMED
File without changes
geneformer/{gene_name_id_dict_gc104M.pkl → gene_name_id_dict_gc95M.pkl} RENAMED
File without changes
geneformer/in_silico_perturber.py CHANGED
@@ -72,7 +72,6 @@ class InSilicoPerturber:
72
  "max_ncells": {None, int},
73
  "cell_inds_to_perturb": {"all", dict},
74
  "emb_layer": {-1, 0},
75
- "model_version": {"V1", "V2"},
76
  "token_dictionary_file": {None, str},
77
  "forward_batch_size": {int},
78
  "nproc": {int},
@@ -97,7 +96,6 @@ class InSilicoPerturber:
97
  emb_layer=-1,
98
  forward_batch_size=100,
99
  nproc=4,
100
- model_version="V2",
101
  token_dictionary_file=None,
102
  clear_mem_ncells=1000,
103
  ):
@@ -186,9 +184,6 @@ class InSilicoPerturber:
186
  | Batch size for forward pass.
187
  nproc : int
188
  | Number of CPU processes to use.
189
- model_version : str
190
- | To auto-select settings for model version other than current default.
191
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
192
  token_dictionary_file : Path
193
  | Path to pickle file containing token dictionary (Ensembl ID:token).
194
  clear_mem_ncells : int
@@ -229,30 +224,15 @@ class InSilicoPerturber:
229
  self.emb_layer = emb_layer
230
  self.forward_batch_size = forward_batch_size
231
  self.nproc = nproc
232
- self.model_version = model_version
233
  self.token_dictionary_file = token_dictionary_file
234
- self.clear_mem_ncells = clear_mem_ncells
235
 
236
  self.validate_options()
237
 
238
- if self.model_version == "V1":
239
- from . import TOKEN_DICTIONARY_FILE_30M
240
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
241
- if self.emb_mode == "cls":
242
- self.emb_mode = "cell"
243
- logger.warning(
244
- "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
245
- )
246
- if self.emb_mode == "cls_and_gene":
247
- self.emb_mode = "cell_and_gene"
248
- logger.warning(
249
- "model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
250
- )
251
-
252
  # load token dictionary (Ensembl IDs:token)
253
  if self.token_dictionary_file is None:
254
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE
255
- with open(self.token_dictionary_file, "rb") as f:
256
  self.gene_token_dict = pickle.load(f)
257
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
258
 
 
72
  "max_ncells": {None, int},
73
  "cell_inds_to_perturb": {"all", dict},
74
  "emb_layer": {-1, 0},
 
75
  "token_dictionary_file": {None, str},
76
  "forward_batch_size": {int},
77
  "nproc": {int},
 
96
  emb_layer=-1,
97
  forward_batch_size=100,
98
  nproc=4,
 
99
  token_dictionary_file=None,
100
  clear_mem_ncells=1000,
101
  ):
 
184
  | Batch size for forward pass.
185
  nproc : int
186
  | Number of CPU processes to use.
 
 
 
187
  token_dictionary_file : Path
188
  | Path to pickle file containing token dictionary (Ensembl ID:token).
189
  clear_mem_ncells : int
 
224
  self.emb_layer = emb_layer
225
  self.forward_batch_size = forward_batch_size
226
  self.nproc = nproc
 
227
  self.token_dictionary_file = token_dictionary_file
228
+ self.clear_mem_ncells = clear_mem_ncells
229
 
230
  self.validate_options()
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  # load token dictionary (Ensembl IDs:token)
233
  if self.token_dictionary_file is None:
234
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
235
+ with open(token_dictionary_file, "rb") as f:
236
  self.gene_token_dict = pickle.load(f)
237
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
238
 
geneformer/in_silico_perturber_stats.py CHANGED
@@ -676,7 +676,6 @@ class InSilicoPerturberStats:
676
  "anchor_gene": {None, str},
677
  "cell_states_to_model": {None, dict},
678
  "pickle_suffix": {None, str},
679
- "model_version": {"V1", "V2"},
680
  }
681
 
682
  def __init__(
@@ -687,7 +686,6 @@ class InSilicoPerturberStats:
687
  anchor_gene=None,
688
  cell_states_to_model=None,
689
  pickle_suffix="_raw.pickle",
690
- model_version="V2",
691
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
692
  gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
693
  ):
@@ -715,7 +713,7 @@ class InSilicoPerturberStats:
715
  | analyzes data for anchor gene perturbed in combination with each other gene.
716
  | However, if combos=0 and anchor_gene="ENSG00000136574":
717
  | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
718
- cell_states_to_model : None, dict
719
  | Cell states to model if testing perturbations that achieve goal state change.
720
  | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
721
  | state_key: key specifying name of column in .dataset that defines the start/goal states
@@ -726,12 +724,6 @@ class InSilicoPerturberStats:
726
  | "start_state": "dcm",
727
  | "goal_state": "nf",
728
  | "alt_states": ["hcm", "other1", "other2"]}
729
- pickle_suffix : None, str
730
- | Suffix to subselect intermediate raw files for analysis.
731
- | Default output of InSilicoPerturber uses suffix "_raw.pickle".
732
- model_version : str
733
- | To auto-select settings for model version other than current default.
734
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
735
  token_dictionary_file : Path
736
  | Path to pickle file containing token dictionary (Ensembl ID:token).
737
  gene_name_id_dictionary_file : Path
@@ -744,15 +736,9 @@ class InSilicoPerturberStats:
744
  self.anchor_gene = anchor_gene
745
  self.cell_states_to_model = cell_states_to_model
746
  self.pickle_suffix = pickle_suffix
747
- self.model_version = model_version
748
 
749
  self.validate_options()
750
 
751
- if self.model_version == "V1":
752
- from . import ENSEMBL_DICTIONARY_FILE_30M, TOKEN_DICTIONARY_FILE_30M
753
- token_dictionary_file=TOKEN_DICTIONARY_FILE_30M
754
- gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE_30M
755
-
756
  # load token dictionary (Ensembl IDs:token)
757
  with open(token_dictionary_file, "rb") as f:
758
  self.gene_token_dict = pickle.load(f)
 
676
  "anchor_gene": {None, str},
677
  "cell_states_to_model": {None, dict},
678
  "pickle_suffix": {None, str},
 
679
  }
680
 
681
  def __init__(
 
686
  anchor_gene=None,
687
  cell_states_to_model=None,
688
  pickle_suffix="_raw.pickle",
 
689
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
690
  gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
691
  ):
 
713
  | analyzes data for anchor gene perturbed in combination with each other gene.
714
  | However, if combos=0 and anchor_gene="ENSG00000136574":
715
  | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
716
+ cell_states_to_model: None, dict
717
  | Cell states to model if testing perturbations that achieve goal state change.
718
  | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
719
  | state_key: key specifying name of column in .dataset that defines the start/goal states
 
724
  | "start_state": "dcm",
725
  | "goal_state": "nf",
726
  | "alt_states": ["hcm", "other1", "other2"]}
 
 
 
 
 
 
727
  token_dictionary_file : Path
728
  | Path to pickle file containing token dictionary (Ensembl ID:token).
729
  gene_name_id_dictionary_file : Path
 
736
  self.anchor_gene = anchor_gene
737
  self.cell_states_to_model = cell_states_to_model
738
  self.pickle_suffix = pickle_suffix
 
739
 
740
  self.validate_options()
741
 
 
 
 
 
 
742
  # load token dictionary (Ensembl IDs:token)
743
  with open(token_dictionary_file, "rb") as f:
744
  self.gene_token_dict = pickle.load(f)
geneformer/mtl/__init__.py CHANGED
@@ -1,4 +1 @@
1
- # ruff: noqa: F401
2
-
3
- from . import eval_utils
4
- from . import utils
 
1
+ # ruff: noqa: F401
 
 
 
geneformer/mtl/collators.py CHANGED
@@ -1,8 +1,8 @@
1
  # imports
2
  import torch
3
  import pickle
4
- from geneformer.collator_for_classification import DataCollatorForGeneClassification
5
- from geneformer import TOKEN_DICTIONARY_FILE
6
 
7
  """Geneformer collator for multi-task cell classification."""
8
 
 
1
  # imports
2
  import torch
3
  import pickle
4
+ from ..collator_for_classification import DataCollatorForGeneClassification
5
+ from .. import TOKEN_DICTIONARY_FILE
6
 
7
  """Geneformer collator for multi-task cell classification."""
8
 
geneformer/mtl/data.py CHANGED
@@ -1,190 +1,126 @@
1
  import os
2
- import pickle
3
- import torch
4
- from torch.utils.data import DataLoader, Dataset
5
- from datasets import load_from_disk
6
-
7
  from .collators import DataCollatorForMultitaskCellClassification
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
- class StreamingMultiTaskDataset(Dataset):
11
-
12
- def __init__(self, dataset_path, config, is_test=False, dataset_type=""):
13
- """Initialize the streaming dataset."""
14
- self.dataset = load_from_disk(dataset_path)
15
- self.config = config
16
- self.is_test = is_test
17
- self.dataset_type = dataset_type
18
- self.cell_id_mapping = {}
19
-
20
- # Setup task and column mappings
21
- self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
22
- self.task_to_column = dict(zip(self.task_names, config["task_columns"]))
23
- config["task_names"] = self.task_names
24
-
25
- # Check if unique_cell_id column exists in the dataset
26
- self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names
27
- print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset")
28
-
29
- # Setup label mappings
30
- self.label_mappings_path = os.path.join(
31
- config["results_dir"],
32
- f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
33
- )
34
-
35
- if not is_test:
36
- self._validate_columns()
37
- self.task_label_mappings, self.num_labels_list = self._create_label_mappings()
38
- self._save_label_mappings()
39
- else:
40
- # Load existing mappings for test data
41
- self.task_label_mappings = self._load_label_mappings()
42
- self.num_labels_list = [len(mapping) for mapping in self.task_label_mappings.values()]
43
-
44
- def _validate_columns(self):
45
- """Ensures required columns are present in the dataset."""
46
- missing_columns = [col for col in self.task_to_column.values()
47
- if col not in self.dataset.column_names]
48
- if missing_columns:
49
- raise KeyError(
50
- f"Missing columns in {self.dataset_type} dataset: {missing_columns}. "
51
- f"Available columns: {self.dataset.column_names}"
52
- )
53
-
54
- def _create_label_mappings(self):
55
- """Creates label mappings for the dataset."""
56
- task_label_mappings = {}
57
- num_labels_list = []
58
-
59
- for task, column in self.task_to_column.items():
60
- unique_values = sorted(set(self.dataset[column]))
61
- mapping = {label: idx for idx, label in enumerate(unique_values)}
62
- task_label_mappings[task] = mapping
63
- num_labels_list.append(len(unique_values))
64
-
65
- return task_label_mappings, num_labels_list
66
-
67
- def _save_label_mappings(self):
68
- """Saves label mappings to a pickle file."""
69
- with open(self.label_mappings_path, "wb") as f:
70
- pickle.dump(self.task_label_mappings, f)
71
-
72
- def _load_label_mappings(self):
73
- """Loads label mappings from a pickle file."""
74
- with open(self.label_mappings_path, "rb") as f:
75
- return pickle.load(f)
76
-
77
- def __len__(self):
78
- return len(self.dataset)
79
-
80
- def __getitem__(self, idx):
81
- record = self.dataset[idx]
82
-
83
- # Store cell ID mapping
84
- if self.has_unique_cell_ids:
85
- unique_cell_id = record["unique_cell_id"]
86
- self.cell_id_mapping[idx] = unique_cell_id
87
- else:
88
- self.cell_id_mapping[idx] = f"cell_{idx}"
89
-
90
- # Create transformed record
91
  transformed_record = {
92
  "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
93
- "cell_id": idx,
94
  }
95
-
96
- # Add labels
97
- if not self.is_test:
98
  label_dict = {
99
- task: self.task_label_mappings[task][record[column]]
100
- for task, column in self.task_to_column.items()
101
  }
102
  else:
103
- label_dict = {task: -1 for task in self.config["task_names"]}
104
-
105
  transformed_record["label"] = label_dict
106
-
107
- return transformed_record
108
 
 
109
 
110
- def get_data_loader(dataset, batch_size, sampler=None, shuffle=True):
111
- """Create a DataLoader with the given dataset and parameters."""
112
- return DataLoader(
113
- dataset,
114
- batch_size=batch_size,
115
- sampler=sampler,
116
- shuffle=shuffle if sampler is None else False,
117
- num_workers=0,
118
- pin_memory=True,
119
- collate_fn=DataCollatorForMultitaskCellClassification(),
120
- )
121
 
 
 
 
 
122
 
123
- def prepare_data_loaders(config, include_test=False):
124
- """Prepare data loaders for training, validation, and optionally test."""
125
- result = {}
126
-
127
- # Process train data
128
- train_dataset = StreamingMultiTaskDataset(
129
- config["train_path"],
130
- config,
131
- dataset_type="train"
132
- )
133
- result["train_loader"] = get_data_loader(train_dataset, config["batch_size"])
134
-
135
- # Store the cell ID mapping from the dataset
136
- result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()}
137
- print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset")
138
-
139
- result["num_labels_list"] = train_dataset.num_labels_list
140
-
141
- # Process validation data
142
- val_dataset = StreamingMultiTaskDataset(
143
- config["val_path"],
144
- config,
145
- dataset_type="validation"
146
- )
147
- result["val_loader"] = get_data_loader(val_dataset, config["batch_size"])
148
-
149
- # Store the complete cell ID mapping for validation
150
- for idx in range(len(val_dataset)):
151
- _ = val_dataset[idx]
152
-
153
- result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()}
154
- print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset")
155
-
156
- # Validate label mappings
157
- validate_label_mappings(config)
158
-
159
- # Process test data if requested
160
- if include_test and "test_path" in config:
161
- test_dataset = StreamingMultiTaskDataset(
162
- config["test_path"],
163
- config,
164
- is_test=True,
165
- dataset_type="test"
166
  )
167
- result["test_loader"] = get_data_loader(test_dataset, config["batch_size"])
168
-
169
- for idx in range(len(test_dataset)):
170
- _ = test_dataset[idx]
171
-
172
- result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()}
173
- print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset")
174
-
175
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  def validate_label_mappings(config):
179
  """Ensures train and validation label mappings are consistent."""
180
  train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
181
  val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
182
-
183
- with open(train_mappings_path, "rb") as f:
184
- train_mappings = pickle.load(f)
185
-
186
- with open(val_mappings_path, "rb") as f:
187
- val_mappings = pickle.load(f)
188
 
189
  for task_name in config["task_names"]:
190
  if train_mappings[task_name] != val_mappings[task_name]:
@@ -195,43 +131,32 @@ def validate_label_mappings(config):
195
  )
196
 
197
 
198
- # Legacy functions for backward compatibility
199
- def preload_and_process_data(config):
200
- """Preloads and preprocesses train and validation datasets."""
201
- data = prepare_data_loaders(config)
202
-
203
- return (
204
- data["train_loader"].dataset,
205
- data["train_cell_mapping"],
206
- data["val_loader"].dataset,
207
- data["val_cell_mapping"],
208
- data["num_labels_list"]
209
  )
210
 
211
 
212
  def preload_data(config):
213
  """Preprocesses train and validation data for trials."""
214
- data = prepare_data_loaders(config)
215
- return data["train_loader"], data["val_loader"]
 
216
 
217
 
218
  def load_and_preprocess_test_data(config):
219
  """Loads and preprocesses test data."""
220
- test_dataset = StreamingMultiTaskDataset(
221
- config["test_path"],
222
- config,
223
- is_test=True,
224
- dataset_type="test"
225
- )
226
-
227
- return (
228
- test_dataset,
229
- test_dataset.cell_id_mapping,
230
- test_dataset.num_labels_list
231
- )
232
 
233
 
234
  def prepare_test_loader(config):
235
  """Prepares DataLoader for test data."""
236
- data = prepare_data_loaders(config, include_test=True)
237
- return data["test_loader"], data["test_cell_mapping"], data["num_labels_list"]
 
 
1
  import os
 
 
 
 
 
2
  from .collators import DataCollatorForMultitaskCellClassification
3
+ from .imports import *
4
+
5
+ def validate_columns(dataset, required_columns, dataset_type):
6
+ """Ensures required columns are present in the dataset."""
7
+ missing_columns = [col for col in required_columns if col not in dataset.column_names]
8
+ if missing_columns:
9
+ raise KeyError(
10
+ f"Missing columns in {dataset_type} dataset: {missing_columns}. "
11
+ f"Available columns: {dataset.column_names}"
12
+ )
13
 
14
 
15
+ def create_label_mappings(dataset, task_to_column):
16
+ """Creates label mappings for the dataset."""
17
+ task_label_mappings = {}
18
+ num_labels_list = []
19
+ for task, column in task_to_column.items():
20
+ unique_values = sorted(set(dataset[column]))
21
+ mapping = {label: idx for idx, label in enumerate(unique_values)}
22
+ task_label_mappings[task] = mapping
23
+ num_labels_list.append(len(unique_values))
24
+ return task_label_mappings, num_labels_list
25
+
26
+
27
+ def save_label_mappings(mappings, path):
28
+ """Saves label mappings to a pickle file."""
29
+ with open(path, "wb") as f:
30
+ pickle.dump(mappings, f)
31
+
32
+
33
+ def load_label_mappings(path):
34
+ """Loads label mappings from a pickle file."""
35
+ with open(path, "rb") as f:
36
+ return pickle.load(f)
37
+
38
+
39
+ def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
40
+ """Transforms the dataset to the required format."""
41
+ transformed_dataset = []
42
+ cell_id_mapping = {}
43
+
44
+ for idx, record in enumerate(dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  transformed_record = {
46
  "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
47
+ "cell_id": idx, # Index-based cell ID
48
  }
49
+
50
+ if not is_test:
 
51
  label_dict = {
52
+ task: task_label_mappings[task][record[column]]
53
+ for task, column in task_to_column.items()
54
  }
55
  else:
56
+ label_dict = {task: -1 for task in config["task_names"]}
57
+
58
  transformed_record["label"] = label_dict
59
+ transformed_dataset.append(transformed_record)
60
+ cell_id_mapping[idx] = record.get("unique_cell_id", idx)
61
 
62
+ return transformed_dataset, cell_id_mapping
63
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
66
+ """Main function to load and preprocess data."""
67
+ try:
68
+ dataset = load_from_disk(dataset_path)
69
 
70
+ # Setup task and column mappings
71
+ task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
72
+ task_to_column = dict(zip(task_names, config["task_columns"]))
73
+ config["task_names"] = task_names
74
+
75
+ label_mappings_path = os.path.join(
76
+ config["results_dir"],
77
+ f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
79
+
80
+ if not is_test:
81
+ validate_columns(dataset, task_to_column.values(), dataset_type)
82
+
83
+ # Create and save label mappings
84
+ task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
85
+ save_label_mappings(task_label_mappings, label_mappings_path)
86
+ else:
87
+ # Load existing mappings for test data
88
+ task_label_mappings = load_label_mappings(label_mappings_path)
89
+ num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
90
+
91
+ # Transform dataset
92
+ transformed_dataset, cell_id_mapping = transform_dataset(
93
+ dataset, task_to_column, task_label_mappings, config, is_test
94
+ )
95
+
96
+ return transformed_dataset, cell_id_mapping, num_labels_list
97
+
98
+ except KeyError as e:
99
+ raise ValueError(f"Configuration error or dataset key missing: {e}")
100
+ except Exception as e:
101
+ raise RuntimeError(f"Error during data loading or preprocessing: {e}")
102
+
103
+
104
+ def preload_and_process_data(config):
105
+ """Preloads and preprocesses train and validation datasets."""
106
+ # Process train data and save mappings
107
+ train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
108
+
109
+ # Process validation data and save mappings
110
+ val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
111
+
112
+ # Validate that the mappings match
113
+ validate_label_mappings(config)
114
+
115
+ return (*train_data[:2], *val_data) # Return train and val data along with mappings
116
 
117
 
118
  def validate_label_mappings(config):
119
  """Ensures train and validation label mappings are consistent."""
120
  train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
121
  val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
122
+ train_mappings = load_label_mappings(train_mappings_path)
123
+ val_mappings = load_label_mappings(val_mappings_path)
 
 
 
 
124
 
125
  for task_name in config["task_names"]:
126
  if train_mappings[task_name] != val_mappings[task_name]:
 
131
  )
132
 
133
 
134
+ def get_data_loader(preprocessed_dataset, batch_size):
135
+ """Creates a DataLoader with optimal settings."""
136
+ return DataLoader(
137
+ preprocessed_dataset,
138
+ batch_size=batch_size,
139
+ shuffle=True,
140
+ collate_fn=DataCollatorForMultitaskCellClassification(),
141
+ num_workers=os.cpu_count(),
142
+ pin_memory=True,
 
 
143
  )
144
 
145
 
146
  def preload_data(config):
147
  """Preprocesses train and validation data for trials."""
148
+ train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
149
+ val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
150
+ return train_loader, val_loader
151
 
152
 
153
  def load_and_preprocess_test_data(config):
154
  """Loads and preprocesses test data."""
155
+ return load_and_preprocess_data(config["test_path"], config, is_test=True)
 
 
 
 
 
 
 
 
 
 
 
156
 
157
 
158
  def prepare_test_loader(config):
159
  """Prepares DataLoader for test data."""
160
+ test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
161
+ test_loader = get_data_loader(test_dataset, config["batch_size"])
162
+ return test_loader, cell_id_mapping, num_labels_list
geneformer/mtl/eval_utils.py CHANGED
@@ -1,16 +1,19 @@
1
- import os
2
- import json
3
- import torch
4
  import pandas as pd
5
 
6
- from .data import prepare_test_loader
 
7
  from .model import GeneformerMultiTask
8
 
 
9
  def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
10
  task_pred_labels = {task_name: [] for task_name in config["task_names"]}
11
  task_pred_probs = {task_name: [] for task_name in config["task_names"]}
12
  cell_ids = []
13
 
 
 
 
 
14
  model.eval()
15
  with torch.no_grad():
16
  for batch in test_loader:
@@ -82,4 +85,4 @@ def load_and_evaluate_test_model(config):
82
  best_model.to(device)
83
 
84
  evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
85
- print("Evaluation completed.")
 
 
 
 
1
  import pandas as pd
2
 
3
+ from .imports import * # noqa # isort:skip
4
+ from .data import prepare_test_loader # noqa # isort:skip
5
  from .model import GeneformerMultiTask
6
 
7
+
8
  def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
9
  task_pred_labels = {task_name: [] for task_name in config["task_names"]}
10
  task_pred_probs = {task_name: [] for task_name in config["task_names"]}
11
  cell_ids = []
12
 
13
+ # # Load task label mappings from pickle file
14
+ # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
15
+ # task_label_mappings = pickle.load(f)
16
+
17
  model.eval()
18
  with torch.no_grad():
19
  for batch in test_loader:
 
85
  best_model.to(device)
86
 
87
  evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
88
+ print("Evaluation completed.")
geneformer/mtl/imports.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+ import json
4
+ import os
5
+ import pickle
6
+ import sys
7
+ import warnings
8
+ from enum import Enum
9
+ from itertools import chain
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ import numpy as np
13
+ import optuna
14
+ import pandas as pd
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+ from datasets import load_from_disk
20
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
21
+ from sklearn.model_selection import train_test_split
22
+ from sklearn.preprocessing import LabelEncoder
23
+ from torch.utils.data import DataLoader
24
+ from transformers import (
25
+ AdamW,
26
+ BatchEncoding,
27
+ BertConfig,
28
+ BertModel,
29
+ DataCollatorForTokenClassification,
30
+ SpecialTokensMixin,
31
+ get_cosine_schedule_with_warmup,
32
+ get_linear_schedule_with_warmup,
33
+ get_scheduler,
34
+ )
35
+ from transformers.utils import logging, to_py_obj
36
+
37
+ from .collators import DataCollatorForMultitaskCellClassification
38
+
39
+ # local modules
40
+ from .data import get_data_loader, preload_and_process_data
41
+ from .model import GeneformerMultiTask
42
+ from .optuna_utils import create_optuna_study
43
+ from .utils import save_model
geneformer/mtl/model.py CHANGED
@@ -118,4 +118,4 @@ class GeneformerMultiTask(nn.Module):
118
  f"Error during loss computation for task {task_id}: {e}"
119
  )
120
 
121
- return total_loss, logits, losses if labels is not None else logits
 
118
  f"Error during loss computation for task {task_id}: {e}"
119
  )
120
 
121
+ return total_loss, logits, losses if labels is not None else logits
geneformer/mtl/optuna_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ from optuna.integration import TensorBoardCallback
3
+
4
+
5
+ def save_trial_callback(study, trial, trials_result_path):
6
+ with open(trials_result_path, "a") as f:
7
+ f.write(
8
+ f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
9
+ )
10
+
11
+
12
+ def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
13
+ study = optuna.create_study(direction="maximize")
14
+
15
+ # init TensorBoard callback
16
+ tensorboard_callback = TensorBoardCallback(
17
+ dirname=tensorboard_log_dir, metric_name="F1 Macro"
18
+ )
19
+
20
+ # callback and TensorBoard callback
21
+ callbacks = [
22
+ lambda study, trial: save_trial_callback(study, trial, trials_result_path),
23
+ tensorboard_callback,
24
+ ]
25
+
26
+ study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
27
+ return study
geneformer/mtl/train.py CHANGED
@@ -1,719 +1,380 @@
1
  import os
 
 
 
2
  import pandas as pd
3
  import torch
4
- import torch.distributed as dist
5
- import torch.multiprocessing as mp
6
- from torch.nn.parallel import DistributedDataParallel as DDP
7
  from torch.utils.tensorboard import SummaryWriter
8
  from tqdm import tqdm
9
- import optuna
10
- import functools
11
- import time
12
 
 
13
  from .model import GeneformerMultiTask
14
- from .utils import (
15
- calculate_metrics,
16
- get_layer_freeze_range,
17
- set_seed,
18
- initialize_wandb,
19
- create_model,
20
- setup_optimizer_and_scheduler,
21
- save_model,
22
- save_hyperparameters,
23
- prepare_training_environment,
24
- log_training_step,
25
- log_validation_metrics,
26
- save_validation_predictions,
27
- setup_logging,
28
- setup_distributed_environment,
29
- train_distributed
30
- )
31
-
32
-
33
- class Trainer:
34
- """Trainer class for multi-task learning"""
35
-
36
- def __init__(self, config):
37
- self.config = config
38
- self.device = None
39
- self.model = None
40
- self.optimizer = None
41
- self.scheduler = None
42
- self.writer = None
43
- self.is_distributed = config.get("distributed_training", False)
44
- self.local_rank = config.get("local_rank", 0)
45
- self.is_main_process = not self.is_distributed or self.local_rank == 0
46
-
47
- def train_epoch(self, train_loader, epoch):
48
- """Train the model for one epoch."""
49
- epoch_start = time.time()
50
- self.model.train()
51
-
52
- # For distributed training, we need to be aware of the global batch count
53
- if self.is_distributed:
54
- # Get world size for reporting
55
- world_size = dist.get_world_size()
56
- # Calculate total batches across all GPUs
57
- total_batches_global = len(train_loader) * world_size if self.local_rank == 0 else len(train_loader)
58
- else:
59
- world_size = 1
60
- total_batches_global = len(train_loader)
61
-
62
- progress_bar = None
63
- if self.is_main_process:
64
- # Use the global batch count for progress reporting in distributed mode
65
- progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}",
66
- total=len(train_loader))
67
- iterator = progress_bar
68
-
69
- # Report distributed training information
70
- if self.is_distributed:
71
- print(f"Distributed training: {world_size} GPUs, {len(train_loader)} batches per GPU, "
72
- f"{total_batches_global} total batches globally")
73
- else:
74
- iterator = train_loader
75
-
76
- batch_times = []
77
- forward_times = []
78
- backward_times = []
79
- optimizer_times = []
80
-
81
- # Get gradient accumulation steps from config (default to 1 if not specified)
82
- accumulation_steps = self.config.get("gradient_accumulation_steps", 1)
83
-
84
- # Zero gradients at the beginning
85
- self.optimizer.zero_grad()
86
-
87
- # Track loss for the entire epoch
88
- total_loss = 0.0
89
- num_batches = 0
90
- accumulated_loss = 0.0
91
-
92
- for batch_idx, batch in enumerate(iterator):
93
- batch_start = time.time()
94
-
95
- input_ids = batch["input_ids"].to(self.device)
96
- attention_mask = batch["attention_mask"].to(self.device)
97
- labels = [
98
- batch["labels"][task_name].to(self.device) for task_name in self.config["task_names"]
99
- ]
100
 
101
- forward_start = time.time()
102
- loss, _, _ = self.model(input_ids, attention_mask, labels)
103
-
104
- # Scale loss by accumulation steps for gradient accumulation
105
- if accumulation_steps > 1:
106
- loss = loss / accumulation_steps
107
-
108
- forward_end = time.time()
109
- forward_times.append(forward_end - forward_start)
110
-
111
- # Track loss - store the unscaled loss for reporting
112
- unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps)
113
- total_loss += unscaled_loss
114
- num_batches += 1
115
- accumulated_loss += loss.item() # For gradient accumulation tracking
116
-
117
- backward_start = time.time()
118
-
119
- # Use no_sync() for all but the last accumulation step to avoid unnecessary communication
120
- if self.is_distributed and accumulation_steps > 1:
121
- # If this is not the last accumulation step or the last batch
122
- if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader):
123
- with self.model.no_sync():
124
- loss.backward()
125
- else:
126
- loss.backward()
127
- else:
128
- # Non-distributed training or accumulation_steps=1
129
- loss.backward()
130
-
131
- backward_end = time.time()
132
- backward_times.append(backward_end - backward_start)
133
-
134
- # Only update weights after accumulation_steps or at the end of the epoch
135
- if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
136
- if self.config["gradient_clipping"]:
137
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
138
-
139
- optimizer_start = time.time()
140
- self.optimizer.step()
141
- self.scheduler.step()
142
- self.optimizer.zero_grad()
143
- optimizer_end = time.time()
144
- optimizer_times.append(optimizer_end - optimizer_start)
145
-
146
- # Log after optimizer step
147
- if self.is_main_process:
148
- # Calculate running average loss
149
- avg_loss = total_loss / num_batches
150
-
151
- log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx)
152
-
153
- # Update progress bar with just the running average loss
154
- progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
155
-
156
- accumulated_loss = 0.0
157
- else:
158
- optimizer_times.append(0) # No optimizer step taken
159
-
160
- batch_end = time.time()
161
- batch_times.append(batch_end - batch_start)
162
-
163
- epoch_end = time.time()
164
-
165
- # Calculate average loss for the epoch
166
- epoch_avg_loss = total_loss / num_batches
167
-
168
- # If distributed, gather losses from all processes to compute global average
169
- if self.is_distributed:
170
- # Create a tensor to hold the loss
171
- loss_tensor = torch.tensor([epoch_avg_loss], device=self.device)
172
- # Gather losses from all processes
173
- all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
174
- dist.all_gather(all_losses, loss_tensor)
175
- # Compute the global average loss across all processes
176
- epoch_avg_loss = torch.mean(torch.stack(all_losses)).item()
177
-
178
- if self.is_main_process:
179
- # douhble check if batch_size has already been adjusted for world_size in the config
180
- # This avoids double-counting the effective batch size
181
- per_gpu_batch_size = self.config['batch_size']
182
- total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size
183
-
184
- print(f"Epoch {epoch+1} timing:")
185
- print(f" Total epoch time: {epoch_end - epoch_start:.2f}s")
186
- print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
187
- print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s")
188
- print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s")
189
- print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s")
190
- print(f" Gradient accumulation steps: {accumulation_steps}")
191
- print(f" Batch size per GPU: {per_gpu_batch_size}")
192
- print(f" Effective global batch size: {total_effective_batch}")
193
- print(f" Average training loss: {epoch_avg_loss:.4f}")
194
- if self.is_distributed:
195
- print(f" Total batches processed across all GPUs: {total_batches_global}")
196
- print(f" Communication optimization: Using no_sync() for gradient accumulation")
197
-
198
- return epoch_avg_loss # Return the average loss for the epoch
199
-
200
- def validate_model(self, val_loader):
201
- val_start = time.time()
202
- self.model.eval()
203
- val_loss = 0.0
204
- task_true_labels = {task_name: [] for task_name in self.config["task_names"]}
205
- task_pred_labels = {task_name: [] for task_name in self.config["task_names"]}
206
- task_pred_probs = {task_name: [] for task_name in self.config["task_names"]}
207
-
208
- val_cell_ids = {}
209
- sample_counter = 0
210
-
211
- batch_times = []
212
-
213
- # Print validation dataset size
214
- if self.is_main_process:
215
- print(f"Validation dataset size: {len(val_loader.dataset)} samples")
216
- print(f"Number of validation batches: {len(val_loader)}")
217
-
218
- if self.is_distributed:
219
- world_size = dist.get_world_size()
220
- print(f"Distributed validation: {world_size} GPUs")
221
- if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
222
- samples_per_gpu = val_loader.sampler.num_samples
223
- print(f"Each GPU processes {samples_per_gpu} validation samples")
224
- print(f"Total validation samples processed: {samples_per_gpu * world_size}")
225
-
226
- with torch.no_grad():
227
- for batch in val_loader:
228
- batch_start = time.time()
229
- input_ids = batch["input_ids"].to(self.device)
230
- attention_mask = batch["attention_mask"].to(self.device)
231
- labels = [
232
- batch["labels"][task_name].to(self.device)
233
- for task_name in self.config["task_names"]
234
- ]
235
- loss, logits, _ = self.model(input_ids, attention_mask, labels)
236
- val_loss += loss.item()
237
-
238
- if "cell_id" in batch:
239
- for i, cell_id in enumerate(batch["cell_id"]):
240
- # Store the actual index for later mapping to unique_cell_id
241
- val_cell_ids[sample_counter + i] = cell_id.item()
242
-
243
- for sample_idx in range(len(batch["input_ids"])):
244
- for i, task_name in enumerate(self.config["task_names"]):
245
- true_label = batch["labels"][task_name][sample_idx].item()
246
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
247
- # Store the full probability distribution
248
- pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist()
249
- task_true_labels[task_name].append(true_label)
250
- task_pred_labels[task_name].append(pred_label)
251
- task_pred_probs[task_name].append(pred_prob)
252
-
253
- # Update current index for cell ID tracking
254
- sample_counter += len(batch["input_ids"])
255
-
256
- batch_end = time.time()
257
- batch_times.append(batch_end - batch_start)
258
-
259
- # norm validation loss by the number of batches
260
- val_loss /= len(val_loader)
261
-
262
- # distributed, gather results from all processes
263
- if self.is_distributed:
264
- # Create tensors to hold the local results
265
- loss_tensor = torch.tensor([val_loss], device=self.device)
266
- gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())]
267
- dist.all_gather(gathered_losses, loss_tensor)
268
-
269
- # Compute average loss across all processes
270
- val_loss = torch.mean(torch.cat(gathered_losses)).item()
271
-
272
- world_size = dist.get_world_size()
273
-
274
- if self.is_main_process:
275
- print(f"Collected predictions from rank {self.local_rank}")
276
- print(f"Number of samples processed by this rank: {sample_counter}")
277
-
278
- val_end = time.time()
279
-
280
- if self.is_main_process:
281
- print(f"Validation timing:")
282
- print(f" Total validation time: {val_end - val_start:.2f}s")
283
- print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s")
284
- print(f" Collected {len(val_cell_ids)} cell indices from validation data")
285
- print(f" Processed {sample_counter} total samples during validation")
286
-
287
- # Print number of samples per task
288
- for task_name in self.config["task_names"]:
289
- print(f" Task {task_name}: {len(task_true_labels[task_name])} samples")
290
-
291
- return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids
292
-
293
- def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
294
- """Train the model and return validation loss and trained model."""
295
- if self.config.get("use_wandb", False) and self.is_main_process:
296
- initialize_wandb(self.config)
297
-
298
- # Create model
299
- self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank)
300
-
301
- # Setup optimizer and scheduler
302
- total_steps = len(train_loader) * self.config["epochs"]
303
- self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps)
304
-
305
- # Training loop
306
- if self.is_main_process:
307
- epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress")
308
- else:
309
- epoch_progress = range(self.config["epochs"])
310
-
311
- best_val_loss = float('inf')
312
- train_losses = []
313
-
314
- with setup_logging(self.config) as self.writer:
315
- for epoch in epoch_progress:
316
- if self.is_distributed:
317
- train_loader.sampler.set_epoch(epoch)
318
-
319
- train_loss = self.train_epoch(train_loader, epoch)
320
- train_losses.append(train_loss)
321
-
322
- # Run validation after each epoch if configured to do so
323
- if self.config.get("validate_each_epoch", False):
324
- val_loss, _, _, _, _ = self.validate_model(val_loader)
325
- if val_loss < best_val_loss:
326
- best_val_loss = val_loss
327
-
328
- if self.is_main_process:
329
- epoch_progress.set_postfix({
330
- "train_loss": f"{train_loss:.4f}",
331
- "val_loss": f"{val_loss:.4f}",
332
- "best_val_loss": f"{best_val_loss:.4f}"
333
- })
334
- else:
335
- if self.is_main_process:
336
- epoch_progress.set_postfix({
337
- "train_loss": f"{train_loss:.4f}"
338
- })
339
-
340
- val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader)
341
- task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
342
-
343
- if self.is_main_process:
344
- log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"])
345
-
346
- # Save validation predictions
347
- save_validation_predictions(
348
- val_cell_ids,
349
- task_true_labels,
350
- task_pred_labels,
351
- task_pred_probs,
352
- {**self.config, "val_cell_mapping": val_cell_id_mapping} # Include the mapping
353
- )
354
-
355
- if self.config.get("use_wandb", False):
356
- import wandb
357
- wandb.finish()
358
-
359
- print(f"\nTraining Summary:")
360
- print(f" Final Training Loss: {train_losses[-1]:.4f}")
361
- print(f" Final Validation Loss: {val_loss:.4f}")
362
- for task_name, metrics in task_metrics.items():
363
- print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
364
-
365
- return val_loss, self.model # Return both the validation loss and the trained model
366
-
367
- def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
368
- if self.is_distributed:
369
- self.device = torch.device(f"cuda:{self.local_rank}")
370
- else:
371
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
372
-
373
- self.model = create_model(self.config, num_labels_list, self.device)
374
-
375
- # war model w DDP
376
- if self.is_distributed:
377
- self.model = DDP(self.model, device_ids=[self.local_rank])
378
-
379
- # communication hook to optimize gradient synchronization
380
- from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
381
-
382
- # default hook which maintains full precision
383
- self.model.register_comm_hook(
384
- state=None,
385
- hook=comm_hooks.allreduce_hook
386
- )
387
-
388
- print(f"Rank {self.local_rank}: Registered communication hook for optimized gradient synchronization")
389
-
390
- print(f"Rank {self.local_rank}: Using samplers created in distributed worker")
391
- print(f"Rank {self.local_rank}: Training dataset has {len(train_loader.dataset)} samples")
392
- if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'num_samples'):
393
- print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch")
394
-
395
- if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'):
396
- print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples")
397
-
398
- # Set up optimizer and scheduler
399
- self.optimizer, self.scheduler = setup_optimizer_and_scheduler(
400
- self.model, self.config, len(train_loader)
401
  )
402
 
403
- if self.is_main_process and self.config.get("use_wandb", False):
404
- initialize_wandb(self.config)
405
-
406
- return train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
407
-
408
-
409
- def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
410
- """Train a model with the given configuration and data."""
411
- # Check if distributed training is enabled
412
- if config.get("distributed_training", False):
413
- # Check if we have multiple GPUs
414
- if torch.cuda.device_count() > 1:
415
- result = train_distributed(
416
- Trainer,
417
- config,
418
- train_loader,
419
- val_loader,
420
- train_cell_id_mapping,
421
- val_cell_id_mapping,
422
- num_labels_list
423
- )
424
- if result is not None:
425
- return result
426
- else:
427
- print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
428
- config["distributed_training"] = False
429
-
430
- # Non-distributed training
431
- trainer = Trainer(config)
432
- trainer.device = device
433
- return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
- def objective(
437
- trial,
438
- train_loader,
439
- val_loader,
440
- train_cell_id_mapping,
441
- val_cell_id_mapping,
442
- num_labels_list,
443
- config,
444
- device,
445
- ):
446
- """Objective function for Optuna hyperparameter optimization."""
447
- set_seed(config["seed"])
448
- initialize_wandb(config)
449
 
450
- trial_config = config.copy()
451
-
452
- # Suggest hyperparameters for this trial
453
- for param_name, param_config in config["hyperparameters"].items():
454
- if param_name == "lr_scheduler_type":
455
- trial_config[param_name] = trial.suggest_categorical(
456
- param_name, param_config["choices"]
457
- )
458
- elif param_name == "task_weights" and config["use_task_weights"]:
459
- weights = [
460
- trial.suggest_float(
461
- f"task_weight_{i}",
462
- param_config["low"],
463
- param_config["high"],
464
- )
465
- for i in range(len(num_labels_list))
466
- ]
467
- weight_sum = sum(weights)
468
- trial_config[param_name] = [w / weight_sum for w in weights]
469
- elif "log" in param_config and param_config["log"]:
470
- trial_config[param_name] = trial.suggest_float(
471
- param_name, param_config["low"], param_config["high"], log=True
472
- )
473
- else:
474
- trial_config[param_name] = trial.suggest_float(
475
- param_name, param_config["low"], param_config["high"]
476
- )
477
-
478
- # Set appropriate max layers to freeze based on pretrained model or user-specific range
479
- if "max_layers_to_freeze" in trial_config:
480
- if trial_config["max_layers_to_freeze"] is None:
481
- # infer range from pretrained model
482
- freeze_range = get_layer_freeze_range(trial_config["pretrained_path"])
483
- trial_config["max_layers_to_freeze"] = int(
484
- trial.suggest_int(
485
- "max_layers_to_freeze",
486
- freeze_range["min"],
487
- freeze_range["max"],
488
- )
489
  )
490
- else:
491
- # user-specified range
492
- min_freeze = trial_config["max_layers_to_freeze"]["min"]
493
- max_freeze = trial_config["max_layers_to_freeze"]["max"]
494
 
495
- trial_config["max_layers_to_freeze"] = int(
496
- trial.suggest_int("max_layers_to_freeze", min_freeze, max_freeze)
497
- )
498
 
499
- trial_config["run_name"] = f"trial_{trial.number}"
500
-
501
- # Handle distributed training for this trial
502
- if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1:
503
- manager = mp.Manager()
504
- shared_dict = manager.dict()
505
-
506
- train_distributed(
507
- Trainer,
508
- trial_config,
509
- train_loader,
510
- val_loader,
511
- train_cell_id_mapping,
512
- val_cell_id_mapping,
513
- num_labels_list,
514
- trial.number,
515
- shared_dict
 
 
 
 
 
 
 
516
  )
517
-
518
- val_loss = shared_dict.get('val_loss', float('inf'))
519
- task_metrics = shared_dict.get('task_metrics', {})
520
-
521
- trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {}))
522
- trial.set_user_attr("task_weights", trial_config["task_weights"])
523
-
524
  if config.get("use_wandb", False):
525
  import wandb
526
- wandb.log({
527
- "trial_number": trial.number,
528
- "val_loss": val_loss,
529
- **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
530
- **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
531
- })
532
- wandb.finish()
533
-
534
- return val_loss
535
-
536
- with setup_logging(trial_config) as writer:
537
- trainer = Trainer(trial_config)
538
- trainer.device = device
539
- trainer.writer = writer
540
-
541
- # Create model with trial hyperparameters
542
- trainer.model = create_model(trial_config, num_labels_list, device)
543
- total_steps = len(train_loader) * config["epochs"]
544
- trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps)
545
-
546
- # Training loop
547
- for epoch in range(config["epochs"]):
548
- trainer.train_epoch(train_loader, epoch)
549
-
550
- val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader)
551
- task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
552
-
553
- # Log metrics
554
- log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"])
555
-
556
- # Save validation predictions
557
- save_validation_predictions(
558
- val_cell_ids,
559
- task_true_labels,
560
- task_pred_labels,
561
- task_pred_probs,
562
- {**trial_config, "val_cell_mapping": val_cell_id_mapping},
563
- trial.number,
564
- )
565
 
566
- # Store model state dict and task weights in trial user attributes
567
- trial.set_user_attr("model_state_dict", trainer.model.state_dict())
568
- trial.set_user_attr("task_weights", trial_config["task_weights"])
 
569
 
570
- # Report intermediate value to Optuna
571
- trial.report(val_loss, config["epochs"])
572
- if trial.should_prune():
573
- raise optuna.TrialPruned()
574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  if config.get("use_wandb", False):
576
  import wandb
 
577
  wandb.log(
578
  {
579
- "trial_number": trial.number,
580
- "val_loss": val_loss,
581
- **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()},
582
- **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()},
583
- **{k: v for k, v in trial_config.items() if k in [
584
- "learning_rate", "warmup_ratio", "weight_decay", "dropout_rate",
585
- "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"
586
- ]},
587
  }
588
  )
589
- wandb.finish()
590
 
591
- return val_loss
 
 
 
 
 
592
 
593
 
594
- def run_manual_tuning(config):
595
- """Run training with manually specified hyperparameters."""
596
- (
597
- device,
598
- train_loader,
599
- val_loader,
600
- train_cell_id_mapping,
601
- val_cell_id_mapping,
602
- num_labels_list,
603
- ) = prepare_training_environment(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
605
- print("\nManual hyperparameters being used:")
606
- for key, value in config["manual_hyperparameters"].items():
607
- print(f"{key}: {value}")
608
- print()
609
 
610
- # Update config with manual hyperparameters
611
- for key, value in config["manual_hyperparameters"].items():
612
- config[key] = value
613
 
614
- # Train the model
615
- val_loss, trained_model = train_model(
616
- config,
617
- device,
618
- train_loader,
619
- val_loader,
620
- train_cell_id_mapping,
621
- val_cell_id_mapping,
622
- num_labels_list,
623
- )
 
624
 
625
- print(f"\nValidation loss with manual hyperparameters: {val_loss}")
 
 
626
 
627
- # Save the trained model - only if not using distributed training
628
- # (distributed training saves the model in the worker)
629
- if not config.get("distributed_training", False):
630
- model_save_directory = os.path.join(
631
- config["model_save_path"], "GeneformerMultiTask"
 
 
632
  )
633
- save_model(trained_model, model_save_directory)
634
-
635
- # Save the hyperparameters
636
- hyperparams_to_save = {
637
- **config["manual_hyperparameters"],
638
- "dropout_rate": config["dropout_rate"],
639
- "use_task_weights": config["use_task_weights"],
640
- "task_weights": config["task_weights"],
641
- "max_layers_to_freeze": config["max_layers_to_freeze"],
642
- "use_attention_pooling": config["use_attention_pooling"],
643
- }
644
- save_hyperparameters(model_save_directory, hyperparams_to_save)
645
 
646
- return val_loss
 
 
 
647
 
 
 
648
 
649
- def run_optuna_study(config):
650
- """Run hyperparameter optimization using Optuna."""
651
- # Prepare training environment
652
- (
653
- device,
654
- train_loader,
655
- val_loader,
656
- train_cell_id_mapping,
657
- val_cell_id_mapping,
658
- num_labels_list,
659
- ) = prepare_training_environment(config)
660
-
661
- # If manual hyperparameters are specified, use them instead of running Optuna
662
- if config.get("use_manual_hyperparameters", False):
663
- return run_manual_tuning(config)
664
-
665
- # Create a partial function with fixed arguments for the objective
666
- objective_with_config_and_data = functools.partial(
667
- objective,
668
- train_loader=train_loader,
669
- val_loader=val_loader,
670
- train_cell_id_mapping=train_cell_id_mapping,
671
- val_cell_id_mapping=val_cell_id_mapping,
672
- num_labels_list=num_labels_list,
673
- config=config,
674
- device=device,
675
  )
676
 
677
- # Create and run the Optuna study
678
- study = optuna.create_study(
679
- direction="minimize", # Minimize validation loss
680
- study_name=config["study_name"],
681
- # storage=config["storage"],
682
- load_if_exists=True,
683
- )
684
 
685
- study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
686
 
687
- # After finding the best trial
688
- best_params = study.best_trial.params
689
- best_task_weights = study.best_trial.user_attrs["task_weights"]
690
- print("Saving the best model and its hyperparameters...")
691
 
692
- # Create a model with the best hyperparameters
693
- best_model = GeneformerMultiTask(
694
- config["pretrained_path"],
695
- num_labels_list,
696
- dropout_rate=best_params["dropout_rate"],
697
- use_task_weights=config["use_task_weights"],
698
- task_weights=best_task_weights,
699
- max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0),
700
- use_attention_pooling=best_params.get("use_attention_pooling", False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  )
702
 
703
- # Get the best model state dictionary
704
- best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
 
706
- best_model_state_dict = {
707
- k.replace("module.", ""): v for k, v in best_model_state_dict.items()
708
- }
709
 
710
- best_model.load_state_dict(best_model_state_dict, strict=False)
 
 
 
711
 
712
- model_save_directory = os.path.join(
713
- config["model_save_path"], "GeneformerMultiTask"
714
  )
715
- save_model(best_model, model_save_directory)
716
 
717
- save_hyperparameters(model_save_directory, {**best_params, "task_weights": best_task_weights})
 
718
 
719
- return study.best_trial.value # Return the best validation loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import random
3
+
4
+ import numpy as np
5
  import pandas as pd
6
  import torch
 
 
 
7
  from torch.utils.tensorboard import SummaryWriter
8
  from tqdm import tqdm
 
 
 
9
 
10
+ from .imports import *
11
  from .model import GeneformerMultiTask
12
+ from .utils import calculate_task_specific_metrics, get_layer_freeze_range
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+
15
+ def set_seed(seed):
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+
23
+
24
+ def initialize_wandb(config):
25
+ if config.get("use_wandb", False):
26
+ import wandb
27
+
28
+ wandb.init(project=config["wandb_project"], config=config)
29
+ print("Weights & Biases (wandb) initialized and will be used for logging.")
30
+ else:
31
+ print(
32
+ "Weights & Biases (wandb) is not enabled. Logging will use other methods."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def create_model(config, num_labels_list, device):
37
+ model = GeneformerMultiTask(
38
+ config["pretrained_path"],
39
+ num_labels_list,
40
+ dropout_rate=config["dropout_rate"],
41
+ use_task_weights=config["use_task_weights"],
42
+ task_weights=config["task_weights"],
43
+ max_layers_to_freeze=config["max_layers_to_freeze"],
44
+ use_attention_pooling=config["use_attention_pooling"],
45
+ )
46
+ if config["use_data_parallel"]:
47
+ model = nn.DataParallel(model)
48
+ return model.to(device)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ def setup_optimizer_and_scheduler(model, config, total_steps):
52
+ optimizer = AdamW(
53
+ model.parameters(),
54
+ lr=config["learning_rate"],
55
+ weight_decay=config["weight_decay"],
56
+ )
57
+ warmup_steps = int(config["warmup_ratio"] * total_steps)
58
+
59
+ if config["lr_scheduler_type"] == "linear":
60
+ scheduler = get_linear_schedule_with_warmup(
61
+ optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
62
+ )
63
+ elif config["lr_scheduler_type"] == "cosine":
64
+ scheduler = get_cosine_schedule_with_warmup(
65
+ optimizer,
66
+ num_warmup_steps=warmup_steps,
67
+ num_training_steps=total_steps,
68
+ num_cycles=0.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
 
 
 
 
70
 
71
+ return optimizer, scheduler
72
+
 
73
 
74
+ def train_epoch(
75
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
76
+ ):
77
+ model.train()
78
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
79
+ for batch_idx, batch in enumerate(progress_bar):
80
+ optimizer.zero_grad()
81
+ input_ids = batch["input_ids"].to(device)
82
+ attention_mask = batch["attention_mask"].to(device)
83
+ labels = [
84
+ batch["labels"][task_name].to(device) for task_name in config["task_names"]
85
+ ]
86
+
87
+ loss, _, _ = model(input_ids, attention_mask, labels)
88
+ loss.backward()
89
+
90
+ if config["gradient_clipping"]:
91
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
92
+
93
+ optimizer.step()
94
+ scheduler.step()
95
+
96
+ writer.add_scalar(
97
+ "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
98
  )
 
 
 
 
 
 
 
99
  if config.get("use_wandb", False):
100
  import wandb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ wandb.log({"Training Loss": loss.item()})
103
+
104
+ # Update progress bar
105
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
106
 
107
+ return loss.item() # Return the last batch loss
 
 
 
108
 
109
+
110
+ def validate_model(model, val_loader, device, config):
111
+ model.eval()
112
+ val_loss = 0.0
113
+ task_true_labels = {task_name: [] for task_name in config["task_names"]}
114
+ task_pred_labels = {task_name: [] for task_name in config["task_names"]}
115
+ task_pred_probs = {task_name: [] for task_name in config["task_names"]}
116
+
117
+ with torch.no_grad():
118
+ for batch in val_loader:
119
+ input_ids = batch["input_ids"].to(device)
120
+ attention_mask = batch["attention_mask"].to(device)
121
+ labels = [
122
+ batch["labels"][task_name].to(device)
123
+ for task_name in config["task_names"]
124
+ ]
125
+ loss, logits, _ = model(input_ids, attention_mask, labels)
126
+ val_loss += loss.item()
127
+
128
+ for sample_idx in range(len(batch["input_ids"])):
129
+ for i, task_name in enumerate(config["task_names"]):
130
+ true_label = batch["labels"][task_name][sample_idx].item()
131
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
132
+ pred_prob = (
133
+ torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
134
+ )
135
+ task_true_labels[task_name].append(true_label)
136
+ task_pred_labels[task_name].append(pred_label)
137
+ task_pred_probs[task_name].append(pred_prob)
138
+
139
+ val_loss /= len(val_loader)
140
+ return val_loss, task_true_labels, task_pred_labels, task_pred_probs
141
+
142
+
143
+ def log_metrics(task_metrics, val_loss, config, writer, epochs):
144
+ for task_name, metrics in task_metrics.items():
145
+ print(
146
+ f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
147
+ )
148
  if config.get("use_wandb", False):
149
  import wandb
150
+
151
  wandb.log(
152
  {
153
+ f"{task_name} Validation F1 Macro": metrics["f1"],
154
+ f"{task_name} Validation Accuracy": metrics["accuracy"],
 
 
 
 
 
 
155
  }
156
  )
 
157
 
158
+ writer.add_scalar("Validation Loss", val_loss, epochs)
159
+ for task_name, metrics in task_metrics.items():
160
+ writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
161
+ writer.add_scalar(
162
+ f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
163
+ )
164
 
165
 
166
+ def save_validation_predictions(
167
+ val_cell_id_mapping,
168
+ task_true_labels,
169
+ task_pred_labels,
170
+ task_pred_probs,
171
+ config,
172
+ trial_number=None,
173
+ ):
174
+ if trial_number is not None:
175
+ trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
176
+ os.makedirs(trial_results_dir, exist_ok=True)
177
+ val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
178
+ else:
179
+ val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
180
+
181
+ rows = []
182
+ for sample_idx in range(len(val_cell_id_mapping)):
183
+ row = {"Cell ID": val_cell_id_mapping[sample_idx]}
184
+ for task_name in config["task_names"]:
185
+ row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
186
+ row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
187
+ row[f"{task_name} Probabilities"] = ",".join(
188
+ map(str, task_pred_probs[task_name][sample_idx])
189
+ )
190
+ rows.append(row)
191
 
192
+ df = pd.DataFrame(rows)
193
+ df.to_csv(val_preds_file, index=False)
194
+ print(f"Validation predictions saved to {val_preds_file}")
 
195
 
 
 
 
196
 
197
+ def train_model(
198
+ config,
199
+ device,
200
+ train_loader,
201
+ val_loader,
202
+ train_cell_id_mapping,
203
+ val_cell_id_mapping,
204
+ num_labels_list,
205
+ ):
206
+ set_seed(config["seed"])
207
+ initialize_wandb(config)
208
 
209
+ model = create_model(config, num_labels_list, device)
210
+ total_steps = len(train_loader) * config["epochs"]
211
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
212
 
213
+ log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
214
+ writer = SummaryWriter(log_dir=log_dir)
215
+
216
+ epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
217
+ for epoch in epoch_progress:
218
+ last_loss = train_epoch(
219
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
220
  )
221
+ epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
224
+ model, val_loader, device, config
225
+ )
226
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
227
 
228
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
229
+ writer.close()
230
 
231
+ save_validation_predictions(
232
+ val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  )
234
 
235
+ if config.get("use_wandb", False):
236
+ import wandb
 
 
 
 
 
237
 
238
+ wandb.finish()
239
 
240
+ print(f"\nFinal Validation Loss: {val_loss:.4f}")
241
+ return val_loss, model # Return both the validation loss and the trained model
 
 
242
 
243
+
244
+ def objective(
245
+ trial,
246
+ train_loader,
247
+ val_loader,
248
+ train_cell_id_mapping,
249
+ val_cell_id_mapping,
250
+ num_labels_list,
251
+ config,
252
+ device,
253
+ ):
254
+ set_seed(config["seed"]) # Set the seed before each trial
255
+ initialize_wandb(config)
256
+
257
+ # Hyperparameters
258
+ config["learning_rate"] = trial.suggest_float(
259
+ "learning_rate",
260
+ config["hyperparameters"]["learning_rate"]["low"],
261
+ config["hyperparameters"]["learning_rate"]["high"],
262
+ log=config["hyperparameters"]["learning_rate"]["log"],
263
+ )
264
+ config["warmup_ratio"] = trial.suggest_float(
265
+ "warmup_ratio",
266
+ config["hyperparameters"]["warmup_ratio"]["low"],
267
+ config["hyperparameters"]["warmup_ratio"]["high"],
268
+ )
269
+ config["weight_decay"] = trial.suggest_float(
270
+ "weight_decay",
271
+ config["hyperparameters"]["weight_decay"]["low"],
272
+ config["hyperparameters"]["weight_decay"]["high"],
273
+ )
274
+ config["dropout_rate"] = trial.suggest_float(
275
+ "dropout_rate",
276
+ config["hyperparameters"]["dropout_rate"]["low"],
277
+ config["hyperparameters"]["dropout_rate"]["high"],
278
+ )
279
+ config["lr_scheduler_type"] = trial.suggest_categorical(
280
+ "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
+ )
282
+ config["use_attention_pooling"] = trial.suggest_categorical(
283
+ "use_attention_pooling", [False]
284
  )
285
 
286
+ if config["use_task_weights"]:
287
+ config["task_weights"] = [
288
+ trial.suggest_float(
289
+ f"task_weight_{i}",
290
+ config["hyperparameters"]["task_weights"]["low"],
291
+ config["hyperparameters"]["task_weights"]["high"],
292
+ )
293
+ for i in range(len(num_labels_list))
294
+ ]
295
+ weight_sum = sum(config["task_weights"])
296
+ config["task_weights"] = [
297
+ weight / weight_sum for weight in config["task_weights"]
298
+ ]
299
+ else:
300
+ config["task_weights"] = None
301
+
302
+ # Dynamic range for max_layers_to_freeze
303
+ freeze_range = get_layer_freeze_range(config["pretrained_path"])
304
+ config["max_layers_to_freeze"] = trial.suggest_int(
305
+ "max_layers_to_freeze",
306
+ freeze_range["min"],
307
+ freeze_range["max"]
308
+ )
309
+
310
+ model = create_model(config, num_labels_list, device)
311
+ total_steps = len(train_loader) * config["epochs"]
312
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
313
 
314
+ log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
315
+ writer = SummaryWriter(log_dir=log_dir)
 
316
 
317
+ for epoch in range(config["epochs"]):
318
+ train_epoch(
319
+ model, train_loader, optimizer, scheduler, device, config, writer, epoch
320
+ )
321
 
322
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
323
+ model, val_loader, device, config
324
  )
325
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
326
 
327
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
328
+ writer.close()
329
 
330
+ save_validation_predictions(
331
+ val_cell_id_mapping,
332
+ task_true_labels,
333
+ task_pred_labels,
334
+ task_pred_probs,
335
+ config,
336
+ trial.number,
337
+ )
338
+
339
+ trial.set_user_attr("model_state_dict", model.state_dict())
340
+ trial.set_user_attr("task_weights", config["task_weights"])
341
+
342
+ trial.report(val_loss, config["epochs"])
343
+
344
+ if trial.should_prune():
345
+ raise optuna.TrialPruned()
346
+
347
+ if config.get("use_wandb", False):
348
+ import wandb
349
+
350
+ wandb.log(
351
+ {
352
+ "trial_number": trial.number,
353
+ "val_loss": val_loss,
354
+ **{
355
+ f"{task_name}_f1": metrics["f1"]
356
+ for task_name, metrics in task_metrics.items()
357
+ },
358
+ **{
359
+ f"{task_name}_accuracy": metrics["accuracy"]
360
+ for task_name, metrics in task_metrics.items()
361
+ },
362
+ **{
363
+ k: v
364
+ for k, v in config.items()
365
+ if k
366
+ in [
367
+ "learning_rate",
368
+ "warmup_ratio",
369
+ "weight_decay",
370
+ "dropout_rate",
371
+ "lr_scheduler_type",
372
+ "use_attention_pooling",
373
+ "max_layers_to_freeze",
374
+ ]
375
+ },
376
+ }
377
+ )
378
+ wandb.finish()
379
+
380
+ return val_loss
geneformer/mtl/train_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from .data import get_data_loader, preload_and_process_data
4
+ from .imports import *
5
+ from .model import GeneformerMultiTask
6
+ from .train import objective, train_model
7
+ from .utils import save_model
8
+
9
+
10
+ def set_seed(seed):
11
+ random.seed(seed)
12
+ np.random.seed(seed)
13
+ torch.manual_seed(seed)
14
+ torch.cuda.manual_seed_all(seed)
15
+ torch.backends.cudnn.deterministic = True
16
+ torch.backends.cudnn.benchmark = False
17
+
18
+
19
+ def run_manual_tuning(config):
20
+ # Set seed for reproducibility
21
+ set_seed(config["seed"])
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ (
25
+ train_dataset,
26
+ train_cell_id_mapping,
27
+ val_dataset,
28
+ val_cell_id_mapping,
29
+ num_labels_list,
30
+ ) = preload_and_process_data(config)
31
+ train_loader = get_data_loader(train_dataset, config["batch_size"])
32
+ val_loader = get_data_loader(val_dataset, config["batch_size"])
33
+
34
+ # Print the manual hyperparameters being used
35
+ print("\nManual hyperparameters being used:")
36
+ for key, value in config["manual_hyperparameters"].items():
37
+ print(f"{key}: {value}")
38
+ print() # Add an empty line for better readability
39
+
40
+ # Use the manual hyperparameters
41
+ for key, value in config["manual_hyperparameters"].items():
42
+ config[key] = value
43
+
44
+ # Train the model
45
+ val_loss, trained_model = train_model(
46
+ config,
47
+ device,
48
+ train_loader,
49
+ val_loader,
50
+ train_cell_id_mapping,
51
+ val_cell_id_mapping,
52
+ num_labels_list,
53
+ )
54
+
55
+ print(f"\nValidation loss with manual hyperparameters: {val_loss}")
56
+
57
+ # Save the trained model
58
+ model_save_directory = os.path.join(
59
+ config["model_save_path"], "GeneformerMultiTask"
60
+ )
61
+ save_model(trained_model, model_save_directory)
62
+
63
+ # Save the hyperparameters
64
+ hyperparams_to_save = {
65
+ **config["manual_hyperparameters"],
66
+ "dropout_rate": config["dropout_rate"],
67
+ "use_task_weights": config["use_task_weights"],
68
+ "task_weights": config["task_weights"],
69
+ "max_layers_to_freeze": config["max_layers_to_freeze"],
70
+ "use_attention_pooling": config["use_attention_pooling"],
71
+ }
72
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
73
+ with open(hyperparams_path, "w") as f:
74
+ json.dump(hyperparams_to_save, f)
75
+ print(f"Manual hyperparameters saved to {hyperparams_path}")
76
+
77
+ return val_loss
78
+
79
+
80
+ def run_optuna_study(config):
81
+ # Set seed for reproducibility
82
+ set_seed(config["seed"])
83
+
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ (
86
+ train_dataset,
87
+ train_cell_id_mapping,
88
+ val_dataset,
89
+ val_cell_id_mapping,
90
+ num_labels_list,
91
+ ) = preload_and_process_data(config)
92
+ train_loader = get_data_loader(train_dataset, config["batch_size"])
93
+ val_loader = get_data_loader(val_dataset, config["batch_size"])
94
+
95
+ if config["use_manual_hyperparameters"]:
96
+ train_model(
97
+ config,
98
+ device,
99
+ train_loader,
100
+ val_loader,
101
+ train_cell_id_mapping,
102
+ val_cell_id_mapping,
103
+ num_labels_list,
104
+ )
105
+ else:
106
+ objective_with_config_and_data = functools.partial(
107
+ objective,
108
+ train_loader=train_loader,
109
+ val_loader=val_loader,
110
+ train_cell_id_mapping=train_cell_id_mapping,
111
+ val_cell_id_mapping=val_cell_id_mapping,
112
+ num_labels_list=num_labels_list,
113
+ config=config,
114
+ device=device,
115
+ )
116
+
117
+ study = optuna.create_study(
118
+ direction="minimize", # Minimize validation loss
119
+ study_name=config["study_name"],
120
+ # storage=config["storage"],
121
+ load_if_exists=True,
122
+ )
123
+
124
+ study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
125
+
126
+ # After finding the best trial
127
+ best_params = study.best_trial.params
128
+ best_task_weights = study.best_trial.user_attrs["task_weights"]
129
+ print("Saving the best model and its hyperparameters...")
130
+
131
+ # Saving model as before
132
+ best_model = GeneformerMultiTask(
133
+ config["pretrained_path"],
134
+ num_labels_list,
135
+ dropout_rate=best_params["dropout_rate"],
136
+ use_task_weights=config["use_task_weights"],
137
+ task_weights=best_task_weights,
138
+ )
139
+
140
+ # Get the best model state dictionary
141
+ best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
142
+
143
+ # Remove the "module." prefix from the state dictionary keys if present
144
+ best_model_state_dict = {
145
+ k.replace("module.", ""): v for k, v in best_model_state_dict.items()
146
+ }
147
+
148
+ # Load the modified state dictionary into the model, skipping unexpected keys
149
+ best_model.load_state_dict(best_model_state_dict, strict=False)
150
+
151
+ model_save_directory = os.path.join(
152
+ config["model_save_path"], "GeneformerMultiTask"
153
+ )
154
+ save_model(best_model, model_save_directory)
155
+
156
+ # Additionally, save the best hyperparameters and task weights
157
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
158
+
159
+ with open(hyperparams_path, "w") as f:
160
+ json.dump({**best_params, "task_weights": best_task_weights}, f)
161
+ print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
geneformer/mtl/utils.py CHANGED
@@ -1,641 +1,129 @@
1
- from typing import Dict, List, Optional, Union
2
- import json
3
  import os
4
- import pickle
5
- import random
6
- import torch
7
- import numpy as np
8
- import optuna
9
  from sklearn.metrics import accuracy_score, f1_score
10
  from sklearn.preprocessing import LabelEncoder
11
- from torch.utils.tensorboard import SummaryWriter
12
- from transformers import AutoConfig, BertConfig, BertModel, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
13
- from torch.optim import AdamW
14
- import pandas as pd
15
- import torch.distributed as dist
16
- from torch.nn.parallel import DistributedDataParallel as DDP
17
- import torch.multiprocessing as mp
18
- from contextlib import contextmanager
19
-
20
-
21
- def set_seed(seed):
22
- random.seed(seed)
23
- np.random.seed(seed)
24
- torch.manual_seed(seed)
25
- torch.cuda.manual_seed_all(seed)
26
- torch.backends.cudnn.deterministic = True
27
- torch.backends.cudnn.benchmark = False
28
-
29
-
30
- def initialize_wandb(config):
31
- if config.get("use_wandb", False):
32
- import wandb
33
- wandb.init(
34
- project=config.get("wandb_project", "geneformer_multitask"),
35
- name=config.get("run_name", "experiment"),
36
- config=config,
37
- reinit=True,
38
- )
39
 
40
-
41
- def create_model(config, num_labels_list, device, is_distributed=False, local_rank=0):
42
- """Create and initialize the model based on configuration."""
43
- from .model import GeneformerMultiTask
44
-
45
- model = GeneformerMultiTask(
46
- config["pretrained_path"],
47
- num_labels_list,
48
- dropout_rate=config.get("dropout_rate", 0.1),
49
- use_task_weights=config.get("use_task_weights", False),
50
- task_weights=config.get("task_weights", None),
51
- max_layers_to_freeze=config.get("max_layers_to_freeze", 0),
52
- use_attention_pooling=config.get("use_attention_pooling", False),
53
- )
54
-
55
- # Move model to device
56
- model.to(device)
57
-
58
- if is_distributed:
59
- model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
60
-
61
- return model
62
-
63
-
64
- def setup_optimizer_and_scheduler(model, config, total_steps):
65
- """Set up optimizer and learning rate scheduler."""
66
- no_decay = ["bias", "LayerNorm.weight"]
67
- optimizer_grouped_parameters = [
68
- {
69
- "params": [p for n, p in model.named_parameters()
70
- if not any(nd in n for nd in no_decay) and p.requires_grad],
71
- "weight_decay": config["weight_decay"],
72
- },
73
- {
74
- "params": [p for n, p in model.named_parameters()
75
- if any(nd in n for nd in no_decay) and p.requires_grad],
76
- "weight_decay": 0.0,
77
- },
78
- ]
79
-
80
- optimizer = AdamW(
81
- optimizer_grouped_parameters,
82
- lr=config["learning_rate"],
83
- eps=config.get("adam_epsilon", 1e-8)
84
- )
85
-
86
- # Prepare scheduler
87
- warmup_steps = int(total_steps * config["warmup_ratio"])
88
-
89
- scheduler_map = {
90
- "linear": get_linear_schedule_with_warmup,
91
- "cosine": get_cosine_schedule_with_warmup
92
- }
93
-
94
- scheduler_fn = scheduler_map.get(config["lr_scheduler_type"])
95
- if not scheduler_fn:
96
- raise ValueError(f"Unsupported scheduler type: {config['lr_scheduler_type']}")
97
-
98
- scheduler = scheduler_fn(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
99
-
100
- return optimizer, scheduler
101
 
102
 
103
  def save_model(model, model_save_directory):
104
- """Save model weights and configuration."""
105
- os.makedirs(model_save_directory, exist_ok=True)
106
-
107
- # Handle DDP model
108
- if isinstance(model, DDP):
109
- model_to_save = model.module
 
 
110
  else:
111
- model_to_save = model
112
-
113
- model_state_dict = model_to_save.state_dict()
 
 
 
114
 
115
  model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
116
  torch.save(model_state_dict, model_save_path)
117
 
118
  # Save the model configuration
119
- model_to_save.config.to_json_file(os.path.join(model_save_directory, "config.json"))
120
-
121
- print(f"Model and configuration saved to {model_save_directory}")
122
-
123
-
124
- def save_hyperparameters(model_save_directory, hyperparams):
125
- """Save hyperparameters to a JSON file."""
126
- hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
127
- with open(hyperparams_path, "w") as f:
128
- json.dump(hyperparams, f)
129
- print(f"Hyperparameters saved to {hyperparams_path}")
130
-
131
-
132
- def calculate_metrics(labels=None, preds=None, task_data=None, metric_type="task_specific", return_format="dict"):
133
- if metric_type == "single":
134
- # Calculate metrics for a single task
135
- if labels is None or preds is None:
136
- raise ValueError("Labels and predictions must be provided for single task metrics")
137
-
138
- task_name = None
139
- if isinstance(labels, dict) and len(labels) == 1:
140
- task_name = list(labels.keys())[0]
141
- labels = labels[task_name]
142
- preds = preds[task_name]
143
-
144
- f1 = f1_score(labels, preds, average="macro")
145
- accuracy = accuracy_score(labels, preds)
146
-
147
- if return_format == "tuple":
148
- return f1, accuracy
149
-
150
- result = {"f1": f1, "accuracy": accuracy}
151
- if task_name:
152
- return {task_name: result}
153
- return result
154
-
155
- elif metric_type == "task_specific":
156
- # Calculate metrics for multiple tasks
157
- if task_data:
158
- result = {}
159
- for task_name, (task_labels, task_preds) in task_data.items():
160
- f1 = f1_score(task_labels, task_preds, average="macro")
161
- accuracy = accuracy_score(task_labels, task_preds)
162
- result[task_name] = {"f1": f1, "accuracy": accuracy}
163
- return result
164
- elif isinstance(labels, dict) and isinstance(preds, dict):
165
- result = {}
166
- for task_name in labels:
167
- if task_name in preds:
168
- f1 = f1_score(labels[task_name], preds[task_name], average="macro")
169
- accuracy = accuracy_score(labels[task_name], preds[task_name])
170
- result[task_name] = {"f1": f1, "accuracy": accuracy}
171
- return result
172
- else:
173
- raise ValueError("For task_specific metrics, either task_data or labels and preds dictionaries must be provided")
174
-
175
- elif metric_type == "combined":
176
- # Calculate combined metrics across all tasks
177
- if labels is None or preds is None:
178
- raise ValueError("Labels and predictions must be provided for combined metrics")
179
-
180
- # Handle label encoding for non-numeric labels
181
- if not all(isinstance(x, (int, float)) for x in labels + preds):
182
- le = LabelEncoder()
183
- le.fit(labels + preds)
184
- labels = le.transform(labels)
185
- preds = le.transform(preds)
186
-
187
- f1 = f1_score(labels, preds, average="macro")
188
- accuracy = accuracy_score(labels, preds)
189
-
190
- if return_format == "tuple":
191
- return f1, accuracy
192
- return {"f1": f1, "accuracy": accuracy}
193
-
194
- else:
195
- raise ValueError(f"Unknown metric_type: {metric_type}")
196
-
197
-
198
- def get_layer_freeze_range(pretrained_path):
199
- if not pretrained_path:
200
- return {"min": 0, "max": 0}
201
-
202
- config = AutoConfig.from_pretrained(pretrained_path)
203
- total_layers = config.num_hidden_layers
204
- return {"min": 0, "max": total_layers - 1}
205
-
206
-
207
- def prepare_training_environment(config):
208
- """
209
- Prepare the training environment by setting seed and loading data.
210
-
211
- Returns:
212
- tuple: (device, train_loader, val_loader, train_cell_id_mapping,
213
- val_cell_id_mapping, num_labels_list)
214
- """
215
- from .data import prepare_data_loaders
216
-
217
- # Set seed for reproducibility
218
- set_seed(config["seed"])
219
-
220
- # Set up device - for non-distributed training
221
- if not config.get("distributed_training", False):
222
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
  else:
224
- # For distributed training, device will be set per process
225
- device = None
226
-
227
- # Load data using the streaming dataset
228
- data = prepare_data_loaders(config)
229
-
230
- # For distributed training, we'll set up samplers later in the distributed worker
231
- # Don't create DistributedSampler here as process group isn't initialized yet
232
-
233
- return (
234
- device,
235
- data["train_loader"],
236
- data["val_loader"],
237
- data["train_cell_mapping"],
238
- data["val_cell_mapping"],
239
- data["num_labels_list"],
240
- )
241
 
 
242
 
243
- # Optuna hyperparameter optimization utilities
244
- def save_trial_callback(study, trial, trials_result_path):
245
- """
246
- Callback to save Optuna trial results to a file.
247
-
248
- Args:
249
- study: Optuna study object
250
- trial: Current trial object
251
- trials_result_path: Path to save trial results
252
- """
253
- with open(trials_result_path, "a") as f:
254
- f.write(
255
- f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
256
- )
257
 
 
 
 
 
 
 
 
 
 
258
 
259
- def create_optuna_study(objective, n_trials: int, trials_result_path: str, tensorboard_log_dir: str) -> optuna.Study:
260
- """Create and run an Optuna study with TensorBoard logging."""
261
- from optuna.integration import TensorBoardCallback
262
-
263
- study = optuna.create_study(direction="maximize")
264
- study.optimize(
265
- objective,
266
- n_trials=n_trials,
267
- callbacks=[
268
- lambda study, trial: save_trial_callback(study, trial, trials_result_path),
269
- TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
270
- ]
271
- )
272
- return study
273
 
 
 
 
274
 
275
- @contextmanager
276
- def setup_logging(config):
277
- run_name = config.get("run_name", "manual_run")
278
- log_dir = os.path.join(config["tensorboard_log_dir"], run_name)
279
- writer = SummaryWriter(log_dir=log_dir)
280
- try:
281
- yield writer
282
- finally:
283
- writer.close()
284
 
 
 
 
 
285
 
286
- def log_training_step(loss, writer, config, epoch, steps_per_epoch, batch_idx):
287
- """Log training step metrics to TensorBoard and optionally W&B."""
288
- writer.add_scalar(
289
- "Training Loss", loss, epoch * steps_per_epoch + batch_idx
290
- )
291
- if config.get("use_wandb", False):
292
- import wandb
293
- wandb.log({"Training Loss": loss})
294
 
 
 
295
 
296
- def log_validation_metrics(task_metrics, val_loss, config, writer, epoch):
297
- """Log validation metrics to console, TensorBoard, and optionally W&B."""
298
- for task_name, metrics in task_metrics.items():
299
- print(
300
- f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
301
- )
302
- if config.get("use_wandb", False):
303
- import wandb
304
- wandb.log(
305
- {
306
- f"{task_name} Validation F1 Macro": metrics["f1"],
307
- f"{task_name} Validation Accuracy": metrics["accuracy"],
308
- }
309
- )
310
 
311
- writer.add_scalar("Validation Loss", val_loss, epoch)
312
- for task_name, metrics in task_metrics.items():
313
- writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epoch)
314
- writer.add_scalar(
315
- f"{task_name} - Validation Accuracy", metrics["accuracy"], epoch
316
- )
317
 
 
 
 
 
 
318
 
319
- def load_label_mappings(results_dir: str, task_names: List[str]) -> Dict[str, Dict]:
320
- """Load or initialize task label mappings."""
321
- label_mappings_path = os.path.join(results_dir, "task_label_mappings_val.pkl")
322
- if os.path.exists(label_mappings_path):
323
- with open(label_mappings_path, 'rb') as f:
324
- return pickle.load(f)
325
- return {task_name: {} for task_name in task_names}
326
 
 
 
 
 
 
327
 
328
- def create_prediction_row(sample_idx: int, val_cell_indices: Dict, task_true_labels: Dict,
329
- task_pred_labels: Dict, task_pred_probs: Dict, task_names: List[str],
330
- inverted_mappings: Dict, val_cell_mapping: Dict) -> Dict:
331
- """Create a row for validation predictions."""
332
- batch_cell_idx = val_cell_indices.get(sample_idx)
333
- cell_id = val_cell_mapping.get(batch_cell_idx, f"unknown_cell_{sample_idx}") if batch_cell_idx is not None else f"unknown_cell_{sample_idx}"
334
-
335
- row = {"Cell ID": cell_id}
336
- for task_name in task_names:
337
- if task_name in task_true_labels and sample_idx < len(task_true_labels[task_name]):
338
- true_idx = task_true_labels[task_name][sample_idx]
339
- pred_idx = task_pred_labels[task_name][sample_idx]
340
- true_label = inverted_mappings.get(task_name, {}).get(true_idx, f"Unknown-{true_idx}")
341
- pred_label = inverted_mappings.get(task_name, {}).get(pred_idx, f"Unknown-{pred_idx}")
342
-
343
- row.update({
344
- f"{task_name}_true_idx": true_idx,
345
- f"{task_name}_pred_idx": pred_idx,
346
- f"{task_name}_true_label": true_label,
347
- f"{task_name}_pred_label": pred_label
348
- })
349
-
350
- if task_name in task_pred_probs and sample_idx < len(task_pred_probs[task_name]):
351
- probs = task_pred_probs[task_name][sample_idx]
352
- if isinstance(probs, (list, np.ndarray)) or (hasattr(probs, '__iter__') and not isinstance(probs, str)):
353
- prob_list = list(probs) if not isinstance(probs, list) else probs
354
- row[f"{task_name}_all_probs"] = ",".join(map(str, prob_list))
355
- for class_idx, prob in enumerate(prob_list):
356
- class_label = inverted_mappings.get(task_name, {}).get(class_idx, f"Unknown-{class_idx}")
357
- row[f"{task_name}_prob_{class_label}"] = prob
358
- else:
359
- row[f"{task_name}_all_probs"] = str(probs)
360
-
361
- return row
362
 
 
 
363
 
364
- def save_validation_predictions(
365
- val_cell_indices,
366
- task_true_labels,
367
- task_pred_labels,
368
- task_pred_probs,
369
- config,
370
- trial_number=None,
371
- ):
372
- """Save validation predictions to a CSV file with class labels and probabilities."""
373
- os.makedirs(config["results_dir"], exist_ok=True)
374
-
375
- if trial_number is not None:
376
- os.makedirs(os.path.join(config["results_dir"], f"trial_{trial_number}"), exist_ok=True)
377
- val_preds_file = os.path.join(config["results_dir"], f"trial_{trial_number}/val_preds.csv")
378
- else:
379
- val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
380
-
381
- if not val_cell_indices or not task_true_labels:
382
- pd.DataFrame().to_csv(val_preds_file, index=False)
383
- return
384
-
385
- try:
386
- label_mappings = load_label_mappings(config["results_dir"], config["task_names"])
387
- inverted_mappings = {task: {idx: label for label, idx in mapping.items()} for task, mapping in label_mappings.items()}
388
- val_cell_mapping = config.get("val_cell_mapping", {})
389
-
390
- # Determine maximum number of samples
391
- max_samples = max(
392
- [len(val_cell_indices)] +
393
- [len(task_true_labels[task]) for task in task_true_labels]
394
- )
395
-
396
- rows = [
397
- create_prediction_row(
398
- sample_idx, val_cell_indices, task_true_labels, task_pred_labels,
399
- task_pred_probs, config["task_names"], inverted_mappings, val_cell_mapping
400
- )
401
- for sample_idx in range(max_samples)
402
- ]
403
-
404
- pd.DataFrame(rows).to_csv(val_preds_file, index=False)
405
- except Exception as e:
406
- pd.DataFrame([{"Error": str(e)}]).to_csv(val_preds_file, index=False)
407
 
 
 
 
 
 
408
 
409
- def setup_distributed_environment(rank, world_size, config):
410
- """
411
- Setup the distributed training environment.
412
-
413
- Args:
414
- rank (int): The rank of the current process
415
- world_size (int): Total number of processes
416
- config (dict): Configuration dictionary
417
- """
418
- os.environ['MASTER_ADDR'] = config.get('master_addr', 'localhost')
419
- os.environ['MASTER_PORT'] = config.get('master_port', '12355')
420
-
421
- # Initialize the process group
422
- dist.init_process_group(
423
- backend='nccl',
424
- init_method='env://',
425
- world_size=world_size,
426
- rank=rank
427
- )
428
-
429
- # Set the device for this process
430
- torch.cuda.set_device(rank)
431
 
432
 
433
- def train_distributed(trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
434
- """Run distributed training across multiple GPUs with fallback to single GPU."""
435
- world_size = torch.cuda.device_count()
436
-
437
- if world_size <= 1:
438
- print("Distributed training requested but only one GPU found. Falling back to single GPU training.")
439
- config["distributed_training"] = False
440
- trainer = trainer_class(config)
441
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
442
- trainer.device = device
443
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
444
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
445
- )
446
- val_loss, model = trainer.train(
447
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
448
- )
449
- model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
450
- save_model(model, model_save_directory)
451
- save_hyperparameters(model_save_directory, {
452
- **get_config_value(config, "manual_hyperparameters", {}),
453
- "dropout_rate": config["dropout_rate"],
454
- "use_task_weights": config["use_task_weights"],
455
- "task_weights": config["task_weights"],
456
- "max_layers_to_freeze": config["max_layers_to_freeze"],
457
- "use_attention_pooling": config["use_attention_pooling"],
458
- })
459
-
460
- if shared_dict is not None:
461
- shared_dict['val_loss'] = val_loss
462
- task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(model, val_loader, device, config)
463
- shared_dict['task_metrics'] = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
464
- shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()}
465
-
466
- return val_loss, model
467
-
468
- print(f"Using distributed training with {world_size} GPUs")
469
- mp.spawn(
470
- _distributed_worker,
471
- args=(world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number, shared_dict),
472
- nprocs=world_size,
473
- join=True
474
- )
475
-
476
- if trial_number is None and shared_dict is None:
477
- model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
478
- model_path = os.path.join(model_save_directory, "pytorch_model.bin")
479
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
480
- model = create_model(config, num_labels_list, device)
481
- model.load_state_dict(torch.load(model_path))
482
- return 0.0, model
483
-
484
- return None
485
-
486
-
487
- def _distributed_worker(rank, world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None):
488
- """Worker function for distributed training."""
489
- setup_distributed_environment(rank, world_size, config)
490
- config["local_rank"] = rank
491
-
492
- # Set up distributed samplers
493
- from torch.utils.data import DistributedSampler
494
- from .data import get_data_loader
495
-
496
- train_sampler = DistributedSampler(train_loader.dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
497
- val_sampler = DistributedSampler(val_loader.dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
498
-
499
- train_loader = get_data_loader(train_loader.dataset, config["batch_size"], sampler=train_sampler, shuffle=False)
500
- val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=val_sampler, shuffle=False)
501
-
502
- if rank == 0:
503
- print(f"Rank {rank}: Training {len(train_sampler)} samples, Validation {len(val_sampler)} samples")
504
- print(f"Total samples across {world_size} GPUs: Training {len(train_sampler) * world_size}, Validation {len(val_sampler) * world_size}")
505
-
506
- # Create and setup trainer
507
- trainer = trainer_class(config)
508
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup(
509
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
510
- )
511
-
512
- # Train the model
513
- val_loss, model = trainer.train(
514
- train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list
515
- )
516
-
517
- # Save model only from the main process
518
- if rank == 0:
519
- model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
520
- save_model(model, model_save_directory)
521
-
522
- save_hyperparameters(model_save_directory, {
523
- **get_config_value(config, "manual_hyperparameters", {}),
524
- "dropout_rate": config["dropout_rate"],
525
- "use_task_weights": config["use_task_weights"],
526
- "task_weights": config["task_weights"],
527
- "max_layers_to_freeze": config["max_layers_to_freeze"],
528
- "use_attention_pooling": config["use_attention_pooling"],
529
- })
530
-
531
- # For Optuna trials, store results in shared dictionary
532
- if shared_dict is not None:
533
- shared_dict['val_loss'] = val_loss
534
-
535
- # Run validation on full dataset from rank 0 for consistent metrics
536
- full_val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=None, shuffle=False)
537
-
538
- # Get validation predictions using our utility function
539
- task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(
540
- model, full_val_loader, trainer.device, config
541
- )
542
-
543
- # Calculate metrics
544
- task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific")
545
- shared_dict['task_metrics'] = task_metrics
546
-
547
- # Store model state dict
548
- if isinstance(model, DDP):
549
- model_state_dict = model.module.state_dict()
550
- else:
551
- model_state_dict = model.state_dict()
552
-
553
- shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model_state_dict.items()}
554
-
555
- # Clean up distributed environment
556
- dist.destroy_process_group()
557
-
558
-
559
- def save_model_without_heads(model_directory):
560
  """
561
- Save a version of the fine-tuned model without classification heads.
562
-
563
  Args:
564
- model_directory (str): Path to the directory containing the fine-tuned model
 
 
565
  """
566
- import torch
567
- from transformers import BertConfig, BertModel
568
-
569
- # Load the full model
570
- model_path = os.path.join(model_directory, "pytorch_model.bin")
571
- config_path = os.path.join(model_directory, "config.json")
572
-
573
- if not os.path.exists(model_path) or not os.path.exists(config_path):
574
- raise FileNotFoundError(f"Model files not found in {model_directory}")
575
-
576
- # Load the configuration
577
- config = BertConfig.from_json_file(config_path)
578
-
579
- # Load the model state dict
580
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
581
-
582
- # Create a new model without heads
583
- base_model = BertModel(config)
584
-
585
- # Filter out the classification head parameters
586
- base_model_state_dict = {}
587
- for key, value in state_dict.items():
588
- # Only keep parameters that belong to the base model (not classification heads)
589
- if not key.startswith('classification_heads') and not key.startswith('attention_pool'):
590
- base_model_state_dict[key] = value
591
-
592
- # Load the filtered state dict into the base model
593
- base_model.load_state_dict(base_model_state_dict, strict=False)
594
-
595
- # Save the model without heads
596
- output_dir = os.path.join(model_directory, "model_without_heads")
597
- os.makedirs(output_dir, exist_ok=True)
598
-
599
- # Save the model weights
600
- torch.save(base_model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
601
-
602
- # Save the configuration
603
- base_model.config.to_json_file(os.path.join(output_dir, "config.json"))
604
-
605
- print(f"Model without classification heads saved to {output_dir}")
606
- return output_dir
607
-
608
-
609
- def get_config_value(config: Dict, key: str, default=None):
610
-
611
- return config.get(key, default)
612
-
613
-
614
- def collect_validation_predictions(model, val_loader, device, config) -> tuple:
615
- task_true_labels = {}
616
- task_pred_labels = {}
617
- task_pred_probs = {}
618
-
619
- with torch.no_grad():
620
- for batch in val_loader:
621
- input_ids = batch["input_ids"].to(device)
622
- attention_mask = batch["attention_mask"].to(device)
623
- labels = [batch["labels"][task_name].to(device) for task_name in config["task_names"]]
624
- _, logits, _ = model(input_ids, attention_mask, labels)
625
-
626
- for sample_idx in range(len(batch["input_ids"])):
627
- for i, task_name in enumerate(config["task_names"]):
628
- if task_name not in task_true_labels:
629
- task_true_labels[task_name] = []
630
- task_pred_labels[task_name] = []
631
- task_pred_probs[task_name] = []
632
-
633
- true_label = batch["labels"][task_name][sample_idx].item()
634
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
635
- pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
636
-
637
- task_true_labels[task_name].append(true_label)
638
- task_pred_labels[task_name].append(pred_label)
639
- task_pred_probs[task_name].append(pred_prob)
640
-
641
- return task_true_labels, task_pred_labels, task_pred_probs
 
 
 
1
  import os
2
+ import shutil
3
+
 
 
 
4
  from sklearn.metrics import accuracy_score, f1_score
5
  from sklearn.preprocessing import LabelEncoder
6
+ from transformers import AutoConfig, BertConfig, BertModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from .imports import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def save_model(model, model_save_directory):
12
+ if not os.path.exists(model_save_directory):
13
+ os.makedirs(model_save_directory)
14
+
15
+ # Get the state dict
16
+ if isinstance(model, nn.DataParallel):
17
+ model_state_dict = (
18
+ model.module.state_dict()
19
+ ) # Use model.module to access the underlying model
20
  else:
21
+ model_state_dict = model.state_dict()
22
+
23
+ # Remove the "module." prefix from the keys if present
24
+ model_state_dict = {
25
+ k.replace("module.", ""): v for k, v in model_state_dict.items()
26
+ }
27
 
28
  model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
29
  torch.save(model_state_dict, model_save_path)
30
 
31
  # Save the model configuration
32
+ if isinstance(model, nn.DataParallel):
33
+ model.module.config.to_json_file(
34
+ os.path.join(model_save_directory, "config.json")
35
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  else:
37
+ model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ print(f"Model and configuration saved to {model_save_directory}")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
43
+ task_metrics = {}
44
+ for task_name in task_true_labels.keys():
45
+ true_labels = task_true_labels[task_name]
46
+ pred_labels = task_pred_labels[task_name]
47
+ f1 = f1_score(true_labels, pred_labels, average="macro")
48
+ accuracy = accuracy_score(true_labels, pred_labels)
49
+ task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
50
+ return task_metrics
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ def calculate_combined_f1(combined_labels, combined_preds):
54
+ # Initialize the LabelEncoder
55
+ le = LabelEncoder()
56
 
57
+ # Fit and transform combined labels and predictions to numerical values
58
+ le.fit(combined_labels + combined_preds)
59
+ encoded_true_labels = le.transform(combined_labels)
60
+ encoded_pred_labels = le.transform(combined_preds)
 
 
 
 
 
61
 
62
+ # Print out the mapping for sanity check
63
+ print("\nLabel Encoder Mapping:")
64
+ for index, class_label in enumerate(le.classes_):
65
+ print(f"'{class_label}': {index}")
66
 
67
+ # Calculate accuracy
68
+ accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
 
 
 
 
 
 
69
 
70
+ # Calculate F1 Macro score
71
+ f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
72
 
73
+ return f1, accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
 
 
 
75
 
76
+ # def save_model_without_heads(original_model_save_directory):
77
+ # # Create a new directory for the model without heads
78
+ # new_model_save_directory = original_model_save_directory + "_No_Heads"
79
+ # if not os.path.exists(new_model_save_directory):
80
+ # os.makedirs(new_model_save_directory)
81
 
82
+ # # Load the model state dictionary
83
+ # model_state_dict = torch.load(
84
+ # os.path.join(original_model_save_directory, "pytorch_model.bin")
85
+ # )
 
 
 
86
 
87
+ # # Initialize a new BERT model without the classification heads
88
+ # config = BertConfig.from_pretrained(
89
+ # os.path.join(original_model_save_directory, "config.json")
90
+ # )
91
+ # model_without_heads = BertModel(config)
92
 
93
+ # # Filter the state dict to exclude classification heads
94
+ # model_without_heads_state_dict = {
95
+ # k: v
96
+ # for k, v in model_state_dict.items()
97
+ # if not k.startswith("classification_heads")
98
+ # }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # # Load the filtered state dict into the model
101
+ # model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
102
 
103
+ # # Save the model without heads
104
+ # model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
105
+ # torch.save(model_without_heads.state_dict(), model_save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # # Copy the configuration file
108
+ # shutil.copy(
109
+ # os.path.join(original_model_save_directory, "config.json"),
110
+ # new_model_save_directory,
111
+ # )
112
 
113
+ # print(f"Model without classification heads saved to {new_model_save_directory}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
+ def get_layer_freeze_range(pretrained_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  """
118
+ Dynamically determines the number of layers to freeze based on the model depth from its configuration.
 
119
  Args:
120
+ pretrained_path (str): Path to the pretrained model directory or model identifier.
121
+ Returns:
122
+ dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
123
  """
124
+ if pretrained_path:
125
+ config = AutoConfig.from_pretrained(pretrained_path)
126
+ total_layers = config.num_hidden_layers
127
+ return {"min": 0, "max": total_layers - 1}
128
+ else:
129
+ return {"min": 0, "max": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl_classifier.py CHANGED
@@ -29,7 +29,7 @@ Geneformer multi-task cell classifier.
29
  import logging
30
  import os
31
 
32
- from .mtl import eval_utils, utils, train
33
 
34
  logger = logging.getLogger(__name__)
35
 
@@ -49,9 +49,7 @@ class MTLClassifier:
49
  "max_layers_to_freeze": {None, dict},
50
  "epochs": {None, int},
51
  "tensorboard_log_dir": {None, str},
52
- "distributed_training": {None, bool},
53
- "master_addr": {None, str},
54
- "master_port": {None, str},
55
  "use_attention_pooling": {None, bool},
56
  "use_task_weights": {None, bool},
57
  "hyperparameters": {None, dict},
@@ -63,7 +61,6 @@ class MTLClassifier:
63
  "max_grad_norm": {None, int, float},
64
  "seed": {None, int},
65
  "trials_result_path": {None, str},
66
- "gradient_accumulation_steps": {None, int},
67
  }
68
 
69
  def __init__(
@@ -82,9 +79,7 @@ class MTLClassifier:
82
  max_layers_to_freeze=None,
83
  epochs=1,
84
  tensorboard_log_dir="/results/tblogdir",
85
- distributed_training=False,
86
- master_addr="localhost",
87
- master_port="12355",
88
  use_attention_pooling=True,
89
  use_task_weights=True,
90
  hyperparameters=None, # Default is None
@@ -94,7 +89,6 @@ class MTLClassifier:
94
  wandb_project=None,
95
  gradient_clipping=False,
96
  max_grad_norm=None,
97
- gradient_accumulation_steps=1, # Add this line with default value 1
98
  seed=42, # Default seed value
99
  ):
100
  """
@@ -123,12 +117,8 @@ class MTLClassifier:
123
  | Path to directory to save results
124
  tensorboard_log_dir : None, str
125
  | Path to directory for Tensorboard logging results
126
- distributed_training : None, bool
127
- | Whether to use distributed data parallel training across multiple GPUs
128
- master_addr : None, str
129
- | Master address for distributed training (default: localhost)
130
- master_port : None, str
131
- | Master port for distributed training (default: 12355)
132
  use_attention_pooling : None, bool
133
  | Whether to use attention pooling
134
  use_task_weights : None, bool
@@ -160,8 +150,6 @@ class MTLClassifier:
160
  | Whether to use gradient clipping
161
  max_grad_norm : None, int, float
162
  | Maximum norm for gradient clipping
163
- gradient_accumulation_steps : None, int
164
- | Number of steps to accumulate gradients before performing a backward/update pass
165
  seed : None, int
166
  | Random seed
167
  """
@@ -177,7 +165,6 @@ class MTLClassifier:
177
  self.batch_size = batch_size
178
  self.n_trials = n_trials
179
  self.study_name = study_name
180
- self.gradient_accumulation_steps = gradient_accumulation_steps
181
 
182
  if max_layers_to_freeze is None:
183
  # Dynamically determine the range of layers to freeze
@@ -188,9 +175,7 @@ class MTLClassifier:
188
 
189
  self.epochs = epochs
190
  self.tensorboard_log_dir = tensorboard_log_dir
191
- self.distributed_training = distributed_training
192
- self.master_addr = master_addr
193
- self.master_port = master_port
194
  self.use_attention_pooling = use_attention_pooling
195
  self.use_task_weights = use_task_weights
196
  self.hyperparameters = (
@@ -308,7 +293,7 @@ class MTLClassifier:
308
  self.config["manual_hyperparameters"] = self.manual_hyperparameters
309
  self.config["use_manual_hyperparameters"] = True
310
 
311
- train.run_manual_tuning(self.config)
312
 
313
  def validate_additional_options(self, req_var_dict):
314
  missing_variable = False
@@ -345,7 +330,7 @@ class MTLClassifier:
345
  req_var_dict = dict(zip(required_variable_names, required_variables))
346
  self.validate_additional_options(req_var_dict)
347
 
348
- train.run_optuna_study(self.config)
349
 
350
  def load_and_evaluate_test_model(
351
  self,
 
29
  import logging
30
  import os
31
 
32
+ from .mtl import eval_utils, train_utils, utils
33
 
34
  logger = logging.getLogger(__name__)
35
 
 
49
  "max_layers_to_freeze": {None, dict},
50
  "epochs": {None, int},
51
  "tensorboard_log_dir": {None, str},
52
+ "use_data_parallel": {None, bool},
 
 
53
  "use_attention_pooling": {None, bool},
54
  "use_task_weights": {None, bool},
55
  "hyperparameters": {None, dict},
 
61
  "max_grad_norm": {None, int, float},
62
  "seed": {None, int},
63
  "trials_result_path": {None, str},
 
64
  }
65
 
66
  def __init__(
 
79
  max_layers_to_freeze=None,
80
  epochs=1,
81
  tensorboard_log_dir="/results/tblogdir",
82
+ use_data_parallel=False,
 
 
83
  use_attention_pooling=True,
84
  use_task_weights=True,
85
  hyperparameters=None, # Default is None
 
89
  wandb_project=None,
90
  gradient_clipping=False,
91
  max_grad_norm=None,
 
92
  seed=42, # Default seed value
93
  ):
94
  """
 
117
  | Path to directory to save results
118
  tensorboard_log_dir : None, str
119
  | Path to directory for Tensorboard logging results
120
+ use_data_parallel : None, bool
121
+ | Whether to use data parallelization
 
 
 
 
122
  use_attention_pooling : None, bool
123
  | Whether to use attention pooling
124
  use_task_weights : None, bool
 
150
  | Whether to use gradient clipping
151
  max_grad_norm : None, int, float
152
  | Maximum norm for gradient clipping
 
 
153
  seed : None, int
154
  | Random seed
155
  """
 
165
  self.batch_size = batch_size
166
  self.n_trials = n_trials
167
  self.study_name = study_name
 
168
 
169
  if max_layers_to_freeze is None:
170
  # Dynamically determine the range of layers to freeze
 
175
 
176
  self.epochs = epochs
177
  self.tensorboard_log_dir = tensorboard_log_dir
178
+ self.use_data_parallel = use_data_parallel
 
 
179
  self.use_attention_pooling = use_attention_pooling
180
  self.use_task_weights = use_task_weights
181
  self.hyperparameters = (
 
293
  self.config["manual_hyperparameters"] = self.manual_hyperparameters
294
  self.config["use_manual_hyperparameters"] = True
295
 
296
+ train_utils.run_manual_tuning(self.config)
297
 
298
  def validate_additional_options(self, req_var_dict):
299
  missing_variable = False
 
330
  req_var_dict = dict(zip(required_variable_names, required_variables))
331
  self.validate_additional_options(req_var_dict)
332
 
333
+ train_utils.run_optuna_study(self.config)
334
 
335
  def load_and_evaluate_test_model(
336
  self,
geneformer/perturber_utils.py CHANGED
@@ -1,6 +1,5 @@
1
  import itertools as it
2
  import logging
3
- import os
4
  import pickle
5
  from collections import defaultdict
6
  from pathlib import Path
@@ -9,7 +8,6 @@ from typing import List
9
  import numpy as np
10
  import pandas as pd
11
  import torch
12
- import datasets
13
  from datasets import Dataset, load_from_disk
14
  from peft import LoraConfig, get_peft_model
15
  from transformers import (
@@ -19,6 +17,11 @@ from transformers import (
19
  BitsAndBytesConfig,
20
  )
21
 
 
 
 
 
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
@@ -124,10 +127,7 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
124
  output_hidden_states = (mode == "eval")
125
 
126
  # Quantization logic
127
- if isinstance(quantize, dict):
128
- quantize_config = quantize.get("bnb_config", None)
129
- peft_config = quantize.get("peft_config", None)
130
- elif quantize:
131
  if inference_only:
132
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
133
  peft_config = None
@@ -138,22 +138,13 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
- try:
142
- peft_config = LoraConfig(
143
- lora_alpha=128,
144
- lora_dropout=0.1,
145
- r=64,
146
- bias="none",
147
- task_type="TokenClassification",
148
- )
149
- except ValueError as e:
150
- peft_config = LoraConfig(
151
- lora_alpha=128,
152
- lora_dropout=0.1,
153
- r=64,
154
- bias="none",
155
- task_type="TOKEN_CLS",
156
- )
157
  else:
158
  quantize_config = None
159
  peft_config = None
@@ -190,34 +181,17 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
190
  model.eval()
191
 
192
  # Handle device placement and PEFT
193
- adapter_config_path = os.path.join(model_directory, "adapter_config.json")
194
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
195
  if not quantize:
196
  # Only move non-quantized models
197
- move_to_cuda(model)
198
- elif os.path.exists(adapter_config_path):
199
- # If adapter files exist, load them into the model using PEFT's from_pretrained
200
- model = PeftModel.from_pretrained(model, model_directory)
201
- move_to_cuda(model)
202
- print("loading lora weights")
203
  elif peft_config:
204
  # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
205
  model.enable_input_require_grads()
206
  model = get_peft_model(model, peft_config)
207
- move_to_cuda(model)
208
 
209
  return model
210
 
211
-
212
- def move_to_cuda(model):
213
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
214
- # get what device model is currently on
215
- model_device = next(model.parameters()).device
216
- # Check if the model is on the CPU and move to cuda if necessary
217
- if (model_device.type == "cpu") and (device.type == "cuda"):
218
- model.to(device)
219
-
220
-
221
  def quant_layers(model):
222
  layer_nums = []
223
  for name, parameter in model.named_parameters():
@@ -431,11 +405,6 @@ def remove_perturbed_indices_set(
431
  def make_perturbation_batch(
432
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
433
  ) -> tuple[Dataset, List[int]]:
434
-
435
- # For datasets>=4.0.0, convert to dict to avoid format issues
436
- if int(datasets.__version__.split(".")[0]) >= 4:
437
- example_cell = example_cell[:]
438
-
439
  if combo_lvl == 0 and tokens_to_perturb == "all":
440
  if perturb_type in ["overexpress", "activate"]:
441
  range_start = 1
@@ -508,11 +477,6 @@ def make_perturbation_batch(
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
511
-
512
- # For datasets>=4.0.0, convert to dict to avoid format issues
513
- if int(datasets.__version__.split(".")[0]) >= 4:
514
- example_cell = example_cell[:]
515
-
516
  if combo_lvl == 0 and tokens_to_perturb == "all":
517
  if perturb_type in ["overexpress", "activate"]:
518
  range_start = 1
@@ -913,4 +877,50 @@ def validate_cell_states_to_model(cell_states_to_model):
913
  "'goal_state': 'nf', "
914
  "'alt_states': ['hcm', 'other1', 'other2']}"
915
  )
916
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import itertools as it
2
  import logging
 
3
  import pickle
4
  from collections import defaultdict
5
  from pathlib import Path
 
8
  import numpy as np
9
  import pandas as pd
10
  import torch
 
11
  from datasets import Dataset, load_from_disk
12
  from peft import LoraConfig, get_peft_model
13
  from transformers import (
 
17
  BitsAndBytesConfig,
18
  )
19
 
20
+ from . import (
21
+ TOKEN_DICTIONARY_FILE,
22
+ ENSEMBL_DICTIONARY_FILE,
23
+ )
24
+
25
  logger = logging.getLogger(__name__)
26
 
27
 
 
127
  output_hidden_states = (mode == "eval")
128
 
129
  # Quantization logic
130
+ if quantize:
 
 
 
131
  if inference_only:
132
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
133
  peft_config = None
 
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
+ peft_config = LoraConfig(
142
+ lora_alpha=128,
143
+ lora_dropout=0.1,
144
+ r=64,
145
+ bias="none",
146
+ task_type="TokenClassification",
147
+ )
 
 
 
 
 
 
 
 
 
148
  else:
149
  quantize_config = None
150
  peft_config = None
 
181
  model.eval()
182
 
183
  # Handle device placement and PEFT
 
 
184
  if not quantize:
185
  # Only move non-quantized models
186
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
+ model = model.to(device)
 
 
 
 
188
  elif peft_config:
189
  # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
190
  model.enable_input_require_grads()
191
  model = get_peft_model(model, peft_config)
 
192
 
193
  return model
194
 
 
 
 
 
 
 
 
 
 
 
195
  def quant_layers(model):
196
  layer_nums = []
197
  for name, parameter in model.named_parameters():
 
405
  def make_perturbation_batch(
406
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
407
  ) -> tuple[Dataset, List[int]]:
 
 
 
 
 
408
  if combo_lvl == 0 and tokens_to_perturb == "all":
409
  if perturb_type in ["overexpress", "activate"]:
410
  range_start = 1
 
477
  def make_perturbation_batch_special(
478
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
479
  ) -> tuple[Dataset, List[int]]:
 
 
 
 
 
480
  if combo_lvl == 0 and tokens_to_perturb == "all":
481
  if perturb_type in ["overexpress", "activate"]:
482
  range_start = 1
 
877
  "'goal_state': 'nf', "
878
  "'alt_states': ['hcm', 'other1', 'other2']}"
879
  )
880
+ raise
881
+
882
+
883
+ class GeneIdHandler:
884
+ def __init__(self, raise_errors=False):
885
+ def invert_dict(dict_obj):
886
+ return {v: k for k, v in dict_obj.items()}
887
+
888
+ self.raise_errors = raise_errors
889
+
890
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
891
+ self.gene_token_dict = pickle.load(f)
892
+ self.token_gene_dict = invert_dict(self.gene_token_dict)
893
+
894
+ with open(ENSEMBL_DICTIONARY_FILE, "rb") as f:
895
+ self.id_gene_dict = pickle.load(f)
896
+ self.gene_id_dict = invert_dict(self.id_gene_dict)
897
+
898
+ def ens_to_token(self, ens_id):
899
+ if not self.raise_errors:
900
+ return self.gene_token_dict.get(ens_id, ens_id)
901
+ else:
902
+ return self.gene_token_dict[ens_id]
903
+
904
+ def token_to_ens(self, token):
905
+ if not self.raise_errors:
906
+ return self.token_gene_dict.get(token, token)
907
+ else:
908
+ return self.token_gene_dict[token]
909
+
910
+ def ens_to_symbol(self, ens_id):
911
+ if not self.raise_errors:
912
+ return self.gene_id_dict.get(ens_id, ens_id)
913
+ else:
914
+ return self.gene_id_dict[ens_id]
915
+
916
+ def symbol_to_ens(self, symbol):
917
+ if not self.raise_errors:
918
+ return self.id_gene_dict.get(symbol, symbol)
919
+ else:
920
+ return self.id_gene_dict[symbol]
921
+
922
+ def token_to_symbol(self, token):
923
+ return self.ens_to_symbol(self.token_to_ens(token))
924
+
925
+ def symbol_to_token(self, symbol):
926
+ return self.ens_to_token(self.symbol_to_ens(symbol))
geneformer/{token_dictionary_gc104M.pkl → token_dictionary_gc95M.pkl} RENAMED
File without changes
geneformer/tokenizer.py CHANGED
@@ -3,7 +3,7 @@ Geneformer tokenizer.
3
 
4
  **Input data:**
5
 
6
- | *Required format:* raw counts scRNAseq data without feature selection as .loom, .h5ad, or .zarr file.
7
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
 
@@ -20,9 +20,9 @@ Geneformer tokenizer.
20
 
21
  **Description:**
22
 
23
- | Input data is a directory with .loom, .h5ad, or .zarr files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
 
25
- | The discussion below references the .loom file format, but the analagous labels are required for .h5ad and .zarr files, just that they will be column instead of row attributes and vice versa due to the transposed format of the file types.
26
 
27
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
 
@@ -30,9 +30,11 @@ Geneformer tokenizer.
30
 
31
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
 
33
- | If one's data is in other formats besides .loom, .h5ad, or .zarr, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom, .h5ad, or .zarr format prior to running the transcriptome tokenizer.
34
 
35
- | OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.)
 
 
36
 
37
  """
38
 
@@ -46,7 +48,6 @@ from collections import Counter
46
  from pathlib import Path
47
  from typing import Literal
48
 
49
- import anndata as ad
50
  import loompy as lp
51
  import numpy as np
52
  import pandas as pd
@@ -88,7 +89,6 @@ def sum_ensembl_ids(
88
  gene_mapping_dict,
89
  gene_token_dict,
90
  custom_attr_name_dict,
91
- use_h5ad_index,
92
  file_format="loom",
93
  chunk_size=512,
94
  ):
@@ -201,19 +201,13 @@ def sum_ensembl_ids(
201
  dsout.add_columns(processed_array, col_attrs=view.ca)
202
  return dedup_filename
203
 
204
- elif file_format in ["h5ad", "zarr"]:
205
  """
206
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
207
  Returns adata object with deduplicated Ensembl IDs.
208
  """
209
 
210
- if file_format == "h5ad":
211
- data = sc.read_h5ad(str(data_directory))
212
- else: # zarr
213
- data = ad.read_zarr(str(data_directory))
214
-
215
- if use_h5ad_index:
216
- data.var["ensembl_id"] = list(data.var.index)
217
 
218
  assert (
219
  "ensembl_id" in data.var.columns
@@ -240,7 +234,7 @@ def sum_ensembl_ids(
240
  gene for gene in ensembl_ids if gene in gene_token_dict.keys()
241
  ]
242
  if len(ensembl_id_check) == len(set(ensembl_id_check)):
243
- return data
244
  else:
245
  raise ValueError("Error: data Ensembl IDs non-unique.")
246
 
@@ -305,9 +299,6 @@ class TranscriptomeTokenizer:
305
  model_input_size=4096,
306
  special_token=True,
307
  collapse_gene_ids=True,
308
- use_h5ad_index=False,
309
- keep_counts=False,
310
- model_version="V2",
311
  gene_median_file=GENE_MEDIAN_FILE,
312
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
313
  gene_mapping_file=ENSEMBL_MAPPING_FILE,
@@ -327,23 +318,15 @@ class TranscriptomeTokenizer:
327
  | Chunk size for anndata tokenizer.
328
  model_input_size : int = 4096
329
  | Max input size of model to truncate input to.
330
- | For the V1 model series, should be 2048. For the V2 model series, should be 4096.
331
  special_token : bool = True
332
  | Adds CLS token before and EOS token after rank value encoding.
333
- | For the V1 model series, should be False. For the V2 model series, should be True.
334
  collapse_gene_ids : bool = True
335
  | Whether to collapse gene IDs based on gene mapping dictionary.
336
- use_h5ad_index : bool = False
337
- | use index as Ensembl IDs (only available for h5ad, only if collapse_gene_ids is True)
338
- keep_counts : bool = False
339
- | Whether to keep a dataset column that represents gene counts normalized by total cell counts
340
- | Counts will be ordered by the gene rank order within the tokenized rank value encoding for each cell.
341
- model_version : str
342
- | To auto-select settings for model version other than current default.
343
- | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
344
  gene_median_file : Path
345
  | Path to pickle file containing dictionary of non-zero median
346
- | gene expression values across Genecorpus.
347
  token_dictionary_file : Path
348
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
349
  gene_mapping_file : None, Path
@@ -365,22 +348,8 @@ class TranscriptomeTokenizer:
365
  # add CLS and EOS tokens
366
  self.special_token = special_token
367
 
368
- # CHANGE DEFAULTS TO BE FOR MODEL OTHER THAN CURRENT
369
- self.model_version = model_version
370
- if self.model_version not in ["V1","V2"]:
371
- logger.error(
372
- "Unrecognized model version. Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells."
373
- )
374
- elif self.model_version == "V1":
375
- self.model_input_size = 2048
376
- self.special_token = False
377
- from . import ENSEMBL_MAPPING_FILE_30M, GENE_MEDIAN_FILE_30M, TOKEN_DICTIONARY_FILE_30M
378
- gene_median_file = GENE_MEDIAN_FILE_30M
379
- token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
380
- gene_mapping_file = ENSEMBL_MAPPING_FILE_30M
381
-
382
  # load dictionary of gene normalization factors
383
- # (non-zero median value of expression across Genecorpus)
384
  with open(gene_median_file, "rb") as f:
385
  self.gene_median_dict = pickle.load(f)
386
 
@@ -403,18 +372,12 @@ class TranscriptomeTokenizer:
403
  "<eos>" in self.gene_token_dict.keys()
404
  ):
405
  logger.warning(
406
- "<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for V2 model series, special_token should be True."
407
  )
408
 
409
  # if collapsing duplicate gene IDs
410
  self.collapse_gene_ids = collapse_gene_ids
411
 
412
- # if using h5ad index as ensembl_ids
413
- self.use_h5ad_index = use_h5ad_index
414
-
415
- # if keeping counts within dataset column
416
- self.keep_counts = keep_counts
417
-
418
  # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
419
  if gene_mapping_file is not None:
420
  with open(gene_mapping_file, "rb") as f:
@@ -439,8 +402,7 @@ class TranscriptomeTokenizer:
439
  data_directory: Path | str,
440
  output_directory: Path | str,
441
  output_prefix: str,
442
- file_format: Literal["loom", "h5ad", "zarr"] = "loom",
443
- input_identifier: str = "",
444
  use_generator: bool = False,
445
  ):
446
  """
@@ -455,21 +417,17 @@ class TranscriptomeTokenizer:
455
  output_prefix : str
456
  | Prefix for output .dataset
457
  file_format : str
458
- | Format of input files. Can be "loom", "h5ad", or "zarr".
459
- input_identifier : str
460
- | Substring identifier for input .loom, .h5ad, or .zarr, only matches are tokenized
461
- | Default is no identifier, tokenizes all files in provided directory.
462
  use_generator : bool
463
  | Whether to use generator or dict for tokenization.
464
 
465
  """
466
- tokenized_cells, cell_metadata, tokenized_counts = self.tokenize_files(
467
- Path(data_directory), file_format, input_identifier
468
  )
469
  tokenized_dataset = self.create_dataset(
470
  tokenized_cells,
471
  cell_metadata,
472
- tokenized_counts,
473
  use_generator=use_generator,
474
  )
475
 
@@ -477,10 +435,9 @@ class TranscriptomeTokenizer:
477
  tokenized_dataset.save_to_disk(str(output_path))
478
 
479
  def tokenize_files(
480
- self, data_directory, file_format: Literal["loom", "h5ad", "zarr"] = "loom", input_identifier: str = ""
481
  ):
482
  tokenized_cells = []
483
- tokenized_counts = []
484
  if self.custom_attr_name_dict is not None:
485
  cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
486
  cell_metadata = {
@@ -489,20 +446,15 @@ class TranscriptomeTokenizer:
489
 
490
  # loops through directories to tokenize .loom files
491
  file_found = 0
492
- # loops through directories to tokenize .loom, .h5ad, or .zarr files
493
  tokenize_file_fn = (
494
  self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
495
  )
496
- if input_identifier == "":
497
- file_match = f"*.{file_format}"
498
- else:
499
- file_match = f"*{input_identifier}*.{file_format}"
500
- for file_path in data_directory.glob(file_match):
501
  file_found = 1
502
  print(f"Tokenizing {file_path}")
503
- file_tokenized_cells, file_cell_metadata, file_tokenized_counts = tokenize_file_fn(file_path, file_format=file_format)
504
  tokenized_cells += file_tokenized_cells
505
- tokenized_counts += file_tokenized_counts
506
  if self.custom_attr_name_dict is not None:
507
  for k in cell_attr:
508
  cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
@@ -516,17 +468,16 @@ class TranscriptomeTokenizer:
516
  f"No .{file_format} files found in directory {data_directory}."
517
  )
518
  raise
519
- return tokenized_cells, cell_metadata, tokenized_counts
520
 
521
- def tokenize_anndata(self, adata_file_path, target_sum=10_000, file_format="h5ad"):
522
  adata = sum_ensembl_ids(
523
  adata_file_path,
524
  self.collapse_gene_ids,
525
  self.gene_mapping_dict,
526
  self.gene_token_dict,
527
- self.custom_attr_name_dict,
528
- self.use_h5ad_index,
529
- file_format=file_format,
530
  chunk_size=self.chunk_size,
531
  )
532
 
@@ -565,7 +516,6 @@ class TranscriptomeTokenizer:
565
  filter_pass_loc = np.array([i for i in range(adata.shape[0])])
566
 
567
  tokenized_cells = []
568
- tokenized_counts = []
569
 
570
  for i in range(0, len(filter_pass_loc), self.chunk_size):
571
  idx = filter_pass_loc[i : i + self.chunk_size]
@@ -573,23 +523,14 @@ class TranscriptomeTokenizer:
573
  n_counts = adata[idx].obs["n_counts"].values[:, None]
574
  X_view0 = adata[idx, :].X
575
  X_view = X_view0[:, coding_miRNA_loc]
576
- X_norm_unscaled = X_view / n_counts * target_sum
577
- X_norm = X_norm_unscaled / norm_factor_vector
578
  X_norm = sp.csr_matrix(X_norm)
579
- X_norm_unscaled = sp.csr_matrix(X_norm_unscaled)
580
 
581
  tokenized_cells += [
582
  rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
583
  for i in range(X_norm.shape[0])
584
  ]
585
 
586
- if self.keep_counts:
587
- X_norm_unscaled = sp.csr_matrix(X_norm_unscaled)
588
- tokenized_counts += [
589
- rank_genes(X_norm[i].data, X_norm_unscaled[i].data)
590
- for i in range(X_norm.shape[0])
591
- ]
592
-
593
  # add custom attributes for subview to dict
594
  if self.custom_attr_name_dict is not None:
595
  for k in file_cell_metadata.keys():
@@ -597,28 +538,9 @@ class TranscriptomeTokenizer:
597
  else:
598
  file_cell_metadata = None
599
 
600
- # ensure no tokenized_cells are empty
601
- empty_cell_indices = [i for i, cell in enumerate(tokenized_cells) if cell.size == 0]
602
- if len(empty_cell_indices) > 0:
603
- logger.warning(
604
- "Warning: cells without any genes in token dictionary detected. This is unusual and may indicate empty droplets or otherwise invalid cells within the input data. Consider further QC prior to tokenization. Proceeding with excluding empty cells."
605
- )
606
- empty_cell_indices.sort(reverse=True) # for safe deletion
607
- for index in empty_cell_indices:
608
- del tokenized_cells[index]
609
- if self.keep_counts:
610
- del tokenized_counts[index]
611
- # remove corresponding metadata
612
- for k,v in file_cell_metadata.items():
613
- for index in empty_cell_indices:
614
- del v[index]
615
- file_cell_metadata[k] = v
616
-
617
- return tokenized_cells, file_cell_metadata, tokenized_counts
618
 
619
- def tokenize_loom(self, loom_file_path, target_sum=10_000, file_format="loom"):
620
- tokenized_counts = [] # keep_counts not implemented for tokenize_loom
621
-
622
  if self.custom_attr_name_dict is not None:
623
  file_cell_metadata = {
624
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
@@ -632,8 +554,7 @@ class TranscriptomeTokenizer:
632
  self.gene_mapping_dict,
633
  self.gene_token_dict,
634
  self.custom_attr_name_dict,
635
- use_h5ad_index=False,
636
- file_format=file_format,
637
  chunk_size=self.chunk_size,
638
  )
639
 
@@ -706,33 +627,29 @@ class TranscriptomeTokenizer:
706
  del data.ra["ensembl_id_collapsed"]
707
 
708
 
709
- return tokenized_cells, file_cell_metadata, tokenized_counts
710
 
711
  def create_dataset(
712
  self,
713
  tokenized_cells,
714
  cell_metadata,
715
- tokenized_counts,
716
  use_generator=False,
717
  keep_uncropped_input_ids=False,
718
  ):
719
  print("Creating dataset.")
720
  # create dict for dataset creation
721
  dataset_dict = {"input_ids": tokenized_cells}
722
- if self.keep_counts:
723
- dataset_dict["counts"] = tokenized_counts
724
-
725
  if self.custom_attr_name_dict is not None:
726
  dataset_dict.update(cell_metadata)
727
 
728
  # create dataset
729
  if use_generator:
 
730
  def dict_generator():
731
  for i in range(len(tokenized_cells)):
732
  yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
733
 
734
  output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
735
-
736
  else:
737
  output_dataset = Dataset.from_dict(dataset_dict)
738
 
@@ -755,23 +672,9 @@ class TranscriptomeTokenizer:
755
  len(example["input_ids"]),
756
  self.gene_token_dict.get("<eos>"),
757
  )
758
- if self.keep_counts:
759
- example["counts"] = example["counts"][
760
- 0 : self.model_input_size - 2
761
- ] # truncate to leave space for CLS and EOS token
762
- example["counts"] = np.insert(
763
- example["counts"], 0, 0.0
764
- )
765
- example["counts"] = np.insert(
766
- example["counts"],
767
- len(example["counts"]),
768
- 0.0,
769
- )
770
  else:
771
  # Truncate/Crop input_ids to input size
772
  example["input_ids"] = example["input_ids"][0 : self.model_input_size]
773
- if self.keep_counts:
774
- example["counts"] = example["counts"][0 : self.model_input_size]
775
  example["length"] = len(example["input_ids"])
776
 
777
  return example
 
3
 
4
  **Input data:**
5
 
6
+ | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
7
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
 
 
20
 
21
  **Description:**
22
 
23
+ | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
 
25
+ | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
26
 
27
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
 
 
30
 
31
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
 
33
+ | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
 
35
+ | OF NOTE: Take care that the correct token dictionary and gene median file is used for the correct model.
36
+
37
+ | OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048.
38
 
39
  """
40
 
 
48
  from pathlib import Path
49
  from typing import Literal
50
 
 
51
  import loompy as lp
52
  import numpy as np
53
  import pandas as pd
 
89
  gene_mapping_dict,
90
  gene_token_dict,
91
  custom_attr_name_dict,
 
92
  file_format="loom",
93
  chunk_size=512,
94
  ):
 
201
  dsout.add_columns(processed_array, col_attrs=view.ca)
202
  return dedup_filename
203
 
204
+ elif file_format == "h5ad":
205
  """
206
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
207
  Returns adata object with deduplicated Ensembl IDs.
208
  """
209
 
210
+ data = sc.read_h5ad(str(data_directory))
 
 
 
 
 
 
211
 
212
  assert (
213
  "ensembl_id" in data.var.columns
 
234
  gene for gene in ensembl_ids if gene in gene_token_dict.keys()
235
  ]
236
  if len(ensembl_id_check) == len(set(ensembl_id_check)):
237
+ return data_directory
238
  else:
239
  raise ValueError("Error: data Ensembl IDs non-unique.")
240
 
 
299
  model_input_size=4096,
300
  special_token=True,
301
  collapse_gene_ids=True,
 
 
 
302
  gene_median_file=GENE_MEDIAN_FILE,
303
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
304
  gene_mapping_file=ENSEMBL_MAPPING_FILE,
 
318
  | Chunk size for anndata tokenizer.
319
  model_input_size : int = 4096
320
  | Max input size of model to truncate input to.
321
+ | For the 30M model series, should be 2048. For the 95M model series, should be 4096.
322
  special_token : bool = True
323
  | Adds CLS token before and EOS token after rank value encoding.
324
+ | For the 30M model series, should be False. For the 95M model series, should be True.
325
  collapse_gene_ids : bool = True
326
  | Whether to collapse gene IDs based on gene mapping dictionary.
 
 
 
 
 
 
 
 
327
  gene_median_file : Path
328
  | Path to pickle file containing dictionary of non-zero median
329
+ | gene expression values across Genecorpus-30M.
330
  token_dictionary_file : Path
331
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
332
  gene_mapping_file : None, Path
 
348
  # add CLS and EOS tokens
349
  self.special_token = special_token
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  # load dictionary of gene normalization factors
352
+ # (non-zero median value of expression across Genecorpus-30M)
353
  with open(gene_median_file, "rb") as f:
354
  self.gene_median_dict = pickle.load(f)
355
 
 
372
  "<eos>" in self.gene_token_dict.keys()
373
  ):
374
  logger.warning(
375
+ "<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for 95M model series, special_token should be True."
376
  )
377
 
378
  # if collapsing duplicate gene IDs
379
  self.collapse_gene_ids = collapse_gene_ids
380
 
 
 
 
 
 
 
381
  # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
382
  if gene_mapping_file is not None:
383
  with open(gene_mapping_file, "rb") as f:
 
402
  data_directory: Path | str,
403
  output_directory: Path | str,
404
  output_prefix: str,
405
+ file_format: Literal["loom", "h5ad"] = "loom",
 
406
  use_generator: bool = False,
407
  ):
408
  """
 
417
  output_prefix : str
418
  | Prefix for output .dataset
419
  file_format : str
420
+ | Format of input files. Can be "loom" or "h5ad".
 
 
 
421
  use_generator : bool
422
  | Whether to use generator or dict for tokenization.
423
 
424
  """
425
+ tokenized_cells, cell_metadata = self.tokenize_files(
426
+ Path(data_directory), file_format
427
  )
428
  tokenized_dataset = self.create_dataset(
429
  tokenized_cells,
430
  cell_metadata,
 
431
  use_generator=use_generator,
432
  )
433
 
 
435
  tokenized_dataset.save_to_disk(str(output_path))
436
 
437
  def tokenize_files(
438
+ self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
439
  ):
440
  tokenized_cells = []
 
441
  if self.custom_attr_name_dict is not None:
442
  cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
443
  cell_metadata = {
 
446
 
447
  # loops through directories to tokenize .loom files
448
  file_found = 0
449
+ # loops through directories to tokenize .loom or .h5ad files
450
  tokenize_file_fn = (
451
  self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
452
  )
453
+ for file_path in data_directory.glob(f"*.{file_format}"):
 
 
 
 
454
  file_found = 1
455
  print(f"Tokenizing {file_path}")
456
+ file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
457
  tokenized_cells += file_tokenized_cells
 
458
  if self.custom_attr_name_dict is not None:
459
  for k in cell_attr:
460
  cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
 
468
  f"No .{file_format} files found in directory {data_directory}."
469
  )
470
  raise
471
+ return tokenized_cells, cell_metadata
472
 
473
+ def tokenize_anndata(self, adata_file_path, target_sum=10_000):
474
  adata = sum_ensembl_ids(
475
  adata_file_path,
476
  self.collapse_gene_ids,
477
  self.gene_mapping_dict,
478
  self.gene_token_dict,
479
+ self.custom_attr_name_dict,
480
+ file_format="h5ad",
 
481
  chunk_size=self.chunk_size,
482
  )
483
 
 
516
  filter_pass_loc = np.array([i for i in range(adata.shape[0])])
517
 
518
  tokenized_cells = []
 
519
 
520
  for i in range(0, len(filter_pass_loc), self.chunk_size):
521
  idx = filter_pass_loc[i : i + self.chunk_size]
 
523
  n_counts = adata[idx].obs["n_counts"].values[:, None]
524
  X_view0 = adata[idx, :].X
525
  X_view = X_view0[:, coding_miRNA_loc]
526
+ X_norm = X_view / n_counts * target_sum / norm_factor_vector
 
527
  X_norm = sp.csr_matrix(X_norm)
 
528
 
529
  tokenized_cells += [
530
  rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
531
  for i in range(X_norm.shape[0])
532
  ]
533
 
 
 
 
 
 
 
 
534
  # add custom attributes for subview to dict
535
  if self.custom_attr_name_dict is not None:
536
  for k in file_cell_metadata.keys():
 
538
  else:
539
  file_cell_metadata = None
540
 
541
+ return tokenized_cells, file_cell_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
+ def tokenize_loom(self, loom_file_path, target_sum=10_000):
 
 
544
  if self.custom_attr_name_dict is not None:
545
  file_cell_metadata = {
546
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
 
554
  self.gene_mapping_dict,
555
  self.gene_token_dict,
556
  self.custom_attr_name_dict,
557
+ file_format="loom",
 
558
  chunk_size=self.chunk_size,
559
  )
560
 
 
627
  del data.ra["ensembl_id_collapsed"]
628
 
629
 
630
+ return tokenized_cells, file_cell_metadata
631
 
632
  def create_dataset(
633
  self,
634
  tokenized_cells,
635
  cell_metadata,
 
636
  use_generator=False,
637
  keep_uncropped_input_ids=False,
638
  ):
639
  print("Creating dataset.")
640
  # create dict for dataset creation
641
  dataset_dict = {"input_ids": tokenized_cells}
 
 
 
642
  if self.custom_attr_name_dict is not None:
643
  dataset_dict.update(cell_metadata)
644
 
645
  # create dataset
646
  if use_generator:
647
+
648
  def dict_generator():
649
  for i in range(len(tokenized_cells)):
650
  yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
651
 
652
  output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
 
653
  else:
654
  output_dataset = Dataset.from_dict(dataset_dict)
655
 
 
672
  len(example["input_ids"]),
673
  self.gene_token_dict.get("<eos>"),
674
  )
 
 
 
 
 
 
 
 
 
 
 
 
675
  else:
676
  # Truncate/Crop input_ids to input size
677
  example["input_ids"] = example["input_ids"][0 : self.model_input_size]
 
 
678
  example["length"] = len(example["input_ids"])
679
 
680
  return example
generation_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
- "transformers_version": "4.44.2"
5
  }
 
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
+ "transformers_version": "4.37.1"
5
  }
gf-12L-30M-i2048/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "gradient_checkpointing": false,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 2048,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0",
20
+ "type_vocab_size": 2,
21
+ "use_cache": true,
22
+ "vocab_size": 25426
23
+ }
gf-12L-30M-i2048/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615
3
+ size 158467410
{Geneformer-V2-316M → gf-12L-30M-i2048}/training_args.bin RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e45150f9a4ca34cb4e91ce79f65f3d99d9d66df9f66a37517a352d291008e0b8
3
- size 5432
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:259cf6067211e24e198690d00f0a222ee5550ad57e23d04ced0d0ca2e1b3738e
3
+ size 2607