Christina Theodoris
commited on
Commit
·
cfc8cdb
1
Parent(s):
7b591f6
fix exact_mean and exact_median to subselect dataframe emb cols, not cell rows
Browse files
geneformer/emb_extractor.py
CHANGED
|
@@ -429,8 +429,8 @@ class EmbExtractor:
|
|
| 429 |
|
| 430 |
**Parameters:**
|
| 431 |
|
| 432 |
-
model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
|
| 433 |
-
| Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
| 434 |
num_classes : int
|
| 435 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
| 436 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
|
@@ -644,15 +644,19 @@ class EmbExtractor:
|
|
| 644 |
if self.exact_summary_stat == "exact_mean":
|
| 645 |
embs = embs.mean(dim=0)
|
| 646 |
emb_dims = pu.get_model_emb_dims(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
embs_df = pd.DataFrame(
|
| 648 |
-
embs_df[0 : emb_dims - 1].mean(axis="rows"),
|
| 649 |
columns=[self.exact_summary_stat],
|
| 650 |
).T
|
| 651 |
elif self.exact_summary_stat == "exact_median":
|
| 652 |
embs = torch.median(embs, dim=0)[0]
|
| 653 |
emb_dims = pu.get_model_emb_dims(model)
|
| 654 |
embs_df = pd.DataFrame(
|
| 655 |
-
embs_df[0 : emb_dims - 1].median(axis="rows"),
|
| 656 |
columns=[self.exact_summary_stat],
|
| 657 |
).T
|
| 658 |
|
|
|
|
| 429 |
|
| 430 |
**Parameters:**
|
| 431 |
|
| 432 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "Pretrained-Quantized"}
|
| 433 |
+
| Whether model is the pretrained Geneformer (full or quantized) or a fine-tuned gene or cell classifier.
|
| 434 |
num_classes : int
|
| 435 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
| 436 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
|
|
|
| 644 |
if self.exact_summary_stat == "exact_mean":
|
| 645 |
embs = embs.mean(dim=0)
|
| 646 |
emb_dims = pu.get_model_emb_dims(model)
|
| 647 |
+
print(embs_df.shape)
|
| 648 |
+
print(embs_df)
|
| 649 |
+
print("#######")
|
| 650 |
+
print(embs_df.iloc[:, 0 : emb_dims - 1])
|
| 651 |
embs_df = pd.DataFrame(
|
| 652 |
+
embs_df.iloc[:, 0 : emb_dims - 1].mean(axis="rows"),
|
| 653 |
columns=[self.exact_summary_stat],
|
| 654 |
).T
|
| 655 |
elif self.exact_summary_stat == "exact_median":
|
| 656 |
embs = torch.median(embs, dim=0)[0]
|
| 657 |
emb_dims = pu.get_model_emb_dims(model)
|
| 658 |
embs_df = pd.DataFrame(
|
| 659 |
+
embs_df.iloc[:, 0 : emb_dims - 1].median(axis="rows"),
|
| 660 |
columns=[self.exact_summary_stat],
|
| 661 |
).T
|
| 662 |
|