Upload 14 files
Browse filesGeneformer backbone by Theodoris et al.
- geneformer/.DS_Store +0 -0
- geneformer/Dataset_Create.py +79 -0
- geneformer/__init__.py +12 -0
- geneformer/collator_for_classification.py +602 -0
- geneformer/emb_extractor.py +806 -0
- geneformer/gene_median_dictionary.pkl +3 -0
- geneformer/gene_name_id_dict.pkl +3 -0
- geneformer/in_silico_perturber.py +915 -0
- geneformer/in_silico_perturber_stats.py +1042 -0
- geneformer/model.safetensors +3 -0
- geneformer/perturber_utils.py +699 -0
- geneformer/pretrainer.py +978 -0
- geneformer/token_dictionary.pkl +3 -0
- geneformer/tokenizer.py +369 -0
geneformer/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
geneformer/Dataset_Create.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from datasets import load_from_disk
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import BertForMaskedLM
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from tqdm.notebook import tqdm
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
# sys.path.append('/Users/chenj0i/Desktop/Lab Work/Geneformer')
|
| 12 |
+
from geneformer.pretrainer import token_dictionary
|
| 13 |
+
|
| 14 |
+
import datetime
|
| 15 |
+
|
| 16 |
+
# imports
|
| 17 |
+
import os
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
| 21 |
+
os.environ["OMPI_MCA_opal_cuda_support"] = "true"
|
| 22 |
+
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
| 23 |
+
|
| 24 |
+
import pickle
|
| 25 |
+
import random
|
| 26 |
+
import subprocess
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import pytz
|
| 30 |
+
import torch
|
| 31 |
+
from datasets import load_from_disk, Dataset
|
| 32 |
+
from transformers import BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, Trainer, BertModel, BertPreTrainedModel
|
| 33 |
+
|
| 34 |
+
from geneformer import GeneformerPretrainer
|
| 35 |
+
|
| 36 |
+
from typing import Tuple
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 39 |
+
from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
|
| 40 |
+
from transformers.activations import ACT2FN
|
| 41 |
+
from typing import List, Optional, Tuple, Union
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# # Randomly select 100_000 sequences from Genecorpus to conduct the training
|
| 46 |
+
genecorpus = load_from_disk("/ibex/user/chenj0i/Geneformer/Genecorpus-30M/genecorpus_30M_2048.dataset")
|
| 47 |
+
|
| 48 |
+
subset_size = 1_200_000
|
| 49 |
+
subset_sequences = genecorpus.shuffle(seed=42).select(i for i in tqdm(list(range(subset_size))))['input_ids']
|
| 50 |
+
subset_train_dataset = Dataset.from_dict({"input_ids": subset_sequences[:-200_000]})
|
| 51 |
+
subset_train_dataset.save_to_disk("/ibex/user/chenj0i/Geneformer/subset_1Mtrain_genecorpus.dataset")
|
| 52 |
+
subset_test_dataset = Dataset.from_dict({"input_ids": subset_sequences[-200_000:]})
|
| 53 |
+
subset_test_dataset.save_to_disk("/ibex/user/chenj0i/Geneformer/subset_200K_1Mtrain_genecorpus.dataset")
|
| 54 |
+
|
| 55 |
+
# Create length file for the training
|
| 56 |
+
# Define the value to repeat
|
| 57 |
+
value_to_repeat = 2048
|
| 58 |
+
# Define the total number of elements
|
| 59 |
+
total_elements = 1_000_000
|
| 60 |
+
# Create the list with repeated values
|
| 61 |
+
data_list = [value_to_repeat] * total_elements
|
| 62 |
+
# Define the path for the output .pkl length file
|
| 63 |
+
output_file = "sub_1Mtrain_genecorpus_30M_2048_lengths.pkl"
|
| 64 |
+
# Save the list to a .pkl file
|
| 65 |
+
with open(output_file, 'wb') as f:
|
| 66 |
+
pickle.dump(data_list, f)
|
| 67 |
+
print(f"List with {subset_size} elements saved as {output_file}")
|
| 68 |
+
|
| 69 |
+
value_to_repeat_test = 2048
|
| 70 |
+
# Define the total number of elements
|
| 71 |
+
total_elements_test = 200_000
|
| 72 |
+
# Create the list with repeated values
|
| 73 |
+
data_list_test = [value_to_repeat_test] * total_elements_test
|
| 74 |
+
# Define the path for the output .pkl length file
|
| 75 |
+
output_file_test = "sub_200K_1Mtrain_genecorpus_30M_2048_lengths.pkl"
|
| 76 |
+
# Save the list to a .pkl file
|
| 77 |
+
with open(output_file_test, 'wb') as f:
|
| 78 |
+
pickle.dump(data_list_test, f)
|
| 79 |
+
print(f"List with {subset_size} elements saved as {output_file_test}")
|
geneformer/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import tokenizer
|
| 2 |
+
from . import pretrainer
|
| 3 |
+
from . import collator_for_classification
|
| 4 |
+
from . import in_silico_perturber
|
| 5 |
+
from . import in_silico_perturber_stats
|
| 6 |
+
from .tokenizer import TranscriptomeTokenizer
|
| 7 |
+
from .pretrainer import GeneformerPretrainer
|
| 8 |
+
from .collator_for_classification import DataCollatorForGeneClassification
|
| 9 |
+
from .collator_for_classification import DataCollatorForCellClassification
|
| 10 |
+
from .emb_extractor import EmbExtractor
|
| 11 |
+
from .in_silico_perturber import InSilicoPerturber
|
| 12 |
+
from .in_silico_perturber_stats import InSilicoPerturberStats
|
geneformer/collator_for_classification.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geneformer collator for gene and cell classification.
|
| 3 |
+
|
| 4 |
+
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
| 5 |
+
"""
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import warnings
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Dict, List, Optional, Union
|
| 11 |
+
|
| 12 |
+
from transformers import (
|
| 13 |
+
DataCollatorForTokenClassification,
|
| 14 |
+
SpecialTokensMixin,
|
| 15 |
+
BatchEncoding,
|
| 16 |
+
)
|
| 17 |
+
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
| 18 |
+
from transformers.utils.generic import _is_tensorflow, _is_torch
|
| 19 |
+
|
| 20 |
+
from .pretrainer import token_dictionary
|
| 21 |
+
|
| 22 |
+
EncodedInput = List[int]
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
VERY_LARGE_INTEGER = int(
|
| 25 |
+
1e30
|
| 26 |
+
) # This is used to set the max input length for a model with infinite size input
|
| 27 |
+
LARGE_INTEGER = int(
|
| 28 |
+
1e20
|
| 29 |
+
) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
|
| 30 |
+
|
| 31 |
+
# precollator functions
|
| 32 |
+
|
| 33 |
+
class ExplicitEnum(Enum):
|
| 34 |
+
"""
|
| 35 |
+
Enum with more explicit error message for missing values.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
@classmethod
|
| 39 |
+
def _missing_(cls, value):
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"%r is not a valid %s, please select one of %s"
|
| 42 |
+
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
class TruncationStrategy(ExplicitEnum):
|
| 46 |
+
"""
|
| 47 |
+
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
| 48 |
+
tab-completion in an IDE.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
ONLY_FIRST = "only_first"
|
| 52 |
+
ONLY_SECOND = "only_second"
|
| 53 |
+
LONGEST_FIRST = "longest_first"
|
| 54 |
+
DO_NOT_TRUNCATE = "do_not_truncate"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class PaddingStrategy(ExplicitEnum):
|
| 59 |
+
"""
|
| 60 |
+
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
| 61 |
+
in an IDE.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
LONGEST = "longest"
|
| 65 |
+
MAX_LENGTH = "max_length"
|
| 66 |
+
DO_NOT_PAD = "do_not_pad"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TensorType(ExplicitEnum):
|
| 71 |
+
"""
|
| 72 |
+
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
| 73 |
+
tab-completion in an IDE.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
PYTORCH = "pt"
|
| 77 |
+
TENSORFLOW = "tf"
|
| 78 |
+
NUMPY = "np"
|
| 79 |
+
JAX = "jax"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
|
| 83 |
+
mask_token = "<mask>"
|
| 84 |
+
mask_token_id = token_dictionary.get("<mask>")
|
| 85 |
+
pad_token = "<pad>"
|
| 86 |
+
pad_token_id = token_dictionary.get("<pad>")
|
| 87 |
+
padding_side = "right"
|
| 88 |
+
all_special_ids = [
|
| 89 |
+
token_dictionary.get("<mask>"),
|
| 90 |
+
token_dictionary.get("<pad>")
|
| 91 |
+
]
|
| 92 |
+
model_input_names = ["input_ids"]
|
| 93 |
+
|
| 94 |
+
def _get_padding_truncation_strategies(
|
| 95 |
+
self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
|
| 96 |
+
):
|
| 97 |
+
"""
|
| 98 |
+
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
| 99 |
+
and pad_to_max_length) and behaviors.
|
| 100 |
+
"""
|
| 101 |
+
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
|
| 102 |
+
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
|
| 103 |
+
|
| 104 |
+
# Backward compatibility for previous behavior, maybe we should deprecate it:
|
| 105 |
+
# If you only set max_length, it activates truncation for max_length
|
| 106 |
+
if max_length is not None and padding is False and truncation is False:
|
| 107 |
+
if verbose:
|
| 108 |
+
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
|
| 109 |
+
logger.warning(
|
| 110 |
+
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
| 111 |
+
"please use `truncation=True` to explicitly truncate examples to max length. "
|
| 112 |
+
"Defaulting to 'longest_first' truncation strategy. "
|
| 113 |
+
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
|
| 114 |
+
"more precisely by providing a specific strategy to `truncation`."
|
| 115 |
+
)
|
| 116 |
+
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
|
| 117 |
+
truncation = "longest_first"
|
| 118 |
+
|
| 119 |
+
# Get padding strategy
|
| 120 |
+
if padding is False and old_pad_to_max_length:
|
| 121 |
+
if verbose:
|
| 122 |
+
warnings.warn(
|
| 123 |
+
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
|
| 124 |
+
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
|
| 125 |
+
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
|
| 126 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
|
| 127 |
+
"maximal input size of the model (e.g. 512 for Bert).",
|
| 128 |
+
FutureWarning,
|
| 129 |
+
)
|
| 130 |
+
if max_length is None:
|
| 131 |
+
padding_strategy = PaddingStrategy.LONGEST
|
| 132 |
+
else:
|
| 133 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
| 134 |
+
elif padding is not False:
|
| 135 |
+
if padding is True:
|
| 136 |
+
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
|
| 137 |
+
elif not isinstance(padding, PaddingStrategy):
|
| 138 |
+
padding_strategy = PaddingStrategy(padding)
|
| 139 |
+
elif isinstance(padding, PaddingStrategy):
|
| 140 |
+
padding_strategy = padding
|
| 141 |
+
else:
|
| 142 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
| 143 |
+
|
| 144 |
+
# Get truncation strategy
|
| 145 |
+
if truncation is False and old_truncation_strategy != "do_not_truncate":
|
| 146 |
+
if verbose:
|
| 147 |
+
warnings.warn(
|
| 148 |
+
"The `truncation_strategy` argument is deprecated and will be removed in a future version, "
|
| 149 |
+
"use `truncation=True` to truncate examples to a max length. You can give a specific "
|
| 150 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
|
| 151 |
+
"maximal input size of the model (e.g. 512 for Bert). "
|
| 152 |
+
" If you have pairs of inputs, you can give a specific truncation strategy selected among "
|
| 153 |
+
"`truncation='only_first'` (will only truncate the first sentence in the pairs) "
|
| 154 |
+
"`truncation='only_second'` (will only truncate the second sentence in the pairs) "
|
| 155 |
+
"or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
|
| 156 |
+
FutureWarning,
|
| 157 |
+
)
|
| 158 |
+
truncation_strategy = TruncationStrategy(old_truncation_strategy)
|
| 159 |
+
elif truncation is not False:
|
| 160 |
+
if truncation is True:
|
| 161 |
+
truncation_strategy = (
|
| 162 |
+
TruncationStrategy.LONGEST_FIRST
|
| 163 |
+
) # Default to truncate the longest sequences in pairs of inputs
|
| 164 |
+
elif not isinstance(truncation, TruncationStrategy):
|
| 165 |
+
truncation_strategy = TruncationStrategy(truncation)
|
| 166 |
+
elif isinstance(truncation, TruncationStrategy):
|
| 167 |
+
truncation_strategy = truncation
|
| 168 |
+
else:
|
| 169 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
| 170 |
+
|
| 171 |
+
# Set max length if needed
|
| 172 |
+
if max_length is None:
|
| 173 |
+
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
| 174 |
+
if self.model_max_length > LARGE_INTEGER:
|
| 175 |
+
if verbose:
|
| 176 |
+
if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
|
| 177 |
+
logger.warning(
|
| 178 |
+
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
| 179 |
+
"Default to no padding."
|
| 180 |
+
)
|
| 181 |
+
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
|
| 182 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
| 183 |
+
else:
|
| 184 |
+
max_length = self.model_max_length
|
| 185 |
+
|
| 186 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
| 187 |
+
if self.model_max_length > LARGE_INTEGER:
|
| 188 |
+
if verbose:
|
| 189 |
+
if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
|
| 190 |
+
logger.warning(
|
| 191 |
+
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
| 192 |
+
"Default to no truncation."
|
| 193 |
+
)
|
| 194 |
+
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
|
| 195 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
| 196 |
+
else:
|
| 197 |
+
max_length = self.model_max_length
|
| 198 |
+
|
| 199 |
+
# Test if we have a padding token
|
| 200 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
|
| 201 |
+
raise ValueError(
|
| 202 |
+
"Asking to pad but the tokenizer does not have a padding token. "
|
| 203 |
+
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
| 204 |
+
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
|
| 208 |
+
if (
|
| 209 |
+
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
|
| 210 |
+
and padding_strategy != PaddingStrategy.DO_NOT_PAD
|
| 211 |
+
and pad_to_multiple_of is not None
|
| 212 |
+
and max_length is not None
|
| 213 |
+
and (max_length % pad_to_multiple_of != 0)
|
| 214 |
+
):
|
| 215 |
+
raise ValueError(
|
| 216 |
+
f"Truncation and padding are both activated but "
|
| 217 |
+
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return padding_strategy, truncation_strategy, max_length, kwargs
|
| 221 |
+
|
| 222 |
+
def pad(
|
| 223 |
+
self,
|
| 224 |
+
encoded_inputs: Union[
|
| 225 |
+
BatchEncoding,
|
| 226 |
+
List[BatchEncoding],
|
| 227 |
+
Dict[str, EncodedInput],
|
| 228 |
+
Dict[str, List[EncodedInput]],
|
| 229 |
+
List[Dict[str, EncodedInput]],
|
| 230 |
+
],
|
| 231 |
+
class_type, # options: "gene" or "cell"
|
| 232 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
| 233 |
+
max_length: Optional[int] = None,
|
| 234 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 235 |
+
return_attention_mask: Optional[bool] = True,
|
| 236 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 237 |
+
verbose: bool = True,
|
| 238 |
+
) -> BatchEncoding:
|
| 239 |
+
"""
|
| 240 |
+
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
| 241 |
+
in the batch.
|
| 242 |
+
|
| 243 |
+
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
| 244 |
+
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
| 245 |
+
|
| 246 |
+
.. note::
|
| 247 |
+
|
| 248 |
+
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
| 249 |
+
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
| 250 |
+
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
| 254 |
+
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
| 255 |
+
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
| 256 |
+
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
| 257 |
+
well as in a PyTorch Dataloader collate function.
|
| 258 |
+
|
| 259 |
+
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
| 260 |
+
see the note above for the return type.
|
| 261 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
| 262 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 263 |
+
index) among:
|
| 264 |
+
|
| 265 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
| 266 |
+
single sequence if provided).
|
| 267 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
| 268 |
+
maximum acceptable input length for the model if that argument is not provided.
|
| 269 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
| 270 |
+
different lengths).
|
| 271 |
+
max_length (:obj:`int`, `optional`):
|
| 272 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 273 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
| 274 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 275 |
+
|
| 276 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 277 |
+
>= 7.5 (Volta).
|
| 278 |
+
return_attention_mask (:obj:`bool`, `optional`):
|
| 279 |
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
| 280 |
+
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
| 281 |
+
|
| 282 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
| 283 |
+
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
| 284 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
| 285 |
+
|
| 286 |
+
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
| 287 |
+
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
| 288 |
+
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
| 289 |
+
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
| 290 |
+
Whether or not to print more information and warnings.
|
| 291 |
+
"""
|
| 292 |
+
# If we have a list of dicts, let's convert it in a dict of lists
|
| 293 |
+
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
| 294 |
+
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
| 295 |
+
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
| 296 |
+
|
| 297 |
+
# The model's main input name, usually `input_ids`, has be passed for padding
|
| 298 |
+
if self.model_input_names[0] not in encoded_inputs:
|
| 299 |
+
raise ValueError(
|
| 300 |
+
"You should supply an encoding or a list of encodings to this method"
|
| 301 |
+
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 305 |
+
|
| 306 |
+
if not required_input:
|
| 307 |
+
if return_attention_mask:
|
| 308 |
+
encoded_inputs["attention_mask"] = []
|
| 309 |
+
return encoded_inputs
|
| 310 |
+
|
| 311 |
+
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
|
| 312 |
+
# and rebuild them afterwards if no return_tensors is specified
|
| 313 |
+
# Note that we lose the specific device the tensor may be on for PyTorch
|
| 314 |
+
|
| 315 |
+
first_element = required_input[0]
|
| 316 |
+
if isinstance(first_element, (list, tuple)):
|
| 317 |
+
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
|
| 318 |
+
index = 0
|
| 319 |
+
while len(required_input[index]) == 0:
|
| 320 |
+
index += 1
|
| 321 |
+
if index < len(required_input):
|
| 322 |
+
first_element = required_input[index][0]
|
| 323 |
+
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
| 324 |
+
if not isinstance(first_element, (int, list, tuple)):
|
| 325 |
+
if is_tf_available() and _is_tensorflow(first_element):
|
| 326 |
+
return_tensors = "tf" if return_tensors is None else return_tensors
|
| 327 |
+
elif is_torch_available() and _is_torch(first_element):
|
| 328 |
+
return_tensors = "pt" if return_tensors is None else return_tensors
|
| 329 |
+
elif isinstance(first_element, np.ndarray):
|
| 330 |
+
return_tensors = "np" if return_tensors is None else return_tensors
|
| 331 |
+
else:
|
| 332 |
+
raise ValueError(
|
| 333 |
+
f"type of {first_element} unknown: {type(first_element)}. "
|
| 334 |
+
f"Should be one of a python, numpy, pytorch or tensorflow object."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
for key, value in encoded_inputs.items():
|
| 338 |
+
encoded_inputs[key] = to_py_obj(value)
|
| 339 |
+
|
| 340 |
+
# Convert padding_strategy in PaddingStrategy
|
| 341 |
+
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
| 342 |
+
padding=padding, max_length=max_length, verbose=verbose
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 346 |
+
if required_input and not isinstance(required_input[0], (list, tuple)):
|
| 347 |
+
encoded_inputs = self._pad(
|
| 348 |
+
encoded_inputs,
|
| 349 |
+
class_type=class_type,
|
| 350 |
+
max_length=max_length,
|
| 351 |
+
padding_strategy=padding_strategy,
|
| 352 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 353 |
+
return_attention_mask=return_attention_mask,
|
| 354 |
+
)
|
| 355 |
+
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
|
| 356 |
+
|
| 357 |
+
batch_size = len(required_input)
|
| 358 |
+
assert all(
|
| 359 |
+
len(v) == batch_size for v in encoded_inputs.values()
|
| 360 |
+
), "Some items in the output dictionary have a different batch size than others."
|
| 361 |
+
|
| 362 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 363 |
+
max_length = max(len(inputs) for inputs in required_input)
|
| 364 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
| 365 |
+
|
| 366 |
+
batch_outputs = {}
|
| 367 |
+
for i in range(batch_size):
|
| 368 |
+
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
|
| 369 |
+
outputs = self._pad(
|
| 370 |
+
inputs,
|
| 371 |
+
class_type=class_type,
|
| 372 |
+
max_length=max_length,
|
| 373 |
+
padding_strategy=padding_strategy,
|
| 374 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 375 |
+
return_attention_mask=return_attention_mask,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
for key, value in outputs.items():
|
| 379 |
+
if key not in batch_outputs:
|
| 380 |
+
batch_outputs[key] = []
|
| 381 |
+
batch_outputs[key].append(value)
|
| 382 |
+
if class_type == "cell":
|
| 383 |
+
del batch_outputs["label"]
|
| 384 |
+
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
| 385 |
+
|
| 386 |
+
def _pad(
|
| 387 |
+
self,
|
| 388 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
| 389 |
+
class_type, # options: "gene" or "cell"
|
| 390 |
+
max_length: Optional[int] = None,
|
| 391 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
|
| 392 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 393 |
+
return_attention_mask: Optional[bool] = True,
|
| 394 |
+
) -> dict:
|
| 395 |
+
"""
|
| 396 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
| 400 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
| 401 |
+
Will truncate by taking into account the special tokens.
|
| 402 |
+
padding_strategy: PaddingStrategy to use for padding.
|
| 403 |
+
|
| 404 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
| 405 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
| 406 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
| 407 |
+
The tokenizer padding sides are defined in self.padding_side:
|
| 408 |
+
|
| 409 |
+
- 'left': pads on the left of the sequences
|
| 410 |
+
- 'right': pads on the right of the sequences
|
| 411 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
| 412 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
| 413 |
+
>= 7.5 (Volta).
|
| 414 |
+
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
| 415 |
+
"""
|
| 416 |
+
# Load from model defaults
|
| 417 |
+
if return_attention_mask is None:
|
| 418 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
| 419 |
+
|
| 420 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 421 |
+
|
| 422 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 423 |
+
max_length = len(required_input)
|
| 424 |
+
|
| 425 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 426 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 427 |
+
|
| 428 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
| 429 |
+
|
| 430 |
+
if needs_to_be_padded:
|
| 431 |
+
difference = max_length - len(required_input)
|
| 432 |
+
if self.padding_side == "right":
|
| 433 |
+
if return_attention_mask:
|
| 434 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
|
| 435 |
+
if "token_type_ids" in encoded_inputs:
|
| 436 |
+
encoded_inputs["token_type_ids"] = (
|
| 437 |
+
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
| 438 |
+
)
|
| 439 |
+
if "special_tokens_mask" in encoded_inputs:
|
| 440 |
+
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
| 441 |
+
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
| 442 |
+
if class_type == "gene":
|
| 443 |
+
encoded_inputs["labels"] = encoded_inputs["labels"] + [-100] * difference
|
| 444 |
+
elif self.padding_side == "left":
|
| 445 |
+
if return_attention_mask:
|
| 446 |
+
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
|
| 447 |
+
if "token_type_ids" in encoded_inputs:
|
| 448 |
+
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
| 449 |
+
"token_type_ids"
|
| 450 |
+
]
|
| 451 |
+
if "special_tokens_mask" in encoded_inputs:
|
| 452 |
+
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
| 453 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
| 454 |
+
if class_type == "gene":
|
| 455 |
+
encoded_inputs["labels"] = [-100] * difference + encoded_inputs["labels"]
|
| 456 |
+
else:
|
| 457 |
+
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
| 458 |
+
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
| 459 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
| 460 |
+
|
| 461 |
+
return encoded_inputs
|
| 462 |
+
|
| 463 |
+
def get_special_tokens_mask(
|
| 464 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 465 |
+
) -> List[int]:
|
| 466 |
+
"""
|
| 467 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 468 |
+
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
| 469 |
+
Args:
|
| 470 |
+
token_ids_0 (:obj:`List[int]`):
|
| 471 |
+
List of ids of the first sequence.
|
| 472 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
| 473 |
+
List of ids of the second sequence.
|
| 474 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 475 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 476 |
+
Returns:
|
| 477 |
+
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 478 |
+
"""
|
| 479 |
+
assert already_has_special_tokens and token_ids_1 is None, (
|
| 480 |
+
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
|
| 481 |
+
"Please use a slow (full python) tokenizer to activate this argument."
|
| 482 |
+
"Or set `return_special_tokens_mask=True` when calling the encoding method "
|
| 483 |
+
"to get the special tokens mask in any tokenizer. "
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
all_special_ids = self.all_special_ids # cache the property
|
| 487 |
+
|
| 488 |
+
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
|
| 489 |
+
|
| 490 |
+
return special_tokens_mask
|
| 491 |
+
|
| 492 |
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
| 493 |
+
"""
|
| 494 |
+
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
| 495 |
+
vocabulary.
|
| 496 |
+
Args:
|
| 497 |
+
tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
|
| 498 |
+
Returns:
|
| 499 |
+
:obj:`int` or :obj:`List[int]`: The token id or list of token ids.
|
| 500 |
+
"""
|
| 501 |
+
if tokens is None:
|
| 502 |
+
return None
|
| 503 |
+
|
| 504 |
+
if isinstance(tokens, str):
|
| 505 |
+
return self._convert_token_to_id_with_added_voc(tokens)
|
| 506 |
+
|
| 507 |
+
ids = []
|
| 508 |
+
for token in tokens:
|
| 509 |
+
ids.append(self._convert_token_to_id_with_added_voc(token))
|
| 510 |
+
return ids
|
| 511 |
+
|
| 512 |
+
def _convert_token_to_id_with_added_voc(self, token):
|
| 513 |
+
if token is None:
|
| 514 |
+
return None
|
| 515 |
+
|
| 516 |
+
return token_dictionary.get(token)
|
| 517 |
+
|
| 518 |
+
def __len__(self):
|
| 519 |
+
return len(token_dictionary)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# collator functions
|
| 523 |
+
|
| 524 |
+
class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
|
| 525 |
+
"""
|
| 526 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
| 527 |
+
Args:
|
| 528 |
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
| 529 |
+
The tokenizer used for encoding the data.
|
| 530 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
| 531 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 532 |
+
among:
|
| 533 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 534 |
+
sequence if provided).
|
| 535 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
| 536 |
+
maximum acceptable input length for the model if that argument is not provided.
|
| 537 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
| 538 |
+
different lengths).
|
| 539 |
+
max_length (:obj:`int`, `optional`):
|
| 540 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 541 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
| 542 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 543 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 544 |
+
7.5 (Volta).
|
| 545 |
+
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
| 546 |
+
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
tokenizer = PrecollatorForGeneAndCellClassification()
|
| 550 |
+
class_type = "gene"
|
| 551 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
| 552 |
+
max_length: Optional[int] = None
|
| 553 |
+
pad_to_multiple_of: Optional[int] = None
|
| 554 |
+
label_pad_token_id: int = -100
|
| 555 |
+
|
| 556 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 557 |
+
super().__init__(
|
| 558 |
+
tokenizer=self.tokenizer,
|
| 559 |
+
padding=self.padding,
|
| 560 |
+
max_length=self.max_length,
|
| 561 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 562 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 563 |
+
*args, **kwargs)
|
| 564 |
+
|
| 565 |
+
def _prepare_batch(self, features):
|
| 566 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 567 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 568 |
+
batch = self.tokenizer.pad(
|
| 569 |
+
features,
|
| 570 |
+
class_type=self.class_type,
|
| 571 |
+
padding=self.padding,
|
| 572 |
+
max_length=self.max_length,
|
| 573 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 574 |
+
return_tensors="pt",
|
| 575 |
+
)
|
| 576 |
+
return batch
|
| 577 |
+
|
| 578 |
+
def __call__(self, features):
|
| 579 |
+
batch = self._prepare_batch(features)
|
| 580 |
+
|
| 581 |
+
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
| 582 |
+
return batch
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class DataCollatorForCellClassification(DataCollatorForGeneClassification):
|
| 586 |
+
|
| 587 |
+
class_type = "cell"
|
| 588 |
+
|
| 589 |
+
def _prepare_batch(self, features):
|
| 590 |
+
|
| 591 |
+
batch = super()._prepare_batch(features)
|
| 592 |
+
|
| 593 |
+
# Special handling for labels.
|
| 594 |
+
# Ensure that tensor is created with the correct type
|
| 595 |
+
# (it should be automatically the case, but let's make sure of it.)
|
| 596 |
+
first = features[0]
|
| 597 |
+
if "label" in first and first["label"] is not None:
|
| 598 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
| 599 |
+
dtype = torch.long if isinstance(label, int) else torch.float
|
| 600 |
+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
| 601 |
+
|
| 602 |
+
return batch
|
geneformer/emb_extractor.py
ADDED
|
@@ -0,0 +1,806 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geneformer embedding extractor.
|
| 3 |
+
|
| 4 |
+
**Description:**
|
| 5 |
+
|
| 6 |
+
| Extracts gene or cell embeddings.
|
| 7 |
+
| Plots cell embeddings as heatmaps or UMAPs.
|
| 8 |
+
| Generates cell state embedding dictionary for use with InSilicoPerturber.
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# imports
|
| 13 |
+
import logging
|
| 14 |
+
import pickle
|
| 15 |
+
from collections import Counter
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import anndata
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import numpy as np
|
| 21 |
+
import pandas as pd
|
| 22 |
+
import scanpy as sc
|
| 23 |
+
import seaborn as sns
|
| 24 |
+
import torch
|
| 25 |
+
from tdigest import TDigest
|
| 26 |
+
from tqdm.auto import trange
|
| 27 |
+
|
| 28 |
+
from . import perturber_utils as pu
|
| 29 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# extract embeddings
|
| 35 |
+
def get_embs(
|
| 36 |
+
model,
|
| 37 |
+
filtered_input_data,
|
| 38 |
+
emb_mode,
|
| 39 |
+
layer_to_quant,
|
| 40 |
+
pad_token_id,
|
| 41 |
+
forward_batch_size,
|
| 42 |
+
summary_stat=None,
|
| 43 |
+
silent=False,
|
| 44 |
+
):
|
| 45 |
+
model_input_size = pu.get_model_input_size(model)
|
| 46 |
+
total_batch_length = len(filtered_input_data)
|
| 47 |
+
|
| 48 |
+
if summary_stat is None:
|
| 49 |
+
embs_list = []
|
| 50 |
+
elif summary_stat is not None:
|
| 51 |
+
# test embedding extraction for example cell and extract # emb dims
|
| 52 |
+
example = filtered_input_data.select([i for i in range(1)])
|
| 53 |
+
example.set_format(type="torch")
|
| 54 |
+
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
| 55 |
+
if emb_mode == "cell":
|
| 56 |
+
# initiate tdigests for # of emb dims
|
| 57 |
+
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
| 58 |
+
if emb_mode == "gene":
|
| 59 |
+
gene_set = list(
|
| 60 |
+
{
|
| 61 |
+
element
|
| 62 |
+
for sublist in filtered_input_data["input_ids"]
|
| 63 |
+
for element in sublist
|
| 64 |
+
}
|
| 65 |
+
)
|
| 66 |
+
# initiate dict with genes as keys and tdigests for # of emb dims as values
|
| 67 |
+
embs_tdigests_dict = {
|
| 68 |
+
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
overall_max_len = 0
|
| 72 |
+
|
| 73 |
+
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
| 74 |
+
max_range = min(i + forward_batch_size, total_batch_length)
|
| 75 |
+
|
| 76 |
+
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
| 77 |
+
|
| 78 |
+
max_len = int(max(minibatch["length"]))
|
| 79 |
+
original_lens = torch.tensor(minibatch["length"], device="cuda")
|
| 80 |
+
minibatch.set_format(type="torch")
|
| 81 |
+
|
| 82 |
+
input_data_minibatch = minibatch["input_ids"]
|
| 83 |
+
input_data_minibatch = pu.pad_tensor_list(
|
| 84 |
+
input_data_minibatch, max_len, pad_token_id, model_input_size
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
outputs = model(
|
| 89 |
+
input_ids=input_data_minibatch.to("cuda"),
|
| 90 |
+
attention_mask=pu.gen_attention_mask(minibatch),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
embs_i = outputs.hidden_states[layer_to_quant]
|
| 94 |
+
|
| 95 |
+
if emb_mode == "cell":
|
| 96 |
+
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
| 97 |
+
if summary_stat is None:
|
| 98 |
+
embs_list.append(mean_embs)
|
| 99 |
+
elif summary_stat is not None:
|
| 100 |
+
# update tdigests with current batch for each emb dim
|
| 101 |
+
accumulate_tdigests(embs_tdigests, mean_embs, emb_dims)
|
| 102 |
+
del mean_embs
|
| 103 |
+
elif emb_mode == "gene":
|
| 104 |
+
if summary_stat is None:
|
| 105 |
+
embs_list.append(embs_i)
|
| 106 |
+
elif summary_stat is not None:
|
| 107 |
+
for h in trange(len(minibatch)):
|
| 108 |
+
length_h = minibatch[h]["length"]
|
| 109 |
+
input_ids_h = minibatch[h]["input_ids"][0:length_h]
|
| 110 |
+
|
| 111 |
+
# double check dimensions before unsqueezing
|
| 112 |
+
embs_i_dim = embs_i.dim()
|
| 113 |
+
if embs_i_dim != 3:
|
| 114 |
+
logger.error(
|
| 115 |
+
f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
|
| 116 |
+
)
|
| 117 |
+
raise
|
| 118 |
+
|
| 119 |
+
embs_h = embs_i[h, :, :].unsqueeze(dim=1)
|
| 120 |
+
dict_h = dict(zip(input_ids_h, embs_h))
|
| 121 |
+
for k in dict_h.keys():
|
| 122 |
+
accumulate_tdigests(
|
| 123 |
+
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
overall_max_len = max(overall_max_len, max_len)
|
| 127 |
+
del outputs
|
| 128 |
+
del minibatch
|
| 129 |
+
del input_data_minibatch
|
| 130 |
+
del embs_i
|
| 131 |
+
|
| 132 |
+
torch.cuda.empty_cache()
|
| 133 |
+
|
| 134 |
+
if summary_stat is None:
|
| 135 |
+
if emb_mode == "cell":
|
| 136 |
+
embs_stack = torch.cat(embs_list, dim=0)
|
| 137 |
+
elif emb_mode == "gene":
|
| 138 |
+
embs_stack = pu.pad_tensor_list(
|
| 139 |
+
embs_list,
|
| 140 |
+
overall_max_len,
|
| 141 |
+
pad_token_id,
|
| 142 |
+
model_input_size,
|
| 143 |
+
1,
|
| 144 |
+
pu.pad_3d_tensor,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# calculate summary stat embs from approximated tdigests
|
| 148 |
+
elif summary_stat is not None:
|
| 149 |
+
if emb_mode == "cell":
|
| 150 |
+
if summary_stat == "mean":
|
| 151 |
+
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
| 152 |
+
elif summary_stat == "median":
|
| 153 |
+
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
| 154 |
+
embs_stack = torch.tensor(summary_emb_list)
|
| 155 |
+
elif emb_mode == "gene":
|
| 156 |
+
if summary_stat == "mean":
|
| 157 |
+
[
|
| 158 |
+
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
| 159 |
+
for gene in embs_tdigests_dict.keys()
|
| 160 |
+
]
|
| 161 |
+
elif summary_stat == "median":
|
| 162 |
+
[
|
| 163 |
+
update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
|
| 164 |
+
for gene in embs_tdigests_dict.keys()
|
| 165 |
+
]
|
| 166 |
+
return embs_tdigests_dict
|
| 167 |
+
|
| 168 |
+
return embs_stack
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
|
| 172 |
+
# note: tdigest batch update known to be slow so updating serially
|
| 173 |
+
[
|
| 174 |
+
embs_tdigests[j].update(mean_embs[i, j].item())
|
| 175 |
+
for i in range(mean_embs.size(0))
|
| 176 |
+
for j in range(emb_dims)
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
|
| 181 |
+
embs_tdigests_dict[gene] = accumulate_tdigests(
|
| 182 |
+
embs_tdigests_dict[gene], gene_embs, emb_dims
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
|
| 187 |
+
embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
|
| 191 |
+
embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
|
| 195 |
+
length_h = minibatch[h]["length"]
|
| 196 |
+
input_ids_h = minibatch[h]["input_ids"][0:length_h]
|
| 197 |
+
embs_h = embs_i[h, :, :].unsqueeze(dim=1)
|
| 198 |
+
dict_h = dict(zip(input_ids_h, embs_h))
|
| 199 |
+
[
|
| 200 |
+
update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
|
| 201 |
+
for k in dict_h.keys()
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def tdigest_mean(embs_tdigests, emb_dims):
|
| 206 |
+
return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def tdigest_median(embs_tdigests, emb_dims):
|
| 210 |
+
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def test_emb(model, example, layer_to_quant):
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
outputs = model(input_ids=example.to("cuda"))
|
| 216 |
+
|
| 217 |
+
embs_test = outputs.hidden_states[layer_to_quant]
|
| 218 |
+
return embs_test.size()[2]
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def label_cell_embs(embs, downsampled_data, emb_labels):
|
| 222 |
+
embs_df = pd.DataFrame(embs.cpu().numpy())
|
| 223 |
+
if emb_labels is not None:
|
| 224 |
+
for label in emb_labels:
|
| 225 |
+
emb_label = downsampled_data[label]
|
| 226 |
+
embs_df[label] = emb_label
|
| 227 |
+
return embs_df
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
| 231 |
+
gene_set = {
|
| 232 |
+
element for sublist in downsampled_data["input_ids"] for element in sublist
|
| 233 |
+
}
|
| 234 |
+
gene_emb_dict = {k: [] for k in gene_set}
|
| 235 |
+
for i in range(embs.size()[0]):
|
| 236 |
+
length = downsampled_data[i]["length"]
|
| 237 |
+
dict_i = dict(
|
| 238 |
+
zip(
|
| 239 |
+
downsampled_data[i]["input_ids"][0:length],
|
| 240 |
+
embs[i, :, :].unsqueeze(dim=1),
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
for k in dict_i.keys():
|
| 244 |
+
gene_emb_dict[k].append(dict_i[k])
|
| 245 |
+
for k in gene_emb_dict.keys():
|
| 246 |
+
gene_emb_dict[k] = (
|
| 247 |
+
torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
|
| 248 |
+
.cpu()
|
| 249 |
+
.numpy()
|
| 250 |
+
)
|
| 251 |
+
embs_df = pd.DataFrame(gene_emb_dict).T
|
| 252 |
+
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
| 253 |
+
return embs_df
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
| 257 |
+
only_embs_df = embs_df.iloc[:, :emb_dims]
|
| 258 |
+
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
| 259 |
+
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
| 260 |
+
str
|
| 261 |
+
)
|
| 262 |
+
vars_dict = {"embs": only_embs_df.columns}
|
| 263 |
+
obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
|
| 264 |
+
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
| 265 |
+
sc.tl.pca(adata, svd_solver="arpack")
|
| 266 |
+
sc.pp.neighbors(adata)
|
| 267 |
+
sc.tl.umap(adata)
|
| 268 |
+
sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
|
| 269 |
+
sns.set_style("white")
|
| 270 |
+
default_kwargs_dict = {"palette": "Set2", "size": 200}
|
| 271 |
+
if kwargs_dict is not None:
|
| 272 |
+
default_kwargs_dict.update(kwargs_dict)
|
| 273 |
+
|
| 274 |
+
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def gen_heatmap_class_colors(labels, df):
|
| 278 |
+
pal = sns.cubehelix_palette(
|
| 279 |
+
len(Counter(labels).keys()),
|
| 280 |
+
light=0.9,
|
| 281 |
+
dark=0.1,
|
| 282 |
+
hue=1,
|
| 283 |
+
reverse=True,
|
| 284 |
+
start=1,
|
| 285 |
+
rot=-2,
|
| 286 |
+
)
|
| 287 |
+
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
| 288 |
+
colors = pd.Series(labels, index=df.index).map(lut)
|
| 289 |
+
return colors
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def gen_heatmap_class_dict(classes, label_colors_series):
|
| 293 |
+
class_color_dict_df = pd.DataFrame(
|
| 294 |
+
{"classes": classes, "color": label_colors_series}
|
| 295 |
+
)
|
| 296 |
+
class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
|
| 297 |
+
return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def make_colorbar(embs_df, label):
|
| 301 |
+
labels = list(embs_df[label])
|
| 302 |
+
|
| 303 |
+
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
| 304 |
+
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
| 305 |
+
|
| 306 |
+
for i, row in label_colors.iterrows():
|
| 307 |
+
colors = row[0]
|
| 308 |
+
if len(colors) != 3 or any(np.isnan(colors)):
|
| 309 |
+
print(i, colors)
|
| 310 |
+
|
| 311 |
+
label_colors.isna().sum()
|
| 312 |
+
|
| 313 |
+
# create dictionary for colors and classes
|
| 314 |
+
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
| 315 |
+
return label_colors, label_color_dict
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
| 319 |
+
sns.set_style("white")
|
| 320 |
+
sns.set(font_scale=2)
|
| 321 |
+
plt.figure(figsize=(15, 15), dpi=150)
|
| 322 |
+
label_colors, label_color_dict = make_colorbar(embs_df, label)
|
| 323 |
+
|
| 324 |
+
default_kwargs_dict = {
|
| 325 |
+
"row_cluster": True,
|
| 326 |
+
"col_cluster": True,
|
| 327 |
+
"row_colors": label_colors,
|
| 328 |
+
"standard_scale": 1,
|
| 329 |
+
"linewidths": 0,
|
| 330 |
+
"xticklabels": False,
|
| 331 |
+
"yticklabels": False,
|
| 332 |
+
"figsize": (15, 15),
|
| 333 |
+
"center": 0,
|
| 334 |
+
"cmap": "magma",
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
if kwargs_dict is not None:
|
| 338 |
+
default_kwargs_dict.update(kwargs_dict)
|
| 339 |
+
g = sns.clustermap(
|
| 340 |
+
embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
| 344 |
+
|
| 345 |
+
for label_color in list(label_color_dict.keys()):
|
| 346 |
+
g.ax_col_dendrogram.bar(
|
| 347 |
+
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
g.ax_col_dendrogram.legend(
|
| 351 |
+
title=f"{label}",
|
| 352 |
+
loc="lower center",
|
| 353 |
+
ncol=4,
|
| 354 |
+
bbox_to_anchor=(0.5, 1),
|
| 355 |
+
facecolor="white",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
plt.savefig(output_file, bbox_inches="tight")
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class EmbExtractor:
|
| 362 |
+
valid_option_dict = {
|
| 363 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
| 364 |
+
"num_classes": {int},
|
| 365 |
+
"emb_mode": {"cell", "gene"},
|
| 366 |
+
"cell_emb_style": {"mean_pool"},
|
| 367 |
+
"gene_emb_style": {"mean_pool"},
|
| 368 |
+
"filter_data": {None, dict},
|
| 369 |
+
"max_ncells": {None, int},
|
| 370 |
+
"emb_layer": {-1, 0},
|
| 371 |
+
"emb_label": {None, list},
|
| 372 |
+
"labels_to_plot": {None, list},
|
| 373 |
+
"forward_batch_size": {int},
|
| 374 |
+
"nproc": {int},
|
| 375 |
+
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
model_type="Pretrained",
|
| 381 |
+
num_classes=0,
|
| 382 |
+
emb_mode="cell",
|
| 383 |
+
cell_emb_style="mean_pool",
|
| 384 |
+
gene_emb_style="mean_pool",
|
| 385 |
+
filter_data=None,
|
| 386 |
+
max_ncells=1000,
|
| 387 |
+
emb_layer=-1,
|
| 388 |
+
emb_label=None,
|
| 389 |
+
labels_to_plot=None,
|
| 390 |
+
forward_batch_size=100,
|
| 391 |
+
nproc=4,
|
| 392 |
+
summary_stat=None,
|
| 393 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 394 |
+
):
|
| 395 |
+
"""
|
| 396 |
+
Initialize embedding extractor.
|
| 397 |
+
|
| 398 |
+
**Parameters:**
|
| 399 |
+
|
| 400 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
|
| 401 |
+
| Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
| 402 |
+
num_classes : int
|
| 403 |
+
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
| 404 |
+
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
| 405 |
+
emb_mode : {"cell", "gene"}
|
| 406 |
+
| Whether to output cell or gene embeddings.
|
| 407 |
+
cell_emb_style : "mean_pool"
|
| 408 |
+
| Method for summarizing cell embeddings.
|
| 409 |
+
| Currently only option is mean pooling of gene embeddings for given cell.
|
| 410 |
+
gene_emb_style : "mean_pool"
|
| 411 |
+
| Method for summarizing gene embeddings.
|
| 412 |
+
| Currently only option is mean pooling of contextual gene embeddings for given gene.
|
| 413 |
+
filter_data : None, dict
|
| 414 |
+
| Default is to extract embeddings from all input data.
|
| 415 |
+
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
| 416 |
+
max_ncells : None, int
|
| 417 |
+
| Maximum number of cells to extract embeddings from.
|
| 418 |
+
| Default is 1000 cells randomly sampled from input data.
|
| 419 |
+
| If None, will extract embeddings from all cells.
|
| 420 |
+
emb_layer : {-1, 0}
|
| 421 |
+
| Embedding layer to extract.
|
| 422 |
+
| The last layer is most specifically weighted to optimize the given learning objective.
|
| 423 |
+
| Generally, it is best to extract the 2nd to last layer to get a more general representation.
|
| 424 |
+
| -1: 2nd to last layer
|
| 425 |
+
| 0: last layer
|
| 426 |
+
emb_label : None, list
|
| 427 |
+
| List of column name(s) in .dataset to add as labels to embedding output.
|
| 428 |
+
labels_to_plot : None, list
|
| 429 |
+
| Cell labels to plot.
|
| 430 |
+
| Shown as color bar in heatmap.
|
| 431 |
+
| Shown as cell color in umap.
|
| 432 |
+
| Plotting umap requires labels to plot.
|
| 433 |
+
forward_batch_size : int
|
| 434 |
+
| Batch size for forward pass.
|
| 435 |
+
nproc : int
|
| 436 |
+
| Number of CPU processes to use.
|
| 437 |
+
summary_stat : {None, "mean", "median", "exact_mean", "exact_median"}
|
| 438 |
+
| If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
|
| 439 |
+
| If mean or median, outputs only approximated mean or median embedding of input data.
|
| 440 |
+
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
| 441 |
+
| Non-exact is slower but more memory-efficient.
|
| 442 |
+
token_dictionary_file : Path
|
| 443 |
+
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 444 |
+
|
| 445 |
+
**Examples:**
|
| 446 |
+
|
| 447 |
+
.. code-block :: python
|
| 448 |
+
|
| 449 |
+
>>> from geneformer import EmbExtractor
|
| 450 |
+
>>> embex = EmbExtractor(model_type="CellClassifier",
|
| 451 |
+
... num_classes=3,
|
| 452 |
+
... emb_mode="cell",
|
| 453 |
+
... filter_data={"cell_type":["cardiomyocyte"]},
|
| 454 |
+
... max_ncells=1000,
|
| 455 |
+
... max_ncells_to_plot=1000,
|
| 456 |
+
... emb_layer=-1,
|
| 457 |
+
... emb_label=["disease", "cell_type"],
|
| 458 |
+
... labels_to_plot=["disease", "cell_type"])
|
| 459 |
+
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
self.model_type = model_type
|
| 463 |
+
self.num_classes = num_classes
|
| 464 |
+
self.emb_mode = emb_mode
|
| 465 |
+
self.cell_emb_style = cell_emb_style
|
| 466 |
+
self.gene_emb_style = gene_emb_style
|
| 467 |
+
self.filter_data = filter_data
|
| 468 |
+
self.max_ncells = max_ncells
|
| 469 |
+
self.emb_layer = emb_layer
|
| 470 |
+
self.emb_label = emb_label
|
| 471 |
+
self.labels_to_plot = labels_to_plot
|
| 472 |
+
self.forward_batch_size = forward_batch_size
|
| 473 |
+
self.nproc = nproc
|
| 474 |
+
if (summary_stat is not None) and ("exact" in summary_stat):
|
| 475 |
+
self.summary_stat = None
|
| 476 |
+
self.exact_summary_stat = summary_stat
|
| 477 |
+
else:
|
| 478 |
+
self.summary_stat = summary_stat
|
| 479 |
+
self.exact_summary_stat = None
|
| 480 |
+
|
| 481 |
+
self.validate_options()
|
| 482 |
+
|
| 483 |
+
# load token dictionary (Ensembl IDs:token)
|
| 484 |
+
with open(token_dictionary_file, "rb") as f:
|
| 485 |
+
self.gene_token_dict = pickle.load(f)
|
| 486 |
+
|
| 487 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
| 488 |
+
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
| 489 |
+
|
| 490 |
+
def validate_options(self):
|
| 491 |
+
# confirm arguments are within valid options and compatible with each other
|
| 492 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
| 493 |
+
attr_value = self.__dict__[attr_name]
|
| 494 |
+
if not isinstance(attr_value, (list, dict)):
|
| 495 |
+
if attr_value in valid_options:
|
| 496 |
+
continue
|
| 497 |
+
valid_type = False
|
| 498 |
+
for option in valid_options:
|
| 499 |
+
if (option in [int, list, dict, bool]) and isinstance(
|
| 500 |
+
attr_value, option
|
| 501 |
+
):
|
| 502 |
+
valid_type = True
|
| 503 |
+
break
|
| 504 |
+
if valid_type:
|
| 505 |
+
continue
|
| 506 |
+
logger.error(
|
| 507 |
+
f"Invalid option for {attr_name}. "
|
| 508 |
+
f"Valid options for {attr_name}: {valid_options}"
|
| 509 |
+
)
|
| 510 |
+
raise
|
| 511 |
+
|
| 512 |
+
if self.filter_data is not None:
|
| 513 |
+
for key, value in self.filter_data.items():
|
| 514 |
+
if not isinstance(value, list):
|
| 515 |
+
self.filter_data[key] = [value]
|
| 516 |
+
logger.warning(
|
| 517 |
+
"Values in filter_data dict must be lists. "
|
| 518 |
+
f"Changing {key} value to list ([{value}])."
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
def extract_embs(
|
| 522 |
+
self,
|
| 523 |
+
model_directory,
|
| 524 |
+
input_data_file,
|
| 525 |
+
output_directory,
|
| 526 |
+
output_prefix,
|
| 527 |
+
output_torch_embs=False,
|
| 528 |
+
cell_state=None,
|
| 529 |
+
):
|
| 530 |
+
"""
|
| 531 |
+
Extract embeddings from input data and save as results in output_directory.
|
| 532 |
+
|
| 533 |
+
**Parameters:**
|
| 534 |
+
|
| 535 |
+
model_directory : Path
|
| 536 |
+
| Path to directory containing model
|
| 537 |
+
input_data_file : Path
|
| 538 |
+
| Path to directory containing .dataset inputs
|
| 539 |
+
output_directory : Path
|
| 540 |
+
| Path to directory where embedding data will be saved as csv
|
| 541 |
+
output_prefix : str
|
| 542 |
+
| Prefix for output file
|
| 543 |
+
output_torch_embs : bool
|
| 544 |
+
| Whether or not to also output the embeddings as a tensor.
|
| 545 |
+
| Note, if true, will output embeddings as both dataframe and tensor.
|
| 546 |
+
cell_state : dict
|
| 547 |
+
| Cell state key and value for state embedding extraction.
|
| 548 |
+
|
| 549 |
+
**Examples:**
|
| 550 |
+
|
| 551 |
+
.. code-block :: python
|
| 552 |
+
|
| 553 |
+
>>> embs = embex.extract_embs("path/to/model",
|
| 554 |
+
... "path/to/input_data",
|
| 555 |
+
... "path/to/output_directory",
|
| 556 |
+
... "output_prefix")
|
| 557 |
+
|
| 558 |
+
"""
|
| 559 |
+
|
| 560 |
+
filtered_input_data = pu.load_and_filter(
|
| 561 |
+
self.filter_data, self.nproc, input_data_file
|
| 562 |
+
)
|
| 563 |
+
if cell_state is not None:
|
| 564 |
+
filtered_input_data = pu.filter_by_dict(
|
| 565 |
+
filtered_input_data, cell_state, self.nproc
|
| 566 |
+
)
|
| 567 |
+
downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
|
| 568 |
+
model = pu.load_model(self.model_type, self.num_classes, model_directory)
|
| 569 |
+
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
| 570 |
+
embs = get_embs(
|
| 571 |
+
model,
|
| 572 |
+
downsampled_data,
|
| 573 |
+
self.emb_mode,
|
| 574 |
+
layer_to_quant,
|
| 575 |
+
self.pad_token_id,
|
| 576 |
+
self.forward_batch_size,
|
| 577 |
+
self.summary_stat,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
if self.emb_mode == "cell":
|
| 581 |
+
if self.summary_stat is None:
|
| 582 |
+
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
| 583 |
+
elif self.summary_stat is not None:
|
| 584 |
+
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
| 585 |
+
elif self.emb_mode == "gene":
|
| 586 |
+
if self.summary_stat is None:
|
| 587 |
+
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
|
| 588 |
+
elif self.summary_stat is not None:
|
| 589 |
+
embs_df = pd.DataFrame(embs).T
|
| 590 |
+
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
| 591 |
+
|
| 592 |
+
# save embeddings to output_path
|
| 593 |
+
if cell_state is None:
|
| 594 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
| 595 |
+
embs_df.to_csv(output_path)
|
| 596 |
+
|
| 597 |
+
if self.exact_summary_stat == "exact_mean":
|
| 598 |
+
embs = embs.mean(dim=0)
|
| 599 |
+
embs_df = pd.DataFrame(
|
| 600 |
+
embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
|
| 601 |
+
).T
|
| 602 |
+
elif self.exact_summary_stat == "exact_median":
|
| 603 |
+
embs = torch.median(embs, dim=0)[0]
|
| 604 |
+
embs_df = pd.DataFrame(
|
| 605 |
+
embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
|
| 606 |
+
).T
|
| 607 |
+
|
| 608 |
+
if cell_state is not None:
|
| 609 |
+
return embs
|
| 610 |
+
else:
|
| 611 |
+
if output_torch_embs:
|
| 612 |
+
return embs_df, embs
|
| 613 |
+
else:
|
| 614 |
+
return embs_df
|
| 615 |
+
|
| 616 |
+
def get_state_embs(
|
| 617 |
+
self,
|
| 618 |
+
cell_states_to_model,
|
| 619 |
+
model_directory,
|
| 620 |
+
input_data_file,
|
| 621 |
+
output_directory,
|
| 622 |
+
output_prefix,
|
| 623 |
+
output_torch_embs=True,
|
| 624 |
+
):
|
| 625 |
+
"""
|
| 626 |
+
Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
|
| 627 |
+
|
| 628 |
+
**Parameters:**
|
| 629 |
+
|
| 630 |
+
cell_states_to_model : None, dict
|
| 631 |
+
| Cell states to model if testing perturbations that achieve goal state change.
|
| 632 |
+
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
| 633 |
+
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
| 634 |
+
| start_state: value in the state_key column that specifies the start state
|
| 635 |
+
| goal_state: value in the state_key column taht specifies the goal end state
|
| 636 |
+
| alt_states: list of values in the state_key column that specify the alternate end states
|
| 637 |
+
| For example:
|
| 638 |
+
| {"state_key": "disease",
|
| 639 |
+
| "start_state": "dcm",
|
| 640 |
+
| "goal_state": "nf",
|
| 641 |
+
| "alt_states": ["hcm", "other1", "other2"]}
|
| 642 |
+
model_directory : Path
|
| 643 |
+
| Path to directory containing model
|
| 644 |
+
input_data_file : Path
|
| 645 |
+
| Path to directory containing .dataset inputs
|
| 646 |
+
output_directory : Path
|
| 647 |
+
| Path to directory where embedding data will be saved as csv
|
| 648 |
+
output_prefix : str
|
| 649 |
+
| Prefix for output file
|
| 650 |
+
output_torch_embs : bool
|
| 651 |
+
| Whether or not to also output the embeddings as a tensor.
|
| 652 |
+
| Note, if true, will output embeddings as both dataframe and tensor.
|
| 653 |
+
|
| 654 |
+
**Outputs**
|
| 655 |
+
|
| 656 |
+
| Outputs state_embs_dict for use with in silico perturber.
|
| 657 |
+
| Format is dictionary of embedding positions of each cell state to model shifts from/towards.
|
| 658 |
+
| Keys specify each possible cell state to model.
|
| 659 |
+
| Values are target embedding positions as torch.tensor.
|
| 660 |
+
| For example:
|
| 661 |
+
| {"nf": emb_nf,
|
| 662 |
+
| "hcm": emb_hcm,
|
| 663 |
+
| "dcm": emb_dcm,
|
| 664 |
+
| "other1": emb_other1,
|
| 665 |
+
| "other2": emb_other2}
|
| 666 |
+
"""
|
| 667 |
+
|
| 668 |
+
pu.validate_cell_states_to_model(cell_states_to_model)
|
| 669 |
+
valid_summary_stats = ["exact_mean", "exact_median"]
|
| 670 |
+
if self.exact_summary_stat not in valid_summary_stats:
|
| 671 |
+
logger.error(
|
| 672 |
+
"For extracting state embs, summary_stat in EmbExtractor "
|
| 673 |
+
f"must be set to option in {valid_summary_stats}"
|
| 674 |
+
)
|
| 675 |
+
raise
|
| 676 |
+
|
| 677 |
+
state_embs_dict = dict()
|
| 678 |
+
state_key = cell_states_to_model["state_key"]
|
| 679 |
+
for k, v in cell_states_to_model.items():
|
| 680 |
+
if k == "state_key":
|
| 681 |
+
continue
|
| 682 |
+
elif (k == "start_state") or (k == "goal_state"):
|
| 683 |
+
state_embs_dict[v] = self.extract_embs(
|
| 684 |
+
model_directory,
|
| 685 |
+
input_data_file,
|
| 686 |
+
output_directory,
|
| 687 |
+
output_prefix,
|
| 688 |
+
output_torch_embs,
|
| 689 |
+
cell_state={state_key: v},
|
| 690 |
+
)
|
| 691 |
+
else: # k == "alt_states"
|
| 692 |
+
for alt_state in v:
|
| 693 |
+
state_embs_dict[alt_state] = self.extract_embs(
|
| 694 |
+
model_directory,
|
| 695 |
+
input_data_file,
|
| 696 |
+
output_directory,
|
| 697 |
+
output_prefix,
|
| 698 |
+
output_torch_embs,
|
| 699 |
+
cell_state={state_key: alt_state},
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
|
| 703 |
+
with open(output_path, "wb") as fp:
|
| 704 |
+
pickle.dump(state_embs_dict, fp)
|
| 705 |
+
|
| 706 |
+
return state_embs_dict
|
| 707 |
+
|
| 708 |
+
def plot_embs(
|
| 709 |
+
self,
|
| 710 |
+
embs,
|
| 711 |
+
plot_style,
|
| 712 |
+
output_directory,
|
| 713 |
+
output_prefix,
|
| 714 |
+
max_ncells_to_plot=1000,
|
| 715 |
+
kwargs_dict=None,
|
| 716 |
+
):
|
| 717 |
+
"""
|
| 718 |
+
Plot embeddings, coloring by provided labels.
|
| 719 |
+
|
| 720 |
+
**Parameters:**
|
| 721 |
+
|
| 722 |
+
embs : pandas.core.frame.DataFrame
|
| 723 |
+
| Pandas dataframe containing embeddings output from extract_embs
|
| 724 |
+
plot_style : str
|
| 725 |
+
| Style of plot: "heatmap" or "umap"
|
| 726 |
+
output_directory : Path
|
| 727 |
+
| Path to directory where plots will be saved as pdf
|
| 728 |
+
output_prefix : str
|
| 729 |
+
| Prefix for output file
|
| 730 |
+
max_ncells_to_plot : None, int
|
| 731 |
+
| Maximum number of cells to plot.
|
| 732 |
+
| Default is 1000 cells randomly sampled from embeddings.
|
| 733 |
+
| If None, will plot embeddings from all cells.
|
| 734 |
+
kwargs_dict : dict
|
| 735 |
+
| Dictionary of kwargs to pass to plotting function.
|
| 736 |
+
|
| 737 |
+
**Examples:**
|
| 738 |
+
|
| 739 |
+
.. code-block :: python
|
| 740 |
+
|
| 741 |
+
>>> embex.plot_embs(embs=embs,
|
| 742 |
+
... plot_style="heatmap",
|
| 743 |
+
... output_directory="path/to/output_directory",
|
| 744 |
+
... output_prefix="output_prefix")
|
| 745 |
+
|
| 746 |
+
"""
|
| 747 |
+
|
| 748 |
+
if plot_style not in ["heatmap", "umap"]:
|
| 749 |
+
logger.error(
|
| 750 |
+
"Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
|
| 751 |
+
)
|
| 752 |
+
raise
|
| 753 |
+
|
| 754 |
+
if (plot_style == "umap") and (self.labels_to_plot is None):
|
| 755 |
+
logger.error("Plotting UMAP requires 'labels_to_plot'. ")
|
| 756 |
+
raise
|
| 757 |
+
|
| 758 |
+
if max_ncells_to_plot > self.max_ncells:
|
| 759 |
+
max_ncells_to_plot = self.max_ncells
|
| 760 |
+
logger.warning(
|
| 761 |
+
"max_ncells_to_plot must be <= max_ncells. "
|
| 762 |
+
f"Changing max_ncells_to_plot to {self.max_ncells}."
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells):
|
| 766 |
+
embs = embs.sample(max_ncells_to_plot, axis=0)
|
| 767 |
+
|
| 768 |
+
if self.emb_label is None:
|
| 769 |
+
label_len = 0
|
| 770 |
+
else:
|
| 771 |
+
label_len = len(self.emb_label)
|
| 772 |
+
|
| 773 |
+
emb_dims = embs.shape[1] - label_len
|
| 774 |
+
|
| 775 |
+
if self.emb_label is None:
|
| 776 |
+
emb_labels = None
|
| 777 |
+
else:
|
| 778 |
+
emb_labels = embs.columns[emb_dims:]
|
| 779 |
+
|
| 780 |
+
if plot_style == "umap":
|
| 781 |
+
for label in self.labels_to_plot:
|
| 782 |
+
if label not in emb_labels:
|
| 783 |
+
logger.warning(
|
| 784 |
+
f"Label {label} from labels_to_plot "
|
| 785 |
+
f"not present in provided embeddings dataframe."
|
| 786 |
+
)
|
| 787 |
+
continue
|
| 788 |
+
output_prefix_label = "_" + output_prefix + f"_umap_{label}"
|
| 789 |
+
output_file = (
|
| 790 |
+
Path(output_directory) / output_prefix_label
|
| 791 |
+
).with_suffix(".pdf")
|
| 792 |
+
plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
|
| 793 |
+
|
| 794 |
+
if plot_style == "heatmap":
|
| 795 |
+
for label in self.labels_to_plot:
|
| 796 |
+
if label not in emb_labels:
|
| 797 |
+
logger.warning(
|
| 798 |
+
f"Label {label} from labels_to_plot "
|
| 799 |
+
f"not present in provided embeddings dataframe."
|
| 800 |
+
)
|
| 801 |
+
continue
|
| 802 |
+
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
| 803 |
+
output_file = (
|
| 804 |
+
Path(output_directory) / output_prefix_label
|
| 805 |
+
).with_suffix(".pdf")
|
| 806 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
geneformer/gene_median_dictionary.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b3b589bb5ec75040d05fc44dd6bf0184cf87f3c362cf158d196a6ed3b7fe5f39
|
| 3 |
+
size 940965
|
geneformer/gene_name_id_dict.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:55e67962e79c0039a6c32d43c5c99f38e51964bbcfa32f736150ee1e285c438c
|
| 3 |
+
size 1117117
|
geneformer/in_silico_perturber.py
ADDED
|
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geneformer in silico perturber.
|
| 3 |
+
|
| 4 |
+
**Usage:**
|
| 5 |
+
|
| 6 |
+
.. code-block :: python
|
| 7 |
+
|
| 8 |
+
>>> from geneformer import InSilicoPerturber
|
| 9 |
+
>>> isp = InSilicoPerturber(perturb_type="delete",
|
| 10 |
+
... perturb_rank_shift=None,
|
| 11 |
+
... genes_to_perturb="all",
|
| 12 |
+
... model_type="CellClassifier",
|
| 13 |
+
... num_classes=0,
|
| 14 |
+
... emb_mode="cell",
|
| 15 |
+
... filter_data={"cell_type":["cardiomyocyte"]},
|
| 16 |
+
... cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
|
| 17 |
+
... state_embs_dict ={"nf": emb_nf, "hcm": emb_hcm, "dcm": emb_dcm, "other1": emb_other1, "other2": emb_other2},
|
| 18 |
+
... max_ncells=None,
|
| 19 |
+
... emb_layer=0,
|
| 20 |
+
... forward_batch_size=100,
|
| 21 |
+
... nproc=16)
|
| 22 |
+
>>> isp.perturb_data("path/to/model",
|
| 23 |
+
... "path/to/input_data",
|
| 24 |
+
... "path/to/output_directory",
|
| 25 |
+
... "output_prefix")
|
| 26 |
+
|
| 27 |
+
**Description:**
|
| 28 |
+
|
| 29 |
+
| Performs in silico perturbation (e.g. deletion or overexpression) of defined set of genes or all genes in sample of cells.
|
| 30 |
+
| Outputs impact of perturbation on cell or gene embeddings.
|
| 31 |
+
| Output files are analyzed with ``in_silico_perturber_stats``.
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import logging
|
| 36 |
+
|
| 37 |
+
# imports
|
| 38 |
+
import os
|
| 39 |
+
import pickle
|
| 40 |
+
from collections import defaultdict
|
| 41 |
+
from typing import List
|
| 42 |
+
|
| 43 |
+
import seaborn as sns
|
| 44 |
+
import torch
|
| 45 |
+
from datasets import Dataset
|
| 46 |
+
from tqdm.auto import trange
|
| 47 |
+
|
| 48 |
+
from . import perturber_utils as pu
|
| 49 |
+
from .emb_extractor import get_embs
|
| 50 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
| 51 |
+
|
| 52 |
+
sns.set()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.getLogger(__name__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class InSilicoPerturber:
|
| 59 |
+
valid_option_dict = {
|
| 60 |
+
"perturb_type": {"delete", "overexpress", "inhibit", "activate"},
|
| 61 |
+
"perturb_rank_shift": {None, 1, 2, 3},
|
| 62 |
+
"genes_to_perturb": {"all", list},
|
| 63 |
+
"combos": {0, 1},
|
| 64 |
+
"anchor_gene": {None, str},
|
| 65 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
| 66 |
+
"num_classes": {int},
|
| 67 |
+
"emb_mode": {"cell", "cell_and_gene"},
|
| 68 |
+
"cell_emb_style": {"mean_pool"},
|
| 69 |
+
"filter_data": {None, dict},
|
| 70 |
+
"cell_states_to_model": {None, dict},
|
| 71 |
+
"state_embs_dict": {None, dict},
|
| 72 |
+
"max_ncells": {None, int},
|
| 73 |
+
"cell_inds_to_perturb": {"all", dict},
|
| 74 |
+
"emb_layer": {-1, 0},
|
| 75 |
+
"forward_batch_size": {int},
|
| 76 |
+
"nproc": {int},
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
perturb_type="delete",
|
| 82 |
+
perturb_rank_shift=None,
|
| 83 |
+
genes_to_perturb="all",
|
| 84 |
+
combos=0,
|
| 85 |
+
anchor_gene=None,
|
| 86 |
+
model_type="Pretrained",
|
| 87 |
+
num_classes=0,
|
| 88 |
+
emb_mode="cell",
|
| 89 |
+
cell_emb_style="mean_pool",
|
| 90 |
+
filter_data=None,
|
| 91 |
+
cell_states_to_model=None,
|
| 92 |
+
state_embs_dict=None,
|
| 93 |
+
max_ncells=None,
|
| 94 |
+
cell_inds_to_perturb="all",
|
| 95 |
+
emb_layer=-1,
|
| 96 |
+
forward_batch_size=100,
|
| 97 |
+
nproc=4,
|
| 98 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Initialize in silico perturber.
|
| 102 |
+
|
| 103 |
+
**Parameters:**
|
| 104 |
+
|
| 105 |
+
perturb_type : {"delete", "overexpress", "inhibit", "activate"}
|
| 106 |
+
| Type of perturbation.
|
| 107 |
+
| "delete": delete gene from rank value encoding
|
| 108 |
+
| "overexpress": move gene to front of rank value encoding
|
| 109 |
+
| *(TBA)* "inhibit": move gene to lower quartile of rank value encoding
|
| 110 |
+
| *(TBA)* "activate": move gene to higher quartile of rank value encoding
|
| 111 |
+
*(TBA)* perturb_rank_shift : None, {1,2,3}
|
| 112 |
+
| Number of quartiles by which to shift rank of gene.
|
| 113 |
+
| For example, if perturb_type="activate" and perturb_rank_shift=1:
|
| 114 |
+
| genes in 4th quartile will move to middle of 3rd quartile.
|
| 115 |
+
| genes in 3rd quartile will move to middle of 2nd quartile.
|
| 116 |
+
| genes in 2nd quartile will move to middle of 1st quartile.
|
| 117 |
+
| genes in 1st quartile will move to front of rank value encoding.
|
| 118 |
+
| For example, if perturb_type="inhibit" and perturb_rank_shift=2:
|
| 119 |
+
| genes in 1st quartile will move to middle of 3rd quartile.
|
| 120 |
+
| genes in 2nd quartile will move to middle of 4th quartile.
|
| 121 |
+
| genes in 3rd or 4th quartile will move to bottom of rank value encoding.
|
| 122 |
+
genes_to_perturb : "all", list
|
| 123 |
+
| Default is perturbing each gene detected in each cell in the dataset.
|
| 124 |
+
| Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
|
| 125 |
+
| If gene list is provided, then perturber will only test perturbing them all together
|
| 126 |
+
| (rather than testing each possible combination of the provided genes).
|
| 127 |
+
combos : {0,1}
|
| 128 |
+
| Whether to perturb genes individually (0) or in pairs (1).
|
| 129 |
+
anchor_gene : None, str
|
| 130 |
+
| ENSEMBL ID of gene to use as anchor in combination perturbations.
|
| 131 |
+
| For example, if combos=1 and anchor_gene="ENSG00000148400":
|
| 132 |
+
| anchor gene will be perturbed in combination with each other gene.
|
| 133 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
|
| 134 |
+
| Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
|
| 135 |
+
num_classes : int
|
| 136 |
+
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
| 137 |
+
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
| 138 |
+
emb_mode : {"cell", "cell_and_gene"}
|
| 139 |
+
| Whether to output impact of perturbation on cell and/or gene embeddings.
|
| 140 |
+
| Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
| 141 |
+
cell_emb_style : "mean_pool"
|
| 142 |
+
| Method for summarizing cell embeddings.
|
| 143 |
+
| Currently only option is mean pooling of gene embeddings for given cell.
|
| 144 |
+
filter_data : None, dict
|
| 145 |
+
| Default is to use all input data for in silico perturbation study.
|
| 146 |
+
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
| 147 |
+
cell_states_to_model : None, dict
|
| 148 |
+
| Cell states to model if testing perturbations that achieve goal state change.
|
| 149 |
+
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
| 150 |
+
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
| 151 |
+
| start_state: value in the state_key column that specifies the start state
|
| 152 |
+
| goal_state: value in the state_key column taht specifies the goal end state
|
| 153 |
+
| alt_states: list of values in the state_key column that specify the alternate end states
|
| 154 |
+
| For example: {"state_key": "disease",
|
| 155 |
+
| "start_state": "dcm",
|
| 156 |
+
| "goal_state": "nf",
|
| 157 |
+
| "alt_states": ["hcm", "other1", "other2"]}
|
| 158 |
+
state_embs_dict : None, dict
|
| 159 |
+
| Embedding positions of each cell state to model shifts from/towards (e.g. mean or median).
|
| 160 |
+
| Dictionary with keys specifying each possible cell state to model.
|
| 161 |
+
| Values are target embedding positions as torch.tensor.
|
| 162 |
+
| For example: {"nf": emb_nf,
|
| 163 |
+
| "hcm": emb_hcm,
|
| 164 |
+
| "dcm": emb_dcm,
|
| 165 |
+
| "other1": emb_other1,
|
| 166 |
+
| "other2": emb_other2}
|
| 167 |
+
max_ncells : None, int
|
| 168 |
+
| Maximum number of cells to test.
|
| 169 |
+
| If None, will test all cells.
|
| 170 |
+
cell_inds_to_perturb : "all", list
|
| 171 |
+
| Default is perturbing each cell in the dataset.
|
| 172 |
+
| Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
|
| 173 |
+
| start_ind: the first index to perturb.
|
| 174 |
+
| end_ind: the last index to perturb (exclusive).
|
| 175 |
+
| Indices will be selected *after* the filter_data criteria and sorting.
|
| 176 |
+
| Useful for splitting extremely large datasets across separate GPUs.
|
| 177 |
+
emb_layer : {-1, 0}
|
| 178 |
+
| Embedding layer to use for quantification.
|
| 179 |
+
| 0: last layer (recommended for questions closely tied to model's training objective)
|
| 180 |
+
| -1: 2nd to last layer (recommended for questions requiring more general representations)
|
| 181 |
+
forward_batch_size : int
|
| 182 |
+
| Batch size for forward pass.
|
| 183 |
+
nproc : int
|
| 184 |
+
| Number of CPU processes to use.
|
| 185 |
+
token_dictionary_file : Path
|
| 186 |
+
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
self.perturb_type = perturb_type
|
| 190 |
+
self.perturb_rank_shift = perturb_rank_shift
|
| 191 |
+
self.genes_to_perturb = genes_to_perturb
|
| 192 |
+
self.combos = combos
|
| 193 |
+
self.anchor_gene = anchor_gene
|
| 194 |
+
if self.genes_to_perturb == "all":
|
| 195 |
+
self.perturb_group = False
|
| 196 |
+
else:
|
| 197 |
+
self.perturb_group = True
|
| 198 |
+
if (self.anchor_gene is not None) or (self.combos != 0):
|
| 199 |
+
self.anchor_gene = None
|
| 200 |
+
self.combos = 0
|
| 201 |
+
logger.warning(
|
| 202 |
+
"anchor_gene set to None and combos set to 0. "
|
| 203 |
+
"If providing list of genes to perturb, "
|
| 204 |
+
"list of genes_to_perturb will be perturbed together, "
|
| 205 |
+
"without anchor gene or combinations."
|
| 206 |
+
)
|
| 207 |
+
self.model_type = model_type
|
| 208 |
+
self.num_classes = num_classes
|
| 209 |
+
self.emb_mode = emb_mode
|
| 210 |
+
self.cell_emb_style = cell_emb_style
|
| 211 |
+
self.filter_data = filter_data
|
| 212 |
+
self.cell_states_to_model = cell_states_to_model
|
| 213 |
+
self.state_embs_dict = state_embs_dict
|
| 214 |
+
self.max_ncells = max_ncells
|
| 215 |
+
self.cell_inds_to_perturb = cell_inds_to_perturb
|
| 216 |
+
self.emb_layer = emb_layer
|
| 217 |
+
self.forward_batch_size = forward_batch_size
|
| 218 |
+
self.nproc = nproc
|
| 219 |
+
|
| 220 |
+
self.validate_options()
|
| 221 |
+
|
| 222 |
+
# load token dictionary (Ensembl IDs:token)
|
| 223 |
+
with open(token_dictionary_file, "rb") as f:
|
| 224 |
+
self.gene_token_dict = pickle.load(f)
|
| 225 |
+
|
| 226 |
+
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
| 227 |
+
|
| 228 |
+
if self.anchor_gene is None:
|
| 229 |
+
self.anchor_token = None
|
| 230 |
+
else:
|
| 231 |
+
try:
|
| 232 |
+
self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
|
| 233 |
+
except KeyError:
|
| 234 |
+
logger.error(f"Anchor gene {self.anchor_gene} not in token dictionary.")
|
| 235 |
+
raise
|
| 236 |
+
|
| 237 |
+
if self.genes_to_perturb == "all":
|
| 238 |
+
self.tokens_to_perturb = "all"
|
| 239 |
+
else:
|
| 240 |
+
missing_genes = [
|
| 241 |
+
gene
|
| 242 |
+
for gene in self.genes_to_perturb
|
| 243 |
+
if gene not in self.gene_token_dict.keys()
|
| 244 |
+
]
|
| 245 |
+
if len(missing_genes) == len(self.genes_to_perturb):
|
| 246 |
+
logger.error(
|
| 247 |
+
"None of the provided genes to perturb are in token dictionary."
|
| 248 |
+
)
|
| 249 |
+
raise
|
| 250 |
+
elif len(missing_genes) > 0:
|
| 251 |
+
logger.warning(
|
| 252 |
+
f"Genes to perturb {missing_genes} are not in token dictionary."
|
| 253 |
+
)
|
| 254 |
+
self.tokens_to_perturb = [
|
| 255 |
+
self.gene_token_dict.get(gene) for gene in self.genes_to_perturb
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
def validate_options(self):
|
| 259 |
+
# first disallow options under development
|
| 260 |
+
if self.perturb_type in ["inhibit", "activate"]:
|
| 261 |
+
logger.error(
|
| 262 |
+
"In silico inhibition and activation currently under development. "
|
| 263 |
+
"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
| 264 |
+
)
|
| 265 |
+
raise
|
| 266 |
+
if (self.combos > 0) and (self.anchor_token is None):
|
| 267 |
+
logger.error(
|
| 268 |
+
"Combination perturbation without anchor gene is currently under development. "
|
| 269 |
+
"Currently, must provide anchor gene for combination perturbation."
|
| 270 |
+
)
|
| 271 |
+
raise
|
| 272 |
+
|
| 273 |
+
# confirm arguments are within valid options and compatible with each other
|
| 274 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
| 275 |
+
attr_value = self.__dict__[attr_name]
|
| 276 |
+
if type(attr_value) not in {list, dict}:
|
| 277 |
+
if attr_value in valid_options:
|
| 278 |
+
continue
|
| 279 |
+
if attr_name in ["anchor_gene"]:
|
| 280 |
+
if type(attr_name) in {str}:
|
| 281 |
+
continue
|
| 282 |
+
valid_type = False
|
| 283 |
+
for option in valid_options:
|
| 284 |
+
if (option in [bool, int, list, dict]) and isinstance(
|
| 285 |
+
attr_value, option
|
| 286 |
+
):
|
| 287 |
+
valid_type = True
|
| 288 |
+
break
|
| 289 |
+
if valid_type:
|
| 290 |
+
continue
|
| 291 |
+
logger.error(
|
| 292 |
+
f"Invalid option for {attr_name}. "
|
| 293 |
+
f"Valid options for {attr_name}: {valid_options}"
|
| 294 |
+
)
|
| 295 |
+
raise
|
| 296 |
+
|
| 297 |
+
if self.perturb_type in ["delete", "overexpress"]:
|
| 298 |
+
if self.perturb_rank_shift is not None:
|
| 299 |
+
if self.perturb_type == "delete":
|
| 300 |
+
logger.warning(
|
| 301 |
+
"perturb_rank_shift set to None. "
|
| 302 |
+
"If perturb type is delete then gene is deleted entirely "
|
| 303 |
+
"rather than shifted by quartile"
|
| 304 |
+
)
|
| 305 |
+
elif self.perturb_type == "overexpress":
|
| 306 |
+
logger.warning(
|
| 307 |
+
"perturb_rank_shift set to None. "
|
| 308 |
+
"If perturb type is overexpress then gene is moved to front "
|
| 309 |
+
"of rank value encoding rather than shifted by quartile"
|
| 310 |
+
)
|
| 311 |
+
self.perturb_rank_shift = None
|
| 312 |
+
|
| 313 |
+
if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
|
| 314 |
+
self.emb_mode = "cell"
|
| 315 |
+
logger.warning(
|
| 316 |
+
"emb_mode set to 'cell'. "
|
| 317 |
+
"Currently, analysis with anchor gene "
|
| 318 |
+
"only outputs effect on cell embeddings."
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if self.cell_states_to_model is not None:
|
| 322 |
+
pu.validate_cell_states_to_model(self.cell_states_to_model)
|
| 323 |
+
|
| 324 |
+
if self.anchor_gene is not None:
|
| 325 |
+
self.anchor_gene = None
|
| 326 |
+
logger.warning(
|
| 327 |
+
"anchor_gene set to None. "
|
| 328 |
+
"Currently, anchor gene not available "
|
| 329 |
+
"when modeling multiple cell states."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if self.state_embs_dict is None:
|
| 333 |
+
logger.error(
|
| 334 |
+
"state_embs_dict must be provided for mode with cell_states_to_model. "
|
| 335 |
+
"Format is dictionary with keys specifying each possible cell state to model. "
|
| 336 |
+
"Values are target embedding positions as torch.tensor."
|
| 337 |
+
)
|
| 338 |
+
raise
|
| 339 |
+
|
| 340 |
+
for state_emb in self.state_embs_dict.values():
|
| 341 |
+
if not torch.is_tensor(state_emb):
|
| 342 |
+
logger.error(
|
| 343 |
+
"state_embs_dict must be dictionary with values being torch.tensor."
|
| 344 |
+
)
|
| 345 |
+
raise
|
| 346 |
+
|
| 347 |
+
keys_absent = []
|
| 348 |
+
for k, v in self.cell_states_to_model.items():
|
| 349 |
+
if (k == "start_state") or (k == "goal_state"):
|
| 350 |
+
if v not in self.state_embs_dict.keys():
|
| 351 |
+
keys_absent.append(v)
|
| 352 |
+
if k == "alt_states":
|
| 353 |
+
for state in v:
|
| 354 |
+
if state not in self.state_embs_dict.keys():
|
| 355 |
+
keys_absent.append(state)
|
| 356 |
+
if len(keys_absent) > 0:
|
| 357 |
+
logger.error(
|
| 358 |
+
"Each start_state, goal_state, and alt_states in cell_states_to_model "
|
| 359 |
+
"must be a key in state_embs_dict with the value being "
|
| 360 |
+
"the state's embedding position as torch.tensor. "
|
| 361 |
+
f"Missing keys: {keys_absent}"
|
| 362 |
+
)
|
| 363 |
+
raise
|
| 364 |
+
|
| 365 |
+
if self.perturb_type in ["inhibit", "activate"]:
|
| 366 |
+
if self.perturb_rank_shift is None:
|
| 367 |
+
logger.error(
|
| 368 |
+
"If perturb_type is inhibit or activate then "
|
| 369 |
+
"quartile to shift by must be specified."
|
| 370 |
+
)
|
| 371 |
+
raise
|
| 372 |
+
|
| 373 |
+
if self.filter_data is not None:
|
| 374 |
+
for key, value in self.filter_data.items():
|
| 375 |
+
if not isinstance(value, list):
|
| 376 |
+
self.filter_data[key] = [value]
|
| 377 |
+
logger.warning(
|
| 378 |
+
"Values in filter_data dict must be lists. "
|
| 379 |
+
f"Changing {key} value to list ([{value}])."
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
if self.cell_inds_to_perturb != "all":
|
| 383 |
+
if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
|
| 384 |
+
logger.error(
|
| 385 |
+
"If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
| 386 |
+
)
|
| 387 |
+
raise
|
| 388 |
+
if (
|
| 389 |
+
self.cell_inds_to_perturb["start"] < 0
|
| 390 |
+
or self.cell_inds_to_perturb["end"] < 0
|
| 391 |
+
):
|
| 392 |
+
logger.error("cell_inds_to_perturb must be positive.")
|
| 393 |
+
raise
|
| 394 |
+
|
| 395 |
+
def perturb_data(
|
| 396 |
+
self, model_directory, input_data_file, output_directory, output_prefix
|
| 397 |
+
):
|
| 398 |
+
"""
|
| 399 |
+
Perturb genes in input data and save as results in output_directory.
|
| 400 |
+
|
| 401 |
+
**Parameters:**
|
| 402 |
+
|
| 403 |
+
model_directory : Path
|
| 404 |
+
| Path to directory containing model
|
| 405 |
+
input_data_file : Path
|
| 406 |
+
| Path to directory containing .dataset inputs
|
| 407 |
+
output_directory : Path
|
| 408 |
+
| Path to directory where perturbation data will be saved as batched pickle files
|
| 409 |
+
output_prefix : str
|
| 410 |
+
| Prefix for output files
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
### format output path ###
|
| 414 |
+
output_path_prefix = os.path.join(
|
| 415 |
+
output_directory, f"in_silico_{self.perturb_type}_{output_prefix}"
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
### load model and define parameters ###
|
| 419 |
+
model = pu.load_model(self.model_type, self.num_classes, model_directory)
|
| 420 |
+
self.max_len = pu.get_model_input_size(model)
|
| 421 |
+
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
| 422 |
+
|
| 423 |
+
### filter input data ###
|
| 424 |
+
# general filtering of input data based on filter_data argument
|
| 425 |
+
filtered_input_data = pu.load_and_filter(
|
| 426 |
+
self.filter_data, self.nproc, input_data_file
|
| 427 |
+
)
|
| 428 |
+
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
| 429 |
+
|
| 430 |
+
if self.perturb_group is True:
|
| 431 |
+
self.isp_perturb_set(
|
| 432 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
| 433 |
+
)
|
| 434 |
+
else:
|
| 435 |
+
self.isp_perturb_all(
|
| 436 |
+
model, filtered_input_data, layer_to_quant, output_path_prefix
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def apply_additional_filters(self, filtered_input_data):
|
| 440 |
+
# additional filtering of input data dependent on isp mode
|
| 441 |
+
if self.cell_states_to_model is not None:
|
| 442 |
+
# filter for cells with start_state and log result
|
| 443 |
+
filtered_input_data = pu.filter_data_by_start_state(
|
| 444 |
+
filtered_input_data, self.cell_states_to_model, self.nproc
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
|
| 448 |
+
# filter for cells with tokens_to_perturb and log result
|
| 449 |
+
filtered_input_data = pu.filter_data_by_tokens_and_log(
|
| 450 |
+
filtered_input_data,
|
| 451 |
+
self.tokens_to_perturb,
|
| 452 |
+
self.nproc,
|
| 453 |
+
"genes_to_perturb",
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
if self.anchor_token is not None:
|
| 457 |
+
# filter for cells with anchor gene and log result
|
| 458 |
+
filtered_input_data = pu.filter_data_by_tokens_and_log(
|
| 459 |
+
filtered_input_data, self.anchor_token, self.nproc, "anchor_gene"
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# downsample and sort largest to smallest to encounter memory constraints earlier
|
| 463 |
+
filtered_input_data = pu.downsample_and_sort(
|
| 464 |
+
filtered_input_data, self.max_ncells
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# slice dataset if cells_inds_to_perturb is not "all"
|
| 468 |
+
if self.cell_inds_to_perturb != "all":
|
| 469 |
+
filtered_input_data = pu.slice_by_inds_to_perturb(
|
| 470 |
+
filtered_input_data, self.cell_inds_to_perturb
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
return filtered_input_data
|
| 474 |
+
|
| 475 |
+
def isp_perturb_set(
|
| 476 |
+
self,
|
| 477 |
+
model,
|
| 478 |
+
filtered_input_data: Dataset,
|
| 479 |
+
layer_to_quant: int,
|
| 480 |
+
output_path_prefix: str,
|
| 481 |
+
):
|
| 482 |
+
def make_group_perturbation_batch(example):
|
| 483 |
+
example_input_ids = example["input_ids"]
|
| 484 |
+
example["tokens_to_perturb"] = self.tokens_to_perturb
|
| 485 |
+
indices_to_perturb = [
|
| 486 |
+
example_input_ids.index(token) if token in example_input_ids else None
|
| 487 |
+
for token in self.tokens_to_perturb
|
| 488 |
+
]
|
| 489 |
+
indices_to_perturb = [
|
| 490 |
+
item for item in indices_to_perturb if item is not None
|
| 491 |
+
]
|
| 492 |
+
if len(indices_to_perturb) > 0:
|
| 493 |
+
example["perturb_index"] = indices_to_perturb
|
| 494 |
+
else:
|
| 495 |
+
# -100 indicates tokens to overexpress are not present in rank value encoding
|
| 496 |
+
example["perturb_index"] = [-100]
|
| 497 |
+
if self.perturb_type == "delete":
|
| 498 |
+
example = pu.delete_indices(example)
|
| 499 |
+
elif self.perturb_type == "overexpress":
|
| 500 |
+
example = pu.overexpress_tokens(example, self.max_len)
|
| 501 |
+
example["n_overflow"] = pu.calc_n_overflow(
|
| 502 |
+
self.max_len,
|
| 503 |
+
example["length"],
|
| 504 |
+
self.tokens_to_perturb,
|
| 505 |
+
indices_to_perturb,
|
| 506 |
+
)
|
| 507 |
+
return example
|
| 508 |
+
|
| 509 |
+
total_batch_length = len(filtered_input_data)
|
| 510 |
+
if self.cell_states_to_model is None:
|
| 511 |
+
cos_sims_dict = defaultdict(list)
|
| 512 |
+
else:
|
| 513 |
+
cos_sims_dict = {
|
| 514 |
+
state: defaultdict(list)
|
| 515 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
perturbed_data = filtered_input_data.map(
|
| 519 |
+
make_group_perturbation_batch, num_proc=self.nproc
|
| 520 |
+
)
|
| 521 |
+
if self.perturb_type == "overexpress":
|
| 522 |
+
filtered_input_data = filtered_input_data.add_column(
|
| 523 |
+
"n_overflow", perturbed_data["n_overflow"]
|
| 524 |
+
)
|
| 525 |
+
# remove overflow genes from original data so that embeddings are comparable
|
| 526 |
+
# i.e. if original cell has genes 0:2047 and you want to overexpress new gene 2048,
|
| 527 |
+
# then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
|
| 528 |
+
# (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
|
| 529 |
+
# rather than only adding 2048)
|
| 530 |
+
filtered_input_data = filtered_input_data.map(
|
| 531 |
+
pu.truncate_by_n_overflow, num_proc=self.nproc
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
if self.emb_mode == "cell_and_gene":
|
| 535 |
+
stored_gene_embs_dict = defaultdict(list)
|
| 536 |
+
|
| 537 |
+
# iterate through batches
|
| 538 |
+
for i in trange(0, total_batch_length, self.forward_batch_size):
|
| 539 |
+
max_range = min(i + self.forward_batch_size, total_batch_length)
|
| 540 |
+
inds_select = [i for i in range(i, max_range)]
|
| 541 |
+
|
| 542 |
+
minibatch = filtered_input_data.select(inds_select)
|
| 543 |
+
perturbation_batch = perturbed_data.select(inds_select)
|
| 544 |
+
|
| 545 |
+
if self.cell_emb_style == "mean_pool":
|
| 546 |
+
full_original_emb = get_embs(
|
| 547 |
+
model,
|
| 548 |
+
minibatch,
|
| 549 |
+
"gene",
|
| 550 |
+
layer_to_quant,
|
| 551 |
+
self.pad_token_id,
|
| 552 |
+
self.forward_batch_size,
|
| 553 |
+
summary_stat=None,
|
| 554 |
+
silent=True,
|
| 555 |
+
)
|
| 556 |
+
indices_to_perturb = perturbation_batch["perturb_index"]
|
| 557 |
+
# remove indices that were perturbed
|
| 558 |
+
original_emb = pu.remove_perturbed_indices_set(
|
| 559 |
+
full_original_emb,
|
| 560 |
+
self.perturb_type,
|
| 561 |
+
indices_to_perturb,
|
| 562 |
+
self.tokens_to_perturb,
|
| 563 |
+
minibatch["length"],
|
| 564 |
+
)
|
| 565 |
+
full_perturbation_emb = get_embs(
|
| 566 |
+
model,
|
| 567 |
+
perturbation_batch,
|
| 568 |
+
"gene",
|
| 569 |
+
layer_to_quant,
|
| 570 |
+
self.pad_token_id,
|
| 571 |
+
self.forward_batch_size,
|
| 572 |
+
summary_stat=None,
|
| 573 |
+
silent=True,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# remove overexpressed genes
|
| 577 |
+
if self.perturb_type == "overexpress":
|
| 578 |
+
perturbation_emb = full_perturbation_emb[
|
| 579 |
+
:, len(self.tokens_to_perturb) :, :
|
| 580 |
+
]
|
| 581 |
+
|
| 582 |
+
elif self.perturb_type == "delete":
|
| 583 |
+
perturbation_emb = full_perturbation_emb[
|
| 584 |
+
:, : max(perturbation_batch["length"]), :
|
| 585 |
+
]
|
| 586 |
+
|
| 587 |
+
n_perturbation_genes = perturbation_emb.size()[1]
|
| 588 |
+
|
| 589 |
+
# if no goal states, the cosine similarties are the mean of gene cosine similarities
|
| 590 |
+
if (
|
| 591 |
+
self.cell_states_to_model is None
|
| 592 |
+
or self.emb_mode == "cell_and_gene"
|
| 593 |
+
):
|
| 594 |
+
gene_cos_sims = pu.quant_cos_sims(
|
| 595 |
+
perturbation_emb,
|
| 596 |
+
original_emb,
|
| 597 |
+
self.cell_states_to_model,
|
| 598 |
+
self.state_embs_dict,
|
| 599 |
+
emb_mode="gene",
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# if there are goal states, the cosine similarities are the cell cosine similarities
|
| 603 |
+
if self.cell_states_to_model is not None:
|
| 604 |
+
original_cell_emb = pu.mean_nonpadding_embs(
|
| 605 |
+
full_original_emb,
|
| 606 |
+
torch.tensor(minibatch["length"], device="cuda"),
|
| 607 |
+
dim=1,
|
| 608 |
+
)
|
| 609 |
+
perturbation_cell_emb = pu.mean_nonpadding_embs(
|
| 610 |
+
full_perturbation_emb,
|
| 611 |
+
torch.tensor(perturbation_batch["length"], device="cuda"),
|
| 612 |
+
dim=1,
|
| 613 |
+
)
|
| 614 |
+
cell_cos_sims = pu.quant_cos_sims(
|
| 615 |
+
perturbation_cell_emb,
|
| 616 |
+
original_cell_emb,
|
| 617 |
+
self.cell_states_to_model,
|
| 618 |
+
self.state_embs_dict,
|
| 619 |
+
emb_mode="cell",
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# get cosine similarities in gene embeddings
|
| 623 |
+
# if getting gene embeddings, need gene names
|
| 624 |
+
if self.emb_mode == "cell_and_gene":
|
| 625 |
+
gene_list = minibatch["input_ids"]
|
| 626 |
+
# need to truncate gene_list
|
| 627 |
+
gene_list = [
|
| 628 |
+
[g for g in genes if g not in self.tokens_to_perturb][
|
| 629 |
+
:n_perturbation_genes
|
| 630 |
+
]
|
| 631 |
+
for genes in gene_list
|
| 632 |
+
]
|
| 633 |
+
|
| 634 |
+
for cell_i, genes in enumerate(gene_list):
|
| 635 |
+
for gene_j, affected_gene in enumerate(genes):
|
| 636 |
+
if len(self.genes_to_perturb) > 1:
|
| 637 |
+
tokens_to_perturb = tuple(self.tokens_to_perturb)
|
| 638 |
+
else:
|
| 639 |
+
tokens_to_perturb = self.tokens_to_perturb[0]
|
| 640 |
+
|
| 641 |
+
# fill in the gene cosine similarities
|
| 642 |
+
try:
|
| 643 |
+
stored_gene_embs_dict[
|
| 644 |
+
(tokens_to_perturb, affected_gene)
|
| 645 |
+
].append(gene_cos_sims[cell_i, gene_j].item())
|
| 646 |
+
except KeyError:
|
| 647 |
+
stored_gene_embs_dict[
|
| 648 |
+
(tokens_to_perturb, affected_gene)
|
| 649 |
+
] = gene_cos_sims[cell_i, gene_j].item()
|
| 650 |
+
else:
|
| 651 |
+
gene_list = None
|
| 652 |
+
|
| 653 |
+
if self.cell_states_to_model is None:
|
| 654 |
+
# calculate the mean of the gene cosine similarities for cell shift
|
| 655 |
+
# tensor of nonpadding lengths for each cell
|
| 656 |
+
if self.perturb_type == "overexpress":
|
| 657 |
+
# subtract number of genes that were overexpressed
|
| 658 |
+
# since they are removed before getting cos sims
|
| 659 |
+
n_overexpressed = len(self.tokens_to_perturb)
|
| 660 |
+
nonpadding_lens = [
|
| 661 |
+
x - n_overexpressed for x in perturbation_batch["length"]
|
| 662 |
+
]
|
| 663 |
+
else:
|
| 664 |
+
nonpadding_lens = perturbation_batch["length"]
|
| 665 |
+
cos_sims_data = pu.mean_nonpadding_embs(
|
| 666 |
+
gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
|
| 667 |
+
)
|
| 668 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
| 669 |
+
cos_sims_dict,
|
| 670 |
+
cos_sims_data,
|
| 671 |
+
filtered_input_data,
|
| 672 |
+
indices_to_perturb,
|
| 673 |
+
gene_list,
|
| 674 |
+
)
|
| 675 |
+
else:
|
| 676 |
+
cos_sims_data = cell_cos_sims
|
| 677 |
+
for state in cos_sims_dict.keys():
|
| 678 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
| 679 |
+
cos_sims_dict[state],
|
| 680 |
+
cos_sims_data[state],
|
| 681 |
+
filtered_input_data,
|
| 682 |
+
indices_to_perturb,
|
| 683 |
+
gene_list,
|
| 684 |
+
)
|
| 685 |
+
del minibatch
|
| 686 |
+
del perturbation_batch
|
| 687 |
+
del original_emb
|
| 688 |
+
del perturbation_emb
|
| 689 |
+
del cos_sims_data
|
| 690 |
+
|
| 691 |
+
torch.cuda.empty_cache()
|
| 692 |
+
|
| 693 |
+
pu.write_perturbation_dictionary(
|
| 694 |
+
cos_sims_dict,
|
| 695 |
+
f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
if self.emb_mode == "cell_and_gene":
|
| 699 |
+
pu.write_perturbation_dictionary(
|
| 700 |
+
stored_gene_embs_dict,
|
| 701 |
+
f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
def isp_perturb_all(
|
| 705 |
+
self,
|
| 706 |
+
model,
|
| 707 |
+
filtered_input_data: Dataset,
|
| 708 |
+
layer_to_quant: int,
|
| 709 |
+
output_path_prefix: str,
|
| 710 |
+
):
|
| 711 |
+
pickle_batch = -1
|
| 712 |
+
if self.cell_states_to_model is None:
|
| 713 |
+
cos_sims_dict = defaultdict(list)
|
| 714 |
+
else:
|
| 715 |
+
cos_sims_dict = {
|
| 716 |
+
state: defaultdict(list)
|
| 717 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
if self.emb_mode == "cell_and_gene":
|
| 721 |
+
stored_gene_embs_dict = defaultdict(list)
|
| 722 |
+
for i in trange(len(filtered_input_data)):
|
| 723 |
+
example_cell = filtered_input_data.select([i])
|
| 724 |
+
full_original_emb = get_embs(
|
| 725 |
+
model,
|
| 726 |
+
example_cell,
|
| 727 |
+
"gene",
|
| 728 |
+
layer_to_quant,
|
| 729 |
+
self.pad_token_id,
|
| 730 |
+
self.forward_batch_size,
|
| 731 |
+
summary_stat=None,
|
| 732 |
+
silent=True,
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# gene_list is used to assign cos sims back to genes
|
| 736 |
+
# need to remove the anchor gene
|
| 737 |
+
gene_list = example_cell["input_ids"][0][:]
|
| 738 |
+
if self.anchor_token is not None:
|
| 739 |
+
for token in self.anchor_token:
|
| 740 |
+
gene_list.remove(token)
|
| 741 |
+
|
| 742 |
+
perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
|
| 743 |
+
example_cell,
|
| 744 |
+
self.perturb_type,
|
| 745 |
+
self.tokens_to_perturb,
|
| 746 |
+
self.anchor_token,
|
| 747 |
+
self.combos,
|
| 748 |
+
self.nproc,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
full_perturbation_emb = get_embs(
|
| 752 |
+
model,
|
| 753 |
+
perturbation_batch,
|
| 754 |
+
"gene",
|
| 755 |
+
layer_to_quant,
|
| 756 |
+
self.pad_token_id,
|
| 757 |
+
self.forward_batch_size,
|
| 758 |
+
summary_stat=None,
|
| 759 |
+
silent=True,
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
num_inds_perturbed = 1 + self.combos
|
| 763 |
+
# need to remove overexpressed gene to quantify cosine shifts
|
| 764 |
+
if self.perturb_type == "overexpress":
|
| 765 |
+
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
| 766 |
+
gene_list = gene_list[
|
| 767 |
+
num_inds_perturbed:
|
| 768 |
+
] # index 0 is not overexpressed
|
| 769 |
+
|
| 770 |
+
elif self.perturb_type == "delete":
|
| 771 |
+
perturbation_emb = full_perturbation_emb
|
| 772 |
+
|
| 773 |
+
original_batch = pu.make_comparison_batch(
|
| 774 |
+
full_original_emb, indices_to_perturb, perturb_group=False
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
|
| 778 |
+
gene_cos_sims = pu.quant_cos_sims(
|
| 779 |
+
perturbation_emb,
|
| 780 |
+
original_batch,
|
| 781 |
+
self.cell_states_to_model,
|
| 782 |
+
self.state_embs_dict,
|
| 783 |
+
emb_mode="gene",
|
| 784 |
+
)
|
| 785 |
+
if self.cell_states_to_model is not None:
|
| 786 |
+
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
| 787 |
+
full_original_emb, "mean_pool"
|
| 788 |
+
)
|
| 789 |
+
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
| 790 |
+
full_perturbation_emb, "mean_pool"
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
cell_cos_sims = pu.quant_cos_sims(
|
| 794 |
+
perturbation_cell_emb,
|
| 795 |
+
original_cell_emb,
|
| 796 |
+
self.cell_states_to_model,
|
| 797 |
+
self.state_embs_dict,
|
| 798 |
+
emb_mode="cell",
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
if self.emb_mode == "cell_and_gene":
|
| 802 |
+
# remove perturbed index for gene list
|
| 803 |
+
perturbed_gene_dict = {
|
| 804 |
+
gene: gene_list[:i] + gene_list[i + 1 :]
|
| 805 |
+
for i, gene in enumerate(gene_list)
|
| 806 |
+
}
|
| 807 |
+
|
| 808 |
+
for perturbation_i, perturbed_gene in enumerate(gene_list):
|
| 809 |
+
for gene_j, affected_gene in enumerate(
|
| 810 |
+
perturbed_gene_dict[perturbed_gene]
|
| 811 |
+
):
|
| 812 |
+
try:
|
| 813 |
+
stored_gene_embs_dict[
|
| 814 |
+
(perturbed_gene, affected_gene)
|
| 815 |
+
].append(gene_cos_sims[perturbation_i, gene_j].item())
|
| 816 |
+
except KeyError:
|
| 817 |
+
stored_gene_embs_dict[
|
| 818 |
+
(perturbed_gene, affected_gene)
|
| 819 |
+
] = gene_cos_sims[perturbation_i, gene_j].item()
|
| 820 |
+
|
| 821 |
+
if self.cell_states_to_model is None:
|
| 822 |
+
cos_sims_data = torch.mean(gene_cos_sims, dim=1)
|
| 823 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
| 824 |
+
cos_sims_dict,
|
| 825 |
+
cos_sims_data,
|
| 826 |
+
filtered_input_data,
|
| 827 |
+
indices_to_perturb,
|
| 828 |
+
gene_list,
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
cos_sims_data = cell_cos_sims
|
| 832 |
+
for state in cos_sims_dict.keys():
|
| 833 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
| 834 |
+
cos_sims_dict[state],
|
| 835 |
+
cos_sims_data[state],
|
| 836 |
+
filtered_input_data,
|
| 837 |
+
indices_to_perturb,
|
| 838 |
+
gene_list,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
# save dict to disk every 100 cells
|
| 842 |
+
if i % 100 == 0:
|
| 843 |
+
pu.write_perturbation_dictionary(
|
| 844 |
+
cos_sims_dict,
|
| 845 |
+
f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
|
| 846 |
+
)
|
| 847 |
+
if self.emb_mode == "cell_and_gene":
|
| 848 |
+
pu.write_perturbation_dictionary(
|
| 849 |
+
stored_gene_embs_dict,
|
| 850 |
+
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
# reset and clear memory every 1000 cells
|
| 854 |
+
if i % 1000 == 0:
|
| 855 |
+
pickle_batch += 1
|
| 856 |
+
if self.cell_states_to_model is None:
|
| 857 |
+
cos_sims_dict = defaultdict(list)
|
| 858 |
+
else:
|
| 859 |
+
cos_sims_dict = {
|
| 860 |
+
state: defaultdict(list)
|
| 861 |
+
for state in pu.get_possible_states(self.cell_states_to_model)
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
if self.emb_mode == "cell_and_gene":
|
| 865 |
+
stored_gene_embs_dict = defaultdict(list)
|
| 866 |
+
|
| 867 |
+
torch.cuda.empty_cache()
|
| 868 |
+
|
| 869 |
+
pu.write_perturbation_dictionary(
|
| 870 |
+
cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
if self.emb_mode == "cell_and_gene":
|
| 874 |
+
pu.write_perturbation_dictionary(
|
| 875 |
+
stored_gene_embs_dict,
|
| 876 |
+
f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
def update_perturbation_dictionary(
|
| 880 |
+
self,
|
| 881 |
+
cos_sims_dict: defaultdict,
|
| 882 |
+
cos_sims_data: torch.Tensor,
|
| 883 |
+
filtered_input_data: Dataset,
|
| 884 |
+
indices_to_perturb: List[List[int]],
|
| 885 |
+
gene_list=None,
|
| 886 |
+
):
|
| 887 |
+
if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
|
| 888 |
+
logger.error(
|
| 889 |
+
f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
|
| 890 |
+
cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
|
| 891 |
+
len(gene_list) = {len(gene_list)}."
|
| 892 |
+
)
|
| 893 |
+
raise
|
| 894 |
+
|
| 895 |
+
if self.perturb_group is True:
|
| 896 |
+
if len(self.tokens_to_perturb) > 1:
|
| 897 |
+
perturbed_genes = tuple(self.tokens_to_perturb)
|
| 898 |
+
else:
|
| 899 |
+
perturbed_genes = self.tokens_to_perturb[0]
|
| 900 |
+
|
| 901 |
+
# if cell embeddings, can just append
|
| 902 |
+
# shape will be (batch size, 1)
|
| 903 |
+
cos_sims_data = torch.squeeze(cos_sims_data).tolist()
|
| 904 |
+
|
| 905 |
+
# handle case of single cell left
|
| 906 |
+
if not isinstance(cos_sims_data, list):
|
| 907 |
+
cos_sims_data = [cos_sims_data]
|
| 908 |
+
|
| 909 |
+
cos_sims_dict[(perturbed_genes, "cell_emb")] += cos_sims_data
|
| 910 |
+
|
| 911 |
+
else:
|
| 912 |
+
for i, cos in enumerate(cos_sims_data.tolist()):
|
| 913 |
+
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
| 914 |
+
|
| 915 |
+
return cos_sims_dict
|
geneformer/in_silico_perturber_stats.py
ADDED
|
@@ -0,0 +1,1042 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geneformer in silico perturber stats generator.
|
| 3 |
+
|
| 4 |
+
**Usage:**
|
| 5 |
+
|
| 6 |
+
.. code-block :: python
|
| 7 |
+
|
| 8 |
+
>>> from geneformer import InSilicoPerturberStats
|
| 9 |
+
>>> ispstats = InSilicoPerturberStats(mode="goal_state_shift",
|
| 10 |
+
... cell_states_to_model={"state_key": "disease",
|
| 11 |
+
... "start_state": "dcm",
|
| 12 |
+
... "goal_state": "nf",
|
| 13 |
+
... "alt_states": ["hcm", "other1", "other2"]})
|
| 14 |
+
>>> ispstats.get_stats("path/to/input_data",
|
| 15 |
+
... None,
|
| 16 |
+
... "path/to/output_directory",
|
| 17 |
+
... "output_prefix")
|
| 18 |
+
|
| 19 |
+
**Description:**
|
| 20 |
+
|
| 21 |
+
| Aggregates data or calculates stats for in silico perturbations based on type of statistics specified in InSilicoPerturberStats.
|
| 22 |
+
| Input data is raw in silico perturbation results in the form of dictionaries outputted by ``in_silico_perturber``.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
import logging
|
| 28 |
+
import os
|
| 29 |
+
import pickle
|
| 30 |
+
import random
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import pandas as pd
|
| 35 |
+
import statsmodels.stats.multitest as smt
|
| 36 |
+
from scipy.stats import ranksums
|
| 37 |
+
from sklearn.mixture import GaussianMixture
|
| 38 |
+
from tqdm.auto import tqdm, trange
|
| 39 |
+
|
| 40 |
+
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
| 41 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
| 42 |
+
|
| 43 |
+
GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
| 44 |
+
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# invert dictionary keys/values
|
| 49 |
+
def invert_dict(dictionary):
|
| 50 |
+
return {v: k for k, v in dictionary.items()}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
|
| 54 |
+
if cell_or_gene_emb == "cell":
|
| 55 |
+
cell_emb_dict = {
|
| 56 |
+
k: v for k, v in cos_sims_dict.items() if v and "cell_emb" in k
|
| 57 |
+
}
|
| 58 |
+
return [cell_emb_dict]
|
| 59 |
+
elif cell_or_gene_emb == "gene":
|
| 60 |
+
if anchor_token is None:
|
| 61 |
+
gene_emb_dict = {k: v for k, v in cos_sims_dict.items() if v}
|
| 62 |
+
else:
|
| 63 |
+
gene_emb_dict = {
|
| 64 |
+
k: v for k, v in cos_sims_dict.items() if v and anchor_token == k[0]
|
| 65 |
+
}
|
| 66 |
+
return [gene_emb_dict]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# read raw dictionary files
|
| 70 |
+
def read_dictionaries(
|
| 71 |
+
input_data_directory,
|
| 72 |
+
cell_or_gene_emb,
|
| 73 |
+
anchor_token,
|
| 74 |
+
cell_states_to_model,
|
| 75 |
+
pickle_suffix,
|
| 76 |
+
):
|
| 77 |
+
file_found = False
|
| 78 |
+
file_path_list = []
|
| 79 |
+
if cell_states_to_model is None:
|
| 80 |
+
dict_list = []
|
| 81 |
+
else:
|
| 82 |
+
validate_cell_states_to_model(cell_states_to_model)
|
| 83 |
+
cell_states_to_model_valid = {
|
| 84 |
+
state: value
|
| 85 |
+
for state, value in cell_states_to_model.items()
|
| 86 |
+
if state != "state_key"
|
| 87 |
+
and cell_states_to_model[state] is not None
|
| 88 |
+
and cell_states_to_model[state] != []
|
| 89 |
+
}
|
| 90 |
+
cell_states_list = []
|
| 91 |
+
# flatten all state values into list
|
| 92 |
+
for state in cell_states_to_model_valid:
|
| 93 |
+
value = cell_states_to_model_valid[state]
|
| 94 |
+
if isinstance(value, list):
|
| 95 |
+
cell_states_list += value
|
| 96 |
+
else:
|
| 97 |
+
cell_states_list.append(value)
|
| 98 |
+
state_dict = {state_value: dict() for state_value in cell_states_list}
|
| 99 |
+
for file in os.listdir(input_data_directory):
|
| 100 |
+
# process only files with given suffix (e.g. "_raw.pickle")
|
| 101 |
+
if file.endswith(pickle_suffix):
|
| 102 |
+
file_found = True
|
| 103 |
+
file_path_list += [f"{input_data_directory}/{file}"]
|
| 104 |
+
for file_path in tqdm(file_path_list):
|
| 105 |
+
with open(file_path, "rb") as fp:
|
| 106 |
+
cos_sims_dict = pickle.load(fp)
|
| 107 |
+
if cell_states_to_model is None:
|
| 108 |
+
dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
|
| 109 |
+
else:
|
| 110 |
+
for state_value in cell_states_list:
|
| 111 |
+
new_dict = read_dict(
|
| 112 |
+
cos_sims_dict[state_value], cell_or_gene_emb, anchor_token
|
| 113 |
+
)[0]
|
| 114 |
+
for key in new_dict:
|
| 115 |
+
try:
|
| 116 |
+
state_dict[state_value][key] += new_dict[key]
|
| 117 |
+
except KeyError:
|
| 118 |
+
state_dict[state_value][key] = new_dict[key]
|
| 119 |
+
if not file_found:
|
| 120 |
+
logger.error(
|
| 121 |
+
"No raw data for processing found within provided directory. "
|
| 122 |
+
"Please ensure data files end with '{pickle_suffix}'."
|
| 123 |
+
)
|
| 124 |
+
raise
|
| 125 |
+
if cell_states_to_model is None:
|
| 126 |
+
return dict_list
|
| 127 |
+
else:
|
| 128 |
+
return state_dict
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# get complete gene list
|
| 132 |
+
def get_gene_list(dict_list, mode):
|
| 133 |
+
if mode == "cell":
|
| 134 |
+
position = 0
|
| 135 |
+
elif mode == "gene":
|
| 136 |
+
position = 1
|
| 137 |
+
gene_set = set()
|
| 138 |
+
if isinstance(dict_list, list):
|
| 139 |
+
for dict_i in dict_list:
|
| 140 |
+
gene_set.update([k[position] for k, v in dict_i.items() if v])
|
| 141 |
+
elif isinstance(dict_list, dict):
|
| 142 |
+
for state, dict_i in dict_list.items():
|
| 143 |
+
gene_set.update([k[position] for k, v in dict_i.items() if v])
|
| 144 |
+
else:
|
| 145 |
+
logger.error(
|
| 146 |
+
"dict_list should be a list, or if modeling shift to goal states, a dict. "
|
| 147 |
+
f"{type(dict_list)} is not the correct format."
|
| 148 |
+
)
|
| 149 |
+
raise
|
| 150 |
+
gene_list = list(gene_set)
|
| 151 |
+
if mode == "gene":
|
| 152 |
+
gene_list.remove("cell_emb")
|
| 153 |
+
gene_list.sort()
|
| 154 |
+
return gene_list
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
|
| 158 |
+
try:
|
| 159 |
+
return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
|
| 160 |
+
except TypeError:
|
| 161 |
+
return gene_token_id_dict.get(token_tuple, np.nan)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def n_detections(token, dict_list, mode, anchor_token):
|
| 165 |
+
cos_sim_megalist = []
|
| 166 |
+
for dict_i in dict_list:
|
| 167 |
+
if mode == "cell":
|
| 168 |
+
cos_sim_megalist += dict_i.get((token, "cell_emb"), [])
|
| 169 |
+
elif mode == "gene":
|
| 170 |
+
cos_sim_megalist += dict_i.get((anchor_token, token), [])
|
| 171 |
+
return len(cos_sim_megalist)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_fdr(pvalues):
|
| 175 |
+
return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_impact_component(test_value, gaussian_mixture_model):
|
| 179 |
+
impact_border = gaussian_mixture_model.means_[0][0]
|
| 180 |
+
nonimpact_border = gaussian_mixture_model.means_[1][0]
|
| 181 |
+
if test_value > nonimpact_border:
|
| 182 |
+
impact_component = 0
|
| 183 |
+
elif test_value < impact_border:
|
| 184 |
+
impact_component = 1
|
| 185 |
+
else:
|
| 186 |
+
impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
|
| 187 |
+
if impact_component_raw == 1:
|
| 188 |
+
impact_component = 0
|
| 189 |
+
elif impact_component_raw == 0:
|
| 190 |
+
impact_component = 1
|
| 191 |
+
return impact_component
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# aggregate data for single perturbation in multiple cells
|
| 195 |
+
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
|
| 196 |
+
names = ["Cosine_shift"]
|
| 197 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
| 198 |
+
|
| 199 |
+
cos_shift_data = []
|
| 200 |
+
token = cos_sims_df["Gene"][0]
|
| 201 |
+
for dict_i in dict_list:
|
| 202 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
| 203 |
+
cos_sims_full_df["Cosine_shift"] = cos_shift_data
|
| 204 |
+
return cos_sims_full_df
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def find(variable, x):
|
| 208 |
+
try:
|
| 209 |
+
if x in variable: # Test if variable is iterable and contains x
|
| 210 |
+
return True
|
| 211 |
+
except (ValueError, TypeError):
|
| 212 |
+
return x == variable # Test if variable is x if non-iterable
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def isp_aggregate_gene_shifts(
|
| 216 |
+
cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
|
| 217 |
+
):
|
| 218 |
+
cos_shift_data = dict()
|
| 219 |
+
for i in trange(cos_sims_df.shape[0]):
|
| 220 |
+
token = cos_sims_df["Gene"][i]
|
| 221 |
+
for dict_i in dict_list:
|
| 222 |
+
affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
|
| 223 |
+
for key in affected_pairs:
|
| 224 |
+
if key in cos_shift_data.keys():
|
| 225 |
+
cos_shift_data[key] += dict_i.get(key, [])
|
| 226 |
+
else:
|
| 227 |
+
cos_shift_data[key] = dict_i.get(key, [])
|
| 228 |
+
|
| 229 |
+
cos_data_mean = {
|
| 230 |
+
k: [np.mean(v), np.std(v), len(v)] for k, v in cos_shift_data.items()
|
| 231 |
+
}
|
| 232 |
+
cos_sims_full_df = pd.DataFrame()
|
| 233 |
+
cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
|
| 234 |
+
cos_sims_full_df["Gene_name"] = [
|
| 235 |
+
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
|
| 236 |
+
for k, v in cos_data_mean.items()
|
| 237 |
+
]
|
| 238 |
+
cos_sims_full_df["Ensembl_ID"] = [
|
| 239 |
+
cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
|
| 240 |
+
for k, v in cos_data_mean.items()
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
cos_sims_full_df["Affected"] = [k[1] for k, v in cos_data_mean.items()]
|
| 244 |
+
cos_sims_full_df["Affected_gene_name"] = [
|
| 245 |
+
gene_id_name_dict.get(gene_token_id_dict.get(token, np.nan), np.nan)
|
| 246 |
+
for token in cos_sims_full_df["Affected"]
|
| 247 |
+
]
|
| 248 |
+
cos_sims_full_df["Affected_Ensembl_ID"] = [
|
| 249 |
+
gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
|
| 250 |
+
]
|
| 251 |
+
cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
|
| 252 |
+
cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
|
| 253 |
+
cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
|
| 254 |
+
|
| 255 |
+
specific_val = "cell_emb"
|
| 256 |
+
cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
|
| 257 |
+
# reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
|
| 258 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
| 259 |
+
by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
|
| 260 |
+
).drop("temp", axis=1)
|
| 261 |
+
|
| 262 |
+
return cos_sims_full_df
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
| 266 |
+
def isp_stats_to_goal_state(
|
| 267 |
+
cos_sims_df, result_dict, cell_states_to_model, genes_perturbed
|
| 268 |
+
):
|
| 269 |
+
if (
|
| 270 |
+
("alt_states" not in cell_states_to_model.keys())
|
| 271 |
+
or (len(cell_states_to_model["alt_states"]) == 0)
|
| 272 |
+
or (cell_states_to_model["alt_states"] == [None])
|
| 273 |
+
):
|
| 274 |
+
alt_end_state_exists = False
|
| 275 |
+
elif (len(cell_states_to_model["alt_states"]) > 0) and (
|
| 276 |
+
cell_states_to_model["alt_states"] != [None]
|
| 277 |
+
):
|
| 278 |
+
alt_end_state_exists = True
|
| 279 |
+
|
| 280 |
+
# for single perturbation in multiple cells, there are no random perturbations to compare to
|
| 281 |
+
if genes_perturbed != "all":
|
| 282 |
+
cos_sims_full_df = pd.DataFrame()
|
| 283 |
+
|
| 284 |
+
cos_shift_data_end = []
|
| 285 |
+
token = cos_sims_df["Gene"][0]
|
| 286 |
+
cos_shift_data_end += result_dict[cell_states_to_model["goal_state"]].get(
|
| 287 |
+
(token, "cell_emb"), []
|
| 288 |
+
)
|
| 289 |
+
cos_sims_full_df["Shift_to_goal_end"] = [np.mean(cos_shift_data_end)]
|
| 290 |
+
if alt_end_state_exists is True:
|
| 291 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
| 292 |
+
cos_shift_data_alt_state = []
|
| 293 |
+
cos_shift_data_alt_state += result_dict.get(alt_state).get(
|
| 294 |
+
(token, "cell_emb"), []
|
| 295 |
+
)
|
| 296 |
+
cos_sims_full_df[f"Shift_to_alt_end_{alt_state}"] = [
|
| 297 |
+
np.mean(cos_shift_data_alt_state)
|
| 298 |
+
]
|
| 299 |
+
|
| 300 |
+
# sort by shift to desired state
|
| 301 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
| 302 |
+
by=["Shift_to_goal_end"], ascending=[False]
|
| 303 |
+
)
|
| 304 |
+
return cos_sims_full_df
|
| 305 |
+
|
| 306 |
+
elif genes_perturbed == "all":
|
| 307 |
+
goal_end_random_megalist = []
|
| 308 |
+
if alt_end_state_exists is True:
|
| 309 |
+
alt_end_state_random_dict = {
|
| 310 |
+
alt_state: [] for alt_state in cell_states_to_model["alt_states"]
|
| 311 |
+
}
|
| 312 |
+
for i in trange(cos_sims_df.shape[0]):
|
| 313 |
+
token = cos_sims_df["Gene"][i]
|
| 314 |
+
goal_end_random_megalist += result_dict[
|
| 315 |
+
cell_states_to_model["goal_state"]
|
| 316 |
+
].get((token, "cell_emb"), [])
|
| 317 |
+
if alt_end_state_exists is True:
|
| 318 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
| 319 |
+
alt_end_state_random_dict[alt_state] += result_dict[alt_state].get(
|
| 320 |
+
(token, "cell_emb"), []
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# downsample to improve speed of ranksums
|
| 324 |
+
if len(goal_end_random_megalist) > 100_000:
|
| 325 |
+
random.seed(42)
|
| 326 |
+
goal_end_random_megalist = random.sample(
|
| 327 |
+
goal_end_random_megalist, k=100_000
|
| 328 |
+
)
|
| 329 |
+
if alt_end_state_exists is True:
|
| 330 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
| 331 |
+
if len(alt_end_state_random_dict[alt_state]) > 100_000:
|
| 332 |
+
random.seed(42)
|
| 333 |
+
alt_end_state_random_dict[alt_state] = random.sample(
|
| 334 |
+
alt_end_state_random_dict[alt_state], k=100_000
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
names = [
|
| 338 |
+
"Gene",
|
| 339 |
+
"Gene_name",
|
| 340 |
+
"Ensembl_ID",
|
| 341 |
+
"Shift_to_goal_end",
|
| 342 |
+
"Goal_end_vs_random_pval",
|
| 343 |
+
]
|
| 344 |
+
if alt_end_state_exists is True:
|
| 345 |
+
[
|
| 346 |
+
names.append(f"Shift_to_alt_end_{alt_state}")
|
| 347 |
+
for alt_state in cell_states_to_model["alt_states"]
|
| 348 |
+
]
|
| 349 |
+
names.append(names.pop(names.index("Goal_end_vs_random_pval")))
|
| 350 |
+
[
|
| 351 |
+
names.append(f"Alt_end_vs_random_pval_{alt_state}")
|
| 352 |
+
for alt_state in cell_states_to_model["alt_states"]
|
| 353 |
+
]
|
| 354 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
| 355 |
+
|
| 356 |
+
n_detections_dict = dict()
|
| 357 |
+
for i in trange(cos_sims_df.shape[0]):
|
| 358 |
+
token = cos_sims_df["Gene"][i]
|
| 359 |
+
name = cos_sims_df["Gene_name"][i]
|
| 360 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
| 361 |
+
goal_end_cos_sim_megalist = result_dict[
|
| 362 |
+
cell_states_to_model["goal_state"]
|
| 363 |
+
].get((token, "cell_emb"), [])
|
| 364 |
+
n_detections_dict[token] = len(goal_end_cos_sim_megalist)
|
| 365 |
+
mean_goal_end = np.mean(goal_end_cos_sim_megalist)
|
| 366 |
+
pval_goal_end = ranksums(
|
| 367 |
+
goal_end_random_megalist, goal_end_cos_sim_megalist
|
| 368 |
+
).pvalue
|
| 369 |
+
|
| 370 |
+
if alt_end_state_exists is True:
|
| 371 |
+
alt_end_state_dict = {
|
| 372 |
+
alt_state: [] for alt_state in cell_states_to_model["alt_states"]
|
| 373 |
+
}
|
| 374 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
| 375 |
+
alt_end_state_dict[alt_state] = result_dict[alt_state].get(
|
| 376 |
+
(token, "cell_emb"), []
|
| 377 |
+
)
|
| 378 |
+
alt_end_state_dict[f"{alt_state}_mean"] = np.mean(
|
| 379 |
+
alt_end_state_dict[alt_state]
|
| 380 |
+
)
|
| 381 |
+
alt_end_state_dict[f"{alt_state}_pval"] = ranksums(
|
| 382 |
+
alt_end_state_random_dict[alt_state],
|
| 383 |
+
alt_end_state_dict[alt_state],
|
| 384 |
+
).pvalue
|
| 385 |
+
|
| 386 |
+
results_dict = dict()
|
| 387 |
+
results_dict["Gene"] = token
|
| 388 |
+
results_dict["Gene_name"] = name
|
| 389 |
+
results_dict["Ensembl_ID"] = ensembl_id
|
| 390 |
+
results_dict["Shift_to_goal_end"] = mean_goal_end
|
| 391 |
+
results_dict["Goal_end_vs_random_pval"] = pval_goal_end
|
| 392 |
+
if alt_end_state_exists is True:
|
| 393 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
| 394 |
+
results_dict[f"Shift_to_alt_end_{alt_state}"] = alt_end_state_dict[
|
| 395 |
+
f"{alt_state}_mean"
|
| 396 |
+
]
|
| 397 |
+
results_dict[
|
| 398 |
+
f"Alt_end_vs_random_pval_{alt_state}"
|
| 399 |
+
] = alt_end_state_dict[f"{alt_state}_pval"]
|
| 400 |
+
|
| 401 |
+
cos_sims_df_i = pd.DataFrame(results_dict, index=[i])
|
| 402 |
+
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
| 403 |
+
|
| 404 |
+
cos_sims_full_df["Goal_end_FDR"] = get_fdr(
|
| 405 |
+
list(cos_sims_full_df["Goal_end_vs_random_pval"])
|
| 406 |
+
)
|
| 407 |
+
if alt_end_state_exists is True:
|
| 408 |
+
for alt_state in cell_states_to_model["alt_states"]:
|
| 409 |
+
cos_sims_full_df[f"Alt_end_FDR_{alt_state}"] = get_fdr(
|
| 410 |
+
list(cos_sims_full_df[f"Alt_end_vs_random_pval_{alt_state}"])
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# quantify number of detections of each gene
|
| 414 |
+
cos_sims_full_df["N_Detections"] = [
|
| 415 |
+
n_detections_dict[token] for token in cos_sims_full_df["Gene"]
|
| 416 |
+
]
|
| 417 |
+
|
| 418 |
+
# sort by shift to desired state
|
| 419 |
+
cos_sims_full_df["Sig"] = [
|
| 420 |
+
1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]
|
| 421 |
+
]
|
| 422 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
| 423 |
+
by=["Sig", "Shift_to_goal_end", "Goal_end_FDR"],
|
| 424 |
+
ascending=[False, False, True],
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
return cos_sims_full_df
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# stats comparing cos sim shifts of test perturbations vs null distribution
|
| 431 |
+
def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
|
| 432 |
+
cos_sims_full_df = cos_sims_df.copy()
|
| 433 |
+
|
| 434 |
+
cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
| 435 |
+
cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
| 436 |
+
cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
|
| 437 |
+
cos_sims_df.shape[0], dtype=float
|
| 438 |
+
)
|
| 439 |
+
cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
| 440 |
+
cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
|
| 441 |
+
cos_sims_full_df["N_Detections_test"] = np.zeros(
|
| 442 |
+
cos_sims_df.shape[0], dtype="uint32"
|
| 443 |
+
)
|
| 444 |
+
cos_sims_full_df["N_Detections_null"] = np.zeros(
|
| 445 |
+
cos_sims_df.shape[0], dtype="uint32"
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
for i in trange(cos_sims_df.shape[0]):
|
| 449 |
+
token = cos_sims_df["Gene"][i]
|
| 450 |
+
test_shifts = []
|
| 451 |
+
null_shifts = []
|
| 452 |
+
|
| 453 |
+
for dict_i in dict_list:
|
| 454 |
+
test_shifts += dict_i.get((token, "cell_emb"), [])
|
| 455 |
+
|
| 456 |
+
for dict_i in null_dict_list:
|
| 457 |
+
null_shifts += dict_i.get((token, "cell_emb"), [])
|
| 458 |
+
|
| 459 |
+
cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
|
| 460 |
+
cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
|
| 461 |
+
cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
|
| 462 |
+
test_shifts
|
| 463 |
+
) - np.mean(null_shifts)
|
| 464 |
+
cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(
|
| 465 |
+
test_shifts, null_shifts, nan_policy="omit"
|
| 466 |
+
).pvalue
|
| 467 |
+
# remove nan values
|
| 468 |
+
cos_sims_full_df.Test_vs_null_pval = np.where(
|
| 469 |
+
np.isnan(cos_sims_full_df.Test_vs_null_pval),
|
| 470 |
+
1,
|
| 471 |
+
cos_sims_full_df.Test_vs_null_pval,
|
| 472 |
+
)
|
| 473 |
+
cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
|
| 474 |
+
cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
|
| 475 |
+
|
| 476 |
+
cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
|
| 477 |
+
cos_sims_full_df["Test_vs_null_pval"]
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
cos_sims_full_df["Sig"] = [
|
| 481 |
+
1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]
|
| 482 |
+
]
|
| 483 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
| 484 |
+
by=["Sig", "Test_vs_null_avg_shift", "Test_vs_null_FDR"],
|
| 485 |
+
ascending=[False, False, True],
|
| 486 |
+
)
|
| 487 |
+
return cos_sims_full_df
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
# stats for identifying perturbations with largest effect within a given set of cells
|
| 491 |
+
# fits a mixture model to 2 components (impact vs. non-impact) and
|
| 492 |
+
# reports the most likely component for each test perturbation
|
| 493 |
+
# Note: because assumes given perturbation has a consistent effect in the cells tested,
|
| 494 |
+
# we recommend only using the mixture model strategy with uniform cell populations
|
| 495 |
+
def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
|
| 496 |
+
names = ["Gene", "Gene_name", "Ensembl_ID"]
|
| 497 |
+
|
| 498 |
+
if combos == 0:
|
| 499 |
+
names += ["Test_avg_shift"]
|
| 500 |
+
elif combos == 1:
|
| 501 |
+
names += [
|
| 502 |
+
"Anchor_shift",
|
| 503 |
+
"Test_token_shift",
|
| 504 |
+
"Sum_of_indiv_shifts",
|
| 505 |
+
"Combo_shift",
|
| 506 |
+
"Combo_minus_sum_shift",
|
| 507 |
+
]
|
| 508 |
+
|
| 509 |
+
names += ["Impact_component", "Impact_component_percent"]
|
| 510 |
+
|
| 511 |
+
cos_sims_full_df = pd.DataFrame(columns=names)
|
| 512 |
+
avg_values = []
|
| 513 |
+
gene_names = []
|
| 514 |
+
|
| 515 |
+
for i in trange(cos_sims_df.shape[0]):
|
| 516 |
+
token = cos_sims_df["Gene"][i]
|
| 517 |
+
name = cos_sims_df["Gene_name"][i]
|
| 518 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
| 519 |
+
cos_shift_data = []
|
| 520 |
+
|
| 521 |
+
for dict_i in dict_list:
|
| 522 |
+
if (combos == 0) and (anchor_token is not None):
|
| 523 |
+
cos_shift_data += dict_i.get((anchor_token, token), [])
|
| 524 |
+
else:
|
| 525 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
| 526 |
+
|
| 527 |
+
# Extract values for current gene
|
| 528 |
+
if combos == 0:
|
| 529 |
+
test_values = cos_shift_data
|
| 530 |
+
elif combos == 1:
|
| 531 |
+
test_values = []
|
| 532 |
+
for tup in cos_shift_data:
|
| 533 |
+
test_values.append(tup[2])
|
| 534 |
+
|
| 535 |
+
if len(test_values) > 0:
|
| 536 |
+
avg_value = np.mean(test_values)
|
| 537 |
+
avg_values.append(avg_value)
|
| 538 |
+
gene_names.append(name)
|
| 539 |
+
|
| 540 |
+
# fit Gaussian mixture model to dataset of mean for each gene
|
| 541 |
+
avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
|
| 542 |
+
gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
|
| 543 |
+
|
| 544 |
+
for i in trange(cos_sims_df.shape[0]):
|
| 545 |
+
token = cos_sims_df["Gene"][i]
|
| 546 |
+
name = cos_sims_df["Gene_name"][i]
|
| 547 |
+
ensembl_id = cos_sims_df["Ensembl_ID"][i]
|
| 548 |
+
cos_shift_data = []
|
| 549 |
+
|
| 550 |
+
for dict_i in dict_list:
|
| 551 |
+
if (combos == 0) and (anchor_token is not None):
|
| 552 |
+
cos_shift_data += dict_i.get((anchor_token, token), [])
|
| 553 |
+
else:
|
| 554 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
| 555 |
+
|
| 556 |
+
if combos == 0:
|
| 557 |
+
mean_test = np.mean(cos_shift_data)
|
| 558 |
+
impact_components = [
|
| 559 |
+
get_impact_component(value, gm) for value in cos_shift_data
|
| 560 |
+
]
|
| 561 |
+
elif combos == 1:
|
| 562 |
+
anchor_cos_sim_megalist = [
|
| 563 |
+
anchor for anchor, token, combo in cos_shift_data
|
| 564 |
+
]
|
| 565 |
+
token_cos_sim_megalist = [token for anchor, token, combo in cos_shift_data]
|
| 566 |
+
anchor_plus_token_cos_sim_megalist = [
|
| 567 |
+
1 - ((1 - anchor) + (1 - token))
|
| 568 |
+
for anchor, token, combo in cos_shift_data
|
| 569 |
+
]
|
| 570 |
+
combo_anchor_token_cos_sim_megalist = [
|
| 571 |
+
combo for anchor, token, combo in cos_shift_data
|
| 572 |
+
]
|
| 573 |
+
combo_minus_sum_cos_sim_megalist = [
|
| 574 |
+
combo - (1 - ((1 - anchor) + (1 - token)))
|
| 575 |
+
for anchor, token, combo in cos_shift_data
|
| 576 |
+
]
|
| 577 |
+
|
| 578 |
+
mean_anchor = np.mean(anchor_cos_sim_megalist)
|
| 579 |
+
mean_token = np.mean(token_cos_sim_megalist)
|
| 580 |
+
mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
|
| 581 |
+
mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
|
| 582 |
+
mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
|
| 583 |
+
|
| 584 |
+
impact_components = [
|
| 585 |
+
get_impact_component(value, gm)
|
| 586 |
+
for value in combo_anchor_token_cos_sim_megalist
|
| 587 |
+
]
|
| 588 |
+
|
| 589 |
+
impact_component = get_impact_component(mean_test, gm)
|
| 590 |
+
impact_component_percent = np.mean(impact_components) * 100
|
| 591 |
+
|
| 592 |
+
data_i = [token, name, ensembl_id]
|
| 593 |
+
if combos == 0:
|
| 594 |
+
data_i += [mean_test]
|
| 595 |
+
elif combos == 1:
|
| 596 |
+
data_i += [
|
| 597 |
+
mean_anchor,
|
| 598 |
+
mean_token,
|
| 599 |
+
mean_sum,
|
| 600 |
+
mean_test,
|
| 601 |
+
mean_combo_minus_sum,
|
| 602 |
+
]
|
| 603 |
+
data_i += [impact_component, impact_component_percent]
|
| 604 |
+
|
| 605 |
+
cos_sims_df_i = pd.DataFrame(dict(zip(names, data_i)), index=[i])
|
| 606 |
+
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
| 607 |
+
|
| 608 |
+
# quantify number of detections of each gene
|
| 609 |
+
cos_sims_full_df["N_Detections"] = [
|
| 610 |
+
n_detections(i, dict_list, "gene", anchor_token)
|
| 611 |
+
for i in cos_sims_full_df["Gene"]
|
| 612 |
+
]
|
| 613 |
+
|
| 614 |
+
if combos == 0:
|
| 615 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
| 616 |
+
by=["Impact_component", "Test_avg_shift"], ascending=[False, True]
|
| 617 |
+
)
|
| 618 |
+
elif combos == 1:
|
| 619 |
+
cos_sims_full_df = cos_sims_full_df.sort_values(
|
| 620 |
+
by=["Impact_component", "Combo_minus_sum_shift"], ascending=[False, True]
|
| 621 |
+
)
|
| 622 |
+
return cos_sims_full_df
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class InSilicoPerturberStats:
|
| 626 |
+
valid_option_dict = {
|
| 627 |
+
"mode": {
|
| 628 |
+
"goal_state_shift",
|
| 629 |
+
"vs_null",
|
| 630 |
+
"mixture_model",
|
| 631 |
+
"aggregate_data",
|
| 632 |
+
"aggregate_gene_shifts",
|
| 633 |
+
},
|
| 634 |
+
"genes_perturbed": {"all", list},
|
| 635 |
+
"combos": {0, 1},
|
| 636 |
+
"anchor_gene": {None, str},
|
| 637 |
+
"cell_states_to_model": {None, dict},
|
| 638 |
+
"pickle_suffix": {None, str},
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
def __init__(
|
| 642 |
+
self,
|
| 643 |
+
mode="mixture_model",
|
| 644 |
+
genes_perturbed="all",
|
| 645 |
+
combos=0,
|
| 646 |
+
anchor_gene=None,
|
| 647 |
+
cell_states_to_model=None,
|
| 648 |
+
pickle_suffix="_raw.pickle",
|
| 649 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 650 |
+
gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
|
| 651 |
+
):
|
| 652 |
+
"""
|
| 653 |
+
Initialize in silico perturber stats generator.
|
| 654 |
+
|
| 655 |
+
**Parameters:**
|
| 656 |
+
|
| 657 |
+
mode : {"goal_state_shift", "vs_null", "mixture_model", "aggregate_data", "aggregate_gene_shifts"}
|
| 658 |
+
| Type of stats.
|
| 659 |
+
| "goal_state_shift": perturbation vs. random for desired cell state shift
|
| 660 |
+
| "vs_null": perturbation vs. null from provided null distribution dataset
|
| 661 |
+
| "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
|
| 662 |
+
| "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
|
| 663 |
+
| "aggregate_gene_shifts": aggregates cosine shifts of genes in response to perturbation(s)
|
| 664 |
+
genes_perturbed : "all", list
|
| 665 |
+
| Genes perturbed in isp experiment.
|
| 666 |
+
| Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
|
| 667 |
+
| Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
|
| 668 |
+
combos : {0,1,2}
|
| 669 |
+
| Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
|
| 670 |
+
anchor_gene : None, str
|
| 671 |
+
| ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
|
| 672 |
+
| For example, if combos=1 and anchor_gene="ENSG00000136574":
|
| 673 |
+
| analyzes data for anchor gene perturbed in combination with each other gene.
|
| 674 |
+
| However, if combos=0 and anchor_gene="ENSG00000136574":
|
| 675 |
+
| analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
|
| 676 |
+
cell_states_to_model: None, dict
|
| 677 |
+
| Cell states to model if testing perturbations that achieve goal state change.
|
| 678 |
+
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
| 679 |
+
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
| 680 |
+
| start_state: value in the state_key column that specifies the start state
|
| 681 |
+
| goal_state: value in the state_key column taht specifies the goal end state
|
| 682 |
+
| alt_states: list of values in the state_key column that specify the alternate end states
|
| 683 |
+
| For example: {"state_key": "disease",
|
| 684 |
+
| "start_state": "dcm",
|
| 685 |
+
| "goal_state": "nf",
|
| 686 |
+
| "alt_states": ["hcm", "other1", "other2"]}
|
| 687 |
+
token_dictionary_file : Path
|
| 688 |
+
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 689 |
+
gene_name_id_dictionary_file : Path
|
| 690 |
+
| Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
|
| 691 |
+
"""
|
| 692 |
+
|
| 693 |
+
self.mode = mode
|
| 694 |
+
self.genes_perturbed = genes_perturbed
|
| 695 |
+
self.combos = combos
|
| 696 |
+
self.anchor_gene = anchor_gene
|
| 697 |
+
self.cell_states_to_model = cell_states_to_model
|
| 698 |
+
self.pickle_suffix = pickle_suffix
|
| 699 |
+
|
| 700 |
+
self.validate_options()
|
| 701 |
+
|
| 702 |
+
# load token dictionary (Ensembl IDs:token)
|
| 703 |
+
with open(token_dictionary_file, "rb") as f:
|
| 704 |
+
self.gene_token_dict = pickle.load(f)
|
| 705 |
+
|
| 706 |
+
# load gene name dictionary (gene name:Ensembl ID)
|
| 707 |
+
with open(gene_name_id_dictionary_file, "rb") as f:
|
| 708 |
+
self.gene_name_id_dict = pickle.load(f)
|
| 709 |
+
|
| 710 |
+
if anchor_gene is None:
|
| 711 |
+
self.anchor_token = None
|
| 712 |
+
else:
|
| 713 |
+
self.anchor_token = self.gene_token_dict[self.anchor_gene]
|
| 714 |
+
|
| 715 |
+
def validate_options(self):
|
| 716 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
| 717 |
+
attr_value = self.__dict__[attr_name]
|
| 718 |
+
if type(attr_value) not in {list, dict}:
|
| 719 |
+
if attr_name in {"anchor_gene"}:
|
| 720 |
+
continue
|
| 721 |
+
elif attr_value in valid_options:
|
| 722 |
+
continue
|
| 723 |
+
valid_type = False
|
| 724 |
+
for option in valid_options:
|
| 725 |
+
if (option in [str, int, list, dict]) and isinstance(
|
| 726 |
+
attr_value, option
|
| 727 |
+
):
|
| 728 |
+
valid_type = True
|
| 729 |
+
break
|
| 730 |
+
if not valid_type:
|
| 731 |
+
logger.error(
|
| 732 |
+
f"Invalid option for {attr_name}. "
|
| 733 |
+
f"Valid options for {attr_name}: {valid_options}"
|
| 734 |
+
)
|
| 735 |
+
raise
|
| 736 |
+
|
| 737 |
+
if self.cell_states_to_model is not None:
|
| 738 |
+
if len(self.cell_states_to_model.items()) == 1:
|
| 739 |
+
logger.warning(
|
| 740 |
+
"The single value dictionary for cell_states_to_model will be "
|
| 741 |
+
"replaced with a dictionary with named keys for start, goal, and alternate states. "
|
| 742 |
+
"Please specify state_key, start_state, goal_state, and alt_states "
|
| 743 |
+
"in the cell_states_to_model dictionary for future use. "
|
| 744 |
+
"For example, cell_states_to_model={"
|
| 745 |
+
"'state_key': 'disease', "
|
| 746 |
+
"'start_state': 'dcm', "
|
| 747 |
+
"'goal_state': 'nf', "
|
| 748 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
| 749 |
+
)
|
| 750 |
+
for key, value in self.cell_states_to_model.items():
|
| 751 |
+
if (len(value) == 3) and isinstance(value, tuple):
|
| 752 |
+
if (
|
| 753 |
+
isinstance(value[0], list)
|
| 754 |
+
and isinstance(value[1], list)
|
| 755 |
+
and isinstance(value[2], list)
|
| 756 |
+
):
|
| 757 |
+
if len(value[0]) == 1 and len(value[1]) == 1:
|
| 758 |
+
all_values = value[0] + value[1] + value[2]
|
| 759 |
+
if len(all_values) == len(set(all_values)):
|
| 760 |
+
continue
|
| 761 |
+
# reformat to the new named key format
|
| 762 |
+
state_values = flatten_list(list(self.cell_states_to_model.values()))
|
| 763 |
+
self.cell_states_to_model = {
|
| 764 |
+
"state_key": list(self.cell_states_to_model.keys())[0],
|
| 765 |
+
"start_state": state_values[0][0],
|
| 766 |
+
"goal_state": state_values[1][0],
|
| 767 |
+
"alt_states": state_values[2:][0],
|
| 768 |
+
}
|
| 769 |
+
elif set(self.cell_states_to_model.keys()) == {
|
| 770 |
+
"state_key",
|
| 771 |
+
"start_state",
|
| 772 |
+
"goal_state",
|
| 773 |
+
"alt_states",
|
| 774 |
+
}:
|
| 775 |
+
if (
|
| 776 |
+
(self.cell_states_to_model["state_key"] is None)
|
| 777 |
+
or (self.cell_states_to_model["start_state"] is None)
|
| 778 |
+
or (self.cell_states_to_model["goal_state"] is None)
|
| 779 |
+
):
|
| 780 |
+
logger.error(
|
| 781 |
+
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
|
| 782 |
+
)
|
| 783 |
+
raise
|
| 784 |
+
|
| 785 |
+
if (
|
| 786 |
+
self.cell_states_to_model["start_state"]
|
| 787 |
+
== self.cell_states_to_model["goal_state"]
|
| 788 |
+
):
|
| 789 |
+
logger.error("All states must be unique.")
|
| 790 |
+
raise
|
| 791 |
+
|
| 792 |
+
if self.cell_states_to_model["alt_states"] is not None:
|
| 793 |
+
if not isinstance(self.cell_states_to_model["alt_states"], list):
|
| 794 |
+
logger.error(
|
| 795 |
+
"self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
| 796 |
+
)
|
| 797 |
+
raise
|
| 798 |
+
if len(self.cell_states_to_model["alt_states"]) != len(
|
| 799 |
+
set(self.cell_states_to_model["alt_states"])
|
| 800 |
+
):
|
| 801 |
+
logger.error("All states must be unique.")
|
| 802 |
+
raise
|
| 803 |
+
|
| 804 |
+
else:
|
| 805 |
+
logger.error(
|
| 806 |
+
"cell_states_to_model must only have the following four keys: "
|
| 807 |
+
"'state_key', 'start_state', 'goal_state', 'alt_states'."
|
| 808 |
+
"For example, cell_states_to_model={"
|
| 809 |
+
"'state_key': 'disease', "
|
| 810 |
+
"'start_state': 'dcm', "
|
| 811 |
+
"'goal_state': 'nf', "
|
| 812 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
| 813 |
+
)
|
| 814 |
+
raise
|
| 815 |
+
|
| 816 |
+
if self.anchor_gene is not None:
|
| 817 |
+
self.anchor_gene = None
|
| 818 |
+
logger.warning(
|
| 819 |
+
"anchor_gene set to None. "
|
| 820 |
+
"Currently, anchor gene not available "
|
| 821 |
+
"when modeling multiple cell states."
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
if self.combos > 0:
|
| 825 |
+
if self.anchor_gene is None:
|
| 826 |
+
logger.error(
|
| 827 |
+
"Currently, stats are only supported for combination "
|
| 828 |
+
"in silico perturbation run with anchor gene. Please add "
|
| 829 |
+
"anchor gene when using with combos > 0. "
|
| 830 |
+
)
|
| 831 |
+
raise
|
| 832 |
+
|
| 833 |
+
if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
|
| 834 |
+
logger.error(
|
| 835 |
+
"Mixture model mode requires multiple gene perturbations to fit model "
|
| 836 |
+
"so is incompatible with a single grouped perturbation."
|
| 837 |
+
)
|
| 838 |
+
raise
|
| 839 |
+
if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
|
| 840 |
+
logger.error(
|
| 841 |
+
"Simple data aggregation mode is for single perturbation in multiple cells "
|
| 842 |
+
"so is incompatible with a genes_perturbed being 'all'."
|
| 843 |
+
)
|
| 844 |
+
raise
|
| 845 |
+
|
| 846 |
+
def get_stats(
|
| 847 |
+
self,
|
| 848 |
+
input_data_directory,
|
| 849 |
+
null_dist_data_directory,
|
| 850 |
+
output_directory,
|
| 851 |
+
output_prefix,
|
| 852 |
+
null_dict_list=None,
|
| 853 |
+
):
|
| 854 |
+
"""
|
| 855 |
+
Get stats for in silico perturbation data and save as results in output_directory.
|
| 856 |
+
|
| 857 |
+
**Parameters:**
|
| 858 |
+
|
| 859 |
+
input_data_directory : Path
|
| 860 |
+
| Path to directory containing cos_sim dictionary inputs
|
| 861 |
+
null_dist_data_directory : Path
|
| 862 |
+
| Path to directory containing null distribution cos_sim dictionary inputs
|
| 863 |
+
output_directory : Path
|
| 864 |
+
| Path to directory where perturbation data will be saved as .csv
|
| 865 |
+
output_prefix : str
|
| 866 |
+
| Prefix for output .csv
|
| 867 |
+
null_dict_list: list[dict]
|
| 868 |
+
| List of loaded null distribution dictionary if more than one comparison vs. the null is to be performed
|
| 869 |
+
|
| 870 |
+
**Outputs:**
|
| 871 |
+
|
| 872 |
+
Definition of possible columns in .csv output file.
|
| 873 |
+
|
| 874 |
+
| Of note, not all columns will be present in all output files.
|
| 875 |
+
| Some columns are specific to particular perturbation modes.
|
| 876 |
+
|
| 877 |
+
| "Gene": gene token
|
| 878 |
+
| "Gene_name": gene name
|
| 879 |
+
| "Ensembl_ID": gene Ensembl ID
|
| 880 |
+
| "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
|
| 881 |
+
| "Sig": 1 if FDR<0.05, otherwise 0
|
| 882 |
+
|
| 883 |
+
| "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
|
| 884 |
+
| "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
|
| 885 |
+
| "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
|
| 886 |
+
| pvalue compares shift caused by perturbing given gene compared to random genes
|
| 887 |
+
| "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
|
| 888 |
+
| pvalue compares shift caused by perturbing given gene compared to random genes
|
| 889 |
+
| "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
|
| 890 |
+
| "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
|
| 891 |
+
|
| 892 |
+
| "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
|
| 893 |
+
| "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
|
| 894 |
+
| "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
|
| 895 |
+
| (i.e. "Test_avg_shift" minus "Null_avg_shift")
|
| 896 |
+
| "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
|
| 897 |
+
| "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
|
| 898 |
+
| "N_Detections_test": "N_Detections" in cells from test distribution
|
| 899 |
+
| "N_Detections_null": "N_Detections" in cells from null distribution
|
| 900 |
+
|
| 901 |
+
| "Anchor_shift": cosine shift in response to given perturbation of anchor gene
|
| 902 |
+
| "Test_token_shift": cosine shift in response to given perturbation of test gene
|
| 903 |
+
| "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
|
| 904 |
+
| "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
|
| 905 |
+
| "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
|
| 906 |
+
| (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
|
| 907 |
+
| "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
|
| 908 |
+
| 1: within impact component; 0: not within impact component
|
| 909 |
+
| "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
|
| 910 |
+
|
| 911 |
+
| In case of aggregating gene shifts:
|
| 912 |
+
| "Perturbed": ID(s) of gene(s) being perturbed
|
| 913 |
+
| "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
|
| 914 |
+
| "Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
|
| 915 |
+
| "Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
|
| 916 |
+
"""
|
| 917 |
+
|
| 918 |
+
if self.mode not in [
|
| 919 |
+
"goal_state_shift",
|
| 920 |
+
"vs_null",
|
| 921 |
+
"mixture_model",
|
| 922 |
+
"aggregate_data",
|
| 923 |
+
"aggregate_gene_shifts",
|
| 924 |
+
]:
|
| 925 |
+
logger.error(
|
| 926 |
+
"Currently, only modes available are stats for goal_state_shift, "
|
| 927 |
+
"vs_null (comparing to null distribution), "
|
| 928 |
+
"mixture_model (fitting mixture model for perturbations with or without impact), "
|
| 929 |
+
"and aggregating data for single perturbations or for gene embedding shifts."
|
| 930 |
+
)
|
| 931 |
+
raise
|
| 932 |
+
|
| 933 |
+
self.gene_token_id_dict = invert_dict(self.gene_token_dict)
|
| 934 |
+
self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
|
| 935 |
+
|
| 936 |
+
# obtain total gene list
|
| 937 |
+
if (self.combos == 0) and (self.anchor_token is not None):
|
| 938 |
+
# cos sim data for effect of gene perturbation on the embedding of each other gene
|
| 939 |
+
dict_list = read_dictionaries(
|
| 940 |
+
input_data_directory,
|
| 941 |
+
"gene",
|
| 942 |
+
self.anchor_token,
|
| 943 |
+
self.cell_states_to_model,
|
| 944 |
+
self.pickle_suffix,
|
| 945 |
+
)
|
| 946 |
+
gene_list = get_gene_list(dict_list, "gene")
|
| 947 |
+
elif (
|
| 948 |
+
(self.combos == 0)
|
| 949 |
+
and (self.anchor_token is None)
|
| 950 |
+
and (self.mode == "aggregate_gene_shifts")
|
| 951 |
+
):
|
| 952 |
+
dict_list = read_dictionaries(
|
| 953 |
+
input_data_directory,
|
| 954 |
+
"gene",
|
| 955 |
+
self.anchor_token,
|
| 956 |
+
self.cell_states_to_model,
|
| 957 |
+
self.pickle_suffix,
|
| 958 |
+
)
|
| 959 |
+
gene_list = get_gene_list(dict_list, "cell")
|
| 960 |
+
else:
|
| 961 |
+
# cos sim data for effect of gene perturbation on the embedding of each cell
|
| 962 |
+
dict_list = read_dictionaries(
|
| 963 |
+
input_data_directory,
|
| 964 |
+
"cell",
|
| 965 |
+
self.anchor_token,
|
| 966 |
+
self.cell_states_to_model,
|
| 967 |
+
self.pickle_suffix,
|
| 968 |
+
)
|
| 969 |
+
gene_list = get_gene_list(dict_list, "cell")
|
| 970 |
+
|
| 971 |
+
# initiate results dataframe
|
| 972 |
+
cos_sims_df_initial = pd.DataFrame(
|
| 973 |
+
{
|
| 974 |
+
"Gene": gene_list,
|
| 975 |
+
"Gene_name": [self.token_to_gene_name(item) for item in gene_list],
|
| 976 |
+
"Ensembl_ID": [
|
| 977 |
+
token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict)
|
| 978 |
+
if self.genes_perturbed != "all"
|
| 979 |
+
else self.gene_token_id_dict[genes[1]]
|
| 980 |
+
if isinstance(genes, tuple)
|
| 981 |
+
else self.gene_token_id_dict[genes]
|
| 982 |
+
for genes in gene_list
|
| 983 |
+
],
|
| 984 |
+
},
|
| 985 |
+
index=[i for i in range(len(gene_list))],
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
if self.mode == "goal_state_shift":
|
| 989 |
+
cos_sims_df = isp_stats_to_goal_state(
|
| 990 |
+
cos_sims_df_initial,
|
| 991 |
+
dict_list,
|
| 992 |
+
self.cell_states_to_model,
|
| 993 |
+
self.genes_perturbed,
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
elif self.mode == "vs_null":
|
| 997 |
+
if null_dict_list is None:
|
| 998 |
+
null_dict_list = read_dictionaries(
|
| 999 |
+
null_dist_data_directory,
|
| 1000 |
+
"cell",
|
| 1001 |
+
self.anchor_token,
|
| 1002 |
+
self.cell_states_to_model,
|
| 1003 |
+
self.pickle_suffix,
|
| 1004 |
+
)
|
| 1005 |
+
cos_sims_df = isp_stats_vs_null(
|
| 1006 |
+
cos_sims_df_initial, dict_list, null_dict_list
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
elif self.mode == "mixture_model":
|
| 1010 |
+
cos_sims_df = isp_stats_mixture_model(
|
| 1011 |
+
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
elif self.mode == "aggregate_data":
|
| 1015 |
+
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
|
| 1016 |
+
|
| 1017 |
+
elif self.mode == "aggregate_gene_shifts":
|
| 1018 |
+
cos_sims_df = isp_aggregate_gene_shifts(
|
| 1019 |
+
cos_sims_df_initial,
|
| 1020 |
+
dict_list,
|
| 1021 |
+
self.gene_token_id_dict,
|
| 1022 |
+
self.gene_id_name_dict,
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
# save perturbation stats to output_path
|
| 1026 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
| 1027 |
+
cos_sims_df.to_csv(output_path)
|
| 1028 |
+
|
| 1029 |
+
def token_to_gene_name(self, item):
|
| 1030 |
+
if np.issubdtype(type(item), np.integer):
|
| 1031 |
+
return self.gene_id_name_dict.get(
|
| 1032 |
+
self.gene_token_id_dict.get(item, np.nan), np.nan
|
| 1033 |
+
)
|
| 1034 |
+
if isinstance(item, tuple):
|
| 1035 |
+
return tuple(
|
| 1036 |
+
[
|
| 1037 |
+
self.gene_id_name_dict.get(
|
| 1038 |
+
self.gene_token_id_dict.get(i, np.nan), np.nan
|
| 1039 |
+
)
|
| 1040 |
+
for i in item
|
| 1041 |
+
]
|
| 1042 |
+
)
|
geneformer/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5e33a757431643b3697de7ef6127950cdc49e06e58d4266b3a3ab191b683f14
|
| 3 |
+
size 41183536
|
geneformer/perturber_utils.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools as it
|
| 2 |
+
import logging
|
| 3 |
+
import pickle
|
| 4 |
+
import re
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
import torch
|
| 12 |
+
from datasets import Dataset, load_from_disk
|
| 13 |
+
from transformers import (
|
| 14 |
+
BertForMaskedLM,
|
| 15 |
+
BertForSequenceClassification,
|
| 16 |
+
BertForTokenClassification,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
sns.set()
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# load data and filter by defined criteria
|
| 25 |
+
def load_and_filter(filter_data, nproc, input_data_file):
|
| 26 |
+
data = load_from_disk(input_data_file)
|
| 27 |
+
if filter_data is not None:
|
| 28 |
+
data = filter_by_dict(data, filter_data, nproc)
|
| 29 |
+
return data
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def filter_by_dict(data, filter_data, nproc):
|
| 33 |
+
for key, value in filter_data.items():
|
| 34 |
+
|
| 35 |
+
def filter_data_by_criteria(example):
|
| 36 |
+
return example[key] in value
|
| 37 |
+
|
| 38 |
+
data = data.filter(filter_data_by_criteria, num_proc=nproc)
|
| 39 |
+
if len(data) == 0:
|
| 40 |
+
logger.error("No cells remain after filtering. Check filtering criteria.")
|
| 41 |
+
raise
|
| 42 |
+
return data
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def filter_data_by_tokens(filtered_input_data, tokens, nproc):
|
| 46 |
+
def if_has_tokens(example):
|
| 47 |
+
return len(set(example["input_ids"]).intersection(tokens)) == len(tokens)
|
| 48 |
+
|
| 49 |
+
filtered_input_data = filtered_input_data.filter(if_has_tokens, num_proc=nproc)
|
| 50 |
+
return filtered_input_data
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def logging_filtered_data_len(filtered_input_data, filtered_tokens_categ):
|
| 54 |
+
if len(filtered_input_data) == 0:
|
| 55 |
+
logger.error(f"No cells in dataset contain {filtered_tokens_categ}.")
|
| 56 |
+
raise
|
| 57 |
+
else:
|
| 58 |
+
logger.info(f"# cells with {filtered_tokens_categ}: {len(filtered_input_data)}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def filter_data_by_tokens_and_log(
|
| 62 |
+
filtered_input_data, tokens, nproc, filtered_tokens_categ
|
| 63 |
+
):
|
| 64 |
+
# filter for cells with anchor gene
|
| 65 |
+
filtered_input_data = filter_data_by_tokens(filtered_input_data, tokens, nproc)
|
| 66 |
+
# logging length of filtered data
|
| 67 |
+
logging_filtered_data_len(filtered_input_data, filtered_tokens_categ)
|
| 68 |
+
|
| 69 |
+
return filtered_input_data
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def filter_data_by_start_state(filtered_input_data, cell_states_to_model, nproc):
|
| 73 |
+
# confirm that start state is valid to prevent futile filtering
|
| 74 |
+
state_key = cell_states_to_model["state_key"]
|
| 75 |
+
state_values = filtered_input_data[state_key]
|
| 76 |
+
start_state = cell_states_to_model["start_state"]
|
| 77 |
+
if start_state not in state_values:
|
| 78 |
+
logger.error(
|
| 79 |
+
f"Start state {start_state} is not present "
|
| 80 |
+
f"in the dataset's {state_key} attribute."
|
| 81 |
+
)
|
| 82 |
+
raise
|
| 83 |
+
|
| 84 |
+
# filter for start state cells
|
| 85 |
+
def filter_for_origin(example):
|
| 86 |
+
return example[state_key] in [start_state]
|
| 87 |
+
|
| 88 |
+
filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=nproc)
|
| 89 |
+
return filtered_input_data
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
|
| 93 |
+
if cell_inds_to_perturb["start"] >= len(filtered_input_data):
|
| 94 |
+
logger.error(
|
| 95 |
+
"cell_inds_to_perturb['start'] is larger than the filtered dataset."
|
| 96 |
+
)
|
| 97 |
+
raise
|
| 98 |
+
if cell_inds_to_perturb["end"] > len(filtered_input_data):
|
| 99 |
+
logger.warning(
|
| 100 |
+
"cell_inds_to_perturb['end'] is larger than the filtered dataset. \
|
| 101 |
+
Setting to the end of the filtered dataset."
|
| 102 |
+
)
|
| 103 |
+
cell_inds_to_perturb["end"] = len(filtered_input_data)
|
| 104 |
+
filtered_input_data = filtered_input_data.select(
|
| 105 |
+
[i for i in range(cell_inds_to_perturb["start"], cell_inds_to_perturb["end"])]
|
| 106 |
+
)
|
| 107 |
+
return filtered_input_data
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# load model to GPU
|
| 111 |
+
def load_model(model_type, num_classes, model_directory):
|
| 112 |
+
if model_type == "Pretrained":
|
| 113 |
+
model = BertForMaskedLM.from_pretrained(
|
| 114 |
+
model_directory, output_hidden_states=True, output_attentions=False
|
| 115 |
+
)
|
| 116 |
+
elif model_type == "GeneClassifier":
|
| 117 |
+
model = BertForTokenClassification.from_pretrained(
|
| 118 |
+
model_directory,
|
| 119 |
+
num_labels=num_classes,
|
| 120 |
+
output_hidden_states=True,
|
| 121 |
+
output_attentions=False,
|
| 122 |
+
)
|
| 123 |
+
elif model_type == "CellClassifier":
|
| 124 |
+
model = BertForSequenceClassification.from_pretrained(
|
| 125 |
+
model_directory,
|
| 126 |
+
num_labels=num_classes,
|
| 127 |
+
output_hidden_states=True,
|
| 128 |
+
output_attentions=False,
|
| 129 |
+
)
|
| 130 |
+
# put the model in eval mode for fwd pass
|
| 131 |
+
model.eval()
|
| 132 |
+
model = model.to("cuda:0")
|
| 133 |
+
return model
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def quant_layers(model):
|
| 137 |
+
layer_nums = []
|
| 138 |
+
for name, parameter in model.named_parameters():
|
| 139 |
+
if "layer" in name:
|
| 140 |
+
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
|
| 141 |
+
return int(max(layer_nums)) + 1
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_model_input_size(model):
|
| 145 |
+
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def flatten_list(megalist):
|
| 149 |
+
return [item for sublist in megalist for item in sublist]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def measure_length(example):
|
| 153 |
+
example["length"] = len(example["input_ids"])
|
| 154 |
+
return example
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def downsample_and_sort(data, max_ncells):
|
| 158 |
+
num_cells = len(data)
|
| 159 |
+
# if max number of cells is defined, then shuffle and subsample to this max number
|
| 160 |
+
if max_ncells is not None:
|
| 161 |
+
if num_cells > max_ncells:
|
| 162 |
+
data = data.shuffle(seed=42)
|
| 163 |
+
num_cells = max_ncells
|
| 164 |
+
data_subset = data.select([i for i in range(num_cells)])
|
| 165 |
+
# sort dataset with largest cell first to encounter any memory errors earlier
|
| 166 |
+
data_sorted = data_subset.sort("length", reverse=True)
|
| 167 |
+
return data_sorted
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_possible_states(cell_states_to_model):
|
| 171 |
+
possible_states = []
|
| 172 |
+
for key in ["start_state", "goal_state"]:
|
| 173 |
+
possible_states += [cell_states_to_model[key]]
|
| 174 |
+
possible_states += cell_states_to_model.get("alt_states", [])
|
| 175 |
+
return possible_states
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
| 179 |
+
example_cell.set_format(type="torch")
|
| 180 |
+
input_data = example_cell["input_ids"]
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
outputs = model(input_ids=input_data.to("cuda"))
|
| 183 |
+
emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
|
| 184 |
+
del outputs
|
| 185 |
+
return emb
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def perturb_emb_by_index(emb, indices):
|
| 189 |
+
mask = torch.ones(emb.numel(), dtype=torch.bool)
|
| 190 |
+
mask[indices] = False
|
| 191 |
+
return emb[mask]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def delete_indices(example):
|
| 195 |
+
indices = example["perturb_index"]
|
| 196 |
+
if any(isinstance(el, list) for el in indices):
|
| 197 |
+
indices = flatten_list(indices)
|
| 198 |
+
for index in sorted(indices, reverse=True):
|
| 199 |
+
del example["input_ids"][index]
|
| 200 |
+
|
| 201 |
+
example["length"] = len(example["input_ids"])
|
| 202 |
+
return example
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
| 206 |
+
def overexpress_indices(example):
|
| 207 |
+
indices = example["perturb_index"]
|
| 208 |
+
if any(isinstance(el, list) for el in indices):
|
| 209 |
+
indices = flatten_list(indices)
|
| 210 |
+
for index in sorted(indices, reverse=True):
|
| 211 |
+
example["input_ids"].insert(0, example["input_ids"].pop(index))
|
| 212 |
+
|
| 213 |
+
example["length"] = len(example["input_ids"])
|
| 214 |
+
return example
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
| 218 |
+
def overexpress_tokens(example, max_len):
|
| 219 |
+
# -100 indicates tokens to overexpress are not present in rank value encoding
|
| 220 |
+
if example["perturb_index"] != [-100]:
|
| 221 |
+
example = delete_indices(example)
|
| 222 |
+
[
|
| 223 |
+
example["input_ids"].insert(0, token)
|
| 224 |
+
for token in example["tokens_to_perturb"][::-1]
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
# truncate to max input size, must also truncate original emb to be comparable
|
| 228 |
+
if len(example["input_ids"]) > max_len:
|
| 229 |
+
example["input_ids"] = example["input_ids"][0:max_len]
|
| 230 |
+
|
| 231 |
+
example["length"] = len(example["input_ids"])
|
| 232 |
+
return example
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def calc_n_overflow(max_len, example_len, tokens_to_perturb, indices_to_perturb):
|
| 236 |
+
n_to_add = len(tokens_to_perturb) - len(indices_to_perturb)
|
| 237 |
+
n_overflow = example_len + n_to_add - max_len
|
| 238 |
+
return n_overflow
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def truncate_by_n_overflow(example):
|
| 242 |
+
new_max_len = example["length"] - example["n_overflow"]
|
| 243 |
+
example["input_ids"] = example["input_ids"][0:new_max_len]
|
| 244 |
+
example["length"] = len(example["input_ids"])
|
| 245 |
+
return example
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
| 249 |
+
# indices_to_remove is list of indices to remove
|
| 250 |
+
indices_to_keep = [
|
| 251 |
+
i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove
|
| 252 |
+
]
|
| 253 |
+
num_dims = emb.dim()
|
| 254 |
+
emb_slice = [
|
| 255 |
+
slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)
|
| 256 |
+
]
|
| 257 |
+
sliced_emb = emb[emb_slice]
|
| 258 |
+
return sliced_emb
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
|
| 262 |
+
output_batch_list = [
|
| 263 |
+
remove_indices_from_emb(emb_batch[i, :, :], idxes, gene_dim - 1)
|
| 264 |
+
for i, idxes in enumerate(list_of_indices_to_remove)
|
| 265 |
+
]
|
| 266 |
+
# add padding given genes are sometimes added that are or are not in original cell
|
| 267 |
+
batch_max = max([emb.size()[gene_dim - 1] for emb in output_batch_list])
|
| 268 |
+
output_batch_list_padded = [
|
| 269 |
+
pad_xd_tensor(emb, 0.000, batch_max, gene_dim - 1) for emb in output_batch_list
|
| 270 |
+
]
|
| 271 |
+
return torch.stack(output_batch_list_padded)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# removes perturbed indices
|
| 275 |
+
# need to handle the various cases where a set of genes is overexpressed
|
| 276 |
+
def remove_perturbed_indices_set(
|
| 277 |
+
emb,
|
| 278 |
+
perturb_type: str,
|
| 279 |
+
indices_to_perturb: List[List],
|
| 280 |
+
tokens_to_perturb: List[List],
|
| 281 |
+
original_lengths: List[int],
|
| 282 |
+
input_ids=None,
|
| 283 |
+
):
|
| 284 |
+
if perturb_type == "overexpress":
|
| 285 |
+
num_perturbed = len(tokens_to_perturb)
|
| 286 |
+
if num_perturbed == 1:
|
| 287 |
+
indices_to_perturb_orig = [
|
| 288 |
+
idx if idx != [-100] else [None] for idx in indices_to_perturb
|
| 289 |
+
]
|
| 290 |
+
if all(v is [None] for v in indices_to_perturb_orig):
|
| 291 |
+
return emb
|
| 292 |
+
else:
|
| 293 |
+
indices_to_perturb_orig = []
|
| 294 |
+
|
| 295 |
+
for idx_list in indices_to_perturb:
|
| 296 |
+
indices_to_perturb_orig.append(
|
| 297 |
+
[idx if idx != [-100] else [None] for idx in idx_list]
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
else:
|
| 301 |
+
indices_to_perturb_orig = indices_to_perturb
|
| 302 |
+
|
| 303 |
+
emb = remove_indices_from_emb_batch(emb, indices_to_perturb_orig, gene_dim=1)
|
| 304 |
+
|
| 305 |
+
return emb
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def make_perturbation_batch(
|
| 309 |
+
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 310 |
+
) -> tuple[Dataset, List[int]]:
|
| 311 |
+
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 312 |
+
if perturb_type in ["overexpress", "activate"]:
|
| 313 |
+
range_start = 1
|
| 314 |
+
elif perturb_type in ["delete", "inhibit"]:
|
| 315 |
+
range_start = 0
|
| 316 |
+
indices_to_perturb = [
|
| 317 |
+
[i] for i in range(range_start, example_cell["length"][0])
|
| 318 |
+
]
|
| 319 |
+
# elif combo_lvl > 0 and anchor_token is None:
|
| 320 |
+
## to implement
|
| 321 |
+
elif combo_lvl > 0 and (anchor_token is not None):
|
| 322 |
+
example_input_ids = example_cell["input_ids"][0]
|
| 323 |
+
anchor_index = example_input_ids.index(anchor_token[0])
|
| 324 |
+
indices_to_perturb = [
|
| 325 |
+
sorted([anchor_index, i]) if i != anchor_index else None
|
| 326 |
+
for i in range(example_cell["length"][0])
|
| 327 |
+
]
|
| 328 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
| 329 |
+
else:
|
| 330 |
+
example_input_ids = example_cell["input_ids"][0]
|
| 331 |
+
indices_to_perturb = [
|
| 332 |
+
[example_input_ids.index(token)] if token in example_input_ids else None
|
| 333 |
+
for token in tokens_to_perturb
|
| 334 |
+
]
|
| 335 |
+
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
| 336 |
+
|
| 337 |
+
# create all permutations of combo_lvl of modifiers from tokens_to_perturb
|
| 338 |
+
if combo_lvl > 0 and (anchor_token is None):
|
| 339 |
+
if tokens_to_perturb != "all":
|
| 340 |
+
if len(tokens_to_perturb) == combo_lvl + 1:
|
| 341 |
+
indices_to_perturb = [
|
| 342 |
+
list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
|
| 343 |
+
]
|
| 344 |
+
else:
|
| 345 |
+
all_indices = [[i] for i in range(example_cell["length"][0])]
|
| 346 |
+
all_indices = [
|
| 347 |
+
index for index in all_indices if index not in indices_to_perturb
|
| 348 |
+
]
|
| 349 |
+
indices_to_perturb = [
|
| 350 |
+
[[j for i in indices_to_perturb for j in i], x] for x in all_indices
|
| 351 |
+
]
|
| 352 |
+
|
| 353 |
+
length = len(indices_to_perturb)
|
| 354 |
+
perturbation_dataset = Dataset.from_dict(
|
| 355 |
+
{
|
| 356 |
+
"input_ids": example_cell["input_ids"] * length,
|
| 357 |
+
"perturb_index": indices_to_perturb,
|
| 358 |
+
}
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if length < 400:
|
| 362 |
+
num_proc_i = 1
|
| 363 |
+
else:
|
| 364 |
+
num_proc_i = num_proc
|
| 365 |
+
|
| 366 |
+
if perturb_type == "delete":
|
| 367 |
+
perturbation_dataset = perturbation_dataset.map(
|
| 368 |
+
delete_indices, num_proc=num_proc_i
|
| 369 |
+
)
|
| 370 |
+
elif perturb_type == "overexpress":
|
| 371 |
+
perturbation_dataset = perturbation_dataset.map(
|
| 372 |
+
overexpress_indices, num_proc=num_proc_i
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
| 376 |
+
|
| 377 |
+
return perturbation_dataset, indices_to_perturb
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# perturbed cell emb removing the activated/overexpressed/inhibited gene emb
|
| 381 |
+
# so that only non-perturbed gene embeddings are compared to each other
|
| 382 |
+
# in original or perturbed context
|
| 383 |
+
def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
|
| 384 |
+
all_embs_list = []
|
| 385 |
+
|
| 386 |
+
# if making comparison batch for multiple perturbations in single cell
|
| 387 |
+
if perturb_group is False:
|
| 388 |
+
# squeeze if single cell
|
| 389 |
+
if original_emb_batch.ndim == 3 and original_emb_batch.size()[0] == 1:
|
| 390 |
+
original_emb_batch = torch.squeeze(original_emb_batch)
|
| 391 |
+
original_emb_list = [original_emb_batch] * len(indices_to_perturb)
|
| 392 |
+
# if making comparison batch for single perturbation in multiple cells
|
| 393 |
+
elif perturb_group is True:
|
| 394 |
+
original_emb_list = original_emb_batch
|
| 395 |
+
|
| 396 |
+
for original_emb, indices in zip(original_emb_list, indices_to_perturb):
|
| 397 |
+
if indices == [-100]:
|
| 398 |
+
all_embs_list += [original_emb[:]]
|
| 399 |
+
continue
|
| 400 |
+
|
| 401 |
+
emb_list = []
|
| 402 |
+
start = 0
|
| 403 |
+
if any(isinstance(el, list) for el in indices):
|
| 404 |
+
indices = flatten_list(indices)
|
| 405 |
+
|
| 406 |
+
# removes indices that were perturbed from the original embedding
|
| 407 |
+
for i in sorted(indices):
|
| 408 |
+
emb_list += [original_emb[start:i]]
|
| 409 |
+
start = i + 1
|
| 410 |
+
|
| 411 |
+
emb_list += [original_emb[start:]]
|
| 412 |
+
all_embs_list += [torch.cat(emb_list)]
|
| 413 |
+
|
| 414 |
+
len_set = set([emb.size()[0] for emb in all_embs_list])
|
| 415 |
+
if len(len_set) > 1:
|
| 416 |
+
max_len = max(len_set)
|
| 417 |
+
all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
|
| 418 |
+
return torch.stack(all_embs_list)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def pad_list(input_ids, pad_token_id, max_len):
|
| 422 |
+
input_ids = np.pad(
|
| 423 |
+
input_ids,
|
| 424 |
+
(0, max_len - len(input_ids)),
|
| 425 |
+
mode="constant",
|
| 426 |
+
constant_values=pad_token_id,
|
| 427 |
+
)
|
| 428 |
+
return input_ids
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def pad_xd_tensor(tensor, pad_token_id, max_len, dim):
|
| 432 |
+
padding_length = max_len - tensor.size()[dim]
|
| 433 |
+
# Construct a padding configuration where all padding values are 0, except for the padding dimension
|
| 434 |
+
# 2 * number of dimensions (padding before and after for every dimension)
|
| 435 |
+
pad_config = [0] * 2 * tensor.dim()
|
| 436 |
+
# Set the padding after the desired dimension to the calculated padding length
|
| 437 |
+
pad_config[-2 * dim - 1] = padding_length
|
| 438 |
+
return torch.nn.functional.pad(
|
| 439 |
+
tensor, pad=pad_config, mode="constant", value=pad_token_id
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def pad_tensor(tensor, pad_token_id, max_len):
|
| 444 |
+
tensor = torch.nn.functional.pad(
|
| 445 |
+
tensor, pad=(0, max_len - tensor.numel()), mode="constant", value=pad_token_id
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
return tensor
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
|
| 452 |
+
if dim == 0:
|
| 453 |
+
pad = (0, 0, 0, max_len - tensor.size()[dim])
|
| 454 |
+
elif dim == 1:
|
| 455 |
+
pad = (0, max_len - tensor.size()[dim], 0, 0)
|
| 456 |
+
tensor = torch.nn.functional.pad(
|
| 457 |
+
tensor, pad=pad, mode="constant", value=pad_token_id
|
| 458 |
+
)
|
| 459 |
+
return tensor
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def pad_3d_tensor(tensor, pad_token_id, max_len, dim):
|
| 463 |
+
if dim == 0:
|
| 464 |
+
raise Exception("dim 0 usually does not need to be padded.")
|
| 465 |
+
if dim == 1:
|
| 466 |
+
pad = (0, 0, 0, max_len - tensor.size()[dim])
|
| 467 |
+
elif dim == 2:
|
| 468 |
+
pad = (0, max_len - tensor.size()[dim], 0, 0)
|
| 469 |
+
tensor = torch.nn.functional.pad(
|
| 470 |
+
tensor, pad=pad, mode="constant", value=pad_token_id
|
| 471 |
+
)
|
| 472 |
+
return tensor
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
|
| 476 |
+
if isinstance(encoding, torch.Tensor):
|
| 477 |
+
encoding_len = encoding.size()[0]
|
| 478 |
+
elif isinstance(encoding, list):
|
| 479 |
+
encoding_len = len(encoding)
|
| 480 |
+
if encoding_len > max_len:
|
| 481 |
+
encoding = encoding[0:max_len]
|
| 482 |
+
elif encoding_len < max_len:
|
| 483 |
+
if isinstance(encoding, torch.Tensor):
|
| 484 |
+
encoding = pad_tensor(encoding, pad_token_id, max_len)
|
| 485 |
+
elif isinstance(encoding, list):
|
| 486 |
+
encoding = pad_list(encoding, pad_token_id, max_len)
|
| 487 |
+
return encoding
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
# pad list of tensors and convert to tensor
|
| 491 |
+
def pad_tensor_list(
|
| 492 |
+
tensor_list,
|
| 493 |
+
dynamic_or_constant,
|
| 494 |
+
pad_token_id,
|
| 495 |
+
model_input_size,
|
| 496 |
+
dim=None,
|
| 497 |
+
padding_func=None,
|
| 498 |
+
):
|
| 499 |
+
# determine maximum tensor length
|
| 500 |
+
if dynamic_or_constant == "dynamic":
|
| 501 |
+
max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
|
| 502 |
+
elif isinstance(dynamic_or_constant, int):
|
| 503 |
+
max_len = dynamic_or_constant
|
| 504 |
+
else:
|
| 505 |
+
max_len = model_input_size
|
| 506 |
+
logger.warning(
|
| 507 |
+
"If padding style is constant, must provide integer value. "
|
| 508 |
+
f"Setting padding to max input size {model_input_size}."
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# pad all tensors to maximum length
|
| 512 |
+
if dim is None:
|
| 513 |
+
tensor_list = [
|
| 514 |
+
pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list
|
| 515 |
+
]
|
| 516 |
+
else:
|
| 517 |
+
tensor_list = [
|
| 518 |
+
padding_func(tensor, pad_token_id, max_len, dim) for tensor in tensor_list
|
| 519 |
+
]
|
| 520 |
+
# return stacked tensors
|
| 521 |
+
if padding_func != pad_3d_tensor:
|
| 522 |
+
return torch.stack(tensor_list)
|
| 523 |
+
else:
|
| 524 |
+
return torch.cat(tensor_list, 0)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def gen_attention_mask(minibatch_encoding, max_len=None):
|
| 528 |
+
if max_len is None:
|
| 529 |
+
max_len = max(minibatch_encoding["length"])
|
| 530 |
+
original_lens = minibatch_encoding["length"]
|
| 531 |
+
attention_mask = [
|
| 532 |
+
[1] * original_len + [0] * (max_len - original_len)
|
| 533 |
+
if original_len <= max_len
|
| 534 |
+
else [1] * max_len
|
| 535 |
+
for original_len in original_lens
|
| 536 |
+
]
|
| 537 |
+
return torch.tensor(attention_mask, device="cuda")
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
# get cell embeddings excluding padding
|
| 541 |
+
def mean_nonpadding_embs(embs, original_lens, dim=1):
|
| 542 |
+
# create a mask tensor based on padding lengths
|
| 543 |
+
mask = torch.arange(embs.size(dim), device=embs.device) < original_lens.unsqueeze(1)
|
| 544 |
+
if embs.dim() == 3:
|
| 545 |
+
# fill the masked positions in embs with zeros
|
| 546 |
+
masked_embs = embs.masked_fill(~mask.unsqueeze(2), 0.0)
|
| 547 |
+
|
| 548 |
+
# compute the mean across the non-padding dimensions
|
| 549 |
+
mean_embs = masked_embs.sum(dim) / original_lens.view(-1, 1).float()
|
| 550 |
+
|
| 551 |
+
elif embs.dim() == 2:
|
| 552 |
+
masked_embs = embs.masked_fill(~mask, 0.0)
|
| 553 |
+
mean_embs = masked_embs.sum(dim) / original_lens.float()
|
| 554 |
+
return mean_embs
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
# get cell embeddings when there is no padding
|
| 558 |
+
def compute_nonpadded_cell_embedding(embs, cell_emb_style):
|
| 559 |
+
if cell_emb_style == "mean_pool":
|
| 560 |
+
return torch.mean(embs, dim=embs.ndim - 2)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
# quantify shifts for a set of genes
|
| 564 |
+
def quant_cos_sims(
|
| 565 |
+
perturbation_emb,
|
| 566 |
+
original_emb,
|
| 567 |
+
cell_states_to_model,
|
| 568 |
+
state_embs_dict,
|
| 569 |
+
emb_mode="gene",
|
| 570 |
+
):
|
| 571 |
+
if emb_mode == "gene":
|
| 572 |
+
cos = torch.nn.CosineSimilarity(dim=2)
|
| 573 |
+
elif emb_mode == "cell":
|
| 574 |
+
cos = torch.nn.CosineSimilarity(dim=1)
|
| 575 |
+
|
| 576 |
+
if cell_states_to_model is None:
|
| 577 |
+
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
| 578 |
+
else:
|
| 579 |
+
possible_states = get_possible_states(cell_states_to_model)
|
| 580 |
+
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
| 581 |
+
for state in possible_states:
|
| 582 |
+
cos_sims[state] = cos_sim_shift(
|
| 583 |
+
original_emb,
|
| 584 |
+
perturbation_emb,
|
| 585 |
+
state_embs_dict[state].to("cuda"), # required to move to cuda here
|
| 586 |
+
cos,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
return cos_sims
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
# calculate cos sim shift of perturbation with respect to origin and alternative cell
|
| 593 |
+
def cos_sim_shift(original_emb, perturbed_emb, end_emb, cos):
|
| 594 |
+
origin_v_end = cos(original_emb, end_emb)
|
| 595 |
+
perturb_v_end = cos(perturbed_emb, end_emb)
|
| 596 |
+
|
| 597 |
+
return perturb_v_end - origin_v_end
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def concatenate_cos_sims(cos_sims):
|
| 601 |
+
if isinstance(cos_sims, list):
|
| 602 |
+
return torch.cat(cos_sims)
|
| 603 |
+
else:
|
| 604 |
+
for state in cos_sims.keys():
|
| 605 |
+
cos_sims[state] = torch.cat(cos_sims[state])
|
| 606 |
+
return cos_sims
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def write_perturbation_dictionary(cos_sims_dict: defaultdict, output_path_prefix: str):
|
| 610 |
+
with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
|
| 611 |
+
pickle.dump(cos_sims_dict, fp)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def tensor_list_to_pd(tensor_list):
|
| 615 |
+
tensor = torch.cat(tensor_list).cpu().numpy()
|
| 616 |
+
df = pd.DataFrame(tensor)
|
| 617 |
+
return df
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def validate_cell_states_to_model(cell_states_to_model):
|
| 621 |
+
if cell_states_to_model is not None:
|
| 622 |
+
if len(cell_states_to_model.items()) == 1:
|
| 623 |
+
logger.warning(
|
| 624 |
+
"The single value dictionary for cell_states_to_model will be "
|
| 625 |
+
"replaced with a dictionary with named keys for start, goal, and alternate states. "
|
| 626 |
+
"Please specify state_key, start_state, goal_state, and alt_states "
|
| 627 |
+
"in the cell_states_to_model dictionary for future use. "
|
| 628 |
+
"For example, cell_states_to_model={"
|
| 629 |
+
"'state_key': 'disease', "
|
| 630 |
+
"'start_state': 'dcm', "
|
| 631 |
+
"'goal_state': 'nf', "
|
| 632 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
| 633 |
+
)
|
| 634 |
+
for key, value in cell_states_to_model.items():
|
| 635 |
+
if (len(value) == 3) and isinstance(value, tuple):
|
| 636 |
+
if (
|
| 637 |
+
isinstance(value[0], list)
|
| 638 |
+
and isinstance(value[1], list)
|
| 639 |
+
and isinstance(value[2], list)
|
| 640 |
+
):
|
| 641 |
+
if len(value[0]) == 1 and len(value[1]) == 1:
|
| 642 |
+
all_values = value[0] + value[1] + value[2]
|
| 643 |
+
if len(all_values) == len(set(all_values)):
|
| 644 |
+
continue
|
| 645 |
+
# reformat to the new named key format
|
| 646 |
+
state_values = flatten_list(list(cell_states_to_model.values()))
|
| 647 |
+
|
| 648 |
+
cell_states_to_model = {
|
| 649 |
+
"state_key": list(cell_states_to_model.keys())[0],
|
| 650 |
+
"start_state": state_values[0][0],
|
| 651 |
+
"goal_state": state_values[1][0],
|
| 652 |
+
"alt_states": state_values[2:][0],
|
| 653 |
+
}
|
| 654 |
+
elif set(cell_states_to_model.keys()).issuperset(
|
| 655 |
+
{"state_key", "start_state", "goal_state"}
|
| 656 |
+
):
|
| 657 |
+
if (
|
| 658 |
+
(cell_states_to_model["state_key"] is None)
|
| 659 |
+
or (cell_states_to_model["start_state"] is None)
|
| 660 |
+
or (cell_states_to_model["goal_state"] is None)
|
| 661 |
+
):
|
| 662 |
+
logger.error(
|
| 663 |
+
"Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
|
| 664 |
+
)
|
| 665 |
+
raise
|
| 666 |
+
|
| 667 |
+
if (
|
| 668 |
+
cell_states_to_model["start_state"]
|
| 669 |
+
== cell_states_to_model["goal_state"]
|
| 670 |
+
):
|
| 671 |
+
logger.error("All states must be unique.")
|
| 672 |
+
raise
|
| 673 |
+
|
| 674 |
+
if "alt_states" in set(cell_states_to_model.keys()):
|
| 675 |
+
if cell_states_to_model["alt_states"] is not None:
|
| 676 |
+
if not isinstance(cell_states_to_model["alt_states"], list):
|
| 677 |
+
logger.error(
|
| 678 |
+
"cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
| 679 |
+
)
|
| 680 |
+
raise
|
| 681 |
+
if len(cell_states_to_model["alt_states"]) != len(
|
| 682 |
+
set(cell_states_to_model["alt_states"])
|
| 683 |
+
):
|
| 684 |
+
logger.error("All states must be unique.")
|
| 685 |
+
raise
|
| 686 |
+
else:
|
| 687 |
+
cell_states_to_model["alt_states"] = []
|
| 688 |
+
|
| 689 |
+
else:
|
| 690 |
+
logger.error(
|
| 691 |
+
"cell_states_to_model must only have the following four keys: "
|
| 692 |
+
"'state_key', 'start_state', 'goal_state', 'alt_states'."
|
| 693 |
+
"For example, cell_states_to_model={"
|
| 694 |
+
"'state_key': 'disease', "
|
| 695 |
+
"'start_state': 'dcm', "
|
| 696 |
+
"'goal_state': 'nf', "
|
| 697 |
+
"'alt_states': ['hcm', 'other1', 'other2']}"
|
| 698 |
+
)
|
| 699 |
+
raise
|
geneformer/pretrainer.py
ADDED
|
@@ -0,0 +1,978 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geneformer precollator and pretrainer.
|
| 3 |
+
|
| 4 |
+
Huggingface data collator and trainer modified to accommodate single-cell transcriptomics data.
|
| 5 |
+
"""
|
| 6 |
+
import collections
|
| 7 |
+
import math
|
| 8 |
+
import pickle
|
| 9 |
+
import warnings
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Dict, Iterator, List, Optional, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
from packaging import version
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
from torch.utils.data.sampler import RandomSampler
|
| 19 |
+
from transformers import (
|
| 20 |
+
BatchEncoding,
|
| 21 |
+
DataCollatorForLanguageModeling,
|
| 22 |
+
SpecialTokensMixin,
|
| 23 |
+
Trainer,
|
| 24 |
+
)
|
| 25 |
+
from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
|
| 26 |
+
from transformers.trainer_pt_utils import (
|
| 27 |
+
DistributedLengthGroupedSampler,
|
| 28 |
+
DistributedSamplerWithLoop,
|
| 29 |
+
LengthGroupedSampler,
|
| 30 |
+
)
|
| 31 |
+
from transformers.training_args import ParallelMode
|
| 32 |
+
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
| 33 |
+
from transformers.utils.generic import _is_tensorflow, _is_torch
|
| 34 |
+
|
| 35 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
EncodedInput = List[int]
|
| 39 |
+
VERY_LARGE_INTEGER = int(
|
| 40 |
+
1e30
|
| 41 |
+
) # This is used to set the max input length for a model with infinite size input
|
| 42 |
+
LARGE_INTEGER = int(
|
| 43 |
+
1e20
|
| 44 |
+
) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
|
| 45 |
+
|
| 46 |
+
if is_sagemaker_dp_enabled():
|
| 47 |
+
import smdistributed.dataparallel.torch.distributed as dist
|
| 48 |
+
else:
|
| 49 |
+
import torch.distributed as dist
|
| 50 |
+
|
| 51 |
+
_is_torch_generator_available = False
|
| 52 |
+
if version.parse(torch.__version__) >= version.parse("1.6"):
|
| 53 |
+
_is_torch_generator_available = True
|
| 54 |
+
|
| 55 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
| 56 |
+
token_dictionary = pickle.load(f)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ExplicitEnum(Enum):
|
| 60 |
+
"""
|
| 61 |
+
Enum with more explicit error message for missing values.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def _missing_(cls, value):
|
| 66 |
+
raise ValueError(
|
| 67 |
+
"%r is not a valid %s, please select one of %s"
|
| 68 |
+
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TruncationStrategy(ExplicitEnum):
|
| 73 |
+
"""
|
| 74 |
+
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
| 75 |
+
tab-completion in an IDE.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
ONLY_FIRST = "only_first"
|
| 79 |
+
ONLY_SECOND = "only_second"
|
| 80 |
+
LONGEST_FIRST = "longest_first"
|
| 81 |
+
DO_NOT_TRUNCATE = "do_not_truncate"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class PaddingStrategy(ExplicitEnum):
|
| 85 |
+
"""
|
| 86 |
+
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
| 87 |
+
in an IDE.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
LONGEST = "longest"
|
| 91 |
+
MAX_LENGTH = "max_length"
|
| 92 |
+
DO_NOT_PAD = "do_not_pad"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class TensorType(ExplicitEnum):
|
| 96 |
+
"""
|
| 97 |
+
Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
|
| 98 |
+
tab-completion in an IDE.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
PYTORCH = "pt"
|
| 102 |
+
TENSORFLOW = "tf"
|
| 103 |
+
NUMPY = "np"
|
| 104 |
+
JAX = "jax"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class GeneformerPreCollator(SpecialTokensMixin):
|
| 108 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 109 |
+
|
| 110 |
+
super().__init__(mask_token = "<mask>", pad_token = "<pad>")
|
| 111 |
+
|
| 112 |
+
self.token_dictionary = kwargs.get("token_dictionary")
|
| 113 |
+
# self.mask_token = "<mask>"
|
| 114 |
+
# self.mask_token_id = self.token_dictionary.get("<mask>")
|
| 115 |
+
# self.pad_token = "<pad>"
|
| 116 |
+
# self.pad_token_id = self.token_dictionary.get("<pad>")
|
| 117 |
+
self.padding_side = "right"
|
| 118 |
+
# self.all_special_ids = [
|
| 119 |
+
# self.token_dictionary.get("<mask>"),
|
| 120 |
+
# self.token_dictionary.get("<pad>"),
|
| 121 |
+
# ]
|
| 122 |
+
self.model_input_names = ["input_ids"]
|
| 123 |
+
|
| 124 |
+
def convert_ids_to_tokens(self,value):
|
| 125 |
+
return self.token_dictionary.get(value)
|
| 126 |
+
|
| 127 |
+
def _get_padding_truncation_strategies(
|
| 128 |
+
self,
|
| 129 |
+
padding=False,
|
| 130 |
+
truncation=False,
|
| 131 |
+
max_length=None,
|
| 132 |
+
pad_to_multiple_of=None,
|
| 133 |
+
verbose=True,
|
| 134 |
+
**kwargs,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
|
| 138 |
+
and pad_to_max_length) and behaviors.
|
| 139 |
+
"""
|
| 140 |
+
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
|
| 141 |
+
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
|
| 142 |
+
|
| 143 |
+
# Backward compatibility for previous behavior, maybe we should deprecate it:
|
| 144 |
+
# If you only set max_length, it activates truncation for max_length
|
| 145 |
+
if max_length is not None and padding is False and truncation is False:
|
| 146 |
+
if verbose:
|
| 147 |
+
if not self.deprecation_warnings.get(
|
| 148 |
+
"Truncation-not-explicitly-activated", False
|
| 149 |
+
):
|
| 150 |
+
logger.warning(
|
| 151 |
+
"Truncation was not explicitly activated but `max_length` is provided a specific value, "
|
| 152 |
+
"please use `truncation=True` to explicitly truncate examples to max length. "
|
| 153 |
+
"Defaulting to 'longest_first' truncation strategy. "
|
| 154 |
+
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
|
| 155 |
+
"more precisely by providing a specific strategy to `truncation`."
|
| 156 |
+
)
|
| 157 |
+
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
|
| 158 |
+
truncation = "longest_first"
|
| 159 |
+
|
| 160 |
+
# Get padding strategy
|
| 161 |
+
if padding is False and old_pad_to_max_length:
|
| 162 |
+
if verbose:
|
| 163 |
+
warnings.warn(
|
| 164 |
+
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
|
| 165 |
+
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
|
| 166 |
+
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
|
| 167 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
|
| 168 |
+
"maximal input size of the model (e.g. 512 for Bert).",
|
| 169 |
+
FutureWarning,
|
| 170 |
+
)
|
| 171 |
+
if max_length is None:
|
| 172 |
+
padding_strategy = PaddingStrategy.LONGEST
|
| 173 |
+
else:
|
| 174 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
| 175 |
+
elif padding is not False:
|
| 176 |
+
if padding is True:
|
| 177 |
+
padding_strategy = (
|
| 178 |
+
PaddingStrategy.LONGEST
|
| 179 |
+
) # Default to pad to the longest sequence in the batch
|
| 180 |
+
elif not isinstance(padding, PaddingStrategy):
|
| 181 |
+
padding_strategy = PaddingStrategy(padding)
|
| 182 |
+
elif isinstance(padding, PaddingStrategy):
|
| 183 |
+
padding_strategy = padding
|
| 184 |
+
else:
|
| 185 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
| 186 |
+
|
| 187 |
+
# Get truncation strategy
|
| 188 |
+
if truncation is False and old_truncation_strategy != "do_not_truncate":
|
| 189 |
+
if verbose:
|
| 190 |
+
warnings.warn(
|
| 191 |
+
"The `truncation_strategy` argument is deprecated and will be removed in a future version, "
|
| 192 |
+
"use `truncation=True` to truncate examples to a max length. You can give a specific "
|
| 193 |
+
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
|
| 194 |
+
"maximal input size of the model (e.g. 512 for Bert). "
|
| 195 |
+
" If you have pairs of inputs, you can give a specific truncation strategy selected among "
|
| 196 |
+
"`truncation='only_first'` (will only truncate the first sentence in the pairs) "
|
| 197 |
+
"`truncation='only_second'` (will only truncate the second sentence in the pairs) "
|
| 198 |
+
"or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
|
| 199 |
+
FutureWarning,
|
| 200 |
+
)
|
| 201 |
+
truncation_strategy = TruncationStrategy(old_truncation_strategy)
|
| 202 |
+
elif truncation is not False:
|
| 203 |
+
if truncation is True:
|
| 204 |
+
truncation_strategy = (
|
| 205 |
+
TruncationStrategy.LONGEST_FIRST
|
| 206 |
+
) # Default to truncate the longest sequences in pairs of inputs
|
| 207 |
+
elif not isinstance(truncation, TruncationStrategy):
|
| 208 |
+
truncation_strategy = TruncationStrategy(truncation)
|
| 209 |
+
elif isinstance(truncation, TruncationStrategy):
|
| 210 |
+
truncation_strategy = truncation
|
| 211 |
+
else:
|
| 212 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
| 213 |
+
|
| 214 |
+
# Set max length if needed
|
| 215 |
+
if max_length is None:
|
| 216 |
+
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
| 217 |
+
if self.model_max_length > LARGE_INTEGER:
|
| 218 |
+
if verbose:
|
| 219 |
+
if not self.deprecation_warnings.get(
|
| 220 |
+
"Asking-to-pad-to-max_length", False
|
| 221 |
+
):
|
| 222 |
+
logger.warning(
|
| 223 |
+
"Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
| 224 |
+
"Default to no padding."
|
| 225 |
+
)
|
| 226 |
+
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
|
| 227 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
| 228 |
+
else:
|
| 229 |
+
max_length = self.model_max_length
|
| 230 |
+
|
| 231 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
| 232 |
+
if self.model_max_length > LARGE_INTEGER:
|
| 233 |
+
if verbose:
|
| 234 |
+
if not self.deprecation_warnings.get(
|
| 235 |
+
"Asking-to-truncate-to-max_length", False
|
| 236 |
+
):
|
| 237 |
+
logger.warning(
|
| 238 |
+
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
|
| 239 |
+
"Default to no truncation."
|
| 240 |
+
)
|
| 241 |
+
self.deprecation_warnings[
|
| 242 |
+
"Asking-to-truncate-to-max_length"
|
| 243 |
+
] = True
|
| 244 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
| 245 |
+
else:
|
| 246 |
+
max_length = self.model_max_length
|
| 247 |
+
|
| 248 |
+
# Test if we have a padding token
|
| 249 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
|
| 250 |
+
not self.pad_token or self.pad_token_id < 0
|
| 251 |
+
):
|
| 252 |
+
raise ValueError(
|
| 253 |
+
"Asking to pad but the tokenizer does not have a padding token. "
|
| 254 |
+
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
|
| 255 |
+
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
|
| 259 |
+
if (
|
| 260 |
+
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
|
| 261 |
+
and padding_strategy != PaddingStrategy.DO_NOT_PAD
|
| 262 |
+
and pad_to_multiple_of is not None
|
| 263 |
+
and max_length is not None
|
| 264 |
+
and (max_length % pad_to_multiple_of != 0)
|
| 265 |
+
):
|
| 266 |
+
raise ValueError(
|
| 267 |
+
f"Truncation and padding are both activated but "
|
| 268 |
+
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return padding_strategy, truncation_strategy, max_length, kwargs
|
| 272 |
+
|
| 273 |
+
def pad(
|
| 274 |
+
self,
|
| 275 |
+
encoded_inputs: Union[
|
| 276 |
+
BatchEncoding,
|
| 277 |
+
List[BatchEncoding],
|
| 278 |
+
Dict[str, EncodedInput],
|
| 279 |
+
Dict[str, List[EncodedInput]],
|
| 280 |
+
List[Dict[str, EncodedInput]],
|
| 281 |
+
],
|
| 282 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
| 283 |
+
max_length: Optional[int] = None,
|
| 284 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 285 |
+
return_attention_mask: Optional[bool] = True,
|
| 286 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 287 |
+
verbose: bool = True,
|
| 288 |
+
) -> BatchEncoding:
|
| 289 |
+
"""
|
| 290 |
+
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
|
| 291 |
+
in the batch.
|
| 292 |
+
|
| 293 |
+
Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
|
| 294 |
+
``self.pad_token_id`` and ``self.pad_token_type_id``)
|
| 295 |
+
|
| 296 |
+
.. note::
|
| 297 |
+
|
| 298 |
+
If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
| 299 |
+
result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
|
| 300 |
+
case of PyTorch tensors, you will lose the specific device of your tensors however.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
|
| 304 |
+
Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
|
| 305 |
+
List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
|
| 306 |
+
List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
|
| 307 |
+
well as in a PyTorch Dataloader collate function.
|
| 308 |
+
|
| 309 |
+
Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
| 310 |
+
see the note above for the return type.
|
| 311 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
| 312 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 313 |
+
index) among:
|
| 314 |
+
|
| 315 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
| 316 |
+
single sequence if provided).
|
| 317 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
| 318 |
+
maximum acceptable input length for the model if that argument is not provided.
|
| 319 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
| 320 |
+
different lengths).
|
| 321 |
+
max_length (:obj:`int`, `optional`):
|
| 322 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 323 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
| 324 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 325 |
+
|
| 326 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 327 |
+
>= 7.5 (Volta).
|
| 328 |
+
return_attention_mask (:obj:`bool`, `optional`):
|
| 329 |
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
| 330 |
+
to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
|
| 331 |
+
|
| 332 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
| 333 |
+
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
| 334 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
| 335 |
+
|
| 336 |
+
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
| 337 |
+
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
| 338 |
+
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
| 339 |
+
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
| 340 |
+
Whether or not to print more information and warnings.
|
| 341 |
+
"""
|
| 342 |
+
# If we have a list of dicts, let's convert it in a dict of lists
|
| 343 |
+
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
| 344 |
+
if isinstance(encoded_inputs, (list, tuple)) and isinstance(
|
| 345 |
+
encoded_inputs[0], (dict, BatchEncoding)
|
| 346 |
+
):
|
| 347 |
+
encoded_inputs = {
|
| 348 |
+
key: [example[key] for example in encoded_inputs]
|
| 349 |
+
for key in encoded_inputs[0].keys()
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
# The model's main input name, usually `input_ids`, has be passed for padding
|
| 353 |
+
if self.model_input_names[0] not in encoded_inputs:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
"You should supply an encoding or a list of encodings to this method"
|
| 356 |
+
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 360 |
+
|
| 361 |
+
if not required_input:
|
| 362 |
+
if return_attention_mask:
|
| 363 |
+
encoded_inputs["attention_mask"] = []
|
| 364 |
+
return encoded_inputs
|
| 365 |
+
|
| 366 |
+
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
|
| 367 |
+
# and rebuild them afterwards if no return_tensors is specified
|
| 368 |
+
# Note that we lose the specific device the tensor may be on for PyTorch
|
| 369 |
+
|
| 370 |
+
first_element = required_input[0]
|
| 371 |
+
if isinstance(first_element, (list, tuple)):
|
| 372 |
+
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
|
| 373 |
+
index = 0
|
| 374 |
+
while len(required_input[index]) == 0:
|
| 375 |
+
index += 1
|
| 376 |
+
if index < len(required_input):
|
| 377 |
+
first_element = required_input[index][0]
|
| 378 |
+
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
| 379 |
+
if not isinstance(first_element, (int, list, tuple)):
|
| 380 |
+
if is_tf_available() and _is_tensorflow(first_element):
|
| 381 |
+
return_tensors = "tf" if return_tensors is None else return_tensors
|
| 382 |
+
elif is_torch_available() and _is_torch(first_element):
|
| 383 |
+
return_tensors = "pt" if return_tensors is None else return_tensors
|
| 384 |
+
elif isinstance(first_element, np.ndarray):
|
| 385 |
+
return_tensors = "np" if return_tensors is None else return_tensors
|
| 386 |
+
else:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"type of {first_element} unknown: {type(first_element)}. "
|
| 389 |
+
f"Should be one of a python, numpy, pytorch or tensorflow object."
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
for key, value in encoded_inputs.items():
|
| 393 |
+
encoded_inputs[key] = to_py_obj(value)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# Convert padding_strategy in PaddingStrategy
|
| 397 |
+
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
|
| 398 |
+
padding=padding, max_length=max_length, verbose=verbose
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 402 |
+
if required_input and not isinstance(required_input[0], (list, tuple)):
|
| 403 |
+
encoded_inputs = self._pad(
|
| 404 |
+
encoded_inputs,
|
| 405 |
+
max_length=max_length,
|
| 406 |
+
padding_strategy=padding_strategy,
|
| 407 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 408 |
+
return_attention_mask=return_attention_mask,
|
| 409 |
+
)
|
| 410 |
+
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
|
| 411 |
+
|
| 412 |
+
batch_size = len(required_input)
|
| 413 |
+
assert all(
|
| 414 |
+
len(v) == batch_size for v in encoded_inputs.values()
|
| 415 |
+
), "Some items in the output dictionary have a different batch size than others."
|
| 416 |
+
|
| 417 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 418 |
+
max_length = max(len(inputs) for inputs in required_input)
|
| 419 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
| 420 |
+
|
| 421 |
+
batch_outputs = {}
|
| 422 |
+
for i in range(batch_size):
|
| 423 |
+
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
|
| 424 |
+
outputs = self._pad(
|
| 425 |
+
inputs,
|
| 426 |
+
max_length=max_length,
|
| 427 |
+
padding_strategy=padding_strategy,
|
| 428 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 429 |
+
return_attention_mask=return_attention_mask,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
for key, value in outputs.items():
|
| 433 |
+
if key not in batch_outputs:
|
| 434 |
+
batch_outputs[key] = []
|
| 435 |
+
batch_outputs[key].append(value)
|
| 436 |
+
|
| 437 |
+
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
| 438 |
+
|
| 439 |
+
def _pad(
|
| 440 |
+
self,
|
| 441 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
| 442 |
+
max_length: Optional[int] = None,
|
| 443 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 444 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 445 |
+
return_attention_mask: Optional[bool] = None,
|
| 446 |
+
) -> dict:
|
| 447 |
+
"""
|
| 448 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
| 452 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
| 453 |
+
Will truncate by taking into account the special tokens.
|
| 454 |
+
padding_strategy: PaddingStrategy to use for padding.
|
| 455 |
+
|
| 456 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
| 457 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
| 458 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
| 459 |
+
The tokenizer padding sides are defined in self.padding_side:
|
| 460 |
+
|
| 461 |
+
- 'left': pads on the left of the sequences
|
| 462 |
+
- 'right': pads on the right of the sequences
|
| 463 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
| 464 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
| 465 |
+
>= 7.5 (Volta).
|
| 466 |
+
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
| 467 |
+
"""
|
| 468 |
+
# Load from model defaults
|
| 469 |
+
if return_attention_mask is None:
|
| 470 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
| 471 |
+
|
| 472 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 473 |
+
|
| 474 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 475 |
+
max_length = len(required_input)
|
| 476 |
+
|
| 477 |
+
if (
|
| 478 |
+
max_length is not None
|
| 479 |
+
and pad_to_multiple_of is not None
|
| 480 |
+
and (max_length % pad_to_multiple_of != 0)
|
| 481 |
+
):
|
| 482 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 483 |
+
|
| 484 |
+
needs_to_be_padded = (
|
| 485 |
+
padding_strategy != PaddingStrategy.DO_NOT_PAD
|
| 486 |
+
and len(required_input) != max_length
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
if needs_to_be_padded:
|
| 490 |
+
difference = max_length - len(required_input)
|
| 491 |
+
if self.padding_side == "right":
|
| 492 |
+
if return_attention_mask:
|
| 493 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input) + [
|
| 494 |
+
0
|
| 495 |
+
] * difference
|
| 496 |
+
if "token_type_ids" in encoded_inputs:
|
| 497 |
+
encoded_inputs["token_type_ids"] = (
|
| 498 |
+
encoded_inputs["token_type_ids"]
|
| 499 |
+
+ [self.pad_token_type_id] * difference
|
| 500 |
+
)
|
| 501 |
+
if "special_tokens_mask" in encoded_inputs:
|
| 502 |
+
encoded_inputs["special_tokens_mask"] = (
|
| 503 |
+
encoded_inputs["special_tokens_mask"] + [1] * difference
|
| 504 |
+
)
|
| 505 |
+
encoded_inputs[self.model_input_names[0]] = (
|
| 506 |
+
required_input + [self.pad_token_id] * difference
|
| 507 |
+
)
|
| 508 |
+
elif self.padding_side == "left":
|
| 509 |
+
if return_attention_mask:
|
| 510 |
+
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
|
| 511 |
+
required_input
|
| 512 |
+
)
|
| 513 |
+
if "token_type_ids" in encoded_inputs:
|
| 514 |
+
encoded_inputs["token_type_ids"] = [
|
| 515 |
+
self.pad_token_type_id
|
| 516 |
+
] * difference + encoded_inputs["token_type_ids"]
|
| 517 |
+
if "special_tokens_mask" in encoded_inputs:
|
| 518 |
+
encoded_inputs["special_tokens_mask"] = [
|
| 519 |
+
1
|
| 520 |
+
] * difference + encoded_inputs["special_tokens_mask"]
|
| 521 |
+
encoded_inputs[self.model_input_names[0]] = [
|
| 522 |
+
self.pad_token_id
|
| 523 |
+
] * difference + required_input
|
| 524 |
+
else:
|
| 525 |
+
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
| 526 |
+
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
| 527 |
+
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
| 528 |
+
|
| 529 |
+
return encoded_inputs
|
| 530 |
+
|
| 531 |
+
def get_special_tokens_mask(
|
| 532 |
+
self,
|
| 533 |
+
token_ids_0: List[int],
|
| 534 |
+
token_ids_1: Optional[List[int]] = None,
|
| 535 |
+
already_has_special_tokens: bool = False,
|
| 536 |
+
) -> List[int]:
|
| 537 |
+
"""
|
| 538 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 539 |
+
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
| 540 |
+
Args:
|
| 541 |
+
token_ids_0 (:obj:`List[int]`):
|
| 542 |
+
List of ids of the first sequence.
|
| 543 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
| 544 |
+
List of ids of the second sequence.
|
| 545 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 546 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 547 |
+
Returns:
|
| 548 |
+
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 549 |
+
"""
|
| 550 |
+
assert already_has_special_tokens and token_ids_1 is None, (
|
| 551 |
+
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
|
| 552 |
+
"Please use a slow (full python) tokenizer to activate this argument."
|
| 553 |
+
"Or set `return_special_tokens_mask=True` when calling the encoding method "
|
| 554 |
+
"to get the special tokens mask in any tokenizer. "
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
all_special_ids = self.all_special_ids # cache the property
|
| 558 |
+
|
| 559 |
+
special_tokens_mask = [
|
| 560 |
+
1 if token in all_special_ids else 0 for token in token_ids_0
|
| 561 |
+
]
|
| 562 |
+
|
| 563 |
+
return special_tokens_mask
|
| 564 |
+
|
| 565 |
+
def convert_tokens_to_ids(
|
| 566 |
+
self, tokens: Union[str, List[str]]
|
| 567 |
+
) -> Union[int, List[int]]:
|
| 568 |
+
"""
|
| 569 |
+
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
| 570 |
+
vocabulary.
|
| 571 |
+
Args:
|
| 572 |
+
tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
|
| 573 |
+
Returns:
|
| 574 |
+
:obj:`int` or :obj:`List[int]`: The token id or list of token ids.
|
| 575 |
+
"""
|
| 576 |
+
if tokens is None:
|
| 577 |
+
return None
|
| 578 |
+
|
| 579 |
+
if isinstance(tokens, str):
|
| 580 |
+
return self._convert_token_to_id_with_added_voc(tokens)
|
| 581 |
+
|
| 582 |
+
ids = []
|
| 583 |
+
for token in tokens:
|
| 584 |
+
ids.append(self._convert_token_to_id_with_added_voc(token))
|
| 585 |
+
return ids
|
| 586 |
+
|
| 587 |
+
def _convert_token_to_id_with_added_voc(self, token):
|
| 588 |
+
if token is None:
|
| 589 |
+
return None
|
| 590 |
+
|
| 591 |
+
return self.token_dictionary.get(token)
|
| 592 |
+
|
| 593 |
+
def __len__(self):
|
| 594 |
+
return len(self.token_dictionary)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
class GeneformerPretrainer(Trainer):
|
| 598 |
+
def __init__(self, *args, **kwargs):
|
| 599 |
+
data_collator = kwargs.get("data_collator",None)
|
| 600 |
+
token_dictionary = kwargs.pop("token_dictionary")
|
| 601 |
+
|
| 602 |
+
if data_collator is None:
|
| 603 |
+
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
| 604 |
+
|
| 605 |
+
# # Data Collator Functions
|
| 606 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 607 |
+
tokenizer=precollator, mlm=True, mlm_probability=0.15
|
| 608 |
+
)
|
| 609 |
+
kwargs["data_collator"] = data_collator
|
| 610 |
+
|
| 611 |
+
# load previously saved length vector for dataset to speed up LengthGroupedSampler
|
| 612 |
+
# pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
|
| 613 |
+
example_lengths_file = kwargs.pop("example_lengths_file")
|
| 614 |
+
if example_lengths_file:
|
| 615 |
+
with open(example_lengths_file, "rb") as f:
|
| 616 |
+
self.example_lengths = pickle.load(f)
|
| 617 |
+
else:
|
| 618 |
+
raise Exception(
|
| 619 |
+
"example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
|
| 620 |
+
)
|
| 621 |
+
super().__init__(*args, **kwargs)
|
| 622 |
+
# self.exp_logits_dir = exp_logits_dir
|
| 623 |
+
# self.min_exp_logits = float('inf')
|
| 624 |
+
# self.max_exp_logits = float('-inf')
|
| 625 |
+
|
| 626 |
+
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
| 627 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
| 628 |
+
if not isinstance(self.train_dataset, collections.abc.Sized):
|
| 629 |
+
return None
|
| 630 |
+
|
| 631 |
+
generator = None
|
| 632 |
+
if self.args.world_size <= 1 and _is_torch_generator_available:
|
| 633 |
+
generator = torch.Generator()
|
| 634 |
+
generator.manual_seed(
|
| 635 |
+
int(torch.empty((), dtype=torch.int64).random_().item())
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# Build the sampler.
|
| 639 |
+
if self.args.group_by_length:
|
| 640 |
+
if is_datasets_available() and isinstance(self.train_dataset, Dataset):
|
| 641 |
+
lengths = self.example_lengths
|
| 642 |
+
else:
|
| 643 |
+
lengths = None
|
| 644 |
+
model_input_name = (
|
| 645 |
+
self.tokenizer.model_input_names[0]
|
| 646 |
+
if self.tokenizer is not None
|
| 647 |
+
else None
|
| 648 |
+
)
|
| 649 |
+
if self.args.world_size <= 1:
|
| 650 |
+
return LengthGroupedSampler(
|
| 651 |
+
dataset=self.train_dataset,
|
| 652 |
+
batch_size=self.args.train_batch_size,
|
| 653 |
+
lengths=lengths,
|
| 654 |
+
model_input_name=model_input_name,
|
| 655 |
+
generator=generator,
|
| 656 |
+
)
|
| 657 |
+
else:
|
| 658 |
+
return CustomDistributedLengthGroupedSampler(
|
| 659 |
+
dataset=self.train_dataset,
|
| 660 |
+
batch_size=self.args.train_batch_size,
|
| 661 |
+
num_replicas=self.args.world_size,
|
| 662 |
+
rank=self.args.process_index,
|
| 663 |
+
lengths=lengths,
|
| 664 |
+
model_input_name=model_input_name,
|
| 665 |
+
seed=self.args.seed,
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
else:
|
| 669 |
+
if self.args.world_size <= 1:
|
| 670 |
+
if _is_torch_generator_available:
|
| 671 |
+
return RandomSampler(self.train_dataset, generator=generator)
|
| 672 |
+
return RandomSampler(self.train_dataset)
|
| 673 |
+
elif (
|
| 674 |
+
self.args.parallel_mode
|
| 675 |
+
in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
| 676 |
+
and not self.args.dataloader_drop_last
|
| 677 |
+
):
|
| 678 |
+
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
| 679 |
+
return DistributedSamplerWithLoop(
|
| 680 |
+
self.train_dataset,
|
| 681 |
+
batch_size=self.args.per_device_train_batch_size,
|
| 682 |
+
num_replicas=self.args.world_size,
|
| 683 |
+
rank=self.args.process_index,
|
| 684 |
+
seed=self.args.seed,
|
| 685 |
+
)
|
| 686 |
+
else:
|
| 687 |
+
return DistributedSampler(
|
| 688 |
+
self.train_dataset,
|
| 689 |
+
num_replicas=self.args.world_size,
|
| 690 |
+
rank=self.args.process_index,
|
| 691 |
+
seed=self.args.seed,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
| 696 |
+
r"""
|
| 697 |
+
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
| 698 |
+
length while keeping a bit of randomness.
|
| 699 |
+
"""
|
| 700 |
+
# Copied and adapted from PyTorch DistributedSampler.
|
| 701 |
+
def __init__(
|
| 702 |
+
self,
|
| 703 |
+
dataset: Dataset,
|
| 704 |
+
batch_size: int,
|
| 705 |
+
num_replicas: Optional[int] = None,
|
| 706 |
+
rank: Optional[int] = None,
|
| 707 |
+
seed: int = 0,
|
| 708 |
+
drop_last: bool = False,
|
| 709 |
+
lengths: Optional[List[int]] = None,
|
| 710 |
+
model_input_name: Optional[str] = None,
|
| 711 |
+
):
|
| 712 |
+
if num_replicas is None:
|
| 713 |
+
if not dist.is_available():
|
| 714 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 715 |
+
num_replicas = dist.get_world_size()
|
| 716 |
+
if rank is None:
|
| 717 |
+
if not dist.is_available():
|
| 718 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 719 |
+
rank = dist.get_rank()
|
| 720 |
+
self.dataset = dataset
|
| 721 |
+
self.batch_size = batch_size
|
| 722 |
+
self.num_replicas = num_replicas
|
| 723 |
+
self.rank = rank
|
| 724 |
+
self.epoch = 0
|
| 725 |
+
self.drop_last = drop_last
|
| 726 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
| 727 |
+
# is no need to drop any data, since the dataset will be split equally.
|
| 728 |
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
| 729 |
+
# Split to nearest available length that is evenly divisible.
|
| 730 |
+
# This is to ensure each rank receives the same amount of data when
|
| 731 |
+
# using this Sampler.
|
| 732 |
+
self.num_samples = math.ceil(
|
| 733 |
+
(len(self.dataset) - self.num_replicas) / self.num_replicas
|
| 734 |
+
)
|
| 735 |
+
else:
|
| 736 |
+
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
| 737 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 738 |
+
self.seed = seed
|
| 739 |
+
self.model_input_name = (
|
| 740 |
+
model_input_name if model_input_name is not None else "input_ids"
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
if lengths is None:
|
| 744 |
+
print("Lengths is none - calculating lengths.")
|
| 745 |
+
if (
|
| 746 |
+
not (
|
| 747 |
+
isinstance(dataset[0], dict)
|
| 748 |
+
or isinstance(dataset[0], BatchEncoding)
|
| 749 |
+
)
|
| 750 |
+
or self.model_input_name not in dataset[0]
|
| 751 |
+
):
|
| 752 |
+
raise ValueError(
|
| 753 |
+
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
| 754 |
+
f"'{self.model_input_name}' key."
|
| 755 |
+
)
|
| 756 |
+
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
| 757 |
+
self.lengths = lengths
|
| 758 |
+
|
| 759 |
+
def __iter__(self) -> Iterator:
|
| 760 |
+
# Deterministically shuffle based on epoch and seed
|
| 761 |
+
g = torch.Generator()
|
| 762 |
+
g.manual_seed(self.seed + self.epoch)
|
| 763 |
+
|
| 764 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
| 765 |
+
|
| 766 |
+
if not self.drop_last:
|
| 767 |
+
# add extra samples to make it evenly divisible
|
| 768 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 769 |
+
else:
|
| 770 |
+
# remove tail of data to make it evenly divisible.
|
| 771 |
+
indices = indices[: self.total_size]
|
| 772 |
+
assert len(indices) == self.total_size
|
| 773 |
+
|
| 774 |
+
# subsample
|
| 775 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
| 776 |
+
assert len(indices) == self.num_samples
|
| 777 |
+
|
| 778 |
+
return iter(indices)
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def get_length_grouped_indices(
|
| 782 |
+
lengths, batch_size, mega_batch_mult=None, generator=None
|
| 783 |
+
):
|
| 784 |
+
"""
|
| 785 |
+
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
| 786 |
+
similar lengths. To do this, the indices are:
|
| 787 |
+
|
| 788 |
+
- randomly permuted
|
| 789 |
+
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
| 790 |
+
- sorted by length in each mega-batch
|
| 791 |
+
|
| 792 |
+
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
| 793 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
| 794 |
+
"""
|
| 795 |
+
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
| 796 |
+
if mega_batch_mult is None:
|
| 797 |
+
# mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
| 798 |
+
mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
|
| 799 |
+
# Just in case, for tiny datasets
|
| 800 |
+
if mega_batch_mult == 0:
|
| 801 |
+
mega_batch_mult = 1
|
| 802 |
+
|
| 803 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
| 804 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
| 805 |
+
megabatch_size = mega_batch_mult * batch_size
|
| 806 |
+
megabatches = [
|
| 807 |
+
indices[i : i + megabatch_size].tolist()
|
| 808 |
+
for i in range(0, len(lengths), megabatch_size)
|
| 809 |
+
]
|
| 810 |
+
megabatches = [
|
| 811 |
+
list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
|
| 812 |
+
for megabatch in megabatches
|
| 813 |
+
]
|
| 814 |
+
|
| 815 |
+
# The rest is to get the biggest batch first.
|
| 816 |
+
# Since each megabatch is sorted by descending length, the longest element is the first
|
| 817 |
+
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
| 818 |
+
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
| 819 |
+
# Switch to put the longest element in first position
|
| 820 |
+
megabatches[0][0], megabatches[max_idx][0] = (
|
| 821 |
+
megabatches[max_idx][0],
|
| 822 |
+
megabatches[0][0],
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
return [item for sublist in megabatches for item in sublist]
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
# from typing import Any, Tuple, Optional
|
| 829 |
+
|
| 830 |
+
# class CustomDataCollatorForMLM(DataCollatorForLanguageModeling):
|
| 831 |
+
# # def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
| 832 |
+
|
| 833 |
+
# # import torch
|
| 834 |
+
|
| 835 |
+
# # labels = inputs.clone()
|
| 836 |
+
# # # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 837 |
+
# # probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
| 838 |
+
# # if special_tokens_mask is None:
|
| 839 |
+
# # special_tokens_mask = [
|
| 840 |
+
# # self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 841 |
+
# # ]
|
| 842 |
+
# # special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
| 843 |
+
# # else:
|
| 844 |
+
# # special_tokens_mask = special_tokens_mask.bool()
|
| 845 |
+
# # probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
| 846 |
+
# # masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 847 |
+
# # labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 848 |
+
|
| 849 |
+
# # # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 850 |
+
# # indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 851 |
+
# # inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 852 |
+
|
| 853 |
+
# # # 15% of the time, we replace masked input tokens with random word
|
| 854 |
+
# # indices_random = torch.bernoulli(torch.full(labels.shape, 0.75)).bool() & masked_indices & ~indices_replaced
|
| 855 |
+
# # random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
| 856 |
+
# # inputs[indices_random] = random_words[indices_random]
|
| 857 |
+
|
| 858 |
+
# # # The rest of the time (5% of the time) we keep the masked input tokens unchanged
|
| 859 |
+
# # return inputs, labels
|
| 860 |
+
|
| 861 |
+
# def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
| 862 |
+
# import torch
|
| 863 |
+
|
| 864 |
+
# labels = inputs.clone()
|
| 865 |
+
# probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
| 866 |
+
# if special_tokens_mask is None:
|
| 867 |
+
# special_tokens_mask = [
|
| 868 |
+
# self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 869 |
+
# ]
|
| 870 |
+
# special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
| 871 |
+
# else:
|
| 872 |
+
# special_tokens_mask = special_tokens_mask.bool()
|
| 873 |
+
|
| 874 |
+
# probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
| 875 |
+
# masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 876 |
+
# labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 877 |
+
|
| 878 |
+
# # 100% of the time, replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 879 |
+
# inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 880 |
+
|
| 881 |
+
# return inputs, labels
|
| 882 |
+
|
| 883 |
+
# class CustomGeneformerPretrainer(Trainer):
|
| 884 |
+
# def __init__(self, *args, **kwargs):
|
| 885 |
+
# data_collator = kwargs.get("data_collator",None)
|
| 886 |
+
# token_dictionary = kwargs.pop("token_dictionary")
|
| 887 |
+
|
| 888 |
+
# if data_collator is None:
|
| 889 |
+
# precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
|
| 890 |
+
|
| 891 |
+
# # # Data Collator Functions
|
| 892 |
+
# data_collator = CustomDataCollatorForMLM(
|
| 893 |
+
# tokenizer=precollator, mlm=True, mlm_probability=0.15
|
| 894 |
+
# )
|
| 895 |
+
# kwargs["data_collator"] = data_collator
|
| 896 |
+
|
| 897 |
+
# # load previously saved length vector for dataset to speed up LengthGroupedSampler
|
| 898 |
+
# # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
|
| 899 |
+
# example_lengths_file = kwargs.pop("example_lengths_file")
|
| 900 |
+
# if example_lengths_file:
|
| 901 |
+
# with open(example_lengths_file, "rb") as f:
|
| 902 |
+
# self.example_lengths = pickle.load(f)
|
| 903 |
+
# else:
|
| 904 |
+
# raise Exception(
|
| 905 |
+
# "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
|
| 906 |
+
# )
|
| 907 |
+
# super().__init__(*args, **kwargs)
|
| 908 |
+
# # self.exp_logits_dir = exp_logits_dir
|
| 909 |
+
# # self.min_exp_logits = float('inf')
|
| 910 |
+
# # self.max_exp_logits = float('-inf')
|
| 911 |
+
|
| 912 |
+
# # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
| 913 |
+
# def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
| 914 |
+
# if not isinstance(self.train_dataset, collections.abc.Sized):
|
| 915 |
+
# return None
|
| 916 |
+
|
| 917 |
+
# generator = None
|
| 918 |
+
# if self.args.world_size <= 1 and _is_torch_generator_available:
|
| 919 |
+
# generator = torch.Generator()
|
| 920 |
+
# generator.manual_seed(
|
| 921 |
+
# int(torch.empty((), dtype=torch.int64).random_().item())
|
| 922 |
+
# )
|
| 923 |
+
|
| 924 |
+
# # Build the sampler.
|
| 925 |
+
# if self.args.group_by_length:
|
| 926 |
+
# if is_datasets_available() and isinstance(self.train_dataset, Dataset):
|
| 927 |
+
# lengths = self.example_lengths
|
| 928 |
+
# else:
|
| 929 |
+
# lengths = None
|
| 930 |
+
# model_input_name = (
|
| 931 |
+
# self.tokenizer.model_input_names[0]
|
| 932 |
+
# if self.tokenizer is not None
|
| 933 |
+
# else None
|
| 934 |
+
# )
|
| 935 |
+
# if self.args.world_size <= 1:
|
| 936 |
+
# return LengthGroupedSampler(
|
| 937 |
+
# dataset=self.train_dataset,
|
| 938 |
+
# batch_size=self.args.train_batch_size,
|
| 939 |
+
# lengths=lengths,
|
| 940 |
+
# model_input_name=model_input_name,
|
| 941 |
+
# generator=generator,
|
| 942 |
+
# )
|
| 943 |
+
# else:
|
| 944 |
+
# return CustomDistributedLengthGroupedSampler(
|
| 945 |
+
# dataset=self.train_dataset,
|
| 946 |
+
# batch_size=self.args.train_batch_size,
|
| 947 |
+
# num_replicas=self.args.world_size,
|
| 948 |
+
# rank=self.args.process_index,
|
| 949 |
+
# lengths=lengths,
|
| 950 |
+
# model_input_name=model_input_name,
|
| 951 |
+
# seed=self.args.seed,
|
| 952 |
+
# )
|
| 953 |
+
|
| 954 |
+
# else:
|
| 955 |
+
# if self.args.world_size <= 1:
|
| 956 |
+
# if _is_torch_generator_available:
|
| 957 |
+
# return RandomSampler(self.train_dataset, generator=generator)
|
| 958 |
+
# return RandomSampler(self.train_dataset)
|
| 959 |
+
# elif (
|
| 960 |
+
# self.args.parallel_mode
|
| 961 |
+
# in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
| 962 |
+
# and not self.args.dataloader_drop_last
|
| 963 |
+
# ):
|
| 964 |
+
# # Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
| 965 |
+
# return DistributedSamplerWithLoop(
|
| 966 |
+
# self.train_dataset,
|
| 967 |
+
# batch_size=self.args.per_device_train_batch_size,
|
| 968 |
+
# num_replicas=self.args.world_size,
|
| 969 |
+
# rank=self.args.process_index,
|
| 970 |
+
# seed=self.args.seed,
|
| 971 |
+
# )
|
| 972 |
+
# else:
|
| 973 |
+
# return DistributedSampler(
|
| 974 |
+
# self.train_dataset,
|
| 975 |
+
# num_replicas=self.args.world_size,
|
| 976 |
+
# rank=self.args.process_index,
|
| 977 |
+
# seed=self.args.seed,
|
| 978 |
+
# )
|
geneformer/token_dictionary.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab9dc40973fa5224d77b793e2fd114cacf3d08423ed9c4c49caf0ba9c7f218f1
|
| 3 |
+
size 788424
|
geneformer/tokenizer.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geneformer tokenizer.
|
| 3 |
+
|
| 4 |
+
**Input data:**
|
| 5 |
+
|
| 6 |
+
| *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
|
| 7 |
+
| *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
|
| 8 |
+
| *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
|
| 9 |
+
|
| 10 |
+
| *Optional col (cell) attribute:* "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria.
|
| 11 |
+
| *Optional col (cell) attributes:* any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below.
|
| 12 |
+
|
| 13 |
+
**Usage:**
|
| 14 |
+
|
| 15 |
+
.. code-block :: python
|
| 16 |
+
|
| 17 |
+
>>> from geneformer import TranscriptomeTokenizer
|
| 18 |
+
>>> tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ"}, nproc=4)
|
| 19 |
+
>>> tk.tokenize_data("data_directory", "output_directory", "output_prefix")
|
| 20 |
+
|
| 21 |
+
**Description:**
|
| 22 |
+
|
| 23 |
+
| Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
|
| 24 |
+
|
| 25 |
+
| The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
|
| 26 |
+
|
| 27 |
+
| Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
|
| 28 |
+
|
| 29 |
+
| No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes "cell_type" and "organ_major" and one would like to retain these attributes as labels in the tokenized dataset with the new names "cell_type" and "organ", respectively, the following custom attribute dictionary should be provided: {"cell_type": "cell_type", "organ_major": "organ"}.
|
| 30 |
+
|
| 31 |
+
| Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
|
| 32 |
+
|
| 33 |
+
| If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from __future__ import annotations
|
| 38 |
+
|
| 39 |
+
import logging
|
| 40 |
+
import pickle
|
| 41 |
+
import warnings
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
from typing import Literal
|
| 44 |
+
|
| 45 |
+
import anndata as ad
|
| 46 |
+
import loompy as lp
|
| 47 |
+
import numpy as np
|
| 48 |
+
import scipy.sparse as sp
|
| 49 |
+
from datasets import Dataset
|
| 50 |
+
|
| 51 |
+
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
|
| 52 |
+
logger = logging.getLogger(__name__)
|
| 53 |
+
|
| 54 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
| 55 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def rank_genes(gene_vector, gene_tokens):
|
| 59 |
+
"""
|
| 60 |
+
Rank gene expression vector.
|
| 61 |
+
"""
|
| 62 |
+
# sort by median-scaled gene values
|
| 63 |
+
sorted_indices = np.argsort(-gene_vector)
|
| 64 |
+
return gene_tokens[sorted_indices]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def tokenize_cell(gene_vector, gene_tokens):
|
| 68 |
+
"""
|
| 69 |
+
Convert normalized gene expression vector to tokenized rank value encoding.
|
| 70 |
+
"""
|
| 71 |
+
# create array of gene vector with token indices
|
| 72 |
+
# mask undetected genes
|
| 73 |
+
nonzero_mask = np.nonzero(gene_vector)[0]
|
| 74 |
+
# rank by median-scaled gene values
|
| 75 |
+
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class TranscriptomeTokenizer:
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
custom_attr_name_dict=None,
|
| 82 |
+
nproc=1,
|
| 83 |
+
chunk_size=512,
|
| 84 |
+
gene_median_file=GENE_MEDIAN_FILE,
|
| 85 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Initialize tokenizer.
|
| 89 |
+
|
| 90 |
+
**Parameters:**
|
| 91 |
+
|
| 92 |
+
custom_attr_name_dict : None, dict
|
| 93 |
+
| Dictionary of custom attributes to be added to the dataset.
|
| 94 |
+
| Keys are the names of the attributes in the loom file.
|
| 95 |
+
| Values are the names of the attributes in the dataset.
|
| 96 |
+
nproc : int
|
| 97 |
+
| Number of processes to use for dataset mapping.
|
| 98 |
+
chunk_size: int = 512
|
| 99 |
+
| Chunk size for anndata tokenizer.
|
| 100 |
+
gene_median_file : Path
|
| 101 |
+
| Path to pickle file containing dictionary of non-zero median
|
| 102 |
+
| gene expression values across Genecorpus-30M.
|
| 103 |
+
token_dictionary_file : Path
|
| 104 |
+
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
| 105 |
+
"""
|
| 106 |
+
# dictionary of custom attributes {output dataset column name: input .loom column name}
|
| 107 |
+
self.custom_attr_name_dict = custom_attr_name_dict
|
| 108 |
+
|
| 109 |
+
# number of processes for dataset mapping
|
| 110 |
+
self.nproc = nproc
|
| 111 |
+
|
| 112 |
+
# chunk size for anndata tokenizer
|
| 113 |
+
self.chunk_size = chunk_size
|
| 114 |
+
|
| 115 |
+
# load dictionary of gene normalization factors
|
| 116 |
+
# (non-zero median value of expression across Genecorpus-30M)
|
| 117 |
+
with open(gene_median_file, "rb") as f:
|
| 118 |
+
self.gene_median_dict = pickle.load(f)
|
| 119 |
+
|
| 120 |
+
# load token dictionary (Ensembl IDs:token)
|
| 121 |
+
with open(token_dictionary_file, "rb") as f:
|
| 122 |
+
self.gene_token_dict = pickle.load(f)
|
| 123 |
+
|
| 124 |
+
# gene keys for full vocabulary
|
| 125 |
+
self.gene_keys = list(self.gene_median_dict.keys())
|
| 126 |
+
|
| 127 |
+
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
| 128 |
+
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|
| 129 |
+
|
| 130 |
+
def tokenize_data(
|
| 131 |
+
self,
|
| 132 |
+
data_directory: Path | str,
|
| 133 |
+
output_directory: Path | str,
|
| 134 |
+
output_prefix: str,
|
| 135 |
+
file_format: Literal["loom", "h5ad"] = "loom",
|
| 136 |
+
use_generator: bool = False,
|
| 137 |
+
):
|
| 138 |
+
"""
|
| 139 |
+
Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
|
| 140 |
+
|
| 141 |
+
**Parameters:**
|
| 142 |
+
|
| 143 |
+
data_directory : Path
|
| 144 |
+
| Path to directory containing loom files or anndata files
|
| 145 |
+
output_directory : Path
|
| 146 |
+
| Path to directory where tokenized data will be saved as .dataset
|
| 147 |
+
output_prefix : str
|
| 148 |
+
| Prefix for output .dataset
|
| 149 |
+
file_format : str
|
| 150 |
+
| Format of input files. Can be "loom" or "h5ad".
|
| 151 |
+
use_generator : bool
|
| 152 |
+
| Whether to use generator or dict for tokenization.
|
| 153 |
+
"""
|
| 154 |
+
tokenized_cells, cell_metadata = self.tokenize_files(
|
| 155 |
+
Path(data_directory), file_format
|
| 156 |
+
)
|
| 157 |
+
tokenized_dataset = self.create_dataset(
|
| 158 |
+
tokenized_cells, cell_metadata, use_generator=use_generator
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
| 162 |
+
tokenized_dataset.save_to_disk(output_path)
|
| 163 |
+
|
| 164 |
+
def tokenize_files(
|
| 165 |
+
self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
|
| 166 |
+
):
|
| 167 |
+
tokenized_cells = []
|
| 168 |
+
if self.custom_attr_name_dict is not None:
|
| 169 |
+
cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
|
| 170 |
+
cell_metadata = {
|
| 171 |
+
attr_key: [] for attr_key in self.custom_attr_name_dict.values()
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
# loops through directories to tokenize .loom files
|
| 175 |
+
file_found = 0
|
| 176 |
+
# loops through directories to tokenize .loom or .h5ad files
|
| 177 |
+
tokenize_file_fn = (
|
| 178 |
+
self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
|
| 179 |
+
)
|
| 180 |
+
for file_path in data_directory.glob(f"*.{file_format}"):
|
| 181 |
+
file_found = 1
|
| 182 |
+
print(f"Tokenizing {file_path}")
|
| 183 |
+
file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
|
| 184 |
+
tokenized_cells += file_tokenized_cells
|
| 185 |
+
if self.custom_attr_name_dict is not None:
|
| 186 |
+
for k in cell_attr:
|
| 187 |
+
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
|
| 188 |
+
k
|
| 189 |
+
]
|
| 190 |
+
else:
|
| 191 |
+
cell_metadata = None
|
| 192 |
+
|
| 193 |
+
if file_found == 0:
|
| 194 |
+
logger.error(
|
| 195 |
+
f"No .{file_format} files found in directory {data_directory}."
|
| 196 |
+
)
|
| 197 |
+
raise
|
| 198 |
+
return tokenized_cells, cell_metadata
|
| 199 |
+
|
| 200 |
+
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
| 201 |
+
adata = ad.read(adata_file_path, backed="r")
|
| 202 |
+
|
| 203 |
+
if self.custom_attr_name_dict is not None:
|
| 204 |
+
file_cell_metadata = {
|
| 205 |
+
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
coding_miRNA_loc = np.where(
|
| 209 |
+
[self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
|
| 210 |
+
)[0]
|
| 211 |
+
norm_factor_vector = np.array(
|
| 212 |
+
[
|
| 213 |
+
self.gene_median_dict[i]
|
| 214 |
+
for i in adata.var["ensembl_id"][coding_miRNA_loc]
|
| 215 |
+
]
|
| 216 |
+
)
|
| 217 |
+
coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]
|
| 218 |
+
coding_miRNA_tokens = np.array(
|
| 219 |
+
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
_ = adata.obs["filter_pass"]
|
| 224 |
+
except KeyError:
|
| 225 |
+
var_exists = False
|
| 226 |
+
else:
|
| 227 |
+
var_exists = True
|
| 228 |
+
|
| 229 |
+
if var_exists:
|
| 230 |
+
filter_pass_loc = np.where([i == 1 for i in adata.obs["filter_pass"]])[0]
|
| 231 |
+
elif not var_exists:
|
| 232 |
+
print(
|
| 233 |
+
f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
|
| 234 |
+
)
|
| 235 |
+
filter_pass_loc = np.array([i for i in range(adata.shape[0])])
|
| 236 |
+
|
| 237 |
+
tokenized_cells = []
|
| 238 |
+
|
| 239 |
+
for i in range(0, len(filter_pass_loc), self.chunk_size):
|
| 240 |
+
idx = filter_pass_loc[i : i + self.chunk_size]
|
| 241 |
+
|
| 242 |
+
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
| 243 |
+
X_view = adata[idx, coding_miRNA_loc].X
|
| 244 |
+
X_norm = X_view / n_counts * target_sum / norm_factor_vector
|
| 245 |
+
X_norm = sp.csr_matrix(X_norm)
|
| 246 |
+
|
| 247 |
+
tokenized_cells += [
|
| 248 |
+
rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
|
| 249 |
+
for i in range(X_norm.shape[0])
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
# add custom attributes for subview to dict
|
| 253 |
+
if self.custom_attr_name_dict is not None:
|
| 254 |
+
for k in file_cell_metadata.keys():
|
| 255 |
+
file_cell_metadata[k] += adata[idx].obs[k].tolist()
|
| 256 |
+
else:
|
| 257 |
+
file_cell_metadata = None
|
| 258 |
+
|
| 259 |
+
return tokenized_cells, file_cell_metadata
|
| 260 |
+
|
| 261 |
+
def tokenize_loom(self, loom_file_path, target_sum=10_000):
|
| 262 |
+
if self.custom_attr_name_dict is not None:
|
| 263 |
+
file_cell_metadata = {
|
| 264 |
+
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
with lp.connect(str(loom_file_path)) as data:
|
| 268 |
+
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
| 269 |
+
coding_miRNA_loc = np.where(
|
| 270 |
+
[self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
|
| 271 |
+
)[0]
|
| 272 |
+
norm_factor_vector = np.array(
|
| 273 |
+
[
|
| 274 |
+
self.gene_median_dict[i]
|
| 275 |
+
for i in data.ra["ensembl_id"][coding_miRNA_loc]
|
| 276 |
+
]
|
| 277 |
+
)
|
| 278 |
+
coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
|
| 279 |
+
coding_miRNA_tokens = np.array(
|
| 280 |
+
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# define coordinates of cells passing filters for inclusion (e.g. QC)
|
| 284 |
+
try:
|
| 285 |
+
data.ca["filter_pass"]
|
| 286 |
+
except AttributeError:
|
| 287 |
+
var_exists = False
|
| 288 |
+
else:
|
| 289 |
+
var_exists = True
|
| 290 |
+
|
| 291 |
+
if var_exists:
|
| 292 |
+
filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0]
|
| 293 |
+
elif not var_exists:
|
| 294 |
+
print(
|
| 295 |
+
f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
|
| 296 |
+
)
|
| 297 |
+
filter_pass_loc = np.array([i for i in range(data.shape[1])])
|
| 298 |
+
|
| 299 |
+
# scan through .loom files and tokenize cells
|
| 300 |
+
tokenized_cells = []
|
| 301 |
+
for _ix, _selection, view in data.scan(
|
| 302 |
+
items=filter_pass_loc, axis=1, batch_size=self.chunk_size
|
| 303 |
+
):
|
| 304 |
+
# select subview with protein-coding and miRNA genes
|
| 305 |
+
subview = view.view[coding_miRNA_loc, :]
|
| 306 |
+
|
| 307 |
+
# normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
|
| 308 |
+
# and normalize by gene normalization factors
|
| 309 |
+
subview_norm_array = (
|
| 310 |
+
subview[:, :]
|
| 311 |
+
/ subview.ca.n_counts
|
| 312 |
+
* target_sum
|
| 313 |
+
/ norm_factor_vector[:, None]
|
| 314 |
+
)
|
| 315 |
+
# tokenize subview gene vectors
|
| 316 |
+
tokenized_cells += [
|
| 317 |
+
tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
|
| 318 |
+
for i in range(subview_norm_array.shape[1])
|
| 319 |
+
]
|
| 320 |
+
|
| 321 |
+
# add custom attributes for subview to dict
|
| 322 |
+
if self.custom_attr_name_dict is not None:
|
| 323 |
+
for k in file_cell_metadata.keys():
|
| 324 |
+
file_cell_metadata[k] += subview.ca[k].tolist()
|
| 325 |
+
else:
|
| 326 |
+
file_cell_metadata = None
|
| 327 |
+
|
| 328 |
+
return tokenized_cells, file_cell_metadata
|
| 329 |
+
|
| 330 |
+
def create_dataset(
|
| 331 |
+
self,
|
| 332 |
+
tokenized_cells,
|
| 333 |
+
cell_metadata,
|
| 334 |
+
use_generator=False,
|
| 335 |
+
keep_uncropped_input_ids=False,
|
| 336 |
+
):
|
| 337 |
+
print("Creating dataset.")
|
| 338 |
+
# create dict for dataset creation
|
| 339 |
+
dataset_dict = {"input_ids": tokenized_cells}
|
| 340 |
+
if self.custom_attr_name_dict is not None:
|
| 341 |
+
dataset_dict.update(cell_metadata)
|
| 342 |
+
|
| 343 |
+
# create dataset
|
| 344 |
+
if use_generator:
|
| 345 |
+
|
| 346 |
+
def dict_generator():
|
| 347 |
+
for i in range(len(tokenized_cells)):
|
| 348 |
+
yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
|
| 349 |
+
|
| 350 |
+
output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
|
| 351 |
+
else:
|
| 352 |
+
output_dataset = Dataset.from_dict(dataset_dict)
|
| 353 |
+
|
| 354 |
+
def format_cell_features(example):
|
| 355 |
+
# Store original uncropped input_ids in separate feature
|
| 356 |
+
if keep_uncropped_input_ids:
|
| 357 |
+
example["input_ids_uncropped"] = example["input_ids"]
|
| 358 |
+
example["length_uncropped"] = len(example["input_ids"])
|
| 359 |
+
|
| 360 |
+
# Truncate/Crop input_ids to size 2,048
|
| 361 |
+
example["input_ids"] = example["input_ids"][0:2048]
|
| 362 |
+
example["length"] = len(example["input_ids"])
|
| 363 |
+
|
| 364 |
+
return example
|
| 365 |
+
|
| 366 |
+
output_dataset_truncated = output_dataset.map(
|
| 367 |
+
format_cell_features, num_proc=self.nproc
|
| 368 |
+
)
|
| 369 |
+
return output_dataset_truncated
|