diff --git a/Geneformer-V2-104M/model.safetensors b/Geneformer-V2-104M/model.safetensors deleted file mode 100755 index 0cc349a7acab42b89dcd92a6a7b4e6bfaf53b5f4..0000000000000000000000000000000000000000 --- a/Geneformer-V2-104M/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fff5cba29ddd8792991fa77b4872246fbe548a178cebda3775cdc72b67780e7f -size 417571156 diff --git a/Geneformer-V2-104M/training_args.bin b/Geneformer-V2-104M/training_args.bin deleted file mode 100755 index 443011e54067545b01bda3cf3979ed4775713c0b..0000000000000000000000000000000000000000 --- a/Geneformer-V2-104M/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0d8ddd9e4f35b5fe23a3adaae03aa4480705ca82eed546a488f970adb3752d9d -size 5496 diff --git a/Geneformer-V2-104M_CLcancer/model.safetensors b/Geneformer-V2-104M_CLcancer/model.safetensors deleted file mode 100755 index 659afe038802fc98201c0216f28bb466ed84f89d..0000000000000000000000000000000000000000 --- a/Geneformer-V2-104M_CLcancer/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:827738139bfed4bafa9d1f3df7c6146da2e3b85f7225076adc32c6eda0ba4357 -size 417571156 diff --git a/Geneformer-V2-104M_CLcancer/training_args.bin b/Geneformer-V2-104M_CLcancer/training_args.bin deleted file mode 100755 index 7199a4754edc3cbf4f76ce9659e4a8a176082dc5..0000000000000000000000000000000000000000 --- a/Geneformer-V2-104M_CLcancer/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8cf8ce52b498253adc6df53197a99821fa145c19b8ae5eeb8d15be76b8b7ddb3 -size 4984 diff --git a/Geneformer-V2-316M/model.safetensors b/Geneformer-V2-316M/model.safetensors deleted file mode 100755 index f9c0c25c4a1df80ddc1455def715a9856152882c..0000000000000000000000000000000000000000 --- a/Geneformer-V2-316M/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:965ceccea81953d362081ef3843560a0e4fef88d396c28017881f1e94b1246f3 -size 1265455076 diff --git a/Geneformer-V2-316M/training_args.bin b/Geneformer-V2-316M/training_args.bin deleted file mode 100755 index 630bab56b199325b337a9969d30167f5b73b7815..0000000000000000000000000000000000000000 --- a/Geneformer-V2-316M/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e45150f9a4ca34cb4e91ce79f65f3d99d9d66df9f66a37517a352d291008e0b8 -size 5432 diff --git a/MANIFEST.in b/MANIFEST.in index 764efdc5ada844c921601ae1ba29bcaf66001986..c3875d90a1e1ee1715279ba71ae3efc1a46643e8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,9 +1,4 @@ -include geneformer/gene_median_dictionary_gc104M.pkl -include geneformer/gene_name_id_dict_gc104M.pkl -include geneformer/ensembl_mapping_dict_gc104M.pkl -include geneformer/token_dictionary_gc104M.pkl - -include geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl -include geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl -include geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl -include geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl \ No newline at end of file +include geneformer/gene_median_dictionary_gc95M.pkl +include geneformer/gene_name_id_dict_gc95M.pkl +include geneformer/ensembl_mapping_dict_gc95M.pkl +include geneformer/token_dictionary_gc95M.pkl diff --git a/README.md b/README.md index 5610efbd5b32f59af3a1e70b7ace9a94f21ac21a..2d1ad4375703f99e682e4293131484adeb939522 100644 --- a/README.md +++ b/README.md @@ -9,28 +9,35 @@ tags: Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology. - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies. -- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model, now trained on ~104 million transcriptomes, and our continual learning, multitask learning, and quantization strategies. +- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies. - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation. # Model Description -Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer V1 was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million human single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. The current updated Geneformer V2 is pretrained on ~104 million human single cell transcriptomes (non-cancer). The cancer continual learning V2 variant was continually pretrained on ~14 million cancer transcriptomes to yield a cancer domain-tuned model. +Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model. -Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus (~30M for V1, ~104M for V2). The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable. +Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable. The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels. We detail applications and results in [our manuscript](https://rdcu.be/ddrx0). -During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational AI model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets. +During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets. The repository includes the following pretrained models: -- Geneformer-V1-10M: original model trained June 2021 on ~30M human single cell transcriptomes, 10M parameters, input size 2048, vocabulary ~25K protein-coding or non-coding RNA genes -- Geneformer-V2-104M and Geneformer-V2-316M: updated model trained Dec 2024 on ~104M human single cell transcriptomes, 104M or 316M parameters, input size 4096, vocabulary ~20K protein-coding genes +L=layers\ +M=millions of cells used for pretraining\ +i=input size\ +(pretraining date) -The current default model in the main directory of the repository is Geneformer-V2-316M. +- GF-6L-30M-i2048 (June 2021) +- GF-12L-30M-i2048 (June 2021) +- GF-12L-95M-i4096 (April 2024) +- GF-20L-95M-i4096 (April 2024) -The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, Geneformer-V2-104M_CLcancer. +The current default model in the main directory of the repository is GF-12L-95M-i4096. + +The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer. # Application The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification. @@ -78,13 +85,9 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main - extracting and plotting cell embeddings - in silico perturbation -Please also see [here](https://tinyurl.com/geneformertutorial) for a quickstart tutorial for predicting candidate therapeutic targets with Geneformer. - -Complete documentation is available at https://geneformer.readthedocs.io/en/latest/. - Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications. -Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). Importantly, as usual for deep learning models, there are no uniformly applicable default hyperparameters for Geneformer. +Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). # Citations - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors) diff --git a/config.json b/config.json index 6bc648aa565eabb748a6a43ee4def5032a0d5237..86e20c35e6f257f0daeb00ebb92a0751d12d8fff 100644 --- a/config.json +++ b/config.json @@ -2,22 +2,22 @@ "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.02, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 1152, + "hidden_dropout_prob": 0.02, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 4608, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 18, - "num_hidden_layers": 18, + "num_attention_heads": 8, + "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "torch_dtype": "float32", - "transformers_version": "4.44.2", + "transformers_version": "4.37.1", "type_vocab_size": 2, "use_cache": true, "vocab_size": 20275 diff --git a/docs/source/geneformer.in_silico_perturber.rst b/docs/source/geneformer.in_silico_perturber.rst index 419b1c74c126d19e2c0a915606c90932330da0c0..fab76dea3c46244ab15d3d77552bc538535675e5 100644 --- a/docs/source/geneformer.in_silico_perturber.rst +++ b/docs/source/geneformer.in_silico_perturber.rst @@ -5,4 +5,4 @@ geneformer.in\_silico\_perturber :members: :undoc-members: :show-inheritance: - :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, , isp_perturb_all_special, isp_perturb_set_special, update_perturbation_dictionary + :exclude-members: valid_option_dict, validate_options, apply_additional_filters, isp_perturb_all, isp_perturb_set, update_perturbation_dictionary diff --git a/examples/cell_classification.ipynb b/examples/cell_classification.ipynb index 64dd4323cb562a31b1641a000dab8aa5f59ac951..321187b9959abe460c6efc34996d6db0cf3488ed 100644 --- a/examples/cell_classification.ipynb +++ b/examples/cell_classification.ipynb @@ -13,7 +13,7 @@ "id": "1792e51c-86c3-406f-be5a-273c4e4aec20", "metadata": {}, "source": [ - "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters." + "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization." ] }, { @@ -69,7 +69,9 @@ " \"seed\": 73,\n", "}\n", "\n", - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", + "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", + "# (otherwise the Classifier will use the current default model dictionary)\n", + "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", "cc = Classifier(classifier=\"cell\",\n", " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n", " filter_data=filter_data_dict,\n", @@ -78,7 +80,6 @@ " freeze_layers = 2,\n", " num_crossval_splits = 1,\n", " forward_batch_size=200,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -263,8 +264,8 @@ " \"train\": train_ids,\n", " \"eval\": eval_ids}\n", "\n", - "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", - "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", + "# Example 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n", + "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -449,7 +450,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/examples/distributed_multitask_cell_classification.ipynb b/examples/distributed_multitask_cell_classification.ipynb deleted file mode 100644 index dfc601ddd6745982175cbd4c3cab6d2db581f1cf..0000000000000000000000000000000000000000 --- a/examples/distributed_multitask_cell_classification.ipynb +++ /dev/null @@ -1,149 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "b3266a7b", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "from geneformer import MTLClassifier" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e12ac9f", - "metadata": {}, - "outputs": [], - "source": [ - "# Define paths\n", - "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n", - "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n", - "train_path = \"/path/to/train/data.dataset\"\n", - "val_path = \"/path/to/val/data.dataset\"\n", - "test_path = \"/path/to/test/data.dataset\"\n", - "results_dir = \"/path/to/results/directory\"\n", - "model_save_path = \"/path/to/model/save/path\"\n", - "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n", - "\n", - "# Define tasks and hyperparameters\n", - "# task_columns should be a list of column names from your dataset\n", - "# Each column represents a specific classification task (e.g. cell type, disease state)\n", - "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9bd7562", - "metadata": {}, - "outputs": [], - "source": [ - "# Check GPU environment\n", - "num_gpus = torch.cuda.device_count()\n", - "use_distributed = num_gpus > 1\n", - "print(f\"Number of GPUs detected: {num_gpus}\")\n", - "print(f\"Using distributed training: {use_distributed}\")\n", - "\n", - "# Set environment variables for distributed training when multiple GPUs are available\n", - "if use_distributed:\n", - " os.environ[\"MASTER_ADDR\"] = \"localhost\" # hostname\n", - " os.environ[\"MASTER_PORT\"] = \"12355\" # Choose an available port\n", - " print(\"Distributed environment variables set.\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b6ff3618", - "metadata": {}, - "outputs": [], - "source": [ - "#Define Hyperparameters for Optimization\n", - "hyperparameters = {\n", - " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n", - " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n", - " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n", - " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n", - " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n", - " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f665c5a7", - "metadata": {}, - "outputs": [], - "source": [ - "mc = MTLClassifier(\n", - " task_columns=task_columns, # Our defined classification tasks\n", - " study_name=\"MTLClassifier_distributed\",\n", - " pretrained_path=pretrained_path,\n", - " train_path=train_path,\n", - " val_path=val_path,\n", - " test_path=test_path,\n", - " model_save_path=model_save_path,\n", - " results_dir=results_dir,\n", - " tensorboard_log_dir=tensorboard_log_dir,\n", - " hyperparameters=hyperparameters,\n", - " # Distributed training parameters\n", - " distributed_training=use_distributed, # Enable distributed training if multiple GPUs available\n", - " master_addr=\"localhost\" if use_distributed else None,\n", - " master_port=\"12355\" if use_distributed else None,\n", - " # Other training parameters\n", - " n_trials=15, # Number of trials for hyperparameter optimization\n", - " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n", - " batch_size=8, # Adjust based on available GPU memory\n", - " gradient_accumulation_steps=4, # Accumulate gradients over multiple steps\n", - " gradient_clipping=True, # Enable gradient clipping for stability\n", - " max_grad_norm=1.0, # Set maximum gradient norm\n", - " seed=42\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f69f7b6a", - "metadata": {}, - "outputs": [], - "source": [ - "# Run Hyperparameter Optimization with Distributed Training\n", - "if __name__ == \"__main__\":\n", - " # This guard is required for distributed training to prevent\n", - " # infinite subprocess spawning when using torch.multiprocessing\n", - " mc.run_optuna_study()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3affd5dd", - "metadata": {}, - "outputs": [], - "source": [ - "# Evaluate the Model on Test Data\n", - "if __name__ == \"__main__\":\n", - " mc.load_and_evaluate_test_model()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "bio", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/extract_and_plot_cell_embeddings.ipynb b/examples/extract_and_plot_cell_embeddings.ipynb index 8571064ab3a3a3f35d3fcf2d90a64fc0ebcb071f..f00388708664a1cd0c774bfa13f0c01d0ee6578d 100644 --- a/examples/extract_and_plot_cell_embeddings.ipynb +++ b/examples/extract_and_plot_cell_embeddings.ipynb @@ -18,7 +18,8 @@ "outputs": [], "source": [ "# initiate EmbExtractor\n", - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", + "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", + "# (otherwise the EmbExtractor will use the current default model dictionary)\n", "embex = EmbExtractor(model_type=\"CellClassifier\",\n", " num_classes=3,\n", " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n", @@ -27,13 +28,13 @@ " emb_label=[\"disease\",\"cell_type\"],\n", " labels_to_plot=[\"disease\"],\n", " forward_batch_size=200,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", - " nproc=16)\n", + " nproc=16,\n", + " token_dictionary_file=\"./gene_dictionaries_30m/token_dictionary_gc30M.pkl\") # change from current default dictionary for 30M model series\n", "\n", "# extracts embedding from input data\n", "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n", - "# example dataset for V1 model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", - "embs = embex.extract_embs(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n", + "# example dataset for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", + "embs = embex.extract_embs(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n", " \"path/to/input_data/\",\n", " \"path/to/output_directory/\",\n", " \"output_prefix\")\n" @@ -131,7 +132,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/examples/gene_classification.ipynb b/examples/gene_classification.ipynb index b739754a95c23c8f74f9cf4a85e05da9c2af58a8..284da7a1cc5846566d8b599ac2b549f6dc20f4a4 100644 --- a/examples/gene_classification.ipynb +++ b/examples/gene_classification.ipynb @@ -13,7 +13,7 @@ "id": "79539e95-2c9c-4162-835c-f0d158abb15d", "metadata": {}, "source": [ - "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters." + "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications." ] }, { @@ -71,14 +71,15 @@ } ], "source": [ - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", + "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", + "# (otherwise the Classifier will use the current default model dictionary)\n", + "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", "cc = Classifier(classifier=\"gene\",\n", " gene_class_dict = gene_class_dict,\n", " max_ncells = 10_000,\n", " freeze_layers = 4,\n", " num_crossval_splits = 5,\n", " forward_batch_size=200,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -842,8 +843,8 @@ } ], "source": [ - "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", - "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", + "# 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n", + "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -1065,14 +1066,12 @@ } ], "source": [ - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "cc = Classifier(classifier=\"gene\",\n", " gene_class_dict = gene_class_dict,\n", " max_ncells = 10_000,\n", " freeze_layers = 4,\n", " num_crossval_splits = 0,\n", " forward_batch_size=200,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -1219,8 +1218,8 @@ } ], "source": [ - "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", - "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", + "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n", + "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -1244,7 +1243,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/examples/in_silico_perturbation.ipynb b/examples/in_silico_perturbation.ipynb index 00607d24415b5a6034fca467ec90c2e76ae72d43..f7102617ebd36956d07ba61f8e4bccdf0719515e 100644 --- a/examples/in_silico_perturbation.ipynb +++ b/examples/in_silico_perturbation.ipynb @@ -39,7 +39,9 @@ "\n", "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n", "\n", - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", + "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", + "# (otherwise the EmbExtractor will use the current default model dictionary)\n", + "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n", " num_classes=3,\n", " filter_data=filter_data_dict,\n", @@ -47,7 +49,6 @@ " emb_layer=0,\n", " summary_stat=\"exact_mean\",\n", " forward_batch_size=256,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)\n", "\n", "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n", @@ -66,7 +67,9 @@ }, "outputs": [], "source": [ - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", + "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", + "# (otherwise the InSilicoPerturber will use the current default model dictionary)\n", + "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", "isp = InSilicoPerturber(perturb_type=\"delete\",\n", " perturb_rank_shift=None,\n", " genes_to_perturb=\"all\",\n", @@ -74,7 +77,7 @@ " anchor_gene=None,\n", " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n", " num_classes=3,\n", - " emb_mode=\"cell\", # OF NOTE: SET TO \"CELL\" FOR V1 MODEL. FOR V2, SHOULD BE \"CLS\" (current default).\n", + " emb_mode=\"cell\",\n", " cell_emb_style=\"mean_pool\",\n", " filter_data=filter_data_dict,\n", " cell_states_to_model=cell_states_to_model,\n", @@ -82,7 +85,6 @@ " max_ncells=2000,\n", " emb_layer=0,\n", " forward_batch_size=400,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -95,7 +97,7 @@ "source": [ "# outputs intermediate files from in silico perturbation\n", "\n", - "isp.perturb_data(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n", + "isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n", " \"path/to/input_data\",\n", " \"path/to/isp_output_directory\",\n", " \"output_prefix\")" @@ -108,13 +110,14 @@ "metadata": {}, "outputs": [], "source": [ - "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", + "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", + "# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n", + "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n", " genes_perturbed=\"all\",\n", " combos=0,\n", " anchor_gene=None,\n", - " cell_states_to_model=cell_states_to_model,\n", - " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL SINCE V1 WAS USED FOR IN SILICO PERTURBATION ABOVE)" + " cell_states_to_model=cell_states_to_model)" ] }, { @@ -148,7 +151,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/examples/multitask_cell_classification.ipynb b/examples/multitask_cell_classification.ipynb index 998e678d7eb812dbf6e5e3764b982e9d6620fd63..b3f13b7477c7fb8797bf871b90f943877fb61029 100644 --- a/examples/multitask_cell_classification.ipynb +++ b/examples/multitask_cell_classification.ipynb @@ -286,7 +286,7 @@ " filter_data_dict=filter_data_dict,\n", " max_ncells=1000, # Number of cells to extract embeddings for\n", " emb_layer=0, # Use the second to last layer\n", - " emb_mode = \"cls\", # Use CLS token embedding for V2 model\n", + " emb_mode = \"cls\",\n", " summary_stat=\"exact_mean\",\n", " forward_batch_size=8, # Adjust based on available GPU memory\n", " nproc=4\n", @@ -324,7 +324,7 @@ " perturb_type=perturb_type,\n", " genes_to_perturb=\"all\", # Perturb all genes\n", " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n", - " emb_mode=\"cls\", # Use CLS token embedding for V2 model\n", + " emb_mode=\"cls\", # Use CLS token embedding\n", " cell_states_to_model=cell_states_to_model,\n", " state_embs_dict=state_embs_dict,\n", " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n", @@ -412,7 +412,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/tokenizing_scRNAseq_data.ipynb b/examples/tokenizing_scRNAseq_data.ipynb index 7f7331fb63d46567465ddcc6cea5560e09a40e24..58c629a166529b066ba3615c16a26e59dd46295f 100644 --- a/examples/tokenizing_scRNAseq_data.ipynb +++ b/examples/tokenizing_scRNAseq_data.ipynb @@ -34,8 +34,12 @@ "metadata": {}, "source": [ "**********************************************************************************************************\n", - "#### OF NOTE: Please ensure the correct token dictionary, gene median file, special token setting, and model input size is used for the correct model version.\n", - "#### Current defaults are for V2 model series. To auto-select the correct settings for V1, set model_version argument to \"V1\"." + "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n", + "#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n", + "\n", + "#### ADDITIONALLY:\n", + "#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n", + "#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048." ] }, { @@ -55,7 +59,7 @@ "metadata": {}, "outputs": [], "source": [ - "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16) # for V1 model, set model_version=\"V1\"\n", + "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n", "tk.tokenize_data(\"loom_data_directory\", \n", " \"output_directory\", \n", " \"output_prefix\", \n", @@ -79,7 +83,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/Geneformer-V2-104M/config.json b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json similarity index 68% rename from Geneformer-V2-104M/config.json rename to fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json index f1f165b1d6df5149022c0c40a23a31fd258d98cf..bc8099f84af0bd3e35d700a7135dd417e38f6bea 100755 --- a/Geneformer-V2-104M/config.json +++ b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json @@ -2,22 +2,22 @@ "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.02, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, + "hidden_dropout_prob": 0.02, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 3072, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 12, + "num_attention_heads": 8, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "torch_dtype": "float32", - "transformers_version": "4.44.2", + "transformers_version": "4.37.2", "type_vocab_size": 2, "use_cache": true, "vocab_size": 20275 diff --git a/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin new file mode 100755 index 0000000000000000000000000000000000000000..87625b1b8fe02c6aa0fc3ffd8c746275570e589d --- /dev/null +++ b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4 +size 152363342 diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/config.json b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/config.json rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/optimizer.pt b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/optimizer.pt rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/pytorch_model.bin b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/pytorch_model.bin rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/rng_state.pth b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/rng_state.pth rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/scheduler.pt b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/scheduler.pt rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/trainer_state.json b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/trainer_state.json rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json diff --git a/fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/training_args.bin b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin similarity index 100% rename from fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224/training_args.bin rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin diff --git a/geneformer/__init__.py b/geneformer/__init__.py index 38cfecec8ec7924d7c74bf195165417b4c891829..52d43619d06f2a7c019b480d1958a82d287d26ff 100644 --- a/geneformer/__init__.py +++ b/geneformer/__init__.py @@ -4,15 +4,10 @@ from pathlib import Path warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip -GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc104M.pkl" -TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc104M.pkl" -ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc104M.pkl" -ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc104M.pkl" - -GENE_MEDIAN_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl" -TOKEN_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/token_dictionary_gc30M.pkl" -ENSEMBL_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl" -ENSEMBL_MAPPING_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl" +GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl" +TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl" +ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl" +ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl" from . import ( collator_for_classification, diff --git a/geneformer/classifier.py b/geneformer/classifier.py index ee99a8a9c0ac4d53e56b18b01a11aa3b1eb5d5c5..c281a7ec3de0a3e568a395d154ec7540cb0bd8bd 100644 --- a/geneformer/classifier.py +++ b/geneformer/classifier.py @@ -48,13 +48,11 @@ import logging import os import pickle import subprocess -from packaging.version import parse from pathlib import Path import numpy as np import pandas as pd import seaborn as sns -import transformers from tqdm.auto import tqdm, trange from transformers import Trainer from transformers.training_args import TrainingArguments @@ -73,7 +71,6 @@ sns.set() logger = logging.getLogger(__name__) -transformers_version = parse(transformers.__version__) class Classifier: valid_option_dict = { @@ -92,7 +89,6 @@ class Classifier: "no_eval": {bool}, "stratify_splits_col": {None, str}, "forward_batch_size": {int}, - "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "nproc": {int}, "ngpu": {int}, @@ -116,7 +112,6 @@ class Classifier: stratify_splits_col=None, no_eval=False, forward_batch_size=100, - model_version="V2", token_dictionary_file=None, nproc=4, ngpu=1, @@ -193,9 +188,6 @@ class Classifier: | Otherwise, will perform eval during training. forward_batch_size : int | Batch size for forward pass (for evaluation, not training). - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : None, str | Default is to use token dictionary file from Geneformer | Otherwise, will load custom gene token dictionary. @@ -230,16 +222,14 @@ class Classifier: self.stratify_splits_col = stratify_splits_col self.no_eval = no_eval self.forward_batch_size = forward_batch_size - self.model_version = model_version self.token_dictionary_file = token_dictionary_file self.nproc = nproc self.ngpu = ngpu - + if self.training_args is None: logger.warning( "Hyperparameter tuning is highly recommended for optimal results. " - "No training_args provided; using default hyperparameters. " - "Please note: these defaults are not recommended to be used uniformly across tasks." + "No training_args provided; using default hyperparameters." ) self.validate_options() @@ -254,10 +244,7 @@ class Classifier: ] = self.cell_state_dict["states"] # load token dictionary (Ensembl IDs:token) - if self.model_version == "V1": - from . import TOKEN_DICTIONARY_FILE_30M - self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M - elif self.token_dictionary_file is None: + if self.token_dictionary_file is None: self.token_dictionary_file = TOKEN_DICTIONARY_FILE with open(self.token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) @@ -367,7 +354,6 @@ class Classifier: attr_to_balance=None, max_trials=100, pval_threshold=0.1, - id_class_dict_path=None, ): """ Prepare data for cell state or gene classification. @@ -410,10 +396,6 @@ class Classifier: pval_threshold : None, float | P-value threshold to use for attribute balancing across splits | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance - id_class_dict_path : Path - | Path to *_id_class_dict.pkl from prior run of prepare_data to reuse for labeling new data - | Dictionary with keys being numeric class labels and values being original dataset class labels - | Note: only available for CellClassifiers """ if test_size is None: @@ -457,13 +439,8 @@ class Classifier: data = cu.rename_cols(data, self.cell_state_dict["state_key"]) # convert classes to numerical labels and save as id_class_dict - if id_class_dict_path is not None: - with open(id_class_dict_path,"rb") as fp: - id_class_dict = pickle.load(fp) - else: - id_class_dict = None data, id_class_dict = cu.label_classes( - self.classifier, data, self.cell_state_dict, self.nproc, id_class_dict, + self.classifier, data, self.cell_state_dict, self.nproc ) elif self.classifier == "gene": @@ -801,7 +778,7 @@ class Classifier: # 5-fold cross-validate num_cells = len(data) fifth_cells = int(np.floor(num_cells * 0.2)) - num_eval = int(min((self.eval_size * num_cells), fifth_cells)) + num_eval = min((self.eval_size * num_cells), fifth_cells) start = i * fifth_cells end = start + num_eval eval_indices = [j for j in range(start, end)] @@ -1085,8 +1062,6 @@ class Classifier: if eval_data is None: def_training_args["evaluation_strategy"] = "no" def_training_args["load_best_model_at_end"] = False - if transformers_version >= parse("4.46"): - def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy") def_training_args.update( {"save_strategy": "epoch", "save_total_limit": 1} ) # only save last model for each run @@ -1258,8 +1233,6 @@ class Classifier: if eval_data is None: def_training_args["evaluation_strategy"] = "no" def_training_args["load_best_model_at_end"] = False - if transformers_version >= parse("4.46"): - def_training_args["eval_strategy"] = def_training_args.pop("evaluation_strategy") training_args_init = TrainingArguments(**def_training_args) if self.freeze_layers is not None: @@ -1313,7 +1286,6 @@ class Classifier: predict=False, output_directory=None, output_prefix=None, - predict_metadata=None, ): """ Evaluate the fine-tuned model. @@ -1339,11 +1311,9 @@ class Classifier: ##### Evaluate the model ##### labels = id_class_dict.keys() - - y_pred, y_true, logits_list, predict_metadata_all = eu.classifier_predict( - model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata + y_pred, y_true, logits_list = eu.classifier_predict( + model, self.classifier, eval_data, self.forward_batch_size ) - conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics( y_pred, y_true, logits_list, num_classes, labels ) @@ -1353,9 +1323,6 @@ class Classifier: "label_ids": y_true, "predictions": logits_list, } - if predict_metadata is not None: - pred_dict["prediction_metadata"] = predict_metadata_all - pred_dict_output_path = ( Path(output_directory) / f"{output_prefix}_pred_dict" ).with_suffix(".pkl") @@ -1376,7 +1343,6 @@ class Classifier: output_directory, output_prefix, predict=True, - predict_metadata=None, ): """ Evaluate the fine-tuned model. @@ -1396,8 +1362,6 @@ class Classifier: | Prefix for output files predict : bool | Whether or not to save eval predictions - predict_metadata : None | list - | Metadata labels to output with predictions (columns in test_data_file) """ # load numerical id to class dictionary (id:class) @@ -1410,15 +1374,6 @@ class Classifier: # load previously filtered and prepared data test_data = pu.load_and_filter(None, self.nproc, test_data_file) - if predict_metadata is not None: - absent_metadata = [] - for predict_metadata_x in predict_metadata: - if predict_metadata_x not in test_data.features.keys(): - absent_metadata += [predict_metadata_x] - if len(absent_metadata)>0: - logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}") - raise - # load previously fine-tuned model model = pu.load_model( self.model_type, @@ -1437,7 +1392,6 @@ class Classifier: predict=predict, output_directory=output_directory, output_prefix=output_prefix, - predict_metadata=predict_metadata, ) all_conf_mat_df = pd.DataFrame( diff --git a/geneformer/classifier_utils.py b/geneformer/classifier_utils.py index a30884c34629573b16c3951517d42f05e57ec108..a20e00c9116688b67f1a48d6ce89e0e9f3b7ad09 100644 --- a/geneformer/classifier_utils.py +++ b/geneformer/classifier_utils.py @@ -94,7 +94,7 @@ def remove_rare(data, rare_threshold, label, nproc): return data -def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None): +def label_classes(classifier, data, gene_class_dict, nproc): if classifier == "cell": label_set = set(data["label"]) elif classifier == "gene": @@ -113,11 +113,8 @@ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None): ) raise - if id_class_dict is None: - class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) - id_class_dict = {v: k for k, v in class_id_dict.items()} - else: - class_id_dict = {v: k for k, v in id_class_dict.items()} + class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) + id_class_dict = {v: k for k, v in class_id_dict.items()} if classifier == "gene": inverse_gene_class_dict = {} @@ -570,27 +567,6 @@ def compute_metrics(pred): return {"accuracy": acc, "macro_f1": macro_f1} -def robust_compute_objective(metrics: dict): - # tries both prefixed ("eval_") and raw metric names to support different transformers versions - metric_name = "macro_f1" - - # check for the prefixed version - prefixed_metric_name = f"eval_{metric_name}" - if prefixed_metric_name in metrics: - return metrics[prefixed_metric_name] - - # fall back to the raw name - elif metric_name in metrics: - return metrics[metric_name] - - # if neither is found, raise a clear error to help with debugging - raise KeyError( - f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. " - f"Please check your `compute_metrics` function and `TrainingArguments`. " - f"Available metrics: {list(metrics.keys())}" - ) - - def get_default_train_args(model, classifier, data, output_dir): num_layers = pu.quant_layers(model) freeze_layers = 0 diff --git a/geneformer/collator_for_classification.py b/geneformer/collator_for_classification.py index a34dda86cb56021a982d76111971f315b8e33e05..297fa666dbf0daeaa94e2ca203ace5f98570a30e 100644 --- a/geneformer/collator_for_classification.py +++ b/geneformer/collator_for_classification.py @@ -26,10 +26,9 @@ LARGE_INTEGER = int( 1e20 ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER -warnings.filterwarnings("ignore", message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach()", category=UserWarning, module="torch") - # precollator functions + class ExplicitEnum(Enum): """ Enum with more explicit error message for missing values. @@ -104,9 +103,6 @@ class PrecollatorForGeneAndCellClassification(SpecialTokensMixin): def pad_token_id(self): return self._pad_token_id - def save_pretrained(self, save_directory): - pass - def _get_padding_truncation_strategies( self, padding=True, @@ -645,8 +641,7 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification): def __call__(self, features): batch = self._prepare_batch(features) - # batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} - batch = {k: torch.tensor(v.clone().detach(), dtype=torch.int64) if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} + batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} return batch diff --git a/geneformer/emb_extractor.py b/geneformer/emb_extractor.py index d46ad463a38896d7628d7b49249f000d9783a8e2..2ef1103dee2e492f87f751a51d0f4f12b1ce87d0 100644 --- a/geneformer/emb_extractor.py +++ b/geneformer/emb_extractor.py @@ -42,8 +42,6 @@ def get_embs( special_token=False, summary_stat=None, silent=False, - save_tdigest=False, - tdigest_path=None, ): model_input_size = pu.get_model_input_size(model) total_batch_length = len(filtered_input_data) @@ -182,18 +180,12 @@ def get_embs( # calculate summary stat embs from approximated tdigests elif summary_stat is not None: if emb_mode == "cell": - if save_tdigest: - with open(f"{tdigest_path}","wb") as fp: - pickle.dump(embs_tdigests, fp) if summary_stat == "mean": summary_emb_list = tdigest_mean(embs_tdigests, emb_dims) elif summary_stat == "median": summary_emb_list = tdigest_median(embs_tdigests, emb_dims) embs_stack = torch.tensor(summary_emb_list) elif emb_mode == "gene": - if save_tdigest: - with open(f"{tdigest_path}","wb") as fp: - pickle.dump(embs_tdigests_dict, fp) if summary_stat == "mean": [ update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims) @@ -260,7 +252,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels): return embs_df -def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mean_pool"): +def label_gene_embs(embs, downsampled_data, token_gene_dict): gene_set = { element for sublist in downsampled_data["input_ids"] for element in sublist } @@ -275,52 +267,25 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mea ) for k in dict_i.keys(): gene_emb_dict[k].append(dict_i[k]) - if gene_emb_style != "all": - for k in gene_emb_dict.keys(): - gene_emb_dict[k] = ( - torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0) - .cpu() - .numpy() - ) - embs_df = pd.DataFrame(gene_emb_dict).T - else: - embs_df = dict_lol_to_df(gene_emb_dict) + for k in gene_emb_dict.keys(): + gene_emb_dict[k] = ( + torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0) + .cpu() + .numpy() + ) + embs_df = pd.DataFrame(gene_emb_dict).T embs_df.index = [token_gene_dict[token] for token in embs_df.index] return embs_df -def dict_lol_to_df(data_dict): - # save dictionary with values being list of equal-length lists as dataframe - df_data = [] - for key, list_of_lists in data_dict.items(): - for i, sublist in enumerate(list_of_lists): - row_data = [key, i] + sublist.tolist() - df_data.append(row_data) - - # determine column names based on the length of sublists - # assuming all sublists have the same length - num_columns_from_sublist = len(list(data_dict.values())[0][0]) - column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)] - - # create the dataframe - df = pd.DataFrame(df_data, columns=column_names) - - # set 'Gene' as the index - df = df.set_index('Gene') - - return df - -def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0): + +def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0): only_embs_df = embs_df.iloc[:, :emb_dims] only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str) only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype( str ) vars_dict = {"embs": only_embs_df.columns} - - obs_dict = {"cell_id": list(only_embs_df.index)} - for label_i in labels_clean: - obs_dict[label_i] = list(embs_df[label_i]) - + obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])} adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict) sc.tl.pca(adata, svd_solver="arpack") sc.pp.neighbors(adata, random_state=seed) @@ -331,26 +296,21 @@ def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, if kwargs_dict is not None: default_kwargs_dict.update(kwargs_dict) - for label_i in labels_clean: - output_prefix_label = output_prefix + f"_umap_{label_i}" - output_file = ( - Path(output_directory) / output_prefix_label - ).with_suffix(".pdf") - - cats = set(embs_df[label_i]) - - with plt.rc_context(): - ax = sc.pl.umap(adata, color=label_i, show=False, **default_kwargs_dict) - ax.legend( - markerscale=2, - frameon=False, - loc="center left", - bbox_to_anchor=(1, 0.5), - ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3), - ) - plt.show() - plt.savefig(output_file, bbox_inches="tight") - + cats = set(embs_df[label]) + + with plt.rc_context(): + ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict) + ax.legend( + markerscale=2, + frameon=False, + loc="center left", + bbox_to_anchor=(1, 0.5), + ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3), + ) + plt.show() + plt.savefig(output_file, bbox_inches="tight") + + def gen_heatmap_class_colors(labels, df): pal = sns.cubehelix_palette( len(Counter(labels).keys()), @@ -435,14 +395,13 @@ class EmbExtractor: "num_classes": {int}, "emb_mode": {"cls", "cell", "gene"}, "cell_emb_style": {"mean_pool"}, - "gene_emb_style": {"mean_pool", "all"}, + "gene_emb_style": {"mean_pool"}, "filter_data": {None, dict}, "max_ncells": {None, int}, "emb_layer": {-1, 0}, "emb_label": {None, list}, "labels_to_plot": {None, list}, "forward_batch_size": {int}, - "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "nproc": {int}, "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"}, @@ -463,8 +422,6 @@ class EmbExtractor: forward_batch_size=100, nproc=4, summary_stat=None, - save_tdigest=False, - model_version="V2", token_dictionary_file=None, ): """ @@ -483,9 +440,9 @@ class EmbExtractor: cell_emb_style : {"mean_pool"} | Method for summarizing cell embeddings if not using CLS token. | Currently only option is mean pooling of gene embeddings for given cell. - gene_emb_style : {"mean_pool", "all} + gene_emb_style : "mean_pool" | Method for summarizing gene embeddings. - | Currently only option is returning all or mean pooling of contextual gene embeddings for given gene. + | Currently only option is mean pooling of contextual gene embeddings for given gene. filter_data : None, dict | Default is to extract embeddings from all input data. | Otherwise, dictionary specifying .dataset column name and list of values to filter by. @@ -515,12 +472,6 @@ class EmbExtractor: | If mean or median, outputs only approximated mean or median embedding of input data. | Non-exact recommended if encountering memory constraints while generating goal embedding positions. | Non-exact is slower but more memory-efficient. - save_tdigest : bool - | Whether to save a dictionary of tdigests for each gene and embedding dimension - | Only applies when summary_stat is not None - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Default is the Geneformer token dictionary | Path to pickle file containing token dictionary (Ensembl ID:token). @@ -551,7 +502,6 @@ class EmbExtractor: self.emb_layer = emb_layer self.emb_label = emb_label self.labels_to_plot = labels_to_plot - self.model_version = model_version self.token_dictionary_file = token_dictionary_file self.forward_batch_size = forward_batch_size self.nproc = nproc @@ -561,29 +511,13 @@ class EmbExtractor: else: self.summary_stat = summary_stat self.exact_summary_stat = None - self.save_tdigest = save_tdigest self.validate_options() - if (summary_stat is None) and (save_tdigest is True): - logger.warning( - "tdigests will not be saved since summary_stat is None." - ) - save_tdigest = False - - if self.model_version == "V1": - from . import TOKEN_DICTIONARY_FILE_30M - self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M - if self.emb_mode == "cls": - self.emb_mode = "cell" - logger.warning( - "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a token." - ) - # load token dictionary (Ensembl IDs:token) if self.token_dictionary_file is None: - self.token_dictionary_file = TOKEN_DICTIONARY_FILE - with open(self.token_dictionary_file, "rb") as f: + token_dictionary_file = TOKEN_DICTIONARY_FILE + with open(token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()} @@ -677,10 +611,6 @@ class EmbExtractor: self.model_type, self.num_classes, model_directory, mode="eval" ) layer_to_quant = pu.quant_layers(model) + self.emb_layer - if self.save_tdigest: - tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl") - else: - tdigest_path = None embs = get_embs( model=model, filtered_input_data=downsampled_data, @@ -690,8 +620,6 @@ class EmbExtractor: forward_batch_size=self.forward_batch_size, token_gene_dict=self.token_gene_dict, summary_stat=self.summary_stat, - save_tdigest=self.save_tdigest, - tdigest_path=tdigest_path, ) if self.emb_mode == "cell": @@ -701,7 +629,7 @@ class EmbExtractor: embs_df = pd.DataFrame(embs.cpu().numpy()).T elif self.emb_mode == "gene": if self.summary_stat is None: - embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict, self.gene_emb_style) + embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict) elif self.summary_stat is not None: embs_df = pd.DataFrame(embs).T embs_df.index = [self.token_gene_dict[token] for token in embs_df.index] @@ -885,14 +813,14 @@ class EmbExtractor: raise if max_ncells_to_plot is not None: - if self.max_ncells is not None: - if max_ncells_to_plot > self.max_ncells: - max_ncells_to_plot = self.max_ncells - logger.warning( - "max_ncells_to_plot must be <= max_ncells. " - f"Changing max_ncells_to_plot to {self.max_ncells}." - ) - embs = embs.sample(max_ncells_to_plot, axis=0) + if max_ncells_to_plot > self.max_ncells: + max_ncells_to_plot = self.max_ncells + logger.warning( + "max_ncells_to_plot must be <= max_ncells. " + f"Changing max_ncells_to_plot to {self.max_ncells}." + ) + elif max_ncells_to_plot < self.max_ncells: + embs = embs.sample(max_ncells_to_plot, axis=0) if self.emb_label is None: label_len = 0 @@ -913,9 +841,12 @@ class EmbExtractor: f"Label {label} from labels_to_plot " f"not present in provided embeddings dataframe." ) - - labels_clean = [label for label in self.labels_to_plot if label in emb_labels] - plot_umap(embs, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict) + continue + output_prefix_label = output_prefix + f"_umap_{label}" + output_file = ( + Path(output_directory) / output_prefix_label + ).with_suffix(".pdf") + plot_umap(embs, emb_dims, label, output_file, kwargs_dict) if plot_style == "heatmap": for label in self.labels_to_plot: diff --git a/geneformer/ensembl_mapping_dict_gc104M.pkl b/geneformer/ensembl_mapping_dict_gc95M.pkl similarity index 100% rename from geneformer/ensembl_mapping_dict_gc104M.pkl rename to geneformer/ensembl_mapping_dict_gc95M.pkl diff --git a/geneformer/evaluation_utils.py b/geneformer/evaluation_utils.py index 88f850cda75150d997f4fb6b4c13f2b13b2e2072..e4bbc8326d33b0de0a62778e6cde0d0c4bd86b25 100644 --- a/geneformer/evaluation_utils.py +++ b/geneformer/evaluation_utils.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd import seaborn as sns import torch -import datasets from datasets.utils.logging import disable_progress_bar, enable_progress_bar from sklearn import preprocessing from sklearn.metrics import ( @@ -21,15 +20,20 @@ from sklearn.metrics import ( ) from tqdm.auto import trange +from . import TOKEN_DICTIONARY_FILE from .emb_extractor import make_colorbar logger = logging.getLogger(__name__) -def preprocess_classifier_batch(cell_batch, max_len, label_name, gene_token_dict): +def preprocess_classifier_batch(cell_batch, max_len, label_name): if max_len is None: max_len = max([len(i) for i in cell_batch["input_ids"]]) + # load token dictionary (Ensembl IDs:token) + with open(TOKEN_DICTIONARY_FILE, "rb") as f: + gene_token_dict = pickle.load(f) + def pad_label_example(example): example[label_name] = np.pad( example[label_name], @@ -77,7 +81,7 @@ def py_softmax(vector): return e / e.sum() -def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict, predict_metadata=None): +def classifier_predict(model, classifier_type, evalset, forward_batch_size): if classifier_type == "gene": label_name = "labels" elif classifier_type == "cell": @@ -85,14 +89,6 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene predict_logits = [] predict_labels = [] - - predict_metadata_all = None - - if predict_metadata is not None: - predict_metadata_all = dict() - for metadata_name in predict_metadata: - predict_metadata_all[metadata_name] = [] - model.eval() # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims @@ -107,21 +103,11 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene for i in trange(0, evalset_len, forward_batch_size): max_range = min(i + forward_batch_size, evalset_len) batch_evalset = evalset.select([i for i in range(i, max_range)]) - - if predict_metadata is not None: - for metadata_name in predict_metadata: - predict_metadata_all[metadata_name] += batch_evalset[metadata_name] - padded_batch = preprocess_classifier_batch( - batch_evalset, max_evalset_len, label_name, gene_token_dict + batch_evalset, max_evalset_len, label_name ) - padded_batch.set_format(type="torch") - # For datasets>=4.0.0, convert to dict to avoid format issues - if int(datasets.__version__.split(".")[0]) >= 4: - padded_batch = padded_batch[:] - input_data_batch = padded_batch["input_ids"] attn_msk_batch = padded_batch["attention_mask"] label_batch = padded_batch[label_name] @@ -148,8 +134,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene y_pred = [vote(item[0]) for item in logit_label_paired] y_true = [item[1] for item in logit_label_paired] logits_list = [item[0] for item in logit_label_paired] - - return y_pred, y_true, logits_list, predict_metadata_all + return y_pred, y_true, logits_list def get_metrics(y_pred, y_true, logits_list, num_classes, labels): diff --git a/geneformer/gene_median_dictionary_gc104M.pkl b/geneformer/gene_median_dictionary_gc95M.pkl similarity index 100% rename from geneformer/gene_median_dictionary_gc104M.pkl rename to geneformer/gene_median_dictionary_gc95M.pkl diff --git a/geneformer/gene_name_id_dict_gc104M.pkl b/geneformer/gene_name_id_dict_gc95M.pkl similarity index 100% rename from geneformer/gene_name_id_dict_gc104M.pkl rename to geneformer/gene_name_id_dict_gc95M.pkl diff --git a/geneformer/in_silico_perturber.py b/geneformer/in_silico_perturber.py index 35244c22237abd7e24938bb4324c2129551634a7..275244f771e344435734f9ef19f3749e294f0d2c 100644 --- a/geneformer/in_silico_perturber.py +++ b/geneformer/in_silico_perturber.py @@ -72,7 +72,6 @@ class InSilicoPerturber: "max_ncells": {None, int}, "cell_inds_to_perturb": {"all", dict}, "emb_layer": {-1, 0}, - "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "forward_batch_size": {int}, "nproc": {int}, @@ -97,7 +96,6 @@ class InSilicoPerturber: emb_layer=-1, forward_batch_size=100, nproc=4, - model_version="V2", token_dictionary_file=None, clear_mem_ncells=1000, ): @@ -186,9 +184,6 @@ class InSilicoPerturber: | Batch size for forward pass. nproc : int | Number of CPU processes to use. - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl ID:token). clear_mem_ncells : int @@ -229,30 +224,15 @@ class InSilicoPerturber: self.emb_layer = emb_layer self.forward_batch_size = forward_batch_size self.nproc = nproc - self.model_version = model_version self.token_dictionary_file = token_dictionary_file - self.clear_mem_ncells = clear_mem_ncells + self.clear_mem_ncells = clear_mem_ncells self.validate_options() - if self.model_version == "V1": - from . import TOKEN_DICTIONARY_FILE_30M - self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M - if self.emb_mode == "cls": - self.emb_mode = "cell" - logger.warning( - "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a token." - ) - if self.emb_mode == "cls_and_gene": - self.emb_mode = "cell_and_gene" - logger.warning( - "model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a token." - ) - # load token dictionary (Ensembl IDs:token) if self.token_dictionary_file is None: - self.token_dictionary_file = TOKEN_DICTIONARY_FILE - with open(self.token_dictionary_file, "rb") as f: + token_dictionary_file = TOKEN_DICTIONARY_FILE + with open(token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()} diff --git a/geneformer/in_silico_perturber_stats.py b/geneformer/in_silico_perturber_stats.py index 40c2d5fffbd87de84b294a0a5d608391d07b5128..9ec98a8caee4e4ca623c5ecc7c18c36210806cce 100644 --- a/geneformer/in_silico_perturber_stats.py +++ b/geneformer/in_silico_perturber_stats.py @@ -676,7 +676,6 @@ class InSilicoPerturberStats: "anchor_gene": {None, str}, "cell_states_to_model": {None, dict}, "pickle_suffix": {None, str}, - "model_version": {"V1", "V2"}, } def __init__( @@ -687,7 +686,6 @@ class InSilicoPerturberStats: anchor_gene=None, cell_states_to_model=None, pickle_suffix="_raw.pickle", - model_version="V2", token_dictionary_file=TOKEN_DICTIONARY_FILE, gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE, ): @@ -715,7 +713,7 @@ class InSilicoPerturberStats: | analyzes data for anchor gene perturbed in combination with each other gene. | However, if combos=0 and anchor_gene="ENSG00000136574": | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene. - cell_states_to_model : None, dict + cell_states_to_model: None, dict | Cell states to model if testing perturbations that achieve goal state change. | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states | state_key: key specifying name of column in .dataset that defines the start/goal states @@ -726,12 +724,6 @@ class InSilicoPerturberStats: | "start_state": "dcm", | "goal_state": "nf", | "alt_states": ["hcm", "other1", "other2"]} - pickle_suffix : None, str - | Suffix to subselect intermediate raw files for analysis. - | Default output of InSilicoPerturber uses suffix "_raw.pickle". - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl ID:token). gene_name_id_dictionary_file : Path @@ -744,15 +736,9 @@ class InSilicoPerturberStats: self.anchor_gene = anchor_gene self.cell_states_to_model = cell_states_to_model self.pickle_suffix = pickle_suffix - self.model_version = model_version self.validate_options() - if self.model_version == "V1": - from . import ENSEMBL_DICTIONARY_FILE_30M, TOKEN_DICTIONARY_FILE_30M - token_dictionary_file=TOKEN_DICTIONARY_FILE_30M - gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE_30M - # load token dictionary (Ensembl IDs:token) with open(token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) diff --git a/geneformer/mtl/__init__.py b/geneformer/mtl/__init__.py index 5f5235c230cb0d1c51e2823e5b92118bb480dba9..06788a56ac11397d1698a74381d466b7b7bd98b7 100644 --- a/geneformer/mtl/__init__.py +++ b/geneformer/mtl/__init__.py @@ -1,4 +1 @@ -# ruff: noqa: F401 - -from . import eval_utils -from . import utils \ No newline at end of file +# ruff: noqa: F401 \ No newline at end of file diff --git a/geneformer/mtl/collators.py b/geneformer/mtl/collators.py index b575b23ba592a6f552928b306c7882de8dbe7f4f..63546f93a05c857781198be88de027f5fb9e827f 100644 --- a/geneformer/mtl/collators.py +++ b/geneformer/mtl/collators.py @@ -1,8 +1,8 @@ # imports import torch import pickle -from geneformer.collator_for_classification import DataCollatorForGeneClassification -from geneformer import TOKEN_DICTIONARY_FILE +from ..collator_for_classification import DataCollatorForGeneClassification +from .. import TOKEN_DICTIONARY_FILE """Geneformer collator for multi-task cell classification.""" diff --git a/geneformer/mtl/data.py b/geneformer/mtl/data.py index 7512f95a98a58e5e2d6076fa863bd628b0a24302..6bdec933d24b7197a1bd17747180217dec4a3d99 100644 --- a/geneformer/mtl/data.py +++ b/geneformer/mtl/data.py @@ -1,190 +1,126 @@ import os -import pickle -import torch -from torch.utils.data import DataLoader, Dataset -from datasets import load_from_disk - from .collators import DataCollatorForMultitaskCellClassification +from .imports import * + +def validate_columns(dataset, required_columns, dataset_type): + """Ensures required columns are present in the dataset.""" + missing_columns = [col for col in required_columns if col not in dataset.column_names] + if missing_columns: + raise KeyError( + f"Missing columns in {dataset_type} dataset: {missing_columns}. " + f"Available columns: {dataset.column_names}" + ) -class StreamingMultiTaskDataset(Dataset): - - def __init__(self, dataset_path, config, is_test=False, dataset_type=""): - """Initialize the streaming dataset.""" - self.dataset = load_from_disk(dataset_path) - self.config = config - self.is_test = is_test - self.dataset_type = dataset_type - self.cell_id_mapping = {} - - # Setup task and column mappings - self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))] - self.task_to_column = dict(zip(self.task_names, config["task_columns"])) - config["task_names"] = self.task_names - - # Check if unique_cell_id column exists in the dataset - self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names - print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset") - - # Setup label mappings - self.label_mappings_path = os.path.join( - config["results_dir"], - f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl" - ) - - if not is_test: - self._validate_columns() - self.task_label_mappings, self.num_labels_list = self._create_label_mappings() - self._save_label_mappings() - else: - # Load existing mappings for test data - self.task_label_mappings = self._load_label_mappings() - self.num_labels_list = [len(mapping) for mapping in self.task_label_mappings.values()] - - def _validate_columns(self): - """Ensures required columns are present in the dataset.""" - missing_columns = [col for col in self.task_to_column.values() - if col not in self.dataset.column_names] - if missing_columns: - raise KeyError( - f"Missing columns in {self.dataset_type} dataset: {missing_columns}. " - f"Available columns: {self.dataset.column_names}" - ) - - def _create_label_mappings(self): - """Creates label mappings for the dataset.""" - task_label_mappings = {} - num_labels_list = [] - - for task, column in self.task_to_column.items(): - unique_values = sorted(set(self.dataset[column])) - mapping = {label: idx for idx, label in enumerate(unique_values)} - task_label_mappings[task] = mapping - num_labels_list.append(len(unique_values)) - - return task_label_mappings, num_labels_list - - def _save_label_mappings(self): - """Saves label mappings to a pickle file.""" - with open(self.label_mappings_path, "wb") as f: - pickle.dump(self.task_label_mappings, f) - - def _load_label_mappings(self): - """Loads label mappings from a pickle file.""" - with open(self.label_mappings_path, "rb") as f: - return pickle.load(f) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - record = self.dataset[idx] - - # Store cell ID mapping - if self.has_unique_cell_ids: - unique_cell_id = record["unique_cell_id"] - self.cell_id_mapping[idx] = unique_cell_id - else: - self.cell_id_mapping[idx] = f"cell_{idx}" - - # Create transformed record +def create_label_mappings(dataset, task_to_column): + """Creates label mappings for the dataset.""" + task_label_mappings = {} + num_labels_list = [] + for task, column in task_to_column.items(): + unique_values = sorted(set(dataset[column])) + mapping = {label: idx for idx, label in enumerate(unique_values)} + task_label_mappings[task] = mapping + num_labels_list.append(len(unique_values)) + return task_label_mappings, num_labels_list + + +def save_label_mappings(mappings, path): + """Saves label mappings to a pickle file.""" + with open(path, "wb") as f: + pickle.dump(mappings, f) + + +def load_label_mappings(path): + """Loads label mappings from a pickle file.""" + with open(path, "rb") as f: + return pickle.load(f) + + +def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test): + """Transforms the dataset to the required format.""" + transformed_dataset = [] + cell_id_mapping = {} + + for idx, record in enumerate(dataset): transformed_record = { "input_ids": torch.tensor(record["input_ids"], dtype=torch.long), - "cell_id": idx, + "cell_id": idx, # Index-based cell ID } - - # Add labels - if not self.is_test: + + if not is_test: label_dict = { - task: self.task_label_mappings[task][record[column]] - for task, column in self.task_to_column.items() + task: task_label_mappings[task][record[column]] + for task, column in task_to_column.items() } else: - label_dict = {task: -1 for task in self.config["task_names"]} - + label_dict = {task: -1 for task in config["task_names"]} + transformed_record["label"] = label_dict - - return transformed_record + transformed_dataset.append(transformed_record) + cell_id_mapping[idx] = record.get("unique_cell_id", idx) + return transformed_dataset, cell_id_mapping -def get_data_loader(dataset, batch_size, sampler=None, shuffle=True): - """Create a DataLoader with the given dataset and parameters.""" - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - shuffle=shuffle if sampler is None else False, - num_workers=0, - pin_memory=True, - collate_fn=DataCollatorForMultitaskCellClassification(), - ) +def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""): + """Main function to load and preprocess data.""" + try: + dataset = load_from_disk(dataset_path) -def prepare_data_loaders(config, include_test=False): - """Prepare data loaders for training, validation, and optionally test.""" - result = {} - - # Process train data - train_dataset = StreamingMultiTaskDataset( - config["train_path"], - config, - dataset_type="train" - ) - result["train_loader"] = get_data_loader(train_dataset, config["batch_size"]) - - # Store the cell ID mapping from the dataset - result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()} - print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset") - - result["num_labels_list"] = train_dataset.num_labels_list - - # Process validation data - val_dataset = StreamingMultiTaskDataset( - config["val_path"], - config, - dataset_type="validation" - ) - result["val_loader"] = get_data_loader(val_dataset, config["batch_size"]) - - # Store the complete cell ID mapping for validation - for idx in range(len(val_dataset)): - _ = val_dataset[idx] - - result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()} - print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset") - - # Validate label mappings - validate_label_mappings(config) - - # Process test data if requested - if include_test and "test_path" in config: - test_dataset = StreamingMultiTaskDataset( - config["test_path"], - config, - is_test=True, - dataset_type="test" + # Setup task and column mappings + task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))] + task_to_column = dict(zip(task_names, config["task_columns"])) + config["task_names"] = task_names + + label_mappings_path = os.path.join( + config["results_dir"], + f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl" ) - result["test_loader"] = get_data_loader(test_dataset, config["batch_size"]) - - for idx in range(len(test_dataset)): - _ = test_dataset[idx] - - result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()} - print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset") - - return result + + if not is_test: + validate_columns(dataset, task_to_column.values(), dataset_type) + + # Create and save label mappings + task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column) + save_label_mappings(task_label_mappings, label_mappings_path) + else: + # Load existing mappings for test data + task_label_mappings = load_label_mappings(label_mappings_path) + num_labels_list = [len(mapping) for mapping in task_label_mappings.values()] + + # Transform dataset + transformed_dataset, cell_id_mapping = transform_dataset( + dataset, task_to_column, task_label_mappings, config, is_test + ) + + return transformed_dataset, cell_id_mapping, num_labels_list + + except KeyError as e: + raise ValueError(f"Configuration error or dataset key missing: {e}") + except Exception as e: + raise RuntimeError(f"Error during data loading or preprocessing: {e}") + + +def preload_and_process_data(config): + """Preloads and preprocesses train and validation datasets.""" + # Process train data and save mappings + train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train") + + # Process validation data and save mappings + val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation") + + # Validate that the mappings match + validate_label_mappings(config) + + return (*train_data[:2], *val_data) # Return train and val data along with mappings def validate_label_mappings(config): """Ensures train and validation label mappings are consistent.""" train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl") val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl") - - with open(train_mappings_path, "rb") as f: - train_mappings = pickle.load(f) - - with open(val_mappings_path, "rb") as f: - val_mappings = pickle.load(f) + train_mappings = load_label_mappings(train_mappings_path) + val_mappings = load_label_mappings(val_mappings_path) for task_name in config["task_names"]: if train_mappings[task_name] != val_mappings[task_name]: @@ -195,43 +131,32 @@ def validate_label_mappings(config): ) -# Legacy functions for backward compatibility -def preload_and_process_data(config): - """Preloads and preprocesses train and validation datasets.""" - data = prepare_data_loaders(config) - - return ( - data["train_loader"].dataset, - data["train_cell_mapping"], - data["val_loader"].dataset, - data["val_cell_mapping"], - data["num_labels_list"] +def get_data_loader(preprocessed_dataset, batch_size): + """Creates a DataLoader with optimal settings.""" + return DataLoader( + preprocessed_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=DataCollatorForMultitaskCellClassification(), + num_workers=os.cpu_count(), + pin_memory=True, ) def preload_data(config): """Preprocesses train and validation data for trials.""" - data = prepare_data_loaders(config) - return data["train_loader"], data["val_loader"] + train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"]) + val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"]) + return train_loader, val_loader def load_and_preprocess_test_data(config): """Loads and preprocesses test data.""" - test_dataset = StreamingMultiTaskDataset( - config["test_path"], - config, - is_test=True, - dataset_type="test" - ) - - return ( - test_dataset, - test_dataset.cell_id_mapping, - test_dataset.num_labels_list - ) + return load_and_preprocess_data(config["test_path"], config, is_test=True) def prepare_test_loader(config): """Prepares DataLoader for test data.""" - data = prepare_data_loaders(config, include_test=True) - return data["test_loader"], data["test_cell_mapping"], data["num_labels_list"] \ No newline at end of file + test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config) + test_loader = get_data_loader(test_dataset, config["batch_size"]) + return test_loader, cell_id_mapping, num_labels_list diff --git a/geneformer/mtl/eval_utils.py b/geneformer/mtl/eval_utils.py index 29f98214362f23e4492814948359af2b145abba8..0a8ea4babe4ab1e48cc56280ee03423075cf7563 100644 --- a/geneformer/mtl/eval_utils.py +++ b/geneformer/mtl/eval_utils.py @@ -1,16 +1,19 @@ -import os -import json -import torch import pandas as pd -from .data import prepare_test_loader +from .imports import * # noqa # isort:skip +from .data import prepare_test_loader # noqa # isort:skip from .model import GeneformerMultiTask + def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config): task_pred_labels = {task_name: [] for task_name in config["task_names"]} task_pred_probs = {task_name: [] for task_name in config["task_names"]} cell_ids = [] + # # Load task label mappings from pickle file + # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f: + # task_label_mappings = pickle.load(f) + model.eval() with torch.no_grad(): for batch in test_loader: @@ -82,4 +85,4 @@ def load_and_evaluate_test_model(config): best_model.to(device) evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config) - print("Evaluation completed.") \ No newline at end of file + print("Evaluation completed.") diff --git a/geneformer/mtl/imports.py b/geneformer/mtl/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe9e90945a10a3d79cc487fa15431f2915e5683 --- /dev/null +++ b/geneformer/mtl/imports.py @@ -0,0 +1,43 @@ +import functools +import gc +import json +import os +import pickle +import sys +import warnings +from enum import Enum +from itertools import chain +from typing import Dict, List, Optional, Union + +import numpy as np +import optuna +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from datasets import load_from_disk +from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder +from torch.utils.data import DataLoader +from transformers import ( + AdamW, + BatchEncoding, + BertConfig, + BertModel, + DataCollatorForTokenClassification, + SpecialTokensMixin, + get_cosine_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_scheduler, +) +from transformers.utils import logging, to_py_obj + +from .collators import DataCollatorForMultitaskCellClassification + +# local modules +from .data import get_data_loader, preload_and_process_data +from .model import GeneformerMultiTask +from .optuna_utils import create_optuna_study +from .utils import save_model diff --git a/geneformer/mtl/model.py b/geneformer/mtl/model.py index 0b03ff31bf81c458756b9afc7ead315e53c8d63b..393ebfad4f44f98d748845ea1ae81d66139988f5 100644 --- a/geneformer/mtl/model.py +++ b/geneformer/mtl/model.py @@ -118,4 +118,4 @@ class GeneformerMultiTask(nn.Module): f"Error during loss computation for task {task_id}: {e}" ) - return total_loss, logits, losses if labels is not None else logits \ No newline at end of file + return total_loss, logits, losses if labels is not None else logits diff --git a/geneformer/mtl/optuna_utils.py b/geneformer/mtl/optuna_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..47f375e90f4030e15feb7bc1245ffbba3e6a086e --- /dev/null +++ b/geneformer/mtl/optuna_utils.py @@ -0,0 +1,27 @@ +import optuna +from optuna.integration import TensorBoardCallback + + +def save_trial_callback(study, trial, trials_result_path): + with open(trials_result_path, "a") as f: + f.write( + f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n" + ) + + +def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir): + study = optuna.create_study(direction="maximize") + + # init TensorBoard callback + tensorboard_callback = TensorBoardCallback( + dirname=tensorboard_log_dir, metric_name="F1 Macro" + ) + + # callback and TensorBoard callback + callbacks = [ + lambda study, trial: save_trial_callback(study, trial, trials_result_path), + tensorboard_callback, + ] + + study.optimize(objective, n_trials=n_trials, callbacks=callbacks) + return study diff --git a/geneformer/mtl/train.py b/geneformer/mtl/train.py index 02c00c706111b6b4198aca3c1dbb7bfcba4846a5..5dee1fb8baf594fb137dce3802a44cc0118f1558 100644 --- a/geneformer/mtl/train.py +++ b/geneformer/mtl/train.py @@ -1,707 +1,380 @@ import os +import random + +import numpy as np import pandas as pd import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -import optuna -import functools -import time +from .imports import * from .model import GeneformerMultiTask -from .utils import ( - calculate_metrics, - get_layer_freeze_range, - set_seed, - initialize_wandb, - create_model, - setup_optimizer_and_scheduler, - save_model, - save_hyperparameters, - prepare_training_environment, - log_training_step, - log_validation_metrics, - save_validation_predictions, - setup_logging, - setup_distributed_environment, - train_distributed -) - - -class Trainer: - """Trainer class for multi-task learning""" - - def __init__(self, config): - self.config = config - self.device = None - self.model = None - self.optimizer = None - self.scheduler = None - self.writer = None - self.is_distributed = config.get("distributed_training", False) - self.local_rank = config.get("local_rank", 0) - self.is_main_process = not self.is_distributed or self.local_rank == 0 - - def train_epoch(self, train_loader, epoch): - """Train the model for one epoch.""" - epoch_start = time.time() - self.model.train() - - # For distributed training, we need to be aware of the global batch count - if self.is_distributed: - # Get world size for reporting - world_size = dist.get_world_size() - # Calculate total batches across all GPUs - total_batches_global = len(train_loader) * world_size if self.local_rank == 0 else len(train_loader) - else: - world_size = 1 - total_batches_global = len(train_loader) - - progress_bar = None - if self.is_main_process: - # Use the global batch count for progress reporting in distributed mode - progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}", - total=len(train_loader)) - iterator = progress_bar - - # Report distributed training information - if self.is_distributed: - print(f"Distributed training: {world_size} GPUs, {len(train_loader)} batches per GPU, " - f"{total_batches_global} total batches globally") - else: - iterator = train_loader - - batch_times = [] - forward_times = [] - backward_times = [] - optimizer_times = [] - - # Get gradient accumulation steps from config (default to 1 if not specified) - accumulation_steps = self.config.get("gradient_accumulation_steps", 1) - - # Zero gradients at the beginning - self.optimizer.zero_grad() - - # Track loss for the entire epoch - total_loss = 0.0 - num_batches = 0 - accumulated_loss = 0.0 - - for batch_idx, batch in enumerate(iterator): - batch_start = time.time() - - input_ids = batch["input_ids"].to(self.device) - attention_mask = batch["attention_mask"].to(self.device) +from .utils import calculate_task_specific_metrics, get_layer_freeze_range + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def initialize_wandb(config): + if config.get("use_wandb", False): + import wandb + + wandb.init(project=config["wandb_project"], config=config) + print("Weights & Biases (wandb) initialized and will be used for logging.") + else: + print( + "Weights & Biases (wandb) is not enabled. Logging will use other methods." + ) + + +def create_model(config, num_labels_list, device): + model = GeneformerMultiTask( + config["pretrained_path"], + num_labels_list, + dropout_rate=config["dropout_rate"], + use_task_weights=config["use_task_weights"], + task_weights=config["task_weights"], + max_layers_to_freeze=config["max_layers_to_freeze"], + use_attention_pooling=config["use_attention_pooling"], + ) + if config["use_data_parallel"]: + model = nn.DataParallel(model) + return model.to(device) + + +def setup_optimizer_and_scheduler(model, config, total_steps): + optimizer = AdamW( + model.parameters(), + lr=config["learning_rate"], + weight_decay=config["weight_decay"], + ) + warmup_steps = int(config["warmup_ratio"] * total_steps) + + if config["lr_scheduler_type"] == "linear": + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps + ) + elif config["lr_scheduler_type"] == "cosine": + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=total_steps, + num_cycles=0.5, + ) + + return optimizer, scheduler + + +def train_epoch( + model, train_loader, optimizer, scheduler, device, config, writer, epoch +): + model.train() + progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") + for batch_idx, batch in enumerate(progress_bar): + optimizer.zero_grad() + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + labels = [ + batch["labels"][task_name].to(device) for task_name in config["task_names"] + ] + + loss, _, _ = model(input_ids, attention_mask, labels) + loss.backward() + + if config["gradient_clipping"]: + torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"]) + + optimizer.step() + scheduler.step() + + writer.add_scalar( + "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx + ) + if config.get("use_wandb", False): + import wandb + + wandb.log({"Training Loss": loss.item()}) + + # Update progress bar + progress_bar.set_postfix({"loss": f"{loss.item():.4f}"}) + + return loss.item() # Return the last batch loss + + +def validate_model(model, val_loader, device, config): + model.eval() + val_loss = 0.0 + task_true_labels = {task_name: [] for task_name in config["task_names"]} + task_pred_labels = {task_name: [] for task_name in config["task_names"]} + task_pred_probs = {task_name: [] for task_name in config["task_names"]} + + with torch.no_grad(): + for batch in val_loader: + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) labels = [ - batch["labels"][task_name].to(self.device) for task_name in self.config["task_names"] + batch["labels"][task_name].to(device) + for task_name in config["task_names"] ] + loss, logits, _ = model(input_ids, attention_mask, labels) + val_loss += loss.item() + + for sample_idx in range(len(batch["input_ids"])): + for i, task_name in enumerate(config["task_names"]): + true_label = batch["labels"][task_name][sample_idx].item() + pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() + pred_prob = ( + torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy() + ) + task_true_labels[task_name].append(true_label) + task_pred_labels[task_name].append(pred_label) + task_pred_probs[task_name].append(pred_prob) + + val_loss /= len(val_loader) + return val_loss, task_true_labels, task_pred_labels, task_pred_probs + + +def log_metrics(task_metrics, val_loss, config, writer, epochs): + for task_name, metrics in task_metrics.items(): + print( + f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}" + ) + if config.get("use_wandb", False): + import wandb - forward_start = time.time() - loss, _, _ = self.model(input_ids, attention_mask, labels) - - # Scale loss by accumulation steps for gradient accumulation - if accumulation_steps > 1: - loss = loss / accumulation_steps - - forward_end = time.time() - forward_times.append(forward_end - forward_start) - - # Track loss - store the unscaled loss for reporting - unscaled_loss = loss.item() * (1 if accumulation_steps == 1 else accumulation_steps) - total_loss += unscaled_loss - num_batches += 1 - accumulated_loss += loss.item() # For gradient accumulation tracking - - backward_start = time.time() - - # Use no_sync() for all but the last accumulation step to avoid unnecessary communication - if self.is_distributed and accumulation_steps > 1: - # If this is not the last accumulation step or the last batch - if (batch_idx + 1) % accumulation_steps != 0 and (batch_idx + 1) != len(train_loader): - with self.model.no_sync(): - loss.backward() - else: - loss.backward() - else: - # Non-distributed training or accumulation_steps=1 - loss.backward() - - backward_end = time.time() - backward_times.append(backward_end - backward_start) - - # Only update weights after accumulation_steps or at the end of the epoch - if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader): - if self.config["gradient_clipping"]: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"]) - - optimizer_start = time.time() - self.optimizer.step() - self.scheduler.step() - self.optimizer.zero_grad() - optimizer_end = time.time() - optimizer_times.append(optimizer_end - optimizer_start) - - # Log after optimizer step - if self.is_main_process: - # Calculate running average loss - avg_loss = total_loss / num_batches - - log_training_step(avg_loss, self.writer, self.config, epoch, len(train_loader), batch_idx) - - # Update progress bar with just the running average loss - progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"}) - - accumulated_loss = 0.0 - else: - optimizer_times.append(0) # No optimizer step taken - - batch_end = time.time() - batch_times.append(batch_end - batch_start) - - epoch_end = time.time() - - # Calculate average loss for the epoch - epoch_avg_loss = total_loss / num_batches - - # If distributed, gather losses from all processes to compute global average - if self.is_distributed: - # Create a tensor to hold the loss - loss_tensor = torch.tensor([epoch_avg_loss], device=self.device) - # Gather losses from all processes - all_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())] - dist.all_gather(all_losses, loss_tensor) - # Compute the global average loss across all processes - epoch_avg_loss = torch.mean(torch.stack(all_losses)).item() - - if self.is_main_process: - # douhble check if batch_size has already been adjusted for world_size in the config - # This avoids double-counting the effective batch size - per_gpu_batch_size = self.config['batch_size'] - total_effective_batch = per_gpu_batch_size * accumulation_steps * world_size - - print(f"Epoch {epoch+1} timing:") - print(f" Total epoch time: {epoch_end - epoch_start:.2f}s") - print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s") - print(f" Average forward time: {sum(forward_times)/len(forward_times):.4f}s") - print(f" Average backward time: {sum(backward_times)/len(backward_times):.4f}s") - print(f" Average optimizer time: {sum([t for t in optimizer_times if t > 0])/max(1, len([t for t in optimizer_times if t > 0])):.4f}s") - print(f" Gradient accumulation steps: {accumulation_steps}") - print(f" Batch size per GPU: {per_gpu_batch_size}") - print(f" Effective global batch size: {total_effective_batch}") - print(f" Average training loss: {epoch_avg_loss:.4f}") - if self.is_distributed: - print(f" Total batches processed across all GPUs: {total_batches_global}") - print(f" Communication optimization: Using no_sync() for gradient accumulation") - - return epoch_avg_loss # Return the average loss for the epoch - - def validate_model(self, val_loader): - val_start = time.time() - self.model.eval() - val_loss = 0.0 - task_true_labels = {task_name: [] for task_name in self.config["task_names"]} - task_pred_labels = {task_name: [] for task_name in self.config["task_names"]} - task_pred_probs = {task_name: [] for task_name in self.config["task_names"]} - - val_cell_ids = {} - sample_counter = 0 - - batch_times = [] - - # Print validation dataset size - if self.is_main_process: - print(f"Validation dataset size: {len(val_loader.dataset)} samples") - print(f"Number of validation batches: {len(val_loader)}") - - if self.is_distributed: - world_size = dist.get_world_size() - print(f"Distributed validation: {world_size} GPUs") - if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'): - samples_per_gpu = val_loader.sampler.num_samples - print(f"Each GPU processes {samples_per_gpu} validation samples") - print(f"Total validation samples processed: {samples_per_gpu * world_size}") - - with torch.no_grad(): - for batch in val_loader: - batch_start = time.time() - input_ids = batch["input_ids"].to(self.device) - attention_mask = batch["attention_mask"].to(self.device) - labels = [ - batch["labels"][task_name].to(self.device) - for task_name in self.config["task_names"] - ] - loss, logits, _ = self.model(input_ids, attention_mask, labels) - val_loss += loss.item() - - if "cell_id" in batch: - for i, cell_id in enumerate(batch["cell_id"]): - # Store the actual index for later mapping to unique_cell_id - val_cell_ids[sample_counter + i] = cell_id.item() - - for sample_idx in range(len(batch["input_ids"])): - for i, task_name in enumerate(self.config["task_names"]): - true_label = batch["labels"][task_name][sample_idx].item() - pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() - # Store the full probability distribution - pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy().tolist() - task_true_labels[task_name].append(true_label) - task_pred_labels[task_name].append(pred_label) - task_pred_probs[task_name].append(pred_prob) - - # Update current index for cell ID tracking - sample_counter += len(batch["input_ids"]) - - batch_end = time.time() - batch_times.append(batch_end - batch_start) - - # norm validation loss by the number of batches - val_loss /= len(val_loader) - - # distributed, gather results from all processes - if self.is_distributed: - # Create tensors to hold the local results - loss_tensor = torch.tensor([val_loss], device=self.device) - gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(dist.get_world_size())] - dist.all_gather(gathered_losses, loss_tensor) - - # Compute average loss across all processes - val_loss = torch.mean(torch.cat(gathered_losses)).item() - - world_size = dist.get_world_size() - - if self.is_main_process: - print(f"Collected predictions from rank {self.local_rank}") - print(f"Number of samples processed by this rank: {sample_counter}") - - val_end = time.time() - - if self.is_main_process: - print(f"Validation timing:") - print(f" Total validation time: {val_end - val_start:.2f}s") - print(f" Average batch time: {sum(batch_times)/len(batch_times):.4f}s") - print(f" Collected {len(val_cell_ids)} cell indices from validation data") - print(f" Processed {sample_counter} total samples during validation") - - # Print number of samples per task - for task_name in self.config["task_names"]: - print(f" Task {task_name}: {len(task_true_labels[task_name])} samples") - - return val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids - - def train(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): - """Train the model and return validation loss and trained model.""" - if self.config.get("use_wandb", False) and self.is_main_process: - initialize_wandb(self.config) - - # Create model - self.model = create_model(self.config, num_labels_list, self.device, self.is_distributed, self.local_rank) - - # Setup optimizer and scheduler - total_steps = len(train_loader) * self.config["epochs"] - self.optimizer, self.scheduler = setup_optimizer_and_scheduler(self.model, self.config, total_steps) - - # Training loop - if self.is_main_process: - epoch_progress = tqdm(range(self.config["epochs"]), desc="Training Progress") - else: - epoch_progress = range(self.config["epochs"]) - - best_val_loss = float('inf') - train_losses = [] - - with setup_logging(self.config) as self.writer: - for epoch in epoch_progress: - if self.is_distributed: - train_loader.sampler.set_epoch(epoch) - - train_loss = self.train_epoch(train_loader, epoch) - train_losses.append(train_loss) - - # Run validation after each epoch if configured to do so - if self.config.get("validate_each_epoch", False): - val_loss, _, _, _, _ = self.validate_model(val_loader) - if val_loss < best_val_loss: - best_val_loss = val_loss - - if self.is_main_process: - epoch_progress.set_postfix({ - "train_loss": f"{train_loss:.4f}", - "val_loss": f"{val_loss:.4f}", - "best_val_loss": f"{best_val_loss:.4f}" - }) - else: - if self.is_main_process: - epoch_progress.set_postfix({ - "train_loss": f"{train_loss:.4f}" - }) - - val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = self.validate_model(val_loader) - task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") - - if self.is_main_process: - log_validation_metrics(task_metrics, val_loss, self.config, self.writer, self.config["epochs"]) - - # Save validation predictions - save_validation_predictions( - val_cell_ids, - task_true_labels, - task_pred_labels, - task_pred_probs, - {**self.config, "val_cell_mapping": val_cell_id_mapping} # Include the mapping - ) - - if self.config.get("use_wandb", False): - import wandb - wandb.finish() - - print(f"\nTraining Summary:") - print(f" Final Training Loss: {train_losses[-1]:.4f}") - print(f" Final Validation Loss: {val_loss:.4f}") - for task_name, metrics in task_metrics.items(): - print(f" {task_name} - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}") - - return val_loss, self.model # Return both the validation loss and the trained model - - def setup(self, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): - if self.is_distributed: - self.device = torch.device(f"cuda:{self.local_rank}") - else: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.model = create_model(self.config, num_labels_list, self.device) - - # war model w DDP - if self.is_distributed: - self.model = DDP(self.model, device_ids=[self.local_rank]) - - # communication hook to optimize gradient synchronization - from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks - - # default hook which maintains full precision - self.model.register_comm_hook( - state=None, - hook=comm_hooks.allreduce_hook + wandb.log( + { + f"{task_name} Validation F1 Macro": metrics["f1"], + f"{task_name} Validation Accuracy": metrics["accuracy"], + } ) - - print(f"Rank {self.local_rank}: Registered communication hook for optimized gradient synchronization") - - print(f"Rank {self.local_rank}: Using samplers created in distributed worker") - print(f"Rank {self.local_rank}: Training dataset has {len(train_loader.dataset)} samples") - if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'num_samples'): - print(f"Rank {self.local_rank}: This GPU will process {train_loader.sampler.num_samples} training samples per epoch") - - if hasattr(val_loader, 'sampler') and hasattr(val_loader.sampler, 'num_samples'): - print(f"Rank {self.local_rank}: This GPU will process {val_loader.sampler.num_samples} validation samples") - - # Set up optimizer and scheduler - self.optimizer, self.scheduler = setup_optimizer_and_scheduler( - self.model, self.config, len(train_loader) + + writer.add_scalar("Validation Loss", val_loss, epochs) + for task_name, metrics in task_metrics.items(): + writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs) + writer.add_scalar( + f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs ) - if self.is_main_process and self.config.get("use_wandb", False): - initialize_wandb(self.config) - - return train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - - -def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): - """Train a model with the given configuration and data.""" - # Check if distributed training is enabled - if config.get("distributed_training", False): - # Check if we have multiple GPUs - if torch.cuda.device_count() > 1: - result = train_distributed( - Trainer, - config, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list + +def save_validation_predictions( + val_cell_id_mapping, + task_true_labels, + task_pred_labels, + task_pred_probs, + config, + trial_number=None, +): + if trial_number is not None: + trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}") + os.makedirs(trial_results_dir, exist_ok=True) + val_preds_file = os.path.join(trial_results_dir, "val_preds.csv") + else: + val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv") + + rows = [] + for sample_idx in range(len(val_cell_id_mapping)): + row = {"Cell ID": val_cell_id_mapping[sample_idx]} + for task_name in config["task_names"]: + row[f"{task_name} True"] = task_true_labels[task_name][sample_idx] + row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx] + row[f"{task_name} Probabilities"] = ",".join( + map(str, task_pred_probs[task_name][sample_idx]) ) - if result is not None: - return result - else: - print("Distributed training requested but only one GPU found. Falling back to single GPU training.") - config["distributed_training"] = False - - # Non-distributed training - trainer = Trainer(config) - trainer.device = device - return trainer.train(train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list) + rows.append(row) + df = pd.DataFrame(rows) + df.to_csv(val_preds_file, index=False) + print(f"Validation predictions saved to {val_preds_file}") -def objective( - trial, + +def train_model( + config, + device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, - config, - device, ): - """Objective function for Optuna hyperparameter optimization.""" set_seed(config["seed"]) initialize_wandb(config) - trial_config = config.copy() - - # Suggest hyperparameters for this trial - for param_name, param_config in config["hyperparameters"].items(): - if param_name == "lr_scheduler_type": - trial_config[param_name] = trial.suggest_categorical( - param_name, param_config["choices"] - ) - elif param_name == "task_weights" and config["use_task_weights"]: - weights = [ - trial.suggest_float( - f"task_weight_{i}", - param_config["low"], - param_config["high"], - ) - for i in range(len(num_labels_list)) - ] - weight_sum = sum(weights) - trial_config[param_name] = [w / weight_sum for w in weights] - elif "log" in param_config and param_config["log"]: - trial_config[param_name] = trial.suggest_float( - param_name, param_config["low"], param_config["high"], log=True - ) - else: - trial_config[param_name] = trial.suggest_float( - param_name, param_config["low"], param_config["high"] - ) - - # Set appropriate max layers to freeze based on pretrained model - if "max_layers_to_freeze" in trial_config: - freeze_range = get_layer_freeze_range(trial_config["pretrained_path"]) - trial_config["max_layers_to_freeze"] = int(trial.suggest_int( - "max_layers_to_freeze", - freeze_range["min"], - freeze_range["max"] - )) - - trial_config["run_name"] = f"trial_{trial.number}" - - # Handle distributed training for this trial - if trial_config.get("distributed_training", False) and torch.cuda.device_count() > 1: - manager = mp.Manager() - shared_dict = manager.dict() - - train_distributed( - Trainer, - trial_config, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list, - trial.number, - shared_dict - ) - - val_loss = shared_dict.get('val_loss', float('inf')) - task_metrics = shared_dict.get('task_metrics', {}) - - trial.set_user_attr("model_state_dict", shared_dict.get('model_state_dict', {})) - trial.set_user_attr("task_weights", trial_config["task_weights"]) - - if config.get("use_wandb", False): - import wandb - wandb.log({ - "trial_number": trial.number, - "val_loss": val_loss, - **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()}, - **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()}, - }) - wandb.finish() - - return val_loss - - with setup_logging(trial_config) as writer: - trainer = Trainer(trial_config) - trainer.device = device - trainer.writer = writer - - # Create model with trial hyperparameters - trainer.model = create_model(trial_config, num_labels_list, device) - total_steps = len(train_loader) * config["epochs"] - trainer.optimizer, trainer.scheduler = setup_optimizer_and_scheduler(trainer.model, trial_config, total_steps) - - # Training loop - for epoch in range(config["epochs"]): - trainer.train_epoch(train_loader, epoch) - - val_loss, task_true_labels, task_pred_labels, task_pred_probs, val_cell_ids = trainer.validate_model(val_loader) - task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") - - # Log metrics - log_validation_metrics(task_metrics, val_loss, trial_config, writer, config["epochs"]) - - # Save validation predictions - save_validation_predictions( - val_cell_ids, - task_true_labels, - task_pred_labels, - task_pred_probs, - {**trial_config, "val_cell_mapping": val_cell_id_mapping}, - trial.number, + model = create_model(config, num_labels_list, device) + total_steps = len(train_loader) * config["epochs"] + optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) + + log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run") + writer = SummaryWriter(log_dir=log_dir) + + epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress") + for epoch in epoch_progress: + last_loss = train_epoch( + model, train_loader, optimizer, scheduler, device, config, writer, epoch ) + epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"}) - # Store model state dict and task weights in trial user attributes - trial.set_user_attr("model_state_dict", trainer.model.state_dict()) - trial.set_user_attr("task_weights", trial_config["task_weights"]) + val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model( + model, val_loader, device, config + ) + task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) - # Report intermediate value to Optuna - trial.report(val_loss, config["epochs"]) - if trial.should_prune(): - raise optuna.TrialPruned() + log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) + writer.close() - if config.get("use_wandb", False): - import wandb - wandb.log( - { - "trial_number": trial.number, - "val_loss": val_loss, - **{f"{task_name}_f1": metrics["f1"] for task_name, metrics in task_metrics.items()}, - **{f"{task_name}_accuracy": metrics["accuracy"] for task_name, metrics in task_metrics.items()}, - **{k: v for k, v in trial_config.items() if k in [ - "learning_rate", "warmup_ratio", "weight_decay", "dropout_rate", - "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze" - ]}, - } - ) - wandb.finish() + save_validation_predictions( + val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config + ) - return val_loss + if config.get("use_wandb", False): + import wandb + wandb.finish() -def run_manual_tuning(config): - """Run training with manually specified hyperparameters.""" - ( - device, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list, - ) = prepare_training_environment(config) + print(f"\nFinal Validation Loss: {val_loss:.4f}") + return val_loss, model # Return both the validation loss and the trained model - print("\nManual hyperparameters being used:") - for key, value in config["manual_hyperparameters"].items(): - print(f"{key}: {value}") - print() - # Update config with manual hyperparameters - for key, value in config["manual_hyperparameters"].items(): - config[key] = value +def objective( + trial, + train_loader, + val_loader, + train_cell_id_mapping, + val_cell_id_mapping, + num_labels_list, + config, + device, +): + set_seed(config["seed"]) # Set the seed before each trial + initialize_wandb(config) - # Train the model - val_loss, trained_model = train_model( - config, - device, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list, + # Hyperparameters + config["learning_rate"] = trial.suggest_float( + "learning_rate", + config["hyperparameters"]["learning_rate"]["low"], + config["hyperparameters"]["learning_rate"]["high"], + log=config["hyperparameters"]["learning_rate"]["log"], + ) + config["warmup_ratio"] = trial.suggest_float( + "warmup_ratio", + config["hyperparameters"]["warmup_ratio"]["low"], + config["hyperparameters"]["warmup_ratio"]["high"], + ) + config["weight_decay"] = trial.suggest_float( + "weight_decay", + config["hyperparameters"]["weight_decay"]["low"], + config["hyperparameters"]["weight_decay"]["high"], + ) + config["dropout_rate"] = trial.suggest_float( + "dropout_rate", + config["hyperparameters"]["dropout_rate"]["low"], + config["hyperparameters"]["dropout_rate"]["high"], + ) + config["lr_scheduler_type"] = trial.suggest_categorical( + "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"] + ) + config["use_attention_pooling"] = trial.suggest_categorical( + "use_attention_pooling", [False] ) - print(f"\nValidation loss with manual hyperparameters: {val_loss}") - - # Save the trained model - only if not using distributed training - # (distributed training saves the model in the worker) - if not config.get("distributed_training", False): - model_save_directory = os.path.join( - config["model_save_path"], "GeneformerMultiTask" - ) - save_model(trained_model, model_save_directory) - - # Save the hyperparameters - hyperparams_to_save = { - **config["manual_hyperparameters"], - "dropout_rate": config["dropout_rate"], - "use_task_weights": config["use_task_weights"], - "task_weights": config["task_weights"], - "max_layers_to_freeze": config["max_layers_to_freeze"], - "use_attention_pooling": config["use_attention_pooling"], - } - save_hyperparameters(model_save_directory, hyperparams_to_save) + if config["use_task_weights"]: + config["task_weights"] = [ + trial.suggest_float( + f"task_weight_{i}", + config["hyperparameters"]["task_weights"]["low"], + config["hyperparameters"]["task_weights"]["high"], + ) + for i in range(len(num_labels_list)) + ] + weight_sum = sum(config["task_weights"]) + config["task_weights"] = [ + weight / weight_sum for weight in config["task_weights"] + ] + else: + config["task_weights"] = None + + # Dynamic range for max_layers_to_freeze + freeze_range = get_layer_freeze_range(config["pretrained_path"]) + config["max_layers_to_freeze"] = trial.suggest_int( + "max_layers_to_freeze", + freeze_range["min"], + freeze_range["max"] + ) - return val_loss + model = create_model(config, num_labels_list, device) + total_steps = len(train_loader) * config["epochs"] + optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) + log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}") + writer = SummaryWriter(log_dir=log_dir) -def run_optuna_study(config): - """Run hyperparameter optimization using Optuna.""" - # Prepare training environment - ( - device, - train_loader, - val_loader, - train_cell_id_mapping, - val_cell_id_mapping, - num_labels_list, - ) = prepare_training_environment(config) - - # If manual hyperparameters are specified, use them instead of running Optuna - if config.get("use_manual_hyperparameters", False): - return run_manual_tuning(config) - - # Create a partial function with fixed arguments for the objective - objective_with_config_and_data = functools.partial( - objective, - train_loader=train_loader, - val_loader=val_loader, - train_cell_id_mapping=train_cell_id_mapping, - val_cell_id_mapping=val_cell_id_mapping, - num_labels_list=num_labels_list, - config=config, - device=device, - ) + for epoch in range(config["epochs"]): + train_epoch( + model, train_loader, optimizer, scheduler, device, config, writer, epoch + ) - # Create and run the Optuna study - study = optuna.create_study( - direction="minimize", # Minimize validation loss - study_name=config["study_name"], - # storage=config["storage"], - load_if_exists=True, + val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model( + model, val_loader, device, config ) + task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) - study.optimize(objective_with_config_and_data, n_trials=config["n_trials"]) + log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) + writer.close() - # After finding the best trial - best_params = study.best_trial.params - best_task_weights = study.best_trial.user_attrs["task_weights"] - print("Saving the best model and its hyperparameters...") - - # Create a model with the best hyperparameters - best_model = GeneformerMultiTask( - config["pretrained_path"], - num_labels_list, - dropout_rate=best_params["dropout_rate"], - use_task_weights=config["use_task_weights"], - task_weights=best_task_weights, - max_layers_to_freeze=best_params.get("max_layers_to_freeze", 0), - use_attention_pooling=best_params.get("use_attention_pooling", False), + save_validation_predictions( + val_cell_id_mapping, + task_true_labels, + task_pred_labels, + task_pred_probs, + config, + trial.number, ) - # Get the best model state dictionary - best_model_state_dict = study.best_trial.user_attrs["model_state_dict"] + trial.set_user_attr("model_state_dict", model.state_dict()) + trial.set_user_attr("task_weights", config["task_weights"]) - best_model_state_dict = { - k.replace("module.", ""): v for k, v in best_model_state_dict.items() - } + trial.report(val_loss, config["epochs"]) - best_model.load_state_dict(best_model_state_dict, strict=False) + if trial.should_prune(): + raise optuna.TrialPruned() - model_save_directory = os.path.join( - config["model_save_path"], "GeneformerMultiTask" - ) - save_model(best_model, model_save_directory) + if config.get("use_wandb", False): + import wandb - save_hyperparameters(model_save_directory, {**best_params, "task_weights": best_task_weights}) + wandb.log( + { + "trial_number": trial.number, + "val_loss": val_loss, + **{ + f"{task_name}_f1": metrics["f1"] + for task_name, metrics in task_metrics.items() + }, + **{ + f"{task_name}_accuracy": metrics["accuracy"] + for task_name, metrics in task_metrics.items() + }, + **{ + k: v + for k, v in config.items() + if k + in [ + "learning_rate", + "warmup_ratio", + "weight_decay", + "dropout_rate", + "lr_scheduler_type", + "use_attention_pooling", + "max_layers_to_freeze", + ] + }, + } + ) + wandb.finish() - return study.best_trial.value # Return the best validation loss + return val_loss diff --git a/geneformer/mtl/train_utils.py b/geneformer/mtl/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..430994a37a53dcde99666a7b5a4d99532e9bc8ba --- /dev/null +++ b/geneformer/mtl/train_utils.py @@ -0,0 +1,161 @@ +import random + +from .data import get_data_loader, preload_and_process_data +from .imports import * +from .model import GeneformerMultiTask +from .train import objective, train_model +from .utils import save_model + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def run_manual_tuning(config): + # Set seed for reproducibility + set_seed(config["seed"]) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ( + train_dataset, + train_cell_id_mapping, + val_dataset, + val_cell_id_mapping, + num_labels_list, + ) = preload_and_process_data(config) + train_loader = get_data_loader(train_dataset, config["batch_size"]) + val_loader = get_data_loader(val_dataset, config["batch_size"]) + + # Print the manual hyperparameters being used + print("\nManual hyperparameters being used:") + for key, value in config["manual_hyperparameters"].items(): + print(f"{key}: {value}") + print() # Add an empty line for better readability + + # Use the manual hyperparameters + for key, value in config["manual_hyperparameters"].items(): + config[key] = value + + # Train the model + val_loss, trained_model = train_model( + config, + device, + train_loader, + val_loader, + train_cell_id_mapping, + val_cell_id_mapping, + num_labels_list, + ) + + print(f"\nValidation loss with manual hyperparameters: {val_loss}") + + # Save the trained model + model_save_directory = os.path.join( + config["model_save_path"], "GeneformerMultiTask" + ) + save_model(trained_model, model_save_directory) + + # Save the hyperparameters + hyperparams_to_save = { + **config["manual_hyperparameters"], + "dropout_rate": config["dropout_rate"], + "use_task_weights": config["use_task_weights"], + "task_weights": config["task_weights"], + "max_layers_to_freeze": config["max_layers_to_freeze"], + "use_attention_pooling": config["use_attention_pooling"], + } + hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") + with open(hyperparams_path, "w") as f: + json.dump(hyperparams_to_save, f) + print(f"Manual hyperparameters saved to {hyperparams_path}") + + return val_loss + + +def run_optuna_study(config): + # Set seed for reproducibility + set_seed(config["seed"]) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ( + train_dataset, + train_cell_id_mapping, + val_dataset, + val_cell_id_mapping, + num_labels_list, + ) = preload_and_process_data(config) + train_loader = get_data_loader(train_dataset, config["batch_size"]) + val_loader = get_data_loader(val_dataset, config["batch_size"]) + + if config["use_manual_hyperparameters"]: + train_model( + config, + device, + train_loader, + val_loader, + train_cell_id_mapping, + val_cell_id_mapping, + num_labels_list, + ) + else: + objective_with_config_and_data = functools.partial( + objective, + train_loader=train_loader, + val_loader=val_loader, + train_cell_id_mapping=train_cell_id_mapping, + val_cell_id_mapping=val_cell_id_mapping, + num_labels_list=num_labels_list, + config=config, + device=device, + ) + + study = optuna.create_study( + direction="minimize", # Minimize validation loss + study_name=config["study_name"], + # storage=config["storage"], + load_if_exists=True, + ) + + study.optimize(objective_with_config_and_data, n_trials=config["n_trials"]) + + # After finding the best trial + best_params = study.best_trial.params + best_task_weights = study.best_trial.user_attrs["task_weights"] + print("Saving the best model and its hyperparameters...") + + # Saving model as before + best_model = GeneformerMultiTask( + config["pretrained_path"], + num_labels_list, + dropout_rate=best_params["dropout_rate"], + use_task_weights=config["use_task_weights"], + task_weights=best_task_weights, + ) + + # Get the best model state dictionary + best_model_state_dict = study.best_trial.user_attrs["model_state_dict"] + + # Remove the "module." prefix from the state dictionary keys if present + best_model_state_dict = { + k.replace("module.", ""): v for k, v in best_model_state_dict.items() + } + + # Load the modified state dictionary into the model, skipping unexpected keys + best_model.load_state_dict(best_model_state_dict, strict=False) + + model_save_directory = os.path.join( + config["model_save_path"], "GeneformerMultiTask" + ) + save_model(best_model, model_save_directory) + + # Additionally, save the best hyperparameters and task weights + hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") + + with open(hyperparams_path, "w") as f: + json.dump({**best_params, "task_weights": best_task_weights}, f) + print(f"Best hyperparameters and task weights saved to {hyperparams_path}") diff --git a/geneformer/mtl/utils.py b/geneformer/mtl/utils.py index 1a59a63b4a4a9d2642ed5d504bd1a38b054d10be..5de5079ffdefb853a183038a6b3956de42f19978 100644 --- a/geneformer/mtl/utils.py +++ b/geneformer/mtl/utils.py @@ -1,641 +1,129 @@ -from typing import Dict, List, Optional, Union -import json import os -import pickle -import random -import torch -import numpy as np -import optuna +import shutil + from sklearn.metrics import accuracy_score, f1_score from sklearn.preprocessing import LabelEncoder -from torch.utils.tensorboard import SummaryWriter -from transformers import AutoConfig, BertConfig, BertModel, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup -from torch.optim import AdamW -import pandas as pd -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP -import torch.multiprocessing as mp -from contextlib import contextmanager - - -def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def initialize_wandb(config): - if config.get("use_wandb", False): - import wandb - wandb.init( - project=config.get("wandb_project", "geneformer_multitask"), - name=config.get("run_name", "experiment"), - config=config, - reinit=True, - ) +from transformers import AutoConfig, BertConfig, BertModel - -def create_model(config, num_labels_list, device, is_distributed=False, local_rank=0): - """Create and initialize the model based on configuration.""" - from .model import GeneformerMultiTask - - model = GeneformerMultiTask( - config["pretrained_path"], - num_labels_list, - dropout_rate=config.get("dropout_rate", 0.1), - use_task_weights=config.get("use_task_weights", False), - task_weights=config.get("task_weights", None), - max_layers_to_freeze=config.get("max_layers_to_freeze", 0), - use_attention_pooling=config.get("use_attention_pooling", False), - ) - - # Move model to device - model.to(device) - - if is_distributed: - model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) - - return model - - -def setup_optimizer_and_scheduler(model, config, total_steps): - """Set up optimizer and learning rate scheduler.""" - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() - if not any(nd in n for nd in no_decay) and p.requires_grad], - "weight_decay": config["weight_decay"], - }, - { - "params": [p for n, p in model.named_parameters() - if any(nd in n for nd in no_decay) and p.requires_grad], - "weight_decay": 0.0, - }, - ] - - optimizer = AdamW( - optimizer_grouped_parameters, - lr=config["learning_rate"], - eps=config.get("adam_epsilon", 1e-8) - ) - - # Prepare scheduler - warmup_steps = int(total_steps * config["warmup_ratio"]) - - scheduler_map = { - "linear": get_linear_schedule_with_warmup, - "cosine": get_cosine_schedule_with_warmup - } - - scheduler_fn = scheduler_map.get(config["lr_scheduler_type"]) - if not scheduler_fn: - raise ValueError(f"Unsupported scheduler type: {config['lr_scheduler_type']}") - - scheduler = scheduler_fn(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) - - return optimizer, scheduler +from .imports import * def save_model(model, model_save_directory): - """Save model weights and configuration.""" - os.makedirs(model_save_directory, exist_ok=True) - - # Handle DDP model - if isinstance(model, DDP): - model_to_save = model.module + if not os.path.exists(model_save_directory): + os.makedirs(model_save_directory) + + # Get the state dict + if isinstance(model, nn.DataParallel): + model_state_dict = ( + model.module.state_dict() + ) # Use model.module to access the underlying model else: - model_to_save = model - - model_state_dict = model_to_save.state_dict() + model_state_dict = model.state_dict() + + # Remove the "module." prefix from the keys if present + model_state_dict = { + k.replace("module.", ""): v for k, v in model_state_dict.items() + } model_save_path = os.path.join(model_save_directory, "pytorch_model.bin") torch.save(model_state_dict, model_save_path) # Save the model configuration - model_to_save.config.to_json_file(os.path.join(model_save_directory, "config.json")) - - print(f"Model and configuration saved to {model_save_directory}") - - -def save_hyperparameters(model_save_directory, hyperparams): - """Save hyperparameters to a JSON file.""" - hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") - with open(hyperparams_path, "w") as f: - json.dump(hyperparams, f) - print(f"Hyperparameters saved to {hyperparams_path}") - - -def calculate_metrics(labels=None, preds=None, task_data=None, metric_type="task_specific", return_format="dict"): - if metric_type == "single": - # Calculate metrics for a single task - if labels is None or preds is None: - raise ValueError("Labels and predictions must be provided for single task metrics") - - task_name = None - if isinstance(labels, dict) and len(labels) == 1: - task_name = list(labels.keys())[0] - labels = labels[task_name] - preds = preds[task_name] - - f1 = f1_score(labels, preds, average="macro") - accuracy = accuracy_score(labels, preds) - - if return_format == "tuple": - return f1, accuracy - - result = {"f1": f1, "accuracy": accuracy} - if task_name: - return {task_name: result} - return result - - elif metric_type == "task_specific": - # Calculate metrics for multiple tasks - if task_data: - result = {} - for task_name, (task_labels, task_preds) in task_data.items(): - f1 = f1_score(task_labels, task_preds, average="macro") - accuracy = accuracy_score(task_labels, task_preds) - result[task_name] = {"f1": f1, "accuracy": accuracy} - return result - elif isinstance(labels, dict) and isinstance(preds, dict): - result = {} - for task_name in labels: - if task_name in preds: - f1 = f1_score(labels[task_name], preds[task_name], average="macro") - accuracy = accuracy_score(labels[task_name], preds[task_name]) - result[task_name] = {"f1": f1, "accuracy": accuracy} - return result - else: - raise ValueError("For task_specific metrics, either task_data or labels and preds dictionaries must be provided") - - elif metric_type == "combined": - # Calculate combined metrics across all tasks - if labels is None or preds is None: - raise ValueError("Labels and predictions must be provided for combined metrics") - - # Handle label encoding for non-numeric labels - if not all(isinstance(x, (int, float)) for x in labels + preds): - le = LabelEncoder() - le.fit(labels + preds) - labels = le.transform(labels) - preds = le.transform(preds) - - f1 = f1_score(labels, preds, average="macro") - accuracy = accuracy_score(labels, preds) - - if return_format == "tuple": - return f1, accuracy - return {"f1": f1, "accuracy": accuracy} - - else: - raise ValueError(f"Unknown metric_type: {metric_type}") - - -def get_layer_freeze_range(pretrained_path): - if not pretrained_path: - return {"min": 0, "max": 0} - - config = AutoConfig.from_pretrained(pretrained_path) - total_layers = config.num_hidden_layers - return {"min": 0, "max": total_layers - 1} - - -def prepare_training_environment(config): - """ - Prepare the training environment by setting seed and loading data. - - Returns: - tuple: (device, train_loader, val_loader, train_cell_id_mapping, - val_cell_id_mapping, num_labels_list) - """ - from .data import prepare_data_loaders - - # Set seed for reproducibility - set_seed(config["seed"]) - - # Set up device - for non-distributed training - if not config.get("distributed_training", False): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if isinstance(model, nn.DataParallel): + model.module.config.to_json_file( + os.path.join(model_save_directory, "config.json") + ) else: - # For distributed training, device will be set per process - device = None - - # Load data using the streaming dataset - data = prepare_data_loaders(config) - - # For distributed training, we'll set up samplers later in the distributed worker - # Don't create DistributedSampler here as process group isn't initialized yet - - return ( - device, - data["train_loader"], - data["val_loader"], - data["train_cell_mapping"], - data["val_cell_mapping"], - data["num_labels_list"], - ) + model.config.to_json_file(os.path.join(model_save_directory, "config.json")) + print(f"Model and configuration saved to {model_save_directory}") -# Optuna hyperparameter optimization utilities -def save_trial_callback(study, trial, trials_result_path): - """ - Callback to save Optuna trial results to a file. - - Args: - study: Optuna study object - trial: Current trial object - trials_result_path: Path to save trial results - """ - with open(trials_result_path, "a") as f: - f.write( - f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n" - ) +def calculate_task_specific_metrics(task_true_labels, task_pred_labels): + task_metrics = {} + for task_name in task_true_labels.keys(): + true_labels = task_true_labels[task_name] + pred_labels = task_pred_labels[task_name] + f1 = f1_score(true_labels, pred_labels, average="macro") + accuracy = accuracy_score(true_labels, pred_labels) + task_metrics[task_name] = {"f1": f1, "accuracy": accuracy} + return task_metrics -def create_optuna_study(objective, n_trials: int, trials_result_path: str, tensorboard_log_dir: str) -> optuna.Study: - """Create and run an Optuna study with TensorBoard logging.""" - from optuna.integration import TensorBoardCallback - - study = optuna.create_study(direction="maximize") - study.optimize( - objective, - n_trials=n_trials, - callbacks=[ - lambda study, trial: save_trial_callback(study, trial, trials_result_path), - TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro") - ] - ) - return study +def calculate_combined_f1(combined_labels, combined_preds): + # Initialize the LabelEncoder + le = LabelEncoder() -@contextmanager -def setup_logging(config): - run_name = config.get("run_name", "manual_run") - log_dir = os.path.join(config["tensorboard_log_dir"], run_name) - writer = SummaryWriter(log_dir=log_dir) - try: - yield writer - finally: - writer.close() + # Fit and transform combined labels and predictions to numerical values + le.fit(combined_labels + combined_preds) + encoded_true_labels = le.transform(combined_labels) + encoded_pred_labels = le.transform(combined_preds) + # Print out the mapping for sanity check + print("\nLabel Encoder Mapping:") + for index, class_label in enumerate(le.classes_): + print(f"'{class_label}': {index}") -def log_training_step(loss, writer, config, epoch, steps_per_epoch, batch_idx): - """Log training step metrics to TensorBoard and optionally W&B.""" - writer.add_scalar( - "Training Loss", loss, epoch * steps_per_epoch + batch_idx - ) - if config.get("use_wandb", False): - import wandb - wandb.log({"Training Loss": loss}) + # Calculate accuracy + accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels) + # Calculate F1 Macro score + f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro") -def log_validation_metrics(task_metrics, val_loss, config, writer, epoch): - """Log validation metrics to console, TensorBoard, and optionally W&B.""" - for task_name, metrics in task_metrics.items(): - print( - f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}" - ) - if config.get("use_wandb", False): - import wandb - wandb.log( - { - f"{task_name} Validation F1 Macro": metrics["f1"], - f"{task_name} Validation Accuracy": metrics["accuracy"], - } - ) + return f1, accuracy - writer.add_scalar("Validation Loss", val_loss, epoch) - for task_name, metrics in task_metrics.items(): - writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epoch) - writer.add_scalar( - f"{task_name} - Validation Accuracy", metrics["accuracy"], epoch - ) +# def save_model_without_heads(original_model_save_directory): +# # Create a new directory for the model without heads +# new_model_save_directory = original_model_save_directory + "_No_Heads" +# if not os.path.exists(new_model_save_directory): +# os.makedirs(new_model_save_directory) -def load_label_mappings(results_dir: str, task_names: List[str]) -> Dict[str, Dict]: - """Load or initialize task label mappings.""" - label_mappings_path = os.path.join(results_dir, "task_label_mappings_val.pkl") - if os.path.exists(label_mappings_path): - with open(label_mappings_path, 'rb') as f: - return pickle.load(f) - return {task_name: {} for task_name in task_names} +# # Load the model state dictionary +# model_state_dict = torch.load( +# os.path.join(original_model_save_directory, "pytorch_model.bin") +# ) +# # Initialize a new BERT model without the classification heads +# config = BertConfig.from_pretrained( +# os.path.join(original_model_save_directory, "config.json") +# ) +# model_without_heads = BertModel(config) -def create_prediction_row(sample_idx: int, val_cell_indices: Dict, task_true_labels: Dict, - task_pred_labels: Dict, task_pred_probs: Dict, task_names: List[str], - inverted_mappings: Dict, val_cell_mapping: Dict) -> Dict: - """Create a row for validation predictions.""" - batch_cell_idx = val_cell_indices.get(sample_idx) - cell_id = val_cell_mapping.get(batch_cell_idx, f"unknown_cell_{sample_idx}") if batch_cell_idx is not None else f"unknown_cell_{sample_idx}" - - row = {"Cell ID": cell_id} - for task_name in task_names: - if task_name in task_true_labels and sample_idx < len(task_true_labels[task_name]): - true_idx = task_true_labels[task_name][sample_idx] - pred_idx = task_pred_labels[task_name][sample_idx] - true_label = inverted_mappings.get(task_name, {}).get(true_idx, f"Unknown-{true_idx}") - pred_label = inverted_mappings.get(task_name, {}).get(pred_idx, f"Unknown-{pred_idx}") - - row.update({ - f"{task_name}_true_idx": true_idx, - f"{task_name}_pred_idx": pred_idx, - f"{task_name}_true_label": true_label, - f"{task_name}_pred_label": pred_label - }) - - if task_name in task_pred_probs and sample_idx < len(task_pred_probs[task_name]): - probs = task_pred_probs[task_name][sample_idx] - if isinstance(probs, (list, np.ndarray)) or (hasattr(probs, '__iter__') and not isinstance(probs, str)): - prob_list = list(probs) if not isinstance(probs, list) else probs - row[f"{task_name}_all_probs"] = ",".join(map(str, prob_list)) - for class_idx, prob in enumerate(prob_list): - class_label = inverted_mappings.get(task_name, {}).get(class_idx, f"Unknown-{class_idx}") - row[f"{task_name}_prob_{class_label}"] = prob - else: - row[f"{task_name}_all_probs"] = str(probs) - - return row +# # Filter the state dict to exclude classification heads +# model_without_heads_state_dict = { +# k: v +# for k, v in model_state_dict.items() +# if not k.startswith("classification_heads") +# } +# # Load the filtered state dict into the model +# model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False) -def save_validation_predictions( - val_cell_indices, - task_true_labels, - task_pred_labels, - task_pred_probs, - config, - trial_number=None, -): - """Save validation predictions to a CSV file with class labels and probabilities.""" - os.makedirs(config["results_dir"], exist_ok=True) - - if trial_number is not None: - os.makedirs(os.path.join(config["results_dir"], f"trial_{trial_number}"), exist_ok=True) - val_preds_file = os.path.join(config["results_dir"], f"trial_{trial_number}/val_preds.csv") - else: - val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv") - - if not val_cell_indices or not task_true_labels: - pd.DataFrame().to_csv(val_preds_file, index=False) - return - - try: - label_mappings = load_label_mappings(config["results_dir"], config["task_names"]) - inverted_mappings = {task: {idx: label for label, idx in mapping.items()} for task, mapping in label_mappings.items()} - val_cell_mapping = config.get("val_cell_mapping", {}) - - # Determine maximum number of samples - max_samples = max( - [len(val_cell_indices)] + - [len(task_true_labels[task]) for task in task_true_labels] - ) - - rows = [ - create_prediction_row( - sample_idx, val_cell_indices, task_true_labels, task_pred_labels, - task_pred_probs, config["task_names"], inverted_mappings, val_cell_mapping - ) - for sample_idx in range(max_samples) - ] - - pd.DataFrame(rows).to_csv(val_preds_file, index=False) - except Exception as e: - pd.DataFrame([{"Error": str(e)}]).to_csv(val_preds_file, index=False) +# # Save the model without heads +# model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin") +# torch.save(model_without_heads.state_dict(), model_save_path) +# # Copy the configuration file +# shutil.copy( +# os.path.join(original_model_save_directory, "config.json"), +# new_model_save_directory, +# ) -def setup_distributed_environment(rank, world_size, config): - """ - Setup the distributed training environment. - - Args: - rank (int): The rank of the current process - world_size (int): Total number of processes - config (dict): Configuration dictionary - """ - os.environ['MASTER_ADDR'] = config.get('master_addr', 'localhost') - os.environ['MASTER_PORT'] = config.get('master_port', '12355') - - # Initialize the process group - dist.init_process_group( - backend='nccl', - init_method='env://', - world_size=world_size, - rank=rank - ) - - # Set the device for this process - torch.cuda.set_device(rank) +# print(f"Model without classification heads saved to {new_model_save_directory}") -def train_distributed(trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None): - """Run distributed training across multiple GPUs with fallback to single GPU.""" - world_size = torch.cuda.device_count() - - if world_size <= 1: - print("Distributed training requested but only one GPU found. Falling back to single GPU training.") - config["distributed_training"] = False - trainer = trainer_class(config) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - trainer.device = device - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup( - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - ) - val_loss, model = trainer.train( - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - ) - model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") - save_model(model, model_save_directory) - save_hyperparameters(model_save_directory, { - **get_config_value(config, "manual_hyperparameters", {}), - "dropout_rate": config["dropout_rate"], - "use_task_weights": config["use_task_weights"], - "task_weights": config["task_weights"], - "max_layers_to_freeze": config["max_layers_to_freeze"], - "use_attention_pooling": config["use_attention_pooling"], - }) - - if shared_dict is not None: - shared_dict['val_loss'] = val_loss - task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions(model, val_loader, device, config) - shared_dict['task_metrics'] = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") - shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()} - - return val_loss, model - - print(f"Using distributed training with {world_size} GPUs") - mp.spawn( - _distributed_worker, - args=(world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number, shared_dict), - nprocs=world_size, - join=True - ) - - if trial_number is None and shared_dict is None: - model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") - model_path = os.path.join(model_save_directory, "pytorch_model.bin") - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = create_model(config, num_labels_list, device) - model.load_state_dict(torch.load(model_path)) - return 0.0, model - - return None - - -def _distributed_worker(rank, world_size, trainer_class, config, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, trial_number=None, shared_dict=None): - """Worker function for distributed training.""" - setup_distributed_environment(rank, world_size, config) - config["local_rank"] = rank - - # Set up distributed samplers - from torch.utils.data import DistributedSampler - from .data import get_data_loader - - train_sampler = DistributedSampler(train_loader.dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) - val_sampler = DistributedSampler(val_loader.dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) - - train_loader = get_data_loader(train_loader.dataset, config["batch_size"], sampler=train_sampler, shuffle=False) - val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=val_sampler, shuffle=False) - - if rank == 0: - print(f"Rank {rank}: Training {len(train_sampler)} samples, Validation {len(val_sampler)} samples") - print(f"Total samples across {world_size} GPUs: Training {len(train_sampler) * world_size}, Validation {len(val_sampler) * world_size}") - - # Create and setup trainer - trainer = trainer_class(config) - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list = trainer.setup( - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - ) - - # Train the model - val_loss, model = trainer.train( - train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list - ) - - # Save model only from the main process - if rank == 0: - model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") - save_model(model, model_save_directory) - - save_hyperparameters(model_save_directory, { - **get_config_value(config, "manual_hyperparameters", {}), - "dropout_rate": config["dropout_rate"], - "use_task_weights": config["use_task_weights"], - "task_weights": config["task_weights"], - "max_layers_to_freeze": config["max_layers_to_freeze"], - "use_attention_pooling": config["use_attention_pooling"], - }) - - # For Optuna trials, store results in shared dictionary - if shared_dict is not None: - shared_dict['val_loss'] = val_loss - - # Run validation on full dataset from rank 0 for consistent metrics - full_val_loader = get_data_loader(val_loader.dataset, config["batch_size"], sampler=None, shuffle=False) - - # Get validation predictions using our utility function - task_true_labels, task_pred_labels, task_pred_probs = collect_validation_predictions( - model, full_val_loader, trainer.device, config - ) - - # Calculate metrics - task_metrics = calculate_metrics(labels=task_true_labels, preds=task_pred_labels, metric_type="task_specific") - shared_dict['task_metrics'] = task_metrics - - # Store model state dict - if isinstance(model, DDP): - model_state_dict = model.module.state_dict() - else: - model_state_dict = model.state_dict() - - shared_dict['model_state_dict'] = {k: v.cpu() for k, v in model_state_dict.items()} - - # Clean up distributed environment - dist.destroy_process_group() - - -def save_model_without_heads(model_directory): +def get_layer_freeze_range(pretrained_path): """ - Save a version of the fine-tuned model without classification heads. - + Dynamically determines the number of layers to freeze based on the model depth from its configuration. Args: - model_directory (str): Path to the directory containing the fine-tuned model + pretrained_path (str): Path to the pretrained model directory or model identifier. + Returns: + dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze. """ - import torch - from transformers import BertConfig, BertModel - - # Load the full model - model_path = os.path.join(model_directory, "pytorch_model.bin") - config_path = os.path.join(model_directory, "config.json") - - if not os.path.exists(model_path) or not os.path.exists(config_path): - raise FileNotFoundError(f"Model files not found in {model_directory}") - - # Load the configuration - config = BertConfig.from_json_file(config_path) - - # Load the model state dict - state_dict = torch.load(model_path, map_location=torch.device('cpu')) - - # Create a new model without heads - base_model = BertModel(config) - - # Filter out the classification head parameters - base_model_state_dict = {} - for key, value in state_dict.items(): - # Only keep parameters that belong to the base model (not classification heads) - if not key.startswith('classification_heads') and not key.startswith('attention_pool'): - base_model_state_dict[key] = value - - # Load the filtered state dict into the base model - base_model.load_state_dict(base_model_state_dict, strict=False) - - # Save the model without heads - output_dir = os.path.join(model_directory, "model_without_heads") - os.makedirs(output_dir, exist_ok=True) - - # Save the model weights - torch.save(base_model.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) - - # Save the configuration - base_model.config.to_json_file(os.path.join(output_dir, "config.json")) - - print(f"Model without classification heads saved to {output_dir}") - return output_dir - - -def get_config_value(config: Dict, key: str, default=None): - - return config.get(key, default) - - -def collect_validation_predictions(model, val_loader, device, config) -> tuple: - task_true_labels = {} - task_pred_labels = {} - task_pred_probs = {} - - with torch.no_grad(): - for batch in val_loader: - input_ids = batch["input_ids"].to(device) - attention_mask = batch["attention_mask"].to(device) - labels = [batch["labels"][task_name].to(device) for task_name in config["task_names"]] - _, logits, _ = model(input_ids, attention_mask, labels) - - for sample_idx in range(len(batch["input_ids"])): - for i, task_name in enumerate(config["task_names"]): - if task_name not in task_true_labels: - task_true_labels[task_name] = [] - task_pred_labels[task_name] = [] - task_pred_probs[task_name] = [] - - true_label = batch["labels"][task_name][sample_idx].item() - pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() - pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy() - - task_true_labels[task_name].append(true_label) - task_pred_labels[task_name].append(pred_label) - task_pred_probs[task_name].append(pred_prob) - - return task_true_labels, task_pred_labels, task_pred_probs + if pretrained_path: + config = AutoConfig.from_pretrained(pretrained_path) + total_layers = config.num_hidden_layers + return {"min": 0, "max": total_layers - 1} + else: + return {"min": 0, "max": 0} diff --git a/geneformer/mtl_classifier.py b/geneformer/mtl_classifier.py index d6f784c3b68d6e58a09a4ca445f62e7833705be5..68ee837a416e27d9e20156100e30718dec6778d0 100644 --- a/geneformer/mtl_classifier.py +++ b/geneformer/mtl_classifier.py @@ -29,7 +29,7 @@ Geneformer multi-task cell classifier. import logging import os -from .mtl import eval_utils, utils, train +from .mtl import eval_utils, train_utils, utils logger = logging.getLogger(__name__) @@ -49,9 +49,7 @@ class MTLClassifier: "max_layers_to_freeze": {None, dict}, "epochs": {None, int}, "tensorboard_log_dir": {None, str}, - "distributed_training": {None, bool}, - "master_addr": {None, str}, - "master_port": {None, str}, + "use_data_parallel": {None, bool}, "use_attention_pooling": {None, bool}, "use_task_weights": {None, bool}, "hyperparameters": {None, dict}, @@ -63,7 +61,6 @@ class MTLClassifier: "max_grad_norm": {None, int, float}, "seed": {None, int}, "trials_result_path": {None, str}, - "gradient_accumulation_steps": {None, int}, } def __init__( @@ -82,9 +79,7 @@ class MTLClassifier: max_layers_to_freeze=None, epochs=1, tensorboard_log_dir="/results/tblogdir", - distributed_training=False, - master_addr="localhost", - master_port="12355", + use_data_parallel=False, use_attention_pooling=True, use_task_weights=True, hyperparameters=None, # Default is None @@ -94,7 +89,6 @@ class MTLClassifier: wandb_project=None, gradient_clipping=False, max_grad_norm=None, - gradient_accumulation_steps=1, # Add this line with default value 1 seed=42, # Default seed value ): """ @@ -123,12 +117,8 @@ class MTLClassifier: | Path to directory to save results tensorboard_log_dir : None, str | Path to directory for Tensorboard logging results - distributed_training : None, bool - | Whether to use distributed data parallel training across multiple GPUs - master_addr : None, str - | Master address for distributed training (default: localhost) - master_port : None, str - | Master port for distributed training (default: 12355) + use_data_parallel : None, bool + | Whether to use data parallelization use_attention_pooling : None, bool | Whether to use attention pooling use_task_weights : None, bool @@ -160,8 +150,6 @@ class MTLClassifier: | Whether to use gradient clipping max_grad_norm : None, int, float | Maximum norm for gradient clipping - gradient_accumulation_steps : None, int - | Number of steps to accumulate gradients before performing a backward/update pass seed : None, int | Random seed """ @@ -177,7 +165,6 @@ class MTLClassifier: self.batch_size = batch_size self.n_trials = n_trials self.study_name = study_name - self.gradient_accumulation_steps = gradient_accumulation_steps if max_layers_to_freeze is None: # Dynamically determine the range of layers to freeze @@ -188,9 +175,7 @@ class MTLClassifier: self.epochs = epochs self.tensorboard_log_dir = tensorboard_log_dir - self.distributed_training = distributed_training - self.master_addr = master_addr - self.master_port = master_port + self.use_data_parallel = use_data_parallel self.use_attention_pooling = use_attention_pooling self.use_task_weights = use_task_weights self.hyperparameters = ( @@ -308,7 +293,7 @@ class MTLClassifier: self.config["manual_hyperparameters"] = self.manual_hyperparameters self.config["use_manual_hyperparameters"] = True - train.run_manual_tuning(self.config) + train_utils.run_manual_tuning(self.config) def validate_additional_options(self, req_var_dict): missing_variable = False @@ -345,7 +330,7 @@ class MTLClassifier: req_var_dict = dict(zip(required_variable_names, required_variables)) self.validate_additional_options(req_var_dict) - train.run_optuna_study(self.config) + train_utils.run_optuna_study(self.config) def load_and_evaluate_test_model( self, diff --git a/geneformer/perturber_utils.py b/geneformer/perturber_utils.py index 6970612fd36941c7bd68c04c37bb1ce2df9cab37..1a5a8378d2cf24717bba629756dc34731dc847f0 100644 --- a/geneformer/perturber_utils.py +++ b/geneformer/perturber_utils.py @@ -1,6 +1,5 @@ import itertools as it import logging -import os import pickle from collections import defaultdict from pathlib import Path @@ -9,7 +8,6 @@ from typing import List import numpy as np import pandas as pd import torch -import datasets from datasets import Dataset, load_from_disk from peft import LoraConfig, get_peft_model from transformers import ( @@ -19,6 +17,11 @@ from transformers import ( BitsAndBytesConfig, ) +from . import ( + TOKEN_DICTIONARY_FILE, + ENSEMBL_DICTIONARY_FILE, +) + logger = logging.getLogger(__name__) @@ -124,10 +127,7 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False): output_hidden_states = (mode == "eval") # Quantization logic - if isinstance(quantize, dict): - quantize_config = quantize.get("bnb_config", None) - peft_config = quantize.get("peft_config", None) - elif quantize: + if quantize: if inference_only: quantize_config = BitsAndBytesConfig(load_in_8bit=True) peft_config = None @@ -138,22 +138,13 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False): bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) - try: - peft_config = LoraConfig( - lora_alpha=128, - lora_dropout=0.1, - r=64, - bias="none", - task_type="TokenClassification", - ) - except ValueError as e: - peft_config = LoraConfig( - lora_alpha=128, - lora_dropout=0.1, - r=64, - bias="none", - task_type="TOKEN_CLS", - ) + peft_config = LoraConfig( + lora_alpha=128, + lora_dropout=0.1, + r=64, + bias="none", + task_type="TokenClassification", + ) else: quantize_config = None peft_config = None @@ -190,34 +181,17 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False): model.eval() # Handle device placement and PEFT - adapter_config_path = os.path.join(model_directory, "adapter_config.json") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not quantize: # Only move non-quantized models - move_to_cuda(model) - elif os.path.exists(adapter_config_path): - # If adapter files exist, load them into the model using PEFT's from_pretrained - model = PeftModel.from_pretrained(model, model_directory) - move_to_cuda(model) - print("loading lora weights") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) elif peft_config: # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf) model.enable_input_require_grads() model = get_peft_model(model, peft_config) - move_to_cuda(model) return model - -def move_to_cuda(model): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # get what device model is currently on - model_device = next(model.parameters()).device - # Check if the model is on the CPU and move to cuda if necessary - if (model_device.type == "cpu") and (device.type == "cuda"): - model.to(device) - - def quant_layers(model): layer_nums = [] for name, parameter in model.named_parameters(): @@ -431,11 +405,6 @@ def remove_perturbed_indices_set( def make_perturbation_batch( example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc ) -> tuple[Dataset, List[int]]: - - # For datasets>=4.0.0, convert to dict to avoid format issues - if int(datasets.__version__.split(".")[0]) >= 4: - example_cell = example_cell[:] - if combo_lvl == 0 and tokens_to_perturb == "all": if perturb_type in ["overexpress", "activate"]: range_start = 1 @@ -508,11 +477,6 @@ def make_perturbation_batch( def make_perturbation_batch_special( example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc ) -> tuple[Dataset, List[int]]: - - # For datasets>=4.0.0, convert to dict to avoid format issues - if int(datasets.__version__.split(".")[0]) >= 4: - example_cell = example_cell[:] - if combo_lvl == 0 and tokens_to_perturb == "all": if perturb_type in ["overexpress", "activate"]: range_start = 1 @@ -913,4 +877,50 @@ def validate_cell_states_to_model(cell_states_to_model): "'goal_state': 'nf', " "'alt_states': ['hcm', 'other1', 'other2']}" ) - raise \ No newline at end of file + raise + + +class GeneIdHandler: + def __init__(self, raise_errors=False): + def invert_dict(dict_obj): + return {v: k for k, v in dict_obj.items()} + + self.raise_errors = raise_errors + + with open(TOKEN_DICTIONARY_FILE, "rb") as f: + self.gene_token_dict = pickle.load(f) + self.token_gene_dict = invert_dict(self.gene_token_dict) + + with open(ENSEMBL_DICTIONARY_FILE, "rb") as f: + self.id_gene_dict = pickle.load(f) + self.gene_id_dict = invert_dict(self.id_gene_dict) + + def ens_to_token(self, ens_id): + if not self.raise_errors: + return self.gene_token_dict.get(ens_id, ens_id) + else: + return self.gene_token_dict[ens_id] + + def token_to_ens(self, token): + if not self.raise_errors: + return self.token_gene_dict.get(token, token) + else: + return self.token_gene_dict[token] + + def ens_to_symbol(self, ens_id): + if not self.raise_errors: + return self.gene_id_dict.get(ens_id, ens_id) + else: + return self.gene_id_dict[ens_id] + + def symbol_to_ens(self, symbol): + if not self.raise_errors: + return self.id_gene_dict.get(symbol, symbol) + else: + return self.id_gene_dict[symbol] + + def token_to_symbol(self, token): + return self.ens_to_symbol(self.token_to_ens(token)) + + def symbol_to_token(self, symbol): + return self.ens_to_token(self.symbol_to_ens(symbol)) diff --git a/geneformer/token_dictionary_gc104M.pkl b/geneformer/token_dictionary_gc95M.pkl similarity index 100% rename from geneformer/token_dictionary_gc104M.pkl rename to geneformer/token_dictionary_gc95M.pkl diff --git a/geneformer/tokenizer.py b/geneformer/tokenizer.py index 8af0cfa0f336d007feb2b144129a96c88ad8a871..b460f028c9d85630b34722a290df6dd40f8908aa 100644 --- a/geneformer/tokenizer.py +++ b/geneformer/tokenizer.py @@ -3,7 +3,7 @@ Geneformer tokenizer. **Input data:** -| *Required format:* raw counts scRNAseq data without feature selection as .loom, .h5ad, or .zarr file. +| *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file. | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene. | *Required col (cell) attribute:* "n_counts"; total read counts in that cell. @@ -20,9 +20,9 @@ Geneformer tokenizer. **Description:** -| Input data is a directory with .loom, .h5ad, or .zarr files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function. +| Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function. -| The discussion below references the .loom file format, but the analagous labels are required for .h5ad and .zarr files, just that they will be column instead of row attributes and vice versa due to the transposed format of the file types. +| The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types. | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization. @@ -30,9 +30,11 @@ Geneformer tokenizer. | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset. -| If one's data is in other formats besides .loom, .h5ad, or .zarr, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom, .h5ad, or .zarr format prior to running the transcriptome tokenizer. +| If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer. -| OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.) +| OF NOTE: Take care that the correct token dictionary and gene median file is used for the correct model. + +| OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048. """ @@ -46,7 +48,6 @@ from collections import Counter from pathlib import Path from typing import Literal -import anndata as ad import loompy as lp import numpy as np import pandas as pd @@ -88,7 +89,6 @@ def sum_ensembl_ids( gene_mapping_dict, gene_token_dict, custom_attr_name_dict, - use_h5ad_index, file_format="loom", chunk_size=512, ): @@ -201,19 +201,13 @@ def sum_ensembl_ids( dsout.add_columns(processed_array, col_attrs=view.ca) return dedup_filename - elif file_format in ["h5ad", "zarr"]: + elif file_format == "h5ad": """ Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together. Returns adata object with deduplicated Ensembl IDs. """ - if file_format == "h5ad": - data = sc.read_h5ad(str(data_directory)) - else: # zarr - data = ad.read_zarr(str(data_directory)) - - if use_h5ad_index: - data.var["ensembl_id"] = list(data.var.index) + data = sc.read_h5ad(str(data_directory)) assert ( "ensembl_id" in data.var.columns @@ -240,7 +234,7 @@ def sum_ensembl_ids( gene for gene in ensembl_ids if gene in gene_token_dict.keys() ] if len(ensembl_id_check) == len(set(ensembl_id_check)): - return data + return data_directory else: raise ValueError("Error: data Ensembl IDs non-unique.") @@ -305,9 +299,6 @@ class TranscriptomeTokenizer: model_input_size=4096, special_token=True, collapse_gene_ids=True, - use_h5ad_index=False, - keep_counts=False, - model_version="V2", gene_median_file=GENE_MEDIAN_FILE, token_dictionary_file=TOKEN_DICTIONARY_FILE, gene_mapping_file=ENSEMBL_MAPPING_FILE, @@ -327,23 +318,15 @@ class TranscriptomeTokenizer: | Chunk size for anndata tokenizer. model_input_size : int = 4096 | Max input size of model to truncate input to. - | For the V1 model series, should be 2048. For the V2 model series, should be 4096. + | For the 30M model series, should be 2048. For the 95M model series, should be 4096. special_token : bool = True | Adds CLS token before and EOS token after rank value encoding. - | For the V1 model series, should be False. For the V2 model series, should be True. + | For the 30M model series, should be False. For the 95M model series, should be True. collapse_gene_ids : bool = True | Whether to collapse gene IDs based on gene mapping dictionary. - use_h5ad_index : bool = False - | use index as Ensembl IDs (only available for h5ad, only if collapse_gene_ids is True) - keep_counts : bool = False - | Whether to keep a dataset column that represents gene counts normalized by total cell counts - | Counts will be ordered by the gene rank order within the tokenized rank value encoding for each cell. - model_version : str - | To auto-select settings for model version other than current default. - | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells gene_median_file : Path | Path to pickle file containing dictionary of non-zero median - | gene expression values across Genecorpus. + | gene expression values across Genecorpus-30M. token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl IDs:token). gene_mapping_file : None, Path @@ -365,22 +348,8 @@ class TranscriptomeTokenizer: # add CLS and EOS tokens self.special_token = special_token - # CHANGE DEFAULTS TO BE FOR MODEL OTHER THAN CURRENT - self.model_version = model_version - if self.model_version not in ["V1","V2"]: - logger.error( - "Unrecognized model version. Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells." - ) - elif self.model_version == "V1": - self.model_input_size = 2048 - self.special_token = False - from . import ENSEMBL_MAPPING_FILE_30M, GENE_MEDIAN_FILE_30M, TOKEN_DICTIONARY_FILE_30M - gene_median_file = GENE_MEDIAN_FILE_30M - token_dictionary_file = TOKEN_DICTIONARY_FILE_30M - gene_mapping_file = ENSEMBL_MAPPING_FILE_30M - # load dictionary of gene normalization factors - # (non-zero median value of expression across Genecorpus) + # (non-zero median value of expression across Genecorpus-30M) with open(gene_median_file, "rb") as f: self.gene_median_dict = pickle.load(f) @@ -403,18 +372,12 @@ class TranscriptomeTokenizer: "" in self.gene_token_dict.keys() ): logger.warning( - " and are in gene_token_dict but special_token = False. Please note that for V2 model series, special_token should be True." + " and are in gene_token_dict but special_token = False. Please note that for 95M model series, special_token should be True." ) # if collapsing duplicate gene IDs self.collapse_gene_ids = collapse_gene_ids - # if using h5ad index as ensembl_ids - self.use_h5ad_index = use_h5ad_index - - # if keeping counts within dataset column - self.keep_counts = keep_counts - # load gene mappings dictionary (Ensembl IDs:Ensembl ID) if gene_mapping_file is not None: with open(gene_mapping_file, "rb") as f: @@ -439,8 +402,7 @@ class TranscriptomeTokenizer: data_directory: Path | str, output_directory: Path | str, output_prefix: str, - file_format: Literal["loom", "h5ad", "zarr"] = "loom", - input_identifier: str = "", + file_format: Literal["loom", "h5ad"] = "loom", use_generator: bool = False, ): """ @@ -455,21 +417,17 @@ class TranscriptomeTokenizer: output_prefix : str | Prefix for output .dataset file_format : str - | Format of input files. Can be "loom", "h5ad", or "zarr". - input_identifier : str - | Substring identifier for input .loom, .h5ad, or .zarr, only matches are tokenized - | Default is no identifier, tokenizes all files in provided directory. + | Format of input files. Can be "loom" or "h5ad". use_generator : bool | Whether to use generator or dict for tokenization. """ - tokenized_cells, cell_metadata, tokenized_counts = self.tokenize_files( - Path(data_directory), file_format, input_identifier + tokenized_cells, cell_metadata = self.tokenize_files( + Path(data_directory), file_format ) tokenized_dataset = self.create_dataset( tokenized_cells, cell_metadata, - tokenized_counts, use_generator=use_generator, ) @@ -477,10 +435,9 @@ class TranscriptomeTokenizer: tokenized_dataset.save_to_disk(str(output_path)) def tokenize_files( - self, data_directory, file_format: Literal["loom", "h5ad", "zarr"] = "loom", input_identifier: str = "" + self, data_directory, file_format: Literal["loom", "h5ad"] = "loom" ): tokenized_cells = [] - tokenized_counts = [] if self.custom_attr_name_dict is not None: cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()] cell_metadata = { @@ -489,20 +446,15 @@ class TranscriptomeTokenizer: # loops through directories to tokenize .loom files file_found = 0 - # loops through directories to tokenize .loom, .h5ad, or .zarr files + # loops through directories to tokenize .loom or .h5ad files tokenize_file_fn = ( self.tokenize_loom if file_format == "loom" else self.tokenize_anndata ) - if input_identifier == "": - file_match = f"*.{file_format}" - else: - file_match = f"*{input_identifier}*.{file_format}" - for file_path in data_directory.glob(file_match): + for file_path in data_directory.glob(f"*.{file_format}"): file_found = 1 print(f"Tokenizing {file_path}") - file_tokenized_cells, file_cell_metadata, file_tokenized_counts = tokenize_file_fn(file_path, file_format=file_format) + file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path) tokenized_cells += file_tokenized_cells - tokenized_counts += file_tokenized_counts if self.custom_attr_name_dict is not None: for k in cell_attr: cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[ @@ -516,17 +468,16 @@ class TranscriptomeTokenizer: f"No .{file_format} files found in directory {data_directory}." ) raise - return tokenized_cells, cell_metadata, tokenized_counts + return tokenized_cells, cell_metadata - def tokenize_anndata(self, adata_file_path, target_sum=10_000, file_format="h5ad"): + def tokenize_anndata(self, adata_file_path, target_sum=10_000): adata = sum_ensembl_ids( adata_file_path, self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, - self.custom_attr_name_dict, - self.use_h5ad_index, - file_format=file_format, + self.custom_attr_name_dict, + file_format="h5ad", chunk_size=self.chunk_size, ) @@ -565,7 +516,6 @@ class TranscriptomeTokenizer: filter_pass_loc = np.array([i for i in range(adata.shape[0])]) tokenized_cells = [] - tokenized_counts = [] for i in range(0, len(filter_pass_loc), self.chunk_size): idx = filter_pass_loc[i : i + self.chunk_size] @@ -573,23 +523,14 @@ class TranscriptomeTokenizer: n_counts = adata[idx].obs["n_counts"].values[:, None] X_view0 = adata[idx, :].X X_view = X_view0[:, coding_miRNA_loc] - X_norm_unscaled = X_view / n_counts * target_sum - X_norm = X_norm_unscaled / norm_factor_vector + X_norm = X_view / n_counts * target_sum / norm_factor_vector X_norm = sp.csr_matrix(X_norm) - X_norm_unscaled = sp.csr_matrix(X_norm_unscaled) tokenized_cells += [ rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices]) for i in range(X_norm.shape[0]) ] - if self.keep_counts: - X_norm_unscaled = sp.csr_matrix(X_norm_unscaled) - tokenized_counts += [ - rank_genes(X_norm[i].data, X_norm_unscaled[i].data) - for i in range(X_norm.shape[0]) - ] - # add custom attributes for subview to dict if self.custom_attr_name_dict is not None: for k in file_cell_metadata.keys(): @@ -597,28 +538,9 @@ class TranscriptomeTokenizer: else: file_cell_metadata = None - # ensure no tokenized_cells are empty - empty_cell_indices = [i for i, cell in enumerate(tokenized_cells) if cell.size == 0] - if len(empty_cell_indices) > 0: - logger.warning( - "Warning: cells without any genes in token dictionary detected. This is unusual and may indicate empty droplets or otherwise invalid cells within the input data. Consider further QC prior to tokenization. Proceeding with excluding empty cells." - ) - empty_cell_indices.sort(reverse=True) # for safe deletion - for index in empty_cell_indices: - del tokenized_cells[index] - if self.keep_counts: - del tokenized_counts[index] - # remove corresponding metadata - for k,v in file_cell_metadata.items(): - for index in empty_cell_indices: - del v[index] - file_cell_metadata[k] = v - - return tokenized_cells, file_cell_metadata, tokenized_counts + return tokenized_cells, file_cell_metadata - def tokenize_loom(self, loom_file_path, target_sum=10_000, file_format="loom"): - tokenized_counts = [] # keep_counts not implemented for tokenize_loom - + def tokenize_loom(self, loom_file_path, target_sum=10_000): if self.custom_attr_name_dict is not None: file_cell_metadata = { attr_key: [] for attr_key in self.custom_attr_name_dict.keys() @@ -632,8 +554,7 @@ class TranscriptomeTokenizer: self.gene_mapping_dict, self.gene_token_dict, self.custom_attr_name_dict, - use_h5ad_index=False, - file_format=file_format, + file_format="loom", chunk_size=self.chunk_size, ) @@ -706,33 +627,29 @@ class TranscriptomeTokenizer: del data.ra["ensembl_id_collapsed"] - return tokenized_cells, file_cell_metadata, tokenized_counts + return tokenized_cells, file_cell_metadata def create_dataset( self, tokenized_cells, cell_metadata, - tokenized_counts, use_generator=False, keep_uncropped_input_ids=False, ): print("Creating dataset.") # create dict for dataset creation dataset_dict = {"input_ids": tokenized_cells} - if self.keep_counts: - dataset_dict["counts"] = tokenized_counts - if self.custom_attr_name_dict is not None: dataset_dict.update(cell_metadata) # create dataset if use_generator: + def dict_generator(): for i in range(len(tokenized_cells)): yield {k: dataset_dict[k][i] for k in dataset_dict.keys()} output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc) - else: output_dataset = Dataset.from_dict(dataset_dict) @@ -755,23 +672,9 @@ class TranscriptomeTokenizer: len(example["input_ids"]), self.gene_token_dict.get(""), ) - if self.keep_counts: - example["counts"] = example["counts"][ - 0 : self.model_input_size - 2 - ] # truncate to leave space for CLS and EOS token - example["counts"] = np.insert( - example["counts"], 0, 0.0 - ) - example["counts"] = np.insert( - example["counts"], - len(example["counts"]), - 0.0, - ) else: # Truncate/Crop input_ids to input size example["input_ids"] = example["input_ids"][0 : self.model_input_size] - if self.keep_counts: - example["counts"] = example["counts"][0 : self.model_input_size] example["length"] = len(example["input_ids"]) return example diff --git a/generation_config.json b/generation_config.json index 0786a9f4dc0de68ee18cbf78399931b05fbefee7..6f690c1f39b5b262e6b898b8891afd9d44978f11 100644 --- a/generation_config.json +++ b/generation_config.json @@ -1,5 +1,5 @@ { "_from_model_config": true, "pad_token_id": 0, - "transformers_version": "4.44.2" + "transformers_version": "4.37.1" } diff --git a/gf-12L-30M-i2048/config.json b/gf-12L-30M-i2048/config.json new file mode 100644 index 0000000000000000000000000000000000000000..52a12424cea85facdf0ca0c507908506daae7ea7 --- /dev/null +++ b/gf-12L-30M-i2048/config.json @@ -0,0 +1,23 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "gradient_checkpointing": false, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 1024, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 2048, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 25426 +} diff --git a/gf-12L-30M-i2048/pytorch_model.bin b/gf-12L-30M-i2048/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..d706ef2ff77fd6809a91034e6ed24af0e1b33999 --- /dev/null +++ b/gf-12L-30M-i2048/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615 +size 158467410 diff --git a/gf-12L-30M-i2048/training_args.bin b/gf-12L-30M-i2048/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..346383caaaa3b555cb6fcd8de8e4982ebf2a50d5 --- /dev/null +++ b/gf-12L-30M-i2048/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:259cf6067211e24e198690d00f0a222ee5550ad57e23d04ced0d0ca2e1b3738e +size 2607 diff --git a/Geneformer-V2-316M/config.json b/gf-12L-95M-i4096/config.json similarity index 63% rename from Geneformer-V2-316M/config.json rename to gf-12L-95M-i4096/config.json index 6bc648aa565eabb748a6a43ee4def5032a0d5237..86e20c35e6f257f0daeb00ebb92a0751d12d8fff 100755 --- a/Geneformer-V2-316M/config.json +++ b/gf-12L-95M-i4096/config.json @@ -2,22 +2,22 @@ "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.02, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 1152, + "hidden_dropout_prob": 0.02, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 4608, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 18, - "num_hidden_layers": 18, + "num_attention_heads": 8, + "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "torch_dtype": "float32", - "transformers_version": "4.44.2", + "transformers_version": "4.37.1", "type_vocab_size": 2, "use_cache": true, "vocab_size": 20275 diff --git a/Geneformer-V2-104M_CLcancer/generation_config.json b/gf-12L-95M-i4096/generation_config.json similarity index 100% rename from Geneformer-V2-104M_CLcancer/generation_config.json rename to gf-12L-95M-i4096/generation_config.json diff --git a/gf-12L-95M-i4096/model.safetensors b/gf-12L-95M-i4096/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..1069352219a29bed65fa8e13feb77004128174fa --- /dev/null +++ b/gf-12L-95M-i4096/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c +size 152012980 diff --git a/gf-12L-95M-i4096/training_args.bin b/gf-12L-95M-i4096/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..18802f485a03e0262866d1ef7a3e4748a3b14ed3 --- /dev/null +++ b/gf-12L-95M-i4096/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d +size 4920 diff --git a/Geneformer-V2-104M_CLcancer/config.json b/gf-12L-95M-i4096_CLcancer/config.json similarity index 57% rename from Geneformer-V2-104M_CLcancer/config.json rename to gf-12L-95M-i4096_CLcancer/config.json index 8af4fee97b1880ff66cb6a183e824da7318c1416..a7793eb2ea27b28f1f4c5b9974d30c98b4afe8a6 100755 --- a/Geneformer-V2-104M_CLcancer/config.json +++ b/gf-12L-95M-i4096_CLcancer/config.json @@ -1,19 +1,19 @@ { - "_name_or_path": "/gladstone/theodoris/lab/ctheodoris/gf-104m/models/241127_143148_geneformer_94M_L12_emb768_SL4096_E3_B18_LR0.0002_LScosine_WR0.007_Oadamw_DS13/models", + "_name_or_path": "/gladstone/theodoris/lab/pretrained_models/encoder/240402_194213_geneformer_94M_L12_emb512_SL4096_E3_B4_LR0.0005_LScosine_WU5000_Oadamw_DS8/models", "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.02, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, + "hidden_dropout_prob": 0.02, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 3072, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 12, + "num_attention_heads": 8, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", diff --git a/Geneformer-V2-316M/generation_config.json b/gf-12L-95M-i4096_CLcancer/generation_config.json similarity index 61% rename from Geneformer-V2-316M/generation_config.json rename to gf-12L-95M-i4096_CLcancer/generation_config.json index 0786a9f4dc0de68ee18cbf78399931b05fbefee7..6f690c1f39b5b262e6b898b8891afd9d44978f11 100755 --- a/Geneformer-V2-316M/generation_config.json +++ b/gf-12L-95M-i4096_CLcancer/generation_config.json @@ -1,5 +1,5 @@ { "_from_model_config": true, "pad_token_id": 0, - "transformers_version": "4.44.2" + "transformers_version": "4.37.1" } diff --git a/gf-12L-95M-i4096_CLcancer/model.safetensors b/gf-12L-95M-i4096_CLcancer/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..cc620ee4b4243b7ab6d83ad518563e1425eab45b --- /dev/null +++ b/gf-12L-95M-i4096_CLcancer/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2451adeed240c165634fea60ccba17063da8a2843ea9fcdcc0ce185720bf0dc2 +size 152012980 diff --git a/gf-12L-95M-i4096_CLcancer/training_args.bin b/gf-12L-95M-i4096_CLcancer/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..1669f5848710ca4a53db6e118e50b816f85381b7 --- /dev/null +++ b/gf-12L-95M-i4096_CLcancer/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37074f3ea62a6ba0a312c38526c20c2dccbb068a2c7ee8c7c73b435dd90ab7b1 +size 5048 diff --git a/gf-20L-95M-i4096/config.json b/gf-20L-95M-i4096/config.json new file mode 100755 index 0000000000000000000000000000000000000000..db949ba1ae442ad3b9e52fd8b7922c6b936ef98c --- /dev/null +++ b/gf-20L-95M-i4096/config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "classifier_dropout": null, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 896, + "initializer_range": 0.02, + "intermediate_size": 1792, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 4096, + "model_type": "bert", + "num_attention_heads": 14, + "num_hidden_layers": 20, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "torch_dtype": "float32", + "transformers_version": "4.37.1", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 20275 +} diff --git a/Geneformer-V2-104M/generation_config.json b/gf-20L-95M-i4096/generation_config.json similarity index 61% rename from Geneformer-V2-104M/generation_config.json rename to gf-20L-95M-i4096/generation_config.json index 0786a9f4dc0de68ee18cbf78399931b05fbefee7..6f690c1f39b5b262e6b898b8891afd9d44978f11 100755 --- a/Geneformer-V2-104M/generation_config.json +++ b/gf-20L-95M-i4096/generation_config.json @@ -1,5 +1,5 @@ { "_from_model_config": true, "pad_token_id": 0, - "transformers_version": "4.44.2" + "transformers_version": "4.37.1" } diff --git a/gf-20L-95M-i4096/model.safetensors b/gf-20L-95M-i4096/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..37212863afb501a17425dd48766d71d534537d24 --- /dev/null +++ b/gf-20L-95M-i4096/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db85c081a6d392448955c7d0185e26aba74507518df991ca8c69ee9108ce8bbf +size 605292732 diff --git a/gf-20L-95M-i4096/training_args.bin b/gf-20L-95M-i4096/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..3db61b0b99d299afb7c4a237d2b531baa253e5d3 --- /dev/null +++ b/gf-20L-95M-i4096/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5afed602918d6f0c4916c1b9335bcdb619bca2c6fd6c7e0dd2a86d195264b8cc +size 5048 diff --git a/Geneformer-V1-10M/config.json b/gf-6L-30M-i2048/config.json similarity index 100% rename from Geneformer-V1-10M/config.json rename to gf-6L-30M-i2048/config.json diff --git a/Geneformer-V1-10M/model.safetensors b/gf-6L-30M-i2048/model.safetensors similarity index 100% rename from Geneformer-V1-10M/model.safetensors rename to gf-6L-30M-i2048/model.safetensors diff --git a/Geneformer-V1-10M/pytorch_model.bin b/gf-6L-30M-i2048/pytorch_model.bin similarity index 100% rename from Geneformer-V1-10M/pytorch_model.bin rename to gf-6L-30M-i2048/pytorch_model.bin diff --git a/Geneformer-V1-10M/training_args.bin b/gf-6L-30M-i2048/training_args.bin similarity index 100% rename from Geneformer-V1-10M/training_args.bin rename to gf-6L-30M-i2048/training_args.bin diff --git a/model.safetensors b/model.safetensors index f9c0c25c4a1df80ddc1455def715a9856152882c..1069352219a29bed65fa8e13feb77004128174fa 100644 --- a/model.safetensors +++ b/model.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:965ceccea81953d362081ef3843560a0e4fef88d396c28017881f1e94b1246f3 -size 1265455076 +oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c +size 152012980 diff --git a/requirements.txt b/requirements.txt index 4b9b25e28193489e3416cc324eefe6db708bc004..0cb09a2593f3a727090f7cf9f7eacd36edd8ddbd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ anndata>=0.9 -bitsandbytes>=0.45.5 datasets>=2.12 hyperopt>=0.2 loompy>=3.0 @@ -23,4 +22,4 @@ tdigest>=0.5.2 tensorboard>=2.15 torch>=2.0.1 tqdm>=4.65 -transformers==4.46 \ No newline at end of file +transformers>=4.40 diff --git a/setup.py b/setup.py index 81d947cea5933da94d1ea315dd5eee59f18c6577..6dde9eefad8c76e3d1e41ae187f2215bdbc93db5 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ setup( include_package_data=True, install_requires=[ "anndata", - "bitsandbytes", "datasets", "loompy", "matplotlib", diff --git a/training_args.bin b/training_args.bin index 630bab56b199325b337a9969d30167f5b73b7815..18802f485a03e0262866d1ef7a3e4748a3b14ed3 100644 --- a/training_args.bin +++ b/training_args.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e45150f9a4ca34cb4e91ce79f65f3d99d9d66df9f66a37517a352d291008e0b8 -size 5432 +oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d +size 4920