Upload 56 files
Browse filesDownstream classification and zero-shot batch effect tasks
This view is limited to 50 files because it contains too many changes.
See raw diff
- Downstream_tasks/.DS_Store +0 -0
- Downstream_tasks/Classification/.DS_Store +0 -0
- Downstream_tasks/Classification/Cardio.py +1418 -0
- Downstream_tasks/Classification/Cardio_ML.ipynb +1404 -0
- Downstream_tasks/Classification/Gene_dosage.ipynb +0 -0
- Downstream_tasks/Classification/Gene_dosage_ML.ipynb +0 -0
- Downstream_tasks/Classification/Tissue_type.py +457 -0
- Downstream_tasks/Classification/Tissue_type_ML.ipynb +933 -0
- Downstream_tasks/Zero_shot_batch_effect/.DS_Store +0 -0
- Downstream_tasks/Zero_shot_batch_effect/.gitignore +419 -0
- Downstream_tasks/Zero_shot_batch_effect/CODE_OF_CONDUCT.md +9 -0
- Downstream_tasks/Zero_shot_batch_effect/LICENSE +21 -0
- Downstream_tasks/Zero_shot_batch_effect/README.md +162 -0
- Downstream_tasks/Zero_shot_batch_effect/SECURITY.md +41 -0
- Downstream_tasks/Zero_shot_batch_effect/SUPPORT.md +16 -0
- Downstream_tasks/Zero_shot_batch_effect/envs/conda_env.yml +21 -0
- Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/Dockerfile +28 -0
- Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/test.py +66 -0
- Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/test_docker.sh +13 -0
- Downstream_tasks/Zero_shot_batch_effect/envs/docker/jupyter/Dockerfile +12 -0
- Downstream_tasks/Zero_shot_batch_effect/envs/installation.sh +85 -0
- Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_Geneformer.ipynb +0 -0
- Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_HVG_and_scVI.ipynb +0 -0
- Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_evaluation_aggregated.ipynb +1058 -0
- Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_raw_data.ipynb +328 -0
- Downstream_tasks/Zero_shot_batch_effect/requirements.txt +12 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__init__.py +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/__init__.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/__init__.cpython-311.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/cell_embeddings.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/cell_embeddings.cpython-311.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/data.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/data.cpython-311.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/geneformer_forward.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/geneformer_forward.cpython-311.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/model_output.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/model_output.cpython-311.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/scgpt_forward.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/utils.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/utils.cpython-311.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/cell_embeddings.py +417 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/data.py +330 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/geneformer_forward.py +365 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__init__.py +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/__init__.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/__init__.cpython-311.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/custom_logging.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/custom_logging.cpython-311.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/umap.cpython-310.pyc +0 -0
- Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/umap.cpython-311.pyc +0 -0
Downstream_tasks/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
Downstream_tasks/Classification/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
Downstream_tasks/Classification/Cardio.py
ADDED
|
@@ -0,0 +1,1418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from tqdm.auto import tqdm, trange
|
| 3 |
+
GPU_NUMBER = [0]
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
| 5 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
| 6 |
+
|
| 7 |
+
# imports
|
| 8 |
+
from collections import Counter
|
| 9 |
+
import seaborn as sns; sns.set()
|
| 10 |
+
from datasets import load_from_disk
|
| 11 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 12 |
+
from transformers import Trainer
|
| 13 |
+
from transformers.training_args import TrainingArguments
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
| 16 |
+
from sklearn import preprocessing
|
| 17 |
+
from sklearn.metrics import (
|
| 18 |
+
ConfusionMatrixDisplay,
|
| 19 |
+
accuracy_score,
|
| 20 |
+
auc,
|
| 21 |
+
confusion_matrix,
|
| 22 |
+
f1_score,
|
| 23 |
+
roc_curve,
|
| 24 |
+
)
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
import sys
|
| 28 |
+
# sys.path.append('../Geneformer')
|
| 29 |
+
from geneformer import DataCollatorForCellClassification
|
| 30 |
+
from datasets import load_from_disk
|
| 31 |
+
import sys
|
| 32 |
+
from tqdm.notebook import tqdm
|
| 33 |
+
import seaborn as sns
|
| 34 |
+
import matplotlib.pyplot as plt
|
| 35 |
+
from geneformer.pretrainer import token_dictionary
|
| 36 |
+
import datetime
|
| 37 |
+
import time
|
| 38 |
+
import pickle
|
| 39 |
+
import random
|
| 40 |
+
import subprocess
|
| 41 |
+
import numpy as np
|
| 42 |
+
import pytz
|
| 43 |
+
import torch
|
| 44 |
+
from datasets import load_from_disk, Dataset
|
| 45 |
+
from transformers import (BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback,
|
| 46 |
+
Trainer, BertModel, BertPreTrainedModel, BertForSequenceClassification, BertForTokenClassification)
|
| 47 |
+
from geneformer import GeneformerPretrainer
|
| 48 |
+
from torch import Tensor
|
| 49 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 50 |
+
from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
|
| 51 |
+
from transformers.activations import ACT2FN
|
| 52 |
+
from typing import List, Optional, Tuple, Union
|
| 53 |
+
import torch.nn.functional as F
|
| 54 |
+
|
| 55 |
+
macro_f1_list = []
|
| 56 |
+
acc_list = []
|
| 57 |
+
|
| 58 |
+
iter_step = 2
|
| 59 |
+
|
| 60 |
+
class CustomBertForMaskedLM(BertPreTrainedModel):
|
| 61 |
+
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
| 62 |
+
_tied_weights_keys = ["decoder.weight", "bert.embeddings.word_embeddings.weight"]
|
| 63 |
+
|
| 64 |
+
def __init__(self, config):
|
| 65 |
+
super().__init__(config)
|
| 66 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 67 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 68 |
+
|
| 69 |
+
self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 70 |
+
|
| 71 |
+
self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size))
|
| 72 |
+
|
| 73 |
+
# Initialize weights
|
| 74 |
+
self.init_weights()
|
| 75 |
+
|
| 76 |
+
# Tie weights automatically
|
| 77 |
+
self.tie_weights()
|
| 78 |
+
|
| 79 |
+
# self.post_init()
|
| 80 |
+
|
| 81 |
+
def tie_weights(self):
|
| 82 |
+
"""
|
| 83 |
+
Ties the weights between the input embeddings and output decoder weights.
|
| 84 |
+
"""
|
| 85 |
+
self.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
| 86 |
+
|
| 87 |
+
def probability_convert(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
|
| 88 |
+
device = probs.device
|
| 89 |
+
batch_size, seq_length, vocab_size = probs.size()
|
| 90 |
+
_, input_seq_length = input_ids.size()
|
| 91 |
+
non_mask = labels == -100
|
| 92 |
+
non_mask_indices = non_mask.nonzero(as_tuple=True)
|
| 93 |
+
known_gene_indices = input_ids[non_mask]
|
| 94 |
+
|
| 95 |
+
# Generate (1-p) matrix whiel assigning all known genes in the beginning
|
| 96 |
+
zeros = torch.zeros((batch_size, 1, vocab_size), device=device)
|
| 97 |
+
zeros[non_mask_indices[0], 0, known_gene_indices] = 1.0
|
| 98 |
+
probs_shifted = torch.cat((zeros, probs[:, :-1, :]), dim=1)
|
| 99 |
+
inv_probs_shifted = 1 - probs_shifted
|
| 100 |
+
|
| 101 |
+
# Cumulative product to get (1-p_1)*(1-p_2)*...*(p_i)
|
| 102 |
+
cumprod_inv_probs = torch.cumprod(inv_probs_shifted, dim=1)
|
| 103 |
+
modified_probs = probs * cumprod_inv_probs
|
| 104 |
+
|
| 105 |
+
# # Since we are assigning probabilities for already known genes,
|
| 106 |
+
# # (1-p_1)*(1-p_2)*...*(p_i) for these genes can result in 0, due to hard assignment of probs to be 1
|
| 107 |
+
# # Add 1e-18 to avoid dividing modified probs by 0
|
| 108 |
+
# # During dubugging stage, some issues occurred in the normalization step.
|
| 109 |
+
# # Since probabilities in each position do not necessarily need to sum up to one, leave out normalization.
|
| 110 |
+
normalized_probs = modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18)
|
| 111 |
+
modified_probs = modified_probs / normalized_probs # Normalization after cumulative production
|
| 112 |
+
|
| 113 |
+
return modified_probs
|
| 114 |
+
|
| 115 |
+
def assign_known_gene_probs(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
|
| 116 |
+
|
| 117 |
+
device = probs.device
|
| 118 |
+
batch_size, seq_length, vocab_size = probs.size()
|
| 119 |
+
_, input_seq_length = input_ids.size()
|
| 120 |
+
|
| 121 |
+
# Truncate `labels` to match the length of `input_ids` along the sequence dimension
|
| 122 |
+
truncated_labels = labels[:, :input_seq_length]
|
| 123 |
+
|
| 124 |
+
non_mask = truncated_labels == -100
|
| 125 |
+
non_mask_indices = non_mask.nonzero(as_tuple=True)
|
| 126 |
+
|
| 127 |
+
ones = torch.ones((batch_size, seq_length, vocab_size), device=device)
|
| 128 |
+
zeros = torch.zeros((batch_size, seq_length, vocab_size), device=device)
|
| 129 |
+
|
| 130 |
+
known_gene_indices = input_ids[non_mask]
|
| 131 |
+
|
| 132 |
+
ones[non_mask_indices[0], non_mask_indices[1], :] = 0.0
|
| 133 |
+
zeros[non_mask_indices[0], non_mask_indices[1], known_gene_indices] = 1.0
|
| 134 |
+
|
| 135 |
+
# Modify already known genes' probabilities using the one-hot tensor
|
| 136 |
+
modified_probs = probs * ones
|
| 137 |
+
modified_probs = modified_probs + zeros
|
| 138 |
+
|
| 139 |
+
# Do the normalization
|
| 140 |
+
modified_probs = modified_probs / modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) # Normalize
|
| 141 |
+
|
| 142 |
+
return modified_probs
|
| 143 |
+
|
| 144 |
+
def forward(
|
| 145 |
+
self,
|
| 146 |
+
input_ids: Tensor | None = None,
|
| 147 |
+
attention_mask: Tensor | None = None,
|
| 148 |
+
token_type_ids: Tensor | None = None,
|
| 149 |
+
position_ids: Tensor | None = None,
|
| 150 |
+
head_mask: Tensor | None = None,
|
| 151 |
+
inputs_embeds: Tensor | None = None,
|
| 152 |
+
encoder_hidden_states: Tensor | None = None,
|
| 153 |
+
encoder_attention_mask: Tensor | None = None,
|
| 154 |
+
labels: Tensor | None = None,
|
| 155 |
+
output_attentions: bool | None = None,
|
| 156 |
+
output_hidden_states: bool | None = None,
|
| 157 |
+
return_dict: bool | None = None) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 158 |
+
|
| 159 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 160 |
+
|
| 161 |
+
outputs = self.bert(
|
| 162 |
+
input_ids,
|
| 163 |
+
attention_mask=attention_mask,
|
| 164 |
+
token_type_ids=token_type_ids,
|
| 165 |
+
position_ids=position_ids,
|
| 166 |
+
head_mask=head_mask,
|
| 167 |
+
inputs_embeds=inputs_embeds,
|
| 168 |
+
output_attentions=output_attentions,
|
| 169 |
+
output_hidden_states=output_hidden_states,
|
| 170 |
+
return_dict=return_dict,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
hidden_states = outputs[0]
|
| 174 |
+
hidden_transform = self.transform(hidden_states)
|
| 175 |
+
logits = self.decoder(hidden_transform) + self.bias
|
| 176 |
+
|
| 177 |
+
probs = F.softmax(logits, dim=-1)
|
| 178 |
+
|
| 179 |
+
# Probability manipulations to avoid repeats from already known genes
|
| 180 |
+
probs = self.assign_known_gene_probs(probs, input_ids, labels)
|
| 181 |
+
convert_probs = self.probability_convert(probs, input_ids, labels)
|
| 182 |
+
assigned_probs = self.assign_known_gene_probs(convert_probs, input_ids, labels)
|
| 183 |
+
|
| 184 |
+
masked_lm_loss = None
|
| 185 |
+
if labels is not None:
|
| 186 |
+
probs_flat = assigned_probs.view(-1, self.config.vocab_size)
|
| 187 |
+
labels_flat = labels.view(-1)
|
| 188 |
+
mask = (labels != -100).float().view(-1)
|
| 189 |
+
|
| 190 |
+
# Compute masked cross-entropy loss
|
| 191 |
+
masked_lm_loss = -torch.log(torch.clamp(probs_flat[torch.arange(len(labels_flat)), labels_flat], min=1e-18)) * mask
|
| 192 |
+
masked_lm_loss = masked_lm_loss.sum() / mask.sum()
|
| 193 |
+
|
| 194 |
+
else:
|
| 195 |
+
loss = None
|
| 196 |
+
|
| 197 |
+
if not return_dict:
|
| 198 |
+
output = (assigned_probs,) + outputs[2:]
|
| 199 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 200 |
+
|
| 201 |
+
return MaskedLMOutput(
|
| 202 |
+
loss=masked_lm_loss,
|
| 203 |
+
logits=assigned_probs,
|
| 204 |
+
hidden_states=outputs.hidden_states,
|
| 205 |
+
attentions=outputs.attentions,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 209 |
+
input_shape = input_ids.shape
|
| 210 |
+
effective_batch_size = input_shape[0]
|
| 211 |
+
|
| 212 |
+
# add a dummy token
|
| 213 |
+
if self.config.pad_token_id is None:
|
| 214 |
+
raise ValueError("The PAD token should be defined for generation")
|
| 215 |
+
|
| 216 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
| 217 |
+
dummy_token = torch.full(
|
| 218 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
| 219 |
+
)
|
| 220 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 221 |
+
|
| 222 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 223 |
+
|
| 224 |
+
def prepare_data(
|
| 225 |
+
input_data_file,
|
| 226 |
+
output_directory,
|
| 227 |
+
output_prefix,
|
| 228 |
+
split_id_dict=None,
|
| 229 |
+
test_size=None,
|
| 230 |
+
attr_to_split=None,
|
| 231 |
+
attr_to_balance=None,
|
| 232 |
+
max_trials=100,
|
| 233 |
+
pval_threshold=0.1,
|
| 234 |
+
):
|
| 235 |
+
"""
|
| 236 |
+
Prepare data for cell state or gene classification.
|
| 237 |
+
|
| 238 |
+
**Parameters**
|
| 239 |
+
|
| 240 |
+
input_data_file : Path
|
| 241 |
+
| Path to directory containing .dataset input
|
| 242 |
+
output_directory : Path
|
| 243 |
+
| Path to directory where prepared data will be saved
|
| 244 |
+
output_prefix : str
|
| 245 |
+
| Prefix for output file
|
| 246 |
+
split_id_dict : None, dict
|
| 247 |
+
| Dictionary of IDs for train and test splits
|
| 248 |
+
| Three-item dictionary with keys: attr_key, train, test
|
| 249 |
+
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
|
| 250 |
+
| train: list of IDs in the attr_key column to include in the train split
|
| 251 |
+
| test: list of IDs in the attr_key column to include in the test split
|
| 252 |
+
| For example: {"attr_key": "individual",
|
| 253 |
+
| "train": ["patient1", "patient2", "patient3", "patient4"],
|
| 254 |
+
| "test": ["patient5", "patient6"]}
|
| 255 |
+
test_size : None, float
|
| 256 |
+
| Proportion of data to be saved separately and held out for test set
|
| 257 |
+
| (e.g. 0.2 if intending hold out 20%)
|
| 258 |
+
| If None, will inherit from split_sizes["test"] from Classifier
|
| 259 |
+
| The training set will be further split to train / validation in self.validate
|
| 260 |
+
| Note: only available for CellClassifiers
|
| 261 |
+
attr_to_split : None, str
|
| 262 |
+
| Key for attribute on which to split data while balancing potential confounders
|
| 263 |
+
| e.g. "patient_id" for splitting by patient while balancing other characteristics
|
| 264 |
+
| Note: only available for CellClassifiers
|
| 265 |
+
attr_to_balance : None, list
|
| 266 |
+
| List of attribute keys on which to balance data while splitting on attr_to_split
|
| 267 |
+
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
| 268 |
+
| Note: only available for CellClassifiers
|
| 269 |
+
max_trials : None, int
|
| 270 |
+
| Maximum number of trials of random splitting to try to achieve balanced other attributes
|
| 271 |
+
| If no split is found without significant (p<0.05) differences in other attributes, will select best
|
| 272 |
+
| Note: only available for CellClassifiers
|
| 273 |
+
pval_threshold : None, float
|
| 274 |
+
| P-value threshold to use for attribute balancing across splits
|
| 275 |
+
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
if test_size is None:
|
| 279 |
+
test_size = oos_test_size
|
| 280 |
+
|
| 281 |
+
# prepare data and labels for classification
|
| 282 |
+
data = load_and_filter(filter_data, nproc, input_data_file)
|
| 283 |
+
|
| 284 |
+
if classifier == "cell":
|
| 285 |
+
if "label" in data.features:
|
| 286 |
+
logger.error(
|
| 287 |
+
"Column name 'label' must be reserved for class IDs. Please rename column."
|
| 288 |
+
)
|
| 289 |
+
raise
|
| 290 |
+
elif classifier == "gene":
|
| 291 |
+
if "labels" in data.features:
|
| 292 |
+
logger.error(
|
| 293 |
+
"Column name 'labels' must be reserved for class IDs. Please rename column."
|
| 294 |
+
)
|
| 295 |
+
raise
|
| 296 |
+
|
| 297 |
+
if (attr_to_split is not None) and (attr_to_balance is None):
|
| 298 |
+
logger.error(
|
| 299 |
+
"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
|
| 300 |
+
)
|
| 301 |
+
raise
|
| 302 |
+
|
| 303 |
+
if not isinstance(attr_to_balance, list):
|
| 304 |
+
attr_to_balance = [attr_to_balance]
|
| 305 |
+
|
| 306 |
+
if classifier == "cell":
|
| 307 |
+
# remove cell states representing < rare_threshold of cells
|
| 308 |
+
data = remove_rare(
|
| 309 |
+
data, rare_threshold, cell_state_dict["state_key"], nproc
|
| 310 |
+
)
|
| 311 |
+
# downsample max cells and max per class
|
| 312 |
+
data = downsample_and_shuffle(
|
| 313 |
+
data, max_ncells, None, cell_state_dict
|
| 314 |
+
)
|
| 315 |
+
# rename cell state column to "label"
|
| 316 |
+
data = rename_cols(data, cell_state_dict["state_key"])
|
| 317 |
+
|
| 318 |
+
# convert classes to numerical labels and save as id_class_dict
|
| 319 |
+
# of note, will label all genes in gene_class_dict
|
| 320 |
+
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
| 321 |
+
# at the time of training with Classifier.validate
|
| 322 |
+
data, id_class_dict = label_classes(
|
| 323 |
+
classifier, data, None, nproc
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# save id_class_dict for future reference
|
| 327 |
+
id_class_output_path = (
|
| 328 |
+
Path(output_directory) / f"{output_prefix}_id_class_dict"
|
| 329 |
+
).with_suffix(".pkl")
|
| 330 |
+
with open(id_class_output_path, "wb") as f:
|
| 331 |
+
pickle.dump(id_class_dict, f)
|
| 332 |
+
|
| 333 |
+
if split_id_dict is not None:
|
| 334 |
+
data_dict = dict()
|
| 335 |
+
data_dict["train"] = filter_by_dict(
|
| 336 |
+
data, {split_id_dict["attr_key"]: split_id_dict["train"]}, nproc
|
| 337 |
+
)
|
| 338 |
+
data_dict["test"] = filter_by_dict(
|
| 339 |
+
data, {split_id_dict["attr_key"]: split_id_dict["test"]}, nproc
|
| 340 |
+
)
|
| 341 |
+
train_data_output_path = (
|
| 342 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
| 343 |
+
).with_suffix(".dataset")
|
| 344 |
+
test_data_output_path = (
|
| 345 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
| 346 |
+
).with_suffix(".dataset")
|
| 347 |
+
data_dict["train"].save_to_disk(str(train_data_output_path))
|
| 348 |
+
data_dict["test"].save_to_disk(str(test_data_output_path))
|
| 349 |
+
elif (test_size is not None) and (classifier == "cell"):
|
| 350 |
+
if 1 > test_size > 0:
|
| 351 |
+
if attr_to_split is None:
|
| 352 |
+
data_dict = data.train_test_split(
|
| 353 |
+
test_size=test_size,
|
| 354 |
+
stratify_by_column=None,
|
| 355 |
+
seed=42,
|
| 356 |
+
)
|
| 357 |
+
train_data_output_path = (
|
| 358 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
| 359 |
+
).with_suffix(".dataset")
|
| 360 |
+
test_data_output_path = (
|
| 361 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
| 362 |
+
).with_suffix(".dataset")
|
| 363 |
+
data_dict["train"].save_to_disk(str(train_data_output_path))
|
| 364 |
+
data_dict["test"].save_to_disk(str(test_data_output_path))
|
| 365 |
+
else:
|
| 366 |
+
data_dict, balance_df = cu.balance_attr_splits(
|
| 367 |
+
data,
|
| 368 |
+
attr_to_split,
|
| 369 |
+
attr_to_balance,
|
| 370 |
+
test_size,
|
| 371 |
+
max_trials,
|
| 372 |
+
pval_threshold,
|
| 373 |
+
cell_state_dict["state_key"],
|
| 374 |
+
nproc,
|
| 375 |
+
)
|
| 376 |
+
balance_df.to_csv(
|
| 377 |
+
f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
|
| 378 |
+
)
|
| 379 |
+
train_data_output_path = (
|
| 380 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
| 381 |
+
).with_suffix(".dataset")
|
| 382 |
+
test_data_output_path = (
|
| 383 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
| 384 |
+
).with_suffix(".dataset")
|
| 385 |
+
data_dict["train"].save_to_disk(str(train_data_output_path))
|
| 386 |
+
data_dict["test"].save_to_disk(str(test_data_output_path))
|
| 387 |
+
else:
|
| 388 |
+
data_output_path = (
|
| 389 |
+
Path(output_directory) / f"{output_prefix}_labeled"
|
| 390 |
+
).with_suffix(".dataset")
|
| 391 |
+
data.save_to_disk(str(data_output_path))
|
| 392 |
+
print(data_output_path)
|
| 393 |
+
else:
|
| 394 |
+
data_output_path = (
|
| 395 |
+
Path(output_directory) / f"{output_prefix}_labeled"
|
| 396 |
+
).with_suffix(".dataset")
|
| 397 |
+
data.save_to_disk(str(data_output_path))
|
| 398 |
+
|
| 399 |
+
def load_and_filter(filter_data, nproc, input_data_file):
|
| 400 |
+
data = load_from_disk(input_data_file)
|
| 401 |
+
if filter_data is not None:
|
| 402 |
+
data = filter_by_dict(data, filter_data, nproc)
|
| 403 |
+
return data
|
| 404 |
+
# get number of classes for classifier
|
| 405 |
+
def get_num_classes(id_class_dict):
|
| 406 |
+
return len(set(id_class_dict.values()))
|
| 407 |
+
|
| 408 |
+
def filter_by_dict(data, filter_data, nproc):
|
| 409 |
+
for key, value in filter_data.items():
|
| 410 |
+
|
| 411 |
+
def filter_data_by_criteria(example):
|
| 412 |
+
return example[key] in value
|
| 413 |
+
|
| 414 |
+
data = data.filter(filter_data_by_criteria, num_proc=nproc)
|
| 415 |
+
if len(data) == 0:
|
| 416 |
+
logger.error("No cells remain after filtering. Check filtering criteria.")
|
| 417 |
+
raise
|
| 418 |
+
return data
|
| 419 |
+
def remove_rare(data, rare_threshold, label, nproc):
|
| 420 |
+
if rare_threshold > 0:
|
| 421 |
+
total_cells = len(data)
|
| 422 |
+
label_counter = Counter(data[label])
|
| 423 |
+
nonrare_label_dict = {
|
| 424 |
+
label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
|
| 425 |
+
}
|
| 426 |
+
data = filter_by_dict(data, nonrare_label_dict, nproc)
|
| 427 |
+
return data
|
| 428 |
+
def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
|
| 429 |
+
data = data.shuffle(seed=42)
|
| 430 |
+
num_cells = len(data)
|
| 431 |
+
# if max number of cells is defined, then subsample to this max number
|
| 432 |
+
if max_ncells is not None:
|
| 433 |
+
if num_cells > max_ncells:
|
| 434 |
+
data = data.select([i for i in range(max_ncells)])
|
| 435 |
+
if max_ncells_per_class is not None:
|
| 436 |
+
class_labels = data[cell_state_dict["state_key"]]
|
| 437 |
+
random.seed(42)
|
| 438 |
+
subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
|
| 439 |
+
data = data.select(subsample_indices)
|
| 440 |
+
return data
|
| 441 |
+
def rename_cols(data, state_key):
|
| 442 |
+
data = data.rename_column(state_key, "label")
|
| 443 |
+
return data
|
| 444 |
+
def label_classes(classifier, data, gene_class_dict, nproc):
|
| 445 |
+
if classifier == "cell":
|
| 446 |
+
label_set = set(data["label"])
|
| 447 |
+
elif classifier == "gene":
|
| 448 |
+
# remove cells without any of the target genes
|
| 449 |
+
def if_contains_label(example):
|
| 450 |
+
a = pu.flatten_list(gene_class_dict.values())
|
| 451 |
+
b = example["input_ids"]
|
| 452 |
+
return not set(a).isdisjoint(b)
|
| 453 |
+
|
| 454 |
+
data = data.filter(if_contains_label, num_proc=nproc)
|
| 455 |
+
label_set = gene_class_dict.keys()
|
| 456 |
+
|
| 457 |
+
if len(data) == 0:
|
| 458 |
+
logger.error(
|
| 459 |
+
"No cells remain after filtering for target genes. Check target gene list."
|
| 460 |
+
)
|
| 461 |
+
raise
|
| 462 |
+
|
| 463 |
+
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
| 464 |
+
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
| 465 |
+
|
| 466 |
+
def classes_to_ids(example):
|
| 467 |
+
if classifier == "cell":
|
| 468 |
+
example["label"] = class_id_dict[example["label"]]
|
| 469 |
+
elif classifier == "gene":
|
| 470 |
+
example["labels"] = label_gene_classes(
|
| 471 |
+
example, class_id_dict, gene_class_dict
|
| 472 |
+
)
|
| 473 |
+
return example
|
| 474 |
+
|
| 475 |
+
data = data.map(classes_to_ids, num_proc=nproc)
|
| 476 |
+
return data, id_class_dict
|
| 477 |
+
|
| 478 |
+
def train_classifier(
|
| 479 |
+
model_directory,
|
| 480 |
+
num_classes,
|
| 481 |
+
train_data,
|
| 482 |
+
eval_data,
|
| 483 |
+
output_directory,
|
| 484 |
+
predict=False,
|
| 485 |
+
classifier='cell',
|
| 486 |
+
no_eval=False,
|
| 487 |
+
quantize = False,
|
| 488 |
+
freeze_layers=2,
|
| 489 |
+
):
|
| 490 |
+
"""
|
| 491 |
+
Fine-tune model for cell state or gene classification.
|
| 492 |
+
|
| 493 |
+
**Parameters**
|
| 494 |
+
|
| 495 |
+
model_directory : Path
|
| 496 |
+
| Path to directory containing model
|
| 497 |
+
num_classes : int
|
| 498 |
+
| Number of classes for classifier
|
| 499 |
+
train_data : Dataset
|
| 500 |
+
| Loaded training .dataset input
|
| 501 |
+
| For cell classifier, labels in column "label".
|
| 502 |
+
| For gene classifier, labels in column "labels".
|
| 503 |
+
eval_data : None, Dataset
|
| 504 |
+
| (Optional) Loaded evaluation .dataset input
|
| 505 |
+
| For cell classifier, labels in column "label".
|
| 506 |
+
| For gene classifier, labels in column "labels".
|
| 507 |
+
output_directory : Path
|
| 508 |
+
| Path to directory where fine-tuned model will be saved
|
| 509 |
+
predict : bool
|
| 510 |
+
| Whether or not to save eval predictions from trainer
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
##### Validate and prepare data #####
|
| 514 |
+
train_data, eval_data = validate_and_clean_cols(
|
| 515 |
+
train_data, eval_data, classifier
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
if (no_eval is True) and (eval_data is not None):
|
| 519 |
+
logger.warning(
|
| 520 |
+
"no_eval set to True; model will be trained without evaluation."
|
| 521 |
+
)
|
| 522 |
+
eval_data = None
|
| 523 |
+
|
| 524 |
+
if (classifier == "gene") and (predict is True):
|
| 525 |
+
logger.warning(
|
| 526 |
+
"Predictions during training not currently available for gene classifiers; setting predict to False."
|
| 527 |
+
)
|
| 528 |
+
predict = False
|
| 529 |
+
|
| 530 |
+
# ensure not overwriting previously saved model
|
| 531 |
+
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
|
| 532 |
+
if os.path.isfile(saved_model_test) is True:
|
| 533 |
+
logger.error("Model already saved to this designated output directory.")
|
| 534 |
+
raise
|
| 535 |
+
# make output directory
|
| 536 |
+
# subprocess.call(f"mkdir {output_directory}", shell=True)
|
| 537 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 538 |
+
|
| 539 |
+
##### Load model and training args #####
|
| 540 |
+
model = load_model(
|
| 541 |
+
"CellClassifier",
|
| 542 |
+
num_classes,
|
| 543 |
+
model_directory,
|
| 544 |
+
"train",
|
| 545 |
+
quantize=quantize,
|
| 546 |
+
)
|
| 547 |
+
#############
|
| 548 |
+
pretrained_model = CustomBertForMaskedLM.from_pretrained(model_directory)
|
| 549 |
+
# Extract the word embeddings from the pretrained model
|
| 550 |
+
pretrained_word_embeddings = pretrained_model.bert.embeddings.word_embeddings.weight.clone()
|
| 551 |
+
model.bert.embeddings.word_embeddings.load_state_dict({"weight": pretrained_word_embeddings})
|
| 552 |
+
############
|
| 553 |
+
def_training_args, def_freeze_layers = get_default_train_args(
|
| 554 |
+
model, classifier, train_data, output_directory
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
if training_args is not None:
|
| 558 |
+
def_training_args.update(training_args)
|
| 559 |
+
logging_steps = round(
|
| 560 |
+
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
|
| 561 |
+
)
|
| 562 |
+
def_training_args["logging_steps"] = logging_steps
|
| 563 |
+
def_training_args["output_dir"] = output_directory
|
| 564 |
+
if eval_data is None:
|
| 565 |
+
def_training_args["evaluation_strategy"] = "no"
|
| 566 |
+
def_training_args["load_best_model_at_end"] = False
|
| 567 |
+
training_args_init = TrainingArguments(**def_training_args)
|
| 568 |
+
|
| 569 |
+
if freeze_layers is not None:
|
| 570 |
+
def_freeze_layers = freeze_layers
|
| 571 |
+
|
| 572 |
+
if def_freeze_layers > 0:
|
| 573 |
+
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
|
| 574 |
+
for module in modules_to_freeze:
|
| 575 |
+
for param in module.parameters():
|
| 576 |
+
param.requires_grad = False
|
| 577 |
+
|
| 578 |
+
##### Fine-tune the model #####
|
| 579 |
+
# define the data collator
|
| 580 |
+
if classifier == "cell":
|
| 581 |
+
data_collator = DataCollatorForCellClassification()
|
| 582 |
+
elif self.classifier == "gene":
|
| 583 |
+
data_collator = DataCollatorForGeneClassification()
|
| 584 |
+
|
| 585 |
+
# create the trainer
|
| 586 |
+
trainer = Trainer(
|
| 587 |
+
model=model,
|
| 588 |
+
args=training_args_init,
|
| 589 |
+
data_collator=data_collator,
|
| 590 |
+
train_dataset=train_data,
|
| 591 |
+
eval_dataset=eval_data,
|
| 592 |
+
compute_metrics=compute_metrics,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# train the classifier
|
| 596 |
+
trainer.train()
|
| 597 |
+
trainer.save_model(output_directory)
|
| 598 |
+
if predict is True:
|
| 599 |
+
# make eval predictions and save predictions and metrics
|
| 600 |
+
predictions = trainer.predict(eval_data)
|
| 601 |
+
prediction_output_path = f"{output_directory}/predictions.pkl"
|
| 602 |
+
with open(prediction_output_path, "wb") as f:
|
| 603 |
+
pickle.dump(predictions, f)
|
| 604 |
+
trainer.save_metrics("eval", predictions.metrics)
|
| 605 |
+
return trainer
|
| 606 |
+
|
| 607 |
+
def validate_and_clean_cols(train_data, eval_data, classifier):
|
| 608 |
+
# validate that data has expected label column and remove others
|
| 609 |
+
if classifier == "cell":
|
| 610 |
+
label_col = "label"
|
| 611 |
+
elif classifier == "gene":
|
| 612 |
+
label_col = "labels"
|
| 613 |
+
|
| 614 |
+
cols_to_keep = [label_col] + ["input_ids", "length"]
|
| 615 |
+
if label_col not in train_data.column_names:
|
| 616 |
+
logger.error(f"train_data must contain column {label_col} with class labels.")
|
| 617 |
+
raise
|
| 618 |
+
else:
|
| 619 |
+
train_data = remove_cols(train_data, cols_to_keep)
|
| 620 |
+
|
| 621 |
+
if eval_data is not None:
|
| 622 |
+
if label_col not in eval_data.column_names:
|
| 623 |
+
logger.error(
|
| 624 |
+
f"eval_data must contain column {label_col} with class labels."
|
| 625 |
+
)
|
| 626 |
+
raise
|
| 627 |
+
else:
|
| 628 |
+
eval_data = remove_cols(eval_data, cols_to_keep)
|
| 629 |
+
return train_data, eval_data
|
| 630 |
+
|
| 631 |
+
def remove_cols(data, cols_to_keep):
|
| 632 |
+
other_cols = list(data.features.keys())
|
| 633 |
+
other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
|
| 634 |
+
data = data.remove_columns(other_cols)
|
| 635 |
+
return data
|
| 636 |
+
|
| 637 |
+
def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
| 638 |
+
if model_type == "MTLCellClassifier-Quantized":
|
| 639 |
+
model_type = "MTLCellClassifier"
|
| 640 |
+
quantize = True
|
| 641 |
+
|
| 642 |
+
output_hidden_states = (mode == "eval")
|
| 643 |
+
|
| 644 |
+
# Quantization logic
|
| 645 |
+
if quantize:
|
| 646 |
+
if model_type == "MTLCellClassifier":
|
| 647 |
+
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 648 |
+
peft_config = None
|
| 649 |
+
else:
|
| 650 |
+
quantize_config = BitsAndBytesConfig(
|
| 651 |
+
load_in_4bit=True,
|
| 652 |
+
bnb_4bit_use_double_quant=True,
|
| 653 |
+
bnb_4bit_quant_type="nf4",
|
| 654 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 655 |
+
)
|
| 656 |
+
peft_config = LoraConfig(
|
| 657 |
+
lora_alpha=128,
|
| 658 |
+
lora_dropout=0.1,
|
| 659 |
+
r=64,
|
| 660 |
+
bias="none",
|
| 661 |
+
task_type="TokenClassification",
|
| 662 |
+
)
|
| 663 |
+
else:
|
| 664 |
+
quantize_config = None
|
| 665 |
+
peft_config = None
|
| 666 |
+
|
| 667 |
+
# Model class selection
|
| 668 |
+
model_classes = {
|
| 669 |
+
"Pretrained": BertForMaskedLM,
|
| 670 |
+
"GeneClassifier": BertForTokenClassification,
|
| 671 |
+
"CellClassifier": BertForSequenceClassification,
|
| 672 |
+
"MTLCellClassifier": BertForMaskedLM
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
model_class = model_classes.get(model_type)
|
| 676 |
+
if not model_class:
|
| 677 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 678 |
+
|
| 679 |
+
# Model loading
|
| 680 |
+
model_args = {
|
| 681 |
+
"pretrained_model_name_or_path": model_directory,
|
| 682 |
+
"output_hidden_states": output_hidden_states,
|
| 683 |
+
"output_attentions": False,
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
if model_type != "Pretrained":
|
| 687 |
+
model_args["num_labels"] = num_classes
|
| 688 |
+
|
| 689 |
+
if quantize_config:
|
| 690 |
+
model_args["quantization_config"] = quantize_config
|
| 691 |
+
|
| 692 |
+
# Load the model
|
| 693 |
+
model = model_class.from_pretrained(**model_args)
|
| 694 |
+
###########################
|
| 695 |
+
|
| 696 |
+
if mode == "eval":
|
| 697 |
+
model.eval()
|
| 698 |
+
|
| 699 |
+
# Handle device placement and PEFT
|
| 700 |
+
if not quantize:
|
| 701 |
+
# Only move non-quantized models
|
| 702 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 703 |
+
model = model.to(device)
|
| 704 |
+
elif peft_config:
|
| 705 |
+
# Apply PEFT for quantized models (except MTLCellClassifier)
|
| 706 |
+
model.enable_input_require_grads()
|
| 707 |
+
model = get_peft_model(model, peft_config)
|
| 708 |
+
|
| 709 |
+
return model
|
| 710 |
+
|
| 711 |
+
def get_default_train_args(model, classifier, data, output_dir):
|
| 712 |
+
num_layers = quant_layers(model)
|
| 713 |
+
freeze_layers_get = 0
|
| 714 |
+
batch_size = 12
|
| 715 |
+
if classifier == "cell":
|
| 716 |
+
epochs = 10
|
| 717 |
+
evaluation_strategy = "epoch"
|
| 718 |
+
load_best_model_at_end = True
|
| 719 |
+
else:
|
| 720 |
+
epochs = 1
|
| 721 |
+
evaluation_strategy = "no"
|
| 722 |
+
load_best_model_at_end = False
|
| 723 |
+
|
| 724 |
+
if num_layers == 6:
|
| 725 |
+
default_training_args = {
|
| 726 |
+
"learning_rate": 5e-5,
|
| 727 |
+
"lr_scheduler_type": "linear",
|
| 728 |
+
"warmup_steps": 500,
|
| 729 |
+
"per_device_train_batch_size": batch_size,
|
| 730 |
+
"per_device_eval_batch_size": batch_size,
|
| 731 |
+
}
|
| 732 |
+
else:
|
| 733 |
+
default_training_args = {
|
| 734 |
+
"per_device_train_batch_size": batch_size,
|
| 735 |
+
"per_device_eval_batch_size": batch_size,
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
training_args = {
|
| 739 |
+
"num_train_epochs": epochs,
|
| 740 |
+
"do_train": True,
|
| 741 |
+
"do_eval": True,
|
| 742 |
+
"evaluation_strategy": evaluation_strategy,
|
| 743 |
+
"logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
|
| 744 |
+
"save_strategy": "epoch",
|
| 745 |
+
"group_by_length": False,
|
| 746 |
+
"length_column_name": "length",
|
| 747 |
+
"disable_tqdm": False,
|
| 748 |
+
"weight_decay": 0.001,
|
| 749 |
+
"load_best_model_at_end": load_best_model_at_end,
|
| 750 |
+
}
|
| 751 |
+
training_args.update(default_training_args)
|
| 752 |
+
|
| 753 |
+
return training_args, freeze_layers_get
|
| 754 |
+
|
| 755 |
+
def quant_layers(model):
|
| 756 |
+
layer_nums = []
|
| 757 |
+
for name, parameter in model.named_parameters():
|
| 758 |
+
if "layer" in name:
|
| 759 |
+
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
| 760 |
+
return int(max(layer_nums)) + 1
|
| 761 |
+
|
| 762 |
+
def compute_metrics(pred):
|
| 763 |
+
labels = pred.label_ids
|
| 764 |
+
preds = pred.predictions.argmax(-1)
|
| 765 |
+
# calculate accuracy and macro f1 using sklearn's function
|
| 766 |
+
acc = accuracy_score(labels, preds)
|
| 767 |
+
macro_f1 = f1_score(labels, preds, average='macro')
|
| 768 |
+
weighted_f1 = f1_score(labels, preds, average='weighted')
|
| 769 |
+
return {
|
| 770 |
+
'accuracy': acc,
|
| 771 |
+
'macro_f1': macro_f1,
|
| 772 |
+
'weighted_f1': weighted_f1
|
| 773 |
+
}
|
| 774 |
+
def evaluate_model(
|
| 775 |
+
model,
|
| 776 |
+
num_classes,
|
| 777 |
+
id_class_dict,
|
| 778 |
+
eval_data,
|
| 779 |
+
predict=False,
|
| 780 |
+
output_directory=None,
|
| 781 |
+
output_prefix=None,
|
| 782 |
+
):
|
| 783 |
+
"""
|
| 784 |
+
Evaluate the fine-tuned model.
|
| 785 |
+
|
| 786 |
+
**Parameters**
|
| 787 |
+
|
| 788 |
+
model : nn.Module
|
| 789 |
+
| Loaded fine-tuned model (e.g. trainer.model)
|
| 790 |
+
num_classes : int
|
| 791 |
+
| Number of classes for classifier
|
| 792 |
+
id_class_dict : dict
|
| 793 |
+
| Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
| 794 |
+
| (dictionary of format: numerical IDs: class_labels)
|
| 795 |
+
eval_data : Dataset
|
| 796 |
+
| Loaded evaluation .dataset input
|
| 797 |
+
predict : bool
|
| 798 |
+
| Whether or not to save eval predictions
|
| 799 |
+
output_directory : Path
|
| 800 |
+
| Path to directory where eval data will be saved
|
| 801 |
+
output_prefix : str
|
| 802 |
+
| Prefix for output files
|
| 803 |
+
"""
|
| 804 |
+
|
| 805 |
+
##### Evaluate the model #####
|
| 806 |
+
labels = id_class_dict.keys()
|
| 807 |
+
y_pred, y_true, logits_list = classifier_predict(
|
| 808 |
+
model, classifier, eval_data, 100
|
| 809 |
+
)
|
| 810 |
+
conf_mat, macro_f1, acc, roc_metrics = get_metrics(
|
| 811 |
+
y_pred, y_true, logits_list, num_classes, labels
|
| 812 |
+
)
|
| 813 |
+
if predict is True:
|
| 814 |
+
pred_dict = {
|
| 815 |
+
"pred_ids": y_pred,
|
| 816 |
+
"label_ids": y_true,
|
| 817 |
+
"predictions": logits_list,
|
| 818 |
+
}
|
| 819 |
+
pred_dict_output_path = (
|
| 820 |
+
Path(output_directory) / f"{output_prefix}_pred_dict"
|
| 821 |
+
).with_suffix(".pkl")
|
| 822 |
+
with open(pred_dict_output_path, "wb") as f:
|
| 823 |
+
pickle.dump(pred_dict, f)
|
| 824 |
+
return {
|
| 825 |
+
"conf_mat": conf_mat,
|
| 826 |
+
"macro_f1": macro_f1,
|
| 827 |
+
"acc": acc,
|
| 828 |
+
"roc_metrics": roc_metrics,
|
| 829 |
+
}
|
| 830 |
+
|
| 831 |
+
def classifier_predict(model, classifier_type, evalset, forward_batch_size):
|
| 832 |
+
if classifier_type == "gene":
|
| 833 |
+
label_name = "labels"
|
| 834 |
+
elif classifier_type == "cell":
|
| 835 |
+
label_name = "label"
|
| 836 |
+
|
| 837 |
+
predict_logits = []
|
| 838 |
+
predict_labels = []
|
| 839 |
+
model.eval()
|
| 840 |
+
|
| 841 |
+
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
| 842 |
+
evalset_len = len(evalset)
|
| 843 |
+
max_divisible = find_largest_div(evalset_len, forward_batch_size)
|
| 844 |
+
if len(evalset) - max_divisible == 1:
|
| 845 |
+
evalset_len = max_divisible
|
| 846 |
+
|
| 847 |
+
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
|
| 848 |
+
|
| 849 |
+
disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
|
| 850 |
+
for i in trange(0, evalset_len, forward_batch_size):
|
| 851 |
+
max_range = min(i + forward_batch_size, evalset_len)
|
| 852 |
+
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
| 853 |
+
padded_batch = preprocess_classifier_batch(
|
| 854 |
+
batch_evalset, max_evalset_len, label_name
|
| 855 |
+
)
|
| 856 |
+
padded_batch.set_format(type="torch")
|
| 857 |
+
|
| 858 |
+
input_data_batch = padded_batch["input_ids"]
|
| 859 |
+
attn_msk_batch = padded_batch["attention_mask"]
|
| 860 |
+
label_batch = padded_batch[label_name]
|
| 861 |
+
with torch.no_grad():
|
| 862 |
+
outputs = model(
|
| 863 |
+
input_ids=input_data_batch.to("cuda"),
|
| 864 |
+
attention_mask=attn_msk_batch.to("cuda"),
|
| 865 |
+
labels=label_batch.to("cuda"),
|
| 866 |
+
)
|
| 867 |
+
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
|
| 868 |
+
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
|
| 869 |
+
|
| 870 |
+
enable_progress_bar()
|
| 871 |
+
logits_by_cell = torch.cat(predict_logits)
|
| 872 |
+
last_dim = len(logits_by_cell.shape) - 1
|
| 873 |
+
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
|
| 874 |
+
labels_by_cell = torch.cat(predict_labels)
|
| 875 |
+
all_labels = torch.flatten(labels_by_cell)
|
| 876 |
+
logit_label_paired = [
|
| 877 |
+
item
|
| 878 |
+
for item in list(zip(all_logits.tolist(), all_labels.tolist()))
|
| 879 |
+
if item[1] != -100
|
| 880 |
+
]
|
| 881 |
+
y_pred = [vote(item[0]) for item in logit_label_paired]
|
| 882 |
+
y_true = [item[1] for item in logit_label_paired]
|
| 883 |
+
logits_list = [item[0] for item in logit_label_paired]
|
| 884 |
+
return y_pred, y_true, logits_list
|
| 885 |
+
|
| 886 |
+
def find_largest_div(N, K):
|
| 887 |
+
rem = N % K
|
| 888 |
+
if rem == 0:
|
| 889 |
+
return N
|
| 890 |
+
else:
|
| 891 |
+
return N - rem
|
| 892 |
+
def preprocess_classifier_batch(cell_batch, max_len, label_name):
|
| 893 |
+
if max_len is None:
|
| 894 |
+
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
| 895 |
+
|
| 896 |
+
def pad_label_example(example):
|
| 897 |
+
example[label_name] = np.pad(
|
| 898 |
+
example[label_name],
|
| 899 |
+
(0, max_len - len(example["input_ids"])),
|
| 900 |
+
mode="constant",
|
| 901 |
+
constant_values=-100,
|
| 902 |
+
)
|
| 903 |
+
example["input_ids"] = np.pad(
|
| 904 |
+
example["input_ids"],
|
| 905 |
+
(0, max_len - len(example["input_ids"])),
|
| 906 |
+
mode="constant",
|
| 907 |
+
constant_values=gene_token_dict.get("<pad>"),
|
| 908 |
+
)
|
| 909 |
+
example["attention_mask"] = (
|
| 910 |
+
example["input_ids"] != gene_token_dict.get("<pad>")
|
| 911 |
+
).astype(int)
|
| 912 |
+
return example
|
| 913 |
+
|
| 914 |
+
padded_batch = cell_batch.map(pad_label_example)
|
| 915 |
+
return padded_batch
|
| 916 |
+
def vote(logit_list):
|
| 917 |
+
m = max(logit_list)
|
| 918 |
+
logit_list.index(m)
|
| 919 |
+
indices = [i for i, x in enumerate(logit_list) if x == m]
|
| 920 |
+
if len(indices) > 1:
|
| 921 |
+
return "tie"
|
| 922 |
+
else:
|
| 923 |
+
return indices[0]
|
| 924 |
+
def py_softmax(vector):
|
| 925 |
+
e = np.exp(vector)
|
| 926 |
+
return e / e.sum()
|
| 927 |
+
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
| 928 |
+
conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
|
| 929 |
+
macro_f1 = f1_score(y_true, y_pred, average="macro")
|
| 930 |
+
acc = accuracy_score(y_true, y_pred)
|
| 931 |
+
roc_metrics = None # roc metrics not reported for multiclass
|
| 932 |
+
if num_classes == 2:
|
| 933 |
+
y_score = [py_softmax(item)[1] for item in logits_list]
|
| 934 |
+
fpr, tpr, _ = roc_curve(y_true, y_score)
|
| 935 |
+
mean_fpr = np.linspace(0, 1, 100)
|
| 936 |
+
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
| 937 |
+
interp_tpr[0] = 0.0
|
| 938 |
+
tpr_wt = len(tpr)
|
| 939 |
+
roc_auc = auc(fpr, tpr)
|
| 940 |
+
roc_metrics = {
|
| 941 |
+
"fpr": fpr,
|
| 942 |
+
"tpr": tpr,
|
| 943 |
+
"interp_tpr": interp_tpr,
|
| 944 |
+
"auc": roc_auc,
|
| 945 |
+
"tpr_wt": tpr_wt,
|
| 946 |
+
}
|
| 947 |
+
return conf_mat, macro_f1, acc, roc_metrics
|
| 948 |
+
def evaluate_saved_model(
|
| 949 |
+
model_directory,
|
| 950 |
+
id_class_dict_file,
|
| 951 |
+
test_data_file,
|
| 952 |
+
output_directory,
|
| 953 |
+
output_prefix,
|
| 954 |
+
predict=True,
|
| 955 |
+
):
|
| 956 |
+
"""
|
| 957 |
+
Evaluate the fine-tuned model.
|
| 958 |
+
|
| 959 |
+
**Parameters**
|
| 960 |
+
|
| 961 |
+
model_directory : Path
|
| 962 |
+
| Path to directory containing model
|
| 963 |
+
id_class_dict_file : Path
|
| 964 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
| 965 |
+
| (dictionary of format: numerical IDs: class_labels)
|
| 966 |
+
test_data_file : Path
|
| 967 |
+
| Path to directory containing test .dataset
|
| 968 |
+
output_directory : Path
|
| 969 |
+
| Path to directory where eval data will be saved
|
| 970 |
+
output_prefix : str
|
| 971 |
+
| Prefix for output files
|
| 972 |
+
predict : bool
|
| 973 |
+
| Whether or not to save eval predictions
|
| 974 |
+
"""
|
| 975 |
+
|
| 976 |
+
# load numerical id to class dictionary (id:class)
|
| 977 |
+
with open(id_class_dict_file, "rb") as f:
|
| 978 |
+
id_class_dict = pickle.load(f)
|
| 979 |
+
|
| 980 |
+
# get number of classes for classifier
|
| 981 |
+
num_classes = get_num_classes(id_class_dict)
|
| 982 |
+
|
| 983 |
+
# load previously filtered and prepared data
|
| 984 |
+
test_data = load_and_filter(None, nproc, test_data_file)
|
| 985 |
+
|
| 986 |
+
# load previously fine-tuned model
|
| 987 |
+
model = load_model(
|
| 988 |
+
"CellClassifier",
|
| 989 |
+
num_classes,
|
| 990 |
+
model_directory,
|
| 991 |
+
"eval",
|
| 992 |
+
quantize=quantize,
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
# evaluate the model
|
| 996 |
+
result = evaluate_model(
|
| 997 |
+
model,
|
| 998 |
+
num_classes,
|
| 999 |
+
id_class_dict,
|
| 1000 |
+
test_data,
|
| 1001 |
+
predict=predict,
|
| 1002 |
+
output_directory=output_directory,
|
| 1003 |
+
output_prefix="CellClassifier",
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
all_conf_mat_df = pd.DataFrame(
|
| 1007 |
+
result["conf_mat"],
|
| 1008 |
+
columns=id_class_dict.values(),
|
| 1009 |
+
index=id_class_dict.values(),
|
| 1010 |
+
)
|
| 1011 |
+
all_metrics = {
|
| 1012 |
+
"conf_matrix": all_conf_mat_df,
|
| 1013 |
+
"macro_f1": result["macro_f1"],
|
| 1014 |
+
"acc": result["acc"],
|
| 1015 |
+
}
|
| 1016 |
+
all_roc_metrics = None # roc metrics not reported for multiclass
|
| 1017 |
+
|
| 1018 |
+
if num_classes == 2:
|
| 1019 |
+
mean_fpr = np.linspace(0, 1, 100)
|
| 1020 |
+
mean_tpr = result["roc_metrics"]["interp_tpr"]
|
| 1021 |
+
all_roc_auc = result["roc_metrics"]["auc"]
|
| 1022 |
+
all_roc_metrics = {
|
| 1023 |
+
"mean_tpr": mean_tpr,
|
| 1024 |
+
"mean_fpr": mean_fpr,
|
| 1025 |
+
"all_roc_auc": all_roc_auc,
|
| 1026 |
+
}
|
| 1027 |
+
all_metrics["all_roc_metrics"] = all_roc_metrics
|
| 1028 |
+
test_metrics_output_path = (
|
| 1029 |
+
Path(output_directory) / f"{output_prefix}_test_metrics_dict"
|
| 1030 |
+
).with_suffix(".pkl")
|
| 1031 |
+
with open(test_metrics_output_path, "wb") as f:
|
| 1032 |
+
pickle.dump(all_metrics, f)
|
| 1033 |
+
|
| 1034 |
+
return all_metrics
|
| 1035 |
+
|
| 1036 |
+
def plot_conf_mat(
|
| 1037 |
+
conf_mat_dict,
|
| 1038 |
+
output_directory,
|
| 1039 |
+
output_prefix,
|
| 1040 |
+
custom_class_order=None,
|
| 1041 |
+
):
|
| 1042 |
+
"""
|
| 1043 |
+
Plot confusion matrix results of evaluating the fine-tuned model.
|
| 1044 |
+
|
| 1045 |
+
**Parameters**
|
| 1046 |
+
|
| 1047 |
+
conf_mat_dict : dict
|
| 1048 |
+
| Dictionary of model_name : confusion_matrix_DataFrame
|
| 1049 |
+
| (all_metrics["conf_matrix"] from self.validate)
|
| 1050 |
+
output_directory : Path
|
| 1051 |
+
| Path to directory where plots will be saved
|
| 1052 |
+
output_prefix : str
|
| 1053 |
+
| Prefix for output file
|
| 1054 |
+
custom_class_order : None, list
|
| 1055 |
+
| List of classes in custom order for plots.
|
| 1056 |
+
| Same order will be used for all models.
|
| 1057 |
+
"""
|
| 1058 |
+
|
| 1059 |
+
for model_name in conf_mat_dict.keys():
|
| 1060 |
+
plot_confusion_matrix(
|
| 1061 |
+
conf_mat_dict[model_name],
|
| 1062 |
+
model_name,
|
| 1063 |
+
output_directory,
|
| 1064 |
+
output_prefix,
|
| 1065 |
+
custom_class_order,
|
| 1066 |
+
)
|
| 1067 |
+
def plot_confusion_matrix(
|
| 1068 |
+
conf_mat_df, title, output_dir, output_prefix, custom_class_order
|
| 1069 |
+
):
|
| 1070 |
+
fig = plt.figure()
|
| 1071 |
+
fig.set_size_inches(10, 10)
|
| 1072 |
+
sns.set(font_scale=1)
|
| 1073 |
+
sns.set_style("whitegrid", {"axes.grid": False})
|
| 1074 |
+
if custom_class_order is not None:
|
| 1075 |
+
conf_mat_df = conf_mat_df.reindex(
|
| 1076 |
+
index=custom_class_order, columns=custom_class_order
|
| 1077 |
+
)
|
| 1078 |
+
display_labels = generate_display_labels(conf_mat_df)
|
| 1079 |
+
conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
|
| 1080 |
+
display = ConfusionMatrixDisplay(
|
| 1081 |
+
confusion_matrix=conf_mat, display_labels=display_labels
|
| 1082 |
+
)
|
| 1083 |
+
display.plot(cmap="Blues", values_format=".2g")
|
| 1084 |
+
plt.title(title)
|
| 1085 |
+
plt.show()
|
| 1086 |
+
|
| 1087 |
+
output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
|
| 1088 |
+
display.figure_.savefig(output_file, bbox_inches="tight")
|
| 1089 |
+
def generate_display_labels(conf_mat_df):
|
| 1090 |
+
display_labels = []
|
| 1091 |
+
i = 0
|
| 1092 |
+
for label in conf_mat_df.index:
|
| 1093 |
+
display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
|
| 1094 |
+
i = i + 1
|
| 1095 |
+
return display_labels
|
| 1096 |
+
|
| 1097 |
+
def plot_predictions(
|
| 1098 |
+
predictions_file,
|
| 1099 |
+
id_class_dict_file,
|
| 1100 |
+
title,
|
| 1101 |
+
output_directory,
|
| 1102 |
+
output_prefix,
|
| 1103 |
+
custom_class_order=None,
|
| 1104 |
+
kwargs_dict=None,
|
| 1105 |
+
):
|
| 1106 |
+
"""
|
| 1107 |
+
Plot prediction results of evaluating the fine-tuned model.
|
| 1108 |
+
|
| 1109 |
+
**Parameters**
|
| 1110 |
+
|
| 1111 |
+
predictions_file : path
|
| 1112 |
+
| Path of model predictions output to plot
|
| 1113 |
+
| (saved output from self.validate if predict_eval=True)
|
| 1114 |
+
| (or saved output from self.evaluate_saved_model)
|
| 1115 |
+
id_class_dict_file : Path
|
| 1116 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
| 1117 |
+
| (dictionary of format: numerical IDs: class_labels)
|
| 1118 |
+
title : str
|
| 1119 |
+
| Title for legend containing class labels.
|
| 1120 |
+
output_directory : Path
|
| 1121 |
+
| Path to directory where plots will be saved
|
| 1122 |
+
output_prefix : str
|
| 1123 |
+
| Prefix for output file
|
| 1124 |
+
custom_class_order : None, list
|
| 1125 |
+
| List of classes in custom order for plots.
|
| 1126 |
+
| Same order will be used for all models.
|
| 1127 |
+
kwargs_dict : None, dict
|
| 1128 |
+
| Dictionary of kwargs to pass to plotting function.
|
| 1129 |
+
"""
|
| 1130 |
+
# load predictions
|
| 1131 |
+
with open(predictions_file, "rb") as f:
|
| 1132 |
+
predictions = pickle.load(f)
|
| 1133 |
+
|
| 1134 |
+
# load numerical id to class dictionary (id:class)
|
| 1135 |
+
with open(id_class_dict_file, "rb") as f:
|
| 1136 |
+
id_class_dict = pickle.load(f)
|
| 1137 |
+
|
| 1138 |
+
if isinstance(predictions, dict):
|
| 1139 |
+
if all(
|
| 1140 |
+
[
|
| 1141 |
+
key in predictions.keys()
|
| 1142 |
+
for key in ["pred_ids", "label_ids", "predictions"]
|
| 1143 |
+
]
|
| 1144 |
+
):
|
| 1145 |
+
# format is output from self.evaluate_saved_model
|
| 1146 |
+
predictions_logits = np.array(predictions["predictions"])
|
| 1147 |
+
true_ids = predictions["label_ids"]
|
| 1148 |
+
else:
|
| 1149 |
+
# format is output from self.validate if predict_eval=True
|
| 1150 |
+
predictions_logits = predictions.predictions
|
| 1151 |
+
true_ids = predictions.label_ids
|
| 1152 |
+
|
| 1153 |
+
num_classes = len(id_class_dict.keys())
|
| 1154 |
+
num_predict_classes = predictions_logits.shape[1]
|
| 1155 |
+
assert num_classes == num_predict_classes
|
| 1156 |
+
classes = id_class_dict.values()
|
| 1157 |
+
true_labels = [id_class_dict[idx] for idx in true_ids]
|
| 1158 |
+
predictions_df = pd.DataFrame(predictions_logits, columns=classes)
|
| 1159 |
+
if custom_class_order is not None:
|
| 1160 |
+
predictions_df = predictions_df.reindex(columns=custom_class_order)
|
| 1161 |
+
predictions_df["true"] = true_labels
|
| 1162 |
+
custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
|
| 1163 |
+
if custom_class_order is not None:
|
| 1164 |
+
custom_dict = dict(
|
| 1165 |
+
zip(custom_class_order, [i for i in range(len(custom_class_order))])
|
| 1166 |
+
)
|
| 1167 |
+
predictions_df = predictions_df.sort_values(
|
| 1168 |
+
by=["true"], key=lambda x: x.map(custom_dict)
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
plot_predictions_eu(
|
| 1172 |
+
predictions_df, title, output_directory, output_prefix, kwargs_dict
|
| 1173 |
+
)
|
| 1174 |
+
def plot_predictions_eu(predictions_df, title, output_dir, output_prefix, kwargs_dict):
|
| 1175 |
+
sns.set(font_scale=2)
|
| 1176 |
+
plt.figure(figsize=(10, 10), dpi=150)
|
| 1177 |
+
label_colors, label_color_dict = make_colorbar(predictions_df, "true")
|
| 1178 |
+
predictions_df = predictions_df.drop(columns=["true"])
|
| 1179 |
+
predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
|
| 1180 |
+
predict_label_list = [label for label in predictions_df.columns]
|
| 1181 |
+
predict_colors = pd.DataFrame(
|
| 1182 |
+
pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
default_kwargs_dict = {
|
| 1186 |
+
"row_cluster": False,
|
| 1187 |
+
"col_cluster": False,
|
| 1188 |
+
"row_colors": label_colors,
|
| 1189 |
+
"col_colors": predict_colors,
|
| 1190 |
+
"linewidths": 0,
|
| 1191 |
+
"xticklabels": False,
|
| 1192 |
+
"yticklabels": False,
|
| 1193 |
+
"center": 0,
|
| 1194 |
+
"cmap": "vlag",
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
if kwargs_dict is not None:
|
| 1198 |
+
default_kwargs_dict.update(kwargs_dict)
|
| 1199 |
+
g = sns.clustermap(predictions_df, **default_kwargs_dict)
|
| 1200 |
+
|
| 1201 |
+
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
| 1202 |
+
|
| 1203 |
+
for label_color in list(label_color_dict.keys()):
|
| 1204 |
+
g.ax_col_dendrogram.bar(
|
| 1205 |
+
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
g.ax_col_dendrogram.legend(
|
| 1209 |
+
title=f"{title}",
|
| 1210 |
+
loc="lower center",
|
| 1211 |
+
ncol=4,
|
| 1212 |
+
bbox_to_anchor=(0.5, 1),
|
| 1213 |
+
facecolor="white",
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
|
| 1217 |
+
plt.savefig(output_file, bbox_inches="tight")
|
| 1218 |
+
def make_colorbar(embs_df, label):
|
| 1219 |
+
labels = list(embs_df[label])
|
| 1220 |
+
|
| 1221 |
+
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
| 1222 |
+
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
| 1223 |
+
|
| 1224 |
+
# create dictionary for colors and classes
|
| 1225 |
+
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
| 1226 |
+
return label_colors, label_color_dict
|
| 1227 |
+
def gen_heatmap_class_colors(labels, df):
|
| 1228 |
+
pal = sns.cubehelix_palette(
|
| 1229 |
+
len(Counter(labels).keys()),
|
| 1230 |
+
light=0.9,
|
| 1231 |
+
dark=0.1,
|
| 1232 |
+
hue=1,
|
| 1233 |
+
reverse=True,
|
| 1234 |
+
start=1,
|
| 1235 |
+
rot=-2,
|
| 1236 |
+
)
|
| 1237 |
+
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
| 1238 |
+
colors = pd.Series(labels, index=df.index).map(lut)
|
| 1239 |
+
return colors
|
| 1240 |
+
def gen_heatmap_class_dict(classes, label_colors_series):
|
| 1241 |
+
class_color_dict_df = pd.DataFrame(
|
| 1242 |
+
{"classes": classes, "color": label_colors_series}
|
| 1243 |
+
)
|
| 1244 |
+
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
|
| 1245 |
+
return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
for i in range(iter_step):
|
| 1249 |
+
|
| 1250 |
+
model_directory = "model path"
|
| 1251 |
+
|
| 1252 |
+
corpus_dir = "Pretrain_data"
|
| 1253 |
+
with open(corpus_dir + "/token_dictionary.pkl", "rb") as fp:
|
| 1254 |
+
gene_token_dict = pickle.load(fp)
|
| 1255 |
+
token_gene_dict = {v: k for k, v in gene_token_dict.items()}
|
| 1256 |
+
|
| 1257 |
+
filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}
|
| 1258 |
+
training_args = {
|
| 1259 |
+
"num_train_epochs": 0.9,
|
| 1260 |
+
"learning_rate": 0.000804,
|
| 1261 |
+
"lr_scheduler_type": "polynomial",
|
| 1262 |
+
"warmup_steps": 1812,
|
| 1263 |
+
"weight_decay":0.258828,
|
| 1264 |
+
"per_device_train_batch_size": 12,
|
| 1265 |
+
"seed": 73,
|
| 1266 |
+
}
|
| 1267 |
+
|
| 1268 |
+
cell_state_dict = {"state_key": "disease", "states": "all"}
|
| 1269 |
+
classifier='cell'
|
| 1270 |
+
filter_data=filter_data_dict
|
| 1271 |
+
split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1}
|
| 1272 |
+
train_size = split_sizes["train"]
|
| 1273 |
+
valid_size = split_sizes["valid"]
|
| 1274 |
+
oos_test_size = split_sizes["test"]
|
| 1275 |
+
max_ncells=None
|
| 1276 |
+
freeze_layers = 2
|
| 1277 |
+
num_crossval_splits = 1
|
| 1278 |
+
forward_batch_size=200
|
| 1279 |
+
nproc=16
|
| 1280 |
+
rare_threshold=0
|
| 1281 |
+
quantize=None
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
train_ids = ["1447", "1600", "1462", "1558", "1300", "1508", "1358", "1678", "1561", "1304", "1610", "1430", "1472", "1707", "1726", "1504", "1425", "1617", "1631", "1735", "1582", "1722", "1622", "1630", "1290", "1479", "1371", "1549", "1515"]
|
| 1285 |
+
eval_ids = ["1422", "1510", "1539", "1606", "1702"]
|
| 1286 |
+
test_ids = ["1437", "1516", "1602", "1685", "1718"]
|
| 1287 |
+
|
| 1288 |
+
train_test_id_split_dict = {"attr_key": "individual",
|
| 1289 |
+
"train": train_ids+eval_ids,
|
| 1290 |
+
"test": test_ids}
|
| 1291 |
+
train_valid_id_split_dict = {"attr_key": "individual",
|
| 1292 |
+
"train": train_ids,
|
| 1293 |
+
"eval": eval_ids}
|
| 1294 |
+
|
| 1295 |
+
# define output directory path
|
| 1296 |
+
current_date = datetime.datetime.now()
|
| 1297 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.strftime('%X').replace(':','')}"
|
| 1298 |
+
datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
| 1299 |
+
output_directory = "output path"
|
| 1300 |
+
|
| 1301 |
+
if output_directory[-1:] != "/": # add slash for dir if not present
|
| 1302 |
+
output_directory = output_directory + "/"
|
| 1303 |
+
output_dir = f"{output_directory}{datestamp}_geneformer_diseaseClassifier/"
|
| 1304 |
+
output_prefix = "cm_classifier_test"
|
| 1305 |
+
subprocess.call(f"mkdir {output_dir}", shell=True)
|
| 1306 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1307 |
+
|
| 1308 |
+
prepare_data(input_data_file="example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset",
|
| 1309 |
+
output_directory=output_dir,
|
| 1310 |
+
output_prefix=output_prefix,
|
| 1311 |
+
split_id_dict=train_test_id_split_dict)
|
| 1312 |
+
|
| 1313 |
+
with open(f"{output_dir}/{output_prefix}_id_class_dict.pkl", "rb") as f:
|
| 1314 |
+
id_class_dict = pickle.load(f)
|
| 1315 |
+
class_id_dict = {v: k for k, v in id_class_dict.items()}
|
| 1316 |
+
|
| 1317 |
+
num_classes = get_num_classes(id_class_dict)
|
| 1318 |
+
|
| 1319 |
+
data = load_and_filter(None, nproc, f"{output_dir}/{output_prefix}_labeled_train.dataset")
|
| 1320 |
+
data = data.shuffle(seed=42)
|
| 1321 |
+
|
| 1322 |
+
##### (Cross-)validate the model #####
|
| 1323 |
+
results = []
|
| 1324 |
+
all_conf_mat = np.zeros((num_classes, num_classes))
|
| 1325 |
+
iteration_num = 1
|
| 1326 |
+
split_id_dict=train_valid_id_split_dict
|
| 1327 |
+
|
| 1328 |
+
for i in trange(num_crossval_splits):
|
| 1329 |
+
print(
|
| 1330 |
+
f"****** Validation split: {iteration_num}/{num_crossval_splits} ******\n"
|
| 1331 |
+
)
|
| 1332 |
+
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
| 1333 |
+
if num_crossval_splits == 1:
|
| 1334 |
+
# single 1-eval_size:eval_size split
|
| 1335 |
+
if split_id_dict is not None:
|
| 1336 |
+
data_dict = dict()
|
| 1337 |
+
data_dict["train"] = filter_by_dict(
|
| 1338 |
+
data,
|
| 1339 |
+
{split_id_dict["attr_key"]: split_id_dict["train"]},
|
| 1340 |
+
nproc,
|
| 1341 |
+
)
|
| 1342 |
+
data_dict["test"] = filter_by_dict(
|
| 1343 |
+
data,
|
| 1344 |
+
{split_id_dict["attr_key"]: split_id_dict["eval"]},
|
| 1345 |
+
nproc,
|
| 1346 |
+
)
|
| 1347 |
+
train_data = data_dict["train"]
|
| 1348 |
+
eval_data = data_dict["test"]
|
| 1349 |
+
|
| 1350 |
+
trainer = train_classifier(
|
| 1351 |
+
model_directory,
|
| 1352 |
+
num_classes,
|
| 1353 |
+
train_data,
|
| 1354 |
+
eval_data,
|
| 1355 |
+
ksplit_output_dir,
|
| 1356 |
+
)
|
| 1357 |
+
|
| 1358 |
+
result = evaluate_model(
|
| 1359 |
+
trainer.model,
|
| 1360 |
+
num_classes,
|
| 1361 |
+
id_class_dict,
|
| 1362 |
+
eval_data,
|
| 1363 |
+
True,
|
| 1364 |
+
ksplit_output_dir,
|
| 1365 |
+
output_prefix,
|
| 1366 |
+
)
|
| 1367 |
+
results += [result]
|
| 1368 |
+
all_conf_mat = all_conf_mat + result["conf_mat"]
|
| 1369 |
+
iteration_num = iteration_num + 1
|
| 1370 |
+
|
| 1371 |
+
all_conf_mat_df = pd.DataFrame(
|
| 1372 |
+
all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
|
| 1373 |
+
)
|
| 1374 |
+
all_metrics = {
|
| 1375 |
+
"conf_matrix": all_conf_mat_df,
|
| 1376 |
+
"macro_f1": [result["macro_f1"] for result in results],
|
| 1377 |
+
"acc": [result["acc"] for result in results],
|
| 1378 |
+
}
|
| 1379 |
+
all_roc_metrics = None # roc metrics not reported for multiclass
|
| 1380 |
+
if num_classes == 2:
|
| 1381 |
+
mean_fpr = np.linspace(0, 1, 100)
|
| 1382 |
+
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
|
| 1383 |
+
all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
|
| 1384 |
+
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
|
| 1385 |
+
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
|
| 1386 |
+
all_tpr, all_roc_auc, all_tpr_wt
|
| 1387 |
+
)
|
| 1388 |
+
all_roc_metrics = {
|
| 1389 |
+
"mean_tpr": mean_tpr,
|
| 1390 |
+
"mean_fpr": mean_fpr,
|
| 1391 |
+
"all_roc_auc": all_roc_auc,
|
| 1392 |
+
"roc_auc": roc_auc,
|
| 1393 |
+
"roc_auc_sd": roc_auc_sd,
|
| 1394 |
+
}
|
| 1395 |
+
all_metrics["all_roc_metrics"] = all_roc_metrics
|
| 1396 |
+
save_eval_output=True
|
| 1397 |
+
if save_eval_output is True:
|
| 1398 |
+
eval_metrics_output_path = (
|
| 1399 |
+
Path(output_dir) / f"cm_classifier_test_eval_metrics_dict"
|
| 1400 |
+
).with_suffix(".pkl")
|
| 1401 |
+
with open(eval_metrics_output_path, "wb") as f:
|
| 1402 |
+
pickle.dump(all_metrics, f)
|
| 1403 |
+
|
| 1404 |
+
datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
|
| 1405 |
+
all_metrics_test = evaluate_saved_model(
|
| 1406 |
+
model_directory=f"{output_dir}/ksplit1/",
|
| 1407 |
+
id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
|
| 1408 |
+
test_data_file=f"{output_dir}/{output_prefix}_labeled_test.dataset",
|
| 1409 |
+
output_directory=output_dir,
|
| 1410 |
+
output_prefix=output_prefix,
|
| 1411 |
+
)
|
| 1412 |
+
|
| 1413 |
+
macro_f1_list.append(all_metrics_test['macro_f1'])
|
| 1414 |
+
acc_list.append(all_metrics_test['acc'])
|
| 1415 |
+
|
| 1416 |
+
|
| 1417 |
+
print("Macro F1: ", macro_f1_list)
|
| 1418 |
+
print("Accuracy: ", acc_list)
|
Downstream_tasks/Classification/Cardio_ML.ipynb
ADDED
|
@@ -0,0 +1,1404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import os\n",
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"from tqdm.auto import tqdm, trange\n",
|
| 12 |
+
"GPU_NUMBER = [0]\n",
|
| 13 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n",
|
| 14 |
+
"os.environ[\"NCCL_DEBUG\"] = \"INFO\"\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"# imports\n",
|
| 17 |
+
"from collections import Counter\n",
|
| 18 |
+
"import datetime\n",
|
| 19 |
+
"import pickle\n",
|
| 20 |
+
"import subprocess\n",
|
| 21 |
+
"import seaborn as sns; sns.set()\n",
|
| 22 |
+
"from datasets import load_from_disk\n",
|
| 23 |
+
"from sklearn.metrics import accuracy_score, f1_score\n",
|
| 24 |
+
"from transformers import BertForSequenceClassification, BertForMaskedLM, BertForTokenClassification\n",
|
| 25 |
+
"from transformers import Trainer\n",
|
| 26 |
+
"from transformers.training_args import TrainingArguments\n",
|
| 27 |
+
"import torch\n",
|
| 28 |
+
"import pandas as pd\n",
|
| 29 |
+
"from datasets.utils.logging import disable_progress_bar, enable_progress_bar\n",
|
| 30 |
+
"from sklearn import preprocessing\n",
|
| 31 |
+
"from sklearn.metrics import (\n",
|
| 32 |
+
" ConfusionMatrixDisplay,\n",
|
| 33 |
+
" accuracy_score,\n",
|
| 34 |
+
" auc,\n",
|
| 35 |
+
" confusion_matrix,\n",
|
| 36 |
+
" f1_score,\n",
|
| 37 |
+
" roc_curve,\n",
|
| 38 |
+
")\n",
|
| 39 |
+
"from pathlib import Path\n",
|
| 40 |
+
"import matplotlib.pyplot as plt\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"import sys\n",
|
| 43 |
+
"# sys.path.append('geneformer')\n",
|
| 44 |
+
"from geneformer import DataCollatorForCellClassification\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"macro_f1_list = []\n",
|
| 47 |
+
"acc_list = []\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"iter_step = 2\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"def prepare_data(\n",
|
| 52 |
+
" input_data_file,\n",
|
| 53 |
+
" output_directory,\n",
|
| 54 |
+
" output_prefix,\n",
|
| 55 |
+
" split_id_dict=None,\n",
|
| 56 |
+
" test_size=None,\n",
|
| 57 |
+
" attr_to_split=None,\n",
|
| 58 |
+
" attr_to_balance=None,\n",
|
| 59 |
+
" max_trials=100,\n",
|
| 60 |
+
" pval_threshold=0.1,\n",
|
| 61 |
+
"):\n",
|
| 62 |
+
" \"\"\"\n",
|
| 63 |
+
" Prepare data for cell state or gene classification.\n",
|
| 64 |
+
"\n",
|
| 65 |
+
" **Parameters**\n",
|
| 66 |
+
"\n",
|
| 67 |
+
" input_data_file : Path\n",
|
| 68 |
+
" | Path to directory containing .dataset input\n",
|
| 69 |
+
" output_directory : Path\n",
|
| 70 |
+
" | Path to directory where prepared data will be saved\n",
|
| 71 |
+
" output_prefix : str\n",
|
| 72 |
+
" | Prefix for output file\n",
|
| 73 |
+
" split_id_dict : None, dict\n",
|
| 74 |
+
" | Dictionary of IDs for train and test splits\n",
|
| 75 |
+
" | Three-item dictionary with keys: attr_key, train, test\n",
|
| 76 |
+
" | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits\n",
|
| 77 |
+
" | train: list of IDs in the attr_key column to include in the train split\n",
|
| 78 |
+
" | test: list of IDs in the attr_key column to include in the test split\n",
|
| 79 |
+
" | For example: {\"attr_key\": \"individual\",\n",
|
| 80 |
+
" | \"train\": [\"patient1\", \"patient2\", \"patient3\", \"patient4\"],\n",
|
| 81 |
+
" | \"test\": [\"patient5\", \"patient6\"]}\n",
|
| 82 |
+
" test_size : None, float\n",
|
| 83 |
+
" | Proportion of data to be saved separately and held out for test set\n",
|
| 84 |
+
" | (e.g. 0.2 if intending hold out 20%)\n",
|
| 85 |
+
" | If None, will inherit from split_sizes[\"test\"] from Classifier\n",
|
| 86 |
+
" | The training set will be further split to train / validation in self.validate\n",
|
| 87 |
+
" | Note: only available for CellClassifiers\n",
|
| 88 |
+
" attr_to_split : None, str\n",
|
| 89 |
+
" | Key for attribute on which to split data while balancing potential confounders\n",
|
| 90 |
+
" | e.g. \"patient_id\" for splitting by patient while balancing other characteristics\n",
|
| 91 |
+
" | Note: only available for CellClassifiers\n",
|
| 92 |
+
" attr_to_balance : None, list\n",
|
| 93 |
+
" | List of attribute keys on which to balance data while splitting on attr_to_split\n",
|
| 94 |
+
" | e.g. [\"age\", \"sex\"] for balancing these characteristics while splitting by patient\n",
|
| 95 |
+
" | Note: only available for CellClassifiers\n",
|
| 96 |
+
" max_trials : None, int\n",
|
| 97 |
+
" | Maximum number of trials of random splitting to try to achieve balanced other attributes\n",
|
| 98 |
+
" | If no split is found without significant (p<0.05) differences in other attributes, will select best\n",
|
| 99 |
+
" | Note: only available for CellClassifiers\n",
|
| 100 |
+
" pval_threshold : None, float\n",
|
| 101 |
+
" | P-value threshold to use for attribute balancing across splits\n",
|
| 102 |
+
" | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance\n",
|
| 103 |
+
" \"\"\"\n",
|
| 104 |
+
"\n",
|
| 105 |
+
" if test_size is None:\n",
|
| 106 |
+
" test_size = oos_test_size\n",
|
| 107 |
+
"\n",
|
| 108 |
+
" # prepare data and labels for classification\n",
|
| 109 |
+
" data = load_and_filter(filter_data, nproc, input_data_file)\n",
|
| 110 |
+
"\n",
|
| 111 |
+
" if classifier == \"cell\":\n",
|
| 112 |
+
" if \"label\" in data.features:\n",
|
| 113 |
+
" logger.error(\n",
|
| 114 |
+
" \"Column name 'label' must be reserved for class IDs. Please rename column.\"\n",
|
| 115 |
+
" )\n",
|
| 116 |
+
" raise\n",
|
| 117 |
+
" elif classifier == \"gene\":\n",
|
| 118 |
+
" if \"labels\" in data.features:\n",
|
| 119 |
+
" logger.error(\n",
|
| 120 |
+
" \"Column name 'labels' must be reserved for class IDs. Please rename column.\"\n",
|
| 121 |
+
" )\n",
|
| 122 |
+
" raise\n",
|
| 123 |
+
"\n",
|
| 124 |
+
" if (attr_to_split is not None) and (attr_to_balance is None):\n",
|
| 125 |
+
" logger.error(\n",
|
| 126 |
+
" \"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined.\"\n",
|
| 127 |
+
" )\n",
|
| 128 |
+
" raise\n",
|
| 129 |
+
"\n",
|
| 130 |
+
" if not isinstance(attr_to_balance, list):\n",
|
| 131 |
+
" attr_to_balance = [attr_to_balance]\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" if classifier == \"cell\":\n",
|
| 134 |
+
" # remove cell states representing < rare_threshold of cells\n",
|
| 135 |
+
" data = remove_rare(\n",
|
| 136 |
+
" data, rare_threshold, cell_state_dict[\"state_key\"], nproc\n",
|
| 137 |
+
" )\n",
|
| 138 |
+
" # downsample max cells and max per class\n",
|
| 139 |
+
" data = downsample_and_shuffle(\n",
|
| 140 |
+
" data, max_ncells, None, cell_state_dict\n",
|
| 141 |
+
" )\n",
|
| 142 |
+
" # rename cell state column to \"label\"\n",
|
| 143 |
+
" data = rename_cols(data, cell_state_dict[\"state_key\"])\n",
|
| 144 |
+
"\n",
|
| 145 |
+
" # convert classes to numerical labels and save as id_class_dict\n",
|
| 146 |
+
" # of note, will label all genes in gene_class_dict\n",
|
| 147 |
+
" # if (cross-)validating, genes will be relabeled in column \"labels\" for each split\n",
|
| 148 |
+
" # at the time of training with Classifier.validate\n",
|
| 149 |
+
" data, id_class_dict = label_classes(\n",
|
| 150 |
+
" classifier, data, None, nproc\n",
|
| 151 |
+
" )\n",
|
| 152 |
+
"\n",
|
| 153 |
+
" # save id_class_dict for future reference\n",
|
| 154 |
+
" id_class_output_path = (\n",
|
| 155 |
+
" Path(output_directory) / f\"{output_prefix}_id_class_dict\"\n",
|
| 156 |
+
" ).with_suffix(\".pkl\")\n",
|
| 157 |
+
" with open(id_class_output_path, \"wb\") as f:\n",
|
| 158 |
+
" pickle.dump(id_class_dict, f)\n",
|
| 159 |
+
"\n",
|
| 160 |
+
" if split_id_dict is not None:\n",
|
| 161 |
+
" data_dict = dict()\n",
|
| 162 |
+
" data_dict[\"train\"] = filter_by_dict(\n",
|
| 163 |
+
" data, {split_id_dict[\"attr_key\"]: split_id_dict[\"train\"]}, nproc\n",
|
| 164 |
+
" )\n",
|
| 165 |
+
" data_dict[\"test\"] = filter_by_dict(\n",
|
| 166 |
+
" data, {split_id_dict[\"attr_key\"]: split_id_dict[\"test\"]}, nproc\n",
|
| 167 |
+
" )\n",
|
| 168 |
+
" train_data_output_path = (\n",
|
| 169 |
+
" Path(output_directory) / f\"{output_prefix}_labeled_train\"\n",
|
| 170 |
+
" ).with_suffix(\".dataset\")\n",
|
| 171 |
+
" test_data_output_path = (\n",
|
| 172 |
+
" Path(output_directory) / f\"{output_prefix}_labeled_test\"\n",
|
| 173 |
+
" ).with_suffix(\".dataset\")\n",
|
| 174 |
+
" data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n",
|
| 175 |
+
" data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n",
|
| 176 |
+
" elif (test_size is not None) and (classifier == \"cell\"):\n",
|
| 177 |
+
" if 1 > test_size > 0:\n",
|
| 178 |
+
" if attr_to_split is None:\n",
|
| 179 |
+
" data_dict = data.train_test_split(\n",
|
| 180 |
+
" test_size=test_size,\n",
|
| 181 |
+
" stratify_by_column=None,\n",
|
| 182 |
+
" seed=42,\n",
|
| 183 |
+
" )\n",
|
| 184 |
+
" train_data_output_path = (\n",
|
| 185 |
+
" Path(output_directory) / f\"{output_prefix}_labeled_train\"\n",
|
| 186 |
+
" ).with_suffix(\".dataset\")\n",
|
| 187 |
+
" test_data_output_path = (\n",
|
| 188 |
+
" Path(output_directory) / f\"{output_prefix}_labeled_test\"\n",
|
| 189 |
+
" ).with_suffix(\".dataset\")\n",
|
| 190 |
+
" data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n",
|
| 191 |
+
" data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n",
|
| 192 |
+
" else:\n",
|
| 193 |
+
" data_dict, balance_df = cu.balance_attr_splits(\n",
|
| 194 |
+
" data,\n",
|
| 195 |
+
" attr_to_split,\n",
|
| 196 |
+
" attr_to_balance,\n",
|
| 197 |
+
" test_size,\n",
|
| 198 |
+
" max_trials,\n",
|
| 199 |
+
" pval_threshold,\n",
|
| 200 |
+
" cell_state_dict[\"state_key\"],\n",
|
| 201 |
+
" nproc,\n",
|
| 202 |
+
" )\n",
|
| 203 |
+
" balance_df.to_csv(\n",
|
| 204 |
+
" f\"{output_directory}/{output_prefix}_train_test_balance_df.csv\"\n",
|
| 205 |
+
" )\n",
|
| 206 |
+
" train_data_output_path = (\n",
|
| 207 |
+
" Path(output_directory) / f\"{output_prefix}_labeled_train\"\n",
|
| 208 |
+
" ).with_suffix(\".dataset\")\n",
|
| 209 |
+
" test_data_output_path = (\n",
|
| 210 |
+
" Path(output_directory) / f\"{output_prefix}_labeled_test\"\n",
|
| 211 |
+
" ).with_suffix(\".dataset\")\n",
|
| 212 |
+
" data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n",
|
| 213 |
+
" data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n",
|
| 214 |
+
" else:\n",
|
| 215 |
+
" data_output_path = (\n",
|
| 216 |
+
" Path(output_directory) / f\"{output_prefix}_labeled\"\n",
|
| 217 |
+
" ).with_suffix(\".dataset\")\n",
|
| 218 |
+
" data.save_to_disk(str(data_output_path))\n",
|
| 219 |
+
" print(data_output_path)\n",
|
| 220 |
+
" else:\n",
|
| 221 |
+
" data_output_path = (\n",
|
| 222 |
+
" Path(output_directory) / f\"{output_prefix}_labeled\"\n",
|
| 223 |
+
" ).with_suffix(\".dataset\")\n",
|
| 224 |
+
" data.save_to_disk(str(data_output_path))\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"def load_and_filter(filter_data, nproc, input_data_file):\n",
|
| 227 |
+
" data = load_from_disk(input_data_file)\n",
|
| 228 |
+
" if filter_data is not None:\n",
|
| 229 |
+
" data = filter_by_dict(data, filter_data, nproc)\n",
|
| 230 |
+
" return data\n",
|
| 231 |
+
"# get number of classes for classifier\n",
|
| 232 |
+
"def get_num_classes(id_class_dict):\n",
|
| 233 |
+
" return len(set(id_class_dict.values()))\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"def filter_by_dict(data, filter_data, nproc):\n",
|
| 236 |
+
" for key, value in filter_data.items():\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" def filter_data_by_criteria(example):\n",
|
| 239 |
+
" return example[key] in value\n",
|
| 240 |
+
"\n",
|
| 241 |
+
" data = data.filter(filter_data_by_criteria, num_proc=nproc)\n",
|
| 242 |
+
" if len(data) == 0:\n",
|
| 243 |
+
" logger.error(\"No cells remain after filtering. Check filtering criteria.\")\n",
|
| 244 |
+
" raise\n",
|
| 245 |
+
" return data\n",
|
| 246 |
+
"def remove_rare(data, rare_threshold, label, nproc):\n",
|
| 247 |
+
" if rare_threshold > 0:\n",
|
| 248 |
+
" total_cells = len(data)\n",
|
| 249 |
+
" label_counter = Counter(data[label])\n",
|
| 250 |
+
" nonrare_label_dict = {\n",
|
| 251 |
+
" label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]\n",
|
| 252 |
+
" }\n",
|
| 253 |
+
" data = filter_by_dict(data, nonrare_label_dict, nproc)\n",
|
| 254 |
+
" return data\n",
|
| 255 |
+
"def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):\n",
|
| 256 |
+
" data = data.shuffle(seed=42)\n",
|
| 257 |
+
" num_cells = len(data)\n",
|
| 258 |
+
" # if max number of cells is defined, then subsample to this max number\n",
|
| 259 |
+
" if max_ncells is not None:\n",
|
| 260 |
+
" if num_cells > max_ncells:\n",
|
| 261 |
+
" data = data.select([i for i in range(max_ncells)])\n",
|
| 262 |
+
" if max_ncells_per_class is not None:\n",
|
| 263 |
+
" class_labels = data[cell_state_dict[\"state_key\"]]\n",
|
| 264 |
+
" random.seed(42)\n",
|
| 265 |
+
" subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)\n",
|
| 266 |
+
" data = data.select(subsample_indices)\n",
|
| 267 |
+
" return data\n",
|
| 268 |
+
"def rename_cols(data, state_key):\n",
|
| 269 |
+
" data = data.rename_column(state_key, \"label\")\n",
|
| 270 |
+
" return data\n",
|
| 271 |
+
"def label_classes(classifier, data, gene_class_dict, nproc):\n",
|
| 272 |
+
" if classifier == \"cell\":\n",
|
| 273 |
+
" label_set = set(data[\"label\"])\n",
|
| 274 |
+
" elif classifier == \"gene\":\n",
|
| 275 |
+
" # remove cells without any of the target genes\n",
|
| 276 |
+
" def if_contains_label(example):\n",
|
| 277 |
+
" a = pu.flatten_list(gene_class_dict.values())\n",
|
| 278 |
+
" b = example[\"input_ids\"]\n",
|
| 279 |
+
" return not set(a).isdisjoint(b)\n",
|
| 280 |
+
"\n",
|
| 281 |
+
" data = data.filter(if_contains_label, num_proc=nproc)\n",
|
| 282 |
+
" label_set = gene_class_dict.keys()\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" if len(data) == 0:\n",
|
| 285 |
+
" logger.error(\n",
|
| 286 |
+
" \"No cells remain after filtering for target genes. Check target gene list.\"\n",
|
| 287 |
+
" )\n",
|
| 288 |
+
" raise\n",
|
| 289 |
+
"\n",
|
| 290 |
+
" class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))\n",
|
| 291 |
+
" id_class_dict = {v: k for k, v in class_id_dict.items()}\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" def classes_to_ids(example):\n",
|
| 294 |
+
" if classifier == \"cell\":\n",
|
| 295 |
+
" example[\"label\"] = class_id_dict[example[\"label\"]]\n",
|
| 296 |
+
" elif classifier == \"gene\":\n",
|
| 297 |
+
" example[\"labels\"] = label_gene_classes(\n",
|
| 298 |
+
" example, class_id_dict, gene_class_dict\n",
|
| 299 |
+
" )\n",
|
| 300 |
+
" return example\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" data = data.map(classes_to_ids, num_proc=nproc)\n",
|
| 303 |
+
" return data, id_class_dict\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"def train_classifier(\n",
|
| 306 |
+
" model_directory,\n",
|
| 307 |
+
" num_classes,\n",
|
| 308 |
+
" train_data,\n",
|
| 309 |
+
" eval_data,\n",
|
| 310 |
+
" output_directory,\n",
|
| 311 |
+
" predict=False,\n",
|
| 312 |
+
" classifier='cell',\n",
|
| 313 |
+
" no_eval=False,\n",
|
| 314 |
+
" quantize = False,\n",
|
| 315 |
+
" freeze_layers=2,\n",
|
| 316 |
+
" ):\n",
|
| 317 |
+
" \"\"\"\n",
|
| 318 |
+
" Fine-tune model for cell state or gene classification.\n",
|
| 319 |
+
"\n",
|
| 320 |
+
" **Parameters**\n",
|
| 321 |
+
"\n",
|
| 322 |
+
" model_directory : Path\n",
|
| 323 |
+
" | Path to directory containing model\n",
|
| 324 |
+
" num_classes : int\n",
|
| 325 |
+
" | Number of classes for classifier\n",
|
| 326 |
+
" train_data : Dataset\n",
|
| 327 |
+
" | Loaded training .dataset input\n",
|
| 328 |
+
" | For cell classifier, labels in column \"label\".\n",
|
| 329 |
+
" | For gene classifier, labels in column \"labels\".\n",
|
| 330 |
+
" eval_data : None, Dataset\n",
|
| 331 |
+
" | (Optional) Loaded evaluation .dataset input\n",
|
| 332 |
+
" | For cell classifier, labels in column \"label\".\n",
|
| 333 |
+
" | For gene classifier, labels in column \"labels\".\n",
|
| 334 |
+
" output_directory : Path\n",
|
| 335 |
+
" | Path to directory where fine-tuned model will be saved\n",
|
| 336 |
+
" predict : bool\n",
|
| 337 |
+
" | Whether or not to save eval predictions from trainer\n",
|
| 338 |
+
" \"\"\"\n",
|
| 339 |
+
"\n",
|
| 340 |
+
" ##### Validate and prepare data #####\n",
|
| 341 |
+
" train_data, eval_data = validate_and_clean_cols(\n",
|
| 342 |
+
" train_data, eval_data, classifier\n",
|
| 343 |
+
" )\n",
|
| 344 |
+
" \n",
|
| 345 |
+
" if (no_eval is True) and (eval_data is not None):\n",
|
| 346 |
+
" logger.warning(\n",
|
| 347 |
+
" \"no_eval set to True; model will be trained without evaluation.\"\n",
|
| 348 |
+
" )\n",
|
| 349 |
+
" eval_data = None\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" if (classifier == \"gene\") and (predict is True):\n",
|
| 352 |
+
" logger.warning(\n",
|
| 353 |
+
" \"Predictions during training not currently available for gene classifiers; setting predict to False.\"\n",
|
| 354 |
+
" )\n",
|
| 355 |
+
" predict = False\n",
|
| 356 |
+
"\n",
|
| 357 |
+
" # ensure not overwriting previously saved model\n",
|
| 358 |
+
" saved_model_test = os.path.join(output_directory, \"pytorch_model.bin\")\n",
|
| 359 |
+
" if os.path.isfile(saved_model_test) is True:\n",
|
| 360 |
+
" logger.error(\"Model already saved to this designated output directory.\")\n",
|
| 361 |
+
" raise\n",
|
| 362 |
+
" # make output directory\n",
|
| 363 |
+
" # subprocess.call(f\"mkdir {output_directory}\", shell=True)\n",
|
| 364 |
+
" os.makedirs(output_dir, exist_ok=True)\n",
|
| 365 |
+
"\n",
|
| 366 |
+
" ##### Load model and training args #####\n",
|
| 367 |
+
" model = load_model(\n",
|
| 368 |
+
" \"CellClassifier\",\n",
|
| 369 |
+
" num_classes,\n",
|
| 370 |
+
" model_directory,\n",
|
| 371 |
+
" \"train\",\n",
|
| 372 |
+
" quantize=quantize,\n",
|
| 373 |
+
" )\n",
|
| 374 |
+
" def_training_args, def_freeze_layers = get_default_train_args(\n",
|
| 375 |
+
" model, classifier, train_data, output_directory\n",
|
| 376 |
+
" )\n",
|
| 377 |
+
"\n",
|
| 378 |
+
" if training_args is not None:\n",
|
| 379 |
+
" def_training_args.update(training_args)\n",
|
| 380 |
+
" logging_steps = round(\n",
|
| 381 |
+
" len(train_data) / def_training_args[\"per_device_train_batch_size\"] / 10\n",
|
| 382 |
+
" )\n",
|
| 383 |
+
" def_training_args[\"logging_steps\"] = logging_steps\n",
|
| 384 |
+
" def_training_args[\"output_dir\"] = output_directory\n",
|
| 385 |
+
" if eval_data is None:\n",
|
| 386 |
+
" def_training_args[\"evaluation_strategy\"] = \"no\"\n",
|
| 387 |
+
" def_training_args[\"load_best_model_at_end\"] = False\n",
|
| 388 |
+
" training_args_init = TrainingArguments(**def_training_args)\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" if freeze_layers is not None:\n",
|
| 391 |
+
" def_freeze_layers = freeze_layers\n",
|
| 392 |
+
"\n",
|
| 393 |
+
" if def_freeze_layers > 0:\n",
|
| 394 |
+
" modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]\n",
|
| 395 |
+
" for module in modules_to_freeze:\n",
|
| 396 |
+
" for param in module.parameters():\n",
|
| 397 |
+
" param.requires_grad = False\n",
|
| 398 |
+
"\n",
|
| 399 |
+
" ##### Fine-tune the model #####\n",
|
| 400 |
+
" # define the data collator\n",
|
| 401 |
+
" if classifier == \"cell\":\n",
|
| 402 |
+
" data_collator = DataCollatorForCellClassification()\n",
|
| 403 |
+
" elif self.classifier == \"gene\":\n",
|
| 404 |
+
" data_collator = DataCollatorForGeneClassification()\n",
|
| 405 |
+
"\n",
|
| 406 |
+
" # create the trainer\n",
|
| 407 |
+
" trainer = Trainer(\n",
|
| 408 |
+
" model=model,\n",
|
| 409 |
+
" args=training_args_init,\n",
|
| 410 |
+
" data_collator=data_collator,\n",
|
| 411 |
+
" train_dataset=train_data,\n",
|
| 412 |
+
" eval_dataset=eval_data,\n",
|
| 413 |
+
" compute_metrics=compute_metrics,\n",
|
| 414 |
+
" )\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" # train the classifier\n",
|
| 417 |
+
" trainer.train()\n",
|
| 418 |
+
" trainer.save_model(output_directory)\n",
|
| 419 |
+
" if predict is True:\n",
|
| 420 |
+
" # make eval predictions and save predictions and metrics\n",
|
| 421 |
+
" predictions = trainer.predict(eval_data)\n",
|
| 422 |
+
" prediction_output_path = f\"{output_directory}/predictions.pkl\"\n",
|
| 423 |
+
" with open(prediction_output_path, \"wb\") as f:\n",
|
| 424 |
+
" pickle.dump(predictions, f)\n",
|
| 425 |
+
" trainer.save_metrics(\"eval\", predictions.metrics)\n",
|
| 426 |
+
" return trainer\n",
|
| 427 |
+
" \n",
|
| 428 |
+
"def validate_and_clean_cols(train_data, eval_data, classifier):\n",
|
| 429 |
+
" # validate that data has expected label column and remove others\n",
|
| 430 |
+
" if classifier == \"cell\":\n",
|
| 431 |
+
" label_col = \"label\"\n",
|
| 432 |
+
" elif classifier == \"gene\":\n",
|
| 433 |
+
" label_col = \"labels\"\n",
|
| 434 |
+
"\n",
|
| 435 |
+
" cols_to_keep = [label_col] + [\"input_ids\", \"length\"]\n",
|
| 436 |
+
" if label_col not in train_data.column_names:\n",
|
| 437 |
+
" logger.error(f\"train_data must contain column {label_col} with class labels.\")\n",
|
| 438 |
+
" raise\n",
|
| 439 |
+
" else:\n",
|
| 440 |
+
" train_data = remove_cols(train_data, cols_to_keep)\n",
|
| 441 |
+
"\n",
|
| 442 |
+
" if eval_data is not None:\n",
|
| 443 |
+
" if label_col not in eval_data.column_names:\n",
|
| 444 |
+
" logger.error(\n",
|
| 445 |
+
" f\"eval_data must contain column {label_col} with class labels.\"\n",
|
| 446 |
+
" )\n",
|
| 447 |
+
" raise\n",
|
| 448 |
+
" else:\n",
|
| 449 |
+
" eval_data = remove_cols(eval_data, cols_to_keep)\n",
|
| 450 |
+
" return train_data, eval_data\n",
|
| 451 |
+
" \n",
|
| 452 |
+
"def remove_cols(data, cols_to_keep):\n",
|
| 453 |
+
" other_cols = list(data.features.keys())\n",
|
| 454 |
+
" other_cols = [ele for ele in other_cols if ele not in cols_to_keep]\n",
|
| 455 |
+
" data = data.remove_columns(other_cols)\n",
|
| 456 |
+
" return data\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"def load_model(model_type, num_classes, model_directory, mode, quantize=False):\n",
|
| 459 |
+
" if model_type == \"MTLCellClassifier-Quantized\":\n",
|
| 460 |
+
" model_type = \"MTLCellClassifier\"\n",
|
| 461 |
+
" quantize = True\n",
|
| 462 |
+
"\n",
|
| 463 |
+
" output_hidden_states = (mode == \"eval\")\n",
|
| 464 |
+
"\n",
|
| 465 |
+
" # Quantization logic\n",
|
| 466 |
+
" if quantize:\n",
|
| 467 |
+
" if model_type == \"MTLCellClassifier\":\n",
|
| 468 |
+
" quantize_config = BitsAndBytesConfig(load_in_8bit=True)\n",
|
| 469 |
+
" peft_config = None\n",
|
| 470 |
+
" else:\n",
|
| 471 |
+
" quantize_config = BitsAndBytesConfig(\n",
|
| 472 |
+
" load_in_4bit=True,\n",
|
| 473 |
+
" bnb_4bit_use_double_quant=True,\n",
|
| 474 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 475 |
+
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
| 476 |
+
" )\n",
|
| 477 |
+
" peft_config = LoraConfig(\n",
|
| 478 |
+
" lora_alpha=128,\n",
|
| 479 |
+
" lora_dropout=0.1,\n",
|
| 480 |
+
" r=64,\n",
|
| 481 |
+
" bias=\"none\",\n",
|
| 482 |
+
" task_type=\"TokenClassification\",\n",
|
| 483 |
+
" )\n",
|
| 484 |
+
" else:\n",
|
| 485 |
+
" quantize_config = None\n",
|
| 486 |
+
" peft_config = None\n",
|
| 487 |
+
"\n",
|
| 488 |
+
" # Model class selection\n",
|
| 489 |
+
" model_classes = {\n",
|
| 490 |
+
" \"Pretrained\": BertForMaskedLM,\n",
|
| 491 |
+
" \"GeneClassifier\": BertForTokenClassification,\n",
|
| 492 |
+
" \"CellClassifier\": BertForSequenceClassification,\n",
|
| 493 |
+
" \"MTLCellClassifier\": BertForMaskedLM\n",
|
| 494 |
+
" }\n",
|
| 495 |
+
"\n",
|
| 496 |
+
" model_class = model_classes.get(model_type)\n",
|
| 497 |
+
" if not model_class:\n",
|
| 498 |
+
" raise ValueError(f\"Unknown model type: {model_type}\")\n",
|
| 499 |
+
"\n",
|
| 500 |
+
" # Model loading\n",
|
| 501 |
+
" model_args = {\n",
|
| 502 |
+
" \"pretrained_model_name_or_path\": model_directory,\n",
|
| 503 |
+
" \"output_hidden_states\": output_hidden_states,\n",
|
| 504 |
+
" \"output_attentions\": False,\n",
|
| 505 |
+
" }\n",
|
| 506 |
+
"\n",
|
| 507 |
+
" if model_type != \"Pretrained\":\n",
|
| 508 |
+
" model_args[\"num_labels\"] = num_classes\n",
|
| 509 |
+
"\n",
|
| 510 |
+
" if quantize_config:\n",
|
| 511 |
+
" model_args[\"quantization_config\"] = quantize_config\n",
|
| 512 |
+
" \n",
|
| 513 |
+
" # Load the model\n",
|
| 514 |
+
" model = model_class.from_pretrained(**model_args)\n",
|
| 515 |
+
" ###########################\n",
|
| 516 |
+
"\n",
|
| 517 |
+
" if mode == \"eval\":\n",
|
| 518 |
+
" model.eval()\n",
|
| 519 |
+
"\n",
|
| 520 |
+
" # Handle device placement and PEFT\n",
|
| 521 |
+
" if not quantize:\n",
|
| 522 |
+
" # Only move non-quantized models\n",
|
| 523 |
+
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 524 |
+
" model = model.to(device)\n",
|
| 525 |
+
" elif peft_config:\n",
|
| 526 |
+
" # Apply PEFT for quantized models (except MTLCellClassifier)\n",
|
| 527 |
+
" model.enable_input_require_grads()\n",
|
| 528 |
+
" model = get_peft_model(model, peft_config)\n",
|
| 529 |
+
"\n",
|
| 530 |
+
" return model\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"def get_default_train_args(model, classifier, data, output_dir):\n",
|
| 533 |
+
" num_layers = quant_layers(model)\n",
|
| 534 |
+
" freeze_layers_get = 0\n",
|
| 535 |
+
" batch_size = 12\n",
|
| 536 |
+
" if classifier == \"cell\":\n",
|
| 537 |
+
" epochs = 10\n",
|
| 538 |
+
" evaluation_strategy = \"epoch\"\n",
|
| 539 |
+
" load_best_model_at_end = True\n",
|
| 540 |
+
" else:\n",
|
| 541 |
+
" epochs = 1\n",
|
| 542 |
+
" evaluation_strategy = \"no\"\n",
|
| 543 |
+
" load_best_model_at_end = False\n",
|
| 544 |
+
"\n",
|
| 545 |
+
" if num_layers == 6:\n",
|
| 546 |
+
" default_training_args = {\n",
|
| 547 |
+
" \"learning_rate\": 5e-5,\n",
|
| 548 |
+
" \"lr_scheduler_type\": \"linear\",\n",
|
| 549 |
+
" \"warmup_steps\": 500,\n",
|
| 550 |
+
" \"per_device_train_batch_size\": batch_size,\n",
|
| 551 |
+
" \"per_device_eval_batch_size\": batch_size,\n",
|
| 552 |
+
" }\n",
|
| 553 |
+
" else:\n",
|
| 554 |
+
" default_training_args = {\n",
|
| 555 |
+
" \"per_device_train_batch_size\": batch_size,\n",
|
| 556 |
+
" \"per_device_eval_batch_size\": batch_size,\n",
|
| 557 |
+
" }\n",
|
| 558 |
+
"\n",
|
| 559 |
+
" training_args = {\n",
|
| 560 |
+
" \"num_train_epochs\": epochs,\n",
|
| 561 |
+
" \"do_train\": True,\n",
|
| 562 |
+
" \"do_eval\": True,\n",
|
| 563 |
+
" \"evaluation_strategy\": evaluation_strategy,\n",
|
| 564 |
+
" \"logging_steps\": np.floor(len(data) / batch_size / 8), # 8 evals per epoch\n",
|
| 565 |
+
" \"save_strategy\": \"epoch\",\n",
|
| 566 |
+
" \"group_by_length\": False,\n",
|
| 567 |
+
" \"length_column_name\": \"length\",\n",
|
| 568 |
+
" \"disable_tqdm\": False,\n",
|
| 569 |
+
" \"weight_decay\": 0.001,\n",
|
| 570 |
+
" \"load_best_model_at_end\": load_best_model_at_end,\n",
|
| 571 |
+
" }\n",
|
| 572 |
+
" training_args.update(default_training_args)\n",
|
| 573 |
+
"\n",
|
| 574 |
+
" return training_args, freeze_layers_get\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"def quant_layers(model):\n",
|
| 577 |
+
" layer_nums = []\n",
|
| 578 |
+
" for name, parameter in model.named_parameters():\n",
|
| 579 |
+
" if \"layer\" in name:\n",
|
| 580 |
+
" layer_nums += [int(name.split(\"layer.\")[1].split(\".\")[0])]\n",
|
| 581 |
+
" return int(max(layer_nums)) + 1\n",
|
| 582 |
+
"\n",
|
| 583 |
+
"def compute_metrics(pred):\n",
|
| 584 |
+
" labels = pred.label_ids\n",
|
| 585 |
+
" preds = pred.predictions.argmax(-1)\n",
|
| 586 |
+
" # calculate accuracy and macro f1 using sklearn's function\n",
|
| 587 |
+
" acc = accuracy_score(labels, preds)\n",
|
| 588 |
+
" macro_f1 = f1_score(labels, preds, average='macro')\n",
|
| 589 |
+
" weighted_f1 = f1_score(labels, preds, average='weighted')\n",
|
| 590 |
+
" return {\n",
|
| 591 |
+
" 'accuracy': acc,\n",
|
| 592 |
+
" 'macro_f1': macro_f1,\n",
|
| 593 |
+
" 'weighted_f1': weighted_f1\n",
|
| 594 |
+
" }\n",
|
| 595 |
+
"def evaluate_model(\n",
|
| 596 |
+
" model,\n",
|
| 597 |
+
" num_classes,\n",
|
| 598 |
+
" id_class_dict,\n",
|
| 599 |
+
" eval_data,\n",
|
| 600 |
+
" predict=False,\n",
|
| 601 |
+
" output_directory=None,\n",
|
| 602 |
+
" output_prefix=None,\n",
|
| 603 |
+
"):\n",
|
| 604 |
+
" \"\"\"\n",
|
| 605 |
+
" Evaluate the fine-tuned model.\n",
|
| 606 |
+
"\n",
|
| 607 |
+
" **Parameters**\n",
|
| 608 |
+
"\n",
|
| 609 |
+
" model : nn.Module\n",
|
| 610 |
+
" | Loaded fine-tuned model (e.g. trainer.model)\n",
|
| 611 |
+
" num_classes : int\n",
|
| 612 |
+
" | Number of classes for classifier\n",
|
| 613 |
+
" id_class_dict : dict\n",
|
| 614 |
+
" | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data\n",
|
| 615 |
+
" | (dictionary of format: numerical IDs: class_labels)\n",
|
| 616 |
+
" eval_data : Dataset\n",
|
| 617 |
+
" | Loaded evaluation .dataset input\n",
|
| 618 |
+
" predict : bool\n",
|
| 619 |
+
" | Whether or not to save eval predictions\n",
|
| 620 |
+
" output_directory : Path\n",
|
| 621 |
+
" | Path to directory where eval data will be saved\n",
|
| 622 |
+
" output_prefix : str\n",
|
| 623 |
+
" | Prefix for output files\n",
|
| 624 |
+
" \"\"\"\n",
|
| 625 |
+
"\n",
|
| 626 |
+
" ##### Evaluate the model #####\n",
|
| 627 |
+
" labels = id_class_dict.keys()\n",
|
| 628 |
+
" y_pred, y_true, logits_list = classifier_predict(\n",
|
| 629 |
+
" model, classifier, eval_data, 100\n",
|
| 630 |
+
" )\n",
|
| 631 |
+
" conf_mat, macro_f1, acc, roc_metrics = get_metrics(\n",
|
| 632 |
+
" y_pred, y_true, logits_list, num_classes, labels\n",
|
| 633 |
+
" )\n",
|
| 634 |
+
" if predict is True:\n",
|
| 635 |
+
" pred_dict = {\n",
|
| 636 |
+
" \"pred_ids\": y_pred,\n",
|
| 637 |
+
" \"label_ids\": y_true,\n",
|
| 638 |
+
" \"predictions\": logits_list,\n",
|
| 639 |
+
" }\n",
|
| 640 |
+
" pred_dict_output_path = (\n",
|
| 641 |
+
" Path(output_directory) / f\"{output_prefix}_pred_dict\"\n",
|
| 642 |
+
" ).with_suffix(\".pkl\")\n",
|
| 643 |
+
" with open(pred_dict_output_path, \"wb\") as f:\n",
|
| 644 |
+
" pickle.dump(pred_dict, f)\n",
|
| 645 |
+
" return {\n",
|
| 646 |
+
" \"conf_mat\": conf_mat,\n",
|
| 647 |
+
" \"macro_f1\": macro_f1,\n",
|
| 648 |
+
" \"acc\": acc,\n",
|
| 649 |
+
" \"roc_metrics\": roc_metrics,\n",
|
| 650 |
+
" }\n",
|
| 651 |
+
" \n",
|
| 652 |
+
"def classifier_predict(model, classifier_type, evalset, forward_batch_size):\n",
|
| 653 |
+
" if classifier_type == \"gene\":\n",
|
| 654 |
+
" label_name = \"labels\"\n",
|
| 655 |
+
" elif classifier_type == \"cell\":\n",
|
| 656 |
+
" label_name = \"label\"\n",
|
| 657 |
+
"\n",
|
| 658 |
+
" predict_logits = []\n",
|
| 659 |
+
" predict_labels = []\n",
|
| 660 |
+
" model.eval()\n",
|
| 661 |
+
"\n",
|
| 662 |
+
" # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims\n",
|
| 663 |
+
" evalset_len = len(evalset)\n",
|
| 664 |
+
" max_divisible = find_largest_div(evalset_len, forward_batch_size)\n",
|
| 665 |
+
" if len(evalset) - max_divisible == 1:\n",
|
| 666 |
+
" evalset_len = max_divisible\n",
|
| 667 |
+
"\n",
|
| 668 |
+
" max_evalset_len = max(evalset.select([i for i in range(evalset_len)])[\"length\"])\n",
|
| 669 |
+
"\n",
|
| 670 |
+
" disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping\n",
|
| 671 |
+
" for i in trange(0, evalset_len, forward_batch_size):\n",
|
| 672 |
+
" max_range = min(i + forward_batch_size, evalset_len)\n",
|
| 673 |
+
" batch_evalset = evalset.select([i for i in range(i, max_range)])\n",
|
| 674 |
+
" padded_batch = preprocess_classifier_batch(\n",
|
| 675 |
+
" batch_evalset, max_evalset_len, label_name\n",
|
| 676 |
+
" )\n",
|
| 677 |
+
" padded_batch.set_format(type=\"torch\")\n",
|
| 678 |
+
"\n",
|
| 679 |
+
" input_data_batch = padded_batch[\"input_ids\"]\n",
|
| 680 |
+
" attn_msk_batch = padded_batch[\"attention_mask\"]\n",
|
| 681 |
+
" label_batch = padded_batch[label_name]\n",
|
| 682 |
+
" with torch.no_grad():\n",
|
| 683 |
+
" outputs = model(\n",
|
| 684 |
+
" input_ids=input_data_batch.to(\"cuda\"),\n",
|
| 685 |
+
" attention_mask=attn_msk_batch.to(\"cuda\"),\n",
|
| 686 |
+
" labels=label_batch.to(\"cuda\"),\n",
|
| 687 |
+
" )\n",
|
| 688 |
+
" predict_logits += [torch.squeeze(outputs.logits.to(\"cpu\"))]\n",
|
| 689 |
+
" predict_labels += [torch.squeeze(label_batch.to(\"cpu\"))]\n",
|
| 690 |
+
"\n",
|
| 691 |
+
" enable_progress_bar()\n",
|
| 692 |
+
" logits_by_cell = torch.cat(predict_logits)\n",
|
| 693 |
+
" last_dim = len(logits_by_cell.shape) - 1\n",
|
| 694 |
+
" all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])\n",
|
| 695 |
+
" labels_by_cell = torch.cat(predict_labels)\n",
|
| 696 |
+
" all_labels = torch.flatten(labels_by_cell)\n",
|
| 697 |
+
" logit_label_paired = [\n",
|
| 698 |
+
" item\n",
|
| 699 |
+
" for item in list(zip(all_logits.tolist(), all_labels.tolist()))\n",
|
| 700 |
+
" if item[1] != -100\n",
|
| 701 |
+
" ]\n",
|
| 702 |
+
" y_pred = [vote(item[0]) for item in logit_label_paired]\n",
|
| 703 |
+
" y_true = [item[1] for item in logit_label_paired]\n",
|
| 704 |
+
" logits_list = [item[0] for item in logit_label_paired]\n",
|
| 705 |
+
" return y_pred, y_true, logits_list\n",
|
| 706 |
+
"\n",
|
| 707 |
+
"def find_largest_div(N, K):\n",
|
| 708 |
+
" rem = N % K\n",
|
| 709 |
+
" if rem == 0:\n",
|
| 710 |
+
" return N\n",
|
| 711 |
+
" else:\n",
|
| 712 |
+
" return N - rem\n",
|
| 713 |
+
"def preprocess_classifier_batch(cell_batch, max_len, label_name):\n",
|
| 714 |
+
" if max_len is None:\n",
|
| 715 |
+
" max_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n",
|
| 716 |
+
"\n",
|
| 717 |
+
" def pad_label_example(example):\n",
|
| 718 |
+
" example[label_name] = np.pad(\n",
|
| 719 |
+
" example[label_name],\n",
|
| 720 |
+
" (0, max_len - len(example[\"input_ids\"])),\n",
|
| 721 |
+
" mode=\"constant\",\n",
|
| 722 |
+
" constant_values=-100,\n",
|
| 723 |
+
" )\n",
|
| 724 |
+
" example[\"input_ids\"] = np.pad(\n",
|
| 725 |
+
" example[\"input_ids\"],\n",
|
| 726 |
+
" (0, max_len - len(example[\"input_ids\"])),\n",
|
| 727 |
+
" mode=\"constant\",\n",
|
| 728 |
+
" constant_values=gene_token_dict.get(\"<pad>\"),\n",
|
| 729 |
+
" )\n",
|
| 730 |
+
" example[\"attention_mask\"] = (\n",
|
| 731 |
+
" example[\"input_ids\"] != gene_token_dict.get(\"<pad>\")\n",
|
| 732 |
+
" ).astype(int)\n",
|
| 733 |
+
" return example\n",
|
| 734 |
+
"\n",
|
| 735 |
+
" padded_batch = cell_batch.map(pad_label_example)\n",
|
| 736 |
+
" return padded_batch\n",
|
| 737 |
+
"def vote(logit_list):\n",
|
| 738 |
+
" m = max(logit_list)\n",
|
| 739 |
+
" logit_list.index(m)\n",
|
| 740 |
+
" indices = [i for i, x in enumerate(logit_list) if x == m]\n",
|
| 741 |
+
" if len(indices) > 1:\n",
|
| 742 |
+
" return \"tie\"\n",
|
| 743 |
+
" else:\n",
|
| 744 |
+
" return indices[0]\n",
|
| 745 |
+
"def py_softmax(vector):\n",
|
| 746 |
+
" e = np.exp(vector)\n",
|
| 747 |
+
" return e / e.sum()\n",
|
| 748 |
+
"def get_metrics(y_pred, y_true, logits_list, num_classes, labels):\n",
|
| 749 |
+
" conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))\n",
|
| 750 |
+
" macro_f1 = f1_score(y_true, y_pred, average=\"macro\")\n",
|
| 751 |
+
" acc = accuracy_score(y_true, y_pred)\n",
|
| 752 |
+
" roc_metrics = None # roc metrics not reported for multiclass\n",
|
| 753 |
+
" if num_classes == 2:\n",
|
| 754 |
+
" y_score = [py_softmax(item)[1] for item in logits_list]\n",
|
| 755 |
+
" fpr, tpr, _ = roc_curve(y_true, y_score)\n",
|
| 756 |
+
" mean_fpr = np.linspace(0, 1, 100)\n",
|
| 757 |
+
" interp_tpr = np.interp(mean_fpr, fpr, tpr)\n",
|
| 758 |
+
" interp_tpr[0] = 0.0\n",
|
| 759 |
+
" tpr_wt = len(tpr)\n",
|
| 760 |
+
" roc_auc = auc(fpr, tpr)\n",
|
| 761 |
+
" roc_metrics = {\n",
|
| 762 |
+
" \"fpr\": fpr,\n",
|
| 763 |
+
" \"tpr\": tpr,\n",
|
| 764 |
+
" \"interp_tpr\": interp_tpr,\n",
|
| 765 |
+
" \"auc\": roc_auc,\n",
|
| 766 |
+
" \"tpr_wt\": tpr_wt,\n",
|
| 767 |
+
" }\n",
|
| 768 |
+
" return conf_mat, macro_f1, acc, roc_metrics\n",
|
| 769 |
+
"def evaluate_saved_model(\n",
|
| 770 |
+
" model_directory,\n",
|
| 771 |
+
" id_class_dict_file,\n",
|
| 772 |
+
" test_data_file,\n",
|
| 773 |
+
" output_directory,\n",
|
| 774 |
+
" output_prefix,\n",
|
| 775 |
+
" predict=True,\n",
|
| 776 |
+
"):\n",
|
| 777 |
+
" \"\"\"\n",
|
| 778 |
+
" Evaluate the fine-tuned model.\n",
|
| 779 |
+
"\n",
|
| 780 |
+
" **Parameters**\n",
|
| 781 |
+
"\n",
|
| 782 |
+
" model_directory : Path\n",
|
| 783 |
+
" | Path to directory containing model\n",
|
| 784 |
+
" id_class_dict_file : Path\n",
|
| 785 |
+
" | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data\n",
|
| 786 |
+
" | (dictionary of format: numerical IDs: class_labels)\n",
|
| 787 |
+
" test_data_file : Path\n",
|
| 788 |
+
" | Path to directory containing test .dataset\n",
|
| 789 |
+
" output_directory : Path\n",
|
| 790 |
+
" | Path to directory where eval data will be saved\n",
|
| 791 |
+
" output_prefix : str\n",
|
| 792 |
+
" | Prefix for output files\n",
|
| 793 |
+
" predict : bool\n",
|
| 794 |
+
" | Whether or not to save eval predictions\n",
|
| 795 |
+
" \"\"\"\n",
|
| 796 |
+
"\n",
|
| 797 |
+
" # load numerical id to class dictionary (id:class)\n",
|
| 798 |
+
" with open(id_class_dict_file, \"rb\") as f:\n",
|
| 799 |
+
" id_class_dict = pickle.load(f)\n",
|
| 800 |
+
"\n",
|
| 801 |
+
" # get number of classes for classifier\n",
|
| 802 |
+
" num_classes = get_num_classes(id_class_dict)\n",
|
| 803 |
+
"\n",
|
| 804 |
+
" # load previously filtered and prepared data\n",
|
| 805 |
+
" test_data = load_and_filter(None, nproc, test_data_file)\n",
|
| 806 |
+
"\n",
|
| 807 |
+
" # load previously fine-tuned model\n",
|
| 808 |
+
" model = load_model(\n",
|
| 809 |
+
" \"CellClassifier\",\n",
|
| 810 |
+
" num_classes,\n",
|
| 811 |
+
" model_directory,\n",
|
| 812 |
+
" \"eval\",\n",
|
| 813 |
+
" quantize=quantize,\n",
|
| 814 |
+
" )\n",
|
| 815 |
+
"\n",
|
| 816 |
+
" # evaluate the model\n",
|
| 817 |
+
" result = evaluate_model(\n",
|
| 818 |
+
" model,\n",
|
| 819 |
+
" num_classes,\n",
|
| 820 |
+
" id_class_dict,\n",
|
| 821 |
+
" test_data,\n",
|
| 822 |
+
" predict=predict,\n",
|
| 823 |
+
" output_directory=output_directory,\n",
|
| 824 |
+
" output_prefix=\"CellClassifier\",\n",
|
| 825 |
+
" )\n",
|
| 826 |
+
"\n",
|
| 827 |
+
" all_conf_mat_df = pd.DataFrame(\n",
|
| 828 |
+
" result[\"conf_mat\"],\n",
|
| 829 |
+
" columns=id_class_dict.values(),\n",
|
| 830 |
+
" index=id_class_dict.values(),\n",
|
| 831 |
+
" )\n",
|
| 832 |
+
" all_metrics = {\n",
|
| 833 |
+
" \"conf_matrix\": all_conf_mat_df,\n",
|
| 834 |
+
" \"macro_f1\": result[\"macro_f1\"],\n",
|
| 835 |
+
" \"acc\": result[\"acc\"],\n",
|
| 836 |
+
" }\n",
|
| 837 |
+
" all_roc_metrics = None # roc metrics not reported for multiclass\n",
|
| 838 |
+
"\n",
|
| 839 |
+
" if num_classes == 2:\n",
|
| 840 |
+
" mean_fpr = np.linspace(0, 1, 100)\n",
|
| 841 |
+
" mean_tpr = result[\"roc_metrics\"][\"interp_tpr\"]\n",
|
| 842 |
+
" all_roc_auc = result[\"roc_metrics\"][\"auc\"]\n",
|
| 843 |
+
" all_roc_metrics = {\n",
|
| 844 |
+
" \"mean_tpr\": mean_tpr,\n",
|
| 845 |
+
" \"mean_fpr\": mean_fpr,\n",
|
| 846 |
+
" \"all_roc_auc\": all_roc_auc,\n",
|
| 847 |
+
" }\n",
|
| 848 |
+
" all_metrics[\"all_roc_metrics\"] = all_roc_metrics\n",
|
| 849 |
+
" test_metrics_output_path = (\n",
|
| 850 |
+
" Path(output_directory) / f\"{output_prefix}_test_metrics_dict\"\n",
|
| 851 |
+
" ).with_suffix(\".pkl\")\n",
|
| 852 |
+
" with open(test_metrics_output_path, \"wb\") as f:\n",
|
| 853 |
+
" pickle.dump(all_metrics, f)\n",
|
| 854 |
+
"\n",
|
| 855 |
+
" return all_metrics\n",
|
| 856 |
+
"\n",
|
| 857 |
+
"def plot_conf_mat(\n",
|
| 858 |
+
" conf_mat_dict,\n",
|
| 859 |
+
" output_directory,\n",
|
| 860 |
+
" output_prefix,\n",
|
| 861 |
+
" custom_class_order=None,\n",
|
| 862 |
+
"):\n",
|
| 863 |
+
" \"\"\"\n",
|
| 864 |
+
" Plot confusion matrix results of evaluating the fine-tuned model.\n",
|
| 865 |
+
"\n",
|
| 866 |
+
" **Parameters**\n",
|
| 867 |
+
"\n",
|
| 868 |
+
" conf_mat_dict : dict\n",
|
| 869 |
+
" | Dictionary of model_name : confusion_matrix_DataFrame\n",
|
| 870 |
+
" | (all_metrics[\"conf_matrix\"] from self.validate)\n",
|
| 871 |
+
" output_directory : Path\n",
|
| 872 |
+
" | Path to directory where plots will be saved\n",
|
| 873 |
+
" output_prefix : str\n",
|
| 874 |
+
" | Prefix for output file\n",
|
| 875 |
+
" custom_class_order : None, list\n",
|
| 876 |
+
" | List of classes in custom order for plots.\n",
|
| 877 |
+
" | Same order will be used for all models.\n",
|
| 878 |
+
" \"\"\"\n",
|
| 879 |
+
"\n",
|
| 880 |
+
" for model_name in conf_mat_dict.keys():\n",
|
| 881 |
+
" plot_confusion_matrix(\n",
|
| 882 |
+
" conf_mat_dict[model_name],\n",
|
| 883 |
+
" model_name,\n",
|
| 884 |
+
" output_directory,\n",
|
| 885 |
+
" output_prefix,\n",
|
| 886 |
+
" custom_class_order,\n",
|
| 887 |
+
" )\n",
|
| 888 |
+
"def plot_confusion_matrix(\n",
|
| 889 |
+
" conf_mat_df, title, output_dir, output_prefix, custom_class_order\n",
|
| 890 |
+
"):\n",
|
| 891 |
+
" fig = plt.figure()\n",
|
| 892 |
+
" fig.set_size_inches(10, 10)\n",
|
| 893 |
+
" sns.set(font_scale=1)\n",
|
| 894 |
+
" sns.set_style(\"whitegrid\", {\"axes.grid\": False})\n",
|
| 895 |
+
" if custom_class_order is not None:\n",
|
| 896 |
+
" conf_mat_df = conf_mat_df.reindex(\n",
|
| 897 |
+
" index=custom_class_order, columns=custom_class_order\n",
|
| 898 |
+
" )\n",
|
| 899 |
+
" display_labels = generate_display_labels(conf_mat_df)\n",
|
| 900 |
+
" conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm=\"l1\")\n",
|
| 901 |
+
" display = ConfusionMatrixDisplay(\n",
|
| 902 |
+
" confusion_matrix=conf_mat, display_labels=display_labels\n",
|
| 903 |
+
" )\n",
|
| 904 |
+
" display.plot(cmap=\"Blues\", values_format=\".2g\")\n",
|
| 905 |
+
" plt.title(title)\n",
|
| 906 |
+
" plt.show()\n",
|
| 907 |
+
"\n",
|
| 908 |
+
" output_file = (Path(output_dir) / f\"{output_prefix}_conf_mat\").with_suffix(\".pdf\")\n",
|
| 909 |
+
" display.figure_.savefig(output_file, bbox_inches=\"tight\")\n",
|
| 910 |
+
"def generate_display_labels(conf_mat_df):\n",
|
| 911 |
+
" display_labels = []\n",
|
| 912 |
+
" i = 0\n",
|
| 913 |
+
" for label in conf_mat_df.index:\n",
|
| 914 |
+
" display_labels += [f\"{label}\\nn={conf_mat_df.iloc[i,:].sum():.0f}\"]\n",
|
| 915 |
+
" i = i + 1\n",
|
| 916 |
+
" return display_labels\n",
|
| 917 |
+
"\n",
|
| 918 |
+
"def plot_predictions(\n",
|
| 919 |
+
" predictions_file,\n",
|
| 920 |
+
" id_class_dict_file,\n",
|
| 921 |
+
" title,\n",
|
| 922 |
+
" output_directory,\n",
|
| 923 |
+
" output_prefix,\n",
|
| 924 |
+
" custom_class_order=None,\n",
|
| 925 |
+
" kwargs_dict=None,\n",
|
| 926 |
+
"):\n",
|
| 927 |
+
" \"\"\"\n",
|
| 928 |
+
" Plot prediction results of evaluating the fine-tuned model.\n",
|
| 929 |
+
"\n",
|
| 930 |
+
" **Parameters**\n",
|
| 931 |
+
"\n",
|
| 932 |
+
" predictions_file : path\n",
|
| 933 |
+
" | Path of model predictions output to plot\n",
|
| 934 |
+
" | (saved output from self.validate if predict_eval=True)\n",
|
| 935 |
+
" | (or saved output from self.evaluate_saved_model)\n",
|
| 936 |
+
" id_class_dict_file : Path\n",
|
| 937 |
+
" | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data\n",
|
| 938 |
+
" | (dictionary of format: numerical IDs: class_labels)\n",
|
| 939 |
+
" title : str\n",
|
| 940 |
+
" | Title for legend containing class labels.\n",
|
| 941 |
+
" output_directory : Path\n",
|
| 942 |
+
" | Path to directory where plots will be saved\n",
|
| 943 |
+
" output_prefix : str\n",
|
| 944 |
+
" | Prefix for output file\n",
|
| 945 |
+
" custom_class_order : None, list\n",
|
| 946 |
+
" | List of classes in custom order for plots.\n",
|
| 947 |
+
" | Same order will be used for all models.\n",
|
| 948 |
+
" kwargs_dict : None, dict\n",
|
| 949 |
+
" | Dictionary of kwargs to pass to plotting function.\n",
|
| 950 |
+
" \"\"\"\n",
|
| 951 |
+
" # load predictions\n",
|
| 952 |
+
" with open(predictions_file, \"rb\") as f:\n",
|
| 953 |
+
" predictions = pickle.load(f)\n",
|
| 954 |
+
"\n",
|
| 955 |
+
" # load numerical id to class dictionary (id:class)\n",
|
| 956 |
+
" with open(id_class_dict_file, \"rb\") as f:\n",
|
| 957 |
+
" id_class_dict = pickle.load(f)\n",
|
| 958 |
+
"\n",
|
| 959 |
+
" if isinstance(predictions, dict):\n",
|
| 960 |
+
" if all(\n",
|
| 961 |
+
" [\n",
|
| 962 |
+
" key in predictions.keys()\n",
|
| 963 |
+
" for key in [\"pred_ids\", \"label_ids\", \"predictions\"]\n",
|
| 964 |
+
" ]\n",
|
| 965 |
+
" ):\n",
|
| 966 |
+
" # format is output from self.evaluate_saved_model\n",
|
| 967 |
+
" predictions_logits = np.array(predictions[\"predictions\"])\n",
|
| 968 |
+
" true_ids = predictions[\"label_ids\"]\n",
|
| 969 |
+
" else:\n",
|
| 970 |
+
" # format is output from self.validate if predict_eval=True\n",
|
| 971 |
+
" predictions_logits = predictions.predictions\n",
|
| 972 |
+
" true_ids = predictions.label_ids\n",
|
| 973 |
+
"\n",
|
| 974 |
+
" num_classes = len(id_class_dict.keys())\n",
|
| 975 |
+
" num_predict_classes = predictions_logits.shape[1]\n",
|
| 976 |
+
" assert num_classes == num_predict_classes\n",
|
| 977 |
+
" classes = id_class_dict.values()\n",
|
| 978 |
+
" true_labels = [id_class_dict[idx] for idx in true_ids]\n",
|
| 979 |
+
" predictions_df = pd.DataFrame(predictions_logits, columns=classes)\n",
|
| 980 |
+
" if custom_class_order is not None:\n",
|
| 981 |
+
" predictions_df = predictions_df.reindex(columns=custom_class_order)\n",
|
| 982 |
+
" predictions_df[\"true\"] = true_labels\n",
|
| 983 |
+
" custom_dict = dict(zip(classes, [i for i in range(len(classes))]))\n",
|
| 984 |
+
" if custom_class_order is not None:\n",
|
| 985 |
+
" custom_dict = dict(\n",
|
| 986 |
+
" zip(custom_class_order, [i for i in range(len(custom_class_order))])\n",
|
| 987 |
+
" )\n",
|
| 988 |
+
" predictions_df = predictions_df.sort_values(\n",
|
| 989 |
+
" by=[\"true\"], key=lambda x: x.map(custom_dict)\n",
|
| 990 |
+
" )\n",
|
| 991 |
+
"\n",
|
| 992 |
+
" plot_predictions_eu(\n",
|
| 993 |
+
" predictions_df, title, output_directory, output_prefix, kwargs_dict\n",
|
| 994 |
+
" )\n",
|
| 995 |
+
"def plot_predictions_eu(predictions_df, title, output_dir, output_prefix, kwargs_dict):\n",
|
| 996 |
+
" sns.set(font_scale=2)\n",
|
| 997 |
+
" plt.figure(figsize=(10, 10), dpi=150)\n",
|
| 998 |
+
" label_colors, label_color_dict = make_colorbar(predictions_df, \"true\")\n",
|
| 999 |
+
" predictions_df = predictions_df.drop(columns=[\"true\"])\n",
|
| 1000 |
+
" predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]\n",
|
| 1001 |
+
" predict_label_list = [label for label in predictions_df.columns]\n",
|
| 1002 |
+
" predict_colors = pd.DataFrame(\n",
|
| 1003 |
+
" pd.Series(predict_colors_list, index=predict_label_list), columns=[\"predicted\"]\n",
|
| 1004 |
+
" )\n",
|
| 1005 |
+
"\n",
|
| 1006 |
+
" default_kwargs_dict = {\n",
|
| 1007 |
+
" \"row_cluster\": False,\n",
|
| 1008 |
+
" \"col_cluster\": False,\n",
|
| 1009 |
+
" \"row_colors\": label_colors,\n",
|
| 1010 |
+
" \"col_colors\": predict_colors,\n",
|
| 1011 |
+
" \"linewidths\": 0,\n",
|
| 1012 |
+
" \"xticklabels\": False,\n",
|
| 1013 |
+
" \"yticklabels\": False,\n",
|
| 1014 |
+
" \"center\": 0,\n",
|
| 1015 |
+
" \"cmap\": \"vlag\",\n",
|
| 1016 |
+
" }\n",
|
| 1017 |
+
"\n",
|
| 1018 |
+
" if kwargs_dict is not None:\n",
|
| 1019 |
+
" default_kwargs_dict.update(kwargs_dict)\n",
|
| 1020 |
+
" g = sns.clustermap(predictions_df, **default_kwargs_dict)\n",
|
| 1021 |
+
"\n",
|
| 1022 |
+
" plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha=\"right\")\n",
|
| 1023 |
+
"\n",
|
| 1024 |
+
" for label_color in list(label_color_dict.keys()):\n",
|
| 1025 |
+
" g.ax_col_dendrogram.bar(\n",
|
| 1026 |
+
" 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0\n",
|
| 1027 |
+
" )\n",
|
| 1028 |
+
"\n",
|
| 1029 |
+
" g.ax_col_dendrogram.legend(\n",
|
| 1030 |
+
" title=f\"{title}\",\n",
|
| 1031 |
+
" loc=\"lower center\",\n",
|
| 1032 |
+
" ncol=4,\n",
|
| 1033 |
+
" bbox_to_anchor=(0.5, 1),\n",
|
| 1034 |
+
" facecolor=\"white\",\n",
|
| 1035 |
+
" )\n",
|
| 1036 |
+
"\n",
|
| 1037 |
+
" output_file = (Path(output_dir) / f\"{output_prefix}_pred\").with_suffix(\".pdf\")\n",
|
| 1038 |
+
" plt.savefig(output_file, bbox_inches=\"tight\")\n",
|
| 1039 |
+
"def make_colorbar(embs_df, label):\n",
|
| 1040 |
+
" labels = list(embs_df[label])\n",
|
| 1041 |
+
"\n",
|
| 1042 |
+
" cell_type_colors = gen_heatmap_class_colors(labels, embs_df)\n",
|
| 1043 |
+
" label_colors = pd.DataFrame(cell_type_colors, columns=[label])\n",
|
| 1044 |
+
"\n",
|
| 1045 |
+
" # create dictionary for colors and classes\n",
|
| 1046 |
+
" label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])\n",
|
| 1047 |
+
" return label_colors, label_color_dict\n",
|
| 1048 |
+
"def gen_heatmap_class_colors(labels, df):\n",
|
| 1049 |
+
" pal = sns.cubehelix_palette(\n",
|
| 1050 |
+
" len(Counter(labels).keys()),\n",
|
| 1051 |
+
" light=0.9,\n",
|
| 1052 |
+
" dark=0.1,\n",
|
| 1053 |
+
" hue=1,\n",
|
| 1054 |
+
" reverse=True,\n",
|
| 1055 |
+
" start=1,\n",
|
| 1056 |
+
" rot=-2,\n",
|
| 1057 |
+
" )\n",
|
| 1058 |
+
" lut = dict(zip(map(str, Counter(labels).keys()), pal))\n",
|
| 1059 |
+
" colors = pd.Series(labels, index=df.index).map(lut)\n",
|
| 1060 |
+
" return colors\n",
|
| 1061 |
+
"def gen_heatmap_class_dict(classes, label_colors_series):\n",
|
| 1062 |
+
" class_color_dict_df = pd.DataFrame(\n",
|
| 1063 |
+
" {\"classes\": classes, \"color\": label_colors_series}\n",
|
| 1064 |
+
" )\n",
|
| 1065 |
+
" class_color_dict_df = class_color_dict_df.drop_duplicates(subset=[\"classes\"])\n",
|
| 1066 |
+
" return dict(zip(class_color_dict_df[\"classes\"], class_color_dict_df[\"color\"]))"
|
| 1067 |
+
]
|
| 1068 |
+
},
|
| 1069 |
+
{
|
| 1070 |
+
"cell_type": "code",
|
| 1071 |
+
"execution_count": null,
|
| 1072 |
+
"metadata": {},
|
| 1073 |
+
"outputs": [
|
| 1074 |
+
{
|
| 1075 |
+
"data": {
|
| 1076 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1077 |
+
"model_id": "7a260f2ee53e46cda883751b4f9ee36f",
|
| 1078 |
+
"version_major": 2,
|
| 1079 |
+
"version_minor": 0
|
| 1080 |
+
},
|
| 1081 |
+
"text/plain": [
|
| 1082 |
+
"Saving the dataset (0/3 shards): 0%| | 0/115367 [00:00<?, ? examples/s]"
|
| 1083 |
+
]
|
| 1084 |
+
},
|
| 1085 |
+
"metadata": {},
|
| 1086 |
+
"output_type": "display_data"
|
| 1087 |
+
},
|
| 1088 |
+
{
|
| 1089 |
+
"data": {
|
| 1090 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1091 |
+
"model_id": "56bf186783134b349bece0953132c491",
|
| 1092 |
+
"version_major": 2,
|
| 1093 |
+
"version_minor": 0
|
| 1094 |
+
},
|
| 1095 |
+
"text/plain": [
|
| 1096 |
+
"Saving the dataset (0/1 shards): 0%| | 0/17228 [00:00<?, ? examples/s]"
|
| 1097 |
+
]
|
| 1098 |
+
},
|
| 1099 |
+
"metadata": {},
|
| 1100 |
+
"output_type": "display_data"
|
| 1101 |
+
},
|
| 1102 |
+
{
|
| 1103 |
+
"data": {
|
| 1104 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1105 |
+
"model_id": "cccf5a6fd66f4005b6ebd2aef3772229",
|
| 1106 |
+
"version_major": 2,
|
| 1107 |
+
"version_minor": 0
|
| 1108 |
+
},
|
| 1109 |
+
"text/plain": [
|
| 1110 |
+
" 0%| | 0/1 [00:00<?, ?it/s]"
|
| 1111 |
+
]
|
| 1112 |
+
},
|
| 1113 |
+
"metadata": {},
|
| 1114 |
+
"output_type": "display_data"
|
| 1115 |
+
},
|
| 1116 |
+
{
|
| 1117 |
+
"name": "stdout",
|
| 1118 |
+
"output_type": "stream",
|
| 1119 |
+
"text": [
|
| 1120 |
+
"****** Validation split: 1/1 ******\n",
|
| 1121 |
+
"\n"
|
| 1122 |
+
]
|
| 1123 |
+
},
|
| 1124 |
+
{
|
| 1125 |
+
"data": {
|
| 1126 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1127 |
+
"model_id": "7c1733b61dd14cb4a9e36cee4704a218",
|
| 1128 |
+
"version_major": 2,
|
| 1129 |
+
"version_minor": 0
|
| 1130 |
+
},
|
| 1131 |
+
"text/plain": [
|
| 1132 |
+
"Filter (num_proc=16): 0%| | 0/115367 [00:00<?, ? examples/s]"
|
| 1133 |
+
]
|
| 1134 |
+
},
|
| 1135 |
+
"metadata": {},
|
| 1136 |
+
"output_type": "display_data"
|
| 1137 |
+
},
|
| 1138 |
+
{
|
| 1139 |
+
"data": {
|
| 1140 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1141 |
+
"model_id": "fb09533c6da74363a7e26f20d777fce8",
|
| 1142 |
+
"version_major": 2,
|
| 1143 |
+
"version_minor": 0
|
| 1144 |
+
},
|
| 1145 |
+
"text/plain": [
|
| 1146 |
+
"Filter (num_proc=16): 0%| | 0/115367 [00:00<?, ? examples/s]"
|
| 1147 |
+
]
|
| 1148 |
+
},
|
| 1149 |
+
"metadata": {},
|
| 1150 |
+
"output_type": "display_data"
|
| 1151 |
+
}
|
| 1152 |
+
],
|
| 1153 |
+
"source": [
|
| 1154 |
+
"corpus_dir = \"Pretrain_data\"\n",
|
| 1155 |
+
"with open(corpus_dir + \"/token_dictionary.pkl\", \"rb\") as fp:\n",
|
| 1156 |
+
" gene_token_dict = pickle.load(fp)\n",
|
| 1157 |
+
"token_gene_dict = {v: k for k, v in gene_token_dict.items()}\n",
|
| 1158 |
+
"\n",
|
| 1159 |
+
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
| 1160 |
+
"training_args = {\n",
|
| 1161 |
+
" \"num_train_epochs\": 0.9,\n",
|
| 1162 |
+
" \"learning_rate\": 0.000804,\n",
|
| 1163 |
+
" \"lr_scheduler_type\": \"polynomial\",\n",
|
| 1164 |
+
" \"warmup_steps\": 1812,\n",
|
| 1165 |
+
" \"weight_decay\":0.258828,\n",
|
| 1166 |
+
" \"per_device_train_batch_size\": 12,\n",
|
| 1167 |
+
" \"seed\": 73,\n",
|
| 1168 |
+
"}\n",
|
| 1169 |
+
"\n",
|
| 1170 |
+
"cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"}\n",
|
| 1171 |
+
"classifier='cell'\n",
|
| 1172 |
+
"filter_data=filter_data_dict\n",
|
| 1173 |
+
"split_sizes={\"train\": 0.8, \"valid\": 0.1, \"test\": 0.1}\n",
|
| 1174 |
+
"train_size = split_sizes[\"train\"]\n",
|
| 1175 |
+
"valid_size = split_sizes[\"valid\"]\n",
|
| 1176 |
+
"oos_test_size = split_sizes[\"test\"]\n",
|
| 1177 |
+
"max_ncells=None\n",
|
| 1178 |
+
"freeze_layers = 2\n",
|
| 1179 |
+
"num_crossval_splits = 1\n",
|
| 1180 |
+
"forward_batch_size=200\n",
|
| 1181 |
+
"nproc=16\n",
|
| 1182 |
+
"rare_threshold=0\n",
|
| 1183 |
+
"quantize=None\n",
|
| 1184 |
+
"\n",
|
| 1185 |
+
"\n",
|
| 1186 |
+
"train_ids = [\"1447\", \"1600\", \"1462\", \"1558\", \"1300\", \"1508\", \"1358\", \"1678\", \"1561\", \"1304\", \"1610\", \"1430\", \"1472\", \"1707\", \"1726\", \"1504\", \"1425\", \"1617\", \"1631\", \"1735\", \"1582\", \"1722\", \"1622\", \"1630\", \"1290\", \"1479\", \"1371\", \"1549\", \"1515\"]\n",
|
| 1187 |
+
"eval_ids = [\"1422\", \"1510\", \"1539\", \"1606\", \"1702\"]\n",
|
| 1188 |
+
"test_ids = [\"1437\", \"1516\", \"1602\", \"1685\", \"1718\"]\n",
|
| 1189 |
+
"\n",
|
| 1190 |
+
"train_test_id_split_dict = {\"attr_key\": \"individual\",\n",
|
| 1191 |
+
" \"train\": train_ids+eval_ids,\n",
|
| 1192 |
+
" \"test\": test_ids}\n",
|
| 1193 |
+
"train_valid_id_split_dict = {\"attr_key\": \"individual\",\n",
|
| 1194 |
+
" \"train\": train_ids,\n",
|
| 1195 |
+
" \"eval\": eval_ids}\n",
|
| 1196 |
+
"\n",
|
| 1197 |
+
"# define output directory path\n",
|
| 1198 |
+
"current_date = datetime.datetime.now()\n",
|
| 1199 |
+
"datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.strftime('%X').replace(':','')}\"\n",
|
| 1200 |
+
"datestamp_min = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
|
| 1201 |
+
"output_directory = \"output path\"\n",
|
| 1202 |
+
"\n",
|
| 1203 |
+
"if output_directory[-1:] != \"/\": # add slash for dir if not present\n",
|
| 1204 |
+
" output_directory = output_directory + \"/\"\n",
|
| 1205 |
+
"output_dir = f\"{output_directory}{datestamp}_geneformer_diseaseClassifier/\"\n",
|
| 1206 |
+
"output_prefix = \"cm_classifier_test\"\n",
|
| 1207 |
+
"subprocess.call(f\"mkdir {output_dir}\", shell=True)\n",
|
| 1208 |
+
"os.makedirs(output_dir, exist_ok=True)\n",
|
| 1209 |
+
"\n",
|
| 1210 |
+
"prepare_data(input_data_file=\"example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\",\n",
|
| 1211 |
+
" output_directory=output_dir,\n",
|
| 1212 |
+
" output_prefix=output_prefix,\n",
|
| 1213 |
+
" split_id_dict=train_test_id_split_dict)\n",
|
| 1214 |
+
"\n",
|
| 1215 |
+
"with open(f\"{output_dir}/{output_prefix}_id_class_dict.pkl\", \"rb\") as f:\n",
|
| 1216 |
+
" id_class_dict = pickle.load(f)\n",
|
| 1217 |
+
"class_id_dict = {v: k for k, v in id_class_dict.items()}\n",
|
| 1218 |
+
"\n",
|
| 1219 |
+
"num_classes = get_num_classes(id_class_dict)\n",
|
| 1220 |
+
"\n",
|
| 1221 |
+
"data = load_and_filter(None, nproc, f\"{output_dir}/{output_prefix}_labeled_train.dataset\")\n",
|
| 1222 |
+
"data = data.shuffle(seed=42)\n",
|
| 1223 |
+
"\n",
|
| 1224 |
+
"##### (Cross-)validate the model #####\n",
|
| 1225 |
+
"results = []\n",
|
| 1226 |
+
"all_conf_mat = np.zeros((num_classes, num_classes))\n",
|
| 1227 |
+
"iteration_num = 1\n",
|
| 1228 |
+
"split_id_dict=train_valid_id_split_dict\n",
|
| 1229 |
+
"\n",
|
| 1230 |
+
"for i in trange(num_crossval_splits):\n",
|
| 1231 |
+
" print(\n",
|
| 1232 |
+
" f\"****** Validation split: {iteration_num}/{num_crossval_splits} ******\\n\"\n",
|
| 1233 |
+
" )\n",
|
| 1234 |
+
" ksplit_output_dir = os.path.join(output_dir, f\"ksplit{iteration_num}\")\n",
|
| 1235 |
+
" if num_crossval_splits == 1:\n",
|
| 1236 |
+
" # single 1-eval_size:eval_size split\n",
|
| 1237 |
+
" if split_id_dict is not None:\n",
|
| 1238 |
+
" data_dict = dict()\n",
|
| 1239 |
+
" data_dict[\"train\"] = filter_by_dict(\n",
|
| 1240 |
+
" data,\n",
|
| 1241 |
+
" {split_id_dict[\"attr_key\"]: split_id_dict[\"train\"]},\n",
|
| 1242 |
+
" nproc,\n",
|
| 1243 |
+
" )\n",
|
| 1244 |
+
" data_dict[\"test\"] = filter_by_dict(\n",
|
| 1245 |
+
" data,\n",
|
| 1246 |
+
" {split_id_dict[\"attr_key\"]: split_id_dict[\"eval\"]},\n",
|
| 1247 |
+
" nproc,\n",
|
| 1248 |
+
" )\n",
|
| 1249 |
+
" train_data = data_dict[\"train\"]\n",
|
| 1250 |
+
" eval_data = data_dict[\"test\"]"
|
| 1251 |
+
]
|
| 1252 |
+
},
|
| 1253 |
+
{
|
| 1254 |
+
"cell_type": "code",
|
| 1255 |
+
"execution_count": null,
|
| 1256 |
+
"metadata": {},
|
| 1257 |
+
"outputs": [
|
| 1258 |
+
{
|
| 1259 |
+
"name": "stdout",
|
| 1260 |
+
"output_type": "stream",
|
| 1261 |
+
"text": [
|
| 1262 |
+
"Converting training dataset...\n"
|
| 1263 |
+
]
|
| 1264 |
+
},
|
| 1265 |
+
{
|
| 1266 |
+
"name": "stderr",
|
| 1267 |
+
"output_type": "stream",
|
| 1268 |
+
"text": [
|
| 1269 |
+
"Converting sequences: 100%|██████████| 93589/93589 [00:02<00:00, 41967.40seq/s] \n"
|
| 1270 |
+
]
|
| 1271 |
+
},
|
| 1272 |
+
{
|
| 1273 |
+
"name": "stdout",
|
| 1274 |
+
"output_type": "stream",
|
| 1275 |
+
"text": [
|
| 1276 |
+
"Converting evaluation dataset...\n"
|
| 1277 |
+
]
|
| 1278 |
+
},
|
| 1279 |
+
{
|
| 1280 |
+
"name": "stderr",
|
| 1281 |
+
"output_type": "stream",
|
| 1282 |
+
"text": [
|
| 1283 |
+
"Converting sequences: 100%|██████████| 21778/21778 [00:00<00:00, 151581.39seq/s]\n"
|
| 1284 |
+
]
|
| 1285 |
+
},
|
| 1286 |
+
{
|
| 1287 |
+
"name": "stdout",
|
| 1288 |
+
"output_type": "stream",
|
| 1289 |
+
"text": [
|
| 1290 |
+
"Training RandomForest...\n",
|
| 1291 |
+
"Training LogisticRegression...\n",
|
| 1292 |
+
" Accuracy Macro F1 Weighted F1 Weighted Precision\n",
|
| 1293 |
+
"RandomForest 0.618055 0.457959 0.649440 0.687780\n",
|
| 1294 |
+
"LogisticRegression 0.592065 0.440782 0.608307 0.645992\n"
|
| 1295 |
+
]
|
| 1296 |
+
}
|
| 1297 |
+
],
|
| 1298 |
+
"source": [
|
| 1299 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 1300 |
+
"from sklearn.linear_model import LogisticRegression\n",
|
| 1301 |
+
"from sklearn.svm import SVC\n",
|
| 1302 |
+
"from sklearn.metrics import accuracy_score, f1_score, precision_score\n",
|
| 1303 |
+
"import numpy as np\n",
|
| 1304 |
+
"from tqdm import tqdm\n",
|
| 1305 |
+
"\n",
|
| 1306 |
+
"def pad_or_truncate(seq, max_len):\n",
|
| 1307 |
+
" if len(seq) < max_len:\n",
|
| 1308 |
+
" return seq + [0] * (max_len - len(seq))\n",
|
| 1309 |
+
" else:\n",
|
| 1310 |
+
" return seq[:max_len]\n",
|
| 1311 |
+
"\n",
|
| 1312 |
+
"def dataset_to_numpy(hf_dataset, max_len=256):\n",
|
| 1313 |
+
" X = []\n",
|
| 1314 |
+
" for seq in tqdm(hf_dataset[\"input_ids\"], desc=\"Converting sequences\", unit=\"seq\"):\n",
|
| 1315 |
+
" X.append(pad_or_truncate(seq, max_len))\n",
|
| 1316 |
+
" y = np.array(hf_dataset[\"label\"])\n",
|
| 1317 |
+
" return np.array(X), y\n",
|
| 1318 |
+
"\n",
|
| 1319 |
+
"print(\"Converting training dataset...\")\n",
|
| 1320 |
+
"X_train, y_train = dataset_to_numpy(train_data)\n",
|
| 1321 |
+
"print(\"Converting evaluation dataset...\")\n",
|
| 1322 |
+
"X_eval, y_eval = dataset_to_numpy(eval_data)\n",
|
| 1323 |
+
"\n",
|
| 1324 |
+
"models = {\n",
|
| 1325 |
+
" \"RandomForest\": RandomForestClassifier(n_estimators=100, random_state=42),\n",
|
| 1326 |
+
" \"LogisticRegression\": LogisticRegression(max_iter=1000, random_state=42),\n",
|
| 1327 |
+
" \"SVM\": SVC(kernel=\"linear\", probability=True, random_state=42),\n",
|
| 1328 |
+
" \"SVM\": make_pipeline(StandardScaler(), SVC(kernel=\"rbf\", probability=True, random_state=42)),\n",
|
| 1329 |
+
"}\n",
|
| 1330 |
+
"\n",
|
| 1331 |
+
"results = {}\n",
|
| 1332 |
+
"for name, model in models.items():\n",
|
| 1333 |
+
" print(f\"Training {name}...\")\n",
|
| 1334 |
+
" model.fit(X_train, y_train)\n",
|
| 1335 |
+
" y_pred = model.predict(X_eval)\n",
|
| 1336 |
+
" \n",
|
| 1337 |
+
" acc = accuracy_score(y_eval, y_pred)\n",
|
| 1338 |
+
" macro_f1 = f1_score(y_eval, y_pred, average=\"macro\")\n",
|
| 1339 |
+
" weighted_f1 = f1_score(y_eval, y_pred, average=\"weighted\")\n",
|
| 1340 |
+
" precision = precision_score(y_eval, y_pred, average=\"weighted\")\n",
|
| 1341 |
+
" \n",
|
| 1342 |
+
" results[name] = {\n",
|
| 1343 |
+
" \"Accuracy\": acc,\n",
|
| 1344 |
+
" \"Macro F1\": macro_f1,\n",
|
| 1345 |
+
" \"Weighted F1\": weighted_f1,\n",
|
| 1346 |
+
" \"Weighted Precision\": precision\n",
|
| 1347 |
+
" }\n",
|
| 1348 |
+
"\n",
|
| 1349 |
+
"# Display results\n",
|
| 1350 |
+
"import pandas as pd\n",
|
| 1351 |
+
"results_df = pd.DataFrame(results).T\n",
|
| 1352 |
+
"print(results_df)\n"
|
| 1353 |
+
]
|
| 1354 |
+
},
|
| 1355 |
+
{
|
| 1356 |
+
"cell_type": "code",
|
| 1357 |
+
"execution_count": 4,
|
| 1358 |
+
"metadata": {},
|
| 1359 |
+
"outputs": [
|
| 1360 |
+
{
|
| 1361 |
+
"data": {
|
| 1362 |
+
"text/plain": [
|
| 1363 |
+
"{'RandomForest': {'Accuracy': 0.6180549178069612,\n",
|
| 1364 |
+
" 'Macro F1': 0.45795920359758124,\n",
|
| 1365 |
+
" 'Weighted F1': 0.6494402066016174,\n",
|
| 1366 |
+
" 'Weighted Precision': 0.687779833202143},\n",
|
| 1367 |
+
" 'LogisticRegression': {'Accuracy': 0.5920653870878868,\n",
|
| 1368 |
+
" 'Macro F1': 0.4407815175765883,\n",
|
| 1369 |
+
" 'Weighted F1': 0.6083068177204959,\n",
|
| 1370 |
+
" 'Weighted Precision': 0.6459924332028076}}"
|
| 1371 |
+
]
|
| 1372 |
+
},
|
| 1373 |
+
"execution_count": 4,
|
| 1374 |
+
"metadata": {},
|
| 1375 |
+
"output_type": "execute_result"
|
| 1376 |
+
}
|
| 1377 |
+
],
|
| 1378 |
+
"source": [
|
| 1379 |
+
"results"
|
| 1380 |
+
]
|
| 1381 |
+
}
|
| 1382 |
+
],
|
| 1383 |
+
"metadata": {
|
| 1384 |
+
"kernelspec": {
|
| 1385 |
+
"display_name": "Python 3",
|
| 1386 |
+
"language": "python",
|
| 1387 |
+
"name": "python3"
|
| 1388 |
+
},
|
| 1389 |
+
"language_info": {
|
| 1390 |
+
"codemirror_mode": {
|
| 1391 |
+
"name": "ipython",
|
| 1392 |
+
"version": 3
|
| 1393 |
+
},
|
| 1394 |
+
"file_extension": ".py",
|
| 1395 |
+
"mimetype": "text/x-python",
|
| 1396 |
+
"name": "python",
|
| 1397 |
+
"nbconvert_exporter": "python",
|
| 1398 |
+
"pygments_lexer": "ipython3",
|
| 1399 |
+
"version": "3.11.7"
|
| 1400 |
+
}
|
| 1401 |
+
},
|
| 1402 |
+
"nbformat": 4,
|
| 1403 |
+
"nbformat_minor": 2
|
| 1404 |
+
}
|
Downstream_tasks/Classification/Gene_dosage.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Downstream_tasks/Classification/Gene_dosage_ML.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Downstream_tasks/Classification/Tissue_type.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from tqdm.auto import tqdm, trange
|
| 3 |
+
GPU_NUMBER = [0]
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
| 5 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
| 6 |
+
|
| 7 |
+
# imports
|
| 8 |
+
from collections import Counter
|
| 9 |
+
import seaborn as sns; sns.set()
|
| 10 |
+
from datasets import load_from_disk
|
| 11 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 12 |
+
from transformers import Trainer
|
| 13 |
+
from transformers.training_args import TrainingArguments
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
| 16 |
+
from sklearn import preprocessing
|
| 17 |
+
from sklearn.metrics import (
|
| 18 |
+
ConfusionMatrixDisplay,
|
| 19 |
+
accuracy_score,
|
| 20 |
+
auc,
|
| 21 |
+
confusion_matrix,
|
| 22 |
+
f1_score,
|
| 23 |
+
roc_curve,
|
| 24 |
+
)
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
import sys
|
| 28 |
+
# sys.path.append('../Geneformer')
|
| 29 |
+
from geneformer import DataCollatorForCellClassification
|
| 30 |
+
from datasets import load_from_disk
|
| 31 |
+
import sys
|
| 32 |
+
from tqdm.notebook import tqdm
|
| 33 |
+
import seaborn as sns
|
| 34 |
+
import matplotlib.pyplot as plt
|
| 35 |
+
from geneformer.pretrainer import token_dictionary
|
| 36 |
+
import datetime
|
| 37 |
+
import time
|
| 38 |
+
import pickle
|
| 39 |
+
import random
|
| 40 |
+
import subprocess
|
| 41 |
+
import numpy as np
|
| 42 |
+
import pytz
|
| 43 |
+
import torch
|
| 44 |
+
from datasets import load_from_disk, Dataset
|
| 45 |
+
from transformers import (BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback,
|
| 46 |
+
Trainer, BertModel, BertPreTrainedModel, BertForSequenceClassification, BertForTokenClassification)
|
| 47 |
+
from geneformer import GeneformerPretrainer
|
| 48 |
+
from torch import Tensor
|
| 49 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 50 |
+
from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
|
| 51 |
+
from transformers.activations import ACT2FN
|
| 52 |
+
from typing import List, Optional, Tuple, Union
|
| 53 |
+
import torch.nn.functional as F
|
| 54 |
+
|
| 55 |
+
model_path = 'model path'
|
| 56 |
+
prefix = 'CAB5_1M'
|
| 57 |
+
total_iter = 1
|
| 58 |
+
|
| 59 |
+
class CustomBertForMaskedLM(BertPreTrainedModel):
|
| 60 |
+
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
| 61 |
+
_tied_weights_keys = ["decoder.weight", "bert.embeddings.word_embeddings.weight"]
|
| 62 |
+
|
| 63 |
+
def __init__(self, config):
|
| 64 |
+
super().__init__(config)
|
| 65 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 66 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 67 |
+
|
| 68 |
+
self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 69 |
+
|
| 70 |
+
self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size))
|
| 71 |
+
|
| 72 |
+
# Initialize weights
|
| 73 |
+
self.init_weights()
|
| 74 |
+
|
| 75 |
+
# Tie weights automatically
|
| 76 |
+
self.tie_weights()
|
| 77 |
+
|
| 78 |
+
# self.post_init()
|
| 79 |
+
|
| 80 |
+
def tie_weights(self):
|
| 81 |
+
"""
|
| 82 |
+
Ties the weights between the input embeddings and output decoder weights.
|
| 83 |
+
"""
|
| 84 |
+
self.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
| 85 |
+
|
| 86 |
+
def probability_convert(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
|
| 87 |
+
device = probs.device
|
| 88 |
+
batch_size, seq_length, vocab_size = probs.size()
|
| 89 |
+
_, input_seq_length = input_ids.size()
|
| 90 |
+
|
| 91 |
+
# truncated_labels = labels[:, :input_seq_length]
|
| 92 |
+
# non_mask = truncated_labels == -100
|
| 93 |
+
non_mask = labels == -100
|
| 94 |
+
non_mask_indices = non_mask.nonzero(as_tuple=True)
|
| 95 |
+
known_gene_indices = input_ids[non_mask]
|
| 96 |
+
|
| 97 |
+
# Generate (1-p) matrix whiel assigning all known genes in the beginning
|
| 98 |
+
zeros = torch.zeros((batch_size, 1, vocab_size), device=device)
|
| 99 |
+
zeros[non_mask_indices[0], 0, known_gene_indices] = 1.0
|
| 100 |
+
probs_shifted = torch.cat((zeros, probs[:, :-1, :]), dim=1)
|
| 101 |
+
inv_probs_shifted = 1 - probs_shifted
|
| 102 |
+
|
| 103 |
+
# Cumulative product to get (1-p_1)*(1-p_2)*...*(p_i)
|
| 104 |
+
cumprod_inv_probs = torch.cumprod(inv_probs_shifted, dim=1)
|
| 105 |
+
modified_probs = probs * cumprod_inv_probs
|
| 106 |
+
|
| 107 |
+
# # Since we are assigning probabilities for already known genes,
|
| 108 |
+
# # (1-p_1)*(1-p_2)*...*(p_i) for these genes can result in 0, due to hard assignment of probs to be 1
|
| 109 |
+
# # Add 1e-18 to avoid dividing modified probs by 0
|
| 110 |
+
# # During dubugging stage, some issues occurred in the normalization step.
|
| 111 |
+
# # Since probabilities in each position do not necessarily need to sum up to one, leave out normalization.
|
| 112 |
+
normalized_probs = modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18)
|
| 113 |
+
modified_probs = modified_probs / normalized_probs # Normalization after cumulative production
|
| 114 |
+
|
| 115 |
+
return modified_probs
|
| 116 |
+
|
| 117 |
+
def assign_known_gene_probs(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
|
| 118 |
+
|
| 119 |
+
device = probs.device
|
| 120 |
+
batch_size, seq_length, vocab_size = probs.size()
|
| 121 |
+
_, input_seq_length = input_ids.size()
|
| 122 |
+
|
| 123 |
+
# Truncate `labels` to match the length of `input_ids` along the sequence dimension
|
| 124 |
+
truncated_labels = labels[:, :input_seq_length]
|
| 125 |
+
|
| 126 |
+
non_mask = truncated_labels == -100
|
| 127 |
+
non_mask_indices = non_mask.nonzero(as_tuple=True)
|
| 128 |
+
|
| 129 |
+
ones = torch.ones((batch_size, seq_length, vocab_size), device=device)
|
| 130 |
+
zeros = torch.zeros((batch_size, seq_length, vocab_size), device=device)
|
| 131 |
+
|
| 132 |
+
known_gene_indices = input_ids[non_mask]
|
| 133 |
+
|
| 134 |
+
ones[non_mask_indices[0], non_mask_indices[1], :] = 0.0
|
| 135 |
+
zeros[non_mask_indices[0], non_mask_indices[1], known_gene_indices] = 1.0
|
| 136 |
+
|
| 137 |
+
# Modify already known genes' probabilities using the one-hot tensor
|
| 138 |
+
modified_probs = probs * ones
|
| 139 |
+
modified_probs = modified_probs + zeros
|
| 140 |
+
|
| 141 |
+
# Do the normalization
|
| 142 |
+
modified_probs = modified_probs / modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) # Normalize
|
| 143 |
+
|
| 144 |
+
return modified_probs
|
| 145 |
+
|
| 146 |
+
def compute_similarity_on_probs(self, probs: Tensor, labels: Tensor) -> Tensor:
|
| 147 |
+
"""
|
| 148 |
+
Optimized computation of average cosine similarity across all positions in each sequence and batch.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
probs (torch.Tensor): Probability tensor of shape (batch_size, seq_length, vocab_size).
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
torch.Tensor: Average similarity term for loss computation.
|
| 155 |
+
"""
|
| 156 |
+
batch_size, seq_length, vocab_size = probs.size()
|
| 157 |
+
device = probs.device
|
| 158 |
+
|
| 159 |
+
non_mask = labels == -100
|
| 160 |
+
non_mask_indices = non_mask.nonzero(as_tuple=True)
|
| 161 |
+
|
| 162 |
+
mask_sim = torch.ones((batch_size, seq_length, seq_length), device=device)
|
| 163 |
+
mask_sim[non_mask_indices[0], non_mask_indices[1], :] = 0.0
|
| 164 |
+
|
| 165 |
+
seq_mask = torch.triu(torch.ones(seq_length, seq_length, device=device), diagonal=1)
|
| 166 |
+
batch_mask = seq_mask.unsqueeze(0).expand(batch_size, seq_length, seq_length)
|
| 167 |
+
mask_sim = mask_sim * batch_mask
|
| 168 |
+
|
| 169 |
+
# Normalize along the vocab_size dimension
|
| 170 |
+
probs_norm = F.normalize(probs, dim=-1) # Shape: (batch_size, seq_length, vocab_size)
|
| 171 |
+
|
| 172 |
+
# Compute pairwise cosine similarity using einsum
|
| 173 |
+
similarities = torch.einsum("biv,bjv->bij", probs_norm, probs_norm) # Shape: (batch_size, seq_length, seq_length), listing pair-wise similarity values across all positions
|
| 174 |
+
|
| 175 |
+
# Mask out lower triangle (to consider only i < j pairs)
|
| 176 |
+
# mask_sim = torch.triu(torch.ones(seq_length, seq_length, device=probs.device), diagonal=1)
|
| 177 |
+
valid_similarities = similarities * mask_sim # Shape: (batch_size, seq_length, seq_length)
|
| 178 |
+
|
| 179 |
+
# Compute average similarity
|
| 180 |
+
total_similarity = valid_similarities.sum()
|
| 181 |
+
total_comparisons = mask_sim.sum().item()
|
| 182 |
+
|
| 183 |
+
if total_comparisons == 0:
|
| 184 |
+
return torch.tensor(0.0, device=device)
|
| 185 |
+
|
| 186 |
+
return total_similarity / total_comparisons
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def forward(
|
| 190 |
+
self,
|
| 191 |
+
input_ids: Tensor | None = None,
|
| 192 |
+
attention_mask: Tensor | None = None,
|
| 193 |
+
token_type_ids: Tensor | None = None,
|
| 194 |
+
position_ids: Tensor | None = None,
|
| 195 |
+
head_mask: Tensor | None = None,
|
| 196 |
+
inputs_embeds: Tensor | None = None,
|
| 197 |
+
encoder_hidden_states: Tensor | None = None,
|
| 198 |
+
encoder_attention_mask: Tensor | None = None,
|
| 199 |
+
labels: Tensor | None = None,
|
| 200 |
+
output_attentions: bool | None = None,
|
| 201 |
+
output_hidden_states: bool | None = None,
|
| 202 |
+
return_dict: bool | None = None) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 203 |
+
|
| 204 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 205 |
+
|
| 206 |
+
outputs = self.bert(
|
| 207 |
+
input_ids,
|
| 208 |
+
attention_mask=attention_mask,
|
| 209 |
+
token_type_ids=token_type_ids,
|
| 210 |
+
position_ids=position_ids,
|
| 211 |
+
head_mask=head_mask,
|
| 212 |
+
inputs_embeds=inputs_embeds,
|
| 213 |
+
output_attentions=output_attentions,
|
| 214 |
+
output_hidden_states=output_hidden_states,
|
| 215 |
+
return_dict=return_dict,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
hidden_states = outputs[0]
|
| 219 |
+
hidden_transform = self.transform(hidden_states)
|
| 220 |
+
logits = self.decoder(hidden_transform) + self.bias
|
| 221 |
+
|
| 222 |
+
# temperature = 0.75
|
| 223 |
+
# logits = logits / temperature
|
| 224 |
+
|
| 225 |
+
probs = F.softmax(logits, dim=-1)
|
| 226 |
+
|
| 227 |
+
# Probability manipulations to avoid repeats from already known genes
|
| 228 |
+
probs = self.assign_known_gene_probs(probs, input_ids, labels)
|
| 229 |
+
convert_probs = self.probability_convert(probs, input_ids, labels)
|
| 230 |
+
assigned_probs = self.assign_known_gene_probs(convert_probs, input_ids, labels)
|
| 231 |
+
|
| 232 |
+
masked_lm_loss = None
|
| 233 |
+
if labels is not None:
|
| 234 |
+
probs_flat = assigned_probs.view(-1, self.config.vocab_size)
|
| 235 |
+
labels_flat = labels.view(-1)
|
| 236 |
+
mask = (labels != -100).float().view(-1)
|
| 237 |
+
|
| 238 |
+
# Compute masked cross-entropy loss
|
| 239 |
+
masked_lm_loss = -torch.log(torch.clamp(probs_flat[torch.arange(len(labels_flat)), labels_flat], min=1e-18)) * mask
|
| 240 |
+
masked_lm_loss = masked_lm_loss.sum() / mask.sum()
|
| 241 |
+
|
| 242 |
+
similarity_loss = self.compute_similarity_on_probs(assigned_probs, labels)
|
| 243 |
+
lambda_similarity = 5.0 # Adjust this value through experimentation
|
| 244 |
+
masked_lm_loss = masked_lm_loss + lambda_similarity * similarity_loss
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
loss = None
|
| 249 |
+
|
| 250 |
+
if not return_dict:
|
| 251 |
+
output = (assigned_probs,) + outputs[2:]
|
| 252 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 253 |
+
|
| 254 |
+
return MaskedLMOutput(
|
| 255 |
+
loss=masked_lm_loss,
|
| 256 |
+
logits=assigned_probs,
|
| 257 |
+
hidden_states=outputs.hidden_states,
|
| 258 |
+
attentions=outputs.attentions,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 262 |
+
input_shape = input_ids.shape
|
| 263 |
+
effective_batch_size = input_shape[0]
|
| 264 |
+
|
| 265 |
+
# add a dummy token
|
| 266 |
+
if self.config.pad_token_id is None:
|
| 267 |
+
raise ValueError("The PAD token should be defined for generation")
|
| 268 |
+
|
| 269 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
| 270 |
+
dummy_token = torch.full(
|
| 271 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
| 272 |
+
)
|
| 273 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 274 |
+
|
| 275 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# load cell type dataset (includes all tissues)
|
| 279 |
+
train_dataset=load_from_disk("example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset")
|
| 280 |
+
# load evaluation dataset (includes all tissues)
|
| 281 |
+
eval_dataset=load_from_disk("example_input_files/cell_classification/cell_type_annotation/cell_type_test_data.dataset")
|
| 282 |
+
|
| 283 |
+
dataset_list = []
|
| 284 |
+
evalset_list = []
|
| 285 |
+
organ_list = []
|
| 286 |
+
target_dict_list = []
|
| 287 |
+
|
| 288 |
+
for organ in Counter(train_dataset["organ_major"]).keys():
|
| 289 |
+
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
|
| 290 |
+
if organ in ["bone_marrow"]:
|
| 291 |
+
continue
|
| 292 |
+
elif organ=="immune":
|
| 293 |
+
organ_ids = ["immune","bone_marrow"]
|
| 294 |
+
organ_list += ["immune"]
|
| 295 |
+
else:
|
| 296 |
+
organ_ids = [organ]
|
| 297 |
+
organ_list += [organ]
|
| 298 |
+
|
| 299 |
+
# filter datasets for given organ
|
| 300 |
+
def if_organ(example):
|
| 301 |
+
return example["organ_major"] in organ_ids
|
| 302 |
+
trainset_organ = train_dataset.filter(if_organ, num_proc=16)
|
| 303 |
+
|
| 304 |
+
# per scDeepsort published method, drop cell types representing <0.5% of cells
|
| 305 |
+
celltype_counter = Counter(trainset_organ["cell_type"])
|
| 306 |
+
total_cells = sum(celltype_counter.values())
|
| 307 |
+
cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
|
| 308 |
+
def if_not_rare_celltype(example):
|
| 309 |
+
return example["cell_type"] in cells_to_keep
|
| 310 |
+
trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
|
| 311 |
+
|
| 312 |
+
# shuffle datasets and rename columns
|
| 313 |
+
trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
|
| 314 |
+
trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
|
| 315 |
+
trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
|
| 316 |
+
|
| 317 |
+
# create dictionary of cell types : label ids
|
| 318 |
+
target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
|
| 319 |
+
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
| 320 |
+
target_dict_list += [target_name_id_dict]
|
| 321 |
+
|
| 322 |
+
# change labels to numerical ids
|
| 323 |
+
def classes_to_ids(example):
|
| 324 |
+
example["label"] = target_name_id_dict[example["label"]]
|
| 325 |
+
return example
|
| 326 |
+
labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
|
| 327 |
+
|
| 328 |
+
# create 80/20 train/eval splits
|
| 329 |
+
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
|
| 330 |
+
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
|
| 331 |
+
|
| 332 |
+
# filter dataset for cell types in corresponding training set
|
| 333 |
+
trained_labels = list(Counter(labeled_train_split["label"]).keys())
|
| 334 |
+
def if_trained_label(example):
|
| 335 |
+
return example["label"] in trained_labels
|
| 336 |
+
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)
|
| 337 |
+
|
| 338 |
+
dataset_list += [labeled_train_split]
|
| 339 |
+
evalset_list += [labeled_eval_split_subset]
|
| 340 |
+
|
| 341 |
+
trainset_dict = dict(zip(organ_list,dataset_list))
|
| 342 |
+
traintargetdict_dict = dict(zip(organ_list,target_dict_list))
|
| 343 |
+
|
| 344 |
+
evalset_dict = dict(zip(organ_list,evalset_list))
|
| 345 |
+
|
| 346 |
+
def compute_metrics(pred):
|
| 347 |
+
labels = pred.label_ids
|
| 348 |
+
preds = pred.predictions.argmax(-1)
|
| 349 |
+
# calculate accuracy and macro f1 using sklearn's function
|
| 350 |
+
acc = accuracy_score(labels, preds)
|
| 351 |
+
macro_f1 = f1_score(labels, preds, average='macro')
|
| 352 |
+
weighted_f1 = f1_score(labels, preds, average='weighted')
|
| 353 |
+
return {
|
| 354 |
+
'accuracy': acc,
|
| 355 |
+
'macro_f1': macro_f1,
|
| 356 |
+
'weighted_f1': weighted_f1
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
# set model parameters
|
| 360 |
+
# max input size
|
| 361 |
+
max_input_size = 2 ** 11 # 2048
|
| 362 |
+
|
| 363 |
+
# set training hyperparameters
|
| 364 |
+
# max learning rate
|
| 365 |
+
max_lr = 5e-5
|
| 366 |
+
# how many pretrained layers to freeze
|
| 367 |
+
freeze_layers = 0
|
| 368 |
+
# number gpus
|
| 369 |
+
num_gpus = 1
|
| 370 |
+
# number cpu cores
|
| 371 |
+
num_proc = 16
|
| 372 |
+
# batch size for training and eval
|
| 373 |
+
geneformer_batch_size = 12
|
| 374 |
+
# learning schedule
|
| 375 |
+
lr_schedule_fn = "linear"
|
| 376 |
+
# warmup steps
|
| 377 |
+
warmup_steps = 500
|
| 378 |
+
# number of epochs
|
| 379 |
+
epochs = 10
|
| 380 |
+
# optimizer
|
| 381 |
+
optimizer = "adamw"
|
| 382 |
+
|
| 383 |
+
for organ in organ_list:
|
| 384 |
+
print(organ)
|
| 385 |
+
organ_trainset = trainset_dict[organ]
|
| 386 |
+
organ_evalset = evalset_dict[organ]
|
| 387 |
+
organ_label_dict = traintargetdict_dict[organ]
|
| 388 |
+
|
| 389 |
+
# set logging steps
|
| 390 |
+
logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
|
| 391 |
+
|
| 392 |
+
# reload pretrained model
|
| 393 |
+
model = BertForSequenceClassification.from_pretrained(model_path,
|
| 394 |
+
num_labels=len(organ_label_dict.keys()),
|
| 395 |
+
output_attentions = False,
|
| 396 |
+
output_hidden_states = False).to("cuda")
|
| 397 |
+
|
| 398 |
+
# #############
|
| 399 |
+
pretrained_model = CustomBertForMaskedLM.from_pretrained(model_path)
|
| 400 |
+
# Extract the word embeddings from the pretrained model
|
| 401 |
+
pretrained_word_embeddings = pretrained_model.bert.embeddings.word_embeddings.weight.clone()
|
| 402 |
+
model.bert.embeddings.word_embeddings.load_state_dict({"weight": pretrained_word_embeddings})
|
| 403 |
+
# ############
|
| 404 |
+
|
| 405 |
+
# define output directory path
|
| 406 |
+
current_date = datetime.datetime.now()
|
| 407 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
| 408 |
+
output_dir = f"/ibex/user/chenj0i/Geneformer/Downstream_tasks/Cell_Classify/{prefix}/{datestamp}_geneformer_CellClassifier_{organ}/"
|
| 409 |
+
|
| 410 |
+
# ensure not overwriting previously saved model
|
| 411 |
+
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
| 412 |
+
if os.path.isfile(saved_model_test) == True:
|
| 413 |
+
raise Exception("Model already saved to this directory.")
|
| 414 |
+
|
| 415 |
+
# make output directory
|
| 416 |
+
# subprocess.call(f'mkdir {output_dir}', shell=True)
|
| 417 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 418 |
+
|
| 419 |
+
# set training arguments
|
| 420 |
+
training_args = {
|
| 421 |
+
"learning_rate": max_lr,
|
| 422 |
+
"do_train": True,
|
| 423 |
+
"do_eval": True,
|
| 424 |
+
"evaluation_strategy": "epoch",
|
| 425 |
+
"save_strategy": "epoch",
|
| 426 |
+
"logging_steps": logging_steps,
|
| 427 |
+
"group_by_length": True,
|
| 428 |
+
"length_column_name": "length",
|
| 429 |
+
"disable_tqdm": False,
|
| 430 |
+
"lr_scheduler_type": lr_schedule_fn,
|
| 431 |
+
"warmup_steps": warmup_steps,
|
| 432 |
+
"weight_decay": 0.001,
|
| 433 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
| 434 |
+
"per_device_eval_batch_size": geneformer_batch_size,
|
| 435 |
+
"num_train_epochs": epochs,
|
| 436 |
+
"load_best_model_at_end": True,
|
| 437 |
+
"output_dir": output_dir,
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
training_args_init = TrainingArguments(**training_args)
|
| 441 |
+
|
| 442 |
+
# create the trainer
|
| 443 |
+
trainer = Trainer(
|
| 444 |
+
model=model,
|
| 445 |
+
args=training_args_init,
|
| 446 |
+
data_collator=DataCollatorForCellClassification(),
|
| 447 |
+
train_dataset=organ_trainset,
|
| 448 |
+
eval_dataset=organ_evalset,
|
| 449 |
+
compute_metrics=compute_metrics
|
| 450 |
+
)
|
| 451 |
+
# train the cell type classifier
|
| 452 |
+
trainer.train()
|
| 453 |
+
predictions = trainer.predict(organ_evalset)
|
| 454 |
+
with open(f"{output_dir}predictions.pickle", "wb") as fp:
|
| 455 |
+
pickle.dump(predictions, fp)
|
| 456 |
+
trainer.save_metrics("eval",predictions.metrics)
|
| 457 |
+
trainer.save_model(output_dir)
|
Downstream_tasks/Classification/Tissue_type_ML.ipynb
ADDED
|
@@ -0,0 +1,933 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import os\n",
|
| 10 |
+
"from collections import Counter\n",
|
| 11 |
+
"import datetime\n",
|
| 12 |
+
"import pickle\n",
|
| 13 |
+
"import numpy as np\n",
|
| 14 |
+
"from datasets import load_from_disk\n",
|
| 15 |
+
"from sklearn.metrics import accuracy_score, f1_score\n",
|
| 16 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 17 |
+
"from sklearn.linear_model import LogisticRegression\n",
|
| 18 |
+
"from sklearn.svm import SVC\n",
|
| 19 |
+
"from sklearn.preprocessing import StandardScaler\n",
|
| 20 |
+
"from sklearn.pipeline import make_pipeline\n",
|
| 21 |
+
"from tqdm import tqdm\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"# Load datasets\n",
|
| 24 |
+
"train_dataset = load_from_disk(\"example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset\")\n",
|
| 25 |
+
"eval_dataset = load_from_disk(\"example_input_files/cell_classification/cell_type_annotation/cell_type_test_data.dataset\")\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"dataset_list, evalset_list, organ_list, target_dict_list = [], [], [], []\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"for organ in Counter(train_dataset[\"organ_major\"]).keys():\n",
|
| 30 |
+
" if organ in [\"bone_marrow\"]: \n",
|
| 31 |
+
" continue\n",
|
| 32 |
+
" elif organ == \"immune\":\n",
|
| 33 |
+
" organ_ids = [\"immune\", \"bone_marrow\"]\n",
|
| 34 |
+
" organ_list += [\"immune\"]\n",
|
| 35 |
+
" else:\n",
|
| 36 |
+
" organ_ids = [organ]\n",
|
| 37 |
+
" organ_list += [organ]\n",
|
| 38 |
+
" \n",
|
| 39 |
+
" def if_organ(example):\n",
|
| 40 |
+
" return example[\"organ_major\"] in organ_ids\n",
|
| 41 |
+
" trainset_organ = train_dataset.filter(if_organ, num_proc=16)\n",
|
| 42 |
+
" \n",
|
| 43 |
+
" celltype_counter = Counter(trainset_organ[\"cell_type\"])\n",
|
| 44 |
+
" total_cells = sum(celltype_counter.values())\n",
|
| 45 |
+
" cells_to_keep = [k for k, v in celltype_counter.items() if v > (0.005 * total_cells)]\n",
|
| 46 |
+
" \n",
|
| 47 |
+
" def if_not_rare_celltype(example):\n",
|
| 48 |
+
" return example[\"cell_type\"] in cells_to_keep\n",
|
| 49 |
+
" trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)\n",
|
| 50 |
+
" \n",
|
| 51 |
+
" trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)\n",
|
| 52 |
+
" trainset_organ_shuffled = trainset_organ_shuffled.rename_column(\"cell_type\", \"label\")\n",
|
| 53 |
+
" trainset_organ_shuffled = trainset_organ_shuffled.remove_columns(\"organ_major\")\n",
|
| 54 |
+
" \n",
|
| 55 |
+
" target_names = list(Counter(trainset_organ_shuffled[\"label\"]).keys())\n",
|
| 56 |
+
" target_name_id_dict = dict(zip(target_names, range(len(target_names))))\n",
|
| 57 |
+
" target_dict_list.append(target_name_id_dict)\n",
|
| 58 |
+
" \n",
|
| 59 |
+
" def classes_to_ids(example):\n",
|
| 60 |
+
" example[\"label\"] = target_name_id_dict[example[\"label\"]]\n",
|
| 61 |
+
" return example\n",
|
| 62 |
+
" labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)\n",
|
| 63 |
+
" \n",
|
| 64 |
+
" labeled_train_split = labeled_trainset.select(range(0, round(len(labeled_trainset) * 0.8)))\n",
|
| 65 |
+
" labeled_eval_split = labeled_trainset.select(range(round(len(labeled_trainset) * 0.8), len(labeled_trainset)))\n",
|
| 66 |
+
" \n",
|
| 67 |
+
" trained_labels = list(Counter(labeled_train_split[\"label\"]).keys())\n",
|
| 68 |
+
" def if_trained_label(example):\n",
|
| 69 |
+
" return example[\"label\"] in trained_labels\n",
|
| 70 |
+
" labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)\n",
|
| 71 |
+
" \n",
|
| 72 |
+
" dataset_list.append(labeled_train_split)\n",
|
| 73 |
+
" evalset_list.append(labeled_eval_split_subset)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"trainset_dict = dict(zip(organ_list, dataset_list))\n",
|
| 76 |
+
"traintargetdict_dict = dict(zip(organ_list, target_dict_list))\n",
|
| 77 |
+
"evalset_dict = dict(zip(organ_list, evalset_list))"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": 2,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [
|
| 85 |
+
{
|
| 86 |
+
"name": "stdout",
|
| 87 |
+
"output_type": "stream",
|
| 88 |
+
"text": [
|
| 89 |
+
"\n",
|
| 90 |
+
"===== Organ: spleen =====\n"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"name": "stderr",
|
| 95 |
+
"output_type": "stream",
|
| 96 |
+
"text": [
|
| 97 |
+
"padding...: 12330it [00:00, 76763.11it/s]\n",
|
| 98 |
+
"padding...: 3083it [00:00, 75593.59it/s]\n",
|
| 99 |
+
"spleen models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"name": "stdout",
|
| 104 |
+
"output_type": "stream",
|
| 105 |
+
"text": [
|
| 106 |
+
"Training RandomForest...\n"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"name": "stderr",
|
| 111 |
+
"output_type": "stream",
|
| 112 |
+
"text": [
|
| 113 |
+
"spleen models: 50%|█████ | 1/2 [00:00<00:00, 1.99it/s]/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 114 |
+
" warnings.warn(\n",
|
| 115 |
+
" "
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"name": "stdout",
|
| 120 |
+
"output_type": "stream",
|
| 121 |
+
"text": [
|
| 122 |
+
"RandomForest - Acc: 0.5864, Macro F1: 0.1947, Weighted F1: 0.5845\n",
|
| 123 |
+
"Training LogisticRegression...\n",
|
| 124 |
+
"LogisticRegression - Acc: 0.7415, Macro F1: 0.1419, Weighted F1: 0.6331\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"===== Organ: kidney =====\n"
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"name": "stderr",
|
| 131 |
+
"output_type": "stream",
|
| 132 |
+
"text": [
|
| 133 |
+
"padding...: 35199it [00:00, 54605.10it/s]\n",
|
| 134 |
+
"padding...: 8800it [00:00, 57420.64it/s]\n",
|
| 135 |
+
"kidney models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"name": "stdout",
|
| 140 |
+
"output_type": "stream",
|
| 141 |
+
"text": [
|
| 142 |
+
"Training RandomForest...\n"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"name": "stderr",
|
| 147 |
+
"output_type": "stream",
|
| 148 |
+
"text": [
|
| 149 |
+
"kidney models: 50%|█████ | 1/2 [00:01<00:01, 1.65s/it]"
|
| 150 |
+
]
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"name": "stdout",
|
| 154 |
+
"output_type": "stream",
|
| 155 |
+
"text": [
|
| 156 |
+
"RandomForest - Acc: 0.1755, Macro F1: 0.0826, Weighted F1: 0.1772\n",
|
| 157 |
+
"Training LogisticRegression...\n"
|
| 158 |
+
]
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"name": "stderr",
|
| 162 |
+
"output_type": "stream",
|
| 163 |
+
"text": [
|
| 164 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 165 |
+
" warnings.warn(\n",
|
| 166 |
+
" "
|
| 167 |
+
]
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"name": "stdout",
|
| 171 |
+
"output_type": "stream",
|
| 172 |
+
"text": [
|
| 173 |
+
"LogisticRegression - Acc: 0.3287, Macro F1: 0.0713, Weighted F1: 0.2267\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"===== Organ: lung =====\n"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"name": "stderr",
|
| 180 |
+
"output_type": "stream",
|
| 181 |
+
"text": [
|
| 182 |
+
"padding...: 26098it [00:00, 63650.72it/s]\n",
|
| 183 |
+
"padding...: 6525it [00:00, 61571.18it/s]\n",
|
| 184 |
+
"lung models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"name": "stdout",
|
| 189 |
+
"output_type": "stream",
|
| 190 |
+
"text": [
|
| 191 |
+
"Training RandomForest...\n"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"name": "stderr",
|
| 196 |
+
"output_type": "stream",
|
| 197 |
+
"text": [
|
| 198 |
+
"lung models: 50%|█████ | 1/2 [00:00<00:00, 1.05it/s]"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"name": "stdout",
|
| 203 |
+
"output_type": "stream",
|
| 204 |
+
"text": [
|
| 205 |
+
"RandomForest - Acc: 0.2077, Macro F1: 0.0910, Weighted F1: 0.2066\n",
|
| 206 |
+
"Training LogisticRegression...\n"
|
| 207 |
+
]
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"name": "stderr",
|
| 211 |
+
"output_type": "stream",
|
| 212 |
+
"text": [
|
| 213 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 214 |
+
" warnings.warn(\n",
|
| 215 |
+
" "
|
| 216 |
+
]
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"name": "stdout",
|
| 220 |
+
"output_type": "stream",
|
| 221 |
+
"text": [
|
| 222 |
+
"LogisticRegression - Acc: 0.3099, Macro F1: 0.0761, Weighted F1: 0.2399\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"===== Organ: brain =====\n"
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"name": "stderr",
|
| 229 |
+
"output_type": "stream",
|
| 230 |
+
"text": [
|
| 231 |
+
"padding...: 10656it [00:00, 67287.79it/s]\n",
|
| 232 |
+
"padding...: 2664it [00:00, 75149.65it/s]\n",
|
| 233 |
+
"brain models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"name": "stdout",
|
| 238 |
+
"output_type": "stream",
|
| 239 |
+
"text": [
|
| 240 |
+
"Training RandomForest...\n"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"name": "stderr",
|
| 245 |
+
"output_type": "stream",
|
| 246 |
+
"text": [
|
| 247 |
+
"brain models: 50%|█████ | 1/2 [00:00<00:00, 2.21it/s]"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"name": "stdout",
|
| 252 |
+
"output_type": "stream",
|
| 253 |
+
"text": [
|
| 254 |
+
"RandomForest - Acc: 0.7459, Macro F1: 0.1863, Weighted F1: 0.7495\n",
|
| 255 |
+
"Training LogisticRegression...\n"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"name": "stderr",
|
| 260 |
+
"output_type": "stream",
|
| 261 |
+
"text": [
|
| 262 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 263 |
+
" warnings.warn(\n",
|
| 264 |
+
" "
|
| 265 |
+
]
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
"name": "stdout",
|
| 269 |
+
"output_type": "stream",
|
| 270 |
+
"text": [
|
| 271 |
+
"LogisticRegression - Acc: 0.8622, Macro F1: 0.1543, Weighted F1: 0.7985\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"===== Organ: placenta =====\n"
|
| 274 |
+
]
|
| 275 |
+
},
|
| 276 |
+
{
|
| 277 |
+
"name": "stderr",
|
| 278 |
+
"output_type": "stream",
|
| 279 |
+
"text": [
|
| 280 |
+
"padding...: 7415it [00:00, 54391.55it/s]\n",
|
| 281 |
+
"padding...: 1854it [00:00, 57379.91it/s]\n",
|
| 282 |
+
"placenta models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"name": "stdout",
|
| 287 |
+
"output_type": "stream",
|
| 288 |
+
"text": [
|
| 289 |
+
"Training RandomForest...\n"
|
| 290 |
+
]
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"name": "stderr",
|
| 294 |
+
"output_type": "stream",
|
| 295 |
+
"text": [
|
| 296 |
+
"placenta models: 50%|█████ | 1/2 [00:00<00:00, 1.88it/s]/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 297 |
+
" warnings.warn(\n",
|
| 298 |
+
" "
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"name": "stdout",
|
| 303 |
+
"output_type": "stream",
|
| 304 |
+
"text": [
|
| 305 |
+
"RandomForest - Acc: 0.6009, Macro F1: 0.3471, Weighted F1: 0.5959\n",
|
| 306 |
+
"Training LogisticRegression...\n",
|
| 307 |
+
"LogisticRegression - Acc: 0.7406, Macro F1: 0.2836, Weighted F1: 0.6302\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"===== Organ: immune =====\n"
|
| 310 |
+
]
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"name": "stderr",
|
| 314 |
+
"output_type": "stream",
|
| 315 |
+
"text": [
|
| 316 |
+
"padding...: 20562it [00:00, 74370.86it/s]\n",
|
| 317 |
+
"padding...: 5140it [00:00, 70895.86it/s]\n",
|
| 318 |
+
"immune models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 319 |
+
]
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"name": "stdout",
|
| 323 |
+
"output_type": "stream",
|
| 324 |
+
"text": [
|
| 325 |
+
"Training RandomForest...\n"
|
| 326 |
+
]
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"name": "stderr",
|
| 330 |
+
"output_type": "stream",
|
| 331 |
+
"text": [
|
| 332 |
+
"immune models: 50%|█████ | 1/2 [00:00<00:00, 1.25it/s]"
|
| 333 |
+
]
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"name": "stdout",
|
| 337 |
+
"output_type": "stream",
|
| 338 |
+
"text": [
|
| 339 |
+
"RandomForest - Acc: 0.2008, Macro F1: 0.1312, Weighted F1: 0.2005\n",
|
| 340 |
+
"Training LogisticRegression...\n"
|
| 341 |
+
]
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"name": "stderr",
|
| 345 |
+
"output_type": "stream",
|
| 346 |
+
"text": [
|
| 347 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 348 |
+
" warnings.warn(\n",
|
| 349 |
+
" "
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"name": "stdout",
|
| 354 |
+
"output_type": "stream",
|
| 355 |
+
"text": [
|
| 356 |
+
"LogisticRegression - Acc: 0.2749, Macro F1: 0.0921, Weighted F1: 0.1488\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"===== Organ: large_intestine =====\n"
|
| 359 |
+
]
|
| 360 |
+
},
|
| 361 |
+
{
|
| 362 |
+
"name": "stderr",
|
| 363 |
+
"output_type": "stream",
|
| 364 |
+
"text": [
|
| 365 |
+
"padding...: 39678it [00:00, 74202.67it/s]\n",
|
| 366 |
+
"padding...: 9920it [00:00, 77582.36it/s]\n",
|
| 367 |
+
"large_intestine models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 368 |
+
]
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"name": "stdout",
|
| 372 |
+
"output_type": "stream",
|
| 373 |
+
"text": [
|
| 374 |
+
"Training RandomForest...\n"
|
| 375 |
+
]
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"name": "stderr",
|
| 379 |
+
"output_type": "stream",
|
| 380 |
+
"text": [
|
| 381 |
+
"large_intestine models: 50%|█████ | 1/2 [00:01<00:01, 1.47s/it]"
|
| 382 |
+
]
|
| 383 |
+
},
|
| 384 |
+
{
|
| 385 |
+
"name": "stdout",
|
| 386 |
+
"output_type": "stream",
|
| 387 |
+
"text": [
|
| 388 |
+
"RandomForest - Acc: 0.2541, Macro F1: 0.0983, Weighted F1: 0.2556\n",
|
| 389 |
+
"Training LogisticRegression...\n"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"name": "stderr",
|
| 394 |
+
"output_type": "stream",
|
| 395 |
+
"text": [
|
| 396 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 397 |
+
" warnings.warn(\n",
|
| 398 |
+
" "
|
| 399 |
+
]
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
"name": "stdout",
|
| 403 |
+
"output_type": "stream",
|
| 404 |
+
"text": [
|
| 405 |
+
"LogisticRegression - Acc: 0.3095, Macro F1: 0.0843, Weighted F1: 0.2555\n",
|
| 406 |
+
"\n",
|
| 407 |
+
"===== Organ: pancreas =====\n"
|
| 408 |
+
]
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"name": "stderr",
|
| 412 |
+
"output_type": "stream",
|
| 413 |
+
"text": [
|
| 414 |
+
"padding...: 21934it [00:00, 63776.95it/s]\n",
|
| 415 |
+
"padding...: 5484it [00:00, 71125.95it/s]\n",
|
| 416 |
+
"pancreas models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 417 |
+
]
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"name": "stdout",
|
| 421 |
+
"output_type": "stream",
|
| 422 |
+
"text": [
|
| 423 |
+
"Training RandomForest...\n"
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"name": "stderr",
|
| 428 |
+
"output_type": "stream",
|
| 429 |
+
"text": [
|
| 430 |
+
"pancreas models: 50%|█████ | 1/2 [00:00<00:00, 1.19it/s]"
|
| 431 |
+
]
|
| 432 |
+
},
|
| 433 |
+
{
|
| 434 |
+
"name": "stdout",
|
| 435 |
+
"output_type": "stream",
|
| 436 |
+
"text": [
|
| 437 |
+
"RandomForest - Acc: 0.2438, Macro F1: 0.1438, Weighted F1: 0.2424\n",
|
| 438 |
+
"Training LogisticRegression...\n"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"name": "stderr",
|
| 443 |
+
"output_type": "stream",
|
| 444 |
+
"text": [
|
| 445 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 446 |
+
" warnings.warn(\n",
|
| 447 |
+
" "
|
| 448 |
+
]
|
| 449 |
+
},
|
| 450 |
+
{
|
| 451 |
+
"name": "stdout",
|
| 452 |
+
"output_type": "stream",
|
| 453 |
+
"text": [
|
| 454 |
+
"LogisticRegression - Acc: 0.3485, Macro F1: 0.1330, Weighted F1: 0.2601\n",
|
| 455 |
+
"\n",
|
| 456 |
+
"===== Organ: liver =====\n"
|
| 457 |
+
]
|
| 458 |
+
},
|
| 459 |
+
{
|
| 460 |
+
"name": "stderr",
|
| 461 |
+
"output_type": "stream",
|
| 462 |
+
"text": [
|
| 463 |
+
"padding...: 22427it [00:00, 64230.25it/s]\n",
|
| 464 |
+
"padding...: 5607it [00:00, 62494.75it/s]\n",
|
| 465 |
+
"liver models: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 466 |
+
]
|
| 467 |
+
},
|
| 468 |
+
{
|
| 469 |
+
"name": "stdout",
|
| 470 |
+
"output_type": "stream",
|
| 471 |
+
"text": [
|
| 472 |
+
"Training RandomForest...\n"
|
| 473 |
+
]
|
| 474 |
+
},
|
| 475 |
+
{
|
| 476 |
+
"name": "stderr",
|
| 477 |
+
"output_type": "stream",
|
| 478 |
+
"text": [
|
| 479 |
+
"liver models: 50%|█████ | 1/2 [00:00<00:00, 1.26it/s]"
|
| 480 |
+
]
|
| 481 |
+
},
|
| 482 |
+
{
|
| 483 |
+
"name": "stdout",
|
| 484 |
+
"output_type": "stream",
|
| 485 |
+
"text": [
|
| 486 |
+
"RandomForest - Acc: 0.2814, Macro F1: 0.1262, Weighted F1: 0.2809\n",
|
| 487 |
+
"Training LogisticRegression...\n"
|
| 488 |
+
]
|
| 489 |
+
},
|
| 490 |
+
{
|
| 491 |
+
"name": "stderr",
|
| 492 |
+
"output_type": "stream",
|
| 493 |
+
"text": [
|
| 494 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
|
| 495 |
+
" warnings.warn(\n",
|
| 496 |
+
" "
|
| 497 |
+
]
|
| 498 |
+
},
|
| 499 |
+
{
|
| 500 |
+
"name": "stdout",
|
| 501 |
+
"output_type": "stream",
|
| 502 |
+
"text": [
|
| 503 |
+
"LogisticRegression - Acc: 0.3512, Macro F1: 0.0738, Weighted F1: 0.2633\n"
|
| 504 |
+
]
|
| 505 |
+
},
|
| 506 |
+
{
|
| 507 |
+
"name": "stderr",
|
| 508 |
+
"output_type": "stream",
|
| 509 |
+
"text": [
|
| 510 |
+
"\r"
|
| 511 |
+
]
|
| 512 |
+
}
|
| 513 |
+
],
|
| 514 |
+
"source": [
|
| 515 |
+
"def extract_features(dataset):\n",
|
| 516 |
+
" seqs = dataset[\"input_ids\"]\n",
|
| 517 |
+
" max_len = max(len(s) for s in seqs)\n",
|
| 518 |
+
" padded = np.zeros((len(seqs), max_len), dtype=np.int64)\n",
|
| 519 |
+
" for i, s in tqdm(enumerate(seqs), desc=\"padding...\", colour=\"blue\"):\n",
|
| 520 |
+
" padded[i, :len(s)] = s\n",
|
| 521 |
+
" X = np.mean(padded, axis=1)[:, None] # simple mean pooling\n",
|
| 522 |
+
" y = np.array(dataset[\"label\"])\n",
|
| 523 |
+
" return X, y\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"results = {}\n",
|
| 526 |
+
"\n",
|
| 527 |
+
"for organ in organ_list:\n",
|
| 528 |
+
" print(f\"\\n===== Organ: {organ} =====\")\n",
|
| 529 |
+
" organ_trainset = trainset_dict[organ]\n",
|
| 530 |
+
" organ_evalset = evalset_dict[organ]\n",
|
| 531 |
+
" \n",
|
| 532 |
+
" X_train, y_train = extract_features(organ_trainset)\n",
|
| 533 |
+
" X_test, y_test = extract_features(organ_evalset)\n",
|
| 534 |
+
" \n",
|
| 535 |
+
" classifiers = {\n",
|
| 536 |
+
" \"RandomForest\": RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1),\n",
|
| 537 |
+
" # \"SVM\": make_pipeline(StandardScaler(), SVC(kernel=\"rbf\", probability=True, random_state=42)),\n",
|
| 538 |
+
" \"LogisticRegression\": make_pipeline(StandardScaler(), LogisticRegression(max_iter=500, multi_class=\"multinomial\"))\n",
|
| 539 |
+
" }\n",
|
| 540 |
+
" \n",
|
| 541 |
+
" organ_results = {}\n",
|
| 542 |
+
" for clf_name, clf in tqdm(classifiers.items(), desc=f\"{organ} models\", leave=False):\n",
|
| 543 |
+
" print(f\"Training {clf_name}...\")\n",
|
| 544 |
+
" clf.fit(X_train, y_train)\n",
|
| 545 |
+
" preds = clf.predict(X_test)\n",
|
| 546 |
+
" acc = accuracy_score(y_test, preds)\n",
|
| 547 |
+
" macro_f1 = f1_score(y_test, preds, average=\"macro\")\n",
|
| 548 |
+
" weighted_f1 = f1_score(y_test, preds, average=\"weighted\")\n",
|
| 549 |
+
" organ_results[clf_name] = {\n",
|
| 550 |
+
" \"accuracy\": acc,\n",
|
| 551 |
+
" \"macro_f1\": macro_f1,\n",
|
| 552 |
+
" \"weighted_f1\": weighted_f1\n",
|
| 553 |
+
" }\n",
|
| 554 |
+
" print(f\"{clf_name} - Acc: {acc:.4f}, Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}\")\n",
|
| 555 |
+
" \n",
|
| 556 |
+
" results[organ] = organ_results\n"
|
| 557 |
+
]
|
| 558 |
+
},
|
| 559 |
+
{
|
| 560 |
+
"cell_type": "code",
|
| 561 |
+
"execution_count": 4,
|
| 562 |
+
"metadata": {},
|
| 563 |
+
"outputs": [
|
| 564 |
+
{
|
| 565 |
+
"name": "stdout",
|
| 566 |
+
"output_type": "stream",
|
| 567 |
+
"text": [
|
| 568 |
+
"\n",
|
| 569 |
+
"===== Organ: spleen =====\n"
|
| 570 |
+
]
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"name": "stderr",
|
| 574 |
+
"output_type": "stream",
|
| 575 |
+
"text": [
|
| 576 |
+
"padding...: 12330it [00:00, 74149.68it/s]\n",
|
| 577 |
+
"padding...: 3083it [00:00, 79566.32it/s]\n",
|
| 578 |
+
"spleen models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 579 |
+
]
|
| 580 |
+
},
|
| 581 |
+
{
|
| 582 |
+
"name": "stdout",
|
| 583 |
+
"output_type": "stream",
|
| 584 |
+
"text": [
|
| 585 |
+
"Training SVM...\n"
|
| 586 |
+
]
|
| 587 |
+
},
|
| 588 |
+
{
|
| 589 |
+
"name": "stderr",
|
| 590 |
+
"output_type": "stream",
|
| 591 |
+
"text": [
|
| 592 |
+
" "
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"name": "stdout",
|
| 597 |
+
"output_type": "stream",
|
| 598 |
+
"text": [
|
| 599 |
+
"SVM - Acc: 0.7434, Macro F1: 0.1421, Weighted F1: 0.6340\n",
|
| 600 |
+
"\n",
|
| 601 |
+
"===== Organ: kidney =====\n"
|
| 602 |
+
]
|
| 603 |
+
},
|
| 604 |
+
{
|
| 605 |
+
"name": "stderr",
|
| 606 |
+
"output_type": "stream",
|
| 607 |
+
"text": [
|
| 608 |
+
"padding...: 35199it [00:00, 54654.42it/s]\n",
|
| 609 |
+
"padding...: 8800it [00:00, 54786.08it/s]\n",
|
| 610 |
+
"kidney models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 611 |
+
]
|
| 612 |
+
},
|
| 613 |
+
{
|
| 614 |
+
"name": "stdout",
|
| 615 |
+
"output_type": "stream",
|
| 616 |
+
"text": [
|
| 617 |
+
"Training SVM...\n"
|
| 618 |
+
]
|
| 619 |
+
},
|
| 620 |
+
{
|
| 621 |
+
"name": "stderr",
|
| 622 |
+
"output_type": "stream",
|
| 623 |
+
"text": [
|
| 624 |
+
" "
|
| 625 |
+
]
|
| 626 |
+
},
|
| 627 |
+
{
|
| 628 |
+
"name": "stdout",
|
| 629 |
+
"output_type": "stream",
|
| 630 |
+
"text": [
|
| 631 |
+
"SVM - Acc: 0.3340, Macro F1: 0.0731, Weighted F1: 0.2334\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"===== Organ: lung =====\n"
|
| 634 |
+
]
|
| 635 |
+
},
|
| 636 |
+
{
|
| 637 |
+
"name": "stderr",
|
| 638 |
+
"output_type": "stream",
|
| 639 |
+
"text": [
|
| 640 |
+
"padding...: 26098it [00:00, 63652.31it/s]\n",
|
| 641 |
+
"padding...: 6525it [00:00, 63915.46it/s]\n",
|
| 642 |
+
"lung models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 643 |
+
]
|
| 644 |
+
},
|
| 645 |
+
{
|
| 646 |
+
"name": "stdout",
|
| 647 |
+
"output_type": "stream",
|
| 648 |
+
"text": [
|
| 649 |
+
"Training SVM...\n"
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"name": "stderr",
|
| 654 |
+
"output_type": "stream",
|
| 655 |
+
"text": [
|
| 656 |
+
" "
|
| 657 |
+
]
|
| 658 |
+
},
|
| 659 |
+
{
|
| 660 |
+
"name": "stdout",
|
| 661 |
+
"output_type": "stream",
|
| 662 |
+
"text": [
|
| 663 |
+
"SVM - Acc: 0.3137, Macro F1: 0.0773, Weighted F1: 0.2438\n",
|
| 664 |
+
"\n",
|
| 665 |
+
"===== Organ: brain =====\n"
|
| 666 |
+
]
|
| 667 |
+
},
|
| 668 |
+
{
|
| 669 |
+
"name": "stderr",
|
| 670 |
+
"output_type": "stream",
|
| 671 |
+
"text": [
|
| 672 |
+
"padding...: 10656it [00:00, 73057.45it/s]\n",
|
| 673 |
+
"padding...: 2664it [00:00, 75210.35it/s]\n",
|
| 674 |
+
"brain models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 675 |
+
]
|
| 676 |
+
},
|
| 677 |
+
{
|
| 678 |
+
"name": "stdout",
|
| 679 |
+
"output_type": "stream",
|
| 680 |
+
"text": [
|
| 681 |
+
"Training SVM...\n"
|
| 682 |
+
]
|
| 683 |
+
},
|
| 684 |
+
{
|
| 685 |
+
"name": "stderr",
|
| 686 |
+
"output_type": "stream",
|
| 687 |
+
"text": [
|
| 688 |
+
" "
|
| 689 |
+
]
|
| 690 |
+
},
|
| 691 |
+
{
|
| 692 |
+
"name": "stdout",
|
| 693 |
+
"output_type": "stream",
|
| 694 |
+
"text": [
|
| 695 |
+
"SVM - Acc: 0.8622, Macro F1: 0.1543, Weighted F1: 0.7985\n",
|
| 696 |
+
"\n",
|
| 697 |
+
"===== Organ: placenta =====\n"
|
| 698 |
+
]
|
| 699 |
+
},
|
| 700 |
+
{
|
| 701 |
+
"name": "stderr",
|
| 702 |
+
"output_type": "stream",
|
| 703 |
+
"text": [
|
| 704 |
+
"padding...: 7415it [00:00, 54724.23it/s]\n",
|
| 705 |
+
"padding...: 1854it [00:00, 57124.05it/s]\n",
|
| 706 |
+
"placenta models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 707 |
+
]
|
| 708 |
+
},
|
| 709 |
+
{
|
| 710 |
+
"name": "stdout",
|
| 711 |
+
"output_type": "stream",
|
| 712 |
+
"text": [
|
| 713 |
+
"Training SVM...\n"
|
| 714 |
+
]
|
| 715 |
+
},
|
| 716 |
+
{
|
| 717 |
+
"name": "stderr",
|
| 718 |
+
"output_type": "stream",
|
| 719 |
+
"text": [
|
| 720 |
+
" "
|
| 721 |
+
]
|
| 722 |
+
},
|
| 723 |
+
{
|
| 724 |
+
"name": "stdout",
|
| 725 |
+
"output_type": "stream",
|
| 726 |
+
"text": [
|
| 727 |
+
"SVM - Acc: 0.7406, Macro F1: 0.2836, Weighted F1: 0.6302\n",
|
| 728 |
+
"\n",
|
| 729 |
+
"===== Organ: immune =====\n"
|
| 730 |
+
]
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"name": "stderr",
|
| 734 |
+
"output_type": "stream",
|
| 735 |
+
"text": [
|
| 736 |
+
"padding...: 20562it [00:00, 74360.35it/s]\n",
|
| 737 |
+
"padding...: 5140it [00:00, 73610.91it/s]\n",
|
| 738 |
+
"immune models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 739 |
+
]
|
| 740 |
+
},
|
| 741 |
+
{
|
| 742 |
+
"name": "stdout",
|
| 743 |
+
"output_type": "stream",
|
| 744 |
+
"text": [
|
| 745 |
+
"Training SVM...\n"
|
| 746 |
+
]
|
| 747 |
+
},
|
| 748 |
+
{
|
| 749 |
+
"name": "stderr",
|
| 750 |
+
"output_type": "stream",
|
| 751 |
+
"text": [
|
| 752 |
+
" "
|
| 753 |
+
]
|
| 754 |
+
},
|
| 755 |
+
{
|
| 756 |
+
"name": "stdout",
|
| 757 |
+
"output_type": "stream",
|
| 758 |
+
"text": [
|
| 759 |
+
"SVM - Acc: 0.2969, Macro F1: 0.1286, Weighted F1: 0.2058\n",
|
| 760 |
+
"\n",
|
| 761 |
+
"===== Organ: large_intestine =====\n"
|
| 762 |
+
]
|
| 763 |
+
},
|
| 764 |
+
{
|
| 765 |
+
"name": "stderr",
|
| 766 |
+
"output_type": "stream",
|
| 767 |
+
"text": [
|
| 768 |
+
"padding...: 39678it [00:00, 78336.69it/s]\n",
|
| 769 |
+
"padding...: 9920it [00:00, 77432.63it/s]\n",
|
| 770 |
+
"large_intestine models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 771 |
+
]
|
| 772 |
+
},
|
| 773 |
+
{
|
| 774 |
+
"name": "stdout",
|
| 775 |
+
"output_type": "stream",
|
| 776 |
+
"text": [
|
| 777 |
+
"Training SVM...\n"
|
| 778 |
+
]
|
| 779 |
+
},
|
| 780 |
+
{
|
| 781 |
+
"name": "stderr",
|
| 782 |
+
"output_type": "stream",
|
| 783 |
+
"text": [
|
| 784 |
+
" "
|
| 785 |
+
]
|
| 786 |
+
},
|
| 787 |
+
{
|
| 788 |
+
"name": "stdout",
|
| 789 |
+
"output_type": "stream",
|
| 790 |
+
"text": [
|
| 791 |
+
"SVM - Acc: 0.3850, Macro F1: 0.1027, Weighted F1: 0.3283\n",
|
| 792 |
+
"\n",
|
| 793 |
+
"===== Organ: pancreas =====\n"
|
| 794 |
+
]
|
| 795 |
+
},
|
| 796 |
+
{
|
| 797 |
+
"name": "stderr",
|
| 798 |
+
"output_type": "stream",
|
| 799 |
+
"text": [
|
| 800 |
+
"padding...: 21934it [00:00, 76007.99it/s]\n",
|
| 801 |
+
"padding...: 5484it [00:00, 75661.05it/s]\n",
|
| 802 |
+
"pancreas models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 803 |
+
]
|
| 804 |
+
},
|
| 805 |
+
{
|
| 806 |
+
"name": "stdout",
|
| 807 |
+
"output_type": "stream",
|
| 808 |
+
"text": [
|
| 809 |
+
"Training SVM...\n"
|
| 810 |
+
]
|
| 811 |
+
},
|
| 812 |
+
{
|
| 813 |
+
"name": "stderr",
|
| 814 |
+
"output_type": "stream",
|
| 815 |
+
"text": [
|
| 816 |
+
" "
|
| 817 |
+
]
|
| 818 |
+
},
|
| 819 |
+
{
|
| 820 |
+
"name": "stdout",
|
| 821 |
+
"output_type": "stream",
|
| 822 |
+
"text": [
|
| 823 |
+
"SVM - Acc: 0.3769, Macro F1: 0.1398, Weighted F1: 0.2843\n",
|
| 824 |
+
"\n",
|
| 825 |
+
"===== Organ: liver =====\n"
|
| 826 |
+
]
|
| 827 |
+
},
|
| 828 |
+
{
|
| 829 |
+
"name": "stderr",
|
| 830 |
+
"output_type": "stream",
|
| 831 |
+
"text": [
|
| 832 |
+
"padding...: 22427it [00:00, 65347.56it/s]\n",
|
| 833 |
+
"padding...: 5607it [00:00, 66067.53it/s]\n",
|
| 834 |
+
"liver models: 0%| | 0/1 [00:00<?, ?it/s]"
|
| 835 |
+
]
|
| 836 |
+
},
|
| 837 |
+
{
|
| 838 |
+
"name": "stdout",
|
| 839 |
+
"output_type": "stream",
|
| 840 |
+
"text": [
|
| 841 |
+
"Training SVM...\n"
|
| 842 |
+
]
|
| 843 |
+
},
|
| 844 |
+
{
|
| 845 |
+
"name": "stderr",
|
| 846 |
+
"output_type": "stream",
|
| 847 |
+
"text": [
|
| 848 |
+
" "
|
| 849 |
+
]
|
| 850 |
+
},
|
| 851 |
+
{
|
| 852 |
+
"name": "stdout",
|
| 853 |
+
"output_type": "stream",
|
| 854 |
+
"text": [
|
| 855 |
+
"SVM - Acc: 0.3820, Macro F1: 0.1061, Weighted F1: 0.3183\n"
|
| 856 |
+
]
|
| 857 |
+
},
|
| 858 |
+
{
|
| 859 |
+
"name": "stderr",
|
| 860 |
+
"output_type": "stream",
|
| 861 |
+
"text": [
|
| 862 |
+
"\r"
|
| 863 |
+
]
|
| 864 |
+
}
|
| 865 |
+
],
|
| 866 |
+
"source": [
|
| 867 |
+
"def extract_features(dataset):\n",
|
| 868 |
+
" seqs = dataset[\"input_ids\"]\n",
|
| 869 |
+
" max_len = max(len(s) for s in seqs)\n",
|
| 870 |
+
" padded = np.zeros((len(seqs), max_len), dtype=np.int64)\n",
|
| 871 |
+
" for i, s in tqdm(enumerate(seqs), desc=\"padding...\", colour=\"blue\"):\n",
|
| 872 |
+
" padded[i, :len(s)] = s\n",
|
| 873 |
+
" X = np.mean(padded, axis=1)[:, None] # simple mean pooling\n",
|
| 874 |
+
" y = np.array(dataset[\"label\"])\n",
|
| 875 |
+
" return X, y\n",
|
| 876 |
+
"\n",
|
| 877 |
+
"results = {}\n",
|
| 878 |
+
"\n",
|
| 879 |
+
"for organ in organ_list:\n",
|
| 880 |
+
" print(f\"\\n===== Organ: {organ} =====\")\n",
|
| 881 |
+
" organ_trainset = trainset_dict[organ]\n",
|
| 882 |
+
" organ_evalset = evalset_dict[organ]\n",
|
| 883 |
+
" \n",
|
| 884 |
+
" X_train, y_train = extract_features(organ_trainset)\n",
|
| 885 |
+
" X_test, y_test = extract_features(organ_evalset)\n",
|
| 886 |
+
" \n",
|
| 887 |
+
" classifiers = {\n",
|
| 888 |
+
" # \"RandomForest\": RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1),\n",
|
| 889 |
+
" \"SVM\": make_pipeline(StandardScaler(), SVC(kernel=\"rbf\", probability=True, random_state=42)),\n",
|
| 890 |
+
" # \"LogisticRegression\": make_pipeline(StandardScaler(), LogisticRegression(max_iter=500, multi_class=\"multinomial\"))\n",
|
| 891 |
+
" }\n",
|
| 892 |
+
" \n",
|
| 893 |
+
" organ_results = {}\n",
|
| 894 |
+
" for clf_name, clf in tqdm(classifiers.items(), desc=f\"{organ} models\", leave=False):\n",
|
| 895 |
+
" print(f\"Training {clf_name}...\")\n",
|
| 896 |
+
" clf.fit(X_train, y_train)\n",
|
| 897 |
+
" preds = clf.predict(X_test)\n",
|
| 898 |
+
" acc = accuracy_score(y_test, preds)\n",
|
| 899 |
+
" macro_f1 = f1_score(y_test, preds, average=\"macro\")\n",
|
| 900 |
+
" weighted_f1 = f1_score(y_test, preds, average=\"weighted\")\n",
|
| 901 |
+
" organ_results[clf_name] = {\n",
|
| 902 |
+
" \"accuracy\": acc,\n",
|
| 903 |
+
" \"macro_f1\": macro_f1,\n",
|
| 904 |
+
" \"weighted_f1\": weighted_f1\n",
|
| 905 |
+
" }\n",
|
| 906 |
+
" print(f\"{clf_name} - Acc: {acc:.4f}, Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}\")\n",
|
| 907 |
+
" \n",
|
| 908 |
+
" results[organ] = organ_results\n"
|
| 909 |
+
]
|
| 910 |
+
}
|
| 911 |
+
],
|
| 912 |
+
"metadata": {
|
| 913 |
+
"kernelspec": {
|
| 914 |
+
"display_name": "Python 3",
|
| 915 |
+
"language": "python",
|
| 916 |
+
"name": "python3"
|
| 917 |
+
},
|
| 918 |
+
"language_info": {
|
| 919 |
+
"codemirror_mode": {
|
| 920 |
+
"name": "ipython",
|
| 921 |
+
"version": 3
|
| 922 |
+
},
|
| 923 |
+
"file_extension": ".py",
|
| 924 |
+
"mimetype": "text/x-python",
|
| 925 |
+
"name": "python",
|
| 926 |
+
"nbconvert_exporter": "python",
|
| 927 |
+
"pygments_lexer": "ipython3",
|
| 928 |
+
"version": "3.11.7"
|
| 929 |
+
}
|
| 930 |
+
},
|
| 931 |
+
"nbformat": 4,
|
| 932 |
+
"nbformat_minor": 2
|
| 933 |
+
}
|
Downstream_tasks/Zero_shot_batch_effect/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/.gitignore
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Ignore specific fils from this repository, below is a long list of defaults
|
| 2 |
+
## to ignore from various code editors and IDEs
|
| 3 |
+
# Python related
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.pyc
|
| 6 |
+
*.pyo
|
| 7 |
+
*.egg-info/
|
| 8 |
+
|
| 9 |
+
# Output folder with outputs of notebooks
|
| 10 |
+
output/
|
| 11 |
+
|
| 12 |
+
# Data should be downloaded from Zenodo, not stored in the repository
|
| 13 |
+
data/
|
| 14 |
+
|
| 15 |
+
# Build directory
|
| 16 |
+
build/
|
| 17 |
+
|
| 18 |
+
# big model files
|
| 19 |
+
*.pkl
|
| 20 |
+
*.bin
|
| 21 |
+
|
| 22 |
+
## Ignore Visual Studio temporary files, build results, and
|
| 23 |
+
## files generated by popular Visual Studio add-ons.
|
| 24 |
+
##
|
| 25 |
+
## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
|
| 26 |
+
|
| 27 |
+
# User-specific files
|
| 28 |
+
*.rsuser
|
| 29 |
+
*.suo
|
| 30 |
+
*.user
|
| 31 |
+
*.userosscache
|
| 32 |
+
*.sln.docstates
|
| 33 |
+
|
| 34 |
+
# User-specific files (MonoDevelop/Xamarin Studio)
|
| 35 |
+
*.userprefs
|
| 36 |
+
|
| 37 |
+
# Mono auto generated files
|
| 38 |
+
mono_crash.*
|
| 39 |
+
|
| 40 |
+
# Build results
|
| 41 |
+
[Dd]ebug/
|
| 42 |
+
[Dd]ebugPublic/
|
| 43 |
+
[Rr]elease/
|
| 44 |
+
[Rr]eleases/
|
| 45 |
+
x64/
|
| 46 |
+
x86/
|
| 47 |
+
[Ww][Ii][Nn]32/
|
| 48 |
+
[Aa][Rr][Mm]/
|
| 49 |
+
[Aa][Rr][Mm]64/
|
| 50 |
+
bld/
|
| 51 |
+
[Bb]in/
|
| 52 |
+
[Oo]bj/
|
| 53 |
+
[Ll]og/
|
| 54 |
+
[Ll]ogs/
|
| 55 |
+
|
| 56 |
+
# Visual Studio 2015/2017 cache/options directory
|
| 57 |
+
.vs/
|
| 58 |
+
# Uncomment if you have tasks that create the project's static files in wwwroot
|
| 59 |
+
#wwwroot/
|
| 60 |
+
|
| 61 |
+
# Visual Studio 2017 auto generated files
|
| 62 |
+
Generated\ Files/
|
| 63 |
+
|
| 64 |
+
# MSTest test Results
|
| 65 |
+
[Tt]est[Rr]esult*/
|
| 66 |
+
[Bb]uild[Ll]og.*
|
| 67 |
+
|
| 68 |
+
# NUnit
|
| 69 |
+
*.VisualState.xml
|
| 70 |
+
TestResult.xml
|
| 71 |
+
nunit-*.xml
|
| 72 |
+
|
| 73 |
+
# Build Results of an ATL Project
|
| 74 |
+
[Dd]ebugPS/
|
| 75 |
+
[Rr]eleasePS/
|
| 76 |
+
dlldata.c
|
| 77 |
+
|
| 78 |
+
# Benchmark Results
|
| 79 |
+
BenchmarkDotNet.Artifacts/
|
| 80 |
+
|
| 81 |
+
# .NET Core
|
| 82 |
+
project.lock.json
|
| 83 |
+
project.fragment.lock.json
|
| 84 |
+
artifacts/
|
| 85 |
+
|
| 86 |
+
# ASP.NET Scaffolding
|
| 87 |
+
ScaffoldingReadMe.txt
|
| 88 |
+
|
| 89 |
+
# StyleCop
|
| 90 |
+
StyleCopReport.xml
|
| 91 |
+
|
| 92 |
+
# Files built by Visual Studio
|
| 93 |
+
*_i.c
|
| 94 |
+
*_p.c
|
| 95 |
+
*_h.h
|
| 96 |
+
*.ilk
|
| 97 |
+
*.meta
|
| 98 |
+
*.obj
|
| 99 |
+
*.iobj
|
| 100 |
+
*.pch
|
| 101 |
+
*.pdb
|
| 102 |
+
*.ipdb
|
| 103 |
+
*.pgc
|
| 104 |
+
*.pgd
|
| 105 |
+
*.rsp
|
| 106 |
+
*.sbr
|
| 107 |
+
*.tlb
|
| 108 |
+
*.tli
|
| 109 |
+
*.tlh
|
| 110 |
+
*.tmp
|
| 111 |
+
*.tmp_proj
|
| 112 |
+
*_wpftmp.csproj
|
| 113 |
+
*.log
|
| 114 |
+
*.tlog
|
| 115 |
+
*.vspscc
|
| 116 |
+
*.vssscc
|
| 117 |
+
.builds
|
| 118 |
+
*.pidb
|
| 119 |
+
*.svclog
|
| 120 |
+
*.scc
|
| 121 |
+
|
| 122 |
+
# Chutzpah Test files
|
| 123 |
+
_Chutzpah*
|
| 124 |
+
|
| 125 |
+
# Visual C++ cache files
|
| 126 |
+
ipch/
|
| 127 |
+
*.aps
|
| 128 |
+
*.ncb
|
| 129 |
+
*.opendb
|
| 130 |
+
*.opensdf
|
| 131 |
+
*.sdf
|
| 132 |
+
*.cachefile
|
| 133 |
+
*.VC.db
|
| 134 |
+
*.VC.VC.opendb
|
| 135 |
+
|
| 136 |
+
# Visual Studio profiler
|
| 137 |
+
*.psess
|
| 138 |
+
*.vsp
|
| 139 |
+
*.vspx
|
| 140 |
+
*.sap
|
| 141 |
+
|
| 142 |
+
# Visual Studio Trace Files
|
| 143 |
+
*.e2e
|
| 144 |
+
|
| 145 |
+
# TFS 2012 Local Workspace
|
| 146 |
+
$tf/
|
| 147 |
+
|
| 148 |
+
# Guidance Automation Toolkit
|
| 149 |
+
*.gpState
|
| 150 |
+
|
| 151 |
+
# ReSharper is a .NET coding add-in
|
| 152 |
+
_ReSharper*/
|
| 153 |
+
*.[Rr]e[Ss]harper
|
| 154 |
+
*.DotSettings.user
|
| 155 |
+
|
| 156 |
+
# TeamCity is a build add-in
|
| 157 |
+
_TeamCity*
|
| 158 |
+
|
| 159 |
+
# DotCover is a Code Coverage Tool
|
| 160 |
+
*.dotCover
|
| 161 |
+
|
| 162 |
+
# AxoCover is a Code Coverage Tool
|
| 163 |
+
.axoCover/*
|
| 164 |
+
!.axoCover/settings.json
|
| 165 |
+
|
| 166 |
+
# Coverlet is a free, cross platform Code Coverage Tool
|
| 167 |
+
coverage*.json
|
| 168 |
+
coverage*.xml
|
| 169 |
+
coverage*.info
|
| 170 |
+
|
| 171 |
+
# Visual Studio code coverage results
|
| 172 |
+
*.coverage
|
| 173 |
+
*.coveragexml
|
| 174 |
+
|
| 175 |
+
# NCrunch
|
| 176 |
+
_NCrunch_*
|
| 177 |
+
.*crunch*.local.xml
|
| 178 |
+
nCrunchTemp_*
|
| 179 |
+
|
| 180 |
+
# MightyMoose
|
| 181 |
+
*.mm.*
|
| 182 |
+
AutoTest.Net/
|
| 183 |
+
|
| 184 |
+
# Web workbench (sass)
|
| 185 |
+
.sass-cache/
|
| 186 |
+
|
| 187 |
+
# Installshield output folder
|
| 188 |
+
[Ee]xpress/
|
| 189 |
+
|
| 190 |
+
# DocProject is a documentation generator add-in
|
| 191 |
+
DocProject/buildhelp/
|
| 192 |
+
DocProject/Help/*.HxT
|
| 193 |
+
DocProject/Help/*.HxC
|
| 194 |
+
DocProject/Help/*.hhc
|
| 195 |
+
DocProject/Help/*.hhk
|
| 196 |
+
DocProject/Help/*.hhp
|
| 197 |
+
DocProject/Help/Html2
|
| 198 |
+
DocProject/Help/html
|
| 199 |
+
|
| 200 |
+
# Click-Once directory
|
| 201 |
+
publish/
|
| 202 |
+
|
| 203 |
+
# Publish Web Output
|
| 204 |
+
*.[Pp]ublish.xml
|
| 205 |
+
*.azurePubxml
|
| 206 |
+
# Note: Comment the next line if you want to checkin your web deploy settings,
|
| 207 |
+
# but database connection strings (with potential passwords) will be unencrypted
|
| 208 |
+
*.pubxml
|
| 209 |
+
*.publishproj
|
| 210 |
+
|
| 211 |
+
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
| 212 |
+
# checkin your Azure Web App publish settings, but sensitive information contained
|
| 213 |
+
# in these scripts will be unencrypted
|
| 214 |
+
PublishScripts/
|
| 215 |
+
|
| 216 |
+
# NuGet Packages
|
| 217 |
+
*.nupkg
|
| 218 |
+
# NuGet Symbol Packages
|
| 219 |
+
*.snupkg
|
| 220 |
+
# The packages folder can be ignored because of Package Restore
|
| 221 |
+
**/[Pp]ackages/*
|
| 222 |
+
# except build/, which is used as an MSBuild target.
|
| 223 |
+
!**/[Pp]ackages/build/
|
| 224 |
+
# Uncomment if necessary however generally it will be regenerated when needed
|
| 225 |
+
#!**/[Pp]ackages/repositories.config
|
| 226 |
+
# NuGet v3's project.json files produces more ignorable files
|
| 227 |
+
*.nuget.props
|
| 228 |
+
*.nuget.targets
|
| 229 |
+
|
| 230 |
+
# Microsoft Azure Build Output
|
| 231 |
+
csx/
|
| 232 |
+
*.build.csdef
|
| 233 |
+
|
| 234 |
+
# Microsoft Azure Emulator
|
| 235 |
+
ecf/
|
| 236 |
+
rcf/
|
| 237 |
+
|
| 238 |
+
# Windows Store app package directories and files
|
| 239 |
+
AppPackages/
|
| 240 |
+
BundleArtifacts/
|
| 241 |
+
Package.StoreAssociation.xml
|
| 242 |
+
_pkginfo.txt
|
| 243 |
+
*.appx
|
| 244 |
+
*.appxbundle
|
| 245 |
+
*.appxupload
|
| 246 |
+
|
| 247 |
+
# Visual Studio cache files
|
| 248 |
+
# files ending in .cache can be ignored
|
| 249 |
+
*.[Cc]ache
|
| 250 |
+
# but keep track of directories ending in .cache
|
| 251 |
+
!?*.[Cc]ache/
|
| 252 |
+
|
| 253 |
+
# Others
|
| 254 |
+
ClientBin/
|
| 255 |
+
~$*
|
| 256 |
+
*~
|
| 257 |
+
*.dbmdl
|
| 258 |
+
*.dbproj.schemaview
|
| 259 |
+
*.jfm
|
| 260 |
+
*.pfx
|
| 261 |
+
*.publishsettings
|
| 262 |
+
orleans.codegen.cs
|
| 263 |
+
|
| 264 |
+
# Including strong name files can present a security risk
|
| 265 |
+
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
| 266 |
+
#*.snk
|
| 267 |
+
|
| 268 |
+
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
| 269 |
+
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
| 270 |
+
#bower_components/
|
| 271 |
+
|
| 272 |
+
# RIA/Silverlight projects
|
| 273 |
+
Generated_Code/
|
| 274 |
+
|
| 275 |
+
# Backup & report files from converting an old project file
|
| 276 |
+
# to a newer Visual Studio version. Backup files are not needed,
|
| 277 |
+
# because we have git ;-)
|
| 278 |
+
_UpgradeReport_Files/
|
| 279 |
+
Backup*/
|
| 280 |
+
UpgradeLog*.XML
|
| 281 |
+
UpgradeLog*.htm
|
| 282 |
+
ServiceFabricBackup/
|
| 283 |
+
*.rptproj.bak
|
| 284 |
+
|
| 285 |
+
# SQL Server files
|
| 286 |
+
*.mdf
|
| 287 |
+
*.ldf
|
| 288 |
+
*.ndf
|
| 289 |
+
|
| 290 |
+
# Business Intelligence projects
|
| 291 |
+
*.rdl.data
|
| 292 |
+
*.bim.layout
|
| 293 |
+
*.bim_*.settings
|
| 294 |
+
*.rptproj.rsuser
|
| 295 |
+
*- [Bb]ackup.rdl
|
| 296 |
+
*- [Bb]ackup ([0-9]).rdl
|
| 297 |
+
*- [Bb]ackup ([0-9][0-9]).rdl
|
| 298 |
+
|
| 299 |
+
# Microsoft Fakes
|
| 300 |
+
FakesAssemblies/
|
| 301 |
+
|
| 302 |
+
# GhostDoc plugin setting file
|
| 303 |
+
*.GhostDoc.xml
|
| 304 |
+
|
| 305 |
+
# Node.js Tools for Visual Studio
|
| 306 |
+
.ntvs_analysis.dat
|
| 307 |
+
node_modules/
|
| 308 |
+
|
| 309 |
+
# Visual Studio 6 build log
|
| 310 |
+
*.plg
|
| 311 |
+
|
| 312 |
+
# Visual Studio 6 workspace options file
|
| 313 |
+
*.opt
|
| 314 |
+
|
| 315 |
+
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
| 316 |
+
*.vbw
|
| 317 |
+
|
| 318 |
+
# Visual Studio 6 auto-generated project file (contains which files were open etc.)
|
| 319 |
+
*.vbp
|
| 320 |
+
|
| 321 |
+
# Visual Studio 6 workspace and project file (working project files containing files to include in project)
|
| 322 |
+
*.dsw
|
| 323 |
+
*.dsp
|
| 324 |
+
|
| 325 |
+
# Visual Studio 6 technical files
|
| 326 |
+
*.ncb
|
| 327 |
+
*.aps
|
| 328 |
+
|
| 329 |
+
# Visual Studio LightSwitch build output
|
| 330 |
+
**/*.HTMLClient/GeneratedArtifacts
|
| 331 |
+
**/*.DesktopClient/GeneratedArtifacts
|
| 332 |
+
**/*.DesktopClient/ModelManifest.xml
|
| 333 |
+
**/*.Server/GeneratedArtifacts
|
| 334 |
+
**/*.Server/ModelManifest.xml
|
| 335 |
+
_Pvt_Extensions
|
| 336 |
+
|
| 337 |
+
# Paket dependency manager
|
| 338 |
+
.paket/paket.exe
|
| 339 |
+
paket-files/
|
| 340 |
+
|
| 341 |
+
# FAKE - F# Make
|
| 342 |
+
.fake/
|
| 343 |
+
|
| 344 |
+
# CodeRush personal settings
|
| 345 |
+
.cr/personal
|
| 346 |
+
|
| 347 |
+
# Python Tools for Visual Studio (PTVS)
|
| 348 |
+
__pycache__/
|
| 349 |
+
*.pyc
|
| 350 |
+
|
| 351 |
+
# Cake - Uncomment if you are using it
|
| 352 |
+
# tools/**
|
| 353 |
+
# !tools/packages.config
|
| 354 |
+
|
| 355 |
+
# Tabs Studio
|
| 356 |
+
*.tss
|
| 357 |
+
|
| 358 |
+
# Telerik's JustMock configuration file
|
| 359 |
+
*.jmconfig
|
| 360 |
+
|
| 361 |
+
# BizTalk build output
|
| 362 |
+
*.btp.cs
|
| 363 |
+
*.btm.cs
|
| 364 |
+
*.odx.cs
|
| 365 |
+
*.xsd.cs
|
| 366 |
+
|
| 367 |
+
# OpenCover UI analysis results
|
| 368 |
+
OpenCover/
|
| 369 |
+
|
| 370 |
+
# Azure Stream Analytics local run output
|
| 371 |
+
ASALocalRun/
|
| 372 |
+
|
| 373 |
+
# MSBuild Binary and Structured Log
|
| 374 |
+
*.binlog
|
| 375 |
+
|
| 376 |
+
# NVidia Nsight GPU debugger configuration file
|
| 377 |
+
*.nvuser
|
| 378 |
+
|
| 379 |
+
# MFractors (Xamarin productivity tool) working folder
|
| 380 |
+
.mfractor/
|
| 381 |
+
|
| 382 |
+
# Local History for Visual Studio
|
| 383 |
+
.localhistory/
|
| 384 |
+
|
| 385 |
+
# Visual Studio History (VSHistory) files
|
| 386 |
+
.vshistory/
|
| 387 |
+
|
| 388 |
+
# BeatPulse healthcheck temp database
|
| 389 |
+
healthchecksdb
|
| 390 |
+
|
| 391 |
+
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
| 392 |
+
MigrationBackup/
|
| 393 |
+
|
| 394 |
+
# Ionide (cross platform F# VS Code tools) working folder
|
| 395 |
+
.ionide/
|
| 396 |
+
|
| 397 |
+
# Fody - auto-generated XML schema
|
| 398 |
+
FodyWeavers.xsd
|
| 399 |
+
|
| 400 |
+
# VS Code files for those working on multiple tools
|
| 401 |
+
.vscode/*
|
| 402 |
+
!.vscode/settings.json
|
| 403 |
+
!.vscode/tasks.json
|
| 404 |
+
!.vscode/launch.json
|
| 405 |
+
!.vscode/extensions.json
|
| 406 |
+
*.code-workspace
|
| 407 |
+
|
| 408 |
+
# Local History for Visual Studio Code
|
| 409 |
+
.history/
|
| 410 |
+
|
| 411 |
+
# Windows Installer files from build outputs
|
| 412 |
+
*.cab
|
| 413 |
+
*.msi
|
| 414 |
+
*.msix
|
| 415 |
+
*.msm
|
| 416 |
+
*.msp
|
| 417 |
+
|
| 418 |
+
# JetBrains Rider
|
| 419 |
+
*.sln.iml
|
Downstream_tasks/Zero_shot_batch_effect/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Microsoft Open Source Code of Conduct
|
| 2 |
+
|
| 3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
| 4 |
+
|
| 5 |
+
Resources:
|
| 6 |
+
|
| 7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
| 8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
| 9 |
+
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
Downstream_tasks/Zero_shot_batch_effect/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Microsoft Corporation.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE
|
Downstream_tasks/Zero_shot_batch_effect/README.md
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Foundation models in single-cell biology: evaluating zero-shot capabilities
|
| 2 |
+
|
| 3 |
+
[](https://www.biorxiv.org/content/10.1101/2023.10.16.561085) [](https://doi.org/10.6084/m9.figshare.24747228)
|
| 4 |
+
|
| 5 |
+
This repository contains the code that accompanies our paper, **Assessing the limits of zero-shot foundation models in single-cell biology**. You can find the preprint of the paper [here](https://www.biorxiv.org/content/10.1101/2023.10.16.561085).
|
| 6 |
+
|
| 7 |
+
## Project overview
|
| 8 |
+
|
| 9 |
+
In this project, we assess two proposed foundation models in the context of single-cell RNA-seq: Geneformer ([pub](https://www.nature.com/articles/s41586-023-06139-9), [code](https://huggingface.co/ctheodoris/Geneformer)) and scGPT ([pub](https://www.biorxiv.org/content/10.1101/2023.04.30.538439v2), [code](https://github.com/bowang-lab/scGPT)). We focus on evaluating the zero-shot capabilities of these models, specifically their ability to generalize beyond their original training objectives. Our evaluation targets two main tasks: cell type clustering and batch integration. In these tasks, we compare the performance of Geneformer and scGPT against two baselines: scVI ([pub](https://www.nature.com/articles/s41592-018-0229-2), [code](https://docs.scvi-tools.org/en/stable/user_guide/models/scvi.html)) and a heuristic method that selects highly variable genes (HVGs). We also investigate the performence of the models in reconstructing the gene expression profiles of cells, and compare it against the baselines - such as a mean expression value or average ranking.
|
| 10 |
+
|
| 11 |
+
## Dependencies
|
| 12 |
+
|
| 13 |
+
Currently the code requires the GPUs supported by flash attention, required for scGPT to run.
|
| 14 |
+
|
| 15 |
+
GPUs supported by flash attention are:
|
| 16 |
+
|
| 17 |
+
- Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100).
|
| 18 |
+
- Turing GPUs (T4, RTX 2080)
|
| 19 |
+
|
| 20 |
+
<details>
|
| 21 |
+
<summary>Packages version</summary>
|
| 22 |
+
|
| 23 |
+
This code has been tested with the following versions of the packages:
|
| 24 |
+
|
| 25 |
+
- Python - tested with `3.9`
|
| 26 |
+
- PyTorch - tested with - `1.13`
|
| 27 |
+
- CUDA - tested with `11.7`
|
| 28 |
+
- [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/v1.0.4) - depends on `v1.0.4`
|
| 29 |
+
- [scGPT](https://github.com/bowang-lab/scGPT/tree/v0.1.6) - depends on `v0.1.6`
|
| 30 |
+
- [Geneformer](https://huggingface.co/ctheodoris/Geneformer/tree/5d0082c1e188ab88997efa87891414fdc6e4f6ff) - depends on commit `5d0082c`
|
| 31 |
+
- [scIB](https://github.com/theislab/scib/tree/v1.0.4) - tested with `v1.0.4`
|
| 32 |
+
- [sc_foundation_evals](https://github.com/microsoft/zero-shot-scfoundation) `v0.1.0`
|
| 33 |
+
|
| 34 |
+
</details>
|
| 35 |
+
|
| 36 |
+
## Installation
|
| 37 |
+
|
| 38 |
+
Below you can find the instructions on how to install the dependencies for this project. We provide two options: using conda/mamba or using Docker.
|
| 39 |
+
|
| 40 |
+
<details>
|
| 41 |
+
<summary>Conda / Mamba</summary>
|
| 42 |
+
|
| 43 |
+
### Conda / Mamba
|
| 44 |
+
|
| 45 |
+
You can install the dependencies using conda. To do so, you need to have conda installed on your machine. If you don't have it, you can install it from [here](https://docs.conda.io/en/latest/miniconda.html).
|
| 46 |
+
|
| 47 |
+
We recommend using [mamba](https://mamba.readthedocs.io/en/latest/user_guide/mamba.html), since it is faster in our experience. You can install mamba following the guide [here](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html#operating-system-package-managers).
|
| 48 |
+
|
| 49 |
+
To simplify installation, we provide the installation script that creates a new conda environment with all the dependencies installed. You can run the following command to create the environment:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
bash envs/installation.sh
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
If the installation is successful, you will see the following message:
|
| 56 |
+
|
| 57 |
+
```console
|
| 58 |
+
2024-08-22 19:49:26 SUCCESS: All packages installed successfully.
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
And you can activate the environment by running:
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
conda activate sc_foundation_evals
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
</details>
|
| 68 |
+
|
| 69 |
+
<details>
|
| 70 |
+
<summary>Docker</summary>
|
| 71 |
+
|
| 72 |
+
### Docker
|
| 73 |
+
|
| 74 |
+
The docker image is available on DockerHub [here](https://hub.docker.com/repository/docker/kzkedzierska/sc_foundation_evals/general). You can pull the image by running:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
docker pull kzkedzierska/sc_foundation_evals
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
The image is based on the `cnstark/pytorch:1.13.0-py3.9.12-cuda11.7.1-ubuntu20.04` image, and has all the dependencies installed. The Dockerfile used to build the image can be found in the `envs/docker` directory.
|
| 81 |
+
|
| 82 |
+
You can also skip pulling the image since `docker` will pull it if needed. To run the interactive session with the image, you can use the following command:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
docker run --gpus all -it kzkedzierska/sc_foundation_evals
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
If you want to be able to run the notebooks, run the image with the following tag:
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
docker run --gpus all -it --rm -p 8888:8888 -v ./:/workspace kzkedzierska/sc_foundation_evals:latest_notebook
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
And open the link provided in the terminal in your browser. It should look like this:
|
| 95 |
+
|
| 96 |
+
```console
|
| 97 |
+
[I 2024-08-23 22:15:13.015 ServerApp] Serving notebooks from local directory: /workspace
|
| 98 |
+
[I 2024-08-23 22:15:13.015 ServerApp] Jupyter Server 2.14.2 is running at:
|
| 99 |
+
[I 2024-08-23 22:15:13.015 ServerApp] http://localhost:8888/tree
|
| 100 |
+
[I 2024-08-23 22:15:13.015 ServerApp] http://127.0.0.1:8888/tree
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
For running the command on the server, consult the documentation of the server provider on how to forward the ports properly.
|
| 104 |
+
|
| 105 |
+
</details>
|
| 106 |
+
|
| 107 |
+
## Running the code
|
| 108 |
+
|
| 109 |
+
### Downloading the weights
|
| 110 |
+
|
| 111 |
+
To run notebooks you also need to have the weights of the models downloaded. scGPT weights are avaialble [here](https://github.com/bowang-lab/scGPT#pretrained-scgpt-model-zoo) and Geneformer weights are available in its repository. As per the instructions in the Geneformer repository, make sure you have `git lfs` installed before downloading the weights via repository cloning.
|
| 112 |
+
|
| 113 |
+
### Copying this repository
|
| 114 |
+
|
| 115 |
+
To run the code, you need to clone this repository.
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
git clone https://github.com/microsoft/zero-shot-scfoundation
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
And download and unpack the data, stored at figshare (see [here](https://doi.org/10.6084/m9.figshare.24747228) for more details).
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
cd zero-shot-scfoundation
|
| 125 |
+
# download and unpack the data
|
| 126 |
+
wget https://figshare.com/ndownloader/files/43480497 -O data.zip
|
| 127 |
+
unzip data.zip && rm data.zip
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
### Notebooks
|
| 131 |
+
|
| 132 |
+
To best understand the code and it's organization, please have a look at the notebooks. The `notebooks` directory currently contains the following notebooks:
|
| 133 |
+
|
| 134 |
+
- [scGPT_zero_shot](notebooks/scGPT_zero_shot.ipynb) - notebook for running scGPT zero-shot evaluation
|
| 135 |
+
- [Geneformer_zero_shot](notebooks/Geneformer_zero_shot.ipynb) - notebook for running Geneformer zero-shot evaluation
|
| 136 |
+
- [Baselines_HVG_and_scVI](notebooks/Baselines_HVG_and_scVI.ipynb) - notebook for running the baselines used in the paper, i.e. HVG and scVI.
|
| 137 |
+
|
| 138 |
+
## Any questions?
|
| 139 |
+
|
| 140 |
+
If you have any questions, or find any issues with the code, please open an issue in this repository. You can find more information on how to file an issue in [here](/SUPPORT.md). We also welcome any contributions to the code - be sure to checkout the **Contributing** section below.
|
| 141 |
+
|
| 142 |
+
## Contributing
|
| 143 |
+
|
| 144 |
+
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
| 145 |
+
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
| 146 |
+
the rights to use your contribution. For details, visit <https://cla.opensource.microsoft.com>.
|
| 147 |
+
|
| 148 |
+
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
| 149 |
+
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
| 150 |
+
provided by the bot. You will only need to do this once across all repos using our CLA.
|
| 151 |
+
|
| 152 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
| 153 |
+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
| 154 |
+
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
| 155 |
+
|
| 156 |
+
## Trademarks
|
| 157 |
+
|
| 158 |
+
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
| 159 |
+
trademarks or logos is subject to and must follow
|
| 160 |
+
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
| 161 |
+
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
| 162 |
+
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
Downstream_tasks/Zero_shot_batch_effect/SECURITY.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
|
| 2 |
+
|
| 3 |
+
## Security
|
| 4 |
+
|
| 5 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
|
| 6 |
+
|
| 7 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
|
| 8 |
+
|
| 9 |
+
## Reporting Security Issues
|
| 10 |
+
|
| 11 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
| 12 |
+
|
| 13 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
|
| 14 |
+
|
| 15 |
+
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
|
| 16 |
+
|
| 17 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
| 18 |
+
|
| 19 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
| 20 |
+
|
| 21 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
| 22 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
| 23 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
| 24 |
+
* Any special configuration required to reproduce the issue
|
| 25 |
+
* Step-by-step instructions to reproduce the issue
|
| 26 |
+
* Proof-of-concept or exploit code (if possible)
|
| 27 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
| 28 |
+
|
| 29 |
+
This information will help us triage your report more quickly.
|
| 30 |
+
|
| 31 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
|
| 32 |
+
|
| 33 |
+
## Preferred Languages
|
| 34 |
+
|
| 35 |
+
We prefer all communications to be in English.
|
| 36 |
+
|
| 37 |
+
## Policy
|
| 38 |
+
|
| 39 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
|
| 40 |
+
|
| 41 |
+
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
Downstream_tasks/Zero_shot_batch_effect/SUPPORT.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Support
|
| 2 |
+
|
| 3 |
+
## How to file issues and get help
|
| 4 |
+
|
| 5 |
+
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
| 6 |
+
issues before filing new issues to avoid duplicates. For new issues, file your bug, ask a question
|
| 7 |
+
or request a feature as a new Issue.
|
| 8 |
+
|
| 9 |
+
If you face an issue with installation or running the code, on top of the error message please describe
|
| 10 |
+
your enviornment well (what operating system do you use, if you use conda or virtual enviornment,
|
| 11 |
+
please list what versions of the packages are installed and available in your PATH at the time of
|
| 12 |
+
running the code). We will try to respond and help.
|
| 13 |
+
|
| 14 |
+
## Microsoft Support Policy
|
| 15 |
+
|
| 16 |
+
Support for this PROJECT is limited to the resources listed above.
|
Downstream_tasks/Zero_shot_batch_effect/envs/conda_env.yml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sc_foundation_evals
|
| 2 |
+
channels:
|
| 3 |
+
- nvidia/label/cuda-11.7.0
|
| 4 |
+
- conda-forge
|
| 5 |
+
- bioconda
|
| 6 |
+
- defaults
|
| 7 |
+
dependencies:
|
| 8 |
+
- python=3.10
|
| 9 |
+
- cudatoolkit
|
| 10 |
+
- r-base=4.2.3
|
| 11 |
+
- ninja
|
| 12 |
+
- rpy2
|
| 13 |
+
- packaging
|
| 14 |
+
- gxx=11.4
|
| 15 |
+
- git-lfs
|
| 16 |
+
- pip>=21.1
|
| 17 |
+
- pip:
|
| 18 |
+
- --index-url https://download.pytorch.org/whl/cu117
|
| 19 |
+
- torch==1.13
|
| 20 |
+
- torchvision
|
| 21 |
+
- torchaudio
|
Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM cnstark/pytorch:1.13.0-py3.9.12-cuda11.7.1-ubuntu20.04
|
| 2 |
+
|
| 3 |
+
# NAME sc_foundation_evals
|
| 4 |
+
|
| 5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 6 |
+
|
| 7 |
+
RUN apt-get update && apt-get install -y wget git git-lfs && \
|
| 8 |
+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb && \
|
| 9 |
+
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
| 10 |
+
rm cuda-keyring_1.1-1_all.deb && \
|
| 11 |
+
apt-get update && \
|
| 12 |
+
echo "tzdata tzdata/Areas select Europe" > /tmp/prelseed.txt; \
|
| 13 |
+
echo "tzdata tzdata/Zones/Europe select Warsaw" >> /tmp/preseed.txt; \
|
| 14 |
+
debconf-set-selections /tmp/preseed.txt && \
|
| 15 |
+
apt-get install -y cuda-toolkit-11-7 && \
|
| 16 |
+
apt-get install -y r-base && \
|
| 17 |
+
apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
| 18 |
+
|
| 19 |
+
ENV PATH=/usr/local/cuda-11.7/bin${PATH:+:${PATH}}
|
| 20 |
+
ENV LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
| 21 |
+
|
| 22 |
+
RUN pip install packaging && \
|
| 23 |
+
pip install flash-attn==1.0.4 --no-build-isolation
|
| 24 |
+
|
| 25 |
+
RUN pip install scib[kBET,rpy2] colorlog PyComplexHeatmap wandb && \
|
| 26 |
+
pip install git+https://github.com/bowang-lab/scGPT.git@v0.1.6 && \
|
| 27 |
+
pip install git+https://huggingface.co/ctheodoris/Geneformer.git@5d0082c1e188ab88997efa87891414fdc6e4f6ff && \
|
| 28 |
+
pip install git+https://github.com/microsoft/zero-shot-scfoundation.git@v0.1.0
|
Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/test.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
from sc_foundation_evals.helpers.custom_logging import log
|
| 5 |
+
except ImportError:
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 9 |
+
log = logging.getLogger(__name__)
|
| 10 |
+
msg = "Cannot load sc_foundation_evals custom logging module. Exiting..."
|
| 11 |
+
log.error(msg)
|
| 12 |
+
raise ImportError(msg)
|
| 13 |
+
|
| 14 |
+
log.info("Hello from the test script! This is to test the build process.")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def import_package(package_name):
|
| 18 |
+
"""
|
| 19 |
+
Try to import a package and return the package if successful.
|
| 20 |
+
Logs and raises an error if the package is not available.
|
| 21 |
+
"""
|
| 22 |
+
try:
|
| 23 |
+
package = __import__(package_name)
|
| 24 |
+
version = getattr(package, "__version__", None)
|
| 25 |
+
log.info(
|
| 26 |
+
f"Successfully imported {package_name}. "
|
| 27 |
+
f"Version: {version if version else 'unknown'}"
|
| 28 |
+
)
|
| 29 |
+
return package
|
| 30 |
+
|
| 31 |
+
except ImportError as e:
|
| 32 |
+
msg = f"Could not import required package: {package_name}"
|
| 33 |
+
log.error(f"{msg}: {e}")
|
| 34 |
+
raise ImportError(msg)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_cuda_availability():
|
| 38 |
+
"""
|
| 39 |
+
Check if CUDA is available and log the result.
|
| 40 |
+
"""
|
| 41 |
+
torch = import_package("torch")
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
log.info("Success -- CUDA is available!")
|
| 44 |
+
else:
|
| 45 |
+
log.error(
|
| 46 |
+
"CUDA is not available. Please check your system configuration."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
try:
|
| 52 |
+
log.debug("Testing CUDA availability...")
|
| 53 |
+
test_cuda_availability()
|
| 54 |
+
log.debug("Testing loading scGPT...")
|
| 55 |
+
import_package("scgpt")
|
| 56 |
+
log.debug("Testing loading Geneformer...")
|
| 57 |
+
import_package("geneformer")
|
| 58 |
+
log.info("All tests passed successfully! :)")
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
log.error(f"An error occurred during the testing process: {e}")
|
| 62 |
+
raise
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
main()
|
Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/test_docker.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /bin/bash
|
| 2 |
+
|
| 3 |
+
# This script is used to test the docker image built by the Dockerfile in the same directory.
|
| 4 |
+
# The docker image is built by the following command:
|
| 5 |
+
# docker build -t kzkedzierska/sc_foundation_evals[:tag] .
|
| 6 |
+
|
| 7 |
+
# The script runs the docker image and executes the test.py script in the container.
|
| 8 |
+
# The test.py script is a simple script that imports the sc_foundation_evals package and prints the version of the package.
|
| 9 |
+
|
| 10 |
+
docker run \
|
| 11 |
+
--gpus all \
|
| 12 |
+
-v "$(pwd)":/workspace kzkedzierska/sc_foundation_evals \
|
| 13 |
+
python test.py
|
Downstream_tasks/Zero_shot_batch_effect/envs/docker/jupyter/Dockerfile
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM kzkedzierska/sc_foundation_evals:latest
|
| 2 |
+
|
| 3 |
+
# Install Jupyter Notebook
|
| 4 |
+
RUN pip install notebook
|
| 5 |
+
|
| 6 |
+
WORKDIR /workspace
|
| 7 |
+
|
| 8 |
+
# Expose the port Jupyter will run on
|
| 9 |
+
EXPOSE 8888
|
| 10 |
+
|
| 11 |
+
# Set the default command to run when starting the container
|
| 12 |
+
CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8888", "--no-browser", "--allow-root", "--NotebookApp.token=''", "--NotebookApp.password=''"]
|
Downstream_tasks/Zero_shot_batch_effect/envs/installation.sh
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /bin/bash
|
| 2 |
+
# exit on error
|
| 3 |
+
set -e
|
| 4 |
+
|
| 5 |
+
_script_name=$(basename "$0")
|
| 6 |
+
|
| 7 |
+
ENV_NAME="sc_foundation_evals"
|
| 8 |
+
|
| 9 |
+
warning() {
|
| 10 |
+
yellow='\033[0;33m'
|
| 11 |
+
nc='\033[0m'
|
| 12 |
+
echo -e "${yellow}$(date '+%Y-%m-%d %H:%M:%S') WARNING: $@${nc}" 1>&2
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
success() {
|
| 16 |
+
green='\033[0;32m'
|
| 17 |
+
nc='\033[0m'
|
| 18 |
+
echo -e "${green}$(date '+%Y-%m-%d %H:%M:%S') SUCCESS: $@${nc}"
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
error() {
|
| 22 |
+
red='\033[0;31m'
|
| 23 |
+
nc='\033[0m'
|
| 24 |
+
echo -e "${red}$(date '+%Y-%m-%d %H:%M:%S') ERROR: $@${nc}" 1>&2
|
| 25 |
+
usage_and_exit 1
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
msg() {
|
| 29 |
+
echo -e "$(date '+%Y-%m-%d %H:%M:%S') INFO: $@"
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
usage() {
|
| 33 |
+
echo -e "
|
| 34 |
+
|
| 35 |
+
USAGE: bash ${_script_name}
|
| 36 |
+
|
| 37 |
+
Script to install the package and set up the Conda environment.
|
| 38 |
+
|
| 39 |
+
EXAMPLES:
|
| 40 |
+
Install the package and set up the Conda environment:
|
| 41 |
+
bash ${_script_name}
|
| 42 |
+
"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
usage_and_exit() {
|
| 46 |
+
usage
|
| 47 |
+
exit $1
|
| 48 |
+
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# if mamba available, use it
|
| 52 |
+
if command -v mamba &>/dev/null; then
|
| 53 |
+
conda_cli=mamba
|
| 54 |
+
else
|
| 55 |
+
conda_cli=conda
|
| 56 |
+
fi
|
| 57 |
+
msg "Using '${conda_cli}' as the Conda CLI."
|
| 58 |
+
|
| 59 |
+
${conda_cli} env create -f envs/conda_env.yml -n ${ENV_NAME} ||
|
| 60 |
+
error "Failed to create the Conda environment '${ENV_NAME}'."
|
| 61 |
+
success "Conda environment '${ENV_NAME}' created successfully."
|
| 62 |
+
|
| 63 |
+
${conda_cli} run \
|
| 64 |
+
-n ${ENV_NAME} pip install flash-attn==1.0.4 --no-build-isolation
|
| 65 |
+
success "Flash attention installed successfully."
|
| 66 |
+
|
| 67 |
+
${conda_cli} run \
|
| 68 |
+
-n ${ENV_NAME} pip install 'setuptools>=65.2' wandb colorlog \
|
| 69 |
+
PyComplexHeatmap scib[kBET,rpy2]==1.0.4 ||
|
| 70 |
+
error "Failed to install the wandb, colorlog, PyComplexHeatmap or scib."
|
| 71 |
+
|
| 72 |
+
${conda_cli} run \
|
| 73 |
+
-n ${ENV_NAME} pip install git+https://github.com/bowang-lab/scGPT.git@v0.1.6 ||
|
| 74 |
+
error "Failed to install the scGPT."
|
| 75 |
+
|
| 76 |
+
${conda_cli} run \
|
| 77 |
+
-n ${ENV_NAME} pip install \
|
| 78 |
+
git+https://huggingface.co/ctheodoris/Geneformer.git@5d0082c1e188ab88997efa87891414fdc6e4f6ff ||
|
| 79 |
+
error "Failed to install the Geneformer."
|
| 80 |
+
|
| 81 |
+
${conda_cli} run \
|
| 82 |
+
-n ${ENV_NAME} pip install git+https://github.com/microsoft/zero-shot-scfoundation ||
|
| 83 |
+
error "Failed to install the sc_foundation_evals."
|
| 84 |
+
|
| 85 |
+
success "All packages installed successfully."
|
Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_Geneformer.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_HVG_and_scVI.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_evaluation_aggregated.ipynb
ADDED
|
@@ -0,0 +1,1058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Geneformer"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": null,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": [
|
| 16 |
+
"import os\n",
|
| 17 |
+
"import logging\n",
|
| 18 |
+
"import warnings\n",
|
| 19 |
+
"import sys\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
|
| 22 |
+
"warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"from sc_foundation_evals import geneformer_forward as gf\n",
|
| 25 |
+
"from sc_foundation_evals import data, cell_embeddings, model_output\n",
|
| 26 |
+
"from sc_foundation_evals.helpers.custom_logging import log\n",
|
| 27 |
+
"log.setLevel(logging.INFO)"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": [
|
| 36 |
+
"geneformer_data = \"model path\"\n",
|
| 37 |
+
"# path to the pre-trained model, can work with the huggingface model hub\n",
|
| 38 |
+
"# i.e. ctheodoris/Geneformer\n",
|
| 39 |
+
"model_dir = os.path.join(geneformer_data)\n",
|
| 40 |
+
"# path to dictionaries in geneformer repo\n",
|
| 41 |
+
"dict_dir = \"Pretrain_data/\"\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"# batch_size depends on available GPU memory\n",
|
| 44 |
+
"batch_size = 24\n",
|
| 45 |
+
"# output_dir is the path to which the results should be saved\n",
|
| 46 |
+
"output_dir = \"zero_shot_results/\"\n",
|
| 47 |
+
"# path to where we will store the embeddings and other evaluation outputs\n",
|
| 48 |
+
"model_out = os.path.join(output_dir, \"model_outputs\")\n",
|
| 49 |
+
"# if you can use multithreading specify num_workers, -1 means use all available\n",
|
| 50 |
+
"num_workers = -1"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": null,
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"# specify the path to anndata object\n",
|
| 60 |
+
"in_dataset_path = \"Zero_shot_batch_data/pbmc.h5ad\"\n",
|
| 61 |
+
"# dataset_name is inferred from in_dataset_path\n",
|
| 62 |
+
"dataset_name = os.path.basename(in_dataset_path).split(\".\")[0]\n",
|
| 63 |
+
"# specify the path for the output of the pre-processing\n",
|
| 64 |
+
"preprocessed_path = f\"zero_shot_preprocess/{dataset_name}/\"\n",
|
| 65 |
+
"# create the preprocessed path if it does not exist\n",
|
| 66 |
+
"os.makedirs(preprocessed_path, exist_ok=True)\n",
|
| 67 |
+
"# in which column in adata.obs are gene names stored? if they are in index, the index will be copied to a column with this name\n",
|
| 68 |
+
"gene_col = \"gene_symbols\"\n",
|
| 69 |
+
"# batch column found in adata.obs\n",
|
| 70 |
+
"batch_col = \"batch\"\n",
|
| 71 |
+
"# where are labels stored in adata.obs? \n",
|
| 72 |
+
"label_col = \"celltype\" #\"str_labels\"\n",
|
| 73 |
+
"# where the raw counts are stored?\n",
|
| 74 |
+
"layer_key = \"counts\" #\"X\" "
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "code",
|
| 79 |
+
"execution_count": null,
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"outputs": [],
|
| 82 |
+
"source": [
|
| 83 |
+
"geneform = gf.Geneformer_instance(save_dir = output_dir, \n",
|
| 84 |
+
" saved_model_path = model_dir,\n",
|
| 85 |
+
" explicit_save_dir = True,\n",
|
| 86 |
+
" num_workers = num_workers)\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"geneform.load_pretrained_model()\n",
|
| 89 |
+
"geneform.load_vocab(dict_dir)\n",
|
| 90 |
+
"# input_data = data.InputData(adata_dataset_path = in_dataset_path)"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "markdown",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"source": [
|
| 97 |
+
"## Create dataset"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": 5,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": [
|
| 106 |
+
"# input_data.preprocess_data(gene_col = gene_col,\n",
|
| 107 |
+
"# model_type = \"geneformer\",\n",
|
| 108 |
+
"# save_ext = \"loom\",\n",
|
| 109 |
+
"# gene_name_id_dict = geneform.gene_name_id,\n",
|
| 110 |
+
"# preprocessed_path = preprocessed_path)\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"# geneform.tokenize_data(adata_path = os.path.join(preprocessed_path, \n",
|
| 113 |
+
"# f\"{dataset_name}.loom\"),\n",
|
| 114 |
+
"# dataset_path = preprocessed_path,\n",
|
| 115 |
+
"# cell_type_col = label_col)"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"cell_type": "markdown",
|
| 120 |
+
"metadata": {},
|
| 121 |
+
"source": [
|
| 122 |
+
"## Load dataset"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "code",
|
| 127 |
+
"execution_count": 6,
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [
|
| 130 |
+
{
|
| 131 |
+
"name": "stderr",
|
| 132 |
+
"output_type": "stream",
|
| 133 |
+
"text": [
|
| 134 |
+
"\u001b[32mINFO \u001b[0m | 2025-07-17 11:57:03 | \u001b[32mLoading data from /ibex/user/chenj0i/Geneformer/zero_shot_preprocess/pbmc/pbmc.loom\u001b[0m\n"
|
| 135 |
+
]
|
| 136 |
+
}
|
| 137 |
+
],
|
| 138 |
+
"source": [
|
| 139 |
+
"geneform.load_tokenized_dataset(os.path.join(preprocessed_path, f\"{dataset_name}.dataset\"))\n",
|
| 140 |
+
"input_data = data.InputData(adata_dataset_path = os.path.join(preprocessed_path, f\"{dataset_name}.loom\"))"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "markdown",
|
| 145 |
+
"metadata": {},
|
| 146 |
+
"source": [
|
| 147 |
+
"## Embeddings extraction"
|
| 148 |
+
]
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"cell_type": "code",
|
| 152 |
+
"execution_count": 7,
|
| 153 |
+
"metadata": {},
|
| 154 |
+
"outputs": [
|
| 155 |
+
{
|
| 156 |
+
"data": {
|
| 157 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 158 |
+
"model_id": "c89a344776044c95bfef70882ebd4ff8",
|
| 159 |
+
"version_major": 2,
|
| 160 |
+
"version_minor": 0
|
| 161 |
+
},
|
| 162 |
+
"text/plain": [
|
| 163 |
+
"Geneformer (extracting embeddings): 0%| | 0/500 [00:00<?, ?it/s]"
|
| 164 |
+
]
|
| 165 |
+
},
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"output_type": "display_data"
|
| 168 |
+
}
|
| 169 |
+
],
|
| 170 |
+
"source": [
|
| 171 |
+
"geneform.extract_embeddings(data = input_data,\n",
|
| 172 |
+
" batch_size = batch_size, \n",
|
| 173 |
+
" layer = -2\n",
|
| 174 |
+
" # layer = -1\n",
|
| 175 |
+
" # layer = 0\n",
|
| 176 |
+
" )"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "code",
|
| 181 |
+
"execution_count": null,
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"outputs": [
|
| 184 |
+
{
|
| 185 |
+
"data": {
|
| 186 |
+
"text/plain": [
|
| 187 |
+
"AnnData object with n_obs × n_vars = 11990 × 3226\n",
|
| 188 |
+
" obs: 'adata_order', 'batch', 'celltype', 'labels', 'n_counts', 'n_genes', 'n_genes_by_counts', 'obs_names', 'str_labels', 'total_counts'\n",
|
| 189 |
+
" var: 'ensembl_id', 'gene_symbols', 'has_ensembl_match', 'mean_counts', 'n_cells', 'n_cells_by_counts', 'n_counts', 'n_counts-0', 'n_counts-1', 'pct_dropout_by_counts', 'total_counts', 'var_names'\n",
|
| 190 |
+
" obsm: 'geneformer'"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
"execution_count": 8,
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"output_type": "execute_result"
|
| 196 |
+
}
|
| 197 |
+
],
|
| 198 |
+
"source": [
|
| 199 |
+
"input_data.adata"
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": null,
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [],
|
| 207 |
+
"source": [
|
| 208 |
+
"from typing import Dict, Optional\n",
|
| 209 |
+
"import numpy as np\n",
|
| 210 |
+
"import scanpy as sc\n",
|
| 211 |
+
"import scib\n",
|
| 212 |
+
"from anndata import AnnData\n",
|
| 213 |
+
"from sklearn.metrics import silhouette_score\n",
|
| 214 |
+
"from tqdm import tqdm\n",
|
| 215 |
+
"import pandas as pd\n",
|
| 216 |
+
"import logging\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"log = logging.getLogger(__name__)\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"def eval_clustering_metrics(\n",
|
| 222 |
+
" adata: AnnData,\n",
|
| 223 |
+
" batch_key: Optional[str] = \"str_batch\",\n",
|
| 224 |
+
" label_key: str = \"cell_type\",\n",
|
| 225 |
+
" embedding_key: str = \"X\", # \"X\" for raw, or embedding key in .obsm\n",
|
| 226 |
+
" resolutions: Optional[list] = None,\n",
|
| 227 |
+
" use_progress_bar: bool = True,\n",
|
| 228 |
+
" verbose: bool = False,\n",
|
| 229 |
+
" subsample_frac: Optional[float] = 0.25,\n",
|
| 230 |
+
") -> Dict[str, float]:\n",
|
| 231 |
+
" \"\"\"Evaluate biological and batch mixing metrics on an embedding or raw expression.\"\"\"\n",
|
| 232 |
+
" \n",
|
| 233 |
+
" results_dict = {}\n",
|
| 234 |
+
"\n",
|
| 235 |
+
" if subsample_frac is not None and 0 < subsample_frac < 1:\n",
|
| 236 |
+
" adata = adata.copy()\n",
|
| 237 |
+
" sc.pp.subsample(adata, fraction=subsample_frac, copy=False)\n",
|
| 238 |
+
" if verbose:\n",
|
| 239 |
+
" log.info(f\"Subsampled adata to {subsample_frac * 100:.1f}% of original cells.\")\n",
|
| 240 |
+
"\n",
|
| 241 |
+
" # Determine whether to use .X or .obsm[embedding_key]\n",
|
| 242 |
+
" if embedding_key == \"X\":\n",
|
| 243 |
+
" use_rep = \"X\"\n",
|
| 244 |
+
" adata.obsm[\"X\"] = adata.X\n",
|
| 245 |
+
" elif embedding_key in adata.obsm:\n",
|
| 246 |
+
" use_rep = embedding_key\n",
|
| 247 |
+
" else:\n",
|
| 248 |
+
" raise ValueError(f\"embedding_key '{embedding_key}' not found in adata.obsm or is not 'X'\")\n",
|
| 249 |
+
"\n",
|
| 250 |
+
" # Clear stale neighbors\n",
|
| 251 |
+
" if \"neighbors\" in adata.uns:\n",
|
| 252 |
+
" if verbose:\n",
|
| 253 |
+
" log.warning(f\"Removing stale neighbors computed from other representations.\")\n",
|
| 254 |
+
" adata.uns.pop(\"neighbors\", None)\n",
|
| 255 |
+
"\n",
|
| 256 |
+
" sc.pp.neighbors(adata, use_rep=use_rep)\n",
|
| 257 |
+
"\n",
|
| 258 |
+
" # Run Louvain across multiple resolutions\n",
|
| 259 |
+
" if resolutions is None:\n",
|
| 260 |
+
" resolutions = [2 * i / 20 for i in range(1, 21)] # Default: 20 steps from 0.1 to 2.0\n",
|
| 261 |
+
" # resolutions = [4 * i / 40 for i in range(1, 41)] # Default: 20 steps from 0.1 to 2.0\n",
|
| 262 |
+
"\n",
|
| 263 |
+
" best_nmi = -1\n",
|
| 264 |
+
" best_res = None\n",
|
| 265 |
+
" best_clustering = None\n",
|
| 266 |
+
"\n",
|
| 267 |
+
" if verbose:\n",
|
| 268 |
+
" log.info(f\"Searching for optimal clustering resolution on {use_rep}...\")\n",
|
| 269 |
+
"\n",
|
| 270 |
+
" for res in tqdm(resolutions, disable=not use_progress_bar, desc=\"Louvain clustering\"):\n",
|
| 271 |
+
" sc.tl.louvain(adata, resolution=res, key_added=\"temp_cluster\")\n",
|
| 272 |
+
" nmi = scib.metrics.nmi(adata, \"temp_cluster\", label_key)\n",
|
| 273 |
+
" if nmi > best_nmi:\n",
|
| 274 |
+
" best_nmi = nmi\n",
|
| 275 |
+
" best_res = res\n",
|
| 276 |
+
" best_clustering = adata.obs[\"temp_cluster\"].copy()\n",
|
| 277 |
+
" del adata.obs[\"temp_cluster\"]\n",
|
| 278 |
+
"\n",
|
| 279 |
+
" if verbose:\n",
|
| 280 |
+
" log.info(f\"Best resolution: {best_res:.2f} with NMI = {best_nmi:.4f}\")\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" adata.obs[\"cluster\"] = best_clustering\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" # Biological conservation metrics\n",
|
| 285 |
+
" results_dict[\"NMI_cluster/label\"] = scib.metrics.nmi(adata, \"cluster\", label_key, \"arithmetic\")\n",
|
| 286 |
+
" results_dict[\"ARI_cluster/label\"] = scib.metrics.ari(adata, \"cluster\", label_key)\n",
|
| 287 |
+
" results_dict[\"ASW_label\"] = scib.metrics.silhouette(adata, label_key, use_rep, \"euclidean\")\n",
|
| 288 |
+
"\n",
|
| 289 |
+
" # Batch effect metrics (if batch_key valid)\n",
|
| 290 |
+
" if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:\n",
|
| 291 |
+
" adata.obs[label_key] = adata.obs[label_key].astype(\"category\")\n",
|
| 292 |
+
" results_dict[\"graph_conn\"] = scib.metrics.graph_connectivity(adata, label_key)\n",
|
| 293 |
+
" results_dict[\"ASW_batch\"] = scib.metrics.silhouette(adata, batch_key, use_rep, \"euclidean\")\n",
|
| 294 |
+
" results_dict[\"ASW_label/batch\"] = scib.metrics.silhouette_batch(\n",
|
| 295 |
+
" adata, batch_key, label_key, embed=use_rep, metric=\"euclidean\", return_all=False\n",
|
| 296 |
+
" )\n",
|
| 297 |
+
" results_dict[\"PCR_batch\"] = scib.metrics.pcr(\n",
|
| 298 |
+
" adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False\n",
|
| 299 |
+
" )\n",
|
| 300 |
+
" results_dict[\"Average_Batch_Score\"] = (\n",
|
| 301 |
+
" results_dict[\"ASW_batch\"] + results_dict[\"PCR_batch\"]\n",
|
| 302 |
+
" ) / 2\n",
|
| 303 |
+
" else:\n",
|
| 304 |
+
" if verbose:\n",
|
| 305 |
+
" log.info(\"Skipping batch metrics — only one batch present or invalid batch_key.\")\n",
|
| 306 |
+
" \n",
|
| 307 |
+
" results_dict[\"avg_bio\"] = np.mean([\n",
|
| 308 |
+
" results_dict[\"NMI_cluster/label\"],\n",
|
| 309 |
+
" results_dict[\"ARI_cluster/label\"],\n",
|
| 310 |
+
" results_dict[\"ASW_label\"]\n",
|
| 311 |
+
" ])\n",
|
| 312 |
+
"\n",
|
| 313 |
+
" # Filter NaNs\n",
|
| 314 |
+
" results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}\n",
|
| 315 |
+
"\n",
|
| 316 |
+
" return results_dict\n"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"cell_type": "markdown",
|
| 321 |
+
"metadata": {},
|
| 322 |
+
"source": [
|
| 323 |
+
"# Embeddings metrics"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "code",
|
| 328 |
+
"execution_count": null,
|
| 329 |
+
"metadata": {},
|
| 330 |
+
"outputs": [
|
| 331 |
+
{
|
| 332 |
+
"name": "stderr",
|
| 333 |
+
"output_type": "stream",
|
| 334 |
+
"text": [
|
| 335 |
+
"Louvain clustering: 100%|██████████| 20/20 [00:02<00:00, 7.68it/s]\n"
|
| 336 |
+
]
|
| 337 |
+
},
|
| 338 |
+
{
|
| 339 |
+
"name": "stdout",
|
| 340 |
+
"output_type": "stream",
|
| 341 |
+
"text": [
|
| 342 |
+
"mean silhouette per group: silhouette_score\n",
|
| 343 |
+
"group \n",
|
| 344 |
+
"B cells 0.990590\n",
|
| 345 |
+
"CD14+ Monocytes 0.979706\n",
|
| 346 |
+
"CD4 T cells 0.987594\n",
|
| 347 |
+
"CD8 T cells 0.991305\n",
|
| 348 |
+
"Dendritic Cells 0.958009\n",
|
| 349 |
+
"FCGR3A+ Monocytes 0.990665\n",
|
| 350 |
+
"Megakaryocytes 0.857295\n",
|
| 351 |
+
"NK cells 0.977292\n",
|
| 352 |
+
"Other 0.933587\n"
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"data": {
|
| 357 |
+
"text/plain": [
|
| 358 |
+
"{'NMI_cluster/label': 0.6061048617613637,\n",
|
| 359 |
+
" 'ARI_cluster/label': 0.503784927975462,\n",
|
| 360 |
+
" 'ASW_label': 0.510432125069201,\n",
|
| 361 |
+
" 'graph_conn': 0.8852579724762832,\n",
|
| 362 |
+
" 'ASW_batch': 0.5012279110960662,\n",
|
| 363 |
+
" 'ASW_label/batch': 0.9628935503212096,\n",
|
| 364 |
+
" 'PCR_batch': 0.0007131078007747846,\n",
|
| 365 |
+
" 'Average_Batch_Score': 0.25097050944842053,\n",
|
| 366 |
+
" 'avg_bio': 0.5401073049353422}"
|
| 367 |
+
]
|
| 368 |
+
},
|
| 369 |
+
"execution_count": 10,
|
| 370 |
+
"metadata": {},
|
| 371 |
+
"output_type": "execute_result"
|
| 372 |
+
}
|
| 373 |
+
],
|
| 374 |
+
"source": [
|
| 375 |
+
"results_dict = eval_clustering_metrics(adata=input_data.adata, \n",
|
| 376 |
+
" batch_key=\"batch\",\n",
|
| 377 |
+
" label_key=\"celltype\",\n",
|
| 378 |
+
" embedding_key=\"geneformer\", # or \"X_scGPT\", etc.\n",
|
| 379 |
+
" verbose=True)\n",
|
| 380 |
+
"results_dict"
|
| 381 |
+
]
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"cell_type": "code",
|
| 385 |
+
"execution_count": null,
|
| 386 |
+
"metadata": {},
|
| 387 |
+
"outputs": [
|
| 388 |
+
{
|
| 389 |
+
"data": {
|
| 390 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 391 |
+
"model_id": "12c31089634046939fc59c2ef27adb59",
|
| 392 |
+
"version_major": 2,
|
| 393 |
+
"version_minor": 0
|
| 394 |
+
},
|
| 395 |
+
"text/plain": [
|
| 396 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 397 |
+
]
|
| 398 |
+
},
|
| 399 |
+
"metadata": {},
|
| 400 |
+
"output_type": "display_data"
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"name": "stdout",
|
| 404 |
+
"output_type": "stream",
|
| 405 |
+
"text": [
|
| 406 |
+
" Rank-Geneformer\n",
|
| 407 |
+
"0 0.805556\n"
|
| 408 |
+
]
|
| 409 |
+
}
|
| 410 |
+
],
|
| 411 |
+
"source": [
|
| 412 |
+
"from scGraph import scGraph\n",
|
| 413 |
+
"\n",
|
| 414 |
+
"scg = scGraph(adata=input_data.adata, batch_key=\"batch\", label_key=\"celltype\", \n",
|
| 415 |
+
" trim_rate=0.05, thres_batch=1, thres_celltype=1)\n",
|
| 416 |
+
"scg.preprocess()\n",
|
| 417 |
+
"scg.compute()\n",
|
| 418 |
+
"results = scg.evaluate()\n",
|
| 419 |
+
"print(results)"
|
| 420 |
+
]
|
| 421 |
+
},
|
| 422 |
+
{
|
| 423 |
+
"cell_type": "markdown",
|
| 424 |
+
"metadata": {},
|
| 425 |
+
"source": [
|
| 426 |
+
"# OOD Dataset raw metrics"
|
| 427 |
+
]
|
| 428 |
+
},
|
| 429 |
+
{
|
| 430 |
+
"cell_type": "code",
|
| 431 |
+
"execution_count": null,
|
| 432 |
+
"metadata": {},
|
| 433 |
+
"outputs": [],
|
| 434 |
+
"source": [
|
| 435 |
+
"# import scanpy as sc \n",
|
| 436 |
+
"\n",
|
| 437 |
+
"# cdata = sc.read_h5ad(\"zero_shot_data/ood_celltype_data1_expand.h5ad\")\n",
|
| 438 |
+
"# adata = cdata.copy()\n",
|
| 439 |
+
"# sc.pp.subsample(adata, fraction=0.05, copy=False)"
|
| 440 |
+
]
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"cell_type": "code",
|
| 444 |
+
"execution_count": 19,
|
| 445 |
+
"metadata": {},
|
| 446 |
+
"outputs": [],
|
| 447 |
+
"source": [
|
| 448 |
+
"# use_rep = \"X\"\n",
|
| 449 |
+
"# adata.obsm[\"X\"] = adata.X\n",
|
| 450 |
+
"# adata.uns.pop(\"neighbors\", None)\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"# sc.pp.neighbors(adata, use_rep=use_rep)\n",
|
| 453 |
+
"# resolutions = [2 * i / 20 for i in range(1, 21)] # Default: 20 steps from 0.1 to 2.0\n",
|
| 454 |
+
"# best_nmi = -1\n",
|
| 455 |
+
"# best_res = None\n",
|
| 456 |
+
"# best_clustering = None"
|
| 457 |
+
]
|
| 458 |
+
},
|
| 459 |
+
{
|
| 460 |
+
"cell_type": "code",
|
| 461 |
+
"execution_count": 20,
|
| 462 |
+
"metadata": {},
|
| 463 |
+
"outputs": [
|
| 464 |
+
{
|
| 465 |
+
"name": "stderr",
|
| 466 |
+
"output_type": "stream",
|
| 467 |
+
"text": [
|
| 468 |
+
"Louvain clustering: 100%|██████████| 20/20 [00:22<00:00, 1.14s/it]\n"
|
| 469 |
+
]
|
| 470 |
+
},
|
| 471 |
+
{
|
| 472 |
+
"name": "stdout",
|
| 473 |
+
"output_type": "stream",
|
| 474 |
+
"text": [
|
| 475 |
+
"mean silhouette per group: silhouette_score\n",
|
| 476 |
+
"group \n",
|
| 477 |
+
"CL:0000077 0.951371\n",
|
| 478 |
+
"CL:0000091 0.905183\n",
|
| 479 |
+
"CL:0000099 0.856871\n",
|
| 480 |
+
"CL:0000164 0.913159\n",
|
| 481 |
+
"CL:0000189 0.934462\n",
|
| 482 |
+
"CL:0000312 0.933951\n",
|
| 483 |
+
"CL:0000453 0.966310\n",
|
| 484 |
+
"CL:0000575 0.779139\n",
|
| 485 |
+
"CL:0000750 0.991985\n",
|
| 486 |
+
"CL:0000767 0.977141\n",
|
| 487 |
+
"CL:0000771 0.893556\n",
|
| 488 |
+
"CL:0000776 0.932994\n",
|
| 489 |
+
"CL:0000810 0.913306\n",
|
| 490 |
+
"CL:0000817 0.931130\n",
|
| 491 |
+
"CL:0000837 0.967683\n",
|
| 492 |
+
"CL:0000843 0.948814\n",
|
| 493 |
+
"CL:0000861 0.841148\n",
|
| 494 |
+
"CL:0000915 0.945803\n",
|
| 495 |
+
"CL:0000957 0.970545\n",
|
| 496 |
+
"CL:0001029 0.950351\n",
|
| 497 |
+
"CL:0001057 0.946863\n",
|
| 498 |
+
"CL:0001074 0.936960\n",
|
| 499 |
+
"CL:0002028 0.935891\n",
|
| 500 |
+
"CL:0002045 0.950375\n",
|
| 501 |
+
"CL:0002064 0.926107\n",
|
| 502 |
+
"CL:0002075 0.759782\n",
|
| 503 |
+
"CL:0002201 0.973459\n",
|
| 504 |
+
"CL:0002393 0.966944\n",
|
| 505 |
+
"CL:0002518 0.911847\n",
|
| 506 |
+
"CL:0005012 0.961174\n",
|
| 507 |
+
"CL:0009009 0.957441\n",
|
| 508 |
+
"CL:0009010 0.933421\n",
|
| 509 |
+
"CL:0009017 0.952055\n",
|
| 510 |
+
"CL:0009042 0.943946\n",
|
| 511 |
+
"CL:0009095 0.863287\n",
|
| 512 |
+
"CL:0011024 0.925223\n",
|
| 513 |
+
"CL:0017000 0.943662\n",
|
| 514 |
+
"CL:1000398 0.954797\n",
|
| 515 |
+
"CL:1000487 0.973023\n",
|
| 516 |
+
"CL:1000488 0.950142\n",
|
| 517 |
+
"CL:1001432 0.984860\n"
|
| 518 |
+
]
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"data": {
|
| 522 |
+
"text/plain": [
|
| 523 |
+
"{'NMI_cluster/label': 0.7833172618112929,\n",
|
| 524 |
+
" 'ARI_cluster/label': 0.5728303202672791,\n",
|
| 525 |
+
" 'ASW_label': 0.4911566338564166,\n",
|
| 526 |
+
" 'graph_conn': 0.7769019941103583,\n",
|
| 527 |
+
" 'ASW_batch': 0.5006964505924973,\n",
|
| 528 |
+
" 'ASW_label/batch': 0.9306380360099057,\n",
|
| 529 |
+
" 'PCR_batch': 0.757978241899424,\n",
|
| 530 |
+
" 'Average_Batch_Score': 0.6293373462459606,\n",
|
| 531 |
+
" 'avg_bio': 0.6157680719783295}"
|
| 532 |
+
]
|
| 533 |
+
},
|
| 534 |
+
"execution_count": 20,
|
| 535 |
+
"metadata": {},
|
| 536 |
+
"output_type": "execute_result"
|
| 537 |
+
}
|
| 538 |
+
],
|
| 539 |
+
"source": [
|
| 540 |
+
"# label_key = \"celltype\"\n",
|
| 541 |
+
"# results_dict = {}\n",
|
| 542 |
+
"# for res in tqdm(resolutions, disable=not True, desc=\"Louvain clustering\"):\n",
|
| 543 |
+
"# sc.tl.louvain(adata, resolution=res, key_added=\"temp_cluster\")\n",
|
| 544 |
+
"# nmi = scib.metrics.nmi(adata, \"temp_cluster\", label_key)\n",
|
| 545 |
+
"# if nmi > best_nmi:\n",
|
| 546 |
+
"# best_nmi = nmi\n",
|
| 547 |
+
"# best_res = res\n",
|
| 548 |
+
"# best_clustering = adata.obs[\"temp_cluster\"].copy()\n",
|
| 549 |
+
"# del adata.obs[\"temp_cluster\"]\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"# adata.obs[\"cluster\"] = best_clustering\n",
|
| 552 |
+
"# # Biological conservation metrics\n",
|
| 553 |
+
"# results_dict[\"NMI_cluster/label\"] = scib.metrics.nmi(adata, \"cluster\", label_key, \"arithmetic\")\n",
|
| 554 |
+
"# results_dict[\"ARI_cluster/label\"] = scib.metrics.ari(adata, \"cluster\", label_key)\n",
|
| 555 |
+
"# results_dict[\"ASW_label\"] = scib.metrics.silhouette(adata, label_key, use_rep, \"euclidean\")\n",
|
| 556 |
+
"\n",
|
| 557 |
+
"# # Batch effect metrics (if batch_key valid)\n",
|
| 558 |
+
"# batch_key = \"batch\"\n",
|
| 559 |
+
"# if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:\n",
|
| 560 |
+
"# adata.obs[label_key] = adata.obs[label_key].astype(\"category\")\n",
|
| 561 |
+
"# results_dict[\"graph_conn\"] = scib.metrics.graph_connectivity(adata, label_key)\n",
|
| 562 |
+
"# results_dict[\"ASW_batch\"] = (1 - scib.metrics.silhouette(adata, batch_key, use_rep, \"euclidean\"))\n",
|
| 563 |
+
"# results_dict[\"ASW_label/batch\"] = scib.metrics.silhouette_batch(\n",
|
| 564 |
+
"# adata, batch_key, label_key, embed=use_rep, metric=\"euclidean\", return_all=False\n",
|
| 565 |
+
"# )\n",
|
| 566 |
+
"# results_dict[\"PCR_batch\"] = scib.metrics.pcr(\n",
|
| 567 |
+
"# adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False\n",
|
| 568 |
+
"# )\n",
|
| 569 |
+
"# results_dict[\"Average_Batch_Score\"] = (\n",
|
| 570 |
+
"# results_dict[\"ASW_batch\"] + results_dict[\"PCR_batch\"]\n",
|
| 571 |
+
"# ) / 2\n",
|
| 572 |
+
"# else:\n",
|
| 573 |
+
"# if verbose:\n",
|
| 574 |
+
"# log.info(\"Skipping batch metrics — only one batch present or invalid batch_key.\")\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"# results_dict[\"avg_bio\"] = np.mean([\n",
|
| 577 |
+
"# results_dict[\"NMI_cluster/label\"],\n",
|
| 578 |
+
"# results_dict[\"ARI_cluster/label\"],\n",
|
| 579 |
+
"# results_dict[\"ASW_label\"]\n",
|
| 580 |
+
"# ])\n",
|
| 581 |
+
"\n",
|
| 582 |
+
"# # Filter NaNs\n",
|
| 583 |
+
"# results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}\n",
|
| 584 |
+
"\n",
|
| 585 |
+
"# results_dict"
|
| 586 |
+
]
|
| 587 |
+
},
|
| 588 |
+
{
|
| 589 |
+
"cell_type": "markdown",
|
| 590 |
+
"metadata": {},
|
| 591 |
+
"source": [
|
| 592 |
+
"# Raw data metrics"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"cell_type": "code",
|
| 597 |
+
"execution_count": null,
|
| 598 |
+
"metadata": {},
|
| 599 |
+
"outputs": [
|
| 600 |
+
{
|
| 601 |
+
"name": "stderr",
|
| 602 |
+
"output_type": "stream",
|
| 603 |
+
"text": [
|
| 604 |
+
"Louvain clustering: 100%|██████████| 20/20 [00:02<00:00, 6.97it/s]\n"
|
| 605 |
+
]
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"name": "stdout",
|
| 609 |
+
"output_type": "stream",
|
| 610 |
+
"text": [
|
| 611 |
+
"mean silhouette per group: silhouette_score\n",
|
| 612 |
+
"group \n",
|
| 613 |
+
"B cells 0.971033\n",
|
| 614 |
+
"CD14+ Monocytes 0.942456\n",
|
| 615 |
+
"CD4 T cells 0.988742\n",
|
| 616 |
+
"CD8 T cells 0.987412\n",
|
| 617 |
+
"Dendritic Cells 0.938792\n",
|
| 618 |
+
"FCGR3A+ Monocytes 0.950513\n",
|
| 619 |
+
"Megakaryocytes 0.752894\n",
|
| 620 |
+
"NK cells 0.890206\n",
|
| 621 |
+
"Other 0.914109\n"
|
| 622 |
+
]
|
| 623 |
+
},
|
| 624 |
+
{
|
| 625 |
+
"data": {
|
| 626 |
+
"text/plain": [
|
| 627 |
+
"{'NMI_cluster/label': 0.6505152890434263,\n",
|
| 628 |
+
" 'ARI_cluster/label': 0.5759899223104351,\n",
|
| 629 |
+
" 'ASW_label': 0.5245759263634682,\n",
|
| 630 |
+
" 'graph_conn': 0.8891452955038966,\n",
|
| 631 |
+
" 'ASW_batch': 0.4964794989209622,\n",
|
| 632 |
+
" 'ASW_label/batch': 0.9262396008669715,\n",
|
| 633 |
+
" 'PCR_batch': 0.0007824623021499673,\n",
|
| 634 |
+
" 'Average_Batch_Score': 0.24863098061155608,\n",
|
| 635 |
+
" 'avg_bio': 0.5836937125724432}"
|
| 636 |
+
]
|
| 637 |
+
},
|
| 638 |
+
"execution_count": 16,
|
| 639 |
+
"metadata": {},
|
| 640 |
+
"output_type": "execute_result"
|
| 641 |
+
}
|
| 642 |
+
],
|
| 643 |
+
"source": [
|
| 644 |
+
"results_dict_raw = eval_clustering_metrics(adata=input_data.adata, \n",
|
| 645 |
+
" batch_key=\"batch\",\n",
|
| 646 |
+
" label_key=\"celltype\",\n",
|
| 647 |
+
" embedding_key=\"X\", # or \"X_scGPT\", etc.\n",
|
| 648 |
+
" verbose=True)\n",
|
| 649 |
+
"results_dict_raw"
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"cell_type": "code",
|
| 654 |
+
"execution_count": null,
|
| 655 |
+
"metadata": {},
|
| 656 |
+
"outputs": [],
|
| 657 |
+
"source": [
|
| 658 |
+
"from scGraph import scGraph\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"scg = scGraph(adata=input_data.adata, batch_key=\"batch\", label_key=\"celltype\", \n",
|
| 661 |
+
" trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key=\"X\")\n",
|
| 662 |
+
"scg.preprocess()\n",
|
| 663 |
+
"scg.compute()\n",
|
| 664 |
+
"results = scg.evaluate()\n",
|
| 665 |
+
"print(results)"
|
| 666 |
+
]
|
| 667 |
+
},
|
| 668 |
+
{
|
| 669 |
+
"cell_type": "markdown",
|
| 670 |
+
"metadata": {},
|
| 671 |
+
"source": [
|
| 672 |
+
"# HVG & scVI"
|
| 673 |
+
]
|
| 674 |
+
},
|
| 675 |
+
{
|
| 676 |
+
"cell_type": "markdown",
|
| 677 |
+
"metadata": {},
|
| 678 |
+
"source": [
|
| 679 |
+
"## HVG"
|
| 680 |
+
]
|
| 681 |
+
},
|
| 682 |
+
{
|
| 683 |
+
"cell_type": "code",
|
| 684 |
+
"execution_count": null,
|
| 685 |
+
"metadata": {},
|
| 686 |
+
"outputs": [],
|
| 687 |
+
"source": [
|
| 688 |
+
"import os\n",
|
| 689 |
+
"import logging\n",
|
| 690 |
+
"\n",
|
| 691 |
+
"import numpy as np\n",
|
| 692 |
+
"import pandas as pd\n",
|
| 693 |
+
"import scanpy as sc\n",
|
| 694 |
+
"from scipy import sparse\n",
|
| 695 |
+
"import scvi\n",
|
| 696 |
+
"\n",
|
| 697 |
+
"import sys\n",
|
| 698 |
+
"sys.path.append(\"zero_shot_batch_effect\")\n",
|
| 699 |
+
"from sc_foundation_evals import utils\n",
|
| 700 |
+
"from sc_foundation_evals.helpers.custom_logging import log\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"log.setLevel(logging.INFO)\n",
|
| 703 |
+
"\n",
|
| 704 |
+
"import warnings\n",
|
| 705 |
+
"os.environ[\"KMP_WARNINGS\"] = \"off\"\n",
|
| 706 |
+
"warnings.filterwarnings(\"ignore\")"
|
| 707 |
+
]
|
| 708 |
+
},
|
| 709 |
+
{
|
| 710 |
+
"cell_type": "code",
|
| 711 |
+
"execution_count": null,
|
| 712 |
+
"metadata": {},
|
| 713 |
+
"outputs": [
|
| 714 |
+
{
|
| 715 |
+
"data": {
|
| 716 |
+
"text/plain": [
|
| 717 |
+
"AnnData object with n_obs × n_vars = 11990 × 3346\n",
|
| 718 |
+
" obs: 'n_counts', 'batch', 'labels', 'str_labels', 'celltype'\n",
|
| 719 |
+
" var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts'\n",
|
| 720 |
+
" uns: 'cell_types'\n",
|
| 721 |
+
" obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc'"
|
| 722 |
+
]
|
| 723 |
+
},
|
| 724 |
+
"execution_count": 18,
|
| 725 |
+
"metadata": {},
|
| 726 |
+
"output_type": "execute_result"
|
| 727 |
+
}
|
| 728 |
+
],
|
| 729 |
+
"source": [
|
| 730 |
+
"# specify the path to anndata object\n",
|
| 731 |
+
"adata_path = in_dataset_path\n",
|
| 732 |
+
"# dataset_name is inferred from in_dataset_path\n",
|
| 733 |
+
"dataset_name = os.path.basename(adata_path).split(\".\")[0]\n",
|
| 734 |
+
"\n",
|
| 735 |
+
"# batch column found in adata.obs\n",
|
| 736 |
+
"batch_col = \"batch\"\n",
|
| 737 |
+
"# where are labels stored in adata.obs? \n",
|
| 738 |
+
"label_col = \"celltype\"\n",
|
| 739 |
+
"# where the raw counts are stored?\n",
|
| 740 |
+
"layer_key = \"counts\"\n",
|
| 741 |
+
"\n",
|
| 742 |
+
"adata = sc.read(adata_path)\n",
|
| 743 |
+
"adata"
|
| 744 |
+
]
|
| 745 |
+
},
|
| 746 |
+
{
|
| 747 |
+
"cell_type": "code",
|
| 748 |
+
"execution_count": null,
|
| 749 |
+
"metadata": {},
|
| 750 |
+
"outputs": [],
|
| 751 |
+
"source": [
|
| 752 |
+
"if layer_key == \"X\":\n",
|
| 753 |
+
" adata.layers[\"counts\"] = adata.X\n",
|
| 754 |
+
"elif layer_key != \"counts\":\n",
|
| 755 |
+
" adata.layers[\"counts\"] = adata.layers[layer_key]"
|
| 756 |
+
]
|
| 757 |
+
},
|
| 758 |
+
{
|
| 759 |
+
"cell_type": "code",
|
| 760 |
+
"execution_count": null,
|
| 761 |
+
"metadata": {},
|
| 762 |
+
"outputs": [],
|
| 763 |
+
"source": [
|
| 764 |
+
"sc.pp.filter_cells(adata, min_genes=10)\n",
|
| 765 |
+
"sc.pp.filter_genes(adata, min_cells=10)\n",
|
| 766 |
+
"sc.pp.normalize_total(adata, target_sum=1e4)\n",
|
| 767 |
+
"sc.pp.log1p(adata)"
|
| 768 |
+
]
|
| 769 |
+
},
|
| 770 |
+
{
|
| 771 |
+
"cell_type": "code",
|
| 772 |
+
"execution_count": null,
|
| 773 |
+
"metadata": {},
|
| 774 |
+
"outputs": [],
|
| 775 |
+
"source": [
|
| 776 |
+
"sc.pp.highly_variable_genes(adata, flavor='seurat', subset=False, n_top_genes=2000)\n",
|
| 777 |
+
"\n",
|
| 778 |
+
"# hvg_mask = adata.var[\"highly_variable\"].values\n",
|
| 779 |
+
"\n",
|
| 780 |
+
"adata.obsm[\"X_genes\"] = adata.X[:, adata.var.highly_variable.values]\n",
|
| 781 |
+
"\n",
|
| 782 |
+
"# check if adata.obsm[\"X_genes\"] is sparse and if so, convert to dense\n",
|
| 783 |
+
"if sparse.issparse(adata.obsm[\"X_genes\"]):\n",
|
| 784 |
+
" adata.obsm[\"X_genes\"] = np.asarray(adata.obsm[\"X_genes\"].todense())"
|
| 785 |
+
]
|
| 786 |
+
},
|
| 787 |
+
{
|
| 788 |
+
"cell_type": "code",
|
| 789 |
+
"execution_count": null,
|
| 790 |
+
"metadata": {},
|
| 791 |
+
"outputs": [
|
| 792 |
+
{
|
| 793 |
+
"name": "stderr",
|
| 794 |
+
"output_type": "stream",
|
| 795 |
+
"text": [
|
| 796 |
+
"\u001b[32mINFO \u001b[0m | 2025-06-22 14:32:11 | \u001b[32mSubsampled adata to 25.0% of original cells.\u001b[0m\n",
|
| 797 |
+
"\u001b[32mINFO \u001b[0m | 2025-06-22 14:32:12 | \u001b[32mSearching for optimal clustering resolution on X_genes...\u001b[0m\n",
|
| 798 |
+
"Louvain clustering: 100%|██████████| 20/20 [00:02<00:00, 8.92it/s]\n",
|
| 799 |
+
"\u001b[32mINFO \u001b[0m | 2025-06-22 14:32:14 | \u001b[32mBest resolution: 0.70 with NMI = 0.6944\u001b[0m\n"
|
| 800 |
+
]
|
| 801 |
+
},
|
| 802 |
+
{
|
| 803 |
+
"name": "stdout",
|
| 804 |
+
"output_type": "stream",
|
| 805 |
+
"text": [
|
| 806 |
+
"mean silhouette per group: silhouette_score\n",
|
| 807 |
+
"group \n",
|
| 808 |
+
"B cells 0.990475\n",
|
| 809 |
+
"CD14+ Monocytes 0.994091\n",
|
| 810 |
+
"CD4 T cells 0.994429\n",
|
| 811 |
+
"CD8 T cells 0.996067\n",
|
| 812 |
+
"Dendritic Cells 0.990181\n",
|
| 813 |
+
"FCGR3A+ Monocytes 0.997131\n",
|
| 814 |
+
"Megakaryocytes 0.973109\n",
|
| 815 |
+
"NK cells 0.997118\n",
|
| 816 |
+
"Other 0.982645\n"
|
| 817 |
+
]
|
| 818 |
+
},
|
| 819 |
+
{
|
| 820 |
+
"data": {
|
| 821 |
+
"text/plain": [
|
| 822 |
+
"{'NMI_cluster/label': 0.6944194464119003,\n",
|
| 823 |
+
" 'ARI_cluster/label': 0.6730602977338459,\n",
|
| 824 |
+
" 'ASW_label': 0.513224795460701,\n",
|
| 825 |
+
" 'graph_conn': 0.8757625892165339,\n",
|
| 826 |
+
" 'ASW_batch': 0.4997675784834428,\n",
|
| 827 |
+
" 'ASW_label/batch': 0.9905828886755944,\n",
|
| 828 |
+
" 'PCR_batch': 0.0008402505807411988,\n",
|
| 829 |
+
" 'Average_Batch_Score': 0.250303914532092,\n",
|
| 830 |
+
" 'avg_bio': 0.626901513202149}"
|
| 831 |
+
]
|
| 832 |
+
},
|
| 833 |
+
"execution_count": 22,
|
| 834 |
+
"metadata": {},
|
| 835 |
+
"output_type": "execute_result"
|
| 836 |
+
}
|
| 837 |
+
],
|
| 838 |
+
"source": [
|
| 839 |
+
"results_dict_hvg = eval_clustering_metrics(adata=adata, \n",
|
| 840 |
+
" batch_key=batch_col,\n",
|
| 841 |
+
" label_key=label_col,\n",
|
| 842 |
+
" embedding_key=\"X_genes\", # or \"X_scGPT\", etc.\n",
|
| 843 |
+
" verbose=True)\n",
|
| 844 |
+
"results_dict_hvg"
|
| 845 |
+
]
|
| 846 |
+
},
|
| 847 |
+
{
|
| 848 |
+
"cell_type": "code",
|
| 849 |
+
"execution_count": null,
|
| 850 |
+
"metadata": {},
|
| 851 |
+
"outputs": [],
|
| 852 |
+
"source": [
|
| 853 |
+
"from scGraph import scGraph\n",
|
| 854 |
+
"\n",
|
| 855 |
+
"scg = scGraph(adata=input_data.adata, batch_key=\"batch\", label_key=\"celltype\", \n",
|
| 856 |
+
" trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key=\"X_genes\")\n",
|
| 857 |
+
"scg.preprocess()\n",
|
| 858 |
+
"scg.compute()\n",
|
| 859 |
+
"results = scg.evaluate()\n",
|
| 860 |
+
"print(results)"
|
| 861 |
+
]
|
| 862 |
+
},
|
| 863 |
+
{
|
| 864 |
+
"cell_type": "markdown",
|
| 865 |
+
"metadata": {},
|
| 866 |
+
"source": [
|
| 867 |
+
"## scVI"
|
| 868 |
+
]
|
| 869 |
+
},
|
| 870 |
+
{
|
| 871 |
+
"cell_type": "code",
|
| 872 |
+
"execution_count": null,
|
| 873 |
+
"metadata": {},
|
| 874 |
+
"outputs": [],
|
| 875 |
+
"source": [
|
| 876 |
+
"if \"counts\" not in adata.layers.keys():\n",
|
| 877 |
+
" adata.layers[\"counts\"] = adata.X.copy()"
|
| 878 |
+
]
|
| 879 |
+
},
|
| 880 |
+
{
|
| 881 |
+
"cell_type": "code",
|
| 882 |
+
"execution_count": null,
|
| 883 |
+
"metadata": {},
|
| 884 |
+
"outputs": [
|
| 885 |
+
{
|
| 886 |
+
"data": {
|
| 887 |
+
"text/plain": [
|
| 888 |
+
"AnnData object with n_obs × n_vars = 11990 × 3345\n",
|
| 889 |
+
" obs: 'n_counts', 'batch', 'labels', 'str_labels', 'celltype', 'n_genes'\n",
|
| 890 |
+
" var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'\n",
|
| 891 |
+
" uns: 'cell_types', 'log1p', 'hvg'\n",
|
| 892 |
+
" obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc', 'X_genes'\n",
|
| 893 |
+
" layers: 'counts'"
|
| 894 |
+
]
|
| 895 |
+
},
|
| 896 |
+
"execution_count": 24,
|
| 897 |
+
"metadata": {},
|
| 898 |
+
"output_type": "execute_result"
|
| 899 |
+
}
|
| 900 |
+
],
|
| 901 |
+
"source": [
|
| 902 |
+
"adata"
|
| 903 |
+
]
|
| 904 |
+
},
|
| 905 |
+
{
|
| 906 |
+
"cell_type": "code",
|
| 907 |
+
"execution_count": null,
|
| 908 |
+
"metadata": {},
|
| 909 |
+
"outputs": [
|
| 910 |
+
{
|
| 911 |
+
"name": "stderr",
|
| 912 |
+
"output_type": "stream",
|
| 913 |
+
"text": [
|
| 914 |
+
"GPU available: True (cuda), used: True\n",
|
| 915 |
+
"TPU available: False, using: 0 TPU cores\n",
|
| 916 |
+
"HPU available: False, using: 0 HPUs\n",
|
| 917 |
+
"You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
|
| 918 |
+
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
|
| 919 |
+
"SLURM auto-requeueing enabled. Setting signal handlers.\n"
|
| 920 |
+
]
|
| 921 |
+
},
|
| 922 |
+
{
|
| 923 |
+
"data": {
|
| 924 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 925 |
+
"model_id": "f654545481b64af2b59385925c0f992a",
|
| 926 |
+
"version_major": 2,
|
| 927 |
+
"version_minor": 0
|
| 928 |
+
},
|
| 929 |
+
"text/plain": [
|
| 930 |
+
"Training: 0%| | 0/400 [00:00<?, ?it/s]"
|
| 931 |
+
]
|
| 932 |
+
},
|
| 933 |
+
"metadata": {},
|
| 934 |
+
"output_type": "display_data"
|
| 935 |
+
},
|
| 936 |
+
{
|
| 937 |
+
"name": "stderr",
|
| 938 |
+
"output_type": "stream",
|
| 939 |
+
"text": [
|
| 940 |
+
"`Trainer.fit` stopped: `max_epochs=400` reached.\n"
|
| 941 |
+
]
|
| 942 |
+
}
|
| 943 |
+
],
|
| 944 |
+
"source": [
|
| 945 |
+
"scvi.model.SCVI.setup_anndata(adata, layer=\"counts\", batch_key=batch_col)\n",
|
| 946 |
+
"model = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood=\"nb\")\n",
|
| 947 |
+
"model.train()\n",
|
| 948 |
+
"adata.obsm[\"X_scVI\"] = model.get_latent_representation()"
|
| 949 |
+
]
|
| 950 |
+
},
|
| 951 |
+
{
|
| 952 |
+
"cell_type": "code",
|
| 953 |
+
"execution_count": null,
|
| 954 |
+
"metadata": {},
|
| 955 |
+
"outputs": [],
|
| 956 |
+
"source": [
|
| 957 |
+
"adata.obsm[\"X_scVI\"] = model.get_latent_representation()"
|
| 958 |
+
]
|
| 959 |
+
},
|
| 960 |
+
{
|
| 961 |
+
"cell_type": "code",
|
| 962 |
+
"execution_count": null,
|
| 963 |
+
"metadata": {},
|
| 964 |
+
"outputs": [
|
| 965 |
+
{
|
| 966 |
+
"name": "stderr",
|
| 967 |
+
"output_type": "stream",
|
| 968 |
+
"text": [
|
| 969 |
+
"\u001b[32mINFO \u001b[0m | 2025-06-22 14:36:48 | \u001b[32mSubsampled adata to 25.0% of original cells.\u001b[0m\n",
|
| 970 |
+
"\u001b[32mINFO \u001b[0m | 2025-06-22 14:36:48 | \u001b[32mSearching for optimal clustering resolution on X_scVI...\u001b[0m\n",
|
| 971 |
+
"Louvain clustering: 100%|██████████| 20/20 [00:02<00:00, 7.97it/s]\n",
|
| 972 |
+
"\u001b[32mINFO \u001b[0m | 2025-06-22 14:36:51 | \u001b[32mBest resolution: 1.20 with NMI = 0.7544\u001b[0m\n"
|
| 973 |
+
]
|
| 974 |
+
},
|
| 975 |
+
{
|
| 976 |
+
"name": "stdout",
|
| 977 |
+
"output_type": "stream",
|
| 978 |
+
"text": [
|
| 979 |
+
"mean silhouette per group: silhouette_score\n",
|
| 980 |
+
"group \n",
|
| 981 |
+
"B cells 0.991501\n",
|
| 982 |
+
"CD14+ Monocytes 0.976939\n",
|
| 983 |
+
"CD4 T cells 0.987053\n",
|
| 984 |
+
"CD8 T cells 0.980696\n",
|
| 985 |
+
"Dendritic Cells 0.931121\n",
|
| 986 |
+
"FCGR3A+ Monocytes 0.974440\n",
|
| 987 |
+
"Megakaryocytes 0.910766\n",
|
| 988 |
+
"NK cells 0.971491\n",
|
| 989 |
+
"Other 0.899360\n"
|
| 990 |
+
]
|
| 991 |
+
},
|
| 992 |
+
{
|
| 993 |
+
"data": {
|
| 994 |
+
"text/plain": [
|
| 995 |
+
"{'NMI_cluster/label': 0.7543923134993394,\n",
|
| 996 |
+
" 'ARI_cluster/label': 0.6471385261878778,\n",
|
| 997 |
+
" 'ASW_label': 0.482499361038208,\n",
|
| 998 |
+
" 'graph_conn': 0.9461266173017836,\n",
|
| 999 |
+
" 'ASW_batch': 0.5024425515439361,\n",
|
| 1000 |
+
" 'ASW_label/batch': 0.9581518028443176,\n",
|
| 1001 |
+
" 'PCR_batch': 0.00044665558752302455,\n",
|
| 1002 |
+
" 'Average_Batch_Score': 0.25144460356572956,\n",
|
| 1003 |
+
" 'avg_bio': 0.628010066908475}"
|
| 1004 |
+
]
|
| 1005 |
+
},
|
| 1006 |
+
"execution_count": 27,
|
| 1007 |
+
"metadata": {},
|
| 1008 |
+
"output_type": "execute_result"
|
| 1009 |
+
}
|
| 1010 |
+
],
|
| 1011 |
+
"source": [
|
| 1012 |
+
"results_dict_scvi = eval_clustering_metrics(adata=adata, \n",
|
| 1013 |
+
" batch_key=batch_col,\n",
|
| 1014 |
+
" label_key=label_col,\n",
|
| 1015 |
+
" embedding_key=\"X_scVI\", # or \"X_scGPT\", etc.\n",
|
| 1016 |
+
" verbose=True)\n",
|
| 1017 |
+
"results_dict_scvi"
|
| 1018 |
+
]
|
| 1019 |
+
},
|
| 1020 |
+
{
|
| 1021 |
+
"cell_type": "code",
|
| 1022 |
+
"execution_count": null,
|
| 1023 |
+
"metadata": {},
|
| 1024 |
+
"outputs": [],
|
| 1025 |
+
"source": [
|
| 1026 |
+
"from scGraph import scGraph\n",
|
| 1027 |
+
"\n",
|
| 1028 |
+
"scg = scGraph(adata=input_data.adata, batch_key=\"batch\", label_key=\"celltype\", \n",
|
| 1029 |
+
" trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key=\"X_scVI\")\n",
|
| 1030 |
+
"scg.preprocess()\n",
|
| 1031 |
+
"scg.compute()\n",
|
| 1032 |
+
"results = scg.evaluate()\n",
|
| 1033 |
+
"print(results)"
|
| 1034 |
+
]
|
| 1035 |
+
}
|
| 1036 |
+
],
|
| 1037 |
+
"metadata": {
|
| 1038 |
+
"kernelspec": {
|
| 1039 |
+
"display_name": "Python 3",
|
| 1040 |
+
"language": "python",
|
| 1041 |
+
"name": "python3"
|
| 1042 |
+
},
|
| 1043 |
+
"language_info": {
|
| 1044 |
+
"codemirror_mode": {
|
| 1045 |
+
"name": "ipython",
|
| 1046 |
+
"version": 3
|
| 1047 |
+
},
|
| 1048 |
+
"file_extension": ".py",
|
| 1049 |
+
"mimetype": "text/x-python",
|
| 1050 |
+
"name": "python",
|
| 1051 |
+
"nbconvert_exporter": "python",
|
| 1052 |
+
"pygments_lexer": "ipython3",
|
| 1053 |
+
"version": "3.11.7"
|
| 1054 |
+
}
|
| 1055 |
+
},
|
| 1056 |
+
"nbformat": 4,
|
| 1057 |
+
"nbformat_minor": 2
|
| 1058 |
+
}
|
Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_raw_data.ipynb
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from typing import Dict, Optional\n",
|
| 10 |
+
"import numpy as np\n",
|
| 11 |
+
"import scanpy as sc\n",
|
| 12 |
+
"import scib\n",
|
| 13 |
+
"from anndata import AnnData\n",
|
| 14 |
+
"from sklearn.metrics import silhouette_score\n",
|
| 15 |
+
"from tqdm import tqdm\n",
|
| 16 |
+
"import pandas as pd\n",
|
| 17 |
+
"import logging\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"log = logging.getLogger(__name__)\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"def eval_clustering_metrics(\n",
|
| 23 |
+
" adata: AnnData,\n",
|
| 24 |
+
" batch_key: Optional[str] = \"str_batch\",\n",
|
| 25 |
+
" label_key: str = \"cell_type\",\n",
|
| 26 |
+
" embedding_key: str = \"X\", # \"X\" for raw, or embedding key in .obsm\n",
|
| 27 |
+
" resolutions: Optional[list] = None,\n",
|
| 28 |
+
" use_progress_bar: bool = True,\n",
|
| 29 |
+
" verbose: bool = False,\n",
|
| 30 |
+
") -> Dict[str, float]:\n",
|
| 31 |
+
" \"\"\"Evaluate biological and batch mixing metrics on an embedding or raw expression.\"\"\"\n",
|
| 32 |
+
" \n",
|
| 33 |
+
" results_dict = {}\n",
|
| 34 |
+
"\n",
|
| 35 |
+
" # Determine whether to use .X or .obsm[embedding_key]\n",
|
| 36 |
+
" if embedding_key == \"X\":\n",
|
| 37 |
+
" use_rep = \"X\"\n",
|
| 38 |
+
" adata.obsm[\"X\"] = adata.X\n",
|
| 39 |
+
" elif embedding_key in adata.obsm:\n",
|
| 40 |
+
" use_rep = embedding_key\n",
|
| 41 |
+
" else:\n",
|
| 42 |
+
" raise ValueError(f\"embedding_key '{embedding_key}' not found in adata.obsm or is not 'X'\")\n",
|
| 43 |
+
"\n",
|
| 44 |
+
" # Clear stale neighbors\n",
|
| 45 |
+
" if \"neighbors\" in adata.uns:\n",
|
| 46 |
+
" if verbose:\n",
|
| 47 |
+
" log.warning(f\"Removing stale neighbors computed from other representations.\")\n",
|
| 48 |
+
" adata.uns.pop(\"neighbors\", None)\n",
|
| 49 |
+
"\n",
|
| 50 |
+
" sc.pp.neighbors(adata, use_rep=use_rep)\n",
|
| 51 |
+
"\n",
|
| 52 |
+
" # Run Louvain across multiple resolutions\n",
|
| 53 |
+
" if resolutions is None:\n",
|
| 54 |
+
" resolutions = [2 * i / 20 for i in range(1, 21)] # Default: 20 steps from 0.1 to 2.0\n",
|
| 55 |
+
"\n",
|
| 56 |
+
" best_nmi = -1\n",
|
| 57 |
+
" best_res = None\n",
|
| 58 |
+
" best_clustering = None\n",
|
| 59 |
+
"\n",
|
| 60 |
+
" if verbose:\n",
|
| 61 |
+
" log.info(f\"Searching for optimal clustering resolution on {use_rep}...\")\n",
|
| 62 |
+
"\n",
|
| 63 |
+
" for res in tqdm(resolutions, disable=not use_progress_bar, desc=\"Louvain clustering\"):\n",
|
| 64 |
+
" sc.tl.louvain(adata, resolution=res, key_added=\"temp_cluster\")\n",
|
| 65 |
+
" nmi = scib.metrics.nmi(adata, \"temp_cluster\", label_key)\n",
|
| 66 |
+
" if nmi > best_nmi:\n",
|
| 67 |
+
" best_nmi = nmi\n",
|
| 68 |
+
" best_res = res\n",
|
| 69 |
+
" best_clustering = adata.obs[\"temp_cluster\"].copy()\n",
|
| 70 |
+
" del adata.obs[\"temp_cluster\"]\n",
|
| 71 |
+
"\n",
|
| 72 |
+
" if verbose:\n",
|
| 73 |
+
" log.info(f\"Best resolution: {best_res:.2f} with NMI = {best_nmi:.4f}\")\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" adata.obs[\"cluster\"] = best_clustering\n",
|
| 76 |
+
"\n",
|
| 77 |
+
" # Biological conservation metrics\n",
|
| 78 |
+
" results_dict[\"NMI_cluster/label\"] = scib.metrics.nmi(adata, \"cluster\", label_key, \"arithmetic\")\n",
|
| 79 |
+
" results_dict[\"ARI_cluster/label\"] = scib.metrics.ari(adata, \"cluster\", label_key)\n",
|
| 80 |
+
" results_dict[\"ASW_label\"] = scib.metrics.silhouette(adata, label_key, use_rep, \"euclidean\")\n",
|
| 81 |
+
"\n",
|
| 82 |
+
" # Batch effect metrics (if batch_key valid)\n",
|
| 83 |
+
" if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:\n",
|
| 84 |
+
" results_dict[\"graph_conn\"] = scib.metrics.graph_connectivity(adata, label_key)\n",
|
| 85 |
+
" results_dict[\"ASW_batch\"] = scib.metrics.silhouette(adata, batch_key, use_rep, \"euclidean\")\n",
|
| 86 |
+
" results_dict[\"ASW_label/batch\"] = scib.metrics.silhouette_batch(\n",
|
| 87 |
+
" adata, batch_key, label_key, embed=use_rep, metric=\"euclidean\", return_all=False\n",
|
| 88 |
+
" )\n",
|
| 89 |
+
" results_dict[\"PCR_batch\"] = scib.metrics.pcr(\n",
|
| 90 |
+
" adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False\n",
|
| 91 |
+
" )\n",
|
| 92 |
+
" else:\n",
|
| 93 |
+
" if verbose:\n",
|
| 94 |
+
" log.info(\"Skipping batch metrics — only one batch present or invalid batch_key.\")\n",
|
| 95 |
+
" \n",
|
| 96 |
+
" results_dict[\"avg_bio\"] = np.mean([\n",
|
| 97 |
+
" results_dict[\"NMI_cluster/label\"],\n",
|
| 98 |
+
" results_dict[\"ARI_cluster/label\"],\n",
|
| 99 |
+
" results_dict[\"ASW_label\"]\n",
|
| 100 |
+
" ])\n",
|
| 101 |
+
"\n",
|
| 102 |
+
" # Filter NaNs\n",
|
| 103 |
+
" results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}\n",
|
| 104 |
+
"\n",
|
| 105 |
+
" return results_dict\n"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "code",
|
| 110 |
+
"execution_count": null,
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [
|
| 113 |
+
{
|
| 114 |
+
"name": "stderr",
|
| 115 |
+
"output_type": "stream",
|
| 116 |
+
"text": [
|
| 117 |
+
"Louvain clustering: 100%|██████████| 20/20 [00:15<00:00, 1.32it/s]\n",
|
| 118 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 119 |
+
" tab = pd.value_counts(labels)\n",
|
| 120 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 121 |
+
" tab = pd.value_counts(labels)\n",
|
| 122 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 123 |
+
" tab = pd.value_counts(labels)\n",
|
| 124 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 125 |
+
" tab = pd.value_counts(labels)\n",
|
| 126 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 127 |
+
" tab = pd.value_counts(labels)\n",
|
| 128 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 129 |
+
" tab = pd.value_counts(labels)\n",
|
| 130 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 131 |
+
" tab = pd.value_counts(labels)\n",
|
| 132 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 133 |
+
" tab = pd.value_counts(labels)\n",
|
| 134 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 135 |
+
" tab = pd.value_counts(labels)\n"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"name": "stdout",
|
| 140 |
+
"output_type": "stream",
|
| 141 |
+
"text": [
|
| 142 |
+
"mean silhouette per group: silhouette_score\n",
|
| 143 |
+
"group \n",
|
| 144 |
+
"B cells 0.986484\n",
|
| 145 |
+
"CD14+ Monocytes 0.943531\n",
|
| 146 |
+
"CD4 T cells 0.980745\n",
|
| 147 |
+
"CD8 T cells 0.951482\n",
|
| 148 |
+
"Dendritic Cells 0.956119\n",
|
| 149 |
+
"FCGR3A+ Monocytes 0.986242\n",
|
| 150 |
+
"Megakaryocytes 0.856766\n",
|
| 151 |
+
"NK cells 0.953083\n",
|
| 152 |
+
"Other 0.930244\n"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"name": "stderr",
|
| 157 |
+
"output_type": "stream",
|
| 158 |
+
"text": [
|
| 159 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/anndata/_core/anndata.py:522: FutureWarning: The dtype argument is deprecated and will be removed in late 2024.\n",
|
| 160 |
+
" warnings.warn(\n"
|
| 161 |
+
]
|
| 162 |
+
}
|
| 163 |
+
],
|
| 164 |
+
"source": [
|
| 165 |
+
"import scanpy as sc \n",
|
| 166 |
+
"adata = sc.read_h5ad(\"zero_shot_batch_data/pbmc.h5ad\") \n",
|
| 167 |
+
"\n",
|
| 168 |
+
"results_dict = eval_clustering_metrics(adata=adata, \n",
|
| 169 |
+
" batch_key=\"batch\",\n",
|
| 170 |
+
" label_key=\"celltype\",\n",
|
| 171 |
+
" embedding_key=\"X\", # or \"X_scGPT\", etc.\n",
|
| 172 |
+
" verbose=True)"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"execution_count": 12,
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [
|
| 180 |
+
{
|
| 181 |
+
"data": {
|
| 182 |
+
"text/plain": [
|
| 183 |
+
"{'NMI_cluster/label': 0.7043350648326699,\n",
|
| 184 |
+
" 'ARI_cluster/label': 0.6456273245075416,\n",
|
| 185 |
+
" 'ASW_label': 0.5333220548927784,\n",
|
| 186 |
+
" 'graph_conn': 0.9038879996225364,\n",
|
| 187 |
+
" 'ASW_batch': 0.4965497492812574,\n",
|
| 188 |
+
" 'ASW_label/batch': 0.9494108132303586,\n",
|
| 189 |
+
" 'PCR_batch': 0.0009914006163016576,\n",
|
| 190 |
+
" 'avg_bio': 0.6277614814109966}"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
"execution_count": 12,
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"output_type": "execute_result"
|
| 196 |
+
}
|
| 197 |
+
],
|
| 198 |
+
"source": [
|
| 199 |
+
"results_dict"
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": 5,
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [
|
| 207 |
+
{
|
| 208 |
+
"name": "stderr",
|
| 209 |
+
"output_type": "stream",
|
| 210 |
+
"text": [
|
| 211 |
+
"/tmp/ipykernel_786097/2986997571.py:30: ImplicitModificationWarning: Setting element `.obsm['X']` of view, initializing view as actual.\n",
|
| 212 |
+
" adata.obsm[\"X\"] = adata.X\n",
|
| 213 |
+
"Louvain clustering: 100%|██████��███| 20/20 [00:11<00:00, 1.68it/s]\n",
|
| 214 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 215 |
+
" tab = pd.value_counts(labels)\n",
|
| 216 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 217 |
+
" tab = pd.value_counts(labels)\n",
|
| 218 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 219 |
+
" tab = pd.value_counts(labels)\n",
|
| 220 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 221 |
+
" tab = pd.value_counts(labels)\n",
|
| 222 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 223 |
+
" tab = pd.value_counts(labels)\n",
|
| 224 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 225 |
+
" tab = pd.value_counts(labels)\n",
|
| 226 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 227 |
+
" tab = pd.value_counts(labels)\n",
|
| 228 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 229 |
+
" tab = pd.value_counts(labels)\n",
|
| 230 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 231 |
+
" tab = pd.value_counts(labels)\n",
|
| 232 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 233 |
+
" tab = pd.value_counts(labels)\n",
|
| 234 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 235 |
+
" tab = pd.value_counts(labels)\n",
|
| 236 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 237 |
+
" tab = pd.value_counts(labels)\n",
|
| 238 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 239 |
+
" tab = pd.value_counts(labels)\n",
|
| 240 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 241 |
+
" tab = pd.value_counts(labels)\n",
|
| 242 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 243 |
+
" tab = pd.value_counts(labels)\n",
|
| 244 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 245 |
+
" tab = pd.value_counts(labels)\n",
|
| 246 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 247 |
+
" tab = pd.value_counts(labels)\n",
|
| 248 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 249 |
+
" tab = pd.value_counts(labels)\n",
|
| 250 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 251 |
+
" tab = pd.value_counts(labels)\n",
|
| 252 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
|
| 253 |
+
" tab = pd.value_counts(labels)\n"
|
| 254 |
+
]
|
| 255 |
+
},
|
| 256 |
+
{
|
| 257 |
+
"name": "stdout",
|
| 258 |
+
"output_type": "stream",
|
| 259 |
+
"text": [
|
| 260 |
+
"mean silhouette per group: nan\n"
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"name": "stderr",
|
| 265 |
+
"output_type": "stream",
|
| 266 |
+
"text": [
|
| 267 |
+
"/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/anndata/_core/anndata.py:522: FutureWarning: The dtype argument is deprecated and will be removed in late 2024.\n",
|
| 268 |
+
" warnings.warn(\n"
|
| 269 |
+
]
|
| 270 |
+
}
|
| 271 |
+
],
|
| 272 |
+
"source": [
|
| 273 |
+
"results_dict_ood = eval_clustering_metrics(adata=adata_ood[:15000],\n",
|
| 274 |
+
" batch_key=\"batch\",\n",
|
| 275 |
+
" label_key=\"cell_type\",\n",
|
| 276 |
+
" embedding_key=\"X\", \n",
|
| 277 |
+
" verbose=True)"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
"cell_type": "code",
|
| 282 |
+
"execution_count": 6,
|
| 283 |
+
"metadata": {},
|
| 284 |
+
"outputs": [
|
| 285 |
+
{
|
| 286 |
+
"data": {
|
| 287 |
+
"text/plain": [
|
| 288 |
+
"{'NMI_cluster/label': 0.9334102174490695,\n",
|
| 289 |
+
" 'ARI_cluster/label': 0.9699361136567832,\n",
|
| 290 |
+
" 'ASW_label': 0.5538543930108312,\n",
|
| 291 |
+
" 'graph_conn': 0.9231509101914211,\n",
|
| 292 |
+
" 'ASW_batch': 0.6438532075334105,\n",
|
| 293 |
+
" 'PCR_batch': 0.042066597759588056,\n",
|
| 294 |
+
" 'avg_bio': 0.8190669080388946}"
|
| 295 |
+
]
|
| 296 |
+
},
|
| 297 |
+
"execution_count": 6,
|
| 298 |
+
"metadata": {},
|
| 299 |
+
"output_type": "execute_result"
|
| 300 |
+
}
|
| 301 |
+
],
|
| 302 |
+
"source": [
|
| 303 |
+
"results_dict_ood"
|
| 304 |
+
]
|
| 305 |
+
}
|
| 306 |
+
],
|
| 307 |
+
"metadata": {
|
| 308 |
+
"kernelspec": {
|
| 309 |
+
"display_name": "Python 3",
|
| 310 |
+
"language": "python",
|
| 311 |
+
"name": "python3"
|
| 312 |
+
},
|
| 313 |
+
"language_info": {
|
| 314 |
+
"codemirror_mode": {
|
| 315 |
+
"name": "ipython",
|
| 316 |
+
"version": 3
|
| 317 |
+
},
|
| 318 |
+
"file_extension": ".py",
|
| 319 |
+
"mimetype": "text/x-python",
|
| 320 |
+
"name": "python",
|
| 321 |
+
"nbconvert_exporter": "python",
|
| 322 |
+
"pygments_lexer": "ipython3",
|
| 323 |
+
"version": "3.11.7"
|
| 324 |
+
}
|
| 325 |
+
},
|
| 326 |
+
"nbformat": 4,
|
| 327 |
+
"nbformat_minor": 2
|
| 328 |
+
}
|
Downstream_tasks/Zero_shot_batch_effect/requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
anndata==0.9.2
|
| 2 |
+
colorlog==6.7.0
|
| 3 |
+
scgpt==0.1.6
|
| 4 |
+
geneformer==0.0.1
|
| 5 |
+
PyComplexHeatmap
|
| 6 |
+
numpy
|
| 7 |
+
pandas
|
| 8 |
+
scanpy
|
| 9 |
+
scipy
|
| 10 |
+
seaborn
|
| 11 |
+
scib
|
| 12 |
+
scvi-tools
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__init__.py
ADDED
|
File without changes
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/cell_embeddings.cpython-310.pyc
ADDED
|
Binary file (9.72 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/cell_embeddings.cpython-311.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (5.76 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/data.cpython-311.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/geneformer_forward.cpython-310.pyc
ADDED
|
Binary file (9.1 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/geneformer_forward.cpython-311.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/model_output.cpython-310.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/model_output.cpython-311.pyc
ADDED
|
Binary file (31.6 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/scgpt_forward.cpython-310.pyc
ADDED
|
Binary file (19.5 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (21.8 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/cell_embeddings.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Copyright (c) Microsoft Corporation.
|
| 2 |
+
## Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import List, Optional, Tuple, Dict, Union
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
plt.style.use('fivethirtyeight')
|
| 9 |
+
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
import scanpy as sc
|
| 12 |
+
|
| 13 |
+
from .helpers import umap
|
| 14 |
+
from .helpers.custom_logging import log
|
| 15 |
+
|
| 16 |
+
from . import data, utils
|
| 17 |
+
from .geneformer_forward import Geneformer_instance
|
| 18 |
+
# from .scgpt_forward import scGPT_instance
|
| 19 |
+
|
| 20 |
+
class CellEmbeddingsEval():
|
| 21 |
+
def __init__(self,
|
| 22 |
+
# model_instance: Union[scGPT_instance,
|
| 23 |
+
# Geneformer_instance],
|
| 24 |
+
model_instance: Union[Geneformer_instance],
|
| 25 |
+
data: data.InputData,
|
| 26 |
+
label_key: Union[str, List[str]] = "cell_type",
|
| 27 |
+
batch_key: Optional[str] = None,
|
| 28 |
+
output_dir: Optional[str] = None,
|
| 29 |
+
log_wandb: bool = False) -> None:
|
| 30 |
+
|
| 31 |
+
# test if model_instance is an instance of scGPT_instance or Geneformer_instance
|
| 32 |
+
# if not isinstance(model_instance,
|
| 33 |
+
# (scGPT_instance, Geneformer_instance)):
|
| 34 |
+
# msg = ("scgpt_instance must be an instance of "
|
| 35 |
+
# "scGPT_instance or Geneformer_instance")
|
| 36 |
+
if not isinstance(model_instance,
|
| 37 |
+
(Geneformer_instance)):
|
| 38 |
+
msg = ("scgpt_instance must be an instance of "
|
| 39 |
+
"scGPT_instance or Geneformer_instance")
|
| 40 |
+
log.error(msg)
|
| 41 |
+
raise ValueError(msg)
|
| 42 |
+
|
| 43 |
+
# test if instance is properly processed
|
| 44 |
+
if not hasattr(model_instance, "cell_embeddings"):
|
| 45 |
+
msg = "Cell embeddings need to be extracted first"
|
| 46 |
+
log.error(msg)
|
| 47 |
+
raise ValueError(msg)
|
| 48 |
+
|
| 49 |
+
# if wandb set to true and not initialized, throw error
|
| 50 |
+
if log_wandb and not model_instance._wandb:
|
| 51 |
+
msg = "wandb is not initialized in model_instance"
|
| 52 |
+
log.error(msg)
|
| 53 |
+
raise ValueError(msg)
|
| 54 |
+
|
| 55 |
+
self._wandb = model_instance._wandb
|
| 56 |
+
|
| 57 |
+
self.eval_instance = model_instance
|
| 58 |
+
self.data = data
|
| 59 |
+
|
| 60 |
+
if batch_key is not None:
|
| 61 |
+
if batch_key not in self.data.adata.obs.columns:
|
| 62 |
+
msg = f"batch_key {batch_key} not found in adata.obs"
|
| 63 |
+
log.error(msg)
|
| 64 |
+
raise ValueError(msg)
|
| 65 |
+
else:
|
| 66 |
+
self.batch_key = batch_key
|
| 67 |
+
else:
|
| 68 |
+
try:
|
| 69 |
+
self.batch_key = self.data.batch_str_col
|
| 70 |
+
except AttributeError:
|
| 71 |
+
msg = "batch_key not provided and not found in data object"
|
| 72 |
+
log.error(msg)
|
| 73 |
+
raise ValueError(msg)
|
| 74 |
+
|
| 75 |
+
if output_dir is not None:
|
| 76 |
+
# if output dir is provided, use it
|
| 77 |
+
self.output_dir = output_dir
|
| 78 |
+
# check if output_dir exists
|
| 79 |
+
if not os.path.exists(self.output_dir):
|
| 80 |
+
log.warning(f"Creating the output directory {self.output_dir}")
|
| 81 |
+
os.makedirs(self.output_dir)
|
| 82 |
+
else:
|
| 83 |
+
# use the same output_dir as the scgpt_instance
|
| 84 |
+
self.output_dir = self.eval_instance.output_dir
|
| 85 |
+
|
| 86 |
+
# if label_key is string, convert to list
|
| 87 |
+
if isinstance(label_key, str):
|
| 88 |
+
label_key = [label_key]
|
| 89 |
+
self.label_key = label_key
|
| 90 |
+
|
| 91 |
+
# make sure that each label exists and is categorical in adata.obs
|
| 92 |
+
for label in self.label_key:
|
| 93 |
+
if label not in self.data.adata.obs.columns:
|
| 94 |
+
msg = f"Label {label} not found in adata.obs"
|
| 95 |
+
log.error(msg)
|
| 96 |
+
raise ValueError(msg)
|
| 97 |
+
self.data.adata.obs[label] = self.data.adata.obs[label].astype("category")
|
| 98 |
+
|
| 99 |
+
def evaluate(self,
|
| 100 |
+
embedding_key: str = "X_scGPT",
|
| 101 |
+
n_cells: int = 7500) -> pd.DataFrame:
|
| 102 |
+
|
| 103 |
+
adata_ = self.data.adata.copy()
|
| 104 |
+
|
| 105 |
+
# if adata_ too big, take a subset
|
| 106 |
+
if adata_.n_obs > n_cells:
|
| 107 |
+
log.warning(f"adata_ has {adata_.n_obs} cells. "
|
| 108 |
+
f"Taking a subset of {n_cells} cells.")
|
| 109 |
+
sc.pp.subsample(adata_, n_obs = n_cells, copy = False)
|
| 110 |
+
|
| 111 |
+
met_df = pd.DataFrame(columns = ["metric", "label", "value"])
|
| 112 |
+
|
| 113 |
+
# get unique values in self.label_key preserving the order
|
| 114 |
+
label_cols = [x for i, x in enumerate(self.label_key)
|
| 115 |
+
if x not in self.label_key[:i]]
|
| 116 |
+
# remove label columns that are not in adata_.obs
|
| 117 |
+
label_cols = [x for x in label_cols if x in adata_.obs.columns]
|
| 118 |
+
|
| 119 |
+
if len(label_cols) == 0:
|
| 120 |
+
msg = f"No label columns {self.label_key} found in adata.obs"
|
| 121 |
+
log.error(msg)
|
| 122 |
+
raise ValueError(msg)
|
| 123 |
+
|
| 124 |
+
# check if the embeddings are in adata
|
| 125 |
+
if embedding_key not in adata_.obsm.keys():
|
| 126 |
+
msg = f"Embeddings {embedding_key} not found in adata.obsm"
|
| 127 |
+
log.error(msg)
|
| 128 |
+
raise ValueError(msg)
|
| 129 |
+
|
| 130 |
+
for label in label_cols:
|
| 131 |
+
log.debug(f"Computing metrics for {label}")
|
| 132 |
+
|
| 133 |
+
metrics = utils.eval_scib_metrics(adata_,
|
| 134 |
+
batch_key = self.batch_key,
|
| 135 |
+
label_key = label,
|
| 136 |
+
embedding_key = embedding_key)
|
| 137 |
+
for metric in metrics.keys():
|
| 138 |
+
log.debug(f"{metric} for {label}: {metrics[metric]}")
|
| 139 |
+
|
| 140 |
+
# log to wandb if initialized
|
| 141 |
+
if self._wandb:
|
| 142 |
+
self._wandb.log({f"{embedding_key}/{label}/{metric}": metrics[metric]})
|
| 143 |
+
|
| 144 |
+
# add row to the dataframe
|
| 145 |
+
met_df.loc[len(met_df)] = [metric, label, metrics[metric]]
|
| 146 |
+
|
| 147 |
+
met_df.to_csv(os.path.join(self.output_dir,
|
| 148 |
+
f"{embedding_key}__metrics.csv"),
|
| 149 |
+
index = False)
|
| 150 |
+
|
| 151 |
+
if self._wandb:
|
| 152 |
+
wandb_df = self._wandb.Table(data = met_df)
|
| 153 |
+
self._wandb.log({f"{embedding_key}/metrics": wandb_df})
|
| 154 |
+
return met_df
|
| 155 |
+
|
| 156 |
+
def create_original_umap(self,
|
| 157 |
+
out_emb: str = "X_umap_input") -> None:
|
| 158 |
+
|
| 159 |
+
sc.pp.neighbors(self.data.adata)
|
| 160 |
+
temp = sc.tl.umap(self.data.adata, min_dist = 0.3, copy=True)
|
| 161 |
+
self.data.adata.obsm[out_emb] = temp.obsm["X_umap"].copy()
|
| 162 |
+
|
| 163 |
+
# TODO: this should be a more generic function that can plot any embedding
|
| 164 |
+
def visualize(self,
|
| 165 |
+
embedding_key: str = "X_scGPT",
|
| 166 |
+
return_fig: bool = False,
|
| 167 |
+
plot_size: Tuple[float, float] = (9, 7),
|
| 168 |
+
plot_title: Optional[str] = None,
|
| 169 |
+
plot_type: [List, str] = "simple",
|
| 170 |
+
n_cells: int = 7500
|
| 171 |
+
) -> Optional[Dict[str, plt.figure]]:
|
| 172 |
+
|
| 173 |
+
raw_emb = "X_umap_input"
|
| 174 |
+
|
| 175 |
+
if embedding_key == raw_emb:
|
| 176 |
+
# if the umap_raw embedding is used, create it first
|
| 177 |
+
self.create_original_umap(out_emb = embedding_key)
|
| 178 |
+
|
| 179 |
+
# if adata already has a umap embedding warn that it will be overwritten
|
| 180 |
+
if "X_umap" in self.data.adata.obsm.keys():
|
| 181 |
+
old_umap_name = "X_umap_old"
|
| 182 |
+
log.warning(f"Copying existing UMAP embedding to {old_umap_name} "
|
| 183 |
+
"and overwriting X_umap.")
|
| 184 |
+
self.data.adata.obsm[old_umap_name] = self.data.adata.obsm["X_umap"].copy()
|
| 185 |
+
|
| 186 |
+
# check if the embeddings are in adata
|
| 187 |
+
if embedding_key not in self.data.adata.obsm.keys():
|
| 188 |
+
msg = f"Embeddings {embedding_key} not found in adata."
|
| 189 |
+
log.error(msg)
|
| 190 |
+
raise ValueError(msg)
|
| 191 |
+
|
| 192 |
+
# if embedding_key contains the string umap, do not compute umap again
|
| 193 |
+
if embedding_key != raw_emb:
|
| 194 |
+
# compute umap embeddings
|
| 195 |
+
sc.pp.neighbors(self.data.adata, use_rep = embedding_key)
|
| 196 |
+
sc.tl.umap(self.data.adata, min_dist = 0.3)
|
| 197 |
+
|
| 198 |
+
adata_ = self.data.adata.copy()
|
| 199 |
+
# if adata_ too big, take a subset
|
| 200 |
+
if adata_.n_obs > n_cells:
|
| 201 |
+
log.warning(f"adata_ has {adata_.n_obs} cells. "
|
| 202 |
+
f"Taking a subset of {n_cells} cells.")
|
| 203 |
+
sc.pp.subsample(adata_, n_obs = n_cells, copy = False)
|
| 204 |
+
# save the subsetted adata.obs
|
| 205 |
+
adata_.obs.to_csv(os.path.join(self.output_dir,
|
| 206 |
+
"adata_obs_subset.csv"))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# make sure plot size is a tuple of numbers
|
| 211 |
+
try:
|
| 212 |
+
w, h = plot_size
|
| 213 |
+
if not isinstance(h, (int, float)) or not isinstance(w, (int, float)):
|
| 214 |
+
msg = f"Height (h = {h}) or width (w = {w}) not valid."
|
| 215 |
+
log.error(msg)
|
| 216 |
+
raise TypeError(msg)
|
| 217 |
+
except TypeError:
|
| 218 |
+
msg = f"Plot size {plot_size} is not a tuple of numbers."
|
| 219 |
+
log.error(msg)
|
| 220 |
+
raise TypeError(msg)
|
| 221 |
+
|
| 222 |
+
# get unique values in self.label_key preserving the order
|
| 223 |
+
label_cols = self.label_key + [self.batch_key]
|
| 224 |
+
label_cols = [x for i, x in enumerate(label_cols)
|
| 225 |
+
if x not in label_cols[:i]]
|
| 226 |
+
# remove label columns that are not in adata_.obs
|
| 227 |
+
label_cols = [x for x in label_cols
|
| 228 |
+
if x in self.data.adata.obs.columns]
|
| 229 |
+
|
| 230 |
+
if len(label_cols) == 0:
|
| 231 |
+
msg = f"No label columns {self.label_key} found in adata.obs"
|
| 232 |
+
log.error(msg)
|
| 233 |
+
raise ValueError(msg)
|
| 234 |
+
|
| 235 |
+
# set the colors for the labels
|
| 236 |
+
labels = dict()
|
| 237 |
+
labels_colors = dict()
|
| 238 |
+
palettes = ['viridis', 'inferno',
|
| 239 |
+
'mako', 'rocket',
|
| 240 |
+
'tab20', 'colorblind',
|
| 241 |
+
'tab20b', 'tab20c']
|
| 242 |
+
|
| 243 |
+
if len(label_cols) > len(palettes):
|
| 244 |
+
log.warning("More labels than palettes. Adding random colors.")
|
| 245 |
+
palettes = palettes + ["random"] * (len(label_cols) - len(palettes))
|
| 246 |
+
|
| 247 |
+
# creating palettes for the labels
|
| 248 |
+
for i, label in enumerate(label_cols):
|
| 249 |
+
labels[label] = self.data.adata.obs[label].unique()
|
| 250 |
+
if len(labels[label]) > 10:
|
| 251 |
+
log.warning(f"More than 10 labels for {label}."
|
| 252 |
+
f"The plots might be hard to read.")
|
| 253 |
+
labels_colors[label] = dict(zip(labels[label],
|
| 254 |
+
umap.generate_pallette(n = len(labels[label]),
|
| 255 |
+
cmap = palettes[i])))
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
figs = {}
|
| 260 |
+
|
| 261 |
+
# if plot_type a string, convert to list
|
| 262 |
+
if isinstance(plot_type, str):
|
| 263 |
+
plot_type = [plot_type]
|
| 264 |
+
|
| 265 |
+
plot_type = [x.lower() for x in plot_type]
|
| 266 |
+
# get unique values in plot_type
|
| 267 |
+
plot_type = [x for i, x in enumerate(plot_type)
|
| 268 |
+
if x not in plot_type[:i]]
|
| 269 |
+
old_plot_type = plot_type
|
| 270 |
+
# check if plot_type is valid
|
| 271 |
+
valid_plot_types = ["simple", "wide", "scanpy"]
|
| 272 |
+
|
| 273 |
+
# create a subset of plot_type that is valid
|
| 274 |
+
plot_type = [x for x in plot_type if x in valid_plot_types]
|
| 275 |
+
if len(plot_type) == 0:
|
| 276 |
+
msg = f"Plot type {plot_type} is not valid. Valid plot types are {valid_plot_types}"
|
| 277 |
+
log.error(msg)
|
| 278 |
+
raise ValueError(msg)
|
| 279 |
+
|
| 280 |
+
# print a warning if plot_type is not valid
|
| 281 |
+
if len(plot_type) < len(old_plot_type):
|
| 282 |
+
log.warning(f"Some plot type(s) {old_plot_type} is not valid. "
|
| 283 |
+
f"Valid plot types are {valid_plot_types}. "
|
| 284 |
+
f"Plotting only {plot_type}")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
plt_emb = "X_umap" if embedding_key != raw_emb else embedding_key
|
| 288 |
+
|
| 289 |
+
plot_title = (plot_title
|
| 290 |
+
if plot_title is not None
|
| 291 |
+
else "UMAP of the cell embeddings")
|
| 292 |
+
|
| 293 |
+
if "simple" in plot_type:
|
| 294 |
+
fig, axs = plt.subplots(ncols = len(label_cols),
|
| 295 |
+
figsize = (len(label_cols) * w, h),
|
| 296 |
+
squeeze = False)
|
| 297 |
+
|
| 298 |
+
axs = axs.flatten()
|
| 299 |
+
|
| 300 |
+
# basic plotting, problematic: size of the points
|
| 301 |
+
embedding = self.data.adata.obsm[plt_emb]
|
| 302 |
+
for i, label in enumerate(label_cols):
|
| 303 |
+
log.debug(f"Plotting the embeddings for {label}")
|
| 304 |
+
# remove axis and grid from the plot
|
| 305 |
+
axs[i].axis('off')
|
| 306 |
+
# plot umap embeddings, add color by cell type
|
| 307 |
+
axs[i].scatter(embedding[:, 0], embedding[:, 1],
|
| 308 |
+
# make points smaller
|
| 309 |
+
s = 0.5,
|
| 310 |
+
c = [labels_colors[label][x] for x
|
| 311 |
+
in self.data.adata.obs[label]])
|
| 312 |
+
legend_handles = [axs[i].plot([], [],
|
| 313 |
+
marker = "o", ls = "",
|
| 314 |
+
color = c, label = l)[0]
|
| 315 |
+
for l, c in labels_colors[label].items()]
|
| 316 |
+
axs[i].legend(handles = legend_handles,
|
| 317 |
+
bbox_to_anchor = (1.05, 1),
|
| 318 |
+
loc = 'upper left')
|
| 319 |
+
|
| 320 |
+
# Add a title to the plot
|
| 321 |
+
axs[i].title.set_text(f"{label}")
|
| 322 |
+
|
| 323 |
+
fig.suptitle(plot_title, fontsize = 16)
|
| 324 |
+
fig.tight_layout()
|
| 325 |
+
fig.subplots_adjust(top = 0.85)
|
| 326 |
+
|
| 327 |
+
fig_savefig = os.path.join(self.output_dir,
|
| 328 |
+
f"umap__{embedding_key}.png")
|
| 329 |
+
fig.savefig(fig_savefig)
|
| 330 |
+
|
| 331 |
+
# if wandb initialized, log the figure
|
| 332 |
+
if self._wandb:
|
| 333 |
+
self._wandb.log({f"umap__{embedding_key}": self._wandb.Image(fig_savefig)})
|
| 334 |
+
|
| 335 |
+
if return_fig:
|
| 336 |
+
figs["umap"] = fig
|
| 337 |
+
|
| 338 |
+
# wide plotting
|
| 339 |
+
if "wide" in plot_type:
|
| 340 |
+
df = pd.DataFrame(self.data.adata.obsm[plt_emb],
|
| 341 |
+
columns = ["umap_1", "umap_2"])
|
| 342 |
+
for i, label in enumerate(label_cols):
|
| 343 |
+
if self.data.adata.obs[label].unique().shape[0] <= 10:
|
| 344 |
+
df[label] = self.data.adata.obs[label].tolist()
|
| 345 |
+
wide_plot = sns.relplot(data = df,
|
| 346 |
+
col = label,
|
| 347 |
+
x = "umap_1",
|
| 348 |
+
y = "umap_2",
|
| 349 |
+
hue = label,
|
| 350 |
+
style = label,
|
| 351 |
+
legend = "full",
|
| 352 |
+
palette = palettes[i])
|
| 353 |
+
# switch off axes
|
| 354 |
+
for axes in wide_plot.axes.flat:
|
| 355 |
+
axes.set_axis_off()
|
| 356 |
+
sns.move_legend(wide_plot, "upper left", bbox_to_anchor=(1, 1))
|
| 357 |
+
wide_plot.fig.suptitle(plot_title, fontsize = 16)
|
| 358 |
+
wide_plot.fig.tight_layout()
|
| 359 |
+
wide_plot.fig.subplots_adjust(top = 0.85)
|
| 360 |
+
|
| 361 |
+
wide_plot_savefig = os.path.join(self.output_dir,
|
| 362 |
+
f"umap_wide__{embedding_key}_{label}.png")
|
| 363 |
+
wide_plot.savefig(wide_plot_savefig)
|
| 364 |
+
|
| 365 |
+
# if wandb initialized, log the figure
|
| 366 |
+
if self._wandb:
|
| 367 |
+
self._wandb.log({f"umap_wide__{embedding_key}_{label}": self._wandb.Image(wide_plot_savefig)})
|
| 368 |
+
if return_fig:
|
| 369 |
+
figs[label] = wide_plot
|
| 370 |
+
else:
|
| 371 |
+
msg = f"More than 10 labels for {label}. Skipping wide plot."
|
| 372 |
+
log.warning(msg)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if "scanpy" in plot_type:
|
| 376 |
+
# scanpy plotting
|
| 377 |
+
labels_colors_flat = {k: v for d in labels_colors
|
| 378 |
+
for k, v in labels_colors[d].items()}
|
| 379 |
+
if embedding_key == raw_emb:
|
| 380 |
+
# TODO: this needs rewriting
|
| 381 |
+
adata_temp__ = self.data.adata.copy()
|
| 382 |
+
adata_temp__.obsm["X_umap"] = self.data.adata.obsm[raw_emb].copy()
|
| 383 |
+
fig2 = sc.pl.umap(adata_temp__,
|
| 384 |
+
color = label_cols,
|
| 385 |
+
add_outline = True,
|
| 386 |
+
layer = plt_emb,
|
| 387 |
+
legend_loc = 'on data',
|
| 388 |
+
palette = labels_colors_flat,
|
| 389 |
+
return_fig = True)
|
| 390 |
+
# remove the temporary adata
|
| 391 |
+
del adata_temp__
|
| 392 |
+
else:
|
| 393 |
+
fig2 = sc.pl.umap(self.data.adata,
|
| 394 |
+
color = label_cols,
|
| 395 |
+
add_outline = True,
|
| 396 |
+
layer = plt_emb,
|
| 397 |
+
legend_loc = 'on data',
|
| 398 |
+
palette = labels_colors_flat,
|
| 399 |
+
return_fig = True)
|
| 400 |
+
fig2.suptitle(plot_title, fontsize = 16)
|
| 401 |
+
fig2.tight_layout()
|
| 402 |
+
fig2.subplots_adjust(top = 0.85)
|
| 403 |
+
|
| 404 |
+
fig2_savefig = os.path.join(self.output_dir,
|
| 405 |
+
f"umap_scanpy__{embedding_key}.png")
|
| 406 |
+
fig2.savefig(fig2_savefig)
|
| 407 |
+
|
| 408 |
+
# if wandb initialized, log the figure
|
| 409 |
+
if self._wandb:
|
| 410 |
+
self._wandb.log({f"umap_scanpy/{embedding_key}": self._wandb.Image(fig2_savefig)})
|
| 411 |
+
|
| 412 |
+
if return_fig:
|
| 413 |
+
figs["umap_scanpy"] = fig2
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
if return_fig:
|
| 417 |
+
return figs
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/data.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Copyright (c) Microsoft Corporation.
|
| 2 |
+
## Licensed under the MIT license.
|
| 3 |
+
import os
|
| 4 |
+
import scanpy as sc
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional, Union, Dict, Literal
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
# from scgpt.preprocess import Preprocessor
|
| 10 |
+
|
| 11 |
+
from .helpers.custom_logging import log
|
| 12 |
+
|
| 13 |
+
# switch of warnings
|
| 14 |
+
import warnings
|
| 15 |
+
os.environ["KMP_WARNINGS"] = "off"
|
| 16 |
+
warnings.filterwarnings('ignore')
|
| 17 |
+
|
| 18 |
+
class InputData():
|
| 19 |
+
def __init__(self,
|
| 20 |
+
adata_dataset_path: str) -> None:
|
| 21 |
+
|
| 22 |
+
# check if the dataset exists
|
| 23 |
+
if not os.path.isfile(adata_dataset_path):
|
| 24 |
+
msg = f"Dataset {adata_dataset_path} does not exist!"
|
| 25 |
+
log.error(msg)
|
| 26 |
+
raise ValueError(msg)
|
| 27 |
+
|
| 28 |
+
msg = f"Loading data from {adata_dataset_path}"
|
| 29 |
+
log.info(msg)
|
| 30 |
+
|
| 31 |
+
self.dataset_name = os.path.basename(adata_dataset_path).split(".")[0]
|
| 32 |
+
self.adata_path = adata_dataset_path
|
| 33 |
+
# read in the dataset
|
| 34 |
+
self.adata = sc.read(adata_dataset_path)
|
| 35 |
+
|
| 36 |
+
self.data_config = dict(
|
| 37 |
+
data_path = adata_dataset_path,
|
| 38 |
+
)
|
| 39 |
+
# this will be updated if add_batch_labels is called
|
| 40 |
+
self.batch_key = None
|
| 41 |
+
|
| 42 |
+
def add_batch_labels(self,
|
| 43 |
+
batch_key: Optional[str] = None,
|
| 44 |
+
batch_str_col: str = "str_batch",
|
| 45 |
+
batch_id_col: str = "batch_id") -> int:
|
| 46 |
+
|
| 47 |
+
self.batch_key = batch_key
|
| 48 |
+
self.batch_id_col = batch_id_col
|
| 49 |
+
self.batch_str_col = batch_str_col
|
| 50 |
+
|
| 51 |
+
if self.batch_key is None:
|
| 52 |
+
# try guessing which column contains batch info
|
| 53 |
+
# get the columns that contain "batch"
|
| 54 |
+
batch_cols = [col for col in
|
| 55 |
+
self.adata.obs.columns if "batch" in col.lower()]
|
| 56 |
+
if len(batch_cols) == 1:
|
| 57 |
+
ori_batch_col = batch_cols[0]
|
| 58 |
+
log.info(f"Using {ori_batch_col} as batch column")
|
| 59 |
+
else:
|
| 60 |
+
msg = "Cannot determine which column contains batch information!"
|
| 61 |
+
log.error(msg)
|
| 62 |
+
raise ValueError(msg)
|
| 63 |
+
else:
|
| 64 |
+
ori_batch_col = self.batch_key
|
| 65 |
+
log.info(f"Using {ori_batch_col} as batch column")
|
| 66 |
+
|
| 67 |
+
self.adata.obs[self.batch_str_col] = (
|
| 68 |
+
self
|
| 69 |
+
.adata
|
| 70 |
+
.obs[ori_batch_col]
|
| 71 |
+
.astype(str)
|
| 72 |
+
)
|
| 73 |
+
batch_id_labels = (
|
| 74 |
+
self.adata
|
| 75 |
+
.obs[self.batch_str_col]
|
| 76 |
+
.astype("category")
|
| 77 |
+
.cat
|
| 78 |
+
.codes
|
| 79 |
+
.values
|
| 80 |
+
)
|
| 81 |
+
self.adata.obs[self.batch_id_col] = batch_id_labels
|
| 82 |
+
log.debug(self.adata.obs[self.batch_id_col].value_counts())
|
| 83 |
+
num_batch_types = len(set(batch_id_labels))
|
| 84 |
+
log.debug(f"Number of batch types: {num_batch_types}")
|
| 85 |
+
return num_batch_types
|
| 86 |
+
|
| 87 |
+
def preprocess_data(self,
|
| 88 |
+
gene_col: str = "gene_name",
|
| 89 |
+
vocab_source: str = "model_default",
|
| 90 |
+
fract_matching: float = 0.5,
|
| 91 |
+
model_type: str = "scGPT",
|
| 92 |
+
# arguments for Geneformer preprocessing
|
| 93 |
+
gene_name_id_dict: Optional[Dict[str, str]] = None,
|
| 94 |
+
filter_gene_by_cells: Optional[int] = 10,
|
| 95 |
+
filter_cell_by_genes: Optional[int] = 10,
|
| 96 |
+
preprocessed_path: Optional[str] = None,
|
| 97 |
+
save_ext: Optional[str] = "loom",
|
| 98 |
+
# arguments for scGPT preprocessing
|
| 99 |
+
gene_vocab: Optional[List[str]] = None,
|
| 100 |
+
data_is_raw: Optional[bool] = True,
|
| 101 |
+
counts_layer: Optional[str] = "X",
|
| 102 |
+
filter_gene_by_counts: Optional[int] = 3,
|
| 103 |
+
filter_cell_by_counts: Optional[Union[int, bool]] = False,
|
| 104 |
+
n_hvg: Optional[Union[int, bool]] = 1200,
|
| 105 |
+
normalize_total: Optional[int] = 1e4,
|
| 106 |
+
n_bins: Optional[int] = 50,
|
| 107 |
+
**kwargs) -> None:
|
| 108 |
+
|
| 109 |
+
if gene_col not in self.adata.var.columns:
|
| 110 |
+
self.adata.var[gene_col] = self.adata.var.index.tolist()
|
| 111 |
+
log.warning(f"Gene names not found in var columns. Using index instead.")
|
| 112 |
+
|
| 113 |
+
self.gene_col = gene_col
|
| 114 |
+
self.data_config["gene_col"] = gene_col
|
| 115 |
+
|
| 116 |
+
# check if model_type is valid
|
| 117 |
+
model_type = model_type.lower()
|
| 118 |
+
valid_model_types = ["scgpt", "geneformer"]
|
| 119 |
+
|
| 120 |
+
if model_type not in valid_model_types:
|
| 121 |
+
msg = (f"Model type {model_type} not supported! "
|
| 122 |
+
f"Valid options are: {valid_model_types}.")
|
| 123 |
+
log.error(msg)
|
| 124 |
+
raise ValueError(msg)
|
| 125 |
+
|
| 126 |
+
self.data_config["model_type"] = model_type
|
| 127 |
+
self.data_config["vocab_source"] = vocab_source
|
| 128 |
+
|
| 129 |
+
# note raw data shape
|
| 130 |
+
self.data_config["input__n_cells"] = self.adata.shape[0]
|
| 131 |
+
self.data_config["input__n_genes"] = self.adata.shape[1]
|
| 132 |
+
|
| 133 |
+
# check if scgpt found in lowercase model string
|
| 134 |
+
if model_type == "scgpt":
|
| 135 |
+
|
| 136 |
+
self.data_config["data_is_raw"] = data_is_raw
|
| 137 |
+
self._preprocess_data_scGPT(gene_vocab = gene_vocab,
|
| 138 |
+
fract_matching = fract_matching,
|
| 139 |
+
input_key = counts_layer,
|
| 140 |
+
filter_gene_by_counts = filter_gene_by_counts,
|
| 141 |
+
filter_cell_by_counts = filter_cell_by_counts,
|
| 142 |
+
normalize_total = normalize_total,
|
| 143 |
+
n_hvg = n_hvg,
|
| 144 |
+
n_bins = n_bins,
|
| 145 |
+
preprocessed_path = preprocessed_path,
|
| 146 |
+
**kwargs)
|
| 147 |
+
|
| 148 |
+
elif model_type == "geneformer":
|
| 149 |
+
|
| 150 |
+
self._preprocess_data_geneformer(preprocessed_path = preprocessed_path,
|
| 151 |
+
save_ext = save_ext,
|
| 152 |
+
gene_name_id_dict = gene_name_id_dict,
|
| 153 |
+
fract_matching = fract_matching,
|
| 154 |
+
filter_cell_by_genes = filter_cell_by_genes,
|
| 155 |
+
filter_gene_by_cells = filter_gene_by_cells)
|
| 156 |
+
|
| 157 |
+
# note raw preprocessed shape
|
| 158 |
+
self.data_config["preprocessed__n_cells"] = self.adata.shape[0]
|
| 159 |
+
self.data_config["preprocessed__n_genes"] = self.adata.shape[1]
|
| 160 |
+
|
| 161 |
+
# def _preprocess_data_scGPT(self,
|
| 162 |
+
# gene_vocab: List[str],
|
| 163 |
+
# fract_matching: float = 0.5,
|
| 164 |
+
# input_key: str = "X",
|
| 165 |
+
# filter_gene_by_counts: int = 3,
|
| 166 |
+
# filter_cell_by_counts: Union[int, bool] = False,
|
| 167 |
+
# normalize_total: int = 1e4,
|
| 168 |
+
# n_hvg: Union[int, bool] = 1200,
|
| 169 |
+
# n_bins: int = 51,
|
| 170 |
+
# normed_key: str = "X_normed",
|
| 171 |
+
# log1p_key: str = "X_log1p",
|
| 172 |
+
# binned_key: str = "X_binned",
|
| 173 |
+
# preprocessed_path: Optional[str] = None) -> None:
|
| 174 |
+
|
| 175 |
+
# # preprocess the data
|
| 176 |
+
# self.adata.var["id_in_vocab"] = [
|
| 177 |
+
# 1 if gene in gene_vocab else -1
|
| 178 |
+
# for gene in self.adata.var[self.gene_col]
|
| 179 |
+
# ]
|
| 180 |
+
# gene_ids_in_vocab = np.array(self.adata.var["id_in_vocab"])
|
| 181 |
+
# fract = np.sum(gene_ids_in_vocab >= 0)/len(gene_ids_in_vocab)
|
| 182 |
+
|
| 183 |
+
# if fract < fract_matching:
|
| 184 |
+
# msg = f"Only {fract*100:.2f}% genes in the dataset are in the vocabulary!"
|
| 185 |
+
# log.error(msg)
|
| 186 |
+
# raise ValueError(msg)
|
| 187 |
+
|
| 188 |
+
# self.adata = self.adata[:, self.adata.var["id_in_vocab"] >= 0]
|
| 189 |
+
# self.data_config["fract_genes_in_vocab"] = fract
|
| 190 |
+
|
| 191 |
+
# log.info(
|
| 192 |
+
# f"Matched {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)}"
|
| 193 |
+
# f" genes in vocabulary of size {len(gene_vocab)}."
|
| 194 |
+
# )
|
| 195 |
+
|
| 196 |
+
# if n_hvg < 1:
|
| 197 |
+
# n_hvg = False
|
| 198 |
+
# # append preprocessing parameters to run config
|
| 199 |
+
# d_ = {
|
| 200 |
+
# "preprocesing__input_key": input_key,
|
| 201 |
+
# "preprocesing__filter_gene_by_counts": filter_gene_by_counts,
|
| 202 |
+
# "preprocesing__filter_cell_by_counts": filter_cell_by_counts,
|
| 203 |
+
# "preprocesing__normalize_total": normalize_total,
|
| 204 |
+
# "preprocesing__normed_key": normed_key,
|
| 205 |
+
# "preprocesing__log1p_key": log1p_key,
|
| 206 |
+
# "preprocesing__binned_key": binned_key,
|
| 207 |
+
# "preprocesing__n_bins": n_bins,
|
| 208 |
+
# "preprocesing__n_hvg": n_hvg,
|
| 209 |
+
# }
|
| 210 |
+
|
| 211 |
+
# self.data_config.update(d_)
|
| 212 |
+
|
| 213 |
+
# msg = "Preprocessing data"
|
| 214 |
+
# log.info(msg)
|
| 215 |
+
|
| 216 |
+
# # Preprocess the data following the scGPT data pre-processing pipeline
|
| 217 |
+
# preprocessor = Preprocessor(
|
| 218 |
+
# # the key in adata.layers to use as raw data
|
| 219 |
+
# use_key = input_key,
|
| 220 |
+
# # step 1
|
| 221 |
+
# filter_gene_by_counts = filter_gene_by_counts,
|
| 222 |
+
# # step 2
|
| 223 |
+
# filter_cell_by_counts = filter_cell_by_counts,
|
| 224 |
+
# # 3. whether to normalize the raw data and to what sum
|
| 225 |
+
# normalize_total = normalize_total,
|
| 226 |
+
# # the key in adata.layers to store the normalized data
|
| 227 |
+
# result_normed_key = normed_key,
|
| 228 |
+
# # 4. whether to log1p the normalized data
|
| 229 |
+
# log1p = self.data_config["data_is_raw"],
|
| 230 |
+
# result_log1p_key = log1p_key,
|
| 231 |
+
# # 5. whether to subset the raw data to highly variable genes
|
| 232 |
+
# subset_hvg = n_hvg,
|
| 233 |
+
# hvg_flavor = ("seurat_v3"
|
| 234 |
+
# if self.data_config["data_is_raw"]
|
| 235 |
+
# else "cell_ranger"),
|
| 236 |
+
# # 6. whether to bin the raw data and to what number of bins
|
| 237 |
+
# binning = n_bins,
|
| 238 |
+
# # the key in adata.layers to store the binned data
|
| 239 |
+
# result_binned_key = binned_key,
|
| 240 |
+
# )
|
| 241 |
+
|
| 242 |
+
# preprocessor(self.adata, batch_key = self.batch_key)
|
| 243 |
+
|
| 244 |
+
# if preprocessed_path is not None:
|
| 245 |
+
# # check if path exists
|
| 246 |
+
# if os.path.exists(preprocessed_path):
|
| 247 |
+
# msg = (f"Saving {self.dataset_name} preprocessed data "
|
| 248 |
+
# f"to {preprocessed_path}")
|
| 249 |
+
# self.adata.write(os.path.join(preprocessed_path,
|
| 250 |
+
# f"{self.dataset_name}.h5ad"))
|
| 251 |
+
# else:
|
| 252 |
+
# msg = (f"Directory {preprocessed_path} does not exist! "
|
| 253 |
+
# "Skipping saving preprocessed data.")
|
| 254 |
+
# log.warning(msg)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _preprocess_data_geneformer(self,
|
| 258 |
+
preprocessed_path: str,
|
| 259 |
+
gene_name_id_dict: Dict[str, str],
|
| 260 |
+
save_ext: Literal["loom", "h5ad"] = "loom",
|
| 261 |
+
fract_matching: float = 0.5,
|
| 262 |
+
filter_cell_by_genes: int = 10,
|
| 263 |
+
filter_gene_by_cells: int = 10) -> None:
|
| 264 |
+
|
| 265 |
+
# for geneformer we need the path to save the data, check if exists
|
| 266 |
+
if preprocessed_path is None or not os.path.exists(preprocessed_path):
|
| 267 |
+
msg = ("For Geneformer, preprocessed_path needs to be specified "
|
| 268 |
+
"and exists to save the dataset. Provided path: "
|
| 269 |
+
f"{preprocessed_path}")
|
| 270 |
+
log.error(msg)
|
| 271 |
+
raise ValueError(msg)
|
| 272 |
+
|
| 273 |
+
sc.pp.calculate_qc_metrics(self.adata,
|
| 274 |
+
percent_top = None,
|
| 275 |
+
log1p = False,
|
| 276 |
+
inplace = True)
|
| 277 |
+
self.adata.obs['n_counts'] = self.adata.obs['total_counts']
|
| 278 |
+
sc.pp.filter_cells(self.adata, min_genes=int(filter_cell_by_genes))
|
| 279 |
+
sc.pp.filter_genes(self.adata, min_cells=int(filter_gene_by_cells))
|
| 280 |
+
|
| 281 |
+
# for now, assuming gene names and using geneformer dictionary
|
| 282 |
+
# to match gene nam to ensembl id; TODO: look into better way?
|
| 283 |
+
# this is tricky because ensembl ids change, in a way
|
| 284 |
+
# gene names are more constant; however they aren't necessarily unique
|
| 285 |
+
# and might be missing from the geneformer dictionary/be different
|
| 286 |
+
# for now, make sure to report the fraction of genes that are matched
|
| 287 |
+
# and save the match/not matched
|
| 288 |
+
|
| 289 |
+
self.adata.var['ensembl_id'] = self.adata.var[self.gene_col].map(gene_name_id_dict)
|
| 290 |
+
self.adata.var['has_ensembl_match'] = self.adata.var['ensembl_id'].notnull()
|
| 291 |
+
|
| 292 |
+
n_all_genes = self.adata.var.shape[0]
|
| 293 |
+
n_matched = self.adata.var.has_ensembl_match.sum()
|
| 294 |
+
fract = n_matched / n_all_genes
|
| 295 |
+
|
| 296 |
+
if fract < fract_matching:
|
| 297 |
+
msg = f"Only {fract*100:.2f}% genes in the dataset are in the vocabulary!"
|
| 298 |
+
log.error(msg)
|
| 299 |
+
raise ValueError(msg)
|
| 300 |
+
|
| 301 |
+
# save the adata.var dataframe
|
| 302 |
+
self.adata.var.to_csv(os.path.join(preprocessed_path,
|
| 303 |
+
f"{self.dataset_name}_var.csv"),
|
| 304 |
+
index = False)
|
| 305 |
+
|
| 306 |
+
# filter out genes that don't have a match
|
| 307 |
+
self.adata = self.adata[:, self.adata.var.has_ensembl_match]
|
| 308 |
+
|
| 309 |
+
# additionally, add the order of the samples, since they will be sorted
|
| 310 |
+
# to speed up forward pass
|
| 311 |
+
self.adata.obs['adata_order'] = self.adata.obs.index.tolist()
|
| 312 |
+
|
| 313 |
+
self.data_config["fract_genes_in_vocab"] = fract
|
| 314 |
+
|
| 315 |
+
log.info(
|
| 316 |
+
f"Matched {fract*100:.2f}% genes ({n_matched}/{n_all_genes})"
|
| 317 |
+
f" genes in vocabulary of size {len(gene_name_id_dict)}."
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
if save_ext == "loom":
|
| 321 |
+
self.adata.write_loom(os.path.join(preprocessed_path,
|
| 322 |
+
f"{self.dataset_name}.loom"))
|
| 323 |
+
elif save_ext == "h5ad":
|
| 324 |
+
self.adata.write_h5ad(os.path.join(preprocessed_path,
|
| 325 |
+
f"{self.dataset_name}.h5ad"))
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def get_config(self):
|
| 329 |
+
return self.data_config
|
| 330 |
+
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/geneformer_forward.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Copyright (c) Microsoft Corporation.
|
| 2 |
+
## Licensed under the MIT license.
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import importlib.util
|
| 6 |
+
import pickle
|
| 7 |
+
|
| 8 |
+
from typing import Dict, Optional, List
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from transformers import BertForMaskedLM
|
| 15 |
+
from geneformer.tokenizer import TranscriptomeTokenizer
|
| 16 |
+
|
| 17 |
+
# from geneformer import EmbExtractor
|
| 18 |
+
from tqdm.auto import trange
|
| 19 |
+
from datasets import Dataset, load_from_disk
|
| 20 |
+
from . import utils
|
| 21 |
+
from .data import InputData
|
| 22 |
+
from .helpers.custom_logging import log
|
| 23 |
+
|
| 24 |
+
from GF_CAB import CustomBertForMaskedLM
|
| 25 |
+
|
| 26 |
+
import warnings
|
| 27 |
+
os.environ["KMP_WARNINGS"] = "off"
|
| 28 |
+
warnings.filterwarnings("ignore")
|
| 29 |
+
|
| 30 |
+
def pad_tensor(t: torch.Tensor,
|
| 31 |
+
max_size: int,
|
| 32 |
+
pad_token_id: int = 0) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Pad a tensor to a max size
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
return F.pad(t, pad = (0, max_size - t.numel()),
|
| 38 |
+
mode = 'constant', value = pad_token_id)
|
| 39 |
+
|
| 40 |
+
# get cell embeddings excluding padding
|
| 41 |
+
def mean_nonpadding_embs(embs, original_lens):
|
| 42 |
+
# mask based on padding lengths
|
| 43 |
+
mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
|
| 44 |
+
|
| 45 |
+
# extend mask dimensions to match the embeddings tensor
|
| 46 |
+
mask = mask.unsqueeze(2).expand_as(embs)
|
| 47 |
+
|
| 48 |
+
# use the mask to zero out the embeddings in padded areas
|
| 49 |
+
masked_embs = embs * mask.float()
|
| 50 |
+
|
| 51 |
+
# sum and divide by the lengths to get the mean of non-padding embs
|
| 52 |
+
mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
|
| 53 |
+
return mean_embs
|
| 54 |
+
|
| 55 |
+
def average_embeddings(embs: torch.Tensor,
|
| 56 |
+
org_lengths: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
|
| 58 |
+
device = embs.device
|
| 59 |
+
|
| 60 |
+
# mask based on padding lengths
|
| 61 |
+
mask = (torch.arange(embs.size(1)).unsqueeze(0).to(device) <
|
| 62 |
+
org_lengths.unsqueeze(1))
|
| 63 |
+
|
| 64 |
+
# extend mask dimensions to match the embeddings tensor
|
| 65 |
+
if len(embs.shape) > 2:
|
| 66 |
+
mask = mask.unsqueeze(2).expand_as(embs)
|
| 67 |
+
|
| 68 |
+
# Use the mask to compute the sum over non-padded areas
|
| 69 |
+
summed_embs = (embs * mask.float()).sum(dim=1)
|
| 70 |
+
|
| 71 |
+
# Divide by the lengths to get the mean of non-padding embs
|
| 72 |
+
mean_embs = summed_embs / org_lengths.view(-1, 1).float()
|
| 73 |
+
|
| 74 |
+
return mean_embs
|
| 75 |
+
|
| 76 |
+
class Geneformer_instance():
|
| 77 |
+
def __init__(self,
|
| 78 |
+
saved_model_path: Optional[str] = None,
|
| 79 |
+
model_run: str = "pretrained",
|
| 80 |
+
model_files: Dict[str, str] = {
|
| 81 |
+
"model_args": "config.json",
|
| 82 |
+
"model_training": "training_args.bin",
|
| 83 |
+
"model_weights": "pytorch_model.bin"
|
| 84 |
+
},
|
| 85 |
+
save_dir: Optional[str] = None,
|
| 86 |
+
explicit_save_dir: bool = False,
|
| 87 |
+
num_workers: int = 0,
|
| 88 |
+
log_wandb: bool = False,
|
| 89 |
+
project_name: str = "Geneformer_eval",
|
| 90 |
+
) -> None:
|
| 91 |
+
|
| 92 |
+
# check if the model run is supported
|
| 93 |
+
supported_model_runs = ["pretrained"] #, "random", "finetune", "train"]
|
| 94 |
+
if model_run not in supported_model_runs:
|
| 95 |
+
msg = f"model_run must be one of {supported_model_runs}"
|
| 96 |
+
log.error(msg)
|
| 97 |
+
raise ValueError(msg)
|
| 98 |
+
self.model_run = model_run
|
| 99 |
+
|
| 100 |
+
self.saved_model_path = saved_model_path
|
| 101 |
+
self.model_files = model_files
|
| 102 |
+
|
| 103 |
+
if num_workers == -1:
|
| 104 |
+
num_workers = len(os.sched_getaffinity(0))
|
| 105 |
+
|
| 106 |
+
if num_workers == 0:
|
| 107 |
+
num_workers = 1
|
| 108 |
+
|
| 109 |
+
self.num_workers = num_workers
|
| 110 |
+
|
| 111 |
+
# check if output directory exists
|
| 112 |
+
if save_dir is not None:
|
| 113 |
+
if explicit_save_dir:
|
| 114 |
+
self.output_dir = save_dir
|
| 115 |
+
else:
|
| 116 |
+
self.output_dir = os.path.join(save_dir,
|
| 117 |
+
self.run_id)
|
| 118 |
+
# if the top out directory does not exist, create it
|
| 119 |
+
if not os.path.exists(save_dir):
|
| 120 |
+
log.warning(f"Creating the top output directory {save_dir}")
|
| 121 |
+
os.makedirs(save_dir)
|
| 122 |
+
else:
|
| 123 |
+
# save in a current path
|
| 124 |
+
self.output_dir = os.path.join(os.getcwd(), self.run_id)
|
| 125 |
+
|
| 126 |
+
# if the out directory already exists, raise an error
|
| 127 |
+
if os.path.exists(self.output_dir) and not explicit_save_dir:
|
| 128 |
+
msg = f"Output directory: {self.output_dir} exists. Something is wrong!"
|
| 129 |
+
log.error(msg)
|
| 130 |
+
raise ValueError(msg)
|
| 131 |
+
|
| 132 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
self.device = torch.device("cuda"
|
| 135 |
+
if torch.cuda.is_available()
|
| 136 |
+
else "cpu")
|
| 137 |
+
|
| 138 |
+
log.info(f"Using device {self.device}")
|
| 139 |
+
|
| 140 |
+
self.project_name = project_name
|
| 141 |
+
if log_wandb:
|
| 142 |
+
has_wandb = importlib.util.find_spec("wandb") is not None
|
| 143 |
+
if not has_wandb:
|
| 144 |
+
msg = "Wandb is not installed. Please install wandb to log to wandb."
|
| 145 |
+
log.error(msg)
|
| 146 |
+
raise RuntimeError(msg)
|
| 147 |
+
if has_wandb:
|
| 148 |
+
import wandb
|
| 149 |
+
self._wandb = wandb
|
| 150 |
+
else:
|
| 151 |
+
self._wandb = None
|
| 152 |
+
|
| 153 |
+
# update this when saved config so that when training it only is saved once
|
| 154 |
+
self.config_saved = False
|
| 155 |
+
|
| 156 |
+
def _check_attr(self,
|
| 157 |
+
attr: str,
|
| 158 |
+
not_none: bool = True) -> bool:
|
| 159 |
+
"""
|
| 160 |
+
Check if the argument is in the class
|
| 161 |
+
"""
|
| 162 |
+
out = hasattr(self, attr)
|
| 163 |
+
if not_none and out:
|
| 164 |
+
out = getattr(self, attr) is not None
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
def load_pretrained_model(self) -> None:
|
| 168 |
+
|
| 169 |
+
# self.model = BertForMaskedLM.from_pretrained(self.saved_model_path,
|
| 170 |
+
# output_attentions=False,
|
| 171 |
+
# output_hidden_states=True)
|
| 172 |
+
self.model = CustomBertForMaskedLM.from_pretrained(self.saved_model_path,
|
| 173 |
+
output_attentions=False,
|
| 174 |
+
output_hidden_states=True)
|
| 175 |
+
|
| 176 |
+
self.model = self.model.to(self.device)
|
| 177 |
+
log.info(f"Model successfully loaded from {self.saved_model_path}")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def load_tokenized_dataset(self,
|
| 181 |
+
dataset_path: str) -> None:
|
| 182 |
+
|
| 183 |
+
self.tokenized_dataset = load_from_disk(dataset_path)
|
| 184 |
+
|
| 185 |
+
def tokenize_data(self,
|
| 186 |
+
adata_path: str,
|
| 187 |
+
dataset_path: str,
|
| 188 |
+
cell_type_col: str = "cell_type",
|
| 189 |
+
columns_to_keep: List[str] = ["adata_order"]):
|
| 190 |
+
|
| 191 |
+
dataset_name = os.path.basename(adata_path).split(".")[0]
|
| 192 |
+
|
| 193 |
+
cols_to_keep = dict(zip([cell_type_col] + columns_to_keep,
|
| 194 |
+
['cell_type'] + columns_to_keep))
|
| 195 |
+
# initialize tokenizer
|
| 196 |
+
self.tokenizer = TranscriptomeTokenizer(cols_to_keep,
|
| 197 |
+
nproc = self.num_workers)
|
| 198 |
+
|
| 199 |
+
# get the extension from adata_path
|
| 200 |
+
_, ext = os.path.splitext(adata_path)
|
| 201 |
+
ext = ext.strip(".")
|
| 202 |
+
|
| 203 |
+
if ext not in ["loom", "h5ad"]:
|
| 204 |
+
msg = f"adata_path must be a loom or h5ad file. Got {ext}"
|
| 205 |
+
log.error(msg)
|
| 206 |
+
raise ValueError(msg)
|
| 207 |
+
|
| 208 |
+
if ext == "h5ad":
|
| 209 |
+
msg = ("using h5ad file. This sometimes causes issues. "
|
| 210 |
+
"If not working try with loom.")
|
| 211 |
+
log.warning(msg)
|
| 212 |
+
|
| 213 |
+
# get the top directory of the adata_path
|
| 214 |
+
adata_dir = os.path.dirname(adata_path)
|
| 215 |
+
|
| 216 |
+
self.tokenizer.tokenize_data(adata_dir,
|
| 217 |
+
dataset_path,
|
| 218 |
+
dataset_name,
|
| 219 |
+
file_format=ext)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# tokenizer does not return the dataset
|
| 223 |
+
# load the dataset
|
| 224 |
+
self.load_tokenized_dataset(os.path.join(dataset_path,
|
| 225 |
+
f"{dataset_name}.dataset"))
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def load_vocab(self,
|
| 229 |
+
dict_paths: str) -> None:
|
| 230 |
+
|
| 231 |
+
token_dictionary_path = os.path.join(dict_paths,
|
| 232 |
+
"token_dictionary.pkl")
|
| 233 |
+
with open(token_dictionary_path, "rb") as f:
|
| 234 |
+
self.vocab = pickle.load(f)
|
| 235 |
+
|
| 236 |
+
self.pad_token_id = self.vocab.get("<pad>")
|
| 237 |
+
|
| 238 |
+
# size of vocabulary
|
| 239 |
+
self.vocab_size = len(self.vocab)
|
| 240 |
+
|
| 241 |
+
gene_name_id_path = os.path.join(dict_paths,
|
| 242 |
+
"gene_name_id_dict.pkl")
|
| 243 |
+
with open(gene_name_id_path, "rb") as f:
|
| 244 |
+
self.gene_name_id = pickle.load(f)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _extend_batch(self,
|
| 248 |
+
batch_dataset: Dataset,
|
| 249 |
+
return_attention_mask: bool = True):
|
| 250 |
+
|
| 251 |
+
max_size = max(batch_dataset['length'])
|
| 252 |
+
|
| 253 |
+
batch_ = [pad_tensor(x, max_size, self.pad_token_id)
|
| 254 |
+
for x in batch_dataset['input_ids']]
|
| 255 |
+
|
| 256 |
+
batch_ = torch.stack(batch_).to(self.device)
|
| 257 |
+
|
| 258 |
+
if return_attention_mask:
|
| 259 |
+
mask_ = [[1] * l + [0] * (max_size - l)
|
| 260 |
+
for l in batch_dataset['length']]
|
| 261 |
+
mask_ = torch.tensor(mask_).to(self.device)
|
| 262 |
+
return batch_, mask_
|
| 263 |
+
|
| 264 |
+
return batch_
|
| 265 |
+
|
| 266 |
+
def _pass_batch(self,
|
| 267 |
+
batch_ids: torch.Tensor,
|
| 268 |
+
attention_mask: torch.Tensor,
|
| 269 |
+
**kwargs) -> torch.Tensor:
|
| 270 |
+
# make sure that batch and attn_mask on the same device
|
| 271 |
+
batch_ids = batch_ids.to(self.device)
|
| 272 |
+
attn_mask = attention_mask.to(self.device)
|
| 273 |
+
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
outputs = self.model(input_ids = batch_ids,
|
| 276 |
+
attention_mask = attn_mask,
|
| 277 |
+
**kwargs)
|
| 278 |
+
|
| 279 |
+
return outputs
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def extract_embeddings(self,
|
| 283 |
+
data: InputData,
|
| 284 |
+
batch_size: int = 48,
|
| 285 |
+
embedding_key: str = "geneformer",
|
| 286 |
+
layer: int = -2):
|
| 287 |
+
|
| 288 |
+
# check if tokenized dataset is loaded
|
| 289 |
+
if not self._check_attr("tokenized_dataset"):
|
| 290 |
+
msg = "Tokenized dataset not loaded. Please load the tokenized dataset."
|
| 291 |
+
log.error(msg)
|
| 292 |
+
raise RuntimeError(msg)
|
| 293 |
+
|
| 294 |
+
# check if layer is valid
|
| 295 |
+
n_layers = self.model.config.num_hidden_layers
|
| 296 |
+
if layer >= n_layers or layer < -n_layers:
|
| 297 |
+
msg = (f"Layer {layer} is not valid. There are only {n_layers} "
|
| 298 |
+
f"Acceptable values are between {-n_layers} (if counting "
|
| 299 |
+
f"forwards) and {n_layers - 1} (if counting backwards)")
|
| 300 |
+
log.error(msg)
|
| 301 |
+
raise ValueError(msg)
|
| 302 |
+
|
| 303 |
+
# save the embeddings to subdir
|
| 304 |
+
embeddings_subdir = os.path.join(self.output_dir, "model_outputs")
|
| 305 |
+
os.makedirs(embeddings_subdir, exist_ok=True)
|
| 306 |
+
|
| 307 |
+
cell_embs_list = []
|
| 308 |
+
rankings_list = []
|
| 309 |
+
|
| 310 |
+
size = len(self.tokenized_dataset)
|
| 311 |
+
|
| 312 |
+
for i in trange(0, size, batch_size,
|
| 313 |
+
desc = "Geneformer (extracting embeddings)"):
|
| 314 |
+
|
| 315 |
+
max_range = min(i+batch_size, size)
|
| 316 |
+
batch_dataset = self.tokenized_dataset.select(list(range(i, max_range)))
|
| 317 |
+
batch_dataset.set_format(type = 'torch')
|
| 318 |
+
|
| 319 |
+
org_lengths = torch.tensor(batch_dataset['length']).to(self.device)
|
| 320 |
+
|
| 321 |
+
batch, attn_mask = self._extend_batch(batch_dataset)
|
| 322 |
+
|
| 323 |
+
model_output = self._pass_batch(batch,
|
| 324 |
+
attention_mask = attn_mask)
|
| 325 |
+
|
| 326 |
+
embs = model_output.hidden_states[layer]
|
| 327 |
+
|
| 328 |
+
# cell_embs = average_embeddings(embs, org_lengths)
|
| 329 |
+
cell_embs = mean_nonpadding_embs(embs, org_lengths)
|
| 330 |
+
|
| 331 |
+
# add cell embeddings to the list
|
| 332 |
+
cell_embs_list.extend(cell_embs.detach().cpu().numpy())
|
| 333 |
+
|
| 334 |
+
# now, get the ranking reconstruction
|
| 335 |
+
out_rankings = (model_output.logits
|
| 336 |
+
.argmax(axis=-1)
|
| 337 |
+
.detach().cpu().numpy())
|
| 338 |
+
|
| 339 |
+
# save the rankings with the original order
|
| 340 |
+
rankings_list.extend(out_rankings)
|
| 341 |
+
|
| 342 |
+
torch.cuda.empty_cache()
|
| 343 |
+
del model_output
|
| 344 |
+
del batch
|
| 345 |
+
del attn_mask
|
| 346 |
+
del embs
|
| 347 |
+
del cell_embs
|
| 348 |
+
|
| 349 |
+
self.cell_embeddings = np.array(cell_embs_list)
|
| 350 |
+
|
| 351 |
+
self.output_rankings = rankings_list
|
| 352 |
+
self.input_rankings = [np.array(item)
|
| 353 |
+
for item
|
| 354 |
+
in self.tokenized_dataset['input_ids']]
|
| 355 |
+
|
| 356 |
+
# add embeddings to adata
|
| 357 |
+
data.adata.obsm[embedding_key] = self.cell_embeddings
|
| 358 |
+
|
| 359 |
+
# for plotting later, save the data.adata.obs
|
| 360 |
+
# order here agrees with the order of the embeddings
|
| 361 |
+
data.adata.obs.to_csv(os.path.join(embeddings_subdir,
|
| 362 |
+
"adata_obs.csv"))
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__init__.py
ADDED
|
File without changes
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/custom_logging.cpython-310.pyc
ADDED
|
Binary file (625 Bytes). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/custom_logging.cpython-311.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/umap.cpython-310.pyc
ADDED
|
Binary file (3.89 kB). View file
|
|
|
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/umap.cpython-311.pyc
ADDED
|
Binary file (6.04 kB). View file
|
|
|