Christina Theodoris
commited on
Commit
·
be98543
1
Parent(s):
fb6abe0
track metadata in predictions
Browse files- geneformer/classifier.py +22 -3
- geneformer/classifier_utils.py +21 -0
- geneformer/evaluation_utils.py +17 -2
geneformer/classifier.py
CHANGED
|
@@ -801,7 +801,7 @@ class Classifier:
|
|
| 801 |
# 5-fold cross-validate
|
| 802 |
num_cells = len(data)
|
| 803 |
fifth_cells = int(np.floor(num_cells * 0.2))
|
| 804 |
-
num_eval = min((self.eval_size * num_cells), fifth_cells)
|
| 805 |
start = i * fifth_cells
|
| 806 |
end = start + num_eval
|
| 807 |
eval_indices = [j for j in range(start, end)]
|
|
@@ -1313,6 +1313,7 @@ class Classifier:
|
|
| 1313 |
predict=False,
|
| 1314 |
output_directory=None,
|
| 1315 |
output_prefix=None,
|
|
|
|
| 1316 |
):
|
| 1317 |
"""
|
| 1318 |
Evaluate the fine-tuned model.
|
|
@@ -1338,9 +1339,11 @@ class Classifier:
|
|
| 1338 |
|
| 1339 |
##### Evaluate the model #####
|
| 1340 |
labels = id_class_dict.keys()
|
| 1341 |
-
|
| 1342 |
-
|
|
|
|
| 1343 |
)
|
|
|
|
| 1344 |
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
| 1345 |
y_pred, y_true, logits_list, num_classes, labels
|
| 1346 |
)
|
|
@@ -1350,6 +1353,9 @@ class Classifier:
|
|
| 1350 |
"label_ids": y_true,
|
| 1351 |
"predictions": logits_list,
|
| 1352 |
}
|
|
|
|
|
|
|
|
|
|
| 1353 |
pred_dict_output_path = (
|
| 1354 |
Path(output_directory) / f"{output_prefix}_pred_dict"
|
| 1355 |
).with_suffix(".pkl")
|
|
@@ -1370,6 +1376,7 @@ class Classifier:
|
|
| 1370 |
output_directory,
|
| 1371 |
output_prefix,
|
| 1372 |
predict=True,
|
|
|
|
| 1373 |
):
|
| 1374 |
"""
|
| 1375 |
Evaluate the fine-tuned model.
|
|
@@ -1389,6 +1396,8 @@ class Classifier:
|
|
| 1389 |
| Prefix for output files
|
| 1390 |
predict : bool
|
| 1391 |
| Whether or not to save eval predictions
|
|
|
|
|
|
|
| 1392 |
"""
|
| 1393 |
|
| 1394 |
# load numerical id to class dictionary (id:class)
|
|
@@ -1401,6 +1410,15 @@ class Classifier:
|
|
| 1401 |
# load previously filtered and prepared data
|
| 1402 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
| 1403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1404 |
# load previously fine-tuned model
|
| 1405 |
model = pu.load_model(
|
| 1406 |
self.model_type,
|
|
@@ -1419,6 +1437,7 @@ class Classifier:
|
|
| 1419 |
predict=predict,
|
| 1420 |
output_directory=output_directory,
|
| 1421 |
output_prefix=output_prefix,
|
|
|
|
| 1422 |
)
|
| 1423 |
|
| 1424 |
all_conf_mat_df = pd.DataFrame(
|
|
|
|
| 801 |
# 5-fold cross-validate
|
| 802 |
num_cells = len(data)
|
| 803 |
fifth_cells = int(np.floor(num_cells * 0.2))
|
| 804 |
+
num_eval = int(min((self.eval_size * num_cells), fifth_cells))
|
| 805 |
start = i * fifth_cells
|
| 806 |
end = start + num_eval
|
| 807 |
eval_indices = [j for j in range(start, end)]
|
|
|
|
| 1313 |
predict=False,
|
| 1314 |
output_directory=None,
|
| 1315 |
output_prefix=None,
|
| 1316 |
+
predict_metadata=None,
|
| 1317 |
):
|
| 1318 |
"""
|
| 1319 |
Evaluate the fine-tuned model.
|
|
|
|
| 1339 |
|
| 1340 |
##### Evaluate the model #####
|
| 1341 |
labels = id_class_dict.keys()
|
| 1342 |
+
|
| 1343 |
+
y_pred, y_true, logits_list, predict_metadata_all = eu.classifier_predict(
|
| 1344 |
+
model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata
|
| 1345 |
)
|
| 1346 |
+
|
| 1347 |
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
| 1348 |
y_pred, y_true, logits_list, num_classes, labels
|
| 1349 |
)
|
|
|
|
| 1353 |
"label_ids": y_true,
|
| 1354 |
"predictions": logits_list,
|
| 1355 |
}
|
| 1356 |
+
if predict_metadata is not None:
|
| 1357 |
+
pred_dict["prediction_metadata"] = predict_metadata_all
|
| 1358 |
+
|
| 1359 |
pred_dict_output_path = (
|
| 1360 |
Path(output_directory) / f"{output_prefix}_pred_dict"
|
| 1361 |
).with_suffix(".pkl")
|
|
|
|
| 1376 |
output_directory,
|
| 1377 |
output_prefix,
|
| 1378 |
predict=True,
|
| 1379 |
+
predict_metadata=None,
|
| 1380 |
):
|
| 1381 |
"""
|
| 1382 |
Evaluate the fine-tuned model.
|
|
|
|
| 1396 |
| Prefix for output files
|
| 1397 |
predict : bool
|
| 1398 |
| Whether or not to save eval predictions
|
| 1399 |
+
predict_metadata : None | list
|
| 1400 |
+
| Metadata labels to output with predictions (columns in test_data_file)
|
| 1401 |
"""
|
| 1402 |
|
| 1403 |
# load numerical id to class dictionary (id:class)
|
|
|
|
| 1410 |
# load previously filtered and prepared data
|
| 1411 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
| 1412 |
|
| 1413 |
+
if predict_metadata is not None:
|
| 1414 |
+
absent_metadata = []
|
| 1415 |
+
for predict_metadata_x in predict_metadata:
|
| 1416 |
+
if predict_metadata_x not in test_data.features.keys():
|
| 1417 |
+
absent_metadata += [predict_metadata_x]
|
| 1418 |
+
if len(absent_metadata)>0:
|
| 1419 |
+
logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}")
|
| 1420 |
+
raise
|
| 1421 |
+
|
| 1422 |
# load previously fine-tuned model
|
| 1423 |
model = pu.load_model(
|
| 1424 |
self.model_type,
|
|
|
|
| 1437 |
predict=predict,
|
| 1438 |
output_directory=output_directory,
|
| 1439 |
output_prefix=output_prefix,
|
| 1440 |
+
predict_metadata=predict_metadata,
|
| 1441 |
)
|
| 1442 |
|
| 1443 |
all_conf_mat_df = pd.DataFrame(
|
geneformer/classifier_utils.py
CHANGED
|
@@ -570,6 +570,27 @@ def compute_metrics(pred):
|
|
| 570 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
| 571 |
|
| 572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
def get_default_train_args(model, classifier, data, output_dir):
|
| 574 |
num_layers = pu.quant_layers(model)
|
| 575 |
freeze_layers = 0
|
|
|
|
| 570 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
| 571 |
|
| 572 |
|
| 573 |
+
def robust_compute_objective(metrics: dict):
|
| 574 |
+
# tries both prefixed ("eval_") and raw metric names to support different transformers versions
|
| 575 |
+
metric_name = "macro_f1"
|
| 576 |
+
|
| 577 |
+
# check for the prefixed version
|
| 578 |
+
prefixed_metric_name = f"eval_{metric_name}"
|
| 579 |
+
if prefixed_metric_name in metrics:
|
| 580 |
+
return metrics[prefixed_metric_name]
|
| 581 |
+
|
| 582 |
+
# fall back to the raw name
|
| 583 |
+
elif metric_name in metrics:
|
| 584 |
+
return metrics[metric_name]
|
| 585 |
+
|
| 586 |
+
# if neither is found, raise a clear error to help with debugging
|
| 587 |
+
raise KeyError(
|
| 588 |
+
f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. "
|
| 589 |
+
f"Please check your `compute_metrics` function and `TrainingArguments`. "
|
| 590 |
+
f"Available metrics: {list(metrics.keys())}"
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
def get_default_train_args(model, classifier, data, output_dir):
|
| 595 |
num_layers = pu.quant_layers(model)
|
| 596 |
freeze_layers = 0
|
geneformer/evaluation_utils.py
CHANGED
|
@@ -77,7 +77,7 @@ def py_softmax(vector):
|
|
| 77 |
return e / e.sum()
|
| 78 |
|
| 79 |
|
| 80 |
-
def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict):
|
| 81 |
if classifier_type == "gene":
|
| 82 |
label_name = "labels"
|
| 83 |
elif classifier_type == "cell":
|
|
@@ -85,6 +85,14 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 85 |
|
| 86 |
predict_logits = []
|
| 87 |
predict_labels = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
model.eval()
|
| 89 |
|
| 90 |
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
|
@@ -99,9 +107,15 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 99 |
for i in trange(0, evalset_len, forward_batch_size):
|
| 100 |
max_range = min(i + forward_batch_size, evalset_len)
|
| 101 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
padded_batch = preprocess_classifier_batch(
|
| 103 |
batch_evalset, max_evalset_len, label_name, gene_token_dict
|
| 104 |
)
|
|
|
|
| 105 |
padded_batch.set_format(type="torch")
|
| 106 |
|
| 107 |
# For datasets>=4.0.0, convert to dict to avoid format issues
|
|
@@ -134,7 +148,8 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 134 |
y_pred = [vote(item[0]) for item in logit_label_paired]
|
| 135 |
y_true = [item[1] for item in logit_label_paired]
|
| 136 |
logits_list = [item[0] for item in logit_label_paired]
|
| 137 |
-
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
|
|
|
| 77 |
return e / e.sum()
|
| 78 |
|
| 79 |
|
| 80 |
+
def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict, predict_metadata=None):
|
| 81 |
if classifier_type == "gene":
|
| 82 |
label_name = "labels"
|
| 83 |
elif classifier_type == "cell":
|
|
|
|
| 85 |
|
| 86 |
predict_logits = []
|
| 87 |
predict_labels = []
|
| 88 |
+
|
| 89 |
+
predict_metadata_all = None
|
| 90 |
+
|
| 91 |
+
if predict_metadata is not None:
|
| 92 |
+
predict_metadata_all = dict()
|
| 93 |
+
for metadata_name in predict_metadata:
|
| 94 |
+
predict_metadata_all[metadata_name] = []
|
| 95 |
+
|
| 96 |
model.eval()
|
| 97 |
|
| 98 |
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
|
|
|
| 107 |
for i in trange(0, evalset_len, forward_batch_size):
|
| 108 |
max_range = min(i + forward_batch_size, evalset_len)
|
| 109 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
| 110 |
+
|
| 111 |
+
if predict_metadata is not None:
|
| 112 |
+
for metadata_name in predict_metadata:
|
| 113 |
+
predict_metadata_all[metadata_name] += batch_evalset[metadata_name]
|
| 114 |
+
|
| 115 |
padded_batch = preprocess_classifier_batch(
|
| 116 |
batch_evalset, max_evalset_len, label_name, gene_token_dict
|
| 117 |
)
|
| 118 |
+
|
| 119 |
padded_batch.set_format(type="torch")
|
| 120 |
|
| 121 |
# For datasets>=4.0.0, convert to dict to avoid format issues
|
|
|
|
| 148 |
y_pred = [vote(item[0]) for item in logit_label_paired]
|
| 149 |
y_true = [item[1] for item in logit_label_paired]
|
| 150 |
logits_list = [item[0] for item in logit_label_paired]
|
| 151 |
+
|
| 152 |
+
return y_pred, y_true, logits_list, predict_metadata_all
|
| 153 |
|
| 154 |
|
| 155 |
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|