Christina Theodoris commited on
Commit
be98543
·
1 Parent(s): fb6abe0

track metadata in predictions

Browse files
geneformer/classifier.py CHANGED
@@ -801,7 +801,7 @@ class Classifier:
801
  # 5-fold cross-validate
802
  num_cells = len(data)
803
  fifth_cells = int(np.floor(num_cells * 0.2))
804
- num_eval = min((self.eval_size * num_cells), fifth_cells)
805
  start = i * fifth_cells
806
  end = start + num_eval
807
  eval_indices = [j for j in range(start, end)]
@@ -1313,6 +1313,7 @@ class Classifier:
1313
  predict=False,
1314
  output_directory=None,
1315
  output_prefix=None,
 
1316
  ):
1317
  """
1318
  Evaluate the fine-tuned model.
@@ -1338,9 +1339,11 @@ class Classifier:
1338
 
1339
  ##### Evaluate the model #####
1340
  labels = id_class_dict.keys()
1341
- y_pred, y_true, logits_list = eu.classifier_predict(
1342
- model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict
 
1343
  )
 
1344
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1345
  y_pred, y_true, logits_list, num_classes, labels
1346
  )
@@ -1350,6 +1353,9 @@ class Classifier:
1350
  "label_ids": y_true,
1351
  "predictions": logits_list,
1352
  }
 
 
 
1353
  pred_dict_output_path = (
1354
  Path(output_directory) / f"{output_prefix}_pred_dict"
1355
  ).with_suffix(".pkl")
@@ -1370,6 +1376,7 @@ class Classifier:
1370
  output_directory,
1371
  output_prefix,
1372
  predict=True,
 
1373
  ):
1374
  """
1375
  Evaluate the fine-tuned model.
@@ -1389,6 +1396,8 @@ class Classifier:
1389
  | Prefix for output files
1390
  predict : bool
1391
  | Whether or not to save eval predictions
 
 
1392
  """
1393
 
1394
  # load numerical id to class dictionary (id:class)
@@ -1401,6 +1410,15 @@ class Classifier:
1401
  # load previously filtered and prepared data
1402
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1403
 
 
 
 
 
 
 
 
 
 
1404
  # load previously fine-tuned model
1405
  model = pu.load_model(
1406
  self.model_type,
@@ -1419,6 +1437,7 @@ class Classifier:
1419
  predict=predict,
1420
  output_directory=output_directory,
1421
  output_prefix=output_prefix,
 
1422
  )
1423
 
1424
  all_conf_mat_df = pd.DataFrame(
 
801
  # 5-fold cross-validate
802
  num_cells = len(data)
803
  fifth_cells = int(np.floor(num_cells * 0.2))
804
+ num_eval = int(min((self.eval_size * num_cells), fifth_cells))
805
  start = i * fifth_cells
806
  end = start + num_eval
807
  eval_indices = [j for j in range(start, end)]
 
1313
  predict=False,
1314
  output_directory=None,
1315
  output_prefix=None,
1316
+ predict_metadata=None,
1317
  ):
1318
  """
1319
  Evaluate the fine-tuned model.
 
1339
 
1340
  ##### Evaluate the model #####
1341
  labels = id_class_dict.keys()
1342
+
1343
+ y_pred, y_true, logits_list, predict_metadata_all = eu.classifier_predict(
1344
+ model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata
1345
  )
1346
+
1347
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1348
  y_pred, y_true, logits_list, num_classes, labels
1349
  )
 
1353
  "label_ids": y_true,
1354
  "predictions": logits_list,
1355
  }
1356
+ if predict_metadata is not None:
1357
+ pred_dict["prediction_metadata"] = predict_metadata_all
1358
+
1359
  pred_dict_output_path = (
1360
  Path(output_directory) / f"{output_prefix}_pred_dict"
1361
  ).with_suffix(".pkl")
 
1376
  output_directory,
1377
  output_prefix,
1378
  predict=True,
1379
+ predict_metadata=None,
1380
  ):
1381
  """
1382
  Evaluate the fine-tuned model.
 
1396
  | Prefix for output files
1397
  predict : bool
1398
  | Whether or not to save eval predictions
1399
+ predict_metadata : None | list
1400
+ | Metadata labels to output with predictions (columns in test_data_file)
1401
  """
1402
 
1403
  # load numerical id to class dictionary (id:class)
 
1410
  # load previously filtered and prepared data
1411
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1412
 
1413
+ if predict_metadata is not None:
1414
+ absent_metadata = []
1415
+ for predict_metadata_x in predict_metadata:
1416
+ if predict_metadata_x not in test_data.features.keys():
1417
+ absent_metadata += [predict_metadata_x]
1418
+ if len(absent_metadata)>0:
1419
+ logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}")
1420
+ raise
1421
+
1422
  # load previously fine-tuned model
1423
  model = pu.load_model(
1424
  self.model_type,
 
1437
  predict=predict,
1438
  output_directory=output_directory,
1439
  output_prefix=output_prefix,
1440
+ predict_metadata=predict_metadata,
1441
  )
1442
 
1443
  all_conf_mat_df = pd.DataFrame(
geneformer/classifier_utils.py CHANGED
@@ -570,6 +570,27 @@ def compute_metrics(pred):
570
  return {"accuracy": acc, "macro_f1": macro_f1}
571
 
572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  def get_default_train_args(model, classifier, data, output_dir):
574
  num_layers = pu.quant_layers(model)
575
  freeze_layers = 0
 
570
  return {"accuracy": acc, "macro_f1": macro_f1}
571
 
572
 
573
+ def robust_compute_objective(metrics: dict):
574
+ # tries both prefixed ("eval_") and raw metric names to support different transformers versions
575
+ metric_name = "macro_f1"
576
+
577
+ # check for the prefixed version
578
+ prefixed_metric_name = f"eval_{metric_name}"
579
+ if prefixed_metric_name in metrics:
580
+ return metrics[prefixed_metric_name]
581
+
582
+ # fall back to the raw name
583
+ elif metric_name in metrics:
584
+ return metrics[metric_name]
585
+
586
+ # if neither is found, raise a clear error to help with debugging
587
+ raise KeyError(
588
+ f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. "
589
+ f"Please check your `compute_metrics` function and `TrainingArguments`. "
590
+ f"Available metrics: {list(metrics.keys())}"
591
+ )
592
+
593
+
594
  def get_default_train_args(model, classifier, data, output_dir):
595
  num_layers = pu.quant_layers(model)
596
  freeze_layers = 0
geneformer/evaluation_utils.py CHANGED
@@ -77,7 +77,7 @@ def py_softmax(vector):
77
  return e / e.sum()
78
 
79
 
80
- def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict):
81
  if classifier_type == "gene":
82
  label_name = "labels"
83
  elif classifier_type == "cell":
@@ -85,6 +85,14 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
85
 
86
  predict_logits = []
87
  predict_labels = []
 
 
 
 
 
 
 
 
88
  model.eval()
89
 
90
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
@@ -99,9 +107,15 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
99
  for i in trange(0, evalset_len, forward_batch_size):
100
  max_range = min(i + forward_batch_size, evalset_len)
101
  batch_evalset = evalset.select([i for i in range(i, max_range)])
 
 
 
 
 
102
  padded_batch = preprocess_classifier_batch(
103
  batch_evalset, max_evalset_len, label_name, gene_token_dict
104
  )
 
105
  padded_batch.set_format(type="torch")
106
 
107
  # For datasets>=4.0.0, convert to dict to avoid format issues
@@ -134,7 +148,8 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
134
  y_pred = [vote(item[0]) for item in logit_label_paired]
135
  y_true = [item[1] for item in logit_label_paired]
136
  logits_list = [item[0] for item in logit_label_paired]
137
- return y_pred, y_true, logits_list
 
138
 
139
 
140
  def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
 
77
  return e / e.sum()
78
 
79
 
80
+ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict, predict_metadata=None):
81
  if classifier_type == "gene":
82
  label_name = "labels"
83
  elif classifier_type == "cell":
 
85
 
86
  predict_logits = []
87
  predict_labels = []
88
+
89
+ predict_metadata_all = None
90
+
91
+ if predict_metadata is not None:
92
+ predict_metadata_all = dict()
93
+ for metadata_name in predict_metadata:
94
+ predict_metadata_all[metadata_name] = []
95
+
96
  model.eval()
97
 
98
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
 
107
  for i in trange(0, evalset_len, forward_batch_size):
108
  max_range = min(i + forward_batch_size, evalset_len)
109
  batch_evalset = evalset.select([i for i in range(i, max_range)])
110
+
111
+ if predict_metadata is not None:
112
+ for metadata_name in predict_metadata:
113
+ predict_metadata_all[metadata_name] += batch_evalset[metadata_name]
114
+
115
  padded_batch = preprocess_classifier_batch(
116
  batch_evalset, max_evalset_len, label_name, gene_token_dict
117
  )
118
+
119
  padded_batch.set_format(type="torch")
120
 
121
  # For datasets>=4.0.0, convert to dict to avoid format issues
 
148
  y_pred = [vote(item[0]) for item in logit_label_paired]
149
  y_true = [item[1] for item in logit_label_paired]
150
  logits_list = [item[0] for item in logit_label_paired]
151
+
152
+ return y_pred, y_true, logits_list, predict_metadata_all
153
 
154
 
155
  def get_metrics(y_pred, y_true, logits_list, num_classes, labels):