Christina Theodoris commited on
Commit ·
4364d31
1
Parent(s): 76a78a0
move V1 autoformatting to after validate_options
Browse files
geneformer/classifier.py
CHANGED
|
@@ -234,10 +234,6 @@ class Classifier:
|
|
| 234 |
self.token_dictionary_file = token_dictionary_file
|
| 235 |
self.nproc = nproc
|
| 236 |
self.ngpu = ngpu
|
| 237 |
-
|
| 238 |
-
if self.model_version == "V1":
|
| 239 |
-
from . import TOKEN_DICTIONARY_FILE_30M
|
| 240 |
-
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
| 241 |
|
| 242 |
if self.training_args is None:
|
| 243 |
logger.warning(
|
|
@@ -258,7 +254,10 @@ class Classifier:
|
|
| 258 |
] = self.cell_state_dict["states"]
|
| 259 |
|
| 260 |
# load token dictionary (Ensembl IDs:token)
|
| 261 |
-
if self.
|
|
|
|
|
|
|
|
|
|
| 262 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
| 263 |
with open(self.token_dictionary_file, "rb") as f:
|
| 264 |
self.gene_token_dict = pickle.load(f)
|
|
|
|
| 234 |
self.token_dictionary_file = token_dictionary_file
|
| 235 |
self.nproc = nproc
|
| 236 |
self.ngpu = ngpu
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
if self.training_args is None:
|
| 239 |
logger.warning(
|
|
|
|
| 254 |
] = self.cell_state_dict["states"]
|
| 255 |
|
| 256 |
# load token dictionary (Ensembl IDs:token)
|
| 257 |
+
if self.model_version == "V1":
|
| 258 |
+
from . import TOKEN_DICTIONARY_FILE_30M
|
| 259 |
+
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
| 260 |
+
elif self.token_dictionary_file is None:
|
| 261 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
| 262 |
with open(self.token_dictionary_file, "rb") as f:
|
| 263 |
self.gene_token_dict = pickle.load(f)
|
geneformer/emb_extractor.py
CHANGED
|
@@ -518,6 +518,8 @@ class EmbExtractor:
|
|
| 518 |
self.summary_stat = summary_stat
|
| 519 |
self.exact_summary_stat = None
|
| 520 |
|
|
|
|
|
|
|
| 521 |
if self.model_version == "V1":
|
| 522 |
from . import TOKEN_DICTIONARY_FILE_30M
|
| 523 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
|
@@ -527,8 +529,6 @@ class EmbExtractor:
|
|
| 527 |
"model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
|
| 528 |
)
|
| 529 |
|
| 530 |
-
self.validate_options()
|
| 531 |
-
|
| 532 |
# load token dictionary (Ensembl IDs:token)
|
| 533 |
if self.token_dictionary_file is None:
|
| 534 |
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
|
|
|
| 518 |
self.summary_stat = summary_stat
|
| 519 |
self.exact_summary_stat = None
|
| 520 |
|
| 521 |
+
self.validate_options()
|
| 522 |
+
|
| 523 |
if self.model_version == "V1":
|
| 524 |
from . import TOKEN_DICTIONARY_FILE_30M
|
| 525 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
|
|
|
| 529 |
"model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
|
| 530 |
)
|
| 531 |
|
|
|
|
|
|
|
| 532 |
# load token dictionary (Ensembl IDs:token)
|
| 533 |
if self.token_dictionary_file is None:
|
| 534 |
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
geneformer/in_silico_perturber.py
CHANGED
|
@@ -231,7 +231,9 @@ class InSilicoPerturber:
|
|
| 231 |
self.nproc = nproc
|
| 232 |
self.model_version = model_version
|
| 233 |
self.token_dictionary_file = token_dictionary_file
|
| 234 |
-
self.clear_mem_ncells = clear_mem_ncells
|
|
|
|
|
|
|
| 235 |
|
| 236 |
if self.model_version == "V1":
|
| 237 |
from . import TOKEN_DICTIONARY_FILE_30M
|
|
@@ -245,10 +247,8 @@ class InSilicoPerturber:
|
|
| 245 |
self.emb_mode = "cell_and_gene"
|
| 246 |
logger.warning(
|
| 247 |
"model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
self.validate_options()
|
| 251 |
-
|
| 252 |
# load token dictionary (Ensembl IDs:token)
|
| 253 |
if self.token_dictionary_file is None:
|
| 254 |
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
|
|
|
| 231 |
self.nproc = nproc
|
| 232 |
self.model_version = model_version
|
| 233 |
self.token_dictionary_file = token_dictionary_file
|
| 234 |
+
self.clear_mem_ncells = clear_mem_ncells
|
| 235 |
+
|
| 236 |
+
self.validate_options()
|
| 237 |
|
| 238 |
if self.model_version == "V1":
|
| 239 |
from . import TOKEN_DICTIONARY_FILE_30M
|
|
|
|
| 247 |
self.emb_mode = "cell_and_gene"
|
| 248 |
logger.warning(
|
| 249 |
"model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
|
| 250 |
+
)
|
| 251 |
+
|
|
|
|
|
|
|
| 252 |
# load token dictionary (Ensembl IDs:token)
|
| 253 |
if self.token_dictionary_file is None:
|
| 254 |
token_dictionary_file = TOKEN_DICTIONARY_FILE
|