incorporate prior changes
Browse files
geneformer/emb_extractor.py
CHANGED
|
@@ -49,8 +49,8 @@ def get_embs(
|
|
| 49 |
if summary_stat is None:
|
| 50 |
embs_list = []
|
| 51 |
elif summary_stat is not None:
|
| 52 |
-
#
|
| 53 |
-
emb_dims = pu.
|
| 54 |
if emb_mode == "cell":
|
| 55 |
# initiate tdigests for # of emb dims
|
| 56 |
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
|
@@ -76,7 +76,7 @@ def get_embs(
|
|
| 76 |
gene_token_dict = {v:k for k,v in token_gene_dict.items()}
|
| 77 |
cls_token_id = gene_token_dict["<cls>"]
|
| 78 |
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
| 79 |
-
|
| 80 |
if cls_present:
|
| 81 |
logger.warning("CLS token present in token dictionary, excluding from average.")
|
| 82 |
if eos_present:
|
|
@@ -146,7 +146,7 @@ def get_embs(
|
|
| 146 |
del embs_h
|
| 147 |
del dict_h
|
| 148 |
elif emb_mode == "cls":
|
| 149 |
-
cls_embs = embs_i[:,0,:] # CLS token layer
|
| 150 |
embs_list.append(cls_embs)
|
| 151 |
del cls_embs
|
| 152 |
|
|
|
|
| 49 |
if summary_stat is None:
|
| 50 |
embs_list = []
|
| 51 |
elif summary_stat is not None:
|
| 52 |
+
# get # of emb dims
|
| 53 |
+
emb_dims = pu.get_model_emb_dims(model)
|
| 54 |
if emb_mode == "cell":
|
| 55 |
# initiate tdigests for # of emb dims
|
| 56 |
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
|
|
|
| 76 |
gene_token_dict = {v:k for k,v in token_gene_dict.items()}
|
| 77 |
cls_token_id = gene_token_dict["<cls>"]
|
| 78 |
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
| 79 |
+
elif emb_mode == "cell":
|
| 80 |
if cls_present:
|
| 81 |
logger.warning("CLS token present in token dictionary, excluding from average.")
|
| 82 |
if eos_present:
|
|
|
|
| 146 |
del embs_h
|
| 147 |
del dict_h
|
| 148 |
elif emb_mode == "cls":
|
| 149 |
+
cls_embs = embs_i[:,0,:].clone().detach() # CLS token layer
|
| 150 |
embs_list.append(cls_embs)
|
| 151 |
del cls_embs
|
| 152 |
|