Christina Theodoris commited on
Commit
fb6abe0
·
1 Parent(s): 5cada60

add emb extractor option for saving all gene embs

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +60 -12
geneformer/emb_extractor.py CHANGED
@@ -42,6 +42,8 @@ def get_embs(
42
  special_token=False,
43
  summary_stat=None,
44
  silent=False,
 
 
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
@@ -180,12 +182,18 @@ def get_embs(
180
  # calculate summary stat embs from approximated tdigests
181
  elif summary_stat is not None:
182
  if emb_mode == "cell":
 
 
 
183
  if summary_stat == "mean":
184
  summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
185
  elif summary_stat == "median":
186
  summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
187
  embs_stack = torch.tensor(summary_emb_list)
188
  elif emb_mode == "gene":
 
 
 
189
  if summary_stat == "mean":
190
  [
191
  update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
@@ -252,7 +260,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels):
252
  return embs_df
253
 
254
 
255
- def label_gene_embs(embs, downsampled_data, token_gene_dict):
256
  gene_set = {
257
  element for sublist in downsampled_data["input_ids"] for element in sublist
258
  }
@@ -267,16 +275,39 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict):
267
  )
268
  for k in dict_i.keys():
269
  gene_emb_dict[k].append(dict_i[k])
270
- for k in gene_emb_dict.keys():
271
- gene_emb_dict[k] = (
272
- torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
273
- .cpu()
274
- .numpy()
275
- )
276
- embs_df = pd.DataFrame(gene_emb_dict).T
 
 
 
277
  embs_df.index = [token_gene_dict[token] for token in embs_df.index]
278
  return embs_df
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
282
  only_embs_df = embs_df.iloc[:, :emb_dims]
@@ -404,7 +435,7 @@ class EmbExtractor:
404
  "num_classes": {int},
405
  "emb_mode": {"cls", "cell", "gene"},
406
  "cell_emb_style": {"mean_pool"},
407
- "gene_emb_style": {"mean_pool"},
408
  "filter_data": {None, dict},
409
  "max_ncells": {None, int},
410
  "emb_layer": {-1, 0},
@@ -432,6 +463,7 @@ class EmbExtractor:
432
  forward_batch_size=100,
433
  nproc=4,
434
  summary_stat=None,
 
435
  model_version="V2",
436
  token_dictionary_file=None,
437
  ):
@@ -451,9 +483,9 @@ class EmbExtractor:
451
  cell_emb_style : {"mean_pool"}
452
  | Method for summarizing cell embeddings if not using CLS token.
453
  | Currently only option is mean pooling of gene embeddings for given cell.
454
- gene_emb_style : "mean_pool"
455
  | Method for summarizing gene embeddings.
456
- | Currently only option is mean pooling of contextual gene embeddings for given gene.
457
  filter_data : None, dict
458
  | Default is to extract embeddings from all input data.
459
  | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
@@ -483,6 +515,9 @@ class EmbExtractor:
483
  | If mean or median, outputs only approximated mean or median embedding of input data.
484
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
485
  | Non-exact is slower but more memory-efficient.
 
 
 
486
  model_version : str
487
  | To auto-select settings for model version other than current default.
488
  | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
@@ -526,9 +561,16 @@ class EmbExtractor:
526
  else:
527
  self.summary_stat = summary_stat
528
  self.exact_summary_stat = None
 
529
 
530
  self.validate_options()
531
 
 
 
 
 
 
 
532
  if self.model_version == "V1":
533
  from . import TOKEN_DICTIONARY_FILE_30M
534
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
@@ -635,6 +677,10 @@ class EmbExtractor:
635
  self.model_type, self.num_classes, model_directory, mode="eval"
636
  )
637
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
 
 
 
 
638
  embs = get_embs(
639
  model=model,
640
  filtered_input_data=downsampled_data,
@@ -644,6 +690,8 @@ class EmbExtractor:
644
  forward_batch_size=self.forward_batch_size,
645
  token_gene_dict=self.token_gene_dict,
646
  summary_stat=self.summary_stat,
 
 
647
  )
648
 
649
  if self.emb_mode == "cell":
@@ -653,7 +701,7 @@ class EmbExtractor:
653
  embs_df = pd.DataFrame(embs.cpu().numpy()).T
654
  elif self.emb_mode == "gene":
655
  if self.summary_stat is None:
656
- embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
657
  elif self.summary_stat is not None:
658
  embs_df = pd.DataFrame(embs).T
659
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
42
  special_token=False,
43
  summary_stat=None,
44
  silent=False,
45
+ save_tdigest=False,
46
+ tdigest_path=None,
47
  ):
48
  model_input_size = pu.get_model_input_size(model)
49
  total_batch_length = len(filtered_input_data)
 
182
  # calculate summary stat embs from approximated tdigests
183
  elif summary_stat is not None:
184
  if emb_mode == "cell":
185
+ if save_tdigest:
186
+ with open(f"{tdigest_path}","wb") as fp:
187
+ pickle.dump(embs_tdigests, fp)
188
  if summary_stat == "mean":
189
  summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
190
  elif summary_stat == "median":
191
  summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
192
  embs_stack = torch.tensor(summary_emb_list)
193
  elif emb_mode == "gene":
194
+ if save_tdigest:
195
+ with open(f"{tdigest_path}","wb") as fp:
196
+ pickle.dump(embs_tdigests_dict, fp)
197
  if summary_stat == "mean":
198
  [
199
  update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
 
260
  return embs_df
261
 
262
 
263
+ def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mean_pool"):
264
  gene_set = {
265
  element for sublist in downsampled_data["input_ids"] for element in sublist
266
  }
 
275
  )
276
  for k in dict_i.keys():
277
  gene_emb_dict[k].append(dict_i[k])
278
+ if gene_emb_style != "all":
279
+ for k in gene_emb_dict.keys():
280
+ gene_emb_dict[k] = (
281
+ torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
282
+ .cpu()
283
+ .numpy()
284
+ )
285
+ embs_df = pd.DataFrame(gene_emb_dict).T
286
+ else:
287
+ embs_df = dict_lol_to_df(gene_emb_dict)
288
  embs_df.index = [token_gene_dict[token] for token in embs_df.index]
289
  return embs_df
290
 
291
+ def dict_lol_to_df(data_dict):
292
+ # save dictionary with values being list of equal-length lists as dataframe
293
+ df_data = []
294
+ for key, list_of_lists in data_dict.items():
295
+ for i, sublist in enumerate(list_of_lists):
296
+ row_data = [key, i] + sublist.tolist()
297
+ df_data.append(row_data)
298
+
299
+ # determine column names based on the length of sublists
300
+ # assuming all sublists have the same length
301
+ num_columns_from_sublist = len(list(data_dict.values())[0][0])
302
+ column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)]
303
+
304
+ # create the dataframe
305
+ df = pd.DataFrame(df_data, columns=column_names)
306
+
307
+ # set 'Gene' as the index
308
+ df = df.set_index('Gene')
309
+
310
+ return df
311
 
312
  def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
313
  only_embs_df = embs_df.iloc[:, :emb_dims]
 
435
  "num_classes": {int},
436
  "emb_mode": {"cls", "cell", "gene"},
437
  "cell_emb_style": {"mean_pool"},
438
+ "gene_emb_style": {"mean_pool", "all"},
439
  "filter_data": {None, dict},
440
  "max_ncells": {None, int},
441
  "emb_layer": {-1, 0},
 
463
  forward_batch_size=100,
464
  nproc=4,
465
  summary_stat=None,
466
+ save_tdigest=False,
467
  model_version="V2",
468
  token_dictionary_file=None,
469
  ):
 
483
  cell_emb_style : {"mean_pool"}
484
  | Method for summarizing cell embeddings if not using CLS token.
485
  | Currently only option is mean pooling of gene embeddings for given cell.
486
+ gene_emb_style : {"mean_pool", "all}
487
  | Method for summarizing gene embeddings.
488
+ | Currently only option is returning all or mean pooling of contextual gene embeddings for given gene.
489
  filter_data : None, dict
490
  | Default is to extract embeddings from all input data.
491
  | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
 
515
  | If mean or median, outputs only approximated mean or median embedding of input data.
516
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
517
  | Non-exact is slower but more memory-efficient.
518
+ save_tdigest : bool
519
+ | Whether to save a dictionary of tdigests for each gene and embedding dimension
520
+ | Only applies when summary_stat is not None
521
  model_version : str
522
  | To auto-select settings for model version other than current default.
523
  | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
 
561
  else:
562
  self.summary_stat = summary_stat
563
  self.exact_summary_stat = None
564
+ self.save_tdigest = save_tdigest
565
 
566
  self.validate_options()
567
 
568
+ if (summary_stat is None) and (save_tdigest is True):
569
+ logger.warning(
570
+ "tdigests will not be saved since summary_stat is None."
571
+ )
572
+ save_tdigest = False
573
+
574
  if self.model_version == "V1":
575
  from . import TOKEN_DICTIONARY_FILE_30M
576
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
 
677
  self.model_type, self.num_classes, model_directory, mode="eval"
678
  )
679
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
680
+ if self.save_tdigest:
681
+ tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl")
682
+ else:
683
+ tdigest_path = None
684
  embs = get_embs(
685
  model=model,
686
  filtered_input_data=downsampled_data,
 
690
  forward_batch_size=self.forward_batch_size,
691
  token_gene_dict=self.token_gene_dict,
692
  summary_stat=self.summary_stat,
693
+ save_tdigest=self.save_tdigest,
694
+ tdigest_path=tdigest_path,
695
  )
696
 
697
  if self.emb_mode == "cell":
 
701
  embs_df = pd.DataFrame(embs.cpu().numpy()).T
702
  elif self.emb_mode == "gene":
703
  if self.summary_stat is None:
704
+ embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict, self.gene_emb_style)
705
  elif self.summary_stat is not None:
706
  embs_df = pd.DataFrame(embs).T
707
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]