embs_df with all model embeddings
Browse files- geneformer/emb_extractor.py +7 -15
geneformer/emb_extractor.py
CHANGED
|
@@ -50,9 +50,7 @@ def get_embs(
|
|
| 50 |
embs_list = []
|
| 51 |
elif summary_stat is not None:
|
| 52 |
# test embedding extraction for example cell and extract # emb dims
|
| 53 |
-
|
| 54 |
-
example.set_format(type="torch")
|
| 55 |
-
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
| 56 |
if emb_mode == "cell":
|
| 57 |
# initiate tdigests for # of emb dims
|
| 58 |
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
|
@@ -78,7 +76,7 @@ def get_embs(
|
|
| 78 |
gene_token_dict = {v:k for k,v in token_gene_dict.items()}
|
| 79 |
cls_token_id = gene_token_dict["<cls>"]
|
| 80 |
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
| 81 |
-
|
| 82 |
if cls_present:
|
| 83 |
logger.warning("CLS token present in token dictionary, excluding from average.")
|
| 84 |
if eos_present:
|
|
@@ -148,7 +146,7 @@ def get_embs(
|
|
| 148 |
del embs_h
|
| 149 |
del dict_h
|
| 150 |
elif emb_mode == "cls":
|
| 151 |
-
cls_embs = embs_i[:,0,:]
|
| 152 |
embs_list.append(cls_embs)
|
| 153 |
del cls_embs
|
| 154 |
|
|
@@ -239,14 +237,6 @@ def tdigest_median(embs_tdigests, emb_dims):
|
|
| 239 |
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
| 240 |
|
| 241 |
|
| 242 |
-
def test_emb(model, example, layer_to_quant):
|
| 243 |
-
with torch.no_grad():
|
| 244 |
-
outputs = model(input_ids=example.to("cuda"))
|
| 245 |
-
|
| 246 |
-
embs_test = outputs.hidden_states[layer_to_quant]
|
| 247 |
-
return embs_test.size()[2]
|
| 248 |
-
|
| 249 |
-
|
| 250 |
def label_cell_embs(embs, downsampled_data, emb_labels):
|
| 251 |
embs_df = pd.DataFrame(embs.cpu().numpy())
|
| 252 |
if emb_labels is not None:
|
|
@@ -632,13 +622,15 @@ class EmbExtractor:
|
|
| 632 |
|
| 633 |
if self.exact_summary_stat == "exact_mean":
|
| 634 |
embs = embs.mean(dim=0)
|
|
|
|
| 635 |
embs_df = pd.DataFrame(
|
| 636 |
-
embs_df[0:
|
| 637 |
).T
|
| 638 |
elif self.exact_summary_stat == "exact_median":
|
| 639 |
embs = torch.median(embs, dim=0)[0]
|
|
|
|
| 640 |
embs_df = pd.DataFrame(
|
| 641 |
-
embs_df[0:
|
| 642 |
).T
|
| 643 |
|
| 644 |
if cell_state is not None:
|
|
|
|
| 50 |
embs_list = []
|
| 51 |
elif summary_stat is not None:
|
| 52 |
# test embedding extraction for example cell and extract # emb dims
|
| 53 |
+
emb_dims = pu.get_model_embedding_dimensions(model)
|
|
|
|
|
|
|
| 54 |
if emb_mode == "cell":
|
| 55 |
# initiate tdigests for # of emb dims
|
| 56 |
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
|
|
|
| 76 |
gene_token_dict = {v:k for k,v in token_gene_dict.items()}
|
| 77 |
cls_token_id = gene_token_dict["<cls>"]
|
| 78 |
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
| 79 |
+
else:
|
| 80 |
if cls_present:
|
| 81 |
logger.warning("CLS token present in token dictionary, excluding from average.")
|
| 82 |
if eos_present:
|
|
|
|
| 146 |
del embs_h
|
| 147 |
del dict_h
|
| 148 |
elif emb_mode == "cls":
|
| 149 |
+
cls_embs = embs_i[:,0,:] # CLS token layer
|
| 150 |
embs_list.append(cls_embs)
|
| 151 |
del cls_embs
|
| 152 |
|
|
|
|
| 237 |
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
| 238 |
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
def label_cell_embs(embs, downsampled_data, emb_labels):
|
| 241 |
embs_df = pd.DataFrame(embs.cpu().numpy())
|
| 242 |
if emb_labels is not None:
|
|
|
|
| 622 |
|
| 623 |
if self.exact_summary_stat == "exact_mean":
|
| 624 |
embs = embs.mean(dim=0)
|
| 625 |
+
emb_dims = pu.get_model_embedding_dimensions(model)
|
| 626 |
embs_df = pd.DataFrame(
|
| 627 |
+
embs_df[0:emb_dims-1].mean(axis="rows"), columns=[self.exact_summary_stat]
|
| 628 |
).T
|
| 629 |
elif self.exact_summary_stat == "exact_median":
|
| 630 |
embs = torch.median(embs, dim=0)[0]
|
| 631 |
+
emb_dims = pu.get_model_embedding_dimensions(model)
|
| 632 |
embs_df = pd.DataFrame(
|
| 633 |
+
embs_df[0:emb_dims-1].median(axis="rows"), columns=[self.exact_summary_stat]
|
| 634 |
).T
|
| 635 |
|
| 636 |
if cell_state is not None:
|