Christina Theodoris commited on
Commit
cfc8cdb
·
1 Parent(s): 7b591f6

fix exact_mean and exact_median to subselect dataframe emb cols, not cell rows

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +8 -4
geneformer/emb_extractor.py CHANGED
@@ -429,8 +429,8 @@ class EmbExtractor:
429
 
430
  **Parameters:**
431
 
432
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
433
- | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
434
  num_classes : int
435
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
436
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
@@ -644,15 +644,19 @@ class EmbExtractor:
644
  if self.exact_summary_stat == "exact_mean":
645
  embs = embs.mean(dim=0)
646
  emb_dims = pu.get_model_emb_dims(model)
 
 
 
 
647
  embs_df = pd.DataFrame(
648
- embs_df[0 : emb_dims - 1].mean(axis="rows"),
649
  columns=[self.exact_summary_stat],
650
  ).T
651
  elif self.exact_summary_stat == "exact_median":
652
  embs = torch.median(embs, dim=0)[0]
653
  emb_dims = pu.get_model_emb_dims(model)
654
  embs_df = pd.DataFrame(
655
- embs_df[0 : emb_dims - 1].median(axis="rows"),
656
  columns=[self.exact_summary_stat],
657
  ).T
658
 
 
429
 
430
  **Parameters:**
431
 
432
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "Pretrained-Quantized"}
433
+ | Whether model is the pretrained Geneformer (full or quantized) or a fine-tuned gene or cell classifier.
434
  num_classes : int
435
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
436
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
 
644
  if self.exact_summary_stat == "exact_mean":
645
  embs = embs.mean(dim=0)
646
  emb_dims = pu.get_model_emb_dims(model)
647
+ print(embs_df.shape)
648
+ print(embs_df)
649
+ print("#######")
650
+ print(embs_df.iloc[:, 0 : emb_dims - 1])
651
  embs_df = pd.DataFrame(
652
+ embs_df.iloc[:, 0 : emb_dims - 1].mean(axis="rows"),
653
  columns=[self.exact_summary_stat],
654
  ).T
655
  elif self.exact_summary_stat == "exact_median":
656
  embs = torch.median(embs, dim=0)[0]
657
  emb_dims = pu.get_model_emb_dims(model)
658
  embs_df = pd.DataFrame(
659
+ embs_df.iloc[:, 0 : emb_dims - 1].median(axis="rows"),
660
  columns=[self.exact_summary_stat],
661
  ).T
662