Update for gene classification
#330
by
hchen725
- opened
- geneformer/classifier_utils.py +72 -33
geneformer/classifier_utils.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
|
|
| 1 |
import logging
|
|
|
|
| 2 |
import random
|
| 3 |
from collections import Counter, defaultdict
|
| 4 |
|
|
@@ -6,6 +8,7 @@ import numpy as np
|
|
| 6 |
import pandas as pd
|
| 7 |
from scipy.stats import chisquare, ranksums
|
| 8 |
from sklearn.metrics import accuracy_score, f1_score
|
|
|
|
| 9 |
|
| 10 |
from . import perturber_utils as pu
|
| 11 |
|
|
@@ -133,61 +136,55 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
|
|
| 133 |
]
|
| 134 |
|
| 135 |
|
| 136 |
-
def
|
| 137 |
data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
):
|
| 139 |
# generate cross-validation splits
|
| 140 |
targets = np.array(targets)
|
| 141 |
labels = np.array(labels)
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
| 146 |
|
| 147 |
# function to filter by whether contains train or eval labels
|
| 148 |
-
def
|
| 149 |
-
a =
|
| 150 |
-
b = example["input_ids"]
|
| 151 |
-
return not set(a).isdisjoint(b)
|
| 152 |
-
|
| 153 |
-
def if_contains_eval_label(example):
|
| 154 |
-
a = targets_eval
|
| 155 |
b = example["input_ids"]
|
| 156 |
return not set(a).isdisjoint(b)
|
| 157 |
|
| 158 |
# filter dataset for examples containing classes for this split
|
| 159 |
-
logger.info(f"Filtering
|
| 160 |
-
|
| 161 |
logger.info(
|
| 162 |
-
f"Filtered {round((1-len(
|
| 163 |
-
)
|
| 164 |
-
logger.info(f"Filtering evalation data for genes in split {iteration_num}")
|
| 165 |
-
eval_data = data.filter(if_contains_eval_label, num_proc=num_proc)
|
| 166 |
-
logger.info(
|
| 167 |
-
f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n"
|
| 168 |
)
|
| 169 |
|
| 170 |
# subsample to max_ncells
|
| 171 |
-
|
| 172 |
-
eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None)
|
| 173 |
|
| 174 |
# relabel genes for this split
|
| 175 |
-
def
|
| 176 |
example["labels"] = [
|
| 177 |
-
|
| 178 |
]
|
| 179 |
return example
|
| 180 |
|
| 181 |
-
|
| 182 |
-
example["labels"] = [
|
| 183 |
-
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
|
| 184 |
-
]
|
| 185 |
-
return example
|
| 186 |
|
| 187 |
-
|
| 188 |
-
eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc)
|
| 189 |
-
|
| 190 |
-
return train_data, eval_data
|
| 191 |
|
| 192 |
|
| 193 |
def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
|
|
@@ -423,3 +420,45 @@ def get_default_train_args(model, classifier, data, output_dir):
|
|
| 423 |
training_args.update(default_training_args)
|
| 424 |
|
| 425 |
return training_args, freeze_layers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
import logging
|
| 3 |
+
import os
|
| 4 |
import random
|
| 5 |
from collections import Counter, defaultdict
|
| 6 |
|
|
|
|
| 8 |
import pandas as pd
|
| 9 |
from scipy.stats import chisquare, ranksums
|
| 10 |
from sklearn.metrics import accuracy_score, f1_score
|
| 11 |
+
from sklearn.model_selection import StratifiedKFold, train_test_split
|
| 12 |
|
| 13 |
from . import perturber_utils as pu
|
| 14 |
|
|
|
|
| 136 |
]
|
| 137 |
|
| 138 |
|
| 139 |
+
def prep_gene_classifier_train_eval_split(
|
| 140 |
data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
|
| 141 |
+
):
|
| 142 |
+
# generate cross-validation splits
|
| 143 |
+
train_data = prep_gene_classifier_split(
|
| 144 |
+
data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc
|
| 145 |
+
)
|
| 146 |
+
eval_data = prep_gene_classifier_split(
|
| 147 |
+
data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc
|
| 148 |
+
)
|
| 149 |
+
return train_data, eval_data
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def prep_gene_classifier_split(
|
| 153 |
+
data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc
|
| 154 |
):
|
| 155 |
# generate cross-validation splits
|
| 156 |
targets = np.array(targets)
|
| 157 |
labels = np.array(labels)
|
| 158 |
+
targets_subset = targets[index]
|
| 159 |
+
labels_subset = labels[index]
|
| 160 |
+
label_dict_subset = dict(zip(targets_subset, labels_subset))
|
|
|
|
| 161 |
|
| 162 |
# function to filter by whether contains train or eval labels
|
| 163 |
+
def if_contains_subset_label(example):
|
| 164 |
+
a = targets_subset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
b = example["input_ids"]
|
| 166 |
return not set(a).isdisjoint(b)
|
| 167 |
|
| 168 |
# filter dataset for examples containing classes for this split
|
| 169 |
+
logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
|
| 170 |
+
subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
|
| 171 |
logger.info(
|
| 172 |
+
f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
)
|
| 174 |
|
| 175 |
# subsample to max_ncells
|
| 176 |
+
subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
|
|
|
|
| 177 |
|
| 178 |
# relabel genes for this split
|
| 179 |
+
def subset_classes_to_ids(example):
|
| 180 |
example["labels"] = [
|
| 181 |
+
label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
|
| 182 |
]
|
| 183 |
return example
|
| 184 |
|
| 185 |
+
subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
+
return subset_data
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
|
|
|
|
| 420 |
training_args.update(default_training_args)
|
| 421 |
|
| 422 |
return training_args, freeze_layers
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def load_best_model(directory, model_type, num_classes, mode="eval"):
|
| 426 |
+
file_dict = dict()
|
| 427 |
+
for subdir, dirs, files in os.walk(directory):
|
| 428 |
+
for file in files:
|
| 429 |
+
if file.endswith("result.json"):
|
| 430 |
+
with open(f"{subdir}/{file}", "rb") as fp:
|
| 431 |
+
result_json = json.load(fp)
|
| 432 |
+
file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
|
| 433 |
+
file_df = pd.DataFrame(
|
| 434 |
+
{"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
|
| 435 |
+
)
|
| 436 |
+
model_superdir = (
|
| 437 |
+
"run-"
|
| 438 |
+
+ file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
|
| 439 |
+
.split("_objective_")[2]
|
| 440 |
+
.split("_")[0]
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
|
| 444 |
+
for file in files:
|
| 445 |
+
if file.endswith("model.safetensors"):
|
| 446 |
+
model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
|
| 447 |
+
return model
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class StratifiedKFold3(StratifiedKFold):
|
| 451 |
+
def split(self, targets, labels, test_ratio=0.5, groups=None):
|
| 452 |
+
s = super().split(targets, labels, groups)
|
| 453 |
+
for train_indxs, test_indxs in s:
|
| 454 |
+
if test_ratio == 0:
|
| 455 |
+
yield train_indxs, test_indxs, None
|
| 456 |
+
else:
|
| 457 |
+
labels_test = np.array(labels)[test_indxs]
|
| 458 |
+
valid_indxs, test_indxs = train_test_split(
|
| 459 |
+
test_indxs,
|
| 460 |
+
stratify=labels_test,
|
| 461 |
+
test_size=test_ratio,
|
| 462 |
+
random_state=0,
|
| 463 |
+
)
|
| 464 |
+
yield train_indxs, valid_indxs, test_indxs
|