Update geneformer/classifier.py
Browse filesPrepare data cell level pass cell state dict instead of genes
- geneformer/classifier.py +9 -3
geneformer/classifier.py
CHANGED
|
@@ -437,14 +437,20 @@ class Classifier:
|
|
| 437 |
)
|
| 438 |
# rename cell state column to "label"
|
| 439 |
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
|
|
|
|
| 441 |
# convert classes to numerical labels and save as id_class_dict
|
| 442 |
# of note, will label all genes in gene_class_dict
|
| 443 |
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
| 444 |
# at the time of training with Classifier.validate
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
|
| 449 |
# save id_class_dict for future reference
|
| 450 |
id_class_output_path = (
|
|
|
|
| 437 |
)
|
| 438 |
# rename cell state column to "label"
|
| 439 |
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
| 440 |
+
|
| 441 |
+
# convert classes to numerical labels and save as id_class_dict
|
| 442 |
+
data, id_class_dict = cu.label_classes(
|
| 443 |
+
self.classifier, data, self.cell_state_dict, self.nproc
|
| 444 |
+
)
|
| 445 |
|
| 446 |
+
elif self.classifier == "gene":
|
| 447 |
# convert classes to numerical labels and save as id_class_dict
|
| 448 |
# of note, will label all genes in gene_class_dict
|
| 449 |
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
| 450 |
# at the time of training with Classifier.validate
|
| 451 |
+
data, id_class_dict = cu.label_classes(
|
| 452 |
+
self.classifier, data, self.gene_class_dict, self.nproc
|
| 453 |
+
)
|
| 454 |
|
| 455 |
# save id_class_dict for future reference
|
| 456 |
id_class_output_path = (
|