Christina Theodoris
commited on
Commit
·
4bddd45
1
Parent(s):
5a43832
add option for hyperparameter tuning to cc.validate
Browse files- examples/cell_classification.ipynb +3 -2
- examples/hyperparam_optimiz_for_disease_classifier.py +0 -226
- geneformer/classifier.py +235 -23
- geneformer/classifier_utils.py +21 -2
- requirements.txt +1 -0
examples/cell_classification.ipynb
CHANGED
|
@@ -13,7 +13,7 @@
|
|
| 13 |
"id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
|
| 14 |
"metadata": {},
|
| 15 |
"source": [
|
| 16 |
-
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
|
@@ -266,7 +266,8 @@
|
|
| 266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 267 |
" output_directory=output_dir,\n",
|
| 268 |
" output_prefix=output_prefix,\n",
|
| 269 |
-
" split_id_dict=train_valid_id_split_dict)"
|
|
|
|
| 270 |
]
|
| 271 |
},
|
| 272 |
{
|
|
|
|
| 13 |
"id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
|
| 14 |
"metadata": {},
|
| 15 |
"source": [
|
| 16 |
+
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
|
|
|
| 266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 267 |
" output_directory=output_dir,\n",
|
| 268 |
" output_prefix=output_prefix,\n",
|
| 269 |
+
" split_id_dict=train_valid_id_split_dict)\n",
|
| 270 |
+
" # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)"
|
| 271 |
]
|
| 272 |
},
|
| 273 |
{
|
examples/hyperparam_optimiz_for_disease_classifier.py
DELETED
|
@@ -1,226 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# coding: utf-8
|
| 3 |
-
|
| 4 |
-
# hyperparameter optimization with raytune for disease classification
|
| 5 |
-
|
| 6 |
-
# imports
|
| 7 |
-
import os
|
| 8 |
-
import subprocess
|
| 9 |
-
GPU_NUMBER = [0,1,2,3]
|
| 10 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
| 11 |
-
os.environ["NCCL_DEBUG"] = "INFO"
|
| 12 |
-
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
| 13 |
-
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
|
| 14 |
-
|
| 15 |
-
# initiate runtime environment for raytune
|
| 16 |
-
import pyarrow # must occur prior to ray import
|
| 17 |
-
import ray
|
| 18 |
-
from ray import tune
|
| 19 |
-
from ray.tune import ExperimentAnalysis
|
| 20 |
-
from ray.tune.suggest.hyperopt import HyperOptSearch
|
| 21 |
-
ray.shutdown() #engage new ray session
|
| 22 |
-
runtime_env = {"conda": "base",
|
| 23 |
-
"env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
|
| 24 |
-
ray.init(runtime_env=runtime_env)
|
| 25 |
-
|
| 26 |
-
def initialize_ray_with_check(ip_address):
|
| 27 |
-
"""
|
| 28 |
-
Initialize Ray with a specified IP address and check its status and accessibility.
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
- ip_address (str): The IP address (with port) to initialize Ray.
|
| 32 |
-
|
| 33 |
-
Returns:
|
| 34 |
-
- bool: True if initialization was successful and dashboard is accessible, False otherwise.
|
| 35 |
-
"""
|
| 36 |
-
try:
|
| 37 |
-
ray.init(address=ip_address)
|
| 38 |
-
print(ray.nodes())
|
| 39 |
-
|
| 40 |
-
services = ray.get_webui_url()
|
| 41 |
-
if not services:
|
| 42 |
-
raise RuntimeError("Ray dashboard is not accessible.")
|
| 43 |
-
else:
|
| 44 |
-
print(f"Ray dashboard is accessible at: {services}")
|
| 45 |
-
return True
|
| 46 |
-
except Exception as e:
|
| 47 |
-
print(f"Error initializing Ray: {e}")
|
| 48 |
-
return False
|
| 49 |
-
|
| 50 |
-
# Usage:
|
| 51 |
-
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
|
| 52 |
-
if initialize_ray_with_check(ip):
|
| 53 |
-
print("Ray initialized successfully.")
|
| 54 |
-
else:
|
| 55 |
-
print("Error during Ray initialization.")
|
| 56 |
-
|
| 57 |
-
import datetime
|
| 58 |
-
import numpy as np
|
| 59 |
-
import pandas as pd
|
| 60 |
-
import random
|
| 61 |
-
import seaborn as sns; sns.set()
|
| 62 |
-
from collections import Counter
|
| 63 |
-
from datasets import load_from_disk
|
| 64 |
-
from scipy.stats import ranksums
|
| 65 |
-
from sklearn.metrics import accuracy_score
|
| 66 |
-
from transformers import BertForSequenceClassification
|
| 67 |
-
from transformers import Trainer
|
| 68 |
-
from transformers.training_args import TrainingArguments
|
| 69 |
-
|
| 70 |
-
from geneformer import DataCollatorForCellClassification
|
| 71 |
-
|
| 72 |
-
# number of CPU cores
|
| 73 |
-
num_proc=30
|
| 74 |
-
|
| 75 |
-
# load train dataset with columns:
|
| 76 |
-
# cell_type (annotation of each cell's type)
|
| 77 |
-
# disease (healthy or disease state)
|
| 78 |
-
# individual (unique ID for each patient)
|
| 79 |
-
# length (length of that cell's rank value encoding)
|
| 80 |
-
train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
|
| 81 |
-
|
| 82 |
-
# filter dataset for given cell_type
|
| 83 |
-
def if_cell_type(example):
|
| 84 |
-
return example["cell_type"].startswith("Cardiomyocyte")
|
| 85 |
-
|
| 86 |
-
trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
|
| 87 |
-
|
| 88 |
-
# create dictionary of disease states : label ids
|
| 89 |
-
target_names = ["healthy", "disease1", "disease2"]
|
| 90 |
-
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
| 91 |
-
|
| 92 |
-
trainset_v3 = trainset_v2.rename_column("disease","label")
|
| 93 |
-
|
| 94 |
-
# change labels to numerical ids
|
| 95 |
-
def classes_to_ids(example):
|
| 96 |
-
example["label"] = target_name_id_dict[example["label"]]
|
| 97 |
-
return example
|
| 98 |
-
|
| 99 |
-
trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
|
| 100 |
-
|
| 101 |
-
# separate into train, validation, test sets
|
| 102 |
-
indiv_set = set(trainset_v4["individual"])
|
| 103 |
-
random.seed(42)
|
| 104 |
-
train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
|
| 105 |
-
eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
|
| 106 |
-
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
|
| 107 |
-
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
|
| 108 |
-
|
| 109 |
-
def if_train(example):
|
| 110 |
-
return example["individual"] in train_indiv
|
| 111 |
-
|
| 112 |
-
classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
|
| 113 |
-
|
| 114 |
-
def if_valid(example):
|
| 115 |
-
return example["individual"] in valid_indiv
|
| 116 |
-
|
| 117 |
-
classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
|
| 118 |
-
|
| 119 |
-
# define output directory path
|
| 120 |
-
current_date = datetime.datetime.now()
|
| 121 |
-
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
| 122 |
-
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
|
| 123 |
-
|
| 124 |
-
# ensure not overwriting previously saved model
|
| 125 |
-
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
| 126 |
-
if os.path.isfile(saved_model_test) == True:
|
| 127 |
-
raise Exception("Model already saved to this directory.")
|
| 128 |
-
|
| 129 |
-
# make output directory
|
| 130 |
-
subprocess.call(f'mkdir {output_dir}', shell=True)
|
| 131 |
-
|
| 132 |
-
# set training parameters
|
| 133 |
-
# how many pretrained layers to freeze
|
| 134 |
-
freeze_layers = 2
|
| 135 |
-
# batch size for training and eval
|
| 136 |
-
geneformer_batch_size = 12
|
| 137 |
-
# number of epochs
|
| 138 |
-
epochs = 1
|
| 139 |
-
# logging steps
|
| 140 |
-
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
|
| 141 |
-
|
| 142 |
-
# define function to initiate model
|
| 143 |
-
def model_init():
|
| 144 |
-
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
|
| 145 |
-
num_labels=len(target_names),
|
| 146 |
-
output_attentions = False,
|
| 147 |
-
output_hidden_states = False)
|
| 148 |
-
if freeze_layers is not None:
|
| 149 |
-
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
| 150 |
-
for module in modules_to_freeze:
|
| 151 |
-
for param in module.parameters():
|
| 152 |
-
param.requires_grad = False
|
| 153 |
-
|
| 154 |
-
model = model.to("cuda:0")
|
| 155 |
-
return model
|
| 156 |
-
|
| 157 |
-
# define metrics
|
| 158 |
-
# note: macro f1 score recommended for imbalanced multiclass classifiers
|
| 159 |
-
def compute_metrics(pred):
|
| 160 |
-
labels = pred.label_ids
|
| 161 |
-
preds = pred.predictions.argmax(-1)
|
| 162 |
-
# calculate accuracy using sklearn's function
|
| 163 |
-
acc = accuracy_score(labels, preds)
|
| 164 |
-
return {
|
| 165 |
-
'accuracy': acc,
|
| 166 |
-
}
|
| 167 |
-
|
| 168 |
-
# set training arguments
|
| 169 |
-
training_args = {
|
| 170 |
-
"do_train": True,
|
| 171 |
-
"do_eval": True,
|
| 172 |
-
"evaluation_strategy": "steps",
|
| 173 |
-
"eval_steps": logging_steps,
|
| 174 |
-
"logging_steps": logging_steps,
|
| 175 |
-
"group_by_length": True,
|
| 176 |
-
"length_column_name": "length",
|
| 177 |
-
"disable_tqdm": True,
|
| 178 |
-
"skip_memory_metrics": True, # memory tracker causes errors in raytune
|
| 179 |
-
"per_device_train_batch_size": geneformer_batch_size,
|
| 180 |
-
"per_device_eval_batch_size": geneformer_batch_size,
|
| 181 |
-
"num_train_epochs": epochs,
|
| 182 |
-
"load_best_model_at_end": True,
|
| 183 |
-
"output_dir": output_dir,
|
| 184 |
-
}
|
| 185 |
-
|
| 186 |
-
training_args_init = TrainingArguments(**training_args)
|
| 187 |
-
|
| 188 |
-
# create the trainer
|
| 189 |
-
trainer = Trainer(
|
| 190 |
-
model_init=model_init,
|
| 191 |
-
args=training_args_init,
|
| 192 |
-
data_collator=DataCollatorForCellClassification(),
|
| 193 |
-
train_dataset=classifier_trainset,
|
| 194 |
-
eval_dataset=classifier_validset,
|
| 195 |
-
compute_metrics=compute_metrics,
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
# specify raytune hyperparameter search space
|
| 199 |
-
ray_config = {
|
| 200 |
-
"num_train_epochs": tune.choice([epochs]),
|
| 201 |
-
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
| 202 |
-
"weight_decay": tune.uniform(0.0, 0.3),
|
| 203 |
-
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
|
| 204 |
-
"warmup_steps": tune.uniform(100, 2000),
|
| 205 |
-
"seed": tune.uniform(0,100),
|
| 206 |
-
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
hyperopt_search = HyperOptSearch(
|
| 210 |
-
metric="eval_accuracy", mode="max")
|
| 211 |
-
|
| 212 |
-
# optimize hyperparameters
|
| 213 |
-
trainer.hyperparameter_search(
|
| 214 |
-
direction="maximize",
|
| 215 |
-
backend="ray",
|
| 216 |
-
resources_per_trial={"cpu":8,"gpu":1},
|
| 217 |
-
hp_space=lambda _: ray_config,
|
| 218 |
-
search_alg=hyperopt_search,
|
| 219 |
-
n_trials=100, # number of trials
|
| 220 |
-
progress_reporter=tune.CLIReporter(max_report_frequency=600,
|
| 221 |
-
sort_by_metric=True,
|
| 222 |
-
max_progress_rows=100,
|
| 223 |
-
mode="max",
|
| 224 |
-
metric="eval_accuracy",
|
| 225 |
-
metric_columns=["loss", "eval_loss", "eval_accuracy"])
|
| 226 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/classifier.py
CHANGED
|
@@ -82,11 +82,12 @@ class Classifier:
|
|
| 82 |
"training_args": {None, dict},
|
| 83 |
"freeze_layers": {int},
|
| 84 |
"num_crossval_splits": {0, 1, 5},
|
| 85 |
-
"
|
| 86 |
"no_eval": {bool},
|
| 87 |
"stratify_splits_col": {None, str},
|
| 88 |
"forward_batch_size": {int},
|
| 89 |
"nproc": {int},
|
|
|
|
| 90 |
}
|
| 91 |
|
| 92 |
def __init__(
|
|
@@ -99,13 +100,15 @@ class Classifier:
|
|
| 99 |
max_ncells=None,
|
| 100 |
max_ncells_per_class=None,
|
| 101 |
training_args=None,
|
|
|
|
| 102 |
freeze_layers=0,
|
| 103 |
num_crossval_splits=1,
|
| 104 |
-
|
| 105 |
stratify_splits_col=None,
|
| 106 |
no_eval=False,
|
| 107 |
forward_batch_size=100,
|
| 108 |
nproc=4,
|
|
|
|
| 109 |
):
|
| 110 |
"""
|
| 111 |
Initialize Geneformer classifier.
|
|
@@ -152,15 +155,18 @@ class Classifier:
|
|
| 152 |
| Otherwise, will use the Hugging Face defaults:
|
| 153 |
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
| 154 |
| Note: Hyperparameter tuning is highly recommended, rather than using defaults.
|
|
|
|
|
|
|
| 155 |
freeze_layers : int
|
| 156 |
| Number of layers to freeze from fine-tuning.
|
| 157 |
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
|
| 158 |
num_crossval_splits : {0, 1, 5}
|
| 159 |
| 0: train on all data without splitting
|
| 160 |
-
| 1: split data into train and eval sets by designated
|
| 161 |
-
| 5: split data into 5 folds of train and eval sets by designated
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
| 164 |
stratify_splits_col : None, str
|
| 165 |
| Name of column in .dataset to be used for stratified splitting.
|
| 166 |
| Proportion of each class in this column will be the same in the splits as in the original dataset.
|
|
@@ -171,6 +177,8 @@ class Classifier:
|
|
| 171 |
| Batch size for forward pass (for evaluation, not training).
|
| 172 |
nproc : int
|
| 173 |
| Number of CPU processes to use.
|
|
|
|
|
|
|
| 174 |
|
| 175 |
"""
|
| 176 |
|
|
@@ -182,13 +190,19 @@ class Classifier:
|
|
| 182 |
self.max_ncells = max_ncells
|
| 183 |
self.max_ncells_per_class = max_ncells_per_class
|
| 184 |
self.training_args = training_args
|
|
|
|
| 185 |
self.freeze_layers = freeze_layers
|
| 186 |
self.num_crossval_splits = num_crossval_splits
|
| 187 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
self.stratify_splits_col = stratify_splits_col
|
| 189 |
self.no_eval = no_eval
|
| 190 |
self.forward_batch_size = forward_batch_size
|
| 191 |
self.nproc = nproc
|
|
|
|
| 192 |
|
| 193 |
if self.training_args is None:
|
| 194 |
logger.warning(
|
|
@@ -301,6 +315,9 @@ class Classifier:
|
|
| 301 |
"Gene_class_dict should contain at least 2 gene classes to classify."
|
| 302 |
)
|
| 303 |
raise
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
def prepare_data(
|
| 306 |
self,
|
|
@@ -337,6 +354,7 @@ class Classifier:
|
|
| 337 |
test_size : None, float
|
| 338 |
| Proportion of data to be saved separately and held out for test set
|
| 339 |
| (e.g. 0.2 if intending hold out 20%)
|
|
|
|
| 340 |
| The training set will be further split to train / validation in self.validate
|
| 341 |
| Note: only available for CellClassifiers
|
| 342 |
attr_to_split : None, str
|
|
@@ -356,6 +374,9 @@ class Classifier:
|
|
| 356 |
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
| 357 |
"""
|
| 358 |
|
|
|
|
|
|
|
|
|
|
| 359 |
# prepare data and labels for classification
|
| 360 |
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
|
| 361 |
|
|
@@ -555,6 +576,7 @@ class Classifier:
|
|
| 555 |
save_eval_output=True,
|
| 556 |
predict_eval=True,
|
| 557 |
predict_trainer=False,
|
|
|
|
| 558 |
):
|
| 559 |
"""
|
| 560 |
(Cross-)validate cell state or gene classifier.
|
|
@@ -604,6 +626,9 @@ class Classifier:
|
|
| 604 |
predict_trainer : bool
|
| 605 |
| Whether or not to save eval predictions from trainer
|
| 606 |
| Saves as a pickle file of trainer predictions
|
|
|
|
|
|
|
|
|
|
| 607 |
"""
|
| 608 |
|
| 609 |
if self.num_crossval_splits == 0:
|
|
@@ -700,14 +725,30 @@ class Classifier:
|
|
| 700 |
]
|
| 701 |
eval_data = data.select(eval_indices)
|
| 702 |
train_data = data.select(train_indices)
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
result = self.evaluate_model(
|
| 712 |
trainer.model,
|
| 713 |
num_classes,
|
|
@@ -752,14 +793,29 @@ class Classifier:
|
|
| 752 |
self.nproc,
|
| 753 |
)
|
| 754 |
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
result = self.evaluate_model(
|
| 764 |
trainer.model,
|
| 765 |
num_classes,
|
|
@@ -810,6 +866,162 @@ class Classifier:
|
|
| 810 |
|
| 811 |
return all_metrics
|
| 812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
def train_classifier(
|
| 814 |
self,
|
| 815 |
model_directory,
|
|
|
|
| 82 |
"training_args": {None, dict},
|
| 83 |
"freeze_layers": {int},
|
| 84 |
"num_crossval_splits": {0, 1, 5},
|
| 85 |
+
"split_sizes": {None, dict},
|
| 86 |
"no_eval": {bool},
|
| 87 |
"stratify_splits_col": {None, str},
|
| 88 |
"forward_batch_size": {int},
|
| 89 |
"nproc": {int},
|
| 90 |
+
"ngpu": {int},
|
| 91 |
}
|
| 92 |
|
| 93 |
def __init__(
|
|
|
|
| 100 |
max_ncells=None,
|
| 101 |
max_ncells_per_class=None,
|
| 102 |
training_args=None,
|
| 103 |
+
ray_config=None,
|
| 104 |
freeze_layers=0,
|
| 105 |
num_crossval_splits=1,
|
| 106 |
+
split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
|
| 107 |
stratify_splits_col=None,
|
| 108 |
no_eval=False,
|
| 109 |
forward_batch_size=100,
|
| 110 |
nproc=4,
|
| 111 |
+
ngpu=1,
|
| 112 |
):
|
| 113 |
"""
|
| 114 |
Initialize Geneformer classifier.
|
|
|
|
| 155 |
| Otherwise, will use the Hugging Face defaults:
|
| 156 |
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
| 157 |
| Note: Hyperparameter tuning is highly recommended, rather than using defaults.
|
| 158 |
+
ray_config : None, dict
|
| 159 |
+
| Training argument ranges for tuning hyperparameters with Ray.
|
| 160 |
freeze_layers : int
|
| 161 |
| Number of layers to freeze from fine-tuning.
|
| 162 |
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
|
| 163 |
num_crossval_splits : {0, 1, 5}
|
| 164 |
| 0: train on all data without splitting
|
| 165 |
+
| 1: split data into train and eval sets by designated split_sizes["valid"]
|
| 166 |
+
| 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
|
| 167 |
+
split_sizes : None, dict
|
| 168 |
+
| Dictionary of proportion of data to hold out for train, validation, and test sets
|
| 169 |
+
| {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
|
| 170 |
stratify_splits_col : None, str
|
| 171 |
| Name of column in .dataset to be used for stratified splitting.
|
| 172 |
| Proportion of each class in this column will be the same in the splits as in the original dataset.
|
|
|
|
| 177 |
| Batch size for forward pass (for evaluation, not training).
|
| 178 |
nproc : int
|
| 179 |
| Number of CPU processes to use.
|
| 180 |
+
ngpu : int
|
| 181 |
+
| Number of GPUs available.
|
| 182 |
|
| 183 |
"""
|
| 184 |
|
|
|
|
| 190 |
self.max_ncells = max_ncells
|
| 191 |
self.max_ncells_per_class = max_ncells_per_class
|
| 192 |
self.training_args = training_args
|
| 193 |
+
self.ray_config = ray_config
|
| 194 |
self.freeze_layers = freeze_layers
|
| 195 |
self.num_crossval_splits = num_crossval_splits
|
| 196 |
+
self.split_sizes = split_sizes
|
| 197 |
+
self.train_size = self.split_sizes["train"]
|
| 198 |
+
self.valid_size = self.split_sizes["valid"]
|
| 199 |
+
self.oos_test_size = self.split_sizes["test"]
|
| 200 |
+
self.eval_size = self.valid_size / (self.train_size + self.valid_size)
|
| 201 |
self.stratify_splits_col = stratify_splits_col
|
| 202 |
self.no_eval = no_eval
|
| 203 |
self.forward_batch_size = forward_batch_size
|
| 204 |
self.nproc = nproc
|
| 205 |
+
self.ngpu = ngpu
|
| 206 |
|
| 207 |
if self.training_args is None:
|
| 208 |
logger.warning(
|
|
|
|
| 315 |
"Gene_class_dict should contain at least 2 gene classes to classify."
|
| 316 |
)
|
| 317 |
raise
|
| 318 |
+
if sum(self.split_sizes.values()) != 1:
|
| 319 |
+
logger.error("Train, validation, and test proportions should sum to 1.")
|
| 320 |
+
raise
|
| 321 |
|
| 322 |
def prepare_data(
|
| 323 |
self,
|
|
|
|
| 354 |
test_size : None, float
|
| 355 |
| Proportion of data to be saved separately and held out for test set
|
| 356 |
| (e.g. 0.2 if intending hold out 20%)
|
| 357 |
+
| If None, will inherit from split_sizes["test"] from Classifier
|
| 358 |
| The training set will be further split to train / validation in self.validate
|
| 359 |
| Note: only available for CellClassifiers
|
| 360 |
attr_to_split : None, str
|
|
|
|
| 374 |
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
| 375 |
"""
|
| 376 |
|
| 377 |
+
if test_size is None:
|
| 378 |
+
test_size = self.oos_test_size
|
| 379 |
+
|
| 380 |
# prepare data and labels for classification
|
| 381 |
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
|
| 382 |
|
|
|
|
| 576 |
save_eval_output=True,
|
| 577 |
predict_eval=True,
|
| 578 |
predict_trainer=False,
|
| 579 |
+
n_hyperopt_trials=0,
|
| 580 |
):
|
| 581 |
"""
|
| 582 |
(Cross-)validate cell state or gene classifier.
|
|
|
|
| 626 |
predict_trainer : bool
|
| 627 |
| Whether or not to save eval predictions from trainer
|
| 628 |
| Saves as a pickle file of trainer predictions
|
| 629 |
+
n_hyperopt_trials : int
|
| 630 |
+
| Number of trials to run for hyperparameter optimization
|
| 631 |
+
| If 0, will not optimize hyperparameters
|
| 632 |
"""
|
| 633 |
|
| 634 |
if self.num_crossval_splits == 0:
|
|
|
|
| 725 |
]
|
| 726 |
eval_data = data.select(eval_indices)
|
| 727 |
train_data = data.select(train_indices)
|
| 728 |
+
if n_hyperopt_trials == 0:
|
| 729 |
+
trainer = self.train_classifier(
|
| 730 |
+
model_directory,
|
| 731 |
+
num_classes,
|
| 732 |
+
train_data,
|
| 733 |
+
eval_data,
|
| 734 |
+
ksplit_output_dir,
|
| 735 |
+
predict_trainer,
|
| 736 |
+
)
|
| 737 |
+
else:
|
| 738 |
+
trainer = self.hyperopt_classifier(
|
| 739 |
+
model_directory,
|
| 740 |
+
num_classes,
|
| 741 |
+
train_data,
|
| 742 |
+
eval_data,
|
| 743 |
+
ksplit_output_dir,
|
| 744 |
+
n_trials=n_hyperopt_trials,
|
| 745 |
+
)
|
| 746 |
+
if iteration_num == self.num_crossval_splits:
|
| 747 |
+
return
|
| 748 |
+
else:
|
| 749 |
+
iteration_num = iteration_num + 1
|
| 750 |
+
continue
|
| 751 |
+
|
| 752 |
result = self.evaluate_model(
|
| 753 |
trainer.model,
|
| 754 |
num_classes,
|
|
|
|
| 793 |
self.nproc,
|
| 794 |
)
|
| 795 |
|
| 796 |
+
if n_hyperopt_trials == 0:
|
| 797 |
+
trainer = self.train_classifier(
|
| 798 |
+
model_directory,
|
| 799 |
+
num_classes,
|
| 800 |
+
train_data,
|
| 801 |
+
eval_data,
|
| 802 |
+
ksplit_output_dir,
|
| 803 |
+
predict_trainer,
|
| 804 |
+
)
|
| 805 |
+
else:
|
| 806 |
+
trainer = self.hyperopt_classifier(
|
| 807 |
+
model_directory,
|
| 808 |
+
num_classes,
|
| 809 |
+
train_data,
|
| 810 |
+
eval_data,
|
| 811 |
+
ksplit_output_dir,
|
| 812 |
+
n_trials=n_hyperopt_trials,
|
| 813 |
+
)
|
| 814 |
+
if iteration_num == self.num_crossval_splits:
|
| 815 |
+
return
|
| 816 |
+
else:
|
| 817 |
+
iteration_num = iteration_num + 1
|
| 818 |
+
continue
|
| 819 |
result = self.evaluate_model(
|
| 820 |
trainer.model,
|
| 821 |
num_classes,
|
|
|
|
| 866 |
|
| 867 |
return all_metrics
|
| 868 |
|
| 869 |
+
def hyperopt_classifier(
|
| 870 |
+
self,
|
| 871 |
+
model_directory,
|
| 872 |
+
num_classes,
|
| 873 |
+
train_data,
|
| 874 |
+
eval_data,
|
| 875 |
+
output_directory,
|
| 876 |
+
n_trials=100,
|
| 877 |
+
):
|
| 878 |
+
"""
|
| 879 |
+
Fine-tune model for cell state or gene classification.
|
| 880 |
+
|
| 881 |
+
**Parameters**
|
| 882 |
+
|
| 883 |
+
model_directory : Path
|
| 884 |
+
| Path to directory containing model
|
| 885 |
+
num_classes : int
|
| 886 |
+
| Number of classes for classifier
|
| 887 |
+
train_data : Dataset
|
| 888 |
+
| Loaded training .dataset input
|
| 889 |
+
| For cell classifier, labels in column "label".
|
| 890 |
+
| For gene classifier, labels in column "labels".
|
| 891 |
+
eval_data : None, Dataset
|
| 892 |
+
| (Optional) Loaded evaluation .dataset input
|
| 893 |
+
| For cell classifier, labels in column "label".
|
| 894 |
+
| For gene classifier, labels in column "labels".
|
| 895 |
+
output_directory : Path
|
| 896 |
+
| Path to directory where fine-tuned model will be saved
|
| 897 |
+
n_trials : int
|
| 898 |
+
| Number of trials to run for hyperparameter optimization
|
| 899 |
+
"""
|
| 900 |
+
|
| 901 |
+
# initiate runtime environment for raytune
|
| 902 |
+
import ray
|
| 903 |
+
from ray import tune
|
| 904 |
+
from ray.tune.search.hyperopt import HyperOptSearch
|
| 905 |
+
|
| 906 |
+
ray.shutdown() # engage new ray session
|
| 907 |
+
ray.init()
|
| 908 |
+
|
| 909 |
+
##### Validate and prepare data #####
|
| 910 |
+
train_data, eval_data = cu.validate_and_clean_cols(
|
| 911 |
+
train_data, eval_data, self.classifier
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
if (self.no_eval is True) and (eval_data is not None):
|
| 915 |
+
logger.warning(
|
| 916 |
+
"no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
# ensure not overwriting previously saved model
|
| 920 |
+
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
|
| 921 |
+
if os.path.isfile(saved_model_test) is True:
|
| 922 |
+
logger.error("Model already saved to this designated output directory.")
|
| 923 |
+
raise
|
| 924 |
+
# make output directory
|
| 925 |
+
subprocess.call(f"mkdir {output_directory}", shell=True)
|
| 926 |
+
|
| 927 |
+
##### Load model and training args #####
|
| 928 |
+
if self.classifier == "cell":
|
| 929 |
+
model_type = "CellClassifier"
|
| 930 |
+
elif self.classifier == "gene":
|
| 931 |
+
model_type = "GeneClassifier"
|
| 932 |
+
|
| 933 |
+
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
| 934 |
+
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
| 935 |
+
model, self.classifier, train_data, output_directory
|
| 936 |
+
)
|
| 937 |
+
del model
|
| 938 |
+
|
| 939 |
+
if self.training_args is not None:
|
| 940 |
+
def_training_args.update(self.training_args)
|
| 941 |
+
logging_steps = round(
|
| 942 |
+
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
|
| 943 |
+
)
|
| 944 |
+
def_training_args["logging_steps"] = logging_steps
|
| 945 |
+
def_training_args["output_dir"] = output_directory
|
| 946 |
+
if eval_data is None:
|
| 947 |
+
def_training_args["evaluation_strategy"] = "no"
|
| 948 |
+
def_training_args["load_best_model_at_end"] = False
|
| 949 |
+
training_args_init = TrainingArguments(**def_training_args)
|
| 950 |
+
|
| 951 |
+
##### Fine-tune the model #####
|
| 952 |
+
# define the data collator
|
| 953 |
+
if self.classifier == "cell":
|
| 954 |
+
data_collator = DataCollatorForCellClassification()
|
| 955 |
+
elif self.classifier == "gene":
|
| 956 |
+
data_collator = DataCollatorForGeneClassification()
|
| 957 |
+
|
| 958 |
+
# define function to initiate model
|
| 959 |
+
def model_init():
|
| 960 |
+
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
| 961 |
+
|
| 962 |
+
if self.freeze_layers is not None:
|
| 963 |
+
def_freeze_layers = self.freeze_layers
|
| 964 |
+
|
| 965 |
+
if def_freeze_layers > 0:
|
| 966 |
+
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
|
| 967 |
+
for module in modules_to_freeze:
|
| 968 |
+
for param in module.parameters():
|
| 969 |
+
param.requires_grad = False
|
| 970 |
+
|
| 971 |
+
model = model.to("cuda:0")
|
| 972 |
+
return model
|
| 973 |
+
|
| 974 |
+
# create the trainer
|
| 975 |
+
trainer = Trainer(
|
| 976 |
+
model_init=model_init,
|
| 977 |
+
args=training_args_init,
|
| 978 |
+
data_collator=data_collator,
|
| 979 |
+
train_dataset=train_data,
|
| 980 |
+
eval_dataset=eval_data,
|
| 981 |
+
compute_metrics=cu.compute_metrics,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# specify raytune hyperparameter search space
|
| 985 |
+
if self.ray_config is None:
|
| 986 |
+
logger.warning(
|
| 987 |
+
"No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
|
| 988 |
+
)
|
| 989 |
+
def_ray_config = {
|
| 990 |
+
"num_train_epochs": tune.choice([1]),
|
| 991 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
| 992 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
| 993 |
+
"lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
|
| 994 |
+
"warmup_steps": tune.uniform(100, 2000),
|
| 995 |
+
"seed": tune.uniform(0, 100),
|
| 996 |
+
"per_device_train_batch_size": tune.choice(
|
| 997 |
+
[def_training_args["per_device_train_batch_size"]]
|
| 998 |
+
),
|
| 999 |
+
}
|
| 1000 |
+
|
| 1001 |
+
hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
|
| 1002 |
+
|
| 1003 |
+
# optimize hyperparameters
|
| 1004 |
+
trainer.hyperparameter_search(
|
| 1005 |
+
direction="maximize",
|
| 1006 |
+
backend="ray",
|
| 1007 |
+
resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
|
| 1008 |
+
hp_space=lambda _: def_ray_config
|
| 1009 |
+
if self.ray_config is None
|
| 1010 |
+
else self.ray_config,
|
| 1011 |
+
search_alg=hyperopt_search,
|
| 1012 |
+
n_trials=n_trials, # number of trials
|
| 1013 |
+
progress_reporter=tune.CLIReporter(
|
| 1014 |
+
max_report_frequency=600,
|
| 1015 |
+
sort_by_metric=True,
|
| 1016 |
+
max_progress_rows=n_trials,
|
| 1017 |
+
mode="max",
|
| 1018 |
+
metric="eval_macro_f1",
|
| 1019 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
| 1020 |
+
),
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
return trainer
|
| 1024 |
+
|
| 1025 |
def train_classifier(
|
| 1026 |
self,
|
| 1027 |
model_directory,
|
geneformer/classifier_utils.py
CHANGED
|
@@ -360,9 +360,23 @@ def get_num_classes(id_class_dict):
|
|
| 360 |
def compute_metrics(pred):
|
| 361 |
labels = pred.label_ids
|
| 362 |
preds = pred.predictions.argmax(-1)
|
|
|
|
| 363 |
# calculate accuracy and macro f1 using sklearn's function
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
| 367 |
|
| 368 |
|
|
@@ -387,6 +401,11 @@ def get_default_train_args(model, classifier, data, output_dir):
|
|
| 387 |
"per_device_train_batch_size": batch_size,
|
| 388 |
"per_device_eval_batch_size": batch_size,
|
| 389 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
training_args = {
|
| 392 |
"num_train_epochs": epochs,
|
|
|
|
| 360 |
def compute_metrics(pred):
|
| 361 |
labels = pred.label_ids
|
| 362 |
preds = pred.predictions.argmax(-1)
|
| 363 |
+
|
| 364 |
# calculate accuracy and macro f1 using sklearn's function
|
| 365 |
+
if len(labels.shape) == 1:
|
| 366 |
+
acc = accuracy_score(labels, preds)
|
| 367 |
+
macro_f1 = f1_score(labels, preds, average="macro")
|
| 368 |
+
else:
|
| 369 |
+
flat_labels = labels.flatten().tolist()
|
| 370 |
+
flat_preds = preds.flatten().tolist()
|
| 371 |
+
logit_label_paired = [
|
| 372 |
+
item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100
|
| 373 |
+
]
|
| 374 |
+
y_pred = [item[0] for item in logit_label_paired]
|
| 375 |
+
y_true = [item[1] for item in logit_label_paired]
|
| 376 |
+
|
| 377 |
+
acc = accuracy_score(y_true, y_pred)
|
| 378 |
+
macro_f1 = f1_score(y_true, y_pred, average="macro")
|
| 379 |
+
|
| 380 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
| 381 |
|
| 382 |
|
|
|
|
| 401 |
"per_device_train_batch_size": batch_size,
|
| 402 |
"per_device_eval_batch_size": batch_size,
|
| 403 |
}
|
| 404 |
+
else:
|
| 405 |
+
default_training_args = {
|
| 406 |
+
"per_device_train_batch_size": batch_size,
|
| 407 |
+
"per_device_eval_batch_size": batch_size,
|
| 408 |
+
}
|
| 409 |
|
| 410 |
training_args = {
|
| 411 |
"num_train_epochs": epochs,
|
requirements.txt
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
anndata>=0.9
|
| 2 |
datasets>=2.12
|
|
|
|
| 3 |
loompy>=3.0
|
| 4 |
matplotlib>=3.7
|
| 5 |
numpy>=1.23
|
|
|
|
| 1 |
anndata>=0.9
|
| 2 |
datasets>=2.12
|
| 3 |
+
hyperopt>=0.2
|
| 4 |
loompy>=3.0
|
| 5 |
matplotlib>=3.7
|
| 6 |
numpy>=1.23
|