pull_geneformer
#557
by
spartanzhao
- opened
- geneformer/classifier.py +3 -22
- geneformer/classifier_utils.py +1 -22
- geneformer/emb_extractor.py +20 -68
- geneformer/evaluation_utils.py +2 -17
- geneformer/in_silico_perturber_stats.py +0 -3
- geneformer/perturber_utils.py +0 -5
- geneformer/tokenizer.py +18 -24
geneformer/classifier.py
CHANGED
|
@@ -801,7 +801,7 @@ class Classifier:
|
|
| 801 |
# 5-fold cross-validate
|
| 802 |
num_cells = len(data)
|
| 803 |
fifth_cells = int(np.floor(num_cells * 0.2))
|
| 804 |
-
num_eval =
|
| 805 |
start = i * fifth_cells
|
| 806 |
end = start + num_eval
|
| 807 |
eval_indices = [j for j in range(start, end)]
|
|
@@ -1313,7 +1313,6 @@ class Classifier:
|
|
| 1313 |
predict=False,
|
| 1314 |
output_directory=None,
|
| 1315 |
output_prefix=None,
|
| 1316 |
-
predict_metadata=None,
|
| 1317 |
):
|
| 1318 |
"""
|
| 1319 |
Evaluate the fine-tuned model.
|
|
@@ -1339,11 +1338,9 @@ class Classifier:
|
|
| 1339 |
|
| 1340 |
##### Evaluate the model #####
|
| 1341 |
labels = id_class_dict.keys()
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict, predict_metadata
|
| 1345 |
)
|
| 1346 |
-
|
| 1347 |
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
| 1348 |
y_pred, y_true, logits_list, num_classes, labels
|
| 1349 |
)
|
|
@@ -1353,9 +1350,6 @@ class Classifier:
|
|
| 1353 |
"label_ids": y_true,
|
| 1354 |
"predictions": logits_list,
|
| 1355 |
}
|
| 1356 |
-
if predict_metadata is not None:
|
| 1357 |
-
pred_dict["prediction_metadata"] = predict_metadata_all
|
| 1358 |
-
|
| 1359 |
pred_dict_output_path = (
|
| 1360 |
Path(output_directory) / f"{output_prefix}_pred_dict"
|
| 1361 |
).with_suffix(".pkl")
|
|
@@ -1376,7 +1370,6 @@ class Classifier:
|
|
| 1376 |
output_directory,
|
| 1377 |
output_prefix,
|
| 1378 |
predict=True,
|
| 1379 |
-
predict_metadata=None,
|
| 1380 |
):
|
| 1381 |
"""
|
| 1382 |
Evaluate the fine-tuned model.
|
|
@@ -1396,8 +1389,6 @@ class Classifier:
|
|
| 1396 |
| Prefix for output files
|
| 1397 |
predict : bool
|
| 1398 |
| Whether or not to save eval predictions
|
| 1399 |
-
predict_metadata : None | list
|
| 1400 |
-
| Metadata labels to output with predictions (columns in test_data_file)
|
| 1401 |
"""
|
| 1402 |
|
| 1403 |
# load numerical id to class dictionary (id:class)
|
|
@@ -1410,15 +1401,6 @@ class Classifier:
|
|
| 1410 |
# load previously filtered and prepared data
|
| 1411 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
| 1412 |
|
| 1413 |
-
if predict_metadata is not None:
|
| 1414 |
-
absent_metadata = []
|
| 1415 |
-
for predict_metadata_x in predict_metadata:
|
| 1416 |
-
if predict_metadata_x not in test_data.features.keys():
|
| 1417 |
-
absent_metadata += [predict_metadata_x]
|
| 1418 |
-
if len(absent_metadata)>0:
|
| 1419 |
-
logger.error(f"Following predict_metadata was not found as column in test_data_file: {absent_metadata}")
|
| 1420 |
-
raise
|
| 1421 |
-
|
| 1422 |
# load previously fine-tuned model
|
| 1423 |
model = pu.load_model(
|
| 1424 |
self.model_type,
|
|
@@ -1437,7 +1419,6 @@ class Classifier:
|
|
| 1437 |
predict=predict,
|
| 1438 |
output_directory=output_directory,
|
| 1439 |
output_prefix=output_prefix,
|
| 1440 |
-
predict_metadata=predict_metadata,
|
| 1441 |
)
|
| 1442 |
|
| 1443 |
all_conf_mat_df = pd.DataFrame(
|
|
|
|
| 801 |
# 5-fold cross-validate
|
| 802 |
num_cells = len(data)
|
| 803 |
fifth_cells = int(np.floor(num_cells * 0.2))
|
| 804 |
+
num_eval = min((self.eval_size * num_cells), fifth_cells)
|
| 805 |
start = i * fifth_cells
|
| 806 |
end = start + num_eval
|
| 807 |
eval_indices = [j for j in range(start, end)]
|
|
|
|
| 1313 |
predict=False,
|
| 1314 |
output_directory=None,
|
| 1315 |
output_prefix=None,
|
|
|
|
| 1316 |
):
|
| 1317 |
"""
|
| 1318 |
Evaluate the fine-tuned model.
|
|
|
|
| 1338 |
|
| 1339 |
##### Evaluate the model #####
|
| 1340 |
labels = id_class_dict.keys()
|
| 1341 |
+
y_pred, y_true, logits_list = eu.classifier_predict(
|
| 1342 |
+
model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict
|
|
|
|
| 1343 |
)
|
|
|
|
| 1344 |
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
| 1345 |
y_pred, y_true, logits_list, num_classes, labels
|
| 1346 |
)
|
|
|
|
| 1350 |
"label_ids": y_true,
|
| 1351 |
"predictions": logits_list,
|
| 1352 |
}
|
|
|
|
|
|
|
|
|
|
| 1353 |
pred_dict_output_path = (
|
| 1354 |
Path(output_directory) / f"{output_prefix}_pred_dict"
|
| 1355 |
).with_suffix(".pkl")
|
|
|
|
| 1370 |
output_directory,
|
| 1371 |
output_prefix,
|
| 1372 |
predict=True,
|
|
|
|
| 1373 |
):
|
| 1374 |
"""
|
| 1375 |
Evaluate the fine-tuned model.
|
|
|
|
| 1389 |
| Prefix for output files
|
| 1390 |
predict : bool
|
| 1391 |
| Whether or not to save eval predictions
|
|
|
|
|
|
|
| 1392 |
"""
|
| 1393 |
|
| 1394 |
# load numerical id to class dictionary (id:class)
|
|
|
|
| 1401 |
# load previously filtered and prepared data
|
| 1402 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
| 1403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1404 |
# load previously fine-tuned model
|
| 1405 |
model = pu.load_model(
|
| 1406 |
self.model_type,
|
|
|
|
| 1419 |
predict=predict,
|
| 1420 |
output_directory=output_directory,
|
| 1421 |
output_prefix=output_prefix,
|
|
|
|
| 1422 |
)
|
| 1423 |
|
| 1424 |
all_conf_mat_df = pd.DataFrame(
|
geneformer/classifier_utils.py
CHANGED
|
@@ -94,7 +94,7 @@ def remove_rare(data, rare_threshold, label, nproc):
|
|
| 94 |
return data
|
| 95 |
|
| 96 |
|
| 97 |
-
def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict
|
| 98 |
if classifier == "cell":
|
| 99 |
label_set = set(data["label"])
|
| 100 |
elif classifier == "gene":
|
|
@@ -570,27 +570,6 @@ def compute_metrics(pred):
|
|
| 570 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
| 571 |
|
| 572 |
|
| 573 |
-
def robust_compute_objective(metrics: dict):
|
| 574 |
-
# tries both prefixed ("eval_") and raw metric names to support different transformers versions
|
| 575 |
-
metric_name = "macro_f1"
|
| 576 |
-
|
| 577 |
-
# check for the prefixed version
|
| 578 |
-
prefixed_metric_name = f"eval_{metric_name}"
|
| 579 |
-
if prefixed_metric_name in metrics:
|
| 580 |
-
return metrics[prefixed_metric_name]
|
| 581 |
-
|
| 582 |
-
# fall back to the raw name
|
| 583 |
-
elif metric_name in metrics:
|
| 584 |
-
return metrics[metric_name]
|
| 585 |
-
|
| 586 |
-
# if neither is found, raise a clear error to help with debugging
|
| 587 |
-
raise KeyError(
|
| 588 |
-
f"Could not find '{prefixed_metric_name}' or '{metric_name}' in the reported metrics. "
|
| 589 |
-
f"Please check your `compute_metrics` function and `TrainingArguments`. "
|
| 590 |
-
f"Available metrics: {list(metrics.keys())}"
|
| 591 |
-
)
|
| 592 |
-
|
| 593 |
-
|
| 594 |
def get_default_train_args(model, classifier, data, output_dir):
|
| 595 |
num_layers = pu.quant_layers(model)
|
| 596 |
freeze_layers = 0
|
|
|
|
| 94 |
return data
|
| 95 |
|
| 96 |
|
| 97 |
+
def label_classes(classifier, data, gene_class_dict, nproc, id_class_dict):
|
| 98 |
if classifier == "cell":
|
| 99 |
label_set = set(data["label"])
|
| 100 |
elif classifier == "gene":
|
|
|
|
| 570 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
| 571 |
|
| 572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
def get_default_train_args(model, classifier, data, output_dir):
|
| 574 |
num_layers = pu.quant_layers(model)
|
| 575 |
freeze_layers = 0
|
geneformer/emb_extractor.py
CHANGED
|
@@ -42,8 +42,6 @@ def get_embs(
|
|
| 42 |
special_token=False,
|
| 43 |
summary_stat=None,
|
| 44 |
silent=False,
|
| 45 |
-
save_tdigest=False,
|
| 46 |
-
tdigest_path=None,
|
| 47 |
):
|
| 48 |
model_input_size = pu.get_model_input_size(model)
|
| 49 |
total_batch_length = len(filtered_input_data)
|
|
@@ -182,18 +180,12 @@ def get_embs(
|
|
| 182 |
# calculate summary stat embs from approximated tdigests
|
| 183 |
elif summary_stat is not None:
|
| 184 |
if emb_mode == "cell":
|
| 185 |
-
if save_tdigest:
|
| 186 |
-
with open(f"{tdigest_path}","wb") as fp:
|
| 187 |
-
pickle.dump(embs_tdigests, fp)
|
| 188 |
if summary_stat == "mean":
|
| 189 |
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
| 190 |
elif summary_stat == "median":
|
| 191 |
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
| 192 |
embs_stack = torch.tensor(summary_emb_list)
|
| 193 |
elif emb_mode == "gene":
|
| 194 |
-
if save_tdigest:
|
| 195 |
-
with open(f"{tdigest_path}","wb") as fp:
|
| 196 |
-
pickle.dump(embs_tdigests_dict, fp)
|
| 197 |
if summary_stat == "mean":
|
| 198 |
[
|
| 199 |
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
|
@@ -260,7 +252,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels):
|
|
| 260 |
return embs_df
|
| 261 |
|
| 262 |
|
| 263 |
-
def label_gene_embs(embs, downsampled_data, token_gene_dict
|
| 264 |
gene_set = {
|
| 265 |
element for sublist in downsampled_data["input_ids"] for element in sublist
|
| 266 |
}
|
|
@@ -275,39 +267,16 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mea
|
|
| 275 |
)
|
| 276 |
for k in dict_i.keys():
|
| 277 |
gene_emb_dict[k].append(dict_i[k])
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
gene_emb_dict[k] =
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
embs_df = pd.DataFrame(gene_emb_dict).T
|
| 286 |
-
else:
|
| 287 |
-
embs_df = dict_lol_to_df(gene_emb_dict)
|
| 288 |
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
| 289 |
return embs_df
|
| 290 |
|
| 291 |
-
def dict_lol_to_df(data_dict):
|
| 292 |
-
# save dictionary with values being list of equal-length lists as dataframe
|
| 293 |
-
df_data = []
|
| 294 |
-
for key, list_of_lists in data_dict.items():
|
| 295 |
-
for i, sublist in enumerate(list_of_lists):
|
| 296 |
-
row_data = [key, i] + sublist.tolist()
|
| 297 |
-
df_data.append(row_data)
|
| 298 |
-
|
| 299 |
-
# determine column names based on the length of sublists
|
| 300 |
-
# assuming all sublists have the same length
|
| 301 |
-
num_columns_from_sublist = len(list(data_dict.values())[0][0])
|
| 302 |
-
column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)]
|
| 303 |
-
|
| 304 |
-
# create the dataframe
|
| 305 |
-
df = pd.DataFrame(df_data, columns=column_names)
|
| 306 |
-
|
| 307 |
-
# set 'Gene' as the index
|
| 308 |
-
df = df.set_index('Gene')
|
| 309 |
-
|
| 310 |
-
return df
|
| 311 |
|
| 312 |
def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
|
| 313 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
|
@@ -435,7 +404,7 @@ class EmbExtractor:
|
|
| 435 |
"num_classes": {int},
|
| 436 |
"emb_mode": {"cls", "cell", "gene"},
|
| 437 |
"cell_emb_style": {"mean_pool"},
|
| 438 |
-
"gene_emb_style": {"mean_pool"
|
| 439 |
"filter_data": {None, dict},
|
| 440 |
"max_ncells": {None, int},
|
| 441 |
"emb_layer": {-1, 0},
|
|
@@ -463,7 +432,6 @@ class EmbExtractor:
|
|
| 463 |
forward_batch_size=100,
|
| 464 |
nproc=4,
|
| 465 |
summary_stat=None,
|
| 466 |
-
save_tdigest=False,
|
| 467 |
model_version="V2",
|
| 468 |
token_dictionary_file=None,
|
| 469 |
):
|
|
@@ -483,9 +451,9 @@ class EmbExtractor:
|
|
| 483 |
cell_emb_style : {"mean_pool"}
|
| 484 |
| Method for summarizing cell embeddings if not using CLS token.
|
| 485 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
| 486 |
-
gene_emb_style :
|
| 487 |
| Method for summarizing gene embeddings.
|
| 488 |
-
| Currently only option is
|
| 489 |
filter_data : None, dict
|
| 490 |
| Default is to extract embeddings from all input data.
|
| 491 |
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
|
@@ -515,9 +483,6 @@ class EmbExtractor:
|
|
| 515 |
| If mean or median, outputs only approximated mean or median embedding of input data.
|
| 516 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
| 517 |
| Non-exact is slower but more memory-efficient.
|
| 518 |
-
save_tdigest : bool
|
| 519 |
-
| Whether to save a dictionary of tdigests for each gene and embedding dimension
|
| 520 |
-
| Only applies when summary_stat is not None
|
| 521 |
model_version : str
|
| 522 |
| To auto-select settings for model version other than current default.
|
| 523 |
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
|
@@ -561,16 +526,9 @@ class EmbExtractor:
|
|
| 561 |
else:
|
| 562 |
self.summary_stat = summary_stat
|
| 563 |
self.exact_summary_stat = None
|
| 564 |
-
self.save_tdigest = save_tdigest
|
| 565 |
|
| 566 |
self.validate_options()
|
| 567 |
|
| 568 |
-
if (summary_stat is None) and (save_tdigest is True):
|
| 569 |
-
logger.warning(
|
| 570 |
-
"tdigests will not be saved since summary_stat is None."
|
| 571 |
-
)
|
| 572 |
-
save_tdigest = False
|
| 573 |
-
|
| 574 |
if self.model_version == "V1":
|
| 575 |
from . import TOKEN_DICTIONARY_FILE_30M
|
| 576 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
|
@@ -677,10 +635,6 @@ class EmbExtractor:
|
|
| 677 |
self.model_type, self.num_classes, model_directory, mode="eval"
|
| 678 |
)
|
| 679 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
| 680 |
-
if self.save_tdigest:
|
| 681 |
-
tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl")
|
| 682 |
-
else:
|
| 683 |
-
tdigest_path = None
|
| 684 |
embs = get_embs(
|
| 685 |
model=model,
|
| 686 |
filtered_input_data=downsampled_data,
|
|
@@ -690,8 +644,6 @@ class EmbExtractor:
|
|
| 690 |
forward_batch_size=self.forward_batch_size,
|
| 691 |
token_gene_dict=self.token_gene_dict,
|
| 692 |
summary_stat=self.summary_stat,
|
| 693 |
-
save_tdigest=self.save_tdigest,
|
| 694 |
-
tdigest_path=tdigest_path,
|
| 695 |
)
|
| 696 |
|
| 697 |
if self.emb_mode == "cell":
|
|
@@ -701,7 +653,7 @@ class EmbExtractor:
|
|
| 701 |
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
| 702 |
elif self.emb_mode == "gene":
|
| 703 |
if self.summary_stat is None:
|
| 704 |
-
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict
|
| 705 |
elif self.summary_stat is not None:
|
| 706 |
embs_df = pd.DataFrame(embs).T
|
| 707 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
@@ -885,14 +837,14 @@ class EmbExtractor:
|
|
| 885 |
raise
|
| 886 |
|
| 887 |
if max_ncells_to_plot is not None:
|
| 888 |
-
if self.max_ncells
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
|
| 897 |
if self.emb_label is None:
|
| 898 |
label_len = 0
|
|
|
|
| 42 |
special_token=False,
|
| 43 |
summary_stat=None,
|
| 44 |
silent=False,
|
|
|
|
|
|
|
| 45 |
):
|
| 46 |
model_input_size = pu.get_model_input_size(model)
|
| 47 |
total_batch_length = len(filtered_input_data)
|
|
|
|
| 180 |
# calculate summary stat embs from approximated tdigests
|
| 181 |
elif summary_stat is not None:
|
| 182 |
if emb_mode == "cell":
|
|
|
|
|
|
|
|
|
|
| 183 |
if summary_stat == "mean":
|
| 184 |
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
| 185 |
elif summary_stat == "median":
|
| 186 |
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
| 187 |
embs_stack = torch.tensor(summary_emb_list)
|
| 188 |
elif emb_mode == "gene":
|
|
|
|
|
|
|
|
|
|
| 189 |
if summary_stat == "mean":
|
| 190 |
[
|
| 191 |
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
|
|
|
| 252 |
return embs_df
|
| 253 |
|
| 254 |
|
| 255 |
+
def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
| 256 |
gene_set = {
|
| 257 |
element for sublist in downsampled_data["input_ids"] for element in sublist
|
| 258 |
}
|
|
|
|
| 267 |
)
|
| 268 |
for k in dict_i.keys():
|
| 269 |
gene_emb_dict[k].append(dict_i[k])
|
| 270 |
+
for k in gene_emb_dict.keys():
|
| 271 |
+
gene_emb_dict[k] = (
|
| 272 |
+
torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
|
| 273 |
+
.cpu()
|
| 274 |
+
.numpy()
|
| 275 |
+
)
|
| 276 |
+
embs_df = pd.DataFrame(gene_emb_dict).T
|
|
|
|
|
|
|
|
|
|
| 277 |
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
| 278 |
return embs_df
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
|
| 282 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
|
|
|
| 404 |
"num_classes": {int},
|
| 405 |
"emb_mode": {"cls", "cell", "gene"},
|
| 406 |
"cell_emb_style": {"mean_pool"},
|
| 407 |
+
"gene_emb_style": {"mean_pool"},
|
| 408 |
"filter_data": {None, dict},
|
| 409 |
"max_ncells": {None, int},
|
| 410 |
"emb_layer": {-1, 0},
|
|
|
|
| 432 |
forward_batch_size=100,
|
| 433 |
nproc=4,
|
| 434 |
summary_stat=None,
|
|
|
|
| 435 |
model_version="V2",
|
| 436 |
token_dictionary_file=None,
|
| 437 |
):
|
|
|
|
| 451 |
cell_emb_style : {"mean_pool"}
|
| 452 |
| Method for summarizing cell embeddings if not using CLS token.
|
| 453 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
| 454 |
+
gene_emb_style : "mean_pool"
|
| 455 |
| Method for summarizing gene embeddings.
|
| 456 |
+
| Currently only option is mean pooling of contextual gene embeddings for given gene.
|
| 457 |
filter_data : None, dict
|
| 458 |
| Default is to extract embeddings from all input data.
|
| 459 |
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
|
|
|
| 483 |
| If mean or median, outputs only approximated mean or median embedding of input data.
|
| 484 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
| 485 |
| Non-exact is slower but more memory-efficient.
|
|
|
|
|
|
|
|
|
|
| 486 |
model_version : str
|
| 487 |
| To auto-select settings for model version other than current default.
|
| 488 |
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
|
|
|
| 526 |
else:
|
| 527 |
self.summary_stat = summary_stat
|
| 528 |
self.exact_summary_stat = None
|
|
|
|
| 529 |
|
| 530 |
self.validate_options()
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
if self.model_version == "V1":
|
| 533 |
from . import TOKEN_DICTIONARY_FILE_30M
|
| 534 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
|
|
|
| 635 |
self.model_type, self.num_classes, model_directory, mode="eval"
|
| 636 |
)
|
| 637 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
embs = get_embs(
|
| 639 |
model=model,
|
| 640 |
filtered_input_data=downsampled_data,
|
|
|
|
| 644 |
forward_batch_size=self.forward_batch_size,
|
| 645 |
token_gene_dict=self.token_gene_dict,
|
| 646 |
summary_stat=self.summary_stat,
|
|
|
|
|
|
|
| 647 |
)
|
| 648 |
|
| 649 |
if self.emb_mode == "cell":
|
|
|
|
| 653 |
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
| 654 |
elif self.emb_mode == "gene":
|
| 655 |
if self.summary_stat is None:
|
| 656 |
+
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
|
| 657 |
elif self.summary_stat is not None:
|
| 658 |
embs_df = pd.DataFrame(embs).T
|
| 659 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
|
|
| 837 |
raise
|
| 838 |
|
| 839 |
if max_ncells_to_plot is not None:
|
| 840 |
+
if max_ncells_to_plot > self.max_ncells:
|
| 841 |
+
max_ncells_to_plot = self.max_ncells
|
| 842 |
+
logger.warning(
|
| 843 |
+
"max_ncells_to_plot must be <= max_ncells. "
|
| 844 |
+
f"Changing max_ncells_to_plot to {self.max_ncells}."
|
| 845 |
+
)
|
| 846 |
+
elif max_ncells_to_plot < self.max_ncells:
|
| 847 |
+
embs = embs.sample(max_ncells_to_plot, axis=0)
|
| 848 |
|
| 849 |
if self.emb_label is None:
|
| 850 |
label_len = 0
|
geneformer/evaluation_utils.py
CHANGED
|
@@ -77,7 +77,7 @@ def py_softmax(vector):
|
|
| 77 |
return e / e.sum()
|
| 78 |
|
| 79 |
|
| 80 |
-
def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict
|
| 81 |
if classifier_type == "gene":
|
| 82 |
label_name = "labels"
|
| 83 |
elif classifier_type == "cell":
|
|
@@ -85,14 +85,6 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 85 |
|
| 86 |
predict_logits = []
|
| 87 |
predict_labels = []
|
| 88 |
-
|
| 89 |
-
predict_metadata_all = None
|
| 90 |
-
|
| 91 |
-
if predict_metadata is not None:
|
| 92 |
-
predict_metadata_all = dict()
|
| 93 |
-
for metadata_name in predict_metadata:
|
| 94 |
-
predict_metadata_all[metadata_name] = []
|
| 95 |
-
|
| 96 |
model.eval()
|
| 97 |
|
| 98 |
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
|
@@ -107,15 +99,9 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 107 |
for i in trange(0, evalset_len, forward_batch_size):
|
| 108 |
max_range = min(i + forward_batch_size, evalset_len)
|
| 109 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
| 110 |
-
|
| 111 |
-
if predict_metadata is not None:
|
| 112 |
-
for metadata_name in predict_metadata:
|
| 113 |
-
predict_metadata_all[metadata_name] += batch_evalset[metadata_name]
|
| 114 |
-
|
| 115 |
padded_batch = preprocess_classifier_batch(
|
| 116 |
batch_evalset, max_evalset_len, label_name, gene_token_dict
|
| 117 |
)
|
| 118 |
-
|
| 119 |
padded_batch.set_format(type="torch")
|
| 120 |
|
| 121 |
# For datasets>=4.0.0, convert to dict to avoid format issues
|
|
@@ -148,8 +134,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 148 |
y_pred = [vote(item[0]) for item in logit_label_paired]
|
| 149 |
y_true = [item[1] for item in logit_label_paired]
|
| 150 |
logits_list = [item[0] for item in logit_label_paired]
|
| 151 |
-
|
| 152 |
-
return y_pred, y_true, logits_list, predict_metadata_all
|
| 153 |
|
| 154 |
|
| 155 |
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
|
|
|
| 77 |
return e / e.sum()
|
| 78 |
|
| 79 |
|
| 80 |
+
def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict):
|
| 81 |
if classifier_type == "gene":
|
| 82 |
label_name = "labels"
|
| 83 |
elif classifier_type == "cell":
|
|
|
|
| 85 |
|
| 86 |
predict_logits = []
|
| 87 |
predict_labels = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
model.eval()
|
| 89 |
|
| 90 |
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
|
|
|
| 99 |
for i in trange(0, evalset_len, forward_batch_size):
|
| 100 |
max_range = min(i + forward_batch_size, evalset_len)
|
| 101 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
padded_batch = preprocess_classifier_batch(
|
| 103 |
batch_evalset, max_evalset_len, label_name, gene_token_dict
|
| 104 |
)
|
|
|
|
| 105 |
padded_batch.set_format(type="torch")
|
| 106 |
|
| 107 |
# For datasets>=4.0.0, convert to dict to avoid format issues
|
|
|
|
| 134 |
y_pred = [vote(item[0]) for item in logit_label_paired]
|
| 135 |
y_true = [item[1] for item in logit_label_paired]
|
| 136 |
logits_list = [item[0] for item in logit_label_paired]
|
| 137 |
+
return y_pred, y_true, logits_list
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
geneformer/in_silico_perturber_stats.py
CHANGED
|
@@ -726,9 +726,6 @@ class InSilicoPerturberStats:
|
|
| 726 |
| "start_state": "dcm",
|
| 727 |
| "goal_state": "nf",
|
| 728 |
| "alt_states": ["hcm", "other1", "other2"]}
|
| 729 |
-
pickle_suffix : None, str
|
| 730 |
-
| Suffix to subselect intermediate raw files for analysis.
|
| 731 |
-
| Default output of InSilicoPerturber uses suffix "_raw.pickle".
|
| 732 |
model_version : str
|
| 733 |
| To auto-select settings for model version other than current default.
|
| 734 |
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
|
|
|
| 726 |
| "start_state": "dcm",
|
| 727 |
| "goal_state": "nf",
|
| 728 |
| "alt_states": ["hcm", "other1", "other2"]}
|
|
|
|
|
|
|
|
|
|
| 729 |
model_version : str
|
| 730 |
| To auto-select settings for model version other than current default.
|
| 731 |
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
geneformer/perturber_utils.py
CHANGED
|
@@ -508,11 +508,6 @@ def make_perturbation_batch(
|
|
| 508 |
def make_perturbation_batch_special(
|
| 509 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 510 |
) -> tuple[Dataset, List[int]]:
|
| 511 |
-
|
| 512 |
-
# For datasets>=4.0.0, convert to dict to avoid format issues
|
| 513 |
-
if int(datasets.__version__.split(".")[0]) >= 4:
|
| 514 |
-
example_cell = example_cell[:]
|
| 515 |
-
|
| 516 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 517 |
if perturb_type in ["overexpress", "activate"]:
|
| 518 |
range_start = 1
|
|
|
|
| 508 |
def make_perturbation_batch_special(
|
| 509 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 510 |
) -> tuple[Dataset, List[int]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 512 |
if perturb_type in ["overexpress", "activate"]:
|
| 513 |
range_start = 1
|
geneformer/tokenizer.py
CHANGED
|
@@ -3,7 +3,7 @@ Geneformer tokenizer.
|
|
| 3 |
|
| 4 |
**Input data:**
|
| 5 |
|
| 6 |
-
| *Required format:* raw counts scRNAseq data without feature selection as .loom
|
| 7 |
| *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
|
| 8 |
| *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
|
| 9 |
|
|
@@ -20,9 +20,9 @@ Geneformer tokenizer.
|
|
| 20 |
|
| 21 |
**Description:**
|
| 22 |
|
| 23 |
-
| Input data is a directory with .loom
|
| 24 |
|
| 25 |
-
| The discussion below references the .loom file format, but the analagous labels are required for .h5ad
|
| 26 |
|
| 27 |
| Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
|
| 28 |
|
|
@@ -30,7 +30,7 @@ Geneformer tokenizer.
|
|
| 30 |
|
| 31 |
| Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
|
| 32 |
|
| 33 |
-
| If one's data is in other formats besides .loom
|
| 34 |
|
| 35 |
| OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.)
|
| 36 |
|
|
@@ -46,7 +46,6 @@ from collections import Counter
|
|
| 46 |
from pathlib import Path
|
| 47 |
from typing import Literal
|
| 48 |
|
| 49 |
-
import anndata as ad
|
| 50 |
import loompy as lp
|
| 51 |
import numpy as np
|
| 52 |
import pandas as pd
|
|
@@ -201,16 +200,13 @@ def sum_ensembl_ids(
|
|
| 201 |
dsout.add_columns(processed_array, col_attrs=view.ca)
|
| 202 |
return dedup_filename
|
| 203 |
|
| 204 |
-
elif file_format
|
| 205 |
"""
|
| 206 |
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
| 207 |
Returns adata object with deduplicated Ensembl IDs.
|
| 208 |
"""
|
| 209 |
|
| 210 |
-
|
| 211 |
-
data = sc.read_h5ad(str(data_directory))
|
| 212 |
-
else: # zarr
|
| 213 |
-
data = ad.read_zarr(str(data_directory))
|
| 214 |
|
| 215 |
if use_h5ad_index:
|
| 216 |
data.var["ensembl_id"] = list(data.var.index)
|
|
@@ -240,7 +236,7 @@ def sum_ensembl_ids(
|
|
| 240 |
gene for gene in ensembl_ids if gene in gene_token_dict.keys()
|
| 241 |
]
|
| 242 |
if len(ensembl_id_check) == len(set(ensembl_id_check)):
|
| 243 |
-
return
|
| 244 |
else:
|
| 245 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
| 246 |
|
|
@@ -439,7 +435,7 @@ class TranscriptomeTokenizer:
|
|
| 439 |
data_directory: Path | str,
|
| 440 |
output_directory: Path | str,
|
| 441 |
output_prefix: str,
|
| 442 |
-
file_format: Literal["loom", "h5ad"
|
| 443 |
input_identifier: str = "",
|
| 444 |
use_generator: bool = False,
|
| 445 |
):
|
|
@@ -455,9 +451,9 @@ class TranscriptomeTokenizer:
|
|
| 455 |
output_prefix : str
|
| 456 |
| Prefix for output .dataset
|
| 457 |
file_format : str
|
| 458 |
-
| Format of input files. Can be "loom"
|
| 459 |
input_identifier : str
|
| 460 |
-
| Substring identifier for input .loom
|
| 461 |
| Default is no identifier, tokenizes all files in provided directory.
|
| 462 |
use_generator : bool
|
| 463 |
| Whether to use generator or dict for tokenization.
|
|
@@ -477,7 +473,7 @@ class TranscriptomeTokenizer:
|
|
| 477 |
tokenized_dataset.save_to_disk(str(output_path))
|
| 478 |
|
| 479 |
def tokenize_files(
|
| 480 |
-
self, data_directory, file_format: Literal["loom", "h5ad"
|
| 481 |
):
|
| 482 |
tokenized_cells = []
|
| 483 |
tokenized_counts = []
|
|
@@ -489,7 +485,7 @@ class TranscriptomeTokenizer:
|
|
| 489 |
|
| 490 |
# loops through directories to tokenize .loom files
|
| 491 |
file_found = 0
|
| 492 |
-
# loops through directories to tokenize .loom
|
| 493 |
tokenize_file_fn = (
|
| 494 |
self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
|
| 495 |
)
|
|
@@ -500,7 +496,7 @@ class TranscriptomeTokenizer:
|
|
| 500 |
for file_path in data_directory.glob(file_match):
|
| 501 |
file_found = 1
|
| 502 |
print(f"Tokenizing {file_path}")
|
| 503 |
-
file_tokenized_cells, file_cell_metadata, file_tokenized_counts = tokenize_file_fn(file_path
|
| 504 |
tokenized_cells += file_tokenized_cells
|
| 505 |
tokenized_counts += file_tokenized_counts
|
| 506 |
if self.custom_attr_name_dict is not None:
|
|
@@ -518,7 +514,7 @@ class TranscriptomeTokenizer:
|
|
| 518 |
raise
|
| 519 |
return tokenized_cells, cell_metadata, tokenized_counts
|
| 520 |
|
| 521 |
-
def tokenize_anndata(self, adata_file_path, target_sum=10_000
|
| 522 |
adata = sum_ensembl_ids(
|
| 523 |
adata_file_path,
|
| 524 |
self.collapse_gene_ids,
|
|
@@ -526,7 +522,7 @@ class TranscriptomeTokenizer:
|
|
| 526 |
self.gene_token_dict,
|
| 527 |
self.custom_attr_name_dict,
|
| 528 |
self.use_h5ad_index,
|
| 529 |
-
file_format=
|
| 530 |
chunk_size=self.chunk_size,
|
| 531 |
)
|
| 532 |
|
|
@@ -616,9 +612,7 @@ class TranscriptomeTokenizer:
|
|
| 616 |
|
| 617 |
return tokenized_cells, file_cell_metadata, tokenized_counts
|
| 618 |
|
| 619 |
-
def tokenize_loom(self, loom_file_path, target_sum=10_000
|
| 620 |
-
tokenized_counts = [] # keep_counts not implemented for tokenize_loom
|
| 621 |
-
|
| 622 |
if self.custom_attr_name_dict is not None:
|
| 623 |
file_cell_metadata = {
|
| 624 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
|
@@ -633,7 +627,7 @@ class TranscriptomeTokenizer:
|
|
| 633 |
self.gene_token_dict,
|
| 634 |
self.custom_attr_name_dict,
|
| 635 |
use_h5ad_index=False,
|
| 636 |
-
file_format=
|
| 637 |
chunk_size=self.chunk_size,
|
| 638 |
)
|
| 639 |
|
|
@@ -706,7 +700,7 @@ class TranscriptomeTokenizer:
|
|
| 706 |
del data.ra["ensembl_id_collapsed"]
|
| 707 |
|
| 708 |
|
| 709 |
-
return tokenized_cells, file_cell_metadata
|
| 710 |
|
| 711 |
def create_dataset(
|
| 712 |
self,
|
|
|
|
| 3 |
|
| 4 |
**Input data:**
|
| 5 |
|
| 6 |
+
| *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
|
| 7 |
| *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
|
| 8 |
| *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
|
| 9 |
|
|
|
|
| 20 |
|
| 21 |
**Description:**
|
| 22 |
|
| 23 |
+
| Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
|
| 24 |
|
| 25 |
+
| The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
|
| 26 |
|
| 27 |
| Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
|
| 28 |
|
|
|
|
| 30 |
|
| 31 |
| Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
|
| 32 |
|
| 33 |
+
| If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
|
| 34 |
|
| 35 |
| OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.)
|
| 36 |
|
|
|
|
| 46 |
from pathlib import Path
|
| 47 |
from typing import Literal
|
| 48 |
|
|
|
|
| 49 |
import loompy as lp
|
| 50 |
import numpy as np
|
| 51 |
import pandas as pd
|
|
|
|
| 200 |
dsout.add_columns(processed_array, col_attrs=view.ca)
|
| 201 |
return dedup_filename
|
| 202 |
|
| 203 |
+
elif file_format == "h5ad":
|
| 204 |
"""
|
| 205 |
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
| 206 |
Returns adata object with deduplicated Ensembl IDs.
|
| 207 |
"""
|
| 208 |
|
| 209 |
+
data = sc.read_h5ad(str(data_directory))
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
if use_h5ad_index:
|
| 212 |
data.var["ensembl_id"] = list(data.var.index)
|
|
|
|
| 236 |
gene for gene in ensembl_ids if gene in gene_token_dict.keys()
|
| 237 |
]
|
| 238 |
if len(ensembl_id_check) == len(set(ensembl_id_check)):
|
| 239 |
+
return data_directory
|
| 240 |
else:
|
| 241 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
| 242 |
|
|
|
|
| 435 |
data_directory: Path | str,
|
| 436 |
output_directory: Path | str,
|
| 437 |
output_prefix: str,
|
| 438 |
+
file_format: Literal["loom", "h5ad"] = "loom",
|
| 439 |
input_identifier: str = "",
|
| 440 |
use_generator: bool = False,
|
| 441 |
):
|
|
|
|
| 451 |
output_prefix : str
|
| 452 |
| Prefix for output .dataset
|
| 453 |
file_format : str
|
| 454 |
+
| Format of input files. Can be "loom" or "h5ad".
|
| 455 |
input_identifier : str
|
| 456 |
+
| Substring identifier for input .loom or .h5ad, only matches are tokenized
|
| 457 |
| Default is no identifier, tokenizes all files in provided directory.
|
| 458 |
use_generator : bool
|
| 459 |
| Whether to use generator or dict for tokenization.
|
|
|
|
| 473 |
tokenized_dataset.save_to_disk(str(output_path))
|
| 474 |
|
| 475 |
def tokenize_files(
|
| 476 |
+
self, data_directory, file_format: Literal["loom", "h5ad"] = "loom", input_identifier: str = ""
|
| 477 |
):
|
| 478 |
tokenized_cells = []
|
| 479 |
tokenized_counts = []
|
|
|
|
| 485 |
|
| 486 |
# loops through directories to tokenize .loom files
|
| 487 |
file_found = 0
|
| 488 |
+
# loops through directories to tokenize .loom or .h5ad files
|
| 489 |
tokenize_file_fn = (
|
| 490 |
self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
|
| 491 |
)
|
|
|
|
| 496 |
for file_path in data_directory.glob(file_match):
|
| 497 |
file_found = 1
|
| 498 |
print(f"Tokenizing {file_path}")
|
| 499 |
+
file_tokenized_cells, file_cell_metadata, file_tokenized_counts = tokenize_file_fn(file_path)
|
| 500 |
tokenized_cells += file_tokenized_cells
|
| 501 |
tokenized_counts += file_tokenized_counts
|
| 502 |
if self.custom_attr_name_dict is not None:
|
|
|
|
| 514 |
raise
|
| 515 |
return tokenized_cells, cell_metadata, tokenized_counts
|
| 516 |
|
| 517 |
+
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
| 518 |
adata = sum_ensembl_ids(
|
| 519 |
adata_file_path,
|
| 520 |
self.collapse_gene_ids,
|
|
|
|
| 522 |
self.gene_token_dict,
|
| 523 |
self.custom_attr_name_dict,
|
| 524 |
self.use_h5ad_index,
|
| 525 |
+
file_format="h5ad",
|
| 526 |
chunk_size=self.chunk_size,
|
| 527 |
)
|
| 528 |
|
|
|
|
| 612 |
|
| 613 |
return tokenized_cells, file_cell_metadata, tokenized_counts
|
| 614 |
|
| 615 |
+
def tokenize_loom(self, loom_file_path, target_sum=10_000):
|
|
|
|
|
|
|
| 616 |
if self.custom_attr_name_dict is not None:
|
| 617 |
file_cell_metadata = {
|
| 618 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
|
|
|
| 627 |
self.gene_token_dict,
|
| 628 |
self.custom_attr_name_dict,
|
| 629 |
use_h5ad_index=False,
|
| 630 |
+
file_format="loom",
|
| 631 |
chunk_size=self.chunk_size,
|
| 632 |
)
|
| 633 |
|
|
|
|
| 700 |
del data.ra["ensembl_id_collapsed"]
|
| 701 |
|
| 702 |
|
| 703 |
+
return tokenized_cells, file_cell_metadata
|
| 704 |
|
| 705 |
def create_dataset(
|
| 706 |
self,
|