Update with gene classifier, custom token dict, and str validate options
#329
by
hchen725
- opened
- geneformer/classifier.py +72 -38
geneformer/classifier.py
CHANGED
|
@@ -53,7 +53,6 @@ from pathlib import Path
|
|
| 53 |
import numpy as np
|
| 54 |
import pandas as pd
|
| 55 |
import seaborn as sns
|
| 56 |
-
from sklearn.model_selection import StratifiedKFold
|
| 57 |
from tqdm.auto import tqdm, trange
|
| 58 |
from transformers import Trainer
|
| 59 |
from transformers.training_args import TrainingArguments
|
|
@@ -86,6 +85,7 @@ class Classifier:
|
|
| 86 |
"no_eval": {bool},
|
| 87 |
"stratify_splits_col": {None, str},
|
| 88 |
"forward_batch_size": {int},
|
|
|
|
| 89 |
"nproc": {int},
|
| 90 |
"ngpu": {int},
|
| 91 |
}
|
|
@@ -107,6 +107,7 @@ class Classifier:
|
|
| 107 |
stratify_splits_col=None,
|
| 108 |
no_eval=False,
|
| 109 |
forward_batch_size=100,
|
|
|
|
| 110 |
nproc=4,
|
| 111 |
ngpu=1,
|
| 112 |
):
|
|
@@ -175,6 +176,9 @@ class Classifier:
|
|
| 175 |
| Otherwise, will perform eval during training.
|
| 176 |
forward_batch_size : int
|
| 177 |
| Batch size for forward pass (for evaluation, not training).
|
|
|
|
|
|
|
|
|
|
| 178 |
nproc : int
|
| 179 |
| Number of CPU processes to use.
|
| 180 |
ngpu : int
|
|
@@ -183,6 +187,10 @@ class Classifier:
|
|
| 183 |
"""
|
| 184 |
|
| 185 |
self.classifier = classifier
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
self.cell_state_dict = cell_state_dict
|
| 187 |
self.gene_class_dict = gene_class_dict
|
| 188 |
self.filter_data = filter_data
|
|
@@ -201,6 +209,7 @@ class Classifier:
|
|
| 201 |
self.stratify_splits_col = stratify_splits_col
|
| 202 |
self.no_eval = no_eval
|
| 203 |
self.forward_batch_size = forward_batch_size
|
|
|
|
| 204 |
self.nproc = nproc
|
| 205 |
self.ngpu = ngpu
|
| 206 |
|
|
@@ -222,7 +231,9 @@ class Classifier:
|
|
| 222 |
] = self.cell_state_dict["states"]
|
| 223 |
|
| 224 |
# load token dictionary (Ensembl IDs:token)
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
self.gene_token_dict = pickle.load(f)
|
| 227 |
|
| 228 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
|
@@ -267,7 +278,7 @@ class Classifier:
|
|
| 267 |
continue
|
| 268 |
valid_type = False
|
| 269 |
for option in valid_options:
|
| 270 |
-
if (option in [int, float, list, dict, bool]) and isinstance(
|
| 271 |
attr_value, option
|
| 272 |
):
|
| 273 |
valid_type = True
|
|
@@ -630,7 +641,6 @@ class Classifier:
|
|
| 630 |
| Number of trials to run for hyperparameter optimization
|
| 631 |
| If 0, will not optimize hyperparameters
|
| 632 |
"""
|
| 633 |
-
|
| 634 |
if self.num_crossval_splits == 0:
|
| 635 |
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
| 636 |
raise
|
|
@@ -772,17 +782,20 @@ class Classifier:
|
|
| 772 |
]
|
| 773 |
)
|
| 774 |
assert len(targets) == len(labels)
|
| 775 |
-
n_splits = int(1 / self.
|
| 776 |
-
skf =
|
| 777 |
# (Cross-)validate
|
| 778 |
-
|
|
|
|
|
|
|
|
|
|
| 779 |
print(
|
| 780 |
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
| 781 |
)
|
| 782 |
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
| 783 |
# filter data for examples containing classes for this split
|
| 784 |
# subsample to max_ncells and relabel data in column "labels"
|
| 785 |
-
train_data, eval_data = cu.
|
| 786 |
data,
|
| 787 |
targets,
|
| 788 |
labels,
|
|
@@ -793,6 +806,18 @@ class Classifier:
|
|
| 793 |
self.nproc,
|
| 794 |
)
|
| 795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
if n_hyperopt_trials == 0:
|
| 797 |
trainer = self.train_classifier(
|
| 798 |
model_directory,
|
|
@@ -802,6 +827,15 @@ class Classifier:
|
|
| 802 |
ksplit_output_dir,
|
| 803 |
predict_trainer,
|
| 804 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
else:
|
| 806 |
trainer = self.hyperopt_classifier(
|
| 807 |
model_directory,
|
|
@@ -811,20 +845,27 @@ class Classifier:
|
|
| 811 |
ksplit_output_dir,
|
| 812 |
n_trials=n_hyperopt_trials,
|
| 813 |
)
|
| 814 |
-
|
| 815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
else:
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
id_class_dict,
|
| 823 |
-
eval_data,
|
| 824 |
-
predict_eval,
|
| 825 |
-
ksplit_output_dir,
|
| 826 |
-
output_prefix,
|
| 827 |
-
)
|
| 828 |
results += [result]
|
| 829 |
all_conf_mat = all_conf_mat + result["conf_mat"]
|
| 830 |
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
|
@@ -925,12 +966,7 @@ class Classifier:
|
|
| 925 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
| 926 |
|
| 927 |
##### Load model and training args #####
|
| 928 |
-
|
| 929 |
-
model_type = "CellClassifier"
|
| 930 |
-
elif self.classifier == "gene":
|
| 931 |
-
model_type = "GeneClassifier"
|
| 932 |
-
|
| 933 |
-
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
| 934 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
| 935 |
model, self.classifier, train_data, output_directory
|
| 936 |
)
|
|
@@ -946,6 +982,9 @@ class Classifier:
|
|
| 946 |
if eval_data is None:
|
| 947 |
def_training_args["evaluation_strategy"] = "no"
|
| 948 |
def_training_args["load_best_model_at_end"] = False
|
|
|
|
|
|
|
|
|
|
| 949 |
training_args_init = TrainingArguments(**def_training_args)
|
| 950 |
|
| 951 |
##### Fine-tune the model #####
|
|
@@ -957,7 +996,9 @@ class Classifier:
|
|
| 957 |
|
| 958 |
# define function to initiate model
|
| 959 |
def model_init():
|
| 960 |
-
model = pu.load_model(
|
|
|
|
|
|
|
| 961 |
|
| 962 |
if self.freeze_layers is not None:
|
| 963 |
def_freeze_layers = self.freeze_layers
|
|
@@ -1018,6 +1059,7 @@ class Classifier:
|
|
| 1018 |
metric="eval_macro_f1",
|
| 1019 |
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
| 1020 |
),
|
|
|
|
| 1021 |
)
|
| 1022 |
|
| 1023 |
return trainer
|
|
@@ -1080,11 +1122,7 @@ class Classifier:
|
|
| 1080 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
| 1081 |
|
| 1082 |
##### Load model and training args #####
|
| 1083 |
-
|
| 1084 |
-
model_type = "CellClassifier"
|
| 1085 |
-
elif self.classifier == "gene":
|
| 1086 |
-
model_type = "GeneClassifier"
|
| 1087 |
-
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
| 1088 |
|
| 1089 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
| 1090 |
model, self.classifier, train_data, output_directory
|
|
@@ -1238,11 +1276,7 @@ class Classifier:
|
|
| 1238 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
| 1239 |
|
| 1240 |
# load previously fine-tuned model
|
| 1241 |
-
|
| 1242 |
-
model_type = "CellClassifier"
|
| 1243 |
-
elif self.classifier == "gene":
|
| 1244 |
-
model_type = "GeneClassifier"
|
| 1245 |
-
model = pu.load_model(model_type, num_classes, model_directory, "eval")
|
| 1246 |
|
| 1247 |
# evaluate the model
|
| 1248 |
result = self.evaluate_model(
|
|
|
|
| 53 |
import numpy as np
|
| 54 |
import pandas as pd
|
| 55 |
import seaborn as sns
|
|
|
|
| 56 |
from tqdm.auto import tqdm, trange
|
| 57 |
from transformers import Trainer
|
| 58 |
from transformers.training_args import TrainingArguments
|
|
|
|
| 85 |
"no_eval": {bool},
|
| 86 |
"stratify_splits_col": {None, str},
|
| 87 |
"forward_batch_size": {int},
|
| 88 |
+
"token_dictionary_file": {None, str},
|
| 89 |
"nproc": {int},
|
| 90 |
"ngpu": {int},
|
| 91 |
}
|
|
|
|
| 107 |
stratify_splits_col=None,
|
| 108 |
no_eval=False,
|
| 109 |
forward_batch_size=100,
|
| 110 |
+
token_dictionary_file=None,
|
| 111 |
nproc=4,
|
| 112 |
ngpu=1,
|
| 113 |
):
|
|
|
|
| 176 |
| Otherwise, will perform eval during training.
|
| 177 |
forward_batch_size : int
|
| 178 |
| Batch size for forward pass (for evaluation, not training).
|
| 179 |
+
token_dictionary_file : None, str
|
| 180 |
+
| Default is to use token dictionary file from Geneformer
|
| 181 |
+
| Otherwise, will load custom gene token dictionary.
|
| 182 |
nproc : int
|
| 183 |
| Number of CPU processes to use.
|
| 184 |
ngpu : int
|
|
|
|
| 187 |
"""
|
| 188 |
|
| 189 |
self.classifier = classifier
|
| 190 |
+
if self.classifier == "cell":
|
| 191 |
+
self.model_type = "CellClassifier"
|
| 192 |
+
elif self.classifier == "gene":
|
| 193 |
+
self.model_type = "GeneClassifier"
|
| 194 |
self.cell_state_dict = cell_state_dict
|
| 195 |
self.gene_class_dict = gene_class_dict
|
| 196 |
self.filter_data = filter_data
|
|
|
|
| 209 |
self.stratify_splits_col = stratify_splits_col
|
| 210 |
self.no_eval = no_eval
|
| 211 |
self.forward_batch_size = forward_batch_size
|
| 212 |
+
self.token_dictionary_file = token_dictionary_file
|
| 213 |
self.nproc = nproc
|
| 214 |
self.ngpu = ngpu
|
| 215 |
|
|
|
|
| 231 |
] = self.cell_state_dict["states"]
|
| 232 |
|
| 233 |
# load token dictionary (Ensembl IDs:token)
|
| 234 |
+
if self.token_dictionary_file is None:
|
| 235 |
+
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
| 236 |
+
with open(token_dictionary_file, "rb") as f:
|
| 237 |
self.gene_token_dict = pickle.load(f)
|
| 238 |
|
| 239 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
|
|
|
| 278 |
continue
|
| 279 |
valid_type = False
|
| 280 |
for option in valid_options:
|
| 281 |
+
if (option in [int, float, list, dict, bool, str]) and isinstance(
|
| 282 |
attr_value, option
|
| 283 |
):
|
| 284 |
valid_type = True
|
|
|
|
| 641 |
| Number of trials to run for hyperparameter optimization
|
| 642 |
| If 0, will not optimize hyperparameters
|
| 643 |
"""
|
|
|
|
| 644 |
if self.num_crossval_splits == 0:
|
| 645 |
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
| 646 |
raise
|
|
|
|
| 782 |
]
|
| 783 |
)
|
| 784 |
assert len(targets) == len(labels)
|
| 785 |
+
n_splits = int(1 / (1 - self.train_size))
|
| 786 |
+
skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
|
| 787 |
# (Cross-)validate
|
| 788 |
+
test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
|
| 789 |
+
for train_index, eval_index, test_index in tqdm(
|
| 790 |
+
skf.split(targets, labels, test_ratio)
|
| 791 |
+
):
|
| 792 |
print(
|
| 793 |
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
| 794 |
)
|
| 795 |
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
| 796 |
# filter data for examples containing classes for this split
|
| 797 |
# subsample to max_ncells and relabel data in column "labels"
|
| 798 |
+
train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
|
| 799 |
data,
|
| 800 |
targets,
|
| 801 |
labels,
|
|
|
|
| 806 |
self.nproc,
|
| 807 |
)
|
| 808 |
|
| 809 |
+
if self.oos_test_size > 0:
|
| 810 |
+
test_data = cu.prep_gene_classifier_split(
|
| 811 |
+
data,
|
| 812 |
+
targets,
|
| 813 |
+
labels,
|
| 814 |
+
test_index,
|
| 815 |
+
"test",
|
| 816 |
+
self.max_ncells,
|
| 817 |
+
iteration_num,
|
| 818 |
+
self.nproc,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
if n_hyperopt_trials == 0:
|
| 822 |
trainer = self.train_classifier(
|
| 823 |
model_directory,
|
|
|
|
| 827 |
ksplit_output_dir,
|
| 828 |
predict_trainer,
|
| 829 |
)
|
| 830 |
+
result = self.evaluate_model(
|
| 831 |
+
trainer.model,
|
| 832 |
+
num_classes,
|
| 833 |
+
id_class_dict,
|
| 834 |
+
eval_data,
|
| 835 |
+
predict_eval,
|
| 836 |
+
ksplit_output_dir,
|
| 837 |
+
output_prefix,
|
| 838 |
+
)
|
| 839 |
else:
|
| 840 |
trainer = self.hyperopt_classifier(
|
| 841 |
model_directory,
|
|
|
|
| 845 |
ksplit_output_dir,
|
| 846 |
n_trials=n_hyperopt_trials,
|
| 847 |
)
|
| 848 |
+
|
| 849 |
+
model = cu.load_best_model(
|
| 850 |
+
ksplit_output_dir, self.model_type, num_classes
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
if self.oos_test_size > 0:
|
| 854 |
+
result = self.evaluate_model(
|
| 855 |
+
model,
|
| 856 |
+
num_classes,
|
| 857 |
+
id_class_dict,
|
| 858 |
+
test_data,
|
| 859 |
+
predict_eval,
|
| 860 |
+
ksplit_output_dir,
|
| 861 |
+
output_prefix,
|
| 862 |
+
)
|
| 863 |
else:
|
| 864 |
+
if iteration_num == self.num_crossval_splits:
|
| 865 |
+
return
|
| 866 |
+
else:
|
| 867 |
+
iteration_num = iteration_num + 1
|
| 868 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
results += [result]
|
| 870 |
all_conf_mat = all_conf_mat + result["conf_mat"]
|
| 871 |
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
|
|
|
| 966 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
| 967 |
|
| 968 |
##### Load model and training args #####
|
| 969 |
+
model = pu.load_model(self.model_type, num_classes, model_directory, "train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 970 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
| 971 |
model, self.classifier, train_data, output_directory
|
| 972 |
)
|
|
|
|
| 982 |
if eval_data is None:
|
| 983 |
def_training_args["evaluation_strategy"] = "no"
|
| 984 |
def_training_args["load_best_model_at_end"] = False
|
| 985 |
+
def_training_args.update(
|
| 986 |
+
{"save_strategy": "epoch", "save_total_limit": 1}
|
| 987 |
+
) # only save last model for each run
|
| 988 |
training_args_init = TrainingArguments(**def_training_args)
|
| 989 |
|
| 990 |
##### Fine-tune the model #####
|
|
|
|
| 996 |
|
| 997 |
# define function to initiate model
|
| 998 |
def model_init():
|
| 999 |
+
model = pu.load_model(
|
| 1000 |
+
self.model_type, num_classes, model_directory, "train"
|
| 1001 |
+
)
|
| 1002 |
|
| 1003 |
if self.freeze_layers is not None:
|
| 1004 |
def_freeze_layers = self.freeze_layers
|
|
|
|
| 1059 |
metric="eval_macro_f1",
|
| 1060 |
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
| 1061 |
),
|
| 1062 |
+
local_dir=output_directory,
|
| 1063 |
)
|
| 1064 |
|
| 1065 |
return trainer
|
|
|
|
| 1122 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
| 1123 |
|
| 1124 |
##### Load model and training args #####
|
| 1125 |
+
model = pu.load_model(self.model_type, num_classes, model_directory, "train")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1126 |
|
| 1127 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
| 1128 |
model, self.classifier, train_data, output_directory
|
|
|
|
| 1276 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
| 1277 |
|
| 1278 |
# load previously fine-tuned model
|
| 1279 |
+
model = pu.load_model(self.model_type, num_classes, model_directory, "eval")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1280 |
|
| 1281 |
# evaluate the model
|
| 1282 |
result = self.evaluate_model(
|