pointing dictionaries from the mtl module's init
#397
by
madhavanvenkatesh
- opened
geneformer/mtl/collators.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
| 1 |
# imports
|
| 2 |
import torch
|
| 3 |
-
|
| 4 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
|
|
|
| 5 |
|
| 6 |
"""
|
| 7 |
Geneformer collator for multi-task cell classification.
|
| 8 |
"""
|
| 9 |
|
| 10 |
-
|
| 11 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
| 12 |
class_type = "cell"
|
| 13 |
|
| 14 |
def __init__(self, *args, **kwargs) -> None:
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
def _prepare_batch(self, features):
|
| 18 |
# Process inputs as usual
|
|
|
|
| 1 |
# imports
|
| 2 |
import torch
|
|
|
|
| 3 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
| 4 |
+
from . import TOKEN_DICTIONARY # import the token dictionary from the mtl module's init
|
| 5 |
|
| 6 |
"""
|
| 7 |
Geneformer collator for multi-task cell classification.
|
| 8 |
"""
|
| 9 |
|
|
|
|
| 10 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
| 11 |
class_type = "cell"
|
| 12 |
|
| 13 |
def __init__(self, *args, **kwargs) -> None:
|
| 14 |
+
# Use the loaded token dictionary from the mtl module's init
|
| 15 |
+
super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
|
| 16 |
|
| 17 |
def _prepare_batch(self, features):
|
| 18 |
# Process inputs as usual
|