geneformer/classifier.py CHANGED
@@ -801,7 +801,7 @@ class Classifier:
801
  # 5-fold cross-validate
802
  num_cells = len(data)
803
  fifth_cells = int(np.floor(num_cells * 0.2))
804
- num_eval = int(min((self.eval_size * num_cells), fifth_cells))
805
  start = i * fifth_cells
806
  end = start + num_eval
807
  eval_indices = [j for j in range(start, end)]
@@ -1313,7 +1313,6 @@ class Classifier:
1313
  predict=False,
1314
  output_directory=None,
1315
  output_prefix=None,
1316
- predict_metadata=None,
1317
  ):
1318
  """
1319
  Evaluate the fine-tuned model.
@@ -1339,11 +1338,9 @@ class Classifier:
1339
 
1340
  ##### Evaluate the model #####
1341
  labels = id_class_dict.keys()
1342
-
1343
- y_pred, y_true, logits_list, predict_metadata_all = eu.classifier_predict(
1344
- model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata
1345
  )
1346
-
1347
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1348
  y_pred, y_true, logits_list, num_classes, labels
1349
  )
@@ -1353,9 +1350,6 @@ class Classifier:
1353
  "label_ids": y_true,
1354
  "predictions": logits_list,
1355
  }
1356
- if predict_metadata is not None:
1357
- pred_dict["prediction_metadata"] = predict_metadata_all
1358
-
1359
  pred_dict_output_path = (
1360
  Path(output_directory) / f"{output_prefix}_pred_dict"
1361
  ).with_suffix(".pkl")
@@ -1376,7 +1370,6 @@ class Classifier:
1376
  output_directory,
1377
  output_prefix,
1378
  predict=True,
1379
- predict_metadata=None,
1380
  ):
1381
  """
1382
  Evaluate the fine-tuned model.
@@ -1396,8 +1389,6 @@ class Classifier:
1396
  | Prefix for output files
1397
  predict : bool
1398
  | Whether or not to save eval predictions
1399
- predict_metadata : None | list
1400
- | Metadata labels to output with predictions (columns in test_data_file)
1401
  """
1402
 
1403
  # load numerical id to class dictionary (id:class)
@@ -1410,15 +1401,6 @@ class Classifier:
1410
  # load previously filtered and prepared data
1411
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1412
 
1413
- if predict_metadata is not None:
1414
- absent_metadata = []
1415
- for predict_metadata_x in predict_metadata:
1416
- if predict_metadata_x not in test_data.features.keys():
1417
- absent_metadata += [predict_metadata_x]
1418
- if len(absent_metadata)>0:
1419
- logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}")
1420
- raise
1421
-
1422
  # load previously fine-tuned model
1423
  model = pu.load_model(
1424
  self.model_type,
@@ -1437,7 +1419,6 @@ class Classifier:
1437
  predict=predict,
1438
  output_directory=output_directory,
1439
  output_prefix=output_prefix,
1440
- predict_metadata=predict_metadata,
1441
  )
1442
 
1443
  all_conf_mat_df = pd.DataFrame(
 
801
  # 5-fold cross-validate
802
  num_cells = len(data)
803
  fifth_cells = int(np.floor(num_cells * 0.2))
804
+ num_eval = min((self.eval_size * num_cells), fifth_cells)
805
  start = i * fifth_cells
806
  end = start + num_eval
807
  eval_indices = [j for j in range(start, end)]
 
1313
  predict=False,
1314
  output_directory=None,
1315
  output_prefix=None,
 
1316
  ):
1317
  """
1318
  Evaluate the fine-tuned model.
 
1338
 
1339
  ##### Evaluate the model #####
1340
  labels = id_class_dict.keys()
1341
+ y_pred, y_true, logits_list = eu.classifier_predict(
1342
+ model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict
 
1343
  )
 
1344
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1345
  y_pred, y_true, logits_list, num_classes, labels
1346
  )
 
1350
  "label_ids": y_true,
1351
  "predictions": logits_list,
1352
  }
 
 
 
1353
  pred_dict_output_path = (
1354
  Path(output_directory) / f"{output_prefix}_pred_dict"
1355
  ).with_suffix(".pkl")
 
1370
  output_directory,
1371
  output_prefix,
1372
  predict=True,
 
1373
  ):
1374
  """
1375
  Evaluate the fine-tuned model.
 
1389
  | Prefix for output files
1390
  predict : bool
1391
  | Whether or not to save eval predictions
 
 
1392
  """
1393
 
1394
  # load numerical id to class dictionary (id:class)
 
1401
  # load previously filtered and prepared data
1402
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1403
 
 
 
 
 
 
 
 
 
 
1404
  # load previously fine-tuned model
1405
  model = pu.load_model(
1406
  self.model_type,
 
1419
  predict=predict,
1420
  output_directory=output_directory,
1421
  output_prefix=output_prefix,
 
1422
  )
1423
 
1424
  all_conf_mat_df = pd.DataFrame(
geneformer/classifier_utils.py CHANGED
@@ -94,7 +94,7 @@ def remove_rare(data, rare_threshold, label, nproc):
94
  return data
95
 
96
 
97
- def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict=None):
98
  if classifier == "cell":
99
  label_set = set(data["label"])
100
  elif classifier == "gene":
@@ -570,27 +570,6 @@ def compute_metrics(pred):
570
  return {"accuracy": acc, "macro_f1": macro_f1}
571
 
572
 
573
- def robust_compute_objective(metrics: dict):
574
- # tries both prefixed ("eval_") and raw metric names to support different transformers versions
575
- metric_name = "macro_f1"
576
-
577
- # check for the prefixed version
578
- prefixed_metric_name = f"eval_{metric_name}"
579
- if prefixed_metric_name in metrics:
580
- return metrics[prefixed_metric_name]
581
-
582
- # fall back to the raw name
583
- elif metric_name in metrics:
584
- return metrics[metric_name]
585
-
586
- # if neither is found, raise a clear error to help with debugging
587
- raise KeyError(
588
- f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. "
589
- f"Please check your `compute_metrics` function and `TrainingArguments`. "
590
- f"Available metrics: {list(metrics.keys())}"
591
- )
592
-
593
-
594
  def get_default_train_args(model, classifier, data, output_dir):
595
  num_layers = pu.quant_layers(model)
596
  freeze_layers = 0
 
94
  return data
95
 
96
 
97
+ def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict):
98
  if classifier == "cell":
99
  label_set = set(data["label"])
100
  elif classifier == "gene":
 
570
  return {"accuracy": acc, "macro_f1": macro_f1}
571
 
572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  def get_default_train_args(model, classifier, data, output_dir):
574
  num_layers = pu.quant_layers(model)
575
  freeze_layers = 0
geneformer/emb_extractor.py CHANGED
@@ -42,8 +42,6 @@ def get_embs(
42
  special_token=False,
43
  summary_stat=None,
44
  silent=False,
45
- save_tdigest=False,
46
- tdigest_path=None,
47
  ):
48
  model_input_size = pu.get_model_input_size(model)
49
  total_batch_length = len(filtered_input_data)
@@ -182,18 +180,12 @@ def get_embs(
182
  # calculate summary stat embs from approximated tdigests
183
  elif summary_stat is not None:
184
  if emb_mode == "cell":
185
- if save_tdigest:
186
- with open(f"{tdigest_path}","wb") as fp:
187
- pickle.dump(embs_tdigests, fp)
188
  if summary_stat == "mean":
189
  summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
190
  elif summary_stat == "median":
191
  summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
192
  embs_stack = torch.tensor(summary_emb_list)
193
  elif emb_mode == "gene":
194
- if save_tdigest:
195
- with open(f"{tdigest_path}","wb") as fp:
196
- pickle.dump(embs_tdigests_dict, fp)
197
  if summary_stat == "mean":
198
  [
199
  update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
@@ -260,7 +252,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels):
260
  return embs_df
261
 
262
 
263
- def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mean_pool"):
264
  gene_set = {
265
  element for sublist in downsampled_data["input_ids"] for element in sublist
266
  }
@@ -275,39 +267,16 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mea
275
  )
276
  for k in dict_i.keys():
277
  gene_emb_dict[k].append(dict_i[k])
278
- if gene_emb_style != "all":
279
- for k in gene_emb_dict.keys():
280
- gene_emb_dict[k] = (
281
- torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
282
- .cpu()
283
- .numpy()
284
- )
285
- embs_df = pd.DataFrame(gene_emb_dict).T
286
- else:
287
- embs_df = dict_lol_to_df(gene_emb_dict)
288
  embs_df.index = [token_gene_dict[token] for token in embs_df.index]
289
  return embs_df
290
 
291
- def dict_lol_to_df(data_dict):
292
- # save dictionary with values being list of equal-length lists as dataframe
293
- df_data = []
294
- for key, list_of_lists in data_dict.items():
295
- for i, sublist in enumerate(list_of_lists):
296
- row_data = [key, i] + sublist.tolist()
297
- df_data.append(row_data)
298
-
299
- # determine column names based on the length of sublists
300
- # assuming all sublists have the same length
301
- num_columns_from_sublist = len(list(data_dict.values())[0][0])
302
- column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)]
303
-
304
- # create the dataframe
305
- df = pd.DataFrame(df_data, columns=column_names)
306
-
307
- # set 'Gene' as the index
308
- df = df.set_index('Gene')
309
-
310
- return df
311
 
312
  def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
313
  only_embs_df = embs_df.iloc[:, :emb_dims]
@@ -435,7 +404,7 @@ class EmbExtractor:
435
  "num_classes": {int},
436
  "emb_mode": {"cls", "cell", "gene"},
437
  "cell_emb_style": {"mean_pool"},
438
- "gene_emb_style": {"mean_pool", "all"},
439
  "filter_data": {None, dict},
440
  "max_ncells": {None, int},
441
  "emb_layer": {-1, 0},
@@ -463,7 +432,6 @@ class EmbExtractor:
463
  forward_batch_size=100,
464
  nproc=4,
465
  summary_stat=None,
466
- save_tdigest=False,
467
  model_version="V2",
468
  token_dictionary_file=None,
469
  ):
@@ -483,9 +451,9 @@ class EmbExtractor:
483
  cell_emb_style : {"mean_pool"}
484
  | Method for summarizing cell embeddings if not using CLS token.
485
  | Currently only option is mean pooling of gene embeddings for given cell.
486
- gene_emb_style : {"mean_pool", "all}
487
  | Method for summarizing gene embeddings.
488
- | Currently only option is returning all or mean pooling of contextual gene embeddings for given gene.
489
  filter_data : None, dict
490
  | Default is to extract embeddings from all input data.
491
  | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
@@ -515,9 +483,6 @@ class EmbExtractor:
515
  | If mean or median, outputs only approximated mean or median embedding of input data.
516
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
517
  | Non-exact is slower but more memory-efficient.
518
- save_tdigest : bool
519
- | Whether to save a dictionary of tdigests for each gene and embedding dimension
520
- | Only applies when summary_stat is not None
521
  model_version : str
522
  | To auto-select settings for model version other than current default.
523
  | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
@@ -561,16 +526,9 @@ class EmbExtractor:
561
  else:
562
  self.summary_stat = summary_stat
563
  self.exact_summary_stat = None
564
- self.save_tdigest = save_tdigest
565
 
566
  self.validate_options()
567
 
568
- if (summary_stat is None) and (save_tdigest is True):
569
- logger.warning(
570
- "tdigests will not be saved since summary_stat is None."
571
- )
572
- save_tdigest = False
573
-
574
  if self.model_version == "V1":
575
  from . import TOKEN_DICTIONARY_FILE_30M
576
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
@@ -677,10 +635,6 @@ class EmbExtractor:
677
  self.model_type, self.num_classes, model_directory, mode="eval"
678
  )
679
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
680
- if self.save_tdigest:
681
- tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl")
682
- else:
683
- tdigest_path = None
684
  embs = get_embs(
685
  model=model,
686
  filtered_input_data=downsampled_data,
@@ -690,8 +644,6 @@ class EmbExtractor:
690
  forward_batch_size=self.forward_batch_size,
691
  token_gene_dict=self.token_gene_dict,
692
  summary_stat=self.summary_stat,
693
- save_tdigest=self.save_tdigest,
694
- tdigest_path=tdigest_path,
695
  )
696
 
697
  if self.emb_mode == "cell":
@@ -701,7 +653,7 @@ class EmbExtractor:
701
  embs_df = pd.DataFrame(embs.cpu().numpy()).T
702
  elif self.emb_mode == "gene":
703
  if self.summary_stat is None:
704
- embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict, self.gene_emb_style)
705
  elif self.summary_stat is not None:
706
  embs_df = pd.DataFrame(embs).T
707
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
@@ -885,14 +837,14 @@ class EmbExtractor:
885
  raise
886
 
887
  if max_ncells_to_plot is not None:
888
- if self.max_ncells is not None:
889
- if max_ncells_to_plot > self.max_ncells:
890
- max_ncells_to_plot = self.max_ncells
891
- logger.warning(
892
- "max_ncells_to_plot must be <= max_ncells. "
893
- f"Changing max_ncells_to_plot to {self.max_ncells}."
894
- )
895
- embs = embs.sample(max_ncells_to_plot, axis=0)
896
 
897
  if self.emb_label is None:
898
  label_len = 0
 
42
  special_token=False,
43
  summary_stat=None,
44
  silent=False,
 
 
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
 
180
  # calculate summary stat embs from approximated tdigests
181
  elif summary_stat is not None:
182
  if emb_mode == "cell":
 
 
 
183
  if summary_stat == "mean":
184
  summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
185
  elif summary_stat == "median":
186
  summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
187
  embs_stack = torch.tensor(summary_emb_list)
188
  elif emb_mode == "gene":
 
 
 
189
  if summary_stat == "mean":
190
  [
191
  update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
 
252
  return embs_df
253
 
254
 
255
+ def label_gene_embs(embs, downsampled_data, token_gene_dict):
256
  gene_set = {
257
  element for sublist in downsampled_data["input_ids"] for element in sublist
258
  }
 
267
  )
268
  for k in dict_i.keys():
269
  gene_emb_dict[k].append(dict_i[k])
270
+ for k in gene_emb_dict.keys():
271
+ gene_emb_dict[k] = (
272
+ torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
273
+ .cpu()
274
+ .numpy()
275
+ )
276
+ embs_df = pd.DataFrame(gene_emb_dict).T
 
 
 
277
  embs_df.index = [token_gene_dict[token] for token in embs_df.index]
278
  return embs_df
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
282
  only_embs_df = embs_df.iloc[:, :emb_dims]
 
404
  "num_classes": {int},
405
  "emb_mode": {"cls", "cell", "gene"},
406
  "cell_emb_style": {"mean_pool"},
407
+ "gene_emb_style": {"mean_pool"},
408
  "filter_data": {None, dict},
409
  "max_ncells": {None, int},
410
  "emb_layer": {-1, 0},
 
432
  forward_batch_size=100,
433
  nproc=4,
434
  summary_stat=None,
 
435
  model_version="V2",
436
  token_dictionary_file=None,
437
  ):
 
451
  cell_emb_style : {"mean_pool"}
452
  | Method for summarizing cell embeddings if not using CLS token.
453
  | Currently only option is mean pooling of gene embeddings for given cell.
454
+ gene_emb_style : "mean_pool"
455
  | Method for summarizing gene embeddings.
456
+ | Currently only option is mean pooling of contextual gene embeddings for given gene.
457
  filter_data : None, dict
458
  | Default is to extract embeddings from all input data.
459
  | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
 
483
  | If mean or median, outputs only approximated mean or median embedding of input data.
484
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
485
  | Non-exact is slower but more memory-efficient.
 
 
 
486
  model_version : str
487
  | To auto-select settings for model version other than current default.
488
  | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
 
526
  else:
527
  self.summary_stat = summary_stat
528
  self.exact_summary_stat = None
 
529
 
530
  self.validate_options()
531
 
 
 
 
 
 
 
532
  if self.model_version == "V1":
533
  from . import TOKEN_DICTIONARY_FILE_30M
534
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
 
635
  self.model_type, self.num_classes, model_directory, mode="eval"
636
  )
637
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
 
 
 
 
638
  embs = get_embs(
639
  model=model,
640
  filtered_input_data=downsampled_data,
 
644
  forward_batch_size=self.forward_batch_size,
645
  token_gene_dict=self.token_gene_dict,
646
  summary_stat=self.summary_stat,
 
 
647
  )
648
 
649
  if self.emb_mode == "cell":
 
653
  embs_df = pd.DataFrame(embs.cpu().numpy()).T
654
  elif self.emb_mode == "gene":
655
  if self.summary_stat is None:
656
+ embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
657
  elif self.summary_stat is not None:
658
  embs_df = pd.DataFrame(embs).T
659
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
837
  raise
838
 
839
  if max_ncells_to_plot is not None:
840
+ if max_ncells_to_plot > self.max_ncells:
841
+ max_ncells_to_plot = self.max_ncells
842
+ logger.warning(
843
+ "max_ncells_to_plot must be <= max_ncells. "
844
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
845
+ )
846
+ elif max_ncells_to_plot < self.max_ncells:
847
+ embs = embs.sample(max_ncells_to_plot, axis=0)
848
 
849
  if self.emb_label is None:
850
  label_len = 0
geneformer/evaluation_utils.py CHANGED
@@ -77,7 +77,7 @@ def py_softmax(vector):
77
  return e / e.sum()
78
 
79
 
80
- def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict, predict_metadata=None):
81
  if classifier_type == "gene":
82
  label_name = "labels"
83
  elif classifier_type == "cell":
@@ -85,14 +85,6 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
85
 
86
  predict_logits = []
87
  predict_labels = []
88
-
89
- predict_metadata_all = None
90
-
91
- if predict_metadata is not None:
92
- predict_metadata_all = dict()
93
- for metadata_name in predict_metadata:
94
- predict_metadata_all[metadata_name] = []
95
-
96
  model.eval()
97
 
98
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
@@ -107,15 +99,9 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
107
  for i in trange(0, evalset_len, forward_batch_size):
108
  max_range = min(i + forward_batch_size, evalset_len)
109
  batch_evalset = evalset.select([i for i in range(i, max_range)])
110
-
111
- if predict_metadata is not None:
112
- for metadata_name in predict_metadata:
113
- predict_metadata_all[metadata_name] += batch_evalset[metadata_name]
114
-
115
  padded_batch = preprocess_classifier_batch(
116
  batch_evalset, max_evalset_len, label_name, gene_token_dict
117
  )
118
-
119
  padded_batch.set_format(type="torch")
120
 
121
  # For datasets>=4.0.0, convert to dict to avoid format issues
@@ -148,8 +134,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
148
  y_pred = [vote(item[0]) for item in logit_label_paired]
149
  y_true = [item[1] for item in logit_label_paired]
150
  logits_list = [item[0] for item in logit_label_paired]
151
-
152
- return y_pred, y_true, logits_list, predict_metadata_all
153
 
154
 
155
  def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
 
77
  return e / e.sum()
78
 
79
 
80
+ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict):
81
  if classifier_type == "gene":
82
  label_name = "labels"
83
  elif classifier_type == "cell":
 
85
 
86
  predict_logits = []
87
  predict_labels = []
 
 
 
 
 
 
 
 
88
  model.eval()
89
 
90
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
 
99
  for i in trange(0, evalset_len, forward_batch_size):
100
  max_range = min(i + forward_batch_size, evalset_len)
101
  batch_evalset = evalset.select([i for i in range(i, max_range)])
 
 
 
 
 
102
  padded_batch = preprocess_classifier_batch(
103
  batch_evalset, max_evalset_len, label_name, gene_token_dict
104
  )
 
105
  padded_batch.set_format(type="torch")
106
 
107
  # For datasets>=4.0.0, convert to dict to avoid format issues
 
134
  y_pred = [vote(item[0]) for item in logit_label_paired]
135
  y_true = [item[1] for item in logit_label_paired]
136
  logits_list = [item[0] for item in logit_label_paired]
137
+ return y_pred, y_true, logits_list
 
138
 
139
 
140
  def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
geneformer/in_silico_perturber_stats.py CHANGED
@@ -726,9 +726,6 @@ class InSilicoPerturberStats:
726
  | "start_state": "dcm",
727
  | "goal_state": "nf",
728
  | "alt_states": ["hcm", "other1", "other2"]}
729
- pickle_suffix : None, str
730
- | Suffix to subselect intermediate raw files for analysis.
731
- | Default output of InSilicoPerturber uses suffix "_raw.pickle".
732
  model_version : str
733
  | To auto-select settings for model version other than current default.
734
  | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
 
726
  | "start_state": "dcm",
727
  | "goal_state": "nf",
728
  | "alt_states": ["hcm", "other1", "other2"]}
 
 
 
729
  model_version : str
730
  | To auto-select settings for model version other than current default.
731
  | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
geneformer/perturber_utils.py CHANGED
@@ -508,11 +508,6 @@ def make_perturbation_batch(
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
511
-
512
- # For datasets>=4.0.0, convert to dict to avoid format issues
513
- if int(datasets.__version__.split(".")[0]) >= 4:
514
- example_cell = example_cell[:]
515
-
516
  if combo_lvl == 0 and tokens_to_perturb == "all":
517
  if perturb_type in ["overexpress", "activate"]:
518
  range_start = 1
 
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
 
 
 
 
 
511
  if combo_lvl == 0 and tokens_to_perturb == "all":
512
  if perturb_type in ["overexpress", "activate"]:
513
  range_start = 1
geneformer/tokenizer.py CHANGED
@@ -3,7 +3,7 @@ Geneformer tokenizer.
3
 
4
  **Input data:**
5
 
6
- | *Required format:* raw counts scRNAseq data without feature selection as .loom, .h5ad, or .zarr file.
7
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
 
@@ -20,9 +20,9 @@ Geneformer tokenizer.
20
 
21
  **Description:**
22
 
23
- | Input data is a directory with .loom, .h5ad, or .zarr files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
 
25
- | The discussion below references the .loom file format, but the analagous labels are required for .h5ad and .zarr files, just that they will be column instead of row attributes and vice versa due to the transposed format of the file types.
26
 
27
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
 
@@ -30,7 +30,7 @@ Geneformer tokenizer.
30
 
31
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
 
33
- | If one's data is in other formats besides .loom, .h5ad, or .zarr, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom, .h5ad, or .zarr format prior to running the transcriptome tokenizer.
34
 
35
  | OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.)
36
 
@@ -46,7 +46,6 @@ from collections import Counter
46
  from pathlib import Path
47
  from typing import Literal
48
 
49
- import anndata as ad
50
  import loompy as lp
51
  import numpy as np
52
  import pandas as pd
@@ -201,16 +200,13 @@ def sum_ensembl_ids(
201
  dsout.add_columns(processed_array, col_attrs=view.ca)
202
  return dedup_filename
203
 
204
- elif file_format in ["h5ad", "zarr"]:
205
  """
206
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
207
  Returns adata object with deduplicated Ensembl IDs.
208
  """
209
 
210
- if file_format == "h5ad":
211
- data = sc.read_h5ad(str(data_directory))
212
- else: # zarr
213
- data = ad.read_zarr(str(data_directory))
214
 
215
  if use_h5ad_index:
216
  data.var["ensembl_id"] = list(data.var.index)
@@ -240,7 +236,7 @@ def sum_ensembl_ids(
240
  gene for gene in ensembl_ids if gene in gene_token_dict.keys()
241
  ]
242
  if len(ensembl_id_check) == len(set(ensembl_id_check)):
243
- return data
244
  else:
245
  raise ValueError("Error: data Ensembl IDs non-unique.")
246
 
@@ -439,7 +435,7 @@ class TranscriptomeTokenizer:
439
  data_directory: Path | str,
440
  output_directory: Path | str,
441
  output_prefix: str,
442
- file_format: Literal["loom", "h5ad", "zarr"] = "loom",
443
  input_identifier: str = "",
444
  use_generator: bool = False,
445
  ):
@@ -455,9 +451,9 @@ class TranscriptomeTokenizer:
455
  output_prefix : str
456
  | Prefix for output .dataset
457
  file_format : str
458
- | Format of input files. Can be "loom", "h5ad", or "zarr".
459
  input_identifier : str
460
- | Substring identifier for input .loom, .h5ad, or .zarr, only matches are tokenized
461
  | Default is no identifier, tokenizes all files in provided directory.
462
  use_generator : bool
463
  | Whether to use generator or dict for tokenization.
@@ -477,7 +473,7 @@ class TranscriptomeTokenizer:
477
  tokenized_dataset.save_to_disk(str(output_path))
478
 
479
  def tokenize_files(
480
- self, data_directory, file_format: Literal["loom", "h5ad", "zarr"] = "loom", input_identifier: str = ""
481
  ):
482
  tokenized_cells = []
483
  tokenized_counts = []
@@ -489,7 +485,7 @@ class TranscriptomeTokenizer:
489
 
490
  # loops through directories to tokenize .loom files
491
  file_found = 0
492
- # loops through directories to tokenize .loom, .h5ad, or .zarr files
493
  tokenize_file_fn = (
494
  self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
495
  )
@@ -500,7 +496,7 @@ class TranscriptomeTokenizer:
500
  for file_path in data_directory.glob(file_match):
501
  file_found = 1
502
  print(f"Tokenizing {file_path}")
503
- file_tokenized_cells, file_cell_metadata, file_tokenized_counts = tokenize_file_fn(file_path, file_format=file_format)
504
  tokenized_cells += file_tokenized_cells
505
  tokenized_counts += file_tokenized_counts
506
  if self.custom_attr_name_dict is not None:
@@ -518,7 +514,7 @@ class TranscriptomeTokenizer:
518
  raise
519
  return tokenized_cells, cell_metadata, tokenized_counts
520
 
521
- def tokenize_anndata(self, adata_file_path, target_sum=10_000, file_format="h5ad"):
522
  adata = sum_ensembl_ids(
523
  adata_file_path,
524
  self.collapse_gene_ids,
@@ -526,7 +522,7 @@ class TranscriptomeTokenizer:
526
  self.gene_token_dict,
527
  self.custom_attr_name_dict,
528
  self.use_h5ad_index,
529
- file_format=file_format,
530
  chunk_size=self.chunk_size,
531
  )
532
 
@@ -616,9 +612,7 @@ class TranscriptomeTokenizer:
616
 
617
  return tokenized_cells, file_cell_metadata, tokenized_counts
618
 
619
- def tokenize_loom(self, loom_file_path, target_sum=10_000, file_format="loom"):
620
- tokenized_counts = [] # keep_counts not implemented for tokenize_loom
621
-
622
  if self.custom_attr_name_dict is not None:
623
  file_cell_metadata = {
624
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
@@ -633,7 +627,7 @@ class TranscriptomeTokenizer:
633
  self.gene_token_dict,
634
  self.custom_attr_name_dict,
635
  use_h5ad_index=False,
636
- file_format=file_format,
637
  chunk_size=self.chunk_size,
638
  )
639
 
@@ -706,7 +700,7 @@ class TranscriptomeTokenizer:
706
  del data.ra["ensembl_id_collapsed"]
707
 
708
 
709
- return tokenized_cells, file_cell_metadata, tokenized_counts
710
 
711
  def create_dataset(
712
  self,
 
3
 
4
  **Input data:**
5
 
6
+ | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
7
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
 
 
20
 
21
  **Description:**
22
 
23
+ | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
 
25
+ | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
26
 
27
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
 
 
30
 
31
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
 
33
+ | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
 
35
  | OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.)
36
 
 
46
  from pathlib import Path
47
  from typing import Literal
48
 
 
49
  import loompy as lp
50
  import numpy as np
51
  import pandas as pd
 
200
  dsout.add_columns(processed_array, col_attrs=view.ca)
201
  return dedup_filename
202
 
203
+ elif file_format == "h5ad":
204
  """
205
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
206
  Returns adata object with deduplicated Ensembl IDs.
207
  """
208
 
209
+ data = sc.read_h5ad(str(data_directory))
 
 
 
210
 
211
  if use_h5ad_index:
212
  data.var["ensembl_id"] = list(data.var.index)
 
236
  gene for gene in ensembl_ids if gene in gene_token_dict.keys()
237
  ]
238
  if len(ensembl_id_check) == len(set(ensembl_id_check)):
239
+ return data_directory
240
  else:
241
  raise ValueError("Error: data Ensembl IDs non-unique.")
242
 
 
435
  data_directory: Path | str,
436
  output_directory: Path | str,
437
  output_prefix: str,
438
+ file_format: Literal["loom", "h5ad"] = "loom",
439
  input_identifier: str = "",
440
  use_generator: bool = False,
441
  ):
 
451
  output_prefix : str
452
  | Prefix for output .dataset
453
  file_format : str
454
+ | Format of input files. Can be "loom" or "h5ad".
455
  input_identifier : str
456
+ | Substring identifier for input .loom or .h5ad, only matches are tokenized
457
  | Default is no identifier, tokenizes all files in provided directory.
458
  use_generator : bool
459
  | Whether to use generator or dict for tokenization.
 
473
  tokenized_dataset.save_to_disk(str(output_path))
474
 
475
  def tokenize_files(
476
+ self, data_directory, file_format: Literal["loom", "h5ad"] = "loom", input_identifier: str = ""
477
  ):
478
  tokenized_cells = []
479
  tokenized_counts = []
 
485
 
486
  # loops through directories to tokenize .loom files
487
  file_found = 0
488
+ # loops through directories to tokenize .loom or .h5ad files
489
  tokenize_file_fn = (
490
  self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
491
  )
 
496
  for file_path in data_directory.glob(file_match):
497
  file_found = 1
498
  print(f"Tokenizing {file_path}")
499
+ file_tokenized_cells, file_cell_metadata, file_tokenized_counts = tokenize_file_fn(file_path)
500
  tokenized_cells += file_tokenized_cells
501
  tokenized_counts += file_tokenized_counts
502
  if self.custom_attr_name_dict is not None:
 
514
  raise
515
  return tokenized_cells, cell_metadata, tokenized_counts
516
 
517
+ def tokenize_anndata(self, adata_file_path, target_sum=10_000):
518
  adata = sum_ensembl_ids(
519
  adata_file_path,
520
  self.collapse_gene_ids,
 
522
  self.gene_token_dict,
523
  self.custom_attr_name_dict,
524
  self.use_h5ad_index,
525
+ file_format="h5ad",
526
  chunk_size=self.chunk_size,
527
  )
528
 
 
612
 
613
  return tokenized_cells, file_cell_metadata, tokenized_counts
614
 
615
+ def tokenize_loom(self, loom_file_path, target_sum=10_000):
 
 
616
  if self.custom_attr_name_dict is not None:
617
  file_cell_metadata = {
618
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
 
627
  self.gene_token_dict,
628
  self.custom_attr_name_dict,
629
  use_h5ad_index=False,
630
+ file_format="loom",
631
  chunk_size=self.chunk_size,
632
  )
633
 
 
700
  del data.ra["ensembl_id_collapsed"]
701
 
702
 
703
+ return tokenized_cells, file_cell_metadata
704
 
705
  def create_dataset(
706
  self,