Christina Theodoris
commited on
Commit
·
bfcada4
1
Parent(s):
e2ee685
fix gene class dict labeling
Browse files
geneformer/classifier_utils.py
CHANGED
|
@@ -115,13 +115,20 @@ def label_classes(classifier, data, gene_class_dict, nproc):
|
|
| 115 |
|
| 116 |
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
| 117 |
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
def classes_to_ids(example):
|
| 120 |
if classifier == "cell":
|
| 121 |
example["label"] = class_id_dict[example["label"]]
|
| 122 |
elif classifier == "gene":
|
| 123 |
example["labels"] = label_gene_classes(
|
| 124 |
-
example, class_id_dict,
|
| 125 |
)
|
| 126 |
return example
|
| 127 |
|
|
@@ -129,9 +136,9 @@ def label_classes(classifier, data, gene_class_dict, nproc):
|
|
| 129 |
return data, id_class_dict
|
| 130 |
|
| 131 |
|
| 132 |
-
def label_gene_classes(example, class_id_dict,
|
| 133 |
return [
|
| 134 |
-
class_id_dict.get(
|
| 135 |
for token_id in example["input_ids"]
|
| 136 |
]
|
| 137 |
|
|
|
|
| 115 |
|
| 116 |
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
| 117 |
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
| 118 |
+
inverse_gene_class_dict = {}
|
| 119 |
+
# Iterate over each key and list of values in the original dictionary
|
| 120 |
+
for key, value_list in gene_class_dict.items():
|
| 121 |
+
# Iterate over each value in the list
|
| 122 |
+
for value in value_list:
|
| 123 |
+
# Assign the value as a key and the original key as its value in the new dictionary
|
| 124 |
+
inverse_gene_class_dict[value] = key
|
| 125 |
|
| 126 |
def classes_to_ids(example):
|
| 127 |
if classifier == "cell":
|
| 128 |
example["label"] = class_id_dict[example["label"]]
|
| 129 |
elif classifier == "gene":
|
| 130 |
example["labels"] = label_gene_classes(
|
| 131 |
+
example, class_id_dict, inverse_gene_class_dict
|
| 132 |
)
|
| 133 |
return example
|
| 134 |
|
|
|
|
| 136 |
return data, id_class_dict
|
| 137 |
|
| 138 |
|
| 139 |
+
def label_gene_classes(example, class_id_dict, inverse_gene_class_dict):
|
| 140 |
return [
|
| 141 |
+
class_id_dict.get(inverse_gene_class_dict.get(token_id, -100), -100)
|
| 142 |
for token_id in example["input_ids"]
|
| 143 |
]
|
| 144 |
|